From 0ca75fe002f346f6ab9b77f40c0576d2905560f1 Mon Sep 17 00:00:00 2001 From: Yong He Date: Wed, 24 Jun 2020 13:16:11 -0700 Subject: Dynamic dispatch for generic interface requirements. -Lower interfaces into actual `IRInterfaceType` insts. -Lower `DeclRef` into `IRAssociatedType` -Generate proper IRType for generic functions. -Add a test case exercising dynamic dispatching a generic static function through an associated type. -Bug fixes for the test case. --- tests/compute/dynamic-dispatch-3.slang | 60 ++++++++++++++++++++++ .../compute/dynamic-dispatch-3.slang.expected.txt | 4 ++ 2 files changed, 64 insertions(+) create mode 100644 tests/compute/dynamic-dispatch-3.slang create mode 100644 tests/compute/dynamic-dispatch-3.slang.expected.txt (limited to 'tests') diff --git a/tests/compute/dynamic-dispatch-3.slang b/tests/compute/dynamic-dispatch-3.slang new file mode 100644 index 000000000..7011a2f4e --- /dev/null +++ b/tests/compute/dynamic-dispatch-3.slang @@ -0,0 +1,60 @@ +//TEST(compute):COMPARE_COMPUTE:-cpu -xslang -allow-dynamic-code + +// Test dynamic dispatch code gen for static member functions +// of associated type. +interface IGetter +{ + int getVal(); +}; +interface IAssoc +{ + int get(); + static int getBase(T getter); +} +interface IInterface +{ + associatedtype Assoc : IAssoc; + int Compute(int inVal); +}; + +struct GetterImpl : IGetter +{ + int getVal() { return 1; } +}; + +int GenericCompute(T obj, int inVal) +{ + GetterImpl getter; + return obj.Compute(inVal) + T.Assoc.getBase(getter); +} + +struct Impl : IInterface +{ + struct Assoc : IAssoc + { + int val; + int get() { return val; } + static int getBase(T t) { return t.getVal(); } + }; + int base; + int Compute(int inVal) { return base + inVal * inVal; } +}; + +int test(int inVal) +{ + Impl obj; + obj.base = 1; + return GenericCompute(obj, inVal); +} + +//TEST_INPUT:ubuffer(data=[0 1 2 3], stride=4):out,name=outputBuffer +RWStructuredBuffer outputBuffer : register(u0); + +[numthreads(4, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + uint tid = dispatchThreadID.x; + int inVal = outputBuffer[tid]; + int outVal = test(inVal); + outputBuffer[tid] = outVal; +} diff --git a/tests/compute/dynamic-dispatch-3.slang.expected.txt b/tests/compute/dynamic-dispatch-3.slang.expected.txt new file mode 100644 index 000000000..a6bafb7ca --- /dev/null +++ b/tests/compute/dynamic-dispatch-3.slang.expected.txt @@ -0,0 +1,4 @@ +2 +3 +6 +B -- cgit v1.2.3 From ffa9a3575ff888dc494ba4878f52441c64a9e08c Mon Sep 17 00:00:00 2001 From: Yong He Date: Wed, 24 Jun 2020 18:09:40 -0700 Subject: Fix `lowerFuncType` and small bug fixes. --- source/slang/slang-emit-c-like.cpp | 6 ++-- source/slang/slang-emit.cpp | 3 +- source/slang/slang-ir-link.cpp | 28 +++++++++++++-- source/slang/slang-ir-lower-generics.cpp | 2 +- source/slang/slang-ir-specialize.cpp | 5 +++ source/slang/slang-lower-to-ir.cpp | 62 ++++++++++++++++---------------- tests/compute/dynamic-dispatch-2.slang | 2 -- 7 files changed, 68 insertions(+), 40 deletions(-) (limited to 'tests') diff --git a/source/slang/slang-emit-c-like.cpp b/source/slang/slang-emit-c-like.cpp index faf7b8c1d..516a8ff22 100644 --- a/source/slang/slang-emit-c-like.cpp +++ b/source/slang/slang-emit-c-like.cpp @@ -305,7 +305,8 @@ void CLikeSourceEmitter::emitWitnessTable(IRWitnessTable* witnessTable) void CLikeSourceEmitter::emitInterface(IRInterfaceType* interfaceType) { SLANG_UNUSED(interfaceType); - SLANG_DIAGNOSE_UNEXPECTED(getSink(), SourceLoc(), "Unimplemented emit: IROpInterfaceType."); + // By default, don't emit anything for interface types. + // This behavior is overloaded by concrete emitters. } void CLikeSourceEmitter::emitTypeImpl(IRType* type, const StringSliceLoc* nameAndLoc) @@ -2289,7 +2290,6 @@ void CLikeSourceEmitter::defaultEmitInstExpr(IRInst* inst, const EmitOpInfo& inO m_writer->emit(")"); } break; - default: diagnoseUnhandledInst(inst); break; @@ -3654,6 +3654,8 @@ void CLikeSourceEmitter::ensureGlobalInst(ComputeEmitActionsContext* ctx, IRInst if (!m_compileRequest->allowDynamicCode) return; break; + + case kIROp_InterfaceRequirementEntry: case kIROp_Generic: return; diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp index f2552f95d..59b059e91 100644 --- a/source/slang/slang-emit.cpp +++ b/source/slang/slang-emit.cpp @@ -280,7 +280,8 @@ Result linkAndOptimizeIR( // For targets that supports dynamic dispatch, we need to lower the // generics / interface types to ordinary functions and types using // function pointers. - lowerGenerics(irModule); + if (compileRequest->allowDynamicCode) + lowerGenerics(irModule); break; default: break; diff --git a/source/slang/slang-ir-link.cpp b/source/slang/slang-ir-link.cpp index 4e6ad74a4..e556fd738 100644 --- a/source/slang/slang-ir-link.cpp +++ b/source/slang/slang-ir-link.cpp @@ -89,6 +89,25 @@ struct IRSpecContextBase } }; +enum class GlobalValueClass +{ + StructKey, + Other +}; + +// Get the "class" of a global value. If there are more than one value of the same class, +// only one value in each class will be selected during linking. +GlobalValueClass getGlobalValueClass(IRInst* value) +{ + switch (value->op) + { + case kIROp_StructKey: + return GlobalValueClass::StructKey; + default: + return GlobalValueClass::Other; + } +} + void registerClonedValue( IRSpecContextBase* context, IRInst* clonedValue, @@ -128,9 +147,11 @@ void registerClonedValue( IROriginalValuesForClone const& originalValues) { registerClonedValue(context, clonedValue, originalValues.originalVal); + auto valueClass = getGlobalValueClass(clonedValue); for( auto s = originalValues.sym; s; s = s->nextWithSameName ) { - registerClonedValue(context, clonedValue, s->irGlobalValue); + if (getGlobalValueClass(s->irGlobalValue) == valueClass) + registerClonedValue(context, clonedValue, s->irGlobalValue); } } @@ -1153,7 +1174,6 @@ IRInst* cloneGlobalValueImpl( return clonedValue; } - /// Clone a global value, which has the given `originalLinkage`. /// /// The `originalVal` is a known global IR value with that linkage, if one is available. @@ -1214,10 +1234,12 @@ IRInst* cloneGlobalValueWithLinkage( // definitions over declarations. // IRInst* bestVal = nullptr; + auto valueClass = getGlobalValueClass(originalVal); for(IRSpecSymbol* ss = sym; ss; ss = ss->nextWithSameName ) { IRInst* newVal = ss->irGlobalValue; - if(isBetterForTarget(context, newVal, bestVal)) + if (getGlobalValueClass(newVal) == valueClass && + isBetterForTarget(context, newVal, bestVal)) bestVal = newVal; } diff --git a/source/slang/slang-ir-lower-generics.cpp b/source/slang/slang-ir-lower-generics.cpp index fe0fa3364..4701a3cce 100644 --- a/source/slang/slang-ir-lower-generics.cpp +++ b/source/slang/slang-ir-lower-generics.cpp @@ -162,7 +162,7 @@ namespace Slang newOperands.add(paramType); } } - if (!translated) + if (!translated && additionalParamCount == 0) return funcType; for (UInt i = 0; i < additionalParamCount; i++) { diff --git a/source/slang/slang-ir-specialize.cpp b/source/slang/slang-ir-specialize.cpp index cf475f1ff..3acb34c87 100644 --- a/source/slang/slang-ir-specialize.cpp +++ b/source/slang/slang-ir-specialize.cpp @@ -416,6 +416,11 @@ struct SpecializationContext case kIROp_BindExistentialsType: break; + // An interface type is always fully specialized. + case kIROp_InterfaceType: + markInstAsFullySpecialized(inst); + break; + case kIROp_Specialize: // The `specialize` instruction is a bit sepcial, // because it is possible to have a `specialize` diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index ff356fd48..3e211e7ca 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -4466,7 +4466,7 @@ struct DeclLoweringVisitor : DeclVisitor auto type = lowerType(subContext, decl->type.type); - return LoweredValInfo::simple(finishOuterGenerics(subBuilder, type)); + return LoweredValInfo::simple(finishOuterGenerics(subBuilder, type, outerGeneric)); } LoweredValInfo visitGenericTypeParamDecl(GenericTypeParamDecl* /*decl*/) @@ -4536,7 +4536,7 @@ struct DeclLoweringVisitor : DeclVisitor Dictionary mapASTToIRWitnessTable) { auto subBuilder = subContext->irBuilder; - + for(auto entry : astWitnessTable->requirementDictionary) { auto requiredMemberDecl = entry.Key; @@ -4676,7 +4676,7 @@ struct DeclLoweringVisitor : DeclVisitor NestedContext nested(this); auto subBuilder = nested.getBuilder(); auto subContext = nested.getContext(); - emitOuterGenerics(subContext, inheritanceDecl, inheritanceDecl); + auto outerGeneric = emitOuterGenerics(subContext, inheritanceDecl, inheritanceDecl); // Lower the super-type to force its declaration to be lowered. // @@ -4705,7 +4705,7 @@ struct DeclLoweringVisitor : DeclVisitor irWitnessTable->moveToEnd(); - return LoweredValInfo::simple(finishOuterGenerics(subBuilder, irWitnessTable)); + return LoweredValInfo::simple(finishOuterGenerics(subBuilder, irWitnessTable, outerGeneric)); } LoweredValInfo visitDeclGroup(DeclGroup* declGroup) @@ -5106,7 +5106,7 @@ struct DeclLoweringVisitor : DeclVisitor auto subBuilder = nestedContext.getBuilder(); auto subContext = nestedContext.getContext(); subBuilder->setInsertInto(subBuilder->getModule()->getModuleInst()); - emitOuterGenerics(subContext, decl, decl); + auto outerGeneric = emitOuterGenerics(subContext, decl, decl); IRType* subVarType = lowerType(subContext, decl->getType()); @@ -5207,7 +5207,7 @@ struct DeclLoweringVisitor : DeclVisitor } irGlobal->moveToEnd(); - finishOuterGenerics(subBuilder, irGlobal); + finishOuterGenerics(subBuilder, irGlobal, outerGeneric); return globalVal; } @@ -5317,7 +5317,7 @@ struct DeclLoweringVisitor : DeclVisitor } // Emit any generics that should wrap the actual type. - emitOuterGenerics(subContext, decl, decl); + auto outerGeneric = emitOuterGenerics(subContext, decl, decl); IRInterfaceType* irInterface = subBuilder->createInterfaceType( requirementEntries.getCount(), @@ -5333,7 +5333,7 @@ struct DeclLoweringVisitor : DeclVisitor addTargetIntrinsicDecorations(irInterface, decl); - return LoweredValInfo::simple(finishOuterGenerics(subBuilder, irInterface)); + return LoweredValInfo::simple(finishOuterGenerics(subBuilder, irInterface, outerGeneric)); } LoweredValInfo visitEnumCaseDecl(EnumCaseDecl* decl) @@ -5367,7 +5367,7 @@ struct DeclLoweringVisitor : DeclVisitor NestedContext nestedContext(this); auto subBuilder = nestedContext.getBuilder(); auto subContext = nestedContext.getContext(); - emitOuterGenerics(subContext, decl, decl); + auto outerGeneric = emitOuterGenerics(subContext, decl, decl); // An `enum` declaration will currently lower directly to its "tag" // type, so that any references to the `enum` become referenes to @@ -5379,7 +5379,7 @@ struct DeclLoweringVisitor : DeclVisitor IRType* loweredTagType = lowerType(subContext, decl->tagType); - return LoweredValInfo::simple(finishOuterGenerics(subBuilder, loweredTagType)); + return LoweredValInfo::simple(finishOuterGenerics(subBuilder, loweredTagType, outerGeneric)); } LoweredValInfo visitAggTypeDecl(AggTypeDecl* decl) @@ -5406,7 +5406,7 @@ struct DeclLoweringVisitor : DeclVisitor auto subContext = nestedContext.getContext(); // Emit any generics that should wrap the actual type. - emitOuterGenerics(subContext, decl, decl); + auto outerGeneric = emitOuterGenerics(subContext, decl, decl); IRInst* resultType = nullptr; if (as(decl)) @@ -5490,7 +5490,7 @@ struct DeclLoweringVisitor : DeclVisitor resultType->moveToEnd(); addTargetIntrinsicDecorations(resultType, decl); - return LoweredValInfo::simple(finishOuterGenerics(subBuilder, resultType)); + return LoweredValInfo::simple(finishOuterGenerics(subBuilder, resultType, outerGeneric)); } LoweredValInfo lowerMemberVarDecl(VarDecl* fieldDecl) @@ -5819,24 +5819,25 @@ struct DeclLoweringVisitor : DeclVisitor // IRInst* finishOuterGenerics( IRBuilder* subBuilder, - IRInst* val) + IRInst* val, + IRGeneric* parentGeneric) { IRInst* v = val; - for(;;) + while (parentGeneric) { - auto parentBlock = as(v->getParent()); - if (!parentBlock) break; - - auto parentGeneric = as(parentBlock->getParent()); - if (!parentGeneric) break; - - subBuilder->setInsertInto(parentBlock); + subBuilder->setInsertInto(parentGeneric->getFirstBlock()); subBuilder->emitReturn(v); parentGeneric->moveToEnd(); // There might be more outer generics, // so we need to loop until we run out. v = parentGeneric; + auto parentBlock = as(v->getParent()); + if (!parentBlock) break; + + parentGeneric = as(parentBlock->getParent()); + if (!parentGeneric) break; + } return v; } @@ -6066,10 +6067,10 @@ struct DeclLoweringVisitor : DeclVisitor // Simple case of a by-value input parameter. break; - // If the parameter is declared `out` or `inout`, - // then we will represent it with a pointer type in - // the IR, but we will use a specialized pointer - // type that encodes the parameter direction information. + // If the parameter is declared `out` or `inout`, + // then we will represent it with a pointer type in + // the IR, but we will use a specialized pointer + // type that encodes the parameter direction information. case kParameterDirection_Out: irParamType = subBuilder->getOutType(irParamType); break; @@ -6153,7 +6154,7 @@ struct DeclLoweringVisitor : DeclVisitor auto funcTypeBuilder = nestedContextFuncType.getBuilder(); auto funcTypeContext = nestedContextFuncType.getContext(); - emitOuterGenerics(funcTypeContext, decl, decl); + auto outerGenerics = emitOuterGenerics(funcTypeContext, decl, decl); ParameterLists parameterLists; List paramTypes; @@ -6166,7 +6167,7 @@ struct DeclLoweringVisitor : DeclVisitor funcTypeContext, decl); - return finishOuterGenerics(funcTypeBuilder, irFuncType); + return finishOuterGenerics(funcTypeBuilder, irFuncType, outerGenerics); } LoweredValInfo lowerFuncDecl(FunctionDeclBase* decl) @@ -6178,7 +6179,7 @@ struct DeclLoweringVisitor : DeclVisitor auto subBuilder = nestedContextFunc.getBuilder(); auto subContext = nestedContextFunc.getContext(); - emitOuterGenerics(subContext, decl, decl); + auto outerGeneric = emitOuterGenerics(subContext, decl, decl); // need to create an IR function here @@ -6550,12 +6551,11 @@ struct DeclLoweringVisitor : DeclVisitor // If this function is defined inside an interface, add a reference to the IRFunc from // the interface's type definition. - auto finalVal = finishOuterGenerics(subBuilder, irFunc); - + auto finalVal = finishOuterGenerics(subBuilder, irFunc, outerGeneric); if (auto genericVal = as(finalVal)) { auto funcType = lowerFuncType(decl); - genericVal->typeUse.set(funcType); + genericVal->setFullType((IRType*)funcType); } maybeAssociateToInterfaceType(decl, finalVal); diff --git a/tests/compute/dynamic-dispatch-2.slang b/tests/compute/dynamic-dispatch-2.slang index ade8aeb84..6b8b0e633 100644 --- a/tests/compute/dynamic-dispatch-2.slang +++ b/tests/compute/dynamic-dispatch-2.slang @@ -12,7 +12,6 @@ interface IInterface { associatedtype Assoc : IAssoc; int Compute(int inVal); - Assoc getAssoc(); }; int GenericCompute(T obj, int inVal) @@ -30,7 +29,6 @@ struct Impl : IInterface }; int base; int Compute(int inVal) { return base + inVal * inVal; } - Assoc getAssoc() { Assoc rs; rs.val = 1; return rs; } }; int test(int inVal) -- cgit v1.2.3