summaryrefslogtreecommitdiffstats
path: root/source
diff options
context:
space:
mode:
Diffstat (limited to 'source')
-rw-r--r--source/slang/core.meta.slang79
-rw-r--r--source/slang/diff.meta.slang76
-rw-r--r--source/slang/slang-check-overload.cpp9
-rw-r--r--source/slang/slang-ir-diff-jvp.cpp94
-rw-r--r--source/slang/slang-ir-peephole.cpp23
-rw-r--r--source/slang/slang-lower-to-ir.cpp11
-rw-r--r--source/slang/slang-syntax.cpp19
7 files changed, 193 insertions, 118 deletions
diff --git a/source/slang/core.meta.slang b/source/slang/core.meta.slang
index 05963bd11..a37124bdc 100644
--- a/source/slang/core.meta.slang
+++ b/source/slang/core.meta.slang
@@ -162,7 +162,7 @@ interface __BuiltinRealType : __BuiltinSignedArithmeticType {}
/// A type that uses a floating-point representation
[sealed]
[builtin]
-interface __BuiltinFloatingPointType : __BuiltinRealType
+interface __BuiltinFloatingPointType : __BuiltinRealType, IDifferentiable
{
/// Initialize from a 32-bit floating-point value.
__init(float value);
@@ -369,6 +369,26 @@ ${{{{
case BaseType::Double:
}}}}
static $(kBaseTypes[tt].name) getPi() { return $(kBaseTypes[tt].name)(3.14159265358979323846264338328); }
+
+ typedef $(kBaseTypes[tt].name) Differential;
+
+ [__unsafeForceInlineEarly]
+ static Differential dzero()
+ {
+ return Differential(0);
+ }
+
+ [__unsafeForceInlineEarly]
+ static Differential dadd(Differential a, Differential b)
+ {
+ return a + b;
+ }
+
+ [__unsafeForceInlineEarly]
+ static Differential dmul(Differential a, Differential b)
+ {
+ return a * b;
+ }
${{{{
break;
}
@@ -891,7 +911,6 @@ for (int tt = 0; tt < kBaseTypeCount; ++tt)
sb << " __init(" << kBaseTypes[ff].name << " value);\n";
}
}
-
sb << "}\n";
}
@@ -926,7 +945,6 @@ for( int C = 2; C <= 4; ++C )
if(rr == R && cc == C) continue;
sb << "__init(matrix<T," << rr << "," << cc << "> value);\n";
}
-
sb << "}\n";
}
@@ -935,6 +953,7 @@ for (int tt = 0; tt < kBaseTypeCount; ++tt)
if(kBaseTypes[tt].tag == BaseType::Void) continue;
auto toType = kBaseTypes[tt].name;
}}}}
+
__generic<let R : int, let C : int> extension matrix<$(toType),R,C>
{
${{{{
@@ -958,6 +977,60 @@ ${{{{
}
}}}}
+__generic<T, U>
+__intrinsic_op(0)
+T __slang_noop_cast(U u);
+
+__generic<T:__BuiltinFloatingPointType, let N: int>
+extension vector<T, N> : IDifferentiable
+{
+ typedef vector<T, N> Differential;
+
+ [__unsafeForceInlineEarly]
+ static Differential dzero()
+ {
+ return Differential(__slang_noop_cast<T>(T.dzero()));
+ }
+
+ [__unsafeForceInlineEarly]
+ static Differential dadd(Differential a, Differential b)
+ {
+ return a + b;
+ }
+
+ [__unsafeForceInlineEarly]
+ static Differential dmul(This a, Differential b)
+ {
+ return a * b;
+ }
+}
+
+__generic<T:__BuiltinFloatingPointType, let R: int, let C: int>
+extension matrix<T, R, C> : IDifferentiable
+{
+ typedef matrix<T, R, C> Differential;
+
+ __init(T val);
+
+ [__unsafeForceInlineEarly]
+ static Differential dzero()
+ {
+ return matrix<T, R, C>(__slang_noop_cast<T>(T.dzero()));
+ }
+
+ [__unsafeForceInlineEarly]
+ static Differential dadd(Differential a, Differential b)
+ {
+ return a + b;
+ }
+
+ [__unsafeForceInlineEarly]
+ static Differential dmul(This a, Differential b)
+ {
+ return a * b;
+ }
+}
+
//@ public:
/// Sampling state for filtered texture fetches.
diff --git a/source/slang/diff.meta.slang b/source/slang/diff.meta.slang
index 2625d79b0..c95f8e1ac 100644
--- a/source/slang/diff.meta.slang
+++ b/source/slang/diff.meta.slang
@@ -9,56 +9,10 @@ attribute_syntax [ForwardDifferentiable] : ForwardDifferentiableAttribute;
__attributeTarget(FunctionDeclBase)
attribute_syntax [ForwardDerivative(function)] : ForwardDerivativeAttribute;
-// Add extensions for the standard types
-extension float : IDifferentiable
-{
- typedef float Differential;
-
- [__unsafeForceInlineEarly]
- static Differential dzero()
- {
- return float(0.f);
- }
-
- [__unsafeForceInlineEarly]
- static Differential dadd(Differential a, Differential b)
- {
- return a + b;
- }
-
- [__unsafeForceInlineEarly]
- static Differential dmul(This a, Differential b)
- {
- return a * b;
- }
-}
-
-__generic<let N:int>
-extension vector<float, N> : IDifferentiable
-{
- typedef vector<float, N> Differential;
-
- [__unsafeForceInlineEarly]
- static Differential dzero()
- {
- return vector<float, N>(0.f);
- }
-
- [__unsafeForceInlineEarly]
- static Differential dadd(Differential a, Differential b)
- {
- return a + b;
- }
- [__unsafeForceInlineEarly]
- static Differential dmul(This a, Differential b)
- {
- return a * b;
- }
-}
+/// Pair type that serves to wrap the primal and
+/// differential types of an arbitrary type T.
- /// Pair type that serves to wrap the primal and
- /// differential types of an arbitrary type T.
__generic<T : IDifferentiable>
__magic_type(DifferentialPairType)
__intrinsic_type($(kIROp_DifferentialPairType))
@@ -126,15 +80,13 @@ struct DifferentialPair : IDifferentiable
}
};
-typealias IDFloat = IFloat & IDifferentiable;
-
#define VECTOR_MAP_UNARY(TYPE, COUNT, FUNC, VALUE) \
vector<TYPE,COUNT> result; for(int i = 0; i < COUNT; ++i) { result[i] = FUNC(VALUE[i]); } return result
namespace dstd
{
// Natural Exponent
- __generic<T : IDFloat>
+ __generic<T : __BuiltinFloatingPointType>
__target_intrinsic(hlsl)
__target_intrinsic(glsl)
__target_intrinsic(cuda, "$P_exp($0)")
@@ -143,16 +95,16 @@ namespace dstd
[ForwardDerivative(d_exp<T>)]
T exp(T x);
- __generic<T : IDFloat>
+ __generic<T : __BuiltinFloatingPointType>
DifferentialPair<T> d_exp(DifferentialPair<T> dpx)
{
return DifferentialPair<T>(
- exp(dpx.p),
- T.dmul(exp(dpx.p), dpx.d));
+ dstd.exp(dpx.p),
+ T.dmul(dstd.exp(dpx.p), dpx.d));
}
// Sine
- __generic<T : IDFloat>
+ __generic<T : __BuiltinFloatingPointType>
__target_intrinsic(hlsl)
__target_intrinsic(glsl)
__target_intrinsic(cuda, "$P_sin($0)")
@@ -161,16 +113,16 @@ namespace dstd
[ForwardDerivative(d_sin<T>)]
T sin(T x);
- __generic<T : IDFloat>
+ __generic<T : __BuiltinFloatingPointType>
DifferentialPair<T> d_sin(DifferentialPair<T> dpx)
{
return DifferentialPair<T>(
- sin(dpx.p),
- T.dmul(cos(dpx.p), dpx.d));
+ dstd.sin(dpx.p),
+ T.dmul(dstd.cos(dpx.p), dpx.d));
}
// Cosine
- __generic<T : IDFloat>
+ __generic<T : __BuiltinFloatingPointType>
__target_intrinsic(hlsl)
__target_intrinsic(glsl)
__target_intrinsic(cuda, "$P_cos($0)")
@@ -179,12 +131,12 @@ namespace dstd
[ForwardDerivative(d_cos<T>)]
T cos(T x);
- __generic<T : IDFloat>
+ __generic<T : __BuiltinFloatingPointType>
DifferentialPair<T> d_cos(DifferentialPair<T> dpx)
{
return DifferentialPair<T>(
- cos(dpx.p),
- T.dmul(-sin(dpx.p), dpx.d));
+ dstd.cos(dpx.p),
+ T.dmul(-dstd.sin(dpx.p), dpx.d));
}
__generic<let N : int>
diff --git a/source/slang/slang-check-overload.cpp b/source/slang/slang-check-overload.cpp
index 42fab94a6..38754d170 100644
--- a/source/slang/slang-check-overload.cpp
+++ b/source/slang/slang-check-overload.cpp
@@ -1573,10 +1573,15 @@ namespace Slang
{
for (auto item : overloadExpr->lookupResult2.items)
{
+ auto funcType = as<FuncType>(GetTypeForDeclRef(item.declRef, item.declRef.decl->loc));
+ if (!funcType)
+ continue;
+ funcType = as<FuncType>(processJVPFuncType(funcType));
+ if (!funcType)
+ continue;
OverloadCandidate candidate;
candidate.flavor = OverloadCandidate::Flavor::Expr;
- candidate.funcType = as<FuncType>(processJVPFuncType(
- as<FuncType>(GetTypeForDeclRef(item.declRef, item.declRef.decl->loc))));
+ candidate.funcType = funcType;
candidate.resultType = candidate.funcType->getResultType();
candidate.item = LookupResultItem(item.declRef);
diff --git a/source/slang/slang-ir-diff-jvp.cpp b/source/slang/slang-ir-diff-jvp.cpp
index 3135f300d..574db2036 100644
--- a/source/slang/slang-ir-diff-jvp.cpp
+++ b/source/slang/slang-ir-diff-jvp.cpp
@@ -681,17 +681,6 @@ struct JVPTranscriber
return builder->getFuncType(newParameterTypes, diffReturnType);
}
- IRWitnessTable* getDifferentialBottomWitness()
- {
- IRBuilder builder(sharedBuilder);
- builder.setInsertInto(sharedBuilder->getModule()->getModuleInst());
- auto result =
- as<IRWitnessTable>(differentiableTypeConformanceContext.lookUpConformanceForType(
- builder.getDifferentialBottomType()));
- SLANG_ASSERT(result);
- return result;
- }
-
// Get or construct `:IDifferentiable` conformance for a DifferentiablePair.
IRWitnessTable* getDifferentialPairWitness(IRInst* inDiffPairType)
{
@@ -699,20 +688,23 @@ struct JVPTranscriber
builder.setInsertInto(inDiffPairType->parent);
auto diffPairType = as<IRDifferentialPairType>(inDiffPairType);
SLANG_ASSERT(diffPairType);
- auto result =
- as<IRWitnessTable>(differentiableTypeConformanceContext.lookUpConformanceForType(
- builder.getDifferentialBottomType()));
- if (result)
- return result;
-
- auto table = builder.createWitnessTable(autoDiffSharedContext->differentiableInterfaceType, diffPairType);
auto diffType = differentiateType(&builder, diffPairType->getValueType());
- auto differentialType = builder.getDifferentialPairType(diffType, getDifferentialBottomWitness());
- builder.createWitnessTableEntry(table, autoDiffSharedContext->differentialAssocTypeStructKey, differentialType);
- // Omit the method synthesis here, since we can just intercept those directly at `getXXMethodForType`.
- differentiableTypeConformanceContext.differentiableWitnessDictionary[diffPairType] = table;
- return table;
+ IRInst* tableInst = nullptr;
+ if (!differentiableTypeConformanceContext.differentiableWitnessDictionary.TryGetValue(diffPairType, tableInst))
+ {
+ IRWitnessTable* table = builder.createWitnessTable(autoDiffSharedContext->differentiableInterfaceType, diffPairType);
+ // The witness that `diffType`
+ auto differentialType = builder.getDifferentialPairType(
+ diffType,
+ differentiableTypeConformanceContext.differentiableWitnessDictionary[diffType]
+ .GetValue());
+ builder.createWitnessTableEntry(table, autoDiffSharedContext->differentialAssocTypeStructKey, differentialType);
+ // Omit the method synthesis here, since we can just intercept those directly at `getXXMethodForType`.
+ differentiableTypeConformanceContext.differentiableWitnessDictionary[diffPairType] = table;
+ tableInst = table;
+ }
+ return as<IRWitnessTable>(tableInst);
}
IRType* getOrCreateDiffPairType(IRInst* primalType, IRInst* witness)
@@ -730,8 +722,10 @@ struct JVPTranscriber
builder.setInsertInto(primalType->parent);
auto witness = as<IRWitnessTable>(
differentiableTypeConformanceContext.lookUpConformanceForType((IRType*)primalType));
- if (!witness)
- witness = getDifferentialBottomWitness();
+ if (!witness && as<IRDifferentialPairType>(primalType))
+ {
+ witness = getDifferentialPairWitness(primalType);
+ }
return builder.getDifferentialPairType(
(IRType*)primalType,
witness);
@@ -2205,29 +2199,41 @@ struct JVPDerivativeContext : public InstPassBase
bool processPairTypes(IRBuilder* builder, IRInst* instWithChildren)
{
bool modified = false;
- // Hoist all pair types to global scope when possible.
+ // Hoist and deduplicate all pair types to global scope when possible.
+ // This avoids emitting different struct types for equivalent pair types.
auto moduleInst = module->getModuleInst();
- processInstsOfType<IRDifferentialPairType>(kIROp_DifferentialPairType, [&](IRInst* originalPairType)
- {
- if (originalPairType->parent != moduleInst)
+ Dictionary<IRInst*, IRInst*> diffPairTypes;
+ for (;;)
+ {
+ bool changed = false;
+ sharedBuilderStorage.deduplicateAndRebuildGlobalNumberingMap();
+ processInstsOfType<IRDifferentialPairType>(kIROp_DifferentialPairType, [&](IRDifferentialPairType* originalPairType)
{
- originalPairType->removeFromParent();
- ShortList<IRInst*> operands;
- for (UInt i = 0; i < originalPairType->getOperandCount(); i++)
+ IRInst* finalType = nullptr;
+ if (diffPairTypes.TryGetValue(originalPairType->getValueType(), finalType))
{
- operands.add(originalPairType->getOperand(i));
+ if (finalType != originalPairType)
+ {
+ originalPairType->replaceUsesWith(finalType);
+ originalPairType->removeAndDeallocate();
+ changed = true;
+ return;
+ }
}
- auto newPairType = builder->findOrEmitHoistableInst(
- originalPairType->getFullType(),
- originalPairType->getOp(),
- originalPairType->getOperandCount(),
- operands.getArrayView().getBuffer());
- originalPairType->replaceUsesWith(newPairType);
- originalPairType->removeAndDeallocate();
- }
- });
-
- sharedBuilderStorage.deduplicateAndRebuildGlobalNumberingMap();
+ diffPairTypes[originalPairType->getValueType()] = originalPairType;
+ if (originalPairType->parent != moduleInst)
+ {
+ if (originalPairType->getValueType()->getParent() != originalPairType->getParent())
+ {
+ originalPairType->insertAfter(originalPairType->getValueType());
+ changed = true;
+ return;
+ }
+ }
+ });
+ if (!changed)
+ break;
+ }
processAllInsts([&](IRInst* inst)
{
diff --git a/source/slang/slang-ir-peephole.cpp b/source/slang/slang-ir-peephole.cpp
index 32950edc9..788110330 100644
--- a/source/slang/slang-ir-peephole.cpp
+++ b/source/slang/slang-ir-peephole.cpp
@@ -197,6 +197,29 @@ struct PeepholeContext : InstPassBase
}
}
break;
+ case kIROp_lookup_interface_method:
+ {
+ if (inst->getOperand(0)->getOp() == kIROp_WitnessTable)
+ {
+ auto wt = as<IRWitnessTable>(inst->getOperand(0));
+ auto key = inst->getOperand(1);
+ for (auto item : wt->getChildren())
+ {
+ if (auto entry = as<IRWitnessTableEntry>(item))
+ {
+ if (entry->getRequirementKey() == key)
+ {
+ auto value = entry->getSatisfyingVal();
+ inst->replaceUsesWith(value);
+ inst->removeAndDeallocate();
+ changed = true;
+ break;
+ }
+ }
+ }
+ }
+ }
+ break;
default:
break;
}
diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp
index f8d8282d8..12a9f73e6 100644
--- a/source/slang/slang-lower-to-ir.cpp
+++ b/source/slang/slang-lower-to-ir.cpp
@@ -1349,6 +1349,7 @@ struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, Lower
// The base (subToMid) will turn into a value with
// witness-table type.
IRInst* baseWitnessTable = lowerSimpleVal(context, val->subToMid);
+ IRInst* midToSup = nullptr;
// The next step should map to an interface requirement
// that is itself an interface conformance, so the result
@@ -1366,7 +1367,6 @@ struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, Lower
// produce transitive witnesses in shapes that will cuase us
// problems here.
//
- IRInst* midToSup = lowerSimpleVal(context, val->midToSup);
if (!baseWitnessTable)
{
@@ -1380,6 +1380,15 @@ struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, Lower
return LoweredValInfo::simple(midToSup);
}
+ if (auto declaredMidToSup = as<DeclaredSubtypeWitness>(val->midToSup))
+ {
+ midToSup = getInterfaceRequirementKey(context, declaredMidToSup->declRef.decl);
+ }
+ else
+ {
+ midToSup = lowerSimpleVal(context, val->midToSup);
+ }
+
return LoweredValInfo::simple(getBuilder()->emitLookupInterfaceMethodInst(
getBuilder()->getWitnessTableType(lowerType(context, val->sup)),
baseWitnessTable,
diff --git a/source/slang/slang-syntax.cpp b/source/slang/slang-syntax.cpp
index 12b9dab42..f3b590acf 100644
--- a/source/slang/slang-syntax.cpp
+++ b/source/slang/slang-syntax.cpp
@@ -234,6 +234,10 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt
case RequirementWitness::Flavor::none:
return RequirementWitness();
+ case RequirementWitness::Flavor::witnessTable:
+ SLANG_ASSERT(!subst);
+ return *this;
+
case RequirementWitness::Flavor::declRef:
{
int diff = 0;
@@ -321,16 +325,19 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt
}
else if (auto transitiveTypeWitness = as<TransitiveSubtypeWitness>(subtypeWitness))
{
- // Hard code witness entry that `T.Differential = DifferentialBottom` for `T` that
- // coerce to `DifferentialBottom`.
- if (astBuilder->getDifferentialBottomType()->equals(transitiveTypeWitness->subToMid->sup))
+ if (auto declaredSubtypeWitnessMidToSup = as<DeclaredSubtypeWitness>(transitiveTypeWitness->midToSup))
{
- if (auto builtinAttr = requirementKey->findModifier<BuiltinRequirementModifier>())
+ auto midKey = declaredSubtypeWitnessMidToSup->declRef;
+ auto midWitness = tryLookUpRequirementWitness(astBuilder, as<SubtypeWitness>(transitiveTypeWitness->subToMid), midKey);
+ if (midWitness.getFlavor() == RequirementWitness::Flavor::witnessTable)
{
- if (builtinAttr->kind == BuiltinRequirementKind::DifferentialType)
+ auto table = midWitness.getWitnessTable();
+ RequirementWitness result;
+ if (table->requirementDictionary.TryGetValue(requirementKey, result))
{
- return RequirementWitness(astBuilder->getDifferentialBottomType());
+ result = result.specialize(astBuilder, midKey.substitutions);
}
+ return result;
}
}
}