summaryrefslogtreecommitdiffstats
path: root/source/slang/slang-check-decl.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'source/slang/slang-check-decl.cpp')
-rw-r--r--source/slang/slang-check-decl.cpp295
1 files changed, 269 insertions, 26 deletions
diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp
index 4c1e967e3..2d009c28c 100644
--- a/source/slang/slang-check-decl.cpp
+++ b/source/slang/slang-check-decl.cpp
@@ -2574,6 +2574,135 @@ namespace Slang
return false;
}
+ GenericDecl* SemanticsVisitor::synthesizeGenericSignatureForRequirementWitness(
+ ConformanceCheckingContext* context,
+ DeclRef<GenericDecl> requiredMemberDeclRef,
+ List<Expr*>& synArgs,
+ List<Expr*>& synGenericArgs,
+ ThisExpr*& synThis)
+ {
+ auto synGenericDecl = m_astBuilder->create<GenericDecl>();
+
+ // For now our synthesized method will use the name and source
+ // location of the requirement we are trying to satisfy.
+ //
+ // TODO: as it stands right now our syntesized method will
+ // get a mangled name, which we don't actually want. Leaving
+ // out the name here doesn't help matters, because then *all*
+ // snthesized methods on a given type would share the same
+ // mangled name!
+ //
+ synGenericDecl->nameAndLoc = requiredMemberDeclRef.getDecl()->nameAndLoc;
+ if (synGenericDecl->nameAndLoc.name)
+ {
+ synGenericDecl->nameAndLoc.name = getSession()->getNameObj("$__syn_" + synGenericDecl->nameAndLoc.name->text);
+ }
+
+ // Dictionary to map from the original type parameters to the synthesized ones.
+ Dictionary<GenericTypeParamDecl*, GenericTypeParamDecl*> mapOrigToSynTypeParams;
+
+ // Our synthesized method will have parameters matching the names
+ // and types of those on the requirement, and it will use expressions
+ // that reference those parametesr as arguments for the call expresison
+ // that makes up the body.
+ //
+ for (auto member : requiredMemberDeclRef.getDecl()->members)
+ {
+ if (auto typeParamDecl = as<GenericTypeParamDecl>(member))
+ {
+ auto synTypeParamDecl = m_astBuilder->create<GenericTypeParamDecl>();
+ synTypeParamDecl->nameAndLoc = typeParamDecl->getNameAndLoc();
+ synTypeParamDecl->initType = typeParamDecl->initType;
+ synTypeParamDecl->parentDecl = synGenericDecl;
+ synGenericDecl->members.add(synTypeParamDecl);
+
+ mapOrigToSynTypeParams.add(typeParamDecl, synTypeParamDecl);
+
+ // Construct a DeclRefExpr from the type parameter.
+ auto synTypeParamDeclRef = makeDeclRef(synTypeParamDecl);
+
+ auto synTypeParamDeclRefExpr = m_astBuilder->create<VarExpr>();
+ synTypeParamDeclRefExpr->declRef = synTypeParamDeclRef;
+ synTypeParamDeclRefExpr->type = getTypeForDeclRef(m_astBuilder, synTypeParamDeclRef, SourceLoc());
+
+ synGenericArgs.add(synTypeParamDeclRefExpr);
+ }
+ }
+
+ for (auto member : requiredMemberDeclRef.getDecl()->members)
+ {
+ if (auto constraintDecl = as<GenericTypeConstraintDecl>(member))
+ {
+ getASTBuilder()->getSpecializedDeclRef(
+ constraintDecl, requiredMemberDeclRef.getSubst());
+
+ auto synConstraintDecl = m_astBuilder->create<GenericTypeConstraintDecl>();
+ synConstraintDecl->nameAndLoc = constraintDecl->getNameAndLoc();
+ synConstraintDecl->parentDecl = synGenericDecl;
+
+ // For constraints of type T : Interface, where T is a simple type parameter,
+ // find the declaration of T
+ //
+ if (auto typeParamDecl = as<DeclRefType>(constraintDecl->sub.type)->declRef.as<GenericTypeParamDecl>().getDecl())
+ {
+ auto synTypeParamDecl = mapOrigToSynTypeParams[typeParamDecl];
+
+ // Construct a DeclRefExpr from the type parameter.
+ auto synTypeParamDeclRef = makeDeclRef(synTypeParamDecl.getValue());
+
+ auto synTypeParamDeclRefExpr = m_astBuilder->create<VarExpr>();
+ synTypeParamDeclRefExpr->declRef = synTypeParamDeclRef;
+ synTypeParamDeclRefExpr->type = getTypeForDeclRef(m_astBuilder, synTypeParamDeclRef, SourceLoc());
+
+ synConstraintDecl->sub = TypeExp(synTypeParamDeclRefExpr);
+ synConstraintDecl->sup = constraintDecl->sup;
+ synGenericDecl->members.add(synConstraintDecl);
+ }
+ else
+ {
+ SLANG_UNEXPECTED("Cannot perform synthesis for requirements with complex type constraints.");
+ }
+ }
+ }
+
+ // Get outer substitutions. (This inner-most substition
+ // must be a ThisTypeSubstition)
+ //
+ Substitutions* outer = nullptr;
+ if (auto thisTypeSubst = findThisTypeSubstitution(
+ requiredMemberDeclRef.getSubst(),
+ as<InterfaceDecl>(requiredMemberDeclRef.getParent(m_astBuilder)).getDecl()))
+ {
+ outer = thisTypeSubst;
+ }
+
+ // Override generic pointer to point to the original generic container.
+ // This will create a substitution of the synthesized parameters for the
+ // original parameters.
+ //
+ GenericSubstitution* requiredFuncSubsts = createDefaultSubstitutionsForGeneric(m_astBuilder, this, requiredMemberDeclRef.getDecl(), outer);
+ DeclRef<Decl> requiredFuncDeclRef = m_astBuilder->getSpecializedDeclRef(requiredMemberDeclRef.getDecl()->inner, requiredFuncSubsts);
+
+ GenericSubstitution* substSynParamsForOrigGeneric = m_astBuilder->getOrCreateGenericSubstitution(
+ outer,
+ requiredMemberDeclRef.getDecl(),
+ createDefaultSubstitutionsForGeneric(m_astBuilder, this, synGenericDecl, nullptr)->getArgs());
+
+ // Substitute parameters of the synthesized generic for the parameters of the original generic.
+ requiredFuncDeclRef = substituteDeclRef(substSynParamsForOrigGeneric, m_astBuilder, requiredFuncDeclRef);
+
+ SLANG_ASSERT(requiredFuncDeclRef.as<FuncDecl>());
+
+ synGenericDecl->inner = synthesizeMethodSignatureForRequirementWitness(
+ context,
+ requiredFuncDeclRef.as<FuncDecl>(),
+ synArgs,
+ synThis);
+ synGenericDecl->inner->parentDecl = synGenericDecl;
+
+ return synGenericDecl;
+ }
+
FuncDecl* SemanticsVisitor::synthesizeMethodSignatureForRequirementWitness(
ConformanceCheckingContext* context,
DeclRef<FuncDecl> requiredMemberDeclRef,
@@ -3274,12 +3403,30 @@ namespace Slang
switch (builtinAttr->kind)
{
case BuiltinRequirementKind::DAddFunc:
- case BuiltinRequirementKind::DMulFunc:
case BuiltinRequirementKind::DZeroFunc:
return trySynthesizeDifferentialMethodRequirementWitness(
context,
requiredFuncDeclRef,
- witnessTable);
+ witnessTable,
+ SynthesisPattern::AllInductive);
+ }
+ }
+ return false;
+ }
+
+ // For generic decl, check if we match DMulFunc, and synthesize the method.
+ if (auto requiredGenericDeclRef = requiredMemberDeclRef.as<GenericDecl>())
+ {
+ if (auto builtinAttr = getInner(requiredGenericDeclRef)->findModifier<BuiltinRequirementModifier>())
+ {
+ switch (builtinAttr->kind)
+ {
+ case BuiltinRequirementKind::DMulFunc:
+ return trySynthesizeDifferentialMethodRequirementWitness(
+ context,
+ requiredGenericDeclRef,
+ witnessTable,
+ SynthesisPattern::FixedFirstArg);
}
}
return false;
@@ -3330,7 +3477,15 @@ namespace Slang
return false;
}
- Stmt* _synthesizeMemberAssignMemberHelper(ASTSynthesizer& synth, Name* funcName, Type* leftType, Expr* leftValue, List<Expr*>&& args, int nestingLevel = 0)
+ Stmt* _synthesizeMemberAssignMemberHelper(
+ ASTSynthesizer& synth,
+ Name* funcName,
+ Type* leftType,
+ Expr* leftValue,
+ List<Expr*>&& args,
+ List<Expr*>&& genericArgs,
+ List<bool>&& inductiveArgMask,
+ int nestingLevel = 0)
{
if (nestingLevel > 16)
return nullptr;
@@ -3342,11 +3497,24 @@ namespace Slang
auto forStmt = synth.emitFor(synth.emitIntConst(0), synth.emitGetArrayLengthExpr(leftValue), indexVar);
addModifier(forStmt, synth.getBuilder()->create<ForceUnrollAttribute>());
auto innerLeft = synth.emitIndexExpr(leftValue, synth.emitVarExpr(indexVar));
- for (auto& arg : args)
+
+ for (auto ii = 0; ii < args.getCount(); ++ii)
{
- arg = synth.emitIndexExpr(arg, synth.emitVarExpr(indexVar));
+ auto& arg = args[ii];
+ if (inductiveArgMask[ii])
+ arg = synth.emitIndexExpr(arg, synth.emitVarExpr(indexVar));
}
- auto assignStmt = _synthesizeMemberAssignMemberHelper(synth, funcName, arrayType->getElementType(), innerLeft, _Move(args), nestingLevel + 1);
+
+ auto assignStmt = _synthesizeMemberAssignMemberHelper(
+ synth,
+ funcName,
+ arrayType->getElementType(),
+ innerLeft,
+ _Move(args),
+ _Move(genericArgs),
+ _Move(inductiveArgMask),
+ nestingLevel + 1);
+
synth.popScope();
if (!assignStmt)
return nullptr;
@@ -3354,13 +3522,18 @@ namespace Slang
}
auto callee = synth.emitMemberExpr(leftType, funcName);
+
+ if (genericArgs.getCount() > 0)
+ callee = synth.emitGenericAppExpr(callee, _Move(genericArgs));
+
return synth.emitAssignStmt(leftValue, synth.emitInvokeExpr(callee, _Move(args)));
}
bool SemanticsVisitor::trySynthesizeDifferentialMethodRequirementWitness(
ConformanceCheckingContext* context,
DeclRef<Decl> requirementDeclRef,
- RefPtr<WitnessTable> witnessTable)
+ RefPtr<WitnessTable> witnessTable,
+ SynthesisPattern pattern)
{
// We support two cases of synthesis here.
// Case 1 is that there the associated Differential type is defined to be `DifferentialBottom`.
@@ -3371,9 +3544,10 @@ namespace Slang
// ```
// static TResult requiredMethod(TParam1 p0, TParam2 p1, ...)
// ```
- // Where TResult, TParam1, TParam2 is either `This` or `Differential`,
- // We synthesize a memberwise dispatch to compute each field of `TResult`,
- // resulting an implementation of the form:
+ // Where TResult,TParam1, TParam2 is either `This` or `Differential`,
+ // We synthesize a memberwise dispatch to compute each field of `TResult`.
+ // Multiple patterns are supported (see SemanticsVisitor::SynthesisPattern for a full list)
+ // For AllInductive, we synthesize an implementation of the form:
// ```
// [BackwardDifferentiable]
// static TResult requiredMethod(TParam1 p0, TParam2 p1, ...)
@@ -3404,13 +3578,32 @@ namespace Slang
ASTSynthesizer synth(m_astBuilder, getNamePool());
List<Expr*> synArgs;
+ List<Expr*> synGenericArgs;
ThisExpr* synThis = nullptr;
- auto synFunc = synthesizeMethodSignatureForRequirementWitness(
- context, requirementDeclRef.as<FuncDecl>(), synArgs, synThis);
+ FuncDecl* synFunc = nullptr;
+ GenericDecl* synGeneric = nullptr;
+
+ if (auto genericDeclRef = requirementDeclRef.as<GenericDecl>())
+ {
+ synGeneric = synthesizeGenericSignatureForRequirementWitness(
+ context, genericDeclRef, synArgs, synGenericArgs, synThis);
+ synFunc = as<FuncDecl>(synGeneric->inner);
+ }
+ else if (auto funcDeclRef = requirementDeclRef.as<FuncDecl>())
+ {
+ synFunc = synthesizeMethodSignatureForRequirementWitness(
+ context, funcDeclRef, synArgs, synThis);
+ }
+
+ SLANG_ASSERT(synFunc);
addModifier(synFunc, m_astBuilder->create<BackwardDifferentiableAttribute>());
- synFunc->parentDecl = context->parentDecl;
+ if (synGeneric)
+ synGeneric->parentDecl = context->parentDecl;
+ else
+ synFunc->parentDecl = context->parentDecl;
+
synth.pushContainerScope(synFunc);
auto blockStmt = m_astBuilder->create<BlockStmt>();
synFunc->body = blockStmt;
@@ -3438,23 +3631,71 @@ namespace Slang
// Construct reference exprs to the member's corresponding fields in each parameter.
List<Expr*> paramFields;
- int paramIndex = 0;
- for (auto arg : synArgs)
+ List<bool> inductiveArgMask;
+
+ switch (pattern)
{
- auto memberExpr = m_astBuilder->create<MemberExpr>();
- memberExpr->baseExpression = arg;
- // TODO: we should probably fetch the name from `[DerivativeMember]` if `arg` is
- // Differential type.
- memberExpr->name = varMember->getName();
- paramFields.add(memberExpr);
- paramIndex++;
+ case SynthesisPattern::AllInductive:
+ {
+ int paramIndex = 0;
+ for (auto arg : synArgs)
+ {
+ auto memberExpr = m_astBuilder->create<MemberExpr>();
+ memberExpr->baseExpression = arg;
+ // TODO: we should probably fetch the name from `[DerivativeMember]` if `arg` is
+ // Differential type.
+ memberExpr->name = varMember->getName();
+ paramFields.add(memberExpr);
+ inductiveArgMask.add(true);
+
+ paramIndex++;
+ }
+ break;
+ }
+ case SynthesisPattern::FixedFirstArg:
+ {
+ int paramIndex = 0;
+ for (auto arg : synArgs)
+ {
+ if (paramIndex == 0)
+ {
+ paramFields.add(arg);
+ inductiveArgMask.add(false);
+
+ paramIndex++;
+ }
+ else
+ {
+ auto memberExpr = m_astBuilder->create<MemberExpr>();
+ memberExpr->baseExpression = arg;
+ // TODO: we should probably fetch the name from `[DerivativeMember]` if `arg` is
+ // Differential type.
+ memberExpr->name = varMember->getName();
+ paramFields.add(memberExpr);
+ inductiveArgMask.add(true);
+
+ paramIndex++;
+ }
+ }
+ break;
+ }
+ default:
+ SLANG_UNIMPLEMENTED_X("unhandled synthesis pattern");
+ break;
}
// Invoke the method for the field and assign the value to resultVar.
// TODO: we should probably fetch the name from `[DerivativeMember]` if `resultVarExpr`
// is Differential type.
auto leftVal = synth.emitMemberExpr(resultVarExpr, varMember->getName());
- if (!_synthesizeMemberAssignMemberHelper(synth, requirementDeclRef.getName(), memberType, leftVal, _Move(paramFields)))
+ if (!_synthesizeMemberAssignMemberHelper(
+ synth,
+ requirementDeclRef.getName(),
+ memberType,
+ leftVal,
+ _Move(paramFields),
+ _Move(synGenericArgs),
+ _Move(inductiveArgMask)))
return false;
}
@@ -3473,11 +3714,11 @@ namespace Slang
// This can be done by obtaining ThisTypeSubstitution from requirementDeclRef to get the
// generic substitution for outer generic parameters, and apply it here.
SubstitutionSet substSet;
- if (auto thisTypeSusbt = findThisTypeSubstitution(
+ if (auto thisTypeSubst = findThisTypeSubstitution(
requirementDeclRef.getSubst(),
as<InterfaceDecl>(requirementDeclRef.getDecl()->parentDecl)))
{
- if (auto declRefType = as<DeclRefType>(thisTypeSusbt->witness->sub))
+ if (auto declRefType = as<DeclRefType>(thisTypeSubst->witness->sub))
{
substSet = declRefType->declRef.getSubst();
}
@@ -3610,7 +3851,9 @@ namespace Slang
// requirement, it may be possible that we can still synthesis the
// implementation if this is one of the known builtin requirements.
// Otherwise, report diagnostic now.
- if (!requiredMemberDeclRef.getDecl()->hasModifier<BuiltinRequirementModifier>())
+ if (!requiredMemberDeclRef.getDecl()->hasModifier<BuiltinRequirementModifier>() &&
+ !(requiredMemberDeclRef.as<GenericDecl>() &&
+ getInner(requiredMemberDeclRef.as<GenericDecl>())->hasModifier<BuiltinRequirementModifier>()))
{
getSink()->diagnose(inheritanceDecl, Diagnostics::typeDoesntImplementInterfaceRequirement, subType, requiredMemberDeclRef);
getSink()->diagnose(requiredMemberDeclRef, Diagnostics::seeDeclarationOf, requiredMemberDeclRef);