diff options
| author | Yong He <yonghe@outlook.com> | 2022-11-08 10:07:57 -0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2022-11-08 10:07:57 -0800 |
| commit | bf67309454032b4f92d0bc9735b608e56b16882f (patch) | |
| tree | a321fe7db0b49fa67608b935c1389354a020f59c /source | |
| parent | ca882a1ef46a5a8bbff50e3a1a6f973e16358634 (diff) | |
Make `__BuiltinFloatingPointType` conform to `IDifferentiable`. (#2499)
Diffstat (limited to 'source')
| -rw-r--r-- | source/slang/core.meta.slang | 79 | ||||
| -rw-r--r-- | source/slang/diff.meta.slang | 76 | ||||
| -rw-r--r-- | source/slang/slang-check-overload.cpp | 9 | ||||
| -rw-r--r-- | source/slang/slang-ir-diff-jvp.cpp | 94 | ||||
| -rw-r--r-- | source/slang/slang-ir-peephole.cpp | 23 | ||||
| -rw-r--r-- | source/slang/slang-lower-to-ir.cpp | 11 | ||||
| -rw-r--r-- | source/slang/slang-syntax.cpp | 19 |
7 files changed, 193 insertions, 118 deletions
diff --git a/source/slang/core.meta.slang b/source/slang/core.meta.slang index 05963bd11..a37124bdc 100644 --- a/source/slang/core.meta.slang +++ b/source/slang/core.meta.slang @@ -162,7 +162,7 @@ interface __BuiltinRealType : __BuiltinSignedArithmeticType {} /// A type that uses a floating-point representation [sealed] [builtin] -interface __BuiltinFloatingPointType : __BuiltinRealType +interface __BuiltinFloatingPointType : __BuiltinRealType, IDifferentiable { /// Initialize from a 32-bit floating-point value. __init(float value); @@ -369,6 +369,26 @@ ${{{{ case BaseType::Double: }}}} static $(kBaseTypes[tt].name) getPi() { return $(kBaseTypes[tt].name)(3.14159265358979323846264338328); } + + typedef $(kBaseTypes[tt].name) Differential; + + [__unsafeForceInlineEarly] + static Differential dzero() + { + return Differential(0); + } + + [__unsafeForceInlineEarly] + static Differential dadd(Differential a, Differential b) + { + return a + b; + } + + [__unsafeForceInlineEarly] + static Differential dmul(Differential a, Differential b) + { + return a * b; + } ${{{{ break; } @@ -891,7 +911,6 @@ for (int tt = 0; tt < kBaseTypeCount; ++tt) sb << " __init(" << kBaseTypes[ff].name << " value);\n"; } } - sb << "}\n"; } @@ -926,7 +945,6 @@ for( int C = 2; C <= 4; ++C ) if(rr == R && cc == C) continue; sb << "__init(matrix<T," << rr << "," << cc << "> value);\n"; } - sb << "}\n"; } @@ -935,6 +953,7 @@ for (int tt = 0; tt < kBaseTypeCount; ++tt) if(kBaseTypes[tt].tag == BaseType::Void) continue; auto toType = kBaseTypes[tt].name; }}}} + __generic<let R : int, let C : int> extension matrix<$(toType),R,C> { ${{{{ @@ -958,6 +977,60 @@ ${{{{ } }}}} +__generic<T, U> +__intrinsic_op(0) +T __slang_noop_cast(U u); + +__generic<T:__BuiltinFloatingPointType, let N: int> +extension vector<T, N> : IDifferentiable +{ + typedef vector<T, N> Differential; + + [__unsafeForceInlineEarly] + static Differential dzero() + { + return Differential(__slang_noop_cast<T>(T.dzero())); + } + + [__unsafeForceInlineEarly] + static Differential dadd(Differential a, Differential b) + { + return a + b; + } + + [__unsafeForceInlineEarly] + static Differential dmul(This a, Differential b) + { + return a * b; + } +} + +__generic<T:__BuiltinFloatingPointType, let R: int, let C: int> +extension matrix<T, R, C> : IDifferentiable +{ + typedef matrix<T, R, C> Differential; + + __init(T val); + + [__unsafeForceInlineEarly] + static Differential dzero() + { + return matrix<T, R, C>(__slang_noop_cast<T>(T.dzero())); + } + + [__unsafeForceInlineEarly] + static Differential dadd(Differential a, Differential b) + { + return a + b; + } + + [__unsafeForceInlineEarly] + static Differential dmul(This a, Differential b) + { + return a * b; + } +} + //@ public: /// Sampling state for filtered texture fetches. diff --git a/source/slang/diff.meta.slang b/source/slang/diff.meta.slang index 2625d79b0..c95f8e1ac 100644 --- a/source/slang/diff.meta.slang +++ b/source/slang/diff.meta.slang @@ -9,56 +9,10 @@ attribute_syntax [ForwardDifferentiable] : ForwardDifferentiableAttribute; __attributeTarget(FunctionDeclBase) attribute_syntax [ForwardDerivative(function)] : ForwardDerivativeAttribute; -// Add extensions for the standard types -extension float : IDifferentiable -{ - typedef float Differential; - - [__unsafeForceInlineEarly] - static Differential dzero() - { - return float(0.f); - } - - [__unsafeForceInlineEarly] - static Differential dadd(Differential a, Differential b) - { - return a + b; - } - - [__unsafeForceInlineEarly] - static Differential dmul(This a, Differential b) - { - return a * b; - } -} - -__generic<let N:int> -extension vector<float, N> : IDifferentiable -{ - typedef vector<float, N> Differential; - - [__unsafeForceInlineEarly] - static Differential dzero() - { - return vector<float, N>(0.f); - } - - [__unsafeForceInlineEarly] - static Differential dadd(Differential a, Differential b) - { - return a + b; - } - [__unsafeForceInlineEarly] - static Differential dmul(This a, Differential b) - { - return a * b; - } -} +/// Pair type that serves to wrap the primal and +/// differential types of an arbitrary type T. - /// Pair type that serves to wrap the primal and - /// differential types of an arbitrary type T. __generic<T : IDifferentiable> __magic_type(DifferentialPairType) __intrinsic_type($(kIROp_DifferentialPairType)) @@ -126,15 +80,13 @@ struct DifferentialPair : IDifferentiable } }; -typealias IDFloat = IFloat & IDifferentiable; - #define VECTOR_MAP_UNARY(TYPE, COUNT, FUNC, VALUE) \ vector<TYPE,COUNT> result; for(int i = 0; i < COUNT; ++i) { result[i] = FUNC(VALUE[i]); } return result namespace dstd { // Natural Exponent - __generic<T : IDFloat> + __generic<T : __BuiltinFloatingPointType> __target_intrinsic(hlsl) __target_intrinsic(glsl) __target_intrinsic(cuda, "$P_exp($0)") @@ -143,16 +95,16 @@ namespace dstd [ForwardDerivative(d_exp<T>)] T exp(T x); - __generic<T : IDFloat> + __generic<T : __BuiltinFloatingPointType> DifferentialPair<T> d_exp(DifferentialPair<T> dpx) { return DifferentialPair<T>( - exp(dpx.p), - T.dmul(exp(dpx.p), dpx.d)); + dstd.exp(dpx.p), + T.dmul(dstd.exp(dpx.p), dpx.d)); } // Sine - __generic<T : IDFloat> + __generic<T : __BuiltinFloatingPointType> __target_intrinsic(hlsl) __target_intrinsic(glsl) __target_intrinsic(cuda, "$P_sin($0)") @@ -161,16 +113,16 @@ namespace dstd [ForwardDerivative(d_sin<T>)] T sin(T x); - __generic<T : IDFloat> + __generic<T : __BuiltinFloatingPointType> DifferentialPair<T> d_sin(DifferentialPair<T> dpx) { return DifferentialPair<T>( - sin(dpx.p), - T.dmul(cos(dpx.p), dpx.d)); + dstd.sin(dpx.p), + T.dmul(dstd.cos(dpx.p), dpx.d)); } // Cosine - __generic<T : IDFloat> + __generic<T : __BuiltinFloatingPointType> __target_intrinsic(hlsl) __target_intrinsic(glsl) __target_intrinsic(cuda, "$P_cos($0)") @@ -179,12 +131,12 @@ namespace dstd [ForwardDerivative(d_cos<T>)] T cos(T x); - __generic<T : IDFloat> + __generic<T : __BuiltinFloatingPointType> DifferentialPair<T> d_cos(DifferentialPair<T> dpx) { return DifferentialPair<T>( - cos(dpx.p), - T.dmul(-sin(dpx.p), dpx.d)); + dstd.cos(dpx.p), + T.dmul(-dstd.sin(dpx.p), dpx.d)); } __generic<let N : int> diff --git a/source/slang/slang-check-overload.cpp b/source/slang/slang-check-overload.cpp index 42fab94a6..38754d170 100644 --- a/source/slang/slang-check-overload.cpp +++ b/source/slang/slang-check-overload.cpp @@ -1573,10 +1573,15 @@ namespace Slang { for (auto item : overloadExpr->lookupResult2.items) { + auto funcType = as<FuncType>(GetTypeForDeclRef(item.declRef, item.declRef.decl->loc)); + if (!funcType) + continue; + funcType = as<FuncType>(processJVPFuncType(funcType)); + if (!funcType) + continue; OverloadCandidate candidate; candidate.flavor = OverloadCandidate::Flavor::Expr; - candidate.funcType = as<FuncType>(processJVPFuncType( - as<FuncType>(GetTypeForDeclRef(item.declRef, item.declRef.decl->loc)))); + candidate.funcType = funcType; candidate.resultType = candidate.funcType->getResultType(); candidate.item = LookupResultItem(item.declRef); diff --git a/source/slang/slang-ir-diff-jvp.cpp b/source/slang/slang-ir-diff-jvp.cpp index 3135f300d..574db2036 100644 --- a/source/slang/slang-ir-diff-jvp.cpp +++ b/source/slang/slang-ir-diff-jvp.cpp @@ -681,17 +681,6 @@ struct JVPTranscriber return builder->getFuncType(newParameterTypes, diffReturnType); } - IRWitnessTable* getDifferentialBottomWitness() - { - IRBuilder builder(sharedBuilder); - builder.setInsertInto(sharedBuilder->getModule()->getModuleInst()); - auto result = - as<IRWitnessTable>(differentiableTypeConformanceContext.lookUpConformanceForType( - builder.getDifferentialBottomType())); - SLANG_ASSERT(result); - return result; - } - // Get or construct `:IDifferentiable` conformance for a DifferentiablePair. IRWitnessTable* getDifferentialPairWitness(IRInst* inDiffPairType) { @@ -699,20 +688,23 @@ struct JVPTranscriber builder.setInsertInto(inDiffPairType->parent); auto diffPairType = as<IRDifferentialPairType>(inDiffPairType); SLANG_ASSERT(diffPairType); - auto result = - as<IRWitnessTable>(differentiableTypeConformanceContext.lookUpConformanceForType( - builder.getDifferentialBottomType())); - if (result) - return result; - - auto table = builder.createWitnessTable(autoDiffSharedContext->differentiableInterfaceType, diffPairType); auto diffType = differentiateType(&builder, diffPairType->getValueType()); - auto differentialType = builder.getDifferentialPairType(diffType, getDifferentialBottomWitness()); - builder.createWitnessTableEntry(table, autoDiffSharedContext->differentialAssocTypeStructKey, differentialType); - // Omit the method synthesis here, since we can just intercept those directly at `getXXMethodForType`. - differentiableTypeConformanceContext.differentiableWitnessDictionary[diffPairType] = table; - return table; + IRInst* tableInst = nullptr; + if (!differentiableTypeConformanceContext.differentiableWitnessDictionary.TryGetValue(diffPairType, tableInst)) + { + IRWitnessTable* table = builder.createWitnessTable(autoDiffSharedContext->differentiableInterfaceType, diffPairType); + // The witness that `diffType` + auto differentialType = builder.getDifferentialPairType( + diffType, + differentiableTypeConformanceContext.differentiableWitnessDictionary[diffType] + .GetValue()); + builder.createWitnessTableEntry(table, autoDiffSharedContext->differentialAssocTypeStructKey, differentialType); + // Omit the method synthesis here, since we can just intercept those directly at `getXXMethodForType`. + differentiableTypeConformanceContext.differentiableWitnessDictionary[diffPairType] = table; + tableInst = table; + } + return as<IRWitnessTable>(tableInst); } IRType* getOrCreateDiffPairType(IRInst* primalType, IRInst* witness) @@ -730,8 +722,10 @@ struct JVPTranscriber builder.setInsertInto(primalType->parent); auto witness = as<IRWitnessTable>( differentiableTypeConformanceContext.lookUpConformanceForType((IRType*)primalType)); - if (!witness) - witness = getDifferentialBottomWitness(); + if (!witness && as<IRDifferentialPairType>(primalType)) + { + witness = getDifferentialPairWitness(primalType); + } return builder.getDifferentialPairType( (IRType*)primalType, witness); @@ -2205,29 +2199,41 @@ struct JVPDerivativeContext : public InstPassBase bool processPairTypes(IRBuilder* builder, IRInst* instWithChildren) { bool modified = false; - // Hoist all pair types to global scope when possible. + // Hoist and deduplicate all pair types to global scope when possible. + // This avoids emitting different struct types for equivalent pair types. auto moduleInst = module->getModuleInst(); - processInstsOfType<IRDifferentialPairType>(kIROp_DifferentialPairType, [&](IRInst* originalPairType) - { - if (originalPairType->parent != moduleInst) + Dictionary<IRInst*, IRInst*> diffPairTypes; + for (;;) + { + bool changed = false; + sharedBuilderStorage.deduplicateAndRebuildGlobalNumberingMap(); + processInstsOfType<IRDifferentialPairType>(kIROp_DifferentialPairType, [&](IRDifferentialPairType* originalPairType) { - originalPairType->removeFromParent(); - ShortList<IRInst*> operands; - for (UInt i = 0; i < originalPairType->getOperandCount(); i++) + IRInst* finalType = nullptr; + if (diffPairTypes.TryGetValue(originalPairType->getValueType(), finalType)) { - operands.add(originalPairType->getOperand(i)); + if (finalType != originalPairType) + { + originalPairType->replaceUsesWith(finalType); + originalPairType->removeAndDeallocate(); + changed = true; + return; + } } - auto newPairType = builder->findOrEmitHoistableInst( - originalPairType->getFullType(), - originalPairType->getOp(), - originalPairType->getOperandCount(), - operands.getArrayView().getBuffer()); - originalPairType->replaceUsesWith(newPairType); - originalPairType->removeAndDeallocate(); - } - }); - - sharedBuilderStorage.deduplicateAndRebuildGlobalNumberingMap(); + diffPairTypes[originalPairType->getValueType()] = originalPairType; + if (originalPairType->parent != moduleInst) + { + if (originalPairType->getValueType()->getParent() != originalPairType->getParent()) + { + originalPairType->insertAfter(originalPairType->getValueType()); + changed = true; + return; + } + } + }); + if (!changed) + break; + } processAllInsts([&](IRInst* inst) { diff --git a/source/slang/slang-ir-peephole.cpp b/source/slang/slang-ir-peephole.cpp index 32950edc9..788110330 100644 --- a/source/slang/slang-ir-peephole.cpp +++ b/source/slang/slang-ir-peephole.cpp @@ -197,6 +197,29 @@ struct PeepholeContext : InstPassBase } } break; + case kIROp_lookup_interface_method: + { + if (inst->getOperand(0)->getOp() == kIROp_WitnessTable) + { + auto wt = as<IRWitnessTable>(inst->getOperand(0)); + auto key = inst->getOperand(1); + for (auto item : wt->getChildren()) + { + if (auto entry = as<IRWitnessTableEntry>(item)) + { + if (entry->getRequirementKey() == key) + { + auto value = entry->getSatisfyingVal(); + inst->replaceUsesWith(value); + inst->removeAndDeallocate(); + changed = true; + break; + } + } + } + } + } + break; default: break; } diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index f8d8282d8..12a9f73e6 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -1349,6 +1349,7 @@ struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, Lower // The base (subToMid) will turn into a value with // witness-table type. IRInst* baseWitnessTable = lowerSimpleVal(context, val->subToMid); + IRInst* midToSup = nullptr; // The next step should map to an interface requirement // that is itself an interface conformance, so the result @@ -1366,7 +1367,6 @@ struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, Lower // produce transitive witnesses in shapes that will cuase us // problems here. // - IRInst* midToSup = lowerSimpleVal(context, val->midToSup); if (!baseWitnessTable) { @@ -1380,6 +1380,15 @@ struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, Lower return LoweredValInfo::simple(midToSup); } + if (auto declaredMidToSup = as<DeclaredSubtypeWitness>(val->midToSup)) + { + midToSup = getInterfaceRequirementKey(context, declaredMidToSup->declRef.decl); + } + else + { + midToSup = lowerSimpleVal(context, val->midToSup); + } + return LoweredValInfo::simple(getBuilder()->emitLookupInterfaceMethodInst( getBuilder()->getWitnessTableType(lowerType(context, val->sup)), baseWitnessTable, diff --git a/source/slang/slang-syntax.cpp b/source/slang/slang-syntax.cpp index 12b9dab42..f3b590acf 100644 --- a/source/slang/slang-syntax.cpp +++ b/source/slang/slang-syntax.cpp @@ -234,6 +234,10 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt case RequirementWitness::Flavor::none: return RequirementWitness(); + case RequirementWitness::Flavor::witnessTable: + SLANG_ASSERT(!subst); + return *this; + case RequirementWitness::Flavor::declRef: { int diff = 0; @@ -321,16 +325,19 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt } else if (auto transitiveTypeWitness = as<TransitiveSubtypeWitness>(subtypeWitness)) { - // Hard code witness entry that `T.Differential = DifferentialBottom` for `T` that - // coerce to `DifferentialBottom`. - if (astBuilder->getDifferentialBottomType()->equals(transitiveTypeWitness->subToMid->sup)) + if (auto declaredSubtypeWitnessMidToSup = as<DeclaredSubtypeWitness>(transitiveTypeWitness->midToSup)) { - if (auto builtinAttr = requirementKey->findModifier<BuiltinRequirementModifier>()) + auto midKey = declaredSubtypeWitnessMidToSup->declRef; + auto midWitness = tryLookUpRequirementWitness(astBuilder, as<SubtypeWitness>(transitiveTypeWitness->subToMid), midKey); + if (midWitness.getFlavor() == RequirementWitness::Flavor::witnessTable) { - if (builtinAttr->kind == BuiltinRequirementKind::DifferentialType) + auto table = midWitness.getWitnessTable(); + RequirementWitness result; + if (table->requirementDictionary.TryGetValue(requirementKey, result)) { - return RequirementWitness(astBuilder->getDifferentialBottomType()); + result = result.specialize(astBuilder, midKey.substitutions); } + return result; } } } |
