Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

jl-generators.cpp changes #447

Closed
wants to merge 13 commits into from
119 changes: 85 additions & 34 deletions deps/ReactantExtra/tblgen/jl-generators.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include <optional>
#include <regex>
#include <string>
#include <tuple>

#include "mlir/TableGen/Argument.h"
#include "mlir/TableGen/CodeGenHelpers.h"
Expand Down Expand Up @@ -175,19 +176,17 @@ std::string emitEnum(EnumAttr e) {

const llvm::StringMap<std::string> cppToJuliaTypeMap = {
{"int64_t", "Int"},
{"uint64_t", "Int"},
{"bool", "Bool"},
{"unsigned", "Int"},
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should be UInt

{"Type", "IR.Type"},
{"APInt", "Int64"}, //TODO: add support in reactant to AP
{"Attribute", "IR.Attribute"}};

const llvm::StringMap<std::string> standardAttributeToJuliaTypeMap = {
{"I32Attr", "Int32"},
{"I64Attr", "Int64"},
{"DenseI64ArrayAttr", "Vector{Int64}"},
{"F32Attr", "Float32"},
{"F64Attr", "Float64"},
{"BoolAttr", "Bool"},
{"StrAttr", "String"},
{"FunctionType", "IR.Type"},
{"APInt", "Int64"}, // TODO: add support in reactant to AP
{"Attribute", "IR.Attribute"},
{"StringRef", "String"},
{"ArrayAttr", "Vector{Attribute}"},
{"FlatSymbolRefAttr", "IR.FlatSymbol"},
{"ArrayRef<int64_t>", "Vector{Int}"},
};

std::string toPascalCase(std::string s) {
Expand Down Expand Up @@ -216,20 +215,48 @@ std::string removeNamespace(std::string s) {
return s.substr(pos + 2);
}

llvm::StringMap<std::string> structMap;
std::string assemblyFormatToJulia(std::string s) {
auto p = 0;
auto output = std::string();
for (auto [i, c] : llvm::enumerate(s)) {
if (c == '`')
continue;
if (c == '$')
p = i;

if (p != 0 && c == ' ') {
auto name = s.substr(p + 1, i - p - 1);
auto new_name = llvm::formatv("$(s.{})", sanitizeName(name));
output.append(new_name);
p = 0;
continue;
}

if (p == 0 && c != ' ')
output.push_back(c);
}
return output;
}

llvm::StringMap<std::string> structMap;
std::optional<std::string> emitStruct(llvm::Record def, std::string dialect) {
auto assembly = def.getValueAsOptionalString("assemblyFormat");
auto standardStructAssembly =
!assembly || *assembly == "`<` struct(params) `>`";
auto mnemonic = def.getValueAsString("mnemonic").str();
auto StructName = toPascalCase(mnemonic);
auto params = def.getValueAsDag("parameters");
auto predicate = def.getValue("predicate")->getValue()->getAsString();
auto structDef = "struct " + StructName + '\n';
auto mlirAttributeDef = "IR.Attribute(s::" + StructName +
") = parse(IR.Attribute,\"#" + dialect + "." +
mnemonic + '<';
mnemonic;
if (standardStructAssembly)
mlirAttributeDef.push_back('<');
for (auto [arg, name_] :
llvm::zip(params->getArgs(), params->getArgNames())) {
auto name = name_->getAsUnquotedString();
auto sanitizedName = sanitizeName(name);
auto isArray = false;
std::string cppType;
std::optional<std::string> juliaType;
Expand Down Expand Up @@ -269,15 +296,21 @@ std::optional<std::string> emitStruct(llvm::Record def, std::string dialect) {
juliaType = juliaTypeEntry->getValue();
}
structDef +=
'\t' + name + "::" +
'\t' + sanitizedName + "::" +
(isArray ? llvm::formatv("Vector{{{}}", *juliaType) : *juliaType) +
'\n';
mlirAttributeDef += llvm::formatv("{0} = $(s.{0}), ", name);
if (standardStructAssembly)
mlirAttributeDef +=
llvm::formatv("{0} = $(s.{1}), ", name, sanitizedName);
}
structDef += "end";

mlirAttributeDef.resize(mlirAttributeDef.length() - 2);
mlirAttributeDef += ">\")";
if (standardStructAssembly) {
mlirAttributeDef.resize(mlirAttributeDef.length() - 2); // remove ,
mlirAttributeDef += ">";
} else
mlirAttributeDef +=
assemblyFormatToJulia(def.getValueAsString("assemblyFormat").str());
mlirAttributeDef += "\")";

structMap.insert({predicate, StructName});
return structDef + "\n\n" + mlirAttributeDef + "\n\n";
Expand All @@ -302,23 +335,27 @@ bool emitOpTableDefs(const llvm::RecordKeeper &recordKeeper,
llvm::ArrayRef<const llvm::Record *> attrdefs =
recordKeeper.getAllDerivedDefinitionsIfDefined("Attr");

auto enumAttr = recordKeeper.getClass("StrAttr");
recordKeeper.getDef("StrAttr");
auto attribs = std::string();
for (auto a : attrdefs) {
mlir::tblgen::Attribute attr(a);
if (attr.isDerivedAttr())
continue;
auto name = attr.getDefName();
if (attr.isEnumAttr()) {
if (attr.isEnumAttr()) { // Some EnumAttribute don't have an EnumAttrInfo
// bound : see VHLO dialect
EnumAttr enumAttr(a);
attribs += emitEnum(enumAttr) + '\n';
auto s = enumAttr.getCppNamespace() +
"::" + enumAttr.getSpecializedAttrClassName();
attrMap.insert({s.str(), enumAttr});
continue;
}
if (attr.getDef().getType()->getAsString() == "EnumAttr")

if (attr.isSubClassOf("EnumAttr"))
continue;

if (attr.getDef().getValue("attrName")) { // detect "struct" attributes
auto structAttr = emitStruct(attr.getDef(), moduleName);
if (!structAttr)
Expand Down Expand Up @@ -475,7 +512,7 @@ end
bool variadic = namedResult.isVariadic();

if (variadic) {
type = "Tuple{Vararg{" + type + "}}";
type = "Union{Vector{" + type + "}, Tuple{Vararg{" + type + "}}}";
}

if (optional) {
Expand Down Expand Up @@ -532,20 +569,34 @@ end
DialectName, VarName);
}
} else {
auto entry = standardAttributeToJuliaTypeMap.find(
attr.getDefName()); // TODO: use predicate & fuse with structMap
if (entry != standardAttributeToJuliaTypeMap.end()) {
varType = entry->second;
pushedExpression = "Attribute(" + VarName + ")";
attr = optional ? attr.getBaseAttr() : attr;

auto attr_entry = cppToJuliaTypeMap.find(attr.getAttrDefName());
if (attr_entry != cppToJuliaTypeMap.end()) {
varType = attr_entry->getValue();
pushedExpression = VarName;
} else {
auto entry = structMap.find(
attr.getDef().getValue("predicate")->getValue()->getAsString());
if (entry != structMap.end()) {
varType = entry->second;
pushedExpression = "Attribute(" + VarName + ")";
} else {
auto fullCppType = attr.getDef()
.getValue("returnType")
->getValue()
->getAsUnquotedString();
auto cppType = removeNamespace(fullCppType);
cppType.erase(std::remove(cppType.begin(), cppType.end(), ' '),
cppType.end());
auto entry = cppToJuliaTypeMap.find(cppType);
if (entry != cppToJuliaTypeMap.end()) {
varType = entry->getValue();
pushedExpression = VarName;
varType = "Attribute";
} else {
auto entry = structMap.find(
attr.getDef().getValue("predicate")->getValue()->getAsString());
if (entry != structMap.end()) {
varType = entry->getValue();
pushedExpression = "Attribute(" + VarName + ")";
} else {
pushedExpression = VarName;
varType = "Any";
}
}
}
}
Expand Down
Loading