Skip to content

Commit

Permalink
[TOSA] Update tosa.mul's shift as input
Browse files Browse the repository at this point in the history
* In TOSA v1.0, tosa.mul's `shift` is an input. This commit updates
  Torch to TOSA in alignment with that change.
* Update `tosa::createMulOpAndCast()` function to create a const shift
  TOSA tensor based on the provided shift value. This should be the API
  used to create tosa.mul from now on (instead of using
  `rewriter.create<tosa::MulOp>()`).
* Update `tosa::CreateOpAndInfer()` function to call
  `tosa::CreateOpAndInferShape()` function.
* Update LIT tests.

Signed-off-by: Justin Ngo <[email protected]>
Change-Id: I84aeccacbb33eee65e5923725ace86a78f877869
  • Loading branch information
justin-ngo-arm committed Feb 10, 2025
1 parent 460c9f3 commit af2aff6
Show file tree
Hide file tree
Showing 6 changed files with 490 additions and 417 deletions.
66 changes: 20 additions & 46 deletions include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,14 @@
#ifndef TORCHMLIR_CONVERSION_TORCHTOTOSA_TOSALEGALIZEUTILS_H
#define TORCHMLIR_CONVERSION_TORCHTOTOSA_TOSALEGALIZEUTILS_H

#include "mlir/Dialect/Quant/IR/QuantTypes.h" // from @llvm-project
#include "mlir/Dialect/Tosa/Utils/ShapeUtils.h" // from @llvm-project
#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
#include "mlir/IR/BuiltinTypes.h" // from @llvm-project
#include "mlir/IR/PatternMatch.h" // from @llvm-project
#include "mlir/Interfaces/InferTypeOpInterface.h" // from @llvm-project
#include "mlir/Support/LLVM.h" // from @llvm-project
#include "mlir/Dialect/Quant/IR/QuantTypes.h" // from @llvm-project
#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h" // from @llvm-project
#include "mlir/Dialect/Tosa/Utils/ShapeUtils.h" // from @llvm-project
#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
#include "mlir/IR/BuiltinTypes.h" // from @llvm-project
#include "mlir/IR/PatternMatch.h" // from @llvm-project
#include "mlir/Interfaces/InferTypeOpInterface.h" // from @llvm-project
#include "mlir/Support/LLVM.h" // from @llvm-project

namespace mlir {
namespace tosa {
Expand Down Expand Up @@ -45,6 +46,10 @@ bool isScale32(mlir::quant::UniformQuantizedType output_element_type);
Value getTosaConstTensorSingleF32(PatternRewriter &rewriter, Operation *op,
float val);

// Create an int8_t const tosa.mul shift tensor from an int
Value getTosaMulShiftConstTensor(PatternRewriter &rewriter, Operation *op,
int32_t shift);

// Create a zero constant tensor of the desired type and shape.
std::optional<Value> getZerosLikeTensor(PatternRewriter &rewriter,
Operation *op, Type type);
Expand All @@ -65,48 +70,17 @@ Value promoteType(PatternRewriter &rewriter, Value input, TensorType outType);

// Creates a TOSA operation and performs shape inference on the individual
// op. This allows shape inference during the framework to TOSA lowering.
template <typename TosaOp, typename... Args>
TosaOp CreateOpAndInfer(ImplicitLocOpBuilder &builder, Type result_ty,
Args &&...args) {
return CreateOpAndInferShape<TosaOp>(builder, result_ty, args...);
}

template <typename TosaOp, typename... Args>
TosaOp CreateOpAndInfer(PatternRewriter &rewriter, Location loc, Type result_ty,
Args &&...args) {
auto op = rewriter.create<TosaOp>(loc, result_ty, args...);

InferShapedTypeOpInterface shapeInterface =
dyn_cast<InferShapedTypeOpInterface>(op.getOperation());
if (!shapeInterface)
return op;

SmallVector<ShapedTypeComponents> returnedShapes;
if (shapeInterface
.inferReturnTypeComponents(op.getContext(), op.getLoc(),
op->getOperands(), op->getAttrDictionary(),
op->getPropertiesStorage(),
op->getRegions(), returnedShapes)
.failed())
return op;

// We need to use the element type of the existing result type to generate
// the new result shaped type. This is because rescale can include a cast to
// different bit-width types and does not have a TypeAttr to define the
// target type.
auto result = op->getResult(0);
auto predictedShape = returnedShapes[0];
auto currentKnowledge = ValueKnowledge::getKnowledgeFromType(result_ty);

// Compute the knowledge based on the inferred type.
auto inferredKnowledge = ValueKnowledge::getPessimisticValueState();
inferredKnowledge.dtype = cast<ShapedType>(result_ty).getElementType();
inferredKnowledge.hasRank = predictedShape.hasRank();
if (predictedShape.hasRank()) {
for (auto dim : predictedShape.getDims()) {
inferredKnowledge.sizes.push_back(dim);
}
}

// Compute the new type based on the joined version.
auto newKnowledge = ValueKnowledge::join(currentKnowledge, inferredKnowledge);
auto new_ty = newKnowledge.getType();
result.setType(new_ty);
return op;
ImplicitLocOpBuilder builder(loc, rewriter);
return CreateOpAndInfer<TosaOp>(builder, result_ty, args...);
}

template <typename TosaOp, typename... Args>
Expand Down
Loading

0 comments on commit af2aff6

Please sign in to comment.