From 78517dc392f0d2ebba25f0ac3f4d4e004b0f0ab0 Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> Date: Fri, 14 Mar 2025 17:15:36 -0700 Subject: Fix lowering of associated types in generic interfaces (#6600) * Fix lowering of associated types in generic interfaces. * Update diff-assoctype-generic-interface.slang * Fix-up lowering of differentiable witnesses for implicit ops * Update slang-ir-autodiff-transcriber-base.cpp * Fix issue with differentiating type-packs --- source/slang/slang-ast-dump.cpp | 2 +- source/slang/slang-ast-modifier.cpp | 2 +- source/slang/slang-ast-modifier.h | 10 +- source/slang/slang-check-expr.cpp | 28 ++++-- source/slang/slang-check-impl.h | 2 +- .../slang/slang-ir-autodiff-transcriber-base.cpp | 36 +------ source/slang/slang-ir-autodiff.cpp | 16 +-- source/slang/slang-ir-autodiff.h | 9 +- source/slang/slang-lower-to-ir.cpp | 41 +++++--- tests/autodiff/autopybind-printf.slang | 47 +++++++++ .../diff-assoctype-generic-interface.slang | 110 +++++++++++++++++++++ 11 files changed, 234 insertions(+), 69 deletions(-) create mode 100644 tests/autodiff/autopybind-printf.slang create mode 100644 tests/autodiff/diff-assoctype-generic-interface.slang diff --git a/source/slang/slang-ast-dump.cpp b/source/slang/slang-ast-dump.cpp index f6cdf50d8..bd366be19 100644 --- a/source/slang/slang-ast-dump.cpp +++ b/source/slang/slang-ast-dump.cpp @@ -647,7 +647,7 @@ struct ASTDumpContext void dump(SourceLanguage language) { m_writer->emit((int)language); } - void dump(KeyValuePair pair) + void dump(KeyValuePair pair) { m_writer->emit("("); dump(pair.key); diff --git a/source/slang/slang-ast-modifier.cpp b/source/slang/slang-ast-modifier.cpp index 2a245130e..383389c39 100644 --- a/source/slang/slang-ast-modifier.cpp +++ b/source/slang/slang-ast-modifier.cpp @@ -5,7 +5,7 @@ namespace Slang { -const OrderedDictionary& DifferentiableAttribute:: +const OrderedDictionary& DifferentiableAttribute:: getMapTypeToIDifferentiableWitness() { for (Index i = m_mapToIDifferentiableWitness.getCount(); diff --git a/source/slang/slang-ast-modifier.h b/source/slang/slang-ast-modifier.h index e4d5ccd09..5f9ccb5bb 100644 --- a/source/slang/slang-ast-modifier.h +++ b/source/slang/slang-ast-modifier.h @@ -1391,25 +1391,25 @@ class DifferentiableAttribute : public Attribute { SLANG_AST_CLASS(DifferentiableAttribute) - List> m_typeToIDifferentiableWitnessMappings; + List> m_typeToIDifferentiableWitnessMappings; - void addType(DeclRefBase* declRef, SubtypeWitness* witness) + void addType(Type* declRef, SubtypeWitness* witness) { getMapTypeToIDifferentiableWitness(); if (m_mapToIDifferentiableWitness.addIfNotExists(declRef, witness)) { m_typeToIDifferentiableWitnessMappings.add( - KeyValuePair(declRef, witness)); + KeyValuePair(declRef, witness)); } } /// Mapping from types to subtype witnesses for conformance to IDifferentiable. - const OrderedDictionary& getMapTypeToIDifferentiableWitness(); + const OrderedDictionary& getMapTypeToIDifferentiableWitness(); SLANG_UNREFLECTED ValSet m_typeRegistrationWorkingSet; private: - OrderedDictionary m_mapToIDifferentiableWitness; + OrderedDictionary m_mapToIDifferentiableWitness; }; class DllImportAttribute : public Attribute diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp index b730069b6..2f91a6a77 100644 --- a/source/slang/slang-check-expr.cpp +++ b/source/slang/slang-check-expr.cpp @@ -1405,14 +1405,12 @@ Type* SemanticsVisitor::getDifferentialType(ASTBuilder* builder, Type* type, Sou return result; } -void SemanticsVisitor::addDifferentiableTypeToDiffTypeRegistry( - DeclRefType* type, - SubtypeWitness* witness) +void SemanticsVisitor::addDifferentiableTypeToDiffTypeRegistry(Type* type, SubtypeWitness* witness) { SLANG_RELEASE_ASSERT(m_parentDifferentiableAttr); if (witness) { - m_parentDifferentiableAttr->addType(type->getDeclRef(), witness); + m_parentDifferentiableAttr->addType(type, witness); } } @@ -1468,14 +1466,14 @@ void SemanticsVisitor::maybeRegisterDifferentiableTypeImplRecursive(ASTBuilder* type, getASTBuilder()->getDifferentiableInterfaceType()))) { - addDifferentiableTypeToDiffTypeRegistry((DeclRefType*)type, subtypeWitness); + addDifferentiableTypeToDiffTypeRegistry(type, subtypeWitness); } if (auto subtypeWitness = as(tryGetInterfaceConformanceWitness( type, getASTBuilder()->getDifferentiableRefInterfaceType()))) { - addDifferentiableTypeToDiffTypeRegistry((DeclRefType*)type, subtypeWitness); + addDifferentiableTypeToDiffTypeRegistry(type, subtypeWitness); } if (auto aggTypeDeclRef = declRefType->getDeclRef().as()) @@ -1515,6 +1513,15 @@ void SemanticsVisitor::maybeRegisterDifferentiableTypeImplRecursive(ASTBuilder* maybeRegisterDifferentiableTypeImplRecursive(builder, typePack->getElementType(i)); return; } + + // General check for types that may not be decl-ref-type, but still have some conformance to + // IDifferentiable/IDifferentiablePtrType + if (auto subtypeWitness = as(tryGetInterfaceConformanceWitness( + type, + getASTBuilder()->getDifferentiableInterfaceType()))) + { + addDifferentiableTypeToDiffTypeRegistry(type, subtypeWitness); + } } @@ -4846,7 +4853,14 @@ Expr* SemanticsVisitor::checkBaseForMemberExpr( auto baseExpr = inBaseExpr; baseExpr = CheckTerm(baseExpr); - return maybeInsertImplicitOpForMemberBase(baseExpr, checkBaseContext, outNeedDeref); + auto resultBaseExpr = + maybeInsertImplicitOpForMemberBase(baseExpr, checkBaseContext, outNeedDeref); + + // We might want to register differentiability on any implicit ops that we add in. + if (this->m_parentFunc && this->m_parentFunc->findModifier()) + maybeRegisterDifferentiableType(getASTBuilder(), resultBaseExpr->type.type); + + return resultBaseExpr; } Expr* SemanticsVisitor::checkGeneralMemberLookupExpr(MemberExpr* expr, Type* baseType) diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h index ba3792af7..95716744c 100644 --- a/source/slang/slang-check-impl.h +++ b/source/slang/slang-check-impl.h @@ -1512,7 +1512,7 @@ public: /// Registers a type as conforming to IDifferentiable, along with a witness /// describing the relationship. /// - void addDifferentiableTypeToDiffTypeRegistry(DeclRefType* type, SubtypeWitness* witness); + void addDifferentiableTypeToDiffTypeRegistry(Type* type, SubtypeWitness* witness); void maybeRegisterDifferentiableTypeImplRecursive(ASTBuilder* builder, Type* type); // Construct the differential for 'type', if it exists. diff --git a/source/slang/slang-ir-autodiff-transcriber-base.cpp b/source/slang/slang-ir-autodiff-transcriber-base.cpp index 8356e5f81..d67d75997 100644 --- a/source/slang/slang-ir-autodiff-transcriber-base.cpp +++ b/source/slang/slang-ir-autodiff-transcriber-base.cpp @@ -720,9 +720,7 @@ IRInst* AutoDiffTranscriberBase::getDifferentialZeroOfType(IRBuilder* builder, I if (auto diffType = differentiateType(builder, originalType)) { - IRInst* diffWitnessTable = nullptr; - IRType* diffOuterType = nullptr; - if (isExistentialType(diffType)) + if (isExistentialType(diffType) && !as(diffType)) { // Emit null differential & pack it into an IDifferentiable existential. @@ -789,25 +787,8 @@ IRInst* AutoDiffTranscriberBase::getDifferentialZeroOfType(IRBuilder* builder, I return result; } - // Since primalType has a corresponding differential type, we can lookup the - // definition for zero(). - IRInst* zeroMethod = nullptr; - if (auto lookupInterface = as(diffType)) - { - // if the differential type itself comes from a witness lookup, we can just lookup the - // zero method from the same witness table. - auto wt = lookupInterface->getWitnessTable(); - zeroMethod = builder->emitLookupInterfaceMethodInst( - builder->getFuncType(List(), diffType), - wt, - autoDiffSharedContext->zeroMethodStructKey); - builder->markInstAsPrimal(zeroMethod); - } - else - { - zeroMethod = - differentiableTypeConformanceContext.getZeroMethodForType(builder, originalType); - } + auto zeroMethod = + differentiableTypeConformanceContext.getZeroMethodForType(builder, originalType); SLANG_RELEASE_ASSERT(zeroMethod); auto emptyArgList = List(); @@ -815,16 +796,7 @@ IRInst* AutoDiffTranscriberBase::getDifferentialZeroOfType(IRBuilder* builder, I auto callInst = builder->emitCallInst((IRType*)diffType, zeroMethod, emptyArgList); builder->markInstAsDifferential(callInst, primalType); - if (diffOuterType && isExistentialType(diffOuterType)) - { - // Need to wrap the result back into an existential. - auto existentialZero = - builder->emitMakeExistential(diffOuterType, callInst, diffWitnessTable); - builder->markInstAsDifferential(existentialZero, primalType); - return existentialZero; - } - else - return callInst; + return callInst; } else { diff --git a/source/slang/slang-ir-autodiff.cpp b/source/slang/slang-ir-autodiff.cpp index df657476a..f3f32add2 100644 --- a/source/slang/slang-ir-autodiff.cpp +++ b/source/slang/slang-ir-autodiff.cpp @@ -1362,9 +1362,10 @@ IRInst* DifferentiableTypeConformanceContext::lookUpInterfaceMethod( IRBuilder* builder, IRType* origType, IRStructKey* key, - IRType* resultType) + IRType* resultType, + DiffConformanceKind kind) { - if (auto conformance = tryGetDifferentiableWitness(builder, origType, DiffConformanceKind::Any)) + if (auto conformance = tryGetDifferentiableWitness(builder, origType, kind)) return _lookupWitness(builder, conformance, key, resultType); return nullptr; } @@ -2097,8 +2098,6 @@ IRInst* DifferentiableTypeConformanceContext::buildTupleWitness( IRWitnessTable* table = nullptr; if (target == DiffConformanceKind::Value) { - SLANG_ASSERT(isDifferentiableValueType((IRType*)inTupleType)); - auto addMethod = builder->createFunc(); auto zeroMethod = builder->createFunc(); @@ -2138,6 +2137,8 @@ IRInst* DifferentiableTypeConformanceContext::buildTupleWitness( &b, (IRType*)elementType, DiffConformanceKind::Value); + + SLANG_ASSERT(isDifferentiableValueType((IRType*)elementType)); IRInst* elementResult = nullptr; if (!innerWitness) { @@ -2171,9 +2172,9 @@ IRInst* DifferentiableTypeConformanceContext::buildTupleWitness( { // Zero method. IRBuilder b = *builder; - b.setInsertInto(addMethod); - b.addBackwardDifferentiableDecoration(addMethod); - addMethod->setFullType(b.getFuncType(0, nullptr, diffTupleType)); + b.setInsertInto(zeroMethod); + b.addBackwardDifferentiableDecoration(zeroMethod); + zeroMethod->setFullType(b.getFuncType(0, nullptr, diffTupleType)); b.emitBlock(); List results; for (UInt i = 0; i < inTupleType->getOperandCount(); i++) @@ -2214,7 +2215,6 @@ IRInst* DifferentiableTypeConformanceContext::buildTupleWitness( else if (target == DiffConformanceKind::Ptr) { SLANG_ASSERT(isDifferentiablePtrType((IRType*)inTupleType)); - table = builder->createWitnessTable( sharedContext->differentiablePtrInterfaceType, (IRType*)inTupleType); diff --git a/source/slang/slang-ir-autodiff.h b/source/slang/slang-ir-autodiff.h index 4698408e3..2cd08eb28 100644 --- a/source/slang/slang-ir-autodiff.h +++ b/source/slang/slang-ir-autodiff.h @@ -252,7 +252,8 @@ struct DifferentiableTypeConformanceContext IRBuilder* builder, IRType* origType, IRStructKey* key, - IRType* resultType = nullptr); + IRType* resultType = nullptr, + DiffConformanceKind kind = DiffConformanceKind::Any); IRType* differentiateType(IRBuilder* builder, IRInst* primalType); @@ -411,7 +412,8 @@ struct DifferentiableTypeConformanceContext builder, origType, sharedContext->zeroMethodStructKey, - sharedContext->zeroMethodType); + sharedContext->zeroMethodType, + DiffConformanceKind::Value); return result; } @@ -421,7 +423,8 @@ struct DifferentiableTypeConformanceContext builder, origType, sharedContext->addMethodStructKey, - sharedContext->addMethodType); + sharedContext->addMethodType, + DiffConformanceKind::Value); return result; } diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index 775986a9a..decfe4a91 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -1919,6 +1919,28 @@ struct ValLoweringVisitor : ValVisitorgetIntValue(type, val->getValue())); } + IRType* visitDifferentialPairType(DifferentialPairType* pairType) + { + IRType* primalType = lowerType(context, pairType->getPrimalType()); + if (as(primalType) || as(primalType)) + { + List operands; + SubstitutionSet(pairType->getDeclRef()) + .forEachSubstitutionArg( + [&](Val* arg) + { + auto argVal = lowerVal(context, arg).val; + SLANG_ASSERT(argVal); + operands.add(argVal); + }); + + auto undefined = getBuilder()->emitUndefined(operands[1]->getFullType()); + return getBuilder()->getDifferentialPairUserCodeType(primalType, undefined); + } + else + return lowerSimpleIntrinsicType(pairType); + } + IRFuncType* visitFuncType(FuncType* type) { IRType* resultType = lowerType(context, type->getResultType()); @@ -10195,15 +10217,17 @@ struct DeclLoweringVisitor : DeclVisitor // If our function is differentiable, register a callback so the derivative // annotations for types can be lowered. // - if (auto diffAttr = decl->findModifier()) + if (decl->findModifier() && !isInterfaceRequirement(decl)) { + auto diffAttr = decl->findModifier(); + auto diffTypeWitnessMap = diffAttr->getMapTypeToIDifferentiableWitness(); - OrderedDictionary resolveddiffTypeWitnessMap; + OrderedDictionary resolveddiffTypeWitnessMap; // Go through each entry in the map and resolve the key. for (auto& entry : diffTypeWitnessMap) { - auto resolvedKey = as(entry.key->resolve()); + auto resolvedKey = as(entry.key->resolve()); resolveddiffTypeWitnessMap[resolvedKey] = as(as(entry.value)->resolve()); } @@ -10211,14 +10235,9 @@ struct DeclLoweringVisitor : DeclVisitor subContext->registerTypeCallback( [=](IRGenContext* context, Type* type, IRType* irType) { - if (!as(type)) - return irType; - - DeclRefBase* declRefBase = as(type)->getDeclRefBase(); - if (resolveddiffTypeWitnessMap.containsKey(declRefBase)) + if (resolveddiffTypeWitnessMap.containsKey(type)) { - auto irWitness = - lowerVal(subContext, resolveddiffTypeWitnessMap[declRefBase]).val; + auto irWitness = lowerVal(subContext, resolveddiffTypeWitnessMap[type]).val; if (irWitness) { IRInst* args[] = {irType, irWitness}; @@ -11328,7 +11347,7 @@ LoweredValInfo emitDeclRef(IRGenContext* context, Decl* decl, DeclRefBase* subst // interface definitions. return emitDeclRef( context, - createDefaultSpecializedDeclRef(context, nullptr, decl), + decl->getDefaultDeclRef(), context->irBuilder->getTypeKind()); } diff --git a/tests/autodiff/autopybind-printf.slang b/tests/autodiff/autopybind-printf.slang new file mode 100644 index 000000000..add1923ef --- /dev/null +++ b/tests/autodiff/autopybind-printf.slang @@ -0,0 +1,47 @@ +//TEST:SIMPLE(filecheck=CUDA): -target cuda -line-directive-mode none +//TEST:SIMPLE(filecheck=TORCH): -target torch -line-directive-mode none + +// CUDA: __device__ void s_primal_ctx_myKernel_0( +// CUDA: printf("%f\n", +// CUDA: __global__ void __kernel__myKernel_bwd_diff(DiffTensorView_[[#]] {{[[:alnum:]_]+}}, DiffTensorView_[[#]] {{[[:alnum:]_]+}}) +// CUDA: __global__ void __kernel__myKernel_fwd_diff(DiffTensorView_[[#]] {{[[:alnum:]_]+}}, DiffTensorView_[[#]] {{[[:alnum:]_]+}}) +// CUDA: __global__ void __kernel__myKernel(DiffTensorView_[[#]] {{[[:alnum:]_]+}}, DiffTensorView_[[#]] {{[[:alnum:]_]+}}) + +[AutoPyBindCUDA] +[Differentiable] +[CudaKernel] +void myKernel(DiffTensorView inValues, DiffTensorView outValues) +{ + if (cudaThreadIdx().x > 0) + return; + printf("%f\n", inValues[cudaThreadIdx().x]); + outValues[cudaThreadIdx().x] = sin(inValues[cudaThreadIdx().x]); +} + +// TORCH: {{^SLANG_PRELUDE_EXPORT$}} +// TORCH-NEXT: void __kernel__myKernel_bwd_diff(DiffTensorView_[[#]] {{[[:alnum:]_]+}}, DiffTensorView_[[#]] {{[[:alnum:]_]+}}) +// +// TORCH: {{^SLANG_PRELUDE_EXPORT$}} +// TORCH-NEXT: void __kernel__myKernel_fwd_diff(DiffTensorView_[[#]] {{[[:alnum:]_]+}}, DiffTensorView_[[#]] {{[[:alnum:]_]+}}) +// +// TORCH: {{^SLANG_PRELUDE_EXPORT$}} +// TORCH-NEXT: void __kernel__myKernel(DiffTensorView_[[#]] {{[[:alnum:]_]+}}, DiffTensorView_[[#]] {{[[:alnum:]_]+}}) +// +// TORCH: {{^SLANG_PRELUDE_EXPORT$}} +// TORCH-NEXT: void myKernel(std::tuple {{[[:alnum:]_]+}}, std::tuple {{[[:alnum:]_]+}}, std::tuple> {{[[:alnum:]_]+}}, std::tuple> {{[[:alnum:]_]+}}) +// +// TORCH: {{^SLANG_PRELUDE_EXPORT$}} +// TORCH-NEXT: std::tuple, std::tuple, const char*, const char*> __funcinfo__myKernel() +// +// TORCH: {{^SLANG_PRELUDE_EXPORT$}} +// TORCH-NEXT: void myKernel_fwd_diff(std::tuple {{[[:alnum:]_]+}}, std::tuple {{[[:alnum:]_]+}}, std::tuple> {{[[:alnum:]_]+}}, std::tuple> {{[[:alnum:]_]+}}) +// +// TORCH: {{^SLANG_PRELUDE_EXPORT$}} +// TORCH-NEXT: void myKernel_bwd_diff(std::tuple {{[[:alnum:]_]+}}, std::tuple {{[[:alnum:]_]+}}, std::tuple> {{[[:alnum:]_]+}}, std::tuple> {{[[:alnum:]_]+}}) +// +// TORCH: {{^SLANG_PRELUDE_EXPORT$}} +// TORCH-NEXT: std::tuple, std::tuple> __typeinfo__DiffTensorView() +// +// TORCH: {{^SLANG_PRELUDE_EXPORT$}} +// TORCH-NEXT: std::tuple, std::tuple> __typeinfo__AtomicAdd() +// \ No newline at end of file diff --git a/tests/autodiff/diff-assoctype-generic-interface.slang b/tests/autodiff/diff-assoctype-generic-interface.slang new file mode 100644 index 000000000..79e0eff08 --- /dev/null +++ b/tests/autodiff/diff-assoctype-generic-interface.slang @@ -0,0 +1,110 @@ +// Test calling differentiable function through dynamic dispatch. + +//TEST(compute, vulkan):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-vk -compute -shaderobj -output-using-type -compile-arg -skip-spirv-validation -emit-spirv-directly + +//TEST_INPUT:ubuffer(data=[2 0 0 0 0 0 0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer outputBuffer; + +interface IGetter : IDifferentiable +{ + [Differentiable] + float get(uint id); +} + +struct GetterImpl : IGetter +{ + float[8] data; + + __init(float[8] data) + { this.data = data; } + + [Differentiable] + float get(uint id) + { + return data[id]; + } +} +interface IFoo +{ + associatedtype Params : IGetter; + + [Differentiable] + Params bar(); +} + +[BackwardDerivative(load_bwd)] +[ForwardDerivative(load_fwd)] +float load(uint id) +{ + return outputBuffer[id] + 2; +} + +DifferentialPair load_fwd(uint id) +{ + return DifferentialPair(load(id), 3.f); +} + +void load_bwd(uint id, float.Differential dOut) +{ + outputBuffer[id + 8] = dOut; +} + +struct FooImpl1: IFoo<8> +{ + typealias Params = GetterImpl; + + __init() + { } + + [Differentiable] + Params bar() + { + float x = load(0); + return GetterImpl({x, x+1, x+2, x+3, x+4, x+5, x+6, x+7}); + } +} + +/* +// There's a slight issue with dynamic dispatch over generic interfaces. Uncomment after that is fixed. + +struct FooImpl2: IFoo<8> +{ + typealias Params = GetterImpl; + + __init() + { } + + [Differentiable] + Params bar() + { + float x = 2 * load(0); + return GetterImpl({x, x+5, x+7, x+9, x+11, x+13, x+15, x+17}); + } +} +*/ + +IFoo<8> getFoo(uint id) +{ + /*if (id == 0) + return FooImpl1(); + else + return FooImpl2();*/ + return FooImpl1(); +} + +[Differentiable] +float doThing(uint id) +{ + IFoo<8> foo = getFoo(id); + return foo.bar().get(0); +} + +[shader("compute")] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + outputBuffer[0] = doThing(0); // CHECK: 2.0 + outputBuffer[1] = doThing(1); // CHECK: 4.0 + + outputBuffer[2] = fwd_diff(doThing)(0).d; // CHECK: 3.0 + outputBuffer[3] = fwd_diff(doThing)(1).d; // CHECK: 3.0 +} \ No newline at end of file -- cgit v1.2.3