diff options
| author | Tim Foley <tfoleyNV@users.noreply.github.com> | 2017-11-28 19:49:21 -0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2017-11-28 19:49:21 -0800 |
| commit | 713938038a87b9e4a69f198f09f1bf231be6f72f (patch) | |
| tree | ec3ecd8d1dca5ea6779ac0dd30daab6c7a9f672d | |
| parent | 49510035d52c12d9c63f7b04ea748764282a9b01 (diff) | |
Enable HLSL/GLSL "rewrite" + IR-based Slang codegen (#300)
The big picture here is that the AST-to-AST pass in `ast-legalize` will now detect when a declaration being referenced comes from an `import`ed module, and (if IR codegen is enabled), it will trigger cloning of the IR for the chosen symbol into an IR module that will sit alongside the legalized AST.
Then, during HLSL/GLSL code emit, we emit all the IR-based code first, and then the AST-based code. Whenever the AST code references a symbol that was lowered via IR (we keep track of these) we emit the mangled name of the IR symbol.
Notes/details:
- A lot of the logic for cloning IR symbols referenced by the AST matches the same logic that would clone them for completely IR-based codegen, so I tried to hoist out the common logic and share it (e.g., so that we apply the same guaranteed transformations in both cases). This required basically rewriting the logic in `emit.cpp` that decomposed the various cases.
- There is a new compute test case added to test this functionality. `tests/compute/rewriter.hlsl` confirms that we can use the `-no-checking` mode for the HLSL code, but still make use of a library of Slang code that employs generics, etc.
- Adding this test case required adding a new compute test mode that invokes `render-test` with the `-hlsl-rewrite` flag.
- It turns out that the existing `tests/render/cross-compile0.hlsl` test should have been using this functionality already. It was opting into the use of the IR via `-use-ir`, and the `render-test` application already tries to set `-no-checking` for non-Slang input languages by default. Fixing the code path this test triggers means that it is now a second test of rewriter+IR codegen.
- The `translateDeclRef` logic in `ast-legalize.cpp` seemed sloppy in places, and would potentially clone declarations, when declaration references were desired. I tried to clean a bit of this up, so some call sites are now changed.
- This change tries to clean up some work around cloning of global values
- All global value kinds (not just functions) now go through the logic of trying to pick a "best" definition, so that they can be used when we are linking multiple modules
- The logic for registering cloned values has been unified a bit, so that clients always pass in an `IROriginalValuesForClone` that either wraps a single value (maybe just null), or an `IRSpecSymbol*` that gives a list of values to regsiter the new value as a clone for.
- I made one piece of code that was cloning witness tables as part of generic specializations *not* register a clone. I think this is correct because we may specialize the same generic multiple ways, so registering any values we clone is not the right idea, but I might be missing something...
- I also reorganized this logic so that it would be easier to clone a global value when we only know its mangled name (which is the case when it is the AST that triggers cloning)
- I made sure that when loading a module via `import`, the translation unit for the new module copies the `-use-ir` flag from the overall compile request, if it is present (otherwise we wouldn't generate IR for loaded modules at all... oops).
- Note that `getSpecializedGlobalValueForDeclRef()`, which is the main routine used by the AST legalization to trigger cloning of an IR value does *not* currently handle declaration references that require specialization.
- This change does *not* deal with trying to unify the type legalization logic between the AST-to-AST rewriter and the IR-based codegen, so if you call an imported function with types that require legalization, Bad Things are expected to happen right now.
| -rw-r--r-- | source/slang/ast-legalize.cpp | 100 | ||||
| -rw-r--r-- | source/slang/ast-legalize.h | 17 | ||||
| -rw-r--r-- | source/slang/emit.cpp | 217 | ||||
| -rw-r--r-- | source/slang/ir-insts.h | 32 | ||||
| -rw-r--r-- | source/slang/ir.cpp | 319 | ||||
| -rw-r--r-- | source/slang/ir.h | 1 | ||||
| -rw-r--r-- | source/slang/options.cpp | 6 | ||||
| -rw-r--r-- | source/slang/slang.cpp | 13 | ||||
| -rw-r--r-- | tests/compute/rewriter.hlsl | 19 | ||||
| -rw-r--r-- | tests/compute/rewriter.hlsl.expected.txt | 4 | ||||
| -rw-r--r-- | tests/compute/rewriter.slang | 30 | ||||
| -rw-r--r-- | tools/render-test/slang-support.cpp | 17 | ||||
| -rw-r--r-- | tools/slang-test/main.cpp | 7 |
13 files changed, 586 insertions, 196 deletions
diff --git a/source/slang/ast-legalize.cpp b/source/slang/ast-legalize.cpp index 8e4da2717..afc8b31c8 100644 --- a/source/slang/ast-legalize.cpp +++ b/source/slang/ast-legalize.cpp @@ -2,6 +2,7 @@ #include "ast-legalize.h" #include "emit.h" +#include "ir-insts.h" #include "type-layout.h" #include "visitor.h" @@ -434,6 +435,10 @@ struct SharedLoweringContext CompileRequest* compileRequest; EntryPointRequest* entryPointRequest; + // The "main" module that is being translated (as opposed + // to any of the modules that might have been imported). + ModuleDecl* mainModuleDecl; + ExtensionUsageTracker* extensionUsageTracker; ProgramLayout* programLayout; @@ -463,6 +468,12 @@ struct SharedLoweringContext bool isRewrite = false; bool requiresCopyGLPositionToPositionPerView = false; + + // State for lowering imported declarations to IR as needed + IRSpecializationState* irSpecializationState = nullptr; + + // The actual result we want to return + LoweredEntryPoint result; }; static void attachLayout( @@ -2123,7 +2134,7 @@ struct LoweringVisitor RefPtr<ScopeStmt> loweredStmt, RefPtr<ScopeStmt> originalStmt) { - loweredStmt->scopeDecl = translateDeclRef(originalStmt->scopeDecl).As<ScopeDecl>(); + loweredStmt->scopeDecl = translateDeclRef(originalStmt->scopeDecl).getDecl()->As<ScopeDecl>(); LoweringVisitor subVisitor = *this; subVisitor.isBuildingStmt = true; @@ -2286,7 +2297,7 @@ struct LoweringVisitor ScopeStmt* originalStmt) { lowerStmtFields(loweredStmt, originalStmt); - loweredStmt->scopeDecl = translateDeclRef(originalStmt->scopeDecl).As<ScopeDecl>(); + loweredStmt->scopeDecl = translateDeclRef(originalStmt->scopeDecl).getDecl()->As<ScopeDecl>(); } // Child statements reference their parent statement, @@ -2586,7 +2597,7 @@ struct LoweringVisitor if (auto genSubst = dynamic_cast<GenericSubstitution*>(inSubstitutions)) { RefPtr<GenericSubstitution> result = new GenericSubstitution(); - result->genericDecl = translateDeclRef(genSubst->genericDecl).As<GenericDecl>(); + result->genericDecl = translateDeclRef(genSubst->genericDecl).getDecl()->As<GenericDecl>(); for (auto arg : genSubst->args) { result->args.Add(translateVal(arg)); @@ -2612,17 +2623,36 @@ struct LoweringVisitor } LoweredDeclRef translateDeclRef( - DeclRef<Decl> const& decl) + DeclRef<Decl> const& declRef) { LoweredDeclRef result; - result.decl = translateDeclRef(decl.decl); - result.substitutions = translateSubstitutions(decl.substitutions); + result.decl = translateDeclRefImpl(declRef); + result.substitutions = translateSubstitutions(declRef.substitutions); return result; } LoweredDecl translateDeclRef( - Decl* decl) + Decl* decl) { + return translateDeclRefImpl(DeclRef<Decl>(decl, nullptr)); + } + + // Try to find the module that (recursively) contains a given declaration. + ModuleDecl* findModuleForDecl( + Decl* decl) + { + for (auto dd = decl; dd; dd = dd->ParentDecl) + { + if (auto moduleDecl = dynamic_cast<ModuleDecl*>(dd)) + return moduleDecl; + } + return nullptr; + } + + LoweredDecl translateDeclRefImpl( + DeclRef<Decl> declRef) + { + Decl* decl = declRef.getDecl(); if (!decl) return LoweredDecl(); // We don't want to translate references to built-in declarations, @@ -2641,6 +2671,38 @@ struct LoweringVisitor if (getModifiedDecl(decl)->HasModifier<BuiltinModifier>()) return decl; + // If we are using the IR, and the declaration comes from + // an imported module (rather than the "rewrite-mode" module + // being translated), then we need to ensure that it gets lowered + // to IR instead. + if (shared->compileRequest->compileFlags & SLANG_COMPILE_FLAG_USE_IR) + { + auto parentModule = findModuleForDecl(decl); + if (parentModule && (parentModule != shared->mainModuleDecl)) + { + // Ensure that the IR code for the given declaration + // gets included in the output IR module, and *also* + // that we generate a suitable specialization of it + // if there are any substitutions in effect. + + getSpecializedGlobalValueForDeclRef( + shared->irSpecializationState, + declRef); + + // Remember that this declaration is handled via IR, + // rather than being present in the legalized AST. + shared->result.irDecls.Add(declRef.getDecl()); + + // We don't actually use the `IRGlobalValue` that the + // above operation returns, and instead just keep + // using the original declaration in the legalized + // AST. The step of mapping that declaration over + // to reference the IR symbol will happen later. + + return decl; + } + } + LoweredDecl loweredDecl; if (shared->loweredDecls.TryGetValue(decl, loweredDecl)) return loweredDecl; @@ -2649,10 +2711,10 @@ struct LoweringVisitor return lowerDecl(decl); } - RefPtr<ContainerDecl> translateDeclRef( - ContainerDecl* decl) + DeclRef<ContainerDecl> translateDeclRef( + DeclRef<ContainerDecl> declRef) { - return translateDeclRef((Decl*)decl).getDecl()->As<ContainerDecl>(); + return translateDeclRef(declRef).As<ContainerDecl>(); } LoweredDecl lowerDeclBase( @@ -2759,9 +2821,9 @@ struct LoweringVisitor { RefPtr<Decl> loweredParent; if (auto genericParentDecl = decl->ParentDecl->As<GenericDecl>()) - loweredParent = translateDeclRef(genericParentDecl->ParentDecl); + loweredParent = translateDeclRef(genericParentDecl->ParentDecl).getDecl(); else - loweredParent = translateDeclRef(decl->ParentDecl); + loweredParent = translateDeclRef(decl->ParentDecl).getDecl(); if (loweredParent) { auto layoutMod = loweredParent->FindModifier<ComputedLayoutModifier>(); @@ -3518,7 +3580,7 @@ struct LoweringVisitor if (auto parentModuleDecl = pp.As<ModuleDecl>()) { LoweringVisitor subVisitor = *this; - subVisitor.parentDecl = translateDeclRef(parentModuleDecl); + subVisitor.parentDecl = translateDeclRef(parentModuleDecl).getDecl()->As<ContainerDecl>(); subVisitor.isBuildingStmt = false; return subVisitor.lowerVarDeclCommonInner(decl, loweredDeclClass); @@ -4659,7 +4721,8 @@ LoweredEntryPoint lowerEntryPoint( EntryPointRequest* entryPoint, ProgramLayout* programLayout, CodeGenTarget target, - ExtensionUsageTracker* extensionUsageTracker) + ExtensionUsageTracker* extensionUsageTracker, + IRSpecializationState* irSpecializationState) { SharedLoweringContext sharedContext; sharedContext.compileRequest = entryPoint->compileRequest; @@ -4667,8 +4730,10 @@ LoweredEntryPoint lowerEntryPoint( sharedContext.programLayout = programLayout; sharedContext.target = target; sharedContext.extensionUsageTracker = extensionUsageTracker; + sharedContext.irSpecializationState = irSpecializationState; auto translationUnit = entryPoint->getTranslationUnit(); + sharedContext.mainModuleDecl = translationUnit->SyntaxNode; // Create a single module/program to hold all the lowered code // (with the exception of instrinsic/stdlib declarations, which @@ -4711,7 +4776,6 @@ LoweredEntryPoint lowerEntryPoint( sharedContext.entryPointLayout = visitor.findEntryPointLayout(entryPoint); - LoweredEntryPoint result; if (isRewrite) { for (auto dd : translationUnit->SyntaxNode->Members) @@ -4722,11 +4786,11 @@ LoweredEntryPoint lowerEntryPoint( else { auto loweredEntryPoint = visitor.lowerEntryPoint(entryPoint); - result.entryPoint = loweredEntryPoint; + sharedContext.result.entryPoint = loweredEntryPoint; } - result.program = sharedContext.loweredProgram; + sharedContext.result.program = sharedContext.loweredProgram; - return result; + return sharedContext.result; } } diff --git a/source/slang/ast-legalize.h b/source/slang/ast-legalize.h index 071ff6c51..9046e8df0 100644 --- a/source/slang/ast-legalize.h +++ b/source/slang/ast-legalize.h @@ -37,11 +37,12 @@ namespace Slang { - class EntryPointRequest; - class ProgramLayout; - class TranslationUnitRequest; - + class EntryPointRequest; struct ExtensionUsageTracker; + struct IRSpecializationState; + class ProgramLayout; + class TranslationUnitRequest; + struct LoweredEntryPoint { @@ -52,6 +53,11 @@ namespace Slang // contains the entry point and any // other declarations it uses RefPtr<ModuleDecl> program; + + // A set of declarations that are not present + // in the generated AST, and are instead stored + // in the companion IR module + HashSet<Decl*> irDecls; }; // Emit code for a single entry point, based on @@ -60,6 +66,7 @@ namespace Slang EntryPointRequest* entryPoint, ProgramLayout* programLayout, CodeGenTarget target, - ExtensionUsageTracker* extensionUsageTracker); + ExtensionUsageTracker* extensionUsageTracker, + IRSpecializationState* irSpecializationState); } #endif diff --git a/source/slang/emit.cpp b/source/slang/emit.cpp index d6f1f8e1a..4a084c714 100644 --- a/source/slang/emit.cpp +++ b/source/slang/emit.cpp @@ -102,6 +102,9 @@ struct SharedEmitContext Dictionary<IRBlock*, IRBlock*> irMapContinueTargetToLoopHead; HashSet<String> irTupleTypes; + + // Map used to tell AST lowering what decls are represented by IR. + HashSet<Decl*>* irDeclSetForAST = nullptr; }; struct EmitContext @@ -2942,6 +2945,19 @@ struct EmitVisitor void EmitDeclRef(DeclRef<Decl> declRef) { + // Are we emitting an AST in a context where some declarations + // are actually stored as IR code? + if(auto irDeclSet = context->shared->irDeclSetForAST) + { + Decl* decl = declRef.getDecl(); + if(irDeclSet->Contains(decl)) + { + emit(getMangledName(declRef)); + return; + } + } + + // TODO: need to qualify a declaration name based on parent scopes/declarations // Emit the name for the declaration itself @@ -6870,55 +6886,142 @@ String emitEntryPoint( // Depending on how the compiler was invoked, we may need to perform // some amount of preocessing on the code before we can emit it. // - // For our purposes, there are basically three different "modes" we - // care about: + // We try to partition the cases we need to handle into a few broad + // categories, each of which is reflected as a different code path + // below: + // + // 1. "Full rewriter" mode, where the user provides HLSL/GLSL, opts + // out of semantic checking, and doesn't make use of any Slang + // code via `import`. + // + // 2. "Partial rewriter" modes, where the user starts with HLSL/GLSL + // and opts out of checking for that code, but also imports some + // Slang code which may need cross-compilation. They may also + // need us to rewrite the AST for some of their HLSL/GLSL function + // bodies to make things work. This actually has two main sub-modes: // - // 1. "Full rewriter" mode, where the user provides HLSL/GLSL, and - // doesn't make use of any Slang code via `import`. + // a) "Without IR." If the user doesn't opt into using the IR, then + // the imported Slang code gets translated to the target languge + // via the same AST-to-AST pass that legalized the user's code. This + // mode will eventually go away, but it is the main one used right now. // - // 2. "Partial rewriter" mode, where the user starts with HLSL/GLSL, - // but also imports some Slang code, and may need us to rewrite - // their HLSL/GLSL function bodies to make things work. + // b) "With IR." If the user opts into using the IR, then we need to + // apply the AST-to-AST pass to their HLSL/GLSL code, but *also* use + // the IR to compile everything else. // - // 3. "Full" mode, where all of the input code is in Slang (and/or - // the subset of HLSL we can fully type-check). + // 3. "Full IR" mode, where we can assume all the input code is in Slang + // (or the subset of HLSL we understand) that has undergone full + // semantic checking, and the user has opted into using the IR. // - // We'll try to detect the cases here: + // We'll try to detect the cases here, starting with case (1): // - if((translationUnit->compileRequest->compileFlags & SLANG_COMPILE_FLAG_USE_IR) - && !(translationUnit->compileFlags & SLANG_COMPILE_FLAG_NO_CHECKING )) + if ((translationUnit->compileFlags & SLANG_COMPILE_FLAG_NO_CHECKING) + && translationUnit->compileRequest->loadedModulesList.Count() == 0) { - // This seems to be case (3), because the user is asking for full - // checking, and so we can assume we understand the code fully. + // The user has opted out of semantic checking for their own code + // (in the "main" module), and also hasn't `import`ed any Slang + // modules that would require cross-compilation. // - // The IR code for the module should already have been generated, - // so that we "just" need to specialize it as needed for the - // specific target and entry point in use. + // Our goal in this mode is to print out the AST we parsed and + // hopefully reproduce something as close to the original as possible. // - // The first pass is to extract the IR code of the entry point, - // and any other symbols it references. At the same time, - // we go ahead and select the target-specific version of - // any such functions if they are available. We also go - // ahead and apply the layout information (from `programLayout`) - // to the IR code (which previously had no layout). + // The only deviation we *want* from the original code is that we will + // add new parameter binding annotations. + + sharedContext.program = translationUnitSyntax; + visitor.EmitDeclsInContainerUsingLayout( + translationUnitSyntax, + globalStructLayout); + } + // + // Next we will check for case (2a): + else if (!(translationUnit->compileRequest->compileFlags & SLANG_COMPILE_FLAG_USE_IR)) + { + // This case means the user has opted out of using the IR (so we can't use the + // cases below), but they either turned on semantic checking *or* imported some + // Slang code, so they can't use the case above. // - // Note: it is important that we extract a *copy* of all the - // relevant IR, so that transformations we make for one - // entry point (or target) don't mess up the IR used for other - // entry points (targets). + // Note: This case should go away completely once the IR is able to be relied + // upon for all cross-compilation scenarios. + + // We will apply our AST-to-AST legalization pass before we emit + // any code, and we will emit code for the AST that comes out + // of this pass instead of the original. + + // We perform legalization of the program before emitting *anything*, + // because the lowering process might change how we emit some + // boilerplate at the start of the ouput for GLSL (e.g., what + // version we require). + + auto lowered = lowerEntryPoint( + entryPoint, + programLayout, + target, + &sharedContext.extensionUsageTracker, + nullptr); + sharedContext.program = lowered.program; + + // Note that we emit the main body code of the program *before* + // we emit any leading preprocessor directives for GLSL. + // This is to give the emit logic a change to make last-minute + // adjustments like changing the required GLSL version. // - auto lowered = specializeIRForEntryPoint( + // TODO: All such adjustments would be better handled during + // lowering, but that requires having a semantic rather than + // textual format for the HLSL->GLSL mapping. + visitor.EmitDeclsInContainer(lowered.program.Ptr()); + } + // + // The remaining cases all require the use of our IR, and so there + // are certain steps that need to be shared. + else + { + // We are going to create a fresh IR module that we will use to + // clone any code needed by the user's entry point. + IRSpecializationState* irSpecializationState = createIRSpecializationState( entryPoint, programLayout, - target, + target, targetRequest); + IRModule* irModule = getIRModule(irSpecializationState); + + LoweredEntryPoint lowered; + if(translationUnit->compileFlags & SLANG_COMPILE_FLAG_NO_CHECKING) + { + // We are in case (2b), where the main module is in unchecked + // HLSL/GLSL that we need to "rewrite," and any library code + // is in Slang that will need to be cross-compiled via the IR. + + // Initially, we will apply the AST-to-AST pass to legalize + // the user's code, much like we would for any other target. + // Along the way, this pass will discover any IR declarations + // that we use, and try to emit code for them into our IR module. + + lowered = lowerEntryPoint( + entryPoint, + programLayout, + target, + &sharedContext.extensionUsageTracker, + irSpecializationState); + } + else + { + // We are in case (3), where all of the code is in Slang, and + // has already been lowered to IR as part of the front-end + // compilation work. We thus start by cloning any code needed + // by the entry point over to our fresh IR module. + + specializeIRForEntryPoint( + irSpecializationState, + entryPoint); + } // If the user specified the flag that they want us to dump // IR, then do it here, for the target-specific, but // un-specialized IR. if (translationUnit->compileRequest->shouldDumpIR) { - dumpIR(lowered); + dumpIR(irModule); } // Next, we need to ensure that the code we emit for @@ -6927,7 +7030,7 @@ String emitEntryPoint( // none of our target supports generics, or interfaces, // so we need to specialize those away. // - specializeGenerics(lowered); + specializeGenerics(irModule); // Debugging code for IR transformations... #if 0 @@ -6941,7 +7044,7 @@ String emitEntryPoint( // we need to ensure that the code only uses types // that are legal on the chosen target. // - legalizeTypes(lowered); + legalizeTypes(irModule); // Debugging output of legalization #if 0 @@ -6950,50 +7053,24 @@ String emitEntryPoint( fprintf(stderr, "###\n"); #endif + // After all of the required optimization and legalization + // passes have been performed, we can emit target code from + // the IR module. + // // TODO: do we want to emit directly from IR, or translate the // IR back into AST for emission? + visitor.emitIRModule(&context, irModule); - visitor.emitIRModule(&context, lowered); + // If we are in case (2b) and the user *also* has AST-based code + // that we need to output, we'll do it now. + if (translationUnit->compileFlags & SLANG_COMPILE_FLAG_NO_CHECKING) + { + sharedContext.irDeclSetForAST = &lowered.irDecls; + visitor.EmitDeclsInContainer(lowered.program); + } // TODO: need to clean up the IR module here } - else if(!(translationUnit->compileFlags & SLANG_COMPILE_FLAG_NO_CHECKING ) || - translationUnit->compileRequest->loadedModulesList.Count() != 0) - { - // The user has `import`ed some Slang modules, and so we are in case (2) - // - // We need to apply a "rewriting" pass to the code the user wrote, - // and then emit the result. - - // We perform lowering of the program before emitting *anything*, - // because the lowering process might change how we emit some - // boilerplate at the start of the ouput for GLSL (e.g., what - // version we require). - auto lowered = lowerEntryPoint(entryPoint, programLayout, target, &sharedContext.extensionUsageTracker); - sharedContext.program = lowered.program; - - // Note that we emit the main body code of the program *before* - // we emit any leading preprocessor directives for GLSL. - // This is to give the emit logic a change to make last-minute - // adjustments like changing the required GLSL version. - // - // TODO: All such adjustments would be better handled during - // lowering, but that requires having a semantic rather than - // textual format for the HLSL->GLSL mapping. - visitor.EmitDeclsInContainer(lowered.program.Ptr()); - } - else - { - // We are in case (1). - // - // We should be able to just emit the AST we parsed right back out, - // along with whatever annotations we added along the way. - - sharedContext.program = translationUnitSyntax; - visitor.EmitDeclsInContainerUsingLayout( - translationUnitSyntax, - globalStructLayout); - } String code = sharedContext.sb.ProduceString(); sharedContext.sb.Clear(); diff --git a/source/slang/ir-insts.h b/source/slang/ir-insts.h index b95aea2fe..b5e1d5a1d 100644 --- a/source/slang/ir-insts.h +++ b/source/slang/ir-insts.h @@ -559,19 +559,45 @@ struct IRBuilder IRLayoutDecoration* addLayoutDecoration(IRValue* value, Layout* layout); }; -// Generate a clone of an IR module that is specialized for -// a particular entry point, target, etc. -IRModule* specializeIRForEntryPoint( +// + +// Interface to IR specialization for use when cloning target-specific +// IR as part of compiling an entry point. +// +// TODO: we really need to move all of this logic to its own files. + +// `IRSpecializationState` is used as an opaque type to wrap up all +// the data needed to perform IR specialization, without exposing +// implementation details. +struct IRSpecializationState; +IRSpecializationState* createIRSpecializationState( EntryPointRequest* entryPointRequest, ProgramLayout* programLayout, CodeGenTarget target, TargetRequest* targetReq); +void destroyIRSpecializationState(IRSpecializationState* state); +IRModule* getIRModule(IRSpecializationState* state); + +IRGlobalValue* getSpecializedGlobalValueForDeclRef( + IRSpecializationState* state, + DeclRef<Decl> const& declRef); + +// Clone the IR values reachable from the given entry point +// into the IR module assocaited with the specialization state. +// When multiple definitions of a symbol are found, the one +// that is best specialized for the given `targetReq` will be +// used. +void specializeIRForEntryPoint( + IRSpecializationState* state, + EntryPointRequest* entryPointRequest); // Find suitable uses of the `specialize` instruction that // can be replaced with references to specialized functions. void specializeGenerics( IRModule* module); +// + } #endif diff --git a/source/slang/ir.cpp b/source/slang/ir.cpp index ae7b71172..3a8aabd85 100644 --- a/source/slang/ir.cpp +++ b/source/slang/ir.cpp @@ -3041,9 +3041,40 @@ namespace Slang IRValue* clonedValue, IRValue* originalValue) { + if(!originalValue) + return; context->getClonedValues().Add(originalValue, clonedValue); } + // Information on values to use when registering a cloned value + struct IROriginalValuesForClone + { + IRValue* originalVal = nullptr; + IRSpecSymbol* sym = nullptr; + + IROriginalValuesForClone() {} + + IROriginalValuesForClone(IRValue* originalValue) + : originalVal(originalValue) + {} + + IROriginalValuesForClone(IRSpecSymbol* symbol) + : sym(symbol) + {} + }; + + void registerClonedValue( + IRSpecContextBase* context, + IRValue* clonedValue, + IROriginalValuesForClone const& originalValues) + { + registerClonedValue(context, clonedValue, originalValues.originalVal); + for( auto s = originalValues.sym; s; s = s->nextWithSameName ) + { + registerClonedValue(context, clonedValue, s->irGlobalValue); + } + } + void cloneDecorations( IRSpecContextBase* context, IRValue* clonedValue, @@ -3100,9 +3131,7 @@ namespace Slang }; - IRGlobalVar* cloneGlobalVar(IRSpecContext* context, IRGlobalVar* originalVar); - IRFunc* cloneFunc(IRSpecContext* context, IRFunc* originalFunc); - IRWitnessTable* cloneWitnessTable(IRSpecContext* context, IRWitnessTable* originalVar); + IRGlobalValue* cloneGlobalValue(IRSpecContext* context, IRGlobalValue* originalVal); RefPtr<Substitutions> cloneSubstitutions( IRSpecContext* context, Substitutions* subst); @@ -3117,16 +3146,9 @@ namespace Slang switch (originalValue->op) { case kIROp_global_var: - return cloneGlobalVar(this, (IRGlobalVar*)originalValue); - break; - case kIROp_Func: - return cloneFunc(this, (IRFunc*)originalValue); - break; - case kIROp_witness_table: - return cloneWitnessTable(this, (IRWitnessTable*)originalValue); - break; + return cloneGlobalValue(this, (IRGlobalValue*) originalValue); case kIROp_boolConst: { @@ -3334,10 +3356,13 @@ namespace Slang IRGlobalValueWithCode* clonedValue, IRGlobalValueWithCode* originalValue); - IRGlobalVar* cloneGlobalVar(IRSpecContext* context, IRGlobalVar* originalVar) + IRGlobalVar* cloneGlobalVarImpl( + IRSpecContext* context, + IRGlobalVar* originalVar, + IROriginalValuesForClone const& originalValues) { auto clonedVar = context->builder->createGlobalVar(context->maybeCloneType(originalVar->getType()->getValueType())); - registerClonedValue(context, clonedVar, originalVar); + registerClonedValue(context, clonedVar, originalValues); auto mangledName = originalVar->mangledName; clonedVar->mangledName = mangledName; @@ -3360,10 +3385,13 @@ namespace Slang return clonedVar; } - IRWitnessTable* cloneWitnessTable(IRSpecContext* context, IRWitnessTable* originalTable) + IRWitnessTable* cloneWitnessTableImpl( + IRSpecContext* context, + IRWitnessTable* originalTable, + IROriginalValuesForClone const& originalValues) { auto clonedTable = context->builder->createWitnessTable(); - registerClonedValue(context, clonedTable, originalTable); + registerClonedValue(context, clonedTable, originalValues); auto mangledName = originalTable->mangledName; clonedTable->mangledName = mangledName; @@ -3384,6 +3412,13 @@ namespace Slang return clonedTable; } + IRWitnessTable* cloneWitnessTableWithoutRegistering( + IRSpecContext* context, + IRWitnessTable* originalTable) + { + return cloneWitnessTableImpl(context, originalTable, IROriginalValuesForClone()); + } + void cloneGlobalValueWithCodeCommon( IRSpecContextBase* context, IRGlobalValueWithCode* clonedValue, @@ -3556,15 +3591,6 @@ namespace Slang return clonedFunc; } - - IRFunc* cloneSimpleFunc(IRSpecContextBase* context, IRFunc* originalFunc) - { - auto clonedFunc = context->builder->createFunc(); - registerClonedValue(context, clonedFunc, originalFunc); - cloneFunctionCommon(context, clonedFunc, originalFunc); - return clonedFunc; - } - // Get a string form of the target so that we can // use it to match against target-specialization modifiers // @@ -3687,57 +3713,111 @@ namespace Slang return false; } - IRFunc* cloneFunc(IRSpecContext* context, IRFunc* originalFunc) + IRFunc* cloneFuncImpl( + IRSpecContext* context, + IRFunc* originalFunc, + IROriginalValuesForClone const& originalValues) { - // We are being asked to clone a particular function, but in - // the IR that comes out of the front-end there could still - // be multiple, target-specific, declarations of any given - // function, all of which share the same mangled name. - auto mangledName = originalFunc->mangledName; + auto clonedFunc = context->builder->createFunc(); + registerClonedValue(context, clonedFunc, originalValues); + cloneFunctionCommon(context, clonedFunc, originalFunc); + return clonedFunc; + } + + // Directly clone a global value, based on a single definition/declaration, `originalVal`. + // The symbol `sym` will thread together other declarations of the same value, and + // we will register the new value as the cloned version of all of those. + IRGlobalValue* cloneGlobalValueImpl( + IRSpecContext* context, + IRGlobalValue* originalVal, + IRSpecSymbol* sym) + { + if( !originalVal ) + { + SLANG_UNEXPECTED("cloning a null value"); + UNREACHABLE_RETURN(nullptr); + } + + switch( originalVal->op ) + { + case kIROp_Func: + return cloneFuncImpl(context, (IRFunc*) originalVal, sym); + + case kIROp_global_var: + return cloneGlobalVarImpl(context, (IRGlobalVar*)originalVal, sym); + case kIROp_witness_table: + return cloneWitnessTableImpl(context, (IRWitnessTable*)originalVal, sym); + + default: + SLANG_UNEXPECTED("unknown global value kind"); + UNREACHABLE_RETURN(nullptr); + } + + } + + // Clone a global value, which has the given `mangledName`. + // The `originalVal` is a known global IR value with that name, if one is available. + // (It is okay for this parameter to be null). + IRGlobalValue* cloneGlobalValueWithMangledName( + IRSpecContext* context, + String const& mangledName, + IRGlobalValue* originalVal) + { if(mangledName.Length() == 0) { - return cloneSimpleFunc(context, originalFunc); + // If there is no mangled name, then we assume this is a local symbol, + // and it can't possibly have multiple declarations. + return cloneGlobalValueImpl(context, originalVal, nullptr); } // - // We will scan through all of the available function declarations - // with the same mangled name as `originalFunc` and try + // We will scan through all of the available declarations + // with the same mangled name as `originalVal` and try // to pick the "best" one for our target. RefPtr<IRSpecSymbol> sym; - if( !context->getSymbols().TryGetValue(originalFunc->mangledName, sym) ) + if( !context->getSymbols().TryGetValue(mangledName, sym) ) { // This shouldn't happen! - SLANG_UNEXPECTED("no matching function registered"); - UNREACHABLE_RETURN(cloneSimpleFunc(context, originalFunc)); + SLANG_UNEXPECTED("no matching values registered"); + UNREACHABLE_RETURN(cloneGlobalValueImpl(context, originalVal, nullptr)); } - // We will try to track the "best" definition we can find. - IRFunc* bestFunc = (IRFunc*) sym->irGlobalValue; - + // We will try to track the "best" declaration we can find. + // + // Generally, one declaration wil lbe better than another if it is + // more specialized for the chosen target. Otherwise, we simply favor + // definitions over declarations. + // + IRGlobalValue* bestVal = sym->irGlobalValue; for( auto ss = sym->nextWithSameName; ss; ss = ss->nextWithSameName ) { - IRFunc* newFunc = (IRFunc*) ss->irGlobalValue; - if(isBetterForTarget(context, newFunc, bestFunc)) - bestFunc = newFunc; + IRGlobalValue* newVal = ss->irGlobalValue; + if(isBetterForTarget(context, newVal, bestVal)) + bestVal = newVal; } - // All right, we are now in a position to clone the "best" - // definition that was found. - auto clonedFunc = context->builder->createFunc(); - - // The resulting function will be used as the cloned version - // of every declaration/definition in the original IR. - for( auto ss = sym; ss; ss = ss->nextWithSameName ) - { - registerClonedValue(context, clonedFunc, ss->irGlobalValue); - } + return cloneGlobalValueImpl(context, bestVal, sym); + } - // Clone the "best" definition into our context - cloneFunctionCommon(context, clonedFunc, bestFunc); + IRGlobalValue* cloneGlobalValueWithMangledName(IRSpecContext* context, String const& mangledName) + { + return cloneGlobalValueWithMangledName(context, mangledName, nullptr); + } - return clonedFunc; + // Clone a global value, where `originalVal` is one declaration/definition, but we might + // have to consider others, in order to find the "best" version of the symbol. + IRGlobalValue* cloneGlobalValue(IRSpecContext* context, IRGlobalValue* originalVal) + { + // We are being asked to clone a particular global value, but in + // the IR that comes out of the front-end there could still + // be multiple, target-specific, declarations of any given + // global value, all of which share the same mangled name. + return cloneGlobalValueWithMangledName( + context, + originalVal->mangledName, + originalVal); } StructTypeLayout* getGlobalStructLayout( @@ -3857,7 +3937,7 @@ namespace Slang globalVar = globalVar->getNextValue(); } SLANG_ASSERT(table); - table = cloneWitnessTable(context, (IRWitnessTable*)(table)); + table = cloneWitnessTableWithoutRegistering(context, (IRWitnessTable*)(table)); IRProxyVal * tableVal = new IRProxyVal(); tableVal->inst = table; paramSubst->witnessTables.Add(KeyValuePair<RefPtr<Type>, RefPtr<Val>>(subtypeWitness->sup, tableVal)); @@ -3868,55 +3948,57 @@ namespace Slang return globalParamSubst; } - IRModule* specializeIRForEntryPoint( + struct IRSpecializationState + { + ProgramLayout* programLayout; + CodeGenTarget target; + TargetRequest* targetReq; + + IRModule* irModule = nullptr; + RefPtr<ProgramLayout> newProgramLayout; + + IRSharedSpecContext sharedContextStorage; + IRSpecContext contextStorage; + + IRSharedSpecContext* getSharedContext() { return &sharedContextStorage; } + IRSpecContext* getContext() { return &contextStorage; } + }; + + IRSpecializationState* createIRSpecializationState( EntryPointRequest* entryPointRequest, ProgramLayout* programLayout, CodeGenTarget target, TargetRequest* targetReq) { + IRSpecializationState* state = new IRSpecializationState(); + + state->programLayout = programLayout; + state->target = target; + state->targetReq = targetReq; + + auto compileRequest = entryPointRequest->compileRequest; - auto session = compileRequest->mSession; auto translationUnit = entryPointRequest->getTranslationUnit(); auto originalIRModule = translationUnit->irModule; - if (!originalIRModule) - { - // We should already have emitted IR for the original - // translation unit, and it we don't have it, then - // we are now in trouble. - return nullptr; - } - - // We now need to start cloning IR symbols from `originalIRModule` - // into a fresh IR module for this entry point. Along the way we need to: - // - // 1. Attach layout information from `programLayout` and/or `entryPointLayout` - // onto the cloned IR symbols, to drive later code generation. - // - // 2. In cases where a function might have multiple target-specific definitions, - // we need to pick the "best" one for the chosen code generation target. - // - - IRSharedSpecContext sharedContextStorage; + auto sharedContext = state->getSharedContext(); initializeSharedSpecContext( - &sharedContextStorage, + sharedContext, compileRequest->mSession, nullptr, originalIRModule); + state->irModule = sharedContext->module; // We also need to attach the IR definitions for symbols from // any loaded modules: for (auto loadedModule : compileRequest->loadedModulesList) { - insertGlobalValueSymbols(&sharedContextStorage, loadedModule->irModule); + insertGlobalValueSymbols(sharedContext, loadedModule->irModule); } - // any loaded modules - - IRSpecContext contextStorage; - IRSpecContext* context = &contextStorage; - context->shared = &sharedContextStorage; - context->builder = &sharedContextStorage.builderStorage; + auto context = state->getContext(); + context->shared = sharedContext; + context->builder = &sharedContext->builderStorage; context->target = target; // Create the GlobalGenericParamSubstitution for substituting global generic types @@ -3928,7 +4010,7 @@ namespace Slang // now specailize the program layout using the substitution RefPtr<ProgramLayout> newProgramLayout = specializeProgramLayout(targetReq, programLayout, globalParamSubst); - auto entryPointLayout = findEntryPointLayout(newProgramLayout, entryPointRequest); + state->newProgramLayout = newProgramLayout; // Next, we want to optimize lookup for layout infromation // associated with global declarations, so that we can @@ -3940,6 +4022,69 @@ namespace Slang context->globalVarLayouts.AddIfNotExists(mangledName, globalVarLayout); } + return state; + } + + void destroyIRSpecializationState(IRSpecializationState* state) + { + delete state; + } + + IRModule* getIRModule(IRSpecializationState* state) + { + return state->irModule; + } + + IRGlobalValue* getSpecializedGlobalValueForDeclRef( + IRSpecializationState* state, + DeclRef<Decl> const& declRef) + { + // We will start be ensuring that we have code for + // the declaration itself. + auto decl = declRef.getDecl(); + auto mangledDeclName = getMangledName(decl); + + IRGlobalValue* irDeclVal = cloneGlobalValueWithMangledName( + state->getContext(), + mangledDeclName); + if(!irDeclVal) + return nullptr; + + // Now we need to deal with specializing the given + // IR value based on the substitutions applied to + // our declaration reference. + + if(!declRef.substitutions) + return irDeclVal; + + SLANG_UNEXPECTED("unhandled"); + UNREACHABLE_RETURN(nullptr); + } + + void specializeIRForEntryPoint( + IRSpecializationState* state, + EntryPointRequest* entryPointRequest) + { + auto target = state->target; + + auto compileRequest = entryPointRequest->compileRequest; + auto session = compileRequest->mSession; + auto translationUnit = entryPointRequest->getTranslationUnit(); + auto originalIRModule = translationUnit->irModule; + if (!originalIRModule) + { + // We should already have emitted IR for the original + // translation unit, and it we don't have it, then + // we are now in trouble. + return; + } + + auto context = state->getContext(); + auto newProgramLayout = state->newProgramLayout; + + auto entryPointLayout = findEntryPointLayout(newProgramLayout, entryPointRequest); + + // Next, we make sure to clone the global value for // the entry point function itself, and rely on // this step to recursively copy over anything else @@ -3964,8 +4109,6 @@ namespace Slang default: break; } - - return sharedContextStorage.module; } // diff --git a/source/slang/ir.h b/source/slang/ir.h index f2069e7c3..a6f5b30a1 100644 --- a/source/slang/ir.h +++ b/source/slang/ir.h @@ -466,6 +466,7 @@ void printSlangIRAssembly(StringBuilder& builder, IRModule* module); String getSlangIRAssembly(IRModule* module); void dumpIR(IRModule* module); + } diff --git a/source/slang/options.cpp b/source/slang/options.cpp index 452e7c439..97deeb544 100644 --- a/source/slang/options.cpp +++ b/source/slang/options.cpp @@ -241,6 +241,12 @@ struct OptionsParser int argc, char const* const* argv) { + // Copy some state out of the current request, in case we've been called + // after some other initialization has been performed. + flags = requestImpl->compileFlags; + + // + char const* const* argCursor = &argv[0]; char const* const* argEnd = &argv[argc]; while (argCursor != argEnd) diff --git a/source/slang/slang.cpp b/source/slang/slang.cpp index 204313e84..ea86663ea 100644 --- a/source/slang/slang.cpp +++ b/source/slang/slang.cpp @@ -203,15 +203,11 @@ void CompileRequest::generateIR() // in isolation. for( auto& translationUnit : translationUnits ) { - // If the user opted out of semantic checking for - // the translation unit, then IR code generation - // is not in general even possible; there might - // be semantics errors (diagnosed or not) in the - // code, and we don't want to deal with those. - if (translationUnit->compileFlags & SLANG_COMPILE_FLAG_NO_CHECKING) + // Also skip IR generation if semantic checking is turned off + // for a given translation unit. + if(translationUnit->compileFlags & SLANG_COMPILE_FLAG_NO_CHECKING) continue; - // Okay, we seem to be in the clear now. translationUnit->irModule = generateIRForTranslationUnit(translationUnit); } } @@ -462,6 +458,7 @@ RefPtr<ModuleDecl> CompileRequest::loadModule( // semantic checking to be enabled. // // TODO: decide which options, if any, should be inherited. + translationUnit->compileFlags = this->compileFlags & (SLANG_COMPILE_FLAG_USE_IR); RefPtr<SourceFile> sourceFile = getSourceManager()->allocateSourceFile(path, source); @@ -486,6 +483,8 @@ void CompileRequest::handlePoundImport( RefPtr<TranslationUnitRequest> translationUnit = new TranslationUnitRequest(); translationUnit->compileRequest = this; + translationUnit->compileFlags = this->compileFlags & (SLANG_COMPILE_FLAG_USE_IR); + // Imported code is always native Slang code RefPtr<Scope> languageScope = mSession->slangLanguageScope; diff --git a/tests/compute/rewriter.hlsl b/tests/compute/rewriter.hlsl new file mode 100644 index 000000000..35a630c62 --- /dev/null +++ b/tests/compute/rewriter.hlsl @@ -0,0 +1,19 @@ +//TEST(compute):HLSL_COMPUTE:-xslang -no-checking -xslang -use-ir +//TEST_INPUT:ubuffer(data=[0 1 2 3], stride=4):dxbinding(0),glbinding(0),out + +// Test that we can use Slang libraries that require IR cross-compilation +// (e.g., libraries that use generics) while writing the main code in +// vanilla HLSL/GLSL without checking enabled. + +import rewriter; + +RWStructuredBuffer<int> outputBuffer : register(u0); + +[numthreads(4, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + uint tid = dispatchThreadID.x; + int inVal = outputBuffer[tid]; + int outVal = test(inVal); + outputBuffer[tid] = outVal; +}
\ No newline at end of file diff --git a/tests/compute/rewriter.hlsl.expected.txt b/tests/compute/rewriter.hlsl.expected.txt new file mode 100644 index 000000000..a0d427709 --- /dev/null +++ b/tests/compute/rewriter.hlsl.expected.txt @@ -0,0 +1,4 @@ +10 +11 +12 +13 diff --git a/tests/compute/rewriter.slang b/tests/compute/rewriter.slang new file mode 100644 index 000000000..2895dfaca --- /dev/null +++ b/tests/compute/rewriter.slang @@ -0,0 +1,30 @@ +//TEST_IGNORE_FILE: + +// This file is a "library" used by the `rewriter.hlsl` test. +// It intentionally uses Slang features that can't be supported +// by naive source-to-source translation. + +interface IHelper +{ + int help(int inVal); +} + +struct MyHelper : IHelper +{ + int help(int inVal) + { + return 16 + inVal; + } +}; + +__generic<H : IHelper> +int doTest(H helper, int inVal) +{ + return helper.help(inVal); +} + +int test(int inVal) +{ + MyHelper helper; + return doTest<MyHelper>(helper, inVal); +} diff --git a/tools/render-test/slang-support.cpp b/tools/render-test/slang-support.cpp index 2465bfd99..cfbc24382 100644 --- a/tools/render-test/slang-support.cpp +++ b/tools/render-test/slang-support.cpp @@ -34,7 +34,19 @@ struct SlangShaderCompilerWrapper : public ShaderCompiler } spAddPreprocessorDefine(slangRequest, langDefine, "1"); + // If we aren't dealing with true Slang input, then don't enable checking. + // + // Note: do this before using command-line arguments to set flags, so + // that we don't accidentally clobber other flags. + if (sourceLanguage != SLANG_SOURCE_LANGUAGE_SLANG) + { + spSetCompileFlags(slangRequest, SLANG_COMPILE_FLAG_NO_CHECKING); + } + + // Preocess any additional command-line options specified for Slang using + // the `-xslang <arg>` option to `render-test`. spProcessCommandLineArguments(slangRequest, &gOptions.slangArgs[0], gOptions.slangArgCount); + int computeTranslationUnit = 0; int vertexTranslationUnit = 0; int fragmentTranslationUnit = 0; @@ -76,11 +88,6 @@ struct SlangShaderCompilerWrapper : public ShaderCompiler } - // If we aren't dealing with true Slang input, then don't enable checking. - if (sourceLanguage != SLANG_SOURCE_LANGUAGE_SLANG) - { - spSetCompileFlags(slangRequest, SLANG_COMPILE_FLAG_NO_CHECKING); - } ShaderProgram * result = nullptr; Slang::List<const char*> rawTypeNames; for (auto typeName : request.entryPointTypeArguments) diff --git a/tools/slang-test/main.cpp b/tools/slang-test/main.cpp index 0210b2558..5c05cc4ce 100644 --- a/tools/slang-test/main.cpp +++ b/tools/slang-test/main.cpp @@ -1207,6 +1207,11 @@ TestResult runSlangComputeComparisonTest(TestInput& input) return runComputeComparisonImpl(input, "-slang -compute", input.outputStem + ".expected.txt"); } +TestResult runHLSLComputeTest(TestInput& input) +{ + return runComputeComparisonImpl(input, "-hlsl-rewrite -compute", input.outputStem + ".expected.txt"); +} + TestResult runSlangRenderComputeComparisonTest(TestInput& input) { return runComputeComparisonImpl(input, "-slang -gcompute", input.outputStem + ".expected.txt"); @@ -1411,6 +1416,7 @@ TestResult runTest( { "COMPARE_HLSL_CROSS_COMPILE_RENDER", &runHLSLCrossCompileRenderComparisonTest}, { "COMPARE_HLSL_GLSL_RENDER", &runHLSLAndGLSLComparisonTest }, { "COMPARE_COMPUTE", runSlangComputeComparisonTest}, + { "HLSL_COMPUTE", runHLSLComputeTest}, { "COMPARE_RENDER_COMPUTE", &runSlangRenderComputeComparisonTest }, #else @@ -1419,6 +1425,7 @@ TestResult runTest( { "COMPARE_HLSL_CROSS_COMPILE_RENDER", &skipTest}, { "COMPARE_HLSL_GLSL_RENDER", &skipTest }, { "COMPARE_COMPUTE", &skipTest}, + { "HLSL_COMPUTE", &skipTest}, { "COMPARE_RENDER_COMPUTE", &skipTest }, #endif { "COMPARE_GLSL", &runGLSLComparisonTest }, |
