diff options
Diffstat (limited to 'source')
| -rw-r--r-- | source/core/list.h | 2 | ||||
| -rw-r--r-- | source/slang/ast-legalize.cpp | 644 | ||||
| -rw-r--r-- | source/slang/ast-legalize.h | 9 | ||||
| -rw-r--r-- | source/slang/emit.cpp | 111 | ||||
| -rw-r--r-- | source/slang/ir-legalize-types.cpp | 164 | ||||
| -rw-r--r-- | source/slang/legalize-types.h | 103 |
6 files changed, 885 insertions, 148 deletions
diff --git a/source/core/list.h b/source/core/list.h index b1461a260..5a94d8b83 100644 --- a/source/core/list.h +++ b/source/core/list.h @@ -381,7 +381,7 @@ namespace Slang void Reverse() { - for (int i = 0; i < (_count >> 1); i++) + for (UInt i = 0; i < (_count >> 1); i++) { Swap(buffer, i, _count - i - 1); } diff --git a/source/slang/ast-legalize.cpp b/source/slang/ast-legalize.cpp index 144277301..2f7814055 100644 --- a/source/slang/ast-legalize.cpp +++ b/source/slang/ast-legalize.cpp @@ -4,6 +4,7 @@ #include "emit.h" #include "ir-insts.h" #include "legalize-types.h" +#include "mangle.h" #include "type-layout.h" #include "visitor.h" @@ -868,7 +869,7 @@ struct LoweringVisitor RefPtr<Expr> createUncheckedVarRef( - char const* name) + String const& name) { return createUncheckedVarRef( shared->compileRequest->getNamePool()->getName(name)); @@ -2672,6 +2673,105 @@ struct LoweringVisitor return translateDeclRefImpl(DeclRef<Decl>(decl, nullptr)); } + LegalExpr translateSimpleLegalValToLegalExpr(IRValue* irVal) + { + switch (irVal->op) + { + case kIROp_global_var: + { + IRGlobalVar* globalVar = (IRGlobalVar*)irVal; + String mangledName = globalVar->mangledName; + SLANG_ASSERT(mangledName.Length() != 0); + + return LegalExpr(createUncheckedVarRef(mangledName)); + } + break; + + default: + SLANG_UNEXPECTED("unhandled opcode"); + UNREACHABLE_RETURN(LegalExpr()); + } + } + + LegalExpr translateLegalValToLegalExpr(LegalVal legalVal) + { + switch (legalVal.flavor) + { + case LegalVal::Flavor::none: + return LegalExpr(); + + case LegalVal::Flavor::simple: + return translateSimpleLegalValToLegalExpr(legalVal.getSimple()); + break; + + case LegalVal::Flavor::pair: + { + auto pairVal = legalVal.getPair(); + RefPtr<PairPseudoExpr> pairExpr = new PairPseudoExpr(); + pairExpr->pairInfo = pairVal->pairInfo; + pairExpr->ordinary = translateLegalValToLegalExpr(pairVal->ordinaryVal); + pairExpr->special = translateLegalValToLegalExpr(pairVal->specialVal); + return LegalExpr(pairExpr); + } + break; + + case LegalVal::Flavor::tuple: + { + auto tupleVal = legalVal.getTuple(); + RefPtr<TuplePseudoExpr> tupleExpr = new TuplePseudoExpr(); + for (auto ee : tupleVal->elements) + { + TuplePseudoExpr::Element element; + element.fieldDeclRef = ee.fieldDeclRef; + element.expr = translateLegalValToLegalExpr(ee.val); + tupleExpr->elements.Add(element); + } + return LegalExpr(tupleExpr); + } + break; + + case LegalVal::Flavor::implicitDeref: + { + auto implicitDerefVal = legalVal.getImplicitDeref(); + RefPtr<ImplicitDerefPseudoExpr> implicitDerefExpr = new ImplicitDerefPseudoExpr(); + implicitDerefExpr->valueExpr = translateLegalValToLegalExpr(implicitDerefVal); + return LegalExpr(implicitDerefExpr); + } + break; + + default: + SLANG_UNEXPECTED("unhandled flavor"); + UNREACHABLE_RETURN(LegalExpr()); + } + } + + void maybeLegalizeIRGlobal( + DeclRef<Decl> declRef) + { + // We've been given a decl-ref to a value that was translated via IR, + // and we need to determine if it needs custom handling for legalization, + // because it was an IR global that got split. + + // TODO: this code is using decls in places it should use decl-refs, + // and that likely needs to get cleaned up... + auto decl = declRef.getDecl(); + + // If we already have an expression registered, then don't bother. + if (shared->mapOriginalDeclToExpr.ContainsKey(decl)) + return; + + String mangledName = getMangledName(declRef); + if (mangledName.Length() == 0) + return; + + LegalVal legalVal; + if (!shared->typeLegalizationContext->mapMangledNameToLegalIRValue.TryGetValue(mangledName, legalVal)) + return; + + LegalExpr legalExpr = translateLegalValToLegalExpr(legalVal); + shared->mapOriginalDeclToExpr.Add(decl, legalExpr); + } + RefPtr<Decl> translateDeclRefImpl( DeclRef<Decl> declRef) { @@ -2703,24 +2803,25 @@ struct LoweringVisitor auto parentModule = findModuleForDecl(decl); if (parentModule && (parentModule != shared->mainModuleDecl)) { - // Ensure that the IR code for the given declaration - // gets included in the output IR module, and *also* - // that we generate a suitable specialization of it - // if there are any substitutions in effect. - - getSpecializedGlobalValueForDeclRef( - shared->irSpecializationState, - declRef); + // This declaration should already have been lowered to + // the IR during the "walk" pass that happened earlier, + // and so we won't do it again here. + // + // Instead, we need to check if the particular + // declaration is one that needs to be swapped for + // a legalized expression (e.g., because it was an IR + // global that got split) + // + maybeLegalizeIRGlobal(declRef); // Remember that this declaration is handled via IR, // rather than being present in the legalized AST. shared->result.irDecls.Add(declRef.getDecl()); - // We don't actually use the `IRGlobalValue` that the - // above operation returns, and instead just keep - // using the original declaration in the legalized - // AST. The step of mapping that declaration over - // to reference the IR symbol will happen later. + // This method can't actually return a `LegalExpr`, + // so for now we just assume that the original + // declaration is the right stand-in for the IR + // value we want. return decl; } @@ -4671,7 +4772,8 @@ LoweredEntryPoint lowerEntryPoint( CodeGenTarget target, ExtensionUsageTracker* extensionUsageTracker, IRSpecializationState* irSpecializationState, - TypeLegalizationContext* typeLegalizationContext) + TypeLegalizationContext* typeLegalizationContext, + List<Decl*> astDecls) { SharedLoweringContext sharedContext; sharedContext.compileRequest = entryPoint->compileRequest; @@ -4732,10 +4834,21 @@ LoweredEntryPoint lowerEntryPoint( if (isRewrite) { + for (auto dd : astDecls) + { + // Skip non-global decls + if (!dd->ParentDecl) + continue; + if (!dynamic_cast<ModuleDecl*>(dd->ParentDecl)) + continue; + visitor.translateDeclRef(dd); + } +#if 0 for (auto dd : translationUnit->SyntaxNode->Members) { visitor.translateDeclRef(dd); } +#endif } else { @@ -4747,4 +4860,505 @@ LoweredEntryPoint lowerEntryPoint( return sharedContext.result; } + +struct FindIRDeclUsedByASTVisitor + : ExprVisitor<FindIRDeclUsedByASTVisitor, void> + , StmtVisitor<FindIRDeclUsedByASTVisitor, void> + , DeclVisitor<FindIRDeclUsedByASTVisitor, void> + , ValVisitor<FindIRDeclUsedByASTVisitor, void, void> + +{ + CompileRequest* compileRequest; + IRSpecializationState* irSpecializationState; + ModuleDecl* mainModuleDecl; + + // Declarations to be processed by the AST lowering pass + List<Decl*>* astDecls; + + HashSet<DeclBase*> seenDecls; + HashSet<DeclBase*> addedDecls; + + void walkType(Type* type) + { + if(!type) return; + + TypeVisitor::dispatch(type); + } + + void walkVal(Val* val) + { + if(!val) return; + + ValVisitor::dispatch(val); + } + + void walkExpr(Expr* expr) + { + if(!expr) return; + + ExprVisitor::dispatch(expr); + } + + void walkStmt(Stmt* stmt) + { + if(!stmt) return; + + StmtVisitor::dispatch(stmt); + } + + void walkSubst(Substitutions* subst) + { + if( auto genericSubst = dynamic_cast<GenericSubstitution*>(subst) ) + { + for( auto arg : genericSubst->args ) + { + walkVal(arg); + } + } + // TODO: handle other cases here + } + + void walkDeclRef(DeclRef<Decl> const& declRef) + { + Decl* decl = declRef.getDecl(); + if (!decl) return; + + // If this is a specialized declaration reference, then any + // of the arguments also need to be walked. + for(auto subst = declRef.substitutions; subst; subst = subst->outer) + { + walkSubst(subst); + } + + // If any parent of the declaration was in the stdlib, or + // is registered as a builtin, then skip it. + for (auto pp = decl; pp; pp = pp->ParentDecl) + { + if (pp->HasModifier<FromStdLibModifier>()) + return; + + if (pp->HasModifier<BuiltinModifier>()) + return; + } + + // If we are using the IR, and the declaration comes from + // an imported module (rather than the "rewrite-mode" module + // being translated), then we need to ensure that it gets lowered + // to IR instead. + if (compileRequest->compileFlags & SLANG_COMPILE_FLAG_USE_IR) + { + auto parentModule = findModuleForDecl(decl); + if (parentModule && (parentModule != mainModuleDecl)) + { + // Ensure that the IR code for the given declaration + // gets included in the output IR module, and *also* + // that we generate a suitable specialization of it + // if there are any substitutions in effect. + + getSpecializedGlobalValueForDeclRef( + irSpecializationState, + declRef); + + // TODO: we probably need to track this value... + + return; + } + } + + // If none of the above triggered, then we seemingly have + // a declaration from the current module, and we should + // add it to our work list so we can walk it too. + addDecl(decl); + } + + // Vals + + void visitIRProxyVal(IRProxyVal*) + {} + + void visitConstantIntVal(ConstantIntVal*) + {} + + void visitGenericParamIntVal(GenericParamIntVal* val) + { + walkDeclRef(val->declRef); + } + + void visitWitness(Witness*) + {} + + // Types + + void visitOverloadGroupType(OverloadGroupType*) + {} + + void visitInitializerListType(InitializerListType*) + {} + + void visitErrorType(ErrorType*) + {} + + void visitIRBasicBlockType(IRBasicBlockType*) + {} + + void visitDeclRefType(DeclRefType* type) + { + walkDeclRef(type->declRef); + } + + void visitGenericDeclRefType(GenericDeclRefType* type) + { + walkDeclRef(type->declRef); + } + + void visitNamedExpressionType(NamedExpressionType* type) + { + walkDeclRef(type->declRef); + } + + void visitFuncType(FuncType* type) + { + for( auto p : type->paramTypes ) + { + walkType(p); + } + walkType(type->resultType); + } + + void visitTypeType(TypeType* type) + { + walkType(type->type); + } + + void visitGroupSharedType(GroupSharedType* type) + { + walkType(type->valueType); + } + + void visitArrayExpressionType(ArrayExpressionType* type) + { + walkType(type->baseType); + walkVal(type->ArrayLength); + } + + // Exprs + + void visitVarExpr(VarExpr* expr) + { + walkDeclRef(expr->declRef); + } + + void visitMemberExpr(MemberExpr* expr) + { + walkExpr(expr->BaseExpression); + walkDeclRef(expr->declRef); + } + + void visitStaticMemberExpr(StaticMemberExpr* expr) + { + walkExpr(expr->BaseExpression); + walkDeclRef(expr->declRef); + } + + void visitOverloadedExpr(OverloadedExpr* expr) + { + walkExpr(expr->base); + + // TODO: need to walk the lookup result too + } + + void visitConstantExpr(ConstantExpr*) + {} + + void visitInitializerListExpr(InitializerListExpr* expr) + { + for(auto a : expr->args) + walkExpr(a); + } + + void visitAppExprBase(AppExprBase* expr) + { + walkExpr(expr->FunctionExpr); + for(auto a : expr->Arguments) + walkExpr(a); + } + + void visitAggTypeCtorExpr(AggTypeCtorExpr* expr) + { + walkType(expr->base); + for(auto a : expr->Arguments) + walkExpr(a); + } + + void visitSharedTypeExpr(SharedTypeExpr* expr) + { + walkType(expr->base); + } + + void visitAssignExpr(AssignExpr* expr) + { + walkExpr(expr->left); + walkExpr(expr->right); + } + + void visitIndexExpr(IndexExpr* expr) + { + walkExpr(expr->BaseExpression); + walkExpr(expr->IndexExpression); + } + + void visitSwizzleExpr(SwizzleExpr* expr) + { + walkExpr(expr->base); + } + + void visitDerefExpr(DerefExpr* expr) + { + walkExpr(expr->base); + } + + void visitParenExpr(ParenExpr* expr) + { + walkExpr(expr->base); + } + + void visitThisExpr(ThisExpr*) + {} + + // Stmts + + void visitSeqStmt(SeqStmt* stmt) + { + for( auto s : stmt->stmts ) + { + walkStmt(s); + } + } + + void visitBlockStmt(BlockStmt* stmt) + { + walkStmt(stmt->body); + } + + void visitUnparsedStmt(UnparsedStmt*) + {} + + void visitEmptyStmt(EmptyStmt*) + {} + + void visitDiscardStmt(DiscardStmt*) + {} + + void visitDeclStmt(DeclStmt* stmt) + { + addDecl(stmt->decl); + } + + void visitIfStmt(IfStmt* stmt) + { + walkExpr(stmt->Predicate); + walkStmt(stmt->PositiveStatement); + walkStmt(stmt->NegativeStatement); + } + + void visitSwitchStmt(SwitchStmt* stmt) + { + walkExpr(stmt->condition); + walkStmt(stmt->body); + } + + void visitCaseStmt(CaseStmt* stmt) + { + walkExpr(stmt->expr); + } + + void visitDefaultStmt(DefaultStmt*) + {} + + void visitForStmt(ForStmt* stmt) + { + walkStmt(stmt->InitialStatement); + walkExpr(stmt->SideEffectExpression); + walkExpr(stmt->PredicateExpression); + walkStmt(stmt->Statement); + } + + void visitWhileStmt(WhileStmt* stmt) + { + walkExpr(stmt->Predicate); + walkStmt(stmt->Statement); + } + + void visitDoWhileStmt(DoWhileStmt* stmt) + { + walkExpr(stmt->Predicate); + walkStmt(stmt->Statement); + } + + void visitCompileTimeForStmt(CompileTimeForStmt* stmt) + { + addDecl(stmt->varDecl); + walkExpr(stmt->rangeBeginExpr); + walkExpr(stmt->rangeEndExpr); + walkStmt(stmt->body); + } + + void visitReturnStmt(ReturnStmt* stmt) + { + walkExpr(stmt->Expression); + } + + void visitExpressionStmt(ExpressionStmt* stmt) + { + walkExpr(stmt->Expression); + } + + void visitJumpStmt(JumpStmt*) + {} + + // Decls + + void visitDeclGroup(DeclGroup* declGroup) + { + for( auto dd : declGroup->decls ) + { + addDecl(dd); + } + } + + void visitContainerDeclCommon(ContainerDecl* decl) + { + for( auto mm : decl->Members ) + { + addDecl(mm); + } + } + + void visitContainerDecl(ContainerDecl* decl) + { + visitContainerDeclCommon(decl); + } + + void visitVarDeclBase(VarDeclBase* decl) + { + walkType(decl->type); + walkExpr(decl->initExpr); + } + + void visitAggTypeDeclBase(AggTypeDeclBase* decl) + { + visitContainerDeclCommon(decl); + } + + void visitInheritanceDecl(InheritanceDecl* decl) + { + walkType(decl->base); + } + + void visitTypeDefDecl(TypeDefDecl* decl) + { + walkType(decl->type); + } + + void visitCallableDeclCommon(CallableDecl* decl) + { + visitContainerDeclCommon(decl); + walkType(decl->ReturnType); + } + + void visitCallableDecl(CallableDecl* decl) + { + visitCallableDeclCommon(decl); + } + + void visitFunctionDeclBase(FunctionDeclBase* decl) + { + visitCallableDeclCommon(decl); + walkStmt(decl->Body); + } + + void visitImportDecl(ImportDecl*) + {} + + void visitGenericTypeParamDecl(GenericTypeParamDecl*) + {} + + void visitGenericTypeConstraintDecl(GenericTypeConstraintDecl*) + {} + + void visitEmptyDecl(EmptyDecl*) + {} + + void visitSyntaxDecl(SyntaxDecl*) + {} + + // + + void addDecl(DeclBase* decl) + { + // Has this decl already been added + // to the output list? + if(addedDecls.Contains(decl)) + return; + + // Are we in the middel of processing this + // decl? + // + // TODO: this implies a cycle, and we need to + // break it! + if (seenDecls.Contains(decl)) + return; + + seenDecls.Add(decl); + + // Recurse on the given decl + DeclVisitor::dispatch(decl); + + // Add it to the output list, if needed + if (auto dd = dynamic_cast<Decl*>(decl)) + { + (*astDecls).Add(dd); + } + + // Mark it as completely processed + addedDecls.Add(decl); + } + + void flush() + { + } +}; + + +void findDeclsUsedByASTEntryPoint( + EntryPointRequest* entryPoint, + CodeGenTarget target, + IRSpecializationState* irSpecializationState, + List<Decl*>& outASTDecls) +{ + auto translationUnit = entryPoint->getTranslationUnit(); + auto mainModuleDecl = translationUnit->SyntaxNode; + + FindIRDeclUsedByASTVisitor visitor; + visitor.compileRequest = entryPoint->compileRequest; + visitor.irSpecializationState = irSpecializationState; + visitor.mainModuleDecl = mainModuleDecl; + visitor.astDecls = &outASTDecls; + + bool isRewrite = isRewriteRequest(translationUnit->sourceLanguage, target); + + if (isRewrite) + { + visitor.addDecl(mainModuleDecl); + } + else + { + visitor.addDecl(entryPoint->decl); + } + + visitor.flush(); +} + + + } diff --git a/source/slang/ast-legalize.h b/source/slang/ast-legalize.h index 23a150002..ab06d7a21 100644 --- a/source/slang/ast-legalize.h +++ b/source/slang/ast-legalize.h @@ -69,6 +69,13 @@ namespace Slang CodeGenTarget target, ExtensionUsageTracker* extensionUsageTracker, IRSpecializationState* irSpecializationState, - TypeLegalizationContext* typeLegalizationContext); + TypeLegalizationContext* typeLegalizationContext, + List<Decl*> astDecls); + + void findDeclsUsedByASTEntryPoint( + EntryPointRequest* entryPoint, + CodeGenTarget target, + IRSpecializationState* irSpecializationState, + List<Decl*>& outASTDecls); } #endif diff --git a/source/slang/emit.cpp b/source/slang/emit.cpp index bf7ad0c3a..ce17c8d03 100644 --- a/source/slang/emit.cpp +++ b/source/slang/emit.cpp @@ -100,7 +100,7 @@ struct SharedEmitContext Dictionary<IRValue*, UInt> mapIRValueToID; Dictionary<Decl*, UInt> mapDeclToID; - HashSet<Decl*> irDeclsVisited; + HashSet<String> irDeclsVisited; Dictionary<IRBlock*, IRBlock*> irMapContinueTargetToLoopHead; @@ -2415,7 +2415,8 @@ struct EmitVisitor } else { - emit(memberExpr->declRef.GetName()); + EmitDeclRef(memberExpr->declRef); +// emit(memberExpr->declRef.GetName()); } if(needClose) Emit(")"); @@ -2453,7 +2454,8 @@ struct EmitVisitor } else { - emit(memberExpr->declRef.GetName()); + EmitDeclRef(memberExpr->declRef); +// emit(memberExpr->declRef.GetName()); } if(needClose) Emit(")"); @@ -6892,15 +6894,11 @@ emitDeclImpl(decl, nullptr); EmitContext* ctx, DeclRef<StructDecl> declRef) { - // TODO: Eventually need to deal with the case where - // we have user-defined generic types. - // - auto decl = declRef.getDecl(); - - if(ctx->shared->irDeclsVisited.Contains(decl)) + auto mangledName = getMangledName(declRef); + if(ctx->shared->irDeclsVisited.Contains(mangledName)) return; - ctx->shared->irDeclsVisited.Add(decl); + ctx->shared->irDeclsVisited.Add(mangledName); // First emit any types used by fields of this type for( auto ff : GetFields(declRef) ) @@ -6935,6 +6933,25 @@ emitDeclImpl(decl, nullptr); Emit("};\n"); } + void emitIRUsedDeclRef( + EmitContext* ctx, + DeclRef<Decl> declRef) + { + auto decl = declRef.getDecl(); + + if(decl->HasModifier<BuiltinTypeModifier>() + || decl->HasModifier<MagicTypeModifier>()) + { + return; + } + + if( auto structDeclRef = declRef.As<StructDecl>() ) + { + // + ensureStructDecl(ctx, structDeclRef); + } + } + // A type is going to be used by the IR, so // make sure that we have emitted whatever // it needs. @@ -6970,19 +6987,7 @@ emitDeclImpl(decl, nullptr); else if( auto declRefType = type->As<DeclRefType>() ) { auto declRef = declRefType->declRef; - auto decl = declRef.getDecl(); - - if(decl->HasModifier<BuiltinTypeModifier>() - || decl->HasModifier<MagicTypeModifier>()) - { - return; - } - - if( auto structDeclRef = declRef.As<StructDecl>() ) - { - // - ensureStructDecl(ctx, structDeclRef); - } + emitIRUsedDeclRef(ctx, declRef); } else {} @@ -7249,13 +7254,21 @@ String emitEntryPoint( // boilerplate at the start of the ouput for GLSL (e.g., what // version we require). + List<Decl*> astDecls; + findDeclsUsedByASTEntryPoint( + entryPoint, + target, + nullptr, + astDecls); + auto lowered = lowerEntryPoint( entryPoint, programLayout, target, &sharedContext.extensionUsageTracker, nullptr, - &typeLegalizationContext); + &typeLegalizationContext, + astDecls); sharedContext.program = lowered.program; // Note that we emit the main body code of the program *before* @@ -7287,25 +7300,23 @@ String emitEntryPoint( typeLegalizationContext.irModule = irModule; - LoweredEntryPoint lowered; + List<Decl*> astDecls; if(translationUnit->compileFlags & SLANG_COMPILE_FLAG_NO_CHECKING) { // We are in case (2b), where the main module is in unchecked // HLSL/GLSL that we need to "rewrite," and any library code // is in Slang that will need to be cross-compiled via the IR. - // Initially, we will apply the AST-to-AST pass to legalize - // the user's code, much like we would for any other target. - // Along the way, this pass will discover any IR declarations - // that we use, and try to emit code for them into our IR module. + // We first need to walk the AST part of the code to look + // for any places where it references declarations that + // are implemented in the IR, so that we can be sure to + // generate suitable IR code for them. - lowered = lowerEntryPoint( + findDeclsUsedByASTEntryPoint( entryPoint, - programLayout, target, - &sharedContext.extensionUsageTracker, irSpecializationState, - &typeLegalizationContext); + astDecls); } else { @@ -7358,6 +7369,33 @@ String emitEntryPoint( fprintf(stderr, "###\n"); #endif + LoweredEntryPoint lowered; + if(translationUnit->compileFlags & SLANG_COMPILE_FLAG_NO_CHECKING) + { + // In the (2b) case, once we have legalized the IR code, + // we now need to go in and legalize the AST code. + // This order is important because when referring to a variable + // that is defined in the IR, we need to legalize it first (which + // might split it into many decls) before we can legalize an AST + // expression that references that decl (which will also need + // to get split). + // + // We don't have to worry about references in the other direction; + // we don't allow the user to define something in unchecked AST + // code and then use it from the IR shader library. + + lowered = lowerEntryPoint( + entryPoint, + programLayout, + target, + &sharedContext.extensionUsageTracker, + irSpecializationState, + &typeLegalizationContext, + astDecls); + } + + // When emitting IR-based declarations, we wnat to + // track which decls have already been lowered. sharedContext.irDeclSetForAST = &lowered.irDecls; // After all of the required optimization and legalization @@ -7372,6 +7410,13 @@ String emitEntryPoint( // that we need to output, we'll do it now. if (translationUnit->compileFlags & SLANG_COMPILE_FLAG_NO_CHECKING) { + // First make sure that we've emitted any types that were declared + // in the IR, but then subsequently only used by the AST + for( auto decl : lowered.irDecls ) + { + visitor.emitIRUsedDeclRef(&context, makeDeclRef(decl)); + } + visitor.EmitDeclsInContainer(lowered.program); } diff --git a/source/slang/ir-legalize-types.cpp b/source/slang/ir-legalize-types.cpp index fedec4f87..5db3b4d33 100644 --- a/source/slang/ir-legalize-types.cpp +++ b/source/slang/ir-legalize-types.cpp @@ -17,78 +17,6 @@ namespace Slang { - - -struct LegalValImpl : RefObject -{ -}; -struct TuplePseudoVal; -struct PairPseudoVal; - -struct LegalVal -{ - enum class Flavor - { - none, - simple, - implicitDeref, - tuple, - pair, - }; - - Flavor flavor = Flavor::none; - RefPtr<RefObject> obj; - IRValue* irValue = nullptr; - - static LegalVal simple(IRValue* irValue) - { - LegalVal result; - result.flavor = Flavor::simple; - result.irValue = irValue; - return result; - } - - IRValue* getSimple() - { - assert(flavor == Flavor::simple); - return irValue; - } - - static LegalVal tuple(RefPtr<TuplePseudoVal> tupleVal); - - RefPtr<TuplePseudoVal> getTuple() - { - assert(flavor == Flavor::tuple); - return obj.As<TuplePseudoVal>(); - } - - static LegalVal implicitDeref(LegalVal const& val); - LegalVal getImplicitDeref(); - - static LegalVal pair(RefPtr<PairPseudoVal> pairInfo); - static LegalVal pair( - LegalVal const& ordinaryVal, - LegalVal const& specialVal, - RefPtr<PairInfo> pairInfo); - - RefPtr<PairPseudoVal> getPair() - { - assert(flavor == Flavor::pair); - return obj.As<PairPseudoVal>(); - } -}; - -struct TuplePseudoVal : LegalValImpl -{ - struct Element - { - DeclRef<VarDeclBase> fieldDeclRef; - LegalVal val; - }; - - List<Element> elements; -}; - LegalVal LegalVal::tuple(RefPtr<TuplePseudoVal> tupleVal) { LegalVal result; @@ -97,16 +25,6 @@ LegalVal LegalVal::tuple(RefPtr<TuplePseudoVal> tupleVal) return result; } -struct PairPseudoVal : LegalValImpl -{ - LegalVal ordinaryVal; - LegalVal specialVal; - - // The info to tell us which fields - // are on which side(s) - RefPtr<PairInfo> pairInfo; -}; - LegalVal LegalVal::pair(RefPtr<PairPseudoVal> pairInfo) { LegalVal result; @@ -135,11 +53,6 @@ LegalVal LegalVal::pair( return LegalVal::pair(obj); } -struct ImplicitDerefVal : LegalValImpl -{ - LegalVal val; -}; - LegalVal LegalVal::implicitDeref(LegalVal const& val) { RefPtr<ImplicitDerefVal> implicitDerefVal = new ImplicitDerefVal(); @@ -189,12 +102,37 @@ static void registerLegalizedValue( context->mapValToLegalVal.Add(irValue, legalVal); } +static void maybeRegisterLegalizedGlobal( + IRTypeLegalizationContext* context, + IRGlobalVar* irGlobalVar, + LegalVal const& legalVal) +{ + // Check the mangled name of the symbol and don't register + // symbols that don't have an external name (currently + // indicated by them having an empty name string). + String mangledName = irGlobalVar->mangledName; + if (mangledName.Length() == 0) + return; + + // Otherwise, register the legalized value for this symbol + // under its mangled name, so that other code can still + // find the right value(s) to use after legalization. + context->typeLegalizationContext->mapMangledNameToLegalIRValue.AddIfNotExists(mangledName, legalVal); +} + +struct IRGlobalNameInfo +{ + IRGlobalVar* globalVar; + UInt counter; +}; + static LegalVal declareVars( IRTypeLegalizationContext* context, IROp op, LegalType type, TypeLayout* typeLayout, - LegalVarChain* varChain); + LegalVarChain* varChain, + IRGlobalNameInfo* globalNameInfo); static LegalType legalizeType( IRTypeLegalizationContext* context, @@ -608,7 +546,7 @@ static LegalVal legalizeLocalVar( varChain = &varChainStorage; } - LegalVal newVal = declareVars(context, kIROp_Var, legalValueType, typeLayout, varChain); + LegalVal newVal = declareVars(context, kIROp_Var, legalValueType, typeLayout, varChain, nullptr); // Remove the old local var. irLocalVar->removeFromParent(); @@ -761,7 +699,7 @@ static void legalizeFunc( context->insertBeforeParam = pp; context->builder->curBlock = nullptr; - auto paramVal = declareVars(context, kIROp_Param, legalParamType, nullptr, nullptr); + auto paramVal = declareVars(context, kIROp_Param, legalParamType, nullptr, nullptr, nullptr); paramVals.Add(paramVal); if (pp == bb->getFirstParam()) { @@ -807,7 +745,8 @@ static LegalVal declareSimpleVar( IROp op, Type* type, TypeLayout* typeLayout, - LegalVarChain* varChain) + LegalVarChain* varChain, + IRGlobalNameInfo* globalNameInfo) { RefPtr<VarLayout> varLayout = createVarLayout(varChain, typeLayout); @@ -830,6 +769,25 @@ static LegalVal declareSimpleVar( globalVar->removeFromParent(); globalVar->insertBefore(context->insertBeforeGlobal, builder->getModule()); + // The legalization of a global variable with linkage (one that has + // a mangled name), must also have an exported name, so that code + // can link against it. + // + // For now we do something *really* simplistic, and just append + // a counter to each leaf variable generated from the original + if (globalNameInfo) + { + String mangledName = globalNameInfo->globalVar->mangledName; + if (mangledName.Length() != 0) + { + mangledName.append("L"); + mangledName.append(globalNameInfo->counter++); + globalVar->mangledName = mangledName; + } + } + + + irVar = globalVar; legalVarVal = LegalVal::simple(irVar); } @@ -887,7 +845,8 @@ static LegalVal declareVars( IROp op, LegalType type, TypeLayout* typeLayout, - LegalVarChain* varChain) + LegalVarChain* varChain, + IRGlobalNameInfo* globalNameInfo) { switch (type.flavor) { @@ -895,7 +854,7 @@ static LegalVal declareVars( return LegalVal(); case LegalType::Flavor::simple: - return declareSimpleVar(context, op, type.getSimple(), typeLayout, varChain); + return declareSimpleVar(context, op, type.getSimple(), typeLayout, varChain, globalNameInfo); break; case LegalType::Flavor::implicitDeref: @@ -908,7 +867,8 @@ static LegalVal declareVars( op, type.getImplicitDeref()->valueType, getDerefTypeLayout(typeLayout), - varChain); + varChain, + globalNameInfo); return LegalVal::implicitDeref(val); } break; @@ -916,8 +876,8 @@ static LegalVal declareVars( case LegalType::Flavor::pair: { auto pairType = type.getPair(); - auto ordinaryVal = declareVars(context, op, pairType->ordinaryType, typeLayout, varChain); - auto specialVal = declareVars(context, op, pairType->specialType, typeLayout, varChain); + auto ordinaryVal = declareVars(context, op, pairType->ordinaryType, typeLayout, varChain, globalNameInfo); + auto specialVal = declareVars(context, op, pairType->specialType, typeLayout, varChain, globalNameInfo); return LegalVal::pair(ordinaryVal, specialVal, pairType->pairInfo); } @@ -951,7 +911,8 @@ static LegalVal declareVars( op, ee.type, fieldTypeLayout, - newVarChain); + newVarChain, + globalNameInfo); TuplePseudoVal::Element element; element.fieldDeclRef = ee.fieldDeclRef; @@ -1003,11 +964,18 @@ static void legalizeGlobalVar( varChain = &varChainStorage; } - LegalVal newVal = declareVars(context, kIROp_global_var, legalValueType, typeLayout, varChain); + IRGlobalNameInfo globalNameInfo; + globalNameInfo.globalVar = irGlobalVar; + globalNameInfo.counter = 0; + + LegalVal newVal = declareVars(context, kIROp_global_var, legalValueType, typeLayout, varChain, &globalNameInfo); // Register the new value as the replacement for the old registerLegalizedValue(context, irGlobalVar, newVal); + // Also register the variable according to its mangled name, if any. + maybeRegisterLegalizedGlobal(context, irGlobalVar, newVal); + // Remove the old global from the module. irGlobalVar->removeFromParent(); // TODO: actually clean up the global! diff --git a/source/slang/legalize-types.h b/source/slang/legalize-types.h index 36e4223b6..853b9f47f 100644 --- a/source/slang/legalize-types.h +++ b/source/slang/legalize-types.h @@ -244,6 +244,106 @@ RefPtr<VarLayout> createVarLayout( TypeLayout* typeLayout); // +// The result of legalizing an IR value will be +// represented with the `LegalVal` type. It is exposed +// in this header (rather than kept as an implementation +// detail, because the AST-based legalization logic needs +// a way to find the post-legalization version of a +// global name). +// +// TODO: We really shouldn't have this structure exposed, +// and instead should really be constructing AST-side +// `LegalExpr` values on-demand whenever we legalize something +// in the IR that will need to be used by the AST, and then +// store *those* in a map indexed in mangled names. +// + +struct LegalValImpl : RefObject +{ +}; +struct TuplePseudoVal; +struct PairPseudoVal; + +struct LegalVal +{ + enum class Flavor + { + none, + simple, + implicitDeref, + tuple, + pair, + }; + + Flavor flavor = Flavor::none; + RefPtr<RefObject> obj; + IRValue* irValue = nullptr; + + static LegalVal simple(IRValue* irValue) + { + LegalVal result; + result.flavor = Flavor::simple; + result.irValue = irValue; + return result; + } + + IRValue* getSimple() + { + assert(flavor == Flavor::simple); + return irValue; + } + + static LegalVal tuple(RefPtr<TuplePseudoVal> tupleVal); + + RefPtr<TuplePseudoVal> getTuple() + { + assert(flavor == Flavor::tuple); + return obj.As<TuplePseudoVal>(); + } + + static LegalVal implicitDeref(LegalVal const& val); + LegalVal getImplicitDeref(); + + static LegalVal pair(RefPtr<PairPseudoVal> pairInfo); + static LegalVal pair( + LegalVal const& ordinaryVal, + LegalVal const& specialVal, + RefPtr<PairInfo> pairInfo); + + RefPtr<PairPseudoVal> getPair() + { + assert(flavor == Flavor::pair); + return obj.As<PairPseudoVal>(); + } +}; + +struct TuplePseudoVal : LegalValImpl +{ + struct Element + { + DeclRef<VarDeclBase> fieldDeclRef; + LegalVal val; + }; + + List<Element> elements; +}; + +struct PairPseudoVal : LegalValImpl +{ + LegalVal ordinaryVal; + LegalVal specialVal; + + // The info to tell us which fields + // are on which side(s) + RefPtr<PairInfo> pairInfo; +}; + +struct ImplicitDerefVal : LegalValImpl +{ + LegalVal val; +}; + +// struct TypeLegalizationContext { @@ -280,6 +380,9 @@ struct TypeLegalizationContext /// emitting declarations of legalized `struct` types /// multiple times). Dictionary<DeclRef<Decl>, LegalType> mapDeclRefToLegalType; + + // + Dictionary<String, LegalVal> mapMangledNameToLegalIRValue; }; |
