diff options
| author | Yong He <yonghe@outlook.com> | 2023-04-12 22:58:22 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-04-12 22:58:22 -0700 |
| commit | ca7bf79df3a3f5f4494912cb0572c36662755b9d (patch) | |
| tree | 64b14034326be8285c0265e74ad3ed11e29ff062 | |
| parent | 12ec9b832fc74faba7162e54e04f7f48878ea88e (diff) | |
Combine lookupWitness lowering with specialization. (#2794)
31 files changed, 735 insertions, 60 deletions
diff --git a/build/visual-studio/slang/slang.vcxproj b/build/visual-studio/slang/slang.vcxproj index c36dd269d..875c6d399 100644 --- a/build/visual-studio/slang/slang.vcxproj +++ b/build/visual-studio/slang/slang.vcxproj @@ -411,6 +411,7 @@ IF EXIST ..\..\..\external\slang-glslang\bin\windows-aarch64\release\slang-glsla <ClInclude Include="..\..\..\source\slang\slang-ir-lower-reinterpret.h" />
<ClInclude Include="..\..\..\source\slang\slang-ir-lower-result-type.h" />
<ClInclude Include="..\..\..\source\slang\slang-ir-lower-tuple-types.h" />
+ <ClInclude Include="..\..\..\source\slang\slang-ir-lower-witness-lookup.h" />
<ClInclude Include="..\..\..\source\slang\slang-ir-marshal-native-call.h" />
<ClInclude Include="..\..\..\source\slang\slang-ir-metadata.h" />
<ClInclude Include="..\..\..\source\slang\slang-ir-missing-return.h" />
@@ -602,6 +603,7 @@ IF EXIST ..\..\..\external\slang-glslang\bin\windows-aarch64\release\slang-glsla <ClCompile Include="..\..\..\source\slang\slang-ir-lower-reinterpret.cpp" />
<ClCompile Include="..\..\..\source\slang\slang-ir-lower-result-type.cpp" />
<ClCompile Include="..\..\..\source\slang\slang-ir-lower-tuple-types.cpp" />
+ <ClCompile Include="..\..\..\source\slang\slang-ir-lower-witness-lookup.cpp" />
<ClCompile Include="..\..\..\source\slang\slang-ir-marshal-native-call.cpp" />
<ClCompile Include="..\..\..\source\slang\slang-ir-metadata.cpp" />
<ClCompile Include="..\..\..\source\slang\slang-ir-missing-return.cpp" />
diff --git a/build/visual-studio/slang/slang.vcxproj.filters b/build/visual-studio/slang/slang.vcxproj.filters index f353468e4..9ef8a68b3 100644 --- a/build/visual-studio/slang/slang.vcxproj.filters +++ b/build/visual-studio/slang/slang.vcxproj.filters @@ -321,6 +321,9 @@ <ClInclude Include="..\..\..\source\slang\slang-ir-lower-tuple-types.h">
<Filter>Header Files</Filter>
</ClInclude>
+ <ClInclude Include="..\..\..\source\slang\slang-ir-lower-witness-lookup.h">
+ <Filter>Header Files</Filter>
+ </ClInclude>
<ClInclude Include="..\..\..\source\slang\slang-ir-marshal-native-call.h">
<Filter>Header Files</Filter>
</ClInclude>
@@ -890,6 +893,9 @@ <ClCompile Include="..\..\..\source\slang\slang-ir-lower-tuple-types.cpp">
<Filter>Source Files</Filter>
</ClCompile>
+ <ClCompile Include="..\..\..\source\slang\slang-ir-lower-witness-lookup.cpp">
+ <Filter>Source Files</Filter>
+ </ClCompile>
<ClCompile Include="..\..\..\source\slang\slang-ir-marshal-native-call.cpp">
<Filter>Source Files</Filter>
</ClCompile>
diff --git a/source/slang/slang-ast-type.cpp b/source/slang/slang-ast-type.cpp index 1fed2d52a..72f1764df 100644 --- a/source/slang/slang-ast-type.cpp +++ b/source/slang/slang-ast-type.cpp @@ -69,8 +69,8 @@ Type* Type::getCanonicalType() // TODO(tfoley): worry about thread safety here? auto canType = et->createCanonicalType(); et->canonicalType = canType; - - SLANG_ASSERT(et->canonicalType); + if (!et->canonicalType) + return getASTBuilder()->getErrorType(); } return et->canonicalType; } @@ -481,7 +481,9 @@ Type* NamedExpressionType::_createCanonicalTypeOverride() { if (!innerType) innerType = getType(m_astBuilder, declRef); - return innerType->getCanonicalType(); + if (innerType) + return innerType->getCanonicalType(); + return nullptr; } HashCode NamedExpressionType::_getHashCodeOverride() diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp index ca141bd7a..e2948561c 100644 --- a/source/slang/slang-emit.cpp +++ b/source/slang/slang-emit.cpp @@ -374,7 +374,9 @@ Result linkAndOptimizeIR( dumpIRIfEnabled(codeGenContext, irModule, "BEFORE-SPECIALIZE"); if (!codeGenContext->isSpecializationDisabled()) - changed |= specializeModule(irModule); + changed |= specializeModule(irModule, codeGenContext->getSink()); + if (codeGenContext->getSink()->getErrorCount() != 0) + return SLANG_FAIL; dumpIRIfEnabled(codeGenContext, irModule, "AFTER-SPECIALIZE"); eliminateDeadCode(irModule); diff --git a/source/slang/slang-ir-any-value-marshalling.cpp b/source/slang/slang-ir-any-value-marshalling.cpp index 79a41f0d5..27bba3fde 100644 --- a/source/slang/slang-ir-any-value-marshalling.cpp +++ b/source/slang/slang-ir-any-value-marshalling.cpp @@ -666,6 +666,7 @@ namespace Slang case kIROp_IntType: case kIROp_FloatType: case kIROp_UIntType: + case kIROp_BoolType: return alignUp(offset, 4) + 4; case kIROp_UInt64Type: case kIROp_Int64Type: @@ -755,6 +756,21 @@ namespace Slang size += kRTTIHeaderSize; return alignUp(offset, 4) + alignUp((SlangInt)size, 4); } + case kIROp_AssociatedType: + { + auto associatedType = cast<IRAssociatedType>(type); + SlangInt maxSize = 0; + for (UInt i = 0; i < associatedType->getOperandCount(); i++) + maxSize = Math::Max(maxSize, _getAnyValueSizeRaw((IRType*)associatedType->getOperand(i), offset)); + return maxSize; + } + case kIROp_ThisType: + { + auto thisType = cast<IRThisType>(type); + auto interfaceType = thisType->getConstraintType(); + auto size = SharedGenericsLoweringContext::getInterfaceAnyValueSize(interfaceType, interfaceType->sourceLoc); + return alignUp(offset, 4) + alignUp((SlangInt)size, 4); + } case kIROp_ExtractExistentialType: { auto existentialValue = type->getOperand(0); diff --git a/source/slang/slang-ir-autodiff-cfg-norm.cpp b/source/slang/slang-ir-autodiff-cfg-norm.cpp index fe3c70bde..80ee37988 100644 --- a/source/slang/slang-ir-autodiff-cfg-norm.cpp +++ b/source/slang/slang-ir-autodiff-cfg-norm.cpp @@ -410,6 +410,7 @@ struct CFGNormalizationPass } case kIROp_loop: + case kIROp_Switch: { auto breakBlock = normalizeBreakableRegion(terminator); diff --git a/source/slang/slang-ir-autodiff-fwd.cpp b/source/slang/slang-ir-autodiff-fwd.cpp index 6c3d6a934..2be394537 100644 --- a/source/slang/slang-ir-autodiff-fwd.cpp +++ b/source/slang/slang-ir-autodiff-fwd.cpp @@ -877,19 +877,6 @@ InstPair ForwardDiffTranscriber::transcribeConst(IRBuilder*, IRInst* origInst) return InstPair(nullptr, nullptr); } -IRInst* ForwardDiffTranscriber::findInterfaceRequirement(IRInterfaceType* type, IRInst* key) -{ - for (UInt i = 0; i < type->getOperandCount(); i++) - { - if (auto req = as<IRInterfaceRequirementEntry>(type->getOperand(i))) - { - if (req->getRequirementKey() == key) - return req->getRequirementVal(); - } - } - return nullptr; -} - InstPair ForwardDiffTranscriber::transcribeSpecialize(IRBuilder* builder, IRSpecialize* origSpecialize) { auto primalBase = findOrTranscribePrimalInst(builder, origSpecialize->getBase()); @@ -1810,6 +1797,7 @@ InstPair ForwardDiffTranscriber::transcribeInstImpl(IRBuilder* builder, IRInst* case kIROp_CastIntToFloat: case kIROp_CastFloatToInt: case kIROp_DetachDerivative: + case kIROp_GetSequentialID: return trascribeNonDiffInst(builder, origInst); // A call to createDynamicObject<T>(arbitraryData) cannot provide a diff value, diff --git a/source/slang/slang-ir-autodiff-fwd.h b/source/slang/slang-ir-autodiff-fwd.h index 8fd271fd8..a9193acbe 100644 --- a/source/slang/slang-ir-autodiff-fwd.h +++ b/source/slang/slang-ir-autodiff-fwd.h @@ -62,8 +62,6 @@ struct ForwardDiffTranscriber : AutoDiffTranscriberBase InstPair transcribeConst(IRBuilder* builder, IRInst* origInst); - IRInst* findInterfaceRequirement(IRInterfaceType* type, IRInst* key); - InstPair transcribeSpecialize(IRBuilder* builder, IRSpecialize* origSpecialize); InstPair transcribeFieldExtract(IRBuilder* builder, IRInst* originalInst); diff --git a/source/slang/slang-ir-autodiff-primal-hoist.cpp b/source/slang/slang-ir-autodiff-primal-hoist.cpp index 04d5560d9..363572c86 100644 --- a/source/slang/slang-ir-autodiff-primal-hoist.cpp +++ b/source/slang/slang-ir-autodiff-primal-hoist.cpp @@ -771,6 +771,7 @@ static bool shouldStoreInst(IRInst* inst) case kIROp_ExtractExistentialType: case kIROp_ExtractExistentialWitnessTable: case kIROp_undefined: + case kIROp_GetSequentialID: return false; case kIROp_GetElement: case kIROp_FieldExtract: diff --git a/source/slang/slang-ir-autodiff-transcriber-base.cpp b/source/slang/slang-ir-autodiff-transcriber-base.cpp index 9cbea7873..81c0ab235 100644 --- a/source/slang/slang-ir-autodiff-transcriber-base.cpp +++ b/source/slang/slang-ir-autodiff-transcriber-base.cpp @@ -645,19 +645,6 @@ IRType* AutoDiffTranscriberBase::tryGetDiffPairType(IRBuilder* builder, IRType* return nullptr; } -IRInst* AutoDiffTranscriberBase::findInterfaceRequirement(IRInterfaceType* type, IRInst* key) -{ - for (UInt i = 0; i < type->getOperandCount(); i++) - { - if (auto req = as<IRInterfaceRequirementEntry>(type->getOperand(i))) - { - if (req->getRequirementKey() == key) - return req->getRequirementVal(); - } - } - return nullptr; -} - InstPair AutoDiffTranscriberBase::transcribeParam(IRBuilder* builder, IRParam* origParam) { auto primalDataType = findOrTranscribePrimalInst(builder, origParam->getDataType()); diff --git a/source/slang/slang-ir-autodiff-transcriber-base.h b/source/slang/slang-ir-autodiff-transcriber-base.h index d5070689e..86af2fb8e 100644 --- a/source/slang/slang-ir-autodiff-transcriber-base.h +++ b/source/slang/slang-ir-autodiff-transcriber-base.h @@ -111,8 +111,6 @@ struct AutoDiffTranscriberBase IRType* tryGetDiffPairType(IRBuilder* builder, IRType* primalType); - IRInst* findInterfaceRequirement(IRInterfaceType* type, IRInst* key); - IRInst* getDifferentialZeroOfType(IRBuilder* builder, IRType* primalType); InstPair trascribeNonDiffInst(IRBuilder* builder, IRInst* origInst); diff --git a/source/slang/slang-ir-generics-lowering-context.cpp b/source/slang/slang-ir-generics-lowering-context.cpp index 8e726cba1..0dbc84e51 100644 --- a/source/slang/slang-ir-generics-lowering-context.cpp +++ b/source/slang/slang-ir-generics-lowering-context.cpp @@ -339,7 +339,7 @@ namespace Slang } } - List<IRWitnessTable*> SharedGenericsLoweringContext::getWitnessTablesFromInterfaceType(IRInst* interfaceType) + List<IRWitnessTable*> getWitnessTablesFromInterfaceType(IRModule* module, IRInst* interfaceType) { List<IRWitnessTable*> witnessTables; for (auto globalInst : module->getGlobalInsts()) @@ -354,6 +354,11 @@ namespace Slang return witnessTables; } + List<IRWitnessTable*> SharedGenericsLoweringContext::getWitnessTablesFromInterfaceType(IRInst* interfaceType) + { + return Slang::getWitnessTablesFromInterfaceType(module, interfaceType); + } + IRIntegerValue SharedGenericsLoweringContext::getInterfaceAnyValueSize(IRInst* type, SourceLoc usageLoc) { SLANG_UNUSED(usageLoc); diff --git a/source/slang/slang-ir-generics-lowering-context.h b/source/slang/slang-ir-generics-lowering-context.h index 0c07a93c9..8030751d0 100644 --- a/source/slang/slang-ir-generics-lowering-context.h +++ b/source/slang/slang-ir-generics-lowering-context.h @@ -74,7 +74,7 @@ namespace Slang IRInst* maybeEmitRTTIObject(IRInst* typeInst); static IRIntegerValue getInterfaceAnyValueSize(IRInst* type, SourceLoc usageLoc); - IRType* lowerAssociatedType(IRBuilder* builder, IRInst* type); + static IRType* lowerAssociatedType(IRBuilder* builder, IRInst* type); IRType* lowerType(IRBuilder* builder, IRInst* paramType, const Dictionary<IRInst*, IRInst*>& typeMapping, IRType* concreteType); @@ -86,20 +86,12 @@ namespace Slang // Get a list of all witness tables whose conformance type is `interfaceType`. List<IRWitnessTable*> getWitnessTablesFromInterfaceType(IRInst* interfaceType); - IRInst* findWitnessTableEntry(IRWitnessTable* table, IRInst* key) - { - for (auto entry : table->getEntries()) - { - if (entry->getRequirementKey() == key) - return entry->getSatisfyingVal(); - } - return nullptr; - } - - /// Does the given `concreteType` fit within the any-value size deterined by `interfaceType`? + /// Does the given `concreteType` fit within the any-value size deterined by `interfaceType`? bool doesTypeFitInAnyValue(IRType* concreteType, IRInterfaceType* interfaceType, IRIntegerValue* outTypeSize = nullptr, IRIntegerValue* outLimit = nullptr); }; + List<IRWitnessTable*> getWitnessTablesFromInterfaceType(IRModule* module, IRInst* interfaceType); + bool isPolymorphicType(IRInst* typeInst); // Returns true if typeInst represents a type and should be lowered into diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h index 8d06c6970..a8ec5a66f 100644 --- a/source/slang/slang-ir-inst-defs.h +++ b/source/slang/slang-ir-inst-defs.h @@ -740,7 +740,8 @@ INST(HighLevelDeclDecoration, highLevelDecl, 1, 0) INST(AnyValueSizeDecoration, AnyValueSize, 1, 0) INST(SpecializeDecoration, SpecializeDecoration, 0, 0) INST(SequentialIDDecoration, SequentialIDDecoration, 1, 0) - + INST(StaticRequirementDecoration, StaticRequirementDecoration, 0, 0) + INST(DispatchFuncDecoration, DispatchFuncDecoration, 1, 0) INST(TypeConstraintDecoration, TypeConstraintDecoration, 1, 0) diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h index a19896287..bf0a5d4cd 100644 --- a/source/slang/slang-ir-insts.h +++ b/source/slang/slang-ir-insts.h @@ -198,6 +198,14 @@ struct IRAnyValueSizeDecoration : IRDecoration } }; +struct IRDispatchFuncDecoration : IRDecoration +{ + enum { kOp = kIROp_DispatchFuncDecoration }; + IR_LEAF_ISA(DispatchFuncDecoration) + + IRInst* getFunc() { return getOperand(0); } +}; + struct IRSpecializeDecoration : IRDecoration { enum { kOp = kIROp_SpecializeDecoration }; @@ -265,6 +273,7 @@ IR_SIMPLE_DECORATION(VulkanHitAttributesDecoration) /// to it. IR_SIMPLE_DECORATION(VulkanHitObjectAttributesDecoration) + struct IRRequireGLSLVersionDecoration : IRDecoration { enum { kOp = kIROp_RequireGLSLVersionDecoration }; @@ -326,6 +335,7 @@ IR_SIMPLE_DECORATION(KeepAliveDecoration) IR_SIMPLE_DECORATION(RequiresNVAPIDecoration) IR_SIMPLE_DECORATION(NoInlineDecoration) IR_SIMPLE_DECORATION(AlwaysFoldIntoUseSiteDecoration) +IR_SIMPLE_DECORATION(StaticRequirementDecoration) struct IRNVAPIMagicDecoration : IRDecoration { @@ -3913,6 +3923,16 @@ public: addDecoration(inst, kIROp_AnyValueSizeDecoration, getIntValue(getIntType(), value)); } + void addDispatchFuncDecoration(IRInst* inst, IRInst* func) + { + addDecoration(inst, kIROp_DispatchFuncDecoration, func); + } + + void addStaticRequirementDecoration(IRInst* inst) + { + addDecoration(inst, kIROp_StaticRequirementDecoration); + } + void addSpecializeDecoration(IRInst* inst) { addDecoration(inst, kIROp_SpecializeDecoration); diff --git a/source/slang/slang-ir-lower-generic-function.cpp b/source/slang/slang-ir-lower-generic-function.cpp index e45b20563..12be27f07 100644 --- a/source/slang/slang-ir-lower-generic-function.cpp +++ b/source/slang/slang-ir-lower-generic-function.cpp @@ -28,14 +28,18 @@ namespace Slang auto genericParent = as<IRGeneric>(genericValue); SLANG_ASSERT(genericParent); SLANG_ASSERT(genericParent->getDataType()); - auto func = as<IRFunc>(findGenericReturnVal(genericParent)); + auto genericRetVal = findGenericReturnVal(genericParent); + auto func = as<IRFunc>(genericRetVal); if (!func) { // Nested generic functions are supposed to be flattened before entering // this pass. The reason we are still seeing them must be that they are // intrinsic functions. In this case we ignore the function. - SLANG_ASSERT(findInnerMostGenericReturnVal(genericParent) - ->findDecoration<IRTargetIntrinsicDecoration>() != nullptr); + if (as<IRGeneric>(genericRetVal)) + { + SLANG_ASSERT(findInnerMostGenericReturnVal(genericParent) + ->findDecoration<IRTargetIntrinsicDecoration>() != nullptr); + } return genericValue; } SLANG_ASSERT(func); diff --git a/source/slang/slang-ir-lower-witness-lookup.cpp b/source/slang/slang-ir-lower-witness-lookup.cpp new file mode 100644 index 000000000..1a1900cd2 --- /dev/null +++ b/source/slang/slang-ir-lower-witness-lookup.cpp @@ -0,0 +1,407 @@ +// slang-ir-lower-generic-existential.cpp + +#include "slang-ir-lower-witness-lookup.h" +#include "slang-ir.h" +#include "slang-ir-insts.h" +#include "slang-ir-util.h" +#include "slang-ir-clone.h" +#include "slang-ir-generics-lowering-context.h" + +namespace Slang +{ + +struct WitnessLookupLoweringContext +{ + IRModule* module; + DiagnosticSink* sink; + + Dictionary<IRStructKey*, IRInst*> witnessDispatchFunctions; + + void init() + { + // Reconstruct the witness dispatch functions map. + for (auto inst : module->getGlobalInsts()) + { + if (auto key = as<IRStructKey>(inst)) + { + for (auto decor : key->getDecorations()) + { + if (auto witnessDispatchFunc = as<IRDispatchFuncDecoration>(decor)) + { + witnessDispatchFunctions.Add(key, witnessDispatchFunc->getFunc()); + } + } + } + } + } + + bool hasAssocType(IRInst* type) + { + if (!type) + return false; + if (type->getOp() == kIROp_AssociatedType) + return true; + for (UInt i = 0; i < type->getOperandCount(); i++) + { + if (hasAssocType(type->getOperand(i))) + return true; + } + return false; + } + + IRType* translateType(IRBuilder builder, IRInst* type) + { + if (!type) + return nullptr; + if (auto genType = as<IRGeneric>(type)) + { + IRCloneEnv cloneEnv; + builder.setInsertBefore(genType); + auto newGeneric = as<IRGeneric>(cloneInst(&cloneEnv, &builder, genType)); + newGeneric->setFullType(builder.getGenericKind()); + auto retVal = findGenericReturnVal(newGeneric); + builder.setInsertBefore(retVal); + auto translated = translateType(builder, retVal); + retVal->replaceUsesWith(translated); + return (IRType*)newGeneric; + } + else if (auto thisType = as<IRThisType>(type)) + { + return (IRType*)thisType->getConstraintType(); + } + else if (auto assocType = as<IRAssociatedType>(type)) + { + return assocType; + } + + if (as<IRBasicType>(type)) + return (IRType*)type; + + switch (type->getOp()) + { + case kIROp_Param: + case kIROp_VectorType: + case kIROp_MatrixType: + case kIROp_StructType: + case kIROp_ClassType: + case kIROp_InterfaceType: + return (IRType*)type; + default: + { + List<IRInst*> translatedOperands; + for (UInt i = 0; i < type->getOperandCount(); i++) + { + translatedOperands.add(translateType(builder, type->getOperand(i))); + } + auto translated = builder.emitIntrinsicInst( + builder.getTypeKind(), + type->getOp(), + (UInt)translatedOperands.getCount(), + translatedOperands.getBuffer()); + return (IRType*)translated; + } + } + } + + IRInst* findOrCreateDispatchFunc(IRLookupWitnessMethod* lookupInst) + { + IRInst* func = nullptr; + auto requirementKey = cast<IRStructKey>(lookupInst->getRequirementKey()); + if (witnessDispatchFunctions.TryGetValue(requirementKey, func)) + { + return func; + } + + auto witnessTableOperand = lookupInst->getWitnessTable(); + auto witnessTableType = as<IRWitnessTableTypeBase>(witnessTableOperand->getDataType()); + SLANG_RELEASE_ASSERT(witnessTableType); + auto interfaceType = as<IRInterfaceType>(unwrapAttributedType(witnessTableType->getConformanceType())); + SLANG_RELEASE_ASSERT(interfaceType); + if (interfaceType->findDecoration<IRComInterfaceDecoration>()) + return nullptr; + auto requirementType = findInterfaceRequirement(interfaceType, requirementKey); + SLANG_RELEASE_ASSERT(requirementType); + + // We only lower non-static function requirement lookups for now. + // Our front end will stick a StaticRequirementDecoration on the IRStructKey for static member requirements. + if (lookupInst->getRequirementKey()->findDecoration<IRStaticRequirementDecoration>()) + return nullptr; + auto interfaceMethodFuncType = as<IRFuncType>(getResolvedInstForDecorations(requirementType)); + if (interfaceMethodFuncType) + { + // Detect cases that we currently does not support and exit. + + // If this is a non static function requirement, we should + // make sure the first parameter is the interface type. If not, something has gone wrong. + if (interfaceMethodFuncType->getParamCount() == 0) + return nullptr; + if (!as<IRThisType>(unwrapAttributedType(interfaceMethodFuncType->getParamType(0)))) + return nullptr; + + // The function has any associated type parameter, we currently can't lower it early in this pass. + // We will lower it in the catch all generic lowering pass. + for (UInt i = 1; i < interfaceMethodFuncType->getParamCount(); i++) + { + if (hasAssocType(interfaceMethodFuncType->getParamType(i))) + return nullptr; + } + + // If return type is a composite type containing an assoc type, we won't lower it now. + // Supporting general use of assoc type is possible, but would require more complex logic + // in this pass to marshal things to and from existential types. + if (interfaceMethodFuncType->getResultType()->getOp() != kIROp_AssociatedType && + hasAssocType(interfaceMethodFuncType->getResultType())) + return nullptr; + } + else + { + return nullptr; + } + + + IRBuilder builder(module); + builder.setInsertBefore(getParentFunc(lookupInst)); + + // Create a dispatch func. + IRFunc* dispatchFunc = nullptr; + IRFuncType* dispatchFuncType = nullptr; + IRGeneric* parentGeneric = nullptr; + + // If requirementType is a generic, we need to create a new generic that has the same parameters. + if (auto genericRequirement = as<IRGeneric>(requirementType)) + { + IRCloneEnv cloneEnv; + parentGeneric = as<IRGeneric>(cloneInst(&cloneEnv, &builder, genericRequirement)); + + auto returnInst = as<IRReturn>(parentGeneric->getFirstBlock()->getLastInst()); + SLANG_RELEASE_ASSERT(returnInst); + builder.setInsertBefore(returnInst); + auto oldDispatchFuncType = as<IRFuncType>(returnInst->getVal()); + if (!oldDispatchFuncType) + return nullptr; + + dispatchFuncType = as<IRFuncType>(translateType(builder, oldDispatchFuncType)); + + SLANG_RELEASE_ASSERT(dispatchFuncType); + + dispatchFunc = builder.createFunc(); + dispatchFunc->setFullType(dispatchFuncType); + builder.emitReturn(dispatchFunc); + returnInst->removeAndDeallocate(); + + parentGeneric->setFullType(translateType(builder, requirementType)); + } + else + { + dispatchFuncType = as<IRFuncType>(translateType(builder, requirementType)); + dispatchFunc = builder.createFunc(); + dispatchFunc->setFullType(dispatchFuncType); + } + + // We need to inline this function if the requirement is differentiable, + // so that the autodiff pass doesn't need to handle the dispatch function. + if (requirementKey->findDecoration<IRForwardDerivativeDecoration>()|| + requirementKey->findDecoration<IRBackwardDerivativeDecoration>()) + { + builder.addForceInlineDecoration(dispatchFunc); + } + + // Collect generic params. + List<IRInst*> genericParams; + if (parentGeneric) + { + for (auto param : parentGeneric->getParams()) + genericParams.add(param); + } + + // Emit the body of the dispatch func. + builder.setInsertInto(dispatchFunc); + auto firstBlock = builder.emitBlock(); + auto firstBlockBuilder = builder; + // Emit parameters. + List<IRInst*> params; + + for (UInt i = 0; i < dispatchFuncType->getParamCount(); i++) + { + params.add(builder.emitParam(dispatchFuncType->getParamType(i))); + } + auto witness = builder.emitExtractExistentialWitnessTable(params[0]); + + auto witnessTables = getWitnessTablesFromInterfaceType(module, interfaceType); + if (witnessTables.getCount() == 0) + { + // If there is no witness table, we should emit an error. + sink->diagnose(lookupInst, Diagnostics::noTypeConformancesFoundForInterface, interfaceType); + return nullptr; + } + else + { + List<IRInst*> cases; + for (auto witnessTable : witnessTables) + { + IRBlock* block = builder.emitBlock(); + auto caseValue = firstBlockBuilder.emitGetSequentialIDInst(witnessTable); + cases.add(caseValue); + cases.add(block); + auto entry = findWitnessTableEntry(witnessTable, requirementKey); + SLANG_RELEASE_ASSERT(entry); + // If the entry is a generic, we need to specialize it. + if (auto genericEntry = as<IRGeneric>(entry)) + { + auto specializedFuncType = builder.emitSpecializeInst( + builder.getTypeKind(), + entry->getFullType(), + (UInt)genericParams.getCount(), + genericParams.getBuffer()); + entry = builder.emitSpecializeInst( + (IRType*)specializedFuncType, + entry, + (UInt)genericParams.getCount(), + genericParams.getBuffer()); + } + auto args = params; + // Reinterpret the first arg into the concrete type. + args[0] = builder.emitReinterpret(witnessTable->getConcreteType(), + builder.emitExtractExistentialValue(builder.emitExtractExistentialType(args[0]), args[0])); + + auto calleeFuncType = as<IRFuncType>(getResolvedInstForDecorations(entry)->getFullType()); + auto callReturnType = calleeFuncType->getResultType(); + if (callReturnType->getParent() != module->getModuleInst()) + { + // the return type is dependent on generic parameter, use the type from dispatchFuncType instead. + callReturnType = dispatchFuncType->getResultType(); + } + + auto call = builder.emitCallInst( + callReturnType, + entry, + (UInt)args.getCount(), + args.getBuffer()); + // If result type is an associated type, we need to pack it into an anyValue. + if (as<IRAssociatedType>(dispatchFuncType->getResultType())) + { + call = builder.emitPackAnyValue(dispatchFuncType->getResultType(), call); + } + builder.emitReturn(call); + } + builder.setInsertInto(firstBlock); + if (witnessTables.getCount() == 1) + { + builder.emitBranch((IRBlock*)cases[1]); + } + else + { + auto witnessId = firstBlockBuilder.emitGetSequentialIDInst(witness); + auto breakLabel = builder.emitBlock(); + builder.emitUnreachable(); + firstBlockBuilder.emitSwitch( + witnessId, + breakLabel, + (IRBlock*)cases.getLast(), + (UInt)(cases.getCount() - 2), + cases.getBuffer()); + } + } + + // Stick a decoration on the requirement key so we can find the dispatch func later. + IRInst* resultValue = parentGeneric ? (IRInst*)parentGeneric : dispatchFunc; + builder.addDispatchFuncDecoration(requirementKey, resultValue); + + // Register the dispatch func to witnessDispatchFunctions dictionary. + witnessDispatchFunctions[requirementKey] = resultValue; + + return resultValue; + } + + void rewriteCallSite(IRCall* call, IRInst* dispatchFunc, IRInst* existentialObject) + { + SLANG_RELEASE_ASSERT(call->getArgCount() != 0); + call->setOperand(0, dispatchFunc); + call->setOperand(1, existentialObject); + } + + bool processWitnessLookup(IRLookupWitnessMethod* lookupInst) + { + auto witnessTableOperand = lookupInst->getWitnessTable(); + auto extractInst = as<IRExtractExistentialWitnessTable>(witnessTableOperand); + if (!extractInst) + return false; + auto dispatchFunc = findOrCreateDispatchFunc(lookupInst); + if (!dispatchFunc) + return false; + bool changed = false; + auto existentialObject = extractInst->getOperand(0); + + IRBuilder builder(lookupInst); + builder.setInsertBefore(lookupInst); + traverseUses(lookupInst, [&](IRUse* use) + { + if (auto specialize = as<IRSpecialize>(use->getUser())) + { + List<IRInst*> args; + for (UInt i = 0; i < specialize->getArgCount(); i++) + args.add(specialize->getArg(i)); + auto specializedType = builder.emitSpecializeInst( + builder.getTypeKind(), + dispatchFunc->getFullType(), + (UInt)args.getCount(), + args.getBuffer()); + auto newSpecialize = builder.emitSpecializeInst( + (IRType*)specializedType, + dispatchFunc, + (UInt)args.getCount(), + args.getBuffer()); + traverseUses(specialize, [&](IRUse* specializeUse) + { + if (auto call = as<IRCall>(specializeUse->getUser())) + { + changed = true; + rewriteCallSite(call, newSpecialize, existentialObject); + } + }); + } + else if (auto call = as<IRCall>(use->getUser())) + { + changed = true; + rewriteCallSite(call, dispatchFunc, existentialObject); + } + }); + return changed; + } + + bool processFunc(IRFunc* func) + { + bool changed = false; + for (auto bb : func->getBlocks()) + { + for (auto inst : bb->getChildren()) + { + if (auto witnessLookupInst = as<IRLookupWitnessMethod>(inst)) + { + changed |= processWitnessLookup(witnessLookupInst); + } + } + } + return changed; + } +}; + +bool lowerWitnessLookup(IRModule* module, DiagnosticSink* sink) +{ + bool changed = false; + WitnessLookupLoweringContext context; + context.module = module; + context.sink = sink; + context.init(); + + for (auto inst : module->getGlobalInsts()) + { + // Process all fully specialized functions and look for + // witness lookup instructions. If we see a lookup for a non-static function, + // create a dispatch function and replace the lookup with a call to the dispatch function. + if (auto func = as<IRFunc>(inst)) + changed |= context.processFunc(func); + } + return changed; +} +} diff --git a/source/slang/slang-ir-lower-witness-lookup.h b/source/slang/slang-ir-lower-witness-lookup.h new file mode 100644 index 000000000..4ae447210 --- /dev/null +++ b/source/slang/slang-ir-lower-witness-lookup.h @@ -0,0 +1,16 @@ +// slang-ir-lower-witness-lookup.h +#pragma once + +namespace Slang +{ + struct IRModule; + class DiagnosticSink; + + /// Lower calls to a witness lookup into a call to a dispatch function. + /// For example, if we see call(witnessLookup(wt, key)), we will create a + /// dispatch function that calls into different implementations based on witness table + /// ID. The dispatch function will be called instead of witnessLookup. + bool lowerWitnessLookup(IRModule* module, DiagnosticSink* sink); + +} + diff --git a/source/slang/slang-ir-peephole.cpp b/source/slang/slang-ir-peephole.cpp index ab3f0ceab..2244b480a 100644 --- a/source/slang/slang-ir-peephole.cpp +++ b/source/slang/slang-ir-peephole.cpp @@ -827,6 +827,50 @@ struct PeepholeContext : InstPassBase } } break; + case kIROp_swizzle: + { + // If we see a swizzle(makeVector) then we can replace it with the values from makeVector. + auto makeVector = inst->getOperand(0); + if (makeVector->getOp() != kIROp_MakeVector) + break; + auto swizzle = as<IRSwizzle>(inst); + List<IRInst*> vals; + auto vectorType = as<IRVectorType>(makeVector->getDataType()); + auto vectorSize = as<IRIntLit>(vectorType->getElementCount()); + if (!vectorSize) + break; + if (makeVector->getOperandCount() != (UInt)vectorSize->getValue()) + break; + for (UInt i = 0; i < swizzle->getElementCount(); i++) + { + auto index = swizzle->getElementIndex(i); + auto intLitIndex = as<IRIntLit>(index); + if (!intLitIndex) + return; + if (intLitIndex->getValue() < (Int)makeVector->getOperandCount()) + vals.add(makeVector->getOperand((UInt)intLitIndex->getValue())); + else + return; + } + if (vals.getCount() == 1) + { + inst->replaceUsesWith(vals[0]); + maybeRemoveOldInst(inst); + changed = true; + } + else + { + IRBuilder builder(module); + builder.setInsertBefore(inst); + auto newMakeVector = builder.emitMakeVector( + swizzle->getDataType(), (UInt)vals.getCount(), vals.getBuffer()); + inst->replaceUsesWith(newMakeVector); + maybeRemoveOldInst(inst); + changed = true; + } + break; + } + default: break; } diff --git a/source/slang/slang-ir-specialize-dispatch.cpp b/source/slang/slang-ir-specialize-dispatch.cpp index 4d9b15bbd..074aedb06 100644 --- a/source/slang/slang-ir-specialize-dispatch.cpp +++ b/source/slang/slang-ir-specialize-dispatch.cpp @@ -3,6 +3,7 @@ #include "slang-ir-generics-lowering-context.h" #include "slang-ir-insts.h" #include "slang-ir.h" +#include "slang-ir-util.h" namespace Slang { @@ -112,7 +113,7 @@ IRFunc* specializeDispatchFunction(SharedGenericsLoweringContext* sharedContext, builder->setInsertInto(defaultBlock); } - auto callee = sharedContext->findWitnessTableEntry(witnessTable, requirementKey); + auto callee = findWitnessTableEntry(witnessTable, requirementKey); SLANG_ASSERT(callee); auto specializedCallInst = builder->emitCallInst(callInst->getFullType(), callee, params); if (callInst->getDataType()->getOp() == kIROp_VoidType) diff --git a/source/slang/slang-ir-specialize-dynamic-associatedtype-lookup.cpp b/source/slang/slang-ir-specialize-dynamic-associatedtype-lookup.cpp index e68ac5b73..1a1186cda 100644 --- a/source/slang/slang-ir-specialize-dynamic-associatedtype-lookup.cpp +++ b/source/slang/slang-ir-specialize-dynamic-associatedtype-lookup.cpp @@ -75,7 +75,7 @@ struct AssociatedTypeLookupSpecializationContext builder.setInsertInto(defaultBlock); } - auto resultWitnessTable = sharedContext->findWitnessTableEntry(witnessTable, key); + auto resultWitnessTable = findWitnessTableEntry(witnessTable, key); auto resultWitnessTableIDDecoration = resultWitnessTable->findDecoration<IRSequentialIDDecoration>(); SLANG_ASSERT(resultWitnessTableIDDecoration); diff --git a/source/slang/slang-ir-specialize.cpp b/source/slang/slang-ir-specialize.cpp index d2e042363..eb3677653 100644 --- a/source/slang/slang-ir-specialize.cpp +++ b/source/slang/slang-ir-specialize.cpp @@ -5,6 +5,7 @@ #include "slang-ir-clone.h" #include "slang-ir-insts.h" #include "slang-ir-ssa-simplification.h" +#include "slang-ir-lower-witness-lookup.h" namespace Slang { @@ -43,6 +44,7 @@ struct SpecializationContext // For convenience, we will keep a pointer to the module // we are specializing. IRModule* module; + DiagnosticSink* sink; bool changed = false; @@ -932,7 +934,11 @@ struct SpecializationContext } else { - break; + // If we run out of specialization opportunities, consider + // lower lookupWitnessMethod insts into dynamic dispatch calls. + iterChanged = lowerWitnessLookup(module, sink); + if (!iterChanged || sink->getErrorCount()) + break; } } @@ -1323,7 +1329,6 @@ struct SpecializationContext IRInst* curInst = localWorkList.getLast(); localWorkList.removeLast(); - processedInsts.Remove(curInst); switch (curInst->getOp()) { @@ -2329,10 +2334,12 @@ struct SpecializationContext }; bool specializeModule( - IRModule* module) + IRModule* module, + DiagnosticSink* sink) { SpecializationContext context; context.module = module; + context.sink = sink; context.processModule(); return context.changed; } @@ -2349,6 +2356,7 @@ void finalizeSpecialization(IRModule* module) case kIROp_ExistentialFuncSpecializationDictionary: case kIROp_ExistentialTypeSpecializationDictionary: case kIROp_GenericSpecializationDictionary: + case kIROp_DispatchFuncDecoration: decor->removeAndDeallocate(); break; default: diff --git a/source/slang/slang-ir-specialize.h b/source/slang/slang-ir-specialize.h index 20d65cb67..1feff3e4b 100644 --- a/source/slang/slang-ir-specialize.h +++ b/source/slang/slang-ir-specialize.h @@ -4,10 +4,12 @@ namespace Slang { struct IRModule; +class DiagnosticSink; /// Specialize generic and interface-based code to use concrete types. bool specializeModule( - IRModule* module); + IRModule* module, + DiagnosticSink* sink); void finalizeSpecialization(IRModule* module); diff --git a/source/slang/slang-ir-util.cpp b/source/slang/slang-ir-util.cpp index c5cebb8b5..83f6735bd 100644 --- a/source/slang/slang-ir-util.cpp +++ b/source/slang/slang-ir-util.cpp @@ -698,6 +698,29 @@ bool isPureFunctionalCall(IRCall* call) return false; } +IRInst* findInterfaceRequirement(IRInterfaceType* type, IRInst* key) +{ + for (UInt i = 0; i < type->getOperandCount(); i++) + { + if (auto req = as<IRInterfaceRequirementEntry>(type->getOperand(i))) + { + if (req->getRequirementKey() == key) + return req->getRequirementVal(); + } + } + return nullptr; +} + +IRInst* findWitnessTableEntry(IRWitnessTable* table, IRInst* key) +{ + for (auto entry : table->getEntries()) + { + if (entry->getRequirementKey() == key) + return entry->getSatisfyingVal(); + } + return nullptr; +} + struct GenericChildrenMigrationContextImpl { IRCloneEnv cloneEnv; diff --git a/source/slang/slang-ir-util.h b/source/slang/slang-ir-util.h index f8e53c38f..ef7ff47bb 100644 --- a/source/slang/slang-ir-util.h +++ b/source/slang/slang-ir-util.h @@ -194,6 +194,10 @@ void sortBlocksInFunc(IRGlobalValueWithCode* func); // Remove all linkage decorations from func. void removeLinkageDecorations(IRGlobalValueWithCode* func); + +IRInst* findInterfaceRequirement(IRInterfaceType* type, IRInst* key); + +IRInst* findWitnessTableEntry(IRWitnessTable* table, IRInst* key); } #endif diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index f0c30dd3c..ada2e043e 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -7145,6 +7145,10 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> entry->setRequirementVal(requirementVal); break; } + if (requirementDecl->findModifier<HLSLStaticModifier>()) + { + getBuilder()->addStaticRequirementDecoration(requirementKey); + } } } irInterface->setOperand(entryIndex, entry); @@ -7807,6 +7811,7 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> // to the newly construct generic value. typeBuilder.setInsertBefore(parentGeneric); auto typeGeneric = typeBuilder.emitGeneric(); + typeGeneric->setFullType(typeBuilder.getGenericKind()); typeBuilder.setInsertInto(typeGeneric); typeBuilder.emitBlock(); diff --git a/tests/diagnostics/no-type-conformance.slang.expected b/tests/diagnostics/no-type-conformance.slang.expected index bc38fa7f1..5f5eda6af 100644 --- a/tests/diagnostics/no-type-conformance.slang.expected +++ b/tests/diagnostics/no-type-conformance.slang.expected @@ -1,8 +1,8 @@ result code = -1 standard error = { -tests/diagnostics/no-type-conformance.slang(4): error 50100: No type conformances are found for interface 'IFoo'. Code generation for current target requires at least one implementation type present in the linkage. -interface IFoo - ^~~~ +tests/diagnostics/no-type-conformance.slang(12): error 50100: No type conformances are found for interface 'IFoo'. Code generation for current target requires at least one implementation type present in the linkage. + obj.get(); + ^ } standard output = { } diff --git a/tests/ir/dynamic-generic-method-specialize.slang b/tests/ir/dynamic-generic-method-specialize.slang new file mode 100644 index 000000000..92ce8158e --- /dev/null +++ b/tests/ir/dynamic-generic-method-specialize.slang @@ -0,0 +1,65 @@ +//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -profile sm_5_0 -output-using-type + +// Test that we can specialize a generic method called through a dynamic interface. + +interface IValue +{ + float getVal(); +} + +struct SimpleVal : IValue +{ + float val; + float getVal() { return val; } +} + +[anyValueSize(16)] +interface IInterface +{ + associatedtype V : IValue; + V run<let N : int>(float arr[N]); +} + +struct Add : IInterface +{ + float base; + typealias V = SimpleVal; + V run<let N : int>(float arr[N]) + { + float sum = base; + for (int i = 0; i < N; i++) + sum += arr[i]; + V rs; + rs.val = sum; + return rs; + } +} + +struct Mul : IInterface +{ + float base; + typealias V = SimpleVal; + V run<let N : int>(float arr[N]) + { + float sum = base; + for (int i = 0; i < N; i++) + sum *= arr[i]; + V rs; + rs.val = sum; + return rs; + } +} + +//TEST_INPUT:ubuffer(data=[0], stride=4):out,name=gOutputBuffer +RWStructuredBuffer<float> gOutputBuffer; + +//TEST_INPUT:type_conformance Add:IInterface=1 +//TEST_INPUT:type_conformance Mul:IInterface=2 + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID: SV_DispatchThreadID) +{ + var obj = createDynamicObject<IInterface>(1, 1.0); // Add. + float arr[3] = { 2, 3, 4 }; + gOutputBuffer[0] = obj.run(arr).getVal(); +}
\ No newline at end of file diff --git a/tests/ir/dynamic-generic-method-specialize.slang.expected.txt b/tests/ir/dynamic-generic-method-specialize.slang.expected.txt new file mode 100644 index 000000000..0bc25648a --- /dev/null +++ b/tests/ir/dynamic-generic-method-specialize.slang.expected.txt @@ -0,0 +1,2 @@ +type: float +10.0
\ No newline at end of file diff --git a/tests/language-server/robustness-7.slang b/tests/language-server/robustness-7.slang new file mode 100644 index 000000000..ebc34e078 --- /dev/null +++ b/tests/language-server/robustness-7.slang @@ -0,0 +1,63 @@ +//TEST:LANG_SERVER: +//HOVER:6,15 + +// Test that we can specialize a generic method called through a dynamic interface. + +interface IValue +{ + float getVal(); +} + +struct SimpleVal : IValue +{ + float val; + float getVal() { return val; } +} + +[anyValueSize(16)] +interface IInterface +{ + associatedtype V : IValue; + V run<let N : int>(float arr[N]); +} + +struct Add : IInterface +{ + float base; + typealias V + float run<let N : int>(float arr[N]) + { + float sum = base; + for (int i = 0; i < N; i++) + sum += arr[i]; + return sum; + } +} + +struct Mul : IInterface +{ + float base; + + float run<let N : int>(float arr[N]) + { + float sum = base; + for (int i = 0; i < N; i++) + sum *= arr[i]; + return sum; + } +} + +//TEST_INPUT:ubuffer(data=[0], stride=4):out,name=gOutputBuffer +RWStructuredBuffer<float> gOutputBuffer; + +//TEST_INPUT:type_conformance Add:IInterface=1 +//TEST_INPUT:type_conformance Mul:IInterface=2 + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID: SV_DispatchThreadID) +{ + var obj = createDynamicObject<IInterface>(1, 1.0); // Add. + float arr[3] = { 2, 3, 4 }; + gOutputBuffer[0] = obj.run(arr); + +}
\ No newline at end of file diff --git a/tests/language-server/robustness-7.slang.expected.txt b/tests/language-server/robustness-7.slang.expected.txt new file mode 100644 index 000000000..d5f8ed9f8 --- /dev/null +++ b/tests/language-server/robustness-7.slang.expected.txt @@ -0,0 +1,12 @@ +-------- +range: 5,10 - 5,16 +content: +``` +interface IValue +``` + +Test that we can specialize a generic method called through a dynamic interface. + +{REDACTED}.slang(6) + + |
