diff options
| author | Yong He <yonghe@outlook.com> | 2020-07-16 13:09:17 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2020-07-16 13:09:17 -0700 |
| commit | 62079c534407abe300d24a6d759641779e48bc67 (patch) | |
| tree | e00e89f76194493e5ba425c866a8e59a5fd1925c | |
| parent | 5758d16612eda0f902d7d4c02535afe44dec2ac2 (diff) | |
Support associatedtype local variables and return values in dynamic dispatch code (#1444)
* Refactor lower-generics pass into separate subpasses.
* IR pass to generate witness table wrappers.
* Support associatedtype local variables and return values in dynamic dispatch code.
| -rw-r--r-- | source/slang/slang-ir-generics-lowering-context.cpp | 1 | ||||
| -rw-r--r-- | source/slang/slang-ir-lower-generic-call.cpp | 11 | ||||
| -rw-r--r-- | source/slang/slang-ir-lower-generic-function.cpp | 4 | ||||
| -rw-r--r-- | source/slang/slang-ir-lower-generic-var.cpp | 11 | ||||
| -rw-r--r-- | source/slang/slang-lower-to-ir.cpp | 5 | ||||
| -rw-r--r-- | tests/compute/dynamic-dispatch-7.slang | 85 | ||||
| -rw-r--r-- | tests/compute/dynamic-dispatch-7.slang.expected.txt | 4 |
7 files changed, 116 insertions, 5 deletions
diff --git a/source/slang/slang-ir-generics-lowering-context.cpp b/source/slang/slang-ir-generics-lowering-context.cpp index 6ee0a17f0..0fcb85378 100644 --- a/source/slang/slang-ir-generics-lowering-context.cpp +++ b/source/slang/slang-ir-generics-lowering-context.cpp @@ -15,6 +15,7 @@ namespace Slang case kIROp_ThisType: case kIROp_AssociatedType: case kIROp_InterfaceType: + case kIROp_lookup_interface_method: return true; case kIROp_Specialize: { diff --git a/source/slang/slang-ir-lower-generic-call.cpp b/source/slang/slang-ir-lower-generic-call.cpp index f339d5309..5095fadd3 100644 --- a/source/slang/slang-ir-lower-generic-call.cpp +++ b/source/slang/slang-ir-lower-generic-call.cpp @@ -127,10 +127,21 @@ namespace Slang translateCallInst(callInst, funcType, loweredFunc, specializeInst); } + void lowerCallToInterfaceMethod(IRCall* callInst, IRLookupWitnessMethod* lookupInst) + { + // If we see a call(lookup_interface_method(...), ...), we need to translate + // all occurences of associatedtypes. + auto funcType = cast<IRFuncType>(lookupInst->getDataType()); + auto loweredFunc = lookupInst; + translateCallInst(callInst, funcType, loweredFunc, nullptr); + } + void lowerCall(IRCall* callInst) { if (auto specializeInst = as<IRSpecialize>(callInst->getCallee())) lowerCallToSpecializedFunc(callInst, specializeInst); + else if (auto lookupInst = as<IRLookupWitnessMethod>(callInst->getCallee())) + lowerCallToInterfaceMethod(callInst, lookupInst); } void processInst(IRInst* inst) diff --git a/source/slang/slang-ir-lower-generic-function.cpp b/source/slang/slang-ir-lower-generic-function.cpp index e930c6cc8..92f00c509 100644 --- a/source/slang/slang-ir-lower-generic-function.cpp +++ b/source/slang/slang-ir-lower-generic-function.cpp @@ -122,7 +122,8 @@ namespace Slang auto paramType = param->getDataType(); if (auto ptrType = as<IRPtrTypeBase>(paramType)) paramType = ptrType->getValueType(); - if (isPointerOfType(paramType->getDataType(), kIROp_RTTIType)) + if (isPointerOfType(paramType->getDataType(), kIROp_RTTIType) || + paramType->op == kIROp_lookup_interface_method) { // Lower into a function parameter of raw pointer type. param->setFullType(builder.getRawPointerType()); @@ -277,6 +278,7 @@ namespace Slang // Update the type of lookupInst to the lowered type of the corresponding interface requirement val. // If the requirement is a function, interfaceRequirementVal will be the lowered function type. + // If the requirement is an associatedtype, interfaceRequirementVal will be Ptr<RTTIObject>. IRInst* interfaceRequirementVal = nullptr; auto witnessTableType = cast<IRWitnessTableType>(lookupInst->getWitnessTable()->getDataType()); auto interfaceType = maybeLowerInterfaceType(cast<IRInterfaceType>(witnessTableType->getConformanceType())); diff --git a/source/slang/slang-ir-lower-generic-var.cpp b/source/slang/slang-ir-lower-generic-var.cpp index 0c45e8e38..8799dea6b 100644 --- a/source/slang/slang-ir-lower-generic-var.cpp +++ b/source/slang/slang-ir-lower-generic-var.cpp @@ -19,12 +19,15 @@ namespace Slang void processVarInst(IRInst* varInst) { // We process only var declarations that have type - // `Ptr<IRParam>`. + // `Ptr<IRParam>` or `Ptr<IRLookupInterfaceMethod>`. + // // Due to the processing of `lowerGenericFunction`, // A local variable of generic type now appears as - // `var X:Ptr<irParam:Ptr<RTTIType>>` + // `var X:Ptr<y:Ptr<RTTIType>>`, + // where y can be an IRParam if it is a generic type, + // or an `lookup_interface_method` if it is an associated type. // We match this pattern and turn this inst into - // `X:RawPtr = alloca(rtti_extract_size(irParam))` + // `X:RTTIPtr(irParam) = alloca(irParam)` auto varTypeInst = varInst->getDataType(); if (!varTypeInst) return; @@ -34,7 +37,7 @@ namespace Slang // `varTypeParam` represents a pointer to the RTTI object. auto varTypeParam = ptrType->getValueType(); - if (varTypeParam->op != kIROp_Param) + if (varTypeParam->op != kIROp_Param && varTypeParam->op != kIROp_lookup_interface_method) return; if (!varTypeParam->getDataType()) return; diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index 175f9264f..322489d89 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -3437,6 +3437,11 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo> UNREACHABLE_RETURN(LoweredValInfo()); } + LoweredValInfo visitAssocTypeDecl(AssocTypeDecl* decl) + { + return LoweredValInfo::simple(context->irBuilder->getAssociatedType()); + } + LoweredValInfo visitAssignExpr(AssignExpr* expr) { // Because our representation of lowered "values" diff --git a/tests/compute/dynamic-dispatch-7.slang b/tests/compute/dynamic-dispatch-7.slang new file mode 100644 index 000000000..62ab94e48 --- /dev/null +++ b/tests/compute/dynamic-dispatch-7.slang @@ -0,0 +1,85 @@ +//TEST(compute):COMPARE_COMPUTE:-cpu -xslang -allow-dynamic-code +//DISABLE_TEST(compute):COMPARE_COMPUTE:-cuda -xslang -allow-dynamic-code + +// Test dynamic dispatch code gen for associated-typed return values +// and local variables. +// TODO: test arguments of associated type. + +interface IAssoc +{ + int Compute(); +} + +interface IInterface +{ + associatedtype TAssoc : IAssoc; + + [mutating] + void SetVal(int inVal); + + TAssoc GetAssoc(); +}; + +T.TAssoc CreateT_Assoc_Inner<T:IInterface>(int inVal) +{ + T obj; + obj.SetVal(inVal); + return obj.GetAssoc(); +} + +T.TAssoc CreateT_Assoc<T:IInterface>(int inVal) +{ + return CreateT_Assoc_Inner<T>(inVal); +} + +T CreateT<T:IInterface>(int inVal) +{ + T obj; + obj.SetVal(inVal); + return obj; +} + +struct Impl : IInterface +{ + struct TAssoc : IAssoc + { + int base; + int Compute() + { + return base; + } + }; + + TAssoc assoc; + [mutating] + void SetVal(int inVal) + { + assoc.base = inVal; + } + + TAssoc GetAssoc() + { + return assoc; + } +}; + +int test() +{ + var obj = CreateT<Impl>(2); + var obj2 = CreateT_Assoc<Impl>(1); + // TODO: compiler crash if type parameter is missing. + // (hitting lowering logic of TypeEqualityWitness) + var obj3 = CreateT_Assoc_Inner<Impl>(1); + return obj.GetAssoc().Compute() + obj2.Compute() + obj3.Compute(); +} + +//TEST_INPUT:ubuffer(data=[0 1 2 3], stride=4):out,name=outputBuffer +RWStructuredBuffer<int> outputBuffer : register(u0); + +[numthreads(4, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + uint tid = dispatchThreadID.x; + int outVal = test(); + outputBuffer[tid] = outVal; +} diff --git a/tests/compute/dynamic-dispatch-7.slang.expected.txt b/tests/compute/dynamic-dispatch-7.slang.expected.txt new file mode 100644 index 000000000..e785149d2 --- /dev/null +++ b/tests/compute/dynamic-dispatch-7.slang.expected.txt @@ -0,0 +1,4 @@ +4 +4 +4 +4 |
