diff options
| author | Tim Foley <tfoleyNV@users.noreply.github.com> | 2017-12-18 15:14:59 -0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2017-12-18 15:14:59 -0800 |
| commit | 393e25fd2e2b8c5ff82ff4c6b14a9d7152d37a5e (patch) | |
| tree | a3b0617e7ce5a5bfcf43454893f6c2962b7ec382 | |
| parent | 46b68ed41daecfaf1761e299cf040156e0f65ac0 (diff) | |
Work on getting rewriter + IR playing nice together. (#314)
* Work on getting rewriter + IR playing nice together.
There are a few different changes here, with the goal of improving the interaction between the "rewriter" code generation approach and the new IR and type legalization code.
The main changes are:
- Add a new pass that occurs before the AST legalization pass, which walks the (used) AST declarations and tries to discover (1) which declarations need to be specialized/lowered via the IR, and (2) which declarations need to be included in the resulting AST module.
- AST-based legalization now uses the generated list when in "rewriter" mode, so that we should be working around issues that users were seeing with types not getting emitted.
- TODO: we still need an equivalent fixup in the case of non-"rewriter" emit, so this may still be a problem for `.slang` files.
- IR type legalization now precedes AST legalization, so that we can record information on how any IR global values got legalized (e.g., if they got split). Then AST legalization includes logic to reconstruct suitable tuple expressions to reference a split global.
- When emitting using IR + AST, we walk all of the declarations that we decided belonged to the IR, but which were subsequently referenced in the AST, to make sure they get output (this would include `struct` types that are declared in a file compiled via IR, but never used in IR-based code).
The rewriter+IR use case still doesn't *quite* work, but the logic for walking the AST in a pre-pass ends up being needed/useful to fix some pure rewriter bugs, so I'm getting this checked in sooner rather than later.
* Fixup: walk arguments to generic declaration reference
The gotcha here is that the code for walking the AST would walk a line of code like:
SomeType a;
and know to traverse the declaration of `SomeType`, but if it saw a line of code like:
ParameterBlock<SomeType> b;
it would traverse the declaration of `ParameterBlock`, but fail to visit that of `SomeType`.
| -rw-r--r-- | source/core/list.h | 2 | ||||
| -rw-r--r-- | source/slang/ast-legalize.cpp | 644 | ||||
| -rw-r--r-- | source/slang/ast-legalize.h | 9 | ||||
| -rw-r--r-- | source/slang/emit.cpp | 111 | ||||
| -rw-r--r-- | source/slang/ir-legalize-types.cpp | 164 | ||||
| -rw-r--r-- | source/slang/legalize-types.h | 103 | ||||
| -rw-r--r-- | tests/bindings/multiple-parameter-blocks.slang | 20 | ||||
| -rw-r--r-- | tests/bugs/split-nested-types.hlsl | 30 | ||||
| -rw-r--r-- | tests/bugs/split-nested-types.slang | 14 | ||||
| -rw-r--r-- | tests/compute/rewriter-parameter-block-complex.hlsl | 2 | ||||
| -rw-r--r-- | tests/compute/rewriter-parameter-block.hlsl | 2 | ||||
| -rw-r--r-- | tests/compute/rewriter-use-ir-type.hlsl | 24 | ||||
| -rw-r--r-- | tests/compute/rewriter-use-ir-type.hlsl.expected.txt | 4 | ||||
| -rw-r--r-- | tests/compute/rewriter-use-ir-type.slang | 6 |
14 files changed, 973 insertions, 162 deletions
diff --git a/source/core/list.h b/source/core/list.h index b1461a260..5a94d8b83 100644 --- a/source/core/list.h +++ b/source/core/list.h @@ -381,7 +381,7 @@ namespace Slang void Reverse() { - for (int i = 0; i < (_count >> 1); i++) + for (UInt i = 0; i < (_count >> 1); i++) { Swap(buffer, i, _count - i - 1); } diff --git a/source/slang/ast-legalize.cpp b/source/slang/ast-legalize.cpp index 144277301..2f7814055 100644 --- a/source/slang/ast-legalize.cpp +++ b/source/slang/ast-legalize.cpp @@ -4,6 +4,7 @@ #include "emit.h" #include "ir-insts.h" #include "legalize-types.h" +#include "mangle.h" #include "type-layout.h" #include "visitor.h" @@ -868,7 +869,7 @@ struct LoweringVisitor RefPtr<Expr> createUncheckedVarRef( - char const* name) + String const& name) { return createUncheckedVarRef( shared->compileRequest->getNamePool()->getName(name)); @@ -2672,6 +2673,105 @@ struct LoweringVisitor return translateDeclRefImpl(DeclRef<Decl>(decl, nullptr)); } + LegalExpr translateSimpleLegalValToLegalExpr(IRValue* irVal) + { + switch (irVal->op) + { + case kIROp_global_var: + { + IRGlobalVar* globalVar = (IRGlobalVar*)irVal; + String mangledName = globalVar->mangledName; + SLANG_ASSERT(mangledName.Length() != 0); + + return LegalExpr(createUncheckedVarRef(mangledName)); + } + break; + + default: + SLANG_UNEXPECTED("unhandled opcode"); + UNREACHABLE_RETURN(LegalExpr()); + } + } + + LegalExpr translateLegalValToLegalExpr(LegalVal legalVal) + { + switch (legalVal.flavor) + { + case LegalVal::Flavor::none: + return LegalExpr(); + + case LegalVal::Flavor::simple: + return translateSimpleLegalValToLegalExpr(legalVal.getSimple()); + break; + + case LegalVal::Flavor::pair: + { + auto pairVal = legalVal.getPair(); + RefPtr<PairPseudoExpr> pairExpr = new PairPseudoExpr(); + pairExpr->pairInfo = pairVal->pairInfo; + pairExpr->ordinary = translateLegalValToLegalExpr(pairVal->ordinaryVal); + pairExpr->special = translateLegalValToLegalExpr(pairVal->specialVal); + return LegalExpr(pairExpr); + } + break; + + case LegalVal::Flavor::tuple: + { + auto tupleVal = legalVal.getTuple(); + RefPtr<TuplePseudoExpr> tupleExpr = new TuplePseudoExpr(); + for (auto ee : tupleVal->elements) + { + TuplePseudoExpr::Element element; + element.fieldDeclRef = ee.fieldDeclRef; + element.expr = translateLegalValToLegalExpr(ee.val); + tupleExpr->elements.Add(element); + } + return LegalExpr(tupleExpr); + } + break; + + case LegalVal::Flavor::implicitDeref: + { + auto implicitDerefVal = legalVal.getImplicitDeref(); + RefPtr<ImplicitDerefPseudoExpr> implicitDerefExpr = new ImplicitDerefPseudoExpr(); + implicitDerefExpr->valueExpr = translateLegalValToLegalExpr(implicitDerefVal); + return LegalExpr(implicitDerefExpr); + } + break; + + default: + SLANG_UNEXPECTED("unhandled flavor"); + UNREACHABLE_RETURN(LegalExpr()); + } + } + + void maybeLegalizeIRGlobal( + DeclRef<Decl> declRef) + { + // We've been given a decl-ref to a value that was translated via IR, + // and we need to determine if it needs custom handling for legalization, + // because it was an IR global that got split. + + // TODO: this code is using decls in places it should use decl-refs, + // and that likely needs to get cleaned up... + auto decl = declRef.getDecl(); + + // If we already have an expression registered, then don't bother. + if (shared->mapOriginalDeclToExpr.ContainsKey(decl)) + return; + + String mangledName = getMangledName(declRef); + if (mangledName.Length() == 0) + return; + + LegalVal legalVal; + if (!shared->typeLegalizationContext->mapMangledNameToLegalIRValue.TryGetValue(mangledName, legalVal)) + return; + + LegalExpr legalExpr = translateLegalValToLegalExpr(legalVal); + shared->mapOriginalDeclToExpr.Add(decl, legalExpr); + } + RefPtr<Decl> translateDeclRefImpl( DeclRef<Decl> declRef) { @@ -2703,24 +2803,25 @@ struct LoweringVisitor auto parentModule = findModuleForDecl(decl); if (parentModule && (parentModule != shared->mainModuleDecl)) { - // Ensure that the IR code for the given declaration - // gets included in the output IR module, and *also* - // that we generate a suitable specialization of it - // if there are any substitutions in effect. - - getSpecializedGlobalValueForDeclRef( - shared->irSpecializationState, - declRef); + // This declaration should already have been lowered to + // the IR during the "walk" pass that happened earlier, + // and so we won't do it again here. + // + // Instead, we need to check if the particular + // declaration is one that needs to be swapped for + // a legalized expression (e.g., because it was an IR + // global that got split) + // + maybeLegalizeIRGlobal(declRef); // Remember that this declaration is handled via IR, // rather than being present in the legalized AST. shared->result.irDecls.Add(declRef.getDecl()); - // We don't actually use the `IRGlobalValue` that the - // above operation returns, and instead just keep - // using the original declaration in the legalized - // AST. The step of mapping that declaration over - // to reference the IR symbol will happen later. + // This method can't actually return a `LegalExpr`, + // so for now we just assume that the original + // declaration is the right stand-in for the IR + // value we want. return decl; } @@ -4671,7 +4772,8 @@ LoweredEntryPoint lowerEntryPoint( CodeGenTarget target, ExtensionUsageTracker* extensionUsageTracker, IRSpecializationState* irSpecializationState, - TypeLegalizationContext* typeLegalizationContext) + TypeLegalizationContext* typeLegalizationContext, + List<Decl*> astDecls) { SharedLoweringContext sharedContext; sharedContext.compileRequest = entryPoint->compileRequest; @@ -4732,10 +4834,21 @@ LoweredEntryPoint lowerEntryPoint( if (isRewrite) { + for (auto dd : astDecls) + { + // Skip non-global decls + if (!dd->ParentDecl) + continue; + if (!dynamic_cast<ModuleDecl*>(dd->ParentDecl)) + continue; + visitor.translateDeclRef(dd); + } +#if 0 for (auto dd : translationUnit->SyntaxNode->Members) { visitor.translateDeclRef(dd); } +#endif } else { @@ -4747,4 +4860,505 @@ LoweredEntryPoint lowerEntryPoint( return sharedContext.result; } + +struct FindIRDeclUsedByASTVisitor + : ExprVisitor<FindIRDeclUsedByASTVisitor, void> + , StmtVisitor<FindIRDeclUsedByASTVisitor, void> + , DeclVisitor<FindIRDeclUsedByASTVisitor, void> + , ValVisitor<FindIRDeclUsedByASTVisitor, void, void> + +{ + CompileRequest* compileRequest; + IRSpecializationState* irSpecializationState; + ModuleDecl* mainModuleDecl; + + // Declarations to be processed by the AST lowering pass + List<Decl*>* astDecls; + + HashSet<DeclBase*> seenDecls; + HashSet<DeclBase*> addedDecls; + + void walkType(Type* type) + { + if(!type) return; + + TypeVisitor::dispatch(type); + } + + void walkVal(Val* val) + { + if(!val) return; + + ValVisitor::dispatch(val); + } + + void walkExpr(Expr* expr) + { + if(!expr) return; + + ExprVisitor::dispatch(expr); + } + + void walkStmt(Stmt* stmt) + { + if(!stmt) return; + + StmtVisitor::dispatch(stmt); + } + + void walkSubst(Substitutions* subst) + { + if( auto genericSubst = dynamic_cast<GenericSubstitution*>(subst) ) + { + for( auto arg : genericSubst->args ) + { + walkVal(arg); + } + } + // TODO: handle other cases here + } + + void walkDeclRef(DeclRef<Decl> const& declRef) + { + Decl* decl = declRef.getDecl(); + if (!decl) return; + + // If this is a specialized declaration reference, then any + // of the arguments also need to be walked. + for(auto subst = declRef.substitutions; subst; subst = subst->outer) + { + walkSubst(subst); + } + + // If any parent of the declaration was in the stdlib, or + // is registered as a builtin, then skip it. + for (auto pp = decl; pp; pp = pp->ParentDecl) + { + if (pp->HasModifier<FromStdLibModifier>()) + return; + + if (pp->HasModifier<BuiltinModifier>()) + return; + } + + // If we are using the IR, and the declaration comes from + // an imported module (rather than the "rewrite-mode" module + // being translated), then we need to ensure that it gets lowered + // to IR instead. + if (compileRequest->compileFlags & SLANG_COMPILE_FLAG_USE_IR) + { + auto parentModule = findModuleForDecl(decl); + if (parentModule && (parentModule != mainModuleDecl)) + { + // Ensure that the IR code for the given declaration + // gets included in the output IR module, and *also* + // that we generate a suitable specialization of it + // if there are any substitutions in effect. + + getSpecializedGlobalValueForDeclRef( + irSpecializationState, + declRef); + + // TODO: we probably need to track this value... + + return; + } + } + + // If none of the above triggered, then we seemingly have + // a declaration from the current module, and we should + // add it to our work list so we can walk it too. + addDecl(decl); + } + + // Vals + + void visitIRProxyVal(IRProxyVal*) + {} + + void visitConstantIntVal(ConstantIntVal*) + {} + + void visitGenericParamIntVal(GenericParamIntVal* val) + { + walkDeclRef(val->declRef); + } + + void visitWitness(Witness*) + {} + + // Types + + void visitOverloadGroupType(OverloadGroupType*) + {} + + void visitInitializerListType(InitializerListType*) + {} + + void visitErrorType(ErrorType*) + {} + + void visitIRBasicBlockType(IRBasicBlockType*) + {} + + void visitDeclRefType(DeclRefType* type) + { + walkDeclRef(type->declRef); + } + + void visitGenericDeclRefType(GenericDeclRefType* type) + { + walkDeclRef(type->declRef); + } + + void visitNamedExpressionType(NamedExpressionType* type) + { + walkDeclRef(type->declRef); + } + + void visitFuncType(FuncType* type) + { + for( auto p : type->paramTypes ) + { + walkType(p); + } + walkType(type->resultType); + } + + void visitTypeType(TypeType* type) + { + walkType(type->type); + } + + void visitGroupSharedType(GroupSharedType* type) + { + walkType(type->valueType); + } + + void visitArrayExpressionType(ArrayExpressionType* type) + { + walkType(type->baseType); + walkVal(type->ArrayLength); + } + + // Exprs + + void visitVarExpr(VarExpr* expr) + { + walkDeclRef(expr->declRef); + } + + void visitMemberExpr(MemberExpr* expr) + { + walkExpr(expr->BaseExpression); + walkDeclRef(expr->declRef); + } + + void visitStaticMemberExpr(StaticMemberExpr* expr) + { + walkExpr(expr->BaseExpression); + walkDeclRef(expr->declRef); + } + + void visitOverloadedExpr(OverloadedExpr* expr) + { + walkExpr(expr->base); + + // TODO: need to walk the lookup result too + } + + void visitConstantExpr(ConstantExpr*) + {} + + void visitInitializerListExpr(InitializerListExpr* expr) + { + for(auto a : expr->args) + walkExpr(a); + } + + void visitAppExprBase(AppExprBase* expr) + { + walkExpr(expr->FunctionExpr); + for(auto a : expr->Arguments) + walkExpr(a); + } + + void visitAggTypeCtorExpr(AggTypeCtorExpr* expr) + { + walkType(expr->base); + for(auto a : expr->Arguments) + walkExpr(a); + } + + void visitSharedTypeExpr(SharedTypeExpr* expr) + { + walkType(expr->base); + } + + void visitAssignExpr(AssignExpr* expr) + { + walkExpr(expr->left); + walkExpr(expr->right); + } + + void visitIndexExpr(IndexExpr* expr) + { + walkExpr(expr->BaseExpression); + walkExpr(expr->IndexExpression); + } + + void visitSwizzleExpr(SwizzleExpr* expr) + { + walkExpr(expr->base); + } + + void visitDerefExpr(DerefExpr* expr) + { + walkExpr(expr->base); + } + + void visitParenExpr(ParenExpr* expr) + { + walkExpr(expr->base); + } + + void visitThisExpr(ThisExpr*) + {} + + // Stmts + + void visitSeqStmt(SeqStmt* stmt) + { + for( auto s : stmt->stmts ) + { + walkStmt(s); + } + } + + void visitBlockStmt(BlockStmt* stmt) + { + walkStmt(stmt->body); + } + + void visitUnparsedStmt(UnparsedStmt*) + {} + + void visitEmptyStmt(EmptyStmt*) + {} + + void visitDiscardStmt(DiscardStmt*) + {} + + void visitDeclStmt(DeclStmt* stmt) + { + addDecl(stmt->decl); + } + + void visitIfStmt(IfStmt* stmt) + { + walkExpr(stmt->Predicate); + walkStmt(stmt->PositiveStatement); + walkStmt(stmt->NegativeStatement); + } + + void visitSwitchStmt(SwitchStmt* stmt) + { + walkExpr(stmt->condition); + walkStmt(stmt->body); + } + + void visitCaseStmt(CaseStmt* stmt) + { + walkExpr(stmt->expr); + } + + void visitDefaultStmt(DefaultStmt*) + {} + + void visitForStmt(ForStmt* stmt) + { + walkStmt(stmt->InitialStatement); + walkExpr(stmt->SideEffectExpression); + walkExpr(stmt->PredicateExpression); + walkStmt(stmt->Statement); + } + + void visitWhileStmt(WhileStmt* stmt) + { + walkExpr(stmt->Predicate); + walkStmt(stmt->Statement); + } + + void visitDoWhileStmt(DoWhileStmt* stmt) + { + walkExpr(stmt->Predicate); + walkStmt(stmt->Statement); + } + + void visitCompileTimeForStmt(CompileTimeForStmt* stmt) + { + addDecl(stmt->varDecl); + walkExpr(stmt->rangeBeginExpr); + walkExpr(stmt->rangeEndExpr); + walkStmt(stmt->body); + } + + void visitReturnStmt(ReturnStmt* stmt) + { + walkExpr(stmt->Expression); + } + + void visitExpressionStmt(ExpressionStmt* stmt) + { + walkExpr(stmt->Expression); + } + + void visitJumpStmt(JumpStmt*) + {} + + // Decls + + void visitDeclGroup(DeclGroup* declGroup) + { + for( auto dd : declGroup->decls ) + { + addDecl(dd); + } + } + + void visitContainerDeclCommon(ContainerDecl* decl) + { + for( auto mm : decl->Members ) + { + addDecl(mm); + } + } + + void visitContainerDecl(ContainerDecl* decl) + { + visitContainerDeclCommon(decl); + } + + void visitVarDeclBase(VarDeclBase* decl) + { + walkType(decl->type); + walkExpr(decl->initExpr); + } + + void visitAggTypeDeclBase(AggTypeDeclBase* decl) + { + visitContainerDeclCommon(decl); + } + + void visitInheritanceDecl(InheritanceDecl* decl) + { + walkType(decl->base); + } + + void visitTypeDefDecl(TypeDefDecl* decl) + { + walkType(decl->type); + } + + void visitCallableDeclCommon(CallableDecl* decl) + { + visitContainerDeclCommon(decl); + walkType(decl->ReturnType); + } + + void visitCallableDecl(CallableDecl* decl) + { + visitCallableDeclCommon(decl); + } + + void visitFunctionDeclBase(FunctionDeclBase* decl) + { + visitCallableDeclCommon(decl); + walkStmt(decl->Body); + } + + void visitImportDecl(ImportDecl*) + {} + + void visitGenericTypeParamDecl(GenericTypeParamDecl*) + {} + + void visitGenericTypeConstraintDecl(GenericTypeConstraintDecl*) + {} + + void visitEmptyDecl(EmptyDecl*) + {} + + void visitSyntaxDecl(SyntaxDecl*) + {} + + // + + void addDecl(DeclBase* decl) + { + // Has this decl already been added + // to the output list? + if(addedDecls.Contains(decl)) + return; + + // Are we in the middel of processing this + // decl? + // + // TODO: this implies a cycle, and we need to + // break it! + if (seenDecls.Contains(decl)) + return; + + seenDecls.Add(decl); + + // Recurse on the given decl + DeclVisitor::dispatch(decl); + + // Add it to the output list, if needed + if (auto dd = dynamic_cast<Decl*>(decl)) + { + (*astDecls).Add(dd); + } + + // Mark it as completely processed + addedDecls.Add(decl); + } + + void flush() + { + } +}; + + +void findDeclsUsedByASTEntryPoint( + EntryPointRequest* entryPoint, + CodeGenTarget target, + IRSpecializationState* irSpecializationState, + List<Decl*>& outASTDecls) +{ + auto translationUnit = entryPoint->getTranslationUnit(); + auto mainModuleDecl = translationUnit->SyntaxNode; + + FindIRDeclUsedByASTVisitor visitor; + visitor.compileRequest = entryPoint->compileRequest; + visitor.irSpecializationState = irSpecializationState; + visitor.mainModuleDecl = mainModuleDecl; + visitor.astDecls = &outASTDecls; + + bool isRewrite = isRewriteRequest(translationUnit->sourceLanguage, target); + + if (isRewrite) + { + visitor.addDecl(mainModuleDecl); + } + else + { + visitor.addDecl(entryPoint->decl); + } + + visitor.flush(); +} + + + } diff --git a/source/slang/ast-legalize.h b/source/slang/ast-legalize.h index 23a150002..ab06d7a21 100644 --- a/source/slang/ast-legalize.h +++ b/source/slang/ast-legalize.h @@ -69,6 +69,13 @@ namespace Slang CodeGenTarget target, ExtensionUsageTracker* extensionUsageTracker, IRSpecializationState* irSpecializationState, - TypeLegalizationContext* typeLegalizationContext); + TypeLegalizationContext* typeLegalizationContext, + List<Decl*> astDecls); + + void findDeclsUsedByASTEntryPoint( + EntryPointRequest* entryPoint, + CodeGenTarget target, + IRSpecializationState* irSpecializationState, + List<Decl*>& outASTDecls); } #endif diff --git a/source/slang/emit.cpp b/source/slang/emit.cpp index bf7ad0c3a..ce17c8d03 100644 --- a/source/slang/emit.cpp +++ b/source/slang/emit.cpp @@ -100,7 +100,7 @@ struct SharedEmitContext Dictionary<IRValue*, UInt> mapIRValueToID; Dictionary<Decl*, UInt> mapDeclToID; - HashSet<Decl*> irDeclsVisited; + HashSet<String> irDeclsVisited; Dictionary<IRBlock*, IRBlock*> irMapContinueTargetToLoopHead; @@ -2415,7 +2415,8 @@ struct EmitVisitor } else { - emit(memberExpr->declRef.GetName()); + EmitDeclRef(memberExpr->declRef); +// emit(memberExpr->declRef.GetName()); } if(needClose) Emit(")"); @@ -2453,7 +2454,8 @@ struct EmitVisitor } else { - emit(memberExpr->declRef.GetName()); + EmitDeclRef(memberExpr->declRef); +// emit(memberExpr->declRef.GetName()); } if(needClose) Emit(")"); @@ -6892,15 +6894,11 @@ emitDeclImpl(decl, nullptr); EmitContext* ctx, DeclRef<StructDecl> declRef) { - // TODO: Eventually need to deal with the case where - // we have user-defined generic types. - // - auto decl = declRef.getDecl(); - - if(ctx->shared->irDeclsVisited.Contains(decl)) + auto mangledName = getMangledName(declRef); + if(ctx->shared->irDeclsVisited.Contains(mangledName)) return; - ctx->shared->irDeclsVisited.Add(decl); + ctx->shared->irDeclsVisited.Add(mangledName); // First emit any types used by fields of this type for( auto ff : GetFields(declRef) ) @@ -6935,6 +6933,25 @@ emitDeclImpl(decl, nullptr); Emit("};\n"); } + void emitIRUsedDeclRef( + EmitContext* ctx, + DeclRef<Decl> declRef) + { + auto decl = declRef.getDecl(); + + if(decl->HasModifier<BuiltinTypeModifier>() + || decl->HasModifier<MagicTypeModifier>()) + { + return; + } + + if( auto structDeclRef = declRef.As<StructDecl>() ) + { + // + ensureStructDecl(ctx, structDeclRef); + } + } + // A type is going to be used by the IR, so // make sure that we have emitted whatever // it needs. @@ -6970,19 +6987,7 @@ emitDeclImpl(decl, nullptr); else if( auto declRefType = type->As<DeclRefType>() ) { auto declRef = declRefType->declRef; - auto decl = declRef.getDecl(); - - if(decl->HasModifier<BuiltinTypeModifier>() - || decl->HasModifier<MagicTypeModifier>()) - { - return; - } - - if( auto structDeclRef = declRef.As<StructDecl>() ) - { - // - ensureStructDecl(ctx, structDeclRef); - } + emitIRUsedDeclRef(ctx, declRef); } else {} @@ -7249,13 +7254,21 @@ String emitEntryPoint( // boilerplate at the start of the ouput for GLSL (e.g., what // version we require). + List<Decl*> astDecls; + findDeclsUsedByASTEntryPoint( + entryPoint, + target, + nullptr, + astDecls); + auto lowered = lowerEntryPoint( entryPoint, programLayout, target, &sharedContext.extensionUsageTracker, nullptr, - &typeLegalizationContext); + &typeLegalizationContext, + astDecls); sharedContext.program = lowered.program; // Note that we emit the main body code of the program *before* @@ -7287,25 +7300,23 @@ String emitEntryPoint( typeLegalizationContext.irModule = irModule; - LoweredEntryPoint lowered; + List<Decl*> astDecls; if(translationUnit->compileFlags & SLANG_COMPILE_FLAG_NO_CHECKING) { // We are in case (2b), where the main module is in unchecked // HLSL/GLSL that we need to "rewrite," and any library code // is in Slang that will need to be cross-compiled via the IR. - // Initially, we will apply the AST-to-AST pass to legalize - // the user's code, much like we would for any other target. - // Along the way, this pass will discover any IR declarations - // that we use, and try to emit code for them into our IR module. + // We first need to walk the AST part of the code to look + // for any places where it references declarations that + // are implemented in the IR, so that we can be sure to + // generate suitable IR code for them. - lowered = lowerEntryPoint( + findDeclsUsedByASTEntryPoint( entryPoint, - programLayout, target, - &sharedContext.extensionUsageTracker, irSpecializationState, - &typeLegalizationContext); + astDecls); } else { @@ -7358,6 +7369,33 @@ String emitEntryPoint( fprintf(stderr, "###\n"); #endif + LoweredEntryPoint lowered; + if(translationUnit->compileFlags & SLANG_COMPILE_FLAG_NO_CHECKING) + { + // In the (2b) case, once we have legalized the IR code, + // we now need to go in and legalize the AST code. + // This order is important because when referring to a variable + // that is defined in the IR, we need to legalize it first (which + // might split it into many decls) before we can legalize an AST + // expression that references that decl (which will also need + // to get split). + // + // We don't have to worry about references in the other direction; + // we don't allow the user to define something in unchecked AST + // code and then use it from the IR shader library. + + lowered = lowerEntryPoint( + entryPoint, + programLayout, + target, + &sharedContext.extensionUsageTracker, + irSpecializationState, + &typeLegalizationContext, + astDecls); + } + + // When emitting IR-based declarations, we wnat to + // track which decls have already been lowered. sharedContext.irDeclSetForAST = &lowered.irDecls; // After all of the required optimization and legalization @@ -7372,6 +7410,13 @@ String emitEntryPoint( // that we need to output, we'll do it now. if (translationUnit->compileFlags & SLANG_COMPILE_FLAG_NO_CHECKING) { + // First make sure that we've emitted any types that were declared + // in the IR, but then subsequently only used by the AST + for( auto decl : lowered.irDecls ) + { + visitor.emitIRUsedDeclRef(&context, makeDeclRef(decl)); + } + visitor.EmitDeclsInContainer(lowered.program); } diff --git a/source/slang/ir-legalize-types.cpp b/source/slang/ir-legalize-types.cpp index fedec4f87..5db3b4d33 100644 --- a/source/slang/ir-legalize-types.cpp +++ b/source/slang/ir-legalize-types.cpp @@ -17,78 +17,6 @@ namespace Slang { - - -struct LegalValImpl : RefObject -{ -}; -struct TuplePseudoVal; -struct PairPseudoVal; - -struct LegalVal -{ - enum class Flavor - { - none, - simple, - implicitDeref, - tuple, - pair, - }; - - Flavor flavor = Flavor::none; - RefPtr<RefObject> obj; - IRValue* irValue = nullptr; - - static LegalVal simple(IRValue* irValue) - { - LegalVal result; - result.flavor = Flavor::simple; - result.irValue = irValue; - return result; - } - - IRValue* getSimple() - { - assert(flavor == Flavor::simple); - return irValue; - } - - static LegalVal tuple(RefPtr<TuplePseudoVal> tupleVal); - - RefPtr<TuplePseudoVal> getTuple() - { - assert(flavor == Flavor::tuple); - return obj.As<TuplePseudoVal>(); - } - - static LegalVal implicitDeref(LegalVal const& val); - LegalVal getImplicitDeref(); - - static LegalVal pair(RefPtr<PairPseudoVal> pairInfo); - static LegalVal pair( - LegalVal const& ordinaryVal, - LegalVal const& specialVal, - RefPtr<PairInfo> pairInfo); - - RefPtr<PairPseudoVal> getPair() - { - assert(flavor == Flavor::pair); - return obj.As<PairPseudoVal>(); - } -}; - -struct TuplePseudoVal : LegalValImpl -{ - struct Element - { - DeclRef<VarDeclBase> fieldDeclRef; - LegalVal val; - }; - - List<Element> elements; -}; - LegalVal LegalVal::tuple(RefPtr<TuplePseudoVal> tupleVal) { LegalVal result; @@ -97,16 +25,6 @@ LegalVal LegalVal::tuple(RefPtr<TuplePseudoVal> tupleVal) return result; } -struct PairPseudoVal : LegalValImpl -{ - LegalVal ordinaryVal; - LegalVal specialVal; - - // The info to tell us which fields - // are on which side(s) - RefPtr<PairInfo> pairInfo; -}; - LegalVal LegalVal::pair(RefPtr<PairPseudoVal> pairInfo) { LegalVal result; @@ -135,11 +53,6 @@ LegalVal LegalVal::pair( return LegalVal::pair(obj); } -struct ImplicitDerefVal : LegalValImpl -{ - LegalVal val; -}; - LegalVal LegalVal::implicitDeref(LegalVal const& val) { RefPtr<ImplicitDerefVal> implicitDerefVal = new ImplicitDerefVal(); @@ -189,12 +102,37 @@ static void registerLegalizedValue( context->mapValToLegalVal.Add(irValue, legalVal); } +static void maybeRegisterLegalizedGlobal( + IRTypeLegalizationContext* context, + IRGlobalVar* irGlobalVar, + LegalVal const& legalVal) +{ + // Check the mangled name of the symbol and don't register + // symbols that don't have an external name (currently + // indicated by them having an empty name string). + String mangledName = irGlobalVar->mangledName; + if (mangledName.Length() == 0) + return; + + // Otherwise, register the legalized value for this symbol + // under its mangled name, so that other code can still + // find the right value(s) to use after legalization. + context->typeLegalizationContext->mapMangledNameToLegalIRValue.AddIfNotExists(mangledName, legalVal); +} + +struct IRGlobalNameInfo +{ + IRGlobalVar* globalVar; + UInt counter; +}; + static LegalVal declareVars( IRTypeLegalizationContext* context, IROp op, LegalType type, TypeLayout* typeLayout, - LegalVarChain* varChain); + LegalVarChain* varChain, + IRGlobalNameInfo* globalNameInfo); static LegalType legalizeType( IRTypeLegalizationContext* context, @@ -608,7 +546,7 @@ static LegalVal legalizeLocalVar( varChain = &varChainStorage; } - LegalVal newVal = declareVars(context, kIROp_Var, legalValueType, typeLayout, varChain); + LegalVal newVal = declareVars(context, kIROp_Var, legalValueType, typeLayout, varChain, nullptr); // Remove the old local var. irLocalVar->removeFromParent(); @@ -761,7 +699,7 @@ static void legalizeFunc( context->insertBeforeParam = pp; context->builder->curBlock = nullptr; - auto paramVal = declareVars(context, kIROp_Param, legalParamType, nullptr, nullptr); + auto paramVal = declareVars(context, kIROp_Param, legalParamType, nullptr, nullptr, nullptr); paramVals.Add(paramVal); if (pp == bb->getFirstParam()) { @@ -807,7 +745,8 @@ static LegalVal declareSimpleVar( IROp op, Type* type, TypeLayout* typeLayout, - LegalVarChain* varChain) + LegalVarChain* varChain, + IRGlobalNameInfo* globalNameInfo) { RefPtr<VarLayout> varLayout = createVarLayout(varChain, typeLayout); @@ -830,6 +769,25 @@ static LegalVal declareSimpleVar( globalVar->removeFromParent(); globalVar->insertBefore(context->insertBeforeGlobal, builder->getModule()); + // The legalization of a global variable with linkage (one that has + // a mangled name), must also have an exported name, so that code + // can link against it. + // + // For now we do something *really* simplistic, and just append + // a counter to each leaf variable generated from the original + if (globalNameInfo) + { + String mangledName = globalNameInfo->globalVar->mangledName; + if (mangledName.Length() != 0) + { + mangledName.append("L"); + mangledName.append(globalNameInfo->counter++); + globalVar->mangledName = mangledName; + } + } + + + irVar = globalVar; legalVarVal = LegalVal::simple(irVar); } @@ -887,7 +845,8 @@ static LegalVal declareVars( IROp op, LegalType type, TypeLayout* typeLayout, - LegalVarChain* varChain) + LegalVarChain* varChain, + IRGlobalNameInfo* globalNameInfo) { switch (type.flavor) { @@ -895,7 +854,7 @@ static LegalVal declareVars( return LegalVal(); case LegalType::Flavor::simple: - return declareSimpleVar(context, op, type.getSimple(), typeLayout, varChain); + return declareSimpleVar(context, op, type.getSimple(), typeLayout, varChain, globalNameInfo); break; case LegalType::Flavor::implicitDeref: @@ -908,7 +867,8 @@ static LegalVal declareVars( op, type.getImplicitDeref()->valueType, getDerefTypeLayout(typeLayout), - varChain); + varChain, + globalNameInfo); return LegalVal::implicitDeref(val); } break; @@ -916,8 +876,8 @@ static LegalVal declareVars( case LegalType::Flavor::pair: { auto pairType = type.getPair(); - auto ordinaryVal = declareVars(context, op, pairType->ordinaryType, typeLayout, varChain); - auto specialVal = declareVars(context, op, pairType->specialType, typeLayout, varChain); + auto ordinaryVal = declareVars(context, op, pairType->ordinaryType, typeLayout, varChain, globalNameInfo); + auto specialVal = declareVars(context, op, pairType->specialType, typeLayout, varChain, globalNameInfo); return LegalVal::pair(ordinaryVal, specialVal, pairType->pairInfo); } @@ -951,7 +911,8 @@ static LegalVal declareVars( op, ee.type, fieldTypeLayout, - newVarChain); + newVarChain, + globalNameInfo); TuplePseudoVal::Element element; element.fieldDeclRef = ee.fieldDeclRef; @@ -1003,11 +964,18 @@ static void legalizeGlobalVar( varChain = &varChainStorage; } - LegalVal newVal = declareVars(context, kIROp_global_var, legalValueType, typeLayout, varChain); + IRGlobalNameInfo globalNameInfo; + globalNameInfo.globalVar = irGlobalVar; + globalNameInfo.counter = 0; + + LegalVal newVal = declareVars(context, kIROp_global_var, legalValueType, typeLayout, varChain, &globalNameInfo); // Register the new value as the replacement for the old registerLegalizedValue(context, irGlobalVar, newVal); + // Also register the variable according to its mangled name, if any. + maybeRegisterLegalizedGlobal(context, irGlobalVar, newVal); + // Remove the old global from the module. irGlobalVar->removeFromParent(); // TODO: actually clean up the global! diff --git a/source/slang/legalize-types.h b/source/slang/legalize-types.h index 36e4223b6..853b9f47f 100644 --- a/source/slang/legalize-types.h +++ b/source/slang/legalize-types.h @@ -244,6 +244,106 @@ RefPtr<VarLayout> createVarLayout( TypeLayout* typeLayout); // +// The result of legalizing an IR value will be +// represented with the `LegalVal` type. It is exposed +// in this header (rather than kept as an implementation +// detail, because the AST-based legalization logic needs +// a way to find the post-legalization version of a +// global name). +// +// TODO: We really shouldn't have this structure exposed, +// and instead should really be constructing AST-side +// `LegalExpr` values on-demand whenever we legalize something +// in the IR that will need to be used by the AST, and then +// store *those* in a map indexed in mangled names. +// + +struct LegalValImpl : RefObject +{ +}; +struct TuplePseudoVal; +struct PairPseudoVal; + +struct LegalVal +{ + enum class Flavor + { + none, + simple, + implicitDeref, + tuple, + pair, + }; + + Flavor flavor = Flavor::none; + RefPtr<RefObject> obj; + IRValue* irValue = nullptr; + + static LegalVal simple(IRValue* irValue) + { + LegalVal result; + result.flavor = Flavor::simple; + result.irValue = irValue; + return result; + } + + IRValue* getSimple() + { + assert(flavor == Flavor::simple); + return irValue; + } + + static LegalVal tuple(RefPtr<TuplePseudoVal> tupleVal); + + RefPtr<TuplePseudoVal> getTuple() + { + assert(flavor == Flavor::tuple); + return obj.As<TuplePseudoVal>(); + } + + static LegalVal implicitDeref(LegalVal const& val); + LegalVal getImplicitDeref(); + + static LegalVal pair(RefPtr<PairPseudoVal> pairInfo); + static LegalVal pair( + LegalVal const& ordinaryVal, + LegalVal const& specialVal, + RefPtr<PairInfo> pairInfo); + + RefPtr<PairPseudoVal> getPair() + { + assert(flavor == Flavor::pair); + return obj.As<PairPseudoVal>(); + } +}; + +struct TuplePseudoVal : LegalValImpl +{ + struct Element + { + DeclRef<VarDeclBase> fieldDeclRef; + LegalVal val; + }; + + List<Element> elements; +}; + +struct PairPseudoVal : LegalValImpl +{ + LegalVal ordinaryVal; + LegalVal specialVal; + + // The info to tell us which fields + // are on which side(s) + RefPtr<PairInfo> pairInfo; +}; + +struct ImplicitDerefVal : LegalValImpl +{ + LegalVal val; +}; + +// struct TypeLegalizationContext { @@ -280,6 +380,9 @@ struct TypeLegalizationContext /// emitting declarations of legalized `struct` types /// multiple times). Dictionary<DeclRef<Decl>, LegalType> mapDeclRefToLegalType; + + // + Dictionary<String, LegalVal> mapMangledNameToLegalIRValue; }; diff --git a/tests/bindings/multiple-parameter-blocks.slang b/tests/bindings/multiple-parameter-blocks.slang index 0a73fdcbd..5fcb9c6d5 100644 --- a/tests/bindings/multiple-parameter-blocks.slang +++ b/tests/bindings/multiple-parameter-blocks.slang @@ -29,20 +29,20 @@ float4 main(float v : V) : SV_Target #else -Texture2D _S1 : register(t0, space0); -Texture2D _S2[4] : register(t1, space0); -SamplerState _S3 : register(s0, space0); +Texture2D _SV01pL0 : register(t0, space0); +Texture2D _SV01pL1[4] : register(t1, space0); +SamplerState _SV01pL2 : register(s0, space0); -Texture2D _S12 : register(t0, space1); -Texture2D _S13[4] : register(t1, space1); -SamplerState _S14 : register(s0, space1); +Texture2D _SV02p1L0 : register(t0, space1); +Texture2D _SV02p1L1[4] : register(t1, space1); +SamplerState _SV02p1L2 : register(s0, space1); float4 main(float v : V) : SV_Target { - return use(_S1, _S3) - + use(_S2[int(v)], _S3) - + use(_S12, _S14) - + use(_S13[int(v)], _S14); + return use(_SV01pL0, _SV01pL2) + + use(_SV01pL1[int(v)], _SV01pL2) + + use(_SV02p1L0, _SV02p1L2) + + use(_SV02p1L1[int(v)], _SV02p1L2); } #endif diff --git a/tests/bugs/split-nested-types.hlsl b/tests/bugs/split-nested-types.hlsl new file mode 100644 index 000000000..210c119df --- /dev/null +++ b/tests/bugs/split-nested-types.hlsl @@ -0,0 +1,30 @@ +// array-size-static-const.hlsl +//TEST:COMPARE_HLSL: -profile ps_5_0 -target dxbc-assembly + +#ifdef __SLANG__ +import split_nested_types; +#else + +struct A { int x; }; + +struct B { float y; }; + +struct C { Texture2D t; SamplerState s; }; + +struct M +{ + A a; + B b; +}; + +#endif + +cbuffer C +{ + M m; +} + +float4 main() : SV_target +{ + return m.b.y; +} diff --git a/tests/bugs/split-nested-types.slang b/tests/bugs/split-nested-types.slang new file mode 100644 index 000000000..ccf95d906 --- /dev/null +++ b/tests/bugs/split-nested-types.slang @@ -0,0 +1,14 @@ +//TEST_IGNORE_FILE: + +struct A { int x; }; + +struct B { float y; }; + +struct C { Texture2D t; SamplerState s; }; + +struct M +{ + A a; + B b; + C c; +}; diff --git a/tests/compute/rewriter-parameter-block-complex.hlsl b/tests/compute/rewriter-parameter-block-complex.hlsl index 4dc312f95..fe7aae4a6 100644 --- a/tests/compute/rewriter-parameter-block-complex.hlsl +++ b/tests/compute/rewriter-parameter-block-complex.hlsl @@ -1,7 +1,5 @@ //TEST(compute):HLSL_COMPUTE:-xslang -no-checking //TEST(compute):COMPARE_COMPUTE:-xslang -use-ir - -// Doesn't work with IR yet. //DISABLED_TEST(compute):HLSL_COMPUTE:-xslang -no-checking -xslang -use-ir //TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):dxbinding(0),glbinding(0),out diff --git a/tests/compute/rewriter-parameter-block.hlsl b/tests/compute/rewriter-parameter-block.hlsl index 0cc06cc10..9d3140475 100644 --- a/tests/compute/rewriter-parameter-block.hlsl +++ b/tests/compute/rewriter-parameter-block.hlsl @@ -1,7 +1,5 @@ //TEST(compute):HLSL_COMPUTE:-xslang -no-checking //TEST(compute):COMPARE_COMPUTE:-xslang -use-ir - -// Doesn't work with rewriter + IR yet. //DISABLED_TEST(compute):HLSL_COMPUTE:-xslang -no-checking -xslang -use-ir //TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):dxbinding(0),glbinding(0),out diff --git a/tests/compute/rewriter-use-ir-type.hlsl b/tests/compute/rewriter-use-ir-type.hlsl new file mode 100644 index 000000000..8d388addf --- /dev/null +++ b/tests/compute/rewriter-use-ir-type.hlsl @@ -0,0 +1,24 @@ +//TEST(compute):HLSL_COMPUTE:-xslang -no-checking -xslang -use-ir + +//TEST_INPUT:cbuffer(data=[1 2 3 4 16 32 48 64]):dxbinding(0),glbinding(0) +//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):dxbinding(0),glbinding(0),out + +import rewriter_use_ir_type; + +RWStructuredBuffer<int> outputBuffer : register(u0); + +cbuffer C : register(b0) +{ + Helper helper; +} + +[numthreads(4, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + uint tid = dispatchThreadID.x; + int inVal = tid; + + int outVal = helper.a[inVal]; + + outputBuffer[tid] = outVal; +}
\ No newline at end of file diff --git a/tests/compute/rewriter-use-ir-type.hlsl.expected.txt b/tests/compute/rewriter-use-ir-type.hlsl.expected.txt new file mode 100644 index 000000000..94ebaf900 --- /dev/null +++ b/tests/compute/rewriter-use-ir-type.hlsl.expected.txt @@ -0,0 +1,4 @@ +1 +2 +3 +4 diff --git a/tests/compute/rewriter-use-ir-type.slang b/tests/compute/rewriter-use-ir-type.slang new file mode 100644 index 000000000..8870f000f --- /dev/null +++ b/tests/compute/rewriter-use-ir-type.slang @@ -0,0 +1,6 @@ +//TEST_IGNORE_FILE: + +struct Helper +{ + int4 a; +};
\ No newline at end of file |
