summaryrefslogtreecommitdiff
path: root/source
diff options
context:
space:
mode:
Diffstat (limited to 'source')
-rw-r--r--source/core/list.h2
-rw-r--r--source/slang/ast-legalize.cpp644
-rw-r--r--source/slang/ast-legalize.h9
-rw-r--r--source/slang/emit.cpp111
-rw-r--r--source/slang/ir-legalize-types.cpp164
-rw-r--r--source/slang/legalize-types.h103
6 files changed, 885 insertions, 148 deletions
diff --git a/source/core/list.h b/source/core/list.h
index b1461a260..5a94d8b83 100644
--- a/source/core/list.h
+++ b/source/core/list.h
@@ -381,7 +381,7 @@ namespace Slang
void Reverse()
{
- for (int i = 0; i < (_count >> 1); i++)
+ for (UInt i = 0; i < (_count >> 1); i++)
{
Swap(buffer, i, _count - i - 1);
}
diff --git a/source/slang/ast-legalize.cpp b/source/slang/ast-legalize.cpp
index 144277301..2f7814055 100644
--- a/source/slang/ast-legalize.cpp
+++ b/source/slang/ast-legalize.cpp
@@ -4,6 +4,7 @@
#include "emit.h"
#include "ir-insts.h"
#include "legalize-types.h"
+#include "mangle.h"
#include "type-layout.h"
#include "visitor.h"
@@ -868,7 +869,7 @@ struct LoweringVisitor
RefPtr<Expr> createUncheckedVarRef(
- char const* name)
+ String const& name)
{
return createUncheckedVarRef(
shared->compileRequest->getNamePool()->getName(name));
@@ -2672,6 +2673,105 @@ struct LoweringVisitor
return translateDeclRefImpl(DeclRef<Decl>(decl, nullptr));
}
+ LegalExpr translateSimpleLegalValToLegalExpr(IRValue* irVal)
+ {
+ switch (irVal->op)
+ {
+ case kIROp_global_var:
+ {
+ IRGlobalVar* globalVar = (IRGlobalVar*)irVal;
+ String mangledName = globalVar->mangledName;
+ SLANG_ASSERT(mangledName.Length() != 0);
+
+ return LegalExpr(createUncheckedVarRef(mangledName));
+ }
+ break;
+
+ default:
+ SLANG_UNEXPECTED("unhandled opcode");
+ UNREACHABLE_RETURN(LegalExpr());
+ }
+ }
+
+ LegalExpr translateLegalValToLegalExpr(LegalVal legalVal)
+ {
+ switch (legalVal.flavor)
+ {
+ case LegalVal::Flavor::none:
+ return LegalExpr();
+
+ case LegalVal::Flavor::simple:
+ return translateSimpleLegalValToLegalExpr(legalVal.getSimple());
+ break;
+
+ case LegalVal::Flavor::pair:
+ {
+ auto pairVal = legalVal.getPair();
+ RefPtr<PairPseudoExpr> pairExpr = new PairPseudoExpr();
+ pairExpr->pairInfo = pairVal->pairInfo;
+ pairExpr->ordinary = translateLegalValToLegalExpr(pairVal->ordinaryVal);
+ pairExpr->special = translateLegalValToLegalExpr(pairVal->specialVal);
+ return LegalExpr(pairExpr);
+ }
+ break;
+
+ case LegalVal::Flavor::tuple:
+ {
+ auto tupleVal = legalVal.getTuple();
+ RefPtr<TuplePseudoExpr> tupleExpr = new TuplePseudoExpr();
+ for (auto ee : tupleVal->elements)
+ {
+ TuplePseudoExpr::Element element;
+ element.fieldDeclRef = ee.fieldDeclRef;
+ element.expr = translateLegalValToLegalExpr(ee.val);
+ tupleExpr->elements.Add(element);
+ }
+ return LegalExpr(tupleExpr);
+ }
+ break;
+
+ case LegalVal::Flavor::implicitDeref:
+ {
+ auto implicitDerefVal = legalVal.getImplicitDeref();
+ RefPtr<ImplicitDerefPseudoExpr> implicitDerefExpr = new ImplicitDerefPseudoExpr();
+ implicitDerefExpr->valueExpr = translateLegalValToLegalExpr(implicitDerefVal);
+ return LegalExpr(implicitDerefExpr);
+ }
+ break;
+
+ default:
+ SLANG_UNEXPECTED("unhandled flavor");
+ UNREACHABLE_RETURN(LegalExpr());
+ }
+ }
+
+ void maybeLegalizeIRGlobal(
+ DeclRef<Decl> declRef)
+ {
+ // We've been given a decl-ref to a value that was translated via IR,
+ // and we need to determine if it needs custom handling for legalization,
+ // because it was an IR global that got split.
+
+ // TODO: this code is using decls in places it should use decl-refs,
+ // and that likely needs to get cleaned up...
+ auto decl = declRef.getDecl();
+
+ // If we already have an expression registered, then don't bother.
+ if (shared->mapOriginalDeclToExpr.ContainsKey(decl))
+ return;
+
+ String mangledName = getMangledName(declRef);
+ if (mangledName.Length() == 0)
+ return;
+
+ LegalVal legalVal;
+ if (!shared->typeLegalizationContext->mapMangledNameToLegalIRValue.TryGetValue(mangledName, legalVal))
+ return;
+
+ LegalExpr legalExpr = translateLegalValToLegalExpr(legalVal);
+ shared->mapOriginalDeclToExpr.Add(decl, legalExpr);
+ }
+
RefPtr<Decl> translateDeclRefImpl(
DeclRef<Decl> declRef)
{
@@ -2703,24 +2803,25 @@ struct LoweringVisitor
auto parentModule = findModuleForDecl(decl);
if (parentModule && (parentModule != shared->mainModuleDecl))
{
- // Ensure that the IR code for the given declaration
- // gets included in the output IR module, and *also*
- // that we generate a suitable specialization of it
- // if there are any substitutions in effect.
-
- getSpecializedGlobalValueForDeclRef(
- shared->irSpecializationState,
- declRef);
+ // This declaration should already have been lowered to
+ // the IR during the "walk" pass that happened earlier,
+ // and so we won't do it again here.
+ //
+ // Instead, we need to check if the particular
+ // declaration is one that needs to be swapped for
+ // a legalized expression (e.g., because it was an IR
+ // global that got split)
+ //
+ maybeLegalizeIRGlobal(declRef);
// Remember that this declaration is handled via IR,
// rather than being present in the legalized AST.
shared->result.irDecls.Add(declRef.getDecl());
- // We don't actually use the `IRGlobalValue` that the
- // above operation returns, and instead just keep
- // using the original declaration in the legalized
- // AST. The step of mapping that declaration over
- // to reference the IR symbol will happen later.
+ // This method can't actually return a `LegalExpr`,
+ // so for now we just assume that the original
+ // declaration is the right stand-in for the IR
+ // value we want.
return decl;
}
@@ -4671,7 +4772,8 @@ LoweredEntryPoint lowerEntryPoint(
CodeGenTarget target,
ExtensionUsageTracker* extensionUsageTracker,
IRSpecializationState* irSpecializationState,
- TypeLegalizationContext* typeLegalizationContext)
+ TypeLegalizationContext* typeLegalizationContext,
+ List<Decl*> astDecls)
{
SharedLoweringContext sharedContext;
sharedContext.compileRequest = entryPoint->compileRequest;
@@ -4732,10 +4834,21 @@ LoweredEntryPoint lowerEntryPoint(
if (isRewrite)
{
+ for (auto dd : astDecls)
+ {
+ // Skip non-global decls
+ if (!dd->ParentDecl)
+ continue;
+ if (!dynamic_cast<ModuleDecl*>(dd->ParentDecl))
+ continue;
+ visitor.translateDeclRef(dd);
+ }
+#if 0
for (auto dd : translationUnit->SyntaxNode->Members)
{
visitor.translateDeclRef(dd);
}
+#endif
}
else
{
@@ -4747,4 +4860,505 @@ LoweredEntryPoint lowerEntryPoint(
return sharedContext.result;
}
+
+struct FindIRDeclUsedByASTVisitor
+ : ExprVisitor<FindIRDeclUsedByASTVisitor, void>
+ , StmtVisitor<FindIRDeclUsedByASTVisitor, void>
+ , DeclVisitor<FindIRDeclUsedByASTVisitor, void>
+ , ValVisitor<FindIRDeclUsedByASTVisitor, void, void>
+
+{
+ CompileRequest* compileRequest;
+ IRSpecializationState* irSpecializationState;
+ ModuleDecl* mainModuleDecl;
+
+ // Declarations to be processed by the AST lowering pass
+ List<Decl*>* astDecls;
+
+ HashSet<DeclBase*> seenDecls;
+ HashSet<DeclBase*> addedDecls;
+
+ void walkType(Type* type)
+ {
+ if(!type) return;
+
+ TypeVisitor::dispatch(type);
+ }
+
+ void walkVal(Val* val)
+ {
+ if(!val) return;
+
+ ValVisitor::dispatch(val);
+ }
+
+ void walkExpr(Expr* expr)
+ {
+ if(!expr) return;
+
+ ExprVisitor::dispatch(expr);
+ }
+
+ void walkStmt(Stmt* stmt)
+ {
+ if(!stmt) return;
+
+ StmtVisitor::dispatch(stmt);
+ }
+
+ void walkSubst(Substitutions* subst)
+ {
+ if( auto genericSubst = dynamic_cast<GenericSubstitution*>(subst) )
+ {
+ for( auto arg : genericSubst->args )
+ {
+ walkVal(arg);
+ }
+ }
+ // TODO: handle other cases here
+ }
+
+ void walkDeclRef(DeclRef<Decl> const& declRef)
+ {
+ Decl* decl = declRef.getDecl();
+ if (!decl) return;
+
+ // If this is a specialized declaration reference, then any
+ // of the arguments also need to be walked.
+ for(auto subst = declRef.substitutions; subst; subst = subst->outer)
+ {
+ walkSubst(subst);
+ }
+
+ // If any parent of the declaration was in the stdlib, or
+ // is registered as a builtin, then skip it.
+ for (auto pp = decl; pp; pp = pp->ParentDecl)
+ {
+ if (pp->HasModifier<FromStdLibModifier>())
+ return;
+
+ if (pp->HasModifier<BuiltinModifier>())
+ return;
+ }
+
+ // If we are using the IR, and the declaration comes from
+ // an imported module (rather than the "rewrite-mode" module
+ // being translated), then we need to ensure that it gets lowered
+ // to IR instead.
+ if (compileRequest->compileFlags & SLANG_COMPILE_FLAG_USE_IR)
+ {
+ auto parentModule = findModuleForDecl(decl);
+ if (parentModule && (parentModule != mainModuleDecl))
+ {
+ // Ensure that the IR code for the given declaration
+ // gets included in the output IR module, and *also*
+ // that we generate a suitable specialization of it
+ // if there are any substitutions in effect.
+
+ getSpecializedGlobalValueForDeclRef(
+ irSpecializationState,
+ declRef);
+
+ // TODO: we probably need to track this value...
+
+ return;
+ }
+ }
+
+ // If none of the above triggered, then we seemingly have
+ // a declaration from the current module, and we should
+ // add it to our work list so we can walk it too.
+ addDecl(decl);
+ }
+
+ // Vals
+
+ void visitIRProxyVal(IRProxyVal*)
+ {}
+
+ void visitConstantIntVal(ConstantIntVal*)
+ {}
+
+ void visitGenericParamIntVal(GenericParamIntVal* val)
+ {
+ walkDeclRef(val->declRef);
+ }
+
+ void visitWitness(Witness*)
+ {}
+
+ // Types
+
+ void visitOverloadGroupType(OverloadGroupType*)
+ {}
+
+ void visitInitializerListType(InitializerListType*)
+ {}
+
+ void visitErrorType(ErrorType*)
+ {}
+
+ void visitIRBasicBlockType(IRBasicBlockType*)
+ {}
+
+ void visitDeclRefType(DeclRefType* type)
+ {
+ walkDeclRef(type->declRef);
+ }
+
+ void visitGenericDeclRefType(GenericDeclRefType* type)
+ {
+ walkDeclRef(type->declRef);
+ }
+
+ void visitNamedExpressionType(NamedExpressionType* type)
+ {
+ walkDeclRef(type->declRef);
+ }
+
+ void visitFuncType(FuncType* type)
+ {
+ for( auto p : type->paramTypes )
+ {
+ walkType(p);
+ }
+ walkType(type->resultType);
+ }
+
+ void visitTypeType(TypeType* type)
+ {
+ walkType(type->type);
+ }
+
+ void visitGroupSharedType(GroupSharedType* type)
+ {
+ walkType(type->valueType);
+ }
+
+ void visitArrayExpressionType(ArrayExpressionType* type)
+ {
+ walkType(type->baseType);
+ walkVal(type->ArrayLength);
+ }
+
+ // Exprs
+
+ void visitVarExpr(VarExpr* expr)
+ {
+ walkDeclRef(expr->declRef);
+ }
+
+ void visitMemberExpr(MemberExpr* expr)
+ {
+ walkExpr(expr->BaseExpression);
+ walkDeclRef(expr->declRef);
+ }
+
+ void visitStaticMemberExpr(StaticMemberExpr* expr)
+ {
+ walkExpr(expr->BaseExpression);
+ walkDeclRef(expr->declRef);
+ }
+
+ void visitOverloadedExpr(OverloadedExpr* expr)
+ {
+ walkExpr(expr->base);
+
+ // TODO: need to walk the lookup result too
+ }
+
+ void visitConstantExpr(ConstantExpr*)
+ {}
+
+ void visitInitializerListExpr(InitializerListExpr* expr)
+ {
+ for(auto a : expr->args)
+ walkExpr(a);
+ }
+
+ void visitAppExprBase(AppExprBase* expr)
+ {
+ walkExpr(expr->FunctionExpr);
+ for(auto a : expr->Arguments)
+ walkExpr(a);
+ }
+
+ void visitAggTypeCtorExpr(AggTypeCtorExpr* expr)
+ {
+ walkType(expr->base);
+ for(auto a : expr->Arguments)
+ walkExpr(a);
+ }
+
+ void visitSharedTypeExpr(SharedTypeExpr* expr)
+ {
+ walkType(expr->base);
+ }
+
+ void visitAssignExpr(AssignExpr* expr)
+ {
+ walkExpr(expr->left);
+ walkExpr(expr->right);
+ }
+
+ void visitIndexExpr(IndexExpr* expr)
+ {
+ walkExpr(expr->BaseExpression);
+ walkExpr(expr->IndexExpression);
+ }
+
+ void visitSwizzleExpr(SwizzleExpr* expr)
+ {
+ walkExpr(expr->base);
+ }
+
+ void visitDerefExpr(DerefExpr* expr)
+ {
+ walkExpr(expr->base);
+ }
+
+ void visitParenExpr(ParenExpr* expr)
+ {
+ walkExpr(expr->base);
+ }
+
+ void visitThisExpr(ThisExpr*)
+ {}
+
+ // Stmts
+
+ void visitSeqStmt(SeqStmt* stmt)
+ {
+ for( auto s : stmt->stmts )
+ {
+ walkStmt(s);
+ }
+ }
+
+ void visitBlockStmt(BlockStmt* stmt)
+ {
+ walkStmt(stmt->body);
+ }
+
+ void visitUnparsedStmt(UnparsedStmt*)
+ {}
+
+ void visitEmptyStmt(EmptyStmt*)
+ {}
+
+ void visitDiscardStmt(DiscardStmt*)
+ {}
+
+ void visitDeclStmt(DeclStmt* stmt)
+ {
+ addDecl(stmt->decl);
+ }
+
+ void visitIfStmt(IfStmt* stmt)
+ {
+ walkExpr(stmt->Predicate);
+ walkStmt(stmt->PositiveStatement);
+ walkStmt(stmt->NegativeStatement);
+ }
+
+ void visitSwitchStmt(SwitchStmt* stmt)
+ {
+ walkExpr(stmt->condition);
+ walkStmt(stmt->body);
+ }
+
+ void visitCaseStmt(CaseStmt* stmt)
+ {
+ walkExpr(stmt->expr);
+ }
+
+ void visitDefaultStmt(DefaultStmt*)
+ {}
+
+ void visitForStmt(ForStmt* stmt)
+ {
+ walkStmt(stmt->InitialStatement);
+ walkExpr(stmt->SideEffectExpression);
+ walkExpr(stmt->PredicateExpression);
+ walkStmt(stmt->Statement);
+ }
+
+ void visitWhileStmt(WhileStmt* stmt)
+ {
+ walkExpr(stmt->Predicate);
+ walkStmt(stmt->Statement);
+ }
+
+ void visitDoWhileStmt(DoWhileStmt* stmt)
+ {
+ walkExpr(stmt->Predicate);
+ walkStmt(stmt->Statement);
+ }
+
+ void visitCompileTimeForStmt(CompileTimeForStmt* stmt)
+ {
+ addDecl(stmt->varDecl);
+ walkExpr(stmt->rangeBeginExpr);
+ walkExpr(stmt->rangeEndExpr);
+ walkStmt(stmt->body);
+ }
+
+ void visitReturnStmt(ReturnStmt* stmt)
+ {
+ walkExpr(stmt->Expression);
+ }
+
+ void visitExpressionStmt(ExpressionStmt* stmt)
+ {
+ walkExpr(stmt->Expression);
+ }
+
+ void visitJumpStmt(JumpStmt*)
+ {}
+
+ // Decls
+
+ void visitDeclGroup(DeclGroup* declGroup)
+ {
+ for( auto dd : declGroup->decls )
+ {
+ addDecl(dd);
+ }
+ }
+
+ void visitContainerDeclCommon(ContainerDecl* decl)
+ {
+ for( auto mm : decl->Members )
+ {
+ addDecl(mm);
+ }
+ }
+
+ void visitContainerDecl(ContainerDecl* decl)
+ {
+ visitContainerDeclCommon(decl);
+ }
+
+ void visitVarDeclBase(VarDeclBase* decl)
+ {
+ walkType(decl->type);
+ walkExpr(decl->initExpr);
+ }
+
+ void visitAggTypeDeclBase(AggTypeDeclBase* decl)
+ {
+ visitContainerDeclCommon(decl);
+ }
+
+ void visitInheritanceDecl(InheritanceDecl* decl)
+ {
+ walkType(decl->base);
+ }
+
+ void visitTypeDefDecl(TypeDefDecl* decl)
+ {
+ walkType(decl->type);
+ }
+
+ void visitCallableDeclCommon(CallableDecl* decl)
+ {
+ visitContainerDeclCommon(decl);
+ walkType(decl->ReturnType);
+ }
+
+ void visitCallableDecl(CallableDecl* decl)
+ {
+ visitCallableDeclCommon(decl);
+ }
+
+ void visitFunctionDeclBase(FunctionDeclBase* decl)
+ {
+ visitCallableDeclCommon(decl);
+ walkStmt(decl->Body);
+ }
+
+ void visitImportDecl(ImportDecl*)
+ {}
+
+ void visitGenericTypeParamDecl(GenericTypeParamDecl*)
+ {}
+
+ void visitGenericTypeConstraintDecl(GenericTypeConstraintDecl*)
+ {}
+
+ void visitEmptyDecl(EmptyDecl*)
+ {}
+
+ void visitSyntaxDecl(SyntaxDecl*)
+ {}
+
+ //
+
+ void addDecl(DeclBase* decl)
+ {
+ // Has this decl already been added
+ // to the output list?
+ if(addedDecls.Contains(decl))
+ return;
+
+ // Are we in the middel of processing this
+ // decl?
+ //
+ // TODO: this implies a cycle, and we need to
+ // break it!
+ if (seenDecls.Contains(decl))
+ return;
+
+ seenDecls.Add(decl);
+
+ // Recurse on the given decl
+ DeclVisitor::dispatch(decl);
+
+ // Add it to the output list, if needed
+ if (auto dd = dynamic_cast<Decl*>(decl))
+ {
+ (*astDecls).Add(dd);
+ }
+
+ // Mark it as completely processed
+ addedDecls.Add(decl);
+ }
+
+ void flush()
+ {
+ }
+};
+
+
+void findDeclsUsedByASTEntryPoint(
+ EntryPointRequest* entryPoint,
+ CodeGenTarget target,
+ IRSpecializationState* irSpecializationState,
+ List<Decl*>& outASTDecls)
+{
+ auto translationUnit = entryPoint->getTranslationUnit();
+ auto mainModuleDecl = translationUnit->SyntaxNode;
+
+ FindIRDeclUsedByASTVisitor visitor;
+ visitor.compileRequest = entryPoint->compileRequest;
+ visitor.irSpecializationState = irSpecializationState;
+ visitor.mainModuleDecl = mainModuleDecl;
+ visitor.astDecls = &outASTDecls;
+
+ bool isRewrite = isRewriteRequest(translationUnit->sourceLanguage, target);
+
+ if (isRewrite)
+ {
+ visitor.addDecl(mainModuleDecl);
+ }
+ else
+ {
+ visitor.addDecl(entryPoint->decl);
+ }
+
+ visitor.flush();
+}
+
+
+
}
diff --git a/source/slang/ast-legalize.h b/source/slang/ast-legalize.h
index 23a150002..ab06d7a21 100644
--- a/source/slang/ast-legalize.h
+++ b/source/slang/ast-legalize.h
@@ -69,6 +69,13 @@ namespace Slang
CodeGenTarget target,
ExtensionUsageTracker* extensionUsageTracker,
IRSpecializationState* irSpecializationState,
- TypeLegalizationContext* typeLegalizationContext);
+ TypeLegalizationContext* typeLegalizationContext,
+ List<Decl*> astDecls);
+
+ void findDeclsUsedByASTEntryPoint(
+ EntryPointRequest* entryPoint,
+ CodeGenTarget target,
+ IRSpecializationState* irSpecializationState,
+ List<Decl*>& outASTDecls);
}
#endif
diff --git a/source/slang/emit.cpp b/source/slang/emit.cpp
index bf7ad0c3a..ce17c8d03 100644
--- a/source/slang/emit.cpp
+++ b/source/slang/emit.cpp
@@ -100,7 +100,7 @@ struct SharedEmitContext
Dictionary<IRValue*, UInt> mapIRValueToID;
Dictionary<Decl*, UInt> mapDeclToID;
- HashSet<Decl*> irDeclsVisited;
+ HashSet<String> irDeclsVisited;
Dictionary<IRBlock*, IRBlock*> irMapContinueTargetToLoopHead;
@@ -2415,7 +2415,8 @@ struct EmitVisitor
}
else
{
- emit(memberExpr->declRef.GetName());
+ EmitDeclRef(memberExpr->declRef);
+// emit(memberExpr->declRef.GetName());
}
if(needClose) Emit(")");
@@ -2453,7 +2454,8 @@ struct EmitVisitor
}
else
{
- emit(memberExpr->declRef.GetName());
+ EmitDeclRef(memberExpr->declRef);
+// emit(memberExpr->declRef.GetName());
}
if(needClose) Emit(")");
@@ -6892,15 +6894,11 @@ emitDeclImpl(decl, nullptr);
EmitContext* ctx,
DeclRef<StructDecl> declRef)
{
- // TODO: Eventually need to deal with the case where
- // we have user-defined generic types.
- //
- auto decl = declRef.getDecl();
-
- if(ctx->shared->irDeclsVisited.Contains(decl))
+ auto mangledName = getMangledName(declRef);
+ if(ctx->shared->irDeclsVisited.Contains(mangledName))
return;
- ctx->shared->irDeclsVisited.Add(decl);
+ ctx->shared->irDeclsVisited.Add(mangledName);
// First emit any types used by fields of this type
for( auto ff : GetFields(declRef) )
@@ -6935,6 +6933,25 @@ emitDeclImpl(decl, nullptr);
Emit("};\n");
}
+ void emitIRUsedDeclRef(
+ EmitContext* ctx,
+ DeclRef<Decl> declRef)
+ {
+ auto decl = declRef.getDecl();
+
+ if(decl->HasModifier<BuiltinTypeModifier>()
+ || decl->HasModifier<MagicTypeModifier>())
+ {
+ return;
+ }
+
+ if( auto structDeclRef = declRef.As<StructDecl>() )
+ {
+ //
+ ensureStructDecl(ctx, structDeclRef);
+ }
+ }
+
// A type is going to be used by the IR, so
// make sure that we have emitted whatever
// it needs.
@@ -6970,19 +6987,7 @@ emitDeclImpl(decl, nullptr);
else if( auto declRefType = type->As<DeclRefType>() )
{
auto declRef = declRefType->declRef;
- auto decl = declRef.getDecl();
-
- if(decl->HasModifier<BuiltinTypeModifier>()
- || decl->HasModifier<MagicTypeModifier>())
- {
- return;
- }
-
- if( auto structDeclRef = declRef.As<StructDecl>() )
- {
- //
- ensureStructDecl(ctx, structDeclRef);
- }
+ emitIRUsedDeclRef(ctx, declRef);
}
else
{}
@@ -7249,13 +7254,21 @@ String emitEntryPoint(
// boilerplate at the start of the ouput for GLSL (e.g., what
// version we require).
+ List<Decl*> astDecls;
+ findDeclsUsedByASTEntryPoint(
+ entryPoint,
+ target,
+ nullptr,
+ astDecls);
+
auto lowered = lowerEntryPoint(
entryPoint,
programLayout,
target,
&sharedContext.extensionUsageTracker,
nullptr,
- &typeLegalizationContext);
+ &typeLegalizationContext,
+ astDecls);
sharedContext.program = lowered.program;
// Note that we emit the main body code of the program *before*
@@ -7287,25 +7300,23 @@ String emitEntryPoint(
typeLegalizationContext.irModule = irModule;
- LoweredEntryPoint lowered;
+ List<Decl*> astDecls;
if(translationUnit->compileFlags & SLANG_COMPILE_FLAG_NO_CHECKING)
{
// We are in case (2b), where the main module is in unchecked
// HLSL/GLSL that we need to "rewrite," and any library code
// is in Slang that will need to be cross-compiled via the IR.
- // Initially, we will apply the AST-to-AST pass to legalize
- // the user's code, much like we would for any other target.
- // Along the way, this pass will discover any IR declarations
- // that we use, and try to emit code for them into our IR module.
+ // We first need to walk the AST part of the code to look
+ // for any places where it references declarations that
+ // are implemented in the IR, so that we can be sure to
+ // generate suitable IR code for them.
- lowered = lowerEntryPoint(
+ findDeclsUsedByASTEntryPoint(
entryPoint,
- programLayout,
target,
- &sharedContext.extensionUsageTracker,
irSpecializationState,
- &typeLegalizationContext);
+ astDecls);
}
else
{
@@ -7358,6 +7369,33 @@ String emitEntryPoint(
fprintf(stderr, "###\n");
#endif
+ LoweredEntryPoint lowered;
+ if(translationUnit->compileFlags & SLANG_COMPILE_FLAG_NO_CHECKING)
+ {
+ // In the (2b) case, once we have legalized the IR code,
+ // we now need to go in and legalize the AST code.
+ // This order is important because when referring to a variable
+ // that is defined in the IR, we need to legalize it first (which
+ // might split it into many decls) before we can legalize an AST
+ // expression that references that decl (which will also need
+ // to get split).
+ //
+ // We don't have to worry about references in the other direction;
+ // we don't allow the user to define something in unchecked AST
+ // code and then use it from the IR shader library.
+
+ lowered = lowerEntryPoint(
+ entryPoint,
+ programLayout,
+ target,
+ &sharedContext.extensionUsageTracker,
+ irSpecializationState,
+ &typeLegalizationContext,
+ astDecls);
+ }
+
+ // When emitting IR-based declarations, we wnat to
+ // track which decls have already been lowered.
sharedContext.irDeclSetForAST = &lowered.irDecls;
// After all of the required optimization and legalization
@@ -7372,6 +7410,13 @@ String emitEntryPoint(
// that we need to output, we'll do it now.
if (translationUnit->compileFlags & SLANG_COMPILE_FLAG_NO_CHECKING)
{
+ // First make sure that we've emitted any types that were declared
+ // in the IR, but then subsequently only used by the AST
+ for( auto decl : lowered.irDecls )
+ {
+ visitor.emitIRUsedDeclRef(&context, makeDeclRef(decl));
+ }
+
visitor.EmitDeclsInContainer(lowered.program);
}
diff --git a/source/slang/ir-legalize-types.cpp b/source/slang/ir-legalize-types.cpp
index fedec4f87..5db3b4d33 100644
--- a/source/slang/ir-legalize-types.cpp
+++ b/source/slang/ir-legalize-types.cpp
@@ -17,78 +17,6 @@
namespace Slang
{
-
-
-struct LegalValImpl : RefObject
-{
-};
-struct TuplePseudoVal;
-struct PairPseudoVal;
-
-struct LegalVal
-{
- enum class Flavor
- {
- none,
- simple,
- implicitDeref,
- tuple,
- pair,
- };
-
- Flavor flavor = Flavor::none;
- RefPtr<RefObject> obj;
- IRValue* irValue = nullptr;
-
- static LegalVal simple(IRValue* irValue)
- {
- LegalVal result;
- result.flavor = Flavor::simple;
- result.irValue = irValue;
- return result;
- }
-
- IRValue* getSimple()
- {
- assert(flavor == Flavor::simple);
- return irValue;
- }
-
- static LegalVal tuple(RefPtr<TuplePseudoVal> tupleVal);
-
- RefPtr<TuplePseudoVal> getTuple()
- {
- assert(flavor == Flavor::tuple);
- return obj.As<TuplePseudoVal>();
- }
-
- static LegalVal implicitDeref(LegalVal const& val);
- LegalVal getImplicitDeref();
-
- static LegalVal pair(RefPtr<PairPseudoVal> pairInfo);
- static LegalVal pair(
- LegalVal const& ordinaryVal,
- LegalVal const& specialVal,
- RefPtr<PairInfo> pairInfo);
-
- RefPtr<PairPseudoVal> getPair()
- {
- assert(flavor == Flavor::pair);
- return obj.As<PairPseudoVal>();
- }
-};
-
-struct TuplePseudoVal : LegalValImpl
-{
- struct Element
- {
- DeclRef<VarDeclBase> fieldDeclRef;
- LegalVal val;
- };
-
- List<Element> elements;
-};
-
LegalVal LegalVal::tuple(RefPtr<TuplePseudoVal> tupleVal)
{
LegalVal result;
@@ -97,16 +25,6 @@ LegalVal LegalVal::tuple(RefPtr<TuplePseudoVal> tupleVal)
return result;
}
-struct PairPseudoVal : LegalValImpl
-{
- LegalVal ordinaryVal;
- LegalVal specialVal;
-
- // The info to tell us which fields
- // are on which side(s)
- RefPtr<PairInfo> pairInfo;
-};
-
LegalVal LegalVal::pair(RefPtr<PairPseudoVal> pairInfo)
{
LegalVal result;
@@ -135,11 +53,6 @@ LegalVal LegalVal::pair(
return LegalVal::pair(obj);
}
-struct ImplicitDerefVal : LegalValImpl
-{
- LegalVal val;
-};
-
LegalVal LegalVal::implicitDeref(LegalVal const& val)
{
RefPtr<ImplicitDerefVal> implicitDerefVal = new ImplicitDerefVal();
@@ -189,12 +102,37 @@ static void registerLegalizedValue(
context->mapValToLegalVal.Add(irValue, legalVal);
}
+static void maybeRegisterLegalizedGlobal(
+ IRTypeLegalizationContext* context,
+ IRGlobalVar* irGlobalVar,
+ LegalVal const& legalVal)
+{
+ // Check the mangled name of the symbol and don't register
+ // symbols that don't have an external name (currently
+ // indicated by them having an empty name string).
+ String mangledName = irGlobalVar->mangledName;
+ if (mangledName.Length() == 0)
+ return;
+
+ // Otherwise, register the legalized value for this symbol
+ // under its mangled name, so that other code can still
+ // find the right value(s) to use after legalization.
+ context->typeLegalizationContext->mapMangledNameToLegalIRValue.AddIfNotExists(mangledName, legalVal);
+}
+
+struct IRGlobalNameInfo
+{
+ IRGlobalVar* globalVar;
+ UInt counter;
+};
+
static LegalVal declareVars(
IRTypeLegalizationContext* context,
IROp op,
LegalType type,
TypeLayout* typeLayout,
- LegalVarChain* varChain);
+ LegalVarChain* varChain,
+ IRGlobalNameInfo* globalNameInfo);
static LegalType legalizeType(
IRTypeLegalizationContext* context,
@@ -608,7 +546,7 @@ static LegalVal legalizeLocalVar(
varChain = &varChainStorage;
}
- LegalVal newVal = declareVars(context, kIROp_Var, legalValueType, typeLayout, varChain);
+ LegalVal newVal = declareVars(context, kIROp_Var, legalValueType, typeLayout, varChain, nullptr);
// Remove the old local var.
irLocalVar->removeFromParent();
@@ -761,7 +699,7 @@ static void legalizeFunc(
context->insertBeforeParam = pp;
context->builder->curBlock = nullptr;
- auto paramVal = declareVars(context, kIROp_Param, legalParamType, nullptr, nullptr);
+ auto paramVal = declareVars(context, kIROp_Param, legalParamType, nullptr, nullptr, nullptr);
paramVals.Add(paramVal);
if (pp == bb->getFirstParam())
{
@@ -807,7 +745,8 @@ static LegalVal declareSimpleVar(
IROp op,
Type* type,
TypeLayout* typeLayout,
- LegalVarChain* varChain)
+ LegalVarChain* varChain,
+ IRGlobalNameInfo* globalNameInfo)
{
RefPtr<VarLayout> varLayout = createVarLayout(varChain, typeLayout);
@@ -830,6 +769,25 @@ static LegalVal declareSimpleVar(
globalVar->removeFromParent();
globalVar->insertBefore(context->insertBeforeGlobal, builder->getModule());
+ // The legalization of a global variable with linkage (one that has
+ // a mangled name), must also have an exported name, so that code
+ // can link against it.
+ //
+ // For now we do something *really* simplistic, and just append
+ // a counter to each leaf variable generated from the original
+ if (globalNameInfo)
+ {
+ String mangledName = globalNameInfo->globalVar->mangledName;
+ if (mangledName.Length() != 0)
+ {
+ mangledName.append("L");
+ mangledName.append(globalNameInfo->counter++);
+ globalVar->mangledName = mangledName;
+ }
+ }
+
+
+
irVar = globalVar;
legalVarVal = LegalVal::simple(irVar);
}
@@ -887,7 +845,8 @@ static LegalVal declareVars(
IROp op,
LegalType type,
TypeLayout* typeLayout,
- LegalVarChain* varChain)
+ LegalVarChain* varChain,
+ IRGlobalNameInfo* globalNameInfo)
{
switch (type.flavor)
{
@@ -895,7 +854,7 @@ static LegalVal declareVars(
return LegalVal();
case LegalType::Flavor::simple:
- return declareSimpleVar(context, op, type.getSimple(), typeLayout, varChain);
+ return declareSimpleVar(context, op, type.getSimple(), typeLayout, varChain, globalNameInfo);
break;
case LegalType::Flavor::implicitDeref:
@@ -908,7 +867,8 @@ static LegalVal declareVars(
op,
type.getImplicitDeref()->valueType,
getDerefTypeLayout(typeLayout),
- varChain);
+ varChain,
+ globalNameInfo);
return LegalVal::implicitDeref(val);
}
break;
@@ -916,8 +876,8 @@ static LegalVal declareVars(
case LegalType::Flavor::pair:
{
auto pairType = type.getPair();
- auto ordinaryVal = declareVars(context, op, pairType->ordinaryType, typeLayout, varChain);
- auto specialVal = declareVars(context, op, pairType->specialType, typeLayout, varChain);
+ auto ordinaryVal = declareVars(context, op, pairType->ordinaryType, typeLayout, varChain, globalNameInfo);
+ auto specialVal = declareVars(context, op, pairType->specialType, typeLayout, varChain, globalNameInfo);
return LegalVal::pair(ordinaryVal, specialVal, pairType->pairInfo);
}
@@ -951,7 +911,8 @@ static LegalVal declareVars(
op,
ee.type,
fieldTypeLayout,
- newVarChain);
+ newVarChain,
+ globalNameInfo);
TuplePseudoVal::Element element;
element.fieldDeclRef = ee.fieldDeclRef;
@@ -1003,11 +964,18 @@ static void legalizeGlobalVar(
varChain = &varChainStorage;
}
- LegalVal newVal = declareVars(context, kIROp_global_var, legalValueType, typeLayout, varChain);
+ IRGlobalNameInfo globalNameInfo;
+ globalNameInfo.globalVar = irGlobalVar;
+ globalNameInfo.counter = 0;
+
+ LegalVal newVal = declareVars(context, kIROp_global_var, legalValueType, typeLayout, varChain, &globalNameInfo);
// Register the new value as the replacement for the old
registerLegalizedValue(context, irGlobalVar, newVal);
+ // Also register the variable according to its mangled name, if any.
+ maybeRegisterLegalizedGlobal(context, irGlobalVar, newVal);
+
// Remove the old global from the module.
irGlobalVar->removeFromParent();
// TODO: actually clean up the global!
diff --git a/source/slang/legalize-types.h b/source/slang/legalize-types.h
index 36e4223b6..853b9f47f 100644
--- a/source/slang/legalize-types.h
+++ b/source/slang/legalize-types.h
@@ -244,6 +244,106 @@ RefPtr<VarLayout> createVarLayout(
TypeLayout* typeLayout);
//
+// The result of legalizing an IR value will be
+// represented with the `LegalVal` type. It is exposed
+// in this header (rather than kept as an implementation
+// detail, because the AST-based legalization logic needs
+// a way to find the post-legalization version of a
+// global name).
+//
+// TODO: We really shouldn't have this structure exposed,
+// and instead should really be constructing AST-side
+// `LegalExpr` values on-demand whenever we legalize something
+// in the IR that will need to be used by the AST, and then
+// store *those* in a map indexed in mangled names.
+//
+
+struct LegalValImpl : RefObject
+{
+};
+struct TuplePseudoVal;
+struct PairPseudoVal;
+
+struct LegalVal
+{
+ enum class Flavor
+ {
+ none,
+ simple,
+ implicitDeref,
+ tuple,
+ pair,
+ };
+
+ Flavor flavor = Flavor::none;
+ RefPtr<RefObject> obj;
+ IRValue* irValue = nullptr;
+
+ static LegalVal simple(IRValue* irValue)
+ {
+ LegalVal result;
+ result.flavor = Flavor::simple;
+ result.irValue = irValue;
+ return result;
+ }
+
+ IRValue* getSimple()
+ {
+ assert(flavor == Flavor::simple);
+ return irValue;
+ }
+
+ static LegalVal tuple(RefPtr<TuplePseudoVal> tupleVal);
+
+ RefPtr<TuplePseudoVal> getTuple()
+ {
+ assert(flavor == Flavor::tuple);
+ return obj.As<TuplePseudoVal>();
+ }
+
+ static LegalVal implicitDeref(LegalVal const& val);
+ LegalVal getImplicitDeref();
+
+ static LegalVal pair(RefPtr<PairPseudoVal> pairInfo);
+ static LegalVal pair(
+ LegalVal const& ordinaryVal,
+ LegalVal const& specialVal,
+ RefPtr<PairInfo> pairInfo);
+
+ RefPtr<PairPseudoVal> getPair()
+ {
+ assert(flavor == Flavor::pair);
+ return obj.As<PairPseudoVal>();
+ }
+};
+
+struct TuplePseudoVal : LegalValImpl
+{
+ struct Element
+ {
+ DeclRef<VarDeclBase> fieldDeclRef;
+ LegalVal val;
+ };
+
+ List<Element> elements;
+};
+
+struct PairPseudoVal : LegalValImpl
+{
+ LegalVal ordinaryVal;
+ LegalVal specialVal;
+
+ // The info to tell us which fields
+ // are on which side(s)
+ RefPtr<PairInfo> pairInfo;
+};
+
+struct ImplicitDerefVal : LegalValImpl
+{
+ LegalVal val;
+};
+
+//
struct TypeLegalizationContext
{
@@ -280,6 +380,9 @@ struct TypeLegalizationContext
/// emitting declarations of legalized `struct` types
/// multiple times).
Dictionary<DeclRef<Decl>, LegalType> mapDeclRefToLegalType;
+
+ //
+ Dictionary<String, LegalVal> mapMangledNameToLegalIRValue;
};