From 54bf54bd0dda378f8400860b25855558f39cb52b Mon Sep 17 00:00:00 2001 From: Yong He Date: Fri, 17 Nov 2017 21:26:21 -0500 Subject: Add support for global generic parameters (#285) * Add support for global generic parameters (In-progress work) This commit include: 1. Update Slang API to allow specification of generic type arguments in an `EntryPointRequest` 2. Add parsing of `__generic_param` construct, which becomes a GlobalGenericParamDecl, contains members of `GenericTypeConstraintDecl`. 3. Semantics checking will check whether the provided type arguments conform to the interfaces as defined by the generic parameter, and store SubtypeWitness values in the EntryPointRequest, which will be used by `specializeIRForEntryPoint` when generating final IR. 4. Add a new type of substitution - `GlobalGenericParamSubstitution` for subsittuting references to `__generic_param` decls or to its member `GenericTypeConsraintDecl` with the actual type argument or witness tables. 5. Update `IRSpecContext` to apply `GlobalGenericParamSubstitution` when specializing the IR for an EntryPointRequest. 6. Update `render-test` to take additional `type` inputs, which specifies the type arguments to substitute into the global `__generic_param` types. This commit does not include ProgramLayout specialization. * IR: pass through `[unroll]` attribute (#284) The initial lowering was adding an `IRLoopControlDecoration` to the instruction at the head of a loop, but this was getting dropped when the IR gets cloned for a particular entry point. The fix was simply to add a case for loop-control decorations to `cloneDecoration`. * fix warnings * IR: support `CompileTimeForStmt` (#286) This statement type is a bit of a hack, to support loops that *must* be unrolled. The AST-to-AST pass handles them by cloning the AST for the loop body N times, and it was easy enough to do the same thing for the IR: emit the instructions for the body N times. The only thing that requires a bit of care is that now we might see the same variable declarations multiple times, so we need to play it safe and overwrite existing entries in our map from declarations to their IR values. Of course a better answer long-term would be to do the actual unrolling in the IR. This is especially true because we might some day want to support compile-time/must-unroll loops in functions, where the loop counter comes in as a parameter (but must still be compile-time-constant at every call site). * Add support for global generic parameters (In-progress work) This commit include: 1. Update Slang API to allow specification of generic type arguments in an `EntryPointRequest` 2. Add parsing of `__generic_param` construct, which becomes a GlobalGenericParamDecl, contains members of `GenericTypeConstraintDecl`. 3. Semantics checking will check whether the provided type arguments conform to the interfaces as defined by the generic parameter, and store SubtypeWitness values in the EntryPointRequest, which will be used by `specializeIRForEntryPoint` when generating final IR. 4. Add a new type of substitution - `GlobalGenericParamSubstitution` for subsittuting references to `__generic_param` decls or to its member `GenericTypeConsraintDecl` with the actual type argument or witness tables. 5. Update `IRSpecContext` to apply `GlobalGenericParamSubstitution` when specializing the IR for an EntryPointRequest. 6. Update `render-test` to take additional `type` inputs, which specifies the type arguments to substitute into the global `__generic_param` types. progress on parameter binding * Add a more contrived test case for specializing parameter bindings * update render-test to align buffers to 256 bytes (to get rid of D3D complains on minimal buffer size). * adding one more test case for parameter binding specialization. * Cleanup according to @tfoleyNV 's suggestions. * fix a bug introduced in the cleanup --- source/slang/ir.cpp | 119 ++++++++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 112 insertions(+), 7 deletions(-) (limited to 'source/slang/ir.cpp') diff --git a/source/slang/ir.cpp b/source/slang/ir.cpp index 9068e717b..bfc26643c 100644 --- a/source/slang/ir.cpp +++ b/source/slang/ir.cpp @@ -3089,12 +3089,16 @@ namespace Slang // to the layout to use for it. Dictionary globalVarLayouts; + RefPtr subst; + // Override the "maybe clone" logic so that we always clone virtual IRValue* maybeCloneValue(IRValue* originalVal) override; // Override teh "maybe clone" logic so that we carefully // clone any IR proxy values inside substitutions virtual DeclRef maybeCloneDeclRef(DeclRef const& declRef) override; + + virtual RefPtr maybeCloneType(Type* originalType) override; }; @@ -3102,6 +3106,11 @@ namespace Slang IRFunc* cloneFunc(IRSpecContext* context, IRFunc* originalFunc); IRWitnessTable* cloneWitnessTable(IRSpecContext* context, IRWitnessTable* originalVar); + RefPtr IRSpecContext::maybeCloneType(Type* originalType) + { + return originalType->Substitute(subst).As(); + } + IRValue* IRSpecContext::maybeCloneValue(IRValue* originalValue) { switch (originalValue->op) @@ -3143,6 +3152,33 @@ namespace Slang case kIROp_decl_ref: { IRDeclRef* od = (IRDeclRef*)originalValue; + + // if the declRef is one of the __generic_param decl being substituted by subst + // return the substituted decl + if (subst) + { + if (od->declRef.getDecl() == subst->paramDecl) + return builder->getTypeVal(subst->actualType.As()); + else if (auto genConstraint = od->declRef.As()) + { + // a decl-ref to GenericTypeConstraintDecl as a result of + // referencing a generic parameter type should be replaced with + // the actual witness table + if (genConstraint.getDecl()->ParentDecl == subst->paramDecl) + { + // find the witness table from subst + for (auto witness : subst->witnessTables) + { + if (witness.Key->EqualsVal(GetSup(genConstraint))) + { + auto proxyVal = witness.Value.As(); + SLANG_ASSERT(proxyVal); + return proxyVal->inst; + } + } + } + } + } auto declRef = maybeCloneDeclRef(od->declRef); return builder->getDeclRefVal(declRef); } @@ -3150,7 +3186,9 @@ namespace Slang case kIROp_TypeType: { IRValue* od = (IRValue*)originalValue; - return builder->getTypeVal(od->type); + int ioDiff = 0; + auto newType = od->type->SubstituteImpl(subst, &ioDiff); + return builder->getTypeVal(newType.As()); } break; default: @@ -3207,7 +3245,9 @@ namespace Slang newSubst->outer = cloneSubstitutions(context, subst->outer); return newSubst; } - return nullptr; + else + SLANG_UNREACHABLE("unimplemented cloneSubstitution"); + UNREACHABLE_RETURN(nullptr); } DeclRef IRSpecContext::maybeCloneDeclRef(DeclRef const& declRef) @@ -3281,7 +3321,7 @@ namespace Slang IRGlobalVar* cloneGlobalVar(IRSpecContext* context, IRGlobalVar* originalVar) { - auto clonedVar = context->builder->createGlobalVar(originalVar->getType()->getValueType()); + auto clonedVar = context->builder->createGlobalVar(context->maybeCloneType(originalVar->getType()->getValueType())); registerClonedValue(context, clonedVar, originalVar); auto mangledName = originalVar->mangledName; @@ -3703,10 +3743,67 @@ namespace Slang } } + // implementation provided in parameter-binding.cpp + RefPtr specializeProgramLayout( + TargetRequest * targetReq, + ProgramLayout* programLayout, + Substitutions * typeSubst); + + RefPtr createGlobalGenericParamSubstitution( + EntryPointRequest * entryPointRequest, + ProgramLayout * programLayout, + IRSpecContext* context, + IRModule* originalIRModule) + { + RefPtr globalParamSubst; + Substitutions * curTailSubst = nullptr; + for (auto param : programLayout->globalGenericParams) + { + auto paramSubst = new GlobalGenericParamSubstitution(); + if (!globalParamSubst) + globalParamSubst = paramSubst; + if (curTailSubst) + curTailSubst->outer = paramSubst; + curTailSubst = paramSubst; + paramSubst->paramDecl = param->decl; + SLANG_ASSERT((UInt)param->index < entryPointRequest->genericParameterTypes.Count()); + paramSubst->actualType = entryPointRequest->genericParameterTypes[param->index]; + // find witness tables + for (auto witness : entryPointRequest->genericParameterWitnesses) + { + if (auto subtypeWitness = witness.As()) + { + if (subtypeWitness->sub->EqualsVal(paramSubst->actualType)) + { + auto witnessTableName = getMangledNameForConformanceWitness(subtypeWitness->sub, subtypeWitness->sup); + auto globalVar = originalIRModule->getFirstGlobalValue(); + IRGlobalValue * table = nullptr; + while (globalVar) + { + if (globalVar->mangledName == witnessTableName) + { + table = globalVar; + break; + } + globalVar = globalVar->getNextValue(); + } + SLANG_ASSERT(table); + table = cloneWitnessTable(context, (IRWitnessTable*)(table)); + IRProxyVal * tableVal = new IRProxyVal(); + tableVal->inst = table; + paramSubst->witnessTables.Add(KeyValuePair, RefPtr>(subtypeWitness->sup, tableVal)); + } + } + } + } + return globalParamSubst; + } + IRModule* specializeIRForEntryPoint( EntryPointRequest* entryPointRequest, ProgramLayout* programLayout, - CodeGenTarget target) + CodeGenTarget target, + TargetRequest* targetReq) { auto compileRequest = entryPointRequest->compileRequest; auto session = compileRequest->mSession; @@ -3720,8 +3817,6 @@ namespace Slang return nullptr; } - auto entryPointLayout = findEntryPointLayout(programLayout, entryPointRequest); - // We now need to start cloning IR symbols from `originalIRModule` // into a fresh IR module for this entry point. Along the way we need to: // @@ -3746,11 +3841,21 @@ namespace Slang context->builder = &sharedContextStorage.builderStorage; context->target = target; + // Create the GlobalGenericParamSubstitution for substituting global generic types + // into user-provided type arguments + auto globalParamSubst = createGlobalGenericParamSubstitution(entryPointRequest, programLayout, context, originalIRModule); + + context->subst = globalParamSubst; + + // now specailize the program layout using the substitution + RefPtr newProgramLayout = specializeProgramLayout(targetReq, programLayout, globalParamSubst); + + auto entryPointLayout = findEntryPointLayout(newProgramLayout, entryPointRequest); // Next, we want to optimize lookup for layout infromation // associated with global declarations, so that we can // look things up based on the IR values (using mangled names) - auto globalStructLayout = getGlobalStructLayout(programLayout); + auto globalStructLayout = getGlobalStructLayout(newProgramLayout); for (auto globalVarLayout : globalStructLayout->fields) { String mangledName = getMangledName(globalVarLayout->varDecl); -- cgit v1.2.3