Viewing file: MatrixBuilder.h (10.58 KB) -rw-r--r-- Select action/file-type: (+) | (+) | (+) | Code (+) | Session (+) | (+) | SDB (+) | (+) | (+) | (+) | (+) | (+) |
//===- llvm/MatrixBuilder.h - Builder to lower matrix ops -------*- C++ -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file defines the MatrixBuilder class, which is used as a convenient way // to lower matrix operations to LLVM IR. // //===----------------------------------------------------------------------===//
#ifndef LLVM_IR_MATRIXBUILDER_H #define LLVM_IR_MATRIXBUILDER_H
#include "llvm/IR/Constant.h" #include "llvm/IR/Constants.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/InstrTypes.h" #include "llvm/IR/Instruction.h" #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/Type.h" #include "llvm/IR/Value.h" #include "llvm/Support/Alignment.h"
namespace llvm {
class Function; class Twine; class Module;
class MatrixBuilder { IRBuilderBase &B; Module *getModule() { return B.GetInsertBlock()->getParent()->getParent(); }
std::pair<Value *, Value *> splatScalarOperandIfNeeded(Value *LHS, Value *RHS) { assert((LHS->getType()->isVectorTy() || RHS->getType()->isVectorTy()) && "One of the operands must be a matrix (embedded in a vector)"); if (LHS->getType()->isVectorTy() && !RHS->getType()->isVectorTy()) { assert(!isa<ScalableVectorType>(LHS->getType()) && "LHS Assumed to be fixed width"); RHS = B.CreateVectorSplat( cast<VectorType>(LHS->getType())->getElementCount(), RHS, "scalar.splat"); } else if (!LHS->getType()->isVectorTy() && RHS->getType()->isVectorTy()) { assert(!isa<ScalableVectorType>(RHS->getType()) && "RHS Assumed to be fixed width"); LHS = B.CreateVectorSplat( cast<VectorType>(RHS->getType())->getElementCount(), LHS, "scalar.splat"); } return {LHS, RHS}; }
public: MatrixBuilder(IRBuilderBase &Builder) : B(Builder) {}
/// Create a column major, strided matrix load. /// \p EltTy - Matrix element type /// \p DataPtr - Start address of the matrix read /// \p Rows - Number of rows in matrix (must be a constant) /// \p Columns - Number of columns in matrix (must be a constant) /// \p Stride - Space between columns CallInst *CreateColumnMajorLoad(Type *EltTy, Value *DataPtr, Align Alignment, Value *Stride, bool IsVolatile, unsigned Rows, unsigned Columns, const Twine &Name = "") { auto *RetType = FixedVectorType::get(EltTy, Rows * Columns);
Value *Ops[] = {DataPtr, Stride, B.getInt1(IsVolatile), B.getInt32(Rows), B.getInt32(Columns)}; Type *OverloadedTypes[] = {RetType, Stride->getType()};
Function *TheFn = Intrinsic::getDeclaration( getModule(), Intrinsic::matrix_column_major_load, OverloadedTypes);
CallInst *Call = B.CreateCall(TheFn->getFunctionType(), TheFn, Ops, Name); Attribute AlignAttr = Attribute::getWithAlignment(Call->getContext(), Alignment); Call->addParamAttr(0, AlignAttr); return Call; }
/// Create a column major, strided matrix store. /// \p Matrix - Matrix to store /// \p Ptr - Pointer to write back to /// \p Stride - Space between columns CallInst *CreateColumnMajorStore(Value *Matrix, Value *Ptr, Align Alignment, Value *Stride, bool IsVolatile, unsigned Rows, unsigned Columns, const Twine &Name = "") { Value *Ops[] = {Matrix, Ptr, Stride, B.getInt1(IsVolatile), B.getInt32(Rows), B.getInt32(Columns)}; Type *OverloadedTypes[] = {Matrix->getType(), Stride->getType()};
Function *TheFn = Intrinsic::getDeclaration( getModule(), Intrinsic::matrix_column_major_store, OverloadedTypes);
CallInst *Call = B.CreateCall(TheFn->getFunctionType(), TheFn, Ops, Name); Attribute AlignAttr = Attribute::getWithAlignment(Call->getContext(), Alignment); Call->addParamAttr(1, AlignAttr); return Call; }
/// Create a llvm.matrix.transpose call, transposing \p Matrix with \p Rows /// rows and \p Columns columns. CallInst *CreateMatrixTranspose(Value *Matrix, unsigned Rows, unsigned Columns, const Twine &Name = "") { auto *OpType = cast<VectorType>(Matrix->getType()); auto *ReturnType = FixedVectorType::get(OpType->getElementType(), Rows * Columns);
Type *OverloadedTypes[] = {ReturnType}; Value *Ops[] = {Matrix, B.getInt32(Rows), B.getInt32(Columns)}; Function *TheFn = Intrinsic::getDeclaration( getModule(), Intrinsic::matrix_transpose, OverloadedTypes);
return B.CreateCall(TheFn->getFunctionType(), TheFn, Ops, Name); }
/// Create a llvm.matrix.multiply call, multiplying matrixes \p LHS and \p /// RHS. CallInst *CreateMatrixMultiply(Value *LHS, Value *RHS, unsigned LHSRows, unsigned LHSColumns, unsigned RHSColumns, const Twine &Name = "") { auto *LHSType = cast<VectorType>(LHS->getType()); auto *RHSType = cast<VectorType>(RHS->getType());
auto *ReturnType = FixedVectorType::get(LHSType->getElementType(), LHSRows * RHSColumns);
Value *Ops[] = {LHS, RHS, B.getInt32(LHSRows), B.getInt32(LHSColumns), B.getInt32(RHSColumns)}; Type *OverloadedTypes[] = {ReturnType, LHSType, RHSType};
Function *TheFn = Intrinsic::getDeclaration( getModule(), Intrinsic::matrix_multiply, OverloadedTypes); return B.CreateCall(TheFn->getFunctionType(), TheFn, Ops, Name); }
/// Insert a single element \p NewVal into \p Matrix at indices (\p RowIdx, \p /// ColumnIdx). Value *CreateMatrixInsert(Value *Matrix, Value *NewVal, Value *RowIdx, Value *ColumnIdx, unsigned NumRows) { return B.CreateInsertElement( Matrix, NewVal, B.CreateAdd(B.CreateMul(ColumnIdx, ConstantInt::get( ColumnIdx->getType(), NumRows)), RowIdx)); }
/// Add matrixes \p LHS and \p RHS. Support both integer and floating point /// matrixes. Value *CreateAdd(Value *LHS, Value *RHS) { assert(LHS->getType()->isVectorTy() || RHS->getType()->isVectorTy()); if (LHS->getType()->isVectorTy() && !RHS->getType()->isVectorTy()) { assert(!isa<ScalableVectorType>(LHS->getType()) && "LHS Assumed to be fixed width"); RHS = B.CreateVectorSplat( cast<VectorType>(LHS->getType())->getElementCount(), RHS, "scalar.splat"); } else if (!LHS->getType()->isVectorTy() && RHS->getType()->isVectorTy()) { assert(!isa<ScalableVectorType>(RHS->getType()) && "RHS Assumed to be fixed width"); LHS = B.CreateVectorSplat( cast<VectorType>(RHS->getType())->getElementCount(), LHS, "scalar.splat"); }
return cast<VectorType>(LHS->getType()) ->getElementType() ->isFloatingPointTy() ? B.CreateFAdd(LHS, RHS) : B.CreateAdd(LHS, RHS); }
/// Subtract matrixes \p LHS and \p RHS. Support both integer and floating /// point matrixes. Value *CreateSub(Value *LHS, Value *RHS) { assert(LHS->getType()->isVectorTy() || RHS->getType()->isVectorTy()); if (LHS->getType()->isVectorTy() && !RHS->getType()->isVectorTy()) { assert(!isa<ScalableVectorType>(LHS->getType()) && "LHS Assumed to be fixed width"); RHS = B.CreateVectorSplat( cast<VectorType>(LHS->getType())->getElementCount(), RHS, "scalar.splat"); } else if (!LHS->getType()->isVectorTy() && RHS->getType()->isVectorTy()) { assert(!isa<ScalableVectorType>(RHS->getType()) && "RHS Assumed to be fixed width"); LHS = B.CreateVectorSplat( cast<VectorType>(RHS->getType())->getElementCount(), LHS, "scalar.splat"); }
return cast<VectorType>(LHS->getType()) ->getElementType() ->isFloatingPointTy() ? B.CreateFSub(LHS, RHS) : B.CreateSub(LHS, RHS); }
/// Multiply matrix \p LHS with scalar \p RHS or scalar \p LHS with matrix \p /// RHS. Value *CreateScalarMultiply(Value *LHS, Value *RHS) { std::tie(LHS, RHS) = splatScalarOperandIfNeeded(LHS, RHS); if (LHS->getType()->getScalarType()->isFloatingPointTy()) return B.CreateFMul(LHS, RHS); return B.CreateMul(LHS, RHS); }
/// Divide matrix \p LHS by scalar \p RHS. If the operands are integers, \p /// IsUnsigned indicates whether UDiv or SDiv should be used. Value *CreateScalarDiv(Value *LHS, Value *RHS, bool IsUnsigned) { assert(LHS->getType()->isVectorTy() && !RHS->getType()->isVectorTy()); assert(!isa<ScalableVectorType>(LHS->getType()) && "LHS Assumed to be fixed width"); RHS = B.CreateVectorSplat(cast<VectorType>(LHS->getType())->getElementCount(), RHS, "scalar.splat"); return cast<VectorType>(LHS->getType()) ->getElementType() ->isFloatingPointTy() ? B.CreateFDiv(LHS, RHS) : (IsUnsigned ? B.CreateUDiv(LHS, RHS) : B.CreateSDiv(LHS, RHS)); }
/// Create an assumption that \p Idx is less than \p NumElements. void CreateIndexAssumption(Value *Idx, unsigned NumElements, Twine const &Name = "") { Value *NumElts = B.getIntN(Idx->getType()->getScalarSizeInBits(), NumElements); auto *Cmp = B.CreateICmpULT(Idx, NumElts); if (isa<ConstantInt>(Cmp)) assert(cast<ConstantInt>(Cmp)->isOne() && "Index must be valid!"); else B.CreateAssumption(Cmp); }
/// Compute the index to access the element at (\p RowIdx, \p ColumnIdx) from /// a matrix with \p NumRows embedded in a vector. Value *CreateIndex(Value *RowIdx, Value *ColumnIdx, unsigned NumRows, Twine const &Name = "") { unsigned MaxWidth = std::max(RowIdx->getType()->getScalarSizeInBits(), ColumnIdx->getType()->getScalarSizeInBits()); Type *IntTy = IntegerType::get(RowIdx->getType()->getContext(), MaxWidth); RowIdx = B.CreateZExt(RowIdx, IntTy); ColumnIdx = B.CreateZExt(ColumnIdx, IntTy); Value *NumRowsV = B.getIntN(MaxWidth, NumRows); return B.CreateAdd(B.CreateMul(ColumnIdx, NumRowsV), RowIdx); } };
} // end namespace llvm
#endif // LLVM_IR_MATRIXBUILDER_H
|