From f65d756bff8d4c5cbc15bd0322a2ae8e6b896a21 Mon Sep 17 00:00:00 2001 From: Ellie Hermaszewska Date: Tue, 29 Oct 2024 14:49:26 +0800 Subject: format * format * Minor test fixes * enable checking cpp format in ci --- source/slang/slang-api.cpp | 418 +- source/slang/slang-artifact-output-util.cpp | 104 +- source/slang/slang-artifact-output-util.h | 76 +- source/slang/slang-ast-all.h | 3 +- source/slang/slang-ast-base.cpp | 63 +- source/slang/slang-ast-base.h | 306 +- source/slang/slang-ast-builder.cpp | 202 +- source/slang/slang-ast-builder.h | 309 +- source/slang/slang-ast-decl-ref.cpp | 82 +- source/slang/slang-ast-decl.cpp | 20 +- source/slang/slang-ast-decl.h | 148 +- source/slang/slang-ast-dump.cpp | 234 +- source/slang/slang-ast-dump.h | 5 +- source/slang/slang-ast-expr.h | 219 +- source/slang/slang-ast-iterator.h | 64 +- source/slang/slang-ast-modifier.cpp | 12 +- source/slang/slang-ast-modifier.h | 688 +- source/slang/slang-ast-natural-layout.cpp | 45 +- source/slang/slang-ast-natural-layout.h | 91 +- source/slang/slang-ast-print.cpp | 92 +- source/slang/slang-ast-print.h | 128 +- source/slang/slang-ast-reflect.cpp | 60 +- source/slang/slang-ast-reflect.h | 61 +- source/slang/slang-ast-stmt.h | 65 +- source/slang/slang-ast-support-types.cpp | 11 +- source/slang/slang-ast-support-types.h | 2775 +-- source/slang/slang-ast-synthesis.cpp | 16 +- source/slang/slang-ast-synthesis.h | 13 +- source/slang/slang-ast-type.cpp | 336 +- source/slang/slang-ast-type.h | 216 +- source/slang/slang-ast-val.cpp | 482 +- source/slang/slang-ast-val.h | 235 +- source/slang/slang-capability.cpp | 491 +- source/slang/slang-capability.h | 121 +- source/slang/slang-check-conformance.cpp | 618 +- source/slang/slang-check-constraint.cpp | 2050 +- source/slang/slang-check-conversion.cpp | 2433 ++- source/slang/slang-check-decl.cpp | 20443 ++++++++++--------- source/slang/slang-check-expr.cpp | 8539 ++++---- source/slang/slang-check-impl.h | 5191 +++-- source/slang/slang-check-inheritance.cpp | 1961 +- source/slang/slang-check-modifier.cpp | 2986 +-- source/slang/slang-check-overload.cpp | 4674 ++--- source/slang/slang-check-resolve-val.cpp | 10 +- source/slang/slang-check-shader.cpp | 2997 +-- source/slang/slang-check-stmt.cpp | 1196 +- source/slang/slang-check-type.cpp | 716 +- source/slang/slang-check.cpp | 343 +- source/slang/slang-check.h | 31 +- source/slang/slang-compiler-options.cpp | 583 +- source/slang/slang-compiler-options.h | 638 +- source/slang/slang-compiler-tu.cpp | 474 +- source/slang/slang-compiler.cpp | 4229 ++-- source/slang/slang-compiler.h | 6027 +++--- source/slang/slang-container-pool.h | 11 +- source/slang/slang-content-assist-info.h | 6 +- source/slang/slang-core-module-textures.cpp | 238 +- source/slang/slang-core-module-textures.h | 41 +- source/slang/slang-core-module.cpp | 25 +- source/slang/slang-diagnostic-defs.h | 2596 ++- source/slang/slang-diagnostics.cpp | 44 +- source/slang/slang-diagnostics.h | 49 +- source/slang/slang-doc-ast.cpp | 34 +- source/slang/slang-doc-ast.h | 43 +- source/slang/slang-doc-markdown-writer.cpp | 512 +- source/slang/slang-doc-markdown-writer.h | 67 +- source/slang/slang-emit-base.cpp | 3 +- source/slang/slang-emit-base.h | 5 +- source/slang/slang-emit-c-like.cpp | 1663 +- source/slang/slang-emit-c-like.h | 394 +- source/slang/slang-emit-cpp.cpp | 761 +- source/slang/slang-emit-cpp.h | 75 +- source/slang/slang-emit-cuda.cpp | 591 +- source/slang/slang-emit-cuda.h | 69 +- source/slang/slang-emit-glsl.cpp | 1749 +- source/slang/slang-emit-glsl.h | 124 +- source/slang/slang-emit-hlsl.cpp | 1036 +- source/slang/slang-emit-hlsl.h | 80 +- source/slang/slang-emit-metal.cpp | 868 +- source/slang/slang-emit-metal.h | 44 +- source/slang/slang-emit-precedence.cpp | 70 +- source/slang/slang-emit-precedence.h | 27 +- source/slang/slang-emit-source-writer.cpp | 124 +- source/slang/slang-emit-source-writer.h | 58 +- source/slang/slang-emit-spirv-ops-debug-info-ext.h | 500 +- source/slang/slang-emit-spirv-ops.h | 840 +- source/slang/slang-emit-spirv.cpp | 3505 ++-- source/slang/slang-emit-torch.cpp | 204 +- source/slang/slang-emit-torch.h | 11 +- source/slang/slang-emit-wgsl.cpp | 1112 +- source/slang/slang-emit-wgsl.h | 31 +- source/slang/slang-emit.cpp | 687 +- source/slang/slang-glsl-extension-tracker.cpp | 39 +- source/slang/slang-glsl-extension-tracker.h | 22 +- .../slang/slang-hlsl-to-vulkan-layout-options.cpp | 61 +- source/slang/slang-hlsl-to-vulkan-layout-options.h | 150 +- source/slang/slang-intrinsic-expand.cpp | 203 +- source/slang/slang-intrinsic-expand.h | 20 +- source/slang/slang-ir-addr-inst-elimination.cpp | 55 +- source/slang/slang-ir-addr-inst-elimination.h | 6 +- source/slang/slang-ir-address-analysis.cpp | 251 +- source/slang/slang-ir-address-analysis.h | 30 +- source/slang/slang-ir-any-value-inference.cpp | 359 +- source/slang/slang-ir-any-value-inference.h | 8 +- source/slang/slang-ir-any-value-marshalling.cpp | 938 +- source/slang/slang-ir-any-value-marshalling.h | 21 +- source/slang/slang-ir-augment-make-existential.cpp | 15 +- source/slang/slang-ir-autodiff-cfg-norm.cpp | 97 +- source/slang/slang-ir-autodiff-cfg-norm.h | 36 +- source/slang/slang-ir-autodiff-fwd.cpp | 750 +- source/slang/slang-ir-autodiff-fwd.h | 23 +- source/slang/slang-ir-autodiff-pairs.cpp | 74 +- source/slang/slang-ir-autodiff-pairs.h | 17 +- source/slang/slang-ir-autodiff-primal-hoist.cpp | 504 +- source/slang/slang-ir-autodiff-primal-hoist.h | 545 +- source/slang/slang-ir-autodiff-propagate.h | 23 +- source/slang/slang-ir-autodiff-region.cpp | 87 +- source/slang/slang-ir-autodiff-region.h | 39 +- source/slang/slang-ir-autodiff-rev.cpp | 2386 +-- source/slang/slang-ir-autodiff-rev.h | 129 +- .../slang/slang-ir-autodiff-transcriber-base.cpp | 489 +- source/slang/slang-ir-autodiff-transcriber-base.h | 62 +- source/slang/slang-ir-autodiff-transpose.h | 1616 +- source/slang/slang-ir-autodiff-unzip.cpp | 114 +- source/slang/slang-ir-autodiff-unzip.h | 249 +- source/slang/slang-ir-autodiff.cpp | 1102 +- source/slang/slang-ir-autodiff.h | 236 +- source/slang/slang-ir-bind-existentials.cpp | 82 +- source/slang/slang-ir-bind-existentials.h | 8 +- source/slang/slang-ir-bit-field-accessors.cpp | 44 +- source/slang/slang-ir-bit-field-accessors.h | 6 +- source/slang/slang-ir-byte-address-legalize.cpp | 615 +- source/slang/slang-ir-byte-address-legalize.h | 30 +- source/slang/slang-ir-call-graph.cpp | 53 +- source/slang/slang-ir-call-graph.h | 12 +- source/slang/slang-ir-check-differentiability.cpp | 221 +- source/slang/slang-ir-check-recursive-type.cpp | 82 +- source/slang/slang-ir-check-recursive-type.h | 10 +- .../slang/slang-ir-check-shader-parameter-type.cpp | 93 +- .../slang/slang-ir-check-shader-parameter-type.h | 16 +- source/slang/slang-ir-check-unsupported-inst.cpp | 132 +- source/slang/slang-ir-check-unsupported-inst.h | 10 +- source/slang/slang-ir-cleanup-void.cpp | 275 +- source/slang/slang-ir-cleanup-void.h | 11 +- source/slang/slang-ir-clone.cpp | 101 +- source/slang/slang-ir-clone.h | 191 +- source/slang/slang-ir-collect-global-uniforms.cpp | 42 +- source/slang/slang-ir-collect-global-uniforms.h | 12 +- source/slang/slang-ir-com-interface.cpp | 40 +- source/slang/slang-ir-com-interface.h | 6 +- source/slang/slang-ir-composite-reg-to-mem.cpp | 307 +- source/slang/slang-ir-composite-reg-to-mem.h | 14 +- source/slang/slang-ir-constexpr.cpp | 199 +- source/slang/slang-ir-constexpr.h | 10 +- source/slang/slang-ir-dce.cpp | 78 +- source/slang/slang-ir-dce.h | 62 +- .../slang-ir-deduplicate-generic-children.cpp | 5 +- .../slang/slang-ir-deduplicate-generic-children.h | 12 +- source/slang/slang-ir-deduplicate.cpp | 110 +- source/slang/slang-ir-defunctionalization.cpp | 4 +- source/slang/slang-ir-defunctionalization.h | 24 +- source/slang/slang-ir-diff-call.cpp | 35 +- source/slang/slang-ir-diff-call.h | 18 +- source/slang/slang-ir-dll-export.cpp | 17 +- source/slang/slang-ir-dll-export.h | 10 +- source/slang/slang-ir-dll-import.cpp | 48 +- source/slang/slang-ir-dll-import.h | 12 +- source/slang/slang-ir-dominators.cpp | 162 +- source/slang/slang-ir-dominators.h | 356 +- ...r-early-raytracing-intrinsic-simplification.cpp | 208 +- ...-ir-early-raytracing-intrinsic-simplification.h | 17 +- .../slang/slang-ir-eliminate-multilevel-break.cpp | 93 +- source/slang/slang-ir-eliminate-multilevel-break.h | 10 +- source/slang/slang-ir-eliminate-phis.cpp | 92 +- source/slang/slang-ir-eliminate-phis.h | 44 +- source/slang/slang-ir-entry-point-pass.cpp | 6 +- source/slang/slang-ir-entry-point-pass.h | 4 +- .../slang/slang-ir-entry-point-raw-ptr-params.cpp | 23 +- source/slang/slang-ir-entry-point-raw-ptr-params.h | 7 +- source/slang/slang-ir-entry-point-uniforms.cpp | 74 +- source/slang/slang-ir-entry-point-uniforms.h | 13 +- source/slang/slang-ir-explicit-global-context.cpp | 140 +- source/slang/slang-ir-explicit-global-context.h | 8 +- source/slang/slang-ir-explicit-global-init.cpp | 34 +- source/slang/slang-ir-explicit-global-init.h | 7 +- source/slang/slang-ir-extract-value-from-type.cpp | 93 +- source/slang/slang-ir-extract-value-from-type.h | 8 +- source/slang/slang-ir-fuse-satcoop.cpp | 177 +- source/slang/slang-ir-fuse-satcoop.h | 12 +- .../slang/slang-ir-generics-lowering-context.cpp | 394 +- source/slang/slang-ir-generics-lowering-context.h | 250 +- source/slang/slang-ir-glsl-legalize.cpp | 1861 +- source/slang/slang-ir-glsl-legalize.h | 14 +- source/slang/slang-ir-glsl-liveness.cpp | 73 +- source/slang/slang-ir-glsl-liveness.h | 14 +- source/slang/slang-ir-hlsl-legalize.cpp | 122 +- source/slang/slang-ir-hlsl-legalize.h | 4 +- source/slang/slang-ir-init-local-var.cpp | 8 +- source/slang/slang-ir-init-local-var.h | 12 +- source/slang/slang-ir-inline.cpp | 276 +- source/slang/slang-ir-inline.h | 55 +- source/slang/slang-ir-insert-debug-value-store.cpp | 333 +- source/slang/slang-ir-insert-debug-value-store.h | 6 +- source/slang/slang-ir-inst-pass-base.h | 217 +- source/slang/slang-ir-insts.h | 2292 +-- source/slang/slang-ir-layout.cpp | 466 +- source/slang/slang-ir-layout.h | 88 +- .../slang/slang-ir-legalize-array-return-type.cpp | 8 +- source/slang/slang-ir-legalize-array-return-type.h | 10 +- source/slang/slang-ir-legalize-image-subscript.cpp | 326 +- source/slang/slang-ir-legalize-image-subscript.h | 8 +- source/slang/slang-ir-legalize-mesh-outputs.cpp | 25 +- source/slang/slang-ir-legalize-mesh-outputs.h | 8 +- source/slang/slang-ir-legalize-types.cpp | 1822 +- .../slang-ir-legalize-uniform-buffer-load.cpp | 9 +- .../slang/slang-ir-legalize-uniform-buffer-load.h | 8 +- source/slang/slang-ir-legalize-varying-params.cpp | 477 +- source/slang/slang-ir-legalize-varying-params.h | 76 +- source/slang/slang-ir-legalize-vector-types.cpp | 316 +- source/slang/slang-ir-legalize-vector-types.h | 14 +- source/slang/slang-ir-link.cpp | 809 +- source/slang/slang-ir-link.h | 52 +- source/slang/slang-ir-liveness.cpp | 611 +- source/slang/slang-ir-liveness.h | 179 +- source/slang/slang-ir-loop-inversion.cpp | 170 +- source/slang/slang-ir-loop-inversion.h | 9 +- source/slang/slang-ir-loop-unroll.cpp | 82 +- source/slang/slang-ir-loop-unroll.h | 38 +- ...g-ir-lower-append-consume-structured-buffer.cpp | 553 +- ...ang-ir-lower-append-consume-structured-buffer.h | 19 +- source/slang/slang-ir-lower-binding-query.cpp | 912 +- source/slang/slang-ir-lower-binding-query.h | 24 +- source/slang/slang-ir-lower-bit-cast.cpp | 70 +- source/slang/slang-ir-lower-bit-cast.h | 6 +- .../slang/slang-ir-lower-buffer-element-type.cpp | 1821 +- source/slang/slang-ir-lower-buffer-element-type.h | 33 +- source/slang/slang-ir-lower-com-methods.cpp | 36 +- source/slang/slang-ir-lower-com-methods.h | 5 +- .../slang-ir-lower-combined-texture-sampler.cpp | 377 +- .../slang-ir-lower-combined-texture-sampler.h | 19 +- source/slang/slang-ir-lower-cuda-builtin-types.cpp | 796 +- source/slang/slang-ir-lower-cuda-builtin-types.h | 82 +- source/slang/slang-ir-lower-error-handling.cpp | 33 +- source/slang/slang-ir-lower-error-handling.h | 12 +- source/slang/slang-ir-lower-existential.cpp | 474 +- source/slang/slang-ir-lower-existential.h | 10 +- source/slang/slang-ir-lower-expand-type.cpp | 261 +- source/slang/slang-ir-lower-expand-type.h | 50 +- source/slang/slang-ir-lower-generic-call.cpp | 653 +- source/slang/slang-ir-lower-generic-call.h | 11 +- source/slang/slang-ir-lower-generic-function.cpp | 678 +- source/slang/slang-ir-lower-generic-function.h | 25 +- source/slang/slang-ir-lower-generic-type.cpp | 116 +- source/slang/slang-ir-lower-generic-type.h | 11 +- source/slang/slang-ir-lower-generics.cpp | 487 +- source/slang/slang-ir-lower-generics.h | 23 +- source/slang/slang-ir-lower-glsl-ssbo-types.cpp | 310 +- source/slang/slang-ir-lower-glsl-ssbo-types.h | 16 +- source/slang/slang-ir-lower-l-value-cast.cpp | 63 +- source/slang/slang-ir-lower-l-value-cast.h | 20 +- source/slang/slang-ir-lower-optional-type.cpp | 475 +- source/slang/slang-ir-lower-optional-type.h | 12 +- source/slang/slang-ir-lower-reinterpret.cpp | 48 +- source/slang/slang-ir-lower-reinterpret.h | 6 +- source/slang/slang-ir-lower-result-type.cpp | 492 +- source/slang/slang-ir-lower-result-type.h | 12 +- source/slang/slang-ir-lower-tuple-types.cpp | 661 +- source/slang/slang-ir-lower-tuple-types.h | 12 +- source/slang/slang-ir-lower-witness-lookup.cpp | 79 +- source/slang/slang-ir-lower-witness-lookup.h | 17 +- source/slang/slang-ir-marshal-native-call.cpp | 506 +- source/slang/slang-ir-marshal-native-call.h | 115 +- source/slang/slang-ir-metadata.cpp | 34 +- source/slang/slang-ir-metadata.h | 2 +- source/slang/slang-ir-metal-legalize.cpp | 3136 +-- source/slang/slang-ir-metal-legalize.h | 6 +- source/slang/slang-ir-missing-return.cpp | 23 +- source/slang/slang-ir-missing-return.h | 10 +- source/slang/slang-ir-obfuscate-loc.cpp | 88 +- source/slang/slang-ir-obfuscate-loc.h | 9 +- source/slang/slang-ir-operator-shift-overflow.cpp | 64 +- source/slang/slang-ir-operator-shift-overflow.h | 14 +- .../slang/slang-ir-optix-entry-point-uniforms.cpp | 72 +- source/slang/slang-ir-optix-entry-point-uniforms.h | 2 +- source/slang/slang-ir-peephole.cpp | 257 +- source/slang/slang-ir-peephole.h | 40 +- .../slang/slang-ir-propagate-func-properties.cpp | 35 +- source/slang/slang-ir-propagate-func-properties.h | 2 +- source/slang/slang-ir-pytorch-cpp-binding.cpp | 601 +- source/slang/slang-ir-pytorch-cpp-binding.h | 3 +- source/slang/slang-ir-reachability.cpp | 127 +- source/slang/slang-ir-reachability.h | 5 +- source/slang/slang-ir-redundancy-removal.cpp | 34 +- source/slang/slang-ir-redundancy-removal.h | 14 +- .../slang/slang-ir-remove-unused-generic-param.cpp | 19 +- .../slang/slang-ir-remove-unused-generic-param.h | 6 +- source/slang/slang-ir-resolve-texture-format.cpp | 158 +- source/slang/slang-ir-resolve-texture-format.h | 2 +- source/slang/slang-ir-restructure-scoping.cpp | 119 +- source/slang/slang-ir-restructure-scoping.h | 2 +- source/slang/slang-ir-restructure.cpp | 1111 +- source/slang/slang-ir-restructure.h | 429 +- source/slang/slang-ir-sccp.cpp | 433 +- source/slang/slang-ir-sccp.h | 39 +- source/slang/slang-ir-simplify-cfg.cpp | 127 +- source/slang/slang-ir-simplify-cfg.h | 30 +- source/slang/slang-ir-simplify-for-emit.cpp | 80 +- source/slang/slang-ir-simplify-for-emit.h | 8 +- source/slang/slang-ir-single-return.cpp | 17 +- source/slang/slang-ir-single-return.h | 14 +- source/slang/slang-ir-specialize-address-space.cpp | 561 +- source/slang/slang-ir-specialize-address-space.h | 42 +- source/slang/slang-ir-specialize-arrays.cpp | 8 +- source/slang/slang-ir-specialize-arrays.h | 32 +- .../slang/slang-ir-specialize-buffer-load-arg.cpp | 10 +- source/slang/slang-ir-specialize-buffer-load-arg.h | 64 +- source/slang/slang-ir-specialize-dispatch.cpp | 44 +- source/slang/slang-ir-specialize-dispatch.h | 2 +- ...ir-specialize-dynamic-associatedtype-lookup.cpp | 112 +- source/slang/slang-ir-specialize-function-call.cpp | 167 +- source/slang/slang-ir-specialize-function-call.h | 52 +- source/slang/slang-ir-specialize-matrix-layout.cpp | 58 +- source/slang/slang-ir-specialize-matrix-layout.h | 14 +- source/slang/slang-ir-specialize-resources.cpp | 316 +- source/slang/slang-ir-specialize-resources.h | 46 +- source/slang/slang-ir-specialize-target-switch.cpp | 149 +- source/slang/slang-ir-specialize-target-switch.h | 14 +- source/slang/slang-ir-specialize.cpp | 393 +- source/slang/slang-ir-specialize.h | 9 +- source/slang/slang-ir-spirv-legalize.cpp | 627 +- source/slang/slang-ir-spirv-legalize.h | 37 +- source/slang/slang-ir-spirv-snippet.cpp | 115 +- source/slang/slang-ir-spirv-snippet.h | 22 +- source/slang/slang-ir-ssa-register-allocate.cpp | 46 +- source/slang/slang-ir-ssa-register-allocate.h | 7 +- source/slang/slang-ir-ssa-simplification.cpp | 261 +- source/slang/slang-ir-ssa-simplification.h | 67 +- source/slang/slang-ir-ssa.cpp | 226 +- source/slang/slang-ir-ssa.h | 14 +- source/slang/slang-ir-string-hash.cpp | 16 +- source/slang/slang-ir-string-hash.h | 9 +- source/slang/slang-ir-strip-cached-dict.cpp | 10 +- source/slang/slang-ir-strip-cached-dict.h | 10 +- source/slang/slang-ir-strip-witness-tables.cpp | 8 +- source/slang/slang-ir-strip-witness-tables.h | 6 +- source/slang/slang-ir-strip.cpp | 38 +- source/slang/slang-ir-strip.h | 20 +- source/slang/slang-ir-synthesize-active-mask.cpp | 179 +- source/slang/slang-ir-synthesize-active-mask.h | 28 +- .../slang/slang-ir-translate-glsl-global-var.cpp | 528 +- source/slang/slang-ir-translate-glsl-global-var.h | 12 +- source/slang/slang-ir-uniformity.cpp | 553 +- source/slang/slang-ir-uniformity.h | 8 +- source/slang/slang-ir-use-uninitialized-values.cpp | 1098 +- source/slang/slang-ir-use-uninitialized-values.h | 10 +- source/slang/slang-ir-user-type-hint.cpp | 9 +- source/slang/slang-ir-user-type-hint.h | 6 +- source/slang/slang-ir-util.cpp | 653 +- source/slang/slang-ir-util.h | 126 +- source/slang/slang-ir-validate.cpp | 652 +- source/slang/slang-ir-validate.h | 66 +- .../slang/slang-ir-variable-scope-correction.cpp | 120 +- source/slang/slang-ir-variable-scope-correction.h | 2 +- source/slang/slang-ir-vk-invert-y.cpp | 35 +- source/slang/slang-ir-vk-invert-y.h | 8 +- source/slang/slang-ir-wgsl-legalize.cpp | 643 +- source/slang/slang-ir-wgsl-legalize.h | 6 +- source/slang/slang-ir-witness-table-wrapper.cpp | 404 +- source/slang/slang-ir-witness-table-wrapper.h | 31 +- source/slang/slang-ir-wrap-structured-buffers.cpp | 258 +- source/slang/slang-ir-wrap-structured-buffers.h | 16 +- source/slang/slang-ir.cpp | 14841 +++++++------- source/slang/slang-ir.h | 1057 +- source/slang/slang-language-server-ast-lookup.cpp | 106 +- source/slang/slang-language-server-ast-lookup.h | 2 +- source/slang/slang-language-server-auto-format.cpp | 63 +- source/slang/slang-language-server-auto-format.h | 10 +- source/slang/slang-language-server-completion.cpp | 169 +- source/slang/slang-language-server-completion.h | 22 +- .../slang-language-server-document-symbols.cpp | 369 +- .../slang/slang-language-server-document-symbols.h | 9 +- source/slang/slang-language-server-inlay-hints.cpp | 34 +- source/slang/slang-language-server-inlay-hints.h | 11 +- .../slang-language-server-semantic-tokens.cpp | 104 +- .../slang/slang-language-server-semantic-tokens.h | 21 +- source/slang/slang-language-server.cpp | 528 +- source/slang/slang-language-server.h | 86 +- source/slang/slang-legalize-types.cpp | 377 +- source/slang/slang-legalize-types.h | 442 +- source/slang/slang-lookup-spirv.h | 3 +- source/slang/slang-lookup.cpp | 494 +- source/slang/slang-lookup.h | 82 +- source/slang/slang-lower-to-ir.cpp | 3621 ++-- source/slang/slang-lower-to-ir.h | 75 +- source/slang/slang-mangle.cpp | 1382 +- source/slang/slang-mangle.h | 34 +- source/slang/slang-mangled-lexer.cpp | 36 +- source/slang/slang-mangled-lexer.h | 18 +- source/slang/slang-module-library.cpp | 44 +- source/slang/slang-module-library.h | 33 +- source/slang/slang-options.cpp | 1887 +- source/slang/slang-options.h | 9 +- source/slang/slang-parameter-binding.cpp | 1582 +- source/slang/slang-parameter-binding.h | 15 +- source/slang/slang-parser.cpp | 13777 +++++++------ source/slang/slang-parser.h | 77 +- source/slang/slang-preprocessor.cpp | 1669 +- source/slang/slang-preprocessor.h | 72 +- source/slang/slang-profile-defs.h | 364 +- source/slang/slang-profile.cpp | 22 +- source/slang/slang-profile.h | 193 +- source/slang/slang-ref-object-reflect.cpp | 79 +- source/slang/slang-ref-object-reflect.h | 46 +- source/slang/slang-reflection-api.cpp | 3124 +-- source/slang/slang-repro.cpp | 487 +- source/slang/slang-repro.h | 144 +- source/slang/slang-serialize-ast-type-info.h | 140 +- source/slang/slang-serialize-ast.cpp | 94 +- source/slang/slang-serialize-ast.h | 30 +- source/slang/slang-serialize-container.cpp | 392 +- source/slang/slang-serialize-container.h | 103 +- source/slang/slang-serialize-factory.cpp | 22 +- source/slang/slang-serialize-factory.h | 14 +- source/slang/slang-serialize-ir-types.cpp | 85 +- source/slang/slang-serialize-ir-types.h | 184 +- source/slang/slang-serialize-ir.cpp | 404 +- source/slang/slang-serialize-ir.h | 123 +- source/slang/slang-serialize-misc-type-info.h | 106 +- source/slang/slang-serialize-reflection.cpp | 30 +- source/slang/slang-serialize-reflection.h | 36 +- source/slang/slang-serialize-source-loc.cpp | 163 +- source/slang/slang-serialize-source-loc.h | 144 +- source/slang/slang-serialize-type-info.h | 224 +- source/slang/slang-serialize-types.cpp | 124 +- source/slang/slang-serialize-types.h | 163 +- source/slang/slang-serialize-value-type-info.h | 106 +- source/slang/slang-serialize.cpp | 169 +- source/slang/slang-serialize.h | 374 +- source/slang/slang-spirv-val.cpp | 11 +- source/slang/slang-spirv-val.h | 4 +- source/slang/slang-syntax.cpp | 1089 +- source/slang/slang-syntax.h | 592 +- source/slang/slang-type-layout.cpp | 1927 +- source/slang/slang-type-layout.h | 764 +- source/slang/slang-type-system-shared.h | 184 +- source/slang/slang-value-reflect.cpp | 6 +- source/slang/slang-value-reflect.h | 3 +- source/slang/slang-visitor.h | 138 +- source/slang/slang-workspace-version.cpp | 69 +- source/slang/slang-workspace-version.h | 364 +- source/slang/slang.cpp | 1944 +- 451 files changed, 106613 insertions(+), 97598 deletions(-) (limited to 'source/slang') diff --git a/source/slang/slang-api.cpp b/source/slang/slang-api.cpp index ce1103610..18d2a5083 100644 --- a/source/slang/slang-api.cpp +++ b/source/slang/slang-api.cpp @@ -1,13 +1,13 @@ // slang-api.cpp -#include "slang-compiler.h" -#include "slang-repro.h" -#include "slang-capability.h" -#include "../core/slang-rtti-info.h" #include "../core/slang-performance-profiler.h" +#include "../core/slang-rtti-info.h" #include "../core/slang-shared-library.h" #include "../slang-record-replay/record/slang-global-session.h" #include "../slang-record-replay/util/record-utility.h" +#include "slang-capability.h" +#include "slang-compiler.h" +#include "slang-repro.h" // implementation of C interface @@ -22,9 +22,9 @@ SLANG_API SlangSession* spCreateSession(const char*) return globalSession.detach(); } -// Attempt to load a previously compiled core module from the same file system location as the slang dll. -// Returns SLANG_OK when the cache is sucessfully loaded. -// Also returns the filename to the core module cache and the timestamp of current slang dll. +// Attempt to load a previously compiled core module from the same file system location as the slang +// dll. Returns SLANG_OK when the cache is sucessfully loaded. Also returns the filename to the core +// module cache and the timestamp of current slang dll. SlangResult tryLoadCoreModuleFromCache( slang::IGlobalSession* globalSession, Slang::String& outCachePath, @@ -65,37 +65,43 @@ SlangResult trySaveCoreModuleToCache( if (dllTimestamp != 0 && cacheFilename.getLength() != 0) { Slang::ComPtr coreModuleBlobPtr; - SLANG_RETURN_ON_FAIL( - globalSession->saveCoreModule(SLANG_ARCHIVE_TYPE_RIFF_LZ4, coreModuleBlobPtr.writeRef())); + SLANG_RETURN_ON_FAIL(globalSession->saveCoreModule( + SLANG_ARCHIVE_TYPE_RIFF_LZ4, + coreModuleBlobPtr.writeRef())); Slang::FileStream fileStream; SLANG_RETURN_ON_FAIL(fileStream.init(cacheFilename, Slang::FileMode::Create)); SLANG_RETURN_ON_FAIL(fileStream.write(&dllTimestamp, sizeof(dllTimestamp))); - SLANG_RETURN_ON_FAIL(fileStream.write(coreModuleBlobPtr->getBufferPointer(), coreModuleBlobPtr->getBufferSize())) + SLANG_RETURN_ON_FAIL(fileStream.write( + coreModuleBlobPtr->getBufferPointer(), + coreModuleBlobPtr->getBufferSize())) } return SLANG_OK; } -SLANG_API SlangResult slang_createGlobalSession( - SlangInt apiVersion, - slang::IGlobalSession** outGlobalSession) +SLANG_API SlangResult +slang_createGlobalSession(SlangInt apiVersion, slang::IGlobalSession** outGlobalSession) { Slang::ComPtr globalSession; #ifdef SLANG_ENABLE_IR_BREAK_ALLOC - // Set inst debug alloc counter to 0 so IRInsts for core module always starts from a large value. + // Set inst debug alloc counter to 0 so IRInsts for core module always starts from a large + // value. Slang::_debugGetIRAllocCounter() = 0x80000000; #endif - SLANG_RETURN_ON_FAIL(slang_createGlobalSessionWithoutCoreModule(apiVersion, globalSession.writeRef())); + SLANG_RETURN_ON_FAIL( + slang_createGlobalSessionWithoutCoreModule(apiVersion, globalSession.writeRef())); // If we have the embedded core module, load from that, else compile it ISlangBlob* coreModuleBlob = slang_getEmbeddedCoreModule(); if (coreModuleBlob) { - SLANG_RETURN_ON_FAIL(globalSession->loadCoreModule(coreModuleBlob->getBufferPointer(), coreModuleBlob->getBufferSize())); + SLANG_RETURN_ON_FAIL(globalSession->loadCoreModule( + coreModuleBlob->getBufferPointer(), + coreModuleBlob->getBufferSize())); } else { @@ -149,7 +155,7 @@ SLANG_API void slang_shutdown() } SLANG_API SlangResult slang_createGlobalSessionWithoutCoreModule( - SlangInt apiVersion, + SlangInt apiVersion, slang::IGlobalSession** outGlobalSession) { if (apiVersion != 0) @@ -167,10 +173,10 @@ SLANG_API SlangResult slang_createGlobalSessionWithoutCoreModule( return SLANG_OK; } -SLANG_API void spDestroySession( - SlangSession* inSession) +SLANG_API void spDestroySession(SlangSession* inSession) { - if (!inSession) return; + if (!inSession) + return; Slang::Session* session = Slang::asInternal(inSession); // It is assumed there is only a single reference on the session (the one placed @@ -186,42 +192,38 @@ SLANG_API const char* spGetBuildTagString() } SLANG_API void spAddBuiltins( - SlangSession* session, - char const* sourcePath, - char const* sourceString) + SlangSession* session, + char const* sourcePath, + char const* sourceString) { session->addBuiltins(sourcePath, sourceString); } SLANG_API void spSessionSetSharedLibraryLoader( - SlangSession* session, + SlangSession* session, ISlangSharedLibraryLoader* loader) { session->setSharedLibraryLoader(loader); } -SLANG_API ISlangSharedLibraryLoader* spSessionGetSharedLibraryLoader( - SlangSession* session) +SLANG_API ISlangSharedLibraryLoader* spSessionGetSharedLibraryLoader(SlangSession* session) { return session->getSharedLibraryLoader(); } -SLANG_API SlangResult spSessionCheckCompileTargetSupport( - SlangSession* session, - SlangCompileTarget target) +SLANG_API SlangResult +spSessionCheckCompileTargetSupport(SlangSession* session, SlangCompileTarget target) { return session->checkCompileTargetSupport(target); } -SLANG_API SlangResult spSessionCheckPassThroughSupport( - SlangSession* session, - SlangPassThrough passThrough) +SLANG_API SlangResult +spSessionCheckPassThroughSupport(SlangSession* session, SlangPassThrough passThrough) { return session->checkPassThroughSupport(passThrough); } -SLANG_API SlangCompileRequest* spCreateCompileRequest( - SlangSession* session) +SLANG_API SlangCompileRequest* spCreateCompileRequest(SlangSession* session) { slang::ICompileRequest* request = nullptr; // Will return with suitable ref count @@ -231,16 +233,12 @@ SLANG_API SlangCompileRequest* spCreateCompileRequest( return request; } -SLANG_API SlangProfileID spFindProfile( - SlangSession* session, - char const* name) +SLANG_API SlangProfileID spFindProfile(SlangSession* session, char const* name) { return session->findProfile(name); } -SLANG_API SlangCapabilityID spFindCapability( - SlangSession* session, - char const* name) +SLANG_API SlangCapabilityID spFindCapability(SlangSession* session, char const* name) { return session->findCapability(name); } @@ -250,8 +248,7 @@ SLANG_API SlangCapabilityID spFindCapability( /*! @brief Destroy a compile request. */ -SLANG_API void spDestroyCompileRequest( - slang::ICompileRequest* request) +SLANG_API void spDestroyCompileRequest(slang::ICompileRequest* request) { if (request) { @@ -261,61 +258,55 @@ SLANG_API void spDestroyCompileRequest( /* All other functions just call into the ICompileResult interface. */ -SLANG_API void spSetFileSystem( - slang::ICompileRequest* request, - ISlangFileSystem* fileSystem) +SLANG_API void spSetFileSystem(slang::ICompileRequest* request, ISlangFileSystem* fileSystem) { SLANG_ASSERT(request); request->setFileSystem(fileSystem); } -SLANG_API void spSetCompileFlags( - slang::ICompileRequest* request, - SlangCompileFlags flags) +SLANG_API void spSetCompileFlags(slang::ICompileRequest* request, SlangCompileFlags flags) { SLANG_ASSERT(request); request->setCompileFlags(flags); } -SLANG_API SlangCompileFlags spGetCompileFlags( - slang::ICompileRequest* request) +SLANG_API SlangCompileFlags spGetCompileFlags(slang::ICompileRequest* request) { SLANG_ASSERT(request); return request->getCompileFlags(); } -SLANG_API void spSetDumpIntermediates( - slang::ICompileRequest* request, - int enable) +SLANG_API void spSetDumpIntermediates(slang::ICompileRequest* request, int enable) { SLANG_ASSERT(request); request->setDumpIntermediates(enable); } -SLANG_API void spSetDumpIntermediatePrefix( - slang::ICompileRequest* request, - const char* prefix) +SLANG_API void spSetDumpIntermediatePrefix(slang::ICompileRequest* request, const char* prefix) { SLANG_ASSERT(request); request->setDumpIntermediatePrefix(prefix); } -SLANG_API void spSetLineDirectiveMode( - slang::ICompileRequest* request, - SlangLineDirectiveMode mode) +SLANG_API void spSetLineDirectiveMode(slang::ICompileRequest* request, SlangLineDirectiveMode mode) { SLANG_ASSERT(request); request->setLineDirectiveMode(mode); } SLANG_API void spSetTargetForceGLSLScalarBufferLayout( - slang::ICompileRequest* request, int targetIndex, bool forceScalarLayout) + slang::ICompileRequest* request, + int targetIndex, + bool forceScalarLayout) { SLANG_ASSERT(request); request->setTargetForceGLSLScalarBufferLayout(targetIndex, forceScalarLayout); } -SLANG_API void spSetTargetUseMinimumSlangOptimization(slang::ICompileRequest* request, int targetIndex, bool val) +SLANG_API void spSetTargetUseMinimumSlangOptimization( + slang::ICompileRequest* request, + int targetIndex, + bool val) { SLANG_ASSERT(request); request->setTargetUseMinimumSlangOptimization(targetIndex, val); @@ -336,51 +327,46 @@ SLANG_API void spSetTargetLineDirectiveMode( request->setTargetLineDirectiveMode(targetIndex, mode); } -SLANG_API void spSetCommandLineCompilerMode( - slang::ICompileRequest* request) +SLANG_API void spSetCommandLineCompilerMode(slang::ICompileRequest* request) { SLANG_ASSERT(request); request->setCommandLineCompilerMode(); } -SLANG_API void spSetCodeGenTarget( - slang::ICompileRequest* request, - SlangCompileTarget target) +SLANG_API void spSetCodeGenTarget(slang::ICompileRequest* request, SlangCompileTarget target) { SLANG_ASSERT(request); request->setCodeGenTarget(target); } -SLANG_API int spAddCodeGenTarget( - slang::ICompileRequest* request, - SlangCompileTarget target) +SLANG_API int spAddCodeGenTarget(slang::ICompileRequest* request, SlangCompileTarget target) { SLANG_ASSERT(request); return request->addCodeGenTarget(target); } SLANG_API void spSetTargetProfile( - slang::ICompileRequest* request, - int targetIndex, - SlangProfileID profile) + slang::ICompileRequest* request, + int targetIndex, + SlangProfileID profile) { SLANG_ASSERT(request); request->setTargetProfile(targetIndex, profile); } SLANG_API void spSetTargetFlags( - slang::ICompileRequest* request, - int targetIndex, - SlangTargetFlags flags) + slang::ICompileRequest* request, + int targetIndex, + SlangTargetFlags flags) { SLANG_ASSERT(request); request->setTargetFlags(targetIndex, flags); } SLANG_API void spSetTargetFloatingPointMode( - slang::ICompileRequest* request, - int targetIndex, - SlangFloatingPointMode mode) + slang::ICompileRequest* request, + int targetIndex, + SlangFloatingPointMode mode) { SLANG_ASSERT(request); request->setTargetFloatingPointMode(targetIndex, mode); @@ -388,123 +374,107 @@ SLANG_API void spSetTargetFloatingPointMode( SLANG_API void spAddTargetCapability( slang::ICompileRequest* request, - int targetIndex, - SlangCapabilityID capability) + int targetIndex, + SlangCapabilityID capability) { SLANG_ASSERT(request); request->addTargetCapability(targetIndex, capability); } -SLANG_API void spSetMatrixLayoutMode( - slang::ICompileRequest* request, - SlangMatrixLayoutMode mode) +SLANG_API void spSetMatrixLayoutMode(slang::ICompileRequest* request, SlangMatrixLayoutMode mode) { SLANG_ASSERT(request); request->setMatrixLayoutMode(mode); } SLANG_API void spSetTargetMatrixLayoutMode( - slang::ICompileRequest* request, - int targetIndex, - SlangMatrixLayoutMode mode) + slang::ICompileRequest* request, + int targetIndex, + SlangMatrixLayoutMode mode) { SLANG_ASSERT(request); request->setTargetMatrixLayoutMode(targetIndex, mode); } -SLANG_API void spSetDebugInfoLevel( - slang::ICompileRequest* request, - SlangDebugInfoLevel level) +SLANG_API void spSetDebugInfoLevel(slang::ICompileRequest* request, SlangDebugInfoLevel level) { SLANG_ASSERT(request); request->setDebugInfoLevel(level); } -SLANG_API void spSetDebugInfoFormat( - slang::ICompileRequest* request, - SlangDebugInfoFormat format) +SLANG_API void spSetDebugInfoFormat(slang::ICompileRequest* request, SlangDebugInfoFormat format) { SLANG_ASSERT(request); request->setDebugInfoFormat(format); } -SLANG_API void spSetOptimizationLevel( - slang::ICompileRequest* request, - SlangOptimizationLevel level) +SLANG_API void spSetOptimizationLevel(slang::ICompileRequest* request, SlangOptimizationLevel level) { SLANG_ASSERT(request); request->setOptimizationLevel(level); } SLANG_API void spSetOutputContainerFormat( - slang::ICompileRequest* request, - SlangContainerFormat format) + slang::ICompileRequest* request, + SlangContainerFormat format) { SLANG_ASSERT(request); request->setOutputContainerFormat(format); } -SLANG_API void spSetPassThrough( - slang::ICompileRequest* request, - SlangPassThrough passThrough) +SLANG_API void spSetPassThrough(slang::ICompileRequest* request, SlangPassThrough passThrough) { SLANG_ASSERT(request); request->setPassThrough(passThrough); } SLANG_API void spSetDiagnosticCallback( - slang::ICompileRequest* request, + slang::ICompileRequest* request, SlangDiagnosticCallback callback, - void const* userData) + void const* userData) { SLANG_ASSERT(request); request->setDiagnosticCallback(callback, userData); } SLANG_API void spSetWriter( - slang::ICompileRequest* request, - SlangWriterChannel chan, - ISlangWriter* writer) + slang::ICompileRequest* request, + SlangWriterChannel chan, + ISlangWriter* writer) { SLANG_ASSERT(request); request->setWriter(chan, writer); } -SLANG_API ISlangWriter* spGetWriter( - slang::ICompileRequest* request, - SlangWriterChannel chan) +SLANG_API ISlangWriter* spGetWriter(slang::ICompileRequest* request, SlangWriterChannel chan) { SLANG_ASSERT(request); return request->getWriter(chan); } -SLANG_API void spAddSearchPath( - slang::ICompileRequest* request, - const char* path) +SLANG_API void spAddSearchPath(slang::ICompileRequest* request, const char* path) { SLANG_ASSERT(request); request->addSearchPath(path); } SLANG_API void spAddPreprocessorDefine( - slang::ICompileRequest* request, - const char* key, - const char* value) + slang::ICompileRequest* request, + const char* key, + const char* value) { SLANG_ASSERT(request); request->addPreprocessorDefine(key, value); } -SLANG_API char const* spGetDiagnosticOutput( - slang::ICompileRequest* request) +SLANG_API char const* spGetDiagnosticOutput(slang::ICompileRequest* request) { SLANG_ASSERT(request); return request->getDiagnosticOutput(); } -SLANG_API SlangResult spGetDiagnosticOutputBlob( - slang::ICompileRequest* request, - ISlangBlob** outBlob) +SLANG_API SlangResult +spGetDiagnosticOutputBlob(slang::ICompileRequest* request, ISlangBlob** outBlob) { SLANG_ASSERT(request); return request->getDiagnosticOutputBlob(outBlob); @@ -513,16 +483,16 @@ SLANG_API SlangResult spGetDiagnosticOutputBlob( // New-fangled compilation API SLANG_API int spAddTranslationUnit( - slang::ICompileRequest* request, - SlangSourceLanguage language, - char const* inName) + slang::ICompileRequest* request, + SlangSourceLanguage language, + char const* inName) { SLANG_ASSERT(request); return request->addTranslationUnit(language, inName); } SLANG_API void spSetDefaultModuleName( - slang::ICompileRequest* request, + slang::ICompileRequest* request, const char* defaultModuleName) { SLANG_ASSERT(request); @@ -530,7 +500,7 @@ SLANG_API void spSetDefaultModuleName( } SLANG_API SlangResult spAddLibraryReference( - slang::ICompileRequest* request, + slang::ICompileRequest* request, const char* basePath, const void* libData, size_t libDataSize) @@ -540,170 +510,168 @@ SLANG_API SlangResult spAddLibraryReference( } SLANG_API void spTranslationUnit_addPreprocessorDefine( - slang::ICompileRequest* request, - int translationUnitIndex, - const char* key, - const char* value) + slang::ICompileRequest* request, + int translationUnitIndex, + const char* key, + const char* value) { SLANG_ASSERT(request); request->addTranslationUnitPreprocessorDefine(translationUnitIndex, key, value); } SLANG_API void spAddTranslationUnitSourceFile( - slang::ICompileRequest* request, - int translationUnitIndex, - char const* path) + slang::ICompileRequest* request, + int translationUnitIndex, + char const* path) { SLANG_ASSERT(request); request->addTranslationUnitSourceFile(translationUnitIndex, path); } SLANG_API void spAddTranslationUnitSourceString( - slang::ICompileRequest* request, - int translationUnitIndex, - char const* path, - char const* source) + slang::ICompileRequest* request, + int translationUnitIndex, + char const* path, + char const* source) { SLANG_ASSERT(request); request->addTranslationUnitSourceString(translationUnitIndex, path, source); } SLANG_API void spAddTranslationUnitSourceStringSpan( - slang::ICompileRequest* request, - int translationUnitIndex, - char const* path, - char const* sourceBegin, - char const* sourceEnd) + slang::ICompileRequest* request, + int translationUnitIndex, + char const* path, + char const* sourceBegin, + char const* sourceEnd) { SLANG_ASSERT(request); request->addTranslationUnitSourceStringSpan(translationUnitIndex, path, sourceBegin, sourceEnd); } SLANG_API void spAddTranslationUnitSourceBlob( - slang::ICompileRequest* request, - int translationUnitIndex, - char const* path, - ISlangBlob* sourceBlob) + slang::ICompileRequest* request, + int translationUnitIndex, + char const* path, + ISlangBlob* sourceBlob) { SLANG_ASSERT(request); request->addTranslationUnitSourceBlob(translationUnitIndex, path, sourceBlob); } SLANG_API int spAddEntryPoint( - slang::ICompileRequest* request, - int translationUnitIndex, - char const* name, - SlangStage stage) + slang::ICompileRequest* request, + int translationUnitIndex, + char const* name, + SlangStage stage) { SLANG_ASSERT(request); return request->addEntryPoint(translationUnitIndex, name, stage); } SLANG_API int spAddEntryPointEx( - slang::ICompileRequest* request, - int translationUnitIndex, - char const* name, - SlangStage stage, - int genericParamTypeNameCount, - char const ** genericParamTypeNames) + slang::ICompileRequest* request, + int translationUnitIndex, + char const* name, + SlangStage stage, + int genericParamTypeNameCount, + char const** genericParamTypeNames) { SLANG_ASSERT(request); - return request->addEntryPointEx(translationUnitIndex, name, stage, genericParamTypeNameCount, genericParamTypeNames); + return request->addEntryPointEx( + translationUnitIndex, + name, + stage, + genericParamTypeNameCount, + genericParamTypeNames); } SLANG_API SlangResult spSetGlobalGenericArgs( - slang::ICompileRequest* request, - int genericArgCount, - char const** genericArgs) + slang::ICompileRequest* request, + int genericArgCount, + char const** genericArgs) { SLANG_ASSERT(request); return request->setGlobalGenericArgs(genericArgCount, genericArgs); } SLANG_API SlangResult spSetTypeNameForGlobalExistentialTypeParam( - slang::ICompileRequest* request, - int slotIndex, - char const* typeName) + slang::ICompileRequest* request, + int slotIndex, + char const* typeName) { SLANG_ASSERT(request); return request->setTypeNameForGlobalExistentialTypeParam(slotIndex, typeName); } SLANG_API SlangResult spSetTypeNameForEntryPointExistentialTypeParam( - slang::ICompileRequest* request, - int entryPointIndex, - int slotIndex, - char const* typeName) + slang::ICompileRequest* request, + int entryPointIndex, + int slotIndex, + char const* typeName) { SLANG_ASSERT(request); - return request->setTypeNameForEntryPointExistentialTypeParam(entryPointIndex, slotIndex, typeName); + return request->setTypeNameForEntryPointExistentialTypeParam( + entryPointIndex, + slotIndex, + typeName); } -SLANG_API SlangResult spCompile( - slang::ICompileRequest* request) +SLANG_API SlangResult spCompile(slang::ICompileRequest* request) { SLANG_ASSERT(request); return request->compile(); } -SLANG_API int -spGetDependencyFileCount( - slang::ICompileRequest* request) +SLANG_API int spGetDependencyFileCount(slang::ICompileRequest* request) { SLANG_ASSERT(request); return request->getDependencyFileCount(); } -SLANG_API char const* -spGetDependencyFilePath( - slang::ICompileRequest* request, - int index) +SLANG_API char const* spGetDependencyFilePath(slang::ICompileRequest* request, int index) { SLANG_ASSERT(request); return request->getDependencyFilePath(index); } -SLANG_API int -spGetTranslationUnitCount( - slang::ICompileRequest* request) +SLANG_API int spGetTranslationUnitCount(slang::ICompileRequest* request) { SLANG_ASSERT(request); return request->getTranslationUnitCount(); } SLANG_API void const* spGetEntryPointCode( - slang::ICompileRequest* request, - int entryPointIndex, - size_t* outSize) + slang::ICompileRequest* request, + int entryPointIndex, + size_t* outSize) { SLANG_ASSERT(request); return request->getEntryPointCode(entryPointIndex, outSize); } SLANG_API SlangResult spGetEntryPointCodeBlob( - slang::ICompileRequest* request, - int entryPointIndex, - int targetIndex, - ISlangBlob** outBlob) + slang::ICompileRequest* request, + int entryPointIndex, + int targetIndex, + ISlangBlob** outBlob) { SLANG_ASSERT(request); return request->getEntryPointCodeBlob(entryPointIndex, targetIndex, outBlob); } SLANG_API SlangResult spGetEntryPointHostCallable( - slang::ICompileRequest* request, - int entryPointIndex, - int targetIndex, - ISlangSharedLibrary** outSharedLibrary) + slang::ICompileRequest* request, + int entryPointIndex, + int targetIndex, + ISlangSharedLibrary** outSharedLibrary) { SLANG_ASSERT(request); return request->getEntryPointHostCallable(entryPointIndex, targetIndex, outSharedLibrary); } -SLANG_API SlangResult spGetTargetCodeBlob( - slang::ICompileRequest* request, - int targetIndex, - ISlangBlob** outBlob) +SLANG_API SlangResult +spGetTargetCodeBlob(slang::ICompileRequest* request, int targetIndex, ISlangBlob** outBlob) { SLANG_ASSERT(request); return request->getTargetCodeBlob(targetIndex, outBlob); @@ -718,25 +686,19 @@ SLANG_API SlangResult spGetTargetHostCallable( return request->getTargetHostCallable(targetIndex, outSharedLibrary); } -SLANG_API char const* spGetEntryPointSource( - slang::ICompileRequest* request, - int entryPointIndex) +SLANG_API char const* spGetEntryPointSource(slang::ICompileRequest* request, int entryPointIndex) { SLANG_ASSERT(request); return request->getEntryPointSource(entryPointIndex); } -SLANG_API void const* spGetCompileRequestCode( - slang::ICompileRequest* request, - size_t* outSize) +SLANG_API void const* spGetCompileRequestCode(slang::ICompileRequest* request, size_t* outSize) { SLANG_ASSERT(request); return request->getCompileRequestCode(outSize); } -SLANG_API SlangResult spGetContainerCode( - slang::ICompileRequest* request, - ISlangBlob** outBlob) +SLANG_API SlangResult spGetContainerCode(slang::ICompileRequest* request, ISlangBlob** outBlob) { SLANG_ASSERT(request); return request->getContainerCode(outBlob); @@ -752,31 +714,27 @@ SLANG_API SlangResult spLoadRepro( return request->loadRepro(fileSystem, data, size); } -SLANG_API SlangResult spSaveRepro( - slang::ICompileRequest* request, - ISlangBlob** outBlob) +SLANG_API SlangResult spSaveRepro(slang::ICompileRequest* request, ISlangBlob** outBlob) { SLANG_ASSERT(request); return request->saveRepro(outBlob); } -SLANG_API SlangResult spEnableReproCapture( - slang::ICompileRequest* request) +SLANG_API SlangResult spEnableReproCapture(slang::ICompileRequest* request) { SLANG_ASSERT(request); return request->enableReproCapture(); } -SLANG_API SlangResult spCompileRequest_getProgram( - slang::ICompileRequest* request, - slang::IComponentType** outProgram) +SLANG_API SlangResult +spCompileRequest_getProgram(slang::ICompileRequest* request, slang::IComponentType** outProgram) { SLANG_ASSERT(request); return request->getProgram(outProgram); } SLANG_API SlangResult spCompileRequest_getProgramWithEntryPoints( - slang::ICompileRequest* request, + slang::ICompileRequest* request, slang::IComponentType** outProgram) { SLANG_ASSERT(request); @@ -784,25 +742,24 @@ SLANG_API SlangResult spCompileRequest_getProgramWithEntryPoints( } SLANG_API SlangResult spCompileRequest_getModule( - slang::ICompileRequest* request, - SlangInt translationUnitIndex, - slang::IModule** outModule) + slang::ICompileRequest* request, + SlangInt translationUnitIndex, + slang::IModule** outModule) { SLANG_ASSERT(request); return request->getModule(translationUnitIndex, outModule); } -SLANG_API SlangResult spCompileRequest_getSession( - slang::ICompileRequest* request, - slang::ISession** outSession) +SLANG_API SlangResult +spCompileRequest_getSession(slang::ICompileRequest* request, slang::ISession** outSession) { SLANG_ASSERT(request); return request->getSession(outSession); } SLANG_API SlangResult spCompileRequest_getEntryPoint( - slang::ICompileRequest* request, - SlangInt entryPointIndex, + slang::ICompileRequest* request, + SlangInt entryPointIndex, slang::IComponentType** outEntryPoint) { SLANG_ASSERT(request); @@ -821,25 +778,23 @@ SLANG_API SlangResult spGetCompileTimeProfile( // Get the output code associated with a specific translation unit SLANG_API char const* spGetTranslationUnitSource( - slang::ICompileRequest* /*request*/, - int /*translationUnitIndex*/) + slang::ICompileRequest* /*request*/, + int /*translationUnitIndex*/ +) { fprintf(stderr, "DEPRECATED: spGetTranslationUnitSource()\n"); return nullptr; } -SLANG_API SlangResult spProcessCommandLineArguments( - SlangCompileRequest* request, - char const* const* args, - int argCount) +SLANG_API SlangResult +spProcessCommandLineArguments(SlangCompileRequest* request, char const* const* args, int argCount) { return request->processCommandLineArguments(args, argCount); } // Reflection API -SLANG_API SlangReflection* spGetReflection( - slang::ICompileRequest* request) +SLANG_API SlangReflection* spGetReflection(slang::ICompileRequest* request) { SLANG_ASSERT(request); return request->getReflection(); @@ -849,7 +804,11 @@ SLANG_API SlangReflection* spGetReflection( /* !!!!!!!!!!!!!!!!!!!!!!!!!!!!! Session !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! */ -SLANG_API SlangResult spExtractRepro(SlangSession* session, const void* reproData, size_t reproDataSize, ISlangMutableFileSystem* fileSystem) +SLANG_API SlangResult spExtractRepro( + SlangSession* session, + const void* reproData, + size_t reproDataSize, + ISlangMutableFileSystem* fileSystem) { using namespace Slang; SLANG_UNUSED(session); @@ -894,7 +853,8 @@ SLANG_API SlangResult spLoadReproAsFileSystem( base.set(buffer.getBuffer(), buffer.getCount()); ComPtr fileSystem; - SLANG_RETURN_ON_FAIL(ReproUtil::loadFileSystem(base, requestState, replaceFileSystem, fileSystem)); + SLANG_RETURN_ON_FAIL( + ReproUtil::loadFileSystem(base, requestState, replaceFileSystem, fileSystem)); *outFileSystem = fileSystem.detach(); return SLANG_OK; diff --git a/source/slang/slang-artifact-output-util.cpp b/source/slang/slang-artifact-output-util.cpp index dcd6eb84e..2d9dae1cf 100644 --- a/source/slang/slang-artifact-output-util.cpp +++ b/source/slang/slang-artifact-output-util.cpp @@ -1,44 +1,49 @@ #include "slang-artifact-output-util.h" -#include "../core/slang-platform.h" - +#include "../core/slang-hex-dump-util.h" #include "../core/slang-io.h" +#include "../core/slang-platform.h" #include "../core/slang-string-util.h" -#include "../core/slang-hex-dump-util.h" - #include "../core/slang-type-text-util.h" // Artifact #include "../compiler-core/slang-artifact-desc-util.h" #include "../compiler-core/slang-artifact-util.h" - #include "slang-compiler.h" namespace Slang { -/* static */SlangResult ArtifactOutputUtil::dissassembleWithDownstream(Session* session, IArtifact* artifact, DiagnosticSink* sink, IArtifact** outArtifact) +/* static */ SlangResult ArtifactOutputUtil::dissassembleWithDownstream( + Session* session, + IArtifact* artifact, + DiagnosticSink* sink, + IArtifact** outArtifact) { auto desc = artifact->getDesc(); auto assemblyDesc = desc; assemblyDesc.kind = ArtifactKind::Assembly; - // Check it seems like a plausbile disassembly + // Check it seems like a plausbile disassembly if (!ArtifactDescUtil::isDisassembly(desc, assemblyDesc)) { if (sink) { - sink->diagnose(SourceLoc(), Diagnostics::cannotDisassemble, ArtifactDescUtil::getText(desc)); + sink->diagnose( + SourceLoc(), + Diagnostics::cannotDisassemble, + ArtifactDescUtil::getText(desc)); } return SLANG_FAIL; } // Get the downstream disassembler that can be used for this target // TODO(JS): - // This could perhaps be performed in some other manner if there was more than one way to produce - // disassembly from a binary. + // This could perhaps be performed in some other manner if there was more than one way to + // produce disassembly from a binary. - const CodeGenTarget target = (CodeGenTarget)ArtifactDescUtil::getCompileTargetFromDesc(assemblyDesc); + const CodeGenTarget target = + (CodeGenTarget)ArtifactDescUtil::getCompileTargetFromDesc(assemblyDesc); if (target == CodeGenTarget::Unknown) { return SLANG_FAIL; @@ -53,7 +58,8 @@ namespace Slang { if (sink) { - auto compilerName = TypeTextUtil::getPassThroughAsHumanText((SlangPassThrough)downstreamCompiler); + auto compilerName = + TypeTextUtil::getPassThroughAsHumanText((SlangPassThrough)downstreamCompiler); sink->diagnose(SourceLoc(), Diagnostics::passThroughCompilerNotFound, compilerName); } return SLANG_FAIL; @@ -66,7 +72,11 @@ namespace Slang return SLANG_OK; } -SlangResult ArtifactOutputUtil::maybeDisassemble(Session* session, IArtifact* artifact, DiagnosticSink* sink, ComPtr& outArtifact) +SlangResult ArtifactOutputUtil::maybeDisassemble( + Session* session, + IArtifact* artifact, + DiagnosticSink* sink, + ComPtr& outArtifact) { const auto desc = artifact->getDesc(); if (ArtifactDescUtil::isText(desc)) @@ -82,7 +92,11 @@ SlangResult ArtifactOutputUtil::maybeDisassemble(Session* session, IArtifact* ar { ComPtr disassemblyArtifact; - if (SLANG_SUCCEEDED(dissassembleWithDownstream(session, artifact, sink, disassemblyArtifact.writeRef()))) + if (SLANG_SUCCEEDED(dissassembleWithDownstream( + session, + artifact, + sink, + disassemblyArtifact.writeRef()))) { // Check it is now text SLANG_ASSERT(ArtifactDescUtil::isText(disassemblyArtifact->getDesc())); @@ -95,7 +109,10 @@ SlangResult ArtifactOutputUtil::maybeDisassemble(Session* session, IArtifact* ar return SLANG_OK; } -/* static */SlangResult ArtifactOutputUtil::write(const ArtifactDesc& desc, ISlangBlob* blob, ISlangWriter* writer) +/* static */ SlangResult ArtifactOutputUtil::write( + const ArtifactDesc& desc, + ISlangBlob* blob, + ISlangWriter* writer) { // If is text, we can just output if (ArtifactDescUtil::isText(desc)) @@ -108,7 +125,11 @@ SlangResult ArtifactOutputUtil::maybeDisassemble(Session* session, IArtifact* ar if (writer->isConsole()) { // Else just dump as text - return HexDumpUtil::dumpWithMarkers((const uint8_t*)blob->getBufferPointer(), blob->getBufferSize(), 24, writer); + return HexDumpUtil::dumpWithMarkers( + (const uint8_t*)blob->getBufferPointer(), + blob->getBufferSize(), + 24, + writer); } else { @@ -119,14 +140,17 @@ SlangResult ArtifactOutputUtil::maybeDisassemble(Session* session, IArtifact* ar } } -/* static */SlangResult ArtifactOutputUtil::write(IArtifact* artifact, ISlangWriter* writer) +/* static */ SlangResult ArtifactOutputUtil::write(IArtifact* artifact, ISlangWriter* writer) { ComPtr blob; SLANG_RETURN_ON_FAIL(artifact->loadBlob(ArtifactKeep::No, blob.writeRef())); return write(artifact->getDesc(), blob, writer); } -static SlangResult _requireBlob(IArtifact* artifact, DiagnosticSink* sink, ComPtr& outBlob) +static SlangResult _requireBlob( + IArtifact* artifact, + DiagnosticSink* sink, + ComPtr& outBlob) { const auto res = artifact->loadBlob(ArtifactKeep::No, outBlob.writeRef()); if (SLANG_FAILED(res)) @@ -137,7 +161,11 @@ static SlangResult _requireBlob(IArtifact* artifact, DiagnosticSink* sink, ComPt return SLANG_OK; } -/* static */SlangResult ArtifactOutputUtil::write(IArtifact* artifact, DiagnosticSink* sink, const UnownedStringSlice& writerName, ISlangWriter* writer) +/* static */ SlangResult ArtifactOutputUtil::write( + IArtifact* artifact, + DiagnosticSink* sink, + const UnownedStringSlice& writerName, + ISlangWriter* writer) { if (sink == nullptr) { @@ -156,7 +184,12 @@ static SlangResult _requireBlob(IArtifact* artifact, DiagnosticSink* sink, ComPt return res; } -/* static */SlangResult ArtifactOutputUtil::maybeConvertAndWrite(Session* session, IArtifact* artifact, DiagnosticSink* sink, const UnownedStringSlice& writerName, ISlangWriter* writer) +/* static */ SlangResult ArtifactOutputUtil::maybeConvertAndWrite( + Session* session, + IArtifact* artifact, + DiagnosticSink* sink, + const UnownedStringSlice& writerName, + ISlangWriter* writer) { // If the output is console we will try and turn into disassembly if (writer->isConsole()) @@ -173,12 +206,17 @@ static SlangResult _requireBlob(IArtifact* artifact, DiagnosticSink* sink, ComPt return write(artifact, sink, writerName, writer); } -/* static */SlangResult ArtifactOutputUtil::writeToFile(const ArtifactDesc& desc, const void* data, size_t size, const String& path) +/* static */ SlangResult ArtifactOutputUtil::writeToFile( + const ArtifactDesc& desc, + const void* data, + size_t size, + const String& path) { - const SlangResult res = ArtifactDescUtil::isText(desc) - ? File::writeAllTextIfChanged(path, UnownedStringSlice((const char*)data, size)) - : File::writeAllBytes(path, data, size); - if(desc.kind == ArtifactKind::Executable) + const SlangResult res = + ArtifactDescUtil::isText(desc) + ? File::writeAllTextIfChanged(path, UnownedStringSlice((const char*)data, size)) + : File::writeAllBytes(path, data, size); + if (desc.kind == ArtifactKind::Executable) { // Ignore any success code here, assume the one from the actual write is more important. SLANG_RETURN_ON_FAIL(File::makeExecutable(path)); @@ -186,21 +224,27 @@ static SlangResult _requireBlob(IArtifact* artifact, DiagnosticSink* sink, ComPt return res; } -/* static */SlangResult ArtifactOutputUtil::writeToFile(const ArtifactDesc& desc, ISlangBlob* blob, const String& path) +/* static */ SlangResult ArtifactOutputUtil::writeToFile( + const ArtifactDesc& desc, + ISlangBlob* blob, + const String& path) { SLANG_RETURN_ON_FAIL(writeToFile(desc, blob->getBufferPointer(), blob->getBufferSize(), path)); return SLANG_OK; } -/* static */SlangResult ArtifactOutputUtil::writeToFile(IArtifact* artifact, const String& path) +/* static */ SlangResult ArtifactOutputUtil::writeToFile(IArtifact* artifact, const String& path) { // Get the blob ComPtr blob; SLANG_RETURN_ON_FAIL(artifact->loadBlob(ArtifactKeep::No, blob.writeRef())); return writeToFile(artifact->getDesc(), blob, path); } - -/* static */SlangResult ArtifactOutputUtil::writeToFile(IArtifact* artifact, DiagnosticSink* sink, const String& path) + +/* static */ SlangResult ArtifactOutputUtil::writeToFile( + IArtifact* artifact, + DiagnosticSink* sink, + const String& path) { if (!sink) { @@ -219,4 +263,4 @@ static SlangResult _requireBlob(IArtifact* artifact, DiagnosticSink* sink, ComPt return res; } -} +} // namespace Slang diff --git a/source/slang/slang-artifact-output-util.h b/source/slang/slang-artifact-output-util.h index 13682c579..88fe5284b 100644 --- a/source/slang/slang-artifact-output-util.h +++ b/source/slang/slang-artifact-output-util.h @@ -1,11 +1,9 @@ #ifndef SLANG_ARTIFACT_OUTPUT_UTIL_H #define SLANG_ARTIFACT_OUTPUT_UTIL_H -#include "../core/slang-basic.h" - #include "../compiler-core/slang-artifact.h" #include "../compiler-core/slang-diagnostic-sink.h" - +#include "../core/slang-basic.h" #include "slang-com-ptr.h" namespace Slang @@ -15,32 +13,54 @@ class Session; struct ArtifactOutputUtil { - /// Attempts to disassembly artifact into outArtifact. - /// Errors are output to sink if set. If not desired pass nullptr - static SlangResult dissassembleWithDownstream(Session* session, IArtifact* artifact, DiagnosticSink* sink, IArtifact** outArtifact); - - /// Disassembles if that is plausible - /// Errors are output to sink if set. If not desired pass nullptr - static SlangResult maybeDisassemble(Session* session, IArtifact* artifact, DiagnosticSink* sink, ComPtr& outArtifact); - - /// Writes output to writer, will convert into disassembly if that is possible and appropriate (if outputting to console for example). - /// Errors are output to sink if set. If not desired pass nullptr - static SlangResult maybeConvertAndWrite(Session* session, IArtifact* artifact, DiagnosticSink* sink, const UnownedStringSlice& writerName, ISlangWriter* writer); - - /// Write (without any diagnostics) - static SlangResult write(IArtifact* artifact, ISlangWriter* writer); - static SlangResult write(const ArtifactDesc& desc, ISlangBlob* blob, ISlangWriter* writer); - - /// Writes the artifact with diagnostics - static SlangResult write(IArtifact* artifact, DiagnosticSink* sink, const UnownedStringSlice& writerName, ISlangWriter* writer); - - /// Write to the specified path - static SlangResult writeToFile(const ArtifactDesc& desc, const void* data, size_t size, const String& path); - static SlangResult writeToFile(const ArtifactDesc& desc, ISlangBlob* blob, const String& path); - static SlangResult writeToFile(IArtifact* artifact, const String& path); - static SlangResult writeToFile(IArtifact* artifact, DiagnosticSink* sink, const String& path); + /// Attempts to disassembly artifact into outArtifact. + /// Errors are output to sink if set. If not desired pass nullptr + static SlangResult dissassembleWithDownstream( + Session* session, + IArtifact* artifact, + DiagnosticSink* sink, + IArtifact** outArtifact); + + /// Disassembles if that is plausible + /// Errors are output to sink if set. If not desired pass nullptr + static SlangResult maybeDisassemble( + Session* session, + IArtifact* artifact, + DiagnosticSink* sink, + ComPtr& outArtifact); + + /// Writes output to writer, will convert into disassembly if that is possible and appropriate + /// (if outputting to console for example). Errors are output to sink if set. If not desired + /// pass nullptr + static SlangResult maybeConvertAndWrite( + Session* session, + IArtifact* artifact, + DiagnosticSink* sink, + const UnownedStringSlice& writerName, + ISlangWriter* writer); + + /// Write (without any diagnostics) + static SlangResult write(IArtifact* artifact, ISlangWriter* writer); + static SlangResult write(const ArtifactDesc& desc, ISlangBlob* blob, ISlangWriter* writer); + + /// Writes the artifact with diagnostics + static SlangResult write( + IArtifact* artifact, + DiagnosticSink* sink, + const UnownedStringSlice& writerName, + ISlangWriter* writer); + + /// Write to the specified path + static SlangResult writeToFile( + const ArtifactDesc& desc, + const void* data, + size_t size, + const String& path); + static SlangResult writeToFile(const ArtifactDesc& desc, ISlangBlob* blob, const String& path); + static SlangResult writeToFile(IArtifact* artifact, const String& path); + static SlangResult writeToFile(IArtifact* artifact, DiagnosticSink* sink, const String& path); }; -} +} // namespace Slang #endif diff --git a/source/slang/slang-ast-all.h b/source/slang/slang-ast-all.h index 096c95f53..b2dc91c55 100644 --- a/source/slang/slang-ast-all.h +++ b/source/slang/slang-ast-all.h @@ -3,9 +3,8 @@ #pragma once #include "slang-ast-base.h" - -#include "slang-ast-expr.h" #include "slang-ast-decl.h" +#include "slang-ast-expr.h" #include "slang-ast-modifier.h" #include "slang-ast-stmt.h" #include "slang-ast-type.h" diff --git a/source/slang/slang-ast-base.cpp b/source/slang/slang-ast-base.cpp index d4904d2de..a1624fcba 100644 --- a/source/slang/slang-ast-base.cpp +++ b/source/slang/slang-ast-base.cpp @@ -1,45 +1,46 @@ #include "slang-ast-base.h" + #include "slang-ast-builder.h" namespace Slang { - void NodeBase::_initDebug(ASTNodeType inAstNodeType, ASTBuilder* inAstBuilder) - { +void NodeBase::_initDebug(ASTNodeType inAstNodeType, ASTBuilder* inAstBuilder) +{ #ifdef _DEBUG - SLANG_UNUSED(inAstNodeType); - static int32_t uidCounter = 0; - static int32_t breakValue = 0; - uidCounter++; - _debugUID = uidCounter; - if (inAstBuilder->getId() == -1) - _debugUID = -_debugUID; - if (breakValue != 0 && _debugUID == breakValue) - SLANG_BREAKPOINT(0) + SLANG_UNUSED(inAstNodeType); + static int32_t uidCounter = 0; + static int32_t breakValue = 0; + uidCounter++; + _debugUID = uidCounter; + if (inAstBuilder->getId() == -1) + _debugUID = -_debugUID; + if (breakValue != 0 && _debugUID == breakValue) + SLANG_BREAKPOINT(0) #else - SLANG_UNUSED(inAstNodeType); - SLANG_UNUSED(inAstBuilder); + SLANG_UNUSED(inAstNodeType); + SLANG_UNUSED(inAstBuilder); #endif - } - DeclRefBase* Decl::getDefaultDeclRef() +} +DeclRefBase* Decl::getDefaultDeclRef() +{ + if (auto astBuilder = getCurrentASTBuilder()) { - if (auto astBuilder = getCurrentASTBuilder()) + const Index currentEpoch = astBuilder->getEpoch(); + if (currentEpoch != m_defaultDeclRefEpoch || !m_defaultDeclRef) { - const Index currentEpoch = astBuilder->getEpoch(); - if (currentEpoch != m_defaultDeclRefEpoch || !m_defaultDeclRef) - { - m_defaultDeclRef = astBuilder->getOrCreate(this); - m_defaultDeclRefEpoch = currentEpoch; - } + m_defaultDeclRef = astBuilder->getOrCreate(this); + m_defaultDeclRefEpoch = currentEpoch; } - return m_defaultDeclRef; - } - - bool Decl::isChildOf(Decl* other) const - { - for (auto parent = parentDecl; parent; parent = parent->parentDecl) - if (parent == other) - return true; - return false; } + return m_defaultDeclRef; +} +bool Decl::isChildOf(Decl* other) const +{ + for (auto parent = parentDecl; parent; parent = parent->parentDecl) + if (parent == other) + return true; + return false; } + +} // namespace Slang diff --git a/source/slang/slang-ast-base.h b/source/slang/slang-ast-base.h index 99d10457c..2b5de61e8 100644 --- a/source/slang/slang-ast-base.h +++ b/source/slang/slang-ast-base.h @@ -2,11 +2,10 @@ #pragma once -#include "slang-ast-support-types.h" - -#include "slang-generated-ast.h" #include "slang-ast-reflect.h" +#include "slang-ast-support-types.h" #include "slang-capability.h" +#include "slang-generated-ast.h" #include "slang-serialize-reflection.h" // This file defines the primary base classes for the hierarchy of @@ -14,17 +13,17 @@ // basic `Decl`, `Stmt`, `Expr`, `type`, etc. definitions come from. namespace Slang -{ +{ class ASTBuilder; struct SemanticsVisitor; -class NodeBase +class NodeBase { SLANG_ABSTRACT_AST_CLASS(NodeBase) - // MUST be called before used. Called automatically via the ASTBuilder. - // Note that the astBuilder is not stored in the NodeBase derived types by default. + // MUST be called before used. Called automatically via the ASTBuilder. + // Note that the astBuilder is not stored in the NodeBase derived types by default. SLANG_FORCE_INLINE void init(ASTNodeType inAstNodeType, ASTBuilder* inAstBuilder) { SLANG_UNUSED(inAstBuilder); @@ -36,14 +35,17 @@ class NodeBase void _initDebug(ASTNodeType inAstNodeType, ASTBuilder* inAstBuilder); - /// Get the class info - SLANG_FORCE_INLINE const ReflectClassInfo& getClassInfo() const { return *ASTClassInfo::getInfo(astNodeType); } + /// Get the class info + SLANG_FORCE_INLINE const ReflectClassInfo& getClassInfo() const + { + return *ASTClassInfo::getInfo(astNodeType); + } SyntaxClass getClass() { return SyntaxClass(&getClassInfo()); } - /// The type of the node. ASTNodeType(-1) is an invalid node type, and shouldn't appear on any - /// correctly constructed (through ASTBuilder) NodeBase derived class. - /// The actual type is set when constructed on the ASTBuilder. + /// The type of the node. ASTNodeType(-1) is an invalid node type, and shouldn't appear on any + /// correctly constructed (through ASTBuilder) NodeBase derived class. + /// The actual type is set when constructed on the ASTBuilder. ASTNodeType astNodeType = ASTNodeType(-1); #ifdef _DEBUG @@ -56,60 +58,79 @@ class NodeBase template SLANG_FORCE_INLINE T* dynamicCast(NodeBase* node) { - return (node && ReflectClassInfo::isSubClassOf(uint32_t(node->astNodeType), T::kReflectClassInfo)) ? static_cast(node) : nullptr; + return (node && + ReflectClassInfo::isSubClassOf(uint32_t(node->astNodeType), T::kReflectClassInfo)) + ? static_cast(node) + : nullptr; } template SLANG_FORCE_INLINE const T* dynamicCast(const NodeBase* node) { - return (node && ReflectClassInfo::isSubClassOf(uint32_t(node->astNodeType), T::kReflectClassInfo)) ? static_cast(node) : nullptr; + return (node && + ReflectClassInfo::isSubClassOf(uint32_t(node->astNodeType), T::kReflectClassInfo)) + ? static_cast(node) + : nullptr; } template SLANG_FORCE_INLINE T* as(NodeBase* node) { - return (node && ReflectClassInfo::isSubClassOf(uint32_t(node->astNodeType), T::kReflectClassInfo)) ? static_cast(node) : nullptr; + return (node && + ReflectClassInfo::isSubClassOf(uint32_t(node->astNodeType), T::kReflectClassInfo)) + ? static_cast(node) + : nullptr; } template SLANG_FORCE_INLINE const T* as(const NodeBase* node) { - return (node && ReflectClassInfo::isSubClassOf(uint32_t(node->astNodeType), T::kReflectClassInfo)) ? static_cast(node) : nullptr; + return (node && + ReflectClassInfo::isSubClassOf(uint32_t(node->astNodeType), T::kReflectClassInfo)) + ? static_cast(node) + : nullptr; } // Because DeclRefBase is now a `Val`, we prevent casting it directly into other nodes // to avoid confusion and bugs. Instead, use the `as<>()` method on `DeclRefBase` to // get a `DeclRef` for a specific node type. template -T* as(DeclRefBase* declRefBase, typename EnableIf::Value, void*>::type arg = nullptr) = delete; +T* as( + DeclRefBase* declRefBase, + typename EnableIf::Value, void*>::type arg = nullptr) = delete; template -T* as(DeclRefBase* declRefBase, typename EnableIf::Value, void*>::type arg = nullptr) +T* as( + DeclRefBase* declRefBase, + typename EnableIf::Value, void*>::type arg = nullptr) { SLANG_UNUSED(arg); return dynamicCast(declRefBase); } template -DeclRef as(DeclRef declRef) { return DeclRef(declRef); } +DeclRef as(DeclRef declRef) +{ + return DeclRef(declRef); +} struct Scope : public NodeBase { SLANG_AST_CLASS(Scope) - + // The container to use for lookup // // Note(tfoley): This is kept as an unowned pointer // so that a scope can't keep parts of the AST alive, // but the opposite it allowed. - ContainerDecl* containerDecl = nullptr; + ContainerDecl* containerDecl = nullptr; // The parent of this scope (where lookup should go if nothing is found locally) - Scope* parent = nullptr; + Scope* parent = nullptr; SLANG_UNREFLECTED // The next sibling of this scope (a peer for lookup) - Scope* nextSibling = nullptr; + Scope* nextSibling = nullptr; }; // Base class for all nodes representing actual syntax @@ -138,14 +159,11 @@ struct ValNodeOperand int64_t intOperand; } values; - ValNodeOperand() - { - values.intOperand = 0; - } + ValNodeOperand() { values.intOperand = 0; } explicit ValNodeOperand(NodeBase* node) { - if constexpr(sizeof(values.nodeOperand) < sizeof(values.intOperand)) + if constexpr (sizeof(values.nodeOperand) < sizeof(values.intOperand)) values.intOperand = 0; if (as(node)) @@ -162,10 +180,11 @@ struct ValNodeOperand template explicit ValNodeOperand(DeclRef declRef) - { + { if constexpr (sizeof(values.nodeOperand) < sizeof(values.intOperand)) values.intOperand = 0; - values.nodeOperand = declRef.declRefBase; kind = ValNodeOperandKind::ValNode; + values.nodeOperand = declRef.declRefBase; + kind = ValNodeOperandKind::ValNode; } template @@ -185,15 +204,21 @@ struct ValNodeOperand } else { - static_assert(std::is_base_of::value || std::is_base_of::value, "pointer used as Val operand must be an AST node."); + static_assert( + std::is_base_of::value || std::is_base_of::value, + "pointer used as Val operand must be an AST node."); } } template explicit ValNodeOperand(EnumType intVal) { - static_assert(std::is_trivial::value, "Type to construct NodeOperand must be trivial."); - static_assert(sizeof(EnumType) <= sizeof(values), "size of operand must be less than pointer size."); + static_assert( + std::is_trivial::value, + "Type to construct NodeOperand must be trivial."); + static_assert( + sizeof(EnumType) <= sizeof(values), + "size of operand must be less than pointer size."); values.intOperand = 0; memcpy(&values, &intVal, sizeof(intVal)); kind = ValNodeOperandKind::ConstantValue; @@ -204,15 +229,19 @@ struct ValNodeDesc { private: HashCode hashCode = 0; + public: - ASTNodeType type; + ASTNodeType type; ShortList operands; inline bool operator==(ValNodeDesc const& that) const { - if (hashCode != that.hashCode) return false; - if (type != that.type) return false; - if (operands.getCount() != that.operands.getCount()) return false; + if (hashCode != that.hashCode) + return false; + if (type != that.type) + return false; + if (operands.getCount() != that.operands.getCount()) + return false; for (Index i = 0; i < operands.getCount(); ++i) { // Note: we are comparing the operands directly for identity @@ -222,22 +251,26 @@ public: // The rationale here is that nodes that will be created // via a `NodeDesc` *should* all be going through the // deduplication path anyway, as should their operands. - // - if (operands[i].values.intOperand != that.operands[i].values.intOperand) return false; + // + if (operands[i].values.intOperand != that.operands[i].values.intOperand) + return false; } return true; } HashCode getHashCode() const { return hashCode; } void init(); - }; template static void addOrAppendToNodeList(ShortList&) -{} +{ +} template -static void addOrAppendToNodeList(ShortList& list, ExpandedSpecializationArgs e, Ts... ts) +static void addOrAppendToNodeList( + ShortList& list, + ExpandedSpecializationArgs e, + Ts... ts) { for (auto arg : e) { @@ -278,11 +311,13 @@ static void addOrAppendToNodeList(ShortList& list, ArrayView< addOrAppendToNodeList(list, ts...); } -inline void addOrAppendToNodeList(List&) -{} +inline void addOrAppendToNodeList(List&) {} template -static void addOrAppendToNodeList(List& list, ExpandedSpecializationArgs e, Ts... ts) +static void addOrAppendToNodeList( + List& list, + ExpandedSpecializationArgs e, + Ts... ts) { for (auto arg : e) { @@ -331,7 +366,7 @@ static void addOrAppendToNodeList(List& list, ArrayView l, Ts class Val : public NodeBase { SLANG_ABSTRACT_AST_CLASS(Val) - + template struct OperandView { @@ -344,31 +379,22 @@ class Val : public NodeBase offset = 0; count = 0; } - OperandView(const Val* val, Index offset, Index count) : val(val), offset(offset), count(count) {} - Index getCount() { return count; } - T* operator[](Index index) const + OperandView(const Val* val, Index offset, Index count) + : val(val), offset(offset), count(count) { - return as(val->getOperand(index + offset)); } + Index getCount() { return count; } + T* operator[](Index index) const { return as(val->getOperand(index + offset)); } struct ConstIterator { const Val* val; Index i; - bool operator==(ConstIterator other) const - { - return val == other.val && i == other.i; - } - bool operator!=(ConstIterator other) const - { - return val != other.val || i != other.i; - } - T *const & operator*() const - { - return *(this->operator->()); - } - T *const * operator->() const + bool operator==(ConstIterator other) const { return val == other.val && i == other.i; } + bool operator!=(ConstIterator other) const { return val != other.val || i != other.i; } + T* const& operator*() const { return *(this->operator->()); } + T* const* operator->() const { - return reinterpret_cast(&val->m_operands[i].values.nodeOperand); + return reinterpret_cast(&val->m_operands[i].values.nodeOperand); } ConstIterator& operator++() { @@ -376,8 +402,8 @@ class Val : public NodeBase return *this; } }; - ConstIterator begin() const { return ConstIterator { val, offset }; } - ConstIterator end() const { return ConstIterator{ val, offset + count }; } + ConstIterator begin() const { return ConstIterator{val, offset}; } + ConstIterator end() const { return ConstIterator{val, offset + count}; } }; typedef IValVisitor Visitor; @@ -405,10 +431,7 @@ class Val : public NodeBase String toString(); HashCode getHashCode(); - bool operator == (const Val & v) const - { - return equals(const_cast(&v)); - } + bool operator==(const Val& v) const { return equals(const_cast(&v)); } // Overrides should be public so base classes can access Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff); @@ -439,13 +462,13 @@ class Val : public NodeBase Index getOperandCount() const { return m_operands.getCount(); } - template + template void setOperands(TArgs... args) { m_operands.clear(); addOrAppendToNodeList(m_operands, args...); } - template + template void addOperands(TArgs... args) { addOrAppendToNodeList(m_operands, args...); @@ -458,18 +481,24 @@ class Val : public NodeBase } List m_operands; - // Private use by the core module deserialization only. Since we know the Vals serialized into the core module is already - // unique, we can just use `this` pointer as the `m_resolvedVal` so we don't need to resolve them again. + // Private use by the core module deserialization only. Since we know the Vals serialized into + // the core module is already unique, we can just use `this` pointer as the `m_resolvedVal` so + // we don't need to resolve them again. void _setUnique(); + protected: Val* defaultResolveImpl(); + private: mutable Val* m_resolvedVal = nullptr; SLANG_UNREFLECTED mutable Index m_resolvedValEpoch = 0; }; template -static void addOrAppendToNodeList(ShortList& list, Val::OperandView l, Ts... ts) +static void addOrAppendToNodeList( + ShortList& list, + Val::OperandView l, + Ts... ts) { for (auto t : l) list.add(ValNodeOperand(t)); @@ -482,12 +511,12 @@ struct ValSet { Val* val = nullptr; ValItem() = default; - ValItem(Val* v) : val(v) {} - - HashCode getHashCode() const + ValItem(Val* v) + : val(v) { - return val ? val->getHashCode() : 0; } + + HashCode getHashCode() const { return val ? val->getHashCode() : 0; } bool operator==(const ValItem other) const { if (val == other.val) @@ -500,34 +529,33 @@ struct ValSet } }; HashSet set; - bool add(Val* val) - { - return set.add(ValItem(val)); - } - bool contains(Val* val) - { - return set.contains(ValItem(val)); - } + bool add(Val* val) { return set.add(ValItem(val)); } + bool contains(Val* val) { return set.contains(ValItem(val)); } }; -SLANG_FORCE_INLINE StringBuilder& operator<<(StringBuilder& io, Val* val) { SLANG_ASSERT(val); val->toText(io); return io; } +SLANG_FORCE_INLINE StringBuilder& operator<<(StringBuilder& io, Val* val) +{ + SLANG_ASSERT(val); + val->toText(io); + return io; +} - /// Given a `value` that refers to a `param` of some generic, attempt to apply - /// the `subst` to it and produce a new `Val` as a result. - /// - /// If the `subst` does not include anything to replace `value`, then this function - /// returns null. - /// +/// Given a `value` that refers to a `param` of some generic, attempt to apply +/// the `subst` to it and produce a new `Val` as a result. +/// +/// If the `subst` does not include anything to replace `value`, then this function +/// returns null. +/// Val* maybeSubstituteGenericParam(Val* value, Decl* param, SubstitutionSet subst, int* ioDiff); class Type; -template +template SLANG_FORCE_INLINE T* as(Type* obj); -template +template SLANG_FORCE_INLINE const T* as(const Type* obj); - + // A type, representing a classifier for some term in the AST. // // Types can include "sugar" in that they may refer to a @@ -539,7 +567,7 @@ SLANG_FORCE_INLINE const T* as(const Type* obj); // "canonical" type. The representation caches a pointer to // a canonical type on every type, so we can easily // operate on the raw representation when needed. -class Type: public Val +class Type : public Val { SLANG_ABSTRACT_AST_CLASS(Type) @@ -547,8 +575,8 @@ class Type: public Val void accept(ITypeVisitor* visitor, void* extra); - /// Type derived types store the AST builder they were constructed on. The builder calls this function - /// after constructing. + /// Type derived types store the AST builder they were constructed on. The builder calls this + /// function after constructing. SLANG_FORCE_INLINE void init(ASTNodeType inAstNodeType, ASTBuilder* inAstBuilder) { Val::init(inAstNodeType, inAstBuilder); @@ -560,12 +588,10 @@ class Type: public Val Type* _createCanonicalTypeOverride(); Val* _resolveImplOverride(); - Type* getCanonicalType() - { - return as(resolve()); - } + Type* getCanonicalType() { return as(resolve()); } ASTBuilder* getASTBuilderForReflection() const { return m_astBuilderForReflection; } + protected: Type* createCanonicalType(); @@ -577,10 +603,16 @@ protected: SLANG_UNREFLECTED ASTBuilder* m_astBuilderForReflection; }; -template -SLANG_FORCE_INLINE T* as(Type* obj) { return obj ? dynamicCast(obj->getCanonicalType()) : nullptr; } -template -SLANG_FORCE_INLINE const T* as(const Type* obj) { return obj ? dynamicCast(const_cast(obj)->getCanonicalType()) : nullptr; } +template +SLANG_FORCE_INLINE T* as(Type* obj) +{ + return obj ? dynamicCast(obj->getCanonicalType()) : nullptr; +} +template +SLANG_FORCE_INLINE const T* as(const Type* obj) +{ + return obj ? dynamicCast(const_cast(obj)->getCanonicalType()) : nullptr; +} class Decl; @@ -620,8 +652,11 @@ class DeclRefBase : public Val } // Returns true if 'as' will return a valid cast - template - bool is() const { return Slang::as(getDecl()) != nullptr; } + template + bool is() const + { + return Slang::as(getDecl()) != nullptr; + } // Convenience accessors for common properties of declarations Name* getName() const; @@ -638,7 +673,12 @@ class DeclRefBase : public Val void toText(StringBuilder& out); }; -SLANG_FORCE_INLINE StringBuilder& operator<<(StringBuilder& io, const DeclRefBase* declRef) { if (declRef) const_cast(declRef)->toText(io); return io; } +SLANG_FORCE_INLINE StringBuilder& operator<<(StringBuilder& io, const DeclRefBase* declRef) +{ + if (declRef) + const_cast(declRef)->toText(io); + return io; +} SLANG_FORCE_INLINE StringBuilder& operator<<(StringBuilder& io, Decl* decl) { @@ -682,7 +722,10 @@ class ModifiableSyntaxNode : public SyntaxNode Modifiers modifiers; template - FilteredModifierList getModifiersOfType() { return FilteredModifierList(modifiers.first); } + FilteredModifierList getModifiersOfType() + { + return FilteredModifierList(modifiers.first); + } // Find the first modifier of a given type, or return `nullptr` if none is found. template @@ -692,7 +735,10 @@ class ModifiableSyntaxNode : public SyntaxNode } template - bool hasModifier() { return findModifier() != nullptr; } + bool hasModifier() + { + return findModifier() != nullptr; + } }; struct DeclReferenceWithLoc @@ -725,9 +771,9 @@ public: RefPtr markup; - Name* getName() const { return nameAndLoc.name; } - SourceLoc getNameLoc() const { return nameAndLoc.loc ; } - NameLoc getNameAndLoc() const { return nameAndLoc ; } + Name* getName() const { return nameAndLoc.name; } + SourceLoc getNameLoc() const { return nameAndLoc.loc; } + NameLoc getNameAndLoc() const { return nameAndLoc; } DeclCheckStateExt checkState = DeclCheckState::Unchecked; @@ -744,6 +790,7 @@ public: // Track the decl reference that caused the requirement of a capability atom. SLANG_UNREFLECTED List capabilityRequirementProvenance; + private: SLANG_UNREFLECTED DeclRefBase* m_defaultDeclRef = nullptr; SLANG_UNREFLECTED Index m_defaultDeclRefEpoch = -1; @@ -808,28 +855,32 @@ Name* DeclRef::getName() const template SourceLoc DeclRef::getNameLoc() const { - if (declRefBase) return declRefBase->getNameLoc(); + if (declRefBase) + return declRefBase->getNameLoc(); return SourceLoc(); } template SourceLoc DeclRef::getLoc() const { - if (declRefBase) return declRefBase->getLoc(); + if (declRefBase) + return declRefBase->getLoc(); return SourceLoc(); } template DeclRef DeclRef::getParent() const { - if (declRefBase) return DeclRef(declRefBase->getParent()); + if (declRefBase) + return DeclRef(declRefBase->getParent()); return DeclRef((DeclRefBase*)nullptr); } template HashCode DeclRef::getHashCode() const { - if (declRefBase) return declRefBase->getHashCode(); + if (declRefBase) + return declRefBase->getHashCode(); return 0; } @@ -837,7 +888,8 @@ template Type* DeclRef::substitute(ASTBuilder* astBuilder, Type* type) const { SLANG_UNUSED(astBuilder); - if (!declRefBase) return type; + if (!declRefBase) + return type; return SubstitutionSet(*this).applyToType(astBuilder, type); } @@ -845,7 +897,8 @@ template SubstExpr DeclRef::substitute(ASTBuilder* astBuilder, Expr* expr) const { SLANG_UNUSED(astBuilder); - if (!declRefBase) return expr; + if (!declRefBase) + return expr; return applySubstitutionToExpr(SubstitutionSet(*this), expr); } @@ -855,16 +908,19 @@ template DeclRef DeclRef::substitute(ASTBuilder* astBuilder, DeclRef declRef) const { SLANG_UNUSED(astBuilder); - if (!declRefBase) return declRef; + if (!declRefBase) + return declRef; return DeclRef(SubstitutionSet(*this).applyToDeclRef(astBuilder, declRef.declRefBase)); } // Apply substitutions to this declaration reference template -DeclRef DeclRef::substituteImpl(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) const +DeclRef DeclRef::substituteImpl(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) + const { SLANG_UNUSED(astBuilder); - if (!declRefBase) return *this; + if (!declRefBase) + return *this; return DeclRef(declRefBase->substituteImpl(astBuilder, subst, ioDiff)); } diff --git a/source/slang/slang-ast-builder.cpp b/source/slang/slang-ast-builder.cpp index 7edbe750a..575c7268b 100644 --- a/source/slang/slang-ast-builder.cpp +++ b/source/slang/slang-ast-builder.cpp @@ -1,16 +1,16 @@ // slang-ast-builder.cpp #include "slang-ast-builder.h" -#include #include "slang-compiler.h" -namespace Slang { +#include + +namespace Slang +{ // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! SharedASTBuilder !!!!!!!!!!!!!!!!!!!!!!!!!!!!!! -SharedASTBuilder::SharedASTBuilder() -{ -} +SharedASTBuilder::SharedASTBuilder() {} void SharedASTBuilder::init(Session* session) { @@ -90,7 +90,8 @@ Type* SharedASTBuilder::getNativeStringType() if (!m_nativeStringType) { auto nativeStringTypeDecl = findMagicDecl("NativeStringType"); - m_nativeStringType = DeclRefType::create(m_astBuilder, makeDeclRef(nativeStringTypeDecl)); + m_nativeStringType = + DeclRefType::create(m_astBuilder, makeDeclRef(nativeStringTypeDecl)); } return m_nativeStringType; } @@ -190,7 +191,9 @@ void SharedASTBuilder::registerBuiltinDecl(Decl* decl, BuiltinTypeModifier* modi m_builtinTypes[Index(modifier->tag)] = type; } -void SharedASTBuilder::registerBuiltinRequirementDecl(Decl* decl, BuiltinRequirementModifier* modifier) +void SharedASTBuilder::registerBuiltinRequirementDecl( + Decl* decl, + BuiltinRequirementModifier* modifier) { m_builtinRequirementDecls[modifier->kind] = decl; } @@ -221,11 +224,11 @@ Decl* SharedASTBuilder::tryFindMagicDecl(const String& name) // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! ASTBuilder !!!!!!!!!!!!!!!!!!!!!!!!!!!!!! -ASTBuilder::ASTBuilder(SharedASTBuilder* sharedASTBuilder, const String& name): - m_sharedASTBuilder(sharedASTBuilder), - m_name(name), - m_id(sharedASTBuilder->m_id++), - m_arena(2097152) +ASTBuilder::ASTBuilder(SharedASTBuilder* sharedASTBuilder, const String& name) + : m_sharedASTBuilder(sharedASTBuilder) + , m_name(name) + , m_id(sharedASTBuilder->m_id++) + , m_arena(2097152) { SLANG_ASSERT(sharedASTBuilder); // Copy Val deduplication map over so we don't create duplicate Vals that are already @@ -233,10 +236,8 @@ ASTBuilder::ASTBuilder(SharedASTBuilder* sharedASTBuilder, const String& name): m_cachedNodes = sharedASTBuilder->getInnerASTBuilder()->m_cachedNodes; } -ASTBuilder::ASTBuilder(): - m_sharedASTBuilder(nullptr), - m_id(-1), - m_arena(2097152) +ASTBuilder::ASTBuilder() + : m_sharedASTBuilder(nullptr), m_id(-1), m_arena(2097152) { m_name = "SharedASTBuilder::m_astBuilder"; } @@ -265,7 +266,7 @@ void ASTBuilder::incrementEpoch() NodeBase* ASTBuilder::createByNodeType(ASTNodeType nodeType) { const ReflectClassInfo* info = ASTClassInfo::getInfo(nodeType); - + auto createFunc = info->m_createFunc; SLANG_ASSERT(createFunc); if (!createFunc) @@ -327,9 +328,12 @@ PtrTypeBase* ASTBuilder::getPtrType(Type* valueType, char const* ptrTypeName) return as(getSpecializedBuiltinType(valueType, ptrTypeName)); } -PtrTypeBase* ASTBuilder::getPtrType(Type* valueType, AddressSpace addrSpace, char const* ptrTypeName) +PtrTypeBase* ASTBuilder::getPtrType( + Type* valueType, + AddressSpace addrSpace, + char const* ptrTypeName) { - Val* args[] = { valueType, getIntVal(getUInt64Type(), (IntegerLiteralValue)addrSpace) }; + Val* args[] = {valueType, getIntVal(getUInt64Type(), (IntegerLiteralValue)addrSpace)}; return as(getSpecializedBuiltinType(makeArrayView(args), ptrTypeName)); } @@ -350,7 +354,8 @@ ArrayExpressionType* ASTBuilder::getArrayType(Type* elementType, IntVal* element } } Val* args[] = {elementType, elementCount}; - return as(getSpecializedBuiltinType(makeArrayView(args), "ArrayExpressionType")); + return as( + getSpecializedBuiltinType(makeArrayView(args), "ArrayExpressionType")); } ConstantBufferType* ASTBuilder::getConstantBufferType(Type* elementType) @@ -365,12 +370,14 @@ ParameterBlockType* ASTBuilder::getParameterBlockType(Type* elementType) HLSLStructuredBufferType* ASTBuilder::getStructuredBufferType(Type* elementType) { - return as(getSpecializedBuiltinType(elementType, "HLSLStructuredBufferType")); + return as( + getSpecializedBuiltinType(elementType, "HLSLStructuredBufferType")); } HLSLRWStructuredBufferType* ASTBuilder::getRWStructuredBufferType(Type* elementType) { - return as(getSpecializedBuiltinType(elementType, "HLSLRWStructuredBufferType")); + return as( + getSpecializedBuiltinType(elementType, "HLSLRWStructuredBufferType")); } SamplerStateType* ASTBuilder::getSamplerStateType() @@ -378,20 +385,23 @@ SamplerStateType* ASTBuilder::getSamplerStateType() return as(getSpecializedBuiltinType(nullptr, "HLSLStructuredBufferType")); } -VectorExpressionType* ASTBuilder::getVectorType( - Type* elementType, - IntVal* elementCount) +VectorExpressionType* ASTBuilder::getVectorType(Type* elementType, IntVal* elementCount) { // Canonicalize constant elementCount to int. if (auto elementCountConstantInt = as(elementCount)) { elementCount = getIntVal(getIntType(), elementCountConstantInt->getValue()); } - Val* args[] = { elementType, elementCount }; - return as(getSpecializedBuiltinType(makeArrayView(args), "VectorExpressionType")); + Val* args[] = {elementType, elementCount}; + return as( + getSpecializedBuiltinType(makeArrayView(args), "VectorExpressionType")); } -MatrixExpressionType* ASTBuilder::getMatrixType(Type* elementType, IntVal* rowCount, IntVal* colCount, IntVal* layout) +MatrixExpressionType* ASTBuilder::getMatrixType( + Type* elementType, + IntVal* rowCount, + IntVal* colCount, + IntVal* layout) { // Canonicalize constant size arguments to int. if (auto rowCountConstantInt = as(rowCount)) @@ -402,35 +412,38 @@ MatrixExpressionType* ASTBuilder::getMatrixType(Type* elementType, IntVal* rowCo { colCount = getIntVal(getIntType(), colCountConstantInt->getValue()); } - Val* args[] = { elementType, rowCount, colCount, layout }; - return as(getSpecializedBuiltinType(makeArrayView(args), "MatrixExpressionType")); + Val* args[] = {elementType, rowCount, colCount, layout}; + return as( + getSpecializedBuiltinType(makeArrayView(args), "MatrixExpressionType")); } -DifferentialPairType* ASTBuilder::getDifferentialPairType( - Type* valueType, - Witness* diffTypeWitness) +DifferentialPairType* ASTBuilder::getDifferentialPairType(Type* valueType, Witness* diffTypeWitness) { - Val* args[] = { valueType, diffTypeWitness }; - return as(getSpecializedBuiltinType(makeArrayView(args), "DifferentialPairType")); + Val* args[] = {valueType, diffTypeWitness}; + return as( + getSpecializedBuiltinType(makeArrayView(args), "DifferentialPairType")); } DifferentialPtrPairType* ASTBuilder::getDifferentialPtrPairType( Type* valueType, Witness* diffRefTypeWitness) { - Val* args[] = { valueType, diffRefTypeWitness }; - return as(getSpecializedBuiltinType(makeArrayView(args), "DifferentialPtrPairType")); + Val* args[] = {valueType, diffRefTypeWitness}; + return as( + getSpecializedBuiltinType(makeArrayView(args), "DifferentialPtrPairType")); } DeclRef ASTBuilder::getDifferentiableInterfaceDecl() { - DeclRef declRef = DeclRef(getBuiltinDeclRef("DifferentiableType", nullptr)); + DeclRef declRef = + DeclRef(getBuiltinDeclRef("DifferentiableType", nullptr)); return declRef; } DeclRef ASTBuilder::getDifferentiableRefInterfaceDecl() { - DeclRef declRef = DeclRef(getBuiltinDeclRef("DifferentiablePtrType", nullptr)); + DeclRef declRef = + DeclRef(getBuiltinDeclRef("DifferentiablePtrType", nullptr)); return declRef; } @@ -441,12 +454,15 @@ bool ASTBuilder::isDifferentiableInterfaceAvailable() DeclRef ASTBuilder::getDefaultInitializableTypeInterfaceDecl() { - DeclRef declRef = DeclRef(getBuiltinDeclRef("DefaultInitializableType", nullptr)); + DeclRef declRef = + DeclRef(getBuiltinDeclRef("DefaultInitializableType", nullptr)); return declRef; } Type* ASTBuilder::getDefaultInitializableType() { - return DeclRefType::create(m_sharedASTBuilder->m_astBuilder, getDefaultInitializableTypeInterfaceDecl()); + return DeclRefType::create( + m_sharedASTBuilder->m_astBuilder, + getDefaultInitializableTypeInterfaceDecl()); } MeshOutputType* ASTBuilder::getMeshOutputTypeFromModifier( @@ -458,13 +474,13 @@ MeshOutputType* ASTBuilder::getMeshOutputTypeFromModifier( SLANG_ASSERT(elementType); SLANG_ASSERT(maxElementCount); - const char* declName - = as(modifier) ? "VerticesType" - : as(modifier) ? "IndicesType" - : as(modifier) ? "PrimitivesType" - : (SLANG_UNEXPECTED("Unhandled mesh output modifier"), nullptr); + const char* declName = as(modifier) ? "VerticesType" + : as(modifier) ? "IndicesType" + : as(modifier) + ? "PrimitivesType" + : (SLANG_UNEXPECTED("Unhandled mesh output modifier"), nullptr); - Val* args[] = { elementType, maxElementCount }; + Val* args[] = {elementType, maxElementCount}; return as(getSpecializedBuiltinType(makeArrayView(args), declName)); } @@ -483,7 +499,8 @@ DeclRef ASTBuilder::getBuiltinDeclRef(const char* builtinMagicTypeName, Va auto decl = m_sharedASTBuilder->findMagicDecl(builtinMagicTypeName); if (auto genericDecl = as(decl)) { - auto declRef = getGenericAppDeclRef(makeDeclRef(genericDecl), makeConstArrayViewSingle(genericArg)); + auto declRef = + getGenericAppDeclRef(makeDeclRef(genericDecl), makeConstArrayViewSingle(genericArg)); return declRef; } else @@ -493,7 +510,9 @@ DeclRef ASTBuilder::getBuiltinDeclRef(const char* builtinMagicTypeName, Va return makeDeclRef(decl); } -DeclRef ASTBuilder::getBuiltinDeclRef(const char* builtinMagicTypeName, ArrayView genericArgs) +DeclRef ASTBuilder::getBuiltinDeclRef( + const char* builtinMagicTypeName, + ArrayView genericArgs) { auto decl = m_sharedASTBuilder->findMagicDecl(builtinMagicTypeName); if (auto genericDecl = as(decl)) @@ -544,8 +563,9 @@ FuncType* ASTBuilder::getFuncType(ArrayView parameters, Type* result, Typ TupleType* ASTBuilder::getTupleType(ArrayView types) { - // The canonical form of a tuple type is always a DeclRefType(GenAppDeclRef(TupleDecl, ConcreteTypePack(types...))). - // If `types` is already a single ConcreteTypePack, then we can use that directly. + // The canonical form of a tuple type is always a DeclRefType(GenAppDeclRef(TupleDecl, + // ConcreteTypePack(types...))). If `types` is already a single ConcreteTypePack, then we can + // use that directly. if (types.getCount() == 1) { if (isTypePack(types[0])) @@ -572,7 +592,8 @@ Type* ASTBuilder::getEachType(Type* baseType) return expandType->getPatternType(); } - // each Tuple ==> each X, because we know that Tuple type must be in the form of Tuple>. + // each Tuple ==> each X, because we know that Tuple type must be in the form of + // Tuple>. if (auto tupleType = as(baseType)) { return getEachType(tupleType->getTypePack()); @@ -613,20 +634,23 @@ ConcreteTypePack* ASTBuilder::getTypePack(ArrayView types) return getOrCreate(flattenedTypes.getArrayView().arrayView); } -TypeEqualityWitness* ASTBuilder::getTypeEqualityWitness( - Type* type) +TypeEqualityWitness* ASTBuilder::getTypeEqualityWitness(Type* type) { return getOrCreate(type, type); } TypePackSubtypeWitness* ASTBuilder::getSubtypeWitnessPack( - Type* subType, Type* superType, ArrayView witnesses) + Type* subType, + Type* superType, + ArrayView witnesses) { return getOrCreate(subType, superType, witnesses); } SubtypeWitness* ASTBuilder::getExpandSubtypeWitness( - Type* subType, Type* superType, SubtypeWitness* patternWitness) + Type* subType, + Type* superType, + SubtypeWitness* patternWitness) { if (auto eachWitness = as(patternWitness)) return eachWitness->getPatternTypeWitness(); @@ -634,7 +658,9 @@ SubtypeWitness* ASTBuilder::getExpandSubtypeWitness( } SubtypeWitness* ASTBuilder::getEachSubtypeWitness( - Type* subType, Type* superType, SubtypeWitness* patternWitness) + Type* subType, + Type* superType, + SubtypeWitness* patternWitness) { if (auto expandWitness = as(patternWitness)) return expandWitness->getPatternTypeWitness(); @@ -642,12 +668,11 @@ SubtypeWitness* ASTBuilder::getEachSubtypeWitness( } DeclaredSubtypeWitness* ASTBuilder::getDeclaredSubtypeWitness( - Type* subType, - Type* superType, - DeclRef const& declRef) + Type* subType, + Type* superType, + DeclRef const& declRef) { - auto witness = getOrCreate( - subType, superType, declRef.declRefBase); + auto witness = getOrCreate(subType, superType, declRef.declRefBase); return witness; } @@ -666,7 +691,7 @@ top: // // If `a == b`, then the `b <: c` witness is also a witness of `a <: c`. // - if(as(aIsSubtypeOfBWitness)) + if (as(aIsSubtypeOfBWitness)) { return bIsSubtypeOfCWitness; } @@ -694,9 +719,8 @@ top: // We can recursively call this operation to produce a witness that // `a <: x`, based on the witnesses we already have for `a <: b` and `b <: x`: // - auto aIsSubtypeOfXWitness = getTransitiveSubtypeWitness( - aIsSubtypeOfBWitness, - bIsSubtypeOfXWitness); + auto aIsSubtypeOfXWitness = + getTransitiveSubtypeWitness(aIsSubtypeOfBWitness, bIsSubtypeOfXWitness); // Now we can perform a "tail recursive" call to this function (via `goto` // to combine the `a <: x` witness with our `x <: c` witness: @@ -714,7 +738,7 @@ top: // and we'd rather form a conjunction witness for `A <: C` // that is of the form `(A <: X)&(A <: Y)`. // - if(auto bIsSubtypeOfXAndY = as(bIsSubtypeOfCWitness)) + if (auto bIsSubtypeOfXAndY = as(bIsSubtypeOfCWitness)) { auto bIsSubtypeOfXWitness = bIsSubtypeOfXAndY->getLeftWitness(); auto bIsSubtypeOfYWitness = bIsSubtypeOfXAndY->getRightWitness(); @@ -730,7 +754,8 @@ top: // `W` is a witness that `B <: X&Y&...` for some conjunction, where `C` // is one component of that conjunction. // - if(auto bIsSubtypeViaExtraction = as(bIsSubtypeOfCWitness)) + if (auto bIsSubtypeViaExtraction = + as(bIsSubtypeOfCWitness)) { // We decompose the witness `extract(i, W)` to get both // the witness `W` that `B <: X&Y&...` as well as the index @@ -761,12 +786,10 @@ top: List newWitnesses; for (Index i = 0; i < witnessPack->getCount(); i++) { - newWitnesses.add(getTransitiveSubtypeWitness(witnessPack->getWitness(i), bIsSubtypeOfCWitness)); + newWitnesses.add( + getTransitiveSubtypeWitness(witnessPack->getWitness(i), bIsSubtypeOfCWitness)); } - return getSubtypeWitnessPack( - aType, - cType, - newWitnesses.getArrayView()); + return getSubtypeWitnessPack(aType, cType, newWitnesses.getArrayView()); } // If left hand is a ExpandSubtypeWitness, then we want to perform the transitive lookup @@ -774,7 +797,9 @@ top: // if (auto expandWitness = as(aIsSubtypeOfBWitness)) { - auto innerTransitiveWitness = getTransitiveSubtypeWitness(expandWitness->getPatternTypeWitness(), bIsSubtypeOfCWitness); + auto innerTransitiveWitness = getTransitiveSubtypeWitness( + expandWitness->getPatternTypeWitness(), + bIsSubtypeOfCWitness); return getExpandSubtypeWitness(expandWitness->getSub(), cType, innerTransitiveWitness); } @@ -787,8 +812,12 @@ top: { if (declRefType->getDeclRef().as()) { - auto newLeftHandWitness = getEachSubtypeWitness(getEachType(declaredWitness->getSub()), declaredWitness->getSup(), declaredWitness); - auto transitiveWitness = getTransitiveSubtypeWitness(newLeftHandWitness, bIsSubtypeOfCWitness); + auto newLeftHandWitness = getEachSubtypeWitness( + getEachType(declaredWitness->getSub()), + declaredWitness->getSup(), + declaredWitness); + auto transitiveWitness = + getTransitiveSubtypeWitness(newLeftHandWitness, bIsSubtypeOfCWitness); return getExpandSubtypeWitness(aType, cType, transitiveWitness); } } @@ -813,10 +842,10 @@ top: } SubtypeWitness* ASTBuilder::getExtractFromConjunctionSubtypeWitness( - Type* subType, - Type* superType, + Type* subType, + Type* superType, SubtypeWitness* conjunctionWitness, - int indexOfSuperTypeInConjunction) + int indexOfSuperTypeInConjunction) { // We are taking a witness `W` for `S <: L&R` and // using it to produce a witness for `S <: L` @@ -845,8 +874,8 @@ SubtypeWitness* ASTBuilder::getExtractFromConjunctionSubtypeWitness( } SubtypeWitness* ASTBuilder::getConjunctionSubtypeWitness( - Type* sub, - Type* lAndR, + Type* sub, + Type* lAndR, SubtypeWitness* subIsLWitness, SubtypeWitness* subIsRWitness) { @@ -858,10 +887,9 @@ SubtypeWitness* ASTBuilder::getConjunctionSubtypeWitness( // auto lExtract = as(subIsLWitness); auto rExtract = as(subIsRWitness); - if(lExtract && rExtract) + if (lExtract && rExtract) { - if (lExtract->getIndexInConjunction() == 0 - && rExtract->getIndexInConjunction() == 1) + if (lExtract->getIndexInConjunction() == 0 && rExtract->getIndexInConjunction() == 1) { auto lInner = lExtract->getConjunctionWitness(); auto rInner = rExtract->getConjunctionWitness(); @@ -883,11 +911,7 @@ SubtypeWitness* ASTBuilder::getConjunctionSubtypeWitness( // witness) deeper, so that we have more chances to expose a // conjunction witness at higher levels. - auto witness = getOrCreate( - sub, - lAndR, - subIsLWitness, - subIsRWitness); + auto witness = getOrCreate(sub, lAndR, subIsLWitness, subIsRWitness); return witness; } diff --git a/source/slang/slang-ast-builder.h b/source/slang/slang-ast-builder.h index a683e523c..cbd4f86c3 100644 --- a/source/slang/slang-ast-builder.h +++ b/source/slang/slang-ast-builder.h @@ -2,13 +2,13 @@ #ifndef SLANG_AST_BUILDER_H #define SLANG_AST_BUILDER_H -#include - -#include "slang-ast-support-types.h" +#include "../core/slang-memory-arena.h" +#include "../core/slang-type-traits.h" #include "slang-ast-all.h" +#include "slang-ast-support-types.h" #include "slang-ir.h" -#include "../core/slang-type-traits.h" -#include "../core/slang-memory-arena.h" + +#include namespace Slang { @@ -16,27 +16,27 @@ namespace Slang class SharedASTBuilder : public RefObject { friend class ASTBuilder; + public: - void registerBuiltinDecl(Decl* decl, BuiltinTypeModifier* modifier); void registerBuiltinRequirementDecl(Decl* decl, BuiltinRequirementModifier* modifier); void registerMagicDecl(Decl* decl, MagicTypeModifier* modifier); - /// Get the string type + /// Get the string type Type* getStringType(); - /// Get the native string type + /// Get the native string type Type* getNativeStringType(); - /// Get the enum type type + /// Get the enum type type Type* getEnumTypeType(); - /// Get the __Dynamic type + /// Get the __Dynamic type Type* getDynamicType(); - /// Get the NullPtr type + /// Get the NullPtr type Type* getNullPtrType(); - /// Get the NullPtr type + /// Get the NullPtr type Type* getNoneType(); - /// Get the `IDifferentiable` type + /// Get the `IDifferentiable` type Type* getDiffInterfaceType(); Type* getErrorType(); @@ -50,7 +50,7 @@ public: const ReflectClassInfo* findClassInfo(const UnownedStringSlice& slice); SyntaxClass findSyntaxClass(const UnownedStringSlice& slice); - // Look up a magic declaration by its name + // Look up a magic declaration by its name Decl* findMagicDecl(String const& name); Decl* tryFindMagicDecl(String const& name); @@ -60,10 +60,11 @@ public: return m_builtinRequirementDecls.getValue(kind); } - /// A name pool that can be used for lookup for findClassInfo etc. It is the same pool as the Session. + /// A name pool that can be used for lookup for findClassInfo etc. It is the same pool as the + /// Session. NamePool* getNamePool() { return m_namePool; } - /// Must be called before used + /// Must be called before used void init(Session* session); SharedASTBuilder(); @@ -80,6 +81,7 @@ public: } return m_thisTypeName; } + protected: // State shared between ASTBuilders @@ -90,7 +92,7 @@ protected: // The following types are created lazily, such that part of their definition // can be in the core module. - // + // // Note(tfoley): These logically belong to `Type`, // but order-of-declaration stuff makes that tricky // @@ -110,7 +112,7 @@ protected: Dictionary m_sliceToTypeMap; Dictionary m_nameToTypeMap; - + NamePool* m_namePool = nullptr; Name* m_thisTypeName = nullptr; @@ -138,8 +140,10 @@ struct ValKey } bool operator==(ValKey other) const { - if (val == other.val) return true; - if (hashCode != other.hashCode) return false; + if (val == other.val) + return true; + if (hashCode != other.hashCode) + return false; if (val->astNodeType != other.val->astNodeType) return false; if (val->m_operands.getCount() != other.val->m_operands.getCount()) @@ -151,7 +155,8 @@ struct ValKey } bool operator==(const ValNodeDesc& desc) const { - if (hashCode != desc.getHashCode()) return false; + if (hashCode != desc.getHashCode()) + return false; if (val->astNodeType != desc.type) return false; if (val->m_operands.getCount() != desc.operands.getCount()) @@ -169,28 +174,16 @@ template<> struct Hash { using is_transparent = void; - auto operator()(const ValKey& k) const - { - return k.getHashCode(); - } - auto operator()(const ValNodeDesc& k) const - { - return Hash{}(k); - } + auto operator()(const ValKey& k) const { return k.getHashCode(); } + auto operator()(const ValNodeDesc& k) const { return Hash{}(k); } }; // A functor which can compare ValKey for equality with ValNodeDesc struct ValKeyEqual { using is_transparent = void; - bool operator()(const Slang::ValKey& a, const Slang::ValKey& b) const - { - return a == b; - } - bool operator()(const Slang::ValNodeDesc& a, const Slang::ValKey& b) const - { - return b == a; - } + bool operator()(const Slang::ValKey& a, const Slang::ValKey& b) const { return a == b; } + bool operator()(const Slang::ValNodeDesc& a, const Slang::ValKey& b) const { return b == a; } }; class ASTBuilder : public RefObject @@ -198,7 +191,6 @@ class ASTBuilder : public RefObject friend class SharedASTBuilder; public: - Val* _getOrCreateImpl(ValNodeDesc&& desc) { if (auto found = m_cachedNodes.tryGetValue(desc)) @@ -220,7 +212,7 @@ public: Dictionary> m_cachedGenericDefaultArgs; /// Create AST types - template + template T* createImpl() { auto alloced = m_arena.allocate(sizeof(T)); @@ -238,24 +230,27 @@ public: return result; } - template + template T* create() { - static_assert(!IsBaseOf::Value, "ASTBuilder::create cannot be used to create a Val, use getOrCreate instead."); + static_assert( + !IsBaseOf::Value, + "ASTBuilder::create cannot be used to create a Val, use getOrCreate instead."); return createImpl(); } template T* create(TArgs&&... args) { - static_assert(!IsBaseOf::Value, "ASTBuilder::create cannot be used to create a Val, use getOrCreate instead."); + static_assert( + !IsBaseOf::Value, + "ASTBuilder::create cannot be used to create a Val, use getOrCreate instead."); return createImpl(args...); } public: - // For compile time check to see if thing being constructed is an AST type - template + template struct IsValidType { enum @@ -270,8 +265,8 @@ public: MemoryArena& getArena() { return m_arena; } - template - SLANG_FORCE_INLINE T* getOrCreate(TArgs ... args) + template + SLANG_FORCE_INLINE T* getOrCreate(TArgs... args) { SLANG_COMPILE_TIME_ASSERT(IsValidType::Value); ValNodeDesc desc; @@ -309,7 +304,9 @@ public: } template - DeclRef getDirectDeclRef(T* decl, typename std::enable_if_t>* = nullptr) + DeclRef getDirectDeclRef( + T* decl, + typename std::enable_if_t>* = nullptr) { return DeclRef(decl); } @@ -337,7 +334,8 @@ public: // Lookup of x from This is a lookup from w directly. // - Member(Lookup(w, someExtension), x) ==> Lookup(w, X) // Lookup of a decl defined in an extension is to lookup directly. - // - Member(Lookup(w, AssociatedType), TypeConstraintDecl) ==> Lookup(w, TypeConstraintDecl) + // - Member(Lookup(w, AssociatedType), TypeConstraintDecl) ==> Lookup(w, + // TypeConstraintDecl) // Type constraint of an associated type is defined directly in w. auto parentDeclKind = lookupDeclRef->getDecl()->astNodeType; @@ -346,9 +344,12 @@ public: case ASTNodeType::ThisTypeDecl: case ASTNodeType::ExtensionDecl: case ASTNodeType::AssocTypeDecl: - return getLookupDeclRef(lookupDeclRef->getLookupSource(), lookupDeclRef->getWitness(), memberDecl).template as(); - default: - break; + return getLookupDeclRef( + lookupDeclRef->getLookupSource(), + lookupDeclRef->getWitness(), + memberDecl) + .template as(); + default: break; } } else if (auto directDeclRef = as(parent.declRefBase)) @@ -386,11 +387,14 @@ public: // Unwrap any existing type casts. while (auto baseTypeCast = as(base)) base = baseTypeCast->getBase(); - + return getOrCreate(type, base); } - DeclRef getGenericAppDeclRef(DeclRef genericDeclRef, ConstArrayView args, Decl* innerDecl = nullptr) + DeclRef getGenericAppDeclRef( + DeclRef genericDeclRef, + ConstArrayView args, + Decl* innerDecl = nullptr) { if (!innerDecl) innerDecl = genericDeclRef.getDecl()->inner; @@ -398,7 +402,10 @@ public: return getOrCreate(innerDecl, genericDeclRef, args); } - DeclRef getGenericAppDeclRef(DeclRef genericDeclRef, Val::OperandView args, Decl* innerDecl = nullptr) + DeclRef getGenericAppDeclRef( + DeclRef genericDeclRef, + Val::OperandView args, + Decl* innerDecl = nullptr) { if (!innerDecl) innerDecl = genericDeclRef.getDecl()->inner; @@ -419,21 +426,57 @@ public: NodeBase* createByNodeType(ASTNodeType nodeType); - /// Get the built in types - SLANG_FORCE_INLINE Type* getBoolType() { return m_sharedASTBuilder->m_builtinTypes[Index(BaseType::Bool)]; } - SLANG_FORCE_INLINE Type* getHalfType() { return m_sharedASTBuilder->m_builtinTypes[Index(BaseType::Half)]; } - SLANG_FORCE_INLINE Type* getFloatType() { return m_sharedASTBuilder->m_builtinTypes[Index(BaseType::Float)]; } - SLANG_FORCE_INLINE Type* getDoubleType() { return m_sharedASTBuilder->m_builtinTypes[Index(BaseType::Double)]; } - SLANG_FORCE_INLINE Type* getIntType() { return m_sharedASTBuilder->m_builtinTypes[Index(BaseType::Int)]; } - SLANG_FORCE_INLINE Type* getInt64Type() { return m_sharedASTBuilder->m_builtinTypes[Index(BaseType::Int64)]; } - SLANG_FORCE_INLINE Type* getIntPtrType() { return m_sharedASTBuilder->m_builtinTypes[Index(BaseType::IntPtr)]; } - SLANG_FORCE_INLINE Type* getUIntType() { return m_sharedASTBuilder->m_builtinTypes[Index(BaseType::UInt)]; } - SLANG_FORCE_INLINE Type* getUInt64Type() { return m_sharedASTBuilder->m_builtinTypes[Index(BaseType::UInt64)]; } - SLANG_FORCE_INLINE Type* getUIntPtrType() { return m_sharedASTBuilder->m_builtinTypes[Index(BaseType::UIntPtr)]; } - SLANG_FORCE_INLINE Type* getVoidType() { return m_sharedASTBuilder->m_builtinTypes[Index(BaseType::Void)]; } - - /// Get a builtin type by the BaseType - SLANG_FORCE_INLINE Type* getBuiltinType(BaseType flavor) { return m_sharedASTBuilder->m_builtinTypes[Index(flavor)]; } + /// Get the built in types + SLANG_FORCE_INLINE Type* getBoolType() + { + return m_sharedASTBuilder->m_builtinTypes[Index(BaseType::Bool)]; + } + SLANG_FORCE_INLINE Type* getHalfType() + { + return m_sharedASTBuilder->m_builtinTypes[Index(BaseType::Half)]; + } + SLANG_FORCE_INLINE Type* getFloatType() + { + return m_sharedASTBuilder->m_builtinTypes[Index(BaseType::Float)]; + } + SLANG_FORCE_INLINE Type* getDoubleType() + { + return m_sharedASTBuilder->m_builtinTypes[Index(BaseType::Double)]; + } + SLANG_FORCE_INLINE Type* getIntType() + { + return m_sharedASTBuilder->m_builtinTypes[Index(BaseType::Int)]; + } + SLANG_FORCE_INLINE Type* getInt64Type() + { + return m_sharedASTBuilder->m_builtinTypes[Index(BaseType::Int64)]; + } + SLANG_FORCE_INLINE Type* getIntPtrType() + { + return m_sharedASTBuilder->m_builtinTypes[Index(BaseType::IntPtr)]; + } + SLANG_FORCE_INLINE Type* getUIntType() + { + return m_sharedASTBuilder->m_builtinTypes[Index(BaseType::UInt)]; + } + SLANG_FORCE_INLINE Type* getUInt64Type() + { + return m_sharedASTBuilder->m_builtinTypes[Index(BaseType::UInt64)]; + } + SLANG_FORCE_INLINE Type* getUIntPtrType() + { + return m_sharedASTBuilder->m_builtinTypes[Index(BaseType::UIntPtr)]; + } + SLANG_FORCE_INLINE Type* getVoidType() + { + return m_sharedASTBuilder->m_builtinTypes[Index(BaseType::Void)]; + } + + /// Get a builtin type by the BaseType + SLANG_FORCE_INLINE Type* getBuiltinType(BaseType flavor) + { + return m_sharedASTBuilder->m_builtinTypes[Index(flavor)]; + } Type* getSpecializedBuiltinType(Type* typeParam, const char* magicTypeName); Type* getSpecializedBuiltinType(ArrayView genericArgs, const char* magicTypeName); @@ -447,27 +490,27 @@ public: Type* getNoneType() { return m_sharedASTBuilder->getNoneType(); } Type* getEnumTypeType() { return m_sharedASTBuilder->getEnumTypeType(); } Type* getDiffInterfaceType() { return m_sharedASTBuilder->getDiffInterfaceType(); } - // Construct the type `Ptr`, where `Ptr` - // is looked up as a builtin type. + // Construct the type `Ptr`, where `Ptr` + // is looked up as a builtin type. PtrType* getPtrType(Type* valueType, AddressSpace addrSpace); - // Construct the type `Out` + // Construct the type `Out` OutType* getOutType(Type* valueType); - // Construct the type `InOut` + // Construct the type `InOut` InOutType* getInOutType(Type* valueType); - // Construct the type `Ref` + // Construct the type `Ref` RefType* getRefType(Type* valueType, AddressSpace addrSpace); - // Construct the type `ConstRef` + // Construct the type `ConstRef` ConstRefType* getConstRefType(Type* valueType); - // Construct the type `Optional` + // Construct the type `Optional` OptionalType* getOptionalType(Type* valueType); - // Construct a pointer type like `Ptr`, but where - // the actual type name for the pointer type is given by `ptrTypeName` + // Construct a pointer type like `Ptr`, but where + // the actual type name for the pointer type is given by `ptrTypeName` PtrTypeBase* getPtrType(Type* valueType, char const* ptrTypeName); PtrTypeBase* getPtrType(Type* valueType, AddressSpace addrSpace, char const* ptrTypeName); @@ -475,7 +518,11 @@ public: VectorExpressionType* getVectorType(Type* elementType, IntVal* elementCount); - MatrixExpressionType* getMatrixType(Type* elementType, IntVal* rowCount, IntVal* colCount, IntVal* layout); + MatrixExpressionType* getMatrixType( + Type* elementType, + IntVal* rowCount, + IntVal* colCount, + IntVal* layout); ConstantBufferType* getConstantBufferType(Type* elementType); @@ -487,10 +534,8 @@ public: SamplerStateType* getSamplerStateType(); - DifferentialPairType* getDifferentialPairType( - Type* valueType, - Witness* diffTypeWitness); - + DifferentialPairType* getDifferentialPairType(Type* valueType, Witness* diffTypeWitness); + DifferentialPtrPairType* getDifferentialPtrPairType( Type* valueType, Witness* diffRefTypeWitness); @@ -537,61 +582,81 @@ public: ConcreteTypePack* getTypePack(ArrayView types); - /// Produce a witness that `T : T` for any type `T` - TypeEqualityWitness* getTypeEqualityWitness( - Type* type); + /// Produce a witness that `T : T` for any type `T` + TypeEqualityWitness* getTypeEqualityWitness(Type* type); DeclaredSubtypeWitness* getDeclaredSubtypeWitness( - Type* subType, - Type* superType, - DeclRef const& declRef); - - TypePackSubtypeWitness* getSubtypeWitnessPack(Type* subType, Type* superType, ArrayView witnesses); - - SubtypeWitness* getExpandSubtypeWitness(Type* subType, Type* superType, SubtypeWitness* patternWitness); - - SubtypeWitness* getEachSubtypeWitness(Type* subType, Type* superType, SubtypeWitness* patternWitness); - - /// Produce a witness that `A <: C` given witnesses that `A <: B` and `B <: C` + Type* subType, + Type* superType, + DeclRef const& declRef); + + TypePackSubtypeWitness* getSubtypeWitnessPack( + Type* subType, + Type* superType, + ArrayView witnesses); + + SubtypeWitness* getExpandSubtypeWitness( + Type* subType, + Type* superType, + SubtypeWitness* patternWitness); + + SubtypeWitness* getEachSubtypeWitness( + Type* subType, + Type* superType, + SubtypeWitness* patternWitness); + + /// Produce a witness that `A <: C` given witnesses that `A <: B` and `B <: C` SubtypeWitness* getTransitiveSubtypeWitness( - SubtypeWitness* aIsSubtypeOfBWitness, - SubtypeWitness* bIsSubtypeOfCWitness); + SubtypeWitness* aIsSubtypeOfBWitness, + SubtypeWitness* bIsSubtypeOfCWitness); - /// Produce a witness that `T <: L` or `T <: R` given `T <: L&R` + /// Produce a witness that `T <: L` or `T <: R` given `T <: L&R` SubtypeWitness* getExtractFromConjunctionSubtypeWitness( - Type* subType, - Type* superType, + Type* subType, + Type* superType, SubtypeWitness* subIsSubtypeOfConjunction, - int indexOfSuperTypeInConjunction); + int indexOfSuperTypeInConjunction); - /// Produce a witnes that `S <: L&R` given witnesses that `S <: L` and `S <: R` + /// Produce a witnes that `S <: L&R` given witnesses that `S <: L` and `S <: R` SubtypeWitness* getConjunctionSubtypeWitness( - Type* sub, - Type* lAndR, + Type* sub, + Type* lAndR, SubtypeWitness* subIsLWitness, SubtypeWitness* subIsRWitness); - /// Helpers to get type info from the SharedASTBuilder - const ReflectClassInfo* findClassInfo(const UnownedStringSlice& slice) { return m_sharedASTBuilder->findClassInfo(slice); } - SyntaxClass findSyntaxClass(const UnownedStringSlice& slice) { return m_sharedASTBuilder->findSyntaxClass(slice); } + /// Helpers to get type info from the SharedASTBuilder + const ReflectClassInfo* findClassInfo(const UnownedStringSlice& slice) + { + return m_sharedASTBuilder->findClassInfo(slice); + } + SyntaxClass findSyntaxClass(const UnownedStringSlice& slice) + { + return m_sharedASTBuilder->findSyntaxClass(slice); + } - const ReflectClassInfo* findClassInfo(Name* name) { return m_sharedASTBuilder->findClassInfo(name); } - SyntaxClass findSyntaxClass(Name* name) { return m_sharedASTBuilder->findSyntaxClass(name); } + const ReflectClassInfo* findClassInfo(Name* name) + { + return m_sharedASTBuilder->findClassInfo(name); + } + SyntaxClass findSyntaxClass(Name* name) + { + return m_sharedASTBuilder->findSyntaxClass(name); + } MemoryArena& getMemoryArena() { return m_arena; } - /// Get the shared AST builder + /// Get the shared AST builder SharedASTBuilder* getSharedASTBuilder() { return m_sharedASTBuilder; } - /// Get the global session + /// Get the global session Session* getGlobalSession() { return m_sharedASTBuilder->m_session; } Index getId() { return m_id; } - /// Ctor - ASTBuilder(SharedASTBuilder* sharedASTBuilder, const String& name); + /// Ctor + ASTBuilder(SharedASTBuilder* sharedASTBuilder, const String& name); - /// Dtor + /// Dtor ~ASTBuilder(); protected: @@ -599,7 +664,7 @@ protected: ASTBuilder(); - template + template SLANG_FORCE_INLINE T* _initAndAdd(T* node) { SLANG_COMPILE_TIME_ASSERT(IsValidType::Value); @@ -623,7 +688,7 @@ protected: String m_name; Index m_id; - /// List of all nodes that require being dtored when ASTBuilder is dtored + /// List of all nodes that require being dtored when ASTBuilder is dtored List m_dtorNodes; SharedASTBuilder* m_sharedASTBuilder; @@ -645,13 +710,11 @@ struct SetASTBuilderContextRAII previousASTBuilder = getCurrentASTBuilder(); setCurrentASTBuilder(astBuilder); } - ~SetASTBuilderContextRAII() - { - setCurrentASTBuilder(previousASTBuilder); - } + ~SetASTBuilderContextRAII() { setCurrentASTBuilder(previousASTBuilder); } }; -#define SLANG_AST_BUILDER_RAII(astBuilder) SetASTBuilderContextRAII _setASTBuilderContextRAII(astBuilder) +#define SLANG_AST_BUILDER_RAII(astBuilder) \ + SetASTBuilderContextRAII _setASTBuilderContextRAII(astBuilder) } // namespace Slang diff --git a/source/slang/slang-ast-decl-ref.cpp b/source/slang/slang-ast-decl-ref.cpp index 6087efb26..b2cc99ae4 100644 --- a/source/slang/slang-ast-decl-ref.cpp +++ b/source/slang/slang-ast-decl-ref.cpp @@ -1,13 +1,16 @@ #include "slang-ast-builder.h" #include "slang-ast-reflect.h" -#include "slang-generated-ast.h" -#include "slang-generated-ast-macro.h" #include "slang-check-impl.h" +#include "slang-generated-ast-macro.h" +#include "slang-generated-ast.h" namespace Slang { -DeclRefBase* DirectDeclRef::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) +DeclRefBase* DirectDeclRef::_substituteImplOverride( + ASTBuilder* astBuilder, + SubstitutionSet subst, + int* ioDiff) { SLANG_UNUSED(astBuilder); SLANG_UNUSED(subst); @@ -53,7 +56,10 @@ DeclRefBase* _resolveAsDeclRef(DeclRefBase* declRefToResolve) return declRefToResolve; } -DeclRefBase* MemberDeclRef::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) +DeclRefBase* MemberDeclRef::_substituteImplOverride( + ASTBuilder* astBuilder, + SubstitutionSet subst, + int* ioDiff) { int diff = 0; auto substParent = getParentOperand()->substituteImpl(astBuilder, subst, &diff); @@ -101,15 +107,18 @@ Decl* LookupDeclRef::getSupDecl() SLANG_UNEXPECTED("Invalid lookup declref"); } -DeclRefBase* LookupDeclRef::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) +DeclRefBase* LookupDeclRef::_substituteImplOverride( + ASTBuilder* astBuilder, + SubstitutionSet subst, + int* ioDiff) { int diff = 0; - + auto substWitness = as(getWitness()->substituteImpl(astBuilder, subst, &diff)); if (diff == 0) return this; (*ioDiff)++; - + auto substSource = as(getLookupSource()->substituteImpl(astBuilder, subst, &diff)); SLANG_ASSERT(substSource); @@ -162,7 +171,8 @@ Val* LookupDeclRef::tryResolve(SubtypeWitness* newWitness, Type* newLookupSource { auto astBuilder = getCurrentASTBuilder(); Decl* requirementKey = getDecl(); - RequirementWitness requirementWitness = tryLookUpRequirementWitness(astBuilder, newWitness, requirementKey); + RequirementWitness requirementWitness = + tryLookUpRequirementWitness(astBuilder, newWitness, requirementKey); switch (requirementWitness.getFlavor()) { default: @@ -170,11 +180,11 @@ Val* LookupDeclRef::tryResolve(SubtypeWitness* newWitness, Type* newLookupSource break; case RequirementWitness::Flavor::val: - { - auto satisfyingVal = requirementWitness.getVal()->resolve(); - return satisfyingVal; - } - break; + { + auto satisfyingVal = requirementWitness.getVal()->resolve(); + return satisfyingVal; + } + break; } // Hard code implementation of T.Differential.Differential == T.Differential rule. @@ -196,7 +206,8 @@ Val* LookupDeclRef::tryResolve(SubtypeWitness* newWitness, Type* newLookupSource auto innerDeclRefType = as(newLookupSource); if (!innerDeclRefType) return nullptr; - auto innerBuiltinReq = innerDeclRefType->getDeclRef().getDecl()->findModifier(); + auto innerBuiltinReq = + innerDeclRefType->getDeclRef().getDecl()->findModifier(); if (!innerBuiltinReq) return nullptr; if (innerBuiltinReq->kind != BuiltinRequirementKind::DifferentialType) @@ -212,7 +223,10 @@ Val* LookupDeclRef::tryResolve(SubtypeWitness* newWitness, Type* newLookupSource return innerDeclRefType; } -DeclRefBase* GenericAppDeclRef::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) +DeclRefBase* GenericAppDeclRef::_substituteImplOverride( + ASTBuilder* astBuilder, + SubstitutionSet subst, + int* ioDiff) { int diff = 0; auto substGenericDeclRef = getGenericDeclRef()->substituteImpl(astBuilder, subst, &diff); @@ -224,10 +238,16 @@ DeclRefBase* GenericAppDeclRef::_substituteImplOverride(ASTBuilder* astBuilder, if (diff == 0) return this; (*ioDiff)++; - return astBuilder->getGenericAppDeclRef(substGenericDeclRef, substArgs.getArrayView(), getDecl()); + return astBuilder->getGenericAppDeclRef( + substGenericDeclRef, + substArgs.getArrayView(), + getDecl()); } -GenericDecl* GenericAppDeclRef::getGenericDecl() { return as(getGenericDeclRef()->getDecl()); } +GenericDecl* GenericAppDeclRef::getGenericDecl() +{ + return as(getGenericDeclRef()->getDecl()); +} void GenericAppDeclRef::_toTextOverride(StringBuilder& out) @@ -243,7 +263,8 @@ void GenericAppDeclRef::_toTextOverride(StringBuilder& out) Index argCount = args.getCount(); for (Index aa = 0; aa < Math::Min(paramCount, argCount); ++aa) { - if (aa != 0) out << ", "; + if (aa != 0) + out << ", "; args[aa]->toText(out); } out << ">"; @@ -266,7 +287,10 @@ Val* GenericAppDeclRef::_resolveImplOverride() diff = true; } if (diff) - resolvedVal = astBuilder->getGenericAppDeclRef(resolvedGenericDeclRef, resolvedArgs.getArrayView(), getDecl()); + resolvedVal = astBuilder->getGenericAppDeclRef( + resolvedGenericDeclRef, + resolvedArgs.getArrayView(), + getDecl()); return resolvedVal; } @@ -282,8 +306,14 @@ DeclRefBase* DeclRefBase::substituteImpl(ASTBuilder* astBuilder, SubstitutionSet SLANG_AST_NODE_VIRTUAL_CALL(DeclRefBase, substituteImpl, (astBuilder, subst, ioDiff)); } -DeclRefBase* DeclRefBase::getBase() { SLANG_AST_NODE_VIRTUAL_CALL(DeclRefBase, getBase, ()); } -void DeclRefBase::toText(StringBuilder& out) { SLANG_AST_NODE_VIRTUAL_CALL(DeclRefBase, toText, (out)); } +DeclRefBase* DeclRefBase::getBase() +{ + SLANG_AST_NODE_VIRTUAL_CALL(DeclRefBase, getBase, ()); +} +void DeclRefBase::toText(StringBuilder& out) +{ + SLANG_AST_NODE_VIRTUAL_CALL(DeclRefBase, toText, (out)); +} Name* DeclRefBase::getName() const { @@ -371,7 +401,8 @@ SubstExpr applySubstitutionToExpr(SubstitutionSet substSet, Expr* expr) } -DeclRefBase* SubstitutionSet::applyToDeclRef(ASTBuilder* astBuilder, DeclRefBase* otherDeclRef) const +DeclRefBase* SubstitutionSet::applyToDeclRef(ASTBuilder* astBuilder, DeclRefBase* otherDeclRef) + const { int diff = 0; return otherDeclRef->substituteImpl(astBuilder, *this, &diff); @@ -430,7 +461,7 @@ GenericAppDeclRef* SubstitutionSet::findGenericAppDeclRef() const DeclRef createDefaultSubstitutionsIfNeeded( ASTBuilder* astBuilder, SemanticsVisitor* semantics, - DeclRef declRef) + DeclRef declRef) { if (declRef.as()) return declRef; @@ -464,7 +495,8 @@ DeclRef createDefaultSubstitutionsIfNeeded( { parentDeclRef = astBuilder->getDirectDeclRef(current); } - parentDeclRef = astBuilder->getGenericAppDeclRef(parentDeclRef.as(), args.getArrayView()); + parentDeclRef = + astBuilder->getGenericAppDeclRef(parentDeclRef.as(), args.getArrayView()); } if (!parentDeclRef) return declRef; @@ -472,4 +504,4 @@ DeclRef createDefaultSubstitutionsIfNeeded( return parentDeclRef; return astBuilder->getMemberDeclRef(parentDeclRef, declRef.getDecl()); } -} +} // namespace Slang diff --git a/source/slang/slang-ast-decl.cpp b/source/slang/slang-ast-decl.cpp index cd9c43410..c0d0e9242 100644 --- a/source/slang/slang-ast-decl.cpp +++ b/source/slang/slang-ast-decl.cpp @@ -1,12 +1,14 @@ // slang-ast-decl.cpp +#include "slang-ast-decl.h" + #include "slang-ast-builder.h" +#include "slang-generated-ast-macro.h" #include "slang-syntax.h" -#include -#include "slang-generated-ast-macro.h" -#include "slang-ast-decl.h" +#include -namespace Slang { +namespace Slang +{ const TypeExp& TypeConstraintDecl::getSup() const { @@ -16,7 +18,7 @@ const TypeExp& TypeConstraintDecl::getSup() const const TypeExp& TypeConstraintDecl::_getSupOverride() const { SLANG_UNEXPECTED("TypeConstraintDecl::_getSupOverride not overridden"); - //return TypeExp::empty; + // return TypeExp::empty; } InterfaceDecl* findParentInterfaceDecl(Decl* decl) @@ -106,14 +108,14 @@ void ContainerDecl::buildMemberDictionary() bool isLocalVar(const Decl* decl) { const auto varDecl = as(decl); - if(!varDecl) + if (!varDecl) return false; const Decl* pp = varDecl->parentDecl; - if(as(pp)) + if (as(pp)) return true; - while(auto genericDecl = as(pp)) + while (auto genericDecl = as(pp)) pp = genericDecl->inner; - if(as(pp)) + if (as(pp)) return true; return false; diff --git a/source/slang/slang-ast-decl.h b/source/slang/slang-ast-decl.h index 36aa3a313..400a9635a 100644 --- a/source/slang/slang-ast-decl.h +++ b/source/slang/slang-ast-decl.h @@ -4,12 +4,13 @@ #include "slang-ast-base.h" -namespace Slang { +namespace Slang +{ // Syntax class definitions for declarations. // A group of declarations that should be treated as a unit -class DeclGroup: public DeclBase +class DeclGroup : public DeclBase { SLANG_AST_CLASS(DeclGroup) @@ -22,7 +23,7 @@ class UnresolvedDecl : public Decl }; // A "container" decl is a parent to other declarations -class ContainerDecl: public Decl +class ContainerDecl : public Decl { SLANG_ABSTRACT_AST_CLASS(ContainerDecl) @@ -65,14 +66,15 @@ class ContainerDecl: public Decl } } - SLANG_UNREFLECTED // We don't want to reflect the following fields + SLANG_UNREFLECTED // We don't want to reflect the following fields -private: - // Denotes how much of Members has been placed into the dictionary/transparentMembers. - // If this value equals the Members.getCount(), the dictionary is completely full and valid. - // If it's >= 0, then the Members after dictionaryLastCount are all that need to be added. - // If it < 0 it means that the dictionary/transparentMembers is invalid and needs to be recreated. - Index dictionaryLastCount = 0; + private : + // Denotes how much of Members has been placed into the dictionary/transparentMembers. + // If this value equals the Members.getCount(), the dictionary is completely full and valid. + // If it's >= 0, then the Members after dictionaryLastCount are all that need to be added. + // If it < 0 it means that the dictionary/transparentMembers is invalid and needs to be + // recreated. + Index dictionaryLastCount = 0; // Dictionary for looking up members by name. // This is built on demand before performing lookup. @@ -112,14 +114,14 @@ class LetDecl : public VarDecl SLANG_AST_CLASS(LetDecl) }; - // An `AggTypeDeclBase` captures the shared functionality - // between true aggregate type declarations and extension - // declarations: - // - // - Both can contain members (they are `ContainerDecl`s) - // - Both can have declared bases - // - Both expose a `this` variable in their body - // +// An `AggTypeDeclBase` captures the shared functionality +// between true aggregate type declarations and extension +// declarations: +// +// - Both can contain members (they are `ContainerDecl`s) +// - Both can have declared bases +// - Both expose a `this` variable in their body +// class AggTypeDeclBase : public ContainerDecl { SLANG_ABSTRACT_AST_CLASS(AggTypeDeclBase); @@ -142,7 +144,7 @@ enum class TypeTag }; // Declaration of a type that represents some sort of aggregate -class AggTypeDecl : public AggTypeDeclBase +class AggTypeDecl : public AggTypeDeclBase { SLANG_ABSTRACT_AST_CLASS(AggTypeDecl) @@ -156,13 +158,10 @@ class AggTypeDecl : public AggTypeDeclBase void addTag(TypeTag tag); bool hasTag(TypeTag tag); - FilteredMemberList getFields() - { - return getMembersOfType(); - } + FilteredMemberList getFields() { return getMembersOfType(); } }; -class StructDecl: public AggTypeDecl +class StructDecl : public AggTypeDecl { SLANG_AST_CLASS(StructDecl); }; @@ -219,20 +218,21 @@ class ThisTypeDecl : public AggTypeDecl }; // An interface which other types can conform to -class InterfaceDecl : public AggTypeDecl +class InterfaceDecl : public AggTypeDecl { SLANG_AST_CLASS(InterfaceDecl) ThisTypeDecl* getThisTypeDecl(); }; -class TypeConstraintDecl : public Decl +class TypeConstraintDecl : public Decl { SLANG_ABSTRACT_AST_CLASS(TypeConstraintDecl) const TypeExp& getSup() const; // Overrides should be public so base classes can access - // Implement _getSupOverride on derived classes to change behavior of getSup, as if getSup is virtual + // Implement _getSupOverride on derived classes to change behavior of getSup, as if getSup is + // virtual const TypeExp& _getSupOverride() const; }; @@ -311,7 +311,7 @@ class GlobalGenericValueParamDecl : public VarDeclBase }; // A scope for local declarations (e.g., as part of a statement) -class ScopeDecl : public ContainerDecl +class ScopeDecl : public ContainerDecl { SLANG_AST_CLASS(ScopeDecl) }; @@ -322,7 +322,8 @@ class ParamDecl : public VarDeclBase SLANG_AST_CLASS(ParamDecl) }; -// A parameter of a function declared in "modern" types (immutable unless explicitly `out` or `inout`) +// A parameter of a function declared in "modern" types (immutable unless explicitly `out` or +// `inout`) class ModernParamDecl : public ParamDecl { SLANG_AST_CLASS(ModernParamDecl) @@ -333,10 +334,7 @@ class CallableDecl : public ContainerDecl { SLANG_ABSTRACT_AST_CLASS(CallableDecl) - FilteredMemberList getParameters() - { - return getMembersOfType(); - } + FilteredMemberList getParameters() { return getMembersOfType(); } TypeExp returnType; @@ -359,7 +357,8 @@ class CallableDecl : public ContainerDecl CallableDecl* nextDecl = nullptr; }; -// Base class for callable things that may also have a body that is evaluated to produce their result +// Base class for callable things that may also have a body that is evaluated to produce their +// result class FunctionDeclBase : public CallableDecl { SLANG_ABSTRACT_AST_CLASS(FunctionDeclBase) @@ -419,21 +418,21 @@ class NamespaceDeclBase : public ContainerDecl SLANG_AST_CLASS(NamespaceDeclBase) }; - // A `namespace` declaration inside some module, that provides - // a named scope for declarations inside it. - // - // Note: Multiple `namespace` declarations with the same name - // in a given module/file will be collapsed into a single - // `NamespaceDecl` during parsing, so this declaration does - // not directly represent what is present in the input syntax. - // +// A `namespace` declaration inside some module, that provides +// a named scope for declarations inside it. +// +// Note: Multiple `namespace` declarations with the same name +// in a given module/file will be collapsed into a single +// `NamespaceDecl` during parsing, so this declaration does +// not directly represent what is present in the input syntax. +// class NamespaceDecl : public NamespaceDeclBase { SLANG_AST_CLASS(NamespaceDecl) }; - // A "module" of code (essentially, a single translation unit) - // that provides a scope for some number of declarations. +// A "module" of code (essentially, a single translation unit) +// that provides a scope for some number of declarations. class ModuleDecl : public NamespaceDeclBase { SLANG_AST_CLASS(ModuleDecl) @@ -445,29 +444,30 @@ class ModuleDecl : public NamespaceDeclBase // Module* module = nullptr; - /// Map a decl to the list of its associated decls. - /// - /// This mapping is filled in during semantic checking, as the decl declarations get checked or generated. - /// - OrderedDictionary> mapDeclToAssociatedDecls; - - /// Whether the module is defined in legacy language. - /// The legacy Slang language does not have visibility modifiers and everything is treated as - /// `public`. Newer version of the language introduces visibility and makes `internal` as the - /// default. To prevent this from breaking existing code, we need to know whether a module is - /// written in the legacy language. We detect this by checking whether the module has any - /// visibility modifiers, or if the module uses new language constructs, e.g. `module`, `__include`, - /// `__implementing` etc. + /// Map a decl to the list of its associated decls. + /// + /// This mapping is filled in during semantic checking, as the decl declarations get checked or + /// generated. + /// + OrderedDictionary> mapDeclToAssociatedDecls; + + /// Whether the module is defined in legacy language. + /// The legacy Slang language does not have visibility modifiers and everything is treated as + /// `public`. Newer version of the language introduces visibility and makes `internal` as the + /// default. To prevent this from breaking existing code, we need to know whether a module is + /// written in the legacy language. We detect this by checking whether the module has any + /// visibility modifiers, or if the module uses new language constructs, e.g. `module`, + /// `__include`, + /// `__implementing` etc. bool isInLegacyLanguage = true; SLANG_UNREFLECTED - /// Map a type to the list of extensions of that type (if any) declared in this module - /// - /// This mapping is filled in during semantic checking, as `ExtensionDecl`s get checked. - /// + /// Map a type to the list of extensions of that type (if any) declared in this module + /// + /// This mapping is filled in during semantic checking, as `ExtensionDecl`s get checked. + /// Dictionary> mapTypeToCandidateExtensions; - }; // Represents a transparent scope of declarations that are defined in a single source file. @@ -627,8 +627,8 @@ class SyntaxDecl : public Decl SLANG_UNREFLECTED // Callback to invoke in order to parse syntax with this keyword. - SyntaxParseCallback parseCallback = nullptr; - void* parseUserData = nullptr; + SyntaxParseCallback parseCallback = nullptr; + void* parseUserData = nullptr; }; // A declaration of an attribute to be used with `[name(...)]` syntax. @@ -640,15 +640,14 @@ class AttributeDecl : public ContainerDecl SyntaxClass syntaxClass; }; -// A synthesized decl used as a placeholder for a differentiable function requirement. This decl will -// be a child of interface decl. -// This allows us to form an interface requirement key for the derivative of an interface function. -// The synthesized `DerivativeRequirementDecl` will be a child of the original function requirement -// decl after an interface type is checked. +// A synthesized decl used as a placeholder for a differentiable function requirement. This decl +// will be a child of interface decl. This allows us to form an interface requirement key for the +// derivative of an interface function. The synthesized `DerivativeRequirementDecl` will be a child +// of the original function requirement decl after an interface type is checked. class DerivativeRequirementDecl : public FunctionDeclBase { SLANG_AST_CLASS(DerivativeRequirementDecl) - + // The original requirement decl. Decl* originalRequirementDecl = nullptr; @@ -656,8 +655,8 @@ class DerivativeRequirementDecl : public FunctionDeclBase Type* diffThisType; }; -// A reference to a synthesized decl representing a differentiable function requirement, this decl will -// be a child in the orignal function. +// A reference to a synthesized decl representing a differentiable function requirement, this decl +// will be a child in the orignal function. class DerivativeRequirementReferenceDecl : public FunctionDeclBase { SLANG_AST_CLASS(DerivativeRequirementReferenceDecl) @@ -681,7 +680,10 @@ bool isLocalVar(const Decl* decl); // Add a sibling lookup scope for `dest` to refer to `source`. -void addSiblingScopeForContainerDecl(ASTBuilder* builder, ContainerDecl* dest, ContainerDecl* source); +void addSiblingScopeForContainerDecl( + ASTBuilder* builder, + ContainerDecl* dest, + ContainerDecl* source); void addSiblingScopeForContainerDecl(ASTBuilder* builder, Scope* destScope, ContainerDecl* source); } // namespace Slang diff --git a/source/slang/slang-ast-dump.cpp b/source/slang/slang-ast-dump.cpp index b3a554e6d..dd80b1d4b 100644 --- a/source/slang/slang-ast-dump.cpp +++ b/source/slang/slang-ast-dump.cpp @@ -1,15 +1,15 @@ // slang-ast-dump.cpp #include "slang-ast-dump.h" -#include -#include - -#include "slang-compiler.h" #include "../core/slang-string.h" - +#include "slang-compiler.h" #include "slang-generated-ast-macro.h" -namespace Slang { +#include +#include + +namespace Slang +{ struct ASTDumpContext @@ -23,8 +23,8 @@ struct ASTDumpContext struct ScopeWrite { - ScopeWrite(ASTDumpContext* context): - m_context(context) + ScopeWrite(ASTDumpContext* context) + : m_context(context) { if (m_context->m_scopeWriteCount == 0) { @@ -148,7 +148,7 @@ struct ASTDumpContext cur = cur->parent; } - for (Index i = scopes.getCount() - 1; i >= 0 ; --i) + for (Index i = scopes.getCount() - 1; i >= 0; --i) { buf << "::"; const Scope* curScope = scopes[i]; @@ -182,7 +182,7 @@ struct ASTDumpContext } } - template + template void dump(const List& list) { m_writer->emit(" { \n"); @@ -203,7 +203,7 @@ struct ASTDumpContext m_writer->emit("}"); } - template + template void dump(const ShortList& list) { m_writer->emit(" { \n"); @@ -241,14 +241,12 @@ struct ASTDumpContext if (manager && sourceLoc.isValid()) { HumaneSourceLoc humaneLoc = manager->getHumaneLoc(sourceLoc); - ScopeWrite(this).getBuf() << " " << humaneLoc.pathInfo.foundPath << ":" << humaneLoc.line; + ScopeWrite(this).getBuf() + << " " << humaneLoc.pathInfo.foundPath << ":" << humaneLoc.line; } } - static char _getHexDigit(UInt32 v) - { - return (v < 10) ? char(v + '0') : char('a' + v - 10); - } + static char _getHexDigit(UInt32 v) { return (v < 10) ? char(v + '0') : char('a' + v - 10); } static bool _charNeedsEscaping(char c) { @@ -258,7 +256,7 @@ struct ASTDumpContext void dump(const UnownedStringSlice& slice) { - + ScopeWrite scope(this); auto& buf = scope.getBuf(); buf.appendChar('\"'); @@ -270,12 +268,12 @@ struct ASTDumpContext } else { - buf << "\\0x" << _getHexDigit(UInt32(c) >> 4) << _getHexDigit(c & 0xf); + buf << "\\0x" << _getHexDigit(UInt32(c) >> 4) << _getHexDigit(c & 0xf); } } buf.appendChar('\"'); } - + void dump(const Token& token) { ScopeWrite(this).getBuf() << " { " << TokenTypeToString(token.type) << ", "; @@ -302,35 +300,18 @@ struct ASTDumpContext return m_objects.getCount() - 1; } - void dump(uint32_t v) - { - m_writer->emit((uint64_t)v); - } - void dump(uint64_t v) - { - m_writer->emit(v); - } - void dump(int32_t v) - { - m_writer->emit(v); - } - void dump(FloatingPointLiteralValue v) - { - m_writer->emit(v); - } + void dump(uint32_t v) { m_writer->emit((uint64_t)v); } + void dump(uint64_t v) { m_writer->emit(v); } + void dump(int32_t v) { m_writer->emit(v); } + void dump(FloatingPointLiteralValue v) { m_writer->emit(v); } - void dump(IntegerLiteralValue v) - { - m_writer->emit(v); - } - void dump(CapabilityName v) - { - m_writer->emit(capabilityNameToString(v)); - } + void dump(IntegerLiteralValue v) { m_writer->emit(v); } + void dump(CapabilityName v) { m_writer->emit(capabilityNameToString(v)); } void dump(const SemanticVersion& version) { - ScopeWrite(this).getBuf() << UInt(version.m_major) << "." << UInt(version.m_minor) << "." << UInt(version.m_patch); + ScopeWrite(this).getBuf() << UInt(version.m_major) << "." << UInt(version.m_minor) << "." + << UInt(version.m_patch); } void dump(const NameLoc& nameLoc) { @@ -347,38 +328,14 @@ struct ASTDumpContext dump(nameLoc.loc); m_writer->emit(" }"); } - void dump(BaseType baseType) - { - m_writer->emit(BaseTypeInfo::asText(baseType)); - } - void dump(Stage stage) - { - m_writer->emit(getStageName(stage)); - } - void dump(ImageFormat imageFormat) - { - m_writer->emit(getGLSLNameForImageFormat(imageFormat)); - } - void dump(TryClauseType clauseType) - { - m_writer->emit(getTryClauseTypeName(clauseType)); - } - void dump(BuiltinRequirementKind kind) - { - m_writer->emit((int)kind); - } - void dump(MarkupVisibility v) - { - m_writer->emit((int)v); - } - void dump(TypeTag tag) - { - m_writer->emit((int)tag); - } - void dump(const String& string) - { - dump(string.getUnownedSlice()); - } + void dump(BaseType baseType) { m_writer->emit(BaseTypeInfo::asText(baseType)); } + void dump(Stage stage) { m_writer->emit(getStageName(stage)); } + void dump(ImageFormat imageFormat) { m_writer->emit(getGLSLNameForImageFormat(imageFormat)); } + void dump(TryClauseType clauseType) { m_writer->emit(getTryClauseTypeName(clauseType)); } + void dump(BuiltinRequirementKind kind) { m_writer->emit((int)kind); } + void dump(MarkupVisibility v) { m_writer->emit((int)v); } + void dump(TypeTag tag) { m_writer->emit((int)tag); } + void dump(const String& string) { dump(string.getUnownedSlice()); } void dump(const DiagnosticInfo* info) { @@ -406,13 +363,13 @@ struct ASTDumpContext m_writer->emit("}"); } - template + template void dump(const SyntaxClass& cls) { m_writer->emit(cls.classInfo->m_name); } - template + template void dump(const Dictionary& dict) { m_writer->emit(" { \n"); @@ -431,7 +388,7 @@ struct ASTDumpContext m_writer->emit("}"); } - template + template void dump(const OrderedDictionary& dict) { m_writer->emit(" { \n"); @@ -463,8 +420,9 @@ struct ASTDumpContext void dump(const DeclCheckStateExt& extState) { auto state = extState.getState(); - - ScopeWrite(this).getBuf() << "DeclCheckStateExt{" << extState.isBeingChecked() << ", " << Index(state) << "}"; + + ScopeWrite(this).getBuf() << "DeclCheckStateExt{" << extState.isBeingChecked() << ", " + << Index(state) << "}"; } void dump(FeedbackType::Kind kind) @@ -473,8 +431,8 @@ struct ASTDumpContext const char* name = nullptr; switch (kind) { - case FeedbackType::Kind::MinMip: name = "MinMip"; break; - case FeedbackType::Kind::MipRegionUsed: name = "MipRegionUsed"; break; + case FeedbackType::Kind::MinMip: name = "MinMip"; break; + case FeedbackType::Kind::MipRegionUsed: name = "MipRegionUsed"; break; } m_buf << "FeedbackType::Kind{" << name << "}"; @@ -486,9 +444,9 @@ struct ASTDumpContext { switch (flavor) { - case SamplerStateFlavor::SamplerState: m_writer->emit("sampler"); break; - case SamplerStateFlavor::SamplerComparisonState: m_writer->emit("samplerComparison"); break; - default: m_writer->emit("unknown"); break; + case SamplerStateFlavor::SamplerState: m_writer->emit("sampler"); break; + case SamplerStateFlavor::SamplerComparisonState: m_writer->emit("samplerComparison"); break; + default: m_writer->emit("unknown"); break; } } @@ -505,9 +463,12 @@ struct ASTDumpContext dump(qualType.type); } - void dump(SyntaxParseCallback callback) { _dumpPtr(UnownedStringSlice::fromLiteral("SyntaxParseCallback"), (const void*)callback); } + void dump(SyntaxParseCallback callback) + { + _dumpPtr(UnownedStringSlice::fromLiteral("SyntaxParseCallback"), (const void*)callback); + } - template + template void dump(const T (&in)[SIZE]) { m_writer->emit(" { \n"); @@ -565,15 +526,9 @@ struct ASTDumpContext m_writer->dedent(); m_writer->emit("}"); } - void dump(const ExpandedSpecializationArg& arg) - { - dump(arg.witness); - } + void dump(const ExpandedSpecializationArg& arg) { dump(arg.witness); } - void dump(const TransparentMemberInfo& memInfo) - { - dump(memInfo.decl); - } + void dump(const TransparentMemberInfo& memInfo) { dump(memInfo.decl); } void dumpRemaining() { @@ -609,7 +564,7 @@ struct ASTDumpContext } } - template + template void dumpField(const char* name, const T& value) { m_writer->emit(name); @@ -640,15 +595,9 @@ struct ASTDumpContext { switch (operand.kind) { - case ValNodeOperandKind::ConstantValue: - dump(operand.values.intOperand); - break; - case ValNodeOperandKind::ValNode: - dump(operand.values.nodeOperand); - break; - case ValNodeOperandKind::ASTNode: - dump(operand.values.nodeOperand); - break; + case ValNodeOperandKind::ConstantValue: dump(operand.values.intOperand); break; + case ValNodeOperandKind::ValNode: dump(operand.values.nodeOperand); break; + case ValNodeOperandKind::ASTNode: dump(operand.values.nodeOperand); break; } } @@ -671,29 +620,16 @@ struct ASTDumpContext void dump(const SPIRVAsmOperand& operand) { - switch(operand.flavor) + switch (operand.flavor) { - case SPIRVAsmOperand::Id: - m_writer->emit("%"); - break; - case SPIRVAsmOperand::ResultMarker: - m_writer->emit("result"); - break; + case SPIRVAsmOperand::Id: m_writer->emit("%"); break; + case SPIRVAsmOperand::ResultMarker: m_writer->emit("result"); break; case SPIRVAsmOperand::Literal: - case SPIRVAsmOperand::NamedValue: - break; - case SPIRVAsmOperand::SlangValue: - m_writer->emit("$"); - break; - case SPIRVAsmOperand::SlangValueAddr: - m_writer->emit("&"); - break; - case SPIRVAsmOperand::SlangType: - m_writer->emit("$$"); - break; - case SPIRVAsmOperand::SlangImmediateValue: - m_writer->emit("!"); - break; + case SPIRVAsmOperand::NamedValue: break; + case SPIRVAsmOperand::SlangValue: m_writer->emit("$"); break; + case SPIRVAsmOperand::SlangValueAddr: m_writer->emit("&"); break; + case SPIRVAsmOperand::SlangType: m_writer->emit("$$"); break; + case SPIRVAsmOperand::SlangImmediateValue: m_writer->emit("!"); break; case SPIRVAsmOperand::RayPayloadFromLocation: m_writer->emit("__rayPayloadFromLocation"); break; @@ -703,13 +639,10 @@ struct ASTDumpContext case SPIRVAsmOperand::RayCallableFromLocation: m_writer->emit("__rayCallableFromLocation"); break; - case SPIRVAsmOperand::BuiltinVar: - m_writer->emit("builtin"); - break; - default: - SLANG_UNREACHABLE("Unhandled case in ast dump for SPIRVAsmOperand"); + case SPIRVAsmOperand::BuiltinVar: m_writer->emit("builtin"); break; + default: SLANG_UNREACHABLE("Unhandled case in ast dump for SPIRVAsmOperand"); } - if(operand.expr) + if (operand.expr) dump(operand.expr); else dump(operand.token); @@ -718,7 +651,7 @@ struct ASTDumpContext void dump(const SPIRVAsmInst& inst) { dump(inst.opcode); - for(const auto& o : inst.operands) + for (const auto& o : inst.operands) dump(o); } @@ -727,7 +660,7 @@ struct ASTDumpContext m_writer->emit("spirv_asm\n"); m_writer->emit("{\n"); m_writer->indent(); - for(const auto& i : expr.insts) + for (const auto& i : expr.insts) { dump(i); m_writer->emit(";\n"); @@ -764,11 +697,8 @@ struct ASTDumpContext void dumpObjectFull(NodeBase* node); - ASTDumpContext(SourceWriter* writer, ASTDumpUtil::Flags flags, ASTDumpUtil::Style dumpStyle): - m_writer(writer), - m_scopeWriteCount(0), - m_dumpStyle(dumpStyle), - m_dumpFlags(flags) + ASTDumpContext(SourceWriter* writer, ASTDumpUtil::Flags flags, ASTDumpUtil::Style dumpStyle) + : m_writer(writer), m_scopeWriteCount(0), m_dumpStyle(dumpStyle), m_dumpFlags(flags) { } @@ -780,7 +710,7 @@ struct ASTDumpContext // Using the SourceWriter, for automatic indentation. SourceWriter* m_writer; - Dictionary m_objectMap; ///< Object index + Dictionary m_objectMap; ///< Object index List m_objects; StringBuilder m_buf; @@ -792,26 +722,28 @@ struct ASTDumpContext struct ASTDumpAccess { -#define SLANG_AST_DUMP_FIELD(FIELD_NAME, TYPE, param) context.dumpField(#FIELD_NAME, static_cast(base)->FIELD_NAME); +#define SLANG_AST_DUMP_FIELD(FIELD_NAME, TYPE, param) \ + context.dumpField(#FIELD_NAME, static_cast(base)->FIELD_NAME); #define SLANG_AST_DUMP_FIELDS_IMPL(NAME, SUPER, ORIGIN, LAST, MARKER, TYPE, param) \ -case ASTNodeType::NAME: \ -{ \ - SLANG_FIELDS_ASTNode_##NAME(SLANG_AST_DUMP_FIELD, NAME) \ - break; \ -} + case ASTNodeType::NAME: \ + { \ + SLANG_FIELDS_ASTNode_##NAME(SLANG_AST_DUMP_FIELD, NAME) break; \ + } static void dump(ASTNodeType type, NodeBase* base, ASTDumpContext& context) { switch (type) { - SLANG_ALL_ASTNode_NodeBase(SLANG_AST_DUMP_FIELDS_IMPL, _) - default: break; + SLANG_ALL_ASTNode_NodeBase(SLANG_AST_DUMP_FIELDS_IMPL, _) default : break; } } }; -void ASTDumpContext::dumpObjectReference(const ReflectClassInfo& type, NodeBase* obj, Index objIndex) +void ASTDumpContext::dumpObjectReference( + const ReflectClassInfo& type, + NodeBase* obj, + Index objIndex) { SLANG_UNUSED(obj); ScopeWrite(this).getBuf() << type.m_name << ":" << objIndex; @@ -878,7 +810,7 @@ void ASTDumpContext::dumpObjectFull(NodeBase* node) } } -/* static */void ASTDumpUtil::dump(NodeBase* node, Style style, Flags flags, SourceWriter* writer) +/* static */ void ASTDumpUtil::dump(NodeBase* node, Style style, Flags flags, SourceWriter* writer) { ASTDumpContext context(writer, flags, style); context.dumpObjectFull(node); diff --git a/source/slang/slang-ast-dump.h b/source/slang/slang-ast-dump.h index 7a2b30c3e..4bf2a3985 100644 --- a/source/slang/slang-ast-dump.h +++ b/source/slang/slang-ast-dump.h @@ -2,9 +2,8 @@ #ifndef SLANG_AST_DUMP_H #define SLANG_AST_DUMP_H -#include "slang-syntax.h" - #include "slang-emit-source-writer.h" +#include "slang-syntax.h" namespace Slang { @@ -25,7 +24,7 @@ struct ASTDumpUtil enum Enum : Flags { HideSourceLoc = 0x1, - HideScope = 0x2, + HideScope = 0x2, }; }; diff --git a/source/slang/slang-ast-expr.h b/source/slang/slang-ast-expr.h index 8b779e8db..1f44e31c2 100644 --- a/source/slang/slang-ast-expr.h +++ b/source/slang/slang-ast-expr.h @@ -4,23 +4,24 @@ #include "slang-ast-base.h" -namespace Slang { +namespace Slang +{ using SpvWord = uint32_t; // Syntax class definitions for expressions. -// - // A placeholder for where an Expr is expected but is missing from source. +// +// A placeholder for where an Expr is expected but is missing from source. class IncompleteExpr : public Expr { SLANG_AST_CLASS(IncompleteExpr) }; - // Base class for expressions that will reference declarations -class DeclRefExpr: public Expr +// Base class for expressions that will reference declarations +class DeclRefExpr : public Expr { SLANG_ABSTRACT_AST_CLASS(DeclRefExpr) - + // The declaration of the symbol being referenced DeclRef declRef; @@ -66,7 +67,7 @@ class OverloadedExpr : public Expr // An expression that references an overloaded set of declarations // having the same name. -class OverloadedExpr2: public Expr +class OverloadedExpr2 : public Expr { SLANG_AST_CLASS(OverloadedExpr2) @@ -94,7 +95,7 @@ class IntegerLiteralExpr : public LiteralExpr IntegerLiteralValue value; }; -class FloatingPointLiteralExpr: public LiteralExpr +class FloatingPointLiteralExpr : public LiteralExpr { SLANG_AST_CLASS(FloatingPointLiteralExpr) FloatingPointLiteralValue value; @@ -122,7 +123,7 @@ class StringLiteralExpr : public LiteralExpr // TODO: consider storing the "segments" of the string // literal, in the case where multiple literals were - //lined up at the lexer level, e.g.: + // lined up at the lexer level, e.g.: // // "first" "second" "third" // @@ -187,7 +188,7 @@ class AppExprBase : public ExprWithArgsBase List argumentDelimeterLocs; }; -class InvokeExpr: public AppExprBase +class InvokeExpr : public AppExprBase { SLANG_AST_CLASS(InvokeExpr) }; @@ -197,7 +198,8 @@ enum class TryClauseType None, Standard, // Normal `try` clause Optional, // (Not implemented) `try?` clause that returns an optional value. - Assert, // (Not implemented) `try!` clause that should always succeed and triggers runtime error if failed. + Assert, // (Not implemented) `try!` clause that should always succeed and triggers runtime error + // if failed. }; char const* getTryClauseTypeName(TryClauseType value); @@ -219,25 +221,25 @@ class NewExpr : public InvokeExpr SLANG_AST_CLASS(NewExpr) }; -class OperatorExpr: public InvokeExpr +class OperatorExpr : public InvokeExpr { SLANG_AST_CLASS(OperatorExpr) }; -class InfixExpr: public OperatorExpr +class InfixExpr : public OperatorExpr { SLANG_AST_CLASS(InfixExpr) }; -class PrefixExpr: public OperatorExpr +class PrefixExpr : public OperatorExpr { SLANG_AST_CLASS(PrefixExpr) }; -class PostfixExpr: public OperatorExpr +class PostfixExpr : public OperatorExpr { SLANG_AST_CLASS(PostfixExpr) }; -class IndexExpr: public Expr +class IndexExpr : public Expr { SLANG_AST_CLASS(IndexExpr) Expr* baseExpression; @@ -248,7 +250,7 @@ class IndexExpr: public Expr List argumentDelimeterLocs; }; -class MemberExpr: public DeclRefExpr +class MemberExpr : public DeclRefExpr { SLANG_AST_CLASS(MemberExpr) Expr* baseExpression = nullptr; @@ -262,7 +264,7 @@ class DerefMemberExpr : public MemberExpr }; // Member looked up on a type, rather than a value -class StaticMemberExpr: public DeclRefExpr +class StaticMemberExpr : public DeclRefExpr { SLANG_AST_CLASS(StaticMemberExpr) Expr* baseExpression = nullptr; @@ -273,7 +275,7 @@ struct MatrixCoord { bool operator==(const MatrixCoord& rhs) const { return row == rhs.row && col == rhs.col; }; bool operator!=(const MatrixCoord& rhs) const { return !(*this == rhs); }; - // Rows and columns are zero indexed + // Rows and columns are zero indexed int row; int col; }; @@ -287,7 +289,7 @@ class MatrixSwizzleExpr : public Expr SourceLoc memberOpLoc; }; -class SwizzleExpr: public Expr +class SwizzleExpr : public Expr { SLANG_AST_CLASS(SwizzleExpr) Expr* base = nullptr; @@ -303,22 +305,22 @@ class MakeRefExpr : public Expr }; // A dereference of a pointer or pointer-like type -class DerefExpr: public Expr +class DerefExpr : public Expr { SLANG_AST_CLASS(DerefExpr) Expr* base = nullptr; }; // Any operation that performs type-casting -class TypeCastExpr: public InvokeExpr +class TypeCastExpr : public InvokeExpr { SLANG_AST_CLASS(TypeCastExpr) -// TypeExp TargetType; -// Expr* Expression = nullptr; + // TypeExp TargetType; + // Expr* Expression = nullptr; }; // An explicit type-cast that appear in the user's code with `(type) expr` syntax -class ExplicitCastExpr: public TypeCastExpr +class ExplicitCastExpr : public TypeCastExpr { SLANG_AST_CLASS(ExplicitCastExpr) }; @@ -341,7 +343,10 @@ class LValueImplicitCastExpr : public TypeCastExpr { SLANG_AST_CLASS(LValueImplicitCastExpr) - explicit LValueImplicitCastExpr(const TypeCastExpr& rhs) :Super(rhs) {} + explicit LValueImplicitCastExpr(const TypeCastExpr& rhs) + : Super(rhs) + { + } }; // To work around situations like int += uint @@ -351,8 +356,11 @@ class OutImplicitCastExpr : public LValueImplicitCastExpr { SLANG_AST_CLASS(OutImplicitCastExpr) - /// Allow explict construction from any TypeCastExpr - explicit OutImplicitCastExpr(const TypeCastExpr& rhs) :Super(rhs) {} + /// Allow explict construction from any TypeCastExpr + explicit OutImplicitCastExpr(const TypeCastExpr& rhs) + : Super(rhs) + { + } }; class InOutImplicitCastExpr : public LValueImplicitCastExpr @@ -360,14 +368,17 @@ class InOutImplicitCastExpr : public LValueImplicitCastExpr SLANG_AST_CLASS(InOutImplicitCastExpr) /// Allow explict construction from any TypeCastExpr - explicit InOutImplicitCastExpr(const TypeCastExpr& rhs) :Super(rhs) {} + explicit InOutImplicitCastExpr(const TypeCastExpr& rhs) + : Super(rhs) + { + } }; - /// A cast of a value to a super-type of its type. - /// - /// The type being cast to is stored as this expression's `type`. - /// -class CastToSuperTypeExpr: public Expr +/// A cast of a value to a super-type of its type. +/// +/// The type being cast to is stored as this expression's `type`. +/// +class CastToSuperTypeExpr : public Expr { SLANG_AST_CLASS(CastToSuperTypeExpr) @@ -377,12 +388,12 @@ class CastToSuperTypeExpr: public Expr /// Expr* valueArg = nullptr; - /// A witness showing that `valueArg`'s type is a sub-type of this expression's `type` + /// A witness showing that `valueArg`'s type is a sub-type of this expression's `type` Val* witnessArg = nullptr; }; - /// A `value is Type` expression that evaluates to `true` if type of `value` is a sub-type of - /// `Type`. +/// A `value is Type` expression that evaluates to `true` if type of `value` is a sub-type of +/// `Type`. class IsTypeExpr : public Expr { SLANG_AST_CLASS(IsTypeExpr) @@ -397,8 +408,8 @@ class IsTypeExpr : public Expr BoolLiteralExpr* constantVal = nullptr; }; - /// A `value as Type` expression that casts `value` to `Type` within type hierarchy. - /// The result is undefined if `value` is not `Type`. +/// A `value as Type` expression that casts `value` to `Type` within type hierarchy. +/// The result is undefined if `value` is not `Type`. class AsTypeExpr : public Expr { SLANG_AST_CLASS(AsTypeExpr) @@ -408,7 +419,6 @@ class AsTypeExpr : public Expr // A witness showing that `typeExpr` is a subtype of `typeof(value)`. Val* witnessArg = nullptr; - }; class SizeOfLikeExpr : public Expr @@ -441,15 +451,15 @@ class MakeOptionalExpr : public Expr { SLANG_AST_CLASS(MakeOptionalExpr) - // If `value` is null, this constructs an `Optional` that doesn't have a value. + // If `value` is null, this constructs an `Optional` that doesn't have a value. Expr* value = nullptr; Expr* typeExpr = nullptr; }; - /// A cast of a value to the same type, with different modifiers. - /// - /// The type being cast to is stored as this expression's `type`. - /// +/// A cast of a value to the same type, with different modifiers. +/// +/// The type being cast to is stored as this expression's `type`. +/// class ModifierCastExpr : public Expr { SLANG_AST_CLASS(ModifierCastExpr) @@ -461,39 +471,39 @@ class ModifierCastExpr : public Expr Expr* valueArg = nullptr; }; -class SelectExpr: public OperatorExpr +class SelectExpr : public OperatorExpr { SLANG_AST_CLASS(SelectExpr) }; -class LogicOperatorShortCircuitExpr: public OperatorExpr +class LogicOperatorShortCircuitExpr : public OperatorExpr { SLANG_AST_CLASS(LogicOperatorShortCircuitExpr) public: enum Flavor { - And, // && - Or, // || + And, // && + Or, // || }; Flavor flavor; }; -class GenericAppExpr: public AppExprBase +class GenericAppExpr : public AppExprBase { SLANG_AST_CLASS(GenericAppExpr) }; // An expression representing re-use of the syntax for a type in more // than once conceptually-distinct declaration -class SharedTypeExpr: public Expr +class SharedTypeExpr : public Expr { SLANG_AST_CLASS(SharedTypeExpr) // The underlying type expression that we want to share TypeExp base; }; -class AssignExpr: public Expr +class AssignExpr : public Expr { SLANG_AST_CLASS(AssignExpr) Expr* left = nullptr; @@ -504,7 +514,7 @@ class AssignExpr: public Expr // // We keep this around explicitly to be sure we don't lose any structure // when we do rewriter stuff. -class ParenExpr: public Expr +class ParenExpr : public Expr { SLANG_AST_CLASS(ParenExpr) Expr* base = nullptr; @@ -512,7 +522,7 @@ class ParenExpr: public Expr // An object-oriented `this` expression, used to // refer to the current instance of an enclosing type. -class ThisExpr: public Expr +class ThisExpr : public Expr { SLANG_AST_CLASS(ThisExpr) @@ -531,14 +541,14 @@ class ReturnValExpr : public Expr }; // An expression that binds a temporary variable in a local expression context -class LetExpr: public Expr +class LetExpr : public Expr { SLANG_AST_CLASS(LetExpr) VarDecl* decl = nullptr; Expr* body = nullptr; }; -class ExtractExistentialValueExpr: public Expr +class ExtractExistentialValueExpr : public Expr { SLANG_AST_CLASS(ExtractExistentialValueExpr) DeclRef declRef; @@ -559,9 +569,9 @@ class DetachExpr : public Expr Expr* inner = nullptr; }; - /// Base class for higher-order function application - /// Eg: foo(fn) where fn is a function expression. - /// +/// Base class for higher-order function application +/// Eg: foo(fn) where fn is a function expression. +/// class HigherOrderInvokeExpr : public Expr { SLANG_ABSTRACT_AST_CLASS(HigherOrderInvokeExpr) @@ -579,25 +589,25 @@ class DifferentiateExpr : public HigherOrderInvokeExpr SLANG_ABSTRACT_AST_CLASS(DifferentiateExpr) }; - /// An expression of the form `__fwd_diff(fn)` to access the - /// forward-mode derivative version of the function `fn` - /// -class ForwardDifferentiateExpr: public DifferentiateExpr +/// An expression of the form `__fwd_diff(fn)` to access the +/// forward-mode derivative version of the function `fn` +/// +class ForwardDifferentiateExpr : public DifferentiateExpr { SLANG_AST_CLASS(ForwardDifferentiateExpr) }; - /// An expression of the form `__bwd_diff(fn)` to access the - /// forward-mode derivative version of the function `fn` - /// -class BackwardDifferentiateExpr: public DifferentiateExpr +/// An expression of the form `__bwd_diff(fn)` to access the +/// forward-mode derivative version of the function `fn` +/// +class BackwardDifferentiateExpr : public DifferentiateExpr { SLANG_AST_CLASS(BackwardDifferentiateExpr) }; - /// An expression of the form `__dispatch_kernel(fn, threadGroupSize, dispatchSize)` to - /// dispatch a compute kernel from host. - /// +/// An expression of the form `__dispatch_kernel(fn, threadGroupSize, dispatchSize)` to +/// dispatch a compute kernel from host. +/// class DispatchKernelExpr : public HigherOrderInvokeExpr { SLANG_AST_CLASS(DispatchKernelExpr) @@ -605,37 +615,37 @@ class DispatchKernelExpr : public HigherOrderInvokeExpr Expr* dispatchSize; }; - /// An express to mark its inner expression as an intended non-differential call. +/// An express to mark its inner expression as an intended non-differential call. class TreatAsDifferentiableExpr : public Expr { SLANG_AST_CLASS(TreatAsDifferentiableExpr) Expr* innerExpr; Scope* scope; - - enum Flavor + + enum Flavor { /// Represents a no_diff wrapper over /// a non-differentiable method. /// i.e. no_diff(fn(...)) - /// + /// NoDiff, /// Represents a call to a method that /// is either marked differentiable, or has /// a user-defined derivative in scope. - /// + /// Differentiable }; Flavor flavor; }; - /// A type expression of the form `This` - /// - /// Refers to the type of `this` in the current context. - /// -class ThisTypeExpr: public Expr +/// A type expression of the form `This` +/// +/// Refers to the type of `this` in the current context. +/// +class ThisTypeExpr : public Expr { SLANG_AST_CLASS(ThisTypeExpr) @@ -643,7 +653,7 @@ class ThisTypeExpr: public Expr Scope* scope = nullptr; }; - /// A type expression of the form `Left & Right`. +/// A type expression of the form `Left & Right`. class AndTypeExpr : public Expr { SLANG_AST_CLASS(AndTypeExpr); @@ -652,7 +662,7 @@ class AndTypeExpr : public Expr TypeExp right; }; - /// A type exprssion that applies one or more modifiers to another type +/// A type exprssion that applies one or more modifiers to another type class ModifiedTypeExpr : public Expr { SLANG_AST_CLASS(ModifiedTypeExpr); @@ -661,7 +671,7 @@ class ModifiedTypeExpr : public Expr TypeExp base; }; - /// A type expression that rrepresents a pointer type, e.g. T* +/// A type expression that rrepresents a pointer type, e.g. T* class PointerTypeExpr : public Expr { SLANG_AST_CLASS(PointerTypeExpr); @@ -669,7 +679,7 @@ class PointerTypeExpr : public Expr TypeExp base; }; - /// A type expression that represents a function type, e.g. (bool, int) -> float +/// A type expression that represents a function type, e.g. (bool, int) -> float class FuncTypeExpr : public Expr { SLANG_AST_CLASS(FuncTypeExpr); @@ -685,9 +695,9 @@ class TupleTypeExpr : public Expr List members; }; - /// An expression that applies a generic to arguments for some, - /// but not all, of its explicit parameters. - /// +/// An expression that applies a generic to arguments for some, +/// but not all, of its explicit parameters. +/// class PartiallyAppliedGenericExpr : public Expr { SLANG_AST_CLASS(PartiallyAppliedGenericExpr); @@ -695,17 +705,17 @@ class PartiallyAppliedGenericExpr : public Expr public: Expr* originalExpr = nullptr; - /// The generic being applied + /// The generic being applied DeclRef baseGenericDeclRef; - /// A substitution that includes the generic arguments known so far + /// A substitution that includes the generic arguments known so far List knownGenericArgs; }; - - /// An expression that holds a set of argument exprs that got matched to a pack parameter - /// during overload resolution. - /// + +/// An expression that holds a set of argument exprs that got matched to a pack parameter +/// during overload resolution. +/// class PackExpr : public Expr { SLANG_AST_CLASS(PackExpr) @@ -720,24 +730,29 @@ class SPIRVAsmOperand public: enum Flavor { - Literal, // No prefix - Id, // Prefixed with % + Literal, // No prefix + Id, // Prefixed with % ResultMarker, // "result" (without quotes) - NamedValue, // Any other identifier + NamedValue, // Any other identifier SlangValue, SlangValueAddr, SlangImmediateValue, SlangType, SampledType, // __sampledType(T), this becomes a 4 vector of the component type of T - ImageType, // __imageType(texture), returns the equivalaent OpTypeImage of a given texture typed value. - SampledImageType, // __sampledImageType(texture), returns the equivalent OpTypeSampledImage of a given texture typed value. - ConvertTexel, // __convertTexel(value), converts `value` to the native texel type of a texture. - TruncateMarker, // __truncate, an invented instruction which coerces to the result type by truncating the element count - EntryPoint, // __entryPoint, a placeholder for the id of a referencing entryPoint. + ImageType, // __imageType(texture), returns the equivalaent OpTypeImage of a given texture + // typed value. + SampledImageType, // __sampledImageType(texture), returns the equivalent OpTypeSampledImage + // of a given texture typed value. + ConvertTexel, // __convertTexel(value), converts `value` to the native texel type of a + // texture. + TruncateMarker, // __truncate, an invented instruction which coerces to the result type by + // truncating the element count + EntryPoint, // __entryPoint, a placeholder for the id of a referencing entryPoint. BuiltinVar, GLSL450Set, NonSemanticDebugPrintfExtSet, - RayPayloadFromLocation, //insert from scope of all payloads in the spir-v shader the payload identified by the integer value provided + RayPayloadFromLocation, // insert from scope of all payloads in the spir-v shader the + // payload identified by the integer value provided RayAttributeFromLocation, RayCallableFromLocation, }; diff --git a/source/slang/slang-ast-iterator.h b/source/slang/slang-ast-iterator.h index bbabb8ab5..436e97c1b 100644 --- a/source/slang/slang-ast-iterator.h +++ b/source/slang/slang-ast-iterator.h @@ -4,15 +4,15 @@ namespace Slang { -template +template struct ASTIterator { const Callback& callback; const Filter& filter; ASTIterator(const Callback& func, const Filter& filterFunc) - : callback(func) - , filter(filterFunc) - {} + : callback(func), filter(filterFunc) + { + } void visitDecl(DeclBase* decl); void visitExpr(Expr* expr); @@ -32,7 +32,8 @@ struct ASTIterator ASTIterator* iterator; ASTIteratorExprVisitor(ASTIterator* iter) : iterator(iter) - {} + { + } void dispatchIfNotNull(Expr* expr) { if (!expr) @@ -45,18 +46,12 @@ struct ASTIterator { iterator->maybeDispatchCallback(expr); } - void visitNoneLiteralExpr(NoneLiteralExpr* expr) - { - iterator->maybeDispatchCallback(expr); - } + void visitNoneLiteralExpr(NoneLiteralExpr* expr) { iterator->maybeDispatchCallback(expr); } void visitIntegerLiteralExpr(IntegerLiteralExpr* expr) { iterator->maybeDispatchCallback(expr); } - void visitOpenRefExpr(OpenRefExpr* expr) - { - dispatchIfNotNull(expr->innerExpr); - } + void visitOpenRefExpr(OpenRefExpr* expr) { dispatchIfNotNull(expr->innerExpr); } void visitFloatingPointLiteralExpr(FloatingPointLiteralExpr* expr) { iterator->maybeDispatchCallback(expr); @@ -74,10 +69,7 @@ struct ASTIterator dispatchIfNotNull(arg); } - void visitBuiltinCastExpr(BuiltinCastExpr* expr) - { - dispatchIfNotNull(expr->base); - } + void visitBuiltinCastExpr(BuiltinCastExpr* expr) { dispatchIfNotNull(expr->base); } void visitParenExpr(ParenExpr* expr) { iterator->maybeDispatchCallback(expr); @@ -262,14 +254,14 @@ struct ASTIterator void visitFuncTypeExpr(FuncTypeExpr* expr) { iterator->maybeDispatchCallback(expr); - for(const auto& t : expr->parameters) + for (const auto& t : expr->parameters) dispatchIfNotNull(t.exp); dispatchIfNotNull(expr->result.exp); } void visitTupleTypeExpr(TupleTypeExpr* expr) { iterator->maybeDispatchCallback(expr); - for(auto t : expr->members) + for (auto t : expr->members) dispatchIfNotNull(t.exp); } void visitPointerTypeExpr(PointerTypeExpr* expr) @@ -314,10 +306,10 @@ struct ASTIterator void visitSPIRVAsmExpr(SPIRVAsmExpr* expr) { iterator->maybeDispatchCallback(expr); - for(const auto& i : expr->insts) + for (const auto& i : expr->insts) { dispatchIfNotNull(i.opcode.expr); - for(const auto& o : i.operands) + for (const auto& o : i.operands) dispatchIfNotNull(o.expr); } } @@ -328,7 +320,8 @@ struct ASTIterator ASTIterator* iterator; ASTIteratorStmtVisitor(ASTIterator* iter) : iterator(iter) - {} + { + } void dispatchIfNotNull(Stmt* stmt) { @@ -454,7 +447,7 @@ struct ASTIterator }; }; -template +template void ASTIterator::visitDecl(DeclBase* decl) { // Don't look at the decl if it is defined in a different file. @@ -527,20 +520,20 @@ void ASTIterator::visitDecl(DeclBase* decl) } } } -template +template void ASTIterator::visitExpr(Expr* expr) { ASTIteratorExprVisitor visitor(this); visitor.dispatchIfNotNull(expr); } -template +template void ASTIterator::visitStmt(Stmt* stmt) { ASTIteratorStmtVisitor visitor(this); visitor.dispatchIfNotNull(stmt); } -template +template void iterateAST(SyntaxNode* node, const FilterFunc& filterFunc, const Func& f) { ASTIterator iter(f, filterFunc); @@ -558,17 +551,20 @@ void iterateAST(SyntaxNode* node, const FilterFunc& filterFunc, const Func& f) } } -template +template void iterateASTWithLanguageServerFilter( - UnownedStringSlice fileName, SourceManager* sourceManager, SyntaxNode* node, const Func& f) + UnownedStringSlice fileName, + SourceManager* sourceManager, + SyntaxNode* node, + const Func& f) { auto filter = [&](DeclBase* decl) - { - return as(decl) || - sourceManager->getHumaneLoc(decl->loc, SourceLocType::Actual) - .pathInfo.foundPath.getUnownedSlice() - .endsWithCaseInsensitive(fileName); - }; + { + return as(decl) || + sourceManager->getHumaneLoc(decl->loc, SourceLocType::Actual) + .pathInfo.foundPath.getUnownedSlice() + .endsWithCaseInsensitive(fileName); + }; iterateAST(node, filter, f); } } // namespace Slang diff --git a/source/slang/slang-ast-modifier.cpp b/source/slang/slang-ast-modifier.cpp index ba30a547d..2a245130e 100644 --- a/source/slang/slang-ast-modifier.cpp +++ b/source/slang/slang-ast-modifier.cpp @@ -1,13 +1,19 @@ // slang-ast-modifier.cpp #include "slang-ast-modifier.h" + #include "slang-ast-expr.h" namespace Slang { -const OrderedDictionary& DifferentiableAttribute::getMapTypeToIDifferentiableWitness() +const OrderedDictionary& DifferentiableAttribute:: + getMapTypeToIDifferentiableWitness() { - for (Index i = m_mapToIDifferentiableWitness.getCount(); i < m_typeToIDifferentiableWitnessMappings.getCount(); i++) - m_mapToIDifferentiableWitness.add(m_typeToIDifferentiableWitnessMappings[i].key, m_typeToIDifferentiableWitnessMappings[i].value); + for (Index i = m_mapToIDifferentiableWitness.getCount(); + i < m_typeToIDifferentiableWitnessMappings.getCount(); + i++) + m_mapToIDifferentiableWitness.add( + m_typeToIDifferentiableWitnessMappings[i].key, + m_typeToIDifferentiableWitnessMappings[i].value); return m_mapToIDifferentiableWitness; } diff --git a/source/slang/slang-ast-modifier.h b/source/slang/slang-ast-modifier.h index 04b7da74d..956704b4c 100644 --- a/source/slang/slang-ast-modifier.h +++ b/source/slang/slang-ast-modifier.h @@ -4,42 +4,118 @@ #include "slang-ast-base.h" -namespace Slang { +namespace Slang +{ // Syntax class definitions for modifiers. // Simple modifiers have no state beyond their identity -class InModifier : public Modifier { SLANG_AST_CLASS(InModifier)}; -class OutModifier : public Modifier { SLANG_AST_CLASS(OutModifier)}; -class ConstModifier : public Modifier { SLANG_AST_CLASS(ConstModifier)}; -class BuiltinModifier : public Modifier { SLANG_AST_CLASS(BuiltinModifier)}; -class InlineModifier : public Modifier { SLANG_AST_CLASS(InlineModifier)}; -class VisibilityModifier : public Modifier {SLANG_AST_CLASS(VisibilityModifier)}; -class PublicModifier : public VisibilityModifier { SLANG_AST_CLASS(PublicModifier)}; -class PrivateModifier : public VisibilityModifier { SLANG_AST_CLASS(PrivateModifier) }; -class InternalModifier : public VisibilityModifier { SLANG_AST_CLASS(InternalModifier) }; -class RequireModifier : public Modifier { SLANG_AST_CLASS(RequireModifier)}; -class ParamModifier : public Modifier { SLANG_AST_CLASS(ParamModifier)}; -class ExternModifier : public Modifier { SLANG_AST_CLASS(ExternModifier)}; -class HLSLExportModifier : public Modifier { SLANG_AST_CLASS(HLSLExportModifier) }; -class TransparentModifier : public Modifier { SLANG_AST_CLASS(TransparentModifier)}; -class FromCoreModuleModifier : public Modifier { SLANG_AST_CLASS(FromCoreModuleModifier)}; -class PrefixModifier : public Modifier { SLANG_AST_CLASS(PrefixModifier)}; -class PostfixModifier : public Modifier { SLANG_AST_CLASS(PostfixModifier)}; -class ExportedModifier : public Modifier { SLANG_AST_CLASS(ExportedModifier)}; -class ConstExprModifier : public Modifier { SLANG_AST_CLASS(ConstExprModifier)}; -class ExternCppModifier : public Modifier { SLANG_AST_CLASS(ExternCppModifier)}; -class GLSLPrecisionModifier : public Modifier { SLANG_AST_CLASS(GLSLPrecisionModifier)}; -class GLSLModuleModifier : public Modifier {SLANG_AST_CLASS(GLSLModuleModifier)}; +class InModifier : public Modifier +{ + SLANG_AST_CLASS(InModifier) +}; +class OutModifier : public Modifier +{ + SLANG_AST_CLASS(OutModifier) +}; +class ConstModifier : public Modifier +{ + SLANG_AST_CLASS(ConstModifier) +}; +class BuiltinModifier : public Modifier +{ + SLANG_AST_CLASS(BuiltinModifier) +}; +class InlineModifier : public Modifier +{ + SLANG_AST_CLASS(InlineModifier) +}; +class VisibilityModifier : public Modifier +{ + SLANG_AST_CLASS(VisibilityModifier) +}; +class PublicModifier : public VisibilityModifier +{ + SLANG_AST_CLASS(PublicModifier) +}; +class PrivateModifier : public VisibilityModifier +{ + SLANG_AST_CLASS(PrivateModifier) +}; +class InternalModifier : public VisibilityModifier +{ + SLANG_AST_CLASS(InternalModifier) +}; +class RequireModifier : public Modifier +{ + SLANG_AST_CLASS(RequireModifier) +}; +class ParamModifier : public Modifier +{ + SLANG_AST_CLASS(ParamModifier) +}; +class ExternModifier : public Modifier +{ + SLANG_AST_CLASS(ExternModifier) +}; +class HLSLExportModifier : public Modifier +{ + SLANG_AST_CLASS(HLSLExportModifier) +}; +class TransparentModifier : public Modifier +{ + SLANG_AST_CLASS(TransparentModifier) +}; +class FromCoreModuleModifier : public Modifier +{ + SLANG_AST_CLASS(FromCoreModuleModifier) +}; +class PrefixModifier : public Modifier +{ + SLANG_AST_CLASS(PrefixModifier) +}; +class PostfixModifier : public Modifier +{ + SLANG_AST_CLASS(PostfixModifier) +}; +class ExportedModifier : public Modifier +{ + SLANG_AST_CLASS(ExportedModifier) +}; +class ConstExprModifier : public Modifier +{ + SLANG_AST_CLASS(ConstExprModifier) +}; +class ExternCppModifier : public Modifier +{ + SLANG_AST_CLASS(ExternCppModifier) +}; +class GLSLPrecisionModifier : public Modifier +{ + SLANG_AST_CLASS(GLSLPrecisionModifier) +}; +class GLSLModuleModifier : public Modifier +{ + SLANG_AST_CLASS(GLSLModuleModifier) +}; // Marks that the definition of a decl is not yet synthesized. -class ToBeSynthesizedModifier : public Modifier {SLANG_AST_CLASS(ToBeSynthesizedModifier)}; +class ToBeSynthesizedModifier : public Modifier +{ + SLANG_AST_CLASS(ToBeSynthesizedModifier) +}; // Marks that the definition of a decl is synthesized. -class SynthesizedModifier : public Modifier { SLANG_AST_CLASS(SynthesizedModifier) }; +class SynthesizedModifier : public Modifier +{ + SLANG_AST_CLASS(SynthesizedModifier) +}; // Marks a synthesized variable as local temporary variable. -class LocalTempVarModifier : public Modifier { SLANG_AST_CLASS(LocalTempVarModifier) }; +class LocalTempVarModifier : public Modifier +{ + SLANG_AST_CLASS(LocalTempVarModifier) +}; // An `extern` variable in an extension is used to introduce additional attributes on an existing // field. @@ -49,21 +125,28 @@ class ExtensionExternVarModifier : public Modifier DeclRef originalDecl; }; -// An 'ActualGlobal' is a global that is output as a normal global in CPU code. -// Globals in HLSL/Slang are constant state passed into kernel execution -class ActualGlobalModifier : public Modifier { SLANG_AST_CLASS(ActualGlobalModifier)}; +// An 'ActualGlobal' is a global that is output as a normal global in CPU code. +// Globals in HLSL/Slang are constant state passed into kernel execution +class ActualGlobalModifier : public Modifier +{ + SLANG_AST_CLASS(ActualGlobalModifier) +}; - /// A modifier that indicates an `InheritanceDecl` should be ignored during name lookup (and related checks). -class IgnoreForLookupModifier : public Modifier { SLANG_AST_CLASS(IgnoreForLookupModifier) }; +/// A modifier that indicates an `InheritanceDecl` should be ignored during name lookup (and related +/// checks). +class IgnoreForLookupModifier : public Modifier +{ + SLANG_AST_CLASS(IgnoreForLookupModifier) +}; // A modifier that marks something as an operation that // has a one-to-one translation to the IR, and thus // has no direct definition in the high-level language. // -class IntrinsicOpModifier : public Modifier +class IntrinsicOpModifier : public Modifier { SLANG_AST_CLASS(IntrinsicOpModifier) - + // Token that names the intrinsic op. Token opToken; @@ -74,10 +157,10 @@ class IntrinsicOpModifier : public Modifier // A modifier that marks something as an intrinsic function, // for some subset of targets. -class TargetIntrinsicModifier : public Modifier +class TargetIntrinsicModifier : public Modifier { SLANG_AST_CLASS(TargetIntrinsicModifier) - + // Token that names the target that the operation // is an intrisic for. Token targetToken; @@ -97,10 +180,10 @@ class TargetIntrinsicModifier : public Modifier // A modifier that marks a declaration as representing a // specialization that should be preferred on a particular // target. -class SpecializedForTargetModifier : public Modifier +class SpecializedForTargetModifier : public Modifier { SLANG_AST_CLASS(SpecializedForTargetModifier) - + // Token that names the target that the operation // has been specialized for. Token targetToken; @@ -108,49 +191,49 @@ class SpecializedForTargetModifier : public Modifier // A modifier to tag something as an intrinsic that requires // a certain GLSL extension to be enabled when used -class RequiredGLSLExtensionModifier : public Modifier +class RequiredGLSLExtensionModifier : public Modifier { SLANG_AST_CLASS(RequiredGLSLExtensionModifier) - + Token extensionNameToken; }; // A modifier to tag something as an intrinsic that requires // a certain GLSL version to be enabled when used -class RequiredGLSLVersionModifier : public Modifier +class RequiredGLSLVersionModifier : public Modifier { SLANG_AST_CLASS(RequiredGLSLVersionModifier) - + Token versionNumberToken; }; // A modifier to tag something as an intrinsic that requires // a certain SPIRV version to be enabled when used. Specified as "major.minor" -class RequiredSPIRVVersionModifier : public Modifier +class RequiredSPIRVVersionModifier : public Modifier { SLANG_AST_CLASS(RequiredSPIRVVersionModifier) - + SemanticVersion version; }; // A modifier to tag something as an intrinsic that requires // a certain CUDA SM version to be enabled when used. Specified as "major.minor" -class RequiredCUDASMVersionModifier : public Modifier +class RequiredCUDASMVersionModifier : public Modifier { SLANG_AST_CLASS(RequiredCUDASMVersionModifier) - + SemanticVersion version; }; -class InOutModifier : public OutModifier +class InOutModifier : public OutModifier { SLANG_AST_CLASS(InOutModifier) }; // `__ref` modifier for by-reference parameter passing -class RefModifier : public Modifier +class RefModifier : public Modifier { SLANG_AST_CLASS(RefModifier) }; @@ -175,7 +258,7 @@ class ConstRefModifier : public Modifier // / // b: RegisterModifier("x0") / // -class SharedModifiers : public Modifier +class SharedModifiers : public Modifier { SLANG_AST_CLASS(SharedModifiers) }; @@ -191,10 +274,10 @@ class SharedModifiers : public Modifier // so that we can recover good source location info // for modifiers that were part of the same vs. // different constructs. -class GLSLLayoutModifier : public Modifier +class GLSLLayoutModifier : public Modifier { SLANG_ABSTRACT_AST_CLASS(GLSLLayoutModifier) - + // The token used to introduce the modifier is stored // as the `nameToken` field. @@ -204,17 +287,17 @@ class GLSLLayoutModifier : public Modifier }; // AST nodes to represent the begin/end of a `layout` modifier group -class GLSLLayoutModifierGroupMarker : public Modifier +class GLSLLayoutModifierGroupMarker : public Modifier { SLANG_ABSTRACT_AST_CLASS(GLSLLayoutModifierGroupMarker) }; -class GLSLLayoutModifierGroupBegin : public GLSLLayoutModifierGroupMarker +class GLSLLayoutModifierGroupBegin : public GLSLLayoutModifierGroupMarker { SLANG_AST_CLASS(GLSLLayoutModifierGroupBegin) }; -class GLSLLayoutModifierGroupEnd : public GLSLLayoutModifierGroupMarker +class GLSLLayoutModifierGroupEnd : public GLSLLayoutModifierGroupMarker { SLANG_AST_CLASS(GLSLLayoutModifierGroupEnd) }; @@ -223,12 +306,12 @@ class GLSLLayoutModifierGroupEnd : public GLSLLayoutModifierGroupMarker // We divide GLSL `layout` modifiers into those we have parsed // (in the sense of having some notion of their semantics), and // those we have not. -class GLSLParsedLayoutModifier : public GLSLLayoutModifier +class GLSLParsedLayoutModifier : public GLSLLayoutModifier { SLANG_ABSTRACT_AST_CLASS(GLSLParsedLayoutModifier) }; -class GLSLUnparsedLayoutModifier : public GLSLLayoutModifier +class GLSLUnparsedLayoutModifier : public GLSLLayoutModifier { SLANG_AST_CLASS(GLSLUnparsedLayoutModifier) }; @@ -236,7 +319,7 @@ class GLSLUnparsedLayoutModifier : public GLSLLayoutModifier // Specific cases for known GLSL `layout` modifiers that we need to work with -class GLSLLocationLayoutModifier : public GLSLParsedLayoutModifier +class GLSLLocationLayoutModifier : public GLSLParsedLayoutModifier { SLANG_AST_CLASS(GLSLLocationLayoutModifier) }; @@ -263,7 +346,7 @@ class GLSLScalarModifier : public GLSLBufferDataLayoutModifier // A catch-all for single-keyword modifiers -class SimpleModifier : public Modifier +class SimpleModifier : public Modifier { SLANG_AST_CLASS(SimpleModifier) }; @@ -271,7 +354,7 @@ class SimpleModifier : public Modifier // Indicates that this is a variable declaration that corresponds to // a parameter block declaration in the source program. -class ImplicitParameterGroupVariableModifier : public Modifier +class ImplicitParameterGroupVariableModifier : public Modifier { SLANG_AST_CLASS(ImplicitParameterGroupVariableModifier) }; @@ -279,39 +362,39 @@ class ImplicitParameterGroupVariableModifier : public Modifier // Indicates that this is a type that corresponds to the element // type of a parameter block declaration in the source program. -class ImplicitParameterGroupElementTypeModifier : public Modifier +class ImplicitParameterGroupElementTypeModifier : public Modifier { SLANG_AST_CLASS(ImplicitParameterGroupElementTypeModifier) }; // An HLSL semantic -class HLSLSemantic : public Modifier +class HLSLSemantic : public Modifier { SLANG_ABSTRACT_AST_CLASS(HLSLSemantic) - + Token name; }; // An HLSL semantic that affects layout -class HLSLLayoutSemantic : public HLSLSemantic +class HLSLLayoutSemantic : public HLSLSemantic { SLANG_AST_CLASS(HLSLLayoutSemantic) - + Token registerName; Token componentMask; }; // An HLSL `register` semantic -class HLSLRegisterSemantic : public HLSLLayoutSemantic +class HLSLRegisterSemantic : public HLSLLayoutSemantic { SLANG_AST_CLASS(HLSLRegisterSemantic) - + Token spaceName; }; // TODO(tfoley): `packoffset` -class HLSLPackOffsetSemantic : public HLSLLayoutSemantic +class HLSLPackOffsetSemantic : public HLSLLayoutSemantic { SLANG_AST_CLASS(HLSLPackOffsetSemantic) @@ -320,7 +403,7 @@ class HLSLPackOffsetSemantic : public HLSLLayoutSemantic // An HLSL semantic that just associated a declaration with a semantic name -class HLSLSimpleSemantic : public HLSLSemantic +class HLSLSimpleSemantic : public HLSLSemantic { SLANG_AST_CLASS(HLSLSimpleSemantic) }; @@ -348,17 +431,17 @@ class RayPayloadWriteSemantic : public RayPayloadAccessSemantic // Directives that came in via the preprocessor, but // that we need to keep around for later steps -class GLSLPreprocessorDirective : public Modifier +class GLSLPreprocessorDirective : public Modifier { SLANG_AST_CLASS(GLSLPreprocessorDirective) }; // A GLSL `#version` directive -class GLSLVersionDirective : public GLSLPreprocessorDirective +class GLSLVersionDirective : public GLSLPreprocessorDirective { SLANG_AST_CLASS(GLSLVersionDirective) - + // Token giving the version number to use Token versionNumberToken; @@ -368,10 +451,10 @@ class GLSLVersionDirective : public GLSLPreprocessorDirective }; // A GLSL `#extension` directive -class GLSLExtensionDirective : public GLSLPreprocessorDirective +class GLSLExtensionDirective : public GLSLPreprocessorDirective { SLANG_AST_CLASS(GLSLExtensionDirective) - + // Token giving the version number to use Token extensionNameToken; @@ -380,31 +463,31 @@ class GLSLExtensionDirective : public GLSLPreprocessorDirective Token dispositionToken; }; -class ParameterGroupReflectionName : public Modifier +class ParameterGroupReflectionName : public Modifier { SLANG_AST_CLASS(ParameterGroupReflectionName) - + NameLoc nameAndLoc; }; // A modifier that indicates a built-in base type (e.g., `float`) -class BuiltinTypeModifier : public Modifier +class BuiltinTypeModifier : public Modifier { SLANG_AST_CLASS(BuiltinTypeModifier) - + BaseType tag; }; // A modifier that indicates a built-in type that isn't a base type (e.g., `vector`) // // TODO(tfoley): This deserves a better name than "magic" -class MagicTypeModifier : public Modifier +class MagicTypeModifier : public Modifier { SLANG_AST_CLASS(MagicTypeModifier) ASTNodeType magicNodeType = ASTNodeType(-1); - /// Modifier has a name so call this magicModifier to disambiguate + /// Modifier has a name so call this magicModifier to disambiguate String magicName; uint32_t tag = uint32_t(0); }; @@ -423,10 +506,10 @@ class BuiltinRequirementModifier : public Modifier // // TODO: This should really subsume `BuiltinTypeModifier` and // `MagicTypeModifier` so that we don't have to apply all of them. -class IntrinsicTypeModifier : public Modifier +class IntrinsicTypeModifier : public Modifier { SLANG_AST_CLASS(IntrinsicTypeModifier) - + // The IR opcode to use when constructing a type uint32_t irOp; @@ -438,30 +521,30 @@ class IntrinsicTypeModifier : public Modifier }; // Modifiers that affect the storage layout for matrices -class MatrixLayoutModifier : public Modifier +class MatrixLayoutModifier : public Modifier { SLANG_AST_CLASS(MatrixLayoutModifier) }; // Modifiers that specify row- and column-major layout, respectively -class RowMajorLayoutModifier : public MatrixLayoutModifier +class RowMajorLayoutModifier : public MatrixLayoutModifier { SLANG_AST_CLASS(RowMajorLayoutModifier) }; -class ColumnMajorLayoutModifier : public MatrixLayoutModifier +class ColumnMajorLayoutModifier : public MatrixLayoutModifier { SLANG_AST_CLASS(ColumnMajorLayoutModifier) }; // The HLSL flavor of those modifiers -class HLSLRowMajorLayoutModifier : public RowMajorLayoutModifier +class HLSLRowMajorLayoutModifier : public RowMajorLayoutModifier { SLANG_AST_CLASS(HLSLRowMajorLayoutModifier) }; -class HLSLColumnMajorLayoutModifier : public ColumnMajorLayoutModifier +class HLSLColumnMajorLayoutModifier : public ColumnMajorLayoutModifier { SLANG_AST_CLASS(HLSLColumnMajorLayoutModifier) }; @@ -474,12 +557,12 @@ class HLSLColumnMajorLayoutModifier : public ColumnMajorLayoutModifier // we actually interpret that as requesting column-major. This makes // sense because we interpret matrix conventions backwards from how // GLSL specifies them. -class GLSLRowMajorLayoutModifier : public ColumnMajorLayoutModifier +class GLSLRowMajorLayoutModifier : public ColumnMajorLayoutModifier { SLANG_AST_CLASS(GLSLRowMajorLayoutModifier) }; -class GLSLColumnMajorLayoutModifier : public RowMajorLayoutModifier +class GLSLColumnMajorLayoutModifier : public RowMajorLayoutModifier { SLANG_AST_CLASS(GLSLColumnMajorLayoutModifier) }; @@ -487,47 +570,46 @@ class GLSLColumnMajorLayoutModifier : public RowMajorLayoutModifier // More HLSL Keyword -class InterpolationModeModifier : public Modifier +class InterpolationModeModifier : public Modifier { SLANG_ABSTRACT_AST_CLASS(InterpolationModeModifier) - }; // HLSL `nointerpolation` modifier -class HLSLNoInterpolationModifier : public InterpolationModeModifier +class HLSLNoInterpolationModifier : public InterpolationModeModifier { SLANG_AST_CLASS(HLSLNoInterpolationModifier) }; // HLSL `noperspective` modifier -class HLSLNoPerspectiveModifier : public InterpolationModeModifier +class HLSLNoPerspectiveModifier : public InterpolationModeModifier { SLANG_AST_CLASS(HLSLNoPerspectiveModifier) }; // HLSL `linear` modifier -class HLSLLinearModifier : public InterpolationModeModifier +class HLSLLinearModifier : public InterpolationModeModifier { SLANG_AST_CLASS(HLSLLinearModifier) }; // HLSL `sample` modifier -class HLSLSampleModifier : public InterpolationModeModifier +class HLSLSampleModifier : public InterpolationModeModifier { SLANG_AST_CLASS(HLSLSampleModifier) }; // HLSL `centroid` modifier -class HLSLCentroidModifier : public InterpolationModeModifier +class HLSLCentroidModifier : public InterpolationModeModifier { SLANG_AST_CLASS(HLSLCentroidModifier) }; - /// Slang-defined `pervertex` modifier +/// Slang-defined `pervertex` modifier class PerVertexModifier : public InterpolationModeModifier { SLANG_AST_CLASS(PerVertexModifier) @@ -535,7 +617,7 @@ class PerVertexModifier : public InterpolationModeModifier // HLSL `precise` modifier -class PreciseModifier : public Modifier +class PreciseModifier : public Modifier { SLANG_AST_CLASS(PreciseModifier) }; @@ -543,14 +625,14 @@ class PreciseModifier : public Modifier // HLSL `shared` modifier (which is used by the effect system, // and shouldn't be confused with `groupshared`) -class HLSLEffectSharedModifier : public Modifier +class HLSLEffectSharedModifier : public Modifier { SLANG_AST_CLASS(HLSLEffectSharedModifier) }; // HLSL `groupshared` modifier -class HLSLGroupSharedModifier : public Modifier +class HLSLGroupSharedModifier : public Modifier { SLANG_AST_CLASS(HLSLGroupSharedModifier) }; @@ -558,7 +640,7 @@ class HLSLGroupSharedModifier : public Modifier // HLSL `static` modifier (probably doesn't need to be // treated as HLSL-specific) -class HLSLStaticModifier : public Modifier +class HLSLStaticModifier : public Modifier { SLANG_AST_CLASS(HLSLStaticModifier) }; @@ -566,33 +648,33 @@ class HLSLStaticModifier : public Modifier // HLSL `uniform` modifier (distinct meaning from GLSL // use of the keyword) -class HLSLUniformModifier : public Modifier +class HLSLUniformModifier : public Modifier { SLANG_AST_CLASS(HLSLUniformModifier) }; // HLSL `volatile` modifier (ignored) -class HLSLVolatileModifier : public Modifier +class HLSLVolatileModifier : public Modifier { SLANG_AST_CLASS(HLSLVolatileModifier) }; -class AttributeTargetModifier : public Modifier +class AttributeTargetModifier : public Modifier { SLANG_AST_CLASS(AttributeTargetModifier) - + // A class to which the declared attribute type is applicable SyntaxClass syntaxClass; }; // Base class for checked and unchecked `[name(arg0, ...)]` style attribute. -class AttributeBase : public Modifier +class AttributeBase : public Modifier { SLANG_AST_CLASS(AttributeBase) - + AttributeDecl* attributeDecl = nullptr; // The original identifier token representing the last part of the qualified name. @@ -603,7 +685,7 @@ class AttributeBase : public Modifier // A `[name(...)]` attribute that hasn't undergone any semantic analysis. // After analysis, this will be transformed into a more specific case. -class UncheckedAttribute : public AttributeBase +class UncheckedAttribute : public AttributeBase { SLANG_AST_CLASS(UncheckedAttribute) @@ -612,22 +694,22 @@ class UncheckedAttribute : public AttributeBase }; // A `[name(arg0, ...)]` style attribute that has been validated. -class Attribute : public AttributeBase +class Attribute : public AttributeBase { SLANG_AST_CLASS(Attribute) - + List intArgVals; }; -class UserDefinedAttribute : public Attribute +class UserDefinedAttribute : public Attribute { SLANG_AST_CLASS(UserDefinedAttribute) }; -class AttributeUsageAttribute : public Attribute +class AttributeUsageAttribute : public Attribute { SLANG_AST_CLASS(AttributeUsageAttribute) - + SyntaxClass targetSyntaxClass; }; @@ -644,10 +726,9 @@ class RequireCapabilityAttribute : public Attribute // An `[unroll]` or `[unroll(count)]` attribute -class UnrollAttribute : public Attribute +class UnrollAttribute : public Attribute { SLANG_AST_CLASS(UnrollAttribute) - }; // An `[unroll]` or `[unroll(count)]` attribute @@ -659,10 +740,10 @@ class ForceUnrollAttribute : public Attribute }; // An `[maxiters(count)]` -class MaxItersAttribute : public Attribute +class MaxItersAttribute : public Attribute { SLANG_AST_CLASS(MaxItersAttribute) - + int32_t value = 0; }; @@ -674,48 +755,48 @@ class InferredMaxItersAttribute : public Attribute int32_t value = 0; }; -class LoopAttribute : public Attribute +class LoopAttribute : public Attribute { SLANG_AST_CLASS(LoopAttribute) }; - // `[loop]` -class FastOptAttribute : public Attribute +// `[loop]` +class FastOptAttribute : public Attribute { SLANG_AST_CLASS(FastOptAttribute) }; - // `[fastopt]` -class AllowUAVConditionAttribute : public Attribute +// `[fastopt]` +class AllowUAVConditionAttribute : public Attribute { SLANG_AST_CLASS(AllowUAVConditionAttribute) }; - // `[allow_uav_condition]` -class BranchAttribute : public Attribute +// `[allow_uav_condition]` +class BranchAttribute : public Attribute { SLANG_AST_CLASS(BranchAttribute) }; - // `[branch]` -class FlattenAttribute : public Attribute +// `[branch]` +class FlattenAttribute : public Attribute { SLANG_AST_CLASS(FlattenAttribute) }; - // `[flatten]` -class ForceCaseAttribute : public Attribute +// `[flatten]` +class ForceCaseAttribute : public Attribute { SLANG_AST_CLASS(ForceCaseAttribute) }; - // `[forcecase]` -class CallAttribute : public Attribute +// `[forcecase]` +class CallAttribute : public Attribute { SLANG_AST_CLASS(CallAttribute) }; - // `[call]` +// `[call]` class UnscopedEnumAttribute : public Attribute { SLANG_AST_CLASS(UnscopedEnumAttribute) }; - // Marks a enum to have `flags` semantics, where each enum case is a bitfield. +// Marks a enum to have `flags` semantics, where each enum case is a bitfield. class FlagsAttribute : public Attribute { SLANG_AST_CLASS(FlagsAttribute); @@ -741,17 +822,17 @@ class VkConstantIdAttribute : public Attribute }; // [[vk_shader_record]] [[shader_record]] -class ShaderRecordAttribute : public Attribute +class ShaderRecordAttribute : public Attribute { SLANG_AST_CLASS(ShaderRecordAttribute) }; // [[vk_binding]] -class GLSLBindingAttribute : public Attribute +class GLSLBindingAttribute : public Attribute { SLANG_AST_CLASS(GLSLBindingAttribute) - + int32_t binding = 0; int32_t set = 0; }; @@ -766,17 +847,17 @@ class VkRestrictPointerAttribute : public Attribute SLANG_AST_CLASS(VkRestrictPointerAttribute) }; -class GLSLOffsetLayoutAttribute : public Attribute +class GLSLOffsetLayoutAttribute : public Attribute { SLANG_AST_CLASS(GLSLOffsetLayoutAttribute) int64_t offset; }; -class GLSLSimpleIntegerLayoutAttribute : public Attribute +class GLSLSimpleIntegerLayoutAttribute : public Attribute { SLANG_AST_CLASS(GLSLSimpleIntegerLayoutAttribute) - + int32_t value = 0; }; @@ -789,14 +870,14 @@ class GLSLInputAttachmentIndexLayoutAttribute : public Attribute }; // [[vk_location]] -class GLSLLocationAttribute : public GLSLSimpleIntegerLayoutAttribute +class GLSLLocationAttribute : public GLSLSimpleIntegerLayoutAttribute { SLANG_AST_CLASS(GLSLLocationAttribute) }; // [[vk_index]] -class GLSLIndexAttribute : public GLSLSimpleIntegerLayoutAttribute +class GLSLIndexAttribute : public GLSLSimpleIntegerLayoutAttribute { SLANG_AST_CLASS(GLSLIndexAttribute) }; @@ -846,49 +927,49 @@ class GLSLLayoutDerivativeGroupLinearAttribute : public Attribute // TODO: for attributes that take arguments, the syntax node // classes should provide accessors for the values of those arguments. -class MaxTessFactorAttribute : public Attribute +class MaxTessFactorAttribute : public Attribute { SLANG_AST_CLASS(MaxTessFactorAttribute) }; -class OutputControlPointsAttribute : public Attribute +class OutputControlPointsAttribute : public Attribute { SLANG_AST_CLASS(OutputControlPointsAttribute) }; -class OutputTopologyAttribute : public Attribute +class OutputTopologyAttribute : public Attribute { SLANG_AST_CLASS(OutputTopologyAttribute) }; -class PartitioningAttribute : public Attribute +class PartitioningAttribute : public Attribute { SLANG_AST_CLASS(PartitioningAttribute) }; -class PatchConstantFuncAttribute : public Attribute +class PatchConstantFuncAttribute : public Attribute { SLANG_AST_CLASS(PatchConstantFuncAttribute) - + FuncDecl* patchConstantFuncDecl = nullptr; }; -class DomainAttribute : public Attribute +class DomainAttribute : public Attribute { SLANG_AST_CLASS(DomainAttribute) }; -class EarlyDepthStencilAttribute : public Attribute +class EarlyDepthStencilAttribute : public Attribute { SLANG_AST_CLASS(EarlyDepthStencilAttribute) }; - // `[earlydepthstencil]` +// `[earlydepthstencil]` // An HLSL `[numthreads(x,y,z)]` attribute -class NumThreadsAttribute : public Attribute +class NumThreadsAttribute : public Attribute { SLANG_AST_CLASS(NumThreadsAttribute) - + // The number of threads to use along each axis // // TODO: These should be accessors that use the @@ -909,10 +990,10 @@ class WaveSizeAttribute : public Attribute IntVal* numLanes; }; -class MaxVertexCountAttribute : public Attribute +class MaxVertexCountAttribute : public Attribute { SLANG_AST_CLASS(MaxVertexCountAttribute) - + // The number of max vertex count for geometry shader // // TODO: This should be an accessor that uses the @@ -920,10 +1001,10 @@ class MaxVertexCountAttribute : public Attribute int32_t value; }; -class InstanceAttribute : public Attribute +class InstanceAttribute : public Attribute { SLANG_AST_CLASS(InstanceAttribute) - + // The number of instances to run for geometry shader // // TODO: This should be an accessor that uses the @@ -932,7 +1013,7 @@ class InstanceAttribute : public Attribute }; // A `[shader("stageName")]`/`[shader("capability")]` attribute which -// marks an entry point for compiling. This attribute also specifies +// marks an entry point for compiling. This attribute also specifies // the 'capabilities' implicitly supported by an entry point class EntryPointAttribute : public Attribute { @@ -946,13 +1027,13 @@ class EntryPointAttribute : public Attribute // core module implementation to indicate that a variable // actually represents the input/output interface for a Vulkan // ray tracing shader to pass per-ray payload information. -class VulkanRayPayloadAttribute : public Attribute +class VulkanRayPayloadAttribute : public Attribute { SLANG_AST_CLASS(VulkanRayPayloadAttribute) int location; }; -class VulkanRayPayloadInAttribute : public Attribute +class VulkanRayPayloadInAttribute : public Attribute { SLANG_AST_CLASS(VulkanRayPayloadInAttribute) @@ -963,13 +1044,13 @@ class VulkanRayPayloadInAttribute : public Attribute // core module implementation to indicate that a variable // actually represents the input/output interface for a Vulkan // ray tracing shader to pass payload information to/from a callee. -class VulkanCallablePayloadAttribute : public Attribute +class VulkanCallablePayloadAttribute : public Attribute { SLANG_AST_CLASS(VulkanCallablePayloadAttribute) int location; }; -class VulkanCallablePayloadInAttribute : public Attribute +class VulkanCallablePayloadInAttribute : public Attribute { SLANG_AST_CLASS(VulkanCallablePayloadInAttribute) @@ -980,7 +1061,7 @@ class VulkanCallablePayloadInAttribute : public Attribute // core module implementation to indicate that a variable // actually represents the output interface for a Vulkan // intersection shader to pass hit attribute information. -class VulkanHitAttributesAttribute : public Attribute +class VulkanHitAttributesAttribute : public Attribute { SLANG_AST_CLASS(VulkanHitAttributesAttribute) }; @@ -1000,7 +1081,7 @@ class VulkanHitObjectAttributesAttribute : public Attribute // function is allowed to modify things through its `this` // argument. // -class MutatingAttribute : public Attribute +class MutatingAttribute : public Attribute { SLANG_AST_CLASS(MutatingAttribute) }; @@ -1035,7 +1116,7 @@ class RefAttribute : public Attribute // reading or writing through any pointer arguments, or any other // state that could be observed by a caller. // -class ReadNoneAttribute : public Attribute +class ReadNoneAttribute : public Attribute { SLANG_AST_CLASS(ReadNoneAttribute) }; @@ -1052,32 +1133,32 @@ class GLSLRequireShaderInputParameterAttribute : public Attribute }; // HLSL modifiers for geometry shader input topology -class HLSLGeometryShaderInputPrimitiveTypeModifier : public Modifier +class HLSLGeometryShaderInputPrimitiveTypeModifier : public Modifier { SLANG_AST_CLASS(HLSLGeometryShaderInputPrimitiveTypeModifier) }; -class HLSLPointModifier : public HLSLGeometryShaderInputPrimitiveTypeModifier +class HLSLPointModifier : public HLSLGeometryShaderInputPrimitiveTypeModifier { SLANG_AST_CLASS(HLSLPointModifier) }; -class HLSLLineModifier : public HLSLGeometryShaderInputPrimitiveTypeModifier +class HLSLLineModifier : public HLSLGeometryShaderInputPrimitiveTypeModifier { SLANG_AST_CLASS(HLSLLineModifier) }; -class HLSLTriangleModifier : public HLSLGeometryShaderInputPrimitiveTypeModifier +class HLSLTriangleModifier : public HLSLGeometryShaderInputPrimitiveTypeModifier { SLANG_AST_CLASS(HLSLTriangleModifier) }; -class HLSLLineAdjModifier : public HLSLGeometryShaderInputPrimitiveTypeModifier +class HLSLLineAdjModifier : public HLSLGeometryShaderInputPrimitiveTypeModifier { SLANG_AST_CLASS(HLSLLineAdjModifier) }; -class HLSLTriangleAdjModifier : public HLSLGeometryShaderInputPrimitiveTypeModifier +class HLSLTriangleAdjModifier : public HLSLGeometryShaderInputPrimitiveTypeModifier { SLANG_AST_CLASS(HLSLTriangleAdjModifier) }; @@ -1112,10 +1193,10 @@ class HLSLPayloadModifier : public Modifier // A modifier to indicate that a constructor/initializer can be used // to perform implicit type conversion, and to specify the cost of // the conversion, if applied. -class ImplicitConversionModifier : public Modifier +class ImplicitConversionModifier : public Modifier { SLANG_AST_CLASS(ImplicitConversionModifier) - + // The conversion cost, used to rank conversions ConversionCost cost; @@ -1123,24 +1204,24 @@ class ImplicitConversionModifier : public Modifier BuiltinConversionKind builtinConversionKind; }; -class FormatAttribute : public Attribute +class FormatAttribute : public Attribute { SLANG_AST_CLASS(FormatAttribute) - + ImageFormat format; }; -class AllowAttribute : public Attribute +class AllowAttribute : public Attribute { SLANG_AST_CLASS(AllowAttribute) - + DiagnosticInfo const* diagnostic = nullptr; }; // A `[__extern]` attribute, which indicates that a function/type is defined externally // -class ExternAttribute : public Attribute +class ExternAttribute : public Attribute { SLANG_AST_CLASS(ExternAttribute) }; @@ -1149,7 +1230,7 @@ class ExternAttribute : public Attribute // An `[__unsafeForceInlineExternal]` attribute indicates that the callee should be inlined // into call sites after initial IR generation (that is, as early as possible). // -class UnsafeForceInlineEarlyAttribute : public Attribute +class UnsafeForceInlineEarlyAttribute : public Attribute { SLANG_AST_CLASS(UnsafeForceInlineEarlyAttribute) }; @@ -1163,30 +1244,41 @@ class ForceInlineAttribute : public Attribute }; - /// An attribute that marks a type declaration as either allowing or - /// disallowing the type to be inherited from in other modules. -class InheritanceControlAttribute : public Attribute { SLANG_AST_CLASS(InheritanceControlAttribute) }; +/// An attribute that marks a type declaration as either allowing or +/// disallowing the type to be inherited from in other modules. +class InheritanceControlAttribute : public Attribute +{ + SLANG_AST_CLASS(InheritanceControlAttribute) +}; - /// An attribute that marks a type declaration as allowing the type to be inherited from in other modules. -class OpenAttribute : public InheritanceControlAttribute { SLANG_AST_CLASS(OpenAttribute) }; +/// An attribute that marks a type declaration as allowing the type to be inherited from in other +/// modules. +class OpenAttribute : public InheritanceControlAttribute +{ + SLANG_AST_CLASS(OpenAttribute) +}; - /// An attribute that marks a type declaration as disallowing the type to be inherited from in other modules. -class SealedAttribute : public InheritanceControlAttribute { SLANG_AST_CLASS(SealedAttribute) }; +/// An attribute that marks a type declaration as disallowing the type to be inherited from in other +/// modules. +class SealedAttribute : public InheritanceControlAttribute +{ + SLANG_AST_CLASS(SealedAttribute) +}; - /// An attribute that marks a decl as a compiler built-in object. +/// An attribute that marks a decl as a compiler built-in object. class BuiltinAttribute : public Attribute { SLANG_AST_CLASS(BuiltinAttribute) }; - - /// An attribute that marks a decl as a compiler built-in object for the autodiff system. + +/// An attribute that marks a decl as a compiler built-in object for the autodiff system. class AutoDiffBuiltinAttribute : public Attribute { SLANG_AST_CLASS(AutoDiffBuiltinAttribute) }; - /// An attribute that defines the size of `AnyValue` type to represent a polymoprhic value that conforms to - /// the decorated interface type. +/// An attribute that defines the size of `AnyValue` type to represent a polymoprhic value that +/// conforms to the decorated interface type. class AnyValueSizeAttribute : public Attribute { SLANG_AST_CLASS(AnyValueSizeAttribute) @@ -1194,26 +1286,26 @@ class AnyValueSizeAttribute : public Attribute int32_t size; }; - /// This is a stop-gap solution to break overload ambiguity in the core module. - /// When there is a function overload ambiguity, the compiler will pick the one with higher rank - /// specified by this attribute. An overload without this attribute will have a rank of 0. - /// In the future, we should enhance our type system to take into account the "specialized"-ness - /// of an overload, such that `T overload1()` is more specialized than `T overload2()` - /// and preferred during overload resolution. +/// This is a stop-gap solution to break overload ambiguity in the core module. +/// When there is a function overload ambiguity, the compiler will pick the one with higher rank +/// specified by this attribute. An overload without this attribute will have a rank of 0. +/// In the future, we should enhance our type system to take into account the "specialized"-ness +/// of an overload, such that `T overload1()` is more specialized than `T +/// overload2()` and preferred during overload resolution. class OverloadRankAttribute : public Attribute { SLANG_AST_CLASS(OverloadRankAttribute) int32_t rank; }; - /// An attribute that marks an interface for specialization use only. Any operation that triggers dynamic - /// dispatch through the interface is a compile-time error. +/// An attribute that marks an interface for specialization use only. Any operation that triggers +/// dynamic dispatch through the interface is a compile-time error. class SpecializeAttribute : public Attribute { SLANG_AST_CLASS(SpecializeAttribute) }; - /// An attribute that marks a type, function or variable as differentiable. +/// An attribute that marks a type, function or variable as differentiable. class DifferentiableAttribute : public Attribute { SLANG_AST_CLASS(DifferentiableAttribute) @@ -1225,7 +1317,8 @@ class DifferentiableAttribute : public Attribute getMapTypeToIDifferentiableWitness(); if (m_mapToIDifferentiableWitness.addIfNotExists(declRef, witness)) { - m_typeToIDifferentiableWitnessMappings.add(KeyValuePair(declRef, witness)); + m_typeToIDifferentiableWitnessMappings.add( + KeyValuePair(declRef, witness)); } } @@ -1233,6 +1326,7 @@ class DifferentiableAttribute : public Attribute const OrderedDictionary& getMapTypeToIDifferentiableWitness(); SLANG_UNREFLECTED ValSet m_typeRegistrationWorkingSet; + private: OrderedDictionary m_mapToIDifferentiableWitness; }; @@ -1286,7 +1380,7 @@ class PyExportAttribute : public Attribute class PreferRecomputeAttribute : public Attribute { SLANG_AST_CLASS(PreferRecomputeAttribute) - + enum SideEffectBehavior { Warn = 0, @@ -1308,7 +1402,7 @@ class DerivativeMemberAttribute : public Attribute DeclRefExpr* memberDeclRef; }; - /// An attribute that marks an interface type as a COM interface declaration. +/// An attribute that marks an interface type as a COM interface declaration. class ComInterfaceAttribute : public Attribute { SLANG_AST_CLASS(ComInterfaceAttribute) @@ -1316,15 +1410,15 @@ class ComInterfaceAttribute : public Attribute String guid; }; - /// A `[__requiresNVAPI]` attribute indicates that the declaration being modifed - /// requires NVAPI operations for its implementation on D3D. +/// A `[__requiresNVAPI]` attribute indicates that the declaration being modifed +/// requires NVAPI operations for its implementation on D3D. class RequiresNVAPIAttribute : public Attribute { SLANG_AST_CLASS(RequiresNVAPIAttribute) }; - /// A `[RequirePrelude(target, "string")]` attribute indicates that the declaration being modifed - /// requires a textual prelude to be injected in the resulting target code. +/// A `[RequirePrelude(target, "string")]` attribute indicates that the declaration being modifed +/// requires a textual prelude to be injected in the resulting target code. class RequirePreludeAttribute : public Attribute { SLANG_AST_CLASS(RequirePreludeAttribute) @@ -1333,9 +1427,9 @@ class RequirePreludeAttribute : public Attribute String prelude; }; - /// A `[__AlwaysFoldIntoUseSite]` attribute indicates that the calls into the modified - /// function should always be folded into use sites during source emit. -class AlwaysFoldIntoUseSiteAttribute :public Attribute +/// A `[__AlwaysFoldIntoUseSite]` attribute indicates that the calls into the modified +/// function should always be folded into use sites during source emit. +class AlwaysFoldIntoUseSiteAttribute : public Attribute { SLANG_AST_CLASS(AlwaysFoldIntoUseSiteAttribute) }; @@ -1348,7 +1442,7 @@ class TreatAsDifferentiableAttribute : public DifferentiableAttribute SLANG_AST_CLASS(TreatAsDifferentiableAttribute) }; - /// The `[ForwardDifferentiable]` attribute indicates that a function can be forward-differentiated. +/// The `[ForwardDifferentiable]` attribute indicates that a function can be forward-differentiated. class ForwardDifferentiableAttribute : public DifferentiableAttribute { SLANG_AST_CLASS(ForwardDifferentiableAttribute) @@ -1361,8 +1455,8 @@ class UserDefinedDerivativeAttribute : public DifferentiableAttribute Expr* funcExpr; }; - /// The `[ForwardDerivative(function)]` attribute specifies a custom function that should - /// be used as the derivative for the decorated function. +/// The `[ForwardDerivative(function)]` attribute specifies a custom function that should +/// be used as the derivative for the decorated function. class ForwardDerivativeAttribute : public UserDefinedDerivativeAttribute { SLANG_AST_CLASS(ForwardDerivativeAttribute) @@ -1377,46 +1471,48 @@ class DerivativeOfAttribute : public DifferentiableAttribute Expr* backDeclRef; // DeclRef to this derivative function when initiated from primalFunction. }; - /// The `[ForwardDerivativeOf(primalFunction)]` attribute marks the decorated function as custom - /// derivative implementation for `primalFunction`. - /// ForwardDerivativeOfAttribute inherits from DifferentiableAttribute because a derivative - /// function itself is considered differentiable. +/// The `[ForwardDerivativeOf(primalFunction)]` attribute marks the decorated function as custom +/// derivative implementation for `primalFunction`. +/// ForwardDerivativeOfAttribute inherits from DifferentiableAttribute because a derivative +/// function itself is considered differentiable. class ForwardDerivativeOfAttribute : public DerivativeOfAttribute { SLANG_AST_CLASS(ForwardDerivativeOfAttribute) }; - /// The `[BackwardDifferentiable]` attribute indicates that a function can be backward-differentiated. +/// The `[BackwardDifferentiable]` attribute indicates that a function can be +/// backward-differentiated. class BackwardDifferentiableAttribute : public DifferentiableAttribute { SLANG_AST_CLASS(BackwardDifferentiableAttribute) int maxOrder = 0; }; - /// The `[BackwardDerivative(function)]` attribute specifies a custom function that should - /// be used as the backward-derivative for the decorated function. +/// The `[BackwardDerivative(function)]` attribute specifies a custom function that should +/// be used as the backward-derivative for the decorated function. class BackwardDerivativeAttribute : public UserDefinedDerivativeAttribute { SLANG_AST_CLASS(BackwardDerivativeAttribute) }; - /// The `[BackwardDerivativeOf(primalFunction)]` attribute marks the decorated function as custom - /// backward-derivative implementation for `primalFunction`. +/// The `[BackwardDerivativeOf(primalFunction)]` attribute marks the decorated function as custom +/// backward-derivative implementation for `primalFunction`. class BackwardDerivativeOfAttribute : public DerivativeOfAttribute { SLANG_AST_CLASS(BackwardDerivativeOfAttribute) }; - /// The `[PrimalSubstitute(function)]` attribute specifies a custom function that should - /// be used as the primal function substitute when differentiating code that calls the primal function. +/// The `[PrimalSubstitute(function)]` attribute specifies a custom function that should +/// be used as the primal function substitute when differentiating code that calls the primal +/// function. class PrimalSubstituteAttribute : public Attribute { SLANG_AST_CLASS(PrimalSubstituteAttribute) Expr* funcExpr; }; - /// The `[PrimalSubstituteOf(primalFunction)]` attribute marks the decorated function as - /// the substitute primal function in a forward or backward derivative function. +/// The `[PrimalSubstituteOf(primalFunction)]` attribute marks the decorated function as +/// the substitute primal function in a forward or backward derivative function. class PrimalSubstituteOfAttribute : public Attribute { SLANG_AST_CLASS(PrimalSubstituteOfAttribute) @@ -1425,62 +1521,62 @@ class PrimalSubstituteOfAttribute : public Attribute Expr* backDeclRef; // DeclRef to this derivative function when initiated from primalFunction. }; - /// The `[NoDiffThis]` attribute is used to specify that the `this` parameter should not be - /// included for differentiation. +/// The `[NoDiffThis]` attribute is used to specify that the `this` parameter should not be +/// included for differentiation. class NoDiffThisAttribute : public Attribute { SLANG_AST_CLASS(NoDiffThisAttribute) }; - /// Indicates that the modified declaration is one of the "magic" declarations - /// that NVAPI uses to communicate extended operations. When NVAPI is being included - /// via the prelude for downstream compilation, declarations with this modifier - /// will not be emitted, instead allowing the versions from the prelude to be used. +/// Indicates that the modified declaration is one of the "magic" declarations +/// that NVAPI uses to communicate extended operations. When NVAPI is being included +/// via the prelude for downstream compilation, declarations with this modifier +/// will not be emitted, instead allowing the versions from the prelude to be used. class NVAPIMagicModifier : public Modifier { SLANG_AST_CLASS(NVAPIMagicModifier) }; - /// A modifier that attaches to a `ModuleDecl` to indicate the register/space binding - /// that NVAPI wants to use, as indicated by, e.g., the `NV_SHADER_EXTN_SLOT` and - /// `NV_SHADER_EXTN_REGISTER_SPACE` preprocessor definitions. +/// A modifier that attaches to a `ModuleDecl` to indicate the register/space binding +/// that NVAPI wants to use, as indicated by, e.g., the `NV_SHADER_EXTN_SLOT` and +/// `NV_SHADER_EXTN_REGISTER_SPACE` preprocessor definitions. class NVAPISlotModifier : public Modifier { SLANG_AST_CLASS(NVAPISlotModifier) - /// The name of the register that is to be used (e.g., `"u3"`) - /// - /// This value will come from the `NV_SHADER_EXTN_SLOT` macro, if set. - /// - /// The `registerName` field must always be filled in when adding - /// an `NVAPISlotModifier` to a module; if no register name is defined, - /// then the modifier should not be added. - /// + /// The name of the register that is to be used (e.g., `"u3"`) + /// + /// This value will come from the `NV_SHADER_EXTN_SLOT` macro, if set. + /// + /// The `registerName` field must always be filled in when adding + /// an `NVAPISlotModifier` to a module; if no register name is defined, + /// then the modifier should not be added. + /// String registerName; - /// The name of the register space to be used (e.g., `space1`) - /// - /// This value will come from the `NV_SHADER_EXTN_REGISTER_SPACE` macro, - /// if set. - /// - /// It is valid for a user to specify a register name but not a space name, - /// and in that case `spaceName` will be set to `"space0"`. + /// The name of the register space to be used (e.g., `space1`) + /// + /// This value will come from the `NV_SHADER_EXTN_REGISTER_SPACE` macro, + /// if set. + /// + /// It is valid for a user to specify a register name but not a space name, + /// and in that case `spaceName` will be set to `"space0"`. String spaceName; }; - /// A `[noinline]` attribute represents a request by the application that, - /// to the extent possible, a function should not be inlined into call sites. - /// - /// Note that due to various limitations of different targets, it is entirely - /// possible for such functions to be inlined or specialized to call sites. - /// +/// A `[noinline]` attribute represents a request by the application that, +/// to the extent possible, a function should not be inlined into call sites. +/// +/// Note that due to various limitations of different targets, it is entirely +/// possible for such functions to be inlined or specialized to call sites. +/// class NoInlineAttribute : public Attribute { SLANG_AST_CLASS(NoInlineAttribute) }; - /// A `[noRefInline]` attribute represents a request to not force inline a - /// function specifically due to a refType parameter. +/// A `[noRefInline]` attribute represents a request to not force inline a +/// function specifically due to a refType parameter. class NoRefInlineAttribute : public Attribute { SLANG_AST_CLASS(NoRefInlineAttribute) @@ -1496,21 +1592,21 @@ class DerivativeGroupLinearAttribute : public Attribute SLANG_AST_CLASS(DerivativeGroupLinearAttribute) }; - /// A `[payload]` attribute indicates that a `struct` type will be used as - /// a ray payload for `TraceRay()` calls, and thus also as input/output - /// for shaders in the ray tracing pipeline that might be invoked for - /// such a ray. - /// +/// A `[payload]` attribute indicates that a `struct` type will be used as +/// a ray payload for `TraceRay()` calls, and thus also as input/output +/// for shaders in the ray tracing pipeline that might be invoked for +/// such a ray. +/// class PayloadAttribute : public Attribute { SLANG_AST_CLASS(PayloadAttribute) }; - /// A `[deprecated("message")]` attribute indicates the target is - /// deprecated. - /// A compiler warning including the message will be raised if the - /// deprecated value is used. - /// +/// A `[deprecated("message")]` attribute indicates the target is +/// deprecated. +/// A compiler warning including the message will be raised if the +/// deprecated value is used. +/// class DeprecatedAttribute : public Attribute { SLANG_AST_CLASS(DeprecatedAttribute) @@ -1528,10 +1624,10 @@ class NoSideEffectAttribute : public Attribute SLANG_AST_CLASS(NoSideEffectAttribute) }; - /// A `[KnownBuiltin("name")]` attribute allows the compiler to - /// identify this declaration during compilation, despite obfuscation or - /// linkage removing optimizations - /// +/// A `[KnownBuiltin("name")]` attribute allows the compiler to +/// identify this declaration during compilation, despite obfuscation or +/// linkage removing optimizations +/// class KnownBuiltinAttribute : public Attribute { SLANG_AST_CLASS(KnownBuiltinAttribute) @@ -1539,48 +1635,48 @@ class KnownBuiltinAttribute : public Attribute String name; }; - /// A modifier that applies to types rather than declarations. - /// - /// In most cases, the Slang compiler assumes that a modifier should - /// inhere to a declaration. Given input like: - /// - /// mod1 mod2 int myVar = ...; - /// - /// The default assumption is that `mod1` and `mod2` apply to `myVar` - /// and *not* to the `int` type. - /// - /// In order to allow modifiers to inhere to the type instead, we introduce - /// a base class for modifiers that really don't want to belong to the declaration, - /// and instead want to belong to the type (or rather the type *specifier* - /// from a parsing standpoint). - /// +/// A modifier that applies to types rather than declarations. +/// +/// In most cases, the Slang compiler assumes that a modifier should +/// inhere to a declaration. Given input like: +/// +/// mod1 mod2 int myVar = ...; +/// +/// The default assumption is that `mod1` and `mod2` apply to `myVar` +/// and *not* to the `int` type. +/// +/// In order to allow modifiers to inhere to the type instead, we introduce +/// a base class for modifiers that really don't want to belong to the declaration, +/// and instead want to belong to the type (or rather the type *specifier* +/// from a parsing standpoint). +/// class TypeModifier : public Modifier { SLANG_AST_CLASS(TypeModifier) }; - /// A kind of syntax element which appears as a modifier in the syntax, but - /// we represent as a function over type expressions +/// A kind of syntax element which appears as a modifier in the syntax, but +/// we represent as a function over type expressions class WrappingTypeModifier : public TypeModifier { SLANG_AST_CLASS(WrappingTypeModifier) }; - /// A modifier that applies to a type and implies information about the - /// underlying format of a resource that uses that type as its element type. - /// +/// A modifier that applies to a type and implies information about the +/// underlying format of a resource that uses that type as its element type. +/// class ResourceElementFormatModifier : public TypeModifier { SLANG_AST_CLASS(ResourceElementFormatModifier) }; - /// HLSL `unorm` modifier +/// HLSL `unorm` modifier class UNormModifier : public ResourceElementFormatModifier { SLANG_AST_CLASS(UNormModifier) }; - /// HLSL `snorm` modifier +/// HLSL `snorm` modifier class SNormModifier : public ResourceElementFormatModifier { SLANG_AST_CLASS(SNormModifier) @@ -1657,12 +1753,12 @@ public: { enum MemoryQualifiersBit { - kNone = 0b0, - kCoherent = 0b1, - kReadOnly = 0b10, - kWriteOnly = 0b100, - kVolatile = 0b1000, - kRestrict = 0b10000, + kNone = 0b0, + kCoherent = 0b1, + kReadOnly = 0b10, + kWriteOnly = 0b100, + kVolatile = 0b1000, + kRestrict = 0b10000, kRasterizerOrdered = 0b100000, }; }; diff --git a/source/slang/slang-ast-natural-layout.cpp b/source/slang/slang-ast-natural-layout.cpp index 4a4ef37fb..8bfc5f8ce 100644 --- a/source/slang/slang-ast-natural-layout.cpp +++ b/source/slang/slang-ast-natural-layout.cpp @@ -25,15 +25,15 @@ NaturalSize NaturalSize::operator*(Count count) const // If the count is 0, in effect the result doesn't take up any space return makeEmpty(); } - else + else { - // We don't want to produce an aligned size, as we allow the last element to not + // We don't want to produce an aligned size, as we allow the last element to not // take up a whole stride (only up to size) return make(size + (getStride() * (count - 1)), alignment); } } -/* static */NaturalSize NaturalSize::makeFromBaseType(BaseType baseType) +/* static */ NaturalSize NaturalSize::makeFromBaseType(BaseType baseType) { // Special case void if (baseType == BaseType::Void) @@ -49,7 +49,7 @@ NaturalSize NaturalSize::operator*(Count count) const } } -/* static */NaturalSize NaturalSize::calcUnion(NaturalSize a, NaturalSize b) +/* static */ NaturalSize NaturalSize::calcUnion(NaturalSize a, NaturalSize b) { const auto alignment = maxAlignment(a.alignment, b.alignment); Count size = (alignment == kInvalidAlignment) ? 0 : Math::Max(a.size, b.size); @@ -58,9 +58,8 @@ NaturalSize NaturalSize::operator*(Count count) const /* !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! ASTNaturalLayoutContext !!!!!!!!!!!!!!!!!!!!!!!!!!!! */ -ASTNaturalLayoutContext::ASTNaturalLayoutContext(ASTBuilder* astBuilder, DiagnosticSink* sink): - m_astBuilder(astBuilder), - m_sink(sink) +ASTNaturalLayoutContext::ASTNaturalLayoutContext(ASTBuilder* astBuilder, DiagnosticSink* sink) + : m_astBuilder(astBuilder), m_sink(sink) { // A null type always maps to invalid m_typeToSize.add(nullptr, NaturalSize::makeInvalid()); @@ -94,10 +93,12 @@ NaturalSize ASTNaturalLayoutContext::calcSize(Type* type) // Calc the size const NaturalSize size = _calcSizeImpl(type); - // We want to add to the cache, but we need to special case - // in case there is an aggregate type that `poisoned` the cache entry, to stop infinite recursion. - // - // A requirement is that when the agg type completes it must set the cache entry, and return the same result. + // We want to add to the cache, but we need to special case + // in case there is an aggregate type that `poisoned` the cache entry, to stop infinite + // recursion. + // + // A requirement is that when the agg type completes it must set the cache entry, and return the + // same result. if (auto foundSize = m_typeToSize.tryGetValueOrAdd(type, size)) { // If there is a found size, it must match. If not we update the state as invalid. @@ -116,17 +117,16 @@ NaturalSize ASTNaturalLayoutContext::_calcSizeImpl(Type* type) if (VectorExpressionType* vecType = as(type)) { const Count elementCount = _getCount(vecType->getElementCount()); - return (elementCount > 0) ? - calcSize(vecType->getElementType()) * elementCount : - NaturalSize::makeInvalid(); + return (elementCount > 0) ? calcSize(vecType->getElementType()) * elementCount + : NaturalSize::makeInvalid(); } else if (auto matType = as(type)) { const Count colCount = _getCount(matType->getColumnCount()); const Count rowCount = _getCount(matType->getRowCount()); - return (colCount > 0 && rowCount > 0) ? - calcSize(matType->getElementType()) * (colCount * rowCount) : - NaturalSize::makeInvalid(); + return (colCount > 0 && rowCount > 0) + ? calcSize(matType->getElementType()) * (colCount * rowCount) + : NaturalSize::makeInvalid(); } else if (auto basicType = as(type)) { @@ -140,9 +140,8 @@ NaturalSize ASTNaturalLayoutContext::_calcSizeImpl(Type* type) else if (auto arrayType = as(type)) { const Count elementCount = _getCount(arrayType->getElementCount()); - return (elementCount > 0) ? - calcSize(arrayType->getElementType()) * elementCount : - NaturalSize::makeInvalid(); + return (elementCount > 0) ? calcSize(arrayType->getElementType()) * elementCount + : NaturalSize::makeInvalid(); } else if (auto namedType = as(type)) { @@ -166,14 +165,14 @@ NaturalSize ASTNaturalLayoutContext::_calcSizeImpl(Type* type) return size; } - else if( auto declRefType = as(type) ) + else if (auto declRefType = as(type)) { if (const auto enumDeclRef = declRefType->getDeclRef().as()) { Type* tagType = getTagType(m_astBuilder, enumDeclRef); return calcSize(tagType); } - else if(const auto structDeclRef = declRefType->getDeclRef().as()) + else if (const auto structDeclRef = declRefType->getDeclRef().as()) { // Poison the cache whilst we construct m_typeToSize.add(type, NaturalSize::makeInvalid()); @@ -213,7 +212,7 @@ NaturalSize ASTNaturalLayoutContext::_calcSizeImpl(Type* type) // Set the cached result to the size. m_typeToSize.set(type, size); - return size; + return size; } else if (const auto typeDef = declRefType->getDeclRef().as()) { diff --git a/source/slang/slang-ast-natural-layout.h b/source/slang/slang-ast-natural-layout.h index 4a165973d..f38bf1a95 100644 --- a/source/slang/slang-ast-natural-layout.h +++ b/source/slang/slang-ast-natural-layout.h @@ -10,39 +10,44 @@ struct NaturalSize { typedef NaturalSize ThisType; - // We are going to use 0 as invalid for alignment. This has a few nice propeties - // - // * Will naturally produce 0 size when used with `calcAligned` operation - // * Is fast to test - // * Is easy to make a fast 'max' such that a max with invalid always returns `invalid` - // - // We also desire that when invalid the `size` member is 0. - // This is so that equality testing doesn't require anything special. - SLANG_FORCE_INLINE static Count calcAligned(Count size, Count alignment) { return (size + alignment - 1) & ~(alignment - 1); } - // Use to get the max of two alignments. Uses some maths such that `invalid` is always max - SLANG_FORCE_INLINE static Count maxAlignment(Count a, Count b) { return (UCount(a) - 1) > (UCount(b) - 1) ? a : b; } - - /// Given two sizes, returns a result that can hold the union. + // We are going to use 0 as invalid for alignment. This has a few nice propeties + // + // * Will naturally produce 0 size when used with `calcAligned` operation + // * Is fast to test + // * Is easy to make a fast 'max' such that a max with invalid always returns `invalid` + // + // We also desire that when invalid the `size` member is 0. + // This is so that equality testing doesn't require anything special. + SLANG_FORCE_INLINE static Count calcAligned(Count size, Count alignment) + { + return (size + alignment - 1) & ~(alignment - 1); + } + // Use to get the max of two alignments. Uses some maths such that `invalid` is always max + SLANG_FORCE_INLINE static Count maxAlignment(Count a, Count b) + { + return (UCount(a) - 1) > (UCount(b) - 1) ? a : b; + } + + /// Given two sizes, returns a result that can hold the union. static NaturalSize calcUnion(NaturalSize a, NaturalSize b); - /// Value chosen such that normal combining operations produce an invalid result - /// as typically a max. + /// Value chosen such that normal combining operations produce an invalid result + /// as typically a max. static const Count kInvalidAlignment = 0; - /// Get the stride, which is equivalent to the size aligned + /// Get the stride, which is equivalent to the size aligned SLANG_FORCE_INLINE Count getStride() const { return calcAligned(size, alignment); } - /// Append rhs to this. - /// If rhs is invalid or this is the result will also be invalid + /// Append rhs to this. + /// If rhs is invalid or this is the result will also be invalid void append(const ThisType& rhs) { const auto newAlignment = maxAlignment(alignment, rhs.alignment); // If the new alignment is valid we calculate the size, else it's 0 - size = (newAlignment != kInvalidAlignment) ? - (calcAligned(size, rhs.alignment) + rhs.size) : - 0; - + size = + (newAlignment != kInvalidAlignment) ? (calcAligned(size, rhs.alignment) + rhs.size) : 0; + // Set the new alignment alignment = newAlignment; } @@ -50,46 +55,48 @@ struct NaturalSize SLANG_FORCE_INLINE bool isInvalid() const { return alignment == kInvalidAlignment; } SLANG_FORCE_INLINE bool isValid() const { return !isInvalid(); } - bool operator==(const ThisType& rhs) const { return size == rhs.size && alignment == rhs.alignment; } + bool operator==(const ThisType& rhs) const + { + return size == rhs.size && alignment == rhs.alignment; + } bool operator!=(const ThisType& rhs) const { return !(*this == rhs); } - /// Converts to bool to make testing convenient + /// Converts to bool to make testing convenient operator bool() const { return isValid(); } - /// An empty size. It consumes 0 bytes and has the lowest alignment (1) - static ThisType makeEmpty() { return ThisType{ 0, 1 }; } - /// Make an invalid size. - static ThisType makeInvalid() { return ThisType{ 0, kInvalidAlignment }; } - /// Make a size with an amount of bytes and the alignment + /// An empty size. It consumes 0 bytes and has the lowest alignment (1) + static ThisType makeEmpty() { return ThisType{0, 1}; } + /// Make an invalid size. + static ThisType makeInvalid() { return ThisType{0, kInvalidAlignment}; } + /// Make a size with an amount of bytes and the alignment static ThisType make(Count size, Count alignment) { return ThisType{size, alignment}; } - /// Given a base type returns it's size + /// Given a base type returns it's size static ThisType makeFromBaseType(BaseType baseType); - /// Multiply by a count. - /// Will return invalid if count < 0 or this is already invalid + /// Multiply by a count. + /// Will return invalid if count < 0 or this is already invalid ThisType operator*(Count count) const; Count size; - Count alignment; + Count alignment; }; struct ASTNaturalLayoutContext { - /// Given a type returns it's natural size. - /// Returns invalid size if types size could not be calculated + /// Given a type returns it's natural size. + /// Returns invalid size if types size could not be calculated NaturalSize calcSize(Type* type); - /// Ctor + /// Ctor ASTNaturalLayoutContext(ASTBuilder* astBuilder, DiagnosticSink* sink = nullptr); - -protected: - /// Gets a count (positivie integer including 0). - /// <0 indicates error +protected: + /// Gets a count (positivie integer including 0). + /// <0 indicates error Count _getCount(IntVal* intVal); - - /// The main implementation, assumes outer `calcSize` will perform caching + + /// The main implementation, assumes outer `calcSize` will perform caching NaturalSize _calcSizeImpl(Type* type); Dictionary m_typeToSize; diff --git a/source/slang/slang-ast-print.cpp b/source/slang/slang-ast-print.cpp index f3ed4ee0b..8900706fa 100644 --- a/source/slang/slang-ast-print.cpp +++ b/source/slang/slang-ast-print.cpp @@ -2,10 +2,10 @@ #include "slang-ast-print.h" #include "core/slang-char-util.h" - #include "slang-check-impl.h" -namespace Slang { +namespace Slang +{ ASTPrinter::Part::Kind ASTPrinter::Part::getKind(ASTPrinter::Part::Type type) { @@ -14,14 +14,14 @@ ASTPrinter::Part::Kind ASTPrinter::Part::getKind(ASTPrinter::Part::Type type) switch (type) { - case Type::ParamType: return Kind::Type; - case Type::ParamName: return Kind::Name; - case Type::ReturnType: return Kind::Type; - case Type::DeclPath: return Kind::Name; - case Type::GenericParamType: return Kind::Type; - case Type::GenericParamValue: return Kind::Value; - case Type::GenericParamValueType: return Kind::Type; - default: break; + case Type::ParamType: return Kind::Type; + case Type::ParamName: return Kind::Name; + case Type::ReturnType: return Kind::Type; + case Type::DeclPath: return Kind::Name; + case Type::GenericParamType: return Kind::Type; + case Type::GenericParamValue: return Kind::Value; + case Type::GenericParamValueType: return Kind::Type; + default: break; } return Kind::None; } @@ -71,7 +71,7 @@ void ASTPrinter::addVal(Val* val) val->toText(m_builder); } -/* static */void ASTPrinter::appendDeclName(Decl* decl, StringBuilder& out) +/* static */ void ASTPrinter::appendDeclName(Decl* decl, StringBuilder& out) { if (as(decl)) { @@ -132,14 +132,15 @@ void ASTPrinter::_addDeclPathRec(const DeclRef& declRef, Index depth) else if (auto namespaceDeclRef = parentDeclRef.as()) { _addDeclPathRec(namespaceDeclRef, depth + 1); - // Hmm, it could be argued that we follow the . as seen in AggType as is followed in some other languages - // like Java. - // That it is useful to have a distinction between something that is a member/method and something that is - // in a scope (such as a namespace), and is something that has returned to later languages probably for that - // reason (Slang accepts . or ::). So for now this is follows the :: convention. + // Hmm, it could be argued that we follow the . as seen in AggType as is followed in some + // other languages like Java. That it is useful to have a distinction between something that + // is a member/method and something that is in a scope (such as a namespace), and is + // something that has returned to later languages probably for that reason (Slang accepts . + // or ::). So for now this is follows the :: convention. // - // It could be argued them that the previous '.' use should vary depending on that distinction. - + // It could be argued them that the previous '.' use should vary depending on that + // distinction. + sb << toSlice("::"); } else if (auto extensionDeclRef = parentDeclRef.as()) @@ -150,7 +151,9 @@ void ASTPrinter::_addDeclPathRec(const DeclRef& declRef, Index depth) { if (auto unspecializedDeclRef = isDeclRefTypeOf(type)) { - type = DeclRefType::create(m_astBuilder, unspecializedDeclRef.getDecl()->getDefaultDeclRef()); + type = DeclRefType::create( + m_astBuilder, + unspecializedDeclRef.getDecl()->getDefaultDeclRef()); } } addType(type); @@ -172,7 +175,7 @@ void ASTPrinter::_addDeclPathRec(const DeclRef& declRef, Index depth) Name* moduleName = moduleDecl->getName(); if ((m_optionFlags & OptionFlag::ModuleName) && moduleName) { - sb << moduleName->text; + sb << moduleName->text; } return; } @@ -181,11 +184,11 @@ void ASTPrinter::_addDeclPathRec(const DeclRef& declRef, Index depth) // If the parent declaration is a generic, then we need to print out its // signature - if (parentGenericDeclRef && - !declRef.as() && + if (parentGenericDeclRef && !declRef.as() && !declRef.as()) { - auto substArgs = tryGetGenericArguments(SubstitutionSet(declRef), parentGenericDeclRef.getDecl()); + auto substArgs = + tryGetGenericArguments(SubstitutionSet(declRef), parentGenericDeclRef.getDecl()); if (substArgs.getCount()) { // If the name we printed previously was an operator @@ -214,7 +217,8 @@ void ASTPrinter::_addDeclPathRec(const DeclRef& declRef, Index depth) if (as(arg)) continue; - if (!first) sb << ", "; + if (!first) + sb << ", "; addVal(arg); first = false; } @@ -238,7 +242,8 @@ void ASTPrinter::addGenericParams(const DeclRef& genericDeclRef) { if (auto genericTypeParam = paramDeclRef.as()) { - if (!first) sb << ", "; + if (!first) + sb << ", "; first = false; { @@ -248,7 +253,8 @@ void ASTPrinter::addGenericParams(const DeclRef& genericDeclRef) } else if (auto genericValParam = paramDeclRef.as()) { - if (!first) sb << ", "; + if (!first) + sb << ", "; first = false; { @@ -265,7 +271,8 @@ void ASTPrinter::addGenericParams(const DeclRef& genericDeclRef) } else if (auto genericTypePackParam = paramDeclRef.as()) { - if (!first) sb << ", "; + if (!first) + sb << ", "; first = false; { ScopePart scopePart(this, Part::Type::GenericParamType); @@ -299,7 +306,8 @@ void ASTPrinter::addDeclParams(const DeclRef& declRef, List>* auto addParamElement = [&](Type* type, Index elementIndex) { - if (!first) sb << ", "; + if (!first) + sb << ", "; // Type part. { @@ -350,7 +358,8 @@ void ASTPrinter::addDeclParams(const DeclRef& declRef, List>* }; if (auto typePack = as(paramType)) { - for (Index elementIndex = 0; elementIndex < typePack->getTypeCount(); ++elementIndex) + for (Index elementIndex = 0; elementIndex < typePack->getTypeCount(); + ++elementIndex) { addParamElement(typePack->getElementType(elementIndex), elementIndex); } @@ -367,7 +376,9 @@ void ASTPrinter::addDeclParams(const DeclRef& declRef, List>* { addGenericParams(genericDeclRef); - addDeclParams(m_astBuilder->getMemberDeclRef(genericDeclRef, genericDeclRef.getDecl()->inner), outParamRange); + addDeclParams( + m_astBuilder->getMemberDeclRef(genericDeclRef, genericDeclRef.getDecl()->inner), + outParamRange); } else { @@ -493,7 +504,8 @@ void ASTPrinter::addDeclResultType(const DeclRef& inDeclRef) DeclRef declRef = inDeclRef; if (auto genericDeclRef = declRef.as()) { - declRef = m_astBuilder->getMemberDeclRef(genericDeclRef, genericDeclRef.getDecl()->inner); + declRef = + m_astBuilder->getMemberDeclRef(genericDeclRef, genericDeclRef.getDecl()->inner); } if (declRef.as()) @@ -518,7 +530,7 @@ void ASTPrinter::addDeclResultType(const DeclRef& inDeclRef) } } -/* static */void ASTPrinter::addDeclSignature(const DeclRef& declRef) +/* static */ void ASTPrinter::addDeclSignature(const DeclRef& declRef) { addDeclKindPrefix(declRef.getDecl()); addDeclPath(declRef); @@ -526,7 +538,9 @@ void ASTPrinter::addDeclResultType(const DeclRef& inDeclRef) addDeclResultType(declRef); } -/* static */String ASTPrinter::getDeclSignatureString(DeclRef declRef, ASTBuilder* astBuilder) +/* static */ String ASTPrinter::getDeclSignatureString( + DeclRef declRef, + ASTBuilder* astBuilder) { ASTPrinter astPrinter( astBuilder, @@ -535,14 +549,20 @@ void ASTPrinter::addDeclResultType(const DeclRef& inDeclRef) return astPrinter.getString(); } -/* static */String ASTPrinter::getDeclSignatureString(const LookupResultItem& item, ASTBuilder* astBuilder) +/* static */ String ASTPrinter::getDeclSignatureString( + const LookupResultItem& item, + ASTBuilder* astBuilder) { return getDeclSignatureString(item.declRef, astBuilder); } -/* static */UnownedStringSlice ASTPrinter::getPart(Part::Type partType, const UnownedStringSlice& slice, const List& parts) +/* static */ UnownedStringSlice ASTPrinter::getPart( + Part::Type partType, + const UnownedStringSlice& slice, + const List& parts) { - const Index index = parts.findFirstIndex([&](const Part& part) -> bool { return part.type == partType; }); + const Index index = + parts.findFirstIndex([&](const Part& part) -> bool { return part.type == partType; }); return index >= 0 ? getPart(slice, parts[index]) : UnownedStringSlice(); } diff --git a/source/slang/slang-ast-print.h b/source/slang/slang-ast-print.h index 0befc89d6..c7b20b26b 100644 --- a/source/slang/slang-ast-print.h +++ b/source/slang/slang-ast-print.h @@ -6,7 +6,8 @@ #include "../core/slang-range.h" #include "slang-ast-all.h" -namespace Slang { +namespace Slang +{ class ASTPrinter { @@ -16,10 +17,12 @@ public: { enum Enum : OptionFlags { - ParamNames = 0x01, ///< If set will output parameter names - ModuleName = 0x02, ///< Writes out module names - NoInternalKeywords = 0x04, ///< Omits internal decoration keywords (e.g. __target_intrinsic). - SimplifiedBuiltinType = 0x08, ///< Prints simplified builtin generic types (e.g. float3) instead of its generic form. + ParamNames = 0x01, ///< If set will output parameter names + ModuleName = 0x02, ///< Writes out module names + NoInternalKeywords = + 0x04, ///< Omits internal decoration keywords (e.g. __target_intrinsic). + SimplifiedBuiltinType = 0x08, ///< Prints simplified builtin generic types (e.g. float3) + ///< instead of its generic form. /// Use the original generic type name instead of the specialized /// type name defined on an extension when @@ -31,12 +34,14 @@ public: /// Note that we could/can have a hierarchy of Parts - with overlapping spans. /// Moreover we could have less kinds, if we used the overlaps to signal out sections /// - /// For example we could have a 'Param', 'Generic' span, and then have 'Name', 'Type' and 'Value'. - /// So a param type, would be the 'Type' defined in a Param span. Moreover you could have the hierarchy of Types, and then - /// such that you can pull out specific parts that make up a type. + /// For example we could have a 'Param', 'Generic' span, and then have 'Name', 'Type' and + /// 'Value'. So a param type, would be the 'Type' defined in a Param span. Moreover you could + /// have the hierarchy of Types, and then such that you can pull out specific parts that make up + /// a type. /// - /// This is powerful/flexible - but requires more complexity at the use sites, so for now we use this simpler mechanism. - + /// This is powerful/flexible - but requires more complexity at the use sites, so for now we use + /// this simpler mechanism. + /// Defines part of the structure of the output printed. struct Part { @@ -51,17 +56,24 @@ public: enum class Type { None, - ParamType, ///< The type associated with a parameter - ParamName, ///< The name associated with a parameter - ReturnType, ///< The return type - DeclPath, ///< The declaration path (NOT including the actual decl name) - GenericParamType, ///< Generic parameter type - GenericParamValue, ///< Generic parameter value - GenericParamValueType, ///< The type requirement for a value type + ParamType, ///< The type associated with a parameter + ParamName, ///< The name associated with a parameter + ReturnType, ///< The return type + DeclPath, ///< The declaration path (NOT including the actual decl name) + GenericParamType, ///< Generic parameter type + GenericParamValue, ///< Generic parameter value + GenericParamValueType, ///< The type requirement for a value type }; static Kind getKind(Type type); - static Part make(Type type, Index start, Index end) { Part part; part.type = type; part.start = start; part.end = end; return part; } + static Part make(Type type, Index start, Index end) + { + Part part; + part.type = type; + part.start = start; + part.end = end; + return part; + } Type type = Type::None; Index start; @@ -76,10 +88,8 @@ public: struct ScopePart { - ScopePart(ASTPrinter* printer, Part::Type type): - m_printer(printer), - m_type(type), - m_startIndex(printer->m_builder.getLength()) + ScopePart(ASTPrinter* printer, Part::Type type) + : m_printer(printer), m_type(type), m_startIndex(printer->m_builder.getLength()) { } ~ScopePart() @@ -96,67 +106,74 @@ public: ASTPrinter* m_printer; }; - /// We might want options to change how things are output, for example we may want to output parameter names - /// if there are any + /// We might want options to change how things are output, for example we may want to output + /// parameter names if there are any - /// Get the currently built up string + /// Get the currently built up string StringBuilder& getStringBuilder() { return m_builder; } - /// Get the current offset, for the end of the string builder - useful for building up ranges + /// Get the current offset, for the end of the string builder - useful for building up ranges Index getOffset() const { return m_builder.getLength(); } - /// Reset the state + /// Reset the state void reset() { m_builder.clear(); } - /// Get the current string + /// Get the current string String getString() { return m_builder.produceString(); } - /// Get contents as a slice + /// Get contents as a slice UnownedStringSlice getSlice() const { return m_builder.getUnownedSlice(); } - /// Add a type + /// Add a type void addType(Type* type); - /// Add a value + /// Add a value void addVal(Val* val); - /// Add the path to the declaration including the declaration name + /// Add the path to the declaration including the declaration name void addDeclPath(const DeclRef& declRef); - /// Add the path such that it encapsulates all overridable decls (ie is without terminal generic parameters) + /// Add the path such that it encapsulates all overridable decls (ie is without terminal generic + /// parameters) void addOverridableDeclPath(const DeclRef& declRef); - /// Add just the parameters from a declaration. - /// Will output the generic parameters (if it's a generic) in <> before the parameters () + /// Add just the parameters from a declaration. + /// Will output the generic parameters (if it's a generic) in <> before the parameters () void addDeclParams(const DeclRef& declRef, List>* outParamRange = nullptr); - /// Add a prefix that describes the kind of declaration + /// Add a prefix that describes the kind of declaration void addDeclKindPrefix(Decl* decl); - /// Add the result type - /// Should be called after the decl params + /// Add the result type + /// Should be called after the decl params void addDeclResultType(const DeclRef& inDeclRef); - /// Add the signature for the decl + /// Add the signature for the decl void addDeclSignature(const DeclRef& declRef); - /// Add generic parameters + /// Add generic parameters void addGenericParams(const DeclRef& genericDeclRef); - /// Get the specified part type. Returns empty slice if not found + /// Get the specified part type. Returns empty slice if not found UnownedStringSlice getPartSlice(Part::Type partType) const; - /// Get the slice for a part + /// Get the slice for a part UnownedStringSlice getPartSlice(const Part& part) const { return getPart(getSlice(), part); } - /// Gets the specified part type - static UnownedStringSlice getPart(const UnownedStringSlice& slice, const Part& part) { return (part.type != Part::Type::None) ? UnownedStringSlice(slice.begin() + part.start, slice.begin() + part.end) : UnownedStringSlice(); } - static UnownedStringSlice getPart(Part::Type partType, const UnownedStringSlice& slice, const List& parts); + /// Gets the specified part type + static UnownedStringSlice getPart(const UnownedStringSlice& slice, const Part& part) + { + return (part.type != Part::Type::None) + ? UnownedStringSlice(slice.begin() + part.start, slice.begin() + part.end) + : UnownedStringSlice(); + } + static UnownedStringSlice getPart( + Part::Type partType, + const UnownedStringSlice& slice, + const List& parts); static void appendDeclName(Decl* decl, StringBuilder& out); - /// Ctor - ASTPrinter(ASTBuilder* astBuilder, OptionFlags optionFlags = 0, List* parts = nullptr): - m_astBuilder(astBuilder), - m_parts(parts), - m_optionFlags(optionFlags) + /// Ctor + ASTPrinter(ASTBuilder* astBuilder, OptionFlags optionFlags = 0, List* parts = nullptr) + : m_astBuilder(astBuilder), m_parts(parts), m_optionFlags(optionFlags) { } @@ -164,14 +181,13 @@ public: static String getDeclSignatureString(DeclRef declRef, ASTBuilder* astBuilder); protected: - void _addDeclPathRec(const DeclRef& declRef, Index depth); void _addDeclName(Decl* decl); - - OptionFlags m_optionFlags; ///< Flags controlling output - List* m_parts; ///< Optional parts list - ASTBuilder* m_astBuilder; ///< Required as types are setup as part of printing - StringBuilder m_builder; ///< The output of the 'printing' process + + OptionFlags m_optionFlags; ///< Flags controlling output + List* m_parts; ///< Optional parts list + ASTBuilder* m_astBuilder; ///< Required as types are setup as part of printing + StringBuilder m_builder; ///< The output of the 'printing' process }; } // namespace Slang diff --git a/source/slang/slang-ast-reflect.cpp b/source/slang/slang-ast-reflect.cpp index a91ec0618..c6c51cb73 100644 --- a/source/slang/slang-ast-reflect.cpp +++ b/source/slang/slang-ast-reflect.cpp @@ -1,39 +1,35 @@ -#include "slang.h" - #include "slang-ast-reflect.h" #include "../core/slang-smart-pointer.h" - #include "slang-ast-all.h" - -#include -#include - +#include "slang-generated-ast-macro.h" #include "slang-visitor.h" +#include "slang.h" -#include "slang-generated-ast-macro.h" +#include +#include namespace Slang { -#define SLANG_REFLECT_GET_REFLECT_CLASS_INFO(NAME, SUPER, ORIGIN, LAST, MARKER, TYPE, param) infos.infos[int(ASTNodeType::NAME)] = &NAME::kReflectClassInfo; +#define SLANG_REFLECT_GET_REFLECT_CLASS_INFO(NAME, SUPER, ORIGIN, LAST, MARKER, TYPE, param) \ + infos.infos[int(ASTNodeType::NAME)] = &NAME::kReflectClassInfo; static ASTClassInfo::Infos _calcInfos() { ASTClassInfo::Infos infos; memset(&infos, 0, sizeof(infos)); - SLANG_ALL_ASTNode_NodeBase(SLANG_REFLECT_GET_REFLECT_CLASS_INFO, _) - return infos; + SLANG_ALL_ASTNode_NodeBase(SLANG_REFLECT_GET_REFLECT_CLASS_INFO, _) return infos; } -/* static */const ASTClassInfo::Infos ASTClassInfo::kInfos = _calcInfos(); +/* static */ const ASTClassInfo::Infos ASTClassInfo::kInfos = _calcInfos(); // Now try and implement all of the classes // Macro generated is of the format struct ASTConstructAccess { - template + template struct Impl { static void* create(void* context) @@ -61,25 +57,34 @@ struct ASTConstructAccess #define SLANG_GET_DESTROY_FUNC_AST(NAME) &ASTConstructAccess::Impl::destroy #define SLANG_REFLECT_CLASS_INFO(NAME, SUPER, ORIGIN, LAST, MARKER, TYPE, param) \ - /* static */const ReflectClassInfo NAME::kReflectClassInfo = { uint32_t(ASTNodeType::NAME), uint32_t(ASTNodeType::LAST), SLANG_GET_SUPER_##TYPE(SUPER), #NAME, SLANG_GET_CREATE_FUNC_##MARKER(NAME), SLANG_GET_DESTROY_FUNC_##MARKER(NAME), uint32_t(sizeof(NAME)), uint8_t(SLANG_ALIGN_OF(NAME)) }; + /* static */ const ReflectClassInfo NAME::kReflectClassInfo = { \ + uint32_t(ASTNodeType::NAME), \ + uint32_t(ASTNodeType::LAST), \ + SLANG_GET_SUPER_##TYPE(SUPER), \ + #NAME, \ + SLANG_GET_CREATE_FUNC_##MARKER(NAME), \ + SLANG_GET_DESTROY_FUNC_##MARKER(NAME), \ + uint32_t(sizeof(NAME)), \ + uint8_t(SLANG_ALIGN_OF(NAME))}; SLANG_ALL_ASTNode_NodeBase(SLANG_REFLECT_CLASS_INFO, _) // We dispatch to non 'abstract' types -#define SLANG_CASE_AST(NAME) case ASTNodeType::NAME: return visitor->dispatch_##NAME(static_cast(this), extra); +#define SLANG_CASE_AST(NAME) \ + case ASTNodeType::NAME: return visitor->dispatch_##NAME(static_cast(this), extra); #define SLANG_CASE_ABSTRACT_AST(NAME) -#define SLANG_CASE_DISPATCH(NAME, SUPER, ORIGIN, LAST, MARKER, TYPE, param) SLANG_CASE_##MARKER(NAME) +#define SLANG_CASE_DISPATCH(NAME, SUPER, ORIGIN, LAST, MARKER, TYPE, param) \ + SLANG_CASE_##MARKER(NAME) -void Val::accept(IValVisitor* visitor, void* extra) + void Val::accept(IValVisitor* visitor, void* extra) { const ReflectClassInfo& classInfo = getClassInfo(); const ASTNodeType astType = ASTNodeType(classInfo.m_classId); switch (astType) { - SLANG_CHILDREN_ASTNode_Val(SLANG_CASE_DISPATCH, _) - default: SLANG_ASSERT(!"Unknown type"); + SLANG_CHILDREN_ASTNode_Val(SLANG_CASE_DISPATCH, _) default : SLANG_ASSERT(!"Unknown type"); } } @@ -90,8 +95,7 @@ void Type::accept(ITypeVisitor* visitor, void* extra) switch (astType) { - SLANG_CHILDREN_ASTNode_Type(SLANG_CASE_DISPATCH, _) - default: SLANG_ASSERT(!"Unknown type"); + SLANG_CHILDREN_ASTNode_Type(SLANG_CASE_DISPATCH, _) default : SLANG_ASSERT(!"Unknown type"); } } @@ -102,8 +106,8 @@ void Modifier::accept(IModifierVisitor* visitor, void* extra) switch (astType) { - SLANG_CHILDREN_ASTNode_Modifier(SLANG_CASE_DISPATCH, _) - default: SLANG_ASSERT(!"Unknown type"); + SLANG_CHILDREN_ASTNode_Modifier(SLANG_CASE_DISPATCH, _) default + : SLANG_ASSERT(!"Unknown type"); } } @@ -114,8 +118,8 @@ void DeclBase::accept(IDeclVisitor* visitor, void* extra) switch (astType) { - SLANG_CHILDREN_ASTNode_DeclBase(SLANG_CASE_DISPATCH, _) - default: SLANG_ASSERT(!"Unknown type"); + SLANG_CHILDREN_ASTNode_DeclBase(SLANG_CASE_DISPATCH, _) default + : SLANG_ASSERT(!"Unknown type"); } } @@ -126,8 +130,7 @@ void Expr::accept(IExprVisitor* visitor, void* extra) switch (astType) { - SLANG_CHILDREN_ASTNode_Expr(SLANG_CASE_DISPATCH, _) - default: SLANG_ASSERT(!"Unknown type"); + SLANG_CHILDREN_ASTNode_Expr(SLANG_CASE_DISPATCH, _) default : SLANG_ASSERT(!"Unknown type"); } } @@ -138,8 +141,7 @@ void Stmt::accept(IStmtVisitor* visitor, void* extra) switch (astType) { - SLANG_CHILDREN_ASTNode_Stmt(SLANG_CASE_DISPATCH, _) - default: SLANG_ASSERT(!"Unknown type"); + SLANG_CHILDREN_ASTNode_Stmt(SLANG_CASE_DISPATCH, _) default : SLANG_ASSERT(!"Unknown type"); } } diff --git a/source/slang/slang-ast-reflect.h b/source/slang/slang-ast-reflect.h index 9bce74587..5bf412955 100644 --- a/source/slang/slang-ast-reflect.h +++ b/source/slang/slang-ast-reflect.h @@ -3,50 +3,55 @@ #ifndef SLANG_AST_REFLECT_H #define SLANG_AST_REFLECT_H -#include "slang-serialize-reflection.h" - #include "slang-generated-ast.h" +#include "slang-serialize-reflection.h" -// Implementation for SLANG_ABSTRACT_CLASS(x) using reflection from C++ extractor in slang-ast-generated.h +// Implementation for SLANG_ABSTRACT_CLASS(x) using reflection from C++ extractor in +// slang-ast-generated.h #define SLANG_AST_CLASS_REFLECT_IMPL(NAME, SUPER, ORIGIN, LAST, MARKER, TYPE, param) \ - protected: \ - NAME() = default; \ - public: \ - typedef NAME This; \ - static constexpr ASTNodeType kType = ASTNodeType::NAME; \ - static const ReflectClassInfo kReflectClassInfo; \ - SLANG_FORCE_INLINE static bool isDerivedFrom(ASTNodeType type) { return int(type) >= int(kType) && int(type) <= int(ASTNodeType::LAST); } \ - SLANG_CLASS_REFLECT_SUPER_##TYPE(SUPER) \ - friend class ASTBuilder; \ - friend struct ASTConstructAccess; \ - friend struct ASTFieldAccess; \ +protected: \ + NAME() = default; \ + \ +public: \ + typedef NAME This; \ + static constexpr ASTNodeType kType = ASTNodeType::NAME; \ + static const ReflectClassInfo kReflectClassInfo; \ + SLANG_FORCE_INLINE static bool isDerivedFrom(ASTNodeType type) \ + { \ + return int(type) >= int(kType) && int(type) <= int(ASTNodeType::LAST); \ + } \ + SLANG_CLASS_REFLECT_SUPER_##TYPE(SUPER) friend class ASTBuilder; \ + friend struct ASTConstructAccess; \ + friend struct ASTFieldAccess; \ friend struct ASTDumpAccess; // Macro definitions - use the SLANG_ASTNode_ definitions to invoke the IMPL to produce the code // injected into AST classes -#define SLANG_ABSTRACT_AST_CLASS(NAME) SLANG_ASTNode_##NAME(SLANG_AST_CLASS_REFLECT_IMPL, _) -#define SLANG_AST_CLASS(NAME) SLANG_ASTNode_##NAME(SLANG_AST_CLASS_REFLECT_IMPL, _) +#define SLANG_ABSTRACT_AST_CLASS(NAME) SLANG_ASTNode_##NAME(SLANG_AST_CLASS_REFLECT_IMPL, _) +#define SLANG_AST_CLASS(NAME) SLANG_ASTNode_##NAME(SLANG_AST_CLASS_REFLECT_IMPL, _) // Macros for simulating virtual methods without virtual methods #define SLANG_AST_NODE_INVOKE(method, methodParams) _##method##Override methodParams -#define SLANG_AST_NODE_CASE(NAME, SUPER, ORIGIN, LAST, MARKER, TYPE, param) case ASTNodeType::NAME: return static_cast(this)-> SLANG_AST_NODE_INVOKE param; +#define SLANG_AST_NODE_CASE(NAME, SUPER, ORIGIN, LAST, MARKER, TYPE, param) \ + case ASTNodeType::NAME: return static_cast(this)->SLANG_AST_NODE_INVOKE param; -#define SLANG_AST_NODE_VIRTUAL_CALL(base, methodName, methodParams) \ - switch (astNodeType) \ - { \ - SLANG_ALL_ASTNode_##base(SLANG_AST_NODE_CASE, (methodName, methodParams)) \ - default: return SLANG_AST_NODE_INVOKE (methodName, methodParams); \ +#define SLANG_AST_NODE_VIRTUAL_CALL(base, methodName, methodParams) \ + switch (astNodeType) \ + { \ + SLANG_ALL_ASTNode_##base(SLANG_AST_NODE_CASE, (methodName, methodParams)) default \ + : return SLANG_AST_NODE_INVOKE(methodName, methodParams); \ } // Same but for a method that's const -#define SLANG_AST_NODE_CONST_CASE(NAME, SUPER, ORIGIN, LAST, MARKER, TYPE, param) case ASTNodeType::NAME: return static_cast(this)-> SLANG_AST_NODE_INVOKE param; -#define SLANG_AST_NODE_CONST_VIRTUAL_CALL(base, methodName, methodParams) \ - switch (astNodeType) \ - { \ - SLANG_ALL_ASTNode_##base(SLANG_AST_NODE_CONST_CASE, (methodName, methodParams)) \ - default: return SLANG_AST_NODE_INVOKE (methodName, methodParams); \ +#define SLANG_AST_NODE_CONST_CASE(NAME, SUPER, ORIGIN, LAST, MARKER, TYPE, param) \ + case ASTNodeType::NAME: return static_cast(this)->SLANG_AST_NODE_INVOKE param; +#define SLANG_AST_NODE_CONST_VIRTUAL_CALL(base, methodName, methodParams) \ + switch (astNodeType) \ + { \ + SLANG_ALL_ASTNode_##base(SLANG_AST_NODE_CONST_CASE, (methodName, methodParams)) default \ + : return SLANG_AST_NODE_INVOKE(methodName, methodParams); \ } #endif // SLANG_AST_REFLECT_H diff --git a/source/slang/slang-ast-stmt.h b/source/slang/slang-ast-stmt.h index 0342cdc50..8c21f78a5 100644 --- a/source/slang/slang-ast-stmt.h +++ b/source/slang/slang-ast-stmt.h @@ -4,11 +4,12 @@ #include "slang-ast-base.h" -namespace Slang { +namespace Slang +{ // Syntax class definitions for statements. -class ScopeStmt : public Stmt +class ScopeStmt : public Stmt { SLANG_ABSTRACT_AST_CLASS(ScopeStmt) @@ -16,7 +17,7 @@ class ScopeStmt : public Stmt }; // A sequence of statements, treated as a single statement -class SeqStmt : public Stmt +class SeqStmt : public Stmt { SLANG_AST_CLASS(SeqStmt) @@ -33,19 +34,19 @@ class LabelStmt : public Stmt }; // The simplest kind of scope statement: just a `{...}` block -class BlockStmt : public ScopeStmt +class BlockStmt : public ScopeStmt { SLANG_AST_CLASS(BlockStmt) - /// TODO(JS): Having ranges of sourcelocs might be a good addition to AST nodes in general. - SourceLoc closingSourceLoc; ///< The source location of the closing brace + /// TODO(JS): Having ranges of sourcelocs might be a good addition to AST nodes in general. + SourceLoc closingSourceLoc; ///< The source location of the closing brace Stmt* body = nullptr; }; // A statement that we aren't going to parse or check, because // we want to let a downstream compiler handle any issues -class UnparsedStmt : public Stmt +class UnparsedStmt : public Stmt { SLANG_AST_CLASS(UnparsedStmt) @@ -53,24 +54,24 @@ class UnparsedStmt : public Stmt List tokens; }; -class EmptyStmt : public Stmt +class EmptyStmt : public Stmt { SLANG_AST_CLASS(EmptyStmt) }; -class DiscardStmt : public Stmt +class DiscardStmt : public Stmt { SLANG_AST_CLASS(DiscardStmt) }; -class DeclStmt : public Stmt +class DeclStmt : public Stmt { SLANG_AST_CLASS(DeclStmt) DeclBase* decl = nullptr; }; -class IfStmt : public Stmt +class IfStmt : public Stmt { SLANG_AST_CLASS(IfStmt) @@ -80,13 +81,12 @@ class IfStmt : public Stmt }; // A statement that can be escaped with a `break` -class BreakableStmt : public ScopeStmt +class BreakableStmt : public ScopeStmt { SLANG_ABSTRACT_AST_CLASS(BreakableStmt) - }; -class SwitchStmt : public BreakableStmt +class SwitchStmt : public BreakableStmt { SLANG_AST_CLASS(SwitchStmt) @@ -121,7 +121,7 @@ class IntrinsicAsmStmt : public Stmt // A statement that is expected to appear lexically nested inside // some other construct, and thus needs to keep track of the // outer statement that it is associated with... -class ChildStmt : public Stmt +class ChildStmt : public Stmt { SLANG_ABSTRACT_AST_CLASS(ChildStmt) @@ -133,14 +133,13 @@ class ChildStmt : public Stmt // Note(tfoley): A correct AST for a C-like language would treat // these as a labelled statement, and so they would contain a // sub-statement. I'm leaving that out for now for simplicity. -class CaseStmtBase : public ChildStmt +class CaseStmtBase : public ChildStmt { SLANG_ABSTRACT_AST_CLASS(CaseStmtBase) - }; // a `case` statement inside a `switch` -class CaseStmt : public CaseStmtBase +class CaseStmt : public CaseStmtBase { SLANG_AST_CLASS(CaseStmt) @@ -150,13 +149,13 @@ class CaseStmt : public CaseStmtBase }; // a `default` statement inside a `switch` -class DefaultStmt : public CaseStmtBase +class DefaultStmt : public CaseStmtBase { SLANG_AST_CLASS(DefaultStmt) }; // a `default` statement inside a `switch` -class GpuForeachStmt : public ScopeStmt +class GpuForeachStmt : public ScopeStmt { SLANG_AST_CLASS(GpuForeachStmt) @@ -167,14 +166,13 @@ class GpuForeachStmt : public ScopeStmt }; // A statement that represents a loop, and can thus be escaped with a `continue` -class LoopStmt : public BreakableStmt +class LoopStmt : public BreakableStmt { SLANG_ABSTRACT_AST_CLASS(LoopStmt) - }; // A `for` statement -class ForStmt : public LoopStmt +class ForStmt : public LoopStmt { SLANG_AST_CLASS(ForStmt) @@ -186,13 +184,12 @@ class ForStmt : public LoopStmt // A `for` statement in a language that doesn't restrict the scope // of the loop variable to the body. -class UnscopedForStmt : public ForStmt +class UnscopedForStmt : public ForStmt { - SLANG_AST_CLASS(UnscopedForStmt) -; + SLANG_AST_CLASS(UnscopedForStmt); }; -class WhileStmt : public LoopStmt +class WhileStmt : public LoopStmt { SLANG_AST_CLASS(WhileStmt) @@ -200,7 +197,7 @@ class WhileStmt : public LoopStmt Stmt* statement = nullptr; }; -class DoWhileStmt : public LoopStmt +class DoWhileStmt : public LoopStmt { SLANG_AST_CLASS(DoWhileStmt) @@ -209,7 +206,7 @@ class DoWhileStmt : public LoopStmt }; // A compile-time, range-based `for` loop, which will not appear in the output code -class CompileTimeForStmt : public ScopeStmt +class CompileTimeForStmt : public ScopeStmt { SLANG_AST_CLASS(CompileTimeForStmt) @@ -223,31 +220,31 @@ class CompileTimeForStmt : public ScopeStmt // The case of child statements that do control flow relative // to their parent statement. -class JumpStmt : public ChildStmt +class JumpStmt : public ChildStmt { SLANG_ABSTRACT_AST_CLASS(JumpStmt) }; -class BreakStmt : public JumpStmt +class BreakStmt : public JumpStmt { SLANG_AST_CLASS(BreakStmt) Token targetLabel; }; -class ContinueStmt : public JumpStmt +class ContinueStmt : public JumpStmt { SLANG_AST_CLASS(ContinueStmt) }; -class ReturnStmt : public Stmt +class ReturnStmt : public Stmt { SLANG_AST_CLASS(ReturnStmt) Expr* expression = nullptr; }; -class ExpressionStmt : public Stmt +class ExpressionStmt : public Stmt { SLANG_AST_CLASS(ExpressionStmt) diff --git a/source/slang/slang-ast-support-types.cpp b/source/slang/slang-ast-support-types.cpp index 6a957e427..a06fa2b88 100644 --- a/source/slang/slang-ast-support-types.cpp +++ b/source/slang/slang-ast-support-types.cpp @@ -1,14 +1,14 @@ #include "slang-ast-support-types.h" + #include "slang-ast-base.h" -#include "slang-ast-type.h" #include "slang-ast-expr.h" +#include "slang-ast-type.h" #include "slang-check-impl.h" namespace Slang { QualType::QualType(Type* type) - : type(type) - , isLeftValue(false) + : type(type), isLeftValue(false) { if (as(type)) { @@ -45,7 +45,8 @@ Expr* getInnerMostExprFromHigherOrderExpr(Expr* expr, FunctionDifferentiableLeve { if (as(expr)) outLevel = FunctionDifferentiableLevel::Backward; - else if (as(expr) && outLevel == FunctionDifferentiableLevel::None) + else if ( + as(expr) && outLevel == FunctionDifferentiableLevel::None) outLevel = FunctionDifferentiableLevel::Forward; if (workListSet.add(higherOrder)) { @@ -69,4 +70,4 @@ UnownedStringSlice getHigherOrderOperatorName(HigherOrderInvokeExpr* expr) return UnownedStringSlice(); } -} +} // namespace Slang diff --git a/source/slang/slang-ast-support-types.h b/source/slang/slang-ast-support-types.h index 3b10539cf..beacbaebe 100644 --- a/source/slang/slang-ast-support-types.h +++ b/source/slang/slang-ast-support-types.h @@ -1,1669 +1,1702 @@ #ifndef SLANG_AST_SUPPORT_TYPES_H #define SLANG_AST_SUPPORT_TYPES_H -#include "../core/slang-basic.h" - +#include "../compiler-core/slang-doc-extractor.h" #include "../compiler-core/slang-lexer.h" #include "../compiler-core/slang-name.h" -#include "../compiler-core/slang-doc-extractor.h" - -#include "slang-profile.h" -#include "slang-type-system-shared.h" -#include "slang.h" - +#include "../core/slang-basic.h" #include "../core/slang-semantic-version.h" - -#include "slang-generated-ast.h" - -#include "slang-serialize-reflection.h" - #include "slang-ast-reflect.h" +#include "slang-generated-ast.h" +#include "slang-profile.h" #include "slang-ref-object-reflect.h" +#include "slang-serialize-reflection.h" +#include "slang-type-system-shared.h" +#include "slang.h" #include #include namespace Slang { - class Module; - class Name; - class Session; - class SyntaxVisitor; - class FuncDecl; - class Layout; - - struct IExprVisitor; - struct IDeclVisitor; - struct IModifierVisitor; - struct IStmtVisitor; - struct ITypeVisitor; - struct IValVisitor; - - class Parser; - class SyntaxNode; - - class Decl; - struct QualType; - class Type; - struct TypeExp; - class Val; - - class NodeBase; - class LookupDeclRef; - class GenericAppDeclRef; - struct CapabilitySet; - - template - T* as(NodeBase* node); - - template - const T* as(const NodeBase* node); - - void printDiagnosticArg(StringBuilder& sb, Decl* decl); - void printDiagnosticArg(StringBuilder& sb, Type* type); - void printDiagnosticArg(StringBuilder& sb, TypeExp const& type); - void printDiagnosticArg(StringBuilder& sb, QualType const& type); - void printDiagnosticArg(StringBuilder& sb, Val* val); - void printDiagnosticArg(StringBuilder& sb, DeclRefBase* declRefBase); - void printDiagnosticArg(StringBuilder& sb, ASTNodeType nodeType); - void printDiagnosticArg(StringBuilder& sb, const CapabilitySet& set); - void printDiagnosticArg(StringBuilder& sb, List& set); - - struct QualifiedDeclPath - { - DeclRefBase* declRef; - QualifiedDeclPath() = default; - QualifiedDeclPath(DeclRefBase* declRef) - : declRef(declRef) - {} - }; - // Prints the fully qualified decl name. - void printDiagnosticArg(StringBuilder& sb, QualifiedDeclPath path); - - - class SyntaxNode; - SourceLoc getDiagnosticPos(SyntaxNode const* syntax); - SourceLoc getDiagnosticPos(TypeExp const& typeExp); - SourceLoc getDiagnosticPos(DeclRefBase* declRef); - - typedef NodeBase* (*SyntaxParseCallback)(Parser* parser, void* userData); - - typedef unsigned int ConversionCost; - enum : ConversionCost +class Module; +class Name; +class Session; +class SyntaxVisitor; +class FuncDecl; +class Layout; + +struct IExprVisitor; +struct IDeclVisitor; +struct IModifierVisitor; +struct IStmtVisitor; +struct ITypeVisitor; +struct IValVisitor; + +class Parser; +class SyntaxNode; + +class Decl; +struct QualType; +class Type; +struct TypeExp; +class Val; + +class NodeBase; +class LookupDeclRef; +class GenericAppDeclRef; +struct CapabilitySet; + +template +T* as(NodeBase* node); + +template +const T* as(const NodeBase* node); + +void printDiagnosticArg(StringBuilder& sb, Decl* decl); +void printDiagnosticArg(StringBuilder& sb, Type* type); +void printDiagnosticArg(StringBuilder& sb, TypeExp const& type); +void printDiagnosticArg(StringBuilder& sb, QualType const& type); +void printDiagnosticArg(StringBuilder& sb, Val* val); +void printDiagnosticArg(StringBuilder& sb, DeclRefBase* declRefBase); +void printDiagnosticArg(StringBuilder& sb, ASTNodeType nodeType); +void printDiagnosticArg(StringBuilder& sb, const CapabilitySet& set); +void printDiagnosticArg(StringBuilder& sb, List& set); + +struct QualifiedDeclPath +{ + DeclRefBase* declRef; + QualifiedDeclPath() = default; + QualifiedDeclPath(DeclRefBase* declRef) + : declRef(declRef) { - // No conversion at all - kConversionCost_None = 0, - - kConversionCost_GenericParamUpcast = 1, - kConversionCost_UnconstraintGenericParam = 20, - kConversionCost_SizedArrayToUnsizedArray = 30, + } +}; +// Prints the fully qualified decl name. +void printDiagnosticArg(StringBuilder& sb, QualifiedDeclPath path); - // Convert between matrices of different layout - kConversionCost_MatrixLayout = 5, - // Conversion from a buffer to the type it carries needs to add a minimal - // extra cost, just so we can distinguish an overload on `ConstantBuffer` - // from one on `Foo` - kConversionCost_GetRef = 5, - kConversionCost_ImplicitDereference = 10, - kConversionCost_InRangeIntLitConversion = 23, - kConversionCost_InRangeIntLitSignedToUnsignedConversion = 32, - kConversionCost_InRangeIntLitUnsignedToSignedConversion = 81, +class SyntaxNode; +SourceLoc getDiagnosticPos(SyntaxNode const* syntax); +SourceLoc getDiagnosticPos(TypeExp const& typeExp); +SourceLoc getDiagnosticPos(DeclRefBase* declRef); - kConversionCost_MutablePtrToConstPtr = 20, +typedef NodeBase* (*SyntaxParseCallback)(Parser* parser, void* userData); - // Conversions based on explicit sub-typing relationships are the cheapest - // - // TODO(tfoley): We will eventually need a discipline for ranking - // when two up-casts are comparable. - kConversionCost_CastToInterface = 50, +typedef unsigned int ConversionCost; +enum : ConversionCost +{ + // No conversion at all + kConversionCost_None = 0, - // Conversion that is lossless and keeps the "kind" of the value the same - kConversionCost_BoolToInt = 120, // Converting bool to int has lower cost than other integer types to prevent ambiguity. - kConversionCost_RankPromotion = 150, - kConversionCost_NoneToOptional = 150, - kConversionCost_ValToOptional = 150, - kConversionCost_NullPtrToPtr = 150, - kConversionCost_PtrToVoidPtr = 150, + kConversionCost_GenericParamUpcast = 1, + kConversionCost_UnconstraintGenericParam = 20, + kConversionCost_SizedArrayToUnsizedArray = 30, - // Conversions that are lossless, but change "kind" - kConversionCost_UnsignedToSignedPromotion = 200, + // Convert between matrices of different layout + kConversionCost_MatrixLayout = 5, - // Same-size size unsigned->signed conversions are potentially lossy, but they are commonly allowed silently. - kConversionCost_SameSizeUnsignedToSignedConversion = 250, + // Conversion from a buffer to the type it carries needs to add a minimal + // extra cost, just so we can distinguish an overload on `ConstantBuffer` + // from one on `Foo` + kConversionCost_GetRef = 5, + kConversionCost_ImplicitDereference = 10, + kConversionCost_InRangeIntLitConversion = 23, + kConversionCost_InRangeIntLitSignedToUnsignedConversion = 32, + kConversionCost_InRangeIntLitUnsignedToSignedConversion = 81, - // Conversion from signed->unsigned integer of same or greater size - kConversionCost_SignedToUnsignedConversion = 300, + kConversionCost_MutablePtrToConstPtr = 20, - // Cost of converting an integer to a floating-point type - kConversionCost_IntegerToFloatConversion = 400, + // Conversions based on explicit sub-typing relationships are the cheapest + // + // TODO(tfoley): We will eventually need a discipline for ranking + // when two up-casts are comparable. + kConversionCost_CastToInterface = 50, - // Cost of converting a pointer to bool - kConversionCost_PtrToBool = 400, + // Conversion that is lossless and keeps the "kind" of the value the same + kConversionCost_BoolToInt = + 120, // Converting bool to int has lower cost than other integer types to prevent ambiguity. + kConversionCost_RankPromotion = 150, + kConversionCost_NoneToOptional = 150, + kConversionCost_ValToOptional = 150, + kConversionCost_NullPtrToPtr = 150, + kConversionCost_PtrToVoidPtr = 150, - // Cost of converting an integer to int16_t - kConversionCost_IntegerTruncate = 450, + // Conversions that are lossless, but change "kind" + kConversionCost_UnsignedToSignedPromotion = 200, - // Cost of converting an integer to a half type - kConversionCost_IntegerToHalfConversion = 500, + // Same-size size unsigned->signed conversions are potentially lossy, but they are commonly + // allowed silently. + kConversionCost_SameSizeUnsignedToSignedConversion = 250, - // Cost of using a concrete argument pack - kConversionCost_ParameterPack = 500, + // Conversion from signed->unsigned integer of same or greater size + kConversionCost_SignedToUnsignedConversion = 300, - // Default case (usable for user-defined conversions) - kConversionCost_Default = 500, + // Cost of converting an integer to a floating-point type + kConversionCost_IntegerToFloatConversion = 400, - // Catch-all for conversions that should be discouraged - // (i.e., that really shouldn't be made implicitly) - // - // TODO: make these conversions not be allowed implicitly in "Slang mode" - kConversionCost_GeneralConversion = 900, + // Cost of converting a pointer to bool + kConversionCost_PtrToBool = 400, - // This is the cost of an explicit conversion, which should - // not actually be performed. - kConversionCost_Explicit = 90000, + // Cost of converting an integer to int16_t + kConversionCost_IntegerTruncate = 450, - // Additional conversion cost to add when promoting from a scalar to - // a vector (this will be added to the cost, if any, of converting - // the element type of the vector) - kConversionCost_OneVectorToScalar = 1, - kConversionCost_ScalarToVector = 2, - kConversionCost_ScalarToMatrix = 10, - kConversionCost_ScalarIntegerToFloatMatrix = kConversionCost_IntegerToFloatConversion + kConversionCost_ScalarToMatrix, + // Cost of converting an integer to a half type + kConversionCost_IntegerToHalfConversion = 500, - // Additional cost when casting an LValue. - kConversionCost_LValueCast = 800, + // Cost of using a concrete argument pack + kConversionCost_ParameterPack = 500, - // Conversion is impossible - kConversionCost_Impossible = 0xFFFFFFFF, - }; + // Default case (usable for user-defined conversions) + kConversionCost_Default = 500, - typedef unsigned int BuiltinConversionKind; - enum : BuiltinConversionKind - { - kBuiltinConversion_Unknown = 0, - kBuiltinConversion_FloatToDouble = 1, - }; + // Catch-all for conversions that should be discouraged + // (i.e., that really shouldn't be made implicitly) + // + // TODO: make these conversions not be allowed implicitly in "Slang mode" + kConversionCost_GeneralConversion = 900, + + // This is the cost of an explicit conversion, which should + // not actually be performed. + kConversionCost_Explicit = 90000, + + // Additional conversion cost to add when promoting from a scalar to + // a vector (this will be added to the cost, if any, of converting + // the element type of the vector) + kConversionCost_OneVectorToScalar = 1, + kConversionCost_ScalarToVector = 2, + kConversionCost_ScalarToMatrix = 10, + kConversionCost_ScalarIntegerToFloatMatrix = + kConversionCost_IntegerToFloatConversion + kConversionCost_ScalarToMatrix, + + // Additional cost when casting an LValue. + kConversionCost_LValueCast = 800, + + // Conversion is impossible + kConversionCost_Impossible = 0xFFFFFFFF, +}; + +typedef unsigned int BuiltinConversionKind; +enum : BuiltinConversionKind +{ + kBuiltinConversion_Unknown = 0, + kBuiltinConversion_FloatToDouble = 1, +}; - enum class ImageFormat - { +enum class ImageFormat +{ #define SLANG_FORMAT(NAME, OTHER) NAME, #include "slang-image-format-defs.h" #undef SLANG_FORMAT - }; +}; - struct ImageFormatInfo +struct ImageFormatInfo +{ + SlangScalarType scalarType; ///< If image format is not made up of channels of set sizes this + ///< will be SLANG_SCALAR_TYPE_NONE + uint8_t channelCount; ///< The number of channels + uint8_t sizeInBytes; ///< Size in bytes + UnownedStringSlice name; ///< The name associated with this type. NOTE! Currently these names + ///< *are* the GLSL format names. +}; + +const ImageFormatInfo& getImageFormatInfo(ImageFormat format); + +bool findImageFormatByName(const UnownedStringSlice& name, ImageFormat* outFormat); +bool findVkImageFormatByName(const UnownedStringSlice& name, ImageFormat* outFormat); + +char const* getGLSLNameForImageFormat(ImageFormat format); + +// TODO(tfoley): We should ditch this enumeration +// and just use the IR opcodes that represent these +// types directly. The one major complication there +// is that the order of the enum values currently +// matters, since it determines promotion rank. +// We either need to keep that restriction, or +// look up promotion rank by some other means. +// + +class Decl; +class Val; + +// Helper type for pairing up a name and the location where it appeared +struct NameLoc +{ + Name* name; + SourceLoc loc; + + NameLoc() + : name(nullptr) { - SlangScalarType scalarType; ///< If image format is not made up of channels of set sizes this will be SLANG_SCALAR_TYPE_NONE - uint8_t channelCount; ///< The number of channels - uint8_t sizeInBytes; ///< Size in bytes - UnownedStringSlice name; ///< The name associated with this type. NOTE! Currently these names *are* the GLSL format names. - }; + } - const ImageFormatInfo& getImageFormatInfo(ImageFormat format); + explicit NameLoc(Name* inName) + : name(inName) + { + } - bool findImageFormatByName(const UnownedStringSlice& name, ImageFormat* outFormat); - bool findVkImageFormatByName(const UnownedStringSlice& name, ImageFormat* outFormat); - char const* getGLSLNameForImageFormat(ImageFormat format); + NameLoc(Name* inName, SourceLoc inLoc) + : name(inName), loc(inLoc) + { + } - // TODO(tfoley): We should ditch this enumeration - // and just use the IR opcodes that represent these - // types directly. The one major complication there - // is that the order of the enum values currently - // matters, since it determines promotion rank. - // We either need to keep that restriction, or - // look up promotion rank by some other means. - // + NameLoc(Token const& token) + : name(token.getNameOrNull()), loc(token.getLoc()) + { + } +}; - class Decl; - class Val; +struct StringSliceLoc +{ + UnownedStringSlice name; + SourceLoc loc; - // Helper type for pairing up a name and the location where it appeared - struct NameLoc + StringSliceLoc() + : name(nullptr) + { + } + explicit StringSliceLoc(const UnownedStringSlice& inName) + : name(inName) + { + } + StringSliceLoc(const UnownedStringSlice& inName, SourceLoc inLoc) + : name(inName), loc(inLoc) { - Name* name; - SourceLoc loc; + } + StringSliceLoc(Token const& token) + : loc(token.getLoc()) + { + Name* tokenName = token.getNameOrNull(); + if (tokenName) + { + name = tokenName->text.getUnownedSlice(); + } + } +}; - NameLoc() - : name(nullptr) - {} +// Helper class for iterating over a list of heap-allocated modifiers +struct ModifierList +{ + struct Iterator + { + Modifier* current = nullptr; - explicit NameLoc(Name* inName) - : name(inName) - {} + Modifier* operator*() { return current; } + void operator++(); - NameLoc(Name* inName, SourceLoc inLoc) - : name(inName) - , loc(inLoc) - {} + bool operator!=(Iterator other) { return current != other.current; }; - NameLoc(Token const& token) - : name(token.getNameOrNull()) - , loc(token.getLoc()) - {} - }; + Iterator() + : current(nullptr) + { + } - struct StringSliceLoc - { - UnownedStringSlice name; - SourceLoc loc; - - StringSliceLoc() - : name(nullptr) - {} - explicit StringSliceLoc(const UnownedStringSlice& inName) - : name(inName) - {} - StringSliceLoc(const UnownedStringSlice& inName, SourceLoc inLoc) - : name(inName) - , loc(inLoc) - {} - StringSliceLoc(Token const& token) - : loc(token.getLoc()) + Iterator(Modifier* modifier) + : current(modifier) { - Name* tokenName = token.getNameOrNull(); - if (tokenName) - { - name = tokenName->text.getUnownedSlice(); - } } }; - // Helper class for iterating over a list of heap-allocated modifiers - struct ModifierList + ModifierList() + : modifiers(nullptr) { - struct Iterator - { - Modifier* current = nullptr; + } - Modifier* operator*() - { - return current; - } + ModifierList(Modifier* modifiers) + : modifiers(modifiers) + { + } - void operator++(); + Iterator begin() { return Iterator(modifiers); } + Iterator end() { return Iterator(nullptr); } - bool operator!=(Iterator other) - { - return current != other.current; - }; + Modifier* modifiers = nullptr; +}; - Iterator() - : current(nullptr) - {} +// Helper class for iterating over heap-allocated modifiers +// of a specific type. +template +struct FilteredModifierList +{ + struct Iterator + { + Modifier* current = nullptr; - Iterator(Modifier* modifier) - : current(modifier) - {} - }; + T* operator*() { return (T*)current; } - ModifierList() - : modifiers(nullptr) - {} + void operator++(); - ModifierList(Modifier* modifiers) - : modifiers(modifiers) - {} + bool operator!=(Iterator other) { return current != other.current; }; - Iterator begin() { return Iterator(modifiers); } - Iterator end() { return Iterator(nullptr); } + Iterator() + : current(nullptr) + { + } - Modifier* modifiers = nullptr; + Iterator(Modifier* modifier) + : current(modifier) + { + } }; - // Helper class for iterating over heap-allocated modifiers - // of a specific type. - template - struct FilteredModifierList + FilteredModifierList() + : modifiers(nullptr) { - struct Iterator - { - Modifier* current = nullptr; + } - T* operator*() - { - return (T*)current; - } + FilteredModifierList(Modifier* modifiers) + : modifiers(adjust(modifiers)) + { + } - void operator++(); - - bool operator!=(Iterator other) - { - return current != other.current; - }; + Iterator begin() { return Iterator(modifiers); } + Iterator end() { return Iterator(nullptr); } - Iterator() - : current(nullptr) - {} + static Modifier* adjust(Modifier* modifier); - Iterator(Modifier* modifier) - : current(modifier) - {} - }; + Modifier* modifiers = nullptr; +}; - FilteredModifierList() - : modifiers(nullptr) - {} +// A set of modifiers attached to a syntax node +struct Modifiers +{ + // The first modifier in the linked list of heap-allocated modifiers + Modifier* first = nullptr; - FilteredModifierList(Modifier* modifiers) - : modifiers(adjust(modifiers)) - {} + template + FilteredModifierList getModifiersOfType() + { + return FilteredModifierList(first); + } - Iterator begin() { return Iterator(modifiers); } - Iterator end() { return Iterator(nullptr); } + // Find the first modifier of a given type, or return `nullptr` if none is found. + template + T* findModifier() + { + return *getModifiersOfType().begin(); + } - static Modifier* adjust(Modifier* modifier); + template + bool hasModifier() + { + return findModifier() != nullptr; + } - Modifier* modifiers = nullptr; - }; + /// True if has no modifiers + bool isEmpty() const { return first == nullptr; } - // A set of modifiers attached to a syntax node - struct Modifiers + FilteredModifierList::Iterator begin() { - // The first modifier in the linked list of heap-allocated modifiers - Modifier* first = nullptr; + return FilteredModifierList::Iterator(first); + } + FilteredModifierList::Iterator end() + { + return FilteredModifierList::Iterator(nullptr); + } +}; - template - FilteredModifierList getModifiersOfType() { return FilteredModifierList(first); } +class NamedExpressionType; +class GenericDecl; +class ContainerDecl; - // Find the first modifier of a given type, or return `nullptr` if none is found. - template - T* findModifier() - { - return *getModifiersOfType().begin(); - } +// Try to extract a simple integer value from an `IntVal`. +// This fill assert-fail if the object doesn't represent a literal value. +IntegerLiteralValue getIntVal(IntVal* val); - template - bool hasModifier() { return findModifier() != nullptr; } +/// Represents how much checking has been applied to a declaration. +enum class DeclCheckState : uint8_t +{ + /// The declaration has been parsed, but + /// is otherwise completely unchecked. + /// + Unchecked, - /// True if has no modifiers - bool isEmpty() const { return first == nullptr; } + /// Basic checks on the modifiers of the declaration have been applied. + /// + /// For example, when a declaration has attributes, the transformation + /// of an attribute from the parsed-but-unchecked form into a checked + /// form (in which it has the appropriate C++ subclass) happens here. + /// + ModifiersChecked, - FilteredModifierList::Iterator begin() { return FilteredModifierList::Iterator(first); } - FilteredModifierList::Iterator end() { return FilteredModifierList::Iterator(nullptr); } - }; + /// Wiring up scopes of namespaces with their siblings defined in different + /// files/modules, and other namespaces imported via `using`. + ScopesWired, - class NamedExpressionType; - class GenericDecl; - class ContainerDecl; - - // Try to extract a simple integer value from an `IntVal`. - // This fill assert-fail if the object doesn't represent a literal value. - IntegerLiteralValue getIntVal(IntVal* val); - - /// Represents how much checking has been applied to a declaration. - enum class DeclCheckState : uint8_t - { - /// The declaration has been parsed, but - /// is otherwise completely unchecked. - /// - Unchecked, - - /// Basic checks on the modifiers of the declaration have been applied. - /// - /// For example, when a declaration has attributes, the transformation - /// of an attribute from the parsed-but-unchecked form into a checked - /// form (in which it has the appropriate C++ subclass) happens here. - /// - ModifiersChecked, - - /// Wiring up scopes of namespaces with their siblings defined in different - /// files/modules, and other namespaces imported via `using`. - ScopesWired, - - /// The type/signature of the declaration has been checked. - /// - /// For a value declaration like a variable or function, this means that - /// the type of the declaration can be queried. - /// - /// For a type declaration like a `struct` or `typedef` this means - /// that a `Type` referring to that declaration can be formed. - /// - SignatureChecked, - - /// The declaration's basic signature has been checked to the point that - /// it is ready to be referenced in other places. - /// - /// For a function, this means that it has been organized into a - /// "redeclration group" if there are multiple functions with the - /// same name in a scope. - /// - ReadyForReference, - - /// The declaration is ready for lookup operations to be performed. - /// - /// For type declarations (e.g., aggregate types, generic type parameters) - /// this means that any base type or constraint clauses have been - /// sufficiently checked so that we can enumerate the inheritance - /// hierarchy of the type and discover all its members. - /// - ReadyForLookup, - - /// Any conformance declared on the declaration have been validated. - /// - /// In particular, this step means that a "witness table" has been - /// created to show how a type satisfies the requirements of any - /// interfaces it conforms to. - /// - ReadyForConformances, - - /// Any DeclRefTypes with substitutions have been fully resolved - /// to concrete type. E.g. `T.X` with `T=A` should resolve to `A.X`. - /// We need a separate pass to resolve these types because `A.X` - /// maybe synthesized and made available only after conformance checking. - TypesFullyResolved, - - /// All attributes are fully checked. This is the final step before - /// checking the function body. - AttributesChecked, - - /// The body/definition is checked. - /// - /// This step includes any validation of the declaration that is - /// immaterial to clients code using the declaration, but that is - /// nonetheless relevant to checking correctness. - /// - /// The canonical example here is checking the body of functions. - /// Client code cannot depend on *how* a function is implemented, - /// but we still need to (eventually) check the bodies of all - /// functions, so it belongs in the last phase of checking. - /// - DefinitionChecked, - DefaultConstructorReadyForUse = DefinitionChecked, - - /// The capabilities required by the decl is infered and validated. - /// - CapabilityChecked, - - // For convenience at sites that call `ensureDecl()`, we define - // some aliases for the above states that are expressed in terms - // of what client code needs to be able to do with a declaration. - // - // These aliases can be changed over time if we decide to add - // more phases to semantic checking. - - CanEnumerateBases = ReadyForLookup, - CanUseBaseOfInheritanceDecl = ReadyForLookup, - CanUseTypeOfValueDecl = ReadyForReference, - CanUseExtensionTargetType = ReadyForLookup, - CanUseAsType = ReadyForReference, - CanUseFuncSignature = ReadyForReference, - CanSpecializeGeneric = ReadyForReference, - CanReadInterfaceRequirements = ReadyForLookup, - }; + /// The type/signature of the declaration has been checked. + /// + /// For a value declaration like a variable or function, this means that + /// the type of the declaration can be queried. + /// + /// For a type declaration like a `struct` or `typedef` this means + /// that a `Type` referring to that declaration can be formed. + /// + SignatureChecked, - /// A `DeclCheckState` plus a bit to track whether a declaration is currently being checked. - struct DeclCheckStateExt - { - SLANG_VALUE_CLASS(DeclCheckStateExt) + /// The declaration's basic signature has been checked to the point that + /// it is ready to be referenced in other places. + /// + /// For a function, this means that it has been organized into a + /// "redeclration group" if there are multiple functions with the + /// same name in a scope. + /// + ReadyForReference, - typedef uint8_t RawType; - DeclCheckStateExt() {} - DeclCheckStateExt(DeclCheckState state) - : m_raw(uint8_t(state)) - {} + /// The declaration is ready for lookup operations to be performed. + /// + /// For type declarations (e.g., aggregate types, generic type parameters) + /// this means that any base type or constraint clauses have been + /// sufficiently checked so that we can enumerate the inheritance + /// hierarchy of the type and discover all its members. + /// + ReadyForLookup, - enum : RawType - { - /// A flag to indicate that a declaration is being checked. - /// - /// The value of this flag is chosen so that it can be - /// represented in the bits of a `DeclCheckState` without - /// colliding with the bits that represent actual states. - /// - kBeingCheckedBit = 0x80, - }; - - DeclCheckState getState() const { return DeclCheckState(m_raw & ~kBeingCheckedBit); } - void setState(DeclCheckState state) - { - m_raw = (m_raw & kBeingCheckedBit) | RawType(state); - } + /// Any conformance declared on the declaration have been validated. + /// + /// In particular, this step means that a "witness table" has been + /// created to show how a type satisfies the requirements of any + /// interfaces it conforms to. + /// + ReadyForConformances, - bool isBeingChecked() const { return (m_raw & kBeingCheckedBit) != 0; } + /// Any DeclRefTypes with substitutions have been fully resolved + /// to concrete type. E.g. `T.X` with `T=A` should resolve to `A.X`. + /// We need a separate pass to resolve these types because `A.X` + /// maybe synthesized and made available only after conformance checking. + TypesFullyResolved, - void setIsBeingChecked(bool isBeingChecked) - { - m_raw = (m_raw & ~kBeingCheckedBit) - | (isBeingChecked ? kBeingCheckedBit : 0); - } + /// All attributes are fully checked. This is the final step before + /// checking the function body. + AttributesChecked, - bool operator>=(DeclCheckState state) const - { - return getState() >= state; - } + /// The body/definition is checked. + /// + /// This step includes any validation of the declaration that is + /// immaterial to clients code using the declaration, but that is + /// nonetheless relevant to checking correctness. + /// + /// The canonical example here is checking the body of functions. + /// Client code cannot depend on *how* a function is implemented, + /// but we still need to (eventually) check the bodies of all + /// functions, so it belongs in the last phase of checking. + /// + DefinitionChecked, + DefaultConstructorReadyForUse = DefinitionChecked, - RawType getRaw() const { return m_raw; } - void setRaw(RawType raw) { m_raw = raw; } + /// The capabilities required by the decl is infered and validated. + /// + CapabilityChecked, - // TODO(JS): - // Unfortunately for automatic serialization to see this member, it has to be public. - //private: - RawType m_raw = 0; + // For convenience at sites that call `ensureDecl()`, we define + // some aliases for the above states that are expressed in terms + // of what client code needs to be able to do with a declaration. + // + // These aliases can be changed over time if we decide to add + // more phases to semantic checking. + + CanEnumerateBases = ReadyForLookup, + CanUseBaseOfInheritanceDecl = ReadyForLookup, + CanUseTypeOfValueDecl = ReadyForReference, + CanUseExtensionTargetType = ReadyForLookup, + CanUseAsType = ReadyForReference, + CanUseFuncSignature = ReadyForReference, + CanSpecializeGeneric = ReadyForReference, + CanReadInterfaceRequirements = ReadyForLookup, +}; + +/// A `DeclCheckState` plus a bit to track whether a declaration is currently being checked. +struct DeclCheckStateExt +{ + SLANG_VALUE_CLASS(DeclCheckStateExt) + + typedef uint8_t RawType; + DeclCheckStateExt() {} + DeclCheckStateExt(DeclCheckState state) + : m_raw(uint8_t(state)) + { + } + + enum : RawType + { + /// A flag to indicate that a declaration is being checked. + /// + /// The value of this flag is chosen so that it can be + /// represented in the bits of a `DeclCheckState` without + /// colliding with the bits that represent actual states. + /// + kBeingCheckedBit = 0x80, }; - void addModifier( - ModifiableSyntaxNode* syntax, - Modifier* modifier); + DeclCheckState getState() const { return DeclCheckState(m_raw & ~kBeingCheckedBit); } + void setState(DeclCheckState state) { m_raw = (m_raw & kBeingCheckedBit) | RawType(state); } - void removeModifier( - ModifiableSyntaxNode* syntax, - Modifier* modifier); + bool isBeingChecked() const { return (m_raw & kBeingCheckedBit) != 0; } - struct QualType + void setIsBeingChecked(bool isBeingChecked) { - SLANG_VALUE_CLASS(QualType) + m_raw = (m_raw & ~kBeingCheckedBit) | (isBeingChecked ? kBeingCheckedBit : 0); + } - Type* type = nullptr; - bool isLeftValue = false; - bool hasReadOnlyOnTarget = false; - bool isWriteOnly = false; + bool operator>=(DeclCheckState state) const { return getState() >= state; } - QualType() = default; + RawType getRaw() const { return m_raw; } + void setRaw(RawType raw) { m_raw = raw; } - QualType(Type* type); + // TODO(JS): + // Unfortunately for automatic serialization to see this member, it has to be public. + // private: + RawType m_raw = 0; +}; - QualType(Type* type, bool isLVal) - : QualType(type) - { - isLeftValue = isLVal; - } +void addModifier(ModifiableSyntaxNode* syntax, Modifier* modifier); +void removeModifier(ModifiableSyntaxNode* syntax, Modifier* modifier); - Type* Ptr() { return type; } +struct QualType +{ + SLANG_VALUE_CLASS(QualType) - operator Type*() { return type; } - Type* operator->() { return type; } - }; + Type* type = nullptr; + bool isLeftValue = false; + bool hasReadOnlyOnTarget = false; + bool isWriteOnly = false; - class ASTBuilder; + QualType() = default; - struct ASTClassInfo + QualType(Type* type); + + QualType(Type* type, bool isLVal) + : QualType(type) { - struct Infos - { - const ReflectClassInfo* infos[int(ASTNodeType::CountOf)]; - }; - SLANG_FORCE_INLINE static const ReflectClassInfo* getInfo(ASTNodeType type) { return kInfos.infos[int(type)]; } - static const Infos kInfos; + isLeftValue = isLVal; + } + + + Type* Ptr() { return type; } + + operator Type*() { return type; } + Type* operator->() { return type; } +}; + +class ASTBuilder; + +struct ASTClassInfo +{ + struct Infos + { + const ReflectClassInfo* infos[int(ASTNodeType::CountOf)]; }; + SLANG_FORCE_INLINE static const ReflectClassInfo* getInfo(ASTNodeType type) + { + return kInfos.infos[int(type)]; + } + static const Infos kInfos; +}; + +// A reference to a class of syntax node, that can be +// used to create instances on the fly +struct SyntaxClassBase +{ + SyntaxClassBase() {} - // A reference to a class of syntax node, that can be - // used to create instances on the fly - struct SyntaxClassBase + SyntaxClassBase(ReflectClassInfo const* inClassInfo) + : classInfo(inClassInfo) { - SyntaxClassBase() - {} + } - SyntaxClassBase(ReflectClassInfo const* inClassInfo) - : classInfo(inClassInfo) - {} + void* createInstanceImpl(ASTBuilder* astBuilder) const + { + auto ci = classInfo; + if (!ci) + return nullptr; - void* createInstanceImpl(ASTBuilder* astBuilder) const - { - auto ci = classInfo; - if (!ci) return nullptr; + auto cf = ci->m_createFunc; + if (!cf) + return nullptr; - auto cf = ci->m_createFunc; - if (!cf) return nullptr; + return cf(astBuilder); + } - return cf(astBuilder); - } + SLANG_FORCE_INLINE bool isSubClassOfImpl(SyntaxClassBase const& super) const + { + return classInfo ? classInfo->isSubClassOf(*super.classInfo) : false; + } - SLANG_FORCE_INLINE bool isSubClassOfImpl(SyntaxClassBase const& super) const - { - return classInfo ? classInfo->isSubClassOf(*super.classInfo) : false; - } + ReflectClassInfo const* classInfo = nullptr; +}; - ReflectClassInfo const* classInfo = nullptr; - }; +template +struct SyntaxClass : SyntaxClassBase +{ + SyntaxClass() {} - template - struct SyntaxClass : SyntaxClassBase + template + SyntaxClass( + SyntaxClass const& other, + typename EnableIf::Value, void>::type* = 0) + : SyntaxClassBase(other.classInfo) { - SyntaxClass() - {} + } - template - SyntaxClass(SyntaxClass const& other, - typename EnableIf::Value, void>::type* = 0) - : SyntaxClassBase(other.classInfo) - { - } + T* createInstance(ASTBuilder* astBuilder) const { return (T*)createInstanceImpl(astBuilder); } - T* createInstance(ASTBuilder* astBuilder) const - { - return (T*)createInstanceImpl(astBuilder); - } + SyntaxClass(const ReflectClassInfo* inClassInfo) + : SyntaxClassBase(inClassInfo) + { + } - SyntaxClass(const ReflectClassInfo* inClassInfo): - SyntaxClassBase(inClassInfo) - {} + static SyntaxClass getClass() { return SyntaxClass(&T::kReflectClassInfo); } - static SyntaxClass getClass() - { - return SyntaxClass(&T::kReflectClassInfo); - } + template + bool isSubClassOf(SyntaxClass super) + { + return isSubClassOfImpl(super); + } - template - bool isSubClassOf(SyntaxClass super) - { - return isSubClassOfImpl(super); - } + template + bool isSubClassOf() + { + return isSubClassOf(SyntaxClass::getClass()); + } - template - bool isSubClassOf() - { - return isSubClassOf(SyntaxClass::getClass()); - } + template + bool operator==(const SyntaxClass other) const + { + return classInfo == other.classInfo; + } - template - bool operator==(const SyntaxClass other) const - { - return classInfo == other.classInfo; - } + template + bool operator!=(const SyntaxClass other) const + { + return classInfo != other.classInfo; + } +}; - template - bool operator!=(const SyntaxClass other) const - { - return classInfo != other.classInfo; - } - }; +template +SyntaxClass getClass() +{ + return SyntaxClass::getClass(); +} - template - SyntaxClass getClass() +struct SubstitutionSet +{ + DeclRefBase* declRef = nullptr; + + // The element index if the substitution is happening inside a pack expansion. + // For example, if we are substituting the pattern type of `expand each T`, where + // `T` is a type pack, then packExpansionIndex will have a value starting from 0 + // to the count of the type pack during expansion of the `expand` type when we + // substitute `each T` with the element of `T` at index `packExpansionIndex`. + Index packExpansionIndex = -1; + + SubstitutionSet() = default; + SubstitutionSet(DeclRefBase* declRefBase) + : declRef(declRefBase) { - return SyntaxClass::getClass(); } + explicit operator bool() const; + + template + void forEachGenericSubstitution(F func) const; + + template + void forEachSubstitutionArg(F func) const; + + Type* applyToType(ASTBuilder* astBuilder, Type* type) const; + DeclRefBase* applyToDeclRef(ASTBuilder* astBuilder, DeclRefBase* declRef) const; + + LookupDeclRef* findLookupDeclRef() const; + GenericAppDeclRef* findGenericAppDeclRef(GenericDecl* genericDecl) const; + GenericAppDeclRef* findGenericAppDeclRef() const; + DeclRefBase* getInnerMostNodeWithSubstInfo() const; +}; + +/// An expression together with (optional) substutions to apply to it +/// +/// Under the hood this is a pair of an `Expr*` and a `SubstitutionSet`. +/// Conceptually it represents the result of applying the substitutions, +/// recursively, to the given expression. +/// +/// `SubstExprBase` exists primarily to provide a non-templated base type +/// for `SubstExpr`. Code should prefer to use `SubstExpr` instead +/// of `SubstExprBase` as often as possible. +/// +struct SubstExprBase +{ +public: + /// Initialize as a null expression + SubstExprBase() {} - struct SubstitutionSet + /// Initialize as the given `expr` with no subsitutions applied + SubstExprBase(Expr* expr) + : m_expr(expr) { - DeclRefBase* declRef = nullptr; + } - // The element index if the substitution is happening inside a pack expansion. - // For example, if we are substituting the pattern type of `expand each T`, where - // `T` is a type pack, then packExpansionIndex will have a value starting from 0 - // to the count of the type pack during expansion of the `expand` type when we - // substitute `each T` with the element of `T` at index `packExpansionIndex`. - Index packExpansionIndex = -1; + /// Initialize as the given `expr` with the given `substs` applied + SubstExprBase(Expr* expr, SubstitutionSet const& substs) + : m_expr(expr), m_substs(substs) + { + } - SubstitutionSet() = default; - SubstitutionSet(DeclRefBase* declRefBase) - :declRef(declRefBase) - {} - explicit operator bool() const; + /// Get the underlying expression without any substitutions + Expr* getExpr() const { return m_expr; } - template - void forEachGenericSubstitution(F func) const; + /// Get the subsitutions being applied, if any + SubstitutionSet const& getSubsts() const { return m_substs; } - template - void forEachSubstitutionArg(F func) const; +private: + Expr* m_expr = nullptr; + SubstitutionSet m_substs; - Type* applyToType(ASTBuilder* astBuilder, Type* type) const; - DeclRefBase* applyToDeclRef(ASTBuilder* astBuilder, DeclRefBase* declRef) const; + typedef void (SubstExprBase::*SafeBool)(); + void SafeBoolTrue() {} - LookupDeclRef* findLookupDeclRef() const; - GenericAppDeclRef* findGenericAppDeclRef(GenericDecl* genericDecl) const; - GenericAppDeclRef* findGenericAppDeclRef() const; - DeclRefBase* getInnerMostNodeWithSubstInfo() const; +public: + /// Test whether this is a non-null expression + operator SafeBool() { return m_expr ? &SubstExprBase::SafeBoolTrue : nullptr; } - }; + /// Test whether this is a null expression + bool operator!() const { return m_expr == nullptr; } +}; - /// An expression together with (optional) substutions to apply to it - /// - /// Under the hood this is a pair of an `Expr*` and a `SubstitutionSet`. - /// Conceptually it represents the result of applying the substitutions, - /// recursively, to the given expression. - /// - /// `SubstExprBase` exists primarily to provide a non-templated base type - /// for `SubstExpr`. Code should prefer to use `SubstExpr` instead - /// of `SubstExprBase` as often as possible. - /// - struct SubstExprBase +/// An expression together with (optional) substutions to apply to it +/// +/// Under the hood this is a pair of an `T*` (there `T: Expr`) and a `SubstitutionSet`. +/// Conceptually it represents the result of applying the substitutions, +/// recursively, to the given expression. +/// +template +struct SubstExpr : SubstExprBase +{ +private: + typedef SubstExprBase Super; + +public: + /// Initialize as a null expression + SubstExpr() {} + + /// Initialize as the given `expr` with no subsitutions applied + SubstExpr(T* expr) + : Super(expr) { - public: - /// Initialize as a null expression - SubstExprBase() - {} + } - /// Initialize as the given `expr` with no subsitutions applied - SubstExprBase(Expr* expr) - : m_expr(expr) - {} + /// Initialize as the given `expr` with the given `substs` applied + SubstExpr(T* expr, SubstitutionSet const& substs) + : Super(expr, substs) + { + } - /// Initialize as the given `expr` with the given `substs` applied - SubstExprBase(Expr* expr, SubstitutionSet const& substs) - : m_expr(expr) - , m_substs(substs) - {} + /// Initialize as a copy of the given `other` expression + template + SubstExpr( + SubstExpr const& other, + typename EnableIf::Value, void>::type* = 0) + : Super(other.getExpr(), other.getSubsts()) + { + } - /// Get the underlying expression without any substitutions - Expr* getExpr() const { return m_expr; } + /// Get the underlying expression without any substitutions + T* getExpr() const { return (T*)Super::getExpr(); } - /// Get the subsitutions being applied, if any - SubstitutionSet const& getSubsts() const { return m_substs; } + /// Dynamic cast to an expression of type `U` + /// + /// Returns a null expression if the cast fails, or if this expression was null. + template + SubstExpr as() + { + return SubstExpr(Slang::as(getExpr()), getSubsts()); + } +}; - private: - Expr* m_expr = nullptr; - SubstitutionSet m_substs; +SubstExpr applySubstitutionToExpr(SubstitutionSet substSet, Expr* expr); - typedef void (SubstExprBase::*SafeBool)(); - void SafeBoolTrue() {} +class ASTBuilder; - public: - /// Test whether this is a non-null expression - operator SafeBool() - { - return m_expr ? &SubstExprBase::SafeBoolTrue : nullptr; - } +template +struct DeclRef; +Module* getModule(Decl* decl); - /// Test whether this is a null expression - bool operator!() const { return m_expr == nullptr; } - }; +// If this is a declref to an associatedtype with a ThisTypeSubsitution, +// try to find the concrete decl that satisfies the associatedtype requirement from the +// concrete type supplied by ThisTypeSubstittution. +Val* _tryLookupConcreteAssociatedTypeFromThisTypeSubst(ASTBuilder* builder, DeclRef declRef); - /// An expression together with (optional) substutions to apply to it - /// - /// Under the hood this is a pair of an `T*` (there `T: Expr`) and a `SubstitutionSet`. - /// Conceptually it represents the result of applying the substitutions, - /// recursively, to the given expression. - /// - template - struct SubstExpr : SubstExprBase - { - private: - typedef SubstExprBase Super; - - public: - /// Initialize as a null expression - SubstExpr() - {} - - /// Initialize as the given `expr` with no subsitutions applied - SubstExpr(T* expr) - : Super(expr) - {} - - /// Initialize as the given `expr` with the given `substs` applied - SubstExpr(T* expr, SubstitutionSet const& substs) - : Super(expr, substs) - {} - - /// Initialize as a copy of the given `other` expression - template - SubstExpr(SubstExpr const& other, - typename EnableIf::Value, void>::type* = 0) - : Super(other.getExpr(), other.getSubsts()) - { - } +template +struct DeclRef +{ + friend class ASTBuilder; - /// Get the underlying expression without any substitutions - T* getExpr() const { return (T*) Super::getExpr(); } +public: + typedef T DeclType; + DeclRefBase* declRefBase; + DeclRef() + : declRefBase(nullptr) + { + } - /// Dynamic cast to an expression of type `U` - /// - /// Returns a null expression if the cast fails, or if this expression was null. - template - SubstExpr as() - { - return SubstExpr(Slang::as(getExpr()), getSubsts()); - } - }; + void init(DeclRefBase* base); - SubstExpr applySubstitutionToExpr(SubstitutionSet substSet, Expr* expr); + DeclRef(Decl* decl); - class ASTBuilder; + DeclRef(DeclRefBase* base) { init(base); } - template - struct DeclRef; - Module* getModule(Decl* decl); + template::Value, void>::type> + DeclRef(DeclRef const& other) + : declRefBase(other.declRefBase) + { + } + T* getDecl() const; - // If this is a declref to an associatedtype with a ThisTypeSubsitution, - // try to find the concrete decl that satisfies the associatedtype requirement from the - // concrete type supplied by ThisTypeSubstittution. - Val* _tryLookupConcreteAssociatedTypeFromThisTypeSubst(ASTBuilder* builder, DeclRef declRef); + Name* getName() const; - template - struct DeclRef - { - friend class ASTBuilder; - public: - typedef T DeclType; - DeclRefBase* declRefBase; - DeclRef() - :declRefBase(nullptr) - {} - - void init(DeclRefBase* base); + SourceLoc getNameLoc() const; + SourceLoc getLoc() const; + DeclRef getParent() const; + HashCode getHashCode() const; + Type* substitute(ASTBuilder* astBuilder, Type* type) const; - DeclRef(Decl* decl); + SubstExpr substitute(ASTBuilder* astBuilder, Expr* expr) const; - DeclRef(DeclRefBase* base) - { - init(base); - } + // Apply substitutions to a type or declaration + template + DeclRef substitute(ASTBuilder* astBuilder, DeclRef declRef) const; - template ::Value, void>::type> - DeclRef(DeclRef const& other) - : declRefBase(other.declRefBase) - {} + // Apply substitutions to this declaration reference + DeclRef substituteImpl(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) const; - T* getDecl() const; + template + DeclRef as() const + { + DeclRef result = DeclRef(declRefBase); + return result; + } - Name* getName() const; - - SourceLoc getNameLoc() const; - SourceLoc getLoc() const; - DeclRef getParent() const; - HashCode getHashCode() const; - Type* substitute(ASTBuilder* astBuilder, Type* type) const; + template + bool is() const + { + return Slang::as(static_cast(getDecl())) != nullptr; + } - SubstExpr substitute(ASTBuilder* astBuilder, Expr* expr) const; + operator DeclRefBase*() const { return declRefBase; } - // Apply substitutions to a type or declaration - template - DeclRef substitute(ASTBuilder* astBuilder, DeclRef declRef) const; + operator DeclRef() const { return DeclRef(declRefBase); } - // Apply substitutions to this declaration reference - DeclRef substituteImpl(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) const; + template + bool equals(DeclRef other) const + { + return declRefBase == other.declRefBase; + } - template - DeclRef as() const - { - DeclRef result = DeclRef(declRefBase); - return result; - } + template + bool operator==(DeclRef other) const + { + return equals(other); + } - template - bool is() const - { - return Slang::as(static_cast(getDecl())) != nullptr; - } + template + bool operator!=(DeclRef other) const + { + return !equals(other); + } - operator DeclRefBase* () const - { - return declRefBase; - } + explicit operator bool() const { return declRefBase; } +}; - operator DeclRef() const - { - return DeclRef(declRefBase); - } +template +inline DeclRef makeDeclRef(T* decl) +{ + return DeclRef(decl); +} - template - bool equals(DeclRef other) const - { - return declRefBase == other.declRefBase; - } +SubstExpr substituteExpr(SubstitutionSet const& substs, Expr* expr); +DeclRef substituteDeclRef( + SubstitutionSet const& substs, + ASTBuilder* astBuilder, + DeclRef const& declRef); +Type* substituteType(SubstitutionSet const& substs, ASTBuilder* astBuilder, Type* type); - template - bool operator == (DeclRef other) const - { - return equals(other); - } +enum class MemberFilterStyle +{ + All, ///< All members + Instance, ///< Only instance members + Static, ///< Only static (ie non instance) members +}; + +Decl* const* adjustFilterCursorImpl( + const ReflectClassInfo& clsInfo, + MemberFilterStyle filterStyle, + Decl* const* ptr, + Decl* const* end); +Decl* const* getFilterCursorByIndexImpl( + const ReflectClassInfo& clsInfo, + MemberFilterStyle filterStyle, + Decl* const* ptr, + Decl* const* end, + Index index); +Index getFilterCountImpl( + const ReflectClassInfo& clsInfo, + MemberFilterStyle filterStyle, + Decl* const* ptr, + Decl* const* end); + + +template +Decl* const* adjustFilterCursor(MemberFilterStyle filterStyle, Decl* const* ptr, Decl* const* end) +{ + return adjustFilterCursorImpl(T::kReflectClassInfo, filterStyle, ptr, end); +} + +/// Finds the element at index. If there is no element at the index (for example has too few +/// elements), returns nullptr. +template +Decl* const* getFilterCursorByIndex( + MemberFilterStyle filterStyle, + Decl* const* ptr, + Decl* const* end, + Index index) +{ + return getFilterCursorByIndexImpl(T::kReflectClassInfo, filterStyle, ptr, end, index); +} - template - bool operator != (DeclRef other) const - { - return !equals(other); - } +template +Index getFilterCount(MemberFilterStyle filterStyle, Decl* const* ptr, Decl* const* end) +{ + return getFilterCountImpl(T::kReflectClassInfo, filterStyle, ptr, end); +} - explicit operator bool() const - { - return declRefBase; - } - }; +template +bool isFilterNonEmpty(MemberFilterStyle filterStyle, Decl* const* ptr, Decl* const* end) +{ + return adjustFilterCursorImpl(T::kReflectClassInfo, filterStyle, ptr, end) != end; +} - template - inline DeclRef makeDeclRef(T* decl) +template +struct FilteredMemberList +{ + typedef Decl* Element; + + FilteredMemberList() + : m_begin(nullptr), m_end(nullptr) { - return DeclRef(decl); } - SubstExpr substituteExpr(SubstitutionSet const& substs, Expr* expr); - DeclRef substituteDeclRef(SubstitutionSet const& substs, ASTBuilder* astBuilder, DeclRef const& declRef); - Type* substituteType(SubstitutionSet const& substs, ASTBuilder* astBuilder, Type* type); + explicit FilteredMemberList( + List const& list, + MemberFilterStyle filterStyle = MemberFilterStyle::All) + : m_begin(adjustFilterCursor(filterStyle, list.begin(), list.end())) + , m_end(list.end()) + , m_filterStyle(filterStyle) + { + } - enum class MemberFilterStyle + struct Iterator { - All, ///< All members - Instance, ///< Only instance members - Static, ///< Only static (ie non instance) members - }; + const Element* m_cursor; + const Element* m_end; + MemberFilterStyle m_filterStyle; + + bool operator!=(Iterator const& other) const { return m_cursor != other.m_cursor; } - Decl*const* adjustFilterCursorImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filterStyle, Decl*const* ptr, Decl*const* end); - Decl*const* getFilterCursorByIndexImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filterStyle, Decl*const* ptr, Decl*const* end, Index index); - Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filterStyle, Decl*const* ptr, Decl*const* end); + void operator++() { m_cursor = adjustFilterCursor(m_filterStyle, m_cursor + 1, m_end); } + T* operator*() { return static_cast(*m_cursor); } + }; - template - Decl*const* adjustFilterCursor(MemberFilterStyle filterStyle, Decl*const* ptr, Decl*const* end) + Iterator begin() { - return adjustFilterCursorImpl(T::kReflectClassInfo, filterStyle, ptr, end); + Iterator iter = {m_begin, m_end, m_filterStyle}; + return iter; } - /// Finds the element at index. If there is no element at the index (for example has too few elements), returns nullptr. - template - Decl*const* getFilterCursorByIndex(MemberFilterStyle filterStyle, Decl*const* ptr, Decl*const* end, Index index) + Iterator end() { - return getFilterCursorByIndexImpl(T::kReflectClassInfo, filterStyle, ptr, end, index); + Iterator iter = {m_end, m_end, m_filterStyle}; + return iter; } - - template - Index getFilterCount(MemberFilterStyle filterStyle, Decl*const* ptr, Decl*const* end) + + // TODO(tfoley): It is ugly to have these. + // We should probably fix the call sites instead. + T* getFirst() { return *begin(); } + Index getCount() { return getFilterCount(m_filterStyle, m_begin, m_end); } + + T* operator[](Index index) const { - return getFilterCountImpl(T::kReflectClassInfo, filterStyle, ptr, end); + Decl* const* ptr = getFilterCursorByIndex(m_filterStyle, m_begin, m_end, index); + SLANG_ASSERT(ptr); + return static_cast(*ptr); } - - template - bool isFilterNonEmpty(MemberFilterStyle filterStyle, Decl*const* ptr, Decl*const* end) + + /// Returns true if empty (equivalent to getCount() == 0) + bool isEmpty() const { - return adjustFilterCursorImpl(T::kReflectClassInfo, filterStyle, ptr, end) != end; + /// Note we don't have to scan, because m_begin has already been adjusted, when the + /// FilteredMemberList is constructed + return m_begin == m_end; } + /// Returns true if non empty (equivalent to getCount() != 0 but faster) + bool isNonEmpty() const { return !isEmpty(); } - template - struct FilteredMemberList + List toList() { - typedef Decl* Element; + List result; + for (auto element : (*this)) + { + result.add(element); + } + return result; + } - FilteredMemberList() - : m_begin(nullptr) - , m_end(nullptr) - {} + const Element* + m_begin; ///< Is either equal to m_end, or points to first *valid* filtered member + const Element* m_end; + MemberFilterStyle m_filterStyle; +}; - explicit FilteredMemberList( - List const& list, - MemberFilterStyle filterStyle = MemberFilterStyle::All) - : m_begin(adjustFilterCursor(filterStyle, list.begin(), list.end())) - , m_end(list.end()) - , m_filterStyle(filterStyle) - {} +struct TransparentMemberInfo +{ + // The declaration of the transparent member + Decl* decl = nullptr; +}; - struct Iterator - { - const Element* m_cursor; - const Element* m_end; - MemberFilterStyle m_filterStyle; +template +struct FilteredMemberRefList +{ + List const& m_decls; + DeclRef m_parent; + MemberFilterStyle m_filterStyle; + ASTBuilder* m_astBuilder; + + FilteredMemberRefList( + ASTBuilder* astBuilder, + List const& decls, + DeclRef parent, + MemberFilterStyle filterStyle = MemberFilterStyle::All) + : m_decls(decls), m_parent(parent), m_filterStyle(filterStyle), m_astBuilder(astBuilder) + { + } - bool operator!=(Iterator const& other) const { return m_cursor != other.m_cursor; } + Index getCount() const + { + return getFilterCount(m_filterStyle, m_decls.begin(), m_decls.end()); + } - void operator++() { m_cursor = adjustFilterCursor(m_filterStyle, m_cursor + 1, m_end); } + /// True if empty (equivalent to getCount == 0, but faster) + bool isEmpty() const { return !isNonEmpty(); } + /// True if non empty (equivalent to getCount() != 0 but faster) + bool isNonEmpty() const + { + return isFilterNonEmpty(m_filterStyle, m_decls.begin(), m_decls.end()); + } - T* operator*() { return static_cast(*m_cursor); } - }; + DeclRef getFirstOrNull() { return isEmpty() ? DeclRef() : (*this)[0]; } - Iterator begin() - { - Iterator iter = { m_begin, m_end, m_filterStyle }; - return iter; - } + DeclRef operator[](Index index) const + { + Decl* const* decl = + getFilterCursorByIndex(m_filterStyle, m_decls.begin(), m_decls.end(), index); + SLANG_ASSERT(decl); + return _getMemberDeclRef(m_astBuilder, m_parent, (T*)*decl).template as(); + } - Iterator end() - { - Iterator iter = { m_end, m_end, m_filterStyle }; - return iter; - } + List> toArray() const + { + List> result; + for (auto d : *this) + result.add(d); + return result; + } - // TODO(tfoley): It is ugly to have these. - // We should probably fix the call sites instead. - T* getFirst() { return *begin(); } - Index getCount() { return getFilterCount(m_filterStyle, m_begin, m_end); } + struct Iterator + { + FilteredMemberRefList const* m_list; + Decl* const* m_ptr; + Decl* const* m_end; + MemberFilterStyle m_filterStyle; - T* operator[](Index index) const + Iterator() + : m_list(nullptr), m_ptr(nullptr), m_filterStyle(MemberFilterStyle::All) { - Decl*const* ptr = getFilterCursorByIndex(m_filterStyle, m_begin, m_end, index); - SLANG_ASSERT(ptr); - return static_cast(*ptr); } - - /// Returns true if empty (equivalent to getCount() == 0) - bool isEmpty() const + Iterator( + FilteredMemberRefList const* list, + Decl* const* ptr, + Decl* const* end, + MemberFilterStyle filterStyle) + : m_list(list), m_ptr(ptr), m_end(end), m_filterStyle(filterStyle) { - /// Note we don't have to scan, because m_begin has already been adjusted, when the FilteredMemberList is constructed - return m_begin == m_end; } - /// Returns true if non empty (equivalent to getCount() != 0 but faster) - bool isNonEmpty() const { return !isEmpty(); } - List toList() + bool operator!=(const Iterator& other) const { return m_ptr != other.m_ptr; } + + void operator++() { m_ptr = adjustFilterCursor(m_filterStyle, m_ptr + 1, m_end); } + + DeclRef operator*() { - List result; - for (auto element : (*this)) - { - result.add(element); - } - return result; + return _getMemberDeclRef(m_list->m_astBuilder, m_list->m_parent, (T*)*m_ptr) + .template as(); } - - const Element* m_begin; ///< Is either equal to m_end, or points to first *valid* filtered member - const Element* m_end; - MemberFilterStyle m_filterStyle; }; - struct TransparentMemberInfo + Iterator begin() const { - // The declaration of the transparent member - Decl* decl = nullptr; - }; + return Iterator( + this, + adjustFilterCursor(m_filterStyle, m_decls.begin(), m_decls.end()), + m_decls.end(), + m_filterStyle); + } + Iterator end() const { return Iterator(this, m_decls.end(), m_decls.end(), m_filterStyle); } +}; - template - struct FilteredMemberRefList - { - List const& m_decls; - DeclRef m_parent; - MemberFilterStyle m_filterStyle; - ASTBuilder* m_astBuilder; - - FilteredMemberRefList( - ASTBuilder* astBuilder, - List const& decls, - DeclRef parent, - MemberFilterStyle filterStyle = MemberFilterStyle::All) - : m_decls(decls) - , m_parent(parent) - , m_filterStyle(filterStyle) - , m_astBuilder(astBuilder) - {} - - Index getCount() const { return getFilterCount(m_filterStyle, m_decls.begin(), m_decls.end()); } - - /// True if empty (equivalent to getCount == 0, but faster) - bool isEmpty() const { return !isNonEmpty(); } - /// True if non empty (equivalent to getCount() != 0 but faster) - bool isNonEmpty() const { return isFilterNonEmpty(m_filterStyle, m_decls.begin(), m_decls.end()); } - - DeclRef getFirstOrNull() { return isEmpty() ? DeclRef() : (*this)[0]; } - - DeclRef operator[](Index index) const - { - Decl*const* decl = getFilterCursorByIndex(m_filterStyle, m_decls.begin(), m_decls.end(), index); - SLANG_ASSERT(decl); - return _getMemberDeclRef(m_astBuilder, m_parent, (T*)*decl).template as(); - } +// +// type Expressions +// - List> toArray() const - { - List> result; - for (auto d : *this) - result.add(d); - return result; - } +// A "type expression" is a term that we expect to resolve to a type during checking. +// We store both the original syntax and the resolved type here. +struct TypeExp +{ + SLANG_VALUE_CLASS(TypeExp) + typedef TypeExp ThisType; - struct Iterator - { - FilteredMemberRefList const* m_list; - Decl*const* m_ptr; - Decl*const* m_end; - MemberFilterStyle m_filterStyle; - - Iterator() : m_list(nullptr), m_ptr(nullptr), m_filterStyle(MemberFilterStyle::All) {} - Iterator( - FilteredMemberRefList const* list, - Decl*const* ptr, - Decl*const* end, - MemberFilterStyle filterStyle - ) - : m_list(list) - , m_ptr(ptr) - , m_end(end) - , m_filterStyle(filterStyle) - {} - - bool operator!=(const Iterator& other) const { return m_ptr != other.m_ptr; } - - void operator++() { m_ptr = adjustFilterCursor(m_filterStyle, m_ptr + 1, m_end); } - - DeclRef operator*() { return _getMemberDeclRef(m_list->m_astBuilder, m_list->m_parent, (T*)*m_ptr).template as(); } - }; - - Iterator begin() const { return Iterator(this, adjustFilterCursor(m_filterStyle, m_decls.begin(), m_decls.end()), m_decls.end(), m_filterStyle); } - Iterator end() const { return Iterator(this, m_decls.end(), m_decls.end(), m_filterStyle); } - }; + TypeExp() {} + TypeExp(TypeExp const& other) + : exp(other.exp), type(other.type) + { + } + explicit TypeExp(Expr* exp) + : exp(exp) + { + } + explicit TypeExp(Type* type) + : type(type) + { + } + TypeExp(Expr* exp, Type* type) + : exp(exp), type(type) + { + } - // - // type Expressions - // + Expr* exp = nullptr; + Type* type = nullptr; - // A "type expression" is a term that we expect to resolve to a type during checking. - // We store both the original syntax and the resolved type here. - struct TypeExp - { - SLANG_VALUE_CLASS(TypeExp) - typedef TypeExp ThisType; - - TypeExp() {} - TypeExp(TypeExp const& other) - : exp(other.exp) - , type(other.type) - {} - explicit TypeExp(Expr* exp) - : exp(exp) - {} - explicit TypeExp(Type* type) - : type(type) - {} - TypeExp(Expr* exp, Type* type) - : exp(exp) - , type(type) - {} - - Expr* exp = nullptr; - Type* type = nullptr; - - bool equals(Type* other); - - Type* Ptr() { return type; } - operator Type*() - { - return type; - } - Type* operator->() { return Ptr(); } + bool equals(Type* other); - ThisType& operator=(const ThisType& rhs) = default; + Type* Ptr() { return type; } + operator Type*() { return type; } + Type* operator->() { return Ptr(); } - //TypeExp accept(SyntaxVisitor* visitor); + ThisType& operator=(const ThisType& rhs) = default; - /// A global immutable TypeExp, that has no type or exp set. - static const TypeExp empty; - }; + // TypeExp accept(SyntaxVisitor* visitor); + + /// A global immutable TypeExp, that has no type or exp set. + static const TypeExp empty; +}; + +// Masks to be applied when lookup up declarations +enum class LookupMask : uint8_t +{ + type = 0x1, + Function = 0x2, + Value = 0x4, + Attribute = 0x8, + Default = type | Function | Value, +}; + +/// Flags for options to be used when looking up declarations +enum class LookupOptions : uint8_t +{ + None = 0, + IgnoreBaseInterfaces = 1 << 0, + Completion = 1 << 1, ///< Lookup all applicable decls for code completion suggestions + NoDeref = 1 << 2, + ConsiderAllLocalNamesInScope = 1 << 3, + ///^ Normally we rely on the checking state of local names to determine + /// if they have been declared. If the scopes are currently + /// "under-construction" and not being checked, then it's safe to + /// consider all names we've inserted so far. This is used when + /// checking to see if a keyword is shadowed. + IgnoreInheritance = + 1 << 4, ///< Lookup only non inheritance children of a struct (including `extension`) +}; +inline LookupOptions operator&(LookupOptions a, LookupOptions b) +{ + return (LookupOptions)((std::underlying_type_t)a & + (std::underlying_type_t)b); +} + +class SerialRefObject; - // Masks to be applied when lookup up declarations - enum class LookupMask : uint8_t +// Make sure C++ extractor can see the base class. +SLANG_PRE_DECLARE(OBJ, class SerialRefObject) + +SLANG_TYPE_SET(OBJ, RefObject) +SLANG_TYPE_SET(VALUE, Value) +SLANG_TYPE_SET(AST, ASTNode) + +class LookupResultItem_Breadcrumb : public SerialRefObject +{ +public: + SLANG_OBJ_CLASS(LookupResultItem_Breadcrumb) + + enum class Kind : uint8_t { - type = 0x1, - Function = 0x2, - Value = 0x4, - Attribute = 0x8, - Default = type | Function | Value, + // The lookup process looked "through" an in-scope + // declaration to the fields inside of it, so that + // even if lookup started with a simple name `f`, + // it needs to result in a member expression `obj.f`. + Member, + + // The lookup process took a pointer(-like) value, and then + // proceeded to derefence it and look at the thing(s) + // it points to instead, so that the final expression + // needs to have `(*obj)` + Deref, + + // The lookup process saw a value `obj` of type `T` and + // took into account an in-scope constraint that says + // `T` is a subtype of some other type `U`, so that + // lookup was able to find a member through type `U` + // instead. + SuperType, + + // The lookup process considered a member of an + // enclosing type as being in scope, so that any + // reference to that member needs to use a `this` + // expression as appropriate. + This, }; - /// Flags for options to be used when looking up declarations - enum class LookupOptions : uint8_t - { - None = 0, - IgnoreBaseInterfaces = 1 << 0, - Completion = 1 << 1, ///< Lookup all applicable decls for code completion suggestions - NoDeref = 1 << 2, - ConsiderAllLocalNamesInScope = 1 << 3, - ///^ Normally we rely on the checking state of local names to determine - /// if they have been declared. If the scopes are currently - /// "under-construction" and not being checked, then it's safe to - /// consider all names we've inserted so far. This is used when - /// checking to see if a keyword is shadowed. - IgnoreInheritance = 1 << 4, ///< Lookup only non inheritance children of a struct (including `extension`) - }; - inline LookupOptions operator&(LookupOptions a, LookupOptions b) + // The kind of lookup step that was performed + Kind kind; + + // For the `Kind::This` case, what does the implicit + // `this` or `This` parameter refer to? + // + enum class ThisParameterMode : uint8_t { - return (LookupOptions)((std::underlying_type_t)a & (std::underlying_type_t)b); - } + ImmutableValue, // An immutable `this` value + MutableValue, // A mutable `this` value + Type, // A `This` type - class SerialRefObject; + Default = ImmutableValue, + }; + ThisParameterMode thisParameterMode = ThisParameterMode::Default; - // Make sure C++ extractor can see the base class. - SLANG_PRE_DECLARE(OBJ, class SerialRefObject) + // As needed, a reference to the declaration that faciliated + // the lookup step. + // + // For a `Member` lookup step, this is the declaration whose + // members were implicitly pulled into scope. + // + // For a `Constraint` lookup step, this is the `ConstraintDecl` + // that serves to witness the subtype relationship. + // + DeclRef declRef; + + Val* val = nullptr; - SLANG_TYPE_SET(OBJ, RefObject) - SLANG_TYPE_SET(VALUE, Value) - SLANG_TYPE_SET(AST, ASTNode) + // The next implicit step that the lookup process took to + // arrive at a final value. + RefPtr next; - class LookupResultItem_Breadcrumb : public SerialRefObject + LookupResultItem_Breadcrumb( + Kind kind, + DeclRef declRef, + Val* val, + RefPtr next, + ThisParameterMode thisParameterMode = ThisParameterMode::Default) + : kind(kind), thisParameterMode(thisParameterMode), declRef(declRef), val(val), next(next) { - public: - SLANG_OBJ_CLASS(LookupResultItem_Breadcrumb) + } - enum class Kind : uint8_t - { - // The lookup process looked "through" an in-scope - // declaration to the fields inside of it, so that - // even if lookup started with a simple name `f`, - // it needs to result in a member expression `obj.f`. - Member, - - // The lookup process took a pointer(-like) value, and then - // proceeded to derefence it and look at the thing(s) - // it points to instead, so that the final expression - // needs to have `(*obj)` - Deref, - - // The lookup process saw a value `obj` of type `T` and - // took into account an in-scope constraint that says - // `T` is a subtype of some other type `U`, so that - // lookup was able to find a member through type `U` - // instead. - SuperType, - - // The lookup process considered a member of an - // enclosing type as being in scope, so that any - // reference to that member needs to use a `this` - // expression as appropriate. - This, - }; - - // The kind of lookup step that was performed - Kind kind; - - // For the `Kind::This` case, what does the implicit - // `this` or `This` parameter refer to? - // - enum class ThisParameterMode : uint8_t - { - ImmutableValue, // An immutable `this` value - MutableValue, // A mutable `this` value - Type, // A `This` type - - Default = ImmutableValue, - }; - ThisParameterMode thisParameterMode = ThisParameterMode::Default; - - // As needed, a reference to the declaration that faciliated - // the lookup step. - // - // For a `Member` lookup step, this is the declaration whose - // members were implicitly pulled into scope. - // - // For a `Constraint` lookup step, this is the `ConstraintDecl` - // that serves to witness the subtype relationship. - // - DeclRef declRef; - - Val* val = nullptr; - - // The next implicit step that the lookup process took to - // arrive at a final value. - RefPtr next; - - LookupResultItem_Breadcrumb( - Kind kind, - DeclRef declRef, - Val* val, - RefPtr next, - ThisParameterMode thisParameterMode = ThisParameterMode::Default) - : kind(kind) - , thisParameterMode(thisParameterMode) - , declRef(declRef) - , val(val) - , next(next) - {} - protected: - // Needed for serialization - LookupResultItem_Breadcrumb() = default; - }; +protected: + // Needed for serialization + LookupResultItem_Breadcrumb() = default; +}; - // Represents one item found during lookup - struct LookupResultItem - { - SLANG_VALUE_CLASS(LookupResultItem) - - typedef LookupResultItem_Breadcrumb Breadcrumb; - - // Sometimes lookup finds an item, but there were additional - // "hops" taken to reach it. We need to remember these steps - // so that if/when we consturct a full expression we generate - // appropriate AST nodes for all the steps. - // - // We build up a list of these "breadcrumbs" while doing - // lookup, and store them alongside each item found. - // - // As an example, suppose we have an HLSL `cbuffer` declaration: - // - // cbuffer C { float4 f; } - // - // This is syntax sugar for a global-scope variable of - // type `ConstantBuffer` where `T` is a `struct` containing - // all the members: - // - // struct Anon0 { float4 f; }; - // __transparent ConstantBuffer anon1; - // - // The `__transparent` modifier there captures the fact that - // when somebody writes `f` in their code, they expect it to - // "see through" the `cbuffer` declaration (or the global variable, - // in this case) and find the member inside. - // - // But when the user writes `f` we can't just create a simple - // `VarExpr` that refers directly to that field, because that - // doesn't actually reflect the required steps in a way that - // code generation can use. - // - // Instead we need to construct an expression like `(*anon1).f`, - // where there is are two additional steps in the process: - // - // 1. We needed to dereference the pointer-like type `ConstantBuffer` - // to get at a value of type `Anon0` - // 2. We needed to access a sub-field of the aggregate type `Anon0` - // - // We *could* just create these full-formed expressions during - // lookup, but this might mean creating a large number of - // AST nodes in cases where the user calls an overloaded function. - // At the very least we'd rather not heap-allocate in the common - // case where no "extra" steps need to be performed to get to - // the declarations. - // - // This is where "breadcrumbs" come in. A breadcrumb represents - // an extra "step" that must be performed to turn a declaration - // found by lookup into a valid expression to splice into the - // AST. Most of the time lookup result items don't have any - // breadcrumbs, so that no extra heap allocation takes place. - // When an item does have breadcrumbs, and it is chosen as - // the unique result (perhaps by overload resolution), then - // we can walk the list of breadcrumbs to create a full - // expression. - - - // A properly-specialized reference to the declaration that was found. - DeclRef declRef; - - // Any breadcrumbs needed in order to turn that declaration - // reference into a well-formed expression. - // - // This is unused in the simple case where a declaration - // is being referenced directly (rather than through - // transparent members). - RefPtr breadcrumbs; - - LookupResultItem() = default; - explicit LookupResultItem(DeclRef declRef) - : declRef(declRef) - {} - LookupResultItem(DeclRef declRef, RefPtr breadcrumbs) - : declRef(declRef) - , breadcrumbs(breadcrumbs) - {} - }; +// Represents one item found during lookup +struct LookupResultItem +{ + SLANG_VALUE_CLASS(LookupResultItem) + typedef LookupResultItem_Breadcrumb Breadcrumb; - // Result of looking up a name in some lexical/semantic environment. - // Can be used to enumerate all the declarations matching that name, - // in the case where the result is overloaded. - struct LookupResult + // Sometimes lookup finds an item, but there were additional + // "hops" taken to reach it. We need to remember these steps + // so that if/when we consturct a full expression we generate + // appropriate AST nodes for all the steps. + // + // We build up a list of these "breadcrumbs" while doing + // lookup, and store them alongside each item found. + // + // As an example, suppose we have an HLSL `cbuffer` declaration: + // + // cbuffer C { float4 f; } + // + // This is syntax sugar for a global-scope variable of + // type `ConstantBuffer` where `T` is a `struct` containing + // all the members: + // + // struct Anon0 { float4 f; }; + // __transparent ConstantBuffer anon1; + // + // The `__transparent` modifier there captures the fact that + // when somebody writes `f` in their code, they expect it to + // "see through" the `cbuffer` declaration (or the global variable, + // in this case) and find the member inside. + // + // But when the user writes `f` we can't just create a simple + // `VarExpr` that refers directly to that field, because that + // doesn't actually reflect the required steps in a way that + // code generation can use. + // + // Instead we need to construct an expression like `(*anon1).f`, + // where there is are two additional steps in the process: + // + // 1. We needed to dereference the pointer-like type `ConstantBuffer` + // to get at a value of type `Anon0` + // 2. We needed to access a sub-field of the aggregate type `Anon0` + // + // We *could* just create these full-formed expressions during + // lookup, but this might mean creating a large number of + // AST nodes in cases where the user calls an overloaded function. + // At the very least we'd rather not heap-allocate in the common + // case where no "extra" steps need to be performed to get to + // the declarations. + // + // This is where "breadcrumbs" come in. A breadcrumb represents + // an extra "step" that must be performed to turn a declaration + // found by lookup into a valid expression to splice into the + // AST. Most of the time lookup result items don't have any + // breadcrumbs, so that no extra heap allocation takes place. + // When an item does have breadcrumbs, and it is chosen as + // the unique result (perhaps by overload resolution), then + // we can walk the list of breadcrumbs to create a full + // expression. + + + // A properly-specialized reference to the declaration that was found. + DeclRef declRef; + + // Any breadcrumbs needed in order to turn that declaration + // reference into a well-formed expression. + // + // This is unused in the simple case where a declaration + // is being referenced directly (rather than through + // transparent members). + RefPtr breadcrumbs; + + LookupResultItem() = default; + explicit LookupResultItem(DeclRef declRef) + : declRef(declRef) { - // The one item that was found, in the simple case - LookupResultItem item; + } + LookupResultItem(DeclRef declRef, RefPtr breadcrumbs) + : declRef(declRef), breadcrumbs(breadcrumbs) + { + } +}; + - // All of the items that were found, in the complex case. - // Note: if there was no overloading, then this list isn't - // used at all, to avoid allocation. - // - // Additionally, if `items` is used, then `item` *must* hold an item that - // is also in the items list (typically the first entry), as an invariant. - // Otherwise isValid/begin will not function correctly. - List items; +// Result of looking up a name in some lexical/semantic environment. +// Can be used to enumerate all the declarations matching that name, +// in the case where the result is overloaded. +struct LookupResult +{ + // The one item that was found, in the simple case + LookupResultItem item; - // Was at least one result found? - bool isValid() const { return item.declRef.getDecl() != nullptr; } + // All of the items that were found, in the complex case. + // Note: if there was no overloading, then this list isn't + // used at all, to avoid allocation. + // + // Additionally, if `items` is used, then `item` *must* hold an item that + // is also in the items list (typically the first entry), as an invariant. + // Otherwise isValid/begin will not function correctly. + List items; - bool isOverloaded() const { return items.getCount() > 1; } + // Was at least one result found? + bool isValid() const { return item.declRef.getDecl() != nullptr; } - Name* getName() const - { - return items.getCount() > 1 ? items[0].declRef.getName() : item.declRef.getName(); - } - LookupResultItem* begin() const + bool isOverloaded() const { return items.getCount() > 1; } + + Name* getName() const + { + return items.getCount() > 1 ? items[0].declRef.getName() : item.declRef.getName(); + } + LookupResultItem* begin() const + { + if (isValid()) { - if (isValid()) - { - if (isOverloaded()) - return const_cast(items.begin()); - else - return const_cast(&item); - } + if (isOverloaded()) + return const_cast(items.begin()); else - return nullptr; + return const_cast(&item); } - LookupResultItem* end() const + else + return nullptr; + } + LookupResultItem* end() const + { + if (isValid()) { - if (isValid()) - { - if (isOverloaded()) - return const_cast(items.end()); - else - return const_cast(&item + 1); - } + if (isOverloaded()) + return const_cast(items.end()); else - return nullptr; + return const_cast(&item + 1); } - }; - - // A helper to avoid having to include slang-check-impl.h in slang-syntax.h - struct SemanticsVisitor; - ASTBuilder* semanticsVisitorGetASTBuilder(SemanticsVisitor*); - - struct LookupRequest - { - SemanticsVisitor* semantics = nullptr; - Scope* scope = nullptr; - Scope* endScope = nullptr; + else + return nullptr; + } +}; - // A decl to exclude from the lookup, used to exclude the current decl being checked, such as in typedef Foo Foo; - // to avoid finding itself. - Decl* declToExclude = nullptr; - LookupMask mask = LookupMask::Default; - LookupOptions options = LookupOptions::None; +// A helper to avoid having to include slang-check-impl.h in slang-syntax.h +struct SemanticsVisitor; +ASTBuilder* semanticsVisitorGetASTBuilder(SemanticsVisitor*); - bool isCompletionRequest() const { return (options & LookupOptions::Completion) != LookupOptions::None; } - bool shouldConsiderAllLocalNames() const { return (options & LookupOptions::ConsiderAllLocalNamesInScope) != LookupOptions::None; } - }; +struct LookupRequest +{ + SemanticsVisitor* semantics = nullptr; + Scope* scope = nullptr; + Scope* endScope = nullptr; - struct WitnessTable; + // A decl to exclude from the lookup, used to exclude the current decl being checked, such as in + // typedef Foo Foo; to avoid finding itself. + Decl* declToExclude = nullptr; + LookupMask mask = LookupMask::Default; + LookupOptions options = LookupOptions::None; - // A value that witnesses the satisfaction of an interface - // requirement by a particular declaration or value. - struct RequirementWitness + bool isCompletionRequest() const { - SLANG_VALUE_CLASS(RequirementWitness) - - RequirementWitness() - : m_flavor(Flavor::none) - {} - - RequirementWitness(DeclRefBase* declRef) - : m_flavor(Flavor::declRef) - , m_declRef(declRef) - {} - - RequirementWitness(Val* val); - - RequirementWitness(RefPtr witnessTable); + return (options & LookupOptions::Completion) != LookupOptions::None; + } + bool shouldConsiderAllLocalNames() const + { + return (options & LookupOptions::ConsiderAllLocalNamesInScope) != LookupOptions::None; + } +}; - enum class Flavor - { - none, - declRef, - val, - witnessTable, - }; +struct WitnessTable; - Flavor getFlavor() const - { - return m_flavor; - } +// A value that witnesses the satisfaction of an interface +// requirement by a particular declaration or value. +struct RequirementWitness +{ + SLANG_VALUE_CLASS(RequirementWitness) - DeclRef getDeclRef() - { - SLANG_ASSERT(getFlavor() == Flavor::declRef); - return m_declRef; - } + RequirementWitness() + : m_flavor(Flavor::none) + { + } - Val* getVal() - { - SLANG_ASSERT(getFlavor() == Flavor::val); - return m_val; - } + RequirementWitness(DeclRefBase* declRef) + : m_flavor(Flavor::declRef), m_declRef(declRef) + { + } - RefPtr getWitnessTable(); + RequirementWitness(Val* val); - RequirementWitness specialize(ASTBuilder* astBuilder, SubstitutionSet const& subst); + RequirementWitness(RefPtr witnessTable); - Flavor m_flavor; - DeclRef m_declRef; - RefPtr m_obj; - Val* m_val = nullptr; + enum class Flavor + { + none, + declRef, + val, + witnessTable, }; - typedef OrderedDictionary RequirementDictionary; + Flavor getFlavor() const { return m_flavor; } - struct WitnessTable : SerialRefObject + DeclRef getDeclRef() { - SLANG_OBJ_CLASS(WitnessTable) + SLANG_ASSERT(getFlavor() == Flavor::declRef); + return m_declRef; + } - const RequirementDictionary& getRequirementDictionary() - { - return m_requirementDictionary; - } + Val* getVal() + { + SLANG_ASSERT(getFlavor() == Flavor::val); + return m_val; + } - void add(Decl* decl, RequirementWitness const& witness); + RefPtr getWitnessTable(); - // The type that the witness table witnesses conformance to (e.g. an Interface) - Type* baseType; + RequirementWitness specialize(ASTBuilder* astBuilder, SubstitutionSet const& subst); - // The type witnessesd by the witness table (a concrete type). - Type* witnessedType; + Flavor m_flavor; + DeclRef m_declRef; + RefPtr m_obj; + Val* m_val = nullptr; +}; - // Whether or not this witness table is an extern declaration. - bool isExtern = false; +typedef OrderedDictionary RequirementDictionary; - // Cached dictionary for looking up satisfying values. - RequirementDictionary m_requirementDictionary; +struct WitnessTable : SerialRefObject +{ + SLANG_OBJ_CLASS(WitnessTable) - RefPtr specialize(ASTBuilder* astBuilder, SubstitutionSet const& subst); + const RequirementDictionary& getRequirementDictionary() { return m_requirementDictionary; } - }; + void add(Decl* decl, RequirementWitness const& witness); - struct SpecializationParam - { - enum class Flavor - { - GenericType, - GenericValue, - ExistentialType, - ExistentialValue, - }; - Flavor flavor; - SourceLoc loc; - NodeBase* object = nullptr; - }; - typedef List SpecializationParams; + // The type that the witness table witnesses conformance to (e.g. an Interface) + Type* baseType; - struct SpecializationArg - { - SLANG_VALUE_CLASS(SpecializationArg) - Val* val = nullptr; - }; - typedef List SpecializationArgs; + // The type witnessesd by the witness table (a concrete type). + Type* witnessedType; - struct ExpandedSpecializationArg : SpecializationArg - { - SLANG_VALUE_CLASS(ExpandedSpecializationArg) - Val* witness = nullptr; - }; - typedef List ExpandedSpecializationArgs; + // Whether or not this witness table is an extern declaration. + bool isExtern = false; - /// A reference-counted object to hold a list of candidate extensions - /// that might be applicable to a type based on its declaration. - /// - struct CandidateExtensionList : RefObject - { - List candidateExtensions; - }; + // Cached dictionary for looking up satisfying values. + RequirementDictionary m_requirementDictionary; + RefPtr specialize(ASTBuilder* astBuilder, SubstitutionSet const& subst); +}; - enum class DeclAssociationKind +struct SpecializationParam +{ + enum class Flavor { - ForwardDerivativeFunc, BackwardDerivativeFunc, PrimalSubstituteFunc + GenericType, + GenericValue, + ExistentialType, + ExistentialValue, }; + Flavor flavor; + SourceLoc loc; + NodeBase* object = nullptr; +}; +typedef List SpecializationParams; - struct DeclAssociation : SerialRefObject - { - SLANG_OBJ_CLASS(DeclAssociation) - DeclAssociationKind kind; - Decl* decl; - }; +struct SpecializationArg +{ + SLANG_VALUE_CLASS(SpecializationArg) + Val* val = nullptr; +}; +typedef List SpecializationArgs; - /// A reference-counted object to hold a list of associated decls for a decl. - /// - struct DeclAssociationList : SerialRefObject - { - SLANG_OBJ_CLASS(DeclAssociationList) +struct ExpandedSpecializationArg : SpecializationArg +{ + SLANG_VALUE_CLASS(ExpandedSpecializationArg) + Val* witness = nullptr; +}; +typedef List ExpandedSpecializationArgs; + +/// A reference-counted object to hold a list of candidate extensions +/// that might be applicable to a type based on its declaration. +/// +struct CandidateExtensionList : RefObject +{ + List candidateExtensions; +}; - List> associations; - }; - /// Represents the "direction" that a parameter is being passed (e.g., `in` or `out` - enum ParameterDirection - { - kParameterDirection_In, ///< Copy in - kParameterDirection_Out, ///< Copy out - kParameterDirection_InOut, ///< Copy in, copy out - kParameterDirection_Ref, ///< By-reference - kParameterDirection_ConstRef, ///< By-const-reference - }; +enum class DeclAssociationKind +{ + ForwardDerivativeFunc, + BackwardDerivativeFunc, + PrimalSubstituteFunc +}; - /// The kind of a builtin interface requirement that can be automatically synthesized. - enum class BuiltinRequirementKind - { - DefaultInitializableConstructor, ///< The `IDefaultInitializable.__init()` method - - DifferentialType, ///< The `IDifferentiable.Differential` associated type requirement - DifferentialPtrType, ///< The `IDifferentiable.DifferentialPtr` associated type requirement - DZeroFunc, ///< The `IDifferentiable.dzero` function requirement - DAddFunc, ///< The `IDifferentiable.dadd` function requirement - DMulFunc, ///< The `IDifferentiable.dmul` function requirement - - InitLogicalFromInt, ///< The `ILogical.__init` mtehod. - Equals, ///< The `ILogical.equals` mtehod. - LessThan, ///< The `ILogical.lessThan` mtehod. - LessThanOrEquals, ///< The `ILogical.lessThanOrEquals` mtehod. - Shl, ///< The `ILogical.shl` mtehod. - Shr, ///< The `ILogical.shr` mtehod. - BitAnd, ///< The `ILogical.bitAnd` mtehod. - BitOr, ///< The `ILogical.bitOr` mtehod. - BitXor, ///< The `ILogical.bitXor` mtehod. - BitNot, ///< The `ILogical.bitNot` mtehod. - And, ///< The `ILogical.and` mtehod. - Or, ///< The `ILogical.or` mtehod. - Not, ///< The `ILogical.not` mtehod. - }; +struct DeclAssociation : SerialRefObject +{ + SLANG_OBJ_CLASS(DeclAssociation) + DeclAssociationKind kind; + Decl* decl; +}; + +/// A reference-counted object to hold a list of associated decls for a decl. +/// +struct DeclAssociationList : SerialRefObject +{ + SLANG_OBJ_CLASS(DeclAssociationList) - enum class FunctionDifferentiableLevel - { - None, - Forward, - Backward - }; + List> associations; +}; - /// Represents a markup (documentation) associated with a decl. - struct MarkupEntry : public SerialRefObject - { - SLANG_OBJ_CLASS(MarkupEntry) +/// Represents the "direction" that a parameter is being passed (e.g., `in` or `out` +enum ParameterDirection +{ + kParameterDirection_In, ///< Copy in + kParameterDirection_Out, ///< Copy out + kParameterDirection_InOut, ///< Copy in, copy out + kParameterDirection_Ref, ///< By-reference + kParameterDirection_ConstRef, ///< By-const-reference +}; + +/// The kind of a builtin interface requirement that can be automatically synthesized. +enum class BuiltinRequirementKind +{ + DefaultInitializableConstructor, ///< The `IDefaultInitializable.__init()` method + + DifferentialType, ///< The `IDifferentiable.Differential` associated type requirement + DifferentialPtrType, ///< The `IDifferentiable.DifferentialPtr` associated type requirement + DZeroFunc, ///< The `IDifferentiable.dzero` function requirement + DAddFunc, ///< The `IDifferentiable.dadd` function requirement + DMulFunc, ///< The `IDifferentiable.dmul` function requirement + + InitLogicalFromInt, ///< The `ILogical.__init` mtehod. + Equals, ///< The `ILogical.equals` mtehod. + LessThan, ///< The `ILogical.lessThan` mtehod. + LessThanOrEquals, ///< The `ILogical.lessThanOrEquals` mtehod. + Shl, ///< The `ILogical.shl` mtehod. + Shr, ///< The `ILogical.shr` mtehod. + BitAnd, ///< The `ILogical.bitAnd` mtehod. + BitOr, ///< The `ILogical.bitOr` mtehod. + BitXor, ///< The `ILogical.bitXor` mtehod. + BitNot, ///< The `ILogical.bitNot` mtehod. + And, ///< The `ILogical.and` mtehod. + Or, ///< The `ILogical.or` mtehod. + Not, ///< The `ILogical.not` mtehod. +}; + +enum class FunctionDifferentiableLevel +{ + None, + Forward, + Backward +}; - NodeBase* m_node; ///< The node this documentation is associated with - String m_markup; ///< The raw contents of of markup associated with the decoration - MarkupVisibility m_visibility = MarkupVisibility::Public; ///< How visible this decl is - }; +/// Represents a markup (documentation) associated with a decl. +struct MarkupEntry : public SerialRefObject +{ + SLANG_OBJ_CLASS(MarkupEntry) - /// Get the inner most expr from an higher order expr chain, e.g. `__fwd_diff(__fwd_diff(f))`'s - /// inner most expr is `f`. - Expr* getInnerMostExprFromHigherOrderExpr(Expr* expr, FunctionDifferentiableLevel& outDiffLevel); - inline Expr* getInnerMostExprFromHigherOrderExpr(Expr* expr) - { - FunctionDifferentiableLevel level; - return getInnerMostExprFromHigherOrderExpr(expr, level); - } + NodeBase* m_node; ///< The node this documentation is associated with + String m_markup; ///< The raw contents of of markup associated with the decoration + MarkupVisibility m_visibility = MarkupVisibility::Public; ///< How visible this decl is +}; +/// Get the inner most expr from an higher order expr chain, e.g. `__fwd_diff(__fwd_diff(f))`'s +/// inner most expr is `f`. +Expr* getInnerMostExprFromHigherOrderExpr(Expr* expr, FunctionDifferentiableLevel& outDiffLevel); +inline Expr* getInnerMostExprFromHigherOrderExpr(Expr* expr) +{ + FunctionDifferentiableLevel level; + return getInnerMostExprFromHigherOrderExpr(expr, level); +} - /// Get the operator name from the higher order invoke expr. - UnownedStringSlice getHigherOrderOperatorName(HigherOrderInvokeExpr* expr); - enum class DeclVisibility - { - Private, - Internal, - Public, - Default = Internal, - }; +/// Get the operator name from the higher order invoke expr. +UnownedStringSlice getHigherOrderOperatorName(HigherOrderInvokeExpr* expr); + +enum class DeclVisibility +{ + Private, + Internal, + Public, + Default = Internal, +}; } // namespace Slang diff --git a/source/slang/slang-ast-synthesis.cpp b/source/slang/slang-ast-synthesis.cpp index cb7d338c8..46ba81d16 100644 --- a/source/slang/slang-ast-synthesis.cpp +++ b/source/slang/slang-ast-synthesis.cpp @@ -5,7 +5,8 @@ namespace Slang Expr* ASTSynthesizer::emitBinaryExpr(UnownedStringSlice operatorToken, Expr* left, Expr* right) { auto infixExpr = m_builder->create(); - infixExpr->functionExpr = emitVarExpr(m_namePool->getName(operatorToken));; + infixExpr->functionExpr = emitVarExpr(m_namePool->getName(operatorToken)); + ; infixExpr->arguments.add(left); infixExpr->arguments.add(right); return infixExpr; @@ -14,7 +15,8 @@ Expr* ASTSynthesizer::emitBinaryExpr(UnownedStringSlice operatorToken, Expr* lef Expr* ASTSynthesizer::emitPrefixExpr(UnownedStringSlice operatorToken, Expr* base) { auto prefixExpr = m_builder->create(); - prefixExpr->functionExpr = emitVarExpr(m_namePool->getName(operatorToken));; + prefixExpr->functionExpr = emitVarExpr(m_namePool->getName(operatorToken)); + ; prefixExpr->arguments.add(base); return prefixExpr; } @@ -22,12 +24,13 @@ Expr* ASTSynthesizer::emitPrefixExpr(UnownedStringSlice operatorToken, Expr* bas Expr* ASTSynthesizer::emitPostfixExpr(UnownedStringSlice operatorToken, Expr* base) { auto postfixExpr = m_builder->create(); - postfixExpr->functionExpr = emitVarExpr(m_namePool->getName(operatorToken));; + postfixExpr->functionExpr = emitVarExpr(m_namePool->getName(operatorToken)); + ; postfixExpr->arguments.add(base); return postfixExpr; } -ForStmt* ASTSynthesizer::emitFor(Expr* initVal, Expr* finalVal, VarDecl* &outIndexVar) +ForStmt* ASTSynthesizer::emitFor(Expr* initVal, Expr* finalVal, VarDecl*& outIndexVar) { auto parentStmt = getCurrentScope().m_parentSeqStmt; auto seqStmt = m_builder->create(); @@ -38,7 +41,8 @@ ForStmt* ASTSynthesizer::emitFor(Expr* initVal, Expr* finalVal, VarDecl* &outInd auto declStmt = emitVarDeclStmt(nullptr, m_namePool->getName("S_synth_loop_index"), initVal); stmt->initialStatement = declStmt; outIndexVar = (VarDecl*)declStmt->decl; - auto predicateExpr = emitBinaryExpr(UnownedStringSlice("<"), emitVarExpr(outIndexVar), finalVal); + auto predicateExpr = + emitBinaryExpr(UnownedStringSlice("<"), emitVarExpr(outIndexVar), finalVal); stmt->predicateExpression = predicateExpr; stmt->sideEffectExpression = emitPrefixExpr(UnownedStringSlice("++"), emitVarExpr(outIndexVar)); parentStmt->stmts.add(stmt); @@ -183,4 +187,4 @@ DeclStmt* ASTSynthesizer::emitVarDeclStmt(Type* type, Name* name, Expr* initVal) return stmt; } -} +} // namespace Slang diff --git a/source/slang/slang-ast-synthesis.h b/source/slang/slang-ast-synthesis.h index 0726360d3..b68bea39c 100644 --- a/source/slang/slang-ast-synthesis.h +++ b/source/slang/slang-ast-synthesis.h @@ -4,7 +4,8 @@ #include "slang-syntax.h" -namespace Slang { +namespace Slang +{ struct ASTEmitScope { @@ -18,10 +19,10 @@ private: ASTBuilder* m_builder; NamePool* m_namePool; List m_scopeStack; + public: ASTSynthesizer(ASTBuilder* builder, NamePool* namePool) - : m_builder(builder) - , m_namePool(namePool) + : m_builder(builder), m_namePool(namePool) { } @@ -97,10 +98,7 @@ public: return scope.m_parentSeqStmt; } - void popScope() - { - m_scopeStack.removeLast(); - } + void popScope() { m_scopeStack.removeLast(); } ASTEmitScope getCurrentScope() { @@ -146,7 +144,6 @@ public: ExpressionStmt* emitExprStmt(Expr* expr); ReturnStmt* emitReturnStmt(Expr* expr); - }; } // namespace Slang diff --git a/source/slang/slang-ast-type.cpp b/source/slang/slang-ast-type.cpp index 0feea3fc9..529cbb1be 100644 --- a/source/slang/slang-ast-type.cpp +++ b/source/slang/slang-ast-type.cpp @@ -1,13 +1,13 @@ // slang-ast-type.cpp #include "slang-ast-builder.h" #include "slang-ast-modifier.h" -#include -#include - +#include "slang-generated-ast-macro.h" #include "slang-syntax.h" -#include "slang-generated-ast-macro.h" -namespace Slang { +#include +#include +namespace Slang +{ bool isAbstractTypePack(Type* type) { @@ -90,17 +90,27 @@ Type* ErrorType::_createCanonicalTypeOverride() return this; } -Val* ErrorType::_substituteImplOverride(ASTBuilder* /* astBuilder */, SubstitutionSet /*subst*/, int* /*ioDiff*/) +Val* ErrorType::_substituteImplOverride( + ASTBuilder* /* astBuilder */, + SubstitutionSet /*subst*/, + int* /*ioDiff*/ +) { return this; } // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! BottomType !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! -void BottomType::_toTextOverride(StringBuilder& out) { out << toSlice("never"); } +void BottomType::_toTextOverride(StringBuilder& out) +{ + out << toSlice("never"); +} Val* BottomType::_substituteImplOverride( - ASTBuilder* /* astBuilder */, SubstitutionSet /*subst*/, int* /*ioDiff*/) + ASTBuilder* /* astBuilder */, + SubstitutionSet /*subst*/, + int* /*ioDiff*/ +) { return this; } @@ -112,11 +122,19 @@ void DeclRefType::_toTextOverride(StringBuilder& out) out << getDeclRef(); } -Val* maybeSubstituteGenericParam(Val* paramVal, Decl* paramDecl, SubstitutionSet subst, int* ioDiff); +Val* maybeSubstituteGenericParam( + Val* paramVal, + Decl* paramDecl, + SubstitutionSet subst, + int* ioDiff); -Val* DeclRefType::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) +Val* DeclRefType::_substituteImplOverride( + ASTBuilder* astBuilder, + SubstitutionSet subst, + int* ioDiff) { - if (!subst) return this; + if (!subst) + return this; int diff = 0; DeclRef substDeclRef = getDeclRef().substituteImpl(astBuilder, subst, &diff); @@ -135,9 +153,12 @@ Val* DeclRefType::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSe return lookupDeclRef->getLookupSource(); } } - else if (as(substDeclRef.getDecl()) || as(substDeclRef.getDecl())) + else if ( + as(substDeclRef.getDecl()) || + as(substDeclRef.getDecl())) { - auto resultVal = maybeSubstituteGenericParam(nullptr, substDeclRef.getDecl(), subst, ioDiff); + auto resultVal = + maybeSubstituteGenericParam(nullptr, substDeclRef.getDecl(), subst, ioDiff); if (resultVal) { (*ioDiff)++; @@ -173,15 +194,13 @@ Val* DeclRefType::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSe // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! ArithmeticExpressionType !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! -BasicExpressionType* ArithmeticExpressionType::getScalarType() -{ - SLANG_AST_NODE_VIRTUAL_CALL(ArithmeticExpressionType, getScalarType, ()) -} +BasicExpressionType* ArithmeticExpressionType::getScalarType(){ + SLANG_AST_NODE_VIRTUAL_CALL(ArithmeticExpressionType, getScalarType, ())} BasicExpressionType* ArithmeticExpressionType::_getScalarTypeOverride() { SLANG_UNEXPECTED("ArithmeticExpressionType::_getScalarTypeOverride not overridden"); - //return nullptr; + // return nullptr; } // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! BasicExpressionType !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! @@ -226,7 +245,8 @@ IntVal* VectorExpressionType::getElementCount() void VectorExpressionType::_toTextOverride(StringBuilder& out) { - out << toSlice("vector<") << getElementType() << toSlice(",") << getElementCount() << toSlice(">"); + out << toSlice("vector<") << getElementType() << toSlice(",") << getElementCount() + << toSlice(">"); } BasicExpressionType* VectorExpressionType::_getScalarTypeOverride() @@ -238,7 +258,8 @@ BasicExpressionType* VectorExpressionType::_getScalarTypeOverride() void MatrixExpressionType::_toTextOverride(StringBuilder& out) { - out << toSlice("matrix<") << getElementType() << toSlice(",") << getRowCount() << toSlice(",") << getColumnCount() << toSlice(">"); + out << toSlice("matrix<") << getElementType() << toSlice(",") << getRowCount() << toSlice(",") + << getColumnCount() << toSlice(">"); } BasicExpressionType* MatrixExpressionType::_getScalarTypeOverride() @@ -363,7 +384,7 @@ Type* GenericDeclRefType::_createCanonicalTypeOverride() void NamespaceType::_toTextOverride(StringBuilder& out) { - out << toSlice("namespace ") << getDeclRef(); + out << toSlice("namespace ") << getDeclRef(); } Type* NamespaceType::_createCanonicalTypeOverride() @@ -415,22 +436,12 @@ void maybePrintAddrSpaceOperand(StringBuilder& out, AddressSpace addrSpace) switch (addrSpace) { case AddressSpace::Generic: - case AddressSpace::UserPointer: - break; - case AddressSpace::GroupShared: - out << toSlice(", groupshared"); - break; - case AddressSpace::Global: - out << toSlice(", global"); - break; - case AddressSpace::ThreadLocal: - out << toSlice(", threadlocal"); - break; - case AddressSpace::Uniform: - out << toSlice(", uniform"); - break; - default: - break; + case AddressSpace::UserPointer: break; + case AddressSpace::GroupShared: out << toSlice(", groupshared"); break; + case AddressSpace::Global: out << toSlice(", global"); break; + case AddressSpace::ThreadLocal: out << toSlice(", threadlocal"); break; + case AddressSpace::Uniform: out << toSlice(", uniform"); break; + default: break; } } @@ -531,7 +542,7 @@ Val* FuncType::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet s // parameter types List substParamTypes; - for (Index pp = 0; pp < getParamCount(); pp++ ) + for (Index pp = 0; pp < getParamCount(); pp++) { substParamTypes.add(as(getParamType(pp)->substituteImpl(astBuilder, subst, &diff))); } @@ -541,7 +552,8 @@ Val* FuncType::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet s return this; (*ioDiff)++; - FuncType* substType = astBuilder->getFuncType(substParamTypes.getArrayView(), substResultType, substErrorType); + FuncType* substType = + astBuilder->getFuncType(substParamTypes.getArrayView(), substResultType, substErrorType); return substType; } @@ -558,7 +570,10 @@ Type* FuncType::_createCanonicalTypeOverride() canParamTypes.add(getParamType(pp)->getCanonicalType()); } - FuncType* canType = getCurrentASTBuilder()->getFuncType(canParamTypes.getArrayView(), canResultType, canErrorType); + FuncType* canType = getCurrentASTBuilder()->getFuncType( + canParamTypes.getArrayView(), + canResultType, + canErrorType); return canType; } @@ -624,7 +639,9 @@ Type* ExpandType::_createCanonicalTypeOverride() { capturedPacks.add(getCapturedTypePack(i)); } - return getCurrentASTBuilder()->getExpandType(canonicalPatternType, capturedPacks.getArrayView().arrayView); + return getCurrentASTBuilder()->getExpandType( + canonicalPatternType, + capturedPacks.getArrayView().arrayView); } Val* ExpandType::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) @@ -640,7 +657,8 @@ Val* ExpandType::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet ShortList concreteTypePacks; for (Index i = 0; i < getCapturedTypePackCount(); i++) { - auto substCapturedTypePack = getCapturedTypePack(i)->substituteImpl(astBuilder, subst, &diff); + auto substCapturedTypePack = + getCapturedTypePack(i)->substituteImpl(astBuilder, subst, &diff); if (auto expandType = as(substCapturedTypePack)) { for (Index j = 0; j < expandType->getCapturedTypePackCount(); j++) @@ -655,7 +673,7 @@ Val* ExpandType::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet } } } - + if (!diff || concreteTypePacks.getCount() != capturedPacks.getCount()) { auto substPatternType = getPatternType()->substituteImpl(astBuilder, subst, &diff); @@ -667,15 +685,17 @@ Val* ExpandType::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet // create a new ExpandType with the substituted pattern/capture types, instead of actually // expanding into a concrete type pack. (*ioDiff)++; - return astBuilder->getExpandType(as(substPatternType), capturedPacks.getArrayView().arrayView); + return astBuilder->getExpandType( + as(substPatternType), + capturedPacks.getArrayView().arrayView); } else { - // All type pack parameters are now concrete type packs, so we can construct a concrete type pack - // by substituting the pattern type with each element of the captured type pack. + // All type pack parameters are now concrete type packs, so we can construct a concrete type + // pack by substituting the pattern type with each element of the captured type pack. ShortList expandedTypes; SLANG_ASSERT(capturedPacks.getCount() != 0); - + for (Index i = 0; i < concreteTypePacks[0]->getTypeCount(); i++) { subst.packExpansionIndex = i; @@ -710,7 +730,10 @@ Type* ConcreteTypePack::_createCanonicalTypeOverride() return getCurrentASTBuilder()->getTypePack(canonicalElementTypes.getArrayView().arrayView); } -Val* ConcreteTypePack::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) +Val* ConcreteTypePack::_substituteImplOverride( + ASTBuilder* astBuilder, + SubstitutionSet subst, + int* ioDiff) { int diff = 0; ShortList substElementTypes; @@ -737,19 +760,26 @@ Type* ExtractExistentialType::_createCanonicalTypeOverride() return this; } -Val* ExtractExistentialType::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) +Val* ExtractExistentialType::_substituteImplOverride( + ASTBuilder* astBuilder, + SubstitutionSet subst, + int* ioDiff) { int diff = 0; auto substDeclRef = getDeclRef().substituteImpl(astBuilder, subst, &diff); - auto substOriginalInterfaceType = getOriginalInterfaceType()->substituteImpl(astBuilder, subst, &diff); - auto substOriginalInterfaceDeclRef = getOriginalInterfaceDeclRef().substituteImpl(astBuilder, subst, &diff); + auto substOriginalInterfaceType = + getOriginalInterfaceType()->substituteImpl(astBuilder, subst, &diff); + auto substOriginalInterfaceDeclRef = + getOriginalInterfaceDeclRef().substituteImpl(astBuilder, subst, &diff); if (!diff) return this; (*ioDiff)++; ExtractExistentialType* substValue = astBuilder->getOrCreate( - substDeclRef, as(substOriginalInterfaceType), substOriginalInterfaceDeclRef); + substDeclRef, + as(substOriginalInterfaceType), + substOriginalInterfaceDeclRef); return substValue; } @@ -758,7 +788,11 @@ SubtypeWitness* ExtractExistentialType::getSubtypeWitness() if (auto cachedValue = this->cachedSubtypeWitness) return cachedValue; - ExtractExistentialSubtypeWitness* openedWitness = getCurrentASTBuilder()->getOrCreate(this, getOriginalInterfaceType(), getDeclRef()); + ExtractExistentialSubtypeWitness* openedWitness = + getCurrentASTBuilder()->getOrCreate( + this, + getOriginalInterfaceType(), + getDeclRef()); this->cachedSubtypeWitness = openedWitness; return openedWitness; } @@ -781,7 +815,8 @@ DeclRef ExtractExistentialType::getThisTypeDeclRef() } SLANG_ASSERT(thisTypeDecl); - DeclRef specialiedInterfaceDeclRef = getCurrentASTBuilder()->getLookupDeclRef(openedWitness, thisTypeDecl).as(); + DeclRef specialiedInterfaceDeclRef = + getCurrentASTBuilder()->getLookupDeclRef(openedWitness, thisTypeDecl).as(); this->cachedThisTypeDeclRef = specialiedInterfaceDeclRef; return specialiedInterfaceDeclRef; @@ -825,20 +860,25 @@ Type* ExistentialSpecializedType::_createCanonicalTypeOverride() newArgs.add(canArg); } - ExistentialSpecializedType* canType = getCurrentASTBuilder()->getOrCreate( - getBaseType()->getCanonicalType(), - newArgs); + ExistentialSpecializedType* canType = + getCurrentASTBuilder()->getOrCreate( + getBaseType()->getCanonicalType(), + newArgs); return canType; } static Val* _substituteImpl(ASTBuilder* astBuilder, Val* val, SubstitutionSet subst, int* ioDiff) { - if (!val) return nullptr; + if (!val) + return nullptr; return val->substituteImpl(astBuilder, subst, ioDiff); } -Val* ExistentialSpecializedType::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) +Val* ExistentialSpecializedType::_substituteImplOverride( + ASTBuilder* astBuilder, + SubstitutionSet subst, + int* ioDiff) { int diff = 0; @@ -859,7 +899,8 @@ Val* ExistentialSpecializedType::_substituteImplOverride(ASTBuilder* astBuilder, (*ioDiff)++; - ExistentialSpecializedType* substType = astBuilder->getOrCreate(substBaseType, substArgs); + ExistentialSpecializedType* substType = + astBuilder->getOrCreate(substBaseType, substArgs); return substType; } @@ -922,10 +963,10 @@ Val* AndType::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet su { int diff = 0; - auto substLeft = as(getLeft()->substituteImpl(astBuilder, subst, &diff)); + auto substLeft = as(getLeft()->substituteImpl(astBuilder, subst, &diff)); auto substRight = as(getRight()->substituteImpl(astBuilder, subst, &diff)); - if(!diff) + if (!diff) return this; (*ioDiff)++; @@ -938,7 +979,7 @@ Val* AndType::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet su void ModifiedType::_toTextOverride(StringBuilder& out) { - for( Index i = 0; i < getModifierCount(); i++ ) + for (Index i = 0; i < getModifierCount(); i++) { getModifier(i)->toText(out); out.appendChar(' '); @@ -954,11 +995,16 @@ Type* ModifiedType::_createCanonicalTypeOverride() auto modifier = this->getModifier(i); modifiers.add(modifier); } - ModifiedType* canonical = getCurrentASTBuilder()->getOrCreate(getBase()->getCanonicalType(), modifiers.getArrayView()); + ModifiedType* canonical = getCurrentASTBuilder()->getOrCreate( + getBase()->getCanonicalType(), + modifiers.getArrayView()); return canonical; } -Val* ModifiedType::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) +Val* ModifiedType::_substituteImplOverride( + ASTBuilder* astBuilder, + SubstitutionSet subst, + int* ioDiff) { int diff = 0; Type* substBase = as(getBase()->substituteImpl(astBuilder, subst, &diff)); @@ -971,12 +1017,13 @@ Val* ModifiedType::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionS substModifiers.add(substModifier); } - if(!diff) + if (!diff) return this; *ioDiff = 1; - ModifiedType* substType = getCurrentASTBuilder()->getOrCreate(substBase, substModifiers.getArrayView()); + ModifiedType* substType = + getCurrentASTBuilder()->getOrCreate(substBase, substModifiers.getArrayView()); return substType; } @@ -1083,16 +1130,12 @@ SlangResourceAccess ResourceType::getAccess() { switch (constIntVal->getValue()) { - case kCoreModule_ResourceAccessReadOnly: - return SLANG_RESOURCE_ACCESS_READ; - case kCoreModule_ResourceAccessReadWrite: - return SLANG_RESOURCE_ACCESS_READ_WRITE; + case kCoreModule_ResourceAccessReadOnly: return SLANG_RESOURCE_ACCESS_READ; + case kCoreModule_ResourceAccessReadWrite: return SLANG_RESOURCE_ACCESS_READ_WRITE; case kCoreModule_ResourceAccessRasterizerOrdered: return SLANG_RESOURCE_ACCESS_RASTER_ORDERED; - case kCoreModule_ResourceAccessFeedback: - return SLANG_RESOURCE_ACCESS_FEEDBACK; - default: - break; + case kCoreModule_ResourceAccessFeedback: return SLANG_RESOURCE_ACCESS_FEEDBACK; + default: break; } } return SLANG_RESOURCE_ACCESS_NONE; @@ -1117,91 +1160,78 @@ Type* ResourceType::getElementType() void ResourceType::_toTextOverride(StringBuilder& out) { auto tryPrintSimpleName = [&](String& outString) -> bool + { + StringBuilder resultSB; + auto access = getAccess(); + switch (access) { - StringBuilder resultSB; - auto access = getAccess(); - switch (access) - { - case SLANG_RESOURCE_ACCESS_READ: - break; - case SLANG_RESOURCE_ACCESS_READ_WRITE: - resultSB << "RW";; - break; - case SLANG_RESOURCE_ACCESS_RASTER_ORDERED: - resultSB << "RasterizerOrdered"; - break; - case SLANG_RESOURCE_ACCESS_FEEDBACK: - resultSB << "Feedback"; - break; - default: - return false; - } - auto combined = as(_getGenericTypeArg(this, 7)); - auto shapeVal = _getGenericTypeArg(this, 1); - if (!as(shapeVal)) - return false; - auto shape = getBaseShape(); - if (!combined) - return false; - if (combined->getValue() != 0) - resultSB << "Sampler"; + case SLANG_RESOURCE_ACCESS_READ: break; + case SLANG_RESOURCE_ACCESS_READ_WRITE: + resultSB << "RW"; + ; + break; + case SLANG_RESOURCE_ACCESS_RASTER_ORDERED: resultSB << "RasterizerOrdered"; break; + case SLANG_RESOURCE_ACCESS_FEEDBACK: resultSB << "Feedback"; break; + default: return false; + } + auto combined = as(_getGenericTypeArg(this, 7)); + auto shapeVal = _getGenericTypeArg(this, 1); + if (!as(shapeVal)) + return false; + auto shape = getBaseShape(); + if (!combined) + return false; + if (combined->getValue() != 0) + resultSB << "Sampler"; + else + { + if (shape == SLANG_TEXTURE_BUFFER) + resultSB << "Buffer"; else + resultSB << "Texture"; + } + switch (shape) + { + case SLANG_TEXTURE_1D: resultSB << "1D"; break; + case SLANG_TEXTURE_2D: resultSB << "2D"; break; + case SLANG_TEXTURE_3D: resultSB << "3D"; break; + case SLANG_TEXTURE_CUBE: resultSB << "Cube"; break; + } + auto isArrayVal = as(_getGenericTypeArg(this, 2)); + if (!isArrayVal) + return false; + if (isArray()) + resultSB << "Array"; + auto isMultisampleVal = as(_getGenericTypeArg(this, 3)); + if (!isMultisampleVal) + return false; + if (isMultisample()) + resultSB << "MS"; + auto isShadowVal = as(_getGenericTypeArg(this, 6)); + if (!isShadowVal) + return false; + if (isShadow()) + return false; + auto elementType = getElementType(); + if (elementType) + { + resultSB << "<"; + resultSB << elementType->toString(); + auto sampleCount = _getGenericTypeArg(this, 4); + if (auto constIntVal = as(sampleCount)) { - if (shape == SLANG_TEXTURE_BUFFER) - resultSB << "Buffer"; - else - resultSB << "Texture"; + if (constIntVal->getValue() != 0) + resultSB << ", " << constIntVal->getValue(); } - switch (shape) + else { - case SLANG_TEXTURE_1D: - resultSB << "1D"; - break; - case SLANG_TEXTURE_2D: - resultSB << "2D"; - break; - case SLANG_TEXTURE_3D: - resultSB << "3D"; - break; - case SLANG_TEXTURE_CUBE: - resultSB << "Cube"; - break; - } - auto isArrayVal = as(_getGenericTypeArg(this, 2)); - if (!isArrayVal) - return false; - if (isArray()) - resultSB << "Array"; - auto isMultisampleVal = as(_getGenericTypeArg(this, 3)); - if (!isMultisampleVal) - return false; - if (isMultisample()) - resultSB << "MS"; - auto isShadowVal = as(_getGenericTypeArg(this, 6)); - if (!isShadowVal) return false; - if (isShadow()) - return false; - auto elementType = getElementType(); - if (elementType) - { - resultSB << "<"; - resultSB << elementType->toString(); - auto sampleCount = _getGenericTypeArg(this, 4); - if (auto constIntVal = as(sampleCount)) - { - if (constIntVal->getValue() != 0) - resultSB << ", " << constIntVal->getValue(); - } - else - { - return false; - } - resultSB << ">"; } - outString = resultSB.toString(); - return true; - }; + resultSB << ">"; + } + outString = resultSB.toString(); + return true; + }; String simpleName; diff --git a/source/slang/slang-ast-type.h b/source/slang/slang-ast-type.h index 1731cdb53..c55e0011f 100644 --- a/source/slang/slang-ast-type.h +++ b/source/slang/slang-ast-type.h @@ -4,12 +4,13 @@ #include "slang-ast-base.h" -namespace Slang { +namespace Slang +{ // Syntax class definitions for types. // The type of a reference to an overloaded name -class OverloadGroupType : public Type +class OverloadGroupType : public Type { SLANG_AST_CLASS(OverloadGroupType) @@ -20,7 +21,7 @@ class OverloadGroupType : public Type // The type of an initializer-list expression (before it has // been coerced to some other type) -class InitializerListType : public Type +class InitializerListType : public Type { SLANG_AST_CLASS(InitializerListType) @@ -30,7 +31,7 @@ class InitializerListType : public Type }; // The type of an expression that was erroneous -class ErrorType : public Type +class ErrorType : public Type { SLANG_AST_CLASS(ErrorType) @@ -65,10 +66,7 @@ class DeclRefType : public Type Type* _createCanonicalTypeOverride(); Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff); - DeclRefType(DeclRefBase* declRefBase) - { - setOperands(declRefBase); - } + DeclRefType(DeclRefBase* declRefBase) { setOperands(declRefBase); } }; template @@ -85,7 +83,7 @@ bool isTypePack(Type* type); bool isAbstractTypePack(Type* type); // Base class for types that can be used in arithmetic expressions -class ArithmeticExpressionType : public DeclRefType +class ArithmeticExpressionType : public DeclRefType { SLANG_ABSTRACT_AST_CLASS(ArithmeticExpressionType) @@ -95,7 +93,7 @@ class ArithmeticExpressionType : public DeclRefType BasicExpressionType* _getScalarTypeOverride(); }; -class BasicExpressionType : public ArithmeticExpressionType +class BasicExpressionType : public ArithmeticExpressionType { SLANG_AST_CLASS(BasicExpressionType) @@ -104,16 +102,13 @@ class BasicExpressionType : public ArithmeticExpressionType // Overrides should be public so base classes can access BasicExpressionType* _getScalarTypeOverride(); - BasicExpressionType(DeclRefBase* inDeclRef) - { - setOperands(inDeclRef); - } + BasicExpressionType(DeclRefBase* inDeclRef) { setOperands(inDeclRef); } }; // Base type for things that are built in to the compiler, // and will usually have special behavior or a custom // mapping to the IR level. -class BuiltinType : public DeclRefType +class BuiltinType : public DeclRefType { SLANG_ABSTRACT_AST_CLASS(BuiltinType) }; @@ -124,8 +119,8 @@ class FeedbackType : public BuiltinType enum class Kind : uint8_t { - MinMip, /// SAMPLER_FEEDBACK_MIN_MIP - MipRegionUsed, /// SAMPLER_FEEDBACK_MIP_REGION_USED + MinMip, /// SAMPLER_FEEDBACK_MIN_MIP + MipRegionUsed, /// SAMPLER_FEEDBACK_MIP_REGION_USED }; Kind getKind() const; @@ -158,7 +153,7 @@ class TextureShapeBufferType : public TextureShapeType }; // Resources that contain "elements" that can be fetched -class ResourceType : public BuiltinType +class ResourceType : public BuiltinType { SLANG_ABSTRACT_AST_CLASS(ResourceType) @@ -174,7 +169,7 @@ class ResourceType : public BuiltinType void _toTextOverride(StringBuilder& out); }; -class TextureTypeBase : public ResourceType +class TextureTypeBase : public ResourceType { SLANG_ABSTRACT_AST_CLASS(TextureTypeBase) @@ -182,13 +177,13 @@ class TextureTypeBase : public ResourceType Val* getFormat(); }; -class TextureType : public TextureTypeBase +class TextureType : public TextureTypeBase { SLANG_AST_CLASS(TextureType) }; // This is a base type for `image*` types, as they exist in GLSL -class GLSLImageType : public TextureTypeBase +class GLSLImageType : public TextureTypeBase { SLANG_AST_CLASS(GLSLImageType) }; @@ -201,7 +196,7 @@ class SubpassInputType : public BuiltinType Type* getElementType(); }; -class SamplerStateType : public BuiltinType +class SamplerStateType : public BuiltinType { SLANG_AST_CLASS(SamplerStateType) @@ -210,7 +205,7 @@ class SamplerStateType : public BuiltinType }; // Other cases of generic types known to the compiler -class BuiltinGenericType : public BuiltinType +class BuiltinGenericType : public BuiltinType { SLANG_AST_CLASS(BuiltinGenericType) @@ -220,7 +215,7 @@ class BuiltinGenericType : public BuiltinType // Types that behave like pointers, in that they can be // dereferenced (implicitly) to access members defined // in the element type. -class PointerLikeType : public BuiltinGenericType +class PointerLikeType : public BuiltinGenericType { SLANG_AST_CLASS(PointerLikeType) }; @@ -232,59 +227,59 @@ class DynamicResourceType : public BuiltinType // HLSL buffer-type resources -class HLSLStructuredBufferTypeBase : public BuiltinGenericType +class HLSLStructuredBufferTypeBase : public BuiltinGenericType { SLANG_AST_CLASS(HLSLStructuredBufferTypeBase) }; -class HLSLStructuredBufferType : public HLSLStructuredBufferTypeBase +class HLSLStructuredBufferType : public HLSLStructuredBufferTypeBase { SLANG_AST_CLASS(HLSLStructuredBufferType) }; -class HLSLRWStructuredBufferType : public HLSLStructuredBufferTypeBase +class HLSLRWStructuredBufferType : public HLSLStructuredBufferTypeBase { SLANG_AST_CLASS(HLSLRWStructuredBufferType) }; -class HLSLRasterizerOrderedStructuredBufferType : public HLSLStructuredBufferTypeBase +class HLSLRasterizerOrderedStructuredBufferType : public HLSLStructuredBufferTypeBase { SLANG_AST_CLASS(HLSLRasterizerOrderedStructuredBufferType) }; -class UntypedBufferResourceType : public BuiltinType +class UntypedBufferResourceType : public BuiltinType { SLANG_AST_CLASS(UntypedBufferResourceType) }; -class HLSLByteAddressBufferType : public UntypedBufferResourceType +class HLSLByteAddressBufferType : public UntypedBufferResourceType { SLANG_AST_CLASS(HLSLByteAddressBufferType) }; -class HLSLRWByteAddressBufferType : public UntypedBufferResourceType +class HLSLRWByteAddressBufferType : public UntypedBufferResourceType { SLANG_AST_CLASS(HLSLRWByteAddressBufferType) }; -class HLSLRasterizerOrderedByteAddressBufferType : public UntypedBufferResourceType +class HLSLRasterizerOrderedByteAddressBufferType : public UntypedBufferResourceType { SLANG_AST_CLASS(HLSLRasterizerOrderedByteAddressBufferType) }; -class RaytracingAccelerationStructureType : public UntypedBufferResourceType +class RaytracingAccelerationStructureType : public UntypedBufferResourceType { SLANG_AST_CLASS(RaytracingAccelerationStructureType) }; -class HLSLAppendStructuredBufferType : public HLSLStructuredBufferTypeBase +class HLSLAppendStructuredBufferType : public HLSLStructuredBufferTypeBase { SLANG_AST_CLASS(HLSLAppendStructuredBufferType) }; -class HLSLConsumeStructuredBufferType : public HLSLStructuredBufferTypeBase +class HLSLConsumeStructuredBufferType : public HLSLStructuredBufferTypeBase { SLANG_AST_CLASS(HLSLConsumeStructuredBufferType) }; @@ -294,7 +289,7 @@ class GLSLAtomicUintType : public BuiltinType SLANG_AST_CLASS(GLSLAtomicUintType) }; -class HLSLPatchType : public BuiltinType +class HLSLPatchType : public BuiltinType { SLANG_AST_CLASS(HLSLPatchType) @@ -302,12 +297,12 @@ class HLSLPatchType : public BuiltinType IntVal* getElementCount(); }; -class HLSLInputPatchType : public HLSLPatchType +class HLSLInputPatchType : public HLSLPatchType { SLANG_AST_CLASS(HLSLInputPatchType) }; -class HLSLOutputPatchType : public HLSLPatchType +class HLSLOutputPatchType : public HLSLPatchType { SLANG_AST_CLASS(HLSLOutputPatchType) }; @@ -315,22 +310,22 @@ class HLSLOutputPatchType : public HLSLPatchType // HLSL geometry shader output stream types -class HLSLStreamOutputType : public BuiltinGenericType +class HLSLStreamOutputType : public BuiltinGenericType { SLANG_AST_CLASS(HLSLStreamOutputType) }; -class HLSLPointStreamType : public HLSLStreamOutputType +class HLSLPointStreamType : public HLSLStreamOutputType { SLANG_AST_CLASS(HLSLPointStreamType) }; -class HLSLLineStreamType : public HLSLStreamOutputType +class HLSLLineStreamType : public HLSLStreamOutputType { SLANG_AST_CLASS(HLSLLineStreamType) }; -class HLSLTriangleStreamType : public HLSLStreamOutputType +class HLSLTriangleStreamType : public HLSLStreamOutputType { SLANG_AST_CLASS(HLSLTriangleStreamType) }; @@ -363,7 +358,7 @@ class PrimitivesType : public MeshOutputType // -class GLSLInputAttachmentType : public BuiltinType +class GLSLInputAttachmentType : public BuiltinType { SLANG_AST_CLASS(GLSLInputAttachmentType) }; @@ -371,17 +366,17 @@ class GLSLInputAttachmentType : public BuiltinType // Base class for types used when desugaring parameter block // declarations, includeing HLSL `cbuffer` or GLSL `uniform` blocks. -class ParameterGroupType : public PointerLikeType +class ParameterGroupType : public PointerLikeType { SLANG_AST_CLASS(ParameterGroupType) }; -class UniformParameterGroupType : public ParameterGroupType +class UniformParameterGroupType : public ParameterGroupType { SLANG_AST_CLASS(UniformParameterGroupType) }; -class VaryingParameterGroupType : public ParameterGroupType +class VaryingParameterGroupType : public ParameterGroupType { SLANG_AST_CLASS(VaryingParameterGroupType) }; @@ -389,26 +384,26 @@ class VaryingParameterGroupType : public ParameterGroupType // type for HLSL `cbuffer` declarations, and `ConstantBuffer` // ALso used for GLSL `uniform` blocks. -class ConstantBufferType : public UniformParameterGroupType +class ConstantBufferType : public UniformParameterGroupType { SLANG_AST_CLASS(ConstantBufferType) }; // type for HLSL `tbuffer` declarations, and `TextureBuffer` -class TextureBufferType : public UniformParameterGroupType +class TextureBufferType : public UniformParameterGroupType { SLANG_AST_CLASS(TextureBufferType) }; // type for GLSL `in` and `out` blocks -class GLSLInputParameterGroupType : public VaryingParameterGroupType +class GLSLInputParameterGroupType : public VaryingParameterGroupType { SLANG_AST_CLASS(GLSLInputParameterGroupType) }; -class GLSLOutputParameterGroupType : public VaryingParameterGroupType +class GLSLOutputParameterGroupType : public VaryingParameterGroupType { SLANG_AST_CLASS(GLSLOutputParameterGroupType) }; @@ -422,12 +417,12 @@ class GLSLShaderStorageBufferType : public PointerLikeType // type for Slang `ParameterBlock` type -class ParameterBlockType : public UniformParameterGroupType +class ParameterBlockType : public UniformParameterGroupType { SLANG_AST_CLASS(ParameterBlockType) }; -class ArrayExpressionType : public DeclRefType +class ArrayExpressionType : public DeclRefType { SLANG_AST_CLASS(ArrayExpressionType) @@ -447,7 +442,7 @@ class AtomicType : public DeclRefType // The "type" of an expression that resolves to a type. // For example, in the expression `float(2)` the sub-expression, // `float` would have the type `TypeType(float)`. -class TypeType : public Type +class TypeType : public Type { SLANG_AST_CLASS(TypeType) @@ -457,20 +452,17 @@ class TypeType : public Type Type* getType() { return as(getOperand(0)); } - TypeType(Type* type) - { - setOperands(type); - } + TypeType(Type* type) { setOperands(type); } }; // A differential pair type, e.g., `__DifferentialPair` -class DifferentialPairType : public ArithmeticExpressionType +class DifferentialPairType : public ArithmeticExpressionType { SLANG_AST_CLASS(DifferentialPairType) Type* getPrimalType(); }; -class DifferentialPtrPairType : public ArithmeticExpressionType +class DifferentialPtrPairType : public ArithmeticExpressionType { SLANG_AST_CLASS(DifferentialPtrPairType) Type* getPrimalRefType(); @@ -492,7 +484,7 @@ class DefaultInitializableType : public BuiltinType }; // A vector type, e.g., `vector` -class VectorExpressionType : public ArithmeticExpressionType +class VectorExpressionType : public ArithmeticExpressionType { SLANG_AST_CLASS(VectorExpressionType) @@ -505,14 +497,14 @@ class VectorExpressionType : public ArithmeticExpressionType }; // A matrix type, e.g., `matrix` -class MatrixExpressionType : public ArithmeticExpressionType +class MatrixExpressionType : public ArithmeticExpressionType { SLANG_AST_CLASS(MatrixExpressionType) - Type* getElementType(); - IntVal* getRowCount(); - IntVal* getColumnCount(); - IntVal* getLayout(); + Type* getElementType(); + IntVal* getRowCount(); + IntVal* getColumnCount(); + IntVal* getLayout(); Type* getRowType(); @@ -556,7 +548,7 @@ class DynamicType : public BuiltinType }; // Type built-in `__EnumType` type -class EnumTypeType : public BuiltinType +class EnumTypeType : public BuiltinType { SLANG_AST_CLASS(EnumTypeType) @@ -565,7 +557,7 @@ class EnumTypeType : public BuiltinType // Base class for types that map down to // simple pointers as part of code generation. -class PtrTypeBase : public BuiltinType +class PtrTypeBase : public BuiltinType { SLANG_AST_CLASS(PtrTypeBase) @@ -586,7 +578,7 @@ class NullPtrType : public BuiltinType }; // A true (user-visible) pointer type, e.g., `T*` -class PtrType : public PtrTypeBase +class PtrType : public PtrTypeBase { SLANG_AST_CLASS(PtrType) @@ -615,13 +607,13 @@ class OutTypeBase : public ParamDirectionType }; // The type for an `out` parameter, e.g., `out T` -class OutType : public OutTypeBase +class OutType : public OutTypeBase { SLANG_AST_CLASS(OutType) }; // The type for an `in out` parameter, e.g., `in out T` -class InOutType : public OutTypeBase +class InOutType : public OutTypeBase { SLANG_AST_CLASS(InOutType) }; @@ -658,7 +650,7 @@ class NativeRefType : public BuiltinType }; // A type alias of some kind (e.g., via `typedef`) -class NamedExpressionType : public Type +class NamedExpressionType : public Type { SLANG_AST_CLASS(NamedExpressionType) @@ -668,15 +660,12 @@ class NamedExpressionType : public Type void _toTextOverride(StringBuilder& out); Type* _createCanonicalTypeOverride(); - NamedExpressionType(DeclRef inDeclRef) - { - setOperands(inDeclRef); - } + NamedExpressionType(DeclRef inDeclRef) { setOperands(inDeclRef); } }; // A function type is defined by its parameter types // and its result type. -class FuncType : public Type +class FuncType : public Type { SLANG_AST_CLASS(FuncType) @@ -717,7 +706,6 @@ class TupleType : public DeclRefType Index getMemberCount() const; Type* getMember(Index i) const; Type* getTypePack() const; - }; class EachType : public Type @@ -726,10 +714,7 @@ class EachType : public Type Type* getElementType() const { return as(getOperand(0)); } DeclRefType* getElementDeclRefType() const { return as(getOperand(0)); } - EachType(Type* elementType) - { - m_operands.add(ValNodeOperand(elementType)); - } + EachType(Type* elementType) { m_operands.add(ValNodeOperand(elementType)); } void _toTextOverride(StringBuilder& out); Type* _createCanonicalTypeOverride(); Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff); @@ -769,7 +754,7 @@ class ConcreteTypePack : public Type }; // The "type" of an expression that names a generic declaration. -class GenericDeclRefType : public Type +class GenericDeclRefType : public Type { SLANG_AST_CLASS(GenericDeclRefType) @@ -779,23 +764,17 @@ class GenericDeclRefType : public Type void _toTextOverride(StringBuilder& out); Type* _createCanonicalTypeOverride(); - GenericDeclRefType(DeclRef declRef) - { - setOperands(declRef); - } + GenericDeclRefType(DeclRef declRef) { setOperands(declRef); } }; // The "type" of a reference to a module or namespace -class NamespaceType : public Type +class NamespaceType : public Type { SLANG_AST_CLASS(NamespaceType) DeclRef getDeclRef() const { return as(getOperand(0)); } - NamespaceType(DeclRef inDeclRef) - { - setOperands(inDeclRef); - } + NamespaceType(DeclRef inDeclRef) { setOperands(inDeclRef); } // Overrides should be public so base classes can access void _toTextOverride(StringBuilder& out); @@ -804,7 +783,7 @@ class NamespaceType : public Type // The concrete type for a value wrapped in an existential, accessible // when the existential is "opened" in some context. -class ExtractExistentialType : public Type +class ExtractExistentialType : public Type { SLANG_AST_CLASS(ExtractExistentialType) @@ -824,8 +803,8 @@ class ExtractExistentialType : public Type setOperands(inDeclRef, inOriginalInterfaceType, inOriginalInterfaceDeclRef); } -// Following fields will not be reflected (and thus won't be serialized, etc.) -SLANG_UNREFLECTED + // Following fields will not be reflected (and thus won't be serialized, etc.) + SLANG_UNREFLECTED // A cached decl-ref to the original interface's ThisType Decl, with // a witness that refers to the type extracted here. @@ -846,21 +825,21 @@ SLANG_UNREFLECTED Type* _createCanonicalTypeOverride(); Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff); - /// Get a witness that shows how this type is a subtype of `originalInterfaceType`. - /// - /// This operation may create the witness on demand and cache it. - /// + /// Get a witness that shows how this type is a subtype of `originalInterfaceType`. + /// + /// This operation may create the witness on demand and cache it. + /// SubtypeWitness* getSubtypeWitness(); - /// Get a decl-ref to the interface's ThisType decl, which represents a substitutable type - /// from which lookup can be performed. - /// - /// This operation may create the decl-ref on demand and cache it. - /// + /// Get a decl-ref to the interface's ThisType decl, which represents a substitutable type + /// from which lookup can be performed. + /// + /// This operation may create the decl-ref on demand and cache it. + /// DeclRef getThisTypeDeclRef(); }; -class ExistentialSpecializedType : public Type +class ExistentialSpecializedType : public Type { SLANG_AST_CLASS(ExistentialSpecializedType) @@ -874,9 +853,7 @@ class ExistentialSpecializedType : public Type } Index getArgCount() { return (getOperandCount() - 1) / 2; } - ExistentialSpecializedType( - Type* inBaseType, - ExpandedSpecializationArgs const& inArgs) + ExistentialSpecializedType(Type* inBaseType, ExpandedSpecializationArgs const& inArgs) { m_operands.add(ValNodeOperand(inBaseType)); for (auto arg : inArgs) @@ -892,30 +869,30 @@ class ExistentialSpecializedType : public Type Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff); }; - /// The type of `this` within a polymorphic declaration +/// The type of `this` within a polymorphic declaration class ThisType : public DeclRefType { SLANG_AST_CLASS(ThisType) - ThisType(DeclRefBase* declRef) : DeclRefType(declRef) {} + ThisType(DeclRefBase* declRef) + : DeclRefType(declRef) + { + } DeclRef getInterfaceDeclRef(); }; - /// The type of `A & B` where `A` and `B` are types - /// - /// A value `v` is of type `A & B` if it is both of type `A` and of type `B`. +/// The type of `A & B` where `A` and `B` are types +/// +/// A value `v` is of type `A & B` if it is both of type `A` and of type `B`. class AndType : public Type { SLANG_AST_CLASS(AndType) Type* getLeft() { return as(getOperand(0)); } Type* getRight() { return as(getOperand(1)); } - - AndType(Type* leftType, Type* rightType) - { - setOperands(leftType, rightType); - } + + AndType(Type* leftType, Type* rightType) { setOperands(leftType, rightType); } // Overrides should be public so base classes can access void _toTextOverride(StringBuilder& out); @@ -927,10 +904,7 @@ class ModifiedType : public Type { SLANG_AST_CLASS(ModifiedType) - Type* getBase() - { - return as(getOperand(0)); - } + Type* getBase() { return as(getOperand(0)); } Index getModifierCount() { return getOperandCount() - 1; } Val* getModifier(Index index) { return getOperand(index + 1); } diff --git a/source/slang/slang-ast-val.cpp b/source/slang/slang-ast-val.cpp index b8c5e6ee1..2a2f275ee 100644 --- a/source/slang/slang-ast-val.cpp +++ b/source/slang/slang-ast-val.cpp @@ -1,16 +1,18 @@ // slang-ast-type.cpp -#include "slang-ast-builder.h" -#include -#include +#include "slang-ast-val.h" -#include "slang-generated-ast-macro.h" +#include "slang-ast-builder.h" +#include "slang-check-impl.h" #include "slang-diagnostics.h" -#include "slang-syntax.h" -#include "slang-ast-val.h" +#include "slang-generated-ast-macro.h" #include "slang-mangle.h" -#include "slang-check-impl.h" +#include "slang-syntax.h" + +#include +#include -namespace Slang { +namespace Slang +{ void ValNodeDesc::init() { @@ -30,7 +32,8 @@ void ValNodeDesc::init() Val* Val::substitute(ASTBuilder* astBuilder, SubstitutionSet subst) { - if (!subst) return this; + if (!subst) + return this; int diff = 0; return substituteImpl(astBuilder, subst, &diff); } @@ -40,12 +43,9 @@ Val* Val::substituteImpl(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioD SLANG_AST_NODE_VIRTUAL_CALL(Val, substituteImpl, (astBuilder, subst, ioDiff)) } -void Val::toText(StringBuilder& out) -{ - SLANG_AST_NODE_VIRTUAL_CALL(Val, toText, (out)) -} +void Val::toText(StringBuilder& out){SLANG_AST_NODE_VIRTUAL_CALL(Val, toText, (out))} -Val* Val::_resolveImplOverride() +Val* Val::_resolveImplOverride() { SLANG_UNEXPECTED("Val::_resolveImplOverride not overridden"); } @@ -60,7 +60,7 @@ Val* Val::resolve() auto astBuilder = getCurrentASTBuilder(); // If we are not in a proper checking context, just return the previously resolved val. if (!astBuilder) - return m_resolvedVal? m_resolvedVal : this; + return m_resolvedVal ? m_resolvedVal : this; if (m_resolvedVal && m_resolvedValEpoch == astBuilder->getEpoch()) { SLANG_ASSERT(as(m_resolvedVal)); @@ -72,7 +72,8 @@ Val* Val::resolve() #ifdef _DEBUG if (m_resolvedVal->_debugUID > 0 && this->_debugUID < 0) { - SLANG_ASSERT_FAILURE("should not be modifying the core module vals outside of the core module checking."); + SLANG_ASSERT_FAILURE( + "should not be modifying the core module vals outside of the core module checking."); } #endif return m_resolvedVal; @@ -86,7 +87,8 @@ void Val::_setUnique() Val* Val::defaultResolveImpl() { - // Default resolve implementation is to recursively resolve all operands, and lookup in deduplication cache. + // Default resolve implementation is to recursively resolve all operands, and lookup in + // deduplication cache. ValNodeDesc newDesc; newDesc.type = astNodeType; bool diff = false; @@ -107,7 +109,7 @@ Val* Val::defaultResolveImpl() } newDesc.operands.add(operand); } - + if (!diff) return this; @@ -220,10 +222,12 @@ Val* maybeSubstituteGenericParam(Val* paramVal, Decl* paramDecl, SubstitutionSet // Nothing found: don't substitute. return paramVal; - } -Val* GenericParamIntVal::_substituteImplOverride(ASTBuilder* /* astBuilder */, SubstitutionSet subst, int* ioDiff) +Val* GenericParamIntVal::_substituteImplOverride( + ASTBuilder* /* astBuilder */, + SubstitutionSet subst, + int* ioDiff) { if (auto result = maybeSubstituteGenericParam(this, getDeclRef().getDecl(), subst, ioDiff)) return result; @@ -252,7 +256,10 @@ void ErrorIntVal::_toTextOverride(StringBuilder& out) out << toSlice(""); } -Val* ErrorIntVal::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) +Val* ErrorIntVal::_substituteImplOverride( + ASTBuilder* astBuilder, + SubstitutionSet subst, + int* ioDiff) { SLANG_UNUSED(astBuilder); SLANG_UNUSED(subst); @@ -260,7 +267,10 @@ Val* ErrorIntVal::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSe return this; } -Val* TypeEqualityWitness::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int * ioDiff) +Val* TypeEqualityWitness::_substituteImplOverride( + ASTBuilder* astBuilder, + SubstitutionSet subst, + int* ioDiff) { auto type = as(getSub()->substituteImpl(astBuilder, subst, ioDiff)); TypeEqualityWitness* rs = astBuilder->getOrCreate(type, type); @@ -274,7 +284,10 @@ void TypeEqualityWitness::_toTextOverride(StringBuilder& out) // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! TypePackSubtypeWitness !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! -Val* TypePackSubtypeWitness::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) +Val* TypePackSubtypeWitness::_substituteImplOverride( + ASTBuilder* astBuilder, + SubstitutionSet subst, + int* ioDiff) { int diff = 0; ShortList newWitnesses; @@ -289,7 +302,10 @@ Val* TypePackSubtypeWitness::_substituteImplOverride(ASTBuilder* astBuilder, Sub if (!diff) return this; (*ioDiff)++; - return getCurrentASTBuilder()->getSubtypeWitnessPack(newSub, newSup, newWitnesses.getArrayView().arrayView); + return getCurrentASTBuilder()->getSubtypeWitnessPack( + newSub, + newSup, + newWitnesses.getArrayView().arrayView); } Val* TypePackSubtypeWitness::_resolveImplOverride() @@ -313,7 +329,10 @@ Val* TypePackSubtypeWitness::_resolveImplOverride() if (!diff) return this; - return getCurrentASTBuilder()->getSubtypeWitnessPack(newSub, newSup, newWitnesses.getArrayView().arrayView); + return getCurrentASTBuilder()->getSubtypeWitnessPack( + newSub, + newSup, + newWitnesses.getArrayView().arrayView); } void TypePackSubtypeWitness::_toTextOverride(StringBuilder& out) @@ -330,7 +349,10 @@ void TypePackSubtypeWitness::_toTextOverride(StringBuilder& out) // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! ExpandSubtypeWitness !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! -Val* ExpandSubtypeWitness::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) +Val* ExpandSubtypeWitness::_substituteImplOverride( + ASTBuilder* astBuilder, + SubstitutionSet subst, + int* ioDiff) { int diff = 0; auto newSub = as(getSub()->substituteImpl(astBuilder, subst, &diff)); @@ -346,16 +368,24 @@ Val* ExpandSubtypeWitness::_substituteImplOverride(ASTBuilder* astBuilder, Subst { auto elementType = subTypePack->getElementType(i); subst.packExpansionIndex = i; - auto elementWitness = as(getPatternTypeWitness()->substituteImpl(astBuilder, subst, &diff)); - auto newWitness = getCurrentASTBuilder()->getExpandSubtypeWitness(elementType, newSup, elementWitness); + auto elementWitness = as( + getPatternTypeWitness()->substituteImpl(astBuilder, subst, &diff)); + auto newWitness = getCurrentASTBuilder()->getExpandSubtypeWitness( + elementType, + newSup, + elementWitness); newWitnesses.add(as(newWitness)); } (*ioDiff)++; - return getCurrentASTBuilder()->getSubtypeWitnessPack(newSub, newSup, newWitnesses.getArrayView().arrayView); + return getCurrentASTBuilder()->getSubtypeWitnessPack( + newSub, + newSup, + newWitnesses.getArrayView().arrayView); } (*ioDiff)++; - auto newPatternWitness = as(getPatternTypeWitness()->substituteImpl(astBuilder, subst, &diff)); + auto newPatternWitness = + as(getPatternTypeWitness()->substituteImpl(astBuilder, subst, &diff)); return getCurrentASTBuilder()->getExpandSubtypeWitness(newSub, newSup, newPatternWitness); } @@ -385,10 +415,14 @@ void ExpandSubtypeWitness::_toTextOverride(StringBuilder& out) // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! EachSubtypeWitness !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! -Val* EachSubtypeWitness::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) +Val* EachSubtypeWitness::_substituteImplOverride( + ASTBuilder* astBuilder, + SubstitutionSet subst, + int* ioDiff) { int diff = 0; - auto newPatternWitness = as(getPatternTypeWitness()->substituteImpl(astBuilder, subst, &diff)); + auto newPatternWitness = + as(getPatternTypeWitness()->substituteImpl(astBuilder, subst, &diff)); if (auto witnessPack = as(newPatternWitness)) { if (subst.packExpansionIndex >= 0 && subst.packExpansionIndex < witnessPack->getCount()) @@ -464,7 +498,10 @@ ConversionCost DeclaredSubtypeWitness::_getOverloadResolutionCostOverride() return kConversionCost_None; } -Val* DeclaredSubtypeWitness::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int * ioDiff) +Val* DeclaredSubtypeWitness::_substituteImplOverride( + ASTBuilder* astBuilder, + SubstitutionSet subst, + int* ioDiff) { if (auto genConstraintDeclRef = getDeclRef().as()) { @@ -493,7 +530,8 @@ Val* DeclaredSubtypeWitness::_substituteImplOverride(ASTBuilder* astBuilder, Sub } if (found) { - auto ordinaryParamCount = genericDecl->getMembersOfType().getCount() + + auto ordinaryParamCount = + genericDecl->getMembersOfType().getCount() + genericDecl->getMembersOfType().getCount(); if (index + ordinaryParamCount < args.getCount()) { @@ -502,8 +540,8 @@ Val* DeclaredSubtypeWitness::_substituteImplOverride(ASTBuilder* astBuilder, Sub } else { - // When the `subst` represents a partial substitution, we may not have a corresponding argument. - // In this case we just return the original witness. + // When the `subst` represents a partial substitution, we may not have a + // corresponding argument. In this case we just return the original witness. // goto breakLabel; } @@ -512,7 +550,8 @@ Val* DeclaredSubtypeWitness::_substituteImplOverride(ASTBuilder* astBuilder, Sub else if (auto thisTypeConstraintDeclRef = getDeclRef().as()) { auto lookupSubst = subst.findLookupDeclRef(); - if (lookupSubst && lookupSubst->getSupDecl() == thisTypeConstraintDeclRef.getDecl()->getInterfaceDecl()) + if (lookupSubst && + lookupSubst->getSupDecl() == thisTypeConstraintDeclRef.getDecl()->getInterfaceDecl()) { (*ioDiff)++; return lookupSubst->getWitness(); @@ -525,7 +564,7 @@ breakLabel:; int diff = 0; auto substSub = as(getSub()->substituteImpl(astBuilder, subst, &diff)); auto substSup = as(getSup()->substituteImpl(astBuilder, subst, &diff)); - + if (!diff) return this; @@ -555,13 +594,13 @@ breakLabel:; // We need to look up the declaration that satisfies // the requirement named by the associated type. Decl* requirementKey = substTypeConstraintDecl; - RequirementWitness requirementWitness = tryLookUpRequirementWitness(astBuilder, thisTypeWitness, requirementKey); + RequirementWitness requirementWitness = + tryLookUpRequirementWitness(astBuilder, thisTypeWitness, requirementKey); switch (requirementWitness.getFlavor()) { - default: - break; + default: break; - case RequirementWitness::Flavor::val: + case RequirementWitness::Flavor::val: { auto satisfyingVal = requirementWitness.getVal(); return satisfyingVal; @@ -573,24 +612,29 @@ breakLabel:; } auto substDeclRef = getDeclRef().substituteImpl(astBuilder, subst, &diff); - auto rs = astBuilder->getDeclaredSubtypeWitness( - substSub, substSup, substDeclRef); + auto rs = astBuilder->getDeclaredSubtypeWitness(substSub, substSup, substDeclRef); return rs; } void DeclaredSubtypeWitness::_toTextOverride(StringBuilder& out) { - out << toSlice("DeclaredSubtypeWitness(") << getSub() << toSlice(", ") << getSup() << toSlice(", ") << getDeclRef() << toSlice(")"); + out << toSlice("DeclaredSubtypeWitness(") << getSub() << toSlice(", ") << getSup() + << toSlice(", ") << getDeclRef() << toSlice(")"); } // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! TransitiveSubtypeWitness !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! -Val* TransitiveSubtypeWitness::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int * ioDiff) +Val* TransitiveSubtypeWitness::_substituteImplOverride( + ASTBuilder* astBuilder, + SubstitutionSet subst, + int* ioDiff) { int diff = 0; - SubtypeWitness* substSubToMid = as(getSubToMid()->substituteImpl(astBuilder, subst, &diff)); - SubtypeWitness* substMidToSup = as(getMidToSup()->substituteImpl(astBuilder, subst, &diff)); + SubtypeWitness* substSubToMid = + as(getSubToMid()->substituteImpl(astBuilder, subst, &diff)); + SubtypeWitness* substMidToSup = + as(getMidToSup()->substituteImpl(astBuilder, subst, &diff)); // If nothing changed, then we can bail out early. if (!diff) @@ -611,7 +655,8 @@ Val* TransitiveSubtypeWitness::_substituteImplOverride(ASTBuilder* astBuilder, S ConversionCost TransitiveSubtypeWitness::_getOverloadResolutionCostOverride() { - return getSubToMid()->getOverloadResolutionCost() + getMidToSup()->getOverloadResolutionCost() + kConversionCost_GenericParamUpcast; + return getSubToMid()->getOverloadResolutionCost() + getMidToSup()->getOverloadResolutionCost() + + kConversionCost_GenericParamUpcast; } void TransitiveSubtypeWitness::_toTextOverride(StringBuilder& out) @@ -619,19 +664,25 @@ void TransitiveSubtypeWitness::_toTextOverride(StringBuilder& out) // Note: we only print the constituent // witnesses, and rely on them to print // the starting and ending types. - - out << toSlice("TransitiveSubtypeWitness(") << getSubToMid() << toSlice(", ") << getMidToSup() << toSlice(")"); + + out << toSlice("TransitiveSubtypeWitness(") << getSubToMid() << toSlice(", ") << getMidToSup() + << toSlice(")"); } -// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! ExtractFromConjunctionSubtypeWitness !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! ExtractFromConjunctionSubtypeWitness +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! -Val* ExtractFromConjunctionSubtypeWitness::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int * ioDiff) +Val* ExtractFromConjunctionSubtypeWitness::_substituteImplOverride( + ASTBuilder* astBuilder, + SubstitutionSet subst, + int* ioDiff) { int diff = 0; auto substSub = as(getSub()->substituteImpl(astBuilder, subst, &diff)); auto substSup = as(getSup()->substituteImpl(astBuilder, subst, &diff)); - auto substWitness = as(getConjunctionWitness()->substituteImpl(astBuilder, subst, &diff)); + auto substWitness = + as(getConjunctionWitness()->substituteImpl(astBuilder, subst, &diff)); // If nothing changed, then we can bail out early. if (!diff) @@ -651,7 +702,10 @@ Val* ExtractFromConjunctionSubtypeWitness::_substituteImplOverride(ASTBuilder* a // simplification logic as needed. // return astBuilder->getExtractFromConjunctionSubtypeWitness( - substSub, substSup, substWitness, getIndexInConjunction()); + substSub, + substSup, + substWitness, + getIndexInConjunction()); } ConversionCost ExtractFromConjunctionSubtypeWitness::_getOverloadResolutionCostOverride() @@ -665,14 +719,18 @@ ConversionCost ExtractFromConjunctionSubtypeWitness::_getOverloadResolutionCostO return kConversionCost_None; } -// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! ExtractExistentialSubtypeWitness !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! ExtractExistentialSubtypeWitness +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! void ExtractExistentialSubtypeWitness::_toTextOverride(StringBuilder& out) { out << toSlice("extractExistentialValue(") << getDeclRef() << toSlice(")"); } -Val* ExtractExistentialSubtypeWitness::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) +Val* ExtractExistentialSubtypeWitness::_substituteImplOverride( + ASTBuilder* astBuilder, + SubstitutionSet subst, + int* ioDiff) { int diff = 0; @@ -685,8 +743,8 @@ Val* ExtractExistentialSubtypeWitness::_substituteImplOverride(ASTBuilder* astBu (*ioDiff)++; - ExtractExistentialSubtypeWitness* substValue = astBuilder->getOrCreate( - substSub, substSup, substDeclRef); + ExtractExistentialSubtypeWitness* substValue = + astBuilder->getOrCreate(substSub, substSup, substDeclRef); return substValue; } @@ -695,15 +753,20 @@ void ConjunctionSubtypeWitness::_toTextOverride(StringBuilder& out) out << "ConjunctionSubtypeWitness("; for (Index i = 0; i < kComponentCount; ++i) { - if (i != 0) out << ","; + if (i != 0) + out << ","; auto w = getComponentWitness(i); - if (w) out << w; + if (w) + out << w; } out << ")"; } -Val* ConjunctionSubtypeWitness::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) +Val* ConjunctionSubtypeWitness::_substituteImplOverride( + ASTBuilder* astBuilder, + SubstitutionSet subst, + int* ioDiff) { int diff = 0; Val* substComponentWitnesses[kComponentCount]; @@ -717,7 +780,7 @@ Val* ConjunctionSubtypeWitness::_substituteImplOverride(ASTBuilder* astBuilder, substComponentWitnesses[i] = w ? w->substituteImpl(astBuilder, subst, &diff) : nullptr; } - if(!diff) + if (!diff) return this; *ioDiff += diff; @@ -764,7 +827,10 @@ void UNormModifierVal::_toTextOverride(StringBuilder& out) out.append("unorm"); } -Val* UNormModifierVal::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) +Val* UNormModifierVal::_substituteImplOverride( + ASTBuilder* astBuilder, + SubstitutionSet subst, + int* ioDiff) { SLANG_UNUSED(astBuilder); SLANG_UNUSED(subst); @@ -779,7 +845,10 @@ void SNormModifierVal::_toTextOverride(StringBuilder& out) out.append("snorm"); } -Val* SNormModifierVal::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) +Val* SNormModifierVal::_substituteImplOverride( + ASTBuilder* astBuilder, + SubstitutionSet subst, + int* ioDiff) { SLANG_UNUSED(astBuilder); SLANG_UNUSED(subst); @@ -793,7 +862,10 @@ void NoDiffModifierVal::_toTextOverride(StringBuilder& out) out.append("no_diff"); } -Val* NoDiffModifierVal::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) +Val* NoDiffModifierVal::_substituteImplOverride( + ASTBuilder* astBuilder, + SubstitutionSet subst, + int* ioDiff) { SLANG_UNUSED(astBuilder); SLANG_UNUSED(subst); @@ -864,7 +936,8 @@ struct PolynomialIntValBuilder PolynomialIntValBuilder(ASTBuilder* inAstBuilder) : astBuilder(inAstBuilder) - {} + { + } // compute val += opreand*multiplier; bool addToPolynomialTerm(IntVal* operand, IntegerLiteralValue multiplier) @@ -880,7 +953,8 @@ struct PolynomialIntValBuilder for (auto term : poly->getTerms()) { auto newTerm = astBuilder->getOrCreate( - multiplier * term->getConstFactor(), term->getParamFactors()); + multiplier * term->getConstFactor(), + term->getParamFactors()); terms.add(newTerm); } return true; @@ -888,7 +962,9 @@ struct PolynomialIntValBuilder else if (auto genVal = as(operand)) { auto factor = astBuilder->getOrCreate(genVal, 1); - auto term = astBuilder->getOrCreate(multiplier, makeArrayViewSingle(factor)); + auto term = astBuilder->getOrCreate( + multiplier, + makeArrayViewSingle(factor)); terms.add(term); return true; } @@ -931,10 +1007,14 @@ struct PolynomialIntValBuilder if (!factorIsDifferent[j]) { factorIsDifferent[j] = true; - auto clonedFactor = astBuilder->getOrCreate(newFactor->getParam(), newFactor->getPower()); + auto clonedFactor = astBuilder->getOrCreate( + newFactor->getParam(), + newFactor->getPower()); newFactor = clonedFactor; } - newFactor = astBuilder->getOrCreate(newFactor->getParam(), newFactor->getPower() + factor->getPower()); + newFactor = astBuilder->getOrCreate( + newFactor->getParam(), + newFactor->getPower() + factor->getPower()); factorFound = true; break; } @@ -957,7 +1037,8 @@ struct PolynomialIntValBuilder newConstantTerm += term->getConstFactor(); continue; } - newFactors2.sort([](PolynomialIntValFactor* t1, PolynomialIntValFactor* t2) {return *t1 < *t2; }); + newFactors2.sort([](PolynomialIntValFactor* t1, PolynomialIntValFactor* t2) + { return *t1 < *t2; }); bool isDifferent = false; if (newFactors2.getCount() != term->getParamFactors().getCount()) isDifferent = true; @@ -976,7 +1057,9 @@ struct PolynomialIntValBuilder } else { - auto newTerm = astBuilder->getOrCreate(term->getConstFactor(), newFactors2.getArrayView()); + auto newTerm = astBuilder->getOrCreate( + term->getConstFactor(), + newFactors2.getArrayView()); addTerm(newTerm); } } @@ -987,10 +1070,12 @@ struct PolynomialIntValBuilder continue; newTerms2.add(term); } - newTerms2.sort([](PolynomialIntValTerm* t1, PolynomialIntValTerm* t2) {return *t1 < *t2; }); + newTerms2.sort([](PolynomialIntValTerm* t1, PolynomialIntValTerm* t2) + { return *t1 < *t2; }); terms = _Move(newTerms2); constantTerm = newConstantTerm; - if (terms.getCount() == 1 && constantTerm == 0 && terms[0]->getConstFactor() == 1 && terms[0]->getParamFactors().getCount() == 1 && + if (terms.getCount() == 1 && constantTerm == 0 && terms[0]->getConstFactor() == 1 && + terms[0]->getParamFactors().getCount() == 1 && terms[0]->getParamFactors()[0]->getPower() == 1) { return terms[0]->getParamFactors()[0]->getParam(); @@ -1008,7 +1093,10 @@ struct PolynomialIntValBuilder } }; -Val* PolynomialIntVal::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) +Val* PolynomialIntVal::_substituteImplOverride( + ASTBuilder* astBuilder, + SubstitutionSet subst, + int* ioDiff) { int diff = 0; PolynomialIntValBuilder builder(astBuilder); @@ -1021,14 +1109,15 @@ Val* PolynomialIntVal::_substituteImplOverride(ASTBuilder* astBuilder, Substitut for (auto& factor : term->getParamFactors()) { auto substResult = factor->getParam()->substituteImpl(astBuilder, subst, &diff); - + if (auto constantVal = as(substResult)) { evaluatedTermConstFactor *= constantVal->getValue(); } else if (auto intResult = as(substResult)) { - auto newFactor = astBuilder->getOrCreate(intResult, factor->getPower()); + auto newFactor = + astBuilder->getOrCreate(intResult, factor->getPower()); evaluatedTermParamFactors.add(newFactor); } } @@ -1038,7 +1127,8 @@ Val* PolynomialIntVal::_substituteImplOverride(ASTBuilder* astBuilder, Substitut } else { - if (evaluatedTermParamFactors.getCount() == 1 && evaluatedTermParamFactors[0]->getPower() == 1) + if (evaluatedTermParamFactors.getCount() == 1 && + evaluatedTermParamFactors[0]->getPower() == 1) { if (auto polyTerm = as(evaluatedTermParamFactors[0]->getParam())) { @@ -1047,7 +1137,8 @@ Val* PolynomialIntVal::_substituteImplOverride(ASTBuilder* astBuilder, Substitut } } auto newTerm = astBuilder->getOrCreate( - evaluatedTermConstFactor, evaluatedTermParamFactors.getArrayView()); + evaluatedTermConstFactor, + evaluatedTermParamFactors.getArrayView()); builder.terms.add(newTerm); } } @@ -1101,7 +1192,8 @@ IntVal* PolynomialIntVal::mul(ASTBuilder* astBuilder, IntVal* op0, IntVal* op1) for (auto term : poly1->getTerms()) { auto newTerm = astBuilder->getOrCreate( - poly0->getConstantTerm() * term->getConstFactor(), term->getParamFactors()); + poly0->getConstantTerm() * term->getConstFactor(), + term->getParamFactors()); builder.terms.add(newTerm); } } @@ -1122,10 +1214,13 @@ IntVal* PolynomialIntVal::mul(ASTBuilder* astBuilder, IntVal* op0, IntVal* op1) for (auto term1 : poly1->getTerms()) { List newFactors; - for (auto f : term0->getParamFactors()) newFactors.add(f); - for (auto f : term1->getParamFactors()) newFactors.add(f); + for (auto f : term0->getParamFactors()) + newFactors.add(f); + for (auto f : term1->getParamFactors()) + newFactors.add(f); auto newTerm = astBuilder->getOrCreate( - term0->getConstFactor() * term1->getConstFactor(), newFactors.getArrayView()); + term0->getConstFactor() * term1->getConstFactor(), + newFactors.getArrayView()); builder.terms.add(newTerm); } } @@ -1137,7 +1232,9 @@ IntVal* PolynomialIntVal::mul(ASTBuilder* astBuilder, IntVal* op0, IntVal* op1) builder.constantTerm = poly0->getConstantTerm() * cVal1->getValue(); for (auto term : poly0->getTerms()) { - auto newTerm = astBuilder->getOrCreate(term->getConstFactor() * cVal1->getValue(), term->getParamFactors()); + auto newTerm = astBuilder->getOrCreate( + term->getConstFactor() * cVal1->getValue(), + term->getParamFactors()); builder.terms.add(newTerm); } return builder.getIntVal(poly0->getType()); @@ -1148,17 +1245,20 @@ IntVal* PolynomialIntVal::mul(ASTBuilder* astBuilder, IntVal* op0, IntVal* op1) auto factor1 = astBuilder->getOrCreate(val1, 1); if (poly0->getConstantTerm() != 0) { - auto term0 = astBuilder->getOrCreate(poly0->getConstantTerm(), makeArrayViewSingle(factor1)); + auto term0 = astBuilder->getOrCreate( + poly0->getConstantTerm(), + makeArrayViewSingle(factor1)); builder.terms.add(term0); } for (auto term : poly0->getTerms()) { List newFactors; - for (auto f: term->getParamFactors()) + for (auto f : term->getParamFactors()) newFactors.add(f); newFactors.add(factor1); auto newTerm = astBuilder->getOrCreate( - term->getConstFactor(), newFactors.getArrayView()); + term->getConstFactor(), + newFactors.getArrayView()); builder.terms.add(newTerm); } return builder.getIntVal(poly0->getType()); @@ -1181,7 +1281,8 @@ IntVal* PolynomialIntVal::mul(ASTBuilder* astBuilder, IntVal* op0, IntVal* op1) PolynomialIntValBuilder builder(astBuilder); auto factor0 = astBuilder->getOrCreate(val0, 1); auto term = astBuilder->getOrCreate( - cVal1->getValue(), makeArrayView(&factor0, 1)); + cVal1->getValue(), + makeArrayView(&factor0, 1)); builder.terms.add(term); return builder.getIntVal(val0->getType()); } @@ -1190,7 +1291,7 @@ IntVal* PolynomialIntVal::mul(ASTBuilder* astBuilder, IntVal* op0, IntVal* op1) PolynomialIntValBuilder builder(astBuilder); auto factor0 = astBuilder->getOrCreate(val0, 1); auto factor1 = astBuilder->getOrCreate(val1, 1); - PolynomialIntValFactor* newFactors[] = { factor0, factor1 }; + PolynomialIntValFactor* newFactors[] = {factor0, factor1}; auto term = astBuilder->getOrCreate(1, makeArrayView(newFactors)); builder.terms.add(term); return builder.getIntVal(val0->getType()); @@ -1209,43 +1310,30 @@ void TypeCastIntVal::_toTextOverride(StringBuilder& out) out << ")"; } -Val* TypeCastIntVal::tryFoldImpl(ASTBuilder* astBuilder, Type* resultType, Val* base, DiagnosticSink* sink) +Val* TypeCastIntVal::tryFoldImpl( + ASTBuilder* astBuilder, + Type* resultType, + Val* base, + DiagnosticSink* sink) { SLANG_UNUSED(sink); auto convertValue = [&](BasicExpressionType* baseType, IntegerLiteralValue& resultValue) -> bool + { + switch (baseType->getBaseType()) { - switch (baseType->getBaseType()) - { - case BaseType::Int: - resultValue = (int)resultValue; - return true; - case BaseType::UInt: - resultValue = (unsigned int)resultValue; - return true; - case BaseType::Int64: - case BaseType::IntPtr: - resultValue = (Int64)resultValue; - return true; - case BaseType::UInt64: - case BaseType::UIntPtr: - resultValue = (UInt64)resultValue; - return true; - case BaseType::Int16: - resultValue = (int16_t)resultValue; - return true; - case BaseType::UInt16: - resultValue = (uint16_t)resultValue; - return true; - case BaseType::Int8: - resultValue = (int8_t)resultValue; - return true; - case BaseType::UInt8: - resultValue = (uint8_t)resultValue; - return true; - default: - return false; - } - }; + case BaseType::Int: resultValue = (int)resultValue; return true; + case BaseType::UInt: resultValue = (unsigned int)resultValue; return true; + case BaseType::Int64: + case BaseType::IntPtr: resultValue = (Int64)resultValue; return true; + case BaseType::UInt64: + case BaseType::UIntPtr: resultValue = (UInt64)resultValue; return true; + case BaseType::Int16: resultValue = (int16_t)resultValue; return true; + case BaseType::UInt16: resultValue = (uint16_t)resultValue; return true; + case BaseType::Int8: resultValue = (int8_t)resultValue; return true; + case BaseType::UInt8: resultValue = (uint8_t)resultValue; return true; + default: return false; + } + }; if (auto c = as(base)) { IntegerLiteralValue resultValue = c->getValue(); @@ -1275,7 +1363,10 @@ Val* TypeCastIntVal::_linkTimeResolveOverride(Dictionary& map) return tryFoldImpl(getCurrentASTBuilder(), getType(), resolvedBase, nullptr); } -Val* TypeCastIntVal::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) +Val* TypeCastIntVal::_substituteImplOverride( + ASTBuilder* astBuilder, + SubstitutionSet subst, + int* ioDiff) { int diff = 0; auto substBase = getBase()->substituteImpl(astBuilder, subst, &diff); @@ -1332,7 +1423,8 @@ void FuncCallIntVal::_toTextOverride(StringBuilder& out) { argToText(0); out << (name ? name->text : ""); - argToText(1);; + argToText(1); + ; } else if (args.getCount() == 1) { @@ -1356,7 +1448,8 @@ void FuncCallIntVal::_toTextOverride(StringBuilder& out) out << "("; for (Index i = 0; i < args.getCount(); i++) { - if (i > 0) out << ", "; + if (i > 0) + out << ", "; args[i]->toText(out); } out << ")"; @@ -1371,7 +1464,7 @@ Val* FuncCallIntVal::_resolveImplOverride() auto funcType = getFuncType(); Val* resolvedVal = this; - + auto newFuncDeclRef = as(funcDeclRef.declRefBase->resolve()); if (!newFuncDeclRef) return this; @@ -1391,12 +1484,21 @@ Val* FuncCallIntVal::_resolveImplOverride() resolvedVal = resolved; else if (diff) { - resolvedVal = astBuilder->getOrCreate(getType(), newFuncDeclRef, funcType, newArgs.getArrayView()); + resolvedVal = astBuilder->getOrCreate( + getType(), + newFuncDeclRef, + funcType, + newArgs.getArrayView()); } return resolvedVal; } -Val* FuncCallIntVal::tryFoldImpl(ASTBuilder* astBuilder, Type* resultType, DeclRef newFuncDecl, List& newArgs, DiagnosticSink* sink) +Val* FuncCallIntVal::tryFoldImpl( + ASTBuilder* astBuilder, + Type* resultType, + DeclRef newFuncDecl, + List& newArgs, + DiagnosticSink* sink) { // Are all args const now? List constArgs; @@ -1422,46 +1524,48 @@ Val* FuncCallIntVal::tryFoldImpl(ASTBuilder* astBuilder, Type* resultType, DeclR const auto opNameSlice = opName->text.getUnownedSlice(); IntegerLiteralValue resultValue = 0; - - // Define convenience macros. + + // Define convenience macros. // The last macro used in the list *must* be // TERMINATING_CASE, as this handles the closing else, and matches if nothing else does. -#define BINARY_OPERATOR_CASE(op) \ - if (opNameSlice == toSlice(#op)) \ - { \ - resultValue = constArgs[0]->getValue() op constArgs[1]->getValue(); \ - } else - -#define DIV_OPERATOR_CASE(op) \ - if (opNameSlice == toSlice(#op)) \ - { \ - if (constArgs[1]->getValue() == 0) \ - { \ - if (sink) \ - sink->diagnose(newFuncDecl.getLoc(), Diagnostics::divideByZero); \ - return nullptr; \ - } \ - resultValue = constArgs[0]->getValue() op constArgs[1]->getValue(); \ - } else - -#define LOGICAL_OPERATOR_CASE(op) \ - if (opNameSlice == toSlice(#op)) \ - { \ - resultValue = (((constArgs[0]->getValue()!=0) op (constArgs[1]->getValue()!=0)) ? 1 : 0); \ - } else +#define BINARY_OPERATOR_CASE(op) \ + if (opNameSlice == toSlice(#op)) \ + { \ + resultValue = constArgs[0]->getValue() op constArgs[1]->getValue(); \ + } \ + else + +#define DIV_OPERATOR_CASE(op) \ + if (opNameSlice == toSlice(#op)) \ + { \ + if (constArgs[1]->getValue() == 0) \ + { \ + if (sink) \ + sink->diagnose(newFuncDecl.getLoc(), Diagnostics::divideByZero); \ + return nullptr; \ + } \ + resultValue = constArgs[0]->getValue() op constArgs[1]->getValue(); \ + } \ + else + +#define LOGICAL_OPERATOR_CASE(op) \ + if (opNameSlice == toSlice(#op)) \ + { \ + resultValue = \ + (((constArgs[0]->getValue() != 0) op(constArgs[1]->getValue() != 0)) ? 1 : 0); \ + } \ + else #define SPECIAL_OPERATOR_CASE(op, IF_MATCH) \ - if (opNameSlice == toSlice(op)) \ - { \ - IF_MATCH \ - } else - -#define TERMINATING_CASE(MATCH) \ - { \ - MATCH \ - } + if (opNameSlice == toSlice(op)) \ + { \ + IF_MATCH \ + } \ + else + +#define TERMINATING_CASE(MATCH) {MATCH} // Handle the cases using the macros BINARY_OPERATOR_CASE(>=) @@ -1482,16 +1586,19 @@ Val* FuncCallIntVal::tryFoldImpl(ASTBuilder* astBuilder, Type* resultType, DeclR // Special cases need their "operator" names quoted. SPECIAL_OPERATOR_CASE("!", resultValue = ((constArgs[0]->getValue() != 0) ? 1 : 0);) SPECIAL_OPERATOR_CASE("~", resultValue = ~constArgs[0]->getValue();) - SPECIAL_OPERATOR_CASE("?:", resultValue = constArgs[0]->getValue() != 0 ? constArgs[1]->getValue() : constArgs[2]->getValue();) + SPECIAL_OPERATOR_CASE("?:", + resultValue = constArgs[0]->getValue() != 0 + ? constArgs[1]->getValue() + : constArgs[2]->getValue();) TERMINATING_CASE(SLANG_UNREACHABLE("constant folding of FuncCallIntVal");) return astBuilder->getIntVal(resultType, resultValue); // The macros for the cases are no longer needed so undef them all. #undef BINARY_OPERATOR_CASE -#undef DIV_OPERATOR_CASE +#undef DIV_OPERATOR_CASE #undef LOGICAL_OPERATOR_CASE -#undef SPECIAL_OPERATOR_CASE +#undef SPECIAL_OPERATOR_CASE #undef TERMINATING_CASE } return nullptr; @@ -1505,7 +1612,10 @@ Val* FuncCallIntVal::_linkTimeResolveOverride(Dictionary& map) return tryFoldImpl(getCurrentASTBuilder(), getType(), getFuncDeclRef(), newArgs, nullptr); } -Val* FuncCallIntVal::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) +Val* FuncCallIntVal::_substituteImplOverride( + ASTBuilder* astBuilder, + SubstitutionSet subst, + int* ioDiff) { int diff = 0; auto newFuncDeclRef = getFuncDeclRef().substituteImpl(astBuilder, subst, &diff); @@ -1526,7 +1636,11 @@ Val* FuncCallIntVal::_substituteImplOverride(ASTBuilder* astBuilder, Substitutio return newVal; else { - auto result = astBuilder->getOrCreate(getType(), newFuncDeclRef, getFuncType(), newArgs.getArrayView()); + auto result = astBuilder->getOrCreate( + getType(), + newFuncDeclRef, + getFuncType(), + newArgs.getArrayView()); return result; } } @@ -1590,7 +1704,10 @@ Val* CountOfIntVal::tryFold(ASTBuilder* astBuilder, Type* intType, Type* newType return result; } -Val* CountOfIntVal::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) +Val* CountOfIntVal::_substituteImplOverride( + ASTBuilder* astBuilder, + SubstitutionSet subst, + int* ioDiff) { int diff = 0; auto newType = as(getTypeArg()->substituteImpl(astBuilder, subst, &diff)); @@ -1644,7 +1761,10 @@ Val* WitnessLookupIntVal::_resolveImplOverride() return this; } -Val* WitnessLookupIntVal::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) +Val* WitnessLookupIntVal::_substituteImplOverride( + ASTBuilder* astBuilder, + SubstitutionSet subst, + int* ioDiff) { int diff = 0; auto newWitness = getWitness()->substituteImpl(astBuilder, subst, &diff); @@ -1669,16 +1789,17 @@ Val* WitnessLookupIntVal::tryFoldOrNull(ASTBuilder* astBuilder, SubtypeWitness* auto witnessEntry = tryLookUpRequirementWitness(astBuilder, witness, key); switch (witnessEntry.getFlavor()) { - case RequirementWitness::Flavor::val: - return witnessEntry.getVal(); - break; - default: - break; + case RequirementWitness::Flavor::val: return witnessEntry.getVal(); break; + default: break; } return nullptr; } -Val* WitnessLookupIntVal::tryFold(ASTBuilder* astBuilder, SubtypeWitness* witness, Decl* key, Type* type) +Val* WitnessLookupIntVal::tryFold( + ASTBuilder* astBuilder, + SubtypeWitness* witness, + Decl* key, + Type* type) { if (auto result = tryFoldOrNull(astBuilder, witness, key)) return result; @@ -1693,7 +1814,10 @@ void DifferentiateVal::_toTextOverride(StringBuilder& out) out << ")"; } -Val* DifferentiateVal::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) +Val* DifferentiateVal::_substituteImplOverride( + ASTBuilder* astBuilder, + SubstitutionSet subst, + int* ioDiff) { int diff = 0; auto newFunc = getFunc().substituteImpl(astBuilder, subst, &diff); @@ -1742,7 +1866,9 @@ Val* PolynomialIntValTerm::_resolveImplOverride() } if (diff) - return astBuilder->getOrCreate(getConstFactor(), newFactors.getArrayView()); + return astBuilder->getOrCreate( + getConstFactor(), + newFactors.getArrayView()); return this; } diff --git a/source/slang/slang-ast-val.h b/source/slang/slang-ast-val.h index 98161596d..7b33a8111 100644 --- a/source/slang/slang-ast-val.h +++ b/source/slang/slang-ast-val.h @@ -5,7 +5,8 @@ #include "slang-ast-base.h" #include "slang-ast-decl.h" -namespace Slang { +namespace Slang +{ // Syntax class definitions for compile-time values. @@ -14,12 +15,12 @@ class DirectDeclRef : public DeclRefBase public: SLANG_AST_CLASS(DirectDeclRef) - DirectDeclRef(Decl* decl) - { - setOperands(decl); - } + DirectDeclRef(Decl* decl) { setOperands(decl); } - DeclRefBase* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff); + DeclRefBase* _substituteImplOverride( + ASTBuilder* astBuilder, + SubstitutionSet subst, + int* ioDiff); void _toTextOverride(StringBuilder& out); Val* _resolveImplOverride(); DeclRefBase* _getBaseOverride(); @@ -37,12 +38,12 @@ public: DeclRefBase* getParentOperand() { return as(getOperand(1)); } - MemberDeclRef(Decl* decl, DeclRefBase* parent) - { - setOperands(decl, parent); - } + MemberDeclRef(Decl* decl, DeclRefBase* parent) { setOperands(decl, parent); } - DeclRefBase* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff); + DeclRefBase* _substituteImplOverride( + ASTBuilder* astBuilder, + SubstitutionSet subst, + int* ioDiff); void _toTextOverride(StringBuilder& out); @@ -52,7 +53,8 @@ public: }; -// Represent a lookup of SuperType::`m_decl` from `lookupSourceType` type that we know conforms to SuperType. +// Represent a lookup of SuperType::`m_decl` from `lookupSourceType` type that we know conforms to +// SuperType. class LookupDeclRef : public DeclRefBase { public: @@ -61,16 +63,10 @@ public: // m_decl represents the decl in SuperType that we want to lookup. // The source type that we are looking up from. - Type* getLookupSource() - { - return as(getOperand(1)); - } + Type* getLookupSource() { return as(getOperand(1)); } // Witness that `lookupSourceType`:SuperType. - SubtypeWitness* getWitness() - { - return as(getOperand(2)); - } + SubtypeWitness* getWitness() { return as(getOperand(2)); } LookupDeclRef(Decl* declToLookup, Type* lookupSource, SubtypeWitness* witness) { @@ -79,7 +75,10 @@ public: Decl* getSupDecl(); - DeclRefBase* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff); + DeclRefBase* _substituteImplOverride( + ASTBuilder* astBuilder, + SubstitutionSet subst, + int* ioDiff); void _toTextOverride(StringBuilder& out); @@ -125,7 +124,10 @@ public: OperandView getArgs() { return OperandView(this, 2, getArgCount()); } - DeclRefBase* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff); + DeclRefBase* _substituteImplOverride( + ASTBuilder* astBuilder, + SubstitutionSet subst, + int* ioDiff); void _toTextOverride(StringBuilder& out); @@ -135,7 +137,7 @@ public: }; // A compile-time integer (may not have a specific concrete value) -class IntVal : public Val +class IntVal : public Val { SLANG_ABSTRACT_AST_CLASS(IntVal) @@ -150,7 +152,7 @@ class IntVal : public Val }; // Trivial case of a value that is just a constant integer -class ConstantIntVal : public IntVal +class ConstantIntVal : public IntVal { SLANG_AST_CLASS(ConstantIntVal) @@ -159,15 +161,12 @@ class ConstantIntVal : public IntVal // Overrides should be public so base classes can access void _toTextOverride(StringBuilder& out); - ConstantIntVal(Type* inType, IntegerLiteralValue inValue) - { - setOperands(inType, inValue); - } + ConstantIntVal(Type* inType, IntegerLiteralValue inValue) { setOperands(inType, inValue); } bool _isLinkTimeValOverride() { return false; } }; // The logical "value" of a reference to a generic value parameter -class GenericParamIntVal : public IntVal +class GenericParamIntVal : public IntVal { SLANG_AST_CLASS(GenericParamIntVal) @@ -195,12 +194,13 @@ class TypeCastIntVal : public IntVal Val* _resolveImplOverride(); Val* getBase() { return getOperand(1); } - TypeCastIntVal(Type* inType, Val* inBase) - { - setOperands(inType, inBase); - } + TypeCastIntVal(Type* inType, Val* inBase) { setOperands(inType, inBase); } - static Val* tryFoldImpl(ASTBuilder* astBuilder, Type* resultType, Val* base, DiagnosticSink* sink); + static Val* tryFoldImpl( + ASTBuilder* astBuilder, + Type* resultType, + Val* base, + DiagnosticSink* sink); bool _isLinkTimeValOverride() { @@ -210,7 +210,6 @@ class TypeCastIntVal : public IntVal } Val* _linkTimeResolveOverride(Dictionary& map); - }; // An compile time int val as result of some general computation. @@ -227,14 +226,23 @@ class FuncCallIntVal : public IntVal OperandView getArgs() { return OperandView(this, 3, getOperandCount() - 3); } Index getArgCount() { return getOperandCount() - 3; } - FuncCallIntVal(Type* inType, DeclRef inFuncDeclRef, Type* inFuncType, ArrayView inArgs) + FuncCallIntVal( + Type* inType, + DeclRef inFuncDeclRef, + Type* inFuncType, + ArrayView inArgs) { setOperands(inType, inFuncDeclRef, inFuncType); for (auto arg : inArgs) m_operands.add(ValNodeOperand(arg)); } - static Val* tryFoldImpl(ASTBuilder* astBuilder, Type* resultType, DeclRef newFuncDecl, List& newArgs, DiagnosticSink* sink); + static Val* tryFoldImpl( + ASTBuilder* astBuilder, + Type* resultType, + DeclRef newFuncDecl, + List& newArgs, + DiagnosticSink* sink); bool _isLinkTimeValOverride() { @@ -253,20 +261,14 @@ class CountOfIntVal : public IntVal { SLANG_AST_CLASS(CountOfIntVal) - CountOfIntVal(Type* inType, Type* typeArg) - { - setOperands(inType, typeArg); - } + CountOfIntVal(Type* inType, Type* typeArg) { setOperands(inType, typeArg); } Val* getTypeArg() { return getOperand(1); } void _toTextOverride(StringBuilder& out); Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff); Val* _resolveImplOverride(); - bool _isLinkTimeValOverride() - { - return false; - } + bool _isLinkTimeValOverride() { return false; } static Val* tryFoldOrNull(ASTBuilder* astBuilder, Type* intType, Type* newType); @@ -293,10 +295,7 @@ class WitnessLookupIntVal : public IntVal static Val* tryFold(ASTBuilder* astBuilder, SubtypeWitness* witness, Decl* key, Type* type); - bool _isLinkTimeValOverride() - { - return false; - } + bool _isLinkTimeValOverride() { return false; } }; // polynomial expression "2*a*b^3 + 1" will be represented as: @@ -325,7 +324,8 @@ public: if (thisGenParam->equals(thatGenParam)) return getPower() < other.getPower(); else - return thisGenParam->getDeclRef().getDecl() < thatGenParam->getDeclRef().getDecl(); + return thisGenParam->getDeclRef().getDecl() < + thatGenParam->getDeclRef().getDecl(); } else { @@ -338,9 +338,9 @@ public: { return false; } - return getParam() == other.getParam() ? getPower() < other.getPower() : getParam() < other.getParam(); + return getParam() == other.getParam() ? getPower() < other.getPower() + : getParam() < other.getParam(); } - } // for sorting only. bool operator==(const PolynomialIntValFactor& other) const @@ -360,24 +360,30 @@ public: { return getPower() == other.getPower() && getParam()->equals(other.getParam()); } - }; class PolynomialIntValTerm : public Val { SLANG_AST_CLASS(PolynomialIntValTerm) public: IntegerLiteralValue getConstFactor() const { return getIntConstOperand(0); } - OperandView getParamFactors() const { return OperandView(this, 1, getOperandCount() - 1); } + OperandView getParamFactors() const + { + return OperandView(this, 1, getOperandCount() - 1); + } Val* _resolveImplOverride(); - PolynomialIntValTerm(IntegerLiteralValue inConstFactor, ArrayView inParamFactors) + PolynomialIntValTerm( + IntegerLiteralValue inConstFactor, + ArrayView inParamFactors) { setOperands(inConstFactor); addOperands(inParamFactors); } - PolynomialIntValTerm(IntegerLiteralValue inConstFactor, OperandView inParamFactors) + PolynomialIntValTerm( + IntegerLiteralValue inConstFactor, + OperandView inParamFactors) { setOperands(inConstFactor); addOperands(inParamFactors); @@ -438,9 +444,11 @@ class PolynomialIntVal : public IntVal { SLANG_AST_CLASS(PolynomialIntVal) public: - IntegerLiteralValue getConstantTerm() { return getIntConstOperand(1); }; - OperandView getTerms() { return OperandView(this, 2, getOperandCount() - 2); }; + OperandView getTerms() + { + return OperandView(this, 2, getOperandCount() - 2); + }; bool isConstant() { return getOperandCount() == 1; } @@ -453,7 +461,10 @@ public: static IntVal* add(ASTBuilder* astBuilder, IntVal* op0, IntVal* op1); static IntVal* sub(ASTBuilder* astBuilder, IntVal* op0, IntVal* op1); static IntVal* mul(ASTBuilder* astBuilder, IntVal* op0, IntVal* op1); - PolynomialIntVal(Type* inType, IntegerLiteralValue inConstantTerm, ArrayView inTerms) + PolynomialIntVal( + Type* inType, + IntegerLiteralValue inConstantTerm, + ArrayView inTerms) { setOperands(inType, inConstantTerm); addOperands(inTerms); @@ -470,8 +481,8 @@ public: } }; - /// An unknown integer value indicating an erroneous sub-expression -class ErrorIntVal : public IntVal +/// An unknown integer value indicating an erroneous sub-expression +class ErrorIntVal : public IntVal { SLANG_AST_CLASS(ErrorIntVal) @@ -484,11 +495,8 @@ class ErrorIntVal : public IntVal // Overrides should be public so base classes can access void _toTextOverride(StringBuilder& out); Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff); - Val* _resolveImplOverride() { return this; } - bool _isLinkTimeValOverride() - { - return false; - } + Val* _resolveImplOverride() { return this; } + bool _isLinkTimeValOverride() { return false; } }; // A witness to the fact that some proposition is true, encoded @@ -523,8 +531,8 @@ class ErrorIntVal : public IntVal // navigate from the knowledge that `X : ILight` to // the concrete declarations that provide the implementation // of `ILight` for `X`. -// -class Witness : public Val +// +class Witness : public Val { SLANG_ABSTRACT_AST_CLASS(Witness) }; @@ -534,7 +542,7 @@ class Witness : public Val // relationships and type-conforms-to-interface relationships) // // TODO: we may need to tease those apart. -class SubtypeWitness : public Witness +class SubtypeWitness : public Witness { SLANG_ABSTRACT_AST_CLASS(SubtypeWitness) @@ -561,7 +569,7 @@ class TypePackSubtypeWitness : public SubtypeWitness { setOperands(sub); addOperands(sup); - for(auto w : witnesses) + for (auto w : witnesses) addOperands(ValNodeOperand(w)); } @@ -602,14 +610,11 @@ class ExpandSubtypeWitness : public SubtypeWitness Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff); }; -class TypeEqualityWitness : public SubtypeWitness +class TypeEqualityWitness : public SubtypeWitness { SLANG_AST_CLASS(TypeEqualityWitness) - TypeEqualityWitness(Type* subType, Type* supType) - { - setOperands(subType, supType); - } + TypeEqualityWitness(Type* subType, Type* supType) { setOperands(subType, supType); } // Overrides should be public so base classes can access void _toTextOverride(StringBuilder& out); @@ -618,14 +623,11 @@ class TypeEqualityWitness : public SubtypeWitness // A witness that one type is a subtype of another // because some in-scope declaration says so -class DeclaredSubtypeWitness : public SubtypeWitness +class DeclaredSubtypeWitness : public SubtypeWitness { SLANG_AST_CLASS(DeclaredSubtypeWitness) - DeclRef getDeclRef() - { - return as(getOperand(2)); - } + DeclRef getDeclRef() { return as(getOperand(2)); } bool isEquality() { @@ -648,27 +650,25 @@ class DeclaredSubtypeWitness : public SubtypeWitness }; // A witness that `sub : sup` because `sub : mid` and `mid : sup` -class TransitiveSubtypeWitness : public SubtypeWitness +class TransitiveSubtypeWitness : public SubtypeWitness { SLANG_AST_CLASS(TransitiveSubtypeWitness) // Witness that `sub : mid` - SubtypeWitness* getSubToMid() - { - return as(getOperand(2)); - } + SubtypeWitness* getSubToMid() { return as(getOperand(2)); } // Witness that `mid : sup` - SubtypeWitness* getMidToSup() - { - return as(getOperand(3)); - } + SubtypeWitness* getMidToSup() { return as(getOperand(3)); } // Overrides should be public so base classes can access void _toTextOverride(StringBuilder& out); Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff); - TransitiveSubtypeWitness(Type* subType, Type* supType, SubtypeWitness* inSubToMid, SubtypeWitness* inMidToSup) + TransitiveSubtypeWitness( + Type* subType, + Type* supType, + SubtypeWitness* inSubToMid, + SubtypeWitness* inMidToSup) { setOperands(subType, supType, inSubToMid, inMidToSup); } @@ -678,7 +678,7 @@ class TransitiveSubtypeWitness : public SubtypeWitness // A witness that `sub : sup` because `sub` was wrapped into // an existential of type `sup`. -class ExtractExistentialSubtypeWitness : public SubtypeWitness +class ExtractExistentialSubtypeWitness : public SubtypeWitness { SLANG_AST_CLASS(ExtractExistentialSubtypeWitness) @@ -695,18 +695,15 @@ class ExtractExistentialSubtypeWitness : public SubtypeWitness Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff); }; - /// A witness of the fact that a user provided "__Dynamic" type argument is a - /// subtype to the existential type parameter. +/// A witness of the fact that a user provided "__Dynamic" type argument is a +/// subtype to the existential type parameter. class DynamicSubtypeWitness : public SubtypeWitness { SLANG_AST_CLASS(DynamicSubtypeWitness) - DynamicSubtypeWitness(Type* inSub, Type* inSup) - { - setOperands(inSub, inSup); - } + DynamicSubtypeWitness(Type* inSub, Type* inSup) { setOperands(inSub, inSup); } }; - /// A witness that `T : L & R` because `T : L` and `T : R` +/// A witness that `T : L & R` because `T : L` and `T : R` class ConjunctionSubtypeWitness : public SubtypeWitness { SLANG_AST_CLASS(ConjunctionSubtypeWitness) @@ -738,7 +735,7 @@ class ConjunctionSubtypeWitness : public SubtypeWitness ConversionCost _getOverloadResolutionCostOverride(); }; - /// A witness that `T <: L` or `T <: R` because `T <: L&R` +/// A witness that `T <: L` or `T <: R` because `T <: L&R` class ExtractFromConjunctionSubtypeWitness : public SubtypeWitness { SLANG_AST_CLASS(ExtractFromConjunctionSubtypeWitness) @@ -751,16 +748,20 @@ class ExtractFromConjunctionSubtypeWitness : public SubtypeWitness /// Witness that `T < L & R` SubtypeWitness* getConjunctionWitness() { return as(getOperand(2)); }; - ExtractFromConjunctionSubtypeWitness(Type* inSub, Type* inSup, SubtypeWitness* witness, int index) + ExtractFromConjunctionSubtypeWitness( + Type* inSub, + Type* inSup, + SubtypeWitness* witness, + int index) { setOperands(inSub, inSup, witness, index); } - /// The zero-based index of the super-type we care about in the conjunction - /// - /// If `conjunctionWitness` is `T < L & R` then this index should be zero if - /// we want to represent `T < L` and one if we want `T < R`. - /// + /// The zero-based index of the super-type we care about in the conjunction + /// + /// If `conjunctionWitness` is `T < L & R` then this index should be zero if + /// we want to represent `T < L` and one if we want `T < R`. + /// int getIndexInConjunction() { return (int)getIntConstOperand(3); }; void _toTextOverride(StringBuilder& out); @@ -769,7 +770,7 @@ class ExtractFromConjunctionSubtypeWitness : public SubtypeWitness ConversionCost _getOverloadResolutionCostOverride(); }; - /// A value that represents a modifier attached to some other value +/// A value that represents a modifier attached to some other value class ModifierVal : public Val { SLANG_AST_CLASS(ModifierVal) @@ -811,15 +812,12 @@ class NoDiffModifierVal : public TypeModifierVal Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff); }; - /// Represents the result of differentiating a function. +/// Represents the result of differentiating a function. class DifferentiateVal : public Val { SLANG_AST_CLASS(DifferentiateVal) - DifferentiateVal(DeclRef inFunc) - { - setOperands(inFunc); - } + DifferentiateVal(DeclRef inFunc) { setOperands(inFunc); } DeclRef getFunc() { return as(getOperand(0)); } @@ -833,7 +831,8 @@ class ForwardDifferentiateVal : public DifferentiateVal SLANG_AST_CLASS(ForwardDifferentiateVal) ForwardDifferentiateVal(DeclRef inFunc) : DifferentiateVal(inFunc) - {} + { + } }; class BackwardDifferentiateVal : public DifferentiateVal @@ -842,7 +841,8 @@ class BackwardDifferentiateVal : public DifferentiateVal BackwardDifferentiateVal(DeclRef inFunc) : DifferentiateVal(inFunc) - {} + { + } }; class BackwardDifferentiateIntermediateTypeVal : public DifferentiateVal @@ -851,7 +851,8 @@ class BackwardDifferentiateIntermediateTypeVal : public DifferentiateVal BackwardDifferentiateIntermediateTypeVal(DeclRef inFunc) : DifferentiateVal(inFunc) - {} + { + } }; class BackwardDifferentiatePrimalVal : public DifferentiateVal @@ -860,7 +861,8 @@ class BackwardDifferentiatePrimalVal : public DifferentiateVal BackwardDifferentiatePrimalVal(DeclRef inFunc) : DifferentiateVal(inFunc) - {} + { + } }; class BackwardDifferentiatePropagateVal : public DifferentiateVal @@ -869,7 +871,8 @@ class BackwardDifferentiatePropagateVal : public DifferentiateVal BackwardDifferentiatePropagateVal(DeclRef inFunc) : DifferentiateVal(inFunc) - {} + { + } }; diff --git a/source/slang/slang-capability.cpp b/source/slang/slang-capability.cpp index bae5a1c72..f56e246dd 100644 --- a/source/slang/slang-capability.cpp +++ b/source/slang/slang-capability.cpp @@ -26,9 +26,9 @@ enum class CapabilityNameFlavor : int32_t // An abstract capability represents a class of feature // where multiple distinct implementations might be possible. - // 'raytracing' may be allowed with a 'raygen' "stage", but + // 'raytracing' may be allowed with a 'raygen' "stage", but // not a 'vertex' "stage" - // For more information (and a clearer description of the rules), + // For more information (and a clearer description of the rules), // read `slang-capabilities.capdef` Abstract, @@ -52,13 +52,13 @@ struct CapabilityAtomInfo UnownedStringSlice name; /// Flavor of atom: concrete, abstract, or alias - CapabilityNameFlavor flavor; + CapabilityNameFlavor flavor; /// If the atom is a direct descendent of an abstract base, keep that for reference here. CapabilityName abstractBase; /// Ranking to use when deciding if this atom is a "better" one to select. - uint32_t rank; + uint32_t rank; /// Canonical representation of atoms in the form of disjoint conjunctions of atoms. ArrayView canonicalRepresentation; @@ -136,8 +136,11 @@ CapabilityAtom getLatestSpirvAtom() if (result == CapabilityAtom::Invalid) { CapabilitySet latestSpirvCapSet = CapabilitySet(CapabilityName::_spirv_latest); - auto latestSpirvCapSetElements = latestSpirvCapSet.getAtomSets()->getElements(); - result = asAtom(latestSpirvCapSetElements[latestSpirvCapSetElements.getCount() - 2]); //-1 gets shader stage + auto latestSpirvCapSetElements = + latestSpirvCapSet.getAtomSets()->getElements(); + result = asAtom( + latestSpirvCapSetElements[latestSpirvCapSetElements.getCount() - 2]); //-1 gets shader + // stage } return result; } @@ -148,8 +151,11 @@ CapabilityAtom getLatestMetalAtom() if (result == CapabilityAtom::Invalid) { CapabilitySet latestMetalCapSet = CapabilitySet(CapabilityName::metallib_latest); - auto latestMetalCapSetElements = latestMetalCapSet.getAtomSets()->getElements(); - result = asAtom(latestMetalCapSetElements[latestMetalCapSetElements.getCount() - 2]); //-1 gets shader stage + auto latestMetalCapSetElements = + latestMetalCapSet.getAtomSets()->getElements(); + result = asAtom( + latestMetalCapSetElements[latestMetalCapSetElements.getCount() - 2]); //-1 gets shader + // stage } return result; } @@ -171,7 +177,7 @@ bool isCapabilityDerivedFrom(CapabilityAtom atom, CapabilityAtom base) return false; } -//CapabilityAtomSet +// CapabilityAtomSet CapabilityAtomSet CapabilityAtomSet::newSetWithoutImpliedAtoms() const { @@ -182,8 +188,8 @@ CapabilityAtomSet CapabilityAtomSet::newSetWithoutImpliedAtoms() const for (auto atom1UInt : *this) { CapabilityAtom atom1 = (CapabilityAtom)atom1UInt; - if (!candidateForSimplifiedList.addIfNotExists(atom1, true) - && candidateForSimplifiedList[atom1] == false) + if (!candidateForSimplifiedList.addIfNotExists(atom1, true) && + candidateForSimplifiedList[atom1] == false) continue; for (auto atom2UInt : *this) @@ -192,8 +198,8 @@ CapabilityAtomSet CapabilityAtomSet::newSetWithoutImpliedAtoms() const continue; CapabilityAtom atom2 = (CapabilityAtom)atom2UInt; - if (!candidateForSimplifiedList.addIfNotExists(atom2, true) - && candidateForSimplifiedList[atom2] == false) + if (!candidateForSimplifiedList.addIfNotExists(atom2, true) && + candidateForSimplifiedList[atom2] == false) continue; auto atomInfo1 = _getInfo(atom1).canonicalRepresentation; @@ -217,7 +223,7 @@ CapabilityAtomSet CapabilityAtomSet::newSetWithoutImpliedAtoms() const } } for (auto i : candidateForSimplifiedList) - if(i.second) + if (i.second) simplifiedSet.add((UInt)i.first); return simplifiedSet; } @@ -247,11 +253,15 @@ CapabilityAtom getStageAtomInSet(const CapabilityAtomSet& atomSet) } template -void CapabilitySet::addPermutationsOfConjunctionForEachInContainer(CapabilityAtomSet& setToPermutate, const CapabilityAtomSet& elementsToPermutateWith, CapabilityAtom knownTargetAtom, CapabilityAtom knownStageAtom) +void CapabilitySet::addPermutationsOfConjunctionForEachInContainer( + CapabilityAtomSet& setToPermutate, + const CapabilityAtomSet& elementsToPermutateWith, + CapabilityAtom knownTargetAtom, + CapabilityAtom knownStageAtom) { SLANG_UNUSED(knownTargetAtom); SLANG_UNUSED(knownStageAtom); - for(auto i : elementsToPermutateWith) + for (auto i : elementsToPermutateWith) { CapabilityName atom = (CapabilityName)i; CapabilityAtomSet conjunctionPermutation = setToPermutate; @@ -273,7 +283,10 @@ void CapabilitySet::addPermutationsOfConjunctionForEachInContainer(CapabilityAto } } -void CapabilitySet::addConjunction(CapabilityAtomSet conjunction, CapabilityAtom knownTargetAtom, CapabilityAtom knownStageAtom) +void CapabilitySet::addConjunction( + CapabilityAtomSet conjunction, + CapabilityAtom knownTargetAtom, + CapabilityAtom knownStageAtom) { if (knownTargetAtom == CapabilityAtom::Invalid) { @@ -281,7 +294,11 @@ void CapabilitySet::addConjunction(CapabilityAtomSet conjunction, CapabilityAtom // if no target in conjunction, add a permutation of the conjunction with every target if (knownTargetAtom == CapabilityAtom::Invalid) { - addPermutationsOfConjunctionForEachInContainer(conjunction, getAtomSetOfTargets(), CapabilityAtom::Invalid, getStageAtomInSet(conjunction)); + addPermutationsOfConjunctionForEachInContainer( + conjunction, + getAtomSetOfTargets(), + CapabilityAtom::Invalid, + getStageAtomInSet(conjunction)); return; } } @@ -295,7 +312,11 @@ void CapabilitySet::addConjunction(CapabilityAtomSet conjunction, CapabilityAtom if (knownStageAtom == CapabilityAtom::Invalid) { capabilitySetToTargetSet.shaderStageSets.reserve(kCapabilityStageCount); - addPermutationsOfConjunctionForEachInContainer(conjunction, getAtomSetOfStages(), knownTargetAtom, CapabilityAtom::Invalid); + addPermutationsOfConjunctionForEachInContainer( + conjunction, + getAtomSetOfStages(), + knownTargetAtom, + CapabilityAtom::Invalid); return; } } @@ -328,8 +349,7 @@ CapabilityAtom CapabilitySet::getUniquelyImpliedStageAtom() const return result; } -CapabilitySet::CapabilitySet() -{} +CapabilitySet::CapabilitySet() {} CapabilitySet::CapabilitySet(Int atomCount, CapabilityName const* atoms) { @@ -383,7 +403,7 @@ bool CapabilitySet::isIncompatibleWith(CapabilityAtom other) const if (isEmpty()) return false; - + CapabilitySet otherSet((CapabilityName)other); return isIncompatibleWith(otherSet); } @@ -403,7 +423,8 @@ bool CapabilitySet::isIncompatibleWith(CapabilitySet const& other) const if (other.isEmpty()) return false; - // Incompatible means there are 0 intersecting abstract nodes from sets in `other` with sets in `this` + // Incompatible means there are 0 intersecting abstract nodes from sets in `other` with sets in + // `this` for (auto& otherSet : other.m_targetSets) { auto targetSet = this->m_targetSets.tryGetValue(otherSet.first); @@ -453,7 +474,9 @@ bool CapabilitySet::implies(CapabilityAtom atom) const return this->implies(tmpSet); } -CapabilitySet::ImpliesReturnFlags CapabilitySet::_implies(CapabilitySet const& otherSet, ImpliesFlags flags) const +CapabilitySet::ImpliesReturnFlags CapabilitySet::_implies( + CapabilitySet const& otherSet, + ImpliesFlags flags) const { // x implies (c | d) only if (x implies c) and (x implies d). @@ -512,9 +535,11 @@ CapabilitySet::ImpliesReturnFlags CapabilitySet::_implies(CapabilitySet const& o bool CapabilitySet::implies(CapabilitySet const& other) const { - return (int)_implies(other, ImpliesFlags::None) & (int)CapabilitySet::ImpliesReturnFlags::Implied; + return (int)_implies(other, ImpliesFlags::None) & + (int)CapabilitySet::ImpliesReturnFlags::Implied; } -CapabilitySet::ImpliesReturnFlags CapabilitySet::atLeastOneSetImpliedInOther(CapabilitySet const& other) const +CapabilitySet::ImpliesReturnFlags CapabilitySet::atLeastOneSetImpliedInOther( + CapabilitySet const& other) const { return _implies(other, ImpliesFlags::OnlyRequireASingleValidImply); } @@ -528,9 +553,8 @@ void CapabilityTargetSet::unionWith(const CapabilityTargetSet& other) if (!thisStageSet.atomSet) thisStageSet.atomSet = otherStageSet.second.atomSet; - else - if(otherStageSet.second.atomSet) - thisStageSet.atomSet->unionWith(*otherStageSet.second.atomSet); + else if (otherStageSet.second.atomSet) + thisStageSet.atomSet->unionWith(*otherStageSet.second.atomSet); } } @@ -549,8 +573,9 @@ void CapabilitySet::unionWith(const CapabilitySet& other) } } -/// Join sets, but: -/// 1. do not destroy target set's which are incompatible with `other` (destroying shaderStageSets is fine) +/// Join sets, but: +/// 1. do not destroy target set's which are incompatible with `other` (destroying shaderStageSets +/// is fine) /// 2. do not create an `CapabilityAtom::Invalid` target set. void CapabilitySet::nonDestructiveJoin(const CapabilitySet& other) { @@ -607,7 +632,7 @@ bool CapabilityStageSet::tryJoin(const CapabilityTargetSet& other) return false; // should not exceed far beyond 2*2 or 1*1 elements - if(otherStageSet->atomSet && this->atomSet) + if (otherStageSet->atomSet && this->atomSet) this->atomSet->add(*otherStageSet->atomSet); return true; @@ -665,7 +690,9 @@ void CapabilitySet::join(const CapabilitySet& other) this->m_targetSets[CapabilityAtom::Invalid].target = CapabilityAtom::Invalid; } -static uint32_t _calcAtomListDifferenceScore(List const& thisList, List const& thatList) +static uint32_t _calcAtomListDifferenceScore( + List const& thisList, + List const& thatList) { uint32_t score = 0; @@ -680,8 +707,10 @@ static uint32_t _calcAtomListDifferenceScore(List const& thisLis Index thatIndex = 0; for (;;) { - if (thisIndex == thisCount) break; - if (thatIndex == thatCount) break; + if (thisIndex == thisCount) + break; + if (thatIndex == thatCount) + break; auto thisAtom = thisList[thisIndex]; auto thatAtom = thatList[thatIndex]; @@ -731,13 +760,16 @@ bool CapabilitySet::hasSameTargets(const CapabilitySet& other) const // MSVC incorrectly throws warning #if defined(_MSC_VER) #pragma warning(push) -#pragma warning(disable:4702) +#pragma warning(disable : 4702) #endif -bool CapabilitySet::isBetterForTarget(CapabilitySet const& that, CapabilitySet const& targetCaps, bool& isEqual) const +bool CapabilitySet::isBetterForTarget( + CapabilitySet const& that, + CapabilitySet const& targetCaps, + bool& isEqual) const { if (this->isEmpty() && (that.isEmpty() || that.isInvalid())) { - if(this->isEmpty() && that.isEmpty()) + if (this->isEmpty() && that.isEmpty()) isEqual = true; return true; } @@ -761,18 +793,22 @@ bool CapabilitySet::isBetterForTarget(CapabilitySet const& that, CapabilitySet c // required to have shader stage for (auto& shaderStageSetsWeNeed : targetWeNeed.second.shaderStageSets) { - auto thisStageSets = thisTarget->shaderStageSets.tryGetValue(shaderStageSetsWeNeed.first); + auto thisStageSets = + thisTarget->shaderStageSets.tryGetValue(shaderStageSetsWeNeed.first); if (!thisStageSets) return false; - auto thatStageSets = thatTarget->shaderStageSets.tryGetValue(shaderStageSetsWeNeed.first); + auto thatStageSets = + thatTarget->shaderStageSets.tryGetValue(shaderStageSetsWeNeed.first); if (!thatStageSets) return true; - // We want the smallest (most specialized) set which is still contained by this/that. This means: + // We want the smallest (most specialized) set which is still contained by this/that. + // This means: // 1. target.contains(this/that) // 2. choose smallest super set - // 3. rank each super set and their atoms, choose the smallest rank'd set (most specialized) - if(shaderStageSetsWeNeed.second.atomSet) + // 3. rank each super set and their atoms, choose the smallest rank'd set (most + // specialized) + if (shaderStageSetsWeNeed.second.atomSet) { auto& shaderStageSetWeNeed = shaderStageSetsWeNeed.second.atomSet.value(); @@ -783,16 +819,21 @@ bool CapabilitySet::isBetterForTarget(CapabilitySet const& that, CapabilitySet c CapabilityAtomSet thatSet{}; Index thatSetCount = 0; - // subtraction of the set we want gets us the "elements which 'targetSet' has but `this/that` is less specialized for" - if(thisStageSets->atomSet) + // subtraction of the set we want gets us the "elements which 'targetSet' has but + // `this/that` is less specialized for" + if (thisStageSets->atomSet) { auto& thisStageSet = thisStageSets->atomSet.value(); - // if `thisStageSet` is more specialized than the target, `thisStageSet` should not be a candidate + // if `thisStageSet` is more specialized than the target, `thisStageSet` should + // not be a candidate if (thisStageSet == shaderStageSetWeNeed) - return true; + return true; if (shaderStageSetWeNeed.contains(thisStageSet)) { - CapabilityAtomSet::calcSubtract(tmp_set, shaderStageSetWeNeed, thisStageSet); + CapabilityAtomSet::calcSubtract( + tmp_set, + shaderStageSetWeNeed, + thisStageSet); tmpCount = tmp_set.countElements(); if (thisSetCount < tmpCount) { @@ -808,7 +849,10 @@ bool CapabilitySet::isBetterForTarget(CapabilitySet const& that, CapabilitySet c return false; if (shaderStageSetWeNeed.contains(thatStageSet)) { - CapabilityAtomSet::calcSubtract(tmp_set, shaderStageSetWeNeed, thatStageSet); + CapabilityAtomSet::calcSubtract( + tmp_set, + shaderStageSetWeNeed, + thatStageSet); tmpCount = tmp_set.countElements(); if (thatSetCount < tmpCount) { @@ -820,8 +864,8 @@ bool CapabilitySet::isBetterForTarget(CapabilitySet const& that, CapabilitySet c if (thisSet == thatSet) isEqual = true; - - //empty means no candidate + + // empty means no candidate if (thisSet.areAllZero()) return false; if (thatSet.areAllZero()) @@ -830,13 +874,16 @@ bool CapabilitySet::isBetterForTarget(CapabilitySet const& that, CapabilitySet c return true; else if (thisSetCount > thatSetCount) return false; - + auto thisSetElements = thisSet.getElements(); auto thatSetElements = thisSet.getElements(); - auto shaderStageSetWeNeedElements = shaderStageSetWeNeed.getElements(); + auto shaderStageSetWeNeedElements = + shaderStageSetWeNeed.getElements(); - auto thisDiffScore = _calcAtomListDifferenceScore(thisSetElements, shaderStageSetWeNeedElements); - auto thatDiffScore = _calcAtomListDifferenceScore(thisSetElements, shaderStageSetWeNeedElements); + auto thisDiffScore = + _calcAtomListDifferenceScore(thisSetElements, shaderStageSetWeNeedElements); + auto thatDiffScore = + _calcAtomListDifferenceScore(thisSetElements, shaderStageSetWeNeedElements); return thisDiffScore < thatDiffScore; } @@ -853,17 +900,21 @@ CapabilitySet::AtomSets::Iterator CapabilitySet::getAtomSets() const return CapabilitySet::AtomSets::Iterator(&this->getCapabilityTargetSets()).begin(); } -bool CapabilitySet::checkCapabilityRequirement(CapabilitySet const& available, CapabilitySet const& required, CapabilityAtomSet& outFailedAvailableSet) +bool CapabilitySet::checkCapabilityRequirement( + CapabilitySet const& available, + CapabilitySet const& required, + CapabilityAtomSet& outFailedAvailableSet) { // Requirements x are met by available disjoint capabilities (a | b) iff // both 'a' satisfies x and 'b' satisfies x. // If we have a caller function F() decorated with: // [require(hlsl, _sm_6_3)] [require(spirv, _spv_ray_tracing)] void F() { g(); } - // We'd better make sure that `g()` can be compiled with both (hlsl+_sm_6_3) and (spirv+_spv_ray_tracing) capability sets. - // In this method, F()'s capability declaration is represented by `available`, - // and g()'s capability is represented by `required`. - // We will check that for every capability conjunction X of F(), there is one capability conjunction Y in g() such that X implies Y. - // + // We'd better make sure that `g()` can be compiled with both (hlsl+_sm_6_3) and + // (spirv+_spv_ray_tracing) capability sets. In this method, F()'s capability declaration is + // represented by `available`, and g()'s capability is represented by `required`. We will check + // that for every capability conjunction X of F(), there is one capability conjunction Y in g() + // such that X implies Y. + // // if empty there is no body, all capabilities are supported. if (required.isEmpty()) @@ -879,9 +930,10 @@ bool CapabilitySet::checkCapabilityRequirement(CapabilitySet const& available, C // if (available.isEmpty() && !required.isEmpty()) return false; - - - // if all sets in `available` are not a super-set to at least 1 `required` set, then we have an err + + + // if all sets in `available` are not a super-set to at least 1 `required` set, then we have an + // err for (auto& availableTarget : available.m_targetSets) { auto reqTarget = required.m_targetSets.tryGetValue(availableTarget.first); @@ -901,26 +953,29 @@ bool CapabilitySet::checkCapabilityRequirement(CapabilitySet const& available, C } const CapabilityAtomSet* lastBadStage = nullptr; - if(availableStage.second.atomSet) + if (availableStage.second.atomSet) { const auto& availableStageSet = availableStage.second.atomSet.value(); lastBadStage = nullptr; - if(reqStage->atomSet) + if (reqStage->atomSet) { const auto& reqStageSet = reqStage->atomSet.value(); if (availableStageSet.contains(reqStageSet)) break; - else + else lastBadStage = &reqStageSet; } if (lastBadStage) { // get missing atoms - CapabilityAtomSet::calcSubtract(outFailedAvailableSet, *lastBadStage, availableStageSet); + CapabilityAtomSet::calcSubtract( + outFailedAvailableSet, + *lastBadStage, + availableStageSet); return false; } } - } + } } return true; @@ -931,7 +986,8 @@ inline CapabilityName maybeConvertSpirvVersionToGlslSpirvVersion(CapabilityName& { if (atom >= CapabilityName::_spirv_1_0 && asAtom(atom) <= getLatestSpirvAtom()) { - return (CapabilityName)((Int)CapabilityName::glsl_spirv_1_0 + ((Int)atom - (Int)CapabilityName::_spirv_1_0)); + return (CapabilityName)((Int)CapabilityName::glsl_spirv_1_0 + + ((Int)atom - (Int)CapabilityName::_spirv_1_0)); } return CapabilityName::Invalid; } @@ -963,7 +1019,8 @@ void CapabilitySet::addSpirvVersionFromOtherAsGlslSpirvVersion(CapabilitySet& ot otherAtom = otherStageSet.second.atomSet->end(); continue; } - auto maybeConvertedSpirvVersionAtom = maybeConvertSpirvVersionToGlslSpirvVersion(otherAtomName); + auto maybeConvertedSpirvVersionAtom = + maybeConvertSpirvVersionToGlslSpirvVersion(otherAtomName); if (maybeConvertedSpirvVersionAtom == CapabilityName::Invalid) continue; @@ -994,29 +1051,29 @@ void printDiagnosticArg(StringBuilder& sb, const CapabilityAtomSet atomSet) } } -// Collection of stages which have same atom sets to compress reprisentation of atom and stage per target +// Collection of stages which have same atom sets to compress reprisentation of atom and stage per +// target struct CompressedCapabilitySet { - /// Collection of stages which have same atom sets to compress reprisentation of atom and stage: {vertex/fragment, ... } + /// Collection of stages which have same atom sets to compress reprisentation of atom and stage: + /// {vertex/fragment, ... } struct StageAndAtomSet { CapabilityAtomSet stages; CapabilityAtomSet atomsWithoutStage; }; - auto begin() - { - return atomSetsOfTargets.begin(); - } + auto begin() { return atomSetsOfTargets.begin(); } - /// Compress 1 capabilitySet into a reprisentation which merges stages that share all of their atoms for printing. + /// Compress 1 capabilitySet into a reprisentation which merges stages that share all of their + /// atoms for printing. Dictionary> atomSetsOfTargets; CompressedCapabilitySet(const CapabilitySet& capabilitySet) { for (auto& atomSet : capabilitySet.getAtomSets()) { auto target = getTargetAtomInSet(atomSet); - + auto stageInSetAtom = getStageAtomInSet(atomSet); CapabilityAtomSet stageInSet; stageInSet.add((UInt)stageInSetAtom); @@ -1025,15 +1082,17 @@ struct CompressedCapabilitySet CapabilityAtomSet::calcSubtract(atomsWithoutStage, atomSet, stageInSet); if (!atomSetsOfTargets.containsKey(target)) { - atomSetsOfTargets[target].add({ stageInSet, atomsWithoutStage }); + atomSetsOfTargets[target].add({stageInSet, atomsWithoutStage}); continue; } - // try to find an equivlent atom set by iterating all of the same `atomSetsOfTarget[target]` and merge these 2 together. + // try to find an equivlent atom set by iterating all of the same + // `atomSetsOfTarget[target]` and merge these 2 together. auto& atomSetsOfTarget = atomSetsOfTargets[target]; for (auto& i : atomSetsOfTarget) { - if (i.atomsWithoutStage.contains(atomsWithoutStage) && atomsWithoutStage.contains(i.atomsWithoutStage)) + if (i.atomsWithoutStage.contains(atomsWithoutStage) && + atomsWithoutStage.contains(i.atomsWithoutStage)) { i.stages.add((UInt)stageInSetAtom); } @@ -1041,18 +1100,19 @@ struct CompressedCapabilitySet } for (auto& targetSets : atomSetsOfTargets) for (auto& targetSet : targetSets.second) - targetSet.atomsWithoutStage = targetSet.atomsWithoutStage.newSetWithoutImpliedAtoms(); + targetSet.atomsWithoutStage = + targetSet.atomsWithoutStage.newSetWithoutImpliedAtoms(); } }; void printDiagnosticArg(StringBuilder& sb, const CompressedCapabilitySet& capabilitySet) { - ////Secondly we will print our new list of atomSet's. + ////Secondly we will print our new list of atomSet's. sb << "{"; bool firstSet = true; for (auto targetSets : capabilitySet.atomSetsOfTargets) { - if(!firstSet) + if (!firstSet) sb << " || "; for (auto targetSet : targetSets.second) { @@ -1108,18 +1168,14 @@ void printDiagnosticArg(StringBuilder& sb, List& list) #ifdef UNIT_TEST_CAPABILITIES -#define CHECK_CAPS(inData) SLANG_ASSERT(inData>0) +#define CHECK_CAPS(inData) SLANG_ASSERT(inData > 0) int TEST_findTargetCapSet(CapabilitySet& capSet, CapabilityAtom target) { - return true - && capSet.getCapabilityTargetSets().containsKey(target); + return true && capSet.getCapabilityTargetSets().containsKey(target); } -int TEST_findTargetStage( - CapabilitySet& capSet, - CapabilityAtom target, - CapabilityAtom stage) +int TEST_findTargetStage(CapabilitySet& capSet, CapabilityAtom target, CapabilityAtom stage) { return capSet.getCapabilityTargetSets()[target].shaderStageSets.containsKey(stage); } @@ -1131,7 +1187,8 @@ int TEST_targetCapSetWithSpecificAtomInStage( CapabilityAtom stage, CapabilityAtom atom) { - return capSet.getCapabilityTargetSets()[target].shaderStageSets[stage].atomSet->contains((UInt)atom); + return capSet.getCapabilityTargetSets()[target].shaderStageSets[stage].atomSet->contains( + (UInt)atom); } int TEST_targetCapSetWithSpecificSetInStage( @@ -1141,8 +1198,9 @@ int TEST_targetCapSetWithSpecificSetInStage( List setToFind) { - bool containsStageKey = capSet.getCapabilityTargetSets()[target].shaderStageSets.containsKey(stage); - if (!containsStageKey) + bool containsStageKey = + capSet.getCapabilityTargetSets()[target].shaderStageSets.containsKey(stage); + if (!containsStageKey) return 0; auto& stageSet = capSet.getCapabilityTargetSets()[target].shaderStageSets[stage]; @@ -1172,35 +1230,63 @@ void TEST_CapabilitySet_addAtom() testCapSet = CapabilitySet(CapabilityName::TEST_ADD_1); CHECK_CAPS(TEST_findTargetCapSet(testCapSet, CapabilityAtom::hlsl)); - CHECK_CAPS(TEST_targetCapSetWithSpecificSetInStage(testCapSet, CapabilityAtom::hlsl, CapabilityAtom::vertex, - { CapabilityAtom::textualTarget, CapabilityAtom::hlsl, CapabilityAtom::vertex, - CapabilityAtom::_sm_4_0, CapabilityAtom::_sm_4_1 })); + CHECK_CAPS(TEST_targetCapSetWithSpecificSetInStage( + testCapSet, + CapabilityAtom::hlsl, + CapabilityAtom::vertex, + {CapabilityAtom::textualTarget, + CapabilityAtom::hlsl, + CapabilityAtom::vertex, + CapabilityAtom::_sm_4_0, + CapabilityAtom::_sm_4_1})); CHECK_CAPS(TEST_findTargetCapSet(testCapSet, CapabilityAtom::glsl)); - CHECK_CAPS(TEST_targetCapSetWithSpecificSetInStage(testCapSet, CapabilityAtom::glsl, CapabilityAtom::vertex, - { CapabilityAtom::textualTarget, CapabilityAtom::glsl, CapabilityAtom::vertex, - CapabilityAtom::_GLSL_130 })); + CHECK_CAPS(TEST_targetCapSetWithSpecificSetInStage( + testCapSet, + CapabilityAtom::glsl, + CapabilityAtom::vertex, + {CapabilityAtom::textualTarget, + CapabilityAtom::glsl, + CapabilityAtom::vertex, + CapabilityAtom::_GLSL_130})); CHECK_CAPS(TEST_findTargetCapSet(testCapSet, CapabilityAtom::spirv_1_0)); - CHECK_CAPS(TEST_targetCapSetWithSpecificSetInStage(testCapSet, CapabilityAtom::spirv_1_0, CapabilityAtom::vertex, - { CapabilityAtom::spirv_1_0, CapabilityAtom::vertex, - CapabilityAtom::spirv_1_1 })); + CHECK_CAPS(TEST_targetCapSetWithSpecificSetInStage( + testCapSet, + CapabilityAtom::spirv_1_0, + CapabilityAtom::vertex, + {CapabilityAtom::spirv_1_0, CapabilityAtom::vertex, CapabilityAtom::spirv_1_1})); CHECK_CAPS(TEST_findTargetCapSet(testCapSet, CapabilityAtom::metal)); - CHECK_CAPS(TEST_targetCapSetWithSpecificSetInStage(testCapSet, CapabilityAtom::metal, CapabilityAtom::vertex, - { CapabilityAtom::textualTarget, CapabilityAtom::metal, CapabilityAtom::vertex })); + CHECK_CAPS(TEST_targetCapSetWithSpecificSetInStage( + testCapSet, + CapabilityAtom::metal, + CapabilityAtom::vertex, + {CapabilityAtom::textualTarget, CapabilityAtom::metal, CapabilityAtom::vertex})); // ------------------------------------------------------------ testCapSet = CapabilitySet(CapabilityName::TEST_ADD_2); CHECK_CAPS(TEST_findTargetCapSet(testCapSet, CapabilityAtom::hlsl)); - CHECK_CAPS(TEST_targetCapSetWithSpecificSetInStage(testCapSet, CapabilityAtom::hlsl, CapabilityAtom::compute, - { CapabilityAtom::textualTarget, CapabilityAtom::hlsl, CapabilityAtom::compute, - CapabilityAtom::_sm_4_0, CapabilityAtom::_sm_4_1 })); - CHECK_CAPS(TEST_targetCapSetWithSpecificSetInStage(testCapSet, CapabilityAtom::hlsl, CapabilityAtom::fragment, - { CapabilityAtom::textualTarget, CapabilityAtom::hlsl, CapabilityAtom::fragment, - CapabilityAtom::_sm_4_0, CapabilityAtom::_sm_4_1 })); + CHECK_CAPS(TEST_targetCapSetWithSpecificSetInStage( + testCapSet, + CapabilityAtom::hlsl, + CapabilityAtom::compute, + {CapabilityAtom::textualTarget, + CapabilityAtom::hlsl, + CapabilityAtom::compute, + CapabilityAtom::_sm_4_0, + CapabilityAtom::_sm_4_1})); + CHECK_CAPS(TEST_targetCapSetWithSpecificSetInStage( + testCapSet, + CapabilityAtom::hlsl, + CapabilityAtom::fragment, + {CapabilityAtom::textualTarget, + CapabilityAtom::hlsl, + CapabilityAtom::fragment, + CapabilityAtom::_sm_4_0, + CapabilityAtom::_sm_4_1})); // ------------------------------------------------------------ @@ -1208,9 +1294,14 @@ void TEST_CapabilitySet_addAtom() CHECK_CAPS((int)!TEST_findTargetCapSet(testCapSet, CapabilityAtom::spirv_1_0)); CHECK_CAPS(TEST_findTargetCapSet(testCapSet, CapabilityAtom::glsl)); - CHECK_CAPS(TEST_targetCapSetWithSpecificSetInStage(testCapSet, CapabilityAtom::glsl, CapabilityAtom::fragment, - { CapabilityAtom::textualTarget, CapabilityAtom::glsl, CapabilityAtom::fragment, - CapabilityAtom::_GLSL_130 })); + CHECK_CAPS(TEST_targetCapSetWithSpecificSetInStage( + testCapSet, + CapabilityAtom::glsl, + CapabilityAtom::fragment, + {CapabilityAtom::textualTarget, + CapabilityAtom::glsl, + CapabilityAtom::fragment, + CapabilityAtom::_GLSL_130})); // ------------------------------------------------------------ @@ -1220,8 +1311,16 @@ void TEST_CapabilitySet_addAtom() CHECK_CAPS((int)!TEST_findTargetCapSet(testCapSet, CapabilityAtom::glsl)); CHECK_CAPS(TEST_findTargetStage(testCapSet, CapabilityAtom::hlsl, CapabilityAtom::vertex)); CHECK_CAPS(TEST_findTargetStage(testCapSet, CapabilityAtom::hlsl, CapabilityAtom::fragment)); - CHECK_CAPS(TEST_targetCapSetWithSpecificAtomInStage(testCapSet, CapabilityAtom::hlsl, CapabilityAtom::fragment, CapabilityAtom::_sm_6_0)); - CHECK_CAPS(TEST_targetCapSetWithSpecificAtomInStage(testCapSet, CapabilityAtom::hlsl, CapabilityAtom::fragment, CapabilityAtom::_sm_5_0)); + CHECK_CAPS(TEST_targetCapSetWithSpecificAtomInStage( + testCapSet, + CapabilityAtom::hlsl, + CapabilityAtom::fragment, + CapabilityAtom::_sm_6_0)); + CHECK_CAPS(TEST_targetCapSetWithSpecificAtomInStage( + testCapSet, + CapabilityAtom::hlsl, + CapabilityAtom::fragment, + CapabilityAtom::_sm_5_0)); // ------------------------------------------------------------ @@ -1230,8 +1329,16 @@ void TEST_CapabilitySet_addAtom() CHECK_CAPS(TEST_findTargetCapSet(testCapSet, CapabilityAtom::hlsl)); CHECK_CAPS((int)!TEST_findTargetCapSet(testCapSet, CapabilityAtom::glsl)); CHECK_CAPS(TEST_findTargetStage(testCapSet, CapabilityAtom::hlsl, CapabilityAtom::fragment)); - CHECK_CAPS(TEST_targetCapSetWithSpecificAtomInStage(testCapSet, CapabilityAtom::hlsl, CapabilityAtom::fragment, CapabilityAtom::_sm_6_5)); - CHECK_CAPS(TEST_targetCapSetWithSpecificAtomInStage(testCapSet, CapabilityAtom::hlsl, CapabilityAtom::fragment, CapabilityAtom::_sm_5_0)); + CHECK_CAPS(TEST_targetCapSetWithSpecificAtomInStage( + testCapSet, + CapabilityAtom::hlsl, + CapabilityAtom::fragment, + CapabilityAtom::_sm_6_5)); + CHECK_CAPS(TEST_targetCapSetWithSpecificAtomInStage( + testCapSet, + CapabilityAtom::hlsl, + CapabilityAtom::fragment, + CapabilityAtom::_sm_5_0)); // ------------------------------------------------------------ @@ -1239,8 +1346,16 @@ void TEST_CapabilitySet_addAtom() CHECK_CAPS(TEST_findTargetCapSet(testCapSet, CapabilityAtom::glsl)); CHECK_CAPS(TEST_findTargetStage(testCapSet, CapabilityAtom::glsl, CapabilityAtom::fragment)); - CHECK_CAPS(TEST_targetCapSetWithSpecificAtomInStage(testCapSet, CapabilityAtom::glsl, CapabilityAtom::fragment, CapabilityAtom::_GL_NV_shader_texture_footprint)); - CHECK_CAPS(TEST_targetCapSetWithSpecificAtomInStage(testCapSet, CapabilityAtom::glsl, CapabilityAtom::fragment, CapabilityAtom::_GL_NV_compute_shader_derivatives)); + CHECK_CAPS(TEST_targetCapSetWithSpecificAtomInStage( + testCapSet, + CapabilityAtom::glsl, + CapabilityAtom::fragment, + CapabilityAtom::_GL_NV_shader_texture_footprint)); + CHECK_CAPS(TEST_targetCapSetWithSpecificAtomInStage( + testCapSet, + CapabilityAtom::glsl, + CapabilityAtom::fragment, + CapabilityAtom::_GL_NV_compute_shader_derivatives)); // ------------------------------------------------------------ @@ -1248,17 +1363,37 @@ void TEST_CapabilitySet_addAtom() CHECK_CAPS(TEST_findTargetCapSet(testCapSet, CapabilityAtom::glsl)); CHECK_CAPS(TEST_findTargetStage(testCapSet, CapabilityAtom::glsl, CapabilityAtom::fragment)); - CHECK_CAPS(TEST_targetCapSetWithSpecificAtomInStage(testCapSet, CapabilityAtom::glsl, CapabilityAtom::fragment, CapabilityAtom::_GL_NV_shader_texture_footprint)); - CHECK_CAPS(TEST_targetCapSetWithSpecificAtomInStage(testCapSet, CapabilityAtom::glsl, CapabilityAtom::fragment, CapabilityAtom::_GL_ARB_shader_image_size)); + CHECK_CAPS(TEST_targetCapSetWithSpecificAtomInStage( + testCapSet, + CapabilityAtom::glsl, + CapabilityAtom::fragment, + CapabilityAtom::_GL_NV_shader_texture_footprint)); + CHECK_CAPS(TEST_targetCapSetWithSpecificAtomInStage( + testCapSet, + CapabilityAtom::glsl, + CapabilityAtom::fragment, + CapabilityAtom::_GL_ARB_shader_image_size)); // ------------------------------------------------------------ testCapSet = CapabilitySet(CapabilityName::TEST_GEN_5); CHECK_CAPS(TEST_findTargetCapSet(testCapSet, CapabilityAtom::hlsl)); - CHECK_CAPS(TEST_targetCapSetWithSpecificAtomInStage(testCapSet, CapabilityAtom::hlsl, CapabilityAtom::fragment, CapabilityAtom::_sm_6_5)); - CHECK_CAPS(TEST_targetCapSetWithSpecificAtomInStage(testCapSet, CapabilityAtom::hlsl, CapabilityAtom::fragment, CapabilityAtom::_sm_6_4)); - CHECK_CAPS(TEST_targetCapSetWithSpecificAtomInStage(testCapSet, CapabilityAtom::hlsl, CapabilityAtom::fragment, CapabilityAtom::_sm_6_0)); + CHECK_CAPS(TEST_targetCapSetWithSpecificAtomInStage( + testCapSet, + CapabilityAtom::hlsl, + CapabilityAtom::fragment, + CapabilityAtom::_sm_6_5)); + CHECK_CAPS(TEST_targetCapSetWithSpecificAtomInStage( + testCapSet, + CapabilityAtom::hlsl, + CapabilityAtom::fragment, + CapabilityAtom::_sm_6_4)); + CHECK_CAPS(TEST_targetCapSetWithSpecificAtomInStage( + testCapSet, + CapabilityAtom::hlsl, + CapabilityAtom::fragment, + CapabilityAtom::_sm_6_0)); } void TEST_CapabilitySet_join() @@ -1282,10 +1417,16 @@ void TEST_CapabilitySet_join() testCapSetA.join(testCapSetB); CHECK_CAPS(TEST_findTargetCapSet(testCapSetA, CapabilityAtom::hlsl)); - CHECK_CAPS(TEST_targetCapSetWithSpecificSetInStage(testCapSetA, CapabilityAtom::hlsl, CapabilityAtom::vertex, - { CapabilityAtom::textualTarget, CapabilityAtom::hlsl, CapabilityAtom::vertex, - CapabilityAtom::_sm_4_0, CapabilityAtom::_sm_4_1 })); - + CHECK_CAPS(TEST_targetCapSetWithSpecificSetInStage( + testCapSetA, + CapabilityAtom::hlsl, + CapabilityAtom::vertex, + {CapabilityAtom::textualTarget, + CapabilityAtom::hlsl, + CapabilityAtom::vertex, + CapabilityAtom::_sm_4_0, + CapabilityAtom::_sm_4_1})); + // ------------------------------------------------------------ testCapSetA = CapabilitySet(CapabilityName::TEST_JOIN_3A); @@ -1294,19 +1435,43 @@ void TEST_CapabilitySet_join() CHECK_CAPS((int)!TEST_findTargetCapSet(testCapSetA, CapabilityAtom::spirv_1_0)); CHECK_CAPS(TEST_findTargetCapSet(testCapSetA, CapabilityAtom::glsl)); - CHECK_CAPS((int)!TEST_findTargetStage(testCapSetA, CapabilityAtom::glsl, CapabilityAtom::raygen)); - CHECK_CAPS(TEST_targetCapSetWithSpecificSetInStage(testCapSetA, CapabilityAtom::glsl, CapabilityAtom::fragment, - { CapabilityAtom::textualTarget, CapabilityAtom::glsl, CapabilityAtom::fragment, - CapabilityAtom::_GLSL_130, CapabilityAtom::_GLSL_140 })); - CHECK_CAPS(TEST_targetCapSetWithSpecificSetInStage(testCapSetA, CapabilityAtom::glsl, CapabilityAtom::vertex, - { CapabilityAtom::textualTarget, CapabilityAtom::glsl, CapabilityAtom::vertex, - CapabilityAtom::_GLSL_130, CapabilityAtom::_GLSL_140 })); - CHECK_CAPS(TEST_targetCapSetWithSpecificSetInStage(testCapSetA, CapabilityAtom::hlsl, CapabilityAtom::fragment, - { CapabilityAtom::textualTarget, CapabilityAtom::hlsl, CapabilityAtom::fragment, - CapabilityAtom::_sm_4_0, CapabilityAtom::_sm_4_1 })); - CHECK_CAPS(TEST_targetCapSetWithSpecificSetInStage(testCapSetA, CapabilityAtom::hlsl, CapabilityAtom::vertex, - { CapabilityAtom::textualTarget, CapabilityAtom::hlsl, CapabilityAtom::vertex, - CapabilityAtom::_sm_4_0 })); + CHECK_CAPS( + (int)!TEST_findTargetStage(testCapSetA, CapabilityAtom::glsl, CapabilityAtom::raygen)); + CHECK_CAPS(TEST_targetCapSetWithSpecificSetInStage( + testCapSetA, + CapabilityAtom::glsl, + CapabilityAtom::fragment, + {CapabilityAtom::textualTarget, + CapabilityAtom::glsl, + CapabilityAtom::fragment, + CapabilityAtom::_GLSL_130, + CapabilityAtom::_GLSL_140})); + CHECK_CAPS(TEST_targetCapSetWithSpecificSetInStage( + testCapSetA, + CapabilityAtom::glsl, + CapabilityAtom::vertex, + {CapabilityAtom::textualTarget, + CapabilityAtom::glsl, + CapabilityAtom::vertex, + CapabilityAtom::_GLSL_130, + CapabilityAtom::_GLSL_140})); + CHECK_CAPS(TEST_targetCapSetWithSpecificSetInStage( + testCapSetA, + CapabilityAtom::hlsl, + CapabilityAtom::fragment, + {CapabilityAtom::textualTarget, + CapabilityAtom::hlsl, + CapabilityAtom::fragment, + CapabilityAtom::_sm_4_0, + CapabilityAtom::_sm_4_1})); + CHECK_CAPS(TEST_targetCapSetWithSpecificSetInStage( + testCapSetA, + CapabilityAtom::hlsl, + CapabilityAtom::vertex, + {CapabilityAtom::textualTarget, + CapabilityAtom::hlsl, + CapabilityAtom::vertex, + CapabilityAtom::_sm_4_0})); // ------------------------------------------------------------ @@ -1315,13 +1480,20 @@ void TEST_CapabilitySet_join() testCapSetA.join(testCapSetB); CHECK_CAPS(TEST_findTargetCapSet(testCapSetA, CapabilityAtom::glsl)); - CHECK_CAPS(TEST_targetCapSetWithSpecificSetInStage(testCapSetA, CapabilityAtom::glsl, CapabilityAtom::fragment, - { CapabilityAtom::textualTarget, CapabilityAtom::glsl, CapabilityAtom::fragment, - CapabilityAtom::_GLSL_130, CapabilityAtom::_GLSL_140, CapabilityAtom::_GLSL_150, CapabilityAtom::_GL_EXT_texture_query_lod, CapabilityAtom::_GL_EXT_texture_shadow_lod })); + CHECK_CAPS(TEST_targetCapSetWithSpecificSetInStage( + testCapSetA, + CapabilityAtom::glsl, + CapabilityAtom::fragment, + {CapabilityAtom::textualTarget, + CapabilityAtom::glsl, + CapabilityAtom::fragment, + CapabilityAtom::_GLSL_130, + CapabilityAtom::_GLSL_140, + CapabilityAtom::_GLSL_150, + CapabilityAtom::_GL_EXT_texture_query_lod, + CapabilityAtom::_GL_EXT_texture_shadow_lod})); // ------------------------------------------------------------ - - } void TEST_CapabilitySet() @@ -1339,9 +1511,9 @@ alias TEST_ADD_3 = _GLSL_130 + compute_fragment_geometry_vertex; alias TEST_GEN_1 = _sm_6_5 + fragment | _sm_6_0 + vertex; alias TEST_GEN_2 = _sm_6_5 + fragment; -alias TEST_GEN_3 = GL_NV_shader_texture_footprint + GL_NV_compute_shader_derivatives + fragment | _GL_NV_shader_texture_footprint + fragment; -alias TEST_GEN_4 = GL_ARB_shader_image_size |& GL_NV_shader_texture_footprint + fragment; -alias TEST_GEN_5 = sm_6_0 + compute_fragment| sm_6_5; +alias TEST_GEN_3 = GL_NV_shader_texture_footprint + GL_NV_compute_shader_derivatives + fragment +| _GL_NV_shader_texture_footprint + fragment; alias TEST_GEN_4 = GL_ARB_shader_image_size |& +GL_NV_shader_texture_footprint + fragment; alias TEST_GEN_5 = sm_6_0 + compute_fragment| sm_6_5; alias TEST_JOIN_1A = hlsl; alias TEST_JOIN_1B = glsl; @@ -1349,16 +1521,16 @@ alias TEST_JOIN_1B = glsl; alias TEST_JOIN_2A = hlsl; alias TEST_JOIN_2B = _sm_4_1 | glsl; -alias TEST_JOIN_3A = glsl + fragment | _sm_4_0 + fragment +alias TEST_JOIN_3A = glsl + fragment | _sm_4_0 + fragment | glsl + vertex | hlsl + vertex ; -alias TEST_JOIN_3B = _sm_4_1 + fragment +alias TEST_JOIN_3B = _sm_4_1 + fragment | _sm_4_0 + vertex | _sm_4_0 + compute | _GLSL_140 + vertex | _GLSL_140 + fragment | spirv_1_0 + fragment - | glsl + raygen + | glsl + raygen | hlsl + raygen ; @@ -1366,10 +1538,11 @@ alias TEST_JOIN_4A = _GLSL_140 + _GL_EXT_texture_query_lod; alias TEST_JOIN_4B = _GLSL_150 + _GL_EXT_texture_shadow_lod; // Will cause capability generator failiure -alias TEST_ERROR_GEN_1 = GL_NV_shader_texture_footprint + GL_NV_compute_shader_derivatives + fragment | _GL_NV_shader_texture_footprint + _GL_NV_shader_atomic_fp16_vector + fragment; -alias TEST_ERROR_GEN_2 = GL_NV_shader_texture_footprint | GL_NV_ray_tracing_motion_blur; -alias TEST_ERROR_GEN_3 = GL_ARB_shader_image_size | GL_NV_shader_texture_footprint + fragment; -alias TEST_ERROR_GEN_4 = _sm_6_5 + fragment + vertex + cpp; +alias TEST_ERROR_GEN_1 = GL_NV_shader_texture_footprint + GL_NV_compute_shader_derivatives + +fragment | _GL_NV_shader_texture_footprint + _GL_NV_shader_atomic_fp16_vector + fragment; alias +TEST_ERROR_GEN_2 = GL_NV_shader_texture_footprint | GL_NV_ray_tracing_motion_blur; alias +TEST_ERROR_GEN_3 = GL_ARB_shader_image_size | GL_NV_shader_texture_footprint + fragment; alias +TEST_ERROR_GEN_4 = _sm_6_5 + fragment + vertex + cpp; /// */ @@ -1377,4 +1550,4 @@ alias TEST_ERROR_GEN_4 = _sm_6_5 + fragment + vertex + cpp; #endif -} +} // namespace Slang diff --git a/source/slang/slang-capability.h b/source/slang/slang-capability.h index 2e12af622..631cd307a 100644 --- a/source/slang/slang-capability.h +++ b/source/slang/slang-capability.h @@ -1,12 +1,12 @@ // slang-capability.h #pragma once +#include "../core/slang-dictionary.h" #include "../core/slang-list.h" #include "../core/slang-string.h" -#include "../core/slang-dictionary.h" -#include #include +#include namespace Slang { @@ -58,14 +58,14 @@ struct CapabilityAtomSet : UIntSet struct CapabilityTargetSet; typedef Dictionary CapabilityTargetSets; -/// CapabilityStageSet encapsulates all capabilities of a specific shader stage for a specific target. -/// Capabilities may be disjoint, but only in rare cases: +/// CapabilityStageSet encapsulates all capabilities of a specific shader stage for a specific +/// target. Capabilities may be disjoint, but only in rare cases: /// {{glsl, _GLSL_130, GL_EXT_FOO1}, {glsl, _GLSL_130, _GLSL_140, _GLSL_150}} struct CapabilityStageSet { CapabilityAtom stage{}; - /// LinkedList of all disjoint sets for fast remove/add of unconstrained list positions. + /// LinkedList of all disjoint sets for fast remove/add of unconstrained list positions. std::optional atomSet{}; void addNewSet(CapabilityAtomSet&& setToAdd) @@ -159,41 +159,59 @@ public: /// `this` may be made invalid if other is fully disjoint. void join(const CapabilitySet& other); - /// Join two capability sets to form ('this' & 'other'). + /// Join two capability sets to form ('this' & 'other'). /// If a target/set has an incompatible atom, do not destroy the target/set. void nonDestructiveJoin(const CapabilitySet& other); /// Add all targets/sets of 'other' into 'this'. Overlapping sets are removed. void unionWith(const CapabilitySet& other); - /// Return a capability set of 'target' atoms 'this' has, but 'other' does not. + /// Return a capability set of 'target' atoms 'this' has, but 'other' does not. CapabilitySet getTargetsThisHasButOtherDoesNot(const CapabilitySet& other); - /// Are these two capability sets equal? + /// Are these two capability sets equal? bool operator==(CapabilitySet const& that) const; void addCapability(List>& atomLists); - /// Calculate a list of "compacted" atoms, which excludes any atoms from the expanded list that are implies by another item in the list. + /// Calculate a list of "compacted" atoms, which excludes any atoms from the expanded list that + /// are implies by another item in the list. /// returns true if 'this' is a better target for 'targetCaps' than 'that' /// isEqual: is `this` and `that` equal - bool isBetterForTarget(CapabilitySet const& that, CapabilitySet const& targetCaps, bool& isEqual) const; - - /// Find any capability sets which are in 'available' but not in 'required'. Return false if this situation occurs. - static bool checkCapabilityRequirement(CapabilitySet const& available, CapabilitySet const& required, CapabilityAtomSet& outFailedAvailableSet); - - // For each element in `elementsToPermutateWith`, create and add a different conjunction permutation by adding to `setToPermutate`. + bool isBetterForTarget( + CapabilitySet const& that, + CapabilitySet const& targetCaps, + bool& isEqual) const; + + /// Find any capability sets which are in 'available' but not in 'required'. Return false if + /// this situation occurs. + static bool checkCapabilityRequirement( + CapabilitySet const& available, + CapabilitySet const& required, + CapabilityAtomSet& outFailedAvailableSet); + + // For each element in `elementsToPermutateWith`, create and add a different conjunction + // permutation by adding to `setToPermutate`. template - void addPermutationsOfConjunctionForEachInContainer(CapabilityAtomSet& setToPermutate, const CapabilityAtomSet& elementsToPermutateWith, CapabilityAtom knownTargetAtom, CapabilityAtom knownStageAtom); - // This is used for adding conjunctions directly and efficently, this is not functionally a join. - // if `knownStage`/`knownTarget` is not CapabilityAtom::Invalid, the given atom will be assumed as an assigned key atom (faster) - inline void addConjunction(CapabilityAtomSet conjunction, CapabilityAtom knownTarget, CapabilityAtom knownStage); + void addPermutationsOfConjunctionForEachInContainer( + CapabilityAtomSet& setToPermutate, + const CapabilityAtomSet& elementsToPermutateWith, + CapabilityAtom knownTargetAtom, + CapabilityAtom knownStageAtom); + // This is used for adding conjunctions directly and efficently, this is not functionally a + // join. if `knownStage`/`knownTarget` is not CapabilityAtom::Invalid, the given atom will be + // assumed as an assigned key atom (faster) + inline void addConjunction( + CapabilityAtomSet conjunction, + CapabilityAtom knownTarget, + CapabilityAtom knownStage); inline void addUnexpandedCapabilites(CapabilityName atom); - + CapabilityTargetSets& getCapabilityTargetSets() { return m_targetSets; } const CapabilityTargetSets& getCapabilityTargetSets() const { return m_targetSets; } - // If this capability set uniquely implies one stage atom, return it. Otherwise returns CapabilityAtom::Invalid. + // If this capability set uniquely implies one stage atom, return it. Otherwise returns + // CapabilityAtom::Invalid. CapabilityAtom getUniquelyImpliedStageAtom() const; struct AtomSets @@ -207,38 +225,24 @@ public: const std::optional* atomSetNode = {}; public: - operator bool() const - { - return (atomSetNode) ? atomSetNode->has_value() : false; - } - const CapabilityAtomSet& operator*() const - { - return *(*this->atomSetNode); - } - const CapabilityAtomSet* operator->() const - { - return &(*(*this->atomSetNode)); - } + operator bool() const { return (atomSetNode) ? atomSetNode->has_value() : false; } + const CapabilityAtomSet& operator*() const { return *(*this->atomSetNode); } + const CapabilityAtomSet* operator->() const { return &(*(*this->atomSetNode)); } bool operator==(const Iterator& other) const { - return other.context == this->context - && other.targetNode == this->targetNode - && other.stageNode == this->stageNode - ; - } - bool operator!=(const Iterator& other) const - { - return !(other == *this); + return other.context == this->context && other.targetNode == this->targetNode && + other.stageNode == this->stageNode; } + bool operator!=(const Iterator& other) const { return !(other == *this); } Iterator& operator++() { - for(;;) + for (;;) { this->stageNode++; if (this->stageNode == (*this->targetNode).second.shaderStageSets.end()) { - for(;;) + for (;;) { this->targetNode++; if (this->targetNode == this->context->end()) @@ -260,10 +264,7 @@ public: } return *this; } - Iterator& operator++(int) - { - return ++(*this); - } + Iterator& operator++(int) { return ++(*this); } Iterator begin() const { Iterator tmp(this->context); @@ -289,23 +290,21 @@ public: tmp.targetNode = this->context->end(); return tmp; } - Iterator(const CapabilityTargetSets* mainContext) - { - context = mainContext; - } + Iterator(const CapabilityTargetSets* mainContext) { context = mainContext; } }; }; /// Get access to the raw atomic capabilities that define this set. /// Get all bottom level UIntSets for each CapabilityTargetSet. CapabilitySet::AtomSets::Iterator getAtomSets() const; - /// Add spirv version capabilities from 'spirv CapabilityTargetSet' as glsl_spirv version capability in 'glsl CapabilityTargetSet' + /// Add spirv version capabilities from 'spirv CapabilityTargetSet' as glsl_spirv version + /// capability in 'glsl CapabilityTargetSet' void addSpirvVersionFromOtherAsGlslSpirvVersion(CapabilitySet& other); /// Gets the first valid compile-target found in the CapabilitySet CapabilityAtom getCompileTarget() { - if(isEmpty() || isInvalid()) + if (isEmpty() || isInvalid()) return CapabilityAtom::Invalid; return (*m_targetSets.begin()).first; } @@ -313,7 +312,7 @@ public: /// Gets the first valid stage found in the CapabilitySet CapabilityAtom getTargetStage() { - if(isEmpty() || isInvalid()) + if (isEmpty() || isInvalid()) return CapabilityAtom::Invalid; return (*(*m_targetSets.begin()).second.shaderStageSets.begin()).first; } @@ -334,13 +333,13 @@ private: ImpliesReturnFlags _implies(CapabilitySet const& other, ImpliesFlags flags) const; }; - /// Returns true if atom is derived from base +/// Returns true if atom is derived from base bool isCapabilityDerivedFrom(CapabilityAtom atom, CapabilityAtom base); - /// Find a capability atom with the given `name`, or return CapabilityAtom::Invalid. +/// Find a capability atom with the given `name`, or return CapabilityAtom::Invalid. CapabilityName findCapabilityName(UnownedStringSlice const& name); - /// Check if 'name' is an '_Internal' or 'External' capability. +/// Check if 'name' is an '_Internal' or 'External' capability. bool isInternalCapabilityName(CapabilityName name); CapabilityAtom getLatestSpirvAtom(); @@ -354,7 +353,7 @@ inline CapabilityAtom asAtom(T name) return CapabilityAtom(name); } - /// Gets the capability names. +/// Gets the capability names. void getCapabilityNames(List& ioNames); UnownedStringSlice capabilityNameToString(CapabilityName name); @@ -362,7 +361,7 @@ UnownedStringSlice capabilityNameToString(CapabilityName name); bool isDirectChildOfAbstractAtom(CapabilityAtom name); - /// Return true if `name` represents an atom for a target version, e.g. spirv_1_5. +/// Return true if `name` represents an atom for a target version, e.g. spirv_1_5. bool isTargetVersionAtom(CapabilityAtom name); bool isSpirvExtensionAtom(CapabilityAtom name); @@ -376,9 +375,9 @@ bool hasTargetAtom(const CapabilityAtomSet& setIn, CapabilityAtom& targetAtom); void freeCapabilityDefs(); -//#define UNIT_TEST_CAPABILITIES +// #define UNIT_TEST_CAPABILITIES #ifdef UNIT_TEST_CAPABILITIES void TEST_CapabilitySet(); #endif -} +} // namespace Slang diff --git a/source/slang/slang-check-conformance.cpp b/source/slang/slang-check-conformance.cpp index 9d9047e41..499b409ea 100644 --- a/source/slang/slang-check-conformance.cpp +++ b/source/slang/slang-check-conformance.cpp @@ -7,367 +7,379 @@ namespace Slang { - bool SemanticsVisitor::isInterfaceSafeForTaggedUnion( - DeclRef interfaceDeclRef) +bool SemanticsVisitor::isInterfaceSafeForTaggedUnion(DeclRef interfaceDeclRef) +{ + for (auto memberDeclRef : getMembers(m_astBuilder, interfaceDeclRef)) { - for( auto memberDeclRef : getMembers(m_astBuilder, interfaceDeclRef) ) - { - if(!isInterfaceRequirementSafeForTaggedUnion(interfaceDeclRef, memberDeclRef)) - return false; - } + if (!isInterfaceRequirementSafeForTaggedUnion(interfaceDeclRef, memberDeclRef)) + return false; + } + + return true; +} + +bool SemanticsVisitor::isInterfaceRequirementSafeForTaggedUnion( + DeclRef interfaceDeclRef, + DeclRef requirementDeclRef) +{ + SLANG_UNUSED(interfaceDeclRef); + + if (auto callableDeclRef = requirementDeclRef.as()) + { + // A `static` method requirement can't be satisfied by a + // tagged union, because there is no tag to dispatch on. + // + if (requirementDeclRef.getDecl()->hasModifier()) + return false; + + // TODO: We will eventually want to check that any callable + // requirements do not use the `This` type or any associated + // types in ways that could lead to errors. + // + // For now we are disallowing interfaces that have associated + // types completely, and we haven't implemented the `This` + // type, so we should be safe. return true; } - - bool SemanticsVisitor::isInterfaceRequirementSafeForTaggedUnion( - DeclRef interfaceDeclRef, - DeclRef requirementDeclRef) + else { - SLANG_UNUSED(interfaceDeclRef); + return false; + } +} - if(auto callableDeclRef = requirementDeclRef.as()) - { - // A `static` method requirement can't be satisfied by a - // tagged union, because there is no tag to dispatch on. - // - if(requirementDeclRef.getDecl()->hasModifier()) - return false; +SubtypeWitness* SemanticsVisitor::isSubtype( + Type* subType, + Type* superType, + IsSubTypeOptions isSubTypeOptions) +{ + SubtypeWitness* result = nullptr; + if (getShared()->tryGetSubtypeWitnessFromCache(subType, superType, result)) + return result; + result = checkAndConstructSubtypeWitness(subType, superType, isSubTypeOptions); - // TODO: We will eventually want to check that any callable - // requirements do not use the `This` type or any associated - // types in ways that could lead to errors. - // - // For now we are disallowing interfaces that have associated - // types completely, and we haven't implemented the `This` - // type, so we should be safe. + if (!result && (int(isSubTypeOptions) & int(IsSubTypeOptions::NoCaching))) + return result; - return true; - } - else + getShared()->cacheSubtypeWitness(subType, superType, result); + return result; +} + +SubtypeWitness* SemanticsVisitor::checkAndConstructSubtypeWitness( + Type* subType, + Type* superType, + IsSubTypeOptions isSubTypeOptions) +{ + // TODO: The Slang codebase is currently being quite slippery by conflating + // multiple concepts, all under the banner of a "subtype" test: + // + // * Struct/class inheritance: When concrete type `A` inherits from concrete + // type `B`, we can directly convert any value of type `A` into a value of type `B` + // + // * Derived interfaces: When interface `X` derives from interface `Y`, we know + // that any concrete type conforming to `X` must also conform to `Y`, so we can + // derive a witness that `A : Y` from a witness tbale that `A : X` for some concrete `A` + // + // * Conformance: When concrete type `A` conforms to interface `X`, we know that there exists + // a witness table for that conformance. + // + // The problem is that these relationships mean different things. If we use the same + // `isSubtype()` test for all of the above cases, then we risk determining that `IFoo` + // *conforms* to `IBar` just because it was declared as `interface IFoo : IBar`. Or + // even more simply that `IFoo` conforms to `IFoo`. + // + // It is dangerous to start treating an interface type like it conforms to itself: + // + // interface IFoo { static int getValue(); } + // int get< T : IFoo >() { return T.getValue(); } + // + // int x = get(); // This needs to be an error!!! + // + // We will eventually need to clarify the distinction between the different kinds of + // subtype-ish relationships, *or* we will need to ensure that `interface`s are not + // treated as proper types (such that they can be passed as generic arguments, etc.) + // + // Note that there is one more case of a subtype-ish relationship that is not covered + // by this function, but that is relevant if/when we do more serious type inference: + // + // * Convertibility: When any value of type `A` can be converted to a value of type + // `B` (even if that conversion might involve computation or a change of representation), + // and that conversion is one that the compiler considers "okay" to do implicitly. + // + // For now we are continuing to conflate all the subtype-ish relationships but not + // tangling convertibility into it. + + // First, make sure both sub type and super type decl are ready for lookup. + if (!(int(isSubTypeOptions) & int(IsSubTypeOptions::NoCaching))) + { + if (auto subDeclRefType = as(subType)) { - return false; + ensureDecl(subDeclRefType->getDeclRef().getDecl(), DeclCheckState::ReadyForLookup); } } - - SubtypeWitness* SemanticsVisitor::isSubtype( - Type* subType, - Type* superType, - IsSubTypeOptions isSubTypeOptions - ) + if (auto superDeclRefType = as(superType)) { - SubtypeWitness* result = nullptr; - if (getShared()->tryGetSubtypeWitnessFromCache(subType, superType, result)) - return result; - result = checkAndConstructSubtypeWitness(subType, superType, isSubTypeOptions); - - if(!result && (int(isSubTypeOptions) & int(IsSubTypeOptions::NoCaching))) - return result; - - getShared()->cacheSubtypeWitness(subType, superType, result); - return result; + ensureDecl(superDeclRefType->getDeclRef().getDecl(), DeclCheckState::ReadyForLookup); } - SubtypeWitness* SemanticsVisitor::checkAndConstructSubtypeWitness( - Type* subType, - Type* superType, - IsSubTypeOptions isSubTypeOptions) + // In the common case, we can use the pre-computed inheritance information for `subType` + // to enumerate all the types it transitively inherits from. + // + auto inheritanceInfo = getShared()->getInheritanceInfo(subType); + for (auto facet : inheritanceInfo.facets) { - // TODO: The Slang codebase is currently being quite slippery by conflating - // multiple concepts, all under the banner of a "subtype" test: - // - // * Struct/class inheritance: When concrete type `A` inherits from concrete - // type `B`, we can directly convert any value of type `A` into a value of type `B` - // - // * Derived interfaces: When interface `X` derives from interface `Y`, we know - // that any concrete type conforming to `X` must also conform to `Y`, so we can - // derive a witness that `A : Y` from a witness tbale that `A : X` for some concrete `A` - // - // * Conformance: When concrete type `A` conforms to interface `X`, we know that there exists - // a witness table for that conformance. + // The `subType` will have a `facet` for each type + // that it transitively inherits from, as well as + // for each `extension` that was found to apply to it. // - // The problem is that these relationships mean different things. If we use the same - // `isSubtype()` test for all of the above cases, then we risk determining that `IFoo` - // *conforms* to `IBar` just because it was declared as `interface IFoo : IBar`. Or - // even more simply that `IFoo` conforms to `IFoo`. + // For subtype testing, we are only interested in + // the facets that represent supertypes, and those + // will be the ones that store a type on the facet. // - // It is dangerous to start treating an interface type like it conforms to itself: - // - // interface IFoo { static int getValue(); } - // int get< T : IFoo >() { return T.getValue(); } - // - // int x = get(); // This needs to be an error!!! - // - // We will eventually need to clarify the distinction between the different kinds of - // subtype-ish relationships, *or* we will need to ensure that `interface`s are not - // treated as proper types (such that they can be passed as generic arguments, etc.) - // - // Note that there is one more case of a subtype-ish relationship that is not covered - // by this function, but that is relevant if/when we do more serious type inference: - // - // * Convertibility: When any value of type `A` can be converted to a value of type - // `B` (even if that conversion might involve computation or a change of representation), - // and that conversion is one that the compiler considers "okay" to do implicitly. - // - // For now we are continuing to conflate all the subtype-ish relationships but not - // tangling convertibility into it. + auto facetType = facet->getType(); + if (!facetType) + continue; - // First, make sure both sub type and super type decl are ready for lookup. - if ( !(int(isSubTypeOptions) & int(IsSubTypeOptions::NoCaching)) ) - { - if (auto subDeclRefType = as(subType)) - { - ensureDecl(subDeclRefType->getDeclRef().getDecl(), DeclCheckState::ReadyForLookup); - } - } - if (auto superDeclRefType = as(superType)) - { - ensureDecl(superDeclRefType->getDeclRef().getDecl(), DeclCheckState::ReadyForLookup); - } + // We will scan until we find a facet that corresponds + // to `superType`, or fail to find such a facet. + // + if (!facetType->equals(superType)) + continue; - // In the common case, we can use the pre-computed inheritance information for `subType` - // to enumerate all the types it transitively inherits from. + // If the `superType` appears in the flattened inheritance list + // for the `subType`, then we know that the subtype relationship + // holds. Conveniently, the `facet` stores a pre-computed witness + // for the subtype relationship, which we can return here. // - auto inheritanceInfo = getShared()->getInheritanceInfo(subType); - for (auto facet : inheritanceInfo.facets) - { - // The `subType` will have a `facet` for each type - // that it transitively inherits from, as well as - // for each `extension` that was found to apply to it. - // - // For subtype testing, we are only interested in - // the facets that represent supertypes, and those - // will be the ones that store a type on the facet. - // - auto facetType = facet->getType(); - if (!facetType) - continue; + return facet->subtypeWitness; + } + // + // TODO: We could expand upon the test using the facet list above + // by taking the facet lists of both `subType` and `superType` + // and then checking if all of the facets that appear in `superType`'s + // linearization also appear in the linearization for `subType` + // (and occur in the same order). + // + // That test could potentially handle certain cases of interface + // conjunctions that the simpler algorithm above can't, but it wouldn't + // seem to be a complete algorithm unless we ensured that interfaces + // have a canonical sorting order for how they appear in linearizations. + // + // One of the main reasons why we don't implement such a test right now + // is that it isn't obvious how to directly produce a witness value + // as collateral from the test. - // We will scan until we find a facet that corresponds - // to `superType`, or fail to find such a facet. - // - if (!facetType->equals(superType)) - continue; + // We expect the logic above to cover the vast majority of subtype + // tests, but there are a few remaining cases of subtype testing + // that cannot be folded into the type linearizations above. + // + // A few of these cases case if the `superType` is a `DeclRefType` + // and, if so, want to compare its `DeclRef` against others. As + // such, we will extract the `DeclRef` here, if it exists, + // as a convienience. + // + DeclRef superTypeDeclRef; + if (auto superDeclRefType = as(superType)) + { + superTypeDeclRef = superDeclRefType->getDeclRef(); + } - // If the `superType` appears in the flattened inheritance list - // for the `subType`, then we know that the subtype relationship - // holds. Conveniently, the `facet` stores a pre-computed witness - // for the subtype relationship, which we can return here. - // - return facet->subtypeWitness; - } + if (as(subType)) + { + // A __Dynamic type always conforms to the interface via its witness table. + auto witness = m_astBuilder->getOrCreate(subType, superType); + return witness; + } + else if (auto conjunctionSuperType = as(superType)) + { + // We know that `T <: L & R` if `T <: L` and `T <: R`. // - // TODO: We could expand upon the test using the facet list above - // by taking the facet lists of both `subType` and `superType` - // and then checking if all of the facets that appear in `superType`'s - // linearization also appear in the linearization for `subType` - // (and occur in the same order). + // We therefore simply recursively test both `T <: L` + // and `T <: R`. // - // That test could potentially handle certain cases of interface - // conjunctions that the simpler algorithm above can't, but it wouldn't - // seem to be a complete algorithm unless we ensured that interfaces - // have a canonical sorting order for how they appear in linearizations. + auto leftWitness = + isSubtype(subType, conjunctionSuperType->getLeft(), IsSubTypeOptions::None); + if (!leftWitness) + return nullptr; // - // One of the main reasons why we don't implement such a test right now - // is that it isn't obvious how to directly produce a witness value - // as collateral from the test. + auto rightWitness = + isSubtype(subType, conjunctionSuperType->getRight(), IsSubTypeOptions::None); + if (!rightWitness) + return nullptr; - // We expect the logic above to cover the vast majority of subtype - // tests, but there are a few remaining cases of subtype testing - // that cannot be folded into the type linearizations above. + // If both of the sub-relationships hold, we can construct + // a conjunction of those witnesses to witness `T <: L&R` // - // A few of these cases case if the `superType` is a `DeclRefType` - // and, if so, want to compare its `DeclRef` against others. As - // such, we will extract the `DeclRef` here, if it exists, - // as a convienience. + return m_astBuilder->getConjunctionSubtypeWitness( + subType, + conjunctionSuperType, + leftWitness, + rightWitness); + } + else if (auto extractExistentialType = as(subType)) + { + // An ExtractExistentialType from an existential value of type I + // is a subtype of I. + // We need to check and make sure the interface type of the `ExtractExistentialType` + // is equal to `superType`. // - DeclRef superTypeDeclRef; - if (auto superDeclRefType = as(superType)) - { - superTypeDeclRef = superDeclRefType->getDeclRef(); - } - - if (as(subType)) + // TODO(tfoley): We could add support for `ExtractExistentialType` to + // the inheritance linearization logic, and eliminate this case. + // + auto interfaceDeclRef = extractExistentialType->getOriginalInterfaceDeclRef(); + if (interfaceDeclRef.equals(superTypeDeclRef)) { - // A __Dynamic type always conforms to the interface via its witness table. - auto witness = m_astBuilder->getOrCreate(subType, superType); + auto witness = extractExistentialType->getSubtypeWitness(); return witness; } - else if (auto conjunctionSuperType = as(superType)) - { - // We know that `T <: L & R` if `T <: L` and `T <: R`. - // - // We therefore simply recursively test both `T <: L` - // and `T <: R`. - // - auto leftWitness = isSubtype(subType, conjunctionSuperType->getLeft(), IsSubTypeOptions::None); - if (!leftWitness) return nullptr; - // - auto rightWitness = isSubtype(subType, conjunctionSuperType->getRight(), IsSubTypeOptions::None); - if (!rightWitness) return nullptr; - - // If both of the sub-relationships hold, we can construct - // a conjunction of those witnesses to witness `T <: L&R` - // - return m_astBuilder->getConjunctionSubtypeWitness( - subType, - conjunctionSuperType, - leftWitness, - rightWitness); - } - else if (auto extractExistentialType = as(subType)) - { - // An ExtractExistentialType from an existential value of type I - // is a subtype of I. - // We need to check and make sure the interface type of the `ExtractExistentialType` - // is equal to `superType`. - // - // TODO(tfoley): We could add support for `ExtractExistentialType` to - // the inheritance linearization logic, and eliminate this case. - // - auto interfaceDeclRef = extractExistentialType->getOriginalInterfaceDeclRef(); - if (interfaceDeclRef.equals(superTypeDeclRef)) - { - auto witness = extractExistentialType->getSubtypeWitness(); - return witness; - } - return nullptr; - } - else if (auto subTypePack = as(subType)) - { - // A type pack (T0, T1, ...) is a subtype of supType, if each of its elements - // is a subtype of the supType. - ShortList elementWitnesses; - for (Index i = 0; i < subTypePack->getTypeCount(); i++) - { - auto elementWitness = isSubtype(subTypePack->getElementType(i), superType, IsSubTypeOptions::None); - if (!elementWitness) - return nullptr; - elementWitnesses.add(elementWitness); - } - return m_astBuilder->getSubtypeWitnessPack(subType, superType, elementWitnesses.getArrayView().arrayView); - } - else if (auto expandType = as(subType)) - { - // A expand type `expand patternType, captureList` is a subtype of supType, if patternType is a subtype of supType. - auto elementWitness = isSubtype(expandType->getPatternType(), superType, IsSubTypeOptions::None); - if (!elementWitness) - return nullptr; - return m_astBuilder->getExpandSubtypeWitness(subType, superType, elementWitness); - } - else if (auto eachType = as(subType)) + return nullptr; + } + else if (auto subTypePack = as(subType)) + { + // A type pack (T0, T1, ...) is a subtype of supType, if each of its elements + // is a subtype of the supType. + ShortList elementWitnesses; + for (Index i = 0; i < subTypePack->getTypeCount(); i++) { - auto elementWitness = isSubtype(eachType->getElementType(), superType, IsSubTypeOptions::None); + auto elementWitness = + isSubtype(subTypePack->getElementType(i), superType, IsSubTypeOptions::None); if (!elementWitness) return nullptr; - return m_astBuilder->getEachSubtypeWitness(subType, superType, elementWitness); + elementWitnesses.add(elementWitness); } - // default is failure - return nullptr; + return m_astBuilder->getSubtypeWitnessPack( + subType, + superType, + elementWitnesses.getArrayView().arrayView); } - - bool SemanticsVisitor::isValidGenericConstraintType(Type* type) + else if (auto expandType = as(subType)) { - if (auto andType = as(type)) - { - return isValidGenericConstraintType(andType->getLeft()) && isValidGenericConstraintType(andType->getRight()); - } - return isInterfaceType(type); + // A expand type `expand patternType, captureList` is a subtype of supType, if patternType + // is a subtype of supType. + auto elementWitness = + isSubtype(expandType->getPatternType(), superType, IsSubTypeOptions::None); + if (!elementWitness) + return nullptr; + return m_astBuilder->getExpandSubtypeWitness(subType, superType, elementWitness); } + else if (auto eachType = as(subType)) + { + auto elementWitness = + isSubtype(eachType->getElementType(), superType, IsSubTypeOptions::None); + if (!elementWitness) + return nullptr; + return m_astBuilder->getEachSubtypeWitness(subType, superType, elementWitness); + } + // default is failure + return nullptr; +} - SubtypeWitness* SemanticsVisitor::isTypeDifferentiable(Type* type) +bool SemanticsVisitor::isValidGenericConstraintType(Type* type) +{ + if (auto andType = as(type)) { - if (auto valueWitness = isSubtype(type, m_astBuilder->getDiffInterfaceType(), IsSubTypeOptions::None)) - return valueWitness; - else if (auto ptrWitness = isSubtype(type, m_astBuilder->getDifferentiableRefInterfaceType(), IsSubTypeOptions::None)) - return ptrWitness; - - return nullptr; + return isValidGenericConstraintType(andType->getLeft()) && + isValidGenericConstraintType(andType->getRight()); } + return isInterfaceType(type); +} - bool SemanticsVisitor::doesTypeHaveTag(Type* type, TypeTag tag) +SubtypeWitness* SemanticsVisitor::isTypeDifferentiable(Type* type) +{ + if (auto valueWitness = + isSubtype(type, m_astBuilder->getDiffInterfaceType(), IsSubTypeOptions::None)) + return valueWitness; + else if ( + auto ptrWitness = isSubtype( + type, + m_astBuilder->getDifferentiableRefInterfaceType(), + IsSubTypeOptions::None)) + return ptrWitness; + + return nullptr; +} + +bool SemanticsVisitor::doesTypeHaveTag(Type* type, TypeTag tag) +{ + if (auto arrayType = as(type)) { - if (auto arrayType = as(type)) - { - return doesTypeHaveTag(arrayType->getElementType(), tag); - } - if (auto modifiedType = as(type)) - { - return doesTypeHaveTag(modifiedType->getBase(), tag); - } - if (auto declRefType = as(type)) - { - if (auto aggTypeDecl = as(declRefType->getDeclRef())) - return aggTypeDecl.getDecl()->hasTag(tag); - } - return false; + return doesTypeHaveTag(arrayType->getElementType(), tag); + } + if (auto modifiedType = as(type)) + { + return doesTypeHaveTag(modifiedType->getBase(), tag); } + if (auto declRefType = as(type)) + { + if (auto aggTypeDecl = as(declRefType->getDeclRef())) + return aggTypeDecl.getDecl()->hasTag(tag); + } + return false; +} - TypeTag SemanticsVisitor::getTypeTags(Type* type) +TypeTag SemanticsVisitor::getTypeTags(Type* type) +{ + if (auto arrayType = as(type)) { - if (auto arrayType = as(type)) + auto typeTag = getTypeTags(arrayType->getElementType()); + bool sized = false; + if (auto cint = as(arrayType->getElementCount())) { - auto typeTag = getTypeTags(arrayType->getElementType()); - bool sized = false; - if (auto cint = as(arrayType->getElementCount())) - { - if (cint->getValue() != kUnsizedArrayMagicLength) - { - sized = true; - } - } - else if (arrayType->getElementCount()) + if (cint->getValue() != kUnsizedArrayMagicLength) { sized = true; - typeTag = (TypeTag)((int)typeTag | (int)TypeTag::LinkTimeSized); } - if (!sized) - typeTag = (TypeTag)((int)typeTag | (int)TypeTag::Unsized); - - return typeTag; - } - if (auto modifiedType = as(type)) - { - return getTypeTags(modifiedType->getBase()); - } - if (auto parameterGroupType = as(type)) - { - auto elementTags = getTypeTags(parameterGroupType->getElementType()); - elementTags = (TypeTag)((int)elementTags & ~(int)TypeTag::Unsized); - return elementTags; } - else if (auto declRefType = as(type)) + else if (arrayType->getElementCount()) { - if (auto aggTypeDecl = as(declRefType->getDeclRef())) - return aggTypeDecl.getDecl()->typeTags; + sized = true; + typeTag = (TypeTag)((int)typeTag | (int)TypeTag::LinkTimeSized); } - return TypeTag::None; - } + if (!sized) + typeTag = (TypeTag)((int)typeTag | (int)TypeTag::Unsized); - - Type* SemanticsVisitor::getConstantBufferElementType(Type* type) + return typeTag; + } + if (auto modifiedType = as(type)) { - if (auto arrType = as(type)) - return getConstantBufferElementType(arrType->getElementType()); - if (auto modifiedType = as(type)) - return getConstantBufferElementType(modifiedType->getBase()); - if (auto constantBuffer = as(type)) - return constantBuffer->getElementType(); - if (auto parameterBlock = as(type)) - return parameterBlock->getElementType(); - return nullptr; + return getTypeTags(modifiedType->getBase()); } - - - SubtypeWitness* SemanticsVisitor::tryGetInterfaceConformanceWitness( - Type* type, - Type* interfaceType) + if (auto parameterGroupType = as(type)) { - return isSubtype(type, interfaceType, IsSubTypeOptions::None); + auto elementTags = getTypeTags(parameterGroupType->getElementType()); + elementTags = (TypeTag)((int)elementTags & ~(int)TypeTag::Unsized); + return elementTags; } - - TypeEqualityWitness* SemanticsVisitor::createTypeEqualityWitness( - Type* type) + else if (auto declRefType = as(type)) { - return m_astBuilder->getTypeEqualityWitness(type); + if (auto aggTypeDecl = as(declRefType->getDeclRef())) + return aggTypeDecl.getDecl()->typeTags; } + return TypeTag::None; +} + + +Type* SemanticsVisitor::getConstantBufferElementType(Type* type) +{ + if (auto arrType = as(type)) + return getConstantBufferElementType(arrType->getElementType()); + if (auto modifiedType = as(type)) + return getConstantBufferElementType(modifiedType->getBase()); + if (auto constantBuffer = as(type)) + return constantBuffer->getElementType(); + if (auto parameterBlock = as(type)) + return parameterBlock->getElementType(); + return nullptr; +} + + +SubtypeWitness* SemanticsVisitor::tryGetInterfaceConformanceWitness(Type* type, Type* interfaceType) +{ + return isSubtype(type, interfaceType, IsSubTypeOptions::None); +} + +TypeEqualityWitness* SemanticsVisitor::createTypeEqualityWitness(Type* type) +{ + return m_astBuilder->getTypeEqualityWitness(type); } +} // namespace Slang diff --git a/source/slang/slang-check-constraint.cpp b/source/slang/slang-check-constraint.cpp index b21365338..13e3e4be6 100644 --- a/source/slang/slang-check-constraint.cpp +++ b/source/slang/slang-check-constraint.cpp @@ -56,1233 +56,1327 @@ namespace Slang { - Type* SemanticsVisitor::TryJoinVectorAndScalarType( - ConstraintSystem* constraints, - VectorExpressionType* vectorType, - BasicExpressionType* scalarType) +Type* SemanticsVisitor::TryJoinVectorAndScalarType( + ConstraintSystem* constraints, + VectorExpressionType* vectorType, + BasicExpressionType* scalarType) +{ + // Join( vector, S ) -> vetor + // + // That is, the join of a vector and a scalar type is + // a vector type with a joined element type. + auto joinElementType = TryJoinTypes(constraints, vectorType->getElementType(), scalarType); + if (!joinElementType) + return nullptr; + + return createVectorType(joinElementType, vectorType->getElementCount()); +} + +Type* SemanticsVisitor::_tryJoinTypeWithInterface( + ConstraintSystem* constraints, + Type* type, + Type* interfaceType) +{ + // The most basic test here should be: does the type declare conformance to the trait. + + if (constraints->subTypeForAdditionalWitnesses == type) { - // Join( vector, S ) -> vetor - // - // That is, the join of a vector and a scalar type is - // a vector type with a joined element type. - auto joinElementType = TryJoinTypes( - constraints, - vectorType->getElementType(), - scalarType); - if(!joinElementType) - return nullptr; - - return createVectorType( - joinElementType, - vectorType->getElementCount()); + // If additional subtype witnesses are provided for `type` in `constraints`, + // try to use them to see if the interface is satisfied. + if (constraints->additionalSubtypeWitnesses->containsKey(interfaceType)) + return type; } - - Type* SemanticsVisitor::_tryJoinTypeWithInterface( - ConstraintSystem* constraints, - Type* type, - Type* interfaceType) + else { - // The most basic test here should be: does the type declare conformance to the trait. - - if (constraints->subTypeForAdditionalWitnesses == type) - { - // If additional subtype witnesses are provided for `type` in `constraints`, - // try to use them to see if the interface is satisfied. - if (constraints->additionalSubtypeWitnesses->containsKey(interfaceType)) - return type; - } - else - { - if (isSubtype( + if (isSubtype( type, interfaceType, - constraints->additionalSubtypeWitnesses ? IsSubTypeOptions::NoCaching : IsSubTypeOptions::None)) - return type; - } + constraints->additionalSubtypeWitnesses ? IsSubTypeOptions::NoCaching + : IsSubTypeOptions::None)) + return type; + } - // Just because `type` doesn't conform to the given `interfaceDeclRef`, that - // doesn't necessarily indicate a failure. It is possible that we have a call - // like `sqrt(2)` so that `type` is `int` and `interfaceDeclRef` is - // `__BuiltinFloatingPointType`. The "obvious" answer is that we should infer - // the type `float`, but it seems like the compiler would have to synthesize - // that answer from thin air. - // - // A robsut/correct solution here might be to enumerate set of types types `S` - // such that for each type `X` in `S`: - // - // * `type` is implicitly convertible to `X` - // * `X` conforms to the interface named by `interfaceDeclRef` - // - // If the set `S` is non-empty then we would try to pick the "best" type from `S`. - // The "best" type would be a type `Y` such that `Y` is implicitly convertible to - // every other type in `S`. - // - // We are going to implement a much simpler strategy for now, where we only apply - // the search process if `type` is a builtin scalar type, and then we only search - // through types `X` that are also builtin scalar types. - // - Type* bestType = nullptr; - ConversionCost bestCost = kConversionCost_Explicit; - if(auto basicType = dynamicCast(type)) + // Just because `type` doesn't conform to the given `interfaceDeclRef`, that + // doesn't necessarily indicate a failure. It is possible that we have a call + // like `sqrt(2)` so that `type` is `int` and `interfaceDeclRef` is + // `__BuiltinFloatingPointType`. The "obvious" answer is that we should infer + // the type `float`, but it seems like the compiler would have to synthesize + // that answer from thin air. + // + // A robsut/correct solution here might be to enumerate set of types types `S` + // such that for each type `X` in `S`: + // + // * `type` is implicitly convertible to `X` + // * `X` conforms to the interface named by `interfaceDeclRef` + // + // If the set `S` is non-empty then we would try to pick the "best" type from `S`. + // The "best" type would be a type `Y` such that `Y` is implicitly convertible to + // every other type in `S`. + // + // We are going to implement a much simpler strategy for now, where we only apply + // the search process if `type` is a builtin scalar type, and then we only search + // through types `X` that are also builtin scalar types. + // + Type* bestType = nullptr; + ConversionCost bestCost = kConversionCost_Explicit; + if (auto basicType = dynamicCast(type)) + { + for (Int baseTypeFlavorIndex = 0; baseTypeFlavorIndex < Int(BaseType::CountOf); + baseTypeFlavorIndex++) { - for(Int baseTypeFlavorIndex = 0; baseTypeFlavorIndex < Int(BaseType::CountOf); baseTypeFlavorIndex++) + // Don't consider `type`, since we already know it doesn't work. + if (baseTypeFlavorIndex == Int(basicType->getBaseType())) + continue; + + // Look up the type in our session. + auto candidateType = + getCurrentASTBuilder()->getBuiltinType(BaseType(baseTypeFlavorIndex)); + if (!candidateType) + continue; + + // We only want to consider types that implement the target interface. + if (!isSubtype(candidateType, interfaceType, IsSubTypeOptions::None)) + continue; + + // We only want to consider types where we can implicitly convert from `type` + auto conversionCost = getConversionCost(candidateType, type); + if (!canConvertImplicitly(conversionCost)) + continue; + + // At this point, we have a candidate type that is usable. + // + // If this is our first viable candidate, then it is our best one: + // + if (!bestType) { - // Don't consider `type`, since we already know it doesn't work. - if(baseTypeFlavorIndex == Int(basicType->getBaseType())) - continue; - - // Look up the type in our session. - auto candidateType = getCurrentASTBuilder()->getBuiltinType(BaseType(baseTypeFlavorIndex)); - if(!candidateType) - continue; - - // We only want to consider types that implement the target interface. - if(!isSubtype(candidateType, interfaceType, IsSubTypeOptions::None)) - continue; - - // We only want to consider types where we can implicitly convert from `type` - auto conversionCost = getConversionCost(candidateType, type); - if(!canConvertImplicitly(conversionCost)) - continue; - - // At this point, we have a candidate type that is usable. + bestType = candidateType; + } + else + { + // Otherwise, we want to pick the "better" type between `candidateType` + // and `bestType`. // - // If this is our first viable candidate, then it is our best one: + // The candidate type that has lower conversion cost from `type` is better. // - if(!bestType) - { - bestType = candidateType; - } - else + if (conversionCost < bestCost) { - // Otherwise, we want to pick the "better" type between `candidateType` - // and `bestType`. + // Our candidate can convert to the current "best" type, so + // it is logically a more specific type that satisfies our + // constraints, therefore we should keep it. // - // The candidate type that has lower conversion cost from `type` is better. - // - if(conversionCost < bestCost) - { - // Our candidate can convert to the current "best" type, so - // it is logically a more specific type that satisfies our - // constraints, therefore we should keep it. - // - bestType = candidateType; - bestCost = conversionCost; - } + bestType = candidateType; + bestCost = conversionCost; } } - if(bestType) - return bestType; } + if (bestType) + return bestType; + } - // If `interfaceType` represents some generic interface type, such as `IFoo`, and `type` conforms to - // some `IFoo`, then we should attempt to unify the them to discover constraints for - // `T`. - if (auto interfaceDeclRef = isDeclRefTypeOf(interfaceType)) + // If `interfaceType` represents some generic interface type, such as `IFoo`, and `type` + // conforms to some `IFoo`, then we should attempt to unify the them to discover constraints + // for `T`. + if (auto interfaceDeclRef = isDeclRefTypeOf(interfaceType)) + { + if (as(interfaceDeclRef.declRefBase)) { - if (as(interfaceDeclRef.declRefBase)) + auto inheritanceInfo = getShared()->getInheritanceInfo(type); + for (auto facet : inheritanceInfo.facets) { - auto inheritanceInfo = getShared()->getInheritanceInfo(type); - for (auto facet : inheritanceInfo.facets) + if (facet->origin.declRef.getDecl() == interfaceDeclRef.getDecl()) { - if (facet->origin.declRef.getDecl() == interfaceDeclRef.getDecl()) - { - auto unificationResult = TryUnifyTypes( - *constraints, - ValUnificationContext(), - QualType(facet->getType()), - interfaceType); - - if (unificationResult) - return type; - } + auto unificationResult = TryUnifyTypes( + *constraints, + ValUnificationContext(), + QualType(facet->getType()), + interfaceType); + + if (unificationResult) + return type; } - if (constraints->subTypeForAdditionalWitnesses) + } + if (constraints->subTypeForAdditionalWitnesses) + { + for (auto witnessKV : *constraints->additionalSubtypeWitnesses) { - for (auto witnessKV : *constraints->additionalSubtypeWitnesses) - { - auto unificationResult = TryUnifyTypes(*constraints, ValUnificationContext(), QualType(witnessKV.first), interfaceType); - if (unificationResult) - return type; - } + auto unificationResult = TryUnifyTypes( + *constraints, + ValUnificationContext(), + QualType(witnessKV.first), + interfaceType); + if (unificationResult) + return type; } } } + } - // For all other cases, we will just bail out for now. - // - // TODO: In the future we should build some kind of side data structure - // to accelerate either one or both of these queries: - // - // * Given a type `T`, what types `U` can it convert to implicitly? - // - // * Given an interface `I`, what types `U` conform to it? - // - // The intersection of the sets returned by these two queries is - // the set of candidates we would like to consider here. + // For all other cases, we will just bail out for now. + // + // TODO: In the future we should build some kind of side data structure + // to accelerate either one or both of these queries: + // + // * Given a type `T`, what types `U` can it convert to implicitly? + // + // * Given an interface `I`, what types `U` conform to it? + // + // The intersection of the sets returned by these two queries is + // the set of candidates we would like to consider here. + + return nullptr; +} - return nullptr; - } +Type* SemanticsVisitor::TryJoinTypes(ConstraintSystem* constraints, QualType left, QualType right) +{ + // Easy case: they are the same type! + if (left->equals(right)) + return left; - Type* SemanticsVisitor::TryJoinTypes( - ConstraintSystem* constraints, - QualType left, - QualType right) + // We can join two basic types by picking the "better" of the two + if (auto leftBasic = as(left)) { - // Easy case: they are the same type! - if (left->equals(right)) - return left; - - // We can join two basic types by picking the "better" of the two - if (auto leftBasic = as(left)) + if (auto rightBasic = as(right)) { - if (auto rightBasic = as(right)) - { - auto costConvertRightToLeft = getConversionCost(leftBasic, right); - auto costConvertLeftToRight = getConversionCost(rightBasic, left); + auto costConvertRightToLeft = getConversionCost(leftBasic, right); + auto costConvertLeftToRight = getConversionCost(rightBasic, left); - // Return the one that had lower conversion cost. - if (costConvertRightToLeft > costConvertLeftToRight) - return right; - else - { - return left; - } - } - - // We can also join a vector and a scalar - if(auto rightVector = as(right)) + // Return the one that had lower conversion cost. + if (costConvertRightToLeft > costConvertLeftToRight) + return right; + else { - return TryJoinVectorAndScalarType(constraints, rightVector, leftBasic); + return left; } } - // We can join two vector types by joining their element types - // (and also their sizes...) - if( auto leftVector = as(left)) + // We can also join a vector and a scalar + if (auto rightVector = as(right)) { - if(auto rightVector = as(right)) - { - // Check if the vector sizes match - if(!leftVector->getElementCount()->equals(rightVector->getElementCount())) - return nullptr; - - // Try to join the element types - auto joinElementType = TryJoinTypes( - constraints, - QualType(leftVector->getElementType(), left.isLeftValue), - QualType(rightVector->getElementType(), right.isLeftValue)); - if(!joinElementType) - return nullptr; + return TryJoinVectorAndScalarType(constraints, rightVector, leftBasic); + } + } - return createVectorType( - joinElementType, - leftVector->getElementCount()); - } + // We can join two vector types by joining their element types + // (and also their sizes...) + if (auto leftVector = as(left)) + { + if (auto rightVector = as(right)) + { + // Check if the vector sizes match + if (!leftVector->getElementCount()->equals(rightVector->getElementCount())) + return nullptr; + + // Try to join the element types + auto joinElementType = TryJoinTypes( + constraints, + QualType(leftVector->getElementType(), left.isLeftValue), + QualType(rightVector->getElementType(), right.isLeftValue)); + if (!joinElementType) + return nullptr; + + return createVectorType(joinElementType, leftVector->getElementCount()); + } - // We can also join a vector and a scalar - if(auto rightBasic = as(right)) - { - return TryJoinVectorAndScalarType(constraints, leftVector, rightBasic); - } + // We can also join a vector and a scalar + if (auto rightBasic = as(right)) + { + return TryJoinVectorAndScalarType(constraints, leftVector, rightBasic); } + } - // HACK: trying to work trait types in here... - if(auto leftDeclRefType = as(left)) + // HACK: trying to work trait types in here... + if (auto leftDeclRefType = as(left)) + { + if (auto leftInterfaceRef = leftDeclRefType->getDeclRef().as()) { - if( auto leftInterfaceRef = leftDeclRefType->getDeclRef().as() ) - { - // - return _tryJoinTypeWithInterface(constraints, right, left); - } + // + return _tryJoinTypeWithInterface(constraints, right, left); } - if(auto rightDeclRefType = as(right)) + } + if (auto rightDeclRefType = as(right)) + { + if (auto rightInterfaceRef = rightDeclRefType->getDeclRef().as()) { - if( auto rightInterfaceRef = rightDeclRefType->getDeclRef().as() ) - { - // - return _tryJoinTypeWithInterface(constraints, left, right); - } + // + return _tryJoinTypeWithInterface(constraints, left, right); } + } - // We can recursively join two TypePacks. - if (auto leftTypePack = as(left)) + // We can recursively join two TypePacks. + if (auto leftTypePack = as(left)) + { + if (auto rightTypePack = as(right)) { - if (auto rightTypePack = as(right)) + if (leftTypePack->getTypeCount() != rightTypePack->getTypeCount()) + return nullptr; + ShortList joinedTypes; + for (Index i = 0; i < leftTypePack->getTypeCount(); ++i) { - if(leftTypePack->getTypeCount() != rightTypePack->getTypeCount()) + auto joinedType = TryJoinTypes( + constraints, + QualType(leftTypePack->getElementType(i), left.isLeftValue), + QualType(rightTypePack->getElementType(i), right.isLeftValue)); + if (!joinedType) return nullptr; - ShortList joinedTypes; - for (Index i = 0; i < leftTypePack->getTypeCount(); ++i) - { - auto joinedType = TryJoinTypes( - constraints, - QualType(leftTypePack->getElementType(i), left.isLeftValue), - QualType(rightTypePack->getElementType(i), right.isLeftValue)); - if(!joinedType) - return nullptr; - joinedTypes.add(joinedType); - } - return m_astBuilder->getTypePack(joinedTypes.getArrayView().arrayView); + joinedTypes.add(joinedType); } + return m_astBuilder->getTypePack(joinedTypes.getArrayView().arrayView); } - - // TODO: all the cases for vectors apply to matrices too! - - // Default case is that we just fail. - return nullptr; } - DeclRef SemanticsVisitor::trySolveConstraintSystem( - ConstraintSystem* system, - DeclRef genericDeclRef, - ArrayView knownGenericArgs, - ConversionCost& outBaseCost) - { - ensureDecl(genericDeclRef.getDecl(), DeclCheckState::ReadyForLookup); + // TODO: all the cases for vectors apply to matrices too! - outBaseCost = kConversionCost_None; + // Default case is that we just fail. + return nullptr; +} - // For now the "solver" is going to be ridiculously simplistic. +DeclRef SemanticsVisitor::trySolveConstraintSystem( + ConstraintSystem* system, + DeclRef genericDeclRef, + ArrayView knownGenericArgs, + ConversionCost& outBaseCost) +{ + ensureDecl(genericDeclRef.getDecl(), DeclCheckState::ReadyForLookup); + + outBaseCost = kConversionCost_None; + + // For now the "solver" is going to be ridiculously simplistic. + + // The generic itself will have some constraints, and for now we add these + // to the system of constrains we will use for solving for the type variables. + // + // TODO: we need to decide whether constraints are used like this to influence + // how we solve for type/value variables, or whether constraints in the parameter + // list just work as a validation step *after* we've solved for the types. + // + // That is, should we allow `` to be written, and cause us to "infer" + // that `T` should be the type `Int`? That seems a little silly. + // + // Eventually, though, we may want to support type identity constraints, especially + // on associated types, like `` + // These seem more reasonable to have influence constraint solving, since it could + // conceivably let us specialize a `X : IContainer` to `X` if we find + // that `X.IndexType == T`. + for (auto constraintDeclRef : + getMembersOfType(m_astBuilder, genericDeclRef)) + { + if (!TryUnifyTypes( + *system, + ValUnificationContext(), + getSub(m_astBuilder, constraintDeclRef), + getSup(m_astBuilder, constraintDeclRef))) + return DeclRef(); + } - // The generic itself will have some constraints, and for now we add these - // to the system of constrains we will use for solving for the type variables. - // - // TODO: we need to decide whether constraints are used like this to influence - // how we solve for type/value variables, or whether constraints in the parameter - // list just work as a validation step *after* we've solved for the types. - // - // That is, should we allow `` to be written, and cause us to "infer" - // that `T` should be the type `Int`? That seems a little silly. - // - // Eventually, though, we may want to support type identity constraints, especially - // on associated types, like `` - // These seem more reasonable to have influence constraint solving, since it could - // conceivably let us specialize a `X : IContainer` to `X` if we find - // that `X.IndexType == T`. - for( auto constraintDeclRef : getMembersOfType(m_astBuilder, genericDeclRef) ) + // Once have built up the initial list of constraints we are trying to satisfy, + // we will attempt to solve for each parameter in a way that satisfies all + // the constraints that apply to that parameter. + // + // Note: this is a very limited kind of solver, in that it doesn't have a + // way to make use of constraints between two or more parameters. + // + // As we go, we will build up a list of argument values for a possible + // solution for how to assign the parameters in a way that satisfies all + // the constraints. + // + ShortList args; + + // If the context is such that some of the arguments are already specified + // or known, we need to go ahead and use those arguments direclty (whether + // or not they are compatible with the constraints). + // + Count knownGenericArgCount = 0; + if (knownGenericArgs.getCount()) + { + knownGenericArgCount = knownGenericArgs.getCount(); + for (auto arg : knownGenericArgs) { - if(!TryUnifyTypes(*system, ValUnificationContext(), getSub(m_astBuilder, constraintDeclRef), getSup(m_astBuilder, constraintDeclRef))) - return DeclRef(); + args.add(arg); } + } - // Once have built up the initial list of constraints we are trying to satisfy, - // we will attempt to solve for each parameter in a way that satisfies all - // the constraints that apply to that parameter. - // - // Note: this is a very limited kind of solver, in that it doesn't have a - // way to make use of constraints between two or more parameters. - // - // As we go, we will build up a list of argument values for a possible - // solution for how to assign the parameters in a way that satisfies all - // the constraints. - // - ShortList args; - - // If the context is such that some of the arguments are already specified - // or known, we need to go ahead and use those arguments direclty (whether - // or not they are compatible with the constraints). + // The state of currently solved arguments. + struct SolvedArg + { + IntVal* val = nullptr; + bool isOptional = true; + ShortList types; + }; + ShortList solvedArgs; + + // We will then iterate over the constraints trying to solve all generic parameters. + // Note that we do not use ranged for here, because processing one constraint may lead to + // new constraints being discovered. + for (Index constraintIndex = 0; constraintIndex < system->constraints.getCount(); + constraintIndex++) + { + // Note: it is important to keep a copy of the constraint here instead of + // using a reference, because the constraint list may be modified during the + // loop as we discover new constraints. // - Count knownGenericArgCount = 0; - if (knownGenericArgs.getCount()) + auto c = system->constraints[constraintIndex]; + if (auto typeParam = as(c.decl)) { - knownGenericArgCount = knownGenericArgs.getCount(); - for (auto arg : knownGenericArgs) + SLANG_ASSERT(typeParam->parameterIndex != -1); + // If the parameter is one where we already know + // the argument value to use, we don't bother with + // trying to solve for it, and treat any constraints + // on such a parameter as implicitly solved-for. + // + if (typeParam->parameterIndex < knownGenericArgCount) { - args.add(arg); + system->constraints[constraintIndex].satisfied = true; + continue; } - } - // The state of currently solved arguments. - struct SolvedArg - { - IntVal* val = nullptr; - bool isOptional = true; - ShortList types; - }; - ShortList solvedArgs; + // If the parameter is a type pack, then we may have + // constraints that apply to invidual elements of the pack. + // We will need to handle the type pack case slightly differently. + // + bool isPack = as(typeParam) != nullptr; - // We will then iterate over the constraints trying to solve all generic parameters. - // Note that we do not use ranged for here, because processing one constraint may lead to - // new constraints being discovered. - for (Index constraintIndex = 0; constraintIndex < system->constraints.getCount(); constraintIndex++) - { - // Note: it is important to keep a copy of the constraint here instead of - // using a reference, because the constraint list may be modified during the - // loop as we discover new constraints. + // We will use a temporary list to hold the resolved types + // for this generic parameter. + // For normal type parameters, there should be only one type + // in the list. For type pack parameters, there can be one type + // for each element in the pack. // - auto c = system->constraints[constraintIndex]; - if (auto typeParam = as(c.decl)) + if (solvedArgs.getCount() <= typeParam->parameterIndex) { - SLANG_ASSERT(typeParam->parameterIndex != -1); - // If the parameter is one where we already know - // the argument value to use, we don't bother with - // trying to solve for it, and treat any constraints - // on such a parameter as implicitly solved-for. - // - if (typeParam->parameterIndex < knownGenericArgCount) - { - system->constraints[constraintIndex].satisfied = true; - continue; - } - - // If the parameter is a type pack, then we may have - // constraints that apply to invidual elements of the pack. - // We will need to handle the type pack case slightly differently. - // - bool isPack = as(typeParam) != nullptr; - - // We will use a temporary list to hold the resolved types - // for this generic parameter. - // For normal type parameters, there should be only one type - // in the list. For type pack parameters, there can be one type - // for each element in the pack. - // - if (solvedArgs.getCount() <= typeParam->parameterIndex) - { - solvedArgs.setCount(typeParam->parameterIndex + 1); - } - auto& types = solvedArgs[typeParam->parameterIndex].types; - if (!isPack) - types.setCount(1); + solvedArgs.setCount(typeParam->parameterIndex + 1); + } + auto& types = solvedArgs[typeParam->parameterIndex].types; + if (!isPack) + types.setCount(1); - bool& typeConstraintOptional = solvedArgs[typeParam->parameterIndex].isOptional; + bool& typeConstraintOptional = solvedArgs[typeParam->parameterIndex].isOptional; - QualType* ptype = nullptr; - if (isPack) - { - types.setCount(Math::Max(types.getCount(), c.indexInPack + 1)); - ptype = &types[c.indexInPack]; - } - else - ptype = &types[0]; - QualType& type = *ptype; - - auto cType = QualType(as(c.val), c.isUsedAsLValue); - SLANG_RELEASE_ASSERT(cType); + QualType* ptype = nullptr; + if (isPack) + { + types.setCount(Math::Max(types.getCount(), c.indexInPack + 1)); + ptype = &types[c.indexInPack]; + } + else + ptype = &types[0]; + QualType& type = *ptype; - if (!type || (typeConstraintOptional && !c.isOptional)) - { - type = cType; - typeConstraintOptional = c.isOptional; - } - else if (!typeConstraintOptional) - { - // If the type parameter is already constrained to a known type, - // we need to make sure our resolved type can satisfy both constraints. - // We do so by updating the resolved type to be the "join" of the current - // solution and the type in the new constraint. If such join cannot be found, - // it means it is not possible to have a compatible solution that meets all - // constraints and we should fail. - // - // Another detail here is that during type joining, we may discover - // new constraints from the base types of the types being joined. - // We will pass the constraint system to `TryJoinTypes` which can - // add new constraints to the system, and we will process the new constraints - // in the next iteration. - // - auto joinType = TryJoinTypes(system, type, cType); - if (!joinType) - { - // failure! - return DeclRef(); - } - type = QualType(joinType, type.isLeftValue || cType.isLeftValue); - } + auto cType = QualType(as(c.val), c.isUsedAsLValue); + SLANG_RELEASE_ASSERT(cType); - c.satisfied = true; + if (!type || (typeConstraintOptional && !c.isOptional)) + { + type = cType; + typeConstraintOptional = c.isOptional; } - else if (auto valParam = as(c.decl)) + else if (!typeConstraintOptional) { - SLANG_ASSERT(valParam->parameterIndex != -1); - - // If the parameter is one where we already know - // the argument value to use, we don't bother with - // trying to solve for it, and treat any constraints - // on such a parameter as implicitly solved-for. + // If the type parameter is already constrained to a known type, + // we need to make sure our resolved type can satisfy both constraints. + // We do so by updating the resolved type to be the "join" of the current + // solution and the type in the new constraint. If such join cannot be found, + // it means it is not possible to have a compatible solution that meets all + // constraints and we should fail. // - if (valParam->parameterIndex < knownGenericArgCount) - { - system->constraints[constraintIndex].satisfied = true; - continue; - } - - if (solvedArgs.getCount() <= valParam->parameterIndex) - solvedArgs.setCount(valParam->parameterIndex + 1); - IntVal*& val = solvedArgs[valParam->parameterIndex].val; - bool& valOptional = solvedArgs[valParam->parameterIndex].isOptional; - - auto cVal = as(c.val); - SLANG_RELEASE_ASSERT(cVal); - - if (!val || (valOptional && !c.isOptional)) - { - val = cVal; - valOptional = c.isOptional; - } - else + // Another detail here is that during type joining, we may discover + // new constraints from the base types of the types being joined. + // We will pass the constraint system to `TryJoinTypes` which can + // add new constraints to the system, and we will process the new constraints + // in the next iteration. + // + auto joinType = TryJoinTypes(system, type, cType); + if (!joinType) { - if(!valOptional && !val->equals(cVal)) - { - // failure! - return DeclRef(); - } + // failure! + return DeclRef(); } - - c.satisfied = true; + type = QualType(joinType, type.isLeftValue || cType.isLeftValue); } - system->constraints[constraintIndex].satisfied = c.satisfied; - } - // After we processed all constraints, `solvedTypes` and `solvedVals` - // should have been filled with the resolved types and values for the - // generic parameters. We can now verify if they are complete and consolidate - // them into final argument list. - for (auto member : genericDeclRef.getDecl()->members) + c.satisfied = true; + } + else if (auto valParam = as(c.decl)) { - if (auto typeParam = as(member)) - { - SLANG_ASSERT(typeParam->parameterIndex != -1); - - if (typeParam->parameterIndex < knownGenericArgCount) - continue; - bool isPack = as(typeParam) != nullptr; - if (typeParam->parameterIndex >= solvedArgs.getCount()) - { - // If the parameter is not a type pack and we don't have a - // resolved type for it, we should fail. - if (!isPack) - return DeclRef(); - // If the parameter is a type pack, we should add an empty - // type list to solvedTypes. - solvedArgs.setCount(typeParam->parameterIndex + 1); - } - auto& types = solvedArgs[typeParam->parameterIndex].types; - // Fail if any of the resolved type element is empty. - for (auto t : types) - { - if (!t) - return DeclRef(); - } - if (!isPack) - { - // If the generic parameter is not a pack, we can simply add the first type. - if (types.getCount() != 1) - return DeclRef(); + SLANG_ASSERT(valParam->parameterIndex != -1); - args.add(types[0]); - } - else - { - // If the generic parameter is a pack, and we are supplying one single pack argument, - // we can use it as is. - if (types.getCount() == 1 && isTypePack(types[0])) - { - args.add(types[0]); - } - else - { - // If we are supplying 0 or multiple arguments for the pack, we need to create a type pack - // and add it to the argument list. - ShortList typeList; - bool isLVal = true; - for (auto t : types) - { - typeList.add(t); - isLVal = isLVal && t.isLeftValue; - } - args.add(QualType(m_astBuilder->getTypePack(typeList.getArrayView().arrayView), isLVal)); - } - } - } - else if (auto valParam = as(member)) + // If the parameter is one where we already know + // the argument value to use, we don't bother with + // trying to solve for it, and treat any constraints + // on such a parameter as implicitly solved-for. + // + if (valParam->parameterIndex < knownGenericArgCount) { - SLANG_ASSERT(valParam->parameterIndex != -1); + system->constraints[constraintIndex].satisfied = true; + continue; + } - if (valParam->parameterIndex < knownGenericArgCount) - continue; + if (solvedArgs.getCount() <= valParam->parameterIndex) + solvedArgs.setCount(valParam->parameterIndex + 1); + IntVal*& val = solvedArgs[valParam->parameterIndex].val; + bool& valOptional = solvedArgs[valParam->parameterIndex].isOptional; - if (valParam->parameterIndex >= solvedArgs.getCount()) - return DeclRef(); + auto cVal = as(c.val); + SLANG_RELEASE_ASSERT(cVal); - auto val = solvedArgs[valParam->parameterIndex].val; - if (!val) + if (!val || (valOptional && !c.isOptional)) + { + val = cVal; + valOptional = c.isOptional; + } + else + { + if (!valOptional && !val->equals(cVal)) { // failure! return DeclRef(); } - args.add(val); } - } - - // After we've solved for the explicit arguments, we need to - // make a second pass and consider the implicit arguments, - // based on what we've already determined to be the values - // for the explicit arguments. - - // Before we begin, we are going to go ahead and create the - // "solved" substitution that we will return if everything works. - // This is because we are going to use this substitution, - // partially filled in with the results we know so far, - // in order to specialize any constraints on the generic. - // - // E.g., if the generic parameters were ``, and - // we've already decided that `T` is `Robin`, then we want to - // search for a conformance `Robin : ISidekick`, which involved - // apply the substitutions we already know... - HashSet constrainedGenericParams; + c.satisfied = true; + } + system->constraints[constraintIndex].satisfied = c.satisfied; + } - for (auto constraintDecl : genericDeclRef.getDecl()->getMembersOfType()) + // After we processed all constraints, `solvedTypes` and `solvedVals` + // should have been filled with the resolved types and values for the + // generic parameters. We can now verify if they are complete and consolidate + // them into final argument list. + for (auto member : genericDeclRef.getDecl()->members) + { + if (auto typeParam = as(member)) { - DeclRef constraintDeclRef = m_astBuilder->getGenericAppDeclRef( - genericDeclRef, args.getArrayView().arrayView, constraintDecl).as(); + SLANG_ASSERT(typeParam->parameterIndex != -1); - // Extract the (substituted) sub- and super-type from the constraint. - auto sub = getSub(m_astBuilder, constraintDeclRef); - auto sup = getSup(m_astBuilder, constraintDeclRef); - - // Mark sub type as constrained. - if (auto subDeclRefType = as(constraintDeclRef.getDecl()->sub.type)) - constrainedGenericParams.add(subDeclRefType->getDeclRef().getDecl()); - else if (auto subEachType = as(constraintDeclRef.getDecl()->sub.type)) - constrainedGenericParams.add(as(subEachType->getElementType())->getDeclRef().getDecl()); - - if (sub->equals(sup) && isDeclRefTypeOf(sup)) + if (typeParam->parameterIndex < knownGenericArgCount) + continue; + bool isPack = as(typeParam) != nullptr; + if (typeParam->parameterIndex >= solvedArgs.getCount()) { - // We are trying to use an interface type itself to conform to the - // type constraint. We can reach this case when the user code does - // not provide an explicit type parameter to specialize a generic - // and the type parameter cannot be inferred from any arguments. - // In this case, we should fail the constraint check. - return DeclRef(); - } - - // Search for a witness that shows the constraint is satisfied. - SubtypeWitness* subTypeWitness = nullptr; - if (sub == system->subTypeForAdditionalWitnesses) - { - // If we are trying to find the subtype info for a type whose inheritance info is - // being calculated, use what we have already known about the type. - system->additionalSubtypeWitnesses->tryGetValue(sup, subTypeWitness); + // If the parameter is not a type pack and we don't have a + // resolved type for it, we should fail. + if (!isPack) + return DeclRef(); + // If the parameter is a type pack, we should add an empty + // type list to solvedTypes. + solvedArgs.setCount(typeParam->parameterIndex + 1); } - else + auto& types = solvedArgs[typeParam->parameterIndex].types; + // Fail if any of the resolved type element is empty. + for (auto t : types) { - // The general case is to initiate a subtype query. - subTypeWitness = isSubtype( - sub, - sup, - system->additionalSubtypeWitnesses ? IsSubTypeOptions::NoCaching : IsSubTypeOptions::None); + if (!t) + return DeclRef(); } - - if (constraintDecl->isEqualityConstraint) + if (!isPack) { - // If constraint is an equality constraint, we need to make sure - // the witness is equality witness. - if (!isTypeEqualityWitness(subTypeWitness)) - subTypeWitness = nullptr; - } + // If the generic parameter is not a pack, we can simply add the first type. + if (types.getCount() != 1) + return DeclRef(); - if(subTypeWitness) - { - // We found a witness, so it will become an (implicit) argument. - args.add(subTypeWitness); - outBaseCost += subTypeWitness->getOverloadResolutionCost(); + args.add(types[0]); } else { - // No witness was found, so the inference will now fail. - // - // TODO: Ideally we should print an error message in - // this case, to let the user know why things failed. - return DeclRef(); + // If the generic parameter is a pack, and we are supplying one single pack + // argument, we can use it as is. + if (types.getCount() == 1 && isTypePack(types[0])) + { + args.add(types[0]); + } + else + { + // If we are supplying 0 or multiple arguments for the pack, we need to create a + // type pack and add it to the argument list. + ShortList typeList; + bool isLVal = true; + for (auto t : types) + { + typeList.add(t); + isLVal = isLVal && t.isLeftValue; + } + args.add(QualType( + m_astBuilder->getTypePack(typeList.getArrayView().arrayView), + isLVal)); + } } - - // TODO: We may need to mark some constrains in our constraint - // system as being solved now, as a result of the witness we found. } - - // Add a flat cost to all unconstrained generic params. - for (auto typeParamDecl : genericDeclRef.getDecl()->getMembersOfType()) + else if (auto valParam = as(member)) { - if (!constrainedGenericParams.contains(typeParamDecl)) - outBaseCost += kConversionCost_UnconstraintGenericParam; - } + SLANG_ASSERT(valParam->parameterIndex != -1); - // Make sure we haven't constructed any spurious constraints - // that we aren't able to satisfy: - for (auto c : system->constraints) - { - if (!c.satisfied) + if (valParam->parameterIndex < knownGenericArgCount) + continue; + + if (valParam->parameterIndex >= solvedArgs.getCount()) + return DeclRef(); + + auto val = solvedArgs[valParam->parameterIndex].val; + if (!val) { + // failure! return DeclRef(); } + args.add(val); } - - return m_astBuilder->getGenericAppDeclRef(genericDeclRef, args.getArrayView().arrayView); } - bool SemanticsVisitor::TryUnifyVals( - ConstraintSystem& constraints, - ValUnificationContext unifyCtx, - Val* fst, - bool fstLVal, - Val* snd, - bool sndLVal) + // After we've solved for the explicit arguments, we need to + // make a second pass and consider the implicit arguments, + // based on what we've already determined to be the values + // for the explicit arguments. + + // Before we begin, we are going to go ahead and create the + // "solved" substitution that we will return if everything works. + // This is because we are going to use this substitution, + // partially filled in with the results we know so far, + // in order to specialize any constraints on the generic. + // + // E.g., if the generic parameters were ``, and + // we've already decided that `T` is `Robin`, then we want to + // search for a conformance `Robin : ISidekick`, which involved + // apply the substitutions we already know... + + HashSet constrainedGenericParams; + + for (auto constraintDecl : + genericDeclRef.getDecl()->getMembersOfType()) { - // if both values are types, then unify types - if (auto fstType = as(fst)) + DeclRef constraintDeclRef = + m_astBuilder + ->getGenericAppDeclRef( + genericDeclRef, + args.getArrayView().arrayView, + constraintDecl) + .as(); + + // Extract the (substituted) sub- and super-type from the constraint. + auto sub = getSub(m_astBuilder, constraintDeclRef); + auto sup = getSup(m_astBuilder, constraintDeclRef); + + // Mark sub type as constrained. + if (auto subDeclRefType = as(constraintDeclRef.getDecl()->sub.type)) + constrainedGenericParams.add(subDeclRefType->getDeclRef().getDecl()); + else if (auto subEachType = as(constraintDeclRef.getDecl()->sub.type)) + constrainedGenericParams.add( + as(subEachType->getElementType())->getDeclRef().getDecl()); + + if (sub->equals(sup) && isDeclRefTypeOf(sup)) { - if (auto sndType = as(snd)) - { - return TryUnifyTypes(constraints, unifyCtx, QualType(fstType, fstLVal), QualType(sndType, sndLVal)); - } + // We are trying to use an interface type itself to conform to the + // type constraint. We can reach this case when the user code does + // not provide an explicit type parameter to specialize a generic + // and the type parameter cannot be inferred from any arguments. + // In this case, we should fail the constraint check. + return DeclRef(); } - // if both values are constant integers, then compare them - if (auto fstIntVal = as(fst)) + // Search for a witness that shows the constraint is satisfied. + SubtypeWitness* subTypeWitness = nullptr; + if (sub == system->subTypeForAdditionalWitnesses) { - if (auto sndIntVal = as(snd)) - { - return fstIntVal->getValue() == sndIntVal->getValue(); - } + // If we are trying to find the subtype info for a type whose inheritance info is + // being calculated, use what we have already known about the type. + system->additionalSubtypeWitnesses->tryGetValue(sup, subTypeWitness); } - - // Check if both are integer values in general - const auto fstInt = as(fst); - const auto sndInt = as(snd); - if (fstInt && sndInt) + else { - const auto paramUnderCast = [](IntVal* i){ - if(const auto c = as(i)) - i = as(c->getBase()); - return as(i); - }; - auto fstParam = paramUnderCast(fstInt); - auto sndParam = paramUnderCast(sndInt); - - bool okay = false; - if (fstParam) - okay |= TryUnifyIntParam(constraints, unifyCtx, fstParam->getDeclRef(), sndInt); - if (sndParam) - okay |= TryUnifyIntParam(constraints, unifyCtx, sndParam->getDeclRef(), fstInt); - return okay; + // The general case is to initiate a subtype query. + subTypeWitness = isSubtype( + sub, + sup, + system->additionalSubtypeWitnesses ? IsSubTypeOptions::NoCaching + : IsSubTypeOptions::None); } - if (auto fstWit = as(fst)) + if (constraintDecl->isEqualityConstraint) { - if (auto sndWit = as(snd)) - { - auto constraintDecl1 = fstWit->getDeclRef().as(); - auto constraintDecl2 = sndWit->getDeclRef().as(); - SLANG_ASSERT(constraintDecl1); - SLANG_ASSERT(constraintDecl2); - return TryUnifyTypes(constraints, - unifyCtx, - getSup(m_astBuilder, constraintDecl1), - getSup(m_astBuilder, constraintDecl2)); - } + // If constraint is an equality constraint, we need to make sure + // the witness is equality witness. + if (!isTypeEqualityWitness(subTypeWitness)) + subTypeWitness = nullptr; } - // Two subtype witnesses can be unified if they exist (non-null) and - // prove that some pair of types are subtypes of types that can be unified. - // - if (auto fstWit = as(fst)) + if (subTypeWitness) { - if (auto sndWit = as(snd)) - { - return TryUnifyTypes(constraints, - unifyCtx, - fstWit->getSup(), - sndWit->getSup()); - } + // We found a witness, so it will become an (implicit) argument. + args.add(subTypeWitness); + outBaseCost += subTypeWitness->getOverloadResolutionCost(); + } + else + { + // No witness was found, so the inference will now fail. + // + // TODO: Ideally we should print an error message in + // this case, to let the user know why things failed. + return DeclRef(); } - SLANG_UNIMPLEMENTED_X("value unification case"); - - // default: fail - //return false; + // TODO: We may need to mark some constrains in our constraint + // system as being solved now, as a result of the witness we found. } - bool SemanticsVisitor::tryUnifyDeclRef( - ConstraintSystem& constraints, - ValUnificationContext unifyCtx, - DeclRefBase* fst, - bool fstIsLVal, - DeclRefBase* snd, - bool sndIsLVal) + // Add a flat cost to all unconstrained generic params. + for (auto typeParamDecl : genericDeclRef.getDecl()->getMembersOfType()) { - if (fst == snd) - return true; - if (fst == nullptr || snd == nullptr) - return false; - auto fstGen = SubstitutionSet(fst).findGenericAppDeclRef(); - auto sndGen = SubstitutionSet(snd).findGenericAppDeclRef(); - if (fstGen == sndGen) - return true; - if (fstGen == nullptr || sndGen == nullptr) - return false; - return tryUnifyGenericAppDeclRef(constraints, unifyCtx, fstGen, fstIsLVal, sndGen, sndIsLVal); + if (!constrainedGenericParams.contains(typeParamDecl)) + outBaseCost += kConversionCost_UnconstraintGenericParam; } - bool SemanticsVisitor::tryUnifyGenericAppDeclRef( - ConstraintSystem& constraints, - ValUnificationContext unifyCtx, - GenericAppDeclRef* fst, - bool fstIsLVal, - GenericAppDeclRef* snd, - bool sndIsLVal) + // Make sure we haven't constructed any spurious constraints + // that we aren't able to satisfy: + for (auto c : system->constraints) { - SLANG_ASSERT(fst); - SLANG_ASSERT(snd); - - auto fstGen = fst; - auto sndGen = snd; - // They must be specializing the same generic - if (fstGen->getGenericDecl() != sndGen->getGenericDecl()) - return false; - - // Their arguments must unify - SLANG_RELEASE_ASSERT(fstGen->getArgs().getCount() == sndGen->getArgs().getCount()); - Index argCount = fstGen->getArgs().getCount(); - bool okay = true; - for (Index aa = 0; aa < argCount; ++aa) + if (!c.satisfied) { - if (!TryUnifyVals(constraints, unifyCtx, fstGen->getArgs()[aa], fstIsLVal, sndGen->getArgs()[aa], sndIsLVal)) - { - okay = false; - } + return DeclRef(); } + } - // Their "base" specializations must unify - auto fstBase = fst->getBase(); - auto sndBase = snd->getBase(); + return m_astBuilder->getGenericAppDeclRef(genericDeclRef, args.getArrayView().arrayView); +} - if (!tryUnifyDeclRef(constraints, unifyCtx, fstBase, fstIsLVal, sndBase, sndIsLVal)) +bool SemanticsVisitor::TryUnifyVals( + ConstraintSystem& constraints, + ValUnificationContext unifyCtx, + Val* fst, + bool fstLVal, + Val* snd, + bool sndLVal) +{ + // if both values are types, then unify types + if (auto fstType = as(fst)) + { + if (auto sndType = as(snd)) { - okay = false; + return TryUnifyTypes( + constraints, + unifyCtx, + QualType(fstType, fstLVal), + QualType(sndType, sndLVal)); } - - return okay; } - bool SemanticsVisitor::TryUnifyTypeParam( - ConstraintSystem& constraints, - ValUnificationContext unificationContext, - GenericTypeParamDeclBase* typeParamDecl, - QualType type) + // if both values are constant integers, then compare them + if (auto fstIntVal = as(fst)) { - // We want to constrain the given type parameter - // to equal the given type. - Constraint constraint; - constraint.decl = typeParamDecl; - constraint.indexInPack = unificationContext.indexInTypePack; - constraint.val = type; - constraint.isUsedAsLValue = type.isLeftValue; - constraints.constraints.add(constraint); + if (auto sndIntVal = as(snd)) + { + return fstIntVal->getValue() == sndIntVal->getValue(); + } + } - return true; + // Check if both are integer values in general + const auto fstInt = as(fst); + const auto sndInt = as(snd); + if (fstInt && sndInt) + { + const auto paramUnderCast = [](IntVal* i) + { + if (const auto c = as(i)) + i = as(c->getBase()); + return as(i); + }; + auto fstParam = paramUnderCast(fstInt); + auto sndParam = paramUnderCast(sndInt); + + bool okay = false; + if (fstParam) + okay |= TryUnifyIntParam(constraints, unifyCtx, fstParam->getDeclRef(), sndInt); + if (sndParam) + okay |= TryUnifyIntParam(constraints, unifyCtx, sndParam->getDeclRef(), fstInt); + return okay; } - bool SemanticsVisitor::TryUnifyIntParam( - ConstraintSystem& constraints, - ValUnificationContext unifyCtx, - GenericValueParamDecl* paramDecl, - IntVal* val) + if (auto fstWit = as(fst)) { - SLANG_UNUSED(unifyCtx); + if (auto sndWit = as(snd)) + { + auto constraintDecl1 = fstWit->getDeclRef().as(); + auto constraintDecl2 = sndWit->getDeclRef().as(); + SLANG_ASSERT(constraintDecl1); + SLANG_ASSERT(constraintDecl2); + return TryUnifyTypes( + constraints, + unifyCtx, + getSup(m_astBuilder, constraintDecl1), + getSup(m_astBuilder, constraintDecl2)); + } + } - // We only want to accumulate constraints on - // the parameters of the declarations being - // specialized (don't accidentially constrain - // parameters of a generic function based on - // calls in its body). - if(paramDecl->parentDecl != constraints.genericDecl) - return false; + // Two subtype witnesses can be unified if they exist (non-null) and + // prove that some pair of types are subtypes of types that can be unified. + // + if (auto fstWit = as(fst)) + { + if (auto sndWit = as(snd)) + { + return TryUnifyTypes(constraints, unifyCtx, fstWit->getSup(), sndWit->getSup()); + } + } - // We want to constrain the given parameter to equal the given value. - Constraint constraint; - constraint.decl = paramDecl; - constraint.val = val; + SLANG_UNIMPLEMENTED_X("value unification case"); - constraints.constraints.add(constraint); + // default: fail + // return false; +} +bool SemanticsVisitor::tryUnifyDeclRef( + ConstraintSystem& constraints, + ValUnificationContext unifyCtx, + DeclRefBase* fst, + bool fstIsLVal, + DeclRefBase* snd, + bool sndIsLVal) +{ + if (fst == snd) return true; - } + if (fst == nullptr || snd == nullptr) + return false; + auto fstGen = SubstitutionSet(fst).findGenericAppDeclRef(); + auto sndGen = SubstitutionSet(snd).findGenericAppDeclRef(); + if (fstGen == sndGen) + return true; + if (fstGen == nullptr || sndGen == nullptr) + return false; + return tryUnifyGenericAppDeclRef(constraints, unifyCtx, fstGen, fstIsLVal, sndGen, sndIsLVal); +} + +bool SemanticsVisitor::tryUnifyGenericAppDeclRef( + ConstraintSystem& constraints, + ValUnificationContext unifyCtx, + GenericAppDeclRef* fst, + bool fstIsLVal, + GenericAppDeclRef* snd, + bool sndIsLVal) +{ + SLANG_ASSERT(fst); + SLANG_ASSERT(snd); + + auto fstGen = fst; + auto sndGen = snd; + // They must be specializing the same generic + if (fstGen->getGenericDecl() != sndGen->getGenericDecl()) + return false; - bool SemanticsVisitor::TryUnifyIntParam( - ConstraintSystem& constraints, - ValUnificationContext unifyCtx, - DeclRef const& varRef, - IntVal* val) + // Their arguments must unify + SLANG_RELEASE_ASSERT(fstGen->getArgs().getCount() == sndGen->getArgs().getCount()); + Index argCount = fstGen->getArgs().getCount(); + bool okay = true; + for (Index aa = 0; aa < argCount; ++aa) { - if(auto genericValueParamRef = varRef.as()) + if (!TryUnifyVals( + constraints, + unifyCtx, + fstGen->getArgs()[aa], + fstIsLVal, + sndGen->getArgs()[aa], + sndIsLVal)) { - return TryUnifyIntParam(constraints, unifyCtx, genericValueParamRef.getDecl(), val); - } - else - { - return false; + okay = false; } } - bool SemanticsVisitor::TryUnifyTypesByStructuralMatch( - ConstraintSystem& constraints, - ValUnificationContext unifyCtx, - QualType fst, - QualType snd) + // Their "base" specializations must unify + auto fstBase = fst->getBase(); + auto sndBase = snd->getBase(); + + if (!tryUnifyDeclRef(constraints, unifyCtx, fstBase, fstIsLVal, sndBase, sndIsLVal)) { - if (auto fstDeclRefType = as(fst)) + okay = false; + } + + return okay; +} + +bool SemanticsVisitor::TryUnifyTypeParam( + ConstraintSystem& constraints, + ValUnificationContext unificationContext, + GenericTypeParamDeclBase* typeParamDecl, + QualType type) +{ + // We want to constrain the given type parameter + // to equal the given type. + Constraint constraint; + constraint.decl = typeParamDecl; + constraint.indexInPack = unificationContext.indexInTypePack; + constraint.val = type; + constraint.isUsedAsLValue = type.isLeftValue; + constraints.constraints.add(constraint); + + return true; +} + +bool SemanticsVisitor::TryUnifyIntParam( + ConstraintSystem& constraints, + ValUnificationContext unifyCtx, + GenericValueParamDecl* paramDecl, + IntVal* val) +{ + SLANG_UNUSED(unifyCtx); + + // We only want to accumulate constraints on + // the parameters of the declarations being + // specialized (don't accidentially constrain + // parameters of a generic function based on + // calls in its body). + if (paramDecl->parentDecl != constraints.genericDecl) + return false; + + // We want to constrain the given parameter to equal the given value. + Constraint constraint; + constraint.decl = paramDecl; + constraint.val = val; + + constraints.constraints.add(constraint); + + return true; +} + +bool SemanticsVisitor::TryUnifyIntParam( + ConstraintSystem& constraints, + ValUnificationContext unifyCtx, + DeclRef const& varRef, + IntVal* val) +{ + if (auto genericValueParamRef = varRef.as()) + { + return TryUnifyIntParam(constraints, unifyCtx, genericValueParamRef.getDecl(), val); + } + else + { + return false; + } +} + +bool SemanticsVisitor::TryUnifyTypesByStructuralMatch( + ConstraintSystem& constraints, + ValUnificationContext unifyCtx, + QualType fst, + QualType snd) +{ + if (auto fstDeclRefType = as(fst)) + { + auto fstDeclRef = fstDeclRefType->getDeclRef(); + + if (auto typeParamDecl = as(fstDeclRef.getDecl())) + if (typeParamDecl->parentDecl == constraints.genericDecl) + return TryUnifyTypeParam(constraints, unifyCtx, typeParamDecl, snd); + + if (auto sndDeclRefType = as(snd)) { - auto fstDeclRef = fstDeclRefType->getDeclRef(); + auto sndDeclRef = sndDeclRefType->getDeclRef(); - if (auto typeParamDecl = as(fstDeclRef.getDecl())) + if (auto typeParamDecl = as(sndDeclRef.getDecl())) if (typeParamDecl->parentDecl == constraints.genericDecl) - return TryUnifyTypeParam(constraints, unifyCtx, typeParamDecl, snd); + return TryUnifyTypeParam(constraints, unifyCtx, typeParamDecl, fst); - if (auto sndDeclRefType = as(snd)) + // If they refer to different declarations, we need to check if one type's super type + // matches the other type, if so we can unify them. + if (fstDeclRef.getDecl() != sndDeclRef.getDecl()) { - auto sndDeclRef = sndDeclRefType->getDeclRef(); - - if (auto typeParamDecl = as(sndDeclRef.getDecl())) - if (typeParamDecl->parentDecl == constraints.genericDecl) - return TryUnifyTypeParam(constraints, unifyCtx, typeParamDecl, fst); - - // If they refer to different declarations, we need to check if one type's super type - // matches the other type, if so we can unify them. - if (fstDeclRef.getDecl() != sndDeclRef.getDecl()) { + auto fstTypeInheritanceInfo = getShared()->getInheritanceInfo(fstDeclRefType); + for (auto supType : fstTypeInheritanceInfo.facets) { - auto fstTypeInheritanceInfo = getShared()->getInheritanceInfo(fstDeclRefType); - for (auto supType : fstTypeInheritanceInfo.facets) + if (supType->origin.declRef.getDecl() == sndDeclRef.getDecl()) { - if (supType->origin.declRef.getDecl() == sndDeclRef.getDecl()) - { - fstDeclRef = supType->origin.declRef; - goto endMatch; - } + fstDeclRef = supType->origin.declRef; + goto endMatch; } } - // try the other direction + } + // try the other direction + { + auto sndTypeInheritanceInfo = getShared()->getInheritanceInfo(sndDeclRefType); + for (auto supType : sndTypeInheritanceInfo.facets) { - auto sndTypeInheritanceInfo = getShared()->getInheritanceInfo(sndDeclRefType); - for (auto supType : sndTypeInheritanceInfo.facets) + if (supType->origin.declRef.getDecl() == fstDeclRef.getDecl()) { - if (supType->origin.declRef.getDecl() == fstDeclRef.getDecl()) - { - sndDeclRef = supType->origin.declRef; - goto endMatch; - } + sndDeclRef = supType->origin.declRef; + goto endMatch; } } - endMatch:; - // If they still refer to different decls, then we can't unify them. - if (fstDeclRef.getDecl() != sndDeclRef.getDecl()) - return false; } + endMatch:; + // If they still refer to different decls, then we can't unify them. + if (fstDeclRef.getDecl() != sndDeclRef.getDecl()) + return false; + } - // next we need to unify the substitutions applied - // to each declaration reference. - if (!tryUnifyDeclRef( + // next we need to unify the substitutions applied + // to each declaration reference. + if (!tryUnifyDeclRef( constraints, unifyCtx, fstDeclRef, fst.isLeftValue, sndDeclRef, snd.isLeftValue)) - { - return false; - } - - return true; + { + return false; } + + return true; } - else if(auto fstFunType = as(fst)) + } + else if (auto fstFunType = as(fst)) + { + if (auto sndFunType = as(snd)) { - if (auto sndFunType = as(snd)) + const Index numParams = fstFunType->getParamCount(); + if (numParams != sndFunType->getParamCount()) + return false; + for (Index i = 0; i < numParams; ++i) { - const Index numParams = fstFunType->getParamCount(); - if(numParams != sndFunType->getParamCount()) + if (!TryUnifyTypes( + constraints, + unifyCtx, + fstFunType->getParamType(i), + sndFunType->getParamType(i))) return false; - for(Index i = 0; i < numParams; ++i) - { - if(!TryUnifyTypes(constraints, unifyCtx, fstFunType->getParamType(i), sndFunType->getParamType(i))) - return false; - } - return TryUnifyTypes(constraints, unifyCtx, fstFunType->getResultType(), sndFunType->getResultType()); } + return TryUnifyTypes( + constraints, + unifyCtx, + fstFunType->getResultType(), + sndFunType->getResultType()); } - else if (auto expandType = as(fst)) + } + else if (auto expandType = as(fst)) + { + if (auto sndExpandType = as(snd)) { - if (auto sndExpandType = as(snd)) - { - return TryUnifyTypes(constraints, unifyCtx, expandType->getPatternType(), sndExpandType->getPatternType()); - } + return TryUnifyTypes( + constraints, + unifyCtx, + expandType->getPatternType(), + sndExpandType->getPatternType()); } - else if (auto eachType = as(fst)) + } + else if (auto eachType = as(fst)) + { + if (auto sndEachType = as(snd)) { - if (auto sndEachType = as(snd)) - { - return TryUnifyTypes(constraints, unifyCtx, eachType->getElementType(), sndEachType->getElementType()); - } + return TryUnifyTypes( + constraints, + unifyCtx, + eachType->getElementType(), + sndEachType->getElementType()); } - else if (auto typePack = as(fst)) + } + else if (auto typePack = as(fst)) + { + if (auto sndTypePack = as(snd)) { - if (auto sndTypePack = as(snd)) + if (typePack->getTypeCount() != sndTypePack->getTypeCount()) + return false; + for (Index i = 0; i < typePack->getTypeCount(); ++i) { - if (typePack->getTypeCount() != sndTypePack->getTypeCount()) + if (!TryUnifyTypes( + constraints, + unifyCtx, + QualType(typePack->getElementType(i), fst.isLeftValue), + QualType(sndTypePack->getElementType(i), snd.isLeftValue))) return false; - for (Index i = 0; i < typePack->getTypeCount(); ++i) - { - if (!TryUnifyTypes(constraints, unifyCtx, QualType(typePack->getElementType(i), fst.isLeftValue), QualType(sndTypePack->getElementType(i), snd.isLeftValue))) - return false; - } - return true; } + return true; } - return false; } + return false; +} - bool SemanticsVisitor::TryUnifyConjunctionType( - ConstraintSystem& constraints, - ValUnificationContext unifyCtx, - QualType fst, - QualType snd) +bool SemanticsVisitor::TryUnifyConjunctionType( + ConstraintSystem& constraints, + ValUnificationContext unifyCtx, + QualType fst, + QualType snd) +{ + // Unifying a type `A & B` with `T` amounts to unifying + // `A` with `T` and also `B` with `T` while + // unifying a type `T` with `A & B` amounts to either + // unifying `T` with `A` or `T` with `B` + // + // If either unification is impossible, then the full + // case is also impossible. + // + if (auto fstAndType = as(fst)) { - // Unifying a type `A & B` with `T` amounts to unifying - // `A` with `T` and also `B` with `T` while - // unifying a type `T` with `A & B` amounts to either - // unifying `T` with `A` or `T` with `B` - // - // If either unification is impossible, then the full - // case is also impossible. - // - if (auto fstAndType = as(fst)) - { - return TryUnifyTypes(constraints, unifyCtx, QualType(fstAndType->getLeft(), fst.isLeftValue), snd) - && TryUnifyTypes(constraints, unifyCtx, QualType(fstAndType->getRight(), fst.isLeftValue), snd); - } - else if (auto sndAndType = as(snd)) - { - return TryUnifyTypes(constraints, unifyCtx, fst, QualType(sndAndType->getLeft(), snd.isLeftValue)) - || TryUnifyTypes(constraints, unifyCtx, fst, QualType(sndAndType->getRight(), snd.isLeftValue)); - } - else - return false; + return TryUnifyTypes( + constraints, + unifyCtx, + QualType(fstAndType->getLeft(), fst.isLeftValue), + snd) && + TryUnifyTypes( + constraints, + unifyCtx, + QualType(fstAndType->getRight(), fst.isLeftValue), + snd); } - - void SemanticsVisitor::maybeUnifyUnconstraintIntParam( - ConstraintSystem& constraints, ValUnificationContext unifyCtx, IntVal* param, IntVal* arg, bool paramIsLVal) + else if (auto sndAndType = as(snd)) { - SLANG_UNUSED(unifyCtx); - - // If `param` is an unconstrained integer val param, and `arg` is a const int val, - // we add a constraint to the system that `param` must be equal to `arg`. - // If `param` is already constrained, ignore and do nothing. - if (auto typeCastParam = as(param)) - { - param = as(typeCastParam->getBase()); - } - auto intParam = as(param); - if (!intParam) - return; - for (auto c : constraints.constraints) - if (c.decl == intParam->getDeclRef().getDecl()) - return; - Constraint c; - c.decl = intParam->getDeclRef().getDecl(); - c.isUsedAsLValue = paramIsLVal; - c.val = arg; - c.isOptional = true; - constraints.constraints.add(c); + return TryUnifyTypes( + constraints, + unifyCtx, + fst, + QualType(sndAndType->getLeft(), snd.isLeftValue)) || + TryUnifyTypes( + constraints, + unifyCtx, + fst, + QualType(sndAndType->getRight(), snd.isLeftValue)); } + else + return false; +} - bool SemanticsVisitor::TryUnifyTypes( - ConstraintSystem& constraints, - ValUnificationContext unifyCtx, - QualType fst, - QualType snd) +void SemanticsVisitor::maybeUnifyUnconstraintIntParam( + ConstraintSystem& constraints, + ValUnificationContext unifyCtx, + IntVal* param, + IntVal* arg, + bool paramIsLVal) +{ + SLANG_UNUSED(unifyCtx); + + // If `param` is an unconstrained integer val param, and `arg` is a const int val, + // we add a constraint to the system that `param` must be equal to `arg`. + // If `param` is already constrained, ignore and do nothing. + if (auto typeCastParam = as(param)) { - if (!fst) return false; + param = as(typeCastParam->getBase()); + } + auto intParam = as(param); + if (!intParam) + return; + for (auto c : constraints.constraints) + if (c.decl == intParam->getDeclRef().getDecl()) + return; + Constraint c; + c.decl = intParam->getDeclRef().getDecl(); + c.isUsedAsLValue = paramIsLVal; + c.val = arg; + c.isOptional = true; + constraints.constraints.add(c); +} - if (fst->equals(snd)) return true; +bool SemanticsVisitor::TryUnifyTypes( + ConstraintSystem& constraints, + ValUnificationContext unifyCtx, + QualType fst, + QualType snd) +{ + if (!fst) + return false; - // An error type can unify with anything, just so we avoid cascading errors. + if (fst->equals(snd)) + return true; - if (const auto fstErrorType = as(fst)) - return true; + // An error type can unify with anything, just so we avoid cascading errors. - if (const auto sndErrorType = as(snd)) - return true; + if (const auto fstErrorType = as(fst)) + return true; - // If one or the other of the types is a conjunction `X & Y`, - // then we want to recurse on both `X` and `Y`. - // - // Note that we check this case *before* we check if one of - // the types is a generic parameter below, so that we should - // never end up trying to match up a type parameter with - // a conjunction directly, and will instead find all of the - // "leaf" types we need to constrain it to. - // - if (as(fst) || as(snd)) - { - return TryUnifyConjunctionType(constraints, unifyCtx, fst, snd); - } + if (const auto sndErrorType = as(snd)) + return true; + + // If one or the other of the types is a conjunction `X & Y`, + // then we want to recurse on both `X` and `Y`. + // + // Note that we check this case *before* we check if one of + // the types is a generic parameter below, so that we should + // never end up trying to match up a type parameter with + // a conjunction directly, and will instead find all of the + // "leaf" types we need to constrain it to. + // + if (as(fst) || as(snd)) + { + return TryUnifyConjunctionType(constraints, unifyCtx, fst, snd); + } - // If one of the types is a type pack, we need to recursively unify the element types. - if (auto fstTypePack = as(fst)) + // If one of the types is a type pack, we need to recursively unify the element types. + if (auto fstTypePack = as(fst)) + { + if (auto sndTypePack = as(snd)) { - if (auto sndTypePack = as(snd)) + if (fstTypePack->getTypeCount() != sndTypePack->getTypeCount()) + return false; + for (Index i = 0; i < fstTypePack->getTypeCount(); ++i) { - if (fstTypePack->getTypeCount() != sndTypePack->getTypeCount()) + if (!TryUnifyTypes( + constraints, + unifyCtx, + QualType(fstTypePack->getElementType(i), fst.isLeftValue), + QualType(sndTypePack->getElementType(i), snd.isLeftValue))) return false; - for (Index i = 0; i < fstTypePack->getTypeCount(); ++i) - { - if (!TryUnifyTypes(constraints, unifyCtx,QualType(fstTypePack->getElementType(i), fst.isLeftValue), QualType(sndTypePack->getElementType(i), snd.isLeftValue))) - return false; - } - return true; } - else if (auto sndExpandType = as(snd)) + return true; + } + else if (auto sndExpandType = as(snd)) + { + for (Index i = 0; i < fstTypePack->getTypeCount(); ++i) { - for (Index i = 0; i < fstTypePack->getTypeCount(); ++i) - { - ValUnificationContext subUnifyCtx = unifyCtx; - subUnifyCtx.indexInTypePack = i; - if (!TryUnifyTypes(constraints, subUnifyCtx, QualType(fstTypePack->getElementType(i), fst.isLeftValue), QualType(sndExpandType->getPatternType(), snd.isLeftValue))) - return false; - } - return true; + ValUnificationContext subUnifyCtx = unifyCtx; + subUnifyCtx.indexInTypePack = i; + if (!TryUnifyTypes( + constraints, + subUnifyCtx, + QualType(fstTypePack->getElementType(i), fst.isLeftValue), + QualType(sndExpandType->getPatternType(), snd.isLeftValue))) + return false; } + return true; } + } - if (auto sndTypePack = as(snd)) + if (auto sndTypePack = as(snd)) + { + if (auto fstExpandType = as(fst)) { - if (auto fstExpandType = as(fst)) + for (Index i = 0; i < sndTypePack->getTypeCount(); ++i) { - for (Index i = 0; i < sndTypePack->getTypeCount(); ++i) - { - ValUnificationContext subUnifyCtx = unifyCtx; - subUnifyCtx.indexInTypePack = i; - if (!TryUnifyTypes(constraints, subUnifyCtx, QualType(fstExpandType->getPatternType(), fst.isLeftValue), QualType(sndTypePack->getElementType(i), snd.isLeftValue))) - return false; - } - return true; + ValUnificationContext subUnifyCtx = unifyCtx; + subUnifyCtx.indexInTypePack = i; + if (!TryUnifyTypes( + constraints, + subUnifyCtx, + QualType(fstExpandType->getPatternType(), fst.isLeftValue), + QualType(sndTypePack->getElementType(i), snd.isLeftValue))) + return false; } + return true; } + } - // A generic parameter type can unify with anything. - // TODO: there actually needs to be some kind of "occurs check" sort - // of thing here... + // A generic parameter type can unify with anything. + // TODO: there actually needs to be some kind of "occurs check" sort + // of thing here... - if (auto fstDeclRefType = as(fst)) - { - auto fstDeclRef = fstDeclRefType->getDeclRef(); + if (auto fstDeclRefType = as(fst)) + { + auto fstDeclRef = fstDeclRefType->getDeclRef(); - if (auto typeParamDecl = as(fstDeclRef.getDecl())) - { - if(typeParamDecl->parentDecl == constraints.genericDecl) - return TryUnifyTypeParam(constraints, unifyCtx, typeParamDecl, snd); - } - else if (auto typePackParamDecl = as(fstDeclRef.getDecl())) - { - if (typePackParamDecl->parentDecl == constraints.genericDecl - && isTypePack(snd)) - return TryUnifyTypeParam(constraints, unifyCtx, typePackParamDecl, snd); - } + if (auto typeParamDecl = as(fstDeclRef.getDecl())) + { + if (typeParamDecl->parentDecl == constraints.genericDecl) + return TryUnifyTypeParam(constraints, unifyCtx, typeParamDecl, snd); } - - if (auto sndDeclRefType = as(snd)) + else if (auto typePackParamDecl = as(fstDeclRef.getDecl())) { - auto sndDeclRef = sndDeclRefType->getDeclRef(); + if (typePackParamDecl->parentDecl == constraints.genericDecl && isTypePack(snd)) + return TryUnifyTypeParam(constraints, unifyCtx, typePackParamDecl, snd); + } + } - if (auto typeParamDecl = as(sndDeclRef.getDecl())) - { - if(typeParamDecl->parentDecl == constraints.genericDecl) - return TryUnifyTypeParam(constraints, unifyCtx, typeParamDecl, fst); - } - else if (auto typePackParamDecl = as(sndDeclRef.getDecl())) - { - if (typePackParamDecl->parentDecl == constraints.genericDecl - && isTypePack(fst)) - return TryUnifyTypeParam(constraints, unifyCtx, typePackParamDecl, fst); - } + if (auto sndDeclRefType = as(snd)) + { + auto sndDeclRef = sndDeclRefType->getDeclRef(); + + if (auto typeParamDecl = as(sndDeclRef.getDecl())) + { + if (typeParamDecl->parentDecl == constraints.genericDecl) + return TryUnifyTypeParam(constraints, unifyCtx, typeParamDecl, fst); } + else if (auto typePackParamDecl = as(sndDeclRef.getDecl())) + { + if (typePackParamDecl->parentDecl == constraints.genericDecl && isTypePack(fst)) + return TryUnifyTypeParam(constraints, unifyCtx, typePackParamDecl, fst); + } + } - // If we can unify the types structurally, then we are golden - if(TryUnifyTypesByStructuralMatch(constraints, unifyCtx, fst, snd)) - return true; + // If we can unify the types structurally, then we are golden + if (TryUnifyTypesByStructuralMatch(constraints, unifyCtx, fst, snd)) + return true; - // Now we need to consider cases where coercion might - // need to be applied. For now we can try to do this - // in a completely ad hoc fashion, but eventually we'd - // want to do it more formally. + // Now we need to consider cases where coercion might + // need to be applied. For now we can try to do this + // in a completely ad hoc fashion, but eventually we'd + // want to do it more formally. - if(auto fstVectorType = as(fst)) + if (auto fstVectorType = as(fst)) + { + if (auto sndScalarType = as(snd)) { - if(auto sndScalarType = as(snd)) - { - // Try unify the vector count param. In case the vector count is defined by a generic value - // parameter, we want to be able to infer that parameter should be 1. - // However, we don't want a failed unification to fail the entire generic argument inference, - // because a scalar can still be casted into a vector of any length. - - maybeUnifyUnconstraintIntParam(constraints, unifyCtx, fstVectorType->getElementCount(), m_astBuilder->getIntVal(m_astBuilder->getIntType(), 1), fst.isLeftValue); - return TryUnifyTypes( - constraints, - unifyCtx, - QualType(fstVectorType->getElementType(), fst.isLeftValue), - QualType(sndScalarType, snd.isLeftValue)); - } + // Try unify the vector count param. In case the vector count is defined by a generic + // value parameter, we want to be able to infer that parameter should be 1. However, we + // don't want a failed unification to fail the entire generic argument inference, + // because a scalar can still be casted into a vector of any length. + + maybeUnifyUnconstraintIntParam( + constraints, + unifyCtx, + fstVectorType->getElementCount(), + m_astBuilder->getIntVal(m_astBuilder->getIntType(), 1), + fst.isLeftValue); + return TryUnifyTypes( + constraints, + unifyCtx, + QualType(fstVectorType->getElementType(), fst.isLeftValue), + QualType(sndScalarType, snd.isLeftValue)); } + } - if(auto fstScalarType = as(fst)) + if (auto fstScalarType = as(fst)) + { + if (auto sndVectorType = as(snd)) { - if(auto sndVectorType = as(snd)) - { - maybeUnifyUnconstraintIntParam(constraints, unifyCtx, sndVectorType->getElementCount(), m_astBuilder->getIntVal(m_astBuilder->getIntType(), 1), snd.isLeftValue); - return TryUnifyTypes( - constraints, - unifyCtx, - QualType(fstScalarType, fst.isLeftValue), - QualType(sndVectorType->getElementType(), snd.isLeftValue)); - } + maybeUnifyUnconstraintIntParam( + constraints, + unifyCtx, + sndVectorType->getElementCount(), + m_astBuilder->getIntVal(m_astBuilder->getIntType(), 1), + snd.isLeftValue); + return TryUnifyTypes( + constraints, + unifyCtx, + QualType(fstScalarType, fst.isLeftValue), + QualType(sndVectorType->getElementType(), snd.isLeftValue)); } + } - if (auto fstUniformParamGroupType = as(fst)) - return TryUnifyTypes(constraints, unifyCtx, QualType(fstUniformParamGroupType->getElementType(), fst.isLeftValue), snd); - if (auto sndUniformParamGroupType = as(snd)) - return TryUnifyTypes(constraints, unifyCtx, fst, QualType(sndUniformParamGroupType->getElementType(), snd.isLeftValue)); + if (auto fstUniformParamGroupType = as(fst)) + return TryUnifyTypes( + constraints, + unifyCtx, + QualType(fstUniformParamGroupType->getElementType(), fst.isLeftValue), + snd); + if (auto sndUniformParamGroupType = as(snd)) + return TryUnifyTypes( + constraints, + unifyCtx, + fst, + QualType(sndUniformParamGroupType->getElementType(), snd.isLeftValue)); - // Each T can coerce with any DeclRefType. - if (auto eachSnd = as(snd)) + // Each T can coerce with any DeclRefType. + if (auto eachSnd = as(snd)) + { + if (auto innerSnd = eachSnd->getElementDeclRefType()) { - if (auto innerSnd = eachSnd->getElementDeclRefType()) + if (auto sndTypePackParamDecl = + as(innerSnd->getDeclRef().getDecl())) { - if (auto sndTypePackParamDecl = as(innerSnd->getDeclRef().getDecl())) + if (innerSnd->getDeclRef().getDecl()->parentDecl == constraints.genericDecl) { - if (innerSnd->getDeclRef().getDecl()->parentDecl == constraints.genericDecl) - { - return TryUnifyTypeParam(constraints, unifyCtx, sndTypePackParamDecl, fst); - } + return TryUnifyTypeParam(constraints, unifyCtx, sndTypePackParamDecl, fst); } } } - if (auto eachFst = as(fst)) + } + if (auto eachFst = as(fst)) + { + if (auto innerFst = eachFst->getElementDeclRefType()) { - if (auto innerFst = eachFst->getElementDeclRefType()) + if (auto fstTypePackParamDecl = + as(innerFst->getDeclRef().getDecl())) { - if (auto fstTypePackParamDecl = as(innerFst->getDeclRef().getDecl())) + if (innerFst->getDeclRef().getDecl()->parentDecl == constraints.genericDecl) { - if (innerFst->getDeclRef().getDecl()->parentDecl == constraints.genericDecl) - { - return TryUnifyTypeParam(constraints, unifyCtx, fstTypePackParamDecl, snd); - } + return TryUnifyTypeParam(constraints, unifyCtx, fstTypePackParamDecl, snd); } } } - return false; } + return false; +} -} +} // namespace Slang diff --git a/source/slang/slang-check-conversion.cpp b/source/slang/slang-check-conversion.cpp index 9ca96ee49..7e33158e6 100644 --- a/source/slang/slang-check-conversion.cpp +++ b/source/slang/slang-check-conversion.cpp @@ -12,259 +12,310 @@ namespace Slang { - ConversionCost SemanticsVisitor::getImplicitConversionCost( - Decl* decl) +ConversionCost SemanticsVisitor::getImplicitConversionCost(Decl* decl) +{ + if (auto modifier = decl->findModifier()) { - if(auto modifier = decl->findModifier()) - { - return modifier->cost; - } - - return kConversionCost_Explicit; + return modifier->cost; } - BuiltinConversionKind SemanticsVisitor::getImplicitConversionBuiltinKind( - Decl* decl) - { - if (auto modifier = decl->findModifier()) - { - return modifier->builtinConversionKind; - } - - return kBuiltinConversion_Unknown; - } + return kConversionCost_Explicit; +} - bool SemanticsVisitor::isEffectivelyScalarForInitializerLists( - Type* type) +BuiltinConversionKind SemanticsVisitor::getImplicitConversionBuiltinKind(Decl* decl) +{ + if (auto modifier = decl->findModifier()) { - if(as(type)) return false; - if(as(type)) return false; - if(as(type)) return false; - - if(as(type)) - { - return true; - } + return modifier->builtinConversionKind; + } - if(as(type)) - { - return true; - } - if(as(type)) - { - return true; - } - if(as(type)) - { - return true; - } + return kBuiltinConversion_Unknown; +} - if(auto declRefType = as(type)) - { - if(as(declRefType->getDeclRef())) - return false; - } +bool SemanticsVisitor::isEffectivelyScalarForInitializerLists(Type* type) +{ + if (as(type)) + return false; + if (as(type)) + return false; + if (as(type)) + return false; + if (as(type)) + { return true; } - bool SemanticsVisitor::shouldUseInitializerDirectly( - Type* toType, - Expr* fromExpr) + if (as(type)) { - // A nested initializer list should always be used directly. - // - if(as(fromExpr)) - { - return true; - } - - // If the desired type is a scalar, then we should always initialize - // directly, since it isn't an aggregate. - // - if(isEffectivelyScalarForInitializerLists(toType)) - return true; + return true; + } + if (as(type)) + { + return true; + } + if (as(type)) + { + return true; + } - // If the type we are initializing isn't effectively scalar, - // but the initialization expression *is*, then it doesn't - // seem like direct initialization is intended. - // - if(isEffectivelyScalarForInitializerLists(fromExpr->type)) + if (auto declRefType = as(type)) + { + if (as(declRefType->getDeclRef())) return false; - - // Once the above cases are handled, the main thing - // we want to check for is whether a direct initialization - // is possible (a type conversion exists). - // - return canCoerce(toType, fromExpr->type, fromExpr); } - bool SemanticsVisitor::_readValueFromInitializerList( - Type* toType, - Expr** outToExpr, - InitializerListExpr* fromInitializerListExpr, - UInt &ioInitArgIndex) + return true; +} + +bool SemanticsVisitor::shouldUseInitializerDirectly(Type* toType, Expr* fromExpr) +{ + // A nested initializer list should always be used directly. + // + if (as(fromExpr)) { - // First, we will check if we have run out of arguments - // on the initializer list. - // - UInt initArgCount = fromInitializerListExpr->args.getCount(); - if(ioInitArgIndex >= initArgCount) - { - // If we are at the end of the initializer list, - // then our ability to read an argument depends - // on whether the type we are trying to read - // is default-initializable. - // - // For now, we will just pretend like everything - // is default-initializable and move along. - return true; - } + return true; + } - // Okay, we have at least one initializer list expression, - // so we will look at the next expression and decide - // whether to use it to initialize the desired type - // directly (possibly via casts), or as the first sub-expression - // for aggregate initialization. - // - auto firstInitExpr = fromInitializerListExpr->args[ioInitArgIndex]; - if(shouldUseInitializerDirectly(toType, firstInitExpr)) - { - ioInitArgIndex++; - return _coerce( - CoercionSite::Initializer, - toType, - outToExpr, - firstInitExpr->type, - firstInitExpr, - nullptr); - } + // If the desired type is a scalar, then we should always initialize + // directly, since it isn't an aggregate. + // + if (isEffectivelyScalarForInitializerLists(toType)) + return true; - // If there is somehow an error in one of the initialization - // expressions, then everything could be thrown off and we - // shouldn't keep trying to read arguments. - // - if( IsErrorExpr(firstInitExpr) ) - { - // Stop reading arguments, as if we'd reached - // the end of the list. - // - ioInitArgIndex = initArgCount; - return true; - } + // If the type we are initializing isn't effectively scalar, + // but the initialization expression *is*, then it doesn't + // seem like direct initialization is intended. + // + if (isEffectivelyScalarForInitializerLists(fromExpr->type)) + return false; + + // Once the above cases are handled, the main thing + // we want to check for is whether a direct initialization + // is possible (a type conversion exists). + // + return canCoerce(toType, fromExpr->type, fromExpr); +} - // The fallback case is to recursively read the - // type from the same list as an aggregate. +bool SemanticsVisitor::_readValueFromInitializerList( + Type* toType, + Expr** outToExpr, + InitializerListExpr* fromInitializerListExpr, + UInt& ioInitArgIndex) +{ + // First, we will check if we have run out of arguments + // on the initializer list. + // + UInt initArgCount = fromInitializerListExpr->args.getCount(); + if (ioInitArgIndex >= initArgCount) + { + // If we are at the end of the initializer list, + // then our ability to read an argument depends + // on whether the type we are trying to read + // is default-initializable. // - return _readAggregateValueFromInitializerList( + // For now, we will just pretend like everything + // is default-initializable and move along. + return true; + } + + // Okay, we have at least one initializer list expression, + // so we will look at the next expression and decide + // whether to use it to initialize the desired type + // directly (possibly via casts), or as the first sub-expression + // for aggregate initialization. + // + auto firstInitExpr = fromInitializerListExpr->args[ioInitArgIndex]; + if (shouldUseInitializerDirectly(toType, firstInitExpr)) + { + ioInitArgIndex++; + return _coerce( + CoercionSite::Initializer, toType, outToExpr, - fromInitializerListExpr, - ioInitArgIndex); + firstInitExpr->type, + firstInitExpr, + nullptr); } - DeclRefType* findBaseStructType(ASTBuilder* astBuilder, DeclRef structTypeDeclRef) + // If there is somehow an error in one of the initialization + // expressions, then everything could be thrown off and we + // shouldn't keep trying to read arguments. + // + if (IsErrorExpr(firstInitExpr)) { - auto inheritanceDecl = getMembersOfType(astBuilder, structTypeDeclRef).getFirstOrNull(); - if(!inheritanceDecl) - return nullptr; - - auto baseType = getBaseType(astBuilder, inheritanceDecl); - auto baseDeclRefType = as(baseType); - if(!baseDeclRefType) - return nullptr; + // Stop reading arguments, as if we'd reached + // the end of the list. + // + ioInitArgIndex = initArgCount; + return true; + } - auto baseDeclRef = baseDeclRefType->getDeclRef(); - auto baseStructDeclRef = baseDeclRef.as(); - if(!baseStructDeclRef) - return nullptr; + // The fallback case is to recursively read the + // type from the same list as an aggregate. + // + return _readAggregateValueFromInitializerList( + toType, + outToExpr, + fromInitializerListExpr, + ioInitArgIndex); +} - return baseDeclRefType; - } +DeclRefType* findBaseStructType(ASTBuilder* astBuilder, DeclRef structTypeDeclRef) +{ + auto inheritanceDecl = + getMembersOfType(astBuilder, structTypeDeclRef).getFirstOrNull(); + if (!inheritanceDecl) + return nullptr; + + auto baseType = getBaseType(astBuilder, inheritanceDecl); + auto baseDeclRefType = as(baseType); + if (!baseDeclRefType) + return nullptr; + + auto baseDeclRef = baseDeclRefType->getDeclRef(); + auto baseStructDeclRef = baseDeclRef.as(); + if (!baseStructDeclRef) + return nullptr; + + return baseDeclRefType; +} - DeclRef findBaseStructDeclRef(ASTBuilder* astBuilder, DeclRef structTypeDeclRef) - { - auto inheritanceDecl = getMembersOfType(astBuilder, structTypeDeclRef).getFirstOrNull(); - if (!inheritanceDecl) - return DeclRef(); +DeclRef findBaseStructDeclRef( + ASTBuilder* astBuilder, + DeclRef structTypeDeclRef) +{ + auto inheritanceDecl = + getMembersOfType(astBuilder, structTypeDeclRef).getFirstOrNull(); + if (!inheritanceDecl) + return DeclRef(); + + auto baseType = getBaseType(astBuilder, inheritanceDecl); + auto baseDeclRefType = as(baseType); + if (!baseDeclRefType) + return DeclRef(); + + auto baseDeclRef = baseDeclRefType->getDeclRef(); + auto baseStructDeclRef = baseDeclRef.as(); + if (!baseStructDeclRef) + return DeclRef(); + + return baseStructDeclRef; +} - auto baseType = getBaseType(astBuilder, inheritanceDecl); - auto baseDeclRefType = as(baseType); - if (!baseDeclRefType) - return DeclRef(); +bool SemanticsVisitor::_readAggregateValueFromInitializerList( + Type* inToType, + Expr** outToExpr, + InitializerListExpr* fromInitializerListExpr, + UInt& ioArgIndex) +{ + auto toType = inToType; + UInt argCount = fromInitializerListExpr->args.getCount(); - auto baseDeclRef = baseDeclRefType->getDeclRef(); - auto baseStructDeclRef = baseDeclRef.as(); - if (!baseStructDeclRef) - return DeclRef(); + // In the case where we need to build a result expression, + // we will collect the new arguments here + List coercedArgs; - return baseStructDeclRef; + if (isEffectivelyScalarForInitializerLists(toType)) + { + // For any type that is effectively a non-aggregate, + // we expect to read a single value from the initializer list + // + if (ioArgIndex < argCount) + { + auto arg = fromInitializerListExpr->args[ioArgIndex++]; + return _coerce(CoercionSite::Initializer, toType, outToExpr, arg->type, arg, nullptr); + } + else + { + // If there wasn't an initialization + // expression to be found, then we need + // to perform default initialization here. + // + // We will let this case come through the front-end + // as an `InitializerListExpr` with zero arguments, + // and then have the IR generation logic deal with + // synthesizing default values. + } } - - bool SemanticsVisitor::_readAggregateValueFromInitializerList( - Type* inToType, - Expr** outToExpr, - InitializerListExpr* fromInitializerListExpr, - UInt &ioArgIndex) + else if (auto toVecType = as(toType)) { - auto toType = inToType; - UInt argCount = fromInitializerListExpr->args.getCount(); + auto toElementCount = toVecType->getElementCount(); + auto toElementType = toVecType->getElementType(); - // In the case where we need to build a result expression, - // we will collect the new arguments here - List coercedArgs; - - if(isEffectivelyScalarForInitializerLists(toType)) + UInt elementCount = 0; + if (auto constElementCount = as(toElementCount)) + { + elementCount = (UInt)constElementCount->getValue(); + } + else { - // For any type that is effectively a non-aggregate, - // we expect to read a single value from the initializer list + // We don't know the element count statically, + // so what are we supposed to be doing? // - if(ioArgIndex < argCount) + if (outToExpr) { - auto arg = fromInitializerListExpr->args[ioArgIndex++]; - return _coerce( - CoercionSite::Initializer, - toType, - outToExpr, - arg->type, - arg, - nullptr); + getSink()->diagnose( + fromInitializerListExpr, + Diagnostics::cannotUseInitializerListForVectorOfUnknownSize, + toElementCount); } - else + return false; + } + + for (UInt ee = 0; ee < elementCount; ++ee) + { + Expr* coercedArg = nullptr; + bool argResult = _readValueFromInitializerList( + toElementType, + outToExpr ? &coercedArg : nullptr, + fromInitializerListExpr, + ioArgIndex); + + // No point in trying further if any argument fails + if (!argResult) + return false; + + if (coercedArg) { - // If there wasn't an initialization - // expression to be found, then we need - // to perform default initialization here. - // - // We will let this case come through the front-end - // as an `InitializerListExpr` with zero arguments, - // and then have the IR generation logic deal with - // synthesizing default values. + coercedArgs.add(coercedArg); } } - else if (auto toVecType = as(toType)) + } + else if (auto toArrayType = as(toType)) + { + // TODO(tfoley): If we can compute the size of the array statically, + // then we want to check that there aren't too many initializers present + + auto toElementType = toArrayType->getElementType(); + if (!toArrayType->isUnsized()) { - auto toElementCount = toVecType->getElementCount(); - auto toElementType = toVecType->getElementType(); + auto toElementCount = toArrayType->getElementCount(); + // In the case of a sized array, we need to check that the number + // of elements being initialized matches what was declared. + // UInt elementCount = 0; if (auto constElementCount = as(toElementCount)) { - elementCount = (UInt) constElementCount->getValue(); + elementCount = (UInt)constElementCount->getValue(); } else { // We don't know the element count statically, // so what are we supposed to be doing? // - if(outToExpr) + if (outToExpr) { - getSink()->diagnose(fromInitializerListExpr, Diagnostics::cannotUseInitializerListForVectorOfUnknownSize, toElementCount); + getSink()->diagnose( + fromInitializerListExpr, + Diagnostics::cannotUseInitializerListForArrayOfUnknownSize, + toElementCount); } return false; } - for(UInt ee = 0; ee < elementCount; ++ee) + for (UInt ee = 0; ee < elementCount; ++ee) { Expr* coercedArg = nullptr; bool argResult = _readValueFromInitializerList( @@ -274,1300 +325,1236 @@ namespace Slang ioArgIndex); // No point in trying further if any argument fails - if(!argResult) + if (!argResult) return false; - if( coercedArg ) + if (coercedArg) { coercedArgs.add(coercedArg); } } } - else if(auto toArrayType = as(toType)) + else { - // TODO(tfoley): If we can compute the size of the array statically, - // then we want to check that there aren't too many initializers present - - auto toElementType = toArrayType->getElementType(); - if(!toArrayType->isUnsized()) + // In the case of an unsized array type, we will use the + // number of arguments to the initializer to determine + // the element count. + // + UInt elementCount = 0; + while (ioArgIndex < argCount) { - auto toElementCount = toArrayType->getElementCount(); + Expr* coercedArg = nullptr; + bool argResult = _readValueFromInitializerList( + toElementType, + outToExpr ? &coercedArg : nullptr, + fromInitializerListExpr, + ioArgIndex); - // In the case of a sized array, we need to check that the number - // of elements being initialized matches what was declared. - // - UInt elementCount = 0; - if (auto constElementCount = as(toElementCount)) - { - elementCount = (UInt) constElementCount->getValue(); - } - else - { - // We don't know the element count statically, - // so what are we supposed to be doing? - // - if(outToExpr) - { - getSink()->diagnose(fromInitializerListExpr, Diagnostics::cannotUseInitializerListForArrayOfUnknownSize, toElementCount); - } + // No point in trying further if any argument fails + if (!argResult) return false; - } - - for(UInt ee = 0; ee < elementCount; ++ee) - { - Expr* coercedArg = nullptr; - bool argResult = _readValueFromInitializerList( - toElementType, - outToExpr ? &coercedArg : nullptr, - fromInitializerListExpr, - ioArgIndex); - // No point in trying further if any argument fails - if(!argResult) - return false; + elementCount++; - if( coercedArg ) - { - coercedArgs.add(coercedArg); - } + if (coercedArg) + { + coercedArgs.add(coercedArg); } } - else - { - // In the case of an unsized array type, we will use the - // number of arguments to the initializer to determine - // the element count. - // - UInt elementCount = 0; - while(ioArgIndex < argCount) - { - Expr* coercedArg = nullptr; - bool argResult = _readValueFromInitializerList( - toElementType, - outToExpr ? &coercedArg : nullptr, - fromInitializerListExpr, - ioArgIndex); - // No point in trying further if any argument fails - if(!argResult) - return false; + // We have a new type for the conversion, based on what + // we learned. + toType = m_astBuilder->getArrayType( + toElementType, + m_astBuilder->getIntVal(m_astBuilder->getIntType(), elementCount)); + } + } + else if (auto toMatrixType = as(toType)) + { + // In the general case, the initializer list might comprise + // both vectors and scalars. + // + // The traditional HLSL compilers treat any vectors in + // the initializer list exactly equivalent to their sequence + // of scalar elements, and don't care how this might, or + // might not, align with the rows of the matrix. + // + // We will draw a line in the sand and say that an initializer + // list for a matrix will act as if the matrix type were an + // array of vectors for the rows. - elementCount++; - if( coercedArg ) - { - coercedArgs.add(coercedArg); - } - } + UInt rowCount = 0; + auto toRowType = + createVectorType(toMatrixType->getElementType(), toMatrixType->getColumnCount()); - // We have a new type for the conversion, based on what - // we learned. - toType = m_astBuilder->getArrayType(toElementType, - m_astBuilder->getIntVal(m_astBuilder->getIntType(), elementCount)); - } + if (auto constRowCount = as(toMatrixType->getRowCount())) + { + rowCount = (UInt)constRowCount->getValue(); } - else if(auto toMatrixType = as(toType)) + else { - // In the general case, the initializer list might comprise - // both vectors and scalars. + // We don't know the element count statically, + // so what are we supposed to be doing? // - // The traditional HLSL compilers treat any vectors in - // the initializer list exactly equivalent to their sequence - // of scalar elements, and don't care how this might, or - // might not, align with the rows of the matrix. - // - // We will draw a line in the sand and say that an initializer - // list for a matrix will act as if the matrix type were an - // array of vectors for the rows. - - - UInt rowCount = 0; - auto toRowType = createVectorType( - toMatrixType->getElementType(), - toMatrixType->getColumnCount()); - - if (auto constRowCount = as(toMatrixType->getRowCount())) + if (outToExpr) { - rowCount = (UInt) constRowCount->getValue(); + getSink()->diagnose( + fromInitializerListExpr, + Diagnostics::cannotUseInitializerListForMatrixOfUnknownSize, + toMatrixType->getRowCount()); } - else - { - // We don't know the element count statically, - // so what are we supposed to be doing? - // - if(outToExpr) - { - getSink()->diagnose(fromInitializerListExpr, Diagnostics::cannotUseInitializerListForMatrixOfUnknownSize, toMatrixType->getRowCount()); - } + return false; + } + + for (UInt rr = 0; rr < rowCount; ++rr) + { + Expr* coercedArg = nullptr; + bool argResult = _readValueFromInitializerList( + toRowType, + outToExpr ? &coercedArg : nullptr, + fromInitializerListExpr, + ioArgIndex); + + // No point in trying further if any argument fails + if (!argResult) return false; - } - for(UInt rr = 0; rr < rowCount; ++rr) + if (coercedArg) + { + coercedArgs.add(coercedArg); + } + } + } + else if (auto toDeclRefType = as(toType)) + { + auto toTypeDeclRef = toDeclRefType->getDeclRef(); + if (auto toStructDeclRef = toTypeDeclRef.as()) + { + // Trying to initialize a `struct` type given an initializer list. + // + // Before we iterate over the fields, we want to check if this struct + // inherits from another `struct` type. If so, we want to read + // an initializer for that base type first. + // + if (auto baseStructType = findBaseStructType(m_astBuilder, toStructDeclRef)) { Expr* coercedArg = nullptr; bool argResult = _readValueFromInitializerList( - toRowType, + baseStructType, outToExpr ? &coercedArg : nullptr, fromInitializerListExpr, ioArgIndex); // No point in trying further if any argument fails - if(!argResult) + if (!argResult) return false; - if( coercedArg ) + if (coercedArg) { coercedArgs.add(coercedArg); } } - } - else if(auto toDeclRefType = as(toType)) - { - auto toTypeDeclRef = toDeclRefType->getDeclRef(); - if(auto toStructDeclRef = toTypeDeclRef.as()) - { - // Trying to initialize a `struct` type given an initializer list. - // - // Before we iterate over the fields, we want to check if this struct - // inherits from another `struct` type. If so, we want to read - // an initializer for that base type first. - // - if (auto baseStructType = findBaseStructType(m_astBuilder, toStructDeclRef)) - { - Expr* coercedArg = nullptr; - bool argResult = _readValueFromInitializerList( - baseStructType, - outToExpr ? &coercedArg : nullptr, - fromInitializerListExpr, - ioArgIndex); - // No point in trying further if any argument fails - if (!argResult) - return false; + // We will go through the fields in order and try to match them + // up with initializer arguments. + // + for (auto fieldDeclRef : getMembersOfType( + m_astBuilder, + toStructDeclRef, + MemberFilterStyle::Instance)) + { + Expr* coercedArg = nullptr; + bool argResult = _readValueFromInitializerList( + getType(m_astBuilder, fieldDeclRef), + outToExpr ? &coercedArg : nullptr, + fromInitializerListExpr, + ioArgIndex); - if (coercedArg) - { - coercedArgs.add(coercedArg); - } - } + // No point in trying further if any argument fails + if (!argResult) + return false; - // We will go through the fields in order and try to match them - // up with initializer arguments. - // - for(auto fieldDeclRef : getMembersOfType(m_astBuilder, toStructDeclRef, MemberFilterStyle::Instance)) + if (coercedArg) { - Expr* coercedArg = nullptr; - bool argResult = _readValueFromInitializerList( - getType(m_astBuilder, fieldDeclRef), - outToExpr ? &coercedArg : nullptr, - fromInitializerListExpr, - ioArgIndex); - - // No point in trying further if any argument fails - if(!argResult) - return false; - - if( coercedArg ) - { - coercedArgs.add(coercedArg); - } + coercedArgs.add(coercedArg); } } } - else + } + else + { + // We shouldn't get to this case in practice, + // but just in case we'll consider an initializer + // list invalid if we are trying to read something + // off of it that wasn't handled by the cases above. + // + if (outToExpr) { - // We shouldn't get to this case in practice, - // but just in case we'll consider an initializer - // list invalid if we are trying to read something - // off of it that wasn't handled by the cases above. - // - if(outToExpr) - { - getSink()->diagnose(fromInitializerListExpr, Diagnostics::cannotUseInitializerListForType, inToType); - } - return false; + getSink()->diagnose( + fromInitializerListExpr, + Diagnostics::cannotUseInitializerListForType, + inToType); } + return false; + } - // We were able to coerce all the arguments given, and so - // we need to construct a suitable expression to remember the result + // We were able to coerce all the arguments given, and so + // we need to construct a suitable expression to remember the result + // + if (outToExpr) + { + auto toInitializerListExpr = m_astBuilder->create(); + toInitializerListExpr->loc = fromInitializerListExpr->loc; + toInitializerListExpr->type = QualType(toType); + toInitializerListExpr->args = coercedArgs; + + // Wrap initalizer list args if we're creating a non-differentiable struct within a + // differentiable function. // - if(outToExpr) + if (auto func = getParentFuncOfVisitor()) { - auto toInitializerListExpr = m_astBuilder->create(); - toInitializerListExpr->loc = fromInitializerListExpr->loc; - toInitializerListExpr->type = QualType(toType); - toInitializerListExpr->args = coercedArgs; - - // Wrap initalizer list args if we're creating a non-differentiable struct within a - // differentiable function. - // - if (auto func = getParentFuncOfVisitor()) + if (func->findModifier() && !isTypeDifferentiable(toType)) { - if (func->findModifier() && - !isTypeDifferentiable(toType)) + for (auto& arg : toInitializerListExpr->args) { - for (auto &arg : toInitializerListExpr->args) + if (isTypeDifferentiable(arg->type.type)) { - if (isTypeDifferentiable(arg->type.type)) - { - auto detachedArg = m_astBuilder->create(); - detachedArg->inner = arg; - detachedArg->type = arg->type; - arg = detachedArg; - } + auto detachedArg = m_astBuilder->create(); + detachedArg->inner = arg; + detachedArg->type = arg->type; + arg = detachedArg; } } } - - *outToExpr = toInitializerListExpr; } - return true; + *outToExpr = toInitializerListExpr; } - bool SemanticsVisitor::_coerceInitializerList( - Type* toType, - Expr** outToExpr, - InitializerListExpr* fromInitializerListExpr) - { - UInt argCount = fromInitializerListExpr->args.getCount(); - UInt argIndex = 0; - - // TODO: we should handle the special case of `{0}` as an initializer - // for arbitrary `struct` types here. - - // If this initializer list has a more specific type than just - // InitializerListType (i.e. it's already undergone a coercion) we - // should ensure that we're allowed to coerce from that type to our - // desired type. - // If this isn't prohibited, then we can proceed to try and coerce from - // the initializer list itself; assuming that coercion is closed under - // composition this shouldn't fail. - if(!as(fromInitializerListExpr->type) && - !canCoerce(toType, fromInitializerListExpr->type, nullptr)) - return _failedCoercion(toType, outToExpr, fromInitializerListExpr); - - if(!_readAggregateValueFromInitializerList(toType, outToExpr, fromInitializerListExpr, argIndex)) - return false; - - if(argIndex != argCount) - { - if( outToExpr ) - { - getSink()->diagnose(fromInitializerListExpr, Diagnostics::tooManyInitializers, argIndex, argCount); - } - } + return true; +} - return true; - } +bool SemanticsVisitor::_coerceInitializerList( + Type* toType, + Expr** outToExpr, + InitializerListExpr* fromInitializerListExpr) +{ + UInt argCount = fromInitializerListExpr->args.getCount(); + UInt argIndex = 0; + + // TODO: we should handle the special case of `{0}` as an initializer + // for arbitrary `struct` types here. + + // If this initializer list has a more specific type than just + // InitializerListType (i.e. it's already undergone a coercion) we + // should ensure that we're allowed to coerce from that type to our + // desired type. + // If this isn't prohibited, then we can proceed to try and coerce from + // the initializer list itself; assuming that coercion is closed under + // composition this shouldn't fail. + if (!as(fromInitializerListExpr->type) && + !canCoerce(toType, fromInitializerListExpr->type, nullptr)) + return _failedCoercion(toType, outToExpr, fromInitializerListExpr); + + if (!_readAggregateValueFromInitializerList( + toType, + outToExpr, + fromInitializerListExpr, + argIndex)) + return false; - bool SemanticsVisitor::_failedCoercion( - Type* toType, - Expr** outToExpr, - Expr* fromExpr) + if (argIndex != argCount) { - if(outToExpr) + if (outToExpr) { - // As a special case, if the expression we are trying to convert - // from is overloaded (implying an ambiguous reference), then we - // will try to produce a more appropriately tailored error message. - // - auto fromType = fromExpr->type.type; - if( as(fromType) ) - { - diagnoseAmbiguousReference(fromExpr); - } - else - { - getSink()->diagnose(fromExpr->loc, Diagnostics::typeMismatch, toType, fromExpr->type); - } + getSink()->diagnose( + fromInitializerListExpr, + Diagnostics::tooManyInitializers, + argIndex, + argCount); } - return false; } - /// Do the `left` and `right` modifiers represent the same thing? - static bool _doModifiersMatch(Val* left, Val* right) - { - if( left == right ) - return true; - - if( left->equals(right) ) - return true; - - return false; - } + return true; +} - /// Does `type` have a modifier that matches `modifier`? - static bool _hasMatchingModifier(ModifiedType* type, Val* modifier) +bool SemanticsVisitor::_failedCoercion(Type* toType, Expr** outToExpr, Expr* fromExpr) +{ + if (outToExpr) { - if(!type) return false; - - for (Index m = 0; m < type->getModifierCount(); m++) + // As a special case, if the expression we are trying to convert + // from is overloaded (implying an ambiguous reference), then we + // will try to produce a more appropriately tailored error message. + // + auto fromType = fromExpr->type.type; + if (as(fromType)) { - if(_doModifiersMatch(type->getModifier(m), modifier)) - return true; + diagnoseAmbiguousReference(fromExpr); + } + else + { + getSink()->diagnose(fromExpr->loc, Diagnostics::typeMismatch, toType, fromExpr->type); } + } + return false; +} + +/// Do the `left` and `right` modifiers represent the same thing? +static bool _doModifiersMatch(Val* left, Val* right) +{ + if (left == right) + return true; + + if (left->equals(right)) + return true; + + return false; +} +/// Does `type` have a modifier that matches `modifier`? +static bool _hasMatchingModifier(ModifiedType* type, Val* modifier) +{ + if (!type) return false; - } - /// Can `modifier` be added to a type as part of a coercion? - /// - /// For example, it is generally safe to convert from a value - /// of type `T` to a value of type `const T` in C/C++. - /// - static bool _canModifierBeAddedDuringCoercion(Val* modifier) + for (Index m = 0; m < type->getModifierCount(); m++) { - switch( modifier->astNodeType ) - { - default: - return false; - - case ASTNodeType::UNormModifierVal: - case ASTNodeType::SNormModifierVal: - case ASTNodeType::NoDiffModifierVal: + if (_doModifiersMatch(type->getModifier(m), modifier)) return true; - } } - /// Can `modifier` be dropped from a type as part of a coercion? - /// - /// For example, it is generally safe to convert from a value - /// of type `const T` to a value of type `T` in C/C++. - /// - static bool _canModifierBeDroppedDuringCoercion(Val* modifier) + return false; +} + +/// Can `modifier` be added to a type as part of a coercion? +/// +/// For example, it is generally safe to convert from a value +/// of type `T` to a value of type `const T` in C/C++. +/// +static bool _canModifierBeAddedDuringCoercion(Val* modifier) +{ + switch (modifier->astNodeType) { - switch( modifier->astNodeType ) - { - default: - return false; + default: return false; - case ASTNodeType::UNormModifierVal: - case ASTNodeType::SNormModifierVal: - case ASTNodeType::NoDiffModifierVal: - return true; - } + case ASTNodeType::UNormModifierVal: + case ASTNodeType::SNormModifierVal: + case ASTNodeType::NoDiffModifierVal: return true; } +} - static bool isSigned(Type* t) +/// Can `modifier` be dropped from a type as part of a coercion? +/// +/// For example, it is generally safe to convert from a value +/// of type `const T` to a value of type `T` in C/C++. +/// +static bool _canModifierBeDroppedDuringCoercion(Val* modifier) +{ + switch (modifier->astNodeType) { - auto basicType = as(t); - if (!basicType) return false; - switch (basicType->getBaseType()) - { - case BaseType::Int8: - case BaseType::Int16: - case BaseType::Int: - case BaseType::Int64: - case BaseType::IntPtr: - return true; - default: - return false; - } + default: return false; + + case ASTNodeType::UNormModifierVal: + case ASTNodeType::SNormModifierVal: + case ASTNodeType::NoDiffModifierVal: return true; } +} - int getTypeBitSize(Type* t) +static bool isSigned(Type* t) +{ + auto basicType = as(t); + if (!basicType) + return false; + switch (basicType->getBaseType()) { - auto basicType = as(t); - if (!basicType) return 0; + case BaseType::Int8: + case BaseType::Int16: + case BaseType::Int: + case BaseType::Int64: + case BaseType::IntPtr: return true; + default: return false; + } +} - switch (basicType->getBaseType()) - { - case BaseType::Int8: - case BaseType::UInt8: - return 8; - case BaseType::Int16: - case BaseType::UInt16: - return 16; - case BaseType::Int: - case BaseType::UInt: - return 32; - case BaseType::Int64: - case BaseType::UInt64: - return 64; - case BaseType::IntPtr: - case BaseType::UIntPtr: +int getTypeBitSize(Type* t) +{ + auto basicType = as(t); + if (!basicType) + return 0; + + switch (basicType->getBaseType()) + { + case BaseType::Int8: + case BaseType::UInt8: return 8; + case BaseType::Int16: + case BaseType::UInt16: return 16; + case BaseType::Int: + case BaseType::UInt: return 32; + case BaseType::Int64: + case BaseType::UInt64: return 64; + case BaseType::IntPtr: + case BaseType::UIntPtr: #if SLANG_PTR_IS_32 - return 32; + return 32; #else - return 64; + return 64; #endif - default: - return 0; - } + default: return 0; } +} - ConversionCost SemanticsVisitor::getImplicitConversionCostWithKnownArg(Decl* decl, Type* toType, Expr* arg) - { - ConversionCost candidateCost = getImplicitConversionCost(decl); +ConversionCost SemanticsVisitor::getImplicitConversionCostWithKnownArg( + Decl* decl, + Type* toType, + Expr* arg) +{ + ConversionCost candidateCost = getImplicitConversionCost(decl); - // Fix up the cost if the operand is a const lit. - if (isScalarIntegerType(toType)) + // Fix up the cost if the operand is a const lit. + if (isScalarIntegerType(toType)) + { + auto knownVal = as(arg); + if (!knownVal) + return candidateCost; + if (getIntValueBitSize(knownVal->value) <= getTypeBitSize(toType)) { - auto knownVal = as(arg); - if (!knownVal) - return candidateCost; - if (getIntValueBitSize(knownVal->value) <= getTypeBitSize(toType)) - { - bool toTypeIsSigned = isSigned(toType); - bool fromTypeIsSigned = isSigned(knownVal->type); - if (toTypeIsSigned == fromTypeIsSigned) - candidateCost = kConversionCost_InRangeIntLitConversion; - else if (toTypeIsSigned) - candidateCost = kConversionCost_InRangeIntLitUnsignedToSignedConversion; - else - candidateCost = kConversionCost_InRangeIntLitSignedToUnsignedConversion; - } + bool toTypeIsSigned = isSigned(toType); + bool fromTypeIsSigned = isSigned(knownVal->type); + if (toTypeIsSigned == fromTypeIsSigned) + candidateCost = kConversionCost_InRangeIntLitConversion; + else if (toTypeIsSigned) + candidateCost = kConversionCost_InRangeIntLitUnsignedToSignedConversion; + else + candidateCost = kConversionCost_InRangeIntLitSignedToUnsignedConversion; } - return candidateCost; } + return candidateCost; +} - bool SemanticsVisitor::_coerce( - CoercionSite site, - Type* toType, - Expr** outToExpr, - QualType fromType, - Expr* fromExpr, - ConversionCost* outCost) +bool SemanticsVisitor::_coerce( + CoercionSite site, + Type* toType, + Expr** outToExpr, + QualType fromType, + Expr* fromExpr, + ConversionCost* outCost) +{ + // If we are about to try and coerce an overloaded expression, + // then we should start by trying to resolve the ambiguous reference + // based on prioritization of the different candidates. + // + // TODO: A more powerful model would be to try to coerce each + // of the constituent overload candidates, filtering down to + // those that are coercible, and then disambiguating the result. + // Such an approach would let us disambiguate between overloaded + // symbols based on their type (e.g., by casting the name of + // an overloaded function to the type of the overload we mean + // to reference). + // + if (auto fromOverloadedExpr = as(fromExpr)) { - // If we are about to try and coerce an overloaded expression, - // then we should start by trying to resolve the ambiguous reference - // based on prioritization of the different candidates. - // - // TODO: A more powerful model would be to try to coerce each - // of the constituent overload candidates, filtering down to - // those that are coercible, and then disambiguating the result. - // Such an approach would let us disambiguate between overloaded - // symbols based on their type (e.g., by casting the name of - // an overloaded function to the type of the overload we mean - // to reference). - // - if( auto fromOverloadedExpr = as(fromExpr) ) - { - auto resolvedExpr = maybeResolveOverloadedExpr(fromOverloadedExpr, LookupMask::Default, nullptr); + auto resolvedExpr = + maybeResolveOverloadedExpr(fromOverloadedExpr, LookupMask::Default, nullptr); - fromExpr = resolvedExpr; - fromType = resolvedExpr->type; - } + fromExpr = resolvedExpr; + fromType = resolvedExpr->type; + } - // An important and easy case is when the "to" and "from" types are equal. - // - if( toType->equals(fromType) ) - { - if(outToExpr) - *outToExpr = fromExpr; - if(outCost) - *outCost = kConversionCost_None; - return true; - } + // An important and easy case is when the "to" and "from" types are equal. + // + if (toType->equals(fromType)) + { + if (outToExpr) + *outToExpr = fromExpr; + if (outCost) + *outCost = kConversionCost_None; + return true; + } - // If both are string types we assume they are convertable in both directions - if (as(fromType) && as(toType)) - { - if (outToExpr) - *outToExpr = fromExpr; - if (outCost) - *outCost = kConversionCost_None; - return true; - } + // If both are string types we assume they are convertable in both directions + if (as(fromType) && as(toType)) + { + if (outToExpr) + *outToExpr = fromExpr; + if (outCost) + *outCost = kConversionCost_None; + return true; + } - // Allow implicit conversion from sized array to unsized array when - // calling a function. - // Note: we implement the logic here instead of an implicit_conversion - // intrinsic in the core module because we only want to allow this conversion - // when calling a function. - // - if (site == CoercionSite::Argument) + // Allow implicit conversion from sized array to unsized array when + // calling a function. + // Note: we implement the logic here instead of an implicit_conversion + // intrinsic in the core module because we only want to allow this conversion + // when calling a function. + // + if (site == CoercionSite::Argument) + { + if (auto fromArrayType = as(fromType)) { - if (auto fromArrayType = as(fromType)) + if (auto toArrayType = as(toType)) { - if (auto toArrayType = as(toType)) + if (fromArrayType->getElementType()->equals(toArrayType->getElementType()) && + toArrayType->isUnsized()) { - if (fromArrayType->getElementType()->equals(toArrayType->getElementType()) - && toArrayType->isUnsized()) - { - if (outToExpr) - *outToExpr = fromExpr; - if (outCost) - *outCost = kConversionCost_SizedArrayToUnsizedArray; - return true; - } + if (outToExpr) + *outToExpr = fromExpr; + if (outCost) + *outCost = kConversionCost_SizedArrayToUnsizedArray; + return true; } } } + } + + // Another important case is when either the "to" or "from" type + // represents an error. In such a case we must have already + // reporeted the error, so it is better to allow the conversion + // to pass than to report a "cascading" error that might not + // make any sense. + // + if (as(toType) || as(fromType)) + { + if (outToExpr) + *outToExpr = CreateImplicitCastExpr(toType, fromExpr); + if (outCost) + *outCost = kConversionCost_None; + return true; + } - // Another important case is when either the "to" or "from" type - // represents an error. In such a case we must have already - // reporeted the error, so it is better to allow the conversion - // to pass than to report a "cascading" error that might not - // make any sense. + { + // It is possible that one or more of the types involved might have modifiers + // on it, but the underlying types are otherwise the same. // - if(as(toType) || as(fromType)) - { - if(outToExpr) - *outToExpr = CreateImplicitCastExpr(toType, fromExpr); - if(outCost) - *outCost = kConversionCost_None; - return true; - } + auto toModified = as(toType); + auto toBase = toModified ? toModified->getBase() : toType; + // + auto fromModified = as(fromType); + auto fromBase = + fromModified ? QualType(fromModified->getBase(), fromType.isLeftValue) : fromType; + + if ((toModified || fromModified) && toBase->equals(fromBase)) { - // It is possible that one or more of the types involved might have modifiers - // on it, but the underlying types are otherwise the same. + // We need to check each modifier present on either `toType` + // or `fromType`. For each modifier, it will either be: // - auto toModified = as(toType); - auto toBase = toModified ? toModified->getBase() : toType; + // * Present on both types; these are a non-issue + // * Present only on `toType` + // * Present only on `fromType` // - auto fromModified = as(fromType); - auto fromBase = fromModified ? QualType(fromModified->getBase(), fromType.isLeftValue) : fromType; - - - if((toModified || fromModified) && toBase->equals(fromBase)) + if (toModified) { - // We need to check each modifier present on either `toType` - // or `fromType`. For each modifier, it will either be: - // - // * Present on both types; these are a non-issue - // * Present only on `toType` - // * Present only on `fromType` - // - if( toModified ) + for (Index m = 0; m < toModified->getModifierCount(); m++) { - for (Index m = 0; m < toModified->getModifierCount(); m++) + auto modifier = toModified->getModifier(m); + if (_hasMatchingModifier(fromModified, modifier)) + continue; + + // If `modifier` is present on `toType`, but not `fromType`, + // then we need to know whether this modifier can be added + // to the type of an expression as part of coercion. + // + if (!_canModifierBeAddedDuringCoercion(modifier)) { - auto modifier = toModified->getModifier(m); - if(_hasMatchingModifier(fromModified, modifier)) - continue; - - // If `modifier` is present on `toType`, but not `fromType`, - // then we need to know whether this modifier can be added - // to the type of an expression as part of coercion. - // - if( !_canModifierBeAddedDuringCoercion(modifier) ) - { - return _failedCoercion(toType, outToExpr, fromExpr); - } + return _failedCoercion(toType, outToExpr, fromExpr); } } - if( fromModified ) + } + if (fromModified) + { + for (Index m = 0; m < fromModified->getModifierCount(); m++) { - for (Index m = 0; m < fromModified->getModifierCount(); m++) - { - auto modifier = fromModified->getModifier(m); + auto modifier = fromModified->getModifier(m); - if(_hasMatchingModifier(toModified, modifier)) - continue; + if (_hasMatchingModifier(toModified, modifier)) + continue; - // If `modifier` is present on `fromType`, but not `toType`, - // then we need to know whether this modifier can be dropped - // to the type of an expression as part of coercion. - // - if( !_canModifierBeDroppedDuringCoercion(modifier) ) - { - return _failedCoercion(toType, outToExpr, fromExpr); - } + // If `modifier` is present on `fromType`, but not `toType`, + // then we need to know whether this modifier can be dropped + // to the type of an expression as part of coercion. + // + if (!_canModifierBeDroppedDuringCoercion(modifier)) + { + return _failedCoercion(toType, outToExpr, fromExpr); } } + } - // If all the modifiers were okay, we can convert. - - // TODO: we may need a cost to allow disambiguation of overloads based on modifiers? - if(outCost) - { - *outCost = kConversionCost_None; - } - if( outToExpr ) - { - *outToExpr = createModifierCastExpr(toType, fromExpr); - } + // If all the modifiers were okay, we can convert. - return true; + // TODO: we may need a cost to allow disambiguation of overloads based on modifiers? + if (outCost) + { + *outCost = kConversionCost_None; } - } - - // Coercion from an initializer list is allowed for many types, - // so we will farm that out to its own subroutine. - // - if (fromExpr && as(fromExpr->type.type)) - { - if (auto fromInitializerListExpr = as(fromExpr)) + if (outToExpr) { - if (!_coerceInitializerList( - toType, - outToExpr, - fromInitializerListExpr)) - { - return false; - } - - // For now, we treat coercion from an initializer list - // as having no cost, so that all conversions from initializer - // lists are equally valid. This is fine given where initializer - // lists are allowed to appear now, but might need to be made - // more strict if we allow for initializer lists in more - // places in the language (e.g., as function arguments). - // - if (outCost) - { - *outCost = kConversionCost_None; - } - - return true; + *outToExpr = createModifierCastExpr(toType, fromExpr); } + + return true; } + } - // nullptr_t can be cast into any pointer type. - if (as(fromType) && as(toType)) + // Coercion from an initializer list is allowed for many types, + // so we will farm that out to its own subroutine. + // + if (fromExpr && as(fromExpr->type.type)) + { + if (auto fromInitializerListExpr = as(fromExpr)) { - if (outCost) + if (!_coerceInitializerList(toType, outToExpr, fromInitializerListExpr)) { - *outCost = kConversionCost_NullPtrToPtr; + return false; } - if (outToExpr) + + // For now, we treat coercion from an initializer list + // as having no cost, so that all conversions from initializer + // lists are equally valid. This is fine given where initializer + // lists are allowed to appear now, but might need to be made + // more strict if we allow for initializer lists in more + // places in the language (e.g., as function arguments). + // + if (outCost) { - auto* defaultExpr = getASTBuilder()->create(); - defaultExpr->type = QualType(toType); - *outToExpr = defaultExpr; + *outCost = kConversionCost_None; } + return true; } - // none_t can be cast into any Optional type. - if (as(fromType) && as(toType)) + } + + // nullptr_t can be cast into any pointer type. + if (as(fromType) && as(toType)) + { + if (outCost) + { + *outCost = kConversionCost_NullPtrToPtr; + } + if (outToExpr) + { + auto* defaultExpr = getASTBuilder()->create(); + defaultExpr->type = QualType(toType); + *outToExpr = defaultExpr; + } + return true; + } + // none_t can be cast into any Optional type. + if (as(fromType) && as(toType)) + { + if (outCost) + { + *outCost = kConversionCost_NoneToOptional; + } + if (outToExpr) + { + auto resultExpr = getASTBuilder()->create(); + resultExpr->loc = fromExpr->loc; + resultExpr->type = toType; + *outToExpr = resultExpr; + } + return true; + } + + // A enum type can be converted into its underlying tag type. + if (auto enumDecl = isEnumType(fromType)) + { + Type* tagType = enumDecl->tagType; + if (tagType == toType) { if (outCost) { - *outCost = kConversionCost_NoneToOptional; + *outCost = kConversionCost_RankPromotion; } if (outToExpr) { - auto resultExpr = getASTBuilder()->create(); - resultExpr->loc = fromExpr->loc; - resultExpr->type = toType; - *outToExpr = resultExpr; + auto rsExpr = getASTBuilder()->create(); + rsExpr->type = toType; + rsExpr->loc = fromExpr->loc; + rsExpr->base = fromExpr; + *outToExpr = rsExpr; } return true; } + } - // A enum type can be converted into its underlying tag type. - if (auto enumDecl = isEnumType(fromType)) + // matrix type with different layouts are convertible + if (auto fromMatrixType = as(fromType)) + { + if (auto toMatrixType = as(toType)) { - Type* tagType = enumDecl->tagType; - if (tagType == toType) + if (fromMatrixType->getElementType()->equals(toMatrixType->getElementType()) && + fromMatrixType->getRowCount()->equals(toMatrixType->getRowCount()) && + fromMatrixType->getColumnCount()->equals(toMatrixType->getColumnCount())) { if (outCost) { - *outCost = kConversionCost_RankPromotion; + *outCost = kConversionCost_MatrixLayout; } if (outToExpr) { - auto rsExpr = getASTBuilder()->create(); - rsExpr->type = toType; - rsExpr->loc = fromExpr->loc; - rsExpr->base = fromExpr; - *outToExpr = rsExpr; + *outToExpr = fromExpr; } return true; } } + } - // matrix type with different layouts are convertible - if (auto fromMatrixType = as(fromType)) + // A type is always convertible to any of its supertypes. + // + if (auto witness = tryGetSubtypeWitness(fromType, toType)) + { + if (outToExpr) { - if (auto toMatrixType = as(toType)) + *outToExpr = createCastToSuperTypeExpr(toType, fromExpr, witness); + + // If the original expression was an l-value, then the result + // of the cast may be an l-value itself. We want to be able + // to invoke `[mutating]` methods on a value that is cast to + // an interface it conforms to, and we also expect to be able + // to pass a value of a derived `struct` type into methods that + // expect a value of its base type. + // + if (fromExpr && fromExpr->type.isLeftValue) { - if (fromMatrixType->getElementType()->equals(toMatrixType->getElementType()) && - fromMatrixType->getRowCount()->equals(toMatrixType->getRowCount()) && - fromMatrixType->getColumnCount()->equals(toMatrixType->getColumnCount())) - { - if (outCost) - { - *outCost = kConversionCost_MatrixLayout; - } - if (outToExpr) - { - *outToExpr = fromExpr; - } - return true; - } + // If the original type is a concrete type and toType is an interface type, + // we need to wrap the original expression into a MakeExistential, and the + // result of MakeExistential is not an l-value. + bool toTypeIsInterface = isInterfaceType(toType); + bool fromTypeIsInterface = isInterfaceType(fromType); + if (!toTypeIsInterface || toTypeIsInterface == fromTypeIsInterface) + (*outToExpr)->type.isLeftValue = true; } - } - - // A type is always convertible to any of its supertypes. + if (outCost) + *outCost = kConversionCost_CastToInterface; + return true; + } + else if (auto fromIsToWitness = tryGetSubtypeWitness(toType, fromType)) + { + // Is toType and fromType the same via some type equality witness? + // If so there is no need to do any conversion. // - if(auto witness = tryGetSubtypeWitness(fromType, toType)) + if (isTypeEqualityWitness(fromIsToWitness)) { if (outToExpr) { - *outToExpr = createCastToSuperTypeExpr(toType, fromExpr, witness); - - // If the original expression was an l-value, then the result - // of the cast may be an l-value itself. We want to be able - // to invoke `[mutating]` methods on a value that is cast to - // an interface it conforms to, and we also expect to be able - // to pass a value of a derived `struct` type into methods that - // expect a value of its base type. - // - if (fromExpr && fromExpr->type.isLeftValue) - { - // If the original type is a concrete type and toType is an interface type, - // we need to wrap the original expression into a MakeExistential, and the - // result of MakeExistential is not an l-value. - bool toTypeIsInterface = isInterfaceType(toType); - bool fromTypeIsInterface = isInterfaceType(fromType); - if (!toTypeIsInterface || toTypeIsInterface == fromTypeIsInterface) - (*outToExpr)->type.isLeftValue = true; - } + *outToExpr = createCastToSuperTypeExpr(toType, fromExpr, fromIsToWitness); } if (outCost) - *outCost = kConversionCost_CastToInterface; + *outCost = 0; return true; } - else if (auto fromIsToWitness = tryGetSubtypeWitness(toType, fromType)) + } + + // Disallow converting to a ParameterGroupType. + // + // TODO(tfoley): Under what circumstances would this check ever be needed? + // + if (as(toType)) + { + return _failedCoercion(toType, outToExpr, fromExpr); + } + + // We allow implicit conversion of a parameter group type like + // `ConstantBuffer` or `ParameterBlock` to its element + // type `X`. + // + if (auto fromParameterGroupType = as(fromType)) + { + auto fromElementType = fromParameterGroupType->getElementType(); + + // If we convert, e.g., `ConstantBuffer to `A`, we will allow + // subsequent conversion of `A` to `B` if such a conversion + // is possible. + // + ConversionCost subCost = kConversionCost_None; + + DerefExpr* derefExpr = nullptr; + if (outToExpr) { - // Is toType and fromType the same via some type equality witness? - // If so there is no need to do any conversion. - // - if (isTypeEqualityWitness(fromIsToWitness)) - { - if (outToExpr) - { - *outToExpr = createCastToSuperTypeExpr(toType, fromExpr, fromIsToWitness); - } - if (outCost) - *outCost = 0; - return true; - } + derefExpr = m_astBuilder->create(); + derefExpr->base = fromExpr; + derefExpr->type = QualType(fromElementType); } - // Disallow converting to a ParameterGroupType. - // - // TODO(tfoley): Under what circumstances would this check ever be needed? - // - if (as(toType)) + if (!_coerce(site, toType, outToExpr, fromElementType, derefExpr, &subCost)) { - return _failedCoercion(toType, outToExpr, fromExpr); + return false; } - // We allow implicit conversion of a parameter group type like - // `ConstantBuffer` or `ParameterBlock` to its element - // type `X`. - // - if(auto fromParameterGroupType = as(fromType)) + if (outCost) + *outCost = subCost + kConversionCost_ImplicitDereference; + return true; + } + + if (auto refType = as(toType)) + { + ConversionCost cost; + if (!canCoerce(refType->getValueType(), fromType, fromExpr, &cost)) + return false; + if (as(toType) && !fromExpr->type.isLeftValue) + return false; + ConversionCost subCost = kConversionCost_GetRef; + + MakeRefExpr* refExpr = nullptr; + if (outToExpr) { - auto fromElementType = fromParameterGroupType->getElementType(); + refExpr = m_astBuilder->create(); + refExpr->base = fromExpr; + refExpr->type = QualType(refType); + refExpr->type.isLeftValue = false; + *outToExpr = refExpr; + } + if (outCost) + *outCost = subCost; + return true; + } - // If we convert, e.g., `ConstantBuffer to `A`, we will allow - // subsequent conversion of `A` to `B` if such a conversion - // is possible. - // - ConversionCost subCost = kConversionCost_None; - DerefExpr* derefExpr = nullptr; - if(outToExpr) - { - derefExpr = m_astBuilder->create(); - derefExpr->base = fromExpr; - derefExpr->type = QualType(fromElementType); - } + // Allow implicit dereferencing a reference type. + if (auto fromRefType = as(fromType)) + { + auto fromValueType = fromRefType->getValueType(); - if(!_coerce( - site, - toType, - outToExpr, - fromElementType, - derefExpr, - &subCost)) - { - return false; - } + // If we convert, e.g., `ConstantBuffer to `A`, we will allow + // subsequent conversion of `A` to `B` if such a conversion + // is possible. + // + ConversionCost subCost = kConversionCost_None; - if(outCost) - *outCost = subCost + kConversionCost_ImplicitDereference; - return true; + Expr* openRefExpr = nullptr; + if (outToExpr) + { + openRefExpr = maybeOpenRef(fromExpr); } - if (auto refType = as(toType)) + if (!_coerce(site, toType, outToExpr, fromValueType, openRefExpr, &subCost)) { - ConversionCost cost; - if (!canCoerce(refType->getValueType(), fromType, fromExpr, &cost)) - return false; - if (as(toType) && !fromExpr->type.isLeftValue) - return false; - ConversionCost subCost = kConversionCost_GetRef; - - MakeRefExpr* refExpr = nullptr; - if (outToExpr) - { - refExpr = m_astBuilder->create(); - refExpr->base = fromExpr; - refExpr->type = QualType(refType); - refExpr->type.isLeftValue = false; - *outToExpr = refExpr; - } - if (outCost) - *outCost = subCost; - return true; + return false; } + if (outCost) + *outCost = subCost + kConversionCost_ImplicitDereference; + return true; + } - // Allow implicit dereferencing a reference type. - if (auto fromRefType = as(fromType)) - { - auto fromValueType = fromRefType->getValueType(); - // If we convert, e.g., `ConstantBuffer to `A`, we will allow - // subsequent conversion of `A` to `B` if such a conversion - // is possible. - // - ConversionCost subCost = kConversionCost_None; + // The main general-purpose approach for conversion is + // using suitable marked initializer ("constructor") + // declarations on the target type. + // + // This is treated as a form of overload resolution, + // since we are effectively forming an overloaded + // call to one of the initializers in the target type. + + OverloadResolveContext overloadContext; + overloadContext.disallowNestedConversions = (site != CoercionSite::ExplicitCoercion); + overloadContext.argCount = 1; + List args; + args.add(fromExpr); + overloadContext.argTypes = &fromType.type; + overloadContext.args = &args; + overloadContext.sourceScope = m_outerScope; + overloadContext.originalExpr = nullptr; + if (fromExpr) + { + overloadContext.loc = fromExpr->loc; + overloadContext.funcLoc = fromExpr->loc; + } - Expr* openRefExpr = nullptr; - if (outToExpr) - { - openRefExpr = maybeOpenRef(fromExpr); - } + overloadContext.baseExpr = nullptr; + overloadContext.mode = OverloadResolveContext::Mode::JustTrying; - if (!_coerce( - site, - toType, - outToExpr, - fromValueType, - openRefExpr, - &subCost)) - { - return false; - } + // Since the lookup and resolution of all possible implicit conversions + // can be very costly, we use a cache to store the checking results. + ImplicitCastMethodKey implicitCastKey = ImplicitCastMethodKey(fromType, toType, fromExpr); + ImplicitCastMethod* cachedMethod = getShared()->tryGetImplicitCastMethod(implicitCastKey); + if (cachedMethod) + { + if (cachedMethod->conversionFuncOverloadCandidate.status != + OverloadCandidate::Status::Applicable) + { + return _failedCoercion(toType, outToExpr, fromExpr); + } + overloadContext.bestCandidateStorage = cachedMethod->conversionFuncOverloadCandidate; + overloadContext.bestCandidate = &overloadContext.bestCandidateStorage; + if (!outToExpr) + { + // If we are not requesting to create an expression, we can return early. if (outCost) - *outCost = subCost + kConversionCost_ImplicitDereference; + *outCost = cachedMethod->cost; return true; } - - - // The main general-purpose approach for conversion is - // using suitable marked initializer ("constructor") - // declarations on the target type. - // - // This is treated as a form of overload resolution, - // since we are effectively forming an overloaded - // call to one of the initializers in the target type. - - OverloadResolveContext overloadContext; - overloadContext.disallowNestedConversions = (site != CoercionSite::ExplicitCoercion); - overloadContext.argCount = 1; - List args; - args.add(fromExpr); - overloadContext.argTypes = &fromType.type; - overloadContext.args = &args; - overloadContext.sourceScope = m_outerScope; - overloadContext.originalExpr = nullptr; - if(fromExpr) + else { - overloadContext.loc = fromExpr->loc; - overloadContext.funcLoc = fromExpr->loc; + if (cachedMethod->isAmbiguous) + { + overloadContext.bestCandidate = nullptr; + overloadContext.bestCandidates.add(cachedMethod->conversionFuncOverloadCandidate); + } } + } - overloadContext.baseExpr = nullptr; - overloadContext.mode = OverloadResolveContext::Mode::JustTrying; - - // Since the lookup and resolution of all possible implicit conversions - // can be very costly, we use a cache to store the checking results. - ImplicitCastMethodKey implicitCastKey = ImplicitCastMethodKey(fromType, toType, fromExpr); - ImplicitCastMethod* cachedMethod = getShared()->tryGetImplicitCastMethod(implicitCastKey); + if (!overloadContext.bestCandidate) + { + AddTypeOverloadCandidates(toType, overloadContext); + } - if (cachedMethod) + // After all of the overload candidates have been added + // to the context and processed, we need to see whether + // there was one best overload or not. + // + if (overloadContext.bestCandidates.getCount() != 0) + { + // In this case there were multiple equally-good candidates to call. + // + // We will start by checking if the candidates are + // even applicable, because if not, then we shouldn't + // consider the conversion as possible. + // + if (overloadContext.bestCandidates[0].status != OverloadCandidate::Status::Applicable) { - if (cachedMethod->conversionFuncOverloadCandidate.status != OverloadCandidate::Status::Applicable) - { - return _failedCoercion(toType, outToExpr, fromExpr); - } - overloadContext.bestCandidateStorage = cachedMethod->conversionFuncOverloadCandidate; - overloadContext.bestCandidate = &overloadContext.bestCandidateStorage; - if (!outToExpr) - { - // If we are not requesting to create an expression, we can return early. - if (outCost) - *outCost = cachedMethod->cost; - return true; - } - else + if (!cachedMethod) { - if (cachedMethod->isAmbiguous) - { - overloadContext.bestCandidate = nullptr; - overloadContext.bestCandidates.add(cachedMethod->conversionFuncOverloadCandidate); - } + getShared()->cacheImplicitCastMethod(implicitCastKey, ImplicitCastMethod{}); } + return _failedCoercion(toType, outToExpr, fromExpr); } - if (!overloadContext.bestCandidate) - { - AddTypeOverloadCandidates(toType, overloadContext); - } - - // After all of the overload candidates have been added - // to the context and processed, we need to see whether - // there was one best overload or not. + // If all of the candidates in `bestCandidates` are applicable, + // then we have an ambiguity. + // + // We will compute a nominal conversion cost as the minimum over + // all the conversions available. // - if(overloadContext.bestCandidates.getCount() != 0) + ConversionCost bestCost = kConversionCost_Explicit; + ImplicitCastMethod method; + for (auto candidate : overloadContext.bestCandidates) { - // In this case there were multiple equally-good candidates to call. - // - // We will start by checking if the candidates are - // even applicable, because if not, then we shouldn't - // consider the conversion as possible. - // - if (overloadContext.bestCandidates[0].status != OverloadCandidate::Status::Applicable) + ConversionCost candidateCost = getImplicitConversionCostWithKnownArg( + candidate.item.declRef.getDecl(), + toType, + fromExpr); + if (candidateCost < bestCost) { - if (!cachedMethod) - { - getShared()->cacheImplicitCastMethod(implicitCastKey, ImplicitCastMethod{}); - } - return _failedCoercion(toType, outToExpr, fromExpr); + method.conversionFuncOverloadCandidate = candidate; + bestCost = candidateCost; } + } - // If all of the candidates in `bestCandidates` are applicable, - // then we have an ambiguity. - // - // We will compute a nominal conversion cost as the minimum over - // all the conversions available. - // - ConversionCost bestCost = kConversionCost_Explicit; - ImplicitCastMethod method; - for(auto candidate : overloadContext.bestCandidates) - { - ConversionCost candidateCost = getImplicitConversionCostWithKnownArg( - candidate.item.declRef.getDecl(), toType, fromExpr); - if (candidateCost < bestCost) - { - method.conversionFuncOverloadCandidate = candidate; - bestCost = candidateCost; - } - } + // Conceptually, we want to treat the conversion as + // possible, but report it as ambiguous if we actually + // need to reify the result as an expression. + // + if (outToExpr) + { + getSink()->diagnose(fromExpr, Diagnostics::ambiguousConversion, fromType, toType); - // Conceptually, we want to treat the conversion as - // possible, but report it as ambiguous if we actually - // need to reify the result as an expression. - // - if(outToExpr) - { - getSink()->diagnose(fromExpr, Diagnostics::ambiguousConversion, fromType, toType); + *outToExpr = CreateErrorExpr(fromExpr); + } - *outToExpr = CreateErrorExpr(fromExpr); - } + if (!cachedMethod) + { + method.isAmbiguous = true; + method.cost = bestCost; + getShared()->cacheImplicitCastMethod(implicitCastKey, method); + } + if (outCost) + *outCost = bestCost; + return true; + } + else if (overloadContext.bestCandidate) + { + // If there is a single best candidate for conversion, + // then we want to use it. + // + // It is possible that there was a single best candidate, + // but it wasn't actually usable, so we will check for + // that case first. + // + if (overloadContext.bestCandidate->status != OverloadCandidate::Status::Applicable) + { if (!cachedMethod) { - method.isAmbiguous = true; - method.cost = bestCost; - getShared()->cacheImplicitCastMethod(implicitCastKey, method); + getShared()->cacheImplicitCastMethod(implicitCastKey, ImplicitCastMethod{}); } - - if(outCost) - *outCost = bestCost; - return true; + return _failedCoercion(toType, outToExpr, fromExpr); } - else if(overloadContext.bestCandidate) + + // Next, we need to look at the implicit conversion + // cost associated with the initializer we are invoking. + // + ConversionCost cost = getImplicitConversionCostWithKnownArg( + overloadContext.bestCandidate->item.declRef.getDecl(), + toType, + fromExpr); + + // If the cost is too high to be usable as an + // implicit conversion, then we will report the + // conversion as possible (so that an overload involving + // this conversion will be selected over one without), + // but then emit a diagnostic when actually reifying + // the result expression. + // + if (outToExpr && site != CoercionSite::ExplicitCoercion) { - // If there is a single best candidate for conversion, - // then we want to use it. - // - // It is possible that there was a single best candidate, - // but it wasn't actually usable, so we will check for - // that case first. - // - if (overloadContext.bestCandidate->status != OverloadCandidate::Status::Applicable) + if (cost >= kConversionCost_Explicit) { - if (!cachedMethod) - { - getShared()->cacheImplicitCastMethod(implicitCastKey, ImplicitCastMethod{}); - } - return _failedCoercion(toType, outToExpr, fromExpr); + getSink()->diagnose(fromExpr, Diagnostics::typeMismatch, toType, fromType); + getSink()->diagnoseWithoutSourceView( + fromExpr, + Diagnostics::noteExplicitConversionPossible, + fromType, + toType); } - - // Next, we need to look at the implicit conversion - // cost associated with the initializer we are invoking. - // - ConversionCost cost = getImplicitConversionCostWithKnownArg( - overloadContext.bestCandidate->item.declRef.getDecl(), toType, fromExpr); - - // If the cost is too high to be usable as an - // implicit conversion, then we will report the - // conversion as possible (so that an overload involving - // this conversion will be selected over one without), - // but then emit a diagnostic when actually reifying - // the result expression. - // - if (outToExpr && site != CoercionSite::ExplicitCoercion) + else if (cost >= kConversionCost_Default) { - if (cost >= kConversionCost_Explicit) + // For general types of implicit conversions, we issue a warning, unless `fromExpr` + // is a known constant and we know it won't cause a problem. + bool shouldEmitGeneralWarning = true; + if (isScalarIntegerType(toType)) { - getSink()->diagnose(fromExpr, Diagnostics::typeMismatch, toType, fromType); - getSink()->diagnoseWithoutSourceView( - fromExpr, Diagnostics::noteExplicitConversionPossible, fromType, toType); - } - else if (cost >= kConversionCost_Default) - { - // For general types of implicit conversions, we issue a warning, unless `fromExpr` is a known constant - // and we know it won't cause a problem. - bool shouldEmitGeneralWarning = true; - if (isScalarIntegerType(toType)) + if (auto intVal = tryFoldIntegerConstantExpression( + fromExpr, + ConstantFoldingKind::CompileTime, + nullptr)) { - if (auto intVal = tryFoldIntegerConstantExpression(fromExpr, ConstantFoldingKind::CompileTime, nullptr)) + if (auto val = as(intVal)) { - if (auto val = as(intVal)) + if (isIntValueInRangeOfType(val->getValue(), toType)) { - if (isIntValueInRangeOfType(val->getValue(), toType)) - { - // OK. - shouldEmitGeneralWarning = false; - } + // OK. + shouldEmitGeneralWarning = false; } } } - if (shouldEmitGeneralWarning) - { - getSink()->diagnose(fromExpr, Diagnostics::unrecommendedImplicitConversion, fromType, toType); - } } - - if (site == CoercionSite::Argument) + if (shouldEmitGeneralWarning) { - auto builtinConversionKind = getImplicitConversionBuiltinKind( - overloadContext.bestCandidate->item.declRef.getDecl()); - if (builtinConversionKind == kBuiltinConversion_FloatToDouble) - { - if (!as(fromExpr)) - getSink()->diagnose(fromExpr, Diagnostics::implicitConversionToDouble); - } + getSink()->diagnose( + fromExpr, + Diagnostics::unrecommendedImplicitConversion, + fromType, + toType); } } - if (fromType.isLeftValue) - { - // If we are implicitly casting the type of an l-value, we need to impose additional cost. - cost += kConversionCost_LValueCast; - } - if(outCost) - *outCost = cost; - if(outToExpr) + if (site == CoercionSite::Argument) { - // The logic here is a bit ugly, to deal with the fact that - // `CompleteOverloadCandidate` will, left to its own devices, - // construct a vanilla `InvokeExpr` to represent the call - // to the initializer we found, while we *want* it to - // create some variety of `ImplicitCastExpr`. - // - // Now, it just so happens that `CompleteOverloadCandidate` - // will use the "original" expression if one is available, - // so we'll create one and initialize it here. - // We fill in the location and arguments, but not the - // base expression (the callee), since that will come - // from the selected overload candidate. - // - InvokeExpr* castExpr = (site == CoercionSite::ExplicitCoercion) - ? m_astBuilder->create() - : createImplicitCastExpr(); - castExpr->loc = fromExpr->loc; - castExpr->arguments.add(fromExpr); - // - // Next we need to set our cast expression as the "original" - // expression and then complete the overload process. - // - overloadContext.originalExpr = castExpr; - *outToExpr = CompleteOverloadCandidate(overloadContext, *overloadContext.bestCandidate); - // - // However, the above isn't *quite* enough, because - // the process of completing the overload candidate - // might overwrite the argument list that was passed - // in to overload resolution, and in this case that - // "argument list" was just a pointer to `fromExpr`. - // - // That means we need to clear the argument list and - // reload it from `args[0]` to make sure that we - // got the arguments *after* any transformations - // were applied. - // For right now this probably doesn't matter, - // because we don't allow nested implicit conversions, - // but I'd rather play it safe. - // - castExpr->arguments.clear(); - castExpr->arguments.add(args[0]); + auto builtinConversionKind = getImplicitConversionBuiltinKind( + overloadContext.bestCandidate->item.declRef.getDecl()); + if (builtinConversionKind == kBuiltinConversion_FloatToDouble) + { + if (!as(fromExpr)) + getSink()->diagnose(fromExpr, Diagnostics::implicitConversionToDouble); + } } - if (!cachedMethod) - getShared()->cacheImplicitCastMethod(implicitCastKey, ImplicitCastMethod{ *overloadContext.bestCandidate, cost }); - return true; } - if (!cachedMethod) + if (fromType.isLeftValue) { - getShared()->cacheImplicitCastMethod(implicitCastKey, ImplicitCastMethod{}); + // If we are implicitly casting the type of an l-value, we need to impose additional + // cost. + cost += kConversionCost_LValueCast; } - return _failedCoercion(toType, outToExpr, fromExpr); - } - - bool SemanticsVisitor::canCoerce( - Type* toType, - QualType fromType, - Expr* fromExpr, - ConversionCost* outCost) - { - // As an optimization, we will maintain a cache of conversion results - // for basic types such as scalars and vectors. - // - - bool shouldAddToCache = false; - ConversionCost cost; - TypeCheckingCache* typeCheckingCache = getLinkage()->getTypeCheckingCache(); + if (outCost) + *outCost = cost; - BasicTypeKeyPair cacheKey; - cacheKey.type1 = makeBasicTypeKey(toType); - cacheKey.type2 = makeBasicTypeKey(fromType, fromExpr); - - if( cacheKey.isValid()) + if (outToExpr) { - if (typeCheckingCache->conversionCostCache.tryGetValue(cacheKey, cost)) - { - if (outCost) - *outCost = cost; - return cost != kConversionCost_Impossible; - } - else - shouldAddToCache = true; + // The logic here is a bit ugly, to deal with the fact that + // `CompleteOverloadCandidate` will, left to its own devices, + // construct a vanilla `InvokeExpr` to represent the call + // to the initializer we found, while we *want* it to + // create some variety of `ImplicitCastExpr`. + // + // Now, it just so happens that `CompleteOverloadCandidate` + // will use the "original" expression if one is available, + // so we'll create one and initialize it here. + // We fill in the location and arguments, but not the + // base expression (the callee), since that will come + // from the selected overload candidate. + // + InvokeExpr* castExpr = (site == CoercionSite::ExplicitCoercion) + ? m_astBuilder->create() + : createImplicitCastExpr(); + castExpr->loc = fromExpr->loc; + castExpr->arguments.add(fromExpr); + // + // Next we need to set our cast expression as the "original" + // expression and then complete the overload process. + // + overloadContext.originalExpr = castExpr; + *outToExpr = CompleteOverloadCandidate(overloadContext, *overloadContext.bestCandidate); + // + // However, the above isn't *quite* enough, because + // the process of completing the overload candidate + // might overwrite the argument list that was passed + // in to overload resolution, and in this case that + // "argument list" was just a pointer to `fromExpr`. + // + // That means we need to clear the argument list and + // reload it from `args[0]` to make sure that we + // got the arguments *after* any transformations + // were applied. + // For right now this probably doesn't matter, + // because we don't allow nested implicit conversions, + // but I'd rather play it safe. + // + castExpr->arguments.clear(); + castExpr->arguments.add(args[0]); } + if (!cachedMethod) + getShared()->cacheImplicitCastMethod( + implicitCastKey, + ImplicitCastMethod{*overloadContext.bestCandidate, cost}); + return true; + } + if (!cachedMethod) + { + getShared()->cacheImplicitCastMethod(implicitCastKey, ImplicitCastMethod{}); + } + return _failedCoercion(toType, outToExpr, fromExpr); +} - // If there was no suitable entry in the cache, - // then we fall back to the general-purpose - // conversion checking logic. - // - // Note that we are passing in `nullptr` as - // the output expression to be constructed, - // which suppresses emission of any diagnostics - // during the coercion process. - // - bool rs = _coerce( - CoercionSite::General, - toType, - nullptr, - fromType, - fromExpr, - &cost); +bool SemanticsVisitor::canCoerce( + Type* toType, + QualType fromType, + Expr* fromExpr, + ConversionCost* outCost) +{ + // As an optimization, we will maintain a cache of conversion results + // for basic types such as scalars and vectors. + // - if (outCost) - *outCost = cost; + bool shouldAddToCache = false; + ConversionCost cost; + TypeCheckingCache* typeCheckingCache = getLinkage()->getTypeCheckingCache(); + + BasicTypeKeyPair cacheKey; + cacheKey.type1 = makeBasicTypeKey(toType); + cacheKey.type2 = makeBasicTypeKey(fromType, fromExpr); - if (shouldAddToCache) + if (cacheKey.isValid()) + { + if (typeCheckingCache->conversionCostCache.tryGetValue(cacheKey, cost)) { - if (!rs) - cost = kConversionCost_Impossible; - typeCheckingCache->conversionCostCache[cacheKey] = cost; + if (outCost) + *outCost = cost; + return cost != kConversionCost_Impossible; } - - return rs; + else + shouldAddToCache = true; } - TypeCastExpr* SemanticsVisitor::createImplicitCastExpr() + // If there was no suitable entry in the cache, + // then we fall back to the general-purpose + // conversion checking logic. + // + // Note that we are passing in `nullptr` as + // the output expression to be constructed, + // which suppresses emission of any diagnostics + // during the coercion process. + // + bool rs = _coerce(CoercionSite::General, toType, nullptr, fromType, fromExpr, &cost); + + if (outCost) + *outCost = cost; + + if (shouldAddToCache) { - return m_astBuilder->create(); + if (!rs) + cost = kConversionCost_Impossible; + typeCheckingCache->conversionCostCache[cacheKey] = cost; } - Expr* SemanticsVisitor::CreateImplicitCastExpr( - Type* toType, - Expr* fromExpr) - { - TypeCastExpr* castExpr = createImplicitCastExpr(); + return rs; +} - auto typeType = m_astBuilder->getTypeType(toType); +TypeCastExpr* SemanticsVisitor::createImplicitCastExpr() +{ + return m_astBuilder->create(); +} - auto typeExpr = m_astBuilder->create(); - typeExpr->type.type = typeType; - typeExpr->base.type = toType; +Expr* SemanticsVisitor::CreateImplicitCastExpr(Type* toType, Expr* fromExpr) +{ + TypeCastExpr* castExpr = createImplicitCastExpr(); - castExpr->loc = fromExpr->loc; - castExpr->functionExpr = typeExpr; - castExpr->type = QualType(toType); - castExpr->arguments.add(fromExpr); - return castExpr; - } + auto typeType = m_astBuilder->getTypeType(toType); - Expr* SemanticsVisitor::createCastToSuperTypeExpr( - Type* toType, - Expr* fromExpr, - Val* witness) - { - CastToSuperTypeExpr* expr = m_astBuilder->create(); - expr->loc = fromExpr->loc; - expr->type = QualType(toType); - expr->valueArg = fromExpr; - expr->witnessArg = witness; - return expr; - } + auto typeExpr = m_astBuilder->create(); + typeExpr->type.type = typeType; + typeExpr->base.type = toType; - Expr* SemanticsVisitor::createModifierCastExpr( - Type* toType, - Expr* fromExpr) - { - ModifierCastExpr* expr = m_astBuilder->create(); - expr->loc = fromExpr->loc; - expr->type = QualType(toType); - expr->valueArg = fromExpr; - return expr; - } + castExpr->loc = fromExpr->loc; + castExpr->functionExpr = typeExpr; + castExpr->type = QualType(toType); + castExpr->arguments.add(fromExpr); + return castExpr; +} +Expr* SemanticsVisitor::createCastToSuperTypeExpr(Type* toType, Expr* fromExpr, Val* witness) +{ + CastToSuperTypeExpr* expr = m_astBuilder->create(); + expr->loc = fromExpr->loc; + expr->type = QualType(toType); + expr->valueArg = fromExpr; + expr->witnessArg = witness; + return expr; +} - Expr* SemanticsVisitor::coerce( - CoercionSite site, - Type* toType, - Expr* fromExpr) - { - Expr* expr = nullptr; - if (!_coerce( - site, - toType, - &expr, - fromExpr->type, - fromExpr, - nullptr)) - { - // Note(tfoley): We don't call `CreateErrorExpr` here, because that would - // clobber the type on `fromExpr`, and an invariant here is that coercion - // really shouldn't *change* the expression that is passed in, but should - // introduce new AST nodes to coerce its value to a different type... - return CreateImplicitCastExpr( - m_astBuilder->getErrorType(), - fromExpr); - } +Expr* SemanticsVisitor::createModifierCastExpr(Type* toType, Expr* fromExpr) +{ + ModifierCastExpr* expr = m_astBuilder->create(); + expr->loc = fromExpr->loc; + expr->type = QualType(toType); + expr->valueArg = fromExpr; + return expr; +} - return expr; - } - bool SemanticsVisitor::canConvertImplicitly( - ConversionCost conversionCost) +Expr* SemanticsVisitor::coerce(CoercionSite site, Type* toType, Expr* fromExpr) +{ + Expr* expr = nullptr; + if (!_coerce(site, toType, &expr, fromExpr->type, fromExpr, nullptr)) { - // Is the conversion cheap enough to be done implicitly? - if (conversionCost >= kConversionCost_GeneralConversion) - return false; - return true; + // Note(tfoley): We don't call `CreateErrorExpr` here, because that would + // clobber the type on `fromExpr`, and an invariant here is that coercion + // really shouldn't *change* the expression that is passed in, but should + // introduce new AST nodes to coerce its value to a different type... + return CreateImplicitCastExpr(m_astBuilder->getErrorType(), fromExpr); } - bool SemanticsVisitor::canConvertImplicitly( - Type* toType, - QualType fromType) - { - auto conversionCost = getConversionCost(toType, fromType); + return expr; +} - // Is the conversion cheap enough to be done implicitly? - if (canConvertImplicitly(conversionCost)) - return false; +bool SemanticsVisitor::canConvertImplicitly(ConversionCost conversionCost) +{ + // Is the conversion cheap enough to be done implicitly? + if (conversionCost >= kConversionCost_GeneralConversion) + return false; + return true; +} - return true; - } +bool SemanticsVisitor::canConvertImplicitly(Type* toType, QualType fromType) +{ + auto conversionCost = getConversionCost(toType, fromType); - ConversionCost SemanticsVisitor::getConversionCost(Type* toType, QualType fromType) - { - ConversionCost conversionCost = kConversionCost_Impossible; - if (!canCoerce(toType, fromType, nullptr, &conversionCost)) - return kConversionCost_Impossible; - return conversionCost; - } + // Is the conversion cheap enough to be done implicitly? + if (canConvertImplicitly(conversionCost)) + return false; + + return true; +} + +ConversionCost SemanticsVisitor::getConversionCost(Type* toType, QualType fromType) +{ + ConversionCost conversionCost = kConversionCost_Impossible; + if (!canCoerce(toType, fromType, nullptr, &conversionCost)) + return kConversionCost_Impossible; + return conversionCost; } +} // namespace Slang diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index dea6c6038..5b3f692be 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -12,4576 +12,4838 @@ // logic also orchestrates the overall flow and how // and when things get checked. +#include "slang-ast-iterator.h" +#include "slang-ast-reflect.h" +#include "slang-ast-synthesis.h" #include "slang-lookup.h" #include "slang-syntax.h" -#include "slang-ast-synthesis.h" -#include "slang-ast-reflect.h" -#include "slang-ast-iterator.h" + #include namespace Slang { - static ConstructorDecl* _getDefaultCtor(StructDecl* structDecl); - static List _getCtorList(ASTBuilder* m_astBuilder, SemanticsVisitor* visitor, StructDecl* structDecl, ConstructorDecl** defaultCtorOut); - - /// Visitor to transition declarations to `DeclCheckState::CheckedModifiers` - struct SemanticsDeclModifiersVisitor - : public SemanticsDeclVisitorBase - , public DeclVisitor +static ConstructorDecl* _getDefaultCtor(StructDecl* structDecl); +static List _getCtorList( + ASTBuilder* m_astBuilder, + SemanticsVisitor* visitor, + StructDecl* structDecl, + ConstructorDecl** defaultCtorOut); + +/// Visitor to transition declarations to `DeclCheckState::CheckedModifiers` +struct SemanticsDeclModifiersVisitor : public SemanticsDeclVisitorBase, + public DeclVisitor +{ + SemanticsDeclModifiersVisitor(SemanticsContext const& outer) + : SemanticsDeclVisitorBase(outer) { - SemanticsDeclModifiersVisitor(SemanticsContext const& outer) - : SemanticsDeclVisitorBase(outer) - {} + } - void visitDeclGroup(DeclGroup*) {} - - void visitDecl(Decl* decl) - { - checkModifiers(decl); - } + void visitDeclGroup(DeclGroup*) {} - void visitStructDecl(StructDecl* structDecl); - }; + void visitDecl(Decl* decl) { checkModifiers(decl); } + + void visitStructDecl(StructDecl* structDecl); +}; - struct SemanticsDeclScopeWiringVisitor : public SemanticsDeclVisitorBase, public DeclVisitor +struct SemanticsDeclScopeWiringVisitor : public SemanticsDeclVisitorBase, + public DeclVisitor +{ + SemanticsDeclScopeWiringVisitor(SemanticsContext const& outer) + : SemanticsDeclVisitorBase(outer) { - SemanticsDeclScopeWiringVisitor(SemanticsContext const& outer) - : SemanticsDeclVisitorBase(outer) - {} + } - void visitDeclGroup(DeclGroup*) {} + void visitDeclGroup(DeclGroup*) {} - void visitDecl(Decl*) {} + void visitDecl(Decl*) {} - void visitUsingDecl(UsingDecl* decl); + void visitUsingDecl(UsingDecl* decl); - void visitImplementingDecl(ImplementingDecl* decl); + void visitImplementingDecl(ImplementingDecl* decl); - void visitNamespaceDecl(NamespaceDecl* decl); - }; + void visitNamespaceDecl(NamespaceDecl* decl); +}; - struct SemanticsDeclAttributesVisitor - : public SemanticsDeclVisitorBase - , public DeclVisitor +struct SemanticsDeclAttributesVisitor : public SemanticsDeclVisitorBase, + public DeclVisitor +{ + SemanticsDeclAttributesVisitor(SemanticsContext const& outer) + : SemanticsDeclVisitorBase(outer) { - SemanticsDeclAttributesVisitor(SemanticsContext const& outer) - : SemanticsDeclVisitorBase(outer) - {} + } - void visitDecl(Decl*) {} - void visitDeclGroup(DeclGroup*) {} + void visitDecl(Decl*) {} + void visitDeclGroup(DeclGroup*) {} - void visitStructDecl(StructDecl* structDecl); + void visitStructDecl(StructDecl* structDecl); - void visitFunctionDeclBase(FunctionDeclBase* decl); + void visitFunctionDeclBase(FunctionDeclBase* decl); - void checkForwardDerivativeOfAttribute(FunctionDeclBase* funcDecl, ForwardDerivativeOfAttribute* attr); + void checkForwardDerivativeOfAttribute( + FunctionDeclBase* funcDecl, + ForwardDerivativeOfAttribute* attr); - void checkBackwardDerivativeOfAttribute(FunctionDeclBase* funcDecl, BackwardDerivativeOfAttribute* attr); + void checkBackwardDerivativeOfAttribute( + FunctionDeclBase* funcDecl, + BackwardDerivativeOfAttribute* attr); - void checkPrimalSubstituteOfAttribute(FunctionDeclBase* funcDecl, PrimalSubstituteOfAttribute* attr); - }; + void checkPrimalSubstituteOfAttribute( + FunctionDeclBase* funcDecl, + PrimalSubstituteOfAttribute* attr); +}; - struct SemanticsDeclHeaderVisitor - : public SemanticsDeclVisitorBase - , public DeclVisitor +struct SemanticsDeclHeaderVisitor : public SemanticsDeclVisitorBase, + public DeclVisitor +{ + SemanticsDeclHeaderVisitor(SemanticsContext const& outer) + : SemanticsDeclVisitorBase(outer) { - SemanticsDeclHeaderVisitor(SemanticsContext const& outer) - : SemanticsDeclVisitorBase(outer) - {} + } - void visitDecl(Decl*) {} - void visitDeclGroup(DeclGroup*) {} + void visitDecl(Decl*) {} + void visitDeclGroup(DeclGroup*) {} - void checkDerivativeMemberAttributeParent(VarDeclBase* varDecl, DerivativeMemberAttribute* attr); - void checkExtensionExternVarAttribute(VarDeclBase* varDecl, ExtensionExternVarModifier* m); - void checkMeshOutputDecl(VarDeclBase* varDecl); - void maybeApplyLayoutModifier(VarDeclBase* varDecl); - void checkVarDeclCommon(VarDeclBase* varDecl); + void checkDerivativeMemberAttributeParent( + VarDeclBase* varDecl, + DerivativeMemberAttribute* attr); + void checkExtensionExternVarAttribute(VarDeclBase* varDecl, ExtensionExternVarModifier* m); + void checkMeshOutputDecl(VarDeclBase* varDecl); + void maybeApplyLayoutModifier(VarDeclBase* varDecl); + void checkVarDeclCommon(VarDeclBase* varDecl); - void visitVarDecl(VarDecl* varDecl) - { - checkVarDeclCommon(varDecl); - } + void visitVarDecl(VarDecl* varDecl) { checkVarDeclCommon(varDecl); } - void visitGlobalGenericValueParamDecl(GlobalGenericValueParamDecl* decl) - { - checkVarDeclCommon(decl); - } + void visitGlobalGenericValueParamDecl(GlobalGenericValueParamDecl* decl) + { + checkVarDeclCommon(decl); + } - void visitImportDecl(ImportDecl* decl); + void visitImportDecl(ImportDecl* decl); - void visitIncludeDecl(IncludeDecl* decl); + void visitIncludeDecl(IncludeDecl* decl); - void visitGenericTypeParamDecl(GenericTypeParamDecl* decl); + void visitGenericTypeParamDecl(GenericTypeParamDecl* decl); - void visitGenericValueParamDecl(GenericValueParamDecl* decl); + void visitGenericValueParamDecl(GenericValueParamDecl* decl); - void visitGenericTypeConstraintDecl(GenericTypeConstraintDecl* decl); + void visitGenericTypeConstraintDecl(GenericTypeConstraintDecl* decl); - void validateGenericConstraintSubType(GenericTypeConstraintDecl* decl, TypeExp type); + void validateGenericConstraintSubType(GenericTypeConstraintDecl* decl, TypeExp type); - void visitGenericDecl(GenericDecl* genericDecl); + void visitGenericDecl(GenericDecl* genericDecl); - void visitTypeDefDecl(TypeDefDecl* decl); + void visitTypeDefDecl(TypeDefDecl* decl); - void visitGlobalGenericParamDecl(GlobalGenericParamDecl* decl); + void visitGlobalGenericParamDecl(GlobalGenericParamDecl* decl); - void visitAssocTypeDecl(AssocTypeDecl* decl); + void visitAssocTypeDecl(AssocTypeDecl* decl); - void checkCallableDeclCommon(CallableDecl* decl); + void checkCallableDeclCommon(CallableDecl* decl); - void visitFuncDecl(FuncDecl* funcDecl); + void visitFuncDecl(FuncDecl* funcDecl); - void visitParamDecl(ParamDecl* paramDecl); + void visitParamDecl(ParamDecl* paramDecl); - void visitConstructorDecl(ConstructorDecl* decl); + void visitConstructorDecl(ConstructorDecl* decl); - void visitAbstractStorageDeclCommon(ContainerDecl* decl); + void visitAbstractStorageDeclCommon(ContainerDecl* decl); - void visitSubscriptDecl(SubscriptDecl* decl); + void visitSubscriptDecl(SubscriptDecl* decl); - void visitPropertyDecl(PropertyDecl* decl); + void visitPropertyDecl(PropertyDecl* decl); - void visitStructDecl(StructDecl* decl); + void visitStructDecl(StructDecl* decl); - void visitClassDecl(ClassDecl* decl); + void visitClassDecl(ClassDecl* decl); - /// Get the type of the storage accessed by an accessor. - /// - /// The type of storage is determined by the parent declaration. - Type* _getAccessorStorageType(AccessorDecl* decl); + /// Get the type of the storage accessed by an accessor. + /// + /// The type of storage is determined by the parent declaration. + Type* _getAccessorStorageType(AccessorDecl* decl); - /// Perform checks common to all types of accessors. - void _visitAccessorDeclCommon(AccessorDecl* decl); + /// Perform checks common to all types of accessors. + void _visitAccessorDeclCommon(AccessorDecl* decl); - void visitAccessorDecl(AccessorDecl* decl); - void visitSetterDecl(SetterDecl* decl); + void visitAccessorDecl(AccessorDecl* decl); + void visitSetterDecl(SetterDecl* decl); - void cloneModifiers(Decl* dest, Decl* src); - void setFuncTypeIntoRequirementDecl(CallableDecl* decl, FuncType* funcType); - }; + void cloneModifiers(Decl* dest, Decl* src); + void setFuncTypeIntoRequirementDecl(CallableDecl* decl, FuncType* funcType); +}; - struct SemanticsDeclRedeclarationVisitor - : public SemanticsDeclVisitorBase - , public DeclVisitor +struct SemanticsDeclRedeclarationVisitor : public SemanticsDeclVisitorBase, + public DeclVisitor +{ + SemanticsDeclRedeclarationVisitor(SemanticsContext const& outer) + : SemanticsDeclVisitorBase(outer) { - SemanticsDeclRedeclarationVisitor(SemanticsContext const& outer) - : SemanticsDeclVisitorBase(outer) - {} + } - void visitDecl(Decl*) {} - void visitDeclGroup(DeclGroup*) {} + void visitDecl(Decl*) {} + void visitDeclGroup(DeclGroup*) {} -#define CASE(TYPE) void visit##TYPE(TYPE* decl) { checkForRedeclaration(decl); } +#define CASE(TYPE) \ + void visit##TYPE(TYPE* decl) { checkForRedeclaration(decl); } - CASE(EnumCaseDecl) - CASE(FuncDecl) - CASE(VarDeclBase) - CASE(SimpleTypeDecl) - CASE(AggTypeDecl) + CASE(EnumCaseDecl) + CASE(FuncDecl) + CASE(VarDeclBase) + CASE(SimpleTypeDecl) + CASE(AggTypeDecl) #undef CASE - }; +}; - struct SemanticsDeclBasesVisitor - : public SemanticsDeclVisitorBase - , public DeclVisitor +struct SemanticsDeclBasesVisitor : public SemanticsDeclVisitorBase, + public DeclVisitor +{ + SemanticsDeclBasesVisitor(SemanticsContext const& outer) + : SemanticsDeclVisitorBase(outer) { - SemanticsDeclBasesVisitor(SemanticsContext const& outer) - : SemanticsDeclVisitorBase(outer) - {} + } - void visitDecl(Decl*) {} - void visitDeclGroup(DeclGroup*) {} + void visitDecl(Decl*) {} + void visitDeclGroup(DeclGroup*) {} - void visitInheritanceDecl(InheritanceDecl* inheritanceDecl); + void visitInheritanceDecl(InheritanceDecl* inheritanceDecl); - void visitThisTypeConstraintDecl(ThisTypeConstraintDecl* thisTypeConstraintDecl); + void visitThisTypeConstraintDecl(ThisTypeConstraintDecl* thisTypeConstraintDecl); - /// Validate that `decl` isn't illegally inheriting from a type in another module. - /// - /// This call checks a single `inheritanceDecl` to make sure that it either - /// * names a base type from the same module as `decl`, or - /// * names a type that allows cross-module inheritance - void _validateCrossModuleInheritance( - AggTypeDeclBase* decl, - InheritanceDecl* inheritanceDecl); + /// Validate that `decl` isn't illegally inheriting from a type in another module. + /// + /// This call checks a single `inheritanceDecl` to make sure that it either + /// * names a base type from the same module as `decl`, or + /// * names a type that allows cross-module inheritance + void _validateCrossModuleInheritance(AggTypeDeclBase* decl, InheritanceDecl* inheritanceDecl); - void visitInterfaceDecl(InterfaceDecl* decl); + void visitInterfaceDecl(InterfaceDecl* decl); - void visitStructDecl(StructDecl* decl); + void visitStructDecl(StructDecl* decl); - void visitClassDecl(ClassDecl* decl); + void visitClassDecl(ClassDecl* decl); - void visitEnumDecl(EnumDecl* decl); + void visitEnumDecl(EnumDecl* decl); - /// Validate that the target type of an extension `decl` is valid. - void _validateExtensionDeclTargetType(ExtensionDecl* decl); - void _validateExtensionDeclMembers(ExtensionDecl* decl); + /// Validate that the target type of an extension `decl` is valid. + void _validateExtensionDeclTargetType(ExtensionDecl* decl); + void _validateExtensionDeclMembers(ExtensionDecl* decl); - void visitExtensionDecl(ExtensionDecl* decl); - }; + void visitExtensionDecl(ExtensionDecl* decl); +}; - struct SemanticsDeclTypeResolutionVisitor - : public SemanticsDeclVisitorBase - , public DeclVisitor +struct SemanticsDeclTypeResolutionVisitor : public SemanticsDeclVisitorBase, + public DeclVisitor +{ + SemanticsDeclTypeResolutionVisitor(SemanticsContext const& outer) + : SemanticsDeclVisitorBase(outer) { - SemanticsDeclTypeResolutionVisitor(SemanticsContext const& outer) - : SemanticsDeclVisitorBase(outer) - {} + } - void visitDecl(Decl*) {} - void visitDeclGroup(DeclGroup*) {} + void visitDecl(Decl*) {} + void visitDeclGroup(DeclGroup*) {} - void visitTypeExp(TypeExp& exp) - { - exp.type = resolveType(exp.type); - } + void visitTypeExp(TypeExp& exp) { exp.type = resolveType(exp.type); } - void visitVarDeclBase(VarDeclBase* varDecl) - { - visitTypeExp(varDecl->type); - } + void visitVarDeclBase(VarDeclBase* varDecl) { visitTypeExp(varDecl->type); } - void visitGenericTypeConstraintDecl(GenericTypeConstraintDecl* decl) - { - visitTypeExp(decl->sup); - } + void visitGenericTypeConstraintDecl(GenericTypeConstraintDecl* decl) + { + visitTypeExp(decl->sup); + } - void visitTypeDefDecl(TypeDefDecl* decl) - { - visitTypeExp(decl->type); - } + void visitTypeDefDecl(TypeDefDecl* decl) { visitTypeExp(decl->type); } - void visitGenericTypeParamDecl(GenericTypeParamDecl* paramDecl) - { - visitTypeExp(paramDecl->initType); - } + void visitGenericTypeParamDecl(GenericTypeParamDecl* paramDecl) + { + visitTypeExp(paramDecl->initType); + } - void visitInheritanceDecl(InheritanceDecl* inheritanceDecl) - { - visitTypeExp(inheritanceDecl->base); - } + void visitInheritanceDecl(InheritanceDecl* inheritanceDecl) + { + visitTypeExp(inheritanceDecl->base); + } - void visitCallableDecl(CallableDecl* decl) - { - for (auto paramDecl : decl->getMembersOfType()) - visitTypeExp(paramDecl->type); + void visitCallableDecl(CallableDecl* decl) + { + for (auto paramDecl : decl->getMembersOfType()) + visitTypeExp(paramDecl->type); - visitTypeExp(decl->returnType); - visitTypeExp(decl->errorType); - } + visitTypeExp(decl->returnType); + visitTypeExp(decl->errorType); + } - void visitPropertyDecl(PropertyDecl* decl) - { - visitTypeExp(decl->type); - } - }; + void visitPropertyDecl(PropertyDecl* decl) { visitTypeExp(decl->type); } +}; - struct SemanticsDeclBodyVisitor - : public SemanticsDeclVisitorBase - , public DeclVisitor +struct SemanticsDeclBodyVisitor : public SemanticsDeclVisitorBase, + public DeclVisitor +{ + SemanticsDeclBodyVisitor(SemanticsContext const& outer) + : SemanticsDeclVisitorBase(outer) { - SemanticsDeclBodyVisitor(SemanticsContext const& outer) - : SemanticsDeclVisitorBase(outer) - {} - - void visitDecl(Decl*) {} - void visitDeclGroup(DeclGroup*) {} + } - void checkVarDeclCommon(VarDeclBase* varDecl); + void visitDecl(Decl*) {} + void visitDeclGroup(DeclGroup*) {} - void visitVarDecl(VarDecl* varDecl) - { - checkVarDeclCommon(varDecl); - } + void checkVarDeclCommon(VarDeclBase* varDecl); - void visitGenericValueParamDecl(GenericValueParamDecl* genValDecl) - { - checkVarDeclCommon(genValDecl); - } + void visitVarDecl(VarDecl* varDecl) { checkVarDeclCommon(varDecl); } - void visitGlobalGenericValueParamDecl(GlobalGenericValueParamDecl* decl) - { - checkVarDeclCommon(decl); - } + void visitGenericValueParamDecl(GenericValueParamDecl* genValDecl) + { + checkVarDeclCommon(genValDecl); + } - void visitEnumCaseDecl(EnumCaseDecl* decl); + void visitGlobalGenericValueParamDecl(GlobalGenericValueParamDecl* decl) + { + checkVarDeclCommon(decl); + } - void visitEnumDecl(EnumDecl* decl); + void visitEnumCaseDecl(EnumCaseDecl* decl); - void visitFunctionDeclBase(FunctionDeclBase* funcDecl); + void visitEnumDecl(EnumDecl* decl); - void visitParamDecl(ParamDecl* paramDecl); + void visitFunctionDeclBase(FunctionDeclBase* funcDecl); - void visitAggTypeDecl(AggTypeDecl* aggTypeDecl); + void visitParamDecl(ParamDecl* paramDecl); - SemanticsContext registerDifferentiableTypesForFunc(FunctionDeclBase* funcDecl); + void visitAggTypeDecl(AggTypeDecl* aggTypeDecl); - struct DeclAndCtorInfo - { - StructDecl* parent = nullptr; - ConstructorDecl* defaultCtor = nullptr; - List ctorList; - DeclAndCtorInfo() - { - } - DeclAndCtorInfo(ASTBuilder* m_astBuilder, SemanticsVisitor* visitor, StructDecl* parent, const bool getOnlyDefault) - { - if (getOnlyDefault) - defaultCtor = _getDefaultCtor(parent); - else - ctorList = _getCtorList(m_astBuilder, visitor, parent, &defaultCtor); - } - }; + SemanticsContext registerDifferentiableTypesForFunc(FunctionDeclBase* funcDecl); - void synthesizeCtorBody(DeclAndCtorInfo& structDeclInfo, List& inheritanceDefaultCtorList, StructDecl* structDecl); - void synthesizeCtorBodyForBases(ConstructorDecl* ctor, List& inheritanceDefaultCtorList, ThisExpr* thisExpr, SeqStmt* seqStmtChild); - void synthesizeCtorBodyForMember(ConstructorDecl* ctor, Decl* member, ThisExpr* thisExpr, Dictionary& cachedDeclToCheckedVar, SeqStmt* seqStmtChild); + struct DeclAndCtorInfo + { + StructDecl* parent = nullptr; + ConstructorDecl* defaultCtor = nullptr; + List ctorList; + DeclAndCtorInfo() {} + DeclAndCtorInfo( + ASTBuilder* m_astBuilder, + SemanticsVisitor* visitor, + StructDecl* parent, + const bool getOnlyDefault) + { + if (getOnlyDefault) + defaultCtor = _getDefaultCtor(parent); + else + ctorList = _getCtorList(m_astBuilder, visitor, parent, &defaultCtor); + } }; - template - struct SemanticsDeclReferenceVisitor - : public SemanticsDeclVisitorBase - , public StmtVisitor - , public ExprVisitor - , public ValVisitor - , public DeclVisitor + void synthesizeCtorBody( + DeclAndCtorInfo& structDeclInfo, + List& inheritanceDefaultCtorList, + StructDecl* structDecl); + void synthesizeCtorBodyForBases( + ConstructorDecl* ctor, + List& inheritanceDefaultCtorList, + ThisExpr* thisExpr, + SeqStmt* seqStmtChild); + void synthesizeCtorBodyForMember( + ConstructorDecl* ctor, + Decl* member, + ThisExpr* thisExpr, + Dictionary& cachedDeclToCheckedVar, + SeqStmt* seqStmtChild); +}; + +template +struct SemanticsDeclReferenceVisitor : public SemanticsDeclVisitorBase, + public StmtVisitor, + public ExprVisitor, + public ValVisitor, + public DeclVisitor +{ + SemanticsDeclReferenceVisitor(SemanticsContext const& outer) + : SemanticsDeclVisitorBase(outer) { - SemanticsDeclReferenceVisitor(SemanticsContext const& outer) - : SemanticsDeclVisitorBase(outer) - {} + } - List sourceLocStack; + List sourceLocStack; - struct PushSourceLocRAII + struct PushSourceLocRAII + { + List& stack; + bool shouldPop = false; + PushSourceLocRAII(List& sourceLocStack, SourceLoc loc) + : stack(sourceLocStack) { - List& stack; - bool shouldPop = false; - PushSourceLocRAII(List& sourceLocStack, SourceLoc loc) - : stack(sourceLocStack) + if (loc.isValid()) { - if (loc.isValid()) - { - stack.add(loc); - shouldPop = true; - } - } - ~PushSourceLocRAII() - { - if (shouldPop) - { - stack.removeLast(); - } + stack.add(loc); + shouldPop = true; } - }; - - virtual void processReferencedDecl(Decl* decl) = 0; - - virtual void processDeclModifiers(Decl* decl, SourceLoc refLoc) = 0; - - void dispatchIfNotNull(Stmt* stmt) - { - if (!stmt) - return; - PushSourceLocRAII sourceLocRAII(sourceLocStack, stmt->loc); - return StmtVisitor::dispatch(stmt); - } - void dispatchIfNotNull(Expr* expr) - { - if (!expr) - return; - PushSourceLocRAII sourceLocRAII(sourceLocStack, expr->loc); - return ExprVisitor::dispatch(expr); - } - void dispatchIfNotNull(Val* val) - { - if (!val) - return; - return ValVisitor::dispatch(val); - } - void dispatchIfNotNull(DeclBase* val) - { - if (!val) - return; - return DeclVisitor::dispatch(val); - } - // Expr Visitor - void visitExpr(Expr*) { } - void visitIndexExpr(IndexExpr* subscriptExpr) - { - for (auto arg : subscriptExpr->indexExprs) - dispatchIfNotNull(arg); - dispatchIfNotNull(subscriptExpr->baseExpression); - } - - void visitParenExpr(ParenExpr* expr) - { - dispatchIfNotNull(expr->base); - } - - void visitAssignExpr(AssignExpr* expr) - { - dispatchIfNotNull(expr->left); - dispatchIfNotNull(expr->right); } - - void visitGenericAppExpr(GenericAppExpr* genericAppExpr) + ~PushSourceLocRAII() { - dispatchIfNotNull(genericAppExpr->functionExpr); - for (auto arg : genericAppExpr->arguments) - dispatchIfNotNull(arg); + if (shouldPop) + { + stack.removeLast(); + } } + }; - void visitSharedTypeExpr(SharedTypeExpr* expr) { dispatchIfNotNull(expr->base.exp); } - - void visitInvokeExpr(InvokeExpr* expr) - { - dispatchIfNotNull(expr->functionExpr); - for (auto arg : expr->arguments) - dispatchIfNotNull(arg); - } + virtual void processReferencedDecl(Decl* decl) = 0; - void visitTypeCastExpr(TypeCastExpr* expr) - { - dispatchIfNotNull(expr->functionExpr); - for (auto arg : expr->arguments) - dispatchIfNotNull(arg); - } + virtual void processDeclModifiers(Decl* decl, SourceLoc refLoc) = 0; - void visitDerefExpr(DerefExpr* expr) { dispatchIfNotNull(expr->base); } - void visitMatrixSwizzleExpr(MatrixSwizzleExpr* expr) - { - dispatchIfNotNull(expr->base); - } - void visitSwizzleExpr(SwizzleExpr* expr) - { - dispatchIfNotNull(expr->base); - } - void visitOverloadedExpr(OverloadedExpr*) - { + void dispatchIfNotNull(Stmt* stmt) + { + if (!stmt) return; - } - void visitOverloadedExpr2(OverloadedExpr2*) - { + PushSourceLocRAII sourceLocRAII(sourceLocStack, stmt->loc); + return StmtVisitor::dispatch(stmt); + } + void dispatchIfNotNull(Expr* expr) + { + if (!expr) return; - } - void visitAggTypeCtorExpr(AggTypeCtorExpr*) - { + PushSourceLocRAII sourceLocRAII(sourceLocStack, expr->loc); + return ExprVisitor::dispatch(expr); + } + void dispatchIfNotNull(Val* val) + { + if (!val) return; - } - void visitCastToSuperTypeExpr(CastToSuperTypeExpr* expr) - { - dispatchIfNotNull(expr->valueArg); - } - void visitModifierCastExpr(ModifierCastExpr* expr) { dispatchIfNotNull(expr->valueArg); } - void visitLetExpr(LetExpr* expr) - { - dispatchIfNotNull(expr->body); - } - void visitExtractExistentialValueExpr(ExtractExistentialValueExpr* expr) - { - dispatchIfNotNull(expr->declRef.declRefBase); - } - - void visitDeclRefExpr(DeclRefExpr* expr) - { - dispatchIfNotNull(expr->type.type); - dispatchIfNotNull(expr->declRef.declRefBase); - - // Pass down the callee location - processDeclModifiers(expr->declRef.getDecl(), expr->loc); - } - void visitStaticMemberExpr(StaticMemberExpr* expr) - { - dispatchIfNotNull(expr->declRef.declRefBase); - } - void visitInitializerListExpr(InitializerListExpr* expr) - { - for (auto arg : expr->args) - { - dispatchIfNotNull(arg); - } - } - - void visitThisExpr(ThisExpr*) - { + return ValVisitor::dispatch(val); + } + void dispatchIfNotNull(DeclBase* val) + { + if (!val) return; - } + return DeclVisitor::dispatch(val); + } + // Expr Visitor + void visitExpr(Expr*) {} + void visitIndexExpr(IndexExpr* subscriptExpr) + { + for (auto arg : subscriptExpr->indexExprs) + dispatchIfNotNull(arg); + dispatchIfNotNull(subscriptExpr->baseExpression); + } - void visitThisTypeExpr(ThisTypeExpr*) { return; } - void visitAndTypeExpr(AndTypeExpr* expr) - { - dispatchIfNotNull(expr->left.type); - dispatchIfNotNull(expr->right.type); - } - void visitPointerTypeExpr(PointerTypeExpr* expr) - { - dispatchIfNotNull(expr->base.type); - } - void visitAsTypeExpr(AsTypeExpr* expr) - { - dispatchIfNotNull(expr->value); - dispatchIfNotNull(expr->witnessArg); - } - void visitIsTypeExpr(IsTypeExpr* expr) - { - dispatchIfNotNull(expr->value); - dispatchIfNotNull(expr->witnessArg); - } - void visitMakeOptionalExpr(MakeOptionalExpr* expr) - { - dispatchIfNotNull(expr->value); - dispatchIfNotNull(expr->typeExpr); - } - void visitPartiallyAppliedGenericExpr(PartiallyAppliedGenericExpr*) - { - return; - } - void visitSPIRVAsmExpr(SPIRVAsmExpr*) - { - return; - } - void visitModifiedTypeExpr(ModifiedTypeExpr* expr) { dispatchIfNotNull(expr->base.type); } - void visitFuncTypeExpr(FuncTypeExpr* expr) - { - for (const auto& t : expr->parameters) - { - dispatchIfNotNull(t.type); - } - dispatchIfNotNull(expr->result.type); - } - void visitTupleTypeExpr(TupleTypeExpr* expr) - { - for (auto t : expr->members) - { - dispatchIfNotNull(t.type); - } - } - void visitTryExpr(TryExpr* expr) { dispatchIfNotNull(expr->base); } - void visitHigherOrderInvokeExpr(HigherOrderInvokeExpr* expr) - { - dispatchIfNotNull(expr->baseFunction); - } - void visitTreatAsDifferentiableExpr(TreatAsDifferentiableExpr* expr) - { - dispatchIfNotNull(expr->innerExpr); - } + void visitParenExpr(ParenExpr* expr) { dispatchIfNotNull(expr->base); } - // Stmt Visitor + void visitAssignExpr(AssignExpr* expr) + { + dispatchIfNotNull(expr->left); + dispatchIfNotNull(expr->right); + } - void visitDeclStmt(DeclStmt* stmt) - { - dispatchIfNotNull(stmt->decl); - } + void visitGenericAppExpr(GenericAppExpr* genericAppExpr) + { + dispatchIfNotNull(genericAppExpr->functionExpr); + for (auto arg : genericAppExpr->arguments) + dispatchIfNotNull(arg); + } - void visitBlockStmt(BlockStmt* stmt) - { - dispatchIfNotNull(stmt->body); - } + void visitSharedTypeExpr(SharedTypeExpr* expr) { dispatchIfNotNull(expr->base.exp); } - void visitSeqStmt(SeqStmt* seqStmt) - { - for (auto stmt : seqStmt->stmts) - dispatchIfNotNull(stmt); - } + void visitInvokeExpr(InvokeExpr* expr) + { + dispatchIfNotNull(expr->functionExpr); + for (auto arg : expr->arguments) + dispatchIfNotNull(arg); + } - void visitLabelStmt(LabelStmt* stmt) - { - dispatchIfNotNull(stmt->innerStmt); - } + void visitTypeCastExpr(TypeCastExpr* expr) + { + dispatchIfNotNull(expr->functionExpr); + for (auto arg : expr->arguments) + dispatchIfNotNull(arg); + } - void visitBreakStmt(BreakStmt*) { return; } + void visitDerefExpr(DerefExpr* expr) { dispatchIfNotNull(expr->base); } + void visitMatrixSwizzleExpr(MatrixSwizzleExpr* expr) { dispatchIfNotNull(expr->base); } + void visitSwizzleExpr(SwizzleExpr* expr) { dispatchIfNotNull(expr->base); } + void visitOverloadedExpr(OverloadedExpr*) { return; } + void visitOverloadedExpr2(OverloadedExpr2*) { return; } + void visitAggTypeCtorExpr(AggTypeCtorExpr*) { return; } + void visitCastToSuperTypeExpr(CastToSuperTypeExpr* expr) { dispatchIfNotNull(expr->valueArg); } + void visitModifierCastExpr(ModifierCastExpr* expr) { dispatchIfNotNull(expr->valueArg); } + void visitLetExpr(LetExpr* expr) { dispatchIfNotNull(expr->body); } + void visitExtractExistentialValueExpr(ExtractExistentialValueExpr* expr) + { + dispatchIfNotNull(expr->declRef.declRefBase); + } - void visitContinueStmt(ContinueStmt*) { return; } + void visitDeclRefExpr(DeclRefExpr* expr) + { + dispatchIfNotNull(expr->type.type); + dispatchIfNotNull(expr->declRef.declRefBase); - void visitDoWhileStmt(DoWhileStmt* stmt) + // Pass down the callee location + processDeclModifiers(expr->declRef.getDecl(), expr->loc); + } + void visitStaticMemberExpr(StaticMemberExpr* expr) + { + dispatchIfNotNull(expr->declRef.declRefBase); + } + void visitInitializerListExpr(InitializerListExpr* expr) + { + for (auto arg : expr->args) { - dispatchIfNotNull(stmt->predicate); - dispatchIfNotNull(stmt->statement); + dispatchIfNotNull(arg); } + } - void visitForStmt(ForStmt* stmt) - { - dispatchIfNotNull(stmt->initialStatement); - dispatchIfNotNull(stmt->predicateExpression); - dispatchIfNotNull(stmt->sideEffectExpression); - dispatchIfNotNull(stmt->statement); - } + void visitThisExpr(ThisExpr*) { return; } - void visitCompileTimeForStmt(CompileTimeForStmt* stmt) + void visitThisTypeExpr(ThisTypeExpr*) { return; } + void visitAndTypeExpr(AndTypeExpr* expr) + { + dispatchIfNotNull(expr->left.type); + dispatchIfNotNull(expr->right.type); + } + void visitPointerTypeExpr(PointerTypeExpr* expr) { dispatchIfNotNull(expr->base.type); } + void visitAsTypeExpr(AsTypeExpr* expr) + { + dispatchIfNotNull(expr->value); + dispatchIfNotNull(expr->witnessArg); + } + void visitIsTypeExpr(IsTypeExpr* expr) + { + dispatchIfNotNull(expr->value); + dispatchIfNotNull(expr->witnessArg); + } + void visitMakeOptionalExpr(MakeOptionalExpr* expr) + { + dispatchIfNotNull(expr->value); + dispatchIfNotNull(expr->typeExpr); + } + void visitPartiallyAppliedGenericExpr(PartiallyAppliedGenericExpr*) { return; } + void visitSPIRVAsmExpr(SPIRVAsmExpr*) { return; } + void visitModifiedTypeExpr(ModifiedTypeExpr* expr) { dispatchIfNotNull(expr->base.type); } + void visitFuncTypeExpr(FuncTypeExpr* expr) + { + for (const auto& t : expr->parameters) { - dispatchIfNotNull(stmt->rangeBeginExpr); - dispatchIfNotNull(stmt->rangeEndExpr); - dispatchIfNotNull(stmt->body); + dispatchIfNotNull(t.type); } - - void visitSwitchStmt(SwitchStmt* stmt) + dispatchIfNotNull(expr->result.type); + } + void visitTupleTypeExpr(TupleTypeExpr* expr) + { + for (auto t : expr->members) { - dispatchIfNotNull(stmt->condition); - dispatchIfNotNull(stmt->body); + dispatchIfNotNull(t.type); } + } + void visitTryExpr(TryExpr* expr) { dispatchIfNotNull(expr->base); } + void visitHigherOrderInvokeExpr(HigherOrderInvokeExpr* expr) + { + dispatchIfNotNull(expr->baseFunction); + } + void visitTreatAsDifferentiableExpr(TreatAsDifferentiableExpr* expr) + { + dispatchIfNotNull(expr->innerExpr); + } - void visitCaseStmt(CaseStmt* stmt) { dispatchIfNotNull(stmt->expr); } + // Stmt Visitor - void visitTargetSwitchStmt(TargetSwitchStmt* stmt) - { - for (auto targetCase : stmt->targetCases) - dispatchIfNotNull(targetCase); - } + void visitDeclStmt(DeclStmt* stmt) { dispatchIfNotNull(stmt->decl); } - void visitTargetCaseStmt(TargetCaseStmt* stmt) - { - dispatchIfNotNull(stmt->body); - } + void visitBlockStmt(BlockStmt* stmt) { dispatchIfNotNull(stmt->body); } - void visitIntrinsicAsmStmt(IntrinsicAsmStmt*) { return; } + void visitSeqStmt(SeqStmt* seqStmt) + { + for (auto stmt : seqStmt->stmts) + dispatchIfNotNull(stmt); + } - void visitDefaultStmt(DefaultStmt*) { return; } + void visitLabelStmt(LabelStmt* stmt) { dispatchIfNotNull(stmt->innerStmt); } - void visitIfStmt(IfStmt* stmt) - { - dispatchIfNotNull(stmt->predicate); - dispatchIfNotNull(stmt->positiveStatement); - dispatchIfNotNull(stmt->negativeStatement); - } + void visitBreakStmt(BreakStmt*) { return; } - void visitUnparsedStmt(UnparsedStmt*) { return; } + void visitContinueStmt(ContinueStmt*) { return; } - void visitEmptyStmt(EmptyStmt*) { return; } + void visitDoWhileStmt(DoWhileStmt* stmt) + { + dispatchIfNotNull(stmt->predicate); + dispatchIfNotNull(stmt->statement); + } - void visitDiscardStmt(DiscardStmt*) { return; } + void visitForStmt(ForStmt* stmt) + { + dispatchIfNotNull(stmt->initialStatement); + dispatchIfNotNull(stmt->predicateExpression); + dispatchIfNotNull(stmt->sideEffectExpression); + dispatchIfNotNull(stmt->statement); + } - void visitReturnStmt(ReturnStmt* stmt) { dispatchIfNotNull(stmt->expression); } + void visitCompileTimeForStmt(CompileTimeForStmt* stmt) + { + dispatchIfNotNull(stmt->rangeBeginExpr); + dispatchIfNotNull(stmt->rangeEndExpr); + dispatchIfNotNull(stmt->body); + } - void visitWhileStmt(WhileStmt* stmt) - { - dispatchIfNotNull(stmt->predicate); - dispatchIfNotNull(stmt->statement); - } + void visitSwitchStmt(SwitchStmt* stmt) + { + dispatchIfNotNull(stmt->condition); + dispatchIfNotNull(stmt->body); + } - void visitGpuForeachStmt(GpuForeachStmt*) { return; } + void visitCaseStmt(CaseStmt* stmt) { dispatchIfNotNull(stmt->expr); } - void visitExpressionStmt(ExpressionStmt* stmt) - { - dispatchIfNotNull(stmt->expression); - } + void visitTargetSwitchStmt(TargetSwitchStmt* stmt) + { + for (auto targetCase : stmt->targetCases) + dispatchIfNotNull(targetCase); + } - // Val Visitor + void visitTargetCaseStmt(TargetCaseStmt* stmt) { dispatchIfNotNull(stmt->body); } - void visitDirectDeclRef(DirectDeclRef* declRef) - { - // If we have already visited, return. - // Otherwise add it to visited set. - if (!visitedVals.add(declRef)) - return; + void visitIntrinsicAsmStmt(IntrinsicAsmStmt*) { return; } - processReferencedDecl(declRef->getDecl()); - } + void visitDefaultStmt(DefaultStmt*) { return; } - void visitVal(Val* val) - { - // If we have already visited, return. - // Otherwise add it to visited set. - if (!visitedVals.add(val)) - return; + void visitIfStmt(IfStmt* stmt) + { + dispatchIfNotNull(stmt->predicate); + dispatchIfNotNull(stmt->positiveStatement); + dispatchIfNotNull(stmt->negativeStatement); + } - for (Index i = 0; i < val->getOperandCount(); i++) - { - auto& operand = val->m_operands[i]; - switch (operand.kind) - { - case ValNodeOperandKind::ValNode: - dispatchIfNotNull(val->getOperand(i)); - break; - default: - break; - } - } + void visitUnparsedStmt(UnparsedStmt*) { return; } + + void visitEmptyStmt(EmptyStmt*) { return; } + + void visitDiscardStmt(DiscardStmt*) { return; } + + void visitReturnStmt(ReturnStmt* stmt) { dispatchIfNotNull(stmt->expression); } + + void visitWhileStmt(WhileStmt* stmt) + { + dispatchIfNotNull(stmt->predicate); + dispatchIfNotNull(stmt->statement); + } + + void visitGpuForeachStmt(GpuForeachStmt*) { return; } + + void visitExpressionStmt(ExpressionStmt* stmt) { dispatchIfNotNull(stmt->expression); } + + // Val Visitor + + void visitDirectDeclRef(DirectDeclRef* declRef) + { + // If we have already visited, return. + // Otherwise add it to visited set. + if (!visitedVals.add(declRef)) return; - } - HashSet visitedVals; + processReferencedDecl(declRef->getDecl()); + } - // Decl visitor - void visitDeclBase(DeclBase*) - {} + void visitVal(Val* val) + { + // If we have already visited, return. + // Otherwise add it to visited set. + if (!visitedVals.add(val)) + return; - void visitContainerDecl(ContainerDecl* decl) + for (Index i = 0; i < val->getOperandCount(); i++) { - for (auto m : decl->members) + auto& operand = val->m_operands[i]; + switch (operand.kind) { - dispatchIfNotNull(m); + case ValNodeOperandKind::ValNode: dispatchIfNotNull(val->getOperand(i)); break; + default: break; } } + return; + } - void visitFunctionDeclBase(FunctionDeclBase* decl) - { - visitContainerDecl(decl); - dispatchIfNotNull(decl->body); - } + HashSet visitedVals; + + // Decl visitor + void visitDeclBase(DeclBase*) {} - void visitVarDeclBase(VarDeclBase* varDecl) + void visitContainerDecl(ContainerDecl* decl) + { + for (auto m : decl->members) { - dispatchIfNotNull(varDecl->type.type); - dispatchIfNotNull(varDecl->initExpr); + dispatchIfNotNull(m); } - }; + } + + void visitFunctionDeclBase(FunctionDeclBase* decl) + { + visitContainerDecl(decl); + dispatchIfNotNull(decl->body); + } - struct SemanticsDeclCapabilityVisitor - : public SemanticsDeclVisitorBase - , public DeclVisitor + void visitVarDeclBase(VarDeclBase* varDecl) { - CapabilitySet m_anyPlatfromCapabilitySet; + dispatchIfNotNull(varDecl->type.type); + dispatchIfNotNull(varDecl->initExpr); + } +}; + +struct SemanticsDeclCapabilityVisitor : public SemanticsDeclVisitorBase, + public DeclVisitor +{ + CapabilitySet m_anyPlatfromCapabilitySet; - SemanticsDeclCapabilityVisitor(SemanticsContext const& outer) - : SemanticsDeclVisitorBase(outer) - {} + SemanticsDeclCapabilityVisitor(SemanticsContext const& outer) + : SemanticsDeclVisitorBase(outer) + { + } - CapabilitySet& getAnyPlatformCapabilitySet() + CapabilitySet& getAnyPlatformCapabilitySet() + { + if (m_anyPlatfromCapabilitySet.isEmpty()) { - if (m_anyPlatfromCapabilitySet.isEmpty()) - { - m_anyPlatfromCapabilitySet = CapabilitySet(CapabilityName::any_target); - } - return m_anyPlatfromCapabilitySet; + m_anyPlatfromCapabilitySet = CapabilitySet(CapabilityName::any_target); } + return m_anyPlatfromCapabilitySet; + } - CapabilitySet getDeclaredCapabilitySet(Decl* decl); - + CapabilitySet getDeclaredCapabilitySet(Decl* decl); - void visitDecl(Decl*) {} - void visitDeclGroup(DeclGroup*) {} - void checkVarDeclCommon(VarDeclBase* varDecl); - void visitAggTypeDeclBase(AggTypeDeclBase* decl); - void visitNamespaceDeclBase(NamespaceDeclBase* decl); - void visitVarDecl(VarDecl* varDecl) - { - checkVarDeclCommon(varDecl); - } + void visitDecl(Decl*) {} + void visitDeclGroup(DeclGroup*) {} + void checkVarDeclCommon(VarDeclBase* varDecl); + void visitAggTypeDeclBase(AggTypeDeclBase* decl); + void visitNamespaceDeclBase(NamespaceDeclBase* decl); - void visitParamDecl(ParamDecl* paramDecl) - { - checkVarDeclCommon(paramDecl); - } + void visitVarDecl(VarDecl* varDecl) { checkVarDeclCommon(varDecl); } - void visitFunctionDeclBase(FunctionDeclBase* funcDecl); + void visitParamDecl(ParamDecl* paramDecl) { checkVarDeclCommon(paramDecl); } - void visitInheritanceDecl(InheritanceDecl* inheritanceDecl); + void visitFunctionDeclBase(FunctionDeclBase* funcDecl); - void diagnoseUndeclaredCapability(Decl* decl, const DiagnosticInfo& diagnosticInfo, const CapabilityAtomSet& failedAtomsInsideAvailableSet); - }; + void visitInheritanceDecl(InheritanceDecl* inheritanceDecl); + void diagnoseUndeclaredCapability( + Decl* decl, + const DiagnosticInfo& diagnosticInfo, + const CapabilityAtomSet& failedAtomsInsideAvailableSet); +}; - /// Should the given `decl` nested in `parentDecl` be treated as a static rather than instance declaration? - bool isEffectivelyStatic( - Decl* decl, - ContainerDecl* parentDecl) - { - // Things at the global scope are always "members" of their module. - // - if(as(parentDecl)) - return false; - if (as(parentDecl)) - return false; - // Anything explicitly marked `static` and not at module scope - // counts as a static rather than instance declaration. - // - if(decl->hasModifier()) - return true; +/// Should the given `decl` nested in `parentDecl` be treated as a static rather than instance +/// declaration? +bool isEffectivelyStatic(Decl* decl, ContainerDecl* parentDecl) +{ + // Things at the global scope are always "members" of their module. + // + if (as(parentDecl)) + return false; + if (as(parentDecl)) + return false; - // Next we need to deal with cases where a declaration is - // effectively `static` even if the language doesn't make - // the user say so. Most languages make the default assumption - // that nested types are `static` even if they don't say - // so (Java is an exception here, perhaps due to some - // influence from the Scandanavian OOP tradition of Beta/gbeta). - // - if(as(decl)) - return true; - if(as(decl)) - return true; + // Anything explicitly marked `static` and not at module scope + // counts as a static rather than instance declaration. + // + if (decl->hasModifier()) + return true; - // Initializer/constructor declarations are effectively `static` - // in Slang. They behave like functions that return an instance - // of the enclosing type, rather than as functions that are - // called on a pre-existing value. - // - if(as(decl)) - return true; + // Next we need to deal with cases where a declaration is + // effectively `static` even if the language doesn't make + // the user say so. Most languages make the default assumption + // that nested types are `static` even if they don't say + // so (Java is an exception here, perhaps due to some + // influence from the Scandanavian OOP tradition of Beta/gbeta). + // + if (as(decl)) + return true; + if (as(decl)) + return true; - if (as(decl)) - return true; + // Initializer/constructor declarations are effectively `static` + // in Slang. They behave like functions that return an instance + // of the enclosing type, rather than as functions that are + // called on a pre-existing value. + // + if (as(decl)) + return true; - // Things nested inside functions may have dependencies - // on values from the enclosing scope, but this needs to - // be dealt with via "capture" so they are also effectively - // `static` - // - if(as(parentDecl)) - return true; + if (as(decl)) + return true; - // Type constraint declarations are used in member-reference - // context as a form of casting operation, so we treat them - // as if they are instance members. This is a bit of a hack, - // but it achieves the result we want until we have an - // explicit representation of up-cast operations in the - // AST. - // - if(as(decl)) - return false; + // Things nested inside functions may have dependencies + // on values from the enclosing scope, but this needs to + // be dealt with via "capture" so they are also effectively + // `static` + // + if (as(parentDecl)) + return true; + // Type constraint declarations are used in member-reference + // context as a form of casting operation, so we treat them + // as if they are instance members. This is a bit of a hack, + // but it achieves the result we want until we have an + // explicit representation of up-cast operations in the + // AST. + // + if (as(decl)) return false; - } - bool isEffectivelyStatic( - Decl* decl) - { - // For the purposes of an ordinary declaration, when determining if - // it is static or per-instance, the "parent" declaration we really - // care about is the next outer non-generic declaration. - // - // TODO: This idiom of getting the "next outer non-generic declaration" - // comes up just enough that we should probably have a convenience - // function for it. + return false; +} - auto parentDecl = decl->parentDecl; - if(auto genericDecl = as(parentDecl)) - parentDecl = genericDecl->parentDecl; +bool isEffectivelyStatic(Decl* decl) +{ + // For the purposes of an ordinary declaration, when determining if + // it is static or per-instance, the "parent" declaration we really + // care about is the next outer non-generic declaration. + // + // TODO: This idiom of getting the "next outer non-generic declaration" + // comes up just enough that we should probably have a convenience + // function for it. + + auto parentDecl = decl->parentDecl; + if (auto genericDecl = as(parentDecl)) + parentDecl = genericDecl->parentDecl; + + return isEffectivelyStatic(decl, parentDecl); +} - return isEffectivelyStatic(decl, parentDecl); - } +bool isGlobalDecl(Decl* decl) +{ + if (!decl) + return false; + auto parentDecl = decl->parentDecl; + if (auto genericDecl = as(parentDecl)) + parentDecl = genericDecl->parentDecl; + return as(parentDecl) != nullptr || as(parentDecl) != nullptr; +} - bool isGlobalDecl(Decl* decl) - { - if (!decl) - return false; - auto parentDecl = decl->parentDecl; - if (auto genericDecl = as(parentDecl)) - parentDecl = genericDecl->parentDecl; - return as(parentDecl) != nullptr || as(parentDecl) != nullptr; - } +bool isUnsafeForceInlineFunc(FunctionDeclBase* funcDecl) +{ + return funcDecl->hasModifier(); +} - bool isUnsafeForceInlineFunc(FunctionDeclBase* funcDecl) +/// Is `decl` a global shader parameter declaration? +bool isGlobalShaderParameter(VarDeclBase* decl) +{ + // If it's an *actual* global it is not a global shader parameter + if (decl->hasModifier()) { - return funcDecl->hasModifier(); + return false; } - /// Is `decl` a global shader parameter declaration? - bool isGlobalShaderParameter(VarDeclBase* decl) - { - // If it's an *actual* global it is not a global shader parameter - if (decl->hasModifier()) { return false; } - - // A global shader parameter must be declared at global or namespace - // scope, so that it has a single definition across the module. - // - if(!isGlobalDecl(decl)) return false; + // A global shader parameter must be declared at global or namespace + // scope, so that it has a single definition across the module. + // + if (!isGlobalDecl(decl)) + return false; - // A global variable marked `static` indicates a traditional - // global variable (albeit one that is implicitly local to - // the translation unit) - // - if(decl->hasModifier()) return false; + // A global variable marked `static` indicates a traditional + // global variable (albeit one that is implicitly local to + // the translation unit) + // + if (decl->hasModifier()) + return false; - // While not normally allowed, out variables are not constant - // parameters, this can happen for example in GLSL mode - if(decl->hasModifier()) return false; - if(decl->hasModifier()) return false; + // While not normally allowed, out variables are not constant + // parameters, this can happen for example in GLSL mode + if (decl->hasModifier()) + return false; + if (decl->hasModifier()) + return false; - // The `groupshared` modifier indicates that a variable cannot - // be a shader parameters, but is instead transient storage - // allocated for the duration of a thread-group's execution. - // - if(decl->hasModifier()) return false; + // The `groupshared` modifier indicates that a variable cannot + // be a shader parameters, but is instead transient storage + // allocated for the duration of a thread-group's execution. + // + if (decl->hasModifier()) + return false; - return true; - } + return true; +} - [[maybe_unused]] - static bool _isUncheckedLocalVar(const Decl* decl) - { - auto checkStateExt = decl->checkState; - auto isUnchecked = checkStateExt.getState() == DeclCheckState::Unchecked || checkStateExt.isBeingChecked(); - return isUnchecked && isLocalVar(decl); - } +[[maybe_unused]] static bool _isUncheckedLocalVar(const Decl* decl) +{ + auto checkStateExt = decl->checkState; + auto isUnchecked = + checkStateExt.getState() == DeclCheckState::Unchecked || checkStateExt.isBeingChecked(); + return isUnchecked && isLocalVar(decl); +} - // Get the type to use when referencing a declaration - QualType getTypeForDeclRef( - ASTBuilder* astBuilder, - SemanticsVisitor* sema, - DiagnosticSink* sink, - DeclRef declRef, - Type** outTypeResult, - SourceLoc loc) +// Get the type to use when referencing a declaration +QualType getTypeForDeclRef( + ASTBuilder* astBuilder, + SemanticsVisitor* sema, + DiagnosticSink* sink, + DeclRef declRef, + Type** outTypeResult, + SourceLoc loc) +{ + if (sema) { - if( sema ) - { - // If this is a local variable which hasn't been checked yet then - // it's probably a declare-after-use which has incorrectly got - // through declref resolution. - SLANG_ASSERT(!_isUncheckedLocalVar(declRef.getDecl())); + // If this is a local variable which hasn't been checked yet then + // it's probably a declare-after-use which has incorrectly got + // through declref resolution. + SLANG_ASSERT(!_isUncheckedLocalVar(declRef.getDecl())); - // Once we've ruled out the case of referencing a local declaration - // before it has been checked, we will go ahead and ensure that - // semantic checking has been performed on the chosen declaration, - // at least up to the point where we can query its type. - // - sema->ensureDecl(declRef, DeclCheckState::CanUseTypeOfValueDecl); - } + // Once we've ruled out the case of referencing a local declaration + // before it has been checked, we will go ahead and ensure that + // semantic checking has been performed on the chosen declaration, + // at least up to the point where we can query its type. + // + sema->ensureDecl(declRef, DeclCheckState::CanUseTypeOfValueDecl); + } - // We need to insert an appropriate type for the expression, based on - // what we found. - if (auto varDeclRef = declRef.as()) - { - QualType qualType; - qualType.type = getType(astBuilder, varDeclRef); + // We need to insert an appropriate type for the expression, based on + // what we found. + if (auto varDeclRef = declRef.as()) + { + QualType qualType; + qualType.type = getType(astBuilder, varDeclRef); - bool isLValue = true; - if(varDeclRef.getDecl()->findModifier()) - isLValue = false; + bool isLValue = true; + if (varDeclRef.getDecl()->findModifier()) + isLValue = false; - // Global-scope shader parameters should not be writable, - // since they are effectively program inputs. - // - // TODO: We could eventually treat a mutable global shader - // parameter as a shorthand for an immutable parameter and - // a global variable that gets initialized from that parameter, - // but in order to do so we'd need to support global variables - // with resource types better in the back-end. - // - if(isGlobalShaderParameter(varDeclRef.getDecl())) - isLValue = false; + // Global-scope shader parameters should not be writable, + // since they are effectively program inputs. + // + // TODO: We could eventually treat a mutable global shader + // parameter as a shorthand for an immutable parameter and + // a global variable that gets initialized from that parameter, + // but in order to do so we'd need to support global variables + // with resource types better in the back-end. + // + if (isGlobalShaderParameter(varDeclRef.getDecl())) + isLValue = false; - // Variables declared with `let` are always immutable. - if(varDeclRef.is()) - isLValue = false; + // Variables declared with `let` are always immutable. + if (varDeclRef.is()) + isLValue = false; - // Generic value parameters are always immutable - if(varDeclRef.is()) - isLValue = false; + // Generic value parameters are always immutable + if (varDeclRef.is()) + isLValue = false; - // Function parameters declared in the "modern" style - // are immutable unless they have an `out` or `inout` modifier. - if(varDeclRef.is()) + // Function parameters declared in the "modern" style + // are immutable unless they have an `out` or `inout` modifier. + if (varDeclRef.is()) + { + // Note: the `inout` modifier AST class inherits from + // the class for the `out` modifier so that we can + // make simple checks like this. + // + if (!varDeclRef.getDecl()->hasModifier()) { - // Note: the `inout` modifier AST class inherits from - // the class for the `out` modifier so that we can - // make simple checks like this. - // - if( !varDeclRef.getDecl()->hasModifier() ) - { - isLValue = false; - } + isLValue = false; } + } - // Ensures child of struct is set read-only or not - bool isWriteOnly = false; - if(auto collection = varDeclRef.getDecl()->findModifier()) + // Ensures child of struct is set read-only or not + bool isWriteOnly = false; + if (auto collection = varDeclRef.getDecl()->findModifier()) + { + if (collection->getMemoryQualifierBit() & MemoryQualifierSetModifier::Flags::kReadOnly) { - if(collection->getMemoryQualifierBit() & MemoryQualifierSetModifier::Flags::kReadOnly) - { - isLValue = false; - qualType.hasReadOnlyOnTarget = true; - } - if(collection->getMemoryQualifierBit() & MemoryQualifierSetModifier::Flags::kWriteOnly) - isWriteOnly = true; + isLValue = false; + qualType.hasReadOnlyOnTarget = true; } - - qualType.isLeftValue = isLValue; - qualType.isWriteOnly = isWriteOnly; - return qualType; + if (collection->getMemoryQualifierBit() & MemoryQualifierSetModifier::Flags::kWriteOnly) + isWriteOnly = true; } - else if( auto propertyDeclRef = declRef.as() ) - { - // Access to a declared `property` is similar to - // access to a variable/field, except that it - // is mediated through accessors (getters, seters, etc.). - QualType qualType; - qualType.type = getType(astBuilder, propertyDeclRef); + qualType.isLeftValue = isLValue; + qualType.isWriteOnly = isWriteOnly; + return qualType; + } + else if (auto propertyDeclRef = declRef.as()) + { + // Access to a declared `property` is similar to + // access to a variable/field, except that it + // is mediated through accessors (getters, seters, etc.). - bool isLValue = false; + QualType qualType; + qualType.type = getType(astBuilder, propertyDeclRef); - // If the property has any declared accessors that - // can be used to set the property, then the resulting - // expression behaves as an l-value. - // - if(propertyDeclRef.getDecl()->getMembersOfType().isNonEmpty()) - isLValue = true; - if(propertyDeclRef.getDecl()->getMembersOfType().isNonEmpty()) - isLValue = true; + bool isLValue = false; - qualType.isLeftValue = isLValue; - return qualType; + // If the property has any declared accessors that + // can be used to set the property, then the resulting + // expression behaves as an l-value. + // + if (propertyDeclRef.getDecl()->getMembersOfType().isNonEmpty()) + isLValue = true; + if (propertyDeclRef.getDecl()->getMembersOfType().isNonEmpty()) + isLValue = true; - } - else if( auto enumCaseDeclRef = declRef.as() ) - { - sema->ensureDecl(declRef.declRefBase, DeclCheckState::DefinitionChecked); - QualType qualType; - qualType.type = getType(astBuilder, enumCaseDeclRef); - qualType.isLeftValue = false; - return qualType; - } - else if (auto typeAliasDeclRef = declRef.as()) - { - auto type = getNamedType(astBuilder, typeAliasDeclRef); - *outTypeResult = type; - return QualType(astBuilder->getTypeType(type)); - } - else if (auto aggTypeDeclRef = declRef.as()) - { - auto type = DeclRefType::create(astBuilder, aggTypeDeclRef); - *outTypeResult = type; - return QualType(astBuilder->getTypeType(type)); - } - else if (auto simpleTypeDeclRef = declRef.as()) - { - auto type = DeclRefType::create(astBuilder, simpleTypeDeclRef); - *outTypeResult = type; - return QualType(astBuilder->getTypeType(type)); - } - else if (auto genericDeclRef = declRef.as()) - { - auto type = getGenericDeclRefType(astBuilder, genericDeclRef); - *outTypeResult = type; - return QualType(astBuilder->getTypeType(type)); - } - else if (auto funcDeclRef = declRef.as()) - { - auto type = getFuncType(astBuilder, funcDeclRef); - return QualType(type); - } - else if (auto constraintDeclRef = declRef.as()) - { - // When we access a constraint or an inheritance decl (as a member), - // we are conceptually performing a "cast" to the given super-type, - // with the declaration showing that such a cast is legal. - auto type = getSup(astBuilder, constraintDeclRef); - return QualType(type); - } - else if( auto namespaceDeclRef = declRef.as()) - { - auto type = getNamespaceType(astBuilder, namespaceDeclRef); - return QualType(type); - } - if( sink ) - { - // The compiler is trying to form a reference to a declaration - // that doesn't appear to be usable as an expression or type. - // - // In practice, this arises when user code has an undefined-identifier - // error, but the name that was undefined in context also matches - // a contextual keyword. Rather than confuse the user with the - // details of contextual keywords in the compiler, we will diagnose - // this as an undefined identifier. - // - // TODO: This code could break if we ever go down this path with - // an identifier that doesn't have a name. - // - sink->diagnose(loc, Diagnostics::undefinedIdentifier2, declRef.getName()); - } - return QualType(astBuilder->getErrorType()); + qualType.isLeftValue = isLValue; + return qualType; } - - QualType getTypeForDeclRef( - ASTBuilder* astBuilder, - DeclRef declRef, - SourceLoc loc) + else if (auto enumCaseDeclRef = declRef.as()) { - Type* typeResult = nullptr; - return getTypeForDeclRef(astBuilder, nullptr, nullptr, declRef, &typeResult, loc); + sema->ensureDecl(declRef.declRefBase, DeclCheckState::DefinitionChecked); + QualType qualType; + qualType.type = getType(astBuilder, enumCaseDeclRef); + qualType.isLeftValue = false; + return qualType; } - - DeclRef applyExtensionToType( - SemanticsVisitor* semantics, - ExtensionDecl* extDecl, - Type* type, - Dictionary* additionalSubtypeWitness) + else if (auto typeAliasDeclRef = declRef.as()) { - if(!semantics) - return DeclRef(); - - return semantics->applyExtensionToType(extDecl, type, additionalSubtypeWitness); + auto type = getNamedType(astBuilder, typeAliasDeclRef); + *outTypeResult = type; + return QualType(astBuilder->getTypeType(type)); } - - bool SemanticsVisitor::isDeclUsableAsStaticMember( - Decl* decl) + else if (auto aggTypeDeclRef = declRef.as()) { - if (m_allowStaticReferenceToNonStaticMember) - return true; + auto type = DeclRefType::create(astBuilder, aggTypeDeclRef); + *outTypeResult = type; + return QualType(astBuilder->getTypeType(type)); + } + else if (auto simpleTypeDeclRef = declRef.as()) + { + auto type = DeclRefType::create(astBuilder, simpleTypeDeclRef); + *outTypeResult = type; + return QualType(astBuilder->getTypeType(type)); + } + else if (auto genericDeclRef = declRef.as()) + { + auto type = getGenericDeclRefType(astBuilder, genericDeclRef); + *outTypeResult = type; + return QualType(astBuilder->getTypeType(type)); + } + else if (auto funcDeclRef = declRef.as()) + { + auto type = getFuncType(astBuilder, funcDeclRef); + return QualType(type); + } + else if (auto constraintDeclRef = declRef.as()) + { + // When we access a constraint or an inheritance decl (as a member), + // we are conceptually performing a "cast" to the given super-type, + // with the declaration showing that such a cast is legal. + auto type = getSup(astBuilder, constraintDeclRef); + return QualType(type); + } + else if (auto namespaceDeclRef = declRef.as()) + { + auto type = getNamespaceType(astBuilder, namespaceDeclRef); + return QualType(type); + } + if (sink) + { + // The compiler is trying to form a reference to a declaration + // that doesn't appear to be usable as an expression or type. + // + // In practice, this arises when user code has an undefined-identifier + // error, but the name that was undefined in context also matches + // a contextual keyword. Rather than confuse the user with the + // details of contextual keywords in the compiler, we will diagnose + // this as an undefined identifier. + // + // TODO: This code could break if we ever go down this path with + // an identifier that doesn't have a name. + // + sink->diagnose(loc, Diagnostics::undefinedIdentifier2, declRef.getName()); + } + return QualType(astBuilder->getErrorType()); +} - if(auto genericDecl = as(decl)) - decl = genericDecl->inner; +QualType getTypeForDeclRef(ASTBuilder* astBuilder, DeclRef declRef, SourceLoc loc) +{ + Type* typeResult = nullptr; + return getTypeForDeclRef(astBuilder, nullptr, nullptr, declRef, &typeResult, loc); +} - if(decl->hasModifier()) - return true; +DeclRef applyExtensionToType( + SemanticsVisitor* semantics, + ExtensionDecl* extDecl, + Type* type, + Dictionary* additionalSubtypeWitness) +{ + if (!semantics) + return DeclRef(); - if(as(decl)) - return true; + return semantics->applyExtensionToType(extDecl, type, additionalSubtypeWitness); +} - if(as(decl)) - return true; +bool SemanticsVisitor::isDeclUsableAsStaticMember(Decl* decl) +{ + if (m_allowStaticReferenceToNonStaticMember) + return true; - if(as(decl)) - return true; + if (auto genericDecl = as(decl)) + decl = genericDecl->inner; - if(as(decl)) - return true; + if (decl->hasModifier()) + return true; - if(as(decl)) - return true; + if (as(decl)) + return true; - return false; - } + if (as(decl)) + return true; - bool SemanticsVisitor::isUsableAsStaticMember( - LookupResultItem const& item) - { - if (m_allowStaticReferenceToNonStaticMember) - return true; + if (as(decl)) + return true; - // There's a bit of a gotcha here, because a lookup result - // item might include "breadcrumbs" that indicate more steps - // along the lookup path. As a result it isn't always - // valid to just check whether the final decl is usable - // as a static member, because it might not even be a - // member of the thing we are trying to work with. - // + if (as(decl)) + return true; - Decl* decl = item.declRef.getDecl(); - for(auto bb = item.breadcrumbs; bb; bb = bb->next) + if (as(decl)) + return true; + + return false; +} + +bool SemanticsVisitor::isUsableAsStaticMember(LookupResultItem const& item) +{ + if (m_allowStaticReferenceToNonStaticMember) + return true; + + // There's a bit of a gotcha here, because a lookup result + // item might include "breadcrumbs" that indicate more steps + // along the lookup path. As a result it isn't always + // valid to just check whether the final decl is usable + // as a static member, because it might not even be a + // member of the thing we are trying to work with. + // + + Decl* decl = item.declRef.getDecl(); + for (auto bb = item.breadcrumbs; bb; bb = bb->next) + { + switch (bb->kind) { - switch(bb->kind) - { - // In case lookup went through a `__transparent` member, - // we are interested in the static-ness of that transparent - // member, and *not* the static-ness of whatever was inside - // of it. - // - // TODO: This would need some work if we ever had - // transparent *type* members. - // - case LookupResultItem::Breadcrumb::Kind::Member: - decl = bb->declRef.getDecl(); - break; + // In case lookup went through a `__transparent` member, + // we are interested in the static-ness of that transparent + // member, and *not* the static-ness of whatever was inside + // of it. + // + // TODO: This would need some work if we ever had + // transparent *type* members. + // + case LookupResultItem::Breadcrumb::Kind::Member: + decl = bb->declRef.getDecl(); + break; // TODO: Are there any other cases that need special-case // handling here? - default: - break; - } + default: break; } + } - // Okay, we've found the declaration we should actually - // be checking, so lets validate that. + // Okay, we've found the declaration we should actually + // be checking, so lets validate that. - return isDeclUsableAsStaticMember(decl); - } + return isDeclUsableAsStaticMember(decl); +} + +/// Dispatch an appropriate visitor to check `decl` up to state `state` +/// +/// The current state of `decl` must be `state-1`. +/// This call does *not* handle updating the state of `decl`; the +/// caller takes responsibility for doing so. +/// +static void _dispatchDeclCheckingVisitor( + Decl* decl, + DeclCheckState state, + SemanticsContext& shared); + +// Make sure a declaration has been checked, so we can refer to it. +// Note that this may lead to us recursively invoking checking, +// so this may not be the best way to handle things. +void SemanticsVisitor::ensureDecl(Decl* decl, DeclCheckState state, SemanticsContext* baseContext) +{ + // If the `decl` has already been checked up to or beyond `state` + // then there is nothing for us to do. + // + if (decl->isChecked(state)) + return; + + // Is the declaration already being checked, somewhere up the + // call stack from us? + // + if (decl->checkState.isBeingChecked()) + { + // We tried to reference the same declaration while checking it! + // + // TODO: we should ideally be tracking a "chain" of declarations + // being checked on the stack, so that we can report the full + // chain that leads from this declaration back to itself. + // + getSink()->diagnose(decl, Diagnostics::cyclicReference, decl); + return; + } + + // If we should skip the checking, return now. + // A common case to skip checking is for the function bodies when we are in + // the language server. In that case we only care about the function bodies in a + // specific module and can skip checking the reference modules until they + // are being opened/edited later. + if (shouldSkipChecking(decl, state)) + { + decl->setCheckState(state); + return; + } + + // Set the flag that indicates we are checking this declaration, + // so that the cycle check above will catch us before we go + // into any infinite loops. + // + decl->checkState.setIsBeingChecked(true); + + // Our task is to bring the `decl` up to `state` which may be + // one or more steps ahead of where it currently is. We can + // invoke a visitor designed to bring a declaration from state + // N to state N+1, and in general we might need multiple such + // passes to get `decl` to where we need it. + // + // The coding of this loop is somewhat defensive to deal + // with special cases that will be described along the way. + // + auto outerScope = getScope(decl); + for (;;) + { + // The first thing is to check what state the decl is + // currently in at the start of this loop iteration, + // and to bail out if it has been checked up to + // (or beyond) our target state. + // + auto currentState = decl->checkState.getState(); + if (currentState >= state) + break; - /// Dispatch an appropriate visitor to check `decl` up to state `state` - /// - /// The current state of `decl` must be `state-1`. - /// This call does *not* handle updating the state of `decl`; the - /// caller takes responsibility for doing so. - /// - static void _dispatchDeclCheckingVisitor(Decl* decl, DeclCheckState state, SemanticsContext& shared); + // Because our visitors are only designed to go from state + // N to N+1 in general, we will aspire to transition to + // a state that is one greater than `currentState`. + // + auto nextState = DeclCheckState(Int(currentState) + 1); - // Make sure a declaration has been checked, so we can refer to it. - // Note that this may lead to us recursively invoking checking, - // so this may not be the best way to handle things. - void SemanticsVisitor::ensureDecl(Decl* decl, DeclCheckState state, SemanticsContext* baseContext) - { - // If the `decl` has already been checked up to or beyond `state` - // then there is nothing for us to do. + // We now dispatch an appropriate visitor based on `nextState`. + // + // Note that we always dispatch the visitor in a "fresh" semantic-checking + // context, so that the state at the point where a declaration is *referenced* + // cannot affect the state in which the declaration is *checked*. // - if (decl->isChecked(state)) return; + SemanticsContext subContext = + baseContext ? SemanticsContext(*baseContext) : SemanticsContext(getShared()); + if (outerScope) + subContext = subContext.withOuterScope(outerScope); + _dispatchDeclCheckingVisitor(decl, nextState, subContext); - // Is the declaration already being checked, somewhere up the - // call stack from us? + // In the common case, the visitor will have done the necessary + // checking, but will *not* have updated the `checkState` on + // `decl`. In that case we will do the update here, to save + // us the complication of having to deal with state update in + // every single visitor method. // - if(decl->checkState.isBeingChecked()) + // However, sometimes a visitor *will* want to manually update + // the state of a declaration, and it may actually update it + // *past* the `nextState` we asked for (or even past the + // eventual target `state`). In those cases we don't want to + // accidentally set the state of `decl` to something lower + // than what has actually been checked, so we test for + // such cases here. + // + if (nextState > decl->checkState.getState()) { - // We tried to reference the same declaration while checking it! - // - // TODO: we should ideally be tracking a "chain" of declarations - // being checked on the stack, so that we can report the full - // chain that leads from this declaration back to itself. - // - getSink()->diagnose(decl, Diagnostics::cyclicReference, decl); - return; + decl->setCheckState(nextState); } + } - // If we should skip the checking, return now. - // A common case to skip checking is for the function bodies when we are in - // the language server. In that case we only care about the function bodies in a - // specific module and can skip checking the reference modules until they - // are being opened/edited later. - if (shouldSkipChecking(decl, state)) - { - decl->setCheckState(state); - return; - } + // Once we are done here, the state of `decl` should have + // been upgraded to (at least) `state`. + // + SLANG_ASSERT(decl->isChecked(state)); - // Set the flag that indicates we are checking this declaration, - // so that the cycle check above will catch us before we go - // into any infinite loops. - // - decl->checkState.setIsBeingChecked(true); + // Now that we are done checking `decl` we need to restore + // its "is being checked" flag so that we don't generate + // errors the next time somebody calls `ensureDecl()` on it. + // + decl->checkState.setIsBeingChecked(false); +} - // Our task is to bring the `decl` up to `state` which may be - // one or more steps ahead of where it currently is. We can - // invoke a visitor designed to bring a declaration from state - // N to state N+1, and in general we might need multiple such - // passes to get `decl` to where we need it. - // - // The coding of this loop is somewhat defensive to deal - // with special cases that will be described along the way. +/// Recursively ensure the tree of declarations under `decl` is in `state`. +/// +/// This function does *not* handle declarations nested in function bodies +/// because those cannot be meaningfully checked outside of the context +/// of their surrounding statement(s). +/// +void SemanticsVisitor::ensureAllDeclsRec(Decl* decl, DeclCheckState state) +{ + // Ensure `decl` itself first. + ensureDecl(decl, state); + + // If `decl` is a container, then we want to ensure its children. + if (auto containerDecl = as(decl)) + { + // NOTE! We purposefully do not iterate with the for(auto childDecl : + // containerDecl->members) here, because the visitor may add to `members` whilst iteration + // takes place, invalidating the iterator and likely a crash. // - auto outerScope = getScope(decl); - for(;;) + // Accessing the members via index side steps the issue. + const auto& members = containerDecl->members; + for (Index i = 0; i < members.getCount(); ++i) { - // The first thing is to check what state the decl is - // currently in at the start of this loop iteration, - // and to bail out if it has been checked up to - // (or beyond) our target state. - // - auto currentState = decl->checkState.getState(); - if(currentState >= state) - break; - - // Because our visitors are only designed to go from state - // N to N+1 in general, we will aspire to transition to - // a state that is one greater than `currentState`. - // - auto nextState = DeclCheckState(Int(currentState) + 1); + Decl* childDecl = members[i]; - // We now dispatch an appropriate visitor based on `nextState`. + // As an exception, if any of the child is a `ScopeDecl`, + // then that indicates that it represents a scope for local + // declarations under a statement (e.g., in a function body), + // and we don't want to check such local declarations here. // - // Note that we always dispatch the visitor in a "fresh" semantic-checking - // context, so that the state at the point where a declaration is *referenced* - // cannot affect the state in which the declaration is *checked*. - // - SemanticsContext subContext = baseContext ? SemanticsContext(*baseContext) : SemanticsContext(getShared()); - if (outerScope) - subContext = subContext.withOuterScope(outerScope); - _dispatchDeclCheckingVisitor(decl, nextState, subContext); - - // In the common case, the visitor will have done the necessary - // checking, but will *not* have updated the `checkState` on - // `decl`. In that case we will do the update here, to save - // us the complication of having to deal with state update in - // every single visitor method. - // - // However, sometimes a visitor *will* want to manually update - // the state of a declaration, and it may actually update it - // *past* the `nextState` we asked for (or even past the - // eventual target `state`). In those cases we don't want to - // accidentally set the state of `decl` to something lower - // than what has actually been checked, so we test for - // such cases here. - // - if(nextState > decl->checkState.getState()) - { - decl->setCheckState(nextState); - } - } - // Once we are done here, the state of `decl` should have - // been upgraded to (at least) `state`. - // - SLANG_ASSERT(decl->isChecked(state)); + if (as(childDecl)) + continue; - // Now that we are done checking `decl` we need to restore - // its "is being checked" flag so that we don't generate - // errors the next time somebody calls `ensureDecl()` on it. - // - decl->checkState.setIsBeingChecked(false); + ensureAllDeclsRec(childDecl, state); + } } - /// Recursively ensure the tree of declarations under `decl` is in `state`. - /// - /// This function does *not* handle declarations nested in function bodies - /// because those cannot be meaningfully checked outside of the context - /// of their surrounding statement(s). - /// - void SemanticsVisitor::ensureAllDeclsRec( - Decl* decl, - DeclCheckState state) + // Note: the "inner" declaration of a `GenericDecl` is currently + // not exposed as one of its children (despite a `GenericDecl` + // being a `ContainerDecl`), so we need to handle the inner + // declaration of a generic as another case here. + // + if (auto genericDecl = as(decl)) { - // Ensure `decl` itself first. - ensureDecl(decl, state); - - // If `decl` is a container, then we want to ensure its children. - if(auto containerDecl = as(decl)) - { - // NOTE! We purposefully do not iterate with the for(auto childDecl : containerDecl->members) here, - // because the visitor may add to `members` whilst iteration takes place, invalidating the iterator - // and likely a crash. - // - // Accessing the members via index side steps the issue. - const auto& members = containerDecl->members; - for(Index i = 0; i < members.getCount(); ++i) - { - Decl* childDecl = members[i]; - - // As an exception, if any of the child is a `ScopeDecl`, - // then that indicates that it represents a scope for local - // declarations under a statement (e.g., in a function body), - // and we don't want to check such local declarations here. - // - - if(as(childDecl)) - continue; + ensureAllDeclsRec(genericDecl->inner, state); + } +} - ensureAllDeclsRec(childDecl, state); - } - } +bool isUnsizedArrayType(Type* type) +{ + // Not an array? + auto arrayType = as(type); + if (!arrayType) + return false; - // Note: the "inner" declaration of a `GenericDecl` is currently - // not exposed as one of its children (despite a `GenericDecl` - // being a `ContainerDecl`), so we need to handle the inner - // declaration of a generic as another case here. - // - if(auto genericDecl = as(decl)) - { - ensureAllDeclsRec(genericDecl->inner, state); - } - } + // Explicit element count given? + return arrayType->isUnsized(); +} - bool isUnsizedArrayType(Type* type) +bool isInterfaceType(Type* type) +{ + if (auto declRefType = as(type)) { - // Not an array? - auto arrayType = as(type); - if (!arrayType) return false; - - // Explicit element count given? - return arrayType->isUnsized(); + if (auto interfaceDeclRef = declRefType->getDeclRef().as()) + return true; } + return false; +} - bool isInterfaceType(Type* type) +EnumDecl* isEnumType(Type* type) +{ + if (auto declRefType = as(type)) { - if (auto declRefType = as(type)) - { - if (auto interfaceDeclRef = declRefType->getDeclRef().as()) - return true; - } - return false; + return as(declRefType->getDeclRef().getDecl()); } + return nullptr; +} - EnumDecl* isEnumType(Type* type) +bool SemanticsVisitor::shouldSkipChecking(Decl* decl, DeclCheckState state) +{ + if (state < DeclCheckState::DefinitionChecked) + return false; + // If we are in language server, we should skip checking all the function bodies + // except for the module or function that the user cared about. + // This optimization helps reduce the response time. + if (!getLinkage()->isInLanguageServer()) { - if (auto declRefType = as(type)) - { - return as(declRefType->getDeclRef().getDecl()); - } - return nullptr; + return false; } - - bool SemanticsVisitor::shouldSkipChecking(Decl* decl, DeclCheckState state) + if (auto funcDecl = as(decl)) { - if (state < DeclCheckState::DefinitionChecked) - return false; - // If we are in language server, we should skip checking all the function bodies - // except for the module or function that the user cared about. - // This optimization helps reduce the response time. - if (!getLinkage()->isInLanguageServer()) - { - return false; - } - if (auto funcDecl = as(decl)) + auto& assistInfo = getLinkage()->contentAssistInfo; + // If this func is not defined in the primary module, skip checking its body. + auto moduleDecl = getModuleDecl(decl); + if (moduleDecl && moduleDecl->getName() != assistInfo.primaryModuleName) + return true; + if (funcDecl->body) { - auto& assistInfo = getLinkage()->contentAssistInfo; - // If this func is not defined in the primary module, skip checking its body. - auto moduleDecl = getModuleDecl(decl); - if (moduleDecl && moduleDecl->getName() != assistInfo.primaryModuleName) + auto humaneLoc = + getLinkage()->getSourceManager()->getHumaneLoc(decl->loc, SourceLocType::Actual); + if (humaneLoc.pathInfo.foundPath != assistInfo.primaryModulePath) + { return true; - if (funcDecl->body) + } + if (assistInfo.checkingMode == ContentAssistCheckingMode::Completion) { - auto humaneLoc = getLinkage()->getSourceManager()->getHumaneLoc( - decl->loc, SourceLocType::Actual); - if (humaneLoc.pathInfo.foundPath != assistInfo.primaryModulePath) + // For completion requests, we skip all funtion bodies except for the one + // that the current cursor is in. + auto startingLine = humaneLoc.line; + for (auto modifier : funcDecl->modifiers) { - return true; + auto modifierLoc = getLinkage()->getSourceManager()->getHumaneLoc( + modifier->loc, + SourceLocType::Actual); + if (modifierLoc.line < startingLine) + startingLine = modifierLoc.line; } - if (assistInfo.checkingMode == ContentAssistCheckingMode::Completion) - { - // For completion requests, we skip all funtion bodies except for the one - // that the current cursor is in. - auto startingLine = humaneLoc.line; - for (auto modifier : funcDecl->modifiers) - { - auto modifierLoc = getLinkage()->getSourceManager()->getHumaneLoc( - modifier->loc, SourceLocType::Actual); - if (modifierLoc.line < startingLine) - startingLine = modifierLoc.line; - } - auto closingLoc = getLinkage()->getSourceManager()->getHumaneLoc( - funcDecl->closingSourceLoc, SourceLocType::Actual); + auto closingLoc = getLinkage()->getSourceManager()->getHumaneLoc( + funcDecl->closingSourceLoc, + SourceLocType::Actual); - if (assistInfo.cursorLine < startingLine || - assistInfo.cursorLine > closingLoc.line) - return true; - } + if (assistInfo.cursorLine < startingLine || assistInfo.cursorLine > closingLoc.line) + return true; } } - return false; } + return false; +} + +void SemanticsVisitor::_validateCircularVarDefinition(VarDeclBase* varDecl) +{ + // The easiest way to test if the declaration is circular is to + // validate it as a constant. + // + // TODO: The logic here will only apply for `static const` declarations + // of integer type, given that our constant folding currently only + // applies to such types. A more robust fix would involve a truly + // recursive walk of the AST declarations, and an even *more* robust + // fix would wait until after IR linking to detect and diagnose circularity + // in case it crosses module boundaries. + // + // + if (!isScalarIntegerType(varDecl->type)) + return; + tryConstantFoldDeclRef(DeclRef(varDecl), ConstantFoldingKind::LinkTime, nullptr); +} + +void SemanticsDeclModifiersVisitor::visitStructDecl(StructDecl* structDecl) +{ + checkModifiers(structDecl); + + // Replace any bitfield member with a property, do this here before + // name lookup to avoid the original var decl being referenced + for (auto& m : structDecl->members) + { + const auto bfm = m->findModifier(); + if (!bfm) + continue; + + auto property = m_astBuilder->create(); + property->modifiers = m->modifiers; + property->type = as(m)->type; + property->loc = m->loc; + property->nameAndLoc = m->getNameAndLoc(); + property->parentDecl = structDecl; + property->ownedScope = m_astBuilder->create(); + property->ownedScope->containerDecl = property; + property->ownedScope->parent = getScope(structDecl); + m = property; + + const auto get = m_astBuilder->create(); + get->ownedScope = m_astBuilder->create(); + get->ownedScope->containerDecl = get; + get->ownedScope->parent = getScope(property); + property->addMember(get); + + const auto set = m_astBuilder->create(); + addModifier(set, m_astBuilder->create()); + set->ownedScope = m_astBuilder->create(); + set->ownedScope->containerDecl = set; + set->ownedScope->parent = getScope(property); + property->addMember(set); + + structDecl->invalidateMemberDictionary(); + } + structDecl->buildMemberDictionary(); +} - void SemanticsVisitor::_validateCircularVarDefinition(VarDeclBase* varDecl) +void SemanticsDeclHeaderVisitor::checkDerivativeMemberAttributeParent( + VarDeclBase* varDecl, + DerivativeMemberAttribute* derivativeMemberAttr) +{ + auto memberType = checkProperType(getLinkage(), varDecl->type, getSink()); + auto diffType = getDifferentialType(m_astBuilder, memberType, varDecl->loc); + if (as(diffType)) { - // The easiest way to test if the declaration is circular is to - // validate it as a constant. - // - // TODO: The logic here will only apply for `static const` declarations - // of integer type, given that our constant folding currently only - // applies to such types. A more robust fix would involve a truly - // recursive walk of the AST declarations, and an even *more* robust - // fix would wait until after IR linking to detect and diagnose circularity - // in case it crosses module boundaries. - // - // - if(!isScalarIntegerType(varDecl->type)) - return; - tryConstantFoldDeclRef(DeclRef(varDecl), ConstantFoldingKind::LinkTime, nullptr); + getSink()->diagnose(derivativeMemberAttr, Diagnostics::typeIsNotDifferentiable, memberType); } - - void SemanticsDeclModifiersVisitor::visitStructDecl(StructDecl* structDecl) + auto thisType = calcThisType(makeDeclRef(varDecl->parentDecl)); + if (!thisType) { - checkModifiers(structDecl); - - // Replace any bitfield member with a property, do this here before - // name lookup to avoid the original var decl being referenced - for(auto& m : structDecl->members) - { - const auto bfm = m->findModifier(); - if(!bfm) - continue; - - auto property = m_astBuilder->create(); - property->modifiers = m->modifiers; - property->type = as(m)->type; - property->loc = m->loc; - property->nameAndLoc = m->getNameAndLoc(); - property->parentDecl = structDecl; - property->ownedScope = m_astBuilder->create(); - property->ownedScope->containerDecl = property; - property->ownedScope->parent = getScope(structDecl); - m = property; - - const auto get = m_astBuilder->create(); - get->ownedScope = m_astBuilder->create(); - get->ownedScope->containerDecl = get; - get->ownedScope->parent = getScope(property); - property->addMember(get); - - const auto set = m_astBuilder->create(); - addModifier(set, m_astBuilder->create()); - set->ownedScope = m_astBuilder->create(); - set->ownedScope->containerDecl = set; - set->ownedScope->parent = getScope(property); - property->addMember(set); - - structDecl->invalidateMemberDictionary(); - } - structDecl->buildMemberDictionary(); + getSink()->diagnose( + derivativeMemberAttr, + Diagnostics::derivativeMemberAttributeCanOnlyBeUsedOnMembers); } - - void SemanticsDeclHeaderVisitor::checkDerivativeMemberAttributeParent( - VarDeclBase* varDecl, DerivativeMemberAttribute* derivativeMemberAttr) + auto diffThisType = getDifferentialType(m_astBuilder, thisType, derivativeMemberAttr->loc); + if (!diffThisType) { - auto memberType = checkProperType(getLinkage(), varDecl->type, getSink()); - auto diffType = getDifferentialType(m_astBuilder, memberType, varDecl->loc); - if (as(diffType)) - { - getSink()->diagnose(derivativeMemberAttr, Diagnostics::typeIsNotDifferentiable, memberType); - } - auto thisType = calcThisType(makeDeclRef(varDecl->parentDecl)); - if (!thisType) - { - getSink()->diagnose( - derivativeMemberAttr, - Diagnostics:: - derivativeMemberAttributeCanOnlyBeUsedOnMembers); - } - auto diffThisType = getDifferentialType(m_astBuilder, thisType, derivativeMemberAttr->loc); - if (!diffThisType) - { - getSink()->diagnose( - derivativeMemberAttr, - Diagnostics::invalidUseOfDerivativeMemberAttributeParentTypeIsNotDifferentiable); - } + getSink()->diagnose( + derivativeMemberAttr, + Diagnostics::invalidUseOfDerivativeMemberAttributeParentTypeIsNotDifferentiable); } +} - void SemanticsDeclHeaderVisitor::checkExtensionExternVarAttribute(VarDeclBase* varDecl, ExtensionExternVarModifier* extensionExternMemberModifier) +void SemanticsDeclHeaderVisitor::checkExtensionExternVarAttribute( + VarDeclBase* varDecl, + ExtensionExternVarModifier* extensionExternMemberModifier) +{ + if (const auto parentExtension = as(varDecl->parentDecl)) { - if (const auto parentExtension = as(varDecl->parentDecl)) + if (auto originalVarDecl = extensionExternMemberModifier->originalDecl.as()) { - if (auto originalVarDecl = extensionExternMemberModifier->originalDecl.as()) + auto originalType = GetTypeForDeclRef(originalVarDecl, originalVarDecl.getLoc()); + auto extVarType = varDecl->type; + if (!extVarType.type || !extVarType.type->equals(originalType)) { - auto originalType = GetTypeForDeclRef(originalVarDecl, originalVarDecl.getLoc()); - auto extVarType = varDecl->type; - if (!extVarType.type || !extVarType.type->equals(originalType)) - { - getSink()->diagnose(varDecl, Diagnostics::typeOfExternDeclMismatchesOriginalDefinition, varDecl, originalType); - } - else - { - return; - } + getSink()->diagnose( + varDecl, + Diagnostics::typeOfExternDeclMismatchesOriginalDefinition, + varDecl, + originalType); } else { - getSink()->diagnose(varDecl, Diagnostics::definitionOfExternDeclMismatchesOriginalDefinition, varDecl); + return; } } + else + { + getSink()->diagnose( + varDecl, + Diagnostics::definitionOfExternDeclMismatchesOriginalDefinition, + varDecl); + } } +} - ImageFormat inferImageFormatFromTextureType(VarDeclBase* varDecl, TextureTypeBase* textureType, bool &outIsInferred) +ImageFormat inferImageFormatFromTextureType( + VarDeclBase* varDecl, + TextureTypeBase* textureType, + bool& outIsInferred) +{ + outIsInferred = false; + ImageFormat format = ImageFormat::unknown; + if (auto formatVal = as(textureType->getFormat())) { - outIsInferred = false; - ImageFormat format = ImageFormat::unknown; - if (auto formatVal = as(textureType->getFormat())) - { - format = (ImageFormat)formatVal->getValue(); - } - if (format != ImageFormat::unknown) - return format; + format = (ImageFormat)formatVal->getValue(); + } + if (format != ImageFormat::unknown) + return format; - if (auto formatAttrib = varDecl->findModifier()) + if (auto formatAttrib = varDecl->findModifier()) + { + format = formatAttrib->format; + } + else + { + // If format is not specified explicitly through format attribute, we will derive a default + // value from the element format. + outIsInferred = true; + auto elementType = textureType->getElementType(); + Int vectorWidth = 1; + if (auto elementVecType = as(elementType)) { - format = formatAttrib->format; + if (auto intLitVal = as(elementVecType->getElementCount())) + { + vectorWidth = (Int)intLitVal->getValue(); + } + else + { + vectorWidth = 1; + } + elementType = elementVecType->getElementType(); } - else + if (auto basicType = as(elementType)) { - // If format is not specified explicitly through format attribute, we will derive a default - // value from the element format. - outIsInferred = true; - auto elementType = textureType->getElementType(); - Int vectorWidth = 1; - if (auto elementVecType = as(elementType)) + switch (basicType->getBaseType()) { - if (auto intLitVal = as(elementVecType->getElementCount())) + case BaseType::UInt: + switch (vectorWidth) { - vectorWidth = (Int)intLitVal->getValue(); + case 1: format = ImageFormat::r32ui; break; + case 2: format = ImageFormat::rg32ui; break; + case 4: format = ImageFormat::rgba32ui; break; } - else + break; + case BaseType::Int: + switch (vectorWidth) { - vectorWidth = 1; + case 1: format = ImageFormat::r32i; break; + case 2: format = ImageFormat::rg32i; break; + case 4: format = ImageFormat::rgba32i; break; } - elementType = elementVecType->getElementType(); - } - if (auto basicType = as(elementType)) - { - switch (basicType->getBaseType()) + break; + case BaseType::UInt16: + switch (vectorWidth) { - case BaseType::UInt: - switch (vectorWidth) - { - case 1: format = ImageFormat::r32ui; break; - case 2: format = ImageFormat::rg32ui; break; - case 4: format = ImageFormat::rgba32ui; break; - } - break; - case BaseType::Int: - switch (vectorWidth) - { - case 1: format = ImageFormat::r32i; break; - case 2: format = ImageFormat::rg32i; break; - case 4: format = ImageFormat::rgba32i; break; - } - break; - case BaseType::UInt16: - switch (vectorWidth) - { - case 1: format = ImageFormat::r16ui; break; - case 2: format = ImageFormat::rg16ui; break; - case 4: format = ImageFormat::rgba16ui; break; - } - break; - case BaseType::Int16: - switch (vectorWidth) - { - case 1: format = ImageFormat::r16i; break; - case 2: format = ImageFormat::rg16i; break; - case 4: format = ImageFormat::rgba16i; break; - } - break; - case BaseType::UInt8: - switch (vectorWidth) - { - case 1: format = ImageFormat::r8ui; break; - case 2: format = ImageFormat::rg8ui; break; - case 4: format = ImageFormat::rgba8ui; break; - } - break; - case BaseType::Int8: - switch (vectorWidth) - { - case 1: format = ImageFormat::r8i; break; - case 2: format = ImageFormat::rg8i; break; - case 4: format = ImageFormat::rgba8i; break; - } - break; - case BaseType::Int64: - switch (vectorWidth) - { - case 1: format = ImageFormat::r64i; break; - default: break; - } - break; - case BaseType::UInt64: - switch (vectorWidth) - { - case 1: format = ImageFormat::r64ui; break; - default: break; - } - break; - case BaseType::Half: - switch (vectorWidth) - { - case 1: format = ImageFormat::r16f; break; - case 2: format = ImageFormat::rg16f; break; - case 4: format = ImageFormat::rgba16f; break; - } - break; + case 1: format = ImageFormat::r16ui; break; + case 2: format = ImageFormat::rg16ui; break; + case 4: format = ImageFormat::rgba16ui; break; + } + break; + case BaseType::Int16: + switch (vectorWidth) + { + case 1: format = ImageFormat::r16i; break; + case 2: format = ImageFormat::rg16i; break; + case 4: format = ImageFormat::rgba16i; break; + } + break; + case BaseType::UInt8: + switch (vectorWidth) + { + case 1: format = ImageFormat::r8ui; break; + case 2: format = ImageFormat::rg8ui; break; + case 4: format = ImageFormat::rgba8ui; break; + } + break; + case BaseType::Int8: + switch (vectorWidth) + { + case 1: format = ImageFormat::r8i; break; + case 2: format = ImageFormat::rg8i; break; + case 4: format = ImageFormat::rgba8i; break; + } + break; + case BaseType::Int64: + switch (vectorWidth) + { + case 1: format = ImageFormat::r64i; break; + default: break; + } + break; + case BaseType::UInt64: + switch (vectorWidth) + { + case 1: format = ImageFormat::r64ui; break; + default: break; } + break; + case BaseType::Half: + switch (vectorWidth) + { + case 1: format = ImageFormat::r16f; break; + case 2: format = ImageFormat::rg16f; break; + case 4: format = ImageFormat::rgba16f; break; + } + break; } } - return format; } + return format; +} - void SemanticsDeclHeaderVisitor::maybeApplyLayoutModifier(VarDeclBase* varDecl) +void SemanticsDeclHeaderVisitor::maybeApplyLayoutModifier(VarDeclBase* varDecl) +{ + if (auto matrixType = as(varDecl->type.type)) { - if (auto matrixType = as(varDecl->type.type)) + if (auto matrixLayoutModifier = varDecl->findModifier()) { - if (auto matrixLayoutModifier = varDecl->findModifier()) - { - auto matrixLayout = as(matrixLayoutModifier) ? SLANG_MATRIX_LAYOUT_COLUMN_MAJOR : SLANG_MATRIX_LAYOUT_ROW_MAJOR; - auto newMatrixType = getASTBuilder()->getMatrixType( - matrixType->getElementType(), - matrixType->getRowCount(), - matrixType->getColumnCount(), - getASTBuilder()->getIntVal(getASTBuilder()->getIntType(), matrixLayout)); - varDecl->type.type = newMatrixType; - } + auto matrixLayout = as(matrixLayoutModifier) + ? SLANG_MATRIX_LAYOUT_COLUMN_MAJOR + : SLANG_MATRIX_LAYOUT_ROW_MAJOR; + auto newMatrixType = getASTBuilder()->getMatrixType( + matrixType->getElementType(), + matrixType->getRowCount(), + matrixType->getColumnCount(), + getASTBuilder()->getIntVal(getASTBuilder()->getIntType(), matrixLayout)); + varDecl->type.type = newMatrixType; } - else if (auto textureType = as(unwrapArrayType(varDecl->type.type))) - { - if (getLinkage()->m_optionSet.getBoolOption(CompilerOptionName::DefaultImageFormatUnknown)) - return; + } + else if (auto textureType = as(unwrapArrayType(varDecl->type.type))) + { + if (getLinkage()->m_optionSet.getBoolOption(CompilerOptionName::DefaultImageFormatUnknown)) + return; - // For texture types, we will ensure there is a [format] attribute declared on the decl, - // if not, we will infer the format from the texture type if it is not specified. - // - bool isInferred = false; - auto format = inferImageFormatFromTextureType(varDecl, textureType, isInferred); - if (format != ImageFormat::unknown && isInferred) - { - auto formatAttrib = m_astBuilder->create(); - formatAttrib->format = format; - addModifier(varDecl, formatAttrib); - } + // For texture types, we will ensure there is a [format] attribute declared on the decl, + // if not, we will infer the format from the texture type if it is not specified. + // + bool isInferred = false; + auto format = inferImageFormatFromTextureType(varDecl, textureType, isInferred); + if (format != ImageFormat::unknown && isInferred) + { + auto formatAttrib = m_astBuilder->create(); + formatAttrib->format = format; + addModifier(varDecl, formatAttrib); } } +} - void SemanticsDeclHeaderVisitor::checkVarDeclCommon(VarDeclBase* varDecl) +void SemanticsDeclHeaderVisitor::checkVarDeclCommon(VarDeclBase* varDecl) +{ + // A variable that didn't have an explicit type written must + // have its type inferred from the initial-value expression. + // + if (!varDecl->type.exp) { - // A variable that didn't have an explicit type written must - // have its type inferred from the initial-value expression. - // - if(!varDecl->type.exp) - { - // In this case we need to perform all checking of the - // variable (including semantic checking of the initial-value - // expression) during the first phase of checking. + // In this case we need to perform all checking of the + // variable (including semantic checking of the initial-value + // expression) during the first phase of checking. - auto initExpr = varDecl->initExpr; - if(!initExpr) - { - if (!varDecl->type.type) - { - getSink()->diagnose(varDecl, Diagnostics::varWithoutTypeMustHaveInitializer); - varDecl->type.type = m_astBuilder->getErrorType(); - } - } - else + auto initExpr = varDecl->initExpr; + if (!initExpr) + { + if (!varDecl->type.type) { - SemanticsVisitor subVisitor(withDeclToExcludeFromLookup(varDecl)); - initExpr = subVisitor.CheckExpr(initExpr); - - // TODO: We might need some additional steps here to ensure - // that the type of the expression is one we are okay with - // inferring. E.g., if we ever decide that integer and floating-point - // literals have a distinct type from the standard int/float types, - // then we would need to "decay" a literal to an explicit type here. - - varDecl->initExpr = initExpr; - varDecl->type.type = initExpr->type; - _validateCircularVarDefinition(varDecl); + getSink()->diagnose(varDecl, Diagnostics::varWithoutTypeMustHaveInitializer); + varDecl->type.type = m_astBuilder->getErrorType(); } - - // If we've gone down this path, then the variable - // declaration is actually pretty far along in checking - varDecl->setCheckState(DeclCheckState::DefinitionChecked); } else { - // A variable with an explicit type is simpler, for the - // most part. SemanticsVisitor subVisitor(withDeclToExcludeFromLookup(varDecl)); - TypeExp typeExp = subVisitor.CheckUsableType(varDecl->type, varDecl); - varDecl->type = typeExp; - if (varDecl->type.equals(m_astBuilder->getVoidType())) - { - getSink()->diagnose(varDecl, Diagnostics::invalidTypeVoid); - } - - // If this is an unsized array variable, then we first want to give - // it a chance to infer an array size from its initializer - // - // TODO(tfoley): May need to extend this to handle the - // multi-dimensional case... - // - if(isUnsizedArrayType(varDecl->type)) - { - if (auto initExpr = varDecl->initExpr) - { - initExpr = CheckTerm(initExpr); - initExpr = coerce(CoercionSite::Initializer, varDecl->type.Ptr(), initExpr); - varDecl->initExpr = initExpr; + initExpr = subVisitor.CheckExpr(initExpr); - maybeInferArraySizeForVariable(varDecl); + // TODO: We might need some additional steps here to ensure + // that the type of the expression is one we are okay with + // inferring. E.g., if we ever decide that integer and floating-point + // literals have a distinct type from the standard int/float types, + // then we would need to "decay" a literal to an explicit type here. - varDecl->setCheckState(DeclCheckState::DefinitionChecked); - } - } - // - // Next we want to make sure that the declared (or inferred) - // size for the array meets whatever language-specific - // constraints we want to enforce (e.g., disallow empty - // arrays in specific cases) - // - validateArraySizeForVariable(varDecl); + varDecl->initExpr = initExpr; + varDecl->type.type = initExpr->type; + _validateCircularVarDefinition(varDecl); } - // If there is a matrix layout modifier or texture format modifier, we will modify the type now. - maybeApplyLayoutModifier(varDecl); - - if (varDecl->initExpr) + // If we've gone down this path, then the variable + // declaration is actually pretty far along in checking + varDecl->setCheckState(DeclCheckState::DefinitionChecked); + } + else + { + // A variable with an explicit type is simpler, for the + // most part. + SemanticsVisitor subVisitor(withDeclToExcludeFromLookup(varDecl)); + TypeExp typeExp = subVisitor.CheckUsableType(varDecl->type, varDecl); + varDecl->type = typeExp; + if (varDecl->type.equals(m_astBuilder->getVoidType())) { - if (as(varDecl->type.type)) - { - auto parentDecl = getParentDecl(varDecl); - if (varDecl->findModifier() && - (as(parentDecl) || as(parentDecl) || varDecl->findModifier())) - { - varDecl->val = tryConstantFoldExpr(varDecl->initExpr, ConstantFoldingKind::LinkTime, nullptr); - } - } + getSink()->diagnose(varDecl, Diagnostics::invalidTypeVoid); } - checkMeshOutputDecl(varDecl); - - // The NVAPI library allows user code to express extended operations - // (not supported natively by D3D HLSL) by communicating with - // a specially identified shader parameter called `g_NvidiaExt`. + // If this is an unsized array variable, then we first want to give + // it a chance to infer an array size from its initializer // - // By default, that shader parameter would look like an ordinary - // global shader parameter to Slang, but we want to be able to - // associate special behavior with it to make downstream compilation - // work nicely (especially in the case where certain cross-platform - // operations in the Slang core module need to use NVAPI). + // TODO(tfoley): May need to extend this to handle the + // multi-dimensional case... // - // We will detect a global variable declaration that appears to - // be declaring `g_NvidiaExt` from NVAPI, and mark it with a special - // modifier to allow downstream steps to detect it whether or - // not it has an associated name. - // - if( as(varDecl->parentDecl) - && varDecl->getName() - && varDecl->getName()->text == "g_NvidiaExt" ) + if (isUnsizedArrayType(varDecl->type)) + { + if (auto initExpr = varDecl->initExpr) + { + initExpr = CheckTerm(initExpr); + initExpr = coerce(CoercionSite::Initializer, varDecl->type.Ptr(), initExpr); + varDecl->initExpr = initExpr; + + maybeInferArraySizeForVariable(varDecl); + + varDecl->setCheckState(DeclCheckState::DefinitionChecked); + } + } + // + // Next we want to make sure that the declared (or inferred) + // size for the array meets whatever language-specific + // constraints we want to enforce (e.g., disallow empty + // arrays in specific cases) + // + validateArraySizeForVariable(varDecl); + } + + // If there is a matrix layout modifier or texture format modifier, we will modify the type now. + maybeApplyLayoutModifier(varDecl); + + if (varDecl->initExpr) + { + if (as(varDecl->type.type)) + { + auto parentDecl = getParentDecl(varDecl); + if (varDecl->findModifier() && + (as(parentDecl) || as(parentDecl) || + varDecl->findModifier())) + { + varDecl->val = + tryConstantFoldExpr(varDecl->initExpr, ConstantFoldingKind::LinkTime, nullptr); + } + } + } + + checkMeshOutputDecl(varDecl); + + // The NVAPI library allows user code to express extended operations + // (not supported natively by D3D HLSL) by communicating with + // a specially identified shader parameter called `g_NvidiaExt`. + // + // By default, that shader parameter would look like an ordinary + // global shader parameter to Slang, but we want to be able to + // associate special behavior with it to make downstream compilation + // work nicely (especially in the case where certain cross-platform + // operations in the Slang core module need to use NVAPI). + // + // We will detect a global variable declaration that appears to + // be declaring `g_NvidiaExt` from NVAPI, and mark it with a special + // modifier to allow downstream steps to detect it whether or + // not it has an associated name. + // + if (as(varDecl->parentDecl) && varDecl->getName() && + varDecl->getName()->text == "g_NvidiaExt") + { + addModifier(varDecl, m_astBuilder->create()); + } + // + // One thing that the `NVAPIMagicModifier` is going to do is ensure + // that `g_NvidiaExt` always gets emitted with *exactly* that name, + // whether or not obfuscation or other steps are enabled. + // + // The `g_NvidiaExt` variable is declared as a: + // + // RWStructuredBuffer + // + // and we also want to make sure that the fields of that struct + // retain their original names in output code. We will detect + // variable declarations that represent fields of that struct + // and flag them as "magic" as well. + // + // Note: The goal here is to make it so that generated HLSL output + // can either use these declarations as they have been preocessed + // by the Slang front-end *or* they can use declarations directly + // from the NVAPI header during downstream compilation. + // + // TODO: It would be nice if we had a way to identify *all* of the + // declarations that come from the NVAPI header and mark them, so + // that the Slang front-end doesn't have to take responsibility + // for generating code from them (and can instead rely on the downstream + // compiler alone). + // + // The NVAPI header doesn't put any kind of macro-defined modifier + // (defaulting to an empty macro) in front of its declarations, + // so the most plausible way to add a modifier to all the declarations + // would be to tag the `nvHLSLExtns.h` header in a list of "magic" + // headers which should get all their declarations flagged during + // front-end processing, and then use the same header again during + // downstream compilation. + // + // For now, the current hackery seems a bit less complicated. + // + if (auto structDecl = as(varDecl->parentDecl)) + { + if (structDecl->getName() && structDecl->getName()->text == "NvShaderExtnStruct") { addModifier(varDecl, m_astBuilder->create()); } - // - // One thing that the `NVAPIMagicModifier` is going to do is ensure - // that `g_NvidiaExt` always gets emitted with *exactly* that name, - // whether or not obfuscation or other steps are enabled. - // - // The `g_NvidiaExt` variable is declared as a: - // - // RWStructuredBuffer - // - // and we also want to make sure that the fields of that struct - // retain their original names in output code. We will detect - // variable declarations that represent fields of that struct - // and flag them as "magic" as well. - // - // Note: The goal here is to make it so that generated HLSL output - // can either use these declarations as they have been preocessed - // by the Slang front-end *or* they can use declarations directly - // from the NVAPI header during downstream compilation. - // - // TODO: It would be nice if we had a way to identify *all* of the - // declarations that come from the NVAPI header and mark them, so - // that the Slang front-end doesn't have to take responsibility - // for generating code from them (and can instead rely on the downstream - // compiler alone). - // - // The NVAPI header doesn't put any kind of macro-defined modifier - // (defaulting to an empty macro) in front of its declarations, - // so the most plausible way to add a modifier to all the declarations - // would be to tag the `nvHLSLExtns.h` header in a list of "magic" - // headers which should get all their declarations flagged during - // front-end processing, and then use the same header again during - // downstream compilation. - // - // For now, the current hackery seems a bit less complicated. - // - if( auto structDecl = as(varDecl->parentDecl)) + } + + if (const auto interfaceDecl = as(varDecl->parentDecl)) + { + if (auto basicType = as(varDecl->getType())) { - if( structDecl->getName() - && structDecl->getName()->text == "NvShaderExtnStruct" ) + switch (basicType->getBaseType()) { - addModifier(varDecl, m_astBuilder->create()); - } - } - - if (const auto interfaceDecl = as(varDecl->parentDecl)) - { - if (auto basicType = as(varDecl->getType())) - { - switch (basicType->getBaseType()) - { - case BaseType::Bool: - case BaseType::Int8: - case BaseType::Int16: - case BaseType::Int: - case BaseType::Int64: - case BaseType::IntPtr: - case BaseType::UInt8: - case BaseType::UInt16: - case BaseType::UInt: - case BaseType::UInt64: - case BaseType::UIntPtr: - break; - default: - getSink()->diagnose(varDecl, Diagnostics::staticConstRequirementMustBeIntOrBool); - break; - } - } - if (!varDecl->findModifier() || !varDecl->findModifier()) - { - getSink()->diagnose(varDecl, Diagnostics::valueRequirementMustBeCompileTimeConst); + case BaseType::Bool: + case BaseType::Int8: + case BaseType::Int16: + case BaseType::Int: + case BaseType::Int64: + case BaseType::IntPtr: + case BaseType::UInt8: + case BaseType::UInt16: + case BaseType::UInt: + case BaseType::UInt64: + case BaseType::UIntPtr: break; + default: + getSink()->diagnose(varDecl, Diagnostics::staticConstRequirementMustBeIntOrBool); + break; } } - - // Check modifiers that can't be checked earlier during modifier checking stage. - if (auto derivativeMemberAttr = varDecl->findModifier()) - { - checkDerivativeMemberAttributeParent(varDecl, derivativeMemberAttr); - } - if (auto extensionExternAttr = varDecl->findModifier()) + if (!varDecl->findModifier() || !varDecl->findModifier()) { - checkExtensionExternVarAttribute(varDecl, extensionExternAttr); + getSink()->diagnose(varDecl, Diagnostics::valueRequirementMustBeCompileTimeConst); } + } - // If a var decl has no_diff type, move the no_diff modifier from the type to the var. - if (auto modifiedType = as(varDecl->type.type)) + // Check modifiers that can't be checked earlier during modifier checking stage. + if (auto derivativeMemberAttr = varDecl->findModifier()) + { + checkDerivativeMemberAttributeParent(varDecl, derivativeMemberAttr); + } + if (auto extensionExternAttr = varDecl->findModifier()) + { + checkExtensionExternVarAttribute(varDecl, extensionExternAttr); + } + + // If a var decl has no_diff type, move the no_diff modifier from the type to the var. + if (auto modifiedType = as(varDecl->type.type)) + { + if (auto nodiffModifier = modifiedType->findModifier()) { - if (auto nodiffModifier = modifiedType->findModifier()) - { - varDecl->type.type = getRemovedModifierType(modifiedType, nodiffModifier); - auto noDiffModifier = m_astBuilder->create(); - noDiffModifier->loc = varDecl->loc; - addModifier(varDecl, noDiffModifier); - } + varDecl->type.type = getRemovedModifierType(modifiedType, nodiffModifier); + auto noDiffModifier = m_astBuilder->create(); + noDiffModifier->loc = varDecl->loc; + addModifier(varDecl, noDiffModifier); } + } - if (as(varDecl->parentDecl)) - { - // If this is a global variable with [vk::push_constant] attribute, - // we need to make sure to wrap it in a `ConstantBuffer`. - - if (!as(varDecl->type)) - { - if (varDecl->findModifier()) - { - varDecl->type.type = m_astBuilder->getConstantBufferType(varDecl->type); - } - } + if (as(varDecl->parentDecl)) + { + // If this is a global variable with [vk::push_constant] attribute, + // we need to make sure to wrap it in a `ConstantBuffer`. - if (getModuleDecl(varDecl)->hasModifier()) + if (!as(varDecl->type)) + { + if (varDecl->findModifier()) { - // If we are in GLSL compatiblity mode, we want to treat all global variables - // without any `uniform` modifiers as true global variables by default. - if (!varDecl->findModifier() && - !varDecl->findModifier() && - !varDecl->findModifier() && - !varDecl->findModifier()) - { - if (!isUniformParameterType(varDecl->type)) - { - auto staticModifier = m_astBuilder->create(); - addModifier(varDecl, staticModifier); - } - } + varDecl->type.type = m_astBuilder->getConstantBufferType(varDecl->type); } } - // Propagate type tags. - if (auto parentAggTypeDecl = as(getParentDecl(varDecl))) + if (getModuleDecl(varDecl)->hasModifier()) { - if (auto varDeclRefType = as(varDecl->type.type)) + // If we are in GLSL compatiblity mode, we want to treat all global variables + // without any `uniform` modifiers as true global variables by default. + if (!varDecl->findModifier() && + !varDecl->findModifier() && !varDecl->findModifier() && + !varDecl->findModifier()) { - parentAggTypeDecl->unionTagsWith(getTypeTags(varDeclRefType)); + if (!isUniformParameterType(varDecl->type)) + { + auto staticModifier = m_astBuilder->create(); + addModifier(varDecl, staticModifier); + } } } - if (getOptionSet().getBoolOption(CompilerOptionName::NoMangle) && - isGlobalDecl(varDecl)) - { - // If -no-mangle option is set, we will add `ExternCpp` modifier to all - // global variables and struct fields to prevent mangling. - addModifier(varDecl, m_astBuilder->create()); - } - checkVisibility(varDecl); - } - - static ConstructorDecl* _createCtor(SemanticsDeclVisitorBase* visitor, ASTBuilder* m_astBuilder, AggTypeDecl* decl) - { - auto ctor = m_astBuilder->create(); - addModifier(ctor, m_astBuilder->create()); - auto ctorName = visitor->getName("$init"); - ctor->ownedScope = m_astBuilder->create(); - ctor->ownedScope->containerDecl = ctor; - ctor->ownedScope->parent = visitor->getScope(decl); - ctor->parentDecl = decl; - ctor->loc = decl->loc; - ctor->closingSourceLoc = ctor->loc; - ctor->nameAndLoc.name = ctorName; - ctor->nameAndLoc.loc = ctor->loc; - ctor->returnType.type = visitor->calcThisType(makeDeclRef(decl)); - auto body = m_astBuilder->create(); - body->scopeDecl = m_astBuilder->create(); - body->scopeDecl->ownedScope = m_astBuilder->create(); - body->scopeDecl->ownedScope->parent = visitor->getScope(ctor); - body->scopeDecl->parentDecl = ctor; - body->scopeDecl->loc = ctor->loc; - body->scopeDecl->closingSourceLoc = ctor->loc; - body->closingSourceLoc = ctor->closingSourceLoc; - ctor->body = body; - body->body = m_astBuilder->create(); - ctor->isSynthesized = true; - decl->addMember(ctor); - return ctor; } - static ConstructorDecl* _getDefaultCtor(StructDecl* structDecl) + // Propagate type tags. + if (auto parentAggTypeDecl = as(getParentDecl(varDecl))) { - for (auto ctor : structDecl->getMembersOfType()) + if (auto varDeclRefType = as(varDecl->type.type)) { - if (!ctor->body || ctor->members.getCount() != 0) - continue; - return ctor; + parentAggTypeDecl->unionTagsWith(getTypeTags(varDeclRefType)); } - return nullptr; } - - - static List _getCtorList(ASTBuilder* m_astBuilder, SemanticsVisitor* visitor, StructDecl* structDecl, ConstructorDecl** defaultCtorOut) + if (getOptionSet().getBoolOption(CompilerOptionName::NoMangle) && isGlobalDecl(varDecl)) { - List ctorList; + // If -no-mangle option is set, we will add `ExternCpp` modifier to all + // global variables and struct fields to prevent mangling. + addModifier(varDecl, m_astBuilder->create()); + } + checkVisibility(varDecl); +} - auto ctorLookupResult = lookUpMember( - m_astBuilder, - visitor, - visitor->getName("$init"), - DeclRefType::create(m_astBuilder, structDecl), - structDecl->ownedScope, - LookupMask::Function, - (LookupOptions)((Index)LookupOptions::IgnoreInheritance | (Index)LookupOptions::NoDeref)); +static ConstructorDecl* _createCtor( + SemanticsDeclVisitorBase* visitor, + ASTBuilder* m_astBuilder, + AggTypeDecl* decl) +{ + auto ctor = m_astBuilder->create(); + addModifier(ctor, m_astBuilder->create()); + auto ctorName = visitor->getName("$init"); + ctor->ownedScope = m_astBuilder->create(); + ctor->ownedScope->containerDecl = ctor; + ctor->ownedScope->parent = visitor->getScope(decl); + ctor->parentDecl = decl; + ctor->loc = decl->loc; + ctor->closingSourceLoc = ctor->loc; + ctor->nameAndLoc.name = ctorName; + ctor->nameAndLoc.loc = ctor->loc; + ctor->returnType.type = visitor->calcThisType(makeDeclRef(decl)); + auto body = m_astBuilder->create(); + body->scopeDecl = m_astBuilder->create(); + body->scopeDecl->ownedScope = m_astBuilder->create(); + body->scopeDecl->ownedScope->parent = visitor->getScope(ctor); + body->scopeDecl->parentDecl = ctor; + body->scopeDecl->loc = ctor->loc; + body->scopeDecl->closingSourceLoc = ctor->loc; + body->closingSourceLoc = ctor->closingSourceLoc; + ctor->body = body; + body->body = m_astBuilder->create(); + ctor->isSynthesized = true; + decl->addMember(ctor); + return ctor; +} - if (!ctorLookupResult.isValid()) - return ctorList; +static ConstructorDecl* _getDefaultCtor(StructDecl* structDecl) +{ + for (auto ctor : structDecl->getMembersOfType()) + { + if (!ctor->body || ctor->members.getCount() != 0) + continue; + return ctor; + } + return nullptr; +} - auto lookupResultHandle = [&](LookupResultItem& item) - { - auto ctor = as(item.declRef.getDecl()); - if (!ctor || !ctor->body) - return; - ctorList.add(ctor); - if (ctor->members.getCount() != 0) - return; - *defaultCtorOut = ctor; - }; - if (ctorLookupResult.items.getCount() == 0) - { - lookupResultHandle(ctorLookupResult.item); - return ctorList; - } - for (auto m : ctorLookupResult.items) - { - lookupResultHandle(m); - } +static List _getCtorList( + ASTBuilder* m_astBuilder, + SemanticsVisitor* visitor, + StructDecl* structDecl, + ConstructorDecl** defaultCtorOut) +{ + List ctorList; + + auto ctorLookupResult = lookUpMember( + m_astBuilder, + visitor, + visitor->getName("$init"), + DeclRefType::create(m_astBuilder, structDecl), + structDecl->ownedScope, + LookupMask::Function, + (LookupOptions)((Index)LookupOptions::IgnoreInheritance | (Index)LookupOptions::NoDeref)); + + if (!ctorLookupResult.isValid()) + return ctorList; + auto lookupResultHandle = [&](LookupResultItem& item) + { + auto ctor = as(item.declRef.getDecl()); + if (!ctor || !ctor->body) + return; + ctorList.add(ctor); + if (ctor->members.getCount() != 0) + return; + *defaultCtorOut = ctor; + }; + if (ctorLookupResult.items.getCount() == 0) + { + lookupResultHandle(ctorLookupResult.item); return ctorList; } - void SemanticsDeclHeaderVisitor::visitStructDecl(StructDecl* structDecl) + for (auto m : ctorLookupResult.items) { - // As described above in `SemanticsDeclHeaderVisitor::checkVarDeclCommon`, - // we want to identify and tag the "magic" declarations that make NVAPI - // work, so that downstream passes can identify them and act accordingly. - // - // In this case, we are looking for the `NvShaderExtnStruct` type, which - // is used by `g_NvidiaExt`. - // - if( structDecl->getName() - && structDecl->getName()->text == "NvShaderExtnStruct" ) - { - addModifier(structDecl, m_astBuilder->create()); - } + lookupResultHandle(m); + } - if (structDecl->hasModifier()) - { - structDecl->addTag(TypeTag::Incomplete); - } + return ctorList; +} - // Slang supports a convenient syntax to create a wrapper type from - // an existing type that implements a given interface. For example, - // the user can write: struct FooWrapper:IFoo = Foo; - // In this case we will synthesize the FooWrapper type with an inner - // member of type `Foo`, and use it to implement all requirements of - // IFoo. - // If this is a wrapper struct, synthesize the inner member now. - if (structDecl->wrappedType.exp) - { - structDecl->wrappedType = CheckProperType(structDecl->wrappedType); - auto member = m_astBuilder->create(); - member->type = structDecl->wrappedType; - member->nameAndLoc.name = getName("inner"); - member->nameAndLoc.loc = structDecl->wrappedType.exp->loc; - member->loc = member->nameAndLoc.loc; - structDecl->addMember(member); - } - checkVisibility(structDecl); +void SemanticsDeclHeaderVisitor::visitStructDecl(StructDecl* structDecl) +{ + // As described above in `SemanticsDeclHeaderVisitor::checkVarDeclCommon`, + // we want to identify and tag the "magic" declarations that make NVAPI + // work, so that downstream passes can identify them and act accordingly. + // + // In this case, we are looking for the `NvShaderExtnStruct` type, which + // is used by `g_NvidiaExt`. + // + if (structDecl->getName() && structDecl->getName()->text == "NvShaderExtnStruct") + { + addModifier(structDecl, m_astBuilder->create()); } - void SemanticsDeclHeaderVisitor::visitClassDecl(ClassDecl* classDecl) + if (structDecl->hasModifier()) { - if (classDecl->hasModifier()) - { - classDecl->addTag(TypeTag::Incomplete); - } - checkVisibility(classDecl); + structDecl->addTag(TypeTag::Incomplete); } - bool DiagnoseIsAllowedInitExpr(VarDeclBase* varDecl, DiagnosticSink* sink) + // Slang supports a convenient syntax to create a wrapper type from + // an existing type that implements a given interface. For example, + // the user can write: struct FooWrapper:IFoo = Foo; + // In this case we will synthesize the FooWrapper type with an inner + // member of type `Foo`, and use it to implement all requirements of + // IFoo. + // If this is a wrapper struct, synthesize the inner member now. + if (structDecl->wrappedType.exp) { - // find groupshared modifier - if (varDecl->findModifier()) - { - if (sink && varDecl->initExpr) - sink->diagnose(varDecl, Diagnostics::cannotHaveInitializer, varDecl, "groupshared"); - return false; - } + structDecl->wrappedType = CheckProperType(structDecl->wrappedType); + auto member = m_astBuilder->create(); + member->type = structDecl->wrappedType; + member->nameAndLoc.name = getName("inner"); + member->nameAndLoc.loc = structDecl->wrappedType.exp->loc; + member->loc = member->nameAndLoc.loc; + structDecl->addMember(member); + } + checkVisibility(structDecl); +} - return true; +void SemanticsDeclHeaderVisitor::visitClassDecl(ClassDecl* classDecl) +{ + if (classDecl->hasModifier()) + { + classDecl->addTag(TypeTag::Incomplete); } + checkVisibility(classDecl); +} - bool isDefaultInitializable(VarDeclBase* varDecl) +bool DiagnoseIsAllowedInitExpr(VarDeclBase* varDecl, DiagnosticSink* sink) +{ + // find groupshared modifier + if (varDecl->findModifier()) { - if (!DiagnoseIsAllowedInitExpr(varDecl, nullptr)) - return false; + if (sink && varDecl->initExpr) + sink->diagnose(varDecl, Diagnostics::cannotHaveInitializer, varDecl, "groupshared"); + return false; + } + + return true; +} - // Find struct and modifiers associated with varDecl - StructDecl* structDecl = as(varDecl); - if (auto declRefType = as(varDecl->getType())) +bool isDefaultInitializable(VarDeclBase* varDecl) +{ + if (!DiagnoseIsAllowedInitExpr(varDecl, nullptr)) + return false; + + // Find struct and modifiers associated with varDecl + StructDecl* structDecl = as(varDecl); + if (auto declRefType = as(varDecl->getType())) + { + if (auto genericAppRefDecl = as(declRefType->getDeclRefBase())) { - if (auto genericAppRefDecl = as(declRefType->getDeclRefBase())) + auto baseGenericRefType = genericAppRefDecl->getBase()->getDecl(); + if (auto baseTypeStruct = as(baseGenericRefType)) { - auto baseGenericRefType = genericAppRefDecl->getBase()->getDecl(); - if (auto baseTypeStruct = as(baseGenericRefType)) - { - structDecl = baseTypeStruct; - } - else if (auto genericDecl = as(baseGenericRefType)) - { - if(auto innerTypeStruct = as(genericDecl->inner)) - structDecl = innerTypeStruct; - } + structDecl = baseTypeStruct; + } + else if (auto genericDecl = as(baseGenericRefType)) + { + if (auto innerTypeStruct = as(genericDecl->inner)) + structDecl = innerTypeStruct; } } - if (structDecl) + } + if (structDecl) + { + // find if a type is non-copyable + if (structDecl->findModifier()) + return false; + } + + return true; +} + +static Expr* constructDefaultInitExprForVar(SemanticsVisitor* visitor, VarDeclBase* varDecl) +{ + if (!varDecl->type || !varDecl->type.type) + return nullptr; + + if (!isDefaultInitializable(varDecl)) + return nullptr; + + ConstructorDecl* defaultCtor = nullptr; + auto declRefType = as(varDecl->type.type); + if (declRefType) + { + if (auto structDecl = as(declRefType->getDeclRef().getDecl())) { - // find if a type is non-copyable - if (structDecl->findModifier()) - return false; + defaultCtor = _getDefaultCtor(structDecl); } - - return true; } - static Expr* constructDefaultInitExprForVar(SemanticsVisitor* visitor, VarDeclBase* varDecl) + if (defaultCtor) { - if (!varDecl->type || !varDecl->type.type) - return nullptr; - - if (!isDefaultInitializable(varDecl)) - return nullptr; + auto* invoke = visitor->getASTBuilder()->create(); + auto member = + visitor->getASTBuilder()->getMemberDeclRef(declRefType->getDeclRef(), defaultCtor); + invoke->functionExpr = visitor->ConstructDeclRefExpr( + member, + nullptr, + defaultCtor->getName(), + defaultCtor->loc, + nullptr); + return invoke; + } + else + { + auto* defaultCall = visitor->getASTBuilder()->create(); + defaultCall->type = QualType(varDecl->type); + return defaultCall; + } +} - ConstructorDecl* defaultCtor = nullptr; - auto declRefType = as(varDecl->type.type); - if (declRefType) - { - if (auto structDecl = as(declRefType->getDeclRef().getDecl())) - { - defaultCtor = _getDefaultCtor(structDecl); - } - } +void SemanticsDeclBodyVisitor::checkVarDeclCommon(VarDeclBase* varDecl) +{ + DiagnoseIsAllowedInitExpr(varDecl, getSink()); - if (defaultCtor) - { - auto* invoke = visitor->getASTBuilder()->create(); - auto member = visitor->getASTBuilder()->getMemberDeclRef(declRefType->getDeclRef(), defaultCtor); - invoke->functionExpr = visitor->ConstructDeclRefExpr(member, nullptr, defaultCtor->getName(), defaultCtor->loc, nullptr); - return invoke; - } - else - { - auto* defaultCall = visitor->getASTBuilder()->create(); - defaultCall->type = QualType(varDecl->type); - return defaultCall; - } + // if zero initialize is true, set everything to a default + if (getOptionSet().hasOption(CompilerOptionName::ZeroInitialize) && !varDecl->initExpr && + as(varDecl)) + { + varDecl->initExpr = constructDefaultInitExprForVar(this, varDecl); } - void SemanticsDeclBodyVisitor::checkVarDeclCommon(VarDeclBase* varDecl) + if (auto initExpr = varDecl->initExpr) { - DiagnoseIsAllowedInitExpr(varDecl, getSink()); + // Disable the short-circuiting for static const variable init expression + bool isStaticConst = + varDecl->hasModifier() && varDecl->hasModifier(); - // if zero initialize is true, set everything to a default - if (getOptionSet().hasOption(CompilerOptionName::ZeroInitialize) - && !varDecl->initExpr - && as(varDecl) - ) - { - varDecl->initExpr = constructDefaultInitExprForVar(this, varDecl); - } - - if (auto initExpr = varDecl->initExpr) - { - // Disable the short-circuiting for static const variable init expression - bool isStaticConst = varDecl->hasModifier() && - varDecl->hasModifier(); + auto subVisitor = + isStaticConst ? SemanticsVisitor(disableShortCircuitLogicalExpr()) : *this; + // If the variable has an explicit initial-value expression, + // then we simply need to check that expression and coerce + // it to the type of the variable. + // + initExpr = subVisitor.CheckTerm(initExpr); - auto subVisitor = isStaticConst? - SemanticsVisitor(disableShortCircuitLogicalExpr()) : *this; - // If the variable has an explicit initial-value expression, - // then we simply need to check that expression and coerce - // it to the type of the variable. - // - initExpr = subVisitor.CheckTerm(initExpr); + if (initExpr->type.isWriteOnly) + getSink()->diagnose(initExpr, Diagnostics::readingFromWriteOnly); + initExpr = coerce(CoercionSite::Initializer, varDecl->type.Ptr(), initExpr); + varDecl->initExpr = initExpr; - if (initExpr->type.isWriteOnly) - getSink()->diagnose(initExpr, Diagnostics::readingFromWriteOnly); - initExpr = coerce(CoercionSite::Initializer, varDecl->type.Ptr(), initExpr); - varDecl->initExpr = initExpr; + // We need to ensure that any variable doesn't introduce + // a constant with a circular definition. + // + varDecl->setCheckState(DeclCheckState::DefinitionChecked); + _validateCircularVarDefinition(varDecl); + } + else + { + // If a variable doesn't have an explicit initial-value + // expression, it is still possible that it should + // be initialized implicitly, because the type of the + // variable has a default (zero parameter) initializer. + // That is, for types where it is possible, we will + // treat a variable declared like this: + // + // MyType myVar; + // + // as if it were declared as: + // + // MyType myVar = MyType(); + // + // Rather than try to code up an ad hoc search for an + // appropriate initializer here, we will instead fall + // back on the general-purpose overload-resolution + // machinery, which can handle looking up initializers + // and filtering them to ones that are applicable + // to our "call site" with zero arguments. + // + OverloadResolveContext overloadContext; + overloadContext.loc = varDecl->nameAndLoc.loc; + overloadContext.mode = OverloadResolveContext::Mode::JustTrying; + overloadContext.sourceScope = m_outerScope; - // We need to ensure that any variable doesn't introduce - // a constant with a circular definition. - // - varDecl->setCheckState(DeclCheckState::DefinitionChecked); - _validateCircularVarDefinition(varDecl); + auto type = varDecl->getType(); + ImplicitCastMethodKey key = ImplicitCastMethodKey(QualType(), type, nullptr); + auto ctorMethod = getShared()->tryGetImplicitCastMethod(key); + if (ctorMethod) + { + overloadContext.bestCandidateStorage = ctorMethod->conversionFuncOverloadCandidate; + overloadContext.bestCandidate = &overloadContext.bestCandidateStorage; } else { - // If a variable doesn't have an explicit initial-value - // expression, it is still possible that it should - // be initialized implicitly, because the type of the - // variable has a default (zero parameter) initializer. - // That is, for types where it is possible, we will - // treat a variable declared like this: - // - // MyType myVar; - // - // as if it were declared as: - // - // MyType myVar = MyType(); + AddTypeOverloadCandidates(type, overloadContext); + } + + if (overloadContext.bestCandidates.getCount() != 0) + { + // If there were multiple equally-good candidates to call, + // then might have an ambiguity. // - // Rather than try to code up an ad hoc search for an - // appropriate initializer here, we will instead fall - // back on the general-purpose overload-resolution - // machinery, which can handle looking up initializers - // and filtering them to ones that are applicable - // to our "call site" with zero arguments. + // Before issuing any kind of diagnostic we need to check + // if any of those candidates are actually applicable, + // because if they aren't then we actually just have + // an uninitialized varaible. // - OverloadResolveContext overloadContext; - overloadContext.loc = varDecl->nameAndLoc.loc; - overloadContext.mode = OverloadResolveContext::Mode::JustTrying; - overloadContext.sourceScope = m_outerScope; - - auto type = varDecl->getType(); - ImplicitCastMethodKey key = ImplicitCastMethodKey(QualType(), type, nullptr); - auto ctorMethod = getShared()->tryGetImplicitCastMethod(key); - if (ctorMethod) + if (overloadContext.bestCandidates[0].status != OverloadCandidate::Status::Applicable) { - overloadContext.bestCandidateStorage = ctorMethod->conversionFuncOverloadCandidate; - overloadContext.bestCandidate = &overloadContext.bestCandidateStorage; + getShared()->cacheImplicitCastMethod(key, ImplicitCastMethod{}); } else { - AddTypeOverloadCandidates(type, overloadContext); + getSink()->diagnose(varDecl, Diagnostics::ambiguousDefaultInitializerForType, type); } - - if(overloadContext.bestCandidates.getCount() != 0) + } + else if (overloadContext.bestCandidate) + { + // If we are in the single-candidate case, then we again + // want to ignore the case where that candidate wasn't + // actually applicable, because declaring a variable + // of a type that *doesn't* have a default initializer + // isn't actually an error. + // + if (overloadContext.bestCandidate->status != OverloadCandidate::Status::Applicable) { - // If there were multiple equally-good candidates to call, - // then might have an ambiguity. - // - // Before issuing any kind of diagnostic we need to check - // if any of those candidates are actually applicable, - // because if they aren't then we actually just have - // an uninitialized varaible. - // - if (overloadContext.bestCandidates[0].status != OverloadCandidate::Status::Applicable) - { - getShared()->cacheImplicitCastMethod(key, ImplicitCastMethod{}); - } - else - { - getSink()->diagnose(varDecl, Diagnostics::ambiguousDefaultInitializerForType, type); - } + getShared()->cacheImplicitCastMethod(key, ImplicitCastMethod{}); } - else if(overloadContext.bestCandidate) + else { - // If we are in the single-candidate case, then we again - // want to ignore the case where that candidate wasn't - // actually applicable, because declaring a variable - // of a type that *doesn't* have a default initializer - // isn't actually an error. + // If we had a single best candidate *and* it was applicable, + // then we use it to construct a new initial-value expression + // for the variable, that will be used for all downstream + // code generation. // - if (overloadContext.bestCandidate->status != OverloadCandidate::Status::Applicable) - { - getShared()->cacheImplicitCastMethod(key, ImplicitCastMethod{}); - } - else - { - // If we had a single best candidate *and* it was applicable, - // then we use it to construct a new initial-value expression - // for the variable, that will be used for all downstream - // code generation. - // - varDecl->initExpr = CompleteOverloadCandidate(overloadContext, *overloadContext.bestCandidate); - getShared()->cacheImplicitCastMethod(key, ImplicitCastMethod{*overloadContext.bestCandidate, 0}); - } + varDecl->initExpr = + CompleteOverloadCandidate(overloadContext, *overloadContext.bestCandidate); + getShared()->cacheImplicitCastMethod( + key, + ImplicitCastMethod{*overloadContext.bestCandidate, 0}); } } + } - TypeTag varTypeTags = getTypeTags(varDecl->getType()); - auto parentDecl = as(getParentDecl(varDecl)); - if (parentDecl) + TypeTag varTypeTags = getTypeTags(varDecl->getType()); + auto parentDecl = as(getParentDecl(varDecl)); + if (parentDecl) + { + parentDecl->addTag(varTypeTags); + auto unsizedMask = (int)TypeTag::Unsized; + bool isUnknownSize = (((int)varTypeTags & unsizedMask) != 0); + if (isUnknownSize) { - parentDecl->addTag(varTypeTags); - auto unsizedMask = (int)TypeTag::Unsized; - bool isUnknownSize = (((int)varTypeTags & unsizedMask) != 0); - if (isUnknownSize) + // Unsized decl must appear as the last member of the struct. + for (auto memberIdx = parentDecl->members.getCount() - 1; memberIdx >= 0; memberIdx--) { - // Unsized decl must appear as the last member of the struct. - for (auto memberIdx = parentDecl->members.getCount() - 1; memberIdx >= 0; memberIdx--) + if (parentDecl->members[memberIdx] == varDecl) { - if (parentDecl->members[memberIdx] == varDecl) - { - break; - } - if (auto memberVarDecl = as(parentDecl->members[memberIdx])) + break; + } + if (auto memberVarDecl = as(parentDecl->members[memberIdx])) + { + if (!memberVarDecl->hasModifier()) { - if (!memberVarDecl->hasModifier()) - { - getSink()->diagnose(varDecl, Diagnostics::unsizedMemberMustAppearLast); - } - break; + getSink()->diagnose(varDecl, Diagnostics::unsizedMemberMustAppearLast); } + break; } } } - bool isGlobalOrLocalVar = !isGlobalShaderParameter(varDecl) && !as(varDecl) && - (!parentDecl || isEffectivelyStatic(varDecl)); - if (isGlobalOrLocalVar) + } + bool isGlobalOrLocalVar = !isGlobalShaderParameter(varDecl) && !as(varDecl) && + (!parentDecl || isEffectivelyStatic(varDecl)); + if (isGlobalOrLocalVar) + { + bool isUnsized = (((int)varTypeTags & (int)TypeTag::Unsized) != 0); + if (isUnsized) { - bool isUnsized = (((int)varTypeTags & (int)TypeTag::Unsized) != 0); - if (isUnsized) - { - getSink()->diagnose(varDecl, Diagnostics::varCannotBeUnsized); - } + getSink()->diagnose(varDecl, Diagnostics::varCannotBeUnsized); } + } - if (auto elementType = getConstantBufferElementType(varDecl->getType())) + if (auto elementType = getConstantBufferElementType(varDecl->getType())) + { + if (doesTypeHaveTag(elementType, TypeTag::Incomplete)) { - if (doesTypeHaveTag(elementType, TypeTag::Incomplete)) - { - getSink()->diagnose(varDecl->type.exp->loc, Diagnostics::incompleteTypeCannotBeUsedInBuffer, elementType); - } - if (doesTypeHaveTag(elementType, TypeTag::Unsized)) - { - // If the element type is unsized, it can only be an array of resource types that we can legalize out. - // Ordinary unsized arrays are not allowed in a constant buffer since we cannot translate it to - // valid HLSL or SPIRV. - ArrayExpressionType* trailingArrayType = nullptr; - VarDeclBase* trailingArrayField = getTrailingUnsizedArrayElement(elementType, varDecl, trailingArrayType); - if (trailingArrayField && !isOpaqueHandleType(trailingArrayType->getElementType())) - { - getSink()->diagnose(trailingArrayField->loc, Diagnostics::cannotUseUnsizedTypeInConstantBuffer, trailingArrayType); - getSink()->diagnose(varDecl->loc, Diagnostics::seeConstantBufferDefinition); - } - } + getSink()->diagnose( + varDecl->type.exp->loc, + Diagnostics::incompleteTypeCannotBeUsedInBuffer, + elementType); } - else if (varDecl->findModifier()) + if (doesTypeHaveTag(elementType, TypeTag::Unsized)) { - auto varType = varDecl->getType(); - if (doesTypeHaveTag(varType, TypeTag::Incomplete)) + // If the element type is unsized, it can only be an array of resource types that we can + // legalize out. Ordinary unsized arrays are not allowed in a constant buffer since we + // cannot translate it to valid HLSL or SPIRV. + ArrayExpressionType* trailingArrayType = nullptr; + VarDeclBase* trailingArrayField = + getTrailingUnsizedArrayElement(elementType, varDecl, trailingArrayType); + if (trailingArrayField && !isOpaqueHandleType(trailingArrayType->getElementType())) { - getSink()->diagnose(varDecl->type.exp->loc, Diagnostics::incompleteTypeCannotBeUsedInUniformParameter, varType); + getSink()->diagnose( + trailingArrayField->loc, + Diagnostics::cannotUseUnsizedTypeInConstantBuffer, + trailingArrayType); + getSink()->diagnose(varDecl->loc, Diagnostics::seeConstantBufferDefinition); } } - maybeRegisterDifferentiableType(getASTBuilder(), varDecl->getType()); } - - // Fill in default substitutions for the 'subtype' part of a type constraint decl - void SemanticsVisitor::CheckConstraintSubType(TypeExp& typeExp) + else if (varDecl->findModifier()) { - if (auto sharedTypeExpr = as(typeExp.exp)) + auto varType = varDecl->getType(); + if (doesTypeHaveTag(varType, TypeTag::Incomplete)) { - if (auto declRefType = as(sharedTypeExpr->base)) - { - auto newDeclRef = createDefaultSubstitutionsIfNeeded(m_astBuilder, this, declRefType->getDeclRef()); - auto newType = DeclRefType::create(m_astBuilder, newDeclRef); - sharedTypeExpr->base.type = newType; - if (as(typeExp.exp->type)) - typeExp.exp->type = m_astBuilder->getTypeType(newType); - } + getSink()->diagnose( + varDecl->type.exp->loc, + Diagnostics::incompleteTypeCannotBeUsedInUniformParameter, + varType); } } + maybeRegisterDifferentiableType(getASTBuilder(), varDecl->getType()); +} - void addVisibilityModifier(ASTBuilder* builder, Decl* decl, DeclVisibility vis) +// Fill in default substitutions for the 'subtype' part of a type constraint decl +void SemanticsVisitor::CheckConstraintSubType(TypeExp& typeExp) +{ + if (auto sharedTypeExpr = as(typeExp.exp)) { - switch (vis) + if (auto declRefType = as(sharedTypeExpr->base)) { - case DeclVisibility::Public: - addModifier(decl, builder->create()); - break; - case DeclVisibility::Internal: - addModifier(decl, builder->create()); - break; - case DeclVisibility::Private: - addModifier(decl, builder->create()); - break; - default: - break; + auto newDeclRef = + createDefaultSubstitutionsIfNeeded(m_astBuilder, this, declRefType->getDeclRef()); + auto newType = DeclRefType::create(m_astBuilder, newDeclRef); + sharedTypeExpr->base.type = newType; + if (as(typeExp.exp->type)) + typeExp.exp->type = m_astBuilder->getTypeType(newType); } } +} - bool SemanticsVisitor::trySynthesizeDifferentialAssociatedTypeRequirementWitness( - ConformanceCheckingContext* context, - DeclRef requirementDeclRef, - RefPtr witnessTable) +void addVisibilityModifier(ASTBuilder* builder, Decl* decl, DeclVisibility vis) +{ + switch (vis) { - ASTSynthesizer synth(m_astBuilder, getNamePool()); - Decl* existingDecl = nullptr; - AggTypeDecl* aggTypeDecl = nullptr; - if (context->parentDecl->getMemberDictionary().tryGetValue(requirementDeclRef.getName(), existingDecl)) - { - // Remove the `ToBeSynthesizedModifier`. - if (as(existingDecl->modifiers.first)) - { - existingDecl->modifiers.first = existingDecl->modifiers.first->next; - } - else - { - // The user has defined an associatedtype explicitly but that we reach here because - // that type failed to satisfy the `IDifferential` requirement. - // We stop the synthesis and let the follow-up logic to report a diagnostic. - return false; - } + case DeclVisibility::Public: addModifier(decl, builder->create()); break; + case DeclVisibility::Internal: addModifier(decl, builder->create()); break; + case DeclVisibility::Private: addModifier(decl, builder->create()); break; + default: break; + } +} - aggTypeDecl = as(existingDecl); - SLANG_RELEASE_ASSERT(aggTypeDecl); - synth.pushContainerScope(aggTypeDecl); +bool SemanticsVisitor::trySynthesizeDifferentialAssociatedTypeRequirementWitness( + ConformanceCheckingContext* context, + DeclRef requirementDeclRef, + RefPtr witnessTable) +{ + ASTSynthesizer synth(m_astBuilder, getNamePool()); + Decl* existingDecl = nullptr; + AggTypeDecl* aggTypeDecl = nullptr; + if (context->parentDecl->getMemberDictionary().tryGetValue( + requirementDeclRef.getName(), + existingDecl)) + { + // Remove the `ToBeSynthesizedModifier`. + if (as(existingDecl->modifiers.first)) + { + existingDecl->modifiers.first = existingDecl->modifiers.first->next; } - - // If we did not find an existing empty struct, we may need to synthesize one. - // But first, we check if the parent type can be used as its own differential type. - // - if (!aggTypeDecl - && as(context->parentDecl) - && canStructBeUsedAsSelfDifferentialType(as(context->parentDecl))) + else { - // If the parent type can be used as its own differential type, we will create a typealias - // to itself as the differential type. - // - auto assocTypeDef = m_astBuilder->create(); - assocTypeDef->nameAndLoc.name = getName("Differential"); - assocTypeDef->type.type = context->conformingType; - assocTypeDef->parentDecl = context->parentDecl; - assocTypeDef->setCheckState(DeclCheckState::DefinitionChecked); - context->parentDecl->members.add(assocTypeDef); - - markSelfDifferentialMembersOfType(as(context->parentDecl), context->conformingType); - - witnessTable->add(requirementDeclRef.getDecl(), RequirementWitness(context->conformingType)); - if (doesTypeSatisfyAssociatedTypeConstraintRequirement(context->conformingType, requirementDeclRef, witnessTable)) - { - - // Increase the epoch so that future calls to Type::getCanonicalType will return the up-to-date folded types. - m_astBuilder->incrementEpoch(); - return true; - } - else - { - witnessTable->m_requirementDictionary.remove(requirementDeclRef.getDecl()); - } - - // Something went wrong. + // The user has defined an associatedtype explicitly but that we reach here because + // that type failed to satisfy the `IDifferential` requirement. + // We stop the synthesis and let the follow-up logic to report a diagnostic. return false; } - if (!aggTypeDecl) - { - aggTypeDecl = m_astBuilder->create(); - aggTypeDecl->parentDecl = context->parentDecl; - context->parentDecl->members.add((aggTypeDecl)); - aggTypeDecl->nameAndLoc.name = requirementDeclRef.getName(); - aggTypeDecl->loc = context->parentDecl->nameAndLoc.loc; - context->parentDecl->invalidateMemberDictionary(); - synth.pushScopeForContainer(aggTypeDecl); - } + aggTypeDecl = as(existingDecl); + SLANG_RELEASE_ASSERT(aggTypeDecl); + synth.pushContainerScope(aggTypeDecl); + } - // If `This` is nested inside a generic, we need to form a complete declref type to the - // newly synthesized aggTypeDecl here. This can be done by obtaining the this type witness - // from requirementDeclRef to get the generic arguments for the outer generic, and - // apply it to the newly synthesized decl. - SubstitutionSet substSet; - Type* thisType = nullptr; - if (auto thisWitness = findThisTypeWitness( - SubstitutionSet(requirementDeclRef), - as(requirementDeclRef.getParent()).getDecl())) - { - thisType = thisWitness->getSub(); - if (auto declRefType = as(thisType)) - { - substSet = SubstitutionSet(declRefType->getDeclRef()); - } - } - if (!substSet.declRef) - return false; - Type* satisfyingType = nullptr; - if (substSet.declRef->getDecl() == context->parentDecl) - { - // The type we are synthesizing conformance for is direct inside a type itself. - // We need to copy the outer generic arguments to the synthesized type. - satisfyingType = DeclRefType::create(m_astBuilder, m_astBuilder->getMemberDeclRef(substSet.declRef, aggTypeDecl)); - } - else if (auto parentExtDecl = as(context->parentDecl)) - { - // The type is defined in an extension, we need to form a declref to the parent - // extension from the requirementDeclRef. - auto extDeclRef = applyExtensionToType(parentExtDecl, thisType); - satisfyingType = DeclRefType::create(m_astBuilder, m_astBuilder->getMemberDeclRef(extDeclRef, aggTypeDecl)); - } - if (!satisfyingType) - return false; - - // Helper function to add a `diffType` field into the synthesized type for the original - // `member`. - auto differentialType = DeclRefType::create(m_astBuilder, DeclRef(makeDeclRef(aggTypeDecl))); - auto addDiffMember = [&](Decl* member, Type* diffMemberType) - { - // If the field is differentiable, add a corresponding field in the associated Differential type. - auto diffField = m_astBuilder->create(); - diffField->nameAndLoc = member->nameAndLoc; - diffField->type.type = diffMemberType; - diffField->checkState = DeclCheckState::SignatureChecked; - diffField->parentDecl = aggTypeDecl; - aggTypeDecl->members.add(diffField); - - auto visibility = getDeclVisibility(member); - addVisibilityModifier(m_astBuilder, diffField, visibility); - - aggTypeDecl->invalidateMemberDictionary(); - - // Inject a `DerivativeMember` modifier on the differential field to point to itself. - { - auto derivativeMemberModifier = m_astBuilder->create(); - auto fieldLookupExpr = m_astBuilder->create(); - fieldLookupExpr->type.type = diffMemberType; - auto baseTypeExpr = m_astBuilder->create(); - baseTypeExpr->base.type = differentialType; - auto baseTypeType = m_astBuilder->getOrCreate(differentialType); - baseTypeExpr->type.type = baseTypeType; - fieldLookupExpr->baseExpression = baseTypeExpr; - fieldLookupExpr->declRef = makeDeclRef(diffField); - derivativeMemberModifier->memberDeclRef = fieldLookupExpr; - addModifier(diffField, derivativeMemberModifier); - } + // If we did not find an existing empty struct, we may need to synthesize one. + // But first, we check if the parent type can be used as its own differential type. + // + if (!aggTypeDecl && as(context->parentDecl) && + canStructBeUsedAsSelfDifferentialType(as(context->parentDecl))) + { + // If the parent type can be used as its own differential type, we will create a typealias + // to itself as the differential type. + // + auto assocTypeDef = m_astBuilder->create(); + assocTypeDef->nameAndLoc.name = getName("Differential"); + assocTypeDef->type.type = context->conformingType; + assocTypeDef->parentDecl = context->parentDecl; + assocTypeDef->setCheckState(DeclCheckState::DefinitionChecked); + context->parentDecl->members.add(assocTypeDef); - // Inject a `DerivativeMember` modifier on the original decl. - { - auto derivativeMemberModifier = m_astBuilder->create(); - auto fieldLookupExpr = m_astBuilder->create(); - fieldLookupExpr->type.type = diffMemberType; - auto baseTypeExpr = m_astBuilder->create(); - baseTypeExpr->base.type = differentialType; - auto baseTypeType = m_astBuilder->getOrCreate(differentialType); - baseTypeExpr->type.type = baseTypeType; - fieldLookupExpr->baseExpression = baseTypeExpr; - fieldLookupExpr->declRef = makeDeclRef(diffField); - derivativeMemberModifier->memberDeclRef = fieldLookupExpr; - addModifier(member, derivativeMemberModifier); - } - }; + markSelfDifferentialMembersOfType( + as(context->parentDecl), + context->conformingType); - // Make the Differential type itself conform to `IDifferential` interface. - bool hasDifferentialConformance = false; - for (auto inheritanceDecl : aggTypeDecl->getMembersOfType()) - { - if (auto declRefType = as(inheritanceDecl->base.type)) - { - if (declRefType->getDeclRef() == m_astBuilder->getDifferentiableInterfaceDecl()) - { - hasDifferentialConformance = true; - break; - } - } - } - if (!hasDifferentialConformance) + witnessTable->add( + requirementDeclRef.getDecl(), + RequirementWitness(context->conformingType)); + if (doesTypeSatisfyAssociatedTypeConstraintRequirement( + context->conformingType, + requirementDeclRef, + witnessTable)) { - auto inheritanceIDiffernetiable = m_astBuilder->create(); - inheritanceIDiffernetiable->base.type = m_astBuilder->getDiffInterfaceType(); - inheritanceIDiffernetiable->parentDecl = aggTypeDecl; - aggTypeDecl->members.add(inheritanceIDiffernetiable); - } - // The `Differential` type of a `Differential` type is always itself. - bool hasDifferentialTypeDef = false; - for (auto member : aggTypeDecl->members) - { - if (auto name = member->getName()) - { - if (name->text == "Differential") - { - hasDifferentialTypeDef = true; - break; - } - } + // Increase the epoch so that future calls to Type::getCanonicalType will return the + // up-to-date folded types. + m_astBuilder->incrementEpoch(); + return true; } - if (!hasDifferentialTypeDef) + else { - auto assocTypeDef = m_astBuilder->create(); - assocTypeDef->nameAndLoc.name = getName("Differential"); - assocTypeDef->type.type = satisfyingType; - assocTypeDef->parentDecl = aggTypeDecl; - assocTypeDef->setCheckState(DeclCheckState::DefinitionChecked); - aggTypeDecl->members.add(assocTypeDef); + witnessTable->m_requirementDictionary.remove(requirementDeclRef.getDecl()); } - // Go through all members and collect their differential types. - // Go through super types. - for (auto inheritance : context->parentDecl->getMembersOfType()) - { - if (auto baseDeclRefType = as(inheritance->base.type)) - { - // Skip interface super types. - if (baseDeclRefType->getDeclRef().as()) - continue; - if (auto superDiffType = tryGetDifferentialType(m_astBuilder, baseDeclRefType)) - { - addDiffMember(inheritance, superDiffType); - } - } - } - // Go through all var members. - for (auto member : context->parentDecl->getMembersOfType()) - { - if (member->hasModifier()) - continue; - auto diffType = tryGetDifferentialType(m_astBuilder, member->type.type); - if (!diffType) - continue; - addDiffMember(member, diffType); - } + // Something went wrong. + return false; + } - addModifier(aggTypeDecl, m_astBuilder->create()); + if (!aggTypeDecl) + { + aggTypeDecl = m_astBuilder->create(); + aggTypeDecl->parentDecl = context->parentDecl; + context->parentDecl->members.add((aggTypeDecl)); + aggTypeDecl->nameAndLoc.name = requirementDeclRef.getName(); + aggTypeDecl->loc = context->parentDecl->nameAndLoc.loc; + context->parentDecl->invalidateMemberDictionary(); + synth.pushScopeForContainer(aggTypeDecl); + } - // The visibility of synthesized decl should be the min of the parent decl and the requirement. - if (requirementDeclRef.getDecl()->findModifier()) + // If `This` is nested inside a generic, we need to form a complete declref type to the + // newly synthesized aggTypeDecl here. This can be done by obtaining the this type witness + // from requirementDeclRef to get the generic arguments for the outer generic, and + // apply it to the newly synthesized decl. + SubstitutionSet substSet; + Type* thisType = nullptr; + if (auto thisWitness = findThisTypeWitness( + SubstitutionSet(requirementDeclRef), + as(requirementDeclRef.getParent()).getDecl())) + { + thisType = thisWitness->getSub(); + if (auto declRefType = as(thisType)) { - auto requirementVisibility = getDeclVisibility(requirementDeclRef.getDecl()); - auto thisVisibility = getDeclVisibility(context->parentDecl); - auto visibility = Math::Min(thisVisibility, requirementVisibility); - addVisibilityModifier(m_astBuilder, aggTypeDecl, visibility); + substSet = SubstitutionSet(declRefType->getDeclRef()); } + } + if (!substSet.declRef) + return false; + Type* satisfyingType = nullptr; + if (substSet.declRef->getDecl() == context->parentDecl) + { + // The type we are synthesizing conformance for is direct inside a type itself. + // We need to copy the outer generic arguments to the synthesized type. + satisfyingType = DeclRefType::create( + m_astBuilder, + m_astBuilder->getMemberDeclRef(substSet.declRef, aggTypeDecl)); + } + else if (auto parentExtDecl = as(context->parentDecl)) + { + // The type is defined in an extension, we need to form a declref to the parent + // extension from the requirementDeclRef. + auto extDeclRef = applyExtensionToType(parentExtDecl, thisType); + satisfyingType = DeclRefType::create( + m_astBuilder, + m_astBuilder->getMemberDeclRef(extDeclRef, aggTypeDecl)); + } + if (!satisfyingType) + return false; - // Synthesize the rest of IDifferential method conformances by recursively checking - // conformance on the synthesized decl. - checkAggTypeConformance(aggTypeDecl); - - witnessTable->add(requirementDeclRef.getDecl(), RequirementWitness(satisfyingType)); - if (!doesTypeSatisfyAssociatedTypeConstraintRequirement(satisfyingType, requirementDeclRef, witnessTable)) - { - // Note: the call to `doesTypeSatisfyAssociatedTypeConstraintRequirement` should always succeed. - // If not, there is something wrong with the code synthesis logic. For now we just return false - // instead of crashing so the user can work around the issues. - witnessTable->m_requirementDictionary.remove(requirementDeclRef.getDecl()); - return false; + // Helper function to add a `diffType` field into the synthesized type for the original + // `member`. + auto differentialType = + DeclRefType::create(m_astBuilder, DeclRef(makeDeclRef(aggTypeDecl))); + auto addDiffMember = [&](Decl* member, Type* diffMemberType) + { + // If the field is differentiable, add a corresponding field in the associated Differential + // type. + auto diffField = m_astBuilder->create(); + diffField->nameAndLoc = member->nameAndLoc; + diffField->type.type = diffMemberType; + diffField->checkState = DeclCheckState::SignatureChecked; + diffField->parentDecl = aggTypeDecl; + aggTypeDecl->members.add(diffField); + + auto visibility = getDeclVisibility(member); + addVisibilityModifier(m_astBuilder, diffField, visibility); + + aggTypeDecl->invalidateMemberDictionary(); + + // Inject a `DerivativeMember` modifier on the differential field to point to itself. + { + auto derivativeMemberModifier = m_astBuilder->create(); + auto fieldLookupExpr = m_astBuilder->create(); + fieldLookupExpr->type.type = diffMemberType; + auto baseTypeExpr = m_astBuilder->create(); + baseTypeExpr->base.type = differentialType; + auto baseTypeType = m_astBuilder->getOrCreate(differentialType); + baseTypeExpr->type.type = baseTypeType; + fieldLookupExpr->baseExpression = baseTypeExpr; + fieldLookupExpr->declRef = makeDeclRef(diffField); + derivativeMemberModifier->memberDeclRef = fieldLookupExpr; + addModifier(diffField, derivativeMemberModifier); + } + + // Inject a `DerivativeMember` modifier on the original decl. + { + auto derivativeMemberModifier = m_astBuilder->create(); + auto fieldLookupExpr = m_astBuilder->create(); + fieldLookupExpr->type.type = diffMemberType; + auto baseTypeExpr = m_astBuilder->create(); + baseTypeExpr->base.type = differentialType; + auto baseTypeType = m_astBuilder->getOrCreate(differentialType); + baseTypeExpr->type.type = baseTypeType; + fieldLookupExpr->baseExpression = baseTypeExpr; + fieldLookupExpr->declRef = makeDeclRef(diffField); + derivativeMemberModifier->memberDeclRef = fieldLookupExpr; + addModifier(member, derivativeMemberModifier); } - return true; - } + }; - void SemanticsDeclHeaderVisitor::validateGenericConstraintSubType(GenericTypeConstraintDecl* decl, TypeExp type) + // Make the Differential type itself conform to `IDifferential` interface. + bool hasDifferentialConformance = false; + for (auto inheritanceDecl : aggTypeDecl->getMembersOfType()) { - // Validate that the sub type of a constraint is in valid form. - // - if (auto subDeclRef = isDeclRefTypeOf(type.type)) + if (auto declRefType = as(inheritanceDecl->base.type)) { - if (subDeclRef.getDecl()->parentDecl == decl->parentDecl) - { - // OK, sub type is one of the generic parameter type. - return; - } - if (as(decl->parentDecl)) + if (declRefType->getDeclRef() == m_astBuilder->getDifferentiableInterfaceDecl()) { - // If the constraint is in a generic decl, then the sub type must be dependent on at least one - // of the generic type parameters defined in the same generic decl. - // For example, it is invalid to define a constraint like `void foo() where int : float` since - // `int` isn't dependent on any generic type parameter. - auto dependentGeneric = getShared()->getDependentGenericParent(subDeclRef); - if (dependentGeneric.getDecl() != decl->parentDecl) - { - getSink()->diagnose(type.exp, Diagnostics::invalidConstraintSubType, type); - return; - } - } - else if (as(decl->parentDecl)) - { - // If the constraint is on an associated type, then it should either be the associated type itself, - // or a associated type of the associated type. - // For example, - // ``` - // interface IFoo { - // associatedtype T - // where T : IFoo // OK, constraint is on the associatedtype T itself. - // where T.T == X // OK, constraint is on the associated type of T. - // where int == X; // Error, int is not a valid left hand side of a constraint. - // } - // ``` - auto lookupDeclRef = as(subDeclRef.declRefBase); - if (!lookupDeclRef) - { - getSink()->diagnose(type.exp, Diagnostics::invalidConstraintSubType, type); - return; - } - - // We allow `associatedtype T where This.T : ...`. - // In this case, the left hand side will be in the form of - // LookupDeclRef(ThisType, T). i.e. lookupDeclRef->getDecl() == T. - // - if (lookupDeclRef->getDecl()->parentDecl == decl->parentDecl || - lookupDeclRef->getDecl() == decl->parentDecl) - return; - auto baseType = as(lookupDeclRef->getLookupSource()); - if (!baseType) - { - getSink()->diagnose(type.exp, Diagnostics::invalidConstraintSubType, type); - return; - } - type.type = baseType; - validateGenericConstraintSubType(decl, type); + hasDifferentialConformance = true; + break; } } } - - void SemanticsDeclHeaderVisitor::visitGenericTypeConstraintDecl(GenericTypeConstraintDecl* decl) + if (!hasDifferentialConformance) { - // TODO: are there any other validations we can do at this point? - // - // There probably needs to be a kind of "occurs check" to make - // sure that the constraint actually applies to at least one - // of the parameters of the generic. - // - CheckConstraintSubType(decl->sub); - - if (!decl->sub.type) - decl->sub = TranslateTypeNodeForced(decl->sub); - if (!decl->sup.type) - decl->sup = TranslateTypeNodeForced(decl->sup); - - if (getLinkage()->m_optionSet.shouldRunNonEssentialValidation()) - { - validateGenericConstraintSubType(decl, decl->sub); - } - - if (!decl->isEqualityConstraint && !isValidGenericConstraintType(decl->sup) && !as(decl->sub.type)) - { - getSink()->diagnose(decl->sup.exp, Diagnostics::invalidTypeForConstraint, decl->sup); - } + auto inheritanceIDiffernetiable = m_astBuilder->create(); + inheritanceIDiffernetiable->base.type = m_astBuilder->getDiffInterfaceType(); + inheritanceIDiffernetiable->parentDecl = aggTypeDecl; + aggTypeDecl->members.add(inheritanceIDiffernetiable); } - void SemanticsDeclHeaderVisitor::visitGenericTypeParamDecl(GenericTypeParamDecl* decl) + // The `Differential` type of a `Differential` type is always itself. + bool hasDifferentialTypeDef = false; + for (auto member : aggTypeDecl->members) { - // TODO: could probably push checking the default value - // for a generic type parameter later. - // - decl->initType = CheckProperType(decl->initType); + if (auto name = member->getName()) + { + if (name->text == "Differential") + { + hasDifferentialTypeDef = true; + break; + } + } } - - void SemanticsDeclHeaderVisitor::visitGenericValueParamDecl(GenericValueParamDecl* decl) + if (!hasDifferentialTypeDef) { - checkVarDeclCommon(decl); + auto assocTypeDef = m_astBuilder->create(); + assocTypeDef->nameAndLoc.name = getName("Differential"); + assocTypeDef->type.type = satisfyingType; + assocTypeDef->parentDecl = aggTypeDecl; + assocTypeDef->setCheckState(DeclCheckState::DefinitionChecked); + aggTypeDecl->members.add(assocTypeDef); } - void SemanticsDeclHeaderVisitor::visitGenericDecl(GenericDecl* genericDecl) + // Go through all members and collect their differential types. + // Go through super types. + for (auto inheritance : context->parentDecl->getMembersOfType()) { - genericDecl->setCheckState(DeclCheckState::ReadyForLookup); - - // NOTE! We purposefully do not iterate with the for(auto m : genericDecl->members) here, - // because the visitor may add to `members` whilst iteration takes place, invalidating the iterator - // and likely a crash. - // - // Accessing the members via index side steps the issue. - - Index parameterIndex = 0; - const auto& members = genericDecl->members; - for (Index i = 0; i < members.getCount(); ++i) + if (auto baseDeclRefType = as(inheritance->base.type)) { - Decl* m = members[i]; - - if (auto typeParam = as(m)) - { - ensureDecl(typeParam, DeclCheckState::ReadyForReference); - typeParam->parameterIndex = parameterIndex++; - } - else if (auto valParam = as(m)) - { - ensureDecl(valParam, DeclCheckState::ReadyForReference); - valParam->parameterIndex = parameterIndex++; - } - else if (auto constraint = as(m)) + // Skip interface super types. + if (baseDeclRefType->getDeclRef().as()) + continue; + if (auto superDiffType = tryGetDifferentialType(m_astBuilder, baseDeclRefType)) { - ensureDecl(constraint, DeclCheckState::ReadyForReference); + addDiffMember(inheritance, superDiffType); } } } - - void SemanticsDeclBasesVisitor::visitInheritanceDecl(InheritanceDecl* inheritanceDecl) + // Go through all var members. + for (auto member : context->parentDecl->getMembersOfType()) { - // check the type being inherited from - auto base = inheritanceDecl->base; - Decl* toExclude = nullptr; - Decl* parent = getParentDecl(inheritanceDecl); - // We exclude in the case that a circular reference is possible. This is when a parent is a transparent decl. - // If we just blanket "block" all ensure's of a parent a generic may fail when trying to fetch a parent - if (parent->findModifier()) - toExclude = parent; - SemanticsDeclVisitorBase baseVistor(this->withDeclToExcludeFromLookup(toExclude)); - baseVistor.CheckConstraintSubType(base); - base = baseVistor.TranslateTypeNode(base); - inheritanceDecl->base = base; - - // Note: we do not check whether the type being inherited from - // is valid to use for inheritance here, because there could - // be contextual factors that need to be taken into account - // based on the declaration that is doing the inheriting. + if (member->hasModifier()) + continue; + auto diffType = tryGetDifferentialType(m_astBuilder, member->type.type); + if (!diffType) + continue; + addDiffMember(member, diffType); } - void SemanticsDeclBasesVisitor::visitThisTypeConstraintDecl(ThisTypeConstraintDecl* thisTypeConstraintDecl) + addModifier(aggTypeDecl, m_astBuilder->create()); + + // The visibility of synthesized decl should be the min of the parent decl and the requirement. + if (requirementDeclRef.getDecl()->findModifier()) { - // Make sure IFoo.This.ThisIsIFooConstraint.base.type is properly set - // to DeclRefType(IFoo) with default generic arguments. - if (!thisTypeConstraintDecl->base.type) - { - auto parentTypeDecl = getParentDecl(getParentDecl(thisTypeConstraintDecl)); - thisTypeConstraintDecl->base.type = DeclRefType::create( - m_astBuilder, - createDefaultSubstitutionsIfNeeded( - m_astBuilder, - this, - getDefaultDeclRef(parentTypeDecl))); - } + auto requirementVisibility = getDeclVisibility(requirementDeclRef.getDecl()); + auto thisVisibility = getDeclVisibility(context->parentDecl); + auto visibility = Math::Min(thisVisibility, requirementVisibility); + addVisibilityModifier(m_astBuilder, aggTypeDecl, visibility); } - // Concretize interface conformances so that we have witnesses as required for lookup. - // for lookup. - struct SemanticsDeclConformancesVisitor - : public SemanticsDeclVisitorBase - , public DeclVisitor - { - SemanticsDeclConformancesVisitor(SemanticsContext const& outer) - : SemanticsDeclVisitorBase(outer) - {} + // Synthesize the rest of IDifferential method conformances by recursively checking + // conformance on the synthesized decl. + checkAggTypeConformance(aggTypeDecl); - void visitDecl(Decl*) {} - void visitDeclGroup(DeclGroup*) {} + witnessTable->add(requirementDeclRef.getDecl(), RequirementWitness(satisfyingType)); + if (!doesTypeSatisfyAssociatedTypeConstraintRequirement( + satisfyingType, + requirementDeclRef, + witnessTable)) + { + // Note: the call to `doesTypeSatisfyAssociatedTypeConstraintRequirement` should always + // succeed. If not, there is something wrong with the code synthesis logic. For now we just + // return false instead of crashing so the user can work around the issues. + witnessTable->m_requirementDictionary.remove(requirementDeclRef.getDecl()); + return false; + } + return true; +} - // Any user-defined type may have declared interface conformances, - // which we should check. - // - void visitAggTypeDecl(AggTypeDecl* aggTypeDecl) +void SemanticsDeclHeaderVisitor::validateGenericConstraintSubType( + GenericTypeConstraintDecl* decl, + TypeExp type) +{ + // Validate that the sub type of a constraint is in valid form. + // + if (auto subDeclRef = isDeclRefTypeOf(type.type)) + { + if (subDeclRef.getDecl()->parentDecl == decl->parentDecl) { - checkAggTypeConformance(aggTypeDecl); + // OK, sub type is one of the generic parameter type. + return; } - - // Conformances can also come via `extension` declarations, and - // we should check them against the type(s) being extended. - // - void visitExtensionDecl(ExtensionDecl* extensionDecl) + if (as(decl->parentDecl)) { - checkExtensionConformance(extensionDecl); + // If the constraint is in a generic decl, then the sub type must be dependent on at + // least one of the generic type parameters defined in the same generic decl. For + // example, it is invalid to define a constraint like `void foo() where int : float` + // since `int` isn't dependent on any generic type parameter. + auto dependentGeneric = getShared()->getDependentGenericParent(subDeclRef); + if (dependentGeneric.getDecl() != decl->parentDecl) + { + getSink()->diagnose(type.exp, Diagnostics::invalidConstraintSubType, type); + return; + } } - }; - - // Check that types used as `Differential` type use themselves as their own `Differential` type. - struct SemanticsDeclDifferentialConformanceVisitor - : public SemanticsDeclVisitorBase - , public DeclVisitor - { - SemanticsDeclDifferentialConformanceVisitor(SemanticsContext const& outer) - : SemanticsDeclVisitorBase(outer) - {} - void visitDecl(Decl*) {} - void visitDeclGroup(DeclGroup*) {} - - void visitInheritanceDecl(InheritanceDecl* inheritanceDecl) + else if (as(decl->parentDecl)) { - if (as(inheritanceDecl->parentDecl)) + // If the constraint is on an associated type, then it should either be the associated + // type itself, or a associated type of the associated type. For example, + // ``` + // interface IFoo { + // associatedtype T + // where T : IFoo // OK, constraint is on the associatedtype T itself. + // where T.T == X // OK, constraint is on the associated type of T. + // where int == X; // Error, int is not a valid left hand side of a constraint. + // } + // ``` + auto lookupDeclRef = as(subDeclRef.declRefBase); + if (!lookupDeclRef) + { + getSink()->diagnose(type.exp, Diagnostics::invalidConstraintSubType, type); return; + } - if (!inheritanceDecl->witnessTable) + // We allow `associatedtype T where This.T : ...`. + // In this case, the left hand side will be in the form of + // LookupDeclRef(ThisType, T). i.e. lookupDeclRef->getDecl() == T. + // + if (lookupDeclRef->getDecl()->parentDecl == decl->parentDecl || + lookupDeclRef->getDecl() == decl->parentDecl) return; - auto baseType = as(inheritanceDecl->witnessTable->baseType); + auto baseType = as(lookupDeclRef->getLookupSource()); if (!baseType) - return; - if (baseType->getDeclRef().getDecl() != m_astBuilder->getDifferentiableInterfaceDecl().getDecl()) - return; - RequirementWitness witnessValue; - auto requirementDecl = m_astBuilder->getSharedASTBuilder()->findBuiltinRequirementDecl(BuiltinRequirementKind::DifferentialType); - if (!inheritanceDecl->witnessTable->getRequirementDictionary().tryGetValue(requirementDecl, witnessValue)) - return; - - if (witnessValue.getFlavor() != RequirementWitness::Flavor::val) - return; - auto differentialType = as(witnessValue.getVal()); - if (!differentialType) - return; - - // Check that the type used as differential type must have itself as its own differential type. - auto diffDiffType = tryGetDifferentialType(m_astBuilder, differentialType); - if (!differentialType->equals(diffDiffType)) { - SourceLoc sourceLoc = differentialType->getDeclRef().getDecl()->loc; - getSink()->diagnose( - inheritanceDecl, - Diagnostics::differentialTypeShouldServeAsItsOwnDifferentialType, - differentialType, - diffDiffType); - getSink()->diagnose(sourceLoc, Diagnostics::seeDefinitionOf, differentialType); + getSink()->diagnose(type.exp, Diagnostics::invalidConstraintSubType, type); + return; } + type.type = baseType; + validateGenericConstraintSubType(decl, type); + } + } +} - // Check that all [DerivativeMember(...)] attributes have their references checked. - for (auto member : inheritanceDecl->parentDecl->getMembersOfType()) - { - if (member->findModifier()) - continue; - auto derivativeMemberAttr = member->findModifier(); - if (!derivativeMemberAttr) - continue; - checkDerivativeMemberAttributeReferences(member, derivativeMemberAttr); - } +void SemanticsDeclHeaderVisitor::visitGenericTypeConstraintDecl(GenericTypeConstraintDecl* decl) +{ + // TODO: are there any other validations we can do at this point? + // + // There probably needs to be a kind of "occurs check" to make + // sure that the constraint actually applies to at least one + // of the parameters of the generic. + // + CheckConstraintSubType(decl->sub); - // Check that either the differential type is the same as the base type, or all fields of the base type that are differentiable - // have a corresponding field in the differential type through the [DerivativeMember(...)] attribute. - // - // We only need to check the fields of the base type that are differentiable. - auto baseDecl = as(inheritanceDecl->parentDecl); - if (!baseDecl) - return; + if (!decl->sub.type) + decl->sub = TranslateTypeNodeForced(decl->sub); + if (!decl->sup.type) + decl->sup = TranslateTypeNodeForced(decl->sup); + + if (getLinkage()->m_optionSet.shouldRunNonEssentialValidation()) + { + validateGenericConstraintSubType(decl, decl->sub); + } + + if (!decl->isEqualityConstraint && !isValidGenericConstraintType(decl->sup) && + !as(decl->sub.type)) + { + getSink()->diagnose(decl->sup.exp, Diagnostics::invalidTypeForConstraint, decl->sup); + } +} - auto thisType = calcThisType(getDefaultDeclRef(baseDecl)); +void SemanticsDeclHeaderVisitor::visitGenericTypeParamDecl(GenericTypeParamDecl* decl) +{ + // TODO: could probably push checking the default value + // for a generic type parameter later. + // + decl->initType = CheckProperType(decl->initType); +} - bool typeIsSelfDifferential = thisType->equals(differentialType); +void SemanticsDeclHeaderVisitor::visitGenericValueParamDecl(GenericValueParamDecl* decl) +{ + checkVarDeclCommon(decl); +} - for (auto member : baseDecl->getMembersOfType()) - { - if (member->findModifier()) - continue; - auto diffType = tryGetDifferentialType(m_astBuilder, member->type.type); - if (!diffType) - continue; +void SemanticsDeclHeaderVisitor::visitGenericDecl(GenericDecl* genericDecl) +{ + genericDecl->setCheckState(DeclCheckState::ReadyForLookup); - if (member->findModifier()) - continue; - else if (!typeIsSelfDifferential) - getSink()->diagnose( - member, - Diagnostics::differentiableMemberShouldHaveCorrespondingFieldInDiffType, - member->nameAndLoc.name, - differentialType); - else - { - // If the type is its own differential type, we can infer the differential - // members from the original type. - // - // Add a derivative member attribute referencing itself. - // - auto derivativeMemberModifier = m_astBuilder->create(); - auto fieldLookupExpr = m_astBuilder->create(); - fieldLookupExpr->type.type = diffType; - auto baseTypeExpr = m_astBuilder->create(); - baseTypeExpr->base.type = differentialType; - auto baseTypeType = m_astBuilder->getOrCreate(differentialType); - baseTypeExpr->type.type = baseTypeType; - fieldLookupExpr->baseExpression = baseTypeExpr; - fieldLookupExpr->declRef = makeDeclRef(member); - derivativeMemberModifier->memberDeclRef = fieldLookupExpr; - addModifier(member, derivativeMemberModifier); - } - } - } - }; + // NOTE! We purposefully do not iterate with the for(auto m : genericDecl->members) here, + // because the visitor may add to `members` whilst iteration takes place, invalidating the + // iterator and likely a crash. + // + // Accessing the members via index side steps the issue. - /// Recursively register any builtin declarations that need to be attached to the `session`. - /// - /// This function should only be needed for declarations in the core module. - /// - static void _registerBuiltinDeclsRec(Session* session, Decl* decl) + Index parameterIndex = 0; + const auto& members = genericDecl->members; + for (Index i = 0; i < members.getCount(); ++i) { - SharedASTBuilder* sharedASTBuilder = session->m_sharedASTBuilder; + Decl* m = members[i]; - if (auto builtinMod = decl->findModifier()) + if (auto typeParam = as(m)) { - sharedASTBuilder->registerBuiltinDecl(decl, builtinMod); + ensureDecl(typeParam, DeclCheckState::ReadyForReference); + typeParam->parameterIndex = parameterIndex++; } - if (auto magicMod = decl->findModifier()) + else if (auto valParam = as(m)) { - sharedASTBuilder->registerMagicDecl(decl, magicMod); + ensureDecl(valParam, DeclCheckState::ReadyForReference); + valParam->parameterIndex = parameterIndex++; } - if (auto builtinRequirement = decl->findModifier()) + else if (auto constraint = as(m)) { - sharedASTBuilder->registerBuiltinRequirementDecl(decl, builtinRequirement); + ensureDecl(constraint, DeclCheckState::ReadyForReference); } - if(auto containerDecl = as(decl)) - { - for(auto childDecl : containerDecl->members) - { - if(as(childDecl)) - continue; + } +} - _registerBuiltinDeclsRec(session, childDecl); - } - } - if(auto genericDecl = as(decl)) - { - _registerBuiltinDeclsRec(session, genericDecl->inner); - } +void SemanticsDeclBasesVisitor::visitInheritanceDecl(InheritanceDecl* inheritanceDecl) +{ + // check the type being inherited from + auto base = inheritanceDecl->base; + Decl* toExclude = nullptr; + Decl* parent = getParentDecl(inheritanceDecl); + // We exclude in the case that a circular reference is possible. This is when a parent is a + // transparent decl. If we just blanket "block" all ensure's of a parent a generic may fail when + // trying to fetch a parent + if (parent->findModifier()) + toExclude = parent; + SemanticsDeclVisitorBase baseVistor(this->withDeclToExcludeFromLookup(toExclude)); + baseVistor.CheckConstraintSubType(base); + base = baseVistor.TranslateTypeNode(base); + inheritanceDecl->base = base; + + // Note: we do not check whether the type being inherited from + // is valid to use for inheritance here, because there could + // be contextual factors that need to be taken into account + // based on the declaration that is doing the inheriting. +} + +void SemanticsDeclBasesVisitor::visitThisTypeConstraintDecl( + ThisTypeConstraintDecl* thisTypeConstraintDecl) +{ + // Make sure IFoo.This.ThisIsIFooConstraint.base.type is properly set + // to DeclRefType(IFoo) with default generic arguments. + if (!thisTypeConstraintDecl->base.type) + { + auto parentTypeDecl = getParentDecl(getParentDecl(thisTypeConstraintDecl)); + thisTypeConstraintDecl->base.type = DeclRefType::create( + m_astBuilder, + createDefaultSubstitutionsIfNeeded( + m_astBuilder, + this, + getDefaultDeclRef(parentTypeDecl))); } +} - void registerBuiltinDecls(Session* session, Decl* decl) +// Concretize interface conformances so that we have witnesses as required for lookup. +// for lookup. +struct SemanticsDeclConformancesVisitor : public SemanticsDeclVisitorBase, + public DeclVisitor +{ + SemanticsDeclConformancesVisitor(SemanticsContext const& outer) + : SemanticsDeclVisitorBase(outer) { - _registerBuiltinDeclsRec(session, decl); } - Type* unwrapArrayType(Type* type) + void visitDecl(Decl*) {} + void visitDeclGroup(DeclGroup*) {} + + // Any user-defined type may have declared interface conformances, + // which we should check. + // + void visitAggTypeDecl(AggTypeDecl* aggTypeDecl) { checkAggTypeConformance(aggTypeDecl); } + + // Conformances can also come via `extension` declarations, and + // we should check them against the type(s) being extended. + // + void visitExtensionDecl(ExtensionDecl* extensionDecl) { - for (;;) - { - if (auto arrayType = as(type)) - type = arrayType->getElementType(); - else - return type; - } + checkExtensionConformance(extensionDecl); } +}; - void discoverExtensionDecls(List& decls, Decl* parent) +// Check that types used as `Differential` type use themselves as their own `Differential` type. +struct SemanticsDeclDifferentialConformanceVisitor + : public SemanticsDeclVisitorBase, + public DeclVisitor +{ + SemanticsDeclDifferentialConformanceVisitor(SemanticsContext const& outer) + : SemanticsDeclVisitorBase(outer) { - if (auto extDecl = as(parent)) - decls.add(extDecl); - if (auto containerDecl = as(parent)) - { - for (auto child : containerDecl->members) - { - discoverExtensionDecls(decls, child); - } - } - if (auto genericDecl = as(parent)) - { - discoverExtensionDecls(decls, genericDecl->inner); - } } + void visitDecl(Decl*) {} + void visitDeclGroup(DeclGroup*) {} - void SemanticsDeclVisitorBase::checkModule(ModuleDecl* moduleDecl) + void visitInheritanceDecl(InheritanceDecl* inheritanceDecl) { - // When we are dealing with code from the core modules, - // there is a potential problem where we might need to look - // up built-in types like `Int` through the session (e.g., - // to determine the type for an integer literal), but those - // types might not have been registered yet. We solve that - // by doing a pre-process on the core module code to find - // and register any built-in declarations. - // - // TODO: This could be factored into another visitor pass - // that fits the more standard checking below, but that would - // seemingly add overhead to checking things other than - // the core module. - // - if(isFromCoreModule(moduleDecl)) - { - _registerBuiltinDeclsRec(getSession(), moduleDecl); - } + if (as(inheritanceDecl->parentDecl)) + return; + + if (!inheritanceDecl->witnessTable) + return; + auto baseType = as(inheritanceDecl->witnessTable->baseType); + if (!baseType) + return; + if (baseType->getDeclRef().getDecl() != + m_astBuilder->getDifferentiableInterfaceDecl().getDecl()) + return; + RequirementWitness witnessValue; + auto requirementDecl = m_astBuilder->getSharedASTBuilder()->findBuiltinRequirementDecl( + BuiltinRequirementKind::DifferentialType); + if (!inheritanceDecl->witnessTable->getRequirementDictionary().tryGetValue( + requirementDecl, + witnessValue)) + return; + + if (witnessValue.getFlavor() != RequirementWitness::Flavor::val) + return; + auto differentialType = as(witnessValue.getVal()); + if (!differentialType) + return; - if (moduleDecl->members.getCount() > 0) + // Check that the type used as differential type must have itself as its own differential + // type. + auto diffDiffType = tryGetDifferentialType(m_astBuilder, differentialType); + if (!differentialType->equals(diffDiffType)) { - auto firstMember = moduleDecl->members[0]; - if (as(firstMember)) - { - if (!getShared()->isInLanguageServer()) - { - // A primary module file can't start with an "implementing" declaration. - getSink()->diagnose(firstMember, Diagnostics::primaryModuleFileCannotStartWithImplementingDecl); - } - } - else if (!as(firstMember)) - { - // A primary module file must start with a `module` declaration. - // TODO: this warning is disabled for now to free users from massive change for now. -#if 0 - getSink()->diagnose(firstMember, Diagnostics::primaryModuleFileMustStartWithModuleDecl); -#endif - } - } - - // We need/want to visit any `import` declarations before - // anything else, to make sure that scoping works. - // - // TODO: This could be factored into another visitor pass - // that fits more with the standard checking below. - // - for(auto importDecl : moduleDecl->getMembersOfType()) - { - ensureDecl(importDecl, DeclCheckState::DefinitionChecked); + SourceLoc sourceLoc = differentialType->getDeclRef().getDecl()->loc; + getSink()->diagnose( + inheritanceDecl, + Diagnostics::differentialTypeShouldServeAsItsOwnDifferentialType, + differentialType, + diffDiffType); + getSink()->diagnose(sourceLoc, Diagnostics::seeDefinitionOf, differentialType); } - // Next, make sure all `__include` decls are processed and the referenced - // files are parsed. - auto visitIncludeDecls = [&](ContainerDecl* fileDecl) - { - for (Index i = 0; i < fileDecl->members.getCount(); i++) - { - auto decl = fileDecl->members[i]; - if (auto includeDecl = as(decl)) - { - ensureDecl(includeDecl, DeclCheckState::DefinitionChecked); - } - else if (auto implementingDecl = as(decl)) - { - ensureDecl(implementingDecl, DeclCheckState::DefinitionChecked); - } - else if (auto importDecl = as(decl)) - { - ensureDecl(importDecl, DeclCheckState::DefinitionChecked); - } - } - }; - visitIncludeDecls(moduleDecl); - for (Index i = 0; i < moduleDecl->members.getCount(); i++) + // Check that all [DerivativeMember(...)] attributes have their references checked. + for (auto member : inheritanceDecl->parentDecl->getMembersOfType()) { - if (auto fileDecl = as(moduleDecl->members[i])) - visitIncludeDecls(fileDecl); + if (member->findModifier()) + continue; + auto derivativeMemberAttr = member->findModifier(); + if (!derivativeMemberAttr) + continue; + checkDerivativeMemberAttributeReferences(member, derivativeMemberAttr); } - // The entire goal of semantic checking is to get all of the - // declarations in the module up to `DeclCheckState::DefinitionChecked`. - // - // The main catch is that checking one declaration A up to state M - // may required that declaration B is checked up to state N. - // A call to `ensureDecl(B, N)` can guarantee that things are checked - // when and where we need them, but that runs the risk of creating - // very deep recursion in the semantic checking. - // - // Instead, we would rather do more breadth-first checking, - // where everything gets checked up to state 1, 2, ... - // before anything gets too far ahead. - // We will therefore enumerate the states/phases for checking, - // and then iteratively try to update all declarations to each - // state in turn. - // - // Note: for a simpler language we could eliminate `ensureDecl` - // completely and *just* have these phases of checking. - // Unfortunately, we have some circularity between the phases: - // - // * Checking an overloaded call requires knowing the parameter - // types of all candidate callees. - // - // * Checking the parameter type of a function requires being - // able to check type expressions. + // Check that either the differential type is the same as the base type, or all fields of + // the base type that are differentiable have a corresponding field in the differential type + // through the [DerivativeMember(...)] attribute. // - // * A type expression like `vector` may have an arbitary - // expression for `N`. - // - // * An arbitrary expression may include function calls, which - // may be to overloaded functions. - // - // Languages like C++ solve the apparent problem by making - // restrictions on order of declaration/definition (and by - // requiring forward declarations or the `template`/`typename` - // keywrods in some cases). - // - // TODO: We could eventually eliminate the potential recursion - // in checking by splitting each phase into a "requirements gathering" - // step and an actual execution step. - // - // When checking a declaration D up to state S, the requirements - // gathering step would produce a list of pairs `(someDecl, someState)` - // indicating that `someDecl` must be in `someState` before the - // actual execution of checking for `(D,S)` can proceeed. The checker - // can then produce an elaborated dependency graph and select nodes - // for execution in an order that satisfies all the dependencies. - // - // Such a more elaborate checking scheme will have to wait for another - // day, but might be worth it (or even necessary) if/when we want to - // support incremental compilation. - // - DeclCheckState states[] = - { - DeclCheckState::ScopesWired, - DeclCheckState::ReadyForReference, - DeclCheckState::ReadyForLookup, - DeclCheckState::ReadyForConformances, - DeclCheckState::DefinitionChecked, - DeclCheckState::CapabilityChecked, - }; + // We only need to check the fields of the base type that are differentiable. + auto baseDecl = as(inheritanceDecl->parentDecl); + if (!baseDecl) + return; - // Discover and check all extension decls before anything else. - List extensionDecls; - discoverExtensionDecls(extensionDecls, moduleDecl); - for (auto s : states) + auto thisType = calcThisType(getDefaultDeclRef(baseDecl)); + + bool typeIsSelfDifferential = thisType->equals(differentialType); + + for (auto member : baseDecl->getMembersOfType()) { - for (auto extensionDecl : extensionDecls) + if (member->findModifier()) + continue; + auto diffType = tryGetDifferentialType(m_astBuilder, member->type.type); + if (!diffType) + continue; + + if (member->findModifier()) + continue; + else if (!typeIsSelfDifferential) + getSink()->diagnose( + member, + Diagnostics::differentiableMemberShouldHaveCorrespondingFieldInDiffType, + member->nameAndLoc.name, + differentialType); + else { - ensureDecl(extensionDecl, s); + // If the type is its own differential type, we can infer the differential + // members from the original type. + // + // Add a derivative member attribute referencing itself. + // + auto derivativeMemberModifier = m_astBuilder->create(); + auto fieldLookupExpr = m_astBuilder->create(); + fieldLookupExpr->type.type = diffType; + auto baseTypeExpr = m_astBuilder->create(); + baseTypeExpr->base.type = differentialType; + auto baseTypeType = m_astBuilder->getOrCreate(differentialType); + baseTypeExpr->type.type = baseTypeType; + fieldLookupExpr->baseExpression = baseTypeExpr; + fieldLookupExpr->declRef = makeDeclRef(member); + derivativeMemberModifier->memberDeclRef = fieldLookupExpr; + addModifier(member, derivativeMemberModifier); } - // We only need to check extension decls up to ReadyForLookup - // so they are properly registered in type inheritance infos. - if (s == DeclCheckState::ReadyForLookup) - break; } + } +}; - // With extensions taken care of, we can now check the remaining decls. - for(auto s : states) - { - // When advancing to state `s` we will recursively - // advance all declarations rooted in the module - // up to `s`. - // - // TODO: In cases where a large module is split across files, - // we could potentially parallelize front-end compilation by - // having multiple instances of the front end where each is - // only responsible for those declarations in a given file. - // - // Under that model, we might only apply later phases of - // checking (notably the final push to `DeclState::Checked`) - // to the subset of declarations coming from a given source - // file. - // - ensureAllDeclsRec(moduleDecl, s); - } +/// Recursively register any builtin declarations that need to be attached to the `session`. +/// +/// This function should only be needed for declarations in the core module. +/// +static void _registerBuiltinDeclsRec(Session* session, Decl* decl) +{ + SharedASTBuilder* sharedASTBuilder = session->m_sharedASTBuilder; - // Once we have completed the above loop, all declarations not - // nested in function bodies should be in `DeclState::Checked`. - // Furthermore, because a fully checked function will have checked - // its body, this also means that all function bodies and the - // declarations they contain should be fully checked. + if (auto builtinMod = decl->findModifier()) + { + sharedASTBuilder->registerBuiltinDecl(decl, builtinMod); } - - bool SemanticsVisitor::doesSignatureMatchRequirement( - DeclRef satisfyingMemberDeclRef, - DeclRef requiredMemberDeclRef, - RefPtr witnessTable) + if (auto magicMod = decl->findModifier()) + { + sharedASTBuilder->registerMagicDecl(decl, magicMod); + } + if (auto builtinRequirement = decl->findModifier()) + { + sharedASTBuilder->registerBuiltinRequirementDecl(decl, builtinRequirement); + } + if (auto containerDecl = as(decl)) { - if(satisfyingMemberDeclRef.getDecl()->hasModifier() - != requiredMemberDeclRef.getDecl()->hasModifier()) + for (auto childDecl : containerDecl->members) { - // A `[mutating]` method can't satisfy a non-`[mutating]` requirement. - // The opposite direction is okay, but we will need to synthesize a wrapper - // to ensure type matches, so we will return false here either way. - return false; - } + if (as(childDecl)) + continue; - if (satisfyingMemberDeclRef.getDecl()->hasModifier() - != requiredMemberDeclRef.getDecl()->hasModifier()) - { - // A `[constref]` method can't satisfy a non-`[constref]` requirement. - // The opposite direction is okay, but we will need to synthesize a wrapper - // to ensure type matches, so we will return false here either way. - return false; + _registerBuiltinDeclsRec(session, childDecl); } + } + if (auto genericDecl = as(decl)) + { + _registerBuiltinDeclsRec(session, genericDecl->inner); + } +} - if (satisfyingMemberDeclRef.getDecl()->hasModifier() - != requiredMemberDeclRef.getDecl()->hasModifier()) - { - // A `[ref]` method can't satisfy a non-`[ref]` requirement. - // The opposite direction is okay, but we will need to synthesize a wrapper - // to ensure type matches, so we will return false here either way. - return false; - } +void registerBuiltinDecls(Session* session, Decl* decl) +{ + _registerBuiltinDeclsRec(session, decl); +} - if(satisfyingMemberDeclRef.getDecl()->hasModifier() - != requiredMemberDeclRef.getDecl()->hasModifier()) - { - // A `static` method can't satisfy a non-`static` requirement and vice versa. - return false; - } +Type* unwrapArrayType(Type* type) +{ + for (;;) + { + if (auto arrayType = as(type)) + type = arrayType->getElementType(); + else + return type; + } +} - bool hasBackwardDerivative = false; - bool hasForwardDerivative = false; - if (requiredMemberDeclRef.getDecl()->hasModifier()) +void discoverExtensionDecls(List& decls, Decl* parent) +{ + if (auto extDecl = as(parent)) + decls.add(extDecl); + if (auto containerDecl = as(parent)) + { + for (auto child : containerDecl->members) { - auto funcDecl = as(satisfyingMemberDeclRef.getDecl()); - if (!funcDecl) - return false; - - if (getShared()->getFuncDifferentiableLevel(funcDecl) != FunctionDifferentiableLevel::Backward) - { - // A non-`BackwardDifferentiable` method can't satisfy a `BackwardDifferentiable` requirement and vice versa. - return false; - } - hasBackwardDerivative = true; - hasForwardDerivative = true; + discoverExtensionDecls(decls, child); } - else if (requiredMemberDeclRef.getDecl()->hasModifier()) + } + if (auto genericDecl = as(parent)) + { + discoverExtensionDecls(decls, genericDecl->inner); + } +} + +void SemanticsDeclVisitorBase::checkModule(ModuleDecl* moduleDecl) +{ + // When we are dealing with code from the core modules, + // there is a potential problem where we might need to look + // up built-in types like `Int` through the session (e.g., + // to determine the type for an integer literal), but those + // types might not have been registered yet. We solve that + // by doing a pre-process on the core module code to find + // and register any built-in declarations. + // + // TODO: This could be factored into another visitor pass + // that fits the more standard checking below, but that would + // seemingly add overhead to checking things other than + // the core module. + // + if (isFromCoreModule(moduleDecl)) + { + _registerBuiltinDeclsRec(getSession(), moduleDecl); + } + + if (moduleDecl->members.getCount() > 0) + { + auto firstMember = moduleDecl->members[0]; + if (as(firstMember)) { - auto funcDecl = as(satisfyingMemberDeclRef.getDecl()); - if (!funcDecl) - return false; - if (getShared()->getFuncDifferentiableLevel(funcDecl) == FunctionDifferentiableLevel::None) + if (!getShared()->isInLanguageServer()) { - // A non-`BackwardDifferentiable` method can't satisfy a `BackwardDifferentiable` requirement and vice versa. - return false; + // A primary module file can't start with an "implementing" declaration. + getSink()->diagnose( + firstMember, + Diagnostics::primaryModuleFileCannotStartWithImplementingDecl); } - hasForwardDerivative = true; } - - // A signature matches the required one if it has the right number of parameters, - // and those parameters have the right types, and also the result/return type - // is the required one. - // - auto requiredParams = getParameters(m_astBuilder, requiredMemberDeclRef).toArray(); - auto satisfyingParams = getParameters(m_astBuilder, satisfyingMemberDeclRef).toArray(); - auto paramCount = requiredParams.getCount(); - if(satisfyingParams.getCount() != paramCount) - return false; - - for(Index paramIndex = 0; paramIndex < paramCount; ++paramIndex) + else if (!as(firstMember)) { - auto requiredParam = requiredParams[paramIndex]; - auto satisfyingParam = satisfyingParams[paramIndex]; - - auto requiredParamType = getType(m_astBuilder, requiredParam); - auto satisfyingParamType = getType(m_astBuilder, satisfyingParam); - - if(!requiredParamType->equals(satisfyingParamType)) - return false; + // A primary module file must start with a `module` declaration. + // TODO: this warning is disabled for now to free users from massive change for now. +#if 0 + getSink()->diagnose(firstMember, Diagnostics::primaryModuleFileMustStartWithModuleDecl); +#endif } + } - auto requiredResultType = getResultType(m_astBuilder, requiredMemberDeclRef); - auto satisfyingResultType = getResultType(m_astBuilder, satisfyingMemberDeclRef); - if(!requiredResultType->equals(satisfyingResultType)) - return false; + // We need/want to visit any `import` declarations before + // anything else, to make sure that scoping works. + // + // TODO: This could be factored into another visitor pass + // that fits more with the standard checking below. + // + for (auto importDecl : moduleDecl->getMembersOfType()) + { + ensureDecl(importDecl, DeclCheckState::DefinitionChecked); + } - if (hasForwardDerivative || hasBackwardDerivative) + // Next, make sure all `__include` decls are processed and the referenced + // files are parsed. + auto visitIncludeDecls = [&](ContainerDecl* fileDecl) + { + for (Index i = 0; i < fileDecl->members.getCount(); i++) { - auto parentInterfaceDecl = as(getParentDecl(requiredMemberDeclRef.getDecl())); - if (parentInterfaceDecl) + auto decl = fileDecl->members[i]; + if (auto includeDecl = as(decl)) { - bool noDiffThisSatisfying = !isTypeDifferentiable(witnessTable->witnessedType); - bool noDiffThisRequirement = (requiredMemberDeclRef.getDecl()->findModifier() != nullptr); - if (noDiffThisRequirement != noDiffThisSatisfying) - return false; + ensureDecl(includeDecl, DeclCheckState::DefinitionChecked); + } + else if (auto implementingDecl = as(decl)) + { + ensureDecl(implementingDecl, DeclCheckState::DefinitionChecked); + } + else if (auto importDecl = as(decl)) + { + ensureDecl(importDecl, DeclCheckState::DefinitionChecked); } } + }; + visitIncludeDecls(moduleDecl); + for (Index i = 0; i < moduleDecl->members.getCount(); i++) + { + if (auto fileDecl = as(moduleDecl->members[i])) + visitIncludeDecls(fileDecl); + } + + // The entire goal of semantic checking is to get all of the + // declarations in the module up to `DeclCheckState::DefinitionChecked`. + // + // The main catch is that checking one declaration A up to state M + // may required that declaration B is checked up to state N. + // A call to `ensureDecl(B, N)` can guarantee that things are checked + // when and where we need them, but that runs the risk of creating + // very deep recursion in the semantic checking. + // + // Instead, we would rather do more breadth-first checking, + // where everything gets checked up to state 1, 2, ... + // before anything gets too far ahead. + // We will therefore enumerate the states/phases for checking, + // and then iteratively try to update all declarations to each + // state in turn. + // + // Note: for a simpler language we could eliminate `ensureDecl` + // completely and *just* have these phases of checking. + // Unfortunately, we have some circularity between the phases: + // + // * Checking an overloaded call requires knowing the parameter + // types of all candidate callees. + // + // * Checking the parameter type of a function requires being + // able to check type expressions. + // + // * A type expression like `vector` may have an arbitary + // expression for `N`. + // + // * An arbitrary expression may include function calls, which + // may be to overloaded functions. + // + // Languages like C++ solve the apparent problem by making + // restrictions on order of declaration/definition (and by + // requiring forward declarations or the `template`/`typename` + // keywrods in some cases). + // + // TODO: We could eventually eliminate the potential recursion + // in checking by splitting each phase into a "requirements gathering" + // step and an actual execution step. + // + // When checking a declaration D up to state S, the requirements + // gathering step would produce a list of pairs `(someDecl, someState)` + // indicating that `someDecl` must be in `someState` before the + // actual execution of checking for `(D,S)` can proceeed. The checker + // can then produce an elaborated dependency graph and select nodes + // for execution in an order that satisfies all the dependencies. + // + // Such a more elaborate checking scheme will have to wait for another + // day, but might be worth it (or even necessary) if/when we want to + // support incremental compilation. + // + DeclCheckState states[] = { + DeclCheckState::ScopesWired, + DeclCheckState::ReadyForReference, + DeclCheckState::ReadyForLookup, + DeclCheckState::ReadyForConformances, + DeclCheckState::DefinitionChecked, + DeclCheckState::CapabilityChecked, + }; - _addMethodWitness(witnessTable, requiredMemberDeclRef, satisfyingMemberDeclRef); - - return true; + // Discover and check all extension decls before anything else. + List extensionDecls; + discoverExtensionDecls(extensionDecls, moduleDecl); + for (auto s : states) + { + for (auto extensionDecl : extensionDecls) + { + ensureDecl(extensionDecl, s); + } + // We only need to check extension decls up to ReadyForLookup + // so they are properly registered in type inheritance infos. + if (s == DeclCheckState::ReadyForLookup) + break; } - bool SemanticsVisitor::doesAccessorMatchRequirement( - DeclRef satisfyingMemberDeclRef, - DeclRef requiredMemberDeclRef) + // With extensions taken care of, we can now check the remaining decls. + for (auto s : states) { - // We require the AST node class of the satisfying accessor - // to be a subclass of the one from the required accessor. + // When advancing to state `s` we will recursively + // advance all declarations rooted in the module + // up to `s`. // - // For our current accessor types, this amounts to requiring - // an exact match, but using a subtype test means that if - // we ever add an `ExtraSpecialGetDecl` that is a subclass - // of `GetDecl`, then one of those would be able to satisfy - // a `get` requirement. + // TODO: In cases where a large module is split across files, + // we could potentially parallelize front-end compilation by + // having multiple instances of the front end where each is + // only responsible for those declarations in a given file. // - auto satisfyingMemberClass = satisfyingMemberDeclRef.getDecl()->getClass(); - auto requiredMemberClass = requiredMemberDeclRef.getDecl()->getClass(); - if(!satisfyingMemberClass.isSubClassOfImpl(requiredMemberClass)) - return false; + // Under that model, we might only apply later phases of + // checking (notably the final push to `DeclState::Checked`) + // to the subset of declarations coming from a given source + // file. + // + ensureAllDeclsRec(moduleDecl, s); + } - // We do not check the parameters or return types of accessors - // here, under the assumption that the validity checks for - // the parent `property` declaration would already make sure - // they are in order. + // Once we have completed the above loop, all declarations not + // nested in function bodies should be in `DeclState::Checked`. + // Furthermore, because a fully checked function will have checked + // its body, this also means that all function bodies and the + // declarations they contain should be fully checked. +} - // TODO: There are other checks we need to make here, like not letting - // an ordinary `set` satisfy a `[nonmutating] set` requirement. +bool SemanticsVisitor::doesSignatureMatchRequirement( + DeclRef satisfyingMemberDeclRef, + DeclRef requiredMemberDeclRef, + RefPtr witnessTable) +{ + if (satisfyingMemberDeclRef.getDecl()->hasModifier() != + requiredMemberDeclRef.getDecl()->hasModifier()) + { + // A `[mutating]` method can't satisfy a non-`[mutating]` requirement. + // The opposite direction is okay, but we will need to synthesize a wrapper + // to ensure type matches, so we will return false here either way. + return false; + } - return true; + if (satisfyingMemberDeclRef.getDecl()->hasModifier() != + requiredMemberDeclRef.getDecl()->hasModifier()) + { + // A `[constref]` method can't satisfy a non-`[constref]` requirement. + // The opposite direction is okay, but we will need to synthesize a wrapper + // to ensure type matches, so we will return false here either way. + return false; } - bool SemanticsVisitor::doesPropertyMatchRequirement( - DeclRef satisfyingMemberDeclRef, - DeclRef requiredMemberDeclRef, - RefPtr witnessTable) + if (satisfyingMemberDeclRef.getDecl()->hasModifier() != + requiredMemberDeclRef.getDecl()->hasModifier()) { - // The type of the satisfying member must match the type of the required member. - // - // Note: It is possible that a `get`-only property could be satisfied by - // a declaration that uses a subtype of the requirement, but that would not - // count as an "exact match" and we would rely on the logic to synthesize - // a stub implementation in that case. - // - auto satisfyingType = getType(getASTBuilder(), satisfyingMemberDeclRef); - auto requiredType = getType(getASTBuilder(), requiredMemberDeclRef); - if(!satisfyingType->equals(requiredType)) + // A `[ref]` method can't satisfy a non-`[ref]` requirement. + // The opposite direction is okay, but we will need to synthesize a wrapper + // to ensure type matches, so we will return false here either way. + return false; + } + + if (satisfyingMemberDeclRef.getDecl()->hasModifier() != + requiredMemberDeclRef.getDecl()->hasModifier()) + { + // A `static` method can't satisfy a non-`static` requirement and vice versa. + return false; + } + + bool hasBackwardDerivative = false; + bool hasForwardDerivative = false; + if (requiredMemberDeclRef.getDecl()->hasModifier()) + { + auto funcDecl = as(satisfyingMemberDeclRef.getDecl()); + if (!funcDecl) return false; - // Each accessor in the requirement must be accounted for by an accessor - // in the satisfying member. - // - // Note: it is fine for the satisfying member to provide *more* accessors - // than the original declaration. - // - Dictionary, DeclRef> mapRequiredToSatisfyingAccessorDeclRef; - for( auto requiredAccessorDeclRef : getMembersOfType(m_astBuilder, requiredMemberDeclRef) ) + if (getShared()->getFuncDifferentiableLevel(funcDecl) != + FunctionDifferentiableLevel::Backward) { - // We need to search for an accessor that can satisfy the requirement. - // - // For now we will do the simplest (and slowest) thing of a linear search, - // which is mostly fine because the number of accessors is bounded. - // - bool found = false; - for( auto satisfyingAccessorDeclRef : getMembersOfType(m_astBuilder, satisfyingMemberDeclRef) ) - { - if( doesAccessorMatchRequirement(satisfyingAccessorDeclRef, requiredAccessorDeclRef) ) - { - // When we find a match on an accessor, we record it so that - // we can set up the witness values later, but we do *not* - // record it into the actual witness table yet, in case - // a later accessor comes along that doesn't find a match. - // - mapRequiredToSatisfyingAccessorDeclRef.add(requiredAccessorDeclRef, satisfyingAccessorDeclRef); - found = true; - break; - } - } - if(!found) - return false; + // A non-`BackwardDifferentiable` method can't satisfy a `BackwardDifferentiable` + // requirement and vice versa. + return false; } - - // Once things are done, we will install the satisfying values - // into the witness table for the requirements. - // - for( const auto& [key, value] : mapRequiredToSatisfyingAccessorDeclRef ) + hasBackwardDerivative = true; + hasForwardDerivative = true; + } + else if (requiredMemberDeclRef.getDecl()->hasModifier()) + { + auto funcDecl = as(satisfyingMemberDeclRef.getDecl()); + if (!funcDecl) + return false; + if (getShared()->getFuncDifferentiableLevel(funcDecl) == FunctionDifferentiableLevel::None) { - witnessTable->add( - key.getDecl(), - RequirementWitness(value)); + // A non-`BackwardDifferentiable` method can't satisfy a `BackwardDifferentiable` + // requirement and vice versa. + return false; } - // - // Note: the property declaration itself isn't something that - // has a useful value/representation in downstream passes, so - // we are mostly just installing it into the witness table - // as a way to mark this requirement as being satisfied. - // - // TODO: It is possible that having a witness table entry that - // doesn't actually map to any IR value could create a problem - // in downstream passes. If such propblems arise, we should - // probably create a new `RequirementWitness` case that - // represents a witness value that is only needed by the front-end, - // and that can be ignored by IR and emit logic. - // - witnessTable->add( - requiredMemberDeclRef.getDecl(), - RequirementWitness(satisfyingMemberDeclRef)); - return true; + hasForwardDerivative = true; } - bool SemanticsVisitor::doesSubscriptMatchRequirement( - DeclRef satisfyingMemberDeclRef, - DeclRef requiredMemberDeclRef, - RefPtr witnessTable) + // A signature matches the required one if it has the right number of parameters, + // and those parameters have the right types, and also the result/return type + // is the required one. + // + auto requiredParams = getParameters(m_astBuilder, requiredMemberDeclRef).toArray(); + auto satisfyingParams = getParameters(m_astBuilder, satisfyingMemberDeclRef).toArray(); + auto paramCount = requiredParams.getCount(); + if (satisfyingParams.getCount() != paramCount) + return false; + + for (Index paramIndex = 0; paramIndex < paramCount; ++paramIndex) { - // The result type and parameters of the satisfying member must match the type of the required member. - // - auto requiredParams = getParameters(m_astBuilder, requiredMemberDeclRef).toArray(); - auto satisfyingParams = getParameters(m_astBuilder, satisfyingMemberDeclRef).toArray(); - auto paramCount = requiredParams.getCount(); - if (satisfyingParams.getCount() != paramCount) - return false; + auto requiredParam = requiredParams[paramIndex]; + auto satisfyingParam = satisfyingParams[paramIndex]; - for (Index paramIndex = 0; paramIndex < paramCount; ++paramIndex) - { - auto requiredParam = requiredParams[paramIndex]; - auto satisfyingParam = satisfyingParams[paramIndex]; + auto requiredParamType = getType(m_astBuilder, requiredParam); + auto satisfyingParamType = getType(m_astBuilder, satisfyingParam); - auto requiredParamType = getType(m_astBuilder, requiredParam); - auto satisfyingParamType = getType(m_astBuilder, satisfyingParam); + if (!requiredParamType->equals(satisfyingParamType)) + return false; + } + + auto requiredResultType = getResultType(m_astBuilder, requiredMemberDeclRef); + auto satisfyingResultType = getResultType(m_astBuilder, satisfyingMemberDeclRef); + if (!requiredResultType->equals(satisfyingResultType)) + return false; - if (!requiredParamType->equals(satisfyingParamType)) + if (hasForwardDerivative || hasBackwardDerivative) + { + auto parentInterfaceDecl = + as(getParentDecl(requiredMemberDeclRef.getDecl())); + if (parentInterfaceDecl) + { + bool noDiffThisSatisfying = !isTypeDifferentiable(witnessTable->witnessedType); + bool noDiffThisRequirement = + (requiredMemberDeclRef.getDecl()->findModifier() != nullptr); + if (noDiffThisRequirement != noDiffThisSatisfying) return false; } + } - auto requiredResultType = getResultType(m_astBuilder, requiredMemberDeclRef); - auto satisfyingResultType = getResultType(m_astBuilder, satisfyingMemberDeclRef); - if (!requiredResultType->equals(satisfyingResultType)) - return false; + _addMethodWitness(witnessTable, requiredMemberDeclRef, satisfyingMemberDeclRef); + + return true; +} + +bool SemanticsVisitor::doesAccessorMatchRequirement( + DeclRef satisfyingMemberDeclRef, + DeclRef requiredMemberDeclRef) +{ + // We require the AST node class of the satisfying accessor + // to be a subclass of the one from the required accessor. + // + // For our current accessor types, this amounts to requiring + // an exact match, but using a subtype test means that if + // we ever add an `ExtraSpecialGetDecl` that is a subclass + // of `GetDecl`, then one of those would be able to satisfy + // a `get` requirement. + // + auto satisfyingMemberClass = satisfyingMemberDeclRef.getDecl()->getClass(); + auto requiredMemberClass = requiredMemberDeclRef.getDecl()->getClass(); + if (!satisfyingMemberClass.isSubClassOfImpl(requiredMemberClass)) + return false; + + // We do not check the parameters or return types of accessors + // here, under the assumption that the validity checks for + // the parent `property` declaration would already make sure + // they are in order. + + // TODO: There are other checks we need to make here, like not letting + // an ordinary `set` satisfy a `[nonmutating] set` requirement. + + return true; +} + +bool SemanticsVisitor::doesPropertyMatchRequirement( + DeclRef satisfyingMemberDeclRef, + DeclRef requiredMemberDeclRef, + RefPtr witnessTable) +{ + // The type of the satisfying member must match the type of the required member. + // + // Note: It is possible that a `get`-only property could be satisfied by + // a declaration that uses a subtype of the requirement, but that would not + // count as an "exact match" and we would rely on the logic to synthesize + // a stub implementation in that case. + // + auto satisfyingType = getType(getASTBuilder(), satisfyingMemberDeclRef); + auto requiredType = getType(getASTBuilder(), requiredMemberDeclRef); + if (!satisfyingType->equals(requiredType)) + return false; - // Each accessor in the requirement must be accounted for by an accessor - // in the satisfying member. + // Each accessor in the requirement must be accounted for by an accessor + // in the satisfying member. + // + // Note: it is fine for the satisfying member to provide *more* accessors + // than the original declaration. + // + Dictionary, DeclRef> mapRequiredToSatisfyingAccessorDeclRef; + for (auto requiredAccessorDeclRef : + getMembersOfType(m_astBuilder, requiredMemberDeclRef)) + { + // We need to search for an accessor that can satisfy the requirement. // - // Note: it is fine for the satisfying member to provide *more* accessors - // than the original declaration. + // For now we will do the simplest (and slowest) thing of a linear search, + // which is mostly fine because the number of accessors is bounded. // - Dictionary, DeclRef> mapRequiredToSatisfyingAccessorDeclRef; - for (auto requiredAccessorDeclRef : getMembersOfType(m_astBuilder, requiredMemberDeclRef)) + bool found = false; + for (auto satisfyingAccessorDeclRef : + getMembersOfType(m_astBuilder, satisfyingMemberDeclRef)) { - // We need to search for an accessor that can satisfy the requirement. - // - // For now we will do the simplest (and slowest) thing of a linear search, - // which is mostly fine because the number of accessors is bounded. - // - bool found = false; - for (auto satisfyingAccessorDeclRef : getMembersOfType(m_astBuilder, satisfyingMemberDeclRef)) + if (doesAccessorMatchRequirement(satisfyingAccessorDeclRef, requiredAccessorDeclRef)) { - if (doesAccessorMatchRequirement(satisfyingAccessorDeclRef, requiredAccessorDeclRef)) - { - // When we find a match on an accessor, we record it so that - // we can set up the witness values later, but we do *not* - // record it into the actual witness table yet, in case - // a later accessor comes along that doesn't find a match. - // - mapRequiredToSatisfyingAccessorDeclRef.add(requiredAccessorDeclRef, satisfyingAccessorDeclRef); - found = true; - break; - } + // When we find a match on an accessor, we record it so that + // we can set up the witness values later, but we do *not* + // record it into the actual witness table yet, in case + // a later accessor comes along that doesn't find a match. + // + mapRequiredToSatisfyingAccessorDeclRef.add( + requiredAccessorDeclRef, + satisfyingAccessorDeclRef); + found = true; + break; } - if (!found) - return false; - } - - // Once things are done, we will install the satisfying values - // into the witness table for the requirements. - // - for (const auto& [key, value] : mapRequiredToSatisfyingAccessorDeclRef) - { - witnessTable->add( - key.getDecl(), - RequirementWitness(value)); } - // - // Note: the subscript declaration itself isn't something that - // has a useful value/representation in downstream passes, so - // we are mostly just installing it into the witness table - // as a way to mark this requirement as being satisfied. - // - witnessTable->add( - requiredMemberDeclRef.getDecl(), - RequirementWitness(satisfyingMemberDeclRef)); - return true; + if (!found) + return false; } - bool SemanticsVisitor::doesVarMatchRequirement( - DeclRef satisfyingMemberDeclRef, - DeclRef requiredMemberDeclRef, - RefPtr witnessTable) + // Once things are done, we will install the satisfying values + // into the witness table for the requirements. + // + for (const auto& [key, value] : mapRequiredToSatisfyingAccessorDeclRef) + { + witnessTable->add(key.getDecl(), RequirementWitness(value)); + } + // + // Note: the property declaration itself isn't something that + // has a useful value/representation in downstream passes, so + // we are mostly just installing it into the witness table + // as a way to mark this requirement as being satisfied. + // + // TODO: It is possible that having a witness table entry that + // doesn't actually map to any IR value could create a problem + // in downstream passes. If such propblems arise, we should + // probably create a new `RequirementWitness` case that + // represents a witness value that is only needed by the front-end, + // and that can be ignored by IR and emit logic. + // + witnessTable->add(requiredMemberDeclRef.getDecl(), RequirementWitness(satisfyingMemberDeclRef)); + return true; +} + +bool SemanticsVisitor::doesSubscriptMatchRequirement( + DeclRef satisfyingMemberDeclRef, + DeclRef requiredMemberDeclRef, + RefPtr witnessTable) +{ + // The result type and parameters of the satisfying member must match the type of the required + // member. + // + auto requiredParams = getParameters(m_astBuilder, requiredMemberDeclRef).toArray(); + auto satisfyingParams = getParameters(m_astBuilder, satisfyingMemberDeclRef).toArray(); + auto paramCount = requiredParams.getCount(); + if (satisfyingParams.getCount() != paramCount) + return false; + + for (Index paramIndex = 0; paramIndex < paramCount; ++paramIndex) { - // The type of the satisfying member must match the type of the required member. - auto satisfyingType = getType(getASTBuilder(), satisfyingMemberDeclRef); - auto requiredType = getType(getASTBuilder(), requiredMemberDeclRef); - if (!satisfyingType->equals(requiredType)) + auto requiredParam = requiredParams[paramIndex]; + auto satisfyingParam = satisfyingParams[paramIndex]; + + auto requiredParamType = getType(m_astBuilder, requiredParam); + auto satisfyingParamType = getType(m_astBuilder, satisfyingParam); + + if (!requiredParamType->equals(satisfyingParamType)) return false; + } + + auto requiredResultType = getResultType(m_astBuilder, requiredMemberDeclRef); + auto satisfyingResultType = getResultType(m_astBuilder, satisfyingMemberDeclRef); + if (!requiredResultType->equals(satisfyingResultType)) + return false; - for (auto modifier : requiredMemberDeclRef.getDecl()->modifiers) + // Each accessor in the requirement must be accounted for by an accessor + // in the satisfying member. + // + // Note: it is fine for the satisfying member to provide *more* accessors + // than the original declaration. + // + Dictionary, DeclRef> mapRequiredToSatisfyingAccessorDeclRef; + for (auto requiredAccessorDeclRef : + getMembersOfType(m_astBuilder, requiredMemberDeclRef)) + { + // We need to search for an accessor that can satisfy the requirement. + // + // For now we will do the simplest (and slowest) thing of a linear search, + // which is mostly fine because the number of accessors is bounded. + // + bool found = false; + for (auto satisfyingAccessorDeclRef : + getMembersOfType(m_astBuilder, satisfyingMemberDeclRef)) { - bool found = false; - for (auto satisfyingModifier : satisfyingMemberDeclRef.getDecl()->modifiers) + if (doesAccessorMatchRequirement(satisfyingAccessorDeclRef, requiredAccessorDeclRef)) { - if (satisfyingModifier->astNodeType == modifier->astNodeType) - { - found = true; - break; - } + // When we find a match on an accessor, we record it so that + // we can set up the witness values later, but we do *not* + // record it into the actual witness table yet, in case + // a later accessor comes along that doesn't find a match. + // + mapRequiredToSatisfyingAccessorDeclRef.add( + requiredAccessorDeclRef, + satisfyingAccessorDeclRef); + found = true; + break; } - if (!found) - return false; } - - auto satisfyingVal = tryConstantFoldDeclRef(satisfyingMemberDeclRef, ConstantFoldingKind::LinkTime, nullptr); - if (satisfyingVal) - { - witnessTable->add( - requiredMemberDeclRef.getDecl(), - RequirementWitness(satisfyingVal)); - } - else - { - witnessTable->add( - requiredMemberDeclRef.getDecl(), - RequirementWitness(satisfyingMemberDeclRef)); - } - return true; + if (!found) + return false; } - bool SemanticsVisitor::doesGenericSignatureMatchRequirement( - DeclRef satisfyingGenericDeclRef, - DeclRef requiredGenericDeclRef, - RefPtr witnessTable) + // Once things are done, we will install the satisfying values + // into the witness table for the requirements. + // + for (const auto& [key, value] : mapRequiredToSatisfyingAccessorDeclRef) { - // The signature of a generic is defiend by its members, and we need the - // satisfying value to have the same number of members for it to be an - // exact match. - // - auto memberCount = requiredGenericDeclRef.getDecl()->members.getCount(); - if(satisfyingGenericDeclRef.getDecl()->members.getCount() != memberCount) - return false; + witnessTable->add(key.getDecl(), RequirementWitness(value)); + } + // + // Note: the subscript declaration itself isn't something that + // has a useful value/representation in downstream passes, so + // we are mostly just installing it into the witness table + // as a way to mark this requirement as being satisfied. + // + witnessTable->add(requiredMemberDeclRef.getDecl(), RequirementWitness(satisfyingMemberDeclRef)); + return true; +} - // We then want to check that pairwise members match, in order. - // - auto requiredMemberDeclRefs = getMembers(m_astBuilder, requiredGenericDeclRef); - auto satisfyingMemberDeclRefs = getMembers(m_astBuilder, satisfyingGenericDeclRef); - // - // We start by performing a superficial "structural" match of the parameters - // to ensure that the two generics have an equivalent mix of type, value, - // and constraint parameters in the same order. - // - // Note that in this step we do *not* make any checks on the actual types - // involved in constraints, or on the types of value parameters. The reason - // for this is that the types on those parameters could be dependent on - // type parameters in the generic parameter list, and thus there could be - // a mismatch at this point. For example, if we have: - // - // interface IBase { void doThing>(); } - // struct Derived : IBase { void doThing>(); } - // - // We clearly have a signature match here, but the constraint parameters for - // `U : IThing` and `Y : IThing` have the problem that both the sub-type - // and super-type they reference are not equivalent without substititions. - // - // We will deal with this issue after the structural matching is checked, at - // which point we can actually verify things like types. - // - for (Index i = 0; i < memberCount; i++) - { - auto requiredMemberDeclRef = requiredMemberDeclRefs[i]; - auto satisfyingMemberDeclRef = satisfyingMemberDeclRefs[i]; +bool SemanticsVisitor::doesVarMatchRequirement( + DeclRef satisfyingMemberDeclRef, + DeclRef requiredMemberDeclRef, + RefPtr witnessTable) +{ + // The type of the satisfying member must match the type of the required member. + auto satisfyingType = getType(getASTBuilder(), satisfyingMemberDeclRef); + auto requiredType = getType(getASTBuilder(), requiredMemberDeclRef); + if (!satisfyingType->equals(requiredType)) + return false; - if (as(requiredMemberDeclRef)) - { - if (as(satisfyingMemberDeclRef)) - { - } - else - return false; - } - else if (auto requiredValueParamDeclRef = requiredMemberDeclRef.as()) - { - if (auto satisfyingValueParamDeclRef = satisfyingMemberDeclRef.as()) - { - } - else - return false; - } - else if (auto requiredConstraintDeclRef = requiredMemberDeclRef.as()) + for (auto modifier : requiredMemberDeclRef.getDecl()->modifiers) + { + bool found = false; + for (auto satisfyingModifier : satisfyingMemberDeclRef.getDecl()->modifiers) + { + if (satisfyingModifier->astNodeType == modifier->astNodeType) { - if (auto satisfyingConstraintDeclRef = satisfyingMemberDeclRef.as()) - { - } - else - return false; + found = true; + break; } } + if (!found) + return false; + } - // In order to compare the inner declarations of the two generics, we need to - // align them so that they are expressed in terms of consistent type parameters. - // - // For example, we might have: - // - // interface IBase { void doThing(T val); } - // struct Derived : IBase { void doThing(U val); } - // - // If we directly compare the signatures of the inner `doThing` function declarations, - // we'd find a mismatch between the `T` and `U` types of the `val` parameter. - // - // We can get around this mismatch by constructing a specialized reference and - // then doing the comparison. For example `IBase::doThing` and `Derived::doThing` - // should both have the signature `X -> void`. - // - // The one big detail that we need to be careful about here is that when we - // recursively call `doesMemberSatisfyRequirement`, that will eventually store - // the satisfying `DeclRef` as the value for the given requirement key, and we don't - // want to store a specialized reference like `Derived::doThing` - we need to - // somehow store the original declaration. - // - // The solution here is to specialize the *required* declaration to the parameters - // of the satisfying declaration. In the example above that means we are going to - // compare `Derived::doThing` against `IBase::doThing` where the `U` there is - // the parameter of `Dervived::doThing`. - // - List requiredSubstArgs; - - for (Index i = 0; i < memberCount; i++) - { - auto requiredMemberDeclRef = requiredMemberDeclRefs[i]; - auto satisfyingMemberDeclRef = satisfyingMemberDeclRefs[i]; + auto satisfyingVal = + tryConstantFoldDeclRef(satisfyingMemberDeclRef, ConstantFoldingKind::LinkTime, nullptr); + if (satisfyingVal) + { + witnessTable->add(requiredMemberDeclRef.getDecl(), RequirementWitness(satisfyingVal)); + } + else + { + witnessTable->add( + requiredMemberDeclRef.getDecl(), + RequirementWitness(satisfyingMemberDeclRef)); + } + return true; +} - if(auto requiredTypeParamDeclRef = requiredMemberDeclRef.as()) - { - auto satisfyingTypeParamDeclRef = satisfyingMemberDeclRef.as(); - SLANG_ASSERT(satisfyingTypeParamDeclRef); - auto satisfyingType = DeclRefType::create(m_astBuilder, satisfyingTypeParamDeclRef); +bool SemanticsVisitor::doesGenericSignatureMatchRequirement( + DeclRef satisfyingGenericDeclRef, + DeclRef requiredGenericDeclRef, + RefPtr witnessTable) +{ + // The signature of a generic is defiend by its members, and we need the + // satisfying value to have the same number of members for it to be an + // exact match. + // + auto memberCount = requiredGenericDeclRef.getDecl()->members.getCount(); + if (satisfyingGenericDeclRef.getDecl()->members.getCount() != memberCount) + return false; - requiredSubstArgs.add(satisfyingType); - } - else if (auto requiredValueParamDeclRef = requiredMemberDeclRef.as()) + // We then want to check that pairwise members match, in order. + // + auto requiredMemberDeclRefs = getMembers(m_astBuilder, requiredGenericDeclRef); + auto satisfyingMemberDeclRefs = getMembers(m_astBuilder, satisfyingGenericDeclRef); + // + // We start by performing a superficial "structural" match of the parameters + // to ensure that the two generics have an equivalent mix of type, value, + // and constraint parameters in the same order. + // + // Note that in this step we do *not* make any checks on the actual types + // involved in constraints, or on the types of value parameters. The reason + // for this is that the types on those parameters could be dependent on + // type parameters in the generic parameter list, and thus there could be + // a mismatch at this point. For example, if we have: + // + // interface IBase { void doThing>(); } + // struct Derived : IBase { void doThing>(); } + // + // We clearly have a signature match here, but the constraint parameters for + // `U : IThing` and `Y : IThing` have the problem that both the sub-type + // and super-type they reference are not equivalent without substititions. + // + // We will deal with this issue after the structural matching is checked, at + // which point we can actually verify things like types. + // + for (Index i = 0; i < memberCount; i++) + { + auto requiredMemberDeclRef = requiredMemberDeclRefs[i]; + auto satisfyingMemberDeclRef = satisfyingMemberDeclRefs[i]; + + if (as(requiredMemberDeclRef)) + { + if (as(satisfyingMemberDeclRef)) { - auto satisfyingValueParamDeclRef = satisfyingMemberDeclRef.as(); - SLANG_ASSERT(satisfyingValueParamDeclRef); - - auto satisfyingVal = m_astBuilder->getOrCreate( - requiredValueParamDeclRef.getDecl()->getType(), - satisfyingValueParamDeclRef); - satisfyingVal->getDeclRef() = satisfyingValueParamDeclRef; - - requiredSubstArgs.add(satisfyingVal); } + else + return false; } - for (Index i = 0; i < memberCount; i++) + else if (auto requiredValueParamDeclRef = requiredMemberDeclRef.as()) { - auto requiredMemberDeclRef = requiredMemberDeclRefs[i]; - auto satisfyingMemberDeclRef = satisfyingMemberDeclRefs[i]; - - if(auto requiredConstraintDeclRef = requiredMemberDeclRef.as()) + if (auto satisfyingValueParamDeclRef = + satisfyingMemberDeclRef.as()) { - auto satisfyingConstraintDeclRef = satisfyingMemberDeclRef.as(); - SLANG_ASSERT(satisfyingConstraintDeclRef); - - auto satisfyingWitness = m_astBuilder->getDeclaredSubtypeWitness( - getSub(m_astBuilder, satisfyingConstraintDeclRef), - getSup(m_astBuilder, satisfyingConstraintDeclRef), - satisfyingConstraintDeclRef); - - requiredSubstArgs.add(satisfyingWitness); } + else + return false; } - - // Now that we have computed a set of specialization arguments that will - // specialize the generic requirement at the type parameters of the satisfying - // generic, we can construct a reference to that declaration and re-run some - // of the earlier checking logic with more type information usable. - // - auto specializedRequiredGenericInnerDeclRef = m_astBuilder->getGenericAppDeclRef( - requiredGenericDeclRef, requiredSubstArgs.getArrayView()); - for (Index i = 0; i < memberCount; i++) + else if ( + auto requiredConstraintDeclRef = requiredMemberDeclRef.as()) { - auto requiredMemberDeclRef = requiredMemberDeclRefs[i]; - auto satisfyingMemberDeclRef = satisfyingMemberDeclRefs[i]; - - if(auto requiredTypeParamDeclRef = requiredMemberDeclRef.as()) + if (auto satisfyingConstraintDeclRef = + satisfyingMemberDeclRef.as()) { - [[maybe_unused]] auto satisfyingTypeParamDeclRef = satisfyingMemberDeclRef.as(); - SLANG_ASSERT(satisfyingTypeParamDeclRef); - - // There are no additional checks we need to make on plain old - // type parameters at this point. - // - // TODO: If we ever support having type parameters of higher kinds, - // then this is possibly where we'd want to check that the kinds of - // the two parameters match. - // } - else if (auto requiredValueParamDeclRef = requiredMemberDeclRef.as()) - { - auto satisfyingValueParamDeclRef = satisfyingMemberDeclRef.as(); - SLANG_ASSERT(satisfyingValueParamDeclRef); + else + return false; + } + } - // For a generic value parameter, we need to check that the required - // and satisfying declaration both agree on the type of the parameter. - // - auto requiredParamType = getType(m_astBuilder, requiredValueParamDeclRef); - auto satisfyingParamType = getType(m_astBuilder, satisfyingValueParamDeclRef); - if (!satisfyingParamType->equals(requiredParamType)) - return false; - } - else if(auto requiredConstraintDeclRef = requiredMemberDeclRef.as()) - { - auto satisfyingConstraintDeclRef = satisfyingMemberDeclRef.as(); - SLANG_ASSERT(satisfyingConstraintDeclRef); + // In order to compare the inner declarations of the two generics, we need to + // align them so that they are expressed in terms of consistent type parameters. + // + // For example, we might have: + // + // interface IBase { void doThing(T val); } + // struct Derived : IBase { void doThing(U val); } + // + // If we directly compare the signatures of the inner `doThing` function declarations, + // we'd find a mismatch between the `T` and `U` types of the `val` parameter. + // + // We can get around this mismatch by constructing a specialized reference and + // then doing the comparison. For example `IBase::doThing` and `Derived::doThing` + // should both have the signature `X -> void`. + // + // The one big detail that we need to be careful about here is that when we + // recursively call `doesMemberSatisfyRequirement`, that will eventually store + // the satisfying `DeclRef` as the value for the given requirement key, and we don't + // want to store a specialized reference like `Derived::doThing` - we need to + // somehow store the original declaration. + // + // The solution here is to specialize the *required* declaration to the parameters + // of the satisfying declaration. In the example above that means we are going to + // compare `Derived::doThing` against `IBase::doThing` where the `U` there is + // the parameter of `Dervived::doThing`. + // + List requiredSubstArgs; - // For a generic constraint parameter, we need to check that the sub-type - // and super-type in the constraint both match. - // - // In current code the sub type will always be one of the generic type parameters, - // and the super-type will always be an interface, but there should be no - // need to make use of those additional details here. - auto specializedRequiredConstraintDeclRef = m_astBuilder->getGenericAppDeclRef( - requiredGenericDeclRef, - requiredSubstArgs.getArrayView(), - requiredConstraintDeclRef.getDecl()).as(); - auto requiredSubType = getSub(m_astBuilder, specializedRequiredConstraintDeclRef); - auto satisfyingSubType = getSub(m_astBuilder, satisfyingConstraintDeclRef); - if (!satisfyingSubType->equals(requiredSubType)) - return false; + for (Index i = 0; i < memberCount; i++) + { + auto requiredMemberDeclRef = requiredMemberDeclRefs[i]; + auto satisfyingMemberDeclRef = satisfyingMemberDeclRefs[i]; - auto requiredSuperType = getSup(m_astBuilder, specializedRequiredConstraintDeclRef); - auto satisfyingSuperType = getSup(m_astBuilder, satisfyingConstraintDeclRef); - if (!satisfyingSuperType->equals(requiredSuperType)) - return false; - } + if (auto requiredTypeParamDeclRef = requiredMemberDeclRef.as()) + { + auto satisfyingTypeParamDeclRef = satisfyingMemberDeclRef.as(); + SLANG_ASSERT(satisfyingTypeParamDeclRef); + auto satisfyingType = DeclRefType::create(m_astBuilder, satisfyingTypeParamDeclRef); + + requiredSubstArgs.add(satisfyingType); } + else if (auto requiredValueParamDeclRef = requiredMemberDeclRef.as()) + { + auto satisfyingValueParamDeclRef = satisfyingMemberDeclRef.as(); + SLANG_ASSERT(satisfyingValueParamDeclRef); - // Note: the above logic really only applies to the case of an exact match on signature, - // even down to the way that constraints were declared. We could potentially be more - // relaxed by taking advantage of the way that various different generic signatures will - // actually lower to the same IR generic signature. - // - // In theory, all we really care about when it comes to constraints is that the constraints - // on the required and satisfying declaration are *equivalent*. - // - // More generally, a satisfying generic could actually provide *looser* constraints and - // still work; all that matters is that it can be instantiated at any argument values/types - // that are valid for the requirement. - // - // We leave both of those issues up to the synthesis path: if we do not find a member that - // provides an exact match, then the compiler should try to synthesize one that is an exact - // match and makes use of existing declarations that might have require defaulting of arguments - // or type conversations to fit. + auto satisfyingVal = m_astBuilder->getOrCreate( + requiredValueParamDeclRef.getDecl()->getType(), + satisfyingValueParamDeclRef); + satisfyingVal->getDeclRef() = satisfyingValueParamDeclRef; - // Once we've validated that the generic signatures are in an exact match, and devised type - // arguments for the requirement to make the two align, we can recursively check the inner - // declaration (whatever it is) for an exact match. - // - return doesMemberSatisfyRequirement( - m_astBuilder->getMemberDeclRef(satisfyingGenericDeclRef, getInner(satisfyingGenericDeclRef)), - specializedRequiredGenericInnerDeclRef, - witnessTable); + requiredSubstArgs.add(satisfyingVal); + } } - - bool SemanticsVisitor::doesTypeSatisfyAssociatedTypeConstraintRequirement(Type* satisfyingType, DeclRef requiredAssociatedTypeDeclRef, RefPtr witnessTable) + for (Index i = 0; i < memberCount; i++) { - SLANG_UNUSED(satisfyingType); + auto requiredMemberDeclRef = requiredMemberDeclRefs[i]; + auto satisfyingMemberDeclRef = satisfyingMemberDeclRefs[i]; - // We will enumerate the type constraints placed on the - // associated type and see if they can be satisfied. - // - bool conformance = true; - Val* witness = nullptr; - for (auto requiredConstraintDeclRef : getMembersOfType(m_astBuilder, requiredAssociatedTypeDeclRef)) + if (auto requiredConstraintDeclRef = requiredMemberDeclRef.as()) { - // Grab the type we expect to conform to from the constraint. - auto requiredSuperType = getSup(m_astBuilder, requiredConstraintDeclRef); + auto satisfyingConstraintDeclRef = + satisfyingMemberDeclRef.as(); + SLANG_ASSERT(satisfyingConstraintDeclRef); - auto subType = getSub(m_astBuilder, requiredConstraintDeclRef); + auto satisfyingWitness = m_astBuilder->getDeclaredSubtypeWitness( + getSub(m_astBuilder, satisfyingConstraintDeclRef), + getSup(m_astBuilder, satisfyingConstraintDeclRef), + satisfyingConstraintDeclRef); - // Perform a search for a witness to the subtype relationship. - witness = tryGetSubtypeWitness(subType, requiredSuperType); - if (witness) - { - auto genConstraint = as(requiredConstraintDeclRef.getDecl()); - if (genConstraint && genConstraint->isEqualityConstraint && !isTypeEqualityWitness(witness)) - witness = nullptr; - } - if (witness) - { - // If a subtype witness was found, then the conformance - // appears to hold, and we can satisfy that requirement. - witnessTable->add(requiredConstraintDeclRef.getDecl(), RequirementWitness(witness)); - } - else - { - // If a witness couldn't be found, then the conformance - // seems like it will fail. - conformance = false; - } + requiredSubstArgs.add(satisfyingWitness); } - return conformance; } - bool SemanticsVisitor::doesTypeSatisfyAssociatedTypeRequirement( - Type* satisfyingType, - DeclRef requiredAssociatedTypeDeclRef, - RefPtr witnessTable) + // Now that we have computed a set of specialization arguments that will + // specialize the generic requirement at the type parameters of the satisfying + // generic, we can construct a reference to that declaration and re-run some + // of the earlier checking logic with more type information usable. + // + auto specializedRequiredGenericInnerDeclRef = m_astBuilder->getGenericAppDeclRef( + requiredGenericDeclRef, + requiredSubstArgs.getArrayView()); + for (Index i = 0; i < memberCount; i++) { - if (auto declRefType = as(satisfyingType)) + auto requiredMemberDeclRef = requiredMemberDeclRefs[i]; + auto satisfyingMemberDeclRef = satisfyingMemberDeclRefs[i]; + + if (auto requiredTypeParamDeclRef = requiredMemberDeclRef.as()) + { + [[maybe_unused]] auto satisfyingTypeParamDeclRef = + satisfyingMemberDeclRef.as(); + SLANG_ASSERT(satisfyingTypeParamDeclRef); + + // There are no additional checks we need to make on plain old + // type parameters at this point. + // + // TODO: If we ever support having type parameters of higher kinds, + // then this is possibly where we'd want to check that the kinds of + // the two parameters match. + // + } + else if (auto requiredValueParamDeclRef = requiredMemberDeclRef.as()) { - // If we are seeing a placeholder that awaits synthesis, return false now to trigger - // auto synthesis. - if (declRefType->getDeclRef().getDecl()->hasModifier()) + auto satisfyingValueParamDeclRef = satisfyingMemberDeclRef.as(); + SLANG_ASSERT(satisfyingValueParamDeclRef); + + // For a generic value parameter, we need to check that the required + // and satisfying declaration both agree on the type of the parameter. + // + auto requiredParamType = getType(m_astBuilder, requiredValueParamDeclRef); + auto satisfyingParamType = getType(m_astBuilder, satisfyingValueParamDeclRef); + if (!satisfyingParamType->equals(requiredParamType)) return false; } + else if ( + auto requiredConstraintDeclRef = requiredMemberDeclRef.as()) + { + auto satisfyingConstraintDeclRef = + satisfyingMemberDeclRef.as(); + SLANG_ASSERT(satisfyingConstraintDeclRef); + + // For a generic constraint parameter, we need to check that the sub-type + // and super-type in the constraint both match. + // + // In current code the sub type will always be one of the generic type parameters, + // and the super-type will always be an interface, but there should be no + // need to make use of those additional details here. + auto specializedRequiredConstraintDeclRef = m_astBuilder + ->getGenericAppDeclRef( + requiredGenericDeclRef, + requiredSubstArgs.getArrayView(), + requiredConstraintDeclRef.getDecl()) + .as(); + auto requiredSubType = getSub(m_astBuilder, specializedRequiredConstraintDeclRef); + auto satisfyingSubType = getSub(m_astBuilder, satisfyingConstraintDeclRef); + if (!satisfyingSubType->equals(requiredSubType)) + return false; + + auto requiredSuperType = getSup(m_astBuilder, specializedRequiredConstraintDeclRef); + auto satisfyingSuperType = getSup(m_astBuilder, satisfyingConstraintDeclRef); + if (!satisfyingSuperType->equals(requiredSuperType)) + return false; + } + } - // Register the satisfying type to the witness table - // before checking the constraints, since the subtype of - // the constraints maybe referencing the satisfying type via - // witness lookups. - auto requirementWitness = RequirementWitness(satisfyingType->getCanonicalType()); - witnessTable->m_requirementDictionary[requiredAssociatedTypeDeclRef.getDecl()] - = requirementWitness; + // Note: the above logic really only applies to the case of an exact match on signature, + // even down to the way that constraints were declared. We could potentially be more + // relaxed by taking advantage of the way that various different generic signatures will + // actually lower to the same IR generic signature. + // + // In theory, all we really care about when it comes to constraints is that the constraints + // on the required and satisfying declaration are *equivalent*. + // + // More generally, a satisfying generic could actually provide *looser* constraints and + // still work; all that matters is that it can be instantiated at any argument values/types + // that are valid for the requirement. + // + // We leave both of those issues up to the synthesis path: if we do not find a member that + // provides an exact match, then the compiler should try to synthesize one that is an exact + // match and makes use of existing declarations that might have require defaulting of arguments + // or type conversations to fit. + + // Once we've validated that the generic signatures are in an exact match, and devised type + // arguments for the requirement to make the two align, we can recursively check the inner + // declaration (whatever it is) for an exact match. + // + return doesMemberSatisfyRequirement( + m_astBuilder->getMemberDeclRef( + satisfyingGenericDeclRef, + getInner(satisfyingGenericDeclRef)), + specializedRequiredGenericInnerDeclRef, + witnessTable); +} - // We need to confirm that the chosen type `satisfyingType`, - // meets all the constraints placed on the associated type - // requirement `requiredAssociatedTypeDeclRef`. - // - // We will enumerate the type constraints placed on the - // associated type and see if they can be satisfied. - // - bool conformance = doesTypeSatisfyAssociatedTypeConstraintRequirement( - satisfyingType, requiredAssociatedTypeDeclRef, witnessTable); +bool SemanticsVisitor::doesTypeSatisfyAssociatedTypeConstraintRequirement( + Type* satisfyingType, + DeclRef requiredAssociatedTypeDeclRef, + RefPtr witnessTable) +{ + SLANG_UNUSED(satisfyingType); + + // We will enumerate the type constraints placed on the + // associated type and see if they can be satisfied. + // + bool conformance = true; + Val* witness = nullptr; + for (auto requiredConstraintDeclRef : + getMembersOfType(m_astBuilder, requiredAssociatedTypeDeclRef)) + { + // Grab the type we expect to conform to from the constraint. + auto requiredSuperType = getSup(m_astBuilder, requiredConstraintDeclRef); - // TODO: if any conformance check failed, we should probably include - // that in an error message produced about not satisfying the requirement. + auto subType = getSub(m_astBuilder, requiredConstraintDeclRef); - if (!conformance) + // Perform a search for a witness to the subtype relationship. + witness = tryGetSubtypeWitness(subType, requiredSuperType); + if (witness) + { + auto genConstraint = as(requiredConstraintDeclRef.getDecl()); + if (genConstraint && genConstraint->isEqualityConstraint && + !isTypeEqualityWitness(witness)) + witness = nullptr; + } + if (witness) + { + // If a subtype witness was found, then the conformance + // appears to hold, and we can satisfy that requirement. + witnessTable->add(requiredConstraintDeclRef.getDecl(), RequirementWitness(witness)); + } + else { - witnessTable->m_requirementDictionary.remove(requiredAssociatedTypeDeclRef.getDecl()); + // If a witness couldn't be found, then the conformance + // seems like it will fail. + conformance = false; } - - return conformance; } + return conformance; +} - bool SemanticsVisitor::doesMemberSatisfyRequirement( - DeclRef memberDeclRef, - DeclRef requiredMemberDeclRef, - RefPtr witnessTable) +bool SemanticsVisitor::doesTypeSatisfyAssociatedTypeRequirement( + Type* satisfyingType, + DeclRef requiredAssociatedTypeDeclRef, + RefPtr witnessTable) +{ + if (auto declRefType = as(satisfyingType)) { - // Sanity check: if are checking whether a type `T` - // implements, say, `IFoo::bar` and lookup of `bar` - // in type `T` yielded `IFoo::bar`, then that shouldn't - // be treated as a valid satisfaction of the requirement. - // - // TODO: Ideally this check should be comparing the `DeclRef`s - // and not just the `Decl`s, but we currently don't get exactly - // the same substitutions when we see the inherited `IFoo::bar`. - // - if(memberDeclRef.getDecl() == requiredMemberDeclRef.getDecl()) + // If we are seeing a placeholder that awaits synthesis, return false now to trigger + // auto synthesis. + if (declRefType->getDeclRef().getDecl()->hasModifier()) return false; + } - // At a high level, we want to check that the - // `memberDecl` and the `requiredMemberDeclRef` - // have the same AST node class, and then also - // check that their signatures match. - // - // There are a bunch of detailed decisions that - // have to be made, though, because we might, e.g., - // allow a function with more general parameter - // types to satisfy a requirement with more - // specific parameter types. - // - // If we ever allow for "property" declarations, - // then we would probably need to allow an - // ordinary field to satisfy a property requirement. - // - // An associated type requirement should be allowed - // to be satisfied by any type declaration: - // a typedef, a `struct`, etc. - // - if (auto memberFuncDecl = memberDeclRef.as()) + // Register the satisfying type to the witness table + // before checking the constraints, since the subtype of + // the constraints maybe referencing the satisfying type via + // witness lookups. + auto requirementWitness = RequirementWitness(satisfyingType->getCanonicalType()); + witnessTable->m_requirementDictionary[requiredAssociatedTypeDeclRef.getDecl()] = + requirementWitness; + + // We need to confirm that the chosen type `satisfyingType`, + // meets all the constraints placed on the associated type + // requirement `requiredAssociatedTypeDeclRef`. + // + // We will enumerate the type constraints placed on the + // associated type and see if they can be satisfied. + // + bool conformance = doesTypeSatisfyAssociatedTypeConstraintRequirement( + satisfyingType, + requiredAssociatedTypeDeclRef, + witnessTable); + + // TODO: if any conformance check failed, we should probably include + // that in an error message produced about not satisfying the requirement. + + if (!conformance) + { + witnessTable->m_requirementDictionary.remove(requiredAssociatedTypeDeclRef.getDecl()); + } + + return conformance; +} + +bool SemanticsVisitor::doesMemberSatisfyRequirement( + DeclRef memberDeclRef, + DeclRef requiredMemberDeclRef, + RefPtr witnessTable) +{ + // Sanity check: if are checking whether a type `T` + // implements, say, `IFoo::bar` and lookup of `bar` + // in type `T` yielded `IFoo::bar`, then that shouldn't + // be treated as a valid satisfaction of the requirement. + // + // TODO: Ideally this check should be comparing the `DeclRef`s + // and not just the `Decl`s, but we currently don't get exactly + // the same substitutions when we see the inherited `IFoo::bar`. + // + if (memberDeclRef.getDecl() == requiredMemberDeclRef.getDecl()) + return false; + + // At a high level, we want to check that the + // `memberDecl` and the `requiredMemberDeclRef` + // have the same AST node class, and then also + // check that their signatures match. + // + // There are a bunch of detailed decisions that + // have to be made, though, because we might, e.g., + // allow a function with more general parameter + // types to satisfy a requirement with more + // specific parameter types. + // + // If we ever allow for "property" declarations, + // then we would probably need to allow an + // ordinary field to satisfy a property requirement. + // + // An associated type requirement should be allowed + // to be satisfied by any type declaration: + // a typedef, a `struct`, etc. + // + if (auto memberFuncDecl = memberDeclRef.as()) + { + if (auto requiredFuncDeclRef = requiredMemberDeclRef.as()) { - if (auto requiredFuncDeclRef = requiredMemberDeclRef.as()) - { - // Check signature match. - return doesSignatureMatchRequirement( - memberFuncDecl, - requiredFuncDeclRef, - witnessTable); - } + // Check signature match. + return doesSignatureMatchRequirement(memberFuncDecl, requiredFuncDeclRef, witnessTable); } - else if (auto memberInitDecl = memberDeclRef.as()) + } + else if (auto memberInitDecl = memberDeclRef.as()) + { + if (auto requiredInitDecl = requiredMemberDeclRef.as()) { - if (auto requiredInitDecl = requiredMemberDeclRef.as()) - { - // Check signature match. - return doesSignatureMatchRequirement( - memberInitDecl, - requiredInitDecl, - witnessTable); - } + // Check signature match. + return doesSignatureMatchRequirement(memberInitDecl, requiredInitDecl, witnessTable); } - else if (auto genDecl = memberDeclRef.as()) + } + else if (auto genDecl = memberDeclRef.as()) + { + // For a generic member, we will check if it can satisfy + // a generic requirement in the interface. + // + // TODO: we could also conceivably check that the generic + // could be *specialized* to satisfy the requirement, + // and then install a specialization of the generic into + // the witness table. Actually doing this would seem + // to require performing something akin to overload + // resolution as part of requirement satisfaction. + // + if (auto requiredGenDeclRef = requiredMemberDeclRef.as()) { - // For a generic member, we will check if it can satisfy - // a generic requirement in the interface. - // - // TODO: we could also conceivably check that the generic - // could be *specialized* to satisfy the requirement, - // and then install a specialization of the generic into - // the witness table. Actually doing this would seem - // to require performing something akin to overload - // resolution as part of requirement satisfaction. - // - if (auto requiredGenDeclRef = requiredMemberDeclRef.as()) - { - return doesGenericSignatureMatchRequirement(genDecl, requiredGenDeclRef, witnessTable); - } + return doesGenericSignatureMatchRequirement(genDecl, requiredGenDeclRef, witnessTable); } - else if (auto subAggTypeDeclRef = memberDeclRef.as()) + } + else if (auto subAggTypeDeclRef = memberDeclRef.as()) + { + if (auto requiredTypeDeclRef = requiredMemberDeclRef.as()) { - if(auto requiredTypeDeclRef = requiredMemberDeclRef.as()) - { - ensureDecl(subAggTypeDeclRef, DeclCheckState::CanUseAsType); + ensureDecl(subAggTypeDeclRef, DeclCheckState::CanUseAsType); - auto satisfyingType = DeclRefType::create(m_astBuilder, subAggTypeDeclRef); - return doesTypeSatisfyAssociatedTypeRequirement(satisfyingType, requiredTypeDeclRef, witnessTable); - } + auto satisfyingType = DeclRefType::create(m_astBuilder, subAggTypeDeclRef); + return doesTypeSatisfyAssociatedTypeRequirement( + satisfyingType, + requiredTypeDeclRef, + witnessTable); } - else if (auto typedefDeclRef = memberDeclRef.as()) + } + else if (auto typedefDeclRef = memberDeclRef.as()) + { + // this is a type-def decl in an aggregate type + // check if the specified type satisfies the constraints defined by the associated type + if (auto requiredTypeDeclRef = requiredMemberDeclRef.as()) { - // this is a type-def decl in an aggregate type - // check if the specified type satisfies the constraints defined by the associated type - if (auto requiredTypeDeclRef = requiredMemberDeclRef.as()) - { - ensureDecl(typedefDeclRef, DeclCheckState::ReadyForLookup); + ensureDecl(typedefDeclRef, DeclCheckState::ReadyForLookup); - auto satisfyingType = getNamedType(m_astBuilder, typedefDeclRef); - return doesTypeSatisfyAssociatedTypeRequirement(satisfyingType, requiredTypeDeclRef, witnessTable); - } + auto satisfyingType = getNamedType(m_astBuilder, typedefDeclRef); + return doesTypeSatisfyAssociatedTypeRequirement( + satisfyingType, + requiredTypeDeclRef, + witnessTable); } - else if( auto propertyDeclRef = memberDeclRef.as() ) + } + else if (auto propertyDeclRef = memberDeclRef.as()) + { + if (auto requiredPropertyDeclRef = requiredMemberDeclRef.as()) { - if( auto requiredPropertyDeclRef = requiredMemberDeclRef.as() ) - { - ensureDecl(propertyDeclRef, DeclCheckState::CanUseFuncSignature); - return doesPropertyMatchRequirement(propertyDeclRef, requiredPropertyDeclRef, witnessTable); - } + ensureDecl(propertyDeclRef, DeclCheckState::CanUseFuncSignature); + return doesPropertyMatchRequirement( + propertyDeclRef, + requiredPropertyDeclRef, + witnessTable); } - else if (auto varDeclRef = memberDeclRef.as()) + } + else if (auto varDeclRef = memberDeclRef.as()) + { + if (auto requiredVarDeclRef = requiredMemberDeclRef.as()) { - if (auto requiredVarDeclRef = requiredMemberDeclRef.as()) - { - ensureDecl(varDeclRef, DeclCheckState::SignatureChecked); - return doesVarMatchRequirement(varDeclRef, requiredVarDeclRef, witnessTable); - } + ensureDecl(varDeclRef, DeclCheckState::SignatureChecked); + return doesVarMatchRequirement(varDeclRef, requiredVarDeclRef, witnessTable); } - else if (auto subscriptDeclRef = memberDeclRef.as()) + } + else if (auto subscriptDeclRef = memberDeclRef.as()) + { + if (auto requiredSubscriptDeclRef = requiredMemberDeclRef.as()) { - if (auto requiredSubscriptDeclRef = requiredMemberDeclRef.as()) - { - ensureDecl(subscriptDeclRef, DeclCheckState::CanUseFuncSignature); - return doesSubscriptMatchRequirement(subscriptDeclRef, requiredSubscriptDeclRef, witnessTable); - } + ensureDecl(subscriptDeclRef, DeclCheckState::CanUseFuncSignature); + return doesSubscriptMatchRequirement( + subscriptDeclRef, + requiredSubscriptDeclRef, + witnessTable); } - // Default: just assume that thing aren't being satisfied. - return false; } + // Default: just assume that thing aren't being satisfied. + return false; +} - GenericDecl* SemanticsVisitor::synthesizeGenericSignatureForRequirementWitness( - ConformanceCheckingContext* context, - DeclRef requiredMemberDeclRef, - List& synArgs, - List& synGenericArgs, - ThisExpr*& synThis) - { - auto synGenericDecl = m_astBuilder->create(); - synGenericDecl->parentDecl = context->parentDecl; - synGenericDecl->ownedScope = m_astBuilder->create(); - synGenericDecl->ownedScope->containerDecl = synGenericDecl; - synGenericDecl->ownedScope->parent = getScope(context->parentDecl); +GenericDecl* SemanticsVisitor::synthesizeGenericSignatureForRequirementWitness( + ConformanceCheckingContext* context, + DeclRef requiredMemberDeclRef, + List& synArgs, + List& synGenericArgs, + ThisExpr*& synThis) +{ + auto synGenericDecl = m_astBuilder->create(); + synGenericDecl->parentDecl = context->parentDecl; + synGenericDecl->ownedScope = m_astBuilder->create(); + synGenericDecl->ownedScope->containerDecl = synGenericDecl; + synGenericDecl->ownedScope->parent = getScope(context->parentDecl); - // For now our synthesized method will use the name and source - // location of the requirement we are trying to satisfy. - // - // TODO: as it stands right now our syntesized method will - // get a mangled name, which we don't actually want. Leaving - // out the name here doesn't help matters, because then *all* - // snthesized methods on a given type would share the same - // mangled name! - // - synGenericDecl->nameAndLoc = requiredMemberDeclRef.getDecl()->nameAndLoc; - if (synGenericDecl->nameAndLoc.name) - { - synGenericDecl->nameAndLoc.name = getSession()->getNameObj("$__syn_" + synGenericDecl->nameAndLoc.name->text); - } + // For now our synthesized method will use the name and source + // location of the requirement we are trying to satisfy. + // + // TODO: as it stands right now our syntesized method will + // get a mangled name, which we don't actually want. Leaving + // out the name here doesn't help matters, because then *all* + // snthesized methods on a given type would share the same + // mangled name! + // + synGenericDecl->nameAndLoc = requiredMemberDeclRef.getDecl()->nameAndLoc; + if (synGenericDecl->nameAndLoc.name) + { + synGenericDecl->nameAndLoc.name = + getSession()->getNameObj("$__syn_" + synGenericDecl->nameAndLoc.name->text); + } - // Dictionary to map from the original type parameters to the synthesized ones. - Dictionary mapOrigToSynTypeParams; + // Dictionary to map from the original type parameters to the synthesized ones. + Dictionary mapOrigToSynTypeParams; - // Our synthesized method will have parameters matching the names - // and types of those on the requirement, and it will use expressions - // that reference those parametesr as arguments for the call expresison - // that makes up the body. - // - for (auto member : requiredMemberDeclRef.getDecl()->members) + // Our synthesized method will have parameters matching the names + // and types of those on the requirement, and it will use expressions + // that reference those parametesr as arguments for the call expresison + // that makes up the body. + // + for (auto member : requiredMemberDeclRef.getDecl()->members) + { + if (auto typeParamDeclBase = as(member)) { - if (auto typeParamDeclBase = as(member)) - { - auto synTypeParamDeclBase = (GenericTypeParamDeclBase*)m_astBuilder->createByNodeType(typeParamDeclBase->astNodeType); - synTypeParamDeclBase->nameAndLoc = typeParamDeclBase->getNameAndLoc(); - synTypeParamDeclBase->parameterIndex = typeParamDeclBase->parameterIndex; - synTypeParamDeclBase->parentDecl = synGenericDecl; + auto synTypeParamDeclBase = (GenericTypeParamDeclBase*)m_astBuilder->createByNodeType( + typeParamDeclBase->astNodeType); + synTypeParamDeclBase->nameAndLoc = typeParamDeclBase->getNameAndLoc(); + synTypeParamDeclBase->parameterIndex = typeParamDeclBase->parameterIndex; + synTypeParamDeclBase->parentDecl = synGenericDecl; - // Note: we intentionally do not copy GenericTypeParamDecl::initType here, - // because initType maybe dependent on the original type parameters, - // and if we copy we must also substitute all the original type parameters with the synthesized ones. - // It shouldn't be required for the implementing declaration to define initType anyways, so we'll just - // save ourselves from the trouble. - // - synGenericDecl->members.add(synTypeParamDeclBase); + // Note: we intentionally do not copy GenericTypeParamDecl::initType here, + // because initType maybe dependent on the original type parameters, + // and if we copy we must also substitute all the original type parameters with the + // synthesized ones. It shouldn't be required for the implementing declaration to define + // initType anyways, so we'll just save ourselves from the trouble. + // + synGenericDecl->members.add(synTypeParamDeclBase); - mapOrigToSynTypeParams.add(typeParamDeclBase, synTypeParamDeclBase); - - // Construct a DeclRefExpr from the type parameter. - auto synTypeParamDeclRef = makeDeclRef(synTypeParamDeclBase); + mapOrigToSynTypeParams.add(typeParamDeclBase, synTypeParamDeclBase); - auto synTypeParamDeclRefExpr = m_astBuilder->create(); - synTypeParamDeclRefExpr->declRef = synTypeParamDeclRef; - synTypeParamDeclRefExpr->type = getTypeForDeclRef(m_astBuilder, synTypeParamDeclRef, SourceLoc()); - - synGenericArgs.add(synTypeParamDeclRefExpr); - } - else if (auto valParamDecl = as(member)) - { - auto synValParamDecl = m_astBuilder->create(); - synValParamDecl->nameAndLoc = valParamDecl->nameAndLoc; - synValParamDecl->parentDecl = synGenericDecl; - synValParamDecl->parameterIndex = valParamDecl->parameterIndex; - synValParamDecl->type = valParamDecl->type; - - // Note: we intentionally do not copy GenericValueParamDecl::initExpr here, - // because initType maybe dependent on the original type/value parameters, - // and if we copy we must also substitute all the original type parameters with the synthesized ones. - // It shouldn't be required for the implementing declaration to define initType anyways, so we'll just - // save ourselves from the trouble. - // - synGenericDecl->members.add(synValParamDecl); + // Construct a DeclRefExpr from the type parameter. + auto synTypeParamDeclRef = makeDeclRef(synTypeParamDeclBase); - mapOrigToSynTypeParams.add(valParamDecl, synGenericDecl); + auto synTypeParamDeclRefExpr = m_astBuilder->create(); + synTypeParamDeclRefExpr->declRef = synTypeParamDeclRef; + synTypeParamDeclRefExpr->type = + getTypeForDeclRef(m_astBuilder, synTypeParamDeclRef, SourceLoc()); - // Construct a DeclRefExpr from the value parameter. - auto synValParamDeclRef = makeDeclRef(synValParamDecl); + synGenericArgs.add(synTypeParamDeclRefExpr); + } + else if (auto valParamDecl = as(member)) + { + auto synValParamDecl = m_astBuilder->create(); + synValParamDecl->nameAndLoc = valParamDecl->nameAndLoc; + synValParamDecl->parentDecl = synGenericDecl; + synValParamDecl->parameterIndex = valParamDecl->parameterIndex; + synValParamDecl->type = valParamDecl->type; - auto synValParamDeclRefExpr = m_astBuilder->create(); - synValParamDeclRefExpr->declRef = synValParamDeclRef; - synValParamDeclRefExpr->type = synValParamDecl->type.type; + // Note: we intentionally do not copy GenericValueParamDecl::initExpr here, + // because initType maybe dependent on the original type/value parameters, + // and if we copy we must also substitute all the original type parameters with the + // synthesized ones. It shouldn't be required for the implementing declaration to define + // initType anyways, so we'll just save ourselves from the trouble. + // + synGenericDecl->members.add(synValParamDecl); - synGenericArgs.add(synValParamDeclRefExpr); - } - } + mapOrigToSynTypeParams.add(valParamDecl, synGenericDecl); - // With all generic parameters in place, we can now form a partial substitution argument list - // without taking into account all the generic constraints. + // Construct a DeclRefExpr from the value parameter. + auto synValParamDeclRef = makeDeclRef(synValParamDecl); - // Given `requiredMemberDeclRef` that is `Lookup(ConcreteType:IFoo, IFoo::bar)`, we can now - // form a partial specialized declref to `IFoo::bar` with substitution args comming - // from the synthesized generic decl, i.e. we want to form: - // `Lookup(ConcreteType:IFoo, IFoo::bar)` where `UImpl` is a synthesized generic parameter. - // - auto partialDefaultArgs = getDefaultSubstitutionArgs(m_astBuilder, this, synGenericDecl); - DeclRef partiallySpecializedRequiredGenericDeclRef = m_astBuilder->getGenericAppDeclRef( - requiredMemberDeclRef, partialDefaultArgs.getArrayView()).as(); + auto synValParamDeclRefExpr = m_astBuilder->create(); + synValParamDeclRefExpr->declRef = synValParamDeclRef; + synValParamDeclRefExpr->type = synValParamDecl->type.type; - // With `partiallySpecializedRequiredGenericDeclRef`, we can obtain the right specialized types - // from the original requirement decl. For example, we can simply apply declref substituion on - // the original type constraint `U:IDerived` to get `UImpl : IDerived`. - // - for (auto member : requiredMemberDeclRef.getDecl()->members) - { - if (auto constraintDecl = as(member)) - { - auto synConstraintDecl = m_astBuilder->create(); - synConstraintDecl->nameAndLoc = constraintDecl->getNameAndLoc(); - synConstraintDecl->parentDecl = synGenericDecl; - - // For generic constraint Sub : Sup, we need to substitute them with - // synthesized generic parameters. - // - synConstraintDecl->sub = TypeExp( - (Type*)constraintDecl->sub.type->substitute(m_astBuilder, - SubstitutionSet(partiallySpecializedRequiredGenericDeclRef))); - synConstraintDecl->sup = TypeExp( - (Type*)constraintDecl->sup.type->substitute(m_astBuilder, - SubstitutionSet(partiallySpecializedRequiredGenericDeclRef))); - synGenericDecl->members.add(synConstraintDecl); - } + synGenericArgs.add(synValParamDeclRefExpr); } + } - // Override generic pointer to point to the original generic container. - // This will create a substitution of the synthesized parameters for the - // original parameters. - // - auto defaultArgs = getDefaultSubstitutionArgs(m_astBuilder, this, synGenericDecl); - DeclRef requiredFuncDeclRef = m_astBuilder->getGenericAppDeclRef( - requiredMemberDeclRef, defaultArgs.getArrayView()).as(); - - SLANG_ASSERT(requiredFuncDeclRef); - ConformanceCheckingContext subContext = *context; - subContext.parentDecl = synGenericDecl; + // With all generic parameters in place, we can now form a partial substitution argument list + // without taking into account all the generic constraints. - synGenericDecl->inner = synthesizeMethodSignatureForRequirementWitnessInner( - &subContext, - requiredFuncDeclRef, - synArgs, - synThis); - return synGenericDecl; - } + // Given `requiredMemberDeclRef` that is `Lookup(ConcreteType:IFoo, IFoo::bar)`, we can now + // form a partial specialized declref to `IFoo::bar` with substitution args comming + // from the synthesized generic decl, i.e. we want to form: + // `Lookup(ConcreteType:IFoo, IFoo::bar)` where `UImpl` is a synthesized generic + // parameter. + // + auto partialDefaultArgs = getDefaultSubstitutionArgs(m_astBuilder, this, synGenericDecl); + DeclRef partiallySpecializedRequiredGenericDeclRef = + m_astBuilder->getGenericAppDeclRef(requiredMemberDeclRef, partialDefaultArgs.getArrayView()) + .as(); - void SemanticsVisitor::addModifiersToSynthesizedDecl( - ConformanceCheckingContext* context, - DeclRef requiredMemberDeclRef, - CallableDecl* synthesized, - ThisExpr*& synThis) + // With `partiallySpecializedRequiredGenericDeclRef`, we can obtain the right specialized types + // from the original requirement decl. For example, we can simply apply declref substituion on + // the original type constraint `U:IDerived` to get `UImpl : IDerived`. + // + for (auto member : requiredMemberDeclRef.getDecl()->members) { - // Required interface methods can be `static` or non-`static`, - // and non-`static` methods can be `[mutating]` or non-`[mutating]`. - // All of these details affect how we introduce our `this` parameter, - // if any. - // - if (requiredMemberDeclRef.getDecl()->hasModifier()) + if (auto constraintDecl = as(member)) { - auto synStaticModifier = m_astBuilder->create(); - synthesized->modifiers.first = synStaticModifier; - } - else - { - // For a non-`static` requirement, we need a `this` parameter. - // - synThis = m_astBuilder->create(); - synThis->scope = synthesized->ownedScope; + auto synConstraintDecl = m_astBuilder->create(); + synConstraintDecl->nameAndLoc = constraintDecl->getNameAndLoc(); + synConstraintDecl->parentDecl = synGenericDecl; - // The type of `this` in our method will be the type for - // which we are synthesizing a conformance. + // For generic constraint Sub : Sup, we need to substitute them with + // synthesized generic parameters. // - synThis->type.type = context->conformingType; - - if (requiredMemberDeclRef.getDecl()->hasModifier()) - { - // If the interface requirement is `[mutating]` then our - // synthesized method should be too, and also the `this` - // parameter should be an l-value. - // - synThis->type.isLeftValue = true; + synConstraintDecl->sub = TypeExp((Type*)constraintDecl->sub.type->substitute( + m_astBuilder, + SubstitutionSet(partiallySpecializedRequiredGenericDeclRef))); + synConstraintDecl->sup = TypeExp((Type*)constraintDecl->sup.type->substitute( + m_astBuilder, + SubstitutionSet(partiallySpecializedRequiredGenericDeclRef))); + synGenericDecl->members.add(synConstraintDecl); + } + } + + // Override generic pointer to point to the original generic container. + // This will create a substitution of the synthesized parameters for the + // original parameters. + // + auto defaultArgs = getDefaultSubstitutionArgs(m_astBuilder, this, synGenericDecl); + DeclRef requiredFuncDeclRef = + m_astBuilder->getGenericAppDeclRef(requiredMemberDeclRef, defaultArgs.getArrayView()) + .as(); + + SLANG_ASSERT(requiredFuncDeclRef); + ConformanceCheckingContext subContext = *context; + subContext.parentDecl = synGenericDecl; + + synGenericDecl->inner = synthesizeMethodSignatureForRequirementWitnessInner( + &subContext, + requiredFuncDeclRef, + synArgs, + synThis); + return synGenericDecl; +} - auto synMutatingAttr = m_astBuilder->create(); - addModifier(synthesized, synMutatingAttr); - } - if (requiredMemberDeclRef.getDecl()->hasModifier()) - { - // If the interface requirement is `[constref]` then our - // synthesized method should be too. - // - auto synConstRefAttr = m_astBuilder->create(); - addModifier(synthesized, synConstRefAttr); - } - if (requiredMemberDeclRef.getDecl()->hasModifier()) - { - // If the interface requirement is `[ref]` then our - // synthesized method should be too. - // - synThis->type.isLeftValue = true; +void SemanticsVisitor::addModifiersToSynthesizedDecl( + ConformanceCheckingContext* context, + DeclRef requiredMemberDeclRef, + CallableDecl* synthesized, + ThisExpr*& synThis) +{ + // Required interface methods can be `static` or non-`static`, + // and non-`static` methods can be `[mutating]` or non-`[mutating]`. + // All of these details affect how we introduce our `this` parameter, + // if any. + // + if (requiredMemberDeclRef.getDecl()->hasModifier()) + { + auto synStaticModifier = m_astBuilder->create(); + synthesized->modifiers.first = synStaticModifier; + } + else + { + // For a non-`static` requirement, we need a `this` parameter. + // + synThis = m_astBuilder->create(); + synThis->scope = synthesized->ownedScope; - auto synConstRefAttr = m_astBuilder->create(); - addModifier(synthesized, synConstRefAttr); - } - if (requiredMemberDeclRef.getDecl()->hasModifier()) - { - auto noDiffThisAttr = m_astBuilder->create(); - addModifier(synthesized, noDiffThisAttr); - } + // The type of `this` in our method will be the type for + // which we are synthesizing a conformance. + // + synThis->type.type = context->conformingType; + + if (requiredMemberDeclRef.getDecl()->hasModifier()) + { + // If the interface requirement is `[mutating]` then our + // synthesized method should be too, and also the `this` + // parameter should be an l-value. + // + synThis->type.isLeftValue = true; + + auto synMutatingAttr = m_astBuilder->create(); + addModifier(synthesized, synMutatingAttr); } - if (requiredMemberDeclRef.getDecl()->hasModifier()) + if (requiredMemberDeclRef.getDecl()->hasModifier()) { - auto attr = m_astBuilder->create(); - addModifier(synthesized, attr); + // If the interface requirement is `[constref]` then our + // synthesized method should be too. + // + auto synConstRefAttr = m_astBuilder->create(); + addModifier(synthesized, synConstRefAttr); } - if (requiredMemberDeclRef.getDecl()->hasModifier()) + if (requiredMemberDeclRef.getDecl()->hasModifier()) { - auto attr = m_astBuilder->create(); - addModifier(synthesized, attr); + // If the interface requirement is `[ref]` then our + // synthesized method should be too. + // + synThis->type.isLeftValue = true; + + auto synConstRefAttr = m_astBuilder->create(); + addModifier(synthesized, synConstRefAttr); } - // The visibility of synthesized decl should be the min of the parent decl and the requirement. - if (requiredMemberDeclRef.getDecl()->findModifier()) + if (requiredMemberDeclRef.getDecl()->hasModifier()) { - auto requirementVisibility = getDeclVisibility(requiredMemberDeclRef.getDecl()); - auto thisVisibility = getDeclVisibility(context->parentDecl); - auto visibility = Math::Min(thisVisibility, requirementVisibility); - addVisibilityModifier(m_astBuilder, synthesized, visibility); + auto noDiffThisAttr = m_astBuilder->create(); + addModifier(synthesized, noDiffThisAttr); } } + if (requiredMemberDeclRef.getDecl()->hasModifier()) + { + auto attr = m_astBuilder->create(); + addModifier(synthesized, attr); + } + if (requiredMemberDeclRef.getDecl()->hasModifier()) + { + auto attr = m_astBuilder->create(); + addModifier(synthesized, attr); + } + // The visibility of synthesized decl should be the min of the parent decl and the requirement. + if (requiredMemberDeclRef.getDecl()->findModifier()) + { + auto requirementVisibility = getDeclVisibility(requiredMemberDeclRef.getDecl()); + auto thisVisibility = getDeclVisibility(context->parentDecl); + auto visibility = Math::Min(thisVisibility, requirementVisibility); + addVisibilityModifier(m_astBuilder, synthesized, visibility); + } +} - void SemanticsVisitor::addRequiredParamsToSynthesizedDecl( - DeclRef requirement, - CallableDecl* synthesized, - List& synArgs) +void SemanticsVisitor::addRequiredParamsToSynthesizedDecl( + DeclRef requirement, + CallableDecl* synthesized, + List& synArgs) +{ + // Our synthesized method will have parameters matching the names + // and types of those on the requirement, and it will use expressions + // that reference those parameters as arguments for the call expresison + // that makes up the body. + // + for (auto paramDeclRef : getParameters(m_astBuilder, requirement)) { - // Our synthesized method will have parameters matching the names - // and types of those on the requirement, and it will use expressions - // that reference those parameters as arguments for the call expresison - // that makes up the body. - // - for (auto paramDeclRef : getParameters(m_astBuilder, requirement)) - { - auto paramType = getType(m_astBuilder, paramDeclRef); + auto paramType = getType(m_astBuilder, paramDeclRef); - // For each parameter of the requirement, we create a matching - // parameter (same name and type) for the synthesized method. - // - auto synParamDecl = m_astBuilder->create(); - synParamDecl->nameAndLoc = paramDeclRef.getDecl()->nameAndLoc; - synParamDecl->type.type = paramType; + // For each parameter of the requirement, we create a matching + // parameter (same name and type) for the synthesized method. + // + auto synParamDecl = m_astBuilder->create(); + synParamDecl->nameAndLoc = paramDeclRef.getDecl()->nameAndLoc; + synParamDecl->type.type = paramType; - // We need to add the parameter as a child declaration of - // the method we are building. - // - synParamDecl->parentDecl = synthesized; - synthesized->members.add(synParamDecl); + // We need to add the parameter as a child declaration of + // the method we are building. + // + synParamDecl->parentDecl = synthesized; + synthesized->members.add(synParamDecl); - // Add modifiers - for (auto modifier : paramDeclRef.getDecl()->modifiers) + // Add modifiers + for (auto modifier : paramDeclRef.getDecl()->modifiers) + { + if (as(modifier)) { - if (as(modifier)) - { - auto noDiffModifier = m_astBuilder->create(); - noDiffModifier->keywordName = getSession()->getNameObj("no_diff"); - addModifier(synParamDecl, noDiffModifier); - } - else if (as(modifier) || as(modifier) || as(modifier) || as(modifier)) - { - auto clonedModifier = (Modifier*)m_astBuilder->createByNodeType(modifier->astNodeType); - clonedModifier->keywordName = modifier->keywordName; - addModifier(synParamDecl, clonedModifier); - } + auto noDiffModifier = m_astBuilder->create(); + noDiffModifier->keywordName = getSession()->getNameObj("no_diff"); + addModifier(synParamDecl, noDiffModifier); } - - // Create an expression that references the parameter for use in arguments. - auto synArg = m_astBuilder->create(); - synArg->declRef = makeDeclRef(synParamDecl); - synArg->type = paramType; - - if (auto typePack = as(paramType)) + else if ( + as(modifier) || as(modifier) || + as(modifier) || as(modifier)) { - // If paramType is a concrete type pack, we want to expand it out into - // individual arguments. - for (Index i = 0; i < typePack->getTypeCount(); i++) - { - auto elementType = typePack->getElementType(i); - auto synMemberExpr = m_astBuilder->create(); - synMemberExpr->base = synArg; - synMemberExpr->elementIndices.add((UInt)i); - synMemberExpr->type = elementType; - synArgs.add(synMemberExpr); - } + auto clonedModifier = + (Modifier*)m_astBuilder->createByNodeType(modifier->astNodeType); + clonedModifier->keywordName = modifier->keywordName; + addModifier(synParamDecl, clonedModifier); } - else + } + + // Create an expression that references the parameter for use in arguments. + auto synArg = m_astBuilder->create(); + synArg->declRef = makeDeclRef(synParamDecl); + synArg->type = paramType; + + if (auto typePack = as(paramType)) + { + // If paramType is a concrete type pack, we want to expand it out into + // individual arguments. + for (Index i = 0; i < typePack->getTypeCount(); i++) { - // For ordinary non-pack paramters, we will use synArg directly to - // referencing the parameter for the call in the function body. - // - synArgs.add(synArg); + auto elementType = typePack->getElementType(i); + auto synMemberExpr = m_astBuilder->create(); + synMemberExpr->base = synArg; + synMemberExpr->elementIndices.add((UInt)i); + synMemberExpr->type = elementType; + synArgs.add(synMemberExpr); } } + else + { + // For ordinary non-pack paramters, we will use synArg directly to + // referencing the parameter for the call in the function body. + // + synArgs.add(synArg); + } } +} - CallableDecl* SemanticsVisitor::synthesizeMethodSignatureForRequirementWitness( - ConformanceCheckingContext* context, - DeclRef requiredMemberDeclRef, - List& synArgs, - ThisExpr*& synThis) +CallableDecl* SemanticsVisitor::synthesizeMethodSignatureForRequirementWitness( + ConformanceCheckingContext* context, + DeclRef requiredMemberDeclRef, + List& synArgs, + ThisExpr*& synThis) +{ + if (auto genericDeclRef = as(requiredMemberDeclRef.getParent())) { - if (auto genericDeclRef = as(requiredMemberDeclRef.getParent())) - { - List synGenericArgs; - auto genericDecl = synthesizeGenericSignatureForRequirementWitness( - context, - genericDeclRef, - synArgs, - synGenericArgs, - synThis); - return (CallableDecl*)genericDecl->inner; - } - return synthesizeMethodSignatureForRequirementWitnessInner( + List synGenericArgs; + auto genericDecl = synthesizeGenericSignatureForRequirementWitness( context, - requiredMemberDeclRef, + genericDeclRef, synArgs, + synGenericArgs, synThis); + return (CallableDecl*)genericDecl->inner; } + return synthesizeMethodSignatureForRequirementWitnessInner( + context, + requiredMemberDeclRef, + synArgs, + synThis); +} - CallableDecl* SemanticsVisitor::synthesizeMethodSignatureForRequirementWitnessInner( - ConformanceCheckingContext* context, - DeclRef requiredMemberDeclRef, - List& synArgs, - ThisExpr*& synThis) - { - CallableDecl* synFuncDecl = as(m_astBuilder->createByNodeType(requiredMemberDeclRef.getDecl()->astNodeType)); - SLANG_ASSERT(synFuncDecl); - - synFuncDecl->ownedScope = m_astBuilder->create(); - synFuncDecl->ownedScope->containerDecl = synFuncDecl; - synFuncDecl->ownedScope->parent = getScope(context->parentDecl); - synFuncDecl->parentDecl = context->parentDecl; +CallableDecl* SemanticsVisitor::synthesizeMethodSignatureForRequirementWitnessInner( + ConformanceCheckingContext* context, + DeclRef requiredMemberDeclRef, + List& synArgs, + ThisExpr*& synThis) +{ + CallableDecl* synFuncDecl = as( + m_astBuilder->createByNodeType(requiredMemberDeclRef.getDecl()->astNodeType)); + SLANG_ASSERT(synFuncDecl); + + synFuncDecl->ownedScope = m_astBuilder->create(); + synFuncDecl->ownedScope->containerDecl = synFuncDecl; + synFuncDecl->ownedScope->parent = getScope(context->parentDecl); + synFuncDecl->parentDecl = context->parentDecl; + + // For now our synthesized method will use the name and source + // location of the requirement we are trying to satisfy. + // + // TODO: as it stands right now our syntesized method will + // get a mangled name, which we don't actually want. Leaving + // out the name here doesn't help matters, because then *all* + // snthesized methods on a given type would share the same + // mangled name! + // + synFuncDecl->nameAndLoc = requiredMemberDeclRef.getDecl()->nameAndLoc; + if (synFuncDecl->nameAndLoc.name) + { + synFuncDecl->nameAndLoc.name = + getSession()->getNameObj("$__syn_" + synFuncDecl->nameAndLoc.name->text); + } + + // The result type of our synthesized method will be the expected + // result type from the interface requirement. + // + // TODO: This logic can/will run into problems if the return type + // is an associated type. + // + // The ideal solution is that we should be solving for interface + // conformance in two phases: a first phase to solve for how + // associated types are satisfied, and then a second phase to solve + // for how other requirements are satisfied (where we can substitute + // in the associated type witnesses for the abstract associated + // types as part of `requiredMemberDeclRef`). + // + // TODO: We should also double-check that this logic will work + // with a method that returns `This`. + // + auto resultType = getResultType(m_astBuilder, requiredMemberDeclRef); + synFuncDecl->returnType.type = resultType; + + addRequiredParamsToSynthesizedDecl(requiredMemberDeclRef, synFuncDecl, synArgs); + addModifiersToSynthesizedDecl(context, requiredMemberDeclRef, synFuncDecl, synThis); + + return synFuncDecl; +} - // For now our synthesized method will use the name and source - // location of the requirement we are trying to satisfy. - // - // TODO: as it stands right now our syntesized method will - // get a mangled name, which we don't actually want. Leaving - // out the name here doesn't help matters, because then *all* - // snthesized methods on a given type would share the same - // mangled name! - // - synFuncDecl->nameAndLoc = requiredMemberDeclRef.getDecl()->nameAndLoc; - if (synFuncDecl->nameAndLoc.name) +void SemanticsVisitor::_addMethodWitness( + WitnessTable* witnessTable, + DeclRef requiredMemberDeclRef, + DeclRef satisfyingMemberDeclRef) +{ + for (auto reqRefDecl : + requiredMemberDeclRef.getDecl()->getMembersOfType()) + { + if (auto fwdReq = as(reqRefDecl->referencedDecl)) { - synFuncDecl->nameAndLoc.name = getSession()->getNameObj("$__syn_" + synFuncDecl->nameAndLoc.name->text); + ForwardDifferentiateVal* val = + m_astBuilder->getOrCreate(satisfyingMemberDeclRef); + witnessTable->add(fwdReq, RequirementWitness(val)); + } + else if (auto bwdReq = as(reqRefDecl->referencedDecl)) + { + DifferentiateVal* val = + m_astBuilder->getOrCreate(satisfyingMemberDeclRef); + witnessTable->add(bwdReq, RequirementWitness(val)); } - - // The result type of our synthesized method will be the expected - // result type from the interface requirement. - // - // TODO: This logic can/will run into problems if the return type - // is an associated type. - // - // The ideal solution is that we should be solving for interface - // conformance in two phases: a first phase to solve for how - // associated types are satisfied, and then a second phase to solve - // for how other requirements are satisfied (where we can substitute - // in the associated type witnesses for the abstract associated - // types as part of `requiredMemberDeclRef`). - // - // TODO: We should also double-check that this logic will work - // with a method that returns `This`. - // - auto resultType = getResultType(m_astBuilder, requiredMemberDeclRef); - synFuncDecl->returnType.type = resultType; - - addRequiredParamsToSynthesizedDecl(requiredMemberDeclRef, synFuncDecl, synArgs); - addModifiersToSynthesizedDecl(context, requiredMemberDeclRef, synFuncDecl, synThis); - - return synFuncDecl; } + witnessTable->add(requiredMemberDeclRef.getDecl(), RequirementWitness(satisfyingMemberDeclRef)); +} - void SemanticsVisitor::_addMethodWitness( - WitnessTable* witnessTable, - DeclRef requiredMemberDeclRef, - DeclRef satisfyingMemberDeclRef) +static bool isWrapperTypeDecl(Decl* decl) +{ + if (auto aggTypeDecl = as(decl)) { - for (auto reqRefDecl : requiredMemberDeclRef.getDecl()->getMembersOfType()) - { - if (auto fwdReq = as(reqRefDecl->referencedDecl)) - { - ForwardDifferentiateVal* val = m_astBuilder->getOrCreate(satisfyingMemberDeclRef); - witnessTable->add(fwdReq, RequirementWitness(val)); - } - else if (auto bwdReq = as(reqRefDecl->referencedDecl)) - { - DifferentiateVal* val = m_astBuilder->getOrCreate(satisfyingMemberDeclRef); - witnessTable->add(bwdReq, RequirementWitness(val)); - } - } - witnessTable->add(requiredMemberDeclRef.getDecl(), RequirementWitness(satisfyingMemberDeclRef)); + if (aggTypeDecl->wrappedType) + return true; } + return false; +} - static bool isWrapperTypeDecl(Decl* decl) +bool SemanticsVisitor::trySynthesizeMethodRequirementWitness( + ConformanceCheckingContext* context, + LookupResult const& lookupResult, + DeclRef requiredMemberDeclRef, + RefPtr witnessTable) +{ + // The situation here is that the context of an inheritance + // declaration didn't provide an exact match for a required + // method. E.g.: + // + // interface ICounter { [mutating] int increment(); } + // struct MyCounter : ICounter + // { + // [mutating] int increment(int val = 1) { ... } + // } + // + // It is clear in this case that the `MyCounter` type *can* + // satisfy the signature required by `ICounter`, but it has + // no explicit method declaration that is a perfect match. + // + // The approach in this function will be to construct a + // synthesized method along the lines of: + // + // struct MyCounter ... + // { + // ... + // [murtating] int synthesized() + // { + // return this.increment(); + // } + // } + // + // That is, we construct a method with the exact signature + // of the requirement (same parameter and result types), + // and then provide it with a body that simple `return`s + // the result of applying the desired requirement name + // (`increment` in this case) to those parameters. + // + // If the synthesized method type-checks, then we can say + // that the type must satisfy the requirement structurally, + // even if there isn't an exact signature match. More + // importantly, the method we just synthesized can be + // used as a witness to the fact that the requirement is + // satisfied. + + // With the big picture spelled out, we can settle into + // the work of constructing our synthesized method. + // + + bool isInWrapperType = isWrapperTypeDecl(context->parentDecl); + + // First, we check that the differentiabliity of the method matches the requirement, + // and we don't attempt to synthesize a method if they don't match. + if (!isInWrapperType && getShared()->getFuncDifferentiableLevel( + as(lookupResult.item.declRef.getDecl())) < + getShared()->getFuncDifferentiableLevel( + as(requiredMemberDeclRef.getDecl()))) { - if (auto aggTypeDecl = as(decl)) - { - if (aggTypeDecl->wrappedType) - return true; - } return false; } - bool SemanticsVisitor::trySynthesizeMethodRequirementWitness( - ConformanceCheckingContext* context, - LookupResult const& lookupResult, - DeclRef requiredMemberDeclRef, - RefPtr witnessTable) - { - // The situation here is that the context of an inheritance - // declaration didn't provide an exact match for a required - // method. E.g.: - // - // interface ICounter { [mutating] int increment(); } - // struct MyCounter : ICounter - // { - // [mutating] int increment(int val = 1) { ... } - // } - // - // It is clear in this case that the `MyCounter` type *can* - // satisfy the signature required by `ICounter`, but it has - // no explicit method declaration that is a perfect match. - // - // The approach in this function will be to construct a - // synthesized method along the lines of: - // - // struct MyCounter ... - // { - // ... - // [murtating] int synthesized() - // { - // return this.increment(); - // } - // } - // - // That is, we construct a method with the exact signature - // of the requirement (same parameter and result types), - // and then provide it with a body that simple `return`s - // the result of applying the desired requirement name - // (`increment` in this case) to those parameters. - // - // If the synthesized method type-checks, then we can say - // that the type must satisfy the requirement structurally, - // even if there isn't an exact signature match. More - // importantly, the method we just synthesized can be - // used as a witness to the fact that the requirement is - // satisfied. - - // With the big picture spelled out, we can settle into - // the work of constructing our synthesized method. - // - - bool isInWrapperType = isWrapperTypeDecl(context->parentDecl); - - // First, we check that the differentiabliity of the method matches the requirement, - // and we don't attempt to synthesize a method if they don't match. - if (!isInWrapperType && - getShared()->getFuncDifferentiableLevel( - as(lookupResult.item.declRef.getDecl())) - < getShared()->getFuncDifferentiableLevel( - as(requiredMemberDeclRef.getDecl()))) - { - return false; - } + ThisExpr* synThis = nullptr; + List synArgs; + auto synFuncDecl = as(synthesizeMethodSignatureForRequirementWitness( + context, + requiredMemberDeclRef, + synArgs, + synThis)); - ThisExpr* synThis = nullptr; - List synArgs; - auto synFuncDecl = as(synthesizeMethodSignatureForRequirementWitness( - context, requiredMemberDeclRef, synArgs, synThis)); + auto resultType = synFuncDecl->returnType.type; - auto resultType = synFuncDecl->returnType.type; + // The body of our synthesized method is going to try to + // make a call using the name of the method requirement (e.g., + // the name `increment` in our example at the top of this function). + // + // The caller already passed in a `LookupResult` that represents + // an attempt to look up the given name in the type of `this`, + // and we really just need to wrap that result up as an overloaded + // expression. + // + auto baseOverloadedExpr = m_astBuilder->create(); + baseOverloadedExpr->name = requiredMemberDeclRef.getDecl()->getName(); - // The body of our synthesized method is going to try to - // make a call using the name of the method requirement (e.g., - // the name `increment` in our example at the top of this function). - // - // The caller already passed in a `LookupResult` that represents - // an attempt to look up the given name in the type of `this`, - // and we really just need to wrap that result up as an overloaded - // expression. - // - auto baseOverloadedExpr = m_astBuilder->create(); - baseOverloadedExpr->name = requiredMemberDeclRef.getDecl()->getName(); + if (isInWrapperType) + { + auto aggTypeDecl = as(context->parentDecl); + baseOverloadedExpr->lookupResult2 = lookUpMember( + m_astBuilder, + this, + baseOverloadedExpr->name, + aggTypeDecl->wrappedType.type, + aggTypeDecl->ownedScope, + LookupMask::Default, + LookupOptions::IgnoreBaseInterfaces); + addModifier(synFuncDecl, m_astBuilder->create()); + + synFuncDecl->parentDecl = aggTypeDecl; + } + else + { + baseOverloadedExpr->lookupResult2 = lookupResult; + } + // If `synThis` is non-null, then we will use it as the base of + // the overloaded expression, so that we have an overloaded + // member reference, and not just an overloaded reference to some + // static definitions. + // + if (synThis) + { if (isInWrapperType) { - auto aggTypeDecl = as(context->parentDecl); - baseOverloadedExpr->lookupResult2 = lookUpMember( + // If this is a wrapper type, then use the inner + // object as the actual this parameter for the redirected + // call. + auto innerExpr = m_astBuilder->create(); + innerExpr->scope = synThis->scope; + innerExpr->name = getName("inner"); + baseOverloadedExpr->base = CheckExpr(innerExpr); + SemanticsDeclBodyVisitor bodyVisitor(withParentFunc(synFuncDecl)); + bodyVisitor.maybeRegisterDifferentiableType( m_astBuilder, - this, - baseOverloadedExpr->name, - aggTypeDecl->wrappedType.type, - aggTypeDecl->ownedScope, - LookupMask::Default, - LookupOptions::IgnoreBaseInterfaces); - addModifier(synFuncDecl, m_astBuilder->create()); - - synFuncDecl->parentDecl = aggTypeDecl; + baseOverloadedExpr->base->type); } else { - baseOverloadedExpr->lookupResult2 = lookupResult; + baseOverloadedExpr->base = synThis; } + } - // If `synThis` is non-null, then we will use it as the base of - // the overloaded expression, so that we have an overloaded - // member reference, and not just an overloaded reference to some - // static definitions. - // - if (synThis) + + // In order to know if our call is well-formed, we need to run + // the semantic checking logic for overload resolution. If it + // runs into an error, we don't want that being reported back + // to the user as some kind of overload-resolution failure. + // + // In order to protect the user from whatever errors might + // occur, we will perform the checking in the context of + // a temporary diagnostic sink. + // + DiagnosticSink tempSink(getSourceManager(), nullptr); + ExprLocalScope localScope; + SemanticsVisitor subVisitor( + withSink(&tempSink).withParentFunc(synFuncDecl).withExprLocalScope(&localScope)); + + Expr* synBase = baseOverloadedExpr; + + // If the requirement is a generic decl, fill in all generic arguments explicitly. + if (auto genericDeclRef = as(synFuncDecl->parentDecl)) + { + auto genericAppExpr = m_astBuilder->create(); + genericAppExpr->functionExpr = synBase; + for (auto member : genericDeclRef->members) { - if (isInWrapperType) + if (auto typeParamDecl = as(member)) { - // If this is a wrapper type, then use the inner - // object as the actual this parameter for the redirected - // call. - auto innerExpr = m_astBuilder->create(); - innerExpr->scope = synThis->scope; - innerExpr->name = getName("inner"); - baseOverloadedExpr->base = CheckExpr(innerExpr); - SemanticsDeclBodyVisitor bodyVisitor(withParentFunc(synFuncDecl)); - bodyVisitor.maybeRegisterDifferentiableType(m_astBuilder, baseOverloadedExpr->base->type); + auto synTypeParamDeclRef = makeDeclRef(typeParamDecl); + auto synTypeParamDeclRefExpr = m_astBuilder->create(); + synTypeParamDeclRefExpr->declRef = synTypeParamDeclRef; + synTypeParamDeclRefExpr->type = + getTypeForDeclRef(m_astBuilder, synTypeParamDeclRef, SourceLoc()); + genericAppExpr->arguments.add(synTypeParamDeclRefExpr); } - else + else if (auto valParamDecl = as(member)) { - baseOverloadedExpr->base = synThis; + auto synValParamDeclRef = makeDeclRef(valParamDecl); + auto synValParamDeclRefExpr = m_astBuilder->create(); + synValParamDeclRefExpr->declRef = synValParamDeclRef; + synValParamDeclRefExpr->type = getType(m_astBuilder, synValParamDeclRef); + genericAppExpr->arguments.add(synValParamDeclRefExpr); } } + synBase = subVisitor.checkGenericAppWithCheckedArgs(genericAppExpr); - - // In order to know if our call is well-formed, we need to run - // the semantic checking logic for overload resolution. If it - // runs into an error, we don't want that being reported back - // to the user as some kind of overload-resolution failure. - // - // In order to protect the user from whatever errors might - // occur, we will perform the checking in the context of - // a temporary diagnostic sink. + // If checking the generic app failed, we can't synthesize the witness. // - DiagnosticSink tempSink(getSourceManager(), nullptr); - ExprLocalScope localScope; - SemanticsVisitor subVisitor(withSink(&tempSink).withParentFunc(synFuncDecl).withExprLocalScope(&localScope)); + if (tempSink.getErrorCount() != 0) + return false; + } + + // We now have the reference to the overload group we plan to call, + // and we already built up the argument list, so we can construct + // an `InvokeExpr` that represents the call we want to make. + // + auto synCall = m_astBuilder->create(); + synCall->functionExpr = synBase; + synCall->arguments = synArgs; + + // With our temporary diagnostic sink soaking up any messages + // from overload resolution, we can now try to resolve + // the call to see what happens. + // + auto checkedCall = subVisitor.ResolveInvoke(synCall); + + // Of course, it is possible that the call went through fine, + // but the result isn't of the type we expect/require, + // so we also need to coerce the result of the call to + // the expected type. + // + auto coercedCall = subVisitor.coerce(CoercionSite::Return, resultType, checkedCall); + + // If our overload resolution or type coercion failed, + // then we have not been able to synthesize a witness + // for the requirement. + // + // TODO: We might want to detect *why* overload resolution + // or type coercion failed, and report errors accordingly. + // + // More detailed diagnostics could help users understand + // what they did wrong, e.g.: + // + // * "We tried to use `foo(int)` but the interface requires `foo(String)` + // + // * "You have two methods that can apply as `bar()` and we couldn't tell which one you meant + // + // For now we just bail out here and rely on the caller to + // diagnose a generic "failed to satisfying requirement" error. + // + if (tempSink.getErrorCount() != 0) + return false; + + // If we were able to type-check the call, then we should + // be able to finish construction of a suitable witness. + // + // We've already created the outer declaration (including its + // parameters), and the inner expression, so the main work + // that is left is defining the body of the new function, + // which comprises a single `return` statement. + // + auto synReturn = m_astBuilder->create(); + synReturn->expression = coercedCall; + + synFuncDecl->body = synReturn; + + // Note: we set the parent of the synthesized declaration + // to the parent of the inheritance declaration being + // validated (which is either a type declaration or + // an `extension`), but we do *not* add the syntehsized + // declaration to the list of child declarations at + // this point. + // + // The synthesized decl already has its parent set to + // the current parent decl, so we don't need more actions + // to wire it up to the AST hierarchy. + // + // By leaving the synthesized declaration off of the list + // of members, we ensure that it doesn't get found + // by lookup (e.g., in a module that `import`s this type). + // Unfortunately, we may also break invariants in other parts + // of the code if they assume that all declarations have + // to appear in the parent/child hierarchy of the module. + // + // TODO: We may need to properly wire the synthesized + // declaration into the hierarchy, but then attach a modifier + // to it to indicate that it should be ignored by things like lookup. + // + + // If the synthesized func is differentiable, make sure to populate its + // differential type dictionary. + SemanticsDeclBodyVisitor bodyVisitor(withParentFunc(synFuncDecl)); + bodyVisitor.registerDifferentiableTypesForFunc(synFuncDecl); + + // Once our synthesized declaration is complete, we need + // to install it as the witness that satifies the given + // requirement. + // + // Subsequent code generation should not be able to tell the + // difference between our synthetic method and a hand-written + // one with the same behavior. + // + auto containerDecl = getParentDecl(synFuncDecl); + auto containerDeclRef = getDefaultDeclRef(containerDecl); + auto synDeclRef = m_astBuilder->getMemberDeclRef(containerDeclRef, synFuncDecl); + _addMethodWitness(witnessTable, requiredMemberDeclRef, synDeclRef); + return true; +} + +bool SemanticsVisitor::trySynthesizeConstructorRequirementWitness( + ConformanceCheckingContext* context, + LookupResult const& satisfyingMemberLookupResult, + DeclRef requiredMemberDeclRef, + RefPtr witnessTable) +{ + SLANG_UNUSED(satisfyingMemberLookupResult); + + if (as(context->parentDecl)) + { + if (auto builtinRequirement = + requiredMemberDeclRef.getDecl()->findModifier()) + { + return trySynthesizeEnumTypeMethodRequirementWitness( + context, + requiredMemberDeclRef, + witnessTable, + builtinRequirement->kind); + } + } + + bool isDefaultInitializableType = requiredMemberDeclRef.getParent() == + getASTBuilder()->getDefaultInitializableTypeInterfaceDecl(); + bool isInWrapperType = isWrapperTypeDecl(context->parentDecl); + if (!isInWrapperType && !isDefaultInitializableType && !satisfyingMemberLookupResult.isValid()) + { + return false; + } + + List synArgs; + ThisExpr* synThis = nullptr; + + auto ctorDecl = (ConstructorDecl*)synthesizeMethodSignatureForRequirementWitness( + context, + requiredMemberDeclRef, + synArgs, + synThis); + ctorDecl->loc = context->parentDecl->loc; + ctorDecl->closingSourceLoc = ctorDecl->loc; + auto ctorName = getName("$init"); + ctorDecl->nameAndLoc.name = ctorName; + ctorDecl->nameAndLoc.loc = context->parentDecl->loc; - Expr* synBase = baseOverloadedExpr; + auto seqStmt = m_astBuilder->create(); + ctorDecl->body = seqStmt; - // If the requirement is a generic decl, fill in all generic arguments explicitly. - if (auto genericDeclRef = as(synFuncDecl->parentDecl)) + + if (isInWrapperType) + { + SemanticsDeclBodyVisitor bodyVisitor(withParentFunc(ctorDecl)); + bodyVisitor.maybeRegisterDifferentiableType(m_astBuilder, context->conformingType); + + for (auto member : context->parentDecl->members) { - auto genericAppExpr = m_astBuilder->create(); - genericAppExpr->functionExpr = synBase; - for (auto member : genericDeclRef->members) + if (auto varDecl = as(member)) { - if (auto typeParamDecl = as(member)) - { - auto synTypeParamDeclRef = makeDeclRef(typeParamDecl); - auto synTypeParamDeclRefExpr = m_astBuilder->create(); - synTypeParamDeclRefExpr->declRef = synTypeParamDeclRef; - synTypeParamDeclRefExpr->type = getTypeForDeclRef(m_astBuilder, synTypeParamDeclRef, SourceLoc()); - genericAppExpr->arguments.add(synTypeParamDeclRefExpr); - } - else if (auto valParamDecl = as(member)) - { - auto synValParamDeclRef = makeDeclRef(valParamDecl); - auto synValParamDeclRefExpr = m_astBuilder->create(); - synValParamDeclRefExpr->declRef = synValParamDeclRef; - synValParamDeclRefExpr->type = getType(m_astBuilder, synValParamDeclRef); - genericAppExpr->arguments.add(synValParamDeclRefExpr); - } - } - synBase = subVisitor.checkGenericAppWithCheckedArgs(genericAppExpr); + auto varExpr = m_astBuilder->create(); + varExpr->scope = ctorDecl->ownedScope; + varExpr->name = varDecl->getName(); + auto checkedVarExpr = CheckTerm(varExpr); + if (!checkedVarExpr) + return false; + if (as(checkedVarExpr->type.type)) + return false; + auto assign = m_astBuilder->create(); + assign->left = checkedVarExpr; + auto temp = m_astBuilder->create(); + auto lookupResult = lookUpMember( + m_astBuilder, + this, + ctorName, + varDecl->type.type, + ctorDecl->ownedScope, + LookupMask::Function, + LookupOptions::IgnoreBaseInterfaces); + temp->functionExpr = createLookupResultExpr( + ctorName, + lookupResult, + nullptr, + context->parentDecl->loc, + nullptr); + temp->arguments.addRange(synArgs); + auto resolvedVar = ResolveInvoke(temp); + if (!resolvedVar) + return false; + assign->right = resolvedVar; + assign->type = m_astBuilder->getVoidType(); + bodyVisitor.maybeRegisterDifferentiableType(m_astBuilder, varDecl->type.type); - // If checking the generic app failed, we can't synthesize the witness. - // - if (tempSink.getErrorCount() != 0) - return false; + auto stmt = m_astBuilder->create(); + stmt->expression = assign; + seqStmt->stmts.add(stmt); + break; + } } + } + else if (synArgs.getCount()) + { + // The body of our synthesized method is going to try to + // make a ctor call with the specified arguments (e.g., + // the name `increment` in our example at the top of this function). + // + auto synBase = m_astBuilder->create(); + synBase->name = requiredMemberDeclRef.getDecl()->getName(); + + synBase->lookupResult2 = satisfyingMemberLookupResult; // We now have the reference to the overload group we plan to call, // and we already built up the argument list, so we can construct @@ -4591,365 +4853,521 @@ namespace Slang synCall->functionExpr = synBase; synCall->arguments = synArgs; + // In order to know if our call is well-formed, we need to run + // the semantic checking logic for overload resolution. If it + // runs into an error, we don't want that being reported back + // to the user as some kind of overload-resolution failure. + // + // In order to protect the user from whatever errors might + // occur, we will perform the checking in the context of + // a temporary diagnostic sink. + // + DiagnosticSink tempSink(getSourceManager(), nullptr); + ExprLocalScope localScope; + SemanticsVisitor subVisitor( + withSink(&tempSink).withParentFunc(ctorDecl).withExprLocalScope(&localScope)); + // With our temporary diagnostic sink soaking up any messages // from overload resolution, we can now try to resolve // the call to see what happens. // auto checkedCall = subVisitor.ResolveInvoke(synCall); - // Of course, it is possible that the call went through fine, - // but the result isn't of the type we expect/require, - // so we also need to coerce the result of the call to - // the expected type. - // - auto coercedCall = subVisitor.coerce(CoercionSite::Return, resultType, checkedCall); - - // If our overload resolution or type coercion failed, - // then we have not been able to synthesize a witness - // for the requirement. - // - // TODO: We might want to detect *why* overload resolution - // or type coercion failed, and report errors accordingly. - // - // More detailed diagnostics could help users understand - // what they did wrong, e.g.: - // - // * "We tried to use `foo(int)` but the interface requires `foo(String)` - // - // * "You have two methods that can apply as `bar()` and we couldn't tell which one you meant - // - // For now we just bail out here and rely on the caller to - // diagnose a generic "failed to satisfying requirement" error. - // - if(tempSink.getErrorCount() != 0) + // If any error occurs during overload resolution, we can't synthesize the witness. + if (tempSink.getErrorCount() != 0) return false; // If we were able to type-check the call, then we should - // be able to finish construction of a suitable witness. - // - // We've already created the outer declaration (including its - // parameters), and the inner expression, so the main work - // that is left is defining the body of the new function, - // which comprises a single `return` statement. - // - auto synReturn = m_astBuilder->create(); - synReturn->expression = coercedCall; - - synFuncDecl->body = synReturn; - - // Note: we set the parent of the synthesized declaration - // to the parent of the inheritance declaration being - // validated (which is either a type declaration or - // an `extension`), but we do *not* add the syntehsized - // declaration to the list of child declarations at - // this point. - // - // The synthesized decl already has its parent set to - // the current parent decl, so we don't need more actions - // to wire it up to the AST hierarchy. - // - // By leaving the synthesized declaration off of the list - // of members, we ensure that it doesn't get found - // by lookup (e.g., in a module that `import`s this type). - // Unfortunately, we may also break invariants in other parts - // of the code if they assume that all declarations have - // to appear in the parent/child hierarchy of the module. - // - // TODO: We may need to properly wire the synthesized - // declaration into the hierarchy, but then attach a modifier - // to it to indicate that it should be ignored by things like lookup. - // - - // If the synthesized func is differentiable, make sure to populate its - // differential type dictionary. - SemanticsDeclBodyVisitor bodyVisitor(withParentFunc(synFuncDecl)); - bodyVisitor.registerDifferentiableTypesForFunc(synFuncDecl); - - // Once our synthesized declaration is complete, we need - // to install it as the witness that satifies the given - // requirement. - // - // Subsequent code generation should not be able to tell the - // difference between our synthetic method and a hand-written - // one with the same behavior. + // be able to finish construction of a suitable ctor witness, + // by emitting `this = resolvedCtorCall()`. // - auto containerDecl = getParentDecl(synFuncDecl); - auto containerDeclRef = getDefaultDeclRef(containerDecl); - auto synDeclRef = m_astBuilder->getMemberDeclRef(containerDeclRef, synFuncDecl); - _addMethodWitness(witnessTable, requiredMemberDeclRef, synDeclRef); - return true; + AssignExpr* assignExpr = m_astBuilder->create(); + assignExpr->left = synThis; + assignExpr->right = checkedCall; + assignExpr->type = m_astBuilder->getVoidType(); + ExpressionStmt* exprStmt = m_astBuilder->create(); + exprStmt->expression = assignExpr; + seqStmt->stmts.add(exprStmt); } - bool SemanticsVisitor::trySynthesizeConstructorRequirementWitness( - ConformanceCheckingContext* context, - LookupResult const& satisfyingMemberLookupResult, - DeclRef requiredMemberDeclRef, - RefPtr witnessTable) - { - SLANG_UNUSED(satisfyingMemberLookupResult); - - if (as(context->parentDecl)) - { - if (auto builtinRequirement = requiredMemberDeclRef.getDecl()->findModifier()) - { - return trySynthesizeEnumTypeMethodRequirementWitness(context, requiredMemberDeclRef, witnessTable, builtinRequirement->kind); - } - } + if (isDefaultInitializableType) + context->parentDecl->addMember(ctorDecl); - bool isDefaultInitializableType = requiredMemberDeclRef.getParent() == getASTBuilder()->getDefaultInitializableTypeInterfaceDecl(); - bool isInWrapperType = isWrapperTypeDecl(context->parentDecl); - if (!isInWrapperType && !isDefaultInitializableType && !satisfyingMemberLookupResult.isValid()) - { - return false; - } + auto containerDecl = getParentDecl(ctorDecl); + auto containerDeclRef = getDefaultDeclRef(containerDecl); + auto synDeclRef = m_astBuilder->getMemberDeclRef(containerDeclRef, ctorDecl); + _addMethodWitness(witnessTable, requiredMemberDeclRef, synDeclRef); - List synArgs; - ThisExpr* synThis = nullptr; + return true; +} - auto ctorDecl = (ConstructorDecl*)synthesizeMethodSignatureForRequirementWitness( +bool SemanticsVisitor::trySynthesizePropertyRequirementWitness( + ConformanceCheckingContext* context, + LookupResult const& lookupResult, + DeclRef requiredMemberDeclRef, + RefPtr witnessTable) +{ + if (isWrapperTypeDecl(context->parentDecl)) + return trySynthesizeWrapperTypePropertyRequirementWitness( context, requiredMemberDeclRef, - synArgs, - synThis); - ctorDecl->loc = context->parentDecl->loc; - ctorDecl->closingSourceLoc = ctorDecl->loc; - auto ctorName = getName("$init"); - ctorDecl->nameAndLoc.name = ctorName; - ctorDecl->nameAndLoc.loc = context->parentDecl->loc; - - auto seqStmt = m_astBuilder->create(); - ctorDecl->body = seqStmt; + witnessTable); + // The situation here is that the context of an inheritance + // declaration didn't provide an exact match for a required + // property. E.g.: + // + // interface ICell { property value : int { get; set; } } + // struct MyCell : ICell + // { + // int value; + // } + // + // It is clear in this case that the `MyCell` type *can* + // satisfy the signature required by `ICell`, but it has + // no explicit `property` declaration, and instead just + // a field with the right name and type. + // + // The approach in this function will be to construct a + // synthesized `preoperty` along the lines of: + // + // struct MyCounter ... + // { + // ... + // property value_synthesized : int + // { + // get { return this.value; } + // set(newValue) { this.value = newValue; } + // } + // } + // + // That is, we construct a `property` with the correct type + // and with an accessor for each requirement, where the accesors + // all try to read or write `this.value`. + // + // If those synthesized accessors all type-check, then we can + // say that the type must satisfy the requirement structurally, + // even if there isn't an exact signature match. More + // importantly, the `property` we just synthesized can be + // used as a witness to the fact that the requirement is + // satisfied. + // + // The big-picture flow of the logic here is similar to + // `trySynthesizeMethodRequirementWitness()` above, and we + // will not comment this code as exhaustively, under the + // assumption that readers of the code don't benefit from + // having the exact same information stated twice. + + // With the introduction out of the way, let's get started + // constructing a synthesized `PropertyDecl`. + // + auto synPropertyDecl = m_astBuilder->create(); + + // Synthesize the property name with a prefix to avoid name clashing. + synPropertyDecl->nameAndLoc = requiredMemberDeclRef.getDecl()->nameAndLoc; + synPropertyDecl->nameAndLoc.name = + getName(String("$syn_property_") + getText(requiredMemberDeclRef.getName())); + synPropertyDecl->parentDecl = context->parentDecl; + + + // The type of our synthesized property can be derived from the + // specialized declref to the requirement decl. + // + auto propertyType = getType(m_astBuilder, requiredMemberDeclRef); + synPropertyDecl->type.type = propertyType; + + + // We start by constructing an expression that represents + // `this.name` where `name` is the name of the required + // member. The caller already passed in a `lookupResult` + // that should indicate all the declarations found by + // looking up `name`, so we can start with that. + // + // TODO: Note that there are many cases for member lookup + // that are not handled just by using `createLookupResultExpr` + // because they are currently being special-cased (the most + // notable cases are swizzles, as well as lookup of static + // members in types). + // + // The main result here is that we will not be able to synthesize + // a requirement for a built-in scalar/vector/matrix type to + // a property with a name like `.xy` based on the presence of + // swizles, even though it seems like such a thing should Just Work. + // + // If this is important we could "fix" it by allowing this + // code to dispatch to the special-case logic used when doing + // semantic checking for member expressions. + // + // Note: an alternative would be to change the core module declarations + // of vectors/matrices so that all the swizzles are defined as + // `property` declarations. There are some C++ math libraries (like GLM) + // that implement swizzle syntax by a similar approach of statically + // enumerating all possible swizzles. The down-side to such an + // approach is that the combinatorial space of swizzles is quite + // large (especially for matrices) so that supporting them via + // general-purpose language features is unlikely to be as efficient + // as special-case logic. + // + // We are going to synthesize an expression and then perform + // semantic checking on it, but if there are semantic errors + // we do *not* want to report them to the user as such, and + // instead want the result to be a failure to synthesize + // a valid witness. + // + // We will buffer up diagnostics into a temporary sink and + // then throw them away when we are done. + // + // TODO: This behavior might be something we want to make + // into a more fundamental capability of `DiagnosticSink` and/or + // `SemanticsVisitor` so that code can push/pop the emission + // of diagnostics more easily. + // + DiagnosticSink tempSink(getSourceManager(), nullptr); + SemanticsVisitor subVisitor(withSink(&tempSink)); + + // We need to create a `this` expression to be used in the body + // of the synthesized accessor. + // + // TODO: if we ever allow `static` properties or subscripts, + // we will need to handle that case here, by *not* creating + // a `this` expression. + // + ThisExpr* synThis = m_astBuilder->create(); + synThis->scope = synPropertyDecl->ownedScope; + + // The type of `this` in our accessor will be the type for + // which we are synthesizing a conformance. + // + synThis->type.type = context->conformingType; + synThis->type.isLeftValue = true; + auto synMemberRef = subVisitor.createLookupResultExpr( + requiredMemberDeclRef.getName(), + lookupResult, + synThis, + requiredMemberDeclRef.getLoc(), + nullptr); + synMemberRef->loc = requiredMemberDeclRef.getLoc(); + + bool canSynAccessors = synthesizeAccessorRequirements( + context, + requiredMemberDeclRef, + propertyType, + synMemberRef, + synPropertyDecl, + witnessTable); + if (!canSynAccessors) + return false; - if (isInWrapperType) - { - SemanticsDeclBodyVisitor bodyVisitor(withParentFunc(ctorDecl)); - bodyVisitor.maybeRegisterDifferentiableType(m_astBuilder, context->conformingType); - for (auto member : context->parentDecl->members) - { - if (auto varDecl = as(member)) - { - auto varExpr = m_astBuilder->create(); - varExpr->scope = ctorDecl->ownedScope; - varExpr->name = varDecl->getName(); - auto checkedVarExpr = CheckTerm(varExpr); - if (!checkedVarExpr) - return false; - if (as(checkedVarExpr->type.type)) - return false; - auto assign = m_astBuilder->create(); - assign->left = checkedVarExpr; - auto temp = m_astBuilder->create(); - auto lookupResult = lookUpMember( - m_astBuilder, - this, - ctorName, - varDecl->type.type, - ctorDecl->ownedScope, - LookupMask::Function, - LookupOptions::IgnoreBaseInterfaces); - temp->functionExpr = createLookupResultExpr(ctorName, lookupResult, nullptr, context->parentDecl->loc, nullptr); - temp->arguments.addRange(synArgs); - auto resolvedVar = ResolveInvoke(temp); - if (!resolvedVar) - return false; - assign->right = resolvedVar; - assign->type = m_astBuilder->getVoidType(); - bodyVisitor.maybeRegisterDifferentiableType(m_astBuilder, varDecl->type.type); - - auto stmt = m_astBuilder->create(); - stmt->expression = assign; - seqStmt->stmts.add(stmt); - break; - } - } - } - else if (synArgs.getCount()) - { - // The body of our synthesized method is going to try to - // make a ctor call with the specified arguments (e.g., - // the name `increment` in our example at the top of this function). - // - auto synBase = m_astBuilder->create(); - synBase->name = requiredMemberDeclRef.getDecl()->getName(); + // The visibility of synthesized decl should be the min of the parent decl and the requirement. + if (requiredMemberDeclRef.getDecl()->findModifier()) + { + auto requirementVisibility = getDeclVisibility(requiredMemberDeclRef.getDecl()); + auto thisVisibility = getDeclVisibility(context->parentDecl); + auto visibility = Math::Min(thisVisibility, requirementVisibility); + addVisibilityModifier(m_astBuilder, synPropertyDecl, visibility); + } + return true; +} - synBase->lookupResult2 = satisfyingMemberLookupResult; +bool SemanticsVisitor::trySynthesizeWrapperTypePropertyRequirementWitness( + ConformanceCheckingContext* context, + DeclRef requiredMemberDeclRef, + RefPtr witnessTable) +{ + // We are synthesizing a property requirement for a wrapper type: + // + // interface IFoo { property value : int { get; set; } } + // struct Foo : IFoo = FooImpl; + // + // We need to synthesize Foo to: + // + // struct Foo : IFoo + // { + // FooImpl inner; + // property value : int { get { return inner.value; } + // set { inner.value = newValue; } + // } + // } + // + // To do so, we need to grab the witness table of FooImpl:IFoo, and create + // wrapper property in Foo that forwards the accessors to the inner object. + // + // We get started by constructing a synthesized `PropertyDecl`. + // + auto synPropertyDecl = m_astBuilder->create(); + synPropertyDecl->parentDecl = context->parentDecl; + + // Synthesize the property name with a prefix to avoid name clashing. + // + synPropertyDecl->nameAndLoc = requiredMemberDeclRef.getDecl()->nameAndLoc; + synPropertyDecl->nameAndLoc.name = + getName(String("$syn_property_") + getText(requiredMemberDeclRef.getName())); + + // Find the witness that FooImpl : IFoo. + auto aggTypeDecl = as(context->parentDecl); + auto innerType = aggTypeDecl->wrappedType.type; + DeclRef innerProperty; + auto innerWitness = tryGetSubtypeWitness(innerType, witnessTable->baseType); + if (!innerWitness) + return false; - // We now have the reference to the overload group we plan to call, - // and we already built up the argument list, so we can construct - // an `InvokeExpr` that represents the call we want to make. - // - auto synCall = m_astBuilder->create(); - synCall->functionExpr = synBase; - synCall->arguments = synArgs; - - // In order to know if our call is well-formed, we need to run - // the semantic checking logic for overload resolution. If it - // runs into an error, we don't want that being reported back - // to the user as some kind of overload-resolution failure. - // - // In order to protect the user from whatever errors might - // occur, we will perform the checking in the context of - // a temporary diagnostic sink. + for (auto requiredAccessorDeclRef : + getMembersOfType(m_astBuilder, requiredMemberDeclRef)) + { + auto innerEntry = tryLookUpRequirementWitness( + m_astBuilder, + innerWitness, + requiredAccessorDeclRef.getDecl()); + if (innerEntry.getFlavor() != RequirementWitness::Flavor::declRef) + return false; + auto innerAccessorDeclRef = as(innerEntry.getDeclRef()); + if (!innerAccessorDeclRef) + return false; + + // The synthesized accessor will be an AST node of the same class as + // the required accessor. + // + auto synAccessorDecl = (AccessorDecl*)m_astBuilder->createByNodeType( + requiredAccessorDeclRef.getDecl()->astNodeType); + synAccessorDecl->ownedScope = m_astBuilder->create(); + synAccessorDecl->ownedScope->containerDecl = synAccessorDecl; + synAccessorDecl->ownedScope->parent = getScope(context->parentDecl); + + // The return type should be the same as the inner object's accessor return type. + // + synAccessorDecl->returnType.type = getResultType(m_astBuilder, innerAccessorDeclRef); + + // Similarly, our synthesized accessor will have parameters matching those of the inner + // accessor. + // + List synArgs; + for (auto innerParamDeclRef : getParameters(m_astBuilder, innerAccessorDeclRef)) + { + auto paramType = getType(m_astBuilder, innerParamDeclRef); + + // The synthesized parameter will ahve the same name and + // type as the parameter of the requirement. + // + auto synParamDecl = m_astBuilder->create(); + synParamDecl->nameAndLoc = innerParamDeclRef.getDecl()->nameAndLoc; + synParamDecl->type.type = paramType; + + // We need to add the parameter as a child declaration of + // the accessor we are building. // - DiagnosticSink tempSink(getSourceManager(), nullptr); - ExprLocalScope localScope; - SemanticsVisitor subVisitor(withSink(&tempSink).withParentFunc(ctorDecl).withExprLocalScope(&localScope)); + synParamDecl->parentDecl = synAccessorDecl; + synAccessorDecl->members.add(synParamDecl); - // With our temporary diagnostic sink soaking up any messages - // from overload resolution, we can now try to resolve - // the call to see what happens. + // For each paramter, we will create an argument expression + // to represent it in the body of the accessor. // - auto checkedCall = subVisitor.ResolveInvoke(synCall); + auto synArg = m_astBuilder->create(); + synArg->declRef = makeDeclRef(synParamDecl); + synArg->type = paramType; + synArgs.add(synArg); + } - // If any error occurs during overload resolution, we can't synthesize the witness. - if (tempSink.getErrorCount() != 0) - return false; + // Now synthesize the body of the property accessor. + // The body of the accessor will depend on the class of the accessor + // we are synthesizing (e.g., `get` vs. `set`). + // + Stmt* synBodyStmt = nullptr; + auto propertyRef = m_astBuilder->create(); + propertyRef->scope = synAccessorDecl->ownedScope; + auto base = m_astBuilder->create(); + base->scope = propertyRef->scope; + base->name = getName("inner"); + propertyRef->baseExpression = base; + innerProperty = innerAccessorDeclRef.getParent(); + propertyRef->name = getParentDecl(innerAccessorDeclRef.getDecl())->getName(); + auto checkedPropertyRefExpr = CheckExpr(propertyRef); + + if (as(requiredAccessorDeclRef)) + { + auto synReturn = m_astBuilder->create(); + synReturn->expression = checkedPropertyRefExpr; + + synBodyStmt = synReturn; + } + else if (as(requiredAccessorDeclRef)) + { + auto synAssign = m_astBuilder->create(); + synAssign->left = checkedPropertyRefExpr; + synAssign->right = synArgs[0]; + + auto synCheckedAssign = checkAssignWithCheckedOperands(synAssign); - // If we were able to type-check the call, then we should - // be able to finish construction of a suitable ctor witness, - // by emitting `this = resolvedCtorCall()`. + auto synExprStmt = m_astBuilder->create(); + synExprStmt->expression = synCheckedAssign; + + synBodyStmt = synExprStmt; + } + else + { + // While there are other kinds of accessors than `get` and `set`, + // those are currently only reserved for the internal use in the core module. + // We will not bother with synthesis for those cases. // - AssignExpr* assignExpr = m_astBuilder->create(); - assignExpr->left = synThis; - assignExpr->right = checkedCall; - assignExpr->type = m_astBuilder->getVoidType(); - ExpressionStmt* exprStmt = m_astBuilder->create(); - exprStmt->expression = assignExpr; - seqStmt->stmts.add(exprStmt); + return false; } - if (isDefaultInitializableType) - context->parentDecl->addMember(ctorDecl); + addModifier(synAccessorDecl, m_astBuilder->create()); + synAccessorDecl->body = synBodyStmt; - auto containerDecl = getParentDecl(ctorDecl); - auto containerDeclRef = getDefaultDeclRef(containerDecl); - auto synDeclRef = m_astBuilder->getMemberDeclRef(containerDeclRef, ctorDecl); - _addMethodWitness(witnessTable, requiredMemberDeclRef, synDeclRef); + synAccessorDecl->parentDecl = synPropertyDecl; + synPropertyDecl->members.add(synAccessorDecl); - return true; + // Register the synthesized accessor. + // + witnessTable->add( + requiredAccessorDeclRef.getDecl(), + RequirementWitness(makeDeclRef(synAccessorDecl))); + } + + // The type of our synthesized property will be the same as the inner property. + // + auto propertyType = getType(m_astBuilder, as(innerProperty)); + synPropertyDecl->type.type = propertyType; + + // The visibility of synthesized decl should be the same as the inner requirement + if (innerProperty.getDecl()->findModifier()) + { + auto vis = getDeclVisibility(innerProperty.getDecl()); + addVisibilityModifier(m_astBuilder, synPropertyDecl, vis); + } + + context->parentDecl->addMember(synPropertyDecl); + witnessTable->add( + requiredMemberDeclRef.getDecl(), + RequirementWitness(makeDeclRef(synPropertyDecl))); + return true; +} + +bool SemanticsVisitor::trySynthesizeAssociatedTypeRequirementWitness( + ConformanceCheckingContext* context, + LookupResult const& inLookupResult, + DeclRef requiredMemberDeclRef, + RefPtr witnessTable) +{ + SLANG_UNUSED(inLookupResult); + + // The only case we can synthesize for now is when the conformant type + // is a wrapper type. + if (!isWrapperTypeDecl(context->parentDecl)) + return false; + auto aggTypeDecl = as(context->parentDecl); + auto lookupResult = lookUpMember( + m_astBuilder, + this, + requiredMemberDeclRef.getName(), + aggTypeDecl->wrappedType.type, + aggTypeDecl->ownedScope, + LookupMask::Default, + LookupOptions::IgnoreBaseInterfaces); + if (!lookupResult.isValid() || lookupResult.isOverloaded()) + return false; + auto assocType = DeclRefType::create(m_astBuilder, lookupResult.item.declRef); + witnessTable->add(requiredMemberDeclRef.getDecl(), assocType); + for (auto typeConstraintDecl : + getMembersOfType(m_astBuilder, requiredMemberDeclRef)) + { + auto witness = tryGetSubtypeWitness(assocType, getSup(m_astBuilder, typeConstraintDecl)); + if (!witness) + return false; + witnessTable->add(typeConstraintDecl.getDecl(), witness); } + return true; +} + +bool SemanticsVisitor::trySynthesizeAssociatedConstantRequirementWitness( + ConformanceCheckingContext* context, + LookupResult const& inLookupResult, + DeclRef requiredMemberDeclRef, + RefPtr witnessTable) +{ + SLANG_UNUSED(inLookupResult); + + // The only case we can synthesize for now is when the conformant type + // is a wrapper type, i.e. + // struct Foo:IFoo = FooImpl; + if (!isWrapperTypeDecl(context->parentDecl)) + return false; + + // Find the witness that FooImpl : IFoo. + auto aggTypeDecl = as(context->parentDecl); + auto innerType = aggTypeDecl->wrappedType.type; + DeclRef innerProperty; + auto innerWitness = tryGetSubtypeWitness(innerType, witnessTable->baseType); + if (!innerWitness) + return false; + + auto witness = + tryLookUpRequirementWitness(m_astBuilder, innerWitness, requiredMemberDeclRef.getDecl()); + if (witness.getFlavor() != RequirementWitness::Flavor::val) + return false; + witnessTable->add(requiredMemberDeclRef.getDecl(), witness.getVal()); + return true; +} - bool SemanticsVisitor::trySynthesizePropertyRequirementWitness( - ConformanceCheckingContext* context, - LookupResult const& lookupResult, - DeclRef requiredMemberDeclRef, - RefPtr witnessTable) - { - if (isWrapperTypeDecl(context->parentDecl)) - return trySynthesizeWrapperTypePropertyRequirementWitness(context, requiredMemberDeclRef, witnessTable); - - // The situation here is that the context of an inheritance - // declaration didn't provide an exact match for a required - // property. E.g.: - // - // interface ICell { property value : int { get; set; } } - // struct MyCell : ICell - // { - // int value; - // } - // - // It is clear in this case that the `MyCell` type *can* - // satisfy the signature required by `ICell`, but it has - // no explicit `property` declaration, and instead just - // a field with the right name and type. - // - // The approach in this function will be to construct a - // synthesized `preoperty` along the lines of: - // - // struct MyCounter ... - // { - // ... - // property value_synthesized : int - // { - // get { return this.value; } - // set(newValue) { this.value = newValue; } - // } - // } - // - // That is, we construct a `property` with the correct type - // and with an accessor for each requirement, where the accesors - // all try to read or write `this.value`. - // - // If those synthesized accessors all type-check, then we can - // say that the type must satisfy the requirement structurally, - // even if there isn't an exact signature match. More - // importantly, the `property` we just synthesized can be - // used as a witness to the fact that the requirement is - // satisfied. - // - // The big-picture flow of the logic here is similar to - // `trySynthesizeMethodRequirementWitness()` above, and we - // will not comment this code as exhaustively, under the - // assumption that readers of the code don't benefit from - // having the exact same information stated twice. - - // With the introduction out of the way, let's get started - // constructing a synthesized `PropertyDecl`. - // - auto synPropertyDecl = m_astBuilder->create(); - - // Synthesize the property name with a prefix to avoid name clashing. - synPropertyDecl->nameAndLoc = requiredMemberDeclRef.getDecl()->nameAndLoc; - synPropertyDecl->nameAndLoc.name = getName(String("$syn_property_") + getText(requiredMemberDeclRef.getName())); - synPropertyDecl->parentDecl = context->parentDecl; - - - // The type of our synthesized property can be derived from the - // specialized declref to the requirement decl. - // - auto propertyType = getType(m_astBuilder, requiredMemberDeclRef); - synPropertyDecl->type.type = propertyType; - - - // We start by constructing an expression that represents - // `this.name` where `name` is the name of the required - // member. The caller already passed in a `lookupResult` - // that should indicate all the declarations found by - // looking up `name`, so we can start with that. - // - // TODO: Note that there are many cases for member lookup - // that are not handled just by using `createLookupResultExpr` - // because they are currently being special-cased (the most - // notable cases are swizzles, as well as lookup of static - // members in types). - // - // The main result here is that we will not be able to synthesize - // a requirement for a built-in scalar/vector/matrix type to - // a property with a name like `.xy` based on the presence of - // swizles, even though it seems like such a thing should Just Work. - // - // If this is important we could "fix" it by allowing this - // code to dispatch to the special-case logic used when doing - // semantic checking for member expressions. - // - // Note: an alternative would be to change the core module declarations - // of vectors/matrices so that all the swizzles are defined as - // `property` declarations. There are some C++ math libraries (like GLM) - // that implement swizzle syntax by a similar approach of statically - // enumerating all possible swizzles. The down-side to such an - // approach is that the combinatorial space of swizzles is quite - // large (especially for matrices) so that supporting them via - // general-purpose language features is unlikely to be as efficient - // as special-case logic. +bool SemanticsVisitor::synthesizeAccessorRequirements( + ConformanceCheckingContext* context, + DeclRef requiredMemberDeclRef, + Type* resultType, + Expr* synBoundStorageExpr, + ContainerDecl* synAccesorContainer, + RefPtr witnessTable) +{ + Dictionary, AccessorDecl*> mapRequiredAccessorToSynAccessor; + for (auto requiredAccessorDeclRef : + getMembersOfType(m_astBuilder, requiredMemberDeclRef)) + { + // The synthesized accessor will be an AST node of the same class as + // the required accessor. // - // We are going to synthesize an expression and then perform - // semantic checking on it, but if there are semantic errors - // we do *not* want to report them to the user as such, and - // instead want the result to be a failure to synthesize - // a valid witness. + auto synAccessorDecl = (AccessorDecl*)m_astBuilder->createByNodeType( + requiredAccessorDeclRef.getDecl()->astNodeType); + synAccessorDecl->ownedScope = m_astBuilder->create(); + synAccessorDecl->ownedScope->containerDecl = synAccessorDecl; + synAccessorDecl->ownedScope->parent = getScope(context->parentDecl); + + // Whatever the required accessor returns, that is what our synthesized accessor will + // return. // - // We will buffer up diagnostics into a temporary sink and - // then throw them away when we are done. + synAccessorDecl->returnType.type = resultType; + + // Similarly, our synthesized accessor will have parameters matching those of the + // requirement. // - // TODO: This behavior might be something we want to make - // into a more fundamental capability of `DiagnosticSink` and/or - // `SemanticsVisitor` so that code can push/pop the emission - // of diagnostics more easily. + // Note: in practice we expect that only `set` accessors will have any parameters, + // and they will only have a single parameter. // - DiagnosticSink tempSink(getSourceManager(), nullptr); - SemanticsVisitor subVisitor(withSink(&tempSink)); + List synArgs; + for (auto requiredParamDeclRef : getParameters(m_astBuilder, requiredAccessorDeclRef)) + { + auto paramType = getType(m_astBuilder, requiredParamDeclRef); + + // The synthesized parameter will ahve the same name and + // type as the parameter of the requirement. + // + auto synParamDecl = m_astBuilder->create(); + synParamDecl->nameAndLoc = requiredParamDeclRef.getDecl()->nameAndLoc; + synParamDecl->type.type = paramType; + + // We need to add the parameter as a child declaration of + // the accessor we are building. + // + synParamDecl->parentDecl = synAccessorDecl; + synAccessorDecl->members.add(synParamDecl); + + // For each paramter, we will create an argument expression + // to represent it in the body of the accessor. + // + auto synArg = m_astBuilder->create(); + synArg->declRef = makeDeclRef(synParamDecl); + synArg->type = paramType; + synArgs.add(synArg); + } // We need to create a `this` expression to be used in the body // of the synthesized accessor. @@ -4959,1084 +5377,747 @@ namespace Slang // a `this` expression. // ThisExpr* synThis = m_astBuilder->create(); - synThis->scope = synPropertyDecl->ownedScope; + synThis->scope = synAccessorDecl->ownedScope; // The type of `this` in our accessor will be the type for // which we are synthesizing a conformance. // synThis->type.type = context->conformingType; + + // A `get` accessor should default to an immutable `this`, + // while other accessors default to mutable `this`. + // + // TODO: If we ever add other kinds of accessors, we will + // need to check that this assumption stays valid. + // synThis->type.isLeftValue = true; - auto synMemberRef = subVisitor.createLookupResultExpr( - requiredMemberDeclRef.getName(), - lookupResult, - synThis, - requiredMemberDeclRef.getLoc(), - nullptr); - synMemberRef->loc = requiredMemberDeclRef.getLoc(); + if (as(requiredAccessorDeclRef)) + synThis->type.isLeftValue = false; - bool canSynAccessors = synthesizeAccessorRequirements( - context, - requiredMemberDeclRef, - propertyType, - synMemberRef, - synPropertyDecl, - witnessTable); - if (!canSynAccessors) - return false; - + // If the accessor requirement is `[nonmutating]` then our + // synthesized accessor should be too, and also the `this` + // parameter should *not* be an l-value. + // + if (requiredAccessorDeclRef.getDecl()->hasModifier()) + { + synThis->type.isLeftValue = false; + auto synAttr = m_astBuilder->create(); + synAccessorDecl->modifiers.first = synAttr; + } + // + // Note: we don't currently support `[mutating] get` accessors, + // but the desired behavior in that case is clear, so we go + // ahead and future-proof this code a bit: + // + else if (requiredAccessorDeclRef.getDecl()->hasModifier()) + { + synThis->type.isLeftValue = true; + auto synAttr = m_astBuilder->create(); + synAccessorDecl->modifiers.first = synAttr; + } + else if (requiredAccessorDeclRef.getDecl()->hasModifier()) + { + synThis->type.isLeftValue = true; - // The visibility of synthesized decl should be the min of the parent decl and the requirement. - if (requiredMemberDeclRef.getDecl()->findModifier()) + auto synAttr = m_astBuilder->create(); + synAccessorDecl->modifiers.first = synAttr; + } + else if (requiredAccessorDeclRef.getDecl()->hasModifier()) { - auto requirementVisibility = getDeclVisibility(requiredMemberDeclRef.getDecl()); - auto thisVisibility = getDeclVisibility(context->parentDecl); - auto visibility = Math::Min(thisVisibility, requirementVisibility); - addVisibilityModifier(m_astBuilder, synPropertyDecl, visibility); + auto synAttr = m_astBuilder->create(); + synAccessorDecl->modifiers.first = synAttr; } - return true; - } - - bool SemanticsVisitor::trySynthesizeWrapperTypePropertyRequirementWitness( - ConformanceCheckingContext* context, - DeclRef requiredMemberDeclRef, - RefPtr witnessTable) - { - // We are synthesizing a property requirement for a wrapper type: + // We are going to synthesize an expression and then perform + // semantic checking on it, but if there are semantic errors + // we do *not* want to report them to the user as such, and + // instead want the result to be a failure to synthesize + // a valid witness. // - // interface IFoo { property value : int { get; set; } } - // struct Foo : IFoo = FooImpl; - // - // We need to synthesize Foo to: - // - // struct Foo : IFoo - // { - // FooImpl inner; - // property value : int { get { return inner.value; } - // set { inner.value = newValue; } - // } - // } - // - // To do so, we need to grab the witness table of FooImpl:IFoo, and create - // wrapper property in Foo that forwards the accessors to the inner object. + // We will buffer up diagnostics into a temporary sink and + // then throw them away when we are done. // - // We get started by constructing a synthesized `PropertyDecl`. + // TODO: This behavior might be something we want to make + // into a more fundamental capability of `DiagnosticSink` and/or + // `SemanticsVisitor` so that code can push/pop the emission + // of diagnostics more easily. // - auto synPropertyDecl = m_astBuilder->create(); - synPropertyDecl->parentDecl = context->parentDecl; + DiagnosticSink tempSink(getSourceManager(), nullptr); + SemanticsVisitor subVisitor(withSink(&tempSink)); - // Synthesize the property name with a prefix to avoid name clashing. + // The body of the accessor will depend on the class of the accessor + // we are synthesizing (e.g., `get` vs. `set`). // - synPropertyDecl->nameAndLoc = requiredMemberDeclRef.getDecl()->nameAndLoc; - synPropertyDecl->nameAndLoc.name = getName(String("$syn_property_") + getText(requiredMemberDeclRef.getName())); - - // Find the witness that FooImpl : IFoo. - auto aggTypeDecl = as(context->parentDecl); - auto innerType = aggTypeDecl->wrappedType.type; - DeclRef innerProperty; - auto innerWitness = tryGetSubtypeWitness(innerType, witnessTable->baseType); - if (!innerWitness) - return false; - - for (auto requiredAccessorDeclRef : getMembersOfType(m_astBuilder, requiredMemberDeclRef)) + Stmt* synBodyStmt = nullptr; + if (as(requiredAccessorDeclRef)) { - auto innerEntry = tryLookUpRequirementWitness(m_astBuilder, innerWitness, requiredAccessorDeclRef.getDecl()); - if (innerEntry.getFlavor() != RequirementWitness::Flavor::declRef) - return false; - auto innerAccessorDeclRef = as(innerEntry.getDeclRef()); - if (!innerAccessorDeclRef) - return false; - - // The synthesized accessor will be an AST node of the same class as - // the required accessor. + // A `get` accessor will simply perform: // - auto synAccessorDecl = (AccessorDecl*)m_astBuilder->createByNodeType(requiredAccessorDeclRef.getDecl()->astNodeType); - synAccessorDecl->ownedScope = m_astBuilder->create(); - synAccessorDecl->ownedScope->containerDecl = synAccessorDecl; - synAccessorDecl->ownedScope->parent = getScope(context->parentDecl); - - // The return type should be the same as the inner object's accessor return type. + // return this.name; // - synAccessorDecl->returnType.type = getResultType(m_astBuilder, innerAccessorDeclRef); - - // Similarly, our synthesized accessor will have parameters matching those of the inner accessor. + // which involves coercing the member access `this.name` to + // the expected type of the property. // - List synArgs; - for (auto innerParamDeclRef : getParameters(m_astBuilder, innerAccessorDeclRef)) - { - auto paramType = getType(m_astBuilder, innerParamDeclRef); + auto coercedMemberRef = + subVisitor.coerce(CoercionSite::Return, resultType, synBoundStorageExpr); + auto synReturn = m_astBuilder->create(); + synReturn->expression = coercedMemberRef; - // The synthesized parameter will ahve the same name and - // type as the parameter of the requirement. - // - auto synParamDecl = m_astBuilder->create(); - synParamDecl->nameAndLoc = innerParamDeclRef.getDecl()->nameAndLoc; - synParamDecl->type.type = paramType; - - // We need to add the parameter as a child declaration of - // the accessor we are building. - // - synParamDecl->parentDecl = synAccessorDecl; - synAccessorDecl->members.add(synParamDecl); - - // For each paramter, we will create an argument expression - // to represent it in the body of the accessor. - // - auto synArg = m_astBuilder->create(); - synArg->declRef = makeDeclRef(synParamDecl); - synArg->type = paramType; - synArgs.add(synArg); - } - - // Now synthesize the body of the property accessor. - // The body of the accessor will depend on the class of the accessor - // we are synthesizing (e.g., `get` vs. `set`). + synBodyStmt = synReturn; + } + else if (as(requiredAccessorDeclRef)) + { + // We expect all `set` accessors to have a single argument, + // but we will defensively bail out if that is somehow + // not the case. // - Stmt* synBodyStmt = nullptr; - auto propertyRef = m_astBuilder->create(); - propertyRef->scope = synAccessorDecl->ownedScope; - auto base = m_astBuilder->create(); - base->scope = propertyRef->scope; - base->name = getName("inner"); - propertyRef->baseExpression = base; - innerProperty = innerAccessorDeclRef.getParent(); - propertyRef->name = getParentDecl(innerAccessorDeclRef.getDecl())->getName(); - auto checkedPropertyRefExpr = CheckExpr(propertyRef); - - if (as(requiredAccessorDeclRef)) - { - auto synReturn = m_astBuilder->create(); - synReturn->expression = checkedPropertyRefExpr; - - synBodyStmt = synReturn; - } - else if (as(requiredAccessorDeclRef)) - { - auto synAssign = m_astBuilder->create(); - synAssign->left = checkedPropertyRefExpr; - synAssign->right = synArgs[0]; - - auto synCheckedAssign = checkAssignWithCheckedOperands(synAssign); + SLANG_ASSERT(synArgs.getCount() == 1); + if (synArgs.getCount() != 1) + return false; - auto synExprStmt = m_astBuilder->create(); - synExprStmt->expression = synCheckedAssign; + // A `set` accessor will simply perform: + // + // this.name = newValue; + // + // which involves creating and checking an assignment + // expression. - synBodyStmt = synExprStmt; - } - else - { - // While there are other kinds of accessors than `get` and `set`, - // those are currently only reserved for the internal use in the core module. - // We will not bother with synthesis for those cases. - // - return false; - } + auto synAssign = m_astBuilder->create(); + synAssign->left = synBoundStorageExpr; + synAssign->right = synArgs[0]; - addModifier(synAccessorDecl, m_astBuilder->create()); - synAccessorDecl->body = synBodyStmt; + auto synCheckedAssign = subVisitor.checkAssignWithCheckedOperands(synAssign); - synAccessorDecl->parentDecl = synPropertyDecl; - synPropertyDecl->members.add(synAccessorDecl); + auto synExprStmt = m_astBuilder->create(); + synExprStmt->expression = synCheckedAssign; - // Register the synthesized accessor. + synBodyStmt = synExprStmt; + } + else + { + // While there are other kinds of accessors than `get` and `set`, + // those are currently only reserved for the internal use in the core module. + // We will not bother with synthesis for those cases. // - witnessTable->add(requiredAccessorDeclRef.getDecl(), RequirementWitness(makeDeclRef(synAccessorDecl))); + return false; } - // The type of our synthesized property will be the same as the inner property. + // We bail out if we ran into any errors (meaning that the synthesized + // accessor is not usable). // - auto propertyType = getType(m_astBuilder, as(innerProperty)); - synPropertyDecl->type.type = propertyType; - - // The visibility of synthesized decl should be the same as the inner requirement - if (innerProperty.getDecl()->findModifier()) - { - auto vis = getDeclVisibility(innerProperty.getDecl()); - addVisibilityModifier(m_astBuilder, synPropertyDecl, vis); - } + // TODO: If there were *warnings* emitted to the sink, it would probably + // be good to show those warnings to the user, since they might indicate + // real issues. E.g., with the current logic a `float` field could + // satisfying an `int` property requirement, but the user would probably + // want to be warned when they do such a thing. + // + if (tempSink.getErrorCount() != 0) + return false; - context->parentDecl->addMember(synPropertyDecl); - witnessTable->add(requiredMemberDeclRef.getDecl(), - RequirementWitness(makeDeclRef(synPropertyDecl))); - return true; - } + synAccessorDecl->body = synBodyStmt; - bool SemanticsVisitor::trySynthesizeAssociatedTypeRequirementWitness( - ConformanceCheckingContext* context, - LookupResult const& inLookupResult, - DeclRef requiredMemberDeclRef, - RefPtr witnessTable) - { - SLANG_UNUSED(inLookupResult); + synAccessorDecl->parentDecl = synAccesorContainer; + synAccesorContainer->members.add(synAccessorDecl); - // The only case we can synthesize for now is when the conformant type - // is a wrapper type. - if (!isWrapperTypeDecl(context->parentDecl)) - return false; - auto aggTypeDecl = as(context->parentDecl); - auto lookupResult = lookUpMember( - m_astBuilder, - this, - requiredMemberDeclRef.getName(), - aggTypeDecl->wrappedType.type, - aggTypeDecl->ownedScope, - LookupMask::Default, - LookupOptions::IgnoreBaseInterfaces); - if (!lookupResult.isValid() || lookupResult.isOverloaded()) - return false; - auto assocType = DeclRefType::create(m_astBuilder, lookupResult.item.declRef); - witnessTable->add(requiredMemberDeclRef.getDecl(), assocType); - for (auto typeConstraintDecl : getMembersOfType(m_astBuilder, requiredMemberDeclRef)) - { - auto witness = tryGetSubtypeWitness(assocType, getSup(m_astBuilder, typeConstraintDecl)); - if (!witness) - return false; - witnessTable->add(typeConstraintDecl.getDecl(), witness); - } - return true; + // If synthesis of an accessor worked, then we will record it into + // a local dictionary. We do *not* install the accessor into the + // witness table yet, because it is possible that synthesis will + // succeed for some accessors but not others, and we don't want + // to leave the witness table in a state where a requirement is + // "partially satisfied." + // + mapRequiredAccessorToSynAccessor.add(requiredAccessorDeclRef, synAccessorDecl); } - bool SemanticsVisitor::trySynthesizeAssociatedConstantRequirementWitness( - ConformanceCheckingContext* context, - LookupResult const& inLookupResult, - DeclRef requiredMemberDeclRef, - RefPtr witnessTable) + // Once our synthesized declaration is complete, we need + // to install it as the witness that satifies the given + // requirement. + // + // Subsequent code generation should not be able to tell the + // difference between our synthetic property and a hand-written + // one with the same behavior. + // + auto containerDecl = getParentDecl(synAccesorContainer); + auto containerDeclRef = getDefaultDeclRef(containerDecl); + for (auto& [key, value] : mapRequiredAccessorToSynAccessor) { - SLANG_UNUSED(inLookupResult); + witnessTable->add( + key.getDecl(), + RequirementWitness(m_astBuilder->getMemberDeclRef(containerDeclRef, value))); + } - // The only case we can synthesize for now is when the conformant type - // is a wrapper type, i.e. - // struct Foo:IFoo = FooImpl; - if (!isWrapperTypeDecl(context->parentDecl)) - return false; + witnessTable->add( + requiredMemberDeclRef.getDecl(), + RequirementWitness(m_astBuilder->getMemberDeclRef(containerDeclRef, synAccesorContainer))); + return true; +} - // Find the witness that FooImpl : IFoo. - auto aggTypeDecl = as(context->parentDecl); - auto innerType = aggTypeDecl->wrappedType.type; - DeclRef innerProperty; - auto innerWitness = tryGetSubtypeWitness(innerType, witnessTable->baseType); - if (!innerWitness) - return false; +bool SemanticsVisitor::trySynthesizeWrapperTypeSubscriptRequirementWitness( + ConformanceCheckingContext* context, + DeclRef requiredMemberDeclRef, + RefPtr witnessTable) +{ + // We are synthesizing the subscript requirement for a wrapper type: + // struct Wrapper + // { + // Inner inner; + // subscript(int index)->int { get { return inner[index]; } + // set { inner[index] = newValue; } + // } + // } + // + // // Find the witness that FooImpl : IFoo. + auto aggTypeDecl = as(context->parentDecl); + auto innerType = aggTypeDecl->wrappedType.type; + DeclRef innerProperty; + auto innerWitness = tryGetSubtypeWitness(innerType, witnessTable->baseType); + if (!innerWitness) + return false; + // + List synArgs; + ThisExpr* synThis; + auto synSubscriptDecl = synthesizeMethodSignatureForRequirementWitness( + context, + requiredMemberDeclRef, + synArgs, + synThis); + auto declType = getType(m_astBuilder, getDefaultDeclRef(synSubscriptDecl).as()); + synThis->checked = true; + + // Form a `this[args...]` expression that we will use to coerce from + // in the synthesized subscript accessors. + // + DiagnosticSink tempSink(getSourceManager(), nullptr); + SemanticsVisitor subVisitor(withSink(&tempSink)); + auto base = m_astBuilder->create(); + base->scope = synThis->scope; + base->name = getName("inner"); + + IndexExpr* indexExpr = m_astBuilder->create(); + indexExpr->baseExpression = base; + indexExpr->indexExprs = _Move(synArgs); + auto synBaseStorageExpr = subVisitor.CheckTerm(indexExpr); + + if (tempSink.getErrorCount() != 0) + return false; - auto witness = tryLookUpRequirementWitness(m_astBuilder, innerWitness, requiredMemberDeclRef.getDecl()); - if (witness.getFlavor() != RequirementWitness::Flavor::val) - return false; - witnessTable->add(requiredMemberDeclRef.getDecl(), witness.getVal()); - return true; - } + // Our synthesized subscript will have an accessor declaration for + // each accessor of the requirement. + // + bool canSynAccessors = synthesizeAccessorRequirements( + context, + requiredMemberDeclRef, + declType, + synBaseStorageExpr, + synSubscriptDecl, + witnessTable); + if (!canSynAccessors) + return false; - bool SemanticsVisitor::synthesizeAccessorRequirements( - ConformanceCheckingContext* context, - DeclRef requiredMemberDeclRef, - Type* resultType, - Expr* synBoundStorageExpr, - ContainerDecl* synAccesorContainer, - RefPtr witnessTable) + // The visibility of synthesized decl should be the min of the parent decl and the requirement. + if (requiredMemberDeclRef.getDecl()->findModifier()) { - Dictionary, AccessorDecl*> mapRequiredAccessorToSynAccessor; - for (auto requiredAccessorDeclRef : getMembersOfType(m_astBuilder, requiredMemberDeclRef)) - { - // The synthesized accessor will be an AST node of the same class as - // the required accessor. - // - auto synAccessorDecl = (AccessorDecl*)m_astBuilder->createByNodeType(requiredAccessorDeclRef.getDecl()->astNodeType); - synAccessorDecl->ownedScope = m_astBuilder->create(); - synAccessorDecl->ownedScope->containerDecl = synAccessorDecl; - synAccessorDecl->ownedScope->parent = getScope(context->parentDecl); - - // Whatever the required accessor returns, that is what our synthesized accessor will return. - // - synAccessorDecl->returnType.type = resultType; + auto requirementVisibility = getDeclVisibility(requiredMemberDeclRef.getDecl()); + auto thisVisibility = getDeclVisibility(context->parentDecl); + auto visibility = Math::Min(thisVisibility, requirementVisibility); + addVisibilityModifier(m_astBuilder, synSubscriptDecl, visibility); + } - // Similarly, our synthesized accessor will have parameters matching those of the requirement. - // - // Note: in practice we expect that only `set` accessors will have any parameters, - // and they will only have a single parameter. - // - List synArgs; - for (auto requiredParamDeclRef : getParameters(m_astBuilder, requiredAccessorDeclRef)) - { - auto paramType = getType(m_astBuilder, requiredParamDeclRef); + return true; +} - // The synthesized parameter will ahve the same name and - // type as the parameter of the requirement. - // - auto synParamDecl = m_astBuilder->create(); - synParamDecl->nameAndLoc = requiredParamDeclRef.getDecl()->nameAndLoc; - synParamDecl->type.type = paramType; +bool SemanticsVisitor::trySynthesizeSubscriptRequirementWitness( + ConformanceCheckingContext* context, + const LookupResult& lookupResult, + DeclRef requiredMemberDeclRef, + RefPtr witnessTable) +{ + if (isWrapperTypeDecl(context->parentDecl)) + return trySynthesizeWrapperTypeSubscriptRequirementWitness( + context, + requiredMemberDeclRef, + witnessTable); - // We need to add the parameter as a child declaration of - // the accessor we are building. - // - synParamDecl->parentDecl = synAccessorDecl; - synAccessorDecl->members.add(synParamDecl); + // The situation here is that the context of an inheritance + // declaration didn't provide an exact match for a required + // subscript. E.g.: + // + // interface ICell { subscript(int index)->int {get;} } + // struct MyCell : ICell + // { + // subscript(uint index)->int {ref;} + // } + // + // It is clear in this case that the `MyCell` type *can* + // satisfy the signature required by `ICell`, if we consider + // all the allowed type coercion rules, and use `ref` accessor + // to implement `get`. + // + // The approach in this function will be to construct a + // synthesized `subscript` along the lines of: + // + // struct MyCell ... + // { + // ... + // subscript(int index)->int {get;} + // { + // get { return this.origianl_subscript[index]; } + // } + // } + // + // That is, we construct a `subscript` with the correct type + // and with an accessor for each requirement, where the accesors + // all try to dispatch to the original subscript decl. + // + // If those synthesized accessors all type-check, then we can + // say that the type must satisfy the requirement structurally, + // even if there isn't an exact signature match. More + // importantly, the `property` we just synthesized can be + // used as a witness to the fact that the requirement is + // satisfied. + // + // The big-picture flow of the logic here is similar to + // `trySynthesizePropertyRequirementWitness()` above, and we + // will not comment this code as exhaustively, under the + // assumption that readers of the code don't benefit from + // having the exact same information stated twice. + // + + List synArgs; + ThisExpr* synThis; + auto synSubscriptDecl = synthesizeMethodSignatureForRequirementWitness( + context, + requiredMemberDeclRef, + synArgs, + synThis); + synThis->type.isLeftValue = true; + synThis->checked = true; + + auto declType = getType(m_astBuilder, getDefaultDeclRef(synSubscriptDecl).as()); + + // Form a `this[args...]` expression that we will use to coerce from + // in the synthesized subscript accessors. + // + DiagnosticSink tempSink(getSourceManager(), nullptr); + SemanticsVisitor subVisitor(withSink(&tempSink)); + Expr* synBaseStorageExpr = nullptr; + if (lookupResult.isValid()) + { + auto calleeExpr = m_astBuilder->create(); + calleeExpr->base = synThis; + calleeExpr->lookupResult2 = lookupResult; + auto invokeExpr = m_astBuilder->create(); + invokeExpr->functionExpr = calleeExpr; + invokeExpr->arguments = _Move(synArgs); + synBaseStorageExpr = subVisitor.ResolveInvoke(invokeExpr); + } + else + { + IndexExpr* indexExpr = m_astBuilder->create(); + indexExpr->baseExpression = synThis; + indexExpr->indexExprs = _Move(synArgs); + synBaseStorageExpr = subVisitor.CheckTerm(indexExpr); + } + if (tempSink.getErrorCount() != 0) + return false; - // For each paramter, we will create an argument expression - // to represent it in the body of the accessor. - // - auto synArg = m_astBuilder->create(); - synArg->declRef = makeDeclRef(synParamDecl); - synArg->type = paramType; - synArgs.add(synArg); - } + // Our synthesized subscript will have an accessor declaration for + // each accessor of the requirement. + // + bool canSynAccessors = synthesizeAccessorRequirements( + context, + requiredMemberDeclRef, + declType, + synBaseStorageExpr, + synSubscriptDecl, + witnessTable); + if (!canSynAccessors) + return false; - // We need to create a `this` expression to be used in the body - // of the synthesized accessor. - // - // TODO: if we ever allow `static` properties or subscripts, - // we will need to handle that case here, by *not* creating - // a `this` expression. - // - ThisExpr* synThis = m_astBuilder->create(); - synThis->scope = synAccessorDecl->ownedScope; - // The type of `this` in our accessor will be the type for - // which we are synthesizing a conformance. - // - synThis->type.type = context->conformingType; + // The visibility of synthesized decl should be the min of the parent decl and the requirement. + if (requiredMemberDeclRef.getDecl()->findModifier()) + { + auto requirementVisibility = getDeclVisibility(requiredMemberDeclRef.getDecl()); + auto thisVisibility = getDeclVisibility(context->parentDecl); + auto visibility = Math::Min(thisVisibility, requirementVisibility); + addVisibilityModifier(m_astBuilder, synSubscriptDecl, visibility); + } - // A `get` accessor should default to an immutable `this`, - // while other accessors default to mutable `this`. - // - // TODO: If we ever add other kinds of accessors, we will - // need to check that this assumption stays valid. - // - synThis->type.isLeftValue = true; - if (as(requiredAccessorDeclRef)) - synThis->type.isLeftValue = false; + return true; +} - // If the accessor requirement is `[nonmutating]` then our - // synthesized accessor should be too, and also the `this` - // parameter should *not* be an l-value. - // - if (requiredAccessorDeclRef.getDecl()->hasModifier()) - { - synThis->type.isLeftValue = false; +bool SemanticsVisitor::trySynthesizeRequirementWitness( + ConformanceCheckingContext* context, + LookupResult const& lookupResult, + DeclRef requiredMemberDeclRef, + RefPtr witnessTable) +{ + SLANG_UNUSED(lookupResult); + SLANG_UNUSED(requiredMemberDeclRef); + SLANG_UNUSED(witnessTable); - auto synAttr = m_astBuilder->create(); - synAccessorDecl->modifiers.first = synAttr; - } - // - // Note: we don't currently support `[mutating] get` accessors, - // but the desired behavior in that case is clear, so we go - // ahead and future-proof this code a bit: - // - else if (requiredAccessorDeclRef.getDecl()->hasModifier()) - { - synThis->type.isLeftValue = true; + if (auto requiredFuncDeclRef = requiredMemberDeclRef.as()) + { + // Check signature match. + if (trySynthesizeMethodRequirementWitness( + context, + lookupResult, + requiredFuncDeclRef, + witnessTable)) + return true; - auto synAttr = m_astBuilder->create(); - synAccessorDecl->modifiers.first = synAttr; - } - else if (requiredAccessorDeclRef.getDecl()->hasModifier()) + if (auto builtinAttr = + requiredFuncDeclRef.getDecl()->findModifier()) + { + switch (builtinAttr->kind) { - synThis->type.isLeftValue = true; - - auto synAttr = m_astBuilder->create(); - synAccessorDecl->modifiers.first = synAttr; - } - else if (requiredAccessorDeclRef.getDecl()->hasModifier()) - { - auto synAttr = m_astBuilder->create(); - synAccessorDecl->modifiers.first = synAttr; - } - // We are going to synthesize an expression and then perform - // semantic checking on it, but if there are semantic errors - // we do *not* want to report them to the user as such, and - // instead want the result to be a failure to synthesize - // a valid witness. - // - // We will buffer up diagnostics into a temporary sink and - // then throw them away when we are done. - // - // TODO: This behavior might be something we want to make - // into a more fundamental capability of `DiagnosticSink` and/or - // `SemanticsVisitor` so that code can push/pop the emission - // of diagnostics more easily. - // - DiagnosticSink tempSink(getSourceManager(), nullptr); - SemanticsVisitor subVisitor(withSink(&tempSink)); - - // The body of the accessor will depend on the class of the accessor - // we are synthesizing (e.g., `get` vs. `set`). - // - Stmt* synBodyStmt = nullptr; - if (as(requiredAccessorDeclRef)) - { - // A `get` accessor will simply perform: - // - // return this.name; - // - // which involves coercing the member access `this.name` to - // the expected type of the property. - // - auto coercedMemberRef = subVisitor.coerce(CoercionSite::Return, resultType, synBoundStorageExpr); - auto synReturn = m_astBuilder->create(); - synReturn->expression = coercedMemberRef; - - synBodyStmt = synReturn; - } - else if (as(requiredAccessorDeclRef)) - { - // We expect all `set` accessors to have a single argument, - // but we will defensively bail out if that is somehow - // not the case. - // - SLANG_ASSERT(synArgs.getCount() == 1); - if (synArgs.getCount() != 1) - return false; - - // A `set` accessor will simply perform: - // - // this.name = newValue; - // - // which involves creating and checking an assignment - // expression. - - auto synAssign = m_astBuilder->create(); - synAssign->left = synBoundStorageExpr; - synAssign->right = synArgs[0]; - - auto synCheckedAssign = subVisitor.checkAssignWithCheckedOperands(synAssign); - - auto synExprStmt = m_astBuilder->create(); - synExprStmt->expression = synCheckedAssign; - - synBodyStmt = synExprStmt; - } - else - { - // While there are other kinds of accessors than `get` and `set`, - // those are currently only reserved for the internal use in the core module. - // We will not bother with synthesis for those cases. - // - return false; - } - - // We bail out if we ran into any errors (meaning that the synthesized - // accessor is not usable). - // - // TODO: If there were *warnings* emitted to the sink, it would probably - // be good to show those warnings to the user, since they might indicate - // real issues. E.g., with the current logic a `float` field could - // satisfying an `int` property requirement, but the user would probably - // want to be warned when they do such a thing. - // - if (tempSink.getErrorCount() != 0) - return false; - - synAccessorDecl->body = synBodyStmt; - - synAccessorDecl->parentDecl = synAccesorContainer; - synAccesorContainer->members.add(synAccessorDecl); - - // If synthesis of an accessor worked, then we will record it into - // a local dictionary. We do *not* install the accessor into the - // witness table yet, because it is possible that synthesis will - // succeed for some accessors but not others, and we don't want - // to leave the witness table in a state where a requirement is - // "partially satisfied." - // - mapRequiredAccessorToSynAccessor.add(requiredAccessorDeclRef, synAccessorDecl); - } - - // Once our synthesized declaration is complete, we need - // to install it as the witness that satifies the given - // requirement. - // - // Subsequent code generation should not be able to tell the - // difference between our synthetic property and a hand-written - // one with the same behavior. - // - auto containerDecl = getParentDecl(synAccesorContainer); - auto containerDeclRef = getDefaultDeclRef(containerDecl); - for (auto& [key, value] : mapRequiredAccessorToSynAccessor) - { - witnessTable->add( - key.getDecl(), - RequirementWitness( - m_astBuilder->getMemberDeclRef(containerDeclRef, value))); - } - - witnessTable->add(requiredMemberDeclRef.getDecl(), - RequirementWitness(m_astBuilder->getMemberDeclRef(containerDeclRef, synAccesorContainer))); - return true; - } - - bool SemanticsVisitor::trySynthesizeWrapperTypeSubscriptRequirementWitness( - ConformanceCheckingContext* context, - DeclRef requiredMemberDeclRef, - RefPtr witnessTable) - { - // We are synthesizing the subscript requirement for a wrapper type: - // struct Wrapper - // { - // Inner inner; - // subscript(int index)->int { get { return inner[index]; } - // set { inner[index] = newValue; } - // } - // } - // - // // Find the witness that FooImpl : IFoo. - auto aggTypeDecl = as(context->parentDecl); - auto innerType = aggTypeDecl->wrappedType.type; - DeclRef innerProperty; - auto innerWitness = tryGetSubtypeWitness(innerType, witnessTable->baseType); - if (!innerWitness) - return false; - // - List synArgs; - ThisExpr* synThis; - auto synSubscriptDecl = synthesizeMethodSignatureForRequirementWitness( - context, - requiredMemberDeclRef, - synArgs, - synThis); - auto declType = getType(m_astBuilder, getDefaultDeclRef(synSubscriptDecl).as()); - synThis->checked = true; - - // Form a `this[args...]` expression that we will use to coerce from - // in the synthesized subscript accessors. - // - DiagnosticSink tempSink(getSourceManager(), nullptr); - SemanticsVisitor subVisitor(withSink(&tempSink)); - auto base = m_astBuilder->create(); - base->scope = synThis->scope; - base->name = getName("inner"); - - IndexExpr* indexExpr = m_astBuilder->create(); - indexExpr->baseExpression = base; - indexExpr->indexExprs = _Move(synArgs); - auto synBaseStorageExpr = subVisitor.CheckTerm(indexExpr); - - if (tempSink.getErrorCount() != 0) - return false; - - // Our synthesized subscript will have an accessor declaration for - // each accessor of the requirement. - // - bool canSynAccessors = synthesizeAccessorRequirements( - context, - requiredMemberDeclRef, - declType, - synBaseStorageExpr, - synSubscriptDecl, witnessTable); - if (!canSynAccessors) - return false; - - // The visibility of synthesized decl should be the min of the parent decl and the requirement. - if (requiredMemberDeclRef.getDecl()->findModifier()) - { - auto requirementVisibility = getDeclVisibility(requiredMemberDeclRef.getDecl()); - auto thisVisibility = getDeclVisibility(context->parentDecl); - auto visibility = Math::Min(thisVisibility, requirementVisibility); - addVisibilityModifier(m_astBuilder, synSubscriptDecl, visibility); - } - - return true; - } - - bool SemanticsVisitor::trySynthesizeSubscriptRequirementWitness( - ConformanceCheckingContext* context, - const LookupResult& lookupResult, - DeclRef requiredMemberDeclRef, - RefPtr witnessTable) - { - if (isWrapperTypeDecl(context->parentDecl)) - return trySynthesizeWrapperTypeSubscriptRequirementWitness(context, requiredMemberDeclRef, witnessTable); - - // The situation here is that the context of an inheritance - // declaration didn't provide an exact match for a required - // subscript. E.g.: - // - // interface ICell { subscript(int index)->int {get;} } - // struct MyCell : ICell - // { - // subscript(uint index)->int {ref;} - // } - // - // It is clear in this case that the `MyCell` type *can* - // satisfy the signature required by `ICell`, if we consider - // all the allowed type coercion rules, and use `ref` accessor - // to implement `get`. - // - // The approach in this function will be to construct a - // synthesized `subscript` along the lines of: - // - // struct MyCell ... - // { - // ... - // subscript(int index)->int {get;} - // { - // get { return this.origianl_subscript[index]; } - // } - // } - // - // That is, we construct a `subscript` with the correct type - // and with an accessor for each requirement, where the accesors - // all try to dispatch to the original subscript decl. - // - // If those synthesized accessors all type-check, then we can - // say that the type must satisfy the requirement structurally, - // even if there isn't an exact signature match. More - // importantly, the `property` we just synthesized can be - // used as a witness to the fact that the requirement is - // satisfied. - // - // The big-picture flow of the logic here is similar to - // `trySynthesizePropertyRequirementWitness()` above, and we - // will not comment this code as exhaustively, under the - // assumption that readers of the code don't benefit from - // having the exact same information stated twice. - // - - List synArgs; - ThisExpr* synThis; - auto synSubscriptDecl = synthesizeMethodSignatureForRequirementWitness( - context, - requiredMemberDeclRef, - synArgs, - synThis); - synThis->type.isLeftValue = true; - synThis->checked = true; - - auto declType = getType(m_astBuilder, getDefaultDeclRef(synSubscriptDecl).as()); - - // Form a `this[args...]` expression that we will use to coerce from - // in the synthesized subscript accessors. - // - DiagnosticSink tempSink(getSourceManager(), nullptr); - SemanticsVisitor subVisitor(withSink(&tempSink)); - Expr* synBaseStorageExpr = nullptr; - if (lookupResult.isValid()) - { - auto calleeExpr = m_astBuilder->create(); - calleeExpr->base = synThis; - calleeExpr->lookupResult2 = lookupResult; - auto invokeExpr = m_astBuilder->create(); - invokeExpr->functionExpr = calleeExpr; - invokeExpr->arguments = _Move(synArgs); - synBaseStorageExpr = subVisitor.ResolveInvoke(invokeExpr); - } - else - { - IndexExpr* indexExpr = m_astBuilder->create(); - indexExpr->baseExpression = synThis; - indexExpr->indexExprs = _Move(synArgs); - synBaseStorageExpr = subVisitor.CheckTerm(indexExpr); - } - if (tempSink.getErrorCount() != 0) - return false; - - // Our synthesized subscript will have an accessor declaration for - // each accessor of the requirement. - // - bool canSynAccessors = synthesizeAccessorRequirements( - context, - requiredMemberDeclRef, - declType, - synBaseStorageExpr, - synSubscriptDecl, witnessTable); - if (!canSynAccessors) - return false; - - - // The visibility of synthesized decl should be the min of the parent decl and the requirement. - if (requiredMemberDeclRef.getDecl()->findModifier()) - { - auto requirementVisibility = getDeclVisibility(requiredMemberDeclRef.getDecl()); - auto thisVisibility = getDeclVisibility(context->parentDecl); - auto visibility = Math::Min(thisVisibility, requirementVisibility); - addVisibilityModifier(m_astBuilder, synSubscriptDecl, visibility); - } - - return true; - } - - bool SemanticsVisitor::trySynthesizeRequirementWitness( - ConformanceCheckingContext* context, - LookupResult const& lookupResult, - DeclRef requiredMemberDeclRef, - RefPtr witnessTable) - { - SLANG_UNUSED(lookupResult); - SLANG_UNUSED(requiredMemberDeclRef); - SLANG_UNUSED(witnessTable); - - if (auto requiredFuncDeclRef = requiredMemberDeclRef.as()) - { - // Check signature match. - if (trySynthesizeMethodRequirementWitness( - context, - lookupResult, - requiredFuncDeclRef, - witnessTable)) - return true; - - if (auto builtinAttr = requiredFuncDeclRef.getDecl()->findModifier()) - { - switch (builtinAttr->kind) - { - case BuiltinRequirementKind::DAddFunc: - case BuiltinRequirementKind::DZeroFunc: - return trySynthesizeDifferentialMethodRequirementWitness( + case BuiltinRequirementKind::DAddFunc: + case BuiltinRequirementKind::DZeroFunc: + return trySynthesizeDifferentialMethodRequirementWitness( + context, + requiredFuncDeclRef, + witnessTable, + SynthesisPattern::AllInductive); + case BuiltinRequirementKind::And: + case BuiltinRequirementKind::Or: + case BuiltinRequirementKind::Not: + case BuiltinRequirementKind::BitAnd: + case BuiltinRequirementKind::BitNot: + case BuiltinRequirementKind::BitOr: + case BuiltinRequirementKind::BitXor: + case BuiltinRequirementKind::Shl: + case BuiltinRequirementKind::Shr: + case BuiltinRequirementKind::Equals: + case BuiltinRequirementKind::LessThan: + case BuiltinRequirementKind::LessThanOrEquals: + if (isEnumType(context->conformingType)) + return trySynthesizeEnumTypeMethodRequirementWitness( context, requiredFuncDeclRef, witnessTable, - SynthesisPattern::AllInductive); - case BuiltinRequirementKind::And: - case BuiltinRequirementKind::Or: - case BuiltinRequirementKind::Not: - case BuiltinRequirementKind::BitAnd: - case BuiltinRequirementKind::BitNot: - case BuiltinRequirementKind::BitOr: - case BuiltinRequirementKind::BitXor: - case BuiltinRequirementKind::Shl: - case BuiltinRequirementKind::Shr: - case BuiltinRequirementKind::Equals: - case BuiltinRequirementKind::LessThan: - case BuiltinRequirementKind::LessThanOrEquals: - if (isEnumType(context->conformingType)) - return trySynthesizeEnumTypeMethodRequirementWitness(context, requiredFuncDeclRef, witnessTable, builtinAttr->kind); - break; - } + builtinAttr->kind); + break; } - return false; } + return false; + } - // For generic decl, check if we match DMulFunc, and synthesize the method. - if (auto requiredGenericDeclRef = requiredMemberDeclRef.as()) - { - auto inner = getInner(requiredGenericDeclRef); - - // TODO: we should be able to remove DMul synthesis logic. - if (auto builtinAttr = inner->findModifier()) - { - switch (builtinAttr->kind) - { - case BuiltinRequirementKind::DMulFunc: - return trySynthesizeDifferentialMethodRequirementWitness( - context, - requiredGenericDeclRef, - witnessTable, - SynthesisPattern::FixedFirstArg); - } - } + // For generic decl, check if we match DMulFunc, and synthesize the method. + if (auto requiredGenericDeclRef = requiredMemberDeclRef.as()) + { + auto inner = getInner(requiredGenericDeclRef); - if (as(inner)) + // TODO: we should be able to remove DMul synthesis logic. + if (auto builtinAttr = inner->findModifier()) + { + switch (builtinAttr->kind) { - return trySynthesizeRequirementWitness( + case BuiltinRequirementKind::DMulFunc: + return trySynthesizeDifferentialMethodRequirementWitness( context, - lookupResult, - m_astBuilder->getMemberDeclRef(requiredGenericDeclRef, inner), - witnessTable); + requiredGenericDeclRef, + witnessTable, + SynthesisPattern::FixedFirstArg); } - return false; } - if( auto requiredPropertyDeclRef = requiredMemberDeclRef.as() ) + if (as(inner)) { - return trySynthesizePropertyRequirementWitness( + return trySynthesizeRequirementWitness( context, lookupResult, - requiredPropertyDeclRef, + m_astBuilder->getMemberDeclRef(requiredGenericDeclRef, inner), witnessTable); } + return false; + } - if (auto requiredSubscriptDeclRef = requiredMemberDeclRef.as()) - { - return trySynthesizeSubscriptRequirementWitness( - context, - lookupResult, - requiredSubscriptDeclRef, - witnessTable); - } + if (auto requiredPropertyDeclRef = requiredMemberDeclRef.as()) + { + return trySynthesizePropertyRequirementWitness( + context, + lookupResult, + requiredPropertyDeclRef, + witnessTable); + } - if (auto requiredAssocTypeDeclRef = requiredMemberDeclRef.as()) + if (auto requiredSubscriptDeclRef = requiredMemberDeclRef.as()) + { + return trySynthesizeSubscriptRequirementWitness( + context, + lookupResult, + requiredSubscriptDeclRef, + witnessTable); + } + + if (auto requiredAssocTypeDeclRef = requiredMemberDeclRef.as()) + { + if (auto builtinAttr = + requiredAssocTypeDeclRef.getDecl()->findModifier()) { - if (auto builtinAttr = requiredAssocTypeDeclRef.getDecl()->findModifier()) + switch (builtinAttr->kind) { - switch (builtinAttr->kind) - { - case BuiltinRequirementKind::DifferentialType: - return trySynthesizeDifferentialAssociatedTypeRequirementWitness( - context, - requiredAssocTypeDeclRef, - witnessTable); - } - } - else - { - return trySynthesizeAssociatedTypeRequirementWitness( + case BuiltinRequirementKind::DifferentialType: + return trySynthesizeDifferentialAssociatedTypeRequirementWitness( context, - lookupResult, requiredAssocTypeDeclRef, witnessTable); } } - - if (auto requiredConstantDeclRef = requiredMemberDeclRef.as()) - { - return trySynthesizeAssociatedConstantRequirementWitness( - context, - lookupResult, - requiredConstantDeclRef, - witnessTable); - } - - if (auto requiredCtor = requiredMemberDeclRef.as()) + else { - return trySynthesizeConstructorRequirementWitness( + return trySynthesizeAssociatedTypeRequirementWitness( context, lookupResult, - requiredCtor, + requiredAssocTypeDeclRef, witnessTable); } + } - // TODO: There are other kinds of requirements for which synthesis should - // be possible: - // - // * It should be possible to synthesize required initializers - // using an approach similar to what is used for methods. - // - // * We should be able to synthesize subscripts with different - // signatures (taking into account default parameters). - // - // * For specific kinds of generic requirements, we should be able - // to wrap the synthesis of the inner declaration in synthesis - // of an outer generic with a matching signature. - // - // All of these cases can/should use similar logic to - // `trySynthesizeMethodRequirementWitness` where they construct an AST - // in the form of what the use site ought to look like, and then - // apply existing semantic checking logic to generate the code. - - return false; + if (auto requiredConstantDeclRef = requiredMemberDeclRef.as()) + { + return trySynthesizeAssociatedConstantRequirementWitness( + context, + lookupResult, + requiredConstantDeclRef, + witnessTable); } - Stmt* _synthesizeMemberAssignMemberHelper( - ASTSynthesizer& synth, - Name* funcName, - Type* leftType, - Expr* leftValue, - List&& args, - List&& genericArgs, - List&& inductiveArgMask, - int nestingLevel = 0) + if (auto requiredCtor = requiredMemberDeclRef.as()) { - if (nestingLevel > 16) - return nullptr; + return trySynthesizeConstructorRequirementWitness( + context, + lookupResult, + requiredCtor, + witnessTable); + } - // If field type is an array, assign each element individually. - if (auto arrayType = as(leftType)) - { - VarDecl* indexVar = nullptr; - auto forStmt = synth.emitFor(synth.emitIntConst(0), synth.emitGetArrayLengthExpr(leftValue), indexVar); - addModifier(forStmt, synth.getBuilder()->create()); - auto innerLeft = synth.emitIndexExpr(leftValue, synth.emitVarExpr(indexVar)); + // TODO: There are other kinds of requirements for which synthesis should + // be possible: + // + // * It should be possible to synthesize required initializers + // using an approach similar to what is used for methods. + // + // * We should be able to synthesize subscripts with different + // signatures (taking into account default parameters). + // + // * For specific kinds of generic requirements, we should be able + // to wrap the synthesis of the inner declaration in synthesis + // of an outer generic with a matching signature. + // + // All of these cases can/should use similar logic to + // `trySynthesizeMethodRequirementWitness` where they construct an AST + // in the form of what the use site ought to look like, and then + // apply existing semantic checking logic to generate the code. + + return false; +} - for (auto ii = 0; ii < args.getCount(); ++ii) - { - auto& arg = args[ii]; - if (inductiveArgMask[ii]) - arg = synth.emitIndexExpr(arg, synth.emitVarExpr(indexVar)); - } +Stmt* _synthesizeMemberAssignMemberHelper( + ASTSynthesizer& synth, + Name* funcName, + Type* leftType, + Expr* leftValue, + List&& args, + List&& genericArgs, + List&& inductiveArgMask, + int nestingLevel = 0) +{ + if (nestingLevel > 16) + return nullptr; - auto assignStmt = _synthesizeMemberAssignMemberHelper( - synth, - funcName, - arrayType->getElementType(), - innerLeft, - _Move(args), - _Move(genericArgs), - _Move(inductiveArgMask), - nestingLevel + 1); + // If field type is an array, assign each element individually. + if (auto arrayType = as(leftType)) + { + VarDecl* indexVar = nullptr; + auto forStmt = + synth.emitFor(synth.emitIntConst(0), synth.emitGetArrayLengthExpr(leftValue), indexVar); + addModifier(forStmt, synth.getBuilder()->create()); + auto innerLeft = synth.emitIndexExpr(leftValue, synth.emitVarExpr(indexVar)); - synth.popScope(); - if (!assignStmt) - return nullptr; - return forStmt; + for (auto ii = 0; ii < args.getCount(); ++ii) + { + auto& arg = args[ii]; + if (inductiveArgMask[ii]) + arg = synth.emitIndexExpr(arg, synth.emitVarExpr(indexVar)); } - auto callee = synth.emitMemberExpr(leftType, funcName); + auto assignStmt = _synthesizeMemberAssignMemberHelper( + synth, + funcName, + arrayType->getElementType(), + innerLeft, + _Move(args), + _Move(genericArgs), + _Move(inductiveArgMask), + nestingLevel + 1); + + synth.popScope(); + if (!assignStmt) + return nullptr; + return forStmt; + } - if (genericArgs.getCount() > 0) - callee = synth.emitGenericAppExpr(callee, _Move(genericArgs)); + auto callee = synth.emitMemberExpr(leftType, funcName); - return synth.emitAssignStmt(leftValue, synth.emitInvokeExpr(callee, _Move(args))); - } + if (genericArgs.getCount() > 0) + callee = synth.emitGenericAppExpr(callee, _Move(genericArgs)); - bool SemanticsVisitor::trySynthesizeEnumTypeMethodRequirementWitness(ConformanceCheckingContext* context, - DeclRef funcDeclRef, - RefPtr witnessTable, - BuiltinRequirementKind requirementKind) - { - List synArgs; - ThisExpr* synThis = nullptr; - auto synFunc = synthesizeMethodSignatureForRequirementWitness( - context, funcDeclRef, synArgs, synThis); - auto intrinsicOpModifier = getASTBuilder()->create(); - switch (requirementKind) - { - case BuiltinRequirementKind::And: - intrinsicOpModifier->op = kIROp_And; - break; - case BuiltinRequirementKind::Or: - intrinsicOpModifier->op = kIROp_Or; - break; - case BuiltinRequirementKind::Not: - intrinsicOpModifier->op = kIROp_Not; - break; - case BuiltinRequirementKind::BitAnd: - intrinsicOpModifier->op = kIROp_BitAnd; - break; - case BuiltinRequirementKind::BitNot: - intrinsicOpModifier->op = kIROp_BitNot; - break; - case BuiltinRequirementKind::BitOr: - intrinsicOpModifier->op = kIROp_BitOr; - break; - case BuiltinRequirementKind::BitXor: - intrinsicOpModifier->op = kIROp_BitXor; - break; - case BuiltinRequirementKind::Shl: - intrinsicOpModifier->op = kIROp_Lsh; - break; - case BuiltinRequirementKind::Shr: - intrinsicOpModifier->op = kIROp_Rsh; - break; - case BuiltinRequirementKind::Equals: - intrinsicOpModifier->op = kIROp_Eql; - break; - case BuiltinRequirementKind::LessThan: - intrinsicOpModifier->op = kIROp_Less; - break; - case BuiltinRequirementKind::LessThanOrEquals: - intrinsicOpModifier->op = kIROp_Leq; - break; - case BuiltinRequirementKind::InitLogicalFromInt: - intrinsicOpModifier->op = kIROp_IntCast; - break; - default: - SLANG_UNEXPECTED("unknown builtin requirement kind."); - } - synFunc->loc = context->parentDecl->closingSourceLoc; - synFunc->nameAndLoc.loc = synFunc->loc; - context->parentDecl->members.add(synFunc); - context->parentDecl->invalidateMemberDictionary(); - addModifier(synFunc, intrinsicOpModifier); - witnessTable->add(funcDeclRef.getDecl(), RequirementWitness(m_astBuilder->getDirectDeclRef(synFunc))); - return true; - } + return synth.emitAssignStmt(leftValue, synth.emitInvokeExpr(callee, _Move(args))); +} - bool SemanticsVisitor::trySynthesizeDifferentialMethodRequirementWitness( - ConformanceCheckingContext* context, - DeclRef requirementDeclRef, - RefPtr witnessTable, - SynthesisPattern pattern) - { - // We support two cases of synthesis here. - // Case 1 is that there the associated Differential type is defined to be `DifferentialBottom`. - // In this case we just trivially return `DifferentialBottom` in all synthesized methods. - // Case 2 is that the `Differential` type contains members corresponding to each primal member. - // We will apply a general code synthesis pattern to reflect that structure. - // For requirement of the form: - // ``` - // static TResult requiredMethod(TParam1 p0, TParam2 p1, ...) - // ``` - // Where TResult,TParam1, TParam2 is either `This` or `Differential`, - // We synthesize a memberwise dispatch to compute each field of `TResult`. - // Multiple patterns are supported (see SemanticsVisitor::SynthesisPattern for a full list) - // For AllInductive, we synthesize an implementation of the form: - // ``` - // [BackwardDifferentiable] - // static TResult requiredMethod(TParam1 p0, TParam2 p1, ...) - // { - // TResult result; - // result.member0 = decltype(result.member0).requiredMethod(p0.member0, p1.member0); - // result.member1 = decltype(result.member1).requiredMethod(p0.member1, p1.member1); - // ... - // return result; - // } - // ``` - - // First we need to make sure the associated `Differential` type requirement is satisfied. - bool hasDifferentialAssocType = false; - for (auto& existingEntry : witnessTable->getRequirementDictionary()) - { - if (auto builtinReqAttr = existingEntry.key->findModifier()) - { - if (builtinReqAttr->kind == BuiltinRequirementKind::DifferentialType && - existingEntry.value.getFlavor() != RequirementWitness::Flavor::none) - { - hasDifferentialAssocType = true; - } - } - } - if (!hasDifferentialAssocType) - return false; +bool SemanticsVisitor::trySynthesizeEnumTypeMethodRequirementWitness( + ConformanceCheckingContext* context, + DeclRef funcDeclRef, + RefPtr witnessTable, + BuiltinRequirementKind requirementKind) +{ + List synArgs; + ThisExpr* synThis = nullptr; + auto synFunc = + synthesizeMethodSignatureForRequirementWitness(context, funcDeclRef, synArgs, synThis); + auto intrinsicOpModifier = getASTBuilder()->create(); + switch (requirementKind) + { + case BuiltinRequirementKind::And: intrinsicOpModifier->op = kIROp_And; break; + case BuiltinRequirementKind::Or: intrinsicOpModifier->op = kIROp_Or; break; + case BuiltinRequirementKind::Not: intrinsicOpModifier->op = kIROp_Not; break; + case BuiltinRequirementKind::BitAnd: intrinsicOpModifier->op = kIROp_BitAnd; break; + case BuiltinRequirementKind::BitNot: intrinsicOpModifier->op = kIROp_BitNot; break; + case BuiltinRequirementKind::BitOr: intrinsicOpModifier->op = kIROp_BitOr; break; + case BuiltinRequirementKind::BitXor: intrinsicOpModifier->op = kIROp_BitXor; break; + case BuiltinRequirementKind::Shl: intrinsicOpModifier->op = kIROp_Lsh; break; + case BuiltinRequirementKind::Shr: intrinsicOpModifier->op = kIROp_Rsh; break; + case BuiltinRequirementKind::Equals: intrinsicOpModifier->op = kIROp_Eql; break; + case BuiltinRequirementKind::LessThan: intrinsicOpModifier->op = kIROp_Less; break; + case BuiltinRequirementKind::LessThanOrEquals: intrinsicOpModifier->op = kIROp_Leq; break; + case BuiltinRequirementKind::InitLogicalFromInt: intrinsicOpModifier->op = kIROp_IntCast; break; + default: SLANG_UNEXPECTED("unknown builtin requirement kind."); + } + synFunc->loc = context->parentDecl->closingSourceLoc; + synFunc->nameAndLoc.loc = synFunc->loc; + context->parentDecl->members.add(synFunc); + context->parentDecl->invalidateMemberDictionary(); + addModifier(synFunc, intrinsicOpModifier); + witnessTable->add( + funcDeclRef.getDecl(), + RequirementWitness(m_astBuilder->getDirectDeclRef(synFunc))); + return true; +} - ASTSynthesizer synth(m_astBuilder, getNamePool()); - List synArgs; - List synGenericArgs; - ThisExpr* synThis = nullptr; - FuncDecl* synFunc = nullptr; - GenericDecl* synGeneric = nullptr; +bool SemanticsVisitor::trySynthesizeDifferentialMethodRequirementWitness( + ConformanceCheckingContext* context, + DeclRef requirementDeclRef, + RefPtr witnessTable, + SynthesisPattern pattern) +{ + // We support two cases of synthesis here. + // Case 1 is that there the associated Differential type is defined to be `DifferentialBottom`. + // In this case we just trivially return `DifferentialBottom` in all synthesized methods. + // Case 2 is that the `Differential` type contains members corresponding to each primal member. + // We will apply a general code synthesis pattern to reflect that structure. + // For requirement of the form: + // ``` + // static TResult requiredMethod(TParam1 p0, TParam2 p1, ...) + // ``` + // Where TResult,TParam1, TParam2 is either `This` or `Differential`, + // We synthesize a memberwise dispatch to compute each field of `TResult`. + // Multiple patterns are supported (see SemanticsVisitor::SynthesisPattern for a full list) + // For AllInductive, we synthesize an implementation of the form: + // ``` + // [BackwardDifferentiable] + // static TResult requiredMethod(TParam1 p0, TParam2 p1, ...) + // { + // TResult result; + // result.member0 = decltype(result.member0).requiredMethod(p0.member0, p1.member0); + // result.member1 = decltype(result.member1).requiredMethod(p0.member1, p1.member1); + // ... + // return result; + // } + // ``` + + // First we need to make sure the associated `Differential` type requirement is satisfied. + bool hasDifferentialAssocType = false; + for (auto& existingEntry : witnessTable->getRequirementDictionary()) + { + if (auto builtinReqAttr = existingEntry.key->findModifier()) + { + if (builtinReqAttr->kind == BuiltinRequirementKind::DifferentialType && + existingEntry.value.getFlavor() != RequirementWitness::Flavor::none) + { + hasDifferentialAssocType = true; + } + } + } + if (!hasDifferentialAssocType) + return false; - if (auto genericDeclRef = requirementDeclRef.as()) - { - synGeneric = synthesizeGenericSignatureForRequirementWitness( - context, genericDeclRef, synArgs, synGenericArgs, synThis); - synFunc = as(synGeneric->inner); - } - else if (auto funcDeclRef = requirementDeclRef.as()) - { - synFunc = as(synthesizeMethodSignatureForRequirementWitness( - context, funcDeclRef, synArgs, synThis)); - } + ASTSynthesizer synth(m_astBuilder, getNamePool()); + List synArgs; + List synGenericArgs; + ThisExpr* synThis = nullptr; + FuncDecl* synFunc = nullptr; + GenericDecl* synGeneric = nullptr; - SLANG_ASSERT(synFunc); + if (auto genericDeclRef = requirementDeclRef.as()) + { + synGeneric = synthesizeGenericSignatureForRequirementWitness( + context, + genericDeclRef, + synArgs, + synGenericArgs, + synThis); + synFunc = as(synGeneric->inner); + } + else if (auto funcDeclRef = requirementDeclRef.as()) + { + synFunc = as( + synthesizeMethodSignatureForRequirementWitness(context, funcDeclRef, synArgs, synThis)); + } - addModifier(synFunc, m_astBuilder->create()); + SLANG_ASSERT(synFunc); + addModifier(synFunc, m_astBuilder->create()); - synth.pushContainerScope(synFunc); - auto blockStmt = m_astBuilder->create(); - synFunc->body = blockStmt; - auto seqStmt = synth.pushSeqStmtScope(); - blockStmt->body = seqStmt; - // Create a variable for return value. - synth.pushVarScope(); - auto varStmt = synth.emitVarDeclStmt(synFunc->returnType.type, getName("result")); - auto resultVarExpr = synth.emitVarExpr(varStmt, synFunc->returnType.type); + synth.pushContainerScope(synFunc); + auto blockStmt = m_astBuilder->create(); + synFunc->body = blockStmt; + auto seqStmt = synth.pushSeqStmtScope(); + blockStmt->body = seqStmt; - for (auto member : context->parentDecl->members) - { - auto derivativeAttr = member->findModifier(); - if (!derivativeAttr) - continue; + // Create a variable for return value. + synth.pushVarScope(); + auto varStmt = synth.emitVarDeclStmt(synFunc->returnType.type, getName("result")); + auto resultVarExpr = synth.emitVarExpr(varStmt, synFunc->returnType.type); - auto varMember = as(member); - if (!varMember) - continue; - ensureDecl(varMember, DeclCheckState::ReadyForReference); - auto memberType = varMember->getType(); - auto diffMemberType = tryGetDifferentialType(m_astBuilder, memberType); - if (!diffMemberType) - continue; + for (auto member : context->parentDecl->members) + { + auto derivativeAttr = member->findModifier(); + if (!derivativeAttr) + continue; - // Pull up the derivative member name from the attribute - auto derivMemberName = derivativeAttr->memberDeclRef->declRef.getName(); + auto varMember = as(member); + if (!varMember) + continue; + ensureDecl(varMember, DeclCheckState::ReadyForReference); + auto memberType = varMember->getType(); + auto diffMemberType = tryGetDifferentialType(m_astBuilder, memberType); + if (!diffMemberType) + continue; - // Construct reference exprs to the member's corresponding fields in each parameter. - List paramFields; - List inductiveArgMask; + // Pull up the derivative member name from the attribute + auto derivMemberName = derivativeAttr->memberDeclRef->declRef.getName(); - switch (pattern) - { - case SynthesisPattern::AllInductive: + // Construct reference exprs to the member's corresponding fields in each parameter. + List paramFields; + List inductiveArgMask; + + switch (pattern) + { + case SynthesisPattern::AllInductive: { for (auto arg : synArgs) { @@ -6050,7 +6131,7 @@ namespace Slang } break; } - case SynthesisPattern::FixedFirstArg: + case SynthesisPattern::FixedFirstArg: { int paramIndex = 0; for (auto arg : synArgs) @@ -6076,14 +6157,12 @@ namespace Slang } break; } - default: - SLANG_UNIMPLEMENTED_X("unhandled synthesis pattern"); - break; - } + default: SLANG_UNIMPLEMENTED_X("unhandled synthesis pattern"); break; + } - // Invoke the method for the field and assign the value to resultVar. - auto leftVal = synth.emitMemberExpr(resultVarExpr, derivMemberName); - if (!_synthesizeMemberAssignMemberHelper( + // Invoke the method for the field and assign the value to resultVar. + auto leftVal = synth.emitMemberExpr(resultVarExpr, derivMemberName); + if (!_synthesizeMemberAssignMemberHelper( synth, requirementDeclRef.getName(), memberType, @@ -6091,2650 +6170,2860 @@ namespace Slang _Move(paramFields), _Move(synGenericArgs), _Move(inductiveArgMask))) - return false; - } + return false; + } - // TODO: synthesize assignments for inherited members here. + // TODO: synthesize assignments for inherited members here. - auto synReturn = m_astBuilder->create(); - synReturn->expression = resultVarExpr; - seqStmt->stmts.add(synReturn); + auto synReturn = m_astBuilder->create(); + synReturn->expression = resultVarExpr; + seqStmt->stmts.add(synReturn); - Decl* witnessDecl = synGeneric ? (Decl*)synGeneric : synFunc; - context->parentDecl->members.add(witnessDecl); - context->parentDecl->invalidateMemberDictionary(); - addModifier(synFunc, m_astBuilder->create()); - - // If `This` is nested inside a generic, we need to form a complete declref type to the - // newly synthesized method here in order to fill into the witness table. - // This can be done by obtaining the ThisType witness from requirementDeclRef to get the - // generic substitution for outer generic parameters, and apply it here. - SubstitutionSet substSet; - if (auto thisTypeWitness = findThisTypeWitness( + Decl* witnessDecl = synGeneric ? (Decl*)synGeneric : synFunc; + context->parentDecl->members.add(witnessDecl); + context->parentDecl->invalidateMemberDictionary(); + addModifier(synFunc, m_astBuilder->create()); + + // If `This` is nested inside a generic, we need to form a complete declref type to the + // newly synthesized method here in order to fill into the witness table. + // This can be done by obtaining the ThisType witness from requirementDeclRef to get the + // generic substitution for outer generic parameters, and apply it here. + SubstitutionSet substSet; + if (auto thisTypeWitness = findThisTypeWitness( SubstitutionSet(requirementDeclRef), as(requirementDeclRef.getDecl()->parentDecl))) + { + if (auto declRefType = as(thisTypeWitness->getSub())) { - if (auto declRefType = as(thisTypeWitness->getSub())) - { - substSet = SubstitutionSet(declRefType->getDeclRef()); - } - } - if (!substSet.declRef) - return false; - DeclRef synthesizedWitnessDeclRef; - if (auto parentExtDecl = as(context->parentDecl)) - { - // If the conformance is declared on an extension to ThisType, - // we need to form a new proper decl ref to the parent extension decl - // with the correct specialization arguments. - // - if (GetOuterGeneric(context->parentDecl)) - { - - auto extDeclRef = applyExtensionToType(parentExtDecl, context->conformingType); - synthesizedWitnessDeclRef = m_astBuilder->getMemberDeclRef(extDeclRef, witnessDecl); - } - } - else - { - synthesizedWitnessDeclRef = m_astBuilder->getMemberDeclRef(substSet.declRef, witnessDecl); + substSet = SubstitutionSet(declRefType->getDeclRef()); } - if (!synthesizedWitnessDeclRef) - synthesizedWitnessDeclRef = m_astBuilder->getDirectDeclRef(witnessDecl); - witnessTable->add(requirementDeclRef.getDecl(), RequirementWitness(synthesizedWitnessDeclRef)); - return true; } - - bool SemanticsVisitor::findWitnessForInterfaceRequirement( - ConformanceCheckingContext* context, - Type* subType, - Type* superInterfaceType, - InheritanceDecl* inheritanceDecl, - DeclRef superInterfaceDeclRef, - DeclRef requiredMemberDeclRef, - RefPtr witnessTable, - SubtypeWitness* subTypeConformsToSuperInterfaceWitness) + if (!substSet.declRef) + return false; + DeclRef synthesizedWitnessDeclRef; + if (auto parentExtDecl = as(context->parentDecl)) { - SLANG_UNUSED(superInterfaceDeclRef) - - // The goal of this function is to find a suitable - // value to satisfy the requirement. - // - // The 99% case is that the requirement is a named member - // of the interface, and we need to search for a member - // with the same name in the type declaration and - // its (known) extensions. - - // The exception to that is when the requiredMemberDeclRef is already - // resolved to the actual satisfying decl, in which case we simply return - // true without any further lookup. - if (!as(requiredMemberDeclRef.getParent().getDecl())) - return true; - - // If `requiredMemberDeclRef` is a lookup decl ref for an interface requirement - // we attempt to do the loopkup through witness tables. - // - // As a first pass, lets check if we already have a - // witness in the table for the requirement, so - // that we can bail out early. + // If the conformance is declared on an extension to ThisType, + // we need to form a new proper decl ref to the parent extension decl + // with the correct specialization arguments. // - if(witnessTable->getRequirementDictionary().containsKey(requiredMemberDeclRef.getDecl())) + if (GetOuterGeneric(context->parentDecl)) { - return true; - } - // The ThisType requirement is always satisfied. - if (as(requiredMemberDeclRef.getDecl())) - { - return true; + auto extDeclRef = applyExtensionToType(parentExtDecl, context->conformingType); + synthesizedWitnessDeclRef = m_astBuilder->getMemberDeclRef(extDeclRef, witnessDecl); } - - // An important exception to the above is that an - // inheritance declaration in the interface is not going - // to be satisfied by an inheritance declaration in the - // conforming type, but rather by a full "witness table" - // full of the satisfying values for each requirement - // in the inherited-from interface. - // - if( auto requiredInheritanceDeclRef = requiredMemberDeclRef.as() ) - { - // Recursively check that the type conforms - // to the inherited interface. - // - // TODO: we *really* need a linearization step here!!!! - - auto reqType = getBaseType(m_astBuilder, requiredInheritanceDeclRef); - - auto interfaceIsReqWitness = - m_astBuilder->getDeclaredSubtypeWitness( - superInterfaceType, - reqType, - requiredInheritanceDeclRef); - // ... - - auto subIsReqWitness = m_astBuilder->getTransitiveSubtypeWitness( - subTypeConformsToSuperInterfaceWitness, - interfaceIsReqWitness); - // ... - - RefPtr satisfyingWitnessTable = new WitnessTable(); - satisfyingWitnessTable->witnessedType = subType; - satisfyingWitnessTable->baseType = reqType; - - witnessTable->add( - requiredInheritanceDeclRef.getDecl(), - RequirementWitness(satisfyingWitnessTable)); - - if( !checkConformanceToType( - context, - subType, - requiredInheritanceDeclRef.getDecl(), - reqType, - subIsReqWitness, - satisfyingWitnessTable) ) - { - return false; - } - - return true; - } - - // We will look up members with the same name, - // since only same-name members will be able to - // satisfy the requirement. - // - Name* name = requiredMemberDeclRef.getName(); - - // We start by looking up members of the same - // name, on the type that is claiming to conform. - // - // This lookup step could include members that - // we might not actually want to consider: - // - // * Lookup through a type `Foo` where `Foo : IBar` - // will be able to find members of `IBar`, which - // somewhat obviously shouldn't apply when - // determining if `Foo` satisfies the requirements - // of `IBar`. - // - // * Lookup in the presence of `__transparent` members - // may produce references to declarations on a *field* - // of the type rather than the type. Conformance through - // transparent members could be supported in theory, - // but would require synthesizing proxy/forwarding - // implementations in the type itself. - // - // For the first issue, we will use a flag to influence - // lookup so that it doesn't include results looked up - // through interface inheritance clauses (but it *will* - // look up result through inheritance clauses corresponding - // to concrete types). - // - // The second issue of members that require us to proxy/forward - // requests will be handled further down. For now we include - // lookup results that might be usable, but not as-is. - // - LookupResult lookupResult; - if (!isWrapperTypeDecl(context->parentDecl)) - { - lookupResult = lookUpMember(m_astBuilder, this, name, subType, nullptr, LookupMask::Default, LookupOptions::IgnoreBaseInterfaces); - - if (!lookupResult.isValid()) - { - // If we failed to look up a member with the name of the - // requirement, it may be possible that we can still synthesis the - // implementation if this is one of the known builtin requirements. - // Otherwise, report diagnostic now. - - if (requiredMemberDeclRef.getDecl()->hasModifier() || - (requiredMemberDeclRef.as() && - getInner(requiredMemberDeclRef.as())->hasModifier())) - { - } - else if (requiredMemberDeclRef.as() && - (as(context->conformingType) || - as(context->conformingType) || - as(context->conformingType))) - { - } - else - { - getSink()->diagnose(inheritanceDecl, Diagnostics::typeDoesntImplementInterfaceRequirement, subType, requiredMemberDeclRef); - getSink()->diagnose(requiredMemberDeclRef, Diagnostics::seeDeclarationOf, requiredMemberDeclRef); - return false; - } - } - } - - // Iterate over the members and look for one that matches - // the expected signature for the requirement. - for (auto member : lookupResult) - { - // To a first approximation, any lookup result that required a "breadcrumb" - // will not be usable to directly satisfy an interface requirement, since - // each breadcrumb will amount to a manipulation of `this` that is required - // to make the declaration usable (e.g., casting to a base type). - // - if(member.breadcrumbs != nullptr) - continue; - - if (doesMemberSatisfyRequirement(member.declRef, requiredMemberDeclRef, witnessTable)) - { - // The member satisfies the requirement in every other way except that - // it may have a lower visibility than min(parentVisibility, requirementVisibilty), - // in that case we will treat it as an error. - auto minRequiredVisibility = Math::Min(getDeclVisibility(requiredMemberDeclRef.getDecl()), getTypeVisibility(subType)); - if (getDeclVisibility(member.declRef.getDecl()) < minRequiredVisibility) - { - getSink()->diagnose(member.declRef, Diagnostics::satisfyingDeclCannotHaveLowerVisibility, member.declRef); - getSink()->diagnose(requiredMemberDeclRef, Diagnostics::seeDeclarationOf, QualifiedDeclPath(requiredMemberDeclRef)); - return false; - } - return true; - } - } - - // If we reach this point then there were no members suitable - // for satisfying the interface requirement *diretly*. - // - // It is possible that one of the items in `lookupResult` could be - // used to synthesize an exact-match witness, by generating the - // code required to handle all the conversions that might be - // required on `this`. - // - // Another situation that will get us here is that we are dealing with - // a wrapper type (struct Foo:IFoo=FooImpl), and we will synthesize - // wrappers that redirects the call into the inner element. - // - if( trySynthesizeRequirementWitness(context, lookupResult, requiredMemberDeclRef, witnessTable) ) - { - return true; - } - - // We failed to find a member of the type that can be used - // to satisfy the requirement (even via synthesis), so we - // need to report the failure to the user. - // - // TODO: Eventually we might want something akin to the current - // overload resolution logic, where we keep track of a list - // of "candidates" for satisfaction of the requirement, - // and if nothing is found we print the candidates that made it - // furthest in checking. - // - if (!lookupResult.isOverloaded() && lookupResult.isValid()) - { - getSink()->diagnose(lookupResult.item.declRef, Diagnostics::memberDoesNotMatchRequirementSignature, lookupResult.item.declRef); - } - else - { - getSink()->diagnose(inheritanceDecl, Diagnostics::typeDoesntImplementInterfaceRequirement, subType, requiredMemberDeclRef); - } - getSink()->diagnose(requiredMemberDeclRef, Diagnostics::seeDeclarationOfInterfaceRequirement, requiredMemberDeclRef); - return false; } - - RefPtr SemanticsVisitor::checkInterfaceConformance( - ConformanceCheckingContext* context, - Type* subType, - Type* superInterfaceType, - InheritanceDecl* inheritanceDecl, - DeclRef superInterfaceDeclRef, - SubtypeWitness* subTypeConformsToSuperInterfaceWitnes) + else { - // Has somebody already checked this conformance, - // and/or is in the middle of checking it? - RefPtr witnessTable; - if(context->mapInterfaceToWitnessTable.tryGetValue(superInterfaceDeclRef, witnessTable)) - return witnessTable; - - // We need to check the declaration of the interface - // before we can check that we conform to it. - // - ensureDecl(superInterfaceDeclRef, DeclCheckState::CanReadInterfaceRequirements); - - // We will construct the witness table, and register it - // *before* we go about checking fine-grained requirements, - // in order to short-circuit any potential for infinite recursion. - - // Note: we will re-use the witnes table attached to the inheritance decl, - // if there is one. This catches cases where semantic checking might - // have synthesized some of the conformance witnesses for us. - // - witnessTable = inheritanceDecl->witnessTable; - if(!witnessTable) - { - witnessTable = new WitnessTable(); - witnessTable->baseType = DeclRefType::create(m_astBuilder, superInterfaceDeclRef); - witnessTable->witnessedType = subType; - } - context->mapInterfaceToWitnessTable.add(superInterfaceDeclRef, witnessTable); + synthesizedWitnessDeclRef = m_astBuilder->getMemberDeclRef(substSet.declRef, witnessDecl); + } + if (!synthesizedWitnessDeclRef) + synthesizedWitnessDeclRef = m_astBuilder->getDirectDeclRef(witnessDecl); + witnessTable->add(requirementDeclRef.getDecl(), RequirementWitness(synthesizedWitnessDeclRef)); + return true; +} - if(!checkInterfaceConformance(context, subType, superInterfaceType, inheritanceDecl, superInterfaceDeclRef, subTypeConformsToSuperInterfaceWitnes, witnessTable)) - return nullptr; +bool SemanticsVisitor::findWitnessForInterfaceRequirement( + ConformanceCheckingContext* context, + Type* subType, + Type* superInterfaceType, + InheritanceDecl* inheritanceDecl, + DeclRef superInterfaceDeclRef, + DeclRef requiredMemberDeclRef, + RefPtr witnessTable, + SubtypeWitness* subTypeConformsToSuperInterfaceWitness) +{ + SLANG_UNUSED(superInterfaceDeclRef) + + // The goal of this function is to find a suitable + // value to satisfy the requirement. + // + // The 99% case is that the requirement is a named member + // of the interface, and we need to search for a member + // with the same name in the type declaration and + // its (known) extensions. + + // The exception to that is when the requiredMemberDeclRef is already + // resolved to the actual satisfying decl, in which case we simply return + // true without any further lookup. + if (!as(requiredMemberDeclRef.getParent().getDecl())) + return true; - return witnessTable; + // If `requiredMemberDeclRef` is a lookup decl ref for an interface requirement + // we attempt to do the loopkup through witness tables. + // + // As a first pass, lets check if we already have a + // witness in the table for the requirement, so + // that we can bail out early. + // + if (witnessTable->getRequirementDictionary().containsKey(requiredMemberDeclRef.getDecl())) + { + return true; } - static bool isAssociatedTypeDecl(Decl* decl) + // The ThisType requirement is always satisfied. + if (as(requiredMemberDeclRef.getDecl())) { - auto d = decl; - while(auto genericDecl = as(d)) - d = genericDecl->inner; - if(as(d)) - return true; - return false; + return true; } - bool SemanticsVisitor::checkInterfaceConformance( - ConformanceCheckingContext* context, - Type* subType, - Type* superInterfaceType, - InheritanceDecl* inheritanceDecl, - DeclRef superInterfaceDeclRef, - SubtypeWitness* subTypeConformsToSuperInterfaceWitness, - WitnessTable* witnessTable) + // An important exception to the above is that an + // inheritance declaration in the interface is not going + // to be satisfied by an inheritance declaration in the + // conforming type, but rather by a full "witness table" + // full of the satisfying values for each requirement + // in the inherited-from interface. + // + if (auto requiredInheritanceDeclRef = requiredMemberDeclRef.as()) { - // We need to check the declaration of the interface - // before we can check that we conform to it. + // Recursively check that the type conforms + // to the inherited interface. // - ensureDecl(superInterfaceDeclRef, DeclCheckState::CanReadInterfaceRequirements); + // TODO: we *really* need a linearization step here!!!! - // When comparing things like signatures, we need to do so in the context - // of a LookupDeclRef that aligns the signatures in the interface - // with those in the concrete type. For example, we need to treat any uses - // of `This` in the interface as equivalent to the concrete type for the - // purpose of signature matching (and similarly for associated types). - // - auto thisTypeDeclRef = m_astBuilder->getLookupDeclRef( - subTypeConformsToSuperInterfaceWitness, superInterfaceDeclRef.getDecl()->getThisTypeDecl()); + auto reqType = getBaseType(m_astBuilder, requiredInheritanceDeclRef); - bool result = true; + auto interfaceIsReqWitness = m_astBuilder->getDeclaredSubtypeWitness( + superInterfaceType, + reqType, + requiredInheritanceDeclRef); + // ... - // TODO: If we ever allow for implementation inheritance, - // then we will need to consider the case where a type - // declares that it conforms to an interface, but one of - // its (non-interface) base types already conforms to - // that interface, so that all of the requirements are - // already satisfied with inherited implementations... + auto subIsReqWitness = m_astBuilder->getTransitiveSubtypeWitness( + subTypeConformsToSuperInterfaceWitness, + interfaceIsReqWitness); + // ... - // Note: we break this logic into two loops, where we first - // check conformance for all associated-type requirements - // and *then* check conformance for all other requirements. - // - // Checking associated-type requirements first ensures that - // we can make use of the identity of the associated types - // when checking other members. - // - // TODO: There could in theory be subtle cases involving - // circular or recursive dependency chains that make such - // a simple ordering impractical (e.g., associated type `A` - // is constrained to `IThing` where `IThing` requires - // that `T : IOtherThing where T.B == int` for another associated - // type `B`). - // - // The only robust solution long-term is probably to treat this - // as a type-inference problem by creating type variables to - // stand in for the associated-type requirements and then to discover - // constraints and solve for those type variables as part of the - // conformance-checking process. - // - for(auto requiredMemberDecl : getMembers(m_astBuilder, superInterfaceDeclRef)) - { - if(!isAssociatedTypeDecl(requiredMemberDecl.getDecl())) - continue; - auto requiredMemberDeclRef = m_astBuilder->getLookupDeclRef(subTypeConformsToSuperInterfaceWitness, requiredMemberDecl.getDecl()); - auto requirementSatisfied = findWitnessForInterfaceRequirement( - context, - subType, - superInterfaceType, - inheritanceDecl, - superInterfaceDeclRef, - requiredMemberDeclRef, - witnessTable, - subTypeConformsToSuperInterfaceWitness); + RefPtr satisfyingWitnessTable = new WitnessTable(); + satisfyingWitnessTable->witnessedType = subType; + satisfyingWitnessTable->baseType = reqType; - result = result && requirementSatisfied; - } - for(auto requiredMemberDecl : getMembers(m_astBuilder, superInterfaceDeclRef)) - { - if(isAssociatedTypeDecl(requiredMemberDecl.getDecl())) - continue; - if (requiredMemberDecl.as()) - continue; - auto requiredMemberDeclRef = m_astBuilder->getLookupDeclRef(subTypeConformsToSuperInterfaceWitness, requiredMemberDecl.getDecl()); - auto requirementSatisfied = findWitnessForInterfaceRequirement( + witnessTable->add( + requiredInheritanceDeclRef.getDecl(), + RequirementWitness(satisfyingWitnessTable)); + + if (!checkConformanceToType( context, subType, - superInterfaceType, - inheritanceDecl, - superInterfaceDeclRef, - requiredMemberDeclRef, - witnessTable, - subTypeConformsToSuperInterfaceWitness); - - result = result && requirementSatisfied; - } - - // Extensions that apply to the interface type can create new conformances - // for the concrete types that inherit from the interface. - // - // These new conformances should not be able to introduce new *requirements* - // for an implementing interface (although they currently can), but we - // still need to go through this logic to find the appropriate value - // that will satisfy the requirement in these cases, and also to put - // the required entry into the witness table for the interface itself. - // - // TODO: This logic is a bit slippery, and we need to figure out what - // it means in the context of separate compilation. If module A defines - // an interface IA, module B defines a type C that conforms to IA, and then - // module C defines an extension that makes IA conform to IC, then it is - // unreasonable to expect the {B:IA} witness table to contain an entry - // corresponding to {IA:IC}. - // - // The simple answer then would be that the {IA:IC} conformance should be - // fixed, with a single witness table for {IA:IC}, but then what should - // happen in B explicitly conformed to IC already? - // - // For now we will just walk through the extensions that are known at - // the time we are compiling and handle those, and punt on the larger issue - // for a bit longer. - // - for(auto candidateExt : getCandidateExtensions(superInterfaceDeclRef, this)) + requiredInheritanceDeclRef.getDecl(), + reqType, + subIsReqWitness, + satisfyingWitnessTable)) { - // We need to apply the extension to the interface type that our - // concrete type is inheriting from. - // - Type* targetType = DeclRefType::create(m_astBuilder, thisTypeDeclRef); - auto parentDeclRef = applyExtensionToType(candidateExt, targetType); - if(!parentDeclRef) - continue; - - // Only inheritance clauses from the extension matter right now. - for(auto requiredInheritanceDecl : getMembersOfType(m_astBuilder, candidateExt)) - { - auto requiredInheritanceDeclRef = m_astBuilder->getLookupDeclRef( - subTypeConformsToSuperInterfaceWitness, requiredInheritanceDecl.getDecl()); - auto requirementSatisfied = findWitnessForInterfaceRequirement( - context, - subType, - superInterfaceType, - inheritanceDecl, - superInterfaceDeclRef, - requiredInheritanceDeclRef, - witnessTable, - subTypeConformsToSuperInterfaceWitness); - - result = result && requirementSatisfied; - } + return false; } - // The conformance was satisfied if all the requirements were satisfied. - // - return result; + return true; } - bool SemanticsVisitor::checkConformanceToType( - ConformanceCheckingContext* context, - Type* subType, - InheritanceDecl* inheritanceDecl, - Type* superType, - SubtypeWitness* subIsSuperWitness, - WitnessTable* witnessTable) - { - if (witnessTable->isExtern) - return true; + // We will look up members with the same name, + // since only same-name members will be able to + // satisfy the requirement. + // + Name* name = requiredMemberDeclRef.getName(); + + // We start by looking up members of the same + // name, on the type that is claiming to conform. + // + // This lookup step could include members that + // we might not actually want to consider: + // + // * Lookup through a type `Foo` where `Foo : IBar` + // will be able to find members of `IBar`, which + // somewhat obviously shouldn't apply when + // determining if `Foo` satisfies the requirements + // of `IBar`. + // + // * Lookup in the presence of `__transparent` members + // may produce references to declarations on a *field* + // of the type rather than the type. Conformance through + // transparent members could be supported in theory, + // but would require synthesizing proxy/forwarding + // implementations in the type itself. + // + // For the first issue, we will use a flag to influence + // lookup so that it doesn't include results looked up + // through interface inheritance clauses (but it *will* + // look up result through inheritance clauses corresponding + // to concrete types). + // + // The second issue of members that require us to proxy/forward + // requests will be handled further down. For now we include + // lookup results that might be usable, but not as-is. + // + LookupResult lookupResult; + if (!isWrapperTypeDecl(context->parentDecl)) + { + lookupResult = lookUpMember( + m_astBuilder, + this, + name, + subType, + nullptr, + LookupMask::Default, + LookupOptions::IgnoreBaseInterfaces); - if (auto supereclRefType = as(superType)) + if (!lookupResult.isValid()) { - auto superTypeDeclRef = supereclRefType->getDeclRef(); - if (auto superInterfaceDeclRef = superTypeDeclRef.as()) + // If we failed to look up a member with the name of the + // requirement, it may be possible that we can still synthesis the + // implementation if this is one of the known builtin requirements. + // Otherwise, report diagnostic now. + + if (requiredMemberDeclRef.getDecl()->hasModifier() || + (requiredMemberDeclRef.as() && + getInner(requiredMemberDeclRef.as()) + ->hasModifier())) { - // The type is stating that it conforms to an interface. - // We need to check that it provides all of the members - // required by that interface. - return checkInterfaceConformance( - context, - subType, - superType, - inheritanceDecl, - superInterfaceDeclRef, - subIsSuperWitness, - witnessTable); } - else if( auto superStructDeclRef = superTypeDeclRef.as() ) + else if ( + requiredMemberDeclRef.as() && + (as(context->conformingType) || + as(context->conformingType) || + as(context->conformingType))) { - // The type is saying it inherits from a `struct`, - // which doesn't require any checking at present - return true; + } + else + { + getSink()->diagnose( + inheritanceDecl, + Diagnostics::typeDoesntImplementInterfaceRequirement, + subType, + requiredMemberDeclRef); + getSink()->diagnose( + requiredMemberDeclRef, + Diagnostics::seeDeclarationOf, + requiredMemberDeclRef); + return false; } } - if (!as(superType)) - { - getSink()->diagnose( - inheritanceDecl, - Diagnostics::invalidTypeForInheritance, - superType); - } - return false; } - static bool _doesTypeDeclHaveDefinition(ContainerDecl* decl) - { - if (auto aggTypeDecl = as(decl)) - return aggTypeDecl->hasBody; - return false; - } - - bool SemanticsVisitor::checkConformance( - Type* subType, - InheritanceDecl* inheritanceDecl, - ContainerDecl* parentDecl) - { - auto superType = inheritanceDecl->base.type; - - if( auto declRefType = as(subType) ) - { - auto declRef = declRefType->getDeclRef(); - - if (auto superDeclRefType = as(superType)) - { - auto superTypeDecl = superDeclRefType->getDeclRef().getDecl(); - if (superTypeDecl->findModifier()) - { - // A struct cannot implement a COM Interface. - if (auto classDecl = as(superTypeDecl)) - { - // OK. - SLANG_UNUSED(classDecl); - } - else if (auto subInterfaceDecl = as(superTypeDecl)) - { - if (!subInterfaceDecl->findModifier()) - { - getSink()->diagnose(inheritanceDecl, Diagnostics::interfaceInheritingComMustBeCom); - } - } - else if (const auto structDecl = as(superTypeDecl)) - { - getSink()->diagnose(inheritanceDecl, Diagnostics::structCannotImplementComInterface); - } - } - } + // Iterate over the members and look for one that matches + // the expected signature for the requirement. + for (auto member : lookupResult) + { + // To a first approximation, any lookup result that required a "breadcrumb" + // will not be usable to directly satisfy an interface requirement, since + // each breadcrumb will amount to a manipulation of `this` that is required + // to make the declaration usable (e.g., casting to a base type). + // + if (member.breadcrumbs != nullptr) + continue; - // Don't check conformances for abstract types that - // are being used to express *required* conformances. - if (auto assocTypeDeclRef = declRef.as()) - { - // An associated type declaration represents a requirement - // in an outer interface declaration, and its members - // (type constraints) represent additional requirements. - return true; - } - else if (auto interfaceDeclRef = declRef.as()) + if (doesMemberSatisfyRequirement(member.declRef, requiredMemberDeclRef, witnessTable)) + { + // The member satisfies the requirement in every other way except that + // it may have a lower visibility than min(parentVisibility, requirementVisibilty), + // in that case we will treat it as an error. + auto minRequiredVisibility = Math::Min( + getDeclVisibility(requiredMemberDeclRef.getDecl()), + getTypeVisibility(subType)); + if (getDeclVisibility(member.declRef.getDecl()) < minRequiredVisibility) { - // HACK: Our semantics as they stand today are that an - // `extension` of an interface that adds a new inheritance - // clause acts *as if* that inheritnace clause had been - // attached to the original `interface` decl: that is, - // it adds additional requirements. - // - // This is *not* a reasonable semantic to keep long-term, - // but it is required for some of our current example - // code to work. - return true; + getSink()->diagnose( + member.declRef, + Diagnostics::satisfyingDeclCannotHaveLowerVisibility, + member.declRef); + getSink()->diagnose( + requiredMemberDeclRef, + Diagnostics::seeDeclarationOf, + QualifiedDeclPath(requiredMemberDeclRef)); + return false; } - - + return true; } + } + + // If we reach this point then there were no members suitable + // for satisfying the interface requirement *diretly*. + // + // It is possible that one of the items in `lookupResult` could be + // used to synthesize an exact-match witness, by generating the + // code required to handle all the conversions that might be + // required on `this`. + // + // Another situation that will get us here is that we are dealing with + // a wrapper type (struct Foo:IFoo=FooImpl), and we will synthesize + // wrappers that redirects the call into the inner element. + // + if (trySynthesizeRequirementWitness(context, lookupResult, requiredMemberDeclRef, witnessTable)) + { + return true; + } - // Look at the type being inherited from, and validate - // appropriately. + // We failed to find a member of the type that can be used + // to satisfy the requirement (even via synthesis), so we + // need to report the failure to the user. + // + // TODO: Eventually we might want something akin to the current + // overload resolution logic, where we keep track of a list + // of "candidates" for satisfaction of the requirement, + // and if nothing is found we print the candidates that made it + // furthest in checking. + // + if (!lookupResult.isOverloaded() && lookupResult.isValid()) + { + getSink()->diagnose( + lookupResult.item.declRef, + Diagnostics::memberDoesNotMatchRequirementSignature, + lookupResult.item.declRef); + } + else + { + getSink()->diagnose( + inheritanceDecl, + Diagnostics::typeDoesntImplementInterfaceRequirement, + subType, + requiredMemberDeclRef); + } + getSink()->diagnose( + requiredMemberDeclRef, + Diagnostics::seeDeclarationOfInterfaceRequirement, + requiredMemberDeclRef); + return false; +} - DeclaredSubtypeWitness* subIsSuperWitness = m_astBuilder->getDeclaredSubtypeWitness(subType, superType, makeDeclRef(inheritanceDecl)); +RefPtr SemanticsVisitor::checkInterfaceConformance( + ConformanceCheckingContext* context, + Type* subType, + Type* superInterfaceType, + InheritanceDecl* inheritanceDecl, + DeclRef superInterfaceDeclRef, + SubtypeWitness* subTypeConformsToSuperInterfaceWitnes) +{ + // Has somebody already checked this conformance, + // and/or is in the middle of checking it? + RefPtr witnessTable; + if (context->mapInterfaceToWitnessTable.tryGetValue(superInterfaceDeclRef, witnessTable)) + return witnessTable; - ConformanceCheckingContext context; - context.conformingType = subType; - context.parentDecl = parentDecl; + // We need to check the declaration of the interface + // before we can check that we conform to it. + // + ensureDecl(superInterfaceDeclRef, DeclCheckState::CanReadInterfaceRequirements); + // We will construct the witness table, and register it + // *before* we go about checking fine-grained requirements, + // in order to short-circuit any potential for infinite recursion. - RefPtr witnessTable = inheritanceDecl->witnessTable; - if(!witnessTable) - { - witnessTable = new WitnessTable(); - witnessTable->baseType = superType; - witnessTable->witnessedType = subType; - witnessTable->isExtern = (!_doesTypeDeclHaveDefinition(parentDecl) - && parentDecl->hasModifier()); - inheritanceDecl->witnessTable = witnessTable; - } + // Note: we will re-use the witnes table attached to the inheritance decl, + // if there is one. This catches cases where semantic checking might + // have synthesized some of the conformance witnesses for us. + // + witnessTable = inheritanceDecl->witnessTable; + if (!witnessTable) + { + witnessTable = new WitnessTable(); + witnessTable->baseType = DeclRefType::create(m_astBuilder, superInterfaceDeclRef); + witnessTable->witnessedType = subType; + } + context->mapInterfaceToWitnessTable.add(superInterfaceDeclRef, witnessTable); - if( !checkConformanceToType(&context, subType, inheritanceDecl, superType, subIsSuperWitness, witnessTable) ) - { - return false; - } + if (!checkInterfaceConformance( + context, + subType, + superInterfaceType, + inheritanceDecl, + superInterfaceDeclRef, + subTypeConformsToSuperInterfaceWitnes, + witnessTable)) + return nullptr; + + return witnessTable; +} +static bool isAssociatedTypeDecl(Decl* decl) +{ + auto d = decl; + while (auto genericDecl = as(d)) + d = genericDecl->inner; + if (as(d)) return true; - } + return false; +} - void SemanticsVisitor::checkExtensionConformance(ExtensionDecl* decl) - { - auto declRef = createDefaultSubstitutionsIfNeeded(m_astBuilder, this, makeDeclRef(decl)).as(); - auto targetType = getTargetType(m_astBuilder, declRef); +bool SemanticsVisitor::checkInterfaceConformance( + ConformanceCheckingContext* context, + Type* subType, + Type* superInterfaceType, + InheritanceDecl* inheritanceDecl, + DeclRef superInterfaceDeclRef, + SubtypeWitness* subTypeConformsToSuperInterfaceWitness, + WitnessTable* witnessTable) +{ + // We need to check the declaration of the interface + // before we can check that we conform to it. + // + ensureDecl(superInterfaceDeclRef, DeclCheckState::CanReadInterfaceRequirements); + + // When comparing things like signatures, we need to do so in the context + // of a LookupDeclRef that aligns the signatures in the interface + // with those in the concrete type. For example, we need to treat any uses + // of `This` in the interface as equivalent to the concrete type for the + // purpose of signature matching (and similarly for associated types). + // + auto thisTypeDeclRef = m_astBuilder->getLookupDeclRef( + subTypeConformsToSuperInterfaceWitness, + superInterfaceDeclRef.getDecl()->getThisTypeDecl()); + + bool result = true; + + // TODO: If we ever allow for implementation inheritance, + // then we will need to consider the case where a type + // declares that it conforms to an interface, but one of + // its (non-interface) base types already conforms to + // that interface, so that all of the requirements are + // already satisfied with inherited implementations... + + // Note: we break this logic into two loops, where we first + // check conformance for all associated-type requirements + // and *then* check conformance for all other requirements. + // + // Checking associated-type requirements first ensures that + // we can make use of the identity of the associated types + // when checking other members. + // + // TODO: There could in theory be subtle cases involving + // circular or recursive dependency chains that make such + // a simple ordering impractical (e.g., associated type `A` + // is constrained to `IThing` where `IThing` requires + // that `T : IOtherThing where T.B == int` for another associated + // type `B`). + // + // The only robust solution long-term is probably to treat this + // as a type-inference problem by creating type variables to + // stand in for the associated-type requirements and then to discover + // constraints and solve for those type variables as part of the + // conformance-checking process. + // + for (auto requiredMemberDecl : getMembers(m_astBuilder, superInterfaceDeclRef)) + { + if (!isAssociatedTypeDecl(requiredMemberDecl.getDecl())) + continue; + auto requiredMemberDeclRef = m_astBuilder->getLookupDeclRef( + subTypeConformsToSuperInterfaceWitness, + requiredMemberDecl.getDecl()); + auto requirementSatisfied = findWitnessForInterfaceRequirement( + context, + subType, + superInterfaceType, + inheritanceDecl, + superInterfaceDeclRef, + requiredMemberDeclRef, + witnessTable, + subTypeConformsToSuperInterfaceWitness); - for (auto inheritanceDecl : decl->getMembersOfType()) - { - checkConformance(targetType, inheritanceDecl, decl); - } + result = result && requirementSatisfied; } - - void SemanticsVisitor::checkDifferentiableMembersInType(AggTypeDecl* decl) + for (auto requiredMemberDecl : getMembers(m_astBuilder, superInterfaceDeclRef)) { - for (auto member : decl->getMembersOfType()) - { - if (auto derivativeAttr = member->findModifier()) - { - checkDerivativeMemberAttributeReferences(member, derivativeAttr); - } + if (isAssociatedTypeDecl(requiredMemberDecl.getDecl())) + continue; + if (requiredMemberDecl.as()) + continue; + auto requiredMemberDeclRef = m_astBuilder->getLookupDeclRef( + subTypeConformsToSuperInterfaceWitness, + requiredMemberDecl.getDecl()); + auto requirementSatisfied = findWitnessForInterfaceRequirement( + context, + subType, + superInterfaceType, + inheritanceDecl, + superInterfaceDeclRef, + requiredMemberDeclRef, + witnessTable, + subTypeConformsToSuperInterfaceWitness); + + result = result && requirementSatisfied; + } + + // Extensions that apply to the interface type can create new conformances + // for the concrete types that inherit from the interface. + // + // These new conformances should not be able to introduce new *requirements* + // for an implementing interface (although they currently can), but we + // still need to go through this logic to find the appropriate value + // that will satisfy the requirement in these cases, and also to put + // the required entry into the witness table for the interface itself. + // + // TODO: This logic is a bit slippery, and we need to figure out what + // it means in the context of separate compilation. If module A defines + // an interface IA, module B defines a type C that conforms to IA, and then + // module C defines an extension that makes IA conform to IC, then it is + // unreasonable to expect the {B:IA} witness table to contain an entry + // corresponding to {IA:IC}. + // + // The simple answer then would be that the {IA:IC} conformance should be + // fixed, with a single witness table for {IA:IC}, but then what should + // happen in B explicitly conformed to IC already? + // + // For now we will just walk through the extensions that are known at + // the time we are compiling and handle those, and punt on the larger issue + // for a bit longer. + // + for (auto candidateExt : getCandidateExtensions(superInterfaceDeclRef, this)) + { + // We need to apply the extension to the interface type that our + // concrete type is inheriting from. + // + Type* targetType = DeclRefType::create(m_astBuilder, thisTypeDeclRef); + auto parentDeclRef = applyExtensionToType(candidateExt, targetType); + if (!parentDeclRef) + continue; + + // Only inheritance clauses from the extension matter right now. + for (auto requiredInheritanceDecl : + getMembersOfType(m_astBuilder, candidateExt)) + { + auto requiredInheritanceDeclRef = m_astBuilder->getLookupDeclRef( + subTypeConformsToSuperInterfaceWitness, + requiredInheritanceDecl.getDecl()); + auto requirementSatisfied = findWitnessForInterfaceRequirement( + context, + subType, + superInterfaceType, + inheritanceDecl, + superInterfaceDeclRef, + requiredInheritanceDeclRef, + witnessTable, + subTypeConformsToSuperInterfaceWitness); + + result = result && requirementSatisfied; } } - void SemanticsVisitor::checkAggTypeConformance(AggTypeDecl* decl) - { - // After we've checked members, we need to go through - // any inheritance clauses on the type itself, and - // confirm that the type actually provides whatever - // those clauses require. + // The conformance was satisfied if all the requirements were satisfied. + // + return result; +} - if (const auto interfaceDecl = as(decl)) +bool SemanticsVisitor::checkConformanceToType( + ConformanceCheckingContext* context, + Type* subType, + InheritanceDecl* inheritanceDecl, + Type* superType, + SubtypeWitness* subIsSuperWitness, + WitnessTable* witnessTable) +{ + if (witnessTable->isExtern) + return true; + + if (auto supereclRefType = as(superType)) + { + auto superTypeDeclRef = supereclRefType->getDeclRef(); + if (auto superInterfaceDeclRef = superTypeDeclRef.as()) { - // Don't check that an interface conforms to the - // things it inherits from. + // The type is stating that it conforms to an interface. + // We need to check that it provides all of the members + // required by that interface. + return checkInterfaceConformance( + context, + subType, + superType, + inheritanceDecl, + superInterfaceDeclRef, + subIsSuperWitness, + witnessTable); } - else if (const auto assocTypeDecl = as(decl)) + else if (auto superStructDeclRef = superTypeDeclRef.as()) { - // Don't check that an associated type decl conforms to the - // things it inherits from. + // The type is saying it inherits from a `struct`, + // which doesn't require any checking at present + return true; } - else - { - // For non-interface types we need to check conformance. - // - - auto astBuilder = getASTBuilder(); - - auto declRef = createDefaultSubstitutionsIfNeeded(astBuilder, this, makeDeclRef(decl)).as(); - auto type = DeclRefType::create(astBuilder, declRef); - - // TODO: Need to figure out what this should do for - // `abstract` types if we ever add them. Should they - // be required to implement all interface requirements, - // just with `abstract` methods that replicate things? - // (That's what C# does). - - // Make a copy of inhertanceDecls firstsince `checkConformance` may modify decl->members. - auto inheritanceDecls = decl->getMembersOfType().toList(); - for (auto inheritanceDecl : inheritanceDecls) - { - // Special handling for when we check for conformance against `IDifferentiable` - // We will reference-checking for the [DerivativeMember(DiffType.member)] - // attributes here, since they have to be performed after types can be referenced - // and before conformance checking, where this information can be used to synthesize - // member methods (such as `dzero`, `dadd`, etc..) - // - if (inheritanceDecl->getSup().type->equals( - astBuilder->getDifferentiableInterfaceType())) - checkDifferentiableMembersInType(decl); + } + if (!as(superType)) + { + getSink()->diagnose(inheritanceDecl, Diagnostics::invalidTypeForInheritance, superType); + } + return false; +} - checkConformance(type, inheritanceDecl, decl); - } +static bool _doesTypeDeclHaveDefinition(ContainerDecl* decl) +{ + if (auto aggTypeDecl = as(decl)) + return aggTypeDecl->hasBody; + return false; +} - // Successful conformance checking may have created new witness tables. - // Increment epoch to invalidate the cache, so subsequent canonical types are - // re-calculated. - // - // TODO: Is it really necessary to invalidate globally? Maybe there's a way to invalidate only the - // types that are affected by these interface decls. - // - astBuilder->incrementEpoch(); - } - } +bool SemanticsVisitor::checkConformance( + Type* subType, + InheritanceDecl* inheritanceDecl, + ContainerDecl* parentDecl) +{ + auto superType = inheritanceDecl->base.type; - void SemanticsDeclBasesVisitor::_validateCrossModuleInheritance( - AggTypeDeclBase* decl, - InheritanceDecl* inheritanceDecl) + if (auto declRefType = as(subType)) { - // Within a single module, users should be allowed to inherit - // one type from another more or less freely, so long as they - // don't violate fundamental validity conditions around - // inheritance. - // - // When an inheritance relationship is declared in one module, - // and the base type is in another module, we may want to - // enforce more restrictions. As a strong example, we probably - // don't want people to declare their own subtype of `int` - // or `Texture2D`. - // - // We start by checking if the type being inherited from is - // a decl-ref type, since that means it refers to a declaration - // that can be localized to its original module. - // - auto baseType = inheritanceDecl->base.type; - auto baseDeclRefType = as(baseType); - if( !baseDeclRefType ) - { - return; - } - auto baseDecl = baseDeclRefType->getDeclRef().getDecl(); + auto declRef = declRefType->getDeclRef(); - // Using the parent/child hierarchy baked into `Decl`s we - // can find the modules that contain both the `decl` doing - // the inheriting, and the `baseDeclRefType` that is being - // inherited from. - // - // If those modules are the same, then we aren't seeing any - // kind of cross-module inheritance here, and there is nothing - // that needs enforcing. - // - auto moduleWithInheritance = getModule(decl); - auto moduleWithBaseType = getModule(baseDecl); - if( moduleWithInheritance == moduleWithBaseType ) + if (auto superDeclRefType = as(superType)) { - return; + auto superTypeDecl = superDeclRefType->getDeclRef().getDecl(); + if (superTypeDecl->findModifier()) + { + // A struct cannot implement a COM Interface. + if (auto classDecl = as(superTypeDecl)) + { + // OK. + SLANG_UNUSED(classDecl); + } + else if (auto subInterfaceDecl = as(superTypeDecl)) + { + if (!subInterfaceDecl->findModifier()) + { + getSink()->diagnose( + inheritanceDecl, + Diagnostics::interfaceInheritingComMustBeCom); + } + } + else if (const auto structDecl = as(superTypeDecl)) + { + getSink()->diagnose( + inheritanceDecl, + Diagnostics::structCannotImplementComInterface); + } + } } - if( baseDecl->hasModifier() ) - { - // If the original declaration had the `[sealed]` attribute on it, - // then it explicitly does *not* allow inheritance from other - // modules. - // - getSink()->diagnose(inheritanceDecl, Diagnostics::cannotInheritFromExplicitlySealedDeclarationInAnotherModule, baseType, moduleWithBaseType->getModuleDecl()->getName()); - return; - } - else if( baseDecl->hasModifier() ) + // Don't check conformances for abstract types that + // are being used to express *required* conformances. + if (auto assocTypeDeclRef = declRef.as()) { - // Conversely, if the original declaration had the `[open]` attribute - // on it, then it explicit *does* allow inheritance from other - // modules. - // - // In this case we don't need to check anything: the inheritance - // is allowed. - } - else if( as(baseDecl) ) - { - // If an interface isn't explicitly marked `[open]` or `[sealed]`, - // then the default behavior is to treat it as `[open]`, since - // interfaces are most often used to define protocols that - // users of a module can opt into. + // An associated type declaration represents a requirement + // in an outer interface declaration, and its members + // (type constraints) represent additional requirements. + return true; } - else + else if (auto interfaceDeclRef = declRef.as()) { - // For any non-interface type, if the declaration didn't specify - // `[open]` or `[sealed]` then we assume `[sealed]` is the default. + // HACK: Our semantics as they stand today are that an + // `extension` of an interface that adds a new inheritance + // clause acts *as if* that inheritnace clause had been + // attached to the original `interface` decl: that is, + // it adds additional requirements. // - getSink()->diagnose(inheritanceDecl, Diagnostics::cannotInheritFromImplicitlySealedDeclarationInAnotherModule, baseType, moduleWithBaseType->getModuleDecl()->getName()); - return; + // This is *not* a reasonable semantic to keep long-term, + // but it is required for some of our current example + // code to work. + return true; } } - void SemanticsDeclBasesVisitor::visitInterfaceDecl(InterfaceDecl* decl) - { - SLANG_OUTER_SCOPE_CONTEXT_DECL_RAII(this, decl); - checkVisibility(decl); - for( auto inheritanceDecl : decl->getMembersOfType() ) - { - ensureDecl(inheritanceDecl, DeclCheckState::CanUseBaseOfInheritanceDecl); - auto baseType = inheritanceDecl->base.type; - - // It is possible that there was an error in checking the base type - // expression, and in such a case we shouldn't emit a cascading error. - // - if( const auto baseErrorType = as(baseType) ) - { - continue; - } - - // An `interface` type can only inherit from other `interface` types. - // - // TODO: In the long run it might make sense for an interface to support - // an inheritance clause naming a non-interface type, with the meaning - // that any type that implements the interface must be a sub-type of the - // type named in the inheritance clause. - // - auto baseDeclRefType = as(baseType); - if( !baseDeclRefType ) - { - getSink()->diagnose(inheritanceDecl, Diagnostics::baseOfInterfaceMustBeInterface, decl, baseType); - continue; - } + // Look at the type being inherited from, and validate + // appropriately. - auto baseDeclRef = baseDeclRefType->getDeclRef(); - auto baseInterfaceDeclRef = baseDeclRef.as(); - if( !baseInterfaceDeclRef ) - { - getSink()->diagnose(inheritanceDecl, Diagnostics::baseOfInterfaceMustBeInterface, decl, baseType); - continue; - } + DeclaredSubtypeWitness* subIsSuperWitness = + m_astBuilder->getDeclaredSubtypeWitness(subType, superType, makeDeclRef(inheritanceDecl)); - // TODO: At this point we have the `baseInterfaceDeclRef` - // and could use it to perform further validity checks, - // and/or to build up a more refined representation of - // the inheritance graph for this type (e.g., a "class - // precedence list"). - // - // E.g., we can/should check that we aren't introducing - // a circular inheritance relationship. + ConformanceCheckingContext context; + context.conformingType = subType; + context.parentDecl = parentDecl; - _validateCrossModuleInheritance(decl, inheritanceDecl); - } - if (decl->findModifier()) - { - // `associatedtype` declaration is not allowed in a COM interface declaration. - for (auto associatedType : decl->getMembersOfType()) - { - getSink()->diagnose( - associatedType, Diagnostics::associatedTypeNotAllowInComInterface); - } - } + RefPtr witnessTable = inheritanceDecl->witnessTable; + if (!witnessTable) + { + witnessTable = new WitnessTable(); + witnessTable->baseType = superType; + witnessTable->witnessedType = subType; + witnessTable->isExtern = + (!_doesTypeDeclHaveDefinition(parentDecl) && parentDecl->hasModifier()); + inheritanceDecl->witnessTable = witnessTable; } - void SemanticsDeclBasesVisitor::visitStructDecl(StructDecl* decl) + if (!checkConformanceToType( + &context, + subType, + inheritanceDecl, + superType, + subIsSuperWitness, + witnessTable)) { - // A `struct` type can only inherit from `struct` or `interface` types. - // - // Furthermore, only the first inheritance clause (in source - // order) is allowed to declare a base `struct` type. - // - SLANG_OUTER_SCOPE_CONTEXT_DECL_RAII(this, decl); - - Index inheritanceClauseCounter = 0; - for( auto inheritanceDecl : decl->getMembersOfType() ) - { - Index inheritanceClauseIndex = inheritanceClauseCounter++; - - ensureDecl(inheritanceDecl, DeclCheckState::CanUseBaseOfInheritanceDecl); - auto baseType = inheritanceDecl->base.type; - - // It is possible that there was an error in checking the base type - // expression, and in such a case we shouldn't emit a cascading error. - // - if( const auto baseErrorType = as(baseType) ) - { - continue; - } - - auto baseDeclRefType = as(baseType); - if( !baseDeclRefType ) - { - getSink()->diagnose(inheritanceDecl, Diagnostics::baseOfStructMustBeStructOrInterface, decl, baseType); - continue; - } + return false; + } - auto baseDeclRef = baseDeclRefType->getDeclRef(); - if( auto baseInterfaceDeclRef = baseDeclRef.as() ) - { - } - else if( auto baseStructDeclRef = baseDeclRef.as() ) - { - // To simplify the task of reading and maintaining code, - // we require that when a `struct` inherits from another - // `struct`, the base `struct` is the first item in - // the list of bases (before any interfaces). - // - // This constraint also has the secondary effect of restricting - // it so that a `struct` cannot multiply inherit from other - // `struct` types. - // - if( inheritanceClauseIndex != 0 ) - { - getSink()->diagnose(inheritanceDecl, Diagnostics::baseStructMustBeListedFirst, decl, baseType); - } - } - else - { - getSink()->diagnose(inheritanceDecl, Diagnostics::baseOfStructMustBeStructOrInterface, decl, baseType); - continue; - } + return true; +} - if (this->getOptionSet().getBoolOption(CompilerOptionName::ZeroInitialize) && !isFromCoreModule(decl)) - { - // Force add IDefaultInitializable to any struct missing (transitively) `IDefaultInitializable`. - auto* defaultInitializableType = m_astBuilder->getDefaultInitializableType(); - if(!isSubtype(DeclRefType::create(m_astBuilder, decl), defaultInitializableType, IsSubTypeOptions::NoCaching)) - { - InheritanceDecl* conformanceDecl = m_astBuilder->create(); - conformanceDecl->parentDecl = decl; - conformanceDecl->loc = decl->loc; - conformanceDecl->base.type = defaultInitializableType; - conformanceDecl->nameAndLoc.name = getName("$inheritance"); - decl->members.add(conformanceDecl); - } - } +void SemanticsVisitor::checkExtensionConformance(ExtensionDecl* decl) +{ + auto declRef = createDefaultSubstitutionsIfNeeded(m_astBuilder, this, makeDeclRef(decl)) + .as(); + auto targetType = getTargetType(m_astBuilder, declRef); - // TODO: At this point we have the `baseDeclRef` - // and could use it to perform further validity checks, - // and/or to build up a more refined representation of - // the inheritance graph for this type (e.g., a "class - // precedence list"). - // - // E.g., we can/should check that we aren't introducing - // a circular inheritance relationship. + for (auto inheritanceDecl : decl->getMembersOfType()) + { + checkConformance(targetType, inheritanceDecl, decl); + } +} - _validateCrossModuleInheritance(decl, inheritanceDecl); +void SemanticsVisitor::checkDifferentiableMembersInType(AggTypeDecl* decl) +{ + for (auto member : decl->getMembersOfType()) + { + if (auto derivativeAttr = member->findModifier()) + { + checkDerivativeMemberAttributeReferences(member, derivativeAttr); } } +} + +void SemanticsVisitor::checkAggTypeConformance(AggTypeDecl* decl) +{ + // After we've checked members, we need to go through + // any inheritance clauses on the type itself, and + // confirm that the type actually provides whatever + // those clauses require. - void SemanticsDeclBasesVisitor::visitClassDecl(ClassDecl* decl) + if (const auto interfaceDecl = as(decl)) { - // A `class` type can only inherit from `class` or `interface` types. - // - // Furthermore, only the first inheritance clause (in source - // order) is allowed to declare a base `class` type. + // Don't check that an interface conforms to the + // things it inherits from. + } + else if (const auto assocTypeDecl = as(decl)) + { + // Don't check that an associated type decl conforms to the + // things it inherits from. + } + else + { + // For non-interface types we need to check conformance. // - SLANG_OUTER_SCOPE_CONTEXT_DECL_RAII(this, decl); - - Index inheritanceClauseCounter = 0; - for (auto inheritanceDecl : decl->getMembersOfType()) - { - Index inheritanceClauseIndex = inheritanceClauseCounter++; - ensureDecl(inheritanceDecl, DeclCheckState::CanUseBaseOfInheritanceDecl); - auto baseType = inheritanceDecl->base.type; - - // It is possible that there was an error in checking the base type - // expression, and in such a case we shouldn't emit a cascading error. - // - if (const auto baseErrorType = as(baseType)) - { - continue; - } + auto astBuilder = getASTBuilder(); - auto baseDeclRefType = as(baseType); - if (!baseDeclRefType) - { - getSink()->diagnose(inheritanceDecl, Diagnostics::baseOfClassMustBeClassOrInterface, decl, baseType); - continue; - } + auto declRef = createDefaultSubstitutionsIfNeeded(astBuilder, this, makeDeclRef(decl)) + .as(); + auto type = DeclRefType::create(astBuilder, declRef); - auto baseDeclRef = baseDeclRefType->getDeclRef(); - if (auto baseInterfaceDeclRef = baseDeclRef.as()) - { - } - else if (auto baseStructDeclRef = baseDeclRef.as()) - { - // To simplify the task of reading and maintaining code, - // we require that when a `class` inherits from another - // `class`, the base `class` is the first item in - // the list of bases (before any interfaces). - // - // This constraint also has the secondary effect of restricting - // it so that a `struct` cannot multiply inherit from other - // `struct` types. - // - if (inheritanceClauseIndex != 0) - { - getSink()->diagnose(inheritanceDecl, Diagnostics::baseClassMustBeListedFirst, decl, baseType); - } - } - else - { - getSink()->diagnose(inheritanceDecl, Diagnostics::baseOfClassMustBeClassOrInterface, decl, baseType); - continue; - } + // TODO: Need to figure out what this should do for + // `abstract` types if we ever add them. Should they + // be required to implement all interface requirements, + // just with `abstract` methods that replicate things? + // (That's what C# does). - // TODO: At this point we have the `baseDeclRef` - // and could use it to perform further validity checks, - // and/or to build up a more refined representation of - // the inheritance graph for this type (e.g., a "class - // precedence list"). + // Make a copy of inhertanceDecls firstsince `checkConformance` may modify decl->members. + auto inheritanceDecls = decl->getMembersOfType().toList(); + for (auto inheritanceDecl : inheritanceDecls) + { + // Special handling for when we check for conformance against `IDifferentiable` + // We will reference-checking for the [DerivativeMember(DiffType.member)] + // attributes here, since they have to be performed after types can be referenced + // and before conformance checking, where this information can be used to synthesize + // member methods (such as `dzero`, `dadd`, etc..) // - // E.g., we can/should check that we aren't introducing - // a circular inheritance relationship. + if (inheritanceDecl->getSup().type->equals( + astBuilder->getDifferentiableInterfaceType())) + checkDifferentiableMembersInType(decl); - _validateCrossModuleInheritance(decl, inheritanceDecl); + checkConformance(type, inheritanceDecl, decl); } - } - bool SemanticsVisitor::isIntegerBaseType(BaseType baseType) - { - return (BaseTypeInfo::getInfo(baseType).flags & BaseTypeInfo::Flag::Integer) != 0; + // Successful conformance checking may have created new witness tables. + // Increment epoch to invalidate the cache, so subsequent canonical types are + // re-calculated. + // + // TODO: Is it really necessary to invalidate globally? Maybe there's a way to invalidate + // only the types that are affected by these interface decls. + // + astBuilder->incrementEpoch(); } +} - bool SemanticsVisitor::isScalarIntegerType(Type* type) - { - auto basicType = as(type); - if(!basicType) - return false; - auto baseType = basicType->getBaseType(); - return isIntegerBaseType(baseType) || baseType == BaseType::Bool; +void SemanticsDeclBasesVisitor::_validateCrossModuleInheritance( + AggTypeDeclBase* decl, + InheritanceDecl* inheritanceDecl) +{ + // Within a single module, users should be allowed to inherit + // one type from another more or less freely, so long as they + // don't violate fundamental validity conditions around + // inheritance. + // + // When an inheritance relationship is declared in one module, + // and the base type is in another module, we may want to + // enforce more restrictions. As a strong example, we probably + // don't want people to declare their own subtype of `int` + // or `Texture2D`. + // + // We start by checking if the type being inherited from is + // a decl-ref type, since that means it refers to a declaration + // that can be localized to its original module. + // + auto baseType = inheritanceDecl->base.type; + auto baseDeclRefType = as(baseType); + if (!baseDeclRefType) + { + return; + } + auto baseDecl = baseDeclRefType->getDeclRef().getDecl(); + + // Using the parent/child hierarchy baked into `Decl`s we + // can find the modules that contain both the `decl` doing + // the inheriting, and the `baseDeclRefType` that is being + // inherited from. + // + // If those modules are the same, then we aren't seeing any + // kind of cross-module inheritance here, and there is nothing + // that needs enforcing. + // + auto moduleWithInheritance = getModule(decl); + auto moduleWithBaseType = getModule(baseDecl); + if (moduleWithInheritance == moduleWithBaseType) + { + return; + } + + if (baseDecl->hasModifier()) + { + // If the original declaration had the `[sealed]` attribute on it, + // then it explicitly does *not* allow inheritance from other + // modules. + // + getSink()->diagnose( + inheritanceDecl, + Diagnostics::cannotInheritFromExplicitlySealedDeclarationInAnotherModule, + baseType, + moduleWithBaseType->getModuleDecl()->getName()); + return; + } + else if (baseDecl->hasModifier()) + { + // Conversely, if the original declaration had the `[open]` attribute + // on it, then it explicit *does* allow inheritance from other + // modules. + // + // In this case we don't need to check anything: the inheritance + // is allowed. + } + else if (as(baseDecl)) + { + // If an interface isn't explicitly marked `[open]` or `[sealed]`, + // then the default behavior is to treat it as `[open]`, since + // interfaces are most often used to define protocols that + // users of a module can opt into. + } + else + { + // For any non-interface type, if the declaration didn't specify + // `[open]` or `[sealed]` then we assume `[sealed]` is the default. + // + getSink()->diagnose( + inheritanceDecl, + Diagnostics::cannotInheritFromImplicitlySealedDeclarationInAnotherModule, + baseType, + moduleWithBaseType->getModuleDecl()->getName()); + return; } +} - bool SemanticsVisitor::isValidCompileTimeConstantType(Type* type) +void SemanticsDeclBasesVisitor::visitInterfaceDecl(InterfaceDecl* decl) +{ + SLANG_OUTER_SCOPE_CONTEXT_DECL_RAII(this, decl); + checkVisibility(decl); + for (auto inheritanceDecl : decl->getMembersOfType()) { - return isScalarIntegerType(type) || isEnumType(type); - } + ensureDecl(inheritanceDecl, DeclCheckState::CanUseBaseOfInheritanceDecl); + auto baseType = inheritanceDecl->base.type; - bool SemanticsVisitor::isIntValueInRangeOfType(IntegerLiteralValue value, Type* type) - { - auto basicType = as(type); - if (!basicType) - return false; + // It is possible that there was an error in checking the base type + // expression, and in such a case we shouldn't emit a cascading error. + // + if (const auto baseErrorType = as(baseType)) + { + continue; + } - switch (basicType->getBaseType()) + // An `interface` type can only inherit from other `interface` types. + // + // TODO: In the long run it might make sense for an interface to support + // an inheritance clause naming a non-interface type, with the meaning + // that any type that implements the interface must be a sub-type of the + // type named in the inheritance clause. + // + auto baseDeclRefType = as(baseType); + if (!baseDeclRefType) { - case BaseType::UInt8: - return (value >= 0 && value <= std::numeric_limits::max()) || (value == -1); - case BaseType::UInt16: - return (value >= 0 && value <= std::numeric_limits::max()) || (value == -1); - case BaseType::UInt: -#if SLANG_PTR_IS_32 - case BaseType::UIntPtr: -#endif - return (value >= 0 && value <= std::numeric_limits::max()) || (value == -1); - case BaseType::UInt64: -#if SLANG_PTR_IS_64 - case BaseType::UIntPtr: -#endif - return true; - case BaseType::Int8: - return value >= std::numeric_limits::min() && value <= std::numeric_limits::max(); - case BaseType::Int16: - return value >= std::numeric_limits::min() && value <= std::numeric_limits::max(); - case BaseType::Int: -#if SLANG_PTR_IS_32 - case BaseType::IntPtr: -#endif - return value >= std::numeric_limits::min() && value <= std::numeric_limits::max(); - case BaseType::Int64: -#if SLANG_PTR_IS_64 - case BaseType::IntPtr: -#endif - return value >= std::numeric_limits::min() && value <= std::numeric_limits::max(); - default: - return false; + getSink()->diagnose( + inheritanceDecl, + Diagnostics::baseOfInterfaceMustBeInterface, + decl, + baseType); + continue; + } + + auto baseDeclRef = baseDeclRefType->getDeclRef(); + auto baseInterfaceDeclRef = baseDeclRef.as(); + if (!baseInterfaceDeclRef) + { + getSink()->diagnose( + inheritanceDecl, + Diagnostics::baseOfInterfaceMustBeInterface, + decl, + baseType); + continue; } - } - void SemanticsVisitor::validateEnumTagType(Type* type, SourceLoc const& loc) - { - // Allow the built-in integer types. + // TODO: At this point we have the `baseInterfaceDeclRef` + // and could use it to perform further validity checks, + // and/or to build up a more refined representation of + // the inheritance graph for this type (e.g., a "class + // precedence list"). // - if(isScalarIntegerType(type)) - return; + // E.g., we can/should check that we aren't introducing + // a circular inheritance relationship. - // By default, don't allow other types to be used - // as an `enum` tag type. - // - getSink()->diagnose(loc, Diagnostics::invalidEnumTagType, type); + _validateCrossModuleInheritance(decl, inheritanceDecl); } - void SemanticsDeclBasesVisitor::visitEnumDecl(EnumDecl* decl) + if (decl->findModifier()) { - SLANG_OUTER_SCOPE_CONTEXT_DECL_RAII(this, decl); - checkVisibility(decl); - - // An `enum` type can inherit from interfaces, and also - // from a single "tag" type that must: - // - // * be a built-in integer type - // * come first in the list of base types - // - Index inheritanceClauseCounter = 0; - Type* tagType = nullptr; - InheritanceDecl* tagTypeInheritanceDecl = nullptr; - for(auto inheritanceDecl : decl->getMembersOfType()) + // `associatedtype` declaration is not allowed in a COM interface declaration. + for (auto associatedType : decl->getMembersOfType()) { - Index inheritanceClauseIndex = inheritanceClauseCounter++; + getSink()->diagnose(associatedType, Diagnostics::associatedTypeNotAllowInComInterface); + } + } +} - ensureDecl(inheritanceDecl, DeclCheckState::CanUseBaseOfInheritanceDecl); - auto baseType = inheritanceDecl->base.type; +void SemanticsDeclBasesVisitor::visitStructDecl(StructDecl* decl) +{ + // A `struct` type can only inherit from `struct` or `interface` types. + // + // Furthermore, only the first inheritance clause (in source + // order) is allowed to declare a base `struct` type. + // + SLANG_OUTER_SCOPE_CONTEXT_DECL_RAII(this, decl); - // It is possible that there was an error in checking the base type - // expression, and in such a case we shouldn't emit a cascading error. - // - if( const auto baseErrorType = as(baseType) ) - { - continue; - } + Index inheritanceClauseCounter = 0; + for (auto inheritanceDecl : decl->getMembersOfType()) + { + Index inheritanceClauseIndex = inheritanceClauseCounter++; - auto baseDeclRefType = as(baseType); - if( !baseDeclRefType ) - { - getSink()->diagnose(inheritanceDecl, Diagnostics::baseOfEnumMustBeIntegerOrInterface, decl, baseType); - continue; - } + ensureDecl(inheritanceDecl, DeclCheckState::CanUseBaseOfInheritanceDecl); + auto baseType = inheritanceDecl->base.type; - auto baseDeclRef = baseDeclRefType->getDeclRef(); - if( auto baseInterfaceDeclRef = baseDeclRef.as() ) - { - _validateCrossModuleInheritance(decl, inheritanceDecl); - } - else if( auto baseStructDeclRef = baseDeclRef.as() ) - { - // To simplify the task of reading and maintaining code, - // we require that when an `enum` declares an explicit - // underlying tag type using an inheritance clause, that - // type must be the first item in the list of bases. - // - // This constraint also has the secondary effect of restricting - // it so that an `enum` can't possibly have multiple tag - // types declared. - // - if( inheritanceClauseIndex != 0 ) - { - getSink()->diagnose(inheritanceDecl, Diagnostics::tagTypeMustBeListedFirst, decl, baseType); - } - else - { - tagType = baseType; - tagTypeInheritanceDecl = inheritanceDecl; - } + // It is possible that there was an error in checking the base type + // expression, and in such a case we shouldn't emit a cascading error. + // + if (const auto baseErrorType = as(baseType)) + { + continue; + } - // Note: we do *not* apply the code that validates - // cross-module inheritance to a base that represnts - // a tag type, because declaring a tag type for an - // `enum` doesn't actually make it into a subtype - // of the tag type, and thus doesn't violate the - // rules when the tag type is `sealed`. - } - else - { - getSink()->diagnose(inheritanceDecl, Diagnostics::baseOfEnumMustBeIntegerOrInterface, decl, baseType); - continue; - } + auto baseDeclRefType = as(baseType); + if (!baseDeclRefType) + { + getSink()->diagnose( + inheritanceDecl, + Diagnostics::baseOfStructMustBeStructOrInterface, + decl, + baseType); + continue; } - // If a tag type has not been set, then we - // default it to the built-in `int` type. - // - // TODO: In the far-flung future we may want to distinguish - // `enum` types that have a "raw representation" like this from - // ones that are purely abstract and don't expose their - // type of their tag. - // - if(!tagType) + auto baseDeclRef = baseDeclRefType->getDeclRef(); + if (auto baseInterfaceDeclRef = baseDeclRef.as()) { - tagType = m_astBuilder->getIntType(); } - else + else if (auto baseStructDeclRef = baseDeclRef.as()) { - // TODO: Need to establish that the tag - // type is suitable. (e.g., if we are going - // to allow raw values for case tags to be - // derived automatically, then the tag - // type needs to be some kind of integer type...) + // To simplify the task of reading and maintaining code, + // we require that when a `struct` inherits from another + // `struct`, the base `struct` is the first item in + // the list of bases (before any interfaces). // - // For now we will just be harsh and require it - // to be one of a few builtin types. - validateEnumTagType(tagType, tagTypeInheritanceDecl->loc); - - // Note: The `InheritanceDecl` that introduces a tag - // type isn't actually representing a super-type of - // the `enum`, and things like name lookup need to - // know to ignore that "inheritance" relationship. + // This constraint also has the secondary effect of restricting + // it so that a `struct` cannot multiply inherit from other + // `struct` types. // - // We add a modifier to the `InheritanceDecl` to ensure - // that it can be detected and ignored by such steps. - // - addModifier(tagTypeInheritanceDecl, m_astBuilder->create()); + if (inheritanceClauseIndex != 0) + { + getSink()->diagnose( + inheritanceDecl, + Diagnostics::baseStructMustBeListedFirst, + decl, + baseType); + } } - decl->tagType = tagType; - - - // An `enum` type should automatically conform to the `__EnumType` interface. - // The compiler needs to insert this conformance behind the scenes, and this - // seems like the best place to do it. + else { - // First, look up the type of the `__EnumType` interface. - Type* enumTypeType = getASTBuilder()->getEnumTypeType(); - - InheritanceDecl* enumConformanceDecl = m_astBuilder->create(); - enumConformanceDecl->parentDecl = decl; - enumConformanceDecl->loc = decl->loc; - enumConformanceDecl->base.type = getASTBuilder()->getEnumTypeType(); - decl->members.add(enumConformanceDecl); - - // The `__EnumType` interface has one required member, the `__Tag` type. - // We need to satisfy this requirement automatically, rather than require - // the user to actually declare a member with this name (otherwise we wouldn't - // let them define a tag value with the name `__Tag`). - // - RefPtr witnessTable = new WitnessTable(); - witnessTable->baseType = enumConformanceDecl->base.type; - witnessTable->witnessedType = enumTypeType; - enumConformanceDecl->witnessTable = witnessTable; + getSink()->diagnose( + inheritanceDecl, + Diagnostics::baseOfStructMustBeStructOrInterface, + decl, + baseType); + continue; + } - Name* tagAssociatedTypeName = getSession()->getNameObj("__Tag"); - Decl* tagAssociatedTypeDecl = nullptr; - if(auto enumTypeTypeDeclRefType = dynamicCast(enumTypeType)) - { - if(auto enumTypeTypeInterfaceDecl = as(enumTypeTypeDeclRefType->getDeclRef().getDecl())) - { - for(auto memberDecl : enumTypeTypeInterfaceDecl->members) - { - if(memberDecl->getName() == tagAssociatedTypeName) - { - tagAssociatedTypeDecl = memberDecl; - break; - } - } - } - } - if(!tagAssociatedTypeDecl) + if (this->getOptionSet().getBoolOption(CompilerOptionName::ZeroInitialize) && + !isFromCoreModule(decl)) + { + // Force add IDefaultInitializable to any struct missing (transitively) + // `IDefaultInitializable`. + auto* defaultInitializableType = m_astBuilder->getDefaultInitializableType(); + if (!isSubtype( + DeclRefType::create(m_astBuilder, decl), + defaultInitializableType, + IsSubTypeOptions::NoCaching)) { - SLANG_DIAGNOSE_UNEXPECTED(getSink(), decl, "failed to find built-in declaration '__Tag'"); + InheritanceDecl* conformanceDecl = m_astBuilder->create(); + conformanceDecl->parentDecl = decl; + conformanceDecl->loc = decl->loc; + conformanceDecl->base.type = defaultInitializableType; + conformanceDecl->nameAndLoc.name = getName("$inheritance"); + decl->members.add(conformanceDecl); } + } - // Okay, add the conformance witness for `__Tag` being satisfied by `tagType` - witnessTable->add(tagAssociatedTypeDecl, RequirementWitness(tagType)); - - // TODO: we actually also need to synthesize a witness for the conformance of `tagType` - // to the `__BuiltinIntegerType` interface, because that is a constraint on the - // associated type `__Tag`. - - // TODO: eventually we should consider synthesizing other requirements for - // the min/max tag values, or the total number of tags, so that people don't - // have to declare these as additional cases. + // TODO: At this point we have the `baseDeclRef` + // and could use it to perform further validity checks, + // and/or to build up a more refined representation of + // the inheritance graph for this type (e.g., a "class + // precedence list"). + // + // E.g., we can/should check that we aren't introducing + // a circular inheritance relationship. - enumConformanceDecl->setCheckState(DeclCheckState::DefinitionChecked); - } + _validateCrossModuleInheritance(decl, inheritanceDecl); } +} - void SemanticsDeclBodyVisitor::visitEnumDecl(EnumDecl* decl) - { - SLANG_OUTER_SCOPE_CONTEXT_DECL_RAII(this, decl); - - auto enumType = DeclRefType::create(m_astBuilder, makeDeclRef(decl)); +void SemanticsDeclBasesVisitor::visitClassDecl(ClassDecl* decl) +{ + // A `class` type can only inherit from `class` or `interface` types. + // + // Furthermore, only the first inheritance clause (in source + // order) is allowed to declare a base `class` type. + // + SLANG_OUTER_SCOPE_CONTEXT_DECL_RAII(this, decl); - auto tagType = decl->tagType; + Index inheritanceClauseCounter = 0; + for (auto inheritanceDecl : decl->getMembersOfType()) + { + Index inheritanceClauseIndex = inheritanceClauseCounter++; - auto isEnumFlags = decl->hasModifier(); + ensureDecl(inheritanceDecl, DeclCheckState::CanUseBaseOfInheritanceDecl); + auto baseType = inheritanceDecl->base.type; - // Check the enum cases in order. - for(auto caseDecl : decl->getMembersOfType()) + // It is possible that there was an error in checking the base type + // expression, and in such a case we shouldn't emit a cascading error. + // + if (const auto baseErrorType = as(baseType)) { - // Each case defines a value of the enum's type. - // - // TODO: If we ever support enum cases with payloads, - // then they would probably have a type that is a - // `FunctionType` from the payload types to the - // enum type. - // - // TODO(tfoley): the case should grab its type when - // doing its own header checking, rather than rely on this... - caseDecl->type.type = enumType; - - ensureDecl(caseDecl, DeclCheckState::DefinitionChecked); + continue; } - // For any enum case that didn't provide an explicit - // tag value, derived an appropriate tag value. - IntegerLiteralValue defaultTag = isEnumFlags ? 1 : 0; - for(auto caseDecl : decl->getMembersOfType()) + auto baseDeclRefType = as(baseType); + if (!baseDeclRefType) { - if(auto explicitTagValExpr = caseDecl->tagExpr) - { - // This tag has an initializer, so it should establish - // the tag value for a successor case that doesn't - // provide an explicit tag. - - IntVal* explicitTagVal = tryConstantFoldExpr(explicitTagValExpr, ConstantFoldingKind::CompileTime, nullptr); - if(explicitTagVal) - { - if(auto constIntVal = as(explicitTagVal)) - { - defaultTag = constIntVal->getValue(); - } - else - { - // TODO: need to handle other possibilities here - getSink()->diagnose(explicitTagValExpr, Diagnostics::unexpectedEnumTagExpr); - } - } - else - { - // If this happens, then the explicit tag value expression - // doesn't seem to be a constant after all. In this case - // we expect the checking logic to have applied already. - } - } - else - { - // This tag has no initializer, so it should use - // the default tag value we are tracking. - IntegerLiteralExpr* tagValExpr = m_astBuilder->create(); - tagValExpr->loc = caseDecl->loc; - tagValExpr->type = QualType(tagType); - tagValExpr->value = defaultTag; - - caseDecl->tagExpr = tagValExpr; - } + getSink()->diagnose( + inheritanceDecl, + Diagnostics::baseOfClassMustBeClassOrInterface, + decl, + baseType); + continue; + } - // Default tag for the next case will be one more than - // for the most recent case. + auto baseDeclRef = baseDeclRefType->getDeclRef(); + if (auto baseInterfaceDeclRef = baseDeclRef.as()) + { + } + else if (auto baseStructDeclRef = baseDeclRef.as()) + { + // To simplify the task of reading and maintaining code, + // we require that when a `class` inherits from another + // `class`, the base `class` is the first item in + // the list of bases (before any interfaces). // - if (!isEnumFlags) - defaultTag++; - else + // This constraint also has the secondary effect of restricting + // it so that a `struct` cannot multiply inherit from other + // `struct` types. + // + if (inheritanceClauseIndex != 0) { - if (defaultTag == 0) - defaultTag = 1; - else - defaultTag <<= 1; + getSink()->diagnose( + inheritanceDecl, + Diagnostics::baseClassMustBeListedFirst, + decl, + baseType); } } + else + { + getSink()->diagnose( + inheritanceDecl, + Diagnostics::baseOfClassMustBeClassOrInterface, + decl, + baseType); + continue; + } + + // TODO: At this point we have the `baseDeclRef` + // and could use it to perform further validity checks, + // and/or to build up a more refined representation of + // the inheritance graph for this type (e.g., a "class + // precedence list"). + // + // E.g., we can/should check that we aren't introducing + // a circular inheritance relationship. + + _validateCrossModuleInheritance(decl, inheritanceDecl); } +} + +bool SemanticsVisitor::isIntegerBaseType(BaseType baseType) +{ + return (BaseTypeInfo::getInfo(baseType).flags & BaseTypeInfo::Flag::Integer) != 0; +} + +bool SemanticsVisitor::isScalarIntegerType(Type* type) +{ + auto basicType = as(type); + if (!basicType) + return false; + auto baseType = basicType->getBaseType(); + return isIntegerBaseType(baseType) || baseType == BaseType::Bool; +} + +bool SemanticsVisitor::isValidCompileTimeConstantType(Type* type) +{ + return isScalarIntegerType(type) || isEnumType(type); +} + +bool SemanticsVisitor::isIntValueInRangeOfType(IntegerLiteralValue value, Type* type) +{ + auto basicType = as(type); + if (!basicType) + return false; - void SemanticsDeclBodyVisitor::visitEnumCaseDecl(EnumCaseDecl* decl) + switch (basicType->getBaseType()) { - // An enum case had better appear inside an enum! - // - // TODO: Do we need/want to support generic cases some day? - auto parentEnumDecl = as(decl->parentDecl); - SLANG_ASSERT(parentEnumDecl); - - decl->type.type = DeclRefType::create(m_astBuilder, makeDeclRef(parentEnumDecl)); + case BaseType::UInt8: + return (value >= 0 && value <= std::numeric_limits::max()) || (value == -1); + case BaseType::UInt16: + return (value >= 0 && value <= std::numeric_limits::max()) || (value == -1); + case BaseType::UInt: +#if SLANG_PTR_IS_32 + case BaseType::UIntPtr: +#endif + return (value >= 0 && value <= std::numeric_limits::max()) || (value == -1); + case BaseType::UInt64: +#if SLANG_PTR_IS_64 + case BaseType::UIntPtr: +#endif + return true; + case BaseType::Int8: + return value >= std::numeric_limits::min() && + value <= std::numeric_limits::max(); + case BaseType::Int16: + return value >= std::numeric_limits::min() && + value <= std::numeric_limits::max(); + case BaseType::Int: +#if SLANG_PTR_IS_32 + case BaseType::IntPtr: +#endif + return value >= std::numeric_limits::min() && + value <= std::numeric_limits::max(); + case BaseType::Int64: +#if SLANG_PTR_IS_64 + case BaseType::IntPtr: +#endif + return value >= std::numeric_limits::min() && + value <= std::numeric_limits::max(); + default: return false; + } +} - // The tag type should have already been set by - // the surrounding `enum` declaration. - auto tagType = parentEnumDecl->tagType; - SLANG_ASSERT(tagType); +void SemanticsVisitor::validateEnumTagType(Type* type, SourceLoc const& loc) +{ + // Allow the built-in integer types. + // + if (isScalarIntegerType(type)) + return; + + // By default, don't allow other types to be used + // as an `enum` tag type. + // + getSink()->diagnose(loc, Diagnostics::invalidEnumTagType, type); +} - // Need to check the init expression, if present, since - // that represents the explicit tag for this case. - if(auto initExpr = decl->tagExpr) - { - initExpr = CheckTerm(initExpr); - initExpr = coerce(CoercionSite::General, tagType, initExpr); +void SemanticsDeclBasesVisitor::visitEnumDecl(EnumDecl* decl) +{ + SLANG_OUTER_SCOPE_CONTEXT_DECL_RAII(this, decl); + checkVisibility(decl); + + // An `enum` type can inherit from interfaces, and also + // from a single "tag" type that must: + // + // * be a built-in integer type + // * come first in the list of base types + // + Index inheritanceClauseCounter = 0; + Type* tagType = nullptr; + InheritanceDecl* tagTypeInheritanceDecl = nullptr; + for (auto inheritanceDecl : decl->getMembersOfType()) + { + Index inheritanceClauseIndex = inheritanceClauseCounter++; + + ensureDecl(inheritanceDecl, DeclCheckState::CanUseBaseOfInheritanceDecl); + auto baseType = inheritanceDecl->base.type; - // We want to enforce that this is an integer constant - // expression, but we don't actually care to retain - // the value. - CheckIntegerConstantExpression(initExpr, IntegerConstantExpressionCoercionType::AnyInteger, nullptr, ConstantFoldingKind::CompileTime); + // It is possible that there was an error in checking the base type + // expression, and in such a case we shouldn't emit a cascading error. + // + if (const auto baseErrorType = as(baseType)) + { + continue; + } - decl->tagExpr = initExpr; + auto baseDeclRefType = as(baseType); + if (!baseDeclRefType) + { + getSink()->diagnose( + inheritanceDecl, + Diagnostics::baseOfEnumMustBeIntegerOrInterface, + decl, + baseType); + continue; } - } - void SemanticsVisitor::ensureDeclBase(DeclBase* declBase, DeclCheckState state, SemanticsContext* baseContext) - { - if(auto decl = as(declBase)) + auto baseDeclRef = baseDeclRefType->getDeclRef(); + if (auto baseInterfaceDeclRef = baseDeclRef.as()) { - ensureDecl(decl, state, baseContext); + _validateCrossModuleInheritance(decl, inheritanceDecl); } - else if(auto declGroup = as(declBase)) + else if (auto baseStructDeclRef = baseDeclRef.as()) { - for(auto dd : declGroup->decls) + // To simplify the task of reading and maintaining code, + // we require that when an `enum` declares an explicit + // underlying tag type using an inheritance clause, that + // type must be the first item in the list of bases. + // + // This constraint also has the secondary effect of restricting + // it so that an `enum` can't possibly have multiple tag + // types declared. + // + if (inheritanceClauseIndex != 0) + { + getSink()->diagnose( + inheritanceDecl, + Diagnostics::tagTypeMustBeListedFirst, + decl, + baseType); + } + else { - ensureDecl(dd, state, baseContext); + tagType = baseType; + tagTypeInheritanceDecl = inheritanceDecl; } + + // Note: we do *not* apply the code that validates + // cross-module inheritance to a base that represnts + // a tag type, because declaring a tag type for an + // `enum` doesn't actually make it into a subtype + // of the tag type, and thus doesn't violate the + // rules when the tag type is `sealed`. } else { - SLANG_UNEXPECTED("unknown case for declaration"); + getSink()->diagnose( + inheritanceDecl, + Diagnostics::baseOfEnumMustBeIntegerOrInterface, + decl, + baseType); + continue; } } - void SemanticsDeclHeaderVisitor::visitTypeDefDecl(TypeDefDecl* decl) + // If a tag type has not been set, then we + // default it to the built-in `int` type. + // + // TODO: In the far-flung future we may want to distinguish + // `enum` types that have a "raw representation" like this from + // ones that are purely abstract and don't expose their + // type of their tag. + // + if (!tagType) { - SemanticsVisitor visitor(withDeclToExcludeFromLookup(decl)); - decl->type = visitor.CheckProperType(decl->type); - checkVisibility(decl); + tagType = m_astBuilder->getIntType(); } - - void SemanticsDeclHeaderVisitor::visitGlobalGenericParamDecl(GlobalGenericParamDecl* decl) + else { - // global generic param only allowed in global scope - auto program = as(decl->parentDecl); - if (!program) - getSink()->diagnose(decl, Slang::Diagnostics::globalGenParamInGlobalScopeOnly); - } + // TODO: Need to establish that the tag + // type is suitable. (e.g., if we are going + // to allow raw values for case tags to be + // derived automatically, then the tag + // type needs to be some kind of integer type...) + // + // For now we will just be harsh and require it + // to be one of a few builtin types. + validateEnumTagType(tagType, tagTypeInheritanceDecl->loc); - void SemanticsDeclHeaderVisitor::visitAssocTypeDecl(AssocTypeDecl* decl) - { - // assoctype only allowed in an interface - auto interfaceDecl = as(decl->parentDecl); - if (!interfaceDecl) - getSink()->diagnose(decl, Slang::Diagnostics::assocTypeInInterfaceOnly); - checkVisibility(decl); + // Note: The `InheritanceDecl` that introduces a tag + // type isn't actually representing a super-type of + // the `enum`, and things like name lookup need to + // know to ignore that "inheritance" relationship. + // + // We add a modifier to the `InheritanceDecl` to ensure + // that it can be detected and ignored by such steps. + // + addModifier(tagTypeInheritanceDecl, m_astBuilder->create()); } + decl->tagType = tagType; + - SemanticsContext SemanticsDeclBodyVisitor::registerDifferentiableTypesForFunc(FunctionDeclBase* decl) + // An `enum` type should automatically conform to the `__EnumType` interface. + // The compiler needs to insert this conformance behind the scenes, and this + // seems like the best place to do it. { - auto newContext = withParentFunc(decl); - if (newContext.getParentDifferentiableAttribute()) + // First, look up the type of the `__EnumType` interface. + Type* enumTypeType = getASTBuilder()->getEnumTypeType(); + + InheritanceDecl* enumConformanceDecl = m_astBuilder->create(); + enumConformanceDecl->parentDecl = decl; + enumConformanceDecl->loc = decl->loc; + enumConformanceDecl->base.type = getASTBuilder()->getEnumTypeType(); + decl->members.add(enumConformanceDecl); + + // The `__EnumType` interface has one required member, the `__Tag` type. + // We need to satisfy this requirement automatically, rather than require + // the user to actually declare a member with this name (otherwise we wouldn't + // let them define a tag value with the name `__Tag`). + // + RefPtr witnessTable = new WitnessTable(); + witnessTable->baseType = enumConformanceDecl->base.type; + witnessTable->witnessedType = enumTypeType; + enumConformanceDecl->witnessTable = witnessTable; + + Name* tagAssociatedTypeName = getSession()->getNameObj("__Tag"); + Decl* tagAssociatedTypeDecl = nullptr; + if (auto enumTypeTypeDeclRefType = dynamicCast(enumTypeType)) { - // Register additional types outside the function body first. - auto oldAttr = m_parentDifferentiableAttr; - m_parentDifferentiableAttr = newContext.getParentDifferentiableAttribute(); - for (auto param : decl->getParameters()) - maybeRegisterDifferentiableType(m_astBuilder, param->type.type); - maybeRegisterDifferentiableType(m_astBuilder, decl->returnType.type); - if (as(decl) || !isEffectivelyStatic(decl)) + if (auto enumTypeTypeInterfaceDecl = + as(enumTypeTypeDeclRefType->getDeclRef().getDecl())) { - auto parentDecl = getParentDecl(decl); - auto parentDeclRef = createDefaultSubstitutionsIfNeeded(m_astBuilder, this, makeDeclRef(parentDecl)); - auto thisType = calcThisType(parentDeclRef); - maybeRegisterDifferentiableType(m_astBuilder, thisType); + for (auto memberDecl : enumTypeTypeInterfaceDecl->members) + { + if (memberDecl->getName() == tagAssociatedTypeName) + { + tagAssociatedTypeDecl = memberDecl; + break; + } + } } - m_parentDifferentiableAttr = oldAttr; } - return newContext; - } - - void SemanticsDeclBodyVisitor::visitFunctionDeclBase(FunctionDeclBase* decl) - { - auto newContext = registerDifferentiableTypesForFunc(decl); - if (const auto body = decl->body) + if (!tagAssociatedTypeDecl) { - checkStmt(decl->body, newContext); + SLANG_DIAGNOSE_UNEXPECTED( + getSink(), + decl, + "failed to find built-in declaration '__Tag'"); } - } - void SemanticsVisitor::getGenericParams( - GenericDecl* decl, - List& outParams, - List& outConstraints) - { - for (auto dd : decl->members) - { - if (dd == decl->inner) - continue; + // Okay, add the conformance witness for `__Tag` being satisfied by `tagType` + witnessTable->add(tagAssociatedTypeDecl, RequirementWitness(tagType)); - if (auto typeParamDecl = as(dd)) - outParams.add(typeParamDecl); - else if (auto valueParamDecl = as(dd)) - outParams.add(valueParamDecl); - else if (auto constraintDecl = as(dd)) - outConstraints.add(constraintDecl); - } + // TODO: we actually also need to synthesize a witness for the conformance of `tagType` + // to the `__BuiltinIntegerType` interface, because that is a constraint on the + // associated type `__Tag`. + + // TODO: eventually we should consider synthesizing other requirements for + // the min/max tag values, or the total number of tags, so that people don't + // have to declare these as additional cases. + + enumConformanceDecl->setCheckState(DeclCheckState::DefinitionChecked); } +} + +void SemanticsDeclBodyVisitor::visitEnumDecl(EnumDecl* decl) +{ + SLANG_OUTER_SCOPE_CONTEXT_DECL_RAII(this, decl); - bool SemanticsVisitor::doGenericSignaturesMatch( - GenericDecl* left, - GenericDecl* right, - DeclRef* outSpecializedRightInner) + auto enumType = DeclRefType::create(m_astBuilder, makeDeclRef(decl)); + + auto tagType = decl->tagType; + + auto isEnumFlags = decl->hasModifier(); + + // Check the enum cases in order. + for (auto caseDecl : decl->getMembersOfType()) { - // Our first goal here is to determine if `left` and - // `right` have equivalent lists of explicit - // generic parameters. - // - // Once we have determined that the explicit generic - // parameters match, we will look at the constraints - // placed on those parameters to see if they are - // equivalent. + // Each case defines a value of the enum's type. // - // We thus start by extracting the explicit parameters - // and the constraints from each declaration. + // TODO: If we ever support enum cases with payloads, + // then they would probably have a type that is a + // `FunctionType` from the payload types to the + // enum type. // - List leftParams; - List leftConstraints; - getGenericParams(left, leftParams, leftConstraints); + // TODO(tfoley): the case should grab its type when + // doing its own header checking, rather than rely on this... + caseDecl->type.type = enumType; - List rightParams; - List rightConstraints; - getGenericParams(right, rightParams, rightConstraints); - - // For there to be any hope of a match, the two decls - // need to have the same number of explicit parameters. - // - Index paramCount = leftParams.getCount(); - if(paramCount != rightParams.getCount()) - return false; + ensureDecl(caseDecl, DeclCheckState::DefinitionChecked); + } - // Next we will walk through the parameters and look - // for a pair-wise match. - // - for(Index pp = 0; pp < paramCount; ++pp) + // For any enum case that didn't provide an explicit + // tag value, derived an appropriate tag value. + IntegerLiteralValue defaultTag = isEnumFlags ? 1 : 0; + for (auto caseDecl : decl->getMembersOfType()) + { + if (auto explicitTagValExpr = caseDecl->tagExpr) { - Decl* leftParam = leftParams[pp]; - Decl* rightParam = rightParams[pp]; + // This tag has an initializer, so it should establish + // the tag value for a successor case that doesn't + // provide an explicit tag. - if (const auto leftTypeParam = as(leftParam)) + IntVal* explicitTagVal = + tryConstantFoldExpr(explicitTagValExpr, ConstantFoldingKind::CompileTime, nullptr); + if (explicitTagVal) { - if (const auto rightTypeParam = as(rightParam)) + if (auto constIntVal = as(explicitTagVal)) { - // Right now any two type parameters are a match. - // Names are irrelevant to matching, and any constraints - // on the type parameters are represented as implicit - // extra parameters of the generic. - // - // TODO: If we ever supported type parameters with - // higher kinds we might need to make a check here - // that the kind of each parameter matches (which - // would in a sense be a kind of recursive check - // of the generic signature of the parameter). - // - continue; + defaultTag = constIntVal->getValue(); } - } - else if (auto leftValueParam = as(leftParam)) - { - if (auto rightValueParam = as(rightParam)) + else { - // In this case we have two generic value parameters, - // and they should only be considered to match if - // they have the same type. - // - // Note: We are assuming here that the type of a value - // parameter cannot be dependent on any of the type - // parameters in the same signature. This is a reasonable - // assumption for now, but could get thorny down the road. - // - if (!leftValueParam->getType()->equals(rightValueParam->getType())) - { - // If the value parameters have non-matching types, - // then the full generic signatures do not match. - // - return false; - } - - // Generic value parameters with the same type are - // always considered to match. - // - continue; + // TODO: need to handle other possibilities here + getSink()->diagnose(explicitTagValExpr, Diagnostics::unexpectedEnumTagExpr); } } - - // If we get to this point, then we have two parameters that - // were of different syntatic categories (e.g., one type parameter - // and one value parameter), so the signatures clearly don't match. - // - return false; + else + { + // If this happens, then the explicit tag value expression + // doesn't seem to be a constant after all. In this case + // we expect the checking logic to have applied already. + } } + else + { + // This tag has no initializer, so it should use + // the default tag value we are tracking. + IntegerLiteralExpr* tagValExpr = m_astBuilder->create(); + tagValExpr->loc = caseDecl->loc; + tagValExpr->type = QualType(tagType); + tagValExpr->value = defaultTag; - // At this point we know that the explicit generic parameters - // of `left` and `right` are aligned, but we need to check - // that the constraints that each declaration places on - // its parameters match. - // - // A first challenge that arises is that `left` and `right` - // will each express the constraints in terms of their - // own parameters. For example, consider the following - // declarations: - // - // void foo1(T value); - // void foo2(U value); - // - // It is "obvious" to a human that the signatures here - // match, but `foo1` has a constraint `T : IFoo` while - // `foo2` has a constraint `U : IFoo`, and since `T` - // and `U` are distinct `Decl`s, those constraints - // are not obviously equivalent. - // - // We will work around this first issue by creating - // a substitution taht lists all the parameters of - // `left`, which we can use to specialize `right` - // so that it aligns. - // - // In terms of the example above, this is like constructing - // `foo2` so that its constraint, after specialization, - // looks like `T : IFoo`. - // - auto& substInnerRightToLeft = *outSpecializedRightInner; - List leftArgs = getDefaultSubstitutionArgs(m_astBuilder, this, left); - substInnerRightToLeft = m_astBuilder->getGenericAppDeclRef(makeDeclRef(right), leftArgs.getArrayView()); - - // We should now be able to enumerate the constraints - // on `right` in a way that uses the same type parameters - // as `left`, using `rightDeclRef`. - // - // At this point a second problem arises: if/when we support - // more flexibility in how generic parameter constraints are - // specified, it will be possible for two declarations to - // list the "same" constraints in very different ways. - // - // For example, if we support a `where` clause for separating - // the constraints from the parameters, then the following - // two declarations should have equivalent signatures: - // - // void foo1(T value) - // where T : IFoo - // { ... } - // - // void foo2(T value) - // { ... } - // - // Similarly, if we allow for general compositions of interfaces - // to be used as constraints, then there can be more than one - // way to specify the same constraints: - // - // void foo1(T value); - // void foo2(T value); - // - // Adding support for equality constraints in `where` clauses - // also creates opportunities for multiple equivalent expressions: - // - // void foo1(...) where T.A == U.A; - // void foo2(...) where U.A == T.A; - // - // A robsut version of the checking logic here should attempt - // to *canonicalize* all of the constraints. Canonicalization - // should involve putting constraints into a deterministic - // order (e.g., for a generic with `` all the constraints - // on `T` should come before those on `U`), rewriting individual - // constraints into a canonical form (e.g., `T : IFoo & IBar` - // should turn into two constraints: `T : IFoo` and `T : IBar`), - // etc. - // - // Once the constraints are in a canonical form we should be able - // to test them for pairwise equivalent. As a safety measure we - // could also try to test whether one set of constraints implies - // the other (since implication in both directions should imply - // equivalence, in which case our canonicalization had better - // have produced the same result). - // - // For now we are taking a simpler short-cut by assuming - // that constraints are already in a canonical form, which - // is reasonable for now as the syntax only allows a single - // constraint per parameter, specified on the parameter itself. - // - // Under the assumption of canonical constraints, we can - // assume that different numbers of constraints must indicate - // a signature mismatch. - // - Index constraintCount = leftConstraints.getCount(); - if(constraintCount != rightConstraints.getCount()) - return false; + caseDecl->tagExpr = tagValExpr; + } - for (Index cc = 0; cc < constraintCount; ++cc) + // Default tag for the next case will be one more than + // for the most recent case. + // + if (!isEnumFlags) + defaultTag++; + else { - // Note that we use a plain `Decl` pointer for the left - // constraint, but need to use a `DeclRef` for the right - // constraint so that we can take the substitution - // arguments into account. - // - GenericTypeConstraintDecl* leftConstraint = leftConstraints[cc]; - auto unspecializedRightConstarintDeclRef = createDefaultSubstitutionsIfNeeded(m_astBuilder, this, makeDeclRef(rightConstraints[cc])); - DeclRef rightConstraint = substInnerRightToLeft.substitute( - m_astBuilder, unspecializedRightConstarintDeclRef).as(); + if (defaultTag == 0) + defaultTag = 1; + else + defaultTag <<= 1; + } + } +} - // For now, every constraint has the form `sub : sup` - // to indicate that `sub` must be a subtype of `sup`. - // - // Two such constraints are equivalent if their `sub` - // and `sup` types are pairwise equivalent. - // - auto leftSub = leftConstraint->sub.type; - auto rightSub = substInnerRightToLeft.substitute(m_astBuilder, rightConstraint.getDecl()->sub.type); - if(!leftSub->equals(rightSub)) - return false; +void SemanticsDeclBodyVisitor::visitEnumCaseDecl(EnumCaseDecl* decl) +{ + // An enum case had better appear inside an enum! + // + // TODO: Do we need/want to support generic cases some day? + auto parentEnumDecl = as(decl->parentDecl); + SLANG_ASSERT(parentEnumDecl); + + decl->type.type = DeclRefType::create(m_astBuilder, makeDeclRef(parentEnumDecl)); + + // The tag type should have already been set by + // the surrounding `enum` declaration. + auto tagType = parentEnumDecl->tagType; + SLANG_ASSERT(tagType); + + // Need to check the init expression, if present, since + // that represents the explicit tag for this case. + if (auto initExpr = decl->tagExpr) + { + initExpr = CheckTerm(initExpr); + initExpr = coerce(CoercionSite::General, tagType, initExpr); + + // We want to enforce that this is an integer constant + // expression, but we don't actually care to retain + // the value. + CheckIntegerConstantExpression( + initExpr, + IntegerConstantExpressionCoercionType::AnyInteger, + nullptr, + ConstantFoldingKind::CompileTime); - auto leftSup = leftConstraint->sup.type; - auto rightSup = substInnerRightToLeft.substitute(m_astBuilder, rightConstraint.getDecl()->sup.type); - if(!leftSup->equals(rightSup)) - return false; - } + decl->tagExpr = initExpr; + } +} - // If we have checked all of the (canonicalized) constraints - // and found them to be pairwise equivalent then the two - // generic signatures seem to match. - // - return true; +void SemanticsVisitor::ensureDeclBase( + DeclBase* declBase, + DeclCheckState state, + SemanticsContext* baseContext) +{ + if (auto decl = as(declBase)) + { + ensureDecl(decl, state, baseContext); + } + else if (auto declGroup = as(declBase)) + { + for (auto dd : declGroup->decls) + { + ensureDecl(dd, state, baseContext); + } } - - bool SemanticsVisitor::doFunctionSignaturesMatch( - DeclRef fst, - DeclRef snd) + else { + SLANG_UNEXPECTED("unknown case for declaration"); + } +} - // TODO(tfoley): This copies the parameter array, which is bad for performance. - auto fstParams = getParameters(m_astBuilder, fst).toArray(); - auto sndParams = getParameters(m_astBuilder, snd).toArray(); - - // If the functions have different numbers of parameters, then - // their signatures trivially don't match. - auto fstParamCount = fstParams.getCount(); - auto sndParamCount = sndParams.getCount(); - if (fstParamCount != sndParamCount) - return false; - - for (Index ii = 0; ii < fstParamCount; ++ii) - { - auto fstParam = fstParams[ii]; - auto sndParam = sndParams[ii]; +void SemanticsDeclHeaderVisitor::visitTypeDefDecl(TypeDefDecl* decl) +{ + SemanticsVisitor visitor(withDeclToExcludeFromLookup(decl)); + decl->type = visitor.CheckProperType(decl->type); + checkVisibility(decl); +} - // If a given parameter type doesn't match, then signatures don't match - if (!getType(m_astBuilder, fstParam)->equals(getType(m_astBuilder, sndParam))) - return false; +void SemanticsDeclHeaderVisitor::visitGlobalGenericParamDecl(GlobalGenericParamDecl* decl) +{ + // global generic param only allowed in global scope + auto program = as(decl->parentDecl); + if (!program) + getSink()->diagnose(decl, Slang::Diagnostics::globalGenParamInGlobalScopeOnly); +} - // If one parameter is `out` and the other isn't, then they don't match - // - // Note(tfoley): we don't consider `out` and `inout` as distinct here, - // because there is no way for overload resolution to pick between them. - if (fstParam.getDecl()->hasModifier() != sndParam.getDecl()->hasModifier()) - return false; +void SemanticsDeclHeaderVisitor::visitAssocTypeDecl(AssocTypeDecl* decl) +{ + // assoctype only allowed in an interface + auto interfaceDecl = as(decl->parentDecl); + if (!interfaceDecl) + getSink()->diagnose(decl, Slang::Diagnostics::assocTypeInInterfaceOnly); + checkVisibility(decl); +} - // If one parameter is `ref` and the other isn't, then they don't match. - // - if(fstParam.getDecl()->hasModifier() != sndParam.getDecl()->hasModifier()) - return false; +SemanticsContext SemanticsDeclBodyVisitor::registerDifferentiableTypesForFunc( + FunctionDeclBase* decl) +{ + auto newContext = withParentFunc(decl); + if (newContext.getParentDifferentiableAttribute()) + { + // Register additional types outside the function body first. + auto oldAttr = m_parentDifferentiableAttr; + m_parentDifferentiableAttr = newContext.getParentDifferentiableAttribute(); + for (auto param : decl->getParameters()) + maybeRegisterDifferentiableType(m_astBuilder, param->type.type); + maybeRegisterDifferentiableType(m_astBuilder, decl->returnType.type); + if (as(decl) || !isEffectivelyStatic(decl)) + { + auto parentDecl = getParentDecl(decl); + auto parentDeclRef = + createDefaultSubstitutionsIfNeeded(m_astBuilder, this, makeDeclRef(parentDecl)); + auto thisType = calcThisType(parentDeclRef); + maybeRegisterDifferentiableType(m_astBuilder, thisType); + } + m_parentDifferentiableAttr = oldAttr; + } + return newContext; +} - // If one parameter is `constref` and the other isn't, then they don't match. - // - if (fstParam.getDecl()->hasModifier() != sndParam.getDecl()->hasModifier()) - return false; - } +void SemanticsDeclBodyVisitor::visitFunctionDeclBase(FunctionDeclBase* decl) +{ + auto newContext = registerDifferentiableTypesForFunc(decl); + if (const auto body = decl->body) + { + checkStmt(decl->body, newContext); + } +} - // Note(tfoley): return type doesn't enter into it, because we can't take - // calling context into account during overload resolution. +void SemanticsVisitor::getGenericParams( + GenericDecl* decl, + List& outParams, + List& outConstraints) +{ + for (auto dd : decl->members) + { + if (dd == decl->inner) + continue; - return true; + if (auto typeParamDecl = as(dd)) + outParams.add(typeParamDecl); + else if (auto valueParamDecl = as(dd)) + outParams.add(valueParamDecl); + else if (auto constraintDecl = as(dd)) + outConstraints.add(constraintDecl); } +} + +bool SemanticsVisitor::doGenericSignaturesMatch( + GenericDecl* left, + GenericDecl* right, + DeclRef* outSpecializedRightInner) +{ + // Our first goal here is to determine if `left` and + // `right` have equivalent lists of explicit + // generic parameters. + // + // Once we have determined that the explicit generic + // parameters match, we will look at the constraints + // placed on those parameters to see if they are + // equivalent. + // + // We thus start by extracting the explicit parameters + // and the constraints from each declaration. + // + List leftParams; + List leftConstraints; + getGenericParams(left, leftParams, leftConstraints); + + List rightParams; + List rightConstraints; + getGenericParams(right, rightParams, rightConstraints); + + // For there to be any hope of a match, the two decls + // need to have the same number of explicit parameters. + // + Index paramCount = leftParams.getCount(); + if (paramCount != rightParams.getCount()) + return false; - List getDefaultSubstitutionArgs(ASTBuilder* astBuilder, SemanticsVisitor* semantics, GenericDecl* genericDecl) + // Next we will walk through the parameters and look + // for a pair-wise match. + // + for (Index pp = 0; pp < paramCount; ++pp) { - List args; - if (astBuilder->m_cachedGenericDefaultArgs.tryGetValue(genericDecl, args)) - return args; + Decl* leftParam = leftParams[pp]; + Decl* rightParam = rightParams[pp]; - for (auto mm : genericDecl->members) + if (const auto leftTypeParam = as(leftParam)) { - if (auto genericTypeParamDecl = as(mm)) - { - args.add(DeclRefType::create(astBuilder, astBuilder->getDirectDeclRef(genericTypeParamDecl))); - } - else if (auto genericTypePackParamDecl = as(mm)) - { - auto packType = DeclRefType::create(astBuilder, astBuilder->getDirectDeclRef(genericTypePackParamDecl)); - args.add(packType); - } - else if (auto genericValueParamDecl = as(mm)) + if (const auto rightTypeParam = as(rightParam)) { - if (semantics) - semantics->ensureDecl(genericValueParamDecl, DeclCheckState::ReadyForLookup); - - args.add(astBuilder->getOrCreate( - genericValueParamDecl->getType(), - astBuilder->getDirectDeclRef(genericValueParamDecl))); + // Right now any two type parameters are a match. + // Names are irrelevant to matching, and any constraints + // on the type parameters are represented as implicit + // extra parameters of the generic. + // + // TODO: If we ever supported type parameters with + // higher kinds we might need to make a check here + // that the kind of each parameter matches (which + // would in a sense be a kind of recursive check + // of the generic signature of the parameter). + // + continue; } } - - bool shouldCache = true; - - // create default substitution arguments for constraints - for (auto mm : genericDecl->members) + else if (auto leftValueParam = as(leftParam)) { - if (auto genericTypeConstraintDecl = as(mm)) + if (auto rightValueParam = as(rightParam)) { - if (semantics) - semantics->ensureDecl(genericTypeConstraintDecl, DeclCheckState::ReadyForReference); - auto constraintDeclRef = astBuilder->getDirectDeclRef(genericTypeConstraintDecl); - auto supType = getSup(astBuilder, constraintDeclRef); - if (!supType) + // In this case we have two generic value parameters, + // and they should only be considered to match if + // they have the same type. + // + // Note: We are assuming here that the type of a value + // parameter cannot be dependent on any of the type + // parameters in the same signature. This is a reasonable + // assumption for now, but could get thorny down the road. + // + if (!leftValueParam->getType()->equals(rightValueParam->getType())) { - args.add(astBuilder->getErrorType()); - shouldCache = false; - continue; + // If the value parameters have non-matching types, + // then the full generic signatures do not match. + // + return false; } - auto witness = - astBuilder->getDeclaredSubtypeWitness( - getSub(astBuilder, constraintDeclRef), - getSup(astBuilder, constraintDeclRef), - constraintDeclRef); - // TODO: this is an ugly hack to prevent crashing. - // In early stages of compilation witness->sub and witness->sup may not be checked yet. - // When semanticVisitor is present we have used that to ensure the type is checked. - // However due to how the code is written we cannot guarantee semanticVisitor is always available - // here, and if we can't get the checked sup/sub type this subst is incomplete and should not be - // cached. - if (!witness->getSub()) - shouldCache = false; - args.add(witness); - } - } - - if (shouldCache) - astBuilder->m_cachedGenericDefaultArgs[genericDecl] = args; - - return args; - } - - typedef Dictionary TargetDeclDictionary; - - static void _addTargetModifiers(CallableDecl* decl, TargetDeclDictionary& ioDict) - { - if (auto specializedModifier = decl->findModifier()) - { - // If it's specialized for target it should have a body... - if (auto funcDecl = as(decl)) - { - // Normally if we have specialization for target it must have a body. - if (funcDecl->body == nullptr) - { - // If it doesn't have a body but does have a target intrinsic/SPIRVInstructionOp - // it's probably ok - SLANG_ASSERT(funcDecl->findModifier() || - funcDecl->findModifier()); - } + // Generic value parameters with the same type are + // always considered to match. + // + continue; } - Name* targetName = specializedModifier->targetToken.getName(); - - ioDict.addIfNotExists(targetName, decl); } - else - { - for (auto modifier : decl->getModifiersOfType()) - { - Name* targetName = modifier->targetToken.getName(); - ioDict.addIfNotExists(targetName, decl); - } - auto funcDecl = as(decl); - if (funcDecl && funcDecl->body) - { - // Should only be one body if it isn't specialized for target. - // Use nullptr for this scenario - ioDict.addIfNotExists(nullptr, decl); - } - } + // If we get to this point, then we have two parameters that + // were of different syntatic categories (e.g., one type parameter + // and one value parameter), so the signatures clearly don't match. + // + return false; } - Result SemanticsVisitor::checkFuncRedeclaration( - FuncDecl* newDecl, - FuncDecl* oldDecl) + // At this point we know that the explicit generic parameters + // of `left` and `right` are aligned, but we need to check + // that the constraints that each declaration places on + // its parameters match. + // + // A first challenge that arises is that `left` and `right` + // will each express the constraints in terms of their + // own parameters. For example, consider the following + // declarations: + // + // void foo1(T value); + // void foo2(U value); + // + // It is "obvious" to a human that the signatures here + // match, but `foo1` has a constraint `T : IFoo` while + // `foo2` has a constraint `U : IFoo`, and since `T` + // and `U` are distinct `Decl`s, those constraints + // are not obviously equivalent. + // + // We will work around this first issue by creating + // a substitution taht lists all the parameters of + // `left`, which we can use to specialize `right` + // so that it aligns. + // + // In terms of the example above, this is like constructing + // `foo2` so that its constraint, after specialization, + // looks like `T : IFoo`. + // + auto& substInnerRightToLeft = *outSpecializedRightInner; + List leftArgs = getDefaultSubstitutionArgs(m_astBuilder, this, left); + substInnerRightToLeft = + m_astBuilder->getGenericAppDeclRef(makeDeclRef(right), leftArgs.getArrayView()); + + // We should now be able to enumerate the constraints + // on `right` in a way that uses the same type parameters + // as `left`, using `rightDeclRef`. + // + // At this point a second problem arises: if/when we support + // more flexibility in how generic parameter constraints are + // specified, it will be possible for two declarations to + // list the "same" constraints in very different ways. + // + // For example, if we support a `where` clause for separating + // the constraints from the parameters, then the following + // two declarations should have equivalent signatures: + // + // void foo1(T value) + // where T : IFoo + // { ... } + // + // void foo2(T value) + // { ... } + // + // Similarly, if we allow for general compositions of interfaces + // to be used as constraints, then there can be more than one + // way to specify the same constraints: + // + // void foo1(T value); + // void foo2(T value); + // + // Adding support for equality constraints in `where` clauses + // also creates opportunities for multiple equivalent expressions: + // + // void foo1(...) where T.A == U.A; + // void foo2(...) where U.A == T.A; + // + // A robsut version of the checking logic here should attempt + // to *canonicalize* all of the constraints. Canonicalization + // should involve putting constraints into a deterministic + // order (e.g., for a generic with `` all the constraints + // on `T` should come before those on `U`), rewriting individual + // constraints into a canonical form (e.g., `T : IFoo & IBar` + // should turn into two constraints: `T : IFoo` and `T : IBar`), + // etc. + // + // Once the constraints are in a canonical form we should be able + // to test them for pairwise equivalent. As a safety measure we + // could also try to test whether one set of constraints implies + // the other (since implication in both directions should imply + // equivalence, in which case our canonicalization had better + // have produced the same result). + // + // For now we are taking a simpler short-cut by assuming + // that constraints are already in a canonical form, which + // is reasonable for now as the syntax only allows a single + // constraint per parameter, specified on the parameter itself. + // + // Under the assumption of canonical constraints, we can + // assume that different numbers of constraints must indicate + // a signature mismatch. + // + Index constraintCount = leftConstraints.getCount(); + if (constraintCount != rightConstraints.getCount()) + return false; + + for (Index cc = 0; cc < constraintCount; ++cc) { - // There are a few different cases that this function needs - // to check for: - // - // * If `newDecl` and `oldDecl` have different signatures such - // that they can always be distinguished at call sites, then - // they don't conflict and don't count as redeclarations. + // Note that we use a plain `Decl` pointer for the left + // constraint, but need to use a `DeclRef` for the right + // constraint so that we can take the substitution + // arguments into account. // - // * If `newDecl` and `oldDecl` have matching signatures, but - // differ in return type (or other details that would affect - // compatibility), then the declarations conflict and an - // error needs to be diagnosed. - // - // * If `newDecl` and `oldDecl` have matching/compatible sigantures, - // but differ when it comes to target-specific overloading, - // then they can co-exist. - // - // * If `newDecl` and `oldDecl` have matching/compatible signatures - // and are specialized for the same target(s), then only - // one can have a body (in which case the other is a forward declaration), - // or else we have a redefinition error. - - auto newGenericDecl = as(newDecl->parentDecl); - auto oldGenericDecl = as(oldDecl->parentDecl); + GenericTypeConstraintDecl* leftConstraint = leftConstraints[cc]; + auto unspecializedRightConstarintDeclRef = createDefaultSubstitutionsIfNeeded( + m_astBuilder, + this, + makeDeclRef(rightConstraints[cc])); + DeclRef rightConstraint = + substInnerRightToLeft.substitute(m_astBuilder, unspecializedRightConstarintDeclRef) + .as(); - // If one declaration is a prefix/postfix operator, and the - // other is not a matching operator, then don't consider these - // to be re-declarations. + // For now, every constraint has the form `sub : sup` + // to indicate that `sub` must be a subtype of `sup`. // - // Note(tfoley): Any attempt to call such an operator using - // ordinary function-call syntax (if we decided to allow it) - // would be ambiguous in such a case, of course. + // Two such constraints are equivalent if their `sub` + // and `sup` types are pairwise equivalent. // - if (newDecl->hasModifier() != oldDecl->hasModifier()) - return SLANG_OK; - if (newDecl->hasModifier() != oldDecl->hasModifier()) - return SLANG_OK; + auto leftSub = leftConstraint->sub.type; + auto rightSub = + substInnerRightToLeft.substitute(m_astBuilder, rightConstraint.getDecl()->sub.type); + if (!leftSub->equals(rightSub)) + return false; - // If one is generic and the other isn't, then there is no match. - if ((newGenericDecl != nullptr) != (oldGenericDecl != nullptr)) - return SLANG_OK; + auto leftSup = leftConstraint->sup.type; + auto rightSup = + substInnerRightToLeft.substitute(m_astBuilder, rightConstraint.getDecl()->sup.type); + if (!leftSup->equals(rightSup)) + return false; + } - // We are going to be comparing the signatures of the - // two functions, but if they are *generic* functions - // then we will need to compare them with consistent - // specializations in place. - // - // We'll go ahead and create some (unspecialized) declaration - // references here, just to be prepared. - // - DeclRef newDeclRef(newDecl); - DeclRef oldDeclRef(oldDecl); + // If we have checked all of the (canonicalized) constraints + // and found them to be pairwise equivalent then the two + // generic signatures seem to match. + // + return true; +} - // If we are working with generic functions, then we need to - // consider if their generic signatures match. - // - if(newGenericDecl) - { - // If one declaration is generic, the other must be. - // (This condition was already checked above) - // - SLANG_ASSERT(oldGenericDecl); +bool SemanticsVisitor::doFunctionSignaturesMatch(DeclRef fst, DeclRef snd) +{ - // As part of checking if the generic signatures match, - // we will produce a substitution that can be used to - // reference `oldGenericDecl` with the generic parameters - // substituted for those of `newDecl`. - // - // One way to think about it is that if we have these - // declarations (ignore the name differences...): - // - // // oldDecl: - // void foo1(T x); - // - // // newDecl: - // void foo2(U x); - // - // Then we will compare the parameter types of `foo2` - // against the specialization `foo1`. - // - DeclRef specializedOldDeclInner; - if(!doGenericSignaturesMatch(newGenericDecl, oldGenericDecl, &specializedOldDeclInner)) - return SLANG_OK; + // TODO(tfoley): This copies the parameter array, which is bad for performance. + auto fstParams = getParameters(m_astBuilder, fst).toArray(); + auto sndParams = getParameters(m_astBuilder, snd).toArray(); - oldDeclRef = specializedOldDeclInner.as(); - } + // If the functions have different numbers of parameters, then + // their signatures trivially don't match. + auto fstParamCount = fstParams.getCount(); + auto sndParamCount = sndParams.getCount(); + if (fstParamCount != sndParamCount) + return false; - // If the parameter signatures don't match, then don't worry - if (!doFunctionSignaturesMatch(newDeclRef, oldDeclRef)) - return SLANG_OK; + for (Index ii = 0; ii < fstParamCount; ++ii) + { + auto fstParam = fstParams[ii]; + auto sndParam = sndParams[ii]; - // If the declatation is declared by 'extern', and new definition is with 'export', then - // we should let overload resolution to handle it. - if (oldDecl->hasModifier() && newDecl->hasModifier()) - { - return SLANG_OK; - } + // If a given parameter type doesn't match, then signatures don't match + if (!getType(m_astBuilder, fstParam)->equals(getType(m_astBuilder, sndParam))) + return false; - // If we get this far, then we've got two declarations in the same - // scope, with the same name and signature, so they appear - // to be redeclarations. + // If one parameter is `out` and the other isn't, then they don't match // - // We will track that redeclaration occured, so that we can - // take it into account for overload resolution. + // Note(tfoley): we don't consider `out` and `inout` as distinct here, + // because there is no way for overload resolution to pick between them. + if (fstParam.getDecl()->hasModifier() != + sndParam.getDecl()->hasModifier()) + return false; + + // If one parameter is `ref` and the other isn't, then they don't match. // - // A huge complication that we'll need to deal with is that - // multiple declarations might introduce default values for - // (different) parameters, and we might need to merge across - // all of them (which could get complicated if defaults for - // parameters can reference earlier parameters). + if (fstParam.getDecl()->hasModifier() != + sndParam.getDecl()->hasModifier()) + return false; - // If the previous declaration wasn't already recorded - // as being part of a redeclaration family, then make - // it the primary declaration of a new family. - if (!oldDecl->primaryDecl) - { - oldDecl->primaryDecl = oldDecl; - } + // If one parameter is `constref` and the other isn't, then they don't match. + // + if (fstParam.getDecl()->hasModifier() != + sndParam.getDecl()->hasModifier()) + return false; + } - // The new declaration will belong to the family of - // the previous one, and so it will share the same - // primary declaration. - newDecl->primaryDecl = oldDecl->primaryDecl; - newDecl->nextDecl = nullptr; + // Note(tfoley): return type doesn't enter into it, because we can't take + // calling context into account during overload resolution. - // Next we want to chain the new declaration onto - // the linked list of redeclarations. - auto link = &oldDecl->nextDecl; - while (*link) - link = &(*link)->nextDecl; - *link = newDecl; + return true; +} - // Now that we've added things to a group of redeclarations, - // we can do some additional validation. +List getDefaultSubstitutionArgs( + ASTBuilder* astBuilder, + SemanticsVisitor* semantics, + GenericDecl* genericDecl) +{ + List args; + if (astBuilder->m_cachedGenericDefaultArgs.tryGetValue(genericDecl, args)) + return args; - // First, we will ensure that the return types match - // between the declarations, so that they are truly - // interchangeable. - // - // Note(tfoley): If we ever decide to add a beefier type - // system to Slang, we might allow overloads like this, - // so long as the desired result type can be disambiguated - // based on context at the call type. In that case we would - // consider result types earlier, as part of the signature - // matching step. - // - auto resultType = getResultType(m_astBuilder, newDeclRef); - auto prevResultType = getResultType(m_astBuilder, oldDeclRef); - if (!resultType->equals(prevResultType)) + for (auto mm : genericDecl->members) + { + if (auto genericTypeParamDecl = as(mm)) { - // Bad redeclaration - getSink()->diagnose(newDecl, Diagnostics::functionRedeclarationWithDifferentReturnType, newDecl->getName(), resultType, prevResultType); - getSink()->diagnose(oldDecl, Diagnostics::seePreviousDeclarationOf, newDecl->getName()); - - // Don't bother emitting other errors at this point - return SLANG_FAIL; + args.add(DeclRefType::create( + astBuilder, + astBuilder->getDirectDeclRef(genericTypeParamDecl))); + } + else if (auto genericTypePackParamDecl = as(mm)) + { + auto packType = DeclRefType::create( + astBuilder, + astBuilder->getDirectDeclRef(genericTypePackParamDecl)); + args.add(packType); } + else if (auto genericValueParamDecl = as(mm)) + { + if (semantics) + semantics->ensureDecl(genericValueParamDecl, DeclCheckState::ReadyForLookup); - // TODO: Enforce that the new declaration had better - // not specify a default value for any parameter that - // already had a default value in a prior declaration. + args.add(astBuilder->getOrCreate( + genericValueParamDecl->getType(), + astBuilder->getDirectDeclRef(genericValueParamDecl))); + } + } - // We are going to want to enforce that we cannot have - // two declarations of a function both specify bodies. - // Before we make that check, however, we need to deal - // with the case where the two function declarations - // might represent different target-specific versions - // of a function. - - // If both of the declarations have a body, then there - // is trouble, because we wouldn't know which one to - // use during code generation. + bool shouldCache = true; - // Here to cover the 'bodies'/target_intrinsics, we find all the targets that - // that are previously defined, and make sure the new definition - // doesn't try and define what is already defined. + // create default substitution arguments for constraints + for (auto mm : genericDecl->members) + { + if (auto genericTypeConstraintDecl = as(mm)) { - TargetDeclDictionary currentTargets; + if (semantics) + semantics->ensureDecl(genericTypeConstraintDecl, DeclCheckState::ReadyForReference); + auto constraintDeclRef = + astBuilder->getDirectDeclRef(genericTypeConstraintDecl); + auto supType = getSup(astBuilder, constraintDeclRef); + if (!supType) { - CallableDecl* curDecl = newDecl->primaryDecl; - while (curDecl) - { - if (curDecl != newDecl) - { - _addTargetModifiers(curDecl, currentTargets); - } - curDecl = curDecl->nextDecl; - } + args.add(astBuilder->getErrorType()); + shouldCache = false; + continue; } + auto witness = astBuilder->getDeclaredSubtypeWitness( + getSub(astBuilder, constraintDeclRef), + getSup(astBuilder, constraintDeclRef), + constraintDeclRef); + // TODO: this is an ugly hack to prevent crashing. + // In early stages of compilation witness->sub and witness->sup may not be checked yet. + // When semanticVisitor is present we have used that to ensure the type is checked. + // However due to how the code is written we cannot guarantee semanticVisitor is always + // available here, and if we can't get the checked sup/sub type this subst is incomplete + // and should not be cached. + if (!witness->getSub()) + shouldCache = false; + args.add(witness); + } + } - // Add the targets for this new decl - TargetDeclDictionary newTargets; - _addTargetModifiers(newDecl, newTargets); + if (shouldCache) + astBuilder->m_cachedGenericDefaultArgs[genericDecl] = args; - bool hasConflict = false; - for (auto& [target, value] : newTargets) - { - auto found = currentTargets.tryGetValue(target); - if (found) - { - // Redefinition - if (!hasConflict) - { - getSink()->diagnose(newDecl, Diagnostics::functionRedefinition, newDecl->getName()); - hasConflict = true; - } + return args; +} - auto prevDecl = *found; - getSink()->diagnose(prevDecl, Diagnostics::seePreviousDefinitionOf, prevDecl->getName()); - } - } +typedef Dictionary TargetDeclDictionary; - if (hasConflict) +static void _addTargetModifiers(CallableDecl* decl, TargetDeclDictionary& ioDict) +{ + if (auto specializedModifier = decl->findModifier()) + { + // If it's specialized for target it should have a body... + if (auto funcDecl = as(decl)) + { + // Normally if we have specialization for target it must have a body. + if (funcDecl->body == nullptr) { - return SLANG_FAIL; + // If it doesn't have a body but does have a target intrinsic/SPIRVInstructionOp + // it's probably ok + + SLANG_ASSERT( + funcDecl->findModifier() || + funcDecl->findModifier()); } } + Name* targetName = specializedModifier->targetToken.getName(); - // At this point we've processed the redeclaration and - // put it into a group, so there is no reason to keep - // looping and looking at prior declarations. - // - // While no diagnostics have been emitted, we return - // a failure result from the operation to indicate - // to the caller that they should stop looping over - // declarations at this point. - // - return SLANG_FAIL; + ioDict.addIfNotExists(targetName, decl); } - - Result SemanticsVisitor::checkRedeclaration(Decl* newDecl, Decl* oldDecl) + else { - // If either of the declarations being looked at is generic, then - // we want to consider the "inner" declaration instead when - // making decisions about what to allow or not. - // - if(auto newGenericDecl = as(newDecl)) - newDecl = newGenericDecl->inner; - if(auto oldGenericDecl = as(oldDecl)) - oldDecl = oldGenericDecl->inner; - - // Functions are special in that we can have many declarations - // with the same name in a given scope, and it is possible - // for them to co-exist as overloads, or even just be multiple - // declarations of the same function (thanks to the inherited - // legacy of C forward declarations). - // - // If both declarations are functions, we will check that - // they are allowed to co-exist using these more nuanced rules. - // - if( auto newFuncDecl = as(newDecl) ) + for (auto modifier : decl->getModifiersOfType()) { - if(auto oldFuncDecl = as(oldDecl) ) - { - // Both new and old declarations are functions, - // so redeclaration may be valid. - return checkFuncRedeclaration(newFuncDecl, oldFuncDecl); - } + Name* targetName = modifier->targetToken.getName(); + ioDict.addIfNotExists(targetName, decl); } - if (as(oldDecl) || as(newDecl)) + auto funcDecl = as(decl); + if (funcDecl && funcDecl->body) { - // It is allowed to have a decl whose name is the same as the module. - return SLANG_OK; + // Should only be one body if it isn't specialized for target. + // Use nullptr for this scenario + ioDict.addIfNotExists(nullptr, decl); } + } +} +Result SemanticsVisitor::checkFuncRedeclaration(FuncDecl* newDecl, FuncDecl* oldDecl) +{ + // There are a few different cases that this function needs + // to check for: + // + // * If `newDecl` and `oldDecl` have different signatures such + // that they can always be distinguished at call sites, then + // they don't conflict and don't count as redeclarations. + // + // * If `newDecl` and `oldDecl` have matching signatures, but + // differ in return type (or other details that would affect + // compatibility), then the declarations conflict and an + // error needs to be diagnosed. + // + // * If `newDecl` and `oldDecl` have matching/compatible sigantures, + // but differ when it comes to target-specific overloading, + // then they can co-exist. + // + // * If `newDecl` and `oldDecl` have matching/compatible signatures + // and are specialized for the same target(s), then only + // one can have a body (in which case the other is a forward declaration), + // or else we have a redefinition error. + + auto newGenericDecl = as(newDecl->parentDecl); + auto oldGenericDecl = as(oldDecl->parentDecl); + + // If one declaration is a prefix/postfix operator, and the + // other is not a matching operator, then don't consider these + // to be re-declarations. + // + // Note(tfoley): Any attempt to call such an operator using + // ordinary function-call syntax (if we decided to allow it) + // would be ambiguous in such a case, of course. + // + if (newDecl->hasModifier() != oldDecl->hasModifier()) + return SLANG_OK; + if (newDecl->hasModifier() != oldDecl->hasModifier()) + return SLANG_OK; + + // If one is generic and the other isn't, then there is no match. + if ((newGenericDecl != nullptr) != (oldGenericDecl != nullptr)) + return SLANG_OK; + + // We are going to be comparing the signatures of the + // two functions, but if they are *generic* functions + // then we will need to compare them with consistent + // specializations in place. + // + // We'll go ahead and create some (unspecialized) declaration + // references here, just to be prepared. + // + DeclRef newDeclRef(newDecl); + DeclRef oldDeclRef(oldDecl); + + // If we are working with generic functions, then we need to + // consider if their generic signatures match. + // + if (newGenericDecl) + { + // If one declaration is generic, the other must be. + // (This condition was already checked above) + // + SLANG_ASSERT(oldGenericDecl); + + // As part of checking if the generic signatures match, + // we will produce a substitution that can be used to + // reference `oldGenericDecl` with the generic parameters + // substituted for those of `newDecl`. + // + // One way to think about it is that if we have these + // declarations (ignore the name differences...): + // + // // oldDecl: + // void foo1(T x); + // + // // newDecl: + // void foo2(U x); + // + // Then we will compare the parameter types of `foo2` + // against the specialization `foo1`. + // + DeclRef specializedOldDeclInner; + if (!doGenericSignaturesMatch(newGenericDecl, oldGenericDecl, &specializedOldDeclInner)) + return SLANG_OK; - // For all other flavors of declaration, we do not - // allow duplicate declarations with the same name. - // - // TODO: We might consider allowing some other cases - // of overloading that can be safely disambiguated: - // - // * A type and a value (function/variable/etc.) of the same name can usually - // co-exist because we can distinguish which is needed by context. - // - // * Multiple generic types with the same name can co-exist - // if their generic parameter lists are sufficient to - // tell them apart at a use site. - - // We will diagnose a redeclaration error at the new declaration, - // and point to the old declaration for context. - // - getSink()->diagnose(newDecl, Diagnostics::redeclaration, newDecl->getName()); - getSink()->diagnose(oldDecl, Diagnostics::seePreviousDeclarationOf, oldDecl->getName()); + oldDeclRef = specializedOldDeclInner.as(); + } + + // If the parameter signatures don't match, then don't worry + if (!doFunctionSignaturesMatch(newDeclRef, oldDeclRef)) + return SLANG_OK; + + // If the declatation is declared by 'extern', and new definition is with 'export', then + // we should let overload resolution to handle it. + if (oldDecl->hasModifier() && newDecl->hasModifier()) + { + return SLANG_OK; + } + + // If we get this far, then we've got two declarations in the same + // scope, with the same name and signature, so they appear + // to be redeclarations. + // + // We will track that redeclaration occured, so that we can + // take it into account for overload resolution. + // + // A huge complication that we'll need to deal with is that + // multiple declarations might introduce default values for + // (different) parameters, and we might need to merge across + // all of them (which could get complicated if defaults for + // parameters can reference earlier parameters). + + // If the previous declaration wasn't already recorded + // as being part of a redeclaration family, then make + // it the primary declaration of a new family. + if (!oldDecl->primaryDecl) + { + oldDecl->primaryDecl = oldDecl; + } + + // The new declaration will belong to the family of + // the previous one, and so it will share the same + // primary declaration. + newDecl->primaryDecl = oldDecl->primaryDecl; + newDecl->nextDecl = nullptr; + + // Next we want to chain the new declaration onto + // the linked list of redeclarations. + auto link = &oldDecl->nextDecl; + while (*link) + link = &(*link)->nextDecl; + *link = newDecl; + + // Now that we've added things to a group of redeclarations, + // we can do some additional validation. + + // First, we will ensure that the return types match + // between the declarations, so that they are truly + // interchangeable. + // + // Note(tfoley): If we ever decide to add a beefier type + // system to Slang, we might allow overloads like this, + // so long as the desired result type can be disambiguated + // based on context at the call type. In that case we would + // consider result types earlier, as part of the signature + // matching step. + // + auto resultType = getResultType(m_astBuilder, newDeclRef); + auto prevResultType = getResultType(m_astBuilder, oldDeclRef); + if (!resultType->equals(prevResultType)) + { + // Bad redeclaration + getSink()->diagnose( + newDecl, + Diagnostics::functionRedeclarationWithDifferentReturnType, + newDecl->getName(), + resultType, + prevResultType); + getSink()->diagnose(oldDecl, Diagnostics::seePreviousDeclarationOf, newDecl->getName()); + + // Don't bother emitting other errors at this point return SLANG_FAIL; } + // TODO: Enforce that the new declaration had better + // not specify a default value for any parameter that + // already had a default value in a prior declaration. - void SemanticsVisitor::checkForRedeclaration(Decl* decl) - { - // We want to consider a "new" declaration in the context - // of some parent/container declaration, and compare it - // to pre-existing "old" declarations of the same name - // in the same container. - // - auto newDecl = decl; - auto parentDecl = decl->parentDecl; + // We are going to want to enforce that we cannot have + // two declarations of a function both specify bodies. + // Before we make that check, however, we need to deal + // with the case where the two function declarations + // might represent different target-specific versions + // of a function. - // Sanity check: there should always be a parent declaration. - // - SLANG_ASSERT(parentDecl); - if (!parentDecl) return; + // If both of the declarations have a body, then there + // is trouble, because we wouldn't know which one to + // use during code generation. - // If the declaration is the "inner" declaration of a generic, - // then we actually want to look one level up, because the - // peers/siblings of the declaration will belong to the same - // parent as the generic, not to the generic. - // - if( auto genericParentDecl = as(parentDecl) ) + // Here to cover the 'bodies'/target_intrinsics, we find all the targets that + // that are previously defined, and make sure the new definition + // doesn't try and define what is already defined. + { + TargetDeclDictionary currentTargets; { - // Note: we need to check here to be sure `newDecl` - // is the "inner" declaration and not one of the - // generic parameters, or else we will end up - // checking them at the wrong scope. - // - if( newDecl == genericParentDecl->inner ) + CallableDecl* curDecl = newDecl->primaryDecl; + while (curDecl) { - newDecl = parentDecl; - parentDecl = genericParentDecl->parentDecl; + if (curDecl != newDecl) + { + _addTargetModifiers(curDecl, currentTargets); + } + curDecl = curDecl->nextDecl; } } - // We will now look for other declarations with - // the same name in the same parent/container. - // - parentDecl->buildMemberDictionary(); - for (auto oldDecl = newDecl->nextInContainerWithSameName; oldDecl; oldDecl = oldDecl->nextInContainerWithSameName) + // Add the targets for this new decl + TargetDeclDictionary newTargets; + _addTargetModifiers(newDecl, newTargets); + + bool hasConflict = false; + for (auto& [target, value] : newTargets) { - // For each matching declaration, we will check - // whether the redeclaration should be allowed, - // and emit an appropriate diagnostic if not. - // - Result checkResult = checkRedeclaration(newDecl, oldDecl); + auto found = currentTargets.tryGetValue(target); + if (found) + { + // Redefinition + if (!hasConflict) + { + getSink()->diagnose( + newDecl, + Diagnostics::functionRedefinition, + newDecl->getName()); + hasConflict = true; + } - // The `checkRedeclaration` function will return a failure - // status (whether or not it actually emitted a diagnostic) - // if we should stop checking further redeclarations, because - // the declaration in question has been dealt with fully. - // - if(SLANG_FAILED(checkResult)) - break; + auto prevDecl = *found; + getSink()->diagnose( + prevDecl, + Diagnostics::seePreviousDefinitionOf, + prevDecl->getName()); + } + } + + if (hasConflict) + { + return SLANG_FAIL; } } + // At this point we've processed the redeclaration and + // put it into a group, so there is no reason to keep + // looping and looking at prior declarations. + // + // While no diagnostics have been emitted, we return + // a failure result from the operation to indicate + // to the caller that they should stop looping over + // declarations at this point. + // + return SLANG_FAIL; +} - void SemanticsDeclHeaderVisitor::visitParamDecl(ParamDecl* paramDecl) - { - // TODO: This logic should be shared with the other cases of - // variable declarations. The main reason I am not doing it - // yet is that we use a `ParamDecl` with a null type as a - // special case in attribute declarations, and that could - // trip up the ordinary variable checks. +Result SemanticsVisitor::checkRedeclaration(Decl* newDecl, Decl* oldDecl) +{ + // If either of the declarations being looked at is generic, then + // we want to consider the "inner" declaration instead when + // making decisions about what to allow or not. + // + if (auto newGenericDecl = as(newDecl)) + newDecl = newGenericDecl->inner; + if (auto oldGenericDecl = as(oldDecl)) + oldDecl = oldGenericDecl->inner; + + // Functions are special in that we can have many declarations + // with the same name in a given scope, and it is possible + // for them to co-exist as overloads, or even just be multiple + // declarations of the same function (thanks to the inherited + // legacy of C forward declarations). + // + // If both declarations are functions, we will check that + // they are allowed to co-exist using these more nuanced rules. + // + if (auto newFuncDecl = as(newDecl)) + { + if (auto oldFuncDecl = as(oldDecl)) + { + // Both new and old declarations are functions, + // so redeclaration may be valid. + return checkFuncRedeclaration(newFuncDecl, oldFuncDecl); + } + } + + if (as(oldDecl) || as(newDecl)) + { + // It is allowed to have a decl whose name is the same as the module. + return SLANG_OK; + } + + + // For all other flavors of declaration, we do not + // allow duplicate declarations with the same name. + // + // TODO: We might consider allowing some other cases + // of overloading that can be safely disambiguated: + // + // * A type and a value (function/variable/etc.) of the same name can usually + // co-exist because we can distinguish which is needed by context. + // + // * Multiple generic types with the same name can co-exist + // if their generic parameter lists are sufficient to + // tell them apart at a use site. + + // We will diagnose a redeclaration error at the new declaration, + // and point to the old declaration for context. + // + getSink()->diagnose(newDecl, Diagnostics::redeclaration, newDecl->getName()); + getSink()->diagnose(oldDecl, Diagnostics::seePreviousDeclarationOf, oldDecl->getName()); + return SLANG_FAIL; +} - auto typeExpr = paramDecl->type; - if(typeExpr.exp) - { - SemanticsVisitor subVisitor(withDeclToExcludeFromLookup(paramDecl)); - typeExpr = subVisitor.CheckUsableType(typeExpr, paramDecl); - paramDecl->type = typeExpr; - checkMeshOutputDecl(paramDecl); - } - if (auto declRefType = as(paramDecl->type.type)) +void SemanticsVisitor::checkForRedeclaration(Decl* decl) +{ + // We want to consider a "new" declaration in the context + // of some parent/container declaration, and compare it + // to pre-existing "old" declarations of the same name + // in the same container. + // + auto newDecl = decl; + auto parentDecl = decl->parentDecl; + + // Sanity check: there should always be a parent declaration. + // + SLANG_ASSERT(parentDecl); + if (!parentDecl) + return; + + // If the declaration is the "inner" declaration of a generic, + // then we actually want to look one level up, because the + // peers/siblings of the declaration will belong to the same + // parent as the generic, not to the generic. + // + if (auto genericParentDecl = as(parentDecl)) + { + // Note: we need to check here to be sure `newDecl` + // is the "inner" declaration and not one of the + // generic parameters, or else we will end up + // checking them at the wrong scope. + // + if (newDecl == genericParentDecl->inner) + { + newDecl = parentDecl; + parentDecl = genericParentDecl->parentDecl; + } + } + + // We will now look for other declarations with + // the same name in the same parent/container. + // + parentDecl->buildMemberDictionary(); + for (auto oldDecl = newDecl->nextInContainerWithSameName; oldDecl; + oldDecl = oldDecl->nextInContainerWithSameName) + { + // For each matching declaration, we will check + // whether the redeclaration should be allowed, + // and emit an appropriate diagnostic if not. + // + Result checkResult = checkRedeclaration(newDecl, oldDecl); + + // The `checkRedeclaration` function will return a failure + // status (whether or not it actually emitted a diagnostic) + // if we should stop checking further redeclarations, because + // the declaration in question has been dealt with fully. + // + if (SLANG_FAILED(checkResult)) + break; + } +} + + +void SemanticsDeclHeaderVisitor::visitParamDecl(ParamDecl* paramDecl) +{ + // TODO: This logic should be shared with the other cases of + // variable declarations. The main reason I am not doing it + // yet is that we use a `ParamDecl` with a null type as a + // special case in attribute declarations, and that could + // trip up the ordinary variable checks. + + auto typeExpr = paramDecl->type; + if (typeExpr.exp) + { + SemanticsVisitor subVisitor(withDeclToExcludeFromLookup(paramDecl)); + typeExpr = subVisitor.CheckUsableType(typeExpr, paramDecl); + paramDecl->type = typeExpr; + checkMeshOutputDecl(paramDecl); + } + + if (auto declRefType = as(paramDecl->type.type)) + { + if (declRefType->getDeclRef().getDecl()->findModifier()) { - if (declRefType->getDeclRef().getDecl()->findModifier()) + // Always pass a non-copyable type by reference. + // Remove all existing direction modifiers, and replace them with a single Ref modifier. + List newModifiers; + bool hasRefModifier = false; + bool isMutable = false; + for (auto modifier : paramDecl->modifiers) { - // Always pass a non-copyable type by reference. - // Remove all existing direction modifiers, and replace them with a single Ref modifier. - List newModifiers; - bool hasRefModifier = false; - bool isMutable = false; - for (auto modifier : paramDecl->modifiers) + if (as(modifier)) { - if (as(modifier)) - { - continue; - } - else if (as(modifier) || as(modifier)) - { - isMutable = true; - continue; - } - if (as(modifier) || as(modifier)) - { - hasRefModifier = true; - } - newModifiers.add(modifier); + continue; } - if (!hasRefModifier) + else if (as(modifier) || as(modifier)) { - if (isMutable) - newModifiers.add(this->getASTBuilder()->create()); - else - newModifiers.add(this->getASTBuilder()->create()); + isMutable = true; + continue; } - paramDecl->modifiers.first = newModifiers.getFirst(); - for (Index i = 0; i < newModifiers.getCount(); i++) + if (as(modifier) || as(modifier)) { - if (i < newModifiers.getCount() - 1) - newModifiers[i]->next = newModifiers[i + 1]; - else - newModifiers[i]->next = nullptr; + hasRefModifier = true; } + newModifiers.add(modifier); } - } - else if (isTypePack(paramDecl->type.type)) - { - // For now, we only allow parameter packs to be `const`. - bool hasConstModifier = false; - for (auto modifier : paramDecl->modifiers) + if (!hasRefModifier) { - if (as(modifier) || as(modifier) || as(modifier) || as(modifier)) - { - getSink()->diagnose(modifier, Diagnostics::parameterPackMustBeConst); - } - else if (as(modifier)) - { - hasConstModifier = true; - } + if (isMutable) + newModifiers.add(this->getASTBuilder()->create()); + else + newModifiers.add(this->getASTBuilder()->create()); } - if (!hasConstModifier) + paramDecl->modifiers.first = newModifiers.getFirst(); + for (Index i = 0; i < newModifiers.getCount(); i++) { - auto constModifier = this->getASTBuilder()->create(); - addModifier(paramDecl, constModifier); + if (i < newModifiers.getCount() - 1) + newModifiers[i]->next = newModifiers[i + 1]; + else + newModifiers[i]->next = nullptr; } } - - maybeApplyLayoutModifier(paramDecl); - - // Only texture types are allowed to have memory qualifiers on parameters - if(!paramDecl->type || paramDecl->type->astNodeType != ASTNodeType::TextureType) - { - auto MemoryQualifierSet = paramDecl->findModifier(); - if(!MemoryQualifierSet) - return; - for(auto mod : MemoryQualifierSet->getModifiers()) - getSink()->diagnose(paramDecl, Diagnostics::memoryQualifierNotAllowedOnANonImageTypeParameter, mod); - } } - - // This checks that the declaration is marked as "out" and changes the hlsl - // modifier based syntax into a proper type. - void SemanticsDeclHeaderVisitor::checkMeshOutputDecl(VarDeclBase* varDecl) + else if (isTypePack(paramDecl->type.type)) { - auto modifier = varDecl->findModifier(); - auto meshOutputType = as(varDecl->type.type); - bool isMeshOutput = modifier || meshOutputType; - - if(!isMeshOutput) - { - return; - } - // HLSL requires an 'out' modifier here, but since we don't operate - // under such strict compatability we can just not warn here. - if(!varDecl->findModifier() && modifier) - { - getSink()->diagnose(varDecl, Diagnostics::meshOutputMustBeOut); - } - - // - // If necessary, convert to our typed representation - // - if(!modifier) - { - return; - } - if(meshOutputType) - { - getSink()->diagnose(modifier, Diagnostics::unnecessaryHLSLMeshOutputModifier); - varDecl->type.type = m_astBuilder->getErrorType(); - return; - } - auto indexExpr = as(varDecl->type.exp); - if(!indexExpr) + // For now, we only allow parameter packs to be `const`. + bool hasConstModifier = false; + for (auto modifier : paramDecl->modifiers) { - getSink()->diagnose(varDecl, Diagnostics::meshOutputMustBeArray); - varDecl->type.type = m_astBuilder->getErrorType(); - return; + if (as(modifier) || as(modifier) || + as(modifier) || as(modifier)) + { + getSink()->diagnose(modifier, Diagnostics::parameterPackMustBeConst); + } + else if (as(modifier)) + { + hasConstModifier = true; + } } - if(indexExpr->indexExprs.getCount() != 1) + if (!hasConstModifier) { - getSink()->diagnose(varDecl, Diagnostics::meshOutputArrayMustHaveSize); - varDecl->type.type = m_astBuilder->getErrorType(); - return; + auto constModifier = this->getASTBuilder()->create(); + addModifier(paramDecl, constModifier); } - auto base = ExpectAType(indexExpr->baseExpression); - auto index = CheckIntegerConstantExpression( - indexExpr->indexExprs[0], - IntegerConstantExpressionCoercionType::AnyInteger, - nullptr, - ConstantFoldingKind::LinkTime, - getSink()); - - Type* d = m_astBuilder->getMeshOutputTypeFromModifier(modifier, base, index); - varDecl->type.type = d; } - void SemanticsDeclBodyVisitor::visitParamDecl(ParamDecl* paramDecl) + maybeApplyLayoutModifier(paramDecl); + + // Only texture types are allowed to have memory qualifiers on parameters + if (!paramDecl->type || paramDecl->type->astNodeType != ASTNodeType::TextureType) { - auto typeExpr = paramDecl->type; + auto MemoryQualifierSet = paramDecl->findModifier(); + if (!MemoryQualifierSet) + return; + for (auto mod : MemoryQualifierSet->getModifiers()) + getSink()->diagnose( + paramDecl, + Diagnostics::memoryQualifierNotAllowedOnANonImageTypeParameter, + mod); + } +} - if (!as(paramDecl->type) && doesTypeHaveTag(paramDecl->type, TypeTag::Unsized)) - { - getSink()->diagnose(paramDecl, Diagnostics::paramCannotBeUnsized, paramDecl); - } +// This checks that the declaration is marked as "out" and changes the hlsl +// modifier based syntax into a proper type. +void SemanticsDeclHeaderVisitor::checkMeshOutputDecl(VarDeclBase* varDecl) +{ + auto modifier = varDecl->findModifier(); + auto meshOutputType = as(varDecl->type.type); + bool isMeshOutput = modifier || meshOutputType; - // The "initializer" expression for a parameter represents - // a default argument value to use if an explicit one is - // not supplied. - if(auto initExpr = paramDecl->initExpr) - { - // We must check the expression and coerce it to the - // actual type of the parameter. - // - initExpr = CheckTerm(initExpr); - initExpr = coerce(CoercionSite::Initializer, typeExpr.type, initExpr); - paramDecl->initExpr = initExpr; - - // TODO: a default argument expression needs to - // conform to other constraints to be valid. - // For example, it should not be allowed to refer - // to other parameters of the same function (or maybe - // only the parameters to its left...). - - // A default argument value should not be allowed on an - // `out` or `inout` parameter. - // - // TODO: we could relax this by requiring the expression - // to yield an lvalue, but that seems like a feature - // with limited practical utility (and an easy source - // of confusing behavior). - // - // Note: the `InOutModifier` class inherits from `OutModifier`, - // so we only need to check for the base case. - // - if(paramDecl->findModifier()) - { - getSink()->diagnose(initExpr, Diagnostics::outputParameterCannotHaveDefaultValue); - } - } + if (!isMeshOutput) + { + return; + } + // HLSL requires an 'out' modifier here, but since we don't operate + // under such strict compatability we can just not warn here. + if (!varDecl->findModifier() && modifier) + { + getSink()->diagnose(varDecl, Diagnostics::meshOutputMustBeOut); + } + + // + // If necessary, convert to our typed representation + // + if (!modifier) + { + return; + } + if (meshOutputType) + { + getSink()->diagnose(modifier, Diagnostics::unnecessaryHLSLMeshOutputModifier); + varDecl->type.type = m_astBuilder->getErrorType(); + return; + } + auto indexExpr = as(varDecl->type.exp); + if (!indexExpr) + { + getSink()->diagnose(varDecl, Diagnostics::meshOutputMustBeArray); + varDecl->type.type = m_astBuilder->getErrorType(); + return; } + if (indexExpr->indexExprs.getCount() != 1) + { + getSink()->diagnose(varDecl, Diagnostics::meshOutputArrayMustHaveSize); + varDecl->type.type = m_astBuilder->getErrorType(); + return; + } + auto base = ExpectAType(indexExpr->baseExpression); + auto index = CheckIntegerConstantExpression( + indexExpr->indexExprs[0], + IntegerConstantExpressionCoercionType::AnyInteger, + nullptr, + ConstantFoldingKind::LinkTime, + getSink()); + + Type* d = m_astBuilder->getMeshOutputTypeFromModifier(modifier, base, index); + varDecl->type.type = d; +} +void SemanticsDeclBodyVisitor::visitParamDecl(ParamDecl* paramDecl) +{ + auto typeExpr = paramDecl->type; - static SeqStmt* _ensureCtorBodyIsSeqStmt(ASTBuilder* m_astBuilder, ConstructorDecl* decl) + if (!as(paramDecl->type) && + doesTypeHaveTag(paramDecl->type, TypeTag::Unsized)) { - // It is possible BlockStmt has a child with the type of - // `ExpressionStmt` if an existing constructor has only 1 - // expression. This would be a senario we need to - // put the `ExpressionStmt` inside a `SeqStmt`. - auto stmt = as(decl->body); - if (!stmt) - { - auto tmpExpr = decl->body; - auto blockStmt = m_astBuilder->create(); - blockStmt->body = tmpExpr; - decl->body = blockStmt; - stmt = blockStmt; - } - if (!as(stmt->body)) - { - auto tmpExpr = stmt->body; - auto seqStmt = m_astBuilder->create(); - seqStmt->stmts.add(tmpExpr); - stmt->body = seqStmt; - return seqStmt; - } - return as(stmt->body); + getSink()->diagnose(paramDecl, Diagnostics::paramCannotBeUnsized, paramDecl); } - void SemanticsDeclBodyVisitor::synthesizeCtorBodyForBases(ConstructorDecl* ctor, List& inheritanceDefaultCtorList, ThisExpr* thisExpr, SeqStmt* seqStmtChild) + // The "initializer" expression for a parameter represents + // a default argument value to use if an explicit one is + // not supplied. + if (auto initExpr = paramDecl->initExpr) { - // e.g. this->base = BaseType(); - for (auto& declInfo : inheritanceDefaultCtorList) - { - if (!declInfo.defaultCtor) - continue; + // We must check the expression and coerce it to the + // actual type of the parameter. + // + initExpr = CheckTerm(initExpr); + initExpr = coerce(CoercionSite::Initializer, typeExpr.type, initExpr); + paramDecl->initExpr = initExpr; - auto ctorToInvoke = m_astBuilder->create(); - ctorToInvoke->declRef = declInfo.defaultCtor->getDefaultDeclRef(); - ctorToInvoke->name = declInfo.defaultCtor->getName(); - ctorToInvoke->loc = declInfo.defaultCtor->loc; - ctorToInvoke->type = m_astBuilder->getFuncType(ArrayView(), ctor->returnType.type); + // TODO: a default argument expression needs to + // conform to other constraints to be valid. + // For example, it should not be allowed to refer + // to other parameters of the same function (or maybe + // only the parameters to its left...). - auto invoke = m_astBuilder->create(); - invoke->functionExpr = ctorToInvoke; + // A default argument value should not be allowed on an + // `out` or `inout` parameter. + // + // TODO: we could relax this by requiring the expression + // to yield an lvalue, but that seems like a feature + // with limited practical utility (and an easy source + // of confusing behavior). + // + // Note: the `InOutModifier` class inherits from `OutModifier`, + // so we only need to check for the base case. + // + if (paramDecl->findModifier()) + { + getSink()->diagnose(initExpr, Diagnostics::outputParameterCannotHaveDefaultValue); + } + } +} - auto assign = m_astBuilder->create(); - assign->left = coerce(CoercionSite::Initializer, declInfo.defaultCtor->returnType.type, thisExpr); - assign->right = invoke; - auto stmt = m_astBuilder->create(); - stmt->expression = assign; - stmt->loc = ctor->loc; - seqStmtChild->stmts.add(stmt); - } +static SeqStmt* _ensureCtorBodyIsSeqStmt(ASTBuilder* m_astBuilder, ConstructorDecl* decl) +{ + // It is possible BlockStmt has a child with the type of + // `ExpressionStmt` if an existing constructor has only 1 + // expression. This would be a senario we need to + // put the `ExpressionStmt` inside a `SeqStmt`. + auto stmt = as(decl->body); + if (!stmt) + { + auto tmpExpr = decl->body; + auto blockStmt = m_astBuilder->create(); + blockStmt->body = tmpExpr; + decl->body = blockStmt; + stmt = blockStmt; + } + if (!as(stmt->body)) + { + auto tmpExpr = stmt->body; + auto seqStmt = m_astBuilder->create(); + seqStmt->stmts.add(tmpExpr); + stmt->body = seqStmt; + return seqStmt; } + return as(stmt->body); +} - void SemanticsDeclBodyVisitor::synthesizeCtorBodyForMember(ConstructorDecl* ctor, Decl* member, ThisExpr* thisExpr, Dictionary& cachedDeclToCheckedVar, SeqStmt* seqStmtChild) +void SemanticsDeclBodyVisitor::synthesizeCtorBodyForBases( + ConstructorDecl* ctor, + List& inheritanceDefaultCtorList, + ThisExpr* thisExpr, + SeqStmt* seqStmtChild) +{ + // e.g. this->base = BaseType(); + for (auto& declInfo : inheritanceDefaultCtorList) { - auto varDeclBase = as(member); + if (!declInfo.defaultCtor) + continue; - // Static variables are initialized at start of runtime, not inside a constructor - if (!varDeclBase - || !varDeclBase->initExpr - || varDeclBase->hasModifier()) - return; + auto ctorToInvoke = m_astBuilder->create(); + ctorToInvoke->declRef = declInfo.defaultCtor->getDefaultDeclRef(); + ctorToInvoke->name = declInfo.defaultCtor->getName(); + ctorToInvoke->loc = declInfo.defaultCtor->loc; + ctorToInvoke->type = m_astBuilder->getFuncType(ArrayView(), ctor->returnType.type); - MemberExpr* memberExpr = m_astBuilder->create(); - memberExpr->baseExpression = thisExpr; - memberExpr->declRef = member->getDefaultDeclRef(); - memberExpr->scope = ctor->ownedScope; - memberExpr->loc = member->loc; - memberExpr->name = member->getName(); - memberExpr->type = DeclRefType::create(getASTBuilder(), member->getDefaultDeclRef()); + auto invoke = m_astBuilder->create(); + invoke->functionExpr = ctorToInvoke; auto assign = m_astBuilder->create(); - assign->left = memberExpr; - assign->right = varDeclBase->initExpr; - assign->loc = member->loc; - + assign->left = + coerce(CoercionSite::Initializer, declInfo.defaultCtor->returnType.type, thisExpr); + assign->right = invoke; auto stmt = m_astBuilder->create(); stmt->expression = assign; - stmt->loc = member->loc; - - Expr* checkedMemberVarExpr; - if (cachedDeclToCheckedVar.containsKey(member)) - checkedMemberVarExpr = cachedDeclToCheckedVar[member]; - else - { - checkedMemberVarExpr = CheckTerm(memberExpr); - cachedDeclToCheckedVar.add({ member, checkedMemberVarExpr }); - } - - if (!checkedMemberVarExpr->type.isLeftValue) - return; + stmt->loc = ctor->loc; seqStmtChild->stmts.add(stmt); } +} +void SemanticsDeclBodyVisitor::synthesizeCtorBodyForMember( + ConstructorDecl* ctor, + Decl* member, + ThisExpr* thisExpr, + Dictionary& cachedDeclToCheckedVar, + SeqStmt* seqStmtChild) +{ + auto varDeclBase = as(member); - void SemanticsDeclBodyVisitor::synthesizeCtorBody(DeclAndCtorInfo& structDeclInfo, List& inheritanceDefaultCtorList, StructDecl* structDecl) - { - Dictionary cachedDeclToCheckedVar; - for (auto ctor : structDeclInfo.ctorList) - { - auto seqStmt = _ensureCtorBodyIsSeqStmt(m_astBuilder, ctor); - auto seqStmtChild = m_astBuilder->create(); - seqStmtChild->stmts.reserve(inheritanceDefaultCtorList.getCount() + structDecl->members.getCount()); + // Static variables are initialized at start of runtime, not inside a constructor + if (!varDeclBase || !varDeclBase->initExpr || varDeclBase->hasModifier()) + return; - ThisExpr* thisExpr = m_astBuilder->create(); - thisExpr->scope = ctor->ownedScope; - thisExpr->type = ctor->returnType.type; + MemberExpr* memberExpr = m_astBuilder->create(); + memberExpr->baseExpression = thisExpr; + memberExpr->declRef = member->getDefaultDeclRef(); + memberExpr->scope = ctor->ownedScope; + memberExpr->loc = member->loc; + memberExpr->name = member->getName(); + memberExpr->type = DeclRefType::create(getASTBuilder(), member->getDefaultDeclRef()); - // Initialize base type by using its default constructor if it has one. - synthesizeCtorBodyForBases(ctor, inheritanceDefaultCtorList, thisExpr, seqStmtChild); + auto assign = m_astBuilder->create(); + assign->left = memberExpr; + assign->right = varDeclBase->initExpr; + assign->loc = member->loc; - // Initialize member variables by using their default value if they have one - // e.g. this->member = default_value - for (auto& m : structDecl->members) - { - synthesizeCtorBodyForMember(ctor, m, thisExpr, cachedDeclToCheckedVar, seqStmtChild); - } + auto stmt = m_astBuilder->create(); + stmt->expression = assign; + stmt->loc = member->loc; - if (seqStmtChild->stmts.getCount() != 0) - { - seqStmt->stmts.insert(0, seqStmtChild); - } - } + Expr* checkedMemberVarExpr; + if (cachedDeclToCheckedVar.containsKey(member)) + checkedMemberVarExpr = cachedDeclToCheckedVar[member]; + else + { + checkedMemberVarExpr = CheckTerm(memberExpr); + cachedDeclToCheckedVar.add({member, checkedMemberVarExpr}); } - void SemanticsDeclBodyVisitor::visitAggTypeDecl(AggTypeDecl* aggTypeDecl) + if (!checkedMemberVarExpr->type.isLeftValue) + return; + + seqStmtChild->stmts.add(stmt); +} + + +void SemanticsDeclBodyVisitor::synthesizeCtorBody( + DeclAndCtorInfo& structDeclInfo, + List& inheritanceDefaultCtorList, + StructDecl* structDecl) +{ + Dictionary cachedDeclToCheckedVar; + for (auto ctor : structDeclInfo.ctorList) { - if (aggTypeDecl->hasTag(TypeTag::Incomplete) && aggTypeDecl->hasModifier()) - { - getSink()->diagnose(aggTypeDecl->loc, Diagnostics::cannotExportIncompleteType, aggTypeDecl); - } + auto seqStmt = _ensureCtorBodyIsSeqStmt(m_astBuilder, ctor); + auto seqStmtChild = m_astBuilder->create(); + seqStmtChild->stmts.reserve( + inheritanceDefaultCtorList.getCount() + structDecl->members.getCount()); - auto structDecl = as(aggTypeDecl); - if (!structDecl) - return; + ThisExpr* thisExpr = m_astBuilder->create(); + thisExpr->scope = ctor->ownedScope; + thisExpr->type = ctor->returnType.type; - List inheritanceDefaultCtorList{}; - for (auto inheritanceMember : structDecl->getMembersOfType()) - { - auto declRefType = as(inheritanceMember->base.type); - if (!declRefType) - continue; - auto structOfInheritance = as(declRefType->getDeclRef().getDecl()); - if (!structOfInheritance) - continue; - inheritanceDefaultCtorList.add(DeclAndCtorInfo(m_astBuilder, this, structOfInheritance, true)); - } - DeclAndCtorInfo structDeclInfo = DeclAndCtorInfo(m_astBuilder, this, structDecl, false); + // Initialize base type by using its default constructor if it has one. + synthesizeCtorBodyForBases(ctor, inheritanceDefaultCtorList, thisExpr, seqStmtChild); - // ensure all varDecl members are processed up to SemanticsBodyVisitor so we can be sure that if init expressions - // of members are to be synthisised, they are. - bool isDefaultInitializableType = isSubtype(DeclRefType::create(m_astBuilder, structDecl), m_astBuilder->getDefaultInitializableType(), IsSubTypeOptions::None); - for (auto m : structDecl->members) + // Initialize member variables by using their default value if they have one + // e.g. this->member = default_value + for (auto& m : structDecl->members) { - auto varDeclBase = as(m); - if (!varDeclBase) - continue; - ensureDecl(m->getDefaultDeclRef(), DeclCheckState::DefaultConstructorReadyForUse); - if (!isDefaultInitializableType - || varDeclBase->initExpr) - continue; - varDeclBase->initExpr = constructDefaultInitExprForVar(this, varDeclBase); + synthesizeCtorBodyForMember(ctor, m, thisExpr, cachedDeclToCheckedVar, seqStmtChild); } - synthesizeCtorBody(structDeclInfo, inheritanceDefaultCtorList, structDecl); - - if (structDeclInfo.defaultCtor) + if (seqStmtChild->stmts.getCount() != 0) { - auto seqStmt = as(as(structDeclInfo.defaultCtor->body)->body); - if (seqStmt && seqStmt->stmts.getCount() == 0) - { - structDecl->members.remove(structDeclInfo.defaultCtor); - structDecl->invalidateMemberDictionary(); - structDecl->buildMemberDictionary(); - } + seqStmt->stmts.insert(0, seqStmtChild); } } +} + +void SemanticsDeclBodyVisitor::visitAggTypeDecl(AggTypeDecl* aggTypeDecl) +{ + if (aggTypeDecl->hasTag(TypeTag::Incomplete) && aggTypeDecl->hasModifier()) + { + getSink()->diagnose(aggTypeDecl->loc, Diagnostics::cannotExportIncompleteType, aggTypeDecl); + } - void SemanticsDeclHeaderVisitor::cloneModifiers(Decl* dest, Decl* src) + auto structDecl = as(aggTypeDecl); + if (!structDecl) + return; + + List inheritanceDefaultCtorList{}; + for (auto inheritanceMember : structDecl->getMembersOfType()) { - dest->modifiers = src->modifiers; + auto declRefType = as(inheritanceMember->base.type); + if (!declRefType) + continue; + auto structOfInheritance = as(declRefType->getDeclRef().getDecl()); + if (!structOfInheritance) + continue; + inheritanceDefaultCtorList.add( + DeclAndCtorInfo(m_astBuilder, this, structOfInheritance, true)); } - void SemanticsDeclHeaderVisitor::setFuncTypeIntoRequirementDecl(CallableDecl* decl, FuncType* funcType) + DeclAndCtorInfo structDeclInfo = DeclAndCtorInfo(m_astBuilder, this, structDecl, false); + + // ensure all varDecl members are processed up to SemanticsBodyVisitor so we can be sure that if + // init expressions of members are to be synthisised, they are. + bool isDefaultInitializableType = isSubtype( + DeclRefType::create(m_astBuilder, structDecl), + m_astBuilder->getDefaultInitializableType(), + IsSubTypeOptions::None); + for (auto m : structDecl->members) { - if (!funcType) - return; - decl->returnType.type = funcType->getResultType(); - decl->errorType.type = funcType->getErrorType(); - for (Index i = 0; i < funcType->getParamCount(); i++) - { - auto paramType = funcType->getParamType(i); - if (auto dirType = as(paramType)) - paramType = dirType->getValueType(); - auto param = m_astBuilder->create(); - param->type.type = paramType; - auto paramDir = funcType->getParamDirection(i); - switch (paramDir) - { - case ParameterDirection::kParameterDirection_InOut: - addModifier(param, m_astBuilder->create()); - break; - case ParameterDirection::kParameterDirection_Out: - addModifier(param, m_astBuilder->create()); - break; - case ParameterDirection::kParameterDirection_Ref: - addModifier(param, m_astBuilder->create()); - break; - case ParameterDirection::kParameterDirection_ConstRef: - addModifier(param, m_astBuilder->create()); - break; - default: - break; - } - decl->members.add(param); - param->parentDecl = decl; - } + auto varDeclBase = as(m); + if (!varDeclBase) + continue; + ensureDecl(m->getDefaultDeclRef(), DeclCheckState::DefaultConstructorReadyForUse); + if (!isDefaultInitializableType || varDeclBase->initExpr) + continue; + varDeclBase->initExpr = constructDefaultInitExprForVar(this, varDeclBase); } - void SemanticsDeclHeaderVisitor::checkCallableDeclCommon(CallableDecl* decl) + synthesizeCtorBody(structDeclInfo, inheritanceDefaultCtorList, structDecl); + + if (structDeclInfo.defaultCtor) { - for(auto paramDecl : decl->getParameters()) + auto seqStmt = as(as(structDeclInfo.defaultCtor->body)->body); + if (seqStmt && seqStmt->stmts.getCount() == 0) { - ensureDecl(paramDecl, DeclCheckState::ReadyForReference); + structDecl->members.remove(structDeclInfo.defaultCtor); + structDecl->invalidateMemberDictionary(); + structDecl->buildMemberDictionary(); } + } +} - auto errorType = decl->errorType; - if (errorType.exp) - { - errorType = CheckProperType(errorType); +void SemanticsDeclHeaderVisitor::cloneModifiers(Decl* dest, Decl* src) +{ + dest->modifiers = src->modifiers; +} +void SemanticsDeclHeaderVisitor::setFuncTypeIntoRequirementDecl( + CallableDecl* decl, + FuncType* funcType) +{ + if (!funcType) + return; + decl->returnType.type = funcType->getResultType(); + decl->errorType.type = funcType->getErrorType(); + for (Index i = 0; i < funcType->getParamCount(); i++) + { + auto paramType = funcType->getParamType(i); + if (auto dirType = as(paramType)) + paramType = dirType->getValueType(); + auto param = m_astBuilder->create(); + param->type.type = paramType; + auto paramDir = funcType->getParamDirection(i); + switch (paramDir) + { + case ParameterDirection::kParameterDirection_InOut: + addModifier(param, m_astBuilder->create()); + break; + case ParameterDirection::kParameterDirection_Out: + addModifier(param, m_astBuilder->create()); + break; + case ParameterDirection::kParameterDirection_Ref: + addModifier(param, m_astBuilder->create()); + break; + case ParameterDirection::kParameterDirection_ConstRef: + addModifier(param, m_astBuilder->create()); + break; + default: break; } - else + decl->members.add(param); + param->parentDecl = decl; + } +} + +void SemanticsDeclHeaderVisitor::checkCallableDeclCommon(CallableDecl* decl) +{ + for (auto paramDecl : decl->getParameters()) + { + ensureDecl(paramDecl, DeclCheckState::ReadyForReference); + } + + auto errorType = decl->errorType; + if (errorType.exp) + { + errorType = CheckProperType(errorType); + } + else + { + errorType = TypeExp(m_astBuilder->getBottomType()); + } + decl->errorType = errorType; + + if (auto interfaceDecl = findParentInterfaceDecl(decl)) + { + bool isDiffFunc = false; + if (decl->hasModifier() || + decl->hasModifier()) { - errorType = TypeExp(m_astBuilder->getBottomType()); - } - decl->errorType = errorType; + auto reqDecl = m_astBuilder->create(); + reqDecl->originalRequirementDecl = decl; + cloneModifiers(reqDecl, decl); + auto declRef = createDefaultSubstitutionsIfNeeded(m_astBuilder, this, makeDeclRef(decl)) + .as(); + auto diffFuncType = getForwardDiffFuncType(getFuncType(m_astBuilder, declRef)); + setFuncTypeIntoRequirementDecl(reqDecl, as(diffFuncType)); + interfaceDecl->members.add(reqDecl); + reqDecl->parentDecl = interfaceDecl; + + if (!decl->hasModifier()) + { + // Build decl-ref-type from interface. + auto interfaceType = + DeclRefType::create(getASTBuilder(), makeDeclRef(interfaceDecl)); + + // If the interface is differentiable, make the this type a pair. + if (tryGetDifferentialType(getASTBuilder(), interfaceType)) + reqDecl->diffThisType = getDifferentialPairType(interfaceType); + } - if (auto interfaceDecl = findParentInterfaceDecl(decl)) + auto reqRef = m_astBuilder->create(); + reqRef->referencedDecl = reqDecl; + reqRef->parentDecl = decl; + decl->members.add(reqRef); + isDiffFunc = true; + } + if (decl->hasModifier()) { - bool isDiffFunc = false; - if (decl->hasModifier() || decl->hasModifier()) + // Requirement for backward derivative. + auto declRef = createDefaultSubstitutionsIfNeeded(m_astBuilder, this, makeDeclRef(decl)) + .as(); + auto originalFuncType = getFuncType(m_astBuilder, declRef); + auto diffFuncType = as(getBackwardDiffFuncType(originalFuncType)); { - auto reqDecl = m_astBuilder->create(); + auto reqDecl = m_astBuilder->create(); reqDecl->originalRequirementDecl = decl; cloneModifiers(reqDecl, decl); - auto declRef = createDefaultSubstitutionsIfNeeded(m_astBuilder, this, makeDeclRef(decl)).as(); - auto diffFuncType = getForwardDiffFuncType(getFuncType(m_astBuilder, declRef)); - setFuncTypeIntoRequirementDecl(reqDecl, as(diffFuncType)); + setFuncTypeIntoRequirementDecl(reqDecl, diffFuncType); interfaceDecl->members.add(reqDecl); reqDecl->parentDecl = interfaceDecl; - if (!decl->hasModifier()) { // Build decl-ref-type from interface. - auto interfaceType = DeclRefType::create(getASTBuilder(), makeDeclRef(interfaceDecl)); + auto interfaceType = + DeclRefType::create(getASTBuilder(), makeDeclRef(interfaceDecl)); // If the interface is differentiable, make the this type a pair. if (tryGetDifferentialType(getASTBuilder(), interfaceType)) @@ -8745,3181 +9034,3479 @@ namespace Slang reqRef->referencedDecl = reqDecl; reqRef->parentDecl = decl; decl->members.add(reqRef); - isDiffFunc = true; - } - if (decl->hasModifier()) - { - // Requirement for backward derivative. - auto declRef = createDefaultSubstitutionsIfNeeded(m_astBuilder, this, makeDeclRef(decl)).as(); - auto originalFuncType = getFuncType(m_astBuilder, declRef); - auto diffFuncType = as(getBackwardDiffFuncType(originalFuncType)); - { - auto reqDecl = m_astBuilder->create(); - reqDecl->originalRequirementDecl = decl; - cloneModifiers(reqDecl, decl); - setFuncTypeIntoRequirementDecl(reqDecl, diffFuncType); - interfaceDecl->members.add(reqDecl); - reqDecl->parentDecl = interfaceDecl; - if (!decl->hasModifier()) - { - // Build decl-ref-type from interface. - auto interfaceType = DeclRefType::create(getASTBuilder(), makeDeclRef(interfaceDecl)); - - // If the interface is differentiable, make the this type a pair. - if (tryGetDifferentialType(getASTBuilder(), interfaceType)) - reqDecl->diffThisType = getDifferentialPairType(interfaceType); - } - - auto reqRef = m_astBuilder->create(); - reqRef->referencedDecl = reqDecl; - reqRef->parentDecl = decl; - decl->members.add(reqRef); - } - isDiffFunc = true; } - if (isDiffFunc) + isDiffFunc = true; + } + if (isDiffFunc) + { + auto interfaceDeclRef = + createDefaultSubstitutionsIfNeeded(m_astBuilder, this, makeDeclRef(interfaceDecl)); + auto interfaceType = DeclRefType::create(m_astBuilder, interfaceDeclRef); + bool noDiffThisRequirement = !isTypeDifferentiable(interfaceType); + if (noDiffThisRequirement) { - auto interfaceDeclRef = createDefaultSubstitutionsIfNeeded(m_astBuilder, this, makeDeclRef(interfaceDecl)); - auto interfaceType = DeclRefType::create(m_astBuilder, interfaceDeclRef); - bool noDiffThisRequirement = !isTypeDifferentiable(interfaceType); - if (noDiffThisRequirement) - { - auto noDiffThisModifier = m_astBuilder->create(); - addModifier(decl, noDiffThisModifier); - } + auto noDiffThisModifier = m_astBuilder->create(); + addModifier(decl, noDiffThisModifier); } } - if (decl->findModifier()) + } + if (decl->findModifier()) + { + // Add `no_diff` modifiers to parameters. + // This is necessary to preserve no-diff-ness for generic function before and after + // specialization. + for (auto paramDecl : decl->getParameters()) { - // Add `no_diff` modifiers to parameters. - // This is necessary to preserve no-diff-ness for generic function before and after - // specialization. - for (auto paramDecl : decl->getParameters()) + if (!paramDecl->type.type) + continue; + if (!isTypeDifferentiable(paramDecl->type.type)) { - if (!paramDecl->type.type) - continue; - if (!isTypeDifferentiable(paramDecl->type.type)) - { - if (!paramDecl->hasModifier()) - { - auto noDiffModifier = m_astBuilder->create(); - noDiffModifier->keywordName = getSession()->getNameObj("no_diff"); - addModifier(paramDecl, noDiffModifier); - } - } if (!paramDecl->hasModifier()) { - if (auto modifier = paramDecl->findModifier()) - { - getSink()->diagnose(modifier, Diagnostics::cannotUseConstRefOnDifferentiableParameter); - } + auto noDiffModifier = m_astBuilder->create(); + noDiffModifier->keywordName = getSession()->getNameObj("no_diff"); + addModifier(paramDecl, noDiffModifier); } } - if (!isEffectivelyStatic(decl)) + if (!paramDecl->hasModifier()) { - auto constrefAttr = decl->findModifier(); - auto refAttr = decl->findModifier(); - if (constrefAttr || refAttr) + if (auto modifier = paramDecl->findModifier()) { - if (isTypeDifferentiable(calcThisType(getParentDecl(decl)))) - { - getSink()->diagnose(constrefAttr, Diagnostics::cannotUseConstRefOnDifferentiableMemberMethod); - } + getSink()->diagnose( + modifier, + Diagnostics::cannotUseConstRefOnDifferentiableParameter); } } } - - // If this method is intended to be a CUDA kernel, verify that the return type is void. - if (decl->findModifier()) + if (!isEffectivelyStatic(decl)) { - if (decl->returnType.type && !decl->returnType.type->equals(m_astBuilder->getVoidType())) + auto constrefAttr = decl->findModifier(); + auto refAttr = decl->findModifier(); + if (constrefAttr || refAttr) { - getSink()->diagnose(decl, Diagnostics::cudaKernelMustReturnVoid); + if (isTypeDifferentiable(calcThisType(getParentDecl(decl)))) + { + getSink()->diagnose( + constrefAttr, + Diagnostics::cannotUseConstRefOnDifferentiableMemberMethod); + } } } - - checkVisibility(decl); } - void SemanticsDeclHeaderVisitor::visitFuncDecl(FuncDecl* funcDecl) + // If this method is intended to be a CUDA kernel, verify that the return type is void. + if (decl->findModifier()) { - auto resultType = funcDecl->returnType; - if(resultType.exp) - { - resultType = CheckProperType(resultType); - } - else if (!funcDecl->returnType.type) + if (decl->returnType.type && !decl->returnType.type->equals(m_astBuilder->getVoidType())) { - resultType = TypeExp(m_astBuilder->getVoidType()); + getSink()->diagnose(decl, Diagnostics::cudaKernelMustReturnVoid); } - funcDecl->returnType = resultType; - - checkCallableDeclCommon(funcDecl); } - IntegerLiteralValue SemanticsVisitor::GetMinBound(IntVal* val) - { - if (auto constantVal = as(val)) - return constantVal->getValue(); - - // TODO(tfoley): Need to track intervals so that this isn't just a lie... - return 1; - } + checkVisibility(decl); +} - void SemanticsVisitor::maybeInferArraySizeForVariable(VarDeclBase* varDecl) +void SemanticsDeclHeaderVisitor::visitFuncDecl(FuncDecl* funcDecl) +{ + auto resultType = funcDecl->returnType; + if (resultType.exp) { - // Not an array? - auto arrayType = as(varDecl->type); - if (!arrayType) return; - - // Explicit element count given? - if (!isUnsizedArrayType(arrayType)) - return; - - // No initializer? - auto initExpr = varDecl->initExpr; - if(!initExpr) return; - - IntVal* elementCount = nullptr; - - // Is the type of the initializer an array type? - if(auto arrayInitType = as(initExpr->type)) - { - elementCount = arrayInitType->getElementCount(); - } - else - { - // Nothing to do: we couldn't infer a size - return; - } - - // Create a new array type based on the size we found, - // and install it into our type. - varDecl->type.type = getArrayType( - m_astBuilder, - arrayType->getElementType(), - elementCount); + resultType = CheckProperType(resultType); } - - void SemanticsVisitor::validateArraySizeForVariable(VarDeclBase* varDecl) + else if (!funcDecl->returnType.type) { - auto arrayType = as(varDecl->type); - if (!arrayType) return; - - if (arrayType->isUnsized()) - { - // Note(tfoley): For now we allow arrays of unspecified size - // everywhere, because some source languages (e.g., GLSL) - // allow them in specific cases. -#if 0 - getSink()->diagnose(varDecl, Diagnostics::invalidArraySize); -#endif - return; - } - - // TODO(tfoley): How to handle the case where bound isn't known? - auto elementCount = arrayType->getElementCount(); - if (GetMinBound(elementCount) <= 0) - { - getSink()->diagnose(varDecl, Diagnostics::invalidArraySize); - return; - } + resultType = TypeExp(m_astBuilder->getVoidType()); } + funcDecl->returnType = resultType; - void SemanticsDeclBasesVisitor::_validateExtensionDeclTargetType(ExtensionDecl* decl) - { - if (auto targetDeclRefType = as(decl->targetType)) - { - // Attach our extension to that type as a candidate... - if (targetDeclRefType->getDeclRef().as()) - { - getSink()->diagnose(decl->targetType.exp, Diagnostics::invalidExtensionOnInterface, decl->targetType); - return; - } - else if (auto aggTypeDeclRef = targetDeclRefType->getDeclRef().as()) - { - auto aggTypeDecl = aggTypeDeclRef.getDecl(); - - getShared()->registerCandidateExtension(aggTypeDecl, decl); - - return; - } - else if (auto genericTypeParamDecl = targetDeclRefType->getDeclRef().as()) - { - // If we are extending a generic type parameter as in `extension T`, - // we want to register the extension with the interface type `IFoo` instead. - auto genericDecl = as(genericTypeParamDecl.getDecl()->parentDecl); - if (!genericDecl) - goto error; - if (genericDecl != decl->parentDecl) - goto error; - bool isTypeConstrained = false; - for (auto constraintDecl : genericDecl->getMembersOfType()) - { - ensureDecl(constraintDecl, DeclCheckState::ReadyForReference); - if (targetDeclRefType == constraintDecl->sub.type) - { - auto supTypeDeclRef = isDeclRefTypeOf(constraintDecl->sup.type); - getShared()->registerCandidateExtension(supTypeDeclRef.getDecl(), decl); - isTypeConstrained = true; - } - } - if (isTypeConstrained) - return; - } - } - error:; - if (!as(decl->targetType.type)) - { - getSink()->diagnose(decl->targetType.exp, Diagnostics::invalidExtensionOnType, decl->targetType); - } - } + checkCallableDeclCommon(funcDecl); +} - void SemanticsDeclBasesVisitor::_validateExtensionDeclMembers(ExtensionDecl* decl) - { - for (auto m : decl->members) - { - auto ctor = as(m); - if (!ctor || !ctor->body || ctor->members.getCount() != 0) - continue; - getSink()->diagnose(m->loc, Diagnostics::invalidMemberTypeInExtension, m->astNodeType); - } - } +IntegerLiteralValue SemanticsVisitor::GetMinBound(IntVal* val) +{ + if (auto constantVal = as(val)) + return constantVal->getValue(); - void SemanticsDeclBasesVisitor::visitExtensionDecl(ExtensionDecl* decl) - { - // We check the target type expression and members, and then validate - // that the type it names is one that it makes sense - // to extend. - // - decl->targetType = CheckProperType(decl->targetType); + // TODO(tfoley): Need to track intervals so that this isn't just a lie... + return 1; +} - _validateExtensionDeclTargetType(decl); +void SemanticsVisitor::maybeInferArraySizeForVariable(VarDeclBase* varDecl) +{ + // Not an array? + auto arrayType = as(varDecl->type); + if (!arrayType) + return; - _validateExtensionDeclMembers(decl); + // Explicit element count given? + if (!isUnsizedArrayType(arrayType)) + return; - for( auto inheritanceDecl : decl->getMembersOfType() ) - { - ensureDecl(inheritanceDecl, DeclCheckState::CanUseBaseOfInheritanceDecl); - auto baseType = inheritanceDecl->base.type; + // No initializer? + auto initExpr = varDecl->initExpr; + if (!initExpr) + return; - // It is possible that there was an error in checking the base type - // expression, and in such a case we shouldn't emit a cascading error. - // - if( const auto baseErrorType = as(baseType) ) - { - continue; - } + IntVal* elementCount = nullptr; - // An `extension` can only introduce inheritance from `interface` types. - // - // TODO: It might in theory make sense to allow an `extension` to - // introduce a non-`interface` base if we decide that an `extension` - // within the same module as the type it extends counts as just - // a continuation of the type's body (like a `partial class` in C#). - // - auto baseDeclRefType = as(baseType); - if( !baseDeclRefType ) - { - getSink()->diagnose(inheritanceDecl, Diagnostics::baseOfExtensionMustBeInterface, decl, baseType); - continue; - } + // Is the type of the initializer an array type? + if (auto arrayInitType = as(initExpr->type)) + { + elementCount = arrayInitType->getElementCount(); + } + else + { + // Nothing to do: we couldn't infer a size + return; + } - auto baseDeclRef = baseDeclRefType->getDeclRef(); - auto baseInterfaceDeclRef = baseDeclRef.as(); - if( !baseInterfaceDeclRef ) - { - getSink()->diagnose(inheritanceDecl, Diagnostics::baseOfExtensionMustBeInterface, decl, baseType); - continue; - } + // Create a new array type based on the size we found, + // and install it into our type. + varDecl->type.type = getArrayType(m_astBuilder, arrayType->getElementType(), elementCount); +} - // TODO: At this point we have the `baseInterfaceDeclRef` - // and could use it to perform further validity checks, - // and/or to build up a more refined representation of - // the inheritance graph for this extension (e.g., a "class - // precedence list"). - // - // E.g., we can/should check that we aren't introducing - // an inheritance relationship that already existed - // on the type as originally declared. +void SemanticsVisitor::validateArraySizeForVariable(VarDeclBase* varDecl) +{ + auto arrayType = as(varDecl->type); + if (!arrayType) + return; - _validateCrossModuleInheritance(decl, inheritanceDecl); - } + if (arrayType->isUnsized()) + { + // Note(tfoley): For now we allow arrays of unspecified size + // everywhere, because some source languages (e.g., GLSL) + // allow them in specific cases. +#if 0 + getSink()->diagnose(varDecl, Diagnostics::invalidArraySize); +#endif + return; } - Type* SemanticsVisitor::calcThisType(DeclRef declRef) + // TODO(tfoley): How to handle the case where bound isn't known? + auto elementCount = arrayType->getElementCount(); + if (GetMinBound(elementCount) <= 0) { - if( auto interfaceDeclRef = declRef.as() ) - { - // In the body of an `interface`, a `This` type - // refers to the concrete type that will eventually - // conform to the interface and fill in its - // requirements. - // - return DeclRefType::create( - m_astBuilder, - m_astBuilder->getDirectDeclRef(interfaceDeclRef.getDecl()->getThisTypeDecl())); - } - else if (auto aggTypeDeclRef = declRef.as()) - { - // In the body of an ordinary aggregate type, - // such as a `struct`, the `This` type just - // refers to the type itself. - // - // TODO: If/when we support `class` types - // with inheritance, then `This` inside a class - // would need to refer to the eventual concrete - // type, much like the `interface` case above. - // - return DeclRefType::create(m_astBuilder, aggTypeDeclRef); - } - else if (auto genTypeParam = declRef.as()) + getSink()->diagnose(varDecl, Diagnostics::invalidArraySize); + return; + } +} + +void SemanticsDeclBasesVisitor::_validateExtensionDeclTargetType(ExtensionDecl* decl) +{ + if (auto targetDeclRefType = as(decl->targetType)) + { + // Attach our extension to that type as a candidate... + if (targetDeclRefType->getDeclRef().as()) { - // We will reach here when we are checking `extension T {...}`, - // where inside the extension, `This` type is the target type - // of the extension, in this case this is a DeclRefType to - // a GenericTypeParamDecl. - // - return DeclRefType::create(m_astBuilder, declRef); + getSink()->diagnose( + decl->targetType.exp, + Diagnostics::invalidExtensionOnInterface, + decl->targetType); + return; } - else if (auto extDeclRef = declRef.as()) + else if (auto aggTypeDeclRef = targetDeclRefType->getDeclRef().as()) { - // In the body of an `extension`, the `This` - // type refers to the type being extended. - // - // Note: we currently have this loop back - // around through `calcThisType` for the - // type being extended, rather than just - // using it directly. This makes a difference - // for polymorphic types like `interface`s, - // and there are reasonable arguments for - // the validity of either option. - // - // Does `extension IFoo` mean extending - // exactly the type `IFoo` (an existential, - // which could at runtime be a value of - // any type conforming to `IFoo`), or does - // it implicitly extend every type that - // conforms to `IFoo`? The difference is - // significant, and we need to make a choice - // sooner or later. - // - ensureDecl(extDeclRef, DeclCheckState::CanUseExtensionTargetType); - auto targetType = getTargetType(m_astBuilder, extDeclRef); - return calcThisType(targetType); + auto aggTypeDecl = aggTypeDeclRef.getDecl(); + + getShared()->registerCandidateExtension(aggTypeDecl, decl); + + return; } - else + else if ( + auto genericTypeParamDecl = targetDeclRefType->getDeclRef().as()) { - return nullptr; + // If we are extending a generic type parameter as in `extension T`, + // we want to register the extension with the interface type `IFoo` instead. + auto genericDecl = as(genericTypeParamDecl.getDecl()->parentDecl); + if (!genericDecl) + goto error; + if (genericDecl != decl->parentDecl) + goto error; + bool isTypeConstrained = false; + for (auto constraintDecl : genericDecl->getMembersOfType()) + { + ensureDecl(constraintDecl, DeclCheckState::ReadyForReference); + if (targetDeclRefType == constraintDecl->sub.type) + { + auto supTypeDeclRef = isDeclRefTypeOf(constraintDecl->sup.type); + getShared()->registerCandidateExtension(supTypeDeclRef.getDecl(), decl); + isTypeConstrained = true; + } + } + if (isTypeConstrained) + return; } } +error:; + if (!as(decl->targetType.type)) + { + getSink()->diagnose( + decl->targetType.exp, + Diagnostics::invalidExtensionOnType, + decl->targetType); + } +} - Type* SemanticsVisitor::calcThisType(Type* type) +void SemanticsDeclBasesVisitor::_validateExtensionDeclMembers(ExtensionDecl* decl) +{ + for (auto m : decl->members) { - if( auto declRefType = as(type) ) - { - return calcThisType(declRefType->getDeclRef()); - } - else - { - return type; - } + auto ctor = as(m); + if (!ctor || !ctor->body || ctor->members.getCount() != 0) + continue; + getSink()->diagnose(m->loc, Diagnostics::invalidMemberTypeInExtension, m->astNodeType); } +} - Type* SemanticsVisitor::findResultTypeForConstructorDecl(ConstructorDecl* decl) +void SemanticsDeclBasesVisitor::visitExtensionDecl(ExtensionDecl* decl) +{ + // We check the target type expression and members, and then validate + // that the type it names is one that it makes sense + // to extend. + // + decl->targetType = CheckProperType(decl->targetType); + + _validateExtensionDeclTargetType(decl); + + _validateExtensionDeclMembers(decl); + + for (auto inheritanceDecl : decl->getMembersOfType()) { - // We want to look at the parent of the declaration, - // but if the declaration is generic, the parent will be - // the `GenericDecl` and we need to skip past that to - // the grandparent. + ensureDecl(inheritanceDecl, DeclCheckState::CanUseBaseOfInheritanceDecl); + auto baseType = inheritanceDecl->base.type; + + // It is possible that there was an error in checking the base type + // expression, and in such a case we shouldn't emit a cascading error. // - auto parent = decl->parentDecl; - auto genericParent = as(parent); - if (genericParent) + if (const auto baseErrorType = as(baseType)) { - parent = genericParent->parentDecl; + continue; } - // The result type for a constructor is whatever `This` would - // refer to in the body of the outer declaration. + // An `extension` can only introduce inheritance from `interface` types. // - auto thisType = calcThisType(makeDeclRef(parent)); - if( !thisType ) + // TODO: It might in theory make sense to allow an `extension` to + // introduce a non-`interface` base if we decide that an `extension` + // within the same module as the type it extends counts as just + // a continuation of the type's body (like a `partial class` in C#). + // + auto baseDeclRefType = as(baseType); + if (!baseDeclRefType) { - getSink()->diagnose(decl, Diagnostics::initializerNotInsideType); - thisType = m_astBuilder->getErrorType(); + getSink()->diagnose( + inheritanceDecl, + Diagnostics::baseOfExtensionMustBeInterface, + decl, + baseType); + continue; } - return thisType; - } - void SemanticsDeclHeaderVisitor::visitConstructorDecl(ConstructorDecl* decl) - { - // We need to compute the result tyep for this declaration, - // since it wasn't filled in for us. - decl->returnType.type = findResultTypeForConstructorDecl(decl); + auto baseDeclRef = baseDeclRefType->getDeclRef(); + auto baseInterfaceDeclRef = baseDeclRef.as(); + if (!baseInterfaceDeclRef) + { + getSink()->diagnose( + inheritanceDecl, + Diagnostics::baseOfExtensionMustBeInterface, + decl, + baseType); + continue; + } - checkCallableDeclCommon(decl); + // TODO: At this point we have the `baseInterfaceDeclRef` + // and could use it to perform further validity checks, + // and/or to build up a more refined representation of + // the inheritance graph for this extension (e.g., a "class + // precedence list"). + // + // E.g., we can/should check that we aren't introducing + // an inheritance relationship that already existed + // on the type as originally declared. + + _validateCrossModuleInheritance(decl, inheritanceDecl); } +} - void SemanticsDeclHeaderVisitor::visitAbstractStorageDeclCommon(ContainerDecl* decl) +Type* SemanticsVisitor::calcThisType(DeclRef declRef) +{ + if (auto interfaceDeclRef = declRef.as()) + { + // In the body of an `interface`, a `This` type + // refers to the concrete type that will eventually + // conform to the interface and fill in its + // requirements. + // + return DeclRefType::create( + m_astBuilder, + m_astBuilder->getDirectDeclRef(interfaceDeclRef.getDecl()->getThisTypeDecl())); + } + else if (auto aggTypeDeclRef = declRef.as()) { - // If we have a subscript or property declaration with no accessor declarations, - // then we should create a single `GetterDecl` to represent - // the implicit meaning of their declaration, so: + // In the body of an ordinary aggregate type, + // such as a `struct`, the `This` type just + // refers to the type itself. // - // subscript(uint index) -> T; - // property x : Y; + // TODO: If/when we support `class` types + // with inheritance, then `This` inside a class + // would need to refer to the eventual concrete + // type, much like the `interface` case above. // - // becomes: + return DeclRefType::create(m_astBuilder, aggTypeDeclRef); + } + else if (auto genTypeParam = declRef.as()) + { + // We will reach here when we are checking `extension T {...}`, + // where inside the extension, `This` type is the target type + // of the extension, in this case this is a DeclRefType to + // a GenericTypeParamDecl. + // + return DeclRefType::create(m_astBuilder, declRef); + } + else if (auto extDeclRef = declRef.as()) + { + // In the body of an `extension`, the `This` + // type refers to the type being extended. // - // subscript(uint index) -> T { get; } - // property x : Y { get; } + // Note: we currently have this loop back + // around through `calcThisType` for the + // type being extended, rather than just + // using it directly. This makes a difference + // for polymorphic types like `interface`s, + // and there are reasonable arguments for + // the validity of either option. // + // Does `extension IFoo` mean extending + // exactly the type `IFoo` (an existential, + // which could at runtime be a value of + // any type conforming to `IFoo`), or does + // it implicitly extend every type that + // conforms to `IFoo`? The difference is + // significant, and we need to make a choice + // sooner or later. + // + ensureDecl(extDeclRef, DeclCheckState::CanUseExtensionTargetType); + auto targetType = getTargetType(m_astBuilder, extDeclRef); + return calcThisType(targetType); + } + else + { + return nullptr; + } +} - bool anyAccessors = decl->getMembersOfType().isNonEmpty(); - - if(!anyAccessors) - { - GetterDecl* getterDecl = m_astBuilder->create(); - getterDecl->loc = decl->loc; +Type* SemanticsVisitor::calcThisType(Type* type) +{ + if (auto declRefType = as(type)) + { + return calcThisType(declRefType->getDeclRef()); + } + else + { + return type; + } +} - getterDecl->parentDecl = decl; - decl->members.add(getterDecl); - } +Type* SemanticsVisitor::findResultTypeForConstructorDecl(ConstructorDecl* decl) +{ + // We want to look at the parent of the declaration, + // but if the declaration is generic, the parent will be + // the `GenericDecl` and we need to skip past that to + // the grandparent. + // + auto parent = decl->parentDecl; + auto genericParent = as(parent); + if (genericParent) + { + parent = genericParent->parentDecl; } - void SemanticsDeclHeaderVisitor::visitSubscriptDecl(SubscriptDecl* decl) + // The result type for a constructor is whatever `This` would + // refer to in the body of the outer declaration. + // + auto thisType = calcThisType(makeDeclRef(parent)); + if (!thisType) { - decl->returnType = CheckUsableType(decl->returnType, decl); + getSink()->diagnose(decl, Diagnostics::initializerNotInsideType); + thisType = m_astBuilder->getErrorType(); + } + return thisType; +} - visitAbstractStorageDeclCommon(decl); +void SemanticsDeclHeaderVisitor::visitConstructorDecl(ConstructorDecl* decl) +{ + // We need to compute the result tyep for this declaration, + // since it wasn't filled in for us. + decl->returnType.type = findResultTypeForConstructorDecl(decl); - checkCallableDeclCommon(decl); - } + checkCallableDeclCommon(decl); +} + +void SemanticsDeclHeaderVisitor::visitAbstractStorageDeclCommon(ContainerDecl* decl) +{ + // If we have a subscript or property declaration with no accessor declarations, + // then we should create a single `GetterDecl` to represent + // the implicit meaning of their declaration, so: + // + // subscript(uint index) -> T; + // property x : Y; + // + // becomes: + // + // subscript(uint index) -> T { get; } + // property x : Y { get; } + // + + bool anyAccessors = decl->getMembersOfType().isNonEmpty(); - void SemanticsDeclHeaderVisitor::visitPropertyDecl(PropertyDecl* decl) + if (!anyAccessors) { - SemanticsVisitor subVisitor(withDeclToExcludeFromLookup(decl)); - decl->type = subVisitor.CheckUsableType(decl->type, decl); - visitAbstractStorageDeclCommon(decl); - checkVisibility(decl); + GetterDecl* getterDecl = m_astBuilder->create(); + getterDecl->loc = decl->loc; + + getterDecl->parentDecl = decl; + decl->members.add(getterDecl); } +} - Type* SemanticsDeclHeaderVisitor::_getAccessorStorageType(AccessorDecl* decl) +void SemanticsDeclHeaderVisitor::visitSubscriptDecl(SubscriptDecl* decl) +{ + decl->returnType = CheckUsableType(decl->returnType, decl); + + visitAbstractStorageDeclCommon(decl); + + checkCallableDeclCommon(decl); +} + +void SemanticsDeclHeaderVisitor::visitPropertyDecl(PropertyDecl* decl) +{ + SemanticsVisitor subVisitor(withDeclToExcludeFromLookup(decl)); + decl->type = subVisitor.CheckUsableType(decl->type, decl); + visitAbstractStorageDeclCommon(decl); + checkVisibility(decl); +} + +Type* SemanticsDeclHeaderVisitor::_getAccessorStorageType(AccessorDecl* decl) +{ + auto parentDecl = decl->parentDecl; + if (auto parentSubscript = as(parentDecl)) { - auto parentDecl = decl->parentDecl; - if (auto parentSubscript = as(parentDecl)) - { - ensureDecl(parentSubscript, DeclCheckState::CanUseTypeOfValueDecl); - return parentSubscript->returnType; - } - else if (auto parentProperty = as(parentDecl)) - { - ensureDecl(parentProperty, DeclCheckState::CanUseTypeOfValueDecl); - return parentProperty->type.type; - } - else - { - return getASTBuilder()->getErrorType(); - } + ensureDecl(parentSubscript, DeclCheckState::CanUseTypeOfValueDecl); + return parentSubscript->returnType; } - - void SemanticsDeclHeaderVisitor::_visitAccessorDeclCommon(AccessorDecl* decl) + else if (auto parentProperty = as(parentDecl)) { - // An accessor must appear nested inside a subscript or property declaration. - // - auto parentDecl = decl->parentDecl; - if (as(parentDecl)) - {} - else if (as(parentDecl)) - {} - else - { - getSink()->diagnose(decl, Diagnostics::accessorMustBeInsideSubscriptOrProperty); - } + ensureDecl(parentProperty, DeclCheckState::CanUseTypeOfValueDecl); + return parentProperty->type.type; + } + else + { + return getASTBuilder()->getErrorType(); } +} - void SemanticsDeclHeaderVisitor::visitAccessorDecl(AccessorDecl* decl) +void SemanticsDeclHeaderVisitor::_visitAccessorDeclCommon(AccessorDecl* decl) +{ + // An accessor must appear nested inside a subscript or property declaration. + // + auto parentDecl = decl->parentDecl; + if (as(parentDecl)) + { + } + else if (as(parentDecl)) + { + } + else { - _visitAccessorDeclCommon(decl); + getSink()->diagnose(decl, Diagnostics::accessorMustBeInsideSubscriptOrProperty); + } +} - // Note: This subroutine is used by both `get` - // and `ref` accessors, but is bypassed by - // `set` accessors (which use `visitSetterDecl` - // intead). +void SemanticsDeclHeaderVisitor::visitAccessorDecl(AccessorDecl* decl) +{ + _visitAccessorDeclCommon(decl); + + // Note: This subroutine is used by both `get` + // and `ref` accessors, but is bypassed by + // `set` accessors (which use `visitSetterDecl` + // intead). + + // Accessors (other than setters) don't support + // parameters. + // + if (decl->getParameters().getCount() != 0) + { + getSink()->diagnose(decl, Diagnostics::nonSetAccessorMustNotHaveParams); + } + + // By default, the return type of an accessor is treated as + // the type of the abstract storage location being accessed. + // + // A `ref` accessor currently relies on this logic even though + // it isn't quite correct, because we don't have support + // for by-reference return values today. This is a non-issue + // for now because we don't support user-defined `ref` + // accessors yet. + // + // TODO: Once we can support the by-reference return value + // correctly *or* we can move to something like a coroutine-based + // `modify` accessor (a la Swift), we should split out + // handling of `RefAccessorDecl` and only use this routine + // for `GetterDecl`s. + // + decl->returnType.type = _getAccessorStorageType(decl); +} - // Accessors (other than setters) don't support - // parameters. - // - if( decl->getParameters().getCount() != 0 ) - { - getSink()->diagnose(decl, Diagnostics::nonSetAccessorMustNotHaveParams); - } +void SemanticsDeclHeaderVisitor::visitSetterDecl(SetterDecl* decl) +{ + // Make sure to invoke the common checking logic for all accessors. + _visitAccessorDeclCommon(decl); - // By default, the return type of an accessor is treated as - // the type of the abstract storage location being accessed. - // - // A `ref` accessor currently relies on this logic even though - // it isn't quite correct, because we don't have support - // for by-reference return values today. This is a non-issue - // for now because we don't support user-defined `ref` - // accessors yet. - // - // TODO: Once we can support the by-reference return value - // correctly *or* we can move to something like a coroutine-based - // `modify` accessor (a la Swift), we should split out - // handling of `RefAccessorDecl` and only use this routine - // for `GetterDecl`s. - // - decl->returnType.type = _getAccessorStorageType(decl); - } + // A `set` accessor always returns `void`. + // + decl->returnType.type = getASTBuilder()->getVoidType(); - void SemanticsDeclHeaderVisitor::visitSetterDecl(SetterDecl* decl) + // A setter always receives a single value representing + // the new value to set into the storage. + // + // The user may declare that parameter explicitly and + // thereby control its name, or they can declare no + // parmaeters and allow the compiler to synthesize one + // names `newValue`. + // + ParamDecl* newValueParam = nullptr; + auto params = decl->getParameters(); + if (params.getCount() >= 1) { - // Make sure to invoke the common checking logic for all accessors. - _visitAccessorDeclCommon(decl); - - // A `set` accessor always returns `void`. + // If the user declared an explicit parameter + // then that is the one that will represent + // the new value. // - decl->returnType.type = getASTBuilder()->getVoidType(); + newValueParam = params.getFirst(); - // A setter always receives a single value representing - // the new value to set into the storage. - // - // The user may declare that parameter explicitly and - // thereby control its name, or they can declare no - // parmaeters and allow the compiler to synthesize one - // names `newValue`. - // - ParamDecl* newValueParam = nullptr; - auto params = decl->getParameters(); - if( params.getCount() >= 1 ) + if (params.getCount() > 1) { - // If the user declared an explicit parameter - // then that is the one that will represent - // the new value. + // If the user declared more than one explicit + // parameter, then that is an error. // - newValueParam = params.getFirst(); - - if( params.getCount() > 1 ) - { - // If the user declared more than one explicit - // parameter, then that is an error. - // - getSink()->diagnose(params[1], Diagnostics::setAccessorMayNotHaveMoreThanOneParam); - } + getSink()->diagnose(params[1], Diagnostics::setAccessorMayNotHaveMoreThanOneParam); } - else - { - // If the user didn't declare any explicit parameters, - // then we create an implicit one and add it into - // the AST. - // - newValueParam = m_astBuilder->create(); - newValueParam->nameAndLoc.name = getName("newValue"); - newValueParam->nameAndLoc.loc = decl->loc; + } + else + { + // If the user didn't declare any explicit parameters, + // then we create an implicit one and add it into + // the AST. + // + newValueParam = m_astBuilder->create(); + newValueParam->nameAndLoc.name = getName("newValue"); + newValueParam->nameAndLoc.loc = decl->loc; - newValueParam->parentDecl = decl; - decl->members.add(newValueParam); - } + newValueParam->parentDecl = decl; + decl->members.add(newValueParam); + } - // The new-value parameter is expected to have the - // same type as the abstract storage that the - // accessor is setting. - // - auto newValueType = _getAccessorStorageType(decl); + // The new-value parameter is expected to have the + // same type as the abstract storage that the + // accessor is setting. + // + auto newValueType = _getAccessorStorageType(decl); - // It is allowed and encouraged for the programmer - // to leave off the type on the new-value parameter, - // in which case we will set it to the expected - // type automatically. + // It is allowed and encouraged for the programmer + // to leave off the type on the new-value parameter, + // in which case we will set it to the expected + // type automatically. + // + if (!newValueParam->type.exp) + { + newValueParam->type.type = newValueType; + } + else + { + // If the user *did* give the new-value parameter + // an explicit type, then we need to check it + // and then enforce that it matches what we expect. // - if( !newValueParam->type.exp ) + auto actualType = CheckProperType(newValueParam->type); + + if (as(actualType)) + { + } + else if (actualType->equals(newValueType)) { - newValueParam->type.type = newValueType; } else { - // If the user *did* give the new-value parameter - // an explicit type, then we need to check it - // and then enforce that it matches what we expect. - // - auto actualType = CheckProperType(newValueParam->type); - - if(as(actualType)) - {} - else if(actualType->equals(newValueType)) - {} - else - { - getSink()->diagnose(newValueParam, Diagnostics::setAccessorParamWrongType, newValueParam, actualType, newValueType); - } + getSink()->diagnose( + newValueParam, + Diagnostics::setAccessorParamWrongType, + newValueParam, + actualType, + newValueType); } } +} - GenericDecl* SemanticsVisitor::GetOuterGeneric(Decl* decl) - { - auto parentDecl = decl->parentDecl; - if (!parentDecl) return nullptr; - auto parentGeneric = as(parentDecl); - return parentGeneric; - } +GenericDecl* SemanticsVisitor::GetOuterGeneric(Decl* decl) +{ + auto parentDecl = decl->parentDecl; + if (!parentDecl) + return nullptr; + auto parentGeneric = as(parentDecl); + return parentGeneric; +} + +Decl* SemanticsVisitor::getOuterGenericOrSelf(Decl* decl) +{ + auto parentDecl = decl->parentDecl; + if (!parentDecl) + return decl; + auto parentGeneric = as(parentDecl); + if (!parentGeneric) + return decl; + return parentGeneric; +} - Decl* SemanticsVisitor::getOuterGenericOrSelf(Decl* decl) +GenericDecl* SemanticsVisitor::findNextOuterGeneric(Decl* decl) +{ + for (auto p = decl->parentDecl; p; p = p->parentDecl) { - auto parentDecl = decl->parentDecl; - if (!parentDecl) return decl; - auto parentGeneric = as(parentDecl); - if (!parentGeneric) return decl; - return parentGeneric; + if (auto genDecl = as(p)) + return genDecl; } + return nullptr; +} - GenericDecl* SemanticsVisitor::findNextOuterGeneric(Decl* decl) +DeclRef SemanticsVisitor::applyExtensionToType( + ExtensionDecl* extDecl, + Type* type, + Dictionary* additionalSubtypeWitnessesForType) +{ + DeclRef extDeclRef = makeDeclRef(extDecl); + + // If the extension is a generic extension, then we + // need to infer type arguments that will give + // us a target type that matches `type`. + // + if (auto extGenericDecl = GetOuterGeneric(extDecl)) { - for (auto p = decl->parentDecl; p; p = p->parentDecl) + ConstraintSystem constraints; + constraints.loc = extDecl->loc; + constraints.genericDecl = extGenericDecl; + if (additionalSubtypeWitnessesForType) { - if (auto genDecl = as(p)) - return genDecl; + constraints.subTypeForAdditionalWitnesses = type; + constraints.additionalSubtypeWitnesses = additionalSubtypeWitnessesForType; } - return nullptr; - } - DeclRef SemanticsVisitor::applyExtensionToType( - ExtensionDecl* extDecl, - Type* type, - Dictionary* additionalSubtypeWitnessesForType) - { - DeclRef extDeclRef = makeDeclRef(extDecl); - - // If the extension is a generic extension, then we - // need to infer type arguments that will give - // us a target type that matches `type`. - // - if (auto extGenericDecl = GetOuterGeneric(extDecl)) + // Inside the body of an extension declaration, we may end up trying to apply that + // extension to its own target type. + // If we see that we are in that case, we can apply the extension declaration as - is, + // without any additional substitutions. + if (extDecl->targetType->equals(type)) { - ConstraintSystem constraints; - constraints.loc = extDecl->loc; - constraints.genericDecl = extGenericDecl; - if (additionalSubtypeWitnessesForType) - { - constraints.subTypeForAdditionalWitnesses = type; - constraints.additionalSubtypeWitnesses = additionalSubtypeWitnessesForType; - } - - // Inside the body of an extension declaration, we may end up trying to apply that - // extension to its own target type. - // If we see that we are in that case, we can apply the extension declaration as - is, - // without any additional substitutions. - if (extDecl->targetType->equals(type)) - { - return createDefaultSubstitutionsIfNeeded(m_astBuilder, this, extDeclRef).as(); - } - - if (!TryUnifyTypes(constraints, ValUnificationContext(), extDecl->targetType.Ptr(), type)) - return DeclRef(); + return createDefaultSubstitutionsIfNeeded(m_astBuilder, this, extDeclRef) + .as(); + } - ConversionCost baseCost; - auto solvedDeclRef = trySolveConstraintSystem(&constraints, makeDeclRef(extGenericDecl), ArrayView(), baseCost); - if (!solvedDeclRef) - { - return DeclRef(); - } + if (!TryUnifyTypes(constraints, ValUnificationContext(), extDecl->targetType.Ptr(), type)) + return DeclRef(); - // Construct a reference to the extension with our constraint variables - // set as they were found by solving the constraint system. - extDeclRef = solvedDeclRef.as(); + ConversionCost baseCost; + auto solvedDeclRef = trySolveConstraintSystem( + &constraints, + makeDeclRef(extGenericDecl), + ArrayView(), + baseCost); + if (!solvedDeclRef) + { + return DeclRef(); } - // Now extract the target type from our (possibly specialized) extension decl-ref. - Type* targetType = getTargetType(m_astBuilder, extDeclRef); + // Construct a reference to the extension with our constraint variables + // set as they were found by solving the constraint system. + extDeclRef = solvedDeclRef.as(); + } + + // Now extract the target type from our (possibly specialized) extension decl-ref. + Type* targetType = getTargetType(m_astBuilder, extDeclRef); - // As a bit of a kludge here, if the target type of the extension is - // an interface, and the `type` we are trying to match up has a this-type - // substitution for that interface, then we want to attach a matching - // substitution to the extension decl-ref. - if(auto targetDeclRefType = as(targetType)) + // As a bit of a kludge here, if the target type of the extension is + // an interface, and the `type` we are trying to match up has a this-type + // substitution for that interface, then we want to attach a matching + // substitution to the extension decl-ref. + if (auto targetDeclRefType = as(targetType)) + { + if (auto targetInterfaceDeclRef = targetDeclRefType->getDeclRef().as()) { - if(auto targetInterfaceDeclRef = targetDeclRefType->getDeclRef().as()) + // Okay, the target type is an interface. + // + // Is the type we want to apply to a ThisType? + if (auto appDeclRefType = as(type)) { - // Okay, the target type is an interface. - // - // Is the type we want to apply to a ThisType? - if(auto appDeclRefType = as(type)) + if (auto thisTypeLookupDeclRef = + SubstitutionSet(appDeclRefType->getDeclRef()).findLookupDeclRef()) { - if(auto thisTypeLookupDeclRef = SubstitutionSet(appDeclRefType->getDeclRef()).findLookupDeclRef()) + if (thisTypeLookupDeclRef->getDecl() == targetInterfaceDeclRef.getDecl()) { - if(thisTypeLookupDeclRef->getDecl() == targetInterfaceDeclRef.getDecl()) - { - // Looks like we have a match in the types, - // now let's see if `type`'s declref starts with a Lookup. - targetType = type; - extDeclRef = m_astBuilder->getLookupDeclRef( - thisTypeLookupDeclRef->getWitness(), extDeclRef.getDecl()) - .as(); - } + // Looks like we have a match in the types, + // now let's see if `type`'s declref starts with a Lookup. + targetType = type; + extDeclRef = m_astBuilder + ->getLookupDeclRef( + thisTypeLookupDeclRef->getWitness(), + extDeclRef.getDecl()) + .as(); } } } } - - // In order for this extension to apply to the given type, we - // need to have a match on the target types. - if (!type->equals(targetType)) - return DeclRef(); - - - return extDeclRef; - } - - QualType SemanticsVisitor::GetTypeForDeclRef(DeclRef declRef, SourceLoc loc) - { - Type* typeResult = nullptr; - return getTypeForDeclRef( - m_astBuilder, - this, - getSink(), - declRef, - &typeResult, - loc); } - void SemanticsVisitor::importFileDeclIntoScope(Scope* scope, FileDecl* fileDecl) - { - // Create a new sub-scope to wire the module - // into our lookup chain. - if (!fileDecl) - return; - addSiblingScopeForContainerDecl(getASTBuilder(), scope, fileDecl); - } + // In order for this extension to apply to the given type, we + // need to have a match on the target types. + if (!type->equals(targetType)) + return DeclRef(); - void SemanticsVisitor::importModuleIntoScope(Scope* scope, ModuleDecl* moduleDecl) - { - if (!moduleDecl) - return; - // If we've imported this one already, then - // skip the step where we modify the current scope. - auto& importedModulesList = getShared()->importedModulesList; - auto& importedModulesSet = getShared()->importedModulesSet; - if (importedModulesSet.contains(moduleDecl)) - { - return; - } - importedModulesList.add(moduleDecl); - importedModulesSet.add(moduleDecl); + return extDeclRef; +} - // Create a new sub-scope to wire the module's scope and its nested FileDecl's scopes - // into our lookup chain. - for (auto moduleScope = moduleDecl->ownedScope; moduleScope; moduleScope = moduleScope->nextSibling) - { - if (moduleScope->containerDecl != moduleDecl && moduleScope->containerDecl->parentDecl != moduleDecl) - continue; +QualType SemanticsVisitor::GetTypeForDeclRef(DeclRef declRef, SourceLoc loc) +{ + Type* typeResult = nullptr; + return getTypeForDeclRef(m_astBuilder, this, getSink(), declRef, &typeResult, loc); +} - addSiblingScopeForContainerDecl(getASTBuilder(), scope, moduleScope->containerDecl); - } +void SemanticsVisitor::importFileDeclIntoScope(Scope* scope, FileDecl* fileDecl) +{ + // Create a new sub-scope to wire the module + // into our lookup chain. + if (!fileDecl) + return; + addSiblingScopeForContainerDecl(getASTBuilder(), scope, fileDecl); +} - // Also import any modules from nested `import` declarations - // with the `__exported` modifier - for (auto importDecl : moduleDecl->getMembersOfType()) - { - if (!importDecl->hasModifier()) - continue; +void SemanticsVisitor::importModuleIntoScope(Scope* scope, ModuleDecl* moduleDecl) +{ + if (!moduleDecl) + return; - importModuleIntoScope(scope, importDecl->importedModuleDecl); - } + // If we've imported this one already, then + // skip the step where we modify the current scope. + auto& importedModulesList = getShared()->importedModulesList; + auto& importedModulesSet = getShared()->importedModulesSet; + if (importedModulesSet.contains(moduleDecl)) + { + return; } + importedModulesList.add(moduleDecl); + importedModulesSet.add(moduleDecl); - void SemanticsDeclHeaderVisitor::visitImportDecl(ImportDecl* decl) + // Create a new sub-scope to wire the module's scope and its nested FileDecl's scopes + // into our lookup chain. + for (auto moduleScope = moduleDecl->ownedScope; moduleScope; + moduleScope = moduleScope->nextSibling) { - // We need to look for a module with the specified name - // (whether it has already been loaded, or needs to - // be loaded), and then put its declarations into - // the module's scope. - - auto name = decl->moduleNameAndLoc.name; - auto scope = getModuleDecl(decl)->ownedScope; - - // Try to load a module matching the name - auto importedModule = findOrImportModule( - getLinkage(), - name, - decl->moduleNameAndLoc.loc, - getSink(), - getShared()->m_environmentModules); - - // If we didn't find a matching module, then bail out - if (!importedModule) - return; + if (moduleScope->containerDecl != moduleDecl && + moduleScope->containerDecl->parentDecl != moduleDecl) + continue; - // Record the module that was imported, so that we can use - // it later during code generation. - auto importedModuleDecl = importedModule->getModuleDecl(); - decl->importedModuleDecl = importedModuleDecl; + addSiblingScopeForContainerDecl(getASTBuilder(), scope, moduleScope->containerDecl); + } - // Add the declarations from the imported module into the scope - // that the `import` declaration is set to extend. - // - importModuleIntoScope(scope, importedModuleDecl); + // Also import any modules from nested `import` declarations + // with the `__exported` modifier + for (auto importDecl : moduleDecl->getMembersOfType()) + { + if (!importDecl->hasModifier()) + continue; - // Record the `import`ed module (and everything it depends on) - // as a dependency of the module we are compiling. - if(auto module = getModule(decl)) - { - module->addModuleDependency(importedModule); - } + importModuleIntoScope(scope, importDecl->importedModuleDecl); } +} - String getSimpleModuleName(Name* name) - { - auto text = getText(name); - auto dirPos = Math::Max(text.indexOf('/'), text.indexOf('\\')); - if (dirPos < 0) - return text; - auto slice = text.getUnownedSlice().tail(dirPos + 1); - auto dotPos = slice.indexOf('.'); - if (dotPos < 0) - return slice; - return String(slice.head(dotPos)); +void SemanticsDeclHeaderVisitor::visitImportDecl(ImportDecl* decl) +{ + // We need to look for a module with the specified name + // (whether it has already been loaded, or needs to + // be loaded), and then put its declarations into + // the module's scope. + + auto name = decl->moduleNameAndLoc.name; + auto scope = getModuleDecl(decl)->ownedScope; + + // Try to load a module matching the name + auto importedModule = findOrImportModule( + getLinkage(), + name, + decl->moduleNameAndLoc.loc, + getSink(), + getShared()->m_environmentModules); + + // If we didn't find a matching module, then bail out + if (!importedModule) + return; + + // Record the module that was imported, so that we can use + // it later during code generation. + auto importedModuleDecl = importedModule->getModuleDecl(); + decl->importedModuleDecl = importedModuleDecl; + + // Add the declarations from the imported module into the scope + // that the `import` declaration is set to extend. + // + importModuleIntoScope(scope, importedModuleDecl); + + // Record the `import`ed module (and everything it depends on) + // as a dependency of the module we are compiling. + if (auto module = getModule(decl)) + { + module->addModuleDependency(importedModule); } +} + +String getSimpleModuleName(Name* name) +{ + auto text = getText(name); + auto dirPos = Math::Max(text.indexOf('/'), text.indexOf('\\')); + if (dirPos < 0) + return text; + auto slice = text.getUnownedSlice().tail(dirPos + 1); + auto dotPos = slice.indexOf('.'); + if (dotPos < 0) + return slice; + return String(slice.head(dotPos)); +} - ModuleDeclarationDecl* findExistingModuleDeclarationDecl(ModuleDecl* decl) +ModuleDeclarationDecl* findExistingModuleDeclarationDecl(ModuleDecl* decl) +{ + if (decl->members.getCount() == 0) + return nullptr; + if (auto rs = as(decl->members[0])) + return rs; + for (auto fileDecl : decl->getMembersOfType()) { - if (decl->members.getCount() == 0) - return nullptr; - if (auto rs = as(decl->members[0])) + if (fileDecl->members.getCount() == 0) + continue; + if (auto rs = as(fileDecl->members[0])) return rs; - for (auto fileDecl : decl->getMembersOfType()) - { - if (fileDecl->members.getCount() == 0) - continue; - if (auto rs = as(fileDecl->members[0])) - return rs; - } - return nullptr; } + return nullptr; +} - void SemanticsDeclHeaderVisitor::visitIncludeDecl(IncludeDecl* decl) - { - auto name = decl->moduleNameAndLoc.name; +void SemanticsDeclHeaderVisitor::visitIncludeDecl(IncludeDecl* decl) +{ + auto name = decl->moduleNameAndLoc.name; - if (!getShared()->getTranslationUnitRequest()) - getSink()->diagnose(decl->moduleNameAndLoc.loc, Diagnostics::cannotProcessInclude); + if (!getShared()->getTranslationUnitRequest()) + getSink()->diagnose(decl->moduleNameAndLoc.loc, Diagnostics::cannotProcessInclude); - auto parentModule = getModule(decl); - auto moduleDecl = parentModule->getModuleDecl(); + auto parentModule = getModule(decl); + auto moduleDecl = parentModule->getModuleDecl(); - auto [fileDecl, isNew] = getLinkage()->findAndIncludeFile(getModule(decl), getShared()->getTranslationUnitRequest(), name, decl->moduleNameAndLoc.loc, getSink()); + auto [fileDecl, isNew] = getLinkage()->findAndIncludeFile( + getModule(decl), + getShared()->getTranslationUnitRequest(), + name, + decl->moduleNameAndLoc.loc, + getSink()); - if (!fileDecl) - return; + if (!fileDecl) + return; - decl->fileDecl = fileDecl; + decl->fileDecl = fileDecl; - if (!isNew) - return; + if (!isNew) + return; - if (fileDecl->members.getCount() == 0) - return; - auto firstMember = fileDecl->members[0]; - if (auto moduleDeclaration = as(firstMember)) - { - // We are trying to include a file that defines a module, the user could mean "import" instead. - getSink()->diagnose(decl->moduleNameAndLoc.loc, Diagnostics::includedFileMissingImplementingDoYouMeanImport, name, moduleDeclaration->getName()); - return; - } + if (fileDecl->members.getCount() == 0) + return; + auto firstMember = fileDecl->members[0]; + if (auto moduleDeclaration = as(firstMember)) + { + // We are trying to include a file that defines a module, the user could mean "import" + // instead. + getSink()->diagnose( + decl->moduleNameAndLoc.loc, + Diagnostics::includedFileMissingImplementingDoYouMeanImport, + name, + moduleDeclaration->getName()); + return; + } - importFileDeclIntoScope(moduleDecl->ownedScope, fileDecl); + importFileDeclIntoScope(moduleDecl->ownedScope, fileDecl); - if (auto implementing = as(firstMember)) + if (auto implementing = as(firstMember)) + { + // The file we are including must be implementing the current module. + auto moduleName = getSimpleModuleName(implementing->moduleNameAndLoc.name); + auto expectedModuleName = moduleDecl->getName(); + bool shouldSkipDiagnostic = false; + if (moduleDecl->members.getCount()) { - // The file we are including must be implementing the current module. - auto moduleName = getSimpleModuleName(implementing->moduleNameAndLoc.name); - auto expectedModuleName = moduleDecl->getName(); - bool shouldSkipDiagnostic = false; - if (moduleDecl->members.getCount()) + if (auto moduleDeclarationDecl = as(moduleDecl->members[0])) { - if (auto moduleDeclarationDecl = as(moduleDecl->members[0])) + expectedModuleName = moduleDeclarationDecl->getName(); + } + else if (getShared()->isInLanguageServer()) + { + auto moduleDeclarationDecls = findExistingModuleDeclarationDecl(moduleDecl); + if (moduleDeclarationDecls) { - expectedModuleName = moduleDeclarationDecl->getName(); + expectedModuleName = moduleDeclarationDecls->getName(); } - else if (getShared()->isInLanguageServer()) + else { - auto moduleDeclarationDecls = findExistingModuleDeclarationDecl(moduleDecl); - if (moduleDeclarationDecls) - { - expectedModuleName = moduleDeclarationDecls->getName(); - } - else - { - shouldSkipDiagnostic = true; - } + shouldSkipDiagnostic = true; } } - if (!shouldSkipDiagnostic && !moduleName.getUnownedSlice().caseInsensitiveEquals(getText(expectedModuleName).getUnownedSlice())) - { - getSink()->diagnose(decl->moduleNameAndLoc.loc, Diagnostics::includedFileDoesNotImplementCurrentModule, expectedModuleName, moduleName); - return; - } + } + if (!shouldSkipDiagnostic && !moduleName.getUnownedSlice().caseInsensitiveEquals( + getText(expectedModuleName).getUnownedSlice())) + { + getSink()->diagnose( + decl->moduleNameAndLoc.loc, + Diagnostics::includedFileDoesNotImplementCurrentModule, + expectedModuleName, + moduleName); return; } - - getSink()->diagnose(decl->moduleNameAndLoc.loc, Diagnostics::includedFileMissingImplementing, name); + return; } - void SemanticsDeclScopeWiringVisitor::visitImplementingDecl(ImplementingDecl* decl) - { - // Don't need to do anything unless we are in a language server context. - if (!getShared()->isInLanguageServer()) - return; - - // Treat an `implementing` declaration as an `include` declaration when - // we are in a language server context. + getSink()->diagnose( + decl->moduleNameAndLoc.loc, + Diagnostics::includedFileMissingImplementing, + name); +} - auto name = decl->moduleNameAndLoc.name; +void SemanticsDeclScopeWiringVisitor::visitImplementingDecl(ImplementingDecl* decl) +{ + // Don't need to do anything unless we are in a language server context. + if (!getShared()->isInLanguageServer()) + return; - if (!getShared()->getTranslationUnitRequest()) - getSink()->diagnose(decl->moduleNameAndLoc.loc, Diagnostics::cannotProcessInclude); + // Treat an `implementing` declaration as an `include` declaration when + // we are in a language server context. - auto [fileDecl, isNew] = getLinkage()->findAndIncludeFile(getModule(decl), getShared()->getTranslationUnitRequest(), name, decl->moduleNameAndLoc.loc, getSink()); + auto name = decl->moduleNameAndLoc.name; - decl->fileDecl = fileDecl; + if (!getShared()->getTranslationUnitRequest()) + getSink()->diagnose(decl->moduleNameAndLoc.loc, Diagnostics::cannotProcessInclude); - if (!isNew) - return; + auto [fileDecl, isNew] = getLinkage()->findAndIncludeFile( + getModule(decl), + getShared()->getTranslationUnitRequest(), + name, + decl->moduleNameAndLoc.loc, + getSink()); - if (!fileDecl || fileDecl->members.getCount() == 0) - { - return; - } + decl->fileDecl = fileDecl; - auto firstMember = fileDecl->members[0]; - if (as(firstMember)) - { - // We are trying to implement a file that defines a module, this is expected. - } - else if (as(firstMember)) - { - getSink()->diagnose(decl->moduleNameAndLoc.loc, Diagnostics::implementingMustReferencePrimaryModuleFile); - return; - } + if (!isNew) + return; - if (auto moduleDecl = getModuleDecl(decl)) - importFileDeclIntoScope(moduleDecl->ownedScope, fileDecl); + if (!fileDecl || fileDecl->members.getCount() == 0) + { + return; } - void SemanticsDeclScopeWiringVisitor::visitUsingDecl(UsingDecl* decl) + auto firstMember = fileDecl->members[0]; + if (as(firstMember)) { - // First, we need to look up whatever the argument of the `using` - // declaration names. - // - decl->arg = CheckTerm(decl->arg); + // We are trying to implement a file that defines a module, this is expected. + } + else if (as(firstMember)) + { + getSink()->diagnose( + decl->moduleNameAndLoc.loc, + Diagnostics::implementingMustReferencePrimaryModuleFile); + return; + } - // Next, we want to ensure that whatever is being named by `decl->arg` - // is a namespace (or a module, since modules are namespace-like). - // - // If a user `import`s multiple modules that all have namespaces - // of the same name, it would be possible for `decl->arg` to be overloaded. - // To handle that case, we will iterate over all the entities that are - // named and import any that are namespace-like. - // - bool scopesAdded = false; - bool hasValidNamespace = false; + if (auto moduleDecl = getModuleDecl(decl)) + importFileDeclIntoScope(moduleDecl->ownedScope, fileDecl); +} - // TODO: consider caching the scope set in NamespaceDecl. - HashSet addedScopes; - for (auto s = decl->scope; s; s = s->nextSibling) - addedScopes.add(s->containerDecl); +void SemanticsDeclScopeWiringVisitor::visitUsingDecl(UsingDecl* decl) +{ + // First, we need to look up whatever the argument of the `using` + // declaration names. + // + decl->arg = CheckTerm(decl->arg); - auto addAllSiblingScopesFromDecl = [&](Scope* scope, ContainerDecl* containerDecl) - { - for (auto s = containerDecl->ownedScope; s; s = s->nextSibling) - { - if (addedScopes.add(s->containerDecl)) - { - scopesAdded = true; - addSiblingScopeForContainerDecl(getASTBuilder(), scope, s->containerDecl); - } - } - }; + // Next, we want to ensure that whatever is being named by `decl->arg` + // is a namespace (or a module, since modules are namespace-like). + // + // If a user `import`s multiple modules that all have namespaces + // of the same name, it would be possible for `decl->arg` to be overloaded. + // To handle that case, we will iterate over all the entities that are + // named and import any that are namespace-like. + // + bool scopesAdded = false; + bool hasValidNamespace = false; + + // TODO: consider caching the scope set in NamespaceDecl. + HashSet addedScopes; + for (auto s = decl->scope; s; s = s->nextSibling) + addedScopes.add(s->containerDecl); - if (auto declRefExpr = as(decl->arg)) + auto addAllSiblingScopesFromDecl = [&](Scope* scope, ContainerDecl* containerDecl) + { + for (auto s = containerDecl->ownedScope; s; s = s->nextSibling) { - if (auto namespaceDeclRef = declRefExpr->declRef.as()) + if (addedScopes.add(s->containerDecl)) { - auto namespaceDecl = namespaceDeclRef.getDecl(); - addAllSiblingScopesFromDecl(decl->scope, namespaceDecl); - hasValidNamespace = true; + scopesAdded = true; + addSiblingScopeForContainerDecl(getASTBuilder(), scope, s->containerDecl); } } - else if (auto overloadedExpr = as(decl->arg)) + }; + + if (auto declRefExpr = as(decl->arg)) + { + if (auto namespaceDeclRef = declRefExpr->declRef.as()) { - for (auto item : overloadedExpr->lookupResult2) - { - if (auto namespaceDeclRef = item.declRef.as()) - { - addAllSiblingScopesFromDecl(decl->scope, namespaceDeclRef.getDecl()); - hasValidNamespace = true; - } - } + auto namespaceDecl = namespaceDeclRef.getDecl(); + addAllSiblingScopesFromDecl(decl->scope, namespaceDecl); + hasValidNamespace = true; } - - if (!scopesAdded) + } + else if (auto overloadedExpr = as(decl->arg)) + { + for (auto item : overloadedExpr->lookupResult2) { - if (!hasValidNamespace) - getSink()->diagnose(decl->arg, Diagnostics::expectedANamespace, decl->arg->type); - return; + if (auto namespaceDeclRef = item.declRef.as()) + { + addAllSiblingScopesFromDecl(decl->scope, namespaceDeclRef.getDecl()); + hasValidNamespace = true; + } } } - void SemanticsDeclScopeWiringVisitor::visitNamespaceDecl(NamespaceDecl* decl) + if (!scopesAdded) { - // We need to wire up the scope of namespaces with other namespace decls of the same name - // that is accessible from the current context. - auto parent = as(getParentDecl(decl)); - if (!parent) - return; - for (auto parentScope = parent->ownedScope; parentScope; parentScope = parentScope->parent) - { - for (auto scope = parentScope; scope; scope = scope->nextSibling) + if (!hasValidNamespace) + getSink()->diagnose(decl->arg, Diagnostics::expectedANamespace, decl->arg->type); + return; + } +} + +void SemanticsDeclScopeWiringVisitor::visitNamespaceDecl(NamespaceDecl* decl) +{ + // We need to wire up the scope of namespaces with other namespace decls of the same name + // that is accessible from the current context. + auto parent = as(getParentDecl(decl)); + if (!parent) + return; + for (auto parentScope = parent->ownedScope; parentScope; parentScope = parentScope->parent) + { + for (auto scope = parentScope; scope; scope = scope->nextSibling) + { + auto container = scope->containerDecl; + auto nsDeclPtr = container->getMemberDictionary().tryGetValue(decl->getName()); + if (!nsDeclPtr) + continue; + auto nsDecl = *nsDeclPtr; + for (auto ns = nsDecl; ns; ns = ns->nextInContainerWithSameName) { - auto container = scope->containerDecl; - auto nsDeclPtr = container->getMemberDictionary().tryGetValue(decl->getName()); - if (!nsDeclPtr) continue; - auto nsDecl = *nsDeclPtr; - for (auto ns = nsDecl; ns; ns = ns->nextInContainerWithSameName) - { - if (ns == decl) - continue; - auto otherNamespace = as(ns); - if (!otherNamespace) - continue; + if (ns == decl) + continue; + auto otherNamespace = as(ns); + if (!otherNamespace) + continue; - if (!ns->checkState.isBeingChecked()) - { - ensureDecl(ns, DeclCheckState::ScopesWired); - } - addSiblingScopeForContainerDecl(getASTBuilder(), decl, otherNamespace); + if (!ns->checkState.isBeingChecked()) + { + ensureDecl(ns, DeclCheckState::ScopesWired); } + addSiblingScopeForContainerDecl(getASTBuilder(), decl, otherNamespace); } - // For file decls, we need to continue searching up in the parent module scope. - if (!as(parentScope->containerDecl)) - break; - } - for (auto usingDecl : decl->getMembersOfType()) - { - ensureDecl(usingDecl, DeclCheckState::ScopesWired); } + // For file decls, we need to continue searching up in the parent module scope. + if (!as(parentScope->containerDecl)) + break; + } + for (auto usingDecl : decl->getMembersOfType()) + { + ensureDecl(usingDecl, DeclCheckState::ScopesWired); } +} - /// Get a reference to the candidate extension list for `typeDecl` in the given dictionary - /// - /// Note: this function creates an empty list of candidates for the given type if - /// a matching entry doesn't exist already. - /// - static List& _getCandidateExtensionList( - AggTypeDecl* typeDecl, - Dictionary>& mapTypeToCandidateExtensions) +/// Get a reference to the candidate extension list for `typeDecl` in the given dictionary +/// +/// Note: this function creates an empty list of candidates for the given type if +/// a matching entry doesn't exist already. +/// +static List& _getCandidateExtensionList( + AggTypeDecl* typeDecl, + Dictionary>& mapTypeToCandidateExtensions) +{ + RefPtr entry; + if (!mapTypeToCandidateExtensions.tryGetValue(typeDecl, entry)) { - RefPtr entry; - if( !mapTypeToCandidateExtensions.tryGetValue(typeDecl, entry) ) - { - entry = new CandidateExtensionList(); - mapTypeToCandidateExtensions.add(typeDecl, entry); - } - return entry->candidateExtensions; + entry = new CandidateExtensionList(); + mapTypeToCandidateExtensions.add(typeDecl, entry); } + return entry->candidateExtensions; +} - List const& SharedSemanticsContext::getCandidateExtensionsForTypeDecl(AggTypeDecl* decl) +List const& SharedSemanticsContext::getCandidateExtensionsForTypeDecl( + AggTypeDecl* decl) +{ + // We are caching the lists of candidate extensions on the shared + // context, so we will only build the lists if they either have + // not been built before, or if some code caused the lists to + // be invalidated. + // + // TODO: Similar to the rebuilding of lookup tables in `ContainerDecl`s, + // we probably want to optimize this logic to gracefully handle new + // extensions encountered during checking instead of tearing the whole + // thing down. For now this potentially-quadratic behavior is acceptable + // because there just aren't that many extension declarations being used. + // + if (!m_candidateExtensionListsBuilt) { - // We are caching the lists of candidate extensions on the shared - // context, so we will only build the lists if they either have - // not been built before, or if some code caused the lists to - // be invalidated. - // - // TODO: Similar to the rebuilding of lookup tables in `ContainerDecl`s, - // we probably want to optimize this logic to gracefully handle new - // extensions encountered during checking instead of tearing the whole - // thing down. For now this potentially-quadratic behavior is acceptable - // because there just aren't that many extension declarations being used. + m_candidateExtensionListsBuilt = true; + + // We need to make sure that all extensions that were declared + // as parts of our core module are always visible, + // even if they are not explicit `import`ed into user code. // - if( !m_candidateExtensionListsBuilt ) + for (auto module : getSession()->coreModules) { - m_candidateExtensionListsBuilt = true; - - // We need to make sure that all extensions that were declared - // as parts of our core module are always visible, - // even if they are not explicit `import`ed into user code. - // - for( auto module : getSession()->coreModules ) - { - _addCandidateExtensionsFromModule(module->getModuleDecl()); - } + _addCandidateExtensionsFromModule(module->getModuleDecl()); + } - // There are two primary modes in which the `SharedSemanticsContext` - // gets used. - // - // In the first mode, we are checking an entire `ModuelDecl`, and we - // need to always check things from the "point of view" of that module - // (so that the extensions that should be visible are based on what - // that module can access via `import`s). + // There are two primary modes in which the `SharedSemanticsContext` + // gets used. + // + // In the first mode, we are checking an entire `ModuelDecl`, and we + // need to always check things from the "point of view" of that module + // (so that the extensions that should be visible are based on what + // that module can access via `import`s). + // + // In the second mode, we are checking code related to API interactions + // by the user (e.g., parsing a type from a string, specializing an + // entry point to type arguments, etc.). In these cases there is no + // clear module that should determine the point of view for looking + // up extensions, and we instead need/want to consider any extensions + // from all modules loaded into the linkage. + // + // We differentiate these cases based on whether a "primary" module + // was set at the time the `SharedSemanticsContext` was constructed. + // + if (m_module) + { + // We have a "primary" module that is being checked, and we should + // look up extensions based on what would be visible to that + // module. // - // In the second mode, we are checking code related to API interactions - // by the user (e.g., parsing a type from a string, specializing an - // entry point to type arguments, etc.). In these cases there is no - // clear module that should determine the point of view for looking - // up extensions, and we instead need/want to consider any extensions - // from all modules loaded into the linkage. + // Extensions declared in the module itself should have already + // been registered when we check them, but we still need to bring + // along with everything the module imported. // - // We differentiate these cases based on whether a "primary" module - // was set at the time the `SharedSemanticsContext` was constructed. + // Note: there is an implicit assumption here that the `importedModules` + // member on the `SharedSemanticsContext` is accurate in this case. // - if( m_module ) + for (auto moduleDecl : this->importedModulesList) { - // We have a "primary" module that is being checked, and we should - // look up extensions based on what would be visible to that - // module. - // - // Extensions declared in the module itself should have already - // been registered when we check them, but we still need to bring - // along with everything the module imported. - // - // Note: there is an implicit assumption here that the `importedModules` - // member on the `SharedSemanticsContext` is accurate in this case. - // - for( auto moduleDecl : this->importedModulesList ) - { - _addCandidateExtensionsFromModule(moduleDecl); - } + _addCandidateExtensionsFromModule(moduleDecl); } - else + } + else + { + // We are in one of the many ad hoc checking modes where we really + // want to resolve things based on the totality of what is + // available/defined within the current linkage. + // + for (auto module : m_linkage->loadedModulesList) { - // We are in one of the many ad hoc checking modes where we really - // want to resolve things based on the totality of what is - // available/defined within the current linkage. - // - for( auto module : m_linkage->loadedModulesList ) - { - _addCandidateExtensionsFromModule(module->getModuleDecl()); - } + _addCandidateExtensionsFromModule(module->getModuleDecl()); } } - - // Once we are sure that the dictionary-of-arrays of extensions - // has been populated, we return to the user the entry they - // asked for. - // - return _getCandidateExtensionList(decl, m_mapTypeDeclToCandidateExtensions); } - void SharedSemanticsContext::registerCandidateExtension(AggTypeDecl* typeDecl, ExtensionDecl* extDecl) - { - // The primary cache of extension declarations is on the `ModuleDecl`. - // We will add the `extDecl` to the cache for the module it belongs to. - // - // We can be sure that the resulting cache won't have lifetime issues, - // because all the extensions it contains are owned by the module itself, - // and the types used as keys had to be reachable/referenceable from the - // code inside the module for the given `extDecl` to extend them. - // - auto moduleDecl = getModuleDecl(extDecl); - _getCandidateExtensionList(typeDecl, moduleDecl->mapTypeToCandidateExtensions).add(extDecl); + // Once we are sure that the dictionary-of-arrays of extensions + // has been populated, we return to the user the entry they + // asked for. + // + return _getCandidateExtensionList(decl, m_mapTypeDeclToCandidateExtensions); +} - // Because we've loaded a new extension, we need to invalidate whatever - // information the `SharedSemanticsContext` had cached about loaded - // extensions, and force it to rebuild its cache to include the - // new extension we just added. - // - _getCandidateExtensionList(typeDecl, m_mapTypeDeclToCandidateExtensions).add(extDecl); +void SharedSemanticsContext::registerCandidateExtension( + AggTypeDecl* typeDecl, + ExtensionDecl* extDecl) +{ + // The primary cache of extension declarations is on the `ModuleDecl`. + // We will add the `extDecl` to the cache for the module it belongs to. + // + // We can be sure that the resulting cache won't have lifetime issues, + // because all the extensions it contains are owned by the module itself, + // and the types used as keys had to be reachable/referenceable from the + // code inside the module for the given `extDecl` to extend them. + // + auto moduleDecl = getModuleDecl(extDecl); + _getCandidateExtensionList(typeDecl, moduleDecl->mapTypeToCandidateExtensions).add(extDecl); + + // Because we've loaded a new extension, we need to invalidate whatever + // information the `SharedSemanticsContext` had cached about loaded + // extensions, and force it to rebuild its cache to include the + // new extension we just added. + // + _getCandidateExtensionList(typeDecl, m_mapTypeDeclToCandidateExtensions).add(extDecl); - // Remove the cached inheritanceInfo about typeDecl, if `extDecl` inherits new types. - bool invalidateSubtypes = false; - if (as(typeDecl)) + // Remove the cached inheritanceInfo about typeDecl, if `extDecl` inherits new types. + bool invalidateSubtypes = false; + if (as(typeDecl)) + { + // If we are extending an interface, we are effectively extending all types + // that inherits the interface. So we need to remove all inheritance info + // that is related to the interface. + invalidateSubtypes = true; + } + bool hasInheritanceMember = false; + bool hasImplicitCastMember = false; + for (auto member : extDecl->members) + { + if (as(member)) { - // If we are extending an interface, we are effectively extending all types - // that inherits the interface. So we need to remove all inheritance info - // that is related to the interface. - invalidateSubtypes = true; + hasInheritanceMember = true; } - bool hasInheritanceMember = false; - bool hasImplicitCastMember = false; - for (auto member : extDecl->members) + else if (auto ctorDecl = as(member)) { - if (as(member)) - { - hasInheritanceMember = true; - } - else if (auto ctorDecl = as(member)) - { - if (ctorDecl->hasModifier()) - hasImplicitCastMember = true; - } + if (ctorDecl->hasModifier()) + hasImplicitCastMember = true; } - auto isTypeUpToDate = [this](Type* type) - { - if (auto declRefType = as(type)) - { - return m_mapDeclRefToInheritanceInfo.containsKey(declRefType->getDeclRef()); - } - return m_mapTypeToInheritanceInfo.containsKey(type); - }; - auto isInheritanceInfoAffected = [typeDecl](InheritanceInfo& info) - { - for (auto f : info.facets) - if (f.getImpl()->getDeclRef().getDecl() == typeDecl) - { - return true; - } - return false; - }; - if (invalidateSubtypes) + } + auto isTypeUpToDate = [this](Type* type) + { + if (auto declRefType = as(type)) { - decltype(m_mapTypeToInheritanceInfo) newMapTypeToInheritanceInfo; - for (auto& kv : m_mapTypeToInheritanceInfo) + return m_mapDeclRefToInheritanceInfo.containsKey(declRefType->getDeclRef()); + } + return m_mapTypeToInheritanceInfo.containsKey(type); + }; + auto isInheritanceInfoAffected = [typeDecl](InheritanceInfo& info) + { + for (auto f : info.facets) + if (f.getImpl()->getDeclRef().getDecl() == typeDecl) { - if (!isInheritanceInfoAffected(kv.second)) - { - newMapTypeToInheritanceInfo.add(kv.first, kv.second); - } + return true; } - m_mapTypeToInheritanceInfo = _Move(newMapTypeToInheritanceInfo); - } - - ShortList, 16> keysToRemove; - for (auto& kv : m_mapDeclRefToInheritanceInfo) + return false; + }; + if (invalidateSubtypes) + { + decltype(m_mapTypeToInheritanceInfo) newMapTypeToInheritanceInfo; + for (auto& kv : m_mapTypeToInheritanceInfo) { - // We can confirm the type is affected by the new extension, - // if the declref type points to typeDecl. - if (kv.first.getDecl() == typeDecl) + if (!isInheritanceInfoAffected(kv.second)) { - keysToRemove.add(kv.first); - continue; + newMapTypeToInheritanceInfo.add(kv.first, kv.second); } + } + m_mapTypeToInheritanceInfo = _Move(newMapTypeToInheritanceInfo); + } - // If we are extending interface types (and in the future any struct type - // if we decide to have full inheritance support), - // we also need to account for conformant that implements the interface. - if (invalidateSubtypes && isInheritanceInfoAffected(kv.second)) - { - keysToRemove.add(kv.first); - } + ShortList, 16> keysToRemove; + for (auto& kv : m_mapDeclRefToInheritanceInfo) + { + // We can confirm the type is affected by the new extension, + // if the declref type points to typeDecl. + if (kv.first.getDecl() == typeDecl) + { + keysToRemove.add(kv.first); + continue; } - for (auto& key : keysToRemove) + + // If we are extending interface types (and in the future any struct type + // if we decide to have full inheritance support), + // we also need to account for conformant that implements the interface. + if (invalidateSubtypes && isInheritanceInfoAffected(kv.second)) { - m_mapDeclRefToInheritanceInfo.remove(key); + keysToRemove.add(kv.first); } + } + for (auto& key : keysToRemove) + { + m_mapDeclRefToInheritanceInfo.remove(key); + } - if (hasInheritanceMember || invalidateSubtypes) + if (hasInheritanceMember || invalidateSubtypes) + { + ShortList typePairsToRemove; + for (auto& kv : m_mapTypePairToSubtypeWitness) { - ShortList typePairsToRemove; - for (auto& kv : m_mapTypePairToSubtypeWitness) - { - if (!isTypeUpToDate(kv.first.type0) || !isTypeUpToDate(kv.first.type1)) - { - typePairsToRemove.add(kv.first); - } - } - for (auto& key : typePairsToRemove) + if (!isTypeUpToDate(kv.first.type0) || !isTypeUpToDate(kv.first.type1)) { - m_mapTypePairToSubtypeWitness.remove(key); + typePairsToRemove.add(kv.first); } } + for (auto& key : typePairsToRemove) + { + m_mapTypePairToSubtypeWitness.remove(key); + } + } - if (hasImplicitCastMember) + if (hasImplicitCastMember) + { + ShortList entriesToRemove; + for (auto& kv : m_mapTypePairToImplicitCastMethod) { - ShortList entriesToRemove; - for (auto& kv : m_mapTypePairToImplicitCastMethod) - { - // Since implicit casts are defined as constructors on the toType, - // we only need to check if the toType is affected by the new extension. - auto declRefType = as(kv.first.toType); + // Since implicit casts are defined as constructors on the toType, + // we only need to check if the toType is affected by the new extension. + auto declRefType = as(kv.first.toType); - if (!declRefType || declRefType->getDeclRef().getDecl() == typeDecl) - { - entriesToRemove.add(kv.first); - } - } - for (auto& key : entriesToRemove) + if (!declRefType || declRefType->getDeclRef().getDecl() == typeDecl) { - m_mapTypePairToImplicitCastMethod.remove(key); + entriesToRemove.add(kv.first); } } - } - - void SharedSemanticsContext::_addCandidateExtensionsFromModule(ModuleDecl* moduleDecl) - { - for( auto& [entryKey, entryValue] : moduleDecl->mapTypeToCandidateExtensions ) + for (auto& key : entriesToRemove) { - auto& list = _getCandidateExtensionList(entryKey, m_mapTypeDeclToCandidateExtensions); - list.addRange(entryValue->candidateExtensions); + m_mapTypePairToImplicitCastMethod.remove(key); } } +} - /// Get a reference to the associated decl list for `decl` in the given dictionary - /// - /// Note: this function creates an empty list of candidates for the given type if - /// a matching entry doesn't exist already. - /// - static List>& _getDeclAssociationList( - Decl* decl, - OrderedDictionary>& mapDeclToDeclarations) +void SharedSemanticsContext::_addCandidateExtensionsFromModule(ModuleDecl* moduleDecl) +{ + for (auto& [entryKey, entryValue] : moduleDecl->mapTypeToCandidateExtensions) { - RefPtr entry; - if (!mapDeclToDeclarations.tryGetValue(decl, entry)) - { - entry = new DeclAssociationList(); - mapDeclToDeclarations.add(decl, entry); - } - return entry->associations; + auto& list = _getCandidateExtensionList(entryKey, m_mapTypeDeclToCandidateExtensions); + list.addRange(entryValue->candidateExtensions); } +} - void SharedSemanticsContext::_addDeclAssociationsFromModule(ModuleDecl* moduleDecl) +/// Get a reference to the associated decl list for `decl` in the given dictionary +/// +/// Note: this function creates an empty list of candidates for the given type if +/// a matching entry doesn't exist already. +/// +static List>& _getDeclAssociationList( + Decl* decl, + OrderedDictionary>& mapDeclToDeclarations) +{ + RefPtr entry; + if (!mapDeclToDeclarations.tryGetValue(decl, entry)) { - for (auto& entry : moduleDecl->mapDeclToAssociatedDecls) - { - auto& list = _getDeclAssociationList(entry.key, m_mapDeclToAssociatedDecls); - list.addRange(entry.value->associations); - } + entry = new DeclAssociationList(); + mapDeclToDeclarations.add(decl, entry); } + return entry->associations; +} - void SharedSemanticsContext::registerAssociatedDecl(Decl* original, DeclAssociationKind kind, Decl* associated) +void SharedSemanticsContext::_addDeclAssociationsFromModule(ModuleDecl* moduleDecl) +{ + for (auto& entry : moduleDecl->mapDeclToAssociatedDecls) { - auto moduleDecl = getModuleDecl(associated); - RefPtr assoc = new DeclAssociation(); - assoc->kind = kind; - assoc->decl = associated; - _getDeclAssociationList(original, moduleDecl->mapDeclToAssociatedDecls).add(assoc); - - m_associatedDeclListsBuilt = false; - m_mapDeclToAssociatedDecls.clear(); + auto& list = _getDeclAssociationList(entry.key, m_mapDeclToAssociatedDecls); + list.addRange(entry.value->associations); } +} + +void SharedSemanticsContext::registerAssociatedDecl( + Decl* original, + DeclAssociationKind kind, + Decl* associated) +{ + auto moduleDecl = getModuleDecl(associated); + RefPtr assoc = new DeclAssociation(); + assoc->kind = kind; + assoc->decl = associated; + _getDeclAssociationList(original, moduleDecl->mapDeclToAssociatedDecls).add(assoc); + + m_associatedDeclListsBuilt = false; + m_mapDeclToAssociatedDecls.clear(); +} - List> const& SharedSemanticsContext::getAssociatedDeclsForDecl(Decl* decl) +List> const& SharedSemanticsContext::getAssociatedDeclsForDecl(Decl* decl) +{ + // This duplicates the exact same logic from `getCandidateExtensionsForTypeDecl`. + // Consider refactoring them into the same framework. + if (!m_associatedDeclListsBuilt) { - // This duplicates the exact same logic from `getCandidateExtensionsForTypeDecl`. - // Consider refactoring them into the same framework. - if (!m_associatedDeclListsBuilt) - { - m_associatedDeclListsBuilt = true; + m_associatedDeclListsBuilt = true; - for (auto module : getSession()->coreModules) - { - _addDeclAssociationsFromModule(module->getModuleDecl()); - } + for (auto module : getSession()->coreModules) + { + _addDeclAssociationsFromModule(module->getModuleDecl()); + } - if (m_module) + if (m_module) + { + _addDeclAssociationsFromModule(m_module->getModuleDecl()); + for (auto moduleDecl : this->importedModulesList) { - _addDeclAssociationsFromModule(m_module->getModuleDecl()); - for (auto moduleDecl : this->importedModulesList) - { - _addDeclAssociationsFromModule(moduleDecl); - } + _addDeclAssociationsFromModule(moduleDecl); } - else + } + else + { + for (auto module : m_linkage->loadedModulesList) { - for (auto module : m_linkage->loadedModulesList) - { - _addDeclAssociationsFromModule(module->getModuleDecl()); - } + _addDeclAssociationsFromModule(module->getModuleDecl()); } } - return _getDeclAssociationList(decl, m_mapDeclToAssociatedDecls); } + return _getDeclAssociationList(decl, m_mapDeclToAssociatedDecls); +} - bool SharedSemanticsContext::isDifferentiableFunc(FunctionDeclBase* func) - { - return getFuncDifferentiableLevel(func) != FunctionDifferentiableLevel::None; - } +bool SharedSemanticsContext::isDifferentiableFunc(FunctionDeclBase* func) +{ + return getFuncDifferentiableLevel(func) != FunctionDifferentiableLevel::None; +} - bool SharedSemanticsContext::isBackwardDifferentiableFunc(FunctionDeclBase* func) - { - return getFuncDifferentiableLevel(func) == FunctionDifferentiableLevel::Backward; - } +bool SharedSemanticsContext::isBackwardDifferentiableFunc(FunctionDeclBase* func) +{ + return getFuncDifferentiableLevel(func) == FunctionDifferentiableLevel::Backward; +} - FunctionDifferentiableLevel SharedSemanticsContext::getFuncDifferentiableLevel(FunctionDeclBase* func) - { - return _getFuncDifferentiableLevelImpl(func, 1); - } +FunctionDifferentiableLevel SharedSemanticsContext::getFuncDifferentiableLevel( + FunctionDeclBase* func) +{ + return _getFuncDifferentiableLevelImpl(func, 1); +} - FunctionDifferentiableLevel SharedSemanticsContext::_getFuncDifferentiableLevelImpl(FunctionDeclBase* func, int recurseLimit) - { - if (!func) - return FunctionDifferentiableLevel::None; +FunctionDifferentiableLevel SharedSemanticsContext::_getFuncDifferentiableLevelImpl( + FunctionDeclBase* func, + int recurseLimit) +{ + if (!func) + return FunctionDifferentiableLevel::None; - if (recurseLimit > 0) + if (recurseLimit > 0) + { + if (auto primalSubst = func->findModifier()) { - if (auto primalSubst = func->findModifier()) + if (auto declRefExpr = as(primalSubst->funcExpr)) { - if (auto declRefExpr = as(primalSubst->funcExpr)) - { - if (auto primalSubstFunc = declRefExpr->declRef.as()) - return _getFuncDifferentiableLevelImpl(primalSubstFunc.getDecl(), recurseLimit - 1); - } + if (auto primalSubstFunc = declRefExpr->declRef.as()) + return _getFuncDifferentiableLevelImpl( + primalSubstFunc.getDecl(), + recurseLimit - 1); } } + } - if (func->findModifier()) - return FunctionDifferentiableLevel::Backward; - if (func->findModifier()) - return FunctionDifferentiableLevel::Backward; + if (func->findModifier()) + return FunctionDifferentiableLevel::Backward; + if (func->findModifier()) + return FunctionDifferentiableLevel::Backward; - if (func->findModifier()) - return FunctionDifferentiableLevel::Backward; + if (func->findModifier()) + return FunctionDifferentiableLevel::Backward; - FunctionDifferentiableLevel diffLevel = FunctionDifferentiableLevel::None; - if (func->findModifier()) - diffLevel = FunctionDifferentiableLevel::Forward; + FunctionDifferentiableLevel diffLevel = FunctionDifferentiableLevel::None; + if (func->findModifier()) + diffLevel = FunctionDifferentiableLevel::Forward; - for (auto assocDecl : getAssociatedDeclsForDecl(func)) + for (auto assocDecl : getAssociatedDeclsForDecl(func)) + { + switch (assocDecl->kind) { - switch (assocDecl->kind) + case DeclAssociationKind::BackwardDerivativeFunc: + return FunctionDifferentiableLevel::Backward; + case DeclAssociationKind::ForwardDerivativeFunc: + diffLevel = FunctionDifferentiableLevel::Forward; + break; + case DeclAssociationKind::PrimalSubstituteFunc: + if (auto assocFunc = as(assocDecl->decl)) { - case DeclAssociationKind::BackwardDerivativeFunc: - return FunctionDifferentiableLevel::Backward; - case DeclAssociationKind::ForwardDerivativeFunc: - diffLevel = FunctionDifferentiableLevel::Forward; - break; - case DeclAssociationKind::PrimalSubstituteFunc: - if (auto assocFunc = as(assocDecl->decl)) - { - return _getFuncDifferentiableLevelImpl(assocFunc, recurseLimit - 1); - } - break; - default: - break; + return _getFuncDifferentiableLevelImpl(assocFunc, recurseLimit - 1); } + break; + default: break; } - if (auto builtinReq = func->findModifier()) + } + if (auto builtinReq = func->findModifier()) + { + switch (builtinReq->kind) { - switch (builtinReq->kind) - { - case BuiltinRequirementKind::DAddFunc: - case BuiltinRequirementKind::DMulFunc: - case BuiltinRequirementKind::DZeroFunc: - return FunctionDifferentiableLevel::Backward; - default: - break; - } + case BuiltinRequirementKind::DAddFunc: + case BuiltinRequirementKind::DMulFunc: + case BuiltinRequirementKind::DZeroFunc: return FunctionDifferentiableLevel::Backward; + default: break; } - return diffLevel; } + return diffLevel; +} - List const& getCandidateExtensions( - DeclRef const& declRef, - SemanticsVisitor* semantics) - { - auto decl = declRef.getDecl(); - auto shared = semantics->getShared(); - return shared->getCandidateExtensionsForTypeDecl(decl); - } +List const& getCandidateExtensions( + DeclRef const& declRef, + SemanticsVisitor* semantics) +{ + auto decl = declRef.getDecl(); + auto shared = semantics->getShared(); + return shared->getCandidateExtensionsForTypeDecl(decl); +} - void _foreachDirectOrExtensionMemberOfType( - SemanticsVisitor* semantics, - DeclRef const& containerDeclRef, - SyntaxClassBase const& syntaxClass, - void (*callback)(DeclRefBase*, void*), - void const* userData) +void _foreachDirectOrExtensionMemberOfType( + SemanticsVisitor* semantics, + DeclRef const& containerDeclRef, + SyntaxClassBase const& syntaxClass, + void (*callback)(DeclRefBase*, void*), + void const* userData) +{ + // We are being asked to invoke the given callback on + // each direct member of `containerDeclRef`, along with + // any members added via `extension` declarations, that + // have the correct AST node class (`syntaxClass`). + // + // We start with the direct members. + // + for (auto memberDeclRef : getMembers(semantics->getASTBuilder(), containerDeclRef)) { - // We are being asked to invoke the given callback on - // each direct member of `containerDeclRef`, along with - // any members added via `extension` declarations, that - // have the correct AST node class (`syntaxClass`). - // - // We start with the direct members. - // - for( auto memberDeclRef : getMembers(semantics->getASTBuilder(), containerDeclRef)) + if (memberDeclRef.getDecl()->getClass().isSubClassOfImpl(syntaxClass)) { - if( memberDeclRef.getDecl()->getClass().isSubClassOfImpl(syntaxClass)) - { - callback(memberDeclRef, (void*)userData); - } + callback(memberDeclRef, (void*)userData); } + } - // Next, in the case wher ethe type can be subject to extensions, - // we loop over the applicable extensions and their member.s - // - if(auto aggTypeDeclRef = containerDeclRef.as()) + // Next, in the case wher ethe type can be subject to extensions, + // we loop over the applicable extensions and their member.s + // + if (auto aggTypeDeclRef = containerDeclRef.as()) + { + auto aggType = DeclRefType::create(semantics->getASTBuilder(), aggTypeDeclRef); + for (auto extDecl : getCandidateExtensions(aggTypeDeclRef, semantics)) { - auto aggType = DeclRefType::create(semantics->getASTBuilder(), aggTypeDeclRef); - for(auto extDecl : getCandidateExtensions(aggTypeDeclRef, semantics)) - { - // Note that `extDecl` may have been declared for a type - // base on the declaration that `aggTypeDeclRef` refers - // to, but that does not guarantee that it applies to - // the type itself. E.g., we might have an extension of - // `vector` for any `N`, but the current type is - // `vector` so that the extension doesn't match. - // - // In order to make sure that we don't enumerate members - // that don't make sense in context, we must apply - // the extension to the type and see if we succeed in - // making a match. - // - auto extDeclRef = applyExtensionToType(semantics, extDecl, aggType); - if(!extDeclRef) - continue; + // Note that `extDecl` may have been declared for a type + // base on the declaration that `aggTypeDeclRef` refers + // to, but that does not guarantee that it applies to + // the type itself. E.g., we might have an extension of + // `vector` for any `N`, but the current type is + // `vector` so that the extension doesn't match. + // + // In order to make sure that we don't enumerate members + // that don't make sense in context, we must apply + // the extension to the type and see if we succeed in + // making a match. + // + auto extDeclRef = applyExtensionToType(semantics, extDecl, aggType); + if (!extDeclRef) + continue; - for( auto memberDeclRef : getMembers(semantics->getASTBuilder(), extDeclRef) ) + for (auto memberDeclRef : getMembers(semantics->getASTBuilder(), extDeclRef)) + { + if (memberDeclRef.getDecl()->getClass().isSubClassOfImpl(syntaxClass)) { - if( memberDeclRef.getDecl()->getClass().isSubClassOfImpl(syntaxClass)) - { - callback(memberDeclRef, (void*)userData); - } + callback(memberDeclRef, (void*)userData); } } } } +} - static void _dispatchDeclCheckingVisitor(Decl* decl, DeclCheckState state, SemanticsContext& shared) +static void _dispatchDeclCheckingVisitor(Decl* decl, DeclCheckState state, SemanticsContext& shared) +{ + switch (state) { - switch(state) - { - case DeclCheckState::ModifiersChecked: - SemanticsDeclModifiersVisitor(shared).dispatch(decl); - break; - case DeclCheckState::ScopesWired: - SemanticsDeclScopeWiringVisitor(shared).dispatch(decl); - break; + case DeclCheckState::ModifiersChecked: + SemanticsDeclModifiersVisitor(shared).dispatch(decl); + break; + case DeclCheckState::ScopesWired: SemanticsDeclScopeWiringVisitor(shared).dispatch(decl); break; - case DeclCheckState::SignatureChecked: - SemanticsDeclHeaderVisitor(shared).dispatch(decl); - break; + case DeclCheckState::SignatureChecked: SemanticsDeclHeaderVisitor(shared).dispatch(decl); break; - case DeclCheckState::ReadyForReference: - SemanticsDeclRedeclarationVisitor(shared).dispatch(decl); - break; + case DeclCheckState::ReadyForReference: + SemanticsDeclRedeclarationVisitor(shared).dispatch(decl); + break; - case DeclCheckState::ReadyForLookup: - SemanticsDeclBasesVisitor(shared).dispatch(decl); - break; + case DeclCheckState::ReadyForLookup: SemanticsDeclBasesVisitor(shared).dispatch(decl); break; - case DeclCheckState::ReadyForConformances: - SemanticsDeclConformancesVisitor(shared).dispatch(decl); - break; + case DeclCheckState::ReadyForConformances: + SemanticsDeclConformancesVisitor(shared).dispatch(decl); + break; - case DeclCheckState::TypesFullyResolved: - SemanticsDeclTypeResolutionVisitor(shared).dispatch(decl); - SemanticsDeclDifferentialConformanceVisitor(shared).dispatch(decl); - break; + case DeclCheckState::TypesFullyResolved: + SemanticsDeclTypeResolutionVisitor(shared).dispatch(decl); + SemanticsDeclDifferentialConformanceVisitor(shared).dispatch(decl); + break; - case DeclCheckState::AttributesChecked: - SemanticsDeclAttributesVisitor(shared).dispatch(decl); - break; + case DeclCheckState::AttributesChecked: + SemanticsDeclAttributesVisitor(shared).dispatch(decl); + break; - case DeclCheckState::DefinitionChecked: - SemanticsDeclBodyVisitor(shared).dispatch(decl); - break; + case DeclCheckState::DefinitionChecked: SemanticsDeclBodyVisitor(shared).dispatch(decl); break; - case DeclCheckState::CapabilityChecked: - if (!shared.getOptionSet().getBoolOption(CompilerOptionName::IgnoreCapabilities)) - { - SemanticsDeclCapabilityVisitor(shared).dispatch(decl); - } - break; + case DeclCheckState::CapabilityChecked: + if (!shared.getOptionSet().getBoolOption(CompilerOptionName::IgnoreCapabilities)) + { + SemanticsDeclCapabilityVisitor(shared).dispatch(decl); } + break; } +} - static void _getCanonicalConstraintTypes(List& outTypeList, Type* type) +static void _getCanonicalConstraintTypes(List& outTypeList, Type* type) +{ + if (auto andType = as(type)) { - if (auto andType = as(type)) - { - _getCanonicalConstraintTypes(outTypeList, andType->getLeft()); - _getCanonicalConstraintTypes(outTypeList, andType->getRight()); - } - else - { - outTypeList.add(type); - } + _getCanonicalConstraintTypes(outTypeList, andType->getLeft()); + _getCanonicalConstraintTypes(outTypeList, andType->getRight()); } - OrderedDictionary> getCanonicalGenericConstraints( - ASTBuilder* astBuilder, - DeclRef genericDecl) + else { - OrderedDictionary> genericConstraints; - for (auto mm : getMembersOfType(astBuilder, genericDecl)) - { - genericConstraints[mm.getDecl()] = List(); - } - for (auto genericTypeConstraintDecl : getMembersOfType(astBuilder, genericDecl)) - { - assert( - genericTypeConstraintDecl.getDecl()->sub.type->astNodeType == - ASTNodeType::DeclRefType); - auto typeParamDecl = as(genericTypeConstraintDecl.getDecl()->sub.type)->getDeclRef().getDecl(); - List* constraintTypes = genericConstraints.tryGetValue(typeParamDecl); - if (!constraintTypes) - continue; - constraintTypes->add(genericTypeConstraintDecl.getDecl()->getSup().type); - } + outTypeList.add(type); + } +} +OrderedDictionary> getCanonicalGenericConstraints( + ASTBuilder* astBuilder, + DeclRef genericDecl) +{ + OrderedDictionary> genericConstraints; + for (auto mm : getMembersOfType(astBuilder, genericDecl)) + { + genericConstraints[mm.getDecl()] = List(); + } + for (auto genericTypeConstraintDecl : + getMembersOfType(astBuilder, genericDecl)) + { + assert( + genericTypeConstraintDecl.getDecl()->sub.type->astNodeType == ASTNodeType::DeclRefType); + auto typeParamDecl = + as(genericTypeConstraintDecl.getDecl()->sub.type)->getDeclRef().getDecl(); + List* constraintTypes = genericConstraints.tryGetValue(typeParamDecl); + if (!constraintTypes) + continue; + constraintTypes->add(genericTypeConstraintDecl.getDecl()->getSup().type); + } - OrderedDictionary> result; - for (auto& constraints : genericConstraints) + OrderedDictionary> result; + for (auto& constraints : genericConstraints) + { + List typeList; + for (auto type : constraints.value) { - List typeList; - for (auto type : constraints.value) - { - _getCanonicalConstraintTypes(typeList, type); - } - // TODO: we also need to sort the types within the list for each generic type param. - result[constraints.key] = typeList; + _getCanonicalConstraintTypes(typeList, type); } - return result; + // TODO: we also need to sort the types within the list for each generic type param. + result[constraints.key] = typeList; } + return result; +} - bool areTypesCompatibile(SemanticsVisitor* visitor, Type* fst, Type* snd) - { - if (fst->equals(snd)) - return true; +bool areTypesCompatibile(SemanticsVisitor* visitor, Type* fst, Type* snd) +{ + if (fst->equals(snd)) + return true; - if (auto declRefType = as(fst)) + if (auto declRefType = as(fst)) + { + auto decl = declRefType->getDeclRef().getDecl(); + if (auto extGenericDecl = visitor->GetOuterGeneric(decl)) { - auto decl = declRefType->getDeclRef().getDecl(); - if (auto extGenericDecl = visitor->GetOuterGeneric(decl)) - { - SemanticsVisitor::ConstraintSystem constraints; - constraints.loc = decl->loc; - constraints.genericDecl = extGenericDecl; + SemanticsVisitor::ConstraintSystem constraints; + constraints.loc = decl->loc; + constraints.genericDecl = extGenericDecl; - if (!visitor->TryUnifyTypes(constraints, SemanticsVisitor::ValUnificationContext(), fst, snd)) - return false; - - ConversionCost baseCost; - if (!visitor->trySolveConstraintSystem(&constraints, makeDeclRef(extGenericDecl), ArrayView(), baseCost)) - return false; - - // If we reach here, it means we have a valid unification. - return true; - } + if (!visitor->TryUnifyTypes( + constraints, + SemanticsVisitor::ValUnificationContext(), + fst, + snd)) + return false; + + ConversionCost baseCost; + if (!visitor->trySolveConstraintSystem( + &constraints, + makeDeclRef(extGenericDecl), + ArrayView(), + baseCost)) + return false; + + // If we reach here, it means we have a valid unification. + return true; } - return false; } + return false; +} - Type* getTypeForThisExpr(SemanticsVisitor* visitor, FunctionDeclBase* funcDecl) - { - ThisExpr* expr = visitor->getASTBuilder()->create(); - expr->scope = funcDecl->ownedScope; - expr->loc = funcDecl->loc; +Type* getTypeForThisExpr(SemanticsVisitor* visitor, FunctionDeclBase* funcDecl) +{ + ThisExpr* expr = visitor->getASTBuilder()->create(); + expr->scope = funcDecl->ownedScope; + expr->loc = funcDecl->loc; - DiagnosticSink dummySink; - auto tempVisitor = SemanticsVisitor(visitor->withSink(&dummySink)); + DiagnosticSink dummySink; + auto tempVisitor = SemanticsVisitor(visitor->withSink(&dummySink)); - auto checkedExpr = tempVisitor.CheckTerm(expr); - - return !(as(checkedExpr->type.type)) ? (checkedExpr->type.type) : nullptr; - } + auto checkedExpr = tempVisitor.CheckTerm(expr); - Type* getTypeForThisExpr(SemanticsVisitor* visitor, DeclRef funcDeclRef) - { - auto type = getTypeForThisExpr(visitor, funcDeclRef.getDecl()); - if (type) - return substituteType( - SubstitutionSet(funcDeclRef.declRefBase), - visitor->getASTBuilder(), - type); - return nullptr; - } + return !(as(checkedExpr->type.type)) ? (checkedExpr->type.type) : nullptr; +} + +Type* getTypeForThisExpr(SemanticsVisitor* visitor, DeclRef funcDeclRef) +{ + auto type = getTypeForThisExpr(visitor, funcDeclRef.getDecl()); + if (type) + return substituteType( + SubstitutionSet(funcDeclRef.declRefBase), + visitor->getASTBuilder(), + type); + return nullptr; +} - struct ArgsWithDirectionInfo +struct ArgsWithDirectionInfo +{ + List args; + List directions; + + Expr* thisArg; + ParameterDirection thisArgDirection; +}; + +template +void checkDerivativeAttributeImpl( + SemanticsVisitor* visitor, + Decl* funcDecl, + TDerivativeAttr* attr, + const List& imaginaryArguments, + const List& expectedParamDirections, + Expr* expectedThisArg, + ParameterDirection expectedThisArgDirection) +{ + if (isInterfaceRequirement(funcDecl)) { - List args; - List directions; - - Expr* thisArg; - ParameterDirection thisArgDirection; - }; + visitor->getSink()->diagnose( + attr, + Diagnostics::cannotAssociateInterfaceRequirementWithDerivative); + return; + } - template - void checkDerivativeAttributeImpl( - SemanticsVisitor* visitor, - Decl* funcDecl, - TDerivativeAttr* attr, - const List& imaginaryArguments, - const List& expectedParamDirections, - Expr* expectedThisArg, - ParameterDirection expectedThisArgDirection) + SemanticsContext::ExprLocalScope scope; + auto ctx = visitor->withExprLocalScope(&scope); + auto subVisitor = SemanticsVisitor(ctx); + auto checkedFuncExpr = visitor->dispatchExpr(attr->funcExpr, ctx); + attr->funcExpr = checkedFuncExpr; + if (attr->args.getCount()) + attr->args[0] = attr->funcExpr; + if (auto declRefExpr = as(checkedFuncExpr)) { - if (isInterfaceRequirement(funcDecl)) + if (declRefExpr->declRef) + visitor->ensureDecl(declRefExpr->declRef, DeclCheckState::TypesFullyResolved); + else { - visitor->getSink()->diagnose(attr, Diagnostics::cannotAssociateInterfaceRequirementWithDerivative); + visitor->getSink()->diagnose(attr, Diagnostics::cannotResolveDerivativeFunction); return; } + } + else if (auto overloadedExpr = as(checkedFuncExpr)) + { + for (auto candidate : overloadedExpr->lookupResult2.items) + { + visitor->ensureDecl(candidate.declRef, DeclCheckState::TypesFullyResolved); + } + } + else + { + visitor->getSink()->diagnose(attr, Diagnostics::cannotResolveDerivativeFunction); + return; + } + + // If left value is true, then convert the + // inner type to an InOutType. + // + auto qualTypeToString = [&](QualType qualType) -> String + { + Type* type = qualType.type; + if (qualType.isLeftValue) + { + type = ctx.getASTBuilder()->getInOutType(type); + } + return type->toString(); + }; + + List argList = imaginaryArguments; + List paramDirections = expectedParamDirections; + bool expectStaticFunc = false; + + if (expectedThisArg) + { + argList.insert(0, expectedThisArg); + paramDirections.insert(0, expectedThisArgDirection); + expectStaticFunc = true; + } + + auto invokeExpr = subVisitor.constructUncheckedInvokeExpr(checkedFuncExpr, argList); + auto resolved = subVisitor.ResolveInvoke(invokeExpr); - SemanticsContext::ExprLocalScope scope; - auto ctx = visitor->withExprLocalScope(&scope); - auto subVisitor = SemanticsVisitor(ctx); - auto checkedFuncExpr = visitor->dispatchExpr(attr->funcExpr, ctx); - attr->funcExpr = checkedFuncExpr; - if (attr->args.getCount()) - attr->args[0] = attr->funcExpr; - if (auto declRefExpr = as(checkedFuncExpr)) + if (auto resolvedInvoke = as(resolved)) + { + if (auto calleeDeclRef = as(resolvedInvoke->functionExpr)) { - if (declRefExpr->declRef) - visitor->ensureDecl(declRefExpr->declRef, DeclCheckState::TypesFullyResolved); - else + // There are two ways to make it to this point.. a proper resolution, and a + // resolution that has failed due to type mismatch. + // Further, a proper resolution can still be invalid due to incorrect parameter + // directionality. + // We'll detect both these incorrect cases here and issue an appropriate diagnostic. + // + auto funcType = as(calleeDeclRef->type); + if (!funcType) { - visitor->getSink()->diagnose(attr, Diagnostics::cannotResolveDerivativeFunction); + // The best candidate does not have a function type. + // If we reach here, it means the function is a generic and we can't deduce the + // generic arguments from imaginary argument list. + // In this case we issue a diagnostic to ask the user to explicitly provide the + // arguments. + visitor->getSink()->diagnose( + attr, + Diagnostics::cannotResolveGenericArgumentForDerivativeFunction); return; } - } - else if (auto overloadedExpr = as(checkedFuncExpr)) - { - for (auto candidate : overloadedExpr->lookupResult2.items) + if (isInterfaceRequirement(calleeDeclRef->declRef.getDecl())) + { + visitor->getSink()->diagnose( + attr, + Diagnostics::cannotUseInterfaceRequirementAsDerivative); + return; + } + if (funcType->getParamCount() != argList.getCount()) { - visitor->ensureDecl(candidate.declRef, DeclCheckState::TypesFullyResolved); + goto error; } - } - else - { - visitor->getSink()->diagnose(attr, Diagnostics::cannotResolveDerivativeFunction); - return; - } - - // If left value is true, then convert the - // inner type to an InOutType. - // - auto qualTypeToString = [&](QualType qualType) -> String - { - Type* type = qualType.type; - if (qualType.isLeftValue) + for (Index ii = 0; ii < argList.getCount(); ++ii) { - type = ctx.getASTBuilder()->getInOutType(type); + // Check if the resolved invoke argument type is an error type. + // If so, then we have a type mismatch. + // + if (resolvedInvoke->arguments[ii]->type.type->equals( + ctx.getASTBuilder()->getErrorType()) || + funcType->getParamDirection(ii) != paramDirections[ii]) + { + visitor->getSink()->diagnose( + attr, + Diagnostics::customDerivativeSignatureMismatchAtPosition, + ii, + qualTypeToString(argList[ii]->type), + funcType->getParamType(ii)->toString()); + } } - return type->toString(); - }; - - List argList = imaginaryArguments; - List paramDirections = expectedParamDirections; - bool expectStaticFunc = false; + // The `imaginaryArguments` list does not include the `this` parameter. + // So we need to check that `this` type matches. + bool funcIsStatic = isEffectivelyStatic(funcDecl); + if (funcIsStatic) + expectStaticFunc = true; - if (expectedThisArg) - { - argList.insert(0, expectedThisArg); - paramDirections.insert(0, expectedThisArgDirection); - expectStaticFunc = true; - } + bool derivativeFuncIsStatic = isEffectivelyStatic(calleeDeclRef->declRef.getDecl()); - auto invokeExpr = subVisitor.constructUncheckedInvokeExpr(checkedFuncExpr, argList); - auto resolved = subVisitor.ResolveInvoke(invokeExpr); + if (expectStaticFunc && !derivativeFuncIsStatic) + { + visitor->getSink()->diagnose(attr, Diagnostics::customDerivativeExpectedStatic); + return; + } - if (auto resolvedInvoke = as(resolved)) - { - if (auto calleeDeclRef = as(resolvedInvoke->functionExpr)) + if (!derivativeFuncIsStatic) { - // There are two ways to make it to this point.. a proper resolution, and a - // resolution that has failed due to type mismatch. - // Further, a proper resolution can still be invalid due to incorrect parameter - // directionality. - // We'll detect both these incorrect cases here and issue an appropriate diagnostic. - // - auto funcType = as(calleeDeclRef->type); - if (!funcType) - { - // The best candidate does not have a function type. - // If we reach here, it means the function is a generic and we can't deduce the - // generic arguments from imaginary argument list. - // In this case we issue a diagnostic to ask the user to explicitly provide the arguments. - visitor->getSink()->diagnose(attr, Diagnostics::cannotResolveGenericArgumentForDerivativeFunction); - return; - } - if (isInterfaceRequirement(calleeDeclRef->declRef.getDecl())) - { - visitor->getSink()->diagnose(attr, Diagnostics::cannotUseInterfaceRequirementAsDerivative); - return; - } - if (funcType->getParamCount() != argList.getCount()) - { - goto error; - } - for (Index ii = 0; ii < argList.getCount(); ++ii) - { - // Check if the resolved invoke argument type is an error type. - // If so, then we have a type mismatch. - // - if (resolvedInvoke->arguments[ii]->type.type->equals(ctx.getASTBuilder()->getErrorType()) || - funcType->getParamDirection(ii) != paramDirections[ii]) - { - visitor->getSink()->diagnose( - attr, - Diagnostics::customDerivativeSignatureMismatchAtPosition, - ii, - qualTypeToString(argList[ii]->type), - funcType->getParamType(ii)->toString()); - } - } - // The `imaginaryArguments` list does not include the `this` parameter. - // So we need to check that `this` type matches. - bool funcIsStatic = isEffectivelyStatic(funcDecl); - if (funcIsStatic) - expectStaticFunc = true; + auto defaultFuncDeclRef = createDefaultSubstitutionsIfNeeded( + visitor->getASTBuilder(), + visitor, + makeDeclRef(funcDecl)); - bool derivativeFuncIsStatic = isEffectivelyStatic(calleeDeclRef->declRef.getDecl()); + DeclRef funcDeclRef = defaultFuncDeclRef.as(); + auto funcThisType = getTypeForThisExpr(visitor, funcDeclRef); + DeclRef calleeFuncDeclRef = + calleeDeclRef->declRef.template as(); + auto derivativeFuncThisType = getTypeForThisExpr(visitor, calleeFuncDeclRef); - if (expectStaticFunc && !derivativeFuncIsStatic) + // If the function is a member function, we need to check that the + // `this` type matches the expected type. This will ensure that after lowering to + // IR, the two functions are compatible. + // + if (!areTypesCompatibile(visitor, funcThisType, derivativeFuncThisType)) { visitor->getSink()->diagnose( attr, - Diagnostics::customDerivativeExpectedStatic); + Diagnostics::customDerivativeSignatureThisParamMismatch); return; } + } - if (!derivativeFuncIsStatic) - { - auto defaultFuncDeclRef = createDefaultSubstitutionsIfNeeded( - visitor->getASTBuilder(), - visitor, - makeDeclRef(funcDecl)); - - DeclRef funcDeclRef = defaultFuncDeclRef.as(); - auto funcThisType = getTypeForThisExpr(visitor, funcDeclRef); - DeclRef calleeFuncDeclRef = calleeDeclRef->declRef.template as(); - auto derivativeFuncThisType = getTypeForThisExpr(visitor, calleeFuncDeclRef); - - // If the function is a member function, we need to check that the - // `this` type matches the expected type. This will ensure that after lowering to IR, - // the two functions are compatible. - // - if (!areTypesCompatibile(visitor, funcThisType, derivativeFuncThisType)) - { - visitor->getSink()->diagnose( - attr, - Diagnostics::customDerivativeSignatureThisParamMismatch); - return; - } - } + // If the two decls are under different generic contexts, we'll need to check that + // they agree and specialize the attribute's decl-ref accordingly. + // + + auto originalNextGeneric = + visitor->findNextOuterGeneric(visitor->getOuterGenericOrSelf(funcDecl)); + auto derivativeNextGeneric = visitor->findNextOuterGeneric( + visitor->getOuterGenericOrSelf(calleeDeclRef->declRef.getDecl())); + + if ((!originalNextGeneric) != (!derivativeNextGeneric)) + { + // Diagnostic for when one is generic and the other is not. + visitor->getSink()->diagnose( + attr, + Diagnostics::cannotResolveGenericArgumentForDerivativeFunction); + return; + } - // If the two decls are under different generic contexts, we'll need to check that - // they agree and specialize the attribute's decl-ref accordingly. + if (originalNextGeneric != derivativeNextGeneric) + { + // If the two generic containers are not the same, but are compatible, we can + // unify them. // - - auto originalNextGeneric = visitor->findNextOuterGeneric(visitor->getOuterGenericOrSelf(funcDecl)); - auto derivativeNextGeneric = visitor->findNextOuterGeneric(visitor->getOuterGenericOrSelf(calleeDeclRef->declRef.getDecl())); - if ((!originalNextGeneric) != (!derivativeNextGeneric)) + DeclRef specializedDecl; + if (!visitor->doGenericSignaturesMatch( + originalNextGeneric, + derivativeNextGeneric, + &specializedDecl)) { - // Diagnostic for when one is generic and the other is not. - visitor->getSink()->diagnose(attr, Diagnostics::cannotResolveGenericArgumentForDerivativeFunction); + visitor->getSink()->diagnose( + attr, + Diagnostics::customDerivativeSignatureMismatch); return; } - if (originalNextGeneric != derivativeNextGeneric) - { - // If the two generic containers are not the same, but are compatible, we can - // unify them. - // - - DeclRef specializedDecl; - if (!visitor->doGenericSignaturesMatch(originalNextGeneric, derivativeNextGeneric, &specializedDecl)) - { - visitor->getSink()->diagnose(attr, Diagnostics::customDerivativeSignatureMismatch); - return; - } - - calleeDeclRef->declRef = substituteDeclRef( - SubstitutionSet(specializedDecl), - visitor->getASTBuilder(), - calleeDeclRef->declRef); - calleeDeclRef->type = substituteType( - SubstitutionSet(specializedDecl), - visitor->getASTBuilder(), - calleeDeclRef->type); - } - - attr->funcExpr = calleeDeclRef; - if (attr->args.getCount()) - attr->args[0] = attr->funcExpr; - return; + calleeDeclRef->declRef = substituteDeclRef( + SubstitutionSet(specializedDecl), + visitor->getASTBuilder(), + calleeDeclRef->declRef); + calleeDeclRef->type = substituteType( + SubstitutionSet(specializedDecl), + visitor->getASTBuilder(), + calleeDeclRef->type); } - } - error:; - // Build the expected signature from imaginary args to diagnose - // when no matching function is found (this excludes the case handled above) - // - StringBuilder builder; - builder << "("; - for (Index ii = 0; ii < argList.getCount(); ++ii) - { - if (ii != 0) - builder << ", "; - if (argList[ii]->type) - builder << qualTypeToString(argList[ii]->type); - else - builder << ""; - } - builder << ")"; - - visitor->getSink()->diagnose(attr, Diagnostics::customDerivativeSignatureMismatch, builder.produceString()); - } - - template - const char* getDerivativeAttrName() { SLANG_UNREACHABLE(""); } - template<> - const char* getDerivativeAttrName() - { - return "ForwardDerivative"; - } - template<> - const char* getDerivativeAttrName() - { - return "BackwardDerivative"; + attr->funcExpr = calleeDeclRef; + if (attr->args.getCount()) + attr->args[0] = attr->funcExpr; + return; + } } - template<> - const char* getDerivativeAttrName() +error:; + // Build the expected signature from imaginary args to diagnose + // when no matching function is found (this excludes the case handled above) + // + StringBuilder builder; + builder << "("; + for (Index ii = 0; ii < argList.getCount(); ++ii) { - return "PrimalSubstitute"; + if (ii != 0) + builder << ", "; + if (argList[ii]->type) + builder << qualTypeToString(argList[ii]->type); + else + builder << ""; } + builder << ")"; - ArgsWithDirectionInfo getImaginaryArgsToFunc(ASTBuilder* astBuilder, FunctionDeclBase* func, SourceLoc loc) - { - List imaginaryArguments; - List directions; - for (auto param : func->getParameters()) - { - auto arg = astBuilder->create(); - arg->declRef = makeDeclRef(param); - arg->type.isLeftValue = param->findModifier() ? true : false; - arg->type.type = param->getType(); - arg->loc = loc; - imaginaryArguments.add(arg); - directions.add(getParameterDirection(param)); - } - return { imaginaryArguments, directions, nullptr, ParameterDirection::kParameterDirection_In }; - } + visitor->getSink()->diagnose( + attr, + Diagnostics::customDerivativeSignatureMismatch, + builder.produceString()); +} - ArgsWithDirectionInfo getImaginaryArgsToForwardDerivative(SemanticsVisitor* visitor, FunctionDeclBase* originalFuncDecl, SourceLoc loc) - { - Expr* thisArgExpr = nullptr; - if (auto thisType = getTypeForThisExpr(visitor, originalFuncDecl)) - { - thisArgExpr = visitor->getASTBuilder()->create(); - thisArgExpr->type = thisType; - thisArgExpr->loc = loc; +template +const char* getDerivativeAttrName() +{ + SLANG_UNREACHABLE(""); +} - if (visitor->isTypeDifferentiable(thisType) && - !originalFuncDecl->findModifier() && - !isEffectivelyStatic(originalFuncDecl)) - { - auto pairType = visitor->getDifferentialPairType(thisType); - thisArgExpr->type.type = pairType; - } - else - { - thisArgExpr = nullptr; - } - } +template<> +const char* getDerivativeAttrName() +{ + return "ForwardDerivative"; +} +template<> +const char* getDerivativeAttrName() +{ + return "BackwardDerivative"; +} +template<> +const char* getDerivativeAttrName() +{ + return "PrimalSubstitute"; +} + +ArgsWithDirectionInfo getImaginaryArgsToFunc( + ASTBuilder* astBuilder, + FunctionDeclBase* func, + SourceLoc loc) +{ + List imaginaryArguments; + List directions; + for (auto param : func->getParameters()) + { + auto arg = astBuilder->create(); + arg->declRef = makeDeclRef(param); + arg->type.isLeftValue = param->findModifier() ? true : false; + arg->type.type = param->getType(); + arg->loc = loc; + imaginaryArguments.add(arg); + directions.add(getParameterDirection(param)); + } + return {imaginaryArguments, directions, nullptr, ParameterDirection::kParameterDirection_In}; +} - ParameterDirection thisTypeDirection = - (thisArgExpr && !thisArgExpr->type.isLeftValue) ? - ParameterDirection::kParameterDirection_In : - ParameterDirection::kParameterDirection_InOut; +ArgsWithDirectionInfo getImaginaryArgsToForwardDerivative( + SemanticsVisitor* visitor, + FunctionDeclBase* originalFuncDecl, + SourceLoc loc) +{ + Expr* thisArgExpr = nullptr; + if (auto thisType = getTypeForThisExpr(visitor, originalFuncDecl)) + { + thisArgExpr = visitor->getASTBuilder()->create(); + thisArgExpr->type = thisType; + thisArgExpr->loc = loc; - List imaginaryArguments; - for (auto param : originalFuncDecl->getParameters()) + if (visitor->isTypeDifferentiable(thisType) && + !originalFuncDecl->findModifier() && + !isEffectivelyStatic(originalFuncDecl)) { - auto arg = visitor->getASTBuilder()->create(); - arg->declRef = makeDeclRef(param); - arg->type.isLeftValue = param->findModifier() ? true : false; - arg->type.type = param->getType(); - arg->loc = loc; - if (!param->findModifier()) - { - if (auto pairType = visitor->getDifferentialPairType(param->getType())) - { - arg->type.type = pairType; - } - } - imaginaryArguments.add(arg); + auto pairType = visitor->getDifferentialPairType(thisType); + thisArgExpr->type.type = pairType; } - - // Copy parameter directions as is. - List expectedParamDirections; - for (auto param : originalFuncDecl->getParameters()) + else { - expectedParamDirections.add(getParameterDirection(param)); + thisArgExpr = nullptr; } - - return { imaginaryArguments, expectedParamDirections, thisArgExpr, thisTypeDirection }; } - ArgsWithDirectionInfo getImaginaryArgsToBackwardDerivative(SemanticsVisitor* visitor, FunctionDeclBase* originalFuncDecl, SourceLoc loc) + ParameterDirection thisTypeDirection = (thisArgExpr && !thisArgExpr->type.isLeftValue) + ? ParameterDirection::kParameterDirection_In + : ParameterDirection::kParameterDirection_InOut; + + List imaginaryArguments; + for (auto param : originalFuncDecl->getParameters()) { - Expr* thisArgExpr = nullptr; - if (auto thisType = getTypeForThisExpr(visitor, originalFuncDecl)) + auto arg = visitor->getASTBuilder()->create(); + arg->declRef = makeDeclRef(param); + arg->type.isLeftValue = param->findModifier() ? true : false; + arg->type.type = param->getType(); + arg->loc = loc; + if (!param->findModifier()) { - thisArgExpr = visitor->getASTBuilder()->create(); - thisArgExpr->type = thisType; - thisArgExpr->loc = loc; - - if (visitor->isTypeDifferentiable(thisType) && - !originalFuncDecl->findModifier() && - !isEffectivelyStatic(originalFuncDecl)) - { - auto pairType = visitor->getDifferentialPairType(thisType); - thisArgExpr->type.type = pairType; - - // TODO: for ptr pair types, no need to set isLeftValue to true. - if (as(thisArgExpr->type.type)) - thisArgExpr->type.isLeftValue = true; - } - else + if (auto pairType = visitor->getDifferentialPairType(param->getType())) { - thisArgExpr = nullptr; + arg->type.type = pairType; } } + imaginaryArguments.add(arg); + } - ParameterDirection thisTypeDirection = - (thisArgExpr && !thisArgExpr->type.isLeftValue) ? - ParameterDirection::kParameterDirection_In : - ParameterDirection::kParameterDirection_InOut; + // Copy parameter directions as is. + List expectedParamDirections; + for (auto param : originalFuncDecl->getParameters()) + { + expectedParamDirections.add(getParameterDirection(param)); + } - List imaginaryArguments; - List expectedParamDirections; + return {imaginaryArguments, expectedParamDirections, thisArgExpr, thisTypeDirection}; +} - auto isOutParam = [&](ParamDecl* param) - { - return param->findModifier() != nullptr - && param->findModifier() == nullptr && param->findModifier() == nullptr; - }; +ArgsWithDirectionInfo getImaginaryArgsToBackwardDerivative( + SemanticsVisitor* visitor, + FunctionDeclBase* originalFuncDecl, + SourceLoc loc) +{ + Expr* thisArgExpr = nullptr; + if (auto thisType = getTypeForThisExpr(visitor, originalFuncDecl)) + { + thisArgExpr = visitor->getASTBuilder()->create(); + thisArgExpr->type = thisType; + thisArgExpr->loc = loc; - for (auto param : originalFuncDecl->getParameters()) + if (visitor->isTypeDifferentiable(thisType) && + !originalFuncDecl->findModifier() && + !isEffectivelyStatic(originalFuncDecl)) { - auto arg = visitor->getASTBuilder()->create(); - arg->declRef = makeDeclRef(param); - arg->type.isLeftValue = param->findModifier() ? true : false; - arg->type.type = param->getType(); - arg->loc = loc; - - ParameterDirection direction = getParameterDirection(param); - - bool isDiffParam = (!param->findModifier()); - if (isDiffParam) - { - auto diffPair = visitor->getDifferentialPairType(param->getType()); - if (auto pairType = as(diffPair)) - { - arg->type.type = pairType; - arg->type.isLeftValue = true; - - if (isOutParam(param)) - { - // out T : IDifferentiable -> in T.Differential - arg->type.isLeftValue = false; - arg->type.type = visitor->tryGetDifferentialType( - visitor->getASTBuilder(), pairType->getPrimalType()); - - direction = ParameterDirection::kParameterDirection_In; - } - else - { - // in T : IDifferentiable -> inout DifferentialPair - // inout T : IDifferentiable -> inout DifferentialPair - direction = ParameterDirection::kParameterDirection_InOut; - } - } - else if (auto refPairType = as(diffPair)) - { - // no need to change direction of ref-pairs. - arg->type.type = refPairType; - } - else - { - isDiffParam = false; - } - } - if (!isDiffParam) - { - if (isOutParam(param)) - { - // Skip non-differentiable out params. - continue; - } - - // no_diff inout T -> in T - // no_diff in T -> in T - // - direction = ParameterDirection::kParameterDirection_In; - } + auto pairType = visitor->getDifferentialPairType(thisType); + thisArgExpr->type.type = pairType; - imaginaryArguments.add(arg); - expectedParamDirections.add(direction); + // TODO: for ptr pair types, no need to set isLeftValue to true. + if (as(thisArgExpr->type.type)) + thisArgExpr->type.isLeftValue = true; } - if (auto diffReturnType = visitor->tryGetDifferentialType(visitor->getASTBuilder(), originalFuncDecl->returnType.type)) + else { - auto arg = visitor->getASTBuilder()->create(); - arg->type.isLeftValue = false; - arg->type.type = diffReturnType; - arg->loc = loc; - imaginaryArguments.add(arg); - expectedParamDirections.add(ParameterDirection::kParameterDirection_In); + thisArgExpr = nullptr; } - - return {imaginaryArguments, expectedParamDirections, thisArgExpr, thisTypeDirection}; } - // This helper function is needed to workaround a gcc bug. - // Remove when we upgrade to a newer version of gcc. - template - static T* _findModifier(Decl* decl) + ParameterDirection thisTypeDirection = (thisArgExpr && !thisArgExpr->type.isLeftValue) + ? ParameterDirection::kParameterDirection_In + : ParameterDirection::kParameterDirection_InOut; + + List imaginaryArguments; + List expectedParamDirections; + + auto isOutParam = [&](ParamDecl* param) { - return decl->findModifier(); - } + return param->findModifier() != nullptr && + param->findModifier() == nullptr && + param->findModifier() == nullptr; + }; - template - void checkDerivativeOfAttributeImpl( - SemanticsVisitor* visitor, - FunctionDeclBase* funcDecl, - TDerivativeOfAttr* derivativeOfAttr, - DeclAssociationKind assocKind) + for (auto param : originalFuncDecl->getParameters()) { - auto astBuilder = visitor->getASTBuilder(); - DeclRef calleeDeclRef; - DeclRefExpr* calleeDeclRefExpr = nullptr; - HigherOrderInvokeExpr* higherOrderFuncExpr = astBuilder->create(); - higherOrderFuncExpr->baseFunction = derivativeOfAttr->funcExpr; - if (derivativeOfAttr->args.getCount() > 0) - higherOrderFuncExpr->loc = derivativeOfAttr->args[0]->loc; + auto arg = visitor->getASTBuilder()->create(); + arg->declRef = makeDeclRef(param); + arg->type.isLeftValue = param->findModifier() ? true : false; + arg->type.type = param->getType(); + arg->loc = loc; - Expr* checkedHigherOrderFuncExpr = visitor->dispatchExpr( - higherOrderFuncExpr, - visitor->allowStaticReferenceToNonStaticMember()); + ParameterDirection direction = getParameterDirection(param); - if (!checkedHigherOrderFuncExpr) - { - visitor->getSink()->diagnose(derivativeOfAttr, Diagnostics::cannotResolveOriginalFunctionForDerivative); - return; - } - List imaginaryArgs = getImaginaryArgsToFunc(astBuilder, funcDecl, derivativeOfAttr->loc).args; - auto invokeExpr = visitor->constructUncheckedInvokeExpr(checkedHigherOrderFuncExpr, imaginaryArgs); - SemanticsContext::ExprLocalScope scope; - auto ctx = visitor->withExprLocalScope(&scope); - auto subVisitor = SemanticsVisitor(ctx); - auto resolved = subVisitor.ResolveInvoke(invokeExpr); - if (auto resolvedInvoke = as(resolved)) + bool isDiffParam = (!param->findModifier()); + if (isDiffParam) { - auto resolvedFuncExpr = as(resolvedInvoke->functionExpr); - if (resolvedFuncExpr) + auto diffPair = visitor->getDifferentialPairType(param->getType()); + if (auto pairType = as(diffPair)) { - calleeDeclRefExpr = as(resolvedFuncExpr->baseFunction); - if (!calleeDeclRef && as(resolvedFuncExpr->baseFunction)) + arg->type.type = pairType; + arg->type.isLeftValue = true; + + if (isOutParam(param)) { - visitor->getSink()->diagnose( - derivativeOfAttr, - Diagnostics::overloadedFuncUsedWithDerivativeOfAttributes); + // out T : IDifferentiable -> in T.Differential + arg->type.isLeftValue = false; + arg->type.type = visitor->tryGetDifferentialType( + visitor->getASTBuilder(), + pairType->getPrimalType()); + + direction = ParameterDirection::kParameterDirection_In; + } + else + { + // in T : IDifferentiable -> inout DifferentialPair + // inout T : IDifferentiable -> inout DifferentialPair + direction = ParameterDirection::kParameterDirection_InOut; } } + else if (auto refPairType = as(diffPair)) + { + // no need to change direction of ref-pairs. + arg->type.type = refPairType; + } + else + { + isDiffParam = false; + } } - - if (!calleeDeclRefExpr) - { - visitor->getSink()->diagnose(derivativeOfAttr, Diagnostics::cannotResolveOriginalFunctionForDerivative); - return; - } - - calleeDeclRefExpr->loc = higherOrderFuncExpr->loc; - if (derivativeOfAttr->args.getCount() > 0) - derivativeOfAttr->args[0] = calleeDeclRefExpr; - - calleeDeclRef = calleeDeclRefExpr->declRef; - - auto calleeFunc = as(calleeDeclRef.getDecl()); - if (!calleeFunc) - { - visitor->getSink()->diagnose(derivativeOfAttr, Diagnostics::cannotResolveOriginalFunctionForDerivative); - return; - } - - // For now, if calleeFunc or funcDecl is nested inside some generic aggregate, - // they must be the same generic decl. For example, using B.f() as the original function - // for C.derivative() is not allowed. - // We may relax this restriction in the future by solving the "inverse" generic arguments - // from the `calleeDeclRef`, and use them to create a declRef to funcDecl from the original - // func. - - if (isInterfaceRequirement(calleeFunc)) - { - visitor->getSink()->diagnose(derivativeOfAttr, Diagnostics::cannotAssociateInterfaceRequirementWithDerivative); - return; - } - if (isInterfaceRequirement(funcDecl)) + if (!isDiffParam) { - visitor->getSink()->diagnose(derivativeOfAttr, Diagnostics::cannotUseInterfaceRequirementAsDerivative); - return; - } + if (isOutParam(param)) + { + // Skip non-differentiable out params. + continue; + } - if (auto existingModifier = _findModifier(calleeFunc)) - { - // The primal function already has a `[*Derivative]` attribute, this is invalid. - visitor->getSink()->diagnose( - derivativeOfAttr, - Diagnostics::declAlreadyHasAttribute, - calleeDeclRef, - getDerivativeAttrName()); - visitor->getSink()->diagnose(existingModifier->loc, Diagnostics::seeDeclarationOf, calleeDeclRef.getDecl()); + // no_diff inout T -> in T + // no_diff in T -> in T + // + direction = ParameterDirection::kParameterDirection_In; } - derivativeOfAttr->funcExpr = calleeDeclRefExpr; - auto derivativeAttr = astBuilder->create(); - derivativeAttr->loc = derivativeOfAttr->loc; - auto outterGeneric = visitor->GetOuterGeneric(funcDecl); - auto declRef = makeDeclRef((outterGeneric ? (Decl*)outterGeneric : funcDecl)); - - // If both the derivative and the original function are defined in the same outer generic - // aggregate type, we want to form a full declref with default arguments. - declRef = createDefaultSubstitutionsIfNeeded(astBuilder, visitor, declRef); - - auto declRefExpr = visitor->ConstructDeclRefExpr(declRef, nullptr, declRef.getName(), derivativeOfAttr->loc, nullptr); - declRefExpr->type.type = nullptr; - derivativeAttr->args.add(declRefExpr); - derivativeAttr->funcExpr = declRefExpr; - checkDerivativeAttribute(visitor, calleeFunc, derivativeAttr); - derivativeOfAttr->backDeclRef = derivativeAttr->funcExpr; - derivativeAttr->funcExpr = nullptr; - visitor->getShared()->registerAssociatedDecl(calleeDeclRef.getDecl(), assocKind, funcDecl); + imaginaryArguments.add(arg); + expectedParamDirections.add(direction); } - - static void checkDerivativeAttribute(SemanticsVisitor* visitor, FunctionDeclBase* funcDecl, ForwardDerivativeAttribute* attr) + if (auto diffReturnType = visitor->tryGetDifferentialType( + visitor->getASTBuilder(), + originalFuncDecl->returnType.type)) { - if (!attr->funcExpr) - return; - if (attr->funcExpr->type.type) - return; - - ArgsWithDirectionInfo imaginaryArguments = getImaginaryArgsToForwardDerivative(visitor, funcDecl, attr->loc); - checkDerivativeAttributeImpl( - visitor, - funcDecl, - attr, - imaginaryArguments.args, - imaginaryArguments.directions, - imaginaryArguments.thisArg, - imaginaryArguments.thisArgDirection); + auto arg = visitor->getASTBuilder()->create(); + arg->type.isLeftValue = false; + arg->type.type = diffReturnType; + arg->loc = loc; + imaginaryArguments.add(arg); + expectedParamDirections.add(ParameterDirection::kParameterDirection_In); } - static void checkDerivativeAttribute(SemanticsVisitor* visitor, FunctionDeclBase* funcDecl, BackwardDerivativeAttribute* attr) - { - if (!attr->funcExpr) - return; - if (attr->funcExpr->type.type) - return; + return {imaginaryArguments, expectedParamDirections, thisArgExpr, thisTypeDirection}; +} - ArgsWithDirectionInfo imaginaryArguments = getImaginaryArgsToBackwardDerivative(visitor, funcDecl, attr->loc); - checkDerivativeAttributeImpl( - visitor, - funcDecl, - attr, - imaginaryArguments.args, - imaginaryArguments.directions, - imaginaryArguments.thisArg, - imaginaryArguments.thisArgDirection); - } +// This helper function is needed to workaround a gcc bug. +// Remove when we upgrade to a newer version of gcc. +template +static T* _findModifier(Decl* decl) +{ + return decl->findModifier(); +} - static void checkDerivativeAttribute(SemanticsVisitor* visitor, FunctionDeclBase* funcDecl, PrimalSubstituteAttribute* attr) - { - if (!attr->funcExpr) - return; - if (attr->funcExpr->type.type) - return; +template +void checkDerivativeOfAttributeImpl( + SemanticsVisitor* visitor, + FunctionDeclBase* funcDecl, + TDerivativeOfAttr* derivativeOfAttr, + DeclAssociationKind assocKind) +{ + auto astBuilder = visitor->getASTBuilder(); + DeclRef calleeDeclRef; + DeclRefExpr* calleeDeclRefExpr = nullptr; + HigherOrderInvokeExpr* higherOrderFuncExpr = astBuilder->create(); + higherOrderFuncExpr->baseFunction = derivativeOfAttr->funcExpr; + if (derivativeOfAttr->args.getCount() > 0) + higherOrderFuncExpr->loc = derivativeOfAttr->args[0]->loc; - ArgsWithDirectionInfo imaginaryArguments = getImaginaryArgsToFunc(visitor->getASTBuilder(), funcDecl, attr->loc); - checkDerivativeAttributeImpl( - visitor, - funcDecl, - attr, - imaginaryArguments.args, - imaginaryArguments.directions, - imaginaryArguments.thisArg, - imaginaryArguments.thisArgDirection); - } + Expr* checkedHigherOrderFuncExpr = visitor->dispatchExpr( + higherOrderFuncExpr, + visitor->allowStaticReferenceToNonStaticMember()); - static void checkCudaKernelAttribute(SemanticsVisitor* visitor, FunctionDeclBase* funcDecl, CudaKernelAttribute*) + if (!checkedHigherOrderFuncExpr) { - // If the method is also marked differentiable, check that the data types are either non-differentiable - // or marked with no_diff. - // - // Note: This is a temporary restriction until we have a more complete story for differentiability. - // - if (funcDecl->findModifier()) + visitor->getSink()->diagnose( + derivativeOfAttr, + Diagnostics::cannotResolveOriginalFunctionForDerivative); + return; + } + List imaginaryArgs = + getImaginaryArgsToFunc(astBuilder, funcDecl, derivativeOfAttr->loc).args; + auto invokeExpr = + visitor->constructUncheckedInvokeExpr(checkedHigherOrderFuncExpr, imaginaryArgs); + SemanticsContext::ExprLocalScope scope; + auto ctx = visitor->withExprLocalScope(&scope); + auto subVisitor = SemanticsVisitor(ctx); + auto resolved = subVisitor.ResolveInvoke(invokeExpr); + if (auto resolvedInvoke = as(resolved)) + { + auto resolvedFuncExpr = as(resolvedInvoke->functionExpr); + if (resolvedFuncExpr) { - for (auto paramDecl : funcDecl->getParameters()) + calleeDeclRefExpr = as(resolvedFuncExpr->baseFunction); + if (!calleeDeclRef && as(resolvedFuncExpr->baseFunction)) { - auto paramType = paramDecl->type; - - if (visitor->isTypeDifferentiable(paramType)) - { - if (!paramDecl->hasModifier()) - { - visitor->getSink()->diagnose(paramDecl, Diagnostics::differentiableKernelEntryPointCannotHaveDifferentiableParams); - } - } + visitor->getSink()->diagnose( + derivativeOfAttr, + Diagnostics::overloadedFuncUsedWithDerivativeOfAttributes); } } } - template - bool tryCheckDerivativeOfAttributeImpl( - SemanticsVisitor* visitor, - FunctionDeclBase* funcDecl, - TDerivativeOfAttr* derivativeOfAttr, - DeclAssociationKind assocKind, - const List& imaginaryArgsToOriginal) - { - DiagnosticSink tempSink(visitor->getSourceManager(), nullptr); - SemanticsVisitor subVisitor(visitor->withSink(&tempSink)); - checkDerivativeOfAttributeImpl( - &subVisitor, - funcDecl, + if (!calleeDeclRefExpr) + { + visitor->getSink()->diagnose( derivativeOfAttr, - assocKind, - imaginaryArgsToOriginal); - return tempSink.getErrorCount() == 0; + Diagnostics::cannotResolveOriginalFunctionForDerivative); + return; } - void SemanticsDeclAttributesVisitor::checkForwardDerivativeOfAttribute(FunctionDeclBase* funcDecl, ForwardDerivativeOfAttribute* attr) + calleeDeclRefExpr->loc = higherOrderFuncExpr->loc; + if (derivativeOfAttr->args.getCount() > 0) + derivativeOfAttr->args[0] = calleeDeclRefExpr; + + calleeDeclRef = calleeDeclRefExpr->declRef; + + auto calleeFunc = as(calleeDeclRef.getDecl()); + if (!calleeFunc) { - checkDerivativeOfAttributeImpl( - this, funcDecl, attr, DeclAssociationKind::ForwardDerivativeFunc); + visitor->getSink()->diagnose( + derivativeOfAttr, + Diagnostics::cannotResolveOriginalFunctionForDerivative); + return; } - void SemanticsDeclAttributesVisitor::checkBackwardDerivativeOfAttribute(FunctionDeclBase* funcDecl, BackwardDerivativeOfAttribute* attr) + // For now, if calleeFunc or funcDecl is nested inside some generic aggregate, + // they must be the same generic decl. For example, using B.f() as the original function + // for C.derivative() is not allowed. + // We may relax this restriction in the future by solving the "inverse" generic arguments + // from the `calleeDeclRef`, and use them to create a declRef to funcDecl from the original + // func. + + if (isInterfaceRequirement(calleeFunc)) { - checkDerivativeOfAttributeImpl( - this, funcDecl, attr, DeclAssociationKind::BackwardDerivativeFunc); + visitor->getSink()->diagnose( + derivativeOfAttr, + Diagnostics::cannotAssociateInterfaceRequirementWithDerivative); + return; } - - void SemanticsDeclAttributesVisitor::checkPrimalSubstituteOfAttribute(FunctionDeclBase* funcDecl, PrimalSubstituteOfAttribute* attr) + if (isInterfaceRequirement(funcDecl)) { - checkDerivativeOfAttributeImpl( - this, funcDecl, attr, DeclAssociationKind::PrimalSubstituteFunc); + visitor->getSink()->diagnose( + derivativeOfAttr, + Diagnostics::cannotUseInterfaceRequirementAsDerivative); + return; } - void SemanticsDeclAttributesVisitor::visitStructDecl(StructDecl* structDecl) + if (auto existingModifier = _findModifier(calleeFunc)) { - // add a empty deault CTor if missing; checking in attributes - // to avoid circular checking logic - auto defaultCtor = _getDefaultCtor(structDecl); - if (!defaultCtor) - _createCtor(this, m_astBuilder, structDecl); - - int backingWidth = 0; - [[maybe_unused]] - int totalWidth = 0; - struct BitFieldInfo - { - int memberIndex; - int bitWidth; - Type* memberType; - BitFieldModifier* bitFieldModifier; - }; - List groupInfo; - - int memberIndex = 0; - int backing_nonce = 0; - const auto dispatchSomeBitPackedMembers = [&](){ - SLANG_ASSERT(totalWidth <= backingWidth); - SLANG_ASSERT(backingWidth <= 64); - - // We're going to insert a backing member to be referenced in - // all the bitfield properties - if(groupInfo.getCount()) - { - const auto backingMemberBasicType - = backingWidth <= 8 ? BaseType::UInt8 - : backingWidth <= 16 ? BaseType::UInt16 - : backingWidth <= 32 ? BaseType::UInt - : BaseType::UInt64; - auto backingMember = m_astBuilder->create(); - backingMember->type.type = m_astBuilder->getBuiltinType(backingMemberBasicType); - backingMember->nameAndLoc.name = getName(String("$bit_field_backing_") + String(backing_nonce)); - backing_nonce++; - backingMember->initExpr = nullptr; - backingMember->parentDecl = structDecl; - const auto backingMemberDeclRef = DeclRef(backingMember->getDefaultDeclRef()); - - int bottomOfMember = 0; - for(const auto m : groupInfo) - { - SLANG_ASSERT(bottomOfMember <= backingWidth); + // The primal function already has a `[*Derivative]` attribute, this is invalid. + visitor->getSink()->diagnose( + derivativeOfAttr, + Diagnostics::declAlreadyHasAttribute, + calleeDeclRef, + getDerivativeAttrName()); + visitor->getSink()->diagnose( + existingModifier->loc, + Diagnostics::seeDeclarationOf, + calleeDeclRef.getDecl()); + } + + derivativeOfAttr->funcExpr = calleeDeclRefExpr; + auto derivativeAttr = astBuilder->create(); + derivativeAttr->loc = derivativeOfAttr->loc; + auto outterGeneric = visitor->GetOuterGeneric(funcDecl); + auto declRef = makeDeclRef((outterGeneric ? (Decl*)outterGeneric : funcDecl)); + + // If both the derivative and the original function are defined in the same outer generic + // aggregate type, we want to form a full declref with default arguments. + declRef = createDefaultSubstitutionsIfNeeded(astBuilder, visitor, declRef); + + auto declRefExpr = visitor->ConstructDeclRefExpr( + declRef, + nullptr, + declRef.getName(), + derivativeOfAttr->loc, + nullptr); + declRefExpr->type.type = nullptr; + derivativeAttr->args.add(declRefExpr); + derivativeAttr->funcExpr = declRefExpr; + checkDerivativeAttribute(visitor, calleeFunc, derivativeAttr); + derivativeOfAttr->backDeclRef = derivativeAttr->funcExpr; + derivativeAttr->funcExpr = nullptr; + visitor->getShared()->registerAssociatedDecl(calleeDeclRef.getDecl(), assocKind, funcDecl); +} - m.bitFieldModifier->backingDeclRef = backingMemberDeclRef; - m.bitFieldModifier->offset = bottomOfMember; +static void checkDerivativeAttribute( + SemanticsVisitor* visitor, + FunctionDeclBase* funcDecl, + ForwardDerivativeAttribute* attr) +{ + if (!attr->funcExpr) + return; + if (attr->funcExpr->type.type) + return; + + ArgsWithDirectionInfo imaginaryArguments = + getImaginaryArgsToForwardDerivative(visitor, funcDecl, attr->loc); + checkDerivativeAttributeImpl( + visitor, + funcDecl, + attr, + imaginaryArguments.args, + imaginaryArguments.directions, + imaginaryArguments.thisArg, + imaginaryArguments.thisArgDirection); +} - bottomOfMember += m.bitWidth; - } +static void checkDerivativeAttribute( + SemanticsVisitor* visitor, + FunctionDeclBase* funcDecl, + BackwardDerivativeAttribute* attr) +{ + if (!attr->funcExpr) + return; + if (attr->funcExpr->type.type) + return; + + ArgsWithDirectionInfo imaginaryArguments = + getImaginaryArgsToBackwardDerivative(visitor, funcDecl, attr->loc); + checkDerivativeAttributeImpl( + visitor, + funcDecl, + attr, + imaginaryArguments.args, + imaginaryArguments.directions, + imaginaryArguments.thisArg, + imaginaryArguments.thisArgDirection); +} - const auto backingMemberIndex = groupInfo[0].memberIndex; - structDecl->members.insert(backingMemberIndex, backingMember); - structDecl->invalidateMemberDictionary(); - ++memberIndex; - } - structDecl->buildMemberDictionary(); +static void checkDerivativeAttribute( + SemanticsVisitor* visitor, + FunctionDeclBase* funcDecl, + PrimalSubstituteAttribute* attr) +{ + if (!attr->funcExpr) + return; + if (attr->funcExpr->type.type) + return; + + ArgsWithDirectionInfo imaginaryArguments = + getImaginaryArgsToFunc(visitor->getASTBuilder(), funcDecl, attr->loc); + checkDerivativeAttributeImpl( + visitor, + funcDecl, + attr, + imaginaryArguments.args, + imaginaryArguments.directions, + imaginaryArguments.thisArg, + imaginaryArguments.thisArgDirection); +} - // Reset everything - backingWidth = 0; - totalWidth = 0; - groupInfo.clear(); - }; - for(; memberIndex < structDecl->members.getCount(); ++memberIndex) +static void checkCudaKernelAttribute( + SemanticsVisitor* visitor, + FunctionDeclBase* funcDecl, + CudaKernelAttribute*) +{ + // If the method is also marked differentiable, check that the data types are either + // non-differentiable or marked with no_diff. + // + // Note: This is a temporary restriction until we have a more complete story for + // differentiability. + // + if (funcDecl->findModifier()) + { + for (auto paramDecl : funcDecl->getParameters()) { - const auto& m = structDecl->members[memberIndex]; + auto paramType = paramDecl->type; - // We can trivially skip any non-property decls - const auto v = as(m); - if(!v) + if (visitor->isTypeDifferentiable(paramType)) { - // If this is a non-bitfield member then finish the current group - if(as(m)) - dispatchSomeBitPackedMembers(); - continue; - } - - const auto bfm = m->findModifier(); - // If there isn't a bit field modifier, then dispatch the - // current group and continue - if(!bfm) - { - dispatchSomeBitPackedMembers(); - continue; - } - - // Verify that this makes sense as a bitfield - const auto t = v->type.type->getCanonicalType(); - SLANG_ASSERT(t); - const auto b = as(t); - if(!b) - { - getSink()->diagnose(v->loc, Diagnostics::bitFieldNonIntegral, t); - continue; - } - const auto baseType = b->getBaseType(); - const bool isIntegerType = isIntegerBaseType(baseType); - if(!isIntegerType) - { - getSink()->diagnose(v->loc, Diagnostics::bitFieldNonIntegral, t); - continue; - } - - // The bit width of this member, and the member type width - const auto thisFieldWidth = bfm->width; - const auto thisFieldTypeWidth = getTypeBitSize(b); - SLANG_ASSERT(thisFieldTypeWidth != 0); - if(thisFieldWidth > thisFieldTypeWidth) - { - getSink()->diagnose( - v->loc, - Diagnostics::bitFieldTooWide, - thisFieldWidth, - t, - thisFieldTypeWidth - ); - // Not much we can do with this field, just ignore it - continue; + if (!paramDecl->hasModifier()) + { + visitor->getSink()->diagnose( + paramDecl, + Diagnostics::differentiableKernelEntryPointCannotHaveDifferentiableParams); + } } + } + } +} - // At this point we're sure that we have a bit field, - // update our bit packing state - - // If there's a 0 width type, dispatch the current group - if(thisFieldWidth == 0) - dispatchSomeBitPackedMembers(); +template +bool tryCheckDerivativeOfAttributeImpl( + SemanticsVisitor* visitor, + FunctionDeclBase* funcDecl, + TDerivativeOfAttr* derivativeOfAttr, + DeclAssociationKind assocKind, + const List& imaginaryArgsToOriginal) +{ + DiagnosticSink tempSink(visitor->getSourceManager(), nullptr); + SemanticsVisitor subVisitor(visitor->withSink(&tempSink)); + checkDerivativeOfAttributeImpl( + &subVisitor, + funcDecl, + derivativeOfAttr, + assocKind, + imaginaryArgsToOriginal); + return tempSink.getErrorCount() == 0; +} - // If this member wouldn't fit into the current group, dispatch - // everything so far; - if(totalWidth + thisFieldWidth > std::max(thisFieldTypeWidth, backingWidth)) - dispatchSomeBitPackedMembers(); +void SemanticsDeclAttributesVisitor::checkForwardDerivativeOfAttribute( + FunctionDeclBase* funcDecl, + ForwardDerivativeOfAttribute* attr) +{ + checkDerivativeOfAttributeImpl( + this, + funcDecl, + attr, + DeclAssociationKind::ForwardDerivativeFunc); +} - // Add this member to the group, - // Grow the backing width if necessary - backingWidth = std::max(thisFieldTypeWidth, backingWidth); - // Grow the total width - totalWidth += int(thisFieldWidth); - groupInfo.add({memberIndex, int(thisFieldWidth), t, bfm}); - } - // If the struct ended with a bitpacked member, then make sure we don't forget the last group - dispatchSomeBitPackedMembers(); - } +void SemanticsDeclAttributesVisitor::checkBackwardDerivativeOfAttribute( + FunctionDeclBase* funcDecl, + BackwardDerivativeOfAttribute* attr) +{ + checkDerivativeOfAttributeImpl( + this, + funcDecl, + attr, + DeclAssociationKind::BackwardDerivativeFunc); +} - void SemanticsDeclAttributesVisitor::visitFunctionDeclBase(FunctionDeclBase* decl) - { - // Run checking on attributes that can't be fully checked in header checking stage. - for (auto attr : decl->modifiers) - { - if (auto fwdDerivativeOfAttr = as(attr)) - checkForwardDerivativeOfAttribute(decl, fwdDerivativeOfAttr); - else if (auto bwdDerivativeOfAttr = as(attr)) - checkBackwardDerivativeOfAttribute(decl, bwdDerivativeOfAttr); - else if (auto primalOfAttr = as(attr)) - checkPrimalSubstituteOfAttribute(decl, primalOfAttr); - else if (auto fwdDerivativeAttr = as(attr)) - checkDerivativeAttribute(this, decl, fwdDerivativeAttr); - else if (auto bwdDerivativeAttr = as(attr)) - checkDerivativeAttribute(this, decl, bwdDerivativeAttr); - else if (auto primalAttr = as(attr)) - checkDerivativeAttribute(this, decl, primalAttr); - else if (auto cudaKernelAttr = as(attr)) - checkCudaKernelAttribute(this, decl, cudaKernelAttr); - } - } +void SemanticsDeclAttributesVisitor::checkPrimalSubstituteOfAttribute( + FunctionDeclBase* funcDecl, + PrimalSubstituteOfAttribute* attr) +{ + checkDerivativeOfAttributeImpl( + this, + funcDecl, + attr, + DeclAssociationKind::PrimalSubstituteFunc); +} - static void _propagateSeeDefinitionOf(SemanticsVisitor* visitor, Decl* funcDecl, DiagnosticCategory diagnosticCategory) - { - maybeDiagnose(visitor->getSink(), visitor->getOptionSet(), diagnosticCategory, funcDecl, Diagnostics::seeDefinitionOf, funcDecl); - } +void SemanticsDeclAttributesVisitor::visitStructDecl(StructDecl* structDecl) +{ + // add a empty deault CTor if missing; checking in attributes + // to avoid circular checking logic + auto defaultCtor = _getDefaultCtor(structDecl); + if (!defaultCtor) + _createCtor(this, m_astBuilder, structDecl); + + int backingWidth = 0; + [[maybe_unused]] int totalWidth = 0; + struct BitFieldInfo + { + int memberIndex; + int bitWidth; + Type* memberType; + BitFieldModifier* bitFieldModifier; + }; + List groupInfo; - static void _propagateRequirement(SemanticsVisitor* visitor, CapabilitySet& resultCaps, SyntaxNode* userNode, SyntaxNode* referencedNode, const CapabilitySet& nodeCaps, SourceLoc referenceLoc) + int memberIndex = 0; + int backing_nonce = 0; + const auto dispatchSomeBitPackedMembers = [&]() { - auto referencedDecl = as(referencedNode); + SLANG_ASSERT(totalWidth <= backingWidth); + SLANG_ASSERT(backingWidth <= 64); - // Ignore cyclic references. - if (referencedDecl) + // We're going to insert a backing member to be referenced in + // all the bitfield properties + if (groupInfo.getCount()) { - if (referencedDecl->checkState.isBeingChecked()) - return; - - ensureDecl(visitor, referencedDecl, DeclCheckState::CapabilityChecked); - } - - if (resultCaps.implies(nodeCaps)) - return; - - auto oldCaps = resultCaps; - bool isAnyInvalid = resultCaps.isInvalid() || nodeCaps.isInvalid(); - resultCaps.join(nodeCaps); - - auto decl = as(userNode); + const auto backingMemberBasicType = backingWidth <= 8 ? BaseType::UInt8 + : backingWidth <= 16 ? BaseType::UInt16 + : backingWidth <= 32 ? BaseType::UInt + : BaseType::UInt64; + auto backingMember = m_astBuilder->create(); + backingMember->type.type = m_astBuilder->getBuiltinType(backingMemberBasicType); + backingMember->nameAndLoc.name = + getName(String("$bit_field_backing_") + String(backing_nonce)); + backing_nonce++; + backingMember->initExpr = nullptr; + backingMember->parentDecl = structDecl; + const auto backingMemberDeclRef = DeclRef(backingMember->getDefaultDeclRef()); - if (!isAnyInvalid && resultCaps.isInvalid()) - { - // If joining the referenced decl's requirements results an invalid capability set, - // then the decl is using things that require conflicting set of capabilities, and we should diagnose an error. - if (referencedDecl && decl) - { - maybeDiagnose( - visitor->getSink(), - visitor->getOptionSet(), - DiagnosticCategory::Capability, - referenceLoc, - Diagnostics::conflictingCapabilityDueToUseOfDecl, - referencedDecl, - nodeCaps, - decl, - oldCaps); - } - else if (decl) + int bottomOfMember = 0; + for (const auto m : groupInfo) { - maybeDiagnose( - visitor->getSink(), - visitor->getOptionSet(), - DiagnosticCategory::Capability, - referenceLoc, - Diagnostics::conflictingCapabilityDueToStatement, - nodeCaps, - decl, - oldCaps); - } - else - { - maybeDiagnose( - visitor->getSink(), - visitor->getOptionSet(), - DiagnosticCategory::Capability, - referenceLoc, - Diagnostics::conflictingCapabilityDueToStatementEnclosingFunc, - nodeCaps, - oldCaps); + SLANG_ASSERT(bottomOfMember <= backingWidth); + + m.bitFieldModifier->backingDeclRef = backingMemberDeclRef; + m.bitFieldModifier->offset = bottomOfMember; + + bottomOfMember += m.bitWidth; } - } - // if stmt inside parent, set the provenance tracker to the calling function - if(!decl) - decl = visitor->getParentFuncOfVisitor(); - if (referencedDecl && decl) - { - // Here we store a childDecl that added/removed capabilities from a parentDecl - decl->capabilityRequirementProvenance.add(DeclReferenceWithLoc{ referencedDecl, referenceLoc }); + const auto backingMemberIndex = groupInfo[0].memberIndex; + structDecl->members.insert(backingMemberIndex, backingMember); + structDecl->invalidateMemberDictionary(); + ++memberIndex; } - }; - - CapabilitySet getStatementCapabilityUsage(SemanticsVisitor* visitor, Stmt* stmt); + structDecl->buildMemberDictionary(); - template - struct CapabilityDeclReferenceVisitor - : public SemanticsDeclReferenceVisitor> + // Reset everything + backingWidth = 0; + totalWidth = 0; + groupInfo.clear(); + }; + for (; memberIndex < structDecl->members.getCount(); ++memberIndex) { - typedef SemanticsDeclReferenceVisitor> Base; + const auto& m = structDecl->members[memberIndex]; - const ProcessFunc handleProcessFunc; - const ParentDiagnosticFunc handleParentDiagnosticFunc; - RequireCapabilityAttribute* maybeRequireCapability; - SemanticsContext& outerContext; - CapabilityDeclReferenceVisitor(const ProcessFunc& processFunc, const ParentDiagnosticFunc& parentDiagnosticFunc, RequireCapabilityAttribute* maybeRequireCapability, SemanticsContext& outer) - : handleProcessFunc(processFunc) - , handleParentDiagnosticFunc(parentDiagnosticFunc) - , maybeRequireCapability(maybeRequireCapability) - , outerContext(outer) - , SemanticsDeclReferenceVisitor>(outer) - { - } - virtual void processReferencedDecl(Decl* decl) override + // We can trivially skip any non-property decls + const auto v = as(m); + if (!v) { - SourceLoc loc = SourceLoc(); - if (Base::sourceLocStack.getCount()) - loc = Base::sourceLocStack.getLast(); - handleProcessFunc(decl, decl->inferredCapabilityRequirements, loc); + // If this is a non-bitfield member then finish the current group + if (as(m)) + dispatchSomeBitPackedMembers(); + continue; } - virtual void processDeclModifiers(Decl* decl, SourceLoc refLoc) override + + const auto bfm = m->findModifier(); + // If there isn't a bit field modifier, then dispatch the + // current group and continue + if (!bfm) { - if (decl) - handleProcessFunc(decl, decl->inferredCapabilityRequirements, refLoc); + dispatchSomeBitPackedMembers(); + continue; } - void visitDiscardStmt(DiscardStmt* stmt) + + // Verify that this makes sense as a bitfield + const auto t = v->type.type->getCanonicalType(); + SLANG_ASSERT(t); + const auto b = as(t); + if (!b) { - handleProcessFunc(stmt, CapabilitySet(CapabilityName::fragment), stmt->loc); + getSink()->diagnose(v->loc, Diagnostics::bitFieldNonIntegral, t); + continue; } - void visitTargetSwitchStmt(TargetSwitchStmt* stmt) + const auto baseType = b->getBaseType(); + const bool isIntegerType = isIntegerBaseType(baseType); + if (!isIntegerType) { - CapabilitySet set; - auto targetCaseCount = stmt->targetCases.getCount(); - for (Index targetCaseIndex = 0; targetCaseIndex < targetCaseCount; targetCaseIndex++) - { - // We may recieve a `default:` case for a `__target_switch`. If this is the case, - // we must resolve the target capability for a non empty set of `calling_functions_targets`: - // ``` default_target = calling_functions_targets-{other_case_targets} ``` - // - // * `calling_functions_capability` = `requirement attribute` of the calling function; if missing - // we can assume it is `any_target` - // - // * `{other_case_targets}` = set of all capabilities all `case` statments target inside the `__target_switch` - - // If we do not handle `default:`, the codegen will fail when trying to find a specific - // codegen target not handled explicitly by a `case` statment. - // We must also ensure the `default` case is last so we have priority to hit `case` statments and can preprocess - // `case` statments before the `default` case. - CapabilitySet targetCap; - if (CapabilityName(stmt->targetCases[targetCaseIndex]->capability) == CapabilityName::Invalid) - { - if (targetCaseCount - 1 != targetCaseIndex) - { - for (Index i = targetCaseIndex; i < targetCaseCount - 1; i++) - std::swap(stmt->targetCases[i], stmt->targetCases[i + 1]); - continue; - } - - if (!maybeRequireCapability) - targetCap = (CapabilitySet(CapabilityName::any_target).getTargetsThisHasButOtherDoesNot(set)); - else - targetCap = (maybeRequireCapability->capabilitySet.getTargetsThisHasButOtherDoesNot(set)); - } - else - { - targetCap = CapabilitySet(CapabilityName(stmt->targetCases[targetCaseIndex]->capability)); - - if (maybeRequireCapability) - { - CapabilitySet testingForInvalid = maybeRequireCapability->capabilitySet; - // Ensure case statement is valid with parent `[require(...)]` - testingForInvalid.join(targetCap); - if (testingForInvalid.isInvalid()) - { - maybeDiagnose(Base::getSink(), outerContext.getOptionSet(), DiagnosticCategory::Capability, stmt->targetCases[targetCaseIndex]->loc, - Diagnostics::conflictingCapabilityDueToStatement, targetCap, maybeRequireCapability, maybeRequireCapability->capabilitySet); - handleParentDiagnosticFunc(DiagnosticCategory::Capability); - } - } - } - auto targetCase = stmt->targetCases[targetCaseIndex]; - auto oldCap = targetCap; - auto bodyCap = getStatementCapabilityUsage(this, targetCase->body); - targetCap.join(bodyCap); - if (targetCap.isInvalid()) - { - maybeDiagnose(Base::getSink(), outerContext.getOptionSet(), DiagnosticCategory::Capability, targetCase->body->loc, Diagnostics::conflictingCapabilityDueToStatement, bodyCap, "target_switch", oldCap); - handleParentDiagnosticFunc(DiagnosticCategory::Capability); - } - set.unionWith(targetCap); - } - handleProcessFunc(stmt, set, stmt->loc); + getSink()->diagnose(v->loc, Diagnostics::bitFieldNonIntegral, t); + continue; } - void visitRequireCapabilityDecl(RequireCapabilityDecl* decl) + // The bit width of this member, and the member type width + const auto thisFieldWidth = bfm->width; + const auto thisFieldTypeWidth = getTypeBitSize(b); + SLANG_ASSERT(thisFieldTypeWidth != 0); + if (thisFieldWidth > thisFieldTypeWidth) { - handleProcessFunc(decl, decl->inferredCapabilityRequirements, decl->loc); - } - }; + getSink()->diagnose( + v->loc, + Diagnostics::bitFieldTooWide, + thisFieldWidth, + t, + thisFieldTypeWidth); + // Not much we can do with this field, just ignore it + continue; + } + + // At this point we're sure that we have a bit field, + // update our bit packing state + + // If there's a 0 width type, dispatch the current group + if (thisFieldWidth == 0) + dispatchSomeBitPackedMembers(); + + // If this member wouldn't fit into the current group, dispatch + // everything so far; + if (totalWidth + thisFieldWidth > std::max(thisFieldTypeWidth, backingWidth)) + dispatchSomeBitPackedMembers(); + + // Add this member to the group, + // Grow the backing width if necessary + backingWidth = std::max(thisFieldTypeWidth, backingWidth); + // Grow the total width + totalWidth += int(thisFieldWidth); + groupInfo.add({memberIndex, int(thisFieldWidth), t, bfm}); + } + // If the struct ended with a bitpacked member, then make sure we don't forget the last group + dispatchSomeBitPackedMembers(); +} + +void SemanticsDeclAttributesVisitor::visitFunctionDeclBase(FunctionDeclBase* decl) +{ + // Run checking on attributes that can't be fully checked in header checking stage. + for (auto attr : decl->modifiers) + { + if (auto fwdDerivativeOfAttr = as(attr)) + checkForwardDerivativeOfAttribute(decl, fwdDerivativeOfAttr); + else if (auto bwdDerivativeOfAttr = as(attr)) + checkBackwardDerivativeOfAttribute(decl, bwdDerivativeOfAttr); + else if (auto primalOfAttr = as(attr)) + checkPrimalSubstituteOfAttribute(decl, primalOfAttr); + else if (auto fwdDerivativeAttr = as(attr)) + checkDerivativeAttribute(this, decl, fwdDerivativeAttr); + else if (auto bwdDerivativeAttr = as(attr)) + checkDerivativeAttribute(this, decl, bwdDerivativeAttr); + else if (auto primalAttr = as(attr)) + checkDerivativeAttribute(this, decl, primalAttr); + else if (auto cudaKernelAttr = as(attr)) + checkCudaKernelAttribute(this, decl, cudaKernelAttr); + } +} - template - void visitReferencedDecls(SemanticsContext& context, NodeBase* node, SourceLoc initialLoc, RequireCapabilityAttribute* maybeRequireCapability, const ProcessFunc& processFunc, const ParentDiagnosticFunc& parentDiagnosticFunc) +static void _propagateSeeDefinitionOf( + SemanticsVisitor* visitor, + Decl* funcDecl, + DiagnosticCategory diagnosticCategory) +{ + maybeDiagnose( + visitor->getSink(), + visitor->getOptionSet(), + diagnosticCategory, + funcDecl, + Diagnostics::seeDefinitionOf, + funcDecl); +} + +static void _propagateRequirement( + SemanticsVisitor* visitor, + CapabilitySet& resultCaps, + SyntaxNode* userNode, + SyntaxNode* referencedNode, + const CapabilitySet& nodeCaps, + SourceLoc referenceLoc) +{ + auto referencedDecl = as(referencedNode); + + // Ignore cyclic references. + if (referencedDecl) { - CapabilityDeclReferenceVisitor visitor(processFunc, parentDiagnosticFunc, maybeRequireCapability, context); - visitor.sourceLocStack.add(initialLoc); + if (referencedDecl->checkState.isBeingChecked()) + return; - if (auto val = as(node)) - visitor.dispatchIfNotNull(val); - if (auto stmt = as(node)) - visitor.dispatchIfNotNull(stmt); - if (auto expr = as(node)) - visitor.dispatchIfNotNull(expr); - if (auto decl = as(node)) - visitor.dispatchIfNotNull(decl); + ensureDecl(visitor, referencedDecl, DeclCheckState::CapabilityChecked); } - CapabilitySet getStatementCapabilityUsage(SemanticsVisitor* visitor, Stmt* stmt) - { - if (stmt == nullptr) - return CapabilitySet(); + if (resultCaps.implies(nodeCaps)) + return; - CapabilitySet inferredRequirements; - visitReferencedDecls(*visitor, stmt, stmt->loc, nullptr, - [&](SyntaxNode* node, const CapabilitySet& nodeCaps, SourceLoc refLoc) - { - _propagateRequirement(visitor, inferredRequirements, stmt, node, nodeCaps, refLoc); - }, - [](DiagnosticCategory category) - { - SLANG_UNUSED(category); - } - ); - return inferredRequirements; + auto oldCaps = resultCaps; + bool isAnyInvalid = resultCaps.isInvalid() || nodeCaps.isInvalid(); + resultCaps.join(nodeCaps); + + auto decl = as(userNode); + + if (!isAnyInvalid && resultCaps.isInvalid()) + { + // If joining the referenced decl's requirements results an invalid capability set, + // then the decl is using things that require conflicting set of capabilities, and we should + // diagnose an error. + if (referencedDecl && decl) + { + maybeDiagnose( + visitor->getSink(), + visitor->getOptionSet(), + DiagnosticCategory::Capability, + referenceLoc, + Diagnostics::conflictingCapabilityDueToUseOfDecl, + referencedDecl, + nodeCaps, + decl, + oldCaps); + } + else if (decl) + { + maybeDiagnose( + visitor->getSink(), + visitor->getOptionSet(), + DiagnosticCategory::Capability, + referenceLoc, + Diagnostics::conflictingCapabilityDueToStatement, + nodeCaps, + decl, + oldCaps); + } + else + { + maybeDiagnose( + visitor->getSink(), + visitor->getOptionSet(), + DiagnosticCategory::Capability, + referenceLoc, + Diagnostics::conflictingCapabilityDueToStatementEnclosingFunc, + nodeCaps, + oldCaps); + } } - void SemanticsDeclCapabilityVisitor::checkVarDeclCommon(VarDeclBase* varDecl) + // if stmt inside parent, set the provenance tracker to the calling function + if (!decl) + decl = visitor->getParentFuncOfVisitor(); + if (referencedDecl && decl) { - visitReferencedDecls(*this, varDecl->type.type, varDecl->loc, varDecl->findModifier(), - [this, varDecl](SyntaxNode* node, const CapabilitySet& nodeCaps, SourceLoc refLoc) - { - _propagateRequirement(this, varDecl->inferredCapabilityRequirements, varDecl, node, nodeCaps, refLoc); - }, - [this, varDecl](DiagnosticCategory category) - { - _propagateSeeDefinitionOf(this, varDecl, category); - } - ); + // Here we store a childDecl that added/removed capabilities from a parentDecl + decl->capabilityRequirementProvenance.add( + DeclReferenceWithLoc{referencedDecl, referenceLoc}); } +}; + +CapabilitySet getStatementCapabilityUsage(SemanticsVisitor* visitor, Stmt* stmt); + +template +struct CapabilityDeclReferenceVisitor + : public SemanticsDeclReferenceVisitor< + CapabilityDeclReferenceVisitor> +{ + typedef SemanticsDeclReferenceVisitor< + CapabilityDeclReferenceVisitor> + Base; - CapabilitySet SemanticsDeclCapabilityVisitor::getDeclaredCapabilitySet(Decl* decl) + const ProcessFunc handleProcessFunc; + const ParentDiagnosticFunc handleParentDiagnosticFunc; + RequireCapabilityAttribute* maybeRequireCapability; + SemanticsContext& outerContext; + CapabilityDeclReferenceVisitor( + const ProcessFunc& processFunc, + const ParentDiagnosticFunc& parentDiagnosticFunc, + RequireCapabilityAttribute* maybeRequireCapability, + SemanticsContext& outer) + : handleProcessFunc(processFunc) + , handleParentDiagnosticFunc(parentDiagnosticFunc) + , maybeRequireCapability(maybeRequireCapability) + , outerContext(outer) + , SemanticsDeclReferenceVisitor< + CapabilityDeclReferenceVisitor>(outer) { - // Merge a decls's declared capability set with all parent declarations. - // For every existing target, we want to join their requirements together. - // If the the parent defines additional targets, we want to add them to the disjunction set. - // For example: - // [require(glsl)] struct Parent { [require(glsl, glsl_ext_1)] [require(spirv)] void foo(); } - // The requirement for `foo` should be glsl+glsl_ext_1 | spirv. - // - CapabilitySet declaredCaps; - for (Decl* parent = decl; parent; parent = getParentDecl(parent)) + } + virtual void processReferencedDecl(Decl* decl) override + { + SourceLoc loc = SourceLoc(); + if (Base::sourceLocStack.getCount()) + loc = Base::sourceLocStack.getLast(); + handleProcessFunc(decl, decl->inferredCapabilityRequirements, loc); + } + virtual void processDeclModifiers(Decl* decl, SourceLoc refLoc) override + { + if (decl) + handleProcessFunc(decl, decl->inferredCapabilityRequirements, refLoc); + } + void visitDiscardStmt(DiscardStmt* stmt) + { + handleProcessFunc(stmt, CapabilitySet(CapabilityName::fragment), stmt->loc); + } + void visitTargetSwitchStmt(TargetSwitchStmt* stmt) + { + CapabilitySet set; + auto targetCaseCount = stmt->targetCases.getCount(); + for (Index targetCaseIndex = 0; targetCaseIndex < targetCaseCount; targetCaseIndex++) { - CapabilitySet localDeclaredCaps; - bool shouldBreak = false; - if (!as(parent) || parent->inferredCapabilityRequirements.isEmpty()) + // We may recieve a `default:` case for a `__target_switch`. If this is the case, + // we must resolve the target capability for a non empty set of + // `calling_functions_targets`: + // ``` default_target = calling_functions_targets-{other_case_targets} ``` + // + // * `calling_functions_capability` = `requirement attribute` of the calling function; + // if missing + // we can assume it is `any_target` + // + // * `{other_case_targets}` = set of all capabilities all `case` statments target inside + // the `__target_switch` + + // If we do not handle `default:`, the codegen will fail when trying to find a specific + // codegen target not handled explicitly by a `case` statment. + // We must also ensure the `default` case is last so we have priority to hit `case` + // statments and can preprocess `case` statments before the `default` case. + CapabilitySet targetCap; + if (CapabilityName(stmt->targetCases[targetCaseIndex]->capability) == + CapabilityName::Invalid) { - for (auto decoration : parent->getModifiersOfType()) + if (targetCaseCount - 1 != targetCaseIndex) { - localDeclaredCaps.unionWith(decoration->capabilitySet); + for (Index i = targetCaseIndex; i < targetCaseCount - 1; i++) + std::swap(stmt->targetCases[i], stmt->targetCases[i + 1]); + continue; } + + if (!maybeRequireCapability) + targetCap = (CapabilitySet(CapabilityName::any_target) + .getTargetsThisHasButOtherDoesNot(set)); + else + targetCap = + (maybeRequireCapability->capabilitySet.getTargetsThisHasButOtherDoesNot( + set)); } else { - localDeclaredCaps = parent->inferredCapabilityRequirements; - shouldBreak = true; - } - // Merge decl's capability declaration with the parent. - declaredCaps.nonDestructiveJoin(localDeclaredCaps); + targetCap = + CapabilitySet(CapabilityName(stmt->targetCases[targetCaseIndex]->capability)); - // If the parent already has inferred capability requirements, we should stop now - // since that already covers transitive parents. - if (shouldBreak) - break; + if (maybeRequireCapability) + { + CapabilitySet testingForInvalid = maybeRequireCapability->capabilitySet; + // Ensure case statement is valid with parent `[require(...)]` + testingForInvalid.join(targetCap); + if (testingForInvalid.isInvalid()) + { + maybeDiagnose( + Base::getSink(), + outerContext.getOptionSet(), + DiagnosticCategory::Capability, + stmt->targetCases[targetCaseIndex]->loc, + Diagnostics::conflictingCapabilityDueToStatement, + targetCap, + maybeRequireCapability, + maybeRequireCapability->capabilitySet); + handleParentDiagnosticFunc(DiagnosticCategory::Capability); + } + } + } + auto targetCase = stmt->targetCases[targetCaseIndex]; + auto oldCap = targetCap; + auto bodyCap = getStatementCapabilityUsage(this, targetCase->body); + targetCap.join(bodyCap); + if (targetCap.isInvalid()) + { + maybeDiagnose( + Base::getSink(), + outerContext.getOptionSet(), + DiagnosticCategory::Capability, + targetCase->body->loc, + Diagnostics::conflictingCapabilityDueToStatement, + bodyCap, + "target_switch", + oldCap); + handleParentDiagnosticFunc(DiagnosticCategory::Capability); + } + set.unionWith(targetCap); } - return declaredCaps; + handleProcessFunc(stmt, set, stmt->loc); } - void SemanticsDeclCapabilityVisitor::visitAggTypeDeclBase(AggTypeDeclBase* decl) + void visitRequireCapabilityDecl(RequireCapabilityDecl* decl) { - decl->inferredCapabilityRequirements = getDeclaredCapabilitySet(decl); + handleProcessFunc(decl, decl->inferredCapabilityRequirements, decl->loc); } +}; - void SemanticsDeclCapabilityVisitor::visitNamespaceDeclBase(NamespaceDeclBase* decl) - { - decl->inferredCapabilityRequirements = getDeclaredCapabilitySet(decl); - } - - template - static inline void _dispatchCapabilitiesVisitorOfFunctionDecl(SemanticsVisitor* visitor, FunctionDeclBase* funcDecl, const ProcessFunc& processFunc, const ParentDiagnosticFunc& parentDiagnosticFunc) - { - visitor->setParentFuncOfVisitor(funcDecl); +template +void visitReferencedDecls( + SemanticsContext& context, + NodeBase* node, + SourceLoc initialLoc, + RequireCapabilityAttribute* maybeRequireCapability, + const ProcessFunc& processFunc, + const ParentDiagnosticFunc& parentDiagnosticFunc) +{ + CapabilityDeclReferenceVisitor visitor( + processFunc, + parentDiagnosticFunc, + maybeRequireCapability, + context); + visitor.sourceLocStack.add(initialLoc); + + if (auto val = as(node)) + visitor.dispatchIfNotNull(val); + if (auto stmt = as(node)) + visitor.dispatchIfNotNull(stmt); + if (auto expr = as(node)) + visitor.dispatchIfNotNull(expr); + if (auto decl = as(node)) + visitor.dispatchIfNotNull(decl); +} - for (auto member : funcDecl->members) - { - visitor->ensureDecl(member, DeclCheckState::CapabilityChecked); - _propagateRequirement(visitor, funcDecl->inferredCapabilityRequirements, funcDecl, member, member->inferredCapabilityRequirements, member->loc); - } +CapabilitySet getStatementCapabilityUsage(SemanticsVisitor* visitor, Stmt* stmt) +{ + if (stmt == nullptr) + return CapabilitySet(); + + CapabilitySet inferredRequirements; + visitReferencedDecls( + *visitor, + stmt, + stmt->loc, + nullptr, + [&](SyntaxNode* node, const CapabilitySet& nodeCaps, SourceLoc refLoc) + { _propagateRequirement(visitor, inferredRequirements, stmt, node, nodeCaps, refLoc); }, + [](DiagnosticCategory category) { SLANG_UNUSED(category); }); + return inferredRequirements; +} - visitReferencedDecls(*visitor, funcDecl->body, funcDecl->loc, funcDecl->findModifier(), processFunc, parentDiagnosticFunc); +void SemanticsDeclCapabilityVisitor::checkVarDeclCommon(VarDeclBase* varDecl) +{ + visitReferencedDecls( + *this, + varDecl->type.type, + varDecl->loc, + varDecl->findModifier(), + [this, varDecl](SyntaxNode* node, const CapabilitySet& nodeCaps, SourceLoc refLoc) + { + _propagateRequirement( + this, + varDecl->inferredCapabilityRequirements, + varDecl, + node, + nodeCaps, + refLoc); + }, + [this, varDecl](DiagnosticCategory category) + { _propagateSeeDefinitionOf(this, varDecl, category); }); +} - if (!isEffectivelyStatic(funcDecl)) +CapabilitySet SemanticsDeclCapabilityVisitor::getDeclaredCapabilitySet(Decl* decl) +{ + // Merge a decls's declared capability set with all parent declarations. + // For every existing target, we want to join their requirements together. + // If the the parent defines additional targets, we want to add them to the disjunction set. + // For example: + // [require(glsl)] struct Parent { [require(glsl, glsl_ext_1)] [require(spirv)] void foo(); } + // The requirement for `foo` should be glsl+glsl_ext_1 | spirv. + // + CapabilitySet declaredCaps; + for (Decl* parent = decl; parent; parent = getParentDecl(parent)) + { + CapabilitySet localDeclaredCaps; + bool shouldBreak = false; + if (!as(parent) || parent->inferredCapabilityRequirements.isEmpty()) { - auto parentAggTypeDecl = getParentAggTypeDecl(funcDecl); - if (parentAggTypeDecl) + for (auto decoration : parent->getModifiersOfType()) { - visitor->ensureDecl(parentAggTypeDecl, DeclCheckState::CapabilityChecked); - _propagateRequirement(visitor, funcDecl->inferredCapabilityRequirements, funcDecl, parentAggTypeDecl, parentAggTypeDecl->inferredCapabilityRequirements, funcDecl->loc); + localDeclaredCaps.unionWith(decoration->capabilitySet); } } + else + { + localDeclaredCaps = parent->inferredCapabilityRequirements; + shouldBreak = true; + } + // Merge decl's capability declaration with the parent. + declaredCaps.nonDestructiveJoin(localDeclaredCaps); + + // If the parent already has inferred capability requirements, we should stop now + // since that already covers transitive parents. + if (shouldBreak) + break; } + return declaredCaps; +} - void SemanticsDeclCapabilityVisitor::visitFunctionDeclBase(FunctionDeclBase* funcDecl) - { - // If the function is an entrypoint and specifies a target stage, add the capabilities to our function capabilities. - _dispatchCapabilitiesVisitorOfFunctionDecl(this, funcDecl, - [this, funcDecl](SyntaxNode* node, const CapabilitySet& nodeCaps, SourceLoc refLoc) - { - _propagateRequirement(this, funcDecl->inferredCapabilityRequirements, funcDecl, node, nodeCaps, refLoc); - }, - [this, funcDecl](DiagnosticCategory category) - { - _propagateSeeDefinitionOf(this, funcDecl, category); - } - ); +void SemanticsDeclCapabilityVisitor::visitAggTypeDeclBase(AggTypeDeclBase* decl) +{ + decl->inferredCapabilityRequirements = getDeclaredCapabilitySet(decl); +} + +void SemanticsDeclCapabilityVisitor::visitNamespaceDeclBase(NamespaceDeclBase* decl) +{ + decl->inferredCapabilityRequirements = getDeclaredCapabilitySet(decl); +} + +template +static inline void _dispatchCapabilitiesVisitorOfFunctionDecl( + SemanticsVisitor* visitor, + FunctionDeclBase* funcDecl, + const ProcessFunc& processFunc, + const ParentDiagnosticFunc& parentDiagnosticFunc) +{ + visitor->setParentFuncOfVisitor(funcDecl); - auto declaredCaps = getDeclaredCapabilitySet(funcDecl); + for (auto member : funcDecl->members) + { + visitor->ensureDecl(member, DeclCheckState::CapabilityChecked); + _propagateRequirement( + visitor, + funcDecl->inferredCapabilityRequirements, + funcDecl, + member, + member->inferredCapabilityRequirements, + member->loc); + } - auto vis = getDeclVisibility(funcDecl); + visitReferencedDecls( + *visitor, + funcDecl->body, + funcDecl->loc, + funcDecl->findModifier(), + processFunc, + parentDiagnosticFunc); - // If 0 capabilities were annotated on a function, capabilities are inferred from the function body - if (declaredCaps.isEmpty()) + if (!isEffectivelyStatic(funcDecl)) + { + auto parentAggTypeDecl = getParentAggTypeDecl(funcDecl); + if (parentAggTypeDecl) { - declaredCaps = funcDecl->inferredCapabilityRequirements; + visitor->ensureDecl(parentAggTypeDecl, DeclCheckState::CapabilityChecked); + _propagateRequirement( + visitor, + funcDecl->inferredCapabilityRequirements, + funcDecl, + parentAggTypeDecl, + parentAggTypeDecl->inferredCapabilityRequirements, + funcDecl->loc); } - else + } +} + +void SemanticsDeclCapabilityVisitor::visitFunctionDeclBase(FunctionDeclBase* funcDecl) +{ + // If the function is an entrypoint and specifies a target stage, add the capabilities to our + // function capabilities. + _dispatchCapabilitiesVisitorOfFunctionDecl( + this, + funcDecl, + [this, funcDecl](SyntaxNode* node, const CapabilitySet& nodeCaps, SourceLoc refLoc) + { + _propagateRequirement( + this, + funcDecl->inferredCapabilityRequirements, + funcDecl, + node, + nodeCaps, + refLoc); + }, + [this, funcDecl](DiagnosticCategory category) + { _propagateSeeDefinitionOf(this, funcDecl, category); }); + + auto declaredCaps = getDeclaredCapabilitySet(funcDecl); + + auto vis = getDeclVisibility(funcDecl); + + // If 0 capabilities were annotated on a function, capabilities are inferred from the function + // body + if (declaredCaps.isEmpty()) + { + declaredCaps = funcDecl->inferredCapabilityRequirements; + } + else + { + if (vis == DeclVisibility::Public) { - if (vis == DeclVisibility::Public) - { - // For public decls, we need to enforce that the function - // only uses capabilities that it declares. - // At a minimum we will propagate shader requirements to our - // function from calling children in all cases so the parent - // can enforce shader targets correctly and propagate to `main` - CapabilityAtomSet failedAvailableCapabilityConjunction; - if (!CapabilitySet::checkCapabilityRequirement( + // For public decls, we need to enforce that the function + // only uses capabilities that it declares. + // At a minimum we will propagate shader requirements to our + // function from calling children in all cases so the parent + // can enforce shader targets correctly and propagate to `main` + CapabilityAtomSet failedAvailableCapabilityConjunction; + if (!CapabilitySet::checkCapabilityRequirement( declaredCaps, funcDecl->inferredCapabilityRequirements, failedAvailableCapabilityConjunction)) - { - diagnoseUndeclaredCapability(funcDecl, Diagnostics::useOfUndeclaredCapability, failedAvailableCapabilityConjunction); - funcDecl->inferredCapabilityRequirements = declaredCaps; - } - else - funcDecl->inferredCapabilityRequirements.nonDestructiveJoin(declaredCaps); - } - else { - // For internal decls, their inferred capability should be joined - // with the declared capabilities. - funcDecl->inferredCapabilityRequirements.join(declaredCaps); + diagnoseUndeclaredCapability( + funcDecl, + Diagnostics::useOfUndeclaredCapability, + failedAvailableCapabilityConjunction); + funcDecl->inferredCapabilityRequirements = declaredCaps; } + else + funcDecl->inferredCapabilityRequirements.nonDestructiveJoin(declaredCaps); + } + else + { + // For internal decls, their inferred capability should be joined + // with the declared capabilities. + funcDecl->inferredCapabilityRequirements.join(declaredCaps); } } +} - void SemanticsDeclCapabilityVisitor::visitInheritanceDecl(InheritanceDecl* inheritanceDecl) +void SemanticsDeclCapabilityVisitor::visitInheritanceDecl(InheritanceDecl* inheritanceDecl) +{ + // Check that the implementation of an interface requirement is not using more capabilities + // than what's declared on the interface method. + if (inheritanceDecl->witnessTable) { - // Check that the implementation of an interface requirement is not using more capabilities - // than what's declared on the interface method. - if (inheritanceDecl->witnessTable) + for (auto& kv : inheritanceDecl->witnessTable->m_requirementDictionary) { - for (auto& kv : inheritanceDecl->witnessTable->m_requirementDictionary) - { - if (kv.value.getFlavor() != RequirementWitness::Flavor::declRef) - continue; - auto requirementDecl = kv.key; - auto implDecl = kv.value.getDeclRef(); - if (!implDecl) - continue; + if (kv.value.getFlavor() != RequirementWitness::Flavor::declRef) + continue; + auto requirementDecl = kv.key; + auto implDecl = kv.value.getDeclRef(); + if (!implDecl) + continue; - if (getModuleDecl(implDecl.getDecl())->isInLegacyLanguage) - break; + if (getModuleDecl(implDecl.getDecl())->isInLegacyLanguage) + break; - ensureDecl(requirementDecl, DeclCheckState::CapabilityChecked); - ensureDecl(implDecl.declRefBase, DeclCheckState::CapabilityChecked); - - CapabilityAtomSet failedAvailableCapabilityConjunction; - if (!CapabilitySet::checkCapabilityRequirement( + ensureDecl(requirementDecl, DeclCheckState::CapabilityChecked); + ensureDecl(implDecl.declRefBase, DeclCheckState::CapabilityChecked); + + CapabilityAtomSet failedAvailableCapabilityConjunction; + if (!CapabilitySet::checkCapabilityRequirement( requirementDecl->inferredCapabilityRequirements, implDecl.getDecl()->inferredCapabilityRequirements, failedAvailableCapabilityConjunction)) - { - diagnoseUndeclaredCapability(implDecl.getDecl(), Diagnostics::useOfUndeclaredCapabilityOfInterfaceRequirement, failedAvailableCapabilityConjunction); - } + { + diagnoseUndeclaredCapability( + implDecl.getDecl(), + Diagnostics::useOfUndeclaredCapabilityOfInterfaceRequirement, + failedAvailableCapabilityConjunction); } } } +} - DeclVisibility getDeclVisibility(Decl* decl) +DeclVisibility getDeclVisibility(Decl* decl) +{ + if (as(decl) || as(decl) || + as(decl)) { - if (as(decl) || as(decl) || as(decl)) - { - auto genericDecl = as(decl->parentDecl); - if (!genericDecl) - return DeclVisibility::Default; - if (genericDecl->inner) - return getDeclVisibility(genericDecl->inner); + auto genericDecl = as(decl->parentDecl); + if (!genericDecl) return DeclVisibility::Default; - } - if (auto genericDecl = as(decl)) - decl = genericDecl->inner; - for (; decl; decl = getParentDecl(decl)) - { - if (as(decl)) - continue; - if (as(decl)) - continue; - break; - } - if (!decl) - return DeclVisibility::Public; - - for (auto modifier : decl->modifiers) - { - if (as(modifier)) - return DeclVisibility::Public; - else if (as(modifier)) - return DeclVisibility::Internal; - else if (as(modifier)) - return DeclVisibility::Private; - } - // Interface members will always have the same visibility as the interface itself. - if (auto interfaceDecl = findParentInterfaceDecl(decl)) - { - return getDeclVisibility(interfaceDecl); - } - auto defaultVis = DeclVisibility::Default; - if (auto parentModule = getModuleDecl(decl)) - defaultVis = parentModule->isInLegacyLanguage ? DeclVisibility::Public : DeclVisibility::Internal; + if (genericDecl->inner) + return getDeclVisibility(genericDecl->inner); + return DeclVisibility::Default; + } + if (auto genericDecl = as(decl)) + decl = genericDecl->inner; + for (; decl; decl = getParentDecl(decl)) + { + if (as(decl)) + continue; + if (as(decl)) + continue; + break; + } + if (!decl) + return DeclVisibility::Public; - // Members of other agg type decls will have their default visibility capped to the parents'. - if (as(decl)) - { + for (auto modifier : decl->modifiers) + { + if (as(modifier)) return DeclVisibility::Public; - } - return defaultVis; + else if (as(modifier)) + return DeclVisibility::Internal; + else if (as(modifier)) + return DeclVisibility::Private; + } + // Interface members will always have the same visibility as the interface itself. + if (auto interfaceDecl = findParentInterfaceDecl(decl)) + { + return getDeclVisibility(interfaceDecl); + } + auto defaultVis = DeclVisibility::Default; + if (auto parentModule = getModuleDecl(decl)) + defaultVis = + parentModule->isInLegacyLanguage ? DeclVisibility::Public : DeclVisibility::Internal; + + // Members of other agg type decls will have their default visibility capped to the parents'. + if (as(decl)) + { + return DeclVisibility::Public; } + return defaultVis; +} - VarDeclBase* getTrailingUnsizedArrayElement(Type* type, VarDeclBase* parentVar, ArrayExpressionType*& outArrayType) +VarDeclBase* getTrailingUnsizedArrayElement( + Type* type, + VarDeclBase* parentVar, + ArrayExpressionType*& outArrayType) +{ + while (auto modifiedType = as(type)) + type = modifiedType->getBase(); + HashSet seenTypes; + for (;;) { - while (auto modifiedType = as(type)) - type = modifiedType->getBase(); - HashSet seenTypes; - for (;;) + if (auto arrayType = as(type)) { - if (auto arrayType = as(type)) + if (arrayType->isUnsized()) { - if (arrayType->isUnsized()) - { - outArrayType = arrayType; - return parentVar; - } - else - return nullptr; + outArrayType = arrayType; + return parentVar; } - else if (auto declRefType = as(type)) + else + return nullptr; + } + else if (auto declRefType = as(type)) + { + if (auto aggTypeDecl = declRefType->getDeclRef().as()) { - if (auto aggTypeDecl = declRefType->getDeclRef().as()) + auto varDecls = aggTypeDecl.getDecl()->getMembersOfType(); + if (varDecls.getCount() == 0) + return nullptr; + VarDeclBase* lastVarDecl = nullptr; + for (auto varDecl : varDecls) { - auto varDecls = aggTypeDecl.getDecl()->getMembersOfType(); - if (varDecls.getCount() == 0) - return nullptr; - VarDeclBase* lastVarDecl = nullptr; - for (auto varDecl : varDecls) - { - if (isEffectivelyStatic(varDecl)) - continue; - lastVarDecl = varDecl; - } - auto lastMember = _getMemberDeclRef( - getCurrentASTBuilder(), aggTypeDecl, lastVarDecl).as(); - auto varType = getType(getCurrentASTBuilder(), lastMember); - if (!varType) - return nullptr; - if (!seenTypes.add(type)) - return nullptr; - type = varType; - parentVar = lastMember.getDecl(); - continue; + if (isEffectivelyStatic(varDecl)) + continue; + lastVarDecl = varDecl; } + auto lastMember = + _getMemberDeclRef(getCurrentASTBuilder(), aggTypeDecl, lastVarDecl) + .as(); + auto varType = getType(getCurrentASTBuilder(), lastMember); + if (!varType) + return nullptr; + if (!seenTypes.add(type)) + return nullptr; + type = varType; + parentVar = lastMember.getDecl(); + continue; } } - return nullptr; } + return nullptr; +} - bool isOpaqueHandleType(Type* type) - { - while (auto modifiedType = as(type)) - type = modifiedType->getBase(); - if (as(type)) - return true; - if (as(type)) - return true; - if (as(type)) - return true; - if (as(type)) - return true; - if (as(type)) - return true; - if (as(type)) - return true; - if (as(type)) - return true; - if (as(type)) - return true; - if (as(type)) - return true; - if (as(type)) - return true; - return false; - } +bool isOpaqueHandleType(Type* type) +{ + while (auto modifiedType = as(type)) + type = modifiedType->getBase(); + if (as(type)) + return true; + if (as(type)) + return true; + if (as(type)) + return true; + if (as(type)) + return true; + if (as(type)) + return true; + if (as(type)) + return true; + if (as(type)) + return true; + if (as(type)) + return true; + if (as(type)) + return true; + if (as(type)) + return true; + return false; +} - void diagnoseMissingCapabilityProvenance(CompilerOptionSet& optionSet, DiagnosticSink* sink, Decl* decl, CapabilitySet& setToFind) - { - HashSet checkedDecls; - DeclReferenceWithLoc declWithRef; - declWithRef.referencedDecl = decl; - declWithRef.referenceLoc = (decl) ? decl->loc : SourceLoc(); - bool bottomOfProvenanceStack = false; - // Find the bottom of the atom provenance stack which fails to contain `setToFind` - while(!bottomOfProvenanceStack && declWithRef.referencedDecl) - { - bottomOfProvenanceStack = true; - for(auto& i : declWithRef.referencedDecl->capabilityRequirementProvenance) +void diagnoseMissingCapabilityProvenance( + CompilerOptionSet& optionSet, + DiagnosticSink* sink, + Decl* decl, + CapabilitySet& setToFind) +{ + HashSet checkedDecls; + DeclReferenceWithLoc declWithRef; + declWithRef.referencedDecl = decl; + declWithRef.referenceLoc = (decl) ? decl->loc : SourceLoc(); + bool bottomOfProvenanceStack = false; + // Find the bottom of the atom provenance stack which fails to contain `setToFind` + while (!bottomOfProvenanceStack && declWithRef.referencedDecl) + { + bottomOfProvenanceStack = true; + for (auto& i : declWithRef.referencedDecl->capabilityRequirementProvenance) + { + if (checkedDecls.contains(i.referencedDecl)) + continue; + checkedDecls.add(i.referencedDecl); + + if (!i.referencedDecl->inferredCapabilityRequirements.implies(setToFind)) { - if (checkedDecls.contains(i.referencedDecl)) - continue; - checkedDecls.add(i.referencedDecl); - - if(!i.referencedDecl->inferredCapabilityRequirements.implies(setToFind)) - { - // We found a source of the incompatible capability, follow this - // element inside the provenance stack until we are at the bottom - declWithRef = i; - bottomOfProvenanceStack = false; - break; - } + // We found a source of the incompatible capability, follow this + // element inside the provenance stack until we are at the bottom + declWithRef = i; + bottomOfProvenanceStack = false; + break; } } - - if (!declWithRef.referencedDecl) - return; - - // Diagnose the use-site - maybeDiagnose(sink, optionSet, DiagnosticCategory::Capability, declWithRef.referenceLoc, Diagnostics::seeUsingOf, declWithRef.referencedDecl); - // Diagnose the definition as the problem - maybeDiagnose(sink, optionSet, DiagnosticCategory::Capability, declWithRef.referencedDecl->loc, Diagnostics::seeDefinitionOf, declWithRef.referencedDecl); - - // If we find a 'require' modifier, this is contributing to the overall capability incompatibility. - // We should hint to the user that this declaration is problematic. - if (auto requireCapabilityAttribute = declWithRef.referencedDecl->findModifier()) - maybeDiagnose(sink, optionSet, DiagnosticCategory::Capability, requireCapabilityAttribute->loc, Diagnostics::seeDeclarationOf, requireCapabilityAttribute); } - void diagnoseCapabilityProvenance(CompilerOptionSet& optionSet, DiagnosticSink* sink, Decl* decl, CapabilityAtom atomToFind, HashSet& printedDecls) + if (!declWithRef.referencedDecl) + return; + + // Diagnose the use-site + maybeDiagnose( + sink, + optionSet, + DiagnosticCategory::Capability, + declWithRef.referenceLoc, + Diagnostics::seeUsingOf, + declWithRef.referencedDecl); + // Diagnose the definition as the problem + maybeDiagnose( + sink, + optionSet, + DiagnosticCategory::Capability, + declWithRef.referencedDecl->loc, + Diagnostics::seeDefinitionOf, + declWithRef.referencedDecl); + + // If we find a 'require' modifier, this is contributing to the overall capability + // incompatibility. We should hint to the user that this declaration is problematic. + if (auto requireCapabilityAttribute = + declWithRef.referencedDecl->findModifier()) + maybeDiagnose( + sink, + optionSet, + DiagnosticCategory::Capability, + requireCapabilityAttribute->loc, + Diagnostics::seeDeclarationOf, + requireCapabilityAttribute); +} + +void diagnoseCapabilityProvenance( + CompilerOptionSet& optionSet, + DiagnosticSink* sink, + Decl* decl, + CapabilityAtom atomToFind, + HashSet& printedDecls) +{ + auto thisModule = getModuleDecl(decl); + Decl* declToPrint = decl; + while (declToPrint) { - auto thisModule = getModuleDecl(decl); - Decl* declToPrint = decl; - while (declToPrint) + Decl* previousDecl = declToPrint; + printedDecls.add(declToPrint); + for (auto& provenance : declToPrint->capabilityRequirementProvenance) { - Decl* previousDecl = declToPrint; - printedDecls.add(declToPrint); - for(auto& provenance : declToPrint->capabilityRequirementProvenance) - { - if (!provenance.referencedDecl->inferredCapabilityRequirements.implies(atomToFind)) - continue; - maybeDiagnose(sink, optionSet, DiagnosticCategory::Capability, provenance.referenceLoc, Diagnostics::seeUsingOf, provenance.referencedDecl); - declToPrint = provenance.referencedDecl; - if (printedDecls.contains(declToPrint)) - break; - if (declToPrint->findModifier()) - break; - auto moduleDecl = getModuleDecl(declToPrint); - if (thisModule != moduleDecl) - break; - if (moduleDecl && moduleDecl->isInLegacyLanguage) - continue; - if (getDeclVisibility(declToPrint) == DeclVisibility::Public) - break; - } - if (previousDecl == declToPrint) + if (!provenance.referencedDecl->inferredCapabilityRequirements.implies(atomToFind)) + continue; + maybeDiagnose( + sink, + optionSet, + DiagnosticCategory::Capability, + provenance.referenceLoc, + Diagnostics::seeUsingOf, + provenance.referencedDecl); + declToPrint = provenance.referencedDecl; + if (printedDecls.contains(declToPrint)) + break; + if (declToPrint->findModifier()) + break; + auto moduleDecl = getModuleDecl(declToPrint); + if (thisModule != moduleDecl) + break; + if (moduleDecl && moduleDecl->isInLegacyLanguage) + continue; + if (getDeclVisibility(declToPrint) == DeclVisibility::Public) break; } - if (declToPrint) - { - maybeDiagnose(sink, optionSet, DiagnosticCategory::Capability, declToPrint->loc, Diagnostics::seeDefinitionOf, declToPrint); - } + if (previousDecl == declToPrint) + break; } - - void SemanticsDeclCapabilityVisitor::diagnoseUndeclaredCapability(Decl* decl, const DiagnosticInfo& diagnosticInfo, const CapabilityAtomSet& failedAtomsInsideAvailableSet) + if (declToPrint) { - if (decl->inferredCapabilityRequirements.isEmpty()) - return; - if(failedAtomsInsideAvailableSet.isEmpty() || failedAtomsInsideAvailableSet.contains((UInt)CapabilityAtom::Invalid)) - return; + maybeDiagnose( + sink, + optionSet, + DiagnosticCategory::Capability, + declToPrint->loc, + Diagnostics::seeDefinitionOf, + declToPrint); + } +} - // There are two causes for why type checking failed on failedAvailableSet. - // The first scenario is that failedAvailableSet defines a set of capabilities on a - // compilation target (e.g. hlsl) that isn't defined by some callees, for example, if we have - // a function: - // [require(hlsl)] // <-- failedAvailableSet - // [require(cpp)] - // void caller() - // { - // printf(); // assume this is defined for (cpp | cuda). - // } - // In this case we should diagnose error reporting printf isn't defined on a required target. - // - // Now, we detect if we are case 1. - - { - CapabilityAtom outFailedAtom{}; - if (hasTargetAtom(failedAtomsInsideAvailableSet, outFailedAtom)) - { - maybeDiagnose(getSink(), this->getOptionSet(), DiagnosticCategory::Capability, decl->loc, Diagnostics::declHasDependenciesNotCompatibleOnTarget, decl, outFailedAtom); - - // Anything defined on a non-failed target atom may be the culprit to why we fail having a target capability. - // Print out all possible culprits. - CapabilityAtomSet failedAtomSet; - failedAtomSet.add((UInt)outFailedAtom); - CapabilityAtomSet targetsNotUsedSet; - CapabilityAtomSet::calcSubtract(targetsNotUsedSet, getAtomSetOfTargets(), failedAtomSet); - - HashSet printedDecls; - for (auto atom : targetsNotUsedSet) - { - CapabilityAtom formattedAtom = asAtom(atom); - diagnoseCapabilityProvenance(this->getOptionSet(), getSink(), decl, formattedAtom, printedDecls); - } - return; +void SemanticsDeclCapabilityVisitor::diagnoseUndeclaredCapability( + Decl* decl, + const DiagnosticInfo& diagnosticInfo, + const CapabilityAtomSet& failedAtomsInsideAvailableSet) +{ + if (decl->inferredCapabilityRequirements.isEmpty()) + return; + if (failedAtomsInsideAvailableSet.isEmpty() || + failedAtomsInsideAvailableSet.contains((UInt)CapabilityAtom::Invalid)) + return; + + // There are two causes for why type checking failed on failedAvailableSet. + // The first scenario is that failedAvailableSet defines a set of capabilities on a + // compilation target (e.g. hlsl) that isn't defined by some callees, for example, if we have + // a function: + // [require(hlsl)] // <-- failedAvailableSet + // [require(cpp)] + // void caller() + // { + // printf(); // assume this is defined for (cpp | cuda). + // } + // In this case we should diagnose error reporting printf isn't defined on a required target. + // + // Now, we detect if we are case 1. + + { + CapabilityAtom outFailedAtom{}; + if (hasTargetAtom(failedAtomsInsideAvailableSet, outFailedAtom)) + { + maybeDiagnose( + getSink(), + this->getOptionSet(), + DiagnosticCategory::Capability, + decl->loc, + Diagnostics::declHasDependenciesNotCompatibleOnTarget, + decl, + outFailedAtom); + + // Anything defined on a non-failed target atom may be the culprit to why we fail having + // a target capability. Print out all possible culprits. + CapabilityAtomSet failedAtomSet; + failedAtomSet.add((UInt)outFailedAtom); + CapabilityAtomSet targetsNotUsedSet; + CapabilityAtomSet::calcSubtract( + targetsNotUsedSet, + getAtomSetOfTargets(), + failedAtomSet); + + HashSet printedDecls; + for (auto atom : targetsNotUsedSet) + { + CapabilityAtom formattedAtom = asAtom(atom); + diagnoseCapabilityProvenance( + this->getOptionSet(), + getSink(), + decl, + formattedAtom, + printedDecls); } + return; } + } - //// The second scenario is when the callee is using a capability that is not provided by the requirement. - //// For example: - //// [require(hlsl,b,c)] - //// void caller() - //// { - //// useD(); // require capability (hlsl,d) - //// } - //// In this case we should report that useD() is using a capability that is not declared by caller. - //// + //// The second scenario is when the callee is using a capability that is not provided by the + /// requirement. / For example: / [require(hlsl,b,c)] / void caller() / { / useD(); + ///// require capability (hlsl,d) / } / In this case we should report that useD() is using a + /// capability that is not declared by caller. + //// - //// If we reach here, we are case 2. + //// If we reach here, we are case 2. - // We will produce all failed atoms. This is important since provenance of multiple atoms - // can come from multiple referenced items in a function body. - HashSet printedDecls; - auto simplifiedFailedAtomsSet = failedAtomsInsideAvailableSet.newSetWithoutImpliedAtoms(); - for (auto i : simplifiedFailedAtomsSet) - { - CapabilityAtom formattedAtom = asAtom(i); - maybeDiagnose(getSink(), this->getOptionSet(), DiagnosticCategory::Capability, decl->loc, diagnosticInfo, decl, formattedAtom); - // Print provenances. - diagnoseCapabilityProvenance(this->getOptionSet(), getSink(), decl, formattedAtom, printedDecls); - } + // We will produce all failed atoms. This is important since provenance of multiple atoms + // can come from multiple referenced items in a function body. + HashSet printedDecls; + auto simplifiedFailedAtomsSet = failedAtomsInsideAvailableSet.newSetWithoutImpliedAtoms(); + for (auto i : simplifiedFailedAtomsSet) + { + CapabilityAtom formattedAtom = asAtom(i); + maybeDiagnose( + getSink(), + this->getOptionSet(), + DiagnosticCategory::Capability, + decl->loc, + diagnosticInfo, + decl, + formattedAtom); + // Print provenances. + diagnoseCapabilityProvenance( + this->getOptionSet(), + getSink(), + decl, + formattedAtom, + printedDecls); } - } + +} // namespace Slang diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp index 41cbf689b..076e310c0 100644 --- a/source/slang/slang-check-expr.cpp +++ b/source/slang/slang-check-expr.cpp @@ -11,157 +11,159 @@ // // * `slang-check-conversion.cpp` is responsible for the logic of handling type conversion/coercion +#include "core/slang-char-util.h" #include "slang-ast-natural-layout.h" - -#include "slang-lookup.h" -#include "slang-lookup-spirv.h" #include "slang-ast-print.h" -#include "core/slang-char-util.h" +#include "slang-lookup-spirv.h" +#include "slang-lookup.h" namespace Slang { - DeclRefType* SemanticsVisitor::getExprDeclRefType(Expr * expr) - { - if (auto typetype = as(expr->type)) - return dynamicCast(typetype->getType()); - else - return as(expr->type); - } +DeclRefType* SemanticsVisitor::getExprDeclRefType(Expr* expr) +{ + if (auto typetype = as(expr->type)) + return dynamicCast(typetype->getType()); + else + return as(expr->type); +} - void SemanticsContext::ExprLocalScope::addBinding(LetExpr* binding) +void SemanticsContext::ExprLocalScope::addBinding(LetExpr* binding) +{ + if (!m_innerMostBinding) { - if (!m_innerMostBinding) - { - SLANG_ASSERT(!m_outerMostBinding); - - // If we haven't added any bindings, then `binding` - // becomes both the inner-most and outer most. - // - m_innerMostBinding = binding; - m_outerMostBinding = binding; - } - else - { - SLANG_ASSERT(m_outerMostBinding); + SLANG_ASSERT(!m_outerMostBinding); - // If we already have bindings, then `binding` - // will become the new inner-most binding. - // - m_innerMostBinding->body = binding; - m_innerMostBinding = binding; - } + // If we haven't added any bindings, then `binding` + // becomes both the inner-most and outer most. + // + m_innerMostBinding = binding; + m_outerMostBinding = binding; } - - - /// Move `expr` into a temporary variable and execute `func` on that variable. - /// - /// Returns an expression that wraps both the creation and initialization of - /// the temporary, and the computation created by `func`. - /// - template - Expr* SemanticsVisitor::moveTemp(Expr* const& expr, F const& func) + else { - VarDecl* varDecl = m_astBuilder->create(); - varDecl->parentDecl = nullptr; - if (m_outerScope && m_outerScope->containerDecl) - m_outerScope->containerDecl->addMember(varDecl); - addModifier(varDecl, m_astBuilder->create()); - varDecl->checkState = DeclCheckState::DefinitionChecked; - varDecl->nameAndLoc.loc = expr->loc; - varDecl->initExpr = expr; - varDecl->type.type = expr->type.type; + SLANG_ASSERT(m_outerMostBinding); - auto varDeclRef = makeDeclRef(varDecl); + // If we already have bindings, then `binding` + // will become the new inner-most binding. + // + m_innerMostBinding->body = binding; + m_innerMostBinding = binding; + } +} - LetExpr* letExpr = m_astBuilder->create(); - letExpr->decl = varDecl; - auto body = func(varDeclRef); - Expr* result = body; - if (auto exprLocalScope = getExprLocalScope()) - { - // We want to add the `LetExpr` to the set of such expressions - // in the local scope, so that it can be emitted properly. - // - exprLocalScope->addBinding(letExpr); - } - else - { - // If we somehow got in here and there wasn't an expression-local - // scope established yet, it almost certainly represents an error. - // - SLANG_ASSERT(exprLocalScope); +/// Move `expr` into a temporary variable and execute `func` on that variable. +/// +/// Returns an expression that wraps both the creation and initialization of +/// the temporary, and the computation created by `func`. +/// +template +Expr* SemanticsVisitor::moveTemp(Expr* const& expr, F const& func) +{ + VarDecl* varDecl = m_astBuilder->create(); + varDecl->parentDecl = nullptr; + if (m_outerScope && m_outerScope->containerDecl) + m_outerScope->containerDecl->addMember(varDecl); + addModifier(varDecl, m_astBuilder->create()); + varDecl->checkState = DeclCheckState::DefinitionChecked; + varDecl->nameAndLoc.loc = expr->loc; + varDecl->initExpr = expr; + varDecl->type.type = expr->type.type; + + auto varDeclRef = makeDeclRef(varDecl); + + LetExpr* letExpr = m_astBuilder->create(); + letExpr->decl = varDecl; + + auto body = func(varDeclRef); + Expr* result = body; + if (auto exprLocalScope = getExprLocalScope()) + { + // We want to add the `LetExpr` to the set of such expressions + // in the local scope, so that it can be emitted properly. + // + exprLocalScope->addBinding(letExpr); + } + else + { + // If we somehow got in here and there wasn't an expression-local + // scope established yet, it almost certainly represents an error. + // + SLANG_ASSERT(exprLocalScope); - // As a fallback, though, we will try to wire up the `letExpr` - // to surround the body directly and return that. - // - letExpr->body = body; - letExpr->type = body->type; + // As a fallback, though, we will try to wire up the `letExpr` + // to surround the body directly and return that. + // + letExpr->body = body; + letExpr->type = body->type; - result = letExpr; - } - return result; + result = letExpr; } + return result; +} + +/// Execute `func` on a variable with the value of `expr`. +/// +/// If `expr` is just a reference to an immutable (e.g., `let`) variable +/// then this might use the existing variable. Otherwise it will create +/// a new variable to hold `expr`, using `moveTemp()`. +/// +template +Expr* SemanticsVisitor::maybeMoveTemp(Expr* const& expr, F const& func) +{ + // TODO: Eventually this operation could consider any case where the + // input `expr` names an immutable "path": one that starts at an + // immutable binding and follows a (possibly empty) chain of accesses + // to immutable members. - /// Execute `func` on a variable with the value of `expr`. - /// - /// If `expr` is just a reference to an immutable (e.g., `let`) variable - /// then this might use the existing variable. Otherwise it will create - /// a new variable to hold `expr`, using `moveTemp()`. - /// - template - Expr* SemanticsVisitor::maybeMoveTemp(Expr* const& expr, F const& func) + if (auto varExpr = as(expr)) { - // TODO: Eventually this operation could consider any case where the - // input `expr` names an immutable "path": one that starts at an - // immutable binding and follows a (possibly empty) chain of accesses - // to immutable members. + auto declRef = varExpr->declRef; + if (auto varDeclRef = declRef.as()) + return func(varDeclRef); + } - if(auto varExpr = as(expr)) - { - auto declRef = varExpr->declRef; - if(auto varDeclRef = declRef.as()) - return func(varDeclRef); - } + return moveTemp(expr, func); +} - return moveTemp(expr, func); - } - - /// Return an expression that represents "opening" the existential `expr`. - /// - /// The type of `expr` must be an interface type, matching `interfaceDeclRef`. - /// - /// If we scope down the PL theory to just the case that Slang cares about, - /// a value of an existential type like `IMover` is a tuple of: - /// - /// * a concrete type `X` - /// * a witness `w` of the fact that `X` implements `IMover` - /// * a value `v` of type `X` - /// - /// "Opening" an existential value is the process of decomposing a single - /// value `e : IMover` into the pieces `X`, `w`, and `v`. - /// - /// Rather than return all those pieces individually, this operation - /// returns an expression that logically corresponds to `v`: an expression - /// of type `X`, where the type carries the knowledge that `X` implements `IMover`. - /// - Expr* SemanticsVisitor::openExistential( - Expr* expr, - DeclRef interfaceDeclRef) - { - // If `expr` refers to an immutable binding, - // then we can use it directly. If it refers - // to an arbitrary expression or a mutable - // binding, we will move its value into an - // immutable temporary so that we can use - // it directly. - // - return maybeMoveTemp(expr, [&](DeclRef varDeclRef) +/// Return an expression that represents "opening" the existential `expr`. +/// +/// The type of `expr` must be an interface type, matching `interfaceDeclRef`. +/// +/// If we scope down the PL theory to just the case that Slang cares about, +/// a value of an existential type like `IMover` is a tuple of: +/// +/// * a concrete type `X` +/// * a witness `w` of the fact that `X` implements `IMover` +/// * a value `v` of type `X` +/// +/// "Opening" an existential value is the process of decomposing a single +/// value `e : IMover` into the pieces `X`, `w`, and `v`. +/// +/// Rather than return all those pieces individually, this operation +/// returns an expression that logically corresponds to `v`: an expression +/// of type `X`, where the type carries the knowledge that `X` implements `IMover`. +/// +Expr* SemanticsVisitor::openExistential(Expr* expr, DeclRef interfaceDeclRef) +{ + // If `expr` refers to an immutable binding, + // then we can use it directly. If it refers + // to an arbitrary expression or a mutable + // binding, we will move its value into an + // immutable temporary so that we can use + // it directly. + // + return maybeMoveTemp( + expr, + [&](DeclRef varDeclRef) { ExtractExistentialType* openedType = m_astBuilder->getOrCreate( - varDeclRef, expr->type.type, interfaceDeclRef); + varDeclRef, + expr->type.type, + interfaceDeclRef); - ExtractExistentialValueExpr* openedValue = m_astBuilder->create(); + ExtractExistentialValueExpr* openedValue = + m_astBuilder->create(); openedValue->declRef = varDeclRef; openedValue->type = QualType(openedType); openedValue->originalExpr = expr; @@ -169,7 +171,7 @@ namespace Slang // The result of opening an existential is an l-value // if the original existential is an l-value. // - if(expr->type.isLeftValue) + if (expr->type.isLeftValue) { // Marking the opened value as an l-value is the easy part. // @@ -188,394 +190,399 @@ namespace Slang return openedValue; }); - } +} - /// If `expr` has existential type, then open it. - /// - /// Returns an expression that opens `expr` if it had existential type. - /// Otherwise returns `expr` itself. - /// - /// See `openExistential` for a discussion of what "opening" an - /// existential-type value means. - /// - Expr* SemanticsVisitor::maybeOpenExistential(Expr* expr) - { - auto exprType = expr->type.type; +/// If `expr` has existential type, then open it. +/// +/// Returns an expression that opens `expr` if it had existential type. +/// Otherwise returns `expr` itself. +/// +/// See `openExistential` for a discussion of what "opening" an +/// existential-type value means. +/// +Expr* SemanticsVisitor::maybeOpenExistential(Expr* expr) +{ + auto exprType = expr->type.type; - if(auto declRefType = as(exprType)) + if (auto declRefType = as(exprType)) + { + if (auto interfaceDeclRef = declRefType->getDeclRef().as()) { - if(auto interfaceDeclRef = declRefType->getDeclRef().as()) - { - return openExistential(expr, interfaceDeclRef); - } + return openExistential(expr, interfaceDeclRef); } - - // Default: apply the callback to the original expression; - return expr; } - Expr* SemanticsVisitor::maybeOpenRef(Expr* expr) - { - auto exprType = expr->type.type; + // Default: apply the callback to the original expression; + return expr; +} - if (auto refType = as(exprType)) - { - auto openRef = m_astBuilder->create(); - openRef->innerExpr = expr; - openRef->type.isLeftValue = (as(exprType) != nullptr); - openRef->type.type = refType->getValueType(); - return openRef; - } - return expr; +Expr* SemanticsVisitor::maybeOpenRef(Expr* expr) +{ + auto exprType = expr->type.type; + + if (auto refType = as(exprType)) + { + auto openRef = m_astBuilder->create(); + openRef->innerExpr = expr; + openRef->type.isLeftValue = (as(exprType) != nullptr); + openRef->type.type = refType->getValueType(); + return openRef; } + return expr; +} - Scope* SemanticsVisitor::getScope(SyntaxNode* node) +Scope* SemanticsVisitor::getScope(SyntaxNode* node) +{ + while (auto declBase = as(node)) { - while (auto declBase = as(node)) + if (auto container = as(node)) { - if (auto container = as(node)) - { - if (container->ownedScope) - return container->ownedScope; - } - node = declBase->parentDecl; + if (container->ownedScope) + return container->ownedScope; } - return nullptr; + node = declBase->parentDecl; } + return nullptr; +} - static SourceLoc _getMemberOpLoc(Expr* expr) - { - if (auto m = as(expr)) - return m->memberOperatorLoc; - if (auto m = as(expr)) - return m->memberOperatorLoc; - return SourceLoc(); - } +static SourceLoc _getMemberOpLoc(Expr* expr) +{ + if (auto m = as(expr)) + return m->memberOperatorLoc; + if (auto m = as(expr)) + return m->memberOperatorLoc; + return SourceLoc(); +} - void addSiblingScopeForContainerDecl(ASTBuilder* builder, ContainerDecl* dest, ContainerDecl* source) - { - addSiblingScopeForContainerDecl(builder, dest->ownedScope, source); - } +void addSiblingScopeForContainerDecl( + ASTBuilder* builder, + ContainerDecl* dest, + ContainerDecl* source) +{ + addSiblingScopeForContainerDecl(builder, dest->ownedScope, source); +} - void addSiblingScopeForContainerDecl(ASTBuilder* builder, Scope* destScope, ContainerDecl* source) - { - auto subScope = builder->create(); - subScope->containerDecl = source; +void addSiblingScopeForContainerDecl(ASTBuilder* builder, Scope* destScope, ContainerDecl* source) +{ + auto subScope = builder->create(); + subScope->containerDecl = source; - subScope->nextSibling = destScope->nextSibling; - destScope->nextSibling = subScope; - } + subScope->nextSibling = destScope->nextSibling; + destScope->nextSibling = subScope; +} - void SemanticsVisitor::diagnoseDeprecatedDeclRefUsage( - DeclRef declRef, - SourceLoc loc, - Expr* originalExpr) +void SemanticsVisitor::diagnoseDeprecatedDeclRefUsage( + DeclRef declRef, + SourceLoc loc, + Expr* originalExpr) +{ + // This is slightly subtle, because we don't want to warn more than + // once for the same occurrence, however in some cases this function is + // called more than once for the same declref (specifically in the case + // of a non-overloaded function, once when the function is identified at + // first, and again when it's checked from + // CheckInvokeExprWithCheckedOperands). + // + // The correct fix is probably to make + // CheckInvokeExprWithCheckedOperands reuse the original declref, + // however that doesn't appear to be a simple change. + // + // What we do instead is see if there's already been a declRef + // constructed for this expression and rest assured that it's already + // had a diagnostic emitted. + auto originalAppExpr = as(originalExpr); + auto originalAppFunDecl = + originalAppExpr ? as(originalAppExpr->functionExpr) : nullptr; + if (originalAppFunDecl && originalAppFunDecl->declRef) + { + return; + } + if (auto deprecatedAttr = declRef.getDecl()->findModifier()) { - // This is slightly subtle, because we don't want to warn more than - // once for the same occurrence, however in some cases this function is - // called more than once for the same declref (specifically in the case - // of a non-overloaded function, once when the function is identified at - // first, and again when it's checked from - // CheckInvokeExprWithCheckedOperands). - // - // The correct fix is probably to make - // CheckInvokeExprWithCheckedOperands reuse the original declref, - // however that doesn't appear to be a simple change. - // - // What we do instead is see if there's already been a declRef - // constructed for this expression and rest assured that it's already - // had a diagnostic emitted. - auto originalAppExpr = as(originalExpr); - auto originalAppFunDecl = originalAppExpr ? as(originalAppExpr->functionExpr) : nullptr; - if(originalAppFunDecl && originalAppFunDecl->declRef) - { - return; - } - if (auto deprecatedAttr = declRef.getDecl()->findModifier()) - { - getSink()->diagnose( - loc, - Diagnostics::deprecatedUsage, - declRef.getName(), - deprecatedAttr->message); - } + getSink()->diagnose( + loc, + Diagnostics::deprecatedUsage, + declRef.getName(), + deprecatedAttr->message); } +} - static bool isMutableGLSLBufferBlockVarExpr(Expr* expr) - { - const auto derefExpr = as(expr); - if(!derefExpr) - return false; - const auto varExpr = as(derefExpr->base); - // Check the declaration type - if(!varExpr) - return false; +static bool isMutableGLSLBufferBlockVarExpr(Expr* expr) +{ + const auto derefExpr = as(expr); + if (!derefExpr) + return false; + const auto varExpr = as(derefExpr->base); + // Check the declaration type + if (!varExpr) + return false; - const auto varExprType = varExpr->type->getCanonicalType(); - const auto ssbt = as(varExprType); - if(!ssbt) - return false; + const auto varExprType = varExpr->type->getCanonicalType(); + const auto ssbt = as(varExprType); + if (!ssbt) + return false; - // Check the modifiers on the declaration - const auto d = varExpr->declRef.getDecl(); - auto collection = d->findModifier(); - if(collection && collection->getMemoryQualifierBit() & MemoryQualifierSetModifier::Flags::kReadOnly) - return false; + // Check the modifiers on the declaration + const auto d = varExpr->declRef.getDecl(); + auto collection = d->findModifier(); + if (collection && + collection->getMemoryQualifierBit() & MemoryQualifierSetModifier::Flags::kReadOnly) + return false; - return true; - } + return true; +} - DeclRefExpr* SemanticsVisitor::ConstructDeclRefExpr( - DeclRef declRef, - Expr* baseExpr, - Name* name, - SourceLoc loc, - Expr* originalExpr) - { - // Compute the type that this declaration reference will have in context. - // - auto type = GetTypeForDeclRef(declRef, loc); +DeclRefExpr* SemanticsVisitor::ConstructDeclRefExpr( + DeclRef declRef, + Expr* baseExpr, + Name* name, + SourceLoc loc, + Expr* originalExpr) +{ + // Compute the type that this declaration reference will have in context. + // + auto type = GetTypeForDeclRef(declRef, loc); - // This is the bottleneck for using declarations which might be - // deprecated, diagnose here. - diagnoseDeprecatedDeclRefUsage(declRef, loc, originalExpr); + // This is the bottleneck for using declarations which might be + // deprecated, diagnose here. + diagnoseDeprecatedDeclRefUsage(declRef, loc, originalExpr); + + // Construct an appropriate expression based on the structured of + // the declaration reference. + // + if (baseExpr) + { + // If there was a base expression, we will have some kind of + // member expression. - // Construct an appropriate expression based on the structured of - // the declaration reference. + // We want to check for the case where the base "expression" + // actually names a type, because in that case we are doing + // a static member reference. // - if (baseExpr) + if (auto typeType = as(baseExpr->type->getCanonicalType())) { - // If there was a base expression, we will have some kind of - // member expression. - - // We want to check for the case where the base "expression" - // actually names a type, because in that case we are doing - // a static member reference. + // Before forming the reference, we will check if the + // member being referenced can even be used as a static + // member, and if not we will diagnose an error. // - if (auto typeType = as(baseExpr->type->getCanonicalType())) - { - // Before forming the reference, we will check if the - // member being referenced can even be used as a static - // member, and if not we will diagnose an error. - // - // TODO: It is conceptually possible to allow static - // references to many instance members, provided we - // change the exposed type/signature. - // - // E.g., if we have: - // - // struct Test { float getVal() { ... } } - // - // Then a reference to `Test.getVal` could be allowed, - // and given a type of `(Test) -> float` to indicate - // that it is an "unbound" instance method. - // - if( !isDeclUsableAsStaticMember(declRef.getDecl()) ) - { - getSink()->diagnose( - loc, - Diagnostics::staticRefToNonStaticMember, - typeType->getType(), - declRef.getName()); - } - - auto expr = m_astBuilder->create(); - expr->loc = loc; - expr->type = type; - expr->baseExpression = baseExpr; - expr->name = name; - expr->declRef = declRef; - expr->memberOperatorLoc = _getMemberOpLoc(originalExpr); - return expr; - } - else if(isEffectivelyStatic(declRef.getDecl())) + // TODO: It is conceptually possible to allow static + // references to many instance members, provided we + // change the exposed type/signature. + // + // E.g., if we have: + // + // struct Test { float getVal() { ... } } + // + // Then a reference to `Test.getVal` could be allowed, + // and given a type of `(Test) -> float` to indicate + // that it is an "unbound" instance method. + // + if (!isDeclUsableAsStaticMember(declRef.getDecl())) { - // Extract the type of the baseExpr - auto baseExprType = baseExpr->type.type; - SharedTypeExpr* baseTypeExpr = m_astBuilder->create(); - baseTypeExpr->base.type = baseExprType; - baseTypeExpr->type.type = m_astBuilder->getTypeType(baseExprType); - - auto expr = m_astBuilder->create(); - expr->loc = loc; - expr->type = type; - expr->baseExpression = baseTypeExpr; - expr->name = name; - expr->declRef = declRef; - expr->memberOperatorLoc = _getMemberOpLoc(originalExpr); - return expr; + getSink()->diagnose( + loc, + Diagnostics::staticRefToNonStaticMember, + typeType->getType(), + declRef.getName()); } - else - { - // If the base expression wasn't a type, then this - // is a normal member expression. - // - auto expr = m_astBuilder->create(); - expr->loc = loc; - expr->type = type; - expr->baseExpression = baseExpr; - expr->name = name; - expr->declRef = declRef; - expr->memberOperatorLoc = _getMemberOpLoc(originalExpr); - - // If any member declares the following value is a - // write only, we must declare the parent as a write - // only to avoid modifying the child - expr->type.isWriteOnly = baseExpr->type.isWriteOnly || expr->type.isWriteOnly; - - // When referring to a member through an expression, - // the result is only an l-value if both the base - // expression and the member agree that it should be. - // - // We have already used the `QualType` from the member - // above (that is `type`), so we need to take the - // l-value status of the base expression into account now. - if(!baseExpr->type.isLeftValue) + + auto expr = m_astBuilder->create(); + expr->loc = loc; + expr->type = type; + expr->baseExpression = baseExpr; + expr->name = name; + expr->declRef = declRef; + expr->memberOperatorLoc = _getMemberOpLoc(originalExpr); + return expr; + } + else if (isEffectivelyStatic(declRef.getDecl())) + { + // Extract the type of the baseExpr + auto baseExprType = baseExpr->type.type; + SharedTypeExpr* baseTypeExpr = m_astBuilder->create(); + baseTypeExpr->base.type = baseExprType; + baseTypeExpr->type.type = m_astBuilder->getTypeType(baseExprType); + + auto expr = m_astBuilder->create(); + expr->loc = loc; + expr->type = type; + expr->baseExpression = baseTypeExpr; + expr->name = name; + expr->declRef = declRef; + expr->memberOperatorLoc = _getMemberOpLoc(originalExpr); + return expr; + } + else + { + // If the base expression wasn't a type, then this + // is a normal member expression. + // + auto expr = m_astBuilder->create(); + expr->loc = loc; + expr->type = type; + expr->baseExpression = baseExpr; + expr->name = name; + expr->declRef = declRef; + expr->memberOperatorLoc = _getMemberOpLoc(originalExpr); + + // If any member declares the following value is a + // write only, we must declare the parent as a write + // only to avoid modifying the child + expr->type.isWriteOnly = baseExpr->type.isWriteOnly || expr->type.isWriteOnly; + + // When referring to a member through an expression, + // the result is only an l-value if both the base + // expression and the member agree that it should be. + // + // We have already used the `QualType` from the member + // above (that is `type`), so we need to take the + // l-value status of the base expression into account now. + if (!baseExpr->type.isLeftValue) + { + // One exception to this is if we're reading the contents + // of a GLSL buffer interface block which isn't marked as + // read_only + expr->type.isLeftValue = isMutableGLSLBufferBlockVarExpr(baseExpr) && + (expr->type.hasReadOnlyOnTarget == false); + + // Another exception is if we are accessing a property + // that provides a [nonmutating] setter. + if (!expr->type.isLeftValue && as(declRef.getDecl())) { - // One exception to this is if we're reading the contents - // of a GLSL buffer interface block which isn't marked as - // read_only - expr->type.isLeftValue = isMutableGLSLBufferBlockVarExpr(baseExpr) && (expr->type.hasReadOnlyOnTarget == false); - - // Another exception is if we are accessing a property - // that provides a [nonmutating] setter. - if (!expr->type.isLeftValue && - as(declRef.getDecl())) + bool isLValue = false; + for (auto member : as(declRef.getDecl())->members) { - bool isLValue = false; - for (auto member : as(declRef.getDecl())->members) + if (as(member) || as(member)) { - if (as(member) || as< RefAccessorDecl>(member)) + if (member->findModifier()) { - if (member->findModifier()) - { - isLValue = true; - } - break; + isLValue = true; } + break; } - expr->type.isLeftValue = isLValue; } + expr->type.isLeftValue = isLValue; } - else + } + else + { + // If we are accessing a readonly property, then the result + // is not an l-value. + if (auto propertyDecl = as(declRef.getDecl())) { - // If we are accessing a readonly property, then the result - // is not an l-value. - if (auto propertyDecl = as(declRef.getDecl())) + bool isLValue = false; + for (auto member : propertyDecl->members) { - bool isLValue = false; - for (auto member : propertyDecl->members) + if (as(member) || as(member)) { - if (as(member) || as< RefAccessorDecl>(member)) - { - isLValue = true; - break; - } + isLValue = true; + break; } - expr->type.isLeftValue = isLValue; } + expr->type.isLeftValue = isLValue; } - return expr; } + return expr; } - else + } + else + { + // If there is no base expression, then the result must + // be an ordinary variable expression. + // + auto expr = m_astBuilder->create(); + expr->loc = loc; + expr->name = name; + expr->type = type; + expr->declRef = declRef; + // Keep a reference to the original expr if it was a genericApp/member. + // This is needed by the language server to locate the original tokens. + if (as(originalExpr) || as(originalExpr) || + as(originalExpr)) { - // If there is no base expression, then the result must - // be an ordinary variable expression. - // - auto expr = m_astBuilder->create(); - expr->loc = loc; - expr->name = name; - expr->type = type; - expr->declRef = declRef; - // Keep a reference to the original expr if it was a genericApp/member. - // This is needed by the language server to locate the original tokens. - if (as(originalExpr) || as(originalExpr) || as(originalExpr)) - { - expr->originalExpr = originalExpr; - } - return expr; + expr->originalExpr = originalExpr; } + return expr; } +} - Expr* SemanticsVisitor::ConstructDerefExpr( - Expr* base, - SourceLoc loc) - { - auto elementType = getPointedToTypeIfCanImplicitDeref(base->type); - SLANG_ASSERT(elementType); +Expr* SemanticsVisitor::ConstructDerefExpr(Expr* base, SourceLoc loc) +{ + auto elementType = getPointedToTypeIfCanImplicitDeref(base->type); + SLANG_ASSERT(elementType); - auto derefExpr = m_astBuilder->create(); - derefExpr->loc = loc; - derefExpr->base = base; - derefExpr->type = QualType(elementType); + auto derefExpr = m_astBuilder->create(); + derefExpr->loc = loc; + derefExpr->base = base; + derefExpr->type = QualType(elementType); - if (as(base->type)) - derefExpr->type.isLeftValue = true; - else - derefExpr->type.isLeftValue = base->type.isLeftValue; + if (as(base->type)) + derefExpr->type.isLeftValue = true; + else + derefExpr->type.isLeftValue = base->type.isLeftValue; - return derefExpr; - } + return derefExpr; +} - InvokeExpr* SemanticsVisitor::constructUncheckedInvokeExpr(Expr* callee, const List& arguments) - { - auto result = m_astBuilder->create(); - result->loc = callee->loc; - result->functionExpr = callee; - result->arguments.addRange(arguments); - return result; - } +InvokeExpr* SemanticsVisitor::constructUncheckedInvokeExpr( + Expr* callee, + const List& arguments) +{ + auto result = m_astBuilder->create(); + result->loc = callee->loc; + result->functionExpr = callee; + result->arguments.addRange(arguments); + return result; +} - Expr* SemanticsVisitor::maybeUseSynthesizedDeclForLookupResult( - LookupResultItem const& item, - Expr* originalExpr) - { - // If the only result from lookup is an entry in an interface decl, it could be that - // the user is leaving out an explicit definition for the requirement and depending on - // the compiler to synthesis the definition. - // In this case, if the lookup is triggered from a location such that the satisfying - // definition should be returned should it existed, we should create a placeholder decl for - // the definition and return a reference to to newly created decl instead of the requirement - // decl in the interface. - switch (item.declRef.getDecl()->astNodeType) - { - case ASTNodeType::AssocTypeDecl: - break; - case ASTNodeType::FuncDecl: - // We don't need to intercept lookup results with synthesized decls for methods, - // because function lookups will only take place when we are checking the decl bodies. - // At that point conformance check and synthesis is already done so they will always resolve - // to the synthesized method. - return nullptr; - default: - return nullptr; - } +Expr* SemanticsVisitor::maybeUseSynthesizedDeclForLookupResult( + LookupResultItem const& item, + Expr* originalExpr) +{ + // If the only result from lookup is an entry in an interface decl, it could be that + // the user is leaving out an explicit definition for the requirement and depending on + // the compiler to synthesis the definition. + // In this case, if the lookup is triggered from a location such that the satisfying + // definition should be returned should it existed, we should create a placeholder decl for + // the definition and return a reference to to newly created decl instead of the requirement + // decl in the interface. + switch (item.declRef.getDecl()->astNodeType) + { + case ASTNodeType::AssocTypeDecl: break; + case ASTNodeType::FuncDecl: + // We don't need to intercept lookup results with synthesized decls for methods, + // because function lookups will only take place when we are checking the decl bodies. + // At that point conformance check and synthesis is already done so they will always + // resolve to the synthesized method. + return nullptr; + default: return nullptr; + } - // We need to check if the lookup should resolve to a definition in an implementation type - // if it existed. - // This will be the case when the lookup is initiated from the concrete implementation type instead of - // directly from the Interface decl. The breadcrumbs of the lookup should provide this information. + // We need to check if the lookup should resolve to a definition in an implementation type + // if it existed. + // This will be the case when the lookup is initiated from the concrete implementation type + // instead of directly from the Interface decl. The breadcrumbs of the lookup should provide + // this information. - // If no breadcrumbs existed, then the lookup should just resolve to the interface requirement. + // If no breadcrumbs existed, then the lookup should just resolve to the interface requirement. - if (!item.breadcrumbs) - return nullptr; + if (!item.breadcrumbs) + return nullptr; - // We will only ever need to synthesis a type to satisfy an associatedtype requirement. - // In this case the lookup should have resolved to a known associatedtype decl. - auto builtinAssocTypeAttr = item.declRef.getDecl()->findModifier(); - if (!builtinAssocTypeAttr) - return nullptr; + // We will only ever need to synthesis a type to satisfy an associatedtype requirement. + // In this case the lookup should have resolved to a known associatedtype decl. + auto builtinAssocTypeAttr = item.declRef.getDecl()->findModifier(); + if (!builtinAssocTypeAttr) + return nullptr; - DeclRefType* subType = nullptr; + DeclRefType* subType = nullptr; - // Check if we are reaching the associated type decl through inheritance from a concrete type. - for (auto breadcrumb = item.breadcrumbs; breadcrumb; breadcrumb = breadcrumb->next) + // Check if we are reaching the associated type decl through inheritance from a concrete type. + for (auto breadcrumb = item.breadcrumbs; breadcrumb; breadcrumb = breadcrumb->next) + { + switch (breadcrumb->kind) { - switch (breadcrumb->kind) - { - case LookupResultItem::Breadcrumb::Kind::SuperType: + case LookupResultItem::Breadcrumb::Kind::SuperType: { auto witness = as(breadcrumb->val); if (auto subDeclRefType = as(witness->getSub())) @@ -588,4201 +595,4373 @@ namespace Slang } } break; - default: - break; - } + default: break; } - if (!subType) - return nullptr; + } + if (!subType) + return nullptr; - subType = as(subType->getCanonicalType()); - if (!subType) - return nullptr; + subType = as(subType->getCanonicalType()); + if (!subType) + return nullptr; - // Don't synthesize for generic parameters. - auto parent = as(subType->getDeclRef().getDecl()); - if (!parent) - return nullptr; + // Don't synthesize for generic parameters. + auto parent = as(subType->getDeclRef().getDecl()); + if (!parent) + return nullptr; - // Don't synthesize for ThisType. - if (as(subType->getDeclRef().getDecl())) - return nullptr; - - // If the inner most subtype is itself an associated type, then we're dealing - // with an abstract type. There's not need to synthesize anythin at this point. - // - if (as(subType->getDeclRef().getDecl())) - return nullptr; + // Don't synthesize for ThisType. + if (as(subType->getDeclRef().getDecl())) + return nullptr; + + // If the inner most subtype is itself an associated type, then we're dealing + // with an abstract type. There's not need to synthesize anythin at this point. + // + if (as(subType->getDeclRef().getDecl())) + return nullptr; - // If we reach here, we are expecting a synthesized decl defined in `subType`. - // Instead of returning a DeclRefExpr to the requirement decl, we synthesize a placeholder decl - // in `subType` and return a DeclRefExpr to the synthesized decl. + // If we reach here, we are expecting a synthesized decl defined in `subType`. + // Instead of returning a DeclRefExpr to the requirement decl, we synthesize a placeholder decl + // in `subType` and return a DeclRefExpr to the synthesized decl. - Decl* synthesizedDecl = nullptr; - switch (builtinAssocTypeAttr->kind) + Decl* synthesizedDecl = nullptr; + switch (builtinAssocTypeAttr->kind) + { + case BuiltinRequirementKind::DifferentialType: { - case BuiltinRequirementKind::DifferentialType: + if (!canStructBeUsedAsSelfDifferentialType(parent)) { - if (!canStructBeUsedAsSelfDifferentialType(parent)) - { - // Need to create a new struct type for the differential. - // - auto structDecl = m_astBuilder->create(); - auto conformanceDecl = m_astBuilder->create(); - conformanceDecl->base.type = m_astBuilder->getDiffInterfaceType(); - conformanceDecl->parentDecl = structDecl; - structDecl->members.add(conformanceDecl); - structDecl->parentDecl = parent; - - synthesizedDecl = structDecl; - auto typeDef = m_astBuilder->create(); - typeDef->nameAndLoc.name = getName("Differential"); - typeDef->parentDecl = structDecl; - - auto synthDeclRef = createDefaultSubstitutionsIfNeeded(m_astBuilder, this, makeDeclRef(structDecl)); - - typeDef->type.type = DeclRefType::create(m_astBuilder, synthDeclRef); - structDecl->members.add(typeDef); - - synthesizedDecl->parentDecl = parent; - synthesizedDecl->nameAndLoc.name = item.declRef.getName(); - synthesizedDecl->loc = parent->loc; - parent->members.add(synthesizedDecl); - parent->invalidateMemberDictionary(); - - // Mark the newly synthesized decl as `ToBeSynthesized` so future checking can differentiate it - // from user-provided definitions, and proceed to fill in its definition. - auto toBeSynthesized = m_astBuilder->create(); - addModifier(synthesizedDecl, toBeSynthesized); - } - else - { - // There's no need for a new struct decl. - // We can simply add a typealias to the existing concrete type. - // - auto typeDef = m_astBuilder->create(); - typeDef->nameAndLoc.name = item.declRef.getName(); - typeDef->parentDecl = parent; - - // Compute the decl's type as if it is referred to from itself. This is important because - // subType may have substitutions from the context it is used in, while this synthesis step - // is local to the decl. - // - typeDef->type.type = calcThisType(subType->getDeclRef().getDecl()->getDefaultDeclRef()); - - synthesizedDecl = parent; - - parent->members.add(typeDef); - parent->invalidateMemberDictionary(); - - markSelfDifferentialMembersOfType(parent, subType); - } + // Need to create a new struct type for the differential. + // + auto structDecl = m_astBuilder->create(); + auto conformanceDecl = m_astBuilder->create(); + conformanceDecl->base.type = m_astBuilder->getDiffInterfaceType(); + conformanceDecl->parentDecl = structDecl; + structDecl->members.add(conformanceDecl); + structDecl->parentDecl = parent; + + synthesizedDecl = structDecl; + auto typeDef = m_astBuilder->create(); + typeDef->nameAndLoc.name = getName("Differential"); + typeDef->parentDecl = structDecl; + + auto synthDeclRef = + createDefaultSubstitutionsIfNeeded(m_astBuilder, this, makeDeclRef(structDecl)); + + typeDef->type.type = DeclRefType::create(m_astBuilder, synthDeclRef); + structDecl->members.add(typeDef); + + synthesizedDecl->parentDecl = parent; + synthesizedDecl->nameAndLoc.name = item.declRef.getName(); + synthesizedDecl->loc = parent->loc; + parent->members.add(synthesizedDecl); + parent->invalidateMemberDictionary(); + + // Mark the newly synthesized decl as `ToBeSynthesized` so future checking can + // differentiate it from user-provided definitions, and proceed to fill in its + // definition. + auto toBeSynthesized = m_astBuilder->create(); + addModifier(synthesizedDecl, toBeSynthesized); } - break; - default: - return nullptr; - } + else + { + // There's no need for a new struct decl. + // We can simply add a typealias to the existing concrete type. + // + auto typeDef = m_astBuilder->create(); + typeDef->nameAndLoc.name = item.declRef.getName(); + typeDef->parentDecl = parent; - auto synthDeclMemberRef = m_astBuilder->getMemberDeclRef(subType->getDeclRef(), synthesizedDecl); - return ConstructDeclRefExpr( - synthDeclMemberRef, - nullptr, - item.declRef.getName(), - originalExpr ? originalExpr->loc : SourceLoc(), - originalExpr); - } + // Compute the decl's type as if it is referred to from itself. This is important + // because subType may have substitutions from the context it is used in, while this + // synthesis step is local to the decl. + // + typeDef->type.type = + calcThisType(subType->getDeclRef().getDecl()->getDefaultDeclRef()); - Expr* SemanticsVisitor::ConstructLookupResultExpr( - LookupResultItem const& item, - Expr* baseExpr, - Name* name, - SourceLoc loc, - Expr* originalExpr) - { - if (!item.declRef) - { - originalExpr->type = QualType(m_astBuilder->getErrorType()); - return originalExpr; + synthesizedDecl = parent; + + parent->members.add(typeDef); + parent->invalidateMemberDictionary(); + + markSelfDifferentialMembersOfType(parent, subType); + } } + break; + default: return nullptr; + } - // We could be referencing a decl that will be synthesized. If so create a placeholder - // and return a DeclRefExpr to it. - if (auto lookupResultExpr = maybeUseSynthesizedDeclForLookupResult(item, originalExpr)) - return lookupResultExpr; + auto synthDeclMemberRef = + m_astBuilder->getMemberDeclRef(subType->getDeclRef(), synthesizedDecl); + return ConstructDeclRefExpr( + synthDeclMemberRef, + nullptr, + item.declRef.getName(), + originalExpr ? originalExpr->loc : SourceLoc(), + originalExpr); +} - // If we collected any breadcrumbs, then these represent - // additional segments of the lookup path that we need - // to expand here. - auto bb = baseExpr; - for (auto breadcrumb = item.breadcrumbs; breadcrumb; breadcrumb = breadcrumb->next) +Expr* SemanticsVisitor::ConstructLookupResultExpr( + LookupResultItem const& item, + Expr* baseExpr, + Name* name, + SourceLoc loc, + Expr* originalExpr) +{ + if (!item.declRef) + { + originalExpr->type = QualType(m_astBuilder->getErrorType()); + return originalExpr; + } + + // We could be referencing a decl that will be synthesized. If so create a placeholder + // and return a DeclRefExpr to it. + if (auto lookupResultExpr = maybeUseSynthesizedDeclForLookupResult(item, originalExpr)) + return lookupResultExpr; + + // If we collected any breadcrumbs, then these represent + // additional segments of the lookup path that we need + // to expand here. + auto bb = baseExpr; + for (auto breadcrumb = item.breadcrumbs; breadcrumb; breadcrumb = breadcrumb->next) + { + switch (breadcrumb->kind) { - switch (breadcrumb->kind) - { - case LookupResultItem::Breadcrumb::Kind::Member: - bb = ConstructDeclRefExpr(breadcrumb->declRef, bb, name, loc, originalExpr); - break; + case LookupResultItem::Breadcrumb::Kind::Member: + bb = ConstructDeclRefExpr(breadcrumb->declRef, bb, name, loc, originalExpr); + break; - case LookupResultItem::Breadcrumb::Kind::Deref: - bb = ConstructDerefExpr(bb, loc); - break; + case LookupResultItem::Breadcrumb::Kind::Deref: bb = ConstructDerefExpr(bb, loc); break; - case LookupResultItem::Breadcrumb::Kind::SuperType: + case LookupResultItem::Breadcrumb::Kind::SuperType: + { + // Note: a lookup through a super-type can + // occur even in the case of a `static` member, + // so we only modify the base expression here + // if there is one. + // + if (bb) { - // Note: a lookup through a super-type can - // occur even in the case of a `static` member, - // so we only modify the base expression here - // if there is one. + // We know that the breadcrumb reprsents a + // cast of the base expression to a super type, + // so we construct that cast explicitly here. + // + auto witness = as(breadcrumb->val); + SLANG_ASSERT(witness); + auto expr = createCastToSuperTypeExpr(witness->getSup(), bb, witness); + + // Note that we allow a cast of an l-value to + // be used as an l-value here because it enables + // `[mutating]` methods to be called, and + // mutable properties to be modified, but this + // is probably not *technically* correct, since + // treating an l-value of type `Derived` as + // an l-value of type `Base` implies that we + // can assign an arbitrary value of type `Base` + // to that l-value (which would be an error). // - if( bb ) + // TODO: make sure we believe there are no + // issues here. + // + if (bb && bb->type.isLeftValue) { - // We know that the breadcrumb reprsents a - // cast of the base expression to a super type, - // so we construct that cast explicitly here. - // - auto witness = as(breadcrumb->val); - SLANG_ASSERT(witness); - auto expr = createCastToSuperTypeExpr(witness->getSup(), bb, witness); - - // Note that we allow a cast of an l-value to - // be used as an l-value here because it enables - // `[mutating]` methods to be called, and - // mutable properties to be modified, but this - // is probably not *technically* correct, since - // treating an l-value of type `Derived` as - // an l-value of type `Base` implies that we - // can assign an arbitrary value of type `Base` - // to that l-value (which would be an error). - // - // TODO: make sure we believe there are no - // issues here. - // - if(bb && bb->type.isLeftValue) - { - expr->type.isLeftValue = true; - } - - bb = expr; + expr->type.isLeftValue = true; } + + bb = expr; } - break; + } + break; - case LookupResultItem::Breadcrumb::Kind::This: - { - // We expect a `this` to always come - // at the start of a chain. - SLANG_ASSERT(bb == nullptr); + case LookupResultItem::Breadcrumb::Kind::This: + { + // We expect a `this` to always come + // at the start of a chain. + SLANG_ASSERT(bb == nullptr); - // We will compute the type to use for `This` using - // the same logic that a direct reference to `This` - // uses. - // - auto thisType = calcThisType(breadcrumb->declRef); + // We will compute the type to use for `This` using + // the same logic that a direct reference to `This` + // uses. + // + auto thisType = calcThisType(breadcrumb->declRef); - // Next we construct an appropriate expression to - // stand in for the implicit `this` or `This` reference. + // Next we construct an appropriate expression to + // stand in for the implicit `this` or `This` reference. + // + // The lookup process will have computed the appropriate + // "mode" to use for the implicit `this` or `This`. + // + auto thisParameterMode = breadcrumb->thisParameterMode; + if (thisParameterMode == LookupResultItem::Breadcrumb::ThisParameterMode::Type) + { + // If we are in a static context, then we do not + // have implicit `this` expression, and the expression + // we construct will need to start with the `This` + // type. // - // The lookup process will have computed the appropriate - // "mode" to use for the implicit `this` or `This`. + // Because we are constrained to yield an expression + // here, we must construct an expression that + // references `This`, and the *type* of that expression + // will be `typeof(This)`, which conceptually + // `typeof(typeof(this))` // - auto thisParameterMode = breadcrumb->thisParameterMode; - if(thisParameterMode == LookupResultItem::Breadcrumb::ThisParameterMode::Type) - { - // If we are in a static context, then we do not - // have implicit `this` expression, and the expression - // we construct will need to start with the `This` - // type. - // - // Because we are constrained to yield an expression - // here, we must construct an expression that - // references `This`, and the *type* of that expression - // will be `typeof(This)`, which conceptually - // `typeof(typeof(this))` - // - auto thisTypeType = m_astBuilder->getTypeType(thisType); + auto thisTypeType = m_astBuilder->getTypeType(thisType); - auto typeExpr = m_astBuilder->create(); - typeExpr->type.type = thisTypeType; - typeExpr->base.type = thisType; + auto typeExpr = m_astBuilder->create(); + typeExpr->type.type = thisTypeType; + typeExpr->base.type = thisType; - bb = typeExpr; - } - else + bb = typeExpr; + } + else + { + // In a context where both static and instance members can + // be referenced, we will construct a reference to `this`, + // and then rely on downstream logic to ensure that a + // refernece to `this.someStaticMember` will be translated + // over to `This.someStaticMember`. + // + ThisExpr* expr = m_astBuilder->create(); + expr->type.type = thisType; + expr->loc = loc; + if (auto declRefExpr = as(originalExpr)) + expr->scope = declRefExpr->scope; + else if (auto invokeExpr = as(originalExpr)) { - // In a context where both static and instance members can - // be referenced, we will construct a reference to `this`, - // and then rely on downstream logic to ensure that a - // refernece to `this.someStaticMember` will be translated - // over to `This.someStaticMember`. - // - ThisExpr* expr = m_astBuilder->create(); - expr->type.type = thisType; - expr->loc = loc; - if (auto declRefExpr = as(originalExpr)) - expr->scope = declRefExpr->scope; - else if (auto invokeExpr = as(originalExpr)) - { - if (auto calleeDeclRefExpr = as(invokeExpr->originalFunctionExpr)) - expr->scope = calleeDeclRefExpr->scope; - } - // Whether or not the implicit `this` is mutable depends - // on the context in which it is used, and the lookup - // logic will have computed an appropriate "mode" based - // on the context during lookup. - // - expr->type.isLeftValue = thisParameterMode == LookupResultItem::Breadcrumb::ThisParameterMode::MutableValue; - - bb = expr; + if (auto calleeDeclRefExpr = + as(invokeExpr->originalFunctionExpr)) + expr->scope = calleeDeclRefExpr->scope; } - } - break; + // Whether or not the implicit `this` is mutable depends + // on the context in which it is used, and the lookup + // logic will have computed an appropriate "mode" based + // on the context during lookup. + // + expr->type.isLeftValue = + thisParameterMode == + LookupResultItem::Breadcrumb::ThisParameterMode::MutableValue; - default: - SLANG_UNREACHABLE("all cases handle"); + bb = expr; + } } - if (getShared()->isInLanguageServer()) + break; + + default: SLANG_UNREACHABLE("all cases handle"); + } + if (getShared()->isInLanguageServer()) + { + // Don't make breadcrumb nodes carry any source loc info, + // as they may confuse language server functionalities. + if (bb) { - // Don't make breadcrumb nodes carry any source loc info, - // as they may confuse language server functionalities. - if (bb) - { - bb->loc = SourceLoc(); - } + bb->loc = SourceLoc(); } } - - return ConstructDeclRefExpr(item.declRef, bb, name, loc, originalExpr); } - void SemanticsVisitor::suggestCompletionItems( - CompletionSuggestions::ScopeKind scopeKind, LookupResult const& lookupResult) + return ConstructDeclRefExpr(item.declRef, bb, name, loc, originalExpr); +} + +void SemanticsVisitor::suggestCompletionItems( + CompletionSuggestions::ScopeKind scopeKind, + LookupResult const& lookupResult) +{ + auto& suggestions = getLinkage()->contentAssistInfo.completionSuggestions; + suggestions.clear(); + suggestions.scopeKind = scopeKind; + for (auto item : lookupResult) { - auto& suggestions = getLinkage()->contentAssistInfo.completionSuggestions; - suggestions.clear(); - suggestions.scopeKind = scopeKind; - for (auto item : lookupResult) - { - suggestions.candidateItems.add(item); - } + suggestions.candidateItems.add(item); } +} - Expr* SemanticsVisitor::createLookupResultExpr( - Name* name, - LookupResult const& lookupResult, - Expr* baseExpr, - SourceLoc loc, - Expr* originalExpr) +Expr* SemanticsVisitor::createLookupResultExpr( + Name* name, + LookupResult const& lookupResult, + Expr* baseExpr, + SourceLoc loc, + Expr* originalExpr) +{ + if (lookupResult.isOverloaded()) + { + auto overloadedExpr = m_astBuilder->create(); + overloadedExpr->name = name; + overloadedExpr->loc = loc; + overloadedExpr->type = QualType(m_astBuilder->getOverloadedType()); + overloadedExpr->base = baseExpr; + overloadedExpr->lookupResult2 = lookupResult; + overloadedExpr->originalExpr = originalExpr; + return overloadedExpr; + } + else { - if (lookupResult.isOverloaded()) - { - auto overloadedExpr = m_astBuilder->create(); - overloadedExpr->name = name; - overloadedExpr->loc = loc; - overloadedExpr->type = QualType( - m_astBuilder->getOverloadedType()); - overloadedExpr->base = baseExpr; - overloadedExpr->lookupResult2 = lookupResult; - overloadedExpr->originalExpr = originalExpr; - return overloadedExpr; - } - else - { - return ConstructLookupResultExpr(lookupResult.item, baseExpr, name, loc, originalExpr); - } + return ConstructLookupResultExpr(lookupResult.item, baseExpr, name, loc, originalExpr); } +} - DeclVisibility SemanticsVisitor::getTypeVisibility(Type* type) +DeclVisibility SemanticsVisitor::getTypeVisibility(Type* type) +{ + if (auto declRefType = as(type)) { - if (auto declRefType = as(type)) + auto v = getDeclVisibility(declRefType->getDeclRef().getDecl()); + auto args = findInnerMostGenericArgs(SubstitutionSet(declRefType->getDeclRef())); + for (auto arg : args) { - auto v = getDeclVisibility(declRefType->getDeclRef().getDecl()); - auto args = findInnerMostGenericArgs(SubstitutionSet(declRefType->getDeclRef())); - for (auto arg : args) - { - if (auto typeArg = as(arg)) - v = Math::Min(v, getTypeVisibility(typeArg)); - } - return v; + if (auto typeArg = as(arg)) + v = Math::Min(v, getTypeVisibility(typeArg)); } - return DeclVisibility::Public; + return v; } + return DeclVisibility::Public; +} - bool SemanticsVisitor::isDeclVisibleFromScope(DeclRef declRef, Scope* scope) +bool SemanticsVisitor::isDeclVisibleFromScope(DeclRef declRef, Scope* scope) +{ + auto visibility = getDeclVisibility(declRef.getDecl()); + if (visibility == DeclVisibility::Public) + return true; + if (visibility == DeclVisibility::Internal) { - auto visibility = getDeclVisibility(declRef.getDecl()); - if (visibility == DeclVisibility::Public) + // Check that the decl is in the same module as the scope. + auto declModule = getModuleDecl(declRef.getDecl()); + if (declModule == getModuleDecl(scope)) return true; - if (visibility == DeclVisibility::Internal) + } + if (visibility == DeclVisibility::Private) + { + // Check that the decl is in the same or parent container decl as scope. + Decl* parentContainer = declRef.getDecl(); + for (; parentContainer; parentContainer = parentContainer->parentDecl) { - // Check that the decl is in the same module as the scope. - auto declModule = getModuleDecl(declRef.getDecl()); - if (declModule == getModuleDecl(scope)) - return true; + if (as(parentContainer)) + break; + if (as(parentContainer)) + break; } - if (visibility == DeclVisibility::Private) - { - // Check that the decl is in the same or parent container decl as scope. - Decl* parentContainer = declRef.getDecl(); - for (;parentContainer; parentContainer = parentContainer->parentDecl) - { - if (as(parentContainer)) - break; - if (as(parentContainer)) - break; - } - for (auto s = scope; s; s = s->parent) - { - if (s->containerDecl == parentContainer) - return true; - } - return false; + for (auto s = scope; s; s = s->parent) + { + if (s->containerDecl == parentContainer) + return true; } return false; } + return false; +} - LookupResult SemanticsVisitor::filterLookupResultByVisibility(const LookupResult& lookupResult) +LookupResult SemanticsVisitor::filterLookupResultByVisibility(const LookupResult& lookupResult) +{ + if (!m_outerScope) + return lookupResult; + LookupResult filteredResult; + for (auto item : lookupResult) { - if (!m_outerScope) - return lookupResult; - LookupResult filteredResult; - for (auto item : lookupResult) - { - if (isDeclVisibleFromScope(item.declRef, m_outerScope)) - AddToLookupResult(filteredResult, item); - } - return filteredResult; + if (isDeclVisibleFromScope(item.declRef, m_outerScope)) + AddToLookupResult(filteredResult, item); } + return filteredResult; +} - LookupResult SemanticsVisitor::filterLookupResultByVisibilityAndDiagnose(const LookupResult& lookupResult, SourceLoc loc, bool& outDiagnosed) +LookupResult SemanticsVisitor::filterLookupResultByVisibilityAndDiagnose( + const LookupResult& lookupResult, + SourceLoc loc, + bool& outDiagnosed) +{ + outDiagnosed = false; + auto result = filterLookupResultByVisibility(lookupResult); + if (lookupResult.isValid() && !result.isValid()) { - outDiagnosed = false; - auto result = filterLookupResultByVisibility(lookupResult); - if (lookupResult.isValid() && !result.isValid()) - { - getSink()->diagnose(loc, Diagnostics::declIsNotVisible, lookupResult.item.declRef); - outDiagnosed = true; + getSink()->diagnose(loc, Diagnostics::declIsNotVisible, lookupResult.item.declRef); + outDiagnosed = true; - if (getShared()->isInLanguageServer()) - { - // When in language server mode, return the unfiltered result so we can still - // provide language service around it. - return lookupResult; - } + if (getShared()->isInLanguageServer()) + { + // When in language server mode, return the unfiltered result so we can still + // provide language service around it. + return lookupResult; } - return result; } + return result; +} - LookupResult SemanticsVisitor::resolveOverloadedLookup(LookupResult const& inResult) - { - // If the result isn't actually overloaded, it is fine as-is - if (!inResult.isValid()) return inResult; - if (!inResult.isOverloaded()) return inResult; - - // We are going to build up a list of items to return. - List items; - for( auto item : inResult.items ) +LookupResult SemanticsVisitor::resolveOverloadedLookup(LookupResult const& inResult) +{ + // If the result isn't actually overloaded, it is fine as-is + if (!inResult.isValid()) + return inResult; + if (!inResult.isOverloaded()) + return inResult; + + // We are going to build up a list of items to return. + List items; + for (auto item : inResult.items) + { + // For each item we consider adding, we will compare it + // to those items we've already added. + // + // If any of the existing items is "better" than `item`, + // then we will skip adding `item`. + // + // If `item` is "better" than any of the existing items, + // we will remove those from `items`. + // + bool shouldAdd = true; + for (Index ii = 0; ii < items.getCount(); ++ii) { - // For each item we consider adding, we will compare it - // to those items we've already added. - // - // If any of the existing items is "better" than `item`, - // then we will skip adding `item`. - // - // If `item` is "better" than any of the existing items, - // we will remove those from `items`. - // - bool shouldAdd = true; - for( Index ii = 0; ii < items.getCount(); ++ii ) + int cmp = CompareLookupResultItems(item, items[ii]); + if (cmp < 0) { - int cmp = CompareLookupResultItems(item, items[ii]); - if( cmp < 0 ) - { - // The new `item` is strictly better - items.fastRemoveAt(ii); - --ii; - } - else if( cmp > 0 ) - { - // The existing item is strictly better - shouldAdd = false; - } + // The new `item` is strictly better + items.fastRemoveAt(ii); + --ii; } - if( shouldAdd ) + else if (cmp > 0) { - items.add(item); + // The existing item is strictly better + shouldAdd = false; } } - - // The resulting `items` list should be all those items - // that were neither better nor worse than one another. - // - // There should always be at least one such item. - // - SLANG_ASSERT(items.getCount() != 0); - - LookupResult result; - for( auto item : items ) + if (shouldAdd) { - AddToLookupResult(result, item); + items.add(item); } - return result; } - void SemanticsVisitor::diagnoseAmbiguousReference(OverloadedExpr* overloadedExpr, LookupResult const& lookupResult) - { - getSink()->diagnose(overloadedExpr, Diagnostics::ambiguousReference, lookupResult.items[0].declRef.getName()); + // The resulting `items` list should be all those items + // that were neither better nor worse than one another. + // + // There should always be at least one such item. + // + SLANG_ASSERT(items.getCount() != 0); - for(auto item : lookupResult.items) - { - String declString = ASTPrinter::getDeclSignatureString(item, m_astBuilder); - getSink()->diagnose(item.declRef, Diagnostics::overloadCandidate, declString); - } + LookupResult result; + for (auto item : items) + { + AddToLookupResult(result, item); } + return result; +} + +void SemanticsVisitor::diagnoseAmbiguousReference( + OverloadedExpr* overloadedExpr, + LookupResult const& lookupResult) +{ + getSink()->diagnose( + overloadedExpr, + Diagnostics::ambiguousReference, + lookupResult.items[0].declRef.getName()); - void SemanticsVisitor::diagnoseAmbiguousReference(Expr* expr) + for (auto item : lookupResult.items) { - if( auto overloadedExpr = as(expr) ) - { - diagnoseAmbiguousReference(overloadedExpr, overloadedExpr->lookupResult2); - } - else - { - getSink()->diagnose(expr, Diagnostics::ambiguousExpression); - } + String declString = ASTPrinter::getDeclSignatureString(item, m_astBuilder); + getSink()->diagnose(item.declRef, Diagnostics::overloadCandidate, declString); } +} - Expr* SemanticsVisitor::_resolveOverloadedExprImpl(OverloadedExpr* overloadedExpr, LookupMask mask, DiagnosticSink* diagSink) +void SemanticsVisitor::diagnoseAmbiguousReference(Expr* expr) +{ + if (auto overloadedExpr = as(expr)) + { + diagnoseAmbiguousReference(overloadedExpr, overloadedExpr->lookupResult2); + } + else { - auto lookupResult = overloadedExpr->lookupResult2; - SLANG_RELEASE_ASSERT(lookupResult.isValid() && lookupResult.isOverloaded()); + getSink()->diagnose(expr, Diagnostics::ambiguousExpression); + } +} - // Take the lookup result we had, and refine it based on what is expected in context. - // - // E.g., if there is both a type and a variable named `Foo`, but in context we know - // that a type is expected, then we can disambiguate by assuming the type is intended. - // - lookupResult = refineLookup(lookupResult, mask); +Expr* SemanticsVisitor::_resolveOverloadedExprImpl( + OverloadedExpr* overloadedExpr, + LookupMask mask, + DiagnosticSink* diagSink) +{ + auto lookupResult = overloadedExpr->lookupResult2; + SLANG_RELEASE_ASSERT(lookupResult.isValid() && lookupResult.isOverloaded()); - // Try to filter out overload candidates based on which ones are "better" than one another. - lookupResult = resolveOverloadedLookup(lookupResult); + // Take the lookup result we had, and refine it based on what is expected in context. + // + // E.g., if there is both a type and a variable named `Foo`, but in context we know + // that a type is expected, then we can disambiguate by assuming the type is intended. + // + lookupResult = refineLookup(lookupResult, mask); - if (!lookupResult.isValid()) - { - // If we didn't find any symbols after filtering, then just - // use the original and report errors that way - return overloadedExpr; - } + // Try to filter out overload candidates based on which ones are "better" than one another. + lookupResult = resolveOverloadedLookup(lookupResult); - if(!lookupResult.isOverloaded()) - { - // If there is only a single item left in the lookup result, - // then we can proceed to use that item alone as the resolved - // expression. - // - return ConstructLookupResultExpr( - lookupResult.item, overloadedExpr->base, overloadedExpr->name, overloadedExpr->loc, overloadedExpr); - } + if (!lookupResult.isValid()) + { + // If we didn't find any symbols after filtering, then just + // use the original and report errors that way + return overloadedExpr; + } - // Otherwise, we weren't able to resolve the overloading given - // the information available in context. - // - // If the client is asking for us to emit diagnostics about - // this fact, we should do so here: + if (!lookupResult.isOverloaded()) + { + // If there is only a single item left in the lookup result, + // then we can proceed to use that item alone as the resolved + // expression. // - if( diagSink ) - { - diagnoseAmbiguousReference(overloadedExpr, lookupResult); - - // TODO(tfoley): should we construct a new ErrorExpr here? - return CreateErrorExpr(overloadedExpr); - } - else - { - // If the client isn't trying to *force* overload resolution - // to complete just yet (e.g., they are just trying out one - // candidate for an overloaded call site), then we return - // the input expression as-is. - // - return overloadedExpr; - } + return ConstructLookupResultExpr( + lookupResult.item, + overloadedExpr->base, + overloadedExpr->name, + overloadedExpr->loc, + overloadedExpr); } - Expr* SemanticsVisitor::maybeResolveOverloadedExpr(Expr* expr, LookupMask mask, DiagnosticSink* diagSink) + // Otherwise, we weren't able to resolve the overloading given + // the information available in context. + // + // If the client is asking for us to emit diagnostics about + // this fact, we should do so here: + // + if (diagSink) { - if (IsErrorExpr(expr)) - return expr; + diagnoseAmbiguousReference(overloadedExpr, lookupResult); - if( auto overloadedExpr = as(expr) ) - { - return _resolveOverloadedExprImpl(overloadedExpr, mask, diagSink); - } - else - { - return expr; - } + // TODO(tfoley): should we construct a new ErrorExpr here? + return CreateErrorExpr(overloadedExpr); + } + else + { + // If the client isn't trying to *force* overload resolution + // to complete just yet (e.g., they are just trying out one + // candidate for an overloaded call site), then we return + // the input expression as-is. + // + return overloadedExpr; + } +} + +Expr* SemanticsVisitor::maybeResolveOverloadedExpr( + Expr* expr, + LookupMask mask, + DiagnosticSink* diagSink) +{ + if (IsErrorExpr(expr)) + return expr; + + if (auto overloadedExpr = as(expr)) + { + return _resolveOverloadedExprImpl(overloadedExpr, mask, diagSink); } + else + { + return expr; + } +} - Expr* SemanticsVisitor::resolveOverloadedExpr(OverloadedExpr* overloadedExpr, LookupMask mask) +Expr* SemanticsVisitor::resolveOverloadedExpr(OverloadedExpr* overloadedExpr, LookupMask mask) +{ + return _resolveOverloadedExprImpl(overloadedExpr, mask, getSink()); +} + +Type* SemanticsVisitor::tryGetDifferentialType(ASTBuilder* builder, Type* type) +{ + if (auto ptrType = as(type)) { - return _resolveOverloadedExprImpl(overloadedExpr, mask, getSink()); + auto baseDiffType = tryGetDifferentialType(builder, ptrType->getValueType()); + if (!baseDiffType) + return nullptr; + return builder->getPtrType(baseDiffType, ptrType->getClassInfo().m_name); + } + else if (auto arrayType = as(type)) + { + auto baseDiffType = tryGetDifferentialType(builder, arrayType->getElementType()); + if (!baseDiffType) + return nullptr; + return builder->getArrayType(baseDiffType, arrayType->getElementCount()); } - Type* SemanticsVisitor::tryGetDifferentialType(ASTBuilder* builder, Type* type) + if (auto declRefType = as(type)) { - if (auto ptrType = as(type)) + if (auto builtinRequirement = + declRefType->getDeclRef().getDecl()->findModifier()) { - auto baseDiffType = tryGetDifferentialType(builder, ptrType->getValueType()); - if (!baseDiffType) return nullptr; - return builder->getPtrType( - baseDiffType, - ptrType->getClassInfo().m_name); + if (builtinRequirement->kind == BuiltinRequirementKind::DifferentialType || + builtinRequirement->kind == BuiltinRequirementKind::DifferentialPtrType) + { + // We are trying to get differential type from a differential type. + // The result is itself. + return type; + } } - else if (auto arrayType = as(type)) + type = resolveType(type); + auto witness = as( + tryGetInterfaceConformanceWitness(type, builder->getDifferentiableInterfaceType())); + if (!witness) + witness = as(tryGetInterfaceConformanceWitness( + type, + builder->getDifferentiableRefInterfaceType())); + if (witness) { - auto baseDiffType = tryGetDifferentialType(builder, arrayType->getElementType()); - if (!baseDiffType) return nullptr; - return builder->getArrayType( - baseDiffType, - arrayType->getElementCount()); - } + auto diffTypeLookupResult = lookUpMember( + getASTBuilder(), + this, + getName("Differential"), + type, + nullptr, + Slang::LookupMask::type, + Slang::LookupOptions::None); - if (auto declRefType = as(type)) - { - if (auto builtinRequirement = declRefType->getDeclRef().getDecl()->findModifier()) + diffTypeLookupResult = resolveOverloadedLookup(diffTypeLookupResult); + + if (!diffTypeLookupResult.isValid()) { - if (builtinRequirement->kind == BuiltinRequirementKind::DifferentialType - || builtinRequirement->kind == BuiltinRequirementKind::DifferentialPtrType) - { - // We are trying to get differential type from a differential type. - // The result is itself. - return type; - } + return nullptr; } - type = resolveType(type); - auto witness = as(tryGetInterfaceConformanceWitness(type, builder->getDifferentiableInterfaceType())); - if (!witness) - witness = as(tryGetInterfaceConformanceWitness(type, builder->getDifferentiableRefInterfaceType())); - if (witness) + else if (diffTypeLookupResult.isOverloaded()) { - auto diffTypeLookupResult = lookUpMember( - getASTBuilder(), - this, - getName("Differential"), - type, - nullptr, - Slang::LookupMask::type, - Slang::LookupOptions::None); - - diffTypeLookupResult = resolveOverloadedLookup(diffTypeLookupResult); - - if (!diffTypeLookupResult.isValid()) - { - return nullptr; - } - else if (diffTypeLookupResult.isOverloaded()) - { - return nullptr; - } - else - { - SharedTypeExpr* baseTypeExpr = m_astBuilder->create(); - baseTypeExpr->base.type = type; - baseTypeExpr->type.type = m_astBuilder->getTypeType(type); - - auto diffTypeExpr = ConstructLookupResultExpr( - diffTypeLookupResult.item, - baseTypeExpr, - declRefType->getDeclRef().getName(), - declRefType->getDeclRef().getLoc(), - baseTypeExpr); - - return resolveType(ExtractTypeFromTypeRepr(diffTypeExpr)); - } + return nullptr; + } + else + { + SharedTypeExpr* baseTypeExpr = m_astBuilder->create(); + baseTypeExpr->base.type = type; + baseTypeExpr->type.type = m_astBuilder->getTypeType(type); + + auto diffTypeExpr = ConstructLookupResultExpr( + diffTypeLookupResult.item, + baseTypeExpr, + declRefType->getDeclRef().getName(), + declRefType->getDeclRef().getLoc(), + baseTypeExpr); + + return resolveType(ExtractTypeFromTypeRepr(diffTypeExpr)); } } + } - if (auto typePack = as(type)) + if (auto typePack = as(type)) + { + bool anyDifferentiableElement = false; + List diffTypes; + for (Index i = 0; i < typePack->getTypeCount(); i++) { - bool anyDifferentiableElement = false; - List diffTypes; - for (Index i = 0; i < typePack->getTypeCount(); i++) - { - auto t = typePack->getElementType(i); - auto diffType = tryGetDifferentialType(builder, t); - if (!diffType) - diffType = m_astBuilder->getVoidType(); - else - anyDifferentiableElement = true; - diffTypes.add(diffType); - } - if (anyDifferentiableElement) - return builder->getTypePack(diffTypes.getArrayView()); + auto t = typePack->getElementType(i); + auto diffType = tryGetDifferentialType(builder, t); + if (!diffType) + diffType = m_astBuilder->getVoidType(); + else + anyDifferentiableElement = true; + diffTypes.add(diffType); } - return nullptr; + if (anyDifferentiableElement) + return builder->getTypePack(diffTypes.getArrayView()); } + return nullptr; +} - bool SemanticsVisitor::canStructBeUsedAsSelfDifferentialType(AggTypeDecl *aggTypeDecl) +bool SemanticsVisitor::canStructBeUsedAsSelfDifferentialType(AggTypeDecl* aggTypeDecl) +{ + // A struct can be used as its own differential type if all its members are differentiable + // and their differential types are the same as the original types. + // + bool canBeUsed = true; + for (auto member : aggTypeDecl->members) { - // A struct can be used as its own differential type if all its members are differentiable - // and their differential types are the same as the original types. - // - bool canBeUsed = true; - for (auto member : aggTypeDecl->members) + if (auto varDecl = as(member)) { - if (auto varDecl = as(member)) + // Try to get the differential type of the member. + Type* diffType = tryGetDifferentialType(getASTBuilder(), varDecl->getType()); + if (!diffType || !diffType->equals(varDecl->getType())) { - // Try to get the differential type of the member. - Type* diffType = tryGetDifferentialType(getASTBuilder(), varDecl->getType()); - if (!diffType || !diffType->equals(varDecl->getType())) - { - canBeUsed = false; - break; - } + canBeUsed = false; + break; } } - return canBeUsed; } + return canBeUsed; +} - void SemanticsVisitor::markSelfDifferentialMembersOfType(AggTypeDecl *parent, Type* type) +void SemanticsVisitor::markSelfDifferentialMembersOfType(AggTypeDecl* parent, Type* type) +{ + // TODO: Handle extensions. + // Add derivative member attributes to all the fields pointing to themselves. + for (auto member : parent->getMembersOfType()) { - // TODO: Handle extensions. - // Add derivative member attributes to all the fields pointing to themselves. - for (auto member : parent->getMembersOfType()) - { - auto derivativeMemberModifier = m_astBuilder->create(); - auto fieldLookupExpr = m_astBuilder->create(); - fieldLookupExpr->type.type = member->getType(); + auto derivativeMemberModifier = m_astBuilder->create(); + auto fieldLookupExpr = m_astBuilder->create(); + fieldLookupExpr->type.type = member->getType(); - auto baseTypeExpr = m_astBuilder->create(); - baseTypeExpr->base.type = type; - auto baseTypeType = m_astBuilder->getOrCreate(type); - baseTypeExpr->type.type = baseTypeType; - fieldLookupExpr->baseExpression = baseTypeExpr; + auto baseTypeExpr = m_astBuilder->create(); + baseTypeExpr->base.type = type; + auto baseTypeType = m_astBuilder->getOrCreate(type); + baseTypeExpr->type.type = baseTypeType; + fieldLookupExpr->baseExpression = baseTypeExpr; - fieldLookupExpr->declRef = makeDeclRef(member); + fieldLookupExpr->declRef = makeDeclRef(member); - derivativeMemberModifier->memberDeclRef = fieldLookupExpr; - addModifier(member, derivativeMemberModifier); - } + derivativeMemberModifier->memberDeclRef = fieldLookupExpr; + addModifier(member, derivativeMemberModifier); + } +} + +void SemanticsVisitor::checkDerivativeMemberAttributeReferences( + VarDeclBase* varDecl, + DerivativeMemberAttribute* derivativeMemberAttr) +{ + if (derivativeMemberAttr->memberDeclRef) + { + // Already checked! This usually happens if this attribute is synthesized by the compiler. + return; } - void SemanticsVisitor::checkDerivativeMemberAttributeReferences( - VarDeclBase* varDecl, DerivativeMemberAttribute* derivativeMemberAttr) + SLANG_ASSERT(derivativeMemberAttr->args.getCount() == 1); + auto checkedExpr = + dispatchExpr(derivativeMemberAttr->args[0], allowStaticReferenceToNonStaticMember()); + + auto memberType = varDecl->type.type; // All types must be fully checked by now. + auto diffType = getDifferentialType(m_astBuilder, memberType, varDecl->loc); + auto thisType = calcThisType(makeDeclRef(varDecl->parentDecl)); + if (!thisType) + return; // Diagnostic should have been emitted previously. + + auto diffThisType = getDifferentialType(m_astBuilder, thisType, derivativeMemberAttr->loc); + if (!diffThisType) + return; // Diagnostic should have been emitted previously. + + if (auto declRefExpr = as(checkedExpr)) { - if (derivativeMemberAttr->memberDeclRef) + derivativeMemberAttr->memberDeclRef = declRefExpr; + if (!diffType->equals(declRefExpr->type)) { - // Already checked! This usually happens if this attribute is synthesized by the compiler. - return; + getSink()->diagnose( + derivativeMemberAttr, + Diagnostics::typeMismatch, + diffType, + declRefExpr->type); } - - SLANG_ASSERT(derivativeMemberAttr->args.getCount() == 1); - auto checkedExpr = dispatchExpr(derivativeMemberAttr->args[0], allowStaticReferenceToNonStaticMember()); - - auto memberType = varDecl->type.type; // All types must be fully checked by now. - auto diffType = getDifferentialType(m_astBuilder, memberType, varDecl->loc); - auto thisType = calcThisType(makeDeclRef(varDecl->parentDecl)); - if (!thisType) return; // Diagnostic should have been emitted previously. - - auto diffThisType = getDifferentialType(m_astBuilder, thisType, derivativeMemberAttr->loc); - if (!diffThisType) return; // Diagnostic should have been emitted previously. - - if (auto declRefExpr = as(checkedExpr)) + if (!varDecl->parentDecl) { - derivativeMemberAttr->memberDeclRef = declRefExpr; - if (!diffType->equals(declRefExpr->type)) - { - getSink()->diagnose(derivativeMemberAttr, Diagnostics::typeMismatch, diffType, declRefExpr->type); - } - if (!varDecl->parentDecl) - { - getSink()->diagnose(derivativeMemberAttr, Diagnostics::attributeNotApplicable, diffType, declRefExpr->type); - } - if (auto memberExpr = as(declRefExpr)) + getSink()->diagnose( + derivativeMemberAttr, + Diagnostics::attributeNotApplicable, + diffType, + declRefExpr->type); + } + if (auto memberExpr = as(declRefExpr)) + { + auto baseExprType = memberExpr->baseExpression->type.type; + if (auto typeType = as(baseExprType)) { - auto baseExprType = memberExpr->baseExpression->type.type; - if (auto typeType = as(baseExprType)) + if (diffThisType->equals(typeType->getType())) { - if (diffThisType->equals(typeType->getType())) - { - return; - } + return; } - } } + } + getSink()->diagnose( + derivativeMemberAttr, + Diagnostics::derivativeMemberAttributeMustNameAMemberInExpectedDifferentialType, + diffThisType); +} + +Type* SemanticsVisitor::getDifferentialType(ASTBuilder* builder, Type* type, SourceLoc loc) +{ + auto result = tryGetDifferentialType(builder, type); + if (!result) + { getSink()->diagnose( - derivativeMemberAttr, - Diagnostics:: - derivativeMemberAttributeMustNameAMemberInExpectedDifferentialType, - diffThisType); + loc, + Diagnostics::typeDoesntImplementInterfaceRequirement, + type, + getName("Differential")); + return m_astBuilder->getErrorType(); } + return result; +} - Type* SemanticsVisitor::getDifferentialType(ASTBuilder* builder, Type* type, SourceLoc loc) +void SemanticsVisitor::addDifferentiableTypeToDiffTypeRegistry( + DeclRefType* type, + SubtypeWitness* witness) +{ + SLANG_RELEASE_ASSERT(m_parentDifferentiableAttr); + if (witness) { - auto result = tryGetDifferentialType(builder, type); - if (!result) - { - getSink()->diagnose(loc, Diagnostics::typeDoesntImplementInterfaceRequirement, type, getName("Differential")); - return m_astBuilder->getErrorType(); - } - return result; + m_parentDifferentiableAttr->addType(type->getDeclRef(), witness); } +} - void SemanticsVisitor::addDifferentiableTypeToDiffTypeRegistry(DeclRefType* type, SubtypeWitness* witness) +void SemanticsVisitor::maybeRegisterDifferentiableType(ASTBuilder* builder, Type* type) +{ + if (!builder->isDifferentiableInterfaceAvailable()) { - SLANG_RELEASE_ASSERT(m_parentDifferentiableAttr); - if (witness) - { - m_parentDifferentiableAttr->addType(type->getDeclRef(), witness); - } + return; } - void SemanticsVisitor::maybeRegisterDifferentiableType(ASTBuilder* builder, Type* type) + if (!m_parentDifferentiableAttr) { - if (!builder->isDifferentiableInterfaceAvailable()) - { - return; - } + return; + } - if (!m_parentDifferentiableAttr) - { - return; - } + maybeRegisterDifferentiableTypeImplRecursive(builder, type); +} - maybeRegisterDifferentiableTypeImplRecursive(builder, type); - } +void SemanticsVisitor::maybeRegisterDifferentiableTypeImplRecursive(ASTBuilder* builder, Type* type) +{ + // Recursively visit the tree of type and register all differentiable types along the way. - void SemanticsVisitor::maybeRegisterDifferentiableTypeImplRecursive(ASTBuilder* builder, Type* type) - { - // Recursively visit the tree of type and register all differentiable types along the way. - - if (as(type)) - return; - if (!type) - return; + if (as(type)) + return; + if (!type) + return; - // Have we already registered this type? If so we can exit now. - if (m_parentDifferentiableAttr->m_typeRegistrationWorkingSet.contains(type)) - return; + // Have we already registered this type? If so we can exit now. + if (m_parentDifferentiableAttr->m_typeRegistrationWorkingSet.contains(type)) + return; - m_parentDifferentiableAttr->m_typeRegistrationWorkingSet.add(type); + m_parentDifferentiableAttr->m_typeRegistrationWorkingSet.add(type); - // Check for special cases such as PtrTypeBase or Array - // This could potentially be handled later by simply defining extensions - // for Ptr etc.. - // - if (auto ptrType = as(type)) + // Check for special cases such as PtrTypeBase or Array + // This could potentially be handled later by simply defining extensions + // for Ptr etc.. + // + if (auto ptrType = as(type)) + { + maybeRegisterDifferentiableTypeImplRecursive(builder, ptrType->getValueType()); + return; + } + + if (auto arrayType = as(type)) + { + maybeRegisterDifferentiableTypeImplRecursive(builder, arrayType->getElementType()); + // Fall through to register the array type itself. + } + + if (auto declRefType = as(type)) + { + if (auto subtypeWitness = as(tryGetInterfaceConformanceWitness( + type, + getASTBuilder()->getDifferentiableInterfaceType()))) { - maybeRegisterDifferentiableTypeImplRecursive(builder, ptrType->getValueType()); - return; + addDifferentiableTypeToDiffTypeRegistry((DeclRefType*)type, subtypeWitness); } - if (auto arrayType = as(type)) + if (auto subtypeWitness = as(tryGetInterfaceConformanceWitness( + type, + getASTBuilder()->getDifferentiableRefInterfaceType()))) { - maybeRegisterDifferentiableTypeImplRecursive(builder, arrayType->getElementType()); - // Fall through to register the array type itself. + addDifferentiableTypeToDiffTypeRegistry((DeclRefType*)type, subtypeWitness); } - if (auto declRefType = as(type)) + if (auto aggTypeDeclRef = declRefType->getDeclRef().as()) { - if (auto subtypeWitness = as( - tryGetInterfaceConformanceWitness(type, getASTBuilder()->getDifferentiableInterfaceType()))) - { - addDifferentiableTypeToDiffTypeRegistry((DeclRefType*)type, subtypeWitness); - } - - if (auto subtypeWitness = as( - tryGetInterfaceConformanceWitness(type, getASTBuilder()->getDifferentiableRefInterfaceType()))) - { - addDifferentiableTypeToDiffTypeRegistry((DeclRefType*)type, subtypeWitness); - } - - if (auto aggTypeDeclRef = declRefType->getDeclRef().as()) - { - foreachDirectOrExtensionMemberOfType(this, aggTypeDeclRef, [&](DeclRef member) - { - auto subType = DeclRefType::create(m_astBuilder, member); - maybeRegisterDifferentiableTypeImplRecursive(m_astBuilder, subType); - }); - foreachDirectOrExtensionMemberOfType(this, aggTypeDeclRef, [&](DeclRef member) - { - auto fieldType = getType(m_astBuilder, member); - maybeRegisterDifferentiableTypeImplRecursive(m_astBuilder, fieldType); - }); - } - SubstitutionSet(declRefType->getDeclRef()).forEachSubstitutionArg([&](Val* arg) + foreachDirectOrExtensionMemberOfType( + this, + aggTypeDeclRef, + [&](DeclRef member) + { + auto subType = DeclRefType::create(m_astBuilder, member); + maybeRegisterDifferentiableTypeImplRecursive(m_astBuilder, subType); + }); + foreachDirectOrExtensionMemberOfType( + this, + aggTypeDeclRef, + [&](DeclRef member) + { + auto fieldType = getType(m_astBuilder, member); + maybeRegisterDifferentiableTypeImplRecursive(m_astBuilder, fieldType); + }); + } + SubstitutionSet(declRefType->getDeclRef()) + .forEachSubstitutionArg( + [&](Val* arg) { if (auto typeArg = as(arg)) { maybeRegisterDifferentiableTypeImplRecursive(m_astBuilder, typeArg); } }); - return; - } - - if (auto typePack = as(type)) - { - for (Index i = 0; i < typePack->getTypeCount(); i++) - maybeRegisterDifferentiableTypeImplRecursive(builder, typePack->getElementType(i)); - return; - } + return; } - - Expr* SemanticsVisitor::CheckTerm(Expr* term) + if (auto typePack = as(type)) { - // If we have already checked the expr, don't check again. - if (term->checked) - { - return term; - } - - auto checkedTerm = _CheckTerm(term); - checkedTerm->checked = true; - - // Differentiable type checking. - // TODO: This can be super slow. - if (this->m_parentFunc && - this->m_parentFunc->findModifier()) - { - maybeRegisterDifferentiableType(getASTBuilder(), checkedTerm->type.type); - } - return checkedTerm; + for (Index i = 0; i < typePack->getTypeCount(); i++) + maybeRegisterDifferentiableTypeImplRecursive(builder, typePack->getElementType(i)); + return; } +} + - Expr* SemanticsVisitor::_CheckTerm(Expr* term) +Expr* SemanticsVisitor::CheckTerm(Expr* term) +{ + // If we have already checked the expr, don't check again. + if (term->checked) { - if (!term) return nullptr; + return term; + } - // The process of checking a term/expression can end up introducing - // temporaries that need to be added to an outer scope. When jumping - // into expression checking, we want to check if we already have such - // a scope in place. If we do, we will re-use it for any sub-expressions. - // If not, we need to create one. - // - if (getExprLocalScope()) - { - return dispatchExpr(term, *this); - } + auto checkedTerm = _CheckTerm(term); + checkedTerm->checked = true; - ExprLocalScope exprLocalScope; + // Differentiable type checking. + // TODO: This can be super slow. + if (this->m_parentFunc && this->m_parentFunc->findModifier()) + { + maybeRegisterDifferentiableType(getASTBuilder(), checkedTerm->type.type); + } + return checkedTerm; +} - Expr* checkedTerm = dispatchExpr(term, withExprLocalScope(&exprLocalScope)); +Expr* SemanticsVisitor::_CheckTerm(Expr* term) +{ + if (!term) + return nullptr; - if (IsErrorExpr(checkedTerm)) - return checkedTerm; + // The process of checking a term/expression can end up introducing + // temporaries that need to be added to an outer scope. When jumping + // into expression checking, we want to check if we already have such + // a scope in place. If we do, we will re-use it for any sub-expressions. + // If not, we need to create one. + // + if (getExprLocalScope()) + { + return dispatchExpr(term, *this); + } - LetExpr* outerMostBinding = exprLocalScope.getOuterMostBinding(); - if(!outerMostBinding) - { - return checkedTerm; - } + ExprLocalScope exprLocalScope; - LetExpr* binding = outerMostBinding; - auto type = checkedTerm->type; - while (binding) - { - binding->type = type; + Expr* checkedTerm = dispatchExpr(term, withExprLocalScope(&exprLocalScope)); - if (const auto body = binding->body) - { - binding = as(binding->body); - SLANG_ASSERT(binding); - continue; - } - else - { - binding->body = checkedTerm; - break; - } - } + if (IsErrorExpr(checkedTerm)) + return checkedTerm; - return outerMostBinding; + LetExpr* outerMostBinding = exprLocalScope.getOuterMostBinding(); + if (!outerMostBinding) + { + return checkedTerm; } - Expr* SemanticsVisitor::CreateErrorExpr(Expr* expr) + LetExpr* binding = outerMostBinding; + auto type = checkedTerm->type; + while (binding) { - if (!expr) + binding->type = type; + + if (const auto body = binding->body) { - expr = m_astBuilder->create(); + binding = as(binding->body); + SLANG_ASSERT(binding); + continue; + } + else + { + binding->body = checkedTerm; + break; } - expr->type = QualType(m_astBuilder->getErrorType()); - return expr; } - bool SemanticsVisitor::IsErrorExpr(Expr* expr) + return outerMostBinding; +} + +Expr* SemanticsVisitor::CreateErrorExpr(Expr* expr) +{ + if (!expr) { - // TODO: we may want other cases here... + expr = m_astBuilder->create(); + } + expr->type = QualType(m_astBuilder->getErrorType()); + return expr; +} - if (const auto errorType = as(expr->type)) - return true; +bool SemanticsVisitor::IsErrorExpr(Expr* expr) +{ + // TODO: we may want other cases here... - return false; - } + if (const auto errorType = as(expr->type)) + return true; - Expr* SemanticsVisitor::GetBaseExpr(Expr* expr) + return false; +} + +Expr* SemanticsVisitor::GetBaseExpr(Expr* expr) +{ + if (auto memberExpr = as(expr)) { - if (auto memberExpr = as(expr)) - { - return memberExpr->baseExpression; - } - else if(auto overloadedExpr = as(expr)) - { - return overloadedExpr->base; - } - else if (auto overloadedExpr2 = as(expr)) - { - return overloadedExpr2->base; - } - else if (auto genApp = as(expr)) - { - return GetBaseExpr(genApp->functionExpr); - } - else if (auto partiallyApplied = as(expr)) - { - return GetBaseExpr(partiallyApplied->originalExpr); - } - return nullptr; + return memberExpr->baseExpression; } - - Expr* SemanticsExprVisitor::visitIncompleteExpr(IncompleteExpr* expr) + else if (auto overloadedExpr = as(expr)) { - expr->type = m_astBuilder->getErrorType(); - return expr; + return overloadedExpr->base; } - - Expr* SemanticsExprVisitor::visitBoolLiteralExpr(BoolLiteralExpr* expr) + else if (auto overloadedExpr2 = as(expr)) { - expr->type = m_astBuilder->getBoolType(); - return expr; + return overloadedExpr2->base; } - - Expr* SemanticsExprVisitor::visitNullPtrLiteralExpr(NullPtrLiteralExpr* expr) + else if (auto genApp = as(expr)) { - expr->type = m_astBuilder->getNullPtrType(); - return expr; + return GetBaseExpr(genApp->functionExpr); } - - Expr* SemanticsExprVisitor::visitNoneLiteralExpr(NoneLiteralExpr* expr) + else if (auto partiallyApplied = as(expr)) { - expr->type = m_astBuilder->getNoneType(); - return expr; + return GetBaseExpr(partiallyApplied->originalExpr); } + return nullptr; +} - Expr* SemanticsExprVisitor::visitIntegerLiteralExpr(IntegerLiteralExpr* expr) - { - // The expression might already have a type, determined by its suffix. - // It it doesn't, we will give it a default type. - // - // TODO: We should be careful to pick a "big enough" type - // based on the size of the value (e.g., don't try to stuff - // a constant in an `int` if it requires 64 or more bits). - // - // The long-term solution here is to give a type to a literal - // based on the context where it is used, but that requires - // a more sophisticated type system than we have today. - // - if(!expr->type.type) - { - expr->type = m_astBuilder->getBuiltinType(expr->suffixType); - } - return expr; +Expr* SemanticsExprVisitor::visitIncompleteExpr(IncompleteExpr* expr) +{ + expr->type = m_astBuilder->getErrorType(); + return expr; +} + +Expr* SemanticsExprVisitor::visitBoolLiteralExpr(BoolLiteralExpr* expr) +{ + expr->type = m_astBuilder->getBoolType(); + return expr; +} + +Expr* SemanticsExprVisitor::visitNullPtrLiteralExpr(NullPtrLiteralExpr* expr) +{ + expr->type = m_astBuilder->getNullPtrType(); + return expr; +} + +Expr* SemanticsExprVisitor::visitNoneLiteralExpr(NoneLiteralExpr* expr) +{ + expr->type = m_astBuilder->getNoneType(); + return expr; +} + +Expr* SemanticsExprVisitor::visitIntegerLiteralExpr(IntegerLiteralExpr* expr) +{ + // The expression might already have a type, determined by its suffix. + // It it doesn't, we will give it a default type. + // + // TODO: We should be careful to pick a "big enough" type + // based on the size of the value (e.g., don't try to stuff + // a constant in an `int` if it requires 64 or more bits). + // + // The long-term solution here is to give a type to a literal + // based on the context where it is used, but that requires + // a more sophisticated type system than we have today. + // + if (!expr->type.type) + { + expr->type = m_astBuilder->getBuiltinType(expr->suffixType); } + return expr; +} - Expr* SemanticsExprVisitor::visitFloatingPointLiteralExpr(FloatingPointLiteralExpr* expr) +Expr* SemanticsExprVisitor::visitFloatingPointLiteralExpr(FloatingPointLiteralExpr* expr) +{ + if (!expr->type.type) { - if(!expr->type.type) - { - expr->type = m_astBuilder->getBuiltinType(expr->suffixType); - } - return expr; + expr->type = m_astBuilder->getBuiltinType(expr->suffixType); } + return expr; +} + +Expr* SemanticsExprVisitor::visitStringLiteralExpr(StringLiteralExpr* expr) +{ + expr->type = m_astBuilder->getStringType(); + return expr; +} + +IntVal* SemanticsVisitor::getIntVal(IntegerLiteralExpr* expr) +{ + return m_astBuilder->getIntVal(expr->type.type, expr->value); +} - Expr* SemanticsExprVisitor::visitStringLiteralExpr(StringLiteralExpr* expr) +IntVal* SemanticsVisitor::tryConstantFoldExpr( + SubstExpr invokeExpr, + ConstantFoldingKind kind, + ConstantFoldingCircularityInfo* circularityInfo) +{ + // We need all the operands to the expression + + // Check if the callee is an operation that is amenable to constant-folding. + // + // For right now we will look for calls to intrinsic functions, and then inspect + // their names (this is bad and slow). + auto funcDeclRefExpr = getBaseExpr(invokeExpr).as(); + if (!funcDeclRefExpr) + return nullptr; + + auto funcDeclRef = getDeclRef(m_astBuilder, funcDeclRefExpr); + if (!funcDeclRef) + return nullptr; + auto intrinsicMod = funcDeclRef.getDecl()->findModifier(); + auto implicitCast = funcDeclRef.getDecl()->findModifier(); + if (!intrinsicMod && !implicitCast) { - expr->type = m_astBuilder->getStringType(); - return expr; + // We can't constant fold anything that doesn't map to a builtin + // operation right now. + // + // TODO: we should really allow constant-folding for anything + // that can be lowered to our bytecode... + return nullptr; } - IntVal* SemanticsVisitor::getIntVal(IntegerLiteralExpr* expr) + // Let's not constant-fold operations with more than a certain number of arguments, for + // simplicity + static const int kMaxArgs = 8; + auto argCount = getArgCount(invokeExpr); + if (argCount > kMaxArgs) + return nullptr; + + // Before checking the operation name, let's look at the arguments + IntVal* argVals[kMaxArgs]; + IntegerLiteralValue constArgVals[kMaxArgs]; + bool allConst = true; + for (Index a = 0; a < argCount; ++a) { - return m_astBuilder->getIntVal(expr->type.type, expr->value); + auto argExpr = getArg(invokeExpr, a); + auto argVal = tryFoldIntegerConstantExpression(argExpr, kind, circularityInfo); + if (!argVal) + return nullptr; + + argVals[a] = argVal; + + if (auto constArgVal = as(argVal)) + { + constArgVals[a] = constArgVal->getValue(); + } + else + { + allConst = false; + } } - IntVal* SemanticsVisitor::tryConstantFoldExpr( - SubstExpr invokeExpr, - ConstantFoldingKind kind, - ConstantFoldingCircularityInfo* circularityInfo) + if (!allConst) { - // We need all the operands to the expression - - // Check if the callee is an operation that is amenable to constant-folding. + // We support a very limited number of operations + // on "constants" that aren't actually known, to be able to handle a generic + // that takes an integer `N` but then constructs a vector of size `N+1`. // - // For right now we will look for calls to intrinsic functions, and then inspect - // their names (this is bad and slow). - auto funcDeclRefExpr = getBaseExpr(invokeExpr).as(); - if (!funcDeclRefExpr) return nullptr; - - auto funcDeclRef = getDeclRef(m_astBuilder, funcDeclRefExpr); - if (!funcDeclRef) - return nullptr; - auto intrinsicMod = funcDeclRef.getDecl()->findModifier(); - auto implicitCast = funcDeclRef.getDecl()->findModifier(); - if (!intrinsicMod && !implicitCast) + // The hard part there is implementing the rules for value unification in the + // presence of more complicated `IntVal` subclasses, like `SumIntVal`. You'd + // need inference to be smart enough to know that `2 + N` and `N + 2` are the + // same value, as are `N + M + 1 + 1` and `M + 2 + N`. + // + // This is done by constructing a 'PolynomialIntVal' and rely on its + // `canonicalize` operation. + if (implicitCast) { - // We can't constant fold anything that doesn't map to a builtin - // operation right now. - // - // TODO: we should really allow constant-folding for anything - // that can be lowered to our bytecode... + // We cannot support casting in this case. return nullptr; } - // Let's not constant-fold operations with more than a certain number of arguments, for simplicity - static const int kMaxArgs = 8; - auto argCount = getArgCount(invokeExpr); - if (argCount > kMaxArgs) - return nullptr; + auto opName = funcDeclRef.getName(); - // Before checking the operation name, let's look at the arguments - IntVal* argVals[kMaxArgs]; - IntegerLiteralValue constArgVals[kMaxArgs]; - bool allConst = true; - for(Index a = 0; a < argCount; ++a) + // handle binary operators + if (opName == getName("-")) { - auto argExpr = getArg(invokeExpr, a); - auto argVal = tryFoldIntegerConstantExpression(argExpr, kind, circularityInfo); - if (!argVal) - return nullptr; - - argVals[a] = argVal; - - if (auto constArgVal = as(argVal)) + if (argCount == 1) { - constArgVals[a] = constArgVal->getValue(); + return PolynomialIntVal::neg(m_astBuilder, argVals[0]); } - else + else if (argCount == 2) { - allConst = false; + return PolynomialIntVal::sub(m_astBuilder, argVals[0], argVals[1]); } } - - if (!allConst) + else if (opName == getName("+")) { - // We support a very limited number of operations - // on "constants" that aren't actually known, to be able to handle a generic - // that takes an integer `N` but then constructs a vector of size `N+1`. - // - // The hard part there is implementing the rules for value unification in the - // presence of more complicated `IntVal` subclasses, like `SumIntVal`. You'd - // need inference to be smart enough to know that `2 + N` and `N + 2` are the - // same value, as are `N + M + 1 + 1` and `M + 2 + N`. - // - // This is done by constructing a 'PolynomialIntVal' and rely on its - // `canonicalize` operation. - if (implicitCast) - { - // We cannot support casting in this case. - return nullptr; - } - - auto opName = funcDeclRef.getName(); - - // handle binary operators - if (opName == getName("-")) - { - if (argCount == 1) - { - return PolynomialIntVal::neg(m_astBuilder, argVals[0]); - } - else if (argCount == 2) - { - return PolynomialIntVal::sub(m_astBuilder, argVals[0], argVals[1]); - } - } - else if (opName == getName("+")) + if (argCount == 1) { - if (argCount == 1) - { - return argVals[0]; - } - else if (argCount == 2) - { - return PolynomialIntVal::add(m_astBuilder, argVals[0], argVals[1]); - } + return argVals[0]; } - else if (opName == getName("*")) + else if (argCount == 2) { - if (argCount == 2) - { - return PolynomialIntVal::mul(m_astBuilder, argVals[0], argVals[1]); - } + return PolynomialIntVal::add(m_astBuilder, argVals[0], argVals[1]); } - else if (opName == getName("/") || opName == getName("==") || opName == getName(">=") || opName == getName("<=") || opName == getName("!=") - || opName == getName(">") || opName == getName("<") || opName == getName("&&") || opName == getName("||") || opName == getName("!") - || opName == getName("|") || opName == getName("&") || opName == getName("^") || opName == getName("~") || opName == getName("%") || - opName == getName("?:") || opName == getName("<<") || opName == getName(">>")) + } + else if (opName == getName("*")) + { + if (argCount == 2) { - auto result = m_astBuilder->getOrCreate( - invokeExpr.getExpr()->type.type, - funcDeclRef, - as(funcDeclRefExpr.getExpr()->type->substitute( - m_astBuilder, funcDeclRefExpr.getSubsts())), - makeArrayView(argVals, argCount)); - SLANG_RELEASE_ASSERT(result->getFuncType()); - return result; + return PolynomialIntVal::mul(m_astBuilder, argVals[0], argVals[1]); } - return nullptr; } - - // At this point, all the operands had simple integer values, so we are golden. - IntegerLiteralValue resultValue = 0; - // If this is an implicit cast, we can try to fold. - if (implicitCast) - { - auto targetBasicType = as(invokeExpr.getExpr()->type.type); - if (!targetBasicType) - return nullptr; - auto foldVal = as( - TypeCastIntVal::tryFoldImpl(m_astBuilder, targetBasicType, argVals[0], getSink())); - if (foldVal) - return foldVal; - auto result = m_astBuilder->getTypeCastIntVal(targetBasicType, argVals[0]); + else if ( + opName == getName("/") || opName == getName("==") || opName == getName(">=") || + opName == getName("<=") || opName == getName("!=") || opName == getName(">") || + opName == getName("<") || opName == getName("&&") || opName == getName("||") || + opName == getName("!") || opName == getName("|") || opName == getName("&") || + opName == getName("^") || opName == getName("~") || opName == getName("%") || + opName == getName("?:") || opName == getName("<<") || opName == getName(">>")) + { + auto result = m_astBuilder->getOrCreate( + invokeExpr.getExpr()->type.type, + funcDeclRef, + as(funcDeclRefExpr.getExpr()->type->substitute( + m_astBuilder, + funcDeclRefExpr.getSubsts())), + makeArrayView(argVals, argCount)); + SLANG_RELEASE_ASSERT(result->getFuncType()); return result; } - else - { - auto opName = funcDeclRef.getName(); + return nullptr; + } - // handle binary operators - if (opName == getName("-")) - { - if (argCount == 1) - { - resultValue = -constArgVals[0]; - } - else if (argCount == 2) - { - resultValue = constArgVals[0] - constArgVals[1]; - } - } - else if (opName == getName("!")) + // At this point, all the operands had simple integer values, so we are golden. + IntegerLiteralValue resultValue = 0; + // If this is an implicit cast, we can try to fold. + if (implicitCast) + { + auto targetBasicType = as(invokeExpr.getExpr()->type.type); + if (!targetBasicType) + return nullptr; + auto foldVal = as( + TypeCastIntVal::tryFoldImpl(m_astBuilder, targetBasicType, argVals[0], getSink())); + if (foldVal) + return foldVal; + auto result = m_astBuilder->getTypeCastIntVal(targetBasicType, argVals[0]); + return result; + } + else + { + auto opName = funcDeclRef.getName(); + + // handle binary operators + if (opName == getName("-")) + { + if (argCount == 1) { - resultValue = constArgVals[0] != 0; + resultValue = -constArgVals[0]; } - else if (opName == getName("~")) + else if (argCount == 2) { - resultValue = ~constArgVals[0]; + resultValue = constArgVals[0] - constArgVals[1]; } - - // simple binary operators -#define CASE(OP) \ - else if(opName == getName(#OP)) do { \ - if(argCount != 2) return nullptr; \ - resultValue = constArgVals[0] OP constArgVals[1]; \ - } while(0) - - CASE(+); // TODO: this can also be unary... - CASE(*); - CASE(<<); - CASE(>>); - CASE(&); - CASE(|); - CASE(^); - CASE(!=); - CASE(==); - CASE(>=); - CASE(<=); - CASE(<); - CASE(>); + } + else if (opName == getName("!")) + { + resultValue = constArgVals[0] != 0; + } + else if (opName == getName("~")) + { + resultValue = ~constArgVals[0]; + } + + // simple binary operators +#define CASE(OP) \ + else if (opName == getName(#OP)) do \ + { \ + if (argCount != 2) \ + return nullptr; \ + resultValue = constArgVals[0] OP constArgVals[1]; \ + } \ + while (0) + + CASE(+); // TODO: this can also be unary... + CASE(*); + CASE(<<); + CASE(>>); + CASE(&); + CASE(|); + CASE(^); + CASE(!=); + CASE(==); + CASE(>=); + CASE(<=); + CASE(<); + CASE(>); #undef CASE - // binary operators with chance of divide-by-zero - // TODO: issue a suitable error in that case -#define CASE(OP) \ - else if(opName == getName(#OP)) do { \ - if(argCount != 2) return nullptr; \ - if(!constArgVals[1]) return nullptr; \ - resultValue = constArgVals[0] OP constArgVals[1]; \ - } while(0) - CASE(/); - CASE(%); + // binary operators with chance of divide-by-zero + // TODO: issue a suitable error in that case +#define CASE(OP) \ + else if (opName == getName(#OP)) do \ + { \ + if (argCount != 2) \ + return nullptr; \ + if (!constArgVals[1]) \ + return nullptr; \ + resultValue = constArgVals[0] OP constArgVals[1]; \ + } \ + while (0) + CASE(/); + CASE(%); #undef CASE - else if (opName == getName("?:")) - { - if (argCount != 3) - return nullptr; - if (constArgVals[0] != 0) - resultValue = constArgVals[1]; - else - resultValue = constArgVals[2]; - } - // TODO(tfoley): more cases - else - { + else if (opName == getName("?:")) + { + if (argCount != 3) return nullptr; - } + if (constArgVals[0] != 0) + resultValue = constArgVals[1]; + else + resultValue = constArgVals[2]; } - - IntVal* result = m_astBuilder->getIntVal(invokeExpr.getExpr()->type.type, resultValue); - return result; - } - - bool SemanticsVisitor::_checkForCircularityInConstantFolding( - Decl* decl, - ConstantFoldingCircularityInfo* circularityInfo) - { - // TODO: If the `decl` is already on the chain of `circularityInfo`, - // then we know that we are trying to recursively fold the - // same declaration as part of its own definition, and we need - // to diagnose that as an error. - // - for( auto info = circularityInfo; info; info = info->next ) + // TODO(tfoley): more cases + else { - if(decl == info->decl) - { - getSink()->diagnose(decl, Diagnostics::variableUsedInItsOwnDefinition, decl); - return true; - } + return nullptr; } - - return false; } - IntVal* SemanticsVisitor::tryConstantFoldDeclRef( - DeclRef const& declRef, - ConstantFoldingKind kind, - ConstantFoldingCircularityInfo* circularityInfo) - { - auto decl = declRef.getDecl(); - - if(_checkForCircularityInConstantFolding(decl, circularityInfo)) - return nullptr; + IntVal* result = m_astBuilder->getIntVal(invokeExpr.getExpr()->type.type, resultValue); + return result; +} - // In HLSL, `const` is used to mark compile-time constant expressions. - if(!decl->hasModifier()) - return nullptr; - if (decl->hasModifier()) +bool SemanticsVisitor::_checkForCircularityInConstantFolding( + Decl* decl, + ConstantFoldingCircularityInfo* circularityInfo) +{ + // TODO: If the `decl` is already on the chain of `circularityInfo`, + // then we know that we are trying to recursively fold the + // same declaration as part of its own definition, and we need + // to diagnose that as an error. + // + for (auto info = circularityInfo; info; info = info->next) + { + if (decl == info->decl) { - // Extern const is not considered compile-time constant by the front-end. - if (kind == ConstantFoldingKind::CompileTime) - return nullptr; - // But if we are OK with link-time constants, we can still fold it into a val. - auto rs = m_astBuilder->getOrCreate( - declRef.substitute(m_astBuilder, declRef.getDecl()->getType()), - declRef); - return rs; + getSink()->diagnose(decl, Diagnostics::variableUsedInItsOwnDefinition, decl); + return true; } + } - if (isInterfaceRequirement(decl)) - { - auto witness = findThisTypeWitness(SubstitutionSet(declRef), as(decl->parentDecl)); + return false; +} - auto val = WitnessLookupIntVal::tryFold( - m_astBuilder, - witness, - decl, - declRef.substitute(m_astBuilder, decl->type.type)); - return as(val); - } +IntVal* SemanticsVisitor::tryConstantFoldDeclRef( + DeclRef const& declRef, + ConstantFoldingKind kind, + ConstantFoldingCircularityInfo* circularityInfo) +{ + auto decl = declRef.getDecl(); + + if (_checkForCircularityInConstantFolding(decl, circularityInfo)) + return nullptr; - if (!getInitExpr(m_astBuilder, declRef)) + // In HLSL, `const` is used to mark compile-time constant expressions. + if (!decl->hasModifier()) + return nullptr; + if (decl->hasModifier()) + { + // Extern const is not considered compile-time constant by the front-end. + if (kind == ConstantFoldingKind::CompileTime) return nullptr; + // But if we are OK with link-time constants, we can still fold it into a val. + auto rs = m_astBuilder->getOrCreate( + declRef.substitute(m_astBuilder, declRef.getDecl()->getType()), + declRef); + return rs; + } + + if (isInterfaceRequirement(decl)) + { + auto witness = + findThisTypeWitness(SubstitutionSet(declRef), as(decl->parentDecl)); - ensureDecl(declRef.getDecl(), DeclCheckState::DefinitionChecked); - ConstantFoldingCircularityInfo newCircularityInfo(decl, circularityInfo); - return tryConstantFoldExpr(getInitExpr(m_astBuilder, declRef), kind, &newCircularityInfo); + auto val = WitnessLookupIntVal::tryFold( + m_astBuilder, + witness, + decl, + declRef.substitute(m_astBuilder, decl->type.type)); + return as(val); } - IntVal* SemanticsVisitor::tryConstantFoldExpr( - SubstExpr expr, - ConstantFoldingKind kind, - ConstantFoldingCircularityInfo* circularityInfo) + if (!getInitExpr(m_astBuilder, declRef)) + return nullptr; + + ensureDecl(declRef.getDecl(), DeclCheckState::DefinitionChecked); + ConstantFoldingCircularityInfo newCircularityInfo(decl, circularityInfo); + return tryConstantFoldExpr(getInitExpr(m_astBuilder, declRef), kind, &newCircularityInfo); +} + +IntVal* SemanticsVisitor::tryConstantFoldExpr( + SubstExpr expr, + ConstantFoldingKind kind, + ConstantFoldingCircularityInfo* circularityInfo) +{ + + // Unwrap any "identity" expressions + while (auto parenExpr = expr.as()) { - - // Unwrap any "identity" expressions - while (auto parenExpr = expr.as()) - { - expr = getBaseExpr(parenExpr); - } + expr = getBaseExpr(parenExpr); + } - if (auto intLitExpr = expr.as()) - { - return getIntVal(intLitExpr); - } + if (auto intLitExpr = expr.as()) + { + return getIntVal(intLitExpr); + } - if (auto boolLitExpr = expr.as()) - { - // If it's a boolean, we allow promotion to int. - const IntegerLiteralValue value = IntegerLiteralValue(boolLitExpr.getExpr()->value); - return m_astBuilder->getIntVal(m_astBuilder->getBoolType(), value); - } + if (auto boolLitExpr = expr.as()) + { + // If it's a boolean, we allow promotion to int. + const IntegerLiteralValue value = IntegerLiteralValue(boolLitExpr.getExpr()->value); + return m_astBuilder->getIntVal(m_astBuilder->getBoolType(), value); + } - if (auto arrayLengthExpr = expr.as()) + if (auto arrayLengthExpr = expr.as()) + { + if (arrayLengthExpr.getExpr()->arrayExpr && arrayLengthExpr.getExpr()->arrayExpr->type) { - if (arrayLengthExpr.getExpr()->arrayExpr && arrayLengthExpr.getExpr()->arrayExpr->type) + auto type = arrayLengthExpr.getExpr()->arrayExpr->type.type->substitute( + m_astBuilder, + expr.getSubsts()); + if (auto arrayType = as(type)) { - auto type = arrayLengthExpr.getExpr()->arrayExpr->type.type->substitute(m_astBuilder, expr.getSubsts()); - if (auto arrayType = as(type)) + if (!arrayType->isUnsized()) { - if (!arrayType->isUnsized()) - { - if (auto val = as(arrayType->getElementCount())) - return val; - } + if (auto val = as(arrayType->getElementCount())) + return val; } } } + } - if (auto countOfExpr = expr.as()) + if (auto countOfExpr = expr.as()) + { + auto type = + as(countOfExpr.getExpr()->sizedType->substitute(m_astBuilder, expr.getSubsts())); + if (type) + return as( + CountOfIntVal::tryFold(m_astBuilder, expr.getExpr()->type.type, type)); + } + + // it is possible that we are referring to a generic value param + if (auto declRefExpr = expr.as()) + { + auto declRef = getDeclRef(m_astBuilder, declRefExpr); + + if (auto genericValParamRef = declRef.as()) { - auto type = as(countOfExpr.getExpr()->sizedType->substitute(m_astBuilder, expr.getSubsts())); - if (type) - return as(CountOfIntVal::tryFold(m_astBuilder, expr.getExpr()->type.type, type)); + Val* valResult = m_astBuilder->getOrCreate( + declRef.substitute(m_astBuilder, genericValParamRef.getDecl()->getType()), + genericValParamRef); + valResult = valResult->substitute(m_astBuilder, expr.getSubsts()); + return as(valResult); } - // it is possible that we are referring to a generic value param - if (auto declRefExpr = expr.as()) + // We may also need to check for references to variables that + // are defined in a way that can be used as a constant expression: + if (auto varRef = declRef.as()) { - auto declRef = getDeclRef(m_astBuilder, declRefExpr); - - if (auto genericValParamRef = declRef.as()) - { - Val* valResult = m_astBuilder->getOrCreate( - declRef.substitute(m_astBuilder, genericValParamRef.getDecl()->getType()), - genericValParamRef); - valResult = valResult->substitute(m_astBuilder, expr.getSubsts()); - return as(valResult); - } - - // We may also need to check for references to variables that - // are defined in a way that can be used as a constant expression: - if(auto varRef = declRef.as()) - { - return tryConstantFoldDeclRef(varRef, kind, circularityInfo); - } - else if(auto enumRef = declRef.as()) - { - // The cases in an `enum` declaration can also be used as constant expressions, - if(auto tagExpr = getTagExpr(m_astBuilder, enumRef)) - { - auto enumCaseDecl = enumRef.getDecl(); - if(_checkForCircularityInConstantFolding(enumCaseDecl, circularityInfo)) - return nullptr; - - ConstantFoldingCircularityInfo newCircularityInfo(enumCaseDecl, circularityInfo); - auto intVal = as(tryConstantFoldExpr(tagExpr, kind, &newCircularityInfo)); - if (!intVal) - return nullptr; - return as(m_astBuilder->getTypeCastIntVal(enumCaseDecl->getType(), intVal)->resolve()); - } - } + return tryConstantFoldDeclRef(varRef, kind, circularityInfo); } - - SubstExpr typeCastOperand; - if (auto typeCastExpr = expr.as()) - typeCastOperand = getArg(typeCastExpr, 0); - else if (auto builtinCastExpr = expr.as()) - typeCastOperand = getBaseExpr(builtinCastExpr); - - if (typeCastOperand) + else if (auto enumRef = declRef.as()) { - auto substType = getType(m_astBuilder, expr); - if (!substType) - return nullptr; - if (!isValidCompileTimeConstantType(substType)) - return nullptr; - auto val = tryConstantFoldExpr(typeCastOperand, kind, circularityInfo); - if (val) + // The cases in an `enum` declaration can also be used as constant expressions, + if (auto tagExpr = getTagExpr(m_astBuilder, enumRef)) { - if (!expr.getExpr()->type) + auto enumCaseDecl = enumRef.getDecl(); + if (_checkForCircularityInConstantFolding(enumCaseDecl, circularityInfo)) + return nullptr; + + ConstantFoldingCircularityInfo newCircularityInfo(enumCaseDecl, circularityInfo); + auto intVal = as(tryConstantFoldExpr(tagExpr, kind, &newCircularityInfo)); + if (!intVal) return nullptr; - auto foldVal = as( - TypeCastIntVal::tryFoldImpl(m_astBuilder, substType, val, getSink())); - if (foldVal) - return foldVal; - auto result = m_astBuilder->getTypeCastIntVal(substType, val); - return result; + return as( + m_astBuilder->getTypeCastIntVal(enumCaseDecl->getType(), intVal)->resolve()); } } - else if (auto invokeExpr = expr.as()) - { - auto val = tryConstantFoldExpr(invokeExpr, kind, circularityInfo); - if (val) - return val; - } - else if (auto sizeOfLikeExpr = as(expr.getExpr())) + } + + SubstExpr typeCastOperand; + if (auto typeCastExpr = expr.as()) + typeCastOperand = getArg(typeCastExpr, 0); + else if (auto builtinCastExpr = expr.as()) + typeCastOperand = getBaseExpr(builtinCastExpr); + + if (typeCastOperand) + { + auto substType = getType(m_astBuilder, expr); + if (!substType) + return nullptr; + if (!isValidCompileTimeConstantType(substType)) + return nullptr; + auto val = tryConstantFoldExpr(typeCastOperand, kind, circularityInfo); + if (val) { - ASTNaturalLayoutContext context(getASTBuilder(), nullptr); - const auto size = context.calcSize(sizeOfLikeExpr->sizedType); - if (!size) - { + if (!expr.getExpr()->type) return nullptr; - } - - auto value = as(sizeOfLikeExpr) ? - size.alignment : - size.size; - - // We can return as an IntVal - return getASTBuilder()->getIntVal(expr.getExpr()->type, value); + auto foldVal = + as(TypeCastIntVal::tryFoldImpl(m_astBuilder, substType, val, getSink())); + if (foldVal) + return foldVal; + auto result = m_astBuilder->getTypeCastIntVal(substType, val); + return result; } - else if (auto indexExpr = expr.as()) + } + else if (auto invokeExpr = expr.as()) + { + auto val = tryConstantFoldExpr(invokeExpr, kind, circularityInfo); + if (val) + return val; + } + else if (auto sizeOfLikeExpr = as(expr.getExpr())) + { + ASTNaturalLayoutContext context(getASTBuilder(), nullptr); + const auto size = context.calcSize(sizeOfLikeExpr->sizedType); + if (!size) { - return tryFoldIndexExpr(indexExpr.getExpr(), kind, circularityInfo); + return nullptr; } - return nullptr; - } - IntVal* SemanticsVisitor::tryFoldIndexExpr( - SubstExpr expr, - ConstantFoldingKind kind, - ConstantFoldingCircularityInfo* circularityInfo) + auto value = as(sizeOfLikeExpr) ? size.alignment : size.size; + + // We can return as an IntVal + return getASTBuilder()->getIntVal(expr.getExpr()->type, value); + } + else if (auto indexExpr = expr.as()) { - // Ad-hoc constant folding for index expressions. - // TOOD: we should generalize this by extending `Val` to support compile-time constants that are - // not just integers, but also arrays and structs etc, so that we can independently fold - // the base expression and the index expression, and then form a ElementExtractVal() from an - // index expr. - // For now we just specialize case for array expression that is an initialization list. - // And this won't work if the array is a link-time constant. - // - auto declRefExpr = as(expr.getExpr()->baseExpression); - if (!declRefExpr) - return nullptr; - auto varDecl = as(declRefExpr->declRef.getDecl()); - if (!varDecl) - return nullptr; - auto type = varDecl->getType(); - if (!type) - return nullptr; - auto arrayType = as(type); - if (!arrayType) - return nullptr; - if (!varDecl->hasModifier()) - return nullptr; - if (isGlobalDecl(varDecl) && !varDecl->hasModifier()) - return nullptr; - if (!varDecl->initExpr) - return nullptr; - auto arrayContentExpr = as(varDecl->initExpr); - if (!arrayContentExpr) - return nullptr; - if (expr.getExpr()->indexExprs.getCount() != 1) - return nullptr; - auto indexVal = as(tryFoldIntegerConstantExpression( - expr.getExpr()->indexExprs[0], kind, circularityInfo)); - if (!indexVal) - return nullptr; - auto index = indexVal->getValue(); - if (index < 0 || index >= arrayContentExpr->args.getCount()) - return nullptr; - auto elementExpr = arrayContentExpr->args[Index(index)]; - return tryFoldIntegerConstantExpression(elementExpr, kind, circularityInfo); + return tryFoldIndexExpr(indexExpr.getExpr(), kind, circularityInfo); } + return nullptr; +} - IntVal* SemanticsVisitor::tryFoldIntegerConstantExpression( - SubstExpr expr, - ConstantFoldingKind kind, - ConstantFoldingCircularityInfo* circularityInfo) - { - // Check if type is acceptable for an integer constant expression - // - if(!isValidCompileTimeConstantType(getType(m_astBuilder, expr))) - return nullptr; +IntVal* SemanticsVisitor::tryFoldIndexExpr( + SubstExpr expr, + ConstantFoldingKind kind, + ConstantFoldingCircularityInfo* circularityInfo) +{ + // Ad-hoc constant folding for index expressions. + // TOOD: we should generalize this by extending `Val` to support compile-time constants that are + // not just integers, but also arrays and structs etc, so that we can independently fold + // the base expression and the index expression, and then form a ElementExtractVal() from an + // index expr. + // For now we just specialize case for array expression that is an initialization list. + // And this won't work if the array is a link-time constant. + // + auto declRefExpr = as(expr.getExpr()->baseExpression); + if (!declRefExpr) + return nullptr; + auto varDecl = as(declRefExpr->declRef.getDecl()); + if (!varDecl) + return nullptr; + auto type = varDecl->getType(); + if (!type) + return nullptr; + auto arrayType = as(type); + if (!arrayType) + return nullptr; + if (!varDecl->hasModifier()) + return nullptr; + if (isGlobalDecl(varDecl) && !varDecl->hasModifier()) + return nullptr; + if (!varDecl->initExpr) + return nullptr; + auto arrayContentExpr = as(varDecl->initExpr); + if (!arrayContentExpr) + return nullptr; + if (expr.getExpr()->indexExprs.getCount() != 1) + return nullptr; + auto indexVal = as( + tryFoldIntegerConstantExpression(expr.getExpr()->indexExprs[0], kind, circularityInfo)); + if (!indexVal) + return nullptr; + auto index = indexVal->getValue(); + if (index < 0 || index >= arrayContentExpr->args.getCount()) + return nullptr; + auto elementExpr = arrayContentExpr->args[Index(index)]; + return tryFoldIntegerConstantExpression(elementExpr, kind, circularityInfo); +} - // Consider operations that we might be able to constant-fold... - // - return tryConstantFoldExpr(expr, kind, circularityInfo); +IntVal* SemanticsVisitor::tryFoldIntegerConstantExpression( + SubstExpr expr, + ConstantFoldingKind kind, + ConstantFoldingCircularityInfo* circularityInfo) +{ + // Check if type is acceptable for an integer constant expression + // + if (!isValidCompileTimeConstantType(getType(m_astBuilder, expr))) + return nullptr; + + // Consider operations that we might be able to constant-fold... + // + return tryConstantFoldExpr(expr, kind, circularityInfo); +} + +IntVal* SemanticsVisitor::CheckIntegerConstantExpression( + Expr* inExpr, + IntegerConstantExpressionCoercionType coercionType, + Type* expectedType, + ConstantFoldingKind kind, + DiagnosticSink* sink) +{ + // No need to issue further errors if the expression didn't even type-check. + if (IsErrorExpr(inExpr)) + return nullptr; + + // First coerce the expression to the expected type + Expr* expr = nullptr; + switch (coercionType) + { + case IntegerConstantExpressionCoercionType::SpecificType: + expr = coerce(CoercionSite::General, expectedType, inExpr); + break; + case IntegerConstantExpressionCoercionType::AnyInteger: + if (isScalarIntegerType(inExpr->type)) + expr = inExpr; + else if (isEnumType(inExpr->type)) + expr = inExpr; + else + expr = coerce(CoercionSite::General, m_astBuilder->getIntType(), inExpr); + break; + default: break; } - IntVal* SemanticsVisitor::CheckIntegerConstantExpression(Expr* inExpr, IntegerConstantExpressionCoercionType coercionType, Type* expectedType, ConstantFoldingKind kind, DiagnosticSink* sink) + // No need to issue further errors if the type coercion failed. + if (IsErrorExpr(expr)) + return nullptr; + + auto result = tryFoldIntegerConstantExpression(expr, kind, nullptr); + if (!result && sink) { - // No need to issue further errors if the expression didn't even type-check. - if(IsErrorExpr(inExpr)) return nullptr; + sink->diagnose(expr, Diagnostics::expectedIntegerConstantNotConstant); + } + return result; +} - // First coerce the expression to the expected type - Expr* expr = nullptr; - switch (coercionType) - { - case IntegerConstantExpressionCoercionType::SpecificType: - expr = coerce(CoercionSite::General, expectedType, inExpr); - break; - case IntegerConstantExpressionCoercionType::AnyInteger: - if (isScalarIntegerType(inExpr->type)) - expr = inExpr; - else if (isEnumType(inExpr->type)) - expr = inExpr; - else - expr = coerce(CoercionSite::General, m_astBuilder->getIntType(), inExpr); - break; - default: - break; - } +IntVal* SemanticsVisitor::CheckIntegerConstantExpression( + Expr* inExpr, + IntegerConstantExpressionCoercionType coercionType, + Type* expectedType, + ConstantFoldingKind kind) +{ + return CheckIntegerConstantExpression(inExpr, coercionType, expectedType, kind, getSink()); +} + +IntVal* SemanticsVisitor::CheckEnumConstantExpression(Expr* expr, ConstantFoldingKind kind) +{ + // No need to issue further errors if the expression didn't even type-check. + if (IsErrorExpr(expr)) + return nullptr; - // No need to issue further errors if the type coercion failed. - if(IsErrorExpr(expr)) return nullptr; + // No need to issue further errors if the type coercion failed. + if (IsErrorExpr(expr)) + return nullptr; - auto result = tryFoldIntegerConstantExpression(expr, kind, nullptr); - if (!result && sink) - { - sink->diagnose(expr, Diagnostics::expectedIntegerConstantNotConstant); - } - return result; + auto result = tryConstantFoldExpr(expr, kind, nullptr); + if (!result) + { + getSink()->diagnose(expr, Diagnostics::expectedIntegerConstantNotConstant); } + return result; +} - IntVal* SemanticsVisitor::CheckIntegerConstantExpression(Expr* inExpr, IntegerConstantExpressionCoercionType coercionType, Type* expectedType, ConstantFoldingKind kind) +Expr* SemanticsVisitor::CheckSimpleSubscriptExpr(IndexExpr* subscriptExpr, Type* elementType) +{ + auto baseExpr = subscriptExpr->baseExpression; + if (subscriptExpr->indexExprs.getCount() < 1) + { + getSink()->diagnose( + subscriptExpr, + Diagnostics::notEnoughArguments, + subscriptExpr->indexExprs.getCount(), + 1); + return CreateErrorExpr(subscriptExpr); + } + else if (subscriptExpr->indexExprs.getCount() > 1) { - return CheckIntegerConstantExpression(inExpr, coercionType, expectedType, kind, getSink()); + getSink()->diagnose( + subscriptExpr, + Diagnostics::tooManyArguments, + subscriptExpr->indexExprs.getCount(), + 1); + return CreateErrorExpr(subscriptExpr); } - IntVal* SemanticsVisitor::CheckEnumConstantExpression(Expr* expr, ConstantFoldingKind kind) + auto indexExpr = subscriptExpr->indexExprs[0]; + + if (!indexExpr->type->equals(m_astBuilder->getIntType()) && + !indexExpr->type->equals(m_astBuilder->getUIntType())) { - // No need to issue further errors if the expression didn't even type-check. - if(IsErrorExpr(expr)) return nullptr; + getSink()->diagnose(indexExpr, Diagnostics::subscriptIndexNonInteger); + return CreateErrorExpr(subscriptExpr); + } - // No need to issue further errors if the type coercion failed. - if(IsErrorExpr(expr)) return nullptr; + subscriptExpr->type = QualType(elementType); - auto result = tryConstantFoldExpr(expr, kind, nullptr); - if (!result) - { - getSink()->diagnose(expr, Diagnostics::expectedIntegerConstantNotConstant); - } - return result; + // TODO(tfoley): need to be more careful about this stuff + subscriptExpr->type.isLeftValue = baseExpr->type.isLeftValue; + + return subscriptExpr; +} + +Expr* SemanticsExprVisitor::visitIndexExpr(IndexExpr* subscriptExpr) +{ + bool needDeref = false; + auto baseExpr = checkBaseForMemberExpr(subscriptExpr->baseExpression, needDeref); + + // If the base expression is a type, it means that this is an array declaration, + // then we should disable short-circuit in case there is logical expression in + // the subscript + auto baseType = baseExpr->type.Ptr(); + auto baseTypeType = as(baseType); + auto subVisitor = (baseTypeType && m_shouldShortCircuitLogicExpr) + ? SemanticsVisitor(disableShortCircuitLogicalExpr()) + : *this; + + for (auto& arg : subscriptExpr->indexExprs) + { + arg = subVisitor.CheckTerm(arg); } - Expr* SemanticsVisitor::CheckSimpleSubscriptExpr( - IndexExpr* subscriptExpr, - Type* elementType) + // If anything went wrong in the base expression, + // then just move along... + if (IsErrorExpr(baseExpr)) + return CreateErrorExpr(subscriptExpr); + + subscriptExpr->baseExpression = baseExpr; + + // Otherwise, we need to look at the type of the base expression, + // to figure out how subscripting should work. + if (baseTypeType) { - auto baseExpr = subscriptExpr->baseExpression; - if (subscriptExpr->indexExprs.getCount() < 1) - { - getSink()->diagnose(subscriptExpr, Diagnostics::notEnoughArguments, subscriptExpr->indexExprs.getCount(), 1); - return CreateErrorExpr(subscriptExpr); - } - else if (subscriptExpr->indexExprs.getCount() > 1) + // We are trying to "index" into a type, so we have an expression like `float[2]` + // which should be interpreted as resolving to an array type. + + IntVal* elementCount = nullptr; + if (subscriptExpr->indexExprs.getCount() == 1) { - getSink()->diagnose(subscriptExpr, Diagnostics::tooManyArguments, subscriptExpr->indexExprs.getCount(), 1); - return CreateErrorExpr(subscriptExpr); + elementCount = CheckIntegerConstantExpression( + subscriptExpr->indexExprs[0], + IntegerConstantExpressionCoercionType::AnyInteger, + nullptr, + ConstantFoldingKind::LinkTime); } - - auto indexExpr = subscriptExpr->indexExprs[0]; - - if (!indexExpr->type->equals(m_astBuilder->getIntType()) && - !indexExpr->type->equals(m_astBuilder->getUIntType())) + else if (subscriptExpr->indexExprs.getCount() != 0) { - getSink()->diagnose(indexExpr, Diagnostics::subscriptIndexNonInteger); - return CreateErrorExpr(subscriptExpr); + getSink()->diagnose(subscriptExpr, Diagnostics::multiDimensionalArrayNotSupported); } - subscriptExpr->type = QualType(elementType); - - // TODO(tfoley): need to be more careful about this stuff - subscriptExpr->type.isLeftValue = baseExpr->type.isLeftValue; + auto elementType = CoerceToUsableType(TypeExp(baseExpr, baseTypeType->getType()), nullptr); + auto arrayType = getArrayType(m_astBuilder, elementType, elementCount); + subscriptExpr->type = QualType(m_astBuilder->getTypeType(arrayType)); return subscriptExpr; } - - Expr* SemanticsExprVisitor::visitIndexExpr(IndexExpr* subscriptExpr) + else if (auto baseArrayType = as(baseType)) { - bool needDeref = false; - auto baseExpr = checkBaseForMemberExpr(subscriptExpr->baseExpression, needDeref); - - // If the base expression is a type, it means that this is an array declaration, - // then we should disable short-circuit in case there is logical expression in - // the subscript - auto baseType = baseExpr->type.Ptr(); - auto baseTypeType = as(baseType); - auto subVisitor = (baseTypeType && m_shouldShortCircuitLogicExpr)? - SemanticsVisitor(disableShortCircuitLogicalExpr()) : *this; + return CheckSimpleSubscriptExpr(subscriptExpr, baseArrayType->getElementType()); + } + else if (auto vecType = as(baseType)) + { + return CheckSimpleSubscriptExpr(subscriptExpr, vecType->getElementType()); + } + else if (auto matType = as(baseType)) + { + // TODO(tfoley): We shouldn't go and recompute + // row types over and over like this... :( + auto rowType = createVectorType(matType->getElementType(), matType->getColumnCount()); - for (auto& arg : subscriptExpr->indexExprs) - { - arg = subVisitor.CheckTerm(arg); - } + return CheckSimpleSubscriptExpr(subscriptExpr, rowType); + } - // If anything went wrong in the base expression, - // then just move along... - if (IsErrorExpr(baseExpr)) - return CreateErrorExpr(subscriptExpr); + // Default behavior is to look at all available `__subscript` + // declarations on the type and try to call one of them. - subscriptExpr->baseExpression = baseExpr; + auto operatorName = getName("operator[]"); - // Otherwise, we need to look at the type of the base expression, - // to figure out how subscripting should work. - if (baseTypeType) - { - // We are trying to "index" into a type, so we have an expression like `float[2]` - // which should be interpreted as resolving to an array type. + LookupResult lookupResult = lookUpMember( + m_astBuilder, + this, + operatorName, + baseType, + m_outerScope, + LookupMask::Default, + LookupOptions::NoDeref); + bool diagnosed = false; + lookupResult = + filterLookupResultByVisibilityAndDiagnose(lookupResult, subscriptExpr->loc, diagnosed); + if (!lookupResult.isValid()) + { + if (!diagnosed) + getSink()->diagnose(subscriptExpr, Diagnostics::subscriptNonArray, baseType); + return CreateErrorExpr(subscriptExpr); + } + auto subscriptFuncExpr = createLookupResultExpr( + operatorName, + lookupResult, + subscriptExpr->baseExpression, + subscriptExpr->loc, + subscriptExpr); + + InvokeExpr* subscriptCallExpr = m_astBuilder->create(); + subscriptCallExpr->loc = subscriptExpr->loc; + subscriptCallExpr->functionExpr = subscriptFuncExpr; + subscriptCallExpr->arguments.addRange(subscriptExpr->indexExprs); + subscriptCallExpr->argumentDelimeterLocs.addRange(subscriptExpr->argumentDelimeterLocs); + + return CheckInvokeExprWithCheckedOperands(subscriptCallExpr); +} - IntVal* elementCount = nullptr; - if (subscriptExpr->indexExprs.getCount() == 1) - { - elementCount = CheckIntegerConstantExpression(subscriptExpr->indexExprs[0], IntegerConstantExpressionCoercionType::AnyInteger, nullptr, ConstantFoldingKind::LinkTime); - } - else if (subscriptExpr->indexExprs.getCount() != 0) - { - getSink()->diagnose(subscriptExpr, Diagnostics::multiDimensionalArrayNotSupported); - } +Expr* SemanticsExprVisitor::visitParenExpr(ParenExpr* expr) +{ + auto base = expr->base; + base = CheckTerm(base); - auto elementType = CoerceToUsableType(TypeExp(baseExpr, baseTypeType->getType()), nullptr); - auto arrayType = getArrayType( - m_astBuilder, - elementType, - elementCount); + expr->base = base; + expr->type = base->type; + return expr; +} - subscriptExpr->type = QualType(m_astBuilder->getTypeType(arrayType)); - return subscriptExpr; - } - else if (auto baseArrayType = as(baseType)) +void SemanticsVisitor::maybeDiagnoseThisNotLValue(Expr* expr) +{ + // We will try to handle expressions of the form: + // + // e ::= "this" + // | e . name + // | e [ expr ] + // + // We will unwrap the `e.name` and `e[expr]` cases in a loop. + Expr* e = expr; + for (;;) + { + if (auto memberExpr = as(e)) { - return CheckSimpleSubscriptExpr( - subscriptExpr, - baseArrayType->getElementType()); + e = memberExpr->baseExpression; } - else if (auto vecType = as(baseType)) + else if (auto subscriptExpr = as(e)) { - return CheckSimpleSubscriptExpr( - subscriptExpr, - vecType->getElementType()); + e = subscriptExpr->baseExpression; } - else if (auto matType = as(baseType)) + else { - // TODO(tfoley): We shouldn't go and recompute - // row types over and over like this... :( - auto rowType = createVectorType( - matType->getElementType(), - matType->getColumnCount()); - - return CheckSimpleSubscriptExpr( - subscriptExpr, - rowType); + break; } - - // Default behavior is to look at all available `__subscript` - // declarations on the type and try to call one of them. - - auto operatorName = getName("operator[]"); - - LookupResult lookupResult = lookUpMember( - m_astBuilder, - this, - operatorName, - baseType, - m_outerScope, - LookupMask::Default, - LookupOptions::NoDeref); - bool diagnosed = false; - lookupResult = filterLookupResultByVisibilityAndDiagnose(lookupResult, subscriptExpr->loc, diagnosed); - if (!lookupResult.isValid()) + } + // + // Now we check to see if we have a `this` expression, + // and if it is immutable. + if (auto thisExpr = as(e)) + { + if (!thisExpr->type.isLeftValue) { - if (!diagnosed) - getSink()->diagnose(subscriptExpr, Diagnostics::subscriptNonArray, baseType); - return CreateErrorExpr(subscriptExpr); + getSink()->diagnoseWithoutSourceView(thisExpr, Diagnostics::thisIsImmutableByDefault); } - auto subscriptFuncExpr = createLookupResultExpr( - operatorName, - lookupResult, - subscriptExpr->baseExpression, - subscriptExpr->loc, - subscriptExpr); - - InvokeExpr* subscriptCallExpr = m_astBuilder->create(); - subscriptCallExpr->loc = subscriptExpr->loc; - subscriptCallExpr->functionExpr = subscriptFuncExpr; - subscriptCallExpr->arguments.addRange(subscriptExpr->indexExprs); - subscriptCallExpr->argumentDelimeterLocs.addRange(subscriptExpr->argumentDelimeterLocs); - - return CheckInvokeExprWithCheckedOperands(subscriptCallExpr); } +} - Expr* SemanticsExprVisitor::visitParenExpr(ParenExpr* expr) - { - auto base = expr->base; - base = CheckTerm(base); +Expr* SemanticsVisitor::checkAssignWithCheckedOperands(AssignExpr* expr) +{ + if (expr->right->type.isWriteOnly) + getSink()->diagnose(expr, Diagnostics::readingFromWriteOnly); - expr->base = base; - expr->type = base->type; - return expr; + expr->left = maybeOpenRef(expr->left); + auto type = expr->left->type; + if (auto atomicType = as(type)) + { + type = atomicType->getElementType(); } + auto right = maybeOpenRef(expr->right); + expr->right = coerce(CoercionSite::Assignment, type, right); - void SemanticsVisitor::maybeDiagnoseThisNotLValue(Expr* expr) + if (!expr->left->type.isLeftValue) { - // We will try to handle expressions of the form: - // - // e ::= "this" - // | e . name - // | e [ expr ] - // - // We will unwrap the `e.name` and `e[expr]` cases in a loop. - Expr* e = expr; - for(;;) + if (as(type)) { - if(auto memberExpr = as(e)) - { - e = memberExpr->baseExpression; - } - else if(auto subscriptExpr = as(e)) - { - e = subscriptExpr->baseExpression; - } - else - { - break; - } + // Don't report an l-value issue on an erroneous expression } - // - // Now we check to see if we have a `this` expression, - // and if it is immutable. - if(auto thisExpr = as(e)) + else { - if(!thisExpr->type.isLeftValue) - { - getSink()->diagnoseWithoutSourceView(thisExpr, Diagnostics::thisIsImmutableByDefault); - } + getSink()->diagnose(expr, Diagnostics::assignNonLValue); + + // As a special case, check if the LHS expression is derived + // from a `this` parameter (implicitly or explicitly), which + // is immutable. We can give the user a bit more context into + // what is going on. + // + maybeDiagnoseThisNotLValue(expr->left); } } + expr->type = type; + return expr; +} - Expr* SemanticsVisitor::checkAssignWithCheckedOperands(AssignExpr* expr) - { - if (expr->right->type.isWriteOnly) - getSink()->diagnose(expr, Diagnostics::readingFromWriteOnly); +Expr* SemanticsExprVisitor::visitAssignExpr(AssignExpr* expr) +{ + expr->left = CheckExpr(expr->left); + expr->right = CheckTerm(expr->right); - expr->left = maybeOpenRef(expr->left); - auto type = expr->left->type; - if (auto atomicType = as(type)) - { - type = atomicType->getElementType(); - } - auto right = maybeOpenRef(expr->right); - expr->right = coerce(CoercionSite::Assignment, type, right); - - if (!expr->left->type.isLeftValue) - { - if (as(type)) - { - // Don't report an l-value issue on an erroneous expression - } - else - { - getSink()->diagnose(expr, Diagnostics::assignNonLValue); - - // As a special case, check if the LHS expression is derived - // from a `this` parameter (implicitly or explicitly), which - // is immutable. We can give the user a bit more context into - // what is going on. - // - maybeDiagnoseThisNotLValue(expr->left); - } - } - expr->type = type; - return expr; - } - - Expr* SemanticsExprVisitor::visitAssignExpr(AssignExpr* expr) - { - expr->left = CheckExpr(expr->left); - expr->right = CheckTerm(expr->right); + return checkAssignWithCheckedOperands(expr); +} - return checkAssignWithCheckedOperands(expr); - } +Expr* SemanticsVisitor::CheckExpr(Expr* uncheckedExpr) +{ + auto checkedTerm = CheckTerm(uncheckedExpr); - Expr* SemanticsVisitor::CheckExpr(Expr* uncheckedExpr) - { - auto checkedTerm = CheckTerm(uncheckedExpr); + // First, we want to do any disambiguation that is needed in order + // to turn the `term` into an expression that names a single + // value (and not something overloaded). + // + auto checkedExpr = maybeResolveOverloadedExpr(checkedTerm, LookupMask::Default, getSink()); - // First, we want to do any disambiguation that is needed in order - // to turn the `term` into an expression that names a single - // value (and not something overloaded). - // - auto checkedExpr = maybeResolveOverloadedExpr(checkedTerm, LookupMask::Default, getSink()); + // Next, we want to ensure that the `expr` actually has a type + // that is allowable in an expression context (e.g., make sure + // that `expr` names a value and not a type). + // + // TODO: Implement this step. - // Next, we want to ensure that the `expr` actually has a type - // that is allowable in an expression context (e.g., make sure - // that `expr` names a value and not a type). - // - // TODO: Implement this step. + return checkedExpr; +} - return checkedExpr; - } +static bool _canLValueCoerceScalarType(Type* a, Type* b) +{ + auto basicTypeA = as(a); + auto basicTypeB = as(b); - static bool _canLValueCoerceScalarType(Type* a, Type* b) + if (basicTypeA && basicTypeB) { - auto basicTypeA = as(a); - auto basicTypeB = as(b); + const auto& infoA = BaseTypeInfo::getInfo(basicTypeA->getBaseType()); + const auto& infoB = BaseTypeInfo::getInfo(basicTypeB->getBaseType()); - if (basicTypeA && basicTypeB) + // TODO(JS): Initially this tries to limit where LValueImplict casts happen. + // We could in principal allow different sizes, as long as we converted to a temprorary + // and back again. + // + // For now we just stick with the simple case. + // // We only allow on integer types for now. In effect just allowing any size uint/int + // conversions + if (infoA.sizeInBytes == infoB.sizeInBytes && + (infoA.flags & infoB.flags & BaseTypeInfo::Flag::Integer)) { - const auto& infoA = BaseTypeInfo::getInfo(basicTypeA->getBaseType()); - const auto& infoB = BaseTypeInfo::getInfo(basicTypeB->getBaseType()); - - // TODO(JS): Initially this tries to limit where LValueImplict casts happen. - // We could in principal allow different sizes, as long as we converted to a temprorary - // and back again. - // - // For now we just stick with the simple case. - // // We only allow on integer types for now. In effect just allowing any size uint/int conversions - if (infoA.sizeInBytes == infoB.sizeInBytes && - (infoA.flags & infoB.flags & BaseTypeInfo::Flag::Integer)) - { - return true; - } - + return true; } - return false; } + return false; +} - static bool _canLValueCoerce(Type* a, Type* b) +static bool _canLValueCoerce(Type* a, Type* b) +{ + // We can *assume* here that if they are coercable, that dimensions of vectors + // and matrices match. We might want to assert to be sure... + SLANG_ASSERT(a != b); + if (a->astNodeType == b->astNodeType) { - // We can *assume* here that if they are coercable, that dimensions of vectors - // and matrices match. We might want to assert to be sure... - SLANG_ASSERT(a != b); - if (a->astNodeType == b->astNodeType) + if (auto matA = as(a)) { - if (auto matA = as(a)) - { - return _canLValueCoerceScalarType(matA->getElementType(), static_cast(b)->getElementType()); - } - else if (auto vecA = as(a)) - { - return _canLValueCoerceScalarType(vecA->getScalarType(), static_cast(b)->getScalarType()); - } + return _canLValueCoerceScalarType( + matA->getElementType(), + static_cast(b)->getElementType()); + } + else if (auto vecA = as(a)) + { + return _canLValueCoerceScalarType( + vecA->getScalarType(), + static_cast(b)->getScalarType()); } - return _canLValueCoerceScalarType(a, b); } + return _canLValueCoerceScalarType(a, b); +} - void SemanticsVisitor::compareMemoryQualifierOfParamToArgument( - ParamDecl* paramIn, - Expr* argIn) - { - auto arg = as(argIn); - if (!paramIn || !arg) - return; +void SemanticsVisitor::compareMemoryQualifierOfParamToArgument(ParamDecl* paramIn, Expr* argIn) +{ + auto arg = as(argIn); + if (!paramIn || !arg) + return; + + auto argDeclRef = arg->declRef; + if (!argDeclRef) + return; + auto argDecl = argDeclRef.getDecl(); + auto argMemMods = argDecl->findModifier(); + if (!argMemMods) + return; + uint32_t argQualifiers = argMemMods->getMemoryQualifierBit(); + + uint32_t paramQualifiers = 0; + auto paramMemMods = paramIn->findModifier(); + if (paramMemMods) + paramQualifiers = paramMemMods->getMemoryQualifierBit(); + + if (argQualifiers & MemoryQualifierSetModifier::Flags::kCoherent && + !(paramQualifiers & MemoryQualifierSetModifier::Flags::kCoherent)) + getSink()->diagnose(arg, Diagnostics::argumentHasMoreMemoryQualifiersThanParam, "coherent"); + if (argQualifiers & MemoryQualifierSetModifier::Flags::kReadOnly && + !(paramQualifiers & MemoryQualifierSetModifier::Flags::kReadOnly)) + getSink()->diagnose(arg, Diagnostics::argumentHasMoreMemoryQualifiersThanParam, "readonly"); + if (argQualifiers & MemoryQualifierSetModifier::Flags::kWriteOnly && + !(paramQualifiers & MemoryQualifierSetModifier::Flags::kWriteOnly)) + getSink()->diagnose( + arg, + Diagnostics::argumentHasMoreMemoryQualifiersThanParam, + "writeonly"); + if (argQualifiers & MemoryQualifierSetModifier::Flags::kVolatile && + !(paramQualifiers & MemoryQualifierSetModifier::Flags::kVolatile)) + getSink()->diagnose(arg, Diagnostics::argumentHasMoreMemoryQualifiersThanParam, "volatile"); + // dropping a `restrict` qualifier from arguments is allowed in GLSL with memory qualifiers +} - auto argDeclRef = arg->declRef; - if (!argDeclRef) - return; - auto argDecl = argDeclRef.getDecl(); - auto argMemMods = argDecl->findModifier(); - if(!argMemMods) - return; - uint32_t argQualifiers = argMemMods->getMemoryQualifierBit(); - - uint32_t paramQualifiers = 0; - auto paramMemMods = paramIn->findModifier(); - if(paramMemMods) - paramQualifiers = paramMemMods->getMemoryQualifierBit(); - - if(argQualifiers & MemoryQualifierSetModifier::Flags::kCoherent - && !(paramQualifiers & MemoryQualifierSetModifier::Flags::kCoherent)) - getSink()->diagnose(arg, Diagnostics::argumentHasMoreMemoryQualifiersThanParam, "coherent"); - if(argQualifiers & MemoryQualifierSetModifier::Flags::kReadOnly - && !(paramQualifiers & MemoryQualifierSetModifier::Flags::kReadOnly)) - getSink()->diagnose(arg, Diagnostics::argumentHasMoreMemoryQualifiersThanParam, "readonly"); - if(argQualifiers & MemoryQualifierSetModifier::Flags::kWriteOnly - && !(paramQualifiers & MemoryQualifierSetModifier::Flags::kWriteOnly)) - getSink()->diagnose(arg, Diagnostics::argumentHasMoreMemoryQualifiersThanParam, "writeonly"); - if(argQualifiers & MemoryQualifierSetModifier::Flags::kVolatile - && !(paramQualifiers & MemoryQualifierSetModifier::Flags::kVolatile)) - getSink()->diagnose(arg, Diagnostics::argumentHasMoreMemoryQualifiersThanParam, "volatile"); - // dropping a `restrict` qualifier from arguments is allowed in GLSL with memory qualifiers - } - - Expr* SemanticsVisitor::CheckInvokeExprWithCheckedOperands(InvokeExpr *expr) - { - auto rs = ResolveInvoke(expr); - if (auto invoke = as(rs)) +Expr* SemanticsVisitor::CheckInvokeExprWithCheckedOperands(InvokeExpr* expr) +{ + auto rs = ResolveInvoke(expr); + if (auto invoke = as(rs)) + { + // if this is still an invoke expression, test arguments passed to inout/out parameter are + // LValues + if (auto funcType = as(invoke->functionExpr->type)) { - // if this is still an invoke expression, test arguments passed to inout/out parameter are LValues - if(auto funcType = as(invoke->functionExpr->type)) + if (!funcType->getErrorType()->equals(m_astBuilder->getBottomType())) { - if (!funcType->getErrorType()->equals(m_astBuilder->getBottomType())) + // If the callee throws, make sure we are inside a try clause. + if (m_enclosingTryClauseType == TryClauseType::None) { - // If the callee throws, make sure we are inside a try clause. - if (m_enclosingTryClauseType == TryClauseType::None) - { - getSink()->diagnose(invoke, Diagnostics::mustUseTryClauseToCallAThrowFunc); - } + getSink()->diagnose(invoke, Diagnostics::mustUseTryClauseToCallAThrowFunc); } + } - auto funcDeclRefExpr = as(invoke->functionExpr); - FunctionDeclBase* funcDeclBase = nullptr; - if (funcDeclRefExpr) - funcDeclBase = as(funcDeclRefExpr->declRef.getDecl()); + auto funcDeclRefExpr = as(invoke->functionExpr); + FunctionDeclBase* funcDeclBase = nullptr; + if (funcDeclRefExpr) + funcDeclBase = as(funcDeclRefExpr->declRef.getDecl()); - Index paramCount = funcType->getParamCount(); - for (Index pp = 0; pp < paramCount; ++pp) + Index paramCount = funcType->getParamCount(); + for (Index pp = 0; pp < paramCount; ++pp) + { + auto paramType = funcType->getParamType(pp); + Expr* argExpr = nullptr; + ParamDecl* paramDecl = nullptr; + if (pp < expr->arguments.getCount()) { - auto paramType = funcType->getParamType(pp); - Expr* argExpr = nullptr; - ParamDecl* paramDecl = nullptr; - if (pp < expr->arguments.getCount()) - { - argExpr = expr->arguments[pp]; - if(funcDeclBase) - paramDecl = funcDeclBase->getParameters()[pp]; - } - compareMemoryQualifierOfParamToArgument(paramDecl, argExpr); + argExpr = expr->arguments[pp]; + if (funcDeclBase) + paramDecl = funcDeclBase->getParameters()[pp]; + } + compareMemoryQualifierOfParamToArgument(paramDecl, argExpr); - if (as(paramType) || as(paramType)) + if (as(paramType) || as(paramType)) + { + // `out`, `inout`, and `ref` parameters currently require + // an *exact* match on the type of the argument. + // + // TODO: relax this requirement by allowing an argument + // for an `inout` parameter to be converted in both + // directions. + // + if (argExpr) { - // `out`, `inout`, and `ref` parameters currently require - // an *exact* match on the type of the argument. - // - // TODO: relax this requirement by allowing an argument - // for an `inout` parameter to be converted in both - // directions. - // - if( argExpr ) + if (!argExpr->type.isLeftValue) { - if( !argExpr->type.isLeftValue) + auto implicitCastExpr = as(argExpr); + + // NOTE: + // This is currently only enabled for in/inout based scenarios. Ie NOT + // ref. + // + // Depending on the target there can be an issue around atomics. + // The fall back transformation with InOut/OutImplicitCast is to + // introduce a temporary, and do the work on that and copy back. + // + // This doesn't work with an atomic. So the work around is to not enable + // the transformation with ref types, which atomics are defined on. + // + // An argument can be made that transformation shouldn't apply to the + // ref scenario in general. + if (implicitCastExpr && as(paramType) && + _canLValueCoerce( + implicitCastExpr->arguments[0]->type, + implicitCastExpr->type)) + { + // This is to work around issues like + // + // ``` + // int a = 0; + // uint b = 1; + // a += b; + // ``` + // That strictly speaking it's not allowed, but we are going to + // allow it for now for situations were the types are uint/int and + // vector/matrix varieties of those types + // + // Then in lowering we are going to insert code to do something like + // ``` + // var OutType: tmp = arg; + // f(... tmp); + // arg = tmp; + // ``` + + TypeCastExpr* lValueImplicitCast; + + // We want to record if the cast is being used for `out` or + // `inout`/`ref` as if it's just `out` we won't need to convert + // before passing in. + if (as(paramType)) + { + lValueImplicitCast = + getASTBuilder()->create( + *implicitCastExpr); + } + else + { + lValueImplicitCast = + getASTBuilder()->create( + *implicitCastExpr); + } + + // Replace the expression. This should make this situation easier to + // detect. + expr->arguments[pp] = lValueImplicitCast; + } + else if (!as(argExpr->type)) { - auto implicitCastExpr = as(argExpr); - - // NOTE: - // This is currently only enabled for in/inout based scenarios. Ie NOT ref. - // - // Depending on the target there can be an issue around atomics. - // The fall back transformation with InOut/OutImplicitCast is to introduce - // a temporary, and do the work on that and copy back. - // - // This doesn't work with an atomic. So the work around is to not enable - // the transformation with ref types, which atomics are defined on. - // - // An argument can be made that transformation shouldn't apply to the ref scenario in general. - if (implicitCastExpr && - as(paramType) && - _canLValueCoerce(implicitCastExpr->arguments[0]->type, implicitCastExpr->type)) + getSink()->diagnose( + argExpr, + Diagnostics::argumentExpectedLValue, + pp); + + + if (implicitCastExpr) { - // This is to work around issues like - // - // ``` - // int a = 0; - // uint b = 1; - // a += b; - // ``` - // That strictly speaking it's not allowed, but we are going to allow it for now - // for situations were the types are uint/int and vector/matrix varieties of those types - // - // Then in lowering we are going to insert code to do something like - // ``` - // var OutType: tmp = arg; - // f(... tmp); - // arg = tmp; - // ``` - - TypeCastExpr* lValueImplicitCast; - - // We want to record if the cast is being used for `out` or `inout`/`ref` as - // if it's just `out` we won't need to convert before passing in. - if (as(paramType)) + const DiagnosticInfo* diagnostic = nullptr; + + // Try and determine reason for failure + if (as(paramType)) { - lValueImplicitCast = getASTBuilder()->create(*implicitCastExpr); + // Ref types are not allowed to use this mechanism because + // it breaks atomics + diagnostic = &Diagnostics::implicitCastUsedAsLValueRef; } - else + else if (!_canLValueCoerce( + implicitCastExpr->arguments[0]->type, + implicitCastExpr->type)) { - lValueImplicitCast = getASTBuilder()->create(*implicitCastExpr); + // We restict what types can use this mechanism - currently + // int/uint and same sized matrix/vectors of those types. + diagnostic = &Diagnostics::implicitCastUsedAsLValueType; } - - // Replace the expression. This should make this situation easier to detect. - expr->arguments[pp] = lValueImplicitCast; - } - else if (!as(argExpr->type)) - { - getSink()->diagnose( - argExpr, - Diagnostics::argumentExpectedLValue, - pp); - - - if(implicitCastExpr) + else { - const DiagnosticInfo* diagnostic = nullptr; - - // Try and determine reason for failure - if (as(paramType)) - { - // Ref types are not allowed to use this mechanism because it breaks atomics - diagnostic = &Diagnostics::implicitCastUsedAsLValueRef; - } - else if (!_canLValueCoerce(implicitCastExpr->arguments[0]->type, implicitCastExpr->type)) - { - // We restict what types can use this mechanism - currently int/uint and same sized matrix/vectors - // of those types. - diagnostic = &Diagnostics::implicitCastUsedAsLValueType; - } - else - { - // Fall back, in case there are other reasons... - diagnostic = &Diagnostics::implicitCastUsedAsLValue; - } - getSink()->diagnoseWithoutSourceView( - argExpr, - *diagnostic, - implicitCastExpr->arguments[0]->type, - implicitCastExpr->type); + // Fall back, in case there are other reasons... + diagnostic = &Diagnostics::implicitCastUsedAsLValue; } - - maybeDiagnoseThisNotLValue(argExpr); + getSink()->diagnoseWithoutSourceView( + argExpr, + *diagnostic, + implicitCastExpr->arguments[0]->type, + implicitCastExpr->type); } + + maybeDiagnoseThisNotLValue(argExpr); } } - else - { - // There are two ways we could get here, both involving - // a call where the number of argument expressions is - // less than the number of parameters on the callee: - // - // 1. There might be fewer arguments than parameters - // because the trailing parameters should be defaulted - // - // 2. There might be fewer arguments than parameters - // because the call is incorrect. - // - // In case (2) an error would have already been diagnosed, - // and we don't want to emit another cascading error here. - // - // In case (1) this implies the user declared an `out` - // or `inout` parameter with a default argument expression. - // That should be an error, but it should be detected - // on the declaration instead of here at the use site. - // - // Thus, it makes sense to ignore this case here. - } + } + else + { + // There are two ways we could get here, both involving + // a call where the number of argument expressions is + // less than the number of parameters on the callee: + // + // 1. There might be fewer arguments than parameters + // because the trailing parameters should be defaulted + // + // 2. There might be fewer arguments than parameters + // because the call is incorrect. + // + // In case (2) an error would have already been diagnosed, + // and we don't want to emit another cascading error here. + // + // In case (1) this implies the user declared an `out` + // or `inout` parameter with a default argument expression. + // That should be an error, but it should be detected + // on the declaration instead of here at the use site. + // + // Thus, it makes sense to ignore this case here. } } + } - if (auto higherOrderInvoke = as(invoke->functionExpr)) + if (auto higherOrderInvoke = as(invoke->functionExpr)) + { + FunctionDifferentiableLevel requiredLevel; + if (auto funcDeclExpr = as( + getInnerMostExprFromHigherOrderExpr(higherOrderInvoke, requiredLevel))) { - FunctionDifferentiableLevel requiredLevel; - if (auto funcDeclExpr = as( - getInnerMostExprFromHigherOrderExpr(higherOrderInvoke, requiredLevel))) + auto funcDecl = as(funcDeclExpr->declRef.getDecl()); + if (funcDecl) { - auto funcDecl = as(funcDeclExpr->declRef.getDecl()); - if (funcDecl) + if (requiredLevel == FunctionDifferentiableLevel::Forward && + !getShared()->isDifferentiableFunc(funcDecl)) { - if (requiredLevel == FunctionDifferentiableLevel::Forward && - !getShared()->isDifferentiableFunc(funcDecl)) - { - getSink()->diagnose(funcDeclExpr, Diagnostics::functionNotMarkedAsDifferentiable, funcDecl, "forward"); - } - if (requiredLevel == FunctionDifferentiableLevel::Backward && - !getShared()->isBackwardDifferentiableFunc(funcDecl)) - { - getSink()->diagnose(funcDeclExpr, Diagnostics::functionNotMarkedAsDifferentiable, funcDecl, "backward"); - } - if (!isEffectivelyStatic(funcDecl) && !isGlobalDecl(funcDecl)) - { - getSink()->diagnose(invoke->functionExpr, Diagnostics::nonStaticMemberFunctionNotAllowedAsDiffOperand, funcDecl); - } + getSink()->diagnose( + funcDeclExpr, + Diagnostics::functionNotMarkedAsDifferentiable, + funcDecl, + "forward"); + } + if (requiredLevel == FunctionDifferentiableLevel::Backward && + !getShared()->isBackwardDifferentiableFunc(funcDecl)) + { + getSink()->diagnose( + funcDeclExpr, + Diagnostics::functionNotMarkedAsDifferentiable, + funcDecl, + "backward"); + } + if (!isEffectivelyStatic(funcDecl) && !isGlobalDecl(funcDecl)) + { + getSink()->diagnose( + invoke->functionExpr, + Diagnostics::nonStaticMemberFunctionNotAllowedAsDiffOperand, + funcDecl); } } } } } - return rs; } + return rs; +} - Expr* SemanticsExprVisitor::visitSelectExpr(SelectExpr* expr) - { - auto result = visitInvokeExpr(expr); - if (as(result->type.type)) - return result; - auto invokeExpr = as(result); - if (!result) - return result; - if (invokeExpr->arguments.getCount() != 3) - return result; - - if (as(invokeExpr->arguments[0]->type.type)) - { - auto newArgs = invokeExpr->arguments; - expr->arguments.clear(); - expr->arguments = newArgs; - expr->type = invokeExpr->type; - return expr; - } - - if (getParentDifferentiableAttribute()) - { - // If we are in a differentiable func, issue - // a diagnostic on use of non short-circuiting select. - getSink()->diagnose(expr->loc, Diagnostics::useOfNonShortCircuitingOperatorInDiffFunc); - } - else - { - // For all other functions, we issue a warning for deprecation of vector-typed ?: operator. - getSink()->diagnose(expr->loc, Diagnostics::useOfNonShortCircuitingOperator); - } +Expr* SemanticsExprVisitor::visitSelectExpr(SelectExpr* expr) +{ + auto result = visitInvokeExpr(expr); + if (as(result->type.type)) + return result; + auto invokeExpr = as(result); + if (!result) + return result; + if (invokeExpr->arguments.getCount() != 3) return result; + + if (as(invokeExpr->arguments[0]->type.type)) + { + auto newArgs = invokeExpr->arguments; + expr->arguments.clear(); + expr->arguments = newArgs; + expr->type = invokeExpr->type; + return expr; } - Expr* SemanticsExprVisitor::convertToLogicOperatorExpr(InvokeExpr* expr) + if (getParentDifferentiableAttribute()) + { + // If we are in a differentiable func, issue + // a diagnostic on use of non short-circuiting select. + getSink()->diagnose(expr->loc, Diagnostics::useOfNonShortCircuitingOperatorInDiffFunc); + } + else { - LogicOperatorShortCircuitExpr* newExpr = nullptr; + // For all other functions, we issue a warning for deprecation of vector-typed ?: operator. + getSink()->diagnose(expr->loc, Diagnostics::useOfNonShortCircuitingOperator); + } + return result; +} - // If the logic expression is inside the generic parameter list, it cannot support short-circuit - // which will generate the ifelse branch. - if (!m_shouldShortCircuitLogicExpr) - { - return nullptr; - } +Expr* SemanticsExprVisitor::convertToLogicOperatorExpr(InvokeExpr* expr) +{ + LogicOperatorShortCircuitExpr* newExpr = nullptr; + + // If the logic expression is inside the generic parameter list, it cannot support short-circuit + // which will generate the ifelse branch. + if (!m_shouldShortCircuitLogicExpr) + { + return nullptr; + } - if (auto varExpr = as(expr->functionExpr)) + if (auto varExpr = as(expr->functionExpr)) + { + if ((varExpr->name->text == "&&") || (varExpr->name->text == "||")) { - if ((varExpr->name->text == "&&") || (varExpr->name->text == "||")) + // We only use short-circuiting in scalar input, will fall back + // to non-short-circuiting in vector input. + bool shortCircuitSupport = true; + for (auto& arg : expr->arguments) { - // We only use short-circuiting in scalar input, will fall back - // to non-short-circuiting in vector input. - bool shortCircuitSupport = true; - for (auto & arg : expr->arguments) + if (!as(arg->type.type)) { - if(!as(arg->type.type)) - { - shortCircuitSupport = false; - } + shortCircuitSupport = false; } + } - if (!shortCircuitSupport) - { - return nullptr; - } + if (!shortCircuitSupport) + { + return nullptr; + } - // We do the cast in the 2nd pass because we want to leave it for 'visitInvokeExpr' - // to handle if this expression doesn't support short-circuiting. - for (auto & arg : expr->arguments) - { - arg = coerce(CoercionSite::Argument, m_astBuilder->getBoolType(), arg); - } + // We do the cast in the 2nd pass because we want to leave it for 'visitInvokeExpr' + // to handle if this expression doesn't support short-circuiting. + for (auto& arg : expr->arguments) + { + arg = coerce(CoercionSite::Argument, m_astBuilder->getBoolType(), arg); + } - expr->functionExpr = CheckTerm(expr->functionExpr); - newExpr = m_astBuilder->create(); - if (varExpr->name->text == "&&") - { - newExpr->flavor = LogicOperatorShortCircuitExpr::Flavor::And; - } - else - { - newExpr->flavor = LogicOperatorShortCircuitExpr::Flavor::Or; - } - newExpr->loc = expr->loc; - newExpr->functionExpr = expr->functionExpr; - newExpr->type = m_astBuilder->getBoolType(); - newExpr->arguments = expr->arguments; + expr->functionExpr = CheckTerm(expr->functionExpr); + newExpr = m_astBuilder->create(); + if (varExpr->name->text == "&&") + { + newExpr->flavor = LogicOperatorShortCircuitExpr::Flavor::And; + } + else + { + newExpr->flavor = LogicOperatorShortCircuitExpr::Flavor::Or; } + newExpr->loc = expr->loc; + newExpr->functionExpr = expr->functionExpr; + newExpr->type = m_astBuilder->getBoolType(); + newExpr->arguments = expr->arguments; } + } - return newExpr; + return newExpr; +} + +Expr* SemanticsExprVisitor::visitInvokeExpr(InvokeExpr* expr) +{ + // check the base expression first + if (!expr->originalFunctionExpr) + expr->originalFunctionExpr = expr->functionExpr; + auto treatAsDifferentiableExpr = m_treatAsDifferentiableExpr; + m_treatAsDifferentiableExpr = nullptr; + // Next check the argument expressions + for (auto& arg : expr->arguments) + { + arg = CheckExpr(arg); } - Expr* SemanticsExprVisitor::visitInvokeExpr(InvokeExpr* expr) + // if the expression is '&&' or '||', we will convert it + // to use short-circuit evaluation. + if (auto newExpr = convertToLogicOperatorExpr(expr)) + return newExpr; + + expr->functionExpr = CheckTerm(expr->functionExpr); + + if (auto baseType = as(expr->functionExpr->type)) { - // check the base expression first - if (!expr->originalFunctionExpr) - expr->originalFunctionExpr = expr->functionExpr; - auto treatAsDifferentiableExpr = m_treatAsDifferentiableExpr; - m_treatAsDifferentiableExpr = nullptr; - // Next check the argument expressions - for (auto & arg : expr->arguments) + // If callee is a value of DeclRefType, then it is a functor. + // We need to look for `operator()` member within the type and + // call that instead. + auto operatorName = getName("()"); + + bool needDeref = false; + expr->functionExpr = maybeInsertImplicitOpForMemberBase(expr->functionExpr, needDeref); + + LookupResult lookupResult = lookUpMember( + m_astBuilder, + this, + operatorName, + expr->functionExpr->type, + m_outerScope, + LookupMask::Default, + LookupOptions::NoDeref); + bool diagnosed = false; + lookupResult = + filterLookupResultByVisibilityAndDiagnose(lookupResult, expr->loc, diagnosed); + if (!lookupResult.isValid()) { - arg = CheckExpr(arg); + if (!diagnosed) + getSink()->diagnose(expr, Diagnostics::callOperatorNotFound, baseType); + return CreateErrorExpr(expr); } + auto callFuncExpr = createLookupResultExpr( + operatorName, + lookupResult, + expr->functionExpr, + expr->loc, + expr->functionExpr); + expr->functionExpr = callFuncExpr; + } - // if the expression is '&&' or '||', we will convert it - // to use short-circuit evaluation. - if (auto newExpr = convertToLogicOperatorExpr(expr)) - return newExpr; - - expr->functionExpr = CheckTerm(expr->functionExpr); + m_treatAsDifferentiableExpr = treatAsDifferentiableExpr; - if (auto baseType = as(expr->functionExpr->type)) + // If we are in a differentiable function, register differential witness tables involved in + // this call. + if (m_parentFunc && m_parentFunc->hasModifier()) + { + for (auto& arg : expr->arguments) { - // If callee is a value of DeclRefType, then it is a functor. - // We need to look for `operator()` member within the type and - // call that instead. - auto operatorName = getName("()"); + maybeRegisterDifferentiableType(m_astBuilder, arg->type.type); + } + } - bool needDeref = false; - expr->functionExpr = maybeInsertImplicitOpForMemberBase(expr->functionExpr, needDeref); + auto checkedExpr = CheckInvokeExprWithCheckedOperands(expr); - LookupResult lookupResult = lookUpMember( - m_astBuilder, - this, - operatorName, - expr->functionExpr->type, - m_outerScope, - LookupMask::Default, - LookupOptions::NoDeref); - bool diagnosed = false; - lookupResult = filterLookupResultByVisibilityAndDiagnose(lookupResult, expr->loc, diagnosed); - if (!lookupResult.isValid()) - { - if (!diagnosed) - getSink()->diagnose(expr, Diagnostics::callOperatorNotFound, baseType); - return CreateErrorExpr(expr); - } - auto callFuncExpr = createLookupResultExpr( - operatorName, - lookupResult, - expr->functionExpr, - expr->loc, - expr->functionExpr); - expr->functionExpr = callFuncExpr; - } + // Perform additional validation for known built-in functions. + maybeCheckKnownBuiltinInvocation(checkedExpr); - m_treatAsDifferentiableExpr = treatAsDifferentiableExpr; + if (m_parentDifferentiableAttr) + { + FunctionDifferentiableLevel callerDiffLevel = FunctionDifferentiableLevel::None; + if (m_parentFunc) + callerDiffLevel = getShared()->getFuncDifferentiableLevel(m_parentFunc); - // If we are in a differentiable function, register differential witness tables involved in - // this call. - if (m_parentFunc && m_parentFunc->hasModifier()) + if (auto checkedInvokeExpr = as(checkedExpr)) { + // Register types for final resolved invoke arguments again. for (auto& arg : expr->arguments) { maybeRegisterDifferentiableType(m_astBuilder, arg->type.type); } - } - - auto checkedExpr = CheckInvokeExprWithCheckedOperands(expr); - - // Perform additional validation for known built-in functions. - maybeCheckKnownBuiltinInvocation(checkedExpr); - - if (m_parentDifferentiableAttr) - { - FunctionDifferentiableLevel callerDiffLevel = FunctionDifferentiableLevel::None; - if (m_parentFunc) - callerDiffLevel = getShared()->getFuncDifferentiableLevel(m_parentFunc); - if (auto checkedInvokeExpr = as(checkedExpr)) + if (auto calleeExpr = as(checkedInvokeExpr->functionExpr)) { - // Register types for final resolved invoke arguments again. - for (auto& arg : expr->arguments) - { - maybeRegisterDifferentiableType(m_astBuilder, arg->type.type); - } - - if (auto calleeExpr = as(checkedInvokeExpr->functionExpr)) + if (auto calleeDecl = as(calleeExpr->declRef.getDecl())) { - if (auto calleeDecl = as(calleeExpr->declRef.getDecl())) + auto calleeDiffLevel = getShared()->getFuncDifferentiableLevel(calleeDecl); + if (calleeDiffLevel >= callerDiffLevel) { - auto calleeDiffLevel = getShared()->getFuncDifferentiableLevel(calleeDecl); - if (calleeDiffLevel >= callerDiffLevel) + if (!m_treatAsDifferentiableExpr) { - if (!m_treatAsDifferentiableExpr) - { - auto newFuncExpr = - getASTBuilder()->create(); - newFuncExpr->type = checkedInvokeExpr->type; - newFuncExpr->innerExpr = checkedInvokeExpr; - newFuncExpr->loc = checkedInvokeExpr->loc; - newFuncExpr->flavor = TreatAsDifferentiableExpr::Flavor::Differentiable; - checkedExpr = newFuncExpr; - } - else - { - getSink()->diagnose( - m_treatAsDifferentiableExpr, - Diagnostics::useOfNoDiffOnDifferentiableFunc); - } + auto newFuncExpr = getASTBuilder()->create(); + newFuncExpr->type = checkedInvokeExpr->type; + newFuncExpr->innerExpr = checkedInvokeExpr; + newFuncExpr->loc = checkedInvokeExpr->loc; + newFuncExpr->flavor = TreatAsDifferentiableExpr::Flavor::Differentiable; + checkedExpr = newFuncExpr; + } + else + { + getSink()->diagnose( + m_treatAsDifferentiableExpr, + Diagnostics::useOfNoDiffOnDifferentiableFunc); } } } } - maybeRegisterDifferentiableType(m_astBuilder, checkedExpr->type.type); } - return checkedExpr; + maybeRegisterDifferentiableType(m_astBuilder, checkedExpr->type.type); } + return checkedExpr; +} - Expr* SemanticsExprVisitor::visitVarExpr(VarExpr *expr) +Expr* SemanticsExprVisitor::visitVarExpr(VarExpr* expr) +{ + // If we've already resolved this expression, don't try again. + if (expr->declRef) { - // If we've already resolved this expression, don't try again. - if (expr->declRef) - { - if (!expr->type) - expr->type = GetTypeForDeclRef(expr->declRef, expr->loc); - return expr; - } - expr->type = QualType(m_astBuilder->getErrorType()); - auto lookupResult = lookUp( - m_astBuilder, this, expr->name, expr->scope, LookupMask::Default, false, getDeclToExcludeFromLookup()); - - bool diagnosed = false; - lookupResult = filterLookupResultByVisibilityAndDiagnose(lookupResult, expr->loc, diagnosed); + if (!expr->type) + expr->type = GetTypeForDeclRef(expr->declRef, expr->loc); + return expr; + } + expr->type = QualType(m_astBuilder->getErrorType()); + auto lookupResult = lookUp( + m_astBuilder, + this, + expr->name, + expr->scope, + LookupMask::Default, + false, + getDeclToExcludeFromLookup()); + + bool diagnosed = false; + lookupResult = filterLookupResultByVisibilityAndDiagnose(lookupResult, expr->loc, diagnosed); + + if (expr->name == getSession()->getCompletionRequestTokenName()) + { + auto scopeKind = CompletionSuggestions::ScopeKind::Expr; + if (!m_parentFunc) + scopeKind = CompletionSuggestions::ScopeKind::Decl; + suggestCompletionItems(scopeKind, lookupResult); + return expr; + } - if (expr->name == getSession()->getCompletionRequestTokenName()) - { - auto scopeKind = CompletionSuggestions::ScopeKind::Expr; - if (!m_parentFunc) - scopeKind = CompletionSuggestions::ScopeKind::Decl; - suggestCompletionItems(scopeKind, lookupResult); - return expr; - } + if (lookupResult.isValid()) + { + return createLookupResultExpr(expr->name, lookupResult, nullptr, expr->loc, expr); + } - if (lookupResult.isValid()) - { - return createLookupResultExpr( - expr->name, - lookupResult, - nullptr, - expr->loc, - expr); - } + if (!diagnosed) + getSink()->diagnose(expr, Diagnostics::undefinedIdentifier2, expr->name); - if (!diagnosed) - getSink()->diagnose(expr, Diagnostics::undefinedIdentifier2, expr->name); + return expr; +} - return expr; +Type* SemanticsVisitor::_toDifferentialParamType(Type* primalType) +{ + // Check for type modifiers like 'out' and 'inout'. We need to differentiate the + // nested type. + // + if (auto primalOutType = as(primalType)) + { + return m_astBuilder->getOutType(_toDifferentialParamType(primalOutType->getValueType())); } + else if (auto primalInOutType = as(primalType)) + { + return m_astBuilder->getInOutType( + _toDifferentialParamType(primalInOutType->getValueType())); + } + return getDifferentialPairType(primalType); +} - Type* SemanticsVisitor::_toDifferentialParamType(Type* primalType) +Type* SemanticsVisitor::getDifferentialPairType(Type* primalType) +{ + if (auto modifiedType = as(primalType)) { - // Check for type modifiers like 'out' and 'inout'. We need to differentiate the - // nested type. - // - if (auto primalOutType = as(primalType)) - { - return m_astBuilder->getOutType(_toDifferentialParamType(primalOutType->getValueType())); - } - else if (auto primalInOutType = as(primalType)) - { - return m_astBuilder->getInOutType(_toDifferentialParamType(primalInOutType->getValueType())); - } - return getDifferentialPairType(primalType); + if (modifiedType->findModifier()) + return modifiedType->getBase(); } - Type* SemanticsVisitor::getDifferentialPairType(Type* primalType) + if (auto typePack = as(primalType)) { - if (auto modifiedType = as(primalType)) + // The differential pair of a type pack should be a type pack of differential pairs. + List diffTypes; + for (Index i = 0; i < typePack->getTypeCount(); i++) { - if (modifiedType->findModifier()) - return modifiedType->getBase(); + auto t = typePack->getElementType(i); + diffTypes.add(getDifferentialPairType(t)); } - - if (auto typePack = as(primalType)) + return m_astBuilder->getTypePack(diffTypes.getArrayView()); + } + else if (isAbstractTypePack(primalType)) + { + // The differential pair of an abstract type pack P should be `expand DifferentialPair`. + auto eachType = m_astBuilder->getEachType(primalType); + auto diffPairEachType = getDifferentialPairType(eachType); + if (auto expandType = as(primalType)) { - // The differential pair of a type pack should be a type pack of differential pairs. - List diffTypes; - for (Index i = 0; i < typePack->getTypeCount(); i++) + List capturedTypePacks; + for (Index i = 0; i < expandType->getCapturedTypePackCount(); i++) { - auto t = typePack->getElementType(i); - diffTypes.add(getDifferentialPairType(t)); + capturedTypePacks.add(expandType->getCapturedTypePack(i)); } - return m_astBuilder->getTypePack(diffTypes.getArrayView()); + return m_astBuilder->getExpandType(diffPairEachType, capturedTypePacks.getArrayView()); } - else if (isAbstractTypePack(primalType)) + else { - // The differential pair of an abstract type pack P should be `expand DifferentialPair`. - auto eachType = m_astBuilder->getEachType(primalType); - auto diffPairEachType = getDifferentialPairType(eachType); - if (auto expandType = as(primalType)) - { - List capturedTypePacks; - for (Index i = 0; i < expandType->getCapturedTypePackCount(); i++) - { - capturedTypePacks.add(expandType->getCapturedTypePack(i)); - } - return m_astBuilder->getExpandType(diffPairEachType, capturedTypePacks.getArrayView()); - } - else - { - return m_astBuilder->getExpandType(diffPairEachType, makeArrayViewSingle(primalType)); - } + return m_astBuilder->getExpandType(diffPairEachType, makeArrayViewSingle(primalType)); } + } - // Get a reference to the builtin 'IDifferentiable' interface - auto differentiableInterface = getASTBuilder()->getDifferentiableInterfaceType(); - auto differentiableRefInterface = getASTBuilder()->getDifferentiableRefInterfaceType(); + // Get a reference to the builtin 'IDifferentiable' interface + auto differentiableInterface = getASTBuilder()->getDifferentiableInterfaceType(); + auto differentiableRefInterface = getASTBuilder()->getDifferentiableRefInterfaceType(); - // Check if the provided type inherits from IDifferentiable. - // If not, return the original type. - if (auto conformanceWitness = isTypeDifferentiable(primalType)) + // Check if the provided type inherits from IDifferentiable. + // If not, return the original type. + if (auto conformanceWitness = isTypeDifferentiable(primalType)) + { + if (conformanceWitness->getSup() == differentiableInterface) { - if (conformanceWitness->getSup() == differentiableInterface) - { - return m_astBuilder->getDifferentialPairType(primalType, conformanceWitness); - } - else if (conformanceWitness->getSup() == differentiableRefInterface) - { - return m_astBuilder->getDifferentialPtrPairType(primalType, conformanceWitness); - } + return m_astBuilder->getDifferentialPairType(primalType, conformanceWitness); + } + else if (conformanceWitness->getSup() == differentiableRefInterface) + { + return m_astBuilder->getDifferentialPtrPairType(primalType, conformanceWitness); } - return primalType; } + return primalType; +} - Type* SemanticsVisitor::getForwardDiffFuncType(FuncType* originalType) - { - // Resolve JVP type here. - // Note that this type checking needs to be in sync with - // the auto-generation logic in slang-ir-jvp-diff.cpp - List paramTypes; +Type* SemanticsVisitor::getForwardDiffFuncType(FuncType* originalType) +{ + // Resolve JVP type here. + // Note that this type checking needs to be in sync with + // the auto-generation logic in slang-ir-jvp-diff.cpp + List paramTypes; - // The JVP return type is float if primal return type is float - // void otherwise. - // - auto resultType = getDifferentialPairType(originalType->getResultType()); - - // No support for differentiating function that throw errors, for now. - SLANG_ASSERT(originalType->getErrorType()->equals(m_astBuilder->getBottomType())); - auto errorType = originalType->getErrorType(); + // The JVP return type is float if primal return type is float + // void otherwise. + // + auto resultType = getDifferentialPairType(originalType->getResultType()); - for (Index i = 0; i < originalType->getParamCount(); i++) - { - if(auto jvpParamType = _toDifferentialParamType(originalType->getParamType(i))) - paramTypes.add(jvpParamType); - } - FuncType* jvpType = m_astBuilder->getOrCreate(paramTypes.getArrayView(), resultType, errorType); + // No support for differentiating function that throw errors, for now. + SLANG_ASSERT(originalType->getErrorType()->equals(m_astBuilder->getBottomType())); + auto errorType = originalType->getErrorType(); - return jvpType; + for (Index i = 0; i < originalType->getParamCount(); i++) + { + if (auto jvpParamType = _toDifferentialParamType(originalType->getParamType(i))) + paramTypes.add(jvpParamType); } + FuncType* jvpType = + m_astBuilder->getOrCreate(paramTypes.getArrayView(), resultType, errorType); - Type* SemanticsVisitor::getBackwardDiffFuncType(FuncType* originalType) - { - // Resolve backward diff type here. - // Note that this type checking needs to be in sync with - // the auto-generation logic in slang-ir-jvp-diff.cpp - List paramTypes; + return jvpType; +} - // The backward diff return type is void - // - auto resultType = m_astBuilder->getVoidType(); +Type* SemanticsVisitor::getBackwardDiffFuncType(FuncType* originalType) +{ + // Resolve backward diff type here. + // Note that this type checking needs to be in sync with + // the auto-generation logic in slang-ir-jvp-diff.cpp + List paramTypes; + + // The backward diff return type is void + // + auto resultType = m_astBuilder->getVoidType(); - // No support for differentiating function that throw errors, for now. - SLANG_ASSERT(originalType->getErrorType()->equals(m_astBuilder->getBottomType())); - auto errorType = originalType->getErrorType(); + // No support for differentiating function that throw errors, for now. + SLANG_ASSERT(originalType->getErrorType()->equals(m_astBuilder->getBottomType())); + auto errorType = originalType->getErrorType(); - for (Index i = 0; i < originalType->getParamCount(); i++) + for (Index i = 0; i < originalType->getParamCount(); i++) + { + if (auto outType = as(originalType->getParamType(i))) { - if (auto outType = as(originalType->getParamType(i))) + auto diffElementType = tryGetDifferentialType(m_astBuilder, outType->getValueType()); + if (diffElementType) { - auto diffElementType = - tryGetDifferentialType(m_astBuilder, outType->getValueType()); - if (diffElementType) - { - paramTypes.add(diffElementType); - } - else - { - continue; - } + paramTypes.add(diffElementType); + } + else + { + continue; } - else if (auto derivType = _toDifferentialParamType(originalType->getParamType(i))) + } + else if (auto derivType = _toDifferentialParamType(originalType->getParamType(i))) + { + if (as(derivType)) { - if (as(derivType)) - { - // An `in` differentiable parameter becomes an `inout` parameter. - derivType = m_astBuilder->getInOutType(derivType); - } - else if (auto inoutType = as(derivType)) + // An `in` differentiable parameter becomes an `inout` parameter. + derivType = m_astBuilder->getInOutType(derivType); + } + else if (auto inoutType = as(derivType)) + { + if (!as(inoutType->getValueType())) { - if (!as(inoutType->getValueType())) - { - // An `inout` non differentiable parameter becomes an `in` parameter - // (removing `out`). - derivType = inoutType->getValueType(); - } + // An `inout` non differentiable parameter becomes an `in` parameter + // (removing `out`). + derivType = inoutType->getValueType(); } - paramTypes.add(derivType); } + paramTypes.add(derivType); } - - // Last parameter is the initial derivative of the original return type - auto dOutType = tryGetDifferentialType(m_astBuilder, originalType->getResultType()); - if (dOutType) - paramTypes.add(dOutType); - - return m_astBuilder->getOrCreate(paramTypes.getArrayView(), resultType, errorType); } - struct HigherOrderInvokeExprCheckingActions + // Last parameter is the initial derivative of the original return type + auto dOutType = tryGetDifferentialType(m_astBuilder, originalType->getResultType()); + if (dOutType) + paramTypes.add(dOutType); + + return m_astBuilder->getOrCreate(paramTypes.getArrayView(), resultType, errorType); +} + +struct HigherOrderInvokeExprCheckingActions +{ + virtual HigherOrderInvokeExpr* createHigherOrderInvokeExpr(SemanticsVisitor* semantics) = 0; + virtual void fillHigherOrderInvokeExpr( + HigherOrderInvokeExpr* resultDiffExpr, + SemanticsVisitor* semantics, + Expr* funcExpr) = 0; + FuncType* getBaseFunctionType(SemanticsVisitor* semantics, Expr* funcExpr) { - virtual HigherOrderInvokeExpr* createHigherOrderInvokeExpr(SemanticsVisitor* semantics) = 0; - virtual void fillHigherOrderInvokeExpr(HigherOrderInvokeExpr* resultDiffExpr, SemanticsVisitor* semantics, Expr* funcExpr) = 0; - FuncType* getBaseFunctionType(SemanticsVisitor* semantics, Expr* funcExpr) - { - if (auto funcType = as(funcExpr->type.type)) + if (auto funcType = as(funcExpr->type.type)) + return funcType; + auto astBuilder = semantics->getASTBuilder(); + if (auto declRefExpr = as(funcExpr)) + { + if (auto baseFuncGenericDeclRef = declRefExpr->declRef.as()) + { + // Get inner function + DeclRef unspecializedInnerRef = createDefaultSubstitutionsIfNeeded( + astBuilder, + semantics, + astBuilder->getMemberDeclRef( + baseFuncGenericDeclRef, + getInner(baseFuncGenericDeclRef))); + auto callableDeclRef = unspecializedInnerRef.as(); + if (!callableDeclRef) + return nullptr; + auto funcType = getFuncType(astBuilder, callableDeclRef); return funcType; - auto astBuilder = semantics->getASTBuilder(); - if (auto declRefExpr = as(funcExpr)) - { - if (auto baseFuncGenericDeclRef = declRefExpr->declRef.as()) - { - // Get inner function - DeclRef unspecializedInnerRef = createDefaultSubstitutionsIfNeeded(astBuilder, semantics, - astBuilder->getMemberDeclRef(baseFuncGenericDeclRef, getInner(baseFuncGenericDeclRef))); - auto callableDeclRef = unspecializedInnerRef.as(); - if (!callableDeclRef) - return nullptr; - auto funcType = getFuncType(astBuilder, callableDeclRef); - return funcType; - } } - return nullptr; } - }; + return nullptr; + } +}; - struct ForwardDifferentiateExprCheckingActions : HigherOrderInvokeExprCheckingActions +struct ForwardDifferentiateExprCheckingActions : HigherOrderInvokeExprCheckingActions +{ + virtual HigherOrderInvokeExpr* createHigherOrderInvokeExpr(SemanticsVisitor* semantics) override { - virtual HigherOrderInvokeExpr* createHigherOrderInvokeExpr(SemanticsVisitor* semantics) override + return semantics->getASTBuilder()->create(); + } + void fillHigherOrderInvokeExpr( + HigherOrderInvokeExpr* resultDiffExpr, + SemanticsVisitor* semantics, + Expr* funcExpr) override + { + resultDiffExpr->baseFunction = funcExpr; + auto baseFuncType = getBaseFunctionType(semantics, funcExpr); + if (!baseFuncType) { - return semantics->getASTBuilder()->create(); + resultDiffExpr->type = semantics->getASTBuilder()->getErrorType(); + semantics->getSink()->diagnose( + funcExpr, + Diagnostics::expectedFunction, + funcExpr->type.type); + return; } - void fillHigherOrderInvokeExpr(HigherOrderInvokeExpr* resultDiffExpr, SemanticsVisitor* semantics, Expr* funcExpr) override + resultDiffExpr->type = semantics->getForwardDiffFuncType(baseFuncType); + if (auto declRefExpr = as(getInnerMostExprFromHigherOrderExpr(funcExpr))) { - resultDiffExpr->baseFunction = funcExpr; - auto baseFuncType = getBaseFunctionType(semantics, funcExpr); - if (!baseFuncType) + auto funcDecl = declRefExpr->declRef.as().getDecl(); + if (auto genDecl = as(declRefExpr->declRef.getDecl())) { - resultDiffExpr->type = semantics->getASTBuilder()->getErrorType(); - semantics->getSink()->diagnose(funcExpr, Diagnostics::expectedFunction, funcExpr->type.type); - return; + funcDecl = as(genDecl->inner); } - resultDiffExpr->type = semantics->getForwardDiffFuncType(baseFuncType); - if (auto declRefExpr = as(getInnerMostExprFromHigherOrderExpr(funcExpr))) + if (funcDecl) { - auto funcDecl = declRefExpr->declRef.as().getDecl(); - if (auto genDecl = as(declRefExpr->declRef.getDecl())) - { - funcDecl = as(genDecl->inner); - } - if (funcDecl) + for (auto param : funcDecl->getParameters()) { - for (auto param : funcDecl->getParameters()) - { - resultDiffExpr->newParameterNames.add(param->getName()); - } + resultDiffExpr->newParameterNames.add(param->getName()); } } } - }; + } +}; - struct BackwardDifferentiateExprCheckingActions : HigherOrderInvokeExprCheckingActions +struct BackwardDifferentiateExprCheckingActions : HigherOrderInvokeExprCheckingActions +{ + virtual HigherOrderInvokeExpr* createHigherOrderInvokeExpr(SemanticsVisitor* semantics) override { - virtual HigherOrderInvokeExpr* createHigherOrderInvokeExpr(SemanticsVisitor* semantics) override + return semantics->getASTBuilder()->create(); + } + void fillHigherOrderInvokeExpr( + HigherOrderInvokeExpr* resultDiffExpr, + SemanticsVisitor* semantics, + Expr* funcExpr) override + { + resultDiffExpr->baseFunction = funcExpr; + auto baseFuncType = getBaseFunctionType(semantics, funcExpr); + if (!baseFuncType) { - return semantics->getASTBuilder()->create(); + resultDiffExpr->type = semantics->getASTBuilder()->getErrorType(); + semantics->getSink()->diagnose( + funcExpr, + Diagnostics::expectedFunction, + funcExpr->type.type); + return; } - void fillHigherOrderInvokeExpr(HigherOrderInvokeExpr* resultDiffExpr, SemanticsVisitor* semantics, Expr* funcExpr) override + resultDiffExpr->type = semantics->getBackwardDiffFuncType(baseFuncType); + if (auto declRefExpr = as(getInnerMostExprFromHigherOrderExpr(funcExpr))) { - resultDiffExpr->baseFunction = funcExpr; - auto baseFuncType = getBaseFunctionType(semantics, funcExpr); - if (!baseFuncType) + auto funcDecl = declRefExpr->declRef.as().getDecl(); + if (auto genDecl = as(declRefExpr->declRef.getDecl())) { - resultDiffExpr->type = semantics->getASTBuilder()->getErrorType(); - semantics->getSink()->diagnose(funcExpr, Diagnostics::expectedFunction, funcExpr->type.type); - return; + funcDecl = as(genDecl->inner); } - resultDiffExpr->type = semantics->getBackwardDiffFuncType(baseFuncType); - if (auto declRefExpr = as(getInnerMostExprFromHigherOrderExpr(funcExpr))) + if (funcDecl) { - auto funcDecl = declRefExpr->declRef.as().getDecl(); - if (auto genDecl = as(declRefExpr->declRef.getDecl())) - { - funcDecl = as(genDecl->inner); - } - if (funcDecl) + for (auto param : funcDecl->getParameters()) { - for (auto param : funcDecl->getParameters()) + if (param->findModifier()) { - if (param->findModifier()) - { - if (param->findModifier() && - !param->findModifier() && - !param->findModifier()) - continue; - } - resultDiffExpr->newParameterNames.add(param->getName()); + if (param->findModifier() && + !param->findModifier() && + !param->findModifier()) + continue; } - resultDiffExpr->newParameterNames.add(semantics->getName("resultGradient")); + resultDiffExpr->newParameterNames.add(param->getName()); } + resultDiffExpr->newParameterNames.add(semantics->getName("resultGradient")); } } - }; + } +}; - template - struct PassthroughHighOrderExprCheckingActionsBase : HigherOrderInvokeExprCheckingActions +template +struct PassthroughHighOrderExprCheckingActionsBase : HigherOrderInvokeExprCheckingActions +{ + virtual HigherOrderInvokeExpr* createHigherOrderInvokeExpr(SemanticsVisitor* semantics) override + { + return semantics->getASTBuilder()->create(); + } + void fillHigherOrderInvokeExpr( + HigherOrderInvokeExpr* resultDiffExpr, + SemanticsVisitor* semantics, + Expr* funcExpr) override { - virtual HigherOrderInvokeExpr* createHigherOrderInvokeExpr(SemanticsVisitor* semantics) override + resultDiffExpr->baseFunction = funcExpr; + auto baseFuncType = getBaseFunctionType(semantics, funcExpr); + if (!baseFuncType) { - return semantics->getASTBuilder()->create(); + resultDiffExpr->type = semantics->getASTBuilder()->getErrorType(); + semantics->getSink()->diagnose( + funcExpr, + Diagnostics::expectedFunction, + funcExpr->type.type); + return; } - void fillHigherOrderInvokeExpr(HigherOrderInvokeExpr* resultDiffExpr, SemanticsVisitor* semantics, Expr* funcExpr) override + resultDiffExpr->type = baseFuncType; + if (auto declRefExpr = as(getInnerMostExprFromHigherOrderExpr(funcExpr))) { - resultDiffExpr->baseFunction = funcExpr; - auto baseFuncType = getBaseFunctionType(semantics, funcExpr); - if (!baseFuncType) + auto funcDecl = declRefExpr->declRef.as().getDecl(); + if (auto genDecl = as(declRefExpr->declRef.getDecl())) { - resultDiffExpr->type = semantics->getASTBuilder()->getErrorType(); - semantics->getSink()->diagnose(funcExpr, Diagnostics::expectedFunction, funcExpr->type.type); - return; + funcDecl = as(genDecl->inner); } - resultDiffExpr->type = baseFuncType; - if (auto declRefExpr = as(getInnerMostExprFromHigherOrderExpr(funcExpr))) + if (funcDecl) { - auto funcDecl = declRefExpr->declRef.as().getDecl(); - if (auto genDecl = as(declRefExpr->declRef.getDecl())) - { - funcDecl = as(genDecl->inner); - } - if (funcDecl) + for (auto param : funcDecl->getParameters()) { - for (auto param : funcDecl->getParameters()) - { - resultDiffExpr->newParameterNames.add(param->getName()); - } + resultDiffExpr->newParameterNames.add(param->getName()); } } } - }; + } +}; - static Expr* _checkHigherOrderInvokeExpr( - SemanticsVisitor* semantics, - HigherOrderInvokeExpr* expr, - HigherOrderInvokeExprCheckingActions* actions) - { - // Check/Resolve inner function declaration. - expr->baseFunction = semantics->CheckTerm(expr->baseFunction); +static Expr* _checkHigherOrderInvokeExpr( + SemanticsVisitor* semantics, + HigherOrderInvokeExpr* expr, + HigherOrderInvokeExprCheckingActions* actions) +{ + // Check/Resolve inner function declaration. + expr->baseFunction = semantics->CheckTerm(expr->baseFunction); - auto astBuilder = semantics->getASTBuilder(); + auto astBuilder = semantics->getASTBuilder(); - // If base is overloaded expr, we want to return an overloaded expr as check result. - // This is done by pushing the `differentiate` operator to each item in the overloaded expr. - if (auto overloadedExpr = as(expr->baseFunction)) + // If base is overloaded expr, we want to return an overloaded expr as check result. + // This is done by pushing the `differentiate` operator to each item in the overloaded expr. + if (auto overloadedExpr = as(expr->baseFunction)) + { + OverloadedExpr2* result = astBuilder->create(); + for (auto item : overloadedExpr->lookupResult2) { - OverloadedExpr2* result = astBuilder->create(); - for (auto item : overloadedExpr->lookupResult2) - { - auto lookupResultExpr = semantics->ConstructLookupResultExpr(item, - nullptr, - overloadedExpr->name, - overloadedExpr->loc, - nullptr); - auto candidateExpr = actions->createHigherOrderInvokeExpr(semantics); - actions->fillHigherOrderInvokeExpr(candidateExpr, semantics, lookupResultExpr); - candidateExpr->loc = expr->loc; - result->candidiateExprs.add(candidateExpr); - } - result->type.type = astBuilder->getOverloadedType(); - result->loc = expr->loc; - return result; - } - else if (auto overloadedExpr2 = as(expr->baseFunction)) + auto lookupResultExpr = semantics->ConstructLookupResultExpr( + item, + nullptr, + overloadedExpr->name, + overloadedExpr->loc, + nullptr); + auto candidateExpr = actions->createHigherOrderInvokeExpr(semantics); + actions->fillHigherOrderInvokeExpr(candidateExpr, semantics, lookupResultExpr); + candidateExpr->loc = expr->loc; + result->candidiateExprs.add(candidateExpr); + } + result->type.type = astBuilder->getOverloadedType(); + result->loc = expr->loc; + return result; + } + else if (auto overloadedExpr2 = as(expr->baseFunction)) + { + OverloadedExpr2* result = astBuilder->create(); + for (auto item : overloadedExpr2->candidiateExprs) { - OverloadedExpr2* result = astBuilder->create(); - for (auto item : overloadedExpr2->candidiateExprs) - { - auto candidateExpr = actions->createHigherOrderInvokeExpr(semantics); - actions->fillHigherOrderInvokeExpr(candidateExpr, semantics, item); - candidateExpr->loc = expr->loc; - result->candidiateExprs.add(candidateExpr); - } - result->type.type = astBuilder->getOverloadedType(); - result->loc = expr->loc; - return result; + auto candidateExpr = actions->createHigherOrderInvokeExpr(semantics); + actions->fillHigherOrderInvokeExpr(candidateExpr, semantics, item); + candidateExpr->loc = expr->loc; + result->candidiateExprs.add(candidateExpr); } - - actions->fillHigherOrderInvokeExpr(expr, semantics, expr->baseFunction); - return expr; + result->type.type = astBuilder->getOverloadedType(); + result->loc = expr->loc; + return result; } - Expr* SemanticsExprVisitor::visitForwardDifferentiateExpr(ForwardDifferentiateExpr* expr) + actions->fillHigherOrderInvokeExpr(expr, semantics, expr->baseFunction); + return expr; +} + +Expr* SemanticsExprVisitor::visitForwardDifferentiateExpr(ForwardDifferentiateExpr* expr) +{ + ForwardDifferentiateExprCheckingActions actions; + return _checkHigherOrderInvokeExpr(this, expr, &actions); +} + +Expr* SemanticsExprVisitor::visitBackwardDifferentiateExpr(BackwardDifferentiateExpr* expr) +{ + BackwardDifferentiateExprCheckingActions actions; + return _checkHigherOrderInvokeExpr(this, expr, &actions); +} + +Expr* SemanticsExprVisitor::visitPrimalSubstituteExpr(PrimalSubstituteExpr* expr) +{ + PassthroughHighOrderExprCheckingActionsBase actions; + return _checkHigherOrderInvokeExpr(this, expr, &actions); +} + +Expr* SemanticsExprVisitor::visitDispatchKernelExpr(DispatchKernelExpr* expr) +{ + auto isInt3Type = [this](Type* type) + { + auto vectorType = as(type); + if (!vectorType) + return false; + if (!isIntegerBaseType(getVectorBaseType(vectorType))) + return false; + auto constElementCount = as(vectorType->getElementCount()); + if (!constElementCount) + return false; + return constElementCount->getValue() == 3; + }; + expr->threadGroupSize = dispatchExpr(expr->threadGroupSize, *this); + if (!isInt3Type(expr->threadGroupSize->type.type)) { - ForwardDifferentiateExprCheckingActions actions; - return _checkHigherOrderInvokeExpr(this, expr, &actions); + getSink()->diagnose( + expr->threadGroupSize, + Diagnostics::typeMismatch, + "uint3", + expr->threadGroupSize->type); } - - Expr* SemanticsExprVisitor::visitBackwardDifferentiateExpr(BackwardDifferentiateExpr* expr) + expr->dispatchSize = dispatchExpr(expr->dispatchSize, *this); + if (!isInt3Type(expr->dispatchSize->type.type)) { - BackwardDifferentiateExprCheckingActions actions; - return _checkHigherOrderInvokeExpr(this, expr, &actions); + getSink()->diagnose( + expr->dispatchSize, + Diagnostics::typeMismatch, + "uint3", + expr->dispatchSize->type); } + PassthroughHighOrderExprCheckingActionsBase actions; + return _checkHigherOrderInvokeExpr(this, expr, &actions); +} - Expr* SemanticsExprVisitor::visitPrimalSubstituteExpr(PrimalSubstituteExpr* expr) +Expr* SemanticsExprVisitor::visitTreatAsDifferentiableExpr(TreatAsDifferentiableExpr* expr) +{ + auto subContext = withTreatAsDifferentiable(expr); + expr->innerExpr = dispatchExpr(expr->innerExpr, subContext); + expr->type = expr->innerExpr->type; + auto innerExpr = expr->innerExpr; + while (auto parenExpr = as(innerExpr)) { - PassthroughHighOrderExprCheckingActionsBase actions; - return _checkHigherOrderInvokeExpr(this, expr, &actions); + innerExpr = parenExpr->base; } - - Expr* SemanticsExprVisitor::visitDispatchKernelExpr(DispatchKernelExpr* expr) + if (!as(innerExpr) && !as(innerExpr)) { - auto isInt3Type = [this](Type* type) - { - auto vectorType = as(type); - if (!vectorType) - return false; - if (!isIntegerBaseType(getVectorBaseType(vectorType))) - return false; - auto constElementCount = as(vectorType->getElementCount()); - if (!constElementCount) - return false; - return constElementCount->getValue() == 3; - }; - expr->threadGroupSize = dispatchExpr(expr->threadGroupSize, *this); - if (!isInt3Type(expr->threadGroupSize->type.type)) - { - getSink()->diagnose( - expr->threadGroupSize, - Diagnostics::typeMismatch, - "uint3", - expr->threadGroupSize->type); - } - expr->dispatchSize = dispatchExpr(expr->dispatchSize, *this); - if (!isInt3Type(expr->dispatchSize->type.type)) - { - getSink()->diagnose( - expr->dispatchSize, - Diagnostics::typeMismatch, - "uint3", - expr->dispatchSize->type); - } - PassthroughHighOrderExprCheckingActionsBase actions; - return _checkHigherOrderInvokeExpr(this, expr, &actions); + getSink()->diagnose(expr, Diagnostics::invalidUseOfNoDiff); } - - Expr* SemanticsExprVisitor::visitTreatAsDifferentiableExpr(TreatAsDifferentiableExpr* expr) + else if (!m_parentDifferentiableAttr) { - auto subContext = withTreatAsDifferentiable(expr); - expr->innerExpr = dispatchExpr(expr->innerExpr, subContext); - expr->type = expr->innerExpr->type; - auto innerExpr = expr->innerExpr; - while (auto parenExpr = as(innerExpr)) - { - innerExpr = parenExpr->base; - } - if (!as(innerExpr) && !as(innerExpr)) - { - getSink()->diagnose(expr, Diagnostics::invalidUseOfNoDiff); - } - else if (!m_parentDifferentiableAttr) - { - getSink()->diagnose(expr, Diagnostics::cannotUseNoDiffInNonDifferentiableFunc); - } - return expr; + getSink()->diagnose(expr, Diagnostics::cannotUseNoDiffInNonDifferentiableFunc); } + return expr; +} - Expr* SemanticsExprVisitor::visitGetArrayLengthExpr(GetArrayLengthExpr* expr) +Expr* SemanticsExprVisitor::visitGetArrayLengthExpr(GetArrayLengthExpr* expr) +{ + expr->arrayExpr = CheckTerm(expr->arrayExpr); + if (auto arrType = as(expr->arrayExpr->type)) { - expr->arrayExpr = CheckTerm(expr->arrayExpr); - if (auto arrType = as(expr->arrayExpr->type)) + expr->type = m_astBuilder->getIntType(); + if (arrType->isUnsized()) { - expr->type = m_astBuilder->getIntType(); - if (arrType->isUnsized()) - { - getSink()->diagnose(expr, Diagnostics::invalidArraySize); - } + getSink()->diagnose(expr, Diagnostics::invalidArraySize); } - else + } + else + { + if (!as(expr->arrayExpr->type)) { - if (!as(expr->arrayExpr->type)) - { - getSink()->diagnose( - expr, Diagnostics::typeMismatch, "array", expr->arrayExpr->type); - } - expr->type = m_astBuilder->getErrorType(); + getSink()->diagnose(expr, Diagnostics::typeMismatch, "array", expr->arrayExpr->type); } - return expr; + expr->type = m_astBuilder->getErrorType(); } + return expr; +} + +Expr* SemanticsExprVisitor::visitDefaultConstructExpr(DefaultConstructExpr* expr) +{ + return expr; +} + +Expr* SemanticsExprVisitor::visitDetachExpr(DetachExpr* expr) +{ + expr->inner = CheckTerm(expr->inner); + expr->type = expr->inner->type; + return expr; +} - Expr* SemanticsExprVisitor::visitDefaultConstructExpr(DefaultConstructExpr* expr) + +static bool _isSizeOfType(Type* type) +{ + if (!type) { - return expr; + return false; } - Expr* SemanticsExprVisitor::visitDetachExpr(DetachExpr* expr) + if (as(type) || as(type) || + as(type) || as(type) || as(type)) { - expr->inner = CheckTerm(expr->inner); - expr->type = expr->inner->type; - return expr; + return true; } - - static bool _isSizeOfType(Type* type) + if (as(type)) { - if (!type) - { - return false; - } + return true; + } - if (as(type) || - as(type) || - as(type) || - as(type) || - as(type)) - { - return true; - } + return false; +} - if (as(type)) - { - return true; - } - +static bool _isCountOfType(Type* type) +{ + if (!type) + { return false; } - static bool _isCountOfType(Type* type) + if (isTypePack(type)) { - if (!type) - { - return false; - } + return true; + } - if (isTypePack(type)) - { - return true; - } + if (as(type)) + { + return true; + } - if (as(type)) - { - return true; - } + if (as(type)) + { + return true; + } - if (as(type)) - { - return true; - } + return false; +} - return false; - } +Expr* SemanticsExprVisitor::visitSizeOfLikeExpr(SizeOfLikeExpr* sizeOfLikeExpr) +{ + auto valueExpr = dispatch(sizeOfLikeExpr->value); + sizeOfLikeExpr->type = m_astBuilder->getIntType(); - Expr* SemanticsExprVisitor::visitSizeOfLikeExpr(SizeOfLikeExpr* sizeOfLikeExpr) - { - auto valueExpr = dispatch(sizeOfLikeExpr->value); - sizeOfLikeExpr->type = m_astBuilder->getIntType(); - - Type* type = nullptr; + Type* type = nullptr; - if (as(valueExpr->type)) - { - TypeExp typeExp; - typeExp.exp = valueExpr; + if (as(valueExpr->type)) + { + TypeExp typeExp; + typeExp.exp = valueExpr; - auto properTypeExpr = CoerceToProperType(typeExp); + auto properTypeExpr = CoerceToProperType(typeExp); - type = properTypeExpr.type; - } - else - { - // Is this a proper type? - TypeExp typeExp(valueExpr->type); - TypeExp properType = tryCoerceToProperType(typeExp); + type = properTypeExpr.type; + } + else + { + // Is this a proper type? + TypeExp typeExp(valueExpr->type); + TypeExp properType = tryCoerceToProperType(typeExp); - type = properType.type; - } + type = properType.type; + } - if (as(sizeOfLikeExpr)) + if (as(sizeOfLikeExpr)) + { + if (!_isCountOfType(type)) { - if (!_isCountOfType(type)) - { - getSink()->diagnose(sizeOfLikeExpr, Diagnostics::countOfArgumentIsInvalid); + getSink()->diagnose(sizeOfLikeExpr, Diagnostics::countOfArgumentIsInvalid); - sizeOfLikeExpr->type = m_astBuilder->getErrorType(); - return sizeOfLikeExpr; - } + sizeOfLikeExpr->type = m_astBuilder->getErrorType(); + return sizeOfLikeExpr; } - else + } + else + { + if (!_isSizeOfType(type)) { - if (!_isSizeOfType(type)) - { - getSink()->diagnose(sizeOfLikeExpr, Diagnostics::sizeOfArgumentIsInvalid); + getSink()->diagnose(sizeOfLikeExpr, Diagnostics::sizeOfArgumentIsInvalid); - sizeOfLikeExpr->type = m_astBuilder->getErrorType(); - return sizeOfLikeExpr; - } + sizeOfLikeExpr->type = m_astBuilder->getErrorType(); + return sizeOfLikeExpr; } + } - sizeOfLikeExpr->sizedType = type; + sizeOfLikeExpr->sizedType = type; - return sizeOfLikeExpr; - } + return sizeOfLikeExpr; +} - Expr* SemanticsExprVisitor::visitBuiltinCastExpr(BuiltinCastExpr* expr) - { - // All builtin cast exprs should already be checked. - return expr; - } +Expr* SemanticsExprVisitor::visitBuiltinCastExpr(BuiltinCastExpr* expr) +{ + // All builtin cast exprs should already be checked. + return expr; +} - Expr* SemanticsExprVisitor::visitTypeCastExpr(TypeCastExpr * expr) - { - if (expr->type) - return expr; +Expr* SemanticsExprVisitor::visitTypeCastExpr(TypeCastExpr* expr) +{ + if (expr->type) + return expr; - // Check the term we are applying first - auto funcExpr = expr->functionExpr; - funcExpr = CheckTerm(funcExpr); + // Check the term we are applying first + auto funcExpr = expr->functionExpr; + funcExpr = CheckTerm(funcExpr); - // Now ensure that the term represents a (proper) type. - TypeExp typeExp; - typeExp.exp = funcExpr; - typeExp = CheckProperType(typeExp); + // Now ensure that the term represents a (proper) type. + TypeExp typeExp; + typeExp.exp = funcExpr; + typeExp = CheckProperType(typeExp); - expr->functionExpr = typeExp.exp; - expr->type.type = typeExp.type; + expr->functionExpr = typeExp.exp; + expr->type.type = typeExp.type; - // Next check the argument expression (there should be only one) - for (auto & arg : expr->arguments) - { - arg = CheckTerm(arg); - } + // Next check the argument expression (there should be only one) + for (auto& arg : expr->arguments) + { + arg = CheckTerm(arg); + } - // LEGACY FEATURE: As a backwards-compatibility feature - // for HLSL, we will allow for a cast to a `struct` type - // from a literal zero, with the semantics of default - // initialization. - // - if( auto declRefType = as(typeExp.type) ) + // LEGACY FEATURE: As a backwards-compatibility feature + // for HLSL, we will allow for a cast to a `struct` type + // from a literal zero, with the semantics of default + // initialization. + // + if (auto declRefType = as(typeExp.type)) + { + if (const auto structDeclRef = as(declRefType->getDeclRef())) { - if(const auto structDeclRef = as(declRefType->getDeclRef())) + if (expr->arguments.getCount() == 1) { - if( expr->arguments.getCount() == 1 ) + auto arg = expr->arguments[0]; + if (auto intLitArg = as(arg)) { - auto arg = expr->arguments[0]; - if( auto intLitArg = as(arg) ) + if (getIntegerLiteralValue(intLitArg->token) == 0) { - if(getIntegerLiteralValue(intLitArg->token) == 0) - { - // At this point we have confirmed that the cast - // has the right form, so we want to apply our special case. - // - // TODO: If/when we allow for user-defined initializer/constructor - // definitions we would have to be careful here because it is - // possible that the target type has defined an initializer/constructor - // that takes a single `int` parmaeter and means to call that instead. - // - // For now that should be a non-issue, and in a pinch such a user - // could use `T(0)` instead of `(T) 0` to get around this special - // HLSL legacy feature. + // At this point we have confirmed that the cast + // has the right form, so we want to apply our special case. + // + // TODO: If/when we allow for user-defined initializer/constructor + // definitions we would have to be careful here because it is + // possible that the target type has defined an initializer/constructor + // that takes a single `int` parmaeter and means to call that instead. + // + // For now that should be a non-issue, and in a pinch such a user + // could use `T(0)` instead of `(T) 0` to get around this special + // HLSL legacy feature. - // We will type-check code like: - // - // MyStruct s = (MyStruct) 0; - // - // the same as: - // - // MyStruct s = {}; - // - // That is, we construct an empty initializer list, and then coerce - // that initializer list expression to the desired type (letting - // the code for handling initializer lists work out all of the - // details of what is/isn't valid). This choice means we get - // to benefit from the existing codegen support for initializer - // lists, rather than needing the `(MyStruct) 0` idiom to be - // special-cased in later stages of the compiler. - // - // Note: we use an empty initializer list `{}` instead of an - // initializer list with a single zero `{0}`, which is semantically - // significant if the first field of `MyStruct` had its own - // default initializer defined as part of the `struct` definition. - // Basically we have chosen to interpret the "cast from zero" syntax - // as sugar for default initialization, and *not* specifically - // for zero-initialization. That choice could be revisited if - // users express displeasure. For now there isn't enough usage - // of explicit default initializers for `struct` fields to - // make this a major concern (since they aren't supported in HLSL). - // - InitializerListExpr* initListExpr = m_astBuilder->create(); - initListExpr->loc = expr->loc; - auto checkedInitListExpr = visitInitializerListExpr(initListExpr); + // We will type-check code like: + // + // MyStruct s = (MyStruct) 0; + // + // the same as: + // + // MyStruct s = {}; + // + // That is, we construct an empty initializer list, and then coerce + // that initializer list expression to the desired type (letting + // the code for handling initializer lists work out all of the + // details of what is/isn't valid). This choice means we get + // to benefit from the existing codegen support for initializer + // lists, rather than needing the `(MyStruct) 0` idiom to be + // special-cased in later stages of the compiler. + // + // Note: we use an empty initializer list `{}` instead of an + // initializer list with a single zero `{0}`, which is semantically + // significant if the first field of `MyStruct` had its own + // default initializer defined as part of the `struct` definition. + // Basically we have chosen to interpret the "cast from zero" syntax + // as sugar for default initialization, and *not* specifically + // for zero-initialization. That choice could be revisited if + // users express displeasure. For now there isn't enough usage + // of explicit default initializers for `struct` fields to + // make this a major concern (since they aren't supported in HLSL). + // + InitializerListExpr* initListExpr = + m_astBuilder->create(); + initListExpr->loc = expr->loc; + auto checkedInitListExpr = visitInitializerListExpr(initListExpr); - return coerce(CoercionSite::General, typeExp.type, checkedInitListExpr); - } + return coerce(CoercionSite::General, typeExp.type, checkedInitListExpr); } } } } + } - // Now process this like any other explicit call (so casts - // and constructor calls are semantically equivalent). - return CheckInvokeExprWithCheckedOperands(expr); - } + // Now process this like any other explicit call (so casts + // and constructor calls are semantically equivalent). + return CheckInvokeExprWithCheckedOperands(expr); +} - Expr* SemanticsExprVisitor::visitTryExpr(TryExpr* expr) +Expr* SemanticsExprVisitor::visitTryExpr(TryExpr* expr) +{ + auto prevTryClauseType = m_enclosingTryClauseType; + m_enclosingTryClauseType = expr->tryClauseType; + expr->base = CheckTerm(expr->base); + m_enclosingTryClauseType = prevTryClauseType; + expr->type = expr->base->type; + if (as(expr->type)) + return expr; + + auto parentFunc = this->m_parentFunc; + // TODO: check if the try clause is caught. + // For now we assume all `try`s are not caught (because we don't have catch yet). + if (!parentFunc) { - auto prevTryClauseType = m_enclosingTryClauseType; - m_enclosingTryClauseType = expr->tryClauseType; - expr->base = CheckTerm(expr->base); - m_enclosingTryClauseType = prevTryClauseType; - expr->type = expr->base->type; - if (as(expr->type)) - return expr; - - auto parentFunc = this->m_parentFunc; - // TODO: check if the try clause is caught. - // For now we assume all `try`s are not caught (because we don't have catch yet). - if (!parentFunc) - { - getSink()->diagnose(expr, Diagnostics::uncaughtTryCallInNonThrowFunc); - return expr; - } - if (parentFunc->errorType->equals(m_astBuilder->getBottomType())) - { - getSink()->diagnose(expr, Diagnostics::uncaughtTryCallInNonThrowFunc); - return expr; - } - if (!as(expr->base)) - { - getSink()->diagnose(expr, Diagnostics::tryClauseMustApplyToInvokeExpr); - return expr; - } - auto base = as(expr->base); - if (auto callee = as(base->functionExpr)) + getSink()->diagnose(expr, Diagnostics::uncaughtTryCallInNonThrowFunc); + return expr; + } + if (parentFunc->errorType->equals(m_astBuilder->getBottomType())) + { + getSink()->diagnose(expr, Diagnostics::uncaughtTryCallInNonThrowFunc); + return expr; + } + if (!as(expr->base)) + { + getSink()->diagnose(expr, Diagnostics::tryClauseMustApplyToInvokeExpr); + return expr; + } + auto base = as(expr->base); + if (auto callee = as(base->functionExpr)) + { + if (auto funcCallee = as(callee->declRef.getDecl())) { - if (auto funcCallee = as(callee->declRef.getDecl())) + if (funcCallee->errorType->equals(m_astBuilder->getBottomType())) { - if (funcCallee->errorType->equals(m_astBuilder->getBottomType())) - { - getSink()->diagnose(expr, Diagnostics::tryInvokeCalleeShouldThrow, callee->declRef); - } - if (!parentFunc->errorType->equals(funcCallee->errorType)) - { - getSink()->diagnose( - expr, - Diagnostics::errorTypeOfCalleeIncompatibleWithCaller, - callee->declRef, - funcCallee->errorType, - parentFunc->errorType); - } - return expr; + getSink()->diagnose(expr, Diagnostics::tryInvokeCalleeShouldThrow, callee->declRef); + } + if (!parentFunc->errorType->equals(funcCallee->errorType)) + { + getSink()->diagnose( + expr, + Diagnostics::errorTypeOfCalleeIncompatibleWithCaller, + callee->declRef, + funcCallee->errorType, + parentFunc->errorType); } + return expr; } - getSink()->diagnose(expr, Diagnostics::calleeOfTryCallMustBeFunc); - return expr; } + getSink()->diagnose(expr, Diagnostics::calleeOfTryCallMustBeFunc); + return expr; +} - Expr* SemanticsExprVisitor::visitIsTypeExpr(IsTypeExpr* expr) - { - expr->typeExpr = CheckProperType(expr->typeExpr); - auto originalVal = CheckTerm(expr->value); - expr->type = m_astBuilder->getBoolType(); - expr->value = originalVal; +Expr* SemanticsExprVisitor::visitIsTypeExpr(IsTypeExpr* expr) +{ + expr->typeExpr = CheckProperType(expr->typeExpr); + auto originalVal = CheckTerm(expr->value); + expr->type = m_astBuilder->getBoolType(); + expr->value = originalVal; + + auto valueType = expr->value->type.type; + if (auto typeType = as(valueType)) + valueType = typeType->getType(); + + // If value is a subtype of `type`, then this expr is always true. + if (isSubtype(valueType, expr->typeExpr.type, IsSubTypeOptions::None)) + { + // Instead of returning a BoolLiteralExpr, we use a field to indicate this scenario, + // so that the language server can still see the original syntax tree. + expr->constantVal = m_astBuilder->create(); + expr->constantVal->type = m_astBuilder->getBoolType(); + expr->constantVal->value = true; + expr->constantVal->loc = expr->loc; + return expr; + } - auto valueType = expr->value->type.type; - if (auto typeType = as(valueType)) - valueType = typeType->getType(); + // Otherwise, if the target type is a subtype of value->type, we need to grab the + // subtype witness for runtime checks. - // If value is a subtype of `type`, then this expr is always true. - if(isSubtype(valueType, expr->typeExpr.type, IsSubTypeOptions::None)) + expr->value = maybeOpenExistential(originalVal); + expr->witnessArg = tryGetSubtypeWitness(expr->typeExpr.type, valueType); + if (expr->witnessArg) + { + // For now we can only support the scenario where `expr->value` is an interface type. + if (!isInterfaceType(originalVal->type)) { - // Instead of returning a BoolLiteralExpr, we use a field to indicate this scenario, - // so that the language server can still see the original syntax tree. - expr->constantVal = m_astBuilder->create(); - expr->constantVal->type = m_astBuilder->getBoolType(); - expr->constantVal->value = true; - expr->constantVal->loc = expr->loc; - return expr; + getSink()->diagnose(expr, Diagnostics::isOperatorValueMustBeInterfaceType); } + return expr; + } + return expr; +} - // Otherwise, if the target type is a subtype of value->type, we need to grab the - // subtype witness for runtime checks. +Expr* SemanticsExprVisitor::visitAsTypeExpr(AsTypeExpr* expr) +{ + TypeExp typeExpr; + typeExpr.exp = expr->typeExpr; + typeExpr = CheckProperType(typeExpr); + expr->value = CheckTerm(expr->value); + auto optType = m_astBuilder->getOptionalType(typeExpr.type); + expr->type = optType; + + // If value is a subtype of `type`, then this expr is equivalent to a CastToSuperTypeExpr. + if (auto witness = tryGetSubtypeWitness(expr->value->type.type, typeExpr.type)) + { + auto castToSuperType = createCastToSuperTypeExpr(typeExpr.type, expr->value, witness); + auto makeOptional = m_astBuilder->create(); + makeOptional->loc = expr->loc; + makeOptional->type = optType; + makeOptional->value = castToSuperType; + makeOptional->typeExpr = typeExpr.exp; + return makeOptional; + } - expr->value = maybeOpenExistential(originalVal); - expr->witnessArg = tryGetSubtypeWitness(expr->typeExpr.type, valueType); - if (expr->witnessArg) + // If target type is an interface type, we will obtain the witness here for + // runtime casting. + expr->witnessArg = tryGetSubtypeWitness(typeExpr.type, expr->value->type.type); + if (expr->witnessArg) + { + // For now we can only support the scenario where `expr->value` is an interface type. + if (!isInterfaceType(expr->value->type.type)) { - // For now we can only support the scenario where `expr->value` is an interface type. - if (!isInterfaceType(originalVal->type)) - { - getSink()->diagnose(expr, Diagnostics::isOperatorValueMustBeInterfaceType); - } - return expr; + getSink()->diagnose(expr, Diagnostics::isOperatorValueMustBeInterfaceType); } + expr->value = maybeOpenExistential(expr->value); return expr; } - Expr* SemanticsExprVisitor::visitAsTypeExpr(AsTypeExpr* expr) - { - TypeExp typeExpr; - typeExpr.exp = expr->typeExpr; - typeExpr = CheckProperType(typeExpr); - expr->value = CheckTerm(expr->value); - auto optType = m_astBuilder->getOptionalType(typeExpr.type); - expr->type = optType; + expr->typeExpr = typeExpr.exp; + return expr; +} - // If value is a subtype of `type`, then this expr is equivalent to a CastToSuperTypeExpr. - if (auto witness = tryGetSubtypeWitness(expr->value->type.type, typeExpr.type)) - { - auto castToSuperType = createCastToSuperTypeExpr(typeExpr.type, expr->value, witness); - auto makeOptional = m_astBuilder->create(); - makeOptional->loc = expr->loc; - makeOptional->type = optType; - makeOptional->value = castToSuperType; - makeOptional->typeExpr = typeExpr.exp; - return makeOptional; - } - // If target type is an interface type, we will obtain the witness here for - // runtime casting. - expr->witnessArg = tryGetSubtypeWitness(typeExpr.type, expr->value->type.type); - if (expr->witnessArg) - { - // For now we can only support the scenario where `expr->value` is an interface type. - if (!isInterfaceType(expr->value->type.type)) - { - getSink()->diagnose(expr, Diagnostics::isOperatorValueMustBeInterfaceType); - } - expr->value = maybeOpenExistential(expr->value); - return expr; - } +Expr* SemanticsExprVisitor::visitExpandExpr(ExpandExpr* expr) +{ + OrderedHashSet capturedTypePackSet; + auto subContext = this->withParentExpandExpr(expr, &capturedTypePackSet); + expr->baseExpr = dispatchExpr(expr->baseExpr, subContext); - expr->typeExpr = typeExpr.exp; + Type* patternType = nullptr; + bool isTypeExpr = false; + if (auto typeType = as(expr->baseExpr->type)) + { + patternType = typeType->getType(); + isTypeExpr = true; + } + else + { + patternType = expr->baseExpr->type; + } + if (as(patternType)) + { + expr->type = m_astBuilder->getErrorType(); return expr; } - - - Expr* SemanticsExprVisitor::visitExpandExpr(ExpandExpr* expr) + if (subContext.getCapturedTypePacks()->getCount() == 0) + { + getSink()->diagnose(expr, Diagnostics::expandTermCapturesNoTypePacks); + } + List capturedTypePacks; + for (auto capturedType : capturedTypePackSet) { - OrderedHashSet capturedTypePackSet; - auto subContext = this->withParentExpandExpr(expr, &capturedTypePackSet); - expr->baseExpr = dispatchExpr(expr->baseExpr, subContext); + capturedTypePacks.add(capturedType); + } + auto expandType = m_astBuilder->getExpandType(patternType, capturedTypePacks.getArrayView()); + if (isTypeExpr) + expr->type = m_astBuilder->getTypeType(expandType); + else + expr->type = QualType(expandType); + return expr; +} - Type* patternType = nullptr; - bool isTypeExpr = false; - if (auto typeType = as(expr->baseExpr->type)) - { - patternType = typeType->getType(); - isTypeExpr = true; - } - else - { - patternType = expr->baseExpr->type; - } - if (as(patternType)) - { - expr->type = m_astBuilder->getErrorType(); - return expr; - } - if (subContext.getCapturedTypePacks()->getCount() == 0) - { - getSink()->diagnose(expr, Diagnostics::expandTermCapturesNoTypePacks); - } - List capturedTypePacks; - for (auto capturedType : capturedTypePackSet) - { - capturedTypePacks.add(capturedType); - } - auto expandType = m_astBuilder->getExpandType(patternType, capturedTypePacks.getArrayView()); - if (isTypeExpr) - expr->type = m_astBuilder->getTypeType(expandType); - else - expr->type = QualType(expandType); +Expr* SemanticsExprVisitor::visitEachExpr(EachExpr* expr) +{ + if (!m_parentExpandExpr) + { + getSink()->diagnose(expr, Diagnostics::eachExprMustBeInsideExpandExpr); + expr->type = m_astBuilder->getErrorType(); return expr; } - Expr* SemanticsExprVisitor::visitEachExpr(EachExpr* expr) + expr->baseExpr = CheckTerm(expr->baseExpr); + bool isTypeNode = false; + Type* baseType = nullptr; + if (auto typeType = as(expr->baseExpr->type)) { - if (!m_parentExpandExpr) - { - getSink()->diagnose(expr, Diagnostics::eachExprMustBeInsideExpandExpr); - expr->type = m_astBuilder->getErrorType(); - return expr; - } - - expr->baseExpr = CheckTerm(expr->baseExpr); - bool isTypeNode = false; - Type* baseType = nullptr; - if (auto typeType = as(expr->baseExpr->type)) - { - isTypeNode = true; - baseType = typeType->getType(); - } - else - { - baseType = expr->baseExpr->type; - } - if (as(baseType)) - { - expr->type = m_astBuilder->getErrorType(); - return expr; - } - if (isTypeNode) + isTypeNode = true; + baseType = typeType->getType(); + } + else + { + baseType = expr->baseExpr->type; + } + if (as(baseType)) + { + expr->type = m_astBuilder->getErrorType(); + return expr; + } + if (isTypeNode) + { + auto declRefType = as(baseType); + if (!declRefType) { - auto declRefType = as(baseType); - if (!declRefType) - { - goto error; - } - if (!declRefType->getDeclRef().as()) - { - goto error; - } + goto error; } - else + if (!declRefType->getDeclRef().as()) { - if (!isTypePack(baseType) && !as(baseType)) - goto error; + goto error; } - - if (auto tupleType = as(baseType)) - baseType = tupleType->getTypePack(); + } + else + { + if (!isTypePack(baseType) && !as(baseType)) + goto error; + } + + if (auto tupleType = as(baseType)) + baseType = tupleType->getTypePack(); + { + SLANG_ASSERT(m_capturedTypePacks); + if (auto baseExpandType = as(baseType)) { - SLANG_ASSERT(m_capturedTypePacks); - if (auto baseExpandType = as(baseType)) - { - for (Index i = 0; i < baseExpandType->getCapturedTypePackCount(); i++) - { - auto capturedType = baseExpandType->getCapturedTypePack(i); - m_capturedTypePacks->add(capturedType); - } - } - else + for (Index i = 0; i < baseExpandType->getCapturedTypePackCount(); i++) { - m_capturedTypePacks->add(baseType); + auto capturedType = baseExpandType->getCapturedTypePack(i); + m_capturedTypePacks->add(capturedType); } - auto eachType = m_astBuilder->getEachType(baseType); - if (isTypeNode) - expr->type = m_astBuilder->getTypeType(eachType); - else - expr->type = QualType(eachType); - return expr; } - error:; - expr->type = m_astBuilder->getErrorType(); - if (!as(baseType)) + else { - getSink()->diagnose(expr, Diagnostics::expectTypePackAfterEach); + m_capturedTypePacks->add(baseType); } + auto eachType = m_astBuilder->getEachType(baseType); + if (isTypeNode) + expr->type = m_astBuilder->getTypeType(eachType); + else + expr->type = QualType(eachType); return expr; } - - void SemanticsExprVisitor::maybeCheckKnownBuiltinInvocation(Expr* invokeExpr) +error:; + expr->type = m_astBuilder->getErrorType(); + if (!as(baseType)) { - auto checkedInvokeExpr = as(invokeExpr); - if (!checkedInvokeExpr) - return; - auto declRefFuncExpr = as(checkedInvokeExpr->functionExpr); - if (!declRefFuncExpr) + getSink()->diagnose(expr, Diagnostics::expectTypePackAfterEach); + } + return expr; +} + +void SemanticsExprVisitor::maybeCheckKnownBuiltinInvocation(Expr* invokeExpr) +{ + auto checkedInvokeExpr = as(invokeExpr); + if (!checkedInvokeExpr) + return; + auto declRefFuncExpr = as(checkedInvokeExpr->functionExpr); + if (!declRefFuncExpr) + return; + auto callee = declRefFuncExpr->declRef.getDecl(); + if (!callee) + return; + auto knownBuiltinAttr = callee->findModifier(); + if (!knownBuiltinAttr) + return; + if (knownBuiltinAttr->name == "GetAttributeAtVertex") + { + if (checkedInvokeExpr->arguments.getCount() != 2) return; - auto callee = declRefFuncExpr->declRef.getDecl(); - if (!callee) + auto vertexAttributeArg = checkedInvokeExpr->arguments[0]; + auto vertexAttributeArgDeclRefExpr = as(vertexAttributeArg); + if (!vertexAttributeArgDeclRefExpr) + { + getSink()->diagnose( + invokeExpr, + Diagnostics::getAttributeAtVertexMustReferToPerVertexInput); return; - auto knownBuiltinAttr = callee->findModifier(); - if (!knownBuiltinAttr) + } + auto vertexAttributeArgDecl = vertexAttributeArgDeclRefExpr->declRef.getDecl(); + if (!vertexAttributeArgDecl) return; - if (knownBuiltinAttr->name == "GetAttributeAtVertex") + if (!vertexAttributeArgDecl->findModifier() && + !vertexAttributeArgDecl->findModifier()) { - if (checkedInvokeExpr->arguments.getCount() != 2) - return; - auto vertexAttributeArg = checkedInvokeExpr->arguments[0]; - auto vertexAttributeArgDeclRefExpr = as(vertexAttributeArg); - if (!vertexAttributeArgDeclRefExpr) - { - getSink()->diagnose(invokeExpr, Diagnostics::getAttributeAtVertexMustReferToPerVertexInput); - return; - } - auto vertexAttributeArgDecl = vertexAttributeArgDeclRefExpr->declRef.getDecl(); - if (!vertexAttributeArgDecl) - return; - if (!vertexAttributeArgDecl->findModifier() && - !vertexAttributeArgDecl->findModifier()) - { - getSink()->diagnose(vertexAttributeArgDeclRefExpr, Diagnostics::getAttributeAtVertexMustReferToPerVertexInput); - return; - } + getSink()->diagnose( + vertexAttributeArgDeclRefExpr, + Diagnostics::getAttributeAtVertexMustReferToPerVertexInput); + return; } } +} - Expr* SemanticsVisitor::MaybeDereference(Expr* inExpr) +Expr* SemanticsVisitor::MaybeDereference(Expr* inExpr) +{ + Expr* expr = inExpr; + for (;;) { - Expr* expr = inExpr; - for (;;) + auto baseType = expr->type; + if (auto pointerLikeType = as(baseType)) { - auto baseType = expr->type; - if (auto pointerLikeType = as(baseType)) - { - auto elementType = QualType(pointerLikeType->getElementType()); - elementType.isLeftValue = baseType.isLeftValue; - elementType.hasReadOnlyOnTarget = baseType.hasReadOnlyOnTarget; - elementType.isWriteOnly = baseType.isWriteOnly; + auto elementType = QualType(pointerLikeType->getElementType()); + elementType.isLeftValue = baseType.isLeftValue; + elementType.hasReadOnlyOnTarget = baseType.hasReadOnlyOnTarget; + elementType.isWriteOnly = baseType.isWriteOnly; - auto derefExpr = m_astBuilder->create(); - derefExpr->base = expr; - derefExpr->type = elementType; - - expr = derefExpr; - continue; - } + auto derefExpr = m_astBuilder->create(); + derefExpr->base = expr; + derefExpr->type = elementType; - // Default case: just use the expression as-is - return expr; + expr = derefExpr; + continue; } + + // Default case: just use the expression as-is + return expr; } +} - Expr* SemanticsVisitor::CheckMatrixSwizzleExpr( - MemberExpr* memberRefExpr, - Type* baseElementType, - IntegerLiteralValue baseElementRowCount, - IntegerLiteralValue baseElementColCount) - { - MatrixSwizzleExpr* swizExpr = m_astBuilder->create(); - swizExpr->loc = memberRefExpr->loc; - swizExpr->base = memberRefExpr->baseExpression; - swizExpr->memberOpLoc = memberRefExpr->memberOperatorLoc; +Expr* SemanticsVisitor::CheckMatrixSwizzleExpr( + MemberExpr* memberRefExpr, + Type* baseElementType, + IntegerLiteralValue baseElementRowCount, + IntegerLiteralValue baseElementColCount) +{ + MatrixSwizzleExpr* swizExpr = m_astBuilder->create(); + swizExpr->loc = memberRefExpr->loc; + swizExpr->base = memberRefExpr->baseExpression; + swizExpr->memberOpLoc = memberRefExpr->memberOperatorLoc; + + // We can have up to 4 swizzles of two elements each + MatrixCoord elementCoords[4]; + int elementCount = 0; - // We can have up to 4 swizzles of two elements each - MatrixCoord elementCoords[4]; - int elementCount = 0; + bool anyDuplicates = false; + int zeroIndexOffset = -1; - bool anyDuplicates = false; - int zeroIndexOffset = -1; + if (memberRefExpr->name == getSession()->getCompletionRequestTokenName()) + { + auto& suggestions = getLinkage()->contentAssistInfo.completionSuggestions; + suggestions.clear(); + suggestions.scopeKind = CompletionSuggestions::ScopeKind::Swizzle; + suggestions.swizzleBaseType = + memberRefExpr->baseExpression ? memberRefExpr->baseExpression->type : nullptr; + suggestions.elementCount[0] = baseElementRowCount; + suggestions.elementCount[1] = baseElementColCount; + } - if (memberRefExpr->name == getSession()->getCompletionRequestTokenName()) + String swizzleText = getText(memberRefExpr->name); + auto cursor = swizzleText.begin(); + + // The contents of the string are 0-terminated + // Every update to cursor corresponds to a check against 0-termination + while (*cursor) + { + // Throw out swizzling with more than 4 output elements + if (elementCount >= 4) { - auto& suggestions = getLinkage()->contentAssistInfo.completionSuggestions; - suggestions.clear(); - suggestions.scopeKind = CompletionSuggestions::ScopeKind::Swizzle; - suggestions.swizzleBaseType = - memberRefExpr->baseExpression ? memberRefExpr->baseExpression->type : nullptr; - suggestions.elementCount[0] = baseElementRowCount; - suggestions.elementCount[1] = baseElementColCount; + getSink()->diagnose( + swizExpr, + Diagnostics::invalidSwizzleExpr, + swizzleText, + baseElementType->toString()); + return CreateErrorExpr(memberRefExpr); } + MatrixCoord elementCoord = {0, 0}; - String swizzleText = getText(memberRefExpr->name); - auto cursor = swizzleText.begin(); + // Check for the preceding underscore + if (*cursor++ != '_') + { + getSink()->diagnose( + swizExpr, + Diagnostics::invalidSwizzleExpr, + swizzleText, + baseElementType->toString()); + return CreateErrorExpr(memberRefExpr); + } - // The contents of the string are 0-terminated - // Every update to cursor corresponds to a check against 0-termination - while (*cursor) + // Check for one or zero indexing + if (*cursor == 'm') { - // Throw out swizzling with more than 4 output elements - if (elementCount >= 4) + // Can't mix one and zero indexing + if (zeroIndexOffset == 1) { - getSink()->diagnose(swizExpr, Diagnostics::invalidSwizzleExpr, swizzleText, baseElementType->toString()); + getSink()->diagnose( + swizExpr, + Diagnostics::invalidSwizzleExpr, + swizzleText, + baseElementType->toString()); return CreateErrorExpr(memberRefExpr); } - MatrixCoord elementCoord = { 0, 0 }; - - // Check for the preceding underscore - if (*cursor++ != '_') + zeroIndexOffset = 0; + // Increment the index since we saw 'm' + cursor++; + } + else + { + // Can't mix one and zero indexing + if (zeroIndexOffset == 0) { - getSink()->diagnose(swizExpr, Diagnostics::invalidSwizzleExpr, swizzleText, baseElementType->toString()); + getSink()->diagnose( + swizExpr, + Diagnostics::invalidSwizzleExpr, + swizzleText, + baseElementType->toString()); return CreateErrorExpr(memberRefExpr); } + zeroIndexOffset = 1; + } + + // Check for the ij components + for (Index j = 0; j < 2; j++) + { + auto ch = *cursor++; - // Check for one or zero indexing - if (*cursor == 'm') + if (ch < '0' || ch > '4') { - // Can't mix one and zero indexing - if (zeroIndexOffset == 1) - { - getSink()->diagnose(swizExpr, Diagnostics::invalidSwizzleExpr, swizzleText, baseElementType->toString()); - return CreateErrorExpr(memberRefExpr); - } - zeroIndexOffset = 0; - // Increment the index since we saw 'm' - cursor++; + // An invalid character in the swizzle is an error + getSink()->diagnose( + swizExpr, + Diagnostics::invalidSwizzleExpr, + swizzleText, + baseElementType->toString()); + return CreateErrorExpr(memberRefExpr); } - else + const int subIndex = ch - '0' - zeroIndexOffset; + + // Check the limit for either the row or column, depending on the step + IntegerLiteralValue elementLimit; + if (j == 0) { - // Can't mix one and zero indexing - if (zeroIndexOffset == 0) - { - getSink()->diagnose(swizExpr, Diagnostics::invalidSwizzleExpr, swizzleText, baseElementType->toString()); - return CreateErrorExpr(memberRefExpr); - } - zeroIndexOffset = 1; + elementLimit = baseElementRowCount; + elementCoord.row = subIndex; } - - // Check for the ij components - for (Index j = 0; j < 2; j++) + else { - auto ch = *cursor++; - - if (ch < '0' || ch > '4') - { - // An invalid character in the swizzle is an error - getSink()->diagnose(swizExpr, Diagnostics::invalidSwizzleExpr, swizzleText, baseElementType->toString()); - return CreateErrorExpr(memberRefExpr); - } - const int subIndex = ch - '0' - zeroIndexOffset; - - // Check the limit for either the row or column, depending on the step - IntegerLiteralValue elementLimit; - if (j == 0) - { - elementLimit = baseElementRowCount; - elementCoord.row = subIndex; - } - else - { - elementLimit = baseElementColCount; - elementCoord.col = subIndex; - } - // Make sure the index is in range for the source type - // Account for off-by-one and reject 0 if oneIndexed - if (subIndex >= elementLimit || subIndex < 0) - { - getSink()->diagnose(swizExpr, Diagnostics::invalidSwizzleExpr, swizzleText, baseElementType->toString()); - return CreateErrorExpr(memberRefExpr); - } + elementLimit = baseElementColCount; + elementCoord.col = subIndex; } - // Check if we've seen this index before - for (int ee = 0; ee < elementCount; ee++) + // Make sure the index is in range for the source type + // Account for off-by-one and reject 0 if oneIndexed + if (subIndex >= elementLimit || subIndex < 0) { - if (elementCoords[ee] == elementCoord) - anyDuplicates = true; + getSink()->diagnose( + swizExpr, + Diagnostics::invalidSwizzleExpr, + swizzleText, + baseElementType->toString()); + return CreateErrorExpr(memberRefExpr); } - - // add to our list... - elementCoords[elementCount] = elementCoord; - elementCount++; - } - - // Store our list in the actual AST node - for (int ee = 0; ee < elementCount; ++ee) - { - swizExpr->elementCoords[ee] = elementCoords[ee]; - } - swizExpr->elementCount = elementCount; - - if (elementCount == 1) - { - // single-component swizzle produces a scalar - // - // Note(tfoley): the official HLSL rules seem to be that it produces - // a one-component vector, which is then implicitly convertible to - // a scalar, but that seems like it just adds complexity. - swizExpr->type = QualType(baseElementType); } - else + // Check if we've seen this index before + for (int ee = 0; ee < elementCount; ee++) { - // TODO(tfoley): would be nice to "re-sugar" type - // here if the input type had a sugared name... - swizExpr->type = QualType(createVectorType( - baseElementType, - m_astBuilder->getIntVal(m_astBuilder->getIntType(), elementCount))); + if (elementCoords[ee] == elementCoord) + anyDuplicates = true; } - // A swizzle can be used as an l-value as long as there - // were no duplicates in the list of components - swizExpr->type.isLeftValue = !anyDuplicates; - - return swizExpr; + // add to our list... + elementCoords[elementCount] = elementCoord; + elementCount++; } - Expr* SemanticsVisitor::CheckMatrixSwizzleExpr( - MemberExpr* memberRefExpr, - Type* baseElementType, - IntVal* baseRowCount, - IntVal* baseColCount) + // Store our list in the actual AST node + for (int ee = 0; ee < elementCount; ++ee) { - if (auto constantRowCount = as(baseRowCount)) + swizExpr->elementCoords[ee] = elementCoords[ee]; + } + swizExpr->elementCount = elementCount; + + if (elementCount == 1) + { + // single-component swizzle produces a scalar + // + // Note(tfoley): the official HLSL rules seem to be that it produces + // a one-component vector, which is then implicitly convertible to + // a scalar, but that seems like it just adds complexity. + swizExpr->type = QualType(baseElementType); + } + else + { + // TODO(tfoley): would be nice to "re-sugar" type + // here if the input type had a sugared name... + swizExpr->type = QualType(createVectorType( + baseElementType, + m_astBuilder->getIntVal(m_astBuilder->getIntType(), elementCount))); + } + + // A swizzle can be used as an l-value as long as there + // were no duplicates in the list of components + swizExpr->type.isLeftValue = !anyDuplicates; + + return swizExpr; +} + +Expr* SemanticsVisitor::CheckMatrixSwizzleExpr( + MemberExpr* memberRefExpr, + Type* baseElementType, + IntVal* baseRowCount, + IntVal* baseColCount) +{ + if (auto constantRowCount = as(baseRowCount)) + { + if (auto constantColCount = as(baseColCount)) { - if (auto constantColCount = as(baseColCount)) - { - return CheckMatrixSwizzleExpr(memberRefExpr, baseElementType, - constantRowCount->getValue(), constantColCount->getValue()); - } + return CheckMatrixSwizzleExpr( + memberRefExpr, + baseElementType, + constantRowCount->getValue(), + constantColCount->getValue()); } - getSink()->diagnose(memberRefExpr, Diagnostics::unimplemented, "swizzle on matrix of unknown size"); - return CreateErrorExpr(memberRefExpr); } + getSink()->diagnose( + memberRefExpr, + Diagnostics::unimplemented, + "swizzle on matrix of unknown size"); + return CreateErrorExpr(memberRefExpr); +} + +Expr* SemanticsVisitor::checkTupleSwizzleExpr(MemberExpr* memberExpr, TupleType* baseTupleType) +{ + UInt tupleElementCount = (UInt)baseTupleType->getMemberCount(); + if (tupleElementCount == 0) + return checkGeneralMemberLookupExpr(memberExpr, baseTupleType); + + if (memberExpr->name == getSession()->getCompletionRequestTokenName()) + { + auto& suggestions = getLinkage()->contentAssistInfo.completionSuggestions; + suggestions.clear(); + suggestions.scopeKind = CompletionSuggestions::ScopeKind::Swizzle; + suggestions.swizzleBaseType = + memberExpr->baseExpression ? memberExpr->baseExpression->type : nullptr; + suggestions.elementCount[0] = (Index)tupleElementCount; + suggestions.elementCount[1] = 0; + return memberExpr; + } + + String swizzleText = getText(memberExpr->name); + auto span = swizzleText.getUnownedSlice(); + Index pos = 0; - Expr* SemanticsVisitor::checkTupleSwizzleExpr(MemberExpr* memberExpr, TupleType* baseTupleType) + ShortList elementCoords; + + bool anyDuplicates = false; + + // The contents of the string are 0-terminated + // Every update to cursor corresponds to a check against 0-termination + while (pos < span.getLength()) { - UInt tupleElementCount = (UInt)baseTupleType->getMemberCount(); - if (tupleElementCount == 0) + UInt elementCoord; + + // Check for the preceding underscore + if (span[pos] != '_') + { return checkGeneralMemberLookupExpr(memberExpr, baseTupleType); - - if (memberExpr->name == getSession()->getCompletionRequestTokenName()) + } + pos++; + + // Parse index. + if (pos >= span.getLength()) { - auto& suggestions = getLinkage()->contentAssistInfo.completionSuggestions; - suggestions.clear(); - suggestions.scopeKind = CompletionSuggestions::ScopeKind::Swizzle; - suggestions.swizzleBaseType = - memberExpr->baseExpression ? memberExpr->baseExpression->type : nullptr; - suggestions.elementCount[0] = (Index)tupleElementCount; - suggestions.elementCount[1] = 0; - return memberExpr; + // Unexpected end of swizzle string, fallback to + // member lookup. + return checkGeneralMemberLookupExpr(memberExpr, baseTupleType); } - String swizzleText = getText(memberExpr->name); - auto span = swizzleText.getUnownedSlice(); - Index pos = 0; + auto ch = span[pos]; - ShortList elementCoords; + if (!CharUtil::isDigit(ch)) + { + // An invalid character in the swizzle is an error, fallback to + // member lookup. + return checkGeneralMemberLookupExpr(memberExpr, baseTupleType); + } + elementCoord = (UInt)StringUtil::parseIntAndAdvancePos(span, pos); - bool anyDuplicates = false; + if (elementCoord >= tupleElementCount) + { + getSink() + ->diagnose(memberExpr, Diagnostics::invalidSwizzleExpr, swizzleText, baseTupleType); + return CreateErrorExpr(memberExpr); + } - // The contents of the string are 0-terminated - // Every update to cursor corresponds to a check against 0-termination - while (pos < span.getLength()) + // Check if we've seen this index before + for (int ee = 0; ee < elementCoords.getCount(); ee++) { - UInt elementCoord; + if (elementCoords[ee] == elementCoord) + anyDuplicates = true; + } - // Check for the preceding underscore - if (span[pos] != '_') - { - return checkGeneralMemberLookupExpr(memberExpr, baseTupleType); - } - pos++; + // add to our list... + elementCoords.add(elementCoord); + } - // Parse index. - if (pos >= span.getLength()) - { - // Unexpected end of swizzle string, fallback to - // member lookup. - return checkGeneralMemberLookupExpr(memberExpr, baseTupleType); - } + SwizzleExpr* swizExpr = m_astBuilder->create(); + swizExpr->loc = memberExpr->loc; + swizExpr->base = memberExpr->baseExpression; + swizExpr->elementIndices = _Move(elementCoords); + swizExpr->memberOpLoc = memberExpr->memberOperatorLoc; - auto ch = span[pos]; + if (swizExpr->elementIndices.getCount() == 1) + { + // single-component swizzle produces a scalar + // + swizExpr->type = QualType(baseTupleType->getMember(swizExpr->elementIndices[0])); + } + else + { + List types; + for (auto index : swizExpr->elementIndices) + { + types.add(baseTupleType->getMember(index)); + } + swizExpr->type = QualType(m_astBuilder->getTupleType(types.getArrayView())); + } - if (!CharUtil::isDigit(ch)) - { - // An invalid character in the swizzle is an error, fallback to - // member lookup. - return checkGeneralMemberLookupExpr(memberExpr, baseTupleType); - } - elementCoord = (UInt)StringUtil::parseIntAndAdvancePos(span, pos); + // A swizzle can be used as an l-value as long as there + // were no duplicates in the list of components + swizExpr->type.isLeftValue = !anyDuplicates; + return swizExpr; +} - if (elementCoord >= tupleElementCount) - { - getSink()->diagnose(memberExpr, Diagnostics::invalidSwizzleExpr, swizzleText, baseTupleType); - return CreateErrorExpr(memberExpr); - } +Expr* SemanticsVisitor::CheckSwizzleExpr( + MemberExpr* memberRefExpr, + Type* baseElementType, + IntegerLiteralValue baseElementCount) +{ + SwizzleExpr* swizExpr = m_astBuilder->create(); + swizExpr->loc = memberRefExpr->loc; + swizExpr->base = memberRefExpr->baseExpression; + swizExpr->memberOpLoc = memberRefExpr->memberOperatorLoc; + IntegerLiteralValue limitElement = baseElementCount; - // Check if we've seen this index before - for (int ee = 0; ee < elementCoords.getCount(); ee++) - { - if (elementCoords[ee] == elementCoord) - anyDuplicates = true; - } + ShortList elementIndices; - // add to our list... - elementCoords.add(elementCoord); + bool anyDuplicates = false; + bool anyError = false; + if (memberRefExpr->name == getSession()->getCompletionRequestTokenName()) + { + auto& suggestions = getLinkage()->contentAssistInfo.completionSuggestions; + suggestions.clear(); + suggestions.scopeKind = CompletionSuggestions::ScopeKind::Swizzle; + suggestions.swizzleBaseType = + memberRefExpr->baseExpression ? memberRefExpr->baseExpression->type : nullptr; + suggestions.elementCount[0] = baseElementCount; + suggestions.elementCount[1] = 0; + } + auto swizzleText = getText(memberRefExpr->name); + + for (Index i = 0; i < swizzleText.getLength(); i++) + { + auto ch = swizzleText[i]; + int elementIndex = -1; + switch (ch) + { + case 'x': + case 'r': elementIndex = 0; break; + case 'y': + case 'g': elementIndex = 1; break; + case 'z': + case 'b': elementIndex = 2; break; + case 'w': + case 'a': elementIndex = 3; break; + default: + // An invalid character in the swizzle is an error + anyError = true; + break; } - SwizzleExpr* swizExpr = m_astBuilder->create(); - swizExpr->loc = memberExpr->loc; - swizExpr->base = memberExpr->baseExpression; - swizExpr->elementIndices = _Move(elementCoords); - swizExpr->memberOpLoc = memberExpr->memberOperatorLoc; + // TODO(tfoley): GLSL requires that all component names + // come from the same "family"... - if (swizExpr->elementIndices.getCount() == 1) + // Make sure the index is in range for the source type + if (elementIndex >= limitElement) { - // single-component swizzle produces a scalar - // - swizExpr->type = QualType(baseTupleType->getMember(swizExpr->elementIndices[0])); + anyError = true; + break; } - else + + // If elementCount is already at 4 stop trying to assign a swizzle element and send an + // error, we cannot have more valid swizzle elements than 4. + if (elementIndices.getCount() >= 4) { - List types; - for (auto index : swizExpr->elementIndices) - { - types.add(baseTupleType->getMember(index)); - } - swizExpr->type = QualType(m_astBuilder->getTupleType(types.getArrayView())); + anyError = true; + break; + } + + // Check if we've seen this index before + for (int ee = 0; ee < elementIndices.getCount(); ee++) + { + if (elementIndices[ee] == (UInt)elementIndex) + anyDuplicates = true; } - // A swizzle can be used as an l-value as long as there - // were no duplicates in the list of components - swizExpr->type.isLeftValue = !anyDuplicates; - return swizExpr; + // add to our list... + elementIndices.add(elementIndex); } - Expr* SemanticsVisitor::CheckSwizzleExpr( - MemberExpr* memberRefExpr, - Type* baseElementType, - IntegerLiteralValue baseElementCount) + swizExpr->elementIndices = _Move(elementIndices); + + if (anyError) + { + getSink()->diagnose( + swizExpr, + Diagnostics::invalidSwizzleExpr, + swizzleText, + baseElementType->toString()); + return CreateErrorExpr(memberRefExpr); + } + else if (swizExpr->elementIndices.getCount() == 1) { - SwizzleExpr* swizExpr = m_astBuilder->create(); - swizExpr->loc = memberRefExpr->loc; - swizExpr->base = memberRefExpr->baseExpression; - swizExpr->memberOpLoc = memberRefExpr->memberOperatorLoc; - IntegerLiteralValue limitElement = baseElementCount; + // single-component swizzle produces a scalar + // + // Note(tfoley): the official HLSL rules seem to be that it produces + // a one-component vector, which is then implicitly convertible to + // a scalar, but that seems like it just adds complexity. + swizExpr->type = QualType(baseElementType); + } + else + { + // TODO(tfoley): would be nice to "re-sugar" type + // here if the input type had a sugared name... + swizExpr->type = QualType(createVectorType( + baseElementType, + m_astBuilder->getIntVal( + m_astBuilder->getIntType(), + swizExpr->elementIndices.getCount()))); + } - ShortList elementIndices; + // A swizzle can be used as an l-value as long as there + // were no duplicates in the list of components + swizExpr->type.isLeftValue = !anyDuplicates && swizExpr->base && swizExpr->base->type && + swizExpr->base->type.isLeftValue; - bool anyDuplicates = false; - bool anyError = false; - if (memberRefExpr->name == getSession()->getCompletionRequestTokenName()) - { - auto& suggestions = getLinkage()->contentAssistInfo.completionSuggestions; - suggestions.clear(); - suggestions.scopeKind = CompletionSuggestions::ScopeKind::Swizzle; - suggestions.swizzleBaseType = - memberRefExpr->baseExpression ? memberRefExpr->baseExpression->type : nullptr; - suggestions.elementCount[0] = baseElementCount; - suggestions.elementCount[1] = 0; - } - auto swizzleText = getText(memberRefExpr->name); + return swizExpr; +} - for (Index i = 0; i < swizzleText.getLength(); i++) - { - auto ch = swizzleText[i]; - int elementIndex = -1; - switch (ch) - { - case 'x': case 'r': elementIndex = 0; break; - case 'y': case 'g': elementIndex = 1; break; - case 'z': case 'b': elementIndex = 2; break; - case 'w': case 'a': elementIndex = 3; break; - default: - // An invalid character in the swizzle is an error - anyError = true; - break; - } +Expr* SemanticsVisitor::CheckSwizzleExpr( + MemberExpr* memberRefExpr, + Type* baseElementType, + IntVal* baseElementCount) +{ + if (auto constantElementCount = as(baseElementCount)) + { + return CheckSwizzleExpr(memberRefExpr, baseElementType, constantElementCount->getValue()); + } + else + { + getSink()->diagnose( + memberRefExpr, + Diagnostics::unimplemented, + "swizzle on vector of unknown size"); + return CreateErrorExpr(memberRefExpr); + } +} - // TODO(tfoley): GLSL requires that all component names - // come from the same "family"... +Expr* SemanticsVisitor::_lookupStaticMember(DeclRefExpr* expr, Expr* baseExpression) +{ + LookupResult globalLookupResult; + bool hasErrors = false; + Expr* base = nullptr; - // Make sure the index is in range for the source type - if (elementIndex >= limitElement) - { - anyError = true; - break; - } + // Keep track of namespace scopes we've already looked up in to avoid producing + // duplicates. + HashSet processedNamespaceScopes; + + auto handleLeafCase = [&](DeclRef baseDeclRef, Type* type) + { + auto aggTypeDeclRef = as(baseDeclRef); - // If elementCount is already at 4 stop trying to assign a swizzle element and send an error, - // we cannot have more valid swizzle elements than 4. - if (elementIndices.getCount() >= 4) + if (auto namespaceDeclRef = as(baseDeclRef)) + { + // We are looking up a namespace member. + // + // We should lookup in all sibling scopes of the namespace. + // Another detail here is that we need to skip scopes that are transitively imported. + // For example, given: + // ``` + // module a; + // namespace ns { int f_a(); } + // + // module b; + // namespace ns { int f_b(); } // will have a sibling scope that refers to a::ns. + // + // module c; + // import b; + // void test() {ns.f_a(); // should not be valid, because c does not import a. } + // ``` + // Note that this logic doesn't work nicely with __exported import, but we should + // consider deprecate this feature anyway. + // + auto namespaceModule = getModuleDecl(namespaceDeclRef.getDecl()); + auto thisModule = + m_outerScope ? getModuleDecl(m_outerScope->containerDecl) : namespaceModule; + + for (auto scope = namespaceDeclRef.getDecl()->ownedScope; scope; + scope = scope->nextSibling) { - anyError = true; - break; + auto namespaceDecl = as(scope->containerDecl); + if (!namespaceDecl) + continue; + if (thisModule != namespaceModule && + namespaceModule != getModuleDecl(namespaceDecl)) + continue; + if (processedNamespaceScopes.add(scope->containerDecl)) + { + LookupResult nsLookupResult = lookUpDirectAndTransparentMembers( + m_astBuilder, + this, + expr->name, + namespaceDecl, + DeclRef(namespaceDecl), + LookupMask::Default, + getDeclToExcludeFromLookup()); + AddToLookupResult(globalLookupResult, nsLookupResult); + } } + } + else if (aggTypeDeclRef || type) + { + // We are looking up a member inside a type. + // We want to be careful here because we should only find members + // that are implicitly or explicitly `static`. + // + if (type == nullptr) + type = DeclRefType::create(m_astBuilder, aggTypeDeclRef); - // Check if we've seen this index before - for (int ee = 0; ee < elementIndices.getCount(); ee++) + if (as(type)) { - if (elementIndices[ee] == (UInt)elementIndex) - anyDuplicates = true; + return; } - // add to our list... - elementIndices.add(elementIndex); - } + LookupResult lookupResult = lookUpMember( + m_astBuilder, + this, + expr->name, + type, + m_outerScope, + LookupMask::Default, + LookupOptions::NoDeref); - swizExpr->elementIndices = _Move(elementIndices); + // We need to confirm that whatever member we + // are trying to refer to is usable via static reference. + // + // TODO: eventually we might allow a non-static + // member to be adapted by turning it into something + // like a closure that takes the missing `this` parameter. + // + // E.g., a static reference to a method could be treated + // as a value with a function type, where the first parameter + // is `type`. + // + // The biggest challenge there is that we'd need to arrange + // to generate "dispatcher" functions that could be used + // to implement that function, in the case where we are + // making a static reference to some kind of polymorphic declaration. + // + // (Also, static references to fields/properties would get even + // harder, because you'd have to know whether a getter/setter/ref-er + // is needed). + // + // For now let's just be expedient and disallow all of that, because + // we can always add it back in later. - if (anyError) - { - getSink()->diagnose(swizExpr, Diagnostics::invalidSwizzleExpr, swizzleText, baseElementType->toString()); - return CreateErrorExpr(memberRefExpr); - } - else if (swizExpr->elementIndices.getCount() == 1) - { - // single-component swizzle produces a scalar + // If the lookup result is valid, then we want to filter + // it to just those candidates that can be referenced statically, + // and ignore any that would only be allowed as instance members. // - // Note(tfoley): the official HLSL rules seem to be that it produces - // a one-component vector, which is then implicitly convertible to - // a scalar, but that seems like it just adds complexity. - swizExpr->type = QualType(baseElementType); - } - else - { - // TODO(tfoley): would be nice to "re-sugar" type - // here if the input type had a sugared name... - swizExpr->type = QualType(createVectorType( - baseElementType, - m_astBuilder->getIntVal(m_astBuilder->getIntType(), swizExpr->elementIndices.getCount()))); + if (lookupResult.isValid()) + { + // We track both the usable items, and whether or + // not there were any non-static items that need + // to be ignored. + // + bool anyNonStatic = false; + List staticItems; + for (auto item : lookupResult) + { + // Is this item usable as a static member? + if (isUsableAsStaticMember(item)) + { + // If yes, then it will be part of the output. + staticItems.add(item); + } + else + { + // If no, then we might need to output an error. + anyNonStatic = true; + } + } + + // Was there anything non-static in the list? + if (anyNonStatic) + { + // If we had some static items, then that's okay, + // we just want to use our newly-filtered list. + if (staticItems.getCount()) + { + lookupResult.items = staticItems; + lookupResult.item = staticItems[0]; + } + else + { + // Otherwise, it is time to report an error. + getSink()->diagnose( + expr->loc, + Diagnostics::staticRefToNonStaticMember, + type, + expr->name); + hasErrors = true; + return; + } + } + // If there were no non-static items, then the `items` + // array already represents what we'd get by filtering... + + AddToLookupResult(globalLookupResult, lookupResult); + base = baseExpression; + } } + }; - // A swizzle can be used as an l-value as long as there - // were no duplicates in the list of components - swizExpr->type.isLeftValue = !anyDuplicates && - swizExpr->base && - swizExpr->base->type && - swizExpr->base->type.isLeftValue; + auto handleLeafExpr = [&](Expr* e) + { + if (auto nsType = as(e->type)) + handleLeafCase(nsType->getDeclRef(), nsType); + else if (auto aggType = as(e->type)) + handleLeafCase(aggType->getDeclRef(), aggType); + else if (auto typetype = as(e->type)) + handleLeafCase(DeclRef(), typetype->getType()); + }; - return swizExpr; + auto& baseType = baseExpression->type; + if (as(baseType)) + { + return CreateErrorExpr(expr); } - Expr* SemanticsVisitor::CheckSwizzleExpr( - MemberExpr* memberRefExpr, - Type* baseElementType, - IntVal* baseElementCount) + if (auto overloaded = as(baseExpression)) { - if (auto constantElementCount = as(baseElementCount)) - { - return CheckSwizzleExpr(memberRefExpr, baseElementType, constantElementCount->getValue()); - } - else + for (auto candidate : overloaded->lookupResult2.items) + handleLeafCase(candidate.declRef, nullptr); + } + else if (auto overloaded2 = as(baseExpression)) + { + for (auto candidate : overloaded2->candidiateExprs) { - getSink()->diagnose(memberRefExpr, Diagnostics::unimplemented, "swizzle on vector of unknown size"); - return CreateErrorExpr(memberRefExpr); + handleLeafExpr(candidate); } } - - Expr* SemanticsVisitor::_lookupStaticMember(DeclRefExpr* expr, Expr* baseExpression) + else { - LookupResult globalLookupResult; - bool hasErrors = false; - Expr* base = nullptr; + handleLeafExpr(baseExpression); + } - // Keep track of namespace scopes we've already looked up in to avoid producing - // duplicates. - HashSet processedNamespaceScopes; + bool diagnosed = false; + globalLookupResult = + filterLookupResultByVisibilityAndDiagnose(globalLookupResult, expr->loc, diagnosed); + diagnosed |= hasErrors; + if (!globalLookupResult.isValid()) + { + return lookupMemberResultFailure(expr, baseType, diagnosed); + } - auto handleLeafCase = [&](DeclRef baseDeclRef, Type* type) - { - auto aggTypeDeclRef = as(baseDeclRef); + if (expr->name == getSession()->getCompletionRequestTokenName()) + { + suggestCompletionItems(CompletionSuggestions::ScopeKind::Member, globalLookupResult); + } + return createLookupResultExpr(expr->name, globalLookupResult, base, expr->loc, expr); +} - if (auto namespaceDeclRef = as(baseDeclRef)) - { - // We are looking up a namespace member. - // - // We should lookup in all sibling scopes of the namespace. - // Another detail here is that we need to skip scopes that are transitively imported. - // For example, given: - // ``` - // module a; - // namespace ns { int f_a(); } - // - // module b; - // namespace ns { int f_b(); } // will have a sibling scope that refers to a::ns. - // - // module c; - // import b; - // void test() {ns.f_a(); // should not be valid, because c does not import a. } - // ``` - // Note that this logic doesn't work nicely with __exported import, but we should consider - // deprecate this feature anyway. - // - auto namespaceModule = getModuleDecl(namespaceDeclRef.getDecl()); - auto thisModule = m_outerScope ? getModuleDecl(m_outerScope->containerDecl) : namespaceModule; +Expr* SemanticsExprVisitor::visitStaticMemberExpr(StaticMemberExpr* expr) +{ + expr->baseExpression = CheckTerm(expr->baseExpression); - for (auto scope = namespaceDeclRef.getDecl()->ownedScope; scope; scope = scope->nextSibling) - { - auto namespaceDecl = as(scope->containerDecl); - if (!namespaceDecl) - continue; - if (thisModule != namespaceModule && namespaceModule != getModuleDecl(namespaceDecl)) - continue; - if (processedNamespaceScopes.add(scope->containerDecl)) - { - LookupResult nsLookupResult = lookUpDirectAndTransparentMembers( - m_astBuilder, - this, - expr->name, - namespaceDecl, - DeclRef(namespaceDecl), - LookupMask::Default, - getDeclToExcludeFromLookup()); - AddToLookupResult(globalLookupResult, nsLookupResult); - } - } - } - else if (aggTypeDeclRef || type) - { - // We are looking up a member inside a type. - // We want to be careful here because we should only find members - // that are implicitly or explicitly `static`. - // - if (type == nullptr) - type = DeclRefType::create(m_astBuilder, aggTypeDeclRef); + // Not sure this is needed -> but guess someone could do + expr->baseExpression = MaybeDereference(expr->baseExpression); - if (as(type)) - { - return; - } + // If the base of the member lookup has an interface type + // *without* a suitable this-type substitution, then we are + // trying to perform lookup on a value of existential type, + // and we should "open" the existential here so that we + // can expose its structure. + // - LookupResult lookupResult = lookUpMember( - m_astBuilder, - this, - expr->name, - type, - m_outerScope, - LookupMask::Default, - LookupOptions::NoDeref); + expr->baseExpression = maybeOpenExistential(expr->baseExpression); + // Do a static lookup + return _lookupStaticMember(expr, expr->baseExpression); +} - // We need to confirm that whatever member we - // are trying to refer to is usable via static reference. - // - // TODO: eventually we might allow a non-static - // member to be adapted by turning it into something - // like a closure that takes the missing `this` parameter. - // - // E.g., a static reference to a method could be treated - // as a value with a function type, where the first parameter - // is `type`. - // - // The biggest challenge there is that we'd need to arrange - // to generate "dispatcher" functions that could be used - // to implement that function, in the case where we are - // making a static reference to some kind of polymorphic declaration. - // - // (Also, static references to fields/properties would get even - // harder, because you'd have to know whether a getter/setter/ref-er - // is needed). - // - // For now let's just be expedient and disallow all of that, because - // we can always add it back in later. +Expr* SemanticsVisitor::lookupMemberResultFailure( + DeclRefExpr* expr, + QualType const& baseType, + bool supressDiagnostic) +{ + // Check it's a member expression + SLANG_ASSERT(as(expr) || as(expr)); - // If the lookup result is valid, then we want to filter - // it to just those candidates that can be referenced statically, - // and ignore any that would only be allowed as instance members. - // - if (lookupResult.isValid()) - { - // We track both the usable items, and whether or - // not there were any non-static items that need - // to be ignored. - // - bool anyNonStatic = false; - List staticItems; - for (auto item : lookupResult) - { - // Is this item usable as a static member? - if (isUsableAsStaticMember(item)) - { - // If yes, then it will be part of the output. - staticItems.add(item); - } - else - { - // If no, then we might need to output an error. - anyNonStatic = true; - } - } + if (!supressDiagnostic) + getSink()->diagnose(expr, Diagnostics::noMemberOfNameInType, expr->name, baseType); + expr->type = QualType(m_astBuilder->getErrorType()); + return expr; +} - // Was there anything non-static in the list? - if (anyNonStatic) - { - // If we had some static items, then that's okay, - // we just want to use our newly-filtered list. - if (staticItems.getCount()) - { - lookupResult.items = staticItems; - lookupResult.item = staticItems[0]; - } - else - { - // Otherwise, it is time to report an error. - getSink()->diagnose( - expr->loc, - Diagnostics::staticRefToNonStaticMember, - type, - expr->name); - hasErrors = true; - return; - } - } - // If there were no non-static items, then the `items` - // array already represents what we'd get by filtering... +Expr* SemanticsVisitor::maybeInsertImplicitOpForMemberBase(Expr* baseExpr, bool& outNeedDeref) +{ + auto derefExpr = MaybeDereference(baseExpr); - AddToLookupResult(globalLookupResult, lookupResult); - base = baseExpression; - } - } - }; + if (derefExpr != baseExpr) + outNeedDeref = true; - auto handleLeafExpr = [&](Expr* e) - { - if (auto nsType = as(e->type)) - handleLeafCase(nsType->getDeclRef(), nsType); - else if (auto aggType = as(e->type)) - handleLeafCase(aggType->getDeclRef(), aggType); - else if (auto typetype = as(e->type)) - handleLeafCase(DeclRef(), typetype->getType()); - }; + baseExpr = derefExpr; - auto& baseType = baseExpression->type; - if (as(baseType)) - { - return CreateErrorExpr(expr); - } + // If the base of the member lookup has an interface type + // *without* a suitable this-type substitution, then we are + // trying to perform lookup on a value of existential type, + // and we should "open" the existential here so that we + // can expose its structure. + // + baseExpr = maybeOpenExistential(baseExpr); - if (auto overloaded = as(baseExpression)) - { - for (auto candidate : overloaded->lookupResult2.items) - handleLeafCase(candidate.declRef, nullptr); - } - else if (auto overloaded2 = as(baseExpression)) + // Handle the case of an overloaded base expression + // here, in case we can use the name of the member to + // disambiguate which of the candidates is meant, or if + // we can return an overloaded result. + if (auto overloadedExpr = as(baseExpr)) + { + // If a member (dynamic or static) lookup result contains both the actual definition + // and the interface definition obtained from inheritance, we want to filter out + // the interface definitions. + LookupResult filteredLookupResult; + for (auto lookupResult : overloadedExpr->lookupResult2) { - for (auto candidate : overloaded2->candidiateExprs) + bool shouldRemove = false; + if (lookupResult.declRef.getParent().as()) { - handleLeafExpr(candidate); + shouldRemove = true; + } + if (lookupResult.declRef.getDecl()->hasModifier()) + shouldRemove = true; + if (!shouldRemove) + { + filteredLookupResult.items.add(lookupResult); } } - else - { - handleLeafExpr(baseExpression); - } + if (filteredLookupResult.items.getCount() == 1) + filteredLookupResult.item = filteredLookupResult.items.getFirst(); + baseExpr = createLookupResultExpr( + overloadedExpr->name, + filteredLookupResult, + overloadedExpr->base, + overloadedExpr->loc, + overloadedExpr); + // TODO: handle other cases of OverloadedExpr that need filtering. + } - bool diagnosed = false; - globalLookupResult = filterLookupResultByVisibilityAndDiagnose(globalLookupResult, expr->loc, diagnosed); - diagnosed |= hasErrors; - if (!globalLookupResult.isValid()) - { - return lookupMemberResultFailure(expr, baseType, diagnosed); - } + return baseExpr; +} - if (expr->name == getSession()->getCompletionRequestTokenName()) - { - suggestCompletionItems(CompletionSuggestions::ScopeKind::Member, globalLookupResult); - } - return createLookupResultExpr( - expr->name, - globalLookupResult, - base, - expr->loc, - expr); - } +Expr* SemanticsVisitor::checkBaseForMemberExpr(Expr* inBaseExpr, bool& outNeedDeref) +{ + auto baseExpr = inBaseExpr; + baseExpr = CheckTerm(baseExpr); + return maybeInsertImplicitOpForMemberBase(baseExpr, outNeedDeref); +} - Expr* SemanticsExprVisitor::visitStaticMemberExpr(StaticMemberExpr* expr) +Expr* SemanticsVisitor::checkGeneralMemberLookupExpr(MemberExpr* expr, Type* baseType) +{ + LookupResult lookupResult = + lookUpMember(m_astBuilder, this, expr->name, baseType, m_outerScope); + bool diagnosed = false; + lookupResult = filterLookupResultByVisibilityAndDiagnose(lookupResult, expr->loc, diagnosed); + if (!lookupResult.isValid()) + { + return lookupMemberResultFailure(expr, baseType, diagnosed); + } + if (expr->name == getSession()->getCompletionRequestTokenName()) { - expr->baseExpression = CheckTerm(expr->baseExpression); + suggestCompletionItems(CompletionSuggestions::ScopeKind::Member, lookupResult); + } + return createLookupResultExpr(expr->name, lookupResult, expr->baseExpression, expr->loc, expr); +} - // Not sure this is needed -> but guess someone could do - expr->baseExpression = MaybeDereference(expr->baseExpression); +Expr* SemanticsExprVisitor::visitMemberExpr(MemberExpr* expr) +{ + bool needDeref = false; + expr->baseExpression = checkBaseForMemberExpr(expr->baseExpression, needDeref); - // If the base of the member lookup has an interface type - // *without* a suitable this-type substitution, then we are - // trying to perform lookup on a value of existential type, - // and we should "open" the existential here so that we - // can expose its structure. - // + if (!needDeref && as(expr) && !as(expr->baseExpression->type)) + { + // The user is trying to use the `->` operator on something that can't be + // dereferenced, so we should diagnose that. + if (!as(expr->baseExpression->type)) + getSink()->diagnose( + expr->memberOperatorLoc, + Diagnostics::cannotDereferenceType, + expr->baseExpression->type); + } - expr->baseExpression = maybeOpenExistential(expr->baseExpression); - // Do a static lookup + auto baseType = expr->baseExpression->type; + + // If we are looking up through a modified type, just pass straight + // through the inner type. + if (auto modifiedType = as(baseType)) + baseType = modifiedType->getBase(); + + // Note: Checking for vector types before declaration-reference types, + // because vectors are also declaration reference types... + // + // Also note: the way this is done right now means that the ability + // to swizzle vectors interferes with any chance of looking up + // members via extension, for vector or scalar types. + // + if (auto baseMatrixType = as(baseType)) + { + return CheckMatrixSwizzleExpr( + expr, + baseMatrixType->getElementType(), + baseMatrixType->getRowCount(), + baseMatrixType->getColumnCount()); + } + if (auto baseVecType = as(baseType)) + { + return CheckSwizzleExpr( + expr, + baseVecType->getElementType(), + baseVecType->getElementCount()); + } + else if (auto baseScalarType = as(baseType)) + { + // Treat scalar like a 1-element vector when swizzling + return CheckSwizzleExpr(expr, baseScalarType, 1); + } + else if (as(baseType)) + { return _lookupStaticMember(expr, expr->baseExpression); } - - Expr* SemanticsVisitor::lookupMemberResultFailure( - DeclRefExpr* expr, - QualType const& baseType, - bool supressDiagnostic) + else if (const auto typeType = as(baseType)) + { + return _lookupStaticMember(expr, expr->baseExpression); + } + else if (as(expr->baseExpression)) + { + return _lookupStaticMember(expr, expr->baseExpression); + } + else if (as(expr->baseExpression)) + { + return _lookupStaticMember(expr, expr->baseExpression); + } + else if (auto baseTupleType = as(baseType)) { - // Check it's a member expression - SLANG_ASSERT(as(expr) || as(expr)); - - if (!supressDiagnostic) - getSink()->diagnose(expr, Diagnostics::noMemberOfNameInType, expr->name, baseType); - expr->type = QualType(m_astBuilder->getErrorType()); - return expr; + return checkTupleSwizzleExpr(expr, baseTupleType); } - - Expr* SemanticsVisitor::maybeInsertImplicitOpForMemberBase(Expr* baseExpr, bool& outNeedDeref) + else if (as(baseType)) { - auto derefExpr = MaybeDereference(baseExpr); - - if (derefExpr != baseExpr) - outNeedDeref = true; - - baseExpr = derefExpr; - - // If the base of the member lookup has an interface type - // *without* a suitable this-type substitution, then we are - // trying to perform lookup on a value of existential type, - // and we should "open" the existential here so that we - // can expose its structure. - // - baseExpr = maybeOpenExistential(baseExpr); - - // Handle the case of an overloaded base expression - // here, in case we can use the name of the member to - // disambiguate which of the candidates is meant, or if - // we can return an overloaded result. - if (auto overloadedExpr = as(baseExpr)) - { - // If a member (dynamic or static) lookup result contains both the actual definition - // and the interface definition obtained from inheritance, we want to filter out - // the interface definitions. - LookupResult filteredLookupResult; - for (auto lookupResult : overloadedExpr->lookupResult2) - { - bool shouldRemove = false; - if (lookupResult.declRef.getParent().as()) - { - shouldRemove = true; - } - if (lookupResult.declRef.getDecl()->hasModifier()) - shouldRemove = true; - if (!shouldRemove) - { - filteredLookupResult.items.add(lookupResult); - } - } - if (filteredLookupResult.items.getCount() == 1) - filteredLookupResult.item = filteredLookupResult.items.getFirst(); - baseExpr = createLookupResultExpr( - overloadedExpr->name, - filteredLookupResult, - overloadedExpr->base, - overloadedExpr->loc, - overloadedExpr); - // TODO: handle other cases of OverloadedExpr that need filtering. - } - - return baseExpr; + return CreateErrorExpr(expr); } - - Expr* SemanticsVisitor::checkBaseForMemberExpr(Expr* inBaseExpr, bool& outNeedDeref) + else { - auto baseExpr = inBaseExpr; - baseExpr = CheckTerm(baseExpr); - return maybeInsertImplicitOpForMemberBase(baseExpr, outNeedDeref); + return checkGeneralMemberLookupExpr(expr, baseType); } +} + +Expr* SemanticsExprVisitor::visitInitializerListExpr(InitializerListExpr* expr) +{ + // If we are assigned a type, expr has already been legalized + if (expr->type) + return expr; + + // When faced with an initializer list, we first just check the sub-expressions blindly. + // Actually making them conform to a desired type will wait for when we know the desired + // type based on context. - Expr* SemanticsVisitor::checkGeneralMemberLookupExpr(MemberExpr* expr, Type* baseType) + for (auto& arg : expr->args) { - LookupResult lookupResult = lookUpMember( - m_astBuilder, - this, - expr->name, - baseType, - m_outerScope); - bool diagnosed = false; - lookupResult = filterLookupResultByVisibilityAndDiagnose(lookupResult, expr->loc, diagnosed); - if (!lookupResult.isValid()) - { - return lookupMemberResultFailure(expr, baseType, diagnosed); - } - if (expr->name == getSession()->getCompletionRequestTokenName()) - { - suggestCompletionItems(CompletionSuggestions::ScopeKind::Member, lookupResult); - } - return createLookupResultExpr( - expr->name, - lookupResult, - expr->baseExpression, - expr->loc, - expr); + arg = CheckTerm(arg); } - Expr* SemanticsExprVisitor::visitMemberExpr(MemberExpr * expr) - { - bool needDeref = false; - expr->baseExpression = checkBaseForMemberExpr(expr->baseExpression, needDeref); + expr->type = m_astBuilder->getInitializerListType(); - if (!needDeref && as(expr) && !as(expr->baseExpression->type)) - { - // The user is trying to use the `->` operator on something that can't be - // dereferenced, so we should diagnose that. - if (!as(expr->baseExpression->type)) - getSink()->diagnose(expr->memberOperatorLoc, Diagnostics::cannotDereferenceType, expr->baseExpression->type); - } + return expr; +} - auto baseType = expr->baseExpression->type; +// Perform semantic checking of an object-oriented `this` +// expression. +Expr* SemanticsExprVisitor::visitThisExpr(ThisExpr* expr) +{ + // A `this` expression will default to immutable. + expr->type.isLeftValue = false; - // If we are looking up through a modified type, just pass straight - // through the inner type. - if (auto modifiedType = as(baseType)) - baseType = modifiedType->getBase(); + // We will do an upwards search starting in the current + // scope, looking for a surrounding type (or `extension`) + // declaration that could be the referrant of the expression. + auto scope = expr->scope; + while (scope) + { + auto containerDecl = scope->containerDecl; - // Note: Checking for vector types before declaration-reference types, - // because vectors are also declaration reference types... - // - // Also note: the way this is done right now means that the ability - // to swizzle vectors interferes with any chance of looking up - // members via extension, for vector or scalar types. - // - if (auto baseMatrixType = as(baseType)) - { - return CheckMatrixSwizzleExpr( - expr, - baseMatrixType->getElementType(), - baseMatrixType->getRowCount(), - baseMatrixType->getColumnCount()); - } - if (auto baseVecType = as(baseType)) - { - return CheckSwizzleExpr( - expr, - baseVecType->getElementType(), - baseVecType->getElementCount()); - } - else if(auto baseScalarType = as(baseType)) - { - // Treat scalar like a 1-element vector when swizzling - return CheckSwizzleExpr( - expr, - baseScalarType, - 1); - } - else if( as(baseType) ) - { - return _lookupStaticMember(expr, expr->baseExpression); - } - else if(const auto typeType = as(baseType)) - { - return _lookupStaticMember(expr, expr->baseExpression); - } - else if (as(expr->baseExpression)) - { - return _lookupStaticMember(expr, expr->baseExpression); - } - else if (as(expr->baseExpression)) - { - return _lookupStaticMember(expr, expr->baseExpression); - } - else if (auto baseTupleType = as(baseType)) - { - return checkTupleSwizzleExpr(expr, baseTupleType); - } - else if (as(baseType)) - { - return CreateErrorExpr(expr); - } - else + if (const auto ctorDecl = as(containerDecl)) { - return checkGeneralMemberLookupExpr(expr, baseType); + expr->type.isLeftValue = true; } - } - - Expr* SemanticsExprVisitor::visitInitializerListExpr(InitializerListExpr* expr) - { - // If we are assigned a type, expr has already been legalized - if(expr->type) - return expr; - - // When faced with an initializer list, we first just check the sub-expressions blindly. - // Actually making them conform to a desired type will wait for when we know the desired - // type based on context. - - for( auto& arg : expr->args ) + else if (const auto setterDecl = as(containerDecl)) { - arg = CheckTerm(arg); + expr->type.isLeftValue = true; } - - expr->type = m_astBuilder->getInitializerListType(); - - return expr; - } - - // Perform semantic checking of an object-oriented `this` - // expression. - Expr* SemanticsExprVisitor::visitThisExpr(ThisExpr* expr) - { - // A `this` expression will default to immutable. - expr->type.isLeftValue = false; - - // We will do an upwards search starting in the current - // scope, looking for a surrounding type (or `extension`) - // declaration that could be the referrant of the expression. - auto scope = expr->scope; - while (scope) + else if (auto funcDeclBase = as(containerDecl)) { - auto containerDecl = scope->containerDecl; - - if( const auto ctorDecl = as(containerDecl) ) + if (funcDeclBase->hasModifier()) { expr->type.isLeftValue = true; } - else if( const auto setterDecl = as(containerDecl) ) + else if (funcDeclBase->hasModifier()) { expr->type.isLeftValue = true; } - else if( auto funcDeclBase = as(containerDecl) ) - { - if( funcDeclBase->hasModifier() ) - { - expr->type.isLeftValue = true; - } - else if (funcDeclBase->hasModifier()) - { - expr->type.isLeftValue = true; - } - } - else if( auto typeOrExtensionDecl = as(containerDecl) ) - { - expr->type.type = calcThisType(makeDeclRef(typeOrExtensionDecl)); - return expr; - } + } + else if (auto typeOrExtensionDecl = as(containerDecl)) + { + expr->type.type = calcThisType(makeDeclRef(typeOrExtensionDecl)); + return expr; + } #if 0 else if (auto aggTypeDecl = as(containerDecl)) { @@ -4815,388 +4994,422 @@ namespace Slang } #endif - scope = scope->parent; - } - - if (auto sink = getSink()) - sink->diagnose(expr, Diagnostics::thisExpressionOutsideOfTypeDecl); - - return CreateErrorExpr(expr); + scope = scope->parent; } - Expr* SemanticsExprVisitor::visitThisTypeExpr(ThisTypeExpr* expr) + if (auto sink = getSink()) + sink->diagnose(expr, Diagnostics::thisExpressionOutsideOfTypeDecl); + + return CreateErrorExpr(expr); +} + +Expr* SemanticsExprVisitor::visitThisTypeExpr(ThisTypeExpr* expr) +{ + auto scope = expr->scope; + while (scope) { - auto scope = expr->scope; - while (scope) + auto containerDecl = scope->containerDecl; + if (auto typeOrExtensionDecl = as(containerDecl)) { - auto containerDecl = scope->containerDecl; - if( auto typeOrExtensionDecl = as(containerDecl) ) - { - auto thisType = calcThisType(makeDeclRef(typeOrExtensionDecl)); - auto thisTypeType = m_astBuilder->getTypeType(thisType); - - expr->type.type = thisTypeType; - return expr; - } + auto thisType = calcThisType(makeDeclRef(typeOrExtensionDecl)); + auto thisTypeType = m_astBuilder->getTypeType(thisType); - scope = scope->parent; + expr->type.type = thisTypeType; + return expr; } - getSink()->diagnose(expr, Diagnostics::thisTypeOutsideOfTypeDecl); - return CreateErrorExpr(expr); + scope = scope->parent; } - Expr* SemanticsExprVisitor::visitCastToSuperTypeExpr(CastToSuperTypeExpr* expr) - { - // CastToSuperType is effectively a struct field. - // As long as the type is not readonly tagged we - // can use CastToSuperType as an L-value - if(!expr->type.hasReadOnlyOnTarget) - expr->type.isLeftValue = true; - return expr; - } + getSink()->diagnose(expr, Diagnostics::thisTypeOutsideOfTypeDecl); + return CreateErrorExpr(expr); +} + +Expr* SemanticsExprVisitor::visitCastToSuperTypeExpr(CastToSuperTypeExpr* expr) +{ + // CastToSuperType is effectively a struct field. + // As long as the type is not readonly tagged we + // can use CastToSuperType as an L-value + if (!expr->type.hasReadOnlyOnTarget) + expr->type.isLeftValue = true; + return expr; +} - Expr* SemanticsExprVisitor::visitReturnValExpr(ReturnValExpr* expr) +Expr* SemanticsExprVisitor::visitReturnValExpr(ReturnValExpr* expr) +{ + auto scope = expr->scope; + if (scope) { - auto scope = expr->scope; - if (scope) + auto parentFunc = as(getParentFunc(scope->containerDecl)); + if (parentFunc) { - auto parentFunc = as(getParentFunc(scope->containerDecl)); - if (parentFunc) + if (as(parentFunc->returnType.type)) { - if (as(parentFunc->returnType.type)) - { - expr->type = parentFunc->returnType.type; - return expr; - } - if (isNonCopyableType(parentFunc->returnType.type)) - { - expr->type.isLeftValue = true; - expr->type.type = parentFunc->returnType.type; - return expr; - } + expr->type = parentFunc->returnType.type; + return expr; + } + if (isNonCopyableType(parentFunc->returnType.type)) + { + expr->type.isLeftValue = true; + expr->type.type = parentFunc->returnType.type; + return expr; } } - getSink()->diagnose(expr, Diagnostics::returnValNotAvailable); - expr->type = getASTBuilder()->getErrorType(); - return expr; } + getSink()->diagnose(expr, Diagnostics::returnValNotAvailable); + expr->type = getASTBuilder()->getErrorType(); + return expr; +} - Expr* SemanticsExprVisitor::visitAndTypeExpr(AndTypeExpr* expr) - { - // The left and right sides of an `&` for types must both be types. - // - expr->left = CheckProperType(expr->left); - expr->right = CheckProperType(expr->right); - - // TODO: We should enforce some rules here about what is allowed - // for the `left` and `right` types. - // - // For now, the right rule is that they probably need to either - // be interfaces, or conjunctions thereof. - // - // Eventually it may be valuable to support more flexible - // types in conjunctions, especialy in cases where inheritance - // gets involved. - - // The result of this expression is an `AndType`, which we need - // to wrap in a `TypeType` to indicate that the result is the type - // itself and not a value of that type. - // - auto andType = m_astBuilder->getAndType(expr->left.type, expr->right.type); - expr->type = m_astBuilder->getTypeType(andType); +Expr* SemanticsExprVisitor::visitAndTypeExpr(AndTypeExpr* expr) +{ + // The left and right sides of an `&` for types must both be types. + // + expr->left = CheckProperType(expr->left); + expr->right = CheckProperType(expr->right); + + // TODO: We should enforce some rules here about what is allowed + // for the `left` and `right` types. + // + // For now, the right rule is that they probably need to either + // be interfaces, or conjunctions thereof. + // + // Eventually it may be valuable to support more flexible + // types in conjunctions, especialy in cases where inheritance + // gets involved. + + // The result of this expression is an `AndType`, which we need + // to wrap in a `TypeType` to indicate that the result is the type + // itself and not a value of that type. + // + auto andType = m_astBuilder->getAndType(expr->left.type, expr->right.type); + expr->type = m_astBuilder->getTypeType(andType); + + return expr; +} - return expr; - } +Expr* SemanticsExprVisitor::visitPointerTypeExpr(PointerTypeExpr* expr) +{ + expr->base = CheckProperType(expr->base); + if (as(expr->base.type)) + expr->type = expr->base.type; + auto ptrType = m_astBuilder->getPtrType(expr->base.type, AddressSpace::UserPointer); + expr->type = m_astBuilder->getTypeType(ptrType); + return expr; +} - Expr* SemanticsExprVisitor::visitPointerTypeExpr(PointerTypeExpr* expr) - { - expr->base = CheckProperType(expr->base); - if (as(expr->base.type)) - expr->type = expr->base.type; - auto ptrType = m_astBuilder->getPtrType(expr->base.type, AddressSpace::UserPointer); - expr->type = m_astBuilder->getTypeType(ptrType); - return expr; - } +Expr* SemanticsExprVisitor::visitModifiedTypeExpr(ModifiedTypeExpr* expr) +{ + // The base type should be a proper type (not an expression, generic, etc.) + // + expr->base = CheckProperType(expr->base); + auto baseType = expr->base.type; - Expr* SemanticsExprVisitor::visitModifiedTypeExpr(ModifiedTypeExpr* expr) + // We will check the modifiers that were applied to the type expression + // one by one, and collect a list of the ones that should modify the + // resulting `Type`. + // + List modifierVals; + for (auto modifier : expr->modifiers) { - // The base type should be a proper type (not an expression, generic, etc.) - // - expr->base = CheckProperType(expr->base); - auto baseType = expr->base.type; - - // We will check the modifiers that were applied to the type expression - // one by one, and collect a list of the ones that should modify the - // resulting `Type`. - // - List modifierVals; - for( auto modifier : expr->modifiers ) + if (auto matrixLayoutModifier = as(modifier)) { - if (auto matrixLayoutModifier = as(modifier)) + if (auto matrixType = as(baseType)) { - if (auto matrixType = as(baseType)) + if (as(matrixLayoutModifier)) { - if (as(matrixLayoutModifier)) - { - baseType = m_astBuilder->getMatrixType(matrixType->getElementType(), matrixType->getRowCount(), matrixType->getColumnCount(), - m_astBuilder->getIntVal(m_astBuilder->getIntType(), kMatrixLayoutMode_ColumnMajor)); - } - else - { - baseType = m_astBuilder->getMatrixType(matrixType->getElementType(), matrixType->getRowCount(), matrixType->getColumnCount(), - m_astBuilder->getIntVal(m_astBuilder->getIntType(), kMatrixLayoutMode_RowMajor)); - } - expr->type = m_astBuilder->getTypeType(baseType); + baseType = m_astBuilder->getMatrixType( + matrixType->getElementType(), + matrixType->getRowCount(), + matrixType->getColumnCount(), + m_astBuilder->getIntVal( + m_astBuilder->getIntType(), + kMatrixLayoutMode_ColumnMajor)); } else { - getSink()->diagnose(matrixLayoutModifier, Diagnostics::matrixLayoutModifierOnNonMatrixType, baseType); + baseType = m_astBuilder->getMatrixType( + matrixType->getElementType(), + matrixType->getRowCount(), + matrixType->getColumnCount(), + m_astBuilder->getIntVal( + m_astBuilder->getIntType(), + kMatrixLayoutMode_RowMajor)); } - continue; + expr->type = m_astBuilder->getTypeType(baseType); } - auto modifierVal = checkTypeModifier(modifier, baseType); - if(!modifierVal) - continue; - modifierVals.add(modifierVal); - } - - if (modifierVals.getCount()) - { - auto modifiedType = m_astBuilder->getModifiedType(baseType, modifierVals); - expr->type = m_astBuilder->getTypeType(modifiedType); + else + { + getSink()->diagnose( + matrixLayoutModifier, + Diagnostics::matrixLayoutModifierOnNonMatrixType, + baseType); + } + continue; } - return expr; + auto modifierVal = checkTypeModifier(modifier, baseType); + if (!modifierVal) + continue; + modifierVals.add(modifierVal); } - Val* SemanticsExprVisitor::checkTypeModifier(Modifier* modifier, Type* type) + if (modifierVals.getCount()) { - SLANG_UNUSED(type); + auto modifiedType = m_astBuilder->getModifiedType(baseType, modifierVals); + expr->type = m_astBuilder->getTypeType(modifiedType); + } + return expr; +} - if( const auto unormModifier = as(modifier) ) - { - // TODO: validate that `type` is either `float` or a vector of `float`s - return m_astBuilder->getUNormModifierVal(); +Val* SemanticsExprVisitor::checkTypeModifier(Modifier* modifier, Type* type) +{ + SLANG_UNUSED(type); - } - else if( const auto snormModifier = as(modifier) ) - { - // TODO: validate that `type` is either `float` or a vector of `float`s - return m_astBuilder->getSNormModifierVal(); - } - else if (const auto noDiffModifier = as(modifier)) - { - return m_astBuilder->getNoDiffModifierVal(); - } - else - { - // TODO: more complete error message here - getSink()->diagnose(modifier, Diagnostics::unexpected, "unknown type modifier in semantic checking"); - return nullptr; - } + if (const auto unormModifier = as(modifier)) + { + // TODO: validate that `type` is either `float` or a vector of `float`s + return m_astBuilder->getUNormModifierVal(); } - - Expr* SemanticsExprVisitor::visitFuncTypeExpr(FuncTypeExpr* expr) + else if (const auto snormModifier = as(modifier)) { - // The input and output to a function type must both be types - for(auto& t : expr->parameters) - t = CheckProperType(t); - expr->result = CheckProperType(expr->result); - - // TODO: Kind checking? Where are we stopping someone passing - // constraints around as value-inhabitable types - - // The result of this expression is a `FuncType`, which we need - // to wrap in a `TypeType` to indicate that the result is the type - // itself and not a value of that type. - List types; - types.reserve(expr->parameters.getCount()); - for(const auto& t : expr->parameters) - types.add(t.type); - auto funcType = m_astBuilder->getFuncType(types.getArrayView(), expr->result.type); - expr->type = m_astBuilder->getTypeType(funcType); - - return expr; + // TODO: validate that `type` is either `float` or a vector of `float`s + return m_astBuilder->getSNormModifierVal(); } - - Expr* SemanticsExprVisitor::visitTupleTypeExpr(TupleTypeExpr* expr) + else if (const auto noDiffModifier = as(modifier)) { - // All tuple members must be types - for(auto& t : expr->members) - t = CheckProperType(t); + return m_astBuilder->getNoDiffModifierVal(); + } + else + { + // TODO: more complete error message here + getSink()->diagnose( + modifier, + Diagnostics::unexpected, + "unknown type modifier in semantic checking"); + return nullptr; + } +} - // As in the other cases above, wrap in TypeType - List types; - types.reserve(expr->members.getCount()); - for(auto t : expr->members) - types.add(t.type); - auto tupleType = m_astBuilder->getTupleType(types.getArrayView()); - expr->type = m_astBuilder->getTypeType(tupleType); +Expr* SemanticsExprVisitor::visitFuncTypeExpr(FuncTypeExpr* expr) +{ + // The input and output to a function type must both be types + for (auto& t : expr->parameters) + t = CheckProperType(t); + expr->result = CheckProperType(expr->result); + + // TODO: Kind checking? Where are we stopping someone passing + // constraints around as value-inhabitable types + + // The result of this expression is a `FuncType`, which we need + // to wrap in a `TypeType` to indicate that the result is the type + // itself and not a value of that type. + List types; + types.reserve(expr->parameters.getCount()); + for (const auto& t : expr->parameters) + types.add(t.type); + auto funcType = m_astBuilder->getFuncType(types.getArrayView(), expr->result.type); + expr->type = m_astBuilder->getTypeType(funcType); + + return expr; +} - return expr; - } +Expr* SemanticsExprVisitor::visitTupleTypeExpr(TupleTypeExpr* expr) +{ + // All tuple members must be types + for (auto& t : expr->members) + t = CheckProperType(t); + + // As in the other cases above, wrap in TypeType + List types; + types.reserve(expr->members.getCount()); + for (auto t : expr->members) + types.add(t.type); + auto tupleType = m_astBuilder->getTupleType(types.getArrayView()); + expr->type = m_astBuilder->getTypeType(tupleType); + + return expr; +} + +Expr* SemanticsExprVisitor::visitSPIRVAsmExpr(SPIRVAsmExpr* expr) +{ + // + // Firstly, get the info for this op, the opcode has already been + // discovered by the parser + // + const auto& spirvInfo = getSession()->spirvCoreGrammarInfo; - Expr* SemanticsExprVisitor::visitSPIRVAsmExpr(SPIRVAsmExpr* expr) + // We will iterate over all the operands in all the insts and check + // them + bool failed = false; + for (auto& inst : expr->insts) { - // - // Firstly, get the info for this op, the opcode has already been - // discovered by the parser - // - const auto& spirvInfo = getSession()->spirvCoreGrammarInfo; + // It's not automatically a failure to not have info, we just won't + // be able to deduce types for operands + const auto opInfo = spirvInfo->opInfos.lookup(SpvOp(inst.opcode.knownValue)); - // We will iterate over all the operands in all the insts and check - // them - bool failed = false; - for(auto& inst : expr->insts) + if (opInfo && opInfo->numOperandTypes == 0 && inst.operands.getCount()) { - // It's not automatically a failure to not have info, we just won't - // be able to deduce types for operands - const auto opInfo = spirvInfo->opInfos.lookup(SpvOp(inst.opcode.knownValue)); - - if(opInfo && opInfo->numOperandTypes == 0 && inst.operands.getCount()) - { - failed = true; - getSink()->diagnose(inst.opcode.token, Diagnostics::spirvInstructionWithTooManyOperands, inst.opcode.token, 0); - continue; - } + failed = true; + getSink()->diagnose( + inst.opcode.token, + Diagnostics::spirvInstructionWithTooManyOperands, + inst.opcode.token, + 0); + continue; + } + + const bool isLast = &inst == &expr->insts.getLast(); + for (Index operandIndex = 0; operandIndex < inst.operands.getCount(); ++operandIndex) + { + // Clamp to the end of the type info array, because the last one will be any variable + // operands + const auto invalidOperandKind = SPIRVCoreGrammarInfo::OperandKind{0xff}; + const auto operandType = + opInfo.has_value() + ? opInfo + ->operandTypes[std::min(operandIndex, Index(opInfo->numOperandTypes) - 1)] + : invalidOperandKind; + const auto baseOperandType = + spirvInfo->operandKindUnderneathIds.lookup(operandType).value_or(operandType); + const auto needsIdWrapper = baseOperandType != operandType; - const bool isLast = &inst == &expr->insts.getLast(); - for(Index operandIndex = 0; operandIndex < inst.operands.getCount(); ++operandIndex) + const auto check = [&](const auto& go, auto& operand) -> void { - // Clamp to the end of the type info array, because the last one will be any variable operands - const auto invalidOperandKind = SPIRVCoreGrammarInfo::OperandKind{0xff}; - const auto operandType - = opInfo.has_value() - ? opInfo->operandTypes[std::min(operandIndex, Index(opInfo->numOperandTypes)-1)] - : invalidOperandKind; - const auto baseOperandType - = spirvInfo->operandKindUnderneathIds.lookup(operandType).value_or(operandType); - const auto needsIdWrapper = baseOperandType != operandType; + if (operand.flavor == SPIRVAsmOperand::SlangType || + operand.flavor == SPIRVAsmOperand::SampledType) + { + // This is a $$type operand or __sampledType(T) + // operand, fill in its TypeExp member. + TypeExp& typeExpr = operand.type; + typeExpr.exp = operand.expr; + typeExpr = CheckProperType(typeExpr); + operand.expr = typeExpr.exp; + } + else if ( + operand.flavor == SPIRVAsmOperand::SlangValue || + operand.flavor == SPIRVAsmOperand::SlangImmediateValue || + operand.flavor == SPIRVAsmOperand::SlangValueAddr || + operand.flavor == SPIRVAsmOperand::ImageType || + operand.flavor == SPIRVAsmOperand::SampledImageType || + operand.flavor == SPIRVAsmOperand::ConvertTexel || + operand.flavor == SPIRVAsmOperand::RayPayloadFromLocation || + operand.flavor == SPIRVAsmOperand::RayAttributeFromLocation || + operand.flavor == SPIRVAsmOperand::RayCallableFromLocation) + { + // This is a $expr operand, check the expr + operand.expr = dispatch(operand.expr); + } + else if (operand.flavor == SPIRVAsmOperand::ResultMarker) + { + // This is the marker, check that it only + // appears in the last instruction. - const auto check = [&](const auto& go, auto& operand) -> void { - if(operand.flavor == SPIRVAsmOperand::SlangType - || operand.flavor == SPIRVAsmOperand::SampledType) - { - // This is a $$type operand or __sampledType(T) - // operand, fill in its TypeExp member. - TypeExp& typeExpr = operand.type; - typeExpr.exp = operand.expr; - typeExpr = CheckProperType(typeExpr); - operand.expr = typeExpr.exp; - } - else if(operand.flavor == SPIRVAsmOperand::SlangValue - || operand.flavor == SPIRVAsmOperand::SlangImmediateValue - || operand.flavor == SPIRVAsmOperand::SlangValueAddr - || operand.flavor == SPIRVAsmOperand::ImageType - || operand.flavor == SPIRVAsmOperand::SampledImageType - || operand.flavor == SPIRVAsmOperand::ConvertTexel - || operand.flavor == SPIRVAsmOperand::RayPayloadFromLocation - || operand.flavor == SPIRVAsmOperand::RayAttributeFromLocation - || operand.flavor == SPIRVAsmOperand::RayCallableFromLocation) - { - // This is a $expr operand, check the expr - operand.expr = dispatch(operand.expr); - } - else if(operand.flavor == SPIRVAsmOperand::ResultMarker) + // TODO: We could consider relaxing this, because SPIR-V + // does have forward references for decorations and such + if (!isLast) { - // This is the marker, check that it only - // appears in the last instruction. - - // TODO: We could consider relaxing this, because SPIR-V - // does have forward references for decorations and such - if (!isLast) - { - getSink()->diagnose(operand.token, Diagnostics::misplacedResultIdMarker); - getSink()->diagnoseWithoutSourceView(expr, Diagnostics::considerOpCopyObject); - } + getSink()->diagnose(operand.token, Diagnostics::misplacedResultIdMarker); + getSink()->diagnoseWithoutSourceView( + expr, + Diagnostics::considerOpCopyObject); } - else if(operand.flavor == SPIRVAsmOperand::NamedValue) + } + else if (operand.flavor == SPIRVAsmOperand::NamedValue) + { + // First try and look it up with the knowledge of this operand's type + auto enumValue = + spirvInfo->allEnums.lookup({baseOperandType, operand.token.getContent()}); + // Then fall back to with the type prefix + if (!enumValue) + enumValue = + spirvInfo->allEnumsWithTypePrefix.lookup(operand.token.getContent()); + // Then see if it's an opcode (for OpSpecialize) + if (!enumValue) + enumValue = spirvInfo->opcodes.lookup(operand.token.getContent()); + if (inst.opcode.knownValue == SpvOpExtInst) { - // First try and look it up with the knowledge of this operand's type - auto enumValue - = spirvInfo->allEnums.lookup({baseOperandType, operand.token.getContent()}); - // Then fall back to with the type prefix - if(!enumValue) - enumValue = spirvInfo->allEnumsWithTypePrefix.lookup(operand.token.getContent()); - // Then see if it's an opcode (for OpSpecialize) - if(!enumValue) - enumValue = spirvInfo->opcodes.lookup(operand.token.getContent()); - if (inst.opcode.knownValue == SpvOpExtInst) + if (!enumValue) { - if (!enumValue) + GLSLstd450 val; + if (lookupGLSLstd450(operand.token.getContent(), val)) { - GLSLstd450 val; - if (lookupGLSLstd450(operand.token.getContent(), val)) - { - enumValue = (SpvWord)val; - } + enumValue = (SpvWord)val; } } - if(!enumValue) - { - failed = true; - getSink()->diagnose(operand.token, Diagnostics::spirvUnableToResolveName, operand.token.getContent()); - return; - } - - operand.knownValue = *enumValue; - operand.wrapInId = needsIdWrapper; } - else if (operand.flavor == SPIRVAsmOperand::BuiltinVar) + if (!enumValue) { - operand.type = CheckProperType(operand.type); - auto builtinVarKind = spirvInfo->allEnums.lookup( - SPIRVCoreGrammarInfo::QualifiedEnumName{spirvInfo->operandKinds.lookup(UnownedStringSlice("BuiltIn")).value(), operand.token.getContent()}); - if (!builtinVarKind) - { - failed = true; - getSink()->diagnose(operand.token, Diagnostics::spirvUnableToResolveName, operand.token.getContent()); - return; - } - operand.knownValue = builtinVarKind.value(); + failed = true; + getSink()->diagnose( + operand.token, + Diagnostics::spirvUnableToResolveName, + operand.token.getContent()); + return; } - if(operand.bitwiseOrWith.getCount() - && operand.flavor != SPIRVAsmOperand::Literal - && operand.flavor != SPIRVAsmOperand::NamedValue) + + operand.knownValue = *enumValue; + operand.wrapInId = needsIdWrapper; + } + else if (operand.flavor == SPIRVAsmOperand::BuiltinVar) + { + operand.type = CheckProperType(operand.type); + auto builtinVarKind = + spirvInfo->allEnums.lookup(SPIRVCoreGrammarInfo::QualifiedEnumName{ + spirvInfo->operandKinds.lookup(UnownedStringSlice("BuiltIn")).value(), + operand.token.getContent()}); + if (!builtinVarKind) { failed = true; - getSink()->diagnose(operand.token, Diagnostics::spirvNonConstantBitwiseOr); + getSink()->diagnose( + operand.token, + Diagnostics::spirvUnableToResolveName, + operand.token.getContent()); + return; } - for(auto& o : operand.bitwiseOrWith) + operand.knownValue = builtinVarKind.value(); + } + if (operand.bitwiseOrWith.getCount() && + operand.flavor != SPIRVAsmOperand::Literal && + operand.flavor != SPIRVAsmOperand::NamedValue) + { + failed = true; + getSink()->diagnose(operand.token, Diagnostics::spirvNonConstantBitwiseOr); + } + for (auto& o : operand.bitwiseOrWith) + { + if (o.flavor != SPIRVAsmOperand::Literal && + o.flavor != SPIRVAsmOperand::NamedValue) { - if(o.flavor != SPIRVAsmOperand::Literal && o.flavor != SPIRVAsmOperand::NamedValue) - { - failed = true; - getSink()->diagnose(operand.token, Diagnostics::spirvNonConstantBitwiseOr); - } - go(go, o); - operand.knownValue |= o.knownValue; + failed = true; + getSink()->diagnose(operand.token, Diagnostics::spirvNonConstantBitwiseOr); } - }; + go(go, o); + operand.knownValue |= o.knownValue; + } + }; - check(check, inst.operands[operandIndex]); - } + check(check, inst.operands[operandIndex]); } + } - if(failed) - return CreateErrorExpr(expr); + if (failed) + return CreateErrorExpr(expr); - // Assign the type of this expression from the type of the last - // instruction, otherwise void - if(expr->insts.getCount()) + // Assign the type of this expression from the type of the last + // instruction, otherwise void + if (expr->insts.getCount()) + { + // TODO: we trust that this is correct, but could should verify + const auto lastOperands = expr->insts.getLast().operands; + if (lastOperands.getCount() >= 2 && lastOperands[0].flavor == SPIRVAsmOperand::SlangType && + lastOperands[1].flavor == SPIRVAsmOperand::ResultMarker) { - // TODO: we trust that this is correct, but could should verify - const auto lastOperands = expr->insts.getLast().operands; - if(lastOperands.getCount() >= 2 - && lastOperands[0].flavor == SPIRVAsmOperand::SlangType - && lastOperands[1].flavor == SPIRVAsmOperand::ResultMarker) - { - expr->type = lastOperands[0].type.type; - } + expr->type = lastOperands[0].type.type; } - if(!expr->type) - expr->type = m_astBuilder->getVoidType(); - - return expr; } + if (!expr->type) + expr->type = m_astBuilder->getVoidType(); + + return expr; } +} // namespace Slang diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h index 499f41fa7..95a6d00cc 100644 --- a/source/slang/slang-check-impl.h +++ b/source/slang/slang-check-impl.h @@ -11,3010 +11,2971 @@ namespace Slang { - template - bool diagnoseCapabilityErrors(DiagnosticSink* sink, CompilerOptionSet& optionSet, P const& pos, DiagnosticInfo const& info, Args const&... args) - { - if (optionSet.getBoolOption(CompilerOptionName::IgnoreCapabilities)) - return false; - return sink->diagnose(pos, info, args...); - } +template +bool diagnoseCapabilityErrors( + DiagnosticSink* sink, + CompilerOptionSet& optionSet, + P const& pos, + DiagnosticInfo const& info, + Args const&... args) +{ + if (optionSet.getBoolOption(CompilerOptionName::IgnoreCapabilities)) + return false; + return sink->diagnose(pos, info, args...); +} - enum class IsSubTypeOptions - { - None = 0, +enum class IsSubTypeOptions +{ + None = 0, - /// A type may not be finished 'DeclCheckState::ReadyForLookup` while `isSubType` is called. - /// We should not cache any negative results when this flag is set. - NoCaching = 1 << 0, - }; + /// A type may not be finished 'DeclCheckState::ReadyForLookup` while `isSubType` is called. + /// We should not cache any negative results when this flag is set. + NoCaching = 1 << 0, +}; - /// Should the given `decl` be treated as a static rather than instance declaration? - bool isEffectivelyStatic( - Decl* decl); +/// Should the given `decl` be treated as a static rather than instance declaration? +bool isEffectivelyStatic(Decl* decl); - bool isGlobalDecl(Decl* decl); +bool isGlobalDecl(Decl* decl); - bool isUnsafeForceInlineFunc(FunctionDeclBase* funcDecl); +bool isUnsafeForceInlineFunc(FunctionDeclBase* funcDecl); - bool isUniformParameterType(Type* type); +bool isUniformParameterType(Type* type); - /// Create a new component type based on `inComponentType`, but with all its requiremetns filled. - RefPtr fillRequirements( - ComponentType* inComponentType); +/// Create a new component type based on `inComponentType`, but with all its requiremetns filled. +RefPtr fillRequirements(ComponentType* inComponentType); - Type* checkProperType( - Linkage* linkage, - TypeExp typeExp, - DiagnosticSink* sink); +Type* checkProperType(Linkage* linkage, TypeExp typeExp, DiagnosticSink* sink); - /// Get the element type if `type` is Ptr or PtrLike type, otherwise returns null. - /// Note: this currently does not include PtrTypeBase. - Type* getPointedToTypeIfCanImplicitDeref(Type* type); +/// Get the element type if `type` is Ptr or PtrLike type, otherwise returns null. +/// Note: this currently does not include PtrTypeBase. +Type* getPointedToTypeIfCanImplicitDeref(Type* type); - inline int getIntValueBitSize(IntegerLiteralValue val) +inline int getIntValueBitSize(IntegerLiteralValue val) +{ + uint64_t v = val > 0 ? (uint64_t)val : (uint64_t)-val; + int result = 1; + while (v >>= 1) { - uint64_t v = val > 0 ? (uint64_t)val : (uint64_t)-val; - int result = 1; - while (v >>= 1) - { - result++; - } - return result; + result++; } + return result; +} - int getTypeBitSize(Type* t); +int getTypeBitSize(Type* t); - // A flat representation of basic types (scalars, vectors and matrices) - // that can be used as lookup key in caches - struct BasicTypeKey - { - uint32_t baseType : 8; - uint32_t dim1 : 4; - uint32_t dim2 : 4; - uint32_t knownConstantBitCount : 8; - uint32_t knownNegative : 1; - uint32_t isLValue : 1; - uint32_t reserved : 6; - uint32_t getRaw() const - { - uint32_t val; - memcpy(&val, this, sizeof(uint32_t)); - return val; - } - bool operator==(BasicTypeKey other) const - { - return getRaw() == other.getRaw(); - } - static BasicTypeKey invalid() { return BasicTypeKey{ 0xff, 0, 0, 0, 0, 0, 0 }; } - }; - - SLANG_FORCE_INLINE BasicTypeKey makeBasicTypeKey(BaseType baseType, IntegerLiteralValue dim1 = 0, IntegerLiteralValue dim2 = 0, bool inIsLValue = false) +// A flat representation of basic types (scalars, vectors and matrices) +// that can be used as lookup key in caches +struct BasicTypeKey +{ + uint32_t baseType : 8; + uint32_t dim1 : 4; + uint32_t dim2 : 4; + uint32_t knownConstantBitCount : 8; + uint32_t knownNegative : 1; + uint32_t isLValue : 1; + uint32_t reserved : 6; + uint32_t getRaw() const { - SLANG_ASSERT(dim1 >= 0 && dim2 >= 0); - return BasicTypeKey{ uint8_t(baseType), uint8_t(dim1), uint8_t(dim2), 0, 0, (inIsLValue?1u:0u), 0 }; + uint32_t val; + memcpy(&val, this, sizeof(uint32_t)); + return val; } + bool operator==(BasicTypeKey other) const { return getRaw() == other.getRaw(); } + static BasicTypeKey invalid() { return BasicTypeKey{0xff, 0, 0, 0, 0, 0, 0}; } +}; + +SLANG_FORCE_INLINE BasicTypeKey makeBasicTypeKey( + BaseType baseType, + IntegerLiteralValue dim1 = 0, + IntegerLiteralValue dim2 = 0, + bool inIsLValue = false) +{ + SLANG_ASSERT(dim1 >= 0 && dim2 >= 0); + return BasicTypeKey{ + uint8_t(baseType), + uint8_t(dim1), + uint8_t(dim2), + 0, + 0, + (inIsLValue ? 1u : 0u), + 0}; +} - inline BasicTypeKey makeBasicTypeKey(QualType typeIn, Expr* exprIn = nullptr) +inline BasicTypeKey makeBasicTypeKey(QualType typeIn, Expr* exprIn = nullptr) +{ + if (auto basicType = as(typeIn)) { - if (auto basicType = as(typeIn)) + auto rs = makeBasicTypeKey(basicType->getBaseType()); + if (auto constInt = as(exprIn)) { - auto rs = makeBasicTypeKey(basicType->getBaseType()); - if (auto constInt = as(exprIn)) + if (constInt->value < 0) { - if (constInt->value < 0) - { - rs.knownNegative = 1; - } - rs.knownConstantBitCount = getIntValueBitSize(constInt->value); + rs.knownNegative = 1; } - rs.isLValue = typeIn.isLeftValue ? 1u : 0u; - return rs; + rs.knownConstantBitCount = getIntValueBitSize(constInt->value); } - else if (auto vectorType = as(typeIn)) + rs.isLValue = typeIn.isLeftValue ? 1u : 0u; + return rs; + } + else if (auto vectorType = as(typeIn)) + { + if (auto elemCount = as(vectorType->getElementCount())) { - if (auto elemCount = as(vectorType->getElementCount())) + if (auto elemBasicType = as(vectorType->getElementType())) { - if( auto elemBasicType = as(vectorType->getElementType()) ) - { - return makeBasicTypeKey(elemBasicType->getBaseType(), elemCount->getValue(), 0, typeIn.isLeftValue); - } + return makeBasicTypeKey( + elemBasicType->getBaseType(), + elemCount->getValue(), + 0, + typeIn.isLeftValue); } } - else if (auto matrixType = as(typeIn)) + } + else if (auto matrixType = as(typeIn)) + { + if (auto elemCount1 = as(matrixType->getRowCount())) { - if (auto elemCount1 = as(matrixType->getRowCount())) + if (auto elemCount2 = as(matrixType->getColumnCount())) { - if (auto elemCount2 = as(matrixType->getColumnCount())) + if (auto elemBasicType = as(matrixType->getElementType())) { - if( auto elemBasicType = as(matrixType->getElementType()) ) - { - return makeBasicTypeKey(elemBasicType->getBaseType(), elemCount1->getValue(), elemCount2->getValue(), typeIn.isLeftValue); - } + return makeBasicTypeKey( + elemBasicType->getBaseType(), + elemCount1->getValue(), + elemCount2->getValue(), + typeIn.isLeftValue); } } } - return BasicTypeKey::invalid(); } + return BasicTypeKey::invalid(); +} - struct BasicTypeKeyPair +struct BasicTypeKeyPair +{ + BasicTypeKey type1, type2; + bool operator==(const BasicTypeKeyPair& rhs) const { - BasicTypeKey type1, type2; - bool operator==(const BasicTypeKeyPair& rhs) const { return type1 == rhs.type1 && type2 == rhs.type2; } - bool operator!=(const BasicTypeKeyPair& rhs) const { return !(*this == rhs); } + return type1 == rhs.type1 && type2 == rhs.type2; + } + bool operator!=(const BasicTypeKeyPair& rhs) const { return !(*this == rhs); } - bool isValid() const { return type1.getRaw() != BasicTypeKey::invalid().getRaw() && type2.getRaw() != BasicTypeKey::invalid().getRaw(); } + bool isValid() const + { + return type1.getRaw() != BasicTypeKey::invalid().getRaw() && + type2.getRaw() != BasicTypeKey::invalid().getRaw(); + } - HashCode getHashCode() const - { - return combineHash(type1.getRaw(), type2.getRaw()); - } - }; + HashCode getHashCode() const { return combineHash(type1.getRaw(), type2.getRaw()); } +}; - struct OperatorOverloadCacheKey +struct OperatorOverloadCacheKey +{ + intptr_t operatorName; + BasicTypeKey args[2]; + bool operator==(OperatorOverloadCacheKey key) const { - intptr_t operatorName; - BasicTypeKey args[2]; - bool operator == (OperatorOverloadCacheKey key) const - { - return operatorName == key.operatorName && args[0] == key.args[0] && args[1] == key.args[1]; - } - HashCode getHashCode() const - { - return combineHash((int)(UInt64)(void*)(operatorName), args[0].getRaw(), args[1].getRaw()); - } - bool fromOperatorExpr(OperatorExpr* opExpr) - { - // First, lets see if the argument types are ones - // that we can encode in our space of keys. - args[0] = BasicTypeKey::invalid(); - args[1] = BasicTypeKey::invalid(); - if (opExpr->arguments.getCount() > 2) - return false; + return operatorName == key.operatorName && args[0] == key.args[0] && args[1] == key.args[1]; + } + HashCode getHashCode() const + { + return combineHash((int)(UInt64)(void*)(operatorName), args[0].getRaw(), args[1].getRaw()); + } + bool fromOperatorExpr(OperatorExpr* opExpr) + { + // First, lets see if the argument types are ones + // that we can encode in our space of keys. + args[0] = BasicTypeKey::invalid(); + args[1] = BasicTypeKey::invalid(); + if (opExpr->arguments.getCount() > 2) + return false; - for (Index i = 0; i < opExpr->arguments.getCount(); i++) + for (Index i = 0; i < opExpr->arguments.getCount(); i++) + { + auto key = makeBasicTypeKey(opExpr->arguments[i]->type, opExpr->arguments[i]); + if (key.getRaw() == BasicTypeKey::invalid().getRaw()) { - auto key = makeBasicTypeKey(opExpr->arguments[i]->type, opExpr->arguments[i]); - if (key.getRaw() == BasicTypeKey::invalid().getRaw()) - { - return false; - } - args[i] = key; + return false; } + args[i] = key; + } + + // Next, lets see if we can find an intrinsic opcode + // attached to an overloaded definition (filtered for + // definitions that could conceivably apply to us). + // + // TODO: This should really be parsed on the operator name + // plus fixity, rather than the intrinsic opcode... + // + // We will need to reject postfix definitions for prefix + // operators, and vice versa, to ensure things work. + // + auto prefixExpr = as(opExpr); + auto postfixExpr = as(opExpr); - // Next, lets see if we can find an intrinsic opcode - // attached to an overloaded definition (filtered for - // definitions that could conceivably apply to us). - // - // TODO: This should really be parsed on the operator name - // plus fixity, rather than the intrinsic opcode... - // - // We will need to reject postfix definitions for prefix - // operators, and vice versa, to ensure things work. - // - auto prefixExpr = as(opExpr); - auto postfixExpr = as(opExpr); - - if (auto overloadedBase = as(opExpr->functionExpr)) + if (auto overloadedBase = as(opExpr->functionExpr)) + { + for (auto item : overloadedBase->lookupResult2) { - for(auto item : overloadedBase->lookupResult2 ) + // Look at a candidate definition to be called and + // see if it gives us a key to work with. + // + Decl* funcDecl = item.declRef.getDecl(); + if (auto genDecl = as(funcDecl)) + funcDecl = genDecl->inner; + + // Reject definitions that have the wrong fixity. + // + if (prefixExpr && !funcDecl->findModifier()) + continue; + if (postfixExpr && !funcDecl->findModifier()) + continue; + + if (auto intrinsicOp = funcDecl->findModifier()) { - // Look at a candidate definition to be called and - // see if it gives us a key to work with. - // - Decl* funcDecl = item.declRef.getDecl(); - if (auto genDecl = as(funcDecl)) - funcDecl = genDecl->inner; - - // Reject definitions that have the wrong fixity. - // - if(prefixExpr && !funcDecl->findModifier()) - continue; - if(postfixExpr && !funcDecl->findModifier()) - continue; - - if (auto intrinsicOp = funcDecl->findModifier()) - { - operatorName = intrinsicOp->op; - return true; - } + operatorName = intrinsicOp->op; + return true; } } - return false; } - }; + return false; + } +}; - struct OverloadCandidate +struct OverloadCandidate +{ + enum class Flavor { - enum class Flavor - { - Func, - Generic, - UnspecializedGeneric, - Expr, - }; - Flavor flavor; - - enum class Status - { - GenericArgumentInferenceFailed, - Unchecked, - ArityChecked, - FixityChecked, - TypeChecked, - DirectionChecked, - VisibilityChecked, - Applicable, - }; - Status status = Status::Unchecked; - - typedef unsigned int Flags; - enum Flag : Flags - { - IsPartiallyAppliedGeneric = 1 << 0, - }; - Flags flags = 0; - - // Reference to the declaration being applied - LookupResultItem item; - - // The expression when flavor is Expr. - Expr* exprVal = nullptr; - - // Type of function being applied (for cases where `item` is not used) - FuncType* funcType = nullptr; - - // The type of the result expression if this candidate is selected - Type* resultType = nullptr; - - // A system for tracking constraints introduced on generic parameters - // ConstraintSystem constraintSystem; - - // How much conversion cost should be considered for this overload, - // when ranking candidates. - ConversionCost conversionCostSum = kConversionCost_None; - - // When required, a candidate can store a pre-checked list of - // arguments so that we don't have to repeat work across checking - // phases. Currently this is only needed for generics. - SubstitutionSet subst; + Func, + Generic, + UnspecializedGeneric, + Expr, }; + Flavor flavor; - struct TypeCheckingCache + enum class Status { - Dictionary resolvedOperatorOverloadCache; - Dictionary conversionCostCache; + GenericArgumentInferenceFailed, + Unchecked, + ArityChecked, + FixityChecked, + TypeChecked, + DirectionChecked, + VisibilityChecked, + Applicable, }; + Status status = Status::Unchecked; - enum class CoercionSite + typedef unsigned int Flags; + enum Flag : Flags { - General, - Assignment, - Argument, - Return, - Initializer, - ExplicitCoercion + IsPartiallyAppliedGeneric = 1 << 0, }; + Flags flags = 0; - struct FacetImpl; + // Reference to the declaration being applied + LookupResultItem item; - /// Information about one "facet" of a type or declaration - /// - /// In the simplest terms, a facet represents a grouping of - /// member declarations that were all originally declared - /// as part of the same `{}`-enclosed body. - /// - /// A given *entity* (a type, type declaration, or `extension` - /// declaration) may have multiple facets, depending on what it - /// declares, what it inherits from, or what `extension`s apply to it. - /// - /// Broadly, an entity will have: - /// - /// * A *self facet*, if it has a body, that contains the members - /// the entity directly declares. - /// - /// * An *inherited facet* for each base type that it (transitively) - /// inherits from. Inherited facets are either *direct*, if the - /// original entity stated the inheritance relationship, or - /// *indirect* if they arise from the transitive closure of the - /// inheritance relationship. Each inherited facet contains the - /// members of the entity that was inherited from. - /// - /// * An *extension facet* for each `extension` declaration that - /// is known to apply to the entity in the context where semantic - /// checking is being performed. Each extension facet contains the - /// members of the `extension` that applied. - /// - struct Facet - { - public: - /// Kinds of facets that can occur - enum class Kind - { - Type, - Extension, - }; + // The expression when flavor is Expr. + Expr* exprVal = nullptr; - /// How many indirections away from the self facet? - typedef unsigned int DirectnessVal; - enum class Directness : DirectnessVal - { - Self = 0, - Direct = 1, - }; + // Type of function being applied (for cases where `item` is not used) + FuncType* funcType = nullptr; - /// The *origin* of a facet is the type and/or declaration - /// that the facet's members belong to. - /// - struct Origin - { - /// A `DeclRef` to the declaration this facet corresponds to, if any. - /// - /// This might be a type declaration, an `extension` declaration, - /// or nothing. - /// - DeclRef declRef; - - /// The type that this facet corresponds to, if any - Type* type = nullptr; - - Origin() - {} - - explicit Origin(DeclRef declRef, Type* type = nullptr) - : declRef(declRef) - , type(type) - {} - }; + // The type of the result expression if this candidate is selected + Type* resultType = nullptr; - Facet() - {} + // A system for tracking constraints introduced on generic parameters + // ConstraintSystem constraintSystem; - typedef FacetImpl Impl; + // How much conversion cost should be considered for this overload, + // when ranking candidates. + ConversionCost conversionCostSum = kConversionCost_None; - Facet(Impl* impl) - : _impl(impl) - {} + // When required, a candidate can store a pre-checked list of + // arguments so that we don't have to repeat work across checking + // phases. Currently this is only needed for generics. + SubstitutionSet subst; +}; - Impl* getImpl() const { return _impl; } - Impl* operator->() const { return _impl; } +struct TypeCheckingCache +{ + Dictionary resolvedOperatorOverloadCache; + Dictionary conversionCostCache; +}; - private: - Impl* _impl = nullptr; +enum class CoercionSite +{ + General, + Assignment, + Argument, + Return, + Initializer, + ExplicitCoercion +}; + +struct FacetImpl; + +/// Information about one "facet" of a type or declaration +/// +/// In the simplest terms, a facet represents a grouping of +/// member declarations that were all originally declared +/// as part of the same `{}`-enclosed body. +/// +/// A given *entity* (a type, type declaration, or `extension` +/// declaration) may have multiple facets, depending on what it +/// declares, what it inherits from, or what `extension`s apply to it. +/// +/// Broadly, an entity will have: +/// +/// * A *self facet*, if it has a body, that contains the members +/// the entity directly declares. +/// +/// * An *inherited facet* for each base type that it (transitively) +/// inherits from. Inherited facets are either *direct*, if the +/// original entity stated the inheritance relationship, or +/// *indirect* if they arise from the transitive closure of the +/// inheritance relationship. Each inherited facet contains the +/// members of the entity that was inherited from. +/// +/// * An *extension facet* for each `extension` declaration that +/// is known to apply to the entity in the context where semantic +/// checking is being performed. Each extension facet contains the +/// members of the `extension` that applied. +/// +struct Facet +{ +public: + /// Kinds of facets that can occur + enum class Kind + { + Type, + Extension, }; - - /// Do the origins of `left` and `right` match, - /// such that they are both facets for the same - /// base type or `extension`? - /// - bool originsMatch(Facet left, Facet right); - - inline bool operator!(Facet facet) { return !facet.getImpl(); } - - bool operator==(Facet::Origin left, Facet::Origin right); - - inline bool operator!=(Facet::Origin left, Facet::Origin right) + /// How many indirections away from the self facet? + typedef unsigned int DirectnessVal; + enum class Directness : DirectnessVal { - return !(left == right); - } - - /// Heap-allocated implementation of a single facet. - struct FacetImpl - { - /// The kind of this facet - Facet::Kind kind = Facet::Kind::Type; - - /// How many indirections away from the self facet? - Facet::Directness directness = Facet::Directness::Self; - - /// The origin of this facet. - /// - /// This is the type or declaration that the facet - /// corresponds to. - /// - Facet::Origin origin; - - Type* getType() const { return origin.type; } - DeclRef getDeclRef() const { return origin.declRef; } - - /// A witness that the type this facet belongs to - /// is a subtype of `origin.type` (if both of those - /// types exist). - /// - SubtypeWitness* subtypeWitness = nullptr; - - /// The next facet in the linearized inheritance list of the entity. - Facet next; - - FacetImpl() - {} - - FacetImpl( - Facet::Kind kind, - Facet::Directness directness, - DeclRef declRef, - Type* type, - SubtypeWitness* subtypeWitness) - : kind(kind) - , directness(directness) - , origin(declRef, type) - , subtypeWitness(subtypeWitness) - {} + Self = 0, + Direct = 1, }; - struct FacetListBuilder; - - /// A singly linked list of facets. - struct FacetList + /// The *origin* of a facet is the type and/or declaration + /// that the facet's members belong to. + /// + struct Origin { - public: - FacetList() - {} - - explicit FacetList(Facet head) - : _head(head) - {} - - Facet getHead() const { return _head; } - Facet& getHead() { return _head; } + /// A `DeclRef` to the declaration this facet corresponds to, if any. + /// + /// This might be a type declaration, an `extension` declaration, + /// or nothing. + /// + DeclRef declRef; - Facet advanceHead() - { - SLANG_ASSERT(_head.getImpl()); - auto facet = _head; - _head = facet->next; - return facet; - } + /// The type that this facet corresponds to, if any + Type* type = nullptr; - Facet popHead() - { - auto facet = advanceHead(); - facet->next = nullptr; - return facet; - } + Origin() {} - FacetList getTail() const + explicit Origin(DeclRef declRef, Type* type = nullptr) + : declRef(declRef), type(type) { - SLANG_ASSERT(_head.getImpl()); - return FacetList(_head->next); } + }; - bool containsMatchFor(Facet facet) const; - - bool isEmpty() const { return _head.getImpl() == nullptr; } - - struct Iterator - { - public: - Iterator() - {} - - Iterator(Facet::Impl* cursor) - : _cursor(cursor) - {} + Facet() {} - bool operator!=(Iterator const& that) const - { - return this->_cursor != that._cursor; - } + typedef FacetImpl Impl; - void operator++() - { - SLANG_ASSERT(_cursor); - _cursor = _cursor->next.getImpl(); - } + Facet(Impl* impl) + : _impl(impl) + { + } - Facet operator*() const - { - return _cursor; - } + Impl* getImpl() const { return _impl; } + Impl* operator->() const { return _impl; } - private: - Facet::Impl* _cursor = nullptr; - }; +private: + Impl* _impl = nullptr; +}; - Iterator begin() const { return Iterator(_head.getImpl()); } - Iterator end() const { return Iterator(); } - struct Appender - { - public: - Appender(FacetList& list) - { - _link = &list._head; - } +/// Do the origins of `left` and `right` match, +/// such that they are both facets for the same +/// base type or `extension`? +/// +bool originsMatch(Facet left, Facet right); - void add(Facet facet) - { - *_link = facet; - _link = &facet->next; - } +inline bool operator!(Facet facet) +{ + return !facet.getImpl(); +} - protected: - Appender() - {} +bool operator==(Facet::Origin left, Facet::Origin right); - Facet* _link = nullptr; - }; +inline bool operator!=(Facet::Origin left, Facet::Origin right) +{ + return !(left == right); +} - typedef FacetListBuilder Builder; +/// Heap-allocated implementation of a single facet. +struct FacetImpl +{ + /// The kind of this facet + Facet::Kind kind = Facet::Kind::Type; + + /// How many indirections away from the self facet? + Facet::Directness directness = Facet::Directness::Self; + + /// The origin of this facet. + /// + /// This is the type or declaration that the facet + /// corresponds to. + /// + Facet::Origin origin; + + Type* getType() const { return origin.type; } + DeclRef getDeclRef() const { return origin.declRef; } + + /// A witness that the type this facet belongs to + /// is a subtype of `origin.type` (if both of those + /// types exist). + /// + SubtypeWitness* subtypeWitness = nullptr; + + /// The next facet in the linearized inheritance list of the entity. + Facet next; + + FacetImpl() {} + + FacetImpl( + Facet::Kind kind, + Facet::Directness directness, + DeclRef declRef, + Type* type, + SubtypeWitness* subtypeWitness) + : kind(kind), directness(directness), origin(declRef, type), subtypeWitness(subtypeWitness) + { + } +}; - protected: +struct FacetListBuilder; - Facet _head; - }; +/// A singly linked list of facets. +struct FacetList +{ +public: + FacetList() {} - struct FacetListBuilder : FacetList, FacetList::Appender + explicit FacetList(Facet head) + : _head(head) { - public: - FacetListBuilder() - { - _link = &_head; - } - }; + } - /// Information about the inheritance of an entity (type or declaration) - /// - /// Currently this is only used to store a linearized list of the - /// `Facet`s that the type/declaration transitively inherits. - /// - struct InheritanceInfo - { - FacetList facets; - }; + Facet getHead() const { return _head; } + Facet& getHead() { return _head; } - /// Cached information about how to convert between two types. - struct ImplicitCastMethod + Facet advanceHead() { - OverloadCandidate conversionFuncOverloadCandidate = OverloadCandidate(); - ConversionCost cost = kConversionCost_Impossible; - bool isAmbiguous = false; - }; + SLANG_ASSERT(_head.getImpl()); + auto facet = _head; + _head = facet->next; + return facet; + } - struct ImplicitCastMethodKey + Facet popHead() { - Type* fromType; // nullptr means default construct. - bool isLValue; - Type* toType; - uint64_t constantVal; - bool isConstant; - HashCode getHashCode() const - { - return combineHash(Slang::getHashCode(fromType), Slang::getHashCode(toType), Slang::getHashCode(constantVal), (HashCode32)isConstant, (HashCode32)isLValue); - } - bool operator == (const ImplicitCastMethodKey& other) const - { - return fromType == other.fromType && toType == other.toType && isConstant == other.isConstant && constantVal == other.constantVal && isLValue == other.isLValue; - } - ImplicitCastMethodKey() = default; - ImplicitCastMethodKey(QualType fromType, Type* toType, Expr* fromExpr) - : fromType(fromType) - , toType(toType) - , constantVal(0) - , isConstant(false) - , isLValue(fromType.isLeftValue) - { - if (auto constInt = as(fromExpr)) - { - constantVal = constInt->value; - isConstant = true; - } - } - }; + auto facet = advanceHead(); + facet->next = nullptr; + return facet; + } - /// Shared state for a semantics-checking session. - struct SharedSemanticsContext : public RefObject + FacetList getTail() const { - Linkage* m_linkage = nullptr; - - /// The (optional) "primary" module that is the parent to everything that will be checked. - Module* m_module = nullptr; - - DiagnosticSink* m_sink = nullptr; - - /// (optional) modules that comes from previously processed translation units in the - /// front-end request that are made visible to the module being checked. This allows - /// `import` to use them instead of trying to find the files in file system. - LoadedModuleDictionary* m_environmentModules = nullptr; + SLANG_ASSERT(_head.getImpl()); + return FacetList(_head->next); + } - /// (optional) The translation unit that is being checked. - /// Needed for handling `__include`s. - TranslationUnitRequest* m_translationUnitRequest = nullptr; + bool containsMatchFor(Facet facet) const; - DiagnosticSink* getSink() - { - return m_sink; - } + bool isEmpty() const { return _head.getImpl() == nullptr; } - CompilerOptionSet& getOptionSet() - { - return m_linkage->m_optionSet; - } - - // We need to track what has been `import`ed into - // the scope of this semantic checking session, - // and also to avoid importing the same thing more - // than once. - // - List importedModulesList; - HashSet importedModulesSet; + struct Iterator + { public: - SharedSemanticsContext( - Linkage* linkage, - Module* module, - DiagnosticSink* sink, - LoadedModuleDictionary* environmentModules = nullptr, - TranslationUnitRequest* translationUnit = nullptr) - : m_linkage(linkage) - , m_module(module) - , m_sink(sink) - , m_environmentModules(environmentModules) - , m_translationUnitRequest(translationUnit) - {} - - Session* getSession() - { - return m_linkage->getSessionImpl(); - } + Iterator() {} - Linkage* getLinkage() + Iterator(Facet::Impl* cursor) + : _cursor(cursor) { - return m_linkage; } - Module* getModule() - { - return m_module; - } + bool operator!=(Iterator const& that) const { return this->_cursor != that._cursor; } - TranslationUnitRequest* getTranslationUnitRequest() + void operator++() { - return m_translationUnitRequest; - } - - bool isInLanguageServer() - { - if (m_linkage) - return m_linkage->isInLanguageServer(); - return false; + SLANG_ASSERT(_cursor); + _cursor = _cursor->next.getImpl(); } - /// Get the list of extension declarations that appear to apply to `decl` in this context - List const& getCandidateExtensionsForTypeDecl(AggTypeDecl* decl); - /// Register a candidate extension `extDecl` for `typeDecl` encountered during checking. - void registerCandidateExtension(AggTypeDecl* typeDecl, ExtensionDecl* extDecl); + Facet operator*() const { return _cursor; } - void registerAssociatedDecl(Decl* original, DeclAssociationKind assoc, Decl* declaration); - - List> const& getAssociatedDeclsForDecl(Decl* decl); - - bool isDifferentiableFunc(FunctionDeclBase* func); - bool isBackwardDifferentiableFunc(FunctionDeclBase* func); - FunctionDifferentiableLevel _getFuncDifferentiableLevelImpl(FunctionDeclBase* func, int recurseLimit); - FunctionDifferentiableLevel getFuncDifferentiableLevel(FunctionDeclBase* func); - - struct InheritanceCircularityInfo - { - InheritanceCircularityInfo( - Decl* decl, - InheritanceCircularityInfo* next) - : decl(decl) - , next(next) - {} - - /// A declaration whose inheritance is being calculated - Decl* decl = nullptr; - - /// The rest of the links in the chain of declarations being processed - InheritanceCircularityInfo* next = nullptr; - }; - - /// Get the processed inheritance information for `type`, including all its facets - InheritanceInfo getInheritanceInfo(Type* type, InheritanceCircularityInfo* circularityInfo = nullptr); + private: + Facet::Impl* _cursor = nullptr; + }; - /// Get the processed inheritance information for `extension`, including all its facets - InheritanceInfo getInheritanceInfo(DeclRef const& extension, InheritanceCircularityInfo* circularityInfo = nullptr); + Iterator begin() const { return Iterator(_head.getImpl()); } + Iterator end() const { return Iterator(); } - /// Prevent an unsupported case of - /// ``` - /// extension : IBar{}; - /// extesnion : IFoo{}; - /// ``` - /// from causing infinite recursion. - bool _checkForCircularityInExtensionTargetType( - Decl* decl, - InheritanceCircularityInfo* circularityInfo); + struct Appender + { + public: + Appender(FacetList& list) { _link = &list._head; } - /// Try get subtype witness from cache, returns true if cache contains a result for the query. - bool tryGetSubtypeWitnessFromCache(Type* sub, Type* sup, SubtypeWitness*& outWitness) - { - auto pair = TypePair{ sub, sup }; - return m_mapTypePairToSubtypeWitness.tryGetValue(pair, outWitness); - } - void cacheSubtypeWitness(Type* sub, Type* sup, SubtypeWitness*& outWitness) - { - auto pair = TypePair{ sub, sup }; - m_mapTypePairToSubtypeWitness[pair] = outWitness; - } - ImplicitCastMethod* tryGetImplicitCastMethod(ImplicitCastMethodKey key) + void add(Facet facet) { - return m_mapTypePairToImplicitCastMethod.tryGetValue(key); - } - void cacheImplicitCastMethod(ImplicitCastMethodKey key, ImplicitCastMethod candidate) - { - m_mapTypePairToImplicitCastMethod[key] = candidate; + *_link = facet; + _link = &facet->next; } - // Get the inner most generic decl that a decl-ref is dependent on. - // For example, `Foo` depends on the generic decl that defines `T`. - // - DeclRef getDependentGenericParent(DeclRef declRef); - private: - /// Mapping from type declarations to the known extensiosn that apply to them - Dictionary> m_mapTypeDeclToCandidateExtensions; - - /// Is the `m_mapTypeDeclToCandidateExtensions` dictionary valid and up to date? - bool m_candidateExtensionListsBuilt = false; - - /// Add candidate extensions declared in `moduleDecl` to `m_mapTypeDeclToCandidateExtensions` - void _addCandidateExtensionsFromModule(ModuleDecl* moduleDecl); - - /// Mapping from a decl to additional declarations of the same decl. - /// The additional declarations provide a location to hold extra decorations. - OrderedDictionary> m_mapDeclToAssociatedDecls; - - /// Is the `m_mapDeclToAssociatedDecls` dictionary valid and up to date? - bool m_associatedDeclListsBuilt = false; - - /// Add associated decls declared in `moduleDecl` to `m_mapDeclToAssociatedDecls` - void _addDeclAssociationsFromModule(ModuleDecl* moduleDecl); - - ASTBuilder* _getASTBuilder() { return m_linkage->getASTBuilder(); } - - InheritanceInfo _getInheritanceInfo(DeclRef declRef, DeclRefType* correspondingType, InheritanceCircularityInfo* circularityInfo); - InheritanceInfo _calcInheritanceInfo(Type* type, InheritanceCircularityInfo* circularityInfo); - InheritanceInfo _calcInheritanceInfo(DeclRef declRef, DeclRefType* correspondingType, InheritanceCircularityInfo* circularityInfo); + protected: + Appender() {} - void getDependentGenericParentImpl(DeclRef& genericParent, DeclRef declRef); + Facet* _link = nullptr; + }; - struct DirectBaseInfo - { - FacetList facets; + typedef FacetListBuilder Builder; - Facet::Impl facetImpl; +protected: + Facet _head; +}; - DirectBaseInfo* next = nullptr; - }; +struct FacetListBuilder : FacetList, FacetList::Appender +{ +public: + FacetListBuilder() { _link = &_head; } +}; + +/// Information about the inheritance of an entity (type or declaration) +/// +/// Currently this is only used to store a linearized list of the +/// `Facet`s that the type/declaration transitively inherits. +/// +struct InheritanceInfo +{ + FacetList facets; +}; - struct DirectBaseListBuilder; +/// Cached information about how to convert between two types. +struct ImplicitCastMethod +{ + OverloadCandidate conversionFuncOverloadCandidate = OverloadCandidate(); + ConversionCost cost = kConversionCost_Impossible; + bool isAmbiguous = false; +}; - struct DirectBaseList +struct ImplicitCastMethodKey +{ + Type* fromType; // nullptr means default construct. + bool isLValue; + Type* toType; + uint64_t constantVal; + bool isConstant; + HashCode getHashCode() const + { + return combineHash( + Slang::getHashCode(fromType), + Slang::getHashCode(toType), + Slang::getHashCode(constantVal), + (HashCode32)isConstant, + (HashCode32)isLValue); + } + bool operator==(const ImplicitCastMethodKey& other) const + { + return fromType == other.fromType && toType == other.toType && + isConstant == other.isConstant && constantVal == other.constantVal && + isLValue == other.isLValue; + } + ImplicitCastMethodKey() = default; + ImplicitCastMethodKey(QualType fromType, Type* toType, Expr* fromExpr) + : fromType(fromType) + , toType(toType) + , constantVal(0) + , isConstant(false) + , isLValue(fromType.isLeftValue) + { + if (auto constInt = as(fromExpr)) { - public: - struct Iterator - { - public: - Iterator() - {} - - Iterator(DirectBaseInfo* cursor) - : _cursor(cursor) - {} + constantVal = constInt->value; + isConstant = true; + } + } +}; - bool operator!=(Iterator that) const - { - return _cursor != that._cursor; - } +/// Shared state for a semantics-checking session. +struct SharedSemanticsContext : public RefObject +{ + Linkage* m_linkage = nullptr; + + /// The (optional) "primary" module that is the parent to everything that will be checked. + Module* m_module = nullptr; + + DiagnosticSink* m_sink = nullptr; + + /// (optional) modules that comes from previously processed translation units in the + /// front-end request that are made visible to the module being checked. This allows + /// `import` to use them instead of trying to find the files in file system. + LoadedModuleDictionary* m_environmentModules = nullptr; + + /// (optional) The translation unit that is being checked. + /// Needed for handling `__include`s. + TranslationUnitRequest* m_translationUnitRequest = nullptr; + + DiagnosticSink* getSink() { return m_sink; } + + CompilerOptionSet& getOptionSet() { return m_linkage->m_optionSet; } + + // We need to track what has been `import`ed into + // the scope of this semantic checking session, + // and also to avoid importing the same thing more + // than once. + // + List importedModulesList; + HashSet importedModulesSet; + +public: + SharedSemanticsContext( + Linkage* linkage, + Module* module, + DiagnosticSink* sink, + LoadedModuleDictionary* environmentModules = nullptr, + TranslationUnitRequest* translationUnit = nullptr) + : m_linkage(linkage) + , m_module(module) + , m_sink(sink) + , m_environmentModules(environmentModules) + , m_translationUnitRequest(translationUnit) + { + } - void operator++() - { - SLANG_ASSERT(_cursor); - _cursor = _cursor->next; - } + Session* getSession() { return m_linkage->getSessionImpl(); } - DirectBaseInfo* operator*() - { - return _cursor; - } + Linkage* getLinkage() { return m_linkage; } - private: - DirectBaseInfo* _cursor = nullptr; - }; + Module* getModule() { return m_module; } - Iterator begin() const { return Iterator(_head); } - Iterator end() const { return Iterator(); } + TranslationUnitRequest* getTranslationUnitRequest() { return m_translationUnitRequest; } - bool isEmpty() const - { - return _head == nullptr; - } + bool isInLanguageServer() + { + if (m_linkage) + return m_linkage->isInLanguageServer(); + return false; + } + /// Get the list of extension declarations that appear to apply to `decl` in this context + List const& getCandidateExtensionsForTypeDecl(AggTypeDecl* decl); - bool doesAnyTailContainMatchFor(Facet facet) const; + /// Register a candidate extension `extDecl` for `typeDecl` encountered during checking. + void registerCandidateExtension(AggTypeDecl* typeDecl, ExtensionDecl* extDecl); - void removeEmptyLists(); + void registerAssociatedDecl(Decl* original, DeclAssociationKind assoc, Decl* declaration); - typedef DirectBaseListBuilder Builder; + List> const& getAssociatedDeclsForDecl(Decl* decl); - public: - DirectBaseInfo* _head = nullptr; - }; + bool isDifferentiableFunc(FunctionDeclBase* func); + bool isBackwardDifferentiableFunc(FunctionDeclBase* func); + FunctionDifferentiableLevel _getFuncDifferentiableLevelImpl( + FunctionDeclBase* func, + int recurseLimit); + FunctionDifferentiableLevel getFuncDifferentiableLevel(FunctionDeclBase* func); - struct DirectBaseListBuilder : DirectBaseList + struct InheritanceCircularityInfo + { + InheritanceCircularityInfo(Decl* decl, InheritanceCircularityInfo* next) + : decl(decl), next(next) { - public: - DirectBaseListBuilder() - { - _link = &_head; - } - - void add(DirectBaseInfo* base) - { - *_link = base; - _link = &base->next; - } - - private: - DirectBaseInfo** _link = nullptr; - }; + } - void _mergeFacetLists(DirectBaseList bases, FacetList baseFacets, FacetList::Builder& ioMergedFacets); + /// A declaration whose inheritance is being calculated + Decl* decl = nullptr; - struct TypePair - { - Type* type0; - Type* type1; - HashCode getHashCode() const { return combineHash(Slang::getHashCode(type0), Slang::getHashCode(type1)); } - bool operator == (const TypePair& other) const { return type0 == other.type0 && type1 == other.type1; } - }; - Dictionary m_mapTypeToInheritanceInfo; - Dictionary, InheritanceInfo> m_mapDeclRefToInheritanceInfo; - Dictionary m_mapTypePairToSubtypeWitness; - Dictionary m_mapTypePairToImplicitCastMethod; + /// The rest of the links in the chain of declarations being processed + InheritanceCircularityInfo* next = nullptr; }; - /// Local/scoped state of the semantic-checking system - /// - /// This type is kept distinct from `SharedSemanticsContext` so that we - /// can avoid unncessary mutable state being propagated through the - /// checking process. - /// - /// Semantic-checking code should make a new local `SemanticsContext` - /// in cases where it want to check a sub-entity (expression, statement, - /// declaration, etc.) in a modified or extended context. - /// - struct SemanticsContext - { - public: - friend struct OuterScopeContextRAII; - - explicit SemanticsContext( - SharedSemanticsContext* shared) - : m_shared(shared) - , m_sink(shared->getSink()) - , m_astBuilder(shared->getLinkage()->getASTBuilder()) - { - if (shared->getLinkage()->m_optionSet.hasOption(CompilerOptionName::DisableShortCircuit)) - { - m_shouldShortCircuitLogicExpr = - !shared->getLinkage()->m_optionSet.getBoolOption(CompilerOptionName::DisableShortCircuit); - } - } + /// Get the processed inheritance information for `type`, including all its facets + InheritanceInfo getInheritanceInfo( + Type* type, + InheritanceCircularityInfo* circularityInfo = nullptr); + + /// Get the processed inheritance information for `extension`, including all its facets + InheritanceInfo getInheritanceInfo( + DeclRef const& extension, + InheritanceCircularityInfo* circularityInfo = nullptr); + + /// Prevent an unsupported case of + /// ``` + /// extension : IBar{}; + /// extesnion : IFoo{}; + /// ``` + /// from causing infinite recursion. + bool _checkForCircularityInExtensionTargetType( + Decl* decl, + InheritanceCircularityInfo* circularityInfo); - SharedSemanticsContext* getShared() { return m_shared; } - CompilerOptionSet& getOptionSet() { return getShared()->getOptionSet(); } - ASTBuilder* getASTBuilder() { return m_astBuilder; } + /// Try get subtype witness from cache, returns true if cache contains a result for the query. + bool tryGetSubtypeWitnessFromCache(Type* sub, Type* sup, SubtypeWitness*& outWitness) + { + auto pair = TypePair{sub, sup}; + return m_mapTypePairToSubtypeWitness.tryGetValue(pair, outWitness); + } + void cacheSubtypeWitness(Type* sub, Type* sup, SubtypeWitness*& outWitness) + { + auto pair = TypePair{sub, sup}; + m_mapTypePairToSubtypeWitness[pair] = outWitness; + } + ImplicitCastMethod* tryGetImplicitCastMethod(ImplicitCastMethodKey key) + { + return m_mapTypePairToImplicitCastMethod.tryGetValue(key); + } + void cacheImplicitCastMethod(ImplicitCastMethodKey key, ImplicitCastMethod candidate) + { + m_mapTypePairToImplicitCastMethod[key] = candidate; + } - DiagnosticSink* getSink() { return m_sink; } + // Get the inner most generic decl that a decl-ref is dependent on. + // For example, `Foo` depends on the generic decl that defines `T`. + // + DeclRef getDependentGenericParent(DeclRef declRef); - Session* getSession() { return m_shared->getSession(); } +private: + /// Mapping from type declarations to the known extensiosn that apply to them + Dictionary> m_mapTypeDeclToCandidateExtensions; - Linkage* getLinkage() { return m_shared->m_linkage; } - NamePool* getNamePool() { return getLinkage()->getNamePool(); } - SourceManager* getSourceManager() { return getLinkage()->getSourceManager(); } + /// Is the `m_mapTypeDeclToCandidateExtensions` dictionary valid and up to date? + bool m_candidateExtensionListsBuilt = false; - SemanticsContext withSink(DiagnosticSink* sink) - { - SemanticsContext result(*this); - result.m_sink = sink; - return result; - } + /// Add candidate extensions declared in `moduleDecl` to `m_mapTypeDeclToCandidateExtensions` + void _addCandidateExtensionsFromModule(ModuleDecl* moduleDecl); - FunctionDeclBase* getParentFuncOfVisitor() { return m_parentFunc; } - void setParentFuncOfVisitor(FunctionDeclBase* funcDecl) { m_parentFunc = funcDecl; } + /// Mapping from a decl to additional declarations of the same decl. + /// The additional declarations provide a location to hold extra decorations. + OrderedDictionary> m_mapDeclToAssociatedDecls; - SemanticsContext withParentFunc(FunctionDeclBase* parentFunc) - { - SemanticsContext result(*this); - result.m_parentFunc = parentFunc; - result.m_outerStmts = nullptr; - result.m_parentDifferentiableAttr = parentFunc->findModifier(); - if (parentFunc->ownedScope) - result.m_outerScope = parentFunc->ownedScope; - return result; - } + /// Is the `m_mapDeclToAssociatedDecls` dictionary valid and up to date? + bool m_associatedDeclListsBuilt = false; - SemanticsContext withParentExpandExpr(ExpandExpr* expr, OrderedHashSet* capturedTypes) - { - SemanticsContext result(*this); - result.m_parentExpandExpr = expr; - result.m_capturedTypePacks = capturedTypes; - return result; - } + /// Add associated decls declared in `moduleDecl` to `m_mapDeclToAssociatedDecls` + void _addDeclAssociationsFromModule(ModuleDecl* moduleDecl); - /// Information for tracking one or more outer statements. - /// - /// During checking of statements, we need to track what - /// outer statements are in scope, so that we can resolve - /// the target for a `break` or `continue` statement (and - /// validate that such statements are only used in contexts - /// where such a target exists). - /// - /// We use a linked list of `OuterStmtInfo` threaded up - /// through the recursive call stack to track the statements - /// that are lexically surrounding the one we are checking. - /// - struct OuterStmtInfo - { - Stmt* stmt = nullptr; - OuterStmtInfo* next; - }; + ASTBuilder* _getASTBuilder() { return m_linkage->getASTBuilder(); } - OuterStmtInfo* getOuterStmts() { return m_outerStmts; } + InheritanceInfo _getInheritanceInfo( + DeclRef declRef, + DeclRefType* correspondingType, + InheritanceCircularityInfo* circularityInfo); + InheritanceInfo _calcInheritanceInfo(Type* type, InheritanceCircularityInfo* circularityInfo); + InheritanceInfo _calcInheritanceInfo( + DeclRef declRef, + DeclRefType* correspondingType, + InheritanceCircularityInfo* circularityInfo); - SemanticsContext withOuterStmts(OuterStmtInfo* outerStmts) - { - SemanticsContext result(*this); - result.m_outerStmts = outerStmts; - return result; - } + void getDependentGenericParentImpl(DeclRef& genericParent, DeclRef declRef); - // Setup the flag to indicate disabling the short-circuiting evaluation - // for the logical expressions associted with the subcontext - SemanticsContext disableShortCircuitLogicalExpr() - { - SemanticsContext result(*this); - result.m_shouldShortCircuitLogicExpr = false; - return result; - } + struct DirectBaseInfo + { + FacetList facets; - TryClauseType getEnclosingTryClauseType() { return m_enclosingTryClauseType; } + Facet::Impl facetImpl; - SemanticsContext withEnclosingTryClauseType(TryClauseType tryClauseType) - { - SemanticsContext result(*this); - result.m_enclosingTryClauseType = tryClauseType; - return result; - } + DirectBaseInfo* next = nullptr; + }; - DifferentiableAttribute* getParentDifferentiableAttribute() - { - return m_parentDifferentiableAttr; - } + struct DirectBaseListBuilder; - /// A scope that is local to a particular expression, and - /// that can be used to allocate temporary bindings that - /// might be needed by that expression or its sub-expressions. - /// - /// The scope is represented as a sequence of nested `LetExpr`s - /// that introduce the bindings needed in the scope. - /// - struct ExprLocalScope + struct DirectBaseList + { + public: + struct Iterator { public: - void addBinding(LetExpr* binding); + Iterator() {} + + Iterator(DirectBaseInfo* cursor) + : _cursor(cursor) + { + } + + bool operator!=(Iterator that) const { return _cursor != that._cursor; } + + void operator++() + { + SLANG_ASSERT(_cursor); + _cursor = _cursor->next; + } - LetExpr* getOuterMostBinding() const { return m_outerMostBinding; } + DirectBaseInfo* operator*() { return _cursor; } private: - LetExpr* m_outerMostBinding = nullptr; - LetExpr* m_innerMostBinding = nullptr; + DirectBaseInfo* _cursor = nullptr; }; - ExprLocalScope* getExprLocalScope() { return m_exprLocalScope; } - Scope* getOuterScope() { return m_outerScope; } + Iterator begin() const { return Iterator(_head); } + Iterator end() const { return Iterator(); } - SemanticsContext withExprLocalScope(ExprLocalScope* exprLocalScope) - { - SemanticsContext result(*this); - result.m_exprLocalScope = exprLocalScope; - return result; - } + bool isEmpty() const { return _head == nullptr; } + + bool doesAnyTailContainMatchFor(Facet facet) const; + + void removeEmptyLists(); + + typedef DirectBaseListBuilder Builder; + + public: + DirectBaseInfo* _head = nullptr; + }; + + struct DirectBaseListBuilder : DirectBaseList + { + public: + DirectBaseListBuilder() { _link = &_head; } - SemanticsContext withOuterScope(Scope* scope) + void add(DirectBaseInfo* base) { - SemanticsContext result(*this); - result.m_outerScope = scope; - return result; + *_link = base; + _link = &base->next; } - SemanticsContext withTreatAsDifferentiable(TreatAsDifferentiableExpr* expr) + private: + DirectBaseInfo** _link = nullptr; + }; + + void _mergeFacetLists( + DirectBaseList bases, + FacetList baseFacets, + FacetList::Builder& ioMergedFacets); + + struct TypePair + { + Type* type0; + Type* type1; + HashCode getHashCode() const { - SemanticsContext result(*this); - result.m_treatAsDifferentiableExpr = expr; - return result; + return combineHash(Slang::getHashCode(type0), Slang::getHashCode(type1)); } - - SemanticsContext allowStaticReferenceToNonStaticMember() + bool operator==(const TypePair& other) const { - SemanticsContext result(*this); - result.m_allowStaticReferenceToNonStaticMember = true; - return result; + return type0 == other.type0 && type1 == other.type1; } + }; + Dictionary m_mapTypeToInheritanceInfo; + Dictionary, InheritanceInfo> m_mapDeclRefToInheritanceInfo; + Dictionary m_mapTypePairToSubtypeWitness; + Dictionary m_mapTypePairToImplicitCastMethod; +}; + +/// Local/scoped state of the semantic-checking system +/// +/// This type is kept distinct from `SharedSemanticsContext` so that we +/// can avoid unncessary mutable state being propagated through the +/// checking process. +/// +/// Semantic-checking code should make a new local `SemanticsContext` +/// in cases where it want to check a sub-entity (expression, statement, +/// declaration, etc.) in a modified or extended context. +/// +struct SemanticsContext +{ +public: + friend struct OuterScopeContextRAII; - SemanticsContext withDeclToExcludeFromLookup(Decl* decl) + explicit SemanticsContext(SharedSemanticsContext* shared) + : m_shared(shared) + , m_sink(shared->getSink()) + , m_astBuilder(shared->getLinkage()->getASTBuilder()) + { + if (shared->getLinkage()->m_optionSet.hasOption(CompilerOptionName::DisableShortCircuit)) { - SemanticsContext result(*this); - result.m_declToExcludeFromLookup = decl; - return result; + m_shouldShortCircuitLogicExpr = !shared->getLinkage()->m_optionSet.getBoolOption( + CompilerOptionName::DisableShortCircuit); } + } - Decl* getDeclToExcludeFromLookup() { return m_declToExcludeFromLookup; } - - OrderedHashSet* getCapturedTypePacks() { return m_capturedTypePacks; } - - private: - SharedSemanticsContext* m_shared = nullptr; - - DiagnosticSink* m_sink = nullptr; + SharedSemanticsContext* getShared() { return m_shared; } + CompilerOptionSet& getOptionSet() { return getShared()->getOptionSet(); } + ASTBuilder* getASTBuilder() { return m_astBuilder; } - ExprLocalScope* m_exprLocalScope = nullptr; + DiagnosticSink* getSink() { return m_sink; } - Decl* m_declToExcludeFromLookup = nullptr; + Session* getSession() { return m_shared->getSession(); } - protected: - // TODO: consider making more of this state `private`... + Linkage* getLinkage() { return m_shared->m_linkage; } + NamePool* getNamePool() { return getLinkage()->getNamePool(); } + SourceManager* getSourceManager() { return getLinkage()->getSourceManager(); } - /// The parent function (if any) that surrounds the statement being checked. - FunctionDeclBase* m_parentFunc = nullptr; + SemanticsContext withSink(DiagnosticSink* sink) + { + SemanticsContext result(*this); + result.m_sink = sink; + return result; + } - DifferentiableAttribute* m_parentDifferentiableAttr = nullptr; + FunctionDeclBase* getParentFuncOfVisitor() { return m_parentFunc; } + void setParentFuncOfVisitor(FunctionDeclBase* funcDecl) { m_parentFunc = funcDecl; } - /// The linked list of lexically surrounding statements. - OuterStmtInfo* m_outerStmts = nullptr; + SemanticsContext withParentFunc(FunctionDeclBase* parentFunc) + { + SemanticsContext result(*this); + result.m_parentFunc = parentFunc; + result.m_outerStmts = nullptr; + result.m_parentDifferentiableAttr = parentFunc->findModifier(); + if (parentFunc->ownedScope) + result.m_outerScope = parentFunc->ownedScope; + return result; + } - /// The type of a try clause (if any) enclosing current expr. - TryClauseType m_enclosingTryClauseType = TryClauseType::None; + SemanticsContext withParentExpandExpr(ExpandExpr* expr, OrderedHashSet* capturedTypes) + { + SemanticsContext result(*this); + result.m_parentExpandExpr = expr; + result.m_capturedTypePacks = capturedTypes; + return result; + } - /// Whether an expr referencing to a non-static member in static style (e.g. `Type.member`) - /// is considered valid in the current context. - bool m_allowStaticReferenceToNonStaticMember = false; + /// Information for tracking one or more outer statements. + /// + /// During checking of statements, we need to track what + /// outer statements are in scope, so that we can resolve + /// the target for a `break` or `continue` statement (and + /// validate that such statements are only used in contexts + /// where such a target exists). + /// + /// We use a linked list of `OuterStmtInfo` threaded up + /// through the recursive call stack to track the statements + /// that are lexically surrounding the one we are checking. + /// + struct OuterStmtInfo + { + Stmt* stmt = nullptr; + OuterStmtInfo* next; + }; - /// Whether or not we are in a `no_diff` environment (and therefore should treat the call to - /// a non-differentiable function as differentiable and not issue a diagnostic). - TreatAsDifferentiableExpr* m_treatAsDifferentiableExpr = nullptr; + OuterStmtInfo* getOuterStmts() { return m_outerStmts; } - ASTBuilder* m_astBuilder = nullptr; + SemanticsContext withOuterStmts(OuterStmtInfo* outerStmts) + { + SemanticsContext result(*this); + result.m_outerStmts = outerStmts; + return result; + } - Scope* m_outerScope = nullptr; + // Setup the flag to indicate disabling the short-circuiting evaluation + // for the logical expressions associted with the subcontext + SemanticsContext disableShortCircuitLogicalExpr() + { + SemanticsContext result(*this); + result.m_shouldShortCircuitLogicExpr = false; + return result; + } - // By default, we will support short-circuit evaluation for the logic expression. - // However, there are few exceptions where we will disable it: - // 1. the logic expression is inside the generic parameter list. - // 2. the logic expression is in the init expression of a static const variable. - // 3. the logic expression is in an array size declaration. - bool m_shouldShortCircuitLogicExpr = true; + TryClauseType getEnclosingTryClauseType() { return m_enclosingTryClauseType; } - ExpandExpr* m_parentExpandExpr = nullptr; + SemanticsContext withEnclosingTryClauseType(TryClauseType tryClauseType) + { + SemanticsContext result(*this); + result.m_enclosingTryClauseType = tryClauseType; + return result; + } - OrderedHashSet* m_capturedTypePacks = nullptr; - }; + DifferentiableAttribute* getParentDifferentiableAttribute() + { + return m_parentDifferentiableAttr; + } - struct OuterScopeContextRAII + /// A scope that is local to a particular expression, and + /// that can be used to allocate temporary bindings that + /// might be needed by that expression or its sub-expressions. + /// + /// The scope is represented as a sequence of nested `LetExpr`s + /// that introduce the bindings needed in the scope. + /// + struct ExprLocalScope { - SemanticsContext* m_context; - Scope* m_oldOuterScope; + public: + void addBinding(LetExpr* binding); - OuterScopeContextRAII(SemanticsContext* context, Scope* outerScope) - : m_context(context) - , m_oldOuterScope(context->getOuterScope()) - { - context->m_outerScope = outerScope; - } + LetExpr* getOuterMostBinding() const { return m_outerMostBinding; } - ~OuterScopeContextRAII() - { - m_context->m_outerScope = m_oldOuterScope; - } + private: + LetExpr* m_outerMostBinding = nullptr; + LetExpr* m_innerMostBinding = nullptr; }; -#define SLANG_OUTER_SCOPE_CONTEXT_RAII(context, scope) OuterScopeContextRAII _outerScopeContextRAII(context, scope) -#define SLANG_OUTER_SCOPE_CONTEXT_DECL_RAII(context, decl) OuterScopeContextRAII _outerScopeContextRAII(context, decl->ownedScope?decl->ownedScope:context->getOuterScope()) + ExprLocalScope* getExprLocalScope() { return m_exprLocalScope; } + Scope* getOuterScope() { return m_outerScope; } - struct SemanticsVisitor : public SemanticsContext + SemanticsContext withExprLocalScope(ExprLocalScope* exprLocalScope) { - typedef SemanticsContext Super; + SemanticsContext result(*this); + result.m_exprLocalScope = exprLocalScope; + return result; + } - explicit SemanticsVisitor( - SharedSemanticsContext* shared) - : Super(shared) - {} + SemanticsContext withOuterScope(Scope* scope) + { + SemanticsContext result(*this); + result.m_outerScope = scope; + return result; + } - SemanticsVisitor( - SemanticsContext const& context) - : Super(context) - {} + SemanticsContext withTreatAsDifferentiable(TreatAsDifferentiableExpr* expr) + { + SemanticsContext result(*this); + result.m_treatAsDifferentiableExpr = expr; + return result; + } - CompilerOptionSet& getOptionSet() - { - return getShared()->getOptionSet(); - } - public: - // Translate Types - - - Expr* TranslateTypeNodeImpl(Expr* node); - Type* ExtractTypeFromTypeRepr(Expr* typeRepr); - Type* TranslateTypeNode(Expr* node); - TypeExp TranslateTypeNodeForced(TypeExp const& typeExp); - TypeExp TranslateTypeNode(TypeExp const& typeExp); - Type* getRemovedModifierType(ModifiedType* type, ModifierVal* modifier); - DeclRefType* getExprDeclRefType(Expr * expr); - - /// Is `decl` usable as a static member? - bool isDeclUsableAsStaticMember( - Decl* decl); - - /// Is `item` usable as a static member? - bool isUsableAsStaticMember( - LookupResultItem const& item); - - /// Move `expr` into a temporary variable and execute `func` on that variable. - /// - /// Returns an expression that wraps both the creation and initialization of - /// the temporary, and the computation created by `func`. - /// - template - Expr* moveTemp(Expr* const& expr, F const& func); - - /// Execute `func` on a variable with the value of `expr`. - /// - /// If `expr` is just a reference to an immutable (e.g., `let`) variable - /// then this might use the existing variable. Otherwise it will create - /// a new variable to hold `expr`, using `moveTemp()`. - /// - template - Expr* maybeMoveTemp(Expr* const& expr, F const& func); - - /// Return an expression that represents "opening" the existential `expr`. - /// - /// The type of `expr` must be an interface type, matching `interfaceDeclRef`. - /// - /// If we scope down the PL theory to just the case that Slang cares about, - /// a value of an existential type like `IMover` is a tuple of: - /// - /// * a concrete type `X` - /// * a witness `w` of the fact that `X` implements `IMover` - /// * a value `v` of type `X` - /// - /// "Opening" an existential value is the process of decomposing a single - /// value `e : IMover` into the pieces `X`, `w`, and `v`. - /// - /// Rather than return all those pieces individually, this operation - /// returns an expression that logically corresponds to `v`: an expression - /// of type `X`, where the type carries the knowledge that `X` implements `IMover`. - /// - Expr* openExistential( - Expr* expr, - DeclRef interfaceDeclRef); - - /// If `expr` has existential type, then open it. - /// - /// Returns an expression that opens `expr` if it had existential type. - /// Otherwise returns `expr` itself. - /// - /// See `openExistential` for a discussion of what "opening" an - /// existential-type value means. - /// - Expr* maybeOpenExistential(Expr* expr); - - /// If `expr` has Ref Type, convert it into an l-value expr that has T type. - Expr* maybeOpenRef(Expr* expr); - - Scope* getScope(SyntaxNode* node); - - void diagnoseDeprecatedDeclRefUsage(DeclRef declRef, SourceLoc loc, Expr* originalExpr); - - DeclRef getDefaultDeclRef(Decl* decl) - { - return createDefaultSubstitutionsIfNeeded(m_astBuilder, this, makeDeclRef(decl)); - } + SemanticsContext allowStaticReferenceToNonStaticMember() + { + SemanticsContext result(*this); + result.m_allowStaticReferenceToNonStaticMember = true; + return result; + } - DeclRef getSpecializedDeclRef(DeclRef declToSpecialize, DeclRef declRefWithSpecializationArgs) - { - return declRefWithSpecializationArgs.substitute(m_astBuilder, declToSpecialize); - } + SemanticsContext withDeclToExcludeFromLookup(Decl* decl) + { + SemanticsContext result(*this); + result.m_declToExcludeFromLookup = decl; + return result; + } - DeclRef getSpecializedDeclRef(Decl* declToSpecialize, DeclRef declRefWithSpecializationArgs) - { - return declRefWithSpecializationArgs.substitute(m_astBuilder, getDefaultDeclRef(declToSpecialize)); - } + Decl* getDeclToExcludeFromLookup() { return m_declToExcludeFromLookup; } - DeclRefExpr* ConstructDeclRefExpr( - DeclRef declRef, - Expr* baseExpr, - Name* name, - SourceLoc loc, - Expr* originalExpr); - - Expr* ConstructDerefExpr( - Expr* base, - SourceLoc loc); - - InvokeExpr* constructUncheckedInvokeExpr(Expr* callee, const List& arguments); - - Expr* maybeUseSynthesizedDeclForLookupResult( - LookupResultItem const& item, - Expr* orignalExpr); - - Expr* ConstructLookupResultExpr( - LookupResultItem const& item, - Expr* baseExpr, - Name* name, - SourceLoc loc, - Expr* originalExpr); - - Expr* createLookupResultExpr( - Name* name, - LookupResult const& lookupResult, - Expr* baseExpr, - SourceLoc loc, - Expr* originalExpr); - - DeclVisibility getTypeVisibility(Type* type); - bool isDeclVisibleFromScope(DeclRef declRef, Scope* scope); - LookupResult filterLookupResultByVisibility(const LookupResult& lookupResult); - LookupResult filterLookupResultByVisibilityAndDiagnose(const LookupResult& lookupResult, SourceLoc loc, bool& outDiagnosed); - - Val* resolveVal(Val* val) - { - if (!val) return nullptr; - return val->resolve(); - } - Type* resolveType(Type* type) - { - return (Type*)resolveVal(type); - } - DeclRef resolveDeclRef(DeclRef declRef); + OrderedHashSet* getCapturedTypePacks() { return m_capturedTypePacks; } - /// Attempt to "resolve" an overloaded `LookupResult` to only include the "best" results - LookupResult resolveOverloadedLookup(LookupResult const& lookupResult); +private: + SharedSemanticsContext* m_shared = nullptr; - /// Attempt to resolve `expr` into an expression that refers to a single declaration/value. - /// If `expr` isn't overloaded, then it will be returned as-is. - /// - /// The provided `mask` is used to filter items down to those that are applicable in a given context (e.g., just types). - /// - /// If the expression cannot be resolved to a single value then *if* `diagSink` is non-null an - /// appropriate "ambiguous reference" error will be reported, and an error expression will be returned. - /// Otherwise, the original expression is returned if resolution fails. - /// - Expr* maybeResolveOverloadedExpr(Expr* expr, LookupMask mask, DiagnosticSink* diagSink); + DiagnosticSink* m_sink = nullptr; - /// Attempt to resolve `overloadedExpr` into an expression that refers to a single declaration/value. - /// - /// Equivalent to `maybeResolveOverloadedExpr` with `diagSink` bound to the sink for the `SemanticsVisitor`. - Expr* resolveOverloadedExpr(OverloadedExpr* overloadedExpr, LookupMask mask); + ExprLocalScope* m_exprLocalScope = nullptr; - /// Worker reoutine for `maybeResolveOverloadedExpr` and `resolveOverloadedExpr`. - Expr* _resolveOverloadedExprImpl(OverloadedExpr* overloadedExpr, LookupMask mask, DiagnosticSink* diagSink); + Decl* m_declToExcludeFromLookup = nullptr; - void diagnoseAmbiguousReference(OverloadedExpr* overloadedExpr, LookupResult const& lookupResult); - void diagnoseAmbiguousReference(Expr* overloadedExpr); +protected: + // TODO: consider making more of this state `private`... + /// The parent function (if any) that surrounds the statement being checked. + FunctionDeclBase* m_parentFunc = nullptr; - Expr* ExpectATypeRepr(Expr* expr); + DifferentiableAttribute* m_parentDifferentiableAttr = nullptr; - Type* ExpectAType(Expr* expr); + /// The linked list of lexically surrounding statements. + OuterStmtInfo* m_outerStmts = nullptr; - Type* ExtractGenericArgType(Expr* exp); + /// The type of a try clause (if any) enclosing current expr. + TryClauseType m_enclosingTryClauseType = TryClauseType::None; - IntVal* ExtractGenericArgInteger(Expr* exp, Type* genericParamType, DiagnosticSink* sink); - IntVal* ExtractGenericArgInteger(Expr* exp, Type* genericParamType); + /// Whether an expr referencing to a non-static member in static style (e.g. `Type.member`) + /// is considered valid in the current context. + bool m_allowStaticReferenceToNonStaticMember = false; - Val* ExtractGenericArgVal(Expr* exp); + /// Whether or not we are in a `no_diff` environment (and therefore should treat the call to + /// a non-differentiable function as differentiable and not issue a diagnostic). + TreatAsDifferentiableExpr* m_treatAsDifferentiableExpr = nullptr; - // Construct a type representing the instantiation of - // the given generic declaration for the given arguments. - // The arguments should already be checked against - // the declaration. - Type* InstantiateGenericType( - DeclRef genericDeclRef, - List const& args); + ASTBuilder* m_astBuilder = nullptr; - // These routines are bottlenecks for semantic checking, - // so that we can add some quality-of-life features for users - // in cases where the compiler crashes - // - void dispatchStmt(Stmt* stmt, SemanticsContext const& context); - Expr* dispatchExpr(Expr* expr, SemanticsContext const& context); - - /// Ensure that a declaration has been checked up to some state - /// (aka, a phase of semantic checking) so that we can safely - /// perform certain operations on it. - /// - /// Calling `ensureDecl` may cause the type-checker to recursively - /// start checking `decl` on top of the stack that is already - /// doing other semantic checking. Care should be taken when relying - /// on this function to avoid blowing out the stack or (even worse - /// creating a circular dependency). - /// - void ensureDecl(Decl* decl, DeclCheckState state, SemanticsContext* baseContext = nullptr); - - /// Helper routine allowing `ensureDecl` to be called on a `DeclRef` - void ensureDecl(DeclRefBase* declRef, DeclCheckState state) - { - ensureDecl(declRef->getDecl(), state); - } + Scope* m_outerScope = nullptr; - void ensureAllDeclsRec( - Decl* decl, - DeclCheckState state); - - /// Helper routine allowing `ensureDecl` to be used on a `DeclBase` - /// - /// `DeclBase` is the base clas of `Decl` and `DeclGroup`. When - /// called on a `DeclGroup` this function just calls `ensureDecl()` - /// on each declaration in the group. - /// - void ensureDeclBase(DeclBase* decl, DeclCheckState state, SemanticsContext* baseContext); - - // A "proper" type is one that can be used as the type of an expression. - // Put simply, it can be a concrete type like `int`, or a generic - // type that is applied to arguments, like `Texture2D`. - // The type `void` is also a proper type, since we can have expressions - // that return a `void` result (e.g., many function calls). - // - // A "non-proper" type is any type that can't actually have values. - // A simple example of this in C++ is `std::vector` - you can't have - // a value of this type. - // - // Part of what this function does is give errors if somebody tries - // to use a non-proper type as the type of a variable (or anything - // else that needs a proper type). - // - // The other thing it handles is the fact that HLSL lets you use - // the name of a non-proper type, and then have the compiler fill - // in the default values for its type arguments (e.g., a variable - // given type `Texture2D` will actually have type `Texture2D`). - bool CoerceToProperTypeImpl( - TypeExp const& typeExp, - Type** outProperType, - DiagnosticSink* diagSink); - - TypeExp CoerceToProperType(TypeExp const& typeExp); - - TypeExp tryCoerceToProperType(TypeExp const& typeExp); - - // Check a type, and coerce it to be proper - TypeExp CheckProperType(TypeExp typeExp); - - // For our purposes, a "usable" type is one that can be - // used to declare a function parameter, variable, etc. - // These turn out to be all the proper types except - // `void`. - // - // TODO(tfoley): consider just allowing `void` as a - // simple example of a "unit" type, and get rid of - // this check. - TypeExp CoerceToUsableType(TypeExp const& typeExp, Decl* decl); + // By default, we will support short-circuit evaluation for the logic expression. + // However, there are few exceptions where we will disable it: + // 1. the logic expression is inside the generic parameter list. + // 2. the logic expression is in the init expression of a static const variable. + // 3. the logic expression is in an array size declaration. + bool m_shouldShortCircuitLogicExpr = true; - // Check a type, and coerce it to be usable - TypeExp CheckUsableType(TypeExp typeExp, Decl* decl); + ExpandExpr* m_parentExpandExpr = nullptr; - Expr* CheckTerm(Expr* term); + OrderedHashSet* m_capturedTypePacks = nullptr; +}; - Expr* _CheckTerm(Expr* term); +struct OuterScopeContextRAII +{ + SemanticsContext* m_context; + Scope* m_oldOuterScope; - Expr* CreateErrorExpr(Expr* expr); + OuterScopeContextRAII(SemanticsContext* context, Scope* outerScope) + : m_context(context), m_oldOuterScope(context->getOuterScope()) + { + context->m_outerScope = outerScope; + } - bool IsErrorExpr(Expr* expr); + ~OuterScopeContextRAII() { m_context->m_outerScope = m_oldOuterScope; } +}; - // Capture the "base" expression in case this is a member reference - Expr* GetBaseExpr(Expr* expr); +#define SLANG_OUTER_SCOPE_CONTEXT_RAII(context, scope) \ + OuterScopeContextRAII _outerScopeContextRAII(context, scope) +#define SLANG_OUTER_SCOPE_CONTEXT_DECL_RAII(context, decl) \ + OuterScopeContextRAII _outerScopeContextRAII( \ + context, \ + decl->ownedScope ? decl->ownedScope : context->getOuterScope()) - /// Validate a declaration to ensure that it doesn't introduce a circularly-defined constant - /// - /// Circular definition in a constant may lead to infinite looping or stack overflow in - /// the compiler, so it needs to be protected against. - /// - /// Note that this function does *not* protect against circular definitions in general, - /// and a program that indirectly initializes a global variable using its own value (e.g., - /// by calling a function that indirectly reads the variable) will be allowed and then - /// exhibit undefined behavior at runtime. - /// - void _validateCircularVarDefinition(VarDeclBase* varDecl); +struct SemanticsVisitor : public SemanticsContext +{ + typedef SemanticsContext Super; - bool shouldSkipChecking(Decl* decl, DeclCheckState state); + explicit SemanticsVisitor(SharedSemanticsContext* shared) + : Super(shared) + { + } - // Auto-diff convenience functions for translating primal types to differential types. - Type* _toDifferentialParamType(Type* primalType); + SemanticsVisitor(SemanticsContext const& context) + : Super(context) + { + } - Type* getDifferentialPairType(Type* primalType); + CompilerOptionSet& getOptionSet() { return getShared()->getOptionSet(); } + +public: + // Translate Types + + + Expr* TranslateTypeNodeImpl(Expr* node); + Type* ExtractTypeFromTypeRepr(Expr* typeRepr); + Type* TranslateTypeNode(Expr* node); + TypeExp TranslateTypeNodeForced(TypeExp const& typeExp); + TypeExp TranslateTypeNode(TypeExp const& typeExp); + Type* getRemovedModifierType(ModifiedType* type, ModifierVal* modifier); + DeclRefType* getExprDeclRefType(Expr* expr); + + /// Is `decl` usable as a static member? + bool isDeclUsableAsStaticMember(Decl* decl); + + /// Is `item` usable as a static member? + bool isUsableAsStaticMember(LookupResultItem const& item); + + /// Move `expr` into a temporary variable and execute `func` on that variable. + /// + /// Returns an expression that wraps both the creation and initialization of + /// the temporary, and the computation created by `func`. + /// + template + Expr* moveTemp(Expr* const& expr, F const& func); + + /// Execute `func` on a variable with the value of `expr`. + /// + /// If `expr` is just a reference to an immutable (e.g., `let`) variable + /// then this might use the existing variable. Otherwise it will create + /// a new variable to hold `expr`, using `moveTemp()`. + /// + template + Expr* maybeMoveTemp(Expr* const& expr, F const& func); + + /// Return an expression that represents "opening" the existential `expr`. + /// + /// The type of `expr` must be an interface type, matching `interfaceDeclRef`. + /// + /// If we scope down the PL theory to just the case that Slang cares about, + /// a value of an existential type like `IMover` is a tuple of: + /// + /// * a concrete type `X` + /// * a witness `w` of the fact that `X` implements `IMover` + /// * a value `v` of type `X` + /// + /// "Opening" an existential value is the process of decomposing a single + /// value `e : IMover` into the pieces `X`, `w`, and `v`. + /// + /// Rather than return all those pieces individually, this operation + /// returns an expression that logically corresponds to `v`: an expression + /// of type `X`, where the type carries the knowledge that `X` implements `IMover`. + /// + Expr* openExistential(Expr* expr, DeclRef interfaceDeclRef); + + /// If `expr` has existential type, then open it. + /// + /// Returns an expression that opens `expr` if it had existential type. + /// Otherwise returns `expr` itself. + /// + /// See `openExistential` for a discussion of what "opening" an + /// existential-type value means. + /// + Expr* maybeOpenExistential(Expr* expr); + + /// If `expr` has Ref Type, convert it into an l-value expr that has T type. + Expr* maybeOpenRef(Expr* expr); + + Scope* getScope(SyntaxNode* node); + + void diagnoseDeprecatedDeclRefUsage(DeclRef declRef, SourceLoc loc, Expr* originalExpr); + + DeclRef getDefaultDeclRef(Decl* decl) + { + return createDefaultSubstitutionsIfNeeded(m_astBuilder, this, makeDeclRef(decl)); + } - // Convert a function's original type to it's forward/backward diff'd type. - Type* getForwardDiffFuncType(FuncType* originalType); - Type* getBackwardDiffFuncType(FuncType* originalType); + DeclRef getSpecializedDeclRef( + DeclRef declToSpecialize, + DeclRef declRefWithSpecializationArgs) + { + return declRefWithSpecializationArgs.substitute(m_astBuilder, declToSpecialize); + } - /// Registers a type as conforming to IDifferentiable, along with a witness - /// describing the relationship. - /// - void addDifferentiableTypeToDiffTypeRegistry(DeclRefType* type, SubtypeWitness* witness); - void maybeRegisterDifferentiableTypeImplRecursive(ASTBuilder* builder, Type* type); + DeclRef getSpecializedDeclRef( + Decl* declToSpecialize, + DeclRef declRefWithSpecializationArgs) + { + return declRefWithSpecializationArgs.substitute( + m_astBuilder, + getDefaultDeclRef(declToSpecialize)); + } - // Construct the differential for 'type', if it exists. - Type* getDifferentialType(ASTBuilder* builder, Type* type, SourceLoc loc); - Type* tryGetDifferentialType(ASTBuilder* builder, Type* type); + DeclRefExpr* ConstructDeclRefExpr( + DeclRef declRef, + Expr* baseExpr, + Name* name, + SourceLoc loc, + Expr* originalExpr); + + Expr* ConstructDerefExpr(Expr* base, SourceLoc loc); + + InvokeExpr* constructUncheckedInvokeExpr(Expr* callee, const List& arguments); + + Expr* maybeUseSynthesizedDeclForLookupResult(LookupResultItem const& item, Expr* orignalExpr); + + Expr* ConstructLookupResultExpr( + LookupResultItem const& item, + Expr* baseExpr, + Name* name, + SourceLoc loc, + Expr* originalExpr); + + Expr* createLookupResultExpr( + Name* name, + LookupResult const& lookupResult, + Expr* baseExpr, + SourceLoc loc, + Expr* originalExpr); + + DeclVisibility getTypeVisibility(Type* type); + bool isDeclVisibleFromScope(DeclRef declRef, Scope* scope); + LookupResult filterLookupResultByVisibility(const LookupResult& lookupResult); + LookupResult filterLookupResultByVisibilityAndDiagnose( + const LookupResult& lookupResult, + SourceLoc loc, + bool& outDiagnosed); + + Val* resolveVal(Val* val) + { + if (!val) + return nullptr; + return val->resolve(); + } + Type* resolveType(Type* type) { return (Type*)resolveVal(type); } + DeclRef resolveDeclRef(DeclRef declRef); + + /// Attempt to "resolve" an overloaded `LookupResult` to only include the "best" results + LookupResult resolveOverloadedLookup(LookupResult const& lookupResult); + + /// Attempt to resolve `expr` into an expression that refers to a single declaration/value. + /// If `expr` isn't overloaded, then it will be returned as-is. + /// + /// The provided `mask` is used to filter items down to those that are applicable in a given + /// context (e.g., just types). + /// + /// If the expression cannot be resolved to a single value then *if* `diagSink` is non-null an + /// appropriate "ambiguous reference" error will be reported, and an error expression will be + /// returned. Otherwise, the original expression is returned if resolution fails. + /// + Expr* maybeResolveOverloadedExpr(Expr* expr, LookupMask mask, DiagnosticSink* diagSink); + + /// Attempt to resolve `overloadedExpr` into an expression that refers to a single + /// declaration/value. + /// + /// Equivalent to `maybeResolveOverloadedExpr` with `diagSink` bound to the sink for the + /// `SemanticsVisitor`. + Expr* resolveOverloadedExpr(OverloadedExpr* overloadedExpr, LookupMask mask); + + /// Worker reoutine for `maybeResolveOverloadedExpr` and `resolveOverloadedExpr`. + Expr* _resolveOverloadedExprImpl( + OverloadedExpr* overloadedExpr, + LookupMask mask, + DiagnosticSink* diagSink); + + void diagnoseAmbiguousReference( + OverloadedExpr* overloadedExpr, + LookupResult const& lookupResult); + void diagnoseAmbiguousReference(Expr* overloadedExpr); + + + Expr* ExpectATypeRepr(Expr* expr); + + Type* ExpectAType(Expr* expr); + + Type* ExtractGenericArgType(Expr* exp); + + IntVal* ExtractGenericArgInteger(Expr* exp, Type* genericParamType, DiagnosticSink* sink); + IntVal* ExtractGenericArgInteger(Expr* exp, Type* genericParamType); + + Val* ExtractGenericArgVal(Expr* exp); + + // Construct a type representing the instantiation of + // the given generic declaration for the given arguments. + // The arguments should already be checked against + // the declaration. + Type* InstantiateGenericType(DeclRef genericDeclRef, List const& args); + + // These routines are bottlenecks for semantic checking, + // so that we can add some quality-of-life features for users + // in cases where the compiler crashes + // + void dispatchStmt(Stmt* stmt, SemanticsContext const& context); + Expr* dispatchExpr(Expr* expr, SemanticsContext const& context); + + /// Ensure that a declaration has been checked up to some state + /// (aka, a phase of semantic checking) so that we can safely + /// perform certain operations on it. + /// + /// Calling `ensureDecl` may cause the type-checker to recursively + /// start checking `decl` on top of the stack that is already + /// doing other semantic checking. Care should be taken when relying + /// on this function to avoid blowing out the stack or (even worse + /// creating a circular dependency). + /// + void ensureDecl(Decl* decl, DeclCheckState state, SemanticsContext* baseContext = nullptr); + + /// Helper routine allowing `ensureDecl` to be called on a `DeclRef` + void ensureDecl(DeclRefBase* declRef, DeclCheckState state) + { + ensureDecl(declRef->getDecl(), state); + } - // Helper function to check if a struct can be used as its own differential type. - bool canStructBeUsedAsSelfDifferentialType(AggTypeDecl *aggTypeDecl); - void markSelfDifferentialMembersOfType(AggTypeDecl *parent, Type* type); + void ensureAllDeclsRec(Decl* decl, DeclCheckState state); + + /// Helper routine allowing `ensureDecl` to be used on a `DeclBase` + /// + /// `DeclBase` is the base clas of `Decl` and `DeclGroup`. When + /// called on a `DeclGroup` this function just calls `ensureDecl()` + /// on each declaration in the group. + /// + void ensureDeclBase(DeclBase* decl, DeclCheckState state, SemanticsContext* baseContext); + + // A "proper" type is one that can be used as the type of an expression. + // Put simply, it can be a concrete type like `int`, or a generic + // type that is applied to arguments, like `Texture2D`. + // The type `void` is also a proper type, since we can have expressions + // that return a `void` result (e.g., many function calls). + // + // A "non-proper" type is any type that can't actually have values. + // A simple example of this in C++ is `std::vector` - you can't have + // a value of this type. + // + // Part of what this function does is give errors if somebody tries + // to use a non-proper type as the type of a variable (or anything + // else that needs a proper type). + // + // The other thing it handles is the fact that HLSL lets you use + // the name of a non-proper type, and then have the compiler fill + // in the default values for its type arguments (e.g., a variable + // given type `Texture2D` will actually have type `Texture2D`). + bool CoerceToProperTypeImpl( + TypeExp const& typeExp, + Type** outProperType, + DiagnosticSink* diagSink); + + TypeExp CoerceToProperType(TypeExp const& typeExp); + + TypeExp tryCoerceToProperType(TypeExp const& typeExp); + + // Check a type, and coerce it to be proper + TypeExp CheckProperType(TypeExp typeExp); + + // For our purposes, a "usable" type is one that can be + // used to declare a function parameter, variable, etc. + // These turn out to be all the proper types except + // `void`. + // + // TODO(tfoley): consider just allowing `void` as a + // simple example of a "unit" type, and get rid of + // this check. + TypeExp CoerceToUsableType(TypeExp const& typeExp, Decl* decl); + + // Check a type, and coerce it to be usable + TypeExp CheckUsableType(TypeExp typeExp, Decl* decl); + + Expr* CheckTerm(Expr* term); + + Expr* _CheckTerm(Expr* term); + + Expr* CreateErrorExpr(Expr* expr); + + bool IsErrorExpr(Expr* expr); + + // Capture the "base" expression in case this is a member reference + Expr* GetBaseExpr(Expr* expr); + + /// Validate a declaration to ensure that it doesn't introduce a circularly-defined constant + /// + /// Circular definition in a constant may lead to infinite looping or stack overflow in + /// the compiler, so it needs to be protected against. + /// + /// Note that this function does *not* protect against circular definitions in general, + /// and a program that indirectly initializes a global variable using its own value (e.g., + /// by calling a function that indirectly reads the variable) will be allowed and then + /// exhibit undefined behavior at runtime. + /// + void _validateCircularVarDefinition(VarDeclBase* varDecl); + + bool shouldSkipChecking(Decl* decl, DeclCheckState state); + + // Auto-diff convenience functions for translating primal types to differential types. + Type* _toDifferentialParamType(Type* primalType); + + Type* getDifferentialPairType(Type* primalType); + + // Convert a function's original type to it's forward/backward diff'd type. + Type* getForwardDiffFuncType(FuncType* originalType); + Type* getBackwardDiffFuncType(FuncType* originalType); + + /// Registers a type as conforming to IDifferentiable, along with a witness + /// describing the relationship. + /// + void addDifferentiableTypeToDiffTypeRegistry(DeclRefType* type, SubtypeWitness* witness); + void maybeRegisterDifferentiableTypeImplRecursive(ASTBuilder* builder, Type* type); + + // Construct the differential for 'type', if it exists. + Type* getDifferentialType(ASTBuilder* builder, Type* type, SourceLoc loc); + Type* tryGetDifferentialType(ASTBuilder* builder, Type* type); + + // Helper function to check if a struct can be used as its own differential type. + bool canStructBeUsedAsSelfDifferentialType(AggTypeDecl* aggTypeDecl); + void markSelfDifferentialMembersOfType(AggTypeDecl* parent, Type* type); + + void checkDerivativeMemberAttributeReferences( + VarDeclBase* varDecl, + DerivativeMemberAttribute* derivativeMemberAttr); + +public: + bool ValuesAreEqual(IntVal* left, IntVal* right); + + // Compute the cost of using a particular declaration to + // perform implicit type conversion. + ConversionCost getImplicitConversionCost(Decl* decl); + + ConversionCost getImplicitConversionCostWithKnownArg(Decl* decl, Type* toType, Expr* arg); + + + BuiltinConversionKind getImplicitConversionBuiltinKind(Decl* decl); + + bool isEffectivelyScalarForInitializerLists(Type* type); + + /// Should the provided expression (from an initializer list) be used directly to initialize + /// `toType`? + bool shouldUseInitializerDirectly(Type* toType, Expr* fromExpr); + + /// Read a value from an initializer list expression. + /// + /// This reads one or more argument from the initializer list + /// given as `fromInitializerListExpr` to initialize a value + /// of type `toType`. This may involve reading one or + /// more arguments from the initializer list, depending + /// on whether `toType` is an aggregate or not, and on + /// whether the next argument in the initializer list is + /// itself an initializer list. + /// + /// This routine returns `true` if it was able to read + /// arguments that can form a value of type `toType`, + /// and `false` otherwise. + /// + /// If the routine succeeds and `outToExpr` is non-null, + /// then it will be filled in with an expression + /// representing the value (or type `toType`) that was read, + /// or it will be left null to indicate that a default + /// value should be used. + /// + /// If the routine fails and `outToExpr` is non-null, + /// then a suitable diagnostic will be emitted. + /// + bool _readValueFromInitializerList( + Type* toType, + Expr** outToExpr, + InitializerListExpr* fromInitializerListExpr, + UInt& ioInitArgIndex); + + /// Read an aggregate value from an initializer list expression. + /// + /// This reads one or more arguments from the initializer list + /// given as `fromInitializerListExpr` to initialize the + /// fields/elements of a value of type `toType`. + /// + /// This routine returns `true` if it was able to read + /// arguments that can form a value of type `toType`, + /// and `false` otherwise. + /// + /// If the routine succeeds and `outToExpr` is non-null, + /// then it will be filled in with an expression + /// representing the value (or type `toType`) that was read, + /// or it will be left null to indicate that a default + /// value should be used. + /// + /// If the routine fails and `outToExpr` is non-null, + /// then a suitable diagnostic will be emitted. + /// + bool _readAggregateValueFromInitializerList( + Type* inToType, + Expr** outToExpr, + InitializerListExpr* fromInitializerListExpr, + UInt& ioArgIndex); + + /// Coerce an initializer-list expression to a specific type. + /// + /// This reads one or more arguments from the initializer list + /// given as `fromInitializerListExpr` to initialize the + /// fields/elements of a value of type `toType`. + /// + /// This routine returns `true` if it was able to read + /// arguments that can form a value of type `toType`, + /// with no arguments left over, and `false` otherwise. + /// + /// If the routine succeeds and `outToExpr` is non-null, + /// then it will be filled in with an expression + /// representing the value (or type `toType`) that was read, + /// or it will be left null to indicate that a default + /// value should be used. + /// + /// If the routine fails and `outToExpr` is non-null, + /// then a suitable diagnostic will be emitted. + /// + bool _coerceInitializerList( + Type* toType, + Expr** outToExpr, + InitializerListExpr* fromInitializerListExpr); + + /// Report that implicit type coercion is not possible. + bool _failedCoercion(Type* toType, Expr** outToExpr, Expr* fromExpr); + + /// Central engine for implementing implicit coercion logic + /// + /// This function tries to find an implicit conversion path from + /// `fromType` to `toType`. It returns `true` if a conversion + /// is found, and `false` if not. + /// + /// If a conversion is found, then its cost will be written to `outCost`. + /// + /// If a `fromExpr` is provided, it must be of type `fromType`, + /// and represent a value to be converted. + /// + /// If `outToExpr` is non-null, and if a conversion is found, then + /// `*outToExpr` will be set to an expression that performs the + /// implicit conversion of `fromExpr` (which must be non-null + /// to `toType`). + /// + /// The case where `outToExpr` is non-null is used to identify + /// when a conversion is being done "for real" so that diagnostics + /// should be emitted on failure. + /// + bool _coerce( + CoercionSite site, + Type* toType, + Expr** outToExpr, + QualType fromType, + Expr* fromExpr, + ConversionCost* outCost); + + /// Check whether implicit type coercion from `fromType` to `toType` is possible. + /// + /// If conversion is possible, returns `true` and sets `outCost` to the cost + /// of the conversion found (if `outCost` is non-null). + /// + /// If conversion is not possible, returns `false`. + /// + bool canCoerce(Type* toType, QualType fromType, Expr* fromExpr, ConversionCost* outCost = 0); + + TypeCastExpr* createImplicitCastExpr(); + + Expr* CreateImplicitCastExpr(Type* toType, Expr* fromExpr); + + /// Create an "up-cast" from a value to an interface type + /// + /// This operation logically constructs an "existential" value, + /// which packages up the value, its type, and the witness + /// of its conformance to the interface. + /// + Expr* createCastToInterfaceExpr(Type* toType, Expr* fromExpr, Val* witness); + + /// Implicitly coerce `fromExpr` to `toType` and diagnose errors if it isn't possible + Expr* coerce(CoercionSite site, Type* toType, Expr* fromExpr); + + // Fill in default substitutions for the 'subtype' part of a type constraint decl + void CheckConstraintSubType(TypeExp& typeExp); + + void checkGenericDeclHeader(GenericDecl* genericDecl); + + IntVal* checkLinkTimeConstantIntVal(Expr* expr); + + ConstantIntVal* checkConstantIntVal(Expr* expr); + + ConstantIntVal* checkConstantEnumVal(Expr* expr); + + // Check an expression, coerce it to the `String` type, and then + // ensure that it has a literal (not just compile-time constant) value. + bool checkLiteralStringVal(Expr* expr, String* outVal); + + bool checkCapabilityName(Expr* expr, CapabilityName& outCapabilityName); + + void visitModifier(Modifier*); + + AttributeDecl* lookUpAttributeDecl(Name* attributeName, Scope* scope); + + bool hasIntArgs(Attribute* attr, int numArgs); + bool hasStringArgs(Attribute* attr, int numArgs); + + bool getAttributeTargetSyntaxClasses(SyntaxClass& cls, uint32_t typeFlags); + + // Check an attribute, and return a checked modifier that represents it. + // + Modifier* validateAttribute( + Attribute* attr, + AttributeDecl* attribClassDecl, + ModifiableSyntaxNode* attrTarget); + + AttributeBase* checkAttribute( + UncheckedAttribute* uncheckedAttr, + ModifiableSyntaxNode* attrTarget); + + Modifier* checkModifier( + Modifier* m, + ModifiableSyntaxNode* syntaxNode, + bool ignoreUnallowedModifier); + + void checkModifiers(ModifiableSyntaxNode* syntaxNode); + void checkVisibility(Decl* decl); + + bool doesSignatureMatchRequirement( + DeclRef satisfyingMemberDeclRef, + DeclRef requiredMemberDeclRef, + RefPtr witnessTable); + + bool doesAccessorMatchRequirement( + DeclRef satisfyingMemberDeclRef, + DeclRef requiredMemberDeclRef); + + bool doesPropertyMatchRequirement( + DeclRef satisfyingMemberDeclRef, + DeclRef requiredMemberDeclRef, + RefPtr witnessTable); + + bool doesSubscriptMatchRequirement( + DeclRef satisfyingMemberDeclRef, + DeclRef requiredMemberDeclRef, + RefPtr witnessTable); + + bool doesVarMatchRequirement( + DeclRef satisfyingMemberDeclRef, + DeclRef requiredMemberDeclRef, + RefPtr witnessTable); + + bool doesGenericSignatureMatchRequirement( + DeclRef genDecl, + DeclRef requirementGenDecl, + RefPtr witnessTable); + + bool doesTypeSatisfyAssociatedTypeConstraintRequirement( + Type* satisfyingType, + DeclRef requiredAssociatedTypeDeclRef, + RefPtr witnessTable); + + bool doesTypeSatisfyAssociatedTypeRequirement( + Type* satisfyingType, + DeclRef requiredAssociatedTypeDeclRef, + RefPtr witnessTable); - void checkDerivativeMemberAttributeReferences( - VarDeclBase* varDecl, DerivativeMemberAttribute* derivativeMemberAttr); - - public: + // Does the given `memberDecl` work as an implementation + // to satisfy the requirement `requiredMemberDeclRef` + // from an interface? + // + // If it does, then inserts a witness into `witnessTable` + // and returns `true`, otherwise returns `false` + bool doesMemberSatisfyRequirement( + DeclRef memberDeclRef, + DeclRef requiredMemberDeclRef, + RefPtr witnessTable); - bool ValuesAreEqual( - IntVal* left, - IntVal* right); - - // Compute the cost of using a particular declaration to - // perform implicit type conversion. - ConversionCost getImplicitConversionCost( - Decl* decl); - - ConversionCost getImplicitConversionCostWithKnownArg(Decl* decl, Type* toType, Expr* arg); - - - BuiltinConversionKind getImplicitConversionBuiltinKind( - Decl* decl); - - bool isEffectivelyScalarForInitializerLists( - Type* type); - - /// Should the provided expression (from an initializer list) be used directly to initialize `toType`? - bool shouldUseInitializerDirectly( - Type* toType, - Expr* fromExpr); - - /// Read a value from an initializer list expression. - /// - /// This reads one or more argument from the initializer list - /// given as `fromInitializerListExpr` to initialize a value - /// of type `toType`. This may involve reading one or - /// more arguments from the initializer list, depending - /// on whether `toType` is an aggregate or not, and on - /// whether the next argument in the initializer list is - /// itself an initializer list. - /// - /// This routine returns `true` if it was able to read - /// arguments that can form a value of type `toType`, - /// and `false` otherwise. - /// - /// If the routine succeeds and `outToExpr` is non-null, - /// then it will be filled in with an expression - /// representing the value (or type `toType`) that was read, - /// or it will be left null to indicate that a default - /// value should be used. - /// - /// If the routine fails and `outToExpr` is non-null, - /// then a suitable diagnostic will be emitted. - /// - bool _readValueFromInitializerList( - Type* toType, - Expr** outToExpr, - InitializerListExpr* fromInitializerListExpr, - UInt &ioInitArgIndex); - - /// Read an aggregate value from an initializer list expression. - /// - /// This reads one or more arguments from the initializer list - /// given as `fromInitializerListExpr` to initialize the - /// fields/elements of a value of type `toType`. - /// - /// This routine returns `true` if it was able to read - /// arguments that can form a value of type `toType`, - /// and `false` otherwise. - /// - /// If the routine succeeds and `outToExpr` is non-null, - /// then it will be filled in with an expression - /// representing the value (or type `toType`) that was read, - /// or it will be left null to indicate that a default - /// value should be used. - /// - /// If the routine fails and `outToExpr` is non-null, - /// then a suitable diagnostic will be emitted. - /// - bool _readAggregateValueFromInitializerList( - Type* inToType, - Expr** outToExpr, - InitializerListExpr* fromInitializerListExpr, - UInt &ioArgIndex); - - /// Coerce an initializer-list expression to a specific type. - /// - /// This reads one or more arguments from the initializer list - /// given as `fromInitializerListExpr` to initialize the - /// fields/elements of a value of type `toType`. - /// - /// This routine returns `true` if it was able to read - /// arguments that can form a value of type `toType`, - /// with no arguments left over, and `false` otherwise. - /// - /// If the routine succeeds and `outToExpr` is non-null, - /// then it will be filled in with an expression - /// representing the value (or type `toType`) that was read, - /// or it will be left null to indicate that a default - /// value should be used. - /// - /// If the routine fails and `outToExpr` is non-null, - /// then a suitable diagnostic will be emitted. - /// - bool _coerceInitializerList( - Type* toType, - Expr** outToExpr, - InitializerListExpr* fromInitializerListExpr); - - /// Report that implicit type coercion is not possible. - bool _failedCoercion( - Type* toType, - Expr** outToExpr, - Expr* fromExpr); - - /// Central engine for implementing implicit coercion logic - /// - /// This function tries to find an implicit conversion path from - /// `fromType` to `toType`. It returns `true` if a conversion - /// is found, and `false` if not. - /// - /// If a conversion is found, then its cost will be written to `outCost`. - /// - /// If a `fromExpr` is provided, it must be of type `fromType`, - /// and represent a value to be converted. - /// - /// If `outToExpr` is non-null, and if a conversion is found, then - /// `*outToExpr` will be set to an expression that performs the - /// implicit conversion of `fromExpr` (which must be non-null - /// to `toType`). - /// - /// The case where `outToExpr` is non-null is used to identify - /// when a conversion is being done "for real" so that diagnostics - /// should be emitted on failure. - /// - bool _coerce( - CoercionSite site, - Type* toType, - Expr** outToExpr, - QualType fromType, - Expr* fromExpr, - ConversionCost* outCost); - - /// Check whether implicit type coercion from `fromType` to `toType` is possible. - /// - /// If conversion is possible, returns `true` and sets `outCost` to the cost - /// of the conversion found (if `outCost` is non-null). - /// - /// If conversion is not possible, returns `false`. - /// - bool canCoerce( - Type* toType, - QualType fromType, - Expr* fromExpr, - ConversionCost* outCost = 0); - - TypeCastExpr* createImplicitCastExpr(); - - Expr* CreateImplicitCastExpr( - Type* toType, - Expr* fromExpr); - - /// Create an "up-cast" from a value to an interface type - /// - /// This operation logically constructs an "existential" value, - /// which packages up the value, its type, and the witness - /// of its conformance to the interface. - /// - Expr* createCastToInterfaceExpr( - Type* toType, - Expr* fromExpr, - Val* witness); - - /// Implicitly coerce `fromExpr` to `toType` and diagnose errors if it isn't possible - Expr* coerce( - CoercionSite site, - Type* toType, - Expr* fromExpr); - - // Fill in default substitutions for the 'subtype' part of a type constraint decl - void CheckConstraintSubType(TypeExp& typeExp); - - void checkGenericDeclHeader(GenericDecl* genericDecl); - - IntVal* checkLinkTimeConstantIntVal( - Expr* expr); - - ConstantIntVal* checkConstantIntVal( - Expr* expr); - - ConstantIntVal* checkConstantEnumVal( - Expr* expr); - - // Check an expression, coerce it to the `String` type, and then - // ensure that it has a literal (not just compile-time constant) value. - bool checkLiteralStringVal( - Expr* expr, - String* outVal); - - bool checkCapabilityName(Expr* expr, CapabilityName& outCapabilityName); - - void visitModifier(Modifier*); - - AttributeDecl* lookUpAttributeDecl(Name* attributeName, Scope* scope); - - bool hasIntArgs(Attribute* attr, int numArgs); - bool hasStringArgs(Attribute* attr, int numArgs); - - bool getAttributeTargetSyntaxClasses(SyntaxClass & cls, uint32_t typeFlags); - - // Check an attribute, and return a checked modifier that represents it. - // - Modifier* validateAttribute(Attribute* attr, AttributeDecl* attribClassDecl, ModifiableSyntaxNode* attrTarget); - - AttributeBase* checkAttribute( - UncheckedAttribute* uncheckedAttr, - ModifiableSyntaxNode* attrTarget); - - Modifier* checkModifier( - Modifier* m, - ModifiableSyntaxNode* syntaxNode, - bool ignoreUnallowedModifier); - - void checkModifiers(ModifiableSyntaxNode* syntaxNode); - void checkVisibility(Decl* decl); - - bool doesSignatureMatchRequirement( - DeclRef satisfyingMemberDeclRef, - DeclRef requiredMemberDeclRef, - RefPtr witnessTable); - - bool doesAccessorMatchRequirement( - DeclRef satisfyingMemberDeclRef, - DeclRef requiredMemberDeclRef); - - bool doesPropertyMatchRequirement( - DeclRef satisfyingMemberDeclRef, - DeclRef requiredMemberDeclRef, - RefPtr witnessTable); - - bool doesSubscriptMatchRequirement( - DeclRef satisfyingMemberDeclRef, - DeclRef requiredMemberDeclRef, - RefPtr witnessTable); - - bool doesVarMatchRequirement( - DeclRef satisfyingMemberDeclRef, - DeclRef requiredMemberDeclRef, - RefPtr witnessTable); - - bool doesGenericSignatureMatchRequirement( - DeclRef genDecl, - DeclRef requirementGenDecl, - RefPtr witnessTable); - - bool doesTypeSatisfyAssociatedTypeConstraintRequirement( - Type* satisfyingType, - DeclRef requiredAssociatedTypeDeclRef, - RefPtr witnessTable); - - bool doesTypeSatisfyAssociatedTypeRequirement( - Type* satisfyingType, - DeclRef requiredAssociatedTypeDeclRef, - RefPtr witnessTable); - - // Does the given `memberDecl` work as an implementation - // to satisfy the requirement `requiredMemberDeclRef` - // from an interface? - // - // If it does, then inserts a witness into `witnessTable` - // and returns `true`, otherwise returns `false` - bool doesMemberSatisfyRequirement( - DeclRef memberDeclRef, - DeclRef requiredMemberDeclRef, - RefPtr witnessTable); - - // State used while checking if a declaration (either a type declaration - // or an extension of that type) conforms to the interfaces it claims - // via its inheritance clauses. - // - struct ConformanceCheckingContext - { - /// The type for which conformances are being checked - Type* conformingType; + // State used while checking if a declaration (either a type declaration + // or an extension of that type) conforms to the interfaces it claims + // via its inheritance clauses. + // + struct ConformanceCheckingContext + { + /// The type for which conformances are being checked + Type* conformingType; - /// The outer declaration for the conformances being checked (either a type or `extension` declaration) - ContainerDecl* parentDecl; + /// The outer declaration for the conformances being checked (either a type or `extension` + /// declaration) + ContainerDecl* parentDecl; - Dictionary, RefPtr> mapInterfaceToWitnessTable; - }; + Dictionary, RefPtr> mapInterfaceToWitnessTable; + }; - void addModifiersToSynthesizedDecl( - ConformanceCheckingContext* context, - DeclRef requirement, - CallableDecl* synthesized, - ThisExpr* &synThis); - - void addRequiredParamsToSynthesizedDecl( - DeclRef requirement, - CallableDecl* synthesized, - List& synArgs); - - CallableDecl* synthesizeMethodSignatureForRequirementWitnessInner( - ConformanceCheckingContext* context, - DeclRef requiredMemberDeclRef, - List& synArgs, - ThisExpr*& synThis); - - CallableDecl* synthesizeMethodSignatureForRequirementWitness( - ConformanceCheckingContext* context, - DeclRef requiredMemberDeclRef, - List& synArgs, - ThisExpr*& synThis); - - GenericDecl* synthesizeGenericSignatureForRequirementWitness( - ConformanceCheckingContext* context, - DeclRef requiredMemberDeclRef, - List& synArgs, - List& synGenericArgs, - ThisExpr*& synThis); - - bool synthesizeAccessorRequirements( - ConformanceCheckingContext* context, - DeclRef requiredMemberDeclRef, - Type* resultType, - Expr* synBoundStorageExpr, - ContainerDecl* synAccesorContainer, - RefPtr witnessTable); - - void _addMethodWitness( - WitnessTable* witnessTable, - DeclRef requirement, - DeclRef method); - - /// Attempt to synthesize a method that can satisfy `requiredMemberDeclRef` using `lookupResult`. - /// - /// On success, installs the syntethesized method in `witnessTable` and returns `true`. - /// Otherwise, returns `false`. - bool trySynthesizeMethodRequirementWitness( - ConformanceCheckingContext* context, - LookupResult const& lookupResult, - DeclRef requiredMemberDeclRef, - RefPtr witnessTable); - - bool trySynthesizeConstructorRequirementWitness( - ConformanceCheckingContext* context, - LookupResult const& lookupResult, - DeclRef requiredMemberDeclRef, - RefPtr witnessTable); - - /// Attempt to synthesize a property that can satisfy `requiredMemberDeclRef` using `lookupResult`. - /// - /// On success, installs the syntethesized method in `witnessTable` and returns `true`. - /// Otherwise, returns `false`. - /// - bool trySynthesizePropertyRequirementWitness( - ConformanceCheckingContext* context, - LookupResult const& lookupResult, - DeclRef requiredMemberDeclRef, - RefPtr witnessTable); - - bool trySynthesizeWrapperTypePropertyRequirementWitness( - ConformanceCheckingContext* context, - DeclRef requiredMemberDeclRef, - RefPtr witnessTable); - - bool trySynthesizeSubscriptRequirementWitness( - ConformanceCheckingContext* context, - const LookupResult& lookupResult, - DeclRef requiredMemberDeclRef, - RefPtr witnessTable); - - bool trySynthesizeWrapperTypeSubscriptRequirementWitness( - ConformanceCheckingContext* context, - DeclRef requiredMemberDeclRef, - RefPtr witnessTable); - - bool trySynthesizeAssociatedTypeRequirementWitness( - ConformanceCheckingContext* context, - LookupResult const& lookupResult, - DeclRef requiredMemberDeclRef, - RefPtr witnessTable); - - bool trySynthesizeAssociatedConstantRequirementWitness( - ConformanceCheckingContext* context, - LookupResult const& lookupResult, - DeclRef requiredMemberDeclRef, - RefPtr witnessTable); - - /// Attempt to synthesize a declartion that can satisfy `requiredMemberDeclRef` using `lookupResult`. - /// - /// On success, installs the syntethesized declaration in `witnessTable` and returns `true`. - /// Otherwise, returns `false`. - bool trySynthesizeRequirementWitness( - ConformanceCheckingContext* context, - LookupResult const& lookupResult, - DeclRef requiredMemberDeclRef, - RefPtr witnessTable); - - - enum SynthesisPattern - { - // Synthesized method inducts over all arguments. - // T fn(T x, T y, T z, ...) - // { - // typeof(T::member0)::fn(x.member0, y.member0, z.member0, ...); - // typeof(T::member1)::fn(x.member1, y.member1, z.member1, ...); - // ... - // } - // - AllInductive, - - // Synthesized method inducts over all arguments except the first. - // T fn(U x, T y, T z) - // { - // typeof(T::member0)::fn(x, y.member0, z.member0, ...); - // typeof(T::member1)::fn(x, y.member1, z.member1, ...); - // ... - // } - FixedFirstArg - }; + void addModifiersToSynthesizedDecl( + ConformanceCheckingContext* context, + DeclRef requirement, + CallableDecl* synthesized, + ThisExpr*& synThis); + + void addRequiredParamsToSynthesizedDecl( + DeclRef requirement, + CallableDecl* synthesized, + List& synArgs); + + CallableDecl* synthesizeMethodSignatureForRequirementWitnessInner( + ConformanceCheckingContext* context, + DeclRef requiredMemberDeclRef, + List& synArgs, + ThisExpr*& synThis); + + CallableDecl* synthesizeMethodSignatureForRequirementWitness( + ConformanceCheckingContext* context, + DeclRef requiredMemberDeclRef, + List& synArgs, + ThisExpr*& synThis); + + GenericDecl* synthesizeGenericSignatureForRequirementWitness( + ConformanceCheckingContext* context, + DeclRef requiredMemberDeclRef, + List& synArgs, + List& synGenericArgs, + ThisExpr*& synThis); + + bool synthesizeAccessorRequirements( + ConformanceCheckingContext* context, + DeclRef requiredMemberDeclRef, + Type* resultType, + Expr* synBoundStorageExpr, + ContainerDecl* synAccesorContainer, + RefPtr witnessTable); + + void _addMethodWitness( + WitnessTable* witnessTable, + DeclRef requirement, + DeclRef method); + + /// Attempt to synthesize a method that can satisfy `requiredMemberDeclRef` using + /// `lookupResult`. + /// + /// On success, installs the syntethesized method in `witnessTable` and returns `true`. + /// Otherwise, returns `false`. + bool trySynthesizeMethodRequirementWitness( + ConformanceCheckingContext* context, + LookupResult const& lookupResult, + DeclRef requiredMemberDeclRef, + RefPtr witnessTable); + + bool trySynthesizeConstructorRequirementWitness( + ConformanceCheckingContext* context, + LookupResult const& lookupResult, + DeclRef requiredMemberDeclRef, + RefPtr witnessTable); + + /// Attempt to synthesize a property that can satisfy `requiredMemberDeclRef` using + /// `lookupResult`. + /// + /// On success, installs the syntethesized method in `witnessTable` and returns `true`. + /// Otherwise, returns `false`. + /// + bool trySynthesizePropertyRequirementWitness( + ConformanceCheckingContext* context, + LookupResult const& lookupResult, + DeclRef requiredMemberDeclRef, + RefPtr witnessTable); + + bool trySynthesizeWrapperTypePropertyRequirementWitness( + ConformanceCheckingContext* context, + DeclRef requiredMemberDeclRef, + RefPtr witnessTable); + + bool trySynthesizeSubscriptRequirementWitness( + ConformanceCheckingContext* context, + const LookupResult& lookupResult, + DeclRef requiredMemberDeclRef, + RefPtr witnessTable); + + bool trySynthesizeWrapperTypeSubscriptRequirementWitness( + ConformanceCheckingContext* context, + DeclRef requiredMemberDeclRef, + RefPtr witnessTable); + + bool trySynthesizeAssociatedTypeRequirementWitness( + ConformanceCheckingContext* context, + LookupResult const& lookupResult, + DeclRef requiredMemberDeclRef, + RefPtr witnessTable); + + bool trySynthesizeAssociatedConstantRequirementWitness( + ConformanceCheckingContext* context, + LookupResult const& lookupResult, + DeclRef requiredMemberDeclRef, + RefPtr witnessTable); + + /// Attempt to synthesize a declartion that can satisfy `requiredMemberDeclRef` using + /// `lookupResult`. + /// + /// On success, installs the syntethesized declaration in `witnessTable` and returns `true`. + /// Otherwise, returns `false`. + bool trySynthesizeRequirementWitness( + ConformanceCheckingContext* context, + LookupResult const& lookupResult, + DeclRef requiredMemberDeclRef, + RefPtr witnessTable); + + + enum SynthesisPattern + { + // Synthesized method inducts over all arguments. + // T fn(T x, T y, T z, ...) + // { + // typeof(T::member0)::fn(x.member0, y.member0, z.member0, ...); + // typeof(T::member1)::fn(x.member1, y.member1, z.member1, ...); + // ... + // } + // + AllInductive, + + // Synthesized method inducts over all arguments except the first. + // T fn(U x, T y, T z) + // { + // typeof(T::member0)::fn(x, y.member0, z.member0, ...); + // typeof(T::member1)::fn(x, y.member1, z.member1, ...); + // ... + // } + FixedFirstArg + }; - /// Attempt to synthesize `zero`, `dadd` & `dmul` methods for a type that conforms to - /// `IDifferentiable`. - /// On success, installs the syntethesized functions and returns `true`. - /// Otherwise, returns `false`. - bool trySynthesizeDifferentialMethodRequirementWitness( - ConformanceCheckingContext* context, - DeclRef requirementDeclRef, - RefPtr witnessTable, - SynthesisPattern pattern); - - /// Attempt to synthesize an associated `Differential` type for a type that conforms to - /// `IDifferentiable`. - /// - /// On success, installs the syntethesized type in `witnessTable`, injects `[DerivativeMember]` - /// modifiers on differentiable fields to point to the corresponding field in the synthesized - /// differential type, and returns `true`. - /// Otherwise, returns `false`. - bool trySynthesizeDifferentialAssociatedTypeRequirementWitness( - ConformanceCheckingContext* context, - DeclRef requirementDeclRef, - RefPtr witnessTable); - - /// Attempt to synthesize function requirements for enum types to make them conform to `ILogical`. - bool trySynthesizeEnumTypeMethodRequirementWitness(ConformanceCheckingContext* context, - DeclRef requirementDeclRef, - RefPtr witnessTable, - BuiltinRequirementKind requirementKind); - - /// Check references from`[DerivativeMember(...)]` attributes on members of the agg-decl. - /// this is typically deferred until after types are ready for reference. - void checkDifferentiableMembersInType(AggTypeDecl* decl); - - struct DifferentiableMemberInfo - { - Decl* memberDecl; - Type* diffType; - }; + /// Attempt to synthesize `zero`, `dadd` & `dmul` methods for a type that conforms to + /// `IDifferentiable`. + /// On success, installs the syntethesized functions and returns `true`. + /// Otherwise, returns `false`. + bool trySynthesizeDifferentialMethodRequirementWitness( + ConformanceCheckingContext* context, + DeclRef requirementDeclRef, + RefPtr witnessTable, + SynthesisPattern pattern); + + /// Attempt to synthesize an associated `Differential` type for a type that conforms to + /// `IDifferentiable`. + /// + /// On success, installs the syntethesized type in `witnessTable`, injects `[DerivativeMember]` + /// modifiers on differentiable fields to point to the corresponding field in the synthesized + /// differential type, and returns `true`. + /// Otherwise, returns `false`. + bool trySynthesizeDifferentialAssociatedTypeRequirementWitness( + ConformanceCheckingContext* context, + DeclRef requirementDeclRef, + RefPtr witnessTable); + + /// Attempt to synthesize function requirements for enum types to make them conform to + /// `ILogical`. + bool trySynthesizeEnumTypeMethodRequirementWitness( + ConformanceCheckingContext* context, + DeclRef requirementDeclRef, + RefPtr witnessTable, + BuiltinRequirementKind requirementKind); + + /// Check references from`[DerivativeMember(...)]` attributes on members of the agg-decl. + /// this is typically deferred until after types are ready for reference. + void checkDifferentiableMembersInType(AggTypeDecl* decl); + + struct DifferentiableMemberInfo + { + Decl* memberDecl; + Type* diffType; + }; - /// Gather differentiable members from decl. - List collectDifferentiableMemberInfo(ContainerDecl* decl); + /// Gather differentiable members from decl. + List collectDifferentiableMemberInfo(ContainerDecl* decl); + + // Check and register a type if it is differentiable. + void maybeRegisterDifferentiableType(ASTBuilder* builder, Type* type); + + // Find the appropriate member of a declared type to + // satisfy a requirement of an interface the type + // claims to conform to. + // + // The type declaration `typeDecl` has declared that it + // conforms to the interface `interfaceDeclRef`, and + // `requiredMemberDeclRef` is a required member of + // the interface. + // + // If a satisfying value is found, registers it in + // `witnessTable` and returns `true`, otherwise + // returns `false`. + // + bool findWitnessForInterfaceRequirement( + ConformanceCheckingContext* context, + Type* subType, + Type* superInterfaceType, + InheritanceDecl* inheritanceDecl, + DeclRef superInterfaceDeclRef, + DeclRef requiredMemberDeclRef, + RefPtr witnessTable, + SubtypeWitness* subTypeConformsToSuperInterfaceWitness); + + // Check that the type declaration `typeDecl`, which + // declares conformance to the interface `interfaceDeclRef`, + // (via the given `inheritanceDecl`) actually provides + // members to satisfy all the requirements in the interface. + bool checkInterfaceConformance( + ConformanceCheckingContext* context, + Type* subType, + Type* superInterfaceType, + InheritanceDecl* inheritanceDecl, + DeclRef superInterfaceDeclRef, + SubtypeWitness* subTypeConformsToSuperInterfaceWitness, + WitnessTable* witnessTable); + + RefPtr checkInterfaceConformance( + ConformanceCheckingContext* context, + Type* subType, + Type* superInterfaceType, + InheritanceDecl* inheritanceDecl, + DeclRef superInterfaceDeclRef, + SubtypeWitness* subTypeConformsToSuperInterfaceWitness); + + bool checkConformanceToType( + ConformanceCheckingContext* context, + Type* subType, + InheritanceDecl* inheritanceDecl, + Type* superType, + SubtypeWitness* subIsSuperWitness, + WitnessTable* witnessTable); + + /// Check that `type` which has declared that it inherits from (and/or implements) + /// another type via `inheritanceDecl` actually does what it needs to for that + /// inheritance to be valid. + bool checkConformance(Type* type, InheritanceDecl* inheritanceDecl, ContainerDecl* parentDecl); + + void checkExtensionConformance(ExtensionDecl* decl); + + void checkAggTypeConformance(AggTypeDecl* decl); + + bool isIntegerBaseType(BaseType baseType); + + /// Is `type` a scalar integer type. + bool isScalarIntegerType(Type* type); + + /// Is `type` something we allow as compile time constants, i.e. scalar integer and enum types. + bool isValidCompileTimeConstantType(Type* type); + + bool isIntValueInRangeOfType(IntegerLiteralValue value, Type* type); + + // Validate that `type` is a suitable type to use + // as the tag type for an `enum` + void validateEnumTagType(Type* type, SourceLoc const& loc); + + void checkStmt(Stmt* stmt, SemanticsContext const& context); + + void getGenericParams( + GenericDecl* decl, + List& outParams, + List& outConstraints); + + /// Determine if `left` and `right` have matching generic signatures. + /// If they do, then outputs a specialized declRef to `ioSubstRightToLeft` that + /// represents a reference to `right` with the parameters of `left`. + bool doGenericSignaturesMatch( + GenericDecl* left, + GenericDecl* right, + DeclRef* outSpecializedRightInner); + + // Check if two functions have the same signature for the purposes + // of overload resolution. + bool doFunctionSignaturesMatch(DeclRef fst, DeclRef snd); + + Result checkRedeclaration(Decl* newDecl, Decl* oldDecl); + Result checkFuncRedeclaration(FuncDecl* newDecl, FuncDecl* oldDecl); + void checkForRedeclaration(Decl* decl); + + Expr* checkPredicateExpr(Expr* expr); + + enum class ConstantFoldingKind + { + CompileTime, + LinkTime, + }; + Expr* checkExpressionAndExpectIntegerConstant( + Expr* expr, + IntVal** outIntVal, + ConstantFoldingKind kind); - // Check and register a type if it is differentiable. - void maybeRegisterDifferentiableType(ASTBuilder* builder, Type* type); + IntegerLiteralValue GetMinBound(IntVal* val); - // Find the appropriate member of a declared type to - // satisfy a requirement of an interface the type - // claims to conform to. - // - // The type declaration `typeDecl` has declared that it - // conforms to the interface `interfaceDeclRef`, and - // `requiredMemberDeclRef` is a required member of - // the interface. - // - // If a satisfying value is found, registers it in - // `witnessTable` and returns `true`, otherwise - // returns `false`. - // - bool findWitnessForInterfaceRequirement( - ConformanceCheckingContext* context, - Type* subType, - Type* superInterfaceType, - InheritanceDecl* inheritanceDecl, - DeclRef superInterfaceDeclRef, - DeclRef requiredMemberDeclRef, - RefPtr witnessTable, - SubtypeWitness* subTypeConformsToSuperInterfaceWitness); - - // Check that the type declaration `typeDecl`, which - // declares conformance to the interface `interfaceDeclRef`, - // (via the given `inheritanceDecl`) actually provides - // members to satisfy all the requirements in the interface. - bool checkInterfaceConformance( - ConformanceCheckingContext* context, - Type* subType, - Type* superInterfaceType, - InheritanceDecl* inheritanceDecl, - DeclRef superInterfaceDeclRef, - SubtypeWitness* subTypeConformsToSuperInterfaceWitness, - WitnessTable* witnessTable); - - RefPtr checkInterfaceConformance( - ConformanceCheckingContext* context, - Type* subType, - Type* superInterfaceType, - InheritanceDecl* inheritanceDecl, - DeclRef superInterfaceDeclRef, - SubtypeWitness* subTypeConformsToSuperInterfaceWitness); - - bool checkConformanceToType( - ConformanceCheckingContext* context, - Type* subType, - InheritanceDecl* inheritanceDecl, - Type* superType, - SubtypeWitness* subIsSuperWitness, - WitnessTable* witnessTable); - - /// Check that `type` which has declared that it inherits from (and/or implements) - /// another type via `inheritanceDecl` actually does what it needs to for that - /// inheritance to be valid. - bool checkConformance( - Type* type, - InheritanceDecl* inheritanceDecl, - ContainerDecl* parentDecl); - - void checkExtensionConformance(ExtensionDecl* decl); - - void checkAggTypeConformance(AggTypeDecl* decl); - - bool isIntegerBaseType(BaseType baseType); - - /// Is `type` a scalar integer type. - bool isScalarIntegerType(Type* type); - - /// Is `type` something we allow as compile time constants, i.e. scalar integer and enum types. - bool isValidCompileTimeConstantType(Type* type); - - bool isIntValueInRangeOfType(IntegerLiteralValue value, Type* type); - - // Validate that `type` is a suitable type to use - // as the tag type for an `enum` - void validateEnumTagType(Type* type, SourceLoc const& loc); - - void checkStmt(Stmt* stmt, SemanticsContext const& context); - - void getGenericParams( - GenericDecl* decl, - List& outParams, - List& outConstraints); - - /// Determine if `left` and `right` have matching generic signatures. - /// If they do, then outputs a specialized declRef to `ioSubstRightToLeft` that - /// represents a reference to `right` with the parameters of `left`. - bool doGenericSignaturesMatch( - GenericDecl* left, - GenericDecl* right, - DeclRef* outSpecializedRightInner); - - // Check if two functions have the same signature for the purposes - // of overload resolution. - bool doFunctionSignaturesMatch( - DeclRef fst, - DeclRef snd); - - Result checkRedeclaration(Decl* newDecl, Decl* oldDecl); - Result checkFuncRedeclaration(FuncDecl* newDecl, FuncDecl* oldDecl); - void checkForRedeclaration(Decl* decl); - - Expr* checkPredicateExpr(Expr* expr); - - enum class ConstantFoldingKind - { - CompileTime, - LinkTime, - }; - Expr* checkExpressionAndExpectIntegerConstant(Expr* expr, IntVal** outIntVal, ConstantFoldingKind kind); + void maybeInferArraySizeForVariable(VarDeclBase* varDecl); - IntegerLiteralValue GetMinBound(IntVal* val); + void validateArraySizeForVariable(VarDeclBase* varDecl); - void maybeInferArraySizeForVariable(VarDeclBase* varDecl); + IntVal* getIntVal(IntegerLiteralExpr* expr); - void validateArraySizeForVariable(VarDeclBase* varDecl); + inline IntVal* getIntVal(SubstExpr expr) + { + return getIntVal(expr.getExpr()); + } - IntVal* getIntVal(IntegerLiteralExpr* expr); + Name* getName(String const& text) { return getNamePool()->getName(text); } - inline IntVal* getIntVal(SubstExpr expr) + /// Helper type to detect and catch circular definitions when folding constants, + /// to prevent the compiler from going into infinite loops or overflowing the stack. + struct ConstantFoldingCircularityInfo + { + ConstantFoldingCircularityInfo(Decl* decl, ConstantFoldingCircularityInfo* next) + : decl(decl), next(next) { - return getIntVal(expr.getExpr()); } - Name* getName(String const& text) - { - return getNamePool()->getName(text); - } + /// A declaration whose value is contributing to the constant being folded + Decl* decl = nullptr; - /// Helper type to detect and catch circular definitions when folding constants, - /// to prevent the compiler from going into infinite loops or overflowing the stack. - struct ConstantFoldingCircularityInfo - { - ConstantFoldingCircularityInfo( - Decl* decl, - ConstantFoldingCircularityInfo* next) - : decl(decl) - , next(next) - {} - - /// A declaration whose value is contributing to the constant being folded - Decl* decl = nullptr; - - /// The rest of the links in the chain of declarations being folded - ConstantFoldingCircularityInfo* next = nullptr; - }; - /// Try to apply front-end constant folding to determine the value of `invokeExpr`. - IntVal* tryConstantFoldExpr( - SubstExpr invokeExpr, - ConstantFoldingKind kind, - ConstantFoldingCircularityInfo* circularityInfo); - - /// Try to apply front-end constant folding to determine the value of `expr`. - IntVal* tryConstantFoldExpr( - SubstExpr expr, - ConstantFoldingKind kind, - ConstantFoldingCircularityInfo* circularityInfo); - - bool _checkForCircularityInConstantFolding( - Decl* decl, - ConstantFoldingCircularityInfo* circularityInfo); - - /// Try to resolve a compile-time constant `IntVal` from the given `declRef`. - IntVal* tryConstantFoldDeclRef( - DeclRef const& declRef, - ConstantFoldingKind kind, - ConstantFoldingCircularityInfo* circularityInfo); - - /// Try to extract the value of an integer constant expression, either - /// returning the `IntVal` value, or null if the expression isn't recognized - /// as an integer constant. - /// - IntVal* tryFoldIntegerConstantExpression( - SubstExpr expr, - ConstantFoldingKind kind, - ConstantFoldingCircularityInfo* circularityInfo); - - IntVal* tryFoldIndexExpr( - SubstExpr expr, - ConstantFoldingKind kind, - ConstantFoldingCircularityInfo* circularityInfo); - - // Enforce that an expression resolves to an integer constant, and get its value - enum class IntegerConstantExpressionCoercionType - { - SpecificType, - AnyInteger - }; - IntVal* CheckIntegerConstantExpression(Expr* inExpr, IntegerConstantExpressionCoercionType coercionType, Type* expectedType, ConstantFoldingKind kind); - IntVal* CheckIntegerConstantExpression(Expr* inExpr, IntegerConstantExpressionCoercionType coercionType, Type* expectedType, ConstantFoldingKind kind, DiagnosticSink* sink); + /// The rest of the links in the chain of declarations being folded + ConstantFoldingCircularityInfo* next = nullptr; + }; + /// Try to apply front-end constant folding to determine the value of `invokeExpr`. + IntVal* tryConstantFoldExpr( + SubstExpr invokeExpr, + ConstantFoldingKind kind, + ConstantFoldingCircularityInfo* circularityInfo); + + /// Try to apply front-end constant folding to determine the value of `expr`. + IntVal* tryConstantFoldExpr( + SubstExpr expr, + ConstantFoldingKind kind, + ConstantFoldingCircularityInfo* circularityInfo); + + bool _checkForCircularityInConstantFolding( + Decl* decl, + ConstantFoldingCircularityInfo* circularityInfo); + + /// Try to resolve a compile-time constant `IntVal` from the given `declRef`. + IntVal* tryConstantFoldDeclRef( + DeclRef const& declRef, + ConstantFoldingKind kind, + ConstantFoldingCircularityInfo* circularityInfo); + + /// Try to extract the value of an integer constant expression, either + /// returning the `IntVal` value, or null if the expression isn't recognized + /// as an integer constant. + /// + IntVal* tryFoldIntegerConstantExpression( + SubstExpr expr, + ConstantFoldingKind kind, + ConstantFoldingCircularityInfo* circularityInfo); + + IntVal* tryFoldIndexExpr( + SubstExpr expr, + ConstantFoldingKind kind, + ConstantFoldingCircularityInfo* circularityInfo); + + // Enforce that an expression resolves to an integer constant, and get its value + enum class IntegerConstantExpressionCoercionType + { + SpecificType, + AnyInteger + }; + IntVal* CheckIntegerConstantExpression( + Expr* inExpr, + IntegerConstantExpressionCoercionType coercionType, + Type* expectedType, + ConstantFoldingKind kind); + IntVal* CheckIntegerConstantExpression( + Expr* inExpr, + IntegerConstantExpressionCoercionType coercionType, + Type* expectedType, + ConstantFoldingKind kind, + DiagnosticSink* sink); - IntVal* CheckEnumConstantExpression(Expr* expr, ConstantFoldingKind kind); + IntVal* CheckEnumConstantExpression(Expr* expr, ConstantFoldingKind kind); - Expr* CheckSimpleSubscriptExpr( - IndexExpr* subscriptExpr, - Type* elementType); + Expr* CheckSimpleSubscriptExpr(IndexExpr* subscriptExpr, Type* elementType); - // The way that we have designed out type system, pretyt much *every* - // type is a reference to some declaration in the core module. - // That means that when we construct a new type on the fly, we need - // to make sure that it is wired up to reference the appropriate - // declaration, or else it won't compare as equal to other types - // that *do* reference the declaration. - // - // This function is used to construct a `vector` type - // programmatically, so that it will work just like a type of - // that form constructed by the user. - VectorExpressionType* createVectorType( - Type* elementType, - IntVal* elementCount); + // The way that we have designed out type system, pretyt much *every* + // type is a reference to some declaration in the core module. + // That means that when we construct a new type on the fly, we need + // to make sure that it is wired up to reference the appropriate + // declaration, or else it won't compare as equal to other types + // that *do* reference the declaration. + // + // This function is used to construct a `vector` type + // programmatically, so that it will work just like a type of + // that form constructed by the user. + VectorExpressionType* createVectorType(Type* elementType, IntVal* elementCount); - // + // - /// Given an immutable `expr` used as an l-value emit a special diagnostic if it was derived from `this`. - void maybeDiagnoseThisNotLValue(Expr* expr); + /// Given an immutable `expr` used as an l-value emit a special diagnostic if it was derived + /// from `this`. + void maybeDiagnoseThisNotLValue(Expr* expr); - // Figure out what type an initializer/constructor declaration - // is supposed to return. In most cases this is just the type - // declaration that its declaration is nested inside. - Type* findResultTypeForConstructorDecl(ConstructorDecl* decl); + // Figure out what type an initializer/constructor declaration + // is supposed to return. In most cases this is just the type + // declaration that its declaration is nested inside. + Type* findResultTypeForConstructorDecl(ConstructorDecl* decl); - /// Determine what type `This` should refer to in the context of the given parent `decl`. - Type* calcThisType(DeclRef decl); + /// Determine what type `This` should refer to in the context of the given parent `decl`. + Type* calcThisType(DeclRef decl); - /// Determine what type `This` should refer to in an extension of `type`. - Type* calcThisType(Type* type); + /// Determine what type `This` should refer to in an extension of `type`. + Type* calcThisType(Type* type); - // + // - struct Constraint - { - Decl* decl = nullptr; // the declaration of the thing being constraints - Index indexInPack = 0; // If the constraint is for a type parameter pack, which index in the pack is this constraint for? + struct Constraint + { + Decl* decl = nullptr; // the declaration of the thing being constraints + Index indexInPack = 0; // If the constraint is for a type parameter pack, which index in the + // pack is this constraint for? + + Val* val = nullptr; // the value to which we are constraining it + bool isUsedAsLValue = false; // If this constraint is for a type parameter, is the type used + // in an l-value parameter? + bool satisfied = false; // Has this constraint been met? + + // Is this constraint optional? An optional constraint provides a hint value to a parameter + // if it is otherwise unconstrained, but doesn't take precedence over a constraint that is + // not optional. + bool isOptional = false; + }; - Val* val = nullptr; // the value to which we are constraining it - bool isUsedAsLValue = false; // If this constraint is for a type parameter, is the type used in an l-value parameter? - bool satisfied = false; // Has this constraint been met? + // A collection of constraints that will need to be satisfied (solved) + // in order for checking to succeed. + struct ConstraintSystem + { + // A source location to use in reporting any issues + SourceLoc loc; - // Is this constraint optional? An optional constraint provides a hint value to a parameter - // if it is otherwise unconstrained, but doesn't take precedence over a constraint that is not optional. - bool isOptional = false; - }; + // The generic declaration whose parameters we + // are trying to solve for. + GenericDecl* genericDecl = nullptr; - // A collection of constraints that will need to be satisfied (solved) - // in order for checking to succeed. - struct ConstraintSystem - { - // A source location to use in reporting any issues - SourceLoc loc; + // Constraints we have accumulated, which constrain + // the possible arguments for those parameters. + List constraints; + + // Additional subtype witnesses available to the currentt constraint solving context. + Type* subTypeForAdditionalWitnesses = nullptr; + Dictionary* additionalSubtypeWitnesses = nullptr; + }; + + Type* TryJoinVectorAndScalarType( + ConstraintSystem* constraints, + VectorExpressionType* vectorType, + BasicExpressionType* scalarType); + + /// Is the given interface one that a tagged-union type can conform to? + /// + /// If a tagged union type `__TaggedUnion(A,B)` is going to be + /// plugged in for a type parameter `T : IFoo` then we need to + /// be sure that the interface `IFoo` doesn't have anything + /// that could lead to unsafe/unsound behavior. This function + /// checks that all the requirements on the interfaceare safe ones. + /// + bool isInterfaceSafeForTaggedUnion(DeclRef interfaceDeclRef); + + /// Is the given interface requirement one that a tagged-union type can satisfy? + /// + /// Unsafe requirements include any `static` requirements, + /// any associated types, and also any requirements that make + /// use of the `This` type (once we support it). + /// + bool isInterfaceRequirementSafeForTaggedUnion( + DeclRef interfaceDeclRef, + DeclRef requirementDeclRef); + + /// Check whether `subType` is a subtype of `superType` + /// + /// If `subType` is a subtype of `superType`, returns + /// a witness value for the subtype relationship. + /// + /// If `subType` is *not* a subtype of `superType`, returns null. + /// + SubtypeWitness* isSubtype(Type* subType, Type* superType, IsSubTypeOptions isSubTypeOptions); + + SubtypeWitness* checkAndConstructSubtypeWitness( + Type* subType, + Type* superType, + IsSubTypeOptions isSubTypeOptions); + + bool isValidGenericConstraintType(Type* type); + + SubtypeWitness* isTypeDifferentiable(Type* type); + + bool doesTypeHaveTag(Type* type, TypeTag tag); + + TypeTag getTypeTags(Type* type); + + Type* getConstantBufferElementType(Type* type); + + /// Check whether `subType` is a sub-type of `superTypeDeclRef`, + /// and return a witness to the sub-type relationship if it holds + /// (return null otherwise). + /// + SubtypeWitness* tryGetSubtypeWitness(Type* subType, Type* superType) + { + return isSubtype(subType, superType, IsSubTypeOptions::None); + } - // The generic declaration whose parameters we - // are trying to solve for. - GenericDecl* genericDecl = nullptr; + /// Check whether `type` conforms to `interfaceDeclRef`, + /// and return a witness to the conformance if it holds + /// (return null otherwise). + /// + /// This function is equivalent to `tryGetSubtypeWitness()`. + /// + SubtypeWitness* tryGetInterfaceConformanceWitness(Type* type, Type* interfaceType); - // Constraints we have accumulated, which constrain - // the possible arguments for those parameters. - List constraints; + Expr* createCastToSuperTypeExpr(Type* toType, Expr* fromExpr, Val* witness); - // Additional subtype witnesses available to the currentt constraint solving context. - Type* subTypeForAdditionalWitnesses = nullptr; - Dictionary* additionalSubtypeWitnesses = nullptr; - }; + Expr* createModifierCastExpr(Type* toType, Expr* fromExpr); - Type* TryJoinVectorAndScalarType( - ConstraintSystem* constraints, - VectorExpressionType* vectorType, - BasicExpressionType* scalarType); - - /// Is the given interface one that a tagged-union type can conform to? - /// - /// If a tagged union type `__TaggedUnion(A,B)` is going to be - /// plugged in for a type parameter `T : IFoo` then we need to - /// be sure that the interface `IFoo` doesn't have anything - /// that could lead to unsafe/unsound behavior. This function - /// checks that all the requirements on the interfaceare safe ones. - /// - bool isInterfaceSafeForTaggedUnion( - DeclRef interfaceDeclRef); - - /// Is the given interface requirement one that a tagged-union type can satisfy? - /// - /// Unsafe requirements include any `static` requirements, - /// any associated types, and also any requirements that make - /// use of the `This` type (once we support it). - /// - bool isInterfaceRequirementSafeForTaggedUnion( - DeclRef interfaceDeclRef, - DeclRef requirementDeclRef); - - /// Check whether `subType` is a subtype of `superType` - /// - /// If `subType` is a subtype of `superType`, returns - /// a witness value for the subtype relationship. - /// - /// If `subType` is *not* a subtype of `superType`, returns null. - /// - SubtypeWitness* isSubtype( - Type* subType, - Type* superType, - IsSubTypeOptions isSubTypeOptions - ); - - SubtypeWitness* checkAndConstructSubtypeWitness( - Type* subType, - Type* superType, - IsSubTypeOptions isSubTypeOptions - ); - - bool isValidGenericConstraintType(Type* type); - - SubtypeWitness* isTypeDifferentiable(Type* type); - - bool doesTypeHaveTag(Type* type, TypeTag tag); - - TypeTag getTypeTags(Type* type); - - Type* getConstantBufferElementType(Type* type); - - /// Check whether `subType` is a sub-type of `superTypeDeclRef`, - /// and return a witness to the sub-type relationship if it holds - /// (return null otherwise). - /// - SubtypeWitness* tryGetSubtypeWitness( - Type* subType, - Type* superType) - { - return isSubtype(subType, superType, IsSubTypeOptions::None); - } + /// Does there exist an implicit conversion from `fromType` to `toType`? + bool canConvertImplicitly(Type* toType, QualType fromType); - /// Check whether `type` conforms to `interfaceDeclRef`, - /// and return a witness to the conformance if it holds - /// (return null otherwise). - /// - /// This function is equivalent to `tryGetSubtypeWitness()`. - /// - SubtypeWitness* tryGetInterfaceConformanceWitness( - Type* type, - Type* interfaceType); - - Expr* createCastToSuperTypeExpr( - Type* toType, - Expr* fromExpr, - Val* witness); - - Expr* createModifierCastExpr( - Type* toType, - Expr* fromExpr); - - /// Does there exist an implicit conversion from `fromType` to `toType`? - bool canConvertImplicitly( - Type* toType, - QualType fromType); - - bool canConvertImplicitly( - ConversionCost cost); - - ConversionCost getConversionCost(Type* toType, QualType fromType); - - Type* _tryJoinTypeWithInterface( - ConstraintSystem* constraints, - Type* type, - Type* interfaceType); - - // Try to compute the "join" between two types - Type* TryJoinTypes( - ConstraintSystem* constraints, - QualType left, - QualType right); - - // Try to solve a system of generic constraints. - // The `system` argument provides the constraints. - // The `varSubst` argument provides the list of constraint - // variables that were created for the system. - // - // Returns a new declref to the inner decl of `genericDeclRef`, - // representing the specialized generic with the values - // we solved for along the way. - DeclRef trySolveConstraintSystem( - ConstraintSystem* system, - DeclRef genericDeclRef, - ArrayView knownGenericArgs, - ConversionCost& outBaseCost); - - - // State related to overload resolution for a call - // to an overloaded symbol - struct OverloadResolveContext - { - enum class Mode - { - // We are just checking if a candidate works or not - JustTrying, + bool canConvertImplicitly(ConversionCost cost); - // We want to actually update the AST for a chosen candidate - ForReal, - }; + ConversionCost getConversionCost(Type* toType, QualType fromType); - // Location to use when reporting overload-resolution errors. - SourceLoc loc; + Type* _tryJoinTypeWithInterface(ConstraintSystem* constraints, Type* type, Type* interfaceType); - // The original expression (if any) that triggered things - AppExprBase* originalExpr = nullptr; + // Try to compute the "join" between two types + Type* TryJoinTypes(ConstraintSystem* constraints, QualType left, QualType right); - // Source location of the "function" part of the expression, if any - SourceLoc funcLoc; + // Try to solve a system of generic constraints. + // The `system` argument provides the constraints. + // The `varSubst` argument provides the list of constraint + // variables that were created for the system. + // + // Returns a new declref to the inner decl of `genericDeclRef`, + // representing the specialized generic with the values + // we solved for along the way. + DeclRef trySolveConstraintSystem( + ConstraintSystem* system, + DeclRef genericDeclRef, + ArrayView knownGenericArgs, + ConversionCost& outBaseCost); - // The source scope of the lookup for performing visibiliity tests. - Scope* sourceScope = nullptr; - // The original arguments to the call - Index argCount = 0; - List* args = nullptr; - Type** argTypes = nullptr; + // State related to overload resolution for a call + // to an overloaded symbol + struct OverloadResolveContext + { + enum class Mode + { + // We are just checking if a candidate works or not + JustTrying, - Index getArgCount() { return argCount; } - Expr*& getArg(Index index) { return (*args)[index]; } - Type* getArgType(Index index) - { - if(argTypes) - return argTypes[index]; - else - return getArg(index)->type.type; - } - Type* getArgTypeForInference(Index index, SemanticsVisitor* semantics) - { - if(argTypes) - return argTypes[index]; - else - return semantics->maybeResolveOverloadedExpr(getArg(index), LookupMask::Default, nullptr)->type; - } - struct MatchedArg - { - Expr* argExpr = nullptr; - Type* argType = nullptr; - }; - bool matchArgumentsToParams(SemanticsVisitor* semantics, const List& params, bool computeTypes, ShortList& outMatchedArgs); + // We want to actually update the AST for a chosen candidate + ForReal, + }; - bool disallowNestedConversions = false; + // Location to use when reporting overload-resolution errors. + SourceLoc loc; - Expr* baseExpr = nullptr; + // The original expression (if any) that triggered things + AppExprBase* originalExpr = nullptr; - // Are we still trying out candidates, or are we - // checking the chosen one for real? - Mode mode = Mode::JustTrying; + // Source location of the "function" part of the expression, if any + SourceLoc funcLoc; - // We store one candidate directly, so that we don't - // need to do dynamic allocation on the list every time - OverloadCandidate bestCandidateStorage; - OverloadCandidate* bestCandidate = nullptr; + // The source scope of the lookup for performing visibiliity tests. + Scope* sourceScope = nullptr; - // Full list of all candidates being considered, in the ambiguous case - List bestCandidates; - }; + // The original arguments to the call + Index argCount = 0; + List* args = nullptr; + Type** argTypes = nullptr; - struct ParamCounts + Index getArgCount() { return argCount; } + Expr*& getArg(Index index) { return (*args)[index]; } + Type* getArgType(Index index) { - Count required; - Count allowed; - }; - - // count the number of parameters required/allowed for a callable - ParamCounts CountParameters(FilteredMemberRefList params); - - // count the number of parameters required/allowed for a generic - ParamCounts CountParameters(DeclRef genericRef); - - bool TryCheckOverloadCandidateClassNewMatchUp( - OverloadResolveContext& context, - OverloadCandidate const& candidate); - - bool TryCheckOverloadCandidateArity( - OverloadResolveContext& context, - OverloadCandidate const& candidate); - - bool TryCheckOverloadCandidateFixity( - OverloadResolveContext& context, - OverloadCandidate const& candidate); - - bool TryCheckOverloadCandidateVisibility( - OverloadResolveContext& context, - OverloadCandidate const& candidate); - - bool TryCheckGenericOverloadCandidateTypes( - OverloadResolveContext& context, - OverloadCandidate& candidate); - - bool TryCheckOverloadCandidateTypes( - OverloadResolveContext& context, - OverloadCandidate& candidate); - - bool TryCheckOverloadCandidateDirections( - OverloadResolveContext& /*context*/, - OverloadCandidate const& /*candidate*/); - - /// Check if the given `expr` refers to an `in` function - /// parameter, or part of one (through field reference, etc.). - /// - /// If the expression refers into a parameter, returns - /// the declaration of the parameter. Otherwise returns - /// null. - /// - ParamDecl* isReferenceIntoFunctionInputParameter( - Expr* expr); - - // Create a witness that attests to the fact that `type` - // is equal to itself. - TypeEqualityWitness* createTypeEqualityWitness( - Type* type); - - // In the case where we are explicitly applying a generic - // to arguments (e.g., `G`) check that the constraints - // on those parameters are satisfied. - // - // Note: the constraints actually work as additional parameters/arguments - // of the generic, and so we need to reify them into the final - // argument list. - // - bool TryCheckOverloadCandidateConstraints( - OverloadResolveContext& context, - OverloadCandidate& candidate); - - // Try to check an overload candidate, but bail out - // if any step fails - void TryCheckOverloadCandidate( - OverloadResolveContext& context, - OverloadCandidate& candidate); - - // Create the representation of a given generic applied to some arguments - Expr* createGenericDeclRef( - Expr* baseExpr, - Expr* originalExpr, - SubstitutionSet substSet); - - // Take an overload candidate that previously got through - // `TryCheckOverloadCandidate` above, and try to finish - // up the work and turn it into a real expression. - // - // If the candidate isn't actually applicable, this is - // where we'd start reporting the issue(s). - Expr* CompleteOverloadCandidate( - OverloadResolveContext& context, - OverloadCandidate& candidate); - - // Implement a comparison operation between overload candidates, - // so that the better candidate compares as less-than the other - int CompareOverloadCandidates( - OverloadCandidate* left, - OverloadCandidate* right); - - /// If `declRef` representations a specialization of a generic, returns the number of specialized generic arguments. - /// Otherwise, returns zero. - /// - Int getSpecializedParamCount(DeclRef const& declRef); - - /// Compare items `left` and `right` produced by lookup, to see if one should be favored for overloading. - int CompareLookupResultItems( - LookupResultItem const& left, - LookupResultItem const& right); - - /// Compare items `left` and `right` being considered as overload candidates, and determine if one should be favored for structural reasons. - int compareOverloadCandidateSpecificity( - LookupResultItem const& left, - LookupResultItem const& right); - - void AddOverloadCandidateInner( - OverloadResolveContext& context, - OverloadCandidate& candidate); - - void AddOverloadCandidate( - OverloadResolveContext& context, - OverloadCandidate& candidate, - ConversionCost baseCost); - - void AddHigherOrderOverloadCandidates( - Expr* funcExpr, - OverloadResolveContext& context, - ConversionCost baseCost); - - void AddFuncOverloadCandidate( - LookupResultItem item, - DeclRef funcDeclRef, - OverloadResolveContext& context, - ConversionCost baseCost); - - void AddFuncOverloadCandidate( - FuncType* /*funcType*/, - OverloadResolveContext& /*context*/, - ConversionCost baseCost); - - void AddFuncExprOverloadCandidate( - FuncType* funcType, - OverloadResolveContext& context, - Expr* expr, - ConversionCost baseCost); - - // Add a candidate callee for overload resolution, based on - // calling a particular `ConstructorDecl`. - void AddCtorOverloadCandidate( - LookupResultItem typeItem, - Type* type, - DeclRef ctorDeclRef, - OverloadResolveContext& context, - Type* resultType, - ConversionCost baseCost); - - // If the given declaration has generic parameters, then - // return the corresponding `GenericDecl` that holds the - // parameters, etc. This returns the immediate generic parent - // of `decl`, e.g. the generic for f, and *not* any indirect - // generic parents, such as P.f(). - GenericDecl* GetOuterGeneric(Decl* decl); - - // If `decl` is inside a generic, return that outer generic, - // otherwise returns `decl`. - Decl* getOuterGenericOrSelf(Decl* decl); - - // Find the next outer generic parent of `decl`, including - // indirect parents. - GenericDecl* findNextOuterGeneric(Decl* decl); - - struct ValUnificationContext + if (argTypes) + return argTypes[index]; + else + return getArg(index)->type.type; + } + Type* getArgTypeForInference(Index index, SemanticsVisitor* semantics) + { + if (argTypes) + return argTypes[index]; + else + return semantics + ->maybeResolveOverloadedExpr(getArg(index), LookupMask::Default, nullptr) + ->type; + } + struct MatchedArg { - Index indexInTypePack = 0; + Expr* argExpr = nullptr; + Type* argType = nullptr; }; + bool matchArgumentsToParams( + SemanticsVisitor* semantics, + const List& params, + bool computeTypes, + ShortList& outMatchedArgs); - // Try to find a unification for two values - bool TryUnifyVals( - ConstraintSystem& constraints, - ValUnificationContext unificationContext, - Val* fst, - bool fstLVal, - Val* snd, - bool sndLVal); - - bool tryUnifyDeclRef( - ConstraintSystem& constraints, - ValUnificationContext unificationContext, - DeclRefBase* fst, - bool fstLVal, - DeclRefBase* snd, - bool sndLVal); - - bool tryUnifyGenericAppDeclRef( - ConstraintSystem& constraints, - ValUnificationContext unificationContext, - GenericAppDeclRef* fst, - bool fstLVal, - GenericAppDeclRef* snd, - bool sndLVal); - - bool TryUnifyTypeParam( - ConstraintSystem& constraints, - ValUnificationContext unificationContext, - GenericTypeParamDeclBase* typeParamDecl, - QualType type); - - bool TryUnifyIntParam( - ConstraintSystem& constraints, - ValUnificationContext unificationContext, - GenericValueParamDecl* paramDecl, - IntVal* val); - - bool TryUnifyIntParam( - ConstraintSystem& constraints, - ValUnificationContext unificationContext, - DeclRef const& varRef, - IntVal* val); - - bool TryUnifyTypesByStructuralMatch( - ConstraintSystem& constraints, - ValUnificationContext unificationContext, - QualType fst, - QualType snd); - - bool TryUnifyTypes( - ConstraintSystem& constraints, - ValUnificationContext unificationContext, - QualType fst, - QualType snd); - - bool TryUnifyConjunctionType( - ConstraintSystem& constraints, - ValUnificationContext unificationContext, - QualType fst, - QualType snd); - - void maybeUnifyUnconstraintIntParam( - ConstraintSystem& constraints, - ValUnificationContext unificationContext, - IntVal* param, - IntVal* arg, - bool paramIsLVal); - - // Is the candidate extension declaration actually applicable to the given type - DeclRef applyExtensionToType( - ExtensionDecl* extDecl, - Type* type, - Dictionary* additionalSubtypeWitnessesForType = nullptr); - - // Take a generic declaration that is being applied - // in a context and attempt to infer any missing generic - // arguments to form a `DeclRef` to the inner declaration - // that could be applicable in the context of the given - // overloaded call. - // Also computes a `baseCost` for the inferred arguments, - // so that we can prefer a more specialized generic candidate - // when there is ambiguity. For example, given - // ``` - // interface IBase; - // interface IDerived : IBase; - // struct Derived : IDerived {} - // void f1(T b) - // void f2(T b); - // ``` - // We will prefer f2 when seeing f(Derived()), because it takes - // less steps to upcast `Derived` to `IDerived` than it does - // to `IBase`. - // - DeclRef inferGenericArguments( - DeclRef genericDeclRef, - OverloadResolveContext& context, - ArrayView knownGenericArgs, - ConversionCost &outBaseCost, - List *innerParameterTypes = nullptr); - - void AddTypeOverloadCandidates( - Type* type, - OverloadResolveContext& context); - - void AddDeclRefOverloadCandidates( - LookupResultItem item, - OverloadResolveContext& context, - ConversionCost baseCost); - - void AddOverloadCandidates( - LookupResult const& result, - OverloadResolveContext& context); - - void AddOverloadCandidates( - Expr* funcExpr, - OverloadResolveContext& context); - - String getCallSignatureString( - OverloadResolveContext& context); - - Expr* ResolveInvoke(InvokeExpr * expr); - - void AddGenericOverloadCandidate( - LookupResultItem baseItem, - OverloadResolveContext& context); - - void AddGenericOverloadCandidates( - Expr* baseExpr, - OverloadResolveContext& context); - - template - void trySetGenericToRayTracingWithParamAttribute( - LookupResultItem genericItem, - DeclRef genericDeclRef, - OverloadResolveContext& context); - - // Add overload candidates based on use of `genericDeclRef` - // in an ordinary function-call context (that is, where it - // has been applied to arguments using `()` and not `<>`). - // - // If some or all of the generic arguments to `genericDeclRef` - // are known at the call site, they should be passed in via - // `substWithKnownGenericArgs`. - // - void addOverloadCandidatesForCallToGeneric( - LookupResultItem genericItem, - OverloadResolveContext& context, - ArrayView knownGenericArgs); - - /// Check a generic application where the operands have already been checked. - Expr* checkGenericAppWithCheckedArgs(GenericAppExpr* genericAppExpr); - - Expr* CheckExpr(Expr* expr); - - - void compareMemoryQualifierOfParamToArgument(ParamDecl* paramIn, Expr* argIn); - Expr* CheckInvokeExprWithCheckedOperands(InvokeExpr *expr); - // Get the type to use when referencing a declaration - QualType GetTypeForDeclRef(DeclRef declRef, SourceLoc loc); + bool disallowNestedConversions = false; - // - // - // + Expr* baseExpr = nullptr; + + // Are we still trying out candidates, or are we + // checking the chosen one for real? + Mode mode = Mode::JustTrying; - Expr* MaybeDereference(Expr* inExpr); + // We store one candidate directly, so that we don't + // need to do dynamic allocation on the list every time + OverloadCandidate bestCandidateStorage; + OverloadCandidate* bestCandidate = nullptr; - Expr* CheckMatrixSwizzleExpr( - MemberExpr* memberRefExpr, - Type* baseElementType, - IntegerLiteralValue baseElementRowCount, - IntegerLiteralValue baseElementColCount); + // Full list of all candidates being considered, in the ambiguous case + List bestCandidates; + }; - Expr* CheckMatrixSwizzleExpr( - MemberExpr* memberRefExpr, - Type* baseElementType, - IntVal* baseElementRowCount, - IntVal* baseElementColCount); + struct ParamCounts + { + Count required; + Count allowed; + }; - Expr* checkTupleSwizzleExpr(MemberExpr* memberExpr, TupleType* baseTupleType); + // count the number of parameters required/allowed for a callable + ParamCounts CountParameters(FilteredMemberRefList params); + + // count the number of parameters required/allowed for a generic + ParamCounts CountParameters(DeclRef genericRef); + + bool TryCheckOverloadCandidateClassNewMatchUp( + OverloadResolveContext& context, + OverloadCandidate const& candidate); + + bool TryCheckOverloadCandidateArity( + OverloadResolveContext& context, + OverloadCandidate const& candidate); + + bool TryCheckOverloadCandidateFixity( + OverloadResolveContext& context, + OverloadCandidate const& candidate); + + bool TryCheckOverloadCandidateVisibility( + OverloadResolveContext& context, + OverloadCandidate const& candidate); + + bool TryCheckGenericOverloadCandidateTypes( + OverloadResolveContext& context, + OverloadCandidate& candidate); + + bool TryCheckOverloadCandidateTypes( + OverloadResolveContext& context, + OverloadCandidate& candidate); + + bool TryCheckOverloadCandidateDirections( + OverloadResolveContext& /*context*/, + OverloadCandidate const& /*candidate*/ + ); + + /// Check if the given `expr` refers to an `in` function + /// parameter, or part of one (through field reference, etc.). + /// + /// If the expression refers into a parameter, returns + /// the declaration of the parameter. Otherwise returns + /// null. + /// + ParamDecl* isReferenceIntoFunctionInputParameter(Expr* expr); + + // Create a witness that attests to the fact that `type` + // is equal to itself. + TypeEqualityWitness* createTypeEqualityWitness(Type* type); + + // In the case where we are explicitly applying a generic + // to arguments (e.g., `G`) check that the constraints + // on those parameters are satisfied. + // + // Note: the constraints actually work as additional parameters/arguments + // of the generic, and so we need to reify them into the final + // argument list. + // + bool TryCheckOverloadCandidateConstraints( + OverloadResolveContext& context, + OverloadCandidate& candidate); + + // Try to check an overload candidate, but bail out + // if any step fails + void TryCheckOverloadCandidate(OverloadResolveContext& context, OverloadCandidate& candidate); + + // Create the representation of a given generic applied to some arguments + Expr* createGenericDeclRef(Expr* baseExpr, Expr* originalExpr, SubstitutionSet substSet); + + // Take an overload candidate that previously got through + // `TryCheckOverloadCandidate` above, and try to finish + // up the work and turn it into a real expression. + // + // If the candidate isn't actually applicable, this is + // where we'd start reporting the issue(s). + Expr* CompleteOverloadCandidate(OverloadResolveContext& context, OverloadCandidate& candidate); + + // Implement a comparison operation between overload candidates, + // so that the better candidate compares as less-than the other + int CompareOverloadCandidates(OverloadCandidate* left, OverloadCandidate* right); + + /// If `declRef` representations a specialization of a generic, returns the number of + /// specialized generic arguments. Otherwise, returns zero. + /// + Int getSpecializedParamCount(DeclRef const& declRef); + + /// Compare items `left` and `right` produced by lookup, to see if one should be favored for + /// overloading. + int CompareLookupResultItems(LookupResultItem const& left, LookupResultItem const& right); + + /// Compare items `left` and `right` being considered as overload candidates, and determine if + /// one should be favored for structural reasons. + int compareOverloadCandidateSpecificity( + LookupResultItem const& left, + LookupResultItem const& right); + + void AddOverloadCandidateInner(OverloadResolveContext& context, OverloadCandidate& candidate); + + void AddOverloadCandidate( + OverloadResolveContext& context, + OverloadCandidate& candidate, + ConversionCost baseCost); + + void AddHigherOrderOverloadCandidates( + Expr* funcExpr, + OverloadResolveContext& context, + ConversionCost baseCost); + + void AddFuncOverloadCandidate( + LookupResultItem item, + DeclRef funcDeclRef, + OverloadResolveContext& context, + ConversionCost baseCost); + + void AddFuncOverloadCandidate( + FuncType* /*funcType*/, + OverloadResolveContext& /*context*/, + ConversionCost baseCost); + + void AddFuncExprOverloadCandidate( + FuncType* funcType, + OverloadResolveContext& context, + Expr* expr, + ConversionCost baseCost); + + // Add a candidate callee for overload resolution, based on + // calling a particular `ConstructorDecl`. + void AddCtorOverloadCandidate( + LookupResultItem typeItem, + Type* type, + DeclRef ctorDeclRef, + OverloadResolveContext& context, + Type* resultType, + ConversionCost baseCost); + + // If the given declaration has generic parameters, then + // return the corresponding `GenericDecl` that holds the + // parameters, etc. This returns the immediate generic parent + // of `decl`, e.g. the generic for f, and *not* any indirect + // generic parents, such as P.f(). + GenericDecl* GetOuterGeneric(Decl* decl); + + // If `decl` is inside a generic, return that outer generic, + // otherwise returns `decl`. + Decl* getOuterGenericOrSelf(Decl* decl); + + // Find the next outer generic parent of `decl`, including + // indirect parents. + GenericDecl* findNextOuterGeneric(Decl* decl); + + struct ValUnificationContext + { + Index indexInTypePack = 0; + }; - Expr* CheckSwizzleExpr( - MemberExpr* memberRefExpr, - Type* baseElementType, - IntegerLiteralValue baseElementCount); + // Try to find a unification for two values + bool TryUnifyVals( + ConstraintSystem& constraints, + ValUnificationContext unificationContext, + Val* fst, + bool fstLVal, + Val* snd, + bool sndLVal); + + bool tryUnifyDeclRef( + ConstraintSystem& constraints, + ValUnificationContext unificationContext, + DeclRefBase* fst, + bool fstLVal, + DeclRefBase* snd, + bool sndLVal); + + bool tryUnifyGenericAppDeclRef( + ConstraintSystem& constraints, + ValUnificationContext unificationContext, + GenericAppDeclRef* fst, + bool fstLVal, + GenericAppDeclRef* snd, + bool sndLVal); + + bool TryUnifyTypeParam( + ConstraintSystem& constraints, + ValUnificationContext unificationContext, + GenericTypeParamDeclBase* typeParamDecl, + QualType type); + + bool TryUnifyIntParam( + ConstraintSystem& constraints, + ValUnificationContext unificationContext, + GenericValueParamDecl* paramDecl, + IntVal* val); + + bool TryUnifyIntParam( + ConstraintSystem& constraints, + ValUnificationContext unificationContext, + DeclRef const& varRef, + IntVal* val); + + bool TryUnifyTypesByStructuralMatch( + ConstraintSystem& constraints, + ValUnificationContext unificationContext, + QualType fst, + QualType snd); + + bool TryUnifyTypes( + ConstraintSystem& constraints, + ValUnificationContext unificationContext, + QualType fst, + QualType snd); + + bool TryUnifyConjunctionType( + ConstraintSystem& constraints, + ValUnificationContext unificationContext, + QualType fst, + QualType snd); + + void maybeUnifyUnconstraintIntParam( + ConstraintSystem& constraints, + ValUnificationContext unificationContext, + IntVal* param, + IntVal* arg, + bool paramIsLVal); + + // Is the candidate extension declaration actually applicable to the given type + DeclRef applyExtensionToType( + ExtensionDecl* extDecl, + Type* type, + Dictionary* additionalSubtypeWitnessesForType = nullptr); + + // Take a generic declaration that is being applied + // in a context and attempt to infer any missing generic + // arguments to form a `DeclRef` to the inner declaration + // that could be applicable in the context of the given + // overloaded call. + // Also computes a `baseCost` for the inferred arguments, + // so that we can prefer a more specialized generic candidate + // when there is ambiguity. For example, given + // ``` + // interface IBase; + // interface IDerived : IBase; + // struct Derived : IDerived {} + // void f1(T b) + // void f2(T b); + // ``` + // We will prefer f2 when seeing f(Derived()), because it takes + // less steps to upcast `Derived` to `IDerived` than it does + // to `IBase`. + // + DeclRef inferGenericArguments( + DeclRef genericDeclRef, + OverloadResolveContext& context, + ArrayView knownGenericArgs, + ConversionCost& outBaseCost, + List* innerParameterTypes = nullptr); + + void AddTypeOverloadCandidates(Type* type, OverloadResolveContext& context); + + void AddDeclRefOverloadCandidates( + LookupResultItem item, + OverloadResolveContext& context, + ConversionCost baseCost); + + void AddOverloadCandidates(LookupResult const& result, OverloadResolveContext& context); + + void AddOverloadCandidates(Expr* funcExpr, OverloadResolveContext& context); + + String getCallSignatureString(OverloadResolveContext& context); + + Expr* ResolveInvoke(InvokeExpr* expr); + + void AddGenericOverloadCandidate(LookupResultItem baseItem, OverloadResolveContext& context); + + void AddGenericOverloadCandidates(Expr* baseExpr, OverloadResolveContext& context); + + template + void trySetGenericToRayTracingWithParamAttribute( + LookupResultItem genericItem, + DeclRef genericDeclRef, + OverloadResolveContext& context); + + // Add overload candidates based on use of `genericDeclRef` + // in an ordinary function-call context (that is, where it + // has been applied to arguments using `()` and not `<>`). + // + // If some or all of the generic arguments to `genericDeclRef` + // are known at the call site, they should be passed in via + // `substWithKnownGenericArgs`. + // + void addOverloadCandidatesForCallToGeneric( + LookupResultItem genericItem, + OverloadResolveContext& context, + ArrayView knownGenericArgs); + + /// Check a generic application where the operands have already been checked. + Expr* checkGenericAppWithCheckedArgs(GenericAppExpr* genericAppExpr); + + Expr* CheckExpr(Expr* expr); + + + void compareMemoryQualifierOfParamToArgument(ParamDecl* paramIn, Expr* argIn); + Expr* CheckInvokeExprWithCheckedOperands(InvokeExpr* expr); + // Get the type to use when referencing a declaration + QualType GetTypeForDeclRef(DeclRef declRef, SourceLoc loc); + + // + // + // + + Expr* MaybeDereference(Expr* inExpr); + + Expr* CheckMatrixSwizzleExpr( + MemberExpr* memberRefExpr, + Type* baseElementType, + IntegerLiteralValue baseElementRowCount, + IntegerLiteralValue baseElementColCount); + + Expr* CheckMatrixSwizzleExpr( + MemberExpr* memberRefExpr, + Type* baseElementType, + IntVal* baseElementRowCount, + IntVal* baseElementColCount); + + Expr* checkTupleSwizzleExpr(MemberExpr* memberExpr, TupleType* baseTupleType); + + Expr* CheckSwizzleExpr( + MemberExpr* memberRefExpr, + Type* baseElementType, + IntegerLiteralValue baseElementCount); - Expr* CheckSwizzleExpr( - MemberExpr* memberRefExpr, - Type* baseElementType, - IntVal* baseElementCount); + Expr* CheckSwizzleExpr( + MemberExpr* memberRefExpr, + Type* baseElementType, + IntVal* baseElementCount); - // Check a member expr as a general member lookup. - // This is the default/fallback behavior if the base type isn't swizzlable. - Expr* checkGeneralMemberLookupExpr(MemberExpr* expr, Type* baseType); + // Check a member expr as a general member lookup. + // This is the default/fallback behavior if the base type isn't swizzlable. + Expr* checkGeneralMemberLookupExpr(MemberExpr* expr, Type* baseType); - /// Perform semantic checking of an assignment where the operands have already been checked. - Expr* checkAssignWithCheckedOperands(AssignExpr* expr); + /// Perform semantic checking of an assignment where the operands have already been checked. + Expr* checkAssignWithCheckedOperands(AssignExpr* expr); - // Look up a static member - // @param expr Can be StaticMemberExpr or MemberExpr - // @param baseExpression Is the underlying type expression determined from resolving expr - Expr* _lookupStaticMember(DeclRefExpr* expr, Expr* baseExpression); + // Look up a static member + // @param expr Can be StaticMemberExpr or MemberExpr + // @param baseExpression Is the underlying type expression determined from resolving expr + Expr* _lookupStaticMember(DeclRefExpr* expr, Expr* baseExpression); - Expr* visitStaticMemberExpr(StaticMemberExpr* expr); + Expr* visitStaticMemberExpr(StaticMemberExpr* expr); - /// Perform checking operations required for the "base" expression of a member-reference like `base.someField` - Expr* checkBaseForMemberExpr(Expr* baseExpr, bool& outNeedDeref); + /// Perform checking operations required for the "base" expression of a member-reference like + /// `base.someField` + Expr* checkBaseForMemberExpr(Expr* baseExpr, bool& outNeedDeref); - /// Prepare baseExpr for use as the base of a member expr. - /// This include inserting implicit open-existential operations as needed. - Expr* maybeInsertImplicitOpForMemberBase(Expr* baseExpr, bool& outNeedDeref); + /// Prepare baseExpr for use as the base of a member expr. + /// This include inserting implicit open-existential operations as needed. + Expr* maybeInsertImplicitOpForMemberBase(Expr* baseExpr, bool& outNeedDeref); - Expr* lookupMemberResultFailure( - DeclRefExpr* expr, - QualType const& baseType, - bool supressDiagnostic = false); + Expr* lookupMemberResultFailure( + DeclRefExpr* expr, + QualType const& baseType, + bool supressDiagnostic = false); - SharedSemanticsContext& operator=(const SharedSemanticsContext &) = delete; + SharedSemanticsContext& operator=(const SharedSemanticsContext&) = delete; - // + // - void importModuleIntoScope(Scope* scope, ModuleDecl* moduleDecl); - void importFileDeclIntoScope(Scope* scope, FileDecl* fileDecl); + void importModuleIntoScope(Scope* scope, ModuleDecl* moduleDecl); + void importFileDeclIntoScope(Scope* scope, FileDecl* fileDecl); - void suggestCompletionItems( - CompletionSuggestions::ScopeKind scopeKind, LookupResult const& lookupResult); - }; + void suggestCompletionItems( + CompletionSuggestions::ScopeKind scopeKind, + LookupResult const& lookupResult); +}; - inline void ensureDecl(SemanticsVisitor* visitor, Decl* decl, DeclCheckState state) - { - visitor->ensureDecl(decl, state); - } +inline void ensureDecl(SemanticsVisitor* visitor, Decl* decl, DeclCheckState state) +{ + visitor->ensureDecl(decl, state); +} - DeclRef applyExtensionToType( - SemanticsVisitor* semantics, - ExtensionDecl* extDecl, - Type* type, - Dictionary* additionalSubtypeWitness = nullptr); +DeclRef applyExtensionToType( + SemanticsVisitor* semantics, + ExtensionDecl* extDecl, + Type* type, + Dictionary* additionalSubtypeWitness = nullptr); - struct SemanticsExprVisitor - : public SemanticsVisitor - , ExprVisitor +struct SemanticsExprVisitor : public SemanticsVisitor, ExprVisitor +{ +public: + SemanticsExprVisitor(SemanticsContext const& outer) + : SemanticsVisitor(outer) { - public: - SemanticsExprVisitor(SemanticsContext const& outer) - : SemanticsVisitor(outer) - {} + } - Expr* visitSizeOfLikeExpr(SizeOfLikeExpr* expr); + Expr* visitSizeOfLikeExpr(SizeOfLikeExpr* expr); - Expr* visitIncompleteExpr(IncompleteExpr* expr); - Expr* visitBoolLiteralExpr(BoolLiteralExpr* expr); - Expr* visitNullPtrLiteralExpr(NullPtrLiteralExpr* expr); - Expr* visitNoneLiteralExpr(NoneLiteralExpr* expr); - Expr* visitIntegerLiteralExpr(IntegerLiteralExpr* expr); - Expr* visitFloatingPointLiteralExpr(FloatingPointLiteralExpr* expr); - Expr* visitStringLiteralExpr(StringLiteralExpr* expr); + Expr* visitIncompleteExpr(IncompleteExpr* expr); + Expr* visitBoolLiteralExpr(BoolLiteralExpr* expr); + Expr* visitNullPtrLiteralExpr(NullPtrLiteralExpr* expr); + Expr* visitNoneLiteralExpr(NoneLiteralExpr* expr); + Expr* visitIntegerLiteralExpr(IntegerLiteralExpr* expr); + Expr* visitFloatingPointLiteralExpr(FloatingPointLiteralExpr* expr); + Expr* visitStringLiteralExpr(StringLiteralExpr* expr); - Expr* visitIndexExpr(IndexExpr* subscriptExpr); + Expr* visitIndexExpr(IndexExpr* subscriptExpr); - Expr* visitParenExpr(ParenExpr* expr); + Expr* visitParenExpr(ParenExpr* expr); - Expr* visitAssignExpr(AssignExpr* expr); + Expr* visitAssignExpr(AssignExpr* expr); - Expr* visitGenericAppExpr(GenericAppExpr* genericAppExpr); + Expr* visitGenericAppExpr(GenericAppExpr* genericAppExpr); - Expr* visitSharedTypeExpr(SharedTypeExpr* expr); + Expr* visitSharedTypeExpr(SharedTypeExpr* expr); - Expr* visitInvokeExpr(InvokeExpr *expr); + Expr* visitInvokeExpr(InvokeExpr* expr); - Expr* visitSelectExpr(SelectExpr* expr); + Expr* visitSelectExpr(SelectExpr* expr); - Expr* visitVarExpr(VarExpr *expr); + Expr* visitVarExpr(VarExpr* expr); - Expr* visitTypeCastExpr(TypeCastExpr * expr); + Expr* visitTypeCastExpr(TypeCastExpr* expr); - Expr* visitBuiltinCastExpr(BuiltinCastExpr* expr); + Expr* visitBuiltinCastExpr(BuiltinCastExpr* expr); - Expr* visitTryExpr(TryExpr* expr); + Expr* visitTryExpr(TryExpr* expr); - Expr* visitIsTypeExpr(IsTypeExpr* expr); + Expr* visitIsTypeExpr(IsTypeExpr* expr); - Expr* visitAsTypeExpr(AsTypeExpr* expr); + Expr* visitAsTypeExpr(AsTypeExpr* expr); - Expr* visitExpandExpr(ExpandExpr* expr); + Expr* visitExpandExpr(ExpandExpr* expr); - Expr* visitEachExpr(EachExpr* expr); + Expr* visitEachExpr(EachExpr* expr); - void maybeCheckKnownBuiltinInvocation(Expr* invokeExpr); + void maybeCheckKnownBuiltinInvocation(Expr* invokeExpr); - // - // Some syntax nodes should not occur in the concrete input syntax, - // and will only appear *after* checking is complete. We need to - // deal with this cases here, even if they are no-ops. - // + // + // Some syntax nodes should not occur in the concrete input syntax, + // and will only appear *after* checking is complete. We need to + // deal with this cases here, even if they are no-ops. + // - #define CASE(NAME) \ - Expr* visit##NAME(NAME* expr) \ - { \ - if (!getShared()->isInLanguageServer()) \ - SLANG_DIAGNOSE_UNEXPECTED(getSink(), expr, "should not appear in input syntax"); \ - expr->type = m_astBuilder->getErrorType(); \ - return expr; \ - } +#define CASE(NAME) \ + Expr* visit##NAME(NAME* expr) \ + { \ + if (!getShared()->isInLanguageServer()) \ + SLANG_DIAGNOSE_UNEXPECTED(getSink(), expr, "should not appear in input syntax"); \ + expr->type = m_astBuilder->getErrorType(); \ + return expr; \ + } - CASE(DerefExpr) - CASE(MakeRefExpr) - CASE(MatrixSwizzleExpr) - CASE(SwizzleExpr) - CASE(OverloadedExpr) - CASE(OverloadedExpr2) - CASE(AggTypeCtorExpr) - CASE(ModifierCastExpr) - CASE(LetExpr) - CASE(ExtractExistentialValueExpr) - CASE(OpenRefExpr) - CASE(MakeOptionalExpr) - CASE(PartiallyAppliedGenericExpr) - CASE(PackExpr) - #undef CASE + CASE(DerefExpr) + CASE(MakeRefExpr) + CASE(MatrixSwizzleExpr) + CASE(SwizzleExpr) + CASE(OverloadedExpr) + CASE(OverloadedExpr2) + CASE(AggTypeCtorExpr) + CASE(ModifierCastExpr) + CASE(LetExpr) + CASE(ExtractExistentialValueExpr) + CASE(OpenRefExpr) + CASE(MakeOptionalExpr) + CASE(PartiallyAppliedGenericExpr) + CASE(PackExpr) +#undef CASE - Expr* visitStaticMemberExpr(StaticMemberExpr* expr); + Expr* visitStaticMemberExpr(StaticMemberExpr* expr); - Expr* visitMemberExpr(MemberExpr * expr); + Expr* visitMemberExpr(MemberExpr* expr); - Expr* visitInitializerListExpr(InitializerListExpr* expr); + Expr* visitInitializerListExpr(InitializerListExpr* expr); - Expr* visitThisExpr(ThisExpr* expr); - Expr* visitThisTypeExpr(ThisTypeExpr* expr); - Expr* visitCastToSuperTypeExpr(CastToSuperTypeExpr* expr); - Expr* visitReturnValExpr(ReturnValExpr* expr); - Expr* visitAndTypeExpr(AndTypeExpr* expr); - Expr* visitPointerTypeExpr(PointerTypeExpr* expr); - Expr* visitModifiedTypeExpr(ModifiedTypeExpr* expr); - Expr* visitFuncTypeExpr(FuncTypeExpr* expr); - Expr* visitTupleTypeExpr(TupleTypeExpr* expr); + Expr* visitThisExpr(ThisExpr* expr); + Expr* visitThisTypeExpr(ThisTypeExpr* expr); + Expr* visitCastToSuperTypeExpr(CastToSuperTypeExpr* expr); + Expr* visitReturnValExpr(ReturnValExpr* expr); + Expr* visitAndTypeExpr(AndTypeExpr* expr); + Expr* visitPointerTypeExpr(PointerTypeExpr* expr); + Expr* visitModifiedTypeExpr(ModifiedTypeExpr* expr); + Expr* visitFuncTypeExpr(FuncTypeExpr* expr); + Expr* visitTupleTypeExpr(TupleTypeExpr* expr); - Expr* visitForwardDifferentiateExpr(ForwardDifferentiateExpr* expr); - Expr* visitBackwardDifferentiateExpr(BackwardDifferentiateExpr* expr); - Expr* visitPrimalSubstituteExpr(PrimalSubstituteExpr* expr); - Expr* visitDispatchKernelExpr(DispatchKernelExpr* expr); + Expr* visitForwardDifferentiateExpr(ForwardDifferentiateExpr* expr); + Expr* visitBackwardDifferentiateExpr(BackwardDifferentiateExpr* expr); + Expr* visitPrimalSubstituteExpr(PrimalSubstituteExpr* expr); + Expr* visitDispatchKernelExpr(DispatchKernelExpr* expr); - Expr* visitTreatAsDifferentiableExpr(TreatAsDifferentiableExpr* expr); + Expr* visitTreatAsDifferentiableExpr(TreatAsDifferentiableExpr* expr); - Expr* visitGetArrayLengthExpr(GetArrayLengthExpr* expr); + Expr* visitGetArrayLengthExpr(GetArrayLengthExpr* expr); - Expr* visitDefaultConstructExpr(DefaultConstructExpr* expr); + Expr* visitDefaultConstructExpr(DefaultConstructExpr* expr); - Expr* visitDetachExpr(DetachExpr* expr); + Expr* visitDetachExpr(DetachExpr* expr); - Expr* visitSPIRVAsmExpr(SPIRVAsmExpr*); + Expr* visitSPIRVAsmExpr(SPIRVAsmExpr*); - /// Perform semantic checking on a `modifier` that is being applied to the given `type` - Val* checkTypeModifier(Modifier* modifier, Type* type); - private: - // Convert the logic operator expression to not use 'InvokeExpr' type - Expr* convertToLogicOperatorExpr(InvokeExpr* expr); + /// Perform semantic checking on a `modifier` that is being applied to the given `type` + Val* checkTypeModifier(Modifier* modifier, Type* type); - }; +private: + // Convert the logic operator expression to not use 'InvokeExpr' type + Expr* convertToLogicOperatorExpr(InvokeExpr* expr); +}; - struct SemanticsStmtVisitor - : public SemanticsVisitor - , StmtVisitor +struct SemanticsStmtVisitor : public SemanticsVisitor, StmtVisitor +{ + SemanticsStmtVisitor(SemanticsContext const& outer) + : SemanticsVisitor(outer) { - SemanticsStmtVisitor(SemanticsContext const& outer) - : SemanticsVisitor(outer) - {} + } - FunctionDeclBase* getParentFunc() { return m_parentFunc; } + FunctionDeclBase* getParentFunc() { return m_parentFunc; } - void checkStmt(Stmt* stmt); + void checkStmt(Stmt* stmt); - template - T* FindOuterStmt(); + template + T* FindOuterStmt(); - Stmt* findOuterStmtWithLabel(Name* label); + Stmt* findOuterStmtWithLabel(Name* label); - void visitDeclStmt(DeclStmt* stmt); + void visitDeclStmt(DeclStmt* stmt); - void visitBlockStmt(BlockStmt* stmt); + void visitBlockStmt(BlockStmt* stmt); - void visitSeqStmt(SeqStmt* stmt); + void visitSeqStmt(SeqStmt* stmt); - void visitLabelStmt(LabelStmt* stmt); + void visitLabelStmt(LabelStmt* stmt); - void visitBreakStmt(BreakStmt *stmt); + void visitBreakStmt(BreakStmt* stmt); - void visitContinueStmt(ContinueStmt *stmt); + void visitContinueStmt(ContinueStmt* stmt); - void visitDoWhileStmt(DoWhileStmt *stmt); + void visitDoWhileStmt(DoWhileStmt* stmt); - void visitForStmt(ForStmt *stmt); + void visitForStmt(ForStmt* stmt); - void visitCompileTimeForStmt(CompileTimeForStmt* stmt); + void visitCompileTimeForStmt(CompileTimeForStmt* stmt); - void visitSwitchStmt(SwitchStmt* stmt); + void visitSwitchStmt(SwitchStmt* stmt); - void visitCaseStmt(CaseStmt* stmt); + void visitCaseStmt(CaseStmt* stmt); - void visitTargetSwitchStmt(TargetSwitchStmt* stmt); + void visitTargetSwitchStmt(TargetSwitchStmt* stmt); - void visitTargetCaseStmt(TargetCaseStmt* stmt); + void visitTargetCaseStmt(TargetCaseStmt* stmt); - void visitIntrinsicAsmStmt(IntrinsicAsmStmt*); + void visitIntrinsicAsmStmt(IntrinsicAsmStmt*); - void visitDefaultStmt(DefaultStmt* stmt); + void visitDefaultStmt(DefaultStmt* stmt); - void visitIfStmt(IfStmt *stmt); + void visitIfStmt(IfStmt* stmt); - void visitUnparsedStmt(UnparsedStmt*); + void visitUnparsedStmt(UnparsedStmt*); - void visitEmptyStmt(EmptyStmt*); + void visitEmptyStmt(EmptyStmt*); - void visitDiscardStmt(DiscardStmt*); + void visitDiscardStmt(DiscardStmt*); - void visitReturnStmt(ReturnStmt *stmt); + void visitReturnStmt(ReturnStmt* stmt); - void visitWhileStmt(WhileStmt *stmt); - - void visitGpuForeachStmt(GpuForeachStmt *stmt); + void visitWhileStmt(WhileStmt* stmt); - void visitExpressionStmt(ExpressionStmt *stmt); + void visitGpuForeachStmt(GpuForeachStmt* stmt); - // Try to infer the max number of iterations the loop will run. - void tryInferLoopMaxIterations(ForStmt* stmt); + void visitExpressionStmt(ExpressionStmt* stmt); - void checkLoopInDifferentiableFunc(Stmt* stmt); + // Try to infer the max number of iterations the loop will run. + void tryInferLoopMaxIterations(ForStmt* stmt); - private: - void validateCaseStmts(SwitchStmt* stmt, DiagnosticSink* sink); - }; + void checkLoopInDifferentiableFunc(Stmt* stmt); - struct SemanticsDeclVisitorBase - : public SemanticsVisitor +private: + void validateCaseStmts(SwitchStmt* stmt, DiagnosticSink* sink); +}; + +struct SemanticsDeclVisitorBase : public SemanticsVisitor +{ + SemanticsDeclVisitorBase(SemanticsContext const& outer) + : SemanticsVisitor(outer) { - SemanticsDeclVisitorBase(SemanticsContext const& outer) - : SemanticsVisitor(outer) - {} + } - void checkBodyStmt(Stmt* stmt, FunctionDeclBase* parentDecl) - { - checkStmt(stmt, withParentFunc(parentDecl)); - } + void checkBodyStmt(Stmt* stmt, FunctionDeclBase* parentDecl) + { + checkStmt(stmt, withParentFunc(parentDecl)); + } - void checkModule(ModuleDecl* programNode); - }; + void checkModule(ModuleDecl* programNode); +}; - bool isUnsizedArrayType(Type* type); +bool isUnsizedArrayType(Type* type); - bool isInterfaceType(Type* type); +bool isInterfaceType(Type* type); - EnumDecl* isEnumType(Type* type); +EnumDecl* isEnumType(Type* type); - DeclVisibility getDeclVisibility(Decl* decl); +DeclVisibility getDeclVisibility(Decl* decl); - // If `type` is unsized, return the trailing unsized array field that makes it so. - VarDeclBase* getTrailingUnsizedArrayElement(Type* type, VarDeclBase* rootObject, ArrayExpressionType*& outArrayType); +// If `type` is unsized, return the trailing unsized array field that makes it so. +VarDeclBase* getTrailingUnsizedArrayElement( + Type* type, + VarDeclBase* rootObject, + ArrayExpressionType*& outArrayType); - // Test if `type` can be an opaque handle on certain targets, this includes - // texture, buffer, sampler, acceleration structure, etc. - bool isOpaqueHandleType(Type* type); +// Test if `type` can be an opaque handle on certain targets, this includes +// texture, buffer, sampler, acceleration structure, etc. +bool isOpaqueHandleType(Type* type); - void diagnoseMissingCapabilityProvenance(CompilerOptionSet& optionSet, DiagnosticSink* sink, Decl* decl, CapabilitySet& setToFind); - void diagnoseCapabilityProvenance(CompilerOptionSet& optionSet, DiagnosticSink* sink, Decl* decl, CapabilityAtom atomToFind, HashSet& printedDecls); +void diagnoseMissingCapabilityProvenance( + CompilerOptionSet& optionSet, + DiagnosticSink* sink, + Decl* decl, + CapabilitySet& setToFind); +void diagnoseCapabilityProvenance( + CompilerOptionSet& optionSet, + DiagnosticSink* sink, + Decl* decl, + CapabilityAtom atomToFind, + HashSet& printedDecls); - void _ensureAllDeclsRec( - SemanticsDeclVisitorBase* visitor, - Decl* decl, - DeclCheckState state); +void _ensureAllDeclsRec(SemanticsDeclVisitorBase* visitor, Decl* decl, DeclCheckState state); - RefPtr findAndValidateEntryPoint( - FrontEndEntryPointRequest* entryPointReq); +RefPtr findAndValidateEntryPoint(FrontEndEntryPointRequest* entryPointReq); - bool resolveStageOfProfileWithEntryPoint(Profile& entryPointProfile, CompilerOptionSet& optionSet, const List>& targets, FuncDecl* entryPointFuncDecl, DiagnosticSink* sink); -} +bool resolveStageOfProfileWithEntryPoint( + Profile& entryPointProfile, + CompilerOptionSet& optionSet, + const List>& targets, + FuncDecl* entryPointFuncDecl, + DiagnosticSink* sink); +} // namespace Slang diff --git a/source/slang/slang-check-inheritance.cpp b/source/slang/slang-check-inheritance.cpp index 360fc0d14..4b0ec0f55 100644 --- a/source/slang/slang-check-inheritance.cpp +++ b/source/slang/slang-check-inheritance.cpp @@ -7,1135 +7,1166 @@ namespace Slang { - InheritanceInfo SharedSemanticsContext::getInheritanceInfo(Type* type, InheritanceCircularityInfo* circularityInfo) - { - // We cache the computed inheritance information for types, - // and re-use that information whenever possible. - - // DeclRefTypes will have their inheritance info cached in m_mapDeclRefToInheritanceInfo. - if (auto declRefType = as(type)) - return _getInheritanceInfo(declRefType->getDeclRef(), declRefType, circularityInfo); - - // Non ordinary types are cached on m_mapTypeToInheritanceInfo. - if (auto found = m_mapTypeToInheritanceInfo.tryGetValue(type)) - return *found; - - // Note: we install a null pointer into the dictionary to act - // as a sentinel during the processing of calculating the inheritnace - // info. If we encounter this sentinel value during the calcuation, - // it means that there was some kind of circular dependency in the - // inheritance graph, and we need to avoid crashing or going - // into an infinite loop in such cases. - // - m_mapTypeToInheritanceInfo[type] = InheritanceInfo(); +InheritanceInfo SharedSemanticsContext::getInheritanceInfo( + Type* type, + InheritanceCircularityInfo* circularityInfo) +{ + // We cache the computed inheritance information for types, + // and re-use that information whenever possible. + + // DeclRefTypes will have their inheritance info cached in m_mapDeclRefToInheritanceInfo. + if (auto declRefType = as(type)) + return _getInheritanceInfo(declRefType->getDeclRef(), declRefType, circularityInfo); + + // Non ordinary types are cached on m_mapTypeToInheritanceInfo. + if (auto found = m_mapTypeToInheritanceInfo.tryGetValue(type)) + return *found; + + // Note: we install a null pointer into the dictionary to act + // as a sentinel during the processing of calculating the inheritnace + // info. If we encounter this sentinel value during the calcuation, + // it means that there was some kind of circular dependency in the + // inheritance graph, and we need to avoid crashing or going + // into an infinite loop in such cases. + // + m_mapTypeToInheritanceInfo[type] = InheritanceInfo(); - auto info = _calcInheritanceInfo(type, circularityInfo); - m_mapTypeToInheritanceInfo[type] = info; + auto info = _calcInheritanceInfo(type, circularityInfo); + m_mapTypeToInheritanceInfo[type] = info; - return info; - } + return info; +} - InheritanceInfo SharedSemanticsContext::getInheritanceInfo(DeclRef const& extension, InheritanceCircularityInfo* circularityInfo) +InheritanceInfo SharedSemanticsContext::getInheritanceInfo( + DeclRef const& extension, + InheritanceCircularityInfo* circularityInfo) +{ + if (_checkForCircularityInExtensionTargetType(extension.getDecl(), circularityInfo)) { - if (_checkForCircularityInExtensionTargetType(extension.getDecl(), circularityInfo)) - { - // If we detect a circularity in the inheritance graph, - // we will return an empty `InheritanceInfo` to avoid - // infinite recursion. - // - return InheritanceInfo(); - } - - // We bottleneck the calculation of inheritance information - // for type and `extension` `DeclRef`s through a single - // routine with an optional `Type` parameter. + // If we detect a circularity in the inheritance graph, + // we will return an empty `InheritanceInfo` to avoid + // infinite recursion. // - InheritanceCircularityInfo newCircularityInfo(extension.getDecl(), circularityInfo); - return _getInheritanceInfo(extension, nullptr, &newCircularityInfo); + return InheritanceInfo(); } - bool SharedSemanticsContext::_checkForCircularityInExtensionTargetType( - Decl* decl, - InheritanceCircularityInfo* circularityInfo) + // We bottleneck the calculation of inheritance information + // for type and `extension` `DeclRef`s through a single + // routine with an optional `Type` parameter. + // + InheritanceCircularityInfo newCircularityInfo(extension.getDecl(), circularityInfo); + return _getInheritanceInfo(extension, nullptr, &newCircularityInfo); +} + +bool SharedSemanticsContext::_checkForCircularityInExtensionTargetType( + Decl* decl, + InheritanceCircularityInfo* circularityInfo) +{ + for (auto info = circularityInfo; info; info = info->next) { - for (auto info = circularityInfo; info; info = info->next) + if (decl == info->decl) { - if (decl == info->decl) - { - getSink()->diagnose(decl, Diagnostics::circularityInExtension, decl); - return true; - } + getSink()->diagnose(decl, Diagnostics::circularityInExtension, decl); + return true; } - - return false; } - InheritanceInfo SharedSemanticsContext::_getInheritanceInfo(DeclRef declRef, DeclRefType* declRefType, InheritanceCircularityInfo* circularityInfo) - { - // Just as with `Type`s, we cache and re-use the inheritance - // information that has been computed for a `DeclRef` whenever - // possible. - - if (auto found = m_mapDeclRefToInheritanceInfo.tryGetValue(declRef)) - return *found; - - // Note: we install a null pointer into the dictionary to act - // as a sentinel during the processing of calculating the inheritnace - // info. If we encounter this sentinel value during the calcuation, - // it means that there was some kind of circular dependency in the - // inheritance graph, and we need to avoid crashing or going - // into an infinite loop in such cases. - // - m_mapDeclRefToInheritanceInfo[declRef] = InheritanceInfo(); + return false; +} - auto info = _calcInheritanceInfo(declRef, declRefType, circularityInfo); - m_mapDeclRefToInheritanceInfo[declRef] = info; +InheritanceInfo SharedSemanticsContext::_getInheritanceInfo( + DeclRef declRef, + DeclRefType* declRefType, + InheritanceCircularityInfo* circularityInfo) +{ + // Just as with `Type`s, we cache and re-use the inheritance + // information that has been computed for a `DeclRef` whenever + // possible. + + if (auto found = m_mapDeclRefToInheritanceInfo.tryGetValue(declRef)) + return *found; + + // Note: we install a null pointer into the dictionary to act + // as a sentinel during the processing of calculating the inheritnace + // info. If we encounter this sentinel value during the calcuation, + // it means that there was some kind of circular dependency in the + // inheritance graph, and we need to avoid crashing or going + // into an infinite loop in such cases. + // + m_mapDeclRefToInheritanceInfo[declRef] = InheritanceInfo(); - getSession()->m_typeDictionarySize = Math::Max( - getSession()->m_typeDictionarySize, (int)m_mapDeclRefToInheritanceInfo.getCount()); + auto info = _calcInheritanceInfo(declRef, declRefType, circularityInfo); + m_mapDeclRefToInheritanceInfo[declRef] = info; - return info; - } + getSession()->m_typeDictionarySize = Math::Max( + getSession()->m_typeDictionarySize, + (int)m_mapDeclRefToInheritanceInfo.getCount()); - void SharedSemanticsContext::getDependentGenericParentImpl(DeclRef& genericParent, DeclRef declRef) - { - auto mergeParent = [](DeclRef& currentParent, DeclRef newParent) - { - if (!currentParent) - { - currentParent = newParent; - return; - } - if (currentParent == newParent) - return; - if (newParent.getDecl()->isChildOf(currentParent.getDecl())) - currentParent = newParent; - }; + return info; +} - if (declRef.as()) +void SharedSemanticsContext::getDependentGenericParentImpl( + DeclRef& genericParent, + DeclRef declRef) +{ + auto mergeParent = [](DeclRef& currentParent, DeclRef newParent) + { + if (!currentParent) { - if (!genericParent) - mergeParent(genericParent, declRef.getParent().as()); + currentParent = newParent; return; } - else if (auto lookupDeclRef = as(declRef.declRefBase)) - { - if (auto lookupSourceDeclRef = isDeclRefTypeOf(lookupDeclRef->getLookupSource())) - getDependentGenericParentImpl(genericParent, lookupSourceDeclRef); - } - else if (auto genericAppDeclRef = as(declRef.declRefBase)) + if (currentParent == newParent) + return; + if (newParent.getDecl()->isChildOf(currentParent.getDecl())) + currentParent = newParent; + }; + + if (declRef.as()) + { + if (!genericParent) + mergeParent(genericParent, declRef.getParent().as()); + return; + } + else if (auto lookupDeclRef = as(declRef.declRefBase)) + { + if (auto lookupSourceDeclRef = isDeclRefTypeOf(lookupDeclRef->getLookupSource())) + getDependentGenericParentImpl(genericParent, lookupSourceDeclRef); + } + else if (auto genericAppDeclRef = as(declRef.declRefBase)) + { + for (Index i = 0; i < genericAppDeclRef->getArgCount(); i++) { - for (Index i = 0; i < genericAppDeclRef->getArgCount(); i++) + if (auto argDeclRef = isDeclRefTypeOf(genericAppDeclRef->getArg(i))) { - if (auto argDeclRef = isDeclRefTypeOf(genericAppDeclRef->getArg(i))) - { - getDependentGenericParentImpl(genericParent, argDeclRef); - } + getDependentGenericParentImpl(genericParent, argDeclRef); } } } +} + +DeclRef SharedSemanticsContext::getDependentGenericParent(DeclRef declRef) +{ + DeclRef genericParent; + getDependentGenericParentImpl(genericParent, declRef); + return genericParent; +} - DeclRef SharedSemanticsContext::getDependentGenericParent(DeclRef declRef) +InheritanceInfo SharedSemanticsContext::_calcInheritanceInfo( + DeclRef declRef, + DeclRefType* declRefType, + InheritanceCircularityInfo* circularityInfo) +{ + // This method is the main engine for computing linearized inheritance + // lists for types and `extension` declarations. + // + // The approach we use for linearization of an inheritance graph is based on + // what is the most broadly-accepted solution to the problem, presented in + // "A Monotonic Superclass Linearization for Dylan" by Barret et al. + // The algorithm recommended in that paper is also called the "C3 linearization + // algorithm." Many developers are most familiar with C3 linearization because + // it is used to compute the method resolution order (MRO) for a class in Python. + // + // The basic idea is that given a type declaration like: + // + // class A : B, C { ... } + // + // we can construct a linearization of the transitive bases of `A` + // by merging the linearizations for `B` and `C`. Any transitive + // base of `A` should appear in the linearization for `B` and/or `C`, + // so the main tasks are to remove duplicates (when a base type appears + // in both the linearization of `B` *and* `C`), and to ensure that + // the ordering is reasonable. + // + // What makes an ordering "reasonable" is a little subtle, especially + // in the context of Slang. In the original use case, the order of types + // in the linearization would determine which methods would override + // which other ones, so different orderings could have large semantic + // impact. Slang currently has less support for overriding, but is + // likely to add more over time. + // + // At the very least, we require that if `S <: T` for types `S` and `T`, + // then `S` should appear *before* `T` in the linearization. This, e.g., + // guarantees that a concrete type that implements an `interface` will + // be listed before that interface and thus during lookup the members + // of the concrete type will be found before those of the `interface`. + // + // We will revisit the question of "reasonable" orderings later, as + // we get more into the core of the algorithm. + + // Our linearization approach will construct a list of *facets* for + // the `declRef` in question, where each facet corresponds to a + // transitive base type, or an applicable `extension`. + // + FacetList::Builder allFacets; + + // It is possible that `declRef` is itself a type declaration, + // in which case `declRefType` will be the coresponding type. + // However, if `declRef` is an `extension` declaration, we + // will extract the type that the extension applies to, so + // that we can have a consistent "self type" to represent + // the type that is at the root of the inheritance list. + // + Type* selfType = declRefType; + Facet::Kind selfFacetKind = Facet::Kind::Type; + + auto astBuilder = _getASTBuilder(); + auto& arena = astBuilder->getArena(); + SemanticsVisitor visitor(this); + if (auto extensionDeclRef = declRef.as()) { - DeclRef genericParent; - getDependentGenericParentImpl(genericParent, declRef); - return genericParent; + auto extendedType = getTargetType(astBuilder, extensionDeclRef); + selfType = extendedType; + selfFacetKind = Facet::Kind::Extension; } - InheritanceInfo SharedSemanticsContext::_calcInheritanceInfo(DeclRef declRef, DeclRefType* declRefType, InheritanceCircularityInfo* circularityInfo) + // Because we are dealing with entities that have declarations, the + // first item in our linearization will always be a facet for + // the declaration itself. + // + TypeEqualityWitness* selfIsSelf = + selfType ? visitor.createTypeEqualityWitness(selfType) : nullptr; + Facet selfFacet = new (arena) + Facet::Impl(selfFacetKind, Facet::Directness::Self, declRef, selfType, selfIsSelf); + allFacets.add(selfFacet); + + // After the self facet will come a list of facets formed + // by merging the facet lists of each of the direct/declared + // bases of the type/declaration in question. + // + // We will first traverse the structure of `declRef` to + // accumulate the list of bases, and then perform the merge + // when we are done. + // + DirectBaseList::Builder directBases; + FacetList::Builder directBaseFacets; + + // We start with a simple operation to add an entry + // into the list of direct bases, for the case where + // we already have all of the relevant information + // about that base. + // + auto addDirectBaseFacet = [&](Facet::Kind kind, + Type* baseType, + SubtypeWitness* selfIsBaseWitness, + DeclRef const& baseDeclRef, + InheritanceInfo const& baseInheritanceInfo) { - // This method is the main engine for computing linearized inheritance - // lists for types and `extension` declarations. - // - // The approach we use for linearization of an inheritance graph is based on - // what is the most broadly-accepted solution to the problem, presented in - // "A Monotonic Superclass Linearization for Dylan" by Barret et al. - // The algorithm recommended in that paper is also called the "C3 linearization - // algorithm." Many developers are most familiar with C3 linearization because - // it is used to compute the method resolution order (MRO) for a class in Python. - // - // The basic idea is that given a type declaration like: - // - // class A : B, C { ... } + auto baseInfo = new (arena) DirectBaseInfo(); + + // The information we store for each direct + // base comprises two main things. // - // we can construct a linearization of the transitive bases of `A` - // by merging the linearizations for `B` and `C`. Any transitive - // base of `A` should appear in the linearization for `B` and/or `C`, - // so the main tasks are to remove duplicates (when a base type appears - // in both the linearization of `B` *and* `C`), and to ensure that - // the ordering is reasonable. + // First, we have a `Facet` that will represent + // the base in the linearized inheritance list + // we are building. // - // What makes an ordering "reasonable" is a little subtle, especially - // in the context of Slang. In the original use case, the order of types - // in the linearization would determine which methods would override - // which other ones, so different orderings could have large semantic - // impact. Slang currently has less support for overriding, but is - // likely to add more over time. + baseInfo->facetImpl = + FacetImpl(kind, Facet::Directness::Direct, baseDeclRef, baseType, selfIsBaseWitness); + Facet baseFacet(&baseInfo->facetImpl); // - // At the very least, we require that if `S <: T` for types `S` and `T`, - // then `S` should appear *before* `T` in the linearization. This, e.g., - // guarantees that a concrete type that implements an `interface` will - // be listed before that interface and thus during lookup the members - // of the concrete type will be found before those of the `interface`. + // Second, we have a list of the facets in the + // linearization of the base itself. // - // We will revisit the question of "reasonable" orderings later, as - // we get more into the core of the algorithm. + baseInfo->facets = baseInheritanceInfo.facets; - // Our linearization approach will construct a list of *facets* for - // the `declRef` in question, where each facet corresponds to a - // transitive base type, or an applicable `extension`. - // - FacetList::Builder allFacets; - - // It is possible that `declRef` is itself a type declaration, - // in which case `declRefType` will be the coresponding type. - // However, if `declRef` is an `extension` declaration, we - // will extract the type that the extension applies to, so - // that we can have a consistent "self type" to represent - // the type that is at the root of the inheritance list. + directBaseFacets.add(baseFacet); + directBases.add(baseInfo); + }; + + // In the case where we know that the base being added + // represents a direct base *type* (and not an `extension`) + // we can derive some of the information needed by + // `addDirectBaseFacet`. + // + auto addDirectBaseType = [&](Type* baseType, SubtypeWitness* selfIsBaseWitness) + { + // If we are representing inheritance from a type, + // then we should have a witness that the type + // in question (either the type being declared by + // `declRef`, or the type being *extended* by + // `declRef`) inherits from that base. // - Type* selfType = declRefType; - Facet::Kind selfFacetKind = Facet::Kind::Type; + SLANG_ASSERT(selfIsBaseWitness); - auto astBuilder = _getASTBuilder(); - auto& arena = astBuilder->getArena(); - SemanticsVisitor visitor(this); - if (auto extensionDeclRef = declRef.as()) + auto baseInheritanceInfo = getInheritanceInfo(baseType, circularityInfo); + + DeclRef baseDeclRef; + if (auto baseDeclRefType = as(baseType)) { - auto extendedType = getTargetType(astBuilder, extensionDeclRef); - selfType = extendedType; - selfFacetKind = Facet::Kind::Extension; + baseDeclRef = baseDeclRefType->getDeclRef(); } - // Because we are dealing with entities that have declarations, the - // first item in our linearization will always be a facet for - // the declaration itself. - // - TypeEqualityWitness* selfIsSelf = selfType ? visitor.createTypeEqualityWitness(selfType) : nullptr; - Facet selfFacet = new(arena) Facet::Impl( - selfFacetKind, - Facet::Directness::Self, - declRef, - selfType, - selfIsSelf); - allFacets.add(selfFacet); - - // After the self facet will come a list of facets formed - // by merging the facet lists of each of the direct/declared - // bases of the type/declaration in question. - // - // We will first traverse the structure of `declRef` to - // accumulate the list of bases, and then perform the merge - // when we are done. - // - DirectBaseList::Builder directBases; - FacetList::Builder directBaseFacets; + addDirectBaseFacet( + Facet::Kind::Type, + baseType, + selfIsBaseWitness, + baseDeclRef, + baseInheritanceInfo); + }; - // We start with a simple operation to add an entry - // into the list of direct bases, for the case where - // we already have all of the relevant information - // about that base. - // - auto addDirectBaseFacet = [&]( - Facet::Kind kind, - Type* baseType, - SubtypeWitness* selfIsBaseWitness, - DeclRef const& baseDeclRef, - InheritanceInfo const& baseInheritanceInfo) + // If we know the type has a facet represented by `extensionTargetDeclRef`, we can consider + // all extensions on this decl to see if they apply to the type. + // + auto considerExtension = [&](DeclRef extensionTargetDeclRef, + Dictionary* additionalSubtypeWitness) + { + bool result = false; + for (auto extDecl : getCandidateExtensions(extensionTargetDeclRef, &visitor)) { - auto baseInfo = new(arena) DirectBaseInfo(); - - // The information we store for each direct - // base comprises two main things. + // The list of *candidate* extensions is computed and + // cached based on the identity of the declaration alone, + // and does not take into account any generic arguments + // of either the type or the `extension`. // - // First, we have a `Facet` that will represent - // the base in the linearized inheritance list - // we are building. + // For example, we might have an `extension` that applies + // to `vector` for any `N`, but the `selfType` + // that we are working with could be `` so + // that the extension doesn't match. // - baseInfo->facetImpl = FacetImpl( - kind, - Facet::Directness::Direct, - baseDeclRef, - baseType, - selfIsBaseWitness); - Facet baseFacet(&baseInfo->facetImpl); + // In order to make sure that we don't enumerate members + // that don't make sense in context, we must apply + // the extension to the type and see if we succeed in + // making a match. // - // Second, we have a list of the facets in the - // linearization of the base itself. - // - baseInfo->facets = baseInheritanceInfo.facets; - - directBaseFacets.add(baseFacet); - directBases.add(baseInfo); - }; + auto extDeclRef = + applyExtensionToType(&visitor, extDecl, selfType, additionalSubtypeWitness); + if (!extDeclRef) + continue; - // In the case where we know that the base being added - // represents a direct base *type* (and not an `extension`) - // we can derive some of the information needed by - // `addDirectBaseFacet`. - // - auto addDirectBaseType = [&]( - Type* baseType, - SubtypeWitness* selfIsBaseWitness) - { - // If we are representing inheritance from a type, - // then we should have a witness that the type - // in question (either the type being declared by - // `declRef`, or the type being *extended* by - // `declRef`) inherits from that base. + // In the case where we *do* find an extension that + // applies to the type, we add a declared base to + // represent the `extension`, knowing that its + // own linearized inheritance list will include + // any transitive based declared on the `extension`. // - SLANG_ASSERT(selfIsBaseWitness); - - auto baseInheritanceInfo = getInheritanceInfo(baseType, circularityInfo); - - DeclRef baseDeclRef; - if (auto baseDeclRefType = as(baseType)) - { - baseDeclRef = baseDeclRefType->getDeclRef(); - } - + auto extInheritanceInfo = getInheritanceInfo(extDeclRef, circularityInfo); addDirectBaseFacet( - Facet::Kind::Type, - baseType, - selfIsBaseWitness, - baseDeclRef, - baseInheritanceInfo); - }; - - // If we know the type has a facet represented by `extensionTargetDeclRef`, we can consider - // all extensions on this decl to see if they apply to the type. - // - auto considerExtension = [&](DeclRef extensionTargetDeclRef, Dictionary* additionalSubtypeWitness) + Facet::Kind::Extension, + selfType, + selfIsSelf, + extDeclRef, + extInheritanceInfo); + result = true; + } + return result; + }; + + // We now look at the structure of the declaration itself + // to help us enumerate the direct bases. + // + auto currentDeclRef = declRef; + for (; currentDeclRef;) + { + if (auto aggTypeDeclBaseRef = currentDeclRef.as()) { - bool result = false; - for (auto extDecl : getCandidateExtensions(extensionTargetDeclRef, &visitor)) + // In the case where we have an aggregate type or `extension` + // declaration, we can use the explicit list of direct bases. + // + for (auto typeConstraintDeclRef : + getMembersOfType(_getASTBuilder(), aggTypeDeclBaseRef)) { - // The list of *candidate* extensions is computed and - // cached based on the identity of the declaration alone, - // and does not take into account any generic arguments - // of either the type or the `extension`. - // - // For example, we might have an `extension` that applies - // to `vector` for any `N`, but the `selfType` - // that we are working with could be `` so - // that the extension doesn't match. + // Note: In certain cases something takes the *syntactic* form of an inheritance + // clause, but it is not actually something that should be treated as implying + // a subtype relationship. For example, an `enum` declaration can use what looks + // like an inheritance clause to indicate its underlying "tag type." // - // In order to make sure that we don't enumerate members - // that don't make sense in context, we must apply - // the extension to the type and see if we succeed in - // making a match. + // We skip such pseudo-inheritance relationships for the purposes of determining + // the linearized list of bases. // - auto extDeclRef = applyExtensionToType(&visitor, extDecl, selfType, additionalSubtypeWitness); - if (!extDeclRef) + if (typeConstraintDeclRef.getDecl()->hasModifier()) continue; - // In the case where we *do* find an extension that - // applies to the type, we add a declared base to - // represent the `extension`, knowing that its - // own linearized inheritance list will include - // any transitive based declared on the `extension`. + // The only case we will ever see a GenericTypeConstraintDecl inside a AggTypeDecl + // is when AggTypeDecl is a associatedtype decl. In this case, we will only lookup + // the type constraint if the constraint is on the associated type itself. // - auto extInheritanceInfo = getInheritanceInfo(extDeclRef, circularityInfo); - addDirectBaseFacet( - Facet::Kind::Extension, - selfType, - selfIsSelf, - extDeclRef, - extInheritanceInfo); - result = true; - } - return result; - }; - - // We now look at the structure of the declaration itself - // to help us enumerate the direct bases. - // - auto currentDeclRef = declRef; - for (; currentDeclRef;) - { - if (auto aggTypeDeclBaseRef = currentDeclRef.as()) - { - // In the case where we have an aggregate type or `extension` - // declaration, we can use the explicit list of direct bases. - // - for (auto typeConstraintDeclRef : getMembersOfType(_getASTBuilder(), aggTypeDeclBaseRef)) + auto genericTypeConstraintDeclRef = + typeConstraintDeclRef.as(); + if (genericTypeConstraintDeclRef) { - // Note: In certain cases something takes the *syntactic* form of an inheritance - // clause, but it is not actually something that should be treated as implying - // a subtype relationship. For example, an `enum` declaration can use what looks - // like an inheritance clause to indicate its underlying "tag type." - // - // We skip such pseudo-inheritance relationships for the purposes of determining - // the linearized list of bases. - // - if (typeConstraintDeclRef.getDecl()->hasModifier()) + // If the base expr on the constraint isn't even a `VarExpr`, then it can't be + // referencing the associated type itself and we can skip this constraint. + if (!genericTypeConstraintDeclRef.getDecl()->sub.type && + !as(genericTypeConstraintDeclRef.getDecl()->sub.exp)) continue; + } - // The only case we will ever see a GenericTypeConstraintDecl inside a AggTypeDecl is when - // AggTypeDecl is a associatedtype decl. In this case, we will only lookup the type constraint - // if the constraint is on the associated type itself. - // - auto genericTypeConstraintDeclRef = typeConstraintDeclRef.as(); - if (genericTypeConstraintDeclRef) - { - // If the base expr on the constraint isn't even a `VarExpr`, then it can't be referencing - // the associated type itself and we can skip this constraint. - if (!genericTypeConstraintDeclRef.getDecl()->sub.type - && !as(genericTypeConstraintDeclRef.getDecl()->sub.exp)) - continue; - } - - visitor.ensureDecl(typeConstraintDeclRef, DeclCheckState::CanUseBaseOfInheritanceDecl); + visitor.ensureDecl( + typeConstraintDeclRef, + DeclCheckState::CanUseBaseOfInheritanceDecl); - // For generic type constraint decls, always make sure it is about the type being checked. - // - if (genericTypeConstraintDeclRef) - { - auto subType = getSub(astBuilder, genericTypeConstraintDeclRef); - if (subType != selfType) - continue; - } - else if (currentDeclRef != declRef) - { + // For generic type constraint decls, always make sure it is about the type being + // checked. + // + if (genericTypeConstraintDeclRef) + { + auto subType = getSub(astBuilder, genericTypeConstraintDeclRef); + if (subType != selfType) continue; - } - // The base type and subtype witness can easily be determined - // using the `InheritanceDecl`. - // - auto baseType = getSup(astBuilder, typeConstraintDeclRef); - auto satisfyingWitness = astBuilder->getDeclaredSubtypeWitness( - selfType, - baseType, - typeConstraintDeclRef); - - addDirectBaseType(baseType, satisfyingWitness); } - } - if (currentDeclRef.as()) - { - // If the current type is an associated type, continue inspecting the base/parent of the - // associatedtype to discover additional constraints defined on the parent associatedtype decls. - // - if (auto lookupDeclRef = as(currentDeclRef.declRefBase)) + else if (currentDeclRef != declRef) { - currentDeclRef = isDeclRefTypeOf(lookupDeclRef->getLookupSource()).as(); continue; } + // The base type and subtype witness can easily be determined + // using the `InheritanceDecl`. + // + auto baseType = getSup(astBuilder, typeConstraintDeclRef); + auto satisfyingWitness = astBuilder->getDeclaredSubtypeWitness( + selfType, + baseType, + typeConstraintDeclRef); + + addDirectBaseType(baseType, satisfyingWitness); } - break; } - - if (auto genericDeclRef = getDependentGenericParent(declRef)) + if (currentDeclRef.as()) { - // The constraints placed on a generic type parameter are siblings of that - // parameter in its parent `GenericDecl`, so we need to enumerate all of - // the constraints of the parent declaration to find those that constrain - // this parameter. + // If the current type is an associated type, continue inspecting the base/parent of the + // associatedtype to discover additional constraints defined on the parent + // associatedtype decls. // - // TODO(tfoley): We might consider adding a cached representation of the - // constraint information that is keyed on a per-parameter basis. Such a - // representation would need to take into account canonicalization of - // constraints. - - if (auto extensionDecl = as(genericDeclRef.getDecl()->inner)) + if (auto lookupDeclRef = as(currentDeclRef.declRefBase)) { - if (isDeclRefTypeOf(extensionDecl->targetType.type) == declRef) - { - // If `T` is a generic parameter where the same generic is an extension on `T`, - // then we need to add the extension itself as a facet. - // - auto extDeclRef = createDefaultSubstitutionsIfNeeded(astBuilder, &visitor, extensionDecl); - auto selfExtFacet = new(arena) Facet::Impl( - Facet::Kind::Extension, - Facet::Directness::Direct, - extDeclRef, - selfType, - astBuilder->getTypeEqualityWitness(selfType)); - allFacets.add(selfExtFacet); - } + currentDeclRef = + isDeclRefTypeOf(lookupDeclRef->getLookupSource()).as(); + continue; } + } + break; + } + + if (auto genericDeclRef = getDependentGenericParent(declRef)) + { + // The constraints placed on a generic type parameter are siblings of that + // parameter in its parent `GenericDecl`, so we need to enumerate all of + // the constraints of the parent declaration to find those that constrain + // this parameter. + // + // TODO(tfoley): We might consider adding a cached representation of the + // constraint information that is keyed on a per-parameter basis. Such a + // representation would need to take into account canonicalization of + // constraints. - for (auto constraintDeclRef : getMembersOfType(astBuilder, genericDeclRef)) + if (auto extensionDecl = as(genericDeclRef.getDecl()->inner)) + { + if (isDeclRefTypeOf(extensionDecl->targetType.type) == declRef) { - if (constraintDeclRef.getDecl()->checkState.isBeingChecked()) - continue; + // If `T` is a generic parameter where the same generic is an extension on `T`, + // then we need to add the extension itself as a facet. + // + auto extDeclRef = + createDefaultSubstitutionsIfNeeded(astBuilder, &visitor, extensionDecl); + auto selfExtFacet = new (arena) Facet::Impl( + Facet::Kind::Extension, + Facet::Directness::Direct, + extDeclRef, + selfType, + astBuilder->getTypeEqualityWitness(selfType)); + allFacets.add(selfExtFacet); + } + } + + for (auto constraintDeclRef : + getMembersOfType(astBuilder, genericDeclRef)) + { + if (constraintDeclRef.getDecl()->checkState.isBeingChecked()) + continue; - ensureDecl(&visitor, constraintDeclRef.getDecl(), DeclCheckState::CanSpecializeGeneric); + ensureDecl(&visitor, constraintDeclRef.getDecl(), DeclCheckState::CanSpecializeGeneric); - auto subType = getSub(astBuilder, constraintDeclRef); - auto superType = getSup(astBuilder, constraintDeclRef); + auto subType = getSub(astBuilder, constraintDeclRef); + auto superType = getSup(astBuilder, constraintDeclRef); - // We only consider constraints where the type represented - // by `declRef` is the subtype, since those - // constraints are the ones that give us information about - // the declared supertypes. - // - auto subDeclRefType = as(subType); - if (!subDeclRefType) + // We only consider constraints where the type represented + // by `declRef` is the subtype, since those + // constraints are the ones that give us information about + // the declared supertypes. + // + auto subDeclRefType = as(subType); + if (!subDeclRefType) + { + if (auto subEachType = as(subType)) { - if (auto subEachType = as(subType)) - { - subDeclRefType = as(subEachType->getElementType()); - } - if (!subDeclRefType) - continue; + subDeclRefType = as(subEachType->getElementType()); } - if (subDeclRefType->getDeclRef() != declRef) + if (!subDeclRefType) continue; - - // Because the constraint is a declared inheritance relationship, - // adding the base to our list of direct bases is as straightforward - // as in all the preceding cases. - // - auto satisfyingWitness = _getASTBuilder()->getDeclaredSubtypeWitness( - selfType, - superType, - constraintDeclRef); - addDirectBaseType(superType, satisfyingWitness); } - } + if (subDeclRefType->getDeclRef() != declRef) + continue; - // At this point we have enumerated all of the bases that can be - // gleaned by looking at the `declRef` itself. The next step is - // to consider any `extension` declarations that might apply to - // a type being delared. - // - // An `extension` may apply to our type, if it directly extends - // the type, or extends a generic `T` type that are constrained - // on one of the interfaces that our type conforms to. - // - if (auto directAggTypeDeclRef = declRef.as()) - { - considerExtension(directAggTypeDeclRef, nullptr); + // Because the constraint is a declared inheritance relationship, + // adding the base to our list of direct bases is as straightforward + // as in all the preceding cases. + // + auto satisfyingWitness = + _getASTBuilder()->getDeclaredSubtypeWitness(selfType, superType, constraintDeclRef); + addDirectBaseType(superType, satisfyingWitness); } - if (!declRef.as()) + } + + // At this point we have enumerated all of the bases that can be + // gleaned by looking at the `declRef` itself. The next step is + // to consider any `extension` declarations that might apply to + // a type being delared. + // + // An `extension` may apply to our type, if it directly extends + // the type, or extends a generic `T` type that are constrained + // on one of the interfaces that our type conforms to. + // + if (auto directAggTypeDeclRef = declRef.as()) + { + considerExtension(directAggTypeDeclRef, nullptr); + } + if (!declRef.as()) + { + HashSet supTypesConsideredForExtensionApplication; + Dictionary additionalSubtypeWitnesses; + for (;;) { - HashSet supTypesConsideredForExtensionApplication; - Dictionary additionalSubtypeWitnesses; - for (;;) + // After we flatten the list of bases, we may discover additional opportunities + // to apply extensions. + List> supTypeWorkList; + auto base = directBases.begin(); + for (auto baseFacet = directBaseFacets.getHead(); baseFacet.getImpl(); + baseFacet = baseFacet->next) { - // After we flatten the list of bases, we may discover additional opportunities - // to apply extensions. - List> supTypeWorkList; - auto base = directBases.begin(); - for (auto baseFacet = directBaseFacets.getHead(); baseFacet.getImpl(); baseFacet = baseFacet->next) + for (auto facet : (*base)->facets) { - for (auto facet : (*base)->facets) + if (auto interfaceDeclRef = facet->origin.declRef.as()) { - if (auto interfaceDeclRef = facet->origin.declRef.as()) + SubtypeWitness* transitiveWitness = baseFacet->subtypeWitness; + transitiveWitness = astBuilder->getTransitiveSubtypeWitness( + baseFacet->subtypeWitness, + facet->subtypeWitness); + additionalSubtypeWitnesses.addIfNotExists( + facet->origin.type, + transitiveWitness); + if (supTypesConsideredForExtensionApplication.add(facet->origin.type)) { - SubtypeWitness* transitiveWitness = baseFacet->subtypeWitness; - transitiveWitness = astBuilder->getTransitiveSubtypeWitness(baseFacet->subtypeWitness, facet->subtypeWitness); - additionalSubtypeWitnesses.addIfNotExists(facet->origin.type, transitiveWitness); - if (supTypesConsideredForExtensionApplication.add(facet->origin.type)) - { - supTypeWorkList.add(interfaceDeclRef); - } + supTypeWorkList.add(interfaceDeclRef); } } - ++base; } - bool canExit = true; - for (auto baseItem : supTypeWorkList) - { - if (considerExtension(baseItem, &additionalSubtypeWitnesses)) - canExit = false; - } - if (canExit) - break; + ++base; } + bool canExit = true; + for (auto baseItem : supTypeWorkList) + { + if (considerExtension(baseItem, &additionalSubtypeWitnesses)) + canExit = false; + } + if (canExit) + break; } + } - // At this point, the list of direct bases (each with its own linearization) - // is complete. - // - // At this point we could scan through the list of bases and perform - // consistency checks on it. For example, when two types in the list of direct - // bases have a subtype relationship between them, it is possible that the - // programmer made some kind of mistake, and we could emit a diagnostic - // about it. - // - // The published C3 algorithm actually considers the declared list of bases - // as one of the inputs to its merge, and is very strict about ordering. - // As such, it would be an error for strict C3 if direct bases were declared - // in an order that is inconsitent with the partial order determined by - // the subtype relationship. Our implementation of linearization is relaxed - // compared to C3 so that it is robust to such ordering issues. + // At this point, the list of direct bases (each with its own linearization) + // is complete. + // + // At this point we could scan through the list of bases and perform + // consistency checks on it. For example, when two types in the list of direct + // bases have a subtype relationship between them, it is possible that the + // programmer made some kind of mistake, and we could emit a diagnostic + // about it. + // + // The published C3 algorithm actually considers the declared list of bases + // as one of the inputs to its merge, and is very strict about ordering. + // As such, it would be an error for strict C3 if direct bases were declared + // in an order that is inconsitent with the partial order determined by + // the subtype relationship. Our implementation of linearization is relaxed + // compared to C3 so that it is robust to such ordering issues. + // + // Note: This step takes quadratic time in the number of direct bases, but + // there's really no other way to easily detect these issues. + // + for (auto leftBase = directBaseFacets.getHead(); leftBase.getImpl(); leftBase = leftBase->next) + { + // Note: all of the direct base facets with a `Type` kind will + // precede all of those with `Extension` kind, so we can bail + // out of the outer loop as soon as we find a non-`Type` + // facet. // - // Note: This step takes quadratic time in the number of direct bases, but - // there's really no other way to easily detect these issues. + if (leftBase->kind != Facet::Kind::Type) + break; + auto leftBaseType = leftBase->origin.type; + + // For the inner loop we scan only the facets that appear *after* + // the `leftBase` in the list of direct bases. // - for(auto leftBase = directBaseFacets.getHead(); leftBase.getImpl(); leftBase = leftBase->next) + for (auto rightBase = leftBase->next; rightBase.getImpl(); rightBase = rightBase->next) { - // Note: all of the direct base facets with a `Type` kind will - // precede all of those with `Extension` kind, so we can bail - // out of the outer loop as soon as we find a non-`Type` - // facet. - // - if(leftBase->kind != Facet::Kind::Type) + if (rightBase->kind != Facet::Kind::Type) break; - auto leftBaseType = leftBase->origin.type; + auto rightBaseType = rightBase->origin.type; - // For the inner loop we scan only the facets that appear *after* - // the `leftBase` in the list of direct bases. - // - for(auto rightBase = leftBase->next; rightBase.getImpl(); rightBase = rightBase->next) + if (visitor.isSubtype(leftBaseType, rightBaseType, IsSubTypeOptions::None)) { - if (rightBase->kind != Facet::Kind::Type) - break; - auto rightBaseType = rightBase->origin.type; - - if (visitor.isSubtype(leftBaseType, rightBaseType, IsSubTypeOptions::None)) - { - // If a type earlier in the list of bases is a subtype of - // one later in the list, then the ordering is consistent - // with the linearization that will be produced, but it - // might represent a mistake on the programmer's part, - // since they listed a base type that is redundant. - // - // TODO: decide whether to diagnose this case. - } - else if (visitor.isSubtype(rightBaseType, leftBaseType, IsSubTypeOptions::None)) - { - // If a type later in the list is a subtype of a type earlier - // in the list, then the declared list of bases is inconsistent - // with the ordering that will (indeed *must*) appear in the - // linearization we generate. - // - // If we end up implementing a strict version of the C3 algorithm, - // we would need to treat such situations as an error, or at least - // emit a warning and then remove the subtype from the list of - // bases. - // - // TODO: decide whether to diagnose this case. - } + // If a type earlier in the list of bases is a subtype of + // one later in the list, then the ordering is consistent + // with the linearization that will be produced, but it + // might represent a mistake on the programmer's part, + // since they listed a base type that is redundant. + // + // TODO: decide whether to diagnose this case. + } + else if (visitor.isSubtype(rightBaseType, leftBaseType, IsSubTypeOptions::None)) + { + // If a type later in the list is a subtype of a type earlier + // in the list, then the declared list of bases is inconsistent + // with the ordering that will (indeed *must*) appear in the + // linearization we generate. + // + // If we end up implementing a strict version of the C3 algorithm, + // we would need to treat such situations as an error, or at least + // emit a warning and then remove the subtype from the list of + // bases. + // + // TODO: decide whether to diagnose this case. } } + } - // Now that we've built up the list of direct bases and their - // respective linearizations, we can apply the core merge algorithm - // to those lists to produce the rest of the linearization for - // the declaration in question. - // - _mergeFacetLists(directBases, directBaseFacets, allFacets); + // Now that we've built up the list of direct bases and their + // respective linearizations, we can apply the core merge algorithm + // to those lists to produce the rest of the linearization for + // the declaration in question. + // + _mergeFacetLists(directBases, directBaseFacets, allFacets); - InheritanceInfo info; - info.facets = allFacets; - return info; - } + InheritanceInfo info; + info.facets = allFacets; + return info; +} - void SharedSemanticsContext::_mergeFacetLists(DirectBaseList bases, FacetList baseFacets, FacetList::Builder& ioMergedFacets) +void SharedSemanticsContext::_mergeFacetLists( + DirectBaseList bases, + FacetList baseFacets, + FacetList::Builder& ioMergedFacets) +{ + // Our task here is to take the list of direct/declared `bases`, + // each of which holds a linearized list of `Facet`s, and produce + // a single linearized list of facets in `ioMergedFacets`. + // + // The `Facet`s in the lists referenced by `bases` are always + // relative to the base type/extension itself, and not to + // the type or declaration for which we are computing + // a linearization. + // + // The `baseFacets` list provides one `Facet` for each direct + // base that are relative to the type/declaration we are + // computing a linearization for. These facets will be used + // directly, instead of those from `bases`, where possible. + // + auto astBuilder = _getASTBuilder(); + auto& arena = astBuilder->getArena(); + for (;;) { - // Our task here is to take the list of direct/declared `bases`, - // each of which holds a linearized list of `Facet`s, and produce - // a single linearized list of facets in `ioMergedFacets`. + // The basic logic here is that on each iteration we + // will look at the first item on each list in `bases` + // and pick one that we will append to the merged output + // (after removing it from the relevant input(s)). + + // If we have run out of lists that need merging, then we are done. // - // The `Facet`s in the lists referenced by `bases` are always - // relative to the base type/extension itself, and not to - // the type or declaration for which we are computing - // a linearization. + if (bases.isEmpty()) + break; + + // Otherwise, we will look at the remaining non-empty lists, + // and see if one of them starts with an facet that can + // be appended to our merged output. // - // The `baseFacets` list provides one `Facet` for each direct - // base that are relative to the type/declaration we are - // computing a linearization for. These facets will be used - // directly, instead of those from `bases`, where possible. + // If multiple such facets are viable, we will always take + // the one from the earliest list in `bases`. Doing so favors + // the types that appear earlier in a list of bases. // - auto astBuilder = _getASTBuilder(); - auto& arena = astBuilder->getArena(); - for(;;) + Facet foundFacet; + DirectBaseInfo* foundBase = nullptr; + for (auto base : bases) { - // The basic logic here is that on each iteration we - // will look at the first item on each list in `bases` - // and pick one that we will append to the merged output - // (after removing it from the relevant input(s)). + Facet headFacet = base->facets.getHead(); - // If we have run out of lists that need merging, then we are done. + // If the head facet of the `base` list appears at a non-head + // position in any of the other lists, we cannot append this + // element without risking inverting the order of some facets + // relative to those other lists. // - if (bases.isEmpty()) - break; + if (bases.doesAnyTailContainMatchFor(headFacet)) + continue; - // Otherwise, we will look at the remaining non-empty lists, - // and see if one of them starts with an facet that can - // be appended to our merged output. + // Otherwise, we are safe to add the `headFacet` to our + // merged list, because it only ever appears as the head + // of one or more of the lists in `bases`. // - // If multiple such facets are viable, we will always take - // the one from the earliest list in `bases`. Doing so favors - // the types that appear earlier in a list of bases. - // - Facet foundFacet; - DirectBaseInfo* foundBase = nullptr; - for (auto base : bases) - { - Facet headFacet = base->facets.getHead(); - - // If the head facet of the `base` list appears at a non-head - // position in any of the other lists, we cannot append this - // element without risking inverting the order of some facets - // relative to those other lists. - // - if (bases.doesAnyTailContainMatchFor(headFacet)) - continue; - - // Otherwise, we are safe to add the `headFacet` to our - // merged list, because it only ever appears as the head - // of one or more of the lists in `bases`. - // - foundFacet = headFacet; - foundBase = base; - break; - } - - if(!foundFacet) - { - // If we could not identify a facet that could be safely - // removed from any of the base lists, then it means that - // we must have a cycle in the ordering constraints implied - // by the `bases` lists. - // - // The simplest example of such a cycle would be if we - // had two lists, `A` and `B`, such that: - // - // A = { X, Y } - // B = { Y, X } - // - // In this case, producing output in the order `X, Y` *or* - // `Y, X` will always invalidate the ordering constraints - // implied by either `A` or `B`. - // - // In the C3 algorithm as published, such a situation is an - // error, and the algorithm fails to produce a linearization. - // The reason for this decision is that allowing this case - // means that a base type and a derived type could disagree - // on the relative priority of method overrides, and thus - // a subclass could possible break semantic assumptions of - // a superclass. - // - // In a more static language like Slang, it seems better to - // allow more flexible inheritance, *especially* when dealing - // with things like `interface`s and `extension`s, where the - // relative ordering of things will often be immaterial. - // - // In a case like this, we would like to arbitrarily pick - // one or the other of `X` and `Y`, and given our default - // policy to favor the earlier list in `bases` where possible, - // we would select `X` from `A`. - // - // One thing worth noting is that when a case like the above - // arises, it is not possible that `X <: Y` or `Y <: X`. - // If a subtype relationship existed between the two, then - // they would be consistently ordered in *both* lists. - // We thus do not have to worry about violating the most - // important requirement for a "reasonable" linearization. - // - foundBase = *bases.begin(); - foundFacet = foundBase->facets.getHead(); - - // Note: because we are grabbing a facet that might appear - // in a non-head position in one or more of our lists, - // we need to have a plan for what to do when we see - // that same facet come to the front of one of our lists - // later. - } - - // If we still cannot find a facet, then there is a true cycle in - // the inheritance graph, which is an error in the user code. - if (!foundFacet.getImpl()) - { - if (!bases.isEmpty()) - { - auto baseDecl = (*bases.begin())->facetImpl.origin.declRef.getDecl(); - getSink()->diagnose(baseDecl, Diagnostics::cyclicReferenceInInheritance, baseDecl); - } - return; - } + foundFacet = headFacet; + foundBase = base; + break; + } - // At this point we definitely have a facet we'd like to - // add to the output, whether it was found via the true - // C3 approach, or our relaxed rule above. + if (!foundFacet) + { + // If we could not identify a facet that could be safely + // removed from any of the base lists, then it means that + // we must have a cycle in the ordering constraints implied + // by the `bases` lists. // - SLANG_ASSERT(foundFacet.getImpl()); - - // If the facet we want to append to the output is the same as the front-most - // facet on the list of bases, then we want to use that facet as-is (since we - // have already allocated storage for it). + // The simplest example of such a cycle would be if we + // had two lists, `A` and `B`, such that: + // + // A = { X, Y } + // B = { Y, X } + // + // In this case, producing output in the order `X, Y` *or* + // `Y, X` will always invalidate the ordering constraints + // implied by either `A` or `B`. + // + // In the C3 algorithm as published, such a situation is an + // error, and the algorithm fails to produce a linearization. + // The reason for this decision is that allowing this case + // means that a base type and a derived type could disagree + // on the relative priority of method overrides, and thus + // a subclass could possible break semantic assumptions of + // a superclass. + // + // In a more static language like Slang, it seems better to + // allow more flexible inheritance, *especially* when dealing + // with things like `interface`s and `extension`s, where the + // relative ordering of things will often be immaterial. + // + // In a case like this, we would like to arbitrarily pick + // one or the other of `X` and `Y`, and given our default + // policy to favor the earlier list in `bases` where possible, + // we would select `X` from `A`. // - // TODO: in cases where the strict C3 algorithm would fail, and we choose a - // `foundFacet` that is at a non-head position in at least some lists, it - // might be possible that we have a facet that matches ones of the `baseFacets`, - // but not the head one. We should confirm what happens in that case. + // One thing worth noting is that when a case like the above + // arises, it is not possible that `X <: Y` or `Y <: X`. + // If a subtype relationship existed between the two, then + // they would be consistently ordered in *both* lists. + // We thus do not have to worry about violating the most + // important requirement for a "reasonable" linearization. // - if(originsMatch(foundFacet, baseFacets.getHead())) + foundBase = *bases.begin(); + foundFacet = foundBase->facets.getHead(); + + // Note: because we are grabbing a facet that might appear + // in a non-head position in one or more of our lists, + // we need to have a plan for what to do when we see + // that same facet come to the front of one of our lists + // later. + } + + // If we still cannot find a facet, then there is a true cycle in + // the inheritance graph, which is an error in the user code. + if (!foundFacet.getImpl()) + { + if (!bases.isEmpty()) { - auto directBaseFacet = baseFacets.popHead(); - ioMergedFacets.add(directBaseFacet); + auto baseDecl = (*bases.begin())->facetImpl.origin.declRef.getDecl(); + getSink()->diagnose(baseDecl, Diagnostics::cyclicReferenceInInheritance, baseDecl); } - else - { - // This facet is seemingly *not* a facet that represents one of the direct - // bases for the type/declaration being processed. - // - // As such, we need to allocate a fresh facet to represent it in the - // linearization we are creating, since the `foundFacet` already belongs - // to the linearization of one of the bases, and shouldn't be repurposed. - // - auto indirectFacet = new(arena) Facet::Impl(); - - // We will initialize the fresh facet to a copy of the state of the - // `foundFacet`, albeit with a higher level of indirection. - // - // TODO: In principle we could search through all of the lists to - // find the one with a facet matching `foundFacet` with minimum - // indirection, so that our measure of indirection is always - // as small as possible for any given facet. - // - *indirectFacet = *(foundFacet.getImpl()); - indirectFacet->next = nullptr; - indirectFacet->directness = - Facet::Directness(Facet::DirectnessVal(indirectFacet->directness) + 1); - - // When using this facet for subtype tests, or when looking - // up member through this facet, we will need a witness - // to show that the self type of the declaration being - // linearized (the type being declared or extended) is a - // subtype of the type for this facet. - // - // We can construct the appropriate witness transitively, - // by noting that: - // - // * The self type is known to be a subtype of the direct - // base represented by `foundBase`, and the facet for - // that base stores a witness to that fact. - // - SubtypeWitness* selfIsSubtypeOfBase = foundBase->facetImpl.subtypeWitness; - // - // * The direct base type must be a subtype of the type - // for any facet found in its own linearization, and - // the `foundFacet` that came from the relevant base - // stores a witness to that fact. - // - SubtypeWitness* baseIsSubtypeOfFacet = foundFacet->subtypeWitness; - - auto selfIsSubtypeOfFacet = _getASTBuilder()->getTransitiveSubtypeWitness( - selfIsSubtypeOfBase, - baseIsSubtypeOfFacet); + return; + } - indirectFacet->subtypeWitness = selfIsSubtypeOfFacet; + // At this point we definitely have a facet we'd like to + // add to the output, whether it was found via the true + // C3 approach, or our relaxed rule above. + // + SLANG_ASSERT(foundFacet.getImpl()); - ioMergedFacets.add(indirectFacet); - } + // If the facet we want to append to the output is the same as the front-most + // facet on the list of bases, then we want to use that facet as-is (since we + // have already allocated storage for it). + // + // TODO: in cases where the strict C3 algorithm would fail, and we choose a + // `foundFacet` that is at a non-head position in at least some lists, it + // might be possible that we have a facet that matches ones of the `baseFacets`, + // but not the head one. We should confirm what happens in that case. + // + if (originsMatch(foundFacet, baseFacets.getHead())) + { + auto directBaseFacet = baseFacets.popHead(); + ioMergedFacets.add(directBaseFacet); + } + else + { + // This facet is seemingly *not* a facet that represents one of the direct + // bases for the type/declaration being processed. + // + // As such, we need to allocate a fresh facet to represent it in the + // linearization we are creating, since the `foundFacet` already belongs + // to the linearization of one of the bases, and shouldn't be repurposed. + // + auto indirectFacet = new (arena) Facet::Impl(); - // We picked one `foundFacet` above to be added to the merged - // output list, and we now need to ensure that we won't ever - // emit a matching facet again. + // We will initialize the fresh facet to a copy of the state of the + // `foundFacet`, albeit with a higher level of indirection. // - // In the case of the strict/standard C3 algorithm, any facets - // matching `foundFacet` would need to appear at a head position - // in one of the base lists. As such, it is sufficient to run - // through the base lists, check for a match at the head of each, - // and remove any matching facets we find. + // TODO: In principle we could search through all of the lists to + // find the one with a facet matching `foundFacet` with minimum + // indirection, so that our measure of indirection is always + // as small as possible for any given facet. // - for (auto base : bases) - { - if (originsMatch(foundFacet, base->facets.getHead())) - { - base->facets.advanceHead(); - continue; - } - } + *indirectFacet = *(foundFacet.getImpl()); + indirectFacet->next = nullptr; + indirectFacet->directness = + Facet::Directness(Facet::DirectnessVal(indirectFacet->directness) + 1); + + // When using this facet for subtype tests, or when looking + // up member through this facet, we will need a witness + // to show that the self type of the declaration being + // linearized (the type being declared or extended) is a + // subtype of the type for this facet. // - // Because we are not implementing the C3 algorithm strictly, - // we need a solution for the case where `foundFacet` is - // in a non-head position in one or more of the base lists. + // We can construct the appropriate witness transitively, + // by noting that: // - // Proactively filtering `foundFacet` out of all of the lists - // is possible, but given that these are singly-linked lists - // we cannot easily filter them without either allocation - // or mutation. + // * The self type is known to be a subtype of the direct + // base represented by `foundBase`, and the facet for + // that base stores a witness to that fact. // - // Instead, we will filter out facets that have already been - // added to the merged list as needed, when such facets come - // to the front of the relevant list. + SubtypeWitness* selfIsSubtypeOfBase = foundBase->facetImpl.subtypeWitness; // - for (auto base : bases) - { - for(;;) - { - // For each base list, we will check if its - // head facet is one that has already been - // emitted to the output. - // - // If the head facet has not been emitted - // already, we don't need to perform any - // filtering on the base list at this time. - // - auto head = base->facets.getHead(); - if (!ioMergedFacets.containsMatchFor(head)) - break; - - // Otherwise, we remove the head facet from - // the given base list and test again, unless - // the list is now empty. - // - base->facets.advanceHead(); - if (base->facets.isEmpty()) - break; - } - } - - // The filtering step might have led to one or more - // of the `bases` lists becomming empty. Our merge - // algorithm really only wants to consider non-empty - // lists, so we go ahead and remove the empty lists - // here. + // * The direct base type must be a subtype of the type + // for any facet found in its own linearization, and + // the `foundFacet` that came from the relevant base + // stores a witness to that fact. // - bases.removeEmptyLists(); + SubtypeWitness* baseIsSubtypeOfFacet = foundFacet->subtypeWitness; - // At this point all of the lists have been appropriately filtered, - // and we are ready to circle back around again to the step - // where select a facet to add to the merged list. - } + auto selfIsSubtypeOfFacet = _getASTBuilder()->getTransitiveSubtypeWitness( + selfIsSubtypeOfBase, + baseIsSubtypeOfFacet); - // At this point, all of the input lists in `bases` should be empty, - // and all of the facets in those lists should have found their way - // over to `ioMergedFacets`. - } + indirectFacet->subtypeWitness = selfIsSubtypeOfFacet; - // The mering algorithm needs to be able to test if two potentially-distinct - // `Facet`s represent the same underlying type or declaration. - // - bool originsMatch(Facet left, Facet right) - { - if (left.getImpl() == right.getImpl()) - return true; - if (!left.getImpl() || !right.getImpl()) - return false; + ioMergedFacets.add(indirectFacet); + } - // If both of the facets are non-null, and not - // identical, we check if their origins match, - // meaning that they represent the same type - // or declaration. + // We picked one `foundFacet` above to be added to the merged + // output list, and we now need to ensure that we won't ever + // emit a matching facet again. // - return left->origin == right->origin; - } - - bool operator==(Facet::Origin left, Facet::Origin right) - { - // If either facet represents a declaration, then - // the origins only match if they both represent - // the *same* declaration. + // In the case of the strict/standard C3 algorithm, any facets + // matching `foundFacet` would need to appear at a head position + // in one of the base lists. As such, it is sufficient to run + // through the base lists, check for a match at the head of each, + // and remove any matching facets we find. // - if (left.declRef.getDecl() || right.declRef.getDecl()) + for (auto base : bases) { - return left.declRef.getDecl() - && right.declRef.getDecl() - && left.declRef.equals(right.declRef); + if (originsMatch(foundFacet, base->facets.getHead())) + { + base->facets.advanceHead(); + continue; + } } - - // Otherwise, if they both represent types, then the - // origins match if they are the same type. // - // Note: an `extension` facet will always have a non-null - // `declRef`, so there is no risk here of an `extension` - // and a type facet being matched by this step; they - // would always land in the case above. + // Because we are not implementing the C3 algorithm strictly, + // we need a solution for the case where `foundFacet` is + // in a non-head position in one or more of the base lists. // - if (left.type || right.type) + // Proactively filtering `foundFacet` out of all of the lists + // is possible, but given that these are singly-linked lists + // we cannot easily filter them without either allocation + // or mutation. + // + // Instead, we will filter out facets that have already been + // added to the merged list as needed, when such facets come + // to the front of the relevant list. + // + for (auto base : bases) { - return left.type - && right.type - && left.type->equals(right.type); + for (;;) + { + // For each base list, we will check if its + // head facet is one that has already been + // emitted to the output. + // + // If the head facet has not been emitted + // already, we don't need to perform any + // filtering on the base list at this time. + // + auto head = base->facets.getHead(); + if (!ioMergedFacets.containsMatchFor(head)) + break; + + // Otherwise, we remove the head facet from + // the given base list and test again, unless + // the list is now empty. + // + base->facets.advanceHead(); + if (base->facets.isEmpty()) + break; + } } - // TODO: The rules we are using for matching here - // would need to be revisited and overhauled significantly - // if we start supporting generic type declarations - // with covariant/contravariant type parameters. - // - // In such cases we would need to treat two facets as - // matching if their declarations or types are an exact - // matching modulo type arguments, and the relationship - // between pairwise type arguments is consistent with - // the variance of the corresponding parameter. + // The filtering step might have led to one or more + // of the `bases` lists becomming empty. Our merge + // algorithm really only wants to consider non-empty + // lists, so we go ahead and remove the empty lists + // here. // - // E.g., we would need to treat facets for `IEnumerable` - // and `IEnumerable` as matching, and ensure that a - // merged output list for a type/declaration could only - // include the more specific of the two (`IEnumerable`). + bases.removeEmptyLists(); - return false; + // At this point all of the lists have been appropriately filtered, + // and we are ready to circle back around again to the step + // where select a facet to add to the merged list. } - // The remaining list-related operations that relate to the merging - // process are relatively simple to follow once the definition of - // matching is clear. + // At this point, all of the input lists in `bases` should be empty, + // and all of the facets in those lists should have found their way + // over to `ioMergedFacets`. +} - bool SharedSemanticsContext::DirectBaseList::doesAnyTailContainMatchFor(Facet facet) const - { - for (auto base : *this) - { - if (base->facets.isEmpty()) - continue; - if (base->facets.getTail().containsMatchFor(facet)) - return true; - } +// The mering algorithm needs to be able to test if two potentially-distinct +// `Facet`s represent the same underlying type or declaration. +// +bool originsMatch(Facet left, Facet right) +{ + if (left.getImpl() == right.getImpl()) + return true; + if (!left.getImpl() || !right.getImpl()) return false; + + // If both of the facets are non-null, and not + // identical, we check if their origins match, + // meaning that they represent the same type + // or declaration. + // + return left->origin == right->origin; +} + +bool operator==(Facet::Origin left, Facet::Origin right) +{ + // If either facet represents a declaration, then + // the origins only match if they both represent + // the *same* declaration. + // + if (left.declRef.getDecl() || right.declRef.getDecl()) + { + return left.declRef.getDecl() && right.declRef.getDecl() && + left.declRef.equals(right.declRef); } - void SharedSemanticsContext::DirectBaseList::removeEmptyLists() + // Otherwise, if they both represent types, then the + // origins match if they are the same type. + // + // Note: an `extension` facet will always have a non-null + // `declRef`, so there is no risk here of an `extension` + // and a type facet being matched by this step; they + // would always land in the case above. + // + if (left.type || right.type) { - DirectBaseInfo** link = &_head; - while (auto base = *link) - { - if (base->facets.isEmpty()) - { - *link = base->next; - } - else - { - link = &base->next; - } - } + return left.type && right.type && left.type->equals(right.type); } - bool FacetList::containsMatchFor(Facet facet) const + // TODO: The rules we are using for matching here + // would need to be revisited and overhauled significantly + // if we start supporting generic type declarations + // with covariant/contravariant type parameters. + // + // In such cases we would need to treat two facets as + // matching if their declarations or types are an exact + // matching modulo type arguments, and the relationship + // between pairwise type arguments is consistent with + // the variance of the corresponding parameter. + // + // E.g., we would need to treat facets for `IEnumerable` + // and `IEnumerable` as matching, and ensure that a + // merged output list for a type/declaration could only + // include the more specific of the two (`IEnumerable`). + + return false; +} + +// The remaining list-related operations that relate to the merging +// process are relatively simple to follow once the definition of +// matching is clear. + +bool SharedSemanticsContext::DirectBaseList::doesAnyTailContainMatchFor(Facet facet) const +{ + for (auto base : *this) { - for (auto f : *this) - { - if (originsMatch(f, facet)) - return true; - } - return false; + if (base->facets.isEmpty()) + continue; + if (base->facets.getTail().containsMatchFor(facet)) + return true; } + return false; +} - InheritanceInfo SharedSemanticsContext::_calcInheritanceInfo(Type* type, InheritanceCircularityInfo* circularityInfo) +void SharedSemanticsContext::DirectBaseList::removeEmptyLists() +{ + DirectBaseInfo** link = &_head; + while (auto base = *link) { - // The majority of the interesting for for computing linearized - // inheritance information arises for `DeclRef`s, but we still - // need a way to compute the relevant information for types - // that might or might not be defined using `Decl`s. - - auto astBuilder = _getASTBuilder(); - auto& arena = astBuilder->getArena(); - if (auto declRefType = as(type)) + if (base->facets.isEmpty()) { - // The `DeclRef` case is the easy one, since we can - // bottleneck through the logic that gets shared between - // type and `extension` declarations. - // - return _getInheritanceInfo(declRefType->getDeclRef(), declRefType, circularityInfo); + *link = base->next; } - else if (auto conjunctionType = as(type)) + else { - // In this case, we have a type of the form `L & R`, - // such that it is a subtype of both `L` and `R`. - // - auto leftType = conjunctionType->getLeft(); - auto rightType = conjunctionType->getRight(); + link = &base->next; + } + } +} - // The linearized inheritance list for the conjunction - // must include all the facets from the lists for `L` - // and `R`, respectively. - // - auto leftInfo = getInheritanceInfo(leftType, circularityInfo); - auto rightInfo = getInheritanceInfo(rightType, circularityInfo); - - // We have a case of subtype witness that can show that - // `T : L` or `T : R` based on `T : L&R`. In this case, - // though, the type `T` is actually `L&R` itself, so - // we need to construct an identity witness for `L&R : L&R` - // to give it something to start from. - // - auto selfIsSelf = astBuilder->getTypeEqualityWitness(conjunctionType); - auto selfIsSubtypeOfLeft = _getASTBuilder()->getExtractFromConjunctionSubtypeWitness(type, leftType, selfIsSelf, 0); - auto selfIsSubtypeOfRight = _getASTBuilder()->getExtractFromConjunctionSubtypeWitness(type, rightType, selfIsSelf, 1); - - // We will set up to perform a merge between the facet - // lists for the two "bases" `L` and `R`. Note that the - // information we write into the `facetImpl` in each case - // is largely just for completeness and debugging, since - // we are *not* going to add those facets into a list - // of direct base facets to be merged. - // - DirectBaseInfo leftBaseInfo; - leftBaseInfo.facetImpl = FacetImpl( - Facet::Kind::Type, - Facet::Directness::Direct, - DeclRef(), - leftType, - selfIsSubtypeOfLeft); - leftBaseInfo.facets = leftInfo.facets; - - DirectBaseInfo rightBaseInfo; - rightBaseInfo.facetImpl = FacetImpl( - Facet::Kind::Type, - Facet::Directness::Direct, - DeclRef(), - rightType, - selfIsSubtypeOfRight); - rightBaseInfo.facets = rightInfo.facets; - - DirectBaseList::Builder directBases; - directBases.add(&leftBaseInfo); - directBases.add(&rightBaseInfo); - - // The merging step is then the same as for the more "standard" case, - // with the only detail that we are not passing in a list of facets - // to represent the directly-declared bases (since there are none; - // this is a structural rather than nominal type). - // - FacetList::Builder mergedFacets; - _mergeFacetLists(directBases, FacetList(), mergedFacets); +bool FacetList::containsMatchFor(Facet facet) const +{ + for (auto f : *this) + { + if (originsMatch(f, facet)) + return true; + } + return false; +} - InheritanceInfo info; - info.facets = mergedFacets; - return info; - } - else if (auto eachType = as(type)) +InheritanceInfo SharedSemanticsContext::_calcInheritanceInfo( + Type* type, + InheritanceCircularityInfo* circularityInfo) +{ + // The majority of the interesting for for computing linearized + // inheritance information arises for `DeclRef`s, but we still + // need a way to compute the relevant information for types + // that might or might not be defined using `Decl`s. + + auto astBuilder = _getASTBuilder(); + auto& arena = astBuilder->getArena(); + if (auto declRefType = as(type)) + { + // The `DeclRef` case is the easy one, since we can + // bottleneck through the logic that gets shared between + // type and `extension` declarations. + // + return _getInheritanceInfo(declRefType->getDeclRef(), declRefType, circularityInfo); + } + else if (auto conjunctionType = as(type)) + { + // In this case, we have a type of the form `L & R`, + // such that it is a subtype of both `L` and `R`. + // + auto leftType = conjunctionType->getLeft(); + auto rightType = conjunctionType->getRight(); + + // The linearized inheritance list for the conjunction + // must include all the facets from the lists for `L` + // and `R`, respectively. + // + auto leftInfo = getInheritanceInfo(leftType, circularityInfo); + auto rightInfo = getInheritanceInfo(rightType, circularityInfo); + + // We have a case of subtype witness that can show that + // `T : L` or `T : R` based on `T : L&R`. In this case, + // though, the type `T` is actually `L&R` itself, so + // we need to construct an identity witness for `L&R : L&R` + // to give it something to start from. + // + auto selfIsSelf = astBuilder->getTypeEqualityWitness(conjunctionType); + auto selfIsSubtypeOfLeft = _getASTBuilder()->getExtractFromConjunctionSubtypeWitness( + type, + leftType, + selfIsSelf, + 0); + auto selfIsSubtypeOfRight = _getASTBuilder()->getExtractFromConjunctionSubtypeWitness( + type, + rightType, + selfIsSelf, + 1); + + // We will set up to perform a merge between the facet + // lists for the two "bases" `L` and `R`. Note that the + // information we write into the `facetImpl` in each case + // is largely just for completeness and debugging, since + // we are *not* going to add those facets into a list + // of direct base facets to be merged. + // + DirectBaseInfo leftBaseInfo; + leftBaseInfo.facetImpl = FacetImpl( + Facet::Kind::Type, + Facet::Directness::Direct, + DeclRef(), + leftType, + selfIsSubtypeOfLeft); + leftBaseInfo.facets = leftInfo.facets; + + DirectBaseInfo rightBaseInfo; + rightBaseInfo.facetImpl = FacetImpl( + Facet::Kind::Type, + Facet::Directness::Direct, + DeclRef(), + rightType, + selfIsSubtypeOfRight); + rightBaseInfo.facets = rightInfo.facets; + + DirectBaseList::Builder directBases; + directBases.add(&leftBaseInfo); + directBases.add(&rightBaseInfo); + + // The merging step is then the same as for the more "standard" case, + // with the only detail that we are not passing in a list of facets + // to represent the directly-declared bases (since there are none; + // this is a structural rather than nominal type). + // + FacetList::Builder mergedFacets; + _mergeFacetLists(directBases, FacetList(), mergedFacets); + + InheritanceInfo info; + info.facets = mergedFacets; + return info; + } + else if (auto eachType = as(type)) + { + auto elementInheritanceInfo = + getInheritanceInfo(eachType->getElementType(), circularityInfo); + SemanticsVisitor visitor(this); + auto directFacet = new (arena) Facet::Impl( + Facet::Kind::Type, + Facet::Directness::Self, + DeclRef(), + type, + visitor.createTypeEqualityWitness(type)); + Facet tail = directFacet; + for (auto facet : elementInheritanceInfo.facets) { - auto elementInheritanceInfo = getInheritanceInfo(eachType->getElementType(), circularityInfo); - SemanticsVisitor visitor(this); - auto directFacet = new(arena) Facet::Impl( - Facet::Kind::Type, - Facet::Directness::Self, - DeclRef(), - type, - visitor.createTypeEqualityWitness(type)); - Facet tail = directFacet; - for (auto facet : elementInheritanceInfo.facets) + if (facet->directness == Facet::Directness::Direct) { - if (facet->directness == Facet::Directness::Direct) - { - auto eachFacet = new(arena) Facet::Impl( - Facet::Kind::Type, - Facet::Directness::Direct, - facet->origin.declRef, - facet->origin.type, - astBuilder->getEachSubtypeWitness(type, facet->subtypeWitness->getSup(), facet->subtypeWitness)); - tail->next = eachFacet; - tail = eachFacet; - } + auto eachFacet = new (arena) Facet::Impl( + Facet::Kind::Type, + Facet::Directness::Direct, + facet->origin.declRef, + facet->origin.type, + astBuilder->getEachSubtypeWitness( + type, + facet->subtypeWitness->getSup(), + facet->subtypeWitness)); + tail->next = eachFacet; + tail = eachFacet; } - InheritanceInfo info; - info.facets = FacetList(directFacet); - return info; - } - else - { - // As a fallback, any type not covered by the above cases will - // get a trivial linearization that consists of a single facet - // corresponding to that type itself. - // - SemanticsVisitor visitor(this); - auto directFacet = new(arena) Facet::Impl( - Facet::Kind::Type, - Facet::Directness::Self, - DeclRef(), - type, - visitor.createTypeEqualityWitness(type)); - - InheritanceInfo info; - info.facets = FacetList(directFacet); - return info; } + InheritanceInfo info; + info.facets = FacetList(directFacet); + return info; + } + else + { + // As a fallback, any type not covered by the above cases will + // get a trivial linearization that consists of a single facet + // corresponding to that type itself. + // + SemanticsVisitor visitor(this); + auto directFacet = new (arena) Facet::Impl( + Facet::Kind::Type, + Facet::Directness::Self, + DeclRef(), + type, + visitor.createTypeEqualityWitness(type)); + + InheritanceInfo info; + info.facets = FacetList(directFacet); + return info; } } +} // namespace Slang diff --git a/source/slang/slang-check-modifier.cpp b/source/slang/slang-check-modifier.cpp index f05b58c34..36a73b2e9 100644 --- a/source/slang/slang-check-modifier.cpp +++ b/source/slang/slang-check-modifier.cpp @@ -1,6 +1,6 @@ // slang-check-modifier.cpp -#include "slang-check-impl.h" #include "../core/slang-char-util.h" +#include "slang-check-impl.h" // This file implements semantic checking behavior for // modifiers. @@ -12,379 +12,350 @@ namespace Slang { - IntVal* SemanticsVisitor::checkLinkTimeConstantIntVal( - Expr* expr) - { - expr = CheckExpr(expr); - return CheckIntegerConstantExpression(expr, IntegerConstantExpressionCoercionType::AnyInteger, nullptr, ConstantFoldingKind::LinkTime); - } +IntVal* SemanticsVisitor::checkLinkTimeConstantIntVal(Expr* expr) +{ + expr = CheckExpr(expr); + return CheckIntegerConstantExpression( + expr, + IntegerConstantExpressionCoercionType::AnyInteger, + nullptr, + ConstantFoldingKind::LinkTime); +} - ConstantIntVal* SemanticsVisitor::checkConstantIntVal( - Expr* expr) +ConstantIntVal* SemanticsVisitor::checkConstantIntVal(Expr* expr) +{ + // First type-check the expression as normal + expr = CheckExpr(expr); + + auto intVal = CheckIntegerConstantExpression( + expr, + IntegerConstantExpressionCoercionType::AnyInteger, + nullptr, + ConstantFoldingKind::CompileTime); + if (!intVal) + return nullptr; + + auto constIntVal = as(intVal); + if (!constIntVal) { - // First type-check the expression as normal - expr = CheckExpr(expr); - - auto intVal = CheckIntegerConstantExpression(expr, IntegerConstantExpressionCoercionType::AnyInteger, nullptr, ConstantFoldingKind::CompileTime); - if(!intVal) - return nullptr; - - auto constIntVal = as(intVal); - if(!constIntVal) - { - getSink()->diagnose(expr->loc, Diagnostics::expectedIntegerConstantNotLiteral); - return nullptr; - } - return constIntVal; + getSink()->diagnose(expr->loc, Diagnostics::expectedIntegerConstantNotLiteral); + return nullptr; } + return constIntVal; +} - ConstantIntVal* SemanticsVisitor::checkConstantEnumVal( - Expr* expr) - { - // First type-check the expression as normal - expr = CheckExpr(expr); +ConstantIntVal* SemanticsVisitor::checkConstantEnumVal(Expr* expr) +{ + // First type-check the expression as normal + expr = CheckExpr(expr); - auto intVal = CheckEnumConstantExpression(expr, ConstantFoldingKind::CompileTime); - if(!intVal) - return nullptr; + auto intVal = CheckEnumConstantExpression(expr, ConstantFoldingKind::CompileTime); + if (!intVal) + return nullptr; - auto constIntVal = as(intVal); - if(!constIntVal) - { - getSink()->diagnose(expr->loc, Diagnostics::expectedIntegerConstantNotLiteral); - return nullptr; - } - return constIntVal; + auto constIntVal = as(intVal); + if (!constIntVal) + { + getSink()->diagnose(expr->loc, Diagnostics::expectedIntegerConstantNotLiteral); + return nullptr; } + return constIntVal; +} - // Check an expression, coerce it to the `String` type, and then - // ensure that it has a literal (not just compile-time constant) value. - bool SemanticsVisitor::checkLiteralStringVal( - Expr* expr, - String* outVal) - { - // TODO: This should actually perform semantic checking, etc., - // but for now we are just going to look for a direct string - // literal AST node. +// Check an expression, coerce it to the `String` type, and then +// ensure that it has a literal (not just compile-time constant) value. +bool SemanticsVisitor::checkLiteralStringVal(Expr* expr, String* outVal) +{ + // TODO: This should actually perform semantic checking, etc., + // but for now we are just going to look for a direct string + // literal AST node. - if(auto stringLitExpr = as(expr)) + if (auto stringLitExpr = as(expr)) + { + if (outVal) { - if(outVal) - { - *outVal = stringLitExpr->value; - } - return true; + *outVal = stringLitExpr->value; } + return true; + } - getSink()->diagnose(expr, Diagnostics::expectedAStringLiteral); + getSink()->diagnose(expr, Diagnostics::expectedAStringLiteral); - return false; - } + return false; +} - bool SemanticsVisitor::checkCapabilityName(Expr* expr, CapabilityName& outCapabilityName) +bool SemanticsVisitor::checkCapabilityName(Expr* expr, CapabilityName& outCapabilityName) +{ + if (auto varExpr = as(expr)) { - if (auto varExpr = as(expr)) + if (!varExpr->name) + return false; + if (varExpr->name == getSession()->getCompletionRequestTokenName()) { - if (!varExpr->name) - return false; - if (varExpr->name == getSession()->getCompletionRequestTokenName()) - { - auto& suggestions = getLinkage()->contentAssistInfo.completionSuggestions; - suggestions.clear(); - suggestions.scopeKind = CompletionSuggestions::ScopeKind::Capabilities; - } - outCapabilityName = findCapabilityName(varExpr->name->text.getUnownedSlice()); - if (outCapabilityName == CapabilityName::Invalid) - { - getSink()->diagnose(expr, Diagnostics::unknownCapability, varExpr->name); - return false; - } - return true; + auto& suggestions = getLinkage()->contentAssistInfo.completionSuggestions; + suggestions.clear(); + suggestions.scopeKind = CompletionSuggestions::ScopeKind::Capabilities; } - getSink()->diagnose(expr, Diagnostics::expectCapability); - return false; + outCapabilityName = findCapabilityName(varExpr->name->text.getUnownedSlice()); + if (outCapabilityName == CapabilityName::Invalid) + { + getSink()->diagnose(expr, Diagnostics::unknownCapability, varExpr->name); + return false; + } + return true; } + getSink()->diagnose(expr, Diagnostics::expectCapability); + return false; +} - void SemanticsVisitor::visitModifier(Modifier*) - { - // Do nothing with modifiers for now - } +void SemanticsVisitor::visitModifier(Modifier*) +{ + // Do nothing with modifiers for now +} - static bool _isDeclAllowedAsAttribute(DeclRef declRef) - { - if (as(declRef.getDecl())) - return true; - auto structDecl = as(declRef.getDecl()); - if (!structDecl) - return false; - auto attrUsageAttr = structDecl->findModifier(); - if (!attrUsageAttr) - return false; +static bool _isDeclAllowedAsAttribute(DeclRef declRef) +{ + if (as(declRef.getDecl())) return true; - } + auto structDecl = as(declRef.getDecl()); + if (!structDecl) + return false; + auto attrUsageAttr = structDecl->findModifier(); + if (!attrUsageAttr) + return false; + return true; +} - AttributeDecl* SemanticsVisitor::lookUpAttributeDecl(Name* attributeName, Scope* scope) +AttributeDecl* SemanticsVisitor::lookUpAttributeDecl(Name* attributeName, Scope* scope) +{ + if (!attributeName) + return nullptr; + // We start by looking for an existing attribute matching + // the name `attributeName`. + // { - if (!attributeName) - return nullptr; - // We start by looking for an existing attribute matching - // the name `attributeName`. + // Look up the name and see what attributes we find. // + LookupMask lookupMask = LookupMask::Attribute; + if (attributeName == getSession()->getCompletionRequestTokenName()) { - // Look up the name and see what attributes we find. - // - LookupMask lookupMask = LookupMask::Attribute; - if (attributeName == getSession()->getCompletionRequestTokenName()) - { - lookupMask = - LookupMask((uint32_t)LookupMask::Attribute | (uint32_t)LookupMask::type); - } - - auto lookupResult = lookUp(m_astBuilder, this, attributeName, scope, lookupMask); - - if (attributeName == getSession()->getCompletionRequestTokenName()) - { - // If this is a completion request, add the lookup result to linkage. - auto& suggestions = getLinkage()->contentAssistInfo.completionSuggestions; - suggestions.clear(); - suggestions.scopeKind = CompletionSuggestions::ScopeKind::Attribute; - for (auto& item : lookupResult) - { - if (_isDeclAllowedAsAttribute(item.declRef)) - { - suggestions.candidateItems.add(item); - } - } - } + lookupMask = LookupMask((uint32_t)LookupMask::Attribute | (uint32_t)LookupMask::type); + } - // If the result was overloaded, then that means there - // are multiple attributes matching the name, and we - // aren't going to be able to narrow it down. - // - if(lookupResult.isOverloaded()) - return nullptr; + auto lookupResult = lookUp(m_astBuilder, this, attributeName, scope, lookupMask); - // If there is a single valid result, and it names - // an existing attribute declaration, then we can - // use it as the result. - // - if (lookupResult.isValid()) + if (attributeName == getSession()->getCompletionRequestTokenName()) + { + // If this is a completion request, add the lookup result to linkage. + auto& suggestions = getLinkage()->contentAssistInfo.completionSuggestions; + suggestions.clear(); + suggestions.scopeKind = CompletionSuggestions::ScopeKind::Attribute; + for (auto& item : lookupResult) { - auto decl = lookupResult.item.declRef.getDecl(); - if (auto attributeDecl = as(decl)) + if (_isDeclAllowedAsAttribute(item.declRef)) { - return attributeDecl; + suggestions.candidateItems.add(item); } } } - // If there wasn't already an attribute matching the - // given name, then we will look for a `struct` type - // matching the name scheme for user-defined attributes. - // - // If the attribute was `[Something(...)]` then we will - // look for a `struct` named `SomethingAttribute`. - // - LookupResult lookupResult = lookUp(m_astBuilder, this, m_astBuilder->getGlobalSession()->getNameObj(attributeName->text + "Attribute"), scope, LookupMask::type); - // - // If we didn't find a matching type name, then we give up. - // - if (!lookupResult.isValid() || lookupResult.isOverloaded()) - return nullptr; - - - // We only allow a `struct` type to be used as an attribute - // if the type itself has an `[AttributeUsage(...)]` attribute - // attached to it. + // If the result was overloaded, then that means there + // are multiple attributes matching the name, and we + // aren't going to be able to narrow it down. // - auto structDecl = lookupResult.item.declRef.as().getDecl(); - if(!structDecl) - return nullptr; - auto attrUsageAttr = structDecl->findModifier(); - if (!attrUsageAttr) + if (lookupResult.isOverloaded()) return nullptr; - // We will now synthesize a new `AttributeDecl` to mirror - // what was declared on the `struct` type. + // If there is a single valid result, and it names + // an existing attribute declaration, then we can + // use it as the result. // - AttributeDecl* attrDecl = m_astBuilder->create(); - attrDecl->nameAndLoc.name = attributeName; - attrDecl->nameAndLoc.loc = structDecl->nameAndLoc.loc; - attrDecl->loc = structDecl->loc; - - while(attrUsageAttr) + if (lookupResult.isValid()) { - AttributeTargetModifier* targetModifier = m_astBuilder->create(); - targetModifier->syntaxClass = attrUsageAttr->targetSyntaxClass; - targetModifier->loc = attrUsageAttr->loc; - addModifier(attrDecl, targetModifier); - attrUsageAttr = as(attrUsageAttr->next); + auto decl = lookupResult.item.declRef.getDecl(); + if (auto attributeDecl = as(decl)) + { + return attributeDecl; + } } + } - // Every attribute declaration is associated with the type - // of syntax nodes it constructs (via reflection/RTTI). - // - // User-defined attributes create instances of - // `UserDefinedAttribute`. - // - attrDecl->syntaxClass = m_astBuilder->findSyntaxClass(UnownedStringSlice::fromLiteral("UserDefinedAttribute")); + // If there wasn't already an attribute matching the + // given name, then we will look for a `struct` type + // matching the name scheme for user-defined attributes. + // + // If the attribute was `[Something(...)]` then we will + // look for a `struct` named `SomethingAttribute`. + // + LookupResult lookupResult = lookUp( + m_astBuilder, + this, + m_astBuilder->getGlobalSession()->getNameObj(attributeName->text + "Attribute"), + scope, + LookupMask::type); + // + // If we didn't find a matching type name, then we give up. + // + if (!lookupResult.isValid() || lookupResult.isOverloaded()) + return nullptr; + + + // We only allow a `struct` type to be used as an attribute + // if the type itself has an `[AttributeUsage(...)]` attribute + // attached to it. + // + auto structDecl = lookupResult.item.declRef.as().getDecl(); + if (!structDecl) + return nullptr; + auto attrUsageAttr = structDecl->findModifier(); + if (!attrUsageAttr) + return nullptr; + + // We will now synthesize a new `AttributeDecl` to mirror + // what was declared on the `struct` type. + // + AttributeDecl* attrDecl = m_astBuilder->create(); + attrDecl->nameAndLoc.name = attributeName; + attrDecl->nameAndLoc.loc = structDecl->nameAndLoc.loc; + attrDecl->loc = structDecl->loc; + + while (attrUsageAttr) + { + AttributeTargetModifier* targetModifier = m_astBuilder->create(); + targetModifier->syntaxClass = attrUsageAttr->targetSyntaxClass; + targetModifier->loc = attrUsageAttr->loc; + addModifier(attrDecl, targetModifier); + attrUsageAttr = as(attrUsageAttr->next); + } - // The fields of the user-defined `struct` type become - // the parameters of the new attribute. - // - // TODO: This step should skip `static` fields. - // - for(auto member : structDecl->members) + // Every attribute declaration is associated with the type + // of syntax nodes it constructs (via reflection/RTTI). + // + // User-defined attributes create instances of + // `UserDefinedAttribute`. + // + attrDecl->syntaxClass = + m_astBuilder->findSyntaxClass(UnownedStringSlice::fromLiteral("UserDefinedAttribute")); + + // The fields of the user-defined `struct` type become + // the parameters of the new attribute. + // + // TODO: This step should skip `static` fields. + // + for (auto member : structDecl->members) + { + if (auto varMember = as(member)) { - if(auto varMember = as(member)) - { - ensureDecl(varMember, DeclCheckState::CanUseTypeOfValueDecl); + ensureDecl(varMember, DeclCheckState::CanUseTypeOfValueDecl); - ParamDecl* paramDecl = m_astBuilder->create(); - paramDecl->nameAndLoc = member->nameAndLoc; - paramDecl->type = varMember->type; - paramDecl->loc = member->loc; - paramDecl->setCheckState(DeclCheckState::DefinitionChecked); + ParamDecl* paramDecl = m_astBuilder->create(); + paramDecl->nameAndLoc = member->nameAndLoc; + paramDecl->type = varMember->type; + paramDecl->loc = member->loc; + paramDecl->setCheckState(DeclCheckState::DefinitionChecked); - paramDecl->parentDecl = attrDecl; - attrDecl->members.add(paramDecl); - } + paramDecl->parentDecl = attrDecl; + attrDecl->members.add(paramDecl); } + } - // We need to end by putting the new attribute declaration - // into the AST, so that it can be found via lookup. - // - auto parentDecl = structDecl->parentDecl; - // - // TODO: handle the case where `parentDecl` is generic? - // - attrDecl->parentDecl = parentDecl; - parentDecl->members.add(attrDecl); - - SLANG_ASSERT(!parentDecl->isMemberDictionaryValid()); - - // Finally, we perform any required semantic checks on - // the newly constructed attribute decl. - // - // TODO: what check state is relevant here? - // - ensureDecl(attrDecl, DeclCheckState::DefinitionChecked); + // We need to end by putting the new attribute declaration + // into the AST, so that it can be found via lookup. + // + auto parentDecl = structDecl->parentDecl; + // + // TODO: handle the case where `parentDecl` is generic? + // + attrDecl->parentDecl = parentDecl; + parentDecl->members.add(attrDecl); + + SLANG_ASSERT(!parentDecl->isMemberDictionaryValid()); + + // Finally, we perform any required semantic checks on + // the newly constructed attribute decl. + // + // TODO: what check state is relevant here? + // + ensureDecl(attrDecl, DeclCheckState::DefinitionChecked); + + return attrDecl; +} - return attrDecl; +bool SemanticsVisitor::hasIntArgs(Attribute* attr, int numArgs) +{ + if (int(attr->args.getCount()) != numArgs) + { + return false; } - - bool SemanticsVisitor::hasIntArgs(Attribute* attr, int numArgs) + for (int i = 0; i < numArgs; ++i) { - if (int(attr->args.getCount()) != numArgs) + if (!as(attr->args[i])) { return false; } - for (int i = 0; i < numArgs; ++i) - { - if (!as(attr->args[i])) - { - return false; - } - } - return true; } + return true; +} - bool SemanticsVisitor::hasStringArgs(Attribute* attr, int numArgs) +bool SemanticsVisitor::hasStringArgs(Attribute* attr, int numArgs) +{ + if (int(attr->args.getCount()) != numArgs) + { + return false; + } + for (int i = 0; i < numArgs; ++i) { - if (int(attr->args.getCount()) != numArgs) + if (!as(attr->args[i])) { return false; } - for (int i = 0; i < numArgs; ++i) - { - if (!as(attr->args[i])) - { - return false; - } - } - return true; } + return true; +} - bool SemanticsVisitor::getAttributeTargetSyntaxClasses(SyntaxClass & cls, uint32_t typeFlags) +bool SemanticsVisitor::getAttributeTargetSyntaxClasses( + SyntaxClass& cls, + uint32_t typeFlags) +{ + if (typeFlags == (int)UserDefinedAttributeTargets::Struct) { - if (typeFlags == (int)UserDefinedAttributeTargets::Struct) - { - cls = m_astBuilder->findSyntaxClass(UnownedStringSlice::fromLiteral("StructDecl")); - return true; - } - if (typeFlags == (int)UserDefinedAttributeTargets::Var) - { - cls = m_astBuilder->findSyntaxClass(UnownedStringSlice::fromLiteral("VarDecl")); - return true; - } - if (typeFlags == (int)UserDefinedAttributeTargets::Function) - { - cls = m_astBuilder->findSyntaxClass(UnownedStringSlice::fromLiteral("FuncDecl")); - return true; - } - if (typeFlags == (int)UserDefinedAttributeTargets::Param) - { - cls = m_astBuilder->findSyntaxClass(UnownedStringSlice::fromLiteral("ParamDecl")); - return true; - } - return false; + cls = m_astBuilder->findSyntaxClass(UnownedStringSlice::fromLiteral("StructDecl")); + return true; } - - Modifier* SemanticsVisitor::validateAttribute(Attribute* attr, AttributeDecl* attribClassDecl, ModifiableSyntaxNode* attrTarget) + if (typeFlags == (int)UserDefinedAttributeTargets::Var) { - if (auto numThreadsAttr = as(attr)) - { - SLANG_ASSERT(attr->args.getCount() == 3); - - IntVal* values[3]; + cls = m_astBuilder->findSyntaxClass(UnownedStringSlice::fromLiteral("VarDecl")); + return true; + } + if (typeFlags == (int)UserDefinedAttributeTargets::Function) + { + cls = m_astBuilder->findSyntaxClass(UnownedStringSlice::fromLiteral("FuncDecl")); + return true; + } + if (typeFlags == (int)UserDefinedAttributeTargets::Param) + { + cls = m_astBuilder->findSyntaxClass(UnownedStringSlice::fromLiteral("ParamDecl")); + return true; + } + return false; +} - for (int i = 0; i < 3; ++i) - { - IntVal* value = nullptr; +Modifier* SemanticsVisitor::validateAttribute( + Attribute* attr, + AttributeDecl* attribClassDecl, + ModifiableSyntaxNode* attrTarget) +{ + if (auto numThreadsAttr = as(attr)) + { + SLANG_ASSERT(attr->args.getCount() == 3); - auto arg = attr->args[i]; - if (arg) - { - auto intValue = checkLinkTimeConstantIntVal(arg); - if (!intValue) - { - return nullptr; - } - if (auto constIntVal = as(intValue)) - { - if (constIntVal->getValue() < 1) - { - getSink()->diagnose(attr, Diagnostics::nonPositiveNumThreads, constIntVal->getValue()); - return nullptr; - } - if (intValue->getType() != m_astBuilder->getIntType()) - { - intValue = m_astBuilder->getIntVal(m_astBuilder->getIntType(), constIntVal->getValue()); - } - } - // Make sure we always canonicalize the type to int. - value = intValue; - if (value->getType() != m_astBuilder->getIntType()) - value = m_astBuilder->getTypeCastIntVal(m_astBuilder->getIntType(), value); - } - else - { - value = m_astBuilder->getIntVal(m_astBuilder->getIntType(), 1); - } - values[i] = value; - } + IntVal* values[3]; - numThreadsAttr->x = values[0]; - numThreadsAttr->y = values[1]; - numThreadsAttr->z = values[2]; - } - else if (auto waveSizeAttr = as(attr)) + for (int i = 0; i < 3; ++i) { - SLANG_ASSERT(attr->args.getCount() == 1); - IntVal* value = nullptr; - auto arg = attr->args[0]; + auto arg = attr->args[i]; if (arg) { auto intValue = checkLinkTimeConstantIntVal(arg); @@ -394,1384 +365,1531 @@ namespace Slang } if (auto constIntVal = as(intValue)) { - bool isValidWaveSize = false; - const IntegerLiteralValue waveSize = constIntVal->getValue(); - for (int validWaveSize : { 4, 8, 16, 32, 64, 128 }) + if (constIntVal->getValue() < 1) { - if (validWaveSize == waveSize) - { - isValidWaveSize = true; - break; - } + getSink()->diagnose( + attr, + Diagnostics::nonPositiveNumThreads, + constIntVal->getValue()); + return nullptr; } - if (!isValidWaveSize) + if (intValue->getType() != m_astBuilder->getIntType()) { - getSink()->diagnose(attr, Diagnostics::invalidWaveSize, constIntVal->getValue()); - return nullptr; + intValue = m_astBuilder->getIntVal( + m_astBuilder->getIntType(), + constIntVal->getValue()); } } + // Make sure we always canonicalize the type to int. value = intValue; + if (value->getType() != m_astBuilder->getIntType()) + value = m_astBuilder->getTypeCastIntVal(m_astBuilder->getIntType(), value); } else { value = m_astBuilder->getIntVal(m_astBuilder->getIntType(), 1); } - - waveSizeAttr->numLanes = value; + values[i] = value; } - else if (auto anyValueSizeAttr = as(attr)) - { - // This case handles GLSL-oriented layout attributes - // that take a single integer argument. - if (attr->args.getCount() != 1) - { - return nullptr; - } - - auto value = checkConstantIntVal(attr->args[0]); - if (value == nullptr) - { - return nullptr; - } + numThreadsAttr->x = values[0]; + numThreadsAttr->y = values[1]; + numThreadsAttr->z = values[2]; + } + else if (auto waveSizeAttr = as(attr)) + { + SLANG_ASSERT(attr->args.getCount() == 1); - const IRIntegerValue kMaxAnyValueSize = 0x7FFF; - if (value->getValue() > kMaxAnyValueSize) - { - getSink()->diagnose(anyValueSizeAttr->loc, Diagnostics::anyValueSizeExceedsLimit, kMaxAnyValueSize); - return nullptr; - } + IntVal* value = nullptr; - anyValueSizeAttr->size = int32_t(value->getValue()); - } - else if (auto glslRequireShaderInputParameter = as(attr)) + auto arg = attr->args[0]; + if (arg) { - if (attr->args.getCount() != 1) - { - return nullptr; - } - auto value = checkConstantIntVal(attr->args[0]); - if (value == nullptr) + auto intValue = checkLinkTimeConstantIntVal(arg); + if (!intValue) { return nullptr; } - if (value->getValue() < 0) + if (auto constIntVal = as(intValue)) { - return nullptr; + bool isValidWaveSize = false; + const IntegerLiteralValue waveSize = constIntVal->getValue(); + for (int validWaveSize : {4, 8, 16, 32, 64, 128}) + { + if (validWaveSize == waveSize) + { + isValidWaveSize = true; + break; + } + } + if (!isValidWaveSize) + { + getSink()->diagnose( + attr, + Diagnostics::invalidWaveSize, + constIntVal->getValue()); + return nullptr; + } } - glslRequireShaderInputParameter->parameterNumber = int32_t(value->getValue()); + value = intValue; } - else if (auto overloadRankAttr = as(attr)) + else { - if (attr->args.getCount() != 1) - { - return nullptr; - } - auto rank = checkConstantIntVal(attr->args[0]); - if (rank == nullptr) - { - return nullptr; - } - overloadRankAttr->rank = int32_t(rank->getValue()); + value = m_astBuilder->getIntVal(m_astBuilder->getIntType(), 1); } - else if (auto inputAttachmentIndexLayoutAttribute = as(attr)) - { - if (attr->args.getCount() != 1) - return nullptr; - - auto location = checkConstantIntVal(attr->args[0]); - if(!location) - return nullptr; - inputAttachmentIndexLayoutAttribute->location = location->getValue(); - } - else if (auto bindingAttr = as(attr)) + waveSizeAttr->numLanes = value; + } + else if (auto anyValueSizeAttr = as(attr)) + { + // This case handles GLSL-oriented layout attributes + // that take a single integer argument. + + if (attr->args.getCount() != 1) { - // This must be vk::binding or gl::binding (as specified in core.meta.slang under vk_binding/gl_binding) - // Must have 2 int parameters. Ideally this would all be checked from the specification - // in core.meta.slang, but that's not completely implemented. So for now we check here. - if (attr->args.getCount() != 2) - { - return nullptr; - } + return nullptr; + } - // TODO(JS): Prior validation currently doesn't ensure both args are ints (as specified in core.meta.slang), so check here - // to make sure they both are - auto binding = checkConstantIntVal(attr->args[0]); - auto set = checkConstantIntVal(attr->args[1]); + auto value = checkConstantIntVal(attr->args[0]); + if (value == nullptr) + { + return nullptr; + } - if (binding == nullptr || set == nullptr) - { - return nullptr; - } + const IRIntegerValue kMaxAnyValueSize = 0x7FFF; + if (value->getValue() > kMaxAnyValueSize) + { + getSink()->diagnose( + anyValueSizeAttr->loc, + Diagnostics::anyValueSizeExceedsLimit, + kMaxAnyValueSize); + return nullptr; + } - bindingAttr->binding = int32_t(binding->getValue()); - bindingAttr->set = int32_t(set->getValue()); + anyValueSizeAttr->size = int32_t(value->getValue()); + } + else if ( + auto glslRequireShaderInputParameter = as(attr)) + { + if (attr->args.getCount() != 1) + { + return nullptr; } - else if (auto simpleLayoutAttr = as(attr)) + auto value = checkConstantIntVal(attr->args[0]); + if (value == nullptr) { - // This case handles GLSL-oriented layout attributes - // that take a single integer argument. - - if (attr->args.getCount() != 1) - { - return nullptr; - } - - auto value = checkConstantIntVal(attr->args[0]); - if (value == nullptr) - { - return nullptr; - } - - simpleLayoutAttr->value = int32_t(value->getValue()); + return nullptr; } - else if (auto maxVertexCountAttr = as(attr)) + if (value->getValue() < 0) { - SLANG_ASSERT(attr->args.getCount() == 1); - auto val = checkConstantIntVal(attr->args[0]); - - if (!val) return nullptr; - - maxVertexCountAttr->value = (int32_t)val->getValue(); + return nullptr; } - else if (auto instanceAttr = as(attr)) + glslRequireShaderInputParameter->parameterNumber = int32_t(value->getValue()); + } + else if (auto overloadRankAttr = as(attr)) + { + if (attr->args.getCount() != 1) { - SLANG_ASSERT(attr->args.getCount() == 1); - auto val = checkConstantIntVal(attr->args[0]); - - if (!val) return nullptr; - - instanceAttr->value = (int32_t)val->getValue(); + return nullptr; } - else if (auto entryPointAttr = as(attr)) + auto rank = checkConstantIntVal(attr->args[0]); + if (rank == nullptr) { - SLANG_ASSERT(attr->args.getCount() == 1); + return nullptr; + } + overloadRankAttr->rank = int32_t(rank->getValue()); + } + else if ( + auto inputAttachmentIndexLayoutAttribute = + as(attr)) + { + if (attr->args.getCount() != 1) + return nullptr; - String capNameString; - if (!checkLiteralStringVal(attr->args[0], &capNameString)) - { - return nullptr; - } + auto location = checkConstantIntVal(attr->args[0]); + if (!location) + return nullptr; - CapabilityName capName = findCapabilityName(capNameString.getUnownedSlice()); - if (capName != CapabilityName::Invalid) - { - if (isInternalCapabilityName(capName)) - maybeDiagnose(getSink(), this->getOptionSet(), DiagnosticCategory::Capability, attr, Diagnostics::usingInternalCapabilityName, attr, capName); - - // Ensure this capability only defines 1 stage per target, else diagnose an error. - // This is a fatal error, do not allow toggling this error off. - entryPointAttr->capabilitySet = CapabilitySet(capName); - HashSet stageToBeUsed; - for (auto& targetSet : entryPointAttr->capabilitySet.getCapabilityTargetSets()) - { - for(auto& stageSet : targetSet.second.shaderStageSets) - stageToBeUsed.add(stageSet.first); - } + inputAttachmentIndexLayoutAttribute->location = location->getValue(); + } + else if (auto bindingAttr = as(attr)) + { + // This must be vk::binding or gl::binding (as specified in core.meta.slang under + // vk_binding/gl_binding) Must have 2 int parameters. Ideally this would all be checked from + // the specification in core.meta.slang, but that's not completely implemented. So for now + // we check here. + if (attr->args.getCount() != 2) + { + return nullptr; + } - // TODO: Once profiles are removed in favor for `CapabilitySet`s we will beable to use more complex relationships, - // Until then we have an artificial limit that any capabilites used inside '[shader(...)]' must only specify 1 stage type - // uniformly across targets. - if (stageToBeUsed.getCount() > 1) - { - List atomsToPrint; - atomsToPrint.reserve(stageToBeUsed.getCount()); - for (auto i : stageToBeUsed) - atomsToPrint.add(i); - getSink()->diagnose(attr, Diagnostics::capabilityHasMultipleStages, capNameString, atomsToPrint); - } - return entryPointAttr; - } - else - { - // always diagnose this error since nothing can compile with an invalid capability - getSink()->diagnose(attr, Diagnostics::unknownCapability, capNameString); - return nullptr; - } + // TODO(JS): Prior validation currently doesn't ensure both args are ints (as specified in + // core.meta.slang), so check here to make sure they both are + auto binding = checkConstantIntVal(attr->args[0]); + auto set = checkConstantIntVal(attr->args[1]); + + if (binding == nullptr || set == nullptr) + { + return nullptr; } - else if ((as(attr)) || - (as(attr)) || - (as(attr)) || - (as(attr)) || - (as(attr))) + + bindingAttr->binding = int32_t(binding->getValue()); + bindingAttr->set = int32_t(set->getValue()); + } + else if (auto simpleLayoutAttr = as(attr)) + { + // This case handles GLSL-oriented layout attributes + // that take a single integer argument. + + if (attr->args.getCount() != 1) { - // Let it go thru iff single string attribute - if (!hasStringArgs(attr, 1)) - { - getSink()->diagnose(attr, Diagnostics::expectedSingleStringArg, attr->keywordName); - } + return nullptr; } - else if (auto opAttr = as(attr)) + + auto value = checkConstantIntVal(attr->args[0]); + if (value == nullptr) { - auto sink = getSink(); - const auto argsCount = opAttr->args.getCount(); - if (argsCount < 1 || argsCount > 2) - { - sink->diagnose(attr, Diagnostics::attributeArgumentCountMismatch, attr->keywordName, "1...2", argsCount); - } - else if (!as(opAttr->args[0])) + return nullptr; + } + + simpleLayoutAttr->value = int32_t(value->getValue()); + } + else if (auto maxVertexCountAttr = as(attr)) + { + SLANG_ASSERT(attr->args.getCount() == 1); + auto val = checkConstantIntVal(attr->args[0]); + + if (!val) + return nullptr; + + maxVertexCountAttr->value = (int32_t)val->getValue(); + } + else if (auto instanceAttr = as(attr)) + { + SLANG_ASSERT(attr->args.getCount() == 1); + auto val = checkConstantIntVal(attr->args[0]); + + if (!val) + return nullptr; + + instanceAttr->value = (int32_t)val->getValue(); + } + else if (auto entryPointAttr = as(attr)) + { + SLANG_ASSERT(attr->args.getCount() == 1); + + String capNameString; + if (!checkLiteralStringVal(attr->args[0], &capNameString)) + { + return nullptr; + } + + CapabilityName capName = findCapabilityName(capNameString.getUnownedSlice()); + if (capName != CapabilityName::Invalid) + { + if (isInternalCapabilityName(capName)) + maybeDiagnose( + getSink(), + this->getOptionSet(), + DiagnosticCategory::Capability, + attr, + Diagnostics::usingInternalCapabilityName, + attr, + capName); + + // Ensure this capability only defines 1 stage per target, else diagnose an error. + // This is a fatal error, do not allow toggling this error off. + entryPointAttr->capabilitySet = CapabilitySet(capName); + HashSet stageToBeUsed; + for (auto& targetSet : entryPointAttr->capabilitySet.getCapabilityTargetSets()) { - sink->diagnose(attr, Diagnostics::attributeExpectedIntArg, attr->keywordName, 0); + for (auto& stageSet : targetSet.second.shaderStageSets) + stageToBeUsed.add(stageSet.first); } - else if (argsCount > 1 && !as(opAttr->args[1])) + + // TODO: Once profiles are removed in favor for `CapabilitySet`s we will beable to use + // more complex relationships, Until then we have an artificial limit that any + // capabilites used inside '[shader(...)]' must only specify 1 stage type uniformly + // across targets. + if (stageToBeUsed.getCount() > 1) { - sink->diagnose(attr, Diagnostics::attributeExpectedStringArg, attr->keywordName, 1); + List atomsToPrint; + atomsToPrint.reserve(stageToBeUsed.getCount()); + for (auto i : stageToBeUsed) + atomsToPrint.add(i); + getSink()->diagnose( + attr, + Diagnostics::capabilityHasMultipleStages, + capNameString, + atomsToPrint); } + return entryPointAttr; } - else if (as(attr)) + else { - // Let it go thru iff single integral attribute - if (!hasIntArgs(attr, 1)) - { - getSink()->diagnose(attr, Diagnostics::expectedSingleIntArg, attr->keywordName); - } + // always diagnose this error since nothing can compile with an invalid capability + getSink()->diagnose(attr, Diagnostics::unknownCapability, capNameString); + return nullptr; + } + } + else if ( + (as(attr)) || (as(attr)) || + (as(attr)) || (as(attr)) || + (as(attr))) + { + // Let it go thru iff single string attribute + if (!hasStringArgs(attr, 1)) + { + getSink()->diagnose(attr, Diagnostics::expectedSingleStringArg, attr->keywordName); + } + } + else if (auto opAttr = as(attr)) + { + auto sink = getSink(); + const auto argsCount = opAttr->args.getCount(); + if (argsCount < 1 || argsCount > 2) + { + sink->diagnose( + attr, + Diagnostics::attributeArgumentCountMismatch, + attr->keywordName, + "1...2", + argsCount); + } + else if (!as(opAttr->args[0])) + { + sink->diagnose(attr, Diagnostics::attributeExpectedIntArg, attr->keywordName, 0); + } + else if (argsCount > 1 && !as(opAttr->args[1])) + { + sink->diagnose(attr, Diagnostics::attributeExpectedStringArg, attr->keywordName, 1); } - else if (auto attrUsageAttr = as(attr)) + } + else if (as(attr)) + { + // Let it go thru iff single integral attribute + if (!hasIntArgs(attr, 1)) { - uint32_t targetClassId = (uint32_t)UserDefinedAttributeTargets::None; - if (attr->args.getCount() == 1) + getSink()->diagnose(attr, Diagnostics::expectedSingleIntArg, attr->keywordName); + } + } + else if (auto attrUsageAttr = as(attr)) + { + uint32_t targetClassId = (uint32_t)UserDefinedAttributeTargets::None; + if (attr->args.getCount() == 1) + { + // IntVal* outIntVal; + if (auto cInt = checkConstantEnumVal(attr->args[0])) { - //IntVal* outIntVal; - if (auto cInt = checkConstantEnumVal(attr->args[0])) - { - targetClassId = (uint32_t)(cInt->getValue()); - } - else - { - getSink()->diagnose(attr, Diagnostics::expectedSingleIntArg, attr->keywordName); - return nullptr; - } + targetClassId = (uint32_t)(cInt->getValue()); } - if (!getAttributeTargetSyntaxClasses(attrUsageAttr->targetSyntaxClass, targetClassId)) + else { - getSink()->diagnose(attr, Diagnostics::invalidAttributeTarget); + getSink()->diagnose(attr, Diagnostics::expectedSingleIntArg, attr->keywordName); return nullptr; } } - else if (const auto unrollAttr = as(attr)) + if (!getAttributeTargetSyntaxClasses(attrUsageAttr->targetSyntaxClass, targetClassId)) + { + getSink()->diagnose(attr, Diagnostics::invalidAttributeTarget); + return nullptr; + } + } + else if (const auto unrollAttr = as(attr)) + { + // Check has an argument. We need this because default behavior is to give an error + // if an attribute has arguments, but not handled explicitly (and the default param will + // come through as 1 arg if nothing is specified) + SLANG_ASSERT(attr->args.getCount() == 1); + } + else if (auto forceUnrollAttr = as(attr)) + { + if (forceUnrollAttr->args.getCount() < 1) + { + getSink()->diagnose(attr, Diagnostics::notEnoughArguments, attr->args.getCount(), 1); + } + auto cint = checkConstantIntVal(attr->args[0]); + if (cint) + forceUnrollAttr->maxIterations = (int32_t)cint->getValue(); + } + else if (auto maxItersAttrs = as(attr)) + { + if (attr->args.getCount() < 1) { - // Check has an argument. We need this because default behavior is to give an error - // if an attribute has arguments, but not handled explicitly (and the default param will come through - // as 1 arg if nothing is specified) - SLANG_ASSERT(attr->args.getCount() == 1); + getSink()->diagnose(attr, Diagnostics::notEnoughArguments, attr->args.getCount(), 1); } - else if (auto forceUnrollAttr = as(attr)) + else { - if (forceUnrollAttr->args.getCount() < 1) - { - getSink()->diagnose(attr, Diagnostics::notEnoughArguments, attr->args.getCount(), 1); - } auto cint = checkConstantIntVal(attr->args[0]); if (cint) - forceUnrollAttr->maxIterations = (int32_t)cint->getValue(); - } - else if (auto maxItersAttrs = as(attr)) - { - if (attr->args.getCount() < 1) { - getSink()->diagnose(attr, Diagnostics::notEnoughArguments, attr->args.getCount(), 1); - } - else - { - auto cint = checkConstantIntVal(attr->args[0]); - if (cint) - { - maxItersAttrs->value = (int32_t) cint->getValue(); - } + maxItersAttrs->value = (int32_t)cint->getValue(); } } - else if (const auto userDefAttr = as(attr)) + } + else if (const auto userDefAttr = as(attr)) + { + // check arguments against attribute parameters defined in attribClassDecl + Index paramIndex = 0; + auto params = attribClassDecl->getMembersOfType(); + for (auto paramDecl : params) { - // check arguments against attribute parameters defined in attribClassDecl - Index paramIndex = 0; - auto params = attribClassDecl->getMembersOfType(); - for (auto paramDecl : params) - { - ensureDecl(paramDecl, DeclCheckState::CanUseTypeOfValueDecl); + ensureDecl(paramDecl, DeclCheckState::CanUseTypeOfValueDecl); - if (paramIndex < attr->args.getCount()) + if (paramIndex < attr->args.getCount()) + { + auto& arg = attr->args[paramIndex]; + bool typeChecked = false; + if (auto basicType = as(paramDecl->getType())) { - auto & arg = attr->args[paramIndex]; - bool typeChecked = false; - if (auto basicType = as(paramDecl->getType())) + if (basicType->getBaseType() == BaseType::Int) { - if (basicType->getBaseType() == BaseType::Int) + if (auto cint = checkConstantIntVal(arg)) { - if (auto cint = checkConstantIntVal(arg)) - { - for (Index ci = attr->intArgVals.getCount(); ci < paramIndex + 1; ci++) - attr->intArgVals.add(nullptr); - attr->intArgVals[(uint32_t)paramIndex] = cint; - } - typeChecked = true; + for (Index ci = attr->intArgVals.getCount(); ci < paramIndex + 1; ci++) + attr->intArgVals.add(nullptr); + attr->intArgVals[(uint32_t)paramIndex] = cint; } - } - if (!typeChecked) - { - arg = CheckTerm(arg); - arg = coerce(CoercionSite::Argument, paramDecl->getType(), arg); + typeChecked = true; } } - paramIndex++; - } - if (params.getCount() < attr->args.getCount()) - { - getSink()->diagnose(attr, Diagnostics::tooManyArguments, attr->args.getCount(), params.getCount()); - } - else if (params.getCount() > attr->args.getCount()) - { - getSink()->diagnose(attr, Diagnostics::notEnoughArguments, attr->args.getCount(), params.getCount()); + if (!typeChecked) + { + arg = CheckTerm(arg); + arg = coerce(CoercionSite::Argument, paramDecl->getType(), arg); + } } + paramIndex++; } - else if (auto diffAttr = as(attr)) + if (params.getCount() < attr->args.getCount()) { - SLANG_ASSERT(attr->args.getCount() == 1); - auto cint = checkConstantIntVal(attr->args[0]); - if (cint) - diffAttr->maxOrder = (int32_t)cint->getValue(); + getSink()->diagnose( + attr, + Diagnostics::tooManyArguments, + attr->args.getCount(), + params.getCount()); } - else if (auto formatAttr = as(attr)) + else if (params.getCount() > attr->args.getCount()) { - SLANG_ASSERT(attr->args.getCount() == 1); - - String formatName; - if(!checkLiteralStringVal(attr->args[0], &formatName)) - { - return nullptr; - } - - ImageFormat format = ImageFormat::unknown; - - if (attr->keywordName->text.getUnownedSlice() == toSlice("image")) - { - if(!findImageFormatByName(formatName.getUnownedSlice(), &format)) - { - getSink()->diagnose(attr->args[0], Diagnostics::unknownImageFormatName, formatName); - } - } - else - { - if (!findVkImageFormatByName(formatName.getUnownedSlice(), &format)) - { - getSink()->diagnose(attr->args[0], Diagnostics::unknownImageFormatName, formatName); - } - } - - formatAttr->format = format; + getSink()->diagnose( + attr, + Diagnostics::notEnoughArguments, + attr->args.getCount(), + params.getCount()); } - else if (auto allowAttr = as(attr)) + } + else if (auto diffAttr = as(attr)) + { + SLANG_ASSERT(attr->args.getCount() == 1); + auto cint = checkConstantIntVal(attr->args[0]); + if (cint) + diffAttr->maxOrder = (int32_t)cint->getValue(); + } + else if (auto formatAttr = as(attr)) + { + SLANG_ASSERT(attr->args.getCount() == 1); + + String formatName; + if (!checkLiteralStringVal(attr->args[0], &formatName)) { - SLANG_ASSERT(attr->args.getCount() == 1); + return nullptr; + } - String diagnosticName; - if(!checkLiteralStringVal(attr->args[0], &diagnosticName)) - { - return nullptr; - } + ImageFormat format = ImageFormat::unknown; - auto diagnosticInfo = findDiagnosticByName(diagnosticName.getUnownedSlice()); - if(!diagnosticInfo) + if (attr->keywordName->text.getUnownedSlice() == toSlice("image")) + { + if (!findImageFormatByName(formatName.getUnownedSlice(), &format)) { - getSink()->diagnose(attr->args[0], Diagnostics::unknownDiagnosticName, diagnosticName); + getSink()->diagnose(attr->args[0], Diagnostics::unknownImageFormatName, formatName); } - - allowAttr->diagnostic = diagnosticInfo; } - else if (auto dllImportAttr = as(attr)) + else { - SLANG_ASSERT(attr->args.getCount() == 1 || attr->args.getCount() == 2); - - String libraryName; - if (!checkLiteralStringVal(dllImportAttr->args[0], &libraryName)) - { - return nullptr; - } - dllImportAttr->modulePath = libraryName; - - String functionName; - if (dllImportAttr->args.getCount() == 2 && !checkLiteralStringVal(dllImportAttr->args[1], &functionName)) + if (!findVkImageFormatByName(formatName.getUnownedSlice(), &format)) { - return nullptr; + getSink()->diagnose(attr->args[0], Diagnostics::unknownImageFormatName, formatName); } - dllImportAttr->functionName = functionName; } - else if (auto rayPayloadAttr = as(attr)) - { - SLANG_ASSERT(attr->args.getCount() == 1); - auto val = checkConstantIntVal(attr->args[0]); - if (!val) return nullptr; + formatAttr->format = format; + } + else if (auto allowAttr = as(attr)) + { + SLANG_ASSERT(attr->args.getCount() == 1); - rayPayloadAttr->location = (int32_t)val->getValue(); - } - else if (auto rayPayloadInAttr = as(attr)) + String diagnosticName; + if (!checkLiteralStringVal(attr->args[0], &diagnosticName)) { - SLANG_ASSERT(attr->args.getCount() == 1); - auto val = checkConstantIntVal(attr->args[0]); - if (!val) return nullptr; - rayPayloadInAttr->location = (int32_t)val->getValue(); + return nullptr; } - else if (auto callablePayloadAttr = as(attr)) - { - SLANG_ASSERT(attr->args.getCount() == 1); - auto val = checkConstantIntVal(attr->args[0]); - if (!val) return nullptr; - - callablePayloadAttr->location = (int32_t)val->getValue(); - } - else if (auto callablePayloadInAttr = as(attr)) + auto diagnosticInfo = findDiagnosticByName(diagnosticName.getUnownedSlice()); + if (!diagnosticInfo) { - SLANG_ASSERT(attr->args.getCount() == 1); - auto val = checkConstantIntVal(attr->args[0]); - if (!val) return nullptr; - callablePayloadInAttr->location = (int32_t)val->getValue(); + getSink()->diagnose(attr->args[0], Diagnostics::unknownDiagnosticName, diagnosticName); } - else if (auto hitObjectAttributesAttr = as(attr)) - { - SLANG_ASSERT(attr->args.getCount() == 1); - auto val = checkConstantIntVal(attr->args[0]); - if (!val) return nullptr; - - hitObjectAttributesAttr->location = (int32_t)val->getValue(); - } - else if (auto constantIdAttr = as(attr)) - { - SLANG_ASSERT(attr->args.getCount() == 1); - auto val = checkConstantIntVal(attr->args[0]); + allowAttr->diagnostic = diagnosticInfo; + } + else if (auto dllImportAttr = as(attr)) + { + SLANG_ASSERT(attr->args.getCount() == 1 || attr->args.getCount() == 2); - if (!val) return nullptr; - constantIdAttr->location = (int32_t)val->getValue(); - } - else if (as(attr) || as(attr)) + String libraryName; + if (!checkLiteralStringVal(dllImportAttr->args[0], &libraryName)) { - SLANG_ASSERT(attr->args.getCount() == 1); - SLANG_ASSERT(as(attrTarget)); - if (auto derivativeAttr = as(attr)) - derivativeAttr->funcExpr = attr->args[0]; - else if (auto primalSubstAttr = as(attr)) - primalSubstAttr->funcExpr = attr->args[0]; + return nullptr; } - else if (as(attr) || as(attr)) + dllImportAttr->modulePath = libraryName; + + String functionName; + if (dllImportAttr->args.getCount() == 2 && + !checkLiteralStringVal(dllImportAttr->args[1], &functionName)) { - SLANG_ASSERT(attr->args.getCount() == 1); - SLANG_ASSERT(as(attrTarget)); - if (auto derivativeOfAttr = as(attr)) - derivativeOfAttr->funcExpr = attr->args[0]; - else if (auto primalOfAttr = as(attr)) - primalOfAttr->funcExpr = attr->args[0]; + return nullptr; } - else if (auto preferRecomputeAttr = as(attr)) - { - SLANG_ASSERT(attr->args.getCount() == 1); - SLANG_ASSERT(as(attrTarget)); + dllImportAttr->functionName = functionName; + } + else if (auto rayPayloadAttr = as(attr)) + { + SLANG_ASSERT(attr->args.getCount() == 1); + auto val = checkConstantIntVal(attr->args[0]); + + if (!val) + return nullptr; + + rayPayloadAttr->location = (int32_t)val->getValue(); + } + else if (auto rayPayloadInAttr = as(attr)) + { + SLANG_ASSERT(attr->args.getCount() == 1); + auto val = checkConstantIntVal(attr->args[0]); + if (!val) + return nullptr; + rayPayloadInAttr->location = (int32_t)val->getValue(); + } + else if (auto callablePayloadAttr = as(attr)) + { + SLANG_ASSERT(attr->args.getCount() == 1); + auto val = checkConstantIntVal(attr->args[0]); + + if (!val) + return nullptr; + + callablePayloadAttr->location = (int32_t)val->getValue(); + } + else if (auto callablePayloadInAttr = as(attr)) + { + SLANG_ASSERT(attr->args.getCount() == 1); + auto val = checkConstantIntVal(attr->args[0]); + if (!val) + return nullptr; + callablePayloadInAttr->location = (int32_t)val->getValue(); + } + else if (auto hitObjectAttributesAttr = as(attr)) + { + SLANG_ASSERT(attr->args.getCount() == 1); + auto val = checkConstantIntVal(attr->args[0]); + + if (!val) + return nullptr; + + hitObjectAttributesAttr->location = (int32_t)val->getValue(); + } + else if (auto constantIdAttr = as(attr)) + { + SLANG_ASSERT(attr->args.getCount() == 1); + auto val = checkConstantIntVal(attr->args[0]); - auto val = checkConstantIntVal(attr->args[0]); - if (!val) return nullptr; + if (!val) + return nullptr; + constantIdAttr->location = (int32_t)val->getValue(); + } + else if (as(attr) || as(attr)) + { + SLANG_ASSERT(attr->args.getCount() == 1); + SLANG_ASSERT(as(attrTarget)); + if (auto derivativeAttr = as(attr)) + derivativeAttr->funcExpr = attr->args[0]; + else if (auto primalSubstAttr = as(attr)) + primalSubstAttr->funcExpr = attr->args[0]; + } + else if (as(attr) || as(attr)) + { + SLANG_ASSERT(attr->args.getCount() == 1); + SLANG_ASSERT(as(attrTarget)); + if (auto derivativeOfAttr = as(attr)) + derivativeOfAttr->funcExpr = attr->args[0]; + else if (auto primalOfAttr = as(attr)) + primalOfAttr->funcExpr = attr->args[0]; + } + else if (auto preferRecomputeAttr = as(attr)) + { + SLANG_ASSERT(attr->args.getCount() == 1); + SLANG_ASSERT(as(attrTarget)); + + auto val = checkConstantIntVal(attr->args[0]); + if (!val) + return nullptr; - preferRecomputeAttr->sideEffectBehavior = (PreferRecomputeAttribute::SideEffectBehavior) val->getValue(); + preferRecomputeAttr->sideEffectBehavior = + (PreferRecomputeAttribute::SideEffectBehavior)val->getValue(); + } + else if (auto comInterfaceAttr = as(attr)) + { + SLANG_ASSERT(attr->args.getCount() == 1); + String guid; + if (!checkLiteralStringVal(comInterfaceAttr->args[0], &guid)) + { + return nullptr; } - else if (auto comInterfaceAttr = as(attr)) + StringBuilder resultGUID; + for (auto ch : guid) { - SLANG_ASSERT(attr->args.getCount() == 1); - String guid; - if (!checkLiteralStringVal(comInterfaceAttr->args[0], &guid)) + if (CharUtil::isHexDigit(ch)) { - return nullptr; + resultGUID.appendChar(ch); } - StringBuilder resultGUID; - for (auto ch : guid) + else if (ch == '-') { - if (CharUtil::isHexDigit(ch)) - { - resultGUID.appendChar(ch); - } - else if (ch == '-') - { - continue; - } - else - { - getSink()->diagnose(attr, Diagnostics::invalidGUID, guid); - return nullptr; - } + continue; } - comInterfaceAttr->guid = resultGUID.toString(); - if (comInterfaceAttr->guid.getLength() != 32) + else { getSink()->diagnose(attr, Diagnostics::invalidGUID, guid); return nullptr; } } - else if (const auto derivativeMemberAttr = as(attr)) + comInterfaceAttr->guid = resultGUID.toString(); + if (comInterfaceAttr->guid.getLength() != 32) { - auto varDecl = as(attrTarget); - if (!varDecl) - { - getSink()->diagnose(attr, Diagnostics::attributeNotApplicable, attr->getKeywordName()); - return nullptr; - } + getSink()->diagnose(attr, Diagnostics::invalidGUID, guid); + return nullptr; } - else if (auto deprecatedAttr = as(attr)) + } + else if (const auto derivativeMemberAttr = as(attr)) + { + auto varDecl = as(attrTarget); + if (!varDecl) { - SLANG_ASSERT(attr->args.getCount() == 1); - - String message; - if(!checkLiteralStringVal(attr->args[0], &message)) - { - return nullptr; - } - - deprecatedAttr->message = message; + getSink()->diagnose(attr, Diagnostics::attributeNotApplicable, attr->getKeywordName()); + return nullptr; } - else if (auto knownBuiltinAttr = as(attr)) + } + else if (auto deprecatedAttr = as(attr)) + { + SLANG_ASSERT(attr->args.getCount() == 1); + + String message; + if (!checkLiteralStringVal(attr->args[0], &message)) { - SLANG_ASSERT(attr->args.getCount() == 1); + return nullptr; + } - String name; - if(!checkLiteralStringVal(attr->args[0], &name)) - { - return nullptr; - } + deprecatedAttr->message = message; + } + else if (auto knownBuiltinAttr = as(attr)) + { + SLANG_ASSERT(attr->args.getCount() == 1); - knownBuiltinAttr->name = name; - } - else if (auto pyExportAttr = as(attr)) + String name; + if (!checkLiteralStringVal(attr->args[0], &name)) { - // Check name string. - SLANG_ASSERT(attr->args.getCount() == 1); + return nullptr; + } - String name; - if(!checkLiteralStringVal(attr->args[0], &name)) - { - return nullptr; - } + knownBuiltinAttr->name = name; + } + else if (auto pyExportAttr = as(attr)) + { + // Check name string. + SLANG_ASSERT(attr->args.getCount() == 1); - pyExportAttr->name = name; - } - else if (auto requireCapAttr = as(attr)) + String name; + if (!checkLiteralStringVal(attr->args[0], &name)) { - List capabilityNames; - for (auto& arg : attr->args) - { - CapabilityName capName; - if (checkCapabilityName(arg, capName)) - { - capabilityNames.add(capName); - if(isInternalCapabilityName(capName)) - maybeDiagnose(getSink(), this->getOptionSet(), DiagnosticCategory::Capability, attr, Diagnostics::usingInternalCapabilityName, attr, capName); - } - } - requireCapAttr->capabilitySet = CapabilitySet(capabilityNames); - if (requireCapAttr->capabilitySet.isInvalid()) - maybeDiagnose(getSink(), this->getOptionSet(), DiagnosticCategory::Capability, attr, Diagnostics::unexpectedCapability, attr, CapabilityName::Invalid); + return nullptr; } - else if (auto requirePreludeAttr = as(attr)) + + pyExportAttr->name = name; + } + else if (auto requireCapAttr = as(attr)) + { + List capabilityNames; + for (auto& arg : attr->args) { - if (attr->args.getCount() > 2) - { - getSink()->diagnose(attr, Diagnostics::tooManyArguments, attr->args.getCount(), 0); - return nullptr; - } - else if (attr->args.getCount() < 2) - { - getSink()->diagnose(attr, Diagnostics::notEnoughArguments, attr->args.getCount(), 2); - return nullptr; - } CapabilityName capName; - if (!checkCapabilityName(attr->args[0], capName)) - { - return nullptr; - } - requirePreludeAttr->capabilitySet = CapabilitySet(capName); - if (auto stringLitExpr = as(attr->args[1])) + if (checkCapabilityName(arg, capName)) { - requirePreludeAttr->prelude = getStringLiteralTokenValue(stringLitExpr->token); - } - else - { - getSink()->diagnose(attr->args[1], Diagnostics::expectedAStringLiteral); - return nullptr; - } - return attr; + capabilityNames.add(capName); + if (isInternalCapabilityName(capName)) + maybeDiagnose( + getSink(), + this->getOptionSet(), + DiagnosticCategory::Capability, + attr, + Diagnostics::usingInternalCapabilityName, + attr, + capName); + } + } + requireCapAttr->capabilitySet = CapabilitySet(capabilityNames); + if (requireCapAttr->capabilitySet.isInvalid()) + maybeDiagnose( + getSink(), + this->getOptionSet(), + DiagnosticCategory::Capability, + attr, + Diagnostics::unexpectedCapability, + attr, + CapabilityName::Invalid); + } + else if (auto requirePreludeAttr = as(attr)) + { + if (attr->args.getCount() > 2) + { + getSink()->diagnose(attr, Diagnostics::tooManyArguments, attr->args.getCount(), 0); + return nullptr; + } + else if (attr->args.getCount() < 2) + { + getSink()->diagnose(attr, Diagnostics::notEnoughArguments, attr->args.getCount(), 2); + return nullptr; + } + CapabilityName capName; + if (!checkCapabilityName(attr->args[0], capName)) + { + return nullptr; + } + requirePreludeAttr->capabilitySet = CapabilitySet(capName); + if (auto stringLitExpr = as(attr->args[1])) + { + requirePreludeAttr->prelude = getStringLiteralTokenValue(stringLitExpr->token); } else { - if(attr->args.getCount() == 0) - { - // If the attribute took no arguments, then we will - // assume it is valid as written. - } - else - { - // We should be special-casing the checking of any attribute - // with a non-zero number of arguments. - getSink()->diagnose(attr, Diagnostics::tooManyArguments, attr->args.getCount(), 0); - return nullptr; - } + getSink()->diagnose(attr->args[1], Diagnostics::expectedAStringLiteral); + return nullptr; } - return attr; } - - AttributeBase* SemanticsVisitor::checkAttribute( - UncheckedAttribute* uncheckedAttr, - ModifiableSyntaxNode* attrTarget) + else { - auto attrName = uncheckedAttr->getKeywordName(); - auto attrDecl = lookUpAttributeDecl( - attrName, - uncheckedAttr->scope); - - if(!attrDecl) + if (attr->args.getCount() == 0) { - getSink()->diagnose(uncheckedAttr, Diagnostics::unknownAttributeName, attrName); - return uncheckedAttr; + // If the attribute took no arguments, then we will + // assume it is valid as written. } - - if (!attrDecl->syntaxClass.isSubClassOf()) + else { - SLANG_DIAGNOSE_UNEXPECTED(getSink(), attrDecl, "attribute declaration does not reference an attribute class"); - return uncheckedAttr; + // We should be special-casing the checking of any attribute + // with a non-zero number of arguments. + getSink()->diagnose(attr, Diagnostics::tooManyArguments, attr->args.getCount(), 0); + return nullptr; } + } - // Manage scope - NodeBase* attrInstance = attrDecl->syntaxClass.createInstance(m_astBuilder); - auto attr = as(attrInstance); - if(!attr) - { - SLANG_DIAGNOSE_UNEXPECTED(getSink(), attrDecl, "attribute class did not yield an attribute object"); - return uncheckedAttr; - } + return attr; +} - // We are going to replace the unchecked attribute with the checked one. +AttributeBase* SemanticsVisitor::checkAttribute( + UncheckedAttribute* uncheckedAttr, + ModifiableSyntaxNode* attrTarget) +{ + auto attrName = uncheckedAttr->getKeywordName(); + auto attrDecl = lookUpAttributeDecl(attrName, uncheckedAttr->scope); - // First copy all of the state over from the original attribute. - attr->keywordName = uncheckedAttr->keywordName; - attr->originalIdentifierToken = uncheckedAttr->originalIdentifierToken; - attr->args = uncheckedAttr->args; - attr->loc = uncheckedAttr->loc; - attr->attributeDecl = attrDecl; + if (!attrDecl) + { + getSink()->diagnose(uncheckedAttr, Diagnostics::unknownAttributeName, attrName); + return uncheckedAttr; + } - // We will start with checking steps that can be applied independent - // of the concrete attribute type that was selected. These only need - // us to look at the attribute declaration itself. - // - // Start by doing argument/parameter matching - UInt argCount = attr->args.getCount(); - UInt paramCounter = 0; - bool mismatch = false; - for(auto paramDecl : attrDecl->getMembersOfType()) - { - UInt paramIndex = paramCounter++; - if( paramIndex < argCount ) + if (!attrDecl->syntaxClass.isSubClassOf()) + { + SLANG_DIAGNOSE_UNEXPECTED( + getSink(), + attrDecl, + "attribute declaration does not reference an attribute class"); + return uncheckedAttr; + } + + // Manage scope + NodeBase* attrInstance = attrDecl->syntaxClass.createInstance(m_astBuilder); + auto attr = as(attrInstance); + if (!attr) + { + SLANG_DIAGNOSE_UNEXPECTED( + getSink(), + attrDecl, + "attribute class did not yield an attribute object"); + return uncheckedAttr; + } + + // We are going to replace the unchecked attribute with the checked one. + + // First copy all of the state over from the original attribute. + attr->keywordName = uncheckedAttr->keywordName; + attr->originalIdentifierToken = uncheckedAttr->originalIdentifierToken; + attr->args = uncheckedAttr->args; + attr->loc = uncheckedAttr->loc; + attr->attributeDecl = attrDecl; + + // We will start with checking steps that can be applied independent + // of the concrete attribute type that was selected. These only need + // us to look at the attribute declaration itself. + // + // Start by doing argument/parameter matching + UInt argCount = attr->args.getCount(); + UInt paramCounter = 0; + bool mismatch = false; + for (auto paramDecl : attrDecl->getMembersOfType()) + { + UInt paramIndex = paramCounter++; + if (paramIndex < argCount) + { + // TODO: support checking the argument against the declared + // type for the parameter. + } + else + { + // We didn't have enough arguments for the + // number of parameters declared. + if (const auto defaultArg = paramDecl->initExpr) { - // TODO: support checking the argument against the declared - // type for the parameter. + // The attribute declaration provided a default, + // so we should use that. + // + // TODO: we need to figure out how to hook up + // default arguments as needed. + // For now just copy the expression over. + + attr->args.add(paramDecl->initExpr); } else { - // We didn't have enough arguments for the - // number of parameters declared. - if(const auto defaultArg = paramDecl->initExpr) - { - // The attribute declaration provided a default, - // so we should use that. - // - // TODO: we need to figure out how to hook up - // default arguments as needed. - // For now just copy the expression over. - - attr->args.add(paramDecl->initExpr); - } - else - { - mismatch = true; - } + mismatch = true; } } - UInt paramCount = paramCounter; + } + UInt paramCount = paramCounter; - if(mismatch) - { - getSink()->diagnose(attr, Diagnostics::attributeArgumentCountMismatch, attrName, paramCount, argCount); - return uncheckedAttr; - } + if (mismatch) + { + getSink()->diagnose( + attr, + Diagnostics::attributeArgumentCountMismatch, + attrName, + paramCount, + argCount); + return uncheckedAttr; + } - // The next bit of validation that we can apply semi-generically - // is to validate that the target for this attribute is a valid - // one for the chosen attribute. - // - // The attribute declaration will have one or more `AttributeTargetModifier`s - // that each specify a syntax class that the attribute can be applied to. - // If any of these match `attrTarget`, then we are good. - // - bool validTarget = false; - for(auto attrTargetMod : attrDecl->getModifiersOfType()) - { - if(attrTarget->getClass().isSubClassOf(attrTargetMod->syntaxClass)) - { - validTarget = true; - break; - } - } - if(!validTarget) + // The next bit of validation that we can apply semi-generically + // is to validate that the target for this attribute is a valid + // one for the chosen attribute. + // + // The attribute declaration will have one or more `AttributeTargetModifier`s + // that each specify a syntax class that the attribute can be applied to. + // If any of these match `attrTarget`, then we are good. + // + bool validTarget = false; + for (auto attrTargetMod : attrDecl->getModifiersOfType()) + { + if (attrTarget->getClass().isSubClassOf(attrTargetMod->syntaxClass)) { - getSink()->diagnose(attr, Diagnostics::attributeNotApplicable, attrName); - return uncheckedAttr; + validTarget = true; + break; } + } + if (!validTarget) + { + getSink()->diagnose(attr, Diagnostics::attributeNotApplicable, attrName); + return uncheckedAttr; + } - // Now apply type-specific validation to the attribute. - if(!validateAttribute(attr, attrDecl, attrTarget)) - { - return uncheckedAttr; - } + // Now apply type-specific validation to the attribute. + if (!validateAttribute(attr, attrDecl, attrTarget)) + { + return uncheckedAttr; + } - return attr; + return attr; +} + +ASTNodeType getModifierConflictGroupKind(ASTNodeType modifierType) +{ + switch (modifierType) + { + // Allowed only on parameters and global variables. + case ASTNodeType::InModifier: return modifierType; + case ASTNodeType::OutModifier: + case ASTNodeType::RefModifier: + case ASTNodeType::ConstRefModifier: + case ASTNodeType::InOutModifier: + return ASTNodeType::OutModifier; + + // Modifiers that are their own exclusive group. + case ASTNodeType::GLSLLayoutModifier: + case ASTNodeType::GLSLParsedLayoutModifier: + case ASTNodeType::GLSLLocationLayoutModifier: + case ASTNodeType::GLSLInputAttachmentIndexLayoutAttribute: + case ASTNodeType::GLSLOffsetLayoutAttribute: + case ASTNodeType::GLSLUnparsedLayoutModifier: + case ASTNodeType::GLSLLayoutModifierGroupMarker: + case ASTNodeType::GLSLLayoutModifierGroupBegin: + case ASTNodeType::GLSLLayoutModifierGroupEnd: + case ASTNodeType::GLSLBufferModifier: + case ASTNodeType::MemoryQualifierSetModifier: + case ASTNodeType::GLSLWriteOnlyModifier: + case ASTNodeType::GLSLReadOnlyModifier: + case ASTNodeType::GLSLVolatileModifier: + case ASTNodeType::GLSLRestrictModifier: + case ASTNodeType::GLSLPatchModifier: + case ASTNodeType::RayPayloadAccessSemantic: + case ASTNodeType::RayPayloadReadSemantic: + case ASTNodeType::RayPayloadWriteSemantic: + case ASTNodeType::GloballyCoherentModifier: + case ASTNodeType::PreciseModifier: + case ASTNodeType::IntrinsicOpModifier: + case ASTNodeType::InlineModifier: + case ASTNodeType::HLSLExportModifier: + case ASTNodeType::ExternCppModifier: + case ASTNodeType::ExportedModifier: + case ASTNodeType::ConstModifier: + case ASTNodeType::ConstExprModifier: + case ASTNodeType::MatrixLayoutModifier: + case ASTNodeType::RowMajorLayoutModifier: + case ASTNodeType::HLSLRowMajorLayoutModifier: + case ASTNodeType::GLSLColumnMajorLayoutModifier: + case ASTNodeType::ColumnMajorLayoutModifier: + case ASTNodeType::HLSLColumnMajorLayoutModifier: + case ASTNodeType::GLSLRowMajorLayoutModifier: + case ASTNodeType::HLSLEffectSharedModifier: + case ASTNodeType::HLSLVolatileModifier: + case ASTNodeType::GLSLPrecisionModifier: + case ASTNodeType::HLSLGroupSharedModifier: return modifierType; + + case ASTNodeType::HLSLStaticModifier: + case ASTNodeType::ActualGlobalModifier: + case ASTNodeType::HLSLUniformModifier: return ASTNodeType::HLSLStaticModifier; + + case ASTNodeType::HLSLNoInterpolationModifier: + case ASTNodeType::HLSLNoPerspectiveModifier: + case ASTNodeType::HLSLLinearModifier: + case ASTNodeType::HLSLSampleModifier: + case ASTNodeType::HLSLCentroidModifier: + case ASTNodeType::PerVertexModifier: return ASTNodeType::InterpolationModeModifier; + + case ASTNodeType::PrefixModifier: + case ASTNodeType::PostfixModifier: return ASTNodeType::PrefixModifier; + + case ASTNodeType::BuiltinModifier: + case ASTNodeType::PublicModifier: + case ASTNodeType::PrivateModifier: + case ASTNodeType::InternalModifier: return ASTNodeType::VisibilityModifier; + + default: return ASTNodeType::NodeBase; } +} - ASTNodeType getModifierConflictGroupKind(ASTNodeType modifierType) - { - switch (modifierType) - { - // Allowed only on parameters and global variables. - case ASTNodeType::InModifier: - return modifierType; - case ASTNodeType::OutModifier: - case ASTNodeType::RefModifier: - case ASTNodeType::ConstRefModifier: - case ASTNodeType::InOutModifier: - return ASTNodeType::OutModifier; - - // Modifiers that are their own exclusive group. - case ASTNodeType::GLSLLayoutModifier: - case ASTNodeType::GLSLParsedLayoutModifier: - case ASTNodeType::GLSLLocationLayoutModifier: - case ASTNodeType::GLSLInputAttachmentIndexLayoutAttribute: - case ASTNodeType::GLSLOffsetLayoutAttribute: - case ASTNodeType::GLSLUnparsedLayoutModifier: - case ASTNodeType::GLSLLayoutModifierGroupMarker: - case ASTNodeType::GLSLLayoutModifierGroupBegin: - case ASTNodeType::GLSLLayoutModifierGroupEnd: - case ASTNodeType::GLSLBufferModifier: - case ASTNodeType::MemoryQualifierSetModifier: - case ASTNodeType::GLSLWriteOnlyModifier: - case ASTNodeType::GLSLReadOnlyModifier: - case ASTNodeType::GLSLVolatileModifier: - case ASTNodeType::GLSLRestrictModifier: - case ASTNodeType::GLSLPatchModifier: - case ASTNodeType::RayPayloadAccessSemantic: - case ASTNodeType::RayPayloadReadSemantic: - case ASTNodeType::RayPayloadWriteSemantic: - case ASTNodeType::GloballyCoherentModifier: - case ASTNodeType::PreciseModifier: - case ASTNodeType::IntrinsicOpModifier: - case ASTNodeType::InlineModifier: - case ASTNodeType::HLSLExportModifier: - case ASTNodeType::ExternCppModifier: - case ASTNodeType::ExportedModifier: - case ASTNodeType::ConstModifier: - case ASTNodeType::ConstExprModifier: - case ASTNodeType::MatrixLayoutModifier: - case ASTNodeType::RowMajorLayoutModifier: - case ASTNodeType::HLSLRowMajorLayoutModifier: - case ASTNodeType::GLSLColumnMajorLayoutModifier: - case ASTNodeType::ColumnMajorLayoutModifier: - case ASTNodeType::HLSLColumnMajorLayoutModifier: - case ASTNodeType::GLSLRowMajorLayoutModifier: - case ASTNodeType::HLSLEffectSharedModifier: - case ASTNodeType::HLSLVolatileModifier: - case ASTNodeType::GLSLPrecisionModifier: - case ASTNodeType::HLSLGroupSharedModifier: - return modifierType; - - case ASTNodeType::HLSLStaticModifier: - case ASTNodeType::ActualGlobalModifier: - case ASTNodeType::HLSLUniformModifier: - return ASTNodeType::HLSLStaticModifier; - - case ASTNodeType::HLSLNoInterpolationModifier: - case ASTNodeType::HLSLNoPerspectiveModifier: - case ASTNodeType::HLSLLinearModifier: - case ASTNodeType::HLSLSampleModifier: - case ASTNodeType::HLSLCentroidModifier: - case ASTNodeType::PerVertexModifier: - return ASTNodeType::InterpolationModeModifier; - - case ASTNodeType::PrefixModifier: - case ASTNodeType::PostfixModifier: - return ASTNodeType::PrefixModifier; - - case ASTNodeType::BuiltinModifier: - case ASTNodeType::PublicModifier: - case ASTNodeType::PrivateModifier: - case ASTNodeType::InternalModifier: - return ASTNodeType::VisibilityModifier; - - default: - return ASTNodeType::NodeBase; - } - } - - bool isModifierAllowedOnDecl(bool isGLSLInput, ASTNodeType modifierType, Decl* decl) - { - switch (modifierType) - { - // In addition to the above cases, these are also present on empty - // global declarations, for instance - // layout(local_size_x=1) in; - case ASTNodeType::InModifier: - case ASTNodeType::InOutModifier: - case ASTNodeType::OutModifier: - case ASTNodeType::GLSLLayoutModifier: - case ASTNodeType::GLSLParsedLayoutModifier: - case ASTNodeType::GLSLLocationLayoutModifier: - case ASTNodeType::GLSLInputAttachmentIndexLayoutAttribute: - case ASTNodeType::GLSLOffsetLayoutAttribute: - case ASTNodeType::GLSLUnparsedLayoutModifier: - case ASTNodeType::GLSLLayoutModifierGroupMarker: - case ASTNodeType::GLSLLayoutModifierGroupBegin: - case ASTNodeType::GLSLLayoutModifierGroupEnd: - // If we are in GLSL mode, also allow these but otherwise fall to - // the regular check - if(isGLSLInput && as(decl) && isGlobalDecl(decl)) - return true; - [[fallthrough]]; - - case ASTNodeType::RefModifier: - case ASTNodeType::ConstRefModifier: - case ASTNodeType::GLSLBufferModifier: - case ASTNodeType::GLSLPatchModifier: - case ASTNodeType::RayPayloadAccessSemantic: - case ASTNodeType::RayPayloadReadSemantic: - case ASTNodeType::RayPayloadWriteSemantic: - return (as(decl) && isGlobalDecl(decl)) || as(decl) || as(decl); - - case ASTNodeType::GLSLWriteOnlyModifier: - case ASTNodeType::GLSLReadOnlyModifier: - case ASTNodeType::GLSLVolatileModifier: - case ASTNodeType::GLSLRestrictModifier: - if(isGLSLInput) - return (as(decl) && (isGlobalDecl(decl)) || as(decl) || as(decl)) - || as(getParentDecl(decl)) && isGlobalDecl(getParentDecl(decl)); - return (as(decl) && (isGlobalDecl(decl)) || as(decl) || as(decl)); - - case ASTNodeType::GloballyCoherentModifier: - case ASTNodeType::HLSLVolatileModifier: - if(isGLSLInput) - return as(decl) && (isGlobalDecl(decl) || as(getParentDecl(decl)) || as(decl)) - || as(decl) && isGlobalDecl(decl) || as(decl) || (as(getParentDecl(decl)) && isGlobalDecl(getParentDecl(decl))); - return as(decl) && (isGlobalDecl(decl) || as(getParentDecl(decl)) || as(decl)); - - // Allowed only on parameters, struct fields and global variables. - case ASTNodeType::InterpolationModeModifier: - case ASTNodeType::HLSLNoInterpolationModifier: - case ASTNodeType::HLSLNoPerspectiveModifier: - case ASTNodeType::HLSLLinearModifier: - case ASTNodeType::HLSLSampleModifier: - case ASTNodeType::HLSLCentroidModifier: - case ASTNodeType::PerVertexModifier: - case ASTNodeType::HLSLUniformModifier: - case ASTNodeType::DynamicUniformModifier: - return (as(decl) && (isGlobalDecl(decl) || as(getParentDecl(decl)))) || as(decl); - - case ASTNodeType::HLSLSemantic: - case ASTNodeType::HLSLLayoutSemantic: - case ASTNodeType::HLSLRegisterSemantic: - case ASTNodeType::HLSLPackOffsetSemantic: - case ASTNodeType::HLSLSimpleSemantic: - return (as(decl) && (isGlobalDecl(decl) || as(getParentDecl(decl)))) || as(decl) || as(decl); - - // Allowed only on functions - case ASTNodeType::IntrinsicOpModifier: - case ASTNodeType::SpecializedForTargetModifier: - case ASTNodeType::InlineModifier: - case ASTNodeType::PrefixModifier: - case ASTNodeType::PostfixModifier: - return as(decl); - - case ASTNodeType::BuiltinModifier: - case ASTNodeType::PublicModifier: - case ASTNodeType::PrivateModifier: - case ASTNodeType::InternalModifier: - case ASTNodeType::ExternModifier: - case ASTNodeType::HLSLExportModifier: - case ASTNodeType::ExternCppModifier: - return as(decl) || as(decl) || as(decl) || as(decl) - || as(decl) || as(decl) || as(decl) || as(decl) - || as(decl); - - case ASTNodeType::ExportedModifier: - return as(decl); - - case ASTNodeType::ConstModifier: - case ASTNodeType::HLSLStaticModifier: - case ASTNodeType::ConstExprModifier: - case ASTNodeType::PreciseModifier: - return as(decl) || as(decl); - - case ASTNodeType::ActualGlobalModifier: - case ASTNodeType::MatrixLayoutModifier: - case ASTNodeType::RowMajorLayoutModifier: - case ASTNodeType::HLSLRowMajorLayoutModifier: - case ASTNodeType::GLSLColumnMajorLayoutModifier: - case ASTNodeType::ColumnMajorLayoutModifier: - case ASTNodeType::HLSLColumnMajorLayoutModifier: - case ASTNodeType::GLSLRowMajorLayoutModifier: - case ASTNodeType::HLSLEffectSharedModifier: - return as(decl) || as(decl); - - case ASTNodeType::GLSLPrecisionModifier: - return as(decl) || as(decl) || as(decl); - case ASTNodeType::HLSLGroupSharedModifier: - // groupshared must be global or static. - if (!as(decl)) - return false; - return isGlobalDecl(decl) || isEffectivelyStatic(decl); - default: +bool isModifierAllowedOnDecl(bool isGLSLInput, ASTNodeType modifierType, Decl* decl) +{ + switch (modifierType) + { + // In addition to the above cases, these are also present on empty + // global declarations, for instance + // layout(local_size_x=1) in; + case ASTNodeType::InModifier: + case ASTNodeType::InOutModifier: + case ASTNodeType::OutModifier: + case ASTNodeType::GLSLLayoutModifier: + case ASTNodeType::GLSLParsedLayoutModifier: + case ASTNodeType::GLSLLocationLayoutModifier: + case ASTNodeType::GLSLInputAttachmentIndexLayoutAttribute: + case ASTNodeType::GLSLOffsetLayoutAttribute: + case ASTNodeType::GLSLUnparsedLayoutModifier: + case ASTNodeType::GLSLLayoutModifierGroupMarker: + case ASTNodeType::GLSLLayoutModifierGroupBegin: + case ASTNodeType::GLSLLayoutModifierGroupEnd: + // If we are in GLSL mode, also allow these but otherwise fall to + // the regular check + if (isGLSLInput && as(decl) && isGlobalDecl(decl)) return true; - } + [[fallthrough]]; + + case ASTNodeType::RefModifier: + case ASTNodeType::ConstRefModifier: + case ASTNodeType::GLSLBufferModifier: + case ASTNodeType::GLSLPatchModifier: + case ASTNodeType::RayPayloadAccessSemantic: + case ASTNodeType::RayPayloadReadSemantic: + case ASTNodeType::RayPayloadWriteSemantic: + return (as(decl) && isGlobalDecl(decl)) || as(decl) || + as(decl); + + case ASTNodeType::GLSLWriteOnlyModifier: + case ASTNodeType::GLSLReadOnlyModifier: + case ASTNodeType::GLSLVolatileModifier: + case ASTNodeType::GLSLRestrictModifier: + if (isGLSLInput) + return (as(decl) && (isGlobalDecl(decl)) || as(decl) || + as(decl)) || + as(getParentDecl(decl)) && isGlobalDecl(getParentDecl(decl)); + return ( + as(decl) && (isGlobalDecl(decl)) || as(decl) || + as(decl)); + + case ASTNodeType::GloballyCoherentModifier: + case ASTNodeType::HLSLVolatileModifier: + if (isGLSLInput) + return as(decl) && + (isGlobalDecl(decl) || as(getParentDecl(decl)) || + as(decl)) || + as(decl) && isGlobalDecl(decl) || as(decl) || + (as(getParentDecl(decl)) && isGlobalDecl(getParentDecl(decl))); + return as(decl) && (isGlobalDecl(decl) || as(getParentDecl(decl)) || + as(decl)); + + // Allowed only on parameters, struct fields and global variables. + case ASTNodeType::InterpolationModeModifier: + case ASTNodeType::HLSLNoInterpolationModifier: + case ASTNodeType::HLSLNoPerspectiveModifier: + case ASTNodeType::HLSLLinearModifier: + case ASTNodeType::HLSLSampleModifier: + case ASTNodeType::HLSLCentroidModifier: + case ASTNodeType::PerVertexModifier: + case ASTNodeType::HLSLUniformModifier: + case ASTNodeType::DynamicUniformModifier: + return (as(decl) && + (isGlobalDecl(decl) || as(getParentDecl(decl)))) || + as(decl); + + case ASTNodeType::HLSLSemantic: + case ASTNodeType::HLSLLayoutSemantic: + case ASTNodeType::HLSLRegisterSemantic: + case ASTNodeType::HLSLPackOffsetSemantic: + case ASTNodeType::HLSLSimpleSemantic: + return (as(decl) && + (isGlobalDecl(decl) || as(getParentDecl(decl)))) || + as(decl) || as(decl); + + // Allowed only on functions + case ASTNodeType::IntrinsicOpModifier: + case ASTNodeType::SpecializedForTargetModifier: + case ASTNodeType::InlineModifier: + case ASTNodeType::PrefixModifier: + case ASTNodeType::PostfixModifier: return as(decl); + + case ASTNodeType::BuiltinModifier: + case ASTNodeType::PublicModifier: + case ASTNodeType::PrivateModifier: + case ASTNodeType::InternalModifier: + case ASTNodeType::ExternModifier: + case ASTNodeType::HLSLExportModifier: + case ASTNodeType::ExternCppModifier: + return as(decl) || as(decl) || as(decl) || + as(decl) || as(decl) || as(decl) || + as(decl) || as(decl) || as(decl); + + case ASTNodeType::ExportedModifier: return as(decl); + + case ASTNodeType::ConstModifier: + case ASTNodeType::HLSLStaticModifier: + case ASTNodeType::ConstExprModifier: + case ASTNodeType::PreciseModifier: return as(decl) || as(decl); + + case ASTNodeType::ActualGlobalModifier: + case ASTNodeType::MatrixLayoutModifier: + case ASTNodeType::RowMajorLayoutModifier: + case ASTNodeType::HLSLRowMajorLayoutModifier: + case ASTNodeType::GLSLColumnMajorLayoutModifier: + case ASTNodeType::ColumnMajorLayoutModifier: + case ASTNodeType::HLSLColumnMajorLayoutModifier: + case ASTNodeType::GLSLRowMajorLayoutModifier: + case ASTNodeType::HLSLEffectSharedModifier: + return as(decl) || as(decl); + + case ASTNodeType::GLSLPrecisionModifier: + return as(decl) || as(decl) || as(decl); + case ASTNodeType::HLSLGroupSharedModifier: + // groupshared must be global or static. + if (!as(decl)) + return false; + return isGlobalDecl(decl) || isEffectivelyStatic(decl); + default: return true; } +} - Modifier* SemanticsVisitor::checkModifier( - Modifier* m, - ModifiableSyntaxNode* syntaxNode, - bool ignoreUnallowedModifier) +Modifier* SemanticsVisitor::checkModifier( + Modifier* m, + ModifiableSyntaxNode* syntaxNode, + bool ignoreUnallowedModifier) +{ + if (auto hlslUncheckedAttribute = as(m)) { - if(auto hlslUncheckedAttribute = as(m)) - { - // We have an HLSL `[name(arg,...)]` attribute, and we'd like - // to check that it is provides all the expected arguments - // - // First, look up the attribute name in the current scope to find - // the right syntax class to instantiate. - // + // We have an HLSL `[name(arg,...)]` attribute, and we'd like + // to check that it is provides all the expected arguments + // + // First, look up the attribute name in the current scope to find + // the right syntax class to instantiate. + // - auto checkedAttr = checkAttribute(hlslUncheckedAttribute, syntaxNode); + auto checkedAttr = checkAttribute(hlslUncheckedAttribute, syntaxNode); - if (as(checkedAttr)) - { - if (auto parentDecl = as(getParentDecl(as(syntaxNode)))) - parentDecl->invalidateMemberDictionary(); - return getASTBuilder()->create(); - } - return checkedAttr; + if (as(checkedAttr)) + { + if (auto parentDecl = as(getParentDecl(as(syntaxNode)))) + parentDecl->invalidateMemberDictionary(); + return getASTBuilder()->create(); } + return checkedAttr; + } - if (auto decl = as(syntaxNode)) + if (auto decl = as(syntaxNode)) + { + auto moduleDecl = getModuleDecl(decl); + bool isGLSLInput = getOptionSet().getBoolOption(CompilerOptionName::AllowGLSL); + if (!isGLSLInput && moduleDecl && moduleDecl->findModifier()) + isGLSLInput = true; + if (!isModifierAllowedOnDecl(isGLSLInput, m->astNodeType, decl)) { - auto moduleDecl = getModuleDecl(decl); - bool isGLSLInput = getOptionSet().getBoolOption(CompilerOptionName::AllowGLSL); - if (!isGLSLInput && moduleDecl && moduleDecl->findModifier()) - isGLSLInput = true; - if (!isModifierAllowedOnDecl(isGLSLInput, m->astNodeType, decl)) + if (!ignoreUnallowedModifier) { - if (!ignoreUnallowedModifier) - { - getSink()->diagnose(m, Diagnostics::modifierNotAllowed, m); - return nullptr; - } - return m; + getSink()->diagnose(m, Diagnostics::modifierNotAllowed, m); + return nullptr; } + return m; } + } - MemoryQualifierSetModifier::Flags::MemoryQualifiersBit memoryQualifierBit = MemoryQualifierSetModifier::Flags::kNone; - if(as(m)) - memoryQualifierBit = MemoryQualifierSetModifier::Flags::kCoherent; - else if(as(m)) - memoryQualifierBit = MemoryQualifierSetModifier::Flags::kReadOnly; - else if(as(m)) - memoryQualifierBit = MemoryQualifierSetModifier::Flags::kWriteOnly; - else if(as(m)) - memoryQualifierBit = MemoryQualifierSetModifier::Flags::kVolatile; - else if(as(m)) - memoryQualifierBit = MemoryQualifierSetModifier::Flags::kRestrict; - if(memoryQualifierBit != MemoryQualifierSetModifier::Flags::kNone) - { - bool newModifier = false; - MemoryQualifierSetModifier* memoryQualifiers = syntaxNode->findModifier(); - if(!memoryQualifiers) - { - newModifier = true; - memoryQualifiers = getASTBuilder()->create(); - } - memoryQualifiers->addQualifier(m, - memoryQualifierBit); - if (newModifier) - { - m->next = memoryQualifiers; - return memoryQualifiers; - } - return nullptr; + MemoryQualifierSetModifier::Flags::MemoryQualifiersBit memoryQualifierBit = + MemoryQualifierSetModifier::Flags::kNone; + if (as(m)) + memoryQualifierBit = MemoryQualifierSetModifier::Flags::kCoherent; + else if (as(m)) + memoryQualifierBit = MemoryQualifierSetModifier::Flags::kReadOnly; + else if (as(m)) + memoryQualifierBit = MemoryQualifierSetModifier::Flags::kWriteOnly; + else if (as(m)) + memoryQualifierBit = MemoryQualifierSetModifier::Flags::kVolatile; + else if (as(m)) + memoryQualifierBit = MemoryQualifierSetModifier::Flags::kRestrict; + if (memoryQualifierBit != MemoryQualifierSetModifier::Flags::kNone) + { + bool newModifier = false; + MemoryQualifierSetModifier* memoryQualifiers = + syntaxNode->findModifier(); + if (!memoryQualifiers) + { + newModifier = true; + memoryQualifiers = getASTBuilder()->create(); } - - if (auto hlslSemantic = as(m)) + memoryQualifiers->addQualifier(m, memoryQualifierBit); + if (newModifier) { - if (hlslSemantic->name.getName() == getSession()->getCompletionRequestTokenName()) - { - getLinkage()->contentAssistInfo.completionSuggestions.scopeKind = - CompletionSuggestions::ScopeKind::HLSLSemantics; - } + m->next = memoryQualifiers; + return memoryQualifiers; } + return nullptr; + } - if (const auto externModifier = as(m)) + if (auto hlslSemantic = as(m)) + { + if (hlslSemantic->name.getName() == getSession()->getCompletionRequestTokenName()) { - if (auto varDecl = as(syntaxNode)) - { - if (auto parentExtension = as(varDecl->parentDecl)) - { - auto originalMemberLookup = lookUpMember(m_astBuilder, this, varDecl->getName(), parentExtension->targetType, parentExtension->ownedScope); - LookupResult filteredResult; - for (auto item : originalMemberLookup.items) - { - if (item.declRef.getDecl() != varDecl) - AddToLookupResult(filteredResult, item); - } - if (filteredResult.isValid() && !filteredResult.isOverloaded()) - { - auto extensionExternMemberModifier = m_astBuilder->create(); - extensionExternMemberModifier->originalDecl = filteredResult.item.declRef; - return extensionExternMemberModifier; - } - else if (filteredResult.isOverloaded()) - { - getSink()->diagnose(varDecl, Diagnostics::ambiguousOriginalDefintionOfExternDecl, varDecl); - } - else - { - getSink()->diagnose(varDecl, Diagnostics::missingOriginalDefintionOfExternDecl, varDecl); - } - } - // The next part of the check is to make sure the type defined here is consistent with the original definition. - // Since we haven't checked the type of this decl yet, we defer that until we have fully checked decl. - // See SemanticsDeclHeaderVisitor::checkExtensionExternVarAttribute. - } + getLinkage()->contentAssistInfo.completionSuggestions.scopeKind = + CompletionSuggestions::ScopeKind::HLSLSemantics; } + } - if (auto packOffsetModifier = as(m)) - { - if (!packOffsetModifier->registerName.getContent().startsWith("c")) - { - getSink()->diagnose(packOffsetModifier, Diagnostics::unknownRegisterClass, packOffsetModifier->registerName); - return m; - } - auto uniformOffset = stringToInt(packOffsetModifier->registerName.getContent().tail(1)) * 16; - if (packOffsetModifier->componentMask.getContentLength()) - { - switch (packOffsetModifier->componentMask.getContent()[0]) + if (const auto externModifier = as(m)) + { + if (auto varDecl = as(syntaxNode)) + { + if (auto parentExtension = as(varDecl->parentDecl)) + { + auto originalMemberLookup = lookUpMember( + m_astBuilder, + this, + varDecl->getName(), + parentExtension->targetType, + parentExtension->ownedScope); + LookupResult filteredResult; + for (auto item : originalMemberLookup.items) { - case 'x': - uniformOffset += 0; - break; - case 'y': - uniformOffset += 4; - break; - case 'z': - uniformOffset += 8; - break; - case 'w': - uniformOffset += 12; - break; - default: - getSink()->diagnose(packOffsetModifier, Diagnostics::invalidComponentMask, packOffsetModifier->componentMask); - break; + if (item.declRef.getDecl() != varDecl) + AddToLookupResult(filteredResult, item); } - } - packOffsetModifier->uniformOffset = uniformOffset; - return packOffsetModifier; - } - - if(auto targetIntrinsic = as(m)) - { - // TODO: verify that the predicate is one we understand - if(targetIntrinsic->scrutinee.name) - { - if(auto genDecl = as(syntaxNode)) + if (filteredResult.isValid() && !filteredResult.isOverloaded()) { - auto scrutineeResults = lookUp( - m_astBuilder, - this, - targetIntrinsic->scrutinee.name, - genDecl->ownedScope); - if(!scrutineeResults.isValid()) - { - getSink()->diagnose( - targetIntrinsic->scrutinee.loc, - Diagnostics::undefinedIdentifier2, - targetIntrinsic->scrutinee.name); - } - if(scrutineeResults.isOverloaded()) - { - getSink()->diagnose( - targetIntrinsic->scrutinee.loc, - Diagnostics::ambiguousReference, - targetIntrinsic->scrutinee.name); - } - targetIntrinsic->scrutineeDeclRef = scrutineeResults.item.declRef; + auto extensionExternMemberModifier = + m_astBuilder->create(); + extensionExternMemberModifier->originalDecl = filteredResult.item.declRef; + return extensionExternMemberModifier; } - } - } - - if (as(m)) - { - if (auto decl = as(syntaxNode)) - { - if (isGlobalDecl(decl)) + else if (filteredResult.isOverloaded()) { - getSink()->diagnose(m, Diagnostics::invalidUseOfPrivateVisibility, as(syntaxNode)); - return m; + getSink()->diagnose( + varDecl, + Diagnostics::ambiguousOriginalDefintionOfExternDecl, + varDecl); } - } - if (as(syntaxNode)) - { - getSink()->diagnose(m, Diagnostics::invalidVisibilityModifierOnTypeOfDecl, syntaxNode->astNodeType); - return m; - } - else if (auto decl = as(syntaxNode)) - { - // Interface requirements can't be private. - if (isInterfaceRequirement(decl)) + else { - getSink()->diagnose(m, Diagnostics::invalidUseOfPrivateVisibility, as(syntaxNode)); + getSink()->diagnose( + varDecl, + Diagnostics::missingOriginalDefintionOfExternDecl, + varDecl); } } + // The next part of the check is to make sure the type defined here is consistent with + // the original definition. Since we haven't checked the type of this decl yet, we defer + // that until we have fully checked decl. See + // SemanticsDeclHeaderVisitor::checkExtensionExternVarAttribute. } - else if (as(m)) - { - if (as(syntaxNode)) - { - getSink()->diagnose(m, Diagnostics::invalidVisibilityModifierOnTypeOfDecl, syntaxNode->astNodeType); - return m; + } + + if (auto packOffsetModifier = as(m)) + { + if (!packOffsetModifier->registerName.getContent().startsWith("c")) + { + getSink()->diagnose( + packOffsetModifier, + Diagnostics::unknownRegisterClass, + packOffsetModifier->registerName); + return m; + } + auto uniformOffset = + stringToInt(packOffsetModifier->registerName.getContent().tail(1)) * 16; + if (packOffsetModifier->componentMask.getContentLength()) + { + switch (packOffsetModifier->componentMask.getContent()[0]) + { + case 'x': uniformOffset += 0; break; + case 'y': uniformOffset += 4; break; + case 'z': uniformOffset += 8; break; + case 'w': uniformOffset += 12; break; + default: + getSink()->diagnose( + packOffsetModifier, + Diagnostics::invalidComponentMask, + packOffsetModifier->componentMask); + break; } } + packOffsetModifier->uniformOffset = uniformOffset; + return packOffsetModifier; + } - if (auto attr = as(m)) + if (auto targetIntrinsic = as(m)) + { + // TODO: verify that the predicate is one we understand + if (targetIntrinsic->scrutinee.name) { - SLANG_ASSERT(attr->args.getCount() == 3); - - IntVal* values[3]; - - for (int i = 0; i < 3; ++i) + if (auto genDecl = as(syntaxNode)) { - IntVal* value = nullptr; - - auto arg = attr->args[i]; - if (arg) + auto scrutineeResults = lookUp( + m_astBuilder, + this, + targetIntrinsic->scrutinee.name, + genDecl->ownedScope); + if (!scrutineeResults.isValid()) { - auto intValue = checkConstantIntVal(arg); - if (!intValue) - { - return nullptr; - } - if (auto cintVal = as(intValue)) - { - if (cintVal->getValue() < 1) - { - getSink()->diagnose(attr, Diagnostics::nonPositiveNumThreads, cintVal->getValue()); - return nullptr; - } - } - value = intValue; + getSink()->diagnose( + targetIntrinsic->scrutinee.loc, + Diagnostics::undefinedIdentifier2, + targetIntrinsic->scrutinee.name); } - else + if (scrutineeResults.isOverloaded()) { - value = m_astBuilder->getIntVal(m_astBuilder->getIntType(), 1); + getSink()->diagnose( + targetIntrinsic->scrutinee.loc, + Diagnostics::ambiguousReference, + targetIntrinsic->scrutinee.name); } - values[i] = value; + targetIntrinsic->scrutineeDeclRef = scrutineeResults.item.declRef; } - - attr->x = values[0]; - attr->y = values[1]; - attr->z = values[2]; } - - // Default behavior is to leave things as they are, - // and assume that modifiers are mostly already checked. - // - // TODO: This would be a good place to validate that - // a modifier is actually valid for the thing it is - // being applied to, and potentially to check that - // it isn't in conflict with any other modifiers - // on the same declaration. - - return m; } - void SemanticsVisitor::checkVisibility(Decl* decl) + if (as(m)) { - if (as(decl)) - { - return; - } - ShortList typesToVerify; - if (auto varDecl = as(decl)) - { - typesToVerify.add(varDecl->type); - } - else if (auto callable = as(decl)) + if (auto decl = as(syntaxNode)) { - typesToVerify.add(callable->returnType); - typesToVerify.add(callable->errorType); - for (auto param : callable->getParameters()) + if (isGlobalDecl(decl)) { - typesToVerify.add(param->type); + getSink()->diagnose( + m, + Diagnostics::invalidUseOfPrivateVisibility, + as(syntaxNode)); + return m; } } - else if (auto propertyDecl = as(decl)) - { - typesToVerify.add(propertyDecl->type); - } - else if (as(decl)) + if (as(syntaxNode)) { + getSink()->diagnose( + m, + Diagnostics::invalidVisibilityModifierOnTypeOfDecl, + syntaxNode->astNodeType); + return m; } - else if (auto typeDecl = as(decl)) + else if (auto decl = as(syntaxNode)) { - typesToVerify.add(typeDecl->type); + // Interface requirements can't be private. + if (isInterfaceRequirement(decl)) + { + getSink()->diagnose( + m, + Diagnostics::invalidUseOfPrivateVisibility, + as(syntaxNode)); + } } - else + } + else if (as(m)) + { + if (as(syntaxNode)) { - return; + getSink()->diagnose( + m, + Diagnostics::invalidVisibilityModifierOnTypeOfDecl, + syntaxNode->astNodeType); + return m; } - auto thisVisibility = getDeclVisibility(decl); + } + + if (auto attr = as(m)) + { + SLANG_ASSERT(attr->args.getCount() == 3); + + IntVal* values[3]; - // First, we check that the decl's type does not have lower visibility. - for (auto type : typesToVerify) + for (int i = 0; i < 3; ++i) { - if (!type) - continue; - DeclVisibility typeVisibility = getTypeVisibility(type); - if (typeVisibility < thisVisibility) + IntVal* value = nullptr; + + auto arg = attr->args[i]; + if (arg) { - getSink()->diagnose(decl, Diagnostics::useOfLessVisibleType, decl, type); - break; + auto intValue = checkConstantIntVal(arg); + if (!intValue) + { + return nullptr; + } + if (auto cintVal = as(intValue)) + { + if (cintVal->getValue() < 1) + { + getSink()->diagnose( + attr, + Diagnostics::nonPositiveNumThreads, + cintVal->getValue()); + return nullptr; + } + } + value = intValue; + } + else + { + value = m_astBuilder->getIntVal(m_astBuilder->getIntType(), 1); } + values[i] = value; } - // Next, we check that the decl does not have higher visiblity than its parent. - Decl* parentDecl = decl; - for (; parentDecl; parentDecl = parentDecl->parentDecl) + attr->x = values[0]; + attr->y = values[1]; + attr->z = values[2]; + } + + // Default behavior is to leave things as they are, + // and assume that modifiers are mostly already checked. + // + // TODO: This would be a good place to validate that + // a modifier is actually valid for the thing it is + // being applied to, and potentially to check that + // it isn't in conflict with any other modifiers + // on the same declaration. + + return m; +} + +void SemanticsVisitor::checkVisibility(Decl* decl) +{ + if (as(decl)) + { + return; + } + ShortList typesToVerify; + if (auto varDecl = as(decl)) + { + typesToVerify.add(varDecl->type); + } + else if (auto callable = as(decl)) + { + typesToVerify.add(callable->returnType); + typesToVerify.add(callable->errorType); + for (auto param : callable->getParameters()) { - if (as(parentDecl)) - break; + typesToVerify.add(param->type); } - if (!parentDecl) - return; - auto parentVisibility = getDeclVisibility(parentDecl); - if (thisVisibility > parentVisibility) + } + else if (auto propertyDecl = as(decl)) + { + typesToVerify.add(propertyDecl->type); + } + else if (as(decl)) + { + } + else if (auto typeDecl = as(decl)) + { + typesToVerify.add(typeDecl->type); + } + else + { + return; + } + auto thisVisibility = getDeclVisibility(decl); + + // First, we check that the decl's type does not have lower visibility. + for (auto type : typesToVerify) + { + if (!type) + continue; + DeclVisibility typeVisibility = getTypeVisibility(type); + if (typeVisibility < thisVisibility) { - getSink()->diagnose(decl, Diagnostics::declCannotHaveHigherVisibility, decl, parentDecl); + getSink()->diagnose(decl, Diagnostics::useOfLessVisibleType, decl, type); + break; } } - void postProcessingOnModifiers(Modifiers& modifiers) + // Next, we check that the decl does not have higher visiblity than its parent. + Decl* parentDecl = decl; + for (; parentDecl; parentDecl = parentDecl->parentDecl) { - // compress all `require` nodes into 1 `require` modifier - RequireCapabilityAttribute* firstRequire = nullptr; - Modifier* previous = nullptr; - Modifier* next = nullptr; - for (auto m = modifiers.first; m != nullptr; m = next) - { - next = m->next; - // + if (as(parentDecl)) + break; + } + if (!parentDecl) + return; + auto parentVisibility = getDeclVisibility(parentDecl); + if (thisVisibility > parentVisibility) + { + getSink()->diagnose(decl, Diagnostics::declCannotHaveHigherVisibility, decl, parentDecl); + } +} + +void postProcessingOnModifiers(Modifiers& modifiers) +{ + // compress all `require` nodes into 1 `require` modifier + RequireCapabilityAttribute* firstRequire = nullptr; + Modifier* previous = nullptr; + Modifier* next = nullptr; + for (auto m = modifiers.first; m != nullptr; m = next) + { + next = m->next; + // - if (auto req = as(m)) + if (auto req = as(m)) + { + if (!firstRequire) { - if (!firstRequire) - { - firstRequire = req; - previous = m; - continue; - } - firstRequire->capabilitySet.unionWith(req->capabilitySet); - if(previous) - previous->next = next; + firstRequire = req; + previous = m; continue; } - - // - previous = m; + firstRequire->capabilitySet.unionWith(req->capabilitySet); + if (previous) + previous->next = next; + continue; } + + // + previous = m; } +} - void SemanticsVisitor::checkModifiers(ModifiableSyntaxNode* syntaxNode) - { - // TODO(tfoley): need to make sure this only - // performs semantic checks on a `SharedModifier` once... +void SemanticsVisitor::checkModifiers(ModifiableSyntaxNode* syntaxNode) +{ + // TODO(tfoley): need to make sure this only + // performs semantic checks on a `SharedModifier` once... - // The process of checking a modifier may produce a new modifier in its place, - // so we will build up a new linked list of modifiers that will replace - // the old list. - Modifier* resultModifiers = nullptr; - Modifier** resultModifierLink = &resultModifiers; + // The process of checking a modifier may produce a new modifier in its place, + // so we will build up a new linked list of modifiers that will replace + // the old list. + Modifier* resultModifiers = nullptr; + Modifier** resultModifierLink = &resultModifiers; - // We will keep track of the modifiers for each conflict group. - Dictionary mapExclusiveGroupToModifier; + // We will keep track of the modifiers for each conflict group. + Dictionary mapExclusiveGroupToModifier; - Modifier* modifier = syntaxNode->modifiers.first; - bool ignoreUnallowedModifier = false; - while (modifier) + Modifier* modifier = syntaxNode->modifiers.first; + bool ignoreUnallowedModifier = false; + while (modifier) + { + // Check if a modifier belonging to the same conflict group is already + // defined. + Modifier* existingModifier = nullptr; + auto conflictGroup = getModifierConflictGroupKind(modifier->astNodeType); + if (conflictGroup != ASTNodeType::NodeBase) { - // Check if a modifier belonging to the same conflict group is already - // defined. - Modifier* existingModifier = nullptr; - auto conflictGroup = getModifierConflictGroupKind(modifier->astNodeType); - if (conflictGroup != ASTNodeType::NodeBase) + if (mapExclusiveGroupToModifier.tryGetValue(conflictGroup, existingModifier)) { - if (mapExclusiveGroupToModifier.tryGetValue(conflictGroup, existingModifier)) - { - getSink()->diagnose(modifier->loc, Diagnostics::duplicateModifier, modifier, existingModifier); - } - mapExclusiveGroupToModifier[conflictGroup] = modifier; + getSink()->diagnose( + modifier->loc, + Diagnostics::duplicateModifier, + modifier, + existingModifier); } + mapExclusiveGroupToModifier[conflictGroup] = modifier; + } - // Because we are rewriting the list in place, we need to extract - // the next modifier here (not at the end of the loop). - auto next = modifier->next; + // Because we are rewriting the list in place, we need to extract + // the next modifier here (not at the end of the loop). + auto next = modifier->next; - // We also go ahead and clobber the `next` field on the modifier - // itself, so that the default behavior of `checkModifier()` can - // be to return a single unlinked modifier. - modifier->next = nullptr; + // We also go ahead and clobber the `next` field on the modifier + // itself, so that the default behavior of `checkModifier()` can + // be to return a single unlinked modifier. + modifier->next = nullptr; - // For any modifiers appears after "SharedModifiers", we will not diagnose - // an error if the modifier is not allowed on the declaration. - if (as(modifier)) - ignoreUnallowedModifier = true; - - // may return a list of modifiers - auto checkedModifier = checkModifier(modifier, syntaxNode, ignoreUnallowedModifier); + // For any modifiers appears after "SharedModifiers", we will not diagnose + // an error if the modifier is not allowed on the declaration. + if (as(modifier)) + ignoreUnallowedModifier = true; - if(checkedModifier) - { - // If checking gave us a modifier to add, then we - // had better add it. - - // Just in case `checkModifier` ever returns multiple - // modifiers, lets advance to the end of the list we - // are building. - while(*resultModifierLink) - resultModifierLink = &(*resultModifierLink)->next; - - // attach the new modifier at the end of the list, - // and now set the "link" to point to its `next` field - *resultModifierLink = checkedModifier; - resultModifierLink = &checkedModifier->next; - } + // may return a list of modifiers + auto checkedModifier = checkModifier(modifier, syntaxNode, ignoreUnallowedModifier); - // Move along to the next modifier - modifier = next; - } + if (checkedModifier) + { + // If checking gave us a modifier to add, then we + // had better add it. - // Whether we actually re-wrote anything or note, lets - // install the new list of modifiers on the declaration - syntaxNode->modifiers.first = resultModifiers; + // Just in case `checkModifier` ever returns multiple + // modifiers, lets advance to the end of the list we + // are building. + while (*resultModifierLink) + resultModifierLink = &(*resultModifierLink)->next; - postProcessingOnModifiers(syntaxNode->modifiers); - } + // attach the new modifier at the end of the list, + // and now set the "link" to point to its `next` field + *resultModifierLink = checkedModifier; + resultModifierLink = &checkedModifier->next; + } + // Move along to the next modifier + modifier = next; + } + // Whether we actually re-wrote anything or note, lets + // install the new list of modifiers on the declaration + syntaxNode->modifiers.first = resultModifiers; + postProcessingOnModifiers(syntaxNode->modifiers); } + + +} // namespace Slang diff --git a/source/slang/slang-check-overload.cpp b/source/slang/slang-check-overload.cpp index 0e0d64f67..609fa8635 100644 --- a/source/slang/slang-check-overload.cpp +++ b/source/slang/slang-check-overload.cpp @@ -1,9 +1,8 @@ // slang-check-overload.cpp #include "slang-ast-base.h" +#include "slang-ast-print.h" #include "slang-check-impl.h" - #include "slang-lookup.h" -#include "slang-ast-print.h" // This file implements semantic checking logic related // to resolving overloading call operations, by checking @@ -11,1692 +10,1718 @@ namespace Slang { - SemanticsVisitor::ParamCounts SemanticsVisitor::CountParameters(FilteredMemberRefList params) +SemanticsVisitor::ParamCounts SemanticsVisitor::CountParameters( + FilteredMemberRefList params) +{ + ParamCounts counts = {0, 0}; + for (auto param : params) { - ParamCounts counts = { 0, 0 }; - for (auto param : params) + Index allowedArgCountToAdd = 1; + auto paramType = getParamType(m_astBuilder, param); + if (isTypePack(paramType)) { - Index allowedArgCountToAdd = 1; - auto paramType = getParamType(m_astBuilder, param); - if (isTypePack(paramType)) + if (auto typePack = as(paramType)) { - if (auto typePack = as(paramType)) - { - counts.required += typePack->getTypeCount(); - allowedArgCountToAdd = typePack->getTypeCount(); - } - else - { - counts.allowed = -1; - } + counts.required += typePack->getTypeCount(); + allowedArgCountToAdd = typePack->getTypeCount(); } - else if (!param.getDecl()->initExpr) + else { - // No initializer means no default value - // - // TODO(tfoley): The logic here is currently broken in two ways: - // - // 1. We are assuming that once one parameter has a default, then all do. - // This can/should be validated earlier, so that we can assume it here. - // - // 2. We are not handling the possibility of multiple declarations for - // a single function, where we'd need to merge default parameters across - // all the declarations. - counts.required++; + counts.allowed = -1; } - - if (counts.allowed >= 0) - counts.allowed += allowedArgCountToAdd; } - return counts; + else if (!param.getDecl()->initExpr) + { + // No initializer means no default value + // + // TODO(tfoley): The logic here is currently broken in two ways: + // + // 1. We are assuming that once one parameter has a default, then all do. + // This can/should be validated earlier, so that we can assume it here. + // + // 2. We are not handling the possibility of multiple declarations for + // a single function, where we'd need to merge default parameters across + // all the declarations. + counts.required++; + } + + if (counts.allowed >= 0) + counts.allowed += allowedArgCountToAdd; } + return counts; +} - SemanticsVisitor::ParamCounts SemanticsVisitor::CountParameters(DeclRef genericRef) +SemanticsVisitor::ParamCounts SemanticsVisitor::CountParameters(DeclRef genericRef) +{ + ParamCounts counts = {0, 0}; + for (auto m : genericRef.getDecl()->members) { - ParamCounts counts = { 0, 0 }; - for (auto m : genericRef.getDecl()->members) + if (auto typeParam = as(m)) { - if (auto typeParam = as(m)) - { - if (counts.allowed >= 0) - counts.allowed++; - if (!typeParam->initType.Ptr()) - { - counts.required++; - } - } - else if (auto valParam = as(m)) + if (counts.allowed >= 0) + counts.allowed++; + if (!typeParam->initType.Ptr()) { - if (counts.allowed >= 0) - counts.allowed++; - if (!valParam->initExpr) - { - counts.required++; - } + counts.required++; } - else if (as(m)) + } + else if (auto valParam = as(m)) + { + if (counts.allowed >= 0) + counts.allowed++; + if (!valParam->initExpr) { - counts.allowed = -1; + counts.required++; } } - return counts; + else if (as(m)) + { + counts.allowed = -1; + } } + return counts; +} - bool SemanticsVisitor::TryCheckOverloadCandidateClassNewMatchUp(OverloadResolveContext& context, OverloadCandidate const& candidate) +bool SemanticsVisitor::TryCheckOverloadCandidateClassNewMatchUp( + OverloadResolveContext& context, + OverloadCandidate const& candidate) +{ + // Check that a constructor call to a class type must be in a `new` expr, and a `new` expr + // is only used to construct a class. + bool isClassType = false; + bool isNewExpr = false; + if (auto ctorDeclRef = candidate.item.declRef.as()) { - // Check that a constructor call to a class type must be in a `new` expr, and a `new` expr - // is only used to construct a class. - bool isClassType = false; - bool isNewExpr = false; - if (auto ctorDeclRef = candidate.item.declRef.as()) + if (auto resultType = as(candidate.resultType)) { - if (auto resultType = as(candidate.resultType)) + if (resultType->getDeclRef().as()) { - if (resultType->getDeclRef().as()) - { - isClassType = true; - } + isClassType = true; } } - if (as(context.originalExpr)) + } + if (as(context.originalExpr)) + { + isNewExpr = true; + } + + if (isNewExpr && !isClassType) + { + getSink()->diagnose(context.originalExpr, Diagnostics::newCanOnlyBeUsedToInitializeAClass); + return false; + } + if (!isNewExpr && isClassType && context.originalExpr) + { + getSink()->diagnose(context.originalExpr, Diagnostics::classCanOnlyBeInitializedWithNew); + return false; + } + return true; +} + +bool SemanticsVisitor::TryCheckOverloadCandidateArity( + OverloadResolveContext& context, + OverloadCandidate const& candidate) +{ + Count argCount = context.getArgCount(); + ParamCounts paramCounts = {0, 0}; + switch (candidate.flavor) + { + case OverloadCandidate::Flavor::Func: + paramCounts = + CountParameters(getParameters(m_astBuilder, candidate.item.declRef.as())); + break; + + case OverloadCandidate::Flavor::Generic: + paramCounts = CountParameters(candidate.item.declRef.as()); + + // A generic can be applied to any number of arguments less + // than or equal to the number of explicitly declared parameters. + // When a program provides fewer arguments than their are parameters, + // the rest will be inferred. + // + paramCounts.required = 0; + break; + + case OverloadCandidate::Flavor::Expr: { - isNewExpr = true; + auto paramCount = candidate.funcType->getParamCount(); + paramCounts.allowed = paramCount; + paramCounts.required = paramCount; } + break; + + default: SLANG_UNEXPECTED("unknown flavor of overload candidate"); break; + } + + if (argCount >= paramCounts.required && + (paramCounts.allowed == -1 || argCount <= paramCounts.allowed)) + return true; - if (isNewExpr && !isClassType) + // Emit an error message if we are checking this call for real + if (context.mode != OverloadResolveContext::Mode::JustTrying) + { + if (argCount < paramCounts.required) { - getSink()->diagnose(context.originalExpr, Diagnostics::newCanOnlyBeUsedToInitializeAClass); - return false; + getSink()->diagnose( + context.loc, + Diagnostics::notEnoughArguments, + argCount, + paramCounts.required); } - if (!isNewExpr && isClassType && context.originalExpr) + else { - getSink()->diagnose(context.originalExpr, Diagnostics::classCanOnlyBeInitializedWithNew); - return false; + SLANG_ASSERT(argCount > paramCounts.allowed); + getSink()->diagnose( + context.loc, + Diagnostics::tooManyArguments, + argCount, + paramCounts.allowed); } - return true; } - bool SemanticsVisitor::TryCheckOverloadCandidateArity( - OverloadResolveContext& context, - OverloadCandidate const& candidate) - { - Count argCount = context.getArgCount(); - ParamCounts paramCounts = { 0, 0 }; - switch (candidate.flavor) - { - case OverloadCandidate::Flavor::Func: - paramCounts = CountParameters(getParameters(m_astBuilder, candidate.item.declRef.as())); - break; + return false; +} - case OverloadCandidate::Flavor::Generic: - paramCounts = CountParameters(candidate.item.declRef.as()); +bool SemanticsVisitor::TryCheckOverloadCandidateFixity( + OverloadResolveContext& context, + OverloadCandidate const& candidate) +{ + auto expr = context.originalExpr; - // A generic can be applied to any number of arguments less - // than or equal to the number of explicitly declared parameters. - // When a program provides fewer arguments than their are parameters, - // the rest will be inferred. - // - paramCounts.required = 0; - break; + auto decl = candidate.item.declRef.getDecl(); - case OverloadCandidate::Flavor::Expr: - { - auto paramCount = candidate.funcType->getParamCount(); - paramCounts.allowed = paramCount; - paramCounts.required = paramCount; - } - break; + if (const auto prefixExpr = as(expr)) + { + if (decl->hasModifier()) + return true; - default: - SLANG_UNEXPECTED("unknown flavor of overload candidate"); - break; + if (context.mode != OverloadResolveContext::Mode::JustTrying) + { + getSink()->diagnose(context.loc, Diagnostics::expectedPrefixOperator); + getSink()->diagnose(decl, Diagnostics::seeDefinitionOf, decl->getName()); } - if (argCount >= paramCounts.required && (paramCounts.allowed == -1 || argCount <= paramCounts.allowed)) + return false; + } + else if (const auto postfixExpr = as(expr)) + { + if (decl->hasModifier()) return true; - // Emit an error message if we are checking this call for real if (context.mode != OverloadResolveContext::Mode::JustTrying) { - if (argCount < paramCounts.required) - { - getSink()->diagnose(context.loc, Diagnostics::notEnoughArguments, argCount, paramCounts.required); - } - else - { - SLANG_ASSERT(argCount > paramCounts.allowed); - getSink()->diagnose(context.loc, Diagnostics::tooManyArguments, argCount, paramCounts.allowed); - } + getSink()->diagnose(context.loc, Diagnostics::expectedPostfixOperator); + getSink()->diagnose(decl, Diagnostics::seeDefinitionOf, decl->getName()); } return false; } + else + { + return true; + } +} - bool SemanticsVisitor::TryCheckOverloadCandidateFixity( - OverloadResolveContext& context, - OverloadCandidate const& candidate) +bool SemanticsVisitor::TryCheckOverloadCandidateVisibility( + OverloadResolveContext& context, + OverloadCandidate const& candidate) +{ + // Always succeeds when we are trying out constructors. + if (context.mode == OverloadResolveContext::Mode::JustTrying) { - auto expr = context.originalExpr; + if (as(candidate.item.declRef)) + return true; + } - auto decl = candidate.item.declRef.getDecl(); + if (!context.sourceScope) + return true; + + if (!candidate.item.declRef) + return true; - if(const auto prefixExpr = as(expr)) + if (!isDeclVisibleFromScope(candidate.item.declRef, context.sourceScope)) + { + if (context.mode == OverloadResolveContext::Mode::ForReal) { - if(decl->hasModifier()) - return true; + getSink()->diagnose(context.loc, Diagnostics::declIsNotVisible, candidate.item.declRef); + } + return false; + } - if (context.mode != OverloadResolveContext::Mode::JustTrying) - { - getSink()->diagnose(context.loc, Diagnostics::expectedPrefixOperator); - getSink()->diagnose(decl, Diagnostics::seeDefinitionOf, decl->getName()); - } + return true; +} - return false; - } - else if(const auto postfixExpr = as(expr)) - { - if(decl->hasModifier()) - return true; +bool SemanticsVisitor::TryCheckGenericOverloadCandidateTypes( + OverloadResolveContext& context, + OverloadCandidate& candidate) +{ + auto genericDeclRef = candidate.item.declRef.as(); - if (context.mode != OverloadResolveContext::Mode::JustTrying) - { - getSink()->diagnose(context.loc, Diagnostics::expectedPostfixOperator); - getSink()->diagnose(decl, Diagnostics::seeDefinitionOf, decl->getName()); - } + // Only allow constructing a PartialGenericAppExpr when referencing a callable decl. + // Other types of generic decls must be fully specified. + bool allowPartialGenericApp = false; + if (as(genericDeclRef.getDecl()->inner)) + { + allowPartialGenericApp = true; + } - return false; - } - else + // The basic idea here is that we need to check that the + // arguments to a generic application (e.g., `F`) + // have the right "type," which in this context means + // checking that: + // + // * The argument for any generic type parameter is a (proper) type. + // + // * The argument for any generic value parameter is a + // specialization-time constant value of the appropriate type. + // + // Some additional checks are *not* handled at this point: + // + // * We don't check that a type argument actually conforms to + // the constraints on the parameter. + // + // Along the way we will build up a `GenericSubstitution` + // to represent the arguments that have been coerced to + // appropriate forms. + // + List checkedArgs; + + // Rather than bail out as soon as we hit a problem, + // we are going to process *all* of the parameters of the + // generic and place suitable arguments into the `checkedArgs` + // array. This is important so that we don't cause crashes + // in cases where the arguments fail this step of checking, + // but we decide to proceed with subsequent steps (e.g., + // because the candidate we are trying here is the *only* + // candidate). + // + bool success = true; + + auto maybeReportGeneralError = [&]() + { + if (context.mode != OverloadResolveContext::Mode::JustTrying) { - return true; + getSink()->diagnose( + context.loc, + Diagnostics::cannotSpecializeGeneric, + candidate.item.declRef); } - } - - bool SemanticsVisitor::TryCheckOverloadCandidateVisibility(OverloadResolveContext& context, OverloadCandidate const& candidate) + }; + List paramTypes; + for (auto memberRef : getMembers(m_astBuilder, genericDeclRef)) { - // Always succeeds when we are trying out constructors. - if (context.mode == OverloadResolveContext::Mode::JustTrying) + if (auto typeParamRef = memberRef.as()) { - if (as(candidate.item.declRef)) - return true; + paramTypes.add(DeclRefType::create(m_astBuilder, typeParamRef)); } - - if (!context.sourceScope) - return true; - - if (!candidate.item.declRef) - return true; - - if (!isDeclVisibleFromScope(candidate.item.declRef, context.sourceScope)) + else if (auto valParamRef = memberRef.as()) { - if (context.mode == OverloadResolveContext::Mode::ForReal) - { - getSink()->diagnose(context.loc, Diagnostics::declIsNotVisible, candidate.item.declRef); - } - return false; + paramTypes.add(getType(m_astBuilder, valParamRef)); } - - return true; - } - - bool SemanticsVisitor::TryCheckGenericOverloadCandidateTypes( - OverloadResolveContext& context, - OverloadCandidate& candidate) - { - auto genericDeclRef = candidate.item.declRef.as(); - - // Only allow constructing a PartialGenericAppExpr when referencing a callable decl. - // Other types of generic decls must be fully specified. - bool allowPartialGenericApp = false; - if (as(genericDeclRef.getDecl()->inner)) + else if (auto typePackParam = memberRef.as()) { - allowPartialGenericApp = true; + paramTypes.add(DeclRefType::create(m_astBuilder, typePackParam)); } + } + ShortList matchedArgs; + if (!context.matchArgumentsToParams(this, paramTypes, false, matchedArgs)) + { + maybeReportGeneralError(); + return false; + } - // The basic idea here is that we need to check that the - // arguments to a generic application (e.g., `F`) - // have the right "type," which in this context means - // checking that: - // - // * The argument for any generic type parameter is a (proper) type. - // - // * The argument for any generic value parameter is a - // specialization-time constant value of the appropriate type. - // - // Some additional checks are *not* handled at this point: - // - // * We don't check that a type argument actually conforms to - // the constraints on the parameter. - // - // Along the way we will build up a `GenericSubstitution` - // to represent the arguments that have been coerced to - // appropriate forms. - // - List checkedArgs; - - // Rather than bail out as soon as we hit a problem, - // we are going to process *all* of the parameters of the - // generic and place suitable arguments into the `checkedArgs` - // array. This is important so that we don't cause crashes - // in cases where the arguments fail this step of checking, - // but we decide to proceed with subsequent steps (e.g., - // because the candidate we are trying here is the *only* - // candidate). - // - bool success = true; - - auto maybeReportGeneralError = [&]() + Index aa = 0; + for (auto memberRef : getMembers(m_astBuilder, genericDeclRef)) + { + if (auto typeParamRef = memberRef.as()) { - if (context.mode != OverloadResolveContext::Mode::JustTrying) + if (aa >= matchedArgs.getCount()) { - getSink()->diagnose(context.loc, Diagnostics::cannotSpecializeGeneric, candidate.item.declRef); + if (allowPartialGenericApp) + { + // If we have run out of arguments, and the referenced decl + // allows partially applied specialization (i.e. a callable + // decl) then we don't apply any more checks at this step. + // We will instead attempt to *infer* an argument at this + // position at a later stage. + // + candidate.flags |= OverloadCandidate::Flag::IsPartiallyAppliedGeneric; + break; + } + else + { + // Otherwise, the generic decl had better provide a default value + // or this reference is ill-formed. + auto substType = typeParamRef.substitute( + m_astBuilder, + typeParamRef.getDecl()->initType.type); + if (!substType) + { + maybeReportGeneralError(); + return false; + } + checkedArgs.add(substType); + continue; + } } - }; - List paramTypes; - for (auto memberRef : getMembers(m_astBuilder, genericDeclRef)) - { - if (auto typeParamRef = memberRef.as()) + + // We have a type parameter, and we expect to find + // a type argument. + // + TypeExp typeArg; + + // Per the earlier check, we have at least one + // argument left, so we will grab + // it and try to coerce it to a proper type. The + // manner in which we handle the coercion depends + // on whether we are "just trying" the candidate + // (so a failure would rule out the candidate, but + // shouldn't be reported to the user), or are doing + // the checking "for real" in which case any errors + // we run into need to be reported. + // + auto arg = matchedArgs[aa++]; + if (context.mode == OverloadResolveContext::Mode::JustTrying) { - paramTypes.add(DeclRefType::create(m_astBuilder, typeParamRef)); + typeArg = tryCoerceToProperType(TypeExp(arg.argExpr)); } - else if (auto valParamRef = memberRef.as()) + else { - paramTypes.add(getType(m_astBuilder, valParamRef)); + arg.argExpr = ExpectATypeRepr(arg.argExpr); + typeArg = CoerceToProperType(TypeExp(arg.argExpr)); } - else if (auto typePackParam = memberRef.as()) + + // If we failed to get a valid type (either because + // there was no matching argument, or because the + // "just trying" coercion failed), then we create + // an error type to stand in for the argument + // + if (!typeArg.type) { - paramTypes.add(DeclRefType::create(m_astBuilder, typePackParam)); + typeArg.type = m_astBuilder->getErrorType(); + success = false; } - } - ShortList matchedArgs; - if (!context.matchArgumentsToParams(this, paramTypes, false, matchedArgs)) - { - maybeReportGeneralError(); - return false; - } - Index aa = 0; - for (auto memberRef : getMembers(m_astBuilder, genericDeclRef)) + checkedArgs.add(typeArg.type); + } + else if (auto valParamRef = memberRef.as()) { - if (auto typeParamRef = memberRef.as()) + if (aa >= matchedArgs.getCount()) { - if (aa >= matchedArgs.getCount()) - { - if (allowPartialGenericApp) - { - // If we have run out of arguments, and the referenced decl - // allows partially applied specialization (i.e. a callable - // decl) then we don't apply any more checks at this step. - // We will instead attempt to *infer* an argument at this - // position at a later stage. - // - candidate.flags |= OverloadCandidate::Flag::IsPartiallyAppliedGeneric; - break; - } - else - { - // Otherwise, the generic decl had better provide a default value - // or this reference is ill-formed. - auto substType = typeParamRef.substitute(m_astBuilder, typeParamRef.getDecl()->initType.type); - if (!substType) - { - maybeReportGeneralError(); - return false; - } - checkedArgs.add(substType); - continue; - } - } - - // We have a type parameter, and we expect to find - // a type argument. - // - TypeExp typeArg; - - // Per the earlier check, we have at least one - // argument left, so we will grab - // it and try to coerce it to a proper type. The - // manner in which we handle the coercion depends - // on whether we are "just trying" the candidate - // (so a failure would rule out the candidate, but - // shouldn't be reported to the user), or are doing - // the checking "for real" in which case any errors - // we run into need to be reported. - // - auto arg = matchedArgs[aa++]; - if (context.mode == OverloadResolveContext::Mode::JustTrying) + if (allowPartialGenericApp) { - typeArg = tryCoerceToProperType(TypeExp(arg.argExpr)); + // If we have run out of arguments and the decl allows + // partial specialization, then we don't apply any more + // checks at this step. We will instead attempt to + // *infer* an argument at this position at a later + // stage. + // + candidate.flags |= OverloadCandidate::Flag::IsPartiallyAppliedGeneric; + break; } else { - arg.argExpr = ExpectATypeRepr(arg.argExpr); - typeArg = CoerceToProperType(TypeExp(arg.argExpr)); + // Otherwise, the generic decl had better provide a default value + // or this reference is ill-formed. + ensureDecl(valParamRef, DeclCheckState::DefinitionChecked); + ConstantFoldingCircularityInfo newCircularityInfo( + valParamRef.getDecl(), + nullptr); + auto defaultVal = tryConstantFoldExpr( + valParamRef.substitute(m_astBuilder, valParamRef.getDecl()->initExpr), + ConstantFoldingKind::CompileTime, + &newCircularityInfo); + if (!defaultVal) + { + maybeReportGeneralError(); + return false; + } + checkedArgs.add(defaultVal); + continue; } + } - // If we failed to get a valid type (either because - // there was no matching argument, or because the - // "just trying" coercion failed), then we create - // an error type to stand in for the argument - // - if( !typeArg.type ) + // The case for a generic value parameter is similar to that + // for a generic type parameter. + // + Expr* arg = nullptr; + + // If we have an argument then we need to coerce it + // to the type of the parameter (and fail if the + // coercion is not possible) + // + arg = matchedArgs[aa++].argExpr; + if (context.mode == OverloadResolveContext::Mode::JustTrying) + { + ConversionCost cost = kConversionCost_None; + if (!canCoerce(getType(m_astBuilder, valParamRef), arg->type, arg, &cost)) { - typeArg.type = m_astBuilder->getErrorType(); success = false; } - - checkedArgs.add(typeArg.type); + candidate.conversionCostSum += cost; } - else if (auto valParamRef = memberRef.as()) + else { - if (aa >= matchedArgs.getCount()) - { - if (allowPartialGenericApp) - { - // If we have run out of arguments and the decl allows - // partial specialization, then we don't apply any more - // checks at this step. We will instead attempt to - // *infer* an argument at this position at a later - // stage. - // - candidate.flags |= OverloadCandidate::Flag::IsPartiallyAppliedGeneric; - break; - } - else - { - // Otherwise, the generic decl had better provide a default value - // or this reference is ill-formed. - ensureDecl(valParamRef, DeclCheckState::DefinitionChecked); - ConstantFoldingCircularityInfo newCircularityInfo(valParamRef.getDecl(), nullptr); - auto defaultVal = tryConstantFoldExpr(valParamRef.substitute(m_astBuilder, valParamRef.getDecl()->initExpr), ConstantFoldingKind::CompileTime, &newCircularityInfo); - if (!defaultVal) - { - maybeReportGeneralError(); - return false; - } - checkedArgs.add(defaultVal); - continue; - } - } + arg = coerce(CoercionSite::Argument, getType(m_astBuilder, valParamRef), arg); + } - // The case for a generic value parameter is similar to that - // for a generic type parameter. - // - Expr* arg = nullptr; + // If we have an argument to work with, then we will + // try to extract its speicalization-time constant value. + // + Val* val = nullptr; + if (arg) + { + val = ExtractGenericArgInteger( + arg, + getType(m_astBuilder, valParamRef), + context.mode == OverloadResolveContext::Mode::JustTrying ? nullptr : getSink()); + } - // If we have an argument then we need to coerce it - // to the type of the parameter (and fail if the - // coercion is not possible) - // - arg = matchedArgs[aa++].argExpr; - if (context.mode == OverloadResolveContext::Mode::JustTrying) + // If any of the above checking steps fail and we don't + // have a value to work with here, we will instead + // use an "error" value to stand in for the argument. + // + if (!val) + { + val = m_astBuilder->getOrCreate(m_astBuilder->getIntType()); + } + checkedArgs.add(val); + } + else if (auto typePackParam = memberRef.as()) + { + Val* val = nullptr; + if (aa >= matchedArgs.getCount()) + { + if (allowPartialGenericApp) { - ConversionCost cost = kConversionCost_None; - if (!canCoerce(getType(m_astBuilder, valParamRef), arg->type, arg, &cost)) - { - success = false; - } - candidate.conversionCostSum += cost; + // If we have run out of arguments and the decl allows + // partial specialization, then we don't apply any more + // checks at this step. We will instead attempt to + // *infer* an argument at this position at a later + // stage. + // + candidate.flags |= OverloadCandidate::Flag::IsPartiallyAppliedGeneric; + break; } else { - arg = coerce(CoercionSite::Argument, getType(m_astBuilder, valParamRef), arg); + // Otherwise, we will just create an empty pack. + val = m_astBuilder->getTypePack(ArrayView()); } - - // If we have an argument to work with, then we will - // try to extract its speicalization-time constant value. - // - Val* val = nullptr; - if( arg ) - { - val = ExtractGenericArgInteger(arg, getType(m_astBuilder, valParamRef), context.mode == OverloadResolveContext::Mode::JustTrying ? nullptr : getSink()); - } - - // If any of the above checking steps fail and we don't - // have a value to work with here, we will instead - // use an "error" value to stand in for the argument. - // - if( !val ) - { - val = m_astBuilder->getOrCreate(m_astBuilder->getIntType()); - } - checkedArgs.add(val); } - else if (auto typePackParam = memberRef.as()) + else { - Val* val = nullptr; - if (aa >= matchedArgs.getCount()) - { - if (allowPartialGenericApp) - { - // If we have run out of arguments and the decl allows - // partial specialization, then we don't apply any more - // checks at this step. We will instead attempt to - // *infer* an argument at this position at a later - // stage. - // - candidate.flags |= OverloadCandidate::Flag::IsPartiallyAppliedGeneric; - break; - } - else - { - // Otherwise, we will just create an empty pack. - val = m_astBuilder->getTypePack(ArrayView()); - } - } - else + auto matchedArg = matchedArgs[aa++]; + if (auto packExpr = as(matchedArg.argExpr)) { - auto matchedArg = matchedArgs[aa++]; - if (auto packExpr = as(matchedArg.argExpr)) - { - // We are providing a concrete pack of types as arguments to a type pack parameter. - // We need to create a `TypePack` type to serve as the argument. - ShortList coercedProperTypes; + // We are providing a concrete pack of types as arguments to a type pack + // parameter. We need to create a `TypePack` type to serve as the argument. + ShortList coercedProperTypes; - // Coerce all types in the pack to proper types. - for (Index i = 0; i < packExpr->args.getCount(); i++) + // Coerce all types in the pack to proper types. + for (Index i = 0; i < packExpr->args.getCount(); i++) + { + TypeExp typeArg; + auto elementTypeExpr = packExpr->args[i]; + if (context.mode == OverloadResolveContext::Mode::JustTrying) { - TypeExp typeArg; - auto elementTypeExpr = packExpr->args[i]; - if (context.mode == OverloadResolveContext::Mode::JustTrying) - { - typeArg = tryCoerceToProperType(TypeExp(elementTypeExpr)); - if (!typeArg.type) - { - typeArg.type = m_astBuilder->getErrorType(); - success = false; - } - } - else - { - elementTypeExpr = ExpectATypeRepr(elementTypeExpr); - typeArg = CoerceToProperType(TypeExp(elementTypeExpr)); - } - // If we failed to get a valid type (either because - // there was no matching argument, or because the - // "just trying" coercion failed), then we create - // an error type to stand in for the argument - // + typeArg = tryCoerceToProperType(TypeExp(elementTypeExpr)); if (!typeArg.type) { typeArg.type = m_astBuilder->getErrorType(); success = false; } - coercedProperTypes.add(typeArg.type); } - val = m_astBuilder->getTypePack(coercedProperTypes.getArrayView().arrayView); - } - else if (auto expandExpr = as(matchedArg.argExpr)) - { - auto argType = expandExpr->type.type; - if (auto typeType = as(argType)) - argType = typeType->getType(); - val = argType; - } - else if (auto typeType = as(matchedArg.argType)) - { - if (isAbstractTypePack(typeType->getType())) + else + { + elementTypeExpr = ExpectATypeRepr(elementTypeExpr); + typeArg = CoerceToProperType(TypeExp(elementTypeExpr)); + } + // If we failed to get a valid type (either because + // there was no matching argument, or because the + // "just trying" coercion failed), then we create + // an error type to stand in for the argument + // + if (!typeArg.type) { - val = typeType->getType(); + typeArg.type = m_astBuilder->getErrorType(); + success = false; } + coercedProperTypes.add(typeArg.type); } + val = m_astBuilder->getTypePack(coercedProperTypes.getArrayView().arrayView); } - if (val == nullptr) + else if (auto expandExpr = as(matchedArg.argExpr)) { - maybeReportGeneralError(); - return false; + auto argType = expandExpr->type.type; + if (auto typeType = as(argType)) + argType = typeType->getType(); + val = argType; + } + else if (auto typeType = as(matchedArg.argType)) + { + if (isAbstractTypePack(typeType->getType())) + { + val = typeType->getType(); + } } - checkedArgs.add(val); } - else + if (val == nullptr) { - continue; + maybeReportGeneralError(); + return false; } + checkedArgs.add(val); } + else + { + continue; + } + } - auto genSubst = m_astBuilder->getGenericAppDeclRef(genericDeclRef, checkedArgs.getArrayView()); - candidate.subst = SubstitutionSet(genSubst); + auto genSubst = m_astBuilder->getGenericAppDeclRef(genericDeclRef, checkedArgs.getArrayView()); + candidate.subst = SubstitutionSet(genSubst); - // Once we are done processing the parameters of the generic, - // we will have build up a usable `checkedArgs` array and - // can return to the caller a report of whether we - // were successful or not. - // - return success; - } + // Once we are done processing the parameters of the generic, + // we will have build up a usable `checkedArgs` array and + // can return to the caller a report of whether we + // were successful or not. + // + return success; +} - static QualType getParamQualType(ASTBuilder* astBuilder, DeclRef param) +static QualType getParamQualType(ASTBuilder* astBuilder, DeclRef param) +{ + auto paramType = getType(astBuilder, param); + bool isLVal = false; + switch (getParameterDirection(param.getDecl())) { - auto paramType = getType(astBuilder, param); - bool isLVal = false; - switch (getParameterDirection(param.getDecl())) - { - case kParameterDirection_InOut: - case kParameterDirection_Out: - case kParameterDirection_Ref: - isLVal = true; - break; - } - return QualType(paramType, isLVal); + case kParameterDirection_InOut: + case kParameterDirection_Out: + case kParameterDirection_Ref: isLVal = true; break; } + return QualType(paramType, isLVal); +} - static QualType getParamQualType(Type* paramType) +static QualType getParamQualType(Type* paramType) +{ + if (auto paramDirType = as(paramType)) { - if (auto paramDirType = as(paramType)) - { - if (as(paramDirType) || as(paramDirType)) - return QualType(paramDirType->getValueType(), true); - } - return paramType; + if (as(paramDirType) || as(paramDirType)) + return QualType(paramDirType->getValueType(), true); } + return paramType; +} + +bool SemanticsVisitor::TryCheckOverloadCandidateTypes( + OverloadResolveContext& context, + OverloadCandidate& candidate) +{ + Index argCount = context.getArgCount(); - bool SemanticsVisitor::TryCheckOverloadCandidateTypes( - OverloadResolveContext& context, - OverloadCandidate& candidate) + List paramTypes; + switch (candidate.flavor) { - Index argCount = context.getArgCount(); + case OverloadCandidate::Flavor::Func: + for (auto param : getParameters(m_astBuilder, candidate.item.declRef.as())) + { + paramTypes.add(getParamQualType(m_astBuilder, param)); + } + break; - List paramTypes; - switch (candidate.flavor) + case OverloadCandidate::Flavor::Expr: { - case OverloadCandidate::Flavor::Func: - for (auto param : getParameters(m_astBuilder, candidate.item.declRef.as())) + auto funcType = candidate.funcType; + Count paramCount = funcType->getParamCount(); + for (Index i = 0; i < paramCount; ++i) { - paramTypes.add(getParamQualType(m_astBuilder, param)); + auto paramType = getParamQualType(funcType->getParamType(i)); + paramTypes.add(paramType); } - break; + } + break; - case OverloadCandidate::Flavor::Expr: + case OverloadCandidate::Flavor::Generic: + { + return TryCheckGenericOverloadCandidateTypes(context, candidate); + } + default: SLANG_UNEXPECTED("unknown flavor of overload candidate"); break; + } + + Index paramIndex = 0; + Index argIndex = 0; + struct Arg + { + Expr* argExpr; + Type* type; + }; + auto readArg = [&]() -> Arg + { + if (argIndex >= argCount) + return {nullptr, nullptr}; + auto arg = context.getArg(argIndex); + Arg result = {arg, context.getArgType(argIndex)}; + argIndex++; + return result; + }; + + auto coerceArgToParam = [&](Arg arg, QualType paramType) -> Arg + { + auto argType = QualType(arg.type, paramType.isLeftValue); + if (!paramType) + return {nullptr, nullptr}; + if (!argType) + return {nullptr, nullptr}; + if (context.mode == OverloadResolveContext::Mode::JustTrying) + { + ConversionCost cost = kConversionCost_None; + if (context.disallowNestedConversions) { - auto funcType = candidate.funcType; - Count paramCount = funcType->getParamCount(); - for (Index i = 0; i < paramCount; ++i) - { - auto paramType = getParamQualType(funcType->getParamType(i)); - paramTypes.add(paramType); - } + // We need an exact match in this case. + if (!paramType->equals(argType)) + return {nullptr, nullptr}; } - break; - - case OverloadCandidate::Flavor::Generic: + else if (!canCoerce(paramType, argType, arg.argExpr, &cost)) { - return TryCheckGenericOverloadCandidateTypes(context, candidate); + return {nullptr, nullptr}; } - default: - SLANG_UNEXPECTED("unknown flavor of overload candidate"); - break; + candidate.conversionCostSum += cost; } - - Index paramIndex = 0; - Index argIndex = 0; - struct Arg { Expr* argExpr; Type* type; }; - auto readArg = [&]() -> Arg + else { - if (argIndex >= argCount) - return { nullptr, nullptr }; - auto arg = context.getArg(argIndex); - Arg result = { arg, context.getArgType(argIndex) }; - argIndex++; - return result; - }; - - auto coerceArgToParam = [&](Arg arg, QualType paramType) -> Arg - { - auto argType = QualType(arg.type, paramType.isLeftValue); - if (!paramType) - return { nullptr, nullptr }; - if (!argType) - return { nullptr, nullptr }; - if (context.mode == OverloadResolveContext::Mode::JustTrying) - { - ConversionCost cost = kConversionCost_None; - if (context.disallowNestedConversions) - { - // We need an exact match in this case. - if (!paramType->equals(argType)) - return { nullptr, nullptr }; - } - else if (!canCoerce(paramType, argType, arg.argExpr, &cost)) - { - return { nullptr, nullptr }; - } - candidate.conversionCostSum += cost; - } - else - { - arg.argExpr = coerce(CoercionSite::Argument, paramType, arg.argExpr); - } - return arg; - }; - ShortList resultArgs; + arg.argExpr = coerce(CoercionSite::Argument, paramType, arg.argExpr); + } + return arg; + }; + ShortList resultArgs; - while (paramIndex < paramTypes.getCount()) + while (paramIndex < paramTypes.getCount()) + { + auto paramType = paramTypes[paramIndex]; + if (auto paramTypePack = as(paramType)) { - auto paramType = paramTypes[paramIndex]; - if (auto paramTypePack = as(paramType)) - { - ShortList innerArgs; - for (Index i = 0; i < paramTypePack->getTypeCount(); i++) - { - auto arg = readArg(); - auto coercedArg = coerceArgToParam(arg, QualType(paramTypePack->getElementType(i), paramType.isLeftValue)); - if (!coercedArg.type) - { - return false; - } - if (context.mode == OverloadResolveContext::Mode::ForReal) - innerArgs.add(coercedArg.argExpr); - } - if (context.mode == OverloadResolveContext::Mode::ForReal) - { - auto packArg = m_astBuilder->create(); - for (auto aa : innerArgs) - packArg->args.add(aa); - packArg->type = paramType; - resultArgs.add(packArg); - } - - // Always add a flat cost for using an argument pack, - // so that we prefer non-pack overloads when possible. - candidate.conversionCostSum += kConversionCost_ParameterPack; - } - else + ShortList innerArgs; + for (Index i = 0; i < paramTypePack->getTypeCount(); i++) { auto arg = readArg(); - if (!arg.type) - { - // If we run out of arguments, we can exit the loop now. - // Note that in this type we don't need to worry about - // default arguments, because we already checked that - // the number of arguments was correct in `TryCheckOverloadCandidateArity`. - break; - } - auto coercedArg = coerceArgToParam(arg, paramType); + auto coercedArg = coerceArgToParam( + arg, + QualType(paramTypePack->getElementType(i), paramType.isLeftValue)); if (!coercedArg.type) { return false; } if (context.mode == OverloadResolveContext::Mode::ForReal) - resultArgs.add(coercedArg.argExpr); + innerArgs.add(coercedArg.argExpr); + } + if (context.mode == OverloadResolveContext::Mode::ForReal) + { + auto packArg = m_astBuilder->create(); + for (auto aa : innerArgs) + packArg->args.add(aa); + packArg->type = paramType; + resultArgs.add(packArg); } - paramIndex++; + + // Always add a flat cost for using an argument pack, + // so that we prefer non-pack overloads when possible. + candidate.conversionCostSum += kConversionCost_ParameterPack; } - if (context.mode == OverloadResolveContext::Mode::ForReal) + else { - context.argCount = resultArgs.getCount(); - if (context.args) + auto arg = readArg(); + if (!arg.type) + { + // If we run out of arguments, we can exit the loop now. + // Note that in this type we don't need to worry about + // default arguments, because we already checked that + // the number of arguments was correct in `TryCheckOverloadCandidateArity`. + break; + } + auto coercedArg = coerceArgToParam(arg, paramType); + if (!coercedArg.type) { - context.args->setCount(context.argCount); - for (Index i = 0; i < context.argCount; i++) - (*context.args)[i] = resultArgs[i]; + return false; } + if (context.mode == OverloadResolveContext::Mode::ForReal) + resultArgs.add(coercedArg.argExpr); } - return true; + paramIndex++; } - - bool isEffectivelyMutating(CallableDecl* decl) + if (context.mode == OverloadResolveContext::Mode::ForReal) { - if(decl->hasModifier()) - return true; - if (decl->hasModifier()) - return true; - if(decl->hasModifier()) - return false; - - if(as(decl)) - return true; + context.argCount = resultArgs.getCount(); + if (context.args) + { + context.args->setCount(context.argCount); + for (Index i = 0; i < context.argCount; i++) + (*context.args)[i] = resultArgs[i]; + } + } + return true; +} +bool isEffectivelyMutating(CallableDecl* decl) +{ + if (decl->hasModifier()) + return true; + if (decl->hasModifier()) + return true; + if (decl->hasModifier()) return false; - } - ParamDecl* SemanticsVisitor::isReferenceIntoFunctionInputParameter( - Expr* inExpr) + if (as(decl)) + return true; + + return false; +} + +ParamDecl* SemanticsVisitor::isReferenceIntoFunctionInputParameter(Expr* inExpr) +{ + auto expr = inExpr; + for (;;) { - auto expr = inExpr; - for (;;) + if (auto declRefExpr = as(expr)) { - if (auto declRefExpr = as(expr)) + auto declRef = declRefExpr->declRef; + if (auto paramDeclRef = declRef.as()) { - auto declRef = declRefExpr->declRef; - if(auto paramDeclRef = declRef.as()) + if (paramDeclRef.as()) { - if (paramDeclRef.as()) - { - // functions declared in our "modern" style (using - // the `func` keyword) never have mutable `in` - // parameters. - // - return nullptr; - } - - if (paramDeclRef.getDecl()->findModifier() || - paramDeclRef.getDecl()->findModifier()) - { - // Function parameters marked with `out`, `inout`, - // `in out` or `ref` are all mutable in a way where - // the result of mutations will be visible to the - // caller. - // - return nullptr; - } + // functions declared in our "modern" style (using + // the `func` keyword) never have mutable `in` + // parameters. + // + return nullptr; + } - // At this point we have an l-value decl-ref to a - // function parameter that is (implicitly or - // explicitly) declared `in`. + if (paramDeclRef.getDecl()->findModifier() || + paramDeclRef.getDecl()->findModifier()) + { + // Function parameters marked with `out`, `inout`, + // `in out` or `ref` are all mutable in a way where + // the result of mutations will be visible to the + // caller. // - return paramDeclRef.getDecl(); + return nullptr; } - } - else if (auto memberExpr = as(expr)) - { - expr = memberExpr->baseExpression; - continue; - } - else if (auto indexExpr = as(expr)) - { - expr = indexExpr->baseExpression; - continue; - } - return nullptr; + // At this point we have an l-value decl-ref to a + // function parameter that is (implicitly or + // explicitly) declared `in`. + // + return paramDeclRef.getDecl(); + } + } + else if (auto memberExpr = as(expr)) + { + expr = memberExpr->baseExpression; + continue; + } + else if (auto indexExpr = as(expr)) + { + expr = indexExpr->baseExpression; + continue; } + + return nullptr; } +} - bool SemanticsVisitor::TryCheckOverloadCandidateDirections( - OverloadResolveContext& context, - OverloadCandidate const& candidate) - { - if(candidate.flavor != OverloadCandidate::Flavor::Func) - return true; +bool SemanticsVisitor::TryCheckOverloadCandidateDirections( + OverloadResolveContext& context, + OverloadCandidate const& candidate) +{ + if (candidate.flavor != OverloadCandidate::Flavor::Func) + return true; - auto funcDeclRef = candidate.item.declRef.as(); - SLANG_ASSERT(funcDeclRef); + auto funcDeclRef = candidate.item.declRef.as(); + SLANG_ASSERT(funcDeclRef); - // Note: This operation was originally introduced as - // a place to add checking around l-value-ness of arguments - // and parameters, but currently that checking is being - // done in other places. - // - // For now we will only use this step to check the - // mutability of the `this` parameter where necessary. - // - if(!isEffectivelyStatic(funcDeclRef.getDecl())) + // Note: This operation was originally introduced as + // a place to add checking around l-value-ness of arguments + // and parameters, but currently that checking is being + // done in other places. + // + // For now we will only use this step to check the + // mutability of the `this` parameter where necessary. + // + if (!isEffectivelyStatic(funcDeclRef.getDecl())) + { + if (isEffectivelyMutating(funcDeclRef.getDecl())) { - if(isEffectivelyMutating(funcDeclRef.getDecl())) + if (context.baseExpr && !context.baseExpr->type.isLeftValue) { - if(context.baseExpr && !context.baseExpr->type.isLeftValue) + if (context.mode == OverloadResolveContext::Mode::ForReal) { - if(context.mode == OverloadResolveContext::Mode::ForReal) - { - getSink()->diagnose(context.loc, Diagnostics::mutatingMethodOnImmutableValue, funcDeclRef.getName()); - maybeDiagnoseThisNotLValue(context.baseExpr); - } - return false; + getSink()->diagnose( + context.loc, + Diagnostics::mutatingMethodOnImmutableValue, + funcDeclRef.getName()); + maybeDiagnoseThisNotLValue(context.baseExpr); } + return false; + } - // The parameters of functions declared using traditional/legacy - // syntax are currently exposed as mutable locals within the body - // of the relevant function. As such, it is legal to call `[mutating]` - // methods on such a function parameter. However, doing so is typically - // indicative of an error on the programmer's part. - // - // We will detect such cases here and issue a diagnostic that explains - // the situation. - // - if(context.baseExpr && context.mode == OverloadResolveContext::Mode::ForReal) + // The parameters of functions declared using traditional/legacy + // syntax are currently exposed as mutable locals within the body + // of the relevant function. As such, it is legal to call `[mutating]` + // methods on such a function parameter. However, doing so is typically + // indicative of an error on the programmer's part. + // + // We will detect such cases here and issue a diagnostic that explains + // the situation. + // + if (context.baseExpr && context.mode == OverloadResolveContext::Mode::ForReal) + { + if (auto paramDecl = isReferenceIntoFunctionInputParameter(context.baseExpr)) { - if(auto paramDecl = isReferenceIntoFunctionInputParameter(context.baseExpr)) - { - const bool isNonCopyable = isNonCopyableType(paramDecl->getType()); + const bool isNonCopyable = isNonCopyableType(paramDecl->getType()); - const auto& diagnotic = isNonCopyable ? - Diagnostics::mutatingMethodOnFunctionInputParameterError : - Diagnostics::mutatingMethodOnFunctionInputParameterWarning; + const auto& diagnotic = + isNonCopyable ? Diagnostics::mutatingMethodOnFunctionInputParameterError + : Diagnostics::mutatingMethodOnFunctionInputParameterWarning; - getSink()->diagnose(context.loc, diagnotic, - funcDeclRef.getName(), - paramDecl->getName()); - } + getSink()->diagnose( + context.loc, + diagnotic, + funcDeclRef.getName(), + paramDecl->getName()); } } } + } + + return true; +} +bool SemanticsVisitor::TryCheckOverloadCandidateConstraints( + OverloadResolveContext& context, + OverloadCandidate& candidate) +{ + // We only need this step for generics, so always succeed on + // everything else. + if (candidate.flavor != OverloadCandidate::Flavor::Generic) return true; - } - bool SemanticsVisitor::TryCheckOverloadCandidateConstraints( - OverloadResolveContext& context, - OverloadCandidate& candidate) - { - // We only need this step for generics, so always succeed on - // everything else. - if(candidate.flavor != OverloadCandidate::Flavor::Generic) - return true; + // It is possible that the overload candidate was only partially + // applied (the number of arguments was not equal to the number + // of explicit parameters). In that case, we want to defer + // final checking of things like constraints until later, in + // case a subsequent pass of overload resolution (like applying + // an overloaded generic function to arguments) will give us + // the missing information to enable inference. + // + if (candidate.flags & OverloadCandidate::Flag::IsPartiallyAppliedGeneric) + return true; - // It is possible that the overload candidate was only partially - // applied (the number of arguments was not equal to the number - // of explicit parameters). In that case, we want to defer - // final checking of things like constraints until later, in - // case a subsequent pass of overload resolution (like applying - // an overloaded generic function to arguments) will give us - // the missing information to enable inference. - // - if(candidate.flags & OverloadCandidate::Flag::IsPartiallyAppliedGeneric) - return true; + auto genericDeclRef = candidate.item.declRef.as(); + SLANG_ASSERT(genericDeclRef); // otherwise we wouldn't be a generic candidate... - auto genericDeclRef = candidate.item.declRef.as(); - SLANG_ASSERT(genericDeclRef); // otherwise we wouldn't be a generic candidate... + // We should have the existing arguments to the generic + // handy, so that we can construct a substitution list. + auto substArgs = tryGetGenericArguments(candidate.subst, genericDeclRef.getDecl()); + SLANG_ASSERT(substArgs.getCount()); - // We should have the existing arguments to the generic - // handy, so that we can construct a substitution list. - auto substArgs = tryGetGenericArguments(candidate.subst, genericDeclRef.getDecl()); - SLANG_ASSERT(substArgs.getCount()); + List newArgs; + for (auto arg : substArgs) + newArgs.add(arg); - List newArgs; - for (auto arg : substArgs) - newArgs.add(arg); + for (auto constraintDecl : + genericDeclRef.getDecl()->getMembersOfType()) + { + DeclRef constraintDeclRef = + m_astBuilder + ->getGenericAppDeclRef(genericDeclRef, newArgs.getArrayView(), constraintDecl) + .as(); - for( auto constraintDecl : genericDeclRef.getDecl()->getMembersOfType() ) - { - DeclRef constraintDeclRef = m_astBuilder->getGenericAppDeclRef( - genericDeclRef, newArgs.getArrayView(), constraintDecl).as(); - - auto sub = getSub(m_astBuilder, constraintDeclRef); - auto sup = getSup(m_astBuilder, constraintDeclRef); + auto sub = getSub(m_astBuilder, constraintDeclRef); + auto sup = getSup(m_astBuilder, constraintDeclRef); - auto subTypeWitness = tryGetSubtypeWitness(sub, sup); - if(subTypeWitness) - { - newArgs.add(subTypeWitness); - } - else + auto subTypeWitness = tryGetSubtypeWitness(sub, sup); + if (subTypeWitness) + { + newArgs.add(subTypeWitness); + } + else + { + if (context.mode != OverloadResolveContext::Mode::JustTrying) { - if(context.mode != OverloadResolveContext::Mode::JustTrying) - { - subTypeWitness = isSubtype(sub, sup, IsSubTypeOptions::None); - getSink()->diagnose(context.loc, Diagnostics::typeArgumentDoesNotConformToInterface, sub, sup); - } - return false; + subTypeWitness = isSubtype(sub, sup, IsSubTypeOptions::None); + getSink()->diagnose( + context.loc, + Diagnostics::typeArgumentDoesNotConformToInterface, + sub, + sup); } + return false; } - - candidate.subst = SubstitutionSet(m_astBuilder->getGenericAppDeclRef(genericDeclRef, newArgs.getArrayView())); - - // Done checking all the constraints, hooray. - return true; } - void SemanticsVisitor::TryCheckOverloadCandidate( - OverloadResolveContext& context, - OverloadCandidate& candidate) - { - if (!TryCheckOverloadCandidateArity(context, candidate)) - return; + candidate.subst = + SubstitutionSet(m_astBuilder->getGenericAppDeclRef(genericDeclRef, newArgs.getArrayView())); - candidate.status = OverloadCandidate::Status::ArityChecked; - if (!TryCheckOverloadCandidateFixity(context, candidate)) - return; + // Done checking all the constraints, hooray. + return true; +} - candidate.status = OverloadCandidate::Status::FixityChecked; - if (!TryCheckOverloadCandidateTypes(context, candidate)) - return; +void SemanticsVisitor::TryCheckOverloadCandidate( + OverloadResolveContext& context, + OverloadCandidate& candidate) +{ + if (!TryCheckOverloadCandidateArity(context, candidate)) + return; - candidate.status = OverloadCandidate::Status::TypeChecked; - if (!TryCheckOverloadCandidateDirections(context, candidate)) - return; + candidate.status = OverloadCandidate::Status::ArityChecked; + if (!TryCheckOverloadCandidateFixity(context, candidate)) + return; - candidate.status = OverloadCandidate::Status::DirectionChecked; - if (!TryCheckOverloadCandidateConstraints(context, candidate)) - return; + candidate.status = OverloadCandidate::Status::FixityChecked; + if (!TryCheckOverloadCandidateTypes(context, candidate)) + return; - candidate.status = OverloadCandidate::Status::VisibilityChecked; - if (!TryCheckOverloadCandidateVisibility(context, candidate)) - return; + candidate.status = OverloadCandidate::Status::TypeChecked; + if (!TryCheckOverloadCandidateDirections(context, candidate)) + return; - candidate.status = OverloadCandidate::Status::Applicable; - } + candidate.status = OverloadCandidate::Status::DirectionChecked; + if (!TryCheckOverloadCandidateConstraints(context, candidate)) + return; - Expr* SemanticsVisitor::createGenericDeclRef( - Expr* baseExpr, - Expr* originalExpr, - SubstitutionSet substArgs) - { - auto baseDeclRefExpr = as(baseExpr); - if (!baseDeclRefExpr) - { - SLANG_DIAGNOSE_UNEXPECTED(getSink(), baseExpr, "expected a reference to a generic declaration"); - return CreateErrorExpr(originalExpr); - } - auto baseGenericRef = baseDeclRefExpr->declRef.as(); - if (!baseGenericRef) - { - SLANG_DIAGNOSE_UNEXPECTED(getSink(), baseExpr, "expected a reference to a generic declaration"); - return CreateErrorExpr(originalExpr); - } - auto genSubst = substArgs.findGenericAppDeclRef(baseGenericRef.getDecl()); - SLANG_ASSERT(genSubst); - DeclRef innerDeclRef = m_astBuilder->getGenericAppDeclRef(baseGenericRef, genSubst->getArgs()); + candidate.status = OverloadCandidate::Status::VisibilityChecked; + if (!TryCheckOverloadCandidateVisibility(context, candidate)) + return; - Expr* base = nullptr; - if (auto mbrExpr = as(baseExpr)) - base = mbrExpr->baseExpression; + candidate.status = OverloadCandidate::Status::Applicable; +} - return ConstructDeclRefExpr( - innerDeclRef, - base, - innerDeclRef.getName(), - originalExpr->loc, - originalExpr); +Expr* SemanticsVisitor::createGenericDeclRef( + Expr* baseExpr, + Expr* originalExpr, + SubstitutionSet substArgs) +{ + auto baseDeclRefExpr = as(baseExpr); + if (!baseDeclRefExpr) + { + SLANG_DIAGNOSE_UNEXPECTED( + getSink(), + baseExpr, + "expected a reference to a generic declaration"); + return CreateErrorExpr(originalExpr); } + auto baseGenericRef = baseDeclRefExpr->declRef.as(); + if (!baseGenericRef) + { + SLANG_DIAGNOSE_UNEXPECTED( + getSink(), + baseExpr, + "expected a reference to a generic declaration"); + return CreateErrorExpr(originalExpr); + } + auto genSubst = substArgs.findGenericAppDeclRef(baseGenericRef.getDecl()); + SLANG_ASSERT(genSubst); + DeclRef innerDeclRef = + m_astBuilder->getGenericAppDeclRef(baseGenericRef, genSubst->getArgs()); + + Expr* base = nullptr; + if (auto mbrExpr = as(baseExpr)) + base = mbrExpr->baseExpression; + + return ConstructDeclRefExpr( + innerDeclRef, + base, + innerDeclRef.getName(), + originalExpr->loc, + originalExpr); +} - Expr* SemanticsVisitor::CompleteOverloadCandidate( - OverloadResolveContext& context, - OverloadCandidate& candidate) +Expr* SemanticsVisitor::CompleteOverloadCandidate( + OverloadResolveContext& context, + OverloadCandidate& candidate) +{ + // special case for generic argument inference failure + if (candidate.status == OverloadCandidate::Status::GenericArgumentInferenceFailed) { - // special case for generic argument inference failure - if (candidate.status == OverloadCandidate::Status::GenericArgumentInferenceFailed) - { - String callString = getCallSignatureString(context); - getSink()->diagnose( - context.loc, - Diagnostics::genericArgumentInferenceFailed, - callString); + String callString = getCallSignatureString(context); + getSink()->diagnose(context.loc, Diagnostics::genericArgumentInferenceFailed, callString); - String declString = ASTPrinter::getDeclSignatureString(candidate.item, m_astBuilder); - getSink()->diagnose(candidate.item.declRef, Diagnostics::genericSignatureTried, declString); - goto error; - } + String declString = ASTPrinter::getDeclSignatureString(candidate.item, m_astBuilder); + getSink()->diagnose(candidate.item.declRef, Diagnostics::genericSignatureTried, declString); + goto error; + } - context.mode = OverloadResolveContext::Mode::ForReal; + context.mode = OverloadResolveContext::Mode::ForReal; - if (!TryCheckOverloadCandidateClassNewMatchUp(context, candidate)) - goto error; + if (!TryCheckOverloadCandidateClassNewMatchUp(context, candidate)) + goto error; - if (!TryCheckOverloadCandidateArity(context, candidate)) - goto error; + if (!TryCheckOverloadCandidateArity(context, candidate)) + goto error; - if (!TryCheckOverloadCandidateFixity(context, candidate)) - goto error; + if (!TryCheckOverloadCandidateFixity(context, candidate)) + goto error; - if (!TryCheckOverloadCandidateTypes(context, candidate)) - goto error; + if (!TryCheckOverloadCandidateTypes(context, candidate)) + goto error; - if (!TryCheckOverloadCandidateDirections(context, candidate)) - goto error; + if (!TryCheckOverloadCandidateDirections(context, candidate)) + goto error; - if (!TryCheckOverloadCandidateConstraints(context, candidate)) - goto error; + if (!TryCheckOverloadCandidateConstraints(context, candidate)) + goto error; - if (!TryCheckOverloadCandidateVisibility(context, candidate)) - goto error; + if (!TryCheckOverloadCandidateVisibility(context, candidate)) + goto error; + { + Expr* baseExpr; + switch (candidate.flavor) { - Expr* baseExpr; - switch(candidate.flavor) - { - case OverloadCandidate::Flavor::Func: - case OverloadCandidate::Flavor::Generic: - baseExpr = ConstructLookupResultExpr( - candidate.item, - context.baseExpr, - candidate.item.declRef.getName(), - context.funcLoc, - context.originalExpr); - break; - case OverloadCandidate::Flavor::Expr: - default: - baseExpr = nullptr; - break; - } + case OverloadCandidate::Flavor::Func: + case OverloadCandidate::Flavor::Generic: + baseExpr = ConstructLookupResultExpr( + candidate.item, + context.baseExpr, + candidate.item.declRef.getName(), + context.funcLoc, + context.originalExpr); + break; + case OverloadCandidate::Flavor::Expr: + default: baseExpr = nullptr; break; + } - switch(candidate.flavor) + switch (candidate.flavor) + { + case OverloadCandidate::Flavor::Func: { - case OverloadCandidate::Flavor::Func: + AppExprBase* callExpr = as(context.originalExpr); + if (!callExpr) { - AppExprBase* callExpr = as(context.originalExpr); - if(!callExpr) - { - callExpr = m_astBuilder->create(); - callExpr->loc = context.loc; - for(Index aa = 0; aa < context.argCount; ++aa) - callExpr->arguments.add(context.getArg(aa)); - } + callExpr = m_astBuilder->create(); + callExpr->loc = context.loc; + for (Index aa = 0; aa < context.argCount; ++aa) + callExpr->arguments.add(context.getArg(aa)); + } - callExpr->originalFunctionExpr = callExpr->functionExpr; - callExpr->functionExpr = baseExpr; - callExpr->type = QualType(candidate.resultType); + callExpr->originalFunctionExpr = callExpr->functionExpr; + callExpr->functionExpr = baseExpr; + callExpr->type = QualType(candidate.resultType); - // A call may yield an l-value, and we should take a look at the candidate to be sure - if(auto subscriptDeclRef = candidate.item.declRef.as()) + // A call may yield an l-value, and we should take a look at the candidate to be + // sure + if (auto subscriptDeclRef = candidate.item.declRef.as()) + { + const auto& decl = subscriptDeclRef.getDecl(); + for (auto member : decl->members) { - const auto& decl = subscriptDeclRef.getDecl(); - for (auto member : decl->members) + if (as(member) || as(member)) { - if (as(member) || as(member)) + // If the subscript decl has a setter, + // then the call is an l-value if base is l-value. + if (auto base = GetBaseExpr(baseExpr)) { - // If the subscript decl has a setter, - // then the call is an l-value if base is l-value. - if (auto base = GetBaseExpr(baseExpr)) - { - if (base->type.isLeftValue) - { - callExpr->type.isLeftValue = true; - break; - } - } - // Otherwise, if the accessor is [nonmutating], we can - // also consider the result of the subscript call as l-value - // regardless of the base. - if (member->findModifier()) + if (base->type.isLeftValue) { callExpr->type.isLeftValue = true; break; } } + // Otherwise, if the accessor is [nonmutating], we can + // also consider the result of the subscript call as l-value + // regardless of the base. + if (member->findModifier()) + { + callExpr->type.isLeftValue = true; + break; + } } } - - // TODO: there may be other cases that confer l-value-ness - - return callExpr; } - break; - - case OverloadCandidate::Flavor::Expr: - { - AppExprBase* callExpr = as(context.originalExpr); - if (!callExpr) - { - callExpr = m_astBuilder->create(); - callExpr->loc = context.loc; - for (Index aa = 0; aa < context.argCount; ++aa) - callExpr->arguments.add(context.getArg(aa)); - } + // TODO: there may be other cases that confer l-value-ness - callExpr->originalFunctionExpr = callExpr->functionExpr; - callExpr->type = QualType(candidate.resultType); - callExpr->functionExpr = candidate.exprVal; - return callExpr; + return callExpr; + } - } - break; + break; - case OverloadCandidate::Flavor::Generic: - // We allow a generic to be applied to fewer arguments than its number - // of parameters, and defer the process of inferring the remaining - // arguments until later. - // - if(candidate.flags & OverloadCandidate::Flag::IsPartiallyAppliedGeneric) + case OverloadCandidate::Flavor::Expr: + { + AppExprBase* callExpr = as(context.originalExpr); + if (!callExpr) { - auto expr = m_astBuilder->create(); - expr->loc = context.loc; - expr->originalExpr = baseExpr; - expr->baseGenericDeclRef = as(baseExpr)->declRef.as(); - auto args = tryGetGenericArguments(candidate.subst, expr->baseGenericDeclRef.getDecl()); - for (auto arg : args) - expr->knownGenericArgs.add(arg); - return expr; + callExpr = m_astBuilder->create(); + callExpr->loc = context.loc; + for (Index aa = 0; aa < context.argCount; ++aa) + callExpr->arguments.add(context.getArg(aa)); } - return createGenericDeclRef( - baseExpr, - context.originalExpr, - candidate.subst); - break; - - default: - SLANG_DIAGNOSE_UNEXPECTED(getSink(), context.loc, "unknown overload candidate flavor"); - break; + callExpr->originalFunctionExpr = callExpr->functionExpr; + callExpr->type = QualType(candidate.resultType); + callExpr->functionExpr = candidate.exprVal; + return callExpr; } - } + break; + case OverloadCandidate::Flavor::Generic: + // We allow a generic to be applied to fewer arguments than its number + // of parameters, and defer the process of inferring the remaining + // arguments until later. + // + if (candidate.flags & OverloadCandidate::Flag::IsPartiallyAppliedGeneric) + { + auto expr = m_astBuilder->create(); + expr->loc = context.loc; + expr->originalExpr = baseExpr; + expr->baseGenericDeclRef = as(baseExpr)->declRef.as(); + auto args = + tryGetGenericArguments(candidate.subst, expr->baseGenericDeclRef.getDecl()); + for (auto arg : args) + expr->knownGenericArgs.add(arg); + return expr; + } - error: + return createGenericDeclRef(baseExpr, context.originalExpr, candidate.subst); + break; - if(context.originalExpr) - { - return CreateErrorExpr(context.originalExpr); - } - else - { - return nullptr; + default: + SLANG_DIAGNOSE_UNEXPECTED(getSink(), context.loc, "unknown overload candidate flavor"); + break; } } - /// Does the given `declRef` represent an interface requirement? - bool isInterfaceRequirement(ASTBuilder* builder, DeclRef const& declRef) - { - SLANG_UNUSED(builder); - if(!declRef) - return false; +error: - auto parent = declRef.getParent(); - if(parent.as()) - parent = parent.getParent(); + if (context.originalExpr) + { + return CreateErrorExpr(context.originalExpr); + } + else + { + return nullptr; + } +} - if(parent.as()) - return true; +/// Does the given `declRef` represent an interface requirement? +bool isInterfaceRequirement(ASTBuilder* builder, DeclRef const& declRef) +{ + SLANG_UNUSED(builder); + if (!declRef) return false; - } - /// If `declRef` representations a specialization of a generic, returns the number of specialized generic arguments. - /// Otherwise, returns zero. - /// - Int SemanticsVisitor::getSpecializedParamCount(DeclRef const& declRef) - { - if(!declRef) - return 0; + auto parent = declRef.getParent(); + if (parent.as()) + parent = parent.getParent(); - // A specialization of a generic must point at the - // "inner" declaration of a generic. That means that - // the parent of the decl ref must be a generic. - // - auto parentGeneric = declRef.getParent().as(); - if(!parentGeneric) - return 0; - // - // Furthermore, the declaration we are considering - // must be the single "inner" declaration of the - // parent generic (and not somthing like a generic - // parameter). - // - if( parentGeneric.getDecl()->inner != declRef.getDecl()) - return 0; + if (parent.as()) + return true; + + return false; +} - return CountParameters(parentGeneric).required; +/// If `declRef` representations a specialization of a generic, returns the number of specialized +/// generic arguments. Otherwise, returns zero. +/// +Int SemanticsVisitor::getSpecializedParamCount(DeclRef const& declRef) +{ + if (!declRef) + return 0; + + // A specialization of a generic must point at the + // "inner" declaration of a generic. That means that + // the parent of the decl ref must be a generic. + // + auto parentGeneric = declRef.getParent().as(); + if (!parentGeneric) + return 0; + // + // Furthermore, the declaration we are considering + // must be the single "inner" declaration of the + // parent generic (and not somthing like a generic + // parameter). + // + if (parentGeneric.getDecl()->inner != declRef.getDecl()) + return 0; + + return CountParameters(parentGeneric).required; +} + +DeclRef getParentDeclRef(DeclRef declRef) +{ + auto parent = declRef.getParent(); + while (parent.as()) + { + parent = parent.getParent(); } + return parent; +} - DeclRef getParentDeclRef(DeclRef declRef) +// Returns -1 if left is preferred, 1 if right is preferred, and 0 if they are equal. +// +int SemanticsVisitor::CompareLookupResultItems( + LookupResultItem const& left, + LookupResultItem const& right) +{ + // It is possible for lookup to return both an interface requirement + // and the concrete function that satisfies that requirement. + // We always want to favor a concrete method over an interface + // requirement it might override. + // + // TODO: This should turn into a more detailed check such that + // a candidate for declaration A is always better than a candidate + // for declaration B if A is an override of B. We can't + // easily make that check right now because we aren't tracking + // this kind of "is an override of ..." information on declarations + // directly (it is only visible through the requirement witness + // information for inheritance declarations). + // + auto leftDeclRefParent = getParentDeclRef(left.declRef); + auto rightDeclRefParent = getParentDeclRef(right.declRef); + bool leftIsInterfaceRequirement = isInterfaceRequirement(left.declRef.getDecl()); + bool rightIsInterfaceRequirement = isInterfaceRequirement(right.declRef.getDecl()); + if (leftIsInterfaceRequirement != rightIsInterfaceRequirement) + return int(leftIsInterfaceRequirement) - int(rightIsInterfaceRequirement); + + // Prefer non-extension declarations over extension declarations. + bool leftIsExtension = as(leftDeclRefParent.getDecl()) != nullptr; + bool rightIsExtension = as(rightDeclRefParent.getDecl()) != nullptr; + if (leftIsExtension != rightIsExtension) { - auto parent = declRef.getParent(); - while (parent.as()) + return int(leftIsExtension) - int(rightIsExtension); + } + else if (leftIsExtension) + { + // If both are declared in extensions, prefer the one that is least generic. + bool leftIsGeneric = leftDeclRefParent.getParent().as() != nullptr; + bool rightIsGeneric = rightDeclRefParent.getParent().as() != nullptr; + if (leftIsGeneric != rightIsGeneric) { - parent = parent.getParent(); + return int(leftIsGeneric) - int(rightIsGeneric); } - return parent; } - // Returns -1 if left is preferred, 1 if right is preferred, and 0 if they are equal. - // - int SemanticsVisitor::CompareLookupResultItems( - LookupResultItem const& left, - LookupResultItem const& right) - { - // It is possible for lookup to return both an interface requirement - // and the concrete function that satisfies that requirement. - // We always want to favor a concrete method over an interface - // requirement it might override. - // - // TODO: This should turn into a more detailed check such that - // a candidate for declaration A is always better than a candidate - // for declaration B if A is an override of B. We can't - // easily make that check right now because we aren't tracking - // this kind of "is an override of ..." information on declarations - // directly (it is only visible through the requirement witness - // information for inheritance declarations). - // - auto leftDeclRefParent = getParentDeclRef(left.declRef); - auto rightDeclRefParent = getParentDeclRef(right.declRef); - bool leftIsInterfaceRequirement = isInterfaceRequirement(left.declRef.getDecl()); - bool rightIsInterfaceRequirement = isInterfaceRequirement(right.declRef.getDecl()); - if(leftIsInterfaceRequirement != rightIsInterfaceRequirement) - return int(leftIsInterfaceRequirement) - int(rightIsInterfaceRequirement); - - // Prefer non-extension declarations over extension declarations. - bool leftIsExtension = as(leftDeclRefParent.getDecl()) != nullptr; - bool rightIsExtension = as(rightDeclRefParent.getDecl()) != nullptr; - if (leftIsExtension != rightIsExtension) - { - return int(leftIsExtension) - int(rightIsExtension); - } - else if (leftIsExtension) - { - // If both are declared in extensions, prefer the one that is least generic. - bool leftIsGeneric = leftDeclRefParent.getParent().as() != nullptr; - bool rightIsGeneric = rightDeclRefParent.getParent().as() != nullptr; - if (leftIsGeneric != rightIsGeneric) - { - return int(leftIsGeneric) - int(rightIsGeneric); - } - } + // Any decl is strictly better than a module decl. + bool leftIsModule = (as(left.declRef) != nullptr); + bool rightIsModule = (as(right.declRef) != nullptr); + if (leftIsModule != rightIsModule) + return int(rightIsModule) - int(leftIsModule); - // Any decl is strictly better than a module decl. - bool leftIsModule = (as(left.declRef) != nullptr); - bool rightIsModule = (as(right.declRef) != nullptr); - if(leftIsModule != rightIsModule) - return int(rightIsModule) - int(leftIsModule); + // If both are interface requirements, prefer the more derived interface. + if (leftIsInterfaceRequirement && rightIsInterfaceRequirement) + { + auto leftType = DeclRefType::create(m_astBuilder, leftDeclRefParent); + auto rightType = DeclRefType::create(m_astBuilder, rightDeclRefParent); - // If both are interface requirements, prefer the more derived interface. - if (leftIsInterfaceRequirement && rightIsInterfaceRequirement) + if (!leftType->equals(rightType)) { - auto leftType = DeclRefType::create(m_astBuilder, leftDeclRefParent); - auto rightType = DeclRefType::create(m_astBuilder, rightDeclRefParent); - - if (!leftType->equals(rightType)) - { - if (isSubtype(leftType, rightType, IsSubTypeOptions::None)) - return -1; - if (isSubtype(rightType, leftType, IsSubTypeOptions::None)) - return 1; - } + if (isSubtype(leftType, rightType, IsSubTypeOptions::None)) + return -1; + if (isSubtype(rightType, leftType, IsSubTypeOptions::None)) + return 1; } + } - // If both parents are the same we have ambiguity - if(left.declRef.getParent() == right.declRef.getParent()) - return 0; - - auto leftAggType = leftDeclRefParent.as(); - auto rightAggType = rightDeclRefParent.as(); - if (leftAggType && rightAggType) - { - auto leftType = DeclRefType::create(m_astBuilder, leftDeclRefParent); - auto rightType = DeclRefType::create(m_astBuilder, rightDeclRefParent); + // If both parents are the same we have ambiguity + if (left.declRef.getParent() == right.declRef.getParent()) + return 0; - auto inheritanceInfo = getShared()->getInheritanceInfo(rightType); - for (auto facet : inheritanceInfo.facets) - if (facet.getImpl()->getDeclRef().equals(leftDeclRefParent)) - return 1; - inheritanceInfo = getShared()->getInheritanceInfo(leftType); - for (auto facet : inheritanceInfo.facets) - if (facet.getImpl()->getDeclRef().equals(rightDeclRefParent)) - return -1; - } + auto leftAggType = leftDeclRefParent.as(); + auto rightAggType = rightDeclRefParent.as(); + if (leftAggType && rightAggType) + { + auto leftType = DeclRefType::create(m_astBuilder, leftDeclRefParent); + auto rightType = DeclRefType::create(m_astBuilder, rightDeclRefParent); + + auto inheritanceInfo = getShared()->getInheritanceInfo(rightType); + for (auto facet : inheritanceInfo.facets) + if (facet.getImpl()->getDeclRef().equals(leftDeclRefParent)) + return 1; + inheritanceInfo = getShared()->getInheritanceInfo(leftType); + for (auto facet : inheritanceInfo.facets) + if (facet.getImpl()->getDeclRef().equals(rightDeclRefParent)) + return -1; + } - // If both are subscript decls, prefer the one that provides more - // accessors. - if (auto leftSubscriptDecl = left.declRef.as()) + // If both are subscript decls, prefer the one that provides more + // accessors. + if (auto leftSubscriptDecl = left.declRef.as()) + { + if (auto rightSubscriptDecl = right.declRef.as()) { - if (auto rightSubscriptDecl = right.declRef.as()) + auto leftAccessorCount = + leftSubscriptDecl.getDecl()->getMembersOfType().getCount(); + auto rightAccessorCount = + rightSubscriptDecl.getDecl()->getMembersOfType().getCount(); + auto decl1IsSubsetOfDecl2 = [=](SubscriptDecl* decl1, SubscriptDecl* decl2) { - auto leftAccessorCount = leftSubscriptDecl.getDecl()->getMembersOfType().getCount(); - auto rightAccessorCount = rightSubscriptDecl.getDecl()->getMembersOfType().getCount(); - auto decl1IsSubsetOfDecl2 = [=](SubscriptDecl* decl1, SubscriptDecl* decl2) + for (auto accessorDecl1 : decl1->getMembersOfType()) + { + bool found = false; + for (auto accessorDecl2 : decl2->getMembersOfType()) { - for (auto accessorDecl1 : decl1->getMembersOfType()) + if (accessorDecl1->astNodeType == accessorDecl2->astNodeType) { - bool found = false; - for (auto accessorDecl2 : decl2->getMembersOfType()) - { - if (accessorDecl1->astNodeType == accessorDecl2->astNodeType) - { - found = true; - break; - } - } - if (!found) - return false; + found = true; + break; } - return true; - }; - if (leftAccessorCount > rightAccessorCount - && decl1IsSubsetOfDecl2(rightSubscriptDecl.getDecl(), leftSubscriptDecl.getDecl())) - { - return -1; - } - else if (rightAccessorCount > leftAccessorCount - && decl1IsSubsetOfDecl2(leftSubscriptDecl.getDecl(), rightSubscriptDecl.getDecl())) - { - return 1; + } + if (!found) + return false; } + return true; + }; + if (leftAccessorCount > rightAccessorCount && + decl1IsSubsetOfDecl2(rightSubscriptDecl.getDecl(), leftSubscriptDecl.getDecl())) + { + return -1; + } + else if ( + rightAccessorCount > leftAccessorCount && + decl1IsSubsetOfDecl2(leftSubscriptDecl.getDecl(), rightSubscriptDecl.getDecl())) + { + return 1; } } + } - // TODO: We should generalize above rules such that in a tie a declaration - // A::m is better than B::m when all other factors are equal and - // A inherits from B. + // TODO: We should generalize above rules such that in a tie a declaration + // A::m is better than B::m when all other factors are equal and + // A inherits from B. - // TODO: There are other cases like this we need to add in terms - // of ranking/prioritizing overloads, around things like - // "transparent" members, or when lookup proceeds from an "inner" - // to an "outer" scope. In many cases the right way to proceed - // could involve attaching a distance/cost/rank to things directly - // as part of lookup, and in other cases it might be best handled - // as a semantic check based on the actual declarations found. + // TODO: There are other cases like this we need to add in terms + // of ranking/prioritizing overloads, around things like + // "transparent" members, or when lookup proceeds from an "inner" + // to an "outer" scope. In many cases the right way to proceed + // could involve attaching a distance/cost/rank to things directly + // as part of lookup, and in other cases it might be best handled + // as a semantic check based on the actual declarations found. - return 0; - } + return 0; +} - int SemanticsVisitor::compareOverloadCandidateSpecificity( - LookupResultItem const& left, - LookupResultItem const& right) - { - // HACK: if both items refer to the same declaration, - // then arbitrarily pick one. - if(left.declRef.equals(right.declRef)) - return -1; +int SemanticsVisitor::compareOverloadCandidateSpecificity( + LookupResultItem const& left, + LookupResultItem const& right) +{ + // HACK: if both items refer to the same declaration, + // then arbitrarily pick one. + if (left.declRef.equals(right.declRef)) + return -1; - // There is a very general rule that we would like to enforce - // in principle: - // - // Given candidates A and B, if A being applicable to some - // arguments implies that B is also applicable, but not vice versa, - // then A is a more specific/specialized candidate than B. - // - // A number of conclusions follow from this general rule. - // For example, a non-generic declaration will always be - // more specific than a generic declaration that was specialized - // to matching types: - // - // int doThing(int a); - // T doThing(T a); - // - // It is clear that if the non-generic `doThing` is applicable - // to an argument `x`, then `doThing` is also applicable to - // `x`. However, knowing that the generic `doThing` was applicable - // to some `y` doesn't tell us that the non-generic `doThing` can - // be called on `y`, because `y` could have some type that can't - // convert to `int`. - // - // Similarly, a generic declaration with a subset of the parameters - // of another generic is always more specialized: - // - // int doThing(vector value); - // int doThing(vector value); - // - // Here we know that both overloads can apply to `float3`, but only - // one can apply to `float4`, so the first overload is more - // specialized/specific. - // - // As a final example, a generic which places more constraints - // on its generic parameters is more specific, all other things - // being equal: - // - // int doThing( T value ); - // int doThing(T value); - // - // In this case we know that the first overload is applicable - // to a strict subset of the types that the second overload can - // apply to. - // - // The above rules represent the idealized principles we want - // to implement, but actually implementing that full check here - // could make overload resolution far more expensive. - // - // For now we are going to do something far simpler and hackier, - // which is to say that a candidate with more generic parameters - // is always preferred over one with fewer. - // - // TODO: We could extend this definition to account for constraints - // on generic parameters in the count, which would handle the - // need to prefer a more-constrained generic when possible. - // - // TODO: In the long run we should clearly replace this with - // the more general "does A being applicable imply B being applicable" - // test. - // - // TODO: The principle stated here doesn't take the actual - // arguments or their types into account, and it might be that - // in some cases disambiguation of which declaration should be - // preferred will depend on knowing the actual arguments. - // - auto leftSpecCount = getSpecializedParamCount(left.declRef); - auto rightSpecCount = getSpecializedParamCount(right.declRef); - if(leftSpecCount != rightSpecCount) - return int(leftSpecCount - rightSpecCount); + // There is a very general rule that we would like to enforce + // in principle: + // + // Given candidates A and B, if A being applicable to some + // arguments implies that B is also applicable, but not vice versa, + // then A is a more specific/specialized candidate than B. + // + // A number of conclusions follow from this general rule. + // For example, a non-generic declaration will always be + // more specific than a generic declaration that was specialized + // to matching types: + // + // int doThing(int a); + // T doThing(T a); + // + // It is clear that if the non-generic `doThing` is applicable + // to an argument `x`, then `doThing` is also applicable to + // `x`. However, knowing that the generic `doThing` was applicable + // to some `y` doesn't tell us that the non-generic `doThing` can + // be called on `y`, because `y` could have some type that can't + // convert to `int`. + // + // Similarly, a generic declaration with a subset of the parameters + // of another generic is always more specialized: + // + // int doThing(vector value); + // int doThing(vector value); + // + // Here we know that both overloads can apply to `float3`, but only + // one can apply to `float4`, so the first overload is more + // specialized/specific. + // + // As a final example, a generic which places more constraints + // on its generic parameters is more specific, all other things + // being equal: + // + // int doThing( T value ); + // int doThing(T value); + // + // In this case we know that the first overload is applicable + // to a strict subset of the types that the second overload can + // apply to. + // + // The above rules represent the idealized principles we want + // to implement, but actually implementing that full check here + // could make overload resolution far more expensive. + // + // For now we are going to do something far simpler and hackier, + // which is to say that a candidate with more generic parameters + // is always preferred over one with fewer. + // + // TODO: We could extend this definition to account for constraints + // on generic parameters in the count, which would handle the + // need to prefer a more-constrained generic when possible. + // + // TODO: In the long run we should clearly replace this with + // the more general "does A being applicable imply B being applicable" + // test. + // + // TODO: The principle stated here doesn't take the actual + // arguments or their types into account, and it might be that + // in some cases disambiguation of which declaration should be + // preferred will depend on knowing the actual arguments. + // + auto leftSpecCount = getSpecializedParamCount(left.declRef); + auto rightSpecCount = getSpecializedParamCount(right.declRef); + if (leftSpecCount != rightSpecCount) + return int(leftSpecCount - rightSpecCount); + + return 0; +} +int getOverloadRank(DeclRef declRef) +{ + if (!declRef.getDecl()) return 0; - } + if (auto attr = declRef.getDecl()->findModifier()) + return attr->rank; + return 0; +} - int getOverloadRank(DeclRef declRef) +int getExportRank(DeclRef left, DeclRef right) +{ + if (left.getDecl() && left.getDecl()->hasModifier()) { - if (!declRef.getDecl()) - return 0; - if (auto attr = declRef.getDecl()->findModifier()) - return attr->rank; - return 0; + return (right.getDecl() && right.getDecl()->hasModifier()) ? -1 : 0; } + return 0; +} - int getExportRank(DeclRef left, DeclRef right) - { - if (left.getDecl() && left.getDecl()->hasModifier()) - { - return (right.getDecl() && right.getDecl()->hasModifier()) ? -1 : 0; - } +int getScopeRank( + DeclRef const& left, + DeclRef const& right, + Slang::Scope* referenceSiteScope) +{ + if (!referenceSiteScope) return 0; + + DeclRef prefixDecl = referenceSiteScope->containerDecl; + + // Hold the path from reference site to the root + // key: Decl node, value: distance from reference site + Dictionary refPath; + for (auto node = prefixDecl; node != nullptr; node = node.getParent()) + { + Decl* key = node.getDecl(); + uint32_t value = (uint32_t)refPath.getCount(); + refPath.add(key, value); } - int getScopeRank(DeclRef const& left, - DeclRef const& right, Slang::Scope* referenceSiteScope) + // find the common prefix decl of reference site and left + int leftDistance = 0; + int rightDistance = 0; + auto distanceToCommonPrefix = [](DeclRef const& candidate, + Dictionary refPath) -> int { - if (!referenceSiteScope) - return 0; + uint32_t distanceToReferenceSite = 0; + uint32_t distanceToCandidate = 0; - DeclRef prefixDecl = referenceSiteScope->containerDecl; + // Sanity check + if (candidate.getDecl() == nullptr) + return -1; - // Hold the path from reference site to the root - // key: Decl node, value: distance from reference site - Dictionary refPath; - for (auto node = prefixDecl; node != nullptr; node = node.getParent()) + // search from candidate to root, once we found the first node in the reference path, that + // is the first common prefix, and we can stop searching. + for (auto node = candidate; node != nullptr; node = node.getParent()) { Decl* key = node.getDecl(); - uint32_t value = (uint32_t)refPath.getCount(); - refPath.add(key, value); - } - - // find the common prefix decl of reference site and left - int leftDistance = 0; - int rightDistance = 0; - auto distanceToCommonPrefix = [](DeclRef const& candidate, Dictionary refPath) -> int - { - uint32_t distanceToReferenceSite = 0; - uint32_t distanceToCandidate = 0; - - // Sanity check - if (candidate.getDecl() == nullptr) - return -1; - - // search from candidate to root, once we found the first node in the reference path, that is the first - // common prefix, and we can stop searching. - for (auto node = candidate; node != nullptr; node = node.getParent()) + if (refPath.tryGetValue(key, distanceToReferenceSite)) { - Decl* key = node.getDecl(); - if (refPath.tryGetValue(key, distanceToReferenceSite)) - { - break; - } - distanceToCandidate++; + break; } + distanceToCandidate++; + } - // If we don't find the common prefix, there must be something wrong, return the max value. - if (distanceToReferenceSite == 0) - return -1; + // If we don't find the common prefix, there must be something wrong, return the max value. + if (distanceToReferenceSite == 0) + return -1; - return distanceToReferenceSite + distanceToCandidate; - }; + return distanceToReferenceSite + distanceToCandidate; + }; - leftDistance = distanceToCommonPrefix(left, refPath); - rightDistance = distanceToCommonPrefix(right, refPath); + leftDistance = distanceToCommonPrefix(left, refPath); + rightDistance = distanceToCommonPrefix(right, refPath); - if (leftDistance == rightDistance) - return 0; + if (leftDistance == rightDistance) + return 0; - if (leftDistance == -1) - return 1; + if (leftDistance == -1) + return 1; - if (rightDistance == -1) - return -1; + if (rightDistance == -1) + return -1; - return leftDistance < rightDistance ? -1 : 1; - } + return leftDistance < rightDistance ? -1 : 1; +} + +int SemanticsVisitor::CompareOverloadCandidates(OverloadCandidate* left, OverloadCandidate* right) +{ + // If one candidate got further along in validation, pick it + if (left->status != right->status) + return int(right->status) - int(left->status); - int SemanticsVisitor::CompareOverloadCandidates( - OverloadCandidate* left, - OverloadCandidate* right) + // If both candidates are applicable, then we need to compare + // the costs of their type conversion sequences + if (left->status == OverloadCandidate::Status::Applicable) { - // If one candidate got further along in validation, pick it - if (left->status != right->status) - return int(right->status) - int(left->status); + // If one candidate incurred less cost related to + // implicit conversion of arguments to matching + // parameter types, then we should prefer that + // candidate. + // + // TODO: This eventually should be refined into + // a test that checks conversion cost per-argument, + // and only considers a candidate "better" if it + // has lower cost for at least one argument, and + // does not have higher cost for any. + // + if (left->conversionCostSum != right->conversionCostSum) + return left->conversionCostSum - right->conversionCostSum; - // If both candidates are applicable, then we need to compare - // the costs of their type conversion sequences - if(left->status == OverloadCandidate::Status::Applicable) + // If both candidates appear to be equally good when it + // comes to the per-argument conversions required, + // then we have two other categories of criteria we + // can look at to disambiguate things: + // + // 1. We can look at how the lookup process found `left` and `right` + // do decide which is a better match based purely on how "far away" + // they are for lookup purposes. A canonincal example here would + // be if one declaration shadows or overrides the other. + // + // 2. We can look at parameter lists of `left` and `right`, their types, etc. + // do decide which is a better match based purely on structure. + // Canonical examples in this case would be preferring a non-generic + // candidate over a generic one, preferring a non-variadic candidate + // over a variadic one, and preferring a candidate with fewer + // default parameters over one with more. + // + // Deciding how to order/interleave these two categories of criteria + // is an important design decision. + // + // For example, consider: + // + // float f(float x); + // + // struct S + // { + // int f(T x); + // + // float g(float y) { return f(y); } + // } + // + // In terms of structural/type matching, the global `f` is a more specialized + // candidate at the call site, while in terms of lookup/lexical crieteria + // the `S.f` declaration is better. + // + // For now we are considering lookup/overriding concerns first (so + // we would bias in favor of selecting `S.f` in the above example), and then + // structural/type concerns, but a more nuanced approach may be + // required in the future to better match programmer intuition. + // + auto itemDiff = CompareLookupResultItems(left->item, right->item); + if (itemDiff) + return itemDiff; + + auto specificityDiff = compareOverloadCandidateSpecificity(left->item, right->item); + if (specificityDiff) + return specificityDiff; + + // `export` function is more flavored than `extern` function. But other modifiers are not + // considered. + auto externExportDiff = getExportRank(left->item.declRef, right->item.declRef); + if (externExportDiff) + return externExportDiff; + + // We need to consider the distance of the declarations to the global scope to resolve this + // case: + // float f(float x); + // struct S + // { + // float f(float x); + // float g(float y) { return f(y); } // will call S::f() instead of ::f() + // } + // we will count the distance from the reference site to the declaration in the scope tree. + + // NOTE: We CAN'T do this for the generic function, because generic lookup is little bit + // complicated. It will go through multiple passes of candidates compare. In the first + // pass, it will lookup all the generic candidates that matches the generic parameter only, + // e.g., the following generic functions are totally different, but they will be selected + // as candidates because the function name and the generic parameters are the same: void + // func(Z0 a, Z1 b); void func(Z0 a, Z1 b, Z0 c); void func(Z0 a, Z1 b, Z0 c, Z1 + // d); + // + // So in this case, we should not consider the scope rank and overload rank at all, because + // there is only one of above candidates is valid, and the rank calculation doesn't + // consider the correctness of the candidates, so it could select the wrong candidate. + // + // In the next pass, the lookup system will match the input parameters in those candidates + // to find out the valid match, the "flavor" field will become "Func" or "Expr". So the + // rank calculation can be applied. + if (left->flavor == OverloadCandidate::Flavor::Generic || + left->flavor == OverloadCandidate::Flavor::UnspecializedGeneric || + right->flavor == OverloadCandidate::Flavor::Generic || + right->flavor == OverloadCandidate::Flavor::UnspecializedGeneric) { - // If one candidate incurred less cost related to - // implicit conversion of arguments to matching - // parameter types, then we should prefer that - // candidate. - // - // TODO: This eventually should be refined into - // a test that checks conversion cost per-argument, - // and only considers a candidate "better" if it - // has lower cost for at least one argument, and - // does not have higher cost for any. - // - if (left->conversionCostSum != right->conversionCostSum) - return left->conversionCostSum - right->conversionCostSum; - - // If both candidates appear to be equally good when it - // comes to the per-argument conversions required, - // then we have two other categories of criteria we - // can look at to disambiguate things: - // - // 1. We can look at how the lookup process found `left` and `right` - // do decide which is a better match based purely on how "far away" - // they are for lookup purposes. A canonincal example here would - // be if one declaration shadows or overrides the other. - // - // 2. We can look at parameter lists of `left` and `right`, their types, etc. - // do decide which is a better match based purely on structure. - // Canonical examples in this case would be preferring a non-generic - // candidate over a generic one, preferring a non-variadic candidate - // over a variadic one, and preferring a candidate with fewer - // default parameters over one with more. - // - // Deciding how to order/interleave these two categories of criteria - // is an important design decision. - // - // For example, consider: - // - // float f(float x); - // - // struct S - // { - // int f(T x); - // - // float g(float y) { return f(y); } - // } - // - // In terms of structural/type matching, the global `f` is a more specialized - // candidate at the call site, while in terms of lookup/lexical crieteria - // the `S.f` declaration is better. - // - // For now we are considering lookup/overriding concerns first (so - // we would bias in favor of selecting `S.f` in the above example), and then - // structural/type concerns, but a more nuanced approach may be - // required in the future to better match programmer intuition. - // - auto itemDiff = CompareLookupResultItems(left->item, right->item); - if(itemDiff) - return itemDiff; - - auto specificityDiff = compareOverloadCandidateSpecificity(left->item, right->item); - if(specificityDiff) - return specificityDiff; - - // `export` function is more flavored than `extern` function. But other modifiers are not considered. - auto externExportDiff = getExportRank(left->item.declRef, right->item.declRef); - if (externExportDiff) - return externExportDiff; - - // We need to consider the distance of the declarations to the global scope to resolve this case: - // float f(float x); - // struct S - // { - // float f(float x); - // float g(float y) { return f(y); } // will call S::f() instead of ::f() - // } - // we will count the distance from the reference site to the declaration in the scope tree. - - // NOTE: We CAN'T do this for the generic function, because generic lookup is little bit complicated. - // It will go through multiple passes of candidates compare. - // In the first pass, it will lookup all the generic candidates that matches the generic parameter only, - // e.g., the following generic functions are totally different, but they will be selected as candidates - // because the function name and the generic parameters are the same: - // void func(Z0 a, Z1 b); - // void func(Z0 a, Z1 b, Z0 c); - // void func(Z0 a, Z1 b, Z0 c, Z1 d); - // - // So in this case, we should not consider the scope rank and overload rank at all, because there is only - // one of above candidates is valid, and the rank calculation doesn't consider the correctness of the - // candidates, so it could select the wrong candidate. - // - // In the next pass, the lookup system will match the input parameters in those candidates to find out the valid - // match, the "flavor" field will become "Func" or "Expr". So the rank calculation can be applied. - if (left->flavor == OverloadCandidate::Flavor::Generic || - left->flavor == OverloadCandidate::Flavor::UnspecializedGeneric || - right->flavor == OverloadCandidate::Flavor::Generic || - right->flavor == OverloadCandidate::Flavor::UnspecializedGeneric) - { - return 0; - } - - auto scopeRank = getScopeRank(left->item.declRef, right->item.declRef, this->m_outerScope); - if (scopeRank) - return scopeRank; - - // If we reach here, we will attempt to use overload rank to break the ties. - auto overloadRankDiff = getOverloadRank(right->item.declRef) - getOverloadRank(left->item.declRef); - if (overloadRankDiff) - return overloadRankDiff; + return 0; } - return 0; + auto scopeRank = getScopeRank(left->item.declRef, right->item.declRef, this->m_outerScope); + if (scopeRank) + return scopeRank; + + // If we reach here, we will attempt to use overload rank to break the ties. + auto overloadRankDiff = + getOverloadRank(right->item.declRef) - getOverloadRank(left->item.declRef); + if (overloadRankDiff) + return overloadRankDiff; } - void SemanticsVisitor::AddOverloadCandidateInner( - OverloadResolveContext& context, - OverloadCandidate& candidate) - { - // Filter our existing candidates, to remove any that are worse than our new one + return 0; +} + +void SemanticsVisitor::AddOverloadCandidateInner( + OverloadResolveContext& context, + OverloadCandidate& candidate) +{ + // Filter our existing candidates, to remove any that are worse than our new one - bool keepThisCandidate = true; // should this candidate be kept? + bool keepThisCandidate = true; // should this candidate be kept? - if (context.bestCandidates.getCount() != 0) + if (context.bestCandidates.getCount() != 0) + { + // We have multiple candidates right now, so filter them. + // This is only used in an assert in debug builds + [[maybe_unused]] bool anyFiltered = false; + // Note that we are querying the list length on every iteration, + // because we might remove things. + for (Index cc = 0; cc < context.bestCandidates.getCount(); ++cc) { - // We have multiple candidates right now, so filter them. - // This is only used in an assert in debug builds - [[maybe_unused]] bool anyFiltered = false; - // Note that we are querying the list length on every iteration, - // because we might remove things. - for (Index cc = 0; cc < context.bestCandidates.getCount(); ++cc) + int cmp = CompareOverloadCandidates(&candidate, &context.bestCandidates[cc]); + if (cmp < 0) { - int cmp = CompareOverloadCandidates(&candidate, &context.bestCandidates[cc]); - if (cmp < 0) - { - // our new candidate is better! + // our new candidate is better! - // remove it from the list (by swapping in a later one) - context.bestCandidates.fastRemoveAt(cc); - // and then reduce our index so that we re-visit the same index - --cc; + // remove it from the list (by swapping in a later one) + context.bestCandidates.fastRemoveAt(cc); + // and then reduce our index so that we re-visit the same index + --cc; - anyFiltered = true; - } - else if(cmp > 0) - { - // our candidate is worse! - keepThisCandidate = false; - } - } - // It should not be possible that we removed some existing candidate *and* - // chose not to keep this candidate (otherwise the better-ness relation - // isn't transitive). Therefore we confirm that we either chose to keep - // this candidate (in which case filtering is okay), or we didn't filter - // anything. - SLANG_ASSERT(keepThisCandidate || !anyFiltered); - } - else if(context.bestCandidate) - { - // There's only one candidate so far - int cmp = CompareOverloadCandidates(&candidate, context.bestCandidate); - if(cmp < 0) - { - // our new candidate is better! - context.bestCandidate = nullptr; + anyFiltered = true; } else if (cmp > 0) { @@ -1704,1134 +1729,1181 @@ namespace Slang keepThisCandidate = false; } } - - // If our candidate isn't good enough, then drop it - if (!keepThisCandidate) - return; - - // Otherwise we want to keep the candidate - if (context.bestCandidates.getCount() > 0) - { - // There were already multiple candidates, and we are adding one more - context.bestCandidates.add(candidate); - } - else if (context.bestCandidate) + // It should not be possible that we removed some existing candidate *and* + // chose not to keep this candidate (otherwise the better-ness relation + // isn't transitive). Therefore we confirm that we either chose to keep + // this candidate (in which case filtering is okay), or we didn't filter + // anything. + SLANG_ASSERT(keepThisCandidate || !anyFiltered); + } + else if (context.bestCandidate) + { + // There's only one candidate so far + int cmp = CompareOverloadCandidates(&candidate, context.bestCandidate); + if (cmp < 0) { - // There was a unique best candidate, but now we are ambiguous - context.bestCandidates.add(*context.bestCandidate); - context.bestCandidates.add(candidate); + // our new candidate is better! context.bestCandidate = nullptr; } - else + else if (cmp > 0) { - // This is the only candidate worth keeping track of right now - context.bestCandidateStorage = candidate; - context.bestCandidate = &context.bestCandidateStorage; + // our candidate is worse! + keepThisCandidate = false; } } - void SemanticsVisitor::AddOverloadCandidate( - OverloadResolveContext& context, - OverloadCandidate& candidate, - ConversionCost baseCost) - { - // Try the candidate out, to see if it is applicable at all. - TryCheckOverloadCandidate(context, candidate); + // If our candidate isn't good enough, then drop it + if (!keepThisCandidate) + return; - candidate.conversionCostSum += baseCost; - - // Now (potentially) add it to the set of candidate overloads to consider. - AddOverloadCandidateInner(context, candidate); + // Otherwise we want to keep the candidate + if (context.bestCandidates.getCount() > 0) + { + // There were already multiple candidates, and we are adding one more + context.bestCandidates.add(candidate); } - - void SemanticsVisitor::AddFuncOverloadCandidate( - LookupResultItem item, - DeclRef funcDeclRef, - OverloadResolveContext& context, - ConversionCost baseCost) + else if (context.bestCandidate) + { + // There was a unique best candidate, but now we are ambiguous + context.bestCandidates.add(*context.bestCandidate); + context.bestCandidates.add(candidate); + context.bestCandidate = nullptr; + } + else { - auto funcDecl = funcDeclRef.getDecl(); - ensureDecl(funcDecl, DeclCheckState::CanUseFuncSignature); + // This is the only candidate worth keeping track of right now + context.bestCandidateStorage = candidate; + context.bestCandidate = &context.bestCandidateStorage; + } +} - // If this function is a redeclaration, - // then we don't want to include it multiple times, - // and mistakenly think we have an ambiguous call. - // - // Instead, we will carefully consider only the - // "primary" declaration of any callable. - if (auto primaryDecl = funcDecl->primaryDecl) - { - if (funcDecl != primaryDecl) - { - // This is a redeclaration, so we don't - // want to consider it. The primary - // declaration should also get considered - // for the call site and it will match - // anything this declaration would have - // matched. - return; - } - } +void SemanticsVisitor::AddOverloadCandidate( + OverloadResolveContext& context, + OverloadCandidate& candidate, + ConversionCost baseCost) +{ + // Try the candidate out, to see if it is applicable at all. + TryCheckOverloadCandidate(context, candidate); - OverloadCandidate candidate; - candidate.flavor = OverloadCandidate::Flavor::Func; - candidate.item = item; - candidate.resultType = getResultType(m_astBuilder, funcDeclRef); + candidate.conversionCostSum += baseCost; - AddOverloadCandidate(context, candidate, baseCost); - } + // Now (potentially) add it to the set of candidate overloads to consider. + AddOverloadCandidateInner(context, candidate); +} - void SemanticsVisitor::AddFuncOverloadCandidate( - FuncType* funcType, - OverloadResolveContext& context, - ConversionCost baseCost) - { - OverloadCandidate candidate; - candidate.flavor = OverloadCandidate::Flavor::Expr; - candidate.funcType = funcType; - candidate.resultType = funcType->getResultType(); +void SemanticsVisitor::AddFuncOverloadCandidate( + LookupResultItem item, + DeclRef funcDeclRef, + OverloadResolveContext& context, + ConversionCost baseCost) +{ + auto funcDecl = funcDeclRef.getDecl(); + ensureDecl(funcDecl, DeclCheckState::CanUseFuncSignature); - AddOverloadCandidate(context, candidate, baseCost); + // If this function is a redeclaration, + // then we don't want to include it multiple times, + // and mistakenly think we have an ambiguous call. + // + // Instead, we will carefully consider only the + // "primary" declaration of any callable. + if (auto primaryDecl = funcDecl->primaryDecl) + { + if (funcDecl != primaryDecl) + { + // This is a redeclaration, so we don't + // want to consider it. The primary + // declaration should also get considered + // for the call site and it will match + // anything this declaration would have + // matched. + return; + } } - void SemanticsVisitor::AddFuncExprOverloadCandidate( - FuncType* funcType, - OverloadResolveContext& context, - Expr* expr, - ConversionCost baseCost) - { - SLANG_ASSERT(expr); - OverloadCandidate candidate; - candidate.flavor = OverloadCandidate::Flavor::Expr; - candidate.funcType = funcType; - candidate.resultType = funcType->getResultType(); - candidate.exprVal = expr; + OverloadCandidate candidate; + candidate.flavor = OverloadCandidate::Flavor::Func; + candidate.item = item; + candidate.resultType = getResultType(m_astBuilder, funcDeclRef); - AddOverloadCandidate(context, candidate, baseCost); - } + AddOverloadCandidate(context, candidate, baseCost); +} - void SemanticsVisitor::AddCtorOverloadCandidate( - LookupResultItem typeItem, - Type* type, - DeclRef ctorDeclRef, - OverloadResolveContext& context, - Type* resultType, - ConversionCost baseCost) - { - SLANG_UNUSED(type) +void SemanticsVisitor::AddFuncOverloadCandidate( + FuncType* funcType, + OverloadResolveContext& context, + ConversionCost baseCost) +{ + OverloadCandidate candidate; + candidate.flavor = OverloadCandidate::Flavor::Expr; + candidate.funcType = funcType; + candidate.resultType = funcType->getResultType(); + + AddOverloadCandidate(context, candidate, baseCost); +} + +void SemanticsVisitor::AddFuncExprOverloadCandidate( + FuncType* funcType, + OverloadResolveContext& context, + Expr* expr, + ConversionCost baseCost) +{ + SLANG_ASSERT(expr); + OverloadCandidate candidate; + candidate.flavor = OverloadCandidate::Flavor::Expr; + candidate.funcType = funcType; + candidate.resultType = funcType->getResultType(); + candidate.exprVal = expr; + + AddOverloadCandidate(context, candidate, baseCost); +} - ensureDecl(ctorDeclRef, DeclCheckState::CanUseFuncSignature); +void SemanticsVisitor::AddCtorOverloadCandidate( + LookupResultItem typeItem, + Type* type, + DeclRef ctorDeclRef, + OverloadResolveContext& context, + Type* resultType, + ConversionCost baseCost) +{ + SLANG_UNUSED(type) - // `typeItem` refers to the type being constructed (the thing - // that was applied as a function) so we need to construct - // a `LookupResultItem` that refers to the constructor instead + ensureDecl(ctorDeclRef, DeclCheckState::CanUseFuncSignature); - LookupResultItem ctorItem; - ctorItem.declRef = ctorDeclRef; - ctorItem.breadcrumbs = new LookupResultItem::Breadcrumb( - LookupResultItem::Breadcrumb::Kind::Member, - typeItem.declRef, - nullptr, - typeItem.breadcrumbs); + // `typeItem` refers to the type being constructed (the thing + // that was applied as a function) so we need to construct + // a `LookupResultItem` that refers to the constructor instead - OverloadCandidate candidate; - candidate.flavor = OverloadCandidate::Flavor::Func; - candidate.item = ctorItem; - candidate.resultType = resultType; + LookupResultItem ctorItem; + ctorItem.declRef = ctorDeclRef; + ctorItem.breadcrumbs = new LookupResultItem::Breadcrumb( + LookupResultItem::Breadcrumb::Kind::Member, + typeItem.declRef, + nullptr, + typeItem.breadcrumbs); - AddOverloadCandidate(context, candidate, baseCost); - } + OverloadCandidate candidate; + candidate.flavor = OverloadCandidate::Flavor::Func; + candidate.item = ctorItem; + candidate.resultType = resultType; - bool SemanticsVisitor::OverloadResolveContext::matchArgumentsToParams( - SemanticsVisitor* semantics, - const List& params, - bool computeTypes, - ShortList& outMatchedArgs) + AddOverloadCandidate(context, candidate, baseCost); +} + +bool SemanticsVisitor::OverloadResolveContext::matchArgumentsToParams( + SemanticsVisitor* semantics, + const List& params, + bool computeTypes, + ShortList& outMatchedArgs) +{ + // We allow params to end with one or more variadic packs. + // We will first find out how many type packs there are. + Index typePackCount = 0; + for (Index i = params.getCount() - 1; i >= 0; --i) { - // We allow params to end with one or more variadic packs. - // We will first find out how many type packs there are. - Index typePackCount = 0; - for (Index i = params.getCount() - 1; i >= 0; --i) - { - if (isTypePack(params[i].type)) - typePackCount++; - else - break; - } - auto fixedParamCount = params.getCount() - typePackCount; + if (isTypePack(params[i].type)) + typePackCount++; + else + break; + } + auto fixedParamCount = params.getCount() - typePackCount; - auto remainingArgCount = getArgCount() - fixedParamCount; + auto remainingArgCount = getArgCount() - fixedParamCount; - // If there are remaining arguments after matching all fixed parameters, - // we'd better have at least one type pack. - if (remainingArgCount > 0 && typePackCount == 0) - return false; + // If there are remaining arguments after matching all fixed parameters, + // we'd better have at least one type pack. + if (remainingArgCount > 0 && typePackCount == 0) + return false; - // Now we can match the arguments to the parameters. + // Now we can match the arguments to the parameters. - // The fixed part comes first. - for (Index i = 0; i < Math::Min(getArgCount(), fixedParamCount); ++i) - { - MatchedArg arg; - arg.argExpr = getArg(i); - arg.argType = getArgType(i); - outMatchedArgs.add(arg); - } + // The fixed part comes first. + for (Index i = 0; i < Math::Min(getArgCount(), fixedParamCount); ++i) + { + MatchedArg arg; + arg.argExpr = getArg(i); + arg.argType = getArgType(i); + outMatchedArgs.add(arg); + } - // Try to match the variadic part. - // Is the corresponding argument a expand expr? If so it will map 1:1 to the type pack param. - auto astBuilder = semantics->getASTBuilder(); + // Try to match the variadic part. + // Is the corresponding argument a expand expr? If so it will map 1:1 to the type pack param. + auto astBuilder = semantics->getASTBuilder(); - if (remainingArgCount <= 0) - return true; - if (typePackCount == 0) - return false; + if (remainingArgCount <= 0) + return true; + if (typePackCount == 0) + return false; - // If the number of type packs can't evenly divide the remaining arguments, - // there isn't a match. - if (remainingArgCount % typePackCount != 0) - return false; + // If the number of type packs can't evenly divide the remaining arguments, + // there isn't a match. + if (remainingArgCount % typePackCount != 0) + return false; - // The default case is to group the remaining arguments into evenly divided PackExprs. - Index typePackSize = remainingArgCount / typePackCount; - for (Index i = 0; i < typePackCount; ++i) + // The default case is to group the remaining arguments into evenly divided PackExprs. + Index typePackSize = remainingArgCount / typePackCount; + for (Index i = 0; i < typePackCount; ++i) + { + // If type pack size is 1, we may not need to wrap things in a PackExpr, + // if the argument is already a pack. + if (typePackSize == 1) { - // If type pack size is 1, we may not need to wrap things in a PackExpr, - // if the argument is already a pack. - if (typePackSize == 1) + auto argType = getArgType(fixedParamCount + i); + if (auto typeType = as(argType)) { - auto argType = getArgType(fixedParamCount + i); - if (auto typeType = as(argType)) - { - argType = typeType->getType(); - } - if (isTypePack(argType)) - { - MatchedArg arg; - arg.argExpr = getArg(fixedParamCount + i); - arg.argType = getArgType(fixedParamCount + i); - outMatchedArgs.add(arg); - continue; - } + argType = typeType->getType(); } - PackExpr* packExpr = nullptr; - if (mode == Mode::ForReal) + if (isTypePack(argType)) { - packExpr = astBuilder->create(); - packExpr->loc = loc; + MatchedArg arg; + arg.argExpr = getArg(fixedParamCount + i); + arg.argType = getArgType(fixedParamCount + i); + outMatchedArgs.add(arg); + continue; } - ShortList types; - for (Index j = 0; j < typePackSize; ++j) + } + PackExpr* packExpr = nullptr; + if (mode == Mode::ForReal) + { + packExpr = astBuilder->create(); + packExpr->loc = loc; + } + ShortList types; + for (Index j = 0; j < typePackSize; ++j) + { + if (packExpr) { - if (packExpr) - { - auto arg = getArg(fixedParamCount + i * typePackSize + j); - packExpr->args.add(arg); - } - if (computeTypes) - types.add(getArgTypeForInference(fixedParamCount + i * typePackSize + j, semantics)); + auto arg = getArg(fixedParamCount + i * typePackSize + j); + packExpr->args.add(arg); } - MatchedArg matchedArg; - matchedArg.argExpr = packExpr; if (computeTypes) - { - matchedArg.argType = astBuilder->getTypePack(types.getArrayView().arrayView); - if (packExpr) - packExpr->type = matchedArg.argType; - } - outMatchedArgs.add(matchedArg); + types.add( + getArgTypeForInference(fixedParamCount + i * typePackSize + j, semantics)); } - return true; + MatchedArg matchedArg; + matchedArg.argExpr = packExpr; + if (computeTypes) + { + matchedArg.argType = astBuilder->getTypePack(types.getArrayView().arrayView); + if (packExpr) + packExpr->type = matchedArg.argType; + } + outMatchedArgs.add(matchedArg); } + return true; +} - DeclRef SemanticsVisitor::inferGenericArguments( - DeclRef genericDeclRef, - OverloadResolveContext& context, - ArrayView knownGenericArgs, - ConversionCost& outBaseCost, - List *innerParameterTypes) - { - // We have been asked to infer zero or more arguments to - // `genericDeclRef`, in a context where it is being applied - // to value-level arguments in `context`. - // - // It is possible that the call site included one or more - // explicit arguments, in which case `substWithKnownGenericArgs` - // will have been filled in and contain those. Otherwise, - // that parameter will be null, and we are expected to - // infer all arguments. - - // The declaration of the generic must be checked up to a point - // where we can attempt to form specializations of it (which in - // practice means that the declarations of its parameters and - // their constraints must have been checked). - // - ensureDecl(genericDeclRef, DeclCheckState::CanSpecializeGeneric); +DeclRef SemanticsVisitor::inferGenericArguments( + DeclRef genericDeclRef, + OverloadResolveContext& context, + ArrayView knownGenericArgs, + ConversionCost& outBaseCost, + List* innerParameterTypes) +{ + // We have been asked to infer zero or more arguments to + // `genericDeclRef`, in a context where it is being applied + // to value-level arguments in `context`. + // + // It is possible that the call site included one or more + // explicit arguments, in which case `substWithKnownGenericArgs` + // will have been filled in and contain those. Otherwise, + // that parameter will be null, and we are expected to + // infer all arguments. + + // The declaration of the generic must be checked up to a point + // where we can attempt to form specializations of it (which in + // practice means that the declarations of its parameters and + // their constraints must have been checked). + // + ensureDecl(genericDeclRef, DeclCheckState::CanSpecializeGeneric); - // Conceptually, we are going to be trying to infer any unspecified - // generic arguments by forming a system of constraints on those arguments - // and then attempting to solve the constraint system. - // - // While the constraint solver we have implemented today is not especially - // clever, we follow a flow that should in principle allow us to plug in - // something more clever down the line. - // - ConstraintSystem constraints; - constraints.loc = context.loc; - constraints.genericDecl = genericDeclRef.getDecl(); - - // In order to perform matching between the types passed in at the - // call site represented by `context` and the parameters of the - // declaraiton being applied, we want to form a reference to - // the "inner" declaration of the generic (e.g., the `FuncitonDecl` - // under the `GenericDecl`). - // - // Check what type of declaration we are dealing with, and then try - // to match it up with the arguments accordingly... + // Conceptually, we are going to be trying to infer any unspecified + // generic arguments by forming a system of constraints on those arguments + // and then attempting to solve the constraint system. + // + // While the constraint solver we have implemented today is not especially + // clever, we follow a flow that should in principle allow us to plug in + // something more clever down the line. + // + ConstraintSystem constraints; + constraints.loc = context.loc; + constraints.genericDecl = genericDeclRef.getDecl(); + + // In order to perform matching between the types passed in at the + // call site represented by `context` and the parameters of the + // declaraiton being applied, we want to form a reference to + // the "inner" declaration of the generic (e.g., the `FuncitonDecl` + // under the `GenericDecl`). + // + // Check what type of declaration we are dealing with, and then try + // to match it up with the arguments accordingly... - if (auto funcDeclRef = as(genericDeclRef.getDecl()->inner)) + if (auto funcDeclRef = as(genericDeclRef.getDecl()->inner)) + { + List paramTypes; + if (!innerParameterTypes) { - List paramTypes; - if (!innerParameterTypes) + auto params = getParameters(m_astBuilder, funcDeclRef).toArray(); + for (auto param : params) { - auto params = getParameters(m_astBuilder, funcDeclRef).toArray(); - for (auto param : params) - { - paramTypes.add(getParamQualType(m_astBuilder, param)); - } - innerParameterTypes = ¶mTypes; + paramTypes.add(getParamQualType(m_astBuilder, param)); } + innerParameterTypes = ¶mTypes; + } - ShortList matchedArgs; + ShortList matchedArgs; - // We now try to match arguments to parameters. + // We now try to match arguments to parameters. + // + // Note that if there are *too few* arguments, we might still have + // a match, because the other arguments might have default values + // that can be used. + // + if (!context.matchArgumentsToParams(this, *innerParameterTypes, true, matchedArgs)) + { + return DeclRef(); + } + + // Perform type unification between arguments and parameters, so + // we can populate the resolve system with inital constraints. + // + for (Index aa = 0; aa < matchedArgs.getCount(); ++aa) + { + // The question here is whether failure to "unify" an argument + // and parameter should lead to immediate failure. // - // Note that if there are *too few* arguments, we might still have - // a match, because the other arguments might have default values - // that can be used. + // The case that is interesting is if we want to unify, say: + // `vector` and `vector` // - if (!context.matchArgumentsToParams(this, *innerParameterTypes, true, matchedArgs)) - { - return DeclRef(); - } - - // Perform type unification between arguments and parameters, so - // we can populate the resolve system with inital constraints. + // It is clear that we should solve with `N = 3`, and then + // a later step may find that the resulting types aren't + // actually a match. + // + // A more refined approach to "unification" could of course + // see that `int` can convert to `float` and use that fact. + // (and indeed we already use something like this to unify + // `float` and `vector`) // - for (Index aa = 0; aa < matchedArgs.getCount(); ++aa) + // So the question is then whether a mismatch during the + // unification step should be taken as an immediate failure... + auto argType = matchedArgs[aa].argType; + auto paramType = (*innerParameterTypes)[aa]; + auto canUnify = TryUnifyTypes( + constraints, + ValUnificationContext(), + QualType(argType, paramType.isLeftValue), + paramType); + + // It is an error if we can't unify the argument with a type pack parameter. + if (!canUnify && isTypePack(paramType)) { - // The question here is whether failure to "unify" an argument - // and parameter should lead to immediate failure. - // - // The case that is interesting is if we want to unify, say: - // `vector` and `vector` - // - // It is clear that we should solve with `N = 3`, and then - // a later step may find that the resulting types aren't - // actually a match. - // - // A more refined approach to "unification" could of course - // see that `int` can convert to `float` and use that fact. - // (and indeed we already use something like this to unify - // `float` and `vector`) - // - // So the question is then whether a mismatch during the - // unification step should be taken as an immediate failure... - auto argType = matchedArgs[aa].argType; - auto paramType = (*innerParameterTypes)[aa]; - auto canUnify = TryUnifyTypes( - constraints, - ValUnificationContext(), - QualType(argType, paramType.isLeftValue), - paramType); - - // It is an error if we can't unify the argument with a type pack parameter. - if (!canUnify && isTypePack(paramType)) - { - return DeclRef(); - } + return DeclRef(); } } - else - { - // TODO(tfoley): any other cases needed here? - return DeclRef(); - } - - // Once we have added all the appropriate constraints to the system, we - // will try to solve for a set of arguments to the generic that satisfy - // those constraints. - // - // Note that this step *also* attempts to infer arguments for all the - // implicit parameters of a generic. Notably, this means inferring - // witnesses for interface conformance constraints. - // - // TODO(tfoley): We probably need to pass along the explicit arguments here, - // so that the solver knows to accept those arguments as-is. - // - return trySolveConstraintSystem( - &constraints, genericDeclRef, knownGenericArgs, outBaseCost); } - - void SemanticsVisitor::AddTypeOverloadCandidates( - Type* type, - OverloadResolveContext& context) + else { - // The code being checked is trying to apply `type` like a function. - // Semantically, the operations `T(args...)` is equivalent to - // `T.__init(args...)` if we had a surface syntax that supported - // looking up `__init` declarations by that name. - // - // Internally, all `__init` declarations are stored with the name - // `$init`, to avoid potential conflicts if a user decided to name - // a field/method `__init`. - // - // We will look up all the initializers on `type` by looking up - // its members named `$init`, and then proceed to perform overload - // resolution with what we find. - // - // TODO: One wrinkle here is single-argument constructor syntax. - // An operation like `(T) oneArg` or `T(oneArg)` is currently - // treated as a call expression, but we might want such cases - // to go through the type coercion logic first/instead, because - // by doing so we could weed out cases where a type is "constructed" - // from a value of the same type. There is no need in Slang for - // "copy constructors" but the core module currently has to define - // some just to make code that does, e.g., `float(1.0f)` work.) + // TODO(tfoley): any other cases needed here? + return DeclRef(); + } - LookupResult initializers = lookUpMember( - m_astBuilder, - this, - getName("$init"), - type, - context.sourceScope, - LookupMask::Default, - LookupOptions::NoDeref); + // Once we have added all the appropriate constraints to the system, we + // will try to solve for a set of arguments to the generic that satisfy + // those constraints. + // + // Note that this step *also* attempts to infer arguments for all the + // implicit parameters of a generic. Notably, this means inferring + // witnesses for interface conformance constraints. + // + // TODO(tfoley): We probably need to pass along the explicit arguments here, + // so that the solver knows to accept those arguments as-is. + // + return trySolveConstraintSystem(&constraints, genericDeclRef, knownGenericArgs, outBaseCost); +} - AddOverloadCandidates(initializers, context); - } +void SemanticsVisitor::AddTypeOverloadCandidates(Type* type, OverloadResolveContext& context) +{ + // The code being checked is trying to apply `type` like a function. + // Semantically, the operations `T(args...)` is equivalent to + // `T.__init(args...)` if we had a surface syntax that supported + // looking up `__init` declarations by that name. + // + // Internally, all `__init` declarations are stored with the name + // `$init`, to avoid potential conflicts if a user decided to name + // a field/method `__init`. + // + // We will look up all the initializers on `type` by looking up + // its members named `$init`, and then proceed to perform overload + // resolution with what we find. + // + // TODO: One wrinkle here is single-argument constructor syntax. + // An operation like `(T) oneArg` or `T(oneArg)` is currently + // treated as a call expression, but we might want such cases + // to go through the type coercion logic first/instead, because + // by doing so we could weed out cases where a type is "constructed" + // from a value of the same type. There is no need in Slang for + // "copy constructors" but the core module currently has to define + // some just to make code that does, e.g., `float(1.0f)` work.) + + LookupResult initializers = lookUpMember( + m_astBuilder, + this, + getName("$init"), + type, + context.sourceScope, + LookupMask::Default, + LookupOptions::NoDeref); + + AddOverloadCandidates(initializers, context); +} - void SemanticsVisitor::addOverloadCandidatesForCallToGeneric( - LookupResultItem genericItem, - OverloadResolveContext& context, - ArrayView knownGenericArgs) - { - auto genericDeclRef = genericItem.declRef.as(); - SLANG_ASSERT(genericDeclRef); +void SemanticsVisitor::addOverloadCandidatesForCallToGeneric( + LookupResultItem genericItem, + OverloadResolveContext& context, + ArrayView knownGenericArgs) +{ + auto genericDeclRef = genericItem.declRef.as(); + SLANG_ASSERT(genericDeclRef); - ConversionCost baseCost = kConversionCost_None; + ConversionCost baseCost = kConversionCost_None; - // Try to infer generic arguments, based on the context - DeclRef innerRef = inferGenericArguments(genericDeclRef, context, knownGenericArgs, baseCost); + // Try to infer generic arguments, based on the context + DeclRef innerRef = + inferGenericArguments(genericDeclRef, context, knownGenericArgs, baseCost); - if (innerRef) - { - // If inference works, then we've now got a - // specialized declaration reference we can apply. + if (innerRef) + { + // If inference works, then we've now got a + // specialized declaration reference we can apply. - LookupResultItem innerItem; - innerItem.breadcrumbs = genericItem.breadcrumbs; - innerItem.declRef = innerRef; - AddDeclRefOverloadCandidates(innerItem, context, baseCost); - } - else - { - // If inference failed, then we need to create - // a candidate that can be used to reflect that fact - // (so we can report a good error) - OverloadCandidate candidate; - candidate.item = genericItem; - candidate.flavor = OverloadCandidate::Flavor::UnspecializedGeneric; - candidate.status = OverloadCandidate::Status::GenericArgumentInferenceFailed; + LookupResultItem innerItem; + innerItem.breadcrumbs = genericItem.breadcrumbs; + innerItem.declRef = innerRef; + AddDeclRefOverloadCandidates(innerItem, context, baseCost); + } + else + { + // If inference failed, then we need to create + // a candidate that can be used to reflect that fact + // (so we can report a good error) + OverloadCandidate candidate; + candidate.item = genericItem; + candidate.flavor = OverloadCandidate::Flavor::UnspecializedGeneric; + candidate.status = OverloadCandidate::Status::GenericArgumentInferenceFailed; - AddOverloadCandidateInner(context, candidate); - } + AddOverloadCandidateInner(context, candidate); } +} - void SemanticsVisitor::AddDeclRefOverloadCandidates( - LookupResultItem item, - OverloadResolveContext& context, - ConversionCost baseCost) +void SemanticsVisitor::AddDeclRefOverloadCandidates( + LookupResultItem item, + OverloadResolveContext& context, + ConversionCost baseCost) +{ + if (auto funcDeclRef = item.declRef.as()) { - if (auto funcDeclRef = item.declRef.as()) - { - AddFuncOverloadCandidate(item, funcDeclRef, context, baseCost); - } - else if (auto aggTypeDeclRef = item.declRef.as()) - { - auto type = DeclRefType::create(m_astBuilder, aggTypeDeclRef); - AddTypeOverloadCandidates(type, context); - } - else if (auto genericDeclRef = item.declRef.as()) - { - LookupResultItem innerItem; - innerItem.breadcrumbs = item.breadcrumbs; - innerItem.declRef = genericDeclRef; - addOverloadCandidatesForCallToGeneric(innerItem, context, ArrayView()); - } - else if( auto typeDefDeclRef = item.declRef.as() ) - { - auto type = getNamedType(m_astBuilder, typeDefDeclRef); - AddTypeOverloadCandidates(type, context); - } - else if( auto genericTypeParamDeclRef = item.declRef.as() ) - { - auto type = DeclRefType::create(m_astBuilder, genericTypeParamDeclRef); - AddTypeOverloadCandidates(type, context); - } - else if( auto localDeclRef = item.declRef.as() ) + AddFuncOverloadCandidate(item, funcDeclRef, context, baseCost); + } + else if (auto aggTypeDeclRef = item.declRef.as()) + { + auto type = DeclRefType::create(m_astBuilder, aggTypeDeclRef); + AddTypeOverloadCandidates(type, context); + } + else if (auto genericDeclRef = item.declRef.as()) + { + LookupResultItem innerItem; + innerItem.breadcrumbs = item.breadcrumbs; + innerItem.declRef = genericDeclRef; + addOverloadCandidatesForCallToGeneric(innerItem, context, ArrayView()); + } + else if (auto typeDefDeclRef = item.declRef.as()) + { + auto type = getNamedType(m_astBuilder, typeDefDeclRef); + AddTypeOverloadCandidates(type, context); + } + else if (auto genericTypeParamDeclRef = item.declRef.as()) + { + auto type = DeclRefType::create(m_astBuilder, genericTypeParamDeclRef); + AddTypeOverloadCandidates(type, context); + } + else if (auto localDeclRef = item.declRef.as()) + { + // We could probably be broader than just parameters here + // eventually. + // Limit it for now though to make the specialization easier + // TODO: why can't this use DeclCheckState::CanUseFuncSignature + ensureDecl(localDeclRef, DeclCheckState::TypesFullyResolved); + const auto type = localDeclRef.getDecl()->getType(); + // We can only add overload candidates if this is known to be a function + if (const auto funType = as(type)) + AddFuncExprOverloadCandidate( + funType, + context, + context.originalExpr->functionExpr, + baseCost); + else + return; + } + else + { + // TODO(tfoley): any other cases needed here? + return; + } +} + +void SemanticsVisitor::AddOverloadCandidates( + LookupResult const& result, + OverloadResolveContext& context) +{ + if (result.isOverloaded()) + { + for (auto item : result.items) { - // We could probably be broader than just parameters here - // eventually. - // Limit it for now though to make the specialization easier - // TODO: why can't this use DeclCheckState::CanUseFuncSignature - ensureDecl(localDeclRef, DeclCheckState::TypesFullyResolved); - const auto type = localDeclRef.getDecl()->getType(); - // We can only add overload candidates if this is known to be a function - if(const auto funType = as(type)) - AddFuncExprOverloadCandidate(funType, context, context.originalExpr->functionExpr, baseCost); - else - return; + AddDeclRefOverloadCandidates(item, context, kConversionCost_None); } - else + } + else + { + AddDeclRefOverloadCandidates(result.item, context, kConversionCost_None); + } +} + +void SemanticsVisitor::AddOverloadCandidates(Expr* funcExpr, OverloadResolveContext& context) +{ + // A call of the form `()()` should be + // resolved as if the user wrote `()`, + // so that we avoid introducing intermediate expressions + // of function type in cases where they are not needed. + // + while (auto parenExpr = as(funcExpr)) + { + funcExpr = parenExpr->base; + } + + auto funcExprType = funcExpr->type; + + if (auto declRefExpr = as(funcExpr)) + { + // The expression directly referenced a declaration, + // so we can use that declaration directly to look + // for anything applicable. + AddDeclRefOverloadCandidates( + LookupResultItem(declRefExpr->declRef), + context, + kConversionCost_None); + } + else if (auto higherOrderExpr = as(funcExpr)) + { + // The expression is the result of a higher order function application. + AddHigherOrderOverloadCandidates(higherOrderExpr, context, kConversionCost_None); + } + else if (auto funcType = as(funcExprType)) + { + // TODO(tfoley): deprecate this path... + AddFuncOverloadCandidate(funcType, context, kConversionCost_None); + } + else if (auto overloadedExpr = as(funcExpr)) + { + AddOverloadCandidates(overloadedExpr->lookupResult2, context); + } + else if (auto overloadedExpr2 = as(funcExpr)) + { + for (auto item : overloadedExpr2->candidiateExprs) { - // TODO(tfoley): any other cases needed here? - return; + AddOverloadCandidates(item, context); } } + else if (auto partiallyAppliedGenericExpr = as(funcExpr)) + { + // A partially-applied generic is allowed as an overload candidate, + // and carries along an (incomplete) substitution that can be used + // to carry the arguments known so far. + // + addOverloadCandidatesForCallToGeneric( + LookupResultItem(partiallyAppliedGenericExpr->baseGenericDeclRef), + context, + partiallyAppliedGenericExpr->knownGenericArgs.getArrayView()); + } + else if (auto typeType = as(funcExprType)) + { + // If none of the above cases matched, but we are + // looking at a type, then I suppose we have + // a constructor call on our hands. + // + // TODO(tfoley): are there any meaningful types left + // that aren't declaration references? + auto type = typeType->getType(); + AddTypeOverloadCandidates(type, context); + return; + } +} - void SemanticsVisitor::AddOverloadCandidates( - LookupResult const& result, - OverloadResolveContext& context) +void SemanticsVisitor::AddHigherOrderOverloadCandidates( + Expr* funcExpr, + OverloadResolveContext& context, + ConversionCost baseCost) +{ + // Lookup the higher order function and process types accordingly. In the future, + // if there are enough varieties, we can have dispatch logic instead of an + // if-else ladder. + if (auto expr = as(funcExpr)) { - if(result.isOverloaded()) + auto funcDeclRefExpr = + as(getInnerMostExprFromHigherOrderExpr(expr->baseFunction)); + if (!funcDeclRefExpr) + return; + if (auto baseFuncDeclRef = funcDeclRefExpr->declRef.as()) { - for(auto item : result.items) + // Base is a normal or fully specialized generic function. + OverloadCandidate candidate; + candidate.flavor = OverloadCandidate::Flavor::Expr; + if (auto diffExpr = as(expr)) { - AddDeclRefOverloadCandidates(item, context, kConversionCost_None); + candidate.funcType = as(diffExpr->type.type); } + candidate.resultType = candidate.funcType->getResultType(); + candidate.item = LookupResultItem(baseFuncDeclRef); + candidate.exprVal = expr; + AddOverloadCandidate(context, candidate, baseCost); } - else + else if (auto baseFuncGenericDeclRef = funcDeclRefExpr->declRef.as()) { - AddDeclRefOverloadCandidates(result.item, context, kConversionCost_None); - } - } + // Process func type to generate JVP func type. + auto diffFuncType = as(expr->type.type); + SLANG_ASSERT(diffFuncType); - void SemanticsVisitor::AddOverloadCandidates( - Expr* funcExpr, - OverloadResolveContext& context) - { - // A call of the form `()()` should be - // resolved as if the user wrote `()`, - // so that we avoid introducing intermediate expressions - // of function type in cases where they are not needed. - // - while(auto parenExpr = as(funcExpr)) - { - funcExpr = parenExpr->base; - } + // Extract parameter list from processed type. + List paramTypes; - auto funcExprType = funcExpr->type; + for (Index ii = 0; ii < diffFuncType->getParamCount(); ii++) + paramTypes.add(getParamQualType(diffFuncType->getParamType(ii))); - if (auto declRefExpr = as(funcExpr)) - { - // The expression directly referenced a declaration, - // so we can use that declaration directly to look - // for anything applicable. - AddDeclRefOverloadCandidates(LookupResultItem(declRefExpr->declRef), context, kConversionCost_None); - } - else if (auto higherOrderExpr = as(funcExpr)) - { - // The expression is the result of a higher order function application. - AddHigherOrderOverloadCandidates(higherOrderExpr, context, kConversionCost_None); - } - else if (auto funcType = as(funcExprType)) - { - // TODO(tfoley): deprecate this path... - AddFuncOverloadCandidate(funcType, context, kConversionCost_None); - } - else if (auto overloadedExpr = as(funcExpr)) - { - AddOverloadCandidates(overloadedExpr->lookupResult2, context); - } - else if (auto overloadedExpr2 = as(funcExpr)) - { - for (auto item : overloadedExpr2->candidiateExprs) + // Try to infer generic arguments, based on the updated context. + OverloadResolveContext subContext = context; + ConversionCost baseCost1 = kConversionCost_None; + DeclRef innerRef = inferGenericArguments( + baseFuncGenericDeclRef, + context, + ArrayView(), + baseCost1, + ¶mTypes); + + if (!innerRef) + return; + + OverloadCandidate candidate; + candidate.flavor = OverloadCandidate::Flavor::Expr; + if (innerRef) { - AddOverloadCandidates(item, context); + diffFuncType = as(innerRef.substitute(m_astBuilder, diffFuncType)); + candidate.item = LookupResultItem(innerRef); } + else + { + candidate.item = LookupResultItem(funcDeclRefExpr->declRef); + } + candidate.funcType = as(diffFuncType); + candidate.resultType = candidate.funcType->getResultType(); + + // Substitute all types in the high-order expression chain. + Expr* inner = expr; + HigherOrderInvokeExpr* lastInner = nullptr; + while (auto hoInner = as(inner)) + { + lastInner = hoInner; + if (innerRef) + hoInner->type = innerRef.substitute(m_astBuilder, hoInner->type.type); + inner = hoInner->baseFunction; + } + // Set inner expression to resolved declref expr. + if (lastInner) + { + auto baseExpr = GetBaseExpr(funcDeclRefExpr); + lastInner->baseFunction = ConstructLookupResultExpr( + candidate.item, + baseExpr, + funcDeclRefExpr->name, + funcDeclRefExpr->loc, + funcDeclRefExpr); + } + candidate.exprVal = expr; + expr->type.type = diffFuncType; + AddOverloadCandidate(context, candidate, baseCost + baseCost1); } - else if (auto partiallyAppliedGenericExpr = as(funcExpr)) - { - // A partially-applied generic is allowed as an overload candidate, - // and carries along an (incomplete) substitution that can be used - // to carry the arguments known so far. - // - addOverloadCandidatesForCallToGeneric( - LookupResultItem(partiallyAppliedGenericExpr->baseGenericDeclRef), - context, - partiallyAppliedGenericExpr->knownGenericArgs.getArrayView()); - } - else if (auto typeType = as(funcExprType)) + else { - // If none of the above cases matched, but we are - // looking at a type, then I suppose we have - // a constructor call on our hands. - // - // TODO(tfoley): are there any meaningful types left - // that aren't declaration references? - auto type = typeType->getType(); - AddTypeOverloadCandidates(type, context); - return; + // Unhandled case for the inner expr. + getSink()->diagnose(funcExpr->loc, Diagnostics::expectedFunction, funcExpr->type); + funcExpr->type = this->getASTBuilder()->getErrorType(); } } +} + +String SemanticsVisitor::getCallSignatureString(OverloadResolveContext& context) +{ + StringBuilder argsListBuilder; + argsListBuilder << "("; + + UInt argCount = context.getArgCount(); + for (UInt aa = 0; aa < argCount; ++aa) + { + if (aa != 0) + argsListBuilder << ", "; + auto argType = context.getArgType(aa); + if (argType) + context.getArgType(aa)->toText(argsListBuilder); + else + argsListBuilder << "error"; + } + argsListBuilder << ")"; + return argsListBuilder.produceString(); +} - void SemanticsVisitor::AddHigherOrderOverloadCandidates( - Expr* funcExpr, - OverloadResolveContext& context, - ConversionCost baseCost) +Expr* SemanticsVisitor::ResolveInvoke(InvokeExpr* expr) +{ + OverloadResolveContext context; + // check if this is a core module operator call, if so we want to use cached results + // to speed up compilation + bool shouldAddToCache = false; + OperatorOverloadCacheKey key; + TypeCheckingCache* typeCheckingCache = getLinkage()->getTypeCheckingCache(); + if (auto opExpr = as(expr)) { - // Lookup the higher order function and process types accordingly. In the future, - // if there are enough varieties, we can have dispatch logic instead of an - // if-else ladder. - if (auto expr = as(funcExpr)) + if (key.fromOperatorExpr(opExpr)) { - auto funcDeclRefExpr = as(getInnerMostExprFromHigherOrderExpr(expr->baseFunction)); - if (!funcDeclRefExpr) - return; - if (auto baseFuncDeclRef = funcDeclRefExpr->declRef.as()) - { - // Base is a normal or fully specialized generic function. - OverloadCandidate candidate; - candidate.flavor = OverloadCandidate::Flavor::Expr; - if (auto diffExpr = as(expr)) - { - candidate.funcType = as(diffExpr->type.type); - } - candidate.resultType = candidate.funcType->getResultType(); - candidate.item = LookupResultItem(baseFuncDeclRef); - candidate.exprVal = expr; - AddOverloadCandidate(context, candidate, baseCost); - } - else if (auto baseFuncGenericDeclRef = funcDeclRefExpr->declRef.as()) + OverloadCandidate candidate; + if (typeCheckingCache->resolvedOperatorOverloadCache.tryGetValue(key, candidate)) { - // Process func type to generate JVP func type. - auto diffFuncType = as(expr->type.type); - SLANG_ASSERT(diffFuncType); - - // Extract parameter list from processed type. - List paramTypes; - - for (Index ii = 0; ii < diffFuncType->getParamCount(); ii++) - paramTypes.add(getParamQualType(diffFuncType->getParamType(ii))); - - // Try to infer generic arguments, based on the updated context. - OverloadResolveContext subContext = context; - ConversionCost baseCost1 = kConversionCost_None; - DeclRef innerRef = inferGenericArguments( - baseFuncGenericDeclRef, - context, - ArrayView(), - baseCost1, - ¶mTypes); - - if (!innerRef) - return; - - OverloadCandidate candidate; - candidate.flavor = OverloadCandidate::Flavor::Expr; - if (innerRef) - { - diffFuncType = as(innerRef.substitute(m_astBuilder, diffFuncType)); - candidate.item = LookupResultItem(innerRef); - } - else - { - candidate.item = LookupResultItem(funcDeclRefExpr->declRef); - } - candidate.funcType = as(diffFuncType); - candidate.resultType = candidate.funcType->getResultType(); - - // Substitute all types in the high-order expression chain. - Expr* inner = expr; - HigherOrderInvokeExpr* lastInner = nullptr; - while (auto hoInner = as(inner)) - { - lastInner = hoInner; - if (innerRef) - hoInner->type = innerRef.substitute(m_astBuilder, hoInner->type.type); - inner = hoInner->baseFunction; - } - // Set inner expression to resolved declref expr. - if (lastInner) - { - auto baseExpr = GetBaseExpr(funcDeclRefExpr); - lastInner->baseFunction = ConstructLookupResultExpr(candidate.item, baseExpr, funcDeclRefExpr->name, funcDeclRefExpr->loc, funcDeclRefExpr); - } - candidate.exprVal = expr; - expr->type.type = diffFuncType; - AddOverloadCandidate(context, candidate, baseCost + baseCost1); + context.bestCandidateStorage = candidate; + context.bestCandidate = &context.bestCandidateStorage; } else { - // Unhandled case for the inner expr. - getSink()->diagnose(funcExpr->loc, - Diagnostics::expectedFunction, - funcExpr->type); - funcExpr->type = this->getASTBuilder()->getErrorType(); + shouldAddToCache = true; } - } } - String SemanticsVisitor::getCallSignatureString( - OverloadResolveContext& context) + // Look at the base expression for the call, and figure out how to invoke it. + auto funcExpr = expr->functionExpr; + + // If we are trying to apply an erroneous expression, then just bail out now. + if (IsErrorExpr(funcExpr)) { - StringBuilder argsListBuilder; - argsListBuilder << "("; + return CreateErrorExpr(expr); + } + // If any of the arguments is an error, then we should bail out, to avoid + // cascading errors where we successfully pick an overload, but not the one + // the user meant. + for (auto arg : expr->arguments) + { + if (IsErrorExpr(arg)) + return CreateErrorExpr(expr); - UInt argCount = context.getArgCount(); - for( UInt aa = 0; aa < argCount; ++aa ) + // If this argument is itself an overloaded value without a type + // then we can't sensibly continue + if (!arg->type && (as(arg) || as(arg))) { - if(aa != 0) argsListBuilder << ", "; - auto argType = context.getArgType(aa); - if (argType) - context.getArgType(aa)->toText(argsListBuilder); - else - argsListBuilder << "error"; + getSink()->diagnose(expr->loc, Diagnostics::overloadedParameterToHigherOrderFunction); + return CreateErrorExpr(expr); } - argsListBuilder << ")"; - return argsListBuilder.produceString(); } - Expr* SemanticsVisitor::ResolveInvoke(InvokeExpr* expr) + for (auto& arg : expr->arguments) { - OverloadResolveContext context; - // check if this is a core module operator call, if so we want to use cached results - // to speed up compilation - bool shouldAddToCache = false; - OperatorOverloadCacheKey key; - TypeCheckingCache* typeCheckingCache = getLinkage()->getTypeCheckingCache(); - if (auto opExpr = as(expr)) - { - if (key.fromOperatorExpr(opExpr)) - { - OverloadCandidate candidate; - if (typeCheckingCache->resolvedOperatorOverloadCache.tryGetValue(key, candidate)) - { - context.bestCandidateStorage = candidate; - context.bestCandidate = &context.bestCandidateStorage; - } - else - { - shouldAddToCache = true; - } + arg = maybeOpenRef(arg); + arg = maybeOpenExistential(arg); + } + + context.originalExpr = expr; + context.funcLoc = funcExpr->loc; + context.argCount = expr->arguments.getCount(); + context.args = &expr->arguments; + context.loc = expr->loc; + context.sourceScope = m_outerScope; + context.baseExpr = GetBaseExpr(funcExpr); + + // We run a special case here where an `InvokeExpr` + // with a single argument where the base/func expression names + // a type should always be treated as an explicit type coercion + // (and hence bottleneck through `coerce()`) instead of just + // as a constructor call. + // + // Such a special-case would help us handle cases of identity + // casts (casting an expression to the type it already has), + // without needing dummy initializer/constructor declarations. + // + // Handling that special casing here (rather than in, say, + // that `(T) expr` and `T(expr)` continue to be semantically + // `visitTypeCastExpr`) would allow us to continue to ensure + // equivalent in (almost) all cases. + // If callee is a type, and we are calling with one argument, then treat it as a + // type coercion. + bool typeOverloadChecked = false; + + if (expr->arguments.getCount() == 1) + { + if (const auto typeType = as(funcExpr->type)) + { + if (isDeclRefTypeOf(typeType->getType())) + { + Expr* resultExpr = nullptr; + DiagnosticSink tempSink(getSourceManager(), nullptr); + ConversionCost conversionCost = kConversionCost_None; + auto coerceResult = SemanticsVisitor(withSink(&tempSink)) + ._coerce( + CoercionSite::ExplicitCoercion, + typeType->getType(), + &resultExpr, + expr->arguments[0]->type, + expr->arguments[0], + &conversionCost); + if (coerceResult) + return resultExpr; + typeOverloadChecked = true; } } + } + if (!context.bestCandidate && !typeOverloadChecked) + { + AddOverloadCandidates(funcExpr, context); + } - // Look at the base expression for the call, and figure out how to invoke it. - auto funcExpr = expr->functionExpr; + if (context.bestCandidates.getCount() > 0) + { + // Things were ambiguous. - // If we are trying to apply an erroneous expression, then just bail out now. - if(IsErrorExpr(funcExpr)) - { - return CreateErrorExpr(expr); - } - // If any of the arguments is an error, then we should bail out, to avoid - // cascading errors where we successfully pick an overload, but not the one - // the user meant. + // It might be that things were only ambiguous because + // one of the argument expressions had an error, and + // so a bunch of candidates could match at that position. + // + // If any argument was an error, we skip out on printing + // another message, to avoid cascading errors. for (auto arg : expr->arguments) { if (IsErrorExpr(arg)) - return CreateErrorExpr(expr); - - // If this argument is itself an overloaded value without a type - // then we can't sensibly continue - if(!arg->type && (as(arg) || as(arg))) { - getSink()->diagnose( - expr->loc, - Diagnostics::overloadedParameterToHigherOrderFunction); return CreateErrorExpr(expr); } } - for (auto& arg : expr->arguments) + Name* funcName = nullptr; { - arg = maybeOpenRef(arg); - arg = maybeOpenExistential(arg); - } + Expr* baseExpr = funcExpr; - context.originalExpr = expr; - context.funcLoc = funcExpr->loc; - context.argCount = expr->arguments.getCount(); - context.args = &expr->arguments; - context.loc = expr->loc; - context.sourceScope = m_outerScope; - context.baseExpr = GetBaseExpr(funcExpr); + if (auto baseGenericApp = as(baseExpr)) + baseExpr = baseGenericApp->functionExpr; - // We run a special case here where an `InvokeExpr` - // with a single argument where the base/func expression names - // a type should always be treated as an explicit type coercion - // (and hence bottleneck through `coerce()`) instead of just - // as a constructor call. - // - // Such a special-case would help us handle cases of identity - // casts (casting an expression to the type it already has), - // without needing dummy initializer/constructor declarations. - // - // Handling that special casing here (rather than in, say, - // that `(T) expr` and `T(expr)` continue to be semantically - // `visitTypeCastExpr`) would allow us to continue to ensure - // equivalent in (almost) all cases. - // If callee is a type, and we are calling with one argument, then treat it as a - // type coercion. - bool typeOverloadChecked = false; + if (auto baseVar = as(baseExpr)) + funcName = baseVar->name; + else if (auto baseMemberRef = as(baseExpr)) + funcName = baseMemberRef->name; + else if (auto baseOverloaded = as(baseExpr)) + funcName = baseOverloaded->name; + } - if (expr->arguments.getCount() == 1) + String argsList = getCallSignatureString(context); + + if (context.bestCandidates[0].status != OverloadCandidate::Status::Applicable) { - if (const auto typeType = as(funcExpr->type)) + // There were multiple equally-good candidates, but none actually usable. + // We will construct a diagnostic message to help out. + + if (funcName) { - if (isDeclRefTypeOf(typeType->getType())) - { - Expr* resultExpr = nullptr; - DiagnosticSink tempSink(getSourceManager(), nullptr); - ConversionCost conversionCost = kConversionCost_None; - auto coerceResult = SemanticsVisitor(withSink(&tempSink))._coerce( - CoercionSite::ExplicitCoercion, - typeType->getType(), - &resultExpr, - expr->arguments[0]->type, - expr->arguments[0], - &conversionCost); - if (coerceResult) - return resultExpr; - typeOverloadChecked = true; - } + getSink()->diagnose( + expr, + Diagnostics::noApplicableOverloadForNameWithArgs, + funcName, + argsList); + } + else + { + getSink()->diagnose(expr, Diagnostics::noApplicableWithArgs, argsList); } } - if (!context.bestCandidate && !typeOverloadChecked) - { - AddOverloadCandidates(funcExpr, context); - } - - if (context.bestCandidates.getCount() > 0) + else { - // Things were ambiguous. + // There were multiple applicable candidates, so we need to report them. - // It might be that things were only ambiguous because - // one of the argument expressions had an error, and - // so a bunch of candidates could match at that position. - // - // If any argument was an error, we skip out on printing - // another message, to avoid cascading errors. - for (auto arg : expr->arguments) + if (funcName) { - if (IsErrorExpr(arg)) - { - return CreateErrorExpr(expr); - } + getSink()->diagnose( + expr, + Diagnostics::ambiguousOverloadForNameWithArgs, + funcName, + argsList); } - - Name* funcName = nullptr; + else { - Expr* baseExpr = funcExpr; - - if(auto baseGenericApp = as(baseExpr)) - baseExpr = baseGenericApp->functionExpr; - - if (auto baseVar = as(baseExpr)) - funcName = baseVar->name; - else if(auto baseMemberRef = as(baseExpr)) - funcName = baseMemberRef->name; - else if(auto baseOverloaded = as(baseExpr)) - funcName = baseOverloaded->name; + getSink()->diagnose(expr, Diagnostics::ambiguousOverloadWithArgs, argsList); } + } - String argsList = getCallSignatureString(context); + { + Index candidateCount = context.bestCandidates.getCount(); + Index maxCandidatesToPrint = 10; // don't show too many candidates at once... + Index candidateIndex = 0; + context.bestCandidates.sort([](const OverloadCandidate& c1, const OverloadCandidate& c2) + { return c1.status < c2.status; }); - if (context.bestCandidates[0].status != OverloadCandidate::Status::Applicable) + for (auto candidate : context.bestCandidates) { - // There were multiple equally-good candidates, but none actually usable. - // We will construct a diagnostic message to help out. + String declString = + ASTPrinter::getDeclSignatureString(candidate.item, m_astBuilder); - if (funcName) - { - getSink()->diagnose(expr, Diagnostics::noApplicableOverloadForNameWithArgs, funcName, argsList); - } + if (candidate.status == OverloadCandidate::Status::VisibilityChecked) + getSink()->diagnose( + candidate.item.declRef, + Diagnostics::invisibleOverloadCandidate, + declString); else - { - getSink()->diagnose(expr, Diagnostics::noApplicableWithArgs, argsList); - } - } - else - { - // There were multiple applicable candidates, so we need to report them. + getSink()->diagnose( + candidate.item.declRef, + Diagnostics::overloadCandidate, + declString); - if (funcName) - { - getSink()->diagnose(expr, Diagnostics::ambiguousOverloadForNameWithArgs, funcName, argsList); - } - else - { - getSink()->diagnose(expr, Diagnostics::ambiguousOverloadWithArgs, argsList); - } + candidateIndex++; + if (candidateIndex == maxCandidatesToPrint) + break; } - + if (candidateIndex != candidateCount) { - Index candidateCount = context.bestCandidates.getCount(); - Index maxCandidatesToPrint = 10; // don't show too many candidates at once... - Index candidateIndex = 0; - context.bestCandidates.sort([](const OverloadCandidate& c1, const OverloadCandidate& c2) { return c1.status < c2.status; }); - - for (auto candidate : context.bestCandidates) - { - String declString = ASTPrinter::getDeclSignatureString(candidate.item, m_astBuilder); - - if (candidate.status == OverloadCandidate::Status::VisibilityChecked) - getSink()->diagnose(candidate.item.declRef, Diagnostics::invisibleOverloadCandidate, declString); - else - getSink()->diagnose(candidate.item.declRef, Diagnostics::overloadCandidate, declString); - - candidateIndex++; - if (candidateIndex == maxCandidatesToPrint) - break; - } - if (candidateIndex != candidateCount) - { - getSink()->diagnose(expr, Diagnostics::moreOverloadCandidates, candidateCount - candidateIndex); - } + getSink()->diagnose( + expr, + Diagnostics::moreOverloadCandidates, + candidateCount - candidateIndex); } - - return CreateErrorExpr(expr); } - else if (context.bestCandidate) - { - // There was one best candidate, even if it might not have been - // applicable in the end. - // We will report errors for this one candidate, then, to give - // the user the most help we can. - if (shouldAddToCache) - typeCheckingCache->resolvedOperatorOverloadCache[key] = *context.bestCandidate; - // Now that we have resolved the overload candidate, we need to undo an `openExistential` - // operation that was applied to `out` arguments. - // - auto funcType = context.bestCandidate->funcType; - ShortList paramDirections; - if (funcType) + return CreateErrorExpr(expr); + } + else if (context.bestCandidate) + { + // There was one best candidate, even if it might not have been + // applicable in the end. + // We will report errors for this one candidate, then, to give + // the user the most help we can. + if (shouldAddToCache) + typeCheckingCache->resolvedOperatorOverloadCache[key] = *context.bestCandidate; + + // Now that we have resolved the overload candidate, we need to undo an `openExistential` + // operation that was applied to `out` arguments. + // + auto funcType = context.bestCandidate->funcType; + ShortList paramDirections; + if (funcType) + { + for (Index i = 0; i < funcType->getParamCount(); i++) { - for (Index i = 0; i < funcType->getParamCount(); i++) - { - paramDirections.add(funcType->getParamDirection(i)); - } + paramDirections.add(funcType->getParamDirection(i)); } - else if (auto callableDeclRef = context.bestCandidate->item.declRef.as()) + } + else if (auto callableDeclRef = context.bestCandidate->item.declRef.as()) + { + for (auto param : callableDeclRef.getDecl()->getParameters()) { - for (auto param : callableDeclRef.getDecl()->getParameters()) - { - paramDirections.add(getParameterDirection(param)); - } + paramDirections.add(getParameterDirection(param)); } - for (Index i = 0; i < expr->arguments.getCount(); i++) + } + for (Index i = 0; i < expr->arguments.getCount(); i++) + { + auto& arg = expr->arguments[i]; + if (i < paramDirections.getCount()) { - auto& arg = expr->arguments[i]; - if (i < paramDirections.getCount()) + switch (paramDirections[i]) { - switch (paramDirections[i]) - { - case kParameterDirection_Out: - case kParameterDirection_InOut: - case kParameterDirection_Ref: - case kParameterDirection_ConstRef: - break; - default: - continue; - } + case kParameterDirection_Out: + case kParameterDirection_InOut: + case kParameterDirection_Ref: + case kParameterDirection_ConstRef: break; + default: continue; } - if (auto extractExistentialExpr = as(arg)) - arg = extractExistentialExpr->originalExpr; } - return CompleteOverloadCandidate(context, *context.bestCandidate); + if (auto extractExistentialExpr = as(arg)) + arg = extractExistentialExpr->originalExpr; } + return CompleteOverloadCandidate(context, *context.bestCandidate); + } - // If absolutely no viable candidates were extracted from the overloaded expression, - // we may be dealing with a composite type or an overloaded expression with composite types. - // - - auto typeExpr = funcExpr; - if (auto overloadedExpr = as(funcExpr)) - { - if (overloadedExpr->lookupResult2.isValid() && overloadedExpr->lookupResult2.isOverloaded()) - { - typeExpr = maybeResolveOverloadedExpr(overloadedExpr, LookupMask::type, nullptr); - } - } + // If absolutely no viable candidates were extracted from the overloaded expression, + // we may be dealing with a composite type or an overloaded expression with composite types. + // - if (auto typetype = as(typeExpr->type)) + auto typeExpr = funcExpr; + if (auto overloadedExpr = as(funcExpr)) + { + if (overloadedExpr->lookupResult2.isValid() && overloadedExpr->lookupResult2.isOverloaded()) { - // We allow a special case when `funcExpr` represents a composite type, - // in which case we will try to construct the type via memberwise assignment from the arguments. - // - auto initListExpr = m_astBuilder->create(); - initListExpr->loc = expr->loc; - initListExpr->args.addRange(expr->arguments); - initListExpr->type = m_astBuilder->getInitializerListType(); - Expr* outExpr = nullptr; - if (_coerceInitializerList(typetype->getType(), &outExpr, initListExpr)) - return outExpr; + typeExpr = maybeResolveOverloadedExpr(overloadedExpr, LookupMask::type, nullptr); } + } - // Nothing at all was found that we could even consider invoking. - // In all other cases, this is an error. - getSink()->diagnose(expr->functionExpr, Diagnostics::expectedFunction, funcExpr->type); - expr->type = QualType(m_astBuilder->getErrorType()); - return expr; + if (auto typetype = as(typeExpr->type)) + { + // We allow a special case when `funcExpr` represents a composite type, + // in which case we will try to construct the type via memberwise assignment from the + // arguments. + // + auto initListExpr = m_astBuilder->create(); + initListExpr->loc = expr->loc; + initListExpr->args.addRange(expr->arguments); + initListExpr->type = m_astBuilder->getInitializerListType(); + Expr* outExpr = nullptr; + if (_coerceInitializerList(typetype->getType(), &outExpr, initListExpr)) + return outExpr; } - void SemanticsVisitor::AddGenericOverloadCandidate( - LookupResultItem baseItem, - OverloadResolveContext& context) + // Nothing at all was found that we could even consider invoking. + // In all other cases, this is an error. + getSink()->diagnose(expr->functionExpr, Diagnostics::expectedFunction, funcExpr->type); + expr->type = QualType(m_astBuilder->getErrorType()); + return expr; +} + +void SemanticsVisitor::AddGenericOverloadCandidate( + LookupResultItem baseItem, + OverloadResolveContext& context) +{ + if (auto genericDeclRef = baseItem.declRef.as()) { - if (auto genericDeclRef = baseItem.declRef.as()) - { - ensureDecl(genericDeclRef, DeclCheckState::CanSpecializeGeneric); + ensureDecl(genericDeclRef, DeclCheckState::CanSpecializeGeneric); - OverloadCandidate candidate; - candidate.flavor = OverloadCandidate::Flavor::Generic; - candidate.item = baseItem; - candidate.resultType = nullptr; + OverloadCandidate candidate; + candidate.flavor = OverloadCandidate::Flavor::Generic; + candidate.item = baseItem; + candidate.resultType = nullptr; - AddOverloadCandidate(context, candidate, kConversionCost_None); - } + AddOverloadCandidate(context, candidate, kConversionCost_None); } +} - void SemanticsVisitor::AddGenericOverloadCandidates( - Expr* baseExpr, - OverloadResolveContext& context) +void SemanticsVisitor::AddGenericOverloadCandidates(Expr* baseExpr, OverloadResolveContext& context) +{ + if (auto baseDeclRefExpr = as(baseExpr)) { - if(auto baseDeclRefExpr = as(baseExpr)) - { - auto declRef = baseDeclRefExpr->declRef; - AddGenericOverloadCandidate(LookupResultItem(declRef), context); - } - else if (auto overloadedExpr = as(baseExpr)) - { - // We are referring to a bunch of declarations, each of which might be generic - for (auto item : overloadedExpr->lookupResult2) - { - AddGenericOverloadCandidate(item, context); - } - } - else - { - // any other cases? - } + auto declRef = baseDeclRefExpr->declRef; + AddGenericOverloadCandidate(LookupResultItem(declRef), context); } - - Expr* SemanticsExprVisitor::visitGenericAppExpr(GenericAppExpr* genericAppExpr) + else if (auto overloadedExpr = as(baseExpr)) { - // Start by checking the base expression and arguments. - - // Disable the short-circuiting logic expression when the experssion is in - // the generic parameter. - if (this->m_shouldShortCircuitLogicExpr) + // We are referring to a bunch of declarations, each of which might be generic + for (auto item : overloadedExpr->lookupResult2) { - auto subContext = disableShortCircuitLogicalExpr(); - return dispatchExpr(genericAppExpr, subContext); + AddGenericOverloadCandidate(item, context); } + } + else + { + // any other cases? + } +} - auto& baseExpr = genericAppExpr->functionExpr; - baseExpr = CheckTerm(baseExpr); - auto& args = genericAppExpr->arguments; - for (auto& arg : args) - { - arg = CheckTerm(arg); - } +Expr* SemanticsExprVisitor::visitGenericAppExpr(GenericAppExpr* genericAppExpr) +{ + // Start by checking the base expression and arguments. - return checkGenericAppWithCheckedArgs(genericAppExpr); + // Disable the short-circuiting logic expression when the experssion is in + // the generic parameter. + if (this->m_shouldShortCircuitLogicExpr) + { + auto subContext = disableShortCircuitLogicalExpr(); + return dispatchExpr(genericAppExpr, subContext); } - /// Check a generic application where the operands have already been checked. - Expr* SemanticsVisitor::checkGenericAppWithCheckedArgs(GenericAppExpr* genericAppExpr) + auto& baseExpr = genericAppExpr->functionExpr; + baseExpr = CheckTerm(baseExpr); + auto& args = genericAppExpr->arguments; + for (auto& arg : args) { - // We are applying a generic to arguments, but there might be multiple generic - // declarations with the same name, so this becomes a specialized case of - // overload resolution. + arg = CheckTerm(arg); + } + + return checkGenericAppWithCheckedArgs(genericAppExpr); +} + +/// Check a generic application where the operands have already been checked. +Expr* SemanticsVisitor::checkGenericAppWithCheckedArgs(GenericAppExpr* genericAppExpr) +{ + // We are applying a generic to arguments, but there might be multiple generic + // declarations with the same name, so this becomes a specialized case of + // overload resolution. - auto& baseExpr = genericAppExpr->functionExpr; - auto& args = genericAppExpr->arguments; + auto& baseExpr = genericAppExpr->functionExpr; + auto& args = genericAppExpr->arguments; - // If there was an error in the base expression, or in any of - // the arguments, then just bail. - if (IsErrorExpr(baseExpr)) + // If there was an error in the base expression, or in any of + // the arguments, then just bail. + if (IsErrorExpr(baseExpr)) + { + return CreateErrorExpr(genericAppExpr); + } + for (auto argExpr : args) + { + if (IsErrorExpr(argExpr)) { return CreateErrorExpr(genericAppExpr); } - for (auto argExpr : args) - { - if (IsErrorExpr(argExpr)) - { - return CreateErrorExpr(genericAppExpr); - } - } + } - // Otherwise, let's start looking at how to find an overload... + // Otherwise, let's start looking at how to find an overload... - OverloadResolveContext context; - context.originalExpr = genericAppExpr; - context.funcLoc = baseExpr->loc; - context.argCount = args.getCount(); - context.args = &args; - context.loc = genericAppExpr->loc; - context.sourceScope = m_outerScope; - context.baseExpr = GetBaseExpr(baseExpr); + OverloadResolveContext context; + context.originalExpr = genericAppExpr; + context.funcLoc = baseExpr->loc; + context.argCount = args.getCount(); + context.args = &args; + context.loc = genericAppExpr->loc; + context.sourceScope = m_outerScope; + context.baseExpr = GetBaseExpr(baseExpr); - AddGenericOverloadCandidates(baseExpr, context); + AddGenericOverloadCandidates(baseExpr, context); - if (context.bestCandidates.getCount() > 0) + if (context.bestCandidates.getCount() > 0) + { + // Things were ambiguous. + if (context.bestCandidates[0].status != OverloadCandidate::Status::Applicable) { - // Things were ambiguous. - if (context.bestCandidates[0].status != OverloadCandidate::Status::Applicable) - { - // There were multiple equally-good candidates, but none actually usable. - // We will construct a diagnostic message to help out. - - // TODO(tfoley): print a reasonable message here... + // There were multiple equally-good candidates, but none actually usable. + // We will construct a diagnostic message to help out. - getSink()->diagnose(genericAppExpr, Diagnostics::unimplemented, "no applicable generic"); + // TODO(tfoley): print a reasonable message here... - return CreateErrorExpr(genericAppExpr); - } - else - { - // There were multiple viable candidates, but that isn't an error: we just need - // to complete all of them and create an overloaded expression as a result. + getSink()->diagnose( + genericAppExpr, + Diagnostics::unimplemented, + "no applicable generic"); - auto overloadedExpr = m_astBuilder->create(); - overloadedExpr->base = context.baseExpr; - for (auto candidate : context.bestCandidates) - { - auto candidateExpr = CompleteOverloadCandidate(context, candidate); - overloadedExpr->candidiateExprs.add(candidateExpr); - } - return overloadedExpr; - } - } - else if (context.bestCandidate) - { - // There was one best candidate, even if it might not have been - // applicable in the end. - // We will report errors for this one candidate, then, to give - // the user the most help we can. - return CompleteOverloadCandidate(context, *context.bestCandidate); + return CreateErrorExpr(genericAppExpr); } else { - // Nothing at all was found that we could even consider invoking - getSink()->diagnose(genericAppExpr, Diagnostics::expectedAGeneric, baseExpr->type); - return CreateErrorExpr(genericAppExpr); + // There were multiple viable candidates, but that isn't an error: we just need + // to complete all of them and create an overloaded expression as a result. + + auto overloadedExpr = m_astBuilder->create(); + overloadedExpr->base = context.baseExpr; + for (auto candidate : context.bestCandidates) + { + auto candidateExpr = CompleteOverloadCandidate(context, candidate); + overloadedExpr->candidiateExprs.add(candidateExpr); + } + return overloadedExpr; } } - + else if (context.bestCandidate) + { + // There was one best candidate, even if it might not have been + // applicable in the end. + // We will report errors for this one candidate, then, to give + // the user the most help we can. + return CompleteOverloadCandidate(context, *context.bestCandidate); + } + else + { + // Nothing at all was found that we could even consider invoking + getSink()->diagnose(genericAppExpr, Diagnostics::expectedAGeneric, baseExpr->type); + return CreateErrorExpr(genericAppExpr); + } } + +} // namespace Slang diff --git a/source/slang/slang-check-resolve-val.cpp b/source/slang/slang-check-resolve-val.cpp index 7cd78a1bf..92a9a9d6d 100644 --- a/source/slang/slang-check-resolve-val.cpp +++ b/source/slang/slang-check-resolve-val.cpp @@ -2,12 +2,11 @@ // Logic for resolving/simplifying Types and DeclRefs. +#include "slang-ast-reflect.h" +#include "slang-ast-synthesis.h" #include "slang-check-impl.h" - #include "slang-lookup.h" #include "slang-syntax.h" -#include "slang-ast-synthesis.h" -#include "slang-ast-reflect.h" namespace Slang { @@ -32,7 +31,8 @@ Type* DeclRefType::_createCanonicalTypeOverride() // A declaration reference is already canonical auto resolvedDeclRef = getDeclRef(); resolvedDeclRef = _resolveAsDeclRef(getDeclRef().declRefBase); - if (auto satisfyingVal = _tryLookupConcreteAssociatedTypeFromThisTypeSubst(astBuilder, resolvedDeclRef)) + if (auto satisfyingVal = + _tryLookupConcreteAssociatedTypeFromThisTypeSubst(astBuilder, resolvedDeclRef)) return as(satisfyingVal); if (resolvedDeclRef != getDeclRef()) return DeclRefType::create(astBuilder, resolvedDeclRef); @@ -55,4 +55,4 @@ ConversionCost SubtypeWitness::getOverloadResolutionCost() SLANG_AST_NODE_VIRTUAL_CALL(SubtypeWitness, getOverloadResolutionCost, ()); } -} +} // namespace Slang diff --git a/source/slang/slang-check-shader.cpp b/source/slang/slang-check-shader.cpp index c365b987b..38c7fa5c1 100644 --- a/source/slang/slang-check-shader.cpp +++ b/source/slang/slang-check-shader.cpp @@ -10,1751 +10,1842 @@ namespace Slang { - static bool isValidThreadDispatchIDType(Type* type) +static bool isValidThreadDispatchIDType(Type* type) +{ + // Can accept a single int/unit { - // Can accept a single int/unit - { - auto basicType = as(type); - if (basicType) - { - return (basicType->getBaseType() == BaseType::Int || basicType->getBaseType() == BaseType::UInt); - } - } - // Can be an int/uint vector from size 1 to 3 + auto basicType = as(type); + if (basicType) { - auto vectorType = as(type); - if (!vectorType) - { - return false; - } - auto elemCount = as(vectorType->getElementCount()); - if (elemCount->getValue() < 1 || elemCount->getValue() > 3) - { - return false; - } - // Must be a basic type - auto basicType = as(vectorType->getElementType()); - if (!basicType) - { - return false; - } - - // Must be integral - auto baseType = basicType->getBaseType(); - return (baseType == BaseType::Int || baseType == BaseType::UInt); + return ( + basicType->getBaseType() == BaseType::Int || + basicType->getBaseType() == BaseType::UInt); } } - - /// Recursively walk `paramDeclRef` and add any existential/interface specialization parameters to `ioSpecializationParams`. - static void _collectExistentialSpecializationParamsRec( - ASTBuilder* astBuilder, - SpecializationParams& ioSpecializationParams, - DeclRef paramDeclRef); - - /// Recursively walk `type` and add any existential/interface specialization parameters to `ioSpecializationParams`. - static void _collectExistentialSpecializationParamsRec( - ASTBuilder* astBuilder, - SpecializationParams& ioSpecializationParams, - Type* type, - SourceLoc loc) + // Can be an int/uint vector from size 1 to 3 { - // Whether or not something is an array does not affect - // the number of existential slots it introduces. - // - while( auto arrayType = as(type) ) + auto vectorType = as(type); + if (!vectorType) { - type = arrayType->getElementType(); + return false; } - - if( auto parameterGroupType = as(type) ) + auto elemCount = as(vectorType->getElementCount()); + if (elemCount->getValue() < 1 || elemCount->getValue() > 3) { - _collectExistentialSpecializationParamsRec( - astBuilder, - ioSpecializationParams, - parameterGroupType->getElementType(), - loc); - return; + return false; } - else if (auto structuredBufferType = as(type)) + // Must be a basic type + auto basicType = as(vectorType->getElementType()); + if (!basicType) { - _collectExistentialSpecializationParamsRec( - astBuilder, ioSpecializationParams, structuredBufferType->getElementType(), loc); - return; + return false; } - if( auto declRefType = as(type) ) - { - auto typeDeclRef = declRefType->getDeclRef(); - if( auto interfaceDeclRef = typeDeclRef.as() ) - { - // Each leaf parameter of interface type adds a specialization - // parameter, which determines the concrete type(s) that may - // be provided as arguments for that parameter. - // - SpecializationParam specializationParam; - specializationParam.flavor = SpecializationParam::Flavor::ExistentialType; - specializationParam.loc = loc; - specializationParam.object = type; - ioSpecializationParams.add(specializationParam); - } - else if( auto structDeclRef = typeDeclRef.as() ) - { - // A structure type should recursively introduce - // existential slots for its fields. - // - for( auto fieldDeclRef : getFields(astBuilder, structDeclRef, MemberFilterStyle::Instance) ) - { - _collectExistentialSpecializationParamsRec( - astBuilder, - ioSpecializationParams, - fieldDeclRef); - } - } - } + // Must be integral + auto baseType = basicType->getBaseType(); + return (baseType == BaseType::Int || baseType == BaseType::UInt); + } +} - // TODO: We eventually need to handle cases like constant - // buffers and parameter blocks that may have existential - // element types. +/// Recursively walk `paramDeclRef` and add any existential/interface specialization parameters to +/// `ioSpecializationParams`. +static void _collectExistentialSpecializationParamsRec( + ASTBuilder* astBuilder, + SpecializationParams& ioSpecializationParams, + DeclRef paramDeclRef); + +/// Recursively walk `type` and add any existential/interface specialization parameters to +/// `ioSpecializationParams`. +static void _collectExistentialSpecializationParamsRec( + ASTBuilder* astBuilder, + SpecializationParams& ioSpecializationParams, + Type* type, + SourceLoc loc) +{ + // Whether or not something is an array does not affect + // the number of existential slots it introduces. + // + while (auto arrayType = as(type)) + { + type = arrayType->getElementType(); } - static void _collectExistentialSpecializationParamsRec( - ASTBuilder* astBuilder, - SpecializationParams& ioSpecializationParams, - DeclRef paramDeclRef) + if (auto parameterGroupType = as(type)) { _collectExistentialSpecializationParamsRec( astBuilder, ioSpecializationParams, - getType(astBuilder, paramDeclRef), - paramDeclRef.getLoc()); + parameterGroupType->getElementType(), + loc); + return; } - - - /// Collect any interface/existential specialization parameters for `paramDeclRef` into `ioParamInfo` and `ioSpecializationParams` - static void _collectExistentialSpecializationParamsForShaderParam( - ASTBuilder* astBuilder, - ShaderParamInfo& ioParamInfo, - SpecializationParams& ioSpecializationParams, - DeclRef paramDeclRef) + else if (auto structuredBufferType = as(type)) { - Index beginParamIndex = ioSpecializationParams.getCount(); - _collectExistentialSpecializationParamsRec(astBuilder, ioSpecializationParams, paramDeclRef); - Index endParamIndex = ioSpecializationParams.getCount(); - - ioParamInfo.firstSpecializationParamIndex = beginParamIndex; - ioParamInfo.specializationParamCount = endParamIndex - beginParamIndex; + _collectExistentialSpecializationParamsRec( + astBuilder, + ioSpecializationParams, + structuredBufferType->getElementType(), + loc); + return; } - void EntryPoint::_collectGenericSpecializationParamsRec(Decl* decl) + if (auto declRefType = as(type)) { - if(!decl) - return; - - _collectGenericSpecializationParamsRec(decl->parentDecl); - - auto genericDecl = as(decl); - if(!genericDecl) - return; - - for(auto m : genericDecl->members) + auto typeDeclRef = declRefType->getDeclRef(); + if (auto interfaceDeclRef = typeDeclRef.as()) { - if(auto genericTypeParam = as(m)) - { - SpecializationParam param; - param.flavor = SpecializationParam::Flavor::GenericType; - param.loc = genericTypeParam->loc; - param.object = genericTypeParam; - m_genericSpecializationParams.add(param); - } - else if(auto genericValParam = as(m)) + // Each leaf parameter of interface type adds a specialization + // parameter, which determines the concrete type(s) that may + // be provided as arguments for that parameter. + // + SpecializationParam specializationParam; + specializationParam.flavor = SpecializationParam::Flavor::ExistentialType; + specializationParam.loc = loc; + specializationParam.object = type; + ioSpecializationParams.add(specializationParam); + } + else if (auto structDeclRef = typeDeclRef.as()) + { + // A structure type should recursively introduce + // existential slots for its fields. + // + for (auto fieldDeclRef : + getFields(astBuilder, structDeclRef, MemberFilterStyle::Instance)) { - SpecializationParam param; - param.flavor = SpecializationParam::Flavor::GenericValue; - param.loc = genericValParam->loc; - param.object = genericValParam; - m_genericSpecializationParams.add(param); + _collectExistentialSpecializationParamsRec( + astBuilder, + ioSpecializationParams, + fieldDeclRef); } } } - /// Enumerate the existential-type parameters of an `EntryPoint`. - /// - /// Any parameters found will be added to the list of existential slots on `this`. - /// - void EntryPoint::_collectShaderParams() - { - // We don't currently treat an entry point as having any - // *global* shader parameters. - // - // TODO: We could probably clean up the code a bit by treating - // an entry point as introducing a global shader parameter - // that is based on the implicit "parameters struct" type - // of the entry point itself. + // TODO: We eventually need to handle cases like constant + // buffers and parameter blocks that may have existential + // element types. +} - // We collect the generic parameters of the entry point, - // along with those of any outer generics first. - // - _collectGenericSpecializationParamsRec(getFuncDecl()); +static void _collectExistentialSpecializationParamsRec( + ASTBuilder* astBuilder, + SpecializationParams& ioSpecializationParams, + DeclRef paramDeclRef) +{ + _collectExistentialSpecializationParamsRec( + astBuilder, + ioSpecializationParams, + getType(astBuilder, paramDeclRef), + paramDeclRef.getLoc()); +} - // After geneic specialization parameters have been collected, - // we look through the value parameters of the entry point - // function and see if any of them introduce existential/interface - // specialization parameters. - // - // Note: we defensively test whether there is a function decl-ref - // because this routine gets called from the constructor, and - // a "dummy" entry point will have a null pointer for the function. - // - if( auto funcDeclRef = getFuncDeclRef() ) - { - for( auto paramDeclRef : getParameters(getLinkage()->getASTBuilder(), funcDeclRef) ) - { - ShaderParamInfo shaderParamInfo; - shaderParamInfo.paramDeclRef = paramDeclRef; - _collectExistentialSpecializationParamsForShaderParam( - getLinkage()->getASTBuilder(), - shaderParamInfo, - m_existentialSpecializationParams, - paramDeclRef); +/// Collect any interface/existential specialization parameters for `paramDeclRef` into +/// `ioParamInfo` and `ioSpecializationParams` +static void _collectExistentialSpecializationParamsForShaderParam( + ASTBuilder* astBuilder, + ShaderParamInfo& ioParamInfo, + SpecializationParams& ioSpecializationParams, + DeclRef paramDeclRef) +{ + Index beginParamIndex = ioSpecializationParams.getCount(); + _collectExistentialSpecializationParamsRec(astBuilder, ioSpecializationParams, paramDeclRef); + Index endParamIndex = ioSpecializationParams.getCount(); - m_shaderParams.add(shaderParamInfo); - } - } - } + ioParamInfo.firstSpecializationParamIndex = beginParamIndex; + ioParamInfo.specializationParamCount = endParamIndex - beginParamIndex; +} + +void EntryPoint::_collectGenericSpecializationParamsRec(Decl* decl) +{ + if (!decl) + return; + + _collectGenericSpecializationParamsRec(decl->parentDecl); - bool isPrimaryDecl( - CallableDecl* decl) + auto genericDecl = as(decl); + if (!genericDecl) + return; + + for (auto m : genericDecl->members) { - SLANG_ASSERT(decl); - return (!decl->primaryDecl) || (decl == decl->primaryDecl); + if (auto genericTypeParam = as(m)) + { + SpecializationParam param; + param.flavor = SpecializationParam::Flavor::GenericType; + param.loc = genericTypeParam->loc; + param.object = genericTypeParam; + m_genericSpecializationParams.add(param); + } + else if (auto genericValParam = as(m)) + { + SpecializationParam param; + param.flavor = SpecializationParam::Flavor::GenericValue; + param.loc = genericValParam->loc; + param.object = genericValParam; + m_genericSpecializationParams.add(param); + } } +} - FuncDecl* findFunctionDeclByName( - Module* translationUnit, - Name* name, - DiagnosticSink* sink) - { - FuncDecl* entryPointFuncDecl = nullptr; +/// Enumerate the existential-type parameters of an `EntryPoint`. +/// +/// Any parameters found will be added to the list of existential slots on `this`. +/// +void EntryPoint::_collectShaderParams() +{ + // We don't currently treat an entry point as having any + // *global* shader parameters. + // + // TODO: We could probably clean up the code a bit by treating + // an entry point as introducing a global shader parameter + // that is based on the implicit "parameters struct" type + // of the entry point itself. + + // We collect the generic parameters of the entry point, + // along with those of any outer generics first. + // + _collectGenericSpecializationParamsRec(getFuncDecl()); - auto expr = translationUnit->findDeclFromString(getText(name), sink); - if (auto declRefExpr = as(expr)) + // After geneic specialization parameters have been collected, + // we look through the value parameters of the entry point + // function and see if any of them introduce existential/interface + // specialization parameters. + // + // Note: we defensively test whether there is a function decl-ref + // because this routine gets called from the constructor, and + // a "dummy" entry point will have a null pointer for the function. + // + if (auto funcDeclRef = getFuncDeclRef()) + { + for (auto paramDeclRef : getParameters(getLinkage()->getASTBuilder(), funcDeclRef)) { - auto declRef = declRefExpr->declRef; - entryPointFuncDecl = declRef.as().getDecl(); + ShaderParamInfo shaderParamInfo; + shaderParamInfo.paramDeclRef = paramDeclRef; - if (entryPointFuncDecl && getModule(entryPointFuncDecl) != translationUnit) - entryPointFuncDecl = nullptr; + _collectExistentialSpecializationParamsForShaderParam( + getLinkage()->getASTBuilder(), + shaderParamInfo, + m_existentialSpecializationParams, + paramDeclRef); + + m_shaderParams.add(shaderParamInfo); } + } +} + +bool isPrimaryDecl(CallableDecl* decl) +{ + SLANG_ASSERT(decl); + return (!decl->primaryDecl) || (decl == decl->primaryDecl); +} + +FuncDecl* findFunctionDeclByName(Module* translationUnit, Name* name, DiagnosticSink* sink) +{ + FuncDecl* entryPointFuncDecl = nullptr; + + auto expr = translationUnit->findDeclFromString(getText(name), sink); + if (auto declRefExpr = as(expr)) + { + auto declRef = declRefExpr->declRef; + entryPointFuncDecl = declRef.as().getDecl(); if (entryPointFuncDecl && getModule(entryPointFuncDecl) != translationUnit) entryPointFuncDecl = nullptr; - - if (!entryPointFuncDecl) - { - auto translationUnitSyntax = translationUnit->getModuleDecl(); - sink->diagnose(translationUnitSyntax, Diagnostics::entryPointFunctionNotFound, name); - } - return entryPointFuncDecl; } - // Is a entry pointer parmaeter of `type` always a uniform parameter? - bool isUniformParameterType(Type* type) + if (entryPointFuncDecl && getModule(entryPointFuncDecl) != translationUnit) + entryPointFuncDecl = nullptr; + + if (!entryPointFuncDecl) { - if (as(type)) - return true; - if (as(type)) - return true; - if (as(type)) - return true; - if (as(type)) - return true; - if (as(type)) - return true; - if (as< GLSLShaderStorageBufferType>(type)) - return true; - if (as(type)) - return true; - if (auto arrayType = as(type)) - return isUniformParameterType(arrayType->getElementType()); - if (auto modType = as(type)) - return isUniformParameterType(modType->getBase()); - return false; + auto translationUnitSyntax = translationUnit->getModuleDecl(); + sink->diagnose(translationUnitSyntax, Diagnostics::entryPointFunctionNotFound, name); } + return entryPointFuncDecl; +} - bool isBuiltinParameterType(Type* type) - { - if (!as(type)) - return false; - if (as(type)) - return false; - if (as(type)) - return false; - if (as(type)) - return false; - if (auto arrayType = as(type)) - return isBuiltinParameterType(arrayType->getElementType()); +// Is a entry pointer parmaeter of `type` always a uniform parameter? +bool isUniformParameterType(Type* type) +{ + if (as(type)) return true; - } + if (as(type)) + return true; + if (as(type)) + return true; + if (as(type)) + return true; + if (as(type)) + return true; + if (as(type)) + return true; + if (as(type)) + return true; + if (auto arrayType = as(type)) + return isUniformParameterType(arrayType->getElementType()); + if (auto modType = as(type)) + return isUniformParameterType(modType->getBase()); + return false; +} + +bool isBuiltinParameterType(Type* type) +{ + if (!as(type)) + return false; + if (as(type)) + return false; + if (as(type)) + return false; + if (as(type)) + return false; + if (auto arrayType = as(type)) + return isBuiltinParameterType(arrayType->getElementType()); + return true; +} - bool doStructFieldsHaveSemanticImpl(Type* type, HashSet& seenTypes) +bool doStructFieldsHaveSemanticImpl(Type* type, HashSet& seenTypes) +{ + auto declRefType = as(type); + if (!declRefType) + return false; + auto structDecl = as(declRefType->getDeclRef().getDecl()); + if (!structDecl) + return false; + seenTypes.add(type); + bool hasFields = false; + for (auto field : structDecl->getFields()) { - auto declRefType = as(type); - if (!declRefType) - return false; - auto structDecl = as(declRefType->getDeclRef().getDecl()); - if (!structDecl) - return false; - seenTypes.add(type); - bool hasFields = false; - for (auto field : structDecl->getFields()) + hasFields = true; + if (!field->findModifier()) { - hasFields = true; - if (!field->findModifier()) + if (!seenTypes.contains(field->getType())) { - if (!seenTypes.contains(field->getType())) - { - if (!doStructFieldsHaveSemanticImpl(field->getType(), seenTypes)) - return false; - } + if (!doStructFieldsHaveSemanticImpl(field->getType(), seenTypes)) + return false; } } - return hasFields; } + return hasFields; +} - bool doStructFieldsHaveSemantic(Type* type) - { - HashSet seenTypes; - return doStructFieldsHaveSemanticImpl(type, seenTypes); - } +bool doStructFieldsHaveSemantic(Type* type) +{ + HashSet seenTypes; + return doStructFieldsHaveSemanticImpl(type, seenTypes); +} - // Validate that an entry point function conforms to any additional - // constraints based on the stage (and profile?) it specifies. - void validateEntryPoint( - EntryPoint* entryPoint, - DiagnosticSink* sink) - { - auto entryPointFuncDecl = entryPoint->getFuncDecl(); - auto stage = entryPoint->getStage(); +// Validate that an entry point function conforms to any additional +// constraints based on the stage (and profile?) it specifies. +void validateEntryPoint(EntryPoint* entryPoint, DiagnosticSink* sink) +{ + auto entryPointFuncDecl = entryPoint->getFuncDecl(); + auto stage = entryPoint->getStage(); - // TODO: We currently do minimal checking here, but this is the - // right place to perform the following validation checks: - // + // TODO: We currently do minimal checking here, but this is the + // right place to perform the following validation checks: + // - // * Are the function input/output parameters and result type - // all valid for the chosen stage? (e.g., there shouldn't be - // an `OutputStream` type in a vertex shader signature) - // - // * For any varying input/output, are there semantics specified - // (Note: this potentially overlaps with layout logic...), and - // are the system-value semantics valid for the given stage? - // - // There's actually a lot of detail to semantic checking, in - // that the AST-level code should probably be validating the - // use of system-value semantics by linking them to explicit - // declarations in the core module. We should also be - // using profile information on those declarations to infer - // appropriate profile restrictions on the entry point. - // - // * Is the entry point actually usable on the given stage/profile? - // E.g., if we have a vertex shader that (transitively) calls - // `Texture2D.Sample`, then that should produce an error because - // that function is specific to the fragment profile/stage. - // + // * Are the function input/output parameters and result type + // all valid for the chosen stage? (e.g., there shouldn't be + // an `OutputStream` type in a vertex shader signature) + // + // * For any varying input/output, are there semantics specified + // (Note: this potentially overlaps with layout logic...), and + // are the system-value semantics valid for the given stage? + // + // There's actually a lot of detail to semantic checking, in + // that the AST-level code should probably be validating the + // use of system-value semantics by linking them to explicit + // declarations in the core module. We should also be + // using profile information on those declarations to infer + // appropriate profile restrictions on the entry point. + // + // * Is the entry point actually usable on the given stage/profile? + // E.g., if we have a vertex shader that (transitively) calls + // `Texture2D.Sample`, then that should produce an error because + // that function is specific to the fragment profile/stage. + // - auto entryPointName = entryPointFuncDecl->getName(); + auto entryPointName = entryPointFuncDecl->getName(); - auto module = getModule(entryPointFuncDecl); - auto linkage = module->getLinkage(); + auto module = getModule(entryPointFuncDecl); + auto linkage = module->getLinkage(); - // Every entry point needs to have a stage specified either via - // command-line/API options, or via an explicit `[shader("...")]` attribute. - // - if( stage == Stage::Unknown ) - { - sink->diagnose(entryPointFuncDecl, Diagnostics::entryPointHasNoStage, entryPointName); - } + // Every entry point needs to have a stage specified either via + // command-line/API options, or via an explicit `[shader("...")]` attribute. + // + if (stage == Stage::Unknown) + { + sink->diagnose(entryPointFuncDecl, Diagnostics::entryPointHasNoStage, entryPointName); + } - if( stage == Stage::Hull ) - { - // TODO: We could consider *always* checking any `[patchconsantfunc("...")]` - // attributes, so that they need to resolve to a function. + if (stage == Stage::Hull) + { + // TODO: We could consider *always* checking any `[patchconsantfunc("...")]` + // attributes, so that they need to resolve to a function. - auto attr = entryPointFuncDecl->findModifier(); + auto attr = entryPointFuncDecl->findModifier(); - if (attr) + if (attr) + { + if (attr->args.getCount() != 1) { - if (attr->args.getCount() != 1) - { - sink->diagnose(attr, Diagnostics::badlyDefinedPatchConstantFunc, entryPointName); - return; - } + sink->diagnose(attr, Diagnostics::badlyDefinedPatchConstantFunc, entryPointName); + return; + } - Expr* expr = attr->args[0]; - StringLiteralExpr* stringLit = as(expr); + Expr* expr = attr->args[0]; + StringLiteralExpr* stringLit = as(expr); - if (!stringLit) - { - sink->diagnose(expr, Diagnostics::badlyDefinedPatchConstantFunc, entryPointName); - return; - } + if (!stringLit) + { + sink->diagnose(expr, Diagnostics::badlyDefinedPatchConstantFunc, entryPointName); + return; + } - // We look up the patch-constant function by its name in the module - // scope of the translation unit that declared the HS entry point. - // - // TODO: Eventually we probably want to do the lookup in the scope - // of the parent declarations of the entry point. E.g., if the entry - // point is a member function of a `struct`, then its patch-constant - // function should be allowed to be another member function of - // the same `struct`. - // - // In the extremely long run we may want to support an alternative to - // this attribute-based linkage between the two functions that - // make up the entry point. - // - Name* name = linkage->getNamePool()->getName(stringLit->value); - FuncDecl* patchConstantFuncDecl = findFunctionDeclByName( - module, + // We look up the patch-constant function by its name in the module + // scope of the translation unit that declared the HS entry point. + // + // TODO: Eventually we probably want to do the lookup in the scope + // of the parent declarations of the entry point. E.g., if the entry + // point is a member function of a `struct`, then its patch-constant + // function should be allowed to be another member function of + // the same `struct`. + // + // In the extremely long run we may want to support an alternative to + // this attribute-based linkage between the two functions that + // make up the entry point. + // + Name* name = linkage->getNamePool()->getName(stringLit->value); + FuncDecl* patchConstantFuncDecl = findFunctionDeclByName(module, name, sink); + if (!patchConstantFuncDecl) + { + sink->diagnose( + expr, + Diagnostics::attributeFunctionNotFound, name, - sink); - if (!patchConstantFuncDecl) - { - sink->diagnose(expr, Diagnostics::attributeFunctionNotFound, name, "patchconstantfunc"); - return; - } - - attr->patchConstantFuncDecl = patchConstantFuncDecl; + "patchconstantfunc"); + return; } + + attr->patchConstantFuncDecl = patchConstantFuncDecl; } - else if(stage == Stage::Compute) + } + else if (stage == Stage::Compute) + { + for (const auto& param : entryPointFuncDecl->getParameters()) { - for(const auto& param : entryPointFuncDecl->getParameters()) + if (auto semantic = param->findModifier()) { - if(auto semantic = param->findModifier()) - { - const auto& semanticToken = semantic->name; + const auto& semanticToken = semantic->name; - String lowerName = String(semanticToken.getContent()).toLower(); + String lowerName = String(semanticToken.getContent()).toLower(); - if(lowerName == "sv_dispatchthreadid") - { - Type* paramType = param->getType(); + if (lowerName == "sv_dispatchthreadid") + { + Type* paramType = param->getType(); - if(!isValidThreadDispatchIDType(paramType)) - { - String typeString = paramType->toString(); - sink->diagnose(param->loc, Diagnostics::invalidDispatchThreadIDType, typeString); - return; - } + if (!isValidThreadDispatchIDType(paramType)) + { + String typeString = paramType->toString(); + sink->diagnose( + param->loc, + Diagnostics::invalidDispatchThreadIDType, + typeString); + return; } } } } + } - bool canHaveVaryingInput = false; - switch (stage) - { - case Stage::Vertex: - case Stage::Fragment: - case Stage::Miss: - case Stage::AnyHit: - case Stage::ClosestHit: - case Stage::Callable: - case Stage::Geometry: - case Stage::Mesh: - case Stage::Hull: - case Stage::Domain: - canHaveVaryingInput = true; - break; - default: - break; - } + bool canHaveVaryingInput = false; + switch (stage) + { + case Stage::Vertex: + case Stage::Fragment: + case Stage::Miss: + case Stage::AnyHit: + case Stage::ClosestHit: + case Stage::Callable: + case Stage::Geometry: + case Stage::Mesh: + case Stage::Hull: + case Stage::Domain: canHaveVaryingInput = true; break; + default: break; + } - for (const auto& param : entryPointFuncDecl->getParameters()) + for (const auto& param : entryPointFuncDecl->getParameters()) + { + if (isUniformParameterType(param->getType())) { - if (isUniformParameterType(param->getType())) + // Automatically add `uniform` modifier to entry point parameters. + if (!param->hasModifier()) { - // Automatically add `uniform` modifier to entry point parameters. - if (!param->hasModifier()) - { - addModifier(param, getCurrentASTBuilder()->create()); - continue; - } - } - - if (canHaveVaryingInput) - continue; - - // If the stage doesn't allow varying input/output, - // we require the parameter to be associated with a system value semantic. - if (param->hasModifier()) - continue; - if (param->findModifier()) - continue; - - bool isBuiltinType = isBuiltinParameterType(param->getType()); - if (isBuiltinType) + addModifier(param, getCurrentASTBuilder()->create()); continue; + } + } - if (doStructFieldsHaveSemantic(param->getType())) - continue; + if (canHaveVaryingInput) + continue; + + // If the stage doesn't allow varying input/output, + // we require the parameter to be associated with a system value semantic. + if (param->hasModifier()) + continue; + if (param->findModifier()) + continue; + + bool isBuiltinType = isBuiltinParameterType(param->getType()); + if (isBuiltinType) + continue; + + if (doStructFieldsHaveSemantic(param->getType())) + continue; + + // The user is defining a parameter with no 'uniform' modifier for a stage that doesn't + // support varying input/output. We will automatically convert it to a 'uniform' parameter, + // and diagnose a warning. + addModifier(param, getCurrentASTBuilder()->create()); + sink->diagnose( + param, + Diagnostics::nonUniformEntryPointParameterTreatedAsUniform, + param->getName()); + } - // The user is defining a parameter with no 'uniform' modifier for a stage that doesn't support - // varying input/output. We will automatically convert it to a 'uniform' parameter, and diagnose a warning. - addModifier(param, getCurrentASTBuilder()->create()); - sink->diagnose(param, Diagnostics::nonUniformEntryPointParameterTreatedAsUniform, param->getName()); + for (auto target : linkage->targets) + { + auto targetCaps = target->getTargetCaps(); + auto stageCapabilitySet = entryPoint->getProfile().getCapabilityName(); + targetCaps.join(stageCapabilitySet); + if (targetCaps.isIncompatibleWith(entryPointFuncDecl->inferredCapabilityRequirements)) + { + // Incompatable means we don't support a set of abstract atoms. + // Diagnose that we lack support for 'stage' and 'target' atoms with our provided + // entry-point + auto compileTarget = target->getTargetCaps().getCompileTarget(); + auto stageTarget = stageCapabilitySet.getTargetStage(); + maybeDiagnose( + sink, + linkage->m_optionSet, + DiagnosticCategory::Capability, + entryPointFuncDecl, + Diagnostics::entryPointUsesUnavailableCapability, + entryPointFuncDecl, + compileTarget, + stageTarget); + + // Find out what is incompatible (ancestor missing a super set of 'target+stage') + CapabilitySet failedSet({(CapabilityName)compileTarget, (CapabilityName)stageTarget}); + diagnoseMissingCapabilityProvenance( + linkage->m_optionSet, + sink, + entryPointFuncDecl, + failedSet); } - - for (auto target : linkage->targets) + else { - auto targetCaps = target->getTargetCaps(); - auto stageCapabilitySet = entryPoint->getProfile().getCapabilityName(); - targetCaps.join(stageCapabilitySet); - if (targetCaps.isIncompatibleWith(entryPointFuncDecl->inferredCapabilityRequirements)) + // Only attempt to error if a user adds to slangc either `-profile` or `-capability` + if ((target->getOptionSet().hasOption(CompilerOptionName::Capability) || + target->getOptionSet().hasOption(CompilerOptionName::Profile)) && + targetCaps.atLeastOneSetImpliedInOther( + entryPointFuncDecl->inferredCapabilityRequirements) == + CapabilitySet::ImpliesReturnFlags::NotImplied) { - // Incompatable means we don't support a set of abstract atoms. - // Diagnose that we lack support for 'stage' and 'target' atoms with our provided entry-point - auto compileTarget = target->getTargetCaps().getCompileTarget(); - auto stageTarget = stageCapabilitySet.getTargetStage(); - maybeDiagnose(sink, linkage->m_optionSet, DiagnosticCategory::Capability, entryPointFuncDecl, Diagnostics::entryPointUsesUnavailableCapability, entryPointFuncDecl, compileTarget, stageTarget); - - // Find out what is incompatible (ancestor missing a super set of 'target+stage') - CapabilitySet failedSet({ (CapabilityName)compileTarget, (CapabilityName)stageTarget }); - diagnoseMissingCapabilityProvenance(linkage->m_optionSet, sink, entryPointFuncDecl, failedSet); - } - else - { - // Only attempt to error if a user adds to slangc either `-profile` or `-capability` - if ( - ( - target->getOptionSet().hasOption(CompilerOptionName::Capability) - || - target->getOptionSet().hasOption(CompilerOptionName::Profile) - ) - && targetCaps.atLeastOneSetImpliedInOther(entryPointFuncDecl->inferredCapabilityRequirements) == CapabilitySet::ImpliesReturnFlags::NotImplied - ) + CapabilitySet combinedSets = targetCaps; + combinedSets.join(entryPointFuncDecl->inferredCapabilityRequirements); + CapabilityAtomSet addedAtoms{}; + if (auto targetCapSet = targetCaps.getAtomSets()) { - CapabilitySet combinedSets = targetCaps; - combinedSets.join(entryPointFuncDecl->inferredCapabilityRequirements); - CapabilityAtomSet addedAtoms{}; - if (auto targetCapSet = targetCaps.getAtomSets()) + if (auto combinedSet = combinedSets.getAtomSets()) { - if (auto combinedSet = combinedSets.getAtomSets()) - { - CapabilityAtomSet::calcSubtract(addedAtoms, (*combinedSet), (*targetCapSet)); - } + CapabilityAtomSet::calcSubtract( + addedAtoms, + (*combinedSet), + (*targetCapSet)); } - maybeDiagnoseWarningOrError( - sink, - target->getOptionSet(), - DiagnosticCategory::Capability, - entryPointFuncDecl->loc, - Diagnostics::profileImplicitlyUpgraded, - Diagnostics::profileImplicitlyUpgradedRestrictive, - entryPointFuncDecl, - target->getOptionSet().getProfile().getName(), - addedAtoms.getElements()); } + maybeDiagnoseWarningOrError( + sink, + target->getOptionSet(), + DiagnosticCategory::Capability, + entryPointFuncDecl->loc, + Diagnostics::profileImplicitlyUpgraded, + Diagnostics::profileImplicitlyUpgradedRestrictive, + entryPointFuncDecl, + target->getOptionSet().getProfile().getName(), + addedAtoms.getElements()); } } } +} - bool resolveStageOfProfileWithEntryPoint(Profile& entryPointProfile, CompilerOptionSet& optionSet, const List>& targets, FuncDecl* entryPointFuncDecl, DiagnosticSink* sink) +bool resolveStageOfProfileWithEntryPoint( + Profile& entryPointProfile, + CompilerOptionSet& optionSet, + const List>& targets, + FuncDecl* entryPointFuncDecl, + DiagnosticSink* sink) +{ + if (auto entryPointAttr = entryPointFuncDecl->findModifier()) { - if (auto entryPointAttr = entryPointFuncDecl->findModifier()) - { - auto entryPointProfileStage = entryPointProfile.getStage(); - auto entryPointStage = getStageFromAtom(entryPointAttr->capabilitySet.getTargetStage()); + auto entryPointProfileStage = entryPointProfile.getStage(); + auto entryPointStage = getStageFromAtom(entryPointAttr->capabilitySet.getTargetStage()); - // Ensure every target is specifying the same stage as an entry-point - // if a profile+stage was set, else user will not be aware that their - // code is requiring `fragment` on a `vertex` shader - for (auto target : targets) - { - auto targetProfile = target->getOptionSet().getProfile(); - auto profileStage = targetProfile.getStage(); - if (profileStage != Stage::Unknown && profileStage != entryPointStage) - maybeDiagnose(sink, optionSet, DiagnosticCategory::Capability, entryPointAttr, Diagnostics::entryPointAndProfileAreIncompatible, entryPointFuncDecl, entryPointStage, targetProfile.getName()); - } - if (entryPointProfileStage == Stage::Unknown) - entryPointProfile = Profile(entryPointStage); - else if (entryPointProfileStage != Stage::Unknown && entryPointProfileStage != entryPointStage) - maybeDiagnose(sink, optionSet, DiagnosticCategory::Capability, entryPointFuncDecl, Diagnostics::specifiedStageDoesntMatchAttribute, entryPointFuncDecl->getName(), entryPointProfileStage, entryPointStage); - entryPointProfile.additionalCapabilities.add(entryPointAttr->capabilitySet); - return true; + // Ensure every target is specifying the same stage as an entry-point + // if a profile+stage was set, else user will not be aware that their + // code is requiring `fragment` on a `vertex` shader + for (auto target : targets) + { + auto targetProfile = target->getOptionSet().getProfile(); + auto profileStage = targetProfile.getStage(); + if (profileStage != Stage::Unknown && profileStage != entryPointStage) + maybeDiagnose( + sink, + optionSet, + DiagnosticCategory::Capability, + entryPointAttr, + Diagnostics::entryPointAndProfileAreIncompatible, + entryPointFuncDecl, + entryPointStage, + targetProfile.getName()); } - return false; + if (entryPointProfileStage == Stage::Unknown) + entryPointProfile = Profile(entryPointStage); + else if ( + entryPointProfileStage != Stage::Unknown && entryPointProfileStage != entryPointStage) + maybeDiagnose( + sink, + optionSet, + DiagnosticCategory::Capability, + entryPointFuncDecl, + Diagnostics::specifiedStageDoesntMatchAttribute, + entryPointFuncDecl->getName(), + entryPointProfileStage, + entryPointStage); + entryPointProfile.additionalCapabilities.add(entryPointAttr->capabilitySet); + return true; } + return false; +} - // Given an entry point specified via API or command line options, - // attempt to find a matching AST declaration that implements the specified - // entry point. If such a function is found, then validate that it actually - // meets the requirements for the selected stage/profile. +// Given an entry point specified via API or command line options, +// attempt to find a matching AST declaration that implements the specified +// entry point. If such a function is found, then validate that it actually +// meets the requirements for the selected stage/profile. +// +// Returns an `EntryPoint` object representing the (unspecialized) +// entry point if it is found and validated, and null otherwise. +// +RefPtr findAndValidateEntryPoint(FrontEndEntryPointRequest* entryPointReq) +{ + // The first step in validating the entry point is to find + // the (unique) function declaration that matches its name. // - // Returns an `EntryPoint` object representing the (unspecialized) - // entry point if it is found and validated, and null otherwise. + // TODO: We may eventually want/need to extend this to + // account for nested names like `SomeStruct.vsMain`, or + // indeed even to handle generics. // - RefPtr findAndValidateEntryPoint( - FrontEndEntryPointRequest* entryPointReq) + auto compileRequest = entryPointReq->getCompileRequest(); + auto translationUnit = entryPointReq->getTranslationUnit(); + auto linkage = compileRequest->getLinkage(); + auto sink = compileRequest->getSink(); + + auto entryPointName = entryPointReq->getName(); + FuncDecl* entryPointFuncDecl = + findFunctionDeclByName(translationUnit->getModule(), entryPointName, sink); + + // Did we find a function declaration in our search? + if (!entryPointFuncDecl) { - // The first step in validating the entry point is to find - // the (unique) function declaration that matches its name. - // - // TODO: We may eventually want/need to extend this to - // account for nested names like `SomeStruct.vsMain`, or - // indeed even to handle generics. - // - auto compileRequest = entryPointReq->getCompileRequest(); - auto translationUnit = entryPointReq->getTranslationUnit(); - auto linkage = compileRequest->getLinkage(); - auto sink = compileRequest->getSink(); - - auto entryPointName = entryPointReq->getName(); - FuncDecl* entryPointFuncDecl = findFunctionDeclByName(translationUnit->getModule(), entryPointName, sink); - - // Did we find a function declaration in our search? - if(!entryPointFuncDecl) - { - return nullptr; - } + return nullptr; + } - // TODO: it is possible that the entry point was declared with - // profile or target overloading. Is there anything that we need - // to do at this point to filter out declarations that aren't - // relevant to the selected profile for the entry point? + // TODO: it is possible that the entry point was declared with + // profile or target overloading. Is there anything that we need + // to do at this point to filter out declarations that aren't + // relevant to the selected profile for the entry point? - // We found something, and can start doing some basic checking. - // - // If the entry point specifies a stage via a `[shader("...")]` attribute, - // then we might be able to infer a stage for the entry point request if - // it didn't have one, *or* issue a diagnostic if there is a mismatch with the profile. + // We found something, and can start doing some basic checking. + // + // If the entry point specifies a stage via a `[shader("...")]` attribute, + // then we might be able to infer a stage for the entry point request if + // it didn't have one, *or* issue a diagnostic if there is a mismatch with the profile. + + auto entryPointProfile = entryPointReq->getProfile(); + resolveStageOfProfileWithEntryPoint( + entryPointProfile, + linkage->m_optionSet, + linkage->targets, + entryPointFuncDecl, + sink); + // TODO: Should we attach a `[shader(...)]` attribute to an + // entry point that didn't have one, so that we can have + // a more uniform representation in the AST? + + RefPtr entryPoint = + EntryPoint::create(linkage, makeDeclRef(entryPointFuncDecl), entryPointProfile); + + // Now that we've *found* the entry point, it is time to validate + // that it actually meets the constraints for the chosen stage/profile. + // + validateEntryPoint(entryPoint, sink); - auto entryPointProfile = entryPointReq->getProfile(); - resolveStageOfProfileWithEntryPoint(entryPointProfile, linkage->m_optionSet, linkage->targets, entryPointFuncDecl, sink); - // TODO: Should we attach a `[shader(...)]` attribute to an - // entry point that didn't have one, so that we can have - // a more uniform representation in the AST? + return entryPoint; +} - RefPtr entryPoint = EntryPoint::create( - linkage, - makeDeclRef(entryPointFuncDecl), - entryPointProfile); +/// Get the name a variable will use for reflection purposes +Name* getReflectionName(VarDeclBase* varDecl) +{ + if (auto reflectionNameModifier = varDecl->findModifier()) + return reflectionNameModifier->nameAndLoc.name; - // Now that we've *found* the entry point, it is time to validate - // that it actually meets the constraints for the chosen stage/profile. - // - validateEntryPoint(entryPoint, sink); + return varDecl->getName(); +} - return entryPoint; +Type* getParamType(ASTBuilder* astBuilder, DeclRef paramDeclRef) +{ + auto paramType = getType(astBuilder, paramDeclRef); + if (paramDeclRef.getDecl()->findModifier()) + { + auto modifierVal = static_cast(astBuilder->getOrCreate()); + paramType = astBuilder->getModifiedType(paramType, 1, &modifierVal); } + return paramType; +} - /// Get the name a variable will use for reflection purposes - Name* getReflectionName(VarDeclBase* varDecl) - { - if (auto reflectionNameModifier = varDecl->findModifier()) - return reflectionNameModifier->nameAndLoc.name; +void Module::_collectShaderParams() +{ + // We are going to walk the global declarations in the body of the + // module, and use those to build up our lists of: + // + // * Global shader parameters + // * Specialization parameters (both generic and interface/existential) + // * Requirements (`import`ed modules) + // + // For requirements, we want to be careful to only + // add each required module once (in case the same + // module got `import`ed multiple times), so we + // will keep a set of the modules we've already + // seen and processed. + // - return varDecl->getName(); - } + // We need to use a work list to traverse through all global scopes, + // including the top level `moduleDecl` and all the included `FileDecl`s. - Type* getParamType(ASTBuilder* astBuilder, DeclRef paramDeclRef) + List workList; + workList.add(m_moduleDecl); + + HashSet requiredModuleSet; + for (Index i = 0; i < workList.getCount(); i++) { - auto paramType = getType(astBuilder, paramDeclRef); - if (paramDeclRef.getDecl()->findModifier()) + auto moduleDecl = workList[i]; + for (auto globalDecl : moduleDecl->members) { - auto modifierVal = static_cast(astBuilder->getOrCreate()); - paramType = astBuilder->getModifiedType(paramType, 1, &modifierVal); - } - return paramType; - } + if (auto globalVar = as(globalDecl)) + { + // We do not want to consider global variable declarations + // that don't represents shader parameters. This includes + // things like `static` globals and `groupshared` variables. + // + if (!isGlobalShaderParameter(globalVar)) + continue; - void Module::_collectShaderParams() - { - // We are going to walk the global declarations in the body of the - // module, and use those to build up our lists of: - // - // * Global shader parameters - // * Specialization parameters (both generic and interface/existential) - // * Requirements (`import`ed modules) - // - // For requirements, we want to be careful to only - // add each required module once (in case the same - // module got `import`ed multiple times), so we - // will keep a set of the modules we've already - // seen and processed. - // + // At this point we know we have a global shader parameter. - // We need to use a work list to traverse through all global scopes, - // including the top level `moduleDecl` and all the included `FileDecl`s. + ShaderParamInfo shaderParamInfo; + shaderParamInfo.paramDeclRef = makeDeclRef(globalVar); - List workList; - workList.add(m_moduleDecl); + // We need to consider what specialization parameters + // are introduced by this shader parameter. This step + // fills in fields on `shaderParamInfo` so that we + // can assocaite specialization arguments supplied later + // with the correct parameter. + // + _collectExistentialSpecializationParamsForShaderParam( + getLinkage()->getASTBuilder(), + shaderParamInfo, + m_specializationParams, + makeDeclRef(globalVar)); - HashSet requiredModuleSet; - for (Index i = 0; i < workList.getCount(); i++) - { - auto moduleDecl = workList[i]; - for (auto globalDecl : moduleDecl->members) + m_shaderParams.add(shaderParamInfo); + } + else if (auto globalGenericParam = as(globalDecl)) { - if (auto globalVar = as(globalDecl)) - { - // We do not want to consider global variable declarations - // that don't represents shader parameters. This includes - // things like `static` globals and `groupshared` variables. - // - if (!isGlobalShaderParameter(globalVar)) - continue; - - // At this point we know we have a global shader parameter. - - ShaderParamInfo shaderParamInfo; - shaderParamInfo.paramDeclRef = makeDeclRef(globalVar); - - // We need to consider what specialization parameters - // are introduced by this shader parameter. This step - // fills in fields on `shaderParamInfo` so that we - // can assocaite specialization arguments supplied later - // with the correct parameter. - // - _collectExistentialSpecializationParamsForShaderParam( - getLinkage()->getASTBuilder(), - shaderParamInfo, - m_specializationParams, - makeDeclRef(globalVar)); - - m_shaderParams.add(shaderParamInfo); - } - else if (auto globalGenericParam = as(globalDecl)) - { - // A global generic type parameter declaration introduces - // a suitable specialization parameter. - // - SpecializationParam specializationParam; - specializationParam.flavor = SpecializationParam::Flavor::GenericType; - specializationParam.loc = globalGenericParam->loc; - specializationParam.object = globalGenericParam; - m_specializationParams.add(specializationParam); - } - else if (auto globalGenericValueParam = as(globalDecl)) - { - // A global generic type parameter declaration introduces - // a suitable specialization parameter. - // - SpecializationParam specializationParam; - specializationParam.flavor = SpecializationParam::Flavor::GenericValue; - specializationParam.loc = globalGenericValueParam->loc; - specializationParam.object = globalGenericValueParam; - m_specializationParams.add(specializationParam); - } - else if (auto importDecl = as(globalDecl)) - { - // An `import` declaration creates a requirement dependency - // from this module to another module. - // - auto importedModule = getModule(importDecl->importedModuleDecl); - if (!requiredModuleSet.contains(importedModule)) - { - requiredModuleSet.add(importedModule); - m_requirements.add(importedModule); - } - } - else if (auto fileDecl = as(globalDecl)) - { - // If we see a `FileDecl`, we need to recursively look into its - // scope. - workList.add(fileDecl); - } - else if (auto namespaceDecl = as(globalDecl)) + // A global generic type parameter declaration introduces + // a suitable specialization parameter. + // + SpecializationParam specializationParam; + specializationParam.flavor = SpecializationParam::Flavor::GenericType; + specializationParam.loc = globalGenericParam->loc; + specializationParam.object = globalGenericParam; + m_specializationParams.add(specializationParam); + } + else if (auto globalGenericValueParam = as(globalDecl)) + { + // A global generic type parameter declaration introduces + // a suitable specialization parameter. + // + SpecializationParam specializationParam; + specializationParam.flavor = SpecializationParam::Flavor::GenericValue; + specializationParam.loc = globalGenericValueParam->loc; + specializationParam.object = globalGenericValueParam; + m_specializationParams.add(specializationParam); + } + else if (auto importDecl = as(globalDecl)) + { + // An `import` declaration creates a requirement dependency + // from this module to another module. + // + auto importedModule = getModule(importDecl->importedModuleDecl); + if (!requiredModuleSet.contains(importedModule)) { - workList.add(namespaceDecl); + requiredModuleSet.add(importedModule); + m_requirements.add(importedModule); } } + else if (auto fileDecl = as(globalDecl)) + { + // If we see a `FileDecl`, we need to recursively look into its + // scope. + workList.add(fileDecl); + } + else if (auto namespaceDecl = as(globalDecl)) + { + workList.add(namespaceDecl); + } } } +} - Index Module::getRequirementCount() - { - return m_requirements.getCount(); - } - - RefPtr Module::getRequirement(Index index) - { - return m_requirements[index]; - } +Index Module::getRequirementCount() +{ + return m_requirements.getCount(); +} - void Module::acceptVisitor(ComponentTypeVisitor* visitor, SpecializationInfo* specializationInfo) - { - visitor->visitModule(this, as(specializationInfo)); - } +RefPtr Module::getRequirement(Index index) +{ + return m_requirements[index]; +} +void Module::acceptVisitor(ComponentTypeVisitor* visitor, SpecializationInfo* specializationInfo) +{ + visitor->visitModule(this, as(specializationInfo)); +} - /// Create a new component type based on `inComponentType`, but with all its requiremetns filled. - RefPtr fillRequirements( - ComponentType* inComponentType) - { - auto linkage = inComponentType->getLinkage(); - // We are going to simplify things by solving the problem iteratively. - // If the current `componentType` has requirements for `A`, `B`, ... etc. - // then we will create a composite of `componentType`, `A`, `B`, ... - // and then see if the resulting composite has any requirements. - // - // This avoids the problem of trying to compute teh transitive closure - // of the requirements relationship (while dealing with deduplication, - // etc.) +/// Create a new component type based on `inComponentType`, but with all its requiremetns filled. +RefPtr fillRequirements(ComponentType* inComponentType) +{ + auto linkage = inComponentType->getLinkage(); - RefPtr componentType = inComponentType; - for(;;) - { - auto requirementCount = componentType->getRequirementCount(); - if(requirementCount == 0) - break; + // We are going to simplify things by solving the problem iteratively. + // If the current `componentType` has requirements for `A`, `B`, ... etc. + // then we will create a composite of `componentType`, `A`, `B`, ... + // and then see if the resulting composite has any requirements. + // + // This avoids the problem of trying to compute teh transitive closure + // of the requirements relationship (while dealing with deduplication, + // etc.) - List> allComponents; - allComponents.add(componentType); + RefPtr componentType = inComponentType; + for (;;) + { + auto requirementCount = componentType->getRequirementCount(); + if (requirementCount == 0) + break; - for(Index rr = 0; rr < requirementCount; ++rr) - { - auto requirement = componentType->getRequirement(rr); - allComponents.add(requirement); - } + List> allComponents; + allComponents.add(componentType); - componentType = CompositeComponentType::create( - linkage, - allComponents); + for (Index rr = 0; rr < requirementCount; ++rr) + { + auto requirement = componentType->getRequirement(rr); + allComponents.add(requirement); } - return componentType; + + componentType = CompositeComponentType::create(linkage, allComponents); } + return componentType; +} + +/// Create a component type to represent the "global scope" of a compile request. +/// +/// This component type will include all the modules and their global +/// parameters from the compile request, but not anything specific +/// to any entry point functions. +/// +/// The layout for this component type will thus represent the things that +/// a user is likely to want to have stay the same across all compiled +/// entry points. +/// +/// The component type that this function creates is unspecialized, in +/// that it doesn't take into account any specialization arguments +/// that might have been supplied as part of the compile request. +/// +RefPtr createUnspecializedGlobalComponentType(FrontEndCompileRequest* compileRequest) +{ + // We want our resulting program to depend on + // all the translation units the user specified, + // even if some of them don't contain entry points + // (this is important for parameter layout/binding). + // + // We also want to ensure that the modules for the + // translation units comes first in the enumerated + // order for dependencies, to match the pre-existing + // compiler behavior (at least for now). + // + auto linkage = compileRequest->getLinkage(); - /// Create a component type to represent the "global scope" of a compile request. - /// - /// This component type will include all the modules and their global - /// parameters from the compile request, but not anything specific - /// to any entry point functions. - /// - /// The layout for this component type will thus represent the things that - /// a user is likely to want to have stay the same across all compiled - /// entry points. - /// - /// The component type that this function creates is unspecialized, in - /// that it doesn't take into account any specialization arguments - /// that might have been supplied as part of the compile request. - /// - RefPtr createUnspecializedGlobalComponentType( - FrontEndCompileRequest* compileRequest) + RefPtr globalComponentType; + if (compileRequest->translationUnits.getCount() == 1) { - // We want our resulting program to depend on - // all the translation units the user specified, - // even if some of them don't contain entry points - // (this is important for parameter layout/binding). - // - // We also want to ensure that the modules for the - // translation units comes first in the enumerated - // order for dependencies, to match the pre-existing - // compiler behavior (at least for now). + // The common case is that a compilation only uses + // a single translation unit, and thus results in + // a single `Module`. We can then use that module + // as the component type that represents the global scope. // - auto linkage = compileRequest->getLinkage(); - - RefPtr globalComponentType; - if(compileRequest->translationUnits.getCount() == 1) - { - // The common case is that a compilation only uses - // a single translation unit, and thus results in - // a single `Module`. We can then use that module - // as the component type that represents the global scope. - // - globalComponentType = compileRequest->translationUnits[0]->getModule(); - } - else + globalComponentType = compileRequest->translationUnits[0]->getModule(); + } + else + { + List> translationUnitComponentTypes; + for (auto tu : compileRequest->translationUnits) { - List> translationUnitComponentTypes; - for( auto tu : compileRequest->translationUnits ) - { - translationUnitComponentTypes.add(tu->getModule()); - } - - globalComponentType = CompositeComponentType::create( - linkage, - translationUnitComponentTypes); + translationUnitComponentTypes.add(tu->getModule()); } - return fillRequirements(globalComponentType); + globalComponentType = + CompositeComponentType::create(linkage, translationUnitComponentTypes); } - void FrontEndCompileRequest::checkEntryPoints() - { - auto linkage = getLinkage(); - SLANG_AST_BUILDER_RAII(linkage->getASTBuilder()); - - auto sink = getSink(); + return fillRequirements(globalComponentType); +} - // The validation of entry points here will be modal, and controlled - // by whether the user specified any entry points directly via - // API or command-line options. - // - // TODO: We may want to make this choice explicit rather than implicit. - // - // First, check if the user requested any entry points explicitly via - // the API or command line. - // - bool anyExplicitEntryPoints = getEntryPointReqCount() != 0; +void FrontEndCompileRequest::checkEntryPoints() +{ + auto linkage = getLinkage(); + SLANG_AST_BUILDER_RAII(linkage->getASTBuilder()); - if( anyExplicitEntryPoints ) - { - // If there were any explicit requests for entry points to be - // checked, then we will *only* check those. - // - for(auto entryPointReq : getEntryPointReqs()) - { - auto entryPoint = findAndValidateEntryPoint( - entryPointReq); - if( entryPoint ) - { - // TODO: We need to implement an explicit policy - // for what should happen if the user specified - // entry points via the command-line (or API), - // but didn't specify any groups (since the current - // compilation API doesn't allow for grouping). - // - entryPointReq->getTranslationUnit()->module->_addEntryPoint(entryPoint); - } - } + auto sink = getSink(); - // TODO: We should consider always processing both categories, - // and just making sure to only check each entry point function - // declaration once... - } - else - { - // Otherwise, scan for any `[shader(...)]` attributes in - // the user's code, and construct `EntryPoint`s to - // represent them. - // - // This ensures that downstream code only has to consider - // the central list of entry point requests, and doesn't - // have to know where they came from. + // The validation of entry points here will be modal, and controlled + // by whether the user specified any entry points directly via + // API or command-line options. + // + // TODO: We may want to make this choice explicit rather than implicit. + // + // First, check if the user requested any entry points explicitly via + // the API or command line. + // + bool anyExplicitEntryPoints = getEntryPointReqCount() != 0; - // TODO: A comprehensive approach here would need to search - // recursively for entry points, because they might appear - // as, e.g., member function of a `struct` type. - // - // For now we'll start with an extremely basic approach that - // should work for typical HLSL code. - // - Index translationUnitCount = translationUnits.getCount(); - for (Index tt = 0; tt < translationUnitCount; ++tt) + if (anyExplicitEntryPoints) + { + // If there were any explicit requests for entry points to be + // checked, then we will *only* check those. + // + for (auto entryPointReq : getEntryPointReqs()) + { + auto entryPoint = findAndValidateEntryPoint(entryPointReq); + if (entryPoint) { - auto translationUnit = translationUnits[tt]; - translationUnit->getModule()->_discoverEntryPoints(sink, this->getLinkage()->targets); + // TODO: We need to implement an explicit policy + // for what should happen if the user specified + // entry points via the command-line (or API), + // but didn't specify any groups (since the current + // compilation API doesn't allow for grouping). + // + entryPointReq->getTranslationUnit()->module->_addEntryPoint(entryPoint); } } + + // TODO: We should consider always processing both categories, + // and just making sure to only check each entry point function + // declaration once... } + else + { + // Otherwise, scan for any `[shader(...)]` attributes in + // the user's code, and construct `EntryPoint`s to + // represent them. + // + // This ensures that downstream code only has to consider + // the central list of entry point requests, and doesn't + // have to know where they came from. + // TODO: A comprehensive approach here would need to search + // recursively for entry points, because they might appear + // as, e.g., member function of a `struct` type. + // + // For now we'll start with an extremely basic approach that + // should work for typical HLSL code. + // + Index translationUnitCount = translationUnits.getCount(); + for (Index tt = 0; tt < translationUnitCount; ++tt) + { + auto translationUnit = translationUnits[tt]; + translationUnit->getModule()->_discoverEntryPoints(sink, this->getLinkage()->targets); + } + } +} - /// Create a component type that represents the global scope for a compile request, - /// along with any entry point functions. - /// - /// The resulting component type will include the global-scope information - /// first, so its layout will be compatible with the result of - /// `createUnspecializedGlobalComponentType`. - /// - /// The new component type will also add on any entry-point functions - /// that were requested and will thus include space for their `uniform` parameters. - /// If multiple entry points were requested then they will be given non-overlapping - /// parameter bindings, consistent with them being used together in - /// a single pipeline state, hit group, etc. - /// - /// The result of this function is unspecialized and doesn't take into - /// account any specialization arguments the user might have supplied. - /// - RefPtr createUnspecializedGlobalAndEntryPointsComponentType( - FrontEndCompileRequest* compileRequest, - List>& outUnspecializedEntryPoints) - { - auto linkage = compileRequest->getLinkage(); - auto globalComponentType = compileRequest->getGlobalComponentType(); +/// Create a component type that represents the global scope for a compile request, +/// along with any entry point functions. +/// +/// The resulting component type will include the global-scope information +/// first, so its layout will be compatible with the result of +/// `createUnspecializedGlobalComponentType`. +/// +/// The new component type will also add on any entry-point functions +/// that were requested and will thus include space for their `uniform` parameters. +/// If multiple entry points were requested then they will be given non-overlapping +/// parameter bindings, consistent with them being used together in +/// a single pipeline state, hit group, etc. +/// +/// The result of this function is unspecialized and doesn't take into +/// account any specialization arguments the user might have supplied. +/// +RefPtr createUnspecializedGlobalAndEntryPointsComponentType( + FrontEndCompileRequest* compileRequest, + List>& outUnspecializedEntryPoints) +{ + auto linkage = compileRequest->getLinkage(); - List> allComponentTypes; - allComponentTypes.add(globalComponentType); + auto globalComponentType = compileRequest->getGlobalComponentType(); - Index translationUnitCount = compileRequest->translationUnits.getCount(); - for(Index tt = 0; tt < translationUnitCount; ++tt) - { - auto translationUnit = compileRequest->translationUnits[tt]; - auto module = translationUnit->getModule(); + List> allComponentTypes; + allComponentTypes.add(globalComponentType); - for(auto entryPoint : module->getEntryPoints() ) - { - outUnspecializedEntryPoints.add(entryPoint); - allComponentTypes.add(entryPoint); - } - } + Index translationUnitCount = compileRequest->translationUnits.getCount(); + for (Index tt = 0; tt < translationUnitCount; ++tt) + { + auto translationUnit = compileRequest->translationUnits[tt]; + auto module = translationUnit->getModule(); - // Also consider entry points that were introduced via adding - // a library reference... - // - for( auto extraEntryPoint : compileRequest->m_extraEntryPoints ) + for (auto entryPoint : module->getEntryPoints()) { - auto entryPoint = EntryPoint::createDummyForDeserialize( - linkage, - extraEntryPoint.name, - extraEntryPoint.profile, - extraEntryPoint.mangledName); + outUnspecializedEntryPoints.add(entryPoint); allComponentTypes.add(entryPoint); } + } - if(allComponentTypes.getCount() > 1) - { - auto composite = CompositeComponentType::create( - linkage, - allComponentTypes); - return composite; - } - else - { - return globalComponentType; - } + // Also consider entry points that were introduced via adding + // a library reference... + // + for (auto extraEntryPoint : compileRequest->m_extraEntryPoints) + { + auto entryPoint = EntryPoint::createDummyForDeserialize( + linkage, + extraEntryPoint.name, + extraEntryPoint.profile, + extraEntryPoint.mangledName); + allComponentTypes.add(entryPoint); } - RefPtr Module::_validateSpecializationArgsImpl( - SpecializationArg const* args, - Index argCount, - DiagnosticSink* sink) + if (allComponentTypes.getCount() > 1) + { + auto composite = CompositeComponentType::create(linkage, allComponentTypes); + return composite; + } + else { - SLANG_ASSERT(argCount == getSpecializationParamCount()); + return globalComponentType; + } +} + +RefPtr Module::_validateSpecializationArgsImpl( + SpecializationArg const* args, + Index argCount, + DiagnosticSink* sink) +{ + SLANG_ASSERT(argCount == getSpecializationParamCount()); - SharedSemanticsContext semanticsContext(getLinkage(), this, sink); - SemanticsVisitor visitor(&semanticsContext); + SharedSemanticsContext semanticsContext(getLinkage(), this, sink); + SemanticsVisitor visitor(&semanticsContext); - RefPtr specializationInfo = new Module::ModuleSpecializationInfo(); + RefPtr specializationInfo = + new Module::ModuleSpecializationInfo(); - for( Index ii = 0; ii < argCount; ++ii ) - { - auto& arg = args[ii]; - auto& param = m_specializationParams[ii]; + for (Index ii = 0; ii < argCount; ++ii) + { + auto& arg = args[ii]; + auto& param = m_specializationParams[ii]; - switch( param.flavor ) + switch (param.flavor) + { + case SpecializationParam::Flavor::GenericType: { - case SpecializationParam::Flavor::GenericType: + auto genericTypeParamDecl = as(param.object); + SLANG_ASSERT(genericTypeParamDecl); + + Type* argType = as(arg.val); + if (!argType) { - auto genericTypeParamDecl = as(param.object); - SLANG_ASSERT(genericTypeParamDecl); + sink->diagnose( + param.loc, + Diagnostics::expectedTypeForSpecializationArg, + genericTypeParamDecl); + argType = getLinkage()->getASTBuilder()->getErrorType(); + } - Type* argType = as(arg.val); - if(!argType) - { - sink->diagnose(param.loc, Diagnostics::expectedTypeForSpecializationArg, genericTypeParamDecl); - argType = getLinkage()->getASTBuilder()->getErrorType(); - } + // TODO: There is a serious flaw to this checking logic if we ever have cases where + // the constraints on one `type_param` can depend on another `type_param`, e.g.: + // + // type_param A; + // type_param B : ISidekick; + // + // In that case, if a user tries to set `B` to `Robin` and `Robin` conforms to + // `ISidekick`, then the compiler needs to know whether `A` is being + // set to `Batman` to know whether the setting for `B` is valid. In this limit + // the constraints can be mutually recursive (so `A : IMentor`). + // + // The only way to check things correctly is to validate each conformance under + // a set of assumptions (substitutions) that includes all the type substitutions, + // and possibly also all the other constraints *except* the one to be validated. + // + // We will punt on this for now, and just check each constraint in isolation. - // TODO: There is a serious flaw to this checking logic if we ever have cases where - // the constraints on one `type_param` can depend on another `type_param`, e.g.: - // - // type_param A; - // type_param B : ISidekick; - // - // In that case, if a user tries to set `B` to `Robin` and `Robin` conforms to - // `ISidekick`, then the compiler needs to know whether `A` is being - // set to `Batman` to know whether the setting for `B` is valid. In this limit - // the constraints can be mutually recursive (so `A : IMentor`). - // - // The only way to check things correctly is to validate each conformance under - // a set of assumptions (substitutions) that includes all the type substitutions, - // and possibly also all the other constraints *except* the one to be validated. - // - // We will punt on this for now, and just check each constraint in isolation. - - // As a quick sanity check, see if the argument that is being supplied for a - // global generic type parameter is a reference to *another* global generic - // type parameter, since that should always be an error. - // - if( auto argDeclRefType = as(argType) ) + // As a quick sanity check, see if the argument that is being supplied for a + // global generic type parameter is a reference to *another* global generic + // type parameter, since that should always be an error. + // + if (auto argDeclRefType = as(argType)) + { + auto argDeclRef = argDeclRefType->getDeclRef(); + if (auto argGenericParamDeclRef = argDeclRef.as()) { - auto argDeclRef = argDeclRefType->getDeclRef(); - if(auto argGenericParamDeclRef = argDeclRef.as()) + if (argGenericParamDeclRef.getDecl() == genericTypeParamDecl) { - if(argGenericParamDeclRef.getDecl() == genericTypeParamDecl) - { - // We are trying to specialize a generic parameter using itself. - sink->diagnose(genericTypeParamDecl, - Diagnostics::cannotSpecializeGlobalGenericToItself, - genericTypeParamDecl->getName()); - continue; - } - else - { - // We are trying to specialize a generic parameter using a *different* - // global generic type parameter. - sink->diagnose(genericTypeParamDecl, - Diagnostics::cannotSpecializeGlobalGenericToAnotherGenericParam, - genericTypeParamDecl->getName(), - argGenericParamDeclRef.getName()); - continue; - } + // We are trying to specialize a generic parameter using itself. + sink->diagnose( + genericTypeParamDecl, + Diagnostics::cannotSpecializeGlobalGenericToItself, + genericTypeParamDecl->getName()); + continue; } - } - - ModuleSpecializationInfo::GenericArgInfo genericArgInfo; - genericArgInfo.paramDecl = genericTypeParamDecl; - genericArgInfo.argVal = argType; - specializationInfo->genericArgs.add(genericArgInfo); - - // Walk through the declared constraints for the parameter, - // and check that the argument actually satisfies them. - for(auto constraintDecl : genericTypeParamDecl->getMembersOfType()) - { - // Get the type that the constraint is enforcing conformance to - auto interfaceType = getSup(getLinkage()->getASTBuilder(), DeclRef(constraintDecl)); - - // Use our semantic-checking logic to search for a witness to the required conformance - auto witness = visitor.isSubtype(argType, interfaceType, IsSubTypeOptions::None); - if(!witness) + else { - // If no witness was found, then we will be unable to satisfy - // the conformances required. - sink->diagnose(genericTypeParamDecl, - Diagnostics::typeArgumentForGenericParameterDoesNotConformToInterface, - argType, - genericTypeParamDecl->nameAndLoc.name, - interfaceType); + // We are trying to specialize a generic parameter using a *different* + // global generic type parameter. + sink->diagnose( + genericTypeParamDecl, + Diagnostics::cannotSpecializeGlobalGenericToAnotherGenericParam, + genericTypeParamDecl->getName(), + argGenericParamDeclRef.getName()); + continue; } - - ModuleSpecializationInfo::GenericArgInfo constraintArgInfo; - constraintArgInfo.paramDecl = constraintDecl; - constraintArgInfo.argVal = witness; - specializationInfo->genericArgs.add(constraintArgInfo); } } - break; - case SpecializationParam::Flavor::ExistentialType: - { - auto interfaceType = as(param.object); - SLANG_ASSERT(interfaceType); + ModuleSpecializationInfo::GenericArgInfo genericArgInfo; + genericArgInfo.paramDecl = genericTypeParamDecl; + genericArgInfo.argVal = argType; + specializationInfo->genericArgs.add(genericArgInfo); - Type* argType = as(arg.val); - if(!argType) - { - sink->diagnose(param.loc, Diagnostics::expectedTypeForSpecializationArg, interfaceType); - argType = getLinkage()->getASTBuilder()->getErrorType(); - } + // Walk through the declared constraints for the parameter, + // and check that the argument actually satisfies them. + for (auto constraintDecl : + genericTypeParamDecl->getMembersOfType()) + { + // Get the type that the constraint is enforcing conformance to + auto interfaceType = getSup( + getLinkage()->getASTBuilder(), + DeclRef(constraintDecl)); - auto witness = visitor.isSubtype(argType, interfaceType, IsSubTypeOptions::None); + // Use our semantic-checking logic to search for a witness to the required + // conformance + auto witness = + visitor.isSubtype(argType, interfaceType, IsSubTypeOptions::None); if (!witness) { - // If no witness was found, then we will be unable to satisfy - // the conformances required. - sink->diagnose(SourceLoc(), - Diagnostics::typeArgumentDoesNotConformToInterface, - argType, - interfaceType); + // If no witness was found, then we will be unable to satisfy + // the conformances required. + sink->diagnose( + genericTypeParamDecl, + Diagnostics::typeArgumentForGenericParameterDoesNotConformToInterface, + argType, + genericTypeParamDecl->nameAndLoc.name, + interfaceType); } - ExpandedSpecializationArg expandedArg; - expandedArg.val = argType; - expandedArg.witness = witness; + ModuleSpecializationInfo::GenericArgInfo constraintArgInfo; + constraintArgInfo.paramDecl = constraintDecl; + constraintArgInfo.argVal = witness; + specializationInfo->genericArgs.add(constraintArgInfo); + } + } + break; + + case SpecializationParam::Flavor::ExistentialType: + { + auto interfaceType = as(param.object); + SLANG_ASSERT(interfaceType); - specializationInfo->existentialArgs.add(expandedArg); + Type* argType = as(arg.val); + if (!argType) + { + sink->diagnose( + param.loc, + Diagnostics::expectedTypeForSpecializationArg, + interfaceType); + argType = getLinkage()->getASTBuilder()->getErrorType(); } - break; - case SpecializationParam::Flavor::GenericValue: + auto witness = visitor.isSubtype(argType, interfaceType, IsSubTypeOptions::None); + if (!witness) { - auto paramDecl = as(param.object); - SLANG_ASSERT(paramDecl); + // If no witness was found, then we will be unable to satisfy + // the conformances required. + sink->diagnose( + SourceLoc(), + Diagnostics::typeArgumentDoesNotConformToInterface, + argType, + interfaceType); + } - // Now we need to check that the argument `Val` has the - // appropriate type expected by the parameter. + ExpandedSpecializationArg expandedArg; + expandedArg.val = argType; + expandedArg.witness = witness; - IntVal* intVal = as(arg.val); - if(!intVal) - { - sink->diagnose(param.loc, Diagnostics::expectedValueOfTypeForSpecializationArg, paramDecl->getType(), paramDecl); - intVal = getLinkage()->getASTBuilder()->getIntVal(m_astBuilder->getIntType(), 0); - } + specializationInfo->existentialArgs.add(expandedArg); + } + break; + + case SpecializationParam::Flavor::GenericValue: + { + auto paramDecl = as(param.object); + SLANG_ASSERT(paramDecl); - ModuleSpecializationInfo::GenericArgInfo expandedArg; - expandedArg.paramDecl = paramDecl; - expandedArg.argVal = intVal; + // Now we need to check that the argument `Val` has the + // appropriate type expected by the parameter. - specializationInfo->genericArgs.add(expandedArg); + IntVal* intVal = as(arg.val); + if (!intVal) + { + sink->diagnose( + param.loc, + Diagnostics::expectedValueOfTypeForSpecializationArg, + paramDecl->getType(), + paramDecl); + intVal = + getLinkage()->getASTBuilder()->getIntVal(m_astBuilder->getIntType(), 0); } - break; - default: - SLANG_UNEXPECTED("unhandled specialization parameter flavor"); + ModuleSpecializationInfo::GenericArgInfo expandedArg; + expandedArg.paramDecl = paramDecl; + expandedArg.argVal = intVal; + + specializationInfo->genericArgs.add(expandedArg); } - } + break; - return specializationInfo; + default: SLANG_UNEXPECTED("unhandled specialization parameter flavor"); + } } + return specializationInfo; +} + - static void _extractSpecializationArgs( - ComponentType* componentType, - List const& argExprs, - List& outArgs, - DiagnosticSink* sink) - { - auto linkage = componentType->getLinkage(); +static void _extractSpecializationArgs( + ComponentType* componentType, + List const& argExprs, + List& outArgs, + DiagnosticSink* sink) +{ + auto linkage = componentType->getLinkage(); - SharedSemanticsContext semanticsContext(linkage, nullptr, sink); - SemanticsVisitor semanticsVisitor(&semanticsContext); + SharedSemanticsContext semanticsContext(linkage, nullptr, sink); + SemanticsVisitor semanticsVisitor(&semanticsContext); - auto argCount = argExprs.getCount(); - for(Index ii = 0; ii < argCount; ++ii ) - { - auto argExpr = argExprs[ii]; + auto argCount = argExprs.getCount(); + for (Index ii = 0; ii < argCount; ++ii) + { + auto argExpr = argExprs[ii]; - SpecializationArg arg; - arg.val = semanticsVisitor.ExtractGenericArgVal(argExpr); - outArgs.add(arg); - } + SpecializationArg arg; + arg.val = semanticsVisitor.ExtractGenericArgVal(argExpr); + outArgs.add(arg); } +} - RefPtr EntryPoint::_validateSpecializationArgsImpl( - SpecializationArg const* inArgs, - Index inArgCount, - DiagnosticSink* sink) - { - auto args = inArgs; - auto argCount = inArgCount; +RefPtr EntryPoint::_validateSpecializationArgsImpl( + SpecializationArg const* inArgs, + Index inArgCount, + DiagnosticSink* sink) +{ + auto args = inArgs; + auto argCount = inArgCount; - SharedSemanticsContext sharedSemanticsContext(getLinkage(), nullptr, sink); - SemanticsVisitor visitor(&sharedSemanticsContext); + SharedSemanticsContext sharedSemanticsContext(getLinkage(), nullptr, sink); + SemanticsVisitor visitor(&sharedSemanticsContext); - // The first N arguments will be for the explicit generic parameters - // of the entry point (if it has any). - // - auto genericSpecializationParamCount = getGenericSpecializationParamCount(); - SLANG_ASSERT(argCount >= genericSpecializationParamCount); + // The first N arguments will be for the explicit generic parameters + // of the entry point (if it has any). + // + auto genericSpecializationParamCount = getGenericSpecializationParamCount(); + SLANG_ASSERT(argCount >= genericSpecializationParamCount); - RefPtr info = new EntryPointSpecializationInfo(); + RefPtr info = new EntryPointSpecializationInfo(); - DeclRef specializedFuncDeclRef = m_funcDeclRef; - if(genericSpecializationParamCount) - { - // We need to construct a generic application and use - // the semantic checking machinery to expand out - // the rest of the arguments via inference... + DeclRef specializedFuncDeclRef = m_funcDeclRef; + if (genericSpecializationParamCount) + { + // We need to construct a generic application and use + // the semantic checking machinery to expand out + // the rest of the arguments via inference... - auto genericDeclRef = m_funcDeclRef.getParent().as(); - SLANG_ASSERT(genericDeclRef); // otherwise we wouldn't have generic parameters + auto genericDeclRef = m_funcDeclRef.getParent().as(); + SLANG_ASSERT(genericDeclRef); // otherwise we wouldn't have generic parameters - List genericArgs; + List genericArgs; - for(Index ii = 0; ii < genericSpecializationParamCount; ++ii) + for (Index ii = 0; ii < genericSpecializationParamCount; ++ii) + { + auto specializationArg = args[ii]; + genericArgs.add(specializationArg.val); + } + auto astBuilder = getLinkage()->getASTBuilder(); + for (auto constraintDecl : getMembersOfType( + getLinkage()->getASTBuilder(), + DeclRef(genericDeclRef))) + { + DeclRef constraintDeclRef = + astBuilder->getDirectDeclRef(constraintDecl.getDecl()); + int argIndex = -1; + int ii = 0; + + // Find the generic parameter type (T) that this constraint (T:IFoo) is applying to. + auto genericParamType = getSub(astBuilder, constraintDeclRef); + auto genParamDeclRefType = as(genericParamType); + if (!genParamDeclRefType) { - auto specializationArg = args[ii]; - genericArgs.add(specializationArg.val); + continue; } - auto astBuilder = getLinkage()->getASTBuilder(); - for (auto constraintDecl : getMembersOfType( - getLinkage()->getASTBuilder(), DeclRef(genericDeclRef))) - { - DeclRef constraintDeclRef = astBuilder->getDirectDeclRef(constraintDecl.getDecl()); - int argIndex = -1; - int ii = 0; - - // Find the generic parameter type (T) that this constraint (T:IFoo) is applying to. - auto genericParamType = getSub(astBuilder, constraintDeclRef); - auto genParamDeclRefType = as(genericParamType); - if (!genParamDeclRefType) - { - continue; - } - auto genParamDeclRef = genParamDeclRefType->getDeclRef(); - - // Find the generic argument index of the corresponding generic parameter type in the - // generic parameter set. - // - for (auto member : genericDeclRef.getDecl()->getMembersOfType()) - { - if (member == genParamDeclRef.getDecl()) - { - argIndex = ii; - break; - } - ii++; - } - if (argIndex == -1) - { - SLANG_ASSERT(!"generic parameter not found in generic decl"); - continue; - } - auto sub = as(args[argIndex].val); - if (!sub) - { - sink->diagnose(constraintDecl, Diagnostics::expectedTypeForSpecializationArg, argIndex); - continue; - } + auto genParamDeclRef = genParamDeclRefType->getDeclRef(); - auto sup = getSup(astBuilder, constraintDeclRef); - auto subTypeWitness = visitor.isSubtype(sub, sup, IsSubTypeOptions::None); - if(subTypeWitness) - { - genericArgs.add(subTypeWitness); - } - else + // Find the generic argument index of the corresponding generic parameter type in the + // generic parameter set. + // + for (auto member : genericDeclRef.getDecl()->getMembersOfType()) + { + if (member == genParamDeclRef.getDecl()) { - // TODO: diagnose a problem here - sink->diagnose(constraintDecl, Diagnostics::typeArgumentDoesNotConformToInterface, sub, sup); - continue; + argIndex = ii; + break; } + ii++; } - - specializedFuncDeclRef = getLinkage()->getASTBuilder()->getGenericAppDeclRef(genericDeclRef, genericArgs.getArrayView()).as(); - SLANG_ASSERT(specializedFuncDeclRef); - } - - info->specializedFuncDeclRef = specializedFuncDeclRef; - - // Once the generic parameters (if any) have been dealt with, - // any remaining specialization arguments are for existential/interface - // specialization parameters, attached to the value parameters - // of the entry point. - // - args += genericSpecializationParamCount; - argCount -= genericSpecializationParamCount; - - auto existentialSpecializationParamCount = getExistentialSpecializationParamCount(); - SLANG_ASSERT(argCount == existentialSpecializationParamCount); - - for( Index ii = 0; ii < existentialSpecializationParamCount; ++ii ) - { - auto& param = m_existentialSpecializationParams[ii]; - auto& specializationArg = args[ii]; - - // TODO: We need to handle all the cases of "flavor" for the `param`s (not just types) - - auto paramType = as(param.object); - auto argType = as(specializationArg.val); - - auto witness = visitor.isSubtype(argType, paramType, IsSubTypeOptions::None); - if (!witness) + if (argIndex == -1) + { + SLANG_ASSERT(!"generic parameter not found in generic decl"); + continue; + } + auto sub = as(args[argIndex].val); + if (!sub) { - // If no witness was found, then we will be unable to satisfy - // the conformances required. - sink->diagnose(SourceLoc(), Diagnostics::typeArgumentDoesNotConformToInterface, argType, paramType); + sink->diagnose( + constraintDecl, + Diagnostics::expectedTypeForSpecializationArg, + argIndex); continue; } - ExpandedSpecializationArg expandedArg; - expandedArg.val = specializationArg.val; - expandedArg.witness = witness; - info->existentialSpecializationArgs.add(expandedArg); + auto sup = getSup(astBuilder, constraintDeclRef); + auto subTypeWitness = visitor.isSubtype(sub, sup, IsSubTypeOptions::None); + if (subTypeWitness) + { + genericArgs.add(subTypeWitness); + } + else + { + // TODO: diagnose a problem here + sink->diagnose( + constraintDecl, + Diagnostics::typeArgumentDoesNotConformToInterface, + sub, + sup); + continue; + } } - return info; + specializedFuncDeclRef = + getLinkage() + ->getASTBuilder() + ->getGenericAppDeclRef(genericDeclRef, genericArgs.getArrayView()) + .as(); + SLANG_ASSERT(specializedFuncDeclRef); } - /// Create a specialization an existing entry point based on specialization argument expressions. - RefPtr createSpecializedEntryPoint( - EntryPoint* unspecializedEntryPoint, - List const& argExprs, - DiagnosticSink* sink) - { - // We need to convert all of the `Expr` arguments - // into `SpecializationArg`s, so that we can bottleneck - // through the shared logic. - // - List args; - _extractSpecializationArgs(unspecializedEntryPoint, argExprs, args, sink); - if(sink->getErrorCount()) - return nullptr; - - return ((ComponentType*) unspecializedEntryPoint)->specialize( - args.getBuffer(), - args.getCount(), - sink); - } + info->specializedFuncDeclRef = specializedFuncDeclRef; + + // Once the generic parameters (if any) have been dealt with, + // any remaining specialization arguments are for existential/interface + // specialization parameters, attached to the value parameters + // of the entry point. + // + args += genericSpecializationParamCount; + argCount -= genericSpecializationParamCount; + + auto existentialSpecializationParamCount = getExistentialSpecializationParamCount(); + SLANG_ASSERT(argCount == existentialSpecializationParamCount); - Scope* ComponentType::_getOrCreateScopeForLegacyLookup(ASTBuilder* astBuilder) + for (Index ii = 0; ii < existentialSpecializationParamCount; ++ii) { - // The shape of this logic is dictated by the legacy - // behavior for name-based lookup/parsing of types - // specified via the API or command line. - // - // We begin with a dummy scope that has as its parent - // the scope that provides the "base" langauge - // definitions (that scope is necessary because - // it defines keywords like `true` and `false`). - // - if (m_lookupScope) - return m_lookupScope; + auto& param = m_existentialSpecializationParams[ii]; + auto& specializationArg = args[ii]; - Scope* scope = astBuilder->create(); - scope->parent = getLinkage()->getSessionImpl()->slangLanguageScope; - // - // Next, the scope needs to include all of the - // modules in the program as peers, as if they - // were `import`ed into the scope. - // - for( auto module : getModuleDependencies() ) - { - for (auto srcScope = module->getModuleDecl()->ownedScope; srcScope; srcScope = srcScope->nextSibling) - { - if (srcScope->containerDecl != module->getModuleDecl() && srcScope->containerDecl->parentDecl != module->getModuleDecl()) - continue; // Skip scopes that is not part of current module. + // TODO: We need to handle all the cases of "flavor" for the `param`s (not just types) - Scope* moduleScope = astBuilder->create(); - moduleScope->containerDecl = srcScope->containerDecl; + auto paramType = as(param.object); + auto argType = as(specializationArg.val); - moduleScope->nextSibling = scope->nextSibling; - scope->nextSibling = moduleScope; - } + auto witness = visitor.isSubtype(argType, paramType, IsSubTypeOptions::None); + if (!witness) + { + // If no witness was found, then we will be unable to satisfy + // the conformances required. + sink->diagnose( + SourceLoc(), + Diagnostics::typeArgumentDoesNotConformToInterface, + argType, + paramType); + continue; } - m_lookupScope = scope; - return scope; + + ExpandedSpecializationArg expandedArg; + expandedArg.val = specializationArg.val; + expandedArg.witness = witness; + info->existentialSpecializationArgs.add(expandedArg); } - /// Parse an array of strings as specialization arguments. - /// - /// Names in the strings will be parsed in the context of - /// the code loaded into the given compile request. - /// - void parseSpecializationArgStrings( - EndToEndCompileRequest* endToEndReq, - List const& genericArgStrings, - List& outGenericArgs) - { - auto unspecialiedProgram = endToEndReq->getUnspecializedGlobalComponentType(); + return info; +} - // TODO(JS): - // - // We create the scopes on the linkages ASTBuilder. We might want to create a temporary ASTBuilder, - // and let that memory get freed, but is like this because it's not clear if the scopes in ASTNode members - // will dangle if we do. - Scope* scope = unspecialiedProgram->_getOrCreateScopeForLegacyLookup(endToEndReq->getLinkage()->getASTBuilder()); +/// Create a specialization an existing entry point based on specialization argument expressions. +RefPtr createSpecializedEntryPoint( + EntryPoint* unspecializedEntryPoint, + List const& argExprs, + DiagnosticSink* sink) +{ + // We need to convert all of the `Expr` arguments + // into `SpecializationArg`s, so that we can bottleneck + // through the shared logic. + // + List args; + _extractSpecializationArgs(unspecializedEntryPoint, argExprs, args, sink); + if (sink->getErrorCount()) + return nullptr; - // We are going to do some semantic checking, so we need to - // set up a `SemanticsVistitor` that we can use. - // - auto linkage = endToEndReq->getLinkage(); - auto sink = endToEndReq->getSink(); + return ((ComponentType*)unspecializedEntryPoint) + ->specialize(args.getBuffer(), args.getCount(), sink); +} - SharedSemanticsContext sharedSemanticsContext( - linkage, - nullptr, - sink); - SemanticsVisitor semantics(&sharedSemanticsContext); +Scope* ComponentType::_getOrCreateScopeForLegacyLookup(ASTBuilder* astBuilder) +{ + // The shape of this logic is dictated by the legacy + // behavior for name-based lookup/parsing of types + // specified via the API or command line. + // + // We begin with a dummy scope that has as its parent + // the scope that provides the "base" langauge + // definitions (that scope is necessary because + // it defines keywords like `true` and `false`). + // + if (m_lookupScope) + return m_lookupScope; - // We will be looping over the generic argument strings - // that the user provided via the API (or command line), - // and parsing+checking each into an `Expr`. - // - // This loop will *not* handle coercing the arguments - // to be types. - // - for(auto name : genericArgStrings) + Scope* scope = astBuilder->create(); + scope->parent = getLinkage()->getSessionImpl()->slangLanguageScope; + // + // Next, the scope needs to include all of the + // modules in the program as peers, as if they + // were `import`ed into the scope. + // + for (auto module : getModuleDependencies()) + { + for (auto srcScope = module->getModuleDecl()->ownedScope; srcScope; + srcScope = srcScope->nextSibling) { - Expr* argExpr = linkage->parseTermString(name, scope); - argExpr = semantics.CheckTerm(argExpr); + if (srcScope->containerDecl != module->getModuleDecl() && + srcScope->containerDecl->parentDecl != module->getModuleDecl()) + continue; // Skip scopes that is not part of current module. - if(!argExpr) - { - sink->diagnose(SourceLoc(), Diagnostics::internalCompilerError, "couldn't parse specialization argument"); - return; - } + Scope* moduleScope = astBuilder->create(); + moduleScope->containerDecl = srcScope->containerDecl; - outGenericArgs.add(argExpr); + moduleScope->nextSibling = scope->nextSibling; + scope->nextSibling = moduleScope; } } + m_lookupScope = scope; + return scope; +} - Type* Linkage::specializeType( - Type* unspecializedType, - Int argCount, - Type* const* args, - DiagnosticSink* sink) - { - SLANG_ASSERT(unspecializedType); - - // TODO: We should cache and re-use specialized types - // when the exact same arguments are provided again later. +/// Parse an array of strings as specialization arguments. +/// +/// Names in the strings will be parsed in the context of +/// the code loaded into the given compile request. +/// +void parseSpecializationArgStrings( + EndToEndCompileRequest* endToEndReq, + List const& genericArgStrings, + List& outGenericArgs) +{ + auto unspecialiedProgram = endToEndReq->getUnspecializedGlobalComponentType(); - SharedSemanticsContext sharedSemanticsContext(this, nullptr, sink); - SemanticsVisitor visitor(&sharedSemanticsContext); + // TODO(JS): + // + // We create the scopes on the linkages ASTBuilder. We might want to create a temporary + // ASTBuilder, and let that memory get freed, but is like this because it's not clear if the + // scopes in ASTNode members will dangle if we do. + Scope* scope = unspecialiedProgram->_getOrCreateScopeForLegacyLookup( + endToEndReq->getLinkage()->getASTBuilder()); + + // We are going to do some semantic checking, so we need to + // set up a `SemanticsVistitor` that we can use. + // + auto linkage = endToEndReq->getLinkage(); + auto sink = endToEndReq->getSink(); - SpecializationParams specializationParams; - _collectExistentialSpecializationParamsRec(getASTBuilder(), specializationParams, unspecializedType, SourceLoc()); + SharedSemanticsContext sharedSemanticsContext(linkage, nullptr, sink); + SemanticsVisitor semantics(&sharedSemanticsContext); - assert(specializationParams.getCount() == argCount); + // We will be looping over the generic argument strings + // that the user provided via the API (or command line), + // and parsing+checking each into an `Expr`. + // + // This loop will *not* handle coercing the arguments + // to be types. + // + for (auto name : genericArgStrings) + { + Expr* argExpr = linkage->parseTermString(name, scope); + argExpr = semantics.CheckTerm(argExpr); - ExpandedSpecializationArgs specializationArgs; - for( Int aa = 0; aa < argCount; ++aa ) + if (!argExpr) { - auto paramType = as(specializationParams[aa].object); - auto argType = args[aa]; - - ExpandedSpecializationArg arg; - arg.val = argType; - arg.witness = visitor.isSubtype(argType, paramType, IsSubTypeOptions::None); - specializationArgs.add(arg); + sink->diagnose( + SourceLoc(), + Diagnostics::internalCompilerError, + "couldn't parse specialization argument"); + return; } - ExistentialSpecializedType* specializedType = m_astBuilder->getOrCreate( - unspecializedType, specializationArgs); + outGenericArgs.add(argExpr); + } +} + +Type* Linkage::specializeType( + Type* unspecializedType, + Int argCount, + Type* const* args, + DiagnosticSink* sink) +{ + SLANG_ASSERT(unspecializedType); - m_specializedTypes.add(specializedType); + // TODO: We should cache and re-use specialized types + // when the exact same arguments are provided again later. - return specializedType; - } + SharedSemanticsContext sharedSemanticsContext(this, nullptr, sink); + SemanticsVisitor visitor(&sharedSemanticsContext); + + SpecializationParams specializationParams; + _collectExistentialSpecializationParamsRec( + getASTBuilder(), + specializationParams, + unspecializedType, + SourceLoc()); + + assert(specializationParams.getCount() == argCount); - /// Shared implementation logic for the `_createSpecializedProgram*` entry points. - static RefPtr _createSpecializedProgramImpl( - Linkage* linkage, - ComponentType* unspecializedProgram, - List const& specializationArgExprs, - DiagnosticSink* sink) + ExpandedSpecializationArgs specializationArgs; + for (Int aa = 0; aa < argCount; ++aa) { - // If there are no specialization arguments, - // then the the result of specialization should - // be the same as the input. - // - auto specializationArgCount = specializationArgExprs.getCount(); - if( specializationArgCount == 0 ) - { - return unspecializedProgram; - } + auto paramType = as(specializationParams[aa].object); + auto argType = args[aa]; - auto specializationParamCount = unspecializedProgram->getSpecializationParamCount(); - if(specializationArgCount != specializationParamCount ) - { - sink->diagnose(SourceLoc(), Diagnostics::mismatchSpecializationArguments, - specializationParamCount, - specializationArgCount); - return nullptr; - } + ExpandedSpecializationArg arg; + arg.val = argType; + arg.witness = visitor.isSubtype(argType, paramType, IsSubTypeOptions::None); + specializationArgs.add(arg); + } - // We have an appropriate number of arguments for the global specialization parameters, - // and now we need to check that the arguments conform to the declared constraints. - // - SharedSemanticsContext visitor(linkage, nullptr, sink); + ExistentialSpecializedType* specializedType = + m_astBuilder->getOrCreate( + unspecializedType, + specializationArgs); - List specializationArgs; - _extractSpecializationArgs(unspecializedProgram, specializationArgExprs, specializationArgs, sink); - if(sink->getErrorCount()) - return nullptr; + m_specializedTypes.add(specializedType); - auto specializedProgram = unspecializedProgram->specialize( - specializationArgs.getBuffer(), - specializationArgs.getCount(), - sink); + return specializedType; +} - return specializedProgram; +/// Shared implementation logic for the `_createSpecializedProgram*` entry points. +static RefPtr _createSpecializedProgramImpl( + Linkage* linkage, + ComponentType* unspecializedProgram, + List const& specializationArgExprs, + DiagnosticSink* sink) +{ + // If there are no specialization arguments, + // then the the result of specialization should + // be the same as the input. + // + auto specializationArgCount = specializationArgExprs.getCount(); + if (specializationArgCount == 0) + { + return unspecializedProgram; } - /// Specialize an entry point that was checked by the front-end, based on specialization arguments. - /// - /// If the end-to-end compile request included specialization argument strings - /// for this entry point, then they will be parsed, checked, and used - /// as arguments to the generic entry point. - /// - /// Returns a specialized entry point if everything worked as expected. - /// Returns null and diagnoses errors if anything goes wrong. - /// - RefPtr createSpecializedEntryPoint( - EndToEndCompileRequest* endToEndReq, - EntryPoint* unspecializedEntryPoint, - EndToEndCompileRequest::EntryPointInfo const& entryPointInfo) + auto specializationParamCount = unspecializedProgram->getSpecializationParamCount(); + if (specializationArgCount != specializationParamCount) { - auto sink = endToEndReq->getSink(); + sink->diagnose( + SourceLoc(), + Diagnostics::mismatchSpecializationArguments, + specializationParamCount, + specializationArgCount); + return nullptr; + } - // If the user specified generic arguments for the entry point, - // then we will need to parse the arguments first. - // - List specializationArgExprs; - parseSpecializationArgStrings( - endToEndReq, - entryPointInfo.specializationArgStrings, - specializationArgExprs); - - // Next we specialize the entry point function given the parsed - // generic argument expressions. - // - auto entryPoint = createSpecializedEntryPoint( - unspecializedEntryPoint, - specializationArgExprs, - sink); + // We have an appropriate number of arguments for the global specialization parameters, + // and now we need to check that the arguments conform to the declared constraints. + // + SharedSemanticsContext visitor(linkage, nullptr, sink); + + List specializationArgs; + _extractSpecializationArgs( + unspecializedProgram, + specializationArgExprs, + specializationArgs, + sink); + if (sink->getErrorCount()) + return nullptr; + + auto specializedProgram = unspecializedProgram->specialize( + specializationArgs.getBuffer(), + specializationArgs.getCount(), + sink); + + return specializedProgram; +} - return entryPoint; - } +/// Specialize an entry point that was checked by the front-end, based on specialization arguments. +/// +/// If the end-to-end compile request included specialization argument strings +/// for this entry point, then they will be parsed, checked, and used +/// as arguments to the generic entry point. +/// +/// Returns a specialized entry point if everything worked as expected. +/// Returns null and diagnoses errors if anything goes wrong. +/// +RefPtr createSpecializedEntryPoint( + EndToEndCompileRequest* endToEndReq, + EntryPoint* unspecializedEntryPoint, + EndToEndCompileRequest::EntryPointInfo const& entryPointInfo) +{ + auto sink = endToEndReq->getSink(); - /// Create a specialized component type for the global scope of the given compile request. - /// - /// The specialized program will be consistent with that created by - /// `createUnspecializedGlobalComponentType`, and will simply fill in - /// its specialization parameters with the arguments (if any) supllied - /// as part fo the end-to-end compile request. - /// - /// The layout of the new component type will be consistent with that - /// of the original *if* there are no global generic type parameters - /// (only interface/existential parameters). - /// - RefPtr createSpecializedGlobalComponentType( - EndToEndCompileRequest* endToEndReq) - { - // The compile request must have already completed front-end processing, - // so that we have an unspecialized program available, and now only need - // to parse and check any generic arguments that are being supplied for - // global or entry-point generic parameters. - // - auto unspecializedProgram = endToEndReq->getUnspecializedGlobalComponentType(); - auto linkage = endToEndReq->getLinkage(); - auto sink = endToEndReq->getSink(); + // If the user specified generic arguments for the entry point, + // then we will need to parse the arguments first. + // + List specializationArgExprs; + parseSpecializationArgStrings( + endToEndReq, + entryPointInfo.specializationArgStrings, + specializationArgExprs); + + // Next we specialize the entry point function given the parsed + // generic argument expressions. + // + auto entryPoint = + createSpecializedEntryPoint(unspecializedEntryPoint, specializationArgExprs, sink); - // First, let's parse the specialization argument strings that were - // provided via the API, so that we can match them - // against what was declared in the program. - // - List globalSpecializationArgs; - parseSpecializationArgStrings( - endToEndReq, - endToEndReq->m_globalSpecializationArgStrings, - globalSpecializationArgs); - - // Don't proceed further if anything failed to parse. - if(sink->getErrorCount()) - return nullptr; - - // Now we create the initial specialized program by - // applying the global generic arguments (if any) to the - // unspecialized program. - // - auto specializedProgram = _createSpecializedProgramImpl( - linkage, - unspecializedProgram, - globalSpecializationArgs, - sink); + return entryPoint; +} - // If anything went wrong with the global generic - // arguments, then bail out now. - // - if(!specializedProgram) - return nullptr; +/// Create a specialized component type for the global scope of the given compile request. +/// +/// The specialized program will be consistent with that created by +/// `createUnspecializedGlobalComponentType`, and will simply fill in +/// its specialization parameters with the arguments (if any) supllied +/// as part fo the end-to-end compile request. +/// +/// The layout of the new component type will be consistent with that +/// of the original *if* there are no global generic type parameters +/// (only interface/existential parameters). +/// +RefPtr createSpecializedGlobalComponentType(EndToEndCompileRequest* endToEndReq) +{ + // The compile request must have already completed front-end processing, + // so that we have an unspecialized program available, and now only need + // to parse and check any generic arguments that are being supplied for + // global or entry-point generic parameters. + // + auto unspecializedProgram = endToEndReq->getUnspecializedGlobalComponentType(); + auto linkage = endToEndReq->getLinkage(); + auto sink = endToEndReq->getSink(); - // Next we will deal with the entry points for the - // new specialized program. - // - // If the user specified explicit entry points as part of the - // end-to-end request, then we only want to process those (and - // ignore any other `[shader(...)]`-attributed entry points). - // - // However, if the user specified *no* entry points as part - // of the end-to-end request, then we would like to go - // ahead and consider all the entry points that were found - // by the front-end. - // - Index entryPointCount = endToEndReq->m_entryPoints.getCount(); - if( entryPointCount == 0 ) - { - entryPointCount = unspecializedProgram->getEntryPointCount(); - endToEndReq->m_entryPoints.setCount(entryPointCount); - } + // First, let's parse the specialization argument strings that were + // provided via the API, so that we can match them + // against what was declared in the program. + // + List globalSpecializationArgs; + parseSpecializationArgStrings( + endToEndReq, + endToEndReq->m_globalSpecializationArgStrings, + globalSpecializationArgs); + + // Don't proceed further if anything failed to parse. + if (sink->getErrorCount()) + return nullptr; + + // Now we create the initial specialized program by + // applying the global generic arguments (if any) to the + // unspecialized program. + // + auto specializedProgram = _createSpecializedProgramImpl( + linkage, + unspecializedProgram, + globalSpecializationArgs, + sink); + + // If anything went wrong with the global generic + // arguments, then bail out now. + // + if (!specializedProgram) + return nullptr; - return specializedProgram; + // Next we will deal with the entry points for the + // new specialized program. + // + // If the user specified explicit entry points as part of the + // end-to-end request, then we only want to process those (and + // ignore any other `[shader(...)]`-attributed entry points). + // + // However, if the user specified *no* entry points as part + // of the end-to-end request, then we would like to go + // ahead and consider all the entry points that were found + // by the front-end. + // + Index entryPointCount = endToEndReq->m_entryPoints.getCount(); + if (entryPointCount == 0) + { + entryPointCount = unspecializedProgram->getEntryPointCount(); + endToEndReq->m_entryPoints.setCount(entryPointCount); } - /// Create a specialized program based on the given compile request. - /// - /// The specialized program created here includes both the global - /// scope for all the translation units involved and all the entry - /// points, and it also includes any specialization arguments - /// that were supplied. - /// - /// It is important to note that this function specializes - /// the global scope and the entry points in isolation and then - /// composes them, and that this can lead to different layout - /// from the result of `createUnspecializedGlobalAndEntryPointsComponentType`. - /// - /// If we have a module `M` with entry point `E`, and each has one - /// specialization parameter, then `createUnspecialized...` will yield: - /// - /// compose(M,E) - /// - /// That composed type will have two specialization parameters (the one - /// from `M` plus the one from `E`) and so we might specialize it to get: - /// - /// specialize(compose(M,E), X, Y) - /// - /// while if we use `createSpecialized...` we will get: - /// - /// compose(specialize(M,X), specialize(E,Y)) - /// - /// While these options are semantically equivalent, they would not lay - /// out the same way in memory. - /// - /// There are many reasons why an application might prefer one over the - /// other, and an application that cares should use the more explicit - /// APIs to construct what they want. The behavior of this function - /// is just to provide a reasonable default for use by end-to-end - /// compilation (e.g., from the command line). - /// - RefPtr createSpecializedGlobalAndEntryPointsComponentType( - EndToEndCompileRequest* endToEndReq, - List>& outSpecializedEntryPoints) - { - auto specializedGlobalComponentType = endToEndReq->getSpecializedGlobalComponentType(); + return specializedProgram; +} - List> allComponentTypes; - allComponentTypes.add(specializedGlobalComponentType); +/// Create a specialized program based on the given compile request. +/// +/// The specialized program created here includes both the global +/// scope for all the translation units involved and all the entry +/// points, and it also includes any specialization arguments +/// that were supplied. +/// +/// It is important to note that this function specializes +/// the global scope and the entry points in isolation and then +/// composes them, and that this can lead to different layout +/// from the result of `createUnspecializedGlobalAndEntryPointsComponentType`. +/// +/// If we have a module `M` with entry point `E`, and each has one +/// specialization parameter, then `createUnspecialized...` will yield: +/// +/// compose(M,E) +/// +/// That composed type will have two specialization parameters (the one +/// from `M` plus the one from `E`) and so we might specialize it to get: +/// +/// specialize(compose(M,E), X, Y) +/// +/// while if we use `createSpecialized...` we will get: +/// +/// compose(specialize(M,X), specialize(E,Y)) +/// +/// While these options are semantically equivalent, they would not lay +/// out the same way in memory. +/// +/// There are many reasons why an application might prefer one over the +/// other, and an application that cares should use the more explicit +/// APIs to construct what they want. The behavior of this function +/// is just to provide a reasonable default for use by end-to-end +/// compilation (e.g., from the command line). +/// +RefPtr createSpecializedGlobalAndEntryPointsComponentType( + EndToEndCompileRequest* endToEndReq, + List>& outSpecializedEntryPoints) +{ + auto specializedGlobalComponentType = endToEndReq->getSpecializedGlobalComponentType(); - auto unspecializedGlobalAndEntryPointsComponentType = endToEndReq->getUnspecializedGlobalAndEntryPointsComponentType(); + List> allComponentTypes; + allComponentTypes.add(specializedGlobalComponentType); - // It is possible that there were entry points other than those specified - // vai the original end-to-end compile request. In particular: - // - // * It is possible to compile with *no* entry points specified, in which - // case the current compiler behavior is to use any entry points marked - // via `[shader(...)]` attributes in the AST. - // - // * It is possible for entry points to come into play via serialized libraries - // loaded with `-r` on the command line (or the equivalent API). - // - // We will thus draw a distinction between the "specified" entry points, - // and the "found" entry points. - // - auto specifiedEntryPointCount = endToEndReq->m_entryPoints.getCount(); - auto foundEntryPointCount = unspecializedGlobalAndEntryPointsComponentType->getEntryPointCount(); + auto unspecializedGlobalAndEntryPointsComponentType = + endToEndReq->getUnspecializedGlobalAndEntryPointsComponentType(); - SLANG_ASSERT(foundEntryPointCount >= specifiedEntryPointCount); + // It is possible that there were entry points other than those specified + // vai the original end-to-end compile request. In particular: + // + // * It is possible to compile with *no* entry points specified, in which + // case the current compiler behavior is to use any entry points marked + // via `[shader(...)]` attributes in the AST. + // + // * It is possible for entry points to come into play via serialized libraries + // loaded with `-r` on the command line (or the equivalent API). + // + // We will thus draw a distinction between the "specified" entry points, + // and the "found" entry points. + // + auto specifiedEntryPointCount = endToEndReq->m_entryPoints.getCount(); + auto foundEntryPointCount = + unspecializedGlobalAndEntryPointsComponentType->getEntryPointCount(); - // For any entry points that were specified, we can use the specialization - // argument information provided via API or command line. - // - for(Index ii = 0; ii < specifiedEntryPointCount; ++ii) - { - auto& entryPointInfo = endToEndReq->m_entryPoints[ii]; - auto unspecializedEntryPoint = unspecializedGlobalAndEntryPointsComponentType->getEntryPoint(ii); + SLANG_ASSERT(foundEntryPointCount >= specifiedEntryPointCount); - auto specializedEntryPoint = createSpecializedEntryPoint(endToEndReq, unspecializedEntryPoint, entryPointInfo); - allComponentTypes.add(specializedEntryPoint); + // For any entry points that were specified, we can use the specialization + // argument information provided via API or command line. + // + for (Index ii = 0; ii < specifiedEntryPointCount; ++ii) + { + auto& entryPointInfo = endToEndReq->m_entryPoints[ii]; + auto unspecializedEntryPoint = + unspecializedGlobalAndEntryPointsComponentType->getEntryPoint(ii); - outSpecializedEntryPoints.add(specializedEntryPoint); - } + auto specializedEntryPoint = + createSpecializedEntryPoint(endToEndReq, unspecializedEntryPoint, entryPointInfo); + allComponentTypes.add(specializedEntryPoint); - // There might have been errors during the specialization above, - // so we will bail out early if anything went wrong, rather - // then try to create a composite where some of the constituent - // component types might be null. - // - if(endToEndReq->getSink()->getErrorCount() != 0) - return nullptr; + outSpecializedEntryPoints.add(specializedEntryPoint); + } - // Any entry points beyond those that were specified up front will be - // assumed to not need/want specialization. - // - for( Index ii = specifiedEntryPointCount; ii < foundEntryPointCount; ++ii ) - { - auto unspecializedEntryPoint = unspecializedGlobalAndEntryPointsComponentType->getEntryPoint(ii); - allComponentTypes.add(unspecializedEntryPoint); - outSpecializedEntryPoints.add(unspecializedEntryPoint); - } + // There might have been errors during the specialization above, + // so we will bail out early if anything went wrong, rather + // then try to create a composite where some of the constituent + // component types might be null. + // + if (endToEndReq->getSink()->getErrorCount() != 0) + return nullptr; - RefPtr composed = CompositeComponentType::create(endToEndReq->getLinkage(), allComponentTypes); - return composed; + // Any entry points beyond those that were specified up front will be + // assumed to not need/want specialization. + // + for (Index ii = specifiedEntryPointCount; ii < foundEntryPointCount; ++ii) + { + auto unspecializedEntryPoint = + unspecializedGlobalAndEntryPointsComponentType->getEntryPoint(ii); + allComponentTypes.add(unspecializedEntryPoint); + outSpecializedEntryPoints.add(unspecializedEntryPoint); } - + RefPtr composed = + CompositeComponentType::create(endToEndReq->getLinkage(), allComponentTypes); + return composed; } + + +} // namespace Slang diff --git a/source/slang/slang-check-stmt.cpp b/source/slang/slang-check-stmt.cpp index 8b0e0b284..d02140d70 100644 --- a/source/slang/slang-check-stmt.cpp +++ b/source/slang/slang-check-stmt.cpp @@ -6,722 +6,774 @@ namespace Slang { - namespace - { - /// RAII-like type for establishing an "outer" statement during nested checks. - /// - /// The `SemanticsStmtVisitor` maintains a linked list of outer statements - /// using `OuterStmtInfo` records stored on the recursive call stack during - /// checking. This type creates a sub-`SemanticsStmtVisitor` that has one - /// additional outer statement added to the stack of outer statements. - /// - /// The outer statements are used to validate and resolve things like - /// the target of `break` or `continue` statements. - /// - struct WithOuterStmt : public SemanticsStmtVisitor - { - public: - WithOuterStmt(SemanticsStmtVisitor* visitor, Stmt* outerStmt) - : SemanticsStmtVisitor(visitor->withOuterStmts(&m_outerStmt)) - { - m_outerStmt.next = visitor->getOuterStmts(); - m_outerStmt.stmt = outerStmt; - } - - private: - OuterStmtInfo m_outerStmt; - }; - } - - void SemanticsVisitor::checkStmt(Stmt* stmt, SemanticsContext const& context) +namespace +{ +/// RAII-like type for establishing an "outer" statement during nested checks. +/// +/// The `SemanticsStmtVisitor` maintains a linked list of outer statements +/// using `OuterStmtInfo` records stored on the recursive call stack during +/// checking. This type creates a sub-`SemanticsStmtVisitor` that has one +/// additional outer statement added to the stack of outer statements. +/// +/// The outer statements are used to validate and resolve things like +/// the target of `break` or `continue` statements. +/// +struct WithOuterStmt : public SemanticsStmtVisitor +{ +public: + WithOuterStmt(SemanticsStmtVisitor* visitor, Stmt* outerStmt) + : SemanticsStmtVisitor(visitor->withOuterStmts(&m_outerStmt)) { - if (!stmt) return; - dispatchStmt(stmt, context); - checkModifiers(stmt); + m_outerStmt.next = visitor->getOuterStmts(); + m_outerStmt.stmt = outerStmt; } - void SemanticsStmtVisitor::visitDeclStmt(DeclStmt* stmt) - { - // When we encounter a declaration during statement checking, - // we expect that it hasn't been checked yet (because otherwise - // it would be referenced before its declaration point), but - // we will bottleneck through the `ensureDecl()` path anyway, - // to unify with the rest of semantic checking. - // - // TODO: This logic might not suffice for something like a - // local `struct` declaration, where it would have members - // that need to be recursively checked. - // - ensureDeclBase(stmt->decl, DeclCheckState::DefinitionChecked, this); - } +private: + OuterStmtInfo m_outerStmt; +}; +} // namespace - void SemanticsStmtVisitor::visitBlockStmt(BlockStmt* stmt) - { - // Make sure to fully check all nested agg type decls first. - if (stmt->scopeDecl) - { - for (auto decl : stmt->scopeDecl->members) - { - if (as(decl)) - ensureAllDeclsRec(decl, DeclCheckState::DefinitionChecked); - } - } - checkStmt(stmt->body); - } +void SemanticsVisitor::checkStmt(Stmt* stmt, SemanticsContext const& context) +{ + if (!stmt) + return; + dispatchStmt(stmt, context); + checkModifiers(stmt); +} - void SemanticsStmtVisitor::visitSeqStmt(SeqStmt* stmt) +void SemanticsStmtVisitor::visitDeclStmt(DeclStmt* stmt) +{ + // When we encounter a declaration during statement checking, + // we expect that it hasn't been checked yet (because otherwise + // it would be referenced before its declaration point), but + // we will bottleneck through the `ensureDecl()` path anyway, + // to unify with the rest of semantic checking. + // + // TODO: This logic might not suffice for something like a + // local `struct` declaration, where it would have members + // that need to be recursively checked. + // + ensureDeclBase(stmt->decl, DeclCheckState::DefinitionChecked, this); +} + +void SemanticsStmtVisitor::visitBlockStmt(BlockStmt* stmt) +{ + // Make sure to fully check all nested agg type decls first. + if (stmt->scopeDecl) { - for(auto ss : stmt->stmts) + for (auto decl : stmt->scopeDecl->members) { - checkStmt(ss); + if (as(decl)) + ensureAllDeclsRec(decl, DeclCheckState::DefinitionChecked); } } + checkStmt(stmt->body); +} - void SemanticsStmtVisitor::visitLabelStmt(LabelStmt* stmt) +void SemanticsStmtVisitor::visitSeqStmt(SeqStmt* stmt) +{ + for (auto ss : stmt->stmts) { - WithOuterStmt subContext(this, stmt); - subContext.checkStmt(stmt->innerStmt); + checkStmt(ss); } +} - void SemanticsStmtVisitor::checkStmt(Stmt* stmt) - { - SemanticsVisitor::checkStmt(stmt, *this); - } +void SemanticsStmtVisitor::visitLabelStmt(LabelStmt* stmt) +{ + WithOuterStmt subContext(this, stmt); + subContext.checkStmt(stmt->innerStmt); +} - template - T* SemanticsStmtVisitor::FindOuterStmt() +void SemanticsStmtVisitor::checkStmt(Stmt* stmt) +{ + SemanticsVisitor::checkStmt(stmt, *this); +} + +template +T* SemanticsStmtVisitor::FindOuterStmt() +{ + for (auto outerStmtInfo = m_outerStmts; outerStmtInfo; outerStmtInfo = outerStmtInfo->next) { - for(auto outerStmtInfo = m_outerStmts; outerStmtInfo; outerStmtInfo = outerStmtInfo->next) - { - auto outerStmt = outerStmtInfo->stmt; - auto found = as(outerStmt); - if (found) - return found; - } - return nullptr; + auto outerStmt = outerStmtInfo->stmt; + auto found = as(outerStmt); + if (found) + return found; } + return nullptr; +} - Stmt* SemanticsStmtVisitor::findOuterStmtWithLabel(Name* label) +Stmt* SemanticsStmtVisitor::findOuterStmtWithLabel(Name* label) +{ + for (auto outerStmtInfo = m_outerStmts; outerStmtInfo; outerStmtInfo = outerStmtInfo->next) { - for (auto outerStmtInfo = m_outerStmts; outerStmtInfo; outerStmtInfo = outerStmtInfo->next) + auto outerStmt = outerStmtInfo->stmt; + auto found = as(outerStmt); + if (found) { - auto outerStmt = outerStmtInfo->stmt; - auto found = as(outerStmt); - if (found) + if (found->label.getName() == label) { - if (found->label.getName() == label) - { - return found->innerStmt; - } + return found->innerStmt; } } - return nullptr; } + return nullptr; +} - void SemanticsStmtVisitor::visitBreakStmt(BreakStmt *stmt) +void SemanticsStmtVisitor::visitBreakStmt(BreakStmt* stmt) +{ + Stmt* targetStmt = nullptr; + if (stmt->targetLabel.type == TokenType::Identifier) { - Stmt* targetStmt = nullptr; - if (stmt->targetLabel.type == TokenType::Identifier) + // This is a break statement with an explicit target label. + // Try to find the outer stmt with the label. + targetStmt = findOuterStmtWithLabel(stmt->targetLabel.getName()); + if (!targetStmt) { - // This is a break statement with an explicit target label. - // Try to find the outer stmt with the label. - targetStmt = findOuterStmtWithLabel(stmt->targetLabel.getName()); - if (!targetStmt) - { - getSink()->diagnose(stmt, Diagnostics::breakLabelNotFound, stmt->targetLabel.getName()); - } - if (!as(targetStmt)) - { - getSink()->diagnose(stmt, Diagnostics::targetLabelDoesNotMarkBreakableStmt, stmt->targetLabel.getName()); - } + getSink()->diagnose(stmt, Diagnostics::breakLabelNotFound, stmt->targetLabel.getName()); } - else + if (!as(targetStmt)) { - // For `break` statements without an explicit target, - // find the inner most breakable stmt. - targetStmt = FindOuterStmt(); - if (!targetStmt) - { - getSink()->diagnose(stmt, Diagnostics::breakOutsideLoop); - } + getSink()->diagnose( + stmt, + Diagnostics::targetLabelDoesNotMarkBreakableStmt, + stmt->targetLabel.getName()); } - stmt->parentStmt = targetStmt; } - - void SemanticsStmtVisitor::visitContinueStmt(ContinueStmt *stmt) + else { - auto outer = FindOuterStmt(); - if (!outer) + // For `break` statements without an explicit target, + // find the inner most breakable stmt. + targetStmt = FindOuterStmt(); + if (!targetStmt) { - getSink()->diagnose(stmt, Diagnostics::continueOutsideLoop); + getSink()->diagnose(stmt, Diagnostics::breakOutsideLoop); } - stmt->parentStmt = outer; } + stmt->parentStmt = targetStmt; +} - Expr* SemanticsVisitor::checkPredicateExpr(Expr* expr) +void SemanticsStmtVisitor::visitContinueStmt(ContinueStmt* stmt) +{ + auto outer = FindOuterStmt(); + if (!outer) { - if (as(expr)) - { - getSink()->diagnose(expr, Diagnostics::assignmentInPredicateExpr); - } - Expr* e = expr; - e = CheckTerm(e); - e = coerce(CoercionSite::General, m_astBuilder->getBoolType(), e); - return e; + getSink()->diagnose(stmt, Diagnostics::continueOutsideLoop); } + stmt->parentStmt = outer; +} - void SemanticsStmtVisitor::visitDoWhileStmt(DoWhileStmt *stmt) +Expr* SemanticsVisitor::checkPredicateExpr(Expr* expr) +{ + if (as(expr)) { - checkModifiers(stmt); - WithOuterStmt subContext(this, stmt); - - stmt->predicate = checkPredicateExpr(stmt->predicate); - subContext.checkStmt(stmt->statement); - checkLoopInDifferentiableFunc(stmt); + getSink()->diagnose(expr, Diagnostics::assignmentInPredicateExpr); } + Expr* e = expr; + e = CheckTerm(e); + e = coerce(CoercionSite::General, m_astBuilder->getBoolType(), e); + return e; +} - void SemanticsStmtVisitor::visitForStmt(ForStmt *stmt) - { - WithOuterStmt subContext(this, stmt); - checkModifiers(stmt); - checkStmt(stmt->initialStatement); +void SemanticsStmtVisitor::visitDoWhileStmt(DoWhileStmt* stmt) +{ + checkModifiers(stmt); + WithOuterStmt subContext(this, stmt); - if (stmt->predicateExpression) - { - stmt->predicateExpression = checkPredicateExpr(stmt->predicateExpression); - } - if (stmt->sideEffectExpression) - { - stmt->sideEffectExpression = CheckExpr(stmt->sideEffectExpression); - } - subContext.checkStmt(stmt->statement); - - tryInferLoopMaxIterations(stmt); + stmt->predicate = checkPredicateExpr(stmt->predicate); + subContext.checkStmt(stmt->statement); + checkLoopInDifferentiableFunc(stmt); +} - checkLoopInDifferentiableFunc(stmt); - } +void SemanticsStmtVisitor::visitForStmt(ForStmt* stmt) +{ + WithOuterStmt subContext(this, stmt); + checkModifiers(stmt); + checkStmt(stmt->initialStatement); - Expr* SemanticsVisitor::checkExpressionAndExpectIntegerConstant(Expr* expr, IntVal** outIntVal, ConstantFoldingKind kind) + if (stmt->predicateExpression) { - expr = CheckExpr(expr); - auto intVal = CheckIntegerConstantExpression(expr, IntegerConstantExpressionCoercionType::AnyInteger, nullptr, kind); - if (outIntVal) - *outIntVal = intVal; - return expr; + stmt->predicateExpression = checkPredicateExpr(stmt->predicateExpression); } - - void SemanticsStmtVisitor::visitCompileTimeForStmt(CompileTimeForStmt* stmt) + if (stmt->sideEffectExpression) { - WithOuterStmt subContext(this, stmt); + stmt->sideEffectExpression = CheckExpr(stmt->sideEffectExpression); + } + subContext.checkStmt(stmt->statement); - stmt->varDecl->type.type = m_astBuilder->getIntType(); - addModifier(stmt->varDecl, m_astBuilder->create()); - stmt->varDecl->setCheckState(DeclCheckState::DefinitionChecked); + tryInferLoopMaxIterations(stmt); - IntVal* rangeBeginVal = nullptr; - IntVal* rangeEndVal = nullptr; + checkLoopInDifferentiableFunc(stmt); +} - if (stmt->rangeBeginExpr) - { - stmt->rangeBeginExpr = checkExpressionAndExpectIntegerConstant(stmt->rangeBeginExpr, &rangeBeginVal, ConstantFoldingKind::LinkTime); - } - else - { - ConstantIntVal* rangeBeginConst = m_astBuilder->getIntVal(m_astBuilder->getIntType(), 0); - rangeBeginVal = rangeBeginConst; - } +Expr* SemanticsVisitor::checkExpressionAndExpectIntegerConstant( + Expr* expr, + IntVal** outIntVal, + ConstantFoldingKind kind) +{ + expr = CheckExpr(expr); + auto intVal = CheckIntegerConstantExpression( + expr, + IntegerConstantExpressionCoercionType::AnyInteger, + nullptr, + kind); + if (outIntVal) + *outIntVal = intVal; + return expr; +} - stmt->rangeEndExpr = checkExpressionAndExpectIntegerConstant(stmt->rangeEndExpr, &rangeEndVal, ConstantFoldingKind::LinkTime); +void SemanticsStmtVisitor::visitCompileTimeForStmt(CompileTimeForStmt* stmt) +{ + WithOuterStmt subContext(this, stmt); - stmt->rangeBeginVal = rangeBeginVal; - stmt->rangeEndVal = rangeEndVal; + stmt->varDecl->type.type = m_astBuilder->getIntType(); + addModifier(stmt->varDecl, m_astBuilder->create()); + stmt->varDecl->setCheckState(DeclCheckState::DefinitionChecked); - subContext.checkStmt(stmt->body); - } + IntVal* rangeBeginVal = nullptr; + IntVal* rangeEndVal = nullptr; - void SemanticsStmtVisitor::validateCaseStmts(SwitchStmt* stmt, DiagnosticSink* sink) + if (stmt->rangeBeginExpr) { - auto blockStmt = as(stmt->body); - if (!blockStmt) - return; + stmt->rangeBeginExpr = checkExpressionAndExpectIntegerConstant( + stmt->rangeBeginExpr, + &rangeBeginVal, + ConstantFoldingKind::LinkTime); + } + else + { + ConstantIntVal* rangeBeginConst = m_astBuilder->getIntVal(m_astBuilder->getIntType(), 0); + rangeBeginVal = rangeBeginConst; + } - auto seqStmt = as(blockStmt->body); - if (!seqStmt) - return; + stmt->rangeEndExpr = checkExpressionAndExpectIntegerConstant( + stmt->rangeEndExpr, + &rangeEndVal, + ConstantFoldingKind::LinkTime); + + stmt->rangeBeginVal = rangeBeginVal; + stmt->rangeEndVal = rangeEndVal; + + subContext.checkStmt(stmt->body); +} - bool hasDefaultStmt = false; - HashSet caseStmtVals; - for (auto& sStmt : seqStmt->stmts) +void SemanticsStmtVisitor::validateCaseStmts(SwitchStmt* stmt, DiagnosticSink* sink) +{ + auto blockStmt = as(stmt->body); + if (!blockStmt) + return; + + auto seqStmt = as(blockStmt->body); + if (!seqStmt) + return; + + bool hasDefaultStmt = false; + HashSet caseStmtVals; + for (auto& sStmt : seqStmt->stmts) + { + if (auto caseStmt = as(sStmt)) { - if (auto caseStmt = as(sStmt)) + // check that all case tags are unique + if (caseStmt->exprVal) { - // check that all case tags are unique - if (caseStmt->exprVal) + // exprVal contains the constant folded expr, that is checked for + // uniqueness within the scope of the switch statement. + if (!caseStmtVals.add(caseStmt->exprVal)) { - // exprVal contains the constant folded expr, that is checked for - // uniqueness within the scope of the switch statement. - if (!caseStmtVals.add(caseStmt->exprVal)) - { - sink->diagnose(sStmt, Diagnostics::switchDuplicateCases); - return; - } + sink->diagnose(sStmt, Diagnostics::switchDuplicateCases); + return; } } - else if (as(sStmt)) + } + else if (as(sStmt)) + { + // check that there is at most one `default` clause + if (hasDefaultStmt) { - // check that there is at most one `default` clause - if (hasDefaultStmt) - { - sink->diagnose(sStmt, Diagnostics::switchMultipleDefault); - return; - } - hasDefaultStmt = true; + sink->diagnose(sStmt, Diagnostics::switchMultipleDefault); + return; } + hasDefaultStmt = true; } } +} - void SemanticsStmtVisitor::visitSwitchStmt(SwitchStmt* stmt) - { - WithOuterStmt subContext(this, stmt); +void SemanticsStmtVisitor::visitSwitchStmt(SwitchStmt* stmt) +{ + WithOuterStmt subContext(this, stmt); - // TODO(tfoley): need to coerce condition to an integral type... - stmt->condition = CheckExpr(stmt->condition); - subContext.checkStmt(stmt->body); + // TODO(tfoley): need to coerce condition to an integral type... + stmt->condition = CheckExpr(stmt->condition); + subContext.checkStmt(stmt->body); - // check the case value exits within the switch - validateCaseStmts(stmt, getSink()); - } + // check the case value exits within the switch + validateCaseStmts(stmt, getSink()); +} - void SemanticsStmtVisitor::visitCaseStmt(CaseStmt* stmt) +void SemanticsStmtVisitor::visitCaseStmt(CaseStmt* stmt) +{ + auto switchStmt = FindOuterStmt(); + if (!switchStmt) { - auto switchStmt = FindOuterStmt(); - if (!switchStmt) - { - getSink()->diagnose(stmt, Diagnostics::caseOutsideSwitch); - return; - } + getSink()->diagnose(stmt, Diagnostics::caseOutsideSwitch); + return; + } - // Check that the type for the `case` is consistent with the type for the `switch`. - auto expr = CheckExpr(stmt->expr); - expr = coerce(CoercionSite::Argument, switchStmt->condition->type, expr); + // Check that the type for the `case` is consistent with the type for the `switch`. + auto expr = CheckExpr(stmt->expr); + expr = coerce(CoercionSite::Argument, switchStmt->condition->type, expr); - // coerce to type being switch on, and ensure that value is a compile-time constant - // The Vals in the AST are pointer-unique, making them easy to check for duplicates - // by addeing them to a HashSet. - auto exprVal = checkConstantIntVal(expr); + // coerce to type being switch on, and ensure that value is a compile-time constant + // The Vals in the AST are pointer-unique, making them easy to check for duplicates + // by addeing them to a HashSet. + auto exprVal = checkConstantIntVal(expr); - stmt->expr = expr; - stmt->exprVal = exprVal; - stmt->parentStmt = switchStmt; - } + stmt->expr = expr; + stmt->exprVal = exprVal; + stmt->parentStmt = switchStmt; +} - void SemanticsStmtVisitor::visitTargetSwitchStmt(TargetSwitchStmt* stmt) +void SemanticsStmtVisitor::visitTargetSwitchStmt(TargetSwitchStmt* stmt) +{ + WithOuterStmt subContext(this, stmt); + HashSet checkedStmt; + for (auto caseStmt : stmt->targetCases) { - WithOuterStmt subContext(this, stmt); - HashSet checkedStmt; - for (auto caseStmt : stmt->targetCases) - { - if (checkedStmt.contains(caseStmt->body)) - continue; - subContext.checkStmt(caseStmt); - checkedStmt.add(caseStmt->body); - } + if (checkedStmt.contains(caseStmt->body)) + continue; + subContext.checkStmt(caseStmt); + checkedStmt.add(caseStmt->body); } +} - void SemanticsStmtVisitor::visitTargetCaseStmt(TargetCaseStmt* stmt) +void SemanticsStmtVisitor::visitTargetCaseStmt(TargetCaseStmt* stmt) +{ + auto switchStmt = FindOuterStmt(); + CapabilitySet set((CapabilityName)stmt->capability); + if (getShared()->isInLanguageServer() && + getShared()->getSession()->getCompletionRequestTokenName() == + stmt->capabilityToken.getName()) { - auto switchStmt = FindOuterStmt(); - CapabilitySet set((CapabilityName)stmt->capability); - if (getShared()->isInLanguageServer() && getShared()->getSession()->getCompletionRequestTokenName() == stmt->capabilityToken.getName()) - { - getShared()->getLinkage()->contentAssistInfo.completionSuggestions.scopeKind = CompletionSuggestions::ScopeKind::Capabilities; - } - - if (stmt->capabilityToken.getContentLength() != 0 && - (set.getCapabilityTargetSets().getCount() != 1 || set.isInvalid() || set.isEmpty())) - { - getSink()->diagnose( - stmt->capabilityToken.loc, - Diagnostics::invalidTargetSwitchCase, - capabilityNameToString((CapabilityName)stmt->capability)); - } - if (!switchStmt) - { - getSink()->diagnose(stmt, Diagnostics::caseOutsideSwitch); - } - WithOuterStmt subContext(this, stmt); - subContext.checkStmt(stmt->body); + getShared()->getLinkage()->contentAssistInfo.completionSuggestions.scopeKind = + CompletionSuggestions::ScopeKind::Capabilities; } - void SemanticsStmtVisitor::visitIntrinsicAsmStmt(IntrinsicAsmStmt* stmt) + if (stmt->capabilityToken.getContentLength() != 0 && + (set.getCapabilityTargetSets().getCount() != 1 || set.isInvalid() || set.isEmpty())) { - WithOuterStmt subContext(this, stmt); - for (auto& arg : stmt->args) - arg = subContext.CheckExpr(arg); + getSink()->diagnose( + stmt->capabilityToken.loc, + Diagnostics::invalidTargetSwitchCase, + capabilityNameToString((CapabilityName)stmt->capability)); } - - void SemanticsStmtVisitor::visitDefaultStmt(DefaultStmt* stmt) + if (!switchStmt) { - auto switchStmt = FindOuterStmt(); - if (!switchStmt) - { - getSink()->diagnose(stmt, Diagnostics::defaultOutsideSwitch); - } - stmt->parentStmt = switchStmt; + getSink()->diagnose(stmt, Diagnostics::caseOutsideSwitch); } + WithOuterStmt subContext(this, stmt); + subContext.checkStmt(stmt->body); +} - void SemanticsStmtVisitor::visitIfStmt(IfStmt *stmt) - { - stmt->predicate = checkPredicateExpr(stmt->predicate); - checkStmt(stmt->positiveStatement); - checkStmt(stmt->negativeStatement); - } +void SemanticsStmtVisitor::visitIntrinsicAsmStmt(IntrinsicAsmStmt* stmt) +{ + WithOuterStmt subContext(this, stmt); + for (auto& arg : stmt->args) + arg = subContext.CheckExpr(arg); +} - void SemanticsStmtVisitor::visitUnparsedStmt(UnparsedStmt*) +void SemanticsStmtVisitor::visitDefaultStmt(DefaultStmt* stmt) +{ + auto switchStmt = FindOuterStmt(); + if (!switchStmt) { - // Nothing to do + getSink()->diagnose(stmt, Diagnostics::defaultOutsideSwitch); } + stmt->parentStmt = switchStmt; +} - void SemanticsStmtVisitor::visitEmptyStmt(EmptyStmt*) - { - // Nothing to do - } +void SemanticsStmtVisitor::visitIfStmt(IfStmt* stmt) +{ + stmt->predicate = checkPredicateExpr(stmt->predicate); + checkStmt(stmt->positiveStatement); + checkStmt(stmt->negativeStatement); +} - void SemanticsStmtVisitor::visitDiscardStmt(DiscardStmt*) +void SemanticsStmtVisitor::visitUnparsedStmt(UnparsedStmt*) +{ + // Nothing to do +} + +void SemanticsStmtVisitor::visitEmptyStmt(EmptyStmt*) +{ + // Nothing to do +} + +void SemanticsStmtVisitor::visitDiscardStmt(DiscardStmt*) +{ + // Nothing to do +} + +void SemanticsStmtVisitor::visitReturnStmt(ReturnStmt* stmt) +{ + auto function = getParentFunc(); + if (!stmt->expression) { - // Nothing to do + if (function && !function->returnType.equals(m_astBuilder->getVoidType()) && + !as(function)) + { + getSink()->diagnose(stmt, Diagnostics::returnNeedsExpression); + } } - - void SemanticsStmtVisitor::visitReturnStmt(ReturnStmt *stmt) + else { - auto function = getParentFunc(); - if (!stmt->expression) + stmt->expression = CheckTerm(stmt->expression); + if (!stmt->expression->type->equals(m_astBuilder->getErrorType())) { - if (function && !function->returnType.equals(m_astBuilder->getVoidType()) && !as(function)) + if (function) { - getSink()->diagnose(stmt, Diagnostics::returnNeedsExpression); + stmt->expression = + coerce(CoercionSite::Return, function->returnType.Ptr(), stmt->expression); } - } - else - { - stmt->expression = CheckTerm(stmt->expression); - if (!stmt->expression->type->equals(m_astBuilder->getErrorType())) + else { - if (function) - { - stmt->expression = coerce(CoercionSite::Return, function->returnType.Ptr(), stmt->expression); - } - else - { - // TODO(tfoley): this case currently gets triggered for member functions, - // which aren't being checked consistently (because of the whole symbol - // table idea getting in the way). + // TODO(tfoley): this case currently gets triggered for member functions, + // which aren't being checked consistently (because of the whole symbol + // table idea getting in the way). -// getSink()->diagnose(stmt, Diagnostics::unimplemented, "case for return stmt"); - } + // getSink()->diagnose(stmt, + // Diagnostics::unimplemented, "case for return stmt"); } } } +} - void SemanticsStmtVisitor::visitWhileStmt(WhileStmt *stmt) - { - checkModifiers(stmt); - WithOuterStmt subContext(this, stmt); - stmt->predicate = checkPredicateExpr(stmt->predicate); - subContext.checkStmt(stmt->statement); - checkLoopInDifferentiableFunc(stmt); - } +void SemanticsStmtVisitor::visitWhileStmt(WhileStmt* stmt) +{ + checkModifiers(stmt); + WithOuterStmt subContext(this, stmt); + stmt->predicate = checkPredicateExpr(stmt->predicate); + subContext.checkStmt(stmt->statement); + checkLoopInDifferentiableFunc(stmt); +} - void SemanticsStmtVisitor::visitExpressionStmt(ExpressionStmt *stmt) +void SemanticsStmtVisitor::visitExpressionStmt(ExpressionStmt* stmt) +{ + stmt->expression = CheckExpr(stmt->expression); + if (auto operatorExpr = as(stmt->expression)) { - stmt->expression = CheckExpr(stmt->expression); - if (auto operatorExpr = as(stmt->expression)) + if (auto func = as(operatorExpr->functionExpr)) { - if (auto func = as(operatorExpr->functionExpr)) + if (func->name && func->name->text == "==") { - if (func->name && func->name->text == "==") - { - getSink()->diagnose(operatorExpr, Diagnostics::danglingEqualityExpr); - } + getSink()->diagnose(operatorExpr, Diagnostics::danglingEqualityExpr); } } } +} - void SemanticsStmtVisitor::tryInferLoopMaxIterations(ForStmt* stmt) - { - // If a for loop is in the form of `for (var = initialVal; var $compareOp otherVal; var sideEffectOp operand)` - // we will try to constant fold the operands and see if we can statically determine the maximum number of - // iterations this loop will run, and insert the inferred result as a `[MaxIters]` attribute on the stmt. - // - // ++, --, +=, -= are supported in side effect expressions. - // >, <, >=, <= are supported in predicate expressions. - // induction variable can appear in either side of the expressions. - // - // Other forms like for (var1 = .., var2 = ..; ) will not be recognized here. - // If we see suspicious code like `for (int i = 0; i < 5; j++)`, we will produce a warning along the way. - // - DeclRef predicateVar = {}; - Expr* initialVal = nullptr; - DeclRef initialVar = {}; - if (auto varStmt = as(stmt->initialStatement)) - { - auto varDecl = as(varStmt->decl); - if (!varDecl) - return; - initialVar = makeDeclRef(varDecl); - initialVal = varDecl->initExpr; - } - else if (auto exprStmt = as(stmt->initialStatement)) - { - auto assignExpr = as(exprStmt->expression); - if (!assignExpr) - return; - auto varExpr = as(assignExpr->left); - if (!varExpr) - return; - initialVar = varExpr->declRef; - initialVal = assignExpr->right; - } - else - return; - - auto initialLitVal = - as(tryFoldIntegerConstantExpression(initialVal, ConstantFoldingKind::CompileTime, nullptr)); - - ConstantIntVal* finalVal = nullptr; - auto binaryExpr = as(stmt->predicateExpression); - if (!binaryExpr) - return; - auto compareFuncExpr = as(binaryExpr->functionExpr); - if (!compareFuncExpr) - return; - if (!compareFuncExpr->declRef.getDecl()) +void SemanticsStmtVisitor::tryInferLoopMaxIterations(ForStmt* stmt) +{ + // If a for loop is in the form of `for (var = initialVal; var $compareOp otherVal; var + // sideEffectOp operand)` we will try to constant fold the operands and see if we can statically + // determine the maximum number of iterations this loop will run, and insert the inferred result + // as a `[MaxIters]` attribute on the stmt. + // + // ++, --, +=, -= are supported in side effect expressions. + // >, <, >=, <= are supported in predicate expressions. + // induction variable can appear in either side of the expressions. + // + // Other forms like for (var1 = .., var2 = ..; ) will not be recognized here. + // If we see suspicious code like `for (int i = 0; i < 5; j++)`, we will produce a warning along + // the way. + // + DeclRef predicateVar = {}; + Expr* initialVal = nullptr; + DeclRef initialVar = {}; + if (auto varStmt = as(stmt->initialStatement)) + { + auto varDecl = as(varStmt->decl); + if (!varDecl) return; - IROp compareOp = kIROp_Nop; - if (auto intrinsicOpModifier = compareFuncExpr->declRef.getDecl()->findModifier()) - { - compareOp = (IROp)intrinsicOpModifier->op; - } - else - { + initialVar = makeDeclRef(varDecl); + initialVal = varDecl->initExpr; + } + else if (auto exprStmt = as(stmt->initialStatement)) + { + auto assignExpr = as(exprStmt->expression); + if (!assignExpr) return; - } - if (binaryExpr->arguments.getCount() != 2) + auto varExpr = as(assignExpr->left); + if (!varExpr) return; - auto leftCompareOperand = binaryExpr->arguments[0]; - auto rightCompareOperand = binaryExpr->arguments[1]; - if (!leftCompareOperand) + initialVar = varExpr->declRef; + initialVal = assignExpr->right; + } + else + return; + + auto initialLitVal = as( + tryFoldIntegerConstantExpression(initialVal, ConstantFoldingKind::CompileTime, nullptr)); + + ConstantIntVal* finalVal = nullptr; + auto binaryExpr = as(stmt->predicateExpression); + if (!binaryExpr) + return; + auto compareFuncExpr = as(binaryExpr->functionExpr); + if (!compareFuncExpr) + return; + if (!compareFuncExpr->declRef.getDecl()) + return; + IROp compareOp = kIROp_Nop; + if (auto intrinsicOpModifier = + compareFuncExpr->declRef.getDecl()->findModifier()) + { + compareOp = (IROp)intrinsicOpModifier->op; + } + else + { + return; + } + if (binaryExpr->arguments.getCount() != 2) + return; + auto leftCompareOperand = binaryExpr->arguments[0]; + auto rightCompareOperand = binaryExpr->arguments[1]; + if (!leftCompareOperand) + return; + if (!rightCompareOperand) + return; + if (auto rightVal = tryFoldIntegerConstantExpression( + binaryExpr->arguments[1], + ConstantFoldingKind::CompileTime, + nullptr)) + { + auto leftVar = as(leftCompareOperand); + if (!leftVar) return; - if (!rightCompareOperand) + predicateVar = leftVar->declRef; + finalVal = as(rightVal); + } + else if ( + auto leftVal = tryFoldIntegerConstantExpression( + binaryExpr->arguments[0], + ConstantFoldingKind::CompileTime, + nullptr)) + { + auto rightVar = as(rightCompareOperand); + if (!rightVar) return; - if (auto rightVal = tryFoldIntegerConstantExpression(binaryExpr->arguments[1], ConstantFoldingKind::CompileTime, nullptr)) + predicateVar = rightVar->declRef; + finalVal = as(leftVal); + compareOp = getSwapSideComparisonOp(compareOp); + } + else + { + // If neither left or right is constant, we assume left is variable and continue checking. + if (auto leftVar = as(leftCompareOperand)) { - auto leftVar = as(leftCompareOperand); - if (!leftVar) - return; predicateVar = leftVar->declRef; - finalVal = as(rightVal); } - else if (auto leftVal = tryFoldIntegerConstantExpression(binaryExpr->arguments[0], ConstantFoldingKind::CompileTime, nullptr)) + if (auto rightVar = as(rightCompareOperand)) { - auto rightVar = as(rightCompareOperand); - if (!rightVar) - return; - predicateVar = rightVar->declRef; - finalVal = as(leftVal); - compareOp = getSwapSideComparisonOp(compareOp); - } - else - { - // If neither left or right is constant, we assume left is variable and continue checking. - if (auto leftVar = as(leftCompareOperand)) - { - predicateVar = leftVar->declRef; - } - if (auto rightVar = as(rightCompareOperand)) + if (rightVar->declRef == initialVar) { - if (rightVar->declRef == initialVar) - { - predicateVar = rightVar->declRef; - compareOp = getSwapSideComparisonOp(compareOp); - } + predicateVar = rightVar->declRef; + compareOp = getSwapSideComparisonOp(compareOp); } } + } - switch (compareOp) - { - case kIROp_Less: - case kIROp_Leq: - case kIROp_Greater: - case kIROp_Geq: - break; - default: - return; - } + switch (compareOp) + { + case kIROp_Less: + case kIROp_Leq: + case kIROp_Greater: + case kIROp_Geq: break; + default: return; + } - ConstantIntVal* stepSize = nullptr; - IROp sideEffectFuncOp = kIROp_Nop; - auto opSideEffectExpr = as(stmt->sideEffectExpression); - if (!opSideEffectExpr) + ConstantIntVal* stepSize = nullptr; + IROp sideEffectFuncOp = kIROp_Nop; + auto opSideEffectExpr = as(stmt->sideEffectExpression); + if (!opSideEffectExpr) + return; + auto sideEffectFuncExpr = as(opSideEffectExpr->functionExpr); + if (!sideEffectFuncExpr) + return; + auto sideEffectFuncDecl = sideEffectFuncExpr->declRef.getDecl(); + if (!sideEffectFuncDecl) + return; + if (auto opName = sideEffectFuncDecl->getName()) + { + if (opName->text == "++") + sideEffectFuncOp = kIROp_Add; + else if (opName->text == "--") + sideEffectFuncOp = kIROp_Sub; + else if (opName->text == "+=") + sideEffectFuncOp = kIROp_Add; + else if (opName->text == "-=") + sideEffectFuncOp = kIROp_Sub; + else return; - auto sideEffectFuncExpr = as(opSideEffectExpr->functionExpr); - if (!sideEffectFuncExpr) + } + if (opSideEffectExpr->arguments.getCount()) + { + auto varExpr = as(opSideEffectExpr->arguments[0]); + if (!varExpr) return; - auto sideEffectFuncDecl = sideEffectFuncExpr->declRef.getDecl(); - if (!sideEffectFuncDecl) + if (varExpr->declRef.getDecl() != initialVar.getDecl()) + { + // If the user writes something like `for (int i = 0; i < 5; j++)`, + // it is most likely a bug, so we issue a warning. + if (predicateVar == initialVar) + getSink()->diagnose( + varExpr, + Diagnostics::forLoopSideEffectChangingDifferentVar, + initialVar, + varExpr->declRef); return; - if (auto opName = sideEffectFuncDecl->getName()) - { - if (opName->text == "++") - sideEffectFuncOp = kIROp_Add; - else if (opName->text == "--") - sideEffectFuncOp = kIROp_Sub; - else if (opName->text == "+=") - sideEffectFuncOp = kIROp_Add; - else if (opName->text == "-=") - sideEffectFuncOp = kIROp_Sub; - else - return; - } - if (opSideEffectExpr->arguments.getCount()) - { - auto varExpr = as(opSideEffectExpr->arguments[0]); - if (!varExpr) - return; - if (varExpr->declRef.getDecl() != initialVar.getDecl()) - { - // If the user writes something like `for (int i = 0; i < 5; j++)`, - // it is most likely a bug, so we issue a warning. - if (predicateVar == initialVar) - getSink()->diagnose(varExpr, Diagnostics::forLoopSideEffectChangingDifferentVar, initialVar, varExpr->declRef); - return; - } } - else + } + else + return; + if (opSideEffectExpr->arguments.getCount() == 2) + { + auto stepVal = tryFoldIntegerConstantExpression( + opSideEffectExpr->arguments[1], + ConstantFoldingKind::CompileTime, + nullptr); + if (!stepVal) return; - if (opSideEffectExpr->arguments.getCount() == 2) - { - auto stepVal = tryFoldIntegerConstantExpression(opSideEffectExpr->arguments[1], ConstantFoldingKind::CompileTime, nullptr); - if (!stepVal) - return; - if (auto constantIntVal = as(stepVal)) - { - stepSize = constantIntVal; - } - } - else + if (auto constantIntVal = as(stepVal)) { - stepSize = m_astBuilder->getIntVal(m_astBuilder->getIntType(), 1); + stepSize = constantIntVal; } + } + else + { + stepSize = m_astBuilder->getIntVal(m_astBuilder->getIntType(), 1); + } - if (predicateVar.getDecl() != initialVar.getDecl()) + if (predicateVar.getDecl() != initialVar.getDecl()) + { + if (predicateVar) + getSink()->diagnose( + stmt->predicateExpression, + Diagnostics::forLoopPredicateCheckingDifferentVar, + initialVar, + predicateVar); + return; + } + if (!stepSize) + return; + if (stepSize->getValue() > 0) + { + if (sideEffectFuncOp == kIROp_Add && compareOp == kIROp_Greater || + sideEffectFuncOp == kIROp_Sub && compareOp == kIROp_Less) { - if (predicateVar) - getSink()->diagnose(stmt->predicateExpression, Diagnostics::forLoopPredicateCheckingDifferentVar, initialVar, predicateVar); - return; - } - if (!stepSize) + getSink()->diagnose( + stmt->sideEffectExpression, + Diagnostics::forLoopChangingIterationVariableInOppsoiteDirection, + initialVar); return; - if (stepSize->getValue() > 0) - { - if (sideEffectFuncOp == kIROp_Add && compareOp == kIROp_Greater || - sideEffectFuncOp == kIROp_Sub && compareOp == kIROp_Less) - { - getSink()->diagnose(stmt->sideEffectExpression, Diagnostics::forLoopChangingIterationVariableInOppsoiteDirection, initialVar); - return; - } } - else if (stepSize->getValue() < 0) - { - if (sideEffectFuncOp == kIROp_Add && compareOp == kIROp_Less || - sideEffectFuncOp == kIROp_Sub && compareOp == kIROp_Greater) - { - getSink()->diagnose(stmt->sideEffectExpression, Diagnostics::forLoopChangingIterationVariableInOppsoiteDirection, initialVar); - return; - } - } - else + } + else if (stepSize->getValue() < 0) + { + if (sideEffectFuncOp == kIROp_Add && compareOp == kIROp_Less || + sideEffectFuncOp == kIROp_Sub && compareOp == kIROp_Greater) { - getSink()->diagnose(stmt->sideEffectExpression, Diagnostics::forLoopNotModifyingIterationVariable, initialVar); + getSink()->diagnose( + stmt->sideEffectExpression, + Diagnostics::forLoopChangingIterationVariableInOppsoiteDirection, + initialVar); return; } - - if (!initialLitVal || !finalVal) - return; - - auto absStepSize = abs(stepSize->getValue()); - int adjustment = 0; - if (compareOp == kIROp_Geq || compareOp == kIROp_Leq) - adjustment = 1; - - auto iterations = (Math::Max(finalVal->getValue(), initialLitVal->getValue()) - - Math::Min(finalVal->getValue(), initialLitVal->getValue()) + absStepSize - 1 + adjustment) / - absStepSize; - switch (compareOp) - { - case kIROp_Geq: - case kIROp_Greater: - // Expect final value to be less than initial value. - if (finalVal->getValue() > initialLitVal->getValue()) - iterations = 0; - break; - case kIROp_Leq: - case kIROp_Less: - if (finalVal->getValue() < initialLitVal->getValue()) - iterations = 0; - break; - } - if (iterations == 0) - { - getSink()->diagnose(stmt, Diagnostics::loopRunsForZeroIterations); - } - - // Note: the inferred max iterations may not be valid if the loop body - // also modifies the induction variable. - // We detect this case during lower-to-ir and will remove the `InferredMaxItersAttribute` - // if the loop body modifies the induction variable. - // - auto maxItersAttr = m_astBuilder->create(); - auto litExpr = m_astBuilder->create(); - litExpr->type.type = m_astBuilder->getIntType(); - litExpr->token.setName(getNamePool()->getName(String(iterations))); - maxItersAttr->args.add(litExpr); - maxItersAttr->intArgVals.add(m_astBuilder->getIntVal(m_astBuilder->getIntType(), iterations)); - maxItersAttr->value = (int32_t)iterations; - maxItersAttr->inductionVar = initialVar; - addModifier(stmt, maxItersAttr); + } + else + { + getSink()->diagnose( + stmt->sideEffectExpression, + Diagnostics::forLoopNotModifyingIterationVariable, + initialVar); return; } - void SemanticsStmtVisitor::checkLoopInDifferentiableFunc(Stmt* stmt) - { - SLANG_UNUSED(stmt); - if (getParentDifferentiableAttribute()) - { - if (!getParentFunc()) - return; - - // If the function is itself a derivative, or has a user defined derivative, - // then we don't require anything. + if (!initialLitVal || !finalVal) + return; - if (getParentFunc()->findModifier()) - return; - if (getParentFunc()->findModifier()) - return; - if (getParentFunc()->findModifier()) - return; - if (getParentFunc()->findModifier()) - return; - } - } + auto absStepSize = abs(stepSize->getValue()); + int adjustment = 0; + if (compareOp == kIROp_Geq || compareOp == kIROp_Leq) + adjustment = 1; + + auto iterations = (Math::Max(finalVal->getValue(), initialLitVal->getValue()) - + Math::Min(finalVal->getValue(), initialLitVal->getValue()) + absStepSize - + 1 + adjustment) / + absStepSize; + switch (compareOp) + { + case kIROp_Geq: + case kIROp_Greater: + // Expect final value to be less than initial value. + if (finalVal->getValue() > initialLitVal->getValue()) + iterations = 0; + break; + case kIROp_Leq: + case kIROp_Less: + if (finalVal->getValue() < initialLitVal->getValue()) + iterations = 0; + break; + } + if (iterations == 0) + { + getSink()->diagnose(stmt, Diagnostics::loopRunsForZeroIterations); + } + + // Note: the inferred max iterations may not be valid if the loop body + // also modifies the induction variable. + // We detect this case during lower-to-ir and will remove the `InferredMaxItersAttribute` + // if the loop body modifies the induction variable. + // + auto maxItersAttr = m_astBuilder->create(); + auto litExpr = m_astBuilder->create(); + litExpr->type.type = m_astBuilder->getIntType(); + litExpr->token.setName(getNamePool()->getName(String(iterations))); + maxItersAttr->args.add(litExpr); + maxItersAttr->intArgVals.add(m_astBuilder->getIntVal(m_astBuilder->getIntType(), iterations)); + maxItersAttr->value = (int32_t)iterations; + maxItersAttr->inductionVar = initialVar; + addModifier(stmt, maxItersAttr); + return; +} - void SemanticsStmtVisitor::visitGpuForeachStmt(GpuForeachStmt*stmt) +void SemanticsStmtVisitor::checkLoopInDifferentiableFunc(Stmt* stmt) +{ + SLANG_UNUSED(stmt); + if (getParentDifferentiableAttribute()) { - stmt->device = CheckExpr(stmt->device); - stmt->gridDims = CheckExpr(stmt->gridDims); - ensureDeclBase(stmt->dispatchThreadID, DeclCheckState::DefinitionChecked, this); - WithOuterStmt subContext(this, stmt); - stmt->kernelCall = subContext.CheckExpr(stmt->kernelCall); - return; + if (!getParentFunc()) + return; + + // If the function is itself a derivative, or has a user defined derivative, + // then we don't require anything. + + if (getParentFunc()->findModifier()) + return; + if (getParentFunc()->findModifier()) + return; + if (getParentFunc()->findModifier()) + return; + if (getParentFunc()->findModifier()) + return; } } + +void SemanticsStmtVisitor::visitGpuForeachStmt(GpuForeachStmt* stmt) +{ + stmt->device = CheckExpr(stmt->device); + stmt->gridDims = CheckExpr(stmt->gridDims); + ensureDeclBase(stmt->dispatchThreadID, DeclCheckState::DefinitionChecked, this); + WithOuterStmt subContext(this, stmt); + stmt->kernelCall = subContext.CheckExpr(stmt->kernelCall); + return; +} +} // namespace Slang diff --git a/source/slang/slang-check-type.cpp b/source/slang/slang-check-type.cpp index eeee13561..34f16751b 100644 --- a/source/slang/slang-check-type.cpp +++ b/source/slang/slang-check-type.cpp @@ -6,467 +6,473 @@ namespace Slang { - Type* checkProperType( - Linkage* linkage, - TypeExp typeExp, - DiagnosticSink* sink) - { - SharedSemanticsContext sharedSemanticsContext( - linkage, - nullptr, - sink); - SemanticsVisitor visitor(&sharedSemanticsContext); +Type* checkProperType(Linkage* linkage, TypeExp typeExp, DiagnosticSink* sink) +{ + SharedSemanticsContext sharedSemanticsContext(linkage, nullptr, sink); + SemanticsVisitor visitor(&sharedSemanticsContext); - SLANG_AST_BUILDER_RAII(linkage->getASTBuilder()); + SLANG_AST_BUILDER_RAII(linkage->getASTBuilder()); - auto typeOut = visitor.CheckProperType(typeExp); - return typeOut.type; - } + auto typeOut = visitor.CheckProperType(typeExp); + return typeOut.type; +} - Type* getPointedToTypeIfCanImplicitDeref(Type* type) +Type* getPointedToTypeIfCanImplicitDeref(Type* type) +{ + if (auto ptrLike = as(type)) { - if (auto ptrLike = as(type)) - { - return ptrLike->getElementType(); - } - else if (auto ptrType = as(type)) - { - return ptrType->getValueType(); - } - else if (auto refType = as(type)) - { - return refType->getValueType(); - } - return nullptr; + return ptrLike->getElementType(); } - - Expr* SemanticsVisitor::TranslateTypeNodeImpl(Expr* node) + else if (auto ptrType = as(type)) { - if (!node) return nullptr; - - auto expr = CheckTerm(node); - expr = ExpectATypeRepr(expr); - return expr; + return ptrType->getValueType(); } - - Type* SemanticsVisitor::ExtractTypeFromTypeRepr(Expr* typeRepr) + else if (auto refType = as(type)) { - if (!typeRepr) return nullptr; - if (auto typeType = as(typeRepr->type)) - { - return typeType->getType(); - } - return m_astBuilder->getErrorType(); + return refType->getValueType(); } + return nullptr; +} + +Expr* SemanticsVisitor::TranslateTypeNodeImpl(Expr* node) +{ + if (!node) + return nullptr; + + auto expr = CheckTerm(node); + expr = ExpectATypeRepr(expr); + return expr; +} - Type* SemanticsVisitor::TranslateTypeNode(Expr* node) +Type* SemanticsVisitor::ExtractTypeFromTypeRepr(Expr* typeRepr) +{ + if (!typeRepr) + return nullptr; + if (auto typeType = as(typeRepr->type)) { - if (!node) return nullptr; - auto typeRepr = TranslateTypeNodeImpl(node); - return ExtractTypeFromTypeRepr(typeRepr); + return typeType->getType(); } + return m_astBuilder->getErrorType(); +} - TypeExp SemanticsVisitor::TranslateTypeNodeForced(TypeExp const& typeExp) - { - auto typeRepr = TranslateTypeNodeImpl(typeExp.exp); +Type* SemanticsVisitor::TranslateTypeNode(Expr* node) +{ + if (!node) + return nullptr; + auto typeRepr = TranslateTypeNodeImpl(node); + return ExtractTypeFromTypeRepr(typeRepr); +} - TypeExp result; - result.exp = typeRepr; - result.type = ExtractTypeFromTypeRepr(typeRepr); - return result; - } +TypeExp SemanticsVisitor::TranslateTypeNodeForced(TypeExp const& typeExp) +{ + auto typeRepr = TranslateTypeNodeImpl(typeExp.exp); + + TypeExp result; + result.exp = typeRepr; + result.type = ExtractTypeFromTypeRepr(typeRepr); + return result; +} - TypeExp SemanticsVisitor::TranslateTypeNode(TypeExp const& typeExp) +TypeExp SemanticsVisitor::TranslateTypeNode(TypeExp const& typeExp) +{ + // HACK(tfoley): It seems that in some cases we end up re-checking + // syntax that we've already checked. We need to root-cause that + // issue, but for now a quick fix in this case is to early + // exist if we've already got a type associated here: + if (typeExp.type) { - // HACK(tfoley): It seems that in some cases we end up re-checking - // syntax that we've already checked. We need to root-cause that - // issue, but for now a quick fix in this case is to early - // exist if we've already got a type associated here: - if (typeExp.type) - { - return typeExp; - } - return TranslateTypeNodeForced(typeExp); + return typeExp; } + return TranslateTypeNodeForced(typeExp); +} - Type* SemanticsVisitor::getRemovedModifierType(ModifiedType* modifiedType, ModifierVal* modifier) +Type* SemanticsVisitor::getRemovedModifierType(ModifiedType* modifiedType, ModifierVal* modifier) +{ + if (modifiedType->getModifierCount() == 1) + return modifiedType->getBase(); + List newModifiers; + for (Index i = 0; i < modifiedType->getModifierCount(); i++) { - if (modifiedType->getModifierCount() == 1) - return modifiedType->getBase(); - List newModifiers; - for (Index i = 0; i < modifiedType->getModifierCount(); i++) - { - auto m = modifiedType->getModifier(i); - if (m == modifier) - continue; - newModifiers.add(m); - } - return m_astBuilder->getModifiedType(modifiedType->getBase(), newModifiers); + auto m = modifiedType->getModifier(i); + if (m == modifier) + continue; + newModifiers.add(m); } + return m_astBuilder->getModifiedType(modifiedType->getBase(), newModifiers); +} - Expr* SemanticsVisitor::ExpectATypeRepr(Expr* expr) +Expr* SemanticsVisitor::ExpectATypeRepr(Expr* expr) +{ + if (auto overloadedExpr = as(expr)) { - if (auto overloadedExpr = as(expr)) - { - expr = resolveOverloadedExpr(overloadedExpr, LookupMask::type); - } - - if (const auto typeType = as(expr->type)) - { - return expr; - } - else if (const auto errorType = as(expr->type)) - { - return expr; - } - - getSink()->diagnose(expr, Diagnostics::expectedAType, expr->type); - return CreateErrorExpr(expr); + expr = resolveOverloadedExpr(overloadedExpr, LookupMask::type); } - Type* SemanticsVisitor::ExpectAType(Expr* expr) + if (const auto typeType = as(expr->type)) { - auto typeRepr = ExpectATypeRepr(expr); - if (auto typeType = as(typeRepr->type)) - { - return typeType->getType(); - } - return m_astBuilder->getErrorType(); + return expr; } - - Type* SemanticsVisitor::ExtractGenericArgType(Expr* exp) + else if (const auto errorType = as(expr->type)) { - return ExpectAType(exp); + return expr; } - IntVal* SemanticsVisitor::ExtractGenericArgInteger(Expr* exp, Type* genericParamType, DiagnosticSink* sink) + getSink()->diagnose(expr, Diagnostics::expectedAType, expr->type); + return CreateErrorExpr(expr); +} + +Type* SemanticsVisitor::ExpectAType(Expr* expr) +{ + auto typeRepr = ExpectATypeRepr(expr); + if (auto typeType = as(typeRepr->type)) { - IntVal* val = CheckIntegerConstantExpression( - exp, - genericParamType ? IntegerConstantExpressionCoercionType::SpecificType - : IntegerConstantExpressionCoercionType::AnyInteger, - genericParamType, - ConstantFoldingKind::LinkTime, - sink); - if(val) return val; - - // If the argument expression could not be coerced to an integer - // constant expression in context, then we will instead construct - // a dummy "error" value to represent the result. - // - val = m_astBuilder->getOrCreate(m_astBuilder->getIntType()); - return val; + return typeType->getType(); } + return m_astBuilder->getErrorType(); +} + +Type* SemanticsVisitor::ExtractGenericArgType(Expr* exp) +{ + return ExpectAType(exp); +} + +IntVal* SemanticsVisitor::ExtractGenericArgInteger( + Expr* exp, + Type* genericParamType, + DiagnosticSink* sink) +{ + IntVal* val = CheckIntegerConstantExpression( + exp, + genericParamType ? IntegerConstantExpressionCoercionType::SpecificType + : IntegerConstantExpressionCoercionType::AnyInteger, + genericParamType, + ConstantFoldingKind::LinkTime, + sink); + if (val) + return val; + + // If the argument expression could not be coerced to an integer + // constant expression in context, then we will instead construct + // a dummy "error" value to represent the result. + // + val = m_astBuilder->getOrCreate(m_astBuilder->getIntType()); + return val; +} - IntVal* SemanticsVisitor::ExtractGenericArgInteger(Expr* exp, Type* genericParamType) +IntVal* SemanticsVisitor::ExtractGenericArgInteger(Expr* exp, Type* genericParamType) +{ + return ExtractGenericArgInteger(exp, genericParamType, getSink()); +} + +Val* SemanticsVisitor::ExtractGenericArgVal(Expr* exp) +{ + if (auto overloadedExpr = as(exp)) { - return ExtractGenericArgInteger(exp, genericParamType, getSink()); + // assume that if it is overloaded, we want a type + exp = resolveOverloadedExpr(overloadedExpr, LookupMask::type); } - - Val* SemanticsVisitor::ExtractGenericArgVal(Expr* exp) + if (auto typeType = as(exp->type)) { - if (auto overloadedExpr = as(exp)) - { - // assume that if it is overloaded, we want a type - exp = resolveOverloadedExpr(overloadedExpr, LookupMask::type); - } - if (auto typeType = as(exp->type)) - { - return typeType->getType(); - } - else if (const auto errorType = as(exp->type)) - { - return exp->type.type; - } - else + return typeType->getType(); + } + else if (const auto errorType = as(exp->type)) + { + return exp->type.type; + } + else + { + if (!exp->type.type) { - if (!exp->type.type) - { - CheckExpr(exp); - } - return ExtractGenericArgInteger(exp, nullptr); + CheckExpr(exp); } + return ExtractGenericArgInteger(exp, nullptr); } +} - Type* SemanticsVisitor::InstantiateGenericType( - DeclRef genericDeclRef, - List const& args) +Type* SemanticsVisitor::InstantiateGenericType( + DeclRef genericDeclRef, + List const& args) +{ + List evaledArgs; + + for (auto argExpr : args) { - List evaledArgs; + evaledArgs.add(ExtractGenericArgVal(argExpr)); + } - for (auto argExpr : args) - { - evaledArgs.add(ExtractGenericArgVal(argExpr)); - } + DeclRef innerDeclRef = + m_astBuilder->getGenericAppDeclRef(genericDeclRef, evaledArgs.getArrayView()); + return DeclRefType::create(m_astBuilder, innerDeclRef); +} - DeclRef innerDeclRef = m_astBuilder->getGenericAppDeclRef(genericDeclRef, evaledArgs.getArrayView()); - return DeclRefType::create(m_astBuilder, innerDeclRef); +bool isManagedType(Type* type) +{ + if (auto declRefValueType = as(type)) + { + if (as(declRefValueType->getDeclRef().getDecl())) + return true; + if (as(declRefValueType->getDeclRef().getDecl())) + return true; } + return false; +} - bool isManagedType(Type* type) +bool SemanticsVisitor::CoerceToProperTypeImpl( + TypeExp const& typeExp, + Type** outProperType, + DiagnosticSink* diagSink) +{ + Type* result = nullptr; + Type* type = typeExp.type; + auto originalExpr = typeExp.exp; + auto expr = originalExpr; + if (!type && expr) { - if (auto declRefValueType = as(type)) + expr = maybeResolveOverloadedExpr(expr, LookupMask::type, diagSink); + + if (auto typeType = as(expr->type)) { - if (as(declRefValueType->getDeclRef().getDecl())) - return true; - if (as(declRefValueType->getDeclRef().getDecl())) - return true; + type = typeType->getType(); } - return false; } - bool SemanticsVisitor::CoerceToProperTypeImpl( - TypeExp const& typeExp, - Type** outProperType, - DiagnosticSink* diagSink) + if (!type) { - Type* result = nullptr; - Type* type = typeExp.type; - auto originalExpr = typeExp.exp; - auto expr = originalExpr; - if(!type && expr) + // Only output diagnostic if we have a sink. + if (diagSink) { - expr = maybeResolveOverloadedExpr(expr, LookupMask::type, diagSink); - - if(auto typeType = as(expr->type)) + // This function *can* be called with typeExp with both exp and type = nullptr. + // Previous behavior didn't output a diagnostic if originalExpr was null, so this keeps + // that behavior. + // + // Additional we check for ErrorType on expr, because if it's set a diagnostic has + // already been output via previous code or via maybeResolveOverloadedExpr. + if (originalExpr && (expr == nullptr || as(expr->type) == nullptr)) { - type = typeType->getType(); + // The diagnostic for expectedAType wants to say what it 'got'. + // The solution given here, currently is to just use the node name. + // How useful that might be could depend, and perhaps some other mechanism + // that catagorized 'what' the wrong thing was is. For now this seems sufficient. + // + // Note that use originalExpr (not expr) because we want original expr for + // diagnostic. + + // Get the AST node type info, so we can output a 'got' name + auto info = ASTClassInfo::getInfo(originalExpr->astNodeType); + diagSink->diagnose(originalExpr, Diagnostics::expectedAType, info->m_name); } } - if (!type) + if (outProperType) { - // Only output diagnostic if we have a sink. - if (diagSink) - { - // This function *can* be called with typeExp with both exp and type = nullptr. - // Previous behavior didn't output a diagnostic if originalExpr was null, so this keeps that behavior. - // - // Additional we check for ErrorType on expr, because if it's set a diagnostic has already been output via - // previous code or via maybeResolveOverloadedExpr. - if (originalExpr && (expr == nullptr || as(expr->type) == nullptr)) - { - // The diagnostic for expectedAType wants to say what it 'got'. - // The solution given here, currently is to just use the node name. - // How useful that might be could depend, and perhaps some other mechanism - // that catagorized 'what' the wrong thing was is. For now this seems sufficient. - // - // Note that use originalExpr (not expr) because we want original expr for diagnostic. - - // Get the AST node type info, so we can output a 'got' name - auto info = ASTClassInfo::getInfo(originalExpr->astNodeType); - diagSink->diagnose(originalExpr, Diagnostics::expectedAType, info->m_name); - } - } - - if (outProperType) - { - *outProperType = nullptr; - } - return false; + *outProperType = nullptr; } + return false; + } - if (auto genericDeclRefType = as(type)) - { - // We are using a reference to a generic declaration as a concrete - // type. This means we should substitute in any default parameter values - // if they are available. - // - // TODO(tfoley): A more expressive type system would substitute in - // "fresh" variables and then solve for their values... - // + if (auto genericDeclRefType = as(type)) + { + // We are using a reference to a generic declaration as a concrete + // type. This means we should substitute in any default parameter values + // if they are available. + // + // TODO(tfoley): A more expressive type system would substitute in + // "fresh" variables and then solve for their values... + // - auto genericDeclRef = genericDeclRefType->getDeclRef(); - ensureDecl(genericDeclRef, DeclCheckState::CanSpecializeGeneric); - List args; - List witnessArgs; - for (Decl* member : genericDeclRef.getDecl()->members) + auto genericDeclRef = genericDeclRefType->getDeclRef(); + ensureDecl(genericDeclRef, DeclCheckState::CanSpecializeGeneric); + List args; + List witnessArgs; + for (Decl* member : genericDeclRef.getDecl()->members) + { + if (auto typeParam = as(member)) { - if (auto typeParam = as(member)) - { - if (!typeParam->initType.exp) - { - if (diagSink) - { - diagSink->diagnose(typeExp.exp, Diagnostics::genericTypeNeedsArgs, typeExp); - *outProperType = m_astBuilder->getErrorType(); - } - return false; - } - - // TODO: this is one place where syntax should get cloned! - if (outProperType) - args.add(ExtractGenericArgVal(typeParam->initType.exp)); - } - else if (auto valParam = as(member)) + if (!typeParam->initType.exp) { - if (!valParam->initExpr) + if (diagSink) { - if (diagSink) - { - diagSink->diagnose(typeExp.exp, Diagnostics::unimplemented, "can't fill in default for generic type parameter"); - *outProperType = m_astBuilder->getErrorType(); - } - return false; + diagSink->diagnose(typeExp.exp, Diagnostics::genericTypeNeedsArgs, typeExp); + *outProperType = m_astBuilder->getErrorType(); } - // TODO: this is one place where syntax should get cloned! - if (outProperType) - args.add(ExtractGenericArgVal(valParam->initExpr)); + return false; } - else if (auto constraintParam = as(member)) + + // TODO: this is one place where syntax should get cloned! + if (outProperType) + args.add(ExtractGenericArgVal(typeParam->initType.exp)); + } + else if (auto valParam = as(member)) + { + if (!valParam->initExpr) { - auto genericParam = as(constraintParam->sub.type)->getDeclRef(); - if (!genericParam) - return false; - auto genericTypeParamDecl = as(genericParam.getDecl()); - if (!genericTypeParamDecl) - return false; - auto defaultType = CheckProperType(genericTypeParamDecl->initType); - if (!defaultType) - return false; - auto witness = tryGetSubtypeWitness(defaultType, CheckProperType(constraintParam->sup)); - if (!witness) + if (diagSink) { - // diagnose - getSink()->diagnose( - genericTypeParamDecl->initType.exp, - Diagnostics::typeArgumentDoesNotConformToInterface, - defaultType, - constraintParam->sup); - return false; + diagSink->diagnose( + typeExp.exp, + Diagnostics::unimplemented, + "can't fill in default for generic type parameter"); + *outProperType = m_astBuilder->getErrorType(); } - witnessArgs.add(witness); + return false; } - else + // TODO: this is one place where syntax should get cloned! + if (outProperType) + args.add(ExtractGenericArgVal(valParam->initExpr)); + } + else if (auto constraintParam = as(member)) + { + auto genericParam = as(constraintParam->sub.type)->getDeclRef(); + if (!genericParam) + return false; + auto genericTypeParamDecl = as(genericParam.getDecl()); + if (!genericTypeParamDecl) + return false; + auto defaultType = CheckProperType(genericTypeParamDecl->initType); + if (!defaultType) + return false; + auto witness = + tryGetSubtypeWitness(defaultType, CheckProperType(constraintParam->sup)); + if (!witness) { - // ignore non-parameter members + // diagnose + getSink()->diagnose( + genericTypeParamDecl->initType.exp, + Diagnostics::typeArgumentDoesNotConformToInterface, + defaultType, + constraintParam->sup); + return false; } + witnessArgs.add(witness); } - // Combine args and witnessArgs - args.addRange(witnessArgs); - - result = DeclRefType::create(getASTBuilder(), - getASTBuilder()->getGenericAppDeclRef(genericDeclRef, args.getArrayView())); - } - - // default case: we expect this to already be a proper type - if (!result) - { - result = type; - } - - // Check for invalid types. - // We don't allow pointers to managed types. - if (auto ptrType = as(result)) - { - if (isManagedType(ptrType->getValueType())) + else { - getSink()->diagnose(typeExp.exp, Diagnostics::cannotDefinePtrTypeToManagedResource); + // ignore non-parameter members } } + // Combine args and witnessArgs + args.addRange(witnessArgs); - *outProperType = result; - return true; + result = DeclRefType::create( + getASTBuilder(), + getASTBuilder()->getGenericAppDeclRef(genericDeclRef, args.getArrayView())); } - TypeExp SemanticsVisitor::CoerceToProperType(TypeExp const& typeExp) + // default case: we expect this to already be a proper type + if (!result) { - TypeExp result = typeExp; - CoerceToProperTypeImpl(typeExp, &result.type, getSink()); - return result; + result = type; } - TypeExp SemanticsVisitor::tryCoerceToProperType(TypeExp const& typeExp) + // Check for invalid types. + // We don't allow pointers to managed types. + if (auto ptrType = as(result)) { - TypeExp result = typeExp; - if(!CoerceToProperTypeImpl(typeExp, &result.type, nullptr)) - return TypeExp(); - return result; + if (isManagedType(ptrType->getValueType())) + { + getSink()->diagnose(typeExp.exp, Diagnostics::cannotDefinePtrTypeToManagedResource); + } } - TypeExp SemanticsVisitor::CheckProperType(TypeExp typeExp) - { - return CoerceToProperType(TranslateTypeNode(typeExp)); - } + *outProperType = result; + return true; +} - TypeExp SemanticsVisitor::CoerceToUsableType(TypeExp const& typeExp, Decl* decl) - { - TypeExp result = CoerceToProperType(typeExp); - Type* type = result.type; - if (auto basicType = as(type)) - { - // TODO: `void` shouldn't be a basic type, to make this easier to avoid - if (basicType->getBaseType() == BaseType::Void) - { - // TODO(tfoley): pick the right diagnostic message - getSink()->diagnose(result.exp, Diagnostics::invalidTypeVoid); - result.type = m_astBuilder->getErrorType(); - return result; - } - } +TypeExp SemanticsVisitor::CoerceToProperType(TypeExp const& typeExp) +{ + TypeExp result = typeExp; + CoerceToProperTypeImpl(typeExp, &result.type, getSink()); + return result; +} + +TypeExp SemanticsVisitor::tryCoerceToProperType(TypeExp const& typeExp) +{ + TypeExp result = typeExp; + if (!CoerceToProperTypeImpl(typeExp, &result.type, nullptr)) + return TypeExp(); + return result; +} - // A type pack is not a usable type other than for defining parameters. - if (!as(decl) && isTypePack(type)) +TypeExp SemanticsVisitor::CheckProperType(TypeExp typeExp) +{ + return CoerceToProperType(TranslateTypeNode(typeExp)); +} + +TypeExp SemanticsVisitor::CoerceToUsableType(TypeExp const& typeExp, Decl* decl) +{ + TypeExp result = CoerceToProperType(typeExp); + Type* type = result.type; + if (auto basicType = as(type)) + { + // TODO: `void` shouldn't be a basic type, to make this easier to avoid + if (basicType->getBaseType() == BaseType::Void) { - getSink()->diagnose(typeExp.exp, Diagnostics::improperUseOfType, typeExp.type); + // TODO(tfoley): pick the right diagnostic message + getSink()->diagnose(result.exp, Diagnostics::invalidTypeVoid); result.type = m_astBuilder->getErrorType(); return result; } - return result; } - TypeExp SemanticsVisitor::CheckUsableType(TypeExp typeExp, Decl* decl) + // A type pack is not a usable type other than for defining parameters. + if (!as(decl) && isTypePack(type)) { - return CoerceToUsableType(TranslateTypeNode(typeExp), decl); + getSink()->diagnose(typeExp.exp, Diagnostics::improperUseOfType, typeExp.type); + result.type = m_astBuilder->getErrorType(); + return result; } + return result; +} - bool SemanticsVisitor::ValuesAreEqual( - IntVal* left, - IntVal* right) - { - if(left == right) return true; +TypeExp SemanticsVisitor::CheckUsableType(TypeExp typeExp, Decl* decl) +{ + return CoerceToUsableType(TranslateTypeNode(typeExp), decl); +} + +bool SemanticsVisitor::ValuesAreEqual(IntVal* left, IntVal* right) +{ + if (left == right) + return true; - if(auto leftConst = as(left)) + if (auto leftConst = as(left)) + { + if (auto rightConst = as(right)) { - if(auto rightConst = as(right)) - { - return leftConst->getValue() == rightConst->getValue(); - } + return leftConst->getValue() == rightConst->getValue(); } + } - if(auto leftVar = as(left)) + if (auto leftVar = as(left)) + { + if (auto rightVar = as(right)) { - if(auto rightVar = as(right)) - { - return leftVar->getDeclRef().equals(rightVar->getDeclRef()); - } - else if (const auto rightPoly = as(right)) - { - return right->equals(leftVar); - } + return leftVar->getDeclRef().equals(rightVar->getDeclRef()); } - if (auto leftVar = as(left)) + else if (const auto rightPoly = as(right)) { - return leftVar->equals(right); + return right->equals(leftVar); } - return false; } - - VectorExpressionType* SemanticsVisitor::createVectorType( - Type* elementType, - IntVal* elementCount) + if (auto leftVar = as(left)) { - return m_astBuilder->getVectorType(elementType, elementCount); + return leftVar->equals(right); } + return false; +} + +VectorExpressionType* SemanticsVisitor::createVectorType(Type* elementType, IntVal* elementCount) +{ + return m_astBuilder->getVectorType(elementType, elementCount); +} - Expr* SemanticsExprVisitor::visitSharedTypeExpr(SharedTypeExpr* expr) +Expr* SemanticsExprVisitor::visitSharedTypeExpr(SharedTypeExpr* expr) +{ + if (!expr->type.Ptr()) { - if (!expr->type.Ptr()) - { - expr->base = CheckProperType(expr->base); - expr->type = expr->base.exp->type; - } - return expr; + expr->base = CheckProperType(expr->base); + expr->type = expr->base.exp->type; } - + return expr; } + +} // namespace Slang diff --git a/source/slang/slang-check.cpp b/source/slang/slang-check.cpp index 3f79b7f41..1fbef899b 100644 --- a/source/slang/slang-check.cpp +++ b/source/slang/slang-check.cpp @@ -5,218 +5,233 @@ // checking that don't cleanly land in one of the more // specialized `slang-check-*` files. -#include "slang-check-impl.h" - #include "../core/slang-type-text-util.h" +#include "slang-check-impl.h" namespace Slang { - namespace { // anonymous - - class SinkSharedLibraryLoader : public RefObject, public ISlangSharedLibraryLoader +namespace +{ // anonymous + +class SinkSharedLibraryLoader : public RefObject, public ISlangSharedLibraryLoader +{ +public: + SLANG_REF_OBJECT_IUNKNOWN_ALL + + virtual SLANG_NO_THROW SlangResult SLANG_MCALL + loadSharedLibrary(const char* path, ISlangSharedLibrary** outSharedLibrary) SLANG_OVERRIDE { - public: - SLANG_REF_OBJECT_IUNKNOWN_ALL + SlangResult res = m_loader->loadSharedLibrary(path, outSharedLibrary); - virtual SLANG_NO_THROW SlangResult SLANG_MCALL loadSharedLibrary( - const char* path, - ISlangSharedLibrary** outSharedLibrary) SLANG_OVERRIDE + // Special handling for failure... + if (SLANG_FAILED(res) && m_sink) { - SlangResult res = m_loader->loadSharedLibrary(path, outSharedLibrary); - - // Special handling for failure... - if (SLANG_FAILED(res) && m_sink) + String filename = Path::getFileNameWithoutExt(path); + if (filename == "dxil") { - String filename = Path::getFileNameWithoutExt(path); - if (filename == "dxil") - { - m_sink->diagnose(SourceLoc(), Diagnostics::dxilNotFound); - } - else - { - m_sink->diagnose(SourceLoc(), Diagnostics::noteFailedToLoadDynamicLibrary, path); - } + m_sink->diagnose(SourceLoc(), Diagnostics::dxilNotFound); + } + else + { + m_sink->diagnose(SourceLoc(), Diagnostics::noteFailedToLoadDynamicLibrary, path); } - return res; } + return res; + } - SinkSharedLibraryLoader(ISlangSharedLibraryLoader* loader, DiagnosticSink* sink) : - m_loader(loader), - m_sink(sink) - { - } + SinkSharedLibraryLoader(ISlangSharedLibraryLoader* loader, DiagnosticSink* sink) + : m_loader(loader), m_sink(sink) + { + } - protected: - ISlangUnknown* getInterface(const Guid& guid) - { - return (guid == ISlangUnknown::getTypeGuid() || guid == ISlangSharedLibraryLoader::getTypeGuid()) ? static_cast(this) : nullptr; - } - ISlangSharedLibraryLoader* m_loader; - DiagnosticSink* m_sink; - }; +protected: + ISlangUnknown* getInterface(const Guid& guid) + { + return (guid == ISlangUnknown::getTypeGuid() || + guid == ISlangSharedLibraryLoader::getTypeGuid()) + ? static_cast(this) + : nullptr; + } + ISlangSharedLibraryLoader* m_loader; + DiagnosticSink* m_sink; +}; - } // anonymous +} // namespace - void Session::_setSharedLibraryLoader(ISlangSharedLibraryLoader* loader) +void Session::_setSharedLibraryLoader(ISlangSharedLibraryLoader* loader) +{ + if (m_sharedLibraryLoader != loader) { - if (m_sharedLibraryLoader != loader) - { - // Need to clear all of the libraries - m_downstreamCompilerSet->clear(); - m_downstreamCompilerInitialized = 0; + // Need to clear all of the libraries + m_downstreamCompilerSet->clear(); + m_downstreamCompilerInitialized = 0; - for (Index i = 0; i < Index(SLANG_PASS_THROUGH_COUNT_OF); ++i) - { - m_downstreamCompilers[i].setNull(); - } - - // Set the loader - m_sharedLibraryLoader = loader; + for (Index i = 0; i < Index(SLANG_PASS_THROUGH_COUNT_OF); ++i) + { + m_downstreamCompilers[i].setNull(); } + + // Set the loader + m_sharedLibraryLoader = loader; } +} + +void Session::resetDownstreamCompiler(PassThroughMode type) +{ + // Mark as initialized + m_downstreamCompilerInitialized &= ~(1 << int(type)); + m_downstreamCompilers[int(type)].setNull(); +} - void Session::resetDownstreamCompiler(PassThroughMode type) +IDownstreamCompiler* Session::getOrLoadDownstreamCompiler( + PassThroughMode type, + DiagnosticSink* sink) +{ + if (m_downstreamCompilerInitialized & (1 << int(type))) { - // Mark as initialized - m_downstreamCompilerInitialized &= ~(1 << int(type)); - m_downstreamCompilers[int(type)].setNull(); + return m_downstreamCompilers[int(type)]; } - IDownstreamCompiler* Session::getOrLoadDownstreamCompiler(PassThroughMode type, DiagnosticSink* sink) + if (type == PassThroughMode::GenericCCpp) { - if (m_downstreamCompilerInitialized & (1 << int(type))) - { - return m_downstreamCompilers[int(type)]; - } - - if (type == PassThroughMode::GenericCCpp) - { - // try testing for availability on all C/C++ compilers - getOrLoadDownstreamCompiler(PassThroughMode::Clang, nullptr); - getOrLoadDownstreamCompiler(PassThroughMode::Gcc, nullptr); - getOrLoadDownstreamCompiler(PassThroughMode::VisualStudio, nullptr); - getOrLoadDownstreamCompiler(PassThroughMode::LLVM, nullptr); - } + // try testing for availability on all C/C++ compilers + getOrLoadDownstreamCompiler(PassThroughMode::Clang, nullptr); + getOrLoadDownstreamCompiler(PassThroughMode::Gcc, nullptr); + getOrLoadDownstreamCompiler(PassThroughMode::VisualStudio, nullptr); + getOrLoadDownstreamCompiler(PassThroughMode::LLVM, nullptr); + } - // Mark that we have tried to load it - m_downstreamCompilerInitialized |= (1 << int(type)); - m_downstreamCompilers[int(type)].setNull(); + // Mark that we have tried to load it + m_downstreamCompilerInitialized |= (1 << int(type)); + m_downstreamCompilers[int(type)].setNull(); - // Do we have a locator - auto locator = m_downstreamCompilerLocators[int(type)]; - if (locator) + // Do we have a locator + auto locator = m_downstreamCompilerLocators[int(type)]; + if (locator) + { + m_downstreamCompilerSet->remove(SlangPassThrough(type)); + + // We want to be able to report a diagnostic to the user if a loader + // was unable to locate the desired downstream compiler, but we + // also need to deal with the fact that the locator might "probe" + // multiple possible library versions/names, and failing to load + // one library should not be taken as a hard error. + // + // The approach we use here is to first apply the `locator` directly + // with our `m_sharedLibraryLoader` and see if it succeeds. If + // it does, then we will move along. + // + if (SLANG_FAILED(locator( + m_downstreamCompilerPaths[int(type)], + m_sharedLibraryLoader, + m_downstreamCompilerSet))) { - m_downstreamCompilerSet->remove(SlangPassThrough(type)); - - // We want to be able to report a diagnostic to the user if a loader - // was unable to locate the desired downstream compiler, but we - // also need to deal with the fact that the locator might "probe" - // multiple possible library versions/names, and failing to load - // one library should not be taken as a hard error. + // If the locator reported a failure the first time we invoked + // it, then we will invoke it against with a wrapper shared library + // loader that reported library load failures to our diagnost `sink`. // - // The approach we use here is to first apply the `locator` directly - // with our `m_sharedLibraryLoader` and see if it succeeds. If - // it does, then we will move along. + // This means that in the case of failure the user will see a listing + // of all the libraries that the locator attempted to load but failed + // to find. The user will know that making one or more of these libraries + // available could fix the issue, but we cannot communicate precise + // information to them with this approach (e.g., the difference between + // "I need all of these libraries" vs. "I need at least one of these + // libraries"). // - if (SLANG_FAILED(locator(m_downstreamCompilerPaths[int(type)], m_sharedLibraryLoader, m_downstreamCompilerSet))) + if (sink) { - // If the locator reported a failure the first time we invoked - // it, then we will invoke it against with a wrapper shared library - // loader that reported library load failures to our diagnost `sink`. - // - // This means that in the case of failure the user will see a listing - // of all the libraries that the locator attempted to load but failed - // to find. The user will know that making one or more of these libraries - // available could fix the issue, but we cannot communicate precise - // information to them with this approach (e.g., the difference between - // "I need all of these libraries" vs. "I need at least one of these - // libraries"). - // - if( sink ) - { - sink->diagnose(SourceLoc(), Diagnostics::failedToLoadDownstreamCompiler, type); - } - SinkSharedLibraryLoader loader(m_sharedLibraryLoader, sink); - locator(m_downstreamCompilerPaths[int(type)], &loader, m_downstreamCompilerSet); + sink->diagnose(SourceLoc(), Diagnostics::failedToLoadDownstreamCompiler, type); } - - DownstreamCompilerUtil::updateDefaults(m_downstreamCompilerSet); + SinkSharedLibraryLoader loader(m_sharedLibraryLoader, sink); + locator(m_downstreamCompilerPaths[int(type)], &loader, m_downstreamCompilerSet); } - IDownstreamCompiler* compiler = nullptr; - - if (type == PassThroughMode::GenericCCpp) - { - compiler = m_downstreamCompilerSet->getDefaultCompiler(SLANG_SOURCE_LANGUAGE_CPP); - } - else - { - DownstreamCompilerDesc desc; - desc.type = SlangPassThrough(type); - compiler = DownstreamCompilerUtil::findCompiler(m_downstreamCompilerSet, DownstreamCompilerUtil::MatchType::Newest, desc); - } - m_downstreamCompilers[int(type)] = compiler; - return compiler; + DownstreamCompilerUtil::updateDefaults(m_downstreamCompilerSet); } - void checkTranslationUnit( - TranslationUnitRequest* translationUnit, - LoadedModuleDictionary& loadedModules) + IDownstreamCompiler* compiler = nullptr; + + if (type == PassThroughMode::GenericCCpp) + { + compiler = m_downstreamCompilerSet->getDefaultCompiler(SLANG_SOURCE_LANGUAGE_CPP); + } + else { - SLANG_AST_BUILDER_RAII(translationUnit->compileRequest->getLinkage()->getASTBuilder()); + DownstreamCompilerDesc desc; + desc.type = SlangPassThrough(type); + compiler = DownstreamCompilerUtil::findCompiler( + m_downstreamCompilerSet, + DownstreamCompilerUtil::MatchType::Newest, + desc); + } + m_downstreamCompilers[int(type)] = compiler; + return compiler; +} - SharedSemanticsContext sharedSemanticsContext( - translationUnit->compileRequest->getLinkage(), - translationUnit->getModule(), - translationUnit->compileRequest->getSink(), - &loadedModules, - translationUnit); +void checkTranslationUnit( + TranslationUnitRequest* translationUnit, + LoadedModuleDictionary& loadedModules) +{ + SLANG_AST_BUILDER_RAII(translationUnit->compileRequest->getLinkage()->getASTBuilder()); - SemanticsDeclVisitorBase visitor( (SemanticsContext(&sharedSemanticsContext)) ); + SharedSemanticsContext sharedSemanticsContext( + translationUnit->compileRequest->getLinkage(), + translationUnit->getModule(), + translationUnit->compileRequest->getSink(), + &loadedModules, + translationUnit); - // Apply the visitor to do the main semantic - // checking that is required on all declarations - // in the translation unit. + SemanticsDeclVisitorBase visitor((SemanticsContext(&sharedSemanticsContext))); - visitor.checkModule(translationUnit->getModuleDecl()); + // Apply the visitor to do the main semantic + // checking that is required on all declarations + // in the translation unit. - translationUnit->getModule()->_collectShaderParams(); - } + visitor.checkModule(translationUnit->getModuleDecl()); - void SemanticsVisitor::dispatchStmt(Stmt* stmt, SemanticsContext const& context) + translationUnit->getModule()->_collectShaderParams(); +} + +void SemanticsVisitor::dispatchStmt(Stmt* stmt, SemanticsContext const& context) +{ + SemanticsStmtVisitor visitor(context); + try { - SemanticsStmtVisitor visitor(context); - try - { - visitor.dispatch(stmt); - } - catch(const AbortCompilationException&) { throw; } - catch(...) - { - getSink()->noteInternalErrorLoc(stmt->loc); - throw; - } + visitor.dispatch(stmt); } - - Expr* SemanticsVisitor::dispatchExpr(Expr* expr, SemanticsContext const& context) + catch (const AbortCompilationException&) { - SemanticsExprVisitor visitor(context); - try - { - return visitor.dispatch(expr); - } - catch(const AbortCompilationException&) { throw; } - catch(...) - { - getSink()->noteInternalErrorLoc(expr->loc); - throw; - } + throw; + } + catch (...) + { + getSink()->noteInternalErrorLoc(stmt->loc); + throw; } +} - ASTBuilder* semanticsVisitorGetASTBuilder(SemanticsVisitor* sv) +Expr* SemanticsVisitor::dispatchExpr(Expr* expr, SemanticsContext const& context) +{ + SemanticsExprVisitor visitor(context); + try { - return sv->getASTBuilder(); + return visitor.dispatch(expr); } + catch (const AbortCompilationException&) + { + throw; + } + catch (...) + { + getSink()->noteInternalErrorLoc(expr->loc); + throw; + } +} +ASTBuilder* semanticsVisitorGetASTBuilder(SemanticsVisitor* sv) +{ + return sv->getASTBuilder(); } + +} // namespace Slang diff --git a/source/slang/slang-check.h b/source/slang/slang-check.h index 3a3e23817..bd2bdce41 100644 --- a/source/slang/slang-check.h +++ b/source/slang/slang-check.h @@ -8,23 +8,24 @@ namespace Slang { - class DiagnosticSink; - class EntryPoint; - class Linkage; - class Module; - class ShaderCompiler; - class ShaderLinkInfo; - class ShaderSymbol; +class DiagnosticSink; +class EntryPoint; +class Linkage; +class Module; +class ShaderCompiler; +class ShaderLinkInfo; +class ShaderSymbol; - class TranslationUnitRequest; +class TranslationUnitRequest; - bool isGlobalShaderParameter(VarDeclBase* decl); - bool isFromCoreModule(Decl* decl); +bool isGlobalShaderParameter(VarDeclBase* decl); +bool isFromCoreModule(Decl* decl); - void registerBuiltinDecls(Session* session, Decl* decl); +void registerBuiltinDecls(Session* session, Decl* decl); - Type* unwrapArrayType(Type* type); +Type* unwrapArrayType(Type* type); - OrderedDictionary> getCanonicalGenericConstraints( - ASTBuilder* builder, DeclRef genericDecl); -} +OrderedDictionary> getCanonicalGenericConstraints( + ASTBuilder* builder, + DeclRef genericDecl); +} // namespace Slang diff --git a/source/slang/slang-compiler-options.cpp b/source/slang/slang-compiler-options.cpp index 3325a313a..c01d3fb9c 100644 --- a/source/slang/slang-compiler-options.cpp +++ b/source/slang/slang-compiler-options.cpp @@ -1,339 +1,358 @@ #include "slang-compiler-options.h" + #include "slang-compiler.h" namespace Slang { - void CompilerOptionSet::load(uint32_t count, slang::CompilerOptionEntry* entries) +void CompilerOptionSet::load(uint32_t count, slang::CompilerOptionEntry* entries) +{ + for (uint32_t i = 0; i < count; i++) { - for (uint32_t i = 0; i < count; i++) + CompilerOptionValue value; + value.kind = entries[i].value.kind; + value.intValue = entries[i].value.intValue0; + value.intValue2 = entries[i].value.intValue1; + if (value.kind == CompilerOptionValueKind::String) { - CompilerOptionValue value; - value.kind = entries[i].value.kind; - value.intValue = entries[i].value.intValue0; - value.intValue2 = entries[i].value.intValue1; - if (value.kind == CompilerOptionValueKind::String) - { - value.stringValue = entries[i].value.stringValue0; - value.stringValue2 = entries[i].value.stringValue1; - } - add(entries[i].name, value); + value.stringValue = entries[i].value.stringValue0; + value.stringValue2 = entries[i].value.stringValue1; } + add(entries[i].name, value); } +} - void CompilerOptionSet::writeCommandLineArgs(Session* globalSession, StringBuilder& sb) +void CompilerOptionSet::writeCommandLineArgs(Session* globalSession, StringBuilder& sb) +{ + for (auto& option : options) { - for (auto& option : options) + auto optionInfoIndex = globalSession->m_commandOptions.findOptionByUserValue( + CommandOptions::UserValue(option.key)); + if (optionInfoIndex == -1) + continue; + auto optionInfo = globalSession->m_commandOptions.getOptionAt(optionInfoIndex); + auto nameCommaIndex = optionInfo.names.indexOf(','); + if (nameCommaIndex == -1) + nameCommaIndex = optionInfo.names.getLength(); + auto name = optionInfo.names.head(nameCommaIndex); + switch (option.key) { - auto optionInfoIndex = globalSession->m_commandOptions.findOptionByUserValue(CommandOptions::UserValue(option.key)); - if (optionInfoIndex == -1) - continue; - auto optionInfo = globalSession->m_commandOptions.getOptionAt(optionInfoIndex); - auto nameCommaIndex = optionInfo.names.indexOf(','); - if (nameCommaIndex == -1) nameCommaIndex = optionInfo.names.getLength(); - auto name = optionInfo.names.head(nameCommaIndex); - switch (option.key) + case CompilerOptionName::Capability: + for (auto v : option.value) { - case CompilerOptionName::Capability: - for (auto v : option.value) - { - sb << " " << optionInfo.names << " " << v.stringValue; - } - break; - case CompilerOptionName::Include: - for (auto v : option.value) - { - sb << " -I \"" << v.stringValue << "\""; - } - break; - case CompilerOptionName::MacroDefine: - for (auto v : option.value) - { - sb << " -D" << v.stringValue; - if (v.stringValue2.getLength()) - sb << "=" << v.stringValue2; - } - break; - case CompilerOptionName::VulkanBindShift: // intValue0 (higher 8 bits): kind; intValue0(higher bits): set; intValue1: shift - for (auto v : option.value) - { - uint8_t kind; - int set, shift; - v.unpackInt3(kind, set, shift); - switch ((HLSLToVulkanLayoutOptions::Kind)(kind)) - { - case HLSLToVulkanLayoutOptions::Kind::UnorderedAccess: - sb << " -fvk-u-shift"; - break; - case HLSLToVulkanLayoutOptions::Kind::Sampler: - sb << " -fvk-s-shift"; - break; - case HLSLToVulkanLayoutOptions::Kind::ShaderResource: - sb << " -fvk-t-shift"; - break; - case HLSLToVulkanLayoutOptions::Kind::ConstantBuffer: - sb << " -fvk-b-shift"; - break; - default: - continue; - } - sb << " " << shift << " " << set; - } - break; - case CompilerOptionName::VulkanBindShiftAll: // intValue0: set; intValue1: shift - for (auto v : option.value) - { - sb << " -fvk-all-shift " << v.intValue2 << " " << v.intValue; - } - break; - case CompilerOptionName::VulkanBindGlobals: // intValue0: index; intValue1: set - for (auto v : option.value) - { - sb << " " << name << v.intValue << " " << v.intValue2; - } - break; - case CompilerOptionName::Optimization: - for (auto v : option.value) - { - sb << " -O" << v.intValue; - } - break; - case CompilerOptionName::DownstreamArgs: - for (auto v : option.value) - { - List lines; - StringUtil::split(v.stringValue2.getUnownedSlice(), '\n', lines); - for (auto l : lines) - { - sb << " -x" << v.stringValue << " " << l.trim(); - } - } - break; - case CompilerOptionName::EmitSpirvDirectly: - case CompilerOptionName::GLSLForceScalarLayout: - case CompilerOptionName::ForceDXLayout: - case CompilerOptionName::MatrixLayoutRow: - case CompilerOptionName::MatrixLayoutColumn: - case CompilerOptionName::VulkanInvertY: - case CompilerOptionName::VulkanUseDxPositionW: - case CompilerOptionName::VulkanUseEntryPointName: - case CompilerOptionName::VulkanUseGLLayout: - case CompilerOptionName::VulkanEmitReflection: - case CompilerOptionName::EnableEffectAnnotations: - case CompilerOptionName::DefaultImageFormatUnknown: - case CompilerOptionName::DisableDynamicDispatch: - case CompilerOptionName::DisableSpecialization: - case CompilerOptionName::DumpIntermediates: - if (option.value.getCount() && option.value[0].intValue != 0) - sb << " " << name; - break; + sb << " " << optionInfo.names << " " << v.stringValue; } - } - } - - void CompilerOptionSet::buildHash(DigestBuilder& builder) - { - for (auto& kv : options) - { - builder.append(kv.key); - builder.append(kv.value.getCount()); - for (auto& v : kv.value) + break; + case CompilerOptionName::Include: + for (auto v : option.value) { - if (v.kind == CompilerOptionValueKind::Int) + sb << " -I \"" << v.stringValue << "\""; + } + break; + case CompilerOptionName::MacroDefine: + for (auto v : option.value) + { + sb << " -D" << v.stringValue; + if (v.stringValue2.getLength()) + sb << "=" << v.stringValue2; + } + break; + case CompilerOptionName::VulkanBindShift: // intValue0 (higher 8 bits): kind; + // intValue0(higher bits): set; intValue1: + // shift + for (auto v : option.value) + { + uint8_t kind; + int set, shift; + v.unpackInt3(kind, set, shift); + switch ((HLSLToVulkanLayoutOptions::Kind)(kind)) { - builder.append(v.intValue); + case HLSLToVulkanLayoutOptions::Kind::UnorderedAccess: sb << " -fvk-u-shift"; break; + case HLSLToVulkanLayoutOptions::Kind::Sampler: sb << " -fvk-s-shift"; break; + case HLSLToVulkanLayoutOptions::Kind::ShaderResource: sb << " -fvk-t-shift"; break; + case HLSLToVulkanLayoutOptions::Kind::ConstantBuffer: sb << " -fvk-b-shift"; break; + default: continue; } - else + sb << " " << shift << " " << set; + } + break; + case CompilerOptionName::VulkanBindShiftAll: // intValue0: set; intValue1: shift + for (auto v : option.value) + { + sb << " -fvk-all-shift " << v.intValue2 << " " << v.intValue; + } + break; + case CompilerOptionName::VulkanBindGlobals: // intValue0: index; intValue1: set + for (auto v : option.value) + { + sb << " " << name << v.intValue << " " << v.intValue2; + } + break; + case CompilerOptionName::Optimization: + for (auto v : option.value) + { + sb << " -O" << v.intValue; + } + break; + case CompilerOptionName::DownstreamArgs: + for (auto v : option.value) + { + List lines; + StringUtil::split(v.stringValue2.getUnownedSlice(), '\n', lines); + for (auto l : lines) { - builder.append(v.stringValue); - builder.append(v.stringValue2); + sb << " -x" << v.stringValue << " " << l.trim(); } } + break; + case CompilerOptionName::EmitSpirvDirectly: + case CompilerOptionName::GLSLForceScalarLayout: + case CompilerOptionName::ForceDXLayout: + case CompilerOptionName::MatrixLayoutRow: + case CompilerOptionName::MatrixLayoutColumn: + case CompilerOptionName::VulkanInvertY: + case CompilerOptionName::VulkanUseDxPositionW: + case CompilerOptionName::VulkanUseEntryPointName: + case CompilerOptionName::VulkanUseGLLayout: + case CompilerOptionName::VulkanEmitReflection: + case CompilerOptionName::EnableEffectAnnotations: + case CompilerOptionName::DefaultImageFormatUnknown: + case CompilerOptionName::DisableDynamicDispatch: + case CompilerOptionName::DisableSpecialization: + case CompilerOptionName::DumpIntermediates: + if (option.value.getCount() && option.value[0].intValue != 0) + sb << " " << name; + break; } } +} - bool CompilerOptionSet::allowDuplicate(CompilerOptionName name) - { - switch (name) - { - case CompilerOptionName::Include: - case CompilerOptionName::MacroDefine: - case CompilerOptionName::WarningsAsErrors: - case CompilerOptionName::DisableWarning: - case CompilerOptionName::DisableWarnings: - case CompilerOptionName::EnableWarning: - case CompilerOptionName::Capability: - case CompilerOptionName::DownstreamArgs: - case CompilerOptionName::VulkanBindShift: - case CompilerOptionName::VulkanBindShiftAll: - return true; - } - return false; - } - CompilerOptionValue Slang::CompilerOptionSet::getDefault(CompilerOptionName name) +void CompilerOptionSet::buildHash(DigestBuilder& builder) +{ + for (auto& kv : options) { - switch (name) + builder.append(kv.key); + builder.append(kv.value.getCount()); + for (auto& v : kv.value) { - case CompilerOptionName::Optimization: - return CompilerOptionValue::fromEnum(OptimizationLevel::Default); - default: - return CompilerOptionValue(); + if (v.kind == CompilerOptionValueKind::Int) + { + builder.append(v.intValue); + } + else + { + builder.append(v.stringValue); + builder.append(v.stringValue2); + } } } +} - SlangTargetFlags CompilerOptionSet::getTargetFlags() +bool CompilerOptionSet::allowDuplicate(CompilerOptionName name) +{ + switch (name) { - SlangTargetFlags result = 0; - if (getBoolOption(CompilerOptionName::DumpIr)) - result |= SLANG_TARGET_FLAG_DUMP_IR; - if (getBoolOption(CompilerOptionName::GenerateWholeProgram)) - result |= SLANG_TARGET_FLAG_GENERATE_WHOLE_PROGRAM; - if (!getBoolOption(CompilerOptionName::EmitSpirvViaGLSL)) - result |= SLANG_TARGET_FLAG_GENERATE_SPIRV_DIRECTLY; - if (getBoolOption(CompilerOptionName::ParameterBlocksUseRegisterSpaces)) - result |= SLANG_TARGET_FLAG_PARAMETER_BLOCKS_USE_REGISTER_SPACES; - return result; + case CompilerOptionName::Include: + case CompilerOptionName::MacroDefine: + case CompilerOptionName::WarningsAsErrors: + case CompilerOptionName::DisableWarning: + case CompilerOptionName::DisableWarnings: + case CompilerOptionName::EnableWarning: + case CompilerOptionName::Capability: + case CompilerOptionName::DownstreamArgs: + case CompilerOptionName::VulkanBindShift: + case CompilerOptionName::VulkanBindShiftAll: return true; } - - void CompilerOptionSet::setTargetFlags(SlangTargetFlags flags) + return false; +} +CompilerOptionValue Slang::CompilerOptionSet::getDefault(CompilerOptionName name) +{ + switch (name) { - set(CompilerOptionName::DumpIr, (flags & SLANG_TARGET_FLAG_DUMP_IR) != 0); - set(CompilerOptionName::GenerateWholeProgram, (flags & SLANG_TARGET_FLAG_GENERATE_WHOLE_PROGRAM) != 0); - if ((flags & SLANG_TARGET_FLAG_GENERATE_SPIRV_DIRECTLY) != 0) - set(CompilerOptionName::EmitSpirvViaGLSL, false); - else - set(CompilerOptionName::EmitSpirvViaGLSL, true); - set(CompilerOptionName::ParameterBlocksUseRegisterSpaces, (flags & SLANG_TARGET_FLAG_PARAMETER_BLOCKS_USE_REGISTER_SPACES) != 0); + case CompilerOptionName::Optimization: + return CompilerOptionValue::fromEnum(OptimizationLevel::Default); + default: return CompilerOptionValue(); } +} - void CompilerOptionSet::addTargetFlags(SlangTargetFlags flags) - { - if ((flags & SLANG_TARGET_FLAG_DUMP_IR)) - set(CompilerOptionName::DumpIr, true); +SlangTargetFlags CompilerOptionSet::getTargetFlags() +{ + SlangTargetFlags result = 0; + if (getBoolOption(CompilerOptionName::DumpIr)) + result |= SLANG_TARGET_FLAG_DUMP_IR; + if (getBoolOption(CompilerOptionName::GenerateWholeProgram)) + result |= SLANG_TARGET_FLAG_GENERATE_WHOLE_PROGRAM; + if (!getBoolOption(CompilerOptionName::EmitSpirvViaGLSL)) + result |= SLANG_TARGET_FLAG_GENERATE_SPIRV_DIRECTLY; + if (getBoolOption(CompilerOptionName::ParameterBlocksUseRegisterSpaces)) + result |= SLANG_TARGET_FLAG_PARAMETER_BLOCKS_USE_REGISTER_SPACES; + return result; +} + +void CompilerOptionSet::setTargetFlags(SlangTargetFlags flags) +{ + set(CompilerOptionName::DumpIr, (flags & SLANG_TARGET_FLAG_DUMP_IR) != 0); + set(CompilerOptionName::GenerateWholeProgram, + (flags & SLANG_TARGET_FLAG_GENERATE_WHOLE_PROGRAM) != 0); + if ((flags & SLANG_TARGET_FLAG_GENERATE_SPIRV_DIRECTLY) != 0) + set(CompilerOptionName::EmitSpirvViaGLSL, false); + else + set(CompilerOptionName::EmitSpirvViaGLSL, true); + set(CompilerOptionName::ParameterBlocksUseRegisterSpaces, + (flags & SLANG_TARGET_FLAG_PARAMETER_BLOCKS_USE_REGISTER_SPACES) != 0); +} - if ((flags & SLANG_TARGET_FLAG_GENERATE_WHOLE_PROGRAM) != 0) - set(CompilerOptionName::GenerateWholeProgram, true); +void CompilerOptionSet::addTargetFlags(SlangTargetFlags flags) +{ + if ((flags & SLANG_TARGET_FLAG_DUMP_IR)) + set(CompilerOptionName::DumpIr, true); - if ((flags & SLANG_TARGET_FLAG_GENERATE_SPIRV_DIRECTLY) != 0) - set(CompilerOptionName::EmitSpirvDirectly, true); + if ((flags & SLANG_TARGET_FLAG_GENERATE_WHOLE_PROGRAM) != 0) + set(CompilerOptionName::GenerateWholeProgram, true); - if ((flags & SLANG_TARGET_FLAG_PARAMETER_BLOCKS_USE_REGISTER_SPACES) != 0) - set(CompilerOptionName::ParameterBlocksUseRegisterSpaces, true); - } - MatrixLayoutMode CompilerOptionSet::getMatrixLayoutMode() - { - if (getBoolOption(CompilerOptionName::MatrixLayoutRow)) - return kMatrixLayoutMode_RowMajor; - if (getBoolOption(CompilerOptionName::MatrixLayoutColumn)) - return kMatrixLayoutMode_ColumnMajor; + if ((flags & SLANG_TARGET_FLAG_GENERATE_SPIRV_DIRECTLY) != 0) + set(CompilerOptionName::EmitSpirvDirectly, true); - return (MatrixLayoutMode)kMatrixLayoutMode_RowMajor; - } + if ((flags & SLANG_TARGET_FLAG_PARAMETER_BLOCKS_USE_REGISTER_SPACES) != 0) + set(CompilerOptionName::ParameterBlocksUseRegisterSpaces, true); +} +MatrixLayoutMode CompilerOptionSet::getMatrixLayoutMode() +{ + if (getBoolOption(CompilerOptionName::MatrixLayoutRow)) + return kMatrixLayoutMode_RowMajor; + if (getBoolOption(CompilerOptionName::MatrixLayoutColumn)) + return kMatrixLayoutMode_ColumnMajor; - void CompilerOptionSet::setMatrixLayoutMode(MatrixLayoutMode mode) - { - options.remove(CompilerOptionName::MatrixLayoutColumn); - options.remove(CompilerOptionName::MatrixLayoutRow); - if (mode == kMatrixLayoutMode_ColumnMajor) - set(CompilerOptionName::MatrixLayoutColumn, true); - if (mode == kMatrixLayoutMode_RowMajor) - set(CompilerOptionName::MatrixLayoutRow, true); - } + return (MatrixLayoutMode)kMatrixLayoutMode_RowMajor; +} - Profile CompilerOptionSet::getProfile() - { - if (auto profileRaw = getEnumOption(CompilerOptionName::Profile)) - return Profile(profileRaw); - return Profile(); - } +void CompilerOptionSet::setMatrixLayoutMode(MatrixLayoutMode mode) +{ + options.remove(CompilerOptionName::MatrixLayoutColumn); + options.remove(CompilerOptionName::MatrixLayoutRow); + if (mode == kMatrixLayoutMode_ColumnMajor) + set(CompilerOptionName::MatrixLayoutColumn, true); + if (mode == kMatrixLayoutMode_RowMajor) + set(CompilerOptionName::MatrixLayoutRow, true); +} - void CompilerOptionSet::setProfile(Profile profile) - { - set(CompilerOptionName::Profile, (int)profile.raw); - } +Profile CompilerOptionSet::getProfile() +{ + if (auto profileRaw = getEnumOption(CompilerOptionName::Profile)) + return Profile(profileRaw); + return Profile(); +} - ProfileVersion CompilerOptionSet::getProfileVersion() - { - if (auto profileRaw = getEnumOption(CompilerOptionName::Profile)) - return Profile(profileRaw).getVersion(); - return ProfileVersion::Unknown; - } +void CompilerOptionSet::setProfile(Profile profile) +{ + set(CompilerOptionName::Profile, (int)profile.raw); +} - void CompilerOptionSet::setProfileVersion(ProfileVersion version) - { - Profile profile; - if (auto profileRaw = getEnumOption(CompilerOptionName::Profile)) - profile = Profile(profileRaw); - profile.setVersion(version); - set(CompilerOptionName::Profile, (int)profile.raw); - } +ProfileVersion CompilerOptionSet::getProfileVersion() +{ + if (auto profileRaw = getEnumOption(CompilerOptionName::Profile)) + return Profile(profileRaw).getVersion(); + return ProfileVersion::Unknown; +} - void CompilerOptionSet::addCapabilityAtom(CapabilityName cap) - { - add(CompilerOptionName::Capability, cap); - } +void CompilerOptionSet::setProfileVersion(ProfileVersion version) +{ + Profile profile; + if (auto profileRaw = getEnumOption(CompilerOptionName::Profile)) + profile = Profile(profileRaw); + profile.setVersion(version); + set(CompilerOptionName::Profile, (int)profile.raw); +} - List CompilerOptionSet::getDownstreamArgs(String downstreamToolName) +void CompilerOptionSet::addCapabilityAtom(CapabilityName cap) +{ + add(CompilerOptionName::Capability, cap); +} + +List CompilerOptionSet::getDownstreamArgs(String downstreamToolName) +{ + List result; + auto downstreamArgsArray = getArray(CompilerOptionName::DownstreamArgs); + for (auto& argSet : downstreamArgsArray) { - List result; - auto downstreamArgsArray = getArray(CompilerOptionName::DownstreamArgs); - for (auto& argSet : downstreamArgsArray) + if (argSet.stringValue == downstreamToolName) { - if (argSet.stringValue == downstreamToolName) - { - CommandLineArgs args; - args.deserialize(argSet.stringValue2); - for (auto arg : args.m_args) - result.add(arg.value); - break; - } + CommandLineArgs args; + args.deserialize(argSet.stringValue2); + for (auto arg : args.m_args) + result.add(arg.value); + break; } - return result; } + return result; +} - void CompilerOptionSet::serialize(SerializedOptionsData* outData) +void CompilerOptionSet::serialize(SerializedOptionsData* outData) +{ + for (auto& option : options) { - for (auto& option : options) + for (auto val : option.value) { - for (auto val : option.value) - { - slang::CompilerOptionEntry entry = {}; - entry.name = option.key; - entry.value.kind = val.kind; - entry.value.intValue0 = val.intValue; - entry.value.intValue1 = val.intValue2; - outData->stringPool.add(val.stringValue); - entry.value.stringValue0 = val.stringValue.getBuffer(); - outData->stringPool.add(val.stringValue2); - entry.value.stringValue1 = val.stringValue.getBuffer(); - outData->entries.add(entry); - } + slang::CompilerOptionEntry entry = {}; + entry.name = option.key; + entry.value.kind = val.kind; + entry.value.intValue0 = val.intValue; + entry.value.intValue1 = val.intValue2; + outData->stringPool.add(val.stringValue); + entry.value.stringValue0 = val.stringValue.getBuffer(); + outData->stringPool.add(val.stringValue2); + entry.value.stringValue1 = val.stringValue.getBuffer(); + outData->entries.add(entry); } } +} - void applySettingsToDiagnosticSink(DiagnosticSink* targetSink, DiagnosticSink* outputSink, CompilerOptionSet& options) +void applySettingsToDiagnosticSink( + DiagnosticSink* targetSink, + DiagnosticSink* outputSink, + CompilerOptionSet& options) +{ + auto disableArray = options.getArray(CompilerOptionName::DisableWarning); + for (auto& element : disableArray) { - auto disableArray = options.getArray(CompilerOptionName::DisableWarning); - for (auto& element : disableArray) - { - overrideDiagnostic(targetSink, outputSink, element.stringValue.getUnownedSlice(), Severity::Warning, Severity::Disable); - } - disableArray = options.getArray(CompilerOptionName::DisableWarnings); - for (auto& element : disableArray) - { - overrideDiagnostics(targetSink, outputSink, element.stringValue.getUnownedSlice(), Severity::Warning, Severity::Disable); - } - auto enableArray = options.getArray(CompilerOptionName::EnableWarning); - for (auto& element : enableArray) - { - overrideDiagnostics(targetSink, outputSink, element.stringValue.getUnownedSlice(), Severity::Warning, Severity::Warning); - } - auto warningsAsErrorsArray = options.getArray(CompilerOptionName::WarningsAsErrors); - for (auto& element : warningsAsErrorsArray) - { - if (element.stringValue == "all") - targetSink->setFlag(DiagnosticSink::Flag::TreatWarningsAsErrors); - else - overrideDiagnostics(targetSink, outputSink, element.stringValue.getUnownedSlice(), Severity::Warning, Severity::Error); - } + overrideDiagnostic( + targetSink, + outputSink, + element.stringValue.getUnownedSlice(), + Severity::Warning, + Severity::Disable); + } + disableArray = options.getArray(CompilerOptionName::DisableWarnings); + for (auto& element : disableArray) + { + overrideDiagnostics( + targetSink, + outputSink, + element.stringValue.getUnownedSlice(), + Severity::Warning, + Severity::Disable); + } + auto enableArray = options.getArray(CompilerOptionName::EnableWarning); + for (auto& element : enableArray) + { + overrideDiagnostics( + targetSink, + outputSink, + element.stringValue.getUnownedSlice(), + Severity::Warning, + Severity::Warning); + } + auto warningsAsErrorsArray = options.getArray(CompilerOptionName::WarningsAsErrors); + for (auto& element : warningsAsErrorsArray) + { + if (element.stringValue == "all") + targetSink->setFlag(DiagnosticSink::Flag::TreatWarningsAsErrors); + else + overrideDiagnostics( + targetSink, + outputSink, + element.stringValue.getUnownedSlice(), + Severity::Warning, + Severity::Error); } } +} // namespace Slang diff --git a/source/slang/slang-compiler-options.h b/source/slang/slang-compiler-options.h index f2bf467e2..3c1c76816 100644 --- a/source/slang/slang-compiler-options.h +++ b/source/slang/slang-compiler-options.h @@ -1,414 +1,400 @@ #ifndef SLANG_COMPILER_OPTIONS_H #define SLANG_COMPILER_OPTIONS_H -#include "slang.h" #include "../core/slang-basic.h" #include "../core/slang-crypto.h" #include "slang-generated-capability-defs.h" #include "slang-profile.h" +#include "slang.h" namespace Slang { - using slang::CompilerOptionName; - using slang::CompilerOptionValueKind; - enum MatrixLayoutMode : SlangMatrixLayoutModeIntegral; - enum class LineDirectiveMode : SlangLineDirectiveModeIntegral; - enum class FloatingPointMode : SlangFloatingPointModeIntegral; - enum class OptimizationLevel : SlangOptimizationLevelIntegral; - enum class DebugInfoLevel : SlangDebugInfoLevelIntegral; - enum class CodeGenTarget : SlangCompileTargetIntegral; - - struct CompilerOptionValue - { - CompilerOptionValueKind kind = CompilerOptionValueKind::Int; - int intValue = 0; - int intValue2 = 0; - String stringValue; - String stringValue2; - - template - static CompilerOptionValue fromEnum(T val) - { - static_assert(std::is_enum::value); - CompilerOptionValue value; - value.intValue = (int)val; - value.kind = CompilerOptionValueKind::Int; - return value; - } - - static CompilerOptionValue fromInt(int val) - { - CompilerOptionValue value; - value.intValue = val; - value.kind = CompilerOptionValueKind::Int; - return value; - } - - static CompilerOptionValue fromInt2(int val, int val2) - { - CompilerOptionValue value; - value.intValue = val; - value.intValue2 = val2; - value.kind = CompilerOptionValueKind::Int; - return value; - } +using slang::CompilerOptionName; +using slang::CompilerOptionValueKind; +enum MatrixLayoutMode : SlangMatrixLayoutModeIntegral; +enum class LineDirectiveMode : SlangLineDirectiveModeIntegral; +enum class FloatingPointMode : SlangFloatingPointModeIntegral; +enum class OptimizationLevel : SlangOptimizationLevelIntegral; +enum class DebugInfoLevel : SlangDebugInfoLevelIntegral; +enum class CodeGenTarget : SlangCompileTargetIntegral; + +struct CompilerOptionValue +{ + CompilerOptionValueKind kind = CompilerOptionValueKind::Int; + int intValue = 0; + int intValue2 = 0; + String stringValue; + String stringValue2; + + template + static CompilerOptionValue fromEnum(T val) + { + static_assert(std::is_enum::value); + CompilerOptionValue value; + value.intValue = (int)val; + value.kind = CompilerOptionValueKind::Int; + return value; + } + + static CompilerOptionValue fromInt(int val) + { + CompilerOptionValue value; + value.intValue = val; + value.kind = CompilerOptionValueKind::Int; + return value; + } - void unpackInt3(uint8_t& v0, int& v1, int& v2) - { - v0 = intValue >> 24; - v1 = intValue & 0xFFFFFF; - v2 = intValue2; - } + static CompilerOptionValue fromInt2(int val, int val2) + { + CompilerOptionValue value; + value.intValue = val; + value.intValue2 = val2; + value.kind = CompilerOptionValueKind::Int; + return value; + } + + void unpackInt3(uint8_t& v0, int& v1, int& v2) + { + v0 = intValue >> 24; + v1 = intValue & 0xFFFFFF; + v2 = intValue2; + } - static CompilerOptionValue fromInt3(uint8_t v0, int v1, int v2) - { - CompilerOptionValue value; - value.intValue = (v0 << 24) + (v1 & 0xFFFFFF); - value.intValue2 = v2; - value.kind = CompilerOptionValueKind::Int; - return value; - } + static CompilerOptionValue fromInt3(uint8_t v0, int v1, int v2) + { + CompilerOptionValue value; + value.intValue = (v0 << 24) + (v1 & 0xFFFFFF); + value.intValue2 = v2; + value.kind = CompilerOptionValueKind::Int; + return value; + } + + static CompilerOptionValue fromString(String val) + { + CompilerOptionValue value; + value.stringValue = val; + value.kind = CompilerOptionValueKind::String; + return value; + } +}; + +struct SerializedOptionsData +{ + List entries; + List stringPool; +}; - static CompilerOptionValue fromString(String val) - { - CompilerOptionValue value; - value.stringValue = val; - value.kind = CompilerOptionValueKind::String; - return value; - } - }; +class Session; - struct SerializedOptionsData - { - List entries; - List stringPool; - }; - - class Session; +struct CompilerOptionSet +{ + void load(uint32_t count, slang::CompilerOptionEntry* entries); - struct CompilerOptionSet - { - void load(uint32_t count, slang::CompilerOptionEntry* entries); + void buildHash(DigestBuilder& builder); - void buildHash(DigestBuilder& builder); + static bool allowDuplicate(CompilerOptionName name); - static bool allowDuplicate(CompilerOptionName name); + void writeCommandLineArgs(Session* globalSession, StringBuilder& sb); - void writeCommandLineArgs(Session* globalSession, StringBuilder& sb); + OrderedDictionary> options; - OrderedDictionary> options; + bool hasOption(CompilerOptionName name) { return options.containsKey(name); } - bool hasOption(CompilerOptionName name) + void set(CompilerOptionName name, CompilerOptionValue value) + { + if (auto v = options.tryGetValue(name)) { - return options.containsKey(name); + v->clear(); + v->add(value); + return; } + options[name] = List{value}; + } - void set(CompilerOptionName name, CompilerOptionValue value) + void set(CompilerOptionName name, const List& value) + { + if (auto v = options.tryGetValue(name)) { - if (auto v = options.tryGetValue(name)) - { - v->clear(); - v->add(value); - return; - } - options[name] = List{ value }; + v->clear(); + v->addRange(value); + return; } + options[name] = List{value}; + } - void set(CompilerOptionName name, const List& value) + void add(CompilerOptionName name, CompilerOptionValue value) + { + if (auto v = options.tryGetValue(name)) { - if (auto v = options.tryGetValue(name)) - { - v->clear(); - v->addRange(value); - return; - } - options[name] = List{ value }; + v->add(value); + return; } + options[name] = List{value}; + } - void add(CompilerOptionName name, CompilerOptionValue value) - { - if (auto v = options.tryGetValue(name)) - { - v->add(value); - return; - } - options[name] = List{ value }; - } - - void add(CompilerOptionName name, const List& value, bool replaceDuplicate = true) + void add( + CompilerOptionName name, + const List& value, + bool replaceDuplicate = true) + { + if (auto v = options.tryGetValue(name)) { - if (auto v = options.tryGetValue(name)) + for (auto element : value) { - for (auto element : value) - { - Index index = v->findFirstIndex([&](const CompilerOptionValue& existingVal) - { - if (existingVal.kind == CompilerOptionValueKind::Int) - return existingVal.intValue == element.intValue; - else - return existingVal.stringValue == element.stringValue; - }); - if (index != -1) + Index index = v->findFirstIndex( + [&](const CompilerOptionValue& existingVal) { - if (replaceDuplicate) - { - (*v)[index].intValue2 = element.intValue; - (*v)[index].stringValue2 = element.stringValue2; - } - } - else + if (existingVal.kind == CompilerOptionValueKind::Int) + return existingVal.intValue == element.intValue; + else + return existingVal.stringValue == element.stringValue; + }); + if (index != -1) + { + if (replaceDuplicate) { - v->add(element); + (*v)[index].intValue2 = element.intValue; + (*v)[index].stringValue2 = element.stringValue2; } } - return; + else + { + v->add(element); + } } - options[name] = List{ value }; + return; } + options[name] = List{value}; + } - // Copy settings from other, and replace the current setting. - void overrideWith(const CompilerOptionSet& other) + // Copy settings from other, and replace the current setting. + void overrideWith(const CompilerOptionSet& other) + { + for (auto& kv : other.options) { - for (auto& kv : other.options) - { - if (allowDuplicate(kv.key)) - add(kv.key, kv.value, true); - else - set(kv.key, kv.value); - } + if (allowDuplicate(kv.key)) + add(kv.key, kv.value, true); + else + set(kv.key, kv.value); } + } - // Copy settings from other, but do not replace the current setting - void inheritFrom(const CompilerOptionSet& other) + // Copy settings from other, but do not replace the current setting + void inheritFrom(const CompilerOptionSet& other) + { + for (auto& kv : other.options) { - for (auto& kv : other.options) + if (allowDuplicate(kv.key)) + add(kv.key, kv.value, false); + else { - if (allowDuplicate(kv.key)) - add(kv.key, kv.value, false); - else - { - if (options.containsKey(kv.key)) - continue; - set(kv.key, kv.value); - } + if (options.containsKey(kv.key)) + continue; + set(kv.key, kv.value); } } + } - void add(CompilerOptionName name, int intVal) - { - add(name, CompilerOptionValue::fromInt(intVal)); - } - void add(CompilerOptionName name, int intVal, int intVal2) - { - add(name, CompilerOptionValue::fromInt2(intVal, intVal2)); - } - void add(CompilerOptionName name, uint8_t intVal, int intVal2, int intVal3) - { - add(name, CompilerOptionValue::fromInt3(intVal, intVal2, intVal3)); - } - void add(CompilerOptionName name, String stringVal) - { - add(name, CompilerOptionValue::fromString(stringVal)); - } - void add(CompilerOptionName name, UnownedStringSlice stringVal) - { - add(name, CompilerOptionValue::fromString(stringVal)); - } - void add(CompilerOptionName name, bool boolVal) - { - add(name, CompilerOptionValue::fromInt(boolVal ? 1 : 0)); - } + void add(CompilerOptionName name, int intVal) + { + add(name, CompilerOptionValue::fromInt(intVal)); + } + void add(CompilerOptionName name, int intVal, int intVal2) + { + add(name, CompilerOptionValue::fromInt2(intVal, intVal2)); + } + void add(CompilerOptionName name, uint8_t intVal, int intVal2, int intVal3) + { + add(name, CompilerOptionValue::fromInt3(intVal, intVal2, intVal3)); + } + void add(CompilerOptionName name, String stringVal) + { + add(name, CompilerOptionValue::fromString(stringVal)); + } + void add(CompilerOptionName name, UnownedStringSlice stringVal) + { + add(name, CompilerOptionValue::fromString(stringVal)); + } + void add(CompilerOptionName name, bool boolVal) + { + add(name, CompilerOptionValue::fromInt(boolVal ? 1 : 0)); + } - template - void add(CompilerOptionName name, EnumType enumVal) - { - static_assert(std::is_enum::value); - add(name, (int)enumVal); - } + template + void add(CompilerOptionName name, EnumType enumVal) + { + static_assert(std::is_enum::value); + add(name, (int)enumVal); + } - void set(CompilerOptionName name, int intVal) - { - set(name, CompilerOptionValue::fromInt(intVal)); - } - void set(CompilerOptionName name, int intVal1, int intVal2) - { - set(name, CompilerOptionValue::fromInt2(intVal1, intVal2)); - } - void set(CompilerOptionName name, uint8_t intVal1, int intVal2, int intVal3) - { - set(name, CompilerOptionValue::fromInt3(intVal1, intVal2, intVal3)); - } - void set(CompilerOptionName name, String stringVal) - { - set(name, CompilerOptionValue::fromString(stringVal)); - } - void set(CompilerOptionName name, bool boolVal) - { - set(name, CompilerOptionValue::fromInt(boolVal ? 1 : 0)); - } + void set(CompilerOptionName name, int intVal) + { + set(name, CompilerOptionValue::fromInt(intVal)); + } + void set(CompilerOptionName name, int intVal1, int intVal2) + { + set(name, CompilerOptionValue::fromInt2(intVal1, intVal2)); + } + void set(CompilerOptionName name, uint8_t intVal1, int intVal2, int intVal3) + { + set(name, CompilerOptionValue::fromInt3(intVal1, intVal2, intVal3)); + } + void set(CompilerOptionName name, String stringVal) + { + set(name, CompilerOptionValue::fromString(stringVal)); + } + void set(CompilerOptionName name, bool boolVal) + { + set(name, CompilerOptionValue::fromInt(boolVal ? 1 : 0)); + } - template - void set(CompilerOptionName name, EnumType enumVal) - { - static_assert(std::is_enum::value); - set(name, (int)enumVal); - } + template + void set(CompilerOptionName name, EnumType enumVal) + { + static_assert(std::is_enum::value); + set(name, (int)enumVal); + } - static CompilerOptionValue getDefault(CompilerOptionName name); - bool getBoolOption(CompilerOptionName name) + static CompilerOptionValue getDefault(CompilerOptionName name); + bool getBoolOption(CompilerOptionName name) + { + if (auto result = options.tryGetValue(name)) { - if (auto result = options.tryGetValue(name)) - { - SLANG_ASSERT(result->getCount() != 0 && (*result)[0].kind == CompilerOptionValueKind::Int); - return result->getCount() != 0 && (*result)[0].intValue != 0; - } - return getDefault(name).intValue != 0; + SLANG_ASSERT( + result->getCount() != 0 && (*result)[0].kind == CompilerOptionValueKind::Int); + return result->getCount() != 0 && (*result)[0].intValue != 0; } - int getIntOption(CompilerOptionName name) + return getDefault(name).intValue != 0; + } + int getIntOption(CompilerOptionName name) + { + if (auto result = options.tryGetValue(name)) { - if (auto result = options.tryGetValue(name)) - { - SLANG_ASSERT(result->getCount() != 0 && (*result)[0].kind == CompilerOptionValueKind::Int); - return (*result)[0].intValue; - } - return getDefault(name).intValue != 0; + SLANG_ASSERT( + result->getCount() != 0 && (*result)[0].kind == CompilerOptionValueKind::Int); + return (*result)[0].intValue; } - String getStringOption(CompilerOptionName name) + return getDefault(name).intValue != 0; + } + String getStringOption(CompilerOptionName name) + { + if (auto result = options.tryGetValue(name)) { - if (auto result = options.tryGetValue(name)) - { - SLANG_ASSERT(result->getCount() != 0 && (*result)[0].kind == CompilerOptionValueKind::String); - return (*result)[0].stringValue; - } - return getDefault(name).stringValue; + SLANG_ASSERT( + result->getCount() != 0 && (*result)[0].kind == CompilerOptionValueKind::String); + return (*result)[0].stringValue; } + return getDefault(name).stringValue; + } - template - EnumType getEnumOption(CompilerOptionName name) - { - static_assert(std::is_enum::value); - return (EnumType)getIntOption(name); - } - ArrayView getArray(CompilerOptionName name) + template + EnumType getEnumOption(CompilerOptionName name) + { + static_assert(std::is_enum::value); + return (EnumType)getIntOption(name); + } + ArrayView getArray(CompilerOptionName name) + { + if (auto result = options.tryGetValue(name)) { - if (auto result = options.tryGetValue(name)) - { - return result->getArrayView(); - } - return ArrayView(); + return result->getArrayView(); } + return ArrayView(); + } - CodeGenTarget getTarget() - { - return getEnumOption(CompilerOptionName::Target); - } + CodeGenTarget getTarget() { return getEnumOption(CompilerOptionName::Target); } - SlangTargetFlags getTargetFlags(); - void setTargetFlags(SlangTargetFlags flags); - void addTargetFlags(SlangTargetFlags flags); + SlangTargetFlags getTargetFlags(); + void setTargetFlags(SlangTargetFlags flags); + void addTargetFlags(SlangTargetFlags flags); - MatrixLayoutMode getMatrixLayoutMode(); + MatrixLayoutMode getMatrixLayoutMode(); - void setMatrixLayoutMode(MatrixLayoutMode mode); + void setMatrixLayoutMode(MatrixLayoutMode mode); - ProfileVersion getProfileVersion(); + ProfileVersion getProfileVersion(); - Profile getProfile(); - void setProfile(Profile profile); + Profile getProfile(); + void setProfile(Profile profile); - void setProfileVersion(ProfileVersion version); + void setProfileVersion(ProfileVersion version); - void addCapabilityAtom(CapabilityName cap); + void addCapabilityAtom(CapabilityName cap); - void addPreprocessorDefine(String name, String value) - { - CompilerOptionValue v; - v.stringValue = name; - v.stringValue2 = value; - v.kind = CompilerOptionValueKind::String; - add(CompilerOptionName::MacroDefine, v); - } + void addPreprocessorDefine(String name, String value) + { + CompilerOptionValue v; + v.stringValue = name; + v.stringValue2 = value; + v.kind = CompilerOptionValueKind::String; + add(CompilerOptionName::MacroDefine, v); + } - void addSearchPath(String path) - { - add(CompilerOptionName::Include, String(path)); - } + void addSearchPath(String path) { add(CompilerOptionName::Include, String(path)); } - bool shouldEmitSPIRVDirectly() - { - if (getBoolOption(CompilerOptionName::EmitSpirvViaGLSL)) - return false; - return true; - } + bool shouldEmitSPIRVDirectly() + { + if (getBoolOption(CompilerOptionName::EmitSpirvViaGLSL)) + return false; + return true; + } - bool shouldUseScalarLayout() - { - return getBoolOption(CompilerOptionName::GLSLForceScalarLayout); - } + bool shouldUseScalarLayout() + { + return getBoolOption(CompilerOptionName::GLSLForceScalarLayout); + } - bool shouldUseDXLayout() - { - return getBoolOption(CompilerOptionName::ForceDXLayout); - } + bool shouldUseDXLayout() { return getBoolOption(CompilerOptionName::ForceDXLayout); } - bool shouldDumpIntermediates() - { - return getBoolOption(CompilerOptionName::DumpIntermediates); - } + bool shouldDumpIntermediates() { return getBoolOption(CompilerOptionName::DumpIntermediates); } - bool shouldDumpIR() - { - return getBoolOption(CompilerOptionName::DumpIr); - } + bool shouldDumpIR() { return getBoolOption(CompilerOptionName::DumpIr); } - bool shouldObfuscateCode() - { - return getBoolOption(CompilerOptionName::Obfuscate); - } + bool shouldObfuscateCode() { return getBoolOption(CompilerOptionName::Obfuscate); } - bool shouldPerformMinimumOptimizations() - { - return getBoolOption(CompilerOptionName::MinimumSlangOptimization); - } + bool shouldPerformMinimumOptimizations() + { + return getBoolOption(CompilerOptionName::MinimumSlangOptimization); + } - bool shouldRunNonEssentialValidation() - { - return !getBoolOption(CompilerOptionName::DisableNonEssentialValidations); - } + bool shouldRunNonEssentialValidation() + { + return !getBoolOption(CompilerOptionName::DisableNonEssentialValidations); + } - bool shouldHaveSourceMap() - { - return !getBoolOption(CompilerOptionName::DisableSourceMap); - } + bool shouldHaveSourceMap() { return !getBoolOption(CompilerOptionName::DisableSourceMap); } - FloatingPointMode getFloatingPointMode() - { - return getEnumOption(CompilerOptionName::FloatingPointMode); - } + FloatingPointMode getFloatingPointMode() + { + return getEnumOption(CompilerOptionName::FloatingPointMode); + } - LineDirectiveMode getLineDirectiveMode() - { - return getEnumOption(CompilerOptionName::LineDirectiveMode); - } + LineDirectiveMode getLineDirectiveMode() + { + return getEnumOption(CompilerOptionName::LineDirectiveMode); + } - OptimizationLevel getOptimizationLevel() - { - return getEnumOption(CompilerOptionName::Optimization); - } + OptimizationLevel getOptimizationLevel() + { + return getEnumOption(CompilerOptionName::Optimization); + } - DebugInfoLevel getDebugInfoLevel() - { - return getEnumOption(CompilerOptionName::DebugInformation); - } + DebugInfoLevel getDebugInfoLevel() + { + return getEnumOption(CompilerOptionName::DebugInformation); + } - List getDownstreamArgs(String downstreamToolName); + List getDownstreamArgs(String downstreamToolName); - void serialize(SerializedOptionsData* outData); - }; + void serialize(SerializedOptionsData* outData); +}; - class DiagnosticSink; - void applySettingsToDiagnosticSink(DiagnosticSink* targetSink, DiagnosticSink* outputSink, CompilerOptionSet& options); +class DiagnosticSink; +void applySettingsToDiagnosticSink( + DiagnosticSink* targetSink, + DiagnosticSink* outputSink, + CompilerOptionSet& options); -} +} // namespace Slang #endif diff --git a/source/slang/slang-compiler-tu.cpp b/source/slang/slang-compiler-tu.cpp index 5f88e871d..4a74ca64c 100644 --- a/source/slang/slang-compiler-tu.cpp +++ b/source/slang/slang-compiler-tu.cpp @@ -2,298 +2,294 @@ // and emit precompiled blobs into IR #include "../core/slang-basic.h" +#include "slang-capability.h" +#include "slang-check-impl.h" #include "slang-compiler.h" #include "slang-ir-insts.h" #include "slang-ir-util.h" -#include "slang-capability.h" -#include "slang-check-impl.h" namespace Slang { - // Only attempt to precompile functions: - // 1) With function bodies (not just empty decls) - // 2) Not marked with unsafeForceInlineDecoration - // 3) Have a simple HLSL data type as the return or parameter type - static bool attemptPrecompiledExport(IRInst* inst) +// Only attempt to precompile functions: +// 1) With function bodies (not just empty decls) +// 2) Not marked with unsafeForceInlineDecoration +// 3) Have a simple HLSL data type as the return or parameter type +static bool attemptPrecompiledExport(IRInst* inst) +{ + if (inst->getOp() != kIROp_Func) { - if (inst->getOp() != kIROp_Func) - { - return false; - } - - // Skip functions with no body - bool hasBody = false; - for (auto child : inst->getChildren()) - { - if (child->getOp() == kIROp_Block) - { - hasBody = true; - break; - } - } - if (!hasBody) - { - return false; - } - - // Skip functions marked with unsafeForceInlineDecoration - if (inst->findDecoration()) - { - return false; - } + return false; + } - // Skip non-simple HLSL data types, filters out generics - if (!isSimpleHLSLDataType(inst)) + // Skip functions with no body + bool hasBody = false; + for (auto child : inst->getChildren()) + { + if (child->getOp() == kIROp_Block) { - return false; + hasBody = true; + break; } + } + if (!hasBody) + { + return false; + } - return true; + // Skip functions marked with unsafeForceInlineDecoration + if (inst->findDecoration()) + { + return false; } - /* - * Precompile the module for the given target. - * - * This function creates a target program and emits the precompiled blob as - * an embedded blob in the module IR, e.g. DXIL, SPIR-V. - * Because the IR for the Slang Module may violate the restrictions of the - * target language, the emitted target blob may not be able to include the - * full module, but rather only the subset that can be precompiled. For - * example, DXIL libraries do not allow resources like structured buffers - * to appear in the library interface. Also, no target languages allow - * generics to be precompiled. - * - * Some restrictions can be enforced up front before linking, but some are - * done during target generation in between IR linking+legalization and - * target source emission. - * - * Functions which can be rejected up front: - * - Functions with no body - * - Functions marked with unsafeForceInlineDecoration - * - Functions that define or use generics - * - * The functions not rejected up front are marked with - * DownstreamModuleExportDecoration which indicates functions we're trying to - * export for precompilation, and this also helps to identify the functions - * in the linked IR which survived the additional pruning. - * - * Functions that are rejected after linking+legalization (inside - * emitPrecompiledDownstreamIR): - * - (DXIL) Functions that return or take a HLSLStructuredBufferType - * - (DXIL) Functions that return or take a Matrix type - * - * emitPrecompiled* produces the output artifact containing target language - * blob, and as metadata, the list of functions which survived the second - * phase of filtering. - * - * The original module IR functions matching those are then marked with - * "AvailableInDownstreamIRDecoration" to indicate to future - * module users which functions are present in the precompiled blob. - */ - SLANG_NO_THROW SlangResult SLANG_MCALL Module::precompileForTarget( - SlangCompileTarget target, - slang::IBlob** outDiagnostics) + // Skip non-simple HLSL data types, filters out generics + if (!isSimpleHLSLDataType(inst)) { - CodeGenTarget targetEnum = CodeGenTarget(target); + return false; + } + + return true; +} - auto module = getIRModule(); - auto linkage = getLinkage(); - auto builder = IRBuilder(module); +/* + * Precompile the module for the given target. + * + * This function creates a target program and emits the precompiled blob as + * an embedded blob in the module IR, e.g. DXIL, SPIR-V. + * Because the IR for the Slang Module may violate the restrictions of the + * target language, the emitted target blob may not be able to include the + * full module, but rather only the subset that can be precompiled. For + * example, DXIL libraries do not allow resources like structured buffers + * to appear in the library interface. Also, no target languages allow + * generics to be precompiled. + * + * Some restrictions can be enforced up front before linking, but some are + * done during target generation in between IR linking+legalization and + * target source emission. + * + * Functions which can be rejected up front: + * - Functions with no body + * - Functions marked with unsafeForceInlineDecoration + * - Functions that define or use generics + * + * The functions not rejected up front are marked with + * DownstreamModuleExportDecoration which indicates functions we're trying to + * export for precompilation, and this also helps to identify the functions + * in the linked IR which survived the additional pruning. + * + * Functions that are rejected after linking+legalization (inside + * emitPrecompiledDownstreamIR): + * - (DXIL) Functions that return or take a HLSLStructuredBufferType + * - (DXIL) Functions that return or take a Matrix type + * + * emitPrecompiled* produces the output artifact containing target language + * blob, and as metadata, the list of functions which survived the second + * phase of filtering. + * + * The original module IR functions matching those are then marked with + * "AvailableInDownstreamIRDecoration" to indicate to future + * module users which functions are present in the precompiled blob. + */ +SLANG_NO_THROW SlangResult SLANG_MCALL +Module::precompileForTarget(SlangCompileTarget target, slang::IBlob** outDiagnostics) +{ + CodeGenTarget targetEnum = CodeGenTarget(target); - DiagnosticSink sink(linkage->getSourceManager(), Lexer::sourceLocationLexer); - applySettingsToDiagnosticSink(&sink, &sink, linkage->m_optionSet); - applySettingsToDiagnosticSink(&sink, &sink, m_optionSet); + auto module = getIRModule(); + auto linkage = getLinkage(); + auto builder = IRBuilder(module); - RefPtr targetReq = new TargetRequest(linkage, targetEnum); + DiagnosticSink sink(linkage->getSourceManager(), Lexer::sourceLocationLexer); + applySettingsToDiagnosticSink(&sink, &sink, linkage->m_optionSet); + applySettingsToDiagnosticSink(&sink, &sink, m_optionSet); - List> allComponentTypes; - allComponentTypes.add(this); // Add Module as a component type + RefPtr targetReq = new TargetRequest(linkage, targetEnum); - for (auto entryPoint : this->getEntryPoints()) - { - allComponentTypes.add(entryPoint); // Add the entry point as a component type - } + List> allComponentTypes; + allComponentTypes.add(this); // Add Module as a component type - auto composite = CompositeComponentType::create( - linkage, - allComponentTypes); + for (auto entryPoint : this->getEntryPoints()) + { + allComponentTypes.add(entryPoint); // Add the entry point as a component type + } - composite = fillRequirements(composite); + auto composite = CompositeComponentType::create(linkage, allComponentTypes); - TargetProgram tp(composite, targetReq); - tp.getOrCreateLayout(&sink); - Slang::Index const entryPointCount = m_entryPoints.getCount(); - tp.getOptionSet().add(CompilerOptionName::GenerateWholeProgram, true); + composite = fillRequirements(composite); - switch (targetReq->getTarget()) - { - case CodeGenTarget::DXIL: - tp.getOptionSet().add(CompilerOptionName::Profile, Profile::RawEnum::DX_Lib_6_6); - break; - case CodeGenTarget::SPIRV: - break; - default: - return SLANG_FAIL; - } + TargetProgram tp(composite, targetReq); + tp.getOrCreateLayout(&sink); + Slang::Index const entryPointCount = m_entryPoints.getCount(); + tp.getOptionSet().add(CompilerOptionName::GenerateWholeProgram, true); - tp.getOptionSet().add(CompilerOptionName::EmbedDownstreamIR, true); + switch (targetReq->getTarget()) + { + case CodeGenTarget::DXIL: + tp.getOptionSet().add(CompilerOptionName::Profile, Profile::RawEnum::DX_Lib_6_6); + break; + case CodeGenTarget::SPIRV: break; + default: return SLANG_FAIL; + } - CodeGenContext::EntryPointIndices entryPointIndices; + tp.getOptionSet().add(CompilerOptionName::EmbedDownstreamIR, true); - entryPointIndices.setCount(entryPointCount); - for (Index i = 0; i < entryPointCount; i++) - entryPointIndices[i] = i; - CodeGenContext::Shared sharedCodeGenContext(&tp, entryPointIndices, &sink, nullptr); - CodeGenContext codeGenContext(&sharedCodeGenContext); + CodeGenContext::EntryPointIndices entryPointIndices; - // Mark all public functions as exported, ensure there's at least one. Store a mapping - // of function name to IRInst* for later reference. After linking is done, we'll scan - // the linked result to see which functions survived the pruning and are included in the - // precompiled blob. - Dictionary nameToFunction; - bool hasAtLeastOneFunction = false; - for (auto inst : module->getGlobalInsts()) - { - if (attemptPrecompiledExport(inst)) - { - hasAtLeastOneFunction = true; - builder.addDecoration(inst, kIROp_DownstreamModuleExportDecoration); - nameToFunction[inst->findDecoration()->getMangledName()] = inst; - } - } + entryPointIndices.setCount(entryPointCount); + for (Index i = 0; i < entryPointCount; i++) + entryPointIndices[i] = i; + CodeGenContext::Shared sharedCodeGenContext(&tp, entryPointIndices, &sink, nullptr); + CodeGenContext codeGenContext(&sharedCodeGenContext); - // Bail if there are no functions to export. That's not treated as an error - // because it's possible that the module just doesn't have any simple HLSL. - if (!hasAtLeastOneFunction) + // Mark all public functions as exported, ensure there's at least one. Store a mapping + // of function name to IRInst* for later reference. After linking is done, we'll scan + // the linked result to see which functions survived the pruning and are included in the + // precompiled blob. + Dictionary nameToFunction; + bool hasAtLeastOneFunction = false; + for (auto inst : module->getGlobalInsts()) + { + if (attemptPrecompiledExport(inst)) { - return SLANG_OK; + hasAtLeastOneFunction = true; + builder.addDecoration(inst, kIROp_DownstreamModuleExportDecoration); + nameToFunction[inst->findDecoration()->getMangledName()] = inst; } + } - ComPtr outArtifact; - SlangResult res = codeGenContext.emitPrecompiledDownstreamIR(outArtifact); + // Bail if there are no functions to export. That's not treated as an error + // because it's possible that the module just doesn't have any simple HLSL. + if (!hasAtLeastOneFunction) + { + return SLANG_OK; + } - sink.getBlobIfNeeded(outDiagnostics); - if (res != SLANG_OK) - { - return res; - } + ComPtr outArtifact; + SlangResult res = codeGenContext.emitPrecompiledDownstreamIR(outArtifact); - auto metadata = findAssociatedRepresentation(outArtifact); - if (!metadata) - { - return SLANG_E_NOT_AVAILABLE; - } + sink.getBlobIfNeeded(outDiagnostics); + if (res != SLANG_OK) + { + return res; + } - for (const auto& mangledName : metadata->getExportedFunctionMangledNames()) - { - auto moduleInst = nameToFunction[mangledName]; - builder.addDecoration(moduleInst, kIROp_AvailableInDownstreamIRDecoration, - builder.getIntValue(builder.getIntType(), (int)targetReq->getTarget())); - auto moduleDec = moduleInst->findDecoration(); - moduleDec->removeAndDeallocate(); - } + auto metadata = findAssociatedRepresentation(outArtifact); + if (!metadata) + { + return SLANG_E_NOT_AVAILABLE; + } + + for (const auto& mangledName : metadata->getExportedFunctionMangledNames()) + { + auto moduleInst = nameToFunction[mangledName]; + builder.addDecoration( + moduleInst, + kIROp_AvailableInDownstreamIRDecoration, + builder.getIntValue(builder.getIntType(), (int)targetReq->getTarget())); + auto moduleDec = moduleInst->findDecoration(); + moduleDec->removeAndDeallocate(); + } - // Finally, clean up the transient export decorations left over in the module. These are - // represent functions that were pruned from the IR after linking, before target generation. - for (auto moduleInst : module->getGlobalInsts()) + // Finally, clean up the transient export decorations left over in the module. These are + // represent functions that were pruned from the IR after linking, before target generation. + for (auto moduleInst : module->getGlobalInsts()) + { + if (moduleInst->getOp() == kIROp_Func) { - if (moduleInst->getOp() == kIROp_Func) + if (auto dec = moduleInst->findDecoration()) { - if (auto dec = moduleInst->findDecoration()) - { - dec->removeAndDeallocate(); - } + dec->removeAndDeallocate(); } } + } - ComPtr blob; - outArtifact->loadBlob(ArtifactKeep::Yes, blob.writeRef()); + ComPtr blob; + outArtifact->loadBlob(ArtifactKeep::Yes, blob.writeRef()); - // Add the precompiled blob to the module - builder.setInsertInto(module); + // Add the precompiled blob to the module + builder.setInsertInto(module); - builder.emitEmbeddedDownstreamIR(targetReq->getTarget(), blob); - return SLANG_OK; - } + builder.emitEmbeddedDownstreamIR(targetReq->getTarget(), blob); + return SLANG_OK; +} - SLANG_NO_THROW SlangResult SLANG_MCALL Module::getPrecompiledTargetCode( - SlangCompileTarget target, - slang::IBlob** outCode, - slang::IBlob** outDiagnostics) +SLANG_NO_THROW SlangResult SLANG_MCALL Module::getPrecompiledTargetCode( + SlangCompileTarget target, + slang::IBlob** outCode, + slang::IBlob** outDiagnostics) +{ + SLANG_UNUSED(outDiagnostics); + for (auto globalInst : getIRModule()->getModuleInst()->getChildren()) { - SLANG_UNUSED(outDiagnostics); - for (auto globalInst : getIRModule()->getModuleInst()->getChildren()) - { - if (auto inst = as(globalInst)) + if (auto inst = as(globalInst)) + { + static_assert(CodeGenTarget::DXIL == static_cast(SLANG_DXIL)); + static_assert(CodeGenTarget::SPIRV == static_cast(SLANG_SPIRV)); + if (inst->getTarget() == static_cast(target)) { - static_assert(CodeGenTarget::DXIL == static_cast(SLANG_DXIL)); - static_assert(CodeGenTarget::SPIRV == static_cast(SLANG_SPIRV)); - if (inst->getTarget() == static_cast(target)) - { - auto slice = inst->getBlob()->getStringSlice(); - auto blob = StringBlob::create(slice); - *outCode = blob.detach(); - return SLANG_OK; - } - } - } - return SLANG_FAIL; + auto slice = inst->getBlob()->getStringSlice(); + auto blob = StringBlob::create(slice); + *outCode = blob.detach(); + return SLANG_OK; + } + } } + return SLANG_FAIL; +} - SLANG_NO_THROW SlangInt SLANG_MCALL Module::getModuleDependencyCount() - { - return 0; - } +SLANG_NO_THROW SlangInt SLANG_MCALL Module::getModuleDependencyCount() +{ + return 0; +} - SLANG_NO_THROW SlangResult SLANG_MCALL Module::getModuleDependency( - SlangInt dependencyIndex, - IModule** outModule, - slang::IBlob** outDiagnostics) - { - SLANG_UNUSED(dependencyIndex); - SLANG_UNUSED(outModule); - SLANG_UNUSED(outDiagnostics); - return SLANG_OK; - } +SLANG_NO_THROW SlangResult SLANG_MCALL Module::getModuleDependency( + SlangInt dependencyIndex, + IModule** outModule, + slang::IBlob** outDiagnostics) +{ + SLANG_UNUSED(dependencyIndex); + SLANG_UNUSED(outModule); + SLANG_UNUSED(outDiagnostics); + return SLANG_OK; +} - // ComponentType +// ComponentType - SLANG_NO_THROW SlangResult SLANG_MCALL ComponentType::precompileForTarget( - SlangCompileTarget target, - slang::IBlob** outDiagnostics) - { - SLANG_UNUSED(target); - SLANG_UNUSED(outDiagnostics); - return SLANG_FAIL; - } +SLANG_NO_THROW SlangResult SLANG_MCALL +ComponentType::precompileForTarget(SlangCompileTarget target, slang::IBlob** outDiagnostics) +{ + SLANG_UNUSED(target); + SLANG_UNUSED(outDiagnostics); + return SLANG_FAIL; +} - SLANG_NO_THROW SlangResult SLANG_MCALL ComponentType::getPrecompiledTargetCode( - SlangCompileTarget target, - slang::IBlob** outCode, - slang::IBlob** outDiagnostics) - { - SLANG_UNUSED(target); - SLANG_UNUSED(outCode); - SLANG_UNUSED(outDiagnostics); - return SLANG_FAIL; - } +SLANG_NO_THROW SlangResult SLANG_MCALL ComponentType::getPrecompiledTargetCode( + SlangCompileTarget target, + slang::IBlob** outCode, + slang::IBlob** outDiagnostics) +{ + SLANG_UNUSED(target); + SLANG_UNUSED(outCode); + SLANG_UNUSED(outDiagnostics); + return SLANG_FAIL; +} - SLANG_NO_THROW SlangInt SLANG_MCALL ComponentType::getModuleDependencyCount() - { - return getModuleDependencies().getCount(); - } +SLANG_NO_THROW SlangInt SLANG_MCALL ComponentType::getModuleDependencyCount() +{ + return getModuleDependencies().getCount(); +} - SLANG_NO_THROW SlangResult SLANG_MCALL ComponentType::getModuleDependency( - SlangInt dependencyIndex, - slang::IModule** outModule, - slang::IBlob** outDiagnostics) +SLANG_NO_THROW SlangResult SLANG_MCALL ComponentType::getModuleDependency( + SlangInt dependencyIndex, + slang::IModule** outModule, + slang::IBlob** outDiagnostics) +{ + SLANG_UNUSED(outDiagnostics); + if (dependencyIndex < 0 || dependencyIndex >= getModuleDependencies().getCount()) { - SLANG_UNUSED(outDiagnostics); - if (dependencyIndex < 0 || dependencyIndex >= getModuleDependencies().getCount()) - { - return SLANG_E_INVALID_ARG; - } - *outModule = getModuleDependencies()[dependencyIndex]; - return SLANG_OK; + return SLANG_E_INVALID_ARG; } + *outModule = getModuleDependencies()[dependencyIndex]; + return SLANG_OK; } +} // namespace Slang diff --git a/source/slang/slang-compiler.cpp b/source/slang/slang-compiler.cpp index 40e1903e5..63952db3d 100644 --- a/source/slang/slang-compiler.cpp +++ b/source/slang/slang-compiler.cpp @@ -1,1572 +1,1651 @@ // Compiler.cpp : Defines the entry point for the console application. // +#include "slang-compiler.h" + +#include "../compiler-core/slang-lexer.h" #include "../core/slang-basic.h" -#include "../core/slang-platform.h" +#include "../core/slang-castable.h" +#include "../core/slang-hex-dump-util.h" #include "../core/slang-io.h" #include "../core/slang-performance-profiler.h" -#include "../core/slang-string-util.h" -#include "../core/slang-hex-dump-util.h" +#include "../core/slang-platform.h" #include "../core/slang-riff.h" -#include "../core/slang-type-text-util.h" +#include "../core/slang-string-util.h" #include "../core/slang-type-convert-util.h" -#include "../core/slang-castable.h" - -#include "slang-check.h" +#include "../core/slang-type-text-util.h" #include "slang-check-impl.h" -#include "slang-compiler.h" - -#include "../compiler-core/slang-lexer.h" +#include "slang-check.h" // Artifact +#include "../compiler-core/slang-artifact-associated.h" +#include "../compiler-core/slang-artifact-container-util.h" #include "../compiler-core/slang-artifact-desc-util.h" -#include "../compiler-core/slang-artifact-representation-impl.h" +#include "../compiler-core/slang-artifact-diagnostic-util.h" #include "../compiler-core/slang-artifact-impl.h" +#include "../compiler-core/slang-artifact-representation-impl.h" #include "../compiler-core/slang-artifact-util.h" -#include "../compiler-core/slang-artifact-associated.h" -#include "../compiler-core/slang-artifact-diagnostic-util.h" -#include "../compiler-core/slang-artifact-container-util.h" // Artifact output #include "slang-artifact-output-util.h" - +#include "slang-emit-cuda.h" +#include "slang-glsl-extension-tracker.h" #include "slang-lower-to-ir.h" #include "slang-mangle.h" #include "slang-parameter-binding.h" #include "slang-parser.h" #include "slang-preprocessor.h" -#include "slang-type-layout.h" - -#include "slang-glsl-extension-tracker.h" -#include "slang-emit-cuda.h" - #include "slang-serialize-ast.h" #include "slang-serialize-container.h" +#include "slang-type-layout.h" namespace Slang { // !!!!!!!!!!!!!!!!!!!!!! free functions for DiagnosicSink !!!!!!!!!!!!!!!!!!!!!!!!!!!!! - bool isHeterogeneousTarget(CodeGenTarget target) - { - return ArtifactDescUtil::makeDescForCompileTarget(asExternal(target)).style == ArtifactStyle::Host; - } +bool isHeterogeneousTarget(CodeGenTarget target) +{ + return ArtifactDescUtil::makeDescForCompileTarget(asExternal(target)).style == + ArtifactStyle::Host; +} - void printDiagnosticArg(StringBuilder& sb, CodeGenTarget val) - { - UnownedStringSlice name = TypeTextUtil::getCompileTargetName(asExternal(val)); - name = name.getLength() ? name : toSlice(""); - sb << name; - } +void printDiagnosticArg(StringBuilder& sb, CodeGenTarget val) +{ + UnownedStringSlice name = TypeTextUtil::getCompileTargetName(asExternal(val)); + name = name.getLength() ? name : toSlice(""); + sb << name; +} - void printDiagnosticArg(StringBuilder& sb, PassThroughMode val) - { - sb << TypeTextUtil::getPassThroughName(SlangPassThrough(val)); - } +void printDiagnosticArg(StringBuilder& sb, PassThroughMode val) +{ + sb << TypeTextUtil::getPassThroughName(SlangPassThrough(val)); +} - // - // FrontEndEntryPointRequest - // +// +// FrontEndEntryPointRequest +// - FrontEndEntryPointRequest::FrontEndEntryPointRequest( - FrontEndCompileRequest* compileRequest, - int translationUnitIndex, - Name* name, - Profile profile) - : m_compileRequest(compileRequest) - , m_translationUnitIndex(translationUnitIndex) - , m_name(name) - , m_profile(profile) - {} +FrontEndEntryPointRequest::FrontEndEntryPointRequest( + FrontEndCompileRequest* compileRequest, + int translationUnitIndex, + Name* name, + Profile profile) + : m_compileRequest(compileRequest) + , m_translationUnitIndex(translationUnitIndex) + , m_name(name) + , m_profile(profile) +{ +} - TranslationUnitRequest* FrontEndEntryPointRequest::getTranslationUnit() - { - return getCompileRequest()->translationUnits[m_translationUnitIndex]; - } +TranslationUnitRequest* FrontEndEntryPointRequest::getTranslationUnit() +{ + return getCompileRequest()->translationUnits[m_translationUnitIndex]; +} - // - // EntryPoint - // +// +// EntryPoint +// - ISlangUnknown* EntryPoint::getInterface(const Guid& guid) - { - if(guid == slang::IEntryPoint::getTypeGuid()) - return static_cast(this); +ISlangUnknown* EntryPoint::getInterface(const Guid& guid) +{ + if (guid == slang::IEntryPoint::getTypeGuid()) + return static_cast(this); - return Super::getInterface(guid); - } + return Super::getInterface(guid); +} - RefPtr EntryPoint::create( - Linkage* linkage, - DeclRef funcDeclRef, - Profile profile) - { - RefPtr entryPoint = new EntryPoint( - linkage, - funcDeclRef.getName(), - profile, - funcDeclRef); - entryPoint->m_mangledName = getMangledName(linkage->getASTBuilder(), funcDeclRef); - return entryPoint; - } +RefPtr EntryPoint::create( + Linkage* linkage, + DeclRef funcDeclRef, + Profile profile) +{ + RefPtr entryPoint = + new EntryPoint(linkage, funcDeclRef.getName(), profile, funcDeclRef); + entryPoint->m_mangledName = getMangledName(linkage->getASTBuilder(), funcDeclRef); + return entryPoint; +} - RefPtr EntryPoint::createDummyForPassThrough( - Linkage* linkage, - Name* name, - Profile profile) - { - RefPtr entryPoint = new EntryPoint( - linkage, - name, - profile, - DeclRef()); - return entryPoint; - } +RefPtr EntryPoint::createDummyForPassThrough( + Linkage* linkage, + Name* name, + Profile profile) +{ + RefPtr entryPoint = new EntryPoint(linkage, name, profile, DeclRef()); + return entryPoint; +} - RefPtr EntryPoint::createDummyForDeserialize( - Linkage* linkage, - Name* name, - Profile profile, - String mangledName) - { - RefPtr entryPoint = new EntryPoint( - linkage, - name, - profile, - DeclRef()); - entryPoint->m_mangledName = mangledName; - return entryPoint; - } - - EntryPoint::EntryPoint( - Linkage* linkage, - Name* name, - Profile profile, - DeclRef funcDeclRef) - : ComponentType(linkage) - , m_name(name) - , m_profile(profile) - , m_funcDeclRef(funcDeclRef) - { - // Collect any specialization parameters used by the entry point - // - _collectShaderParams(); - } +RefPtr EntryPoint::createDummyForDeserialize( + Linkage* linkage, + Name* name, + Profile profile, + String mangledName) +{ + RefPtr entryPoint = new EntryPoint(linkage, name, profile, DeclRef()); + entryPoint->m_mangledName = mangledName; + return entryPoint; +} - Module* EntryPoint::getModule() - { - return Slang::getModule(getFuncDecl()); - } +EntryPoint::EntryPoint(Linkage* linkage, Name* name, Profile profile, DeclRef funcDeclRef) + : ComponentType(linkage), m_name(name), m_profile(profile), m_funcDeclRef(funcDeclRef) +{ + // Collect any specialization parameters used by the entry point + // + _collectShaderParams(); +} - Index EntryPoint::getSpecializationParamCount() - { - return m_genericSpecializationParams.getCount() + m_existentialSpecializationParams.getCount(); - } +Module* EntryPoint::getModule() +{ + return Slang::getModule(getFuncDecl()); +} + +Index EntryPoint::getSpecializationParamCount() +{ + return m_genericSpecializationParams.getCount() + m_existentialSpecializationParams.getCount(); +} - SpecializationParam const& EntryPoint::getSpecializationParam(Index index) +SpecializationParam const& EntryPoint::getSpecializationParam(Index index) +{ + auto genericParamCount = m_genericSpecializationParams.getCount(); + if (index < genericParamCount) { - auto genericParamCount = m_genericSpecializationParams.getCount(); - if(index < genericParamCount) - { - return m_genericSpecializationParams[index]; - } - else - { - return m_existentialSpecializationParams[index - genericParamCount]; - } + return m_genericSpecializationParams[index]; } - - Index EntryPoint::getRequirementCount() + else { - // The only requirement of an entry point is the module that contains it. - // - // TODO: We will eventually want to support the case of an entry - // point nested in a `struct` type, in which case there should be - // a single requirement representing that outer type (so that multiple - // entry points nested under the same type can share the storage - // for parameters at that scope). - - // Note: the defensive coding is here because the - // "dummy" entry points we create for pass-through - // compilation will not have an associated module. - // - if( const auto module = getModule() ) - { - return 1; - } - return 0; + return m_existentialSpecializationParams[index - genericParamCount]; } +} - RefPtr EntryPoint::getRequirement(Index index) +Index EntryPoint::getRequirementCount() +{ + // The only requirement of an entry point is the module that contains it. + // + // TODO: We will eventually want to support the case of an entry + // point nested in a `struct` type, in which case there should be + // a single requirement representing that outer type (so that multiple + // entry points nested under the same type can share the storage + // for parameters at that scope). + + // Note: the defensive coding is here because the + // "dummy" entry points we create for pass-through + // compilation will not have an associated module. + // + if (const auto module = getModule()) { - SLANG_UNUSED(index); - SLANG_ASSERT(index == 0); - SLANG_ASSERT(getModule()); - return getModule(); + return 1; } + return 0; +} - String EntryPoint::getEntryPointMangledName(Index index) - { - SLANG_UNUSED(index); - SLANG_ASSERT(index == 0); +RefPtr EntryPoint::getRequirement(Index index) +{ + SLANG_UNUSED(index); + SLANG_ASSERT(index == 0); + SLANG_ASSERT(getModule()); + return getModule(); +} - return m_mangledName; - } +String EntryPoint::getEntryPointMangledName(Index index) +{ + SLANG_UNUSED(index); + SLANG_ASSERT(index == 0); - String EntryPoint::getEntryPointNameOverride(Index index) - { - SLANG_UNUSED(index); - SLANG_ASSERT(index == 0); + return m_mangledName; +} - return m_name ? m_name->text : ""; - } +String EntryPoint::getEntryPointNameOverride(Index index) +{ + SLANG_UNUSED(index); + SLANG_ASSERT(index == 0); - void EntryPoint::acceptVisitor(ComponentTypeVisitor* visitor, SpecializationInfo* specializationInfo) - { - visitor->visitEntryPoint(this, as(specializationInfo)); - } + return m_name ? m_name->text : ""; +} - void EntryPoint::buildHash(DigestBuilder& builder) - { - SLANG_UNUSED(builder); - } +void EntryPoint::acceptVisitor( + ComponentTypeVisitor* visitor, + SpecializationInfo* specializationInfo) +{ + visitor->visitEntryPoint(this, as(specializationInfo)); +} - List const& EntryPoint::getModuleDependencies() - { - if(auto module = getModule()) - return module->getModuleDependencies(); +void EntryPoint::buildHash(DigestBuilder& builder) +{ + SLANG_UNUSED(builder); +} - static List empty; - return empty; - } +List const& EntryPoint::getModuleDependencies() +{ + if (auto module = getModule()) + return module->getModuleDependencies(); - List const& EntryPoint::getFileDependencies() - { - if(const auto module = getModule()) - return getModule()->getFileDependencies(); - - static List empty; - return empty; - } + static List empty; + return empty; +} - TypeConformance::TypeConformance( - Linkage* linkage, - SubtypeWitness* witness, - Int confomrmanceIdOverride, - DiagnosticSink* sink) - : ComponentType(linkage) - , m_subtypeWitness(witness) - , m_conformanceIdOverride(confomrmanceIdOverride) - { - addDepedencyFromWitness(witness); - m_irModule = generateIRForTypeConformance(this, m_conformanceIdOverride, sink); - } +List const& EntryPoint::getFileDependencies() +{ + if (const auto module = getModule()) + return getModule()->getFileDependencies(); - void TypeConformance::addDepedencyFromWitness(SubtypeWitness* witness) + static List empty; + return empty; +} + +TypeConformance::TypeConformance( + Linkage* linkage, + SubtypeWitness* witness, + Int confomrmanceIdOverride, + DiagnosticSink* sink) + : ComponentType(linkage) + , m_subtypeWitness(witness) + , m_conformanceIdOverride(confomrmanceIdOverride) +{ + addDepedencyFromWitness(witness); + m_irModule = generateIRForTypeConformance(this, m_conformanceIdOverride, sink); +} + +void TypeConformance::addDepedencyFromWitness(SubtypeWitness* witness) +{ + if (auto declaredWitness = as(witness)) { - if (auto declaredWitness = as(witness)) - { - auto declModule = getModule(declaredWitness->getDeclRef().getDecl()); - m_moduleDependencyList.addDependency(declModule); - m_fileDependencyList.addDependency(declModule); - if (m_requirementSet.add(declModule)) - { - m_requirements.add(declModule); - } - // TODO: handle the specialization arguments in declaredWitness->declRef.substitutions. - } - else if (auto transitiveWitness = as(witness)) + auto declModule = getModule(declaredWitness->getDeclRef().getDecl()); + m_moduleDependencyList.addDependency(declModule); + m_fileDependencyList.addDependency(declModule); + if (m_requirementSet.add(declModule)) { - addDepedencyFromWitness(transitiveWitness->getMidToSup()); - addDepedencyFromWitness(transitiveWitness->getSubToMid()); + m_requirements.add(declModule); } - else if (auto conjunctionWitness = as(witness)) + // TODO: handle the specialization arguments in declaredWitness->declRef.substitutions. + } + else if (auto transitiveWitness = as(witness)) + { + addDepedencyFromWitness(transitiveWitness->getMidToSup()); + addDepedencyFromWitness(transitiveWitness->getSubToMid()); + } + else if (auto conjunctionWitness = as(witness)) + { + auto componentCount = conjunctionWitness->getComponentCount(); + for (Index i = 0; i < componentCount; ++i) { - auto componentCount = conjunctionWitness->getComponentCount(); - for (Index i = 0; i < componentCount; ++i) - { - auto w = as(conjunctionWitness->getComponentWitness(i)); - if (w) addDepedencyFromWitness(w); - } + auto w = as(conjunctionWitness->getComponentWitness(i)); + if (w) + addDepedencyFromWitness(w); } } +} - ISlangUnknown* TypeConformance::getInterface(const Guid& guid) - { - if (guid == slang::ITypeConformance::getTypeGuid()) - return static_cast(this); +ISlangUnknown* TypeConformance::getInterface(const Guid& guid) +{ + if (guid == slang::ITypeConformance::getTypeGuid()) + return static_cast(this); - return Super::getInterface(guid); - } + return Super::getInterface(guid); +} - void TypeConformance::buildHash(DigestBuilder& builder) - { - //TODO: Implement some kind of hashInto for Val then replace this - auto subtypeWitness = m_subtypeWitness->toString(); +void TypeConformance::buildHash(DigestBuilder& builder) +{ + // TODO: Implement some kind of hashInto for Val then replace this + auto subtypeWitness = m_subtypeWitness->toString(); - builder.append(subtypeWitness); - builder.append(m_conformanceIdOverride); - } + builder.append(subtypeWitness); + builder.append(m_conformanceIdOverride); +} - List const& TypeConformance::getModuleDependencies() - { - return m_moduleDependencyList.getModuleList(); - } +List const& TypeConformance::getModuleDependencies() +{ + return m_moduleDependencyList.getModuleList(); +} - List const& TypeConformance::getFileDependencies() - { - return m_fileDependencyList.getFileList(); - } +List const& TypeConformance::getFileDependencies() +{ + return m_fileDependencyList.getFileList(); +} - Index TypeConformance::getRequirementCount() { return m_requirements.getCount(); } +Index TypeConformance::getRequirementCount() +{ + return m_requirements.getCount(); +} - RefPtr TypeConformance::getRequirement(Index index) - { - return m_requirements[index]; - } +RefPtr TypeConformance::getRequirement(Index index) +{ + return m_requirements[index]; +} - void TypeConformance::acceptVisitor( - ComponentTypeVisitor* visitor, - ComponentType::SpecializationInfo* specializationInfo) - { - SLANG_UNUSED(specializationInfo); - visitor->visitTypeConformance(this); - } +void TypeConformance::acceptVisitor( + ComponentTypeVisitor* visitor, + ComponentType::SpecializationInfo* specializationInfo) +{ + SLANG_UNUSED(specializationInfo); + visitor->visitTypeConformance(this); +} - RefPtr TypeConformance::_validateSpecializationArgsImpl( - SpecializationArg const* args, - Index argCount, - DiagnosticSink* sink) - { - SLANG_UNUSED(args); - SLANG_UNUSED(argCount); - SLANG_UNUSED(sink); - return nullptr; - } +RefPtr TypeConformance::_validateSpecializationArgsImpl( + SpecializationArg const* args, + Index argCount, + DiagnosticSink* sink) +{ + SLANG_UNUSED(args); + SLANG_UNUSED(argCount); + SLANG_UNUSED(sink); + return nullptr; +} - // +// - Profile Profile::lookUp(UnownedStringSlice const& name) - { - #define PROFILE(TAG, NAME, STAGE, VERSION) if(name == UnownedTerminatedStringSlice(#NAME)) return Profile::TAG; - #define PROFILE_ALIAS(TAG, DEF, NAME) if(name == UnownedTerminatedStringSlice(#NAME)) return Profile::TAG; - #include "slang-profile-defs.h" +Profile Profile::lookUp(UnownedStringSlice const& name) +{ +#define PROFILE(TAG, NAME, STAGE, VERSION) \ + if (name == UnownedTerminatedStringSlice(#NAME)) \ + return Profile::TAG; +#define PROFILE_ALIAS(TAG, DEF, NAME) \ + if (name == UnownedTerminatedStringSlice(#NAME)) \ + return Profile::TAG; +#include "slang-profile-defs.h" - return Profile::Unknown; - } + return Profile::Unknown; +} - Profile Profile::lookUp(char const* name) +Profile Profile::lookUp(char const* name) +{ + return lookUp(UnownedTerminatedStringSlice(name)); +} + +CapabilitySet Profile::getCapabilityName() +{ + List result; + switch (getVersion()) { - return lookUp(UnownedTerminatedStringSlice(name)); +#define PROFILE_VERSION(TAG, NAME) \ + case ProfileVersion::TAG: result.add(CapabilityName::TAG); break; +#include "slang-profile-defs.h" + default: break; } - - CapabilitySet Profile::getCapabilityName() + switch (getStage()) { - List result; - switch (getVersion()) - { - #define PROFILE_VERSION(TAG, NAME) case ProfileVersion::TAG: result.add(CapabilityName::TAG); break; - #include "slang-profile-defs.h" - default: - break; - } - switch (getStage()) - { -#define PROFILE_STAGE(TAG, NAME, VAL) case Stage::TAG: result.add(CapabilityName::NAME); break; +#define PROFILE_STAGE(TAG, NAME, VAL) \ + case Stage::TAG: result.add(CapabilityName::NAME); break; #include "slang-profile-defs.h" - default: - break; - } - - CapabilitySet resultSet = CapabilitySet(result); - for(auto i : this->additionalCapabilities) - resultSet.join(i); - return resultSet; + default: break; } - char const* Profile::getName() + CapabilitySet resultSet = CapabilitySet(result); + for (auto i : this->additionalCapabilities) + resultSet.join(i); + return resultSet; +} + +char const* Profile::getName() +{ + switch (raw) { - switch( raw ) - { - default: - return "unknown"; + default: return "unknown"; - #define PROFILE(TAG, NAME, STAGE, VERSION) case Profile::TAG: return #NAME; - #define PROFILE_ALIAS(TAG, DEF, NAME) /* empty */ - #include "slang-profile-defs.h" - } +#define PROFILE(TAG, NAME, STAGE, VERSION) \ + case Profile::TAG: return #NAME; +#define PROFILE_ALIAS(TAG, DEF, NAME) /* empty */ +#include "slang-profile-defs.h" } +} - static const StageInfo kStages[] = - { - #define PROFILE_STAGE(ID, NAME, ENUM) \ - { #NAME, Stage::ID }, +static const StageInfo kStages[] = { +#define PROFILE_STAGE(ID, NAME, ENUM) {#NAME, Stage::ID}, - #define PROFILE_STAGE_ALIAS(ID, NAME, VAL) \ - { #NAME, Stage::ID }, +#define PROFILE_STAGE_ALIAS(ID, NAME, VAL) {#NAME, Stage::ID}, - #include "slang-profile-defs.h" - }; +#include "slang-profile-defs.h" +}; - ConstArrayView getStageInfos() - { - return makeConstArrayView(kStages); - } +ConstArrayView getStageInfos() +{ + return makeConstArrayView(kStages); +} - Stage findStageByName(String const& name) +Stage findStageByName(String const& name) +{ + for (auto entry : kStages) { - for(auto entry : kStages) + if (name == entry.name) { - if(name == entry.name) - { - return entry.stage; - } + return entry.stage; } - - return Stage::Unknown; } - UnownedStringSlice getStageText(Stage stage) + return Stage::Unknown; +} + +UnownedStringSlice getStageText(Stage stage) +{ + for (auto entry : kStages) { - for (auto entry : kStages) + if (stage == entry.stage) { - if (stage == entry.stage) - { - return UnownedStringSlice(entry.name); - } - } - return UnownedStringSlice(); - } - - Stage getStageFromAtom(CapabilityAtom atom) - { - switch (atom) - { - case CapabilityAtom::vertex: - return Stage::Vertex; - case CapabilityAtom::hull: - return Stage::Hull; - case CapabilityAtom::domain: - return Stage::Domain; - case CapabilityAtom::geometry: - return Stage::Geometry; - case CapabilityAtom::fragment: - return Stage::Fragment; - case CapabilityAtom::compute: - return Stage::Compute; - case CapabilityAtom::_mesh: - return Stage::Mesh; - case CapabilityAtom::_amplification: - return Stage::Amplification; - case CapabilityAtom::_anyhit: - return Stage::AnyHit; - case CapabilityAtom::_closesthit: - return Stage::ClosestHit; - case CapabilityAtom::_intersection: - return Stage::Intersection; - case CapabilityAtom::_raygen: - return Stage::RayGeneration; - case CapabilityAtom::_miss: - return Stage::Miss; - case CapabilityAtom::_callable: - return Stage::Callable; - default: - SLANG_UNEXPECTED("unknown stage atom"); - UNREACHABLE_RETURN(Stage::Unknown); - } - } - - SlangResult checkExternalCompilerSupport(Session* session, PassThroughMode passThrough) - { - // Check if the type is supported on this compile - if (passThrough == PassThroughMode::None) - { - // If no pass through -> that will always work! - return SLANG_OK; + return UnownedStringSlice(entry.name); } + } + return UnownedStringSlice(); +} - return session->getOrLoadDownstreamCompiler(passThrough, nullptr) ? SLANG_OK: SLANG_E_NOT_FOUND; +Stage getStageFromAtom(CapabilityAtom atom) +{ + switch (atom) + { + case CapabilityAtom::vertex: return Stage::Vertex; + case CapabilityAtom::hull: return Stage::Hull; + case CapabilityAtom::domain: return Stage::Domain; + case CapabilityAtom::geometry: return Stage::Geometry; + case CapabilityAtom::fragment: return Stage::Fragment; + case CapabilityAtom::compute: return Stage::Compute; + case CapabilityAtom::_mesh: return Stage::Mesh; + case CapabilityAtom::_amplification: return Stage::Amplification; + case CapabilityAtom::_anyhit: return Stage::AnyHit; + case CapabilityAtom::_closesthit: return Stage::ClosestHit; + case CapabilityAtom::_intersection: return Stage::Intersection; + case CapabilityAtom::_raygen: return Stage::RayGeneration; + case CapabilityAtom::_miss: return Stage::Miss; + case CapabilityAtom::_callable: return Stage::Callable; + default: SLANG_UNEXPECTED("unknown stage atom"); UNREACHABLE_RETURN(Stage::Unknown); } +} + +SlangResult checkExternalCompilerSupport(Session* session, PassThroughMode passThrough) +{ + // Check if the type is supported on this compile + if (passThrough == PassThroughMode::None) + { + // If no pass through -> that will always work! + return SLANG_OK; + } + + return session->getOrLoadDownstreamCompiler(passThrough, nullptr) ? SLANG_OK + : SLANG_E_NOT_FOUND; +} - SourceLanguage getDefaultSourceLanguageForDownstreamCompiler(PassThroughMode compiler) +SourceLanguage getDefaultSourceLanguageForDownstreamCompiler(PassThroughMode compiler) +{ + switch (compiler) { - switch (compiler) + case PassThroughMode::None: { - case PassThroughMode::None: - { - return SourceLanguage::Unknown; - } - case PassThroughMode::Fxc: - case PassThroughMode::Dxc: - { - return SourceLanguage::HLSL; - } - case PassThroughMode::Glslang: - { - return SourceLanguage::GLSL; - } - case PassThroughMode::LLVM: - case PassThroughMode::Clang: - case PassThroughMode::VisualStudio: - case PassThroughMode::Gcc: - case PassThroughMode::GenericCCpp: - { - // These could ingest C, but we only have this function to work out a - // 'default' language to ingest. - return SourceLanguage::CPP; - } - case PassThroughMode::NVRTC: - { - return SourceLanguage::CUDA; - } - case PassThroughMode::Tint: - { - return SourceLanguage::WGSL; - } - case PassThroughMode::SpirvDis: - { - return SourceLanguage::SPIRV; - } - case PassThroughMode::MetalC: - { - return SourceLanguage::Metal; - } - default: break; + return SourceLanguage::Unknown; } - SLANG_ASSERT(!"Unknown compiler"); - return SourceLanguage::Unknown; + case PassThroughMode::Fxc: + case PassThroughMode::Dxc: + { + return SourceLanguage::HLSL; + } + case PassThroughMode::Glslang: + { + return SourceLanguage::GLSL; + } + case PassThroughMode::LLVM: + case PassThroughMode::Clang: + case PassThroughMode::VisualStudio: + case PassThroughMode::Gcc: + case PassThroughMode::GenericCCpp: + { + // These could ingest C, but we only have this function to work out a + // 'default' language to ingest. + return SourceLanguage::CPP; + } + case PassThroughMode::NVRTC: + { + return SourceLanguage::CUDA; + } + case PassThroughMode::Tint: + { + return SourceLanguage::WGSL; + } + case PassThroughMode::SpirvDis: + { + return SourceLanguage::SPIRV; + } + case PassThroughMode::MetalC: + { + return SourceLanguage::Metal; + } + default: break; } + SLANG_ASSERT(!"Unknown compiler"); + return SourceLanguage::Unknown; +} - PassThroughMode getDownstreamCompilerRequiredForTarget(CodeGenTarget target) +PassThroughMode getDownstreamCompilerRequiredForTarget(CodeGenTarget target) +{ + switch (target) { - switch (target) + // Don't *require* a downstream compiler for source output + case CodeGenTarget::GLSL: + case CodeGenTarget::HLSL: + case CodeGenTarget::CUDASource: + case CodeGenTarget::CPPSource: + case CodeGenTarget::HostCPPSource: + case CodeGenTarget::PyTorchCppBinding: + case CodeGenTarget::CSource: + case CodeGenTarget::Metal: + case CodeGenTarget::WGSL: { - // Don't *require* a downstream compiler for source output - case CodeGenTarget::GLSL: - case CodeGenTarget::HLSL: - case CodeGenTarget::CUDASource: - case CodeGenTarget::CPPSource: - case CodeGenTarget::HostCPPSource: - case CodeGenTarget::PyTorchCppBinding: - case CodeGenTarget::CSource: - case CodeGenTarget::Metal: - case CodeGenTarget::WGSL: - { - return PassThroughMode::None; - } - case CodeGenTarget::None: - { - return PassThroughMode::None; - } - case CodeGenTarget::WGSLSPIRVAssembly: - case CodeGenTarget::SPIRVAssembly: - case CodeGenTarget::SPIRV: - { - return PassThroughMode::SpirvDis; - } - case CodeGenTarget::DXBytecode: - case CodeGenTarget::DXBytecodeAssembly: - { - return PassThroughMode::Fxc; - } - case CodeGenTarget::DXIL: - case CodeGenTarget::DXILAssembly: - { - return PassThroughMode::Dxc; - } - case CodeGenTarget::MetalLib: - case CodeGenTarget::MetalLibAssembly: - { - return PassThroughMode::MetalC; - } - case CodeGenTarget::ShaderHostCallable: - case CodeGenTarget::ShaderSharedLibrary: - case CodeGenTarget::HostExecutable: - case CodeGenTarget::HostHostCallable: - case CodeGenTarget::HostSharedLibrary: - { - // We need some C/C++ compiler - return PassThroughMode::GenericCCpp; - } - case CodeGenTarget::PTX: - { - return PassThroughMode::NVRTC; - } - case CodeGenTarget::WGSLSPIRV: - { - return PassThroughMode::Tint; - } - default: break; + return PassThroughMode::None; } - - SLANG_ASSERT(!"Unhandled target"); - return PassThroughMode::None; + case CodeGenTarget::None: + { + return PassThroughMode::None; + } + case CodeGenTarget::WGSLSPIRVAssembly: + case CodeGenTarget::SPIRVAssembly: + case CodeGenTarget::SPIRV: + { + return PassThroughMode::SpirvDis; + } + case CodeGenTarget::DXBytecode: + case CodeGenTarget::DXBytecodeAssembly: + { + return PassThroughMode::Fxc; + } + case CodeGenTarget::DXIL: + case CodeGenTarget::DXILAssembly: + { + return PassThroughMode::Dxc; + } + case CodeGenTarget::MetalLib: + case CodeGenTarget::MetalLibAssembly: + { + return PassThroughMode::MetalC; + } + case CodeGenTarget::ShaderHostCallable: + case CodeGenTarget::ShaderSharedLibrary: + case CodeGenTarget::HostExecutable: + case CodeGenTarget::HostHostCallable: + case CodeGenTarget::HostSharedLibrary: + { + // We need some C/C++ compiler + return PassThroughMode::GenericCCpp; + } + case CodeGenTarget::PTX: + { + return PassThroughMode::NVRTC; + } + case CodeGenTarget::WGSLSPIRV: + { + return PassThroughMode::Tint; + } + default: break; } - EndToEndCompileRequest* CodeGenContext::isPassThroughEnabled() - { - auto endToEndReq = isEndToEndCompile(); + SLANG_ASSERT(!"Unhandled target"); + return PassThroughMode::None; +} - // If there isn't an end-to-end compile going on, - // there can be no pass-through. - // - if (!endToEndReq) - return nullptr; +EndToEndCompileRequest* CodeGenContext::isPassThroughEnabled() +{ + auto endToEndReq = isEndToEndCompile(); - // And if pass-through isn't set on that end-to-end compile, - // then we clearly areb't doing a pass-through compile. - // - if(endToEndReq->m_passThrough == PassThroughMode::None) - return nullptr; + // If there isn't an end-to-end compile going on, + // there can be no pass-through. + // + if (!endToEndReq) + return nullptr; - // If we have confirmed that pass-through compilation is going on, - // we return the end-to-end request, because it has all the - // relevant state that we need to implement pass-through mode. - // - return endToEndReq; - } + // And if pass-through isn't set on that end-to-end compile, + // then we clearly areb't doing a pass-through compile. + // + if (endToEndReq->m_passThrough == PassThroughMode::None) + return nullptr; - /// If there is a pass-through compile going on, find the translation unit for the given entry point. - /// Assumes isPassThroughEnabled has already been called - TranslationUnitRequest* getPassThroughTranslationUnit( - EndToEndCompileRequest* endToEndReq, - Int entryPointIndex) - { - SLANG_ASSERT(endToEndReq); - SLANG_ASSERT(endToEndReq->m_passThrough != PassThroughMode::None); - auto frontEndReq = endToEndReq->getFrontEndReq(); - auto entryPointReq = frontEndReq->getEntryPointReq(entryPointIndex); - auto translationUnit = entryPointReq->getTranslationUnit(); - return translationUnit; - } + // If we have confirmed that pass-through compilation is going on, + // we return the end-to-end request, because it has all the + // relevant state that we need to implement pass-through mode. + // + return endToEndReq; +} - TranslationUnitRequest* CodeGenContext::findPassThroughTranslationUnit( - Int entryPointIndex) - { - if (auto endToEndReq = isPassThroughEnabled()) - return getPassThroughTranslationUnit(endToEndReq, entryPointIndex); - return nullptr; - } +/// If there is a pass-through compile going on, find the translation unit for the given entry +/// point. Assumes isPassThroughEnabled has already been called +TranslationUnitRequest* getPassThroughTranslationUnit( + EndToEndCompileRequest* endToEndReq, + Int entryPointIndex) +{ + SLANG_ASSERT(endToEndReq); + SLANG_ASSERT(endToEndReq->m_passThrough != PassThroughMode::None); + auto frontEndReq = endToEndReq->getFrontEndReq(); + auto entryPointReq = frontEndReq->getEntryPointReq(entryPointIndex); + auto translationUnit = entryPointReq->getTranslationUnit(); + return translationUnit; +} - static void _appendCodeWithPath(const UnownedStringSlice& filePath, const UnownedStringSlice& fileContent, StringBuilder& outCodeBuilder) - { - outCodeBuilder << "#line 1 \""; - auto handler = StringEscapeUtil::getHandler(StringEscapeUtil::Style::Cpp); - handler->appendEscaped(filePath, outCodeBuilder); - outCodeBuilder << "\"\n"; - outCodeBuilder << fileContent << "\n"; - } +TranslationUnitRequest* CodeGenContext::findPassThroughTranslationUnit(Int entryPointIndex) +{ + if (auto endToEndReq = isPassThroughEnabled()) + return getPassThroughTranslationUnit(endToEndReq, entryPointIndex); + return nullptr; +} - void trackGLSLTargetCaps( - GLSLExtensionTracker* extensionTracker, - CapabilitySet const& caps) +static void _appendCodeWithPath( + const UnownedStringSlice& filePath, + const UnownedStringSlice& fileContent, + StringBuilder& outCodeBuilder) +{ + outCodeBuilder << "#line 1 \""; + auto handler = StringEscapeUtil::getHandler(StringEscapeUtil::Style::Cpp); + handler->appendEscaped(filePath, outCodeBuilder); + outCodeBuilder << "\"\n"; + outCodeBuilder << fileContent << "\n"; +} + +void trackGLSLTargetCaps(GLSLExtensionTracker* extensionTracker, CapabilitySet const& caps) +{ + for (auto& conjunctions : caps.getAtomSets()) { - for(auto& conjunctions : caps.getAtomSets() ) + for (auto atom : conjunctions) { - for (auto atom : conjunctions) + switch (asAtom(atom)) { - switch (asAtom(atom)) - { - default: - break; + default: break; - case CapabilityAtom::glsl_spirv_1_0: extensionTracker->requireSPIRVVersion(SemanticVersion(1, 0)); break; - case CapabilityAtom::glsl_spirv_1_1: extensionTracker->requireSPIRVVersion(SemanticVersion(1, 1)); break; - case CapabilityAtom::glsl_spirv_1_2: extensionTracker->requireSPIRVVersion(SemanticVersion(1, 2)); break; - case CapabilityAtom::glsl_spirv_1_3: extensionTracker->requireSPIRVVersion(SemanticVersion(1, 3)); break; - case CapabilityAtom::glsl_spirv_1_4: extensionTracker->requireSPIRVVersion(SemanticVersion(1, 4)); break; - case CapabilityAtom::glsl_spirv_1_5: extensionTracker->requireSPIRVVersion(SemanticVersion(1, 5)); break; - case CapabilityAtom::glsl_spirv_1_6: extensionTracker->requireSPIRVVersion(SemanticVersion(1, 6)); break; - } + case CapabilityAtom::glsl_spirv_1_0: + extensionTracker->requireSPIRVVersion(SemanticVersion(1, 0)); + break; + case CapabilityAtom::glsl_spirv_1_1: + extensionTracker->requireSPIRVVersion(SemanticVersion(1, 1)); + break; + case CapabilityAtom::glsl_spirv_1_2: + extensionTracker->requireSPIRVVersion(SemanticVersion(1, 2)); + break; + case CapabilityAtom::glsl_spirv_1_3: + extensionTracker->requireSPIRVVersion(SemanticVersion(1, 3)); + break; + case CapabilityAtom::glsl_spirv_1_4: + extensionTracker->requireSPIRVVersion(SemanticVersion(1, 4)); + break; + case CapabilityAtom::glsl_spirv_1_5: + extensionTracker->requireSPIRVVersion(SemanticVersion(1, 5)); + break; + case CapabilityAtom::glsl_spirv_1_6: + extensionTracker->requireSPIRVVersion(SemanticVersion(1, 6)); + break; } } } +} - SlangResult CodeGenContext::requireTranslationUnitSourceFiles() +SlangResult CodeGenContext::requireTranslationUnitSourceFiles() +{ + if (auto endToEndReq = isPassThroughEnabled()) { - if (auto endToEndReq = isPassThroughEnabled()) + for (auto entryPointIndex : getEntryPointIndices()) { - for (auto entryPointIndex : getEntryPointIndices()) - { - auto translationUnit = getPassThroughTranslationUnit(endToEndReq, entryPointIndex); - SLANG_ASSERT(translationUnit); - /// Make sure we have the source files - SLANG_RETURN_ON_FAIL(translationUnit->requireSourceFiles()); - } + auto translationUnit = getPassThroughTranslationUnit(endToEndReq, entryPointIndex); + SLANG_ASSERT(translationUnit); + /// Make sure we have the source files + SLANG_RETURN_ON_FAIL(translationUnit->requireSourceFiles()); } - - return SLANG_OK; } + return SLANG_OK; +} + #if SLANG_VC -// TODO(JS): This is a workaround +// TODO(JS): This is a workaround // In debug VS builds there is a warning on line about it being unreachable. // for (auto entryPointIndex : getEntryPointIndices()) // It's not clear how that could possibly be unreachable -# pragma warning(push) -# pragma warning(disable:4702) +#pragma warning(push) +#pragma warning(disable : 4702) #endif - SlangResult CodeGenContext::emitEntryPointsSource(ComPtr& outArtifact) - { - outArtifact.setNull(); +SlangResult CodeGenContext::emitEntryPointsSource(ComPtr& outArtifact) +{ + outArtifact.setNull(); - SLANG_RETURN_ON_FAIL(requireTranslationUnitSourceFiles()); + SLANG_RETURN_ON_FAIL(requireTranslationUnitSourceFiles()); - auto endToEndReq = isPassThroughEnabled(); - if(endToEndReq) + auto endToEndReq = isPassThroughEnabled(); + if (endToEndReq) + { + for (auto entryPointIndex : getEntryPointIndices()) { - for (auto entryPointIndex : getEntryPointIndices()) - { - auto translationUnit = getPassThroughTranslationUnit(endToEndReq, entryPointIndex); - SLANG_ASSERT(translationUnit); + auto translationUnit = getPassThroughTranslationUnit(endToEndReq, entryPointIndex); + SLANG_ASSERT(translationUnit); - /// Make sure we have the source files - SLANG_RETURN_ON_FAIL(translationUnit->requireSourceFiles()); + /// Make sure we have the source files + SLANG_RETURN_ON_FAIL(translationUnit->requireSourceFiles()); - // Generate a string that includes the content of - // the source file(s), along with a line directive - // to ensure that we get reasonable messages - // from the downstream compiler when in pass-through - // mode. + // Generate a string that includes the content of + // the source file(s), along with a line directive + // to ensure that we get reasonable messages + // from the downstream compiler when in pass-through + // mode. - StringBuilder codeBuilder; - if (getTargetFormat() == CodeGenTarget::GLSL) + StringBuilder codeBuilder; + if (getTargetFormat() == CodeGenTarget::GLSL) + { + // Special case GLSL + int translationUnitCounter = 0; + for (auto sourceFile : translationUnit->getSourceFiles()) { - // Special case GLSL - int translationUnitCounter = 0; - for (auto sourceFile : translationUnit->getSourceFiles()) + int translationUnitIndex = translationUnitCounter++; + + // We want to output `#line` directives, but we need + // to skip this for the first file, since otherwise + // some GLSL implementations will get tripped up by + // not having the `#version` directive be the first + // thing in the file. + if (translationUnitIndex != 0) { - int translationUnitIndex = translationUnitCounter++; - - // We want to output `#line` directives, but we need - // to skip this for the first file, since otherwise - // some GLSL implementations will get tripped up by - // not having the `#version` directive be the first - // thing in the file. - if (translationUnitIndex != 0) - { - codeBuilder << "#line 1 " << translationUnitIndex << "\n"; - } - codeBuilder << sourceFile->getContent() << "\n"; + codeBuilder << "#line 1 " << translationUnitIndex << "\n"; } + codeBuilder << sourceFile->getContent() << "\n"; } - else + } + else + { + for (auto sourceFile : translationUnit->getSourceFiles()) { - for (auto sourceFile : translationUnit->getSourceFiles()) - { - _appendCodeWithPath(sourceFile->getPathInfo().foundPath.getUnownedSlice(), sourceFile->getContent(), codeBuilder); - } + _appendCodeWithPath( + sourceFile->getPathInfo().foundPath.getUnownedSlice(), + sourceFile->getContent(), + codeBuilder); } + } - auto artifact = ArtifactUtil::createArtifactForCompileTarget(asExternal(getTargetFormat())); - artifact->addRepresentationUnknown(StringBlob::moveCreate(codeBuilder)); + auto artifact = + ArtifactUtil::createArtifactForCompileTarget(asExternal(getTargetFormat())); + artifact->addRepresentationUnknown(StringBlob::moveCreate(codeBuilder)); - outArtifact.swap(artifact); - return SLANG_OK; - } + outArtifact.swap(artifact); return SLANG_OK; } - else - { - return emitEntryPointsSourceFromIR(outArtifact); - } + return SLANG_OK; + } + else + { + return emitEntryPointsSourceFromIR(outArtifact); } +} #if SLANG_VC -# pragma warning(pop) +#pragma warning(pop) #endif - SlangResult CodeGenContext::emitPrecompiledDownstreamIR(ComPtr& outArtifact) - { - return _emitEntryPoints(outArtifact); - } +SlangResult CodeGenContext::emitPrecompiledDownstreamIR(ComPtr& outArtifact) +{ + return _emitEntryPoints(outArtifact); +} - String GetHLSLProfileName(Profile profile) +String GetHLSLProfileName(Profile profile) +{ + switch (profile.getFamily()) { - switch( profile.getFamily() ) - { - case ProfileFamily::DX: - // Profile version is a DX one, so stick with it. - break; + case ProfileFamily::DX: + // Profile version is a DX one, so stick with it. + break; - default: - // Profile is a non-DX profile family, so we need to try - // to clobber it with something to get a default. - // - // TODO: This is a huge hack... - profile.setVersion(ProfileVersion::DX_5_1); - break; - } - - char const* stagePrefix = nullptr; - switch( profile.getStage() ) - { - // Note: All of the raytracing-related stages require - // compiling for a `lib_*` profile, even when only a - // single entry point is present. - // - // We also go ahead and use this target in any case - // where we don't know the actual stage to compiel for, - // as a fallback option. - // - // TODO: We also want to use this option when compiling - // multiple entry points to a DXIL library. - // - default: - stagePrefix = "lib"; - break; - - // The traditional rasterization pipeline and compute - // shaders all have custom profile names that identify - // both the stage and shader model, which need to be - // used when compiling a single entry point. - // - #define CASE(NAME, PREFIX) case Stage::NAME: stagePrefix = #PREFIX; break - CASE(Vertex, vs); - CASE(Hull, hs); - CASE(Domain, ds); - CASE(Geometry, gs); - CASE(Fragment, ps); - CASE(Compute, cs); - CASE(Amplification, as); - CASE(Mesh, ms); - #undef CASE - } - - char const* versionSuffix = nullptr; - switch(profile.getVersion()) - { - #define CASE(TAG, SUFFIX) case ProfileVersion::TAG: versionSuffix = #SUFFIX; break - CASE(DX_4_0, _4_0); - CASE(DX_4_1, _4_1); - CASE(DX_5_0, _5_0); - CASE(DX_5_1, _5_1); - CASE(DX_6_0, _6_0); - CASE(DX_6_1, _6_1); - CASE(DX_6_2, _6_2); - CASE(DX_6_3, _6_3); - CASE(DX_6_4, _6_4); - CASE(DX_6_5, _6_5); - CASE(DX_6_6, _6_6); - CASE(DX_6_7, _6_7); - #undef CASE - - default: - return "unknown"; - } - - String result; - result.append(stagePrefix); - result.append(versionSuffix); - return result; + default: + // Profile is a non-DX profile family, so we need to try + // to clobber it with something to get a default. + // + // TODO: This is a huge hack... + profile.setVersion(ProfileVersion::DX_5_1); + break; } - void reportExternalCompileError(const char* compilerName, Severity severity, SlangResult res, const UnownedStringSlice& diagnostic, DiagnosticSink* sink) + char const* stagePrefix = nullptr; + switch (profile.getStage()) { - StringBuilder builder; - if (compilerName) - { - builder << compilerName << ": "; - } - - if (SLANG_FAILED(res) && res != SLANG_FAIL) - { - { - char tmp[17]; - sprintf_s(tmp, SLANG_COUNT_OF(tmp), "0x%08x", uint32_t(res)); - builder << "Result(" << tmp << ") "; - } - - PlatformUtil::appendResult(res, builder); - } - - if (diagnostic.getLength() > 0) - { - builder.append(diagnostic); - if (!diagnostic.endsWith("\n")) - { - builder.append("\n"); - } - } - - sink->diagnoseRaw(severity, builder.getUnownedSlice()); - } + // Note: All of the raytracing-related stages require + // compiling for a `lib_*` profile, even when only a + // single entry point is present. + // + // We also go ahead and use this target in any case + // where we don't know the actual stage to compiel for, + // as a fallback option. + // + // TODO: We also want to use this option when compiling + // multiple entry points to a DXIL library. + // + default: + stagePrefix = "lib"; + break; + + // The traditional rasterization pipeline and compute + // shaders all have custom profile names that identify + // both the stage and shader model, which need to be + // used when compiling a single entry point. + // +#define CASE(NAME, PREFIX) \ + case Stage::NAME: stagePrefix = #PREFIX; break + CASE(Vertex, vs); + CASE(Hull, hs); + CASE(Domain, ds); + CASE(Geometry, gs); + CASE(Fragment, ps); + CASE(Compute, cs); + CASE(Amplification, as); + CASE(Mesh, ms); +#undef CASE + } + + char const* versionSuffix = nullptr; + switch (profile.getVersion()) + { +#define CASE(TAG, SUFFIX) \ + case ProfileVersion::TAG: versionSuffix = #SUFFIX; break + CASE(DX_4_0, _4_0); + CASE(DX_4_1, _4_1); + CASE(DX_5_0, _5_0); + CASE(DX_5_1, _5_1); + CASE(DX_6_0, _6_0); + CASE(DX_6_1, _6_1); + CASE(DX_6_2, _6_2); + CASE(DX_6_3, _6_3); + CASE(DX_6_4, _6_4); + CASE(DX_6_5, _6_5); + CASE(DX_6_6, _6_6); + CASE(DX_6_7, _6_7); +#undef CASE + + default: return "unknown"; + } + + String result; + result.append(stagePrefix); + result.append(versionSuffix); + return result; +} - void reportExternalCompileError(const char* compilerName, SlangResult res, const UnownedStringSlice& diagnostic, DiagnosticSink* sink) +void reportExternalCompileError( + const char* compilerName, + Severity severity, + SlangResult res, + const UnownedStringSlice& diagnostic, + DiagnosticSink* sink) +{ + StringBuilder builder; + if (compilerName) { - // TODO(tfoley): need a better policy for how we translate diagnostics - // back into the Slang world (although we should always try to generate - // HLSL that doesn't produce any diagnostics...) - reportExternalCompileError(compilerName, SLANG_FAILED(res) ? Severity::Error : Severity::Warning, res, diagnostic, sink); + builder << compilerName << ": "; } - static String _getDisplayPath(DiagnosticSink* sink, SourceFile* sourceFile) + if (SLANG_FAILED(res) && res != SLANG_FAIL) { - if (sink->isFlagSet(DiagnosticSink::Flag::VerbosePath)) { - return sourceFile->calcVerbosePath(); + char tmp[17]; + sprintf_s(tmp, SLANG_COUNT_OF(tmp), "0x%08x", uint32_t(res)); + builder << "Result(" << tmp << ") "; } - else + + PlatformUtil::appendResult(res, builder); + } + + if (diagnostic.getLength() > 0) + { + builder.append(diagnostic); + if (!diagnostic.endsWith("\n")) { - return sourceFile->getPathInfo().foundPath; + builder.append("\n"); } } - String CodeGenContext::calcSourcePathForEntryPoints() + sink->diagnoseRaw(severity, builder.getUnownedSlice()); +} + +void reportExternalCompileError( + const char* compilerName, + SlangResult res, + const UnownedStringSlice& diagnostic, + DiagnosticSink* sink) +{ + // TODO(tfoley): need a better policy for how we translate diagnostics + // back into the Slang world (although we should always try to generate + // HLSL that doesn't produce any diagnostics...) + reportExternalCompileError( + compilerName, + SLANG_FAILED(res) ? Severity::Error : Severity::Warning, + res, + diagnostic, + sink); +} + +static String _getDisplayPath(DiagnosticSink* sink, SourceFile* sourceFile) +{ + if (sink->isFlagSet(DiagnosticSink::Flag::VerbosePath)) { - String failureMode = "slang-generated"; - if (getEntryPointCount() != 1) - return failureMode; - auto entryPointIndex = getSingleEntryPointIndex(); - auto translationUnitRequest = findPassThroughTranslationUnit(entryPointIndex); - if (!translationUnitRequest) - return failureMode; + return sourceFile->calcVerbosePath(); + } + else + { + return sourceFile->getPathInfo().foundPath; + } +} - const auto& sourceFiles = translationUnitRequest->getSourceFiles(); +String CodeGenContext::calcSourcePathForEntryPoints() +{ + String failureMode = "slang-generated"; + if (getEntryPointCount() != 1) + return failureMode; + auto entryPointIndex = getSingleEntryPointIndex(); + auto translationUnitRequest = findPassThroughTranslationUnit(entryPointIndex); + if (!translationUnitRequest) + return failureMode; - auto sink = getSink(); + const auto& sourceFiles = translationUnitRequest->getSourceFiles(); - const Index numSourceFiles = sourceFiles.getCount(); + auto sink = getSink(); - switch (numSourceFiles) + const Index numSourceFiles = sourceFiles.getCount(); + + switch (numSourceFiles) + { + case 0: return "unknown"; + case 1: return _getDisplayPath(sink, sourceFiles[0]); + default: { - case 0: return "unknown"; - case 1: return _getDisplayPath(sink, sourceFiles[0]); - default: + StringBuilder builder; + builder << _getDisplayPath(sink, sourceFiles[0]); + for (int i = 1; i < numSourceFiles; ++i) { - StringBuilder builder; - builder << _getDisplayPath(sink, sourceFiles[0]); - for (int i = 1; i < numSourceFiles; ++i) - { - builder << ";" << _getDisplayPath(sink, sourceFiles[i]); - } - return builder; + builder << ";" << _getDisplayPath(sink, sourceFiles[i]); } + return builder; } } +} - // Helper function for cases where we can assume a single entry point - Int assertSingleEntryPoint(List const& entryPointIndices) { - SLANG_ASSERT(entryPointIndices.getCount() == 1); - return *entryPointIndices.begin(); - } +// Helper function for cases where we can assume a single entry point +Int assertSingleEntryPoint(List const& entryPointIndices) +{ + SLANG_ASSERT(entryPointIndices.getCount() == 1); + return *entryPointIndices.begin(); +} - // True if it's best to use 'emitted' source for complication. For a downstream compiler - // that is not file based, this is always ok. - /// - /// If the downstream compiler is file system based, we may want to just use the file that was passed to be compiled. - /// That the downstream compiler can determine if it will then save the file or not based on if it's a match - - /// and generally there will not be a match with emitted source. - /// - /// This test is only used for pass through mode. - static bool _useEmittedSource(IDownstreamCompiler* compiler, TranslationUnitRequest* translationUnit) +// True if it's best to use 'emitted' source for complication. For a downstream compiler +// that is not file based, this is always ok. +/// +/// If the downstream compiler is file system based, we may want to just use the file that was +/// passed to be compiled. That the downstream compiler can determine if it will then save the file +/// or not based on if it's a match - and generally there will not be a match with emitted source. +/// +/// This test is only used for pass through mode. +static bool _useEmittedSource( + IDownstreamCompiler* compiler, + TranslationUnitRequest* translationUnit) +{ + // We only bother if it's a file based compiler. + if (compiler->isFileBased()) { - // We only bother if it's a file based compiler. - if (compiler->isFileBased()) - { - // It can only have *one* source file as otherwise we have to combine to make a new source file anyway - return translationUnit->getSourceArtifacts().getCount() != 1; - } - return true; + // It can only have *one* source file as otherwise we have to combine to make a new source + // file anyway + return translationUnit->getSourceArtifacts().getCount() != 1; } + return true; +} - static Severity _getDiagnosticSeverity(ArtifactDiagnostic::Severity severity) +static Severity _getDiagnosticSeverity(ArtifactDiagnostic::Severity severity) +{ + switch (severity) { - switch (severity) - { - case ArtifactDiagnostic::Severity::Warning: return Severity::Warning; - case ArtifactDiagnostic::Severity::Info: return Severity::Note; - default: return Severity::Error; - } + case ArtifactDiagnostic::Severity::Warning: return Severity::Warning; + case ArtifactDiagnostic::Severity::Info: return Severity::Note; + default: return Severity::Error; } +} - static RefPtr _newExtensionTracker(CodeGenTarget target) +static RefPtr _newExtensionTracker(CodeGenTarget target) +{ + switch (target) { - switch (target) + case CodeGenTarget::PTX: + case CodeGenTarget::CUDASource: { - case CodeGenTarget::PTX: - case CodeGenTarget::CUDASource: - { - return new CUDAExtensionTracker; - } - case CodeGenTarget::SPIRV: - case CodeGenTarget::GLSL: - { - return new GLSLExtensionTracker; - } - default: return nullptr; + return new CUDAExtensionTracker; + } + case CodeGenTarget::SPIRV: + case CodeGenTarget::GLSL: + { + return new GLSLExtensionTracker; } + default: return nullptr; } +} - static CodeGenTarget _getDefaultSourceForTarget(CodeGenTarget target) +static CodeGenTarget _getDefaultSourceForTarget(CodeGenTarget target) +{ + switch (target) { - switch (target) + case CodeGenTarget::ShaderHostCallable: + case CodeGenTarget::ShaderSharedLibrary: { - case CodeGenTarget::ShaderHostCallable: - case CodeGenTarget::ShaderSharedLibrary: - { - return CodeGenTarget::CPPSource; - } - case CodeGenTarget::HostHostCallable: - case CodeGenTarget::HostExecutable: - case CodeGenTarget::HostSharedLibrary: - { - return CodeGenTarget::HostCPPSource; - } - case CodeGenTarget::PTX: return CodeGenTarget::CUDASource; - case CodeGenTarget::DXBytecode: return CodeGenTarget::HLSL; - case CodeGenTarget::DXIL: return CodeGenTarget::HLSL; - case CodeGenTarget::SPIRV: return CodeGenTarget::GLSL; - case CodeGenTarget::MetalLib: return CodeGenTarget::Metal; - case CodeGenTarget::WGSLSPIRV: return CodeGenTarget::WGSL; - default: break; + return CodeGenTarget::CPPSource; + } + case CodeGenTarget::HostHostCallable: + case CodeGenTarget::HostExecutable: + case CodeGenTarget::HostSharedLibrary: + { + return CodeGenTarget::HostCPPSource; } - return CodeGenTarget::Unknown; + case CodeGenTarget::PTX: return CodeGenTarget::CUDASource; + case CodeGenTarget::DXBytecode: return CodeGenTarget::HLSL; + case CodeGenTarget::DXIL: return CodeGenTarget::HLSL; + case CodeGenTarget::SPIRV: return CodeGenTarget::GLSL; + case CodeGenTarget::MetalLib: return CodeGenTarget::Metal; + case CodeGenTarget::WGSLSPIRV: return CodeGenTarget::WGSL; + default: break; } + return CodeGenTarget::Unknown; +} - static bool _isCPUHostTarget(CodeGenTarget target) - { - auto desc = ArtifactDescUtil::makeDescForCompileTarget(asExternal(target)); - return desc.style == ArtifactStyle::Host; - } +static bool _isCPUHostTarget(CodeGenTarget target) +{ + auto desc = ArtifactDescUtil::makeDescForCompileTarget(asExternal(target)); + return desc.style == ArtifactStyle::Host; +} - static bool _shouldSetEntryPointName(TargetProgram* targetProgram) - { - if (!isKhronosTarget(targetProgram->getTargetReq())) - return true; - if (targetProgram->getOptionSet().getBoolOption(CompilerOptionName::VulkanUseEntryPointName)) - return true; - return false; - } +static bool _shouldSetEntryPointName(TargetProgram* targetProgram) +{ + if (!isKhronosTarget(targetProgram->getTargetReq())) + return true; + if (targetProgram->getOptionSet().getBoolOption(CompilerOptionName::VulkanUseEntryPointName)) + return true; + return false; +} - SlangResult passthroughDownstreamDiagnostics(DiagnosticSink* sink, IDownstreamCompiler* compiler, IArtifact* artifact) - { - auto diagnostics = findAssociatedRepresentation(artifact); +SlangResult passthroughDownstreamDiagnostics( + DiagnosticSink* sink, + IDownstreamCompiler* compiler, + IArtifact* artifact) +{ + auto diagnostics = findAssociatedRepresentation(artifact); - if (!diagnostics) - return SLANG_OK; + if (!diagnostics) + return SLANG_OK; - if (diagnostics->getCount()) - { - StringBuilder compilerText; - DownstreamCompilerUtil::appendAsText(compiler->getDesc(), compilerText); + if (diagnostics->getCount()) + { + StringBuilder compilerText; + DownstreamCompilerUtil::appendAsText(compiler->getDesc(), compilerText); - StringBuilder builder; + StringBuilder builder; - auto const diagnosticCount = diagnostics->getCount(); - for (Index i = 0; i < diagnosticCount; ++i) - { - const auto& diagnostic = *diagnostics->getAt(i); + auto const diagnosticCount = diagnostics->getCount(); + for (Index i = 0; i < diagnosticCount; ++i) + { + const auto& diagnostic = *diagnostics->getAt(i); - builder.clear(); + builder.clear(); - const Severity severity = _getDiagnosticSeverity(diagnostic.severity); + const Severity severity = _getDiagnosticSeverity(diagnostic.severity); - if (diagnostic.filePath.count == 0 && diagnostic.location.line == 0 && severity == Severity::Note) + if (diagnostic.filePath.count == 0 && diagnostic.location.line == 0 && + severity == Severity::Note) + { + // If theres no filePath line number and it's info, output severity and text alone + builder << getSeverityName(severity) << " : "; + } + else + { + if (diagnostic.filePath.count) { - // If theres no filePath line number and it's info, output severity and text alone - builder << getSeverityName(severity) << " : "; + builder << asStringSlice(diagnostic.filePath); } - else - { - if (diagnostic.filePath.count) - { - builder << asStringSlice(diagnostic.filePath); - } - - if (diagnostic.location.line) - { - builder << "(" << diagnostic.location.line << ")"; - } - builder << ": "; + if (diagnostic.location.line) + { + builder << "(" << diagnostic.location.line << ")"; + } - if (diagnostic.stage == ArtifactDiagnostic::Stage::Link) - { - builder << "link "; - } + builder << ": "; - builder << getSeverityName(severity); - builder << " " << asStringSlice(diagnostic.code) << ": "; + if (diagnostic.stage == ArtifactDiagnostic::Stage::Link) + { + builder << "link "; } - builder << asStringSlice(diagnostic.text); - reportExternalCompileError(compilerText.getBuffer(), severity, SLANG_OK, builder.getUnownedSlice(), sink); + builder << getSeverityName(severity); + builder << " " << asStringSlice(diagnostic.code) << ": "; } - } - // If any errors are emitted, then we are done - if (diagnostics->hasOfAtLeastSeverity(ArtifactDiagnostic::Severity::Error)) - { - return SLANG_FAIL; + builder << asStringSlice(diagnostic.text); + reportExternalCompileError( + compilerText.getBuffer(), + severity, + SLANG_OK, + builder.getUnownedSlice(), + sink); } - - return SLANG_OK; } - SlangResult CodeGenContext::emitWithDownstreamForEntryPoints(ComPtr& outArtifact) + // If any errors are emitted, then we are done + if (diagnostics->hasOfAtLeastSeverity(ArtifactDiagnostic::Severity::Error)) { - outArtifact.setNull(); + return SLANG_FAIL; + } - auto sink = getSink(); - auto session = getSession(); + return SLANG_OK; +} - CodeGenTarget sourceTarget = CodeGenTarget::None; - SourceLanguage sourceLanguage = SourceLanguage::Unknown; +SlangResult CodeGenContext::emitWithDownstreamForEntryPoints(ComPtr& outArtifact) +{ + outArtifact.setNull(); - auto target = getTargetFormat(); - RefPtr extensionTracker = _newExtensionTracker(target); - PassThroughMode compilerType; + auto sink = getSink(); + auto session = getSession(); - SliceAllocator allocator; - - if (auto endToEndReq = isPassThroughEnabled()) - { - compilerType = endToEndReq->m_passThrough; - } - else - { - // If we are not in pass through, lookup the default compiler for the emitted source type + CodeGenTarget sourceTarget = CodeGenTarget::None; + SourceLanguage sourceLanguage = SourceLanguage::Unknown; - // Get the default source codegen type for a given target - sourceTarget = _getDefaultSourceForTarget(target); - compilerType = (PassThroughMode)session->getDownstreamCompilerForTransition((SlangCompileTarget)sourceTarget, (SlangCompileTarget)target); - // We should have a downstream compiler set at this point - if (compilerType == PassThroughMode::None) - { - auto sourceName = TypeTextUtil::getCompileTargetName(SlangCompileTarget(sourceTarget)); - auto targetName = TypeTextUtil::getCompileTargetName(SlangCompileTarget(target)); + auto target = getTargetFormat(); + RefPtr extensionTracker = _newExtensionTracker(target); + PassThroughMode compilerType; - sink->diagnose(SourceLoc(), Diagnostics::compilerNotDefinedForTransition, sourceName, targetName); - return SLANG_FAIL; - } - } + SliceAllocator allocator; - SLANG_ASSERT(compilerType != PassThroughMode::None); + if (auto endToEndReq = isPassThroughEnabled()) + { + compilerType = endToEndReq->m_passThrough; + } + else + { + // If we are not in pass through, lookup the default compiler for the emitted source type - // Get the required downstream compiler - IDownstreamCompiler* compiler = session->getOrLoadDownstreamCompiler(compilerType, sink); - if (!compiler) + // Get the default source codegen type for a given target + sourceTarget = _getDefaultSourceForTarget(target); + compilerType = (PassThroughMode)session->getDownstreamCompilerForTransition( + (SlangCompileTarget)sourceTarget, + (SlangCompileTarget)target); + // We should have a downstream compiler set at this point + if (compilerType == PassThroughMode::None) { - auto compilerName = TypeTextUtil::getPassThroughAsHumanText((SlangPassThrough)compilerType); - sink->diagnose(SourceLoc(), Diagnostics::passThroughCompilerNotFound, compilerName); + auto sourceName = TypeTextUtil::getCompileTargetName(SlangCompileTarget(sourceTarget)); + auto targetName = TypeTextUtil::getCompileTargetName(SlangCompileTarget(target)); + + sink->diagnose( + SourceLoc(), + Diagnostics::compilerNotDefinedForTransition, + sourceName, + targetName); return SLANG_FAIL; } + } + + SLANG_ASSERT(compilerType != PassThroughMode::None); + + // Get the required downstream compiler + IDownstreamCompiler* compiler = session->getOrLoadDownstreamCompiler(compilerType, sink); + if (!compiler) + { + auto compilerName = TypeTextUtil::getPassThroughAsHumanText((SlangPassThrough)compilerType); + sink->diagnose(SourceLoc(), Diagnostics::passThroughCompilerNotFound, compilerName); + return SLANG_FAIL; + } - Dictionary preprocessorDefinitions; - List includePaths; + Dictionary preprocessorDefinitions; + List includePaths; - typedef DownstreamCompileOptions CompileOptions; - CompileOptions options; + typedef DownstreamCompileOptions CompileOptions; + CompileOptions options; - List requiredCapabilityVersions; - List compilerSpecificArguments; - List> libraries; - List libraryPaths; + List requiredCapabilityVersions; + List compilerSpecificArguments; + List> libraries; + List libraryPaths; - // Set compiler specific args + // Set compiler specific args + { + auto name = TypeTextUtil::getPassThroughName((SlangPassThrough)compilerType); + List downstreamArgs = getTargetProgram()->getOptionSet().getDownstreamArgs(name); + for (const auto& arg : downstreamArgs) { - auto name = TypeTextUtil::getPassThroughName((SlangPassThrough)compilerType); - List downstreamArgs = getTargetProgram()->getOptionSet().getDownstreamArgs(name); - for (const auto& arg : downstreamArgs) + // We special case some kinds of args, that can be handled directly + if (arg.startsWith("-I")) { - // We special case some kinds of args, that can be handled directly - if (arg.startsWith("-I")) - { - // We handle the -I option, by just adding to the include paths - includePaths.add(arg.getUnownedSlice().tail(2)); - } - else - { - compilerSpecificArguments.add(arg); - } + // We handle the -I option, by just adding to the include paths + includePaths.add(arg.getUnownedSlice().tail(2)); + } + else + { + compilerSpecificArguments.add(arg); } } + } - ComPtr sourceArtifact; + ComPtr sourceArtifact; - /* This is more convoluted than the other scenarios, because when we invoke C/C++ compiler we would ideally like - to use the original file. We want to do this because we want includes relative to the source file to work, and - for that to work most easily we want to use the original file, if there is one */ - if (auto endToEndReq = isPassThroughEnabled()) + /* This is more convoluted than the other scenarios, because when we invoke C/C++ compiler we + would ideally like to use the original file. We want to do this because we want includes + relative to the source file to work, and for that to work most easily we want to use the + original file, if there is one */ + if (auto endToEndReq = isPassThroughEnabled()) + { + // If we are pass through, we may need to set extension tracker state. + if (GLSLExtensionTracker* glslTracker = as(extensionTracker)) { - // If we are pass through, we may need to set extension tracker state. - if (GLSLExtensionTracker* glslTracker = as(extensionTracker)) - { - trackGLSLTargetCaps(glslTracker, getTargetCaps()); - } + trackGLSLTargetCaps(glslTracker, getTargetCaps()); + } - auto translationUnit = getPassThroughTranslationUnit(endToEndReq, getSingleEntryPointIndex()); + auto translationUnit = + getPassThroughTranslationUnit(endToEndReq, getSingleEntryPointIndex()); - // We are just passing thru, so it's whatever it originally was - sourceLanguage = translationUnit->sourceLanguage; + // We are just passing thru, so it's whatever it originally was + sourceLanguage = translationUnit->sourceLanguage; - // TODO(JS): This seems like a bit of a hack - // That if a pass-through is being performed and the source language is Slang - // no downstream compiler knows how to deal with that, so probably means 'HLSL' - sourceLanguage = (sourceLanguage == SourceLanguage::Slang) ? SourceLanguage::HLSL : sourceLanguage; - sourceTarget = CodeGenTarget(TypeConvertUtil::getCompileTargetFromSourceLanguage((SlangSourceLanguage)sourceLanguage)); + // TODO(JS): This seems like a bit of a hack + // That if a pass-through is being performed and the source language is Slang + // no downstream compiler knows how to deal with that, so probably means 'HLSL' + sourceLanguage = + (sourceLanguage == SourceLanguage::Slang) ? SourceLanguage::HLSL : sourceLanguage; + sourceTarget = CodeGenTarget(TypeConvertUtil::getCompileTargetFromSourceLanguage( + (SlangSourceLanguage)sourceLanguage)); - // If it's pass through we accumulate the preprocessor definitions. - for (const auto& define : endToEndReq->getOptionSet().getArray(CompilerOptionName::MacroDefine)) - preprocessorDefinitions.add(define.stringValue, define.stringValue2); - for (const auto& define : translationUnit->preprocessorDefinitions) - preprocessorDefinitions.add(define); - - { - /* TODO(JS): Not totally clear what options should be set here. If we are using the pass through - then using say the defines/includes - all makes total sense. If we are generating C++ code from slang, then should we really be using these values -> aren't they what is - being set for the *slang* source, not for the C++ generated code. That being the case it implies that there needs to be a mechanism - (if there isn't already) to specify such information on a particular pass/pass through etc. + // If it's pass through we accumulate the preprocessor definitions. + for (const auto& define : + endToEndReq->getOptionSet().getArray(CompilerOptionName::MacroDefine)) + preprocessorDefinitions.add(define.stringValue, define.stringValue2); + for (const auto& define : translationUnit->preprocessorDefinitions) + preprocessorDefinitions.add(define); - On invoking DXC for example include paths do not appear to be set at all (even with pass-through). - */ + { + /* TODO(JS): Not totally clear what options should be set here. If we are using the pass + through - then using say the defines/includes all makes total sense. If we are + generating C++ code from slang, then should we really be using these values -> aren't + they what is being set for the *slang* source, not for the C++ generated code. That + being the case it implies that there needs to be a mechanism (if there isn't already) to + specify such information on a particular pass/pass through etc. + + On invoking DXC for example include paths do not appear to be set at all (even with + pass-through). + */ + + auto linkage = getLinkage(); - auto linkage = getLinkage(); + // Add all the search paths - // Add all the search paths - - const auto searchDirectories = linkage->getSearchDirectories(); - const SearchDirectoryList* searchList = &searchDirectories; - while (searchList) + const auto searchDirectories = linkage->getSearchDirectories(); + const SearchDirectoryList* searchList = &searchDirectories; + while (searchList) + { + for (const auto& searchDirectory : searchList->searchDirectories) { - for (const auto& searchDirectory : searchList->searchDirectories) - { - includePaths.add(searchDirectory.path); - } - searchList = searchList->parent; + includePaths.add(searchDirectory.path); } + searchList = searchList->parent; } + } - // If emitted source is required, emit and set the path - if (_useEmittedSource(compiler, translationUnit)) - { - CodeGenContext sourceCodeGenContext(this, sourceTarget, extensionTracker); - - SLANG_RETURN_ON_FAIL(sourceCodeGenContext.emitEntryPointsSource(sourceArtifact)); + // If emitted source is required, emit and set the path + if (_useEmittedSource(compiler, translationUnit)) + { + CodeGenContext sourceCodeGenContext(this, sourceTarget, extensionTracker); - // If it's not file based we can set an appropriate path name, and it doesn't matter if it doesn't - // exist on the file system. - // We set the name to the path as this will be used for downstream reporting. - auto sourcePath = calcSourcePathForEntryPoints(); - sourceArtifact->setName(sourcePath.getBuffer()); + SLANG_RETURN_ON_FAIL(sourceCodeGenContext.emitEntryPointsSource(sourceArtifact)); - sourceCodeGenContext.maybeDumpIntermediate(sourceArtifact); - } - else - { - // Special case if we have a single file, so that we pass the path, and the contents as is. - const auto& sourceArtifacts = translationUnit->getSourceArtifacts(); - SLANG_ASSERT(sourceArtifacts.getCount() == 1); + // If it's not file based we can set an appropriate path name, and it doesn't matter if + // it doesn't exist on the file system. We set the name to the path as this will be used + // for downstream reporting. + auto sourcePath = calcSourcePathForEntryPoints(); + sourceArtifact->setName(sourcePath.getBuffer()); - sourceArtifact = sourceArtifacts[0]; - SLANG_ASSERT(sourceArtifact); - } + sourceCodeGenContext.maybeDumpIntermediate(sourceArtifact); } else { - CodeGenContext sourceCodeGenContext(this, sourceTarget, extensionTracker); + // Special case if we have a single file, so that we pass the path, and the contents as + // is. + const auto& sourceArtifacts = translationUnit->getSourceArtifacts(); + SLANG_ASSERT(sourceArtifacts.getCount() == 1); - sourceCodeGenContext.removeAvailableInDownstreamIR = true; + sourceArtifact = sourceArtifacts[0]; + SLANG_ASSERT(sourceArtifact); + } + } + else + { + CodeGenContext sourceCodeGenContext(this, sourceTarget, extensionTracker); - SLANG_RETURN_ON_FAIL(sourceCodeGenContext.emitEntryPointsSource(sourceArtifact)); - sourceCodeGenContext.maybeDumpIntermediate(sourceArtifact); + sourceCodeGenContext.removeAvailableInDownstreamIR = true; - sourceLanguage = (SourceLanguage)TypeConvertUtil::getSourceLanguageFromTarget((SlangCompileTarget)sourceTarget); - } + SLANG_RETURN_ON_FAIL(sourceCodeGenContext.emitEntryPointsSource(sourceArtifact)); + sourceCodeGenContext.maybeDumpIntermediate(sourceArtifact); - if (sourceArtifact) - { - // Set the source artifacts - options.sourceArtifacts = makeSlice(sourceArtifact.readRef(), 1); - } + sourceLanguage = (SourceLanguage)TypeConvertUtil::getSourceLanguageFromTarget( + (SlangCompileTarget)sourceTarget); + } - // Add any preprocessor definitions associated with the linkage - { - // TODO(JS): This is somewhat arguable - should defines passed to Slang really be - // passed to downstream compilers? It does appear consistent with the behavior if - // there is an endToEndReq. - // - // That said it's very convenient and provides way to control aspects - // of downstream compilation. + if (sourceArtifact) + { + // Set the source artifacts + options.sourceArtifacts = makeSlice(sourceArtifact.readRef(), 1); + } - for (const auto& define : getTargetProgram()->getOptionSet().getArray(CompilerOptionName::MacroDefine)) - { - preprocessorDefinitions.addIfNotExists(define.stringValue, define.stringValue2); - } - } + // Add any preprocessor definitions associated with the linkage + { + // TODO(JS): This is somewhat arguable - should defines passed to Slang really be + // passed to downstream compilers? It does appear consistent with the behavior if + // there is an endToEndReq. + // + // That said it's very convenient and provides way to control aspects + // of downstream compilation. - - // If we have an extension tracker, we may need to set options such as SPIR-V version - // and CUDA Shader Model. - if (extensionTracker) + for (const auto& define : + getTargetProgram()->getOptionSet().getArray(CompilerOptionName::MacroDefine)) { - // Look for the version - if (auto cudaTracker = as(extensionTracker)) - { - cudaTracker->finalize(); + preprocessorDefinitions.addIfNotExists(define.stringValue, define.stringValue2); + } + } - if (cudaTracker->m_smVersion.isSet()) - { - DownstreamCompileOptions::CapabilityVersion version; - version.kind = DownstreamCompileOptions::CapabilityVersion::Kind::CUDASM; - version.version = cudaTracker->m_smVersion; - requiredCapabilityVersions.add(version); - } + // If we have an extension tracker, we may need to set options such as SPIR-V version + // and CUDA Shader Model. + if (extensionTracker) + { + // Look for the version + if (auto cudaTracker = as(extensionTracker)) + { + cudaTracker->finalize(); - if (cudaTracker->isBaseTypeRequired(BaseType::Half)) - { - options.flags |= CompileOptions::Flag::EnableFloat16; - } - } - else if (GLSLExtensionTracker* glslTracker = as(extensionTracker)) + if (cudaTracker->m_smVersion.isSet()) { DownstreamCompileOptions::CapabilityVersion version; - version.kind = DownstreamCompileOptions::CapabilityVersion::Kind::SPIRV; - version.version = glslTracker->getSPIRVVersion(); + version.kind = DownstreamCompileOptions::CapabilityVersion::Kind::CUDASM; + version.version = cudaTracker->m_smVersion; requiredCapabilityVersions.add(version); } - } - - // Set the file sytem and source manager, as *may* be used by downstream compiler - options.fileSystemExt = getFileSystemExt(); - options.sourceManager = getSourceManager(); - // Set the source type - options.sourceLanguage = SlangSourceLanguage(sourceLanguage); - - switch (target) + if (cudaTracker->isBaseTypeRequired(BaseType::Half)) + { + options.flags |= CompileOptions::Flag::EnableFloat16; + } + } + else if (GLSLExtensionTracker* glslTracker = as(extensionTracker)) { - case CodeGenTarget::ShaderHostCallable: - case CodeGenTarget::ShaderSharedLibrary: - // Disable exceptions and security checks - options.flags &= ~(CompileOptions::Flag::EnableExceptionHandling | CompileOptions::Flag::EnableSecurityChecks); - break; + DownstreamCompileOptions::CapabilityVersion version; + version.kind = DownstreamCompileOptions::CapabilityVersion::Kind::SPIRV; + version.version = glslTracker->getSPIRVVersion(); + + requiredCapabilityVersions.add(version); } + } - Profile profile; + // Set the file sytem and source manager, as *may* be used by downstream compiler + options.fileSystemExt = getFileSystemExt(); + options.sourceManager = getSourceManager(); - if (compilerType == PassThroughMode::Fxc || - compilerType == PassThroughMode::Dxc || - compilerType == PassThroughMode::Glslang) - { - const auto entryPointIndices = getEntryPointIndices(); - auto targetReq = getTargetReq(); + // Set the source type + options.sourceLanguage = SlangSourceLanguage(sourceLanguage); - const auto entryPointIndicesCount = entryPointIndices.getCount(); + switch (target) + { + case CodeGenTarget::ShaderHostCallable: + case CodeGenTarget::ShaderSharedLibrary: + // Disable exceptions and security checks + options.flags &= + ~(CompileOptions::Flag::EnableExceptionHandling | + CompileOptions::Flag::EnableSecurityChecks); + break; + } - // Whole program means - // * can have 0-N entry points - // * 'doesn't build into an executable/kernel' - // - // So in some sense it is a library - if (getTargetProgram()->getOptionSet().getBoolOption(CompilerOptionName::GenerateWholeProgram)) - { - if (compilerType == PassThroughMode::Dxc) - { - // Can support no entry points on DXC because we can build libraries - profile = Profile(getTargetProgram()->getOptionSet().getEnumOption(CompilerOptionName::Profile)); - } - else - { - auto downstreamCompilerName = TypeTextUtil::getPassThroughName((SlangPassThrough)compilerType); + Profile profile; - sink->diagnose(SourceLoc(), Diagnostics::downstreamCompilerDoesntSupportWholeProgramCompilation, downstreamCompilerName); - return SLANG_FAIL; - } - } - else if (entryPointIndicesCount == 1) - { - // All support a single entry point - const Index entryPointIndex = entryPointIndices[0]; + if (compilerType == PassThroughMode::Fxc || compilerType == PassThroughMode::Dxc || + compilerType == PassThroughMode::Glslang) + { + const auto entryPointIndices = getEntryPointIndices(); + auto targetReq = getTargetReq(); - auto entryPoint = getEntryPoint(entryPointIndex); - profile = getEffectiveProfile(entryPoint, targetReq); + const auto entryPointIndicesCount = entryPointIndices.getCount(); - if (_shouldSetEntryPointName(getTargetProgram())) - { - options.entryPointName = allocator.allocate(getText(entryPoint->getName())); - auto entryPointNameOverride = getProgram()->getEntryPointNameOverride(entryPointIndex); - if (entryPointNameOverride.getLength() != 0) - { - options.entryPointName = allocator.allocate(entryPointNameOverride); - } - } + // Whole program means + // * can have 0-N entry points + // * 'doesn't build into an executable/kernel' + // + // So in some sense it is a library + if (getTargetProgram()->getOptionSet().getBoolOption( + CompilerOptionName::GenerateWholeProgram)) + { + if (compilerType == PassThroughMode::Dxc) + { + // Can support no entry points on DXC because we can build libraries + profile = + Profile(getTargetProgram()->getOptionSet().getEnumOption( + CompilerOptionName::Profile)); } - else + else { - // We only support a single entry point on this target - SLANG_ASSERT(!"Can only compile with a single entry point on this target"); + auto downstreamCompilerName = + TypeTextUtil::getPassThroughName((SlangPassThrough)compilerType); + + sink->diagnose( + SourceLoc(), + Diagnostics::downstreamCompilerDoesntSupportWholeProgramCompilation, + downstreamCompilerName); return SLANG_FAIL; } - - options.stage = SlangStage(profile.getStage()); + } + else if (entryPointIndicesCount == 1) + { + // All support a single entry point + const Index entryPointIndex = entryPointIndices[0]; - if (compilerType == PassThroughMode::Dxc) + auto entryPoint = getEntryPoint(entryPointIndex); + profile = getEffectiveProfile(entryPoint, targetReq); + + if (_shouldSetEntryPointName(getTargetProgram())) { - // We will enable the flag to generate proper code for 16 - bit types - // by default, as long as the user is requesting a sufficiently - // high shader model. - // - // TODO: Need to check that this is safe to enable in all cases, - // or if it will make a shader demand hardware features that - // aren't always present. - // - // TODO: Ideally the dxc back-end should be passed some information - // on the "capabilities" that were used and/or requested in the code. - // - if (profile.getVersion() >= ProfileVersion::DX_6_2) + options.entryPointName = allocator.allocate(getText(entryPoint->getName())); + auto entryPointNameOverride = + getProgram()->getEntryPointNameOverride(entryPointIndex); + if (entryPointNameOverride.getLength() != 0) { - options.flags |= CompileOptions::Flag::EnableFloat16; + options.entryPointName = allocator.allocate(entryPointNameOverride); } + } + } + else + { + // We only support a single entry point on this target + SLANG_ASSERT(!"Can only compile with a single entry point on this target"); + return SLANG_FAIL; + } - // Set the matrix layout - options.matrixLayout = (SlangMatrixLayoutMode)getTargetProgram()->getOptionSet().getMatrixLayoutMode(); + options.stage = SlangStage(profile.getStage()); + + if (compilerType == PassThroughMode::Dxc) + { + // We will enable the flag to generate proper code for 16 - bit types + // by default, as long as the user is requesting a sufficiently + // high shader model. + // + // TODO: Need to check that this is safe to enable in all cases, + // or if it will make a shader demand hardware features that + // aren't always present. + // + // TODO: Ideally the dxc back-end should be passed some information + // on the "capabilities" that were used and/or requested in the code. + // + if (profile.getVersion() >= ProfileVersion::DX_6_2) + { + options.flags |= CompileOptions::Flag::EnableFloat16; } - // Set the profile - options.profileName = allocator.allocate(GetHLSLProfileName(profile)); + // Set the matrix layout + options.matrixLayout = + (SlangMatrixLayoutMode)getTargetProgram()->getOptionSet().getMatrixLayoutMode(); + } + + // Set the profile + options.profileName = allocator.allocate(GetHLSLProfileName(profile)); + } + + // If we aren't using LLVM 'host callable', we want downstream compile to produce a shared + // library + if (compilerType != PassThroughMode::LLVM && + ArtifactDescUtil::makeDescForCompileTarget(asExternal(target)).kind == + ArtifactKind::HostCallable) + { + target = CodeGenTarget::ShaderSharedLibrary; + } + + if (!isPassThroughEnabled()) + { + if (_isCPUHostTarget(target)) + { + libraryPaths.add(Path::getParentDirectory(Path::getExecutablePath())); + libraryPaths.add( + Path::combine(Path::getParentDirectory(Path::getExecutablePath()), "../lib")); + + // Set up the library artifact + auto artifact = Artifact::create( + ArtifactDesc::make(ArtifactKind::Library, Artifact::Payload::HostCPU), + toSlice("slang-rt")); + + ComPtr fileRep(new OSFileArtifactRepresentation( + IOSFileArtifactRepresentation::Kind::NameOnly, + toSlice("slang-rt"), + nullptr)); + artifact->addRepresentation(fileRep); + + libraries.add(artifact); + } + } + + options.targetType = (SlangCompileTarget)target; + + // Need to configure for the compilation + + { + auto linkage = getLinkage(); + + switch (getTargetProgram()->getOptionSet().getEnumOption( + CompilerOptionName::Optimization)) + { + case OptimizationLevel::None: + options.optimizationLevel = DownstreamCompileOptions::OptimizationLevel::None; + break; + case OptimizationLevel::Default: + options.optimizationLevel = DownstreamCompileOptions::OptimizationLevel::Default; + break; + case OptimizationLevel::High: + options.optimizationLevel = DownstreamCompileOptions::OptimizationLevel::High; + break; + case OptimizationLevel::Maximal: + options.optimizationLevel = DownstreamCompileOptions::OptimizationLevel::Maximal; + break; + default: SLANG_ASSERT(!"Unhandled optimization level"); break; + } + + switch (getTargetProgram()->getOptionSet().getEnumOption( + CompilerOptionName::DebugInformation)) + { + case DebugInfoLevel::None: + options.debugInfoType = DownstreamCompileOptions::DebugInfoType::None; + break; + case DebugInfoLevel::Minimal: + options.debugInfoType = DownstreamCompileOptions::DebugInfoType::Minimal; + break; + + case DebugInfoLevel::Standard: + options.debugInfoType = DownstreamCompileOptions::DebugInfoType::Standard; + break; + case DebugInfoLevel::Maximal: + options.debugInfoType = DownstreamCompileOptions::DebugInfoType::Maximal; + break; + default: SLANG_ASSERT(!"Unhandled debug level"); break; } - // If we aren't using LLVM 'host callable', we want downstream compile to produce a shared library - if (compilerType != PassThroughMode::LLVM && - ArtifactDescUtil::makeDescForCompileTarget(asExternal(target)).kind == ArtifactKind::HostCallable) + switch (getTargetProgram()->getOptionSet().getEnumOption( + CompilerOptionName::FloatingPointMode)) { - target = CodeGenTarget::ShaderSharedLibrary; + case FloatingPointMode::Default: + options.floatingPointMode = DownstreamCompileOptions::FloatingPointMode::Default; + break; + case FloatingPointMode::Precise: + options.floatingPointMode = DownstreamCompileOptions::FloatingPointMode::Precise; + break; + case FloatingPointMode::Fast: + options.floatingPointMode = DownstreamCompileOptions::FloatingPointMode::Fast; + break; + default: SLANG_ASSERT(!"Unhandled floating point mode"); } - if (!isPassThroughEnabled()) { - if (_isCPUHostTarget(target)) + // We need to look at the stage of the entry point(s) we are + // being asked to compile, since this will determine the + // "pipeline" that the result should be compiled for (e.g., + // compute vs. ray tracing). + // + // TODO: This logic is kind of messy in that it assumes + // a program to be compiled will only contain kernels for + // a single pipeline type, but that invariant isn't expressed + // at all in the front-end today. It also has no error + // checking for the case where there are conflicts. + // + // HACK: Right now none of the above concerns matter + // because we always perform code generation on a single + // entry point at a time. + // + Index entryPointCount = getEntryPointCount(); + for (Index ee = 0; ee < entryPointCount; ++ee) { - libraryPaths.add(Path::getParentDirectory(Path::getExecutablePath())); - libraryPaths.add(Path::combine(Path::getParentDirectory(Path::getExecutablePath()), "../lib")); + auto stage = getEntryPoint(ee)->getStage(); + switch (stage) + { + default: break; - // Set up the library artifact - auto artifact = Artifact::create(ArtifactDesc::make(ArtifactKind::Library, Artifact::Payload::HostCPU), toSlice("slang-rt")); + case Stage::Compute: + options.pipelineType = DownstreamCompileOptions::PipelineType::Compute; + break; - ComPtr fileRep(new OSFileArtifactRepresentation(IOSFileArtifactRepresentation::Kind::NameOnly, toSlice("slang-rt"), nullptr)); - artifact->addRepresentation(fileRep); + case Stage::Vertex: + case Stage::Hull: + case Stage::Domain: + case Stage::Geometry: + case Stage::Fragment: + options.pipelineType = DownstreamCompileOptions::PipelineType::Rasterization; + break; - libraries.add(artifact); + case Stage::RayGeneration: + case Stage::Intersection: + case Stage::AnyHit: + case Stage::ClosestHit: + case Stage::Miss: + case Stage::Callable: + options.pipelineType = DownstreamCompileOptions::PipelineType::RayTracing; + break; + } } } - options.targetType = (SlangCompileTarget)target; - - // Need to configure for the compilation + // Add all the search paths (as calculated earlier - they will only be set if this is a pass + // through else will be empty) + options.includePaths = allocator.allocate(includePaths); + // Add the specified defines (as calculated earlier - they will only be set if this is a + // pass through else will be empty) { - auto linkage = getLinkage(); - - switch (getTargetProgram()->getOptionSet().getEnumOption(CompilerOptionName::Optimization)) - { - case OptimizationLevel::None: options.optimizationLevel = DownstreamCompileOptions::OptimizationLevel::None; break; - case OptimizationLevel::Default: options.optimizationLevel = DownstreamCompileOptions::OptimizationLevel::Default; break; - case OptimizationLevel::High: options.optimizationLevel = DownstreamCompileOptions::OptimizationLevel::High; break; - case OptimizationLevel::Maximal: options.optimizationLevel = DownstreamCompileOptions::OptimizationLevel::Maximal; break; - default: SLANG_ASSERT(!"Unhandled optimization level"); break; - } - - switch (getTargetProgram()->getOptionSet().getEnumOption(CompilerOptionName::DebugInformation)) - { - case DebugInfoLevel::None: options.debugInfoType = DownstreamCompileOptions::DebugInfoType::None; break; - case DebugInfoLevel::Minimal: options.debugInfoType = DownstreamCompileOptions::DebugInfoType::Minimal; break; - - case DebugInfoLevel::Standard: options.debugInfoType = DownstreamCompileOptions::DebugInfoType::Standard; break; - case DebugInfoLevel::Maximal: options.debugInfoType = DownstreamCompileOptions::DebugInfoType::Maximal; break; - default: SLANG_ASSERT(!"Unhandled debug level"); break; - } - - switch (getTargetProgram()->getOptionSet().getEnumOption(CompilerOptionName::FloatingPointMode)) - { - case FloatingPointMode::Default: options.floatingPointMode = DownstreamCompileOptions::FloatingPointMode::Default; break; - case FloatingPointMode::Precise: options.floatingPointMode = DownstreamCompileOptions::FloatingPointMode::Precise; break; - case FloatingPointMode::Fast: options.floatingPointMode = DownstreamCompileOptions::FloatingPointMode::Fast; break; - default: SLANG_ASSERT(!"Unhandled floating point mode"); - } - - { - // We need to look at the stage of the entry point(s) we are - // being asked to compile, since this will determine the - // "pipeline" that the result should be compiled for (e.g., - // compute vs. ray tracing). - // - // TODO: This logic is kind of messy in that it assumes - // a program to be compiled will only contain kernels for - // a single pipeline type, but that invariant isn't expressed - // at all in the front-end today. It also has no error - // checking for the case where there are conflicts. - // - // HACK: Right now none of the above concerns matter - // because we always perform code generation on a single - // entry point at a time. - // - Index entryPointCount = getEntryPointCount(); - for(Index ee = 0; ee < entryPointCount; ++ee) - { - auto stage = getEntryPoint(ee)->getStage(); - switch(stage) - { - default: - break; - - case Stage::Compute: - options.pipelineType = DownstreamCompileOptions::PipelineType::Compute; - break; - - case Stage::Vertex: - case Stage::Hull: - case Stage::Domain: - case Stage::Geometry: - case Stage::Fragment: - options.pipelineType = DownstreamCompileOptions::PipelineType::Rasterization; - break; - - case Stage::RayGeneration: - case Stage::Intersection: - case Stage::AnyHit: - case Stage::ClosestHit: - case Stage::Miss: - case Stage::Callable: - options.pipelineType = DownstreamCompileOptions::PipelineType::RayTracing; - break; - } - } - } + const auto count = preprocessorDefinitions.getCount(); + auto dst = allocator.getArena().allocateArray(count); - // Add all the search paths (as calculated earlier - they will only be set if this is a pass through else will be empty) - options.includePaths = allocator.allocate(includePaths); + Index i = 0; - // Add the specified defines (as calculated earlier - they will only be set if this is a pass through else will be empty) + for (const auto& [defKey, defValue] : preprocessorDefinitions) { - const auto count = preprocessorDefinitions.getCount(); - auto dst = allocator.getArena().allocateArray(count); - - Index i = 0; + auto& define = dst[i]; - for(const auto& [defKey, defValue] : preprocessorDefinitions) - { - auto& define = dst[i]; - - define.nameWithSig = allocator.allocate(defKey); - define.value = allocator.allocate(defValue); + define.nameWithSig = allocator.allocate(defKey); + define.value = allocator.allocate(defValue); - ++i; - } - options.defines = makeSlice(dst, count); + ++i; } - - // Add all of the module libraries - libraries.addRange(linkage->m_libModules.getBuffer(), linkage->m_libModules.getCount()); + options.defines = makeSlice(dst, count); } - auto program = getProgram(); + // Add all of the module libraries + libraries.addRange(linkage->m_libModules.getBuffer(), linkage->m_libModules.getCount()); + } + + auto program = getProgram(); - // Load embedded precompiled libraries from IR into library artifacts - program->enumerateIRModules([&](IRModule* irModule) + // Load embedded precompiled libraries from IR into library artifacts + program->enumerateIRModules( + [&](IRModule* irModule) { for (auto globalInst : irModule->getModuleInst()->getChildren()) { @@ -1577,7 +1656,8 @@ namespace Slang if (inst->getTarget() == CodeGenTarget::DXIL) { auto slice = inst->getBlob()->getStringSlice(); - ArtifactDesc desc = ArtifactDescUtil::makeDescForCompileTarget(SLANG_DXIL); + ArtifactDesc desc = + ArtifactDescUtil::makeDescForCompileTarget(SLANG_DXIL); desc.kind = ArtifactKind::Library; auto library = ArtifactUtil::createArtifact(desc); @@ -1590,1060 +1670,1097 @@ namespace Slang } }); - options.compilerSpecificArguments = allocator.allocate(compilerSpecificArguments); - options.requiredCapabilityVersions = SliceUtil::asSlice(requiredCapabilityVersions); - options.libraries = SliceUtil::asSlice(libraries); - options.libraryPaths = allocator.allocate(libraryPaths); - - // Compile - ComPtr artifact; - auto downstreamStartTime = std::chrono::high_resolution_clock::now(); - SLANG_RETURN_ON_FAIL(compiler->compile(options, artifact.writeRef())); - auto downstreamElapsedTime = - (std::chrono::high_resolution_clock::now() - downstreamStartTime).count() * 0.000000001; - getSession()->addDownstreamCompileTime(downstreamElapsedTime); - - SLANG_RETURN_ON_FAIL(passthroughDownstreamDiagnostics(getSink(), compiler, artifact)); - - // Copy over all of the information associated with the source into the output - if (sourceArtifact) - { - for (auto associatedArtifact : sourceArtifact->getAssociated()) - { - artifact->addAssociated(associatedArtifact); - } - } + options.compilerSpecificArguments = allocator.allocate(compilerSpecificArguments); + options.requiredCapabilityVersions = SliceUtil::asSlice(requiredCapabilityVersions); + options.libraries = SliceUtil::asSlice(libraries); + options.libraryPaths = allocator.allocate(libraryPaths); - // Set the artifact - outArtifact.swap(artifact); - return SLANG_OK; - } + // Compile + ComPtr artifact; + auto downstreamStartTime = std::chrono::high_resolution_clock::now(); + SLANG_RETURN_ON_FAIL(compiler->compile(options, artifact.writeRef())); + auto downstreamElapsedTime = + (std::chrono::high_resolution_clock::now() - downstreamStartTime).count() * 0.000000001; + getSession()->addDownstreamCompileTime(downstreamElapsedTime); - SlangResult emitSPIRVForEntryPointsDirectly( - CodeGenContext* codeGenContext, - ComPtr& outArtifact); + SLANG_RETURN_ON_FAIL(passthroughDownstreamDiagnostics(getSink(), compiler, artifact)); - static CodeGenTarget _getIntermediateTarget(CodeGenTarget target) + // Copy over all of the information associated with the source into the output + if (sourceArtifact) { - switch (target) + for (auto associatedArtifact : sourceArtifact->getAssociated()) { - case CodeGenTarget::DXBytecodeAssembly: return CodeGenTarget::DXBytecode; - case CodeGenTarget::DXILAssembly: return CodeGenTarget::DXIL; - case CodeGenTarget::SPIRVAssembly: return CodeGenTarget::SPIRV; - case CodeGenTarget::WGSLSPIRVAssembly: return CodeGenTarget::WGSLSPIRV; - default: return CodeGenTarget::None; + artifact->addAssociated(associatedArtifact); } } - /// Function to simplify the logic around emitting, and dissassembling - SlangResult CodeGenContext::_emitEntryPoints(ComPtr& outArtifact) - { - auto target = getTargetFormat(); - switch (target) - { - case CodeGenTarget::SPIRVAssembly: - case CodeGenTarget::DXBytecodeAssembly: - case CodeGenTarget::DXILAssembly: - case CodeGenTarget::MetalLibAssembly: - case CodeGenTarget::WGSLSPIRVAssembly: - { - // First compile to an intermediate target for the corresponding binary format. - const CodeGenTarget intermediateTarget = _getIntermediateTarget(target); - CodeGenContext intermediateContext(this, intermediateTarget); - - ComPtr intermediateArtifact; - - SLANG_RETURN_ON_FAIL(intermediateContext._emitEntryPoints(intermediateArtifact)); - intermediateContext.maybeDumpIntermediate(intermediateArtifact); + // Set the artifact + outArtifact.swap(artifact); + return SLANG_OK; +} - // Then disassemble the intermediate binary result to get the desired output - // Output the disassemble - ComPtr disassemblyArtifact; - SLANG_RETURN_ON_FAIL(ArtifactOutputUtil::dissassembleWithDownstream(getSession(), intermediateArtifact, getSink(), disassemblyArtifact.writeRef())); +SlangResult emitSPIRVForEntryPointsDirectly( + CodeGenContext* codeGenContext, + ComPtr& outArtifact); - outArtifact.swap(disassemblyArtifact); - return SLANG_OK; - } - case CodeGenTarget::SPIRV: - if (getTargetProgram()->getOptionSet().shouldEmitSPIRVDirectly()) - { - SLANG_RETURN_ON_FAIL(emitSPIRVForEntryPointsDirectly(this, outArtifact)); - return SLANG_OK; - } - [[fallthrough]]; - case CodeGenTarget::DXIL: - case CodeGenTarget::DXBytecode: - case CodeGenTarget::MetalLib: - case CodeGenTarget::PTX: - case CodeGenTarget::ShaderHostCallable: - case CodeGenTarget::ShaderSharedLibrary: - case CodeGenTarget::HostExecutable: - case CodeGenTarget::HostHostCallable: - case CodeGenTarget::HostSharedLibrary: - case CodeGenTarget::WGSLSPIRV: - SLANG_RETURN_ON_FAIL(emitWithDownstreamForEntryPoints(outArtifact)); - return SLANG_OK; +static CodeGenTarget _getIntermediateTarget(CodeGenTarget target) +{ + switch (target) + { + case CodeGenTarget::DXBytecodeAssembly: return CodeGenTarget::DXBytecode; + case CodeGenTarget::DXILAssembly: return CodeGenTarget::DXIL; + case CodeGenTarget::SPIRVAssembly: return CodeGenTarget::SPIRV; + case CodeGenTarget::WGSLSPIRVAssembly: return CodeGenTarget::WGSLSPIRV; + default: return CodeGenTarget::None; + } +} - default: break; +/// Function to simplify the logic around emitting, and dissassembling +SlangResult CodeGenContext::_emitEntryPoints(ComPtr& outArtifact) +{ + auto target = getTargetFormat(); + switch (target) + { + case CodeGenTarget::SPIRVAssembly: + case CodeGenTarget::DXBytecodeAssembly: + case CodeGenTarget::DXILAssembly: + case CodeGenTarget::MetalLibAssembly: + case CodeGenTarget::WGSLSPIRVAssembly: + { + // First compile to an intermediate target for the corresponding binary format. + const CodeGenTarget intermediateTarget = _getIntermediateTarget(target); + CodeGenContext intermediateContext(this, intermediateTarget); + + ComPtr intermediateArtifact; + + SLANG_RETURN_ON_FAIL(intermediateContext._emitEntryPoints(intermediateArtifact)); + intermediateContext.maybeDumpIntermediate(intermediateArtifact); + + // Then disassemble the intermediate binary result to get the desired output + // Output the disassemble + ComPtr disassemblyArtifact; + SLANG_RETURN_ON_FAIL(ArtifactOutputUtil::dissassembleWithDownstream( + getSession(), + intermediateArtifact, + getSink(), + disassemblyArtifact.writeRef())); + + outArtifact.swap(disassemblyArtifact); + return SLANG_OK; } + case CodeGenTarget::SPIRV: + if (getTargetProgram()->getOptionSet().shouldEmitSPIRVDirectly()) + { + SLANG_RETURN_ON_FAIL(emitSPIRVForEntryPointsDirectly(this, outArtifact)); + return SLANG_OK; + } + [[fallthrough]]; + case CodeGenTarget::DXIL: + case CodeGenTarget::DXBytecode: + case CodeGenTarget::MetalLib: + case CodeGenTarget::PTX: + case CodeGenTarget::ShaderHostCallable: + case CodeGenTarget::ShaderSharedLibrary: + case CodeGenTarget::HostExecutable: + case CodeGenTarget::HostHostCallable: + case CodeGenTarget::HostSharedLibrary: + case CodeGenTarget::WGSLSPIRV: + SLANG_RETURN_ON_FAIL(emitWithDownstreamForEntryPoints(outArtifact)); + return SLANG_OK; - return SLANG_FAIL; + default: break; } - // Do emit logic for a zero or more entry points - SlangResult CodeGenContext::emitEntryPoints(ComPtr& outArtifact) - { - CompileTimerRAII recordCompileTime(getSession()); - - auto target = getTargetFormat(); + return SLANG_FAIL; +} - switch (target) +// Do emit logic for a zero or more entry points +SlangResult CodeGenContext::emitEntryPoints(ComPtr& outArtifact) +{ + CompileTimerRAII recordCompileTime(getSession()); + + auto target = getTargetFormat(); + + switch (target) + { + case CodeGenTarget::SPIRVAssembly: + case CodeGenTarget::DXBytecodeAssembly: + case CodeGenTarget::DXILAssembly: + case CodeGenTarget::SPIRV: + case CodeGenTarget::DXIL: + case CodeGenTarget::DXBytecode: + case CodeGenTarget::MetalLib: + case CodeGenTarget::MetalLibAssembly: + case CodeGenTarget::PTX: + case CodeGenTarget::HostHostCallable: + case CodeGenTarget::ShaderHostCallable: + case CodeGenTarget::ShaderSharedLibrary: + case CodeGenTarget::HostExecutable: + case CodeGenTarget::HostSharedLibrary: + case CodeGenTarget::WGSLSPIRVAssembly: + { + SLANG_RETURN_ON_FAIL(_emitEntryPoints(outArtifact)); + + maybeDumpIntermediate(outArtifact); + return SLANG_OK; + } + break; + case CodeGenTarget::GLSL: + case CodeGenTarget::HLSL: + case CodeGenTarget::CUDASource: + case CodeGenTarget::CPPSource: + case CodeGenTarget::HostCPPSource: + case CodeGenTarget::PyTorchCppBinding: + case CodeGenTarget::CSource: + case CodeGenTarget::Metal: + case CodeGenTarget::WGSL: { - case CodeGenTarget::SPIRVAssembly: - case CodeGenTarget::DXBytecodeAssembly: - case CodeGenTarget::DXILAssembly: - case CodeGenTarget::SPIRV: - case CodeGenTarget::DXIL: - case CodeGenTarget::DXBytecode: - case CodeGenTarget::MetalLib: - case CodeGenTarget::MetalLibAssembly: - case CodeGenTarget::PTX: - case CodeGenTarget::HostHostCallable: - case CodeGenTarget::ShaderHostCallable: - case CodeGenTarget::ShaderSharedLibrary: - case CodeGenTarget::HostExecutable: - case CodeGenTarget::HostSharedLibrary: - case CodeGenTarget::WGSLSPIRVAssembly: - { - SLANG_RETURN_ON_FAIL(_emitEntryPoints(outArtifact)); - - maybeDumpIntermediate(outArtifact); - return SLANG_OK; - } - break; - case CodeGenTarget::GLSL: - case CodeGenTarget::HLSL: - case CodeGenTarget::CUDASource: - case CodeGenTarget::CPPSource: - case CodeGenTarget::HostCPPSource: - case CodeGenTarget::PyTorchCppBinding: - case CodeGenTarget::CSource: - case CodeGenTarget::Metal: - case CodeGenTarget::WGSL: - { - RefPtr extensionTracker = _newExtensionTracker(target); - - CodeGenContext subContext(this, target, extensionTracker); + RefPtr extensionTracker = _newExtensionTracker(target); - ComPtr sourceArtifact; + CodeGenContext subContext(this, target, extensionTracker); - SLANG_RETURN_ON_FAIL(subContext.emitEntryPointsSource(sourceArtifact)); + ComPtr sourceArtifact; - subContext.maybeDumpIntermediate(sourceArtifact); - outArtifact = sourceArtifact; - return SLANG_OK; - } - break; + SLANG_RETURN_ON_FAIL(subContext.emitEntryPointsSource(sourceArtifact)); - case CodeGenTarget::None: - // The user requested no output + subContext.maybeDumpIntermediate(sourceArtifact); + outArtifact = sourceArtifact; return SLANG_OK; + } + break; - // Note(tfoley): We currently hit this case when compiling the core module - case CodeGenTarget::Unknown: - return SLANG_OK; + case CodeGenTarget::None: + // The user requested no output + return SLANG_OK; - default: - SLANG_UNEXPECTED("unhandled code generation target"); - break; - } - return SLANG_FAIL; + // Note(tfoley): We currently hit this case when compiling the core module + case CodeGenTarget::Unknown: return SLANG_OK; + + default: SLANG_UNEXPECTED("unhandled code generation target"); break; } + return SLANG_FAIL; +} - void EndToEndCompileRequest::writeArtifactToStandardOutput(IArtifact* artifact, DiagnosticSink* sink) +void EndToEndCompileRequest::writeArtifactToStandardOutput( + IArtifact* artifact, + DiagnosticSink* sink) +{ + // If it's host callable it's not available to write to output + if (isDerivedFrom(artifact->getDesc().kind, ArtifactKind::HostCallable)) { - // If it's host callable it's not available to write to output - if (isDerivedFrom(artifact->getDesc().kind, ArtifactKind::HostCallable)) - { - return; - } - - auto session = getSession(); - ArtifactOutputUtil::maybeConvertAndWrite(session, artifact, sink, toSlice("stdout"), getWriter(WriterChannel::StdOutput)); + return; } - String EndToEndCompileRequest::_getWholeProgramPath(TargetRequest* targetReq) + auto session = getSession(); + ArtifactOutputUtil::maybeConvertAndWrite( + session, + artifact, + sink, + toSlice("stdout"), + getWriter(WriterChannel::StdOutput)); +} + +String EndToEndCompileRequest::_getWholeProgramPath(TargetRequest* targetReq) +{ + RefPtr targetInfo; + if (m_targetInfos.tryGetValue(targetReq, targetInfo)) { - RefPtr targetInfo; - if (m_targetInfos.tryGetValue(targetReq, targetInfo)) - { - return targetInfo->wholeTargetOutputPath; - } - return String(); + return targetInfo->wholeTargetOutputPath; } + return String(); +} - String EndToEndCompileRequest::_getEntryPointPath(TargetRequest* targetReq, Index entryPointIndex) +String EndToEndCompileRequest::_getEntryPointPath(TargetRequest* targetReq, Index entryPointIndex) +{ + // It is possible that we are dynamically discovering entry + // points (using `[shader(...)]` attributes), so that there + // might be entry points added to the program that did not + // get paths specified via command-line options. + // + RefPtr targetInfo; + if (m_targetInfos.tryGetValue(targetReq, targetInfo)) { - // It is possible that we are dynamically discovering entry - // points (using `[shader(...)]` attributes), so that there - // might be entry points added to the program that did not - // get paths specified via command-line options. - // - RefPtr targetInfo; - if (m_targetInfos.tryGetValue(targetReq, targetInfo)) + String outputPath; + if (targetInfo->entryPointOutputPaths.tryGetValue(entryPointIndex, outputPath)) { - String outputPath; - if (targetInfo->entryPointOutputPaths.tryGetValue(entryPointIndex, outputPath)) - { - return outputPath; - } + return outputPath; } + } + + return String(); +} - return String(); +SlangResult EndToEndCompileRequest::_writeArtifact(const String& path, IArtifact* artifact) +{ + if (path.getLength() > 0) + { + SLANG_RETURN_ON_FAIL(ArtifactOutputUtil::writeToFile(artifact, getSink(), path)); + } + else if (m_containerFormat == ContainerFormat::None) + { + // If we aren't writing to a container and we didn't write to a file, we can output to + // standard output + writeArtifactToStandardOutput(artifact, getSink()); } + return SLANG_OK; +} - SlangResult EndToEndCompileRequest::_writeArtifact(const String& path, IArtifact* artifact) +SlangResult EndToEndCompileRequest::_maybeWriteArtifact(const String& path, IArtifact* artifact) +{ + // We don't have to do anything if there is no artifact + if (!artifact) { - if (path.getLength() > 0) - { - SLANG_RETURN_ON_FAIL(ArtifactOutputUtil::writeToFile(artifact, getSink(), path)); - } - else if (m_containerFormat == ContainerFormat::None) - { - // If we aren't writing to a container and we didn't write to a file, we can output to standard output - writeArtifactToStandardOutput(artifact, getSink()); - } return SLANG_OK; } - SlangResult EndToEndCompileRequest::_maybeWriteArtifact(const String& path, IArtifact* artifact) + // If embedding is enabled... + if (m_sourceEmbedStyle != SourceEmbedUtil::Style::None) { - // We don't have to do anything if there is no artifact - if (!artifact) - { - return SLANG_OK; - } - - // If embedding is enabled... - if (m_sourceEmbedStyle != SourceEmbedUtil::Style::None) - { - SourceEmbedUtil::Options options; + SourceEmbedUtil::Options options; - options.style = m_sourceEmbedStyle; - options.variableName = m_sourceEmbedName; - options.language = (SlangSourceLanguage)m_sourceEmbedLanguage; + options.style = m_sourceEmbedStyle; + options.variableName = m_sourceEmbedName; + options.language = (SlangSourceLanguage)m_sourceEmbedLanguage; - ComPtr embeddedArtifact; - SLANG_RETURN_ON_FAIL(SourceEmbedUtil::createEmbedded(artifact, options, embeddedArtifact)); + ComPtr embeddedArtifact; + SLANG_RETURN_ON_FAIL(SourceEmbedUtil::createEmbedded(artifact, options, embeddedArtifact)); - if (!embeddedArtifact) - { - return SLANG_FAIL; - } - SLANG_RETURN_ON_FAIL(_writeArtifact(SourceEmbedUtil::getPath(path, options), embeddedArtifact)); - return SLANG_OK; - } - else + if (!embeddedArtifact) { - SLANG_RETURN_ON_FAIL(_writeArtifact(path, artifact)); + return SLANG_FAIL; } - + SLANG_RETURN_ON_FAIL( + _writeArtifact(SourceEmbedUtil::getPath(path, options), embeddedArtifact)); return SLANG_OK; } - - IArtifact* TargetProgram::_createWholeProgramResult( - DiagnosticSink* sink, - EndToEndCompileRequest* endToEndReq) + else { - // We want to call `emitEntryPoints` function to generate code that contains - // all the entrypoints defined in `m_program`. - // The current logic of `emitEntryPoints` takes a list of entry-point indices to - // emit code for, so we construct such a list first. - List entryPointIndices; + SLANG_RETURN_ON_FAIL(_writeArtifact(path, artifact)); + } - m_entryPointResults.setCount(m_program->getEntryPointCount()); - entryPointIndices.setCount(m_program->getEntryPointCount()); - for (Index i = 0; i < entryPointIndices.getCount(); i++) - entryPointIndices[i] = i; - - CodeGenContext::Shared sharedCodeGenContext(this, entryPointIndices, sink, endToEndReq); - CodeGenContext codeGenContext(&sharedCodeGenContext); + return SLANG_OK; +} - if (SLANG_FAILED(codeGenContext.emitEntryPoints(m_wholeProgramResult))) - { - return nullptr; - } - - return m_wholeProgramResult; - } +IArtifact* TargetProgram::_createWholeProgramResult( + DiagnosticSink* sink, + EndToEndCompileRequest* endToEndReq) +{ + // We want to call `emitEntryPoints` function to generate code that contains + // all the entrypoints defined in `m_program`. + // The current logic of `emitEntryPoints` takes a list of entry-point indices to + // emit code for, so we construct such a list first. + List entryPointIndices; + + m_entryPointResults.setCount(m_program->getEntryPointCount()); + entryPointIndices.setCount(m_program->getEntryPointCount()); + for (Index i = 0; i < entryPointIndices.getCount(); i++) + entryPointIndices[i] = i; + + CodeGenContext::Shared sharedCodeGenContext(this, entryPointIndices, sink, endToEndReq); + CodeGenContext codeGenContext(&sharedCodeGenContext); - IArtifact* TargetProgram::_createEntryPointResult( - Int entryPointIndex, - DiagnosticSink* sink, - EndToEndCompileRequest* endToEndReq) + if (SLANG_FAILED(codeGenContext.emitEntryPoints(m_wholeProgramResult))) { - // It is possible that entry points got added to the `Program` - // *after* we created this `TargetProgram`, so there might be - // a request for an entry point that we didn't allocate space for. - // - // TODO: Change the construction logic so that a `Program` is - // constructed all at once rather than incrementally, to avoid - // this problem. - // - if(entryPointIndex >= m_entryPointResults.getCount()) - m_entryPointResults.setCount(entryPointIndex + 1); + return nullptr; + } - - CodeGenContext::EntryPointIndices entryPointIndices; - entryPointIndices.add(entryPointIndex); + return m_wholeProgramResult; +} - CodeGenContext::Shared sharedCodeGenContext(this, entryPointIndices, sink, endToEndReq); - CodeGenContext codeGenContext(&sharedCodeGenContext); +IArtifact* TargetProgram::_createEntryPointResult( + Int entryPointIndex, + DiagnosticSink* sink, + EndToEndCompileRequest* endToEndReq) +{ + // It is possible that entry points got added to the `Program` + // *after* we created this `TargetProgram`, so there might be + // a request for an entry point that we didn't allocate space for. + // + // TODO: Change the construction logic so that a `Program` is + // constructed all at once rather than incrementally, to avoid + // this problem. + // + if (entryPointIndex >= m_entryPointResults.getCount()) + m_entryPointResults.setCount(entryPointIndex + 1); - codeGenContext.emitEntryPoints(m_entryPointResults[entryPointIndex]); - return m_entryPointResults[entryPointIndex]; - } + CodeGenContext::EntryPointIndices entryPointIndices; + entryPointIndices.add(entryPointIndex); - IArtifact* TargetProgram::getOrCreateWholeProgramResult( - DiagnosticSink* sink) - { - if (m_wholeProgramResult) - return m_wholeProgramResult; + CodeGenContext::Shared sharedCodeGenContext(this, entryPointIndices, sink, endToEndReq); + CodeGenContext codeGenContext(&sharedCodeGenContext); - // If we haven't yet computed a layout for this target - // program, we need to make sure that is done before - // code generation. - // - if (!getOrCreateIRModuleForLayout(sink)) - { - return nullptr; - } + codeGenContext.emitEntryPoints(m_entryPointResults[entryPointIndex]); - return _createWholeProgramResult(sink); - } + return m_entryPointResults[entryPointIndex]; +} + +IArtifact* TargetProgram::getOrCreateWholeProgramResult(DiagnosticSink* sink) +{ + if (m_wholeProgramResult) + return m_wholeProgramResult; - IArtifact* TargetProgram::getOrCreateEntryPointResult( - Int entryPointIndex, - DiagnosticSink* sink) + // If we haven't yet computed a layout for this target + // program, we need to make sure that is done before + // code generation. + // + if (!getOrCreateIRModuleForLayout(sink)) { - if(entryPointIndex >= m_entryPointResults.getCount()) - m_entryPointResults.setCount(entryPointIndex + 1); + return nullptr; + } - if(IArtifact* artifact = m_entryPointResults[entryPointIndex]) - return artifact; + return _createWholeProgramResult(sink); +} - // If we haven't yet computed a layout for this target - // program, we need to make sure that is done before - // code generation. - // - if( !getOrCreateIRModuleForLayout(sink) ) - { - return nullptr; - } +IArtifact* TargetProgram::getOrCreateEntryPointResult(Int entryPointIndex, DiagnosticSink* sink) +{ + if (entryPointIndex >= m_entryPointResults.getCount()) + m_entryPointResults.setCount(entryPointIndex + 1); - return _createEntryPointResult( - entryPointIndex, - sink); - } + if (IArtifact* artifact = m_entryPointResults[entryPointIndex]) + return artifact; - void EndToEndCompileRequest::generateOutput( - TargetProgram* targetProgram) + // If we haven't yet computed a layout for this target + // program, we need to make sure that is done before + // code generation. + // + if (!getOrCreateIRModuleForLayout(sink)) { - auto program = targetProgram->getProgram(); - - // Generate target code any entry points that - // have been requested for compilation. - auto entryPointCount = program->getEntryPointCount(); - if (targetProgram->getOptionSet().getBoolOption(CompilerOptionName::GenerateWholeProgram)) - { - targetProgram->_createWholeProgramResult(getSink(), this); - } - else - { - for (Index ii = 0; ii < entryPointCount; ++ii) - { - targetProgram->_createEntryPointResult( - ii, - getSink(), - this); - } - } + return nullptr; } - - bool _shouldWriteSourceLocs(Linkage* linkage) + return _createEntryPointResult(entryPointIndex, sink); +} + +void EndToEndCompileRequest::generateOutput(TargetProgram* targetProgram) +{ + auto program = targetProgram->getProgram(); + + // Generate target code any entry points that + // have been requested for compilation. + auto entryPointCount = program->getEntryPointCount(); + if (targetProgram->getOptionSet().getBoolOption(CompilerOptionName::GenerateWholeProgram)) + { + targetProgram->_createWholeProgramResult(getSink(), this); + } + else { - // If debug information or source manager are not avaiable we can't/shouldn't write out locs - if (linkage->m_optionSet.getEnumOption(CompilerOptionName::DebugInformation) == DebugInfoLevel::None || - linkage->getSourceManager() == nullptr) + for (Index ii = 0; ii < entryPointCount; ++ii) { - return false; + targetProgram->_createEntryPointResult(ii, getSink(), this); } - - // Otherwise we do want to write out the locs - return true; } +} + - SlangResult EndToEndCompileRequest::writeContainerToStream(Stream* stream) +bool _shouldWriteSourceLocs(Linkage* linkage) +{ + // If debug information or source manager are not avaiable we can't/shouldn't write out locs + if (linkage->m_optionSet.getEnumOption(CompilerOptionName::DebugInformation) == + DebugInfoLevel::None || + linkage->getSourceManager() == nullptr) { - auto linkage = getLinkage(); + return false; + } - // Set up options - SerialContainerUtil::WriteOptions options; + // Otherwise we do want to write out the locs + return true; +} - options.compressionType = linkage->m_optionSet.getEnumOption(CompilerOptionName::IrCompression); +SlangResult EndToEndCompileRequest::writeContainerToStream(Stream* stream) +{ + auto linkage = getLinkage(); - // If debug information is enabled, enable writing out source locs - if (_shouldWriteSourceLocs(linkage)) - { - options.optionFlags |= SerialOptionFlag::SourceLocation; - options.sourceManager = linkage->getSourceManager(); - } + // Set up options + SerialContainerUtil::WriteOptions options; - { - RiffContainer container; - { - SerialContainerData data; - SLANG_RETURN_ON_FAIL(SerialContainerUtil::addEndToEndRequestToData(this, options, data)); - SLANG_RETURN_ON_FAIL(SerialContainerUtil::write(data, options, &container)); - } - // We now write the RiffContainer to the stream - SLANG_RETURN_ON_FAIL(RiffUtil::write(container.getRoot(), true, stream)); - } + options.compressionType = linkage->m_optionSet.getEnumOption( + CompilerOptionName::IrCompression); - return SLANG_OK; + // If debug information is enabled, enable writing out source locs + if (_shouldWriteSourceLocs(linkage)) + { + options.optionFlags |= SerialOptionFlag::SourceLocation; + options.sourceManager = linkage->getSourceManager(); } - static IBoxValue* _getObfuscatedSourceMap(TranslationUnitRequest* translationUnit) { - if (auto module = translationUnit->getModule()) + RiffContainer container; { - if (auto irModule = module->getIRModule()) - { - return irModule->getObfuscatedSourceMap(); - } + SerialContainerData data; + SLANG_RETURN_ON_FAIL( + SerialContainerUtil::addEndToEndRequestToData(this, options, data)); + SLANG_RETURN_ON_FAIL(SerialContainerUtil::write(data, options, &container)); } - return nullptr; + // We now write the RiffContainer to the stream + SLANG_RETURN_ON_FAIL(RiffUtil::write(container.getRoot(), true, stream)); } - SlangResult EndToEndCompileRequest::maybeCreateContainer() + return SLANG_OK; +} + +static IBoxValue* _getObfuscatedSourceMap(TranslationUnitRequest* translationUnit) +{ + if (auto module = translationUnit->getModule()) { - m_containerArtifact.setNull(); + if (auto irModule = module->getIRModule()) + { + return irModule->getObfuscatedSourceMap(); + } + } + return nullptr; +} - List> artifacts; +SlangResult EndToEndCompileRequest::maybeCreateContainer() +{ + m_containerArtifact.setNull(); - auto linkage = getLinkage(); - - auto program = getSpecializedGlobalAndEntryPointsComponentType(); + List> artifacts; - for (auto targetReq : linkage->targets) - { - auto targetProgram = program->getTargetProgram(targetReq); + auto linkage = getLinkage(); + + auto program = getSpecializedGlobalAndEntryPointsComponentType(); + + for (auto targetReq : linkage->targets) + { + auto targetProgram = program->getTargetProgram(targetReq); - if (targetProgram->getOptionSet().getBoolOption(CompilerOptionName::GenerateWholeProgram)) - { - if (auto artifact = targetProgram->getExistingWholeProgramResult()) + if (targetProgram->getOptionSet().getBoolOption(CompilerOptionName::GenerateWholeProgram)) + { + if (auto artifact = targetProgram->getExistingWholeProgramResult()) + { + if (!targetProgram->getOptionSet().getBoolOption( + CompilerOptionName::EmbedDownstreamIR)) { - if (!targetProgram->getOptionSet().getBoolOption(CompilerOptionName::EmbedDownstreamIR)) - { - artifacts.add(ComPtr(artifact)); - } + artifacts.add(ComPtr(artifact)); } } - else + } + else + { + Index entryPointCount = program->getEntryPointCount(); + for (Index ee = 0; ee < entryPointCount; ++ee) { - Index entryPointCount = program->getEntryPointCount(); - for (Index ee = 0; ee < entryPointCount; ++ee) + if (auto artifact = targetProgram->getExistingEntryPointResult(ee)) { - if (auto artifact = targetProgram->getExistingEntryPointResult(ee)) - { - artifacts.add(ComPtr(artifact)); - } + artifacts.add(ComPtr(artifact)); } } } + } - // If IR emitting is enabled, add IR to the artifacts - if (m_emitIr && (m_containerFormat == ContainerFormat::SlangModule)) + // If IR emitting is enabled, add IR to the artifacts + if (m_emitIr && (m_containerFormat == ContainerFormat::SlangModule)) + { + OwnedMemoryStream stream(FileAccess::Write); + SlangResult res = writeContainerToStream(&stream); + if (SLANG_FAILED(res)) { - OwnedMemoryStream stream(FileAccess::Write); - SlangResult res = writeContainerToStream(&stream); - if (SLANG_FAILED(res)) - { - getSink()->diagnose(SourceLoc(), Diagnostics::unableToCreateModuleContainer); - return res; - } + getSink()->diagnose(SourceLoc(), Diagnostics::unableToCreateModuleContainer); + return res; + } - // Need to turn into a blob - List blobData; - stream.swapContents(blobData); + // Need to turn into a blob + List blobData; + stream.swapContents(blobData); - auto containerBlob = ListBlob::moveCreate(blobData); + auto containerBlob = ListBlob::moveCreate(blobData); - auto irArtifact = Artifact::create(ArtifactDesc::make(Artifact::Kind::CompileBinary, ArtifactPayload::SlangIR, ArtifactStyle::Unknown)); - irArtifact->addRepresentationUnknown(containerBlob); - - // Add the IR artifact - artifacts.add(irArtifact); - } + auto irArtifact = Artifact::create(ArtifactDesc::make( + Artifact::Kind::CompileBinary, + ArtifactPayload::SlangIR, + ArtifactStyle::Unknown)); + irArtifact->addRepresentationUnknown(containerBlob); - // If there is only one artifact we can use that as the container - if (artifacts.getCount() == 1) + // Add the IR artifact + artifacts.add(irArtifact); + } + + // If there is only one artifact we can use that as the container + if (artifacts.getCount() == 1) + { + m_containerArtifact = artifacts[0]; + } + else + { + m_containerArtifact = ArtifactUtil::createArtifact( + ArtifactDesc::make(ArtifactKind::Container, ArtifactPayload::CompileResults)); + + for (IArtifact* childArtifact : artifacts) { - m_containerArtifact = artifacts[0]; + m_containerArtifact->addChild(childArtifact); } - else - { - m_containerArtifact = ArtifactUtil::createArtifact(ArtifactDesc::make(ArtifactKind::Container, ArtifactPayload::CompileResults)); + } - for (IArtifact* childArtifact : artifacts) - { - m_containerArtifact->addChild(childArtifact); - } - } + // Get all of the source obfuscated source maps and add those + if (m_containerArtifact) + { + auto frontEndReq = getFrontEndReq(); - // Get all of the source obfuscated source maps and add those - if (m_containerArtifact) + for (auto translationUnit : frontEndReq->translationUnits) { - auto frontEndReq = getFrontEndReq(); - - for (auto translationUnit : frontEndReq->translationUnits) + // Hmmm do I have to therefore add a map for all translation units(!) + // I guess this is okay in so far as an association can always be looked up by name + if (auto sourceMap = _getObfuscatedSourceMap(translationUnit)) { - // Hmmm do I have to therefore add a map for all translation units(!) - // I guess this is okay in so far as an association can always be looked up by name - if (auto sourceMap = _getObfuscatedSourceMap(translationUnit)) - { - auto artifactDesc = ArtifactDesc::make(ArtifactKind::Json, ArtifactPayload::SourceMap, ArtifactStyle::Obfuscated); + auto artifactDesc = ArtifactDesc::make( + ArtifactKind::Json, + ArtifactPayload::SourceMap, + ArtifactStyle::Obfuscated); + + // Create the source map artifact + auto sourceMapArtifact = + Artifact::create(artifactDesc, sourceMap->get().m_file.getUnownedSlice()); - // Create the source map artifact - auto sourceMapArtifact = Artifact::create(artifactDesc, sourceMap->get().m_file.getUnownedSlice()); + // Add the repesentation + sourceMapArtifact->addRepresentation(sourceMap); - // Add the repesentation - sourceMapArtifact->addRepresentation(sourceMap); + // Associate with the container + m_containerArtifact->addAssociated(sourceMapArtifact); + } + } + } + + return SLANG_OK; +} + +CompilerOptionSet& EndToEndCompileRequest::getTargetOptionSet(TargetRequest* req) +{ + return req->getOptionSet(); +} - // Associate with the container - m_containerArtifact->addAssociated(sourceMapArtifact); - } - } - } +CompilerOptionSet& EndToEndCompileRequest::getTargetOptionSet(Index targetIndex) +{ + return m_linkage->targets[targetIndex]->getOptionSet(); +} +SlangResult EndToEndCompileRequest::maybeWriteContainer(const String& fileName) +{ + // If there is no container, or filename, don't write anything + if (fileName.getLength() == 0 || !m_containerArtifact) + { return SLANG_OK; } - CompilerOptionSet& EndToEndCompileRequest::getTargetOptionSet(TargetRequest* req) - { - return req->getOptionSet(); - } + // Filter the containerArtifact into things that can be written + ComPtr writeArtifact; + SLANG_RETURN_ON_FAIL(ArtifactContainerUtil::filter(m_containerArtifact, writeArtifact)); - CompilerOptionSet& EndToEndCompileRequest::getTargetOptionSet(Index targetIndex) + // Only write if there is something to write + if (writeArtifact) { - return m_linkage->targets[targetIndex]->getOptionSet(); + SLANG_RETURN_ON_FAIL(ArtifactContainerUtil::writeContainer(writeArtifact, fileName)); } - SlangResult EndToEndCompileRequest::maybeWriteContainer(const String& fileName) - { - // If there is no container, or filename, don't write anything - if (fileName.getLength() == 0 || !m_containerArtifact) - { - return SLANG_OK; - } + return SLANG_OK; +} - // Filter the containerArtifact into things that can be written - ComPtr writeArtifact; - SLANG_RETURN_ON_FAIL(ArtifactContainerUtil::filter(m_containerArtifact, writeArtifact)); +static void _writeString(Stream& stream, const char* string) +{ + stream.write(string, strlen(string)); +} - // Only write if there is something to write - if (writeArtifact) +static void _escapeDependencyString(const char* string, StringBuilder& outBuilder) +{ + // make has unusual escaping rules, but we only care about characters that are acceptable in a + // path + for (const char* p = string; *p; ++p) + { + char c = *p; + switch (c) { - SLANG_RETURN_ON_FAIL(ArtifactContainerUtil::writeContainer(writeArtifact, fileName)); + case ' ': + case ':': + case '#': + case '[': + case ']': + case '\\': outBuilder.appendChar('\\'); break; + + case '$': outBuilder.appendChar('$'); break; } - - return SLANG_OK; - } - static void _writeString(Stream& stream, const char* string) - { - stream.write(string, strlen(string)); + outBuilder.appendChar(c); } +} - static void _escapeDependencyString(const char* string, StringBuilder& outBuilder) - { - // make has unusual escaping rules, but we only care about characters that are acceptable in a path - for (const char* p = string; *p; ++p) - { - char c = *p; - switch(c) - { - case ' ': - case ':': - case '#': - case '[': - case ']': - case '\\': - outBuilder.appendChar('\\'); - break; - - case '$': - outBuilder.appendChar('$'); - break; - } +// Writes a line to the file stream, formatted like this: +// : +static void _writeDependencyStatement( + Stream& stream, + EndToEndCompileRequest* compileRequest, + const String& outputPath) +{ + if (outputPath.getLength() == 0) + return; - outBuilder.appendChar(c); - } - } + StringBuilder builder; + _escapeDependencyString(outputPath.begin(), builder); + _writeString(stream, builder.begin()); + _writeString(stream, ": "); - // Writes a line to the file stream, formatted like this: - // : - static void _writeDependencyStatement(Stream& stream, EndToEndCompileRequest* compileRequest, const String& outputPath) + int dependencyCount = compileRequest->getDependencyFileCount(); + for (int dependencyIndex = 0; dependencyIndex < dependencyCount; ++dependencyIndex) { - if (outputPath.getLength() == 0) - return; - - StringBuilder builder; - _escapeDependencyString(outputPath.begin(), builder); + builder.clear(); + _escapeDependencyString(compileRequest->getDependencyFilePath(dependencyIndex), builder); _writeString(stream, builder.begin()); - _writeString(stream, ": "); - - int dependencyCount = compileRequest->getDependencyFileCount(); - for (int dependencyIndex = 0; dependencyIndex < dependencyCount; ++dependencyIndex) - { - builder.clear(); - _escapeDependencyString(compileRequest->getDependencyFilePath(dependencyIndex), builder); - _writeString(stream, builder.begin()); - _writeString(stream, (dependencyIndex + 1 < dependencyCount) ? " " : "\n"); - } + _writeString(stream, (dependencyIndex + 1 < dependencyCount) ? " " : "\n"); } +} - // Writes a file with dependency info, with one line in the output file per compile product. - static SlangResult _writeDependencyFile(EndToEndCompileRequest* compileRequest) - { - if (compileRequest->m_dependencyOutputPath.getLength() == 0) - return SLANG_OK; +// Writes a file with dependency info, with one line in the output file per compile product. +static SlangResult _writeDependencyFile(EndToEndCompileRequest* compileRequest) +{ + if (compileRequest->m_dependencyOutputPath.getLength() == 0) + return SLANG_OK; - FileStream stream; - SLANG_RETURN_ON_FAIL(stream.init(compileRequest->m_dependencyOutputPath, FileMode::Create, FileAccess::Write, FileShare::ReadWrite)); + FileStream stream; + SLANG_RETURN_ON_FAIL(stream.init( + compileRequest->m_dependencyOutputPath, + FileMode::Create, + FileAccess::Write, + FileShare::ReadWrite)); - auto linkage = compileRequest->getLinkage(); - auto program = compileRequest->getSpecializedGlobalAndEntryPointsComponentType(); + auto linkage = compileRequest->getLinkage(); + auto program = compileRequest->getSpecializedGlobalAndEntryPointsComponentType(); - // Iterate over all the targets and their outputs - for (const auto& targetReq : linkage->targets) + // Iterate over all the targets and their outputs + for (const auto& targetReq : linkage->targets) + { + if (compileRequest->getTargetOptionSet(targetReq).getBoolOption( + CompilerOptionName::GenerateWholeProgram)) { - if (compileRequest->getTargetOptionSet(targetReq).getBoolOption(CompilerOptionName::GenerateWholeProgram)) + RefPtr targetInfo; + if (compileRequest->m_targetInfos.tryGetValue(targetReq, targetInfo)) { - RefPtr targetInfo; - if (compileRequest->m_targetInfos.tryGetValue(targetReq, targetInfo)) - { - _writeDependencyStatement(stream, compileRequest, targetInfo->wholeTargetOutputPath); - } + _writeDependencyStatement( + stream, + compileRequest, + targetInfo->wholeTargetOutputPath); } - else + } + else + { + Index entryPointCount = program->getEntryPointCount(); + for (Index entryPointIndex = 0; entryPointIndex < entryPointCount; ++entryPointIndex) { - Index entryPointCount = program->getEntryPointCount(); - for (Index entryPointIndex = 0; entryPointIndex < entryPointCount; ++entryPointIndex) + RefPtr targetInfo; + if (compileRequest->m_targetInfos.tryGetValue(targetReq, targetInfo)) { - RefPtr targetInfo; - if (compileRequest->m_targetInfos.tryGetValue(targetReq, targetInfo)) + String outputPath; + if (targetInfo->entryPointOutputPaths.tryGetValue(entryPointIndex, outputPath)) { - String outputPath; - if (targetInfo->entryPointOutputPaths.tryGetValue(entryPointIndex, outputPath)) - { - _writeDependencyStatement(stream, compileRequest, outputPath); - } + _writeDependencyStatement(stream, compileRequest, outputPath); } } } } - - return SLANG_OK; } + return SLANG_OK; +} + - void EndToEndCompileRequest::generateOutput( - ComponentType* program) +void EndToEndCompileRequest::generateOutput(ComponentType* program) +{ + // When dynamic dispatch is disabled, the program must + // be fully specialized by now. So we check if we still + // have unspecialized generic/existential parameters, + // and report them as an error. + // + auto specializationParamCount = program->getSpecializationParamCount(); + if (getOptionSet().getBoolOption(CompilerOptionName::DisableDynamicDispatch) && + specializationParamCount != 0) { - // When dynamic dispatch is disabled, the program must - // be fully specialized by now. So we check if we still - // have unspecialized generic/existential parameters, - // and report them as an error. - // - auto specializationParamCount = program->getSpecializationParamCount(); - if (getOptionSet().getBoolOption(CompilerOptionName::DisableDynamicDispatch) && specializationParamCount != 0) + auto sink = getSink(); + + for (Index ii = 0; ii < specializationParamCount; ++ii) { - auto sink = getSink(); - - for( Index ii = 0; ii < specializationParamCount; ++ii ) + auto specializationParam = program->getSpecializationParam(ii); + if (auto decl = as(specializationParam.object)) { - auto specializationParam = program->getSpecializationParam(ii); - if( auto decl = as(specializationParam.object) ) - { - sink->diagnose(specializationParam.loc, Diagnostics::specializationParameterOfNameNotSpecialized, decl); - } - else if( auto type = as(specializationParam.object) ) - { - sink->diagnose(specializationParam.loc, Diagnostics::specializationParameterOfNameNotSpecialized, type); - } - else - { - sink->diagnose(specializationParam.loc, Diagnostics::specializationParameterNotSpecialized); - } + sink->diagnose( + specializationParam.loc, + Diagnostics::specializationParameterOfNameNotSpecialized, + decl); + } + else if (auto type = as(specializationParam.object)) + { + sink->diagnose( + specializationParam.loc, + Diagnostics::specializationParameterOfNameNotSpecialized, + type); + } + else + { + sink->diagnose( + specializationParam.loc, + Diagnostics::specializationParameterNotSpecialized); } - - return; } + return; + } + - // Go through the code-generation targets that the user - // has specified, and generate code for each of them. - // - auto linkage = getLinkage(); - for (auto targetReq : linkage->targets) - { - if (targetReq->getOptionSet().getBoolOption(CompilerOptionName::EmbedDownstreamIR)) - continue; + // Go through the code-generation targets that the user + // has specified, and generate code for each of them. + // + auto linkage = getLinkage(); + for (auto targetReq : linkage->targets) + { + if (targetReq->getOptionSet().getBoolOption(CompilerOptionName::EmbedDownstreamIR)) + continue; - auto targetProgram = program->getTargetProgram(targetReq); - generateOutput(targetProgram); - } + auto targetProgram = program->getTargetProgram(targetReq); + generateOutput(targetProgram); } +} - void EndToEndCompileRequest::generateOutput() - { - SLANG_PROFILE; - generateOutput(getSpecializedGlobalAndEntryPointsComponentType()); +void EndToEndCompileRequest::generateOutput() +{ + SLANG_PROFILE; + generateOutput(getSpecializedGlobalAndEntryPointsComponentType()); + + // If we are in command-line mode, we might be expected to actually + // write output to one or more files here. - // If we are in command-line mode, we might be expected to actually - // write output to one or more files here. + if (m_isCommandLineCompile && m_containerFormat == ContainerFormat::None) + { + auto linkage = getLinkage(); + auto program = getSpecializedGlobalAndEntryPointsComponentType(); - if (m_isCommandLineCompile && - m_containerFormat == ContainerFormat::None) + for (auto targetReq : linkage->targets) { - auto linkage = getLinkage(); - auto program = getSpecializedGlobalAndEntryPointsComponentType(); + auto targetProgram = program->getTargetProgram(targetReq); - for (auto targetReq : linkage->targets) + if (targetProgram->getOptionSet().getBoolOption( + CompilerOptionName::GenerateWholeProgram)) { - auto targetProgram = program->getTargetProgram(targetReq); - - if (targetProgram->getOptionSet().getBoolOption(CompilerOptionName::GenerateWholeProgram)) + if (const auto artifact = targetProgram->getExistingWholeProgramResult()) { - if (const auto artifact = targetProgram->getExistingWholeProgramResult()) - { - const auto path = _getWholeProgramPath(targetReq); + const auto path = _getWholeProgramPath(targetReq); - _maybeWriteArtifact(path, artifact); - } + _maybeWriteArtifact(path, artifact); } - else + } + else + { + Index entryPointCount = program->getEntryPointCount(); + for (Index ee = 0; ee < entryPointCount; ++ee) { - Index entryPointCount = program->getEntryPointCount(); - for (Index ee = 0; ee < entryPointCount; ++ee) - { - if (const auto artifact = targetProgram->getExistingEntryPointResult(ee)) - { - const auto path = _getEntryPointPath(targetReq, ee); + if (const auto artifact = targetProgram->getExistingEntryPointResult(ee)) + { + const auto path = _getEntryPointPath(targetReq, ee); - _maybeWriteArtifact(path, artifact); - } + _maybeWriteArtifact(path, artifact); } } } } + } - // Maybe create the container - maybeCreateContainer(); + // Maybe create the container + maybeCreateContainer(); - // If it's a command line compile we may need to write the container to a file - if (m_isCommandLineCompile) - { - // TODO(JS): - // We could write the container into a source embedded format potentially + // If it's a command line compile we may need to write the container to a file + if (m_isCommandLineCompile) + { + // TODO(JS): + // We could write the container into a source embedded format potentially - maybeWriteContainer(m_containerOutputPath); + maybeWriteContainer(m_containerOutputPath); - _writeDependencyFile(this); - } + _writeDependencyFile(this); } +} - // Debug logic for dumping intermediate outputs +// Debug logic for dumping intermediate outputs - - void CodeGenContext::_dumpIntermediateMaybeWithAssembly(IArtifact* artifact) - { - _dumpIntermediate(artifact); - ComPtr assembly; - ArtifactOutputUtil::maybeDisassemble(getSession(), artifact, nullptr, assembly); +void CodeGenContext::_dumpIntermediateMaybeWithAssembly(IArtifact* artifact) +{ + _dumpIntermediate(artifact); - if (assembly) - { - _dumpIntermediate(assembly); - } - } + ComPtr assembly; + ArtifactOutputUtil::maybeDisassemble(getSession(), artifact, nullptr, assembly); - void CodeGenContext::_dumpIntermediate(IArtifact* artifact) + if (assembly) { - ComPtr blob; - if (SLANG_FAILED(artifact->loadBlob(ArtifactKeep::No, blob.writeRef()))) - { - return; - } - _dumpIntermediate(artifact->getDesc(), blob->getBufferPointer(), blob->getBufferSize()); + _dumpIntermediate(assembly); } +} - void CodeGenContext::_dumpIntermediate( - const ArtifactDesc& desc, - void const* data, - size_t size) +void CodeGenContext::_dumpIntermediate(IArtifact* artifact) +{ + ComPtr blob; + if (SLANG_FAILED(artifact->loadBlob(ArtifactKeep::No, blob.writeRef()))) { - // Try to generate a unique ID for the file to dump, - // even in cases where there might be multiple threads - // doing compilation. - // - // This is primarily a debugging aid, so we don't - // really need/want to do anything too elaborate + return; + } + _dumpIntermediate(artifact->getDesc(), blob->getBufferPointer(), blob->getBufferSize()); +} - static std::atomic counter(0); +void CodeGenContext::_dumpIntermediate(const ArtifactDesc& desc, void const* data, size_t size) +{ + // Try to generate a unique ID for the file to dump, + // even in cases where there might be multiple threads + // doing compilation. + // + // This is primarily a debugging aid, so we don't + // really need/want to do anything too elaborate - const uint32_t id = ++counter; + static std::atomic counter(0); - // Just use the counter for the 'base name' - StringBuilder basename; + const uint32_t id = ++counter; - // Add the prefix - basename << getIntermediateDumpPrefix(); + // Just use the counter for the 'base name' + StringBuilder basename; - // Add the id - basename << int(id); + // Add the prefix + basename << getIntermediateDumpPrefix(); - // Work out the filename based on the desc and the basename - StringBuilder filename; - ArtifactDescUtil::calcNameForDesc(desc, basename.getUnownedSlice(), filename); + // Add the id + basename << int(id); - // If didn't produce a filename, use basename with .unknown extension - if (filename.getLength() == 0) - { - filename = basename; - filename << ".unknown"; - } - - // Write to a file - ArtifactOutputUtil::writeToFile(desc, data, size, filename); - } + // Work out the filename based on the desc and the basename + StringBuilder filename; + ArtifactDescUtil::calcNameForDesc(desc, basename.getUnownedSlice(), filename); - void CodeGenContext::maybeDumpIntermediate(IArtifact* artifact) + // If didn't produce a filename, use basename with .unknown extension + if (filename.getLength() == 0) { - if (!shouldDumpIntermediates()) - return; - - - _dumpIntermediateMaybeWithAssembly(artifact); + filename = basename; + filename << ".unknown"; } - IRDumpOptions CodeGenContext::getIRDumpOptions() - { - if (auto endToEndReq = isEndToEndCompile()) - { - return endToEndReq->getFrontEndReq()->m_irDumpOptions; - } - return IRDumpOptions(); - } + // Write to a file + ArtifactOutputUtil::writeToFile(desc, data, size, filename); +} - bool CodeGenContext::shouldValidateIR() - { - return getTargetProgram()->getOptionSet().getBoolOption(CompilerOptionName::ValidateIr); - } +void CodeGenContext::maybeDumpIntermediate(IArtifact* artifact) +{ + if (!shouldDumpIntermediates()) + return; - bool CodeGenContext::shouldSkipSPIRVValidation() - { - return getTargetProgram()->getOptionSet().getBoolOption(CompilerOptionName::SkipSPIRVValidation); - } - bool CodeGenContext::shouldDumpIR() - { - return getTargetProgram()->getOptionSet().getBoolOption(CompilerOptionName::DumpIr); - } + _dumpIntermediateMaybeWithAssembly(artifact); +} - bool CodeGenContext::shouldReportCheckpointIntermediates() +IRDumpOptions CodeGenContext::getIRDumpOptions() +{ + if (auto endToEndReq = isEndToEndCompile()) { - return getTargetProgram()->getOptionSet().getBoolOption(CompilerOptionName::ReportCheckpointIntermediates); + return endToEndReq->getFrontEndReq()->m_irDumpOptions; } + return IRDumpOptions(); +} - bool CodeGenContext::shouldDumpIntermediates() - { - return getTargetProgram()->getOptionSet().getBoolOption(CompilerOptionName::DumpIntermediates); - } +bool CodeGenContext::shouldValidateIR() +{ + return getTargetProgram()->getOptionSet().getBoolOption(CompilerOptionName::ValidateIr); +} - bool CodeGenContext::shouldTrackLiveness() - { - return getTargetProgram()->getOptionSet().getBoolOption(CompilerOptionName::TrackLiveness); - } +bool CodeGenContext::shouldSkipSPIRVValidation() +{ + return getTargetProgram()->getOptionSet().getBoolOption( + CompilerOptionName::SkipSPIRVValidation); +} - String CodeGenContext::getIntermediateDumpPrefix() - { - return getTargetProgram()->getOptionSet().getStringOption(CompilerOptionName::DumpIntermediatePrefix); - } +bool CodeGenContext::shouldDumpIR() +{ + return getTargetProgram()->getOptionSet().getBoolOption(CompilerOptionName::DumpIr); +} - bool CodeGenContext::getUseUnknownImageFormatAsDefault() - { - return getTargetProgram()->getOptionSet().getBoolOption(CompilerOptionName::DefaultImageFormatUnknown); - } +bool CodeGenContext::shouldReportCheckpointIntermediates() +{ + return getTargetProgram()->getOptionSet().getBoolOption( + CompilerOptionName::ReportCheckpointIntermediates); +} - bool CodeGenContext::isSpecializationDisabled() - { - return getTargetProgram()->getOptionSet().getBoolOption(CompilerOptionName::DisableSpecialization); - } +bool CodeGenContext::shouldDumpIntermediates() +{ + return getTargetProgram()->getOptionSet().getBoolOption(CompilerOptionName::DumpIntermediates); +} - SLANG_NO_THROW SlangResult SLANG_MCALL Module::serialize(ISlangBlob** outSerializedBlob) - { - SerialContainerUtil::WriteOptions writeOptions; - writeOptions.sourceManager = getLinkage()->getSourceManager(); - OwnedMemoryStream memoryStream(FileAccess::Write); - SLANG_RETURN_ON_FAIL(SerialContainerUtil::write(this, writeOptions, &memoryStream)); - *outSerializedBlob = RawBlob::create( - memoryStream.getContents().getBuffer(), - (size_t)memoryStream.getContents().getCount()).detach(); - return SLANG_OK; - } +bool CodeGenContext::shouldTrackLiveness() +{ + return getTargetProgram()->getOptionSet().getBoolOption(CompilerOptionName::TrackLiveness); +} - SLANG_NO_THROW SlangResult SLANG_MCALL Module::writeToFile(char const* fileName) - { - SerialContainerUtil::WriteOptions writeOptions; - writeOptions.sourceManager = getLinkage()->getSourceManager(); - FileStream fileStream; - SLANG_RETURN_ON_FAIL(fileStream.init(fileName, FileMode::Create)); - return SerialContainerUtil::write(this, writeOptions, &fileStream); - } +String CodeGenContext::getIntermediateDumpPrefix() +{ + return getTargetProgram()->getOptionSet().getStringOption( + CompilerOptionName::DumpIntermediatePrefix); +} - SLANG_NO_THROW const char* SLANG_MCALL Module::getName() - { - if (m_name) - return m_name->text.getBuffer(); - return nullptr; - } +bool CodeGenContext::getUseUnknownImageFormatAsDefault() +{ + return getTargetProgram()->getOptionSet().getBoolOption( + CompilerOptionName::DefaultImageFormatUnknown); +} - SLANG_NO_THROW const char* SLANG_MCALL Module::getFilePath() - { - if (m_pathInfo.hasFoundPath()) - return m_pathInfo.foundPath.getBuffer(); - return nullptr; - } +bool CodeGenContext::isSpecializationDisabled() +{ + return getTargetProgram()->getOptionSet().getBoolOption( + CompilerOptionName::DisableSpecialization); +} - SLANG_NO_THROW const char* SLANG_MCALL Module::getUniqueIdentity() - { - if (m_pathInfo.hasUniqueIdentity()) - return m_pathInfo.getMostUniqueIdentity().getBuffer(); - return nullptr; - } +SLANG_NO_THROW SlangResult SLANG_MCALL Module::serialize(ISlangBlob** outSerializedBlob) +{ + SerialContainerUtil::WriteOptions writeOptions; + writeOptions.sourceManager = getLinkage()->getSourceManager(); + OwnedMemoryStream memoryStream(FileAccess::Write); + SLANG_RETURN_ON_FAIL(SerialContainerUtil::write(this, writeOptions, &memoryStream)); + *outSerializedBlob = RawBlob::create( + memoryStream.getContents().getBuffer(), + (size_t)memoryStream.getContents().getCount()) + .detach(); + return SLANG_OK; +} - SLANG_NO_THROW SlangInt32 SLANG_MCALL Module::getDependencyFileCount() - { - return (SlangInt32)getFileDependencies().getCount(); - } +SLANG_NO_THROW SlangResult SLANG_MCALL Module::writeToFile(char const* fileName) +{ + SerialContainerUtil::WriteOptions writeOptions; + writeOptions.sourceManager = getLinkage()->getSourceManager(); + FileStream fileStream; + SLANG_RETURN_ON_FAIL(fileStream.init(fileName, FileMode::Create)); + return SerialContainerUtil::write(this, writeOptions, &fileStream); +} - SLANG_NO_THROW char const* SLANG_MCALL Module::getDependencyFilePath( - SlangInt32 index) - { - SourceFile* sourceFile = getFileDependencies()[index]; - return sourceFile->getPathInfo().hasFoundPath() ? sourceFile->getPathInfo().foundPath.getBuffer() : nullptr; - } +SLANG_NO_THROW const char* SLANG_MCALL Module::getName() +{ + if (m_name) + return m_name->text.getBuffer(); + return nullptr; +} - void validateEntryPoint( - EntryPoint* entryPoint, - DiagnosticSink* sink); +SLANG_NO_THROW const char* SLANG_MCALL Module::getFilePath() +{ + if (m_pathInfo.hasFoundPath()) + return m_pathInfo.foundPath.getBuffer(); + return nullptr; +} - void Module::_discoverEntryPoints(DiagnosticSink* sink, const List>& targets) - { - if (m_entryPoints.getCount() > 0) - return; - _discoverEntryPointsImpl(m_moduleDecl, sink, targets); - } - void Module::_discoverEntryPointsImpl(ContainerDecl* containerDecl, DiagnosticSink* sink, const List>& targets) +SLANG_NO_THROW const char* SLANG_MCALL Module::getUniqueIdentity() +{ + if (m_pathInfo.hasUniqueIdentity()) + return m_pathInfo.getMostUniqueIdentity().getBuffer(); + return nullptr; +} + +SLANG_NO_THROW SlangInt32 SLANG_MCALL Module::getDependencyFileCount() +{ + return (SlangInt32)getFileDependencies().getCount(); +} + +SLANG_NO_THROW char const* SLANG_MCALL Module::getDependencyFilePath(SlangInt32 index) +{ + SourceFile* sourceFile = getFileDependencies()[index]; + return sourceFile->getPathInfo().hasFoundPath() + ? sourceFile->getPathInfo().foundPath.getBuffer() + : nullptr; +} + +void validateEntryPoint(EntryPoint* entryPoint, DiagnosticSink* sink); + +void Module::_discoverEntryPoints(DiagnosticSink* sink, const List>& targets) +{ + if (m_entryPoints.getCount() > 0) + return; + _discoverEntryPointsImpl(m_moduleDecl, sink, targets); +} +void Module::_discoverEntryPointsImpl( + ContainerDecl* containerDecl, + DiagnosticSink* sink, + const List>& targets) +{ + for (auto globalDecl : containerDecl->members) { - for (auto globalDecl : containerDecl->members) + auto maybeFuncDecl = globalDecl; + if (auto genericDecl = as(maybeFuncDecl)) { - auto maybeFuncDecl = globalDecl; - if (auto genericDecl = as(maybeFuncDecl)) - { - maybeFuncDecl = genericDecl->inner; - } + maybeFuncDecl = genericDecl->inner; + } + + if (as(globalDecl) || as(globalDecl) || + as(globalDecl)) + { + _discoverEntryPointsImpl(as(globalDecl), sink, targets); + continue; + } + + auto funcDecl = as(maybeFuncDecl); + if (!funcDecl) + continue; + + Profile profile; + bool resolvedStageOfProfileWithEntryPoint = resolveStageOfProfileWithEntryPoint( + profile, + getLinkage()->m_optionSet, + targets, + funcDecl, + sink); + if (!resolvedStageOfProfileWithEntryPoint) + { + // If there isn't a [shader] attribute, look for a [numthreads] attribute + // since that implicitly means a compute shader. We'll not do this when compiling for + // CUDA/Torch since [numthreads] attributes are utilized differently for those targets. + // - if (as(globalDecl) || as(globalDecl) || as(globalDecl)) + bool allTargetsCUDARelated = true; + for (auto target : targets) { - _discoverEntryPointsImpl(as(globalDecl), sink, targets); - continue; + if (!isCUDATarget(target) && + target->getTarget() != CodeGenTarget::PyTorchCppBinding) + { + allTargetsCUDARelated = false; + break; + } } - auto funcDecl = as(maybeFuncDecl); - if (!funcDecl) + if (allTargetsCUDARelated && targets.getCount() > 0) continue; - Profile profile; - bool resolvedStageOfProfileWithEntryPoint = resolveStageOfProfileWithEntryPoint(profile, getLinkage()->m_optionSet, targets, funcDecl, sink); - if (!resolvedStageOfProfileWithEntryPoint) + bool canDetermineStage = false; + for (auto modifier : funcDecl->modifiers) { - // If there isn't a [shader] attribute, look for a [numthreads] attribute - // since that implicitly means a compute shader. We'll not do this when compiling for - // CUDA/Torch since [numthreads] attributes are utilized differently for those targets. - // - - bool allTargetsCUDARelated = true; - for (auto target : targets) + if (as(modifier)) { - if (!isCUDATarget(target) && - target->getTarget() != CodeGenTarget::PyTorchCppBinding) - { - allTargetsCUDARelated = false; - break; - } + if (funcDecl->findModifier()) + profile.setStage(Stage::Mesh); + else + profile.setStage(Stage::Compute); + canDetermineStage = true; + break; } - - if (allTargetsCUDARelated && targets.getCount() > 0) - continue; - - bool canDetermineStage = false; - for (auto modifier : funcDecl->modifiers) + else if (as(modifier)) { - if (as(modifier)) - { - if (funcDecl->findModifier()) - profile.setStage(Stage::Mesh); - else - profile.setStage(Stage::Compute); - canDetermineStage = true; - break; - } - else if (as(modifier)) - { - profile.setStage(Stage::Hull); - canDetermineStage = true; - break; - } + profile.setStage(Stage::Hull); + canDetermineStage = true; + break; } - if (!canDetermineStage) - continue; } + if (!canDetermineStage) + continue; + } - RefPtr entryPoint = EntryPoint::create( - getLinkage(), - makeDeclRef(funcDecl), - profile); + RefPtr entryPoint = + EntryPoint::create(getLinkage(), makeDeclRef(funcDecl), profile); - validateEntryPoint(entryPoint, sink); + validateEntryPoint(entryPoint, sink); - // Note: in the case that the user didn't explicitly - // specify entry points and we are instead compiling - // a shader "library," then we do not want to automatically - // combine the entry points into groups in the generated - // `Program`, since that would be slightly too magical. - // - // Instead, each entry point will end up in a singleton - // group, so that its entry-point parameters lay out - // independent of the others. - // - _addEntryPoint(entryPoint); - } + // Note: in the case that the user didn't explicitly + // specify entry points and we are instead compiling + // a shader "library," then we do not want to automatically + // combine the entry points into groups in the generated + // `Program`, since that would be slightly too magical. + // + // Instead, each entry point will end up in a singleton + // group, so that its entry-point parameters lay out + // independent of the others. + // + _addEntryPoint(entryPoint); } } - +} // namespace Slang diff --git a/source/slang/slang-compiler.h b/source/slang/slang-compiler.h index 67c931ac8..624e6ec3b 100644 --- a/source/slang/slang-compiler.h +++ b/source/slang/slang-compiler.h @@ -1,3448 +1,3613 @@ #ifndef SLANG_COMPILER_H_INCLUDED #define SLANG_COMPILER_H_INCLUDED -#include "../core/slang-basic.h" -#include "../core/slang-shared-library.h" -#include "../core/slang-crypto.h" - -#include "../compiler-core/slang-downstream-compiler.h" +#include "../compiler-core/slang-artifact-representation-impl.h" +#include "../compiler-core/slang-command-line-args.h" #include "../compiler-core/slang-downstream-compiler-util.h" - -#include "../compiler-core/slang-name.h" +#include "../compiler-core/slang-downstream-compiler.h" #include "../compiler-core/slang-include-system.h" -#include "../compiler-core/slang-command-line-args.h" - +#include "../compiler-core/slang-name.h" #include "../compiler-core/slang-source-embed-util.h" - #include "../compiler-core/slang-spirv-core-grammar.h" - -#include "../core/slang-std-writers.h" +#include "../core/slang-basic.h" #include "../core/slang-command-options.h" - +#include "../core/slang-crypto.h" #include "../core/slang-file-system.h" - -#include "slang-com-ptr.h" - +#include "../core/slang-shared-library.h" +#include "../core/slang-std-writers.h" #include "slang-capability.h" +#include "slang-com-ptr.h" +#include "slang-compiler-options.h" +#include "slang-content-assist-info.h" #include "slang-diagnostics.h" +#include "slang-hlsl-to-vulkan-layout-options.h" #include "slang-preprocessor.h" #include "slang-profile.h" -#include "slang-syntax.h" -#include "slang-content-assist-info.h" -#include "slang-hlsl-to-vulkan-layout-options.h" -#include "slang-compiler-options.h" #include "slang-serialize-ir-types.h" - -#include "../compiler-core/slang-artifact-representation-impl.h" - +#include "slang-syntax.h" #include "slang.h" namespace Slang { - struct PathInfo; - struct IncludeHandler; - struct SharedSemanticsContext; - - class ProgramLayout; - class PtrType; - class TargetProgram; - class TargetRequest; - class TypeLayout; - class Artifact; - - enum class CompilerMode - { - ProduceLibrary, - ProduceShader, - GenerateChoice - }; +struct PathInfo; +struct IncludeHandler; +struct SharedSemanticsContext; + +class ProgramLayout; +class PtrType; +class TargetProgram; +class TargetRequest; +class TypeLayout; +class Artifact; + +enum class CompilerMode +{ + ProduceLibrary, + ProduceShader, + GenerateChoice +}; + +enum class StageTarget +{ + Unknown, + VertexShader, + HullShader, + DomainShader, + GeometryShader, + FragmentShader, + ComputeShader, +}; + +enum class CodeGenTarget : SlangCompileTargetIntegral +{ + Unknown = SLANG_TARGET_UNKNOWN, + None = SLANG_TARGET_NONE, + GLSL = SLANG_GLSL, + HLSL = SLANG_HLSL, + SPIRV = SLANG_SPIRV, + SPIRVAssembly = SLANG_SPIRV_ASM, + DXBytecode = SLANG_DXBC, + DXBytecodeAssembly = SLANG_DXBC_ASM, + DXIL = SLANG_DXIL, + DXILAssembly = SLANG_DXIL_ASM, + CSource = SLANG_C_SOURCE, + CPPSource = SLANG_CPP_SOURCE, + PyTorchCppBinding = SLANG_CPP_PYTORCH_BINDING, + HostCPPSource = SLANG_HOST_CPP_SOURCE, + HostExecutable = SLANG_HOST_EXECUTABLE, + HostSharedLibrary = SLANG_HOST_SHARED_LIBRARY, + ShaderSharedLibrary = SLANG_SHADER_SHARED_LIBRARY, + ShaderHostCallable = SLANG_SHADER_HOST_CALLABLE, + CUDASource = SLANG_CUDA_SOURCE, + PTX = SLANG_PTX, + CUDAObjectCode = SLANG_CUDA_OBJECT_CODE, + ObjectCode = SLANG_OBJECT_CODE, + HostHostCallable = SLANG_HOST_HOST_CALLABLE, + Metal = SLANG_METAL, + MetalLib = SLANG_METAL_LIB, + MetalLibAssembly = SLANG_METAL_LIB_ASM, + WGSL = SLANG_WGSL, + WGSLSPIRVAssembly = SLANG_WGSL_SPIRV_ASM, + WGSLSPIRV = SLANG_WGSL_SPIRV, + CountOf = SLANG_TARGET_COUNT_OF, +}; + +bool isHeterogeneousTarget(CodeGenTarget target); + +void printDiagnosticArg(StringBuilder& sb, CodeGenTarget val); + +enum class ContainerFormat : SlangContainerFormatIntegral +{ + None = SLANG_CONTAINER_FORMAT_NONE, + SlangModule = SLANG_CONTAINER_FORMAT_SLANG_MODULE, +}; + +enum class LineDirectiveMode : SlangLineDirectiveModeIntegral +{ + Default = SLANG_LINE_DIRECTIVE_MODE_DEFAULT, + None = SLANG_LINE_DIRECTIVE_MODE_NONE, + Standard = SLANG_LINE_DIRECTIVE_MODE_STANDARD, + GLSL = SLANG_LINE_DIRECTIVE_MODE_GLSL, + SourceMap = SLANG_LINE_DIRECTIVE_MODE_SOURCE_MAP, +}; + +enum class ResultFormat +{ + None, + Text, + Binary, +}; + +// When storing the layout for a matrix-type +// value, we need to know whether it has been +// laid out with row-major or column-major +// storage. +// +enum MatrixLayoutMode : SlangMatrixLayoutModeIntegral +{ + kMatrixLayoutMode_RowMajor = SLANG_MATRIX_LAYOUT_ROW_MAJOR, + kMatrixLayoutMode_ColumnMajor = SLANG_MATRIX_LAYOUT_COLUMN_MAJOR, +}; + +enum class DebugInfoLevel : SlangDebugInfoLevelIntegral +{ + None = SLANG_DEBUG_INFO_LEVEL_NONE, + Minimal = SLANG_DEBUG_INFO_LEVEL_MINIMAL, + Standard = SLANG_DEBUG_INFO_LEVEL_STANDARD, + Maximal = SLANG_DEBUG_INFO_LEVEL_MAXIMAL, +}; + +enum class DebugInfoFormat : SlangDebugInfoFormatIntegral +{ + Default = SLANG_DEBUG_INFO_FORMAT_DEFAULT, + C7 = SLANG_DEBUG_INFO_FORMAT_C7, + Pdb = SLANG_DEBUG_INFO_FORMAT_PDB, + + Stabs = SLANG_DEBUG_INFO_FORMAT_STABS, + Coff = SLANG_DEBUG_INFO_FORMAT_COFF, + Dwarf = SLANG_DEBUG_INFO_FORMAT_DWARF, + + CountOf = SLANG_DEBUG_INFO_FORMAT_COUNT_OF, +}; + +enum class OptimizationLevel : SlangOptimizationLevelIntegral +{ + None = SLANG_OPTIMIZATION_LEVEL_NONE, + Default = SLANG_OPTIMIZATION_LEVEL_DEFAULT, + High = SLANG_OPTIMIZATION_LEVEL_HIGH, + Maximal = SLANG_OPTIMIZATION_LEVEL_MAXIMAL, +}; + +struct CodeGenContext; +class EndToEndCompileRequest; +class FrontEndCompileRequest; +class Linkage; +class Module; +class TranslationUnitRequest; + +/// Information collected about global or entry-point shader parameters +struct ShaderParamInfo +{ + DeclRef paramDeclRef; + Int firstSpecializationParamIndex = 0; + Int specializationParamCount = 0; +}; + +/// A request for the front-end to find and validate an entry-point function +struct FrontEndEntryPointRequest : RefObject +{ +public: + /// Create a request for an entry point. + FrontEndEntryPointRequest( + FrontEndCompileRequest* compileRequest, + int translationUnitIndex, + Name* name, + Profile profile); + + /// Get the parent front-end compile request. + FrontEndCompileRequest* getCompileRequest() { return m_compileRequest; } + + /// Get the translation unit that contains the entry point. + TranslationUnitRequest* getTranslationUnit(); + + /// Get the name of the entry point to find. + Name* getName() { return m_name; } + + /// Get the stage that the entry point is to be compiled for + Stage getStage() { return m_profile.getStage(); } + + /// Get the profile that the entry point is to be compiled for + Profile getProfile() { return m_profile; } + + /// Get the index to the translation unit + int getTranslationUnitIndex() const { return m_translationUnitIndex; } + +private: + // The parent compile request + FrontEndCompileRequest* m_compileRequest; + + // The index of the translation unit that will hold the entry point + int m_translationUnitIndex; + + // The name of the entry point function to look for + Name* m_name; + + // The profile to compile for (including stage) + Profile m_profile; +}; + +/// Tracks an ordered list of modules that something depends on. +/// TODO: Shader caching currently relies on this being in well defined order. +struct ModuleDependencyList +{ +public: + /// Get the list of modules that are depended on. + List const& getModuleList() { return m_moduleList; } + + /// Add a module and everything it depends on to the list. + void addDependency(Module* module); + + /// Add a module to the list, but not the modules it depends on. + void addLeafDependency(Module* module); + +private: + void _addDependency(Module* module); + + List m_moduleList; + HashSet m_moduleSet; +}; + +/// Tracks an unordered list of source files that something depends on +/// TODO: Shader caching currently relies on this being in well defined order. +struct FileDependencyList +{ +public: + /// Get the list of files that are depended on. + List const& getFileList() { return m_fileList; } + + /// Add a file to the list, if it is not already present + void addDependency(SourceFile* sourceFile); - enum class StageTarget + /// Add all of the paths that `module` depends on to the list + void addDependency(Module* module); + + void clear() { - Unknown, - VertexShader, - HullShader, - DomainShader, - GeometryShader, - FragmentShader, - ComputeShader, - }; + m_fileList.clear(); + m_fileSet.clear(); + } - enum class CodeGenTarget : SlangCompileTargetIntegral - { - Unknown = SLANG_TARGET_UNKNOWN, - None = SLANG_TARGET_NONE, - GLSL = SLANG_GLSL, - HLSL = SLANG_HLSL, - SPIRV = SLANG_SPIRV, - SPIRVAssembly = SLANG_SPIRV_ASM, - DXBytecode = SLANG_DXBC, - DXBytecodeAssembly = SLANG_DXBC_ASM, - DXIL = SLANG_DXIL, - DXILAssembly = SLANG_DXIL_ASM, - CSource = SLANG_C_SOURCE, - CPPSource = SLANG_CPP_SOURCE, - PyTorchCppBinding = SLANG_CPP_PYTORCH_BINDING, - HostCPPSource = SLANG_HOST_CPP_SOURCE, - HostExecutable = SLANG_HOST_EXECUTABLE, - HostSharedLibrary = SLANG_HOST_SHARED_LIBRARY, - ShaderSharedLibrary = SLANG_SHADER_SHARED_LIBRARY, - ShaderHostCallable = SLANG_SHADER_HOST_CALLABLE, - CUDASource = SLANG_CUDA_SOURCE, - PTX = SLANG_PTX, - CUDAObjectCode = SLANG_CUDA_OBJECT_CODE, - ObjectCode = SLANG_OBJECT_CODE, - HostHostCallable = SLANG_HOST_HOST_CALLABLE, - Metal = SLANG_METAL, - MetalLib = SLANG_METAL_LIB, - MetalLibAssembly = SLANG_METAL_LIB_ASM, - WGSL = SLANG_WGSL, - WGSLSPIRVAssembly = SLANG_WGSL_SPIRV_ASM, - WGSLSPIRV = SLANG_WGSL_SPIRV, - CountOf = SLANG_TARGET_COUNT_OF, - }; +private: + // TODO: We are using a `HashSet` here to deduplicate + // the paths so that we don't return the same path + // multiple times from `getFilePathList`, but because + // order isn't important, we could potentially do better + // in terms of memory (at some cost in performance) by + // just sorting the `m_fileList` every once in + // a while and then deduplicating. + + List m_fileList; + HashSet m_fileSet; +}; + + +class EntryPoint; + +class ComponentType; +class ComponentTypeVisitor; + +/// Base class for "component types" that represent the pieces a final +/// shader program gets linked together from. +/// +class ComponentType : public RefObject, + public slang::IComponentType, + public slang::IModulePrecompileService_Experimental +{ +public: + // + // ISlangUnknown interface + // + + SLANG_REF_OBJECT_IUNKNOWN_ALL; + ISlangUnknown* getInterface(Guid const& guid); + + // + // slang::IComponentType interface + // + + SLANG_NO_THROW slang::ISession* SLANG_MCALL getSession() SLANG_OVERRIDE; + SLANG_NO_THROW slang::ProgramLayout* SLANG_MCALL + getLayout(SlangInt targetIndex, slang::IBlob** outDiagnostics) SLANG_OVERRIDE; + SLANG_NO_THROW SlangResult SLANG_MCALL getEntryPointCode( + SlangInt entryPointIndex, + SlangInt targetIndex, + slang::IBlob** outCode, + slang::IBlob** outDiagnostics) SLANG_OVERRIDE; + + IArtifact* getTargetArtifact(SlangInt targetIndex, slang::IBlob** outDiagnostics); + + SLANG_NO_THROW SlangResult SLANG_MCALL getTargetCode( + SlangInt targetIndex, + slang::IBlob** outCode, + slang::IBlob** outDiagnostics = nullptr) SLANG_OVERRIDE; + SLANG_NO_THROW SlangResult SLANG_MCALL getEntryPointMetadata( + SlangInt entryPointIndex, + SlangInt targetIndex, + slang::IMetadata** outMetadata, + slang::IBlob** outDiagnostics) SLANG_OVERRIDE; + SLANG_NO_THROW SlangResult SLANG_MCALL getTargetMetadata( + SlangInt targetIndex, + slang::IMetadata** outMetadata, + slang::IBlob** outDiagnostics = nullptr) SLANG_OVERRIDE; + + SLANG_NO_THROW SlangResult SLANG_MCALL getResultAsFileSystem( + SlangInt entryPointIndex, + SlangInt targetIndex, + ISlangMutableFileSystem** outFileSystem) SLANG_OVERRIDE; + + SLANG_NO_THROW SlangResult SLANG_MCALL specialize( + slang::SpecializationArg const* specializationArgs, + SlangInt specializationArgCount, + slang::IComponentType** outSpecializedComponentType, + ISlangBlob** outDiagnostics) SLANG_OVERRIDE; + SLANG_NO_THROW SlangResult SLANG_MCALL + renameEntryPoint(const char* newName, slang::IComponentType** outEntryPoint) SLANG_OVERRIDE; + SLANG_NO_THROW SlangResult SLANG_MCALL link( + slang::IComponentType** outLinkedComponentType, + ISlangBlob** outDiagnostics) SLANG_OVERRIDE; + SLANG_NO_THROW SlangResult SLANG_MCALL getEntryPointHostCallable( + int entryPointIndex, + int targetIndex, + ISlangSharedLibrary** outSharedLibrary, + slang::IBlob** outDiagnostics) SLANG_OVERRIDE; + + /// ComponentType is the only class inheriting from IComponentType that provides a + /// meaningful implementation for this function. All others should forward these and + /// implement `buildHash`. + SLANG_NO_THROW void SLANG_MCALL getEntryPointHash( + SlangInt entryPointIndex, + SlangInt targetIndex, + slang::IBlob** outHash) SLANG_OVERRIDE; + + SLANG_NO_THROW SlangResult SLANG_MCALL linkWithOptions( + slang::IComponentType** outLinkedComponentType, + uint32_t count, + slang::CompilerOptionEntry* entries, + ISlangBlob** outDiagnostics) override; + + + // + // slang::IModulePrecompileService interface + // + SLANG_NO_THROW SlangResult SLANG_MCALL + precompileForTarget(SlangCompileTarget target, slang::IBlob** outDiagnostics) SLANG_OVERRIDE; + + SLANG_NO_THROW SlangResult SLANG_MCALL getPrecompiledTargetCode( + SlangCompileTarget target, + slang::IBlob** outCode, + slang::IBlob** outDiagnostics = nullptr) SLANG_OVERRIDE; + + SLANG_NO_THROW SlangInt SLANG_MCALL getModuleDependencyCount() SLANG_OVERRIDE; + + SLANG_NO_THROW SlangResult SLANG_MCALL getModuleDependency( + SlangInt dependencyIndex, + slang::IModule** outModule, + slang::IBlob** outDiagnostics = nullptr) SLANG_OVERRIDE; + + CompilerOptionSet& getOptionSet() { return m_optionSet; } + + /// Get the linkage (aka "session" in the public API) for this component type. + Linkage* getLinkage() { return m_linkage; } + + /// Get the target-specific version of this program for the given `target`. + /// + /// The `target` must be a target on the `Linkage` that was used to create this program. + TargetProgram* getTargetProgram(TargetRequest* target); + + /// Update the hash builder with the dependencies for this component type. + virtual void buildHash(DigestBuilder& builder) = 0; + + /// Get the number of entry points linked into this component type. + virtual Index getEntryPointCount() = 0; + + /// Get one of the entry points linked into this component type. + virtual RefPtr getEntryPoint(Index index) = 0; + + /// Get the mangled name of one of the entry points linked into this component type. + virtual String getEntryPointMangledName(Index index) = 0; + + /// Get the name override of one of the entry points linked into this component type. + virtual String getEntryPointNameOverride(Index index) = 0; + + /// Get the number of global shader parameters linked into this component type. + virtual Index getShaderParamCount() = 0; - bool isHeterogeneousTarget(CodeGenTarget target); + /// Get one of the global shader parametesr linked into this component type. + virtual ShaderParamInfo getShaderParam(Index index) = 0; - void printDiagnosticArg(StringBuilder& sb, CodeGenTarget val); + /// Get the specialization parameter at `index`. + virtual SpecializationParam const& getSpecializationParam(Index index) = 0; - enum class ContainerFormat : SlangContainerFormatIntegral + /// Get the number of "requirements" that this component type has. + /// + /// A requirement represents another component type that this component + /// needs in order to function correctly. For example, the dependency + /// of one module on another module that it `import`s is represented + /// as a requirement, as is the dependency of an entry point on the + /// module that defines it. + /// + virtual Index getRequirementCount() = 0; + + /// Get the requirement at `index`. + virtual RefPtr getRequirement(Index index) = 0; + + /// Parse a type from a string, in the context of this component type. + /// + /// Any names in the string will be resolved using the modules + /// referenced by the program. + /// + /// On an error, returns null and reports diagnostic messages + /// to the provided `sink`. + /// + /// TODO: This function shouldn't be on the base class, since + /// it only really makes sense on `Module`. + /// + Type* getTypeFromString(String const& typeStr, DiagnosticSink* sink); + + Expr* findDeclFromString(String const& name, DiagnosticSink* sink); + + Expr* findDeclFromStringInType( + Type* type, + String const& name, + LookupMask mask, + DiagnosticSink* sink); + + bool isSubType(Type* subType, Type* superType); + + Dictionary& getMangledNameToIntValMap(); + ConstantIntVal* tryFoldIntVal(IntVal* intVal); + + /// Get a list of modules that this component type depends on. + /// + virtual List const& getModuleDependencies() = 0; + + /// Get the full list of source files this component type depends on. + /// + virtual List const& getFileDependencies() = 0; + + /// Callback for use with `enumerateIRModules` + typedef void (*EnumerateIRModulesCallback)(IRModule* irModule, void* userData); + + /// Invoke `callback` on all the IR modules that are (transitively) linked into this component + /// type. + void enumerateIRModules(EnumerateIRModulesCallback callback, void* userData); + + /// Invoke `callback` on all the IR modules that are (transitively) linked into this component + /// type. + template + void enumerateIRModules(F const& callback) { - None = SLANG_CONTAINER_FORMAT_NONE, - SlangModule = SLANG_CONTAINER_FORMAT_SLANG_MODULE, - }; + struct Helper + { + static void helper(IRModule* irModule, void* userData) { (*(F*)userData)(irModule); } + }; + enumerateIRModules(&Helper::helper, (void*)&callback); + } + + /// Callback for use with `enumerateModules` + typedef void (*EnumerateModulesCallback)(Module* module, void* userData); + + /// Invoke `callback` on all the modules that are (transitively) linked into this component + /// type. + void enumerateModules(EnumerateModulesCallback callback, void* userData); - enum class LineDirectiveMode : SlangLineDirectiveModeIntegral + /// Invoke `callback` on all the modules that are (transitively) linked into this component + /// type. + template + void enumerateModules(F const& callback) { - Default = SLANG_LINE_DIRECTIVE_MODE_DEFAULT, - None = SLANG_LINE_DIRECTIVE_MODE_NONE, - Standard = SLANG_LINE_DIRECTIVE_MODE_STANDARD, - GLSL = SLANG_LINE_DIRECTIVE_MODE_GLSL, - SourceMap = SLANG_LINE_DIRECTIVE_MODE_SOURCE_MAP, - }; + struct Helper + { + static void helper(Module* module, void* userData) { (*(F*)userData)(module); } + }; + enumerateModules(&Helper::helper, (void*)&callback); + } - enum class ResultFormat + /// Side-band information generated when specializing this component type. + /// + /// Difference subclasses of `ComponentType` are expected to create their + /// own subclass of `SpecializationInfo` as the output of `_validateSpecializationArgs`. + /// Later, whenever we want to use a specialized component type we will + /// also have the `SpecializationInfo` available and will expect it to + /// have the correct (subclass-specific) type. + /// + class SpecializationInfo : public RefObject { - None, - Text, - Binary, }; - // When storing the layout for a matrix-type - // value, we need to know whether it has been - // laid out with row-major or column-major - // storage. - // - enum MatrixLayoutMode : SlangMatrixLayoutModeIntegral + /// Validate the given specialization `args` and compute any side-band specialization info. + /// + /// Any errors will be reported to `sink`, which can thus be used to test + /// if the operation was successful. + /// + /// A null return value is allowed, since not all subclasses require + /// custom side-band specialization information. + /// + /// This function is an implementation detail of `specialize()`. + /// + virtual RefPtr _validateSpecializationArgsImpl( + SpecializationArg const* args, + Index argCount, + DiagnosticSink* sink) = 0; + + /// Validate the given specialization `args` and compute any side-band specialization info. + /// + /// Any errors will be reported to `sink`, which can thus be used to test + /// if the operation was successful. + /// + /// A null return value is allowed, since not all subclasses require + /// custom side-band specialization information. + /// + /// This function is an implementation detail of `specialize()`. + /// + RefPtr _validateSpecializationArgs( + SpecializationArg const* args, + Index argCount, + DiagnosticSink* sink) { - kMatrixLayoutMode_RowMajor = SLANG_MATRIX_LAYOUT_ROW_MAJOR, - kMatrixLayoutMode_ColumnMajor = SLANG_MATRIX_LAYOUT_COLUMN_MAJOR, - }; + if (argCount == 0) + return nullptr; + return _validateSpecializationArgsImpl(args, argCount, sink); + } + + /// Specialize this component type given `specializationArgs` + /// + /// Any diagnostics will be reported to `sink`, which can be used + /// to determine if the operation was successful. It is allowed + /// for this operation to have a non-null return even when an + /// error is ecnountered. + /// + RefPtr specialize( + SpecializationArg const* specializationArgs, + SlangInt specializationArgCount, + DiagnosticSink* sink); + + /// Invoke `visitor` on this component type, using the appropriate dynamic type. + /// + /// This function implements the "visitor pattern" for `ComponentType`. + /// + /// If the `specializationInfo` argument is non-null, it must be specialization + /// information generated for this specific component type by `_validateSpecializationArgs`. + /// In that case, appropriately-typed specialization information will be passed + /// when invoking the `visitor`. + /// + virtual void acceptVisitor( + ComponentTypeVisitor* visitor, + SpecializationInfo* specializationInfo) = 0; + + /// Create a scope suitable for looking up names or parsing specialization arguments. + /// + /// This facility is only needed to support legacy APIs for string-based lookup + /// and parsing via Slang reflection, and is not recommended for future APIs to use. + /// + Scope* _getOrCreateScopeForLegacyLookup(ASTBuilder* astBuilder); + +protected: + ComponentType(Linkage* linkage); + +protected: + Linkage* m_linkage; + + CompilerOptionSet m_optionSet; + + // Cache of target-specific programs for each target. + Dictionary> m_targetPrograms; + + // Any types looked up dynamically using `getTypeFromString` + // + // TODO: Remove this. Type lookup should only be supported on `Module`s. + // + Dictionary m_types; + + // Any decls looked up dynamically using `findDeclFromString`. + Dictionary m_decls; + + Scope* m_lookupScope = nullptr; + std::unique_ptr> m_mapMangledNameToIntVal; + + Dictionary> m_targetArtifacts; +}; + +/// A component type built up from other component types. +class CompositeComponentType : public ComponentType +{ +public: + static RefPtr create( + Linkage* linkage, + List> const& childComponents); + + virtual void buildHash(DigestBuilder& builder) SLANG_OVERRIDE; + + List> const& getChildComponents() { return m_childComponents; }; + Index getChildComponentCount() { return m_childComponents.getCount(); } + RefPtr getChildComponent(Index index) { return m_childComponents[index]; } + + Index getEntryPointCount() SLANG_OVERRIDE; + RefPtr getEntryPoint(Index index) SLANG_OVERRIDE; + String getEntryPointMangledName(Index index) SLANG_OVERRIDE; + String getEntryPointNameOverride(Index index) SLANG_OVERRIDE; + + Index getShaderParamCount() SLANG_OVERRIDE; + ShaderParamInfo getShaderParam(Index index) SLANG_OVERRIDE; - enum class DebugInfoLevel : SlangDebugInfoLevelIntegral + SLANG_NO_THROW Index SLANG_MCALL getSpecializationParamCount() SLANG_OVERRIDE; + SpecializationParam const& getSpecializationParam(Index index) SLANG_OVERRIDE; + + Index getRequirementCount() SLANG_OVERRIDE; + RefPtr getRequirement(Index index) SLANG_OVERRIDE; + + List const& getModuleDependencies() SLANG_OVERRIDE; + List const& getFileDependencies() SLANG_OVERRIDE; + + class CompositeSpecializationInfo : public SpecializationInfo { - None = SLANG_DEBUG_INFO_LEVEL_NONE, - Minimal = SLANG_DEBUG_INFO_LEVEL_MINIMAL, - Standard = SLANG_DEBUG_INFO_LEVEL_STANDARD, - Maximal = SLANG_DEBUG_INFO_LEVEL_MAXIMAL, + public: + List> childInfos; }; - enum class DebugInfoFormat : SlangDebugInfoFormatIntegral - { - Default = SLANG_DEBUG_INFO_FORMAT_DEFAULT, - C7 = SLANG_DEBUG_INFO_FORMAT_C7, - Pdb = SLANG_DEBUG_INFO_FORMAT_PDB, +protected: + void acceptVisitor(ComponentTypeVisitor* visitor, SpecializationInfo* specializationInfo) + SLANG_OVERRIDE; - Stabs = SLANG_DEBUG_INFO_FORMAT_STABS, - Coff = SLANG_DEBUG_INFO_FORMAT_COFF, - Dwarf = SLANG_DEBUG_INFO_FORMAT_DWARF, - CountOf = SLANG_DEBUG_INFO_FORMAT_COUNT_OF, - }; + RefPtr _validateSpecializationArgsImpl( + SpecializationArg const* args, + Index argCount, + DiagnosticSink* sink) SLANG_OVERRIDE; + +public: + CompositeComponentType(Linkage* linkage, List> const& childComponents); + +private: + List> m_childComponents; + + // The following arrays hold the concatenated entry points, parameters, + // etc. from the child components. This approach allows for reasonably + // fast (constant time) access through operations like `getShaderParam`, + // but means that the memory usage of a composite is proportional to + // the sum of the memory usage of the children, rather than being fixed + // by the number of children (as it would be if we just stored + // `m_childComponents`). + // + // TODO: We could conceivably build some O(numChildren) arrays that + // support binary-search to provide logarithmic-time access to entry + // points, parameters, etc. while giving a better overall memory usage. + // + List m_entryPoints; + List m_entryPointMangledNames; + List m_entryPointNameOverrides; + List m_shaderParams; + List m_specializationParams; + List m_requirements; + + ModuleDependencyList m_moduleDependencyList; + FileDependencyList m_fileDependencyList; +}; + +/// A component type created by specializing another component type. +class SpecializedComponentType : public ComponentType +{ +public: + SpecializedComponentType( + ComponentType* base, + SpecializationInfo* specializationInfo, + List const& specializationArgs, + DiagnosticSink* sink); + + virtual void buildHash(DigestBuilder& builer) SLANG_OVERRIDE; + + /// Get the base (unspecialized) component type that is being specialized. + RefPtr getBaseComponentType() { return m_base; } - enum class OptimizationLevel : SlangOptimizationLevelIntegral + RefPtr getSpecializationInfo() { return m_specializationInfo; } + + /// Get the number of arguments supplied for existential type parameters. + /// + /// Note that the number of arguments may not match the number of parameters. + /// In particular, an unspecialized entry point may have many parameters, but zero arguments. + Index getSpecializationArgCount() { return m_specializationArgs.getCount(); } + + /// Get the existential type argument (type and witness table) at `index`. + SpecializationArg const& getSpecializationArg(Index index) { - None = SLANG_OPTIMIZATION_LEVEL_NONE, - Default = SLANG_OPTIMIZATION_LEVEL_DEFAULT, - High = SLANG_OPTIMIZATION_LEVEL_HIGH, - Maximal = SLANG_OPTIMIZATION_LEVEL_MAXIMAL, - }; + return m_specializationArgs[index]; + } - struct CodeGenContext; - class EndToEndCompileRequest; - class FrontEndCompileRequest; - class Linkage; - class Module; - class TranslationUnitRequest; + /// Get an array of all existential type arguments. + SpecializationArg const* getSpecializationArgs() { return m_specializationArgs.getBuffer(); } - /// Information collected about global or entry-point shader parameters - struct ShaderParamInfo + Index getEntryPointCount() SLANG_OVERRIDE { return m_base->getEntryPointCount(); } + RefPtr getEntryPoint(Index index) SLANG_OVERRIDE { - DeclRef paramDeclRef; - Int firstSpecializationParamIndex = 0; - Int specializationParamCount = 0; - }; + return m_base->getEntryPoint(index); + } + String getEntryPointMangledName(Index index) SLANG_OVERRIDE; + String getEntryPointNameOverride(Index index) SLANG_OVERRIDE; - /// A request for the front-end to find and validate an entry-point function - struct FrontEndEntryPointRequest : RefObject + Index getShaderParamCount() SLANG_OVERRIDE { return m_base->getShaderParamCount(); } + ShaderParamInfo getShaderParam(Index index) SLANG_OVERRIDE { - public: - /// Create a request for an entry point. - FrontEndEntryPointRequest( - FrontEndCompileRequest* compileRequest, - int translationUnitIndex, - Name* name, - Profile profile); + return m_base->getShaderParam(index); + } - /// Get the parent front-end compile request. - FrontEndCompileRequest* getCompileRequest() { return m_compileRequest; } + SLANG_NO_THROW Index SLANG_MCALL getSpecializationParamCount() SLANG_OVERRIDE { return 0; } + SpecializationParam const& getSpecializationParam(Index index) SLANG_OVERRIDE + { + SLANG_UNUSED(index); + static SpecializationParam dummy; + return dummy; + } - /// Get the translation unit that contains the entry point. - TranslationUnitRequest* getTranslationUnit(); + Index getRequirementCount() SLANG_OVERRIDE; + RefPtr getRequirement(Index index) SLANG_OVERRIDE; - /// Get the name of the entry point to find. - Name* getName() { return m_name; } + List const& getModuleDependencies() SLANG_OVERRIDE { return m_moduleDependencies; } + List const& getFileDependencies() SLANG_OVERRIDE { return m_fileDependencies; } - /// Get the stage that the entry point is to be compiled for - Stage getStage() - { - return m_profile.getStage(); - } + RefPtr getIRModule() { return m_irModule; } - /// Get the profile that the entry point is to be compiled for - Profile getProfile() - { - return m_profile; - } + void acceptVisitor(ComponentTypeVisitor* visitor, SpecializationInfo* specializationInfo) + SLANG_OVERRIDE; - /// Get the index to the translation unit - int getTranslationUnitIndex() const { return m_translationUnitIndex; } +protected: + RefPtr _validateSpecializationArgsImpl( + SpecializationArg const* args, + Index argCount, + DiagnosticSink* sink) SLANG_OVERRIDE + { + SLANG_UNUSED(args); + SLANG_UNUSED(argCount); + SLANG_UNUSED(sink); + return nullptr; + } - private: - // The parent compile request - FrontEndCompileRequest* m_compileRequest; +private: + RefPtr m_base; + RefPtr m_specializationInfo; + SpecializationArgs m_specializationArgs; + RefPtr m_irModule; - // The index of the translation unit that will hold the entry point - int m_translationUnitIndex; + List m_entryPointMangledNames; + List m_entryPointNameOverrides; - // The name of the entry point function to look for - Name* m_name; + List m_moduleDependencies; + List m_fileDependencies; + List> m_requirements; +}; - // The profile to compile for (including stage) - Profile m_profile; - }; +class RenamedEntryPointComponentType : public ComponentType +{ +public: + using Super = ComponentType; + + RenamedEntryPointComponentType(ComponentType* base, String newName); + + ComponentType* getBase() { return m_base.Ptr(); } + + // Forward `IComponentType` methods - /// Tracks an ordered list of modules that something depends on. - /// TODO: Shader caching currently relies on this being in well defined order. - struct ModuleDependencyList + SLANG_NO_THROW slang::ISession* SLANG_MCALL getSession() SLANG_OVERRIDE { - public: - /// Get the list of modules that are depended on. - List const& getModuleList() { return m_moduleList; } + return Super::getSession(); + } - /// Add a module and everything it depends on to the list. - void addDependency(Module* module); + SLANG_NO_THROW slang::ProgramLayout* SLANG_MCALL + getLayout(SlangInt targetIndex, slang::IBlob** outDiagnostics) SLANG_OVERRIDE + { + return Super::getLayout(targetIndex, outDiagnostics); + } - /// Add a module to the list, but not the modules it depends on. - void addLeafDependency(Module* module); + SLANG_NO_THROW SlangResult SLANG_MCALL getEntryPointCode( + SlangInt entryPointIndex, + SlangInt targetIndex, + slang::IBlob** outCode, + slang::IBlob** outDiagnostics) SLANG_OVERRIDE + { + return Super::getEntryPointCode(entryPointIndex, targetIndex, outCode, outDiagnostics); + } - private: - void _addDependency(Module* module); + SLANG_NO_THROW SlangResult SLANG_MCALL specialize( + slang::SpecializationArg const* specializationArgs, + SlangInt specializationArgCount, + slang::IComponentType** outSpecializedComponentType, + ISlangBlob** outDiagnostics) SLANG_OVERRIDE + { + return Super::specialize( + specializationArgs, + specializationArgCount, + outSpecializedComponentType, + outDiagnostics); + } - List m_moduleList; - HashSet m_moduleSet; - }; + SLANG_NO_THROW SlangResult SLANG_MCALL + renameEntryPoint(const char* newName, slang::IComponentType** outEntryPoint) SLANG_OVERRIDE + { + return Super::renameEntryPoint(newName, outEntryPoint); + } - /// Tracks an unordered list of source files that something depends on - /// TODO: Shader caching currently relies on this being in well defined order. - struct FileDependencyList + SLANG_NO_THROW SlangResult SLANG_MCALL + link(slang::IComponentType** outLinkedComponentType, ISlangBlob** outDiagnostics) SLANG_OVERRIDE { - public: - /// Get the list of files that are depended on. - List const& getFileList() { return m_fileList; } + return Super::link(outLinkedComponentType, outDiagnostics); + } - /// Add a file to the list, if it is not already present - void addDependency(SourceFile* sourceFile); + SLANG_NO_THROW SlangResult SLANG_MCALL getEntryPointHostCallable( + int entryPointIndex, + int targetIndex, + ISlangSharedLibrary** outSharedLibrary, + slang::IBlob** outDiagnostics) SLANG_OVERRIDE + { + return Super::getEntryPointHostCallable( + entryPointIndex, + targetIndex, + outSharedLibrary, + outDiagnostics); + } - /// Add all of the paths that `module` depends on to the list - void addDependency(Module* module); + List const& getModuleDependencies() SLANG_OVERRIDE + { + return m_base->getModuleDependencies(); + } + List const& getFileDependencies() SLANG_OVERRIDE + { + return m_base->getFileDependencies(); + } - void clear() - { - m_fileList.clear(); - m_fileSet.clear(); - } + SLANG_NO_THROW Index SLANG_MCALL getSpecializationParamCount() SLANG_OVERRIDE + { + return m_base->getSpecializationParamCount(); + } - private: + SpecializationParam const& getSpecializationParam(Index index) SLANG_OVERRIDE + { + return m_base->getSpecializationParam(index); + } - // TODO: We are using a `HashSet` here to deduplicate - // the paths so that we don't return the same path - // multiple times from `getFilePathList`, but because - // order isn't important, we could potentially do better - // in terms of memory (at some cost in performance) by - // just sorting the `m_fileList` every once in - // a while and then deduplicating. + Index getRequirementCount() SLANG_OVERRIDE { return m_base->getRequirementCount(); } + RefPtr getRequirement(Index index) SLANG_OVERRIDE + { + return m_base->getRequirement(index); + } + Index getEntryPointCount() SLANG_OVERRIDE { return m_base->getEntryPointCount(); } + RefPtr getEntryPoint(Index index) SLANG_OVERRIDE + { + return m_base->getEntryPoint(index); + } + String getEntryPointMangledName(Index index) SLANG_OVERRIDE + { + return m_base->getEntryPointMangledName(index); + } + String getEntryPointNameOverride(Index index) SLANG_OVERRIDE + { + SLANG_UNUSED(index); + SLANG_ASSERT(index == 0); + return m_entryPointNameOverride; + } - List m_fileList; - HashSet m_fileSet; - }; + Index getShaderParamCount() SLANG_OVERRIDE { return m_base->getShaderParamCount(); } + ShaderParamInfo getShaderParam(Index index) SLANG_OVERRIDE + { + return m_base->getShaderParam(index); + } + void acceptVisitor(ComponentTypeVisitor* visitor, SpecializationInfo* specializationInfo) + SLANG_OVERRIDE; - class EntryPoint; + virtual void buildHash(DigestBuilder& builder) SLANG_OVERRIDE; - class ComponentType; - class ComponentTypeVisitor; +private: + RefPtr m_base; + String m_entryPointNameOverride; - /// Base class for "component types" that represent the pieces a final - /// shader program gets linked together from. - /// - class ComponentType : public RefObject, public slang::IComponentType, public slang::IModulePrecompileService_Experimental +protected: + RefPtr _validateSpecializationArgsImpl( + SpecializationArg const* args, + Index argCount, + DiagnosticSink* sink) SLANG_OVERRIDE { - public: - // - // ISlangUnknown interface - // - - SLANG_REF_OBJECT_IUNKNOWN_ALL; - ISlangUnknown* getInterface(Guid const& guid); - - // - // slang::IComponentType interface - // - - SLANG_NO_THROW slang::ISession* SLANG_MCALL getSession() SLANG_OVERRIDE; - SLANG_NO_THROW slang::ProgramLayout* SLANG_MCALL getLayout( - SlangInt targetIndex, - slang::IBlob** outDiagnostics) SLANG_OVERRIDE; - SLANG_NO_THROW SlangResult SLANG_MCALL getEntryPointCode( - SlangInt entryPointIndex, - SlangInt targetIndex, - slang::IBlob** outCode, - slang::IBlob** outDiagnostics) SLANG_OVERRIDE; - - IArtifact* getTargetArtifact(SlangInt targetIndex, slang::IBlob** outDiagnostics); - - SLANG_NO_THROW SlangResult SLANG_MCALL getTargetCode( - SlangInt targetIndex, - slang::IBlob** outCode, - slang::IBlob** outDiagnostics = nullptr) SLANG_OVERRIDE; - SLANG_NO_THROW SlangResult SLANG_MCALL getEntryPointMetadata( - SlangInt entryPointIndex, - SlangInt targetIndex, - slang::IMetadata** outMetadata, - slang::IBlob** outDiagnostics) SLANG_OVERRIDE; - SLANG_NO_THROW SlangResult SLANG_MCALL getTargetMetadata( - SlangInt targetIndex, - slang::IMetadata** outMetadata, - slang::IBlob** outDiagnostics = nullptr) SLANG_OVERRIDE; - - SLANG_NO_THROW SlangResult SLANG_MCALL getResultAsFileSystem( - SlangInt entryPointIndex, - SlangInt targetIndex, - ISlangMutableFileSystem** outFileSystem) SLANG_OVERRIDE; - - SLANG_NO_THROW SlangResult SLANG_MCALL specialize( - slang::SpecializationArg const* specializationArgs, - SlangInt specializationArgCount, - slang::IComponentType** outSpecializedComponentType, - ISlangBlob** outDiagnostics) SLANG_OVERRIDE; - SLANG_NO_THROW SlangResult SLANG_MCALL renameEntryPoint( - const char* newName, - slang::IComponentType** outEntryPoint) SLANG_OVERRIDE; - SLANG_NO_THROW SlangResult SLANG_MCALL link( - slang::IComponentType** outLinkedComponentType, - ISlangBlob** outDiagnostics) SLANG_OVERRIDE; - SLANG_NO_THROW SlangResult SLANG_MCALL getEntryPointHostCallable( - int entryPointIndex, - int targetIndex, - ISlangSharedLibrary** outSharedLibrary, - slang::IBlob** outDiagnostics) SLANG_OVERRIDE; - - /// ComponentType is the only class inheriting from IComponentType that provides a - /// meaningful implementation for this function. All others should forward these and - /// implement `buildHash`. - SLANG_NO_THROW void SLANG_MCALL getEntryPointHash( - SlangInt entryPointIndex, - SlangInt targetIndex, - slang::IBlob** outHash) SLANG_OVERRIDE; - - SLANG_NO_THROW SlangResult SLANG_MCALL linkWithOptions( - slang::IComponentType** outLinkedComponentType, - uint32_t count, - slang::CompilerOptionEntry* entries, - ISlangBlob** outDiagnostics) override; - - - // - // slang::IModulePrecompileService interface - // - SLANG_NO_THROW SlangResult SLANG_MCALL precompileForTarget( - SlangCompileTarget target, - slang::IBlob** outDiagnostics) SLANG_OVERRIDE; - - SLANG_NO_THROW SlangResult SLANG_MCALL getPrecompiledTargetCode( - SlangCompileTarget target, - slang::IBlob** outCode, - slang::IBlob** outDiagnostics = nullptr) SLANG_OVERRIDE; - - SLANG_NO_THROW SlangInt SLANG_MCALL getModuleDependencyCount() - SLANG_OVERRIDE; - - SLANG_NO_THROW SlangResult SLANG_MCALL getModuleDependency( - SlangInt dependencyIndex, - slang::IModule** outModule, - slang::IBlob** outDiagnostics = nullptr) SLANG_OVERRIDE; - - CompilerOptionSet& getOptionSet() { return m_optionSet; } - - /// Get the linkage (aka "session" in the public API) for this component type. - Linkage* getLinkage() { return m_linkage; } - - /// Get the target-specific version of this program for the given `target`. - /// - /// The `target` must be a target on the `Linkage` that was used to create this program. - TargetProgram* getTargetProgram(TargetRequest* target); - - /// Update the hash builder with the dependencies for this component type. - virtual void buildHash(DigestBuilder& builder) = 0; - - /// Get the number of entry points linked into this component type. - virtual Index getEntryPointCount() = 0; - - /// Get one of the entry points linked into this component type. - virtual RefPtr getEntryPoint(Index index) = 0; - - /// Get the mangled name of one of the entry points linked into this component type. - virtual String getEntryPointMangledName(Index index) = 0; - - /// Get the name override of one of the entry points linked into this component type. - virtual String getEntryPointNameOverride(Index index) = 0; - - /// Get the number of global shader parameters linked into this component type. - virtual Index getShaderParamCount() = 0; - - /// Get one of the global shader parametesr linked into this component type. - virtual ShaderParamInfo getShaderParam(Index index) = 0; - - /// Get the specialization parameter at `index`. - virtual SpecializationParam const& getSpecializationParam(Index index) = 0; - - /// Get the number of "requirements" that this component type has. - /// - /// A requirement represents another component type that this component - /// needs in order to function correctly. For example, the dependency - /// of one module on another module that it `import`s is represented - /// as a requirement, as is the dependency of an entry point on the - /// module that defines it. - /// - virtual Index getRequirementCount() = 0; - - /// Get the requirement at `index`. - virtual RefPtr getRequirement(Index index) = 0; - - /// Parse a type from a string, in the context of this component type. - /// - /// Any names in the string will be resolved using the modules - /// referenced by the program. - /// - /// On an error, returns null and reports diagnostic messages - /// to the provided `sink`. - /// - /// TODO: This function shouldn't be on the base class, since - /// it only really makes sense on `Module`. - /// - Type* getTypeFromString( - String const& typeStr, - DiagnosticSink* sink); - - Expr* findDeclFromString( - String const& name, - DiagnosticSink* sink); - - Expr* findDeclFromStringInType( - Type* type, - String const& name, - LookupMask mask, - DiagnosticSink* sink); - - bool isSubType(Type* subType, Type* superType); - - Dictionary& getMangledNameToIntValMap(); - ConstantIntVal* tryFoldIntVal(IntVal* intVal); - - /// Get a list of modules that this component type depends on. - /// - virtual List const& getModuleDependencies() = 0; - - /// Get the full list of source files this component type depends on. - /// - virtual List const& getFileDependencies() = 0; - - /// Callback for use with `enumerateIRModules` - typedef void (*EnumerateIRModulesCallback)(IRModule* irModule, void* userData); - - /// Invoke `callback` on all the IR modules that are (transitively) linked into this component type. - void enumerateIRModules(EnumerateIRModulesCallback callback, void* userData); - - /// Invoke `callback` on all the IR modules that are (transitively) linked into this component type. - template - void enumerateIRModules(F const& callback) - { - struct Helper - { - static void helper(IRModule* irModule, void* userData) - { - (*(F*)userData)(irModule); - } - }; - enumerateIRModules(&Helper::helper, (void*)&callback); - } + return m_base->_validateSpecializationArgsImpl(args, argCount, sink); + } +}; - /// Callback for use with `enumerateModules` - typedef void (*EnumerateModulesCallback)(Module* module, void* userData); +/// Describes an entry point for the purposes of layout and code generation. +/// +/// This class also tracks any generic arguments to the entry point, +/// in the case that it is a specialization of a generic entry point. +/// +/// There is also a provision for creating a "dummy" entry point for +/// the purposes of pass-through compilation modes. Only the +/// `getName()` and `getProfile()` methods should be expected to +/// return useful data on pass-through entry points. +/// +class EntryPoint : public ComponentType, public slang::IEntryPoint +{ + typedef ComponentType Super; - /// Invoke `callback` on all the modules that are (transitively) linked into this component type. - void enumerateModules(EnumerateModulesCallback callback, void* userData); +public: + SLANG_REF_OBJECT_IUNKNOWN_ALL - /// Invoke `callback` on all the modules that are (transitively) linked into this component type. - template - void enumerateModules(F const& callback) - { - struct Helper - { - static void helper(Module* module, void* userData) - { - (*(F*)userData)(module); - } - }; - enumerateModules(&Helper::helper, (void*)&callback); - } + ISlangUnknown* getInterface(const Guid& guid); - /// Side-band information generated when specializing this component type. - /// - /// Difference subclasses of `ComponentType` are expected to create their - /// own subclass of `SpecializationInfo` as the output of `_validateSpecializationArgs`. - /// Later, whenever we want to use a specialized component type we will - /// also have the `SpecializationInfo` available and will expect it to - /// have the correct (subclass-specific) type. - /// - class SpecializationInfo : public RefObject - { - }; - /// Validate the given specialization `args` and compute any side-band specialization info. - /// - /// Any errors will be reported to `sink`, which can thus be used to test - /// if the operation was successful. - /// - /// A null return value is allowed, since not all subclasses require - /// custom side-band specialization information. - /// - /// This function is an implementation detail of `specialize()`. - /// - virtual RefPtr _validateSpecializationArgsImpl( - SpecializationArg const* args, - Index argCount, - DiagnosticSink* sink) = 0; - - /// Validate the given specialization `args` and compute any side-band specialization info. - /// - /// Any errors will be reported to `sink`, which can thus be used to test - /// if the operation was successful. - /// - /// A null return value is allowed, since not all subclasses require - /// custom side-band specialization information. - /// - /// This function is an implementation detail of `specialize()`. - /// - RefPtr _validateSpecializationArgs( - SpecializationArg const* args, - Index argCount, - DiagnosticSink* sink) - { - if(argCount == 0) return nullptr; - return _validateSpecializationArgsImpl(args, argCount, sink); - } + // Forward `IComponentType` methods - /// Specialize this component type given `specializationArgs` - /// - /// Any diagnostics will be reported to `sink`, which can be used - /// to determine if the operation was successful. It is allowed - /// for this operation to have a non-null return even when an - /// error is ecnountered. - /// - RefPtr specialize( - SpecializationArg const* specializationArgs, - SlangInt specializationArgCount, - DiagnosticSink* sink); - - /// Invoke `visitor` on this component type, using the appropriate dynamic type. - /// - /// This function implements the "visitor pattern" for `ComponentType`. - /// - /// If the `specializationInfo` argument is non-null, it must be specialization - /// information generated for this specific component type by `_validateSpecializationArgs`. - /// In that case, appropriately-typed specialization information will be passed - /// when invoking the `visitor`. - /// - virtual void acceptVisitor(ComponentTypeVisitor* visitor, SpecializationInfo* specializationInfo) = 0; - - /// Create a scope suitable for looking up names or parsing specialization arguments. - /// - /// This facility is only needed to support legacy APIs for string-based lookup - /// and parsing via Slang reflection, and is not recommended for future APIs to use. - /// - Scope* _getOrCreateScopeForLegacyLookup(ASTBuilder* astBuilder); - protected: - ComponentType(Linkage* linkage); - - protected: - Linkage* m_linkage; - - CompilerOptionSet m_optionSet; - - // Cache of target-specific programs for each target. - Dictionary> m_targetPrograms; - - // Any types looked up dynamically using `getTypeFromString` - // - // TODO: Remove this. Type lookup should only be supported on `Module`s. - // - Dictionary m_types; - - // Any decls looked up dynamically using `findDeclFromString`. - Dictionary m_decls; - - Scope* m_lookupScope = nullptr; - std::unique_ptr> m_mapMangledNameToIntVal; - - Dictionary> m_targetArtifacts; - }; + SLANG_NO_THROW slang::ISession* SLANG_MCALL getSession() SLANG_OVERRIDE + { + return Super::getSession(); + } - /// A component type built up from other component types. - class CompositeComponentType : public ComponentType + SLANG_NO_THROW slang::ProgramLayout* SLANG_MCALL + getLayout(SlangInt targetIndex, slang::IBlob** outDiagnostics) SLANG_OVERRIDE { - public: - static RefPtr create( - Linkage* linkage, - List> const& childComponents); + return Super::getLayout(targetIndex, outDiagnostics); + } + + SLANG_NO_THROW SlangResult SLANG_MCALL getEntryPointCode( + SlangInt entryPointIndex, + SlangInt targetIndex, + slang::IBlob** outCode, + slang::IBlob** outDiagnostics) SLANG_OVERRIDE + { + return Super::getEntryPointCode(entryPointIndex, targetIndex, outCode, outDiagnostics); + } - virtual void buildHash(DigestBuilder& builder) SLANG_OVERRIDE; + SLANG_NO_THROW SlangResult SLANG_MCALL getTargetCode( + SlangInt targetIndex, + slang::IBlob** outCode, + slang::IBlob** outDiagnostics) SLANG_OVERRIDE + { + return Super::getTargetCode(targetIndex, outCode, outDiagnostics); + } - List> const& getChildComponents() { return m_childComponents; }; - Index getChildComponentCount() { return m_childComponents.getCount(); } - RefPtr getChildComponent(Index index) { return m_childComponents[index]; } + SLANG_NO_THROW SlangResult SLANG_MCALL getEntryPointMetadata( + SlangInt entryPointIndex, + SlangInt targetIndex, + slang::IMetadata** outMetadata, + slang::IBlob** outDiagnostics) SLANG_OVERRIDE + { + return Super::getEntryPointMetadata( + entryPointIndex, + targetIndex, + outMetadata, + outDiagnostics); + } - Index getEntryPointCount() SLANG_OVERRIDE; - RefPtr getEntryPoint(Index index) SLANG_OVERRIDE; - String getEntryPointMangledName(Index index) SLANG_OVERRIDE; - String getEntryPointNameOverride(Index index) SLANG_OVERRIDE; + SLANG_NO_THROW SlangResult SLANG_MCALL getTargetMetadata( + SlangInt targetIndex, + slang::IMetadata** outMetadata, + slang::IBlob** outDiagnostics) SLANG_OVERRIDE + { + return Super::getTargetMetadata(targetIndex, outMetadata, outDiagnostics); + } - Index getShaderParamCount() SLANG_OVERRIDE; - ShaderParamInfo getShaderParam(Index index) SLANG_OVERRIDE; + SLANG_NO_THROW SlangResult SLANG_MCALL getResultAsFileSystem( + SlangInt entryPointIndex, + SlangInt targetIndex, + ISlangMutableFileSystem** outFileSystem) SLANG_OVERRIDE + { + return Super::getResultAsFileSystem(entryPointIndex, targetIndex, outFileSystem); + } + + SLANG_NO_THROW SlangResult SLANG_MCALL specialize( + slang::SpecializationArg const* specializationArgs, + SlangInt specializationArgCount, + slang::IComponentType** outSpecializedComponentType, + ISlangBlob** outDiagnostics) SLANG_OVERRIDE + { + return Super::specialize( + specializationArgs, + specializationArgCount, + outSpecializedComponentType, + outDiagnostics); + } - SLANG_NO_THROW Index SLANG_MCALL getSpecializationParamCount() SLANG_OVERRIDE; - SpecializationParam const& getSpecializationParam(Index index) SLANG_OVERRIDE; + SLANG_NO_THROW SlangResult SLANG_MCALL + renameEntryPoint(const char* newName, slang::IComponentType** outEntryPoint) SLANG_OVERRIDE + { + return Super::renameEntryPoint(newName, outEntryPoint); + } - Index getRequirementCount() SLANG_OVERRIDE; - RefPtr getRequirement(Index index) SLANG_OVERRIDE; + SLANG_NO_THROW SlangResult SLANG_MCALL + link(slang::IComponentType** outLinkedComponentType, ISlangBlob** outDiagnostics) SLANG_OVERRIDE + { + return Super::link(outLinkedComponentType, outDiagnostics); + } - List const& getModuleDependencies() SLANG_OVERRIDE; - List const& getFileDependencies() SLANG_OVERRIDE; + virtual SLANG_NO_THROW SlangResult SLANG_MCALL linkWithOptions( + slang::IComponentType** outLinkedComponentType, + uint32_t count, + slang::CompilerOptionEntry* entries, + ISlangBlob** outDiagnostics) override + { + return Super::linkWithOptions(outLinkedComponentType, count, entries, outDiagnostics); + } - class CompositeSpecializationInfo : public SpecializationInfo - { - public: - List> childInfos; - }; + SLANG_NO_THROW SlangResult SLANG_MCALL getEntryPointHostCallable( + int entryPointIndex, + int targetIndex, + ISlangSharedLibrary** outSharedLibrary, + slang::IBlob** outDiagnostics) SLANG_OVERRIDE + { + return Super::getEntryPointHostCallable( + entryPointIndex, + targetIndex, + outSharedLibrary, + outDiagnostics); + } - protected: - void acceptVisitor(ComponentTypeVisitor* visitor, SpecializationInfo* specializationInfo) SLANG_OVERRIDE; + SLANG_NO_THROW void SLANG_MCALL getEntryPointHash( + SlangInt entryPointIndex, + SlangInt targetIndex, + slang::IBlob** outHash) SLANG_OVERRIDE + { + return Super::getEntryPointHash(entryPointIndex, targetIndex, outHash); + } + + virtual void buildHash(DigestBuilder& builder) SLANG_OVERRIDE; + + /// Create an entry point that refers to the given function. + static RefPtr create( + Linkage* linkage, + DeclRef funcDeclRef, + Profile profile); + + /// Get the function decl-ref, including any generic arguments. + DeclRef getFuncDeclRef() { return m_funcDeclRef; } + + /// Get the function declaration (without generic arguments). + FuncDecl* getFuncDecl() { return m_funcDeclRef.getDecl(); } + + /// Get the name of the entry point + Name* getName() { return m_name; } + + /// Get the profile associated with the entry point + /// + /// Note: only the stage part of the profile is expected + /// to contain useful data, but certain legacy code paths + /// allow for "shader model" information to come via this path. + /// + Profile getProfile() { return m_profile; } + + /// Get the stage that the entry point is for. + Stage getStage() { return m_profile.getStage(); } + + /// Get the module that contains the entry point. + Module* getModule(); + + /// Get a list of modules that this entry point depends on. + /// + /// This will include the module that defines the entry point (see `getModule()`), + /// but may also include modules that are required by its generic type arguments. + /// + List const& getModuleDependencies() + SLANG_OVERRIDE; // { return getModule()->getModuleDependencies(); } + List const& getFileDependencies() + SLANG_OVERRIDE; // { return getModule()->getFileDependencies(); } + + /// Create a dummy `EntryPoint` that is only usable for pass-through compilation. + static RefPtr createDummyForPassThrough( + Linkage* linkage, + Name* name, + Profile profile); + + /// Create a dummy `EntryPoint` that stands in for a serialized entry point + static RefPtr createDummyForDeserialize( + Linkage* linkage, + Name* name, + Profile profile, + String mangledName); + + /// Get the number of existential type parameters for the entry point. + SLANG_NO_THROW Index SLANG_MCALL getSpecializationParamCount() SLANG_OVERRIDE; + /// Get the existential type parameter at `index`. + SpecializationParam const& getSpecializationParam(Index index) SLANG_OVERRIDE; + + Index getRequirementCount() SLANG_OVERRIDE; + RefPtr getRequirement(Index index) SLANG_OVERRIDE; + + SpecializationParams const& getExistentialSpecializationParams() + { + return m_existentialSpecializationParams; + } + + Index getGenericSpecializationParamCount() { return m_genericSpecializationParams.getCount(); } + Index getExistentialSpecializationParamCount() + { + return m_existentialSpecializationParams.getCount(); + } + + /// Get an array of all entry-point shader parameters. + List const& getShaderParams() { return m_shaderParams; } + + Index getEntryPointCount() SLANG_OVERRIDE { return 1; }; + RefPtr getEntryPoint(Index index) SLANG_OVERRIDE + { + SLANG_UNUSED(index); + return this; + } + String getEntryPointMangledName(Index index) SLANG_OVERRIDE; + String getEntryPointNameOverride(Index index) SLANG_OVERRIDE; - RefPtr _validateSpecializationArgsImpl( - SpecializationArg const* args, - Index argCount, - DiagnosticSink* sink) SLANG_OVERRIDE; + Index getShaderParamCount() SLANG_OVERRIDE { return 0; } + ShaderParamInfo getShaderParam(Index index) SLANG_OVERRIDE + { + SLANG_UNUSED(index); + return ShaderParamInfo(); + } + class EntryPointSpecializationInfo : public SpecializationInfo + { public: - CompositeComponentType( - Linkage* linkage, - List> const& childComponents); - - private: - List> m_childComponents; - - // The following arrays hold the concatenated entry points, parameters, - // etc. from the child components. This approach allows for reasonably - // fast (constant time) access through operations like `getShaderParam`, - // but means that the memory usage of a composite is proportional to - // the sum of the memory usage of the children, rather than being fixed - // by the number of children (as it would be if we just stored - // `m_childComponents`). - // - // TODO: We could conceivably build some O(numChildren) arrays that - // support binary-search to provide logarithmic-time access to entry - // points, parameters, etc. while giving a better overall memory usage. - // - List m_entryPoints; - List m_entryPointMangledNames; - List m_entryPointNameOverrides; - List m_shaderParams; - List m_specializationParams; - List m_requirements; - - ModuleDependencyList m_moduleDependencyList; - FileDependencyList m_fileDependencyList; + DeclRef specializedFuncDeclRef; + List existentialSpecializationArgs; }; - /// A component type created by specializing another component type. - class SpecializedComponentType : public ComponentType + SLANG_NO_THROW slang::FunctionReflection* SLANG_MCALL getFunctionReflection() SLANG_OVERRIDE { - public: - SpecializedComponentType( - ComponentType* base, - SpecializationInfo* specializationInfo, - List const& specializationArgs, - DiagnosticSink* sink); + return (slang::FunctionReflection*)m_funcDeclRef.declRefBase; + } - virtual void buildHash(DigestBuilder& builer) SLANG_OVERRIDE; +protected: + void acceptVisitor(ComponentTypeVisitor* visitor, SpecializationInfo* specializationInfo) + SLANG_OVERRIDE; - /// Get the base (unspecialized) component type that is being specialized. - RefPtr getBaseComponentType() { return m_base; } + RefPtr _validateSpecializationArgsImpl( + SpecializationArg const* args, + Index argCount, + DiagnosticSink* sink) SLANG_OVERRIDE; - RefPtr getSpecializationInfo() { return m_specializationInfo; } +private: + EntryPoint(Linkage* linkage, Name* name, Profile profile, DeclRef funcDeclRef); - /// Get the number of arguments supplied for existential type parameters. - /// - /// Note that the number of arguments may not match the number of parameters. - /// In particular, an unspecialized entry point may have many parameters, but zero arguments. - Index getSpecializationArgCount() { return m_specializationArgs.getCount(); } + void _collectGenericSpecializationParamsRec(Decl* decl); + void _collectShaderParams(); - /// Get the existential type argument (type and witness table) at `index`. - SpecializationArg const& getSpecializationArg(Index index) { return m_specializationArgs[index]; } + // The name of the entry point function (e.g., `main`) + // + Name* m_name = nullptr; - /// Get an array of all existential type arguments. - SpecializationArg const* getSpecializationArgs() { return m_specializationArgs.getBuffer(); } + // The declaration of the entry-point function itself. + // + DeclRef m_funcDeclRef; - Index getEntryPointCount() SLANG_OVERRIDE { return m_base->getEntryPointCount(); } - RefPtr getEntryPoint(Index index) SLANG_OVERRIDE { return m_base->getEntryPoint(index); } - String getEntryPointMangledName(Index index) SLANG_OVERRIDE; - String getEntryPointNameOverride(Index index) SLANG_OVERRIDE; + /// The mangled name of the entry point function + String m_mangledName; - Index getShaderParamCount() SLANG_OVERRIDE { return m_base->getShaderParamCount(); } - ShaderParamInfo getShaderParam(Index index) SLANG_OVERRIDE { return m_base->getShaderParam(index); } + SpecializationParams m_genericSpecializationParams; + SpecializationParams m_existentialSpecializationParams; - SLANG_NO_THROW Index SLANG_MCALL getSpecializationParamCount() SLANG_OVERRIDE { return 0; } - SpecializationParam const& getSpecializationParam(Index index) SLANG_OVERRIDE { SLANG_UNUSED(index); static SpecializationParam dummy; return dummy; } + /// Information about entry-point parameters + List m_shaderParams; - Index getRequirementCount() SLANG_OVERRIDE; - RefPtr getRequirement(Index index) SLANG_OVERRIDE; + // The profile that the entry point will be compiled for + // (this is a combination of the target stage, and also + // a feature level that sets capabilities) + // + // Note: the profile-version part of this should probably + // be moving towards deprecation, in favor of the version + // information (e.g., "Shader Model 5.1") always coming + // from the target, while the stage part is all that is + // intrinsic to the entry point. + // + Profile m_profile; +}; - List const& getModuleDependencies() SLANG_OVERRIDE { return m_moduleDependencies; } - List const& getFileDependencies() SLANG_OVERRIDE { return m_fileDependencies; } +class TypeConformance : public ComponentType, public slang::ITypeConformance +{ + typedef ComponentType Super; - RefPtr getIRModule() { return m_irModule; } +public: + SLANG_REF_OBJECT_IUNKNOWN_ALL - void acceptVisitor(ComponentTypeVisitor* visitor, SpecializationInfo* specializationInfo) SLANG_OVERRIDE; + ISlangUnknown* getInterface(const Guid& guid); - protected: + TypeConformance( + Linkage* linkage, + SubtypeWitness* witness, + Int confomrmanceIdOverride, + DiagnosticSink* sink); - RefPtr _validateSpecializationArgsImpl( - SpecializationArg const* args, - Index argCount, - DiagnosticSink* sink) SLANG_OVERRIDE - { - SLANG_UNUSED(args); - SLANG_UNUSED(argCount); - SLANG_UNUSED(sink); - return nullptr; - } + // Forward `IComponentType` methods - private: - RefPtr m_base; - RefPtr m_specializationInfo; - SpecializationArgs m_specializationArgs; - RefPtr m_irModule; + SLANG_NO_THROW slang::ISession* SLANG_MCALL getSession() SLANG_OVERRIDE + { + return Super::getSession(); + } - List m_entryPointMangledNames; - List m_entryPointNameOverrides; + SLANG_NO_THROW slang::ProgramLayout* SLANG_MCALL + getLayout(SlangInt targetIndex, slang::IBlob** outDiagnostics) SLANG_OVERRIDE + { + return Super::getLayout(targetIndex, outDiagnostics); + } - List m_moduleDependencies; - List m_fileDependencies; - List> m_requirements; - }; + SLANG_NO_THROW SlangResult SLANG_MCALL getEntryPointCode( + SlangInt entryPointIndex, + SlangInt targetIndex, + slang::IBlob** outCode, + slang::IBlob** outDiagnostics) SLANG_OVERRIDE + { + return Super::getEntryPointCode(entryPointIndex, targetIndex, outCode, outDiagnostics); + } - class RenamedEntryPointComponentType : public ComponentType + SLANG_NO_THROW SlangResult SLANG_MCALL getTargetCode( + SlangInt targetIndex, + slang::IBlob** outCode, + slang::IBlob** outDiagnostics) SLANG_OVERRIDE { - public: - using Super = ComponentType; + return Super::getTargetCode(targetIndex, outCode, outDiagnostics); + } - RenamedEntryPointComponentType(ComponentType* base, String newName); + SLANG_NO_THROW SlangResult SLANG_MCALL getEntryPointMetadata( + SlangInt entryPointIndex, + SlangInt targetIndex, + slang::IMetadata** outMetadata, + slang::IBlob** outDiagnostics) SLANG_OVERRIDE + { + return Super::getEntryPointMetadata( + entryPointIndex, + targetIndex, + outMetadata, + outDiagnostics); + } - ComponentType* getBase() { return m_base.Ptr(); } + SLANG_NO_THROW SlangResult SLANG_MCALL getTargetMetadata( + SlangInt targetIndex, + slang::IMetadata** outMetadata, + slang::IBlob** outDiagnostics) SLANG_OVERRIDE + { + return Super::getTargetMetadata(targetIndex, outMetadata, outDiagnostics); + } - // Forward `IComponentType` methods + SLANG_NO_THROW SlangResult SLANG_MCALL getResultAsFileSystem( + SlangInt entryPointIndex, + SlangInt targetIndex, + ISlangMutableFileSystem** outFileSystem) SLANG_OVERRIDE + { + return Super::getResultAsFileSystem(entryPointIndex, targetIndex, outFileSystem); + } - SLANG_NO_THROW slang::ISession* SLANG_MCALL getSession() SLANG_OVERRIDE - { - return Super::getSession(); - } + SLANG_NO_THROW SlangResult SLANG_MCALL specialize( + slang::SpecializationArg const* specializationArgs, + SlangInt specializationArgCount, + slang::IComponentType** outSpecializedComponentType, + ISlangBlob** outDiagnostics) SLANG_OVERRIDE + { + return Super::specialize( + specializationArgs, + specializationArgCount, + outSpecializedComponentType, + outDiagnostics); + } + + SLANG_NO_THROW SlangResult SLANG_MCALL + renameEntryPoint(const char* newName, slang::IComponentType** outEntryPoint) SLANG_OVERRIDE + { + return Super::renameEntryPoint(newName, outEntryPoint); + } + + SLANG_NO_THROW SlangResult SLANG_MCALL + link(slang::IComponentType** outLinkedComponentType, ISlangBlob** outDiagnostics) SLANG_OVERRIDE + { + return Super::link(outLinkedComponentType, outDiagnostics); + } + + virtual SLANG_NO_THROW SlangResult SLANG_MCALL linkWithOptions( + slang::IComponentType** outLinkedComponentType, + uint32_t count, + slang::CompilerOptionEntry* entries, + ISlangBlob** outDiagnostics) override + { + return Super::linkWithOptions(outLinkedComponentType, count, entries, outDiagnostics); + } + + SLANG_NO_THROW SlangResult SLANG_MCALL getEntryPointHostCallable( + int entryPointIndex, + int targetIndex, + ISlangSharedLibrary** outSharedLibrary, + slang::IBlob** outDiagnostics) SLANG_OVERRIDE + { + return Super::getEntryPointHostCallable( + entryPointIndex, + targetIndex, + outSharedLibrary, + outDiagnostics); + } - SLANG_NO_THROW slang::ProgramLayout* SLANG_MCALL - getLayout(SlangInt targetIndex, slang::IBlob** outDiagnostics) SLANG_OVERRIDE - { - return Super::getLayout(targetIndex, outDiagnostics); - } + SLANG_NO_THROW void SLANG_MCALL getEntryPointHash( + SlangInt entryPointIndex, + SlangInt targetIndex, + slang::IBlob** outHash) SLANG_OVERRIDE + { + return Super::getEntryPointHash(entryPointIndex, targetIndex, outHash); + } - SLANG_NO_THROW SlangResult SLANG_MCALL getEntryPointCode( - SlangInt entryPointIndex, - SlangInt targetIndex, - slang::IBlob** outCode, - slang::IBlob** outDiagnostics) SLANG_OVERRIDE - { - return Super::getEntryPointCode(entryPointIndex, targetIndex, outCode, outDiagnostics); - } + virtual void buildHash(DigestBuilder& builder) SLANG_OVERRIDE; - SLANG_NO_THROW SlangResult SLANG_MCALL specialize( - slang::SpecializationArg const* specializationArgs, - SlangInt specializationArgCount, - slang::IComponentType** outSpecializedComponentType, - ISlangBlob** outDiagnostics) SLANG_OVERRIDE - { - return Super::specialize( - specializationArgs, - specializationArgCount, - outSpecializedComponentType, - outDiagnostics); - } + List const& getModuleDependencies() SLANG_OVERRIDE; + List const& getFileDependencies() SLANG_OVERRIDE; - SLANG_NO_THROW SlangResult SLANG_MCALL renameEntryPoint( - const char* newName, slang::IComponentType** outEntryPoint) SLANG_OVERRIDE - { - return Super::renameEntryPoint(newName, outEntryPoint); - } + SLANG_NO_THROW Index SLANG_MCALL getSpecializationParamCount() SLANG_OVERRIDE { return 0; } - SLANG_NO_THROW SlangResult SLANG_MCALL link( - slang::IComponentType** outLinkedComponentType, - ISlangBlob** outDiagnostics) SLANG_OVERRIDE - { - return Super::link(outLinkedComponentType, outDiagnostics); - } + /// Get the existential type parameter at `index`. + SpecializationParam const& getSpecializationParam(Index /*index*/) SLANG_OVERRIDE + { + static SpecializationParam emptyParam; + return emptyParam; + } - SLANG_NO_THROW SlangResult SLANG_MCALL getEntryPointHostCallable( - int entryPointIndex, - int targetIndex, - ISlangSharedLibrary** outSharedLibrary, - slang::IBlob** outDiagnostics) SLANG_OVERRIDE - { - return Super::getEntryPointHostCallable( - entryPointIndex, targetIndex, outSharedLibrary, outDiagnostics); - } + Index getRequirementCount() SLANG_OVERRIDE; + RefPtr getRequirement(Index index) SLANG_OVERRIDE; + Index getEntryPointCount() SLANG_OVERRIDE { return 0; }; + RefPtr getEntryPoint(Index index) SLANG_OVERRIDE + { + SLANG_UNUSED(index); + return nullptr; + } + String getEntryPointMangledName(Index /*index*/) SLANG_OVERRIDE { return ""; } + String getEntryPointNameOverride(Index /*index*/) SLANG_OVERRIDE { return ""; } - List const& getModuleDependencies() SLANG_OVERRIDE - { - return m_base->getModuleDependencies(); - } - List const& getFileDependencies() SLANG_OVERRIDE - { - return m_base->getFileDependencies(); - } + Index getShaderParamCount() SLANG_OVERRIDE { return 0; } + ShaderParamInfo getShaderParam(Index index) SLANG_OVERRIDE + { + SLANG_UNUSED(index); + return ShaderParamInfo(); + } - SLANG_NO_THROW Index SLANG_MCALL getSpecializationParamCount() SLANG_OVERRIDE - { - return m_base->getSpecializationParamCount(); - } + SubtypeWitness* getSubtypeWitness() { return m_subtypeWitness; } + IRModule* getIRModule() { return m_irModule.Ptr(); } + +protected: + void acceptVisitor(ComponentTypeVisitor* visitor, SpecializationInfo* specializationInfo) + SLANG_OVERRIDE; + + RefPtr _validateSpecializationArgsImpl( + SpecializationArg const* args, + Index argCount, + DiagnosticSink* sink) SLANG_OVERRIDE; + +private: + SubtypeWitness* m_subtypeWitness; + ModuleDependencyList m_moduleDependencyList; + FileDependencyList m_fileDependencyList; + List> m_requirements; + HashSet m_requirementSet; + RefPtr m_irModule; + Int m_conformanceIdOverride; + void addDepedencyFromWitness(SubtypeWitness* witness); +}; - SpecializationParam const& getSpecializationParam(Index index) SLANG_OVERRIDE - { - return m_base->getSpecializationParam(index); - } +enum class PassThroughMode : SlangPassThroughIntegral +{ + None = SLANG_PASS_THROUGH_NONE, ///< don't pass through: use Slang compiler + Fxc = SLANG_PASS_THROUGH_FXC, ///< pass through HLSL to `D3DCompile` API + Dxc = SLANG_PASS_THROUGH_DXC, ///< pass through HLSL to `IDxcCompiler` API + Glslang = SLANG_PASS_THROUGH_GLSLANG, ///< pass through GLSL to `glslang` library + SpirvDis = SLANG_PASS_THROUGH_SPIRV_DIS, ///< pass through spirv-dis + Clang = SLANG_PASS_THROUGH_CLANG, ///< Pass through clang compiler + VisualStudio = SLANG_PASS_THROUGH_VISUAL_STUDIO, ///< Visual studio compiler + Gcc = SLANG_PASS_THROUGH_GCC, ///< Gcc compiler + GenericCCpp = SLANG_PASS_THROUGH_GENERIC_C_CPP, ///< Generic C/C++ compiler + NVRTC = SLANG_PASS_THROUGH_NVRTC, ///< NVRTC CUDA compiler + LLVM = SLANG_PASS_THROUGH_LLVM, ///< LLVM 'compiler' + SpirvOpt = SLANG_PASS_THROUGH_SPIRV_OPT, ///< pass thorugh spirv to spirv-opt + MetalC = SLANG_PASS_THROUGH_METAL, + Tint = SLANG_PASS_THROUGH_TINT, ///< pass through spirv to Tint API + CountOf = SLANG_PASS_THROUGH_COUNT_OF, +}; +void printDiagnosticArg(StringBuilder& sb, PassThroughMode val); - Index getRequirementCount() SLANG_OVERRIDE { return m_base->getRequirementCount(); } - RefPtr getRequirement(Index index) SLANG_OVERRIDE - { - return m_base->getRequirement(index); - } - Index getEntryPointCount() SLANG_OVERRIDE { return m_base->getEntryPointCount(); } - RefPtr getEntryPoint(Index index) SLANG_OVERRIDE - { - return m_base->getEntryPoint(index); - } - String getEntryPointMangledName(Index index) SLANG_OVERRIDE { return m_base->getEntryPointMangledName(index); } - String getEntryPointNameOverride(Index index) SLANG_OVERRIDE - { - SLANG_UNUSED(index); - SLANG_ASSERT(index == 0); - return m_entryPointNameOverride; - } +class SourceFile; - Index getShaderParamCount() SLANG_OVERRIDE { return m_base->getShaderParamCount(); } - ShaderParamInfo getShaderParam(Index index) SLANG_OVERRIDE - { - return m_base->getShaderParam(index); - } +/// A module of code that has been compiled through the front-end +/// +/// A module comprises all the code from one translation unit (which +/// may span multiple Slang source files), and provides access +/// to both the AST and IR representations of that code. +/// +class Module : public ComponentType, public slang::IModule +{ + typedef ComponentType Super; - void acceptVisitor(ComponentTypeVisitor* visitor, SpecializationInfo* specializationInfo) - SLANG_OVERRIDE; +public: + SLANG_REF_OBJECT_IUNKNOWN_ALL - virtual void buildHash(DigestBuilder& builder) SLANG_OVERRIDE; + ISlangUnknown* getInterface(const Guid& guid); - private: - RefPtr m_base; - String m_entryPointNameOverride; - protected: - RefPtr _validateSpecializationArgsImpl( - SpecializationArg const* args, Index argCount, DiagnosticSink* sink) SLANG_OVERRIDE - { - return m_base->_validateSpecializationArgsImpl(args, argCount, sink); - } - }; + // Forward `IComponentType` methods - /// Describes an entry point for the purposes of layout and code generation. - /// - /// This class also tracks any generic arguments to the entry point, - /// in the case that it is a specialization of a generic entry point. - /// - /// There is also a provision for creating a "dummy" entry point for - /// the purposes of pass-through compilation modes. Only the - /// `getName()` and `getProfile()` methods should be expected to - /// return useful data on pass-through entry points. - /// - class EntryPoint : public ComponentType, public slang::IEntryPoint + SLANG_NO_THROW slang::ISession* SLANG_MCALL getSession() SLANG_OVERRIDE { - typedef ComponentType Super; + return Super::getSession(); + } - public: - SLANG_REF_OBJECT_IUNKNOWN_ALL + SLANG_NO_THROW slang::ProgramLayout* SLANG_MCALL + getLayout(SlangInt targetIndex, slang::IBlob** outDiagnostics) SLANG_OVERRIDE + { + return Super::getLayout(targetIndex, outDiagnostics); + } - ISlangUnknown* getInterface(const Guid& guid); + SLANG_NO_THROW SlangResult SLANG_MCALL getEntryPointCode( + SlangInt entryPointIndex, + SlangInt targetIndex, + slang::IBlob** outCode, + slang::IBlob** outDiagnostics) SLANG_OVERRIDE + { + return Super::getEntryPointCode(entryPointIndex, targetIndex, outCode, outDiagnostics); + } + SLANG_NO_THROW SlangResult SLANG_MCALL getTargetCode( + SlangInt targetIndex, + slang::IBlob** outCode, + slang::IBlob** outDiagnostics) SLANG_OVERRIDE + { + return Super::getTargetCode(targetIndex, outCode, outDiagnostics); + } - // Forward `IComponentType` methods + SLANG_NO_THROW SlangResult SLANG_MCALL getResultAsFileSystem( + SlangInt entryPointIndex, + SlangInt targetIndex, + ISlangMutableFileSystem** outFileSystem) SLANG_OVERRIDE + { + return Super::getResultAsFileSystem(entryPointIndex, targetIndex, outFileSystem); + } - SLANG_NO_THROW slang::ISession* SLANG_MCALL getSession() SLANG_OVERRIDE - { - return Super::getSession(); - } + SLANG_NO_THROW SlangResult SLANG_MCALL specialize( + slang::SpecializationArg const* specializationArgs, + SlangInt specializationArgCount, + slang::IComponentType** outSpecializedComponentType, + ISlangBlob** outDiagnostics) SLANG_OVERRIDE + { + return Super::specialize( + specializationArgs, + specializationArgCount, + outSpecializedComponentType, + outDiagnostics); + } - SLANG_NO_THROW slang::ProgramLayout* SLANG_MCALL getLayout( - SlangInt targetIndex, - slang::IBlob** outDiagnostics) SLANG_OVERRIDE - { - return Super::getLayout(targetIndex, outDiagnostics); - } + SLANG_NO_THROW SlangResult SLANG_MCALL + renameEntryPoint(const char* newName, slang::IComponentType** outEntryPoint) SLANG_OVERRIDE + { + return Super::renameEntryPoint(newName, outEntryPoint); + } - SLANG_NO_THROW SlangResult SLANG_MCALL getEntryPointCode( - SlangInt entryPointIndex, - SlangInt targetIndex, - slang::IBlob** outCode, - slang::IBlob** outDiagnostics) SLANG_OVERRIDE - { - return Super::getEntryPointCode(entryPointIndex, targetIndex, outCode, outDiagnostics); - } + SLANG_NO_THROW SlangResult SLANG_MCALL + link(slang::IComponentType** outLinkedComponentType, ISlangBlob** outDiagnostics) SLANG_OVERRIDE + { + return Super::link(outLinkedComponentType, outDiagnostics); + } - SLANG_NO_THROW SlangResult SLANG_MCALL getTargetCode( - SlangInt targetIndex, - slang::IBlob** outCode, - slang::IBlob** outDiagnostics) SLANG_OVERRIDE - { - return Super::getTargetCode(targetIndex, outCode, outDiagnostics); - } + SLANG_NO_THROW SlangResult SLANG_MCALL getEntryPointHostCallable( + int entryPointIndex, + int targetIndex, + ISlangSharedLibrary** outSharedLibrary, + slang::IBlob** outDiagnostics) SLANG_OVERRIDE + { + return Super::getEntryPointHostCallable( + entryPointIndex, + targetIndex, + outSharedLibrary, + outDiagnostics); + } - SLANG_NO_THROW SlangResult SLANG_MCALL getEntryPointMetadata( - SlangInt entryPointIndex, - SlangInt targetIndex, - slang::IMetadata** outMetadata, - slang::IBlob** outDiagnostics) SLANG_OVERRIDE + SLANG_NO_THROW SlangResult SLANG_MCALL + findEntryPointByName(char const* name, slang::IEntryPoint** outEntryPoint) SLANG_OVERRIDE + { + if (outEntryPoint == nullptr) { - return Super::getEntryPointMetadata(entryPointIndex, targetIndex, outMetadata, outDiagnostics); + return SLANG_E_INVALID_ARG; } - SLANG_NO_THROW SlangResult SLANG_MCALL getTargetMetadata( - SlangInt targetIndex, - slang::IMetadata** outMetadata, - slang::IBlob** outDiagnostics) SLANG_OVERRIDE - { - return Super::getTargetMetadata(targetIndex, outMetadata, outDiagnostics); - } + ComPtr entryPoint(findEntryPointByName(UnownedStringSlice(name))); + if ((!entryPoint)) + return SLANG_FAIL; - SLANG_NO_THROW SlangResult SLANG_MCALL getResultAsFileSystem( - SlangInt entryPointIndex, - SlangInt targetIndex, - ISlangMutableFileSystem** outFileSystem) SLANG_OVERRIDE - { - return Super::getResultAsFileSystem(entryPointIndex, targetIndex, outFileSystem); - } + *outEntryPoint = entryPoint.detach(); + return SLANG_OK; + } - SLANG_NO_THROW SlangResult SLANG_MCALL specialize( - slang::SpecializationArg const* specializationArgs, - SlangInt specializationArgCount, - slang::IComponentType** outSpecializedComponentType, - ISlangBlob** outDiagnostics) SLANG_OVERRIDE + virtual SLANG_NO_THROW SlangResult SLANG_MCALL findAndCheckEntryPoint( + char const* name, + SlangStage stage, + slang::IEntryPoint** outEntryPoint, + ISlangBlob** outDiagnostics) override + { + if (outEntryPoint == nullptr) { - return Super::specialize( - specializationArgs, - specializationArgCount, - outSpecializedComponentType, - outDiagnostics); + return SLANG_E_INVALID_ARG; } - SLANG_NO_THROW SlangResult SLANG_MCALL renameEntryPoint( - const char* newName, slang::IComponentType** outEntryPoint) SLANG_OVERRIDE - { - return Super::renameEntryPoint(newName, outEntryPoint); - } + ComPtr entryPoint( + findAndCheckEntryPoint(UnownedStringSlice(name), stage, outDiagnostics)); + if ((!entryPoint)) + return SLANG_FAIL; - SLANG_NO_THROW SlangResult SLANG_MCALL link( - slang::IComponentType** outLinkedComponentType, - ISlangBlob** outDiagnostics) SLANG_OVERRIDE - { - return Super::link( - outLinkedComponentType, - outDiagnostics); - } + *outEntryPoint = entryPoint.detach(); + return SLANG_OK; + } - virtual SLANG_NO_THROW SlangResult SLANG_MCALL linkWithOptions( - slang::IComponentType** outLinkedComponentType, - uint32_t count, - slang::CompilerOptionEntry* entries, - ISlangBlob** outDiagnostics) override - { - return Super::linkWithOptions(outLinkedComponentType, count, entries, outDiagnostics); - } + virtual SlangInt32 SLANG_MCALL getDefinedEntryPointCount() override + { + return (SlangInt32)m_entryPoints.getCount(); + } - SLANG_NO_THROW SlangResult SLANG_MCALL getEntryPointHostCallable( - int entryPointIndex, - int targetIndex, - ISlangSharedLibrary** outSharedLibrary, - slang::IBlob** outDiagnostics) SLANG_OVERRIDE - { - return Super::getEntryPointHostCallable(entryPointIndex, targetIndex, outSharedLibrary, outDiagnostics); - } + virtual SlangResult SLANG_MCALL + getDefinedEntryPoint(SlangInt32 index, slang::IEntryPoint** outEntryPoint) override + { + if (index < 0 || index >= m_entryPoints.getCount()) + return SLANG_E_INVALID_ARG; - SLANG_NO_THROW void SLANG_MCALL getEntryPointHash( - SlangInt entryPointIndex, - SlangInt targetIndex, - slang::IBlob** outHash) SLANG_OVERRIDE + if (outEntryPoint == nullptr) { - return Super::getEntryPointHash(entryPointIndex, targetIndex, outHash); + return SLANG_E_INVALID_ARG; } - virtual void buildHash(DigestBuilder& builder) SLANG_OVERRIDE; + ComPtr entryPoint(m_entryPoints[index].Ptr()); + *outEntryPoint = entryPoint.detach(); + return SLANG_OK; + } - /// Create an entry point that refers to the given function. - static RefPtr create( - Linkage* linkage, - DeclRef funcDeclRef, - Profile profile); + virtual SLANG_NO_THROW SlangResult SLANG_MCALL linkWithOptions( + slang::IComponentType** outLinkedComponentType, + uint32_t count, + slang::CompilerOptionEntry* entries, + ISlangBlob** outDiagnostics) override + { + return Super::linkWithOptions(outLinkedComponentType, count, entries, outDiagnostics); + } + // - /// Get the function decl-ref, including any generic arguments. - DeclRef getFuncDeclRef() { return m_funcDeclRef; } + SLANG_NO_THROW void SLANG_MCALL getEntryPointHash( + SlangInt entryPointIndex, + SlangInt targetIndex, + slang::IBlob** outHash) SLANG_OVERRIDE + { + return Super::getEntryPointHash(entryPointIndex, targetIndex, outHash); + } - /// Get the function declaration (without generic arguments). - FuncDecl* getFuncDecl() { return m_funcDeclRef.getDecl(); } + SLANG_NO_THROW SlangResult SLANG_MCALL getEntryPointMetadata( + SlangInt entryPointIndex, + SlangInt targetIndex, + slang::IMetadata** outMetadata, + slang::IBlob** outDiagnostics) SLANG_OVERRIDE + { + return Super::getEntryPointMetadata( + entryPointIndex, + targetIndex, + outMetadata, + outDiagnostics); + } - /// Get the name of the entry point - Name* getName() { return m_name; } + SLANG_NO_THROW SlangResult SLANG_MCALL getTargetMetadata( + SlangInt targetIndex, + slang::IMetadata** outMetadata, + slang::IBlob** outDiagnostics) SLANG_OVERRIDE + { + return Super::getTargetMetadata(targetIndex, outMetadata, outDiagnostics); + } - /// Get the profile associated with the entry point - /// - /// Note: only the stage part of the profile is expected - /// to contain useful data, but certain legacy code paths - /// allow for "shader model" information to come via this path. - /// - Profile getProfile() { return m_profile; } + /// Get a serialized representation of the checked module. + virtual SLANG_NO_THROW SlangResult SLANG_MCALL + serialize(ISlangBlob** outSerializedBlob) override; - /// Get the stage that the entry point is for. - Stage getStage() { return m_profile.getStage(); } + /// Write the serialized representation of this module to a file. + virtual SLANG_NO_THROW SlangResult SLANG_MCALL writeToFile(char const* fileName) override; - /// Get the module that contains the entry point. - Module* getModule(); + /// Get the name of the module. + virtual SLANG_NO_THROW const char* SLANG_MCALL getName() override; - /// Get a list of modules that this entry point depends on. - /// - /// This will include the module that defines the entry point (see `getModule()`), - /// but may also include modules that are required by its generic type arguments. - /// - List const& getModuleDependencies() SLANG_OVERRIDE; // { return getModule()->getModuleDependencies(); } - List const& getFileDependencies() SLANG_OVERRIDE; // { return getModule()->getFileDependencies(); } + /// Get the path of the module. + virtual SLANG_NO_THROW const char* SLANG_MCALL getFilePath() override; - /// Create a dummy `EntryPoint` that is only usable for pass-through compilation. - static RefPtr createDummyForPassThrough( - Linkage* linkage, - Name* name, - Profile profile); + /// Get the unique identity of the module. + virtual SLANG_NO_THROW const char* SLANG_MCALL getUniqueIdentity() override; - /// Create a dummy `EntryPoint` that stands in for a serialized entry point - static RefPtr createDummyForDeserialize( - Linkage* linkage, - Name* name, - Profile profile, - String mangledName); + /// Get the number of dependency files that this module depends on. + /// This includes both the explicit source files, as well as any + /// additional files that were transitively referenced (e.g., via + /// a `#include` directive). + virtual SLANG_NO_THROW SlangInt32 SLANG_MCALL getDependencyFileCount() override; - /// Get the number of existential type parameters for the entry point. - SLANG_NO_THROW Index SLANG_MCALL getSpecializationParamCount() SLANG_OVERRIDE; + /// Get the path to a file this module depends on. + virtual SLANG_NO_THROW char const* SLANG_MCALL getDependencyFilePath(SlangInt32 index) override; - /// Get the existential type parameter at `index`. - SpecializationParam const& getSpecializationParam(Index index) SLANG_OVERRIDE; - Index getRequirementCount() SLANG_OVERRIDE; - RefPtr getRequirement(Index index) SLANG_OVERRIDE; + // IModulePrecompileService_Experimental + /// Precompile TU to target language + virtual SLANG_NO_THROW SlangResult SLANG_MCALL + precompileForTarget(SlangCompileTarget target, slang::IBlob** outDiagnostics) override; - SpecializationParams const& getExistentialSpecializationParams() { return m_existentialSpecializationParams; } + virtual SLANG_NO_THROW SlangResult SLANG_MCALL getPrecompiledTargetCode( + SlangCompileTarget target, + slang::IBlob** outCode, + slang::IBlob** outDiagnostics = nullptr) override; - Index getGenericSpecializationParamCount() { return m_genericSpecializationParams.getCount(); } - Index getExistentialSpecializationParamCount() { return m_existentialSpecializationParams.getCount(); } + virtual SLANG_NO_THROW SlangInt SLANG_MCALL getModuleDependencyCount() SLANG_OVERRIDE; - /// Get an array of all entry-point shader parameters. - List const& getShaderParams() { return m_shaderParams; } + virtual SLANG_NO_THROW SlangResult SLANG_MCALL getModuleDependency( + SlangInt dependencyIndex, + slang::IModule** outModule, + slang::IBlob** outDiagnostics = nullptr) SLANG_OVERRIDE; - Index getEntryPointCount() SLANG_OVERRIDE { return 1; }; - RefPtr getEntryPoint(Index index) SLANG_OVERRIDE { SLANG_UNUSED(index); return this; } - String getEntryPointMangledName(Index index) SLANG_OVERRIDE; - String getEntryPointNameOverride(Index index) SLANG_OVERRIDE; + virtual void buildHash(DigestBuilder& builder) SLANG_OVERRIDE; - Index getShaderParamCount() SLANG_OVERRIDE { return 0; } - ShaderParamInfo getShaderParam(Index index) SLANG_OVERRIDE { SLANG_UNUSED(index); return ShaderParamInfo(); } - - class EntryPointSpecializationInfo : public SpecializationInfo - { - public: - DeclRef specializedFuncDeclRef; - List existentialSpecializationArgs; - }; + virtual slang::DeclReflection* getModuleReflection() SLANG_OVERRIDE; - SLANG_NO_THROW slang::FunctionReflection* SLANG_MCALL getFunctionReflection() SLANG_OVERRIDE - { - return (slang::FunctionReflection*)m_funcDeclRef.declRefBase; - } - protected: - void acceptVisitor(ComponentTypeVisitor* visitor, SpecializationInfo* specializationInfo) SLANG_OVERRIDE; - - RefPtr _validateSpecializationArgsImpl( - SpecializationArg const* args, - Index argCount, - DiagnosticSink* sink) SLANG_OVERRIDE; - - private: - EntryPoint( - Linkage* linkage, - Name* name, - Profile profile, - DeclRef funcDeclRef); - - void _collectGenericSpecializationParamsRec(Decl* decl); - void _collectShaderParams(); - - // The name of the entry point function (e.g., `main`) - // - Name* m_name = nullptr; - - // The declaration of the entry-point function itself. - // - DeclRef m_funcDeclRef; - - /// The mangled name of the entry point function - String m_mangledName; - - SpecializationParams m_genericSpecializationParams; - SpecializationParams m_existentialSpecializationParams; - - /// Information about entry-point parameters - List m_shaderParams; - - // The profile that the entry point will be compiled for - // (this is a combination of the target stage, and also - // a feature level that sets capabilities) - // - // Note: the profile-version part of this should probably - // be moving towards deprecation, in favor of the version - // information (e.g., "Shader Model 5.1") always coming - // from the target, while the stage part is all that is - // intrinsic to the entry point. - // - Profile m_profile; - }; + void setDigest(SHA1::Digest const& digest) { m_digest = digest; } + SHA1::Digest computeDigest(); - class TypeConformance - : public ComponentType - , public slang::ITypeConformance - { - typedef ComponentType Super; + /// Create a module (initially empty). + Module(Linkage* linkage, ASTBuilder* astBuilder = nullptr); - public: - SLANG_REF_OBJECT_IUNKNOWN_ALL + /// Get the AST for the module (if it has been parsed) + ModuleDecl* getModuleDecl() { return m_moduleDecl; } - ISlangUnknown* getInterface(const Guid& guid); + /// The the IR for the module (if it has been generated) + IRModule* getIRModule() { return m_irModule; } - TypeConformance( - Linkage* linkage, - SubtypeWitness* witness, - Int confomrmanceIdOverride, - DiagnosticSink* sink); + /// Get the list of other modules this module depends on + List const& getModuleDependencyList() + { + return m_moduleDependencyList.getModuleList(); + } - // Forward `IComponentType` methods + /// Get the list of files this module depends on + List const& getFileDependencyList() { return m_fileDependencyList.getFileList(); } - SLANG_NO_THROW slang::ISession* SLANG_MCALL getSession() SLANG_OVERRIDE - { - return Super::getSession(); - } + /// Register a module that this module depends on + void addModuleDependency(Module* module); - SLANG_NO_THROW slang::ProgramLayout* SLANG_MCALL - getLayout(SlangInt targetIndex, slang::IBlob** outDiagnostics) SLANG_OVERRIDE - { - return Super::getLayout(targetIndex, outDiagnostics); - } + /// Register a source file that this module depends on + void addFileDependency(SourceFile* sourceFile); - SLANG_NO_THROW SlangResult SLANG_MCALL getEntryPointCode( - SlangInt entryPointIndex, - SlangInt targetIndex, - slang::IBlob** outCode, - slang::IBlob** outDiagnostics) SLANG_OVERRIDE - { - return Super::getEntryPointCode(entryPointIndex, targetIndex, outCode, outDiagnostics); - } + void clearFileDependency() { m_fileDependencyList.clear(); } + /// Set the AST for this module. + /// + /// This should only be called once, during creation of the module. + /// + void setModuleDecl(ModuleDecl* moduleDecl); // { m_moduleDecl = moduleDecl; } - SLANG_NO_THROW SlangResult SLANG_MCALL getTargetCode( - SlangInt targetIndex, - slang::IBlob** outCode, - slang::IBlob** outDiagnostics) SLANG_OVERRIDE - { - return Super::getTargetCode(targetIndex, outCode, outDiagnostics); - } + void setName(String name); + void setName(Name* name) { m_name = name; } + void setPathInfo(PathInfo pathInfo) { m_pathInfo = pathInfo; } - SLANG_NO_THROW SlangResult SLANG_MCALL getEntryPointMetadata( - SlangInt entryPointIndex, - SlangInt targetIndex, - slang::IMetadata** outMetadata, - slang::IBlob** outDiagnostics) SLANG_OVERRIDE - { - return Super::getEntryPointMetadata(entryPointIndex, targetIndex, outMetadata, outDiagnostics); - } + /// Set the IR for this module. + /// + /// This should only be called once, during creation of the module. + /// + void setIRModule(IRModule* irModule) { m_irModule = irModule; } - SLANG_NO_THROW SlangResult SLANG_MCALL getTargetMetadata( - SlangInt targetIndex, - slang::IMetadata** outMetadata, - slang::IBlob** outDiagnostics) SLANG_OVERRIDE - { - return Super::getTargetMetadata(targetIndex, outMetadata, outDiagnostics); - } + Index getEntryPointCount() SLANG_OVERRIDE { return 0; } + RefPtr getEntryPoint(Index index) SLANG_OVERRIDE + { + SLANG_UNUSED(index); + return nullptr; + } + String getEntryPointMangledName(Index index) SLANG_OVERRIDE + { + SLANG_UNUSED(index); + return String(); + } + String getEntryPointNameOverride(Index index) SLANG_OVERRIDE + { + SLANG_UNUSED(index); + return String(); + } - SLANG_NO_THROW SlangResult SLANG_MCALL getResultAsFileSystem( - SlangInt entryPointIndex, - SlangInt targetIndex, - ISlangMutableFileSystem** outFileSystem) SLANG_OVERRIDE - { - return Super::getResultAsFileSystem(entryPointIndex, targetIndex, outFileSystem); - } + Index getShaderParamCount() SLANG_OVERRIDE { return m_shaderParams.getCount(); } + ShaderParamInfo getShaderParam(Index index) SLANG_OVERRIDE { return m_shaderParams[index]; } - SLANG_NO_THROW SlangResult SLANG_MCALL specialize( - slang::SpecializationArg const* specializationArgs, - SlangInt specializationArgCount, - slang::IComponentType** outSpecializedComponentType, - ISlangBlob** outDiagnostics) SLANG_OVERRIDE - { - return Super::specialize( - specializationArgs, - specializationArgCount, - outSpecializedComponentType, - outDiagnostics); - } + SLANG_NO_THROW Index SLANG_MCALL getSpecializationParamCount() SLANG_OVERRIDE + { + return m_specializationParams.getCount(); + } + SpecializationParam const& getSpecializationParam(Index index) SLANG_OVERRIDE + { + return m_specializationParams[index]; + } - SLANG_NO_THROW SlangResult SLANG_MCALL renameEntryPoint( - const char* newName, slang::IComponentType** outEntryPoint) SLANG_OVERRIDE - { - return Super::renameEntryPoint(newName, outEntryPoint); - } + Index getRequirementCount() SLANG_OVERRIDE; + RefPtr getRequirement(Index index) SLANG_OVERRIDE; - SLANG_NO_THROW SlangResult SLANG_MCALL link( - slang::IComponentType** outLinkedComponentType, - ISlangBlob** outDiagnostics) SLANG_OVERRIDE - { - return Super::link(outLinkedComponentType, outDiagnostics); - } + List const& getModuleDependencies() SLANG_OVERRIDE + { + return m_moduleDependencyList.getModuleList(); + } + List const& getFileDependencies() SLANG_OVERRIDE + { + return m_fileDependencyList.getFileList(); + } - virtual SLANG_NO_THROW SlangResult SLANG_MCALL linkWithOptions( - slang::IComponentType** outLinkedComponentType, - uint32_t count, - slang::CompilerOptionEntry* entries, - ISlangBlob** outDiagnostics) override - { - return Super::linkWithOptions(outLinkedComponentType, count, entries, outDiagnostics); - } + /// Given a mangled name finds the exported NodeBase associated with this module. + /// If not found returns nullptr. + NodeBase* findExportFromMangledName(const UnownedStringSlice& slice); + + /// Get the ASTBuilder + ASTBuilder* getASTBuilder() { return m_astBuilder; } + + /// Collect information on the shader parameters of the module. + /// + /// This method should only be called once, after the core + /// structured of the module (its AST and IR) have been created, + /// and before any of the `ComponentType` APIs are used. + /// + /// TODO: We might eventually consider a non-stateful approach + /// to constructing a `Module`. + /// + void _collectShaderParams(); + + void _discoverEntryPoints(DiagnosticSink* sink, const List>& targets); + void _discoverEntryPointsImpl( + ContainerDecl* containerDecl, + DiagnosticSink* sink, + const List>& targets); - SLANG_NO_THROW SlangResult SLANG_MCALL getEntryPointHostCallable( - int entryPointIndex, - int targetIndex, - ISlangSharedLibrary** outSharedLibrary, - slang::IBlob** outDiagnostics) SLANG_OVERRIDE - { - return Super::getEntryPointHostCallable( - entryPointIndex, targetIndex, outSharedLibrary, outDiagnostics); - } - SLANG_NO_THROW void SLANG_MCALL getEntryPointHash( - SlangInt entryPointIndex, - SlangInt targetIndex, - slang::IBlob** outHash) SLANG_OVERRIDE + class ModuleSpecializationInfo : public SpecializationInfo + { + public: + struct GenericArgInfo { - return Super::getEntryPointHash(entryPointIndex, targetIndex, outHash); - } - - virtual void buildHash(DigestBuilder& builder) SLANG_OVERRIDE; + Decl* paramDecl = nullptr; + Val* argVal = nullptr; + }; - List const& getModuleDependencies() SLANG_OVERRIDE; - List const& getFileDependencies() SLANG_OVERRIDE; + List genericArgs; + List existentialArgs; + }; - SLANG_NO_THROW Index SLANG_MCALL getSpecializationParamCount() SLANG_OVERRIDE { return 0; } + RefPtr findEntryPointByName(UnownedStringSlice const& name); + RefPtr findAndCheckEntryPoint( + UnownedStringSlice const& name, + SlangStage stage, + ISlangBlob** outDiagnostics); - /// Get the existential type parameter at `index`. - SpecializationParam const& getSpecializationParam(Index /*index*/) SLANG_OVERRIDE - { - static SpecializationParam emptyParam; - return emptyParam; - } + List>& getEntryPoints() { return m_entryPoints; } + void _addEntryPoint(EntryPoint* entryPoint); + void _processFindDeclsExportSymbolsRec(Decl* decl); - Index getRequirementCount() SLANG_OVERRIDE; - RefPtr getRequirement(Index index) SLANG_OVERRIDE; - Index getEntryPointCount() SLANG_OVERRIDE { return 0; }; - RefPtr getEntryPoint(Index index) SLANG_OVERRIDE - { - SLANG_UNUSED(index); - return nullptr; - } - String getEntryPointMangledName(Index /*index*/) SLANG_OVERRIDE { return ""; } - String getEntryPointNameOverride(Index /*index*/) SLANG_OVERRIDE { return ""; } + // Gets the files that has been included into the module. + Dictionary& getIncludedSourceFileMap() + { + return m_mapSourceFileToFileDecl; + } - Index getShaderParamCount() SLANG_OVERRIDE { return 0; } - ShaderParamInfo getShaderParam(Index index) SLANG_OVERRIDE - { - SLANG_UNUSED(index); - return ShaderParamInfo(); - } +protected: + void acceptVisitor(ComponentTypeVisitor* visitor, SpecializationInfo* specializationInfo) + SLANG_OVERRIDE; - SubtypeWitness* getSubtypeWitness() { return m_subtypeWitness; } - IRModule* getIRModule() { return m_irModule.Ptr(); } - protected: - void acceptVisitor(ComponentTypeVisitor* visitor, SpecializationInfo* specializationInfo) - SLANG_OVERRIDE; - - RefPtr _validateSpecializationArgsImpl( - SpecializationArg const* args, - Index argCount, - DiagnosticSink* sink) SLANG_OVERRIDE; - private: - SubtypeWitness* m_subtypeWitness; - ModuleDependencyList m_moduleDependencyList; - FileDependencyList m_fileDependencyList; - List> m_requirements; - HashSet m_requirementSet; - RefPtr m_irModule; - Int m_conformanceIdOverride; - void addDepedencyFromWitness(SubtypeWitness* witness); - }; + RefPtr _validateSpecializationArgsImpl( + SpecializationArg const* args, + Index argCount, + DiagnosticSink* sink) SLANG_OVERRIDE; - enum class PassThroughMode : SlangPassThroughIntegral - { - None = SLANG_PASS_THROUGH_NONE, ///< don't pass through: use Slang compiler - Fxc = SLANG_PASS_THROUGH_FXC, ///< pass through HLSL to `D3DCompile` API - Dxc = SLANG_PASS_THROUGH_DXC, ///< pass through HLSL to `IDxcCompiler` API - Glslang = SLANG_PASS_THROUGH_GLSLANG, ///< pass through GLSL to `glslang` library - SpirvDis = SLANG_PASS_THROUGH_SPIRV_DIS, ///< pass through spirv-dis - Clang = SLANG_PASS_THROUGH_CLANG, ///< Pass through clang compiler - VisualStudio = SLANG_PASS_THROUGH_VISUAL_STUDIO, ///< Visual studio compiler - Gcc = SLANG_PASS_THROUGH_GCC, ///< Gcc compiler - GenericCCpp = SLANG_PASS_THROUGH_GENERIC_C_CPP, ///< Generic C/C++ compiler - NVRTC = SLANG_PASS_THROUGH_NVRTC, ///< NVRTC CUDA compiler - LLVM = SLANG_PASS_THROUGH_LLVM, ///< LLVM 'compiler' - SpirvOpt = SLANG_PASS_THROUGH_SPIRV_OPT, ///< pass thorugh spirv to spirv-opt - MetalC = SLANG_PASS_THROUGH_METAL, - Tint = SLANG_PASS_THROUGH_TINT, ///< pass through spirv to Tint API - CountOf = SLANG_PASS_THROUGH_COUNT_OF, - }; - void printDiagnosticArg(StringBuilder& sb, PassThroughMode val); +private: + Name* m_name = nullptr; + PathInfo m_pathInfo; - class SourceFile; + // The AST for the module + ModuleDecl* m_moduleDecl = nullptr; - /// A module of code that has been compiled through the front-end - /// - /// A module comprises all the code from one translation unit (which - /// may span multiple Slang source files), and provides access - /// to both the AST and IR representations of that code. - /// - class Module : public ComponentType, public slang::IModule - { - typedef ComponentType Super; + // The IR for the module + RefPtr m_irModule = nullptr; - public: - SLANG_REF_OBJECT_IUNKNOWN_ALL + List m_shaderParams; + SpecializationParams m_specializationParams; - ISlangUnknown* getInterface(const Guid& guid); + List m_requirements; + // A digest that uniquely identifies the contents of the module. + SHA1::Digest m_digest; - // Forward `IComponentType` methods + // List of modules this module depends on + ModuleDependencyList m_moduleDependencyList; - SLANG_NO_THROW slang::ISession* SLANG_MCALL getSession() SLANG_OVERRIDE - { - return Super::getSession(); - } + // List of source files this module depends on + FileDependencyList m_fileDependencyList; - SLANG_NO_THROW slang::ProgramLayout* SLANG_MCALL getLayout( - SlangInt targetIndex, - slang::IBlob** outDiagnostics) SLANG_OVERRIDE - { - return Super::getLayout(targetIndex, outDiagnostics); - } + // Entry points that were defined in this module + // + // Note: the entry point defined in the module are *not* + // part of the memory image/layout of the module when + // it is considered as an IComponentType. This can be + // a bit confusing, but if all the entry points in the + // module were automatically linked into the component + // type, we'd need a way to access just the global + // scope of the module without the entry points, in + // case we wanted to link a single entry point against + // the global scope. The `Module` type provides exactly + // that "module without its entry points" unit of + // granularity for linking. + // + // This list only exists for lookup purposes, so that + // the user can find an existing entry-point function + // that was defined as part of the module. + // + List> m_entryPoints; - SLANG_NO_THROW SlangResult SLANG_MCALL getEntryPointCode( - SlangInt entryPointIndex, - SlangInt targetIndex, - slang::IBlob** outCode, - slang::IBlob** outDiagnostics) SLANG_OVERRIDE - { - return Super::getEntryPointCode(entryPointIndex, targetIndex, outCode, outDiagnostics); - } + // The builder that owns all of the AST nodes from parsing the source of + // this module. + RefPtr m_astBuilder; - SLANG_NO_THROW SlangResult SLANG_MCALL getTargetCode( - SlangInt targetIndex, - slang::IBlob** outCode, - slang::IBlob** outDiagnostics) SLANG_OVERRIDE - { - return Super::getTargetCode(targetIndex, outCode, outDiagnostics); - } + // Holds map of exported mangled names to symbols. m_mangledExportPool maps names to indices, + // and m_mangledExportSymbols holds the NodeBase* values for each index. + StringSlicePool m_mangledExportPool; + List m_mangledExportSymbols; - SLANG_NO_THROW SlangResult SLANG_MCALL getResultAsFileSystem( - SlangInt entryPointIndex, - SlangInt targetIndex, - ISlangMutableFileSystem** outFileSystem) SLANG_OVERRIDE - { - return Super::getResultAsFileSystem(entryPointIndex, targetIndex, outFileSystem); - } + // Source files that have been pulled into the module with `__include`. + Dictionary m_mapSourceFileToFileDecl; +}; +typedef Module LoadedModule; - SLANG_NO_THROW SlangResult SLANG_MCALL specialize( - slang::SpecializationArg const* specializationArgs, - SlangInt specializationArgCount, - slang::IComponentType** outSpecializedComponentType, - ISlangBlob** outDiagnostics) SLANG_OVERRIDE - { - return Super::specialize( - specializationArgs, - specializationArgCount, - outSpecializedComponentType, - outDiagnostics); - } +/// A request for the front-end to compile a translation unit. +class TranslationUnitRequest : public RefObject +{ +public: + TranslationUnitRequest(FrontEndCompileRequest* compileRequest); + TranslationUnitRequest(FrontEndCompileRequest* compileRequest, Module* m); - SLANG_NO_THROW SlangResult SLANG_MCALL renameEntryPoint( - const char* newName, slang::IComponentType** outEntryPoint) SLANG_OVERRIDE - { - return Super::renameEntryPoint(newName, outEntryPoint); - } + // The parent compile request + FrontEndCompileRequest* compileRequest = nullptr; - SLANG_NO_THROW SlangResult SLANG_MCALL link( - slang::IComponentType** outLinkedComponentType, - ISlangBlob** outDiagnostics) SLANG_OVERRIDE - { - return Super::link( - outLinkedComponentType, - outDiagnostics); - } + // The language in which the source file(s) + // are assumed to be written + SourceLanguage sourceLanguage = SourceLanguage::Unknown; - SLANG_NO_THROW SlangResult SLANG_MCALL getEntryPointHostCallable( - int entryPointIndex, - int targetIndex, - ISlangSharedLibrary** outSharedLibrary, - slang::IBlob** outDiagnostics) SLANG_OVERRIDE - { - return Super::getEntryPointHostCallable(entryPointIndex, targetIndex, outSharedLibrary, outDiagnostics); - } + /// Makes any source artifact available as a SourceFile. + /// If successful any of the source artifacts will be represented by the same index + /// of sourceArtifacts + SlangResult requireSourceFiles(); - SLANG_NO_THROW SlangResult SLANG_MCALL findEntryPointByName( - char const* name, - slang::IEntryPoint** outEntryPoint) SLANG_OVERRIDE - { - if (outEntryPoint == nullptr) - { - return SLANG_E_INVALID_ARG; - } + /// Get the source files. + /// Since lazily evaluated requires calling requireSourceFiles to know it's in sync + /// with sourceArtifacts. + List const& getSourceFiles(); - ComPtr entryPoint(findEntryPointByName(UnownedStringSlice(name))); - if((!entryPoint)) - return SLANG_FAIL; + /// Get the source artifacts associated + const List>& getSourceArtifacts() const { return m_sourceArtifacts; } - *outEntryPoint = entryPoint.detach(); - return SLANG_OK; - } + /// Clear all of the source + void clearSource() + { + m_sourceArtifacts.clear(); + m_sourceFiles.clear(); + } - virtual SLANG_NO_THROW SlangResult SLANG_MCALL findAndCheckEntryPoint( - char const* name, - SlangStage stage, - slang::IEntryPoint** outEntryPoint, - ISlangBlob** outDiagnostics) override - { - if (outEntryPoint == nullptr) - { - return SLANG_E_INVALID_ARG; - } + /// Add a source artifact + void addSourceArtifact(IArtifact* sourceArtifact); - ComPtr entryPoint(findAndCheckEntryPoint(UnownedStringSlice(name), stage, outDiagnostics)); - if ((!entryPoint)) - return SLANG_FAIL; + /// Add both the artifact and the sourceFile. + void addSource(IArtifact* sourceArtifact, SourceFile* sourceFile); - *outEntryPoint = entryPoint.detach(); - return SLANG_OK; - } + // The entry points associated with this translation unit + List> const& getEntryPoints() { return module->getEntryPoints(); } - virtual SlangInt32 SLANG_MCALL getDefinedEntryPointCount() override - { - return (SlangInt32)m_entryPoints.getCount(); - } + void _addEntryPoint(EntryPoint* entryPoint) { module->_addEntryPoint(entryPoint); } - virtual SlangResult SLANG_MCALL getDefinedEntryPoint(SlangInt32 index, slang::IEntryPoint** outEntryPoint) override - { - if (index < 0 || index >= m_entryPoints.getCount()) - return SLANG_E_INVALID_ARG; + // Preprocessor definitions to use for this translation unit only + // (whereas the ones on `compileRequest` will be shared) + Dictionary preprocessorDefinitions; - if (outEntryPoint == nullptr) - { - return SLANG_E_INVALID_ARG; - } + /// The name that will be used for the module this translation unit produces. + Name* moduleName = nullptr; - ComPtr entryPoint(m_entryPoints[index].Ptr()); - *outEntryPoint = entryPoint.detach(); - return SLANG_OK; - } + /// Result of compiling this translation unit (a module) + RefPtr module; - virtual SLANG_NO_THROW SlangResult SLANG_MCALL linkWithOptions( - slang::IComponentType** outLinkedComponentType, - uint32_t count, - slang::CompilerOptionEntry* entries, - ISlangBlob** outDiagnostics) override - { - return Super::linkWithOptions(outLinkedComponentType, count, entries, outDiagnostics); - } - // + bool isChecked = false; - SLANG_NO_THROW void SLANG_MCALL getEntryPointHash( - SlangInt entryPointIndex, - SlangInt targetIndex, - slang::IBlob** outHash) SLANG_OVERRIDE - { - return Super::getEntryPointHash(entryPointIndex, targetIndex, outHash); - } + Module* getModule() { return module; } + ModuleDecl* getModuleDecl() { return module->getModuleDecl(); } - SLANG_NO_THROW SlangResult SLANG_MCALL getEntryPointMetadata( - SlangInt entryPointIndex, - SlangInt targetIndex, - slang::IMetadata** outMetadata, - slang::IBlob** outDiagnostics) SLANG_OVERRIDE - { - return Super::getEntryPointMetadata(entryPointIndex, targetIndex, outMetadata, outDiagnostics); - } + Session* getSession(); + NamePool* getNamePool(); + SourceManager* getSourceManager(); - SLANG_NO_THROW SlangResult SLANG_MCALL getTargetMetadata( - SlangInt targetIndex, - slang::IMetadata** outMetadata, - slang::IBlob** outDiagnostics) SLANG_OVERRIDE - { - return Super::getTargetMetadata(targetIndex, outMetadata, outDiagnostics); - } + Scope* getLanguageScope(); - /// Get a serialized representation of the checked module. - virtual SLANG_NO_THROW SlangResult SLANG_MCALL serialize(ISlangBlob** outSerializedBlob) override; + Dictionary getCombinedPreprocessorDefinitions(); - /// Write the serialized representation of this module to a file. - virtual SLANG_NO_THROW SlangResult SLANG_MCALL writeToFile(char const* fileName) override; + void setModuleName(Name* name) + { + moduleName = name; + if (module) + module->setName(name); + } - /// Get the name of the module. - virtual SLANG_NO_THROW const char* SLANG_MCALL getName() override; +protected: + void _addSourceFile(SourceFile* sourceFile); + /* Given an artifact, find a PathInfo. + If no PathInfo can be found will return an unknown PathInfo */ + PathInfo _findSourcePathInfo(IArtifact* artifact); - /// Get the path of the module. - virtual SLANG_NO_THROW const char* SLANG_MCALL getFilePath() override; + List> m_sourceArtifacts; + // The source file(s) that will be compiled to form this translation unit + // + // Usually, for HLSL or GLSL there will be only one file. + // NOTE! This member is generated lazily from m_sourceArtifacts + // it is *necessary* to call requireSourceFiles to ensure it's in sync. + List m_sourceFiles; +}; - /// Get the unique identity of the module. - virtual SLANG_NO_THROW const char* SLANG_MCALL getUniqueIdentity() override; +enum class FloatingPointMode : SlangFloatingPointModeIntegral +{ + Default = SLANG_FLOATING_POINT_MODE_DEFAULT, + Fast = SLANG_FLOATING_POINT_MODE_FAST, + Precise = SLANG_FLOATING_POINT_MODE_PRECISE, +}; - /// Get the number of dependency files that this module depends on. - /// This includes both the explicit source files, as well as any - /// additional files that were transitively referenced (e.g., via - /// a `#include` directive). - virtual SLANG_NO_THROW SlangInt32 SLANG_MCALL getDependencyFileCount() override; +enum class WriterChannel : SlangWriterChannelIntegral +{ + Diagnostic = SLANG_WRITER_CHANNEL_DIAGNOSTIC, + StdOutput = SLANG_WRITER_CHANNEL_STD_OUTPUT, + StdError = SLANG_WRITER_CHANNEL_STD_ERROR, + CountOf = SLANG_WRITER_CHANNEL_COUNT_OF, +}; - /// Get the path to a file this module depends on. - virtual SLANG_NO_THROW char const* SLANG_MCALL getDependencyFilePath( - SlangInt32 index) override; +enum class WriterMode : SlangWriterModeIntegral +{ + Text = SLANG_WRITER_MODE_TEXT, + Binary = SLANG_WRITER_MODE_BINARY, +}; +class TargetRequest; - // IModulePrecompileService_Experimental - /// Precompile TU to target language - virtual SLANG_NO_THROW SlangResult SLANG_MCALL precompileForTarget( - SlangCompileTarget target, - slang::IBlob** outDiagnostics) override; +/// Are we generating code for a D3D API? +bool isD3DTarget(TargetRequest* targetReq); - virtual SLANG_NO_THROW SlangResult SLANG_MCALL getPrecompiledTargetCode( - SlangCompileTarget target, - slang::IBlob** outCode, - slang::IBlob** outDiagnostics = nullptr) override; +// Are we generating code for Metal? +bool isMetalTarget(TargetRequest* targetReq); - virtual SLANG_NO_THROW SlangInt SLANG_MCALL getModuleDependencyCount() - SLANG_OVERRIDE; +/// Are we generating code for a Khronos API (OpenGL or Vulkan)? +bool isKhronosTarget(TargetRequest* targetReq); - virtual SLANG_NO_THROW SlangResult SLANG_MCALL getModuleDependency( - SlangInt dependencyIndex, - slang::IModule** outModule, - slang::IBlob** outDiagnostics = nullptr) SLANG_OVERRIDE; +/// Are we generating code for a CUDA API (CUDA / OptiX)? +bool isCUDATarget(TargetRequest* targetReq); - virtual void buildHash(DigestBuilder& builder) SLANG_OVERRIDE; +// Are we generating code for a CPU target +bool isCPUTarget(TargetRequest* targetReq); - virtual slang::DeclReflection* getModuleReflection() SLANG_OVERRIDE; +/// A request to generate output in some target format. +class TargetRequest : public RefObject +{ +public: + TargetRequest(Linkage* linkage, CodeGenTarget format); - void setDigest(SHA1::Digest const& digest) { m_digest = digest; } - SHA1::Digest computeDigest(); + TargetRequest(const TargetRequest& other); - /// Create a module (initially empty). - Module(Linkage* linkage, ASTBuilder* astBuilder = nullptr); + Linkage* getLinkage() { return linkage; } - /// Get the AST for the module (if it has been parsed) - ModuleDecl* getModuleDecl() { return m_moduleDecl; } + Session* getSession(); - /// The the IR for the module (if it has been generated) - IRModule* getIRModule() { return m_irModule; } + CodeGenTarget getTarget() + { + return optionSet.getEnumOption(CompilerOptionName::Target); + } - /// Get the list of other modules this module depends on - List const& getModuleDependencyList() { return m_moduleDependencyList.getModuleList(); } + // TypeLayouts created on the fly by reflection API + struct TypeLayoutKey + { + Type* type; + slang::LayoutRules rules; + HashCode getHashCode() const + { + Hasher hasher; + hasher.hashValue(type); + hasher.hashValue(rules); + return hasher.getResult(); + } + bool operator==(TypeLayoutKey other) const + { + return type == other.type && rules == other.rules; + } + }; + Dictionary> typeLayouts; - /// Get the list of files this module depends on - List const& getFileDependencyList() { return m_fileDependencyList.getFileList(); } + Dictionary>& getTypeLayouts() { return typeLayouts; } - /// Register a module that this module depends on - void addModuleDependency(Module* module); + TypeLayout* getTypeLayout(Type* type, slang::LayoutRules rules); - /// Register a source file that this module depends on - void addFileDependency(SourceFile* sourceFile); + CompilerOptionSet& getOptionSet() { return optionSet; } - void clearFileDependency() { m_fileDependencyList.clear(); } - /// Set the AST for this module. - /// - /// This should only be called once, during creation of the module. - /// - void setModuleDecl(ModuleDecl* moduleDecl);// { m_moduleDecl = moduleDecl; } + CapabilitySet getTargetCaps(); - void setName(String name); - void setName(Name* name) { m_name = name; } - void setPathInfo(PathInfo pathInfo) { m_pathInfo = pathInfo; } + void setTargetCaps(CapabilitySet capSet); - /// Set the IR for this module. - /// - /// This should only be called once, during creation of the module. - /// - void setIRModule(IRModule* irModule) { m_irModule = irModule; } + HLSLToVulkanLayoutOptions* getHLSLToVulkanLayoutOptions(); - Index getEntryPointCount() SLANG_OVERRIDE { return 0; } - RefPtr getEntryPoint(Index index) SLANG_OVERRIDE { SLANG_UNUSED(index); return nullptr; } - String getEntryPointMangledName(Index index) SLANG_OVERRIDE { SLANG_UNUSED(index); return String(); } - String getEntryPointNameOverride(Index index) SLANG_OVERRIDE { SLANG_UNUSED(index); return String(); } +private: + Linkage* linkage = nullptr; + CompilerOptionSet optionSet; + CapabilitySet cookedCapabilities; + RefPtr hlslToVulkanOptions; +}; - Index getShaderParamCount() SLANG_OVERRIDE { return m_shaderParams.getCount(); } - ShaderParamInfo getShaderParam(Index index) SLANG_OVERRIDE { return m_shaderParams[index]; } +/// Given a target request returns which (if any) intermediate source language is required +/// to produce it. +/// +/// If no intermediate source language is required, will return SourceLanguage::Unknown +SourceLanguage getIntermediateSourceLanguageForTarget(TargetProgram* req); - SLANG_NO_THROW Index SLANG_MCALL getSpecializationParamCount() SLANG_OVERRIDE { return m_specializationParams.getCount(); } - SpecializationParam const& getSpecializationParam(Index index) SLANG_OVERRIDE { return m_specializationParams[index]; } +/// Are resource types "bindless" (implemented as ordinary data) on the given `target`? +bool areResourceTypesBindlessOnTarget(TargetRequest* target); - Index getRequirementCount() SLANG_OVERRIDE; - RefPtr getRequirement(Index index) SLANG_OVERRIDE; +// Compute the "effective" profile to use when outputting the given entry point +// for the chosen code-generation target. +// +// The stage of the effective profile will always come from the entry point, while +// the profile version (aka "shader model") will be computed as follows: +// +// - If the entry point and target belong to the same profile family, then take +// the latest version between the two (e.g., if the entry point specified `ps_5_1` +// and the target specifies `sm_5_0` then use `sm_5_1` as the version). +// +// - If the entry point and target disagree on the profile family, always use the +// profile family and version from the target. +// +Profile getEffectiveProfile(EntryPoint* entryPoint, TargetRequest* target); - List const& getModuleDependencies() SLANG_OVERRIDE { return m_moduleDependencyList.getModuleList(); } - List const& getFileDependencies() SLANG_OVERRIDE { return m_fileDependencyList.getFileList(); } - /// Given a mangled name finds the exported NodeBase associated with this module. - /// If not found returns nullptr. - NodeBase* findExportFromMangledName(const UnownedStringSlice& slice); +/// Given a target returns the required downstream compiler +PassThroughMode getDownstreamCompilerRequiredForTarget(CodeGenTarget target); +/// Given a target returns a downstream compiler the prelude should be taken from. +SourceLanguage getDefaultSourceLanguageForDownstreamCompiler(PassThroughMode compiler); - /// Get the ASTBuilder - ASTBuilder* getASTBuilder() { return m_astBuilder; } +/// Get the build tag string +const char* getBuildTagString(); - /// Collect information on the shader parameters of the module. - /// - /// This method should only be called once, after the core - /// structured of the module (its AST and IR) have been created, - /// and before any of the `ComponentType` APIs are used. - /// - /// TODO: We might eventually consider a non-stateful approach - /// to constructing a `Module`. - /// - void _collectShaderParams(); +struct TypeCheckingCache; - void _discoverEntryPoints(DiagnosticSink* sink, const List>& targets); - void _discoverEntryPointsImpl(ContainerDecl* containerDecl, DiagnosticSink* sink, const List>& targets); +struct ContainerTypeKey +{ + slang::TypeReflection* elementType; + slang::ContainerType containerType; + bool operator==(ContainerTypeKey other) const + { + return elementType == other.elementType && containerType == other.containerType; + } + Slang::HashCode getHashCode() const + { + return Slang::combineHash( + Slang::getHashCode(elementType), + Slang::getHashCode(containerType)); + } +}; +/// A dictionary of currently loaded modules. Used by `findOrImportModule` to +/// lookup additional loaded modules. +typedef Dictionary LoadedModuleDictionary; - class ModuleSpecializationInfo : public SpecializationInfo - { - public: - struct GenericArgInfo - { - Decl* paramDecl = nullptr; - Val* argVal = nullptr; - }; - - List genericArgs; - List existentialArgs; - }; +enum ModuleBlobType +{ + Source, + IR +}; - RefPtr findEntryPointByName(UnownedStringSlice const& name); - RefPtr findAndCheckEntryPoint(UnownedStringSlice const& name, SlangStage stage, ISlangBlob** outDiagnostics); - - List>& getEntryPoints() { return m_entryPoints; } - void _addEntryPoint(EntryPoint* entryPoint); - void _processFindDeclsExportSymbolsRec(Decl* decl); - - // Gets the files that has been included into the module. - Dictionary& getIncludedSourceFileMap() { return m_mapSourceFileToFileDecl; } - - protected: - void acceptVisitor(ComponentTypeVisitor* visitor, SpecializationInfo* specializationInfo) SLANG_OVERRIDE; - - RefPtr _validateSpecializationArgsImpl( - SpecializationArg const* args, - Index argCount, - DiagnosticSink* sink) SLANG_OVERRIDE; - - private: - Name* m_name = nullptr; - PathInfo m_pathInfo; - - // The AST for the module - ModuleDecl* m_moduleDecl = nullptr; - - // The IR for the module - RefPtr m_irModule = nullptr; - - List m_shaderParams; - SpecializationParams m_specializationParams; - - List m_requirements; - - // A digest that uniquely identifies the contents of the module. - SHA1::Digest m_digest; - - // List of modules this module depends on - ModuleDependencyList m_moduleDependencyList; - - // List of source files this module depends on - FileDependencyList m_fileDependencyList; - - // Entry points that were defined in this module - // - // Note: the entry point defined in the module are *not* - // part of the memory image/layout of the module when - // it is considered as an IComponentType. This can be - // a bit confusing, but if all the entry points in the - // module were automatically linked into the component - // type, we'd need a way to access just the global - // scope of the module without the entry points, in - // case we wanted to link a single entry point against - // the global scope. The `Module` type provides exactly - // that "module without its entry points" unit of - // granularity for linking. - // - // This list only exists for lookup purposes, so that - // the user can find an existing entry-point function - // that was defined as part of the module. - // - List> m_entryPoints; - - // The builder that owns all of the AST nodes from parsing the source of - // this module. - RefPtr m_astBuilder; - - // Holds map of exported mangled names to symbols. m_mangledExportPool maps names to indices, - // and m_mangledExportSymbols holds the NodeBase* values for each index. - StringSlicePool m_mangledExportPool; - List m_mangledExportSymbols; - - // Source files that have been pulled into the module with `__include`. - Dictionary m_mapSourceFileToFileDecl; - }; - typedef Module LoadedModule; +struct SerialContainerDataModule; - /// A request for the front-end to compile a translation unit. - class TranslationUnitRequest : public RefObject +/// A context for loading and re-using code modules. +class Linkage : public RefObject, public slang::ISession +{ +public: + SLANG_REF_OBJECT_IUNKNOWN_ALL + + CompilerOptionSet m_optionSet; + + ISlangUnknown* getInterface(const Guid& guid); + + SLANG_NO_THROW slang::IGlobalSession* SLANG_MCALL getGlobalSession() override; + SLANG_NO_THROW slang::IModule* SLANG_MCALL + loadModule(const char* moduleName, slang::IBlob** outDiagnostics = nullptr) override; + slang::IModule* loadModuleFromBlob( + const char* moduleName, + const char* path, + slang::IBlob* source, + ModuleBlobType blobType, + slang::IBlob** outDiagnostics = nullptr); + SLANG_NO_THROW slang::IModule* SLANG_MCALL loadModuleFromIRBlob( + const char* moduleName, + const char* path, + slang::IBlob* source, + slang::IBlob** outDiagnostics = nullptr) override; + SLANG_NO_THROW slang::IModule* SLANG_MCALL loadModuleFromSource( + const char* moduleName, + const char* path, + slang::IBlob* source, + slang::IBlob** outDiagnostics = nullptr) override; + SLANG_NO_THROW slang::IModule* SLANG_MCALL loadModuleFromSourceString( + const char* moduleName, + const char* path, + const char* string, + slang::IBlob** outDiagnostics = nullptr) override; + SLANG_NO_THROW SlangResult SLANG_MCALL createCompositeComponentType( + slang::IComponentType* const* componentTypes, + SlangInt componentTypeCount, + slang::IComponentType** outCompositeComponentType, + ISlangBlob** outDiagnostics = nullptr) override; + SLANG_NO_THROW slang::TypeReflection* SLANG_MCALL specializeType( + slang::TypeReflection* type, + slang::SpecializationArg const* specializationArgs, + SlangInt specializationArgCount, + ISlangBlob** outDiagnostics = nullptr) override; + SLANG_NO_THROW slang::TypeLayoutReflection* SLANG_MCALL getTypeLayout( + slang::TypeReflection* type, + SlangInt targetIndex = 0, + slang::LayoutRules rules = slang::LayoutRules::Default, + ISlangBlob** outDiagnostics = nullptr) override; + SLANG_NO_THROW slang::TypeReflection* SLANG_MCALL getContainerType( + slang::TypeReflection* elementType, + slang::ContainerType containerType, + ISlangBlob** outDiagnostics = nullptr) override; + SLANG_NO_THROW slang::TypeReflection* SLANG_MCALL getDynamicType() override; + SLANG_NO_THROW SlangResult SLANG_MCALL + getTypeRTTIMangledName(slang::TypeReflection* type, ISlangBlob** outNameBlob) override; + SLANG_NO_THROW SlangResult SLANG_MCALL getTypeConformanceWitnessMangledName( + slang::TypeReflection* type, + slang::TypeReflection* interfaceType, + ISlangBlob** outNameBlob) override; + SLANG_NO_THROW SlangResult SLANG_MCALL getTypeConformanceWitnessSequentialID( + slang::TypeReflection* type, + slang::TypeReflection* interfaceType, + uint32_t* outId) override; + SLANG_NO_THROW SlangResult SLANG_MCALL createTypeConformanceComponentType( + slang::TypeReflection* type, + slang::TypeReflection* interfaceType, + slang::ITypeConformance** outConformance, + SlangInt conformanceIdOverride, + ISlangBlob** outDiagnostics) override; + SLANG_NO_THROW SlangResult SLANG_MCALL + createCompileRequest(SlangCompileRequest** outCompileRequest) override; + virtual SLANG_NO_THROW SlangInt SLANG_MCALL getLoadedModuleCount() override; + virtual SLANG_NO_THROW slang::IModule* SLANG_MCALL getLoadedModule(SlangInt index) override; + virtual SLANG_NO_THROW bool SLANG_MCALL + isBinaryModuleUpToDate(const char* modulePath, slang::IBlob* binaryModuleBlob) override; + + // Updates the supplied builder with linkage-related information, which includes preprocessor + // defines, the compiler version, and other compiler options. This is then merged with the hash + // produced for the program to produce a key that can be used with the shader cache. + void buildHash(DigestBuilder& builder, SlangInt targetIndex = -1); + + void addTarget(slang::TargetDesc const& desc); + SlangResult addSearchPath(char const* path); + SlangResult addPreprocessorDefine(char const* name, char const* value); + SlangResult setMatrixLayoutMode(SlangMatrixLayoutMode mode); + /// Create an initially-empty linkage + Linkage(Session* session, ASTBuilder* astBuilder, Linkage* builtinLinkage); + + /// Dtor + ~Linkage(); + + bool isInLanguageServer() { - public: - TranslationUnitRequest( - FrontEndCompileRequest* compileRequest); - TranslationUnitRequest( - FrontEndCompileRequest* compileRequest, Module* m); + return contentAssistInfo.checkingMode != ContentAssistCheckingMode::None; + } + + /// Get the parent session for this linkage + Session* getSessionImpl() { return m_session; } - // The parent compile request - FrontEndCompileRequest* compileRequest = nullptr; + // Information on the targets we are being asked to + // generate code for. + List> targets; - // The language in which the source file(s) - // are assumed to be written - SourceLanguage sourceLanguage = SourceLanguage::Unknown; + // Directories to search for `#include` files or `import`ed modules + SearchDirectoryList& getSearchDirectories(); - /// Makes any source artifact available as a SourceFile. - /// If successful any of the source artifacts will be represented by the same index - /// of sourceArtifacts - SlangResult requireSourceFiles(); + // Source manager to help track files loaded + SourceManager m_defaultSourceManager; + SourceManager* m_sourceManager = nullptr; + RefPtr m_cmdLineContext; - /// Get the source files. - /// Since lazily evaluated requires calling requireSourceFiles to know it's in sync - /// with sourceArtifacts. - List const& getSourceFiles(); - - /// Get the source artifacts associated - const List>& getSourceArtifacts() const { return m_sourceArtifacts; } + // Name pool for looking up names + NamePool namePool; - /// Clear all of the source - void clearSource() { m_sourceArtifacts.clear(); m_sourceFiles.clear(); } + NamePool* getNamePool() { return &namePool; } - /// Add a source artifact - void addSourceArtifact(IArtifact* sourceArtifact); + ASTBuilder* getASTBuilder() { return m_astBuilder; } - /// Add both the artifact and the sourceFile. - void addSource(IArtifact* sourceArtifact, SourceFile* sourceFile); + RefPtr m_astBuilder; - // The entry points associated with this translation unit - List> const& getEntryPoints() { return module->getEntryPoints(); } + // Cache for container types. + Dictionary m_containerTypes; - void _addEntryPoint(EntryPoint* entryPoint) { module->_addEntryPoint(entryPoint); } + // cache used by type checking, implemented in check.cpp + TypeCheckingCache* getTypeCheckingCache(); + void destroyTypeCheckingCache(); - // Preprocessor definitions to use for this translation unit only - // (whereas the ones on `compileRequest` will be shared) - Dictionary preprocessorDefinitions; + TypeCheckingCache* m_typeCheckingCache = nullptr; - /// The name that will be used for the module this translation unit produces. - Name* moduleName = nullptr; + // Modules that have been dynamically loaded via `import` + // + // This is a list of unique modules loaded, in the order they were encountered. + List> loadedModulesList; - /// Result of compiling this translation unit (a module) - RefPtr module; + // Map from the path (or uniqueIdentity if available) of a module file to its definition + Dictionary> mapPathToLoadedModule; - bool isChecked = false; + // Map from the logical name of a module to its definition + Dictionary> mapNameToLoadedModules; - Module* getModule() { return module; } - ModuleDecl* getModuleDecl() { return module->getModuleDecl(); } + // Map from the mangled name of RTTI objects to sequential IDs + // used by `switch`-based dynamic dispatch. + Dictionary mapMangledNameToRTTIObjectIndex; - Session* getSession(); - NamePool* getNamePool(); - SourceManager* getSourceManager(); + // Counters for allocating sequential IDs to witness tables conforming to each interface type. + Dictionary mapInterfaceMangledNameToSequentialIDCounters; - Scope* getLanguageScope(); + SearchDirectoryList searchDirectoryCache; - Dictionary getCombinedPreprocessorDefinitions(); + // The resulting specialized IR module for each entry point request + List> compiledModules; - void setModuleName(Name* name) - { - moduleName = name; - if (module) - module->setName(name); - } + ContentAssistInfo contentAssistInfo; - protected: - void _addSourceFile(SourceFile* sourceFile); - /* Given an artifact, find a PathInfo. - If no PathInfo can be found will return an unknown PathInfo */ - PathInfo _findSourcePathInfo(IArtifact* artifact); - - List> m_sourceArtifacts; - // The source file(s) that will be compiled to form this translation unit - // - // Usually, for HLSL or GLSL there will be only one file. - // NOTE! This member is generated lazily from m_sourceArtifacts - // it is *necessary* to call requireSourceFiles to ensure it's in sync. - List m_sourceFiles; - }; + /// File system implementation to use when loading files from disk. + /// + /// If this member is `null`, a default implementation that tries + /// to use the native OS filesystem will be used instead. + /// + ComPtr m_fileSystem; - enum class FloatingPointMode : SlangFloatingPointModeIntegral - { - Default = SLANG_FLOATING_POINT_MODE_DEFAULT, - Fast = SLANG_FLOATING_POINT_MODE_FAST, - Precise = SLANG_FLOATING_POINT_MODE_PRECISE, - }; + /// The extended file system implementation. Will be set to a default implementation + /// if fileSystem is nullptr. Otherwise it will either be fileSystem's interface, + /// or a wrapped impl that makes fileSystem operate as fileSystemExt + ComPtr m_fileSystemExt; - enum class WriterChannel : SlangWriterChannelIntegral - { - Diagnostic = SLANG_WRITER_CHANNEL_DIAGNOSTIC, - StdOutput = SLANG_WRITER_CHANNEL_STD_OUTPUT, - StdError = SLANG_WRITER_CHANNEL_STD_ERROR, - CountOf = SLANG_WRITER_CHANNEL_COUNT_OF, - }; + /// Get the currenly set file system + ISlangFileSystemExt* getFileSystemExt() { return m_fileSystemExt; } - enum class WriterMode : SlangWriterModeIntegral - { - Text = SLANG_WRITER_MODE_TEXT, - Binary = SLANG_WRITER_MODE_BINARY, - }; + /// Load a file into memory using the configured file system. + /// + /// @param path The path to attempt to load from + /// @param outBlob A destination pointer to receive the loaded blob + /// @returns A `SlangResult` to indicate success or failure. + /// + SlangResult loadFile(String const& path, PathInfo& outPathInfo, ISlangBlob** outBlob); - class TargetRequest; + Expr* parseTermString(String str, Scope* scope); - /// Are we generating code for a D3D API? - bool isD3DTarget(TargetRequest* targetReq); + Type* specializeType( + Type* unspecializedType, + Int argCount, + Type* const* args, + DiagnosticSink* sink); - // Are we generating code for Metal? - bool isMetalTarget(TargetRequest* targetReq); + /// Add a new target and return its index. + UInt addTarget(CodeGenTarget target); + + RefPtr loadModule( + Name* name, + const PathInfo& filePathInfo, + ISlangBlob* fileContentsBlob, + SourceLoc const& loc, + DiagnosticSink* sink, + const LoadedModuleDictionary* additionalLoadedModules, + ModuleBlobType blobType); - /// Are we generating code for a Khronos API (OpenGL or Vulkan)? - bool isKhronosTarget(TargetRequest* targetReq); + RefPtr loadModuleFromIRBlobImpl( + Name* name, + const PathInfo& filePathInfo, + ISlangBlob* fileContentsBlob, + SourceLoc const& loc, + DiagnosticSink* sink, + const LoadedModuleDictionary* additionalLoadedModules); + RefPtr loadDeserializedModule( + Name* name, + const PathInfo& filePathInfo, + SerialContainerDataModule& m, + DiagnosticSink* sink); - /// Are we generating code for a CUDA API (CUDA / OptiX)? - bool isCUDATarget(TargetRequest* targetReq); + SourceFile* loadSourceFile(String pathFrom, String path); - // Are we generating code for a CPU target - bool isCPUTarget(TargetRequest* targetReq); + void loadParsedModule( + RefPtr compileRequest, + RefPtr translationUnit, + Name* name, + PathInfo const& pathInfo); - /// A request to generate output in some target format. - class TargetRequest : public RefObject - { - public: - TargetRequest(Linkage* linkage, CodeGenTarget format); + /// Load a module of the given name. + Module* loadModule(String const& name); - TargetRequest(const TargetRequest& other); + bool isBinaryModuleUpToDate(String fromPath, RiffContainer* container); - Linkage* getLinkage() { return linkage; } - - Session* getSession(); + RefPtr findOrImportModule( + Name* name, + SourceLoc const& loc, + DiagnosticSink* sink, + const LoadedModuleDictionary* loadedModules = nullptr); - CodeGenTarget getTarget() { return optionSet.getEnumOption(CompilerOptionName::Target); } + void prepareDeserializedModule( + SerialContainerDataModule& moduleEntry, + const PathInfo& pathInfo, + Module* module, + DiagnosticSink* sink); - // TypeLayouts created on the fly by reflection API - struct TypeLayoutKey - { - Type* type; - slang::LayoutRules rules; - HashCode getHashCode() const - { - Hasher hasher; - hasher.hashValue(type); - hasher.hashValue(rules); - return hasher.getResult(); - } - bool operator==(TypeLayoutKey other) const - { - return type == other.type && rules == other.rules; - } - }; - Dictionary> typeLayouts; + SourceFile* findFile(Name* name, SourceLoc loc, IncludeSystem& outIncludeSystem); + struct IncludeResult + { + FileDecl* fileDecl; + bool isNew; + }; + IncludeResult findAndIncludeFile( + Module* module, + TranslationUnitRequest* translationUnit, + Name* name, + SourceLoc const& loc, + DiagnosticSink* sink); - Dictionary>& getTypeLayouts() { return typeLayouts; } + SourceManager* getSourceManager() { return m_sourceManager; } - TypeLayout* getTypeLayout(Type* type, slang::LayoutRules rules); + /// Override the source manager for the linkage. + /// + /// This is only used to install a temporary override when + /// parsing stuff from strings (where we don't want to retain + /// full source files for the parsed result). + /// + /// TODO: We should remove the need for this hack. + /// + void setSourceManager(SourceManager* sourceManager) { m_sourceManager = sourceManager; } - CompilerOptionSet& getOptionSet() { return optionSet; } + void setRequireCacheFileSystem(bool requireCacheFileSystem); - CapabilitySet getTargetCaps(); + void setFileSystem(ISlangFileSystem* fileSystem); - void setTargetCaps(CapabilitySet capSet); + DeclRef specializeGeneric( + DeclRef declRef, + List argExprs, + DiagnosticSink* sink); - HLSLToVulkanLayoutOptions* getHLSLToVulkanLayoutOptions(); + DeclRef specializeWithArgTypes( + Expr* funcExpr, + List argTypes, + DiagnosticSink* sink); - private: - Linkage* linkage = nullptr; - CompilerOptionSet optionSet; - CapabilitySet cookedCapabilities; - RefPtr hlslToVulkanOptions; - }; + bool isSpecialized(DeclRef declRef); - /// Given a target request returns which (if any) intermediate source language is required - /// to produce it. - /// - /// If no intermediate source language is required, will return SourceLanguage::Unknown - SourceLanguage getIntermediateSourceLanguageForTarget(TargetProgram* req); + DiagnosticSink::Flags diagnosticSinkFlags = 0; - /// Are resource types "bindless" (implemented as ordinary data) on the given `target`? - bool areResourceTypesBindlessOnTarget(TargetRequest* target); + bool m_requireCacheFileSystem = false; - // Compute the "effective" profile to use when outputting the given entry point - // for the chosen code-generation target. - // - // The stage of the effective profile will always come from the entry point, while - // the profile version (aka "shader model") will be computed as follows: - // - // - If the entry point and target belong to the same profile family, then take - // the latest version between the two (e.g., if the entry point specified `ps_5_1` - // and the target specifies `sm_5_0` then use `sm_5_1` as the version). - // - // - If the entry point and target disagree on the profile family, always use the - // profile family and version from the target. - // - Profile getEffectiveProfile(EntryPoint* entryPoint, TargetRequest* target); + // Modules that have been read in with the -r option + List> m_libModules; + void _stopRetainingParentSession() { m_retainedSession = nullptr; } - /// Given a target returns the required downstream compiler - PassThroughMode getDownstreamCompilerRequiredForTarget(CodeGenTarget target); - /// Given a target returns a downstream compiler the prelude should be taken from. - SourceLanguage getDefaultSourceLanguageForDownstreamCompiler(PassThroughMode compiler); + // Get shared semantics information for reflection purposes. + SharedSemanticsContext* getSemanticsForReflection(); - /// Get the build tag string - const char* getBuildTagString(); +private: + /// The global Slang library session that this linkage is a child of + Session* m_session = nullptr; - struct TypeCheckingCache; + RefPtr m_retainedSession; - struct ContainerTypeKey + /// Tracks state of modules currently being loaded. + /// + /// This information is used to diagnose cases where + /// a user tries to recursively import the same module + /// (possibly along a transitive chain of `import`s). + /// + struct ModuleBeingImportedRAII { - slang::TypeReflection* elementType; - slang::ContainerType containerType; - bool operator==(ContainerTypeKey other) const - { - return elementType == other.elementType && containerType == other.containerType; - } - Slang::HashCode getHashCode() const + public: + ModuleBeingImportedRAII( + Linkage* linkage, + Module* module, + Name* name, + SourceLoc const& importLoc) + : linkage(linkage), module(module), name(name), importLoc(importLoc) { - return Slang::combineHash( - Slang::getHashCode(elementType), Slang::getHashCode(containerType)); + next = linkage->m_modulesBeingImported; + linkage->m_modulesBeingImported = this; } - }; - /// A dictionary of currently loaded modules. Used by `findOrImportModule` to - /// lookup additional loaded modules. - typedef Dictionary LoadedModuleDictionary; + ~ModuleBeingImportedRAII() { linkage->m_modulesBeingImported = next; } - enum ModuleBlobType - { - Source, IR + Linkage* linkage; + Module* module; + Name* name; + SourceLoc importLoc; + ModuleBeingImportedRAII* next; }; - struct SerialContainerDataModule; + // Any modules currently being imported will be listed here + ModuleBeingImportedRAII* m_modulesBeingImported = nullptr; - /// A context for loading and re-using code modules. - class Linkage : public RefObject, public slang::ISession - { - public: - SLANG_REF_OBJECT_IUNKNOWN_ALL - - CompilerOptionSet m_optionSet; - - ISlangUnknown* getInterface(const Guid& guid); - - SLANG_NO_THROW slang::IGlobalSession* SLANG_MCALL getGlobalSession() override; - SLANG_NO_THROW slang::IModule* SLANG_MCALL loadModule( - const char* moduleName, - slang::IBlob** outDiagnostics = nullptr) override; - slang::IModule* loadModuleFromBlob( - const char* moduleName, - const char* path, - slang::IBlob* source, - ModuleBlobType blobType, - slang::IBlob** outDiagnostics = nullptr); - SLANG_NO_THROW slang::IModule* SLANG_MCALL loadModuleFromIRBlob( - const char* moduleName, - const char* path, - slang::IBlob* source, - slang::IBlob** outDiagnostics = nullptr) override; - SLANG_NO_THROW slang::IModule* SLANG_MCALL loadModuleFromSource( - const char* moduleName, - const char* path, - slang::IBlob* source, - slang::IBlob** outDiagnostics = nullptr) override; - SLANG_NO_THROW slang::IModule* SLANG_MCALL loadModuleFromSourceString( - const char* moduleName, - const char* path, - const char* string, - slang::IBlob** outDiagnostics = nullptr) override; - SLANG_NO_THROW SlangResult SLANG_MCALL createCompositeComponentType( - slang::IComponentType* const* componentTypes, - SlangInt componentTypeCount, - slang::IComponentType** outCompositeComponentType, - ISlangBlob** outDiagnostics = nullptr) override; - SLANG_NO_THROW slang::TypeReflection* SLANG_MCALL specializeType( - slang::TypeReflection* type, - slang::SpecializationArg const* specializationArgs, - SlangInt specializationArgCount, - ISlangBlob** outDiagnostics = nullptr) override; - SLANG_NO_THROW slang::TypeLayoutReflection* SLANG_MCALL getTypeLayout( - slang::TypeReflection* type, - SlangInt targetIndex = 0, - slang::LayoutRules rules = slang::LayoutRules::Default, - ISlangBlob** outDiagnostics = nullptr) override; - SLANG_NO_THROW slang::TypeReflection* SLANG_MCALL getContainerType( - slang::TypeReflection* elementType, - slang::ContainerType containerType, - ISlangBlob** outDiagnostics = nullptr) override; - SLANG_NO_THROW slang::TypeReflection* SLANG_MCALL getDynamicType() override; - SLANG_NO_THROW SlangResult SLANG_MCALL getTypeRTTIMangledName( - slang::TypeReflection* type, - ISlangBlob** outNameBlob) override; - SLANG_NO_THROW SlangResult SLANG_MCALL getTypeConformanceWitnessMangledName( - slang::TypeReflection* type, - slang::TypeReflection* interfaceType, - ISlangBlob** outNameBlob) override; - SLANG_NO_THROW SlangResult SLANG_MCALL getTypeConformanceWitnessSequentialID( - slang::TypeReflection* type, - slang::TypeReflection* interfaceType, - uint32_t* outId) override; - SLANG_NO_THROW SlangResult SLANG_MCALL createTypeConformanceComponentType( - slang::TypeReflection* type, - slang::TypeReflection* interfaceType, - slang::ITypeConformance** outConformance, - SlangInt conformanceIdOverride, - ISlangBlob** outDiagnostics) override; - SLANG_NO_THROW SlangResult SLANG_MCALL createCompileRequest( - SlangCompileRequest** outCompileRequest) override; - virtual SLANG_NO_THROW SlangInt SLANG_MCALL getLoadedModuleCount() override; - virtual SLANG_NO_THROW slang::IModule* SLANG_MCALL getLoadedModule(SlangInt index) override; - virtual SLANG_NO_THROW bool SLANG_MCALL isBinaryModuleUpToDate(const char* modulePath, slang::IBlob* binaryModuleBlob) override; - - // Updates the supplied builder with linkage-related information, which includes preprocessor - // defines, the compiler version, and other compiler options. This is then merged with the hash - // produced for the program to produce a key that can be used with the shader cache. - void buildHash(DigestBuilder& builder, SlangInt targetIndex = -1); - - void addTarget( - slang::TargetDesc const& desc); - SlangResult addSearchPath( - char const* path); - SlangResult addPreprocessorDefine( - char const* name, - char const* value); - SlangResult setMatrixLayoutMode( - SlangMatrixLayoutMode mode); - /// Create an initially-empty linkage - Linkage(Session* session, ASTBuilder* astBuilder, Linkage* builtinLinkage); - - /// Dtor - ~Linkage(); - - bool isInLanguageServer() { return contentAssistInfo.checkingMode != ContentAssistCheckingMode::None; } - - /// Get the parent session for this linkage - Session* getSessionImpl() { return m_session; } - - // Information on the targets we are being asked to - // generate code for. - List> targets; - - // Directories to search for `#include` files or `import`ed modules - SearchDirectoryList& getSearchDirectories(); - - // Source manager to help track files loaded - SourceManager m_defaultSourceManager; - SourceManager* m_sourceManager = nullptr; - RefPtr m_cmdLineContext; - - // Name pool for looking up names - NamePool namePool; - - NamePool* getNamePool() { return &namePool; } - - ASTBuilder* getASTBuilder() { return m_astBuilder; } - - RefPtr m_astBuilder; - - // Cache for container types. - Dictionary m_containerTypes; - - // cache used by type checking, implemented in check.cpp - TypeCheckingCache* getTypeCheckingCache(); - void destroyTypeCheckingCache(); - - TypeCheckingCache* m_typeCheckingCache = nullptr; - - // Modules that have been dynamically loaded via `import` - // - // This is a list of unique modules loaded, in the order they were encountered. - List > loadedModulesList; - - // Map from the path (or uniqueIdentity if available) of a module file to its definition - Dictionary> mapPathToLoadedModule; - - // Map from the logical name of a module to its definition - Dictionary> mapNameToLoadedModules; - - // Map from the mangled name of RTTI objects to sequential IDs - // used by `switch`-based dynamic dispatch. - Dictionary mapMangledNameToRTTIObjectIndex; - - // Counters for allocating sequential IDs to witness tables conforming to each interface type. - Dictionary mapInterfaceMangledNameToSequentialIDCounters; - - SearchDirectoryList searchDirectoryCache; - - // The resulting specialized IR module for each entry point request - List> compiledModules; - - ContentAssistInfo contentAssistInfo; - - /// File system implementation to use when loading files from disk. - /// - /// If this member is `null`, a default implementation that tries - /// to use the native OS filesystem will be used instead. - /// - ComPtr m_fileSystem; - - /// The extended file system implementation. Will be set to a default implementation - /// if fileSystem is nullptr. Otherwise it will either be fileSystem's interface, - /// or a wrapped impl that makes fileSystem operate as fileSystemExt - ComPtr m_fileSystemExt; - - /// Get the currenly set file system - ISlangFileSystemExt* getFileSystemExt() { return m_fileSystemExt; } - - /// Load a file into memory using the configured file system. - /// - /// @param path The path to attempt to load from - /// @param outBlob A destination pointer to receive the loaded blob - /// @returns A `SlangResult` to indicate success or failure. - /// - SlangResult loadFile(String const& path, PathInfo& outPathInfo, ISlangBlob** outBlob); - - Expr* parseTermString(String str, Scope* scope); - - Type* specializeType( - Type* unspecializedType, - Int argCount, - Type* const* args, - DiagnosticSink* sink); - - /// Add a new target and return its index. - UInt addTarget( - CodeGenTarget target); - - RefPtr loadModule( - Name* name, - const PathInfo& filePathInfo, - ISlangBlob* fileContentsBlob, - SourceLoc const& loc, - DiagnosticSink* sink, - const LoadedModuleDictionary* additionalLoadedModules, - ModuleBlobType blobType); - - RefPtr loadModuleFromIRBlobImpl( - Name* name, - const PathInfo& filePathInfo, - ISlangBlob* fileContentsBlob, - SourceLoc const& loc, - DiagnosticSink* sink, - const LoadedModuleDictionary* additionalLoadedModules); - RefPtr loadDeserializedModule( - Name* name, - const PathInfo& filePathInfo, - SerialContainerDataModule& m, - DiagnosticSink* sink); + /// Is the given module in the middle of being imported? + bool isBeingImported(Module* module); - SourceFile* loadSourceFile(String pathFrom, String path); + /// Diagnose that an error occured in the process of importing a module + void _diagnoseErrorInImportedModule(DiagnosticSink* sink); - void loadParsedModule( - RefPtr compileRequest, - RefPtr translationUnit, - Name* name, - PathInfo const& pathInfo); + List m_specializedTypes; - /// Load a module of the given name. - Module* loadModule(String const& name); + RefPtr m_semanticsForReflection; +}; - bool isBinaryModuleUpToDate(String fromPath, RiffContainer* container); +/// Shared functionality between front- and back-end compile requests. +/// +/// This is the base class for both `FrontEndCompileRequest` and +/// `BackEndCompileRequest`, and allows a small number of parts of +/// the compiler to be easily invocable from either front-end or +/// back-end work. +/// +class CompileRequestBase : public RefObject +{ + // TODO: We really shouldn't need this type in the long run. + // The few places that rely on it should be refactored to just + // depend on the underlying information (a linkage and a diagnostic + // sink) directly. + // + // The flags to control dumping and validation of IR should be + // moved to some kind of shared settings/options `struct` that + // both front-end and back-end requests can store. + +public: + Session* getSession(); + Linkage* getLinkage() { return m_linkage; } + DiagnosticSink* getSink() { return m_sink; } + SourceManager* getSourceManager() { return getLinkage()->getSourceManager(); } + NamePool* getNamePool() { return getLinkage()->getNamePool(); } + ISlangFileSystemExt* getFileSystemExt() { return getLinkage()->getFileSystemExt(); } + SlangResult loadFile(String const& path, PathInfo& outPathInfo, ISlangBlob** outBlob) + { + return getLinkage()->loadFile(path, outPathInfo, outBlob); + } - RefPtr findOrImportModule( - Name* name, - SourceLoc const& loc, - DiagnosticSink* sink, - const LoadedModuleDictionary* loadedModules = nullptr); +protected: + CompileRequestBase(Linkage* linkage, DiagnosticSink* sink); - void prepareDeserializedModule(SerialContainerDataModule& moduleEntry, const PathInfo& pathInfo, Module* module, DiagnosticSink* sink); +private: + Linkage* m_linkage = nullptr; + DiagnosticSink* m_sink = nullptr; +}; - SourceFile* findFile(Name* name, SourceLoc loc, IncludeSystem& outIncludeSystem); - struct IncludeResult - { - FileDecl* fileDecl; - bool isNew; - }; - IncludeResult findAndIncludeFile(Module* module, TranslationUnitRequest* translationUnit, Name* name, SourceLoc const& loc, DiagnosticSink* sink); +/// A request to compile source code to an AST + IR. +class FrontEndCompileRequest : public CompileRequestBase +{ +public: + /// Note that writers can be parsed as nullptr to disable output, + /// and individual channels set to null to disable them + FrontEndCompileRequest(Linkage* linkage, StdWriters* writers, DiagnosticSink* sink); - SourceManager* getSourceManager() - { - return m_sourceManager; - } + int addEntryPoint(int translationUnitIndex, String const& name, Profile entryPointProfile); - /// Override the source manager for the linkage. - /// - /// This is only used to install a temporary override when - /// parsing stuff from strings (where we don't want to retain - /// full source files for the parsed result). - /// - /// TODO: We should remove the need for this hack. - /// - void setSourceManager(SourceManager* sourceManager) - { - m_sourceManager = sourceManager; - } + // Translation units we are being asked to compile + List> translationUnits; - void setRequireCacheFileSystem(bool requireCacheFileSystem); + // Additional modules that needs to be made visible to `import` while checking. + const LoadedModuleDictionary* additionalLoadedModules = nullptr; - void setFileSystem(ISlangFileSystem* fileSystem); + RefPtr getTranslationUnit(UInt index) + { + return translationUnits[index]; + } - DeclRef specializeGeneric( - DeclRef declRef, - List argExprs, - DiagnosticSink* sink); - - DeclRef specializeWithArgTypes( - Expr* funcExpr, - List argTypes, - DiagnosticSink* sink); - - bool isSpecialized(DeclRef declRef); + // If true then generateIR will serialize out IR, and serialize back in again. Making + // serialization a bottleneck or firewall between the front end and the backend + bool useSerialIRBottleneck = false; - DiagnosticSink::Flags diagnosticSinkFlags = 0; + // If true will serialize and de-serialize with debug information + bool verifyDebugSerialization = false; - bool m_requireCacheFileSystem = false; + CompilerOptionSet optionSet; - // Modules that have been read in with the -r option - List> m_libModules; + List> m_entryPointReqs; - void _stopRetainingParentSession() - { - m_retainedSession = nullptr; - } + List> const& getEntryPointReqs() { return m_entryPointReqs; } + UInt getEntryPointReqCount() { return m_entryPointReqs.getCount(); } + FrontEndEntryPointRequest* getEntryPointReq(UInt index) { return m_entryPointReqs[index]; } - // Get shared semantics information for reflection purposes. - SharedSemanticsContext* getSemanticsForReflection(); + void parseTranslationUnit(TranslationUnitRequest* translationUnit); - private: - /// The global Slang library session that this linkage is a child of - Session* m_session = nullptr; + // Perform primary semantic checking on all + // of the translation units in the program + void checkAllTranslationUnits(); - RefPtr m_retainedSession; + void checkEntryPoints(); - /// Tracks state of modules currently being loaded. - /// - /// This information is used to diagnose cases where - /// a user tries to recursively import the same module - /// (possibly along a transitive chain of `import`s). - /// - struct ModuleBeingImportedRAII - { - public: - ModuleBeingImportedRAII( - Linkage* linkage, - Module* module, - Name* name, - SourceLoc const& importLoc) - : linkage(linkage) - , module(module) - , name(name) - , importLoc(importLoc) - { - next = linkage->m_modulesBeingImported; - linkage->m_modulesBeingImported = this; - } - - ~ModuleBeingImportedRAII() - { - linkage->m_modulesBeingImported = next; - } - - Linkage* linkage; - Module* module; - Name* name; - SourceLoc importLoc; - ModuleBeingImportedRAII* next; - }; + void generateIR(); - // Any modules currently being imported will be listed here - ModuleBeingImportedRAII*m_modulesBeingImported = nullptr; + SlangResult executeActionsInner(); - /// Is the given module in the middle of being imported? - bool isBeingImported(Module* module); + /// Add a translation unit to be compiled. + /// + /// @param language The source language that the translation unit will use (e.g., + /// `SourceLanguage::Slang` + /// @param moduleName The name that will be used for the module compile from the translation + /// unit. + /// + /// If moduleName is passed as nullptr a module name is generated. + /// If all translation units in a compile request use automatically generated + /// module names, then they are guaranteed not to conflict with one another. + /// + /// @return The zero-based index of the translation unit in this compile request. + int addTranslationUnit(SourceLanguage language, Name* moduleName); - /// Diagnose that an error occured in the process of importing a module - void _diagnoseErrorInImportedModule( - DiagnosticSink* sink); + int addTranslationUnit(TranslationUnitRequest* translationUnit); - List m_specializedTypes; + void addTranslationUnitSourceArtifact(int translationUnitIndex, IArtifact* sourceArtifact); - RefPtr m_semanticsForReflection; + void addTranslationUnitSourceBlob( + int translationUnitIndex, + String const& path, + ISlangBlob* sourceBlob); - }; + void addTranslationUnitSourceFile(int translationUnitIndex, String const& path); - /// Shared functionality between front- and back-end compile requests. - /// - /// This is the base class for both `FrontEndCompileRequest` and - /// `BackEndCompileRequest`, and allows a small number of parts of - /// the compiler to be easily invocable from either front-end or - /// back-end work. - /// - class CompileRequestBase : public RefObject - { - // TODO: We really shouldn't need this type in the long run. - // The few places that rely on it should be refactored to just - // depend on the underlying information (a linkage and a diagnostic - // sink) directly. - // - // The flags to control dumping and validation of IR should be - // moved to some kind of shared settings/options `struct` that - // both front-end and back-end requests can store. + /// Get a component type that represents the global scope of the compile request. + ComponentType* getGlobalComponentType() { return m_globalComponentType; } - public: - Session* getSession(); - Linkage* getLinkage() { return m_linkage; } - DiagnosticSink* getSink() { return m_sink; } - SourceManager* getSourceManager() { return getLinkage()->getSourceManager(); } - NamePool* getNamePool() { return getLinkage()->getNamePool(); } - ISlangFileSystemExt* getFileSystemExt() { return getLinkage()->getFileSystemExt(); } - SlangResult loadFile(String const& path, PathInfo& outPathInfo, ISlangBlob** outBlob) { return getLinkage()->loadFile(path, outPathInfo, outBlob); } - - protected: - CompileRequestBase( - Linkage* linkage, - DiagnosticSink* sink); - - private: - Linkage* m_linkage = nullptr; - DiagnosticSink* m_sink = nullptr; - }; + /// Get a component type that represents the global scope of the compile request, plus the + /// requested entry points. + ComponentType* getGlobalAndEntryPointsComponentType() + { + return m_globalAndEntryPointsComponentType; + } - /// A request to compile source code to an AST + IR. - class FrontEndCompileRequest : public CompileRequestBase + List> const& getUnspecializedEntryPoints() { - public: - /// Note that writers can be parsed as nullptr to disable output, - /// and individual channels set to null to disable them - FrontEndCompileRequest( - Linkage* linkage, - StdWriters* writers, - DiagnosticSink* sink); + return m_unspecializedEntryPoints; + } - int addEntryPoint( - int translationUnitIndex, - String const& name, - Profile entryPointProfile); + /// Does the code we are compiling represent part of the Slang core module? + bool m_isCoreModuleCode = false; - // Translation units we are being asked to compile - List > translationUnits; + Name* m_defaultModuleName = nullptr; - // Additional modules that needs to be made visible to `import` while checking. - const LoadedModuleDictionary* additionalLoadedModules = nullptr; + /// The irDumpOptions + IRDumpOptions m_irDumpOptions; - RefPtr getTranslationUnit(UInt index) { return translationUnits[index]; } + /// An "extra" entry point that was added via a library reference + struct ExtraEntryPointInfo + { + Name* name; + Profile profile; + String mangledName; + }; - // If true then generateIR will serialize out IR, and serialize back in again. Making - // serialization a bottleneck or firewall between the front end and the backend - bool useSerialIRBottleneck = false; + /// A list of "extra" entry points added via a library reference + List m_extraEntryPoints; - // If true will serialize and de-serialize with debug information - bool verifyDebugSerialization = false; +private: + /// A component type that includes only the global scopes of the translation unit(s) that were + /// compiled. + RefPtr m_globalComponentType; - CompilerOptionSet optionSet; + /// A component type that extends the global scopes with all of the entry points that were + /// specified. + RefPtr m_globalAndEntryPointsComponentType; - List> m_entryPointReqs; + List> m_unspecializedEntryPoints; - List> const& getEntryPointReqs() { return m_entryPointReqs; } - UInt getEntryPointReqCount() { return m_entryPointReqs.getCount(); } - FrontEndEntryPointRequest* getEntryPointReq(UInt index) { return m_entryPointReqs[index]; } + RefPtr m_writers; +}; - void parseTranslationUnit( - TranslationUnitRequest* translationUnit); +/// A visitor for use with `ComponentType`s, allowing dispatch over the concrete subclasses. +class ComponentTypeVisitor +{ +public: + // The following methods should be overriden in a concrete subclass + // to customize how it acts on each of the concrete types of component. + // + // In cases where the application wants to simply "recurse" on a + // composite, specialized, or legacy component type it can use + // the `visitChildren` methods below. + // + virtual void visitEntryPoint( + EntryPoint* entryPoint, + EntryPoint::EntryPointSpecializationInfo* specializationInfo) = 0; + virtual void visitModule( + Module* module, + Module::ModuleSpecializationInfo* specializationInfo) = 0; + virtual void visitComposite( + CompositeComponentType* composite, + CompositeComponentType::CompositeSpecializationInfo* specializationInfo) = 0; + virtual void visitSpecialized(SpecializedComponentType* specialized) = 0; + virtual void visitTypeConformance(TypeConformance* conformance) = 0; + virtual void visitRenamedEntryPoint( + RenamedEntryPointComponentType* renamedEntryPoint, + EntryPoint::EntryPointSpecializationInfo* specializationInfo) = 0; + +protected: + // These helpers can be used to recurse into the logical children of a + // component type, and are useful for the common case where a visitor + // only cares about a few leaf cases. + // + void visitChildren( + CompositeComponentType* composite, + CompositeComponentType::CompositeSpecializationInfo* specializationInfo); + void visitChildren(SpecializedComponentType* specialized); +}; - // Perform primary semantic checking on all - // of the translation units in the program - void checkAllTranslationUnits(); +/// A `TargetProgram` represents a `ComponentType` specialized for a particular `TargetRequest` +/// +/// TODO: This should probably be renamed to `TargetComponentType`. +/// +/// By binding a component type to a specific target, a `TargetProgram` allows +/// for things like layout to be computed, that fundamentally depend on +/// the choice of target. +/// +/// A `TargetProgram` handles request for compiled kernel code for +/// entry point functions. In practice, kernel code can only be +/// correctly generated when the underlying `ComponentType` is "fully linked" +/// (has no remaining unsatisfied requirements). +/// +class TargetProgram : public RefObject +{ +public: + TargetProgram(ComponentType* componentType, TargetRequest* targetReq); + + /// Get the underlying program + ComponentType* getProgram() { return m_program; } + + /// Get the underlying target + TargetRequest* getTargetReq() { return m_targetReq; } + + /// Get the layout for the program on the target. + /// + /// If this is the first time the layout has been + /// requested, report any errors that arise during + /// layout to the given `sink`. + /// + ProgramLayout* getOrCreateLayout(DiagnosticSink* sink); + + /// Get the layout for the program on the target. + /// + /// This routine assumes that `getOrCreateLayout` + /// has already been called previously. + /// + ProgramLayout* getExistingLayout() + { + SLANG_ASSERT(m_layout); + return m_layout; + } - void checkEntryPoints(); + /// Get the compiled code for an entry point on the target. + /// + /// If this is the first time that code generation has + /// been requested, report any errors that arise during + /// code generation to the given `sink`. + /// + IArtifact* getOrCreateEntryPointResult(Int entryPointIndex, DiagnosticSink* sink); + IArtifact* getOrCreateWholeProgramResult(DiagnosticSink* sink); + + IArtifact* getExistingWholeProgramResult() { return m_wholeProgramResult; } + /// Get the compiled code for an entry point on the target. + /// + /// This routine assumes that `getOrCreateEntryPointResult` + /// has already been called previously. + /// + IArtifact* getExistingEntryPointResult(Int entryPointIndex) + { + return m_entryPointResults[entryPointIndex]; + } - void generateIR(); + IArtifact* _createWholeProgramResult( + DiagnosticSink* sink, + EndToEndCompileRequest* endToEndReq = nullptr); + + /// Internal helper for `getOrCreateEntryPointResult`. + /// + /// This is used so that command-line and API-based + /// requests for code can bottleneck through the same place. + /// + /// Shouldn't be called directly by most code. + /// + IArtifact* _createEntryPointResult( + Int entryPointIndex, + DiagnosticSink* sink, + EndToEndCompileRequest* endToEndReq = nullptr); - SlangResult executeActionsInner(); + RefPtr getOrCreateIRModuleForLayout(DiagnosticSink* sink); - /// Add a translation unit to be compiled. - /// - /// @param language The source language that the translation unit will use (e.g., `SourceLanguage::Slang` - /// @param moduleName The name that will be used for the module compile from the translation unit. - /// - /// If moduleName is passed as nullptr a module name is generated. - /// If all translation units in a compile request use automatically generated - /// module names, then they are guaranteed not to conflict with one another. - /// - /// @return The zero-based index of the translation unit in this compile request. - int addTranslationUnit(SourceLanguage language, Name* moduleName); + RefPtr getExistingIRModuleForLayout() { return m_irModuleForLayout; } - int addTranslationUnit(TranslationUnitRequest* translationUnit); + CompilerOptionSet& getOptionSet() { return m_optionSet; } - void addTranslationUnitSourceArtifact( - int translationUnitIndex, - IArtifact* sourceArtifact); + HLSLToVulkanLayoutOptions* getHLSLToVulkanLayoutOptions() + { + return m_targetReq->getHLSLToVulkanLayoutOptions(); + } - void addTranslationUnitSourceBlob( - int translationUnitIndex, - String const& path, - ISlangBlob* sourceBlob); + bool shouldEmitSPIRVDirectly() + { + return isKhronosTarget(m_targetReq) && getOptionSet().shouldEmitSPIRVDirectly(); + } - void addTranslationUnitSourceFile( - int translationUnitIndex, - String const& path); +private: + RefPtr createIRModuleForLayout(DiagnosticSink* sink); - /// Get a component type that represents the global scope of the compile request. - ComponentType* getGlobalComponentType() { return m_globalComponentType; } + // The program being compiled or laid out + ComponentType* m_program; - /// Get a component type that represents the global scope of the compile request, plus the requested entry points. - ComponentType* getGlobalAndEntryPointsComponentType() { return m_globalAndEntryPointsComponentType; } + // The target that code/layout will be generated for + TargetRequest* m_targetReq; - List> const& getUnspecializedEntryPoints() { return m_unspecializedEntryPoints; } + // The computed layout, if it has been generated yet + RefPtr m_layout; - /// Does the code we are compiling represent part of the Slang core module? - bool m_isCoreModuleCode = false; + CompilerOptionSet m_optionSet; - Name* m_defaultModuleName = nullptr; + // Generated compile results for each entry point + // in the parent `Program` (indexing matches + // the order they are given in the `Program`) + ComPtr m_wholeProgramResult; + List> m_entryPointResults; - /// The irDumpOptions - IRDumpOptions m_irDumpOptions; + RefPtr m_irModuleForLayout; +}; - /// An "extra" entry point that was added via a library reference - struct ExtraEntryPointInfo - { - Name* name; - Profile profile; - String mangledName; - }; +/// A back-end-specific object to track optional feaures/capabilities/extensions +/// that are discovered to be used by a program/kernel as part of code generation. +class ExtensionTracker : public RefObject +{ + // TODO: The existence of this type is evidence of a design/architecture problem. + // + // A better formulation of things requires a few key changes: + // + // 1. All optional capabilities need to be enumerated as part of the `CapabilitySet` + // system, so that they can be reasoned about uniformly across different targets + // and different layers of the compiler. + // + // 2. The front-end should be responsible for either or both of: + // + // * Checking that `public` or otherwise externally-visible items (declarations/definitions) + // explicitly declare the capabilities they require, and that they only ever + // make use of items that are comatible with those required capabilities. + // + // * Inferring the capabilities required by items that are not externally visible, + // and attaching those capabilities explicit as a modifier or other synthesized AST node. + // + // 3. The capabilities required by a given `ComponentType` and its entry points should be + // explicitly know-able, and they should be something we can compare to the capabilities + // of a code generation target *before* back-end code generation is started. We should be + // able to issue error messages around lacking capabilities in a way the user can understand, + // in terms of the high-level-language entities. - /// A list of "extra" entry points added via a library reference - List m_extraEntryPoints; +public: +}; - private: - /// A component type that includes only the global scopes of the translation unit(s) that were compiled. - RefPtr m_globalComponentType; +/// A context for code generation in the compiler back-end +struct CodeGenContext +{ +public: + typedef List EntryPointIndices; - /// A component type that extends the global scopes with all of the entry points that were specified. - RefPtr m_globalAndEntryPointsComponentType; + struct Shared + { + public: + Shared( + TargetProgram* targetProgram, + EntryPointIndices const& entryPointIndices, + DiagnosticSink* sink, + EndToEndCompileRequest* endToEndReq) + : targetProgram(targetProgram) + , entryPointIndices(entryPointIndices) + , sink(sink) + , endToEndReq(endToEndReq) + { + } - List> m_unspecializedEntryPoints; + // Shared( + // TargetProgram* targetProgram, + // EndToEndCompileRequest* endToEndReq); - RefPtr m_writers; + TargetProgram* targetProgram = nullptr; + EntryPointIndices entryPointIndices; + DiagnosticSink* sink = nullptr; + EndToEndCompileRequest* endToEndReq = nullptr; }; - /// A visitor for use with `ComponentType`s, allowing dispatch over the concrete subclasses. - class ComponentTypeVisitor + CodeGenContext(Shared* shared) + : m_shared(shared), m_targetFormat(shared->targetProgram->getTargetReq()->getTarget()) { - public: - // The following methods should be overriden in a concrete subclass - // to customize how it acts on each of the concrete types of component. - // - // In cases where the application wants to simply "recurse" on a - // composite, specialized, or legacy component type it can use - // the `visitChildren` methods below. - // - virtual void visitEntryPoint(EntryPoint* entryPoint, EntryPoint::EntryPointSpecializationInfo* specializationInfo) = 0; - virtual void visitModule(Module* module, Module::ModuleSpecializationInfo* specializationInfo) = 0; - virtual void visitComposite(CompositeComponentType* composite, CompositeComponentType::CompositeSpecializationInfo* specializationInfo) = 0; - virtual void visitSpecialized(SpecializedComponentType* specialized) = 0; - virtual void visitTypeConformance(TypeConformance* conformance) = 0; - virtual void visitRenamedEntryPoint( - RenamedEntryPointComponentType* renamedEntryPoint, - EntryPoint::EntryPointSpecializationInfo* specializationInfo) = 0; - - protected: - // These helpers can be used to recurse into the logical children of a - // component type, and are useful for the common case where a visitor - // only cares about a few leaf cases. - // - void visitChildren(CompositeComponentType* composite, CompositeComponentType::CompositeSpecializationInfo* specializationInfo); - void visitChildren(SpecializedComponentType* specialized); - }; + } - /// A `TargetProgram` represents a `ComponentType` specialized for a particular `TargetRequest` - /// - /// TODO: This should probably be renamed to `TargetComponentType`. - /// - /// By binding a component type to a specific target, a `TargetProgram` allows - /// for things like layout to be computed, that fundamentally depend on - /// the choice of target. - /// - /// A `TargetProgram` handles request for compiled kernel code for - /// entry point functions. In practice, kernel code can only be - /// correctly generated when the underlying `ComponentType` is "fully linked" - /// (has no remaining unsatisfied requirements). - /// - class TargetProgram : public RefObject + CodeGenContext( + CodeGenContext* base, + CodeGenTarget targetFormat, + ExtensionTracker* extensionTracker = nullptr) + : m_shared(base->m_shared) + , m_targetFormat(targetFormat) + , m_extensionTracker(extensionTracker) { - public: - TargetProgram( - ComponentType* componentType, - TargetRequest* targetReq); - - /// Get the underlying program - ComponentType* getProgram() { return m_program; } - - /// Get the underlying target - TargetRequest* getTargetReq() { return m_targetReq; } - - /// Get the layout for the program on the target. - /// - /// If this is the first time the layout has been - /// requested, report any errors that arise during - /// layout to the given `sink`. - /// - ProgramLayout* getOrCreateLayout(DiagnosticSink* sink); - - /// Get the layout for the program on the target. - /// - /// This routine assumes that `getOrCreateLayout` - /// has already been called previously. - /// - ProgramLayout* getExistingLayout() - { - SLANG_ASSERT(m_layout); - return m_layout; - } - - /// Get the compiled code for an entry point on the target. - /// - /// If this is the first time that code generation has - /// been requested, report any errors that arise during - /// code generation to the given `sink`. - /// - IArtifact* getOrCreateEntryPointResult(Int entryPointIndex, DiagnosticSink* sink); - IArtifact* getOrCreateWholeProgramResult(DiagnosticSink* sink); + } - IArtifact* getExistingWholeProgramResult() - { - return m_wholeProgramResult; - } - /// Get the compiled code for an entry point on the target. - /// - /// This routine assumes that `getOrCreateEntryPointResult` - /// has already been called previously. - /// - IArtifact* getExistingEntryPointResult(Int entryPointIndex) - { - return m_entryPointResults[entryPointIndex]; - } + /// Get the diagnostic sink + DiagnosticSink* getSink() { return m_shared->sink; } - IArtifact* _createWholeProgramResult( - DiagnosticSink* sink, - EndToEndCompileRequest* endToEndReq = nullptr); - - /// Internal helper for `getOrCreateEntryPointResult`. - /// - /// This is used so that command-line and API-based - /// requests for code can bottleneck through the same place. - /// - /// Shouldn't be called directly by most code. - /// - IArtifact* _createEntryPointResult( - Int entryPointIndex, - DiagnosticSink* sink, - EndToEndCompileRequest* endToEndReq = nullptr); - - RefPtr getOrCreateIRModuleForLayout(DiagnosticSink* sink); - - RefPtr getExistingIRModuleForLayout() - { - return m_irModuleForLayout; - } + TargetProgram* getTargetProgram() { return m_shared->targetProgram; } - CompilerOptionSet& getOptionSet() { return m_optionSet; } + EntryPointIndices const& getEntryPointIndices() { return m_shared->entryPointIndices; } - HLSLToVulkanLayoutOptions* getHLSLToVulkanLayoutOptions() { return m_targetReq->getHLSLToVulkanLayoutOptions(); } + CodeGenTarget getTargetFormat() { return m_targetFormat; } - bool shouldEmitSPIRVDirectly() - { - return isKhronosTarget(m_targetReq) && getOptionSet().shouldEmitSPIRVDirectly(); - } + ExtensionTracker* getExtensionTracker() { return m_extensionTracker; } - private: - RefPtr createIRModuleForLayout(DiagnosticSink* sink); + TargetRequest* getTargetReq() { return getTargetProgram()->getTargetReq(); } - // The program being compiled or laid out - ComponentType* m_program; + CapabilitySet getTargetCaps() { return getTargetReq()->getTargetCaps(); } - // The target that code/layout will be generated for - TargetRequest* m_targetReq; + CodeGenTarget getFinalTargetFormat() { return getTargetReq()->getTarget(); } - // The computed layout, if it has been generated yet - RefPtr m_layout; + ComponentType* getProgram() { return getTargetProgram()->getProgram(); } - CompilerOptionSet m_optionSet; + Linkage* getLinkage() { return getProgram()->getLinkage(); } - // Generated compile results for each entry point - // in the parent `Program` (indexing matches - // the order they are given in the `Program`) - ComPtr m_wholeProgramResult; - List> m_entryPointResults; + Session* getSession() { return getLinkage()->getSessionImpl(); } - RefPtr m_irModuleForLayout; - }; + /// Get the source manager + SourceManager* getSourceManager() { return getLinkage()->getSourceManager(); } - /// A back-end-specific object to track optional feaures/capabilities/extensions - /// that are discovered to be used by a program/kernel as part of code generation. - class ExtensionTracker : public RefObject - { - // TODO: The existence of this type is evidence of a design/architecture problem. - // - // A better formulation of things requires a few key changes: - // - // 1. All optional capabilities need to be enumerated as part of the `CapabilitySet` - // system, so that they can be reasoned about uniformly across different targets - // and different layers of the compiler. - // - // 2. The front-end should be responsible for either or both of: - // - // * Checking that `public` or otherwise externally-visible items (declarations/definitions) - // explicitly declare the capabilities they require, and that they only ever - // make use of items that are comatible with those required capabilities. - // - // * Inferring the capabilities required by items that are not externally visible, - // and attaching those capabilities explicit as a modifier or other synthesized AST node. - // - // 3. The capabilities required by a given `ComponentType` and its entry points should be - // explicitly know-able, and they should be something we can compare to the capabilities - // of a code generation target *before* back-end code generation is started. We should be - // able to issue error messages around lacking capabilities in a way the user can understand, - // in terms of the high-level-language entities. + ISlangFileSystemExt* getFileSystemExt() { return getLinkage()->getFileSystemExt(); } - public: - }; + EndToEndCompileRequest* isEndToEndCompile() { return m_shared->endToEndReq; } - /// A context for code generation in the compiler back-end - struct CodeGenContext - { - public: - typedef List EntryPointIndices; + EndToEndCompileRequest* isPassThroughEnabled(); - struct Shared - { - public: - Shared( - TargetProgram* targetProgram, - EntryPointIndices const& entryPointIndices, - DiagnosticSink* sink, - EndToEndCompileRequest* endToEndReq) - : targetProgram(targetProgram) - , entryPointIndices(entryPointIndices) - , sink(sink) - , endToEndReq(endToEndReq) - {} - -// Shared( -// TargetProgram* targetProgram, -// EndToEndCompileRequest* endToEndReq); - - TargetProgram* targetProgram = nullptr; - EntryPointIndices entryPointIndices; - DiagnosticSink* sink = nullptr; - EndToEndCompileRequest* endToEndReq = nullptr; - }; + Count getEntryPointCount() { return getEntryPointIndices().getCount(); } - CodeGenContext( - Shared* shared) - : m_shared(shared) - , m_targetFormat(shared->targetProgram->getTargetReq()->getTarget()) - {} - - CodeGenContext( - CodeGenContext* base, - CodeGenTarget targetFormat, - ExtensionTracker* extensionTracker = nullptr) - : m_shared(base->m_shared) - , m_targetFormat(targetFormat) - , m_extensionTracker(extensionTracker) - {} - - /// Get the diagnostic sink - DiagnosticSink* getSink() - { - return m_shared->sink; - } + EntryPoint* getEntryPoint(Index index) { return getProgram()->getEntryPoint(index); } - TargetProgram* getTargetProgram() - { - return m_shared->targetProgram; - } + Index getSingleEntryPointIndex() + { + SLANG_ASSERT(getEntryPointCount() == 1); + return getEntryPointIndices()[0]; + } - EntryPointIndices const& getEntryPointIndices() - { - return m_shared->entryPointIndices; - } + // - CodeGenTarget getTargetFormat() - { - return m_targetFormat; - } + IRDumpOptions getIRDumpOptions(); - ExtensionTracker* getExtensionTracker() - { - return m_extensionTracker; - } + bool shouldValidateIR(); + bool shouldDumpIR(); + bool shouldReportCheckpointIntermediates(); - TargetRequest* getTargetReq() - { - return getTargetProgram()->getTargetReq(); - } + bool shouldTrackLiveness(); - CapabilitySet getTargetCaps() - { - return getTargetReq()->getTargetCaps(); - } + bool shouldDumpIntermediates(); + String getIntermediateDumpPrefix(); - CodeGenTarget getFinalTargetFormat() - { - return getTargetReq()->getTarget(); - } + bool getUseUnknownImageFormatAsDefault(); - ComponentType* getProgram() - { - return getTargetProgram()->getProgram(); - } + bool isSpecializationDisabled(); - Linkage* getLinkage() - { - return getProgram()->getLinkage(); - } + bool shouldSkipSPIRVValidation(); - Session* getSession() - { - return getLinkage()->getSessionImpl(); - } + SlangResult requireTranslationUnitSourceFiles(); - /// Get the source manager - SourceManager* getSourceManager() - { - return getLinkage()->getSourceManager(); - } + // - ISlangFileSystemExt* getFileSystemExt() - { - return getLinkage()->getFileSystemExt(); - } + SlangResult emitEntryPoints(ComPtr& outArtifact); - EndToEndCompileRequest* isEndToEndCompile() - { - return m_shared->endToEndReq; - } + SlangResult emitPrecompiledDownstreamIR(ComPtr& outArtifact); - EndToEndCompileRequest* isPassThroughEnabled(); + void maybeDumpIntermediate(IArtifact* artifact); - Count getEntryPointCount() - { - return getEntryPointIndices().getCount(); - } + // Used to cause instructions available in precompiled blobs to be + // removed between IR linking and target source generation. + bool removeAvailableInDownstreamIR = false; - EntryPoint* getEntryPoint(Index index) - { - return getProgram()->getEntryPoint(index); - } +protected: + CodeGenTarget m_targetFormat = CodeGenTarget::Unknown; + ExtensionTracker* m_extensionTracker = nullptr; - Index getSingleEntryPointIndex() - { - SLANG_ASSERT(getEntryPointCount() == 1); - return getEntryPointIndices()[0]; - } + /// Will output assembly as well as the artifact if appropriate for the artifact type for + /// assembly output and conversion is possible + void _dumpIntermediateMaybeWithAssembly(IArtifact* artifact); - // + void _dumpIntermediate(IArtifact* artifact); + void _dumpIntermediate(const ArtifactDesc& desc, void const* data, size_t size); - IRDumpOptions getIRDumpOptions(); + /* Emits entry point source taking into account if a pass-through or not. Uses 'targetFormat' to + determine the target (not targetReq) */ + SlangResult emitEntryPointsSource(ComPtr& outArtifact); - bool shouldValidateIR(); - bool shouldDumpIR(); - bool shouldReportCheckpointIntermediates(); + SlangResult emitEntryPointsSourceFromIR(ComPtr& outArtifact); - bool shouldTrackLiveness(); + SlangResult emitWithDownstreamForEntryPoints(ComPtr& outArtifact); - bool shouldDumpIntermediates(); - String getIntermediateDumpPrefix(); + /* Determines a suitable filename to identify the input for a given entry point being compiled. + If the end-to-end compile is a pass-through case, will attempt to find the (unique) source file + pathname for the translation unit containing the entry point at `entryPointIndex. + If the compilation is not in a pass-through case, then always returns `"slang-generated"`. + @param endToEndReq The end-to-end compile request which might be using pass-through compilation + @param entryPointIndex The index of the entry point to compute a filename for. + @return the appropriate source filename */ + String calcSourcePathForEntryPoints(); - bool getUseUnknownImageFormatAsDefault(); + TranslationUnitRequest* findPassThroughTranslationUnit(Int entryPointIndex); - bool isSpecializationDisabled(); - bool shouldSkipSPIRVValidation(); + SlangResult _emitEntryPoints(ComPtr& outArtifact); - SlangResult requireTranslationUnitSourceFiles(); +private: + Shared* m_shared = nullptr; +}; - // +/// A compile request that spans the front and back ends of the compiler +/// +/// This is what the command-line `slangc` uses, as well as the legacy +/// C API. It ties together the functionality of `Linkage`, +/// `FrontEndCompileRequest`, and `BackEndCompileRequest`, plus a small +/// number of additional features that primarily make sense for +/// command-line usage. +/// +class EndToEndCompileRequest : public RefObject, public slang::ICompileRequest +{ +public: + SLANG_CLASS_GUID(0xce6d2383, 0xee1b, 0x4fd7, {0xa0, 0xf, 0xb8, 0xb6, 0x33, 0x12, 0x95, 0xc8}) + + // ISlangUnknown + SLANG_NO_THROW SlangResult SLANG_MCALL queryInterface(SlangUUID const& uuid, void** outObject) + SLANG_OVERRIDE; + SLANG_REF_OBJECT_IUNKNOWN_ADD_REF + SLANG_REF_OBJECT_IUNKNOWN_RELEASE + + // slang::ICompileRequest + virtual SLANG_NO_THROW void SLANG_MCALL setFileSystem(ISlangFileSystem* fileSystem) + SLANG_OVERRIDE; + virtual SLANG_NO_THROW void SLANG_MCALL setCompileFlags(SlangCompileFlags flags) SLANG_OVERRIDE; + virtual SLANG_NO_THROW SlangCompileFlags SLANG_MCALL getCompileFlags() SLANG_OVERRIDE; + virtual SLANG_NO_THROW void SLANG_MCALL setDumpIntermediates(int enable) SLANG_OVERRIDE; + virtual SLANG_NO_THROW void SLANG_MCALL setDumpIntermediatePrefix(const char* prefix) + SLANG_OVERRIDE; + virtual SLANG_NO_THROW void SLANG_MCALL setEnableEffectAnnotations(bool value) SLANG_OVERRIDE; + virtual SLANG_NO_THROW void SLANG_MCALL setLineDirectiveMode(SlangLineDirectiveMode mode) + SLANG_OVERRIDE; + virtual SLANG_NO_THROW void SLANG_MCALL setCodeGenTarget(SlangCompileTarget target) + SLANG_OVERRIDE; + virtual SLANG_NO_THROW int SLANG_MCALL addCodeGenTarget(SlangCompileTarget target) + SLANG_OVERRIDE; + virtual SLANG_NO_THROW void SLANG_MCALL + setTargetProfile(int targetIndex, SlangProfileID profile) SLANG_OVERRIDE; + virtual SLANG_NO_THROW void SLANG_MCALL setTargetFlags(int targetIndex, SlangTargetFlags flags) + SLANG_OVERRIDE; + virtual SLANG_NO_THROW void SLANG_MCALL + setTargetFloatingPointMode(int targetIndex, SlangFloatingPointMode mode) SLANG_OVERRIDE; + virtual SLANG_NO_THROW void SLANG_MCALL + setTargetMatrixLayoutMode(int targetIndex, SlangMatrixLayoutMode mode) SLANG_OVERRIDE; + virtual SLANG_NO_THROW void SLANG_MCALL + setTargetForceGLSLScalarBufferLayout(int targetIndex, bool value) SLANG_OVERRIDE; + virtual SLANG_NO_THROW void SLANG_MCALL setTargetForceDXLayout(int targetIndex, bool value) + SLANG_OVERRIDE; + virtual SLANG_NO_THROW void SLANG_MCALL + setTargetGenerateWholeProgram(int targetIndex, bool value) SLANG_OVERRIDE; + virtual SLANG_NO_THROW void SLANG_MCALL setTargetEmbedDownstreamIR(int targetIndex, bool value) + SLANG_OVERRIDE; + virtual SLANG_NO_THROW void SLANG_MCALL setMatrixLayoutMode(SlangMatrixLayoutMode mode) + SLANG_OVERRIDE; + virtual SLANG_NO_THROW void SLANG_MCALL setDebugInfoLevel(SlangDebugInfoLevel level) + SLANG_OVERRIDE; + virtual SLANG_NO_THROW void SLANG_MCALL setOptimizationLevel(SlangOptimizationLevel level) + SLANG_OVERRIDE; + virtual SLANG_NO_THROW void SLANG_MCALL setOutputContainerFormat(SlangContainerFormat format) + SLANG_OVERRIDE; + virtual SLANG_NO_THROW void SLANG_MCALL setPassThrough(SlangPassThrough passThrough) + SLANG_OVERRIDE; + virtual SLANG_NO_THROW void SLANG_MCALL + setDiagnosticCallback(SlangDiagnosticCallback callback, void const* userData) SLANG_OVERRIDE; + virtual SLANG_NO_THROW void SLANG_MCALL + setWriter(SlangWriterChannel channel, ISlangWriter* writer) SLANG_OVERRIDE; + virtual SLANG_NO_THROW ISlangWriter* SLANG_MCALL getWriter(SlangWriterChannel channel) + SLANG_OVERRIDE; + virtual SLANG_NO_THROW void SLANG_MCALL addSearchPath(const char* searchDir) SLANG_OVERRIDE; + virtual SLANG_NO_THROW void SLANG_MCALL + addPreprocessorDefine(const char* key, const char* value) SLANG_OVERRIDE; + virtual SLANG_NO_THROW SlangResult SLANG_MCALL + processCommandLineArguments(char const* const* args, int argCount) SLANG_OVERRIDE; + virtual SLANG_NO_THROW int SLANG_MCALL + addTranslationUnit(SlangSourceLanguage language, char const* name) SLANG_OVERRIDE; + virtual SLANG_NO_THROW void SLANG_MCALL setDefaultModuleName(const char* defaultModuleName) + SLANG_OVERRIDE; + virtual SLANG_NO_THROW void SLANG_MCALL addTranslationUnitPreprocessorDefine( + int translationUnitIndex, + const char* key, + const char* value) SLANG_OVERRIDE; + virtual SLANG_NO_THROW void SLANG_MCALL + addTranslationUnitSourceFile(int translationUnitIndex, char const* path) SLANG_OVERRIDE; + virtual SLANG_NO_THROW void SLANG_MCALL addTranslationUnitSourceString( + int translationUnitIndex, + char const* path, + char const* source) SLANG_OVERRIDE; + virtual SLANG_NO_THROW SlangResult SLANG_MCALL addLibraryReference( + const char* basePath, + const void* libData, + size_t libDataSize) SLANG_OVERRIDE; + virtual SLANG_NO_THROW void SLANG_MCALL addTranslationUnitSourceStringSpan( + int translationUnitIndex, + char const* path, + char const* sourceBegin, + char const* sourceEnd) SLANG_OVERRIDE; + virtual SLANG_NO_THROW void SLANG_MCALL addTranslationUnitSourceBlob( + int translationUnitIndex, + char const* path, + ISlangBlob* sourceBlob) SLANG_OVERRIDE; + virtual SLANG_NO_THROW int SLANG_MCALL + addEntryPoint(int translationUnitIndex, char const* name, SlangStage stage) SLANG_OVERRIDE; + virtual SLANG_NO_THROW int SLANG_MCALL addEntryPointEx( + int translationUnitIndex, + char const* name, + SlangStage stage, + int genericArgCount, + char const** genericArgs) SLANG_OVERRIDE; + virtual SLANG_NO_THROW SlangResult SLANG_MCALL + setGlobalGenericArgs(int genericArgCount, char const** genericArgs) SLANG_OVERRIDE; + virtual SLANG_NO_THROW SlangResult SLANG_MCALL + setTypeNameForGlobalExistentialTypeParam(int slotIndex, char const* typeName) SLANG_OVERRIDE; + virtual SLANG_NO_THROW SlangResult SLANG_MCALL setTypeNameForEntryPointExistentialTypeParam( + int entryPointIndex, + int slotIndex, + char const* typeName) SLANG_OVERRIDE; + virtual SLANG_NO_THROW void SLANG_MCALL setAllowGLSLInput(bool value) SLANG_OVERRIDE; + virtual SLANG_NO_THROW SlangResult SLANG_MCALL compile() SLANG_OVERRIDE; + virtual SLANG_NO_THROW char const* SLANG_MCALL getDiagnosticOutput() SLANG_OVERRIDE; + virtual SLANG_NO_THROW SlangResult SLANG_MCALL getDiagnosticOutputBlob(ISlangBlob** outBlob) + SLANG_OVERRIDE; + virtual SLANG_NO_THROW int SLANG_MCALL getDependencyFileCount() SLANG_OVERRIDE; + virtual SLANG_NO_THROW char const* SLANG_MCALL getDependencyFilePath(int index) SLANG_OVERRIDE; + virtual SLANG_NO_THROW int SLANG_MCALL getTranslationUnitCount() SLANG_OVERRIDE; + virtual SLANG_NO_THROW char const* SLANG_MCALL getEntryPointSource(int entryPointIndex) + SLANG_OVERRIDE; + virtual SLANG_NO_THROW void const* SLANG_MCALL + getEntryPointCode(int entryPointIndex, size_t* outSize) SLANG_OVERRIDE; + virtual SLANG_NO_THROW SlangResult SLANG_MCALL getEntryPointCodeBlob( + int entryPointIndex, + int targetIndex, + ISlangBlob** outBlob) SLANG_OVERRIDE; + virtual SLANG_NO_THROW SlangResult SLANG_MCALL getEntryPointHostCallable( + int entryPointIndex, + int targetIndex, + ISlangSharedLibrary** outSharedLibrary) SLANG_OVERRIDE; + virtual SLANG_NO_THROW SlangResult SLANG_MCALL + getTargetCodeBlob(int targetIndex, ISlangBlob** outBlob) SLANG_OVERRIDE; + virtual SLANG_NO_THROW SlangResult SLANG_MCALL + getTargetHostCallable(int targetIndex, ISlangSharedLibrary** outSharedLibrary) SLANG_OVERRIDE; + virtual SLANG_NO_THROW void const* SLANG_MCALL getCompileRequestCode(size_t* outSize) + SLANG_OVERRIDE; + virtual SLANG_NO_THROW ISlangMutableFileSystem* SLANG_MCALL + getCompileRequestResultAsFileSystem() SLANG_OVERRIDE; + virtual SLANG_NO_THROW SlangResult SLANG_MCALL getContainerCode(ISlangBlob** outBlob) + SLANG_OVERRIDE; + virtual SLANG_NO_THROW SlangResult SLANG_MCALL + loadRepro(ISlangFileSystem* fileSystem, const void* data, size_t size) SLANG_OVERRIDE; + virtual SLANG_NO_THROW SlangResult SLANG_MCALL saveRepro(ISlangBlob** outBlob) SLANG_OVERRIDE; + virtual SLANG_NO_THROW SlangResult SLANG_MCALL enableReproCapture() SLANG_OVERRIDE; + virtual SLANG_NO_THROW SlangResult SLANG_MCALL getProgram(slang::IComponentType** outProgram) + SLANG_OVERRIDE; + virtual SLANG_NO_THROW SlangResult SLANG_MCALL + getEntryPoint(SlangInt entryPointIndex, slang::IComponentType** outEntryPoint) SLANG_OVERRIDE; + virtual SLANG_NO_THROW SlangResult SLANG_MCALL + getModule(SlangInt translationUnitIndex, slang::IModule** outModule) SLANG_OVERRIDE; + virtual SLANG_NO_THROW SlangResult SLANG_MCALL getSession(slang::ISession** outSession) + SLANG_OVERRIDE; + virtual SLANG_NO_THROW SlangReflection* SLANG_MCALL getReflection() SLANG_OVERRIDE; + virtual SLANG_NO_THROW void SLANG_MCALL setCommandLineCompilerMode() SLANG_OVERRIDE; + virtual SLANG_NO_THROW SlangResult SLANG_MCALL + addTargetCapability(SlangInt targetIndex, SlangCapabilityID capability) SLANG_OVERRIDE; + virtual SLANG_NO_THROW SlangResult SLANG_MCALL + getProgramWithEntryPoints(slang::IComponentType** outProgram) SLANG_OVERRIDE; + virtual SLANG_NO_THROW SlangResult SLANG_MCALL isParameterLocationUsed( + SlangInt entryPointIndex, + SlangInt targetIndex, + SlangParameterCategory category, + SlangUInt spaceIndex, + SlangUInt registerIndex, + bool& outUsed) SLANG_OVERRIDE; + virtual SLANG_NO_THROW void SLANG_MCALL + setTargetLineDirectiveMode(SlangInt targetIndex, SlangLineDirectiveMode mode) SLANG_OVERRIDE; + virtual SLANG_NO_THROW void SLANG_MCALL + overrideDiagnosticSeverity(SlangInt messageID, SlangSeverity overrideSeverity) SLANG_OVERRIDE; + virtual SLANG_NO_THROW SlangDiagnosticFlags SLANG_MCALL getDiagnosticFlags() SLANG_OVERRIDE; + virtual SLANG_NO_THROW void SLANG_MCALL setDiagnosticFlags(SlangDiagnosticFlags flags) + SLANG_OVERRIDE; + virtual SLANG_NO_THROW void SLANG_MCALL setDebugInfoFormat(SlangDebugInfoFormat format) + SLANG_OVERRIDE; + virtual SLANG_NO_THROW void SLANG_MCALL setReportDownstreamTime(bool value) SLANG_OVERRIDE; + virtual SLANG_NO_THROW void SLANG_MCALL setReportPerfBenchmark(bool value) SLANG_OVERRIDE; + virtual SLANG_NO_THROW void SLANG_MCALL setSkipSPIRVValidation(bool value) SLANG_OVERRIDE; + virtual SLANG_NO_THROW void SLANG_MCALL + setTargetUseMinimumSlangOptimization(int targetIndex, bool value) SLANG_OVERRIDE; + virtual SLANG_NO_THROW void SLANG_MCALL setIgnoreCapabilityCheck(bool value) SLANG_OVERRIDE; + virtual SLANG_NO_THROW SlangResult SLANG_MCALL + getCompileTimeProfile(ISlangProfiler** compileTimeProfile, bool isClear) SLANG_OVERRIDE; + + void setTrackLiveness(bool v); + + EndToEndCompileRequest(Session* session); + + EndToEndCompileRequest(Linkage* linkage); + + ~EndToEndCompileRequest(); + + // If enabled will emit IR + bool m_emitIr = false; + + // What container format are we being asked to generate? + // If it's set to a format, the container blob will be calculated during compile + ContainerFormat m_containerFormat = ContainerFormat::None; + + /// Where the container is stored. This is calculated as part of compile if m_containerFormat is + /// set to a supported format. + ComPtr m_containerArtifact; + /// Holds the container as a file system + ComPtr m_containerFileSystem; + + /// File system used by repro system if a file couldn't be found within the repro (or associated + /// directory) + ComPtr m_reproFallbackFileSystem = + ComPtr(OSFileSystem::getExtSingleton()); + + // Path to output container to + String m_containerOutputPath; + + // Should we just pass the input to another compiler? + PassThroughMode m_passThrough = PassThroughMode::None; + + /// If output should be source embedded, define the style of the embedding + SourceEmbedUtil::Style m_sourceEmbedStyle = SourceEmbedUtil::Style::None; + /// The language to be used for source embedding + SourceLanguage m_sourceEmbedLanguage = SourceLanguage::C; + /// Source embed variable name. Note may be used as a basis for names if multiple items written + String m_sourceEmbedName; + + /// Source code for the specialization arguments to use for the global specialization parameters + /// of the program. + List m_globalSpecializationArgStrings; + + // Are we being driven by the command-line `slangc`, and should act accordingly? + bool m_isCommandLineCompile = false; + + String m_diagnosticOutput; + + /// A blob holding the diagnostic output + ComPtr m_diagnosticOutputBlob; + + /// Per-entry-point information not tracked by other compile requests + class EntryPointInfo : public RefObject + { + public: + /// Source code for the specialization arguments to use for the specialization parameters of + /// the entry point. + List specializationArgStrings; + }; + List m_entryPoints; - SlangResult emitEntryPoints(ComPtr& outArtifact); + /// Per-target information only needed for command-line compiles + class TargetInfo : public RefObject + { + public: + // Requested output paths for each entry point. + // An empty string indices no output desired for + // the given entry point. + Dictionary entryPointOutputPaths; + String wholeTargetOutputPath; + CompilerOptionSet targetOptions; + }; + Dictionary> m_targetInfos; - SlangResult emitPrecompiledDownstreamIR(ComPtr& outArtifact); + CompilerOptionSet m_optionSetForDefaultTarget; - void maybeDumpIntermediate(IArtifact* artifact); + CompilerOptionSet& getTargetOptionSet(TargetRequest* req); - // Used to cause instructions available in precompiled blobs to be - // removed between IR linking and target source generation. - bool removeAvailableInDownstreamIR = false; + CompilerOptionSet& getTargetOptionSet(Index targetIndex); - protected: - CodeGenTarget m_targetFormat = CodeGenTarget::Unknown; - ExtensionTracker* m_extensionTracker = nullptr; + String m_dependencyOutputPath; - /// Will output assembly as well as the artifact if appropriate for the artifact type for assembly output - /// and conversion is possible - void _dumpIntermediateMaybeWithAssembly(IArtifact* artifact); + /// Writes the modules in a container to the stream + SlangResult writeContainerToStream(Stream* stream); - void _dumpIntermediate(IArtifact* artifact); - void _dumpIntermediate( - const ArtifactDesc& desc, - void const* data, - size_t size); + /// If a container format has been specified produce a container (stored in m_containerBlob) + SlangResult maybeCreateContainer(); + /// If a container has been constructed and the filename/path has contents will try to write + /// the container contents to the file + SlangResult maybeWriteContainer(const String& fileName); - /* Emits entry point source taking into account if a pass-through or not. Uses 'targetFormat' to determine - the target (not targetReq) */ - SlangResult emitEntryPointsSource(ComPtr& outArtifact); + Linkage* getLinkage() { return m_linkage; } - SlangResult emitEntryPointsSourceFromIR(ComPtr& outArtifact); - - SlangResult emitWithDownstreamForEntryPoints(ComPtr& outArtifact); + int addEntryPoint( + int translationUnitIndex, + String const& name, + Profile profile, + List const& genericTypeNames); - /* Determines a suitable filename to identify the input for a given entry point being compiled. - If the end-to-end compile is a pass-through case, will attempt to find the (unique) source file - pathname for the translation unit containing the entry point at `entryPointIndex. - If the compilation is not in a pass-through case, then always returns `"slang-generated"`. - @param endToEndReq The end-to-end compile request which might be using pass-through compilation - @param entryPointIndex The index of the entry point to compute a filename for. - @return the appropriate source filename */ - String calcSourcePathForEntryPoints(); + void setWriter(WriterChannel chan, ISlangWriter* writer); + ISlangWriter* getWriter(WriterChannel chan) const + { + return m_writers->getWriter(SlangWriterChannel(chan)); + } - TranslationUnitRequest* findPassThroughTranslationUnit( - Int entryPointIndex); + /// The end to end request can be passed as nullptr, if not driven by one + SlangResult executeActionsInner(); + SlangResult executeActions(); + Session* getSession() { return m_session; } + DiagnosticSink* getSink() { return &m_sink; } + NamePool* getNamePool() { return getLinkage()->getNamePool(); } - SlangResult _emitEntryPoints(ComPtr& outArtifact); - private: - Shared* m_shared = nullptr; - }; + FrontEndCompileRequest* getFrontEndReq() { return m_frontEndReq; } - /// A compile request that spans the front and back ends of the compiler - /// - /// This is what the command-line `slangc` uses, as well as the legacy - /// C API. It ties together the functionality of `Linkage`, - /// `FrontEndCompileRequest`, and `BackEndCompileRequest`, plus a small - /// number of additional features that primarily make sense for - /// command-line usage. - /// - class EndToEndCompileRequest : public RefObject, public slang::ICompileRequest + ComponentType* getUnspecializedGlobalComponentType() { - public: - SLANG_CLASS_GUID(0xce6d2383, 0xee1b, 0x4fd7, { 0xa0, 0xf, 0xb8, 0xb6, 0x33, 0x12, 0x95, 0xc8 }) - - // ISlangUnknown - SLANG_NO_THROW SlangResult SLANG_MCALL queryInterface(SlangUUID const& uuid, void** outObject) SLANG_OVERRIDE; - SLANG_REF_OBJECT_IUNKNOWN_ADD_REF - SLANG_REF_OBJECT_IUNKNOWN_RELEASE - - // slang::ICompileRequest - virtual SLANG_NO_THROW void SLANG_MCALL setFileSystem(ISlangFileSystem* fileSystem) SLANG_OVERRIDE; - virtual SLANG_NO_THROW void SLANG_MCALL setCompileFlags(SlangCompileFlags flags) SLANG_OVERRIDE; - virtual SLANG_NO_THROW SlangCompileFlags SLANG_MCALL getCompileFlags() SLANG_OVERRIDE; - virtual SLANG_NO_THROW void SLANG_MCALL setDumpIntermediates(int enable) SLANG_OVERRIDE; - virtual SLANG_NO_THROW void SLANG_MCALL setDumpIntermediatePrefix(const char* prefix) SLANG_OVERRIDE; - virtual SLANG_NO_THROW void SLANG_MCALL setEnableEffectAnnotations(bool value) SLANG_OVERRIDE; - virtual SLANG_NO_THROW void SLANG_MCALL setLineDirectiveMode(SlangLineDirectiveMode mode) SLANG_OVERRIDE; - virtual SLANG_NO_THROW void SLANG_MCALL setCodeGenTarget(SlangCompileTarget target) SLANG_OVERRIDE; - virtual SLANG_NO_THROW int SLANG_MCALL addCodeGenTarget(SlangCompileTarget target) SLANG_OVERRIDE; - virtual SLANG_NO_THROW void SLANG_MCALL setTargetProfile(int targetIndex, SlangProfileID profile) SLANG_OVERRIDE; - virtual SLANG_NO_THROW void SLANG_MCALL setTargetFlags(int targetIndex, SlangTargetFlags flags) SLANG_OVERRIDE; - virtual SLANG_NO_THROW void SLANG_MCALL setTargetFloatingPointMode(int targetIndex, SlangFloatingPointMode mode) SLANG_OVERRIDE; - virtual SLANG_NO_THROW void SLANG_MCALL setTargetMatrixLayoutMode(int targetIndex, SlangMatrixLayoutMode mode) SLANG_OVERRIDE; - virtual SLANG_NO_THROW void SLANG_MCALL setTargetForceGLSLScalarBufferLayout(int targetIndex, bool value) SLANG_OVERRIDE; - virtual SLANG_NO_THROW void SLANG_MCALL setTargetForceDXLayout(int targetIndex, bool value) SLANG_OVERRIDE; - virtual SLANG_NO_THROW void SLANG_MCALL setTargetGenerateWholeProgram(int targetIndex, bool value) SLANG_OVERRIDE; - virtual SLANG_NO_THROW void SLANG_MCALL setTargetEmbedDownstreamIR(int targetIndex, bool value) SLANG_OVERRIDE; - virtual SLANG_NO_THROW void SLANG_MCALL setMatrixLayoutMode(SlangMatrixLayoutMode mode) SLANG_OVERRIDE; - virtual SLANG_NO_THROW void SLANG_MCALL setDebugInfoLevel(SlangDebugInfoLevel level) SLANG_OVERRIDE; - virtual SLANG_NO_THROW void SLANG_MCALL setOptimizationLevel(SlangOptimizationLevel level) SLANG_OVERRIDE; - virtual SLANG_NO_THROW void SLANG_MCALL setOutputContainerFormat(SlangContainerFormat format) SLANG_OVERRIDE; - virtual SLANG_NO_THROW void SLANG_MCALL setPassThrough(SlangPassThrough passThrough) SLANG_OVERRIDE; - virtual SLANG_NO_THROW void SLANG_MCALL setDiagnosticCallback(SlangDiagnosticCallback callback, void const* userData) SLANG_OVERRIDE; - virtual SLANG_NO_THROW void SLANG_MCALL setWriter(SlangWriterChannel channel, ISlangWriter* writer) SLANG_OVERRIDE; - virtual SLANG_NO_THROW ISlangWriter* SLANG_MCALL getWriter(SlangWriterChannel channel) SLANG_OVERRIDE; - virtual SLANG_NO_THROW void SLANG_MCALL addSearchPath(const char* searchDir) SLANG_OVERRIDE; - virtual SLANG_NO_THROW void SLANG_MCALL addPreprocessorDefine(const char* key, const char* value) SLANG_OVERRIDE; - virtual SLANG_NO_THROW SlangResult SLANG_MCALL processCommandLineArguments(char const* const* args, int argCount) SLANG_OVERRIDE; - virtual SLANG_NO_THROW int SLANG_MCALL addTranslationUnit(SlangSourceLanguage language, char const* name) SLANG_OVERRIDE; - virtual SLANG_NO_THROW void SLANG_MCALL setDefaultModuleName(const char* defaultModuleName) SLANG_OVERRIDE; - virtual SLANG_NO_THROW void SLANG_MCALL addTranslationUnitPreprocessorDefine(int translationUnitIndex, const char* key, const char* value) SLANG_OVERRIDE; - virtual SLANG_NO_THROW void SLANG_MCALL addTranslationUnitSourceFile(int translationUnitIndex, char const* path) SLANG_OVERRIDE; - virtual SLANG_NO_THROW void SLANG_MCALL addTranslationUnitSourceString(int translationUnitIndex, char const* path, char const* source) SLANG_OVERRIDE; - virtual SLANG_NO_THROW SlangResult SLANG_MCALL addLibraryReference(const char* basePath, const void* libData, size_t libDataSize) SLANG_OVERRIDE; - virtual SLANG_NO_THROW void SLANG_MCALL addTranslationUnitSourceStringSpan(int translationUnitIndex, char const* path, char const* sourceBegin, char const* sourceEnd) SLANG_OVERRIDE; - virtual SLANG_NO_THROW void SLANG_MCALL addTranslationUnitSourceBlob(int translationUnitIndex, char const* path, ISlangBlob* sourceBlob) SLANG_OVERRIDE; - virtual SLANG_NO_THROW int SLANG_MCALL addEntryPoint(int translationUnitIndex, char const* name, SlangStage stage) SLANG_OVERRIDE; - virtual SLANG_NO_THROW int SLANG_MCALL addEntryPointEx(int translationUnitIndex, char const* name, SlangStage stage, int genericArgCount, char const** genericArgs) SLANG_OVERRIDE; - virtual SLANG_NO_THROW SlangResult SLANG_MCALL setGlobalGenericArgs(int genericArgCount, char const** genericArgs) SLANG_OVERRIDE; - virtual SLANG_NO_THROW SlangResult SLANG_MCALL setTypeNameForGlobalExistentialTypeParam(int slotIndex, char const* typeName) SLANG_OVERRIDE; - virtual SLANG_NO_THROW SlangResult SLANG_MCALL setTypeNameForEntryPointExistentialTypeParam(int entryPointIndex, int slotIndex, char const* typeName) SLANG_OVERRIDE; - virtual SLANG_NO_THROW void SLANG_MCALL setAllowGLSLInput(bool value) SLANG_OVERRIDE; - virtual SLANG_NO_THROW SlangResult SLANG_MCALL compile() SLANG_OVERRIDE; - virtual SLANG_NO_THROW char const* SLANG_MCALL getDiagnosticOutput() SLANG_OVERRIDE; - virtual SLANG_NO_THROW SlangResult SLANG_MCALL getDiagnosticOutputBlob(ISlangBlob** outBlob) SLANG_OVERRIDE; - virtual SLANG_NO_THROW int SLANG_MCALL getDependencyFileCount() SLANG_OVERRIDE; - virtual SLANG_NO_THROW char const* SLANG_MCALL getDependencyFilePath(int index) SLANG_OVERRIDE; - virtual SLANG_NO_THROW int SLANG_MCALL getTranslationUnitCount() SLANG_OVERRIDE; - virtual SLANG_NO_THROW char const* SLANG_MCALL getEntryPointSource(int entryPointIndex) SLANG_OVERRIDE; - virtual SLANG_NO_THROW void const* SLANG_MCALL getEntryPointCode(int entryPointIndex, size_t* outSize) SLANG_OVERRIDE; - virtual SLANG_NO_THROW SlangResult SLANG_MCALL getEntryPointCodeBlob(int entryPointIndex, int targetIndex, ISlangBlob** outBlob) SLANG_OVERRIDE; - virtual SLANG_NO_THROW SlangResult SLANG_MCALL getEntryPointHostCallable(int entryPointIndex, int targetIndex, ISlangSharedLibrary** outSharedLibrary) SLANG_OVERRIDE; - virtual SLANG_NO_THROW SlangResult SLANG_MCALL getTargetCodeBlob(int targetIndex, ISlangBlob** outBlob) SLANG_OVERRIDE; - virtual SLANG_NO_THROW SlangResult SLANG_MCALL getTargetHostCallable(int targetIndex, ISlangSharedLibrary** outSharedLibrary) SLANG_OVERRIDE; - virtual SLANG_NO_THROW void const* SLANG_MCALL getCompileRequestCode(size_t* outSize) SLANG_OVERRIDE; - virtual SLANG_NO_THROW ISlangMutableFileSystem* SLANG_MCALL getCompileRequestResultAsFileSystem() SLANG_OVERRIDE; - virtual SLANG_NO_THROW SlangResult SLANG_MCALL getContainerCode(ISlangBlob** outBlob) SLANG_OVERRIDE; - virtual SLANG_NO_THROW SlangResult SLANG_MCALL loadRepro(ISlangFileSystem* fileSystem, const void* data, size_t size) SLANG_OVERRIDE; - virtual SLANG_NO_THROW SlangResult SLANG_MCALL saveRepro(ISlangBlob** outBlob) SLANG_OVERRIDE; - virtual SLANG_NO_THROW SlangResult SLANG_MCALL enableReproCapture() SLANG_OVERRIDE; - virtual SLANG_NO_THROW SlangResult SLANG_MCALL getProgram(slang::IComponentType** outProgram) SLANG_OVERRIDE; - virtual SLANG_NO_THROW SlangResult SLANG_MCALL getEntryPoint(SlangInt entryPointIndex, slang::IComponentType** outEntryPoint) SLANG_OVERRIDE; - virtual SLANG_NO_THROW SlangResult SLANG_MCALL getModule(SlangInt translationUnitIndex, slang::IModule** outModule) SLANG_OVERRIDE; - virtual SLANG_NO_THROW SlangResult SLANG_MCALL getSession(slang::ISession** outSession) SLANG_OVERRIDE; - virtual SLANG_NO_THROW SlangReflection* SLANG_MCALL getReflection() SLANG_OVERRIDE; - virtual SLANG_NO_THROW void SLANG_MCALL setCommandLineCompilerMode() SLANG_OVERRIDE; - virtual SLANG_NO_THROW SlangResult SLANG_MCALL addTargetCapability(SlangInt targetIndex, SlangCapabilityID capability) SLANG_OVERRIDE; - virtual SLANG_NO_THROW SlangResult SLANG_MCALL getProgramWithEntryPoints(slang::IComponentType** outProgram) SLANG_OVERRIDE; - virtual SLANG_NO_THROW SlangResult SLANG_MCALL isParameterLocationUsed(SlangInt entryPointIndex, SlangInt targetIndex, SlangParameterCategory category, SlangUInt spaceIndex, SlangUInt registerIndex, bool& outUsed) SLANG_OVERRIDE; - virtual SLANG_NO_THROW void SLANG_MCALL setTargetLineDirectiveMode( - SlangInt targetIndex, - SlangLineDirectiveMode mode) SLANG_OVERRIDE; - virtual SLANG_NO_THROW void SLANG_MCALL overrideDiagnosticSeverity( - SlangInt messageID, - SlangSeverity overrideSeverity) SLANG_OVERRIDE; - virtual SLANG_NO_THROW SlangDiagnosticFlags SLANG_MCALL getDiagnosticFlags() SLANG_OVERRIDE; - virtual SLANG_NO_THROW void SLANG_MCALL setDiagnosticFlags(SlangDiagnosticFlags flags) SLANG_OVERRIDE; - virtual SLANG_NO_THROW void SLANG_MCALL setDebugInfoFormat(SlangDebugInfoFormat format) SLANG_OVERRIDE; - virtual SLANG_NO_THROW void SLANG_MCALL setReportDownstreamTime(bool value) SLANG_OVERRIDE; - virtual SLANG_NO_THROW void SLANG_MCALL setReportPerfBenchmark(bool value) SLANG_OVERRIDE; - virtual SLANG_NO_THROW void SLANG_MCALL setSkipSPIRVValidation(bool value) SLANG_OVERRIDE; - virtual SLANG_NO_THROW void SLANG_MCALL setTargetUseMinimumSlangOptimization(int targetIndex, bool value) SLANG_OVERRIDE; - virtual SLANG_NO_THROW void SLANG_MCALL setIgnoreCapabilityCheck(bool value) SLANG_OVERRIDE; - virtual SLANG_NO_THROW SlangResult SLANG_MCALL getCompileTimeProfile(ISlangProfiler** compileTimeProfile, bool isClear) SLANG_OVERRIDE; - - void setTrackLiveness(bool v); - - EndToEndCompileRequest( - Session* session); - - EndToEndCompileRequest( - Linkage* linkage); - - ~EndToEndCompileRequest(); - - // If enabled will emit IR - bool m_emitIr = false; - - // What container format are we being asked to generate? - // If it's set to a format, the container blob will be calculated during compile - ContainerFormat m_containerFormat = ContainerFormat::None; - - /// Where the container is stored. This is calculated as part of compile if m_containerFormat is set to - /// a supported format. - ComPtr m_containerArtifact; - /// Holds the container as a file system - ComPtr m_containerFileSystem; - - /// File system used by repro system if a file couldn't be found within the repro (or associated directory) - ComPtr m_reproFallbackFileSystem = ComPtr(OSFileSystem::getExtSingleton()); - - // Path to output container to - String m_containerOutputPath; - - // Should we just pass the input to another compiler? - PassThroughMode m_passThrough = PassThroughMode::None; - - /// If output should be source embedded, define the style of the embedding - SourceEmbedUtil::Style m_sourceEmbedStyle = SourceEmbedUtil::Style::None; - /// The language to be used for source embedding - SourceLanguage m_sourceEmbedLanguage = SourceLanguage::C; - /// Source embed variable name. Note may be used as a basis for names if multiple items written - String m_sourceEmbedName; - - /// Source code for the specialization arguments to use for the global specialization parameters of the program. - List m_globalSpecializationArgStrings; - - // Are we being driven by the command-line `slangc`, and should act accordingly? - bool m_isCommandLineCompile = false; - - String m_diagnosticOutput; - - /// A blob holding the diagnostic output - ComPtr m_diagnosticOutputBlob; - - /// Per-entry-point information not tracked by other compile requests - class EntryPointInfo : public RefObject - { - public: - /// Source code for the specialization arguments to use for the specialization parameters of the entry point. - List specializationArgStrings; - }; - List m_entryPoints; - - /// Per-target information only needed for command-line compiles - class TargetInfo : public RefObject - { - public: - // Requested output paths for each entry point. - // An empty string indices no output desired for - // the given entry point. - Dictionary entryPointOutputPaths; - String wholeTargetOutputPath; - CompilerOptionSet targetOptions; - }; - Dictionary> m_targetInfos; + return getFrontEndReq()->getGlobalComponentType(); + } + ComponentType* getUnspecializedGlobalAndEntryPointsComponentType() + { + return getFrontEndReq()->getGlobalAndEntryPointsComponentType(); + } - CompilerOptionSet m_optionSetForDefaultTarget; + ComponentType* getSpecializedGlobalComponentType() { return m_specializedGlobalComponentType; } + ComponentType* getSpecializedGlobalAndEntryPointsComponentType() + { + return m_specializedGlobalAndEntryPointsComponentType; + } - CompilerOptionSet& getTargetOptionSet(TargetRequest* req); + ComponentType* getSpecializedEntryPointComponentType(Index index) + { + return m_specializedEntryPoints[index]; + } - CompilerOptionSet& getTargetOptionSet(Index targetIndex); + void writeArtifactToStandardOutput(IArtifact* artifact, DiagnosticSink* sink); - String m_dependencyOutputPath; + void generateOutput(); - /// Writes the modules in a container to the stream - SlangResult writeContainerToStream(Stream* stream); - - /// If a container format has been specified produce a container (stored in m_containerBlob) - SlangResult maybeCreateContainer(); - /// If a container has been constructed and the filename/path has contents will try to write - /// the container contents to the file - SlangResult maybeWriteContainer(const String& fileName); + CompilerOptionSet& getOptionSet() { return m_linkage->m_optionSet; } - Linkage* getLinkage() { return m_linkage; } +private: + String _getWholeProgramPath(TargetRequest* targetReq); + String _getEntryPointPath(TargetRequest* targetReq, Index entryPointIndex); - int addEntryPoint( - int translationUnitIndex, - String const& name, - Profile profile, - List const & genericTypeNames); + /// Maybe write the artifact to the path (if set), or stdout (if there is no container or path) + SlangResult _maybeWriteArtifact(const String& path, IArtifact* artifact); + SlangResult _writeArtifact(const String& path, IArtifact* artifact); - void setWriter(WriterChannel chan, ISlangWriter* writer); - ISlangWriter* getWriter(WriterChannel chan) const { return m_writers->getWriter(SlangWriterChannel(chan)); } + /// Adds any extra settings to complete a targetRequest + void _completeTargetRequest(UInt targetIndex); - /// The end to end request can be passed as nullptr, if not driven by one - SlangResult executeActionsInner(); - SlangResult executeActions(); + ISlangUnknown* getInterface(const Guid& guid); - Session* getSession() { return m_session; } - DiagnosticSink* getSink() { return &m_sink; } - NamePool* getNamePool() { return getLinkage()->getNamePool(); } + void generateOutput(ComponentType* program); + void generateOutput(TargetProgram* targetProgram); - FrontEndCompileRequest* getFrontEndReq() { return m_frontEndReq; } + void init(); - ComponentType* getUnspecializedGlobalComponentType() { return getFrontEndReq()->getGlobalComponentType(); } - ComponentType* getUnspecializedGlobalAndEntryPointsComponentType() - { - return getFrontEndReq()->getGlobalAndEntryPointsComponentType(); - } + Session* m_session = nullptr; + RefPtr m_linkage; + DiagnosticSink m_sink; + RefPtr m_frontEndReq; + RefPtr m_specializedGlobalComponentType; + RefPtr m_specializedGlobalAndEntryPointsComponentType; + List> m_specializedEntryPoints; - ComponentType* getSpecializedGlobalComponentType() { return m_specializedGlobalComponentType; } - ComponentType* getSpecializedGlobalAndEntryPointsComponentType() { return m_specializedGlobalAndEntryPointsComponentType; } + // For output - ComponentType* getSpecializedEntryPointComponentType(Index index) - { - return m_specializedEntryPoints[index]; - } - - void writeArtifactToStandardOutput(IArtifact* artifact, DiagnosticSink* sink); - - void generateOutput(); - - CompilerOptionSet& getOptionSet() { return m_linkage->m_optionSet; } - private: - - String _getWholeProgramPath(TargetRequest* targetReq); - String _getEntryPointPath(TargetRequest* targetReq, Index entryPointIndex); - - /// Maybe write the artifact to the path (if set), or stdout (if there is no container or path) - SlangResult _maybeWriteArtifact(const String& path, IArtifact* artifact); - SlangResult _writeArtifact(const String& path, IArtifact* artifact); - - /// Adds any extra settings to complete a targetRequest - void _completeTargetRequest(UInt targetIndex); - - ISlangUnknown* getInterface(const Guid& guid); - - void generateOutput(ComponentType* program); - void generateOutput(TargetProgram* targetProgram); - - void init(); - - Session* m_session = nullptr; - RefPtr m_linkage; - DiagnosticSink m_sink; - RefPtr m_frontEndReq; - RefPtr m_specializedGlobalComponentType; - RefPtr m_specializedGlobalAndEntryPointsComponentType; - List> m_specializedEntryPoints; - - // For output - - RefPtr m_writers; - }; + RefPtr m_writers; +}; - /* Returns SLANG_OK if pass through support is available */ - SlangResult checkExternalCompilerSupport(Session* session, PassThroughMode passThrough); - /* Report an error appearing from external compiler to the diagnostic sink error to the diagnostic sink. - @param compilerName The name of the compiler the error came for (or nullptr if not known) - @param res Result associated with the error. The error code will be reported. (Can take HRESULT - and will expand to string if known) - @param diagnostic The diagnostic string associated with the compile failure - @param sink The diagnostic sink to report to */ - void reportExternalCompileError(const char* compilerName, SlangResult res, const UnownedStringSlice& diagnostic, DiagnosticSink* sink); +/* Returns SLANG_OK if pass through support is available */ +SlangResult checkExternalCompilerSupport(Session* session, PassThroughMode passThrough); +/* Report an error appearing from external compiler to the diagnostic sink error to the diagnostic +sink. +@param compilerName The name of the compiler the error came for (or nullptr if not known) +@param res Result associated with the error. The error code will be reported. (Can take HRESULT - +and will expand to string if known) +@param diagnostic The diagnostic string associated with the compile failure +@param sink The diagnostic sink to report to */ +void reportExternalCompileError( + const char* compilerName, + SlangResult res, + const UnownedStringSlice& diagnostic, + DiagnosticSink* sink); - // +// - // Information about BaseType that's useful for checking literals - struct BaseTypeInfo +// Information about BaseType that's useful for checking literals +struct BaseTypeInfo +{ + typedef uint8_t Flags; + struct Flag { - typedef uint8_t Flags; - struct Flag + enum Enum : Flags { - enum Enum : Flags - { - Signed = 0x1, - FloatingPoint = 0x2, - Integer = 0x4, - }; + Signed = 0x1, + FloatingPoint = 0x2, + Integer = 0x4, }; + }; - SLANG_FORCE_INLINE static const BaseTypeInfo& getInfo(BaseType baseType) { return s_info[Index(baseType)]; } + SLANG_FORCE_INLINE static const BaseTypeInfo& getInfo(BaseType baseType) + { + return s_info[Index(baseType)]; + } - static UnownedStringSlice asText(BaseType baseType); + static UnownedStringSlice asText(BaseType baseType); - uint8_t sizeInBytes; ///< Size of type in bytes - Flags flags; - uint8_t baseType; + uint8_t sizeInBytes; ///< Size of type in bytes + Flags flags; + uint8_t baseType; - static bool check(); + static bool check(); - private: - static const BaseTypeInfo s_info[Index(BaseType::CountOf)]; - }; +private: + static const BaseTypeInfo s_info[Index(BaseType::CountOf)]; +}; - class CodeGenTransitionMap +class CodeGenTransitionMap +{ +public: + struct Pair { - public: - struct Pair - { - typedef Pair ThisType; - SLANG_FORCE_INLINE bool operator==(const ThisType& rhs) const { return source == rhs.source && target == rhs.target; } - SLANG_FORCE_INLINE bool operator!=(const ThisType& rhs) const { return !(*this == rhs); } - - SLANG_FORCE_INLINE HashCode getHashCode() const { return combineHash(HashCode(source), HashCode(target)); } - - CodeGenTarget source; - CodeGenTarget target; - }; - - void removeTransition(CodeGenTarget source, CodeGenTarget target) - { - m_map.remove(Pair{ source, target }); - } - void addTransition(CodeGenTarget source, CodeGenTarget target, PassThroughMode compiler) + typedef Pair ThisType; + SLANG_FORCE_INLINE bool operator==(const ThisType& rhs) const { - SLANG_ASSERT(source != target); - m_map.set(Pair{ source, target }, compiler); + return source == rhs.source && target == rhs.target; } - bool hasTransition(CodeGenTarget source, CodeGenTarget target) const - { - return m_map.containsKey(Pair{ source, target }); - } - PassThroughMode getTransition(CodeGenTarget source, CodeGenTarget target) const + SLANG_FORCE_INLINE bool operator!=(const ThisType& rhs) const { return !(*this == rhs); } + + SLANG_FORCE_INLINE HashCode getHashCode() const { - const Pair pair{ source, target }; - auto value = m_map.tryGetValue(pair); - return value ? *value : PassThroughMode::None; + return combineHash(HashCode(source), HashCode(target)); } - protected: - Dictionary m_map; + CodeGenTarget source; + CodeGenTarget target; }; - class Session : public RefObject, public slang::IGlobalSession + void removeTransition(CodeGenTarget source, CodeGenTarget target) { - public: - SLANG_COM_INTERFACE(0xd6b767eb, 0xd786, 0x4343, { 0x2a, 0x8c, 0x6d, 0xa0, 0x3d, 0x5a, 0xb4, 0x4a }) - - SLANG_NO_THROW SlangResult SLANG_MCALL queryInterface(SlangUUID const& uuid, void** outObject) SLANG_OVERRIDE; - SLANG_REF_OBJECT_IUNKNOWN_ADD_REF - SLANG_REF_OBJECT_IUNKNOWN_RELEASE - - // slang::IGlobalSession - SLANG_NO_THROW SlangResult SLANG_MCALL createSession(slang::SessionDesc const& desc, slang::ISession** outSession) override; - SLANG_NO_THROW SlangProfileID SLANG_MCALL findProfile(char const* name) override; - SLANG_NO_THROW void SLANG_MCALL setDownstreamCompilerPath(SlangPassThrough passThrough, char const* path) override; - SLANG_NO_THROW void SLANG_MCALL setDownstreamCompilerPrelude(SlangPassThrough inPassThrough, char const* prelude) override; - SLANG_NO_THROW void SLANG_MCALL getDownstreamCompilerPrelude(SlangPassThrough inPassThrough, ISlangBlob** outPrelude) override; - SLANG_NO_THROW const char* SLANG_MCALL getBuildTagString() override; - SLANG_NO_THROW SlangResult SLANG_MCALL setDefaultDownstreamCompiler(SlangSourceLanguage sourceLanguage, SlangPassThrough defaultCompiler) override; - SLANG_NO_THROW SlangPassThrough SLANG_MCALL getDefaultDownstreamCompiler(SlangSourceLanguage sourceLanguage) override; - - SLANG_NO_THROW void SLANG_MCALL setLanguagePrelude(SlangSourceLanguage inSourceLanguage, char const* prelude) override; - SLANG_NO_THROW void SLANG_MCALL getLanguagePrelude(SlangSourceLanguage inSourceLanguage, ISlangBlob** outPrelude) override; - - SLANG_NO_THROW SlangResult SLANG_MCALL createCompileRequest(slang::ICompileRequest** outCompileRequest) override; - - SLANG_NO_THROW void SLANG_MCALL addBuiltins(char const* sourcePath, char const* sourceString) override; - SLANG_NO_THROW void SLANG_MCALL setSharedLibraryLoader(ISlangSharedLibraryLoader* loader) override; - SLANG_NO_THROW ISlangSharedLibraryLoader* SLANG_MCALL getSharedLibraryLoader() override; - SLANG_NO_THROW SlangResult SLANG_MCALL checkCompileTargetSupport(SlangCompileTarget target) override; - SLANG_NO_THROW SlangResult SLANG_MCALL checkPassThroughSupport(SlangPassThrough passThrough) override; - - void writeCoreModuleDoc(String config); - SLANG_NO_THROW SlangResult SLANG_MCALL compileCoreModule(slang::CompileCoreModuleFlags flags) override; - SLANG_NO_THROW SlangResult SLANG_MCALL loadCoreModule(const void* coreModule, size_t coreModuleSizeInBytes) override; - SLANG_NO_THROW SlangResult SLANG_MCALL saveCoreModule(SlangArchiveType archiveType, ISlangBlob** outBlob) override; - - SLANG_NO_THROW SlangCapabilityID SLANG_MCALL findCapability(char const* name) override; - - SLANG_NO_THROW void SLANG_MCALL setDownstreamCompilerForTransition(SlangCompileTarget source, SlangCompileTarget target, SlangPassThrough compiler) override; - SLANG_NO_THROW SlangPassThrough SLANG_MCALL getDownstreamCompilerForTransition(SlangCompileTarget source, SlangCompileTarget target) override; - SLANG_NO_THROW void SLANG_MCALL getCompilerElapsedTime(double* outTotalTime, double* outDownstreamTime) override - { - *outDownstreamTime = m_downstreamCompileTime; - *outTotalTime = m_totalCompileTime; - } + m_map.remove(Pair{source, target}); + } + void addTransition(CodeGenTarget source, CodeGenTarget target, PassThroughMode compiler) + { + SLANG_ASSERT(source != target); + m_map.set(Pair{source, target}, compiler); + } + bool hasTransition(CodeGenTarget source, CodeGenTarget target) const + { + return m_map.containsKey(Pair{source, target}); + } + PassThroughMode getTransition(CodeGenTarget source, CodeGenTarget target) const + { + const Pair pair{source, target}; + auto value = m_map.tryGetValue(pair); + return value ? *value : PassThroughMode::None; + } - SLANG_NO_THROW SlangResult SLANG_MCALL setSPIRVCoreGrammar(char const* jsonPath) override; +protected: + Dictionary m_map; +}; - SLANG_NO_THROW SlangResult SLANG_MCALL parseCommandLineArguments( - int argc, const char* const* argv, slang::SessionDesc* outSessionDesc, ISlangUnknown** outAllocation) override; +class Session : public RefObject, public slang::IGlobalSession +{ +public: + SLANG_COM_INTERFACE( + 0xd6b767eb, + 0xd786, + 0x4343, + {0x2a, 0x8c, 0x6d, 0xa0, 0x3d, 0x5a, 0xb4, 0x4a}) + + SLANG_NO_THROW SlangResult SLANG_MCALL queryInterface(SlangUUID const& uuid, void** outObject) + SLANG_OVERRIDE; + SLANG_REF_OBJECT_IUNKNOWN_ADD_REF + SLANG_REF_OBJECT_IUNKNOWN_RELEASE + + // slang::IGlobalSession + SLANG_NO_THROW SlangResult SLANG_MCALL + createSession(slang::SessionDesc const& desc, slang::ISession** outSession) override; + SLANG_NO_THROW SlangProfileID SLANG_MCALL findProfile(char const* name) override; + SLANG_NO_THROW void SLANG_MCALL + setDownstreamCompilerPath(SlangPassThrough passThrough, char const* path) override; + SLANG_NO_THROW void SLANG_MCALL + setDownstreamCompilerPrelude(SlangPassThrough inPassThrough, char const* prelude) override; + SLANG_NO_THROW void SLANG_MCALL + getDownstreamCompilerPrelude(SlangPassThrough inPassThrough, ISlangBlob** outPrelude) override; + SLANG_NO_THROW const char* SLANG_MCALL getBuildTagString() override; + SLANG_NO_THROW SlangResult SLANG_MCALL setDefaultDownstreamCompiler( + SlangSourceLanguage sourceLanguage, + SlangPassThrough defaultCompiler) override; + SLANG_NO_THROW SlangPassThrough SLANG_MCALL + getDefaultDownstreamCompiler(SlangSourceLanguage sourceLanguage) override; + + SLANG_NO_THROW void SLANG_MCALL + setLanguagePrelude(SlangSourceLanguage inSourceLanguage, char const* prelude) override; + SLANG_NO_THROW void SLANG_MCALL + getLanguagePrelude(SlangSourceLanguage inSourceLanguage, ISlangBlob** outPrelude) override; + + SLANG_NO_THROW SlangResult SLANG_MCALL + createCompileRequest(slang::ICompileRequest** outCompileRequest) override; + + SLANG_NO_THROW void SLANG_MCALL + addBuiltins(char const* sourcePath, char const* sourceString) override; + SLANG_NO_THROW void SLANG_MCALL + setSharedLibraryLoader(ISlangSharedLibraryLoader* loader) override; + SLANG_NO_THROW ISlangSharedLibraryLoader* SLANG_MCALL getSharedLibraryLoader() override; + SLANG_NO_THROW SlangResult SLANG_MCALL + checkCompileTargetSupport(SlangCompileTarget target) override; + SLANG_NO_THROW SlangResult SLANG_MCALL + checkPassThroughSupport(SlangPassThrough passThrough) override; + + void writeCoreModuleDoc(String config); + SLANG_NO_THROW SlangResult SLANG_MCALL + compileCoreModule(slang::CompileCoreModuleFlags flags) override; + SLANG_NO_THROW SlangResult SLANG_MCALL + loadCoreModule(const void* coreModule, size_t coreModuleSizeInBytes) override; + SLANG_NO_THROW SlangResult SLANG_MCALL + saveCoreModule(SlangArchiveType archiveType, ISlangBlob** outBlob) override; + + SLANG_NO_THROW SlangCapabilityID SLANG_MCALL findCapability(char const* name) override; + + SLANG_NO_THROW void SLANG_MCALL setDownstreamCompilerForTransition( + SlangCompileTarget source, + SlangCompileTarget target, + SlangPassThrough compiler) override; + SLANG_NO_THROW SlangPassThrough SLANG_MCALL getDownstreamCompilerForTransition( + SlangCompileTarget source, + SlangCompileTarget target) override; + SLANG_NO_THROW void SLANG_MCALL + getCompilerElapsedTime(double* outTotalTime, double* outDownstreamTime) override + { + *outDownstreamTime = m_downstreamCompileTime; + *outTotalTime = m_totalCompileTime; + } - SLANG_NO_THROW SlangResult SLANG_MCALL getSessionDescDigest(slang::SessionDesc* sessionDesc, ISlangBlob** outBlob) override; + SLANG_NO_THROW SlangResult SLANG_MCALL setSPIRVCoreGrammar(char const* jsonPath) override; - /// Get the downstream compiler for a transition - IDownstreamCompiler* getDownstreamCompiler(CodeGenTarget source, CodeGenTarget target); - - // This needs to be atomic not because of contention between threads as `Session` is - // *not* multithreaded, but can be used exclusively on one thread at a time. - // The need for atomic is purely for visibility. If the session is used on a different - // thread we need to be sure any changes to m_epochId are visible to this thread. - std::atomic m_epochId = 1; + SLANG_NO_THROW SlangResult SLANG_MCALL parseCommandLineArguments( + int argc, + const char* const* argv, + slang::SessionDesc* outSessionDesc, + ISlangUnknown** outAllocation) override; - Scope* baseLanguageScope = nullptr; - Scope* coreLanguageScope = nullptr; - Scope* hlslLanguageScope = nullptr; - Scope* slangLanguageScope = nullptr; - Scope* autodiffLanguageScope = nullptr; + SLANG_NO_THROW SlangResult SLANG_MCALL + getSessionDescDigest(slang::SessionDesc* sessionDesc, ISlangBlob** outBlob) override; - ModuleDecl* baseModuleDecl = nullptr; - List> coreModules; + /// Get the downstream compiler for a transition + IDownstreamCompiler* getDownstreamCompiler(CodeGenTarget source, CodeGenTarget target); - SourceManager builtinSourceManager; + // This needs to be atomic not because of contention between threads as `Session` is + // *not* multithreaded, but can be used exclusively on one thread at a time. + // The need for atomic is purely for visibility. If the session is used on a different + // thread we need to be sure any changes to m_epochId are visible to this thread. + std::atomic m_epochId = 1; - SourceManager* getBuiltinSourceManager() { return &builtinSourceManager; } + Scope* baseLanguageScope = nullptr; + Scope* coreLanguageScope = nullptr; + Scope* hlslLanguageScope = nullptr; + Scope* slangLanguageScope = nullptr; + Scope* autodiffLanguageScope = nullptr; - // Name pool stuff for unique-ing identifiers + ModuleDecl* baseModuleDecl = nullptr; + List> coreModules; - RootNamePool rootNamePool; - NamePool namePool; + SourceManager builtinSourceManager; - RootNamePool* getRootNamePool() { return &rootNamePool; } - NamePool* getNamePool() { return &namePool; } - Name* getNameObj(String name) { return namePool.getName(name); } - Name* tryGetNameObj(String name) { return namePool.tryGetName(name); } - // + SourceManager* getBuiltinSourceManager() { return &builtinSourceManager; } - /// This AST Builder should only be used for creating AST nodes that are global across requests - /// not doing so could lead to memory being consumed but not used. - ASTBuilder* getGlobalASTBuilder() { return globalAstBuilder; } - void finalizeSharedASTBuilder(); + // Name pool stuff for unique-ing identifiers - RefPtr globalAstBuilder; + RootNamePool rootNamePool; + NamePool namePool; - // Generated code for core module, etc. - String coreModulePath; + RootNamePool* getRootNamePool() { return &rootNamePool; } + NamePool* getNamePool() { return &namePool; } + Name* getNameObj(String name) { return namePool.getName(name); } + Name* tryGetNameObj(String name) { return namePool.tryGetName(name); } + // - ComPtr coreLibraryCode; - //ComPtr slangLibraryCode; - ComPtr hlslLibraryCode; - ComPtr glslLibraryCode; - ComPtr autodiffLibraryCode; + /// This AST Builder should only be used for creating AST nodes that are global across requests + /// not doing so could lead to memory being consumed but not used. + ASTBuilder* getGlobalASTBuilder() { return globalAstBuilder; } + void finalizeSharedASTBuilder(); - String getCoreModulePath(); + RefPtr globalAstBuilder; - ComPtr getCoreLibraryCode(); - ComPtr getHLSLLibraryCode(); - ComPtr getAutodiffLibraryCode(); - ComPtr getGLSLLibraryCode(); + // Generated code for core module, etc. + String coreModulePath; - RefPtr m_sharedASTBuilder; + ComPtr coreLibraryCode; + // ComPtr slangLibraryCode; + ComPtr hlslLibraryCode; + ComPtr glslLibraryCode; + ComPtr autodiffLibraryCode; - SPIRVCoreGrammarInfo& getSPIRVCoreGrammarInfo() - { - if(!spirvCoreGrammarInfo) - setSPIRVCoreGrammar(nullptr); - SLANG_ASSERT(spirvCoreGrammarInfo); - return *spirvCoreGrammarInfo; - } - RefPtr spirvCoreGrammarInfo; + String getCoreModulePath(); - // + ComPtr getCoreLibraryCode(); + ComPtr getHLSLLibraryCode(); + ComPtr getAutodiffLibraryCode(); + ComPtr getGLSLLibraryCode(); - void _setSharedLibraryLoader(ISlangSharedLibraryLoader* loader); + RefPtr m_sharedASTBuilder; - /// Will try to load the library by specified name (using the set loader), if not one already available. - IDownstreamCompiler* getOrLoadDownstreamCompiler(PassThroughMode type, DiagnosticSink* sink); - /// Will unload the specified shared library if it's currently loaded - void resetDownstreamCompiler(PassThroughMode type); + SPIRVCoreGrammarInfo& getSPIRVCoreGrammarInfo() + { + if (!spirvCoreGrammarInfo) + setSPIRVCoreGrammar(nullptr); + SLANG_ASSERT(spirvCoreGrammarInfo); + return *spirvCoreGrammarInfo; + } + RefPtr spirvCoreGrammarInfo; - /// Get the prelude associated with the language - const String& getPreludeForLanguage(SourceLanguage language) { return m_languagePreludes[int(language)]; } + // - /// Get the built in linkage -> handy to get the core module from - Linkage* getBuiltinLinkage() const { return m_builtinLinkage; } + void _setSharedLibraryLoader(ISlangSharedLibraryLoader* loader); - Name* getCompletionRequestTokenName() const { return m_completionTokenName; } + /// Will try to load the library by specified name (using the set loader), if not one already + /// available. + IDownstreamCompiler* getOrLoadDownstreamCompiler(PassThroughMode type, DiagnosticSink* sink); + /// Will unload the specified shared library if it's currently loaded + void resetDownstreamCompiler(PassThroughMode type); - void init(); + /// Get the prelude associated with the language + const String& getPreludeForLanguage(SourceLanguage language) + { + return m_languagePreludes[int(language)]; + } - void addBuiltinSource( - Scope* scope, - String const& path, - ISlangBlob* sourceBlob); - ~Session(); + /// Get the built in linkage -> handy to get the core module from + Linkage* getBuiltinLinkage() const { return m_builtinLinkage; } - void addDownstreamCompileTime(double time) { m_downstreamCompileTime += time; } - void addTotalCompileTime(double time) { m_totalCompileTime += time; } + Name* getCompletionRequestTokenName() const { return m_completionTokenName; } - ComPtr m_sharedLibraryLoader; ///< The shared library loader (never null) + void init(); - int m_downstreamCompilerInitialized = 0; + void addBuiltinSource(Scope* scope, String const& path, ISlangBlob* sourceBlob); + ~Session(); - RefPtr m_downstreamCompilerSet; ///< Information about all available downstream compilers. - ComPtr m_downstreamCompilers[int(PassThroughMode::CountOf)]; ///< A downstream compiler for a pass through - DownstreamCompilerLocatorFunc m_downstreamCompilerLocators[int(PassThroughMode::CountOf)]; - Name* m_completionTokenName = nullptr; ///< The name of a completion request token. + void addDownstreamCompileTime(double time) { m_downstreamCompileTime += time; } + void addTotalCompileTime(double time) { m_totalCompileTime += time; } - /// For parsing command line options - CommandOptions m_commandOptions; + ComPtr + m_sharedLibraryLoader; ///< The shared library loader (never null) - int m_typeDictionarySize = 0; - private: + int m_downstreamCompilerInitialized = 0; - void _initCodeGenTransitionMap(); + RefPtr + m_downstreamCompilerSet; ///< Information about all available downstream compilers. + ComPtr m_downstreamCompilers[int( + PassThroughMode::CountOf)]; ///< A downstream compiler for a pass through + DownstreamCompilerLocatorFunc m_downstreamCompilerLocators[int(PassThroughMode::CountOf)]; + Name* m_completionTokenName = nullptr; ///< The name of a completion request token. - SlangResult _readBuiltinModule(ISlangFileSystem* fileSystem, Scope* scope, String moduleName); + /// For parsing command line options + CommandOptions m_commandOptions; - SlangResult _loadRequest(EndToEndCompileRequest* request, const void* data, size_t size); + int m_typeDictionarySize = 0; - /// Linkage used for all built-in (core module) code. - RefPtr m_builtinLinkage; +private: + void _initCodeGenTransitionMap(); - String m_downstreamCompilerPaths[int(PassThroughMode::CountOf)]; ///< Paths for each pass through - String m_languagePreludes[int(SourceLanguage::CountOf)]; ///< Prelude for each source language - PassThroughMode m_defaultDownstreamCompilers[int(SourceLanguage::CountOf)]; + SlangResult _readBuiltinModule(ISlangFileSystem* fileSystem, Scope* scope, String moduleName); - // Describes a conversion from one code gen target (source) to another (target) - CodeGenTransitionMap m_codeGenTransitionMap; + SlangResult _loadRequest(EndToEndCompileRequest* request, const void* data, size_t size); - double m_downstreamCompileTime = 0.0; - double m_totalCompileTime = 0.0; - }; + /// Linkage used for all built-in (core module) code. + RefPtr m_builtinLinkage; - void checkTranslationUnit( - TranslationUnitRequest* translationUnit, LoadedModuleDictionary& loadedModules); + String + m_downstreamCompilerPaths[int(PassThroughMode::CountOf)]; ///< Paths for each pass through + String m_languagePreludes[int(SourceLanguage::CountOf)]; ///< Prelude for each source language + PassThroughMode m_defaultDownstreamCompilers[int(SourceLanguage::CountOf)]; - // Look for a module that matches the given name: - // either one we've loaded already, or one we - // can find vai the search paths available to us. - // - // Needed by import declaration checking. - // - RefPtr findOrImportModule( - Linkage* linkage, - Name* name, - SourceLoc const& loc, - DiagnosticSink* sink, - const LoadedModuleDictionary* additionalLoadedModules); + // Describes a conversion from one code gen target (source) to another (target) + CodeGenTransitionMap m_codeGenTransitionMap; + + double m_downstreamCompileTime = 0.0; + double m_totalCompileTime = 0.0; +}; - SlangResult passthroughDownstreamDiagnostics(DiagnosticSink* sink, IDownstreamCompiler* compiler, IArtifact* artifact); +void checkTranslationUnit( + TranslationUnitRequest* translationUnit, + LoadedModuleDictionary& loadedModules); + +// Look for a module that matches the given name: +// either one we've loaded already, or one we +// can find vai the search paths available to us. +// +// Needed by import declaration checking. +// +RefPtr findOrImportModule( + Linkage* linkage, + Name* name, + SourceLoc const& loc, + DiagnosticSink* sink, + const LoadedModuleDictionary* additionalLoadedModules); + +SlangResult passthroughDownstreamDiagnostics( + DiagnosticSink* sink, + IDownstreamCompiler* compiler, + IArtifact* artifact); // // The following functions are utilties to convert between @@ -3489,7 +3654,7 @@ SLANG_FORCE_INLINE slang::IComponentType* asExternal(ComponentType* componentTyp SLANG_FORCE_INLINE slang::ProgramLayout* asExternal(ProgramLayout* programLayout) { - return (slang::ProgramLayout*) programLayout; + return (slang::ProgramLayout*)programLayout; } SLANG_FORCE_INLINE Type* asInternal(slang::TypeReflection* type) @@ -3538,9 +3703,15 @@ SLANG_FORCE_INLINE EndToEndCompileRequest* asInternal(SlangCompileRequest* reque return endToEndRequest; } -SLANG_FORCE_INLINE SlangCompileTarget asExternal(CodeGenTarget target) { return (SlangCompileTarget)target; } +SLANG_FORCE_INLINE SlangCompileTarget asExternal(CodeGenTarget target) +{ + return (SlangCompileTarget)target; +} -SLANG_FORCE_INLINE SlangSourceLanguage asExternal(SourceLanguage sourceLanguage) { return (SlangSourceLanguage)sourceLanguage; } +SLANG_FORCE_INLINE SlangSourceLanguage asExternal(SourceLanguage sourceLanguage) +{ + return (SlangSourceLanguage)sourceLanguage; +} // Helper class for recording compile time. struct CompileTimerRAII @@ -3555,8 +3726,9 @@ struct CompileTimerRAII ~CompileTimerRAII() { double elapsedTime = std::chrono::duration_cast( - std::chrono::high_resolution_clock::now() - startTime) - .count() / 1e6; + std::chrono::high_resolution_clock::now() - startTime) + .count() / + 1e6; session->addTotalCompileTime(elapsedTime); } }; @@ -3568,17 +3740,32 @@ enum class DiagnosticCategory Capability = 1 << 0, }; template -bool maybeDiagnose(DiagnosticSink* sink, CompilerOptionSet& optionSet, DiagnosticCategory errorType, P const& pos, DiagnosticInfo const& info, Args const&... args) +bool maybeDiagnose( + DiagnosticSink* sink, + CompilerOptionSet& optionSet, + DiagnosticCategory errorType, + P const& pos, + DiagnosticInfo const& info, + Args const&... args) { - if ((int)errorType & (int)DiagnosticCategory::Capability && optionSet.getBoolOption(CompilerOptionName::IgnoreCapabilities)) + if ((int)errorType & (int)DiagnosticCategory::Capability && + optionSet.getBoolOption(CompilerOptionName::IgnoreCapabilities)) return false; return sink->diagnose(pos, info, args...); } template -bool maybeDiagnoseWarningOrError(DiagnosticSink* sink, CompilerOptionSet& optionSet, DiagnosticCategory errorType, P const& pos, DiagnosticInfo const& warningInfo, DiagnosticInfo const& errorInfo, Args const&... args) +bool maybeDiagnoseWarningOrError( + DiagnosticSink* sink, + CompilerOptionSet& optionSet, + DiagnosticCategory errorType, + P const& pos, + DiagnosticInfo const& warningInfo, + DiagnosticInfo const& errorInfo, + Args const&... args) { - if ((int)errorType & (int)DiagnosticCategory::Capability && optionSet.getBoolOption(CompilerOptionName::RestrictiveCapabilityCheck)) + if ((int)errorType & (int)DiagnosticCategory::Capability && + optionSet.getBoolOption(CompilerOptionName::RestrictiveCapabilityCheck)) { return maybeDiagnose(sink, optionSet, errorType, pos, errorInfo, args...); } @@ -3588,6 +3775,6 @@ bool maybeDiagnoseWarningOrError(DiagnosticSink* sink, CompilerOptionSet& option } } -} +} // namespace Slang #endif diff --git a/source/slang/slang-container-pool.h b/source/slang/slang-container-pool.h index 0b7a56694..57ae9c270 100644 --- a/source/slang/slang-container-pool.h +++ b/source/slang/slang-container-pool.h @@ -1,8 +1,8 @@ #ifndef SLANG_CONTAINER_POOL_H #define SLANG_CONTAINER_POOL_H -#include "../core/slang-list.h" #include "../core/slang-dictionary.h" +#include "../core/slang-list.h" #include "../core/slang-virtual-object-pool.h" // A pool to allow reuse of common types of containers to avoid @@ -46,8 +46,11 @@ struct ContainerPool ObjectPool> m_hashSetPool; ContainerPool() - :m_listPool(kContainerPoolSize), m_dictionaryPool(kContainerPoolSize), m_hashSetPool(kContainerPoolSize) - {} + : m_listPool(kContainerPoolSize) + , m_dictionaryPool(kContainerPoolSize) + , m_hashSetPool(kContainerPoolSize) + { + } template List* getList() @@ -88,6 +91,6 @@ struct ContainerPool m_hashSetPool.freeObject((HashSet*)set); } }; -} +} // namespace Slang #endif diff --git a/source/slang/slang-content-assist-info.h b/source/slang/slang-content-assist-info.h index 5487a1f69..6d4503cc5 100644 --- a/source/slang/slang-content-assist-info.h +++ b/source/slang/slang-content-assist-info.h @@ -94,8 +94,8 @@ struct ContentAssistInfo // The primary module from which the current content assist request is made. Provided by the // language server. Name* primaryModuleName = nullptr; - // The primary module path from which the current content assist request is made. Provided by the - // language server. + // The primary module path from which the current content assist request is made. Provided by + // the language server. String primaryModulePath; // The cursor location at which a completion request is made. Provided by the language server. Index cursorLine = 0; @@ -110,4 +110,4 @@ struct ContentAssistInfo PreprocessorContentAssistInfo preprocessorInfo; }; -} +} // namespace Slang diff --git a/source/slang/slang-core-module-textures.cpp b/source/slang/slang-core-module-textures.cpp index 064529ff4..568c6f37d 100644 --- a/source/slang/slang-core-module-textures.cpp +++ b/source/slang/slang-core-module-textures.cpp @@ -1,7 +1,9 @@ #include "slang-core-module-textures.h" + #include -#define EMIT_LINE_DIRECTIVE() sb << "#line " << (__LINE__+1) << " \"slang-core-module-textures.cpp\"\n" +#define EMIT_LINE_DIRECTIVE() \ + sb << "#line " << (__LINE__ + 1) << " \"slang-core-module-textures.cpp\"\n" namespace Slang { @@ -24,7 +26,7 @@ static_assert(SLANG_COUNT_OF(spaces) % indentWidth == 1); struct BraceScope { BraceScope(const char*& i, StringBuilder& sb, const char* end = "\n") - :i(i), sb(sb), end(end) + : i(i), sb(sb), end(end) { // If we hit this assert, it means that we are indenting too deep and // need more spaces in 'spaces' above. @@ -79,7 +81,7 @@ void TextureTypeInfo::writeFuncBody( sb << i << "case cpp:\n"; sb << i << "case hlsl:\n"; sb << i << "__intrinsic_asm \"." << funcName << "\";\n"; - if(glsl.getLength()) + if (glsl.getLength()) { sb << i << "case glsl:\n"; if (glsl.startsWith("if")) @@ -87,7 +89,7 @@ void TextureTypeInfo::writeFuncBody( else sb << i << "__intrinsic_asm \"" << glsl << "\";\n"; } - if(cuda.getLength()) + if (cuda.getLength()) { sb << i << "case cuda:\n"; sb << i << "__intrinsic_asm \"" << cuda << "\";\n"; @@ -103,7 +105,7 @@ void TextureTypeInfo::writeFuncBody( sb << i << "if (access == " << kCoreModule_ResourceAccessReadWrite << ")\n"; sb << i << "return spirv_asm\n"; { - BraceScope spirvRWScope{ i, sb, ";\n" }; + BraceScope spirvRWScope{i, sb, ";\n"}; sb << spirvRWDefault << "\n"; } sb << i << "else if (isCombined != 0)\n"; @@ -174,19 +176,18 @@ void TextureTypeInfo::writeFunc( cuda, metal, wgsl, - readNoneMode - ); + readNoneMode); } void TextureTypeInfo::writeGetDimensionFunctions() { - static const char* kComponentNames[]{ "x", "y", "z", "w" }; + static const char* kComponentNames[]{"x", "y", "z", "w"}; SlangResourceShape baseShape = base.baseShape; // `GetDimensions` - const char* dimParamTypes[] = { "out float ", "out int ", "out uint " }; - const char* dimParamTypesInner[] = { "float", "int", "uint" }; + const char* dimParamTypes[] = {"out float ", "out int ", "out uint "}; + const char* dimParamTypesInner[] = {"float", "int", "uint"}; for (int tid = 0; tid < 3; tid++) { auto t = dimParamTypes[tid]; @@ -223,8 +224,10 @@ void TextureTypeInfo::writeGetDimensionFunctions() case SLANG_TEXTURE_1D: ++paramCount; params << t << "width"; - metal << "(*($" << String(paramCount) << ") = $0.get_width(" << String(metalMipLevel) << ")),"; - wgsl << "($" << String(paramCount) << ") = textureDimensions($0" << (includeMipInfo ? ", $1" : "") << ");"; + metal << "(*($" << String(paramCount) << ") = $0.get_width(" + << String(metalMipLevel) << ")),"; + wgsl << "($" << String(paramCount) << ") = textureDimensions($0" + << (includeMipInfo ? ", $1" : "") << ");"; sizeDimCount = 1; break; @@ -233,13 +236,15 @@ void TextureTypeInfo::writeGetDimensionFunctions() case SLANG_TEXTURE_CUBE: ++paramCount; params << t << "width,"; - metal << "(*($" << String(paramCount) << ") = $0.get_width(" << String(metalMipLevel) << ")),"; + metal << "(*($" << String(paramCount) << ") = $0.get_width(" + << String(metalMipLevel) << ")),"; wgsl << "var dim = textureDimensions($0" << (includeMipInfo ? ", $1" : "") << ");"; wgsl << "($" << String(paramCount) << ") = dim.x;"; ++paramCount; params << t << "height"; - metal << "(*($" << String(paramCount) << ") = $0.get_height(" << String(metalMipLevel) << ")),"; + metal << "(*($" << String(paramCount) << ") = $0.get_height(" + << String(metalMipLevel) << ")),"; wgsl << "($" << String(paramCount) << ") = dim.y;"; sizeDimCount = 2; @@ -248,26 +253,27 @@ void TextureTypeInfo::writeGetDimensionFunctions() case SLANG_TEXTURE_3D: ++paramCount; params << t << "width,"; - metal << "(*($" << String(paramCount) << ") = $0.get_width(" << String(metalMipLevel) << ")),"; + metal << "(*($" << String(paramCount) << ") = $0.get_width(" + << String(metalMipLevel) << ")),"; wgsl << "var dim = textureDimensions($0" << (includeMipInfo ? ", $1" : "") << ");"; wgsl << "($" << String(paramCount) << ") = dim.x;"; ++paramCount; params << t << "height,"; - metal << "(*($" << String(paramCount) << ") = $0.get_height(" << String(metalMipLevel) << ")),"; + metal << "(*($" << String(paramCount) << ") = $0.get_height(" + << String(metalMipLevel) << ")),"; wgsl << "($" << String(paramCount) << ") = dim.y;"; ++paramCount; params << t << "depth"; - metal << "(*($" << String(paramCount) << ") = $0.get_depth(" << String(metalMipLevel) << ")),"; + metal << "(*($" << String(paramCount) << ") = $0.get_depth(" + << String(metalMipLevel) << ")),"; wgsl << "($" << String(paramCount) << ") = dim.z;"; sizeDimCount = 3; break; - default: - assert(!"unexpected"); - break; + default: assert(!"unexpected"); break; } if (isArray) @@ -301,71 +307,70 @@ void TextureTypeInfo::writeGetDimensionFunctions() StringBuilder glsl; { auto emitIntrinsic = [&](UnownedStringSlice funcName, bool useLodStr) + { + int aa = 1; + StringBuilder opStrSB; + opStrSB << " = " << funcName << "($0"; + if (useLodStr) { - int aa = 1; - StringBuilder opStrSB; - opStrSB << " = " << funcName << "($0"; - if (useLodStr) - { - String lodStr = ", 0"; - if (includeMipInfo) - { - int mipLevelArg = aa++; - lodStr = ", int($"; - lodStr.append(mipLevelArg); - lodStr.append(")"); - } - opStrSB << lodStr; - } - auto opStr = opStrSB.produceString(); - int cc = 0; - switch (baseShape) + String lodStr = ", 0"; + if (includeMipInfo) { - case SLANG_TEXTURE_1D: - glsl << "($" << aa++ << opStr << ")"; - if (isArray) - { - glsl << ".x"; - } - glsl << ")"; - cc = 1; - break; - - case SLANG_TEXTURE_2D: - case SLANG_TEXTURE_CUBE: - glsl << "($" << aa++ << opStr << ").x)"; - glsl << ", ($" << aa++ << opStr << ").y)"; - cc = 2; - break; - - case SLANG_TEXTURE_3D: - glsl << "($" << aa++ << opStr << ").x)"; - glsl << ", ($" << aa++ << opStr << ").y)"; - glsl << ", ($" << aa++ << opStr << ").z)"; - cc = 3; - break; - - default: - SLANG_UNEXPECTED("unhandled resource shape"); - break; + int mipLevelArg = aa++; + lodStr = ", int($"; + lodStr.append(mipLevelArg); + lodStr.append(")"); } - + opStrSB << lodStr; + } + auto opStr = opStrSB.produceString(); + int cc = 0; + switch (baseShape) + { + case SLANG_TEXTURE_1D: + glsl << "($" << aa++ << opStr << ")"; if (isArray) { - glsl << ", ($" << aa++ << opStr << ")." << kComponentNames[cc] << ")"; + glsl << ".x"; } + glsl << ")"; + cc = 1; + break; + + case SLANG_TEXTURE_2D: + case SLANG_TEXTURE_CUBE: + glsl << "($" << aa++ << opStr << ").x)"; + glsl << ", ($" << aa++ << opStr << ").y)"; + cc = 2; + break; + + case SLANG_TEXTURE_3D: + glsl << "($" << aa++ << opStr << ").x)"; + glsl << ", ($" << aa++ << opStr << ").y)"; + glsl << ", ($" << aa++ << opStr << ").z)"; + cc = 3; + break; + + default: SLANG_UNEXPECTED("unhandled resource shape"); break; + } + + if (isArray) + { + glsl << ", ($" << aa++ << opStr << ")." << kComponentNames[cc] << ")"; + } - if (isMultisample) - { - glsl << ", ($" << aa++ << " = textureSamples($0))"; - } + if (isMultisample) + { + glsl << ", ($" << aa++ << " = textureSamples($0))"; + } - if (includeMipInfo) - { - glsl << ", ($" << aa++ << " = textureQueryLevels($0))"; - } - }; - glsl << "if (access == " << kCoreModule_ResourceAccessReadOnly << ") __intrinsic_asm \""; + if (includeMipInfo) + { + glsl << ", ($" << aa++ << " = textureQueryLevels($0))"; + } + }; + glsl << "if (access == " << kCoreModule_ResourceAccessReadOnly + << ") __intrinsic_asm \""; emitIntrinsic(toSlice("textureSize"), !isMultisample); glsl << "\";\n"; glsl << "__intrinsic_asm \""; @@ -374,53 +379,56 @@ void TextureTypeInfo::writeGetDimensionFunctions() } // SPIRV ASM generation - auto generateSpirvAsm = [&](StringBuilder& spirv, bool isRW, UnownedStringSlice imageVar) + auto generateSpirvAsm = + [&](StringBuilder& spirv, bool isRW, UnownedStringSlice imageVar) { spirv << "%vecSize:$$uint"; - if (sizeDimCount > 1) spirv << sizeDimCount; + if (sizeDimCount > 1) + spirv << sizeDimCount; spirv << " = "; if (isMultisample || isRW) spirv << "OpImageQuerySize " << imageVar << ";"; else - spirv << "OpImageQuerySizeLod " << imageVar <<" $0;"; + spirv << "OpImageQuerySizeLod " << imageVar << " $0;"; auto convertAndStore = [&](UnownedStringSlice uintSourceVal, const char* destParam) + { + if (UnownedStringSlice(rawT) == "uint") { - if (UnownedStringSlice(rawT) == "uint") + spirv << "OpStore &" << destParam << " %" << uintSourceVal << ";"; + } + else + { + if (UnownedStringSlice(rawT) == "int") { - spirv << "OpStore &" << destParam << " %" << uintSourceVal << ";"; + spirv << "%c_" << uintSourceVal << " : $$" << rawT << " = OpBitcast %" + << uintSourceVal << "; "; } else { - if (UnownedStringSlice(rawT) == "int") - { - spirv << "%c_" << uintSourceVal << " : $$" << rawT << " = OpBitcast %" << uintSourceVal << "; "; - } - else - { - spirv << "%c_" << uintSourceVal << " : $$" << rawT << " = OpConvertUToF %" << uintSourceVal << "; "; - } - spirv << "OpStore &" << destParam << "%c_" << uintSourceVal << ";"; + spirv << "%c_" << uintSourceVal << " : $$" << rawT + << " = OpConvertUToF %" << uintSourceVal << "; "; } - }; + spirv << "OpStore &" << destParam << "%c_" << uintSourceVal << ";"; + } + }; auto extractSizeComponent = [&](int componentId, const char* destParam) + { + String elementVal = String("_") + destParam; + if (sizeDimCount == 1) { - String elementVal = String("_") + destParam; - if (sizeDimCount == 1) - { - spirv << "%" << elementVal << " : $$uint = OpCopyObject %vecSize; "; - } - else - { - spirv << "%" << elementVal << " : $$uint = OpCompositeExtract %vecSize " << componentId << "; "; - } - convertAndStore(elementVal.getUnownedSlice(), destParam); - }; + spirv << "%" << elementVal << " : $$uint = OpCopyObject %vecSize; "; + } + else + { + spirv << "%" << elementVal << " : $$uint = OpCompositeExtract %vecSize " + << componentId << "; "; + } + convertAndStore(elementVal.getUnownedSlice(), destParam); + }; switch (baseShape) { - case SLANG_TEXTURE_1D: - extractSizeComponent(0, "width"); - break; + case SLANG_TEXTURE_1D: extractSizeComponent(0, "width"); break; case SLANG_TEXTURE_2D: case SLANG_TEXTURE_CUBE: @@ -434,9 +442,7 @@ void TextureTypeInfo::writeGetDimensionFunctions() extractSizeComponent(2, "depth"); break; - default: - assert(!"unexpected"); - break; + default: assert(!"unexpected"); break; } if (isArray) @@ -479,11 +485,15 @@ void TextureTypeInfo::writeGetDimensionFunctions() sb << " __glsl_extension(GL_EXT_samplerless_texture_functions)\n"; sb << " [require(cpp"; - if (glsl.getLength()) sb << "_glsl"; + if (glsl.getLength()) + sb << "_glsl"; sb << "_hlsl"; - if (metal.getLength()) sb << "_metal"; - if (spirvDefault.getLength() && spirvCombined.getLength()) sb << "_spirv"; - if (wgsl.getLength()) sb << "_wgsl"; + if (metal.getLength()) + sb << "_metal"; + if (spirvDefault.getLength() && spirvCombined.getLength()) + sb << "_spirv"; + if (wgsl.getLength()) + sb << "_wgsl"; sb << ", texture_sm_4_1)]\n"; writeFunc( @@ -502,4 +512,4 @@ void TextureTypeInfo::writeGetDimensionFunctions() } } -} +} // namespace Slang diff --git a/source/slang/slang-core-module-textures.h b/source/slang/slang-core-module-textures.h index a521a44d3..7262327cb 100644 --- a/source/slang/slang-core-module-textures.h +++ b/source/slang/slang-core-module-textures.h @@ -1,31 +1,33 @@ #pragma once +#include "../core/slang-string.h" #include "slang-ir.h" #include "slang-type-system-shared.h" -#include "../core/slang-string.h" namespace Slang { -static const struct BaseTextureShapeInfo { - char const* shapeName; - SlangResourceShape baseShape; - int coordCount; +static const struct BaseTextureShapeInfo +{ + char const* shapeName; + SlangResourceShape baseShape; + int coordCount; } kBaseTextureShapes[] = { - { "1D", SLANG_TEXTURE_1D, 1 }, - { "2D", SLANG_TEXTURE_2D, 2 }, - { "3D", SLANG_TEXTURE_3D, 3 }, - { "Cube", SLANG_TEXTURE_CUBE, 3 }, + {"1D", SLANG_TEXTURE_1D, 1}, + {"2D", SLANG_TEXTURE_2D, 2}, + {"3D", SLANG_TEXTURE_3D, 3}, + {"Cube", SLANG_TEXTURE_CUBE, 3}, }; -static const struct BaseTextureAccessInfo { +static const struct BaseTextureAccessInfo +{ char const* name; SlangResourceAccess access; } kBaseTextureAccessLevels[] = { - { "", SLANG_RESOURCE_ACCESS_READ }, - { "RW", SLANG_RESOURCE_ACCESS_READ_WRITE }, - { "RasterizerOrdered", SLANG_RESOURCE_ACCESS_RASTER_ORDERED }, - { "Feedback", SLANG_RESOURCE_ACCESS_FEEDBACK }, + {"", SLANG_RESOURCE_ACCESS_READ}, + {"RW", SLANG_RESOURCE_ACCESS_READ_WRITE}, + {"RasterizerOrdered", SLANG_RESOURCE_ACCESS_RASTER_ORDERED}, + {"Feedback", SLANG_RESOURCE_ACCESS_FEEDBACK}, }; struct TextureTypeInfo @@ -70,8 +72,7 @@ public: const String& spirvRWDefault, const String& spirvCombined, const String& metal, - const String& wgsl - ); + const String& wgsl); void writeFuncWithSig( const char* funcName, const String& sig, @@ -82,8 +83,7 @@ public: const String& cuda = String{}, const String& metal = String{}, const String& wgsl = String{}, - const ReadNoneMode readNoneMode = ReadNoneMode::Never - ); + const ReadNoneMode readNoneMode = ReadNoneMode::Never); void writeFunc( const char* returnType, const char* funcName, @@ -95,11 +95,10 @@ public: const String& cuda = String{}, const String& metal = String{}, const String& wgsl = String{}, - const ReadNoneMode readNoneMode = ReadNoneMode::Never - ); + const ReadNoneMode readNoneMode = ReadNoneMode::Never); // A pointer to a string representing the current level of indentation const char* i; }; -} +} // namespace Slang diff --git a/source/slang/slang-core-module.cpp b/source/slang/slang-core-module.cpp index 5b163e594..df53bbf42 100644 --- a/source/slang/slang-core-module.cpp +++ b/source/slang/slang-core-module.cpp @@ -1,6 +1,6 @@ +#include "../core/slang-string-util.h" #include "slang-compiler.h" #include "slang-ir.h" -#include "../core/slang-string-util.h" #define STRINGIZE(x) STRINGIZE2(x) #define STRINGIZE2(x) #x @@ -8,18 +8,19 @@ namespace Slang { - String Session::getCoreModulePath() +String Session::getCoreModulePath() +{ + if (coreModulePath.getLength() == 0) { - if(coreModulePath.getLength() == 0) - { - // Make sure we have a line of text from __FILE__, that we'll extract the filename from - List lines; - StringUtil::calcLines(UnownedStringSlice::fromLiteral(__FILE__), lines); - SLANG_ASSERT(lines.getCount() > 0 && lines[0].getLength() > 0); + // Make sure we have a line of text from __FILE__, that we'll extract the filename from + List lines; + StringUtil::calcLines(UnownedStringSlice::fromLiteral(__FILE__), lines); + SLANG_ASSERT(lines.getCount() > 0 && lines[0].getLength() > 0); - // Make the path just the filename to remove issues around path being included on different targets - coreModulePath = Path::getFileName(lines[0]); - } - return coreModulePath; + // Make the path just the filename to remove issues around path being included on different + // targets + coreModulePath = Path::getFileName(lines[0]); } + return coreModulePath; } +} // namespace Slang diff --git a/source/slang/slang-diagnostic-defs.h b/source/slang/slang-diagnostic-defs.h index 48b296ce3..f37d6a8b6 100644 --- a/source/slang/slang-diagnostic-defs.h +++ b/source/slang/slang-diagnostic-defs.h @@ -23,9 +23,24 @@ // DIAGNOSTIC(-1, Note, alsoSeePipelineDefinition, "also see pipeline definition") -DIAGNOSTIC(-1, Note, implicitParameterMatchingFailedBecauseNameNotAccessible, "implicit parameter matching failed because the component of the same name is not accessible from '$0'.\ncheck if you have declared necessary requirements and properly used the 'public' qualifier.") -DIAGNOSTIC(-1, Note, implicitParameterMatchingFailedBecauseShaderDoesNotDefineComponent, "implicit parameter matching failed because shader '$0' does not define component '$1'.") -DIAGNOSTIC(-1, Note, implicitParameterMatchingFailedBecauseTypeMismatch, "implicit parameter matching failed because the component of the same name does not match parameter type '$0'.") +DIAGNOSTIC( + -1, + Note, + implicitParameterMatchingFailedBecauseNameNotAccessible, + "implicit parameter matching failed because the component of the same name is not accessible " + "from '$0'.\ncheck if you have declared necessary requirements and properly used the 'public' " + "qualifier.") +DIAGNOSTIC( + -1, + Note, + implicitParameterMatchingFailedBecauseShaderDoesNotDefineComponent, + "implicit parameter matching failed because shader '$0' does not define component '$1'.") +DIAGNOSTIC( + -1, + Note, + implicitParameterMatchingFailedBecauseTypeMismatch, + "implicit parameter matching failed because the component of the same name does not match " + "parameter type '$0'.") DIAGNOSTIC(-1, Note, noteShaderIsTargetingPipeine, "shader '$0' is targeting pipeline '$1'") DIAGNOSTIC(-1, Note, seeDefinitionOf, "see definition of '$0'") DIAGNOSTIC(-1, Note, seeConstantBufferDefinition, "see constant buffer definition.") @@ -35,15 +50,28 @@ DIAGNOSTIC(-1, Note, seeDefinitionOfShader, "see definition of shader '$0'") DIAGNOSTIC(-1, Note, seeInclusionOf, "see inclusion of '$0'") DIAGNOSTIC(-1, Note, seeModuleBeingUsedIn, "see module '$0' being used in '$1'") DIAGNOSTIC(-1, Note, seePipelineRequirementDefinition, "see pipeline requirement definition") -DIAGNOSTIC(-1, Note, seePotentialDefinitionOfComponent, "see potential definition of component '$0'") +DIAGNOSTIC( + -1, + Note, + seePotentialDefinitionOfComponent, + "see potential definition of component '$0'") DIAGNOSTIC(-1, Note, seePreviousDefinition, "see previous definition") DIAGNOSTIC(-1, Note, seePreviousDefinitionOf, "see previous definition of '$0'") DIAGNOSTIC(-1, Note, seeRequirementDeclaration, "see requirement declaration") -DIAGNOSTIC(-1, Note, doYouForgetToMakeComponentAccessible, "do you forget to make component '$0' acessible from '$1' (missing public qualifier)?") +DIAGNOSTIC( + -1, + Note, + doYouForgetToMakeComponentAccessible, + "do you forget to make component '$0' acessible from '$1' (missing public qualifier)?") DIAGNOSTIC(-1, Note, seeDeclarationOf, "see declaration of '$0'") -DIAGNOSTIC(-1, Note, seeDeclarationOfInterfaceRequirement, "see interface requirement declaration of '$0'") -// An alternate wording of the above note, emphasing the position rather than content of the declaration. +DIAGNOSTIC( + -1, + Note, + seeDeclarationOfInterfaceRequirement, + "see interface requirement declaration of '$0'") +// An alternate wording of the above note, emphasing the position rather than content of the +// declaration. DIAGNOSTIC(-1, Note, declaredHere, "declared here") DIAGNOSTIC(-1, Note, seeOtherDeclarationOf, "see other declaration of '$0'") DIAGNOSTIC(-1, Note, seePreviousDeclarationOf, "see previous declaration of '$0'") @@ -55,98 +83,214 @@ DIAGNOSTIC(-1, Note, entryPointCandidate, "see candidate declaration for entry p // 0xxxx - Command line and interaction with host platform APIs. // -DIAGNOSTIC( 1, Error, cannotOpenFile, "cannot open file '$0'.") -DIAGNOSTIC( 2, Error, cannotFindFile, "cannot find file '$0'.") -DIAGNOSTIC( 2, Error, unsupportedCompilerMode, "unsupported compiler mode.") -DIAGNOSTIC( 4, Error, cannotWriteOutputFile, "cannot write output file '$0'.") -DIAGNOSTIC( 5, Error, failedToLoadDynamicLibrary, "failed to load dynamic library '$0'") -DIAGNOSTIC( 6, Error, tooManyOutputPathsSpecified, "$0 output paths specified, but only $1 entry points given") - -DIAGNOSTIC( 7, Error, noOutputPathSpecifiedForEntryPoint, - "no output path specified for entry point '$0' (the '-o' option for an entry point must precede the corresponding '-entry')") - -DIAGNOSTIC( 8, Error, outputPathsImplyDifferentFormats, +DIAGNOSTIC(1, Error, cannotOpenFile, "cannot open file '$0'.") +DIAGNOSTIC(2, Error, cannotFindFile, "cannot find file '$0'.") +DIAGNOSTIC(2, Error, unsupportedCompilerMode, "unsupported compiler mode.") +DIAGNOSTIC(4, Error, cannotWriteOutputFile, "cannot write output file '$0'.") +DIAGNOSTIC(5, Error, failedToLoadDynamicLibrary, "failed to load dynamic library '$0'") +DIAGNOSTIC( + 6, + Error, + tooManyOutputPathsSpecified, + "$0 output paths specified, but only $1 entry points given") + +DIAGNOSTIC( + 7, + Error, + noOutputPathSpecifiedForEntryPoint, + "no output path specified for entry point '$0' (the '-o' option for an entry point must " + "precede the corresponding '-entry')") + +DIAGNOSTIC( + 8, + Error, + outputPathsImplyDifferentFormats, "the output paths '$0' and '$1' require different code-generation targets") -DIAGNOSTIC( 10, Error, explicitOutputPathsAndMultipleTargets, "canot use both explicit output paths ('-o') and multiple targets ('-target')") -DIAGNOSTIC( 12, Error, cannotDeduceSourceLanguage, "can't deduce language for input file '$0'") -DIAGNOSTIC( 13, Error, unknownCodeGenerationTarget, "unknown code generation target '$0'") -DIAGNOSTIC( 14, Error, unknownProfile, "unknown profile '$0'") -DIAGNOSTIC( 15, Error, unknownStage, "unknown stage '$0'") -DIAGNOSTIC( 16, Error, unknownPassThroughTarget, "unknown pass-through target '$0'") -DIAGNOSTIC( 17, Error, unknownCommandLineOption, "unknown command-line option '$0'") -DIAGNOSTIC( 19, Error, unknownSourceLanguage, "unknown source language '$0'") - -DIAGNOSTIC( 20, Error, entryPointsNeedToBeAssociatedWithTranslationUnits, "when using multiple source files, entry points must be specified after their corresponding source file(s)") -DIAGNOSTIC( 22, Error, unknownDownstreamCompiler, "unknown downstream compiler '$0'") - -DIAGNOSTIC( 26, Error, unknownOptimiziationLevel, "unknown optimization level '$0'") - -DIAGNOSTIC( 28, Error, unableToGenerateCodeForTarget, "unable to generate code for target '$0'") - -DIAGNOSTIC( 30, Warning, sameStageSpecifiedMoreThanOnce, "the stage '$0' was specified more than once for entry point '$1'") -DIAGNOSTIC( 31, Error, conflictingStagesForEntryPoint, "conflicting stages have been specified for entry point '$0'") -DIAGNOSTIC( 32, Warning, explicitStageDoesntMatchImpliedStage, "the stage specified for entry point '$0' ('$1') does not match the stage implied by the source file name ('$2')") -DIAGNOSTIC( 33, Error, stageSpecificationIgnoredBecauseNoEntryPoints, "one or more stages were specified, but no entry points were specified with '-entry'") -DIAGNOSTIC( 34, Error, stageSpecificationIgnoredBecauseBeforeAllEntryPoints, "when compiling multiple entry points, any '-stage' options must follow the '-entry' option that they apply to") -DIAGNOSTIC( 35, Error, noStageSpecifiedInPassThroughMode, "no stage was specified for entry point '$0'; when using the '-pass-through' option, stages must be fully specified on the command line") -DIAGNOSTIC( 36, Error, expectingAnInteger, "expecting an integer value") - -DIAGNOSTIC( 40, Warning, sameProfileSpecifiedMoreThanOnce, "the '$0' was specified more than once for target '$0'") -DIAGNOSTIC( 41, Error, conflictingProfilesSpecifiedForTarget, "conflicting profiles have been specified for target '$0'") - -DIAGNOSTIC( 42, Error, profileSpecificationIgnoredBecauseNoTargets, "a '-profile' option was specified, but no target was specified with '-target'") -DIAGNOSTIC( 43, Error, profileSpecificationIgnoredBecauseBeforeAllTargets, "when using multiple targets, any '-profile' option must follow the '-target' it applies to") - -DIAGNOSTIC( 42, Error, targetFlagsIgnoredBecauseNoTargets, "target options were specified, but no target was specified with '-target'") -DIAGNOSTIC( 43, Error, targetFlagsIgnoredBecauseBeforeAllTargets, "when using multiple targets, any target options must follow the '-target' they apply to") - -DIAGNOSTIC( 50, Error, duplicateTargets, "the target '$0' has been specified more than once") - -DIAGNOSTIC( 51, Error, unhandledLanguageForSourceEmbedding, "unhandled source language for source embedding") - -DIAGNOSTIC( 60, Error, cannotDeduceOutputFormatFromPath, "cannot infer an output format from the output path '$0'") -DIAGNOSTIC( 61, Error, cannotMatchOutputFileToTarget, "no specified '-target' option matches the output path '$0', which implies the '$1' format") - -DIAGNOSTIC( 62, Error, unknownCommandLineValue, "unknown value for option. Valid values are '$0'") -DIAGNOSTIC( 63, Error, unknownHelpCategory, "unknown help category") - -DIAGNOSTIC( 70, Error, cannotMatchOutputFileToEntryPoint, "the output path '$0' is not associated with any entry point; a '-o' option for a compiled kernel must follow the '-entry' option for its corresponding entry point") - -DIAGNOSTIC( 80, Error, duplicateOutputPathsForEntryPointAndTarget, "multiple output paths have been specified entry point '$0' on target '$1'") -DIAGNOSTIC( 81, Error, duplicateOutputPathsForTarget, "multiple output paths have been specified for target '$0'") -DIAGNOSTIC( 82, Error, duplicateDependencyOutputPaths, "the -dep argument can only be specified once") - -DIAGNOSTIC( 82, Error, unableToWriteReproFile, "unable to write repro file '%0'") -DIAGNOSTIC( 83, Error, unableToWriteModuleContainer, "unable to write module container '%0'") -DIAGNOSTIC( 84, Error, unableToReadModuleContainer, "unable to read module container '%0'") -DIAGNOSTIC( 85, Error, unableToAddReferenceToModuleContainer, "unable to add a reference to a module container") -DIAGNOSTIC( 86, Error, unableToCreateModuleContainer, "unable to create module container") - -DIAGNOSTIC( 87, Error, unableToSetDefaultDownstreamCompiler, "unable to set default downstream compiler for source language '%0' to '%1'") - -DIAGNOSTIC( 88, Error, unknownArchiveType, "archive type '%0' is unknown") -DIAGNOSTIC( 89, Error, expectingSlangRiffContainer, "expecting a slang riff container") -DIAGNOSTIC( 90, Error, incompatibleRiffSemanticVersion, "incompatible riff semantic version %0 expecting %1") -DIAGNOSTIC( 91, Error, riffHashMismatch, "riff hash mismatch - incompatible riff") -DIAGNOSTIC( 92, Error, unableToCreateDirectory, "unable to create directory '$0'") -DIAGNOSTIC( 93, Error, unableExtractReproToDirectory, "unable to extract repro to directory '$0'") -DIAGNOSTIC( 94, Error, unableToReadRiff, "unable to read as 'riff'/not a 'riff' file") - -DIAGNOSTIC( 95, Error, unknownLibraryKind, "unknown library kind '$0'") -DIAGNOSTIC( 96, Error, kindNotLinkable, "not a known linkable kind '$0'") -DIAGNOSTIC( 97, Error, libraryDoesNotExist, "library '$0' does not exist") -DIAGNOSTIC( 98, Error, cannotAccessAsBlob, "cannot access as a blob") -DIAGNOSTIC( 99, Error, unknownDebugOption, "unknown debug option, known options are ($0)") +DIAGNOSTIC( + 10, + Error, + explicitOutputPathsAndMultipleTargets, + "canot use both explicit output paths ('-o') and multiple targets ('-target')") +DIAGNOSTIC(12, Error, cannotDeduceSourceLanguage, "can't deduce language for input file '$0'") +DIAGNOSTIC(13, Error, unknownCodeGenerationTarget, "unknown code generation target '$0'") +DIAGNOSTIC(14, Error, unknownProfile, "unknown profile '$0'") +DIAGNOSTIC(15, Error, unknownStage, "unknown stage '$0'") +DIAGNOSTIC(16, Error, unknownPassThroughTarget, "unknown pass-through target '$0'") +DIAGNOSTIC(17, Error, unknownCommandLineOption, "unknown command-line option '$0'") +DIAGNOSTIC(19, Error, unknownSourceLanguage, "unknown source language '$0'") + +DIAGNOSTIC( + 20, + Error, + entryPointsNeedToBeAssociatedWithTranslationUnits, + "when using multiple source files, entry points must be specified after their corresponding " + "source file(s)") +DIAGNOSTIC(22, Error, unknownDownstreamCompiler, "unknown downstream compiler '$0'") + +DIAGNOSTIC(26, Error, unknownOptimiziationLevel, "unknown optimization level '$0'") + +DIAGNOSTIC(28, Error, unableToGenerateCodeForTarget, "unable to generate code for target '$0'") + +DIAGNOSTIC( + 30, + Warning, + sameStageSpecifiedMoreThanOnce, + "the stage '$0' was specified more than once for entry point '$1'") +DIAGNOSTIC( + 31, + Error, + conflictingStagesForEntryPoint, + "conflicting stages have been specified for entry point '$0'") +DIAGNOSTIC( + 32, + Warning, + explicitStageDoesntMatchImpliedStage, + "the stage specified for entry point '$0' ('$1') does not match the stage implied by the " + "source file name ('$2')") +DIAGNOSTIC( + 33, + Error, + stageSpecificationIgnoredBecauseNoEntryPoints, + "one or more stages were specified, but no entry points were specified with '-entry'") +DIAGNOSTIC( + 34, + Error, + stageSpecificationIgnoredBecauseBeforeAllEntryPoints, + "when compiling multiple entry points, any '-stage' options must follow the '-entry' option " + "that they apply to") +DIAGNOSTIC( + 35, + Error, + noStageSpecifiedInPassThroughMode, + "no stage was specified for entry point '$0'; when using the '-pass-through' option, stages " + "must be fully specified on the command line") +DIAGNOSTIC(36, Error, expectingAnInteger, "expecting an integer value") + +DIAGNOSTIC( + 40, + Warning, + sameProfileSpecifiedMoreThanOnce, + "the '$0' was specified more than once for target '$0'") +DIAGNOSTIC( + 41, + Error, + conflictingProfilesSpecifiedForTarget, + "conflicting profiles have been specified for target '$0'") + +DIAGNOSTIC( + 42, + Error, + profileSpecificationIgnoredBecauseNoTargets, + "a '-profile' option was specified, but no target was specified with '-target'") +DIAGNOSTIC( + 43, + Error, + profileSpecificationIgnoredBecauseBeforeAllTargets, + "when using multiple targets, any '-profile' option must follow the '-target' it applies to") + +DIAGNOSTIC( + 42, + Error, + targetFlagsIgnoredBecauseNoTargets, + "target options were specified, but no target was specified with '-target'") +DIAGNOSTIC( + 43, + Error, + targetFlagsIgnoredBecauseBeforeAllTargets, + "when using multiple targets, any target options must follow the '-target' they apply to") + +DIAGNOSTIC(50, Error, duplicateTargets, "the target '$0' has been specified more than once") + +DIAGNOSTIC( + 51, + Error, + unhandledLanguageForSourceEmbedding, + "unhandled source language for source embedding") + +DIAGNOSTIC( + 60, + Error, + cannotDeduceOutputFormatFromPath, + "cannot infer an output format from the output path '$0'") +DIAGNOSTIC( + 61, + Error, + cannotMatchOutputFileToTarget, + "no specified '-target' option matches the output path '$0', which implies the '$1' format") + +DIAGNOSTIC(62, Error, unknownCommandLineValue, "unknown value for option. Valid values are '$0'") +DIAGNOSTIC(63, Error, unknownHelpCategory, "unknown help category") + +DIAGNOSTIC( + 70, + Error, + cannotMatchOutputFileToEntryPoint, + "the output path '$0' is not associated with any entry point; a '-o' option for a compiled " + "kernel must follow the '-entry' option for its corresponding entry point") + +DIAGNOSTIC( + 80, + Error, + duplicateOutputPathsForEntryPointAndTarget, + "multiple output paths have been specified entry point '$0' on target '$1'") +DIAGNOSTIC( + 81, + Error, + duplicateOutputPathsForTarget, + "multiple output paths have been specified for target '$0'") +DIAGNOSTIC( + 82, + Error, + duplicateDependencyOutputPaths, + "the -dep argument can only be specified once") + +DIAGNOSTIC(82, Error, unableToWriteReproFile, "unable to write repro file '%0'") +DIAGNOSTIC(83, Error, unableToWriteModuleContainer, "unable to write module container '%0'") +DIAGNOSTIC(84, Error, unableToReadModuleContainer, "unable to read module container '%0'") +DIAGNOSTIC( + 85, + Error, + unableToAddReferenceToModuleContainer, + "unable to add a reference to a module container") +DIAGNOSTIC(86, Error, unableToCreateModuleContainer, "unable to create module container") + +DIAGNOSTIC( + 87, + Error, + unableToSetDefaultDownstreamCompiler, + "unable to set default downstream compiler for source language '%0' to '%1'") + +DIAGNOSTIC(88, Error, unknownArchiveType, "archive type '%0' is unknown") +DIAGNOSTIC(89, Error, expectingSlangRiffContainer, "expecting a slang riff container") +DIAGNOSTIC( + 90, + Error, + incompatibleRiffSemanticVersion, + "incompatible riff semantic version %0 expecting %1") +DIAGNOSTIC(91, Error, riffHashMismatch, "riff hash mismatch - incompatible riff") +DIAGNOSTIC(92, Error, unableToCreateDirectory, "unable to create directory '$0'") +DIAGNOSTIC(93, Error, unableExtractReproToDirectory, "unable to extract repro to directory '$0'") +DIAGNOSTIC(94, Error, unableToReadRiff, "unable to read as 'riff'/not a 'riff' file") + +DIAGNOSTIC(95, Error, unknownLibraryKind, "unknown library kind '$0'") +DIAGNOSTIC(96, Error, kindNotLinkable, "not a known linkable kind '$0'") +DIAGNOSTIC(97, Error, libraryDoesNotExist, "library '$0' does not exist") +DIAGNOSTIC(98, Error, cannotAccessAsBlob, "cannot access as a blob") +DIAGNOSTIC(99, Error, unknownDebugOption, "unknown debug option, known options are ($0)") // // 001xx - Downstream Compilers // -DIAGNOSTIC( 100, Error, failedToLoadDownstreamCompiler, "failed to load downstream compiler '$0'") -DIAGNOSTIC( 101, Error, downstreamCompilerDoesntSupportWholeProgramCompilation, "downstream compiler '$0' doesn't support whole program compilation") -DIAGNOSTIC( 102, Note, downstreamCompileTime, "downstream compile time: $0s") -DIAGNOSTIC( 103, Note, performanceBenchmarkResult, "compiler performance benchmark:\n$0") +DIAGNOSTIC(100, Error, failedToLoadDownstreamCompiler, "failed to load downstream compiler '$0'") +DIAGNOSTIC( + 101, + Error, + downstreamCompilerDoesntSupportWholeProgramCompilation, + "downstream compiler '$0' doesn't support whole program compilation") +DIAGNOSTIC(102, Note, downstreamCompileTime, "downstream compile time: $0s") +DIAGNOSTIC(103, Note, performanceBenchmarkResult, "compiler performance benchmark:\n$0") DIAGNOSTIC(99999, Note, noteFailedToLoadDynamicLibrary, "failed to load dynamic library '$0'") // @@ -154,9 +298,13 @@ DIAGNOSTIC(99999, Note, noteFailedToLoadDynamicLibrary, "failed to load dynamic // // 150xx - conditionals -DIAGNOSTIC(15000, Error, endOfFileInPreprocessorConditional, "end of file encountered during preprocessor conditional") +DIAGNOSTIC( + 15000, + Error, + endOfFileInPreprocessorConditional, + "end of file encountered during preprocessor conditional") DIAGNOSTIC(15001, Error, directiveWithoutIf, "'$0' directive without '#if'") -DIAGNOSTIC(15002, Error, directiveAfterElse , "'$0' directive without '#if'") +DIAGNOSTIC(15002, Error, directiveAfterElse, "'$0' directive without '#if'") DIAGNOSTIC(-1, Note, seeDirective, "see '$0' directive") @@ -164,17 +312,41 @@ DIAGNOSTIC(-1, Note, seeDirective, "see '$0' directive") DIAGNOSTIC(15100, Error, expectedPreprocessorDirectiveName, "expected preprocessor directive name") DIAGNOSTIC(15101, Error, unknownPreprocessorDirective, "unknown preprocessor directive '$0'") DIAGNOSTIC(15102, Error, expectedTokenInPreprocessorDirective, "expected '$0' in '$1' directive") -DIAGNOSTIC(15102, Error, expected2TokensInPreprocessorDirective, "expected '$0' or '$1' in '$2' directive") -DIAGNOSTIC(15103, Error, unexpectedTokensAfterDirective, "unexpected tokens following '$0' directive") +DIAGNOSTIC( + 15102, + Error, + expected2TokensInPreprocessorDirective, + "expected '$0' or '$1' in '$2' directive") +DIAGNOSTIC( + 15103, + Error, + unexpectedTokensAfterDirective, + "unexpected tokens following '$0' directive") // 152xx - preprocessor expressions -DIAGNOSTIC(15200, Error, expectedTokenInPreprocessorExpression, "expected '$0' in preprocessor expression") -DIAGNOSTIC(15201, Error, syntaxErrorInPreprocessorExpression, "syntax error in preprocessor expression") -DIAGNOSTIC(15202, Error, divideByZeroInPreprocessorExpression, "division by zero in preprocessor expression") +DIAGNOSTIC( + 15200, + Error, + expectedTokenInPreprocessorExpression, + "expected '$0' in preprocessor expression") +DIAGNOSTIC( + 15201, + Error, + syntaxErrorInPreprocessorExpression, + "syntax error in preprocessor expression") +DIAGNOSTIC( + 15202, + Error, + divideByZeroInPreprocessorExpression, + "division by zero in preprocessor expression") DIAGNOSTIC(15203, Error, expectedTokenInDefinedExpression, "expected '$0' in 'defined' expression") DIAGNOSTIC(15204, Warning, directiveExpectsExpression, "'$0' directive requires an expression") -DIAGNOSTIC(15205, Warning, undefinedIdentifierInPreprocessorExpression, "undefined identifier '$0' in preprocessor expression will evaluate to zero") +DIAGNOSTIC( + 15205, + Warning, + undefinedIdentifierInPreprocessorExpression, + "undefined identifier '$0' in preprocessor expression will evaluate to zero") DIAGNOSTIC(15206, Error, expectedIntegralVersionNumber, "Expected integer for #version number") DIAGNOSTIC(-1, Note, seeOpeningToken, "see opening '$0'") @@ -183,7 +355,11 @@ DIAGNOSTIC(-1, Note, seeOpeningToken, "see opening '$0'") DIAGNOSTIC(15300, Error, includeFailed, "failed to find include file '$0'") DIAGNOSTIC(15301, Error, importFailed, "failed to find imported file '$0'") DIAGNOSTIC(-1, Error, noIncludeHandlerSpecified, "no `#include` handler was specified") -DIAGNOSTIC(15302, Error, noUniqueIdentity, "`#include` handler didn't generate a unique identity for file '$0'") +DIAGNOSTIC( + 15302, + Error, + noUniqueIdentity, + "`#include` handler didn't generate a unique identity for file '$0'") // 154xx - macro definition @@ -194,25 +370,49 @@ DIAGNOSTIC(15404, Warning, builtinMacroRedefinition, "Redefinition of builtin ma DIAGNOSTIC(15405, Error, tokenPasteAtStart, "'##' is not allowed at the start of a macro body") DIAGNOSTIC(15406, Error, tokenPasteAtEnd, "'##' is not allowed at the end of a macro body") -DIAGNOSTIC(15407, Error, expectedMacroParameterAfterStringize, "'#' in macro body must be followed by the name of a macro parameter") +DIAGNOSTIC( + 15407, + Error, + expectedMacroParameterAfterStringize, + "'#' in macro body must be followed by the name of a macro parameter") DIAGNOSTIC(15408, Error, duplicateMacroParameterName, "redefinition of macro parameter '$0'") -DIAGNOSTIC(15409, Error, variadicMacroParameterMustBeLast, "a variadic macro parameter is only allowed at the end of the parameter list") +DIAGNOSTIC( + 15409, + Error, + variadicMacroParameterMustBeLast, + "a variadic macro parameter is only allowed at the end of the parameter list") // 155xx - macro expansion DIAGNOSTIC(15500, Warning, expectedTokenInMacroArguments, "expected '$0' in macro invocation") -DIAGNOSTIC(15501, Error, wrongNumberOfArgumentsToMacro, "wrong number of arguments to macro (expected $0, got $1)") -DIAGNOSTIC(15502, Error, errorParsingToMacroInvocationArgument, "error parsing macro '$0' invocation argument to '$1'") - -DIAGNOSTIC(15503, Warning, invalidTokenPasteResult, "toking pasting with '##' resulted in the invalid token '$0'") +DIAGNOSTIC( + 15501, + Error, + wrongNumberOfArgumentsToMacro, + "wrong number of arguments to macro (expected $0, got $1)") +DIAGNOSTIC( + 15502, + Error, + errorParsingToMacroInvocationArgument, + "error parsing macro '$0' invocation argument to '$1'") + +DIAGNOSTIC( + 15503, + Warning, + invalidTokenPasteResult, + "toking pasting with '##' resulted in the invalid token '$0'") // 156xx - pragmas DIAGNOSTIC(15600, Error, expectedPragmaDirectiveName, "expected a name after '#pragma'") DIAGNOSTIC(15601, Warning, unknownPragmaDirectiveIgnored, "ignoring unknown directive '#pragma $0'") -DIAGNOSTIC(15602, Warning, pragmaOnceIgnored, "pragma once was ignored - this is typically because is not placed in an include") +DIAGNOSTIC( + 15602, + Warning, + pragmaOnceIgnored, + "pragma once was ignored - this is typically because is not placed in an include") // 159xx - user-defined error/warning -DIAGNOSTIC(15900, Error, userDefinedError, "#error: $0") -DIAGNOSTIC(15901, Warning, userDefinedWarning, "#warning: $0") +DIAGNOSTIC(15900, Error, userDefinedError, "#error: $0") +DIAGNOSTIC(15901, Warning, userDefinedWarning, "#warning: $0") // // 2xxxx - Parsing @@ -232,16 +432,42 @@ DIAGNOSTIC(20001, Error, typeNameExpectedBut, "unexpected $0, expected type name DIAGNOSTIC(20001, Error, typeNameExpectedButEOF, "type name expected but end of file encountered.") DIAGNOSTIC(20001, Error, unexpectedEOF, " Unexpected end of file.") DIAGNOSTIC(20002, Error, syntaxError, "syntax error.") -DIAGNOSTIC(20004, Error, unexpectedTokenExpectedComponentDefinition, "unexpected token '$0', only component definitions are allowed in a shader scope.") +DIAGNOSTIC( + 20004, + Error, + unexpectedTokenExpectedComponentDefinition, + "unexpected token '$0', only component definitions are allowed in a shader scope.") DIAGNOSTIC(20008, Error, invalidOperator, "invalid operator '$0'.") DIAGNOSTIC(20011, Error, unexpectedColon, "unexpected ':'.") -DIAGNOSTIC(20012, Error, invalidSPIRVVersion, "Expecting SPIR-V version as either 'major.minor', or quoted if has patch (eg for SPIR-V 1.2, '1.2' or \"1.2\"')") -DIAGNOSTIC(20013, Error, invalidCUDASMVersion, "Expecting CUDA SM version as either 'major.minor', or quoted if has patch (eg for '7.0' or \"7.0\"')") -DIAGNOSTIC(20014, Error, classIsReservedKeyword, "'class' is a reserved keyword in this context; use 'struct' instead.") +DIAGNOSTIC( + 20012, + Error, + invalidSPIRVVersion, + "Expecting SPIR-V version as either 'major.minor', or quoted if has patch (eg for SPIR-V 1.2, " + "'1.2' or \"1.2\"')") +DIAGNOSTIC( + 20013, + Error, + invalidCUDASMVersion, + "Expecting CUDA SM version as either 'major.minor', or quoted if has patch (eg for '7.0' or " + "\"7.0\"')") +DIAGNOSTIC( + 20014, + Error, + classIsReservedKeyword, + "'class' is a reserved keyword in this context; use 'struct' instead.") DIAGNOSTIC(20015, Error, unknownSPIRVCapability, "unknown SPIR-V capability '$0'.") -DIAGNOSTIC(20016, Error, missingLayoutBindingModifier, "Expecting 'binding' modifier in the layout qualifier here") - -DIAGNOSTIC(20101, Warning, unintendedEmptyStatement, "potentially unintended empty statement at this location; use {} instead.") +DIAGNOSTIC( + 20016, + Error, + missingLayoutBindingModifier, + "Expecting 'binding' modifier in the layout qualifier here") + +DIAGNOSTIC( + 20101, + Warning, + unintendedEmptyStatement, + "potentially unintended empty statement at this location; use {} instead.") DIAGNOSTIC(30102, Error, declNotAllowed, "$0 is not allowed here.") @@ -249,20 +475,53 @@ DIAGNOSTIC(30102, Error, declNotAllowed, "$0 is not allowed here.") DIAGNOSTIC(29000, Error, snippetParsingFailed, "unable to parse target intrinsic snippet: $0") DIAGNOSTIC(29100, Error, unrecognizedSPIRVOpcode, "unrecognized spirv opcode: $0") -DIAGNOSTIC(29101, Error, misplacedResultIdMarker, "the result-id marker must only be used in the last instruction of a spriv_asm expression") -DIAGNOSTIC(29102, Note, considerOpCopyObject, "consider adding an OpCopyObject instruction to the end of the spirv_asm expression") -DIAGNOSTIC(29103, Note, noSuchAddress, "unable to take the address of this address-of asm operand") -DIAGNOSTIC(29104, Error, spirvInstructionWithoutResultId, "cannot use this 'x = $0...' syntax because $0 does not have a operand") -DIAGNOSTIC(29105, Error, spirvInstructionWithoutResultTypeId, "cannot use this 'x : = $0...' syntax because $0 does not have a operand") +DIAGNOSTIC( + 29101, + Error, + misplacedResultIdMarker, + "the result-id marker must only be used in the last instruction of a spriv_asm expression") +DIAGNOSTIC( + 29102, + Note, + considerOpCopyObject, + "consider adding an OpCopyObject instruction to the end of the spirv_asm expression") +DIAGNOSTIC(29103, Note, noSuchAddress, "unable to take the address of this address-of asm operand") +DIAGNOSTIC( + 29104, + Error, + spirvInstructionWithoutResultId, + "cannot use this 'x = $0...' syntax because $0 does not have a operand") +DIAGNOSTIC( + 29105, + Error, + spirvInstructionWithoutResultTypeId, + "cannot use this 'x : = $0...' syntax because $0 does not have a " + "operand") // This is a warning because we trust that people using the spirv_asm block know what they're doing -DIAGNOSTIC(29106, Warning, spirvInstructionWithTooManyOperands, "too many operands for $0 (expected max $1), did you forget a semicolon?") -DIAGNOSTIC(29107, Error, spirvUnableToResolveName, "unknown SPIR-V identifier $0, it's not a known enumerator or opcode") -DIAGNOSTIC(29108, Error, spirvNonConstantBitwiseOr, "only integer literals and enum names can appear in a bitwise or expression") +DIAGNOSTIC( + 29106, + Warning, + spirvInstructionWithTooManyOperands, + "too many operands for $0 (expected max $1), did you forget a semicolon?") +DIAGNOSTIC( + 29107, + Error, + spirvUnableToResolveName, + "unknown SPIR-V identifier $0, it's not a known enumerator or opcode") +DIAGNOSTIC( + 29108, + Error, + spirvNonConstantBitwiseOr, + "only integer literals and enum names can appear in a bitwise or expression") DIAGNOSTIC(29109, Error, spirvOperandRange, "Literal ints must be in the range 0 to 0xffffffff") DIAGNOSTIC(29110, Error, unknownTargetName, "unknown target name '$0'") -DIAGNOSTIC(29111, Error, spirvInvalidTruncate, "__truncate has been given a source smaller than its target") +DIAGNOSTIC( + 29111, + Error, + spirvInvalidTruncate, + "__truncate has been given a source smaller than its target") // // 3xxxx - Semantic analysis @@ -271,16 +530,28 @@ DIAGNOSTIC(30002, Error, divideByZero, "divide by zero") DIAGNOSTIC(30003, Error, breakOutsideLoop, "'break' must appear inside loop or switch constructs.") DIAGNOSTIC(30004, Error, continueOutsideLoop, "'continue' must appear inside loop constructs.") DIAGNOSTIC(30005, Error, whilePredicateTypeError, "'while': expression must evaluate to int.") -DIAGNOSTIC(30006, Error, ifPredicateTypeError, "'if': expression must evaluate to int.") +DIAGNOSTIC(30006, Error, ifPredicateTypeError, "'if': expression must evaluate to int.") DIAGNOSTIC(30006, Error, returnNeedsExpression, "'return' should have an expression.") -DIAGNOSTIC(30007, Error, componentReturnTypeMismatch, "expression type '$0' does not match component's type '$1'") -DIAGNOSTIC(30007, Error, functionReturnTypeMismatch, "expression type '$0' does not match function's return type '$1'") +DIAGNOSTIC( + 30007, + Error, + componentReturnTypeMismatch, + "expression type '$0' does not match component's type '$1'") +DIAGNOSTIC( + 30007, + Error, + functionReturnTypeMismatch, + "expression type '$0' does not match function's return type '$1'") DIAGNOSTIC(30008, Error, variableNameAlreadyDefined, "variable $0 already defined.") DIAGNOSTIC(30009, Error, invalidTypeVoid, "invalid type 'void'.") DIAGNOSTIC(30010, Error, whilePredicateTypeError2, "'while': expression must evaluate to int.") DIAGNOSTIC(30011, Error, assignNonLValue, "left of '=' is not an l-value.") DIAGNOSTIC(30012, Error, noApplicationUnaryOperator, "no overload found for operator $0 ($1).") -DIAGNOSTIC(30012, Error, noOverloadFoundForBinOperatorOnTypes, "no overload found for operator $0 ($1, $2).") +DIAGNOSTIC( + 30012, + Error, + noOverloadFoundForBinOperatorOnTypes, + "no overload found for operator $0 ($1, $2).") DIAGNOSTIC(30013, Error, subscriptNonArray, "no subscript operation found for type '$0'") DIAGNOSTIC(30014, Error, subscriptIndexNonInteger, "index expression must evaluate to int.") DIAGNOSTIC(30016, Error, callOperatorNotFound, "no call operation found for type '$0'") @@ -290,306 +561,973 @@ DIAGNOSTIC(30021, Error, noApplicationFunction, "$0: no overload takes arguments DIAGNOSTIC(30022, Error, invalidTypeCast, "invalid type cast between \"$0\" and \"$1\".") DIAGNOSTIC(30023, Error, typeHasNoPublicMemberOfName, "\"$0\" does not have public member \"$1\".") DIAGNOSTIC(30025, Error, invalidArraySize, "array size must be larger than zero.") -DIAGNOSTIC(30026, Error, returnInComponentMustComeLast, "'return' can only appear as the last statement in component definition.") +DIAGNOSTIC( + 30026, + Error, + returnInComponentMustComeLast, + "'return' can only appear as the last statement in component definition.") DIAGNOSTIC(30027, Error, noMemberOfNameInType, "'$0' is not a member of '$1'.") -DIAGNOSTIC(30028, Error, forPredicateTypeError, "'for': predicate expression must evaluate to bool.") -DIAGNOSTIC(30030, Error, projectionOutsideImportOperator, "'project': invalid use outside import operator.") -DIAGNOSTIC(30031, Error, projectTypeMismatch, "'project': expression must evaluate to record type '$0'.") -DIAGNOSTIC(30033, Error, invalidTypeForLocalVariable, "cannot declare a local variable of this type.") -DIAGNOSTIC(30035, Error, componentOverloadTypeMismatch, "'$0': type of overloaded component mismatches previous definition.") +DIAGNOSTIC( + 30028, + Error, + forPredicateTypeError, + "'for': predicate expression must evaluate to bool.") +DIAGNOSTIC( + 30030, + Error, + projectionOutsideImportOperator, + "'project': invalid use outside import operator.") +DIAGNOSTIC( + 30031, + Error, + projectTypeMismatch, + "'project': expression must evaluate to record type '$0'.") +DIAGNOSTIC( + 30033, + Error, + invalidTypeForLocalVariable, + "cannot declare a local variable of this type.") +DIAGNOSTIC( + 30035, + Error, + componentOverloadTypeMismatch, + "'$0': type of overloaded component mismatches previous definition.") DIAGNOSTIC(30041, Error, bitOperationNonIntegral, "bit operation: operand must be integral type.") -DIAGNOSTIC(30043, Error, getStringHashRequiresStringLiteral, "getStringHash parameter can only accept a string literal") -DIAGNOSTIC(30047, Error, argumentExpectedLValue, "argument passed to parameter '$0' must be l-value.") -DIAGNOSTIC(30048, Error, argumentHasMoreMemoryQualifiersThanParam, "argument passed in to parameter has a memory qualifier the parameter type is missing: '$0'") - -DIAGNOSTIC(30049, Note, thisIsImmutableByDefault, "a 'this' parameter is an immutable parameter by default in Slang; apply the `[mutating]` attribute to the function declaration to opt in to a mutable `this`") -DIAGNOSTIC(30050, Error, mutatingMethodOnImmutableValue, "mutating method '$0' cannot be called on an immutable value") +DIAGNOSTIC( + 30043, + Error, + getStringHashRequiresStringLiteral, + "getStringHash parameter can only accept a string literal") +DIAGNOSTIC( + 30047, + Error, + argumentExpectedLValue, + "argument passed to parameter '$0' must be l-value.") +DIAGNOSTIC( + 30048, + Error, + argumentHasMoreMemoryQualifiersThanParam, + "argument passed in to parameter has a memory qualifier the parameter type is missing: '$0'") + +DIAGNOSTIC( + 30049, + Note, + thisIsImmutableByDefault, + "a 'this' parameter is an immutable parameter by default in Slang; apply the `[mutating]` " + "attribute to the function declaration to opt in to a mutable `this`") +DIAGNOSTIC( + 30050, + Error, + mutatingMethodOnImmutableValue, + "mutating method '$0' cannot be called on an immutable value") DIAGNOSTIC(30051, Error, invalidValueForArgument, "invalid value for argument '$0'") DIAGNOSTIC(30052, Error, invalidSwizzleExpr, "invalid swizzle pattern '$0' on type '$1'") DIAGNOSTIC(30053, Error, breakLabelNotFound, "label '$0' used as break target is not found.") -DIAGNOSTIC(30054, Error, targetLabelDoesNotMarkBreakableStmt, "invalid break target: statement labeled '$0' is not breakable.") -DIAGNOSTIC(30055, Error, useOfNonShortCircuitingOperatorInDiffFunc, "non-short-circuiting `?:` operator is not allowed in a differentiable function, use `select` instead.") -DIAGNOSTIC(30056, Warning, useOfNonShortCircuitingOperator, "non-short-circuiting `?:` operator is deprecated, use 'select' instead.") -DIAGNOSTIC(30057, Error, assignmentInPredicateExpr, "use an assignment operation as predicate expression is not allowed, wrap the assignment with '()' to clarify the intent.") +DIAGNOSTIC( + 30054, + Error, + targetLabelDoesNotMarkBreakableStmt, + "invalid break target: statement labeled '$0' is not breakable.") +DIAGNOSTIC( + 30055, + Error, + useOfNonShortCircuitingOperatorInDiffFunc, + "non-short-circuiting `?:` operator is not allowed in a differentiable function, use `select` " + "instead.") +DIAGNOSTIC( + 30056, + Warning, + useOfNonShortCircuitingOperator, + "non-short-circuiting `?:` operator is deprecated, use 'select' instead.") +DIAGNOSTIC( + 30057, + Error, + assignmentInPredicateExpr, + "use an assignment operation as predicate expression is not allowed, wrap the assignment with " + "'()' to clarify the intent.") DIAGNOSTIC(30058, Warning, danglingEqualityExpr, "result of '==' not used, did you intend '='?") DIAGNOSTIC(30060, Error, expectedAType, "expected a type, got a '$0'") DIAGNOSTIC(30061, Error, expectedANamespace, "expected a namespace, got a '$0'") -DIAGNOSTIC(30062, Note, implicitCastUsedAsLValueRef, "argument was implicitly cast from '$0' to '$1', and Slang does not support using an implicit cast as an l-value with a reference") -DIAGNOSTIC(30063, Note, implicitCastUsedAsLValueType, "argument was implicitly cast from '$0' to '$1', and Slang does not support using an implicit cast as an l-value with this type") -DIAGNOSTIC(30064, Note, implicitCastUsedAsLValue, "argument was implicitly cast from '$0' to '$1', and Slang does not support using an implicit cast as an l-value for this usage") - -DIAGNOSTIC(30065, Error, newCanOnlyBeUsedToInitializeAClass, "`new` can only be used to initialize a class") -DIAGNOSTIC(30066, Error, classCanOnlyBeInitializedWithNew, "a class can only be initialized by a `new` clause") - -DIAGNOSTIC(30067, Error, mutatingMethodOnFunctionInputParameterError, "mutating method '$0' called on `in` parameter '$1'; changes will not be visible to caller. copy the parameter into a local variable if this behavior is intended") -DIAGNOSTIC(30068, Warning, mutatingMethodOnFunctionInputParameterWarning, "mutating method '$0' called on `in` parameter '$1'; changes will not be visible to caller. copy the parameter into a local variable if this behavior is intended") - -DIAGNOSTIC(30070, Error, unsizedMemberMustAppearLast, "member with unknown size at compile time can only appear as the last member in a composite type.") +DIAGNOSTIC( + 30062, + Note, + implicitCastUsedAsLValueRef, + "argument was implicitly cast from '$0' to '$1', and Slang does not support using an implicit " + "cast as an l-value with a reference") +DIAGNOSTIC( + 30063, + Note, + implicitCastUsedAsLValueType, + "argument was implicitly cast from '$0' to '$1', and Slang does not support using an implicit " + "cast as an l-value with this type") +DIAGNOSTIC( + 30064, + Note, + implicitCastUsedAsLValue, + "argument was implicitly cast from '$0' to '$1', and Slang does not support using an implicit " + "cast as an l-value for this usage") + +DIAGNOSTIC( + 30065, + Error, + newCanOnlyBeUsedToInitializeAClass, + "`new` can only be used to initialize a class") +DIAGNOSTIC( + 30066, + Error, + classCanOnlyBeInitializedWithNew, + "a class can only be initialized by a `new` clause") + +DIAGNOSTIC( + 30067, + Error, + mutatingMethodOnFunctionInputParameterError, + "mutating method '$0' called on `in` parameter '$1'; changes will not be visible to caller. " + "copy the parameter into a local variable if this behavior is intended") +DIAGNOSTIC( + 30068, + Warning, + mutatingMethodOnFunctionInputParameterWarning, + "mutating method '$0' called on `in` parameter '$1'; changes will not be visible to caller. " + "copy the parameter into a local variable if this behavior is intended") + +DIAGNOSTIC( + 30070, + Error, + unsizedMemberMustAppearLast, + "member with unknown size at compile time can only appear as the last member in a composite " + "type.") DIAGNOSTIC(30071, Error, varCannotBeUnsized, "cannot instantiate a variable of unsized type.") DIAGNOSTIC(30072, Error, paramCannotBeUnsized, "function parameter cannot be unsized.") -DIAGNOSTIC(30075, Error, cannotSpecializeGeneric, "cannot specialize generic '$0' with the provided arguments.") - -DIAGNOSTIC(30100, Error, staticRefToNonStaticMember, "type '$0' cannot be used to refer to non-static member '$1'") -DIAGNOSTIC(30101, Error, cannotDereferenceType, "cannot dereference type '$0', do you mean to use '.'?") +DIAGNOSTIC( + 30075, + Error, + cannotSpecializeGeneric, + "cannot specialize generic '$0' with the provided arguments.") + +DIAGNOSTIC( + 30100, + Error, + staticRefToNonStaticMember, + "type '$0' cannot be used to refer to non-static member '$1'") +DIAGNOSTIC( + 30101, + Error, + cannotDereferenceType, + "cannot dereference type '$0', do you mean to use '.'?") DIAGNOSTIC(30200, Error, redeclaration, "declaration of '$0' conflicts with existing declaration") DIAGNOSTIC(30201, Error, functionRedefinition, "function '$0' already has a body") -DIAGNOSTIC(30202, Error, functionRedeclarationWithDifferentReturnType, "function '$0' declared to return '$1' was previously declared to return '$2'") - -DIAGNOSTIC(30300, Error, isOperatorValueMustBeInterfaceType, "'is'/'as' operator requires an interface-typed expression.") +DIAGNOSTIC( + 30202, + Error, + functionRedeclarationWithDifferentReturnType, + "function '$0' declared to return '$1' was previously declared to return '$2'") + +DIAGNOSTIC( + 30300, + Error, + isOperatorValueMustBeInterfaceType, + "'is'/'as' operator requires an interface-typed expression.") DIAGNOSTIC(33070, Error, expectedFunction, "expected a function, got '$0'") DIAGNOSTIC(33071, Error, expectedAStringLiteral, "expected a string literal") -DIAGNOSTIC(-1, Note, noteExplicitConversionPossible, "explicit conversion from '$0' to '$1' is possible") -DIAGNOSTIC(30080, Error, ambiguousConversion, "more than one implicit conversion exists from '$0' to '$1'") -DIAGNOSTIC(30081, Warning, unrecommendedImplicitConversion, "implicit conversion from '$0' to '$1' is not recommended") -DIAGNOSTIC(30082, Warning, implicitConversionToDouble, " implicit float-to-double conversion may cause unexpected performance issues, use explicit cast if intended.") -DIAGNOSTIC(30090, Error, tryClauseMustApplyToInvokeExpr, "expression in a 'try' clause must be a call to a function or operator overload.") -DIAGNOSTIC(30091, Error, tryInvokeCalleeShouldThrow, "'$0' called from a 'try' clause does not throw an error, make sure the callee is marked as 'throws'") +DIAGNOSTIC( + -1, + Note, + noteExplicitConversionPossible, + "explicit conversion from '$0' to '$1' is possible") +DIAGNOSTIC( + 30080, + Error, + ambiguousConversion, + "more than one implicit conversion exists from '$0' to '$1'") +DIAGNOSTIC( + 30081, + Warning, + unrecommendedImplicitConversion, + "implicit conversion from '$0' to '$1' is not recommended") +DIAGNOSTIC( + 30082, + Warning, + implicitConversionToDouble, + " implicit float-to-double conversion may cause unexpected performance issues, use explicit " + "cast if intended.") +DIAGNOSTIC( + 30090, + Error, + tryClauseMustApplyToInvokeExpr, + "expression in a 'try' clause must be a call to a function or operator overload.") +DIAGNOSTIC( + 30091, + Error, + tryInvokeCalleeShouldThrow, + "'$0' called from a 'try' clause does not throw an error, make sure the callee is marked as " + "'throws'") DIAGNOSTIC(30092, Error, calleeOfTryCallMustBeFunc, "callee in a 'try' clause must be a function") -DIAGNOSTIC(30093, Error, uncaughtTryCallInNonThrowFunc, "the current function or environment is not declared to throw any errors, but the 'try' clause is not caught") -DIAGNOSTIC(30094, Error, mustUseTryClauseToCallAThrowFunc, "the callee may throw an error, and therefore must be called within a 'try' clause") -DIAGNOSTIC(30095, Error, errorTypeOfCalleeIncompatibleWithCaller, "the error type `$1` of callee `$0` is not compatible with the caller's error type `$2`.") - -DIAGNOSTIC(30096, Error, differentialTypeShouldServeAsItsOwnDifferentialType, "cannot use type '$0' a `Differential` type. A differential type's differential must be itself. However, '$0.Differential' is '$1'.") -DIAGNOSTIC(30097, Error, functionNotMarkedAsDifferentiable, "function '$0' is not marked as $1-differentiable.") -DIAGNOSTIC(30098, Error, nonStaticMemberFunctionNotAllowedAsDiffOperand, "non-static function reference '$0' is not allowed here.") +DIAGNOSTIC( + 30093, + Error, + uncaughtTryCallInNonThrowFunc, + "the current function or environment is not declared to throw any errors, but the 'try' clause " + "is not caught") +DIAGNOSTIC( + 30094, + Error, + mustUseTryClauseToCallAThrowFunc, + "the callee may throw an error, and therefore must be called within a 'try' clause") +DIAGNOSTIC( + 30095, + Error, + errorTypeOfCalleeIncompatibleWithCaller, + "the error type `$1` of callee `$0` is not compatible with the caller's error type `$2`.") + +DIAGNOSTIC( + 30096, + Error, + differentialTypeShouldServeAsItsOwnDifferentialType, + "cannot use type '$0' a `Differential` type. A differential type's differential must be " + "itself. However, '$0.Differential' is '$1'.") +DIAGNOSTIC( + 30097, + Error, + functionNotMarkedAsDifferentiable, + "function '$0' is not marked as $1-differentiable.") +DIAGNOSTIC( + 30098, + Error, + nonStaticMemberFunctionNotAllowedAsDiffOperand, + "non-static function reference '$0' is not allowed here.") DIAGNOSTIC(30099, Error, sizeOfArgumentIsInvalid, "argument to sizeof is invalid") -DIAGNOSTIC(30083, Error, countOfArgumentIsInvalid, "argument to countof can only be a type pack or tuple") +DIAGNOSTIC( + 30083, + Error, + countOfArgumentIsInvalid, + "argument to countof can only be a type pack or tuple") DIAGNOSTIC(30101, Error, readingFromWriteOnly, "cannot read from writeonly, check modifiers.") -DIAGNOSTIC(30102, Error, differentiableMemberShouldHaveCorrespondingFieldInDiffType, "differentiable member '$0' should have a corresponding field in '$1'. Use [DerivativeMember($1.)] or mark as no_diff") +DIAGNOSTIC( + 30102, + Error, + differentiableMemberShouldHaveCorrespondingFieldInDiffType, + "differentiable member '$0' should have a corresponding field in '$1'. Use " + "[DerivativeMember($1.)] or mark as no_diff") DIAGNOSTIC(30103, Error, expectTypePackAfterEach, "expected a type pack or a tuple after 'each'.") -DIAGNOSTIC(30104, Error, eachExprMustBeInsideExpandExpr, "'each' expression must be inside 'expand' expression.") -DIAGNOSTIC(30105, Error, expandTermCapturesNoTypePacks, "'expand' term captures no type packs. At least one type pack must be referenced via an 'each' term inside an 'expand' term.") +DIAGNOSTIC( + 30104, + Error, + eachExprMustBeInsideExpandExpr, + "'each' expression must be inside 'expand' expression.") +DIAGNOSTIC( + 30105, + Error, + expandTermCapturesNoTypePacks, + "'expand' term captures no type packs. At least one type pack must be referenced via an 'each' " + "term inside an 'expand' term.") DIAGNOSTIC(30106, Error, improperUseOfType, "type '$0' cannot be used in this context.") DIAGNOSTIC(30107, Error, parameterPackMustBeConst, "a parameter pack must be declared as 'const'.") // Include -DIAGNOSTIC(30500, Error, includedFileMissingImplementing, "missing 'implementing' declaration in the included source file '$0'.") -DIAGNOSTIC(30501, Error, includedFileMissingImplementingDoYouMeanImport, "missing 'implementing' declaration in the included source file '$0'. The file declares that it defines module '$1', do you mean 'import' instead?") -DIAGNOSTIC(30502, Error, includedFileDoesNotImplementCurrentModule, "the included source file is expected to implement module '$0', but it is implementing '$1' instead.") -DIAGNOSTIC(30503, Error, primaryModuleFileCannotStartWithImplementingDecl, "a primary source file for a module cannot start with 'implementing'.") -DIAGNOSTIC(30504, Warning, primaryModuleFileMustStartWithModuleDecl, "a primary source file for a module should start with 'module'.") -DIAGNOSTIC(30505, Error, implementingMustReferencePrimaryModuleFile, "the source file referenced by 'implementing' must be a primary module file starting with a 'module' declaration.") +DIAGNOSTIC( + 30500, + Error, + includedFileMissingImplementing, + "missing 'implementing' declaration in the included source file '$0'.") +DIAGNOSTIC( + 30501, + Error, + includedFileMissingImplementingDoYouMeanImport, + "missing 'implementing' declaration in the included source file '$0'. The file declares that " + "it defines module '$1', do you mean 'import' instead?") +DIAGNOSTIC( + 30502, + Error, + includedFileDoesNotImplementCurrentModule, + "the included source file is expected to implement module '$0', but it is implementing '$1' " + "instead.") +DIAGNOSTIC( + 30503, + Error, + primaryModuleFileCannotStartWithImplementingDecl, + "a primary source file for a module cannot start with 'implementing'.") +DIAGNOSTIC( + 30504, + Warning, + primaryModuleFileMustStartWithModuleDecl, + "a primary source file for a module should start with 'module'.") +DIAGNOSTIC( + 30505, + Error, + implementingMustReferencePrimaryModuleFile, + "the source file referenced by 'implementing' must be a primary module file starting with a " + "'module' declaration.") // Visibilty DIAGNOSTIC(30600, Error, declIsNotVisible, "'$0' is not accessible from the current context.") -DIAGNOSTIC(30601, Error, declCannotHaveHigherVisibility, "'$0' cannot have a higher visibility than '$1'.") -DIAGNOSTIC(30602, Error, satisfyingDeclCannotHaveLowerVisibility, "'$0' is less visible than the interface requirement it satisfies.") -DIAGNOSTIC(30603, Error, invalidUseOfPrivateVisibility, "'$0' cannot have private visibility because it is not a member of a type.") +DIAGNOSTIC( + 30601, + Error, + declCannotHaveHigherVisibility, + "'$0' cannot have a higher visibility than '$1'.") +DIAGNOSTIC( + 30602, + Error, + satisfyingDeclCannotHaveLowerVisibility, + "'$0' is less visible than the interface requirement it satisfies.") +DIAGNOSTIC( + 30603, + Error, + invalidUseOfPrivateVisibility, + "'$0' cannot have private visibility because it is not a member of a type.") DIAGNOSTIC(30604, Error, useOfLessVisibleType, "'$0' references less visible type '$1'.") -DIAGNOSTIC(36005, Error, invalidVisibilityModifierOnTypeOfDecl, "visibility modifier is not allowed on '$0'.") +DIAGNOSTIC( + 36005, + Error, + invalidVisibilityModifierOnTypeOfDecl, + "visibility modifier is not allowed on '$0'.") // Capability -DIAGNOSTIC(36100, Error, conflictingCapabilityDueToUseOfDecl, "'$0' requires capability '$1' that is conflicting with the '$2's current capability requirement '$3'.") -DIAGNOSTIC(36101, Error, conflictingCapabilityDueToStatement, "statement requires capability '$0' that is conflicting with the '$1's current capability requirement '$2'.") -DIAGNOSTIC(36102, Error, conflictingCapabilityDueToStatementEnclosingFunc, "statement requires capability '$0' that is conflicting with the current function's capability requirement '$1'.") -DIAGNOSTIC(36103, Warning, missingCapabilityRequirementOnPublicDecl, "public symbol '$0' is missing capability requirement declaration, the symbol is assumed to require inferred capabilities '$1'.") +DIAGNOSTIC( + 36100, + Error, + conflictingCapabilityDueToUseOfDecl, + "'$0' requires capability '$1' that is conflicting with the '$2's current capability " + "requirement '$3'.") +DIAGNOSTIC( + 36101, + Error, + conflictingCapabilityDueToStatement, + "statement requires capability '$0' that is conflicting with the '$1's current capability " + "requirement '$2'.") +DIAGNOSTIC( + 36102, + Error, + conflictingCapabilityDueToStatementEnclosingFunc, + "statement requires capability '$0' that is conflicting with the current function's capability " + "requirement '$1'.") +DIAGNOSTIC( + 36103, + Warning, + missingCapabilityRequirementOnPublicDecl, + "public symbol '$0' is missing capability requirement declaration, the symbol is assumed to " + "require inferred capabilities '$1'.") DIAGNOSTIC(36104, Error, useOfUndeclaredCapability, "'$0' uses undeclared capability '$1'.") -DIAGNOSTIC(36104, Error, useOfUndeclaredCapabilityOfInterfaceRequirement, "'$0' uses capability '$1' that is missing from the interface requirement.") +DIAGNOSTIC( + 36104, + Error, + useOfUndeclaredCapabilityOfInterfaceRequirement, + "'$0' uses capability '$1' that is missing from the interface requirement.") DIAGNOSTIC(36105, Error, unknownCapability, "unknown capability name '$0'.") DIAGNOSTIC(36106, Error, expectCapability, "expect a capability name.") -DIAGNOSTIC(36107, Error, entryPointUsesUnavailableCapability, "entrypoint '$0' does not support compilation target '$1' with stage '$2'") -DIAGNOSTIC(36108, Error, declHasDependenciesNotCompatibleOnTarget, "'$0' has dependencies that are not compatible on the required target '$1'.") +DIAGNOSTIC( + 36107, + Error, + entryPointUsesUnavailableCapability, + "entrypoint '$0' does not support compilation target '$1' with stage '$2'") +DIAGNOSTIC( + 36108, + Error, + declHasDependenciesNotCompatibleOnTarget, + "'$0' has dependencies that are not compatible on the required target '$1'.") DIAGNOSTIC(36109, Error, invalidTargetSwitchCase, "'$0' cannot be used as a target_switch case.") -DIAGNOSTIC(36110, Error, stageIsIncompatibleWithCapabilityDefinition, "'$0' is defined for stage '$1', which is incompatible with the declared capability set '$2'.") +DIAGNOSTIC( + 36110, + Error, + stageIsIncompatibleWithCapabilityDefinition, + "'$0' is defined for stage '$1', which is incompatible with the declared capability set '$2'.") DIAGNOSTIC(36111, Error, unexpectedCapability, "'$0' resolves into a disallowed `$1` Capability.") -DIAGNOSTIC(36112, Warning, entryPointAndProfileAreIncompatible, "'$0' is defined for stage '$1', which is incompatible with the declared profile '$2'.") -DIAGNOSTIC(36113, Warning, usingInternalCapabilityName, "'$0' resolves into a '_Internal' '_$1' Capability, use '$1' instead.") -DIAGNOSTIC(36114, Warning, incompatibleWithPrecompileLib, "Precompiled library requires '$0', has `$1`, implicitly upgrading capabilities.") -DIAGNOSTIC(36115, Error, incompatibleWithPrecompileLibRestrictive, "Precompiled library requires '$0', has `$1`.") -DIAGNOSTIC(36116, Error, capabilityHasMultipleStages, "Capability '$0' is targeting stages '$1', only allowed to use 1 unique stage here.") +DIAGNOSTIC( + 36112, + Warning, + entryPointAndProfileAreIncompatible, + "'$0' is defined for stage '$1', which is incompatible with the declared profile '$2'.") +DIAGNOSTIC( + 36113, + Warning, + usingInternalCapabilityName, + "'$0' resolves into a '_Internal' '_$1' Capability, use '$1' instead.") +DIAGNOSTIC( + 36114, + Warning, + incompatibleWithPrecompileLib, + "Precompiled library requires '$0', has `$1`, implicitly upgrading capabilities.") +DIAGNOSTIC( + 36115, + Error, + incompatibleWithPrecompileLibRestrictive, + "Precompiled library requires '$0', has `$1`.") +DIAGNOSTIC( + 36116, + Error, + capabilityHasMultipleStages, + "Capability '$0' is targeting stages '$1', only allowed to use 1 unique stage here.") // Attributes DIAGNOSTIC(31000, Warning, unknownAttributeName, "unknown attribute '$0'") -DIAGNOSTIC(31001, Error, attributeArgumentCountMismatch, "attribute '$0' expects $1 arguments ($2 provided)") +DIAGNOSTIC( + 31001, + Error, + attributeArgumentCountMismatch, + "attribute '$0' expects $1 arguments ($2 provided)") DIAGNOSTIC(31002, Error, attributeNotApplicable, "attribute '$0' is not valid here") -DIAGNOSTIC(31003, Error, badlyDefinedPatchConstantFunc, "hull shader '$0' has has badly defined 'patchconstantfunc' attribute.") +DIAGNOSTIC( + 31003, + Error, + badlyDefinedPatchConstantFunc, + "hull shader '$0' has has badly defined 'patchconstantfunc' attribute.") DIAGNOSTIC(31004, Error, expectedSingleIntArg, "attribute '$0' expects a single int argument") DIAGNOSTIC(31005, Error, expectedSingleStringArg, "attribute '$0' expects a single string argument") -DIAGNOSTIC(31006, Error, attributeFunctionNotFound, "Could not find function '$0' for attribute'$1'") +DIAGNOSTIC( + 31006, + Error, + attributeFunctionNotFound, + "Could not find function '$0' for attribute'$1'") DIAGNOSTIC(31007, Error, attributeExpectedIntArg, "attribute '$0' expects argument $1 to be int") -DIAGNOSTIC(31008, Error, attributeExpectedStringArg, "attribute '$0' expects argument $1 to be string") +DIAGNOSTIC( + 31008, + Error, + attributeExpectedStringArg, + "attribute '$0' expects argument $1 to be string") DIAGNOSTIC(31100, Error, unknownStageName, "unknown stage name '$0'") DIAGNOSTIC(31101, Error, unknownImageFormatName, "unknown image format '$0'") DIAGNOSTIC(31101, Error, unknownDiagnosticName, "unknown diagnostic '$0'") -DIAGNOSTIC(31102, Error, nonPositiveNumThreads, "expected a positive integer in 'numthreads' attribute, got '$0'") -DIAGNOSTIC(31103, Error, invalidWaveSize, "expected a power of 2 between 4 and 128, inclusive, in 'WaveSize' attribute, got '$0'") -DIAGNOSTIC(31104, Warning, explicitUniformLocation, "Explicit binding of uniform locations is discouraged. Prefer 'ConstantBuffer<$0>' over 'uniform $0'") +DIAGNOSTIC( + 31102, + Error, + nonPositiveNumThreads, + "expected a positive integer in 'numthreads' attribute, got '$0'") +DIAGNOSTIC( + 31103, + Error, + invalidWaveSize, + "expected a power of 2 between 4 and 128, inclusive, in 'WaveSize' attribute, got '$0'") +DIAGNOSTIC( + 31104, + Warning, + explicitUniformLocation, + "Explicit binding of uniform locations is discouraged. Prefer 'ConstantBuffer<$0>' over " + "'uniform $0'") DIAGNOSTIC(31120, Error, invalidAttributeTarget, "invalid syntax target for user defined attribute") DIAGNOSTIC(31121, Error, anyValueSizeExceedsLimit, "'anyValueSize' cannot exceed $0") -DIAGNOSTIC(31122, Error, associatedTypeNotAllowInComInterface, "associatedtype not allowed in a [COM] interface") +DIAGNOSTIC( + 31122, + Error, + associatedTypeNotAllowInComInterface, + "associatedtype not allowed in a [COM] interface") DIAGNOSTIC(31123, Error, invalidGUID, "'$0' is not a valid GUID") -DIAGNOSTIC(31124, Error, structCannotImplementComInterface, "a struct type cannot implement a [COM] interface") -DIAGNOSTIC(31124, Error, interfaceInheritingComMustBeCom, "an interface type that inherits from a [COM] interface must itself be a [COM] interface") - -DIAGNOSTIC(31130, Error, derivativeMemberAttributeMustNameAMemberInExpectedDifferentialType, "[DerivativeMember] must reference to a member in the associated differential type '$0'.") -DIAGNOSTIC(31131, Error, invalidUseOfDerivativeMemberAttributeParentTypeIsNotDifferentiable, "invalid use of [DerivativeMember], parent type is not differentiable.") -DIAGNOSTIC(31132, Error, derivativeMemberAttributeCanOnlyBeUsedOnMembers, "[DerivativeMember] is allowed on members only.") - -DIAGNOSTIC(31140, Error, typeOfExternDeclMismatchesOriginalDefinition, "type of `extern` decl '$0' differs from its original definition. expected '$1'.") -DIAGNOSTIC(31141, Error, definitionOfExternDeclMismatchesOriginalDefinition, "`extern` decl '$0' is not consistent with its original definition.") -DIAGNOSTIC(31142, Error, ambiguousOriginalDefintionOfExternDecl, "`extern` decl '$0' has ambiguous original definitions.") -DIAGNOSTIC(31143, Error, missingOriginalDefintionOfExternDecl, "no original definition found for `extern` decl '$0'.") +DIAGNOSTIC( + 31124, + Error, + structCannotImplementComInterface, + "a struct type cannot implement a [COM] interface") +DIAGNOSTIC( + 31124, + Error, + interfaceInheritingComMustBeCom, + "an interface type that inherits from a [COM] interface must itself be a [COM] interface") + +DIAGNOSTIC( + 31130, + Error, + derivativeMemberAttributeMustNameAMemberInExpectedDifferentialType, + "[DerivativeMember] must reference to a member in the associated differential type '$0'.") +DIAGNOSTIC( + 31131, + Error, + invalidUseOfDerivativeMemberAttributeParentTypeIsNotDifferentiable, + "invalid use of [DerivativeMember], parent type is not differentiable.") +DIAGNOSTIC( + 31132, + Error, + derivativeMemberAttributeCanOnlyBeUsedOnMembers, + "[DerivativeMember] is allowed on members only.") + +DIAGNOSTIC( + 31140, + Error, + typeOfExternDeclMismatchesOriginalDefinition, + "type of `extern` decl '$0' differs from its original definition. expected '$1'.") +DIAGNOSTIC( + 31141, + Error, + definitionOfExternDeclMismatchesOriginalDefinition, + "`extern` decl '$0' is not consistent with its original definition.") +DIAGNOSTIC( + 31142, + Error, + ambiguousOriginalDefintionOfExternDecl, + "`extern` decl '$0' has ambiguous original definitions.") +DIAGNOSTIC( + 31143, + Error, + missingOriginalDefintionOfExternDecl, + "no original definition found for `extern` decl '$0'.") DIAGNOSTIC(31145, Error, invalidCustomDerivative, "invalid custom derivative attribute.") DIAGNOSTIC(31146, Error, declAlreadyHasAttribute, "'$0' already has attribute '[$1]'.") -DIAGNOSTIC(31147, Error, cannotResolveOriginalFunctionForDerivative, "cannot resolve the original function for the the custom derivative.") -DIAGNOSTIC(31148, Error, cannotResolveDerivativeFunction, "cannot resolve the custom derivative function") -DIAGNOSTIC(31149, Error, customDerivativeSignatureMismatchAtPosition, "invalid custom derivative. parameter type mismatch at position $0. expected '$1', got '$2'") -DIAGNOSTIC(31150, Error, customDerivativeSignatureMismatch, "invalid custom derivative. could not resolve function with expected signature '$0'") -DIAGNOSTIC(31151, Error, cannotResolveGenericArgumentForDerivativeFunction, - "The generic arguments to the derivative function cannot be deduced from the parameter list of the original function. " - "Consider using [ForwardDerivative], [BackwardDerivative] or [PrimalSubstitute] attributes on the primal function" - " with explicit generic arguments to associate it with a generic derivative function. Note that [ForwardDerivativeOf], " - "[BackwardDerivativeOf], and [PrimalSubstituteOf] attributes are not supported when the generic arguments to the derivatives cannot be automatically deduced.") -DIAGNOSTIC(31152, Error, cannotAssociateInterfaceRequirementWithDerivative, "cannot associate an interface requirement with a derivative.") -DIAGNOSTIC(31153, Error, cannotUseInterfaceRequirementAsDerivative, "cannot use an interface requirement as a derivative.") -DIAGNOSTIC(31154, Error, customDerivativeSignatureThisParamMismatch, "custom derivative does not match expected signature on `this`. Both original and derivative function must have the same `this` type.") -DIAGNOSTIC(31155, Error, customDerivativeNotAllowedForMemberFunctionsOfDifferentiableType, "custom derivative is not allowed for non-static member functions of a differentiable type.") -DIAGNOSTIC(31156, Error, customDerivativeExpectedStatic, "expected a static definition for the custom derivative.") -DIAGNOSTIC(31157, Error, overloadedFuncUsedWithDerivativeOfAttributes, "cannot resolve overloaded functions for derivative-of attributes.") +DIAGNOSTIC( + 31147, + Error, + cannotResolveOriginalFunctionForDerivative, + "cannot resolve the original function for the the custom derivative.") +DIAGNOSTIC( + 31148, + Error, + cannotResolveDerivativeFunction, + "cannot resolve the custom derivative function") +DIAGNOSTIC( + 31149, + Error, + customDerivativeSignatureMismatchAtPosition, + "invalid custom derivative. parameter type mismatch at position $0. expected '$1', got '$2'") +DIAGNOSTIC( + 31150, + Error, + customDerivativeSignatureMismatch, + "invalid custom derivative. could not resolve function with expected signature '$0'") +DIAGNOSTIC( + 31151, + Error, + cannotResolveGenericArgumentForDerivativeFunction, + "The generic arguments to the derivative function cannot be deduced from the parameter list of " + "the original function. " + "Consider using [ForwardDerivative], [BackwardDerivative] or [PrimalSubstitute] attributes on " + "the primal function" + " with explicit generic arguments to associate it with a generic derivative function. Note " + "that [ForwardDerivativeOf], " + "[BackwardDerivativeOf], and [PrimalSubstituteOf] attributes are not supported when the " + "generic arguments to the derivatives cannot be automatically deduced.") +DIAGNOSTIC( + 31152, + Error, + cannotAssociateInterfaceRequirementWithDerivative, + "cannot associate an interface requirement with a derivative.") +DIAGNOSTIC( + 31153, + Error, + cannotUseInterfaceRequirementAsDerivative, + "cannot use an interface requirement as a derivative.") +DIAGNOSTIC( + 31154, + Error, + customDerivativeSignatureThisParamMismatch, + "custom derivative does not match expected signature on `this`. Both original and derivative " + "function must have the same `this` type.") +DIAGNOSTIC( + 31155, + Error, + customDerivativeNotAllowedForMemberFunctionsOfDifferentiableType, + "custom derivative is not allowed for non-static member functions of a differentiable type.") +DIAGNOSTIC( + 31156, + Error, + customDerivativeExpectedStatic, + "expected a static definition for the custom derivative.") +DIAGNOSTIC( + 31157, + Error, + overloadedFuncUsedWithDerivativeOfAttributes, + "cannot resolve overloaded functions for derivative-of attributes.") DIAGNOSTIC(31200, Warning, deprecatedUsage, "$0 has been deprecated: $1") DIAGNOSTIC(31201, Error, modifierNotAllowed, "modifier '$0' is not allowed here.") -DIAGNOSTIC(31202, Error, duplicateModifier, "modifier '$0' is redundant or conflicting with existing modifier '$1'") +DIAGNOSTIC( + 31202, + Error, + duplicateModifier, + "modifier '$0' is redundant or conflicting with existing modifier '$1'") DIAGNOSTIC(31203, Error, cannotExportIncompleteType, "cannot export incomplete type '$0'") -DIAGNOSTIC(31204, Error, incompleteTypeCannotBeUsedInBuffer, "incomplete type '$0' cannot be used in a buffer") -DIAGNOSTIC(31205, Error, incompleteTypeCannotBeUsedInUniformParameter, "incomplete type '$0' cannot be used in a uniform parameter") -DIAGNOSTIC(31206, Error, memoryQualifierNotAllowedOnANonImageTypeParameter, "modifier $0 is not allowed on a non image type parameter.") -DIAGNOSTIC(31208, Error, requireInputDecoratedVarForParameter, "$0 expects for argument $1 a type which is a shader input (`in`) variable.") -DIAGNOSTIC(31210, Error, derivativeGroupQuadMustBeMultiple2ForXYThreads, "compute derivative group quad requires thread dispatch count of X and Y to each be at a multiple of 2") -DIAGNOSTIC(31211, Error, derivativeGroupLinearMustBeMultiple4ForTotalThreadCount, "compute derivative group linear requires total thread dispatch count to be at a multiple of 4") -DIAGNOSTIC(31212, Error, onlyOneOfDerivativeGroupLinearOrQuadCanBeSet, "cannot set compute derivative group linear and compute derivative group quad at the same time") -DIAGNOSTIC(31213, Error, cudaKernelMustReturnVoid, "return type of a CUDA kernel function cannot be non-void.") -DIAGNOSTIC(31214, Error, differentiableKernelEntryPointCannotHaveDifferentiableParams, "differentiable kernel entry point cannot have differentiable parameters. Consider using DiffTensorView to pass differentiable data, or marking this parameter with 'no_diff'") -DIAGNOSTIC(31215, Error, cannotUseUnsizedTypeInConstantBuffer, "cannot use unsized type '$0' in a constant buffer.") +DIAGNOSTIC( + 31204, + Error, + incompleteTypeCannotBeUsedInBuffer, + "incomplete type '$0' cannot be used in a buffer") +DIAGNOSTIC( + 31205, + Error, + incompleteTypeCannotBeUsedInUniformParameter, + "incomplete type '$0' cannot be used in a uniform parameter") +DIAGNOSTIC( + 31206, + Error, + memoryQualifierNotAllowedOnANonImageTypeParameter, + "modifier $0 is not allowed on a non image type parameter.") +DIAGNOSTIC( + 31208, + Error, + requireInputDecoratedVarForParameter, + "$0 expects for argument $1 a type which is a shader input (`in`) variable.") +DIAGNOSTIC( + 31210, + Error, + derivativeGroupQuadMustBeMultiple2ForXYThreads, + "compute derivative group quad requires thread dispatch count of X and Y to each be at a " + "multiple of 2") +DIAGNOSTIC( + 31211, + Error, + derivativeGroupLinearMustBeMultiple4ForTotalThreadCount, + "compute derivative group linear requires total thread dispatch count to be at a multiple of 4") +DIAGNOSTIC( + 31212, + Error, + onlyOneOfDerivativeGroupLinearOrQuadCanBeSet, + "cannot set compute derivative group linear and compute derivative group quad at the same time") +DIAGNOSTIC( + 31213, + Error, + cudaKernelMustReturnVoid, + "return type of a CUDA kernel function cannot be non-void.") +DIAGNOSTIC( + 31214, + Error, + differentiableKernelEntryPointCannotHaveDifferentiableParams, + "differentiable kernel entry point cannot have differentiable parameters. Consider using " + "DiffTensorView to pass differentiable data, or marking this parameter with 'no_diff'") +DIAGNOSTIC( + 31215, + Error, + cannotUseUnsizedTypeInConstantBuffer, + "cannot use unsized type '$0' in a constant buffer.") // Enums -DIAGNOSTIC(32000, Error, invalidEnumTagType, "invalid tag type for 'enum': '$0'") -DIAGNOSTIC(32003, Error, unexpectedEnumTagExpr, "unexpected form for 'enum' tag value expression") +DIAGNOSTIC(32000, Error, invalidEnumTagType, "invalid tag type for 'enum': '$0'") +DIAGNOSTIC(32003, Error, unexpectedEnumTagExpr, "unexpected form for 'enum' tag value expression") // 303xx: interfaces and associated types -DIAGNOSTIC(30300, Error, assocTypeInInterfaceOnly, "'associatedtype' can only be defined in an 'interface'.") -DIAGNOSTIC(30301, Error, globalGenParamInGlobalScopeOnly, "'type_param' can only be defined global scope.") -DIAGNOSTIC(30302, Error, staticConstRequirementMustBeIntOrBool, "'static const' requirement can only have int or bool type.") -DIAGNOSTIC(30303, Error, valueRequirementMustBeCompileTimeConst, "requirement in the form of a simple value must be declared as 'static const'.") +DIAGNOSTIC( + 30300, + Error, + assocTypeInInterfaceOnly, + "'associatedtype' can only be defined in an 'interface'.") +DIAGNOSTIC( + 30301, + Error, + globalGenParamInGlobalScopeOnly, + "'type_param' can only be defined global scope.") +DIAGNOSTIC( + 30302, + Error, + staticConstRequirementMustBeIntOrBool, + "'static const' requirement can only have int or bool type.") +DIAGNOSTIC( + 30303, + Error, + valueRequirementMustBeCompileTimeConst, + "requirement in the form of a simple value must be declared as 'static const'.") DIAGNOSTIC(30310, Error, typeIsNotDifferentiable, "type '$0' is not differentiable.") // Interop -DIAGNOSTIC(30400, Error, cannotDefinePtrTypeToManagedResource, "pointer to a managed resource is invalid, use `NativeRef` instead") +DIAGNOSTIC( + 30400, + Error, + cannotDefinePtrTypeToManagedResource, + "pointer to a managed resource is invalid, use `NativeRef` instead") // Control flow -DIAGNOSTIC(30500, Warning, forLoopSideEffectChangingDifferentVar, "the for loop initializes and checks variable '$0' but the side effect expression is modifying '$1'.") -DIAGNOSTIC(30501, Warning, forLoopPredicateCheckingDifferentVar, "the for loop initializes and modifies variable '$0' but the predicate expression is checking '$1'.") -DIAGNOSTIC(30502, Warning, forLoopChangingIterationVariableInOppsoiteDirection, "the for loop is modifiying variable '$0' in the opposite direction from loop exit condition.") -DIAGNOSTIC(30503, Warning, forLoopNotModifyingIterationVariable, "the for loop is not modifiying variable '$0' because the step size evaluates to 0.") -DIAGNOSTIC(30504, Warning, forLoopTerminatesInFewerIterationsThanMaxIters, "the for loop is statically determined to terminate within $0 iterations, which is less than what [MaxIters] specifies.") -DIAGNOSTIC(30505, Warning, loopRunsForZeroIterations, "the loop runs for 0 iterations and will be removed.") -DIAGNOSTIC(30510, Error, loopInDiffFuncRequireUnrollOrMaxIters, "loops inside a differentiable function need to provide either '[MaxIters(n)]' or '[ForceUnroll]' attribute.") +DIAGNOSTIC( + 30500, + Warning, + forLoopSideEffectChangingDifferentVar, + "the for loop initializes and checks variable '$0' but the side effect expression is modifying " + "'$1'.") +DIAGNOSTIC( + 30501, + Warning, + forLoopPredicateCheckingDifferentVar, + "the for loop initializes and modifies variable '$0' but the predicate expression is checking " + "'$1'.") +DIAGNOSTIC( + 30502, + Warning, + forLoopChangingIterationVariableInOppsoiteDirection, + "the for loop is modifiying variable '$0' in the opposite direction from loop exit condition.") +DIAGNOSTIC( + 30503, + Warning, + forLoopNotModifyingIterationVariable, + "the for loop is not modifiying variable '$0' because the step size evaluates to 0.") +DIAGNOSTIC( + 30504, + Warning, + forLoopTerminatesInFewerIterationsThanMaxIters, + "the for loop is statically determined to terminate within $0 iterations, which is less than " + "what [MaxIters] specifies.") +DIAGNOSTIC( + 30505, + Warning, + loopRunsForZeroIterations, + "the loop runs for 0 iterations and will be removed.") +DIAGNOSTIC( + 30510, + Error, + loopInDiffFuncRequireUnrollOrMaxIters, + "loops inside a differentiable function need to provide either '[MaxIters(n)]' or " + "'[ForceUnroll]' attribute.") // Switch -DIAGNOSTIC(30600, Error, switchMultipleDefault, "multiple 'default' cases not allowed within a 'switch' statement") -DIAGNOSTIC(30601, Error, switchDuplicateCases, "duplicate cases not allowed within a 'switch' statement") +DIAGNOSTIC( + 30600, + Error, + switchMultipleDefault, + "multiple 'default' cases not allowed within a 'switch' statement") +DIAGNOSTIC( + 30601, + Error, + switchDuplicateCases, + "duplicate cases not allowed within a 'switch' statement") // TODO: need to assign numbers to all these extra diagnostics... DIAGNOSTIC(39999, Fatal, cyclicReference, "cyclic reference '$0'.") -DIAGNOSTIC(39999, Error, cyclicReferenceInInheritance, "cyclic reference in inheritance graph '$0'.") - -DIAGNOSTIC(39999, Error, localVariableUsedBeforeDeclared, "local variable '$0' is being used before its declaration.") -DIAGNOSTIC(39999, Error, variableUsedInItsOwnDefinition, "the initial-value expression for variable '$0' depends on the value of the variable itself") -DIAGNOSTIC(39901, Fatal , cannotProcessInclude, "internal compiler error: cannot process '__include' in the current semantic checking context.") +DIAGNOSTIC( + 39999, + Error, + cyclicReferenceInInheritance, + "cyclic reference in inheritance graph '$0'.") + +DIAGNOSTIC( + 39999, + Error, + localVariableUsedBeforeDeclared, + "local variable '$0' is being used before its declaration.") +DIAGNOSTIC( + 39999, + Error, + variableUsedInItsOwnDefinition, + "the initial-value expression for variable '$0' depends on the value of the variable itself") +DIAGNOSTIC( + 39901, + Fatal, + cannotProcessInclude, + "internal compiler error: cannot process '__include' in the current semantic checking context.") // 304xx: generics DIAGNOSTIC(30400, Error, genericTypeNeedsArgs, "generic type '$0' used without argument") DIAGNOSTIC(30401, Error, invalidTypeForConstraint, "type '$0' cannot be used as a constraint.") -DIAGNOSTIC(30402, Error, invalidConstraintSubType, "type '$0' is not a valid left hand side of a type constraint.") +DIAGNOSTIC( + 30402, + Error, + invalidConstraintSubType, + "type '$0' is not a valid left hand side of a type constraint.") // 305xx: initializer lists DIAGNOSTIC(30500, Error, tooManyInitializers, "too many initializers (expected $0, got $1)") -DIAGNOSTIC(30501, Error, cannotUseInitializerListForArrayOfUnknownSize, "cannot use initializer list for array of statically unknown size '$0'") -DIAGNOSTIC(30502, Error, cannotUseInitializerListForVectorOfUnknownSize, "cannot use initializer list for vector of statically unknown size '$0'") -DIAGNOSTIC(30503, Error, cannotUseInitializerListForMatrixOfUnknownSize, "cannot use initializer list for matrix of statically unknown size '$0' rows") -DIAGNOSTIC(30504, Error, cannotUseInitializerListForType, "cannot use initializer list for type '$0'") +DIAGNOSTIC( + 30501, + Error, + cannotUseInitializerListForArrayOfUnknownSize, + "cannot use initializer list for array of statically unknown size '$0'") +DIAGNOSTIC( + 30502, + Error, + cannotUseInitializerListForVectorOfUnknownSize, + "cannot use initializer list for vector of statically unknown size '$0'") +DIAGNOSTIC( + 30503, + Error, + cannotUseInitializerListForMatrixOfUnknownSize, + "cannot use initializer list for matrix of statically unknown size '$0' rows") +DIAGNOSTIC( + 30504, + Error, + cannotUseInitializerListForType, + "cannot use initializer list for type '$0'") // 3062x: variables -DIAGNOSTIC(30620, Error, varWithoutTypeMustHaveInitializer, "a variable declaration without an initial-value expression must be given an explicit type") -DIAGNOSTIC(30622, Error, ambiguousDefaultInitializerForType, "more than one default initializer was found for type '$0'") +DIAGNOSTIC( + 30620, + Error, + varWithoutTypeMustHaveInitializer, + "a variable declaration without an initial-value expression must be given an explicit type") +DIAGNOSTIC( + 30622, + Error, + ambiguousDefaultInitializerForType, + "more than one default initializer was found for type '$0'") DIAGNOSTIC(30623, Error, cannotHaveInitializer, "'$0' cannot have an initializer because it is $1") // 307xx: parameters -DIAGNOSTIC(30700, Error, outputParameterCannotHaveDefaultValue, "an 'out' or 'inout' parameter cannot have a default-value expression") +DIAGNOSTIC( + 30700, + Error, + outputParameterCannotHaveDefaultValue, + "an 'out' or 'inout' parameter cannot have a default-value expression") // 308xx: inheritance -DIAGNOSTIC(30810, Error, baseOfInterfaceMustBeInterface, "interface '$0' cannot inherit from non-interface type '$1'") -DIAGNOSTIC(30811, Error, baseOfStructMustBeStructOrInterface, "struct '$0' cannot inherit from type '$1' that is neither a struct nor an interface") -DIAGNOSTIC(30812, Error, baseOfEnumMustBeIntegerOrInterface, "enum '$0' cannot inherit from type '$1' that is neither an interface not a builtin integer type") -DIAGNOSTIC(30813, Error, baseOfExtensionMustBeInterface, "extension cannot inherit from non-interface type '$1'") -DIAGNOSTIC(30814, Error, baseOfClassMustBeClassOrInterface, "class '$0' cannot inherit from type '$1' that is neither a class nor an interface") +DIAGNOSTIC( + 30810, + Error, + baseOfInterfaceMustBeInterface, + "interface '$0' cannot inherit from non-interface type '$1'") +DIAGNOSTIC( + 30811, + Error, + baseOfStructMustBeStructOrInterface, + "struct '$0' cannot inherit from type '$1' that is neither a struct nor an interface") +DIAGNOSTIC( + 30812, + Error, + baseOfEnumMustBeIntegerOrInterface, + "enum '$0' cannot inherit from type '$1' that is neither an interface not a builtin integer " + "type") +DIAGNOSTIC( + 30813, + Error, + baseOfExtensionMustBeInterface, + "extension cannot inherit from non-interface type '$1'") +DIAGNOSTIC( + 30814, + Error, + baseOfClassMustBeClassOrInterface, + "class '$0' cannot inherit from type '$1' that is neither a class nor an interface") DIAGNOSTIC(30815, Error, circularityInExtension, "circular extension is not allowed.") -DIAGNOSTIC(30820, Error, baseStructMustBeListedFirst, "a struct type may only inherit from one other struct type, and that type must appear first in the list of bases") -DIAGNOSTIC(30821, Error, tagTypeMustBeListedFirst, "an unum type may only have a single tag type, and that type must be listed first in the list of bases") -DIAGNOSTIC(30822, Error, baseClassMustBeListedFirst, "a class type may only inherit from one other class type, and that type must appear first in the list of bases") - -DIAGNOSTIC(30830, Error, cannotInheritFromExplicitlySealedDeclarationInAnotherModule, "cannot inherit from type '$0' marked 'sealed' in module '$1'") -DIAGNOSTIC(30831, Error, cannotInheritFromImplicitlySealedDeclarationInAnotherModule, "cannot inherit from type '$0' in module '$1' because it is implicitly 'sealed'; mark the base type 'open' to allow inheritance across modules") +DIAGNOSTIC( + 30820, + Error, + baseStructMustBeListedFirst, + "a struct type may only inherit from one other struct type, and that type must appear first in " + "the list of bases") +DIAGNOSTIC( + 30821, + Error, + tagTypeMustBeListedFirst, + "an unum type may only have a single tag type, and that type must be listed first in the list " + "of bases") +DIAGNOSTIC( + 30822, + Error, + baseClassMustBeListedFirst, + "a class type may only inherit from one other class type, and that type must appear first in " + "the list of bases") + +DIAGNOSTIC( + 30830, + Error, + cannotInheritFromExplicitlySealedDeclarationInAnotherModule, + "cannot inherit from type '$0' marked 'sealed' in module '$1'") +DIAGNOSTIC( + 30831, + Error, + cannotInheritFromImplicitlySealedDeclarationInAnotherModule, + "cannot inherit from type '$0' in module '$1' because it is implicitly 'sealed'; mark the base " + "type 'open' to allow inheritance across modules") DIAGNOSTIC(30832, Error, invalidTypeForInheritance, "type '$0' cannot be used for inheritance") -DIAGNOSTIC(30850, Error, invalidExtensionOnType, "type '$0' cannot be extended. `extension` can only be used to extend a nominal type.") +DIAGNOSTIC( + 30850, + Error, + invalidExtensionOnType, + "type '$0' cannot be extended. `extension` can only be used to extend a nominal type.") DIAGNOSTIC(30851, Error, invalidMemberTypeInExtension, "$0 cannot be a part of an `extension`") -DIAGNOSTIC(30852, Error, invalidExtensionOnInterface, "cannot extend interface type '$0'. consider using a generic extension: `extension T {...}`.") +DIAGNOSTIC( + 30852, + Error, + invalidExtensionOnInterface, + "cannot extend interface type '$0'. consider using a generic extension: `extension T " + "{...}`.") // 309xx: subscripts -DIAGNOSTIC(30900, Error, multiDimensionalArrayNotSupported, "multi-dimensional array is not supported.") +DIAGNOSTIC( + 30900, + Error, + multiDimensionalArrayNotSupported, + "multi-dimensional array is not supported.") // 310xx: properties // 311xx: accessors -DIAGNOSTIC(31100, Error, accessorMustBeInsideSubscriptOrProperty, "an accessor declaration is only allowed inside a subscript or property declaration") - -DIAGNOSTIC(31101, Error, nonSetAccessorMustNotHaveParams, "accessors other than 'set' must not have parameters") -DIAGNOSTIC(31102, Error, setAccessorMayNotHaveMoreThanOneParam, "a 'set' accessor may not have more than one parameter") -DIAGNOSTIC(31102, Error, setAccessorParamWrongType, "'set' parameter '$0' has type '$1' which does not match the expected type '$2'") +DIAGNOSTIC( + 31100, + Error, + accessorMustBeInsideSubscriptOrProperty, + "an accessor declaration is only allowed inside a subscript or property declaration") + +DIAGNOSTIC( + 31101, + Error, + nonSetAccessorMustNotHaveParams, + "accessors other than 'set' must not have parameters") +DIAGNOSTIC( + 31102, + Error, + setAccessorMayNotHaveMoreThanOneParam, + "a 'set' accessor may not have more than one parameter") +DIAGNOSTIC( + 31102, + Error, + setAccessorParamWrongType, + "'set' parameter '$0' has type '$1' which does not match the expected type '$2'") // 313xx: bit fields -DIAGNOSTIC(31300, Error, bitFieldTooWide, "bit-field size ($0) exceeds the width of its type $1 ($2)") +DIAGNOSTIC( + 31300, + Error, + bitFieldTooWide, + "bit-field size ($0) exceeds the width of its type $1 ($2)") DIAGNOSTIC(31301, Error, bitFieldNonIntegral, "bit-field type ($0) must be an integral type") // 39999 waiting to be placed in the right range -DIAGNOSTIC(39999, Error, expectedIntegerConstantWrongType, "expected integer constant (found: '$0')") -DIAGNOSTIC(39999, Error, expectedIntegerConstantNotConstant, "expression does not evaluate to a compile-time constant") -DIAGNOSTIC(39999, Error, expectedIntegerConstantNotLiteral, "could not extract value from integer constant") - -DIAGNOSTIC(39999, Error, expectedRayTracingPayloadObjectAtLocationButMissing, "raytracing payload expected at location $0 but it is missing") - -DIAGNOSTIC(39999, Error, noApplicableOverloadForNameWithArgs, "no overload for '$0' applicable to arguments of type $1") +DIAGNOSTIC( + 39999, + Error, + expectedIntegerConstantWrongType, + "expected integer constant (found: '$0')") +DIAGNOSTIC( + 39999, + Error, + expectedIntegerConstantNotConstant, + "expression does not evaluate to a compile-time constant") +DIAGNOSTIC( + 39999, + Error, + expectedIntegerConstantNotLiteral, + "could not extract value from integer constant") + +DIAGNOSTIC( + 39999, + Error, + expectedRayTracingPayloadObjectAtLocationButMissing, + "raytracing payload expected at location $0 but it is missing") + +DIAGNOSTIC( + 39999, + Error, + noApplicableOverloadForNameWithArgs, + "no overload for '$0' applicable to arguments of type $1") DIAGNOSTIC(39999, Error, noApplicableWithArgs, "no overload applicable to arguments of type $0") -DIAGNOSTIC(39999, Error, ambiguousOverloadForNameWithArgs, "ambiguous call to '$0' with arguments of type $1") -DIAGNOSTIC(39999, Error, ambiguousOverloadWithArgs, "ambiguous call to overloaded operation with arguments of type $0") +DIAGNOSTIC( + 39999, + Error, + ambiguousOverloadForNameWithArgs, + "ambiguous call to '$0' with arguments of type $1") +DIAGNOSTIC( + 39999, + Error, + ambiguousOverloadWithArgs, + "ambiguous call to overloaded operation with arguments of type $0") DIAGNOSTIC(39999, Note, overloadCandidate, "candidate: $0") DIAGNOSTIC(39999, Note, invisibleOverloadCandidate, "candidate (invisible): $0") @@ -597,98 +1535,295 @@ DIAGNOSTIC(39999, Note, invisibleOverloadCandidate, "candidate (invisible): $0") DIAGNOSTIC(39999, Note, moreOverloadCandidates, "$0 more overload candidates") DIAGNOSTIC(39999, Error, caseOutsideSwitch, "'case' not allowed outside of a 'switch' statement") -DIAGNOSTIC(39999, Error, defaultOutsideSwitch, "'default' not allowed outside of a 'switch' statement") +DIAGNOSTIC( + 39999, + Error, + defaultOutsideSwitch, + "'default' not allowed outside of a 'switch' statement") DIAGNOSTIC(39999, Error, expectedAGeneric, "expected a generic when using '<...>' (found: '$0')") -DIAGNOSTIC(39999, Error, genericArgumentInferenceFailed, "could not specialize generic for arguments of type $0") +DIAGNOSTIC( + 39999, + Error, + genericArgumentInferenceFailed, + "could not specialize generic for arguments of type $0") DIAGNOSTIC(39999, Error, ambiguousReference, "ambiguous reference to '$0'") DIAGNOSTIC(39999, Error, ambiguousExpression, "ambiguous reference") DIAGNOSTIC(39999, Error, declarationDidntDeclareAnything, "declaration does not declare anything") -DIAGNOSTIC(39999, Error, expectedPrefixOperator, "function called as prefix operator was not declared `__prefix`") -DIAGNOSTIC(39999, Error, expectedPostfixOperator, "function called as postfix operator was not declared `__postfix`") +DIAGNOSTIC( + 39999, + Error, + expectedPrefixOperator, + "function called as prefix operator was not declared `__prefix`") +DIAGNOSTIC( + 39999, + Error, + expectedPostfixOperator, + "function called as postfix operator was not declared `__postfix`") DIAGNOSTIC(39999, Error, notEnoughArguments, "not enough arguments to call (got $0, expected $1)") DIAGNOSTIC(39999, Error, tooManyArguments, "too many arguments to call (got $0, expected $1)") DIAGNOSTIC(39999, Error, invalidIntegerLiteralSuffix, "invalid suffix '$0' on integer literal") -DIAGNOSTIC(39999, Error, invalidFloatingPointLiteralSuffix, "invalid suffix '$0' on floating-point literal") - -DIAGNOSTIC(39999, Warning, integerLiteralTruncated, "integer literal '$0' too large for type '$1' truncated to '$2'") -DIAGNOSTIC(39999, Warning, floatLiteralUnrepresentable, "$0 literal '$1' unrepresentable, converted to '$2'") -DIAGNOSTIC(39999, Warning, floatLiteralTooSmall, "'$1' is smaller than the smallest representable value for type $0, converted to '$2'") - -DIAGNOSTIC(39999, Error, unableToFindSymbolInModule, "unable to find the mangled symbol '$0' in module '$1'") - -DIAGNOSTIC(39999, Error, overloadedParameterToHigherOrderFunction, "passing overloaded functions to higher order functions is not supported") +DIAGNOSTIC( + 39999, + Error, + invalidFloatingPointLiteralSuffix, + "invalid suffix '$0' on floating-point literal") + +DIAGNOSTIC( + 39999, + Warning, + integerLiteralTruncated, + "integer literal '$0' too large for type '$1' truncated to '$2'") +DIAGNOSTIC( + 39999, + Warning, + floatLiteralUnrepresentable, + "$0 literal '$1' unrepresentable, converted to '$2'") +DIAGNOSTIC( + 39999, + Warning, + floatLiteralTooSmall, + "'$1' is smaller than the smallest representable value for type $0, converted to '$2'") + +DIAGNOSTIC( + 39999, + Error, + unableToFindSymbolInModule, + "unable to find the mangled symbol '$0' in module '$1'") + +DIAGNOSTIC( + 39999, + Error, + overloadedParameterToHigherOrderFunction, + "passing overloaded functions to higher order functions is not supported") // 38xxx -DIAGNOSTIC(38000, Error, entryPointFunctionNotFound, "no function found matching entry point name '$0'") -DIAGNOSTIC(38001, Error, ambiguousEntryPoint, "more than one function matches entry point name '$0'") -DIAGNOSTIC(38003, Error, entryPointSymbolNotAFunction, "entry point '$0' must be declared as a function") - -DIAGNOSTIC(38004, Error, entryPointTypeParameterNotFound, "no type found matching entry-point type parameter name '$0'") -DIAGNOSTIC(38005, Error, expectedTypeForSpecializationArg, "expected a type as argument for specialization parameter '$0'") - -DIAGNOSTIC(38006, Warning, specifiedStageDoesntMatchAttribute, "entry point '$0' being compiled for the '$1' stage has a '[shader(...)]' attribute that specifies the '$2' stage") -DIAGNOSTIC(38007, Error, entryPointHasNoStage, "no stage specified for entry point '$0'; use either a '[shader(\"name\")]' function attribute or the '-stage ' command-line option to specify a stage") - -DIAGNOSTIC(38008, Error, specializationParameterOfNameNotSpecialized, "no specialization argument was provided for specialization parameter '$0'") -DIAGNOSTIC(38008, Error, specializationParameterNotSpecialized, "no specialization argument was provided for specialization parameter") - -DIAGNOSTIC(38009, Error, expectedValueOfTypeForSpecializationArg, "expected a constant value of type '$0' as argument for specialization parameter '$1'") - -DIAGNOSTIC(38100, Error, typeDoesntImplementInterfaceRequirement, "type '$0' does not provide required interface member '$1'") -DIAGNOSTIC(38105, Error, memberDoesNotMatchRequirementSignature, "member '$0' does not match interface requirement.") -DIAGNOSTIC(38101, Error, thisExpressionOutsideOfTypeDecl, "'this' expression can only be used in members of an aggregate type") -DIAGNOSTIC(38102, Error, initializerNotInsideType, "an 'init' declaration is only allowed inside a type or 'extension' declaration") -DIAGNOSTIC(38103, Error, thisTypeOutsideOfTypeDecl, "'This' type can only be used inside of an aggregate type") -DIAGNOSTIC(38104, Error, returnValNotAvailable, "cannot use '__return_val' here. '__return_val' is defined only in functions that return a non-copyable value.") -DIAGNOSTIC(38020, Error, mismatchEntryPointTypeArgument, "expecting $0 entry-point type arguments, provided $1.") -DIAGNOSTIC(38021, Error, typeArgumentForGenericParameterDoesNotConformToInterface, "type argument `$0` for generic parameter `$1` does not conform to interface `$2`.") - -DIAGNOSTIC(38022, Error, cannotSpecializeGlobalGenericToItself, "the global type parameter '$0' cannot be specialized to itself") -DIAGNOSTIC(38023, Error, cannotSpecializeGlobalGenericToAnotherGenericParam, "the global type parameter '$0' cannot be specialized using another global type parameter ('$1')") - - -DIAGNOSTIC(38024, Error, invalidDispatchThreadIDType, "parameter with SV_DispatchThreadID must be either scalar or vector (1 to 3) of uint/int but is $0") +DIAGNOSTIC( + 38000, + Error, + entryPointFunctionNotFound, + "no function found matching entry point name '$0'") +DIAGNOSTIC( + 38001, + Error, + ambiguousEntryPoint, + "more than one function matches entry point name '$0'") +DIAGNOSTIC( + 38003, + Error, + entryPointSymbolNotAFunction, + "entry point '$0' must be declared as a function") + +DIAGNOSTIC( + 38004, + Error, + entryPointTypeParameterNotFound, + "no type found matching entry-point type parameter name '$0'") +DIAGNOSTIC( + 38005, + Error, + expectedTypeForSpecializationArg, + "expected a type as argument for specialization parameter '$0'") + +DIAGNOSTIC( + 38006, + Warning, + specifiedStageDoesntMatchAttribute, + "entry point '$0' being compiled for the '$1' stage has a '[shader(...)]' attribute that " + "specifies the '$2' stage") +DIAGNOSTIC( + 38007, + Error, + entryPointHasNoStage, + "no stage specified for entry point '$0'; use either a '[shader(\"name\")]' function attribute " + "or the '-stage ' command-line option to specify a stage") + +DIAGNOSTIC( + 38008, + Error, + specializationParameterOfNameNotSpecialized, + "no specialization argument was provided for specialization parameter '$0'") +DIAGNOSTIC( + 38008, + Error, + specializationParameterNotSpecialized, + "no specialization argument was provided for specialization parameter") + +DIAGNOSTIC( + 38009, + Error, + expectedValueOfTypeForSpecializationArg, + "expected a constant value of type '$0' as argument for specialization parameter '$1'") + +DIAGNOSTIC( + 38100, + Error, + typeDoesntImplementInterfaceRequirement, + "type '$0' does not provide required interface member '$1'") +DIAGNOSTIC( + 38105, + Error, + memberDoesNotMatchRequirementSignature, + "member '$0' does not match interface requirement.") +DIAGNOSTIC( + 38101, + Error, + thisExpressionOutsideOfTypeDecl, + "'this' expression can only be used in members of an aggregate type") +DIAGNOSTIC( + 38102, + Error, + initializerNotInsideType, + "an 'init' declaration is only allowed inside a type or 'extension' declaration") +DIAGNOSTIC( + 38103, + Error, + thisTypeOutsideOfTypeDecl, + "'This' type can only be used inside of an aggregate type") +DIAGNOSTIC( + 38104, + Error, + returnValNotAvailable, + "cannot use '__return_val' here. '__return_val' is defined only in functions that return a " + "non-copyable value.") +DIAGNOSTIC( + 38020, + Error, + mismatchEntryPointTypeArgument, + "expecting $0 entry-point type arguments, provided $1.") +DIAGNOSTIC( + 38021, + Error, + typeArgumentForGenericParameterDoesNotConformToInterface, + "type argument `$0` for generic parameter `$1` does not conform to interface `$2`.") + +DIAGNOSTIC( + 38022, + Error, + cannotSpecializeGlobalGenericToItself, + "the global type parameter '$0' cannot be specialized to itself") +DIAGNOSTIC( + 38023, + Error, + cannotSpecializeGlobalGenericToAnotherGenericParam, + "the global type parameter '$0' cannot be specialized using another global type parameter " + "('$1')") + + +DIAGNOSTIC( + 38024, + Error, + invalidDispatchThreadIDType, + "parameter with SV_DispatchThreadID must be either scalar or vector (1 to 3) of uint/int but " + "is $0") DIAGNOSTIC(-1, Note, noteWhenCompilingEntryPoint, "when compiling entry point '$0'") -DIAGNOSTIC(38025, Error, mismatchSpecializationArguments, "expected $0 specialization arguments ($1 provided)") -DIAGNOSTIC(38026, Error, globalTypeArgumentDoesNotConformToInterface, "type argument `$1` for global generic parameter `$0` does not conform to interface `$2`.") - -DIAGNOSTIC(38027, Error, mismatchExistentialSlotArgCount, "expected $0 existential slot arguments ($1 provided)") -DIAGNOSTIC(38029, Error, typeArgumentDoesNotConformToInterface, "type argument '$0' does not conform to the required interface '$1'") - -DIAGNOSTIC(38031, Error, invalidUseOfNoDiff, "'no_diff' can only be used to decorate a call or a subscript operation") -DIAGNOSTIC(38032, Error, useOfNoDiffOnDifferentiableFunc, "use 'no_diff' on a call to a differentiable function has no meaning.") -DIAGNOSTIC(38033, Error, cannotUseNoDiffInNonDifferentiableFunc, "cannot use 'no_diff' in a non-differentiable function.") -DIAGNOSTIC(38034, Error, cannotUseConstRefOnDifferentiableParameter, "cannot use '__constref' on a differentiable parameter.") -DIAGNOSTIC(38034, Error, cannotUseConstRefOnDifferentiableMemberMethod, "cannot use '[constref]' on a differentiable member method of a differentiable type.") - -DIAGNOSTIC(38040, Warning, nonUniformEntryPointParameterTreatedAsUniform, "parameter '$0' is treated as 'uniform' because it does not have a system-value semantic.") +DIAGNOSTIC( + 38025, + Error, + mismatchSpecializationArguments, + "expected $0 specialization arguments ($1 provided)") +DIAGNOSTIC( + 38026, + Error, + globalTypeArgumentDoesNotConformToInterface, + "type argument `$1` for global generic parameter `$0` does not conform to interface `$2`.") + +DIAGNOSTIC( + 38027, + Error, + mismatchExistentialSlotArgCount, + "expected $0 existential slot arguments ($1 provided)") +DIAGNOSTIC( + 38029, + Error, + typeArgumentDoesNotConformToInterface, + "type argument '$0' does not conform to the required interface '$1'") + +DIAGNOSTIC( + 38031, + Error, + invalidUseOfNoDiff, + "'no_diff' can only be used to decorate a call or a subscript operation") +DIAGNOSTIC( + 38032, + Error, + useOfNoDiffOnDifferentiableFunc, + "use 'no_diff' on a call to a differentiable function has no meaning.") +DIAGNOSTIC( + 38033, + Error, + cannotUseNoDiffInNonDifferentiableFunc, + "cannot use 'no_diff' in a non-differentiable function.") +DIAGNOSTIC( + 38034, + Error, + cannotUseConstRefOnDifferentiableParameter, + "cannot use '__constref' on a differentiable parameter.") +DIAGNOSTIC( + 38034, + Error, + cannotUseConstRefOnDifferentiableMemberMethod, + "cannot use '[constref]' on a differentiable member method of a differentiable type.") + +DIAGNOSTIC( + 38040, + Warning, + nonUniformEntryPointParameterTreatedAsUniform, + "parameter '$0' is treated as 'uniform' because it does not have a system-value semantic.") DIAGNOSTIC(38200, Error, recursiveModuleImport, "module `$0` recursively imports itself") -DIAGNOSTIC(39999, Error, errorInImportedModule, "import of module '$0' failed because of a compilation error") +DIAGNOSTIC( + 39999, + Error, + errorInImportedModule, + "import of module '$0' failed because of a compilation error") DIAGNOSTIC(39999, Fatal, complationCeased, "compilation ceased") // 39xxx - Type layout and parameter binding. -DIAGNOSTIC(39000, Error, conflictingExplicitBindingsForParameter, "conflicting explicit bindings for parameter '$0'") -DIAGNOSTIC(39001, Warning, parameterBindingsOverlap, "explicit binding for parameter '$0' overlaps with parameter '$1'") - - -DIAGNOSTIC(39002, Error, shaderParameterDeclarationsDontMatch, "declarations of shader parameter '$0' in different translation units don't match") - -DIAGNOSTIC(39003, Note, shaderParameterTypeMismatch, "type is declared as '$0' in one translation unit, and '$0' in another") -DIAGNOSTIC(39004, Note, fieldTypeMisMatch, "type of field '$0' is declared as '$1' in one translation unit, and '$2' in another") -DIAGNOSTIC(39005, Note, fieldDeclarationsDontMatch, "type '$0' is declared with different fields in each translation unit") +DIAGNOSTIC( + 39000, + Error, + conflictingExplicitBindingsForParameter, + "conflicting explicit bindings for parameter '$0'") +DIAGNOSTIC( + 39001, + Warning, + parameterBindingsOverlap, + "explicit binding for parameter '$0' overlaps with parameter '$1'") + + +DIAGNOSTIC( + 39002, + Error, + shaderParameterDeclarationsDontMatch, + "declarations of shader parameter '$0' in different translation units don't match") + +DIAGNOSTIC( + 39003, + Note, + shaderParameterTypeMismatch, + "type is declared as '$0' in one translation unit, and '$0' in another") +DIAGNOSTIC( + 39004, + Note, + fieldTypeMisMatch, + "type of field '$0' is declared as '$1' in one translation unit, and '$2' in another") +DIAGNOSTIC( + 39005, + Note, + fieldDeclarationsDontMatch, + "type '$0' is declared with different fields in each translation unit") DIAGNOSTIC(39006, Note, usedInDeclarationOf, "used in declaration of '$0'") DIAGNOSTIC(39007, Error, unknownRegisterClass, "unknown register class: '$0'") @@ -697,43 +1832,132 @@ DIAGNOSTIC(39009, Error, expectedSpace, "expected 'space', got '$0'") DIAGNOSTIC(39010, Error, expectedSpaceIndex, "expected a register space index after 'space'") DIAGNOSTIC(39011, Error, invalidComponentMask, "invalid register component mask '$0'.") -DIAGNOSTIC(39013, Warning, registerModifierButNoVulkanLayout, "shader parameter '$0' has a 'register' specified for D3D, but no '[[vk::binding(...)]]` specified for Vulkan") -DIAGNOSTIC(39014, Error, unexpectedSpecifierAfterSpace, "unexpected specifier after register space: '$0'") -DIAGNOSTIC(39015, Error, wholeSpaceParameterRequiresZeroBinding, "shader parameter '$0' consumes whole descriptor sets, so the binding must be in the form '[[vk::binding(0, ...)]]'; the non-zero binding '$1' is not allowed") - -DIAGNOSTIC(39016, Warning, hlslToVulkanMappingNotFound, "unable to infer Vulkan binding for '$0', automatic layout will be used") - -DIAGNOSTIC(39017, Error, dontExpectOutParametersForStage, "the '$0' stage does not support `out` or `inout` entry point parameters") -DIAGNOSTIC(39018, Error, dontExpectInParametersForStage, "the '$0' stage does not support `in` entry point parameters") - -DIAGNOSTIC(39019, Warning, globalUniformNotExpected, "'$0' is implicitly a global shader parameter, not a global variable. If a global variable is intended, add the 'static' modifier. If a uniform shader parameter is intended, add the 'uniform' modifier to silence this warning.") - -DIAGNOSTIC(39020, Error, tooManyShaderRecordConstantBuffers, "can have at most one 'shader record' attributed constant buffer; found $0.") - -DIAGNOSTIC(39021, Error, typeParametersNotAllowedOnEntryPointGlobal, "local-root-signature shader parameter '$0' at global scope must not include existential/interface types") - -DIAGNOSTIC(39022, Warning, vkIndexWithoutVkLocation, "ignoring '[[vk::index(...)]]` attribute without a corresponding '[[vk::location(...)]]' attribute") -DIAGNOSTIC(39023, Error, mixingImplicitAndExplicitBindingForVaryingParams, "mixing explicit and implicit bindings for varying parameters is not supported (see '$0' and '$1')") - -DIAGNOSTIC(39024, Warning, cannotInferVulkanBindingWithoutRegisterModifier, "shader parameter '$0' doesn't have a 'register' specified, automatic layout will be used") - -DIAGNOSTIC(39025, Error, conflictingVulkanInferredBindingForParameter, "conflicting vulkan inferred binding for parameter '$0' overlap is $1 and $2") - -DIAGNOSTIC(39026, Error, matrixLayoutModifierOnNonMatrixType, "matrix layout modifier cannot be used on non-matrix type '$0'.") - -DIAGNOSTIC(39027, Error, getAttributeAtVertexMustReferToPerVertexInput, "'GetAttributeAtVertex' must reference a vertex input directly, and the vertex input must be decorated with 'pervertex' or 'nointerpolation'.") - -DIAGNOSTIC(39028, Error, notValidVaryingParameter, "parameter '$0' is not a valid varying parameter.") +DIAGNOSTIC( + 39013, + Warning, + registerModifierButNoVulkanLayout, + "shader parameter '$0' has a 'register' specified for D3D, but no '[[vk::binding(...)]]` " + "specified for Vulkan") +DIAGNOSTIC( + 39014, + Error, + unexpectedSpecifierAfterSpace, + "unexpected specifier after register space: '$0'") +DIAGNOSTIC( + 39015, + Error, + wholeSpaceParameterRequiresZeroBinding, + "shader parameter '$0' consumes whole descriptor sets, so the binding must be in the form " + "'[[vk::binding(0, ...)]]'; the non-zero binding '$1' is not allowed") + +DIAGNOSTIC( + 39016, + Warning, + hlslToVulkanMappingNotFound, + "unable to infer Vulkan binding for '$0', automatic layout will be used") + +DIAGNOSTIC( + 39017, + Error, + dontExpectOutParametersForStage, + "the '$0' stage does not support `out` or `inout` entry point parameters") +DIAGNOSTIC( + 39018, + Error, + dontExpectInParametersForStage, + "the '$0' stage does not support `in` entry point parameters") + +DIAGNOSTIC( + 39019, + Warning, + globalUniformNotExpected, + "'$0' is implicitly a global shader parameter, not a global variable. If a global variable is " + "intended, add the 'static' modifier. If a uniform shader parameter is intended, add the " + "'uniform' modifier to silence this warning.") + +DIAGNOSTIC( + 39020, + Error, + tooManyShaderRecordConstantBuffers, + "can have at most one 'shader record' attributed constant buffer; found $0.") + +DIAGNOSTIC( + 39021, + Error, + typeParametersNotAllowedOnEntryPointGlobal, + "local-root-signature shader parameter '$0' at global scope must not include " + "existential/interface types") + +DIAGNOSTIC( + 39022, + Warning, + vkIndexWithoutVkLocation, + "ignoring '[[vk::index(...)]]` attribute without a corresponding '[[vk::location(...)]]' " + "attribute") +DIAGNOSTIC( + 39023, + Error, + mixingImplicitAndExplicitBindingForVaryingParams, + "mixing explicit and implicit bindings for varying parameters is not supported (see '$0' and " + "'$1')") + +DIAGNOSTIC( + 39024, + Warning, + cannotInferVulkanBindingWithoutRegisterModifier, + "shader parameter '$0' doesn't have a 'register' specified, automatic layout will be used") + +DIAGNOSTIC( + 39025, + Error, + conflictingVulkanInferredBindingForParameter, + "conflicting vulkan inferred binding for parameter '$0' overlap is $1 and $2") + +DIAGNOSTIC( + 39026, + Error, + matrixLayoutModifierOnNonMatrixType, + "matrix layout modifier cannot be used on non-matrix type '$0'.") + +DIAGNOSTIC( + 39027, + Error, + getAttributeAtVertexMustReferToPerVertexInput, + "'GetAttributeAtVertex' must reference a vertex input directly, and the vertex input must be " + "decorated with 'pervertex' or 'nointerpolation'.") + +DIAGNOSTIC( + 39028, + Error, + notValidVaryingParameter, + "parameter '$0' is not a valid varying parameter.") // // 4xxxx - IL code generation. // -DIAGNOSTIC(40001, Error, bindingAlreadyOccupiedByComponent, "resource binding location '$0' is already occupied by component '$1'.") +DIAGNOSTIC( + 40001, + Error, + bindingAlreadyOccupiedByComponent, + "resource binding location '$0' is already occupied by component '$1'.") DIAGNOSTIC(40002, Error, invalidBindingValue, "binding location '$0' is out of valid range.") -DIAGNOSTIC(40003, Error, bindingExceedsLimit, "binding location '$0' assigned to component '$1' exceeds maximum limit.") -DIAGNOSTIC(40004, Error, bindingAlreadyOccupiedByModule, "DescriptorSet ID '$0' is already occupied by module instance '$1'.") -DIAGNOSTIC(40005, Error, topLevelModuleUsedWithoutSpecifyingBinding, "top level module '$0' is being used without specifying binding location. Use [Binding: \"index\"] attribute to provide a binding location.") +DIAGNOSTIC( + 40003, + Error, + bindingExceedsLimit, + "binding location '$0' assigned to component '$1' exceeds maximum limit.") +DIAGNOSTIC( + 40004, + Error, + bindingAlreadyOccupiedByModule, + "DescriptorSet ID '$0' is already occupied by module instance '$1'.") +DIAGNOSTIC( + 40005, + Error, + topLevelModuleUsedWithoutSpecifyingBinding, + "top level module '$0' is being used without specifying binding location. Use [Binding: " + "\"index\"] attribute to provide a binding location.") DIAGNOSTIC(40006, Error, unimplementedSystemValueSemantic, "unknown system-value semantic '$0'") @@ -743,17 +1967,38 @@ DIAGNOSTIC(40006, Error, needCompileTimeConstant, "expected a compile-time const DIAGNOSTIC(40007, Internal, irValidationFailed, "IR validation failed: $0") -DIAGNOSTIC(40008, Error, invalidLValueForRefParameter, "the form of this l-value argument is not valid for a `ref` parameter") - -DIAGNOSTIC(40009, Error, dynamicInterfaceLacksAnyValueSizeAttribute, "interface '$0' is being used in dynamic dispatch code but has no [anyValueSize] attribute defined.") +DIAGNOSTIC( + 40008, + Error, + invalidLValueForRefParameter, + "the form of this l-value argument is not valid for a `ref` parameter") + +DIAGNOSTIC( + 40009, + Error, + dynamicInterfaceLacksAnyValueSizeAttribute, + "interface '$0' is being used in dynamic dispatch code but has no [anyValueSize] attribute " + "defined.") DIAGNOSTIC(40010, Note, seeInterfaceUsage, "see usage of interface '$0'.") -DIAGNOSTIC(40011, Error, unconstrainedGenericParameterNotAllowedInDynamicFunction, "unconstrained generic paramter '$0' is not allowed in a dynamic function.") +DIAGNOSTIC( + 40011, + Error, + unconstrainedGenericParameterNotAllowedInDynamicFunction, + "unconstrained generic paramter '$0' is not allowed in a dynamic function.") -DIAGNOSTIC(40020, Error, cannotUnrollLoop, "loop does not terminate within the limited number of iterations, unrolling is aborted.") +DIAGNOSTIC( + 40020, + Error, + cannotUnrollLoop, + "loop does not terminate within the limited number of iterations, unrolling is aborted.") -DIAGNOSTIC(40030, Fatal, functionNeverReturnsFatal, "function '$0' never returns, compilation ceased.") +DIAGNOSTIC( + 40030, + Fatal, + functionNeverReturnsFatal, + "function '$0' never returns, compilation ceased.") // 41000 - IR-level validation issues @@ -761,65 +2006,201 @@ DIAGNOSTIC(41000, Warning, unreachableCode, "unreachable code detected") DIAGNOSTIC(41001, Error, recursiveType, "type '$0' contains cyclic reference to itself.") DIAGNOSTIC(41010, Warning, missingReturn, "control flow may reach end of non-'void' function") -DIAGNOSTIC(41011, Error, profileIncompatibleWithTargetSwitch, "__target_switch has no compatable target with current profile '$0'") -DIAGNOSTIC(41012, Warning, profileImplicitlyUpgraded, "entry point '$0' uses additional capabilities that are not part of the specified profile '$1'. The profile setting is automatically updated to include these capabilities: '$2'") -DIAGNOSTIC(41012, Error, profileImplicitlyUpgradedRestrictive, "entry point '$0' uses capabilities that are not part of the specified profile '$1'. Missing capabilities are: '$2'") +DIAGNOSTIC( + 41011, + Error, + profileIncompatibleWithTargetSwitch, + "__target_switch has no compatable target with current profile '$0'") +DIAGNOSTIC( + 41012, + Warning, + profileImplicitlyUpgraded, + "entry point '$0' uses additional capabilities that are not part of the specified profile " + "'$1'. The profile setting is automatically updated to include these capabilities: '$2'") +DIAGNOSTIC( + 41012, + Error, + profileImplicitlyUpgradedRestrictive, + "entry point '$0' uses capabilities that are not part of the specified profile '$1'. Missing " + "capabilities are: '$2'") DIAGNOSTIC(41015, Warning, usingUninitializedOut, "use of uninitialized out parameter '$0'") DIAGNOSTIC(41016, Warning, usingUninitializedVariable, "use of uninitialized variable '$0'") -DIAGNOSTIC(41017, Warning, usingUninitializedGlobalVariable, "use of uninitialized global variable '$0'") -DIAGNOSTIC(41018, Warning, returningWithUninitializedOut, "returning without initializing out parameter '$0'") -DIAGNOSTIC(41019, Warning, returningWithPartiallyUninitializedOut, "returning without fully initializing out parameter '$0'") -DIAGNOSTIC(41020, Warning, constructorUninitializedField, "exiting constructor without initializing field '$0'") -DIAGNOSTIC(41021, Warning, fieldNotDefaultInitialized, "default initializer for '$0' will not initialize field '$1'") +DIAGNOSTIC( + 41017, + Warning, + usingUninitializedGlobalVariable, + "use of uninitialized global variable '$0'") +DIAGNOSTIC( + 41018, + Warning, + returningWithUninitializedOut, + "returning without initializing out parameter '$0'") +DIAGNOSTIC( + 41019, + Warning, + returningWithPartiallyUninitializedOut, + "returning without fully initializing out parameter '$0'") +DIAGNOSTIC( + 41020, + Warning, + constructorUninitializedField, + "exiting constructor without initializing field '$0'") +DIAGNOSTIC( + 41021, + Warning, + fieldNotDefaultInitialized, + "default initializer for '$0' will not initialize field '$1'") DIAGNOSTIC(41022, Warning, inOutNeverStoredInto, "inout parameter '$0' is never written to") -DIAGNOSTIC(41023, Warning, methodNeverMutates, "method marked `[mutable]` but never modifies `this`") - -DIAGNOSTIC(41011, Error, typeDoesNotFitAnyValueSize, "type '$0' does not fit in the size required by its conforming interface.") +DIAGNOSTIC( + 41023, + Warning, + methodNeverMutates, + "method marked `[mutable]` but never modifies `this`") + +DIAGNOSTIC( + 41011, + Error, + typeDoesNotFitAnyValueSize, + "type '$0' does not fit in the size required by its conforming interface.") DIAGNOSTIC(41012, Note, typeAndLimit, "sizeof($0) is $1, limit is $2") -DIAGNOSTIC(41012, Error, typeCannotBePackedIntoAnyValue, "type '$0' contains fields that cannot be packed into an AnyValue.") -DIAGNOSTIC(41020, Error, lossOfDerivativeDueToCallOfNonDifferentiableFunction, "derivative cannot be propagated through call to non-$1-differentiable function `$0`, use 'no_diff' to clarify intention.") -DIAGNOSTIC(41024, Error, lossOfDerivativeAssigningToNonDifferentiableLocation, "derivative is lost during assignment to non-differentiable location, use 'detach()' to clarify intention.") -DIAGNOSTIC(41025, Error, lossOfDerivativeUsingNonDifferentiableLocationAsOutArg, "derivative is lost when passing a non-differentiable location to an `out` or `inout` parameter, consider passing a temporary variable instead.") -DIAGNOSTIC(41021, Error, differentiableFuncMustHaveOutput, "a differentiable function must have at least one differentiable output.") -DIAGNOSTIC(41022, Error, differentiableFuncMustHaveInput, "a differentiable function must have at least one differentiable input.") -DIAGNOSTIC(41023, Error, getStringHashMustBeOnStringLiteral, "getStringHash can only be called when argument is statically resolvable to a string literal") - -DIAGNOSTIC(41030, Warning, operatorShiftLeftOverflow, "left shift amount exceeds the number of bits and the result will be always zero, (`$0` << `$1`).") - -DIAGNOSTIC(41901, Error, unsupportedUseOfLValueForAutoDiff, "unsupported use of L-value for auto differentiation.") -DIAGNOSTIC(41902, Error, cannotDifferentiateDynamicallyIndexedData, "cannot auto-differentiate mixed read/write access to dynamically indexed data in '$0'.") +DIAGNOSTIC( + 41012, + Error, + typeCannotBePackedIntoAnyValue, + "type '$0' contains fields that cannot be packed into an AnyValue.") +DIAGNOSTIC( + 41020, + Error, + lossOfDerivativeDueToCallOfNonDifferentiableFunction, + "derivative cannot be propagated through call to non-$1-differentiable function `$0`, use " + "'no_diff' to clarify intention.") +DIAGNOSTIC( + 41024, + Error, + lossOfDerivativeAssigningToNonDifferentiableLocation, + "derivative is lost during assignment to non-differentiable location, use 'detach()' to " + "clarify intention.") +DIAGNOSTIC( + 41025, + Error, + lossOfDerivativeUsingNonDifferentiableLocationAsOutArg, + "derivative is lost when passing a non-differentiable location to an `out` or `inout` " + "parameter, consider passing a temporary variable instead.") +DIAGNOSTIC( + 41021, + Error, + differentiableFuncMustHaveOutput, + "a differentiable function must have at least one differentiable output.") +DIAGNOSTIC( + 41022, + Error, + differentiableFuncMustHaveInput, + "a differentiable function must have at least one differentiable input.") +DIAGNOSTIC( + 41023, + Error, + getStringHashMustBeOnStringLiteral, + "getStringHash can only be called when argument is statically resolvable to a string literal") + +DIAGNOSTIC( + 41030, + Warning, + operatorShiftLeftOverflow, + "left shift amount exceeds the number of bits and the result will be always zero, (`$0` << " + "`$1`).") + +DIAGNOSTIC( + 41901, + Error, + unsupportedUseOfLValueForAutoDiff, + "unsupported use of L-value for auto differentiation.") +DIAGNOSTIC( + 41902, + Error, + cannotDifferentiateDynamicallyIndexedData, + "cannot auto-differentiate mixed read/write access to dynamically indexed data in '$0'.") DIAGNOSTIC(41903, Error, unableToSizeOf, "sizeof could not be performed for type '$0'.") DIAGNOSTIC(41904, Error, unableToAlignOf, "alignof could not be performed for type '$0'.") -DIAGNOSTIC(42001, Error, invalidUseOfTorchTensorTypeInDeviceFunc, "invalid use of TorchTensor type in device/kernel functions. use `TensorView` instead.") +DIAGNOSTIC( + 42001, + Error, + invalidUseOfTorchTensorTypeInDeviceFunc, + "invalid use of TorchTensor type in device/kernel functions. use `TensorView` instead.") -DIAGNOSTIC(42050, Warning, potentialIssuesWithPreferRecomputeOnSideEffectMethod, "$0 has [PreferRecompute] and may have side effects. side effects may execute multiple times. use [PreferRecompute(SideEffectBehavior.Allow)], or mark function with [__NoSideEffect]") +DIAGNOSTIC( + 42050, + Warning, + potentialIssuesWithPreferRecomputeOnSideEffectMethod, + "$0 has [PreferRecompute] and may have side effects. side effects may execute multiple times. " + "use [PreferRecompute(SideEffectBehavior.Allow)], or mark function with [__NoSideEffect]") DIAGNOSTIC(45001, Error, unresolvedSymbol, "unresolved external symbol '$0'.") -DIAGNOSTIC(41201, Warning, expectDynamicUniformArgument, "argument for '$0' might not be a dynamic uniform, use `asDynamicUniform()` to silence this warning.") -DIAGNOSTIC(41201, Warning, expectDynamicUniformValue, "value stored at this location must be dynamic uniform, use `asDynamicUniform()` to silence this warning.") - - -DIAGNOSTIC(41202, Error, notEqualBitCastSize, "invalid to bit_cast differently sized types: '$0' with size '$1' casted into '$2' with size '$3'") -DIAGNOSTIC(41203, Warning, notEqualReinterpretCastSize, "reinterpret<> into not equally sized types: '$0' with size '$1' casted into '$2' with size '$3'") - -DIAGNOSTIC(41300, Error, byteAddressBufferUnaligned, "invalid alignment `$0` specified for the byte address buffer resource with the element size of `$1`") +DIAGNOSTIC( + 41201, + Warning, + expectDynamicUniformArgument, + "argument for '$0' might not be a dynamic uniform, use `asDynamicUniform()` to silence this " + "warning.") +DIAGNOSTIC( + 41201, + Warning, + expectDynamicUniformValue, + "value stored at this location must be dynamic uniform, use `asDynamicUniform()` to silence " + "this warning.") + + +DIAGNOSTIC( + 41202, + Error, + notEqualBitCastSize, + "invalid to bit_cast differently sized types: '$0' with size '$1' casted into '$2' with size " + "'$3'") +DIAGNOSTIC( + 41203, + Warning, + notEqualReinterpretCastSize, + "reinterpret<> into not equally sized types: '$0' with size '$1' casted into '$2' with size " + "'$3'") + +DIAGNOSTIC( + 41300, + Error, + byteAddressBufferUnaligned, + "invalid alignment `$0` specified for the byte address buffer resource with the element size " + "of `$1`") DIAGNOSTIC(41400, Error, staticAssertionFailure, "static assertion failed, $0") DIAGNOSTIC(41401, Error, staticAssertionFailureWithoutMessage, "static assertion failed.") -DIAGNOSTIC(41402, Error, staticAssertionConditionNotConstant, "condition for static assertion cannot be evaluated at the compile-time.") - -DIAGNOSTIC(41402, Error, multiSampledTextureDoesNotAllowWrites, "cannot write to a multisampled texture with target '$0'.") +DIAGNOSTIC( + 41402, + Error, + staticAssertionConditionNotConstant, + "condition for static assertion cannot be evaluated at the compile-time.") + +DIAGNOSTIC( + 41402, + Error, + multiSampledTextureDoesNotAllowWrites, + "cannot write to a multisampled texture with target '$0'.") // // 5xxxx - Target code generation. // -DIAGNOSTIC(50010, Internal, missingExistentialBindingsForParameter, "missing argument for existential parameter slot") -DIAGNOSTIC(50011, Warning, spirvVersionNotSupported, "Slang's SPIR-V backend only supports SPIR-V version 1.3 and later." - " Use `-emit-spirv-via-glsl` option to produce SPIR-V 1.0 through 1.2.") +DIAGNOSTIC( + 50010, + Internal, + missingExistentialBindingsForParameter, + "missing argument for existential parameter slot") +DIAGNOSTIC( + 50011, + Warning, + spirvVersionNotSupported, + "Slang's SPIR-V backend only supports SPIR-V version 1.3 and later." + " Use `-emit-spirv-via-glsl` option to produce SPIR-V 1.0 through 1.2.") DIAGNOSTIC(50020, Error, invalidTessCoordType, "TessCoord must have vec2 or vec3 type.") DIAGNOSTIC(50020, Error, invalidFragCoordType, "FragCoord must be a vec4.") DIAGNOSTIC(50020, Error, invalidInvocationIdType, "InvocationId must have int type.") @@ -828,86 +2209,263 @@ DIAGNOSTIC(50020, Error, invalidPrimitiveIdType, "PrimitiveId must have int type DIAGNOSTIC(50020, Error, invalidPatchVertexCountType, "PatchVertexCount must have int type.") DIAGNOSTIC(50022, Error, worldIsNotDefined, "world '$0' is not defined.") DIAGNOSTIC(50023, Error, stageShouldProvideWorldAttribute, "'$0' should provide 'World' attribute.") -DIAGNOSTIC(50040, Error, componentHasInvalidTypeForPositionOutput, "'$0': component used as 'loc' output must be of vec4 type.") +DIAGNOSTIC( + 50040, + Error, + componentHasInvalidTypeForPositionOutput, + "'$0': component used as 'loc' output must be of vec4 type.") DIAGNOSTIC(50041, Error, componentNotDefined, "'$0': component not defined.") -DIAGNOSTIC(50052, Error, domainShaderRequiresControlPointCount, "'DomainShader' requires attribute 'ControlPointCount'.") -DIAGNOSTIC(50052, Error, hullShaderRequiresControlPointCount, "'HullShader' requires attribute 'ControlPointCount'.") -DIAGNOSTIC(50052, Error, hullShaderRequiresControlPointWorld, "'HullShader' requires attribute 'ControlPointWorld'.") -DIAGNOSTIC(50052, Error, hullShaderRequiresCornerPointWorld, "'HullShader' requires attribute 'CornerPointWorld'.") +DIAGNOSTIC( + 50052, + Error, + domainShaderRequiresControlPointCount, + "'DomainShader' requires attribute 'ControlPointCount'.") +DIAGNOSTIC( + 50052, + Error, + hullShaderRequiresControlPointCount, + "'HullShader' requires attribute 'ControlPointCount'.") +DIAGNOSTIC( + 50052, + Error, + hullShaderRequiresControlPointWorld, + "'HullShader' requires attribute 'ControlPointWorld'.") +DIAGNOSTIC( + 50052, + Error, + hullShaderRequiresCornerPointWorld, + "'HullShader' requires attribute 'CornerPointWorld'.") DIAGNOSTIC(50052, Error, hullShaderRequiresDomain, "'HullShader' requires attribute 'Domain'.") -DIAGNOSTIC(50052, Error, hullShaderRequiresInputControlPointCount, "'HullShader' requires attribute 'InputControlPointCount'.") -DIAGNOSTIC(50052, Error, hullShaderRequiresOutputTopology, "'HullShader' requires attribute 'OutputTopology'.") -DIAGNOSTIC(50052, Error, hullShaderRequiresPartitioning, "'HullShader' requires attribute 'Partitioning'.") -DIAGNOSTIC(50052, Error, hullShaderRequiresPatchWorld, "'HullShader' requires attribute 'PatchWorld'.") -DIAGNOSTIC(50052, Error, hullShaderRequiresTessLevelInner, "'HullShader' requires attribute 'TessLevelInner'.") -DIAGNOSTIC(50052, Error, hullShaderRequiresTessLevelOuter, "'HullShader' requires attribute 'TessLevelOuter'.") - -DIAGNOSTIC(50053, Error, invalidTessellationDomian, "'Domain' should be either 'triangles' or 'quads'.") -DIAGNOSTIC(50053, Error, invalidTessellationOutputTopology, "'OutputTopology' must be one of: 'point', 'line', 'triangle_cw', or 'triangle_ccw'.") -DIAGNOSTIC(50053, Error, invalidTessellationPartitioning, "'Partitioning' must be one of: 'integer', 'pow2', 'fractional_even', or 'fractional_odd'.") -DIAGNOSTIC(50053, Error, invalidTessellationDomain, "'Domain' should be either 'triangles' or 'quads'.") - -DIAGNOSTIC(50082, Error, importingFromPackedBufferUnsupported, "importing type '$0' from PackedBuffer is not supported by the GLSL backend.") -DIAGNOSTIC(51090, Error, cannotGenerateCodeForExternComponentType, "cannot generate code for extern component type '$0'.") +DIAGNOSTIC( + 50052, + Error, + hullShaderRequiresInputControlPointCount, + "'HullShader' requires attribute 'InputControlPointCount'.") +DIAGNOSTIC( + 50052, + Error, + hullShaderRequiresOutputTopology, + "'HullShader' requires attribute 'OutputTopology'.") +DIAGNOSTIC( + 50052, + Error, + hullShaderRequiresPartitioning, + "'HullShader' requires attribute 'Partitioning'.") +DIAGNOSTIC( + 50052, + Error, + hullShaderRequiresPatchWorld, + "'HullShader' requires attribute 'PatchWorld'.") +DIAGNOSTIC( + 50052, + Error, + hullShaderRequiresTessLevelInner, + "'HullShader' requires attribute 'TessLevelInner'.") +DIAGNOSTIC( + 50052, + Error, + hullShaderRequiresTessLevelOuter, + "'HullShader' requires attribute 'TessLevelOuter'.") + +DIAGNOSTIC( + 50053, + Error, + invalidTessellationDomian, + "'Domain' should be either 'triangles' or 'quads'.") +DIAGNOSTIC( + 50053, + Error, + invalidTessellationOutputTopology, + "'OutputTopology' must be one of: 'point', 'line', 'triangle_cw', or 'triangle_ccw'.") +DIAGNOSTIC( + 50053, + Error, + invalidTessellationPartitioning, + "'Partitioning' must be one of: 'integer', 'pow2', 'fractional_even', or 'fractional_odd'.") +DIAGNOSTIC( + 50053, + Error, + invalidTessellationDomain, + "'Domain' should be either 'triangles' or 'quads'.") + +DIAGNOSTIC( + 50082, + Error, + importingFromPackedBufferUnsupported, + "importing type '$0' from PackedBuffer is not supported by the GLSL backend.") +DIAGNOSTIC( + 51090, + Error, + cannotGenerateCodeForExternComponentType, + "cannot generate code for extern component type '$0'.") DIAGNOSTIC(51091, Error, typeCannotBePlacedInATexture, "type '$0' cannot be placed in a texture.") DIAGNOSTIC(51092, Error, stageDoesntHaveInputWorld, "'$0' doesn't appear to have any input world") -DIAGNOSTIC(50100, Error, noTypeConformancesFoundForInterface, "No type conformances are found for interface '$0'. Code generation for current target requires at least one implementation type present in the linkage.") - -DIAGNOSTIC(52000, Error, multiLevelBreakUnsupported, "control flow appears to require multi-level `break`, which Slang does not yet support") - -DIAGNOSTIC(52001, Warning, dxilNotFound, "dxil shared library not found, so 'dxc' output cannot be signed! Shader code will not be runnable in non-development environments.") - -DIAGNOSTIC(52002, Error, passThroughCompilerNotFound, "could not find a suitable pass-through compiler for '$0'.") +DIAGNOSTIC( + 50100, + Error, + noTypeConformancesFoundForInterface, + "No type conformances are found for interface '$0'. Code generation for current target " + "requires at least one implementation type present in the linkage.") + +DIAGNOSTIC( + 52000, + Error, + multiLevelBreakUnsupported, + "control flow appears to require multi-level `break`, which Slang does not yet support") + +DIAGNOSTIC( + 52001, + Warning, + dxilNotFound, + "dxil shared library not found, so 'dxc' output cannot be signed! Shader code will not be " + "runnable in non-development environments.") + +DIAGNOSTIC( + 52002, + Error, + passThroughCompilerNotFound, + "could not find a suitable pass-through compiler for '$0'.") DIAGNOSTIC(52003, Error, cannotDisassemble, "cannot disassemble '$0'.") DIAGNOSTIC(52004, Error, unableToWriteFile, "unable to write file '$0'") DIAGNOSTIC(52005, Error, unableToReadFile, "unable to read file '$0'") -DIAGNOSTIC(52006, Error, compilerNotDefinedForTransition, "compiler not defined for transition '$0' to '$1'.") - -DIAGNOSTIC(52007, Error, typeCannotBeUsedInDynamicDispatch, "failed to generate dynamic dispatch code for type '$0'.") -DIAGNOSTIC(52008, Error, dynamicDispatchOnSpecializeOnlyInterface, "type '$0' is marked for specialization only, but dynamic dispatch is needed for the call.") -DIAGNOSTIC(53001, Error, invalidTypeMarshallingForImportedDLLSymbol, "invalid type marshalling in imported func $0.") +DIAGNOSTIC( + 52006, + Error, + compilerNotDefinedForTransition, + "compiler not defined for transition '$0' to '$1'.") + +DIAGNOSTIC( + 52007, + Error, + typeCannotBeUsedInDynamicDispatch, + "failed to generate dynamic dispatch code for type '$0'.") +DIAGNOSTIC( + 52008, + Error, + dynamicDispatchOnSpecializeOnlyInterface, + "type '$0' is marked for specialization only, but dynamic dispatch is needed for the call.") +DIAGNOSTIC( + 53001, + Error, + invalidTypeMarshallingForImportedDLLSymbol, + "invalid type marshalling in imported func $0.") DIAGNOSTIC(54001, Warning, meshOutputMustBeOut, "Mesh shader outputs must be declared with 'out'.") DIAGNOSTIC(54002, Error, meshOutputMustBeArray, "HLSL style mesh shader outputs must be arrays") -DIAGNOSTIC(54003, Error, meshOutputArrayMustHaveSize, "HLSL style mesh shader output arrays must have a length specified") -DIAGNOSTIC(54004, Warning, unnecessaryHLSLMeshOutputModifier, "Unnecessary HLSL style mesh shader output modifier") - -DIAGNOSTIC(55101, Error, invalidTorchKernelReturnType, "'$0' is not a valid return type for a pytorch kernel function.") -DIAGNOSTIC(55102, Error, invalidTorchKernelParamType, "'$0' is not a valid parameter type for a pytorch kernel function.") - -DIAGNOSTIC(55200, Error, unsupportedBuiltinType, "'$0' is not a supported builtin type for the target.") -DIAGNOSTIC(55201, Error, unsupportedRecursion, "recursion detected in call to '$0', but the current code generation target does not allow recursion.") -DIAGNOSTIC(55202, Error, systemValueAttributeNotSupported, "system value semantic '$0' is not supported for the current target.") -DIAGNOSTIC(55203, Error, systemValueTypeIncompatible, "system value semantic '$0' should have type '$1' or be convertible to type '$1'.") -DIAGNOSTIC(55204, Error, unsupportedTargetIntrinsic, "intrinsic operation '$0' is not supported for the current target.") -DIAGNOSTIC(56001, Error, unableToAutoMapCUDATypeToHostType, "Could not automatically map '$0' to a host type. Automatic binding generation failed for '$1'") -DIAGNOSTIC(56002, Error, attemptToQuerySizeOfUnsizedArray, "cannot obtain the size of an unsized array.") +DIAGNOSTIC( + 54003, + Error, + meshOutputArrayMustHaveSize, + "HLSL style mesh shader output arrays must have a length specified") +DIAGNOSTIC( + 54004, + Warning, + unnecessaryHLSLMeshOutputModifier, + "Unnecessary HLSL style mesh shader output modifier") + +DIAGNOSTIC( + 55101, + Error, + invalidTorchKernelReturnType, + "'$0' is not a valid return type for a pytorch kernel function.") +DIAGNOSTIC( + 55102, + Error, + invalidTorchKernelParamType, + "'$0' is not a valid parameter type for a pytorch kernel function.") + +DIAGNOSTIC( + 55200, + Error, + unsupportedBuiltinType, + "'$0' is not a supported builtin type for the target.") +DIAGNOSTIC( + 55201, + Error, + unsupportedRecursion, + "recursion detected in call to '$0', but the current code generation target does not allow " + "recursion.") +DIAGNOSTIC( + 55202, + Error, + systemValueAttributeNotSupported, + "system value semantic '$0' is not supported for the current target.") +DIAGNOSTIC( + 55203, + Error, + systemValueTypeIncompatible, + "system value semantic '$0' should have type '$1' or be convertible to type '$1'.") +DIAGNOSTIC( + 55204, + Error, + unsupportedTargetIntrinsic, + "intrinsic operation '$0' is not supported for the current target.") +DIAGNOSTIC( + 56001, + Error, + unableToAutoMapCUDATypeToHostType, + "Could not automatically map '$0' to a host type. Automatic binding generation failed for '$1'") +DIAGNOSTIC( + 56002, + Error, + attemptToQuerySizeOfUnsizedArray, + "cannot obtain the size of an unsized array.") // Metal -DIAGNOSTIC(56100, Error, constantBufferInParameterBlockNotAllowedOnMetal, "nested 'ConstantBuffer' inside a 'ParameterBlock' is not supported on Metal, use 'ParameterBlock' instead.") +DIAGNOSTIC( + 56100, + Error, + constantBufferInParameterBlockNotAllowedOnMetal, + "nested 'ConstantBuffer' inside a 'ParameterBlock' is not supported on Metal, use " + "'ParameterBlock' instead.") DIAGNOSTIC(57001, Warning, spirvOptFailed, "spirv-opt failed. $0") DIAGNOSTIC(57002, Error, unknownPatchConstantParameter, "unknown patch constant parameter '$0'.") DIAGNOSTIC(57003, Error, unknownTessPartitioning, "unknown tessellation partitioning '$0'.") // GLSL Compatibility -DIAGNOSTIC(58001, Error, entryPointMustReturnVoidWhenGlobalOutputPresent, "entry point must return 'void' when global output variables are present.") -DIAGNOSTIC(58002, Error, unhandledGLSLSSBOType, "Unhandled GLSL Shader Storage Buffer Object contents, unsized arrays as a final parameter must be the only parameter") - -DIAGNOSTIC(58003, Error, inconsistentPointerAddressSpace, "'$0': use of pointer with inconsistent address space.") +DIAGNOSTIC( + 58001, + Error, + entryPointMustReturnVoidWhenGlobalOutputPresent, + "entry point must return 'void' when global output variables are present.") +DIAGNOSTIC( + 58002, + Error, + unhandledGLSLSSBOType, + "Unhandled GLSL Shader Storage Buffer Object contents, unsized arrays as a final parameter " + "must be the only parameter") + +DIAGNOSTIC( + 58003, + Error, + inconsistentPointerAddressSpace, + "'$0': use of pointer with inconsistent address space.") // Autodiff checkpoint reporting -DIAGNOSTIC(-1, Note, reportCheckpointIntermediates, "checkpointing context of $1 bytes associated with function: '$0'") -DIAGNOSTIC(-1, Note, reportCheckpointVariable, "$0 bytes ($1) used to checkpoint the following item:") +DIAGNOSTIC( + -1, + Note, + reportCheckpointIntermediates, + "checkpointing context of $1 bytes associated with function: '$0'") +DIAGNOSTIC( + -1, + Note, + reportCheckpointVariable, + "$0 bytes ($1) used to checkpoint the following item:") DIAGNOSTIC(-1, Note, reportCheckpointCounter, "$0 bytes ($1) used for a loop counter here:") DIAGNOSTIC(-1, Note, reportCheckpointNone, "no checkpoint contexts to report") // 9xxxx - Documentation generation -DIAGNOSTIC(90001, Warning, ignoredDocumentationOnOverloadCandidate, "documentation comment on overload candidate '$0' is ignored") +DIAGNOSTIC( + 90001, + Warning, + ignoredDocumentationOnOverloadCandidate, + "documentation comment on overload candidate '$0' is ignored") // // 8xxxx - Issues specific to a particular library/technology/platform/etc. @@ -915,9 +2473,17 @@ DIAGNOSTIC(90001, Warning, ignoredDocumentationOnOverloadCandidate, "documentati // 811xx - NVAPI -DIAGNOSTIC(81110, Error, nvapiMacroMismatch, "conflicting definitions for NVAPI macro '$0': '$1' and '$2'") +DIAGNOSTIC( + 81110, + Error, + nvapiMacroMismatch, + "conflicting definitions for NVAPI macro '$0': '$1' and '$2'") -DIAGNOSTIC(81111, Error, opaqueReferenceMustResolveToGlobal, "could not determine register/space for a resource or sampler used with NVAPI") +DIAGNOSTIC( + 81111, + Error, + opaqueReferenceMustResolveToGlobal, + "could not determine register/space for a resource or sampler used with NVAPI") // 99999 - Internal compiler errors, and not-yet-classified diagnostics. @@ -925,10 +2491,26 @@ DIAGNOSTIC(99999, Internal, unimplemented, "unimplemented feature in Slang compi DIAGNOSTIC(99999, Internal, unexpected, "unexpected condition encountered in Slang compiler: $0") DIAGNOSTIC(99999, Internal, internalCompilerError, "Slang internal compiler error") DIAGNOSTIC(99999, Error, compilationAborted, "Slang compilation aborted due to internal error") -DIAGNOSTIC(99999, Error, compilationAbortedDueToException, "Slang compilation aborted due to an exception of $0: $1") -DIAGNOSTIC(99999, Internal, serialDebugVerificationFailed, "Verification of serial debug information failed.") -DIAGNOSTIC(99999, Internal, spirvValidationFailed, "Validation of generated SPIR-V failed. SPIRV generated: \n$0") - -DIAGNOSTIC(99999, Internal, noBlocksOrIntrinsic, "no blocks found for function definition, is there a '$0' intrinsic missing?") +DIAGNOSTIC( + 99999, + Error, + compilationAbortedDueToException, + "Slang compilation aborted due to an exception of $0: $1") +DIAGNOSTIC( + 99999, + Internal, + serialDebugVerificationFailed, + "Verification of serial debug information failed.") +DIAGNOSTIC( + 99999, + Internal, + spirvValidationFailed, + "Validation of generated SPIR-V failed. SPIRV generated: \n$0") + +DIAGNOSTIC( + 99999, + Internal, + noBlocksOrIntrinsic, + "no blocks found for function definition, is there a '$0' intrinsic missing?") #undef DIAGNOSTIC diff --git a/source/slang/slang-diagnostics.cpp b/source/slang/slang-diagnostics.cpp index df8c4c0d2..53d5392a2 100644 --- a/source/slang/slang-diagnostics.cpp +++ b/source/slang/slang-diagnostics.cpp @@ -1,33 +1,33 @@ // slang-diagnostics.cpp #include "slang-diagnostics.h" -#include "../core/slang-memory-arena.h" +#include "../compiler-core/slang-core-diagnostics.h" +#include "../compiler-core/slang-name.h" +#include "../core/slang-char-util.h" #include "../core/slang-dictionary.h" +#include "../core/slang-memory-arena.h" #include "../core/slang-string-util.h" -#include "../core/slang-char-util.h" - -#include "../compiler-core/slang-name.h" -#include "../compiler-core/slang-core-diagnostics.h" namespace Slang { namespace Diagnostics { -#define DIAGNOSTIC(id, severity, name, messageFormat) const DiagnosticInfo name = { id, Severity::severity, #name, messageFormat }; +#define DIAGNOSTIC(id, severity, name, messageFormat) \ + const DiagnosticInfo name = {id, Severity::severity, #name, messageFormat}; #include "slang-diagnostic-defs.h" #undef DIAGNOSTIC -} +} // namespace Diagnostics -static const DiagnosticInfo* const kCompilerDiagnostics[] = -{ -#define DIAGNOSTIC(id, severity, name, messageFormat) &Diagnostics::name, +static const DiagnosticInfo* const kCompilerDiagnostics[] = { +#define DIAGNOSTIC(id, severity, name, messageFormat) &Diagnostics::name, #include "slang-diagnostic-defs.h" #undef DIAGNOSTIC }; static DiagnosticsLookup* _newDiagnosticsLookup() { - DiagnosticsLookup* lookup = new DiagnosticsLookup(kCompilerDiagnostics, SLANG_COUNT_OF(kCompilerDiagnostics)); + DiagnosticsLookup* lookup = + new DiagnosticsLookup(kCompilerDiagnostics, SLANG_COUNT_OF(kCompilerDiagnostics)); // Add all the diagnostics in 'core' DiagnosticsLookup* coreLookup = getCoreDiagnosticsLookup(); @@ -61,14 +61,19 @@ const DiagnosticsLookup* getDiagnosticsLookup() } -SlangResult overrideDiagnostic(DiagnosticSink* sink, DiagnosticSink* outDiagnostic, const UnownedStringSlice& identifier, Severity originalSeverity, Severity overrideSeverity) +SlangResult overrideDiagnostic( + DiagnosticSink* sink, + DiagnosticSink* outDiagnostic, + const UnownedStringSlice& identifier, + Severity originalSeverity, + Severity overrideSeverity) { auto diagnosticsLookup = getDiagnosticsLookup(); const DiagnosticInfo* diagnostic = nullptr; Int diagnosticId = -1; - // If it starts with a digit we assume it a number + // If it starts with a digit we assume it a number if (identifier.getLength() > 0 && (CharUtil::isDigit(identifier[0]) || identifier[0] == '-')) { if (SLANG_FAILED(StringUtil::parseInt(identifier, diagnosticId))) @@ -95,7 +100,8 @@ SlangResult overrideDiagnostic(DiagnosticSink* sink, DiagnosticSink* outDiagnost } // If we are only allowing certain original severities check it's the right type - if (diagnostic && originalSeverity != Severity::Disable && diagnostic->severity != originalSeverity) + if (diagnostic && originalSeverity != Severity::Disable && + diagnostic->severity != originalSeverity) { // Strictly speaking the diagnostic name is known, but it's not the right severity // to be converted from, so it is an 'unknown name' in the context of severity... @@ -110,14 +116,20 @@ SlangResult overrideDiagnostic(DiagnosticSink* sink, DiagnosticSink* outDiagnost return SLANG_OK; } -SlangResult overrideDiagnostics(DiagnosticSink* sink, DiagnosticSink* outDiagnostic, const UnownedStringSlice& identifierList, Severity originalSeverity, Severity overrideSeverity) +SlangResult overrideDiagnostics( + DiagnosticSink* sink, + DiagnosticSink* outDiagnostic, + const UnownedStringSlice& identifierList, + Severity originalSeverity, + Severity overrideSeverity) { List slices; StringUtil::split(identifierList, ',', slices); for (const auto& slice : slices) { - SLANG_RETURN_ON_FAIL(overrideDiagnostic(sink, outDiagnostic, slice, originalSeverity, overrideSeverity)); + SLANG_RETURN_ON_FAIL( + overrideDiagnostic(sink, outDiagnostic, slice, originalSeverity, overrideSeverity)); } return SLANG_OK; } diff --git a/source/slang/slang-diagnostics.h b/source/slang/slang-diagnostics.h index 202ee0cf4..4d0d0212a 100644 --- a/source/slang/slang-diagnostics.h +++ b/source/slang/slang-diagnostics.h @@ -1,34 +1,47 @@ #ifndef SLANG_DIAGNOSTICS_H #define SLANG_DIAGNOSTICS_H -#include "../core/slang-basic.h" -#include "../core/slang-writer.h" - -#include "../compiler-core/slang-source-loc.h" #include "../compiler-core/slang-diagnostic-sink.h" +#include "../compiler-core/slang-source-loc.h" #include "../compiler-core/slang-token.h" - +#include "../core/slang-basic.h" +#include "../core/slang-writer.h" #include "slang.h" namespace Slang { - DiagnosticInfo const* findDiagnosticByName(UnownedStringSlice const& name); - const DiagnosticsLookup* getDiagnosticsLookup(); - SlangResult overrideDiagnostic(DiagnosticSink* sink, DiagnosticSink* outDiagnostic, const UnownedStringSlice& identifier, Severity originalSeverity, Severity overrideSeverity); - SlangResult overrideDiagnostics(DiagnosticSink* sink, DiagnosticSink* outDiagnostic, const UnownedStringSlice& identifierList, Severity originalSeverity, Severity overrideSeverity); - - namespace Diagnostics - { +DiagnosticInfo const* findDiagnosticByName(UnownedStringSlice const& name); +const DiagnosticsLookup* getDiagnosticsLookup(); +SlangResult overrideDiagnostic( + DiagnosticSink* sink, + DiagnosticSink* outDiagnostic, + const UnownedStringSlice& identifier, + Severity originalSeverity, + Severity overrideSeverity); +SlangResult overrideDiagnostics( + DiagnosticSink* sink, + DiagnosticSink* outDiagnostic, + const UnownedStringSlice& identifierList, + Severity originalSeverity, + Severity overrideSeverity); + +namespace Diagnostics +{ #define DIAGNOSTIC(id, severity, name, messageFormat) extern const DiagnosticInfo name; #include "slang-diagnostic-defs.h" - } -} +} // namespace Diagnostics +} // namespace Slang #ifdef _DEBUG -#define SLANG_INTERNAL_ERROR(sink, pos) \ - (sink)->diagnose(Slang::SourceLoc(__LINE__, 0, 0, __FILE__), Slang::Diagnostics::internalCompilerError) -#define SLANG_UNIMPLEMENTED(sink, pos, what) \ - (sink)->diagnose(Slang::SourceLoc(__LINE__, 0, 0, __FILE__), Slang::Diagnostics::unimplemented, what) +#define SLANG_INTERNAL_ERROR(sink, pos) \ + (sink)->diagnose( \ + Slang::SourceLoc(__LINE__, 0, 0, __FILE__), \ + Slang::Diagnostics::internalCompilerError) +#define SLANG_UNIMPLEMENTED(sink, pos, what) \ + (sink)->diagnose( \ + Slang::SourceLoc(__LINE__, 0, 0, __FILE__), \ + Slang::Diagnostics::unimplemented, \ + what) #else #define SLANG_INTERNAL_ERROR(sink, pos) \ diff --git a/source/slang/slang-doc-ast.cpp b/source/slang/slang-doc-ast.cpp index 4d35f32be..0d4b69895 100644 --- a/source/slang/slang-doc-ast.cpp +++ b/source/slang/slang-doc-ast.cpp @@ -2,14 +2,14 @@ #include "slang-doc-ast.h" #include "../core/slang-string-util.h" - #include "slang/slang-ast-support-types.h" -//#include "slang-ast-builder.h" -//#include "slang-ast-print.h" +// #include "slang-ast-builder.h" +// #include "slang-ast-print.h" -namespace Slang { +namespace Slang +{ -/* static */DocMarkupExtractor::SearchStyle ASTMarkupUtil::getSearchStyle(Decl* decl) +/* static */ DocMarkupExtractor::SearchStyle ASTMarkupUtil::getSearchStyle(Decl* decl) { typedef Extractor::SearchStyle SearchStyle; @@ -50,7 +50,8 @@ namespace Slang { bool shouldDocumentDecl(Decl* decl) { - return !getText(decl->getName()).startsWith("$__syn") && !decl->hasModifier(); + return !getText(decl->getName()).startsWith("$__syn") && + !decl->hasModifier(); } static void _addDeclRec(Decl* decl, List& outDecls) @@ -73,8 +74,8 @@ static void _addDeclRec(Decl* decl, List& outDecls) if (ContainerDecl* containerDecl = as(decl)) { - // Add the container - which could be a class, struct, enum, namespace, extension, generic etc. - // Now add what the container contains + // Add the container - which could be a class, struct, enum, namespace, extension, generic + // etc. Now add what the container contains for (Decl* childDecl : containerDecl->members) { _addDeclRec(childDecl, outDecls); @@ -82,7 +83,7 @@ static void _addDeclRec(Decl* decl, List& outDecls) } } -/* static */void ASTMarkupUtil::findDecls(ModuleDecl* moduleDecl, List& outDecls) +/* static */ void ASTMarkupUtil::findDecls(ModuleDecl* moduleDecl, List& outDecls) { for (Decl* decl : moduleDecl->members) { @@ -90,7 +91,12 @@ static void _addDeclRec(Decl* decl, List& outDecls) } } -SlangResult ASTMarkupUtil::extract(ModuleDecl* moduleDecl, SourceManager* sourceManager, DiagnosticSink* sink, ASTMarkup* outDoc, bool searchOrindaryComments) +SlangResult ASTMarkupUtil::extract( + ModuleDecl* moduleDecl, + SourceManager* sourceManager, + DiagnosticSink* sink, + ASTMarkup* outDoc, + bool searchOrindaryComments) { List decls; findDecls(moduleDecl, decls); @@ -123,7 +129,9 @@ SlangResult ASTMarkupUtil::extract(ModuleDecl* moduleDecl, SourceManager* source extractor.setSearchInOrdinaryComments(searchOrindaryComments); List views; - SLANG_RETURN_ON_FAIL(extractor.extract(inputItems.getBuffer(), declsCount, sourceManager, sink, views, outItems)); + SLANG_RETURN_ON_FAIL( + extractor + .extract(inputItems.getBuffer(), declsCount, sourceManager, sink, views, outItems)); } // Set back @@ -136,14 +144,14 @@ SlangResult ASTMarkupUtil::extract(ModuleDecl* moduleDecl, SourceManager* source if (inputItem.searchStyle != Extractor::SearchStyle::None) { Decl* decl = decls[outputItem.inputIndex]; - + // Add to the documentation ASTMarkup::Entry& docEntry = outDoc->addEntry(decl); docEntry.m_markup = outputItem.text; docEntry.m_visibility = outputItem.visibilty; } } - + return SLANG_OK; } diff --git a/source/slang/slang-doc-ast.h b/source/slang/slang-doc-ast.h index fb5a08986..3339e2a77 100644 --- a/source/slang/slang-doc-ast.h +++ b/source/slang/slang-doc-ast.h @@ -2,39 +2,36 @@ #ifndef SLANG_DOC_AST_H #define SLANG_DOC_AST_H -#include "../core/slang-basic.h" - #include "../compiler-core/slang-doc-extractor.h" - +#include "../core/slang-basic.h" #include "slang-ast-all.h" - #include "slang-syntax.h" -namespace Slang { +namespace Slang +{ -/* Holds the documentation markup that is associated with each node (typically a decl) from a module */ +/* Holds the documentation markup that is associated with each node (typically a decl) from a module + */ class ASTMarkup : public RefObject { public: - typedef MarkupEntry Entry; - /// Adds an entry, returns the reference to pre-existing node if there is one + /// Adds an entry, returns the reference to pre-existing node if there is one Entry& addEntry(NodeBase* base); - /// Gets an entry for a node. Returns nullptr if there is no markup. + /// Gets an entry for a node. Returns nullptr if there is no markup. Entry* getEntry(NodeBase* base); - /// Get list of all of the entries in source order + /// Get list of all of the entries in source order List& getEntries() { return m_entries; } - /// Attaches the markup to the AST nodes. + /// Attaches the markup to the AST nodes. void attachToAST(); protected: - - /// Map from AST nodes to documentation entries + /// Map from AST nodes to documentation entries Dictionary m_entryMap; - /// All of the documentation entries in source order + /// All of the documentation entries in source order List m_entries; }; @@ -71,22 +68,28 @@ SLANG_INLINE void ASTMarkup::attachToAST() } } -/* Extracts documentation markup from source. +/* Extracts documentation markup from source. The comments are extracted and associated in declarations. The association is held in DocMarkup type. The comment style follows the doxygen style */ struct ASTMarkupUtil { typedef DocMarkupExtractor Extractor; - /// Given a module finds all the decls, and places in outDecls + /// Given a module finds all the decls, and places in outDecls static void findDecls(ModuleDecl* moduleDecl, List& outDecls); - /// Given a decl determines the search style that is appropriate. Returns None if can't determine a suitable style + /// Given a decl determines the search style that is appropriate. Returns None if can't + /// determine a suitable style static Extractor::SearchStyle getSearchStyle(Decl* decl); - /// Extracts documentation from the nodes held in the module using the source manager. Found documentation is placed - /// in outMarkup - static SlangResult extract(ModuleDecl* moduleDecl, SourceManager* sourceManager, DiagnosticSink* sink, ASTMarkup* outMarkup, bool searchOrindaryComments = false); + /// Extracts documentation from the nodes held in the module using the source manager. Found + /// documentation is placed in outMarkup + static SlangResult extract( + ModuleDecl* moduleDecl, + SourceManager* sourceManager, + DiagnosticSink* sink, + ASTMarkup* outMarkup, + bool searchOrindaryComments = false); }; bool shouldDocumentDecl(Decl* decl); diff --git a/source/slang/slang-doc-markdown-writer.cpp b/source/slang/slang-doc-markdown-writer.cpp index 79bec1402..2c0e6a040 100644 --- a/source/slang/slang-doc-markdown-writer.cpp +++ b/source/slang/slang-doc-markdown-writer.cpp @@ -1,20 +1,20 @@ // slang-doc-markdown-writer.cpp #include "slang-doc-markdown-writer.h" -#include "../core/slang-string-util.h" -#include "../core/slang-type-text-util.h" #include "../core/slang-char-util.h" +#include "../core/slang-string-util.h" #include "../core/slang-token-reader.h" - +#include "../core/slang-type-text-util.h" #include "slang-ast-builder.h" #include "slang-lookup.h" -namespace Slang { +namespace Slang +{ /* !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! DocMarkDownWriter !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! */ -template +template static void _getDecls(ContainerDecl* containerDecl, List& out) { for (Decl* decl : containerDecl->members) @@ -26,8 +26,11 @@ static void _getDecls(ContainerDecl* containerDecl, List& out) } } -template -static void _getDeclsOfType(DocMarkdownWriter* writer, ContainerDecl* containerDecl, List& out) +template +static void _getDeclsOfType( + DocMarkdownWriter* writer, + ContainerDecl* containerDecl, + List& out) { for (Decl* decl : containerDecl->members) { @@ -49,7 +52,7 @@ static void _getDeclsOfType(DocMarkdownWriter* writer, ContainerDecl* containerD } } -template +template static void _getDeclsOfType(DocMarkdownWriter* writer, DocumentPage* page, List& out) { // Collect all decls of type T from all entries for the page. @@ -70,10 +73,11 @@ static void _getDeclsOfType(DocMarkdownWriter* writer, DocumentPage* page, List< if (pair.first) out.add(pair.second); } - out.sort([](Decl* a, Decl* b) -> bool { return getText(a->getName()) < getText(b->getName()); }); + out.sort( + [](Decl* a, Decl* b) -> bool { return getText(a->getName()) < getText(b->getName()); }); } -template +template static void _toList(FilteredMemberList& list, List& out) { for (Decl* decl : list) @@ -96,7 +100,10 @@ String getDocPath(const DocumentationConfig& config, String path) return config.rootDir + Path::getPathWithoutExt(path); } -void DocMarkdownWriter::_appendAsBullets(const List& values, bool insertLinkForName, char wrapChar) +void DocMarkdownWriter::_appendAsBullets( + const List& values, + bool insertLinkForName, + char wrapChar) { auto& out = *m_builder; for (const auto& value : values) @@ -114,7 +121,7 @@ void DocMarkdownWriter::_appendAsBullets(const List& values, bool i } out.appendChar(wrapChar); out << name; - out.appendChar(wrapChar); + out.appendChar(wrapChar); if (path.getLength()) { out << "](" << getDocPath(m_config, path) << ")"; @@ -340,7 +347,9 @@ String DocMarkdownWriter::_getName(InheritanceDecl* decl) return buf.produceString(); } -DocMarkdownWriter::NameAndText DocMarkdownWriter::_getNameAndText(ASTMarkup::Entry* entry, Decl* decl) +DocMarkdownWriter::NameAndText DocMarkdownWriter::_getNameAndText( + ASTMarkup::Entry* entry, + Decl* decl) { NameAndText nameAndText; @@ -459,7 +468,7 @@ void DocMarkdownWriter::_appendCommaList(const List& strings, char wrapC } } -/* static */void DocMarkdownWriter::getSignature(const List& parts, Signature& outSig) +/* static */ void DocMarkdownWriter::getSignature(const List& parts, Signature& outSig) { const Index count = parts.getCount(); for (Index i = 0; i < count; ++i) @@ -467,7 +476,7 @@ void DocMarkdownWriter::_appendCommaList(const List& strings, char wrapC const auto& part = parts[i]; switch (part.type) { - case Part::Type::ParamType: + case Part::Type::ParamType: { PartPair pair; pair.first = part; @@ -479,22 +488,22 @@ void DocMarkdownWriter::_appendCommaList(const List& strings, char wrapC outSig.params.add(pair); break; } - case Part::Type::ReturnType: + case Part::Type::ReturnType: { outSig.returnType = part; break; } - case Part::Type::DeclPath: + case Part::Type::DeclPath: { outSig.name = part; break; } - case Part::Type::GenericParamValue: - case Part::Type::GenericParamType: + case Part::Type::GenericParamValue: + case Part::Type::GenericParamType: { Signature::GenericParam genericParam; genericParam.name = part; - + if ((i + 1) < count && parts[i + 1].type == Part::Type::GenericParamValueType) { genericParam.type = parts[i + 1]; @@ -505,7 +514,7 @@ void DocMarkdownWriter::_appendCommaList(const List& strings, char wrapC break; } - default: break; + default: break; } } } @@ -520,7 +529,7 @@ void escapeHTMLContent(StringBuilder& sb, UnownedStringSlice str) case '>': sb << ">"; break; case '&': sb << "&"; break; case '"': sb << """; break; - default: sb.appendChar(ch); break; + default: sb.appendChar(ch); break; } } } @@ -528,7 +537,7 @@ void escapeHTMLContent(StringBuilder& sb, UnownedStringSlice str) void DocMarkdownWriter::writeVar(const ASTMarkup::Entry& entry, VarDecl* varDecl) { auto& out = *m_builder; - + ASTPrinter printer(m_astBuilder); printer.addDeclPath(DeclRef(varDecl)); @@ -618,7 +627,7 @@ void DocMarkdownWriter::writeProperty(const ASTMarkup::Entry& entry, PropertyDec } } out << "}\n\n\n"; - + declDoc.writeSection(out, this, propertyDecl, DocPageSection::ReturnInfo); declDoc.writeSection(out, this, propertyDecl, DocPageSection::Remarks); declDoc.writeSection(out, this, propertyDecl, DocPageSection::Example); @@ -633,7 +642,7 @@ void DocMarkdownWriter::writeTypeDef(const ASTMarkup::Entry& entry, TypeDefDecl* ASTMarkup::Entry newEntry = entry; _appendAggTypeName(newEntry, typeDefDecl); out << toSlice("\n\n"); - + DeclDocumentation declDoc; declDoc.parse(entry.m_markup.getUnownedSlice()); registerCategory(m_currentPage, declDoc); @@ -728,7 +737,11 @@ void DocMarkdownWriter::writeAttribute(const ASTMarkup::Entry& entry, AttributeD declDoc.writeSection(out, this, attributeDecl, DocPageSection::SeeAlso); } -void DocMarkdownWriter::writeExtensionConditions(StringBuilder& out, ExtensionDecl* extensionDecl, const char* prefix, bool isHtml) +void DocMarkdownWriter::writeExtensionConditions( + StringBuilder& out, + ExtensionDecl* extensionDecl, + const char* prefix, + bool isHtml) { // Synthesize `where` clause for things defined in an extension. auto targetTypeDeclRef = isDeclRefTypeOf(extensionDecl->targetType); @@ -753,7 +766,8 @@ void DocMarkdownWriter::writeExtensionConditions(StringBuilder& out, ExtensionDe // Locate the original generic parameter defined on the type being extended. Decl* originalParamDecl = nullptr; - if (auto targetTypeParentGenericDecl = as(targetTypeDeclRef.getDecl()->parentDecl)) + if (auto targetTypeParentGenericDecl = + as(targetTypeDeclRef.getDecl()->parentDecl)) { for (auto member : targetTypeParentGenericDecl->members) { @@ -784,16 +798,18 @@ void DocMarkdownWriter::writeExtensionConditions(StringBuilder& out, ExtensionDe Val* constraintVal = nullptr; if (genericParamDecl) { - // If we have `TargetType` the member belongs to `extension TargetType`, - // We want to print a synthesized `where T : C` clause. - // Here `extTypeParamDecl` is a reference to `X`, so we need to find the corresponding `T`. + // If we have `TargetType` the member belongs to `extension + // TargetType`, We want to print a synthesized `where T : C` clause. Here + // `extTypeParamDecl` is a reference to `X`, so we need to find the corresponding + // `T`. // Find constraints on the originalParamDecl. for (auto member : genericParamDecl->parentDecl->members) { if (auto typeConstraint = as(member)) { - if (isDeclRefTypeOf(typeConstraint->sub.type).getDecl() == genericParamDecl) + if (isDeclRefTypeOf(typeConstraint->sub.type).getDecl() == + genericParamDecl) { if (typeConstraint->isEqualityConstraint) { @@ -807,9 +823,9 @@ void DocMarkdownWriter::writeExtensionConditions(StringBuilder& out, ExtensionDe } else { - // If we have `extension TargetType` where `Y` does not name a generic parameter defined - // on the extension itself, we want to print a synthesized `where T == Y` clause, where - // `T` is the original generic parameter on the target type. + // If we have `extension TargetType` where `Y` does not name a generic parameter + // defined on the extension itself, we want to print a synthesized `where T == Y` + // clause, where `T` is the original generic parameter on the target type. isEqualityConstraint = true; constraintVal = arg; } @@ -817,7 +833,9 @@ void DocMarkdownWriter::writeExtensionConditions(StringBuilder& out, ExtensionDe { out << prefix; if (isHtml) - out << translateToHTMLWithLinks(originalParamDecl, originalParamDecl->getName()->text); + out << translateToHTMLWithLinks( + originalParamDecl, + originalParamDecl->getName()->text); else out << translateToMarkdownWithLinks(originalParamDecl->getName()->text); if (isEqualityConstraint) @@ -843,8 +861,11 @@ void DocMarkdownWriter::writeSignature(CallableDecl* callableDecl) } List parts; - - ASTPrinter printer(m_astBuilder, ASTPrinter::OptionFlag::ParamNames | ASTPrinter::OptionFlag::NoSpecializedExtensionTypeName, &parts); + + ASTPrinter printer( + m_astBuilder, + ASTPrinter::OptionFlag::ParamNames | ASTPrinter::OptionFlag::NoSpecializedExtensionTypeName, + &parts); printer.addDeclSignature(makeDeclRef(callableDecl)); Signature signature; @@ -865,20 +886,21 @@ void DocMarkdownWriter::writeSignature(CallableDecl* callableDecl) switch (paramCount) { - case 0: + case 0: { // Has no parameters out << toSlice("()"); break; } - case 1: + case 1: { if (signature.name.end - signature.name.start < 40) { // Place all on single line out.appendChar('('); const auto& param = signature.params[0]; - out << translateToHTMLWithLinks(callableDecl, printer.getPartSlice(param.first)) << toSlice(" "); + out << translateToHTMLWithLinks(callableDecl, printer.getPartSlice(param.first)) + << toSlice(" "); out << translateToHTMLWithLinks(callableDecl, printer.getPartSlice(param.second)); out << ")"; break; @@ -886,7 +908,7 @@ void DocMarkdownWriter::writeSignature(CallableDecl* callableDecl) // If the name is already long, fall through to default. [[fallthrough]]; } - default: + default: { // Put each parameter on a line on it's own out << toSlice("(\n"); @@ -896,7 +918,8 @@ void DocMarkdownWriter::writeSignature(CallableDecl* callableDecl) const auto& param = signature.params[i]; line.clear(); - line << " " << translateToHTMLWithLinks(callableDecl, printer.getPartSlice(param.first)); + line << " " + << translateToHTMLWithLinks(callableDecl, printer.getPartSlice(param.first)); line.appendChar(' '); line << translateToHTMLWithLinks(callableDecl, printer.getPartSlice(param.second)); @@ -922,8 +945,13 @@ void DocMarkdownWriter::writeSignature(CallableDecl* callableDecl) // Synthesize `where` clause for things defined in an extension. if (auto targetTypeDeclRef = isDeclRefTypeOf(extensionDecl->targetType)) { - writeExtensionConditions(out, extensionDecl, "\n where ", true); - // We need to follow the parent of the target type instead of the parent of the extension decl. + writeExtensionConditions( + out, + extensionDecl, + "\n where ", + true); + // We need to follow the parent of the target type instead of the parent of the + // extension decl. parentDecl = getParentDecl(targetTypeDeclRef.getDecl()); continue; } @@ -936,12 +964,16 @@ void DocMarkdownWriter::writeSignature(CallableDecl* callableDecl) if (auto typeConstraint = as(member)) { out << toSlice("\n where "); - out << translateToHTMLWithLinks(parentDecl, getSub(m_astBuilder, typeConstraint)->toString()); + out << translateToHTMLWithLinks( + parentDecl, + getSub(m_astBuilder, typeConstraint)->toString()); if (typeConstraint->isEqualityConstraint) out << " == "; else out << toSlice(" : "); - out << translateToHTMLWithLinks(parentDecl, getSup(m_astBuilder, typeConstraint)->toString()); + out << translateToHTMLWithLinks( + parentDecl, + getSup(m_astBuilder, typeConstraint)->toString()); } } } @@ -950,7 +982,9 @@ void DocMarkdownWriter::writeSignature(CallableDecl* callableDecl) out << ";\n"; } -List DocMarkdownWriter::_getUniqueParams(const List& decls, DeclDocumentation* funcDoc) +List DocMarkdownWriter::_getUniqueParams( + const List& decls, + DeclDocumentation* funcDoc) { List out; @@ -1000,7 +1034,9 @@ List DocMarkdownWriter::_getUniqueParams(const L return out; } -static Index _addRequirement(const DocMarkdownWriter::Requirement& req, List& ioReqs) +static Index _addRequirement( + const DocMarkdownWriter::Requirement& req, + List& ioReqs) { auto index = ioReqs.indexOf(req); if (index < 0) @@ -1013,7 +1049,7 @@ static Index _addRequirement(const DocMarkdownWriter::Requirement& req, List& ioReqs) { - return _addRequirement(DocMarkdownWriter::Requirement{ set }, ioReqs); + return _addRequirement(DocMarkdownWriter::Requirement{set}, ioReqs); } static Index _addRequirements(Decl* decl, List& ioReqs) @@ -1102,15 +1138,17 @@ void DocMarkdownWriter::_appendRequirements(const Requirement& requirement) m_builder->append("\n"); - // TODO: We should probably print the capabilities for each stage set if the requirements differ between - // different stages, but for now we'll just print the first one, assuming the rest are the same. - // This is currently true for most if not all of our core module decls. + // TODO: We should probably print the capabilities for each stage set if the requirements + // differ between different stages, but for now we'll just print the first one, assuming the + // rest are the same. This is currently true for most if not all of our core module decls. // if (targetSet.second.shaderStageSets.getCount() > 0 && targetSet.second.shaderStageSets.begin()->second.atomSet.has_value()) { List capabilities; - auto atomSet = targetSet.second.shaderStageSets.begin()->second.atomSet.value().newSetWithoutImpliedAtoms(); + auto atomSet = targetSet.second.shaderStageSets.begin() + ->second.atomSet.value() + .newSetWithoutImpliedAtoms(); for (auto atom : atomSet) { // If the requirement atom is the target or stage atom, don't repeat ourselves. @@ -1141,28 +1179,30 @@ void DocMarkdownWriter::_appendRequirements(const Requirement& requirement) } } -void DocMarkdownWriter::_maybeAppendRequirements(const UnownedStringSlice& title, const List& uniqueRequirements) +void DocMarkdownWriter::_maybeAppendRequirements( + const UnownedStringSlice& title, + const List& uniqueRequirements) { - auto& out = *m_builder; - const Index uniqueCount = uniqueRequirements.getCount(); + auto& out = *m_builder; + const Index uniqueCount = uniqueRequirements.getCount(); - if (uniqueCount <= 0) - { - return; - } + if (uniqueCount <= 0) + { + return; + } - if (uniqueCount == 1) - { - const auto& reqs = uniqueRequirements[0]; + if (uniqueCount == 1) + { + const auto& reqs = uniqueRequirements[0]; - out << title; + out << title; - _appendRequirements(reqs); - out << toSlice("\n"); - } - else - { - out << title; + _appendRequirements(reqs); + out << toSlice("\n"); + } + else + { + out << title; for (Index i = 0; i < uniqueCount; ++i) { @@ -1170,9 +1210,9 @@ void DocMarkdownWriter::_maybeAppendRequirements(const UnownedStringSlice& title _appendRequirements(uniqueRequirements[i]); out << toSlice("\n"); } - } + } - out << toSlice("\n"); + out << toSlice("\n"); } static Decl* _getSameNameDecl(ContainerDecl* parentDecl, Decl* decl) @@ -1204,14 +1244,15 @@ void ParsedDescription::write(DocMarkdownWriter* writer, Decl* decl, StringBuild { switch (span.kind) { - case DocumentationSpanKind::OrdinaryText: + case DocumentationSpanKind::OrdinaryText: { out << span.text; break; } - case DocumentationSpanKind::InlineCode: + case DocumentationSpanKind::InlineCode: { - out << "" << writer->translateToHTMLWithLinks(decl, span.text) << ""; + out << "" << writer->translateToHTMLWithLinks(decl, span.text) + << ""; break; } } @@ -1233,8 +1274,8 @@ void ParsedDescription::parse(UnownedStringSlice text) if (line.startsWith("```")) { isInCodeBlock = !isInCodeBlock; - spans.add({ line, DocumentationSpanKind::OrdinaryText}); - spans.add({ toSlice("\n"), DocumentationSpanKind::OrdinaryText }); + spans.add({line, DocumentationSpanKind::OrdinaryText}); + spans.add({toSlice("\n"), DocumentationSpanKind::OrdinaryText}); codeBlockIndent = originalLine.indexOf('`'); continue; } @@ -1250,9 +1291,10 @@ void ParsedDescription::parse(UnownedStringSlice text) { if (currentSpanEnd > currentSpanStart) { - spans.add({ - UnownedStringSlice(currentSpanStart, line.begin() + i), - isInCode ? DocumentationSpanKind::InlineCode : DocumentationSpanKind::OrdinaryText }); + spans.add( + {UnownedStringSlice(currentSpanStart, line.begin() + i), + isInCode ? DocumentationSpanKind::InlineCode + : DocumentationSpanKind::OrdinaryText}); currentSpanEnd = currentSpanStart = line.begin() + i + 1; } isInCode = !isInCode; @@ -1264,10 +1306,11 @@ void ParsedDescription::parse(UnownedStringSlice text) } if (currentSpanEnd > currentSpanStart) { - spans.add({ UnownedStringSlice(currentSpanStart, currentSpanEnd), - DocumentationSpanKind::OrdinaryText }); + spans.add( + {UnownedStringSlice(currentSpanStart, currentSpanEnd), + DocumentationSpanKind::OrdinaryText}); } - spans.add({ toSlice("\n"), DocumentationSpanKind::OrdinaryText }); + spans.add({toSlice("\n"), DocumentationSpanKind::OrdinaryText}); } else { @@ -1283,8 +1326,8 @@ void ParsedDescription::parse(UnownedStringSlice text) break; } } - spans.add({ line, DocumentationSpanKind::OrdinaryText }); - spans.add({ toSlice("\n"), DocumentationSpanKind::OrdinaryText }); + spans.add({line, DocumentationSpanKind::OrdinaryText}); + spans.add({toSlice("\n"), DocumentationSpanKind::OrdinaryText}); } } } @@ -1407,9 +1450,7 @@ void DeclDocumentation::parse(const UnownedStringSlice& text) { case DocPageSection::ExperimentalCallout: case DocPageSection::InternalCallout: - case DocPageSection::DeprecatedCallout: - currentSection = DocPageSection::Description; - break; + case DocPageSection::DeprecatedCallout: currentSection = DocPageSection::Description; break; } } for (auto& kv : sectionBuilders) @@ -1418,7 +1459,10 @@ void DeclDocumentation::parse(const UnownedStringSlice& text) } } -void DocMarkdownWriter::writeCallableOverridable(DocumentPage* page, const ASTMarkup::Entry& primaryEntry, CallableDecl* callableDecl) +void DocMarkdownWriter::writeCallableOverridable( + DocumentPage* page, + const ASTMarkup::Entry& primaryEntry, + CallableDecl* callableDecl) { SLANG_UNUSED(primaryEntry); @@ -1449,7 +1493,10 @@ void DocMarkdownWriter::writeCallableOverridable(DocumentPage* page, const ASTMa if (!descriptionSB.toString().startsWith(markup)) { auto decl = as(entry->m_node); - m_sink->diagnose(decl->loc, Diagnostics::ignoredDocumentationOnOverloadCandidate, decl); + m_sink->diagnose( + decl->loc, + Diagnostics::ignoredDocumentationOnOverloadCandidate, + decl); } } else @@ -1478,9 +1525,12 @@ void DocMarkdownWriter::writeCallableOverridable(DocumentPage* page, const ASTMa { for (auto& entry : page->entries) { - Decl* sameNameDecl = _getSameNameDecl(as(getParentDecl((Decl*)entry->m_node)), callableDecl); + Decl* sameNameDecl = _getSameNameDecl( + as(getParentDecl((Decl*)entry->m_node)), + callableDecl); - for (Decl* curDecl = sameNameDecl; curDecl; curDecl = curDecl->nextInContainerWithSameName) + for (Decl* curDecl = sameNameDecl; curDecl; + curDecl = curDecl->nextInContainerWithSameName) { CallableDecl* sig = nullptr; if (GenericDecl* genericDecl = as(curDecl)) @@ -1507,7 +1557,9 @@ void DocMarkdownWriter::writeCallableOverridable(DocumentPage* page, const ASTMa } // Lets put back into source order - sigs.sort([](CallableDecl* a, CallableDecl* b) -> bool { return a->loc.getRaw() < b->loc.getRaw(); }); + sigs.sort( + [](CallableDecl* a, CallableDecl* b) -> bool + { return a->loc.getRaw() < b->loc.getRaw(); }); } // Maps a sig index to a unique requirements set @@ -1525,13 +1577,13 @@ void DocMarkdownWriter::writeCallableOverridable(DocumentPage* page, const ASTMa } // Output the signature - { + { out << toSlice("## Signature \n\n"); out << toSlice("
\n");
 
         const Int sigCount = sigs.getCount();
         for (Index i = 0; i < sigCount; ++i)
-        {            
+        {
             auto sig = sigs[i];
             // Get the requirements index for this sig
             const Index requirementsIndex = requirementsMap[i];
@@ -1539,7 +1591,8 @@ void DocMarkdownWriter::writeCallableOverridable(DocumentPage* page, const ASTMa
             // Output if needs unique requirements
             if (requirements.getCount() > 1 && requirementsIndex != -1)
             {
-                out << toSlice("/// Requires Capability Set ") << (requirementsIndex + 1) << toSlice(":\n");
+                out << toSlice("/// Requires Capability Set ") << (requirementsIndex + 1)
+                    << toSlice(":\n");
             }
 
             writeSignature(sig);
@@ -1559,16 +1612,15 @@ void DocMarkdownWriter::writeCallableOverridable(DocumentPage* page, const ASTMa
                 GenericDecl* genericDecl = as(sig->parentDecl);
 
                 // NOTE!
-                // Here we assume the names of generic parameters are such that they are 
+                // Here we assume the names of generic parameters are such that they are
 
-                // We list generic parameters, as types of parameters, if they are directly associated with this
-                // callable.
+                // We list generic parameters, as types of parameters, if they are directly
+                // associated with this callable.
                 if (genericDecl)
                 {
                     for (Decl* decl : genericDecl->members)
                     {
-                        if (as(decl) ||
-                            as(decl))
+                        if (as(decl) || as(decl))
                         {
                             genericDecls.add(decl);
                         }
@@ -1654,7 +1706,7 @@ void DocMarkdownWriter::writeEnum(const ASTMarkup::Entry& entry, EnumDecl* enumD
     out << toSlice("## Values \n\n");
 
     _appendAsBullets(_getAsNameAndTextList(enumDecl->getMembersOfType()), false, '_');
-    
+
     declDoc.writeSection(out, this, enumDecl, DocPageSection::Remarks);
     declDoc.writeSection(out, this, enumDecl, DocPageSection::Example);
     declDoc.writeSection(out, this, enumDecl, DocPageSection::SeeAlso);
@@ -1666,7 +1718,7 @@ void DocMarkdownWriter::_appendEscaped(const UnownedStringSlice& text)
 
     const char* start = text.begin();
     const char* cur = start;
-    const char*const end = text.end();
+    const char* const end = text.end();
 
     for (; cur < end; ++cur)
     {
@@ -1674,25 +1726,25 @@ void DocMarkdownWriter::_appendEscaped(const UnownedStringSlice& text)
 
         switch (c)
         {
-            case '<':
-            case '>':
-            case '&':
-            case '"':
-            case '_':
+        case '<':
+        case '>':
+        case '&':
+        case '"':
+        case '_':
             {
                 // Flush if any before
                 if (cur > start)
                 {
                     out.append(start, cur);
                 }
-                // Prefix with the 
+                // Prefix with the
                 out.appendChar('\\');
 
                 // Start will still include the char, for later flushing
                 start = cur;
                 break;
             }
-            default: break;
+        default: break;
         }
     }
 
@@ -1704,7 +1756,9 @@ void DocMarkdownWriter::_appendEscaped(const UnownedStringSlice& text)
 }
 
 
-void DocMarkdownWriter::_appendDerivedFrom(const UnownedStringSlice& prefix, AggTypeDeclBase* aggTypeDecl)
+void DocMarkdownWriter::_appendDerivedFrom(
+    const UnownedStringSlice& prefix,
+    AggTypeDeclBase* aggTypeDecl)
 {
     auto& out = *m_builder;
 
@@ -1774,7 +1828,8 @@ void DocMarkdownWriter::_appendAggTypeName(const ASTMarkup::Entry& entry, Decl*
     }
     else if (as(aggTypeDecl))
     {
-        out << toSlice("interface ") << escapeMarkdownText(printer.getStringBuilder().produceString());
+        out << toSlice("interface ")
+            << escapeMarkdownText(printer.getStringBuilder().produceString());
     }
     else if (ExtensionDecl* extensionDecl = as(aggTypeDecl))
     {
@@ -1783,7 +1838,8 @@ void DocMarkdownWriter::_appendAggTypeName(const ASTMarkup::Entry& entry, Decl*
     }
     else if (as(aggTypeDecl))
     {
-        out << toSlice("typealias ") << escapeMarkdownText(printer.getStringBuilder().produceString());
+        out << toSlice("typealias ")
+            << escapeMarkdownText(printer.getStringBuilder().produceString());
     }
     else
     {
@@ -1791,7 +1847,10 @@ void DocMarkdownWriter::_appendAggTypeName(const ASTMarkup::Entry& entry, Decl*
     }
 }
 
-void DocMarkdownWriter::writeAggType(DocumentPage* page, const ASTMarkup::Entry& primaryEntry, AggTypeDeclBase* aggTypeDecl)
+void DocMarkdownWriter::writeAggType(
+    DocumentPage* page,
+    const ASTMarkup::Entry& primaryEntry,
+    AggTypeDeclBase* aggTypeDecl)
 {
     auto& out = *m_builder;
 
@@ -1875,9 +1934,11 @@ void DocMarkdownWriter::writeAggType(DocumentPage* page, const ASTMarkup::Entry&
                     for (auto inheritanceDecl : inheritanceDecls)
                     {
                         out << "  - ";
-                        out << escapeMarkdownText(getSub(m_astBuilder, inheritanceDecl)->toString());
+                        out << escapeMarkdownText(
+                            getSub(m_astBuilder, inheritanceDecl)->toString());
                         out << " : ";
-                        out << escapeMarkdownText(getSup(m_astBuilder, inheritanceDecl)->toString());
+                        out << escapeMarkdownText(
+                            getSup(m_astBuilder, inheritanceDecl)->toString());
                         out << toSlice("\n");
                     }
                 }
@@ -1915,7 +1976,8 @@ void DocMarkdownWriter::writeAggType(DocumentPage* page, const ASTMarkup::Entry&
         if (uniqueMethods.getCount())
         {
             // Put in source definition order
-            uniqueMethods.sort([](Decl* a, Decl* b) -> bool { return a->loc.getRaw() < b->loc.getRaw(); });
+            uniqueMethods.sort(
+                [](Decl* a, Decl* b) -> bool { return a->loc.getRaw() < b->loc.getRaw(); });
 
             out << "## Methods\n\n";
             _appendAsBullets(_getAsStringList(uniqueMethods), 0);
@@ -1951,7 +2013,7 @@ void DocMarkdownWriter::writeAggType(DocumentPage* page, const ASTMarkup::Entry&
                 out << escapeMarkdownText(inheritanceDecl->base.type->toString());
                 if (nonEmptyLines.getCount() != 0)
                 {
-                    out << "` when the following conditions are met:\n\n";    
+                    out << "` when the following conditions are met:\n\n";
                     for (auto condition : nonEmptyLines)
                     {
                         out << "  * " << condition << "\n";
@@ -1990,9 +2052,7 @@ String DocMarkdownWriter::escapeMarkdownText(String text)
             sb << '\\';
             sb.appendChar(c);
             break;
-        default:
-            sb.appendChar(c);
-            break;
+        default: sb.appendChar(c); break;
         }
     }
     return sb.produceString();
@@ -2011,10 +2071,14 @@ Slang::Misc::Token treatLiteralsAsIdentifier(Slang::Misc::Token token)
     {
         token.Type = Slang::Misc::TokenType::Identifier;
         StringBuilder stringSB;
-        StringEscapeUtil::appendQuoted(StringEscapeUtil::getHandler(StringEscapeUtil::Style::Cpp), token.Content.getUnownedSlice(), stringSB);
+        StringEscapeUtil::appendQuoted(
+            StringEscapeUtil::getHandler(StringEscapeUtil::Style::Cpp),
+            token.Content.getUnownedSlice(),
+            stringSB);
         token.Content = stringSB.produceString();
     }
-    else if (token.Type == Slang::Misc::TokenType::IntLiteral ||
+    else if (
+        token.Type == Slang::Misc::TokenType::IntLiteral ||
         token.Type == Slang::Misc::TokenType::DoubleLiteral)
     {
         token.Type = Slang::Misc::TokenType::Identifier;
@@ -2030,7 +2094,7 @@ String DocMarkdownWriter::translateToMarkdownWithLinks(String text, bool strictC
     Slang::Misc::TokenReader reader(text);
     bool requireSpaceBeforeNextToken = false;
     bool isFirstToken = true;
-    for (; !reader.IsEnd(); )
+    for (; !reader.IsEnd();)
     {
         auto token = treatLiteralsAsIdentifier(reader.ReadToken());
 
@@ -2055,12 +2119,14 @@ String DocMarkdownWriter::translateToMarkdownWithLinks(String text, bool strictC
             }
             String sectionName;
             Decl* referencedDecl = nullptr;
-            auto page = findPageForToken(currentPage.getLast(), tokenContent, sectionName, referencedDecl);
+            auto page =
+                findPageForToken(currentPage.getLast(), tokenContent, sectionName, referencedDecl);
 
             if (isFirstToken && strictChildLookup && page && page->parentPage != m_currentPage)
             {
-                // If we are performing a strict child lookup (for displaying the member list of an agg type),
-                // then we want to ignore any lookup results that refer to a different parent page.
+                // If we are performing a strict child lookup (for displaying the member list of an
+                // agg type), then we want to ignore any lookup results that refer to a different
+                // parent page.
                 page = nullptr;
             }
 
@@ -2088,9 +2154,7 @@ String DocMarkdownWriter::translateToMarkdownWithLinks(String text, bool strictC
             case Slang::Misc::TokenType::Comma:
             case Slang::Misc::TokenType::Dot:
             case Slang::Misc::TokenType::IntLiteral:
-            case Slang::Misc::TokenType::Semicolon:
-                requireSpaceBeforeNextToken = false;
-                break;
+            case Slang::Misc::TokenType::Semicolon:  requireSpaceBeforeNextToken = false; break;
             default:
                 requireSpaceBeforeNextToken = true;
                 sb.appendChar(' ');
@@ -2124,7 +2188,8 @@ bool isKeyword(const UnownedStringSlice& slice)
 {
     if (isDeclKeyword(slice))
         return true;
-    static const char* knownTypeNames[] = { "int", "float", "half", "double", "bool", "void", "uint" };
+    static const char* knownTypeNames[] =
+        {"int", "float", "half", "double", "bool", "void", "uint"};
     for (auto typeName : knownTypeNames)
     {
         if (slice == typeName)
@@ -2141,7 +2206,7 @@ String DocMarkdownWriter::translateToHTMLWithLinks(Decl* decl, String text)
     currentPage.add(m_currentPage);
     Slang::Misc::TokenReader reader(text);
     bool prevIsIdentifier = false;
-    for (; !reader.IsEnd(); )
+    for (; !reader.IsEnd();)
     {
         auto token = treatLiteralsAsIdentifier(reader.ReadToken());
 
@@ -2151,7 +2216,8 @@ String DocMarkdownWriter::translateToHTMLWithLinks(Decl* decl, String text)
                 sb.append(' ');
             String sectionName;
             Decl* referencedDecl = nullptr;
-            auto page = findPageForToken(currentPage.getLast(), token.Content, sectionName, referencedDecl);
+            auto page =
+                findPageForToken(currentPage.getLast(), token.Content, sectionName, referencedDecl);
             if (page)
             {
                 sb.append("(referencedDecl) ||
-                    as(referencedDecl))
+                else if (as(referencedDecl) || as(referencedDecl))
                     sb.append(" class=\"code_type\"");
                 else if (as(referencedDecl))
                     sb.append(" class=\"code_param\"");
@@ -2223,12 +2288,12 @@ const char* getSectionTitle(DocPageSection section)
     switch (section)
     {
     case DocPageSection::Description: return "Description";
-    case DocPageSection::Parameter: return "Parameters";
-    case DocPageSection::ReturnInfo: return "Return value";
-    case DocPageSection::Remarks: return "Remarks";
-    case DocPageSection::Example: return "Example";
-    case DocPageSection::SeeAlso: return "See also";
-    default: return "";
+    case DocPageSection::Parameter:   return "Parameters";
+    case DocPageSection::ReturnInfo:  return "Return value";
+    case DocPageSection::Remarks:     return "Remarks";
+    case DocPageSection::Example:     return "Example";
+    case DocPageSection::SeeAlso:     return "See also";
+    default:                          return "";
     }
 }
 
@@ -2243,7 +2308,10 @@ void DeclDocumentation::writeDescription(StringBuilder& out, DocMarkdownWriter*
     writeSection(out, writer, decl, DocPageSection::Description);
 }
 
-void DeclDocumentation::writeGenericParameters(StringBuilder& out, DocMarkdownWriter* writer, Decl* decl)
+void DeclDocumentation::writeGenericParameters(
+    StringBuilder& out,
+    DocMarkdownWriter* writer,
+    Decl* decl)
 {
     GenericDecl* genericDecl = as(decl->parentDecl);
     if (!genericDecl)
@@ -2253,8 +2321,7 @@ void DeclDocumentation::writeGenericParameters(StringBuilder& out, DocMarkdownWr
     List params;
     for (Decl* member : genericDecl->members)
     {
-        if (as(member) ||
-            as(member))
+        if (as(member) || as(member))
         {
             params.add(member);
         }
@@ -2285,7 +2352,11 @@ void DeclDocumentation::writeGenericParameters(StringBuilder& out, DocMarkdownWr
     }
 }
 
-void DeclDocumentation::writeSection(StringBuilder& out, DocMarkdownWriter* writer, Decl* decl, DocPageSection section)
+void DeclDocumentation::writeSection(
+    StringBuilder& out,
+    DocMarkdownWriter* writer,
+    Decl* decl,
+    DocPageSection section)
 {
     SLANG_UNUSED(decl);
     ParsedDescription* sectionDoc = sections.tryGetValue(section);
@@ -2296,20 +2367,26 @@ void DeclDocumentation::writeSection(StringBuilder& out, DocMarkdownWriter* writ
     {
     case DocPageSection::DeprecatedCallout:
         out << "> #### Deprecated Feature\n";
-        out << "> The feature described in this page is marked as deprecated, and may be removed in a future release.\n";
-        out << "> Users are advised to avoid using this feature, and to migrate to a newer alternative.\n";
+        out << "> The feature described in this page is marked as deprecated, and may be "
+               "removed in a future release.\n";
+        out << "> Users are advised to avoid using this feature, and to migrate to a newer "
+               "alternative.\n";
         out << "\n";
         return;
     case DocPageSection::ExperimentalCallout:
         out << "> #### Experimental Feature\n";
-        out << "> The feature described in this page is marked as experimental, and may be subject to change in future releases.\n";
-        out << "> Users are advised that any code that depend on this feature may not be compilable by future versions of the compiler.\n";
+        out << "> The feature described in this page is marked as experimental, and may be "
+               "subject to change in future releases.\n";
+        out << "> Users are advised that any code that depend on this feature may not be "
+               "compilable by future versions of the compiler.\n";
         out << "\n";
         return;
     case DocPageSection::InternalCallout:
         out << "> #### Internal Feature\n";
-        out << "> The feature described in this page is marked as an internal implementation detail, and is not intended for use by end-users.\n";
-        out << "> Users are advised to avoid using this declaration directly, as it may be removed or changed in future releases.\n";
+        out << "> The feature described in this page is marked as an internal implementation "
+               "detail, and is not intended for use by end-users.\n";
+        out << "> Users are advised to avoid using this declaration directly, as it may be "
+               "removed or changed in future releases.\n";
         out << "\n";
         return;
     }
@@ -2323,14 +2400,10 @@ void DeclDocumentation::writeSection(StringBuilder& out, DocMarkdownWriter* writ
 void DocMarkdownWriter::createPage(ASTMarkup::Entry& entry, Decl* decl)
 {
     // Skip these they will be output as part of their respective 'containers'
-    if (as(decl) ||
-        as(decl) ||
-        as(decl) ||
-        as(decl) ||
-        as(decl) ||
-        as(decl))
+    if (as(decl) || as(decl) || as(decl) ||
+        as(decl) || as(decl) || as(decl))
     {
-        return; 
+        return;
     }
 
     if (CallableDecl* callableDecl = as(decl))
@@ -2387,11 +2460,15 @@ void DocMarkdownWriter::registerCategory(DocumentPage* page, DeclDocumentation&
 
 bool DocMarkdownWriter::isVisible(const Name* name)
 {
-    return name == nullptr || !name->text.startsWith(toSlice("__"))
-        || m_config.visibleDeclNames.contains(getText((Name*)name));
+    return name == nullptr || !name->text.startsWith(toSlice("__")) ||
+           m_config.visibleDeclNames.contains(getText((Name*)name));
 }
 
-DocumentPage* DocMarkdownWriter::findPageForToken(DocumentPage* currentPage, String token, String& outSectionName, Decl*& outDecl)
+DocumentPage* DocMarkdownWriter::findPageForToken(
+    DocumentPage* currentPage,
+    String token,
+    String& outSectionName,
+    Decl*& outDecl)
 {
     while (currentPage)
     {
@@ -2532,7 +2609,7 @@ void DocumentationConfig::parse(UnownedStringSlice config)
     List lines;
     StringUtil::calcLines(config, lines);
     Index ptr = 0;
-    for (;ptr < lines.getCount(); ptr++)
+    for (; ptr < lines.getCount(); ptr++)
     {
         auto line = lines[ptr];
         if (line.startsWith(toSlice("@preamble:")))
@@ -2578,7 +2655,8 @@ void DocumentationConfig::parse(UnownedStringSlice config)
 
 void sortPages(DocumentPage* page)
 {
-    page->children.sort([](DocumentPage* a, DocumentPage* b) -> bool { return a->shortName < b->shortName; });
+    page->children.sort(
+        [](DocumentPage* a, DocumentPage* b) -> bool { return a->shortName < b->shortName; });
 }
 
 void DocMarkdownWriter::generateSectionIndexPage(DocumentPage* page)
@@ -2594,7 +2672,8 @@ void DocMarkdownWriter::generateSectionIndexPage(DocumentPage* page)
 
     for (auto child : page->children)
     {
-        sb << "- [" << escapeMarkdownText(child->shortName) << "](" << getDocPath(m_config, child->path) << ")\n";
+        sb << "- [" << escapeMarkdownText(child->shortName) << "]("
+           << getDocPath(m_config, child->path) << ")\n";
     }
 }
 
@@ -2602,34 +2681,57 @@ DocumentPage* DocMarkdownWriter::writeAll(UnownedStringSlice configStr)
 {
     m_config.parse(configStr);
 
-    auto addBuiltinPage = [&](DocumentPage* parent, UnownedStringSlice path, UnownedStringSlice title, UnownedStringSlice shortTitle)
+    auto addBuiltinPage = [&](DocumentPage* parent,
+                              UnownedStringSlice path,
+                              UnownedStringSlice title,
+                              UnownedStringSlice shortTitle)
+    {
+        RefPtr page = new DocumentPage();
+        page->title = title;
+        page->path = path;
+        page->shortName = shortTitle;
+        page->decl = nullptr;
+        if (parent)
         {
-            RefPtr page = new DocumentPage();
-            page->title = title;
-            page->path = path;
-            page->shortName = shortTitle;
-            page->decl = nullptr;
-            if (parent)
-            {
-                parent->children.add(page);
-            }
-            m_output[page->path] = page;
-            return page.get();
-        };
-    m_rootPage = addBuiltinPage(nullptr, toSlice("index.md"), m_config.title.getUnownedSlice(), toSlice("Core Module Reference"));
+            parent->children.add(page);
+        }
+        m_output[page->path] = page;
+        return page.get();
+    };
+    m_rootPage = addBuiltinPage(
+        nullptr,
+        toSlice("index.md"),
+        m_config.title.getUnownedSlice(),
+        toSlice("Core Module Reference"));
     m_rootPage->skipWrite = true;
 
-    m_interfacesPage = addBuiltinPage(m_rootPage.get(), toSlice("interfaces/index.md"), toSlice("Interfaces"), toSlice("Interfaces"));
-    m_typesPage = addBuiltinPage(m_rootPage.get(), toSlice("types/index.md"), toSlice("Types"), toSlice("Types"));
-    m_attributesPage = addBuiltinPage(m_rootPage.get(), toSlice("attributes/index.md"), toSlice("Attributes"), toSlice("Attributes"));
-    m_globalDeclsPage = addBuiltinPage(m_rootPage.get(), toSlice("global-decls/index.md"), toSlice("Global Declarations"), toSlice("Global Declarations"));
+    m_interfacesPage = addBuiltinPage(
+        m_rootPage.get(),
+        toSlice("interfaces/index.md"),
+        toSlice("Interfaces"),
+        toSlice("Interfaces"));
+    m_typesPage = addBuiltinPage(
+        m_rootPage.get(),
+        toSlice("types/index.md"),
+        toSlice("Types"),
+        toSlice("Types"));
+    m_attributesPage = addBuiltinPage(
+        m_rootPage.get(),
+        toSlice("attributes/index.md"),
+        toSlice("Attributes"),
+        toSlice("Attributes"));
+    m_globalDeclsPage = addBuiltinPage(
+        m_rootPage.get(),
+        toSlice("global-decls/index.md"),
+        toSlice("Global Declarations"),
+        toSlice("Global Declarations"));
 
     // In the first pass, we create all the pages so we can reference them
     // when writing the content.
     for (auto& entry : m_markup->getEntries())
     {
         Decl* decl = as(entry.m_node);
-    
+
         if (decl && isVisible(entry))
         {
             createPage(entry, decl);
@@ -2652,7 +2754,7 @@ void DocMarkdownWriter::writePage(DocumentPage* page)
         return;
     if (page->entries.getCount() == 0)
         return;
-    
+
     m_currentPage = page;
     m_builder = &(page->get());
 
@@ -2710,14 +2812,23 @@ void DocMarkdownWriter::writePageRecursive(DocumentPage* page)
     }
 }
 
-void writeTOCImpl(StringBuilder& sb, DocMarkdownWriter* writer, DocumentationConfig& config, DocumentPage* page);
+void writeTOCImpl(
+    StringBuilder& sb,
+    DocMarkdownWriter* writer,
+    DocumentationConfig& config,
+    DocumentPage* page);
 
-void writeTOCChildren(StringBuilder& sb, DocMarkdownWriter* writer, DocumentationConfig& config, DocumentPage* page)
+void writeTOCChildren(
+    StringBuilder& sb,
+    DocMarkdownWriter* writer,
+    DocumentationConfig& config,
+    DocumentPage* page)
 {
     if (page->children.getCount() == 0)
         return;
 
-    sb << R"(