Skip to main content

TensorGraph

The TensorGraph class is a powerful abstraction for managing and executing tensor operations within a computational graph. It provides an efficient mechanism for building, modifying, and dispatching operations on tensors with a high degree of flexibility and control over execution parameters, formats, and flags. Whether working with basic operations or advanced deep learning techniques, this class offers a robust interface for numerical computation.

#include <parallel/TellusimTensorGraph.h>

Constructors

TensorGraph()

Methods

Clear graph.

void clear()

Check graph.

bool isCreated() const

Create graph.

bool create(const Device &device, Flags flags = FlagsAll, Masks masks = MasksAll, Async *async = nullptr)

Dispatch Tensor operation.

bool dispatch(Compute &compute, Operation op, Tensor dest, Flags flags = FlagNone) const
bool dispatch(Compute &compute, Operation op, Tensor dest, Tensor src_0, Flags flags = FlagNone) const
bool dispatch(Compute &compute, Operation op, Tensor dest, Tensor src_0, Tensor src_1, Flags flags = FlagNone) const
bool dispatch(Compute &compute, Operation op, Tensor dest, Tensor src_0, Tensor src_1, Tensor src_2, Flags flags = FlagNone) const
TypeNameDescription
TensorGraph::OperationopGraph operation.
TensorGraph::FlagsflagsOperation flags.
TensordestDestination tensor.
Tensorsrc_0Source tensors.

Dispatch Texture to Tensor.

bool dispatch(Compute &compute, Tensor dest, Texture &src, Region region, Slice slice = Slice()) const
bool dispatch(Compute &compute, Tensor dest, Texture &src, Slice slice = Slice()) const

Dispatch Tensor to Texture.

bool dispatch(Compute &compute, Texture &dest, Tensor src, Region region, Slice slice = Slice()) const
bool dispatch(Compute &compute, Texture &dest, Tensor src, Slice slice = Slice()) const

Enums

Operation

Graph operations.

NameValueDescription
Clear0Clear tensor values.
Range1Init tensor values from 0 to size.
Copy2Copy tensor with the same or different layout.
Cat3Concatenates two tensors.
Transpose4Transpose tensor.
MatMul5Matrix multiplication.
Mul6Value multiplication.
Mad7Value multiplication and addition.
Div8Value division.
Add9Value addition.
Conv10Convolution.
DeConv11Deconvolution.
BatchNorm12Batch normalization.
BatchMad13Batch multiplication and addition.
SoftMin14Softmin function.
SoftMax15Softmax function.
MaxPool16Maximum pooling.
AvgPool17Average pooling.
GELU18Gaussian error linear unit function.
ReLU19Rectified linear unit function.
SiLU20Sigmoid linear unit function.
Sigm21Sigmoid function.
Tanh22Tanh function.
Sin23Sin function.
Cos24Cos function.
Exp25Exp function.
NumOperations26

Flags

Graph flags.

NameValue
FlagNone0
FlagSizeQuery(1 << 0)
FlagFormatRf32(1 << 1)
FlagFormatRf16(1 << 2)
FlagTranspose(1 << 3)
FlagWrapClamp(1 << 4)
FlagWrapRepeat(1 << 5)
FlagReadScale(1 << 6)
FlagReadBias(1 << 7)
FlagConvert(1 << 8)
FlagKernel(1 << 9)
FlagGELU(1 << 10)
FlagReLU(1 << 11)
FlagSiLU(1 << 12)
FlagSigm(1 << 13)
FlagTanh(1 << 14)
FlagSin(1 << 15)
FlagCos(1 << 16)
FlagExp(1 << 17)
FlagFormatFlagFormatRf32 | FlagFormatRf16
FlagWrapFlagWrapClamp | FlagWrapRepeat
FlagReadFlagReadScale | FlagReadBias
FlagUnitFlagGELU | FlagReLU | FlagSiLU
FlagMathFlagSigm | FlagTanh | FlagSin | FlagCos | FlagExp
FlagsAllFlagFormat | FlagTranspose | FlagWrap | FlagRead | FlagConvert | FlagKernel | FlagUnit | FlagMath

Masks

Graph masks.

NameValue
MaskNone0
MaskClear(1 << Clear)
MaskRange(1 << Range)
MaskCopy(1 << Copy)
MaskCat(1 << Cat)
MaskTranspose(1 << Transpose)
MaskMatMul(1 << MatMul)
MaskMul(1 << Mul)
MaskMad(1 << Mad)
MaskDiv(1 << Div)
MaskAdd(1 << Add)
MaskConv(1 << Conv)
MaskDeConv(1 << DeConv)
MaskBatchNorm(1 << BatchNorm)
MaskBatchMad(1 << BatchMad)
MaskSoftMin(1 << SoftMin)
MaskSoftMax(1 << SoftMax)
MaskMaxPool(1 << MaxPool)
MaskAvgPool(1 << AvgPool)
MaskGELU(1 << GELU)
MaskReLU(1 << ReLU)
MaskSiLU(1 << SiLU)
MaskSigm(1 << Sigm)
MaskTanh(1 << Tanh)
MaskSin(1 << Sin)
MaskCos(1 << Cos)
MaskExp(1 << Exp)
MasksAll(1 << NumOperations) - 1