summaryrefslogtreecommitdiffstats
path: root/source
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2025-02-12 14:58:00 -0800
committerGitHub <noreply@github.com>2025-02-12 14:58:00 -0800
commit74852ceb6b3bcc018042aba3e30933b7b6fc09ef (patch)
treeb77dd62f0e74510fae0d2af8b7afa260ce9d49b8 /source
parent3f102afe1038882f336dc052a9954811150fa700 (diff)
Allow LHS of `where` to be any type. (#6333)
* Allow LHS of `where` to be any type. * Register free-form extensions when loading precompiled module. * Fix test. * Fix. * Fix `as<IRType>`. * try fix precompiled module test.
Diffstat (limited to 'source')
-rw-r--r--source/slang/hlsl.meta.slang13
-rw-r--r--source/slang/slang-check-decl.cpp133
-rw-r--r--source/slang/slang-check-impl.h13
-rw-r--r--source/slang/slang-ir-hlsl-legalize.cpp2
-rw-r--r--source/slang/slang-ir-insts.h2
-rw-r--r--source/slang/slang-ir-link.cpp4
-rw-r--r--source/slang/slang-ir-use-uninitialized-values.cpp2
-rw-r--r--source/slang/slang-lower-to-ir.cpp6
-rw-r--r--source/slang/slang-serialize-container.cpp16
9 files changed, 127 insertions, 64 deletions
diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang
index a10e747c0..81b28b30a 100644
--- a/source/slang/hlsl.meta.slang
+++ b/source/slang/hlsl.meta.slang
@@ -524,6 +524,13 @@ interface ITexelElement
__init(Element x);
}
+extension<T:__BuiltinArithmeticType> T : ITexelElement
+{
+ typealias Element = T;
+ static const int elementCount = 1;
+ __intrinsic_op(0) __init(Element x);
+}
+
${{{
// Scalar types that can be used as texel element.
const char* texeElementScalarTypes[] = {
@@ -539,12 +546,6 @@ const char* texeElementScalarTypes[] = {
for (auto elementType : texeElementScalarTypes)
{
}}}
-extension $(elementType) : ITexelElement
-{
- typealias Element = $(elementType);
- static const int elementCount = 1;
- __intrinsic_op(0) __init(Element x);
-}
extension<int N> vector<$(elementType), N> : ITexelElement
{
typealias Element = $(elementType);
diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp
index 76074f551..1ef5b1cec 100644
--- a/source/slang/slang-check-decl.cpp
+++ b/source/slang/slang-check-decl.cpp
@@ -9789,6 +9789,46 @@ void SemanticsVisitor::validateArraySizeForVariable(VarDeclBase* varDecl)
}
}
+bool getExtensionTargetDeclList(
+ ASTBuilder* astBuilder,
+ DeclRefType* targetDeclRefType,
+ ExtensionDecl* extDecl,
+ ShortList<AggTypeDecl*>& targetDecls)
+{
+ if (auto aggTypeDeclRef = targetDeclRefType->getDeclRef().as<AggTypeDecl>())
+ {
+ auto aggTypeDecl = aggTypeDeclRef.getDecl();
+
+ targetDecls.add(aggTypeDecl);
+ return true;
+ }
+
+ auto genericParamDeclRef = targetDeclRefType->getDeclRef().as<GenericTypeParamDeclBase>();
+ if (!genericParamDeclRef)
+ return false;
+
+ auto genericParent = as<GenericDecl>(genericParamDeclRef.getParent().getDecl());
+ if (!genericParent)
+ return false;
+
+ if (genericParent != extDecl->parentDecl)
+ return false;
+
+ for (auto member : getMembersOfType<GenericTypeConstraintDecl>(astBuilder, genericParent))
+ {
+ if (getSub(astBuilder, member) == targetDeclRefType)
+ {
+ auto baseType = getSup(astBuilder, member);
+ if (auto baseTypeDecl = isDeclRefTypeOf<AggTypeDecl>(baseType))
+ {
+ targetDecls.add(baseTypeDecl.getDecl());
+ }
+ }
+ }
+ return targetDecls.getCount() != 0;
+}
+
+
void SemanticsDeclBasesVisitor::_validateExtensionDeclTargetType(ExtensionDecl* decl)
{
if (auto targetDeclRefType = as<DeclRefType>(decl->targetType))
@@ -11582,8 +11622,8 @@ void checkDerivativeAttributeImpl(
auto derivativeFuncThisType = getTypeForThisExpr(visitor, calleeFuncDeclRef);
// If the function is a member function, we need to check that the
- // `this` type matches the expected type. This will ensure that after lowering to
- // IR, the two functions are compatible.
+ // `this` type matches the expected type. This will ensure that after lowering
+ // to IR, the two functions are compatible.
//
if (!areTypesCompatibile(visitor, funcThisType, derivativeFuncThisType))
{
@@ -11971,8 +12011,9 @@ void checkDerivativeOfAttributeImpl(
if (as<ErrorType>(resolved->type.type))
{
- // If we can't resolve a type, something went wrong. If we're working with a generic
- // decl, the most likely cause is a failure of generic argument inference.
+ // If we can't resolve a type, something went wrong. If we're working with a
+ // generic decl, the most likely cause is a failure of generic argument
+ // inference.
//
visitor->getSink()->diagnose(
derivativeOfAttr,
@@ -12284,8 +12325,8 @@ bool SemanticsDeclAttributesVisitor::collectInitializableMembers(
// Find the base type's members first
for (auto inheritanceMember : structDecl->getMembersOfType<InheritanceDecl>())
{
- // For base types, we need to pick their parameters of the constructor to the derived type's
- // constructor
+ // For base types, we need to pick their parameters of the constructor to the derived
+ // type's constructor
if (auto baseTypeDeclRef = isDeclRefTypeOf<StructDecl>(inheritanceMember->base.type))
{
// We should only find the member initialization constructor because it is the
@@ -12294,15 +12335,15 @@ bool SemanticsDeclAttributesVisitor::collectInitializableMembers(
baseTypeDeclRef.getDecl(),
ConstructorDecl::ConstructorFlavor::SynthesizedMemberInit);
- // The constructor has to have higher or equal visibility level than the struct itself,
- // otherwise, it's not accessible so we will not pick up.
+ // The constructor has to have higher or equal visibility level than the struct
+ // itself, otherwise, it's not accessible so we will not pick up.
if (ctor && getDeclVisibility(ctor) >= ctorVisibility)
{
for (ParamDecl* param : ctor->getParameters())
{
- // Because the parameters in the ctor must have the higher or equal visibility
- // than the ctor itself, we don't need to check the visibility level of the
- // parameter.
+ // Because the parameters in the ctor must have the higher or equal
+ // visibility than the ctor itself, we don't need to check the visibility
+ // level of the parameter.
resultMembers.add(param);
}
}
@@ -12342,10 +12383,9 @@ static Expr* _getParamDefaultValue(SemanticsVisitor* visitor, VarDeclBase* varDe
bool SemanticsDeclAttributesVisitor::_synthesizeCtorSignature(StructDecl* structDecl)
{
- // If a type or its base type already defines any explicit constructors, do not synthesize any
- // constructors.
- // See
- // https://github.com/shader-slang/spec/blob/main/proposals/004-initialization.md#inheritance-initialization
+ // If a type or its base type already defines any explicit constructors, do not synthesize
+ // any constructors. see:
+ // https://github.com/shader-slang/slang/blob/master/docs/proposals/004-initialization.md#inheritance-initialization
if (_hasExplicitConstructor(structDecl, true))
return false;
@@ -12397,9 +12437,9 @@ bool SemanticsDeclAttributesVisitor::_synthesizeCtorSignature(StructDecl* struct
ctorParam->loc = ctor->loc;
ctor->members.add(ctorParam);
- // We need to ensure member is `no_diff` if it cannot be differentiated, `ctor` modifiers do
- // not matter in this case since member-wise ctor is always differentiable or "treat as
- // differentiable".
+ // We need to ensure member is `no_diff` if it cannot be differentiated, `ctor`
+ // modifiers do not matter in this case since member-wise ctor is always differentiable
+ // or "treat as differentiable".
if (!isTypeDifferentiable(member->getType()) || member->hasModifier<NoDiffModifier>())
{
auto noDiffMod = m_astBuilder->create<NoDiffModifier>();
@@ -12559,7 +12599,8 @@ void SemanticsDeclAttributesVisitor::visitStructDecl(StructDecl* structDecl)
totalWidth += int(thisFieldWidth);
groupInfo.add({memberIndex, int(thisFieldWidth), t, bfm});
}
- // If the struct ended with a bitpacked member, then make sure we don't forget the last group
+ // If the struct ended with a bitpacked member, then make sure we don't forget the last
+ // group
dispatchSomeBitPackedMembers();
}
@@ -12630,8 +12671,8 @@ static void _propagateRequirement(
if (!isAnyInvalid && resultCaps.isInvalid())
{
// If joining the referenced decl's requirements results an invalid capability set,
- // then the decl is using things that require conflicting set of capabilities, and we should
- // diagnose an error.
+ // then the decl is using things that require conflicting set of capabilities, and we
+ // should diagnose an error.
if (referencedDecl && decl)
{
maybeDiagnose(
@@ -12736,17 +12777,17 @@ struct CapabilityDeclReferenceVisitor
// `calling_functions_targets`:
// ``` default_target = calling_functions_targets-{other_case_targets} ```
//
- // * `calling_functions_capability` = `requirement attribute` of the calling function;
- // if missing
+ // * `calling_functions_capability` = `requirement attribute` of the calling
+ // function; if missing
// we can assume it is `any_target`
//
- // * `{other_case_targets}` = set of all capabilities all `case` statments target inside
- // the `__target_switch`
+ // * `{other_case_targets}` = set of all capabilities all `case` statments target
+ // inside the `__target_switch`
- // If we do not handle `default:`, the codegen will fail when trying to find a specific
- // codegen target not handled explicitly by a `case` statment.
- // We must also ensure the `default` case is last so we have priority to hit `case`
- // statments and can preprocess `case` statments before the `default` case.
+ // If we do not handle `default:`, the codegen will fail when trying to find a
+ // specific codegen target not handled explicitly by a `case` statment. We must also
+ // ensure the `default` case is last so we have priority to hit `case` statments and
+ // can preprocess `case` statments before the `default` case.
CapabilitySet targetCap;
if (CapabilityName(stmt->targetCases[targetCaseIndex]->capability) ==
CapabilityName::Invalid)
@@ -12901,7 +12942,8 @@ CapabilitySet SemanticsDeclCapabilityVisitor::getDeclaredCapabilitySet(Decl* dec
// For every existing target, we want to join their requirements together.
// If the the parent defines additional targets, we want to add them to the disjunction set.
// For example:
- // [require(glsl)] struct Parent { [require(glsl, glsl_ext_1)] [require(spirv)] void foo(); }
+ // [require(glsl)] struct Parent { [require(glsl, glsl_ext_1)] [require(spirv)] void
+ // foo(); }
// The requirement for `foo` should be glsl+glsl_ext_1 | spirv.
//
CapabilitySet declaredCaps;
@@ -12990,8 +13032,8 @@ static inline void _dispatchCapabilitiesVisitorOfFunctionDecl(
void SemanticsDeclCapabilityVisitor::visitFunctionDeclBase(FunctionDeclBase* funcDecl)
{
- // If the function is an entrypoint and specifies a target stage, add the capabilities to our
- // function capabilities.
+ // If the function is an entrypoint and specifies a target stage, add the capabilities to
+ // our function capabilities.
_dispatchCapabilitiesVisitorOfFunctionDecl(
this,
funcDecl,
@@ -13012,8 +13054,8 @@ void SemanticsDeclCapabilityVisitor::visitFunctionDeclBase(FunctionDeclBase* fun
auto vis = getDeclVisibility(funcDecl);
- // If 0 capabilities were annotated on a function, capabilities are inferred from the function
- // body
+ // If 0 capabilities were annotated on a function, capabilities are inferred from the
+ // function body
if (declaredCaps.isEmpty())
{
declaredCaps = funcDecl->inferredCapabilityRequirements;
@@ -13133,7 +13175,8 @@ DeclVisibility getDeclVisibility(Decl* decl)
: parentModule->defaultVisibility;
}
- // Members of other agg type decls will have their default visibility capped to the parents'.
+ // Members of other agg type decls will have their default visibility capped to the
+ // parents'.
if (as<NamespaceDecl>(decl))
{
return DeclVisibility::Public;
@@ -13345,15 +13388,16 @@ void SemanticsDeclCapabilityVisitor::diagnoseUndeclaredCapability(
// There are two causes for why type checking failed on failedAvailableSet.
// The first scenario is that failedAvailableSet defines a set of capabilities on a
- // compilation target (e.g. hlsl) that isn't defined by some callees, for example, if we have
- // a function:
+ // compilation target (e.g. hlsl) that isn't defined by some callees, for example, if we
+ // have a function:
// [require(hlsl)] // <-- failedAvailableSet
// [require(cpp)]
// void caller()
// {
// printf(); // assume this is defined for (cpp | cuda).
// }
- // In this case we should diagnose error reporting printf isn't defined on a required target.
+ // In this case we should diagnose error reporting printf isn't defined on a required
+ // target.
//
// Now, we detect if we are case 1.
@@ -13370,8 +13414,8 @@ void SemanticsDeclCapabilityVisitor::diagnoseUndeclaredCapability(
decl,
outFailedAtom);
- // Anything defined on a non-failed target atom may be the culprit to why we fail having
- // a target capability. Print out all possible culprits.
+ // Anything defined on a non-failed target atom may be the culprit to why we fail
+ // having a target capability. Print out all possible culprits.
CapabilityAtomSet failedAtomSet;
failedAtomSet.add((UInt)outFailedAtom);
CapabilityAtomSet targetsNotUsedSet;
@@ -13395,9 +13439,12 @@ void SemanticsDeclCapabilityVisitor::diagnoseUndeclaredCapability(
}
}
- //// The second scenario is when the callee is using a capability that is not provided by the
- /// requirement. / For example: / [require(hlsl,b,c)] / void caller() / { / useD();
- ///// require capability (hlsl,d) / } / In this case we should report that useD() is using a
+ //// The second scenario is when the callee is using a capability that is not provided by
+ /// the
+ /// requirement. / For example: / [require(hlsl,b,c)] / void caller() / { /
+ /// useD();
+ ///// require capability (hlsl,d) / } / In this case we should report that useD() is
+ /// using a
/// capability that is not declared by caller.
////
diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h
index 350362a0e..59290f8ad 100644
--- a/source/slang/slang-check-impl.h
+++ b/source/slang/slang-check-impl.h
@@ -3066,4 +3066,17 @@ bool resolveStageOfProfileWithEntryPoint(
const List<RefPtr<TargetRequest>>& targets,
FuncDecl* entryPointFuncDecl,
DiagnosticSink* sink);
+
+// For an extensions decl, collect a list of decls on which the extension might be applying to.
+// For example, if we see a `extension Foo`, return a `Decl*` that represents `struct Foo`.
+// In the case of free-form generic extensions i.e. `extension<T:IFoo> T : IBar`, return `IFoo`.
+// These are the decls that we need to register the extension with in
+// `mapTypeToCandidateExtensions`.
+// Returns true when any base decls are found.
+bool getExtensionTargetDeclList(
+ ASTBuilder* astBuilder,
+ DeclRefType* targetDeclRefType,
+ ExtensionDecl* extDeclRef,
+ ShortList<AggTypeDecl*>& targetDecls);
+
} // namespace Slang
diff --git a/source/slang/slang-ir-hlsl-legalize.cpp b/source/slang/slang-ir-hlsl-legalize.cpp
index ec2419985..7116d635e 100644
--- a/source/slang/slang-ir-hlsl-legalize.cpp
+++ b/source/slang/slang-ir-hlsl-legalize.cpp
@@ -36,7 +36,7 @@ void searchChildrenForForceVarIntoStructTemporarily(IRModule* module, IRInst* in
continue;
auto forceStructArg = arg->getOperand(0);
auto forceStructBaseType =
- as<IRType>(forceStructArg->getDataType()->getOperand(0));
+ (IRType*)(forceStructArg->getDataType()->getOperand(0));
IRBuilder builder(call);
if (forceStructBaseType->getOp() == kIROp_StructType)
{
diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h
index 9c3892c0e..dbefa68c7 100644
--- a/source/slang/slang-ir-insts.h
+++ b/source/slang/slang-ir-insts.h
@@ -1143,7 +1143,7 @@ struct IRMixedDifferentialInstDecoration : IRAutodiffInstDecoration
IRUse pairType;
IR_LEAF_ISA(MixedDifferentialInstDecoration)
- IRType* getPairType() { return as<IRType>(getOperand(0)); }
+ IRType* getPairType() { return (IRType*)(getOperand(0)); }
};
struct IRRecomputeBlockDecoration : IRAutodiffInstDecoration
diff --git a/source/slang/slang-ir-link.cpp b/source/slang/slang-ir-link.cpp
index 364e58c48..53bbbda9e 100644
--- a/source/slang/slang-ir-link.cpp
+++ b/source/slang/slang-ir-link.cpp
@@ -614,8 +614,8 @@ IRWitnessTable* cloneWitnessTableImpl(
IRWitnessTable* clonedTable = dstTable;
if (!clonedTable)
{
- auto clonedBaseType = cloneType(context, as<IRType>(originalTable->getConformanceType()));
- auto clonedSubType = cloneType(context, as<IRType>(originalTable->getConcreteType()));
+ auto clonedBaseType = cloneType(context, (IRType*)(originalTable->getConformanceType()));
+ auto clonedSubType = cloneType(context, (IRType*)(originalTable->getConcreteType()));
clonedTable = builder->createWitnessTable(clonedBaseType, clonedSubType);
}
cloneSimpleGlobalValueImpl(context, originalTable, originalValues, clonedTable, registerValue);
diff --git a/source/slang/slang-ir-use-uninitialized-values.cpp b/source/slang/slang-ir-use-uninitialized-values.cpp
index 51de2117c..71aae5923 100644
--- a/source/slang/slang-ir-use-uninitialized-values.cpp
+++ b/source/slang/slang-ir-use-uninitialized-values.cpp
@@ -186,7 +186,7 @@ static bool canIgnoreType(IRType* type, IRType* upper)
if (auto spec = as<IRSpecialize>(type))
{
IRInst* inner = getResolvedInstForDecorations(spec);
- IRType* innerType = as<IRType>(inner);
+ IRType* innerType = (IRType*)(inner);
return canIgnoreType(innerType, upper);
}
diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp
index 06cdf430b..ed8a52b9e 100644
--- a/source/slang/slang-lower-to-ir.cpp
+++ b/source/slang/slang-lower-to-ir.cpp
@@ -9453,10 +9453,10 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
// should handle propgation of value-size information from constraints
// back to generic parameters?
//
- if (auto declRefType = as<DeclRefType>(constraintDecl->sub.type))
+ if (auto genParamDeclRef =
+ isDeclRefTypeOf<GenericTypeParamDeclBase>(constraintDecl->sub.type))
{
- auto typeParamDeclVal =
- subContext->findLoweredDecl(declRefType->getDeclRef().getDecl());
+ auto typeParamDeclVal = subContext->findLoweredDecl(genParamDeclRef.getDecl());
SLANG_ASSERT(typeParamDeclVal && typeParamDeclVal->val);
subBuilder->addTypeConstraintDecoration(typeParamDeclVal->val, supType);
}
diff --git a/source/slang/slang-serialize-container.cpp b/source/slang/slang-serialize-container.cpp
index e5775f526..b5121d373 100644
--- a/source/slang/slang-serialize-container.cpp
+++ b/source/slang/slang-serialize-container.cpp
@@ -5,6 +5,7 @@
#include "../core/slang-math.h"
#include "../core/slang-stream.h"
#include "../core/slang-text-io.h"
+#include "slang-check-impl.h"
#include "slang-compiler.h"
#include "slang-mangled-lexer.h"
#include "slang-parser.h"
@@ -813,15 +814,16 @@ static List<ExtensionDecl*>& _getCandidateExtensionList(
if (auto targetDeclRefType =
as<DeclRefType>(extensionDecl->targetType))
{
- // Attach our extension to that type as a candidate...
- if (auto aggTypeDeclRef =
- targetDeclRefType->getDeclRef()
- .as<AggTypeDecl>())
+ ShortList<AggTypeDecl*> baseDecls;
+ getExtensionTargetDeclList(
+ astBuilder,
+ targetDeclRefType,
+ extensionDecl,
+ baseDecls);
+ for (auto baseDecl : baseDecls)
{
- auto aggTypeDecl = aggTypeDeclRef.getDecl();
-
_getCandidateExtensionList(
- aggTypeDecl,
+ baseDecl,
moduleDecl->mapTypeToCandidateExtensions)
.add(extensionDecl);
}