summaryrefslogtreecommitdiff
path: root/source/slang/ast-legalize.cpp
diff options
context:
space:
mode:
authorTim Foley <tfoleyNV@users.noreply.github.com>2017-12-18 15:14:59 -0800
committerGitHub <noreply@github.com>2017-12-18 15:14:59 -0800
commit393e25fd2e2b8c5ff82ff4c6b14a9d7152d37a5e (patch)
treea3b0617e7ce5a5bfcf43454893f6c2962b7ec382 /source/slang/ast-legalize.cpp
parent46b68ed41daecfaf1761e299cf040156e0f65ac0 (diff)
Work on getting rewriter + IR playing nice together. (#314)
* Work on getting rewriter + IR playing nice together. There are a few different changes here, with the goal of improving the interaction between the "rewriter" code generation approach and the new IR and type legalization code. The main changes are: - Add a new pass that occurs before the AST legalization pass, which walks the (used) AST declarations and tries to discover (1) which declarations need to be specialized/lowered via the IR, and (2) which declarations need to be included in the resulting AST module. - AST-based legalization now uses the generated list when in "rewriter" mode, so that we should be working around issues that users were seeing with types not getting emitted. - TODO: we still need an equivalent fixup in the case of non-"rewriter" emit, so this may still be a problem for `.slang` files. - IR type legalization now precedes AST legalization, so that we can record information on how any IR global values got legalized (e.g., if they got split). Then AST legalization includes logic to reconstruct suitable tuple expressions to reference a split global. - When emitting using IR + AST, we walk all of the declarations that we decided belonged to the IR, but which were subsequently referenced in the AST, to make sure they get output (this would include `struct` types that are declared in a file compiled via IR, but never used in IR-based code). The rewriter+IR use case still doesn't *quite* work, but the logic for walking the AST in a pre-pass ends up being needed/useful to fix some pure rewriter bugs, so I'm getting this checked in sooner rather than later. * Fixup: walk arguments to generic declaration reference The gotcha here is that the code for walking the AST would walk a line of code like: SomeType a; and know to traverse the declaration of `SomeType`, but if it saw a line of code like: ParameterBlock<SomeType> b; it would traverse the declaration of `ParameterBlock`, but fail to visit that of `SomeType`.
Diffstat (limited to 'source/slang/ast-legalize.cpp')
-rw-r--r--source/slang/ast-legalize.cpp644
1 files changed, 629 insertions, 15 deletions
diff --git a/source/slang/ast-legalize.cpp b/source/slang/ast-legalize.cpp
index 144277301..2f7814055 100644
--- a/source/slang/ast-legalize.cpp
+++ b/source/slang/ast-legalize.cpp
@@ -4,6 +4,7 @@
#include "emit.h"
#include "ir-insts.h"
#include "legalize-types.h"
+#include "mangle.h"
#include "type-layout.h"
#include "visitor.h"
@@ -868,7 +869,7 @@ struct LoweringVisitor
RefPtr<Expr> createUncheckedVarRef(
- char const* name)
+ String const& name)
{
return createUncheckedVarRef(
shared->compileRequest->getNamePool()->getName(name));
@@ -2672,6 +2673,105 @@ struct LoweringVisitor
return translateDeclRefImpl(DeclRef<Decl>(decl, nullptr));
}
+ LegalExpr translateSimpleLegalValToLegalExpr(IRValue* irVal)
+ {
+ switch (irVal->op)
+ {
+ case kIROp_global_var:
+ {
+ IRGlobalVar* globalVar = (IRGlobalVar*)irVal;
+ String mangledName = globalVar->mangledName;
+ SLANG_ASSERT(mangledName.Length() != 0);
+
+ return LegalExpr(createUncheckedVarRef(mangledName));
+ }
+ break;
+
+ default:
+ SLANG_UNEXPECTED("unhandled opcode");
+ UNREACHABLE_RETURN(LegalExpr());
+ }
+ }
+
+ LegalExpr translateLegalValToLegalExpr(LegalVal legalVal)
+ {
+ switch (legalVal.flavor)
+ {
+ case LegalVal::Flavor::none:
+ return LegalExpr();
+
+ case LegalVal::Flavor::simple:
+ return translateSimpleLegalValToLegalExpr(legalVal.getSimple());
+ break;
+
+ case LegalVal::Flavor::pair:
+ {
+ auto pairVal = legalVal.getPair();
+ RefPtr<PairPseudoExpr> pairExpr = new PairPseudoExpr();
+ pairExpr->pairInfo = pairVal->pairInfo;
+ pairExpr->ordinary = translateLegalValToLegalExpr(pairVal->ordinaryVal);
+ pairExpr->special = translateLegalValToLegalExpr(pairVal->specialVal);
+ return LegalExpr(pairExpr);
+ }
+ break;
+
+ case LegalVal::Flavor::tuple:
+ {
+ auto tupleVal = legalVal.getTuple();
+ RefPtr<TuplePseudoExpr> tupleExpr = new TuplePseudoExpr();
+ for (auto ee : tupleVal->elements)
+ {
+ TuplePseudoExpr::Element element;
+ element.fieldDeclRef = ee.fieldDeclRef;
+ element.expr = translateLegalValToLegalExpr(ee.val);
+ tupleExpr->elements.Add(element);
+ }
+ return LegalExpr(tupleExpr);
+ }
+ break;
+
+ case LegalVal::Flavor::implicitDeref:
+ {
+ auto implicitDerefVal = legalVal.getImplicitDeref();
+ RefPtr<ImplicitDerefPseudoExpr> implicitDerefExpr = new ImplicitDerefPseudoExpr();
+ implicitDerefExpr->valueExpr = translateLegalValToLegalExpr(implicitDerefVal);
+ return LegalExpr(implicitDerefExpr);
+ }
+ break;
+
+ default:
+ SLANG_UNEXPECTED("unhandled flavor");
+ UNREACHABLE_RETURN(LegalExpr());
+ }
+ }
+
+ void maybeLegalizeIRGlobal(
+ DeclRef<Decl> declRef)
+ {
+ // We've been given a decl-ref to a value that was translated via IR,
+ // and we need to determine if it needs custom handling for legalization,
+ // because it was an IR global that got split.
+
+ // TODO: this code is using decls in places it should use decl-refs,
+ // and that likely needs to get cleaned up...
+ auto decl = declRef.getDecl();
+
+ // If we already have an expression registered, then don't bother.
+ if (shared->mapOriginalDeclToExpr.ContainsKey(decl))
+ return;
+
+ String mangledName = getMangledName(declRef);
+ if (mangledName.Length() == 0)
+ return;
+
+ LegalVal legalVal;
+ if (!shared->typeLegalizationContext->mapMangledNameToLegalIRValue.TryGetValue(mangledName, legalVal))
+ return;
+
+ LegalExpr legalExpr = translateLegalValToLegalExpr(legalVal);
+ shared->mapOriginalDeclToExpr.Add(decl, legalExpr);
+ }
+
RefPtr<Decl> translateDeclRefImpl(
DeclRef<Decl> declRef)
{
@@ -2703,24 +2803,25 @@ struct LoweringVisitor
auto parentModule = findModuleForDecl(decl);
if (parentModule && (parentModule != shared->mainModuleDecl))
{
- // Ensure that the IR code for the given declaration
- // gets included in the output IR module, and *also*
- // that we generate a suitable specialization of it
- // if there are any substitutions in effect.
-
- getSpecializedGlobalValueForDeclRef(
- shared->irSpecializationState,
- declRef);
+ // This declaration should already have been lowered to
+ // the IR during the "walk" pass that happened earlier,
+ // and so we won't do it again here.
+ //
+ // Instead, we need to check if the particular
+ // declaration is one that needs to be swapped for
+ // a legalized expression (e.g., because it was an IR
+ // global that got split)
+ //
+ maybeLegalizeIRGlobal(declRef);
// Remember that this declaration is handled via IR,
// rather than being present in the legalized AST.
shared->result.irDecls.Add(declRef.getDecl());
- // We don't actually use the `IRGlobalValue` that the
- // above operation returns, and instead just keep
- // using the original declaration in the legalized
- // AST. The step of mapping that declaration over
- // to reference the IR symbol will happen later.
+ // This method can't actually return a `LegalExpr`,
+ // so for now we just assume that the original
+ // declaration is the right stand-in for the IR
+ // value we want.
return decl;
}
@@ -4671,7 +4772,8 @@ LoweredEntryPoint lowerEntryPoint(
CodeGenTarget target,
ExtensionUsageTracker* extensionUsageTracker,
IRSpecializationState* irSpecializationState,
- TypeLegalizationContext* typeLegalizationContext)
+ TypeLegalizationContext* typeLegalizationContext,
+ List<Decl*> astDecls)
{
SharedLoweringContext sharedContext;
sharedContext.compileRequest = entryPoint->compileRequest;
@@ -4732,10 +4834,21 @@ LoweredEntryPoint lowerEntryPoint(
if (isRewrite)
{
+ for (auto dd : astDecls)
+ {
+ // Skip non-global decls
+ if (!dd->ParentDecl)
+ continue;
+ if (!dynamic_cast<ModuleDecl*>(dd->ParentDecl))
+ continue;
+ visitor.translateDeclRef(dd);
+ }
+#if 0
for (auto dd : translationUnit->SyntaxNode->Members)
{
visitor.translateDeclRef(dd);
}
+#endif
}
else
{
@@ -4747,4 +4860,505 @@ LoweredEntryPoint lowerEntryPoint(
return sharedContext.result;
}
+
+struct FindIRDeclUsedByASTVisitor
+ : ExprVisitor<FindIRDeclUsedByASTVisitor, void>
+ , StmtVisitor<FindIRDeclUsedByASTVisitor, void>
+ , DeclVisitor<FindIRDeclUsedByASTVisitor, void>
+ , ValVisitor<FindIRDeclUsedByASTVisitor, void, void>
+
+{
+ CompileRequest* compileRequest;
+ IRSpecializationState* irSpecializationState;
+ ModuleDecl* mainModuleDecl;
+
+ // Declarations to be processed by the AST lowering pass
+ List<Decl*>* astDecls;
+
+ HashSet<DeclBase*> seenDecls;
+ HashSet<DeclBase*> addedDecls;
+
+ void walkType(Type* type)
+ {
+ if(!type) return;
+
+ TypeVisitor::dispatch(type);
+ }
+
+ void walkVal(Val* val)
+ {
+ if(!val) return;
+
+ ValVisitor::dispatch(val);
+ }
+
+ void walkExpr(Expr* expr)
+ {
+ if(!expr) return;
+
+ ExprVisitor::dispatch(expr);
+ }
+
+ void walkStmt(Stmt* stmt)
+ {
+ if(!stmt) return;
+
+ StmtVisitor::dispatch(stmt);
+ }
+
+ void walkSubst(Substitutions* subst)
+ {
+ if( auto genericSubst = dynamic_cast<GenericSubstitution*>(subst) )
+ {
+ for( auto arg : genericSubst->args )
+ {
+ walkVal(arg);
+ }
+ }
+ // TODO: handle other cases here
+ }
+
+ void walkDeclRef(DeclRef<Decl> const& declRef)
+ {
+ Decl* decl = declRef.getDecl();
+ if (!decl) return;
+
+ // If this is a specialized declaration reference, then any
+ // of the arguments also need to be walked.
+ for(auto subst = declRef.substitutions; subst; subst = subst->outer)
+ {
+ walkSubst(subst);
+ }
+
+ // If any parent of the declaration was in the stdlib, or
+ // is registered as a builtin, then skip it.
+ for (auto pp = decl; pp; pp = pp->ParentDecl)
+ {
+ if (pp->HasModifier<FromStdLibModifier>())
+ return;
+
+ if (pp->HasModifier<BuiltinModifier>())
+ return;
+ }
+
+ // If we are using the IR, and the declaration comes from
+ // an imported module (rather than the "rewrite-mode" module
+ // being translated), then we need to ensure that it gets lowered
+ // to IR instead.
+ if (compileRequest->compileFlags & SLANG_COMPILE_FLAG_USE_IR)
+ {
+ auto parentModule = findModuleForDecl(decl);
+ if (parentModule && (parentModule != mainModuleDecl))
+ {
+ // Ensure that the IR code for the given declaration
+ // gets included in the output IR module, and *also*
+ // that we generate a suitable specialization of it
+ // if there are any substitutions in effect.
+
+ getSpecializedGlobalValueForDeclRef(
+ irSpecializationState,
+ declRef);
+
+ // TODO: we probably need to track this value...
+
+ return;
+ }
+ }
+
+ // If none of the above triggered, then we seemingly have
+ // a declaration from the current module, and we should
+ // add it to our work list so we can walk it too.
+ addDecl(decl);
+ }
+
+ // Vals
+
+ void visitIRProxyVal(IRProxyVal*)
+ {}
+
+ void visitConstantIntVal(ConstantIntVal*)
+ {}
+
+ void visitGenericParamIntVal(GenericParamIntVal* val)
+ {
+ walkDeclRef(val->declRef);
+ }
+
+ void visitWitness(Witness*)
+ {}
+
+ // Types
+
+ void visitOverloadGroupType(OverloadGroupType*)
+ {}
+
+ void visitInitializerListType(InitializerListType*)
+ {}
+
+ void visitErrorType(ErrorType*)
+ {}
+
+ void visitIRBasicBlockType(IRBasicBlockType*)
+ {}
+
+ void visitDeclRefType(DeclRefType* type)
+ {
+ walkDeclRef(type->declRef);
+ }
+
+ void visitGenericDeclRefType(GenericDeclRefType* type)
+ {
+ walkDeclRef(type->declRef);
+ }
+
+ void visitNamedExpressionType(NamedExpressionType* type)
+ {
+ walkDeclRef(type->declRef);
+ }
+
+ void visitFuncType(FuncType* type)
+ {
+ for( auto p : type->paramTypes )
+ {
+ walkType(p);
+ }
+ walkType(type->resultType);
+ }
+
+ void visitTypeType(TypeType* type)
+ {
+ walkType(type->type);
+ }
+
+ void visitGroupSharedType(GroupSharedType* type)
+ {
+ walkType(type->valueType);
+ }
+
+ void visitArrayExpressionType(ArrayExpressionType* type)
+ {
+ walkType(type->baseType);
+ walkVal(type->ArrayLength);
+ }
+
+ // Exprs
+
+ void visitVarExpr(VarExpr* expr)
+ {
+ walkDeclRef(expr->declRef);
+ }
+
+ void visitMemberExpr(MemberExpr* expr)
+ {
+ walkExpr(expr->BaseExpression);
+ walkDeclRef(expr->declRef);
+ }
+
+ void visitStaticMemberExpr(StaticMemberExpr* expr)
+ {
+ walkExpr(expr->BaseExpression);
+ walkDeclRef(expr->declRef);
+ }
+
+ void visitOverloadedExpr(OverloadedExpr* expr)
+ {
+ walkExpr(expr->base);
+
+ // TODO: need to walk the lookup result too
+ }
+
+ void visitConstantExpr(ConstantExpr*)
+ {}
+
+ void visitInitializerListExpr(InitializerListExpr* expr)
+ {
+ for(auto a : expr->args)
+ walkExpr(a);
+ }
+
+ void visitAppExprBase(AppExprBase* expr)
+ {
+ walkExpr(expr->FunctionExpr);
+ for(auto a : expr->Arguments)
+ walkExpr(a);
+ }
+
+ void visitAggTypeCtorExpr(AggTypeCtorExpr* expr)
+ {
+ walkType(expr->base);
+ for(auto a : expr->Arguments)
+ walkExpr(a);
+ }
+
+ void visitSharedTypeExpr(SharedTypeExpr* expr)
+ {
+ walkType(expr->base);
+ }
+
+ void visitAssignExpr(AssignExpr* expr)
+ {
+ walkExpr(expr->left);
+ walkExpr(expr->right);
+ }
+
+ void visitIndexExpr(IndexExpr* expr)
+ {
+ walkExpr(expr->BaseExpression);
+ walkExpr(expr->IndexExpression);
+ }
+
+ void visitSwizzleExpr(SwizzleExpr* expr)
+ {
+ walkExpr(expr->base);
+ }
+
+ void visitDerefExpr(DerefExpr* expr)
+ {
+ walkExpr(expr->base);
+ }
+
+ void visitParenExpr(ParenExpr* expr)
+ {
+ walkExpr(expr->base);
+ }
+
+ void visitThisExpr(ThisExpr*)
+ {}
+
+ // Stmts
+
+ void visitSeqStmt(SeqStmt* stmt)
+ {
+ for( auto s : stmt->stmts )
+ {
+ walkStmt(s);
+ }
+ }
+
+ void visitBlockStmt(BlockStmt* stmt)
+ {
+ walkStmt(stmt->body);
+ }
+
+ void visitUnparsedStmt(UnparsedStmt*)
+ {}
+
+ void visitEmptyStmt(EmptyStmt*)
+ {}
+
+ void visitDiscardStmt(DiscardStmt*)
+ {}
+
+ void visitDeclStmt(DeclStmt* stmt)
+ {
+ addDecl(stmt->decl);
+ }
+
+ void visitIfStmt(IfStmt* stmt)
+ {
+ walkExpr(stmt->Predicate);
+ walkStmt(stmt->PositiveStatement);
+ walkStmt(stmt->NegativeStatement);
+ }
+
+ void visitSwitchStmt(SwitchStmt* stmt)
+ {
+ walkExpr(stmt->condition);
+ walkStmt(stmt->body);
+ }
+
+ void visitCaseStmt(CaseStmt* stmt)
+ {
+ walkExpr(stmt->expr);
+ }
+
+ void visitDefaultStmt(DefaultStmt*)
+ {}
+
+ void visitForStmt(ForStmt* stmt)
+ {
+ walkStmt(stmt->InitialStatement);
+ walkExpr(stmt->SideEffectExpression);
+ walkExpr(stmt->PredicateExpression);
+ walkStmt(stmt->Statement);
+ }
+
+ void visitWhileStmt(WhileStmt* stmt)
+ {
+ walkExpr(stmt->Predicate);
+ walkStmt(stmt->Statement);
+ }
+
+ void visitDoWhileStmt(DoWhileStmt* stmt)
+ {
+ walkExpr(stmt->Predicate);
+ walkStmt(stmt->Statement);
+ }
+
+ void visitCompileTimeForStmt(CompileTimeForStmt* stmt)
+ {
+ addDecl(stmt->varDecl);
+ walkExpr(stmt->rangeBeginExpr);
+ walkExpr(stmt->rangeEndExpr);
+ walkStmt(stmt->body);
+ }
+
+ void visitReturnStmt(ReturnStmt* stmt)
+ {
+ walkExpr(stmt->Expression);
+ }
+
+ void visitExpressionStmt(ExpressionStmt* stmt)
+ {
+ walkExpr(stmt->Expression);
+ }
+
+ void visitJumpStmt(JumpStmt*)
+ {}
+
+ // Decls
+
+ void visitDeclGroup(DeclGroup* declGroup)
+ {
+ for( auto dd : declGroup->decls )
+ {
+ addDecl(dd);
+ }
+ }
+
+ void visitContainerDeclCommon(ContainerDecl* decl)
+ {
+ for( auto mm : decl->Members )
+ {
+ addDecl(mm);
+ }
+ }
+
+ void visitContainerDecl(ContainerDecl* decl)
+ {
+ visitContainerDeclCommon(decl);
+ }
+
+ void visitVarDeclBase(VarDeclBase* decl)
+ {
+ walkType(decl->type);
+ walkExpr(decl->initExpr);
+ }
+
+ void visitAggTypeDeclBase(AggTypeDeclBase* decl)
+ {
+ visitContainerDeclCommon(decl);
+ }
+
+ void visitInheritanceDecl(InheritanceDecl* decl)
+ {
+ walkType(decl->base);
+ }
+
+ void visitTypeDefDecl(TypeDefDecl* decl)
+ {
+ walkType(decl->type);
+ }
+
+ void visitCallableDeclCommon(CallableDecl* decl)
+ {
+ visitContainerDeclCommon(decl);
+ walkType(decl->ReturnType);
+ }
+
+ void visitCallableDecl(CallableDecl* decl)
+ {
+ visitCallableDeclCommon(decl);
+ }
+
+ void visitFunctionDeclBase(FunctionDeclBase* decl)
+ {
+ visitCallableDeclCommon(decl);
+ walkStmt(decl->Body);
+ }
+
+ void visitImportDecl(ImportDecl*)
+ {}
+
+ void visitGenericTypeParamDecl(GenericTypeParamDecl*)
+ {}
+
+ void visitGenericTypeConstraintDecl(GenericTypeConstraintDecl*)
+ {}
+
+ void visitEmptyDecl(EmptyDecl*)
+ {}
+
+ void visitSyntaxDecl(SyntaxDecl*)
+ {}
+
+ //
+
+ void addDecl(DeclBase* decl)
+ {
+ // Has this decl already been added
+ // to the output list?
+ if(addedDecls.Contains(decl))
+ return;
+
+ // Are we in the middel of processing this
+ // decl?
+ //
+ // TODO: this implies a cycle, and we need to
+ // break it!
+ if (seenDecls.Contains(decl))
+ return;
+
+ seenDecls.Add(decl);
+
+ // Recurse on the given decl
+ DeclVisitor::dispatch(decl);
+
+ // Add it to the output list, if needed
+ if (auto dd = dynamic_cast<Decl*>(decl))
+ {
+ (*astDecls).Add(dd);
+ }
+
+ // Mark it as completely processed
+ addedDecls.Add(decl);
+ }
+
+ void flush()
+ {
+ }
+};
+
+
+void findDeclsUsedByASTEntryPoint(
+ EntryPointRequest* entryPoint,
+ CodeGenTarget target,
+ IRSpecializationState* irSpecializationState,
+ List<Decl*>& outASTDecls)
+{
+ auto translationUnit = entryPoint->getTranslationUnit();
+ auto mainModuleDecl = translationUnit->SyntaxNode;
+
+ FindIRDeclUsedByASTVisitor visitor;
+ visitor.compileRequest = entryPoint->compileRequest;
+ visitor.irSpecializationState = irSpecializationState;
+ visitor.mainModuleDecl = mainModuleDecl;
+ visitor.astDecls = &outASTDecls;
+
+ bool isRewrite = isRewriteRequest(translationUnit->sourceLanguage, target);
+
+ if (isRewrite)
+ {
+ visitor.addDecl(mainModuleDecl);
+ }
+ else
+ {
+ visitor.addDecl(entryPoint->decl);
+ }
+
+ visitor.flush();
+}
+
+
+
}