summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2017-10-31 11:12:08 -0400
committerYong He <yonghe@outlook.com>2017-10-31 11:12:08 -0400
commit093bf1eb9149ba82258b5a5a159b2f54263b17c2 (patch)
tree8ee5c2bd4b730d3bd446546dd50f0284d3e47161
parent84f381cc180b3176d6a58da4085ee8470f246922 (diff)
work in-progress: type checking associated types
-rw-r--r--source/slang/check.cpp96
-rw-r--r--source/slang/decl-defs.h5
-rw-r--r--source/slang/diagnostic-defs.h79
-rw-r--r--source/slang/emit.cpp12
-rw-r--r--source/slang/lower.cpp15
-rw-r--r--source/slang/parser.cpp31
-rw-r--r--source/slang/syntax.cpp14
-rw-r--r--source/slang/type-defs.h2
-rw-r--r--tests/compute/assoctype-simple.slang13
9 files changed, 156 insertions, 111 deletions
diff --git a/source/slang/check.cpp b/source/slang/check.cpp
index 0c35c4bf4..fe7498d86 100644
--- a/source/slang/check.cpp
+++ b/source/slang/check.cpp
@@ -1538,7 +1538,48 @@ namespace Slang
requiredInitDecl);
}
}
+ else if (auto subStructTypeDecl = dynamic_cast<AggTypeDecl*>(memberDecl))
+ {
+ // this is a sub type (e.g. nested struct declaration) in an aggregate type
+ // check if this sub type declaration satisfies the constraints defined by the associated type
+ if (auto requiredTypeDecl = requiredMemberDeclRef.As<AssocTypeDecl>())
+ {
+ bool conformance = true;
+ for (auto & inheritanceDecl : requiredTypeDecl.getDecl()->getMembersOfType<InheritanceDecl>())
+ {
+ conformance = conformance && checkConformance(subStructTypeDecl, inheritanceDecl.Ptr());
+ }
+ return conformance;
+ }
+ }
+ else if (auto typedefDecl = dynamic_cast<TypeDefDecl*>(memberDecl))
+ {
+ // this is a type-def decl in an aggregate type
+ // check if the specified type satisfies the constraints defined by the associated type
+ if (auto requiredTypeDecl = requiredMemberDeclRef.As<AssocTypeDecl>())
+ {
+ auto constraintList = requiredTypeDecl.getDecl()->getMembersOfType<InheritanceDecl>();
+ if (constraintList.Count())
+ {
+ auto declRefType = typedefDecl->type->AsDeclRefType();
+ if (!declRefType)
+ return false;
+ auto subStructTypeDecl = declRefType->declRef.getDecl()->As<AggTypeDecl>();
+ //TODO: What do we do if type is a generic specialization?
+ // i.e. if the struct defines typedef Generic<float> T;
+ // how to check if T satisfies the associatedtype constraints?
+ // the code below will only work when T is defined to be a simple aggregated type (no generics).
+ bool conformance = true;
+ for (auto & inheritanceDecl : constraintList)
+ {
+ conformance = conformance && checkConformance(subStructTypeDecl, inheritanceDecl.Ptr());
+ }
+ return conformance;
+ }
+ return true;
+ }
+ }
// Default: just assume that thing aren't being satisfied.
return false;
}
@@ -1623,11 +1664,13 @@ namespace Slang
// declares conformance to the interface `interfaceDeclRef`,
// (via the given `inheritanceDecl`) actually provides
// members to satisfy all the requirements in the interface.
- void checkInterfaceConformance(
+ bool checkInterfaceConformance(
AggTypeDecl* typeDecl,
InheritanceDecl* inheritanceDecl,
DeclRef<InterfaceDecl> interfaceDeclRef)
{
+ bool result = true;
+
// We need to check the declaration of the interface
// before we can check that we conform to it.
EnsureDecl(interfaceDeclRef.getDecl());
@@ -1654,7 +1697,7 @@ namespace Slang
// to the inherited interface.
//
// TODO: we *really* need a linearization step here!!!!
- checkConformanceToType(
+ result = result && checkConformanceToType(
typeDecl,
inheritanceDecl,
getBaseType(requiredInheritanceDeclRef));
@@ -1670,16 +1713,20 @@ namespace Slang
requiredMemberDeclRef);
if (!conformanceWitness)
+ {
+ result = false;
continue;
+ }
// Store that witness into a table stored on the `inheritnaceDecl`
// so that it can be used for downstream code generation.
inheritanceDecl->requirementWitnesses.Add(requiredMemberDeclRef, conformanceWitness);
}
+ return result;
}
- void checkConformanceToType(
+ bool checkConformanceToType(
AggTypeDecl* typeDecl,
InheritanceDecl* inheritanceDecl,
Type* baseType)
@@ -1692,29 +1739,29 @@ namespace Slang
// The type is stating that it conforms to an interface.
// We need to check that it provides all of the members
// required by that interface.
- checkInterfaceConformance(
+ return checkInterfaceConformance(
typeDecl,
inheritanceDecl,
baseInterfaceDeclRef);
- return;
}
}
getSink()->diagnose(inheritanceDecl, Diagnostics::unimplemented, "type not supported for inheritance");
+ return false;
}
// Check that the type declaration `typeDecl`, which
// declares that it inherits from another type via
// `inheritanceDecl` actually does what it needs to
// for that inheritance to be valid.
- void checkConformance(
+ bool checkConformance(
AggTypeDecl* typeDecl,
InheritanceDecl* inheritanceDecl)
{
// Look at the type being inherited from, and validate
// appropriately.
auto baseType = inheritanceDecl->base.type;
- checkConformanceToType(typeDecl, inheritanceDecl, baseType);
+ return checkConformanceToType(typeDecl, inheritanceDecl, baseType);
}
void visitAggTypeDecl(AggTypeDecl* decl)
@@ -1758,6 +1805,11 @@ namespace Slang
// Don't check that an interface conforms to the
// things it inherits from.
}
+ else if (auto assocTypeDecl = dynamic_cast<AssocTypeDecl*>(decl))
+ {
+ // Don't check that an associated type decl conforms to the
+ // things it inherits from.
+ }
else
{
// For non-interface types we need to check conformance.
@@ -1794,6 +1846,24 @@ namespace Slang
decl->SetCheckState(DeclCheckState::Checked);
}
+ void visitAssocTypeDecl(AssocTypeDecl* decl)
+ {
+ if (decl->IsChecked(DeclCheckState::Checked)) return;
+ decl->SetCheckState(DeclCheckState::CheckedHeader);
+
+ // assoctype only allowed in an interface
+ auto interfaceDecl = decl->ParentDecl->As<InterfaceDecl>();
+ if (!interfaceDecl)
+ getSink()->diagnose(decl, Slang::Diagnostics::assocTypeInInterfaceOnly);
+
+ // Now check all of the member declarations.
+ for (auto member : decl->Members)
+ {
+ checkDecl(member);
+ }
+ decl->SetCheckState(DeclCheckState::Checked);
+ }
+
void checkStmt(Stmt* stmt)
{
if (!stmt) return;
@@ -5647,8 +5717,13 @@ namespace Slang
RefPtr<Expr> visitInvokeExpr(InvokeExpr *expr)
{
+ if (auto appExpr = expr->FunctionExpr->As<GenericAppExpr>())
+ if (auto varExpr = appExpr->FunctionExpr->As<VarExpr>())
+ if (varExpr->name->text == "test")
+ printf("break");
// check the base expression first
expr->FunctionExpr = CheckExpr(expr->FunctionExpr);
+
// Next check the argument expressions
for (auto & arg : expr->Arguments)
@@ -6386,7 +6461,12 @@ namespace Slang
auto type = getFuncType(session, funcDeclRef);
return QualType(type);
}
-
+ else if (auto assocTypeDeclRef = declRef.As<AssocTypeDecl>())
+ {
+ auto type = DeclRefType::Create(session, assocTypeDeclRef);
+ *outTypeResult = type;
+ return QualType(getTypeType(type));
+ }
if( sink )
{
sink->diagnose(declRef, Diagnostics::unimplemented, "cannot form reference to this kind of declaration");
diff --git a/source/slang/decl-defs.h b/source/slang/decl-defs.h
index 4021f5a38..bb8c26f58 100644
--- a/source/slang/decl-defs.h
+++ b/source/slang/decl-defs.h
@@ -122,9 +122,8 @@ SYNTAX_CLASS(TypeDefDecl, SimpleTypeDecl)
SYNTAX_FIELD(TypeExp, type)
END_SYNTAX_CLASS()
-// An 'assoctype' declaration
-SYNTAX_CLASS(AssocTypeDecl, SimpleTypeDecl)
- SYNTAX_FIELD(TypeExp, constraint)
+// An 'assoctype' declaration, it is a container of inheritance clauses
+SYNTAX_CLASS(AssocTypeDecl, ContainerDecl)
END_SYNTAX_CLASS()
// A scope for local declarations (e.g., as part of a statement)
diff --git a/source/slang/diagnostic-defs.h b/source/slang/diagnostic-defs.h
index cca1c4869..86b9b74ec 100644
--- a/source/slang/diagnostic-defs.h
+++ b/source/slang/diagnostic-defs.h
@@ -191,83 +191,10 @@ DIAGNOSTIC(30035, Error, componentOverloadTypeMismatch, "'$0': type of overloade
DIAGNOSTIC(30041, Error, bitOperationNonIntegral, "bit operation: operand must be integral type.")
DIAGNOSTIC(30047, Error, argumentExpectedLValue, "argument passed to parameter '$0' must be l-value.")
DIAGNOSTIC(30051, Error, invalidValueForArgument, "invalid value for argument '$0'")
-DIAGNOSTIC(30052, Error, ordinaryFunctionAsModuleArgument, "ordinary functions not allowed as argument to function-typed module parameter.")
-DIAGNOSTIC(30079, Error, selectPrdicateTypeMismatch, "selector must evaluate to bool.");
-DIAGNOSTIC(30080, Error, selectValuesTypeMismatch, "the two value expressions in a select clause must have same type.");
-DIAGNOSTIC(31040, Error, undefinedTypeName, "undefined type name: '$0'.")
-DIAGNOSTIC(32013, Error, circularReferenceNotAllowed, "'$0': circular reference is not allowed.");
-DIAGNOSTIC(32014, Error, shaderDoesProvideRequirement, "shader '$0' does not provide '$1' as required by '$2'.")
-DIAGNOSTIC(32015, Error, argumentNotAvilableInWorld, "argument '$0' is not available in world '$1' as required by '$2'.")
-DIAGNOSTIC(32015, Error, componentNotAvilableInWorld, "component '$0' is not available in world '$1' as required by '$2'.")
-DIAGNOSTIC(32047, Error, firstArgumentToImportNotComponent, "first argument of an import operator call does not resolve to a component.");
-DIAGNOSTIC(32051, Error, componentTypeNotWhatPipelineRequires, "component '$0' has type '$1', but pipeline '$2' requires it to be '$3'.")
-DIAGNOSTIC(32052, Error, shaderDoesNotDefineComponentAsRequiredByPipeline, "shader '$0' does not define '$1' as required by pipeline '$2''.")
-DIAGNOSTIC(33001, Error, worldNameAlreadyDefined, "world '$0' is already defined.")
-DIAGNOSTIC(33002, Error, explicitPipelineSpecificationRequiredForShader, "explicit pipeline specification required for shader '$0' because multiple pipelines are defined in current context.")
-DIAGNOSTIC(33003, Error, cannotDefineComponentsInAPipeline, "cannot define components in a pipeline.")
-DIAGNOSTIC(33004, Error, undefinedWorldName, "undefined world name '$0'.")
-DIAGNOSTIC(33005, Error, abstractWorldAsTargetOfImport, "abstract world cannot appear as target as an import operator.")
-
-// Note(tfoley): This is a duplicate of 33004 above.
-DIAGNOSTIC(33006, Error, undefinedWorldName2, "undefined world name '$0'.")
-
-DIAGNOSTIC(33007, Error, importOperatorCircularity, "import operator '$0' creates a circular dependency between world '$1' and '$2'")
-DIAGNOSTIC(33009, Error, parametersOnlyAllowedInModules, "parameters can only be defined in modules.")
-DIAGNOSTIC(33010, Error, undefinedPipelineName, "pipeline '$0' is undefined.")
-DIAGNOSTIC(33011, Error, shaderCircularity, "shader '$0' involves circular reference.")
-DIAGNOSTIC(33012, Error, worldIsNotDefinedInPipeline, "'$0' is not a defined world in '$1'.")
-DIAGNOSTIC(33013, Error, abstractWorldCannotAppearWithOthers, "abstract world cannot appear with other worlds.")
-DIAGNOSTIC(33014, Error, nonAbstractComponentMustHaveImplementation, "non-abstract component must have an implementation.")
-DIAGNOSTIC(33016, Error, usingInComponentDefinition, "'using': importing not allowed in component definition.")
-DIAGNOSTIC(33018, Error, nameAlreadyDefined, "'$0' is already defined.")
-DIAGNOSTIC(33018, Error, shaderAlreadyDefined, "shader '$0' has already been defined.")
-DIAGNOSTIC(33019, Error, componentMarkedExportMustHaveWorld, "component '$0': definition marked as 'export' must have an explicitly specified world.")
-DIAGNOSTIC(33020, Error, componentIsAlreadyDefined, "'$0' is already defined.")
-DIAGNOSTIC(33020, Error, componentIsAlreadyDefinedInThatWorld, "'$0' is already defined at '$1'.")
-DIAGNOSTIC(33021, Error, inconsistentSignatureForComponent, "'$0': inconsistent signature.")
-DIAGNOSTIC(33022, Error, nameAlreadyDefinedInCurrentScope, "'$0' is already defined in current scope.")
-DIAGNOSTIC(33022, Error, parameterNameConflictsWithExistingDefinition, "'$0': parameter name conflicts with existing definition.")
-DIAGNOSTIC(33023, Error, parameterOfModuleIsUnassigned, "parameter '$0' of module '$1' is unassigned.")
-DIAGNOSTIC(33027, Error, argumentTypeDoesNotMatchParameterType, "argument type ($0) does not match parameter type ($1)")
-DIAGNOSTIC(33028, Error, nameIsNotAParameterOfCallee, "'$0' is not a parameter of '$1'.")
-DIAGNOSTIC(33029, Error, requirementsClashWithPreviousDef, "'$0': requirement clash with previous definition.")
-DIAGNOSTIC(33030, Error, positionArgumentAfterNamed, "positional argument cannot appear after a named argument.")
-DIAGNOSTIC(33032, Error, functionRedefinition, "'$0': function redefinition.")
-DIAGNOSTIC(33034, Error, recordTypeVariableInImportOperator, "cannot declare a record-typed variable in an import operator.")
-DIAGNOSTIC(33037, Error, componetMarkedExportCannotHaveParameters, "component '$0': definition marked as 'export' cannot have parameters.")
-DIAGNOSTIC(33039, Error, componentInInputWorldCantHaveCode, "'$0': no code allowed for component defined in input world.")
-DIAGNOSTIC(33040, Error, requireWithComputation, "'require': cannot define computation on component requirements.")
-DIAGNOSTIC(33042, Error, paramWithComputation, "'param': cannot define computation on parameters.")
-DIAGNOSTIC(33041, Error, pipelineOfModuleIncompatibleWithPipelineOfShader, "pipeline '$0' targeted by module '$1' is incompatible with pipeline '$2' targeted by shader '$3'.")
DIAGNOSTIC(33070, Error, expectedFunction, "expression preceding parenthesis of apparent call must have function type.")
-DIAGNOSTIC(33071, Error, importOperatorCalledFromAutoPlacedComponent, "cannot call an import operator from an auto-placed component '$0'. try qualify the component with explicit worlds.")
-DIAGNOSTIC(33072, Error, noApplicableImportOperator, "'$0' is an import operator defined in pipeline '$1', but none of the import operator overloads converting to world '$2' matches argument list ($3).")
-DIAGNOSTIC(33073, Error, importOperatorCalledFromMultiWorldComponent, "cannot call an import operator from a multi-world component definition. consider qualify the component with only one explicit world.")
-DIAGNOSTIC(33080, Error, componentTypeDoesNotMatchInterface, "'$0': component type does not match definition in interface '$1'.")
-DIAGNOSTIC(33081, Error, shaderDidNotDefineComponentFunction, "shader '$0' did not define component function $1 as required by interface '$2'.")
-DIAGNOSTIC(33082, Error, shaderDidNotDefineComponent, "shader '$0' did not define component '$1' as required by interface '$2'.")
-DIAGNOSTIC(33083, Error, interfaceImplMustBePublic, "'$0': component fulfilling interface '$1' must be declared as 'public'.")
-DIAGNOSTIC(33084, Error, defaultParamNotAllowedInInterface, "'$0': default parameter value not allowed in interface definition.")
-
-DIAGNOSTIC(33100, Error, componentCantBeComputedAtWorldBecauseDependentNotAvailable, "'$0' cannot be computed at '$1' because the dependent component '$2' is not accessible.")
-DIAGNOSTIC(33101, Warning, worldIsNotAValidChoiceForKey, "'$0' is not a valid choice for '$1'.")
-DIAGNOSTIC(33102, Error, componentDefinitionCircularity, "component definition '$0' involves circular reference.")
-DIAGNOSTIC(34024, Error, componentAlreadyDefinedWhenCompiling, "component named '$0' is already defined when compiling '$1'.")
-DIAGNOSTIC(34025, Error, globalComponentConflictWithPreviousDeclaration, "'$0': global component conflicts with previous declaration.")
-DIAGNOSTIC(34026, Warning, componentIsAlreadyDefinedUseRequire, "'$0': component is already defined when compiling shader '$1'. use 'require' to declare it as a parameter.")
-DIAGNOSTIC(34062, Error, cylicReference, "cyclic reference: $0");
-DIAGNOSTIC(34064, Error, noApplicableImplicitImportOperator, "cannot find import operator to import component '$0' to world '$1' when compiling '$2'.")
-DIAGNOSTIC(34065, Error, resourceTypeMustBeParamOrRequire, "'$0': resource typed component must be declared as 'param' or 'require'.");
-DIAGNOSTIC(34066, Error, cannotDefineComputationOnResourceType, "'$0': cannot define computation on resource typed component.");
-
-DIAGNOSTIC(35001, Error, fragDepthAttributeCanOnlyApplyToOutput, "FragDepth attribute can only apply to an output component.");
-DIAGNOSTIC(35002, Error, fragDepthAttributeCanOnlyApplyToFloatComponent, "FragDepth attribute can only apply to a float component.");
-
-
-DIAGNOSTIC(36001, Error, insufficientTemplateShaderArguments, "instantiating template shader '$0': insufficient arguments.");
-DIAGNOSTIC(36002, Error, tooManyTemplateShaderArguments, "instantiating template shader '$0': too many arguments.");
-DIAGNOSTIC(36003, Error, templateShaderArgumentIsNotDefined, "'$0' provided as template shader argument to '$1' is not a defined module.");
-DIAGNOSTIC(36004, Error, templateShaderArgumentDidNotImplementRequiredInterface, "module '$0' provided as template shader argument to '$1' did not implement required interface '$2'.");
+
+// 303xx: interfaces and associated types
+DIAGNOSTIC(30300, Error, assocTypeInInterfaceOnly, "'associatedtype' can only be defined in an 'interface'.")
// TODO: need to assign numbers to all these extra diagnostics...
diff --git a/source/slang/emit.cpp b/source/slang/emit.cpp
index 2e8ab58d4..3454a3f85 100644
--- a/source/slang/emit.cpp
+++ b/source/slang/emit.cpp
@@ -1039,7 +1039,6 @@ struct EmitVisitor
UNEXPECTED(PtrType);
#undef UNEXPECTED
-
void visitNamedExpressionType(NamedExpressionType* type, TypeEmitArg const& arg)
{
// Named types are valid for GLSL
@@ -1053,6 +1052,11 @@ struct EmitVisitor
EmitDeclarator(arg.declarator);
}
+ void visitAssocTypeDeclRefType(AssocTypeDeclRefType* /*type*/, TypeEmitArg const& /*arg*/)
+ {
+ SLANG_UNREACHABLE("visitAssocTypeDeclRefType in EmitVisitor");
+ }
+
void visitBasicExpressionType(BasicExpressionType* basicType, TypeEmitArg const& arg)
{
auto declarator = arg.declarator;
@@ -3009,6 +3013,12 @@ struct EmitVisitor
Emit(";\n");
}
+ void visitAssocTypeDecl(AssocTypeDecl * /*assocType*/, DeclEmitArg const&)
+ {
+ SLANG_UNREACHABLE("visitAssocTypeDecl in EmitVisitor");
+ }
+
+
void visitImportDecl(ImportDecl* decl, DeclEmitArg const&)
{
// When in "rewriter" mode, we need to emit the code of the imported
diff --git a/source/slang/lower.cpp b/source/slang/lower.cpp
index 0f16a8ad7..776a55530 100644
--- a/source/slang/lower.cpp
+++ b/source/slang/lower.cpp
@@ -765,6 +765,7 @@ struct LoweringVisitor
loweredDeclRef.As<Decl>());
}
+
RefPtr<Type> visitNamedExpressionType(NamedExpressionType* type)
{
if (shared->target == CodeGenTarget::GLSL)
@@ -778,6 +779,13 @@ struct LoweringVisitor
translateDeclRef(DeclRef<Decl>(type->declRef)).As<TypeDefDecl>());
}
+ RefPtr<Type> visitAssocTypeDeclRefType(AssocTypeDeclRefType* type)
+ {
+ // not supported by lowering
+ SLANG_UNREACHABLE("visitAssocTypeDeclRefType in LowerVisitor");
+ return nullptr;
+ }
+
RefPtr<Type> visitTypeType(TypeType* type)
{
return getTypeType(lowerType(type->type));
@@ -2820,6 +2828,13 @@ struct LoweringVisitor
return LoweredDecl();
}
+ LoweredDecl visitAssocTypeDecl(AssocTypeDecl * /*assocType*/)
+ {
+ // not supported
+ SLANG_UNREACHABLE("visitAssocTypeDecl in LowerVisitor");
+ return LoweredDecl();
+ }
+
LoweredDecl visitTypeDefDecl(TypeDefDecl* decl)
{
if (shared->target == CodeGenTarget::GLSL)
diff --git a/source/slang/parser.cpp b/source/slang/parser.cpp
index 554eebc18..224450a66 100644
--- a/source/slang/parser.cpp
+++ b/source/slang/parser.cpp
@@ -544,21 +544,6 @@ namespace Slang
return typeDefDecl;
}
- RefPtr<RefObject> ParseAssocType(Parser * parser, void *)
- {
- RefPtr<AssocTypeDecl> assocTypeDecl = new AssocTypeDecl();
-
- auto nameToken = parser->ReadToken(TokenType::Identifier);
- assocTypeDecl->nameAndLoc = NameLoc(nameToken);
- assocTypeDecl->loc = nameToken.loc;
- if (parser->LookAheadToken(TokenType::Colon))
- {
- auto type = parser->ParseTypeExp();
- assocTypeDecl->constraint = type;
- }
- return assocTypeDecl;
- }
-
// Add a modifier to a list of modifiers being built
static void AddModifier(RefPtr<Modifier>** ioModifierLink, RefPtr<Modifier> modifier)
{
@@ -2102,7 +2087,7 @@ namespace Slang
return decl;
}
- static void parseOptionalInheritanceClause(Parser* parser, AggTypeDecl* decl)
+ static void parseOptionalInheritanceClause(Parser* parser, ContainerDecl* decl)
{
if( AdvanceIf(parser, TokenType::Colon) )
{
@@ -2121,6 +2106,18 @@ namespace Slang
}
}
+ RefPtr<RefObject> ParseAssocType(Parser * parser, void *)
+ {
+ RefPtr<AssocTypeDecl> assocTypeDecl = new AssocTypeDecl();
+
+ auto nameToken = parser->ReadToken(TokenType::Identifier);
+ assocTypeDecl->nameAndLoc = NameLoc(nameToken);
+ assocTypeDecl->loc = nameToken.loc;
+ parseOptionalInheritanceClause(parser, assocTypeDecl.Ptr());
+ parser->ReadToken(TokenType::Semicolon);
+ return assocTypeDecl;
+ }
+
static RefPtr<RefObject> parseInterfaceDecl(Parser* parser, void* /*userData*/)
{
RefPtr<InterfaceDecl> decl = new InterfaceDecl();
@@ -4062,7 +4059,7 @@ namespace Slang
#define DECL(KEYWORD, CALLBACK) \
addBuiltinSyntax<Decl>(session, scope, #KEYWORD, &CALLBACK)
DECL(typedef, ParseTypeDef);
- DECL(assoctype, ParseAssocType);
+ DECL(associatedtype,ParseAssocType);
DECL(cbuffer, parseHLSLCBufferDecl);
DECL(tbuffer, parseHLSLTBufferDecl);
DECL(__generic, ParseGenericDecl);
diff --git a/source/slang/syntax.cpp b/source/slang/syntax.cpp
index 165b2d132..97f3cfb15 100644
--- a/source/slang/syntax.cpp
+++ b/source/slang/syntax.cpp
@@ -957,6 +957,20 @@ void Type::accept(IValVisitor* visitor, void* extra)
return false;
}
+ RefPtr<Val> AssocTypeDeclRefType::SubstituteImpl(Substitutions* subst, int* ioDiff)
+ {
+ auto parentType = this->GetDeclRef().GetParent().SubstituteImpl(subst, ioDiff);
+ if (auto aggDeclRef = parentType.As<AggTypeDecl>())
+ {
+ Decl* targetTypeDecl = nullptr;
+ if (aggDeclRef.getDecl()->memberDictionary.TryGetValue(this->GetDeclRef().decl->getName(), targetTypeDecl))
+ {
+ return DeclRefType::Create(this->session, DeclRef<Decl>(targetTypeDecl, parentType.substitutions));
+ }
+ }
+ return this;
+ }
+
int AssocTypeDeclRefType::GetHashCode()
{
return declRef.GetHashCode();
diff --git a/source/slang/type-defs.h b/source/slang/type-defs.h
index 7648e0b87..15dfe3e07 100644
--- a/source/slang/type-defs.h
+++ b/source/slang/type-defs.h
@@ -495,6 +495,7 @@ END_SYNTAX_CLASS()
// The "type" of an expression that references a asscoiated type decl (via 'assoctype' keyword).
SYNTAX_CLASS(AssocTypeDeclRefType, Type)
DECL_FIELD(DeclRef<AssocTypeDecl>, declRef)
+
RAW(
AssocTypeDeclRefType()
{}
@@ -509,6 +510,7 @@ SYNTAX_CLASS(AssocTypeDeclRefType, Type)
virtual String ToString() override;
protected:
+ virtual RefPtr<Val> SubstituteImpl(Substitutions* subst, int* ioDiff) override;
virtual bool EqualsImpl(Type * type) override;
virtual int GetHashCode() override;
virtual Type* CreateCanonicalType() override;
diff --git a/tests/compute/assoctype-simple.slang b/tests/compute/assoctype-simple.slang
index 5a2c339a6..e03bb4e54 100644
--- a/tests/compute/assoctype-simple.slang
+++ b/tests/compute/assoctype-simple.slang
@@ -7,21 +7,21 @@ RWStructuredBuffer<float> outputBuffer;
interface ISimple
{
- assoctype T;
- T add(T v0, T v1);
+ associatedtype U;
+ U add(U v0, U v1);
}
struct Simple : ISimple
{
- typedef float T;
- T add(T v0, float v1)
+ typedef float U;
+ U add(U v0, float v1)
{
return v0 + v1;
}
};
__generic<T:ISimple>
-T.T test(T simple, T.T v0, T.T v1)
+T.U test(T simple, T.U v0, T.U v1)
{
return simple.add(v0, v1);
}
@@ -29,6 +29,7 @@ T.T test(T simple, T.T v0, T.T v1)
[numthreads(4, 1, 1)]
void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
{
- float outVal = test<Simple>(Simple(), 2.0, 1.0); // == 3.0
+ Simple s;
+ float outVal = test<Simple>(s, 2.0, 1.0); // == 3.0
outputBuffer[tid] = outVal;
} \ No newline at end of file