From 62079c534407abe300d24a6d759641779e48bc67 Mon Sep 17 00:00:00 2001 From: Yong He Date: Thu, 16 Jul 2020 13:09:17 -0700 Subject: Support associatedtype local variables and return values in dynamic dispatch code (#1444) * Refactor lower-generics pass into separate subpasses. * IR pass to generate witness table wrappers. * Support associatedtype local variables and return values in dynamic dispatch code. --- .../slang/slang-ir-generics-lowering-context.cpp | 1 + source/slang/slang-ir-lower-generic-call.cpp | 11 +++ source/slang/slang-ir-lower-generic-function.cpp | 4 +- source/slang/slang-ir-lower-generic-var.cpp | 11 ++- source/slang/slang-lower-to-ir.cpp | 5 ++ tests/compute/dynamic-dispatch-7.slang | 85 ++++++++++++++++++++++ .../compute/dynamic-dispatch-7.slang.expected.txt | 4 + 7 files changed, 116 insertions(+), 5 deletions(-) create mode 100644 tests/compute/dynamic-dispatch-7.slang create mode 100644 tests/compute/dynamic-dispatch-7.slang.expected.txt diff --git a/source/slang/slang-ir-generics-lowering-context.cpp b/source/slang/slang-ir-generics-lowering-context.cpp index 6ee0a17f0..0fcb85378 100644 --- a/source/slang/slang-ir-generics-lowering-context.cpp +++ b/source/slang/slang-ir-generics-lowering-context.cpp @@ -15,6 +15,7 @@ namespace Slang case kIROp_ThisType: case kIROp_AssociatedType: case kIROp_InterfaceType: + case kIROp_lookup_interface_method: return true; case kIROp_Specialize: { diff --git a/source/slang/slang-ir-lower-generic-call.cpp b/source/slang/slang-ir-lower-generic-call.cpp index f339d5309..5095fadd3 100644 --- a/source/slang/slang-ir-lower-generic-call.cpp +++ b/source/slang/slang-ir-lower-generic-call.cpp @@ -127,10 +127,21 @@ namespace Slang translateCallInst(callInst, funcType, loweredFunc, specializeInst); } + void lowerCallToInterfaceMethod(IRCall* callInst, IRLookupWitnessMethod* lookupInst) + { + // If we see a call(lookup_interface_method(...), ...), we need to translate + // all occurences of associatedtypes. + auto funcType = cast(lookupInst->getDataType()); + auto loweredFunc = lookupInst; + translateCallInst(callInst, funcType, loweredFunc, nullptr); + } + void lowerCall(IRCall* callInst) { if (auto specializeInst = as(callInst->getCallee())) lowerCallToSpecializedFunc(callInst, specializeInst); + else if (auto lookupInst = as(callInst->getCallee())) + lowerCallToInterfaceMethod(callInst, lookupInst); } void processInst(IRInst* inst) diff --git a/source/slang/slang-ir-lower-generic-function.cpp b/source/slang/slang-ir-lower-generic-function.cpp index e930c6cc8..92f00c509 100644 --- a/source/slang/slang-ir-lower-generic-function.cpp +++ b/source/slang/slang-ir-lower-generic-function.cpp @@ -122,7 +122,8 @@ namespace Slang auto paramType = param->getDataType(); if (auto ptrType = as(paramType)) paramType = ptrType->getValueType(); - if (isPointerOfType(paramType->getDataType(), kIROp_RTTIType)) + if (isPointerOfType(paramType->getDataType(), kIROp_RTTIType) || + paramType->op == kIROp_lookup_interface_method) { // Lower into a function parameter of raw pointer type. param->setFullType(builder.getRawPointerType()); @@ -277,6 +278,7 @@ namespace Slang // Update the type of lookupInst to the lowered type of the corresponding interface requirement val. // If the requirement is a function, interfaceRequirementVal will be the lowered function type. + // If the requirement is an associatedtype, interfaceRequirementVal will be Ptr. IRInst* interfaceRequirementVal = nullptr; auto witnessTableType = cast(lookupInst->getWitnessTable()->getDataType()); auto interfaceType = maybeLowerInterfaceType(cast(witnessTableType->getConformanceType())); diff --git a/source/slang/slang-ir-lower-generic-var.cpp b/source/slang/slang-ir-lower-generic-var.cpp index 0c45e8e38..8799dea6b 100644 --- a/source/slang/slang-ir-lower-generic-var.cpp +++ b/source/slang/slang-ir-lower-generic-var.cpp @@ -19,12 +19,15 @@ namespace Slang void processVarInst(IRInst* varInst) { // We process only var declarations that have type - // `Ptr`. + // `Ptr` or `Ptr`. + // // Due to the processing of `lowerGenericFunction`, // A local variable of generic type now appears as - // `var X:Ptr>` + // `var X:Ptr>`, + // where y can be an IRParam if it is a generic type, + // or an `lookup_interface_method` if it is an associated type. // We match this pattern and turn this inst into - // `X:RawPtr = alloca(rtti_extract_size(irParam))` + // `X:RTTIPtr(irParam) = alloca(irParam)` auto varTypeInst = varInst->getDataType(); if (!varTypeInst) return; @@ -34,7 +37,7 @@ namespace Slang // `varTypeParam` represents a pointer to the RTTI object. auto varTypeParam = ptrType->getValueType(); - if (varTypeParam->op != kIROp_Param) + if (varTypeParam->op != kIROp_Param && varTypeParam->op != kIROp_lookup_interface_method) return; if (!varTypeParam->getDataType()) return; diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index 175f9264f..322489d89 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -3437,6 +3437,11 @@ struct ExprLoweringVisitorBase : ExprVisitor UNREACHABLE_RETURN(LoweredValInfo()); } + LoweredValInfo visitAssocTypeDecl(AssocTypeDecl* decl) + { + return LoweredValInfo::simple(context->irBuilder->getAssociatedType()); + } + LoweredValInfo visitAssignExpr(AssignExpr* expr) { // Because our representation of lowered "values" diff --git a/tests/compute/dynamic-dispatch-7.slang b/tests/compute/dynamic-dispatch-7.slang new file mode 100644 index 000000000..62ab94e48 --- /dev/null +++ b/tests/compute/dynamic-dispatch-7.slang @@ -0,0 +1,85 @@ +//TEST(compute):COMPARE_COMPUTE:-cpu -xslang -allow-dynamic-code +//DISABLE_TEST(compute):COMPARE_COMPUTE:-cuda -xslang -allow-dynamic-code + +// Test dynamic dispatch code gen for associated-typed return values +// and local variables. +// TODO: test arguments of associated type. + +interface IAssoc +{ + int Compute(); +} + +interface IInterface +{ + associatedtype TAssoc : IAssoc; + + [mutating] + void SetVal(int inVal); + + TAssoc GetAssoc(); +}; + +T.TAssoc CreateT_Assoc_Inner(int inVal) +{ + T obj; + obj.SetVal(inVal); + return obj.GetAssoc(); +} + +T.TAssoc CreateT_Assoc(int inVal) +{ + return CreateT_Assoc_Inner(inVal); +} + +T CreateT(int inVal) +{ + T obj; + obj.SetVal(inVal); + return obj; +} + +struct Impl : IInterface +{ + struct TAssoc : IAssoc + { + int base; + int Compute() + { + return base; + } + }; + + TAssoc assoc; + [mutating] + void SetVal(int inVal) + { + assoc.base = inVal; + } + + TAssoc GetAssoc() + { + return assoc; + } +}; + +int test() +{ + var obj = CreateT(2); + var obj2 = CreateT_Assoc(1); + // TODO: compiler crash if type parameter is missing. + // (hitting lowering logic of TypeEqualityWitness) + var obj3 = CreateT_Assoc_Inner(1); + return obj.GetAssoc().Compute() + obj2.Compute() + obj3.Compute(); +} + +//TEST_INPUT:ubuffer(data=[0 1 2 3], stride=4):out,name=outputBuffer +RWStructuredBuffer outputBuffer : register(u0); + +[numthreads(4, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + uint tid = dispatchThreadID.x; + int outVal = test(); + outputBuffer[tid] = outVal; +} diff --git a/tests/compute/dynamic-dispatch-7.slang.expected.txt b/tests/compute/dynamic-dispatch-7.slang.expected.txt new file mode 100644 index 000000000..e785149d2 --- /dev/null +++ b/tests/compute/dynamic-dispatch-7.slang.expected.txt @@ -0,0 +1,4 @@ +4 +4 +4 +4 -- cgit v1.2.3