From 6cbc3929a54d37bd23cb5efa8e3320ba02f78b2f Mon Sep 17 00:00:00 2001 From: jsmall-nvidia Date: Fri, 31 May 2019 17:20:37 -0400 Subject: Use slang- prefix on slang compiler and core source (#973) * Prefixing source files in source/slang with slang- * Prefix source in source/slang with slang- prefix. * Rename core source files with slang- prefix. * Update project files. * Fix problems from automatic merge. --- source/core/allocator.h | 64 - source/core/array-view.h | 112 - source/core/array.h | 135 - source/core/basic.h | 13 - source/core/common.h | 95 - source/core/core.vcxproj | 40 +- source/core/core.vcxproj.filters | 78 +- source/core/dictionary.h | 619 -- source/core/exception.h | 137 - source/core/hash.h | 153 - source/core/list.h | 631 -- source/core/platform.cpp | 172 - source/core/platform.h | 67 - source/core/secure-crt.h | 88 - source/core/slang-allocator.h | 64 + source/core/slang-array-view.h | 112 + source/core/slang-array.h | 135 + source/core/slang-basic.h | 13 + source/core/slang-byte-encode-util.cpp | 2 - source/core/slang-byte-encode-util.h | 6 +- source/core/slang-common.h | 95 + source/core/slang-dictionary.h | 620 ++ source/core/slang-exception.h | 137 + source/core/slang-free-list.h | 6 +- source/core/slang-hash.h | 153 + source/core/slang-io.cpp | 2 +- source/core/slang-io.h | 10 +- source/core/slang-list.h | 631 ++ source/core/slang-math.h | 4 +- source/core/slang-memory-arena.h | 4 +- source/core/slang-object-scope-manager.h | 8 +- source/core/slang-platform.cpp | 172 + source/core/slang-platform.h | 67 + source/core/slang-random-generator.h | 6 +- source/core/slang-render-api-util.cpp | 6 +- source/core/slang-render-api-util.h | 4 +- source/core/slang-secure-crt.h | 88 + source/core/slang-shared-library.cpp | 5 +- source/core/slang-shared-library.h | 12 +- source/core/slang-smart-pointer.h | 250 + source/core/slang-std-writers.h | 4 +- source/core/slang-stream.cpp | 294 + source/core/slang-stream.h | 113 + source/core/slang-string-slice-pool.h | 8 +- source/core/slang-string-util.h | 6 +- source/core/slang-string.cpp | 2 +- source/core/slang-string.h | 12 +- source/core/slang-test-tool-util.h | 4 +- source/core/slang-text-io.cpp | 343 + source/core/slang-text-io.h | 316 + source/core/slang-token-reader.cpp | 768 ++ source/core/slang-token-reader.h | 260 + source/core/slang-type-traits.h | 46 + source/core/slang-uint-set.h | 8 +- source/core/slang-writer.cpp | 2 +- source/core/slang-writer.h | 6 +- source/core/smart-pointer.h | 250 - source/core/stream.cpp | 294 - source/core/stream.h | 113 - source/core/text-io.cpp | 343 - source/core/text-io.h | 316 - source/core/token-reader.cpp | 768 -- source/core/token-reader.h | 258 - source/core/type-traits.h | 46 - source/slang/check.cpp | 11334 ----------------------- source/slang/check.h | 7 - source/slang/compiler.cpp | 1645 ---- source/slang/compiler.h | 1423 --- source/slang/decl-defs.h | 325 - source/slang/diagnostic-defs.h | 487 - source/slang/diagnostics.cpp | 350 - source/slang/diagnostics.h | 280 - source/slang/dxc-support.cpp | 302 - source/slang/emit.cpp | 510 - source/slang/emit.h | 27 - source/slang/expr-defs.h | 206 - source/slang/image-format-defs.h | 47 - source/slang/ir-bind-existentials.cpp | 352 - source/slang/ir-bind-existentials.h | 15 - source/slang/ir-clone.cpp | 295 - source/slang/ir-clone.h | 183 - source/slang/ir-constexpr.cpp | 553 -- source/slang/ir-constexpr.h | 12 - source/slang/ir-dce.cpp | 325 - source/slang/ir-dce.h | 19 - source/slang/ir-dominators.cpp | 720 -- source/slang/ir-dominators.h | 162 - source/slang/ir-entry-point-uniforms.cpp | 425 - source/slang/ir-entry-point-uniforms.h | 12 - source/slang/ir-glsl-legalize.cpp | 1687 ---- source/slang/ir-glsl-legalize.h | 22 - source/slang/ir-inst-defs.h | 480 - source/slang/ir-insts.h | 1343 --- source/slang/ir-legalize-types.cpp | 2626 ------ source/slang/ir-link.cpp | 1361 --- source/slang/ir-link.h | 27 - source/slang/ir-missing-return.cpp | 43 - source/slang/ir-missing-return.h | 12 - source/slang/ir-restructure-scoping.cpp | 434 - source/slang/ir-restructure-scoping.h | 24 - source/slang/ir-restructure.cpp | 663 -- source/slang/ir-restructure.h | 261 - source/slang/ir-sccp.cpp | 950 -- source/slang/ir-sccp.h | 18 - source/slang/ir-serialize.cpp | 2125 ----- source/slang/ir-serialize.h | 549 -- source/slang/ir-specialize-resources.cpp | 865 -- source/slang/ir-specialize-resources.h | 24 - source/slang/ir-specialize.cpp | 1864 ---- source/slang/ir-specialize.h | 12 - source/slang/ir-ssa.cpp | 1159 --- source/slang/ir-ssa.h | 9 - source/slang/ir-union.cpp | 776 -- source/slang/ir-union.h | 18 - source/slang/ir-validate.cpp | 207 - source/slang/ir-validate.h | 35 - source/slang/ir.cpp | 4511 --------- source/slang/ir.h | 1202 --- source/slang/legalize-types.cpp | 1486 --- source/slang/legalize-types.h | 678 -- source/slang/lexer.cpp | 1334 --- source/slang/lexer.h | 136 - source/slang/lookup.cpp | 713 -- source/slang/lookup.h | 60 - source/slang/lower-to-ir.cpp | 6498 ------------- source/slang/lower-to-ir.h | 28 - source/slang/mangle.cpp | 478 - source/slang/mangle.h | 29 - source/slang/modifier-defs.h | 463 - source/slang/name.cpp | 37 - source/slang/name.h | 86 - source/slang/object-meta-begin.h | 43 - source/slang/object-meta-end.h | 17 - source/slang/options.cpp | 1356 --- source/slang/parameter-binding.cpp | 2583 ------ source/slang/parameter-binding.h | 34 - source/slang/parser.cpp | 4725 ---------- source/slang/parser.h | 30 - source/slang/preprocessor.cpp | 2302 ----- source/slang/preprocessor.h | 38 - source/slang/profile-defs.h | 305 - source/slang/profile.cpp | 34 - source/slang/profile.h | 106 - source/slang/reflection.cpp | 1451 --- source/slang/reflection.h | 26 - source/slang/slang-c-like-source-emitter.cpp | 38 +- source/slang/slang-c-like-source-emitter.h | 10 +- source/slang/slang-check.cpp | 11334 +++++++++++++++++++++++ source/slang/slang-check.h | 7 + source/slang/slang-compiler.cpp | 1645 ++++ source/slang/slang-compiler.h | 1423 +++ source/slang/slang-decl-defs.h | 325 + source/slang/slang-diagnostic-defs.h | 487 + source/slang/slang-diagnostics.cpp | 350 + source/slang/slang-diagnostics.h | 280 + source/slang/slang-dxc-support.cpp | 302 + source/slang/slang-emit-context.h | 6 +- source/slang/slang-emit-precedence.h | 2 +- source/slang/slang-emit.cpp | 510 + source/slang/slang-emit.h | 27 + source/slang/slang-expr-defs.h | 206 + source/slang/slang-extension-usage-tracker.h | 4 +- source/slang/slang-file-system.cpp | 2 +- source/slang/slang-file-system.h | 2 +- source/slang/slang-image-format-defs.h | 47 + source/slang/slang-ir-bind-existentials.cpp | 352 + source/slang/slang-ir-bind-existentials.h | 15 + source/slang/slang-ir-clone.cpp | 295 + source/slang/slang-ir-clone.h | 183 + source/slang/slang-ir-constexpr.cpp | 553 ++ source/slang/slang-ir-constexpr.h | 12 + source/slang/slang-ir-dce.cpp | 325 + source/slang/slang-ir-dce.h | 19 + source/slang/slang-ir-dominators.cpp | 720 ++ source/slang/slang-ir-dominators.h | 162 + source/slang/slang-ir-entry-point-uniforms.cpp | 425 + source/slang/slang-ir-entry-point-uniforms.h | 12 + source/slang/slang-ir-glsl-legalize.cpp | 1687 ++++ source/slang/slang-ir-glsl-legalize.h | 22 + source/slang/slang-ir-inst-defs.h | 480 + source/slang/slang-ir-insts.h | 1343 +++ source/slang/slang-ir-legalize-types.cpp | 2626 ++++++ source/slang/slang-ir-link.cpp | 1361 +++ source/slang/slang-ir-link.h | 27 + source/slang/slang-ir-missing-return.cpp | 43 + source/slang/slang-ir-missing-return.h | 12 + source/slang/slang-ir-restructure-scoping.cpp | 434 + source/slang/slang-ir-restructure-scoping.h | 24 + source/slang/slang-ir-restructure.cpp | 663 ++ source/slang/slang-ir-restructure.h | 261 + source/slang/slang-ir-sccp.cpp | 950 ++ source/slang/slang-ir-sccp.h | 18 + source/slang/slang-ir-serialize.cpp | 2125 +++++ source/slang/slang-ir-serialize.h | 549 ++ source/slang/slang-ir-specialize-resources.cpp | 865 ++ source/slang/slang-ir-specialize-resources.h | 24 + source/slang/slang-ir-specialize.cpp | 1864 ++++ source/slang/slang-ir-specialize.h | 12 + source/slang/slang-ir-ssa.cpp | 1159 +++ source/slang/slang-ir-ssa.h | 9 + source/slang/slang-ir-union.cpp | 776 ++ source/slang/slang-ir-union.h | 18 + source/slang/slang-ir-validate.cpp | 207 + source/slang/slang-ir-validate.h | 35 + source/slang/slang-ir.cpp | 4511 +++++++++ source/slang/slang-ir.h | 1202 +++ source/slang/slang-legalize-types.cpp | 1486 +++ source/slang/slang-legalize-types.h | 678 ++ source/slang/slang-lexer.cpp | 1334 +++ source/slang/slang-lexer.h | 136 + source/slang/slang-lookup.cpp | 713 ++ source/slang/slang-lookup.h | 60 + source/slang/slang-lower-to-ir.cpp | 6498 +++++++++++++ source/slang/slang-lower-to-ir.h | 28 + source/slang/slang-mangle.cpp | 478 + source/slang/slang-mangle.h | 29 + source/slang/slang-mangled-lexer.h | 4 +- source/slang/slang-modifier-defs.h | 463 + source/slang/slang-name.cpp | 37 + source/slang/slang-name.h | 86 + source/slang/slang-object-meta-begin.h | 43 + source/slang/slang-object-meta-end.h | 17 + source/slang/slang-options.cpp | 1356 +++ source/slang/slang-parameter-binding.cpp | 2583 ++++++ source/slang/slang-parameter-binding.h | 34 + source/slang/slang-parser.cpp | 4725 ++++++++++ source/slang/slang-parser.h | 30 + source/slang/slang-preprocessor.cpp | 2302 +++++ source/slang/slang-preprocessor.h | 38 + source/slang/slang-profile-defs.h | 305 + source/slang/slang-profile.cpp | 34 + source/slang/slang-profile.h | 106 + source/slang/slang-reflection.cpp | 1451 +++ source/slang/slang-reflection.h | 26 + source/slang/slang-source-loc.cpp | 591 ++ source/slang/slang-source-loc.h | 412 + source/slang/slang-source-stream.h | 4 +- source/slang/slang-stdlib.cpp | 6 +- source/slang/slang-stmt-defs.h | 124 + source/slang/slang-syntax-base-defs.h | 307 + source/slang/slang-syntax-defs.h | 10 + source/slang/slang-syntax-visitors.h | 36 + source/slang/slang-syntax.cpp | 2865 ++++++ source/slang/slang-syntax.h | 1419 +++ source/slang/slang-token-defs.h | 96 + source/slang/slang-token.cpp | 39 + source/slang/slang-token.h | 67 + source/slang/slang-type-defs.h | 490 + source/slang/slang-type-layout.cpp | 3209 +++++++ source/slang/slang-type-layout.h | 1118 +++ source/slang/slang-type-system-shared.cpp | 11 + source/slang/slang-type-system-shared.h | 102 + source/slang/slang-val-defs.h | 155 + source/slang/slang-visitor.h | 535 ++ source/slang/slang.cpp | 37 +- source/slang/slang.vcxproj | 198 +- source/slang/slang.vcxproj.filters | 236 +- source/slang/source-loc.cpp | 591 -- source/slang/source-loc.h | 412 - source/slang/stmt-defs.h | 124 - source/slang/syntax-base-defs.h | 307 - source/slang/syntax-defs.h | 10 - source/slang/syntax-visitors.h | 36 - source/slang/syntax.cpp | 2865 ------ source/slang/syntax.h | 1419 --- source/slang/token-defs.h | 96 - source/slang/token.cpp | 39 - source/slang/token.h | 67 - source/slang/type-defs.h | 490 - source/slang/type-layout.cpp | 3209 ------- source/slang/type-layout.h | 1118 --- source/slang/type-system-shared.cpp | 11 - source/slang/type-system-shared.h | 102 - source/slang/val-defs.h | 155 - source/slang/visitor.h | 535 -- tools/gfx/circular-resource-heap-d3d12.h | 4 +- tools/gfx/d3d-util.h | 4 +- tools/gfx/descriptor-heap-d3d12.h | 2 +- tools/gfx/flag-combiner.h | 2 +- tools/gfx/render-gl.cpp | 4 +- tools/gfx/render-vk.cpp | 2 +- tools/gfx/render.h | 6 +- tools/gfx/surface.cpp | 2 +- tools/gfx/vk-api.cpp | 2 +- tools/gfx/vk-swap-chain.cpp | 2 +- tools/gfx/vk-swap-chain.h | 2 +- tools/render-test/options.cpp | 2 +- tools/render-test/shader-input-layout.cpp | 2 +- tools/render-test/shader-input-layout.h | 2 +- tools/slang-generate/main.cpp | 4 +- tools/slang-test/options.h | 4 +- tools/slang-test/slang-test-main.cpp | 2 +- tools/slang-test/slangc-tool.cpp | 2 +- tools/slang-test/test-context.h | 4 +- tools/slang-test/test-reporter.h | 4 +- tools/slang-test/unit-test-byte-encode.cpp | 2 +- tools/slang-test/unit-test-free-list.cpp | 2 +- tools/slang-test/unit-test-memory-arena.cpp | 2 +- 298 files changed, 85091 insertions(+), 85088 deletions(-) delete mode 100644 source/core/allocator.h delete mode 100644 source/core/array-view.h delete mode 100644 source/core/array.h delete mode 100644 source/core/basic.h delete mode 100644 source/core/common.h delete mode 100644 source/core/dictionary.h delete mode 100644 source/core/exception.h delete mode 100644 source/core/hash.h delete mode 100644 source/core/list.h delete mode 100644 source/core/platform.cpp delete mode 100644 source/core/platform.h delete mode 100644 source/core/secure-crt.h create mode 100644 source/core/slang-allocator.h create mode 100644 source/core/slang-array-view.h create mode 100644 source/core/slang-array.h create mode 100644 source/core/slang-basic.h create mode 100644 source/core/slang-common.h create mode 100644 source/core/slang-dictionary.h create mode 100644 source/core/slang-exception.h create mode 100644 source/core/slang-hash.h create mode 100644 source/core/slang-list.h create mode 100644 source/core/slang-platform.cpp create mode 100644 source/core/slang-platform.h create mode 100644 source/core/slang-secure-crt.h create mode 100644 source/core/slang-smart-pointer.h create mode 100644 source/core/slang-stream.cpp create mode 100644 source/core/slang-stream.h create mode 100644 source/core/slang-text-io.cpp create mode 100644 source/core/slang-text-io.h create mode 100644 source/core/slang-token-reader.cpp create mode 100644 source/core/slang-token-reader.h create mode 100644 source/core/slang-type-traits.h delete mode 100644 source/core/smart-pointer.h delete mode 100644 source/core/stream.cpp delete mode 100644 source/core/stream.h delete mode 100644 source/core/text-io.cpp delete mode 100644 source/core/text-io.h delete mode 100644 source/core/token-reader.cpp delete mode 100644 source/core/token-reader.h delete mode 100644 source/core/type-traits.h delete mode 100644 source/slang/check.cpp delete mode 100644 source/slang/check.h delete mode 100644 source/slang/compiler.cpp delete mode 100644 source/slang/compiler.h delete mode 100644 source/slang/decl-defs.h delete mode 100644 source/slang/diagnostic-defs.h delete mode 100644 source/slang/diagnostics.cpp delete mode 100644 source/slang/diagnostics.h delete mode 100644 source/slang/dxc-support.cpp delete mode 100644 source/slang/emit.cpp delete mode 100644 source/slang/emit.h delete mode 100644 source/slang/expr-defs.h delete mode 100644 source/slang/image-format-defs.h delete mode 100644 source/slang/ir-bind-existentials.cpp delete mode 100644 source/slang/ir-bind-existentials.h delete mode 100644 source/slang/ir-clone.cpp delete mode 100644 source/slang/ir-clone.h delete mode 100644 source/slang/ir-constexpr.cpp delete mode 100644 source/slang/ir-constexpr.h delete mode 100644 source/slang/ir-dce.cpp delete mode 100644 source/slang/ir-dce.h delete mode 100644 source/slang/ir-dominators.cpp delete mode 100644 source/slang/ir-dominators.h delete mode 100644 source/slang/ir-entry-point-uniforms.cpp delete mode 100644 source/slang/ir-entry-point-uniforms.h delete mode 100644 source/slang/ir-glsl-legalize.cpp delete mode 100644 source/slang/ir-glsl-legalize.h delete mode 100644 source/slang/ir-inst-defs.h delete mode 100644 source/slang/ir-insts.h delete mode 100644 source/slang/ir-legalize-types.cpp delete mode 100644 source/slang/ir-link.cpp delete mode 100644 source/slang/ir-link.h delete mode 100644 source/slang/ir-missing-return.cpp delete mode 100644 source/slang/ir-missing-return.h delete mode 100644 source/slang/ir-restructure-scoping.cpp delete mode 100644 source/slang/ir-restructure-scoping.h delete mode 100644 source/slang/ir-restructure.cpp delete mode 100644 source/slang/ir-restructure.h delete mode 100644 source/slang/ir-sccp.cpp delete mode 100644 source/slang/ir-sccp.h delete mode 100644 source/slang/ir-serialize.cpp delete mode 100644 source/slang/ir-serialize.h delete mode 100644 source/slang/ir-specialize-resources.cpp delete mode 100644 source/slang/ir-specialize-resources.h delete mode 100644 source/slang/ir-specialize.cpp delete mode 100644 source/slang/ir-specialize.h delete mode 100644 source/slang/ir-ssa.cpp delete mode 100644 source/slang/ir-ssa.h delete mode 100644 source/slang/ir-union.cpp delete mode 100644 source/slang/ir-union.h delete mode 100644 source/slang/ir-validate.cpp delete mode 100644 source/slang/ir-validate.h delete mode 100644 source/slang/ir.cpp delete mode 100644 source/slang/ir.h delete mode 100644 source/slang/legalize-types.cpp delete mode 100644 source/slang/legalize-types.h delete mode 100644 source/slang/lexer.cpp delete mode 100644 source/slang/lexer.h delete mode 100644 source/slang/lookup.cpp delete mode 100644 source/slang/lookup.h delete mode 100644 source/slang/lower-to-ir.cpp delete mode 100644 source/slang/lower-to-ir.h delete mode 100644 source/slang/mangle.cpp delete mode 100644 source/slang/mangle.h delete mode 100644 source/slang/modifier-defs.h delete mode 100644 source/slang/name.cpp delete mode 100644 source/slang/name.h delete mode 100644 source/slang/object-meta-begin.h delete mode 100644 source/slang/object-meta-end.h delete mode 100644 source/slang/options.cpp delete mode 100644 source/slang/parameter-binding.cpp delete mode 100644 source/slang/parameter-binding.h delete mode 100644 source/slang/parser.cpp delete mode 100644 source/slang/parser.h delete mode 100644 source/slang/preprocessor.cpp delete mode 100644 source/slang/preprocessor.h delete mode 100644 source/slang/profile-defs.h delete mode 100644 source/slang/profile.cpp delete mode 100644 source/slang/profile.h delete mode 100644 source/slang/reflection.cpp delete mode 100644 source/slang/reflection.h create mode 100644 source/slang/slang-check.cpp create mode 100644 source/slang/slang-check.h create mode 100644 source/slang/slang-compiler.cpp create mode 100644 source/slang/slang-compiler.h create mode 100644 source/slang/slang-decl-defs.h create mode 100644 source/slang/slang-diagnostic-defs.h create mode 100644 source/slang/slang-diagnostics.cpp create mode 100644 source/slang/slang-diagnostics.h create mode 100644 source/slang/slang-dxc-support.cpp create mode 100644 source/slang/slang-emit.cpp create mode 100644 source/slang/slang-emit.h create mode 100644 source/slang/slang-expr-defs.h create mode 100644 source/slang/slang-image-format-defs.h create mode 100644 source/slang/slang-ir-bind-existentials.cpp create mode 100644 source/slang/slang-ir-bind-existentials.h create mode 100644 source/slang/slang-ir-clone.cpp create mode 100644 source/slang/slang-ir-clone.h create mode 100644 source/slang/slang-ir-constexpr.cpp create mode 100644 source/slang/slang-ir-constexpr.h create mode 100644 source/slang/slang-ir-dce.cpp create mode 100644 source/slang/slang-ir-dce.h create mode 100644 source/slang/slang-ir-dominators.cpp create mode 100644 source/slang/slang-ir-dominators.h create mode 100644 source/slang/slang-ir-entry-point-uniforms.cpp create mode 100644 source/slang/slang-ir-entry-point-uniforms.h create mode 100644 source/slang/slang-ir-glsl-legalize.cpp create mode 100644 source/slang/slang-ir-glsl-legalize.h create mode 100644 source/slang/slang-ir-inst-defs.h create mode 100644 source/slang/slang-ir-insts.h create mode 100644 source/slang/slang-ir-legalize-types.cpp create mode 100644 source/slang/slang-ir-link.cpp create mode 100644 source/slang/slang-ir-link.h create mode 100644 source/slang/slang-ir-missing-return.cpp create mode 100644 source/slang/slang-ir-missing-return.h create mode 100644 source/slang/slang-ir-restructure-scoping.cpp create mode 100644 source/slang/slang-ir-restructure-scoping.h create mode 100644 source/slang/slang-ir-restructure.cpp create mode 100644 source/slang/slang-ir-restructure.h create mode 100644 source/slang/slang-ir-sccp.cpp create mode 100644 source/slang/slang-ir-sccp.h create mode 100644 source/slang/slang-ir-serialize.cpp create mode 100644 source/slang/slang-ir-serialize.h create mode 100644 source/slang/slang-ir-specialize-resources.cpp create mode 100644 source/slang/slang-ir-specialize-resources.h create mode 100644 source/slang/slang-ir-specialize.cpp create mode 100644 source/slang/slang-ir-specialize.h create mode 100644 source/slang/slang-ir-ssa.cpp create mode 100644 source/slang/slang-ir-ssa.h create mode 100644 source/slang/slang-ir-union.cpp create mode 100644 source/slang/slang-ir-union.h create mode 100644 source/slang/slang-ir-validate.cpp create mode 100644 source/slang/slang-ir-validate.h create mode 100644 source/slang/slang-ir.cpp create mode 100644 source/slang/slang-ir.h create mode 100644 source/slang/slang-legalize-types.cpp create mode 100644 source/slang/slang-legalize-types.h create mode 100644 source/slang/slang-lexer.cpp create mode 100644 source/slang/slang-lexer.h create mode 100644 source/slang/slang-lookup.cpp create mode 100644 source/slang/slang-lookup.h create mode 100644 source/slang/slang-lower-to-ir.cpp create mode 100644 source/slang/slang-lower-to-ir.h create mode 100644 source/slang/slang-mangle.cpp create mode 100644 source/slang/slang-mangle.h create mode 100644 source/slang/slang-modifier-defs.h create mode 100644 source/slang/slang-name.cpp create mode 100644 source/slang/slang-name.h create mode 100644 source/slang/slang-object-meta-begin.h create mode 100644 source/slang/slang-object-meta-end.h create mode 100644 source/slang/slang-options.cpp create mode 100644 source/slang/slang-parameter-binding.cpp create mode 100644 source/slang/slang-parameter-binding.h create mode 100644 source/slang/slang-parser.cpp create mode 100644 source/slang/slang-parser.h create mode 100644 source/slang/slang-preprocessor.cpp create mode 100644 source/slang/slang-preprocessor.h create mode 100644 source/slang/slang-profile-defs.h create mode 100644 source/slang/slang-profile.cpp create mode 100644 source/slang/slang-profile.h create mode 100644 source/slang/slang-reflection.cpp create mode 100644 source/slang/slang-reflection.h create mode 100644 source/slang/slang-source-loc.cpp create mode 100644 source/slang/slang-source-loc.h create mode 100644 source/slang/slang-stmt-defs.h create mode 100644 source/slang/slang-syntax-base-defs.h create mode 100644 source/slang/slang-syntax-defs.h create mode 100644 source/slang/slang-syntax-visitors.h create mode 100644 source/slang/slang-syntax.cpp create mode 100644 source/slang/slang-syntax.h create mode 100644 source/slang/slang-token-defs.h create mode 100644 source/slang/slang-token.cpp create mode 100644 source/slang/slang-token.h create mode 100644 source/slang/slang-type-defs.h create mode 100644 source/slang/slang-type-layout.cpp create mode 100644 source/slang/slang-type-layout.h create mode 100644 source/slang/slang-type-system-shared.cpp create mode 100644 source/slang/slang-type-system-shared.h create mode 100644 source/slang/slang-val-defs.h create mode 100644 source/slang/slang-visitor.h delete mode 100644 source/slang/source-loc.cpp delete mode 100644 source/slang/source-loc.h delete mode 100644 source/slang/stmt-defs.h delete mode 100644 source/slang/syntax-base-defs.h delete mode 100644 source/slang/syntax-defs.h delete mode 100644 source/slang/syntax-visitors.h delete mode 100644 source/slang/syntax.cpp delete mode 100644 source/slang/syntax.h delete mode 100644 source/slang/token-defs.h delete mode 100644 source/slang/token.cpp delete mode 100644 source/slang/token.h delete mode 100644 source/slang/type-defs.h delete mode 100644 source/slang/type-layout.cpp delete mode 100644 source/slang/type-layout.h delete mode 100644 source/slang/type-system-shared.cpp delete mode 100644 source/slang/type-system-shared.h delete mode 100644 source/slang/val-defs.h delete mode 100644 source/slang/visitor.h diff --git a/source/core/allocator.h b/source/core/allocator.h deleted file mode 100644 index 5832d0b84..000000000 --- a/source/core/allocator.h +++ /dev/null @@ -1,64 +0,0 @@ -#ifndef CORE_LIB_ALLOCATOR_H -#define CORE_LIB_ALLOCATOR_H - -#include -#ifdef _MSC_VER -# include -#endif - -namespace Slang -{ - inline void* alignedAllocate(size_t size, size_t alignment) - { -#ifdef _MSC_VER - return _aligned_malloc(size, alignment); -#elif defined(__CYGWIN__) - return aligned_alloc(alignment, size); -#else - void * rs = 0; - int succ = posix_memalign(&rs, alignment, size); - if (succ!=0) - rs = 0; - return rs; -#endif - } - - inline void alignedDeallocate(void* ptr) - { -#ifdef _MSC_VER - _aligned_free(ptr); -#else - free(ptr); -#endif - } - - class StandardAllocator - { - public: - // not really called - void* allocate(size_t size) - { - return ::malloc(size); - } - void deallocate(void * ptr) - { - return ::free(ptr); - } - }; - - template - class AlignedAllocator - { - public: - void* allocate(size_t size) - { - return alignedAllocate(size, ALIGNMENT); - } - void deallocate(void * ptr) - { - return alignedDeallocate(ptr); - } - }; -} - -#endif diff --git a/source/core/array-view.h b/source/core/array-view.h deleted file mode 100644 index ad9673e2e..000000000 --- a/source/core/array-view.h +++ /dev/null @@ -1,112 +0,0 @@ -#ifndef CORE_LIB_ARRAY_VIEW_H -#define CORE_LIB_ARRAY_VIEW_H - -#include "common.h" - -namespace Slang -{ - template - class ArrayView - { - private: - T* m_buffer; - int m_count; - public: - const T* begin() const { return m_buffer; } - T* begin() { return m_buffer; } - - const T* end() const { return m_buffer + m_count; } - T* end() { return m_buffer + m_count; } - - public: - ArrayView(): - m_buffer(nullptr), - m_count(0) - { - } - ArrayView(T& singleObj): - m_buffer(&singleObj), - m_count(1) - { - } - ArrayView(T* buffer, int size): - m_buffer(buffer), - m_count(size) - { - } - - inline int getCount() const { return m_count; } - - inline const T& operator [](int idx) const - { - SLANG_ASSERT(idx >= 0 && idx <= m_count); - return m_buffer[idx]; - } - inline T& operator [](int idx) - { - SLANG_ASSERT(idx >= 0 && idx <= m_count); - return m_buffer[idx]; - } - - inline const T* getBuffer() const { return m_buffer; } - inline T* getBuffer() { return m_buffer; } - - template - int indexOf(const T2 & val) const - { - for (int i = 0; i < m_count; i++) - { - if (m_buffer[i] == val) - return i; - } - return -1; - } - - template - int lastIndexOf(const T2 & val) const - { - for (int i = m_count - 1; i >= 0; i--) - { - if (m_buffer[i] == val) - return i; - } - return -1; - } - - template - int findFirstIndex(const Func& predicate) const - { - for (int i = 0; i < m_count; i++) - { - if (predicate(m_buffer[i])) - return i; - } - return -1; - } - - template - int findLastIndex(const Func& predicate) const - { - for (int i = m_count - 1; i >= 0; i--) - { - if (predicate(m_buffer[i])) - return i; - } - return -1; - } - }; - - template - ArrayView makeArrayView(T& obj) - { - return ArrayView(obj); - } - - template - ArrayView makeArrayView(T* buffer, int count) - { - return ArrayView(buffer, count); - } -} - -#endif diff --git a/source/core/array.h b/source/core/array.h deleted file mode 100644 index 2a5fa0aa7..000000000 --- a/source/core/array.h +++ /dev/null @@ -1,135 +0,0 @@ -#ifndef CORE_LIB_ARRAY_H -#define CORE_LIB_ARRAY_H - -#include "exception.h" -#include "array-view.h" - -namespace Slang -{ - template - class Array - { - private: - T m_buffer[COUNT]; - int m_count = 0; - public: - T* begin() { return m_buffer; } - const T* begin() const { return m_buffer; } - - const T* end() const { return m_buffer + m_count; } - T* end() { return m_buffer + m_count; } - - public: - inline int getCapacity() const { return COUNT; } - inline int getCount() const { return m_count; } - inline const T& getFirst() const - { - SLANG_ASSERT(m_count > 0); - return m_buffer[0]; - } - inline T& getFirst() - { - SLANG_ASSERT(m_count > 0); - return m_buffer[0]; - } - inline const T& getLast() const - { - SLANG_ASSERT(m_count > 0); - return m_buffer[m_count - 1]; - } - inline T& getLast() - { - SLANG_ASSERT(m_count > 0); - return m_buffer[m_count - 1]; - } - inline void setCount(int newCount) - { - SLANG_ASSERT(newCount >= 0 && newCount <= COUNT); - m_count = newCount; - } - inline void add(const T & item) - { - SLANG_ASSERT(m_count < COUNT); - m_buffer[m_count++] = item; - } - inline void add(T && item) - { - SLANG_ASSERT(m_count < COUNT); - m_buffer[m_count++] = _Move(item); - } - - inline const T& operator [](int idx) const - { - SLANG_ASSERT(idx >= 0 && idx < m_count); - return m_buffer[idx]; - } - inline T& operator [](int idx) - { - SLANG_ASSERT(idx >= 0 && idx < m_count); - return m_buffer[idx]; - } - - inline const T* getBuffer() const { return m_buffer; } - inline T* getBuffer() { return m_buffer; } - - inline void clear() { m_count = 0; } - - template - int indexOf(const T2& val) const - { - for (int i = 0; i < m_count; i++) - { - if (m_buffer[i] == val) - return i; - } - return -1; - } - - template - int lastIndexOf(const T2& val) const - { - for (int i = m_count - 1; i >= 0; i--) - { - if (m_buffer[i] == val) - return i; - } - return -1; - } - - inline ArrayView getArrayView() const - { - return ArrayView((T*)m_buffer, m_count); - } - inline ArrayView getArrayView(int start, int count) const - { - return ArrayView((T*)m_buffer + start, count); - } - }; - - template - struct FirstType - { - typedef T Type; - }; - - - template - void insertArray(Array&) {} - - template - void insertArray(Array& arr, const T& val, TArgs... args) - { - arr.add(val); - insertArray(arr, args...); - } - - template - auto makeArray(TArgs ...args) -> Array::Type, sizeof...(args)> - { - Array::Type, sizeof...(args)> rs; - insertArray(rs, args...); - return rs; - } -} - -#endif diff --git a/source/core/basic.h b/source/core/basic.h deleted file mode 100644 index e89d740bf..000000000 --- a/source/core/basic.h +++ /dev/null @@ -1,13 +0,0 @@ -#ifndef CORE_LIB_BASIC_H -#define CORE_LIB_BASIC_H - -#include "common.h" -#include "slang-math.h" -#include "slang-string.h" -#include "array.h" -#include "list.h" -#include "smart-pointer.h" -#include "exception.h" -#include "dictionary.h" - -#endif \ No newline at end of file diff --git a/source/core/common.h b/source/core/common.h deleted file mode 100644 index 0e5396caf..000000000 --- a/source/core/common.h +++ /dev/null @@ -1,95 +0,0 @@ -#ifndef CORE_LIB_COMMON_H -#define CORE_LIB_COMMON_H - -#include "../../slang.h" - -#include - -#include - -#ifdef __GNUC__ -#define CORE_LIB_ALIGN_16(x) x __attribute__((aligned(16))) -#else -#define CORE_LIB_ALIGN_16(x) __declspec(align(16)) x -#endif - -#define VARIADIC_TEMPLATE - -namespace Slang -{ - typedef int32_t Int32; - typedef uint32_t UInt32; - - typedef int64_t Int64; - typedef uint64_t UInt64; - - // Define - typedef SlangUInt UInt; - typedef SlangInt Int; - -// typedef unsigned short Word; - - typedef intptr_t PtrInt; - - // Type used for indexing, in arrays/views etc - typedef Int Index; - - template - inline T&& _Move(T & obj) - { - return static_cast(obj); - } - - template - inline void Swap(T & v0, T & v1) - { - T tmp = _Move(v0); - v0 = _Move(v1); - v1 = _Move(tmp); - } - -#ifdef _MSC_VER -# define SLANG_RETURN_NEVER __declspec(noreturn) -//#elif SLANG_CLANG -//# define SLANG_RETURN_NEVER [[noreturn]] -#else -# define SLANG_RETURN_NEVER [[noreturn]] -//# define SLANG_RETURN_NEVER /* empty */ -#endif - -#ifdef _MSC_VER -#define UNREACHABLE_RETURN(x) -#define UNREACHABLE(x) -#else -#define UNREACHABLE_RETURN(x) return x; -#define UNREACHABLE(x) x; -#endif - - SLANG_RETURN_NEVER void signalUnexpectedError(char const* message); -} - -#define SLANG_UNEXPECTED(reason) \ - Slang::signalUnexpectedError("unexpected: " reason) - -#define SLANG_UNIMPLEMENTED_X(what) \ - Slang::signalUnexpectedError("unimplemented: " what) - -#define SLANG_UNREACHABLE(msg) \ - Slang::signalUnexpectedError("unreachable code executed: " msg) - -#ifdef _DEBUG -#define SLANG_EXPECT(VALUE, MSG) if(VALUE) {} else Slang::signalUnexpectedError("assertion failed: '" MSG "'") -#define SLANG_ASSERT(VALUE) SLANG_EXPECT(VALUE, #VALUE) -#else -#define SLANG_EXPECT(VALUE, MSG) do {} while(0) -#define SLANG_ASSERT(VALUE) do {} while(0) -#endif - -#define SLANG_RELEASE_ASSERT(VALUE) if(VALUE) {} else Slang::signalUnexpectedError("assertion failed") -#define SLANG_RELEASE_EXPECT(VALUE, WHAT) if(VALUE) {} else SLANG_UNEXPECTED(WHAT) - -template void slang_use_obj(T&) {} - -#define SLANG_UNREFERENCED_PARAMETER(P) slang_use_obj(P) -#define SLANG_UNREFERENCED_VARIABLE(P) slang_use_obj(P) -#endif diff --git a/source/core/core.vcxproj b/source/core/core.vcxproj index 0416eaed6..a8e92949f 100644 --- a/source/core/core.vcxproj +++ b/source/core/core.vcxproj @@ -170,59 +170,59 @@ - - - - - - - - - - - + + + + + + + + + + + + + + + + - - - - - - + + + + - - - diff --git a/source/core/core.vcxproj.filters b/source/core/core.vcxproj.filters index 0a0ea93fa..8656ab49b 100644 --- a/source/core/core.vcxproj.filters +++ b/source/core/core.vcxproj.filters @@ -9,46 +9,40 @@ - + Header Files - + Header Files - + Header Files - + Header Files - - Header Files - - - Header Files - - + Header Files - + Header Files - + Header Files - + Header Files - + Header Files - + Header Files - + Header Files - + Header Files @@ -60,18 +54,30 @@ Header Files + + Header Files + Header Files Header Files + + Header Files + Header Files + + Header Files + Header Files + + Header Files + Header Files @@ -84,32 +90,23 @@ Header Files - - Header Files - - - Header Files - - + Header Files - + Header Files - + Header Files - + Header Files - + Header Files - - Source Files - Source Files @@ -125,6 +122,9 @@ Source Files + + Source Files + Source Files @@ -137,6 +137,9 @@ Source Files + + Source Files + Source Files @@ -149,19 +152,16 @@ Source Files - + Source Files - + Source Files - - Source Files - - + Source Files - + Source Files diff --git a/source/core/dictionary.h b/source/core/dictionary.h deleted file mode 100644 index 1b3525756..000000000 --- a/source/core/dictionary.h +++ /dev/null @@ -1,619 +0,0 @@ -#ifndef CORE_LIB_DICTIONARY_H -#define CORE_LIB_DICTIONARY_H -#include "list.h" -#include "common.h" -#include "slang-uint-set.h" -#include "exception.h" -#include "slang-math.h" -#include "hash.h" - -namespace Slang -{ - template - class KeyValuePair - { - public: - TKey Key; - TValue Value; - KeyValuePair() - {} - KeyValuePair(const TKey & key, const TValue & value) - { - Key = key; - Value = value; - } - KeyValuePair(TKey && key, TValue && value) - { - Key = _Move(key); - Value = _Move(value); - } - KeyValuePair(TKey && key, const TValue & value) - { - Key = _Move(key); - Value = value; - } - KeyValuePair(const KeyValuePair & _that) - { - Key = _that.Key; - Value = _that.Value; - } - KeyValuePair(KeyValuePair && _that) - { - operator=(_Move(_that)); - } - KeyValuePair & operator=(KeyValuePair && that) - { - Key = _Move(that.Key); - Value = _Move(that.Value); - return *this; - } - KeyValuePair & operator=(const KeyValuePair & that) - { - Key = that.Key; - Value = that.Value; - return *this; - } - int GetHashCode() - { - return GetHashCode(Key); - } - }; - - template - inline KeyValuePair KVPair(const TKey & k, const TValue & v) - { - return KeyValuePair(k, v); - } - - const float MaxLoadFactor = 0.7f; - - template - class Dictionary - { - friend class Iterator; - friend class ItemProxy; - private: - inline int GetProbeOffset(int /*probeId*/) const - { - // quadratic probing - return 1; - } - private: - int bucketSizeMinusOne; - int _count; - UIntSet marks; - KeyValuePair* hashMap; - void Free() - { - if (hashMap) - delete[] hashMap; - hashMap = 0; - } - inline bool IsDeleted(int pos) const - { - return marks.contains((pos << 1) + 1); - } - inline bool IsEmpty(int pos) const - { - return !marks.contains((pos << 1)); - } - inline void SetDeleted(int pos, bool val) - { - if (val) - marks.add((pos << 1) + 1); - else - marks.remove((pos << 1) + 1); - } - inline void SetEmpty(int pos, bool val) - { - if (val) - marks.remove((pos << 1)); - else - marks.add((pos << 1)); - } - struct FindPositionResult - { - int ObjectPosition; - int InsertionPosition; - FindPositionResult() - { - ObjectPosition = -1; - InsertionPosition = -1; - } - FindPositionResult(int objPos, int insertPos) - { - ObjectPosition = objPos; - InsertionPosition = insertPos; - } - - }; - inline int GetHashPos(TKey& key) const - { - return ((unsigned int)(GetHashCode(key) * 2654435761)) % bucketSizeMinusOne; - } - FindPositionResult FindPosition(const TKey& key) const - { - int hashPos = GetHashPos(const_cast(key)); - int insertPos = -1; - int numProbes = 0; - while (numProbes <= bucketSizeMinusOne) - { - if (IsEmpty(hashPos)) - { - if (insertPos == -1) - return FindPositionResult(-1, hashPos); - else - return FindPositionResult(-1, insertPos); - } - else if (IsDeleted(hashPos)) - { - if (insertPos == -1) - insertPos = hashPos; - } - else if (hashMap[hashPos].Key == key) - { - return FindPositionResult(hashPos, -1); - } - numProbes++; - hashPos = (hashPos + GetProbeOffset(numProbes)) & bucketSizeMinusOne; - } - if (insertPos != -1) - return FindPositionResult(-1, insertPos); - throw InvalidOperationException("Hash map is full. This indicates an error in Key::Equal or Key::GetHashCode."); - } - TValue & _Insert(KeyValuePair&& kvPair, int pos) - { - hashMap[pos] = _Move(kvPair); - SetEmpty(pos, false); - SetDeleted(pos, false); - return hashMap[pos].Value; - } - void Rehash() - { - if (bucketSizeMinusOne == -1 || _count >= int(MaxLoadFactor * bucketSizeMinusOne)) - { - int newSize = (bucketSizeMinusOne + 1) * 2; - if (newSize == 0) - { - newSize = 16; - } - Dictionary newDict; - newDict.bucketSizeMinusOne = newSize - 1; - newDict.hashMap = new KeyValuePair[newSize]; - newDict.marks.resizeAndClear(newSize * 2); - if (hashMap) - { - for (auto & kvPair : *this) - { - newDict.Add(_Move(kvPair)); - } - } - *this = _Move(newDict); - } - } - - bool AddIfNotExists(KeyValuePair&& kvPair) - { - Rehash(); - auto pos = FindPosition(kvPair.Key); - if (pos.ObjectPosition != -1) - return false; - else if (pos.InsertionPosition != -1) - { - _count++; - _Insert(_Move(kvPair), pos.InsertionPosition); - return true; - } - else - throw InvalidOperationException("Inconsistent find result returned. This is a bug in Dictionary implementation."); - } - void Add(KeyValuePair&& kvPair) - { - if (!AddIfNotExists(_Move(kvPair))) - throw KeyExistsException("The key already exists in Dictionary."); - } - TValue& Set(KeyValuePair&& kvPair) - { - Rehash(); - auto pos = FindPosition(kvPair.Key); - if (pos.ObjectPosition != -1) - return _Insert(_Move(kvPair), pos.ObjectPosition); - else if (pos.InsertionPosition != -1) - { - _count++; - return _Insert(_Move(kvPair), pos.InsertionPosition); - } - else - throw InvalidOperationException("Inconsistent find result returned. This is a bug in Dictionary implementation."); - } - public: - class Iterator - { - private: - const Dictionary * dict; - int pos; - public: - KeyValuePair & operator *() const - { - return dict->hashMap[pos]; - } - KeyValuePair * operator ->() const - { - return dict->hashMap + pos; - } - Iterator & operator ++() - { - if (pos > dict->bucketSizeMinusOne) - return *this; - pos++; - while (pos <= dict->bucketSizeMinusOne && (dict->IsDeleted(pos) || dict->IsEmpty(pos))) - { - pos++; - } - return *this; - } - Iterator operator ++(int) - { - Iterator rs = *this; - operator++(); - return rs; - } - bool operator != (const Iterator & _that) const - { - return pos != _that.pos || dict != _that.dict; - } - bool operator == (const Iterator & _that) const - { - return pos == _that.pos && dict == _that.dict; - } - Iterator(const Dictionary * _dict, int _pos) - { - this->dict = _dict; - this->pos = _pos; - } - Iterator() - { - this->dict = 0; - this->pos = 0; - } - }; - - Iterator begin() const - { - int pos = 0; - while (pos < bucketSizeMinusOne + 1) - { - if (IsEmpty(pos) || IsDeleted(pos)) - pos++; - else - break; - } - return Iterator(this, pos); - } - Iterator end() const - { - return Iterator(this, bucketSizeMinusOne + 1); - } - public: - void Add(const TKey & key, const TValue & value) - { - Add(KeyValuePair(key, value)); - } - void Add(TKey && key, TValue && value) - { - Add(KeyValuePair(_Move(key), _Move(value))); - } - bool AddIfNotExists(const TKey & key, const TValue & value) - { - return AddIfNotExists(KeyValuePair(key, value)); - } - bool AddIfNotExists(TKey && key, TValue && value) - { - return AddIfNotExists(KeyValuePair(_Move(key), _Move(value))); - } - void Remove(const TKey & key) - { - if (_count == 0) - return; - auto pos = FindPosition(key); - if (pos.ObjectPosition != -1) - { - SetDeleted(pos.ObjectPosition, true); - _count--; - } - } - void Clear() - { - _count = 0; - - marks.clear(); - } - - TValue* TryGetValueOrAdd(const TKey& key, const TValue& value) - { - Rehash(); - auto pos = FindPosition(key); - if (pos.ObjectPosition != -1) - { - return &hashMap[pos.ObjectPosition].Value; - } - else if (pos.InsertionPosition != -1) - { - // Make pair - KeyValuePair kvPair(_Move(key), _Move(value)); - _count++; - _Insert(_Move(kvPair), pos.InsertionPosition); - return nullptr; - } - else - throw InvalidOperationException("Inconsistent find result returned. This is a bug in Dictionary implementation."); - } - - bool ContainsKey(const TKey& key) const - { - if (bucketSizeMinusOne == -1) - return false; - auto pos = FindPosition(key); - return pos.ObjectPosition != -1; - } - bool TryGetValue(const TKey& key, TValue& value) const - { - if (bucketSizeMinusOne == -1) - return false; - auto pos = FindPosition(key); - if (pos.ObjectPosition != -1) - { - value = hashMap[pos.ObjectPosition].Value; - return true; - } - return false; - } - TValue* TryGetValue(const TKey& key) const - { - if (bucketSizeMinusOne == -1) - return nullptr; - auto pos = FindPosition(key); - if (pos.ObjectPosition != -1) - { - return &hashMap[pos.ObjectPosition].Value; - } - return nullptr; - } - - class ItemProxy - { - private: - const Dictionary * dict; - TKey key; - public: - ItemProxy(const TKey& _key, const Dictionary* _dict) - { - this->dict = _dict; - this->key = _key; - } - ItemProxy(TKey&& _key, const Dictionary* _dict) - { - this->dict = _dict; - this->key = _Move(_key); - } - TValue & GetValue() const - { - auto pos = dict->FindPosition(key); - if (pos.ObjectPosition != -1) - { - return dict->hashMap[pos.ObjectPosition].Value; - } - else - throw KeyNotFoundException("The key does not exists in dictionary."); - } - inline TValue & operator()() const - { - return GetValue(); - } - operator TValue&() const - { - return GetValue(); - } - TValue & operator = (const TValue & val) const - { - return ((Dictionary*)dict)->Set(KeyValuePair(_Move(key), val)); - } - TValue & operator = (TValue && val) const - { - return ((Dictionary*)dict)->Set(KeyValuePair(_Move(key), _Move(val))); - } - }; - ItemProxy operator [](const TKey & key) const - { - return ItemProxy(key, this); - } - ItemProxy operator [](TKey && key) const - { - return ItemProxy(_Move(key), this); - } - int Count() const - { - return _count; - } - private: - template - void Init(const KeyValuePair & kvPair, Args... args) - { - Add(kvPair); - Init(args...); - } - public: - Dictionary() - { - bucketSizeMinusOne = -1; - _count = 0; - hashMap = nullptr; - } - template - Dictionary(Arg arg, Args... args) - { - Init(arg, args...); - } - Dictionary(const Dictionary& other) - : bucketSizeMinusOne(-1), _count(0), hashMap(nullptr) - { - *this = other; - } - Dictionary(Dictionary&& other) - : bucketSizeMinusOne(-1), _count(0), hashMap(nullptr) - { - *this = (_Move(other)); - } - Dictionary& operator = (const Dictionary& other) - { - if (this == &other) - return *this; - Free(); - bucketSizeMinusOne = other.bucketSizeMinusOne; - _count = other._count; - hashMap = new KeyValuePair[other.bucketSizeMinusOne + 1]; - marks = other.marks; - for (int i = 0; i <= bucketSizeMinusOne; i++) - hashMap[i] = other.hashMap[i]; - return *this; - } - Dictionary & operator = (Dictionary&& other) - { - if (this == &other) - return *this; - Free(); - bucketSizeMinusOne = other.bucketSizeMinusOne; - _count = other._count; - hashMap = other.hashMap; - marks = _Move(other.marks); - other.hashMap = 0; - other._count = 0; - other.bucketSizeMinusOne = -1; - return *this; - } - ~Dictionary() - { - Free(); - } - }; - - class _DummyClass - {}; - - template - class HashSetBase - { - protected: - DictionaryType dict; - private: - template - void Init(const T & v, Args... args) - { - Add(v); - Init(args...); - } - public: - HashSetBase() - {} - template - HashSetBase(Arg arg, Args... args) - { - Init(arg, args...); - } - HashSetBase(const HashSetBase & set) - { - operator=(set); - } - HashSetBase(HashSetBase && set) - { - operator=(_Move(set)); - } - HashSetBase & operator = (const HashSetBase & set) - { - dict = set.dict; - return *this; - } - HashSetBase & operator = (HashSetBase && set) - { - dict = _Move(set.dict); - return *this; - } - public: - class Iterator - { - private: - typename DictionaryType::Iterator iter; - public: - Iterator() = default; - T & operator *() const - { - return (*iter).Key; - } - T * operator ->() const - { - return &(*iter).Key; - } - Iterator & operator ++() - { - ++iter; - return *this; - } - Iterator operator ++(int) - { - Iterator rs = *this; - operator++(); - return rs; - } - bool operator != (const Iterator & _that) const - { - return iter != _that.iter; - } - bool operator == (const Iterator & _that) const - { - return iter == _that.iter; - } - Iterator(const typename DictionaryType::Iterator & _iter) - { - this->iter = _iter; - } - }; - Iterator begin() const - { - return Iterator(dict.begin()); - } - Iterator end() const - { - return Iterator(dict.end()); - } - public: - int Count() const - { - return dict.Count(); - } - void Clear() - { - dict.Clear(); - } - bool Add(const T& obj) - { - return dict.AddIfNotExists(obj, _DummyClass()); - } - bool Add(T && obj) - { - return dict.AddIfNotExists(_Move(obj), _DummyClass()); - } - void Remove(const T & obj) - { - dict.Remove(obj); - } - bool Contains(const T & obj) const - { - return dict.ContainsKey(obj); - } - }; - template - class HashSet : public HashSetBase> - {}; -} - -#endif diff --git a/source/core/exception.h b/source/core/exception.h deleted file mode 100644 index fc7aa48e2..000000000 --- a/source/core/exception.h +++ /dev/null @@ -1,137 +0,0 @@ -#ifndef CORE_LIB_EXCEPTION_H -#define CORE_LIB_EXCEPTION_H - -#include "common.h" -#include "slang-string.h" - -namespace Slang -{ - class Exception - { - public: - String Message; - Exception() - {} - Exception(const String & message) - : Message(message) - { - } - - virtual ~Exception() - {} - }; - - class IndexOutofRangeException : public Exception - { - public: - IndexOutofRangeException() - {} - IndexOutofRangeException(const String & message) - : Exception(message) - { - } - - }; - - class InvalidOperationException : public Exception - { - public: - InvalidOperationException() - {} - InvalidOperationException(const String & message) - : Exception(message) - { - } - - }; - - class ArgumentException : public Exception - { - public: - ArgumentException() - {} - ArgumentException(const String & message) - : Exception(message) - { - } - - }; - - class KeyNotFoundException : public Exception - { - public: - KeyNotFoundException() - {} - KeyNotFoundException(const String & message) - : Exception(message) - { - } - }; - class KeyExistsException : public Exception - { - public: - KeyExistsException() - {} - KeyExistsException(const String & message) - : Exception(message) - { - } - }; - - class NotSupportedException : public Exception - { - public: - NotSupportedException() - {} - NotSupportedException(const String & message) - : Exception(message) - { - } - }; - - class NotImplementedException : public Exception - { - public: - NotImplementedException() - {} - NotImplementedException(const String & message) - : Exception(message) - { - } - }; - - class InvalidProgramException : public Exception - { - public: - InvalidProgramException() - {} - InvalidProgramException(const String & message) - : Exception(message) - { - } - }; - - class InternalError : public Exception - { - public: - InternalError() - {} - InternalError(const String & message) - : Exception(message) - { - } - }; - - class AbortCompilationException : public Exception - { - public: - AbortCompilationException() - {} - AbortCompilationException(const String & message) - : Exception(message) - { - } - }; -} - -#endif \ No newline at end of file diff --git a/source/core/hash.h b/source/core/hash.h deleted file mode 100644 index 83e99179b..000000000 --- a/source/core/hash.h +++ /dev/null @@ -1,153 +0,0 @@ -#ifndef CORELIB_HASH_H -#define CORELIB_HASH_H - -#include "slang-math.h" -#include -#include - -namespace Slang -{ - typedef int HashCode; - - inline int GetHashCode(double key) - { - return FloatAsInt((float)key); - } - inline int GetHashCode(float key) - { - return FloatAsInt(key); - } - inline int GetHashCode(const char * buffer) - { - if (!buffer) - return 0; - int hash = 0; - int c; - auto str = buffer; - c = *str++; - while (c) - { - hash = c + (hash << 6) + (hash << 16) - hash; - c = *str++; - } - return hash; - } - inline int GetHashCode(char * buffer) - { - return GetHashCode(const_cast(buffer)); - } - inline int GetHashCode(const char * buffer, size_t numChars) - { - int hash = 0; - for (size_t i = 0; i < numChars; ++i) - { - hash = int(buffer[i]) + (hash << 6) + (hash << 16) - hash; - } - return hash; - } - - inline uint64_t GetHashCode64(const char * buffer, size_t numChars) - { - // Use uints because hash requires wrap around behavior and int is undefined on over/underflows - uint64_t hash = 0; - for (size_t i = 0; i < numChars; ++i) - { - hash = uint64_t(int64_t(buffer[i])) + (hash << 6) + (hash << 16) - hash; - } - return hash; - } - - template - class Hash - { - public: - }; - template<> - class Hash<1> - { - public: - template - static int GetHashCode(TKey & key) - { - return (int)key; - } - }; - template<> - class Hash<0> - { - public: - template - static int GetHashCode(TKey & key) - { - return int(key.GetHashCode()); - } - }; - template - class PointerHash - {}; - template<> - class PointerHash<1> - { - public: - template - static int GetHashCode(TKey const& key) - { - return (int)((PtrInt)key) / 16; // sizeof(typename std::remove_pointer::type); - } - }; - template<> - class PointerHash<0> - { - public: - template - static int GetHashCode(TKey & key) - { - return Hash::value || std::is_enum::value>::GetHashCode(key); - } - }; - - template - int GetHashCode(const TKey & key) - { - return PointerHash::value>::GetHashCode(key); - } - - template - int GetHashCode(TKey & key) - { - return PointerHash::value>::GetHashCode(key); - } - - inline int combineHash(int left, int right) - { - return (left * 16777619) ^ right; - } - - struct Hasher - { - public: - Hasher() {} - - template - void hashValue(T const& value) - { - m_hashCode = combineHash(m_hashCode, GetHashCode(value)); - } - - template - void hashObject(T const& object) - { - m_hashCode = combineHash(m_hashCode, object->GetHashCode()); - } - - HashCode getResult() const - { - return m_hashCode; - } - - private: - HashCode m_hashCode = 0; - }; -} - -#endif diff --git a/source/core/list.h b/source/core/list.h deleted file mode 100644 index 7ba313305..000000000 --- a/source/core/list.h +++ /dev/null @@ -1,631 +0,0 @@ -#ifndef FUNDAMENTAL_LIB_LIST_H -#define FUNDAMENTAL_LIB_LIST_H - -#include "../../slang.h" - -#include "allocator.h" -#include "slang-math.h" -#include "array-view.h" - -#include -#include -#include - - -namespace Slang -{ - - template - class Initializer - { - - }; - - template - class Initializer - { - public: - static void initialize(T* buffer, int size) - { - for (int i = 0; i - class Initializer - { - public: - static void initialize(T* buffer, int size) - { - // It's pod so no initialization required - //for (int i = 0; i < size; i++) - // new (buffer + i) T; - } - }; - - template - class AllocateMethod - { - public: - static inline T* allocateArray(Index count) - { - TAllocator allocator; - T * rs = (T*)allocator.allocate(count * sizeof(T)); - Initializer::value>::initialize(rs, count); - return rs; - } - static inline void deallocateArray(T* ptr, Index count) - { - TAllocator allocator; - if (!std::is_trivially_destructible::value) - { - for (Index i = 0; i < count; i++) - ptr[i].~T(); - } - allocator.deallocate(ptr); - } - }; - - template - class AllocateMethod - { - public: - static inline T* allocateArray(Index count) - { - return new T[count]; - } - static inline void deallocateArray(T* ptr, Index /*bufferSize*/) - { - delete [] ptr; - } - }; - - - template - class List - { - private: - static const Index kInitialCount = 16; - - public: - List() - : m_buffer(nullptr), m_count(0), m_capacity(0) - { - } - template - List(const T& val, Args... args) - { - _init(val, args...); - } - List(const List& list) - : m_buffer(nullptr), m_count(0), m_capacity(0) - { - this->operator=(list); - } - List(List&& list) - : m_buffer(nullptr), m_count(0), m_capacity(0) - { - this->operator=(static_cast&&>(list)); - } - static List makeRepeated(const T& val, Index count) - { - List rs; - rs.setCount(count); - for (Index i = 0; i < count; i++) - rs[i] = val; - return rs; - } - ~List() - { - _deallocateBuffer(); - } - List& operator=(const List& list) - { - clearAndDeallocate(); - addRange(list); - return *this; - } - - List& operator=(List&& list) - { - // Could just do a swap here, and memory would be freed on rhs dtor - - _deallocateBuffer(); - m_count = list.m_count; - m_capacity = list.m_capacity; - m_buffer = list.m_buffer; - - list.m_buffer = nullptr; - list.m_count = 0; - list.m_capacity = 0; - return *this; - } - - // TODO(JS): These should be made const safe but some other code depends on this behavior for now. - T* begin() const { return m_buffer; } - T* end() const { return m_buffer + m_count; } - - const T& getFirst() const - { - SLANG_ASSERT(m_count > 0); - return m_buffer[0]; - } - - const T& getLast() const - { - SLANG_ASSERT(m_count > 0); - return m_buffer[m_count-1]; - } - - T& getFirst() - { - SLANG_ASSERT(m_count > 0); - return m_buffer[0]; - } - - T& getLast() - { - SLANG_ASSERT(m_count > 0); - return m_buffer[m_count - 1]; - } - - void removeLast() - { - SLANG_ASSERT(m_count > 0); - m_count--; - } - - inline void swapWith(List& other) - { - T* buffer = m_buffer; - m_buffer = other.m_buffer; - other.m_buffer = buffer; - - auto bufferSize = m_capacity; - m_capacity = other.m_capacity; - other.m_capacity = bufferSize; - - auto count = m_count; - m_count = other.m_count; - other.m_count = count; - } - - T* detachBuffer() - { - T* rs = m_buffer; - m_buffer = nullptr; - m_count = 0; - m_capacity = 0; - return rs; - } - - inline ArrayView getArrayView() const - { - return ArrayView(m_buffer, m_count); - } - - inline ArrayView getArrayView(Index start, Index count) const - { - SLANG_ASSERT(start >= 0 && count >= 0 && start + count <= m_count); - return ArrayView(m_buffer + start, count); - } - - void add(T&& obj) - { - if (m_capacity < m_count + 1) - { - Index newBufferSize = kInitialCount; - if (m_capacity) - newBufferSize = (m_capacity << 1); - - reserve(newBufferSize); - } - m_buffer[m_count++] = static_cast(obj); - } - - void add(const T& obj) - { - if (m_capacity < m_count + 1) - { - Index newBufferSize = kInitialCount; - if (m_capacity) - newBufferSize = (m_capacity << 1); - - reserve(newBufferSize); - } - m_buffer[m_count++] = obj; - - } - - Index getCount() const { return m_count; } - Index getCapacity() const { return m_capacity; } - - const T* getBuffer() const { return m_buffer; } - T* getBuffer() { return m_buffer; } - - void insert(Index idx, const T& val) { insertRange(idx, &val, 1); } - - void insertRange(Index idx, const T* vals, Index n) - { - if (m_capacity < m_count + n) - { - Index newBufferCount = kInitialCount; - while (newBufferCount < m_count + n) - newBufferCount = newBufferCount << 1; - - T* newBuffer = _allocate(newBufferCount); - if (m_capacity) - { - /*if (std::has_trivial_copy_assign::value && std::has_trivial_destructor::value) - { - memcpy(newBuffer, buffer, sizeof(T) * id); - memcpy(newBuffer + id + n, buffer + id, sizeof(T) * (_count - id)); - } - else*/ - { - for (Index i = 0; i < idx; i++) - newBuffer[i] = m_buffer[i]; - for (Index i = idx; i < m_count; i++) - newBuffer[i + n] = T(static_cast(m_buffer[i])); - } - _deallocateBuffer(); - } - m_buffer = newBuffer; - m_capacity = newBufferCount; - } - else - { - /*if (std::has_trivial_copy_assign::value && std::has_trivial_destructor::value) - memmove(buffer + id + n, buffer + id, sizeof(T) * (_count - id)); - else*/ - { - for (Index i = m_count; i > idx; i--) - m_buffer[i + n - 1] = static_cast(m_buffer[i - 1]); - } - } - /*if (std::has_trivial_copy_assign::value && std::has_trivial_destructor::value) - memcpy(buffer + id, vals, sizeof(T) * n); - else*/ - for (Index i = 0; i < n; i++) - m_buffer[idx + i] = vals[i]; - - m_count += n; - } - - //slower than original edition - //void Add(const T & val) - //{ - // InsertRange(_count, &val, 1); - //} - - void insertRange(Index id, const List& list) { insertRange(id, list.m_buffer, list.m_count); } - - void addRange(ArrayView list) { insertRange(m_count, list.getBuffer(), list.Count()); } - - void addRange(const T* vals, Index n) { insertRange(m_count, vals, n); } - - void addRange(const List& list) { insertRange(m_count, list.m_buffer, list.m_count); } - - void removeRange(Index idx, Index count) - { - SLANG_ASSERT(idx >= 0 && idx <= m_count); - - const Index actualDeleteCount = ((idx + count) >= m_count)? (m_count - idx) : count; - for (Index i = idx + actualDeleteCount; i < m_count; i++) - m_buffer[i - actualDeleteCount] = static_cast(m_buffer[i]); - m_count -= actualDeleteCount; - } - - void removeAt(Index id) { removeRange(id, 1); } - - void remove(const T& val) - { - Index idx = indexOf(val); - if (idx != -1) - removeAt(idx); - } - - void reverse() - { - for (Index i = 0; i < (m_count >> 1); i++) - { - swapElements(m_buffer, i, m_count - i - 1); - } - } - - void fastRemove(const T& val) - { - Index idx = indexOf(val); - fastRemoveAt(idx); - } - - void fastRemoveAt(Index idx) - { - if (idx != -1 && m_count - 1 != idx) - { - m_buffer[idx] = _Move(m_buffer[m_count - 1]); - } - m_count--; - } - - void clear() { m_count = 0; } - - void clearAndDeallocate() - { - _deallocateBuffer(); - m_count = m_capacity = 0; - } - - void reserve(Index size) - { - if(size > m_capacity) - { - T* newBuffer = _allocate(size); - if (m_capacity) - { - /*if (std::has_trivial_copy_assign::value && std::has_trivial_destructor::value) - memcpy(newBuffer, buffer, _count * sizeof(T)); - else*/ - { - for (Index i = 0; i < m_count; i++) - newBuffer[i] = static_cast(m_buffer[i]); - - // Default-initialize the remaining elements - for(Index i = m_count; i < size; i++) - { - new(newBuffer + i) T(); - } - } - _deallocateBuffer(); - } - m_buffer = newBuffer; - m_capacity = size; - } - } - - void growToCount(Index count) - { - Index newBufferCount = Index(1) << Math::Log2Ceil(count); - if (m_capacity < newBufferCount) - { - reserve(newBufferCount); - } - m_count = count; - } - - void setCount(Index count) - { - reserve(count); - m_count = count; - } - - void unsafeShrinkToCount(Index count) { m_count = count; } - - void compress() - { - if (m_capacity > m_count && m_count > 0) - { - T* newBuffer = _allocate(m_count); - for (Index i = 0; i < m_count; i++) - newBuffer[i] = static_cast(m_buffer[i]); - - _deallocateBuffer(); - m_buffer = newBuffer; - m_capacity = m_count; - } - } - - SLANG_FORCE_INLINE T& operator [](Index idx) const - { - SLANG_ASSERT(idx >= 0 && idx <= m_count); - return m_buffer[idx]; - } - - template - Index findFirstIndex(const Func& predicate) const - { - for (Index i = 0; i < m_count; i++) - { - if (predicate(m_buffer[i])) - return i; - } - return (Index)-1; - } - - template - Index findLastIndex(const Func& predicate) const - { - for (Index i = m_count; i > 0; i--) - { - if (predicate(m_buffer[i-1])) - return i-1; - } - return (Index)-1; - } - - template - Index indexOf(const T2& val) const - { - for (Index i = 0; i < m_count; i++) - { - if (m_buffer[i] == val) - return i; - } - return (Index)-1; - } - - template - Index lastIndexOf(const T2& val) const - { - for (Index i = m_count; i > 0; i--) - { - if(m_buffer[i-1] == val) - return i-1; - } - return (Index)-1; - } - - void sort() - { - sort([](const T& t1, const T& t2){return t1 < t2;}); - } - - bool contains(const T& val) const { return indexOf(val) != Index(-1); } - - template - void sort(Comparer compare) - { - //insertionSort(buffer, 0, _count - 1); - //quickSort(buffer, 0, _count - 1, compare); - std::sort(m_buffer, m_buffer + m_count, compare); - } - - template - void forEach(IterateFunc f) const - { - for (Index i = 0; i< m_count; i++) - f(m_buffer[i]); - } - - template - void quickSort(T* vals, Index startIndex, Index endIndex, Comparer comparer) - { - static const Index kMinQSortSize = 32; - - if(startIndex < endIndex) - { - if (endIndex - startIndex < kMinQSortSize) - insertionSort(vals, startIndex, endIndex, comparer); - else - { - Index pivotIndex = (startIndex + endIndex) >> 1; - Index pivotNewIndex = partition(vals, startIndex, endIndex, pivotIndex, comparer); - quickSort(vals, startIndex, pivotNewIndex - 1, comparer); - quickSort(vals, pivotNewIndex + 1, endIndex, comparer); - } - } - - } - template - Index partition(T* vals, Index left, Index right, Index pivotIndex, Comparer comparer) - { - T pivotValue = vals[pivotIndex]; - swapElements(vals, right, pivotIndex); - Index storeIndex = left; - for (Index i = left; i < right; i++) - { - if (comparer(vals[i], pivotValue)) - { - swapElements(vals, i, storeIndex); - storeIndex++; - } - } - swapElements(vals, storeIndex, right); - return storeIndex; - } - template - void insertionSort(T* vals, Index startIndex, Index endIndex, Comparer comparer) - { - for (Index i = startIndex + 1; i <= endIndex; i++) - { - T insertValue = static_cast(vals[i]); - Index insertIndex = i - 1; - while (insertIndex >= startIndex && comparer(insertValue, vals[insertIndex])) - { - vals[insertIndex + 1] = static_cast(vals[insertIndex]); - insertIndex--; - } - vals[insertIndex + 1] = static_cast(insertValue); - } - } - - inline void swapElements(T* vals, Index index1, Index index2) - { - if (index1 != index2) - { - T tmp = static_cast(vals[index1]); - vals[index1] = static_cast(vals[index2]); - vals[index2] = static_cast(tmp); - } - } - - template - Index binarySearch(const T2& obj, Comparer comparer) - { - Index imin = 0, imax = m_count - 1; - while (imax >= imin) - { - Index imid = (imin + imax) >> 1; - int compareResult = comparer(m_buffer[imid], obj); - if (compareResult == 0) - return imid; - else if (compareResult < 0) - imin = imid + 1; - else - imax = imid - 1; - } - return -1; - } - - template - int binarySearch(const T2& obj) - { - return binarySearch(obj, - [](T & curObj, const T2 & thatObj)->int - { - if (curObj < thatObj) - return -1; - else if (curObj == thatObj) - return 0; - else - return 1; - }); - } - private: - T* m_buffer; ///< A new T[N] allocated buffer. NOTE! All elements up to capacity are in some valid form for T. - Index m_capacity; ///< The total capacity of elements - Index m_count; ///< The amount of elements - - void _deallocateBuffer() - { - if (m_buffer) - { - AllocateMethod::deallocateArray(m_buffer, m_capacity); - m_buffer = nullptr; - } - } - static inline T* _allocate(Index count) - { - return AllocateMethod::allocateArray(count); - } - - template - void _init(const T& val, Args... args) - { - add(val); - _init(args...); - } - }; - - template - T calcMin(const List& list) - { - T minVal = list.getFirst(); - for (Index i = 1; i < list.getCount(); i++) - if (list[i] < minVal) - minVal = list[i]; - return minVal; - } - - template - T calcMax(const List& list) - { - T maxVal = list.getFirst(); - for (Index i = 1; i< list.getCount(); i++) - if (list[i] > maxVal) - maxVal = list[i]; - return maxVal; - } -} - -#endif diff --git a/source/core/platform.cpp b/source/core/platform.cpp deleted file mode 100644 index 0deec8ed6..000000000 --- a/source/core/platform.cpp +++ /dev/null @@ -1,172 +0,0 @@ -// platform.cpp -#include "platform.h" - -#include "common.h" - -#ifdef _WIN32 - #define WIN32_LEAN_AND_MEAN - #define NOMINMAX - #include - #undef WIN32_LEAN_AND_MEAN - #undef NOMINMAX -#else - #include "slang-string.h" - #include -#endif - -namespace Slang -{ - // SharedLibrary - -/* static */SlangResult SharedLibrary::load(const char* filename, SharedLibrary::Handle& handleOut) -{ - StringBuilder builder; - appendPlatformFileName(UnownedStringSlice(filename), builder); - return loadWithPlatformFilename(builder.begin(), handleOut); -} - -#ifdef _WIN32 - -// Make sure SlangResult match for common standard window HRESULT -SLANG_COMPILE_TIME_ASSERT(E_FAIL == SLANG_FAIL); -SLANG_COMPILE_TIME_ASSERT(E_NOINTERFACE == SLANG_E_NO_INTERFACE); -SLANG_COMPILE_TIME_ASSERT(E_HANDLE == SLANG_E_INVALID_HANDLE); -SLANG_COMPILE_TIME_ASSERT(E_NOTIMPL == SLANG_E_NOT_IMPLEMENTED); -SLANG_COMPILE_TIME_ASSERT(E_INVALIDARG == SLANG_E_INVALID_ARG); -SLANG_COMPILE_TIME_ASSERT(E_OUTOFMEMORY == SLANG_E_OUT_OF_MEMORY); - -/* static */SlangResult PlatformUtil::appendResult(SlangResult res, StringBuilder& builderOut) -{ - if (SLANG_FAILED(res) && res != SLANG_FAIL) - { - LPWSTR buffer = nullptr; - FormatMessage(FORMAT_MESSAGE_FROM_SYSTEM | FORMAT_MESSAGE_ALLOCATE_BUFFER, - nullptr, - res, - MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), // Default language - (LPWSTR)&buffer, - 0, - nullptr); - - if (buffer) - { - builderOut << " "; - // Convert to string - builderOut.Append(String::fromWString(buffer)); - LocalFree(buffer); - return SLANG_OK; - } - } - return SLANG_FAIL; -} - -/* static */SlangResult SharedLibrary::loadWithPlatformFilename(char const* platformFileName, SharedLibrary::Handle& handleOut) -{ - handleOut = nullptr; - // https://docs.microsoft.com/en-us/windows/desktop/api/libloaderapi/nf-libloaderapi-loadlibrarya - const HMODULE h = LoadLibraryA(platformFileName); - if (!h) - { - const DWORD lastError = GetLastError(); - switch (lastError) - { - case ERROR_MOD_NOT_FOUND: - case ERROR_PATH_NOT_FOUND: - case ERROR_FILE_NOT_FOUND: - { - return SLANG_E_NOT_FOUND; - } - case ERROR_INVALID_ACCESS: - case ERROR_ACCESS_DENIED: - case ERROR_INVALID_DATA: - { - return SLANG_E_CANNOT_OPEN; - } - default: break; - } - // Turn to Result, if not one of the well known errors - return HRESULT_FROM_WIN32(lastError); - } - handleOut = (Handle)h; - return SLANG_OK; -} - -/* static */void SharedLibrary::unload(Handle handle) -{ - SLANG_ASSERT(handle); - ::FreeLibrary((HMODULE)handle); -} - -/* static */SharedLibrary::FuncPtr SharedLibrary::findFuncByName(Handle handle, char const* name) -{ - SLANG_ASSERT(handle); - return (FuncPtr)GetProcAddress((HMODULE)handle, name); -} - -/* static */void SharedLibrary::appendPlatformFileName(const UnownedStringSlice& name, StringBuilder& dst) -{ - // Windows doesn't need the extension or any prefix to work - dst.Append(name); -} - -#else // _WIN32 - -/* static */SlangResult PlatformUtil::appendResult(SlangResult res, StringBuilder& builderOut) -{ - return SLANG_E_NOT_IMPLEMENTED; -} - -/* static */SlangResult SharedLibrary::loadWithPlatformFilename(char const* platformFileName, Handle& handleOut) -{ - handleOut = nullptr; - - void* h = dlopen(platformFileName, RTLD_NOW | RTLD_LOCAL); - if(!h) - { -#if 0 - // We can't output the error message here, because it will cause output when testing what code gen is available - if(auto msg = dlerror()) - { - fprintf(stderr, "error: %s\n", msg); - } -#endif - return SLANG_FAIL; - } - handleOut = (Handle)h; - return SLANG_OK; -} - -/* static */void SharedLibrary::unload(Handle handle) -{ - SLANG_ASSERT(handle); - dlclose(handle); -} - -/* static */SharedLibrary::FuncPtr SharedLibrary::findFuncByName(Handle handle, char const* name) -{ - SLANG_ASSERT(handle); - return (FuncPtr)dlsym((void*)handle, name); -} - -/* static */void SharedLibrary::appendPlatformFileName(const UnownedStringSlice& name, StringBuilder& dst) -{ -#if __CYGWIN__ - dst.Append(name); - dst.Append(".dll"); -#elif SLANG_APPLE_FAMILY - dst.Append("lib"); - dst.Append(name); - dst.Append(".dylib"); -#elif SLANG_LINUX_FAMILY - dst.Append("lib"); - dst.Append(name); - dst.Append(".so"); -#else - // Just guess we can do with the name on it's own - dst.Append(name); -#endif -} - -#endif // _WIN32 - -} diff --git a/source/core/platform.h b/source/core/platform.h deleted file mode 100644 index 544ae8c16..000000000 --- a/source/core/platform.h +++ /dev/null @@ -1,67 +0,0 @@ -// platform.h -#ifndef SLANG_CORE_PLATFORM_H_INCLUDED -#define SLANG_CORE_PLATFORM_H_INCLUDED - -#include "../../slang.h" -#include "../core/slang-string.h" - -namespace Slang -{ - // Interface for working with shared libraries - // in a platform-independent fashion. - struct SharedLibrary - { - typedef struct SharedLibraryImpl* Handle; - - typedef void(*FuncPtr)(void); - - /// Load via an unadorned filename - /// - /// @param the unadorned filename - /// @return Returns a non null handle for the shared library on success. nullptr indicated failure - static SlangResult load(const char* filename, Handle& handleOut); - - /// Attempt to load a shared library for - /// the current platform. Returns null handle on failure - /// The platform specific filename can be generated from a call to appendPlatformFileName - /// - /// @param platformFileName the platform specific file name. - /// @return Returns a non null handle for the shared library on success. nullptr indicated failure - static SlangResult loadWithPlatformFilename(char const* platformFileName, Handle& handleOut); - - /// Unload the library that was returned from load as handle - /// @param The valid handle returned from load - static void unload(Handle handle); - - /// Given a shared library handle and a name, return the associated function - /// Return nullptr if function is not found - /// @param The shared library handle as returned by loadPlatformLibrary - static FuncPtr findFuncByName(Handle handle, char const* name); - - /// Append to the end of dst, the name, with any platform specific additions - /// The input name should be unadorned with any 'lib' prefix or extension - static void appendPlatformFileName(const UnownedStringSlice& name, StringBuilder& dst); - - private: - /// Not constructible! - SharedLibrary(); - }; - - struct PlatformUtil - { - /// Appends a text interpretation of a result (as defined by supporting OS) - /// @param res Result to produce a string for - /// @param builderOut Append the string produced to builderOut - /// @return SLANG_OK if string is found and appended. Fail otherwise. SLANG_E_NOT_IMPLEMENTED if there is no impl for this platform. - static SlangResult appendResult(SlangResult res, StringBuilder& builderOut); - }; - -#ifndef _MSC_VER - #define _fileno fileno - #define _isatty isatty - #define _setmode setmode - #define _O_BINARY O_BINARY -#endif -} - -#endif diff --git a/source/core/secure-crt.h b/source/core/secure-crt.h deleted file mode 100644 index 52a0d4870..000000000 --- a/source/core/secure-crt.h +++ /dev/null @@ -1,88 +0,0 @@ -#ifndef _MSC_VER -#ifndef CORE_LIB_SECURE_CRT_H -#define CORE_LIB_SECURE_CRT_H -#include -#include -#include -#include -#include - -#include - -inline void memcpy_s(void *dest, size_t numberOfElements, const void * src, size_t count) -{ - memcpy(dest, src, count); -} - -#define _TRUNCATE ((size_t)-1) -#define _stricmp strcasecmp - -inline void fopen_s(FILE**f, const char * fileName, const char * mode) -{ - *f = fopen(fileName, mode); -} - -inline size_t fread_s(void * buffer, size_t bufferSize, size_t elementSize, size_t count, FILE * stream) -{ - return fread(buffer, elementSize, count, stream); -} - -inline size_t wcsnlen_s(const wchar_t * str, size_t /*numberofElements*/) -{ - return wcslen(str); -} - -inline size_t strnlen_s(const char * str, size_t numberOfElements) -{ -#if defined( __CYGWIN__ ) - const char* cur = str; - if (str) - { - const char*const end = str + numberOfElements; - while (*cur && cur < end) cur++; - } - return size_t(cur - str); -#else - return strnlen(str, numberOfElements); -#endif -} - -inline int sprintf_s(char * buffer, size_t sizeOfBuffer, const char * format, ...) -{ - va_list argptr; - va_start(argptr, format); - int rs = vsnprintf(buffer, sizeOfBuffer, format, argptr); - va_end(argptr); - return rs; -} - -inline int swprintf_s(wchar_t * buffer, size_t sizeOfBuffer, const wchar_t * format, ...) -{ - va_list argptr; - va_start(argptr, format); - int rs = vswprintf(buffer, sizeOfBuffer, format, argptr); - va_end(argptr); - return rs; -} - -inline void wcscpy_s(wchar_t * strDestination, size_t /*numberOfElements*/, const wchar_t * strSource) -{ - wcscpy(strDestination, strSource); -} -inline void strcpy_s(char * strDestination, size_t /*numberOfElements*/, const char * strSource) -{ - strcpy(strDestination, strSource); -} - -inline void wcsncpy_s(wchar_t * strDestination, size_t /*numberOfElements*/, const wchar_t * strSource, size_t count) -{ - wcscpy(strDestination, strSource); - //wcsncpy(strDestination, strSource, count); -} -inline void strncpy_s(char * strDestination, size_t /*numberOfElements*/, const char * strSource, size_t count) -{ - strncpy(strDestination, strSource, count); - //wcsncpy(strDestination, strSource, count); -} -#endif -#endif diff --git a/source/core/slang-allocator.h b/source/core/slang-allocator.h new file mode 100644 index 000000000..481f8810f --- /dev/null +++ b/source/core/slang-allocator.h @@ -0,0 +1,64 @@ +#ifndef SLANG_CORE_ALLOCATOR_H +#define SLANG_CORE_ALLOCATOR_H + +#include +#ifdef _MSC_VER +# include +#endif + +namespace Slang +{ + inline void* alignedAllocate(size_t size, size_t alignment) + { +#ifdef _MSC_VER + return _aligned_malloc(size, alignment); +#elif defined(__CYGWIN__) + return aligned_alloc(alignment, size); +#else + void * rs = 0; + int succ = posix_memalign(&rs, alignment, size); + if (succ!=0) + rs = 0; + return rs; +#endif + } + + inline void alignedDeallocate(void* ptr) + { +#ifdef _MSC_VER + _aligned_free(ptr); +#else + free(ptr); +#endif + } + + class StandardAllocator + { + public: + // not really called + void* allocate(size_t size) + { + return ::malloc(size); + } + void deallocate(void * ptr) + { + return ::free(ptr); + } + }; + + template + class AlignedAllocator + { + public: + void* allocate(size_t size) + { + return alignedAllocate(size, ALIGNMENT); + } + void deallocate(void * ptr) + { + return alignedDeallocate(ptr); + } + }; +} + +#endif diff --git a/source/core/slang-array-view.h b/source/core/slang-array-view.h new file mode 100644 index 000000000..8b653f4c7 --- /dev/null +++ b/source/core/slang-array-view.h @@ -0,0 +1,112 @@ +#ifndef SLANG_CORE_ARRAY_VIEW_H +#define SLANG_CORE_ARRAY_VIEW_H + +#include "slang-common.h" + +namespace Slang +{ + template + class ArrayView + { + private: + T* m_buffer; + int m_count; + public: + const T* begin() const { return m_buffer; } + T* begin() { return m_buffer; } + + const T* end() const { return m_buffer + m_count; } + T* end() { return m_buffer + m_count; } + + public: + ArrayView(): + m_buffer(nullptr), + m_count(0) + { + } + ArrayView(T& singleObj): + m_buffer(&singleObj), + m_count(1) + { + } + ArrayView(T* buffer, int size): + m_buffer(buffer), + m_count(size) + { + } + + inline int getCount() const { return m_count; } + + inline const T& operator [](int idx) const + { + SLANG_ASSERT(idx >= 0 && idx <= m_count); + return m_buffer[idx]; + } + inline T& operator [](int idx) + { + SLANG_ASSERT(idx >= 0 && idx <= m_count); + return m_buffer[idx]; + } + + inline const T* getBuffer() const { return m_buffer; } + inline T* getBuffer() { return m_buffer; } + + template + int indexOf(const T2 & val) const + { + for (int i = 0; i < m_count; i++) + { + if (m_buffer[i] == val) + return i; + } + return -1; + } + + template + int lastIndexOf(const T2 & val) const + { + for (int i = m_count - 1; i >= 0; i--) + { + if (m_buffer[i] == val) + return i; + } + return -1; + } + + template + int findFirstIndex(const Func& predicate) const + { + for (int i = 0; i < m_count; i++) + { + if (predicate(m_buffer[i])) + return i; + } + return -1; + } + + template + int findLastIndex(const Func& predicate) const + { + for (int i = m_count - 1; i >= 0; i--) + { + if (predicate(m_buffer[i])) + return i; + } + return -1; + } + }; + + template + ArrayView makeArrayView(T& obj) + { + return ArrayView(obj); + } + + template + ArrayView makeArrayView(T* buffer, int count) + { + return ArrayView(buffer, count); + } +} + +#endif diff --git a/source/core/slang-array.h b/source/core/slang-array.h new file mode 100644 index 000000000..d4bb7386f --- /dev/null +++ b/source/core/slang-array.h @@ -0,0 +1,135 @@ +#ifndef SLANG_CORE_ARRAY_H +#define SLANG_CORE_ARRAY_H + +#include "slang-exception.h" +#include "slang-array-view.h" + +namespace Slang +{ + template + class Array + { + private: + T m_buffer[COUNT]; + int m_count = 0; + public: + T* begin() { return m_buffer; } + const T* begin() const { return m_buffer; } + + const T* end() const { return m_buffer + m_count; } + T* end() { return m_buffer + m_count; } + + public: + inline int getCapacity() const { return COUNT; } + inline int getCount() const { return m_count; } + inline const T& getFirst() const + { + SLANG_ASSERT(m_count > 0); + return m_buffer[0]; + } + inline T& getFirst() + { + SLANG_ASSERT(m_count > 0); + return m_buffer[0]; + } + inline const T& getLast() const + { + SLANG_ASSERT(m_count > 0); + return m_buffer[m_count - 1]; + } + inline T& getLast() + { + SLANG_ASSERT(m_count > 0); + return m_buffer[m_count - 1]; + } + inline void setCount(int newCount) + { + SLANG_ASSERT(newCount >= 0 && newCount <= COUNT); + m_count = newCount; + } + inline void add(const T & item) + { + SLANG_ASSERT(m_count < COUNT); + m_buffer[m_count++] = item; + } + inline void add(T && item) + { + SLANG_ASSERT(m_count < COUNT); + m_buffer[m_count++] = _Move(item); + } + + inline const T& operator [](int idx) const + { + SLANG_ASSERT(idx >= 0 && idx < m_count); + return m_buffer[idx]; + } + inline T& operator [](int idx) + { + SLANG_ASSERT(idx >= 0 && idx < m_count); + return m_buffer[idx]; + } + + inline const T* getBuffer() const { return m_buffer; } + inline T* getBuffer() { return m_buffer; } + + inline void clear() { m_count = 0; } + + template + int indexOf(const T2& val) const + { + for (int i = 0; i < m_count; i++) + { + if (m_buffer[i] == val) + return i; + } + return -1; + } + + template + int lastIndexOf(const T2& val) const + { + for (int i = m_count - 1; i >= 0; i--) + { + if (m_buffer[i] == val) + return i; + } + return -1; + } + + inline ArrayView getArrayView() const + { + return ArrayView((T*)m_buffer, m_count); + } + inline ArrayView getArrayView(int start, int count) const + { + return ArrayView((T*)m_buffer + start, count); + } + }; + + template + struct FirstType + { + typedef T Type; + }; + + + template + void insertArray(Array&) {} + + template + void insertArray(Array& arr, const T& val, TArgs... args) + { + arr.add(val); + insertArray(arr, args...); + } + + template + auto makeArray(TArgs ...args) -> Array::Type, sizeof...(args)> + { + Array::Type, sizeof...(args)> rs; + insertArray(rs, args...); + return rs; + } +} + +#endif diff --git a/source/core/slang-basic.h b/source/core/slang-basic.h new file mode 100644 index 000000000..7931749f4 --- /dev/null +++ b/source/core/slang-basic.h @@ -0,0 +1,13 @@ +#ifndef SLANG_CORE_BASIC_H +#define SLANG_CORE_BASIC_H + +#include "slang-common.h" +#include "slang-math.h" +#include "slang-string.h" +#include "slang-array.h" +#include "slang-list.h" +#include "slang-smart-pointer.h" +#include "slang-exception.h" +#include "slang-dictionary.h" + +#endif diff --git a/source/core/slang-byte-encode-util.cpp b/source/core/slang-byte-encode-util.cpp index 47ab824a4..32eb96a29 100644 --- a/source/core/slang-byte-encode-util.cpp +++ b/source/core/slang-byte-encode-util.cpp @@ -1,7 +1,5 @@ #include "slang-byte-encode-util.h" - - namespace Slang { // Descriptions of algorithms here... diff --git a/source/core/slang-byte-encode-util.h b/source/core/slang-byte-encode-util.h index 5936cae60..cb601d522 100644 --- a/source/core/slang-byte-encode-util.h +++ b/source/core/slang-byte-encode-util.h @@ -1,7 +1,7 @@ -#ifndef SLANG_BYTE_ENCODE_UTIL_H -#define SLANG_BYTE_ENCODE_UTIL_H +#ifndef SLANG_CORE_BYTE_ENCODE_UTIL_H +#define SLANG_CORE_BYTE_ENCODE_UTIL_H -#include "list.h" +#include "slang-list.h" namespace Slang { diff --git a/source/core/slang-common.h b/source/core/slang-common.h new file mode 100644 index 000000000..7d8568642 --- /dev/null +++ b/source/core/slang-common.h @@ -0,0 +1,95 @@ +#ifndef SLANG_CORE_COMMON_H +#define SLANG_CORE_COMMON_H + +#include "../../slang.h" + +#include + +#include + +#ifdef __GNUC__ +#define CORE_LIB_ALIGN_16(x) x __attribute__((aligned(16))) +#else +#define CORE_LIB_ALIGN_16(x) __declspec(align(16)) x +#endif + +#define VARIADIC_TEMPLATE + +namespace Slang +{ + typedef int32_t Int32; + typedef uint32_t UInt32; + + typedef int64_t Int64; + typedef uint64_t UInt64; + + // Define + typedef SlangUInt UInt; + typedef SlangInt Int; + +// typedef unsigned short Word; + + typedef intptr_t PtrInt; + + // Type used for indexing, in arrays/views etc + typedef Int Index; + + template + inline T&& _Move(T & obj) + { + return static_cast(obj); + } + + template + inline void Swap(T & v0, T & v1) + { + T tmp = _Move(v0); + v0 = _Move(v1); + v1 = _Move(tmp); + } + +#ifdef _MSC_VER +# define SLANG_RETURN_NEVER __declspec(noreturn) +//#elif SLANG_CLANG +//# define SLANG_RETURN_NEVER [[noreturn]] +#else +# define SLANG_RETURN_NEVER [[noreturn]] +//# define SLANG_RETURN_NEVER /* empty */ +#endif + +#ifdef _MSC_VER +#define UNREACHABLE_RETURN(x) +#define UNREACHABLE(x) +#else +#define UNREACHABLE_RETURN(x) return x; +#define UNREACHABLE(x) x; +#endif + + SLANG_RETURN_NEVER void signalUnexpectedError(char const* message); +} + +#define SLANG_UNEXPECTED(reason) \ + Slang::signalUnexpectedError("unexpected: " reason) + +#define SLANG_UNIMPLEMENTED_X(what) \ + Slang::signalUnexpectedError("unimplemented: " what) + +#define SLANG_UNREACHABLE(msg) \ + Slang::signalUnexpectedError("unreachable code executed: " msg) + +#ifdef _DEBUG +#define SLANG_EXPECT(VALUE, MSG) if(VALUE) {} else Slang::signalUnexpectedError("assertion failed: '" MSG "'") +#define SLANG_ASSERT(VALUE) SLANG_EXPECT(VALUE, #VALUE) +#else +#define SLANG_EXPECT(VALUE, MSG) do {} while(0) +#define SLANG_ASSERT(VALUE) do {} while(0) +#endif + +#define SLANG_RELEASE_ASSERT(VALUE) if(VALUE) {} else Slang::signalUnexpectedError("assertion failed") +#define SLANG_RELEASE_EXPECT(VALUE, WHAT) if(VALUE) {} else SLANG_UNEXPECTED(WHAT) + +template void slang_use_obj(T&) {} + +#define SLANG_UNREFERENCED_PARAMETER(P) slang_use_obj(P) +#define SLANG_UNREFERENCED_VARIABLE(P) slang_use_obj(P) +#endif diff --git a/source/core/slang-dictionary.h b/source/core/slang-dictionary.h new file mode 100644 index 000000000..69e8022dd --- /dev/null +++ b/source/core/slang-dictionary.h @@ -0,0 +1,620 @@ +#ifndef SLANG_CORE_DICTIONARY_H +#define SLANG_CORE_DICTIONARY_H + +#include "slang-list.h" +#include "slang-common.h" +#include "slang-uint-set.h" +#include "slang-exception.h" +#include "slang-math.h" +#include "slang-hash.h" + +namespace Slang +{ + template + class KeyValuePair + { + public: + TKey Key; + TValue Value; + KeyValuePair() + {} + KeyValuePair(const TKey & key, const TValue & value) + { + Key = key; + Value = value; + } + KeyValuePair(TKey && key, TValue && value) + { + Key = _Move(key); + Value = _Move(value); + } + KeyValuePair(TKey && key, const TValue & value) + { + Key = _Move(key); + Value = value; + } + KeyValuePair(const KeyValuePair & _that) + { + Key = _that.Key; + Value = _that.Value; + } + KeyValuePair(KeyValuePair && _that) + { + operator=(_Move(_that)); + } + KeyValuePair & operator=(KeyValuePair && that) + { + Key = _Move(that.Key); + Value = _Move(that.Value); + return *this; + } + KeyValuePair & operator=(const KeyValuePair & that) + { + Key = that.Key; + Value = that.Value; + return *this; + } + int GetHashCode() + { + return GetHashCode(Key); + } + }; + + template + inline KeyValuePair KVPair(const TKey & k, const TValue & v) + { + return KeyValuePair(k, v); + } + + const float MaxLoadFactor = 0.7f; + + template + class Dictionary + { + friend class Iterator; + friend class ItemProxy; + private: + inline int GetProbeOffset(int /*probeId*/) const + { + // quadratic probing + return 1; + } + private: + int bucketSizeMinusOne; + int _count; + UIntSet marks; + KeyValuePair* hashMap; + void Free() + { + if (hashMap) + delete[] hashMap; + hashMap = 0; + } + inline bool IsDeleted(int pos) const + { + return marks.contains((pos << 1) + 1); + } + inline bool IsEmpty(int pos) const + { + return !marks.contains((pos << 1)); + } + inline void SetDeleted(int pos, bool val) + { + if (val) + marks.add((pos << 1) + 1); + else + marks.remove((pos << 1) + 1); + } + inline void SetEmpty(int pos, bool val) + { + if (val) + marks.remove((pos << 1)); + else + marks.add((pos << 1)); + } + struct FindPositionResult + { + int ObjectPosition; + int InsertionPosition; + FindPositionResult() + { + ObjectPosition = -1; + InsertionPosition = -1; + } + FindPositionResult(int objPos, int insertPos) + { + ObjectPosition = objPos; + InsertionPosition = insertPos; + } + + }; + inline int GetHashPos(TKey& key) const + { + return ((unsigned int)(GetHashCode(key) * 2654435761)) % bucketSizeMinusOne; + } + FindPositionResult FindPosition(const TKey& key) const + { + int hashPos = GetHashPos(const_cast(key)); + int insertPos = -1; + int numProbes = 0; + while (numProbes <= bucketSizeMinusOne) + { + if (IsEmpty(hashPos)) + { + if (insertPos == -1) + return FindPositionResult(-1, hashPos); + else + return FindPositionResult(-1, insertPos); + } + else if (IsDeleted(hashPos)) + { + if (insertPos == -1) + insertPos = hashPos; + } + else if (hashMap[hashPos].Key == key) + { + return FindPositionResult(hashPos, -1); + } + numProbes++; + hashPos = (hashPos + GetProbeOffset(numProbes)) & bucketSizeMinusOne; + } + if (insertPos != -1) + return FindPositionResult(-1, insertPos); + throw InvalidOperationException("Hash map is full. This indicates an error in Key::Equal or Key::GetHashCode."); + } + TValue & _Insert(KeyValuePair&& kvPair, int pos) + { + hashMap[pos] = _Move(kvPair); + SetEmpty(pos, false); + SetDeleted(pos, false); + return hashMap[pos].Value; + } + void Rehash() + { + if (bucketSizeMinusOne == -1 || _count >= int(MaxLoadFactor * bucketSizeMinusOne)) + { + int newSize = (bucketSizeMinusOne + 1) * 2; + if (newSize == 0) + { + newSize = 16; + } + Dictionary newDict; + newDict.bucketSizeMinusOne = newSize - 1; + newDict.hashMap = new KeyValuePair[newSize]; + newDict.marks.resizeAndClear(newSize * 2); + if (hashMap) + { + for (auto & kvPair : *this) + { + newDict.Add(_Move(kvPair)); + } + } + *this = _Move(newDict); + } + } + + bool AddIfNotExists(KeyValuePair&& kvPair) + { + Rehash(); + auto pos = FindPosition(kvPair.Key); + if (pos.ObjectPosition != -1) + return false; + else if (pos.InsertionPosition != -1) + { + _count++; + _Insert(_Move(kvPair), pos.InsertionPosition); + return true; + } + else + throw InvalidOperationException("Inconsistent find result returned. This is a bug in Dictionary implementation."); + } + void Add(KeyValuePair&& kvPair) + { + if (!AddIfNotExists(_Move(kvPair))) + throw KeyExistsException("The key already exists in Dictionary."); + } + TValue& Set(KeyValuePair&& kvPair) + { + Rehash(); + auto pos = FindPosition(kvPair.Key); + if (pos.ObjectPosition != -1) + return _Insert(_Move(kvPair), pos.ObjectPosition); + else if (pos.InsertionPosition != -1) + { + _count++; + return _Insert(_Move(kvPair), pos.InsertionPosition); + } + else + throw InvalidOperationException("Inconsistent find result returned. This is a bug in Dictionary implementation."); + } + public: + class Iterator + { + private: + const Dictionary * dict; + int pos; + public: + KeyValuePair & operator *() const + { + return dict->hashMap[pos]; + } + KeyValuePair * operator ->() const + { + return dict->hashMap + pos; + } + Iterator & operator ++() + { + if (pos > dict->bucketSizeMinusOne) + return *this; + pos++; + while (pos <= dict->bucketSizeMinusOne && (dict->IsDeleted(pos) || dict->IsEmpty(pos))) + { + pos++; + } + return *this; + } + Iterator operator ++(int) + { + Iterator rs = *this; + operator++(); + return rs; + } + bool operator != (const Iterator & _that) const + { + return pos != _that.pos || dict != _that.dict; + } + bool operator == (const Iterator & _that) const + { + return pos == _that.pos && dict == _that.dict; + } + Iterator(const Dictionary * _dict, int _pos) + { + this->dict = _dict; + this->pos = _pos; + } + Iterator() + { + this->dict = 0; + this->pos = 0; + } + }; + + Iterator begin() const + { + int pos = 0; + while (pos < bucketSizeMinusOne + 1) + { + if (IsEmpty(pos) || IsDeleted(pos)) + pos++; + else + break; + } + return Iterator(this, pos); + } + Iterator end() const + { + return Iterator(this, bucketSizeMinusOne + 1); + } + public: + void Add(const TKey & key, const TValue & value) + { + Add(KeyValuePair(key, value)); + } + void Add(TKey && key, TValue && value) + { + Add(KeyValuePair(_Move(key), _Move(value))); + } + bool AddIfNotExists(const TKey & key, const TValue & value) + { + return AddIfNotExists(KeyValuePair(key, value)); + } + bool AddIfNotExists(TKey && key, TValue && value) + { + return AddIfNotExists(KeyValuePair(_Move(key), _Move(value))); + } + void Remove(const TKey & key) + { + if (_count == 0) + return; + auto pos = FindPosition(key); + if (pos.ObjectPosition != -1) + { + SetDeleted(pos.ObjectPosition, true); + _count--; + } + } + void Clear() + { + _count = 0; + + marks.clear(); + } + + TValue* TryGetValueOrAdd(const TKey& key, const TValue& value) + { + Rehash(); + auto pos = FindPosition(key); + if (pos.ObjectPosition != -1) + { + return &hashMap[pos.ObjectPosition].Value; + } + else if (pos.InsertionPosition != -1) + { + // Make pair + KeyValuePair kvPair(_Move(key), _Move(value)); + _count++; + _Insert(_Move(kvPair), pos.InsertionPosition); + return nullptr; + } + else + throw InvalidOperationException("Inconsistent find result returned. This is a bug in Dictionary implementation."); + } + + bool ContainsKey(const TKey& key) const + { + if (bucketSizeMinusOne == -1) + return false; + auto pos = FindPosition(key); + return pos.ObjectPosition != -1; + } + bool TryGetValue(const TKey& key, TValue& value) const + { + if (bucketSizeMinusOne == -1) + return false; + auto pos = FindPosition(key); + if (pos.ObjectPosition != -1) + { + value = hashMap[pos.ObjectPosition].Value; + return true; + } + return false; + } + TValue* TryGetValue(const TKey& key) const + { + if (bucketSizeMinusOne == -1) + return nullptr; + auto pos = FindPosition(key); + if (pos.ObjectPosition != -1) + { + return &hashMap[pos.ObjectPosition].Value; + } + return nullptr; + } + + class ItemProxy + { + private: + const Dictionary * dict; + TKey key; + public: + ItemProxy(const TKey& _key, const Dictionary* _dict) + { + this->dict = _dict; + this->key = _key; + } + ItemProxy(TKey&& _key, const Dictionary* _dict) + { + this->dict = _dict; + this->key = _Move(_key); + } + TValue & GetValue() const + { + auto pos = dict->FindPosition(key); + if (pos.ObjectPosition != -1) + { + return dict->hashMap[pos.ObjectPosition].Value; + } + else + throw KeyNotFoundException("The key does not exists in dictionary."); + } + inline TValue & operator()() const + { + return GetValue(); + } + operator TValue&() const + { + return GetValue(); + } + TValue & operator = (const TValue & val) const + { + return ((Dictionary*)dict)->Set(KeyValuePair(_Move(key), val)); + } + TValue & operator = (TValue && val) const + { + return ((Dictionary*)dict)->Set(KeyValuePair(_Move(key), _Move(val))); + } + }; + ItemProxy operator [](const TKey & key) const + { + return ItemProxy(key, this); + } + ItemProxy operator [](TKey && key) const + { + return ItemProxy(_Move(key), this); + } + int Count() const + { + return _count; + } + private: + template + void Init(const KeyValuePair & kvPair, Args... args) + { + Add(kvPair); + Init(args...); + } + public: + Dictionary() + { + bucketSizeMinusOne = -1; + _count = 0; + hashMap = nullptr; + } + template + Dictionary(Arg arg, Args... args) + { + Init(arg, args...); + } + Dictionary(const Dictionary& other) + : bucketSizeMinusOne(-1), _count(0), hashMap(nullptr) + { + *this = other; + } + Dictionary(Dictionary&& other) + : bucketSizeMinusOne(-1), _count(0), hashMap(nullptr) + { + *this = (_Move(other)); + } + Dictionary& operator = (const Dictionary& other) + { + if (this == &other) + return *this; + Free(); + bucketSizeMinusOne = other.bucketSizeMinusOne; + _count = other._count; + hashMap = new KeyValuePair[other.bucketSizeMinusOne + 1]; + marks = other.marks; + for (int i = 0; i <= bucketSizeMinusOne; i++) + hashMap[i] = other.hashMap[i]; + return *this; + } + Dictionary & operator = (Dictionary&& other) + { + if (this == &other) + return *this; + Free(); + bucketSizeMinusOne = other.bucketSizeMinusOne; + _count = other._count; + hashMap = other.hashMap; + marks = _Move(other.marks); + other.hashMap = 0; + other._count = 0; + other.bucketSizeMinusOne = -1; + return *this; + } + ~Dictionary() + { + Free(); + } + }; + + class _DummyClass + {}; + + template + class HashSetBase + { + protected: + DictionaryType dict; + private: + template + void Init(const T & v, Args... args) + { + Add(v); + Init(args...); + } + public: + HashSetBase() + {} + template + HashSetBase(Arg arg, Args... args) + { + Init(arg, args...); + } + HashSetBase(const HashSetBase & set) + { + operator=(set); + } + HashSetBase(HashSetBase && set) + { + operator=(_Move(set)); + } + HashSetBase & operator = (const HashSetBase & set) + { + dict = set.dict; + return *this; + } + HashSetBase & operator = (HashSetBase && set) + { + dict = _Move(set.dict); + return *this; + } + public: + class Iterator + { + private: + typename DictionaryType::Iterator iter; + public: + Iterator() = default; + T & operator *() const + { + return (*iter).Key; + } + T * operator ->() const + { + return &(*iter).Key; + } + Iterator & operator ++() + { + ++iter; + return *this; + } + Iterator operator ++(int) + { + Iterator rs = *this; + operator++(); + return rs; + } + bool operator != (const Iterator & _that) const + { + return iter != _that.iter; + } + bool operator == (const Iterator & _that) const + { + return iter == _that.iter; + } + Iterator(const typename DictionaryType::Iterator & _iter) + { + this->iter = _iter; + } + }; + Iterator begin() const + { + return Iterator(dict.begin()); + } + Iterator end() const + { + return Iterator(dict.end()); + } + public: + int Count() const + { + return dict.Count(); + } + void Clear() + { + dict.Clear(); + } + bool Add(const T& obj) + { + return dict.AddIfNotExists(obj, _DummyClass()); + } + bool Add(T && obj) + { + return dict.AddIfNotExists(_Move(obj), _DummyClass()); + } + void Remove(const T & obj) + { + dict.Remove(obj); + } + bool Contains(const T & obj) const + { + return dict.ContainsKey(obj); + } + }; + template + class HashSet : public HashSetBase> + {}; +} + +#endif diff --git a/source/core/slang-exception.h b/source/core/slang-exception.h new file mode 100644 index 000000000..91139e298 --- /dev/null +++ b/source/core/slang-exception.h @@ -0,0 +1,137 @@ +#ifndef SLANG_CORE_EXCEPTION_H +#define SLANG_CORE_EXCEPTION_H + +#include "slang-common.h" +#include "slang-string.h" + +namespace Slang +{ + class Exception + { + public: + String Message; + Exception() + {} + Exception(const String & message) + : Message(message) + { + } + + virtual ~Exception() + {} + }; + + class IndexOutofRangeException : public Exception + { + public: + IndexOutofRangeException() + {} + IndexOutofRangeException(const String & message) + : Exception(message) + { + } + + }; + + class InvalidOperationException : public Exception + { + public: + InvalidOperationException() + {} + InvalidOperationException(const String & message) + : Exception(message) + { + } + + }; + + class ArgumentException : public Exception + { + public: + ArgumentException() + {} + ArgumentException(const String & message) + : Exception(message) + { + } + + }; + + class KeyNotFoundException : public Exception + { + public: + KeyNotFoundException() + {} + KeyNotFoundException(const String & message) + : Exception(message) + { + } + }; + class KeyExistsException : public Exception + { + public: + KeyExistsException() + {} + KeyExistsException(const String & message) + : Exception(message) + { + } + }; + + class NotSupportedException : public Exception + { + public: + NotSupportedException() + {} + NotSupportedException(const String & message) + : Exception(message) + { + } + }; + + class NotImplementedException : public Exception + { + public: + NotImplementedException() + {} + NotImplementedException(const String & message) + : Exception(message) + { + } + }; + + class InvalidProgramException : public Exception + { + public: + InvalidProgramException() + {} + InvalidProgramException(const String & message) + : Exception(message) + { + } + }; + + class InternalError : public Exception + { + public: + InternalError() + {} + InternalError(const String & message) + : Exception(message) + { + } + }; + + class AbortCompilationException : public Exception + { + public: + AbortCompilationException() + {} + AbortCompilationException(const String & message) + : Exception(message) + { + } + }; +} + +#endif diff --git a/source/core/slang-free-list.h b/source/core/slang-free-list.h index 62c2b9d93..ee0158279 100644 --- a/source/core/slang-free-list.h +++ b/source/core/slang-free-list.h @@ -1,9 +1,9 @@ -#ifndef SLANG_FREE_LIST_H -#define SLANG_FREE_LIST_H +#ifndef SLANG_CORE_FREE_LIST_H +#define SLANG_CORE_FREE_LIST_H #include "../../slang.h" -#include "common.h" +#include "slang-common.h" #include #include diff --git a/source/core/slang-hash.h b/source/core/slang-hash.h new file mode 100644 index 000000000..08a40491c --- /dev/null +++ b/source/core/slang-hash.h @@ -0,0 +1,153 @@ +#ifndef SLANG_CORE_HASH_H +#define SLANG_CORE_HASH_H + +#include "slang-math.h" +#include +#include + +namespace Slang +{ + typedef int HashCode; + + inline int GetHashCode(double key) + { + return FloatAsInt((float)key); + } + inline int GetHashCode(float key) + { + return FloatAsInt(key); + } + inline int GetHashCode(const char * buffer) + { + if (!buffer) + return 0; + int hash = 0; + int c; + auto str = buffer; + c = *str++; + while (c) + { + hash = c + (hash << 6) + (hash << 16) - hash; + c = *str++; + } + return hash; + } + inline int GetHashCode(char * buffer) + { + return GetHashCode(const_cast(buffer)); + } + inline int GetHashCode(const char * buffer, size_t numChars) + { + int hash = 0; + for (size_t i = 0; i < numChars; ++i) + { + hash = int(buffer[i]) + (hash << 6) + (hash << 16) - hash; + } + return hash; + } + + inline uint64_t GetHashCode64(const char * buffer, size_t numChars) + { + // Use uints because hash requires wrap around behavior and int is undefined on over/underflows + uint64_t hash = 0; + for (size_t i = 0; i < numChars; ++i) + { + hash = uint64_t(int64_t(buffer[i])) + (hash << 6) + (hash << 16) - hash; + } + return hash; + } + + template + class Hash + { + public: + }; + template<> + class Hash<1> + { + public: + template + static int GetHashCode(TKey & key) + { + return (int)key; + } + }; + template<> + class Hash<0> + { + public: + template + static int GetHashCode(TKey & key) + { + return int(key.GetHashCode()); + } + }; + template + class PointerHash + {}; + template<> + class PointerHash<1> + { + public: + template + static int GetHashCode(TKey const& key) + { + return (int)((PtrInt)key) / 16; // sizeof(typename std::remove_pointer::type); + } + }; + template<> + class PointerHash<0> + { + public: + template + static int GetHashCode(TKey & key) + { + return Hash::value || std::is_enum::value>::GetHashCode(key); + } + }; + + template + int GetHashCode(const TKey & key) + { + return PointerHash::value>::GetHashCode(key); + } + + template + int GetHashCode(TKey & key) + { + return PointerHash::value>::GetHashCode(key); + } + + inline int combineHash(int left, int right) + { + return (left * 16777619) ^ right; + } + + struct Hasher + { + public: + Hasher() {} + + template + void hashValue(T const& value) + { + m_hashCode = combineHash(m_hashCode, GetHashCode(value)); + } + + template + void hashObject(T const& object) + { + m_hashCode = combineHash(m_hashCode, object->GetHashCode()); + } + + HashCode getResult() const + { + return m_hashCode; + } + + private: + HashCode m_hashCode = 0; + }; +} + +#endif diff --git a/source/core/slang-io.cpp b/source/core/slang-io.cpp index 82a2e5a35..ae2520a78 100644 --- a/source/core/slang-io.cpp +++ b/source/core/slang-io.cpp @@ -1,5 +1,5 @@ #include "slang-io.h" -#include "exception.h" +#include "slang-exception.h" #include "../../slang-com-helper.h" diff --git a/source/core/slang-io.h b/source/core/slang-io.h index 9df8d8d57..64ec11c5a 100644 --- a/source/core/slang-io.h +++ b/source/core/slang-io.h @@ -1,10 +1,10 @@ -#ifndef CORE_LIB_IO_H -#define CORE_LIB_IO_H +#ifndef SLANG_CORE_IO_H +#define SLANG_CORE_IO_H #include "slang-string.h" -#include "stream.h" -#include "text-io.h" -#include "secure-crt.h" +#include "slang-stream.h" +#include "slang-text-io.h" +#include "slang-secure-crt.h" namespace Slang { diff --git a/source/core/slang-list.h b/source/core/slang-list.h new file mode 100644 index 000000000..0affbfd66 --- /dev/null +++ b/source/core/slang-list.h @@ -0,0 +1,631 @@ +#ifndef SLANG_CORE_LIST_H +#define SLANG_CORE_LIST_H + +#include "../../slang.h" + +#include "slang-allocator.h" +#include "slang-math.h" +#include "slang-array-view.h" + +#include +#include +#include + + +namespace Slang +{ + + template + class Initializer + { + + }; + + template + class Initializer + { + public: + static void initialize(T* buffer, int size) + { + for (int i = 0; i + class Initializer + { + public: + static void initialize(T* buffer, int size) + { + // It's pod so no initialization required + //for (int i = 0; i < size; i++) + // new (buffer + i) T; + } + }; + + template + class AllocateMethod + { + public: + static inline T* allocateArray(Index count) + { + TAllocator allocator; + T * rs = (T*)allocator.allocate(count * sizeof(T)); + Initializer::value>::initialize(rs, count); + return rs; + } + static inline void deallocateArray(T* ptr, Index count) + { + TAllocator allocator; + if (!std::is_trivially_destructible::value) + { + for (Index i = 0; i < count; i++) + ptr[i].~T(); + } + allocator.deallocate(ptr); + } + }; + + template + class AllocateMethod + { + public: + static inline T* allocateArray(Index count) + { + return new T[count]; + } + static inline void deallocateArray(T* ptr, Index /*bufferSize*/) + { + delete [] ptr; + } + }; + + + template + class List + { + private: + static const Index kInitialCount = 16; + + public: + List() + : m_buffer(nullptr), m_count(0), m_capacity(0) + { + } + template + List(const T& val, Args... args) + { + _init(val, args...); + } + List(const List& list) + : m_buffer(nullptr), m_count(0), m_capacity(0) + { + this->operator=(list); + } + List(List&& list) + : m_buffer(nullptr), m_count(0), m_capacity(0) + { + this->operator=(static_cast&&>(list)); + } + static List makeRepeated(const T& val, Index count) + { + List rs; + rs.setCount(count); + for (Index i = 0; i < count; i++) + rs[i] = val; + return rs; + } + ~List() + { + _deallocateBuffer(); + } + List& operator=(const List& list) + { + clearAndDeallocate(); + addRange(list); + return *this; + } + + List& operator=(List&& list) + { + // Could just do a swap here, and memory would be freed on rhs dtor + + _deallocateBuffer(); + m_count = list.m_count; + m_capacity = list.m_capacity; + m_buffer = list.m_buffer; + + list.m_buffer = nullptr; + list.m_count = 0; + list.m_capacity = 0; + return *this; + } + + // TODO(JS): These should be made const safe but some other code depends on this behavior for now. + T* begin() const { return m_buffer; } + T* end() const { return m_buffer + m_count; } + + const T& getFirst() const + { + SLANG_ASSERT(m_count > 0); + return m_buffer[0]; + } + + const T& getLast() const + { + SLANG_ASSERT(m_count > 0); + return m_buffer[m_count-1]; + } + + T& getFirst() + { + SLANG_ASSERT(m_count > 0); + return m_buffer[0]; + } + + T& getLast() + { + SLANG_ASSERT(m_count > 0); + return m_buffer[m_count - 1]; + } + + void removeLast() + { + SLANG_ASSERT(m_count > 0); + m_count--; + } + + inline void swapWith(List& other) + { + T* buffer = m_buffer; + m_buffer = other.m_buffer; + other.m_buffer = buffer; + + auto bufferSize = m_capacity; + m_capacity = other.m_capacity; + other.m_capacity = bufferSize; + + auto count = m_count; + m_count = other.m_count; + other.m_count = count; + } + + T* detachBuffer() + { + T* rs = m_buffer; + m_buffer = nullptr; + m_count = 0; + m_capacity = 0; + return rs; + } + + inline ArrayView getArrayView() const + { + return ArrayView(m_buffer, m_count); + } + + inline ArrayView getArrayView(Index start, Index count) const + { + SLANG_ASSERT(start >= 0 && count >= 0 && start + count <= m_count); + return ArrayView(m_buffer + start, count); + } + + void add(T&& obj) + { + if (m_capacity < m_count + 1) + { + Index newBufferSize = kInitialCount; + if (m_capacity) + newBufferSize = (m_capacity << 1); + + reserve(newBufferSize); + } + m_buffer[m_count++] = static_cast(obj); + } + + void add(const T& obj) + { + if (m_capacity < m_count + 1) + { + Index newBufferSize = kInitialCount; + if (m_capacity) + newBufferSize = (m_capacity << 1); + + reserve(newBufferSize); + } + m_buffer[m_count++] = obj; + + } + + Index getCount() const { return m_count; } + Index getCapacity() const { return m_capacity; } + + const T* getBuffer() const { return m_buffer; } + T* getBuffer() { return m_buffer; } + + void insert(Index idx, const T& val) { insertRange(idx, &val, 1); } + + void insertRange(Index idx, const T* vals, Index n) + { + if (m_capacity < m_count + n) + { + Index newBufferCount = kInitialCount; + while (newBufferCount < m_count + n) + newBufferCount = newBufferCount << 1; + + T* newBuffer = _allocate(newBufferCount); + if (m_capacity) + { + /*if (std::has_trivial_copy_assign::value && std::has_trivial_destructor::value) + { + memcpy(newBuffer, buffer, sizeof(T) * id); + memcpy(newBuffer + id + n, buffer + id, sizeof(T) * (_count - id)); + } + else*/ + { + for (Index i = 0; i < idx; i++) + newBuffer[i] = m_buffer[i]; + for (Index i = idx; i < m_count; i++) + newBuffer[i + n] = T(static_cast(m_buffer[i])); + } + _deallocateBuffer(); + } + m_buffer = newBuffer; + m_capacity = newBufferCount; + } + else + { + /*if (std::has_trivial_copy_assign::value && std::has_trivial_destructor::value) + memmove(buffer + id + n, buffer + id, sizeof(T) * (_count - id)); + else*/ + { + for (Index i = m_count; i > idx; i--) + m_buffer[i + n - 1] = static_cast(m_buffer[i - 1]); + } + } + /*if (std::has_trivial_copy_assign::value && std::has_trivial_destructor::value) + memcpy(buffer + id, vals, sizeof(T) * n); + else*/ + for (Index i = 0; i < n; i++) + m_buffer[idx + i] = vals[i]; + + m_count += n; + } + + //slower than original edition + //void Add(const T & val) + //{ + // InsertRange(_count, &val, 1); + //} + + void insertRange(Index id, const List& list) { insertRange(id, list.m_buffer, list.m_count); } + + void addRange(ArrayView list) { insertRange(m_count, list.getBuffer(), list.Count()); } + + void addRange(const T* vals, Index n) { insertRange(m_count, vals, n); } + + void addRange(const List& list) { insertRange(m_count, list.m_buffer, list.m_count); } + + void removeRange(Index idx, Index count) + { + SLANG_ASSERT(idx >= 0 && idx <= m_count); + + const Index actualDeleteCount = ((idx + count) >= m_count)? (m_count - idx) : count; + for (Index i = idx + actualDeleteCount; i < m_count; i++) + m_buffer[i - actualDeleteCount] = static_cast(m_buffer[i]); + m_count -= actualDeleteCount; + } + + void removeAt(Index id) { removeRange(id, 1); } + + void remove(const T& val) + { + Index idx = indexOf(val); + if (idx != -1) + removeAt(idx); + } + + void reverse() + { + for (Index i = 0; i < (m_count >> 1); i++) + { + swapElements(m_buffer, i, m_count - i - 1); + } + } + + void fastRemove(const T& val) + { + Index idx = indexOf(val); + fastRemoveAt(idx); + } + + void fastRemoveAt(Index idx) + { + if (idx != -1 && m_count - 1 != idx) + { + m_buffer[idx] = _Move(m_buffer[m_count - 1]); + } + m_count--; + } + + void clear() { m_count = 0; } + + void clearAndDeallocate() + { + _deallocateBuffer(); + m_count = m_capacity = 0; + } + + void reserve(Index size) + { + if(size > m_capacity) + { + T* newBuffer = _allocate(size); + if (m_capacity) + { + /*if (std::has_trivial_copy_assign::value && std::has_trivial_destructor::value) + memcpy(newBuffer, buffer, _count * sizeof(T)); + else*/ + { + for (Index i = 0; i < m_count; i++) + newBuffer[i] = static_cast(m_buffer[i]); + + // Default-initialize the remaining elements + for(Index i = m_count; i < size; i++) + { + new(newBuffer + i) T(); + } + } + _deallocateBuffer(); + } + m_buffer = newBuffer; + m_capacity = size; + } + } + + void growToCount(Index count) + { + Index newBufferCount = Index(1) << Math::Log2Ceil(count); + if (m_capacity < newBufferCount) + { + reserve(newBufferCount); + } + m_count = count; + } + + void setCount(Index count) + { + reserve(count); + m_count = count; + } + + void unsafeShrinkToCount(Index count) { m_count = count; } + + void compress() + { + if (m_capacity > m_count && m_count > 0) + { + T* newBuffer = _allocate(m_count); + for (Index i = 0; i < m_count; i++) + newBuffer[i] = static_cast(m_buffer[i]); + + _deallocateBuffer(); + m_buffer = newBuffer; + m_capacity = m_count; + } + } + + SLANG_FORCE_INLINE T& operator [](Index idx) const + { + SLANG_ASSERT(idx >= 0 && idx <= m_count); + return m_buffer[idx]; + } + + template + Index findFirstIndex(const Func& predicate) const + { + for (Index i = 0; i < m_count; i++) + { + if (predicate(m_buffer[i])) + return i; + } + return (Index)-1; + } + + template + Index findLastIndex(const Func& predicate) const + { + for (Index i = m_count; i > 0; i--) + { + if (predicate(m_buffer[i-1])) + return i-1; + } + return (Index)-1; + } + + template + Index indexOf(const T2& val) const + { + for (Index i = 0; i < m_count; i++) + { + if (m_buffer[i] == val) + return i; + } + return (Index)-1; + } + + template + Index lastIndexOf(const T2& val) const + { + for (Index i = m_count; i > 0; i--) + { + if(m_buffer[i-1] == val) + return i-1; + } + return (Index)-1; + } + + void sort() + { + sort([](const T& t1, const T& t2){return t1 < t2;}); + } + + bool contains(const T& val) const { return indexOf(val) != Index(-1); } + + template + void sort(Comparer compare) + { + //insertionSort(buffer, 0, _count - 1); + //quickSort(buffer, 0, _count - 1, compare); + std::sort(m_buffer, m_buffer + m_count, compare); + } + + template + void forEach(IterateFunc f) const + { + for (Index i = 0; i< m_count; i++) + f(m_buffer[i]); + } + + template + void quickSort(T* vals, Index startIndex, Index endIndex, Comparer comparer) + { + static const Index kMinQSortSize = 32; + + if(startIndex < endIndex) + { + if (endIndex - startIndex < kMinQSortSize) + insertionSort(vals, startIndex, endIndex, comparer); + else + { + Index pivotIndex = (startIndex + endIndex) >> 1; + Index pivotNewIndex = partition(vals, startIndex, endIndex, pivotIndex, comparer); + quickSort(vals, startIndex, pivotNewIndex - 1, comparer); + quickSort(vals, pivotNewIndex + 1, endIndex, comparer); + } + } + + } + template + Index partition(T* vals, Index left, Index right, Index pivotIndex, Comparer comparer) + { + T pivotValue = vals[pivotIndex]; + swapElements(vals, right, pivotIndex); + Index storeIndex = left; + for (Index i = left; i < right; i++) + { + if (comparer(vals[i], pivotValue)) + { + swapElements(vals, i, storeIndex); + storeIndex++; + } + } + swapElements(vals, storeIndex, right); + return storeIndex; + } + template + void insertionSort(T* vals, Index startIndex, Index endIndex, Comparer comparer) + { + for (Index i = startIndex + 1; i <= endIndex; i++) + { + T insertValue = static_cast(vals[i]); + Index insertIndex = i - 1; + while (insertIndex >= startIndex && comparer(insertValue, vals[insertIndex])) + { + vals[insertIndex + 1] = static_cast(vals[insertIndex]); + insertIndex--; + } + vals[insertIndex + 1] = static_cast(insertValue); + } + } + + inline void swapElements(T* vals, Index index1, Index index2) + { + if (index1 != index2) + { + T tmp = static_cast(vals[index1]); + vals[index1] = static_cast(vals[index2]); + vals[index2] = static_cast(tmp); + } + } + + template + Index binarySearch(const T2& obj, Comparer comparer) + { + Index imin = 0, imax = m_count - 1; + while (imax >= imin) + { + Index imid = (imin + imax) >> 1; + int compareResult = comparer(m_buffer[imid], obj); + if (compareResult == 0) + return imid; + else if (compareResult < 0) + imin = imid + 1; + else + imax = imid - 1; + } + return -1; + } + + template + int binarySearch(const T2& obj) + { + return binarySearch(obj, + [](T & curObj, const T2 & thatObj)->int + { + if (curObj < thatObj) + return -1; + else if (curObj == thatObj) + return 0; + else + return 1; + }); + } + private: + T* m_buffer; ///< A new T[N] allocated buffer. NOTE! All elements up to capacity are in some valid form for T. + Index m_capacity; ///< The total capacity of elements + Index m_count; ///< The amount of elements + + void _deallocateBuffer() + { + if (m_buffer) + { + AllocateMethod::deallocateArray(m_buffer, m_capacity); + m_buffer = nullptr; + } + } + static inline T* _allocate(Index count) + { + return AllocateMethod::allocateArray(count); + } + + template + void _init(const T& val, Args... args) + { + add(val); + _init(args...); + } + }; + + template + T calcMin(const List& list) + { + T minVal = list.getFirst(); + for (Index i = 1; i < list.getCount(); i++) + if (list[i] < minVal) + minVal = list[i]; + return minVal; + } + + template + T calcMax(const List& list) + { + T maxVal = list.getFirst(); + for (Index i = 1; i< list.getCount(); i++) + if (list[i] > maxVal) + maxVal = list[i]; + return maxVal; + } +} + +#endif diff --git a/source/core/slang-math.h b/source/core/slang-math.h index a245e2d2c..0daad0d5a 100644 --- a/source/core/slang-math.h +++ b/source/core/slang-math.h @@ -1,5 +1,5 @@ -#ifndef CORE_LIB_MATH_H -#define CORE_LIB_MATH_H +#ifndef SLANG_CORE_MATH_H +#define SLANG_CORE_MATH_H #include diff --git a/source/core/slang-memory-arena.h b/source/core/slang-memory-arena.h index b9066e198..75744710f 100644 --- a/source/core/slang-memory-arena.h +++ b/source/core/slang-memory-arena.h @@ -1,5 +1,5 @@ -#ifndef SLANG_MEMORY_ARENA_H -#define SLANG_MEMORY_ARENA_H +#ifndef SLANG_CORE_MEMORY_ARENA_H +#define SLANG_CORE_MEMORY_ARENA_H #include "../../slang.h" diff --git a/source/core/slang-object-scope-manager.h b/source/core/slang-object-scope-manager.h index 660a7cace..9930e46ea 100644 --- a/source/core/slang-object-scope-manager.h +++ b/source/core/slang-object-scope-manager.h @@ -1,8 +1,8 @@ -#ifndef SLANG_OBJECT_SCOPE_MANAGER_H -#define SLANG_OBJECT_SCOPE_MANAGER_H +#ifndef SLANG_CORE_OBJECT_SCOPE_MANAGER_H +#define SLANG_CORE_OBJECT_SCOPE_MANAGER_H -#include "smart-pointer.h" -#include "list.h" +#include "slang-smart-pointer.h" +#include "slang-list.h" namespace Slang { diff --git a/source/core/slang-platform.cpp b/source/core/slang-platform.cpp new file mode 100644 index 000000000..1cb2bc56e --- /dev/null +++ b/source/core/slang-platform.cpp @@ -0,0 +1,172 @@ +// slang-platform.cpp +#include "slang-platform.h" + +#include "slang-common.h" + +#ifdef _WIN32 + #define WIN32_LEAN_AND_MEAN + #define NOMINMAX + #include + #undef WIN32_LEAN_AND_MEAN + #undef NOMINMAX +#else + #include "slang-string.h" + #include +#endif + +namespace Slang +{ + // SharedLibrary + +/* static */SlangResult SharedLibrary::load(const char* filename, SharedLibrary::Handle& handleOut) +{ + StringBuilder builder; + appendPlatformFileName(UnownedStringSlice(filename), builder); + return loadWithPlatformFilename(builder.begin(), handleOut); +} + +#ifdef _WIN32 + +// Make sure SlangResult match for common standard window HRESULT +SLANG_COMPILE_TIME_ASSERT(E_FAIL == SLANG_FAIL); +SLANG_COMPILE_TIME_ASSERT(E_NOINTERFACE == SLANG_E_NO_INTERFACE); +SLANG_COMPILE_TIME_ASSERT(E_HANDLE == SLANG_E_INVALID_HANDLE); +SLANG_COMPILE_TIME_ASSERT(E_NOTIMPL == SLANG_E_NOT_IMPLEMENTED); +SLANG_COMPILE_TIME_ASSERT(E_INVALIDARG == SLANG_E_INVALID_ARG); +SLANG_COMPILE_TIME_ASSERT(E_OUTOFMEMORY == SLANG_E_OUT_OF_MEMORY); + +/* static */SlangResult PlatformUtil::appendResult(SlangResult res, StringBuilder& builderOut) +{ + if (SLANG_FAILED(res) && res != SLANG_FAIL) + { + LPWSTR buffer = nullptr; + FormatMessage(FORMAT_MESSAGE_FROM_SYSTEM | FORMAT_MESSAGE_ALLOCATE_BUFFER, + nullptr, + res, + MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), // Default language + (LPWSTR)&buffer, + 0, + nullptr); + + if (buffer) + { + builderOut << " "; + // Convert to string + builderOut.Append(String::fromWString(buffer)); + LocalFree(buffer); + return SLANG_OK; + } + } + return SLANG_FAIL; +} + +/* static */SlangResult SharedLibrary::loadWithPlatformFilename(char const* platformFileName, SharedLibrary::Handle& handleOut) +{ + handleOut = nullptr; + // https://docs.microsoft.com/en-us/windows/desktop/api/libloaderapi/nf-libloaderapi-loadlibrarya + const HMODULE h = LoadLibraryA(platformFileName); + if (!h) + { + const DWORD lastError = GetLastError(); + switch (lastError) + { + case ERROR_MOD_NOT_FOUND: + case ERROR_PATH_NOT_FOUND: + case ERROR_FILE_NOT_FOUND: + { + return SLANG_E_NOT_FOUND; + } + case ERROR_INVALID_ACCESS: + case ERROR_ACCESS_DENIED: + case ERROR_INVALID_DATA: + { + return SLANG_E_CANNOT_OPEN; + } + default: break; + } + // Turn to Result, if not one of the well known errors + return HRESULT_FROM_WIN32(lastError); + } + handleOut = (Handle)h; + return SLANG_OK; +} + +/* static */void SharedLibrary::unload(Handle handle) +{ + SLANG_ASSERT(handle); + ::FreeLibrary((HMODULE)handle); +} + +/* static */SharedLibrary::FuncPtr SharedLibrary::findFuncByName(Handle handle, char const* name) +{ + SLANG_ASSERT(handle); + return (FuncPtr)GetProcAddress((HMODULE)handle, name); +} + +/* static */void SharedLibrary::appendPlatformFileName(const UnownedStringSlice& name, StringBuilder& dst) +{ + // Windows doesn't need the extension or any prefix to work + dst.Append(name); +} + +#else // _WIN32 + +/* static */SlangResult PlatformUtil::appendResult(SlangResult res, StringBuilder& builderOut) +{ + return SLANG_E_NOT_IMPLEMENTED; +} + +/* static */SlangResult SharedLibrary::loadWithPlatformFilename(char const* platformFileName, Handle& handleOut) +{ + handleOut = nullptr; + + void* h = dlopen(platformFileName, RTLD_NOW | RTLD_LOCAL); + if(!h) + { +#if 0 + // We can't output the error message here, because it will cause output when testing what code gen is available + if(auto msg = dlerror()) + { + fprintf(stderr, "error: %s\n", msg); + } +#endif + return SLANG_FAIL; + } + handleOut = (Handle)h; + return SLANG_OK; +} + +/* static */void SharedLibrary::unload(Handle handle) +{ + SLANG_ASSERT(handle); + dlclose(handle); +} + +/* static */SharedLibrary::FuncPtr SharedLibrary::findFuncByName(Handle handle, char const* name) +{ + SLANG_ASSERT(handle); + return (FuncPtr)dlsym((void*)handle, name); +} + +/* static */void SharedLibrary::appendPlatformFileName(const UnownedStringSlice& name, StringBuilder& dst) +{ +#if __CYGWIN__ + dst.Append(name); + dst.Append(".dll"); +#elif SLANG_APPLE_FAMILY + dst.Append("lib"); + dst.Append(name); + dst.Append(".dylib"); +#elif SLANG_LINUX_FAMILY + dst.Append("lib"); + dst.Append(name); + dst.Append(".so"); +#else + // Just guess we can do with the name on it's own + dst.Append(name); +#endif +} + +#endif // _WIN32 + +} diff --git a/source/core/slang-platform.h b/source/core/slang-platform.h new file mode 100644 index 000000000..e33c5599d --- /dev/null +++ b/source/core/slang-platform.h @@ -0,0 +1,67 @@ +// slang-platform.h +#ifndef SLANG_CORE_PLATFORM_H +#define SLANG_CORE_PLATFORM_H + +#include "../../slang.h" +#include "../core/slang-string.h" + +namespace Slang +{ + // Interface for working with shared libraries + // in a platform-independent fashion. + struct SharedLibrary + { + typedef struct SharedLibraryImpl* Handle; + + typedef void(*FuncPtr)(void); + + /// Load via an unadorned filename + /// + /// @param the unadorned filename + /// @return Returns a non null handle for the shared library on success. nullptr indicated failure + static SlangResult load(const char* filename, Handle& handleOut); + + /// Attempt to load a shared library for + /// the current platform. Returns null handle on failure + /// The platform specific filename can be generated from a call to appendPlatformFileName + /// + /// @param platformFileName the platform specific file name. + /// @return Returns a non null handle for the shared library on success. nullptr indicated failure + static SlangResult loadWithPlatformFilename(char const* platformFileName, Handle& handleOut); + + /// Unload the library that was returned from load as handle + /// @param The valid handle returned from load + static void unload(Handle handle); + + /// Given a shared library handle and a name, return the associated function + /// Return nullptr if function is not found + /// @param The shared library handle as returned by loadPlatformLibrary + static FuncPtr findFuncByName(Handle handle, char const* name); + + /// Append to the end of dst, the name, with any platform specific additions + /// The input name should be unadorned with any 'lib' prefix or extension + static void appendPlatformFileName(const UnownedStringSlice& name, StringBuilder& dst); + + private: + /// Not constructible! + SharedLibrary(); + }; + + struct PlatformUtil + { + /// Appends a text interpretation of a result (as defined by supporting OS) + /// @param res Result to produce a string for + /// @param builderOut Append the string produced to builderOut + /// @return SLANG_OK if string is found and appended. Fail otherwise. SLANG_E_NOT_IMPLEMENTED if there is no impl for this platform. + static SlangResult appendResult(SlangResult res, StringBuilder& builderOut); + }; + +#ifndef _MSC_VER + #define _fileno fileno + #define _isatty isatty + #define _setmode setmode + #define _O_BINARY O_BINARY +#endif +} + +#endif diff --git a/source/core/slang-random-generator.h b/source/core/slang-random-generator.h index cc25aadf3..8b4d1759b 100644 --- a/source/core/slang-random-generator.h +++ b/source/core/slang-random-generator.h @@ -1,12 +1,12 @@ -#ifndef SLANG_RANDOM_GENERATOR_H -#define SLANG_RANDOM_GENERATOR_H +#ifndef SLANG_CORE_RANDOM_GENERATOR_H +#define SLANG_CORE_RANDOM_GENERATOR_H #include "../../slang.h" #include #include -#include "smart-pointer.h" +#include "slang-smart-pointer.h" namespace Slang { diff --git a/source/core/slang-render-api-util.cpp b/source/core/slang-render-api-util.cpp index 05def0fe3..3df971219 100644 --- a/source/core/slang-render-api-util.cpp +++ b/source/core/slang-render-api-util.cpp @@ -3,10 +3,10 @@ #include "../../slang.h" -#include "../../source/core/list.h" -#include "../../source/core/slang-string-util.h" +#include "slang-list.h" +#include "slang-string-util.h" -#include "platform.h" +#include "slang-platform.h" namespace Slang { diff --git a/source/core/slang-render-api-util.h b/source/core/slang-render-api-util.h index 42e88a6ac..fbdd3930c 100644 --- a/source/core/slang-render-api-util.h +++ b/source/core/slang-render-api-util.h @@ -1,5 +1,5 @@ -#ifndef SLANG_RENDER_API_UTIL_H -#define SLANG_RENDER_API_UTIL_H +#ifndef SLANG_CORE_RENDER_API_UTIL_H +#define SLANG_CORE_RENDER_API_UTIL_H #include "../../source/core/slang-string.h" diff --git a/source/core/slang-secure-crt.h b/source/core/slang-secure-crt.h new file mode 100644 index 000000000..991fe939e --- /dev/null +++ b/source/core/slang-secure-crt.h @@ -0,0 +1,88 @@ +#ifndef _MSC_VER +#ifndef SLANG_CORE_SECURE_CRT_H +#define SLANG_CORE_SECURE_CRT_H +#include +#include +#include +#include +#include + +#include + +inline void memcpy_s(void *dest, size_t numberOfElements, const void * src, size_t count) +{ + memcpy(dest, src, count); +} + +#define _TRUNCATE ((size_t)-1) +#define _stricmp strcasecmp + +inline void fopen_s(FILE**f, const char * fileName, const char * mode) +{ + *f = fopen(fileName, mode); +} + +inline size_t fread_s(void * buffer, size_t bufferSize, size_t elementSize, size_t count, FILE * stream) +{ + return fread(buffer, elementSize, count, stream); +} + +inline size_t wcsnlen_s(const wchar_t * str, size_t /*numberofElements*/) +{ + return wcslen(str); +} + +inline size_t strnlen_s(const char * str, size_t numberOfElements) +{ +#if defined( __CYGWIN__ ) + const char* cur = str; + if (str) + { + const char*const end = str + numberOfElements; + while (*cur && cur < end) cur++; + } + return size_t(cur - str); +#else + return strnlen(str, numberOfElements); +#endif +} + +inline int sprintf_s(char * buffer, size_t sizeOfBuffer, const char * format, ...) +{ + va_list argptr; + va_start(argptr, format); + int rs = vsnprintf(buffer, sizeOfBuffer, format, argptr); + va_end(argptr); + return rs; +} + +inline int swprintf_s(wchar_t * buffer, size_t sizeOfBuffer, const wchar_t * format, ...) +{ + va_list argptr; + va_start(argptr, format); + int rs = vswprintf(buffer, sizeOfBuffer, format, argptr); + va_end(argptr); + return rs; +} + +inline void wcscpy_s(wchar_t * strDestination, size_t /*numberOfElements*/, const wchar_t * strSource) +{ + wcscpy(strDestination, strSource); +} +inline void strcpy_s(char * strDestination, size_t /*numberOfElements*/, const char * strSource) +{ + strcpy(strDestination, strSource); +} + +inline void wcsncpy_s(wchar_t * strDestination, size_t /*numberOfElements*/, const wchar_t * strSource, size_t count) +{ + wcscpy(strDestination, strSource); + //wcsncpy(strDestination, strSource, count); +} +inline void strncpy_s(char * strDestination, size_t /*numberOfElements*/, const char * strSource, size_t count) +{ + strncpy(strDestination, strSource, count); + //wcsncpy(strDestination, strSource, count); +} +#endif +#endif diff --git a/source/core/slang-shared-library.cpp b/source/core/slang-shared-library.cpp index 20d457840..009abf921 100644 --- a/source/core/slang-shared-library.cpp +++ b/source/core/slang-shared-library.cpp @@ -1,8 +1,9 @@ #include "slang-shared-library.h" #include "../../slang-com-ptr.h" -#include "../core/slang-io.h" -#include "../core/slang-string-util.h" + +#include "slang-io.h" +#include "slang-string-util.h" namespace Slang { diff --git a/source/core/slang-shared-library.h b/source/core/slang-shared-library.h index 62d15b6b4..5a4eb7229 100644 --- a/source/core/slang-shared-library.h +++ b/source/core/slang-shared-library.h @@ -1,13 +1,13 @@ -#ifndef SLANG_SHARED_LIBRARY_H_INCLUDED -#define SLANG_SHARED_LIBRARY_H_INCLUDED +#ifndef SLANG_CORE_SHARED_LIBRARY_H +#define SLANG_CORE_SHARED_LIBRARY_H #include "../../slang.h" #include "../../slang-com-helper.h" #include "../../slang-com-ptr.h" -#include "../core/platform.h" -#include "../core/common.h" -#include "../core/dictionary.h" +#include "../core/slang-platform.h" +#include "../core/slang-common.h" +#include "../core/slang-dictionary.h" namespace Slang { @@ -119,4 +119,4 @@ public: } -#endif // SLANG_SHARED_LIBRARY_H_INCLUDED \ No newline at end of file +#endif // SLANG_SHARED_LIBRARY_H_INCLUDED diff --git a/source/core/slang-smart-pointer.h b/source/core/slang-smart-pointer.h new file mode 100644 index 000000000..bae30de37 --- /dev/null +++ b/source/core/slang-smart-pointer.h @@ -0,0 +1,250 @@ +#ifndef SLANG_CORE_SMART_POINTER_H +#define SLANG_CORE_SMART_POINTER_H + +#include "slang-common.h" +#include "slang-hash.h" +#include "slang-type-traits.h" + +#include "../../slang.h" + +namespace Slang +{ + // Base class for all reference-counted objects + class RefObject + { + private: + UInt referenceCount; + + public: + RefObject() + : referenceCount(0) + {} + + RefObject(const RefObject &) + : referenceCount(0) + {} + + virtual ~RefObject() + {} + + UInt addReference() + { + return ++referenceCount; + } + + UInt decreaseReference() + { + return --referenceCount; + } + + UInt releaseReference() + { + SLANG_ASSERT(referenceCount != 0); + if(--referenceCount == 0) + { + delete this; + return 0; + } + return referenceCount; + } + + bool isUniquelyReferenced() + { + SLANG_ASSERT(referenceCount != 0); + return referenceCount == 1; + } + + UInt debugGetReferenceCount() + { + return referenceCount; + } + }; + + SLANG_FORCE_INLINE void addReference(RefObject* obj) + { + if(obj) obj->addReference(); + } + + SLANG_FORCE_INLINE void releaseReference(RefObject* obj) + { + if(obj) obj->releaseReference(); + } + + // For straight dynamic cast. + // Use instead of dynamic_cast as it allows for replacement without using Rtti in the future + template + SLANG_FORCE_INLINE T* dynamicCast(RefObject* obj) { return dynamic_cast(obj); } + template + SLANG_FORCE_INLINE const T* dynamicCast(const RefObject* obj) { return dynamic_cast(obj); } + + // Like a dynamicCast, but allows a type to implement a specific implementation that is suitable for it + template + SLANG_FORCE_INLINE T* as(RefObject* obj) { return dynamicCast(obj); } + template + SLANG_FORCE_INLINE const T* as(const RefObject* obj) { return dynamicCast(obj); } + + // "Smart" pointer to a reference-counted object + template + struct RefPtr + { + RefPtr() + : pointer(nullptr) + {} + + RefPtr(T* p) + : pointer(p) + { + addReference(p); + } + + RefPtr(RefPtr const& p) + : pointer(p.pointer) + { + addReference(p.pointer); + } + + RefPtr(RefPtr&& p) + : pointer(p.pointer) + { + p.pointer = nullptr; + } + + template + RefPtr(RefPtr const& p, + typename EnableIf::Value, void>::type * = 0) + : pointer((U*) p) + { + addReference((U*) p); + } + +#if 0 + void operator=(T* p) + { + T* old = pointer; + addReference(p); + pointer = p; + releaseReference(old); + } +#endif + + void operator=(RefPtr const& p) + { + T* old = pointer; + addReference(p.pointer); + pointer = p.pointer; + releaseReference(old); + } + + void operator=(RefPtr&& p) + { + T* old = pointer; + pointer = p.pointer; + p.pointer = old; + } + + template + typename EnableIf::value, void>::type + operator=(RefPtr const& p) + { + T* old = pointer; + addReference(p.pointer); + pointer = p.pointer; + releaseReference(old); + } + + int GetHashCode() + { + // Note: We need a `RefPtr` to hash the same as a `T*`, + // so that a `T*` can be used as a key in a dictionary with + // `RefPtr` keys, and vice versa. + // + return Slang::GetHashCode(pointer); + } + + bool operator==(const T * ptr) const + { + return pointer == ptr; + } + + bool operator!=(const T * ptr) const + { + return pointer != ptr; + } + + bool operator==(RefPtr const& ptr) const + { + return pointer == ptr.pointer; + } + + bool operator!=(RefPtr const& ptr) const + { + return pointer != ptr.pointer; + } + + template + RefPtr dynamicCast() const + { + return RefPtr(Slang::dynamicCast(pointer)); + } + + template + RefPtr as() const + { + return RefPtr(Slang::as(pointer)); + } + + template + bool is() const { return Slang::as(pointer) != nullptr; } + + ~RefPtr() + { + releaseReference((Slang::RefObject*) pointer); + } + + T& operator*() const + { + return *pointer; + } + + T* operator->() const + { + return pointer; + } + + T * Ptr() const + { + return pointer; + } + + operator T*() const + { + return pointer; + } + + void attach(T* p) + { + T* old = pointer; + pointer = p; + releaseReference(old); + } + + T* detach() + { + auto rs = pointer; + pointer = nullptr; + return rs; + } + + /// Get ready for writing (nulls contents) + SLANG_FORCE_INLINE T** writeRef() { *this = nullptr; return &pointer; } + + /// Get for read access + SLANG_FORCE_INLINE T*const* readRef() const { return &pointer; } + + private: + T* pointer; + + }; +} + +#endif diff --git a/source/core/slang-std-writers.h b/source/core/slang-std-writers.h index b35d3d037..8ecb89227 100644 --- a/source/core/slang-std-writers.h +++ b/source/core/slang-std-writers.h @@ -1,5 +1,5 @@ -#ifndef SLANG_STD_WRITERS_H -#define SLANG_STD_WRITERS_H +#ifndef SLANG_CORE_STD_WRITERS_H +#define SLANG_CORE_STD_WRITERS_H #include "slang-writer.h" #include "../../slang-com-ptr.h" diff --git a/source/core/slang-stream.cpp b/source/core/slang-stream.cpp new file mode 100644 index 000000000..ee194c451 --- /dev/null +++ b/source/core/slang-stream.cpp @@ -0,0 +1,294 @@ +#include "slang-stream.h" +#ifdef _WIN32 +#include +#endif +#include "slang-io.h" + +namespace Slang +{ + FileStream::FileStream(const Slang::String & fileName, FileMode fileMode) + { + Init(fileName, fileMode, fileMode==FileMode::Open?FileAccess::Read:FileAccess::Write, FileShare::None); + } + FileStream::FileStream(const Slang::String & fileName, FileMode fileMode, FileAccess access, FileShare share) + { + Init(fileName, fileMode, access, share); + } + void FileStream::Init(const Slang::String & fileName, FileMode fileMode, FileAccess access, FileShare share) + { + const wchar_t * mode = L"rt"; + const char* modeMBCS = "rt"; + switch (fileMode) + { + case Slang::FileMode::Create: + if (access == FileAccess::Read) + throw ArgumentException("Read-only access is incompatible with Create mode."); + else if (access == FileAccess::ReadWrite) + { + mode = L"w+b"; + modeMBCS = "w+b"; + this->fileAccess = FileAccess::ReadWrite; + } + else + { + mode = L"wb"; + modeMBCS = "wb"; + this->fileAccess = FileAccess::Write; + } + break; + case Slang::FileMode::Open: + if (access == FileAccess::Read) + { + mode = L"rb"; + modeMBCS = "rb"; + this->fileAccess = FileAccess::Read; + } + else if (access == FileAccess::ReadWrite) + { + mode = L"r+b"; + modeMBCS = "r+b"; + this->fileAccess = FileAccess::ReadWrite; + } + else + { + mode = L"wb"; + modeMBCS = "wb"; + this->fileAccess = FileAccess::Write; + } + break; + case Slang::FileMode::CreateNew: + if (File::exists(fileName)) + { + throw IOException("Failed opening '" + fileName + "', file already exists."); + } + if (access == FileAccess::Read) + throw ArgumentException("Read-only access is incompatible with Create mode."); + else if (access == FileAccess::ReadWrite) + { + mode = L"w+b"; + this->fileAccess = FileAccess::ReadWrite; + } + else + { + mode = L"wb"; + this->fileAccess = FileAccess::Write; + } + break; + case Slang::FileMode::Append: + if (access == FileAccess::Read) + throw ArgumentException("Read-only access is incompatible with Append mode."); + else if (access == FileAccess::ReadWrite) + { + mode = L"a+b"; + this->fileAccess = FileAccess::ReadWrite; + } + else + { + mode = L"ab"; + this->fileAccess = FileAccess::Write; + } + break; + default: + break; + } +#ifdef _WIN32 + int shFlag = _SH_DENYRW; + switch (share) + { + case Slang::FileShare::None: + shFlag = _SH_DENYRW; + break; + case Slang::FileShare::ReadOnly: + shFlag = _SH_DENYWR; + break; + case Slang::FileShare::WriteOnly: + shFlag = _SH_DENYRD; + break; + case Slang::FileShare::ReadWrite: + shFlag = _SH_DENYNO; + break; + default: + throw ArgumentException("Invalid file share mode."); + break; + } + if (share == Slang::FileShare::None) +#pragma warning(suppress:4996) + handle = _wfopen(fileName.toWString(), mode); + else + handle = _wfsopen(fileName.toWString(), mode, shFlag); +#else + handle = fopen(fileName.getBuffer(), modeMBCS); +#endif + if (!handle) + { + throw IOException("Cannot open file '" + fileName + "'"); + } + } + FileStream::~FileStream() + { + Close(); + } + Int64 FileStream::GetPosition() + { +#if defined(_WIN32) || defined(__CYGWIN__) + fpos_t pos; + fgetpos(handle, &pos); + return pos; +#elif defined(__APPLE__) + return ftell(handle); +#else + fpos64_t pos; + fgetpos64(handle, &pos); + return *(Int64*)(&pos); +#endif + } + void FileStream::Seek(SeekOrigin origin, Int64 offset) + { + int _origin; + switch (origin) + { + case Slang::SeekOrigin::Start: + _origin = SEEK_SET; + endReached = false; + break; + case Slang::SeekOrigin::End: + _origin = SEEK_END; + // JS TODO: This doesn't seem right, the offset can mean it's not at the end + endReached = true; + break; + case Slang::SeekOrigin::Current: + _origin = SEEK_CUR; + endReached = false; + break; + default: + throw NotSupportedException("Unsupported seek origin."); + break; + } +#ifdef _WIN32 + int rs = _fseeki64(handle, offset, _origin); +#else + int rs = fseek(handle, (int)offset, _origin); +#endif + if (rs != 0) + { + throw IOException("FileStream seek failed."); + } + } + Int64 FileStream::Read(void * buffer, Int64 length) + { + auto bytes = fread_s(buffer, (size_t)length, 1, (size_t)length, handle); + if (bytes == 0 && length > 0) + { + if (!feof(handle)) + throw IOException("FileStream read failed."); + else if (endReached) + throw EndOfStreamException("End of file is reached."); + endReached = true; + } + return (int)bytes; + } + Int64 FileStream::Write(const void * buffer, Int64 length) + { + auto bytes = (Int64)fwrite(buffer, 1, (size_t)length, handle); + if (bytes < length) + { + throw IOException("FileStream write failed."); + } + return bytes; + } + bool FileStream::CanRead() + { + return ((int)fileAccess & (int)FileAccess::Read) != 0; + } + bool FileStream::CanWrite() + { + return ((int)fileAccess & (int)FileAccess::Write) != 0; + } + void FileStream::Close() + { + if (handle) + { + fclose(handle); + handle = 0; + } + } + bool FileStream::IsEnd() + { + return endReached; + } + + // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! MemoryStream !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! + + void MemoryStream::Seek(SeekOrigin origin, Int64 offset) + { + Int64 pos = 0; + switch (origin) + { + case Slang::SeekOrigin::Start: + pos = offset; + break; + case Slang::SeekOrigin::End: + pos = Int64(m_contents.getCount()) + offset; + break; + case Slang::SeekOrigin::Current: + pos = Int64(m_position) + offset; + break; + default: + throw NotSupportedException("Unsupported seek origin."); + break; + } + + m_atEnd = false; + + // Clamp to the valid range + pos = (pos < 0) ? 0 : pos; + pos = (pos > Int64(m_contents.getCount())) ? Int64(m_contents.getCount()) : pos; + + m_position = UInt(pos); + } + + Int64 MemoryStream::Read(void * buffer, Int64 length) + { + if (!CanRead()) + { + throw IOException("Cannot read this stream."); + } + + const Int64 maxRead = Int64(m_contents.getCount() - m_position); + + if (maxRead == 0 && length > 0) + { + m_atEnd = true; + throw EndOfStreamException("End of file is reached."); + } + + length = length > maxRead ? maxRead : length; + + ::memcpy(buffer, m_contents.begin() + m_position, size_t(length)); + m_position += UInt(length); + return maxRead; + } + + Int64 MemoryStream::Write(const void * buffer, Int64 length) + { + if (!CanWrite()) + { + throw IOException("Cannot write this stream."); + } + + if (m_position == m_contents.getCount()) + { + m_contents.addRange((const uint8_t*)buffer, UInt(length)); + } + else + { + m_contents.insertRange(m_position, (const uint8_t*)buffer, UInt(length)); + } + + m_atEnd = false; + + m_position += UInt(length); + return length; + } + +} diff --git a/source/core/slang-stream.h b/source/core/slang-stream.h new file mode 100644 index 000000000..67e04fa6a --- /dev/null +++ b/source/core/slang-stream.h @@ -0,0 +1,113 @@ +#ifndef SLANG_CORE_STREAM_H +#define SLANG_CORE_STREAM_H + +#include "slang-basic.h" + +namespace Slang +{ + class IOException : public Exception + { + public: + IOException() + {} + IOException(const String & message) + : Slang::Exception(message) + { + } + }; + + class EndOfStreamException : public IOException + { + public: + EndOfStreamException() + {} + EndOfStreamException(const String & message) + : IOException(message) + { + } + }; + + enum class SeekOrigin + { + Start, End, Current + }; + + class Stream : public RefObject + { + public: + virtual ~Stream() {} + virtual Int64 GetPosition()=0; + virtual void Seek(SeekOrigin origin, Int64 offset)=0; + virtual Int64 Read(void * buffer, Int64 length) = 0; + virtual Int64 Write(const void * buffer, Int64 length) = 0; + virtual bool IsEnd() = 0; + virtual bool CanRead() = 0; + virtual bool CanWrite() = 0; + virtual void Close() = 0; + }; + + enum class FileMode + { + Create, Open, CreateNew, Append + }; + + enum class FileAccess + { + None = 0, Read = 1, Write = 2, ReadWrite = 3 + }; + + enum class FileShare + { + None, ReadOnly, WriteOnly, ReadWrite + }; + + class MemoryStream : public Stream + { + public: + virtual Int64 GetPosition() SLANG_OVERRIDE { return m_position; } + virtual void Seek(SeekOrigin origin, Int64 offset) SLANG_OVERRIDE; + virtual Int64 Read(void * buffer, Int64 length) SLANG_OVERRIDE; + virtual Int64 Write(const void * buffer, Int64 length) SLANG_OVERRIDE; + virtual bool IsEnd() SLANG_OVERRIDE { return m_atEnd; } + virtual bool CanRead() SLANG_OVERRIDE { return (int(m_access) & int(FileAccess::Read)) != 0; } + virtual bool CanWrite() SLANG_OVERRIDE { return (int(m_access) & int(FileAccess::Write)) != 0; } + virtual void Close() SLANG_OVERRIDE { m_access = FileAccess::None; } + + MemoryStream(FileAccess access) : + m_access(access), + m_position(0), + m_atEnd(false) + {} + + Index m_position; + + bool m_atEnd; ///< Happens when a read is done and nothing can be returned because already at end + + FileAccess m_access; + List m_contents; + }; + + class FileStream : public Stream + { + private: + FILE * handle; + FileAccess fileAccess; + bool endReached = false; + void Init(const Slang::String & fileName, FileMode fileMode, FileAccess access, FileShare share); + public: + FileStream(const Slang::String & fileName, FileMode fileMode = FileMode::Open); + FileStream(const Slang::String & fileName, FileMode fileMode, FileAccess access, FileShare share); + ~FileStream(); + public: + virtual Int64 GetPosition(); + virtual void Seek(SeekOrigin origin, Int64 offset); + virtual Int64 Read(void * buffer, Int64 length); + virtual Int64 Write(const void * buffer, Int64 length); + virtual bool CanRead(); + virtual bool CanWrite(); + virtual void Close(); + virtual bool IsEnd(); + }; +} + +#endif diff --git a/source/core/slang-string-slice-pool.h b/source/core/slang-string-slice-pool.h index cf8f63c81..4d5f91e37 100644 --- a/source/core/slang-string-slice-pool.h +++ b/source/core/slang-string-slice-pool.h @@ -1,11 +1,11 @@ -#ifndef SLANG_STRING_SLICE_POOL_H -#define SLANG_STRING_SLICE_POOL_H +#ifndef SLANG_CORE_STRING_SLICE_POOL_H +#define SLANG_CORE_STRING_SLICE_POOL_H #include "slang-string.h" -#include "list.h" +#include "slang-list.h" #include "slang-memory-arena.h" -#include "dictionary.h" +#include "slang-dictionary.h" namespace Slang { diff --git a/source/core/slang-string-util.h b/source/core/slang-string-util.h index 40fda31c4..fcae23bb3 100644 --- a/source/core/slang-string-util.h +++ b/source/core/slang-string-util.h @@ -1,8 +1,8 @@ -#ifndef SLANG_STRING_UTIL_H -#define SLANG_STRING_UTIL_H +#ifndef SLANG_CORE_STRING_UTIL_H +#define SLANG_CORE_STRING_UTIL_H #include "slang-string.h" -#include "list.h" +#include "slang-list.h" #include diff --git a/source/core/slang-string.cpp b/source/core/slang-string.cpp index 9a908c93e..64b8e4dc1 100644 --- a/source/core/slang-string.cpp +++ b/source/core/slang-string.cpp @@ -1,5 +1,5 @@ #include "slang-string.h" -#include "text-io.h" +#include "slang-text-io.h" namespace Slang { diff --git a/source/core/slang-string.h b/source/core/slang-string.h index 82eda74ac..1cd9e5413 100644 --- a/source/core/slang-string.h +++ b/source/core/slang-string.h @@ -1,14 +1,14 @@ -#ifndef FUNDAMENTAL_LIB_STRING_H -#define FUNDAMENTAL_LIB_STRING_H +#ifndef SLANG_CORE_STRING_H +#define SLANG_CORE_STRING_H #include #include #include -#include "smart-pointer.h" -#include "common.h" -#include "hash.h" -#include "secure-crt.h" +#include "slang-smart-pointer.h" +#include "slang-common.h" +#include "slang-hash.h" +#include "slang-secure-crt.h" #include diff --git a/source/core/slang-test-tool-util.h b/source/core/slang-test-tool-util.h index 3ec655cad..a5d7541ec 100644 --- a/source/core/slang-test-tool-util.h +++ b/source/core/slang-test-tool-util.h @@ -1,5 +1,5 @@ -#ifndef SLANG_TEST_TOOL_UTIL_H -#define SLANG_TEST_TOOL_UTIL_H +#ifndef SLANG_CORE_TEST_TOOL_UTIL_H +#define SLANG_CORE_TEST_TOOL_UTIL_H #include "slang-std-writers.h" diff --git a/source/core/slang-text-io.cpp b/source/core/slang-text-io.cpp new file mode 100644 index 000000000..18039e41b --- /dev/null +++ b/source/core/slang-text-io.cpp @@ -0,0 +1,343 @@ +#include "slang-text-io.h" +#ifdef _WIN32 +#define WIN32_LEAN_AND_MEAN +#define NOMINMAX +#include +#undef WIN32_LEAN_AND_MEAN +#undef NOMINMAX +#define CONVERT_END_OF_LINE +#endif + +namespace Slang +{ + class Utf8Encoding : public Encoding + { + public: + virtual void GetBytes(List & result, const String & str) override + { + result.addRange(str.getBuffer(), str.getLength()); + } + virtual String ToString(const char * bytes, int /*length*/) override + { + return String(bytes); + } + }; + + class Utf32Encoding : public Encoding + { + public: + virtual void GetBytes(List & result, const String & str) override + { + Index ptr = 0; + while (ptr < str.getLength()) + { + int codePoint = GetUnicodePointFromUTF8([&](int) + { + if (ptr < str.getLength()) + return str[ptr++]; + else + return '\0'; + }); + result.addRange((char*)&codePoint, 4); + } + } + virtual String ToString(const char * bytes, int length) override + { + StringBuilder sb; + int * content = (int*)bytes; + for (int i = 0; i < (length >> 2); i++) + { + char buf[5]; + int count = EncodeUnicodePointToUTF8(buf, content[i]); + for (int j = 0; j < count; j++) + sb.Append(buf[j]); + } + return sb.ProduceString(); + } + }; + + class Utf16Encoding : public Encoding //UTF16 + { + private: + bool reverseOrder = false; + public: + Utf16Encoding(bool pReverseOrder) + : reverseOrder(pReverseOrder) + {} + virtual void GetBytes(List & result, const String & str) override + { + Index ptr = 0; + while (ptr < str.getLength()) + { + int codePoint = GetUnicodePointFromUTF8([&](int) + { + if (ptr < str.getLength()) + return str[ptr++]; + else + return '\0'; + }); + unsigned short buffer[2]; + int count; + if (!reverseOrder) + count = EncodeUnicodePointToUTF16(buffer, codePoint); + else + count = EncodeUnicodePointToUTF16Reversed(buffer, codePoint); + result.addRange((char*)buffer, count * 2); + } + } + virtual String ToString(const char * bytes, int length) override + { + int ptr = 0; + StringBuilder sb; + while (ptr < length) + { + int codePoint = GetUnicodePointFromUTF16([&](int) + { + if (ptr < length) + return bytes[ptr++]; + else + return '\0'; + }); + char buf[5]; + int count = EncodeUnicodePointToUTF8(buf, codePoint); + for (int i = 0; i < count; i++) + sb.Append(buf[i]); + } + return sb.ProduceString(); + } + }; + + Utf8Encoding __utf8Encoding; + Utf16Encoding __utf16Encoding(false); + Utf16Encoding __utf16EncodingReversed(true); + Utf32Encoding __utf32Encoding; + + Encoding * Encoding::UTF8 = &__utf8Encoding; + Encoding * Encoding::UTF16 = &__utf16Encoding; + Encoding * Encoding::UTF16Reversed = &__utf16EncodingReversed; + Encoding * Encoding::UTF32 = &__utf32Encoding; + + const unsigned short Utf16Header = 0xFEFF; + const unsigned short Utf16ReversedHeader = 0xFFFE; + + StreamWriter::StreamWriter(const String & path, Encoding * encoding) + { + this->stream = new FileStream(path, FileMode::Create); + this->encoding = encoding; + if (encoding == Encoding::UTF16) + { + this->stream->Write(&Utf16Header, 2); + } + else if (encoding == Encoding::UTF16Reversed) + { + this->stream->Write(&Utf16ReversedHeader, 2); + } + } + StreamWriter::StreamWriter(RefPtr stream, Encoding * encoding) + { + this->stream = stream; + this->encoding = encoding; + if (encoding == Encoding::UTF16) + { + this->stream->Write(&Utf16Header, 2); + } + else if (encoding == Encoding::UTF16Reversed) + { + this->stream->Write(&Utf16ReversedHeader, 2); + } + } + void StreamWriter::Write(const String & str) + { + encodingBuffer.clear(); + StringBuilder sb; + String newLine; +#ifdef _WIN32 + newLine = "\r\n"; +#else + newLine = "\n"; +#endif + for (Index i = 0; i < str.getLength(); i++) + { + if (str[i] == '\r') + sb << newLine; + else if (str[i] == '\n') + { + if (i > 0 && str[i - 1] != '\r') + sb << newLine; + } + else + sb << str[i]; + } + encoding->GetBytes(encodingBuffer, sb.ProduceString()); + stream->Write(encodingBuffer.getBuffer(), encodingBuffer.getCount()); + } + void StreamWriter::Write(const char * str) + { + Write(String(str)); + } + + StreamReader::StreamReader(const String & path) + { + stream = new FileStream(path, FileMode::Open); + ReadBuffer(); + encoding = DetermineEncoding(); + if (encoding == 0) + encoding = Encoding::UTF8; + } + StreamReader::StreamReader(RefPtr stream, Encoding * encoding) + { + this->stream = stream; + this->encoding = encoding; + ReadBuffer(); + auto determinedEncoding = DetermineEncoding(); + if (this->encoding == nullptr) + this->encoding = determinedEncoding; + } + + bool HasNullBytes(char * str, int len) + { + bool hasSeenNull = false; + for (int i = 0; i < len - 1; i++) + if (str[i] == 0) + hasSeenNull = true; + else if (hasSeenNull) + return true; + return false; + } + + Encoding * StreamReader::DetermineEncoding() + { + if (buffer.getCount() >= 3 && (unsigned char)(buffer[0]) == 0xEF && (unsigned char)(buffer[1]) == 0xBB && (unsigned char)(buffer[2]) == 0xBF) + { + ptr += 3; + return Encoding::UTF8; + } + else if (*((unsigned short*)(buffer.getBuffer())) == 0xFEFF) + { + ptr += 2; + return Encoding::UTF16; + } + else if (*((unsigned short*)(buffer.getBuffer())) == 0xFFFE) + { + ptr += 2; + return Encoding::UTF16Reversed; + } + else + { + // find null bytes + if (HasNullBytes(buffer.getBuffer(), (int)buffer.getCount())) + { + return Encoding::UTF16; + } + return Encoding::UTF8; + } + } + + void StreamReader::ReadBuffer() + { + buffer.setCount(4096); + memset(buffer.getBuffer(), 0, buffer.getCount() * sizeof(buffer[0])); + auto len = stream->Read(buffer.getBuffer(), buffer.getCount()); + buffer.setCount((int)len); + ptr = 0; + } + + char StreamReader::ReadBufferChar() + { + if (ptrIsEnd()) + ReadBuffer(); + if (ptr + TextWriter & operator << (const T& val) + { + Write(val.ToString()); + return *this; + } + TextWriter & operator << (int value) + { + Write(String(value)); + return *this; + } + TextWriter & operator << (float value) + { + Write(String(value)); + return *this; + } + TextWriter & operator << (double value) + { + Write(String(value)); + return *this; + } + TextWriter & operator << (const char* value) + { + Write(value); + return *this; + } + TextWriter & operator << (const String & val) + { + Write(val); + return *this; + } + TextWriter & operator << (const _EndLine &) + { +#ifdef _WIN32 + Write("\r\n"); +#else + Write("\n"); +#endif + return *this; + } + }; + + template + int GetUnicodePointFromUTF8(const ReadCharFunc & get) + { + int codePoint = 0; + int leading = get(0); + int mask = 0x80; + int count = 0; + while (leading & mask) + { + count++; + mask >>= 1; + } + codePoint = (leading & (mask - 1)); + for (int i = 1; i <= count - 1; i++) + { + codePoint <<= 6; + codePoint += (get(i) & 0x3F); + } + return codePoint; + } + + template + int GetUnicodePointFromUTF16(const ReadCharFunc & get) + { + int byte0 = (unsigned char)get(0); + int byte1 = (unsigned char)get(1); + int word0 = byte0 + (byte1 << 8); + if (word0 >= 0xD800 && word0 <= 0xDFFF) + { + int byte2 = (unsigned char)get(2); + int byte3 = (unsigned char)get(3); + int word1 = byte2 + (byte3 << 8); + return ((word0 & 0x3FF) << 10) + (word1 & 0x3FF) + 0x10000; + } + else + return word0; + } + + template + int GetUnicodePointFromUTF16Reversed(const ReadCharFunc & get) + { + int byte0 = (unsigned char)get(0); + int byte1 = (unsigned char)get(1); + int word0 = (byte0 << 8) + byte1; + if (word0 >= 0xD800 && word0 <= 0xDFFF) + { + int byte2 = (unsigned char)get(2); + int byte3 = (unsigned char)get(3); + int word1 = (byte2 << 8) + byte3; + return ((word0 & 0x3FF) << 10) + (word1 & 0x3FF); + } + else + return word0; + } + + template + int GetUnicodePointFromUTF32(const ReadCharFunc & get) + { + int byte0 = (unsigned char)get(0); + int byte1 = (unsigned char)get(1); + int byte2 = (unsigned char)get(2); + int byte3 = (unsigned char)get(3); + return byte0 + (byte1 << 8) + (byte2 << 16) + (byte3 << 24); + } + + inline int EncodeUnicodePointToUTF8(char * buffer, int codePoint) + { + int count = 0; + if (codePoint <= 0x7F) + buffer[count++] = ((char)codePoint); + else if (codePoint <= 0x7FF) + { + unsigned char byte = (unsigned char)(0xC0 + (codePoint >> 6)); + buffer[count++] = ((char)byte); + byte = 0x80 + (codePoint & 0x3F); + buffer[count++] = ((char)byte); + } + else if (codePoint <= 0xFFFF) + { + unsigned char byte = (unsigned char)(0xE0 + (codePoint >> 12)); + buffer[count++] = ((char)byte); + byte = (unsigned char)(0x80 + ((codePoint >> 6) & (0x3F))); + buffer[count++] = ((char)byte); + byte = (unsigned char)(0x80 + (codePoint & 0x3F)); + buffer[count++] = ((char)byte); + } + else + { + unsigned char byte = (unsigned char)(0xF0 + (codePoint >> 18)); + buffer[count++] = ((char)byte); + byte = (unsigned char)(0x80 + ((codePoint >> 12) & 0x3F)); + buffer[count++] = ((char)byte); + byte = (unsigned char)(0x80 + ((codePoint >> 6) & 0x3F)); + buffer[count++] = ((char)byte); + byte = (unsigned char)(0x80 + (codePoint & 0x3F)); + buffer[count++] = ((char)byte); + } + return count; + } + + inline int EncodeUnicodePointToUTF16(unsigned short * buffer, int codePoint) + { + int count = 0; + if (codePoint <= 0xD7FF || (codePoint >= 0xE000 && codePoint <= 0xFFFF)) + buffer[count++] = (unsigned short)codePoint; + else + { + int sub = codePoint - 0x10000; + int high = (sub >> 10) + 0xD800; + int low = (sub & 0x3FF) + 0xDC00; + buffer[count++] = (unsigned short)high; + buffer[count++] = (unsigned short)low; + } + return count; + } + + inline unsigned short ReverseBitOrder(unsigned short val) + { + int byte0 = val & 0xFF; + int byte1 = val >> 8; + return (unsigned short)(byte1 + (byte0 << 8)); + } + + inline int EncodeUnicodePointToUTF16Reversed(unsigned short * buffer, int codePoint) + { + int count = 0; + if (codePoint <= 0xD7FF || (codePoint >= 0xE000 && codePoint <= 0xFFFF)) + buffer[count++] = ReverseBitOrder((unsigned short)codePoint); + else + { + int sub = codePoint - 0x10000; + int high = (sub >> 10) + 0xD800; + int low = (sub & 0x3FF) + 0xDC00; + buffer[count++] = ReverseBitOrder((unsigned short)high); + buffer[count++] = ReverseBitOrder((unsigned short)low); + } + return count; + } + + class Encoding + { + public: + static Encoding * UTF8, * UTF16, *UTF16Reversed, * UTF32; + virtual void GetBytes(List& buffer, const String & str) = 0; + virtual String ToString(const char * buffer, int length) = 0; + virtual ~Encoding() + {} + }; + + class StreamWriter : public TextWriter + { + private: + List encodingBuffer; + RefPtr stream; + Encoding * encoding; + public: + StreamWriter(const String & path, Encoding * encoding = Encoding::UTF8); + StreamWriter(RefPtr stream, Encoding * encoding = Encoding::UTF8); + virtual void Write(const String & str); + virtual void Write(const char * str); + virtual void Close() + { + stream->Close(); + } + void ReleaseStream() + { + stream = 0; + } + }; + + class StreamReader : public TextReader + { + private: + RefPtr stream; + List buffer; + Encoding * encoding; + Index ptr; + char ReadBufferChar(); + void ReadBuffer(); + + Encoding * DetermineEncoding(); + protected: + virtual void ReadChar() + { + decodedCharPtr = 0; + int codePoint = 0; + if (encoding == Encoding::UTF8) + codePoint = GetUnicodePointFromUTF8([&](int) {return ReadBufferChar(); }); + else if (encoding == Encoding::UTF16) + codePoint = GetUnicodePointFromUTF16([&](int) {return ReadBufferChar(); }); + else if (encoding == Encoding::UTF16Reversed) + codePoint = GetUnicodePointFromUTF16Reversed([&](int) {return ReadBufferChar(); }); + else if (encoding == Encoding::UTF32) + codePoint = GetUnicodePointFromUTF32([&](int) {return ReadBufferChar(); }); + decodedCharSize = EncodeUnicodePointToUTF8(decodedChar, codePoint); + } + public: + StreamReader(const String & path); + StreamReader(RefPtr stream, Encoding * encoding = nullptr); + virtual String ReadLine(); + virtual String ReadToEnd(); + virtual bool IsEnd() + { + return ptr == buffer.getCount() && stream->IsEnd(); + } + virtual void Close() + { + stream->Close(); + } + void ReleaseStream() + { + stream = 0; + } + }; +} + +#endif diff --git a/source/core/slang-token-reader.cpp b/source/core/slang-token-reader.cpp new file mode 100644 index 000000000..a15dcda9c --- /dev/null +++ b/source/core/slang-token-reader.cpp @@ -0,0 +1,768 @@ +#include "slang-token-reader.h" + +namespace Slang +{ + enum class TokenizeErrorType + { + InvalidCharacter, InvalidEscapeSequence + }; + + enum class State + { + Start, Identifier, Operator, Int, Hex, Fixed, Double, Char, String, MultiComment, SingleComment + }; + + enum class LexDerivative + { + None, Line, File + }; + + inline bool IsLetter(char ch) + { + return ((ch >= 'a' && ch <= 'z') || + (ch >= 'A' && ch <= 'Z') || ch == '_'); + } + + inline bool IsDigit(char ch) + { + return ch >= '0' && ch <= '9'; + } + + inline bool IsPunctuation(char ch) + { + return ch == '+' || ch == '-' || ch == '*' || ch == '/' || ch == '%' || + ch == '!' || ch == '^' || ch == '&' || ch == '(' || ch == ')' || + ch == '=' || ch == '{' || ch == '}' || ch == '[' || ch == ']' || + ch == '|' || ch == ';' || ch == ',' || ch == '.' || ch == '<' || + ch == '>' || ch == '~' || ch == '@' || ch == ':' || ch == '?' || ch == '#'; + } + + inline bool IsWhiteSpace(char ch) + { + return (ch == ' ' || ch == '\t' || ch == '\n' || ch == '\r' || ch == '\v'); + } + + void ParseOperators(const String & str, List & tokens, TokenFlags& tokenFlags, int line, int col, int startPos, String fileName) + { + Index pos = 0; + while (pos < str.getLength()) + { + wchar_t curChar = str[pos]; + wchar_t nextChar = (pos < str.getLength() - 1) ? str[pos + 1] : '\0'; + wchar_t nextNextChar = (pos < str.getLength() - 2) ? str[pos + 2] : '\0'; + auto InsertToken = [&](TokenType type, const String & ct) + { + tokens.add(Token(type, ct, line, int(col + pos), int(pos + startPos), fileName, tokenFlags)); + tokenFlags = 0; + }; + switch (curChar) + { + case '+': + if (nextChar == '+') + { + InsertToken(TokenType::OpInc, "++"); + pos += 2; + } + else if (nextChar == '=') + { + InsertToken(TokenType::OpAddAssign, "+="); + pos += 2; + } + else + { + InsertToken(TokenType::OpAdd, "+"); + pos++; + } + break; + case '-': + if (nextChar == '-') + { + InsertToken(TokenType::OpDec, "--"); + pos += 2; + } + else if (nextChar == '=') + { + InsertToken(TokenType::OpSubAssign, "-="); + pos += 2; + } + else if (nextChar == '>') + { + InsertToken(TokenType::RightArrow, "->"); + pos += 2; + } + else + { + InsertToken(TokenType::OpSub, "-"); + pos++; + } + break; + case '*': + if (nextChar == '=') + { + InsertToken(TokenType::OpMulAssign, "*="); + pos += 2; + } + else + { + InsertToken(TokenType::OpMul, "*"); + pos++; + } + break; + case '/': + if (nextChar == '=') + { + InsertToken(TokenType::OpDivAssign, "/="); + pos += 2; + } + else + { + InsertToken(TokenType::OpDiv, "/"); + pos++; + } + break; + case '%': + if (nextChar == '=') + { + InsertToken(TokenType::OpModAssign, "%="); + pos += 2; + } + else + { + InsertToken(TokenType::OpMod, "%"); + pos++; + } + break; + case '|': + if (nextChar == '|') + { + InsertToken(TokenType::OpOr, "||"); + pos += 2; + } + else if (nextChar == '=') + { + InsertToken(TokenType::OpOrAssign, "|="); + pos += 2; + } + else + { + InsertToken(TokenType::OpBitOr, "|"); + pos++; + } + break; + case '&': + if (nextChar == '&') + { + InsertToken(TokenType::OpAnd, "&&"); + pos += 2; + } + else if (nextChar == '=') + { + InsertToken(TokenType::OpAndAssign, "&="); + pos += 2; + } + else + { + InsertToken(TokenType::OpBitAnd, "&"); + pos++; + } + break; + case '^': + if (nextChar == '=') + { + InsertToken(TokenType::OpXorAssign, "^="); + pos += 2; + } + else + { + InsertToken(TokenType::OpBitXor, "^"); + pos++; + } + break; + case '>': + if (nextChar == '>') + { + if (nextNextChar == '=') + { + InsertToken(TokenType::OpShrAssign, ">>="); + pos += 3; + } + else + { + InsertToken(TokenType::OpRsh, ">>"); + pos += 2; + } + } + else if (nextChar == '=') + { + InsertToken(TokenType::OpGeq, ">="); + pos += 2; + } + else + { + InsertToken(TokenType::OpGreater, ">"); + pos++; + } + break; + case '<': + if (nextChar == '<') + { + if (nextNextChar == '=') + { + InsertToken(TokenType::OpShlAssign, "<<="); + pos += 3; + } + else + { + InsertToken(TokenType::OpLsh, "<<"); + pos += 2; + } + } + else if (nextChar == '=') + { + InsertToken(TokenType::OpLeq, "<="); + pos += 2; + } + else + { + InsertToken(TokenType::OpLess, "<"); + pos++; + } + break; + case '=': + if (nextChar == '=') + { + InsertToken(TokenType::OpEql, "=="); + pos += 2; + } + else + { + InsertToken(TokenType::OpAssign, "="); + pos++; + } + break; + case '!': + if (nextChar == '=') + { + InsertToken(TokenType::OpNeq, "!="); + pos += 2; + } + else + { + InsertToken(TokenType::OpNot, "!"); + pos++; + } + break; + case '?': + InsertToken(TokenType::QuestionMark, "?"); + pos++; + break; + case '@': + InsertToken(TokenType::At, "@"); + pos++; + break; + case '#': + if (nextChar == '#') + { + InsertToken(TokenType::PoundPound, "##"); + pos += 2; + } + else + { + InsertToken(TokenType::Pound, "#"); + pos++; + } + pos++; + break; + case ':': + InsertToken(TokenType::Colon, ":"); + pos++; + break; + case '~': + InsertToken(TokenType::OpBitNot, "~"); + pos++; + break; + case ';': + InsertToken(TokenType::Semicolon, ";"); + pos++; + break; + case ',': + InsertToken(TokenType::Comma, ","); + pos++; + break; + case '.': + InsertToken(TokenType::Dot, "."); + pos++; + break; + case '{': + InsertToken(TokenType::LBrace, "{"); + pos++; + break; + case '}': + InsertToken(TokenType::RBrace, "}"); + pos++; + break; + case '[': + InsertToken(TokenType::LBracket, "["); + pos++; + break; + case ']': + InsertToken(TokenType::RBracket, "]"); + pos++; + break; + case '(': + InsertToken(TokenType::LParent, "("); + pos++; + break; + case ')': + InsertToken(TokenType::RParent, ")"); + pos++; + break; + } + } + } + + List TokenizeText(const String & fileName, const String & text) + { + Index lastPos = 0, pos = 0; + int line = 1, col = 0; + String file = fileName; + State state = State::Start; + StringBuilder tokenBuilder; + int tokenLine, tokenCol; + List tokenList; + LexDerivative derivative = LexDerivative::None; + TokenFlags tokenFlags = TokenFlag::AtStartOfLine; + auto InsertToken = [&](TokenType type) + { + derivative = LexDerivative::None; + tokenList.add(Token(type, tokenBuilder.ToString(), tokenLine, tokenCol, int(pos), file, tokenFlags)); + tokenFlags = 0; + tokenBuilder.Clear(); + }; + auto ProcessTransferChar = [&](char nextChar) + { + switch (nextChar) + { + case '\\': + case '\"': + case '\'': + tokenBuilder.Append(nextChar); + break; + case 't': + tokenBuilder.Append('\t'); + break; + case 's': + tokenBuilder.Append(' '); + break; + case 'n': + tokenBuilder.Append('\n'); + break; + case 'r': + tokenBuilder.Append('\r'); + break; + case 'b': + tokenBuilder.Append('\b'); + break; + } + }; + while (pos <= text.getLength()) + { + char curChar = (pos < text.getLength() ? text[pos] : ' '); + char nextChar = (pos < text.getLength() - 1) ? text[pos + 1] : '\0'; + if (lastPos != pos) + { + if (curChar == '\n') + { + line++; + col = 0; + } + else + col++; + lastPos = pos; + } + + switch (state) + { + case State::Start: + if (IsLetter(curChar)) + { + state = State::Identifier; + tokenLine = line; + tokenCol = col; + } + else if (IsDigit(curChar)) + { + state = State::Int; + tokenLine = line; + tokenCol = col; + } + else if (curChar == '\'') + { + state = State::Char; + pos++; + tokenLine = line; + tokenCol = col; + } + else if (curChar == '"') + { + state = State::String; + pos++; + tokenLine = line; + tokenCol = col; + } + else if (curChar == '\r' || curChar == '\n') + { + tokenFlags |= TokenFlag::AtStartOfLine | TokenFlag::AfterWhitespace; + pos++; + } + else if (curChar == ' ' || curChar == '\t' || curChar == -62 || curChar == -96) // -62/-96:non-break space + { + tokenFlags |= TokenFlag::AfterWhitespace; + pos++; + } + else if (curChar == '/' && nextChar == '/') + { + state = State::SingleComment; + pos += 2; + } + else if (curChar == '/' && nextChar == '*') + { + pos += 2; + state = State::MultiComment; + } + else if (curChar == '.' && IsDigit(nextChar)) + { + tokenBuilder.Append("0."); + state = State::Fixed; + pos++; + } + else if (IsPunctuation(curChar)) + { + state = State::Operator; + tokenLine = line; + tokenCol = col; + } + else + { + pos++; + } + break; + case State::Identifier: + if (IsLetter(curChar) || IsDigit(curChar)) + { + tokenBuilder.Append(curChar); + pos++; + } + else + { + auto tokenStr = tokenBuilder.ToString(); +#if 0 + if (tokenStr == "#line_reset#") + { + line = 0; + col = 0; + tokenBuilder.Clear(); + } + else if (tokenStr == "#line") + { + derivative = LexDerivative::Line; + tokenBuilder.Clear(); + } + else if (tokenStr == "#file") + { + derivative = LexDerivative::File; + tokenBuilder.Clear(); + line = 0; + col = 0; + } + else +#endif + InsertToken(TokenType::Identifier); + state = State::Start; + } + break; + case State::Operator: + if (IsPunctuation(curChar) && !((curChar == '/' && nextChar == '/') || (curChar == '/' && nextChar == '*'))) + { + tokenBuilder.Append(curChar); + pos++; + } + else + { + //do token analyze + ParseOperators(tokenBuilder.ToString(), tokenList, tokenFlags, tokenLine, tokenCol, (int)(pos - tokenBuilder.getLength()), file); + tokenBuilder.Clear(); + state = State::Start; + } + break; + case State::Int: + if (IsDigit(curChar)) + { + tokenBuilder.Append(curChar); + pos++; + } + else if (curChar == '.') + { + state = State::Fixed; + tokenBuilder.Append(curChar); + pos++; + } + else if (curChar == 'e' || curChar == 'E') + { + state = State::Double; + tokenBuilder.Append(curChar); + if (nextChar == '-' || nextChar == '+') + { + tokenBuilder.Append(nextChar); + pos++; + } + pos++; + } + else if (curChar == 'x') + { + state = State::Hex; + tokenBuilder.Append(curChar); + pos++; + } + else if (curChar == 'u') + { + pos++; + tokenBuilder.Append(curChar); + InsertToken(TokenType::IntLiteral); + state = State::Start; + } + else + { + if (derivative == LexDerivative::Line) + { + derivative = LexDerivative::None; + line = StringToInt(tokenBuilder.ToString()) - 1; + col = 0; + tokenBuilder.Clear(); + } + else + { + InsertToken(TokenType::IntLiteral); + } + state = State::Start; + } + break; + case State::Hex: + if (IsDigit(curChar) || (curChar >= 'a' && curChar <= 'f') || (curChar >= 'A' && curChar <= 'F')) + { + tokenBuilder.Append(curChar); + pos++; + } + else + { + InsertToken(TokenType::IntLiteral); + state = State::Start; + } + break; + case State::Fixed: + if (IsDigit(curChar)) + { + tokenBuilder.Append(curChar); + pos++; + } + else if (curChar == 'e' || curChar == 'E') + { + state = State::Double; + tokenBuilder.Append(curChar); + if (nextChar == '-' || nextChar == '+') + { + tokenBuilder.Append(nextChar); + pos++; + } + pos++; + } + else + { + if (curChar == 'f') + pos++; + InsertToken(TokenType::DoubleLiteral); + state = State::Start; + } + break; + case State::Double: + if (IsDigit(curChar)) + { + tokenBuilder.Append(curChar); + pos++; + } + else + { + if (curChar == 'f') + pos++; + InsertToken(TokenType::DoubleLiteral); + state = State::Start; + } + break; + case State::String: + if (curChar != '"') + { + if (curChar == '\\') + { + ProcessTransferChar(nextChar); + pos++; + } + else + tokenBuilder.Append(curChar); + } + else + { + if (derivative == LexDerivative::File) + { + derivative = LexDerivative::None; + file = tokenBuilder.ToString(); + tokenBuilder.Clear(); + } + else + { + InsertToken(TokenType::StringLiteral); + } + state = State::Start; + } + pos++; + break; + case State::Char: + if (curChar != '\'') + { + if (curChar == '\\') + { + ProcessTransferChar(nextChar); + pos++; + } + else + tokenBuilder.Append(curChar); + } + else + { + InsertToken(TokenType::CharLiteral); + state = State::Start; + } + pos++; + break; + case State::SingleComment: + if (curChar == '\n') + { + state = State::Start; + tokenFlags |= TokenFlag::AtStartOfLine | TokenFlag::AfterWhitespace; + } + pos++; + break; + case State::MultiComment: + if (curChar == '*' && nextChar == '/') + { + state = State::Start; + tokenFlags |= TokenFlag::AfterWhitespace; + pos += 2; + } + else + pos++; + break; + } + } + return tokenList; + } + List TokenizeText(const String & text) + { + return TokenizeText("", text); + } + + String EscapeStringLiteral(String str) + { + StringBuilder sb; + sb << "\""; + const Index length = str.getLength(); + const char*const data = str.getBuffer(); + for (Index i = 0; i < length; i++) + { + switch (data[i]) + { + case ' ': + sb << "\\s"; + break; + case '\n': + sb << "\\n"; + break; + case '\r': + sb << "\\r"; + break; + case '\t': + sb << "\\t"; + break; + case '\v': + sb << "\\v"; + break; + case '\'': + sb << "\\\'"; + break; + case '\"': + sb << "\\\""; + break; + case '\\': + sb << "\\\\"; + break; + default: + sb << data[i]; + break; + } + } + sb << "\""; + return sb.ProduceString(); + } + + String UnescapeStringLiteral(String str) + { + StringBuilder sb; + const Index length = str.getLength(); + const char*const data = str.getBuffer(); + for (Index i = 0; i < length; i++) + { + if (data[i] == '\\' && i < length - 1) + { + switch (data[i + 1]) + { + case 's': + sb << " "; + break; + case 't': + sb << '\t'; + break; + case 'n': + sb << '\n'; + break; + case 'r': + sb << '\r'; + break; + case 'v': + sb << '\v'; + break; + case '\'': + sb << '\''; + break; + case '\"': + sb << "\""; + break; + case '\\': + sb << "\\"; + break; + default: + i = i - 1; + sb << data[i]; + } + i++; + } + else + sb << data[i]; + } + return sb.ProduceString(); + } + + TokenReader::TokenReader(String text) + { + this->tokens = TokenizeText("", text); + tokenPtr = 0; + } +} diff --git a/source/core/slang-token-reader.h b/source/core/slang-token-reader.h new file mode 100644 index 000000000..f8a455452 --- /dev/null +++ b/source/core/slang-token-reader.h @@ -0,0 +1,260 @@ +#ifndef SLANG_CORE_TOKEN_READER_H +#define SLANG_CORE_TOKEN_READER_H + +#include "slang-basic.h" + +namespace Slang +{ + /* NOTE! This TokenReader is NOT used by the main slang compiler !*/ + + enum class TokenType + { + EndOfFile = -1, + // illegal + Unknown, + // identifier + Identifier, + // constant + IntLiteral, DoubleLiteral, StringLiteral, CharLiteral, + // operators + Semicolon, Comma, Dot, LBrace, RBrace, LBracket, RBracket, LParent, RParent, + OpAssign, OpAdd, OpSub, OpMul, OpDiv, OpMod, OpNot, OpBitNot, OpLsh, OpRsh, + OpEql, OpNeq, OpGreater, OpLess, OpGeq, OpLeq, + OpAnd, OpOr, OpBitXor, OpBitAnd, OpBitOr, + OpInc, OpDec, OpAddAssign, OpSubAssign, OpMulAssign, OpDivAssign, OpModAssign, + OpShlAssign, OpShrAssign, OpOrAssign, OpAndAssign, OpXorAssign, + + QuestionMark, Colon, RightArrow, At, Pound, PoundPound, Scope, + }; + + class CodePosition + { + public: + int Line = -1, Col = -1, Pos = -1; + String FileName; + String ToString() + { + StringBuilder sb(100); + sb << FileName; + if (Line != -1) + sb << "(" << Line << ")"; + return sb.ProduceString(); + } + CodePosition() = default; + CodePosition(int line, int col, int pos, String fileName) + { + Line = line; + Col = col; + Pos = pos; + this->FileName = fileName; + } + bool operator < (const CodePosition & pos) const + { + return FileName < pos.FileName || (FileName == pos.FileName && Line < pos.Line) || + (FileName == pos.FileName && Line == pos.Line && Col < pos.Col); + } + bool operator == (const CodePosition & pos) const + { + return FileName == pos.FileName && Line == pos.Line && Col == pos.Col; + } + }; + + enum TokenFlag : unsigned int + { + AtStartOfLine = 1 << 0, + AfterWhitespace = 1 << 1, + }; + typedef unsigned int TokenFlags; + + class Token + { + public: + TokenType Type = TokenType::Unknown; + String Content; + CodePosition Position; + TokenFlags flags; + Token() = default; + Token(TokenType type, const String & content, int line, int col, int pos, String fileName, TokenFlags flags = 0) + : flags(flags) + { + Type = type; + Content = content; + Position = CodePosition(line, col, pos, fileName); + } + }; + + class TextFormatException : public Exception + { + public: + TextFormatException(String message) + : Exception(message) + {} + }; + + class TokenReader + { + private: + bool legal; + List tokens; + int tokenPtr; + public: + TokenReader(String text); + int ReadInt() + { + auto token = ReadToken(); + bool neg = false; + if (token.Content == '-') + { + neg = true; + token = ReadToken(); + } + if (token.Type == TokenType::IntLiteral) + { + if (neg) + return -StringToInt(token.Content); + else + return StringToInt(token.Content); + } + throw TextFormatException("Text parsing error: int expected."); + } + unsigned int ReadUInt() + { + auto token = ReadToken(); + if (token.Type == TokenType::IntLiteral) + { + return StringToUInt(token.Content); + } + throw TextFormatException("Text parsing error: int expected."); + } + double ReadDouble() + { + auto token = ReadToken(); + bool neg = false; + if (token.Content == '-') + { + neg = true; + token = ReadToken(); + } + if (token.Type == TokenType::DoubleLiteral || token.Type == TokenType::IntLiteral) + { + if (neg) + return -StringToDouble(token.Content); + else + return StringToDouble(token.Content); + } + throw TextFormatException("Text parsing error: floating point value expected."); + } + float ReadFloat() + { + return (float)ReadDouble(); + } + String ReadWord() + { + auto token = ReadToken(); + if (token.Type == TokenType::Identifier) + { + return token.Content; + } + throw TextFormatException("Text parsing error: identifier expected."); + } + String Read(const char * expectedStr) + { + auto token = ReadToken(); + if (token.Content == expectedStr) + { + return token.Content; + } + throw TextFormatException("Text parsing error: \'" + String(expectedStr) + "\' expected."); + } + String Read(String expectedStr) + { + auto token = ReadToken(); + if (token.Content == expectedStr) + { + return token.Content; + } + throw TextFormatException("Text parsing error: \'" + expectedStr + "\' expected."); + } + + String ReadStringLiteral() + { + auto token = ReadToken(); + if (token.Type == TokenType::StringLiteral) + { + return token.Content; + } + throw TextFormatException("Text parsing error: string literal expected."); + } + void Back(int count) + { + tokenPtr -= count; + } + Token ReadToken() + { + if (tokenPtr < (int)tokens.getCount()) + { + auto &rs = tokens[tokenPtr]; + tokenPtr++; + return rs; + } + throw TextFormatException("Unexpected ending."); + } + Token NextToken(int offset = 0) + { + if (tokenPtr + offset < (int)tokens.getCount()) + return tokens[tokenPtr + offset]; + else + { + Token rs; + rs.Type = TokenType::Unknown; + return rs; + } + } + bool LookAhead(String token) + { + if (tokenPtr < (int)tokens.getCount()) + { + auto next = NextToken(); + return next.Content == token; + } + else + { + return false; + } + } + bool IsEnd() + { + return tokenPtr == (int)tokens.getCount(); + } + public: + bool IsLegalText() + { + return legal; + } + }; + + inline List Split(String text, char c) + { + List result; + StringBuilder sb; + for (Index i = 0; i < text.getLength(); i++) + { + if (text[i] == c) + { + auto str = sb.ToString(); + if (str.getLength() != 0) + result.add(str); + sb.Clear(); + } + else + sb << text[i]; + } + auto lastStr = sb.ToString(); + if (lastStr.getLength()) + result.add(lastStr); + return result; + } +} + + +#endif diff --git a/source/core/slang-type-traits.h b/source/core/slang-type-traits.h new file mode 100644 index 000000000..ccd1fb29c --- /dev/null +++ b/source/core/slang-type-traits.h @@ -0,0 +1,46 @@ +#ifndef SLANG_CORE_TYPE_TRAITS_H +#define SLANG_CORE_TYPE_TRAITS_H + +namespace Slang +{ + struct TraitResultYes + { + char x; + }; + struct TraitResultNo + { + char x[2]; + }; + + template + struct IsBaseOfTraitHost + { + operator B*() const { return nullptr; } + operator D*() { return nullptr; } + }; + + template + struct IsBaseOf + { + template + static TraitResultYes Check(D*, T) { return TraitResultYes(); } + static TraitResultNo Check(B*, int) { return TraitResultNo(); } + enum { Value = sizeof(Check(IsBaseOfTraitHost(), int())) == sizeof(TraitResultYes) }; + }; + + template + struct EnableIf {}; + + template + struct EnableIf { typedef T type; }; + + template + struct IsConvertible + { + static TraitResultYes Use(B) { return TraitResultYes(); }; + static TraitResultNo Use(...) { return TraitResultNo(); }; + enum { Value = sizeof(Use(*(D*)(nullptr))) == sizeof(TraitResultYes) }; + }; +} + +#endif diff --git a/source/core/slang-uint-set.h b/source/core/slang-uint-set.h index 25f0c9269..77930ba0d 100644 --- a/source/core/slang-uint-set.h +++ b/source/core/slang-uint-set.h @@ -1,9 +1,9 @@ -#ifndef SLANG_UINT_SET_H -#define SLANG_UINT_SET_H +#ifndef SLANG_CORE_UINT_SET_H +#define SLANG_CORE_UINT_SET_H -#include "list.h" +#include "slang-list.h" #include "slang-math.h" -#include "common.h" +#include "slang-common.h" #include diff --git a/source/core/slang-writer.cpp b/source/core/slang-writer.cpp index 2c6f99bf9..5b643fff8 100644 --- a/source/core/slang-writer.cpp +++ b/source/core/slang-writer.cpp @@ -1,6 +1,6 @@ #include "slang-writer.h" -#include "platform.h" +#include "slang-platform.h" #include "slang-string-util.h" // Includes to allow us to control console diff --git a/source/core/slang-writer.h b/source/core/slang-writer.h index 463450ac9..6e26d6750 100644 --- a/source/core/slang-writer.h +++ b/source/core/slang-writer.h @@ -1,10 +1,10 @@ -#ifndef SLANG_WRITER_H -#define SLANG_WRITER_H +#ifndef SLANG_CORE_WRITER_H +#define SLANG_CORE_WRITER_H #include "slang-string.h" #include "../../slang-com-helper.h" -#include "../../source/core/list.h" +#include "slang-list.h" namespace Slang { diff --git a/source/core/smart-pointer.h b/source/core/smart-pointer.h deleted file mode 100644 index aa5c06e02..000000000 --- a/source/core/smart-pointer.h +++ /dev/null @@ -1,250 +0,0 @@ -#ifndef FUNDAMENTAL_LIB_SMART_POINTER_H -#define FUNDAMENTAL_LIB_SMART_POINTER_H - -#include "common.h" -#include "hash.h" -#include "type-traits.h" - -#include "../../slang.h" - -namespace Slang -{ - // Base class for all reference-counted objects - class RefObject - { - private: - UInt referenceCount; - - public: - RefObject() - : referenceCount(0) - {} - - RefObject(const RefObject &) - : referenceCount(0) - {} - - virtual ~RefObject() - {} - - UInt addReference() - { - return ++referenceCount; - } - - UInt decreaseReference() - { - return --referenceCount; - } - - UInt releaseReference() - { - SLANG_ASSERT(referenceCount != 0); - if(--referenceCount == 0) - { - delete this; - return 0; - } - return referenceCount; - } - - bool isUniquelyReferenced() - { - SLANG_ASSERT(referenceCount != 0); - return referenceCount == 1; - } - - UInt debugGetReferenceCount() - { - return referenceCount; - } - }; - - SLANG_FORCE_INLINE void addReference(RefObject* obj) - { - if(obj) obj->addReference(); - } - - SLANG_FORCE_INLINE void releaseReference(RefObject* obj) - { - if(obj) obj->releaseReference(); - } - - // For straight dynamic cast. - // Use instead of dynamic_cast as it allows for replacement without using Rtti in the future - template - SLANG_FORCE_INLINE T* dynamicCast(RefObject* obj) { return dynamic_cast(obj); } - template - SLANG_FORCE_INLINE const T* dynamicCast(const RefObject* obj) { return dynamic_cast(obj); } - - // Like a dynamicCast, but allows a type to implement a specific implementation that is suitable for it - template - SLANG_FORCE_INLINE T* as(RefObject* obj) { return dynamicCast(obj); } - template - SLANG_FORCE_INLINE const T* as(const RefObject* obj) { return dynamicCast(obj); } - - // "Smart" pointer to a reference-counted object - template - struct RefPtr - { - RefPtr() - : pointer(nullptr) - {} - - RefPtr(T* p) - : pointer(p) - { - addReference(p); - } - - RefPtr(RefPtr const& p) - : pointer(p.pointer) - { - addReference(p.pointer); - } - - RefPtr(RefPtr&& p) - : pointer(p.pointer) - { - p.pointer = nullptr; - } - - template - RefPtr(RefPtr const& p, - typename EnableIf::Value, void>::type * = 0) - : pointer((U*) p) - { - addReference((U*) p); - } - -#if 0 - void operator=(T* p) - { - T* old = pointer; - addReference(p); - pointer = p; - releaseReference(old); - } -#endif - - void operator=(RefPtr const& p) - { - T* old = pointer; - addReference(p.pointer); - pointer = p.pointer; - releaseReference(old); - } - - void operator=(RefPtr&& p) - { - T* old = pointer; - pointer = p.pointer; - p.pointer = old; - } - - template - typename EnableIf::value, void>::type - operator=(RefPtr const& p) - { - T* old = pointer; - addReference(p.pointer); - pointer = p.pointer; - releaseReference(old); - } - - int GetHashCode() - { - // Note: We need a `RefPtr` to hash the same as a `T*`, - // so that a `T*` can be used as a key in a dictionary with - // `RefPtr` keys, and vice versa. - // - return Slang::GetHashCode(pointer); - } - - bool operator==(const T * ptr) const - { - return pointer == ptr; - } - - bool operator!=(const T * ptr) const - { - return pointer != ptr; - } - - bool operator==(RefPtr const& ptr) const - { - return pointer == ptr.pointer; - } - - bool operator!=(RefPtr const& ptr) const - { - return pointer != ptr.pointer; - } - - template - RefPtr dynamicCast() const - { - return RefPtr(Slang::dynamicCast(pointer)); - } - - template - RefPtr as() const - { - return RefPtr(Slang::as(pointer)); - } - - template - bool is() const { return Slang::as(pointer) != nullptr; } - - ~RefPtr() - { - releaseReference((Slang::RefObject*) pointer); - } - - T& operator*() const - { - return *pointer; - } - - T* operator->() const - { - return pointer; - } - - T * Ptr() const - { - return pointer; - } - - operator T*() const - { - return pointer; - } - - void attach(T* p) - { - T* old = pointer; - pointer = p; - releaseReference(old); - } - - T* detach() - { - auto rs = pointer; - pointer = nullptr; - return rs; - } - - /// Get ready for writing (nulls contents) - SLANG_FORCE_INLINE T** writeRef() { *this = nullptr; return &pointer; } - - /// Get for read access - SLANG_FORCE_INLINE T*const* readRef() const { return &pointer; } - - private: - T* pointer; - - }; -} - -#endif diff --git a/source/core/stream.cpp b/source/core/stream.cpp deleted file mode 100644 index de0a8b8f3..000000000 --- a/source/core/stream.cpp +++ /dev/null @@ -1,294 +0,0 @@ -#include "stream.h" -#ifdef _WIN32 -#include -#endif -#include "slang-io.h" - -namespace Slang -{ - FileStream::FileStream(const Slang::String & fileName, FileMode fileMode) - { - Init(fileName, fileMode, fileMode==FileMode::Open?FileAccess::Read:FileAccess::Write, FileShare::None); - } - FileStream::FileStream(const Slang::String & fileName, FileMode fileMode, FileAccess access, FileShare share) - { - Init(fileName, fileMode, access, share); - } - void FileStream::Init(const Slang::String & fileName, FileMode fileMode, FileAccess access, FileShare share) - { - const wchar_t * mode = L"rt"; - const char* modeMBCS = "rt"; - switch (fileMode) - { - case Slang::FileMode::Create: - if (access == FileAccess::Read) - throw ArgumentException("Read-only access is incompatible with Create mode."); - else if (access == FileAccess::ReadWrite) - { - mode = L"w+b"; - modeMBCS = "w+b"; - this->fileAccess = FileAccess::ReadWrite; - } - else - { - mode = L"wb"; - modeMBCS = "wb"; - this->fileAccess = FileAccess::Write; - } - break; - case Slang::FileMode::Open: - if (access == FileAccess::Read) - { - mode = L"rb"; - modeMBCS = "rb"; - this->fileAccess = FileAccess::Read; - } - else if (access == FileAccess::ReadWrite) - { - mode = L"r+b"; - modeMBCS = "r+b"; - this->fileAccess = FileAccess::ReadWrite; - } - else - { - mode = L"wb"; - modeMBCS = "wb"; - this->fileAccess = FileAccess::Write; - } - break; - case Slang::FileMode::CreateNew: - if (File::exists(fileName)) - { - throw IOException("Failed opening '" + fileName + "', file already exists."); - } - if (access == FileAccess::Read) - throw ArgumentException("Read-only access is incompatible with Create mode."); - else if (access == FileAccess::ReadWrite) - { - mode = L"w+b"; - this->fileAccess = FileAccess::ReadWrite; - } - else - { - mode = L"wb"; - this->fileAccess = FileAccess::Write; - } - break; - case Slang::FileMode::Append: - if (access == FileAccess::Read) - throw ArgumentException("Read-only access is incompatible with Append mode."); - else if (access == FileAccess::ReadWrite) - { - mode = L"a+b"; - this->fileAccess = FileAccess::ReadWrite; - } - else - { - mode = L"ab"; - this->fileAccess = FileAccess::Write; - } - break; - default: - break; - } -#ifdef _WIN32 - int shFlag = _SH_DENYRW; - switch (share) - { - case Slang::FileShare::None: - shFlag = _SH_DENYRW; - break; - case Slang::FileShare::ReadOnly: - shFlag = _SH_DENYWR; - break; - case Slang::FileShare::WriteOnly: - shFlag = _SH_DENYRD; - break; - case Slang::FileShare::ReadWrite: - shFlag = _SH_DENYNO; - break; - default: - throw ArgumentException("Invalid file share mode."); - break; - } - if (share == Slang::FileShare::None) -#pragma warning(suppress:4996) - handle = _wfopen(fileName.toWString(), mode); - else - handle = _wfsopen(fileName.toWString(), mode, shFlag); -#else - handle = fopen(fileName.getBuffer(), modeMBCS); -#endif - if (!handle) - { - throw IOException("Cannot open file '" + fileName + "'"); - } - } - FileStream::~FileStream() - { - Close(); - } - Int64 FileStream::GetPosition() - { -#if defined(_WIN32) || defined(__CYGWIN__) - fpos_t pos; - fgetpos(handle, &pos); - return pos; -#elif defined(__APPLE__) - return ftell(handle); -#else - fpos64_t pos; - fgetpos64(handle, &pos); - return *(Int64*)(&pos); -#endif - } - void FileStream::Seek(SeekOrigin origin, Int64 offset) - { - int _origin; - switch (origin) - { - case Slang::SeekOrigin::Start: - _origin = SEEK_SET; - endReached = false; - break; - case Slang::SeekOrigin::End: - _origin = SEEK_END; - // JS TODO: This doesn't seem right, the offset can mean it's not at the end - endReached = true; - break; - case Slang::SeekOrigin::Current: - _origin = SEEK_CUR; - endReached = false; - break; - default: - throw NotSupportedException("Unsupported seek origin."); - break; - } -#ifdef _WIN32 - int rs = _fseeki64(handle, offset, _origin); -#else - int rs = fseek(handle, (int)offset, _origin); -#endif - if (rs != 0) - { - throw IOException("FileStream seek failed."); - } - } - Int64 FileStream::Read(void * buffer, Int64 length) - { - auto bytes = fread_s(buffer, (size_t)length, 1, (size_t)length, handle); - if (bytes == 0 && length > 0) - { - if (!feof(handle)) - throw IOException("FileStream read failed."); - else if (endReached) - throw EndOfStreamException("End of file is reached."); - endReached = true; - } - return (int)bytes; - } - Int64 FileStream::Write(const void * buffer, Int64 length) - { - auto bytes = (Int64)fwrite(buffer, 1, (size_t)length, handle); - if (bytes < length) - { - throw IOException("FileStream write failed."); - } - return bytes; - } - bool FileStream::CanRead() - { - return ((int)fileAccess & (int)FileAccess::Read) != 0; - } - bool FileStream::CanWrite() - { - return ((int)fileAccess & (int)FileAccess::Write) != 0; - } - void FileStream::Close() - { - if (handle) - { - fclose(handle); - handle = 0; - } - } - bool FileStream::IsEnd() - { - return endReached; - } - - // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! MemoryStream !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! - - void MemoryStream::Seek(SeekOrigin origin, Int64 offset) - { - Int64 pos = 0; - switch (origin) - { - case Slang::SeekOrigin::Start: - pos = offset; - break; - case Slang::SeekOrigin::End: - pos = Int64(m_contents.getCount()) + offset; - break; - case Slang::SeekOrigin::Current: - pos = Int64(m_position) + offset; - break; - default: - throw NotSupportedException("Unsupported seek origin."); - break; - } - - m_atEnd = false; - - // Clamp to the valid range - pos = (pos < 0) ? 0 : pos; - pos = (pos > Int64(m_contents.getCount())) ? Int64(m_contents.getCount()) : pos; - - m_position = UInt(pos); - } - - Int64 MemoryStream::Read(void * buffer, Int64 length) - { - if (!CanRead()) - { - throw IOException("Cannot read this stream."); - } - - const Int64 maxRead = Int64(m_contents.getCount() - m_position); - - if (maxRead == 0 && length > 0) - { - m_atEnd = true; - throw EndOfStreamException("End of file is reached."); - } - - length = length > maxRead ? maxRead : length; - - ::memcpy(buffer, m_contents.begin() + m_position, size_t(length)); - m_position += UInt(length); - return maxRead; - } - - Int64 MemoryStream::Write(const void * buffer, Int64 length) - { - if (!CanWrite()) - { - throw IOException("Cannot write this stream."); - } - - if (m_position == m_contents.getCount()) - { - m_contents.addRange((const uint8_t*)buffer, UInt(length)); - } - else - { - m_contents.insertRange(m_position, (const uint8_t*)buffer, UInt(length)); - } - - m_atEnd = false; - - m_position += UInt(length); - return length; - } - -} diff --git a/source/core/stream.h b/source/core/stream.h deleted file mode 100644 index 618aadbd4..000000000 --- a/source/core/stream.h +++ /dev/null @@ -1,113 +0,0 @@ -#ifndef CORE_LIB_STREAM_H -#define CORE_LIB_STREAM_H - -#include "basic.h" - -namespace Slang -{ - class IOException : public Exception - { - public: - IOException() - {} - IOException(const String & message) - : Slang::Exception(message) - { - } - }; - - class EndOfStreamException : public IOException - { - public: - EndOfStreamException() - {} - EndOfStreamException(const String & message) - : IOException(message) - { - } - }; - - enum class SeekOrigin - { - Start, End, Current - }; - - class Stream : public RefObject - { - public: - virtual ~Stream() {} - virtual Int64 GetPosition()=0; - virtual void Seek(SeekOrigin origin, Int64 offset)=0; - virtual Int64 Read(void * buffer, Int64 length) = 0; - virtual Int64 Write(const void * buffer, Int64 length) = 0; - virtual bool IsEnd() = 0; - virtual bool CanRead() = 0; - virtual bool CanWrite() = 0; - virtual void Close() = 0; - }; - - enum class FileMode - { - Create, Open, CreateNew, Append - }; - - enum class FileAccess - { - None = 0, Read = 1, Write = 2, ReadWrite = 3 - }; - - enum class FileShare - { - None, ReadOnly, WriteOnly, ReadWrite - }; - - class MemoryStream : public Stream - { - public: - virtual Int64 GetPosition() SLANG_OVERRIDE { return m_position; } - virtual void Seek(SeekOrigin origin, Int64 offset) SLANG_OVERRIDE; - virtual Int64 Read(void * buffer, Int64 length) SLANG_OVERRIDE; - virtual Int64 Write(const void * buffer, Int64 length) SLANG_OVERRIDE; - virtual bool IsEnd() SLANG_OVERRIDE { return m_atEnd; } - virtual bool CanRead() SLANG_OVERRIDE { return (int(m_access) & int(FileAccess::Read)) != 0; } - virtual bool CanWrite() SLANG_OVERRIDE { return (int(m_access) & int(FileAccess::Write)) != 0; } - virtual void Close() SLANG_OVERRIDE { m_access = FileAccess::None; } - - MemoryStream(FileAccess access) : - m_access(access), - m_position(0), - m_atEnd(false) - {} - - Index m_position; - - bool m_atEnd; ///< Happens when a read is done and nothing can be returned because already at end - - FileAccess m_access; - List m_contents; - }; - - class FileStream : public Stream - { - private: - FILE * handle; - FileAccess fileAccess; - bool endReached = false; - void Init(const Slang::String & fileName, FileMode fileMode, FileAccess access, FileShare share); - public: - FileStream(const Slang::String & fileName, FileMode fileMode = FileMode::Open); - FileStream(const Slang::String & fileName, FileMode fileMode, FileAccess access, FileShare share); - ~FileStream(); - public: - virtual Int64 GetPosition(); - virtual void Seek(SeekOrigin origin, Int64 offset); - virtual Int64 Read(void * buffer, Int64 length); - virtual Int64 Write(const void * buffer, Int64 length); - virtual bool CanRead(); - virtual bool CanWrite(); - virtual void Close(); - virtual bool IsEnd(); - }; -} - -#endif diff --git a/source/core/text-io.cpp b/source/core/text-io.cpp deleted file mode 100644 index 1f6b44c92..000000000 --- a/source/core/text-io.cpp +++ /dev/null @@ -1,343 +0,0 @@ -#include "text-io.h" -#ifdef _WIN32 -#define WIN32_LEAN_AND_MEAN -#define NOMINMAX -#include -#undef WIN32_LEAN_AND_MEAN -#undef NOMINMAX -#define CONVERT_END_OF_LINE -#endif - -namespace Slang -{ - class Utf8Encoding : public Encoding - { - public: - virtual void GetBytes(List & result, const String & str) override - { - result.addRange(str.getBuffer(), str.getLength()); - } - virtual String ToString(const char * bytes, int /*length*/) override - { - return String(bytes); - } - }; - - class Utf32Encoding : public Encoding - { - public: - virtual void GetBytes(List & result, const String & str) override - { - Index ptr = 0; - while (ptr < str.getLength()) - { - int codePoint = GetUnicodePointFromUTF8([&](int) - { - if (ptr < str.getLength()) - return str[ptr++]; - else - return '\0'; - }); - result.addRange((char*)&codePoint, 4); - } - } - virtual String ToString(const char * bytes, int length) override - { - StringBuilder sb; - int * content = (int*)bytes; - for (int i = 0; i < (length >> 2); i++) - { - char buf[5]; - int count = EncodeUnicodePointToUTF8(buf, content[i]); - for (int j = 0; j < count; j++) - sb.Append(buf[j]); - } - return sb.ProduceString(); - } - }; - - class Utf16Encoding : public Encoding //UTF16 - { - private: - bool reverseOrder = false; - public: - Utf16Encoding(bool pReverseOrder) - : reverseOrder(pReverseOrder) - {} - virtual void GetBytes(List & result, const String & str) override - { - Index ptr = 0; - while (ptr < str.getLength()) - { - int codePoint = GetUnicodePointFromUTF8([&](int) - { - if (ptr < str.getLength()) - return str[ptr++]; - else - return '\0'; - }); - unsigned short buffer[2]; - int count; - if (!reverseOrder) - count = EncodeUnicodePointToUTF16(buffer, codePoint); - else - count = EncodeUnicodePointToUTF16Reversed(buffer, codePoint); - result.addRange((char*)buffer, count * 2); - } - } - virtual String ToString(const char * bytes, int length) override - { - int ptr = 0; - StringBuilder sb; - while (ptr < length) - { - int codePoint = GetUnicodePointFromUTF16([&](int) - { - if (ptr < length) - return bytes[ptr++]; - else - return '\0'; - }); - char buf[5]; - int count = EncodeUnicodePointToUTF8(buf, codePoint); - for (int i = 0; i < count; i++) - sb.Append(buf[i]); - } - return sb.ProduceString(); - } - }; - - Utf8Encoding __utf8Encoding; - Utf16Encoding __utf16Encoding(false); - Utf16Encoding __utf16EncodingReversed(true); - Utf32Encoding __utf32Encoding; - - Encoding * Encoding::UTF8 = &__utf8Encoding; - Encoding * Encoding::UTF16 = &__utf16Encoding; - Encoding * Encoding::UTF16Reversed = &__utf16EncodingReversed; - Encoding * Encoding::UTF32 = &__utf32Encoding; - - const unsigned short Utf16Header = 0xFEFF; - const unsigned short Utf16ReversedHeader = 0xFFFE; - - StreamWriter::StreamWriter(const String & path, Encoding * encoding) - { - this->stream = new FileStream(path, FileMode::Create); - this->encoding = encoding; - if (encoding == Encoding::UTF16) - { - this->stream->Write(&Utf16Header, 2); - } - else if (encoding == Encoding::UTF16Reversed) - { - this->stream->Write(&Utf16ReversedHeader, 2); - } - } - StreamWriter::StreamWriter(RefPtr stream, Encoding * encoding) - { - this->stream = stream; - this->encoding = encoding; - if (encoding == Encoding::UTF16) - { - this->stream->Write(&Utf16Header, 2); - } - else if (encoding == Encoding::UTF16Reversed) - { - this->stream->Write(&Utf16ReversedHeader, 2); - } - } - void StreamWriter::Write(const String & str) - { - encodingBuffer.clear(); - StringBuilder sb; - String newLine; -#ifdef _WIN32 - newLine = "\r\n"; -#else - newLine = "\n"; -#endif - for (Index i = 0; i < str.getLength(); i++) - { - if (str[i] == '\r') - sb << newLine; - else if (str[i] == '\n') - { - if (i > 0 && str[i - 1] != '\r') - sb << newLine; - } - else - sb << str[i]; - } - encoding->GetBytes(encodingBuffer, sb.ProduceString()); - stream->Write(encodingBuffer.getBuffer(), encodingBuffer.getCount()); - } - void StreamWriter::Write(const char * str) - { - Write(String(str)); - } - - StreamReader::StreamReader(const String & path) - { - stream = new FileStream(path, FileMode::Open); - ReadBuffer(); - encoding = DetermineEncoding(); - if (encoding == 0) - encoding = Encoding::UTF8; - } - StreamReader::StreamReader(RefPtr stream, Encoding * encoding) - { - this->stream = stream; - this->encoding = encoding; - ReadBuffer(); - auto determinedEncoding = DetermineEncoding(); - if (this->encoding == nullptr) - this->encoding = determinedEncoding; - } - - bool HasNullBytes(char * str, int len) - { - bool hasSeenNull = false; - for (int i = 0; i < len - 1; i++) - if (str[i] == 0) - hasSeenNull = true; - else if (hasSeenNull) - return true; - return false; - } - - Encoding * StreamReader::DetermineEncoding() - { - if (buffer.getCount() >= 3 && (unsigned char)(buffer[0]) == 0xEF && (unsigned char)(buffer[1]) == 0xBB && (unsigned char)(buffer[2]) == 0xBF) - { - ptr += 3; - return Encoding::UTF8; - } - else if (*((unsigned short*)(buffer.getBuffer())) == 0xFEFF) - { - ptr += 2; - return Encoding::UTF16; - } - else if (*((unsigned short*)(buffer.getBuffer())) == 0xFFFE) - { - ptr += 2; - return Encoding::UTF16Reversed; - } - else - { - // find null bytes - if (HasNullBytes(buffer.getBuffer(), (int)buffer.getCount())) - { - return Encoding::UTF16; - } - return Encoding::UTF8; - } - } - - void StreamReader::ReadBuffer() - { - buffer.setCount(4096); - memset(buffer.getBuffer(), 0, buffer.getCount() * sizeof(buffer[0])); - auto len = stream->Read(buffer.getBuffer(), buffer.getCount()); - buffer.setCount((int)len); - ptr = 0; - } - - char StreamReader::ReadBufferChar() - { - if (ptrIsEnd()) - ReadBuffer(); - if (ptr - TextWriter & operator << (const T& val) - { - Write(val.ToString()); - return *this; - } - TextWriter & operator << (int value) - { - Write(String(value)); - return *this; - } - TextWriter & operator << (float value) - { - Write(String(value)); - return *this; - } - TextWriter & operator << (double value) - { - Write(String(value)); - return *this; - } - TextWriter & operator << (const char* value) - { - Write(value); - return *this; - } - TextWriter & operator << (const String & val) - { - Write(val); - return *this; - } - TextWriter & operator << (const _EndLine &) - { -#ifdef _WIN32 - Write("\r\n"); -#else - Write("\n"); -#endif - return *this; - } - }; - - template - int GetUnicodePointFromUTF8(const ReadCharFunc & get) - { - int codePoint = 0; - int leading = get(0); - int mask = 0x80; - int count = 0; - while (leading & mask) - { - count++; - mask >>= 1; - } - codePoint = (leading & (mask - 1)); - for (int i = 1; i <= count - 1; i++) - { - codePoint <<= 6; - codePoint += (get(i) & 0x3F); - } - return codePoint; - } - - template - int GetUnicodePointFromUTF16(const ReadCharFunc & get) - { - int byte0 = (unsigned char)get(0); - int byte1 = (unsigned char)get(1); - int word0 = byte0 + (byte1 << 8); - if (word0 >= 0xD800 && word0 <= 0xDFFF) - { - int byte2 = (unsigned char)get(2); - int byte3 = (unsigned char)get(3); - int word1 = byte2 + (byte3 << 8); - return ((word0 & 0x3FF) << 10) + (word1 & 0x3FF) + 0x10000; - } - else - return word0; - } - - template - int GetUnicodePointFromUTF16Reversed(const ReadCharFunc & get) - { - int byte0 = (unsigned char)get(0); - int byte1 = (unsigned char)get(1); - int word0 = (byte0 << 8) + byte1; - if (word0 >= 0xD800 && word0 <= 0xDFFF) - { - int byte2 = (unsigned char)get(2); - int byte3 = (unsigned char)get(3); - int word1 = (byte2 << 8) + byte3; - return ((word0 & 0x3FF) << 10) + (word1 & 0x3FF); - } - else - return word0; - } - - template - int GetUnicodePointFromUTF32(const ReadCharFunc & get) - { - int byte0 = (unsigned char)get(0); - int byte1 = (unsigned char)get(1); - int byte2 = (unsigned char)get(2); - int byte3 = (unsigned char)get(3); - return byte0 + (byte1 << 8) + (byte2 << 16) + (byte3 << 24); - } - - inline int EncodeUnicodePointToUTF8(char * buffer, int codePoint) - { - int count = 0; - if (codePoint <= 0x7F) - buffer[count++] = ((char)codePoint); - else if (codePoint <= 0x7FF) - { - unsigned char byte = (unsigned char)(0xC0 + (codePoint >> 6)); - buffer[count++] = ((char)byte); - byte = 0x80 + (codePoint & 0x3F); - buffer[count++] = ((char)byte); - } - else if (codePoint <= 0xFFFF) - { - unsigned char byte = (unsigned char)(0xE0 + (codePoint >> 12)); - buffer[count++] = ((char)byte); - byte = (unsigned char)(0x80 + ((codePoint >> 6) & (0x3F))); - buffer[count++] = ((char)byte); - byte = (unsigned char)(0x80 + (codePoint & 0x3F)); - buffer[count++] = ((char)byte); - } - else - { - unsigned char byte = (unsigned char)(0xF0 + (codePoint >> 18)); - buffer[count++] = ((char)byte); - byte = (unsigned char)(0x80 + ((codePoint >> 12) & 0x3F)); - buffer[count++] = ((char)byte); - byte = (unsigned char)(0x80 + ((codePoint >> 6) & 0x3F)); - buffer[count++] = ((char)byte); - byte = (unsigned char)(0x80 + (codePoint & 0x3F)); - buffer[count++] = ((char)byte); - } - return count; - } - - inline int EncodeUnicodePointToUTF16(unsigned short * buffer, int codePoint) - { - int count = 0; - if (codePoint <= 0xD7FF || (codePoint >= 0xE000 && codePoint <= 0xFFFF)) - buffer[count++] = (unsigned short)codePoint; - else - { - int sub = codePoint - 0x10000; - int high = (sub >> 10) + 0xD800; - int low = (sub & 0x3FF) + 0xDC00; - buffer[count++] = (unsigned short)high; - buffer[count++] = (unsigned short)low; - } - return count; - } - - inline unsigned short ReverseBitOrder(unsigned short val) - { - int byte0 = val & 0xFF; - int byte1 = val >> 8; - return (unsigned short)(byte1 + (byte0 << 8)); - } - - inline int EncodeUnicodePointToUTF16Reversed(unsigned short * buffer, int codePoint) - { - int count = 0; - if (codePoint <= 0xD7FF || (codePoint >= 0xE000 && codePoint <= 0xFFFF)) - buffer[count++] = ReverseBitOrder((unsigned short)codePoint); - else - { - int sub = codePoint - 0x10000; - int high = (sub >> 10) + 0xD800; - int low = (sub & 0x3FF) + 0xDC00; - buffer[count++] = ReverseBitOrder((unsigned short)high); - buffer[count++] = ReverseBitOrder((unsigned short)low); - } - return count; - } - - class Encoding - { - public: - static Encoding * UTF8, * UTF16, *UTF16Reversed, * UTF32; - virtual void GetBytes(List& buffer, const String & str) = 0; - virtual String ToString(const char * buffer, int length) = 0; - virtual ~Encoding() - {} - }; - - class StreamWriter : public TextWriter - { - private: - List encodingBuffer; - RefPtr stream; - Encoding * encoding; - public: - StreamWriter(const String & path, Encoding * encoding = Encoding::UTF8); - StreamWriter(RefPtr stream, Encoding * encoding = Encoding::UTF8); - virtual void Write(const String & str); - virtual void Write(const char * str); - virtual void Close() - { - stream->Close(); - } - void ReleaseStream() - { - stream = 0; - } - }; - - class StreamReader : public TextReader - { - private: - RefPtr stream; - List buffer; - Encoding * encoding; - Index ptr; - char ReadBufferChar(); - void ReadBuffer(); - - Encoding * DetermineEncoding(); - protected: - virtual void ReadChar() - { - decodedCharPtr = 0; - int codePoint = 0; - if (encoding == Encoding::UTF8) - codePoint = GetUnicodePointFromUTF8([&](int) {return ReadBufferChar(); }); - else if (encoding == Encoding::UTF16) - codePoint = GetUnicodePointFromUTF16([&](int) {return ReadBufferChar(); }); - else if (encoding == Encoding::UTF16Reversed) - codePoint = GetUnicodePointFromUTF16Reversed([&](int) {return ReadBufferChar(); }); - else if (encoding == Encoding::UTF32) - codePoint = GetUnicodePointFromUTF32([&](int) {return ReadBufferChar(); }); - decodedCharSize = EncodeUnicodePointToUTF8(decodedChar, codePoint); - } - public: - StreamReader(const String & path); - StreamReader(RefPtr stream, Encoding * encoding = nullptr); - virtual String ReadLine(); - virtual String ReadToEnd(); - virtual bool IsEnd() - { - return ptr == buffer.getCount() && stream->IsEnd(); - } - virtual void Close() - { - stream->Close(); - } - void ReleaseStream() - { - stream = 0; - } - }; -} - -#endif diff --git a/source/core/token-reader.cpp b/source/core/token-reader.cpp deleted file mode 100644 index ea40c9ed9..000000000 --- a/source/core/token-reader.cpp +++ /dev/null @@ -1,768 +0,0 @@ -#include "token-reader.h" - -namespace Slang -{ - enum class TokenizeErrorType - { - InvalidCharacter, InvalidEscapeSequence - }; - - enum class State - { - Start, Identifier, Operator, Int, Hex, Fixed, Double, Char, String, MultiComment, SingleComment - }; - - enum class LexDerivative - { - None, Line, File - }; - - inline bool IsLetter(char ch) - { - return ((ch >= 'a' && ch <= 'z') || - (ch >= 'A' && ch <= 'Z') || ch == '_'); - } - - inline bool IsDigit(char ch) - { - return ch >= '0' && ch <= '9'; - } - - inline bool IsPunctuation(char ch) - { - return ch == '+' || ch == '-' || ch == '*' || ch == '/' || ch == '%' || - ch == '!' || ch == '^' || ch == '&' || ch == '(' || ch == ')' || - ch == '=' || ch == '{' || ch == '}' || ch == '[' || ch == ']' || - ch == '|' || ch == ';' || ch == ',' || ch == '.' || ch == '<' || - ch == '>' || ch == '~' || ch == '@' || ch == ':' || ch == '?' || ch == '#'; - } - - inline bool IsWhiteSpace(char ch) - { - return (ch == ' ' || ch == '\t' || ch == '\n' || ch == '\r' || ch == '\v'); - } - - void ParseOperators(const String & str, List & tokens, TokenFlags& tokenFlags, int line, int col, int startPos, String fileName) - { - Index pos = 0; - while (pos < str.getLength()) - { - wchar_t curChar = str[pos]; - wchar_t nextChar = (pos < str.getLength() - 1) ? str[pos + 1] : '\0'; - wchar_t nextNextChar = (pos < str.getLength() - 2) ? str[pos + 2] : '\0'; - auto InsertToken = [&](TokenType type, const String & ct) - { - tokens.add(Token(type, ct, line, int(col + pos), int(pos + startPos), fileName, tokenFlags)); - tokenFlags = 0; - }; - switch (curChar) - { - case '+': - if (nextChar == '+') - { - InsertToken(TokenType::OpInc, "++"); - pos += 2; - } - else if (nextChar == '=') - { - InsertToken(TokenType::OpAddAssign, "+="); - pos += 2; - } - else - { - InsertToken(TokenType::OpAdd, "+"); - pos++; - } - break; - case '-': - if (nextChar == '-') - { - InsertToken(TokenType::OpDec, "--"); - pos += 2; - } - else if (nextChar == '=') - { - InsertToken(TokenType::OpSubAssign, "-="); - pos += 2; - } - else if (nextChar == '>') - { - InsertToken(TokenType::RightArrow, "->"); - pos += 2; - } - else - { - InsertToken(TokenType::OpSub, "-"); - pos++; - } - break; - case '*': - if (nextChar == '=') - { - InsertToken(TokenType::OpMulAssign, "*="); - pos += 2; - } - else - { - InsertToken(TokenType::OpMul, "*"); - pos++; - } - break; - case '/': - if (nextChar == '=') - { - InsertToken(TokenType::OpDivAssign, "/="); - pos += 2; - } - else - { - InsertToken(TokenType::OpDiv, "/"); - pos++; - } - break; - case '%': - if (nextChar == '=') - { - InsertToken(TokenType::OpModAssign, "%="); - pos += 2; - } - else - { - InsertToken(TokenType::OpMod, "%"); - pos++; - } - break; - case '|': - if (nextChar == '|') - { - InsertToken(TokenType::OpOr, "||"); - pos += 2; - } - else if (nextChar == '=') - { - InsertToken(TokenType::OpOrAssign, "|="); - pos += 2; - } - else - { - InsertToken(TokenType::OpBitOr, "|"); - pos++; - } - break; - case '&': - if (nextChar == '&') - { - InsertToken(TokenType::OpAnd, "&&"); - pos += 2; - } - else if (nextChar == '=') - { - InsertToken(TokenType::OpAndAssign, "&="); - pos += 2; - } - else - { - InsertToken(TokenType::OpBitAnd, "&"); - pos++; - } - break; - case '^': - if (nextChar == '=') - { - InsertToken(TokenType::OpXorAssign, "^="); - pos += 2; - } - else - { - InsertToken(TokenType::OpBitXor, "^"); - pos++; - } - break; - case '>': - if (nextChar == '>') - { - if (nextNextChar == '=') - { - InsertToken(TokenType::OpShrAssign, ">>="); - pos += 3; - } - else - { - InsertToken(TokenType::OpRsh, ">>"); - pos += 2; - } - } - else if (nextChar == '=') - { - InsertToken(TokenType::OpGeq, ">="); - pos += 2; - } - else - { - InsertToken(TokenType::OpGreater, ">"); - pos++; - } - break; - case '<': - if (nextChar == '<') - { - if (nextNextChar == '=') - { - InsertToken(TokenType::OpShlAssign, "<<="); - pos += 3; - } - else - { - InsertToken(TokenType::OpLsh, "<<"); - pos += 2; - } - } - else if (nextChar == '=') - { - InsertToken(TokenType::OpLeq, "<="); - pos += 2; - } - else - { - InsertToken(TokenType::OpLess, "<"); - pos++; - } - break; - case '=': - if (nextChar == '=') - { - InsertToken(TokenType::OpEql, "=="); - pos += 2; - } - else - { - InsertToken(TokenType::OpAssign, "="); - pos++; - } - break; - case '!': - if (nextChar == '=') - { - InsertToken(TokenType::OpNeq, "!="); - pos += 2; - } - else - { - InsertToken(TokenType::OpNot, "!"); - pos++; - } - break; - case '?': - InsertToken(TokenType::QuestionMark, "?"); - pos++; - break; - case '@': - InsertToken(TokenType::At, "@"); - pos++; - break; - case '#': - if (nextChar == '#') - { - InsertToken(TokenType::PoundPound, "##"); - pos += 2; - } - else - { - InsertToken(TokenType::Pound, "#"); - pos++; - } - pos++; - break; - case ':': - InsertToken(TokenType::Colon, ":"); - pos++; - break; - case '~': - InsertToken(TokenType::OpBitNot, "~"); - pos++; - break; - case ';': - InsertToken(TokenType::Semicolon, ";"); - pos++; - break; - case ',': - InsertToken(TokenType::Comma, ","); - pos++; - break; - case '.': - InsertToken(TokenType::Dot, "."); - pos++; - break; - case '{': - InsertToken(TokenType::LBrace, "{"); - pos++; - break; - case '}': - InsertToken(TokenType::RBrace, "}"); - pos++; - break; - case '[': - InsertToken(TokenType::LBracket, "["); - pos++; - break; - case ']': - InsertToken(TokenType::RBracket, "]"); - pos++; - break; - case '(': - InsertToken(TokenType::LParent, "("); - pos++; - break; - case ')': - InsertToken(TokenType::RParent, ")"); - pos++; - break; - } - } - } - - List TokenizeText(const String & fileName, const String & text) - { - Index lastPos = 0, pos = 0; - int line = 1, col = 0; - String file = fileName; - State state = State::Start; - StringBuilder tokenBuilder; - int tokenLine, tokenCol; - List tokenList; - LexDerivative derivative = LexDerivative::None; - TokenFlags tokenFlags = TokenFlag::AtStartOfLine; - auto InsertToken = [&](TokenType type) - { - derivative = LexDerivative::None; - tokenList.add(Token(type, tokenBuilder.ToString(), tokenLine, tokenCol, int(pos), file, tokenFlags)); - tokenFlags = 0; - tokenBuilder.Clear(); - }; - auto ProcessTransferChar = [&](char nextChar) - { - switch (nextChar) - { - case '\\': - case '\"': - case '\'': - tokenBuilder.Append(nextChar); - break; - case 't': - tokenBuilder.Append('\t'); - break; - case 's': - tokenBuilder.Append(' '); - break; - case 'n': - tokenBuilder.Append('\n'); - break; - case 'r': - tokenBuilder.Append('\r'); - break; - case 'b': - tokenBuilder.Append('\b'); - break; - } - }; - while (pos <= text.getLength()) - { - char curChar = (pos < text.getLength() ? text[pos] : ' '); - char nextChar = (pos < text.getLength() - 1) ? text[pos + 1] : '\0'; - if (lastPos != pos) - { - if (curChar == '\n') - { - line++; - col = 0; - } - else - col++; - lastPos = pos; - } - - switch (state) - { - case State::Start: - if (IsLetter(curChar)) - { - state = State::Identifier; - tokenLine = line; - tokenCol = col; - } - else if (IsDigit(curChar)) - { - state = State::Int; - tokenLine = line; - tokenCol = col; - } - else if (curChar == '\'') - { - state = State::Char; - pos++; - tokenLine = line; - tokenCol = col; - } - else if (curChar == '"') - { - state = State::String; - pos++; - tokenLine = line; - tokenCol = col; - } - else if (curChar == '\r' || curChar == '\n') - { - tokenFlags |= TokenFlag::AtStartOfLine | TokenFlag::AfterWhitespace; - pos++; - } - else if (curChar == ' ' || curChar == '\t' || curChar == -62 || curChar == -96) // -62/-96:non-break space - { - tokenFlags |= TokenFlag::AfterWhitespace; - pos++; - } - else if (curChar == '/' && nextChar == '/') - { - state = State::SingleComment; - pos += 2; - } - else if (curChar == '/' && nextChar == '*') - { - pos += 2; - state = State::MultiComment; - } - else if (curChar == '.' && IsDigit(nextChar)) - { - tokenBuilder.Append("0."); - state = State::Fixed; - pos++; - } - else if (IsPunctuation(curChar)) - { - state = State::Operator; - tokenLine = line; - tokenCol = col; - } - else - { - pos++; - } - break; - case State::Identifier: - if (IsLetter(curChar) || IsDigit(curChar)) - { - tokenBuilder.Append(curChar); - pos++; - } - else - { - auto tokenStr = tokenBuilder.ToString(); -#if 0 - if (tokenStr == "#line_reset#") - { - line = 0; - col = 0; - tokenBuilder.Clear(); - } - else if (tokenStr == "#line") - { - derivative = LexDerivative::Line; - tokenBuilder.Clear(); - } - else if (tokenStr == "#file") - { - derivative = LexDerivative::File; - tokenBuilder.Clear(); - line = 0; - col = 0; - } - else -#endif - InsertToken(TokenType::Identifier); - state = State::Start; - } - break; - case State::Operator: - if (IsPunctuation(curChar) && !((curChar == '/' && nextChar == '/') || (curChar == '/' && nextChar == '*'))) - { - tokenBuilder.Append(curChar); - pos++; - } - else - { - //do token analyze - ParseOperators(tokenBuilder.ToString(), tokenList, tokenFlags, tokenLine, tokenCol, (int)(pos - tokenBuilder.getLength()), file); - tokenBuilder.Clear(); - state = State::Start; - } - break; - case State::Int: - if (IsDigit(curChar)) - { - tokenBuilder.Append(curChar); - pos++; - } - else if (curChar == '.') - { - state = State::Fixed; - tokenBuilder.Append(curChar); - pos++; - } - else if (curChar == 'e' || curChar == 'E') - { - state = State::Double; - tokenBuilder.Append(curChar); - if (nextChar == '-' || nextChar == '+') - { - tokenBuilder.Append(nextChar); - pos++; - } - pos++; - } - else if (curChar == 'x') - { - state = State::Hex; - tokenBuilder.Append(curChar); - pos++; - } - else if (curChar == 'u') - { - pos++; - tokenBuilder.Append(curChar); - InsertToken(TokenType::IntLiteral); - state = State::Start; - } - else - { - if (derivative == LexDerivative::Line) - { - derivative = LexDerivative::None; - line = StringToInt(tokenBuilder.ToString()) - 1; - col = 0; - tokenBuilder.Clear(); - } - else - { - InsertToken(TokenType::IntLiteral); - } - state = State::Start; - } - break; - case State::Hex: - if (IsDigit(curChar) || (curChar >= 'a' && curChar <= 'f') || (curChar >= 'A' && curChar <= 'F')) - { - tokenBuilder.Append(curChar); - pos++; - } - else - { - InsertToken(TokenType::IntLiteral); - state = State::Start; - } - break; - case State::Fixed: - if (IsDigit(curChar)) - { - tokenBuilder.Append(curChar); - pos++; - } - else if (curChar == 'e' || curChar == 'E') - { - state = State::Double; - tokenBuilder.Append(curChar); - if (nextChar == '-' || nextChar == '+') - { - tokenBuilder.Append(nextChar); - pos++; - } - pos++; - } - else - { - if (curChar == 'f') - pos++; - InsertToken(TokenType::DoubleLiteral); - state = State::Start; - } - break; - case State::Double: - if (IsDigit(curChar)) - { - tokenBuilder.Append(curChar); - pos++; - } - else - { - if (curChar == 'f') - pos++; - InsertToken(TokenType::DoubleLiteral); - state = State::Start; - } - break; - case State::String: - if (curChar != '"') - { - if (curChar == '\\') - { - ProcessTransferChar(nextChar); - pos++; - } - else - tokenBuilder.Append(curChar); - } - else - { - if (derivative == LexDerivative::File) - { - derivative = LexDerivative::None; - file = tokenBuilder.ToString(); - tokenBuilder.Clear(); - } - else - { - InsertToken(TokenType::StringLiteral); - } - state = State::Start; - } - pos++; - break; - case State::Char: - if (curChar != '\'') - { - if (curChar == '\\') - { - ProcessTransferChar(nextChar); - pos++; - } - else - tokenBuilder.Append(curChar); - } - else - { - InsertToken(TokenType::CharLiteral); - state = State::Start; - } - pos++; - break; - case State::SingleComment: - if (curChar == '\n') - { - state = State::Start; - tokenFlags |= TokenFlag::AtStartOfLine | TokenFlag::AfterWhitespace; - } - pos++; - break; - case State::MultiComment: - if (curChar == '*' && nextChar == '/') - { - state = State::Start; - tokenFlags |= TokenFlag::AfterWhitespace; - pos += 2; - } - else - pos++; - break; - } - } - return tokenList; - } - List TokenizeText(const String & text) - { - return TokenizeText("", text); - } - - String EscapeStringLiteral(String str) - { - StringBuilder sb; - sb << "\""; - const Index length = str.getLength(); - const char*const data = str.getBuffer(); - for (Index i = 0; i < length; i++) - { - switch (data[i]) - { - case ' ': - sb << "\\s"; - break; - case '\n': - sb << "\\n"; - break; - case '\r': - sb << "\\r"; - break; - case '\t': - sb << "\\t"; - break; - case '\v': - sb << "\\v"; - break; - case '\'': - sb << "\\\'"; - break; - case '\"': - sb << "\\\""; - break; - case '\\': - sb << "\\\\"; - break; - default: - sb << data[i]; - break; - } - } - sb << "\""; - return sb.ProduceString(); - } - - String UnescapeStringLiteral(String str) - { - StringBuilder sb; - const Index length = str.getLength(); - const char*const data = str.getBuffer(); - for (Index i = 0; i < length; i++) - { - if (data[i] == '\\' && i < length - 1) - { - switch (data[i + 1]) - { - case 's': - sb << " "; - break; - case 't': - sb << '\t'; - break; - case 'n': - sb << '\n'; - break; - case 'r': - sb << '\r'; - break; - case 'v': - sb << '\v'; - break; - case '\'': - sb << '\''; - break; - case '\"': - sb << "\""; - break; - case '\\': - sb << "\\"; - break; - default: - i = i - 1; - sb << data[i]; - } - i++; - } - else - sb << data[i]; - } - return sb.ProduceString(); - } - - TokenReader::TokenReader(String text) - { - this->tokens = TokenizeText("", text); - tokenPtr = 0; - } -} diff --git a/source/core/token-reader.h b/source/core/token-reader.h deleted file mode 100644 index a5b9b3694..000000000 --- a/source/core/token-reader.h +++ /dev/null @@ -1,258 +0,0 @@ -#ifndef CORE_TOKEN_READER_H -#define CORE_TOKEN_READER_H - -#include "basic.h" - -namespace Slang -{ - enum class TokenType - { - EndOfFile = -1, - // illegal - Unknown, - // identifier - Identifier, - // constant - IntLiteral, DoubleLiteral, StringLiteral, CharLiteral, - // operators - Semicolon, Comma, Dot, LBrace, RBrace, LBracket, RBracket, LParent, RParent, - OpAssign, OpAdd, OpSub, OpMul, OpDiv, OpMod, OpNot, OpBitNot, OpLsh, OpRsh, - OpEql, OpNeq, OpGreater, OpLess, OpGeq, OpLeq, - OpAnd, OpOr, OpBitXor, OpBitAnd, OpBitOr, - OpInc, OpDec, OpAddAssign, OpSubAssign, OpMulAssign, OpDivAssign, OpModAssign, - OpShlAssign, OpShrAssign, OpOrAssign, OpAndAssign, OpXorAssign, - - QuestionMark, Colon, RightArrow, At, Pound, PoundPound, Scope, - }; - - class CodePosition - { - public: - int Line = -1, Col = -1, Pos = -1; - String FileName; - String ToString() - { - StringBuilder sb(100); - sb << FileName; - if (Line != -1) - sb << "(" << Line << ")"; - return sb.ProduceString(); - } - CodePosition() = default; - CodePosition(int line, int col, int pos, String fileName) - { - Line = line; - Col = col; - Pos = pos; - this->FileName = fileName; - } - bool operator < (const CodePosition & pos) const - { - return FileName < pos.FileName || (FileName == pos.FileName && Line < pos.Line) || - (FileName == pos.FileName && Line == pos.Line && Col < pos.Col); - } - bool operator == (const CodePosition & pos) const - { - return FileName == pos.FileName && Line == pos.Line && Col == pos.Col; - } - }; - - enum TokenFlag : unsigned int - { - AtStartOfLine = 1 << 0, - AfterWhitespace = 1 << 1, - }; - typedef unsigned int TokenFlags; - - class Token - { - public: - TokenType Type = TokenType::Unknown; - String Content; - CodePosition Position; - TokenFlags flags; - Token() = default; - Token(TokenType type, const String & content, int line, int col, int pos, String fileName, TokenFlags flags = 0) - : flags(flags) - { - Type = type; - Content = content; - Position = CodePosition(line, col, pos, fileName); - } - }; - - class TextFormatException : public Exception - { - public: - TextFormatException(String message) - : Exception(message) - {} - }; - - class TokenReader - { - private: - bool legal; - List tokens; - int tokenPtr; - public: - TokenReader(String text); - int ReadInt() - { - auto token = ReadToken(); - bool neg = false; - if (token.Content == '-') - { - neg = true; - token = ReadToken(); - } - if (token.Type == TokenType::IntLiteral) - { - if (neg) - return -StringToInt(token.Content); - else - return StringToInt(token.Content); - } - throw TextFormatException("Text parsing error: int expected."); - } - unsigned int ReadUInt() - { - auto token = ReadToken(); - if (token.Type == TokenType::IntLiteral) - { - return StringToUInt(token.Content); - } - throw TextFormatException("Text parsing error: int expected."); - } - double ReadDouble() - { - auto token = ReadToken(); - bool neg = false; - if (token.Content == '-') - { - neg = true; - token = ReadToken(); - } - if (token.Type == TokenType::DoubleLiteral || token.Type == TokenType::IntLiteral) - { - if (neg) - return -StringToDouble(token.Content); - else - return StringToDouble(token.Content); - } - throw TextFormatException("Text parsing error: floating point value expected."); - } - float ReadFloat() - { - return (float)ReadDouble(); - } - String ReadWord() - { - auto token = ReadToken(); - if (token.Type == TokenType::Identifier) - { - return token.Content; - } - throw TextFormatException("Text parsing error: identifier expected."); - } - String Read(const char * expectedStr) - { - auto token = ReadToken(); - if (token.Content == expectedStr) - { - return token.Content; - } - throw TextFormatException("Text parsing error: \'" + String(expectedStr) + "\' expected."); - } - String Read(String expectedStr) - { - auto token = ReadToken(); - if (token.Content == expectedStr) - { - return token.Content; - } - throw TextFormatException("Text parsing error: \'" + expectedStr + "\' expected."); - } - - String ReadStringLiteral() - { - auto token = ReadToken(); - if (token.Type == TokenType::StringLiteral) - { - return token.Content; - } - throw TextFormatException("Text parsing error: string literal expected."); - } - void Back(int count) - { - tokenPtr -= count; - } - Token ReadToken() - { - if (tokenPtr < (int)tokens.getCount()) - { - auto &rs = tokens[tokenPtr]; - tokenPtr++; - return rs; - } - throw TextFormatException("Unexpected ending."); - } - Token NextToken(int offset = 0) - { - if (tokenPtr + offset < (int)tokens.getCount()) - return tokens[tokenPtr + offset]; - else - { - Token rs; - rs.Type = TokenType::Unknown; - return rs; - } - } - bool LookAhead(String token) - { - if (tokenPtr < (int)tokens.getCount()) - { - auto next = NextToken(); - return next.Content == token; - } - else - { - return false; - } - } - bool IsEnd() - { - return tokenPtr == (int)tokens.getCount(); - } - public: - bool IsLegalText() - { - return legal; - } - }; - - inline List Split(String text, char c) - { - List result; - StringBuilder sb; - for (Index i = 0; i < text.getLength(); i++) - { - if (text[i] == c) - { - auto str = sb.ToString(); - if (str.getLength() != 0) - result.add(str); - sb.Clear(); - } - else - sb << text[i]; - } - auto lastStr = sb.ToString(); - if (lastStr.getLength()) - result.add(lastStr); - return result; - } -} - - -#endif diff --git a/source/core/type-traits.h b/source/core/type-traits.h deleted file mode 100644 index 804b4d3fe..000000000 --- a/source/core/type-traits.h +++ /dev/null @@ -1,46 +0,0 @@ -#ifndef CORELIB_TYPETRAITS_H -#define CORELIB_TYPETRAITS_H - -namespace Slang -{ - struct TraitResultYes - { - char x; - }; - struct TraitResultNo - { - char x[2]; - }; - - template - struct IsBaseOfTraitHost - { - operator B*() const { return nullptr; } - operator D*() { return nullptr; } - }; - - template - struct IsBaseOf - { - template - static TraitResultYes Check(D*, T) { return TraitResultYes(); } - static TraitResultNo Check(B*, int) { return TraitResultNo(); } - enum { Value = sizeof(Check(IsBaseOfTraitHost(), int())) == sizeof(TraitResultYes) }; - }; - - template - struct EnableIf {}; - - template - struct EnableIf { typedef T type; }; - - template - struct IsConvertible - { - static TraitResultYes Use(B) { return TraitResultYes(); }; - static TraitResultNo Use(...) { return TraitResultNo(); }; - enum { Value = sizeof(Use(*(D*)(nullptr))) == sizeof(TraitResultYes) }; - }; -} - -#endif diff --git a/source/slang/check.cpp b/source/slang/check.cpp deleted file mode 100644 index d51785112..000000000 --- a/source/slang/check.cpp +++ /dev/null @@ -1,11334 +0,0 @@ -#include "syntax-visitors.h" - -#include "lookup.h" -#include "compiler.h" -#include "visitor.h" - -#include "../core/secure-crt.h" -#include - -namespace Slang -{ - RefPtr getTypeType( - Type* type); - - /// Should the given `decl` nested in `parentDecl` be treated as a static rather than instance declaration? - bool isEffectivelyStatic( - Decl* decl, - ContainerDecl* parentDecl) - { - // Things at the global scope are always "members" of their module. - // - if(as(parentDecl)) - return false; - - // Anything explicitly marked `static` and not at module scope - // counts as a static rather than instance declaration. - // - if(decl->HasModifier()) - return true; - - // Next we need to deal with cases where a declaration is - // effectively `static` even if the language doesn't make - // the user say so. Most languages make the default assumption - // that nested types are `static` even if they don't say - // so (Java is an exception here, perhaps due to some - // influence from the Scandanavian OOP tradition of Beta/gbeta). - // - if(as(decl)) - return true; - if(as(decl)) - return true; - - // Things nested inside functions may have dependencies - // on values from the enclosing scope, but this needs to - // be dealt with via "capture" so they are also effectively - // `static` - // - if(as(parentDecl)) - return true; - - // Type constraint declarations are used in member-reference - // context as a form of casting operation, so we treat them - // as if they are instance members. This is a bit of a hack, - // but it achieves the result we want until we have an - // explicit representation of up-cast operations in the - // AST. - // - if(as(decl)) - return false; - - return false; - } - - /// Should the given `decl` be treated as a static rather than instance declaration? - bool isEffectivelyStatic( - Decl* decl) - { - // For the purposes of an ordinary declaration, when determining if - // it is static or per-instance, the "parent" declaration we really - // care about is the next outer non-generic declaration. - // - // TODO: This idiom of getting the "next outer non-generic declaration" - // comes up just enough that we should probably have a convenience - // function for it. - - auto parentDecl = decl->ParentDecl; - if(auto genericDecl = as(parentDecl)) - parentDecl = genericDecl->ParentDecl; - - return isEffectivelyStatic(decl, parentDecl); - } - - /// Is `decl` a global shader parameter declaration? - bool isGlobalShaderParameter(VarDeclBase* decl) - { - // A global shader parameter must be declared at global (module) scope. - // - if(!as(decl->ParentDecl)) return false; - - // A global variable marked `static` indicates a traditional - // global variable (albeit one that is implicitly local to - // the translation unit) - // - if(decl->HasModifier()) return false; - - // The `groupshared` modifier indicates that a variable cannot - // be a shader parameters, but is instead transient storage - // allocated for the duration of a thread-group's execution. - // - if(decl->HasModifier()) return false; - - return true; - } - - // A flat representation of basic types (scalars, vectors and matrices) - // that can be used as lookup key in caches - struct BasicTypeKey - { - union - { - struct - { - unsigned char type : 4; - unsigned char dim1 : 2; - unsigned char dim2 : 2; - } data; - unsigned char aggVal; - }; - bool fromType(Type* typeIn) - { - aggVal = 0; - if (auto basicType = as(typeIn)) - { - data.type = (unsigned char)basicType->baseType; - data.dim1 = data.dim2 = 0; - } - else if (auto vectorType = as(typeIn)) - { - if (auto elemCount = as(vectorType->elementCount)) - { - data.dim1 = elemCount->value - 1; - auto elementBasicType = as(vectorType->elementType); - data.type = (unsigned char)elementBasicType->baseType; - data.dim2 = 0; - } - else - return false; - } - else if (auto matrixType = as(typeIn)) - { - if (auto elemCount1 = as(matrixType->getRowCount())) - { - if (auto elemCount2 = as(matrixType->getColumnCount())) - { - auto elemBasicType = as(matrixType->getElementType()); - data.type = (unsigned char)elemBasicType->baseType; - data.dim1 = elemCount1->value - 1; - data.dim2 = elemCount2->value - 1; - } - } - else - return false; - } - else - return false; - return true; - } - }; - - struct BasicTypeKeyPair - { - BasicTypeKey type1, type2; - bool operator == (BasicTypeKeyPair p) - { - return type1.aggVal == p.type1.aggVal && type2.aggVal == p.type2.aggVal; - } - int GetHashCode() - { - return combineHash(type1.aggVal, type2.aggVal); - } - }; - - struct OverloadCandidate - { - enum class Flavor - { - Func, - Generic, - UnspecializedGeneric, - }; - Flavor flavor; - - enum class Status - { - GenericArgumentInferenceFailed, - Unchecked, - ArityChecked, - FixityChecked, - TypeChecked, - DirectionChecked, - Applicable, - }; - Status status = Status::Unchecked; - - // Reference to the declaration being applied - LookupResultItem item; - - // The type of the result expression if this candidate is selected - RefPtr resultType; - - // A system for tracking constraints introduced on generic parameters - // ConstraintSystem constraintSystem; - - // How much conversion cost should be considered for this overload, - // when ranking candidates. - ConversionCost conversionCostSum = kConversionCost_None; - - // When required, a candidate can store a pre-checked list of - // arguments so that we don't have to repeat work across checking - // phases. Currently this is only needed for generics. - RefPtr subst; - }; - - struct OperatorOverloadCacheKey - { - IROp operatorName; - BasicTypeKey args[2]; - bool operator == (OperatorOverloadCacheKey key) - { - return operatorName == key.operatorName && args[0].aggVal == key.args[0].aggVal - && args[1].aggVal == key.args[1].aggVal; - } - int GetHashCode() - { - return ((int)(UInt64)(void*)(operatorName) << 16) ^ (args[0].aggVal << 8) ^ (args[1].aggVal); - } - bool fromOperatorExpr(OperatorExpr* opExpr) - { - // First, lets see if the argument types are ones - // that we can encode in our space of keys. - args[0].aggVal = 0; - args[1].aggVal = 0; - if (opExpr->Arguments.getCount() > 2) - return false; - - for (Index i = 0; i < opExpr->Arguments.getCount(); i++) - { - if (!args[i].fromType(opExpr->Arguments[i]->type.Ptr())) - return false; - } - - // Next, lets see if we can find an intrinsic opcode - // attached to an overloaded definition (filtered for - // definitions that could conceivably apply to us). - // - // TODO: This should really be parsed on the operator name - // plus fixity, rather than the intrinsic opcode... - // - // We will need to reject postfix definitions for prefix - // operators, and vice versa, to ensure things work. - // - auto prefixExpr = as(opExpr); - auto postfixExpr = as(opExpr); - - if (auto overloadedBase = as(opExpr->FunctionExpr)) - { - for(auto item : overloadedBase->lookupResult2 ) - { - // Look at a candidate definition to be called and - // see if it gives us a key to work with. - // - Decl* funcDecl = overloadedBase->lookupResult2.item.declRef.decl; - if (auto genDecl = as(funcDecl)) - funcDecl = genDecl->inner.Ptr(); - - // Reject definitions that have the wrong fixity. - // - if(prefixExpr && !funcDecl->FindModifier()) - continue; - if(postfixExpr && !funcDecl->FindModifier()) - continue; - - if (auto intrinsicOp = funcDecl->FindModifier()) - { - operatorName = intrinsicOp->op; - return true; - } - } - } - return false; - } - }; - - struct TypeCheckingCache - { - Dictionary resolvedOperatorOverloadCache; - Dictionary conversionCostCache; - }; - - TypeCheckingCache* Session::getTypeCheckingCache() - { - if (!typeCheckingCache) - typeCheckingCache = new TypeCheckingCache(); - return typeCheckingCache; - } - - void Session::destroyTypeCheckingCache() - { - delete typeCheckingCache; - typeCheckingCache = nullptr; - } - - namespace { // anonymous - struct FunctionInfo - { - const char* name; - SharedLibraryType libraryType; - }; - } // anonymous - - static FunctionInfo _getFunctionInfo(Session::SharedLibraryFuncType funcType) - { - typedef Session::SharedLibraryFuncType FuncType; - typedef SharedLibraryType LibType; - - switch (funcType) - { - case FuncType::Glslang_Compile: return { "glslang_compile", LibType::Glslang } ; - case FuncType::Fxc_D3DCompile: return { "D3DCompile", LibType::Fxc }; - case FuncType::Fxc_D3DDisassemble: return { "D3DDisassemble", LibType::Fxc }; - case FuncType::Dxc_DxcCreateInstance: return { "DxcCreateInstance", LibType::Dxc }; - default: return { nullptr, LibType::Unknown }; - } - } - - ISlangSharedLibrary* Session::getOrLoadSharedLibrary(SharedLibraryType type, DiagnosticSink* sink) - { - // If not loaded, try loading it - if (!sharedLibraries[int(type)]) - { - // Try to preload dxil first, if loading dxc - if (type == SharedLibraryType::Dxc) - { - // Pass nullptr as the sink, because if it fails we don't want to report as error - getOrLoadSharedLibrary(SharedLibraryType::Dxil, nullptr); - } - - const char* libName = DefaultSharedLibraryLoader::getSharedLibraryNameFromType(type); - if (SLANG_FAILED(sharedLibraryLoader->loadSharedLibrary(libName, sharedLibraries[int(type)].writeRef()))) - { - if (sink) - { - sink->diagnose(SourceLoc(), Diagnostics::failedToLoadDynamicLibrary, libName); - } - return nullptr; - } - } - return sharedLibraries[int(type)]; - } - - SlangFuncPtr Session::getSharedLibraryFunc(SharedLibraryFuncType type, DiagnosticSink* sink) - { - if (sharedLibraryFunctions[int(type)]) - { - return sharedLibraryFunctions[int(type)]; - } - // do we have the library - FunctionInfo info = _getFunctionInfo(type); - if (info.name == nullptr) - { - return nullptr; - } - // Try loading the library - ISlangSharedLibrary* sharedLib = getOrLoadSharedLibrary(info.libraryType, sink); - if (!sharedLib) - { - return nullptr; - } - - // Okay now access the func - SlangFuncPtr func = sharedLib->findFuncByName(info.name); - if (!func) - { - const char* libName = DefaultSharedLibraryLoader::getSharedLibraryNameFromType(info.libraryType); - sink->diagnose(SourceLoc(), Diagnostics::failedToFindFunctionInSharedLibrary, info.name, libName); - return nullptr; - } - - // Store in the function cache - sharedLibraryFunctions[int(type)] = func; - return func; - } - - - enum class CheckingPhase - { - Header, Body - }; - - struct SemanticsVisitor - : ExprVisitor> - , StmtVisitor - , DeclVisitor - { - CheckingPhase checkingPhase = CheckingPhase::Header; - DeclCheckState getCheckedState() - { - if (checkingPhase == CheckingPhase::Body) - return DeclCheckState::Checked; - else - return DeclCheckState::CheckedHeader; - } - - Linkage* m_linkage = nullptr; - DiagnosticSink* m_sink = nullptr; - - DiagnosticSink* getSink() - { - return m_sink; - } - -// ModuleDecl * program = nullptr; - FuncDecl * function = nullptr; - - - // lexical outer statements - List outerStmts; - - // We need to track what has been `import`ed, - // to avoid importing the same thing more than once - // - // TODO: a smarter approach might be to filter - // out duplicate references during lookup. - HashSet importedModules; - - public: - SemanticsVisitor( - Linkage* linkage, - DiagnosticSink* sink) - : m_linkage(linkage) - , m_sink(sink) - {} - - Session* getSession() - { - return m_linkage->getSession(); - } - - public: - // Translate Types - RefPtr typeResult; - RefPtr TranslateTypeNodeImpl(const RefPtr & node) - { - if (!node) return nullptr; - - auto expr = CheckTerm(node); - expr = ExpectATypeRepr(expr); - return expr; - } - RefPtr ExtractTypeFromTypeRepr(const RefPtr& typeRepr) - { - if (!typeRepr) return nullptr; - if (auto typeType = as(typeRepr->type)) - { - return typeType->type; - } - return getSession()->getErrorType(); - } - RefPtr TranslateTypeNode(const RefPtr & node) - { - if (!node) return nullptr; - auto typeRepr = TranslateTypeNodeImpl(node); - return ExtractTypeFromTypeRepr(typeRepr); - } - TypeExp TranslateTypeNodeForced(TypeExp const& typeExp) - { - auto typeRepr = TranslateTypeNodeImpl(typeExp.exp); - - TypeExp result; - result.exp = typeRepr; - result.type = ExtractTypeFromTypeRepr(typeRepr); - return result; - } - TypeExp TranslateTypeNode(TypeExp const& typeExp) - { - // HACK(tfoley): It seems that in some cases we end up re-checking - // syntax that we've already checked. We need to root-cause that - // issue, but for now a quick fix in this case is to early - // exist if we've already got a type associated here: - if (typeExp.type) - { - return typeExp; - } - return TranslateTypeNodeForced(typeExp); - } - - RefPtr getExprDeclRefType(Expr * expr) - { - if (auto typetype = as(expr->type)) - return typetype->type.dynamicCast(); - else - return as(expr->type); - } - - /// Is `decl` usable as a static member? - bool isDeclUsableAsStaticMember( - Decl* decl) - { - if(decl->HasModifier()) - return true; - - if(as(decl)) - return true; - - if(as(decl)) - return true; - - if(as(decl)) - return true; - - if(as(decl)) - return true; - - return false; - } - - /// Is `item` usable as a static member? - bool isUsableAsStaticMember( - LookupResultItem const& item) - { - // There's a bit of a gotcha here, because a lookup result - // item might include "breadcrumbs" that indicate more steps - // along the lookup path. As a result it isn't always - // valid to just check whether the final decl is usable - // as a static member, because it might not even be a - // member of the thing we are trying to work with. - // - - Decl* decl = item.declRef.getDecl(); - for(auto bb = item.breadcrumbs; bb; bb = bb->next) - { - switch(bb->kind) - { - // In case lookup went through a `__transparent` member, - // we are interested in the static-ness of that transparent - // member, and *not* the static-ness of whatever was inside - // of it. - // - // TODO: This would need some work if we ever had - // transparent *type* members. - // - case LookupResultItem::Breadcrumb::Kind::Member: - decl = bb->declRef.getDecl(); - break; - - // TODO: Are there any other cases that need special-case - // handling here? - - default: - break; - } - } - - // Okay, we've found the declaration we should actually - // be checking, so lets validate that. - - return isDeclUsableAsStaticMember(decl); - } - - /// Move `expr` into a temporary variable and execute `func` on that variable. - /// - /// Returns an expression that wraps both the creation and initialization of - /// the temporary, and the computation created by `func`. - /// - template - RefPtr moveTemp(RefPtr const& expr, F const& func) - { - RefPtr varDecl = new VarDecl(); - varDecl->ParentDecl = nullptr; // TODO: need to fill this in somehow! - varDecl->checkState = DeclCheckState::Checked; - varDecl->nameAndLoc.loc = expr->loc; - varDecl->initExpr = expr; - varDecl->type.type = expr->type.type; - - auto varDeclRef = makeDeclRef(varDecl.Ptr()); - - RefPtr letExpr = new LetExpr(); - letExpr->decl = varDecl; - - auto body = func(varDeclRef); - - letExpr->body = body; - letExpr->type = body->type; - - return letExpr; - } - - /// Execute `func` on a variable with the value of `expr`. - /// - /// If `expr` is just a reference to an immutable (e.g., `let`) variable - /// then this might use the existing variable. Otherwise it will create - /// a new variable to hold `expr`, using `moveTemp()`. - /// - template - RefPtr maybeMoveTemp(RefPtr const& expr, F const& func) - { - if(auto varExpr = as(expr)) - { - auto declRef = varExpr->declRef; - if(auto varDeclRef = declRef.as()) - return func(varDeclRef); - } - - return moveTemp(expr, func); - } - - /// Return an expression that represents "opening" the existential `expr`. - /// - /// The type of `expr` must be an interface type, matching `interfaceDeclRef`. - /// - /// If we scope down the PL theory to just the case that Slang cares about, - /// a value of an existential type like `IMover` is a tuple of: - /// - /// * a concrete type `X` - /// * a witness `w` of the fact that `X` implements `IMover` - /// * a value `v` of type `X` - /// - /// "Opening" an existential value is the process of decomposing a single - /// value `e : IMover` into the pieces `X`, `w`, and `v`. - /// - /// Rather than return all those pieces individually, this operation - /// returns an expression that logically corresponds to `v`: an expression - /// of type `X`, where the type carries the knowledge that `X` implements `IMover`. - /// - RefPtr openExistential( - RefPtr expr, - DeclRef interfaceDeclRef) - { - // If `expr` refers to an immutable binding, - // then we can use it directly. If it refers - // to an arbitrary expression or a mutable - // binding, we will move its value into an - // immutable temporary so that we can use - // it directly. - // - auto interfaceDecl = interfaceDeclRef.getDecl(); - return maybeMoveTemp(expr, [&](DeclRef varDeclRef) - { - RefPtr openedType = new ExtractExistentialType(); - openedType->declRef = varDeclRef; - - RefPtr openedWitness = new ExtractExistentialSubtypeWitness(); - openedWitness->sub = openedType; - openedWitness->sup = expr->type.type; - openedWitness->declRef = varDeclRef; - - RefPtr openedThisType = new ThisTypeSubstitution(); - openedThisType->outer = interfaceDeclRef.substitutions.substitutions; - openedThisType->interfaceDecl = interfaceDecl; - openedThisType->witness = openedWitness; - - DeclRef substDeclRef = DeclRef(interfaceDecl, openedThisType); - auto substDeclRefType = DeclRefType::Create(getSession(), substDeclRef); - - RefPtr openedValue = new ExtractExistentialValueExpr(); - openedValue->declRef = varDeclRef; - openedValue->type = QualType(substDeclRefType); - - return openedValue; - }); - } - - /// If `expr` has existential type, then open it. - /// - /// Returns an expression that opens `expr` if it had existential type. - /// Otherwise returns `expr` itself. - /// - /// See `openExistential` for a discussion of what "opening" an - /// existential-type value means. - /// - RefPtr maybeOpenExistential(RefPtr expr) - { - auto exprType = expr->type.type; - - if(auto declRefType = as(exprType)) - { - if(auto interfaceDeclRef = declRefType->declRef.as()) - { - // Is there an this-type substitution being applied, so that - // we are referencing the interface type through a concrete - // type (e.g., a type parameter constrained to this interface)? - // - // Because of the way that substitutions need to mirror the nesting - // hierarchy of declarations, any this-type substitution pertaining - // to the chosen interface decl must be the first substitution on - // the list (which is a linked list from the "inside" out). - // - auto thisTypeSubst = interfaceDeclRef.substitutions.substitutions.as(); - if(thisTypeSubst && thisTypeSubst->interfaceDecl == interfaceDeclRef.decl) - { - // This isn't really an existential type, because somebody - // has already filled in a this-type substitution. - } - else - { - // Okay, here is the case that matters. - // - return openExistential(expr, interfaceDeclRef); - } - } - } - - // Default: apply the callback to the original expression; - return expr; - } - - RefPtr ConstructDeclRefExpr( - DeclRef declRef, - RefPtr baseExpr, - SourceLoc loc) - { - // Compute the type that this declaration reference will have in context. - // - auto type = GetTypeForDeclRef(declRef); - - // Construct an appropriate expression based on the structured of - // the declaration reference. - // - if (baseExpr) - { - // If there was a base expression, we will have some kind of - // member expression. - - // We want to check for the case where the base "expression" - // actually names a type, because in that case we are doing - // a static member reference. - // - // TODO: Should we be checking if the member is static here? - // If it isn't, should we be automatically producing a "curried" - // form (e.g., for a member function, return a value usable - // for referencing it as a free function). - // - if (as(baseExpr->type)) - { - auto expr = new StaticMemberExpr(); - expr->loc = loc; - expr->type = type; - expr->BaseExpression = baseExpr; - expr->name = declRef.GetName(); - expr->declRef = declRef; - return expr; - } - else if(isEffectivelyStatic(declRef.getDecl())) - { - // Extract the type of the baseExpr - auto baseExprType = baseExpr->type.type; - RefPtr baseTypeExpr = new SharedTypeExpr(); - baseTypeExpr->base.type = baseExprType; - baseTypeExpr->type.type = getTypeType(baseExprType); - - auto expr = new StaticMemberExpr(); - expr->loc = loc; - expr->type = type; - expr->BaseExpression = baseTypeExpr; - expr->name = declRef.GetName(); - expr->declRef = declRef; - return expr; - } - else - { - // If the base expression wasn't a type, then this - // is a normal member expression. - // - auto expr = new MemberExpr(); - expr->loc = loc; - expr->type = type; - expr->BaseExpression = baseExpr; - expr->name = declRef.GetName(); - expr->declRef = declRef; - - // When referring to a member through an expression, - // the result is only an l-value if both the base - // expression and the member agree that it should be. - // - // We have already used the `QualType` from the member - // above (that is `type`), so we need to take the - // l-value status of the base expression into account now. - if(!baseExpr->type.IsLeftValue) - { - expr->type.IsLeftValue = false; - } - - return expr; - } - } - else - { - // If there is no base expression, then the result must - // be an ordinary variable expression. - // - auto expr = new VarExpr(); - expr->loc = loc; - expr->name = declRef.GetName(); - expr->type = type; - expr->declRef = declRef; - return expr; - } - } - - RefPtr ConstructDerefExpr( - RefPtr base, - SourceLoc loc) - { - auto ptrLikeType = as(base->type); - SLANG_ASSERT(ptrLikeType); - - auto derefExpr = new DerefExpr(); - derefExpr->loc = loc; - derefExpr->base = base; - derefExpr->type = QualType(ptrLikeType->elementType); - - // TODO(tfoley): handle l-value status here - - return derefExpr; - } - - RefPtr createImplicitThisMemberExpr( - Type* type, - SourceLoc loc, - LookupResultItem::Breadcrumb::ThisParameterMode thisParameterMode) - { - RefPtr expr = new ThisExpr(); - expr->type = type; - expr->type.IsLeftValue = thisParameterMode == LookupResultItem::Breadcrumb::ThisParameterMode::Mutating; - expr->loc = loc; - return expr; - } - - RefPtr ConstructLookupResultExpr( - LookupResultItem const& item, - RefPtr baseExpr, - SourceLoc loc) - { - // If we collected any breadcrumbs, then these represent - // additional segments of the lookup path that we need - // to expand here. - auto bb = baseExpr; - for (auto breadcrumb = item.breadcrumbs; breadcrumb; breadcrumb = breadcrumb->next) - { - switch (breadcrumb->kind) - { - case LookupResultItem::Breadcrumb::Kind::Member: - bb = ConstructDeclRefExpr(breadcrumb->declRef, bb, loc); - break; - - case LookupResultItem::Breadcrumb::Kind::Deref: - bb = ConstructDerefExpr(bb, loc); - break; - - case LookupResultItem::Breadcrumb::Kind::Constraint: - { - // TODO: do we need to make something more - // explicit here? - bb = ConstructDeclRefExpr( - breadcrumb->declRef, - bb, - loc); - } - break; - - case LookupResultItem::Breadcrumb::Kind::This: - { - // We expect a `this` to always come - // at the start of a chain. - SLANG_ASSERT(bb == nullptr); - - // The member was looked up via a `this` expression, - // so we need to create one here. - if (auto extensionDeclRef = breadcrumb->declRef.as()) - { - bb = createImplicitThisMemberExpr( - GetTargetType(extensionDeclRef), - loc, - breadcrumb->thisParameterMode); - } - else - { - auto type = DeclRefType::Create(getSession(), breadcrumb->declRef); - bb = createImplicitThisMemberExpr( - type, - loc, - breadcrumb->thisParameterMode); - } - } - break; - - default: - SLANG_UNREACHABLE("all cases handle"); - } - } - - return ConstructDeclRefExpr(item.declRef, bb, loc); - } - - RefPtr createLookupResultExpr( - LookupResult const& lookupResult, - RefPtr baseExpr, - SourceLoc loc) - { - if (lookupResult.isOverloaded()) - { - auto overloadedExpr = new OverloadedExpr(); - overloadedExpr->loc = loc; - overloadedExpr->type = QualType( - getSession()->getOverloadedType()); - overloadedExpr->base = baseExpr; - overloadedExpr->lookupResult2 = lookupResult; - return overloadedExpr; - } - else - { - return ConstructLookupResultExpr(lookupResult.item, baseExpr, loc); - } - } - - RefPtr ResolveOverloadedExpr(RefPtr overloadedExpr, LookupMask mask) - { - auto lookupResult = overloadedExpr->lookupResult2; - SLANG_RELEASE_ASSERT(lookupResult.isValid() && lookupResult.isOverloaded()); - - // Take the lookup result we had, and refine it based on what is expected in context. - lookupResult = refineLookup(lookupResult, mask); - - if (!lookupResult.isValid()) - { - // If we didn't find any symbols after filtering, then just - // use the original and report errors that way - return overloadedExpr; - } - - if (lookupResult.isOverloaded()) - { - // We had an ambiguity anyway, so report it. - getSink()->diagnose(overloadedExpr, Diagnostics::ambiguousReference, lookupResult.items[0].declRef.GetName()); - - for(auto item : lookupResult.items) - { - String declString = getDeclSignatureString(item); - getSink()->diagnose(item.declRef, Diagnostics::overloadCandidate, declString); - } - - // TODO(tfoley): should we construct a new ErrorExpr here? - return CreateErrorExpr(overloadedExpr); - } - - // otherwise, we had a single decl and it was valid, hooray! - return ConstructLookupResultExpr(lookupResult.item, overloadedExpr->base, overloadedExpr->loc); - } - - RefPtr ExpectATypeRepr(RefPtr expr) - { - if (auto overloadedExpr = as(expr)) - { - expr = ResolveOverloadedExpr(overloadedExpr, LookupMask::type); - } - - if (auto typeType = as(expr->type)) - { - return expr; - } - else if (auto errorType = as(expr->type)) - { - return expr; - } - - getSink()->diagnose(expr, Diagnostics::unimplemented, "expected a type"); - return CreateErrorExpr(expr); - } - - RefPtr ExpectAType(RefPtr expr) - { - auto typeRepr = ExpectATypeRepr(expr); - if (auto typeType = as(typeRepr->type)) - { - return typeType->type; - } - return getSession()->getErrorType(); - } - - RefPtr ExtractGenericArgType(RefPtr exp) - { - return ExpectAType(exp); - } - - RefPtr ExtractGenericArgInteger(RefPtr exp) - { - return CheckIntegerConstantExpression(exp.Ptr()); - } - - RefPtr ExtractGenericArgVal(RefPtr exp) - { - if (auto overloadedExpr = as(exp)) - { - // assume that if it is overloaded, we want a type - exp = ResolveOverloadedExpr(overloadedExpr, LookupMask::type); - } - - if (auto typeType = as(exp->type)) - { - return typeType->type; - } - else if (auto errorType = as(exp->type)) - { - return exp->type.type; - } - else - { - return ExtractGenericArgInteger(exp); - } - } - - // Construct a type representing the instantiation of - // the given generic declaration for the given arguments. - // The arguments should already be checked against - // the declaration. - RefPtr InstantiateGenericType( - DeclRef genericDeclRef, - List> const& args) - { - RefPtr subst = new GenericSubstitution(); - subst->genericDecl = genericDeclRef.getDecl(); - subst->outer = genericDeclRef.substitutions.substitutions; - - for (auto argExpr : args) - { - subst->args.add(ExtractGenericArgVal(argExpr)); - } - - DeclRef innerDeclRef; - innerDeclRef.decl = GetInner(genericDeclRef); - innerDeclRef.substitutions = SubstitutionSet(subst); - - return DeclRefType::Create( - getSession(), - innerDeclRef); - } - - // This routine is a bottleneck for all declaration checking, - // so that we can add some quality-of-life features for users - // in cases where the compiler crashes - void dispatchDecl(DeclBase* decl) - { - try - { - DeclVisitor::dispatch(decl); - } - // Don't emit any context message for an explicit `AbortCompilationException` - // because it should only happen when an error is already emitted. - catch(AbortCompilationException&) { throw; } - catch(...) - { - getSink()->noteInternalErrorLoc(decl->loc); - throw; - } - } - void dispatchStmt(Stmt* stmt) - { - try - { - StmtVisitor::dispatch(stmt); - } - catch(AbortCompilationException&) { throw; } - catch(...) - { - getSink()->noteInternalErrorLoc(stmt->loc); - throw; - } - } - void dispatchExpr(Expr* expr) - { - try - { - ExprVisitor::dispatch(expr); - } - catch(AbortCompilationException&) { throw; } - catch(...) - { - getSink()->noteInternalErrorLoc(expr->loc); - throw; - } - } - - // Make sure a declaration has been checked, so we can refer to it. - // Note that this may lead to us recursively invoking checking, - // so this may not be the best way to handle things. - void EnsureDecl(RefPtr decl, DeclCheckState state) - { - if (decl->IsChecked(state)) return; - if (decl->checkState == DeclCheckState::CheckingHeader) - { - // We tried to reference the same declaration while checking it! - // - // TODO: we should ideally be tracking a "chain" of declarations - // being checked on the stack, so that we can report the full - // chain that leads from this declaration back to itself. - // - getSink()->diagnose(decl, Diagnostics::cyclicReference, decl); - return; - } - - // Hack: if we are somehow referencing a local variable declaration - // before the line of code that defines it, then we need to diagnose - // an error. - // - // TODO: The right answer is that lookup should have been performed in - // the scope that was in place *before* the variable was declared, but - // this is a quick fix that at least alerts the user to how we are - // interpreting their code. - // - if (auto varDecl = as(decl)) - { - if (auto parenScope = as(varDecl->ParentDecl)) - { - // TODO: This diagnostic should be emitted on the line that is referencing - // the declaration. That requires `EnsureDecl` to take the requesting - // location as a parameter. - getSink()->diagnose(decl, Diagnostics::localVariableUsedBeforeDeclared, decl); - return; - } - } - - if (DeclCheckState::CheckingHeader > decl->checkState) - { - decl->SetCheckState(DeclCheckState::CheckingHeader); - } - - // Check the modifiers on the declaration first, in case - // semantics of the body itself will depend on them. - checkModifiers(decl); - - // Use visitor pattern to dispatch to correct case - dispatchDecl(decl); - - if(state > decl->checkState) - { - decl->SetCheckState(state); - } - } - - void EnusreAllDeclsRec(RefPtr decl) - { - checkDecl(decl); - if (auto containerDecl = as(decl)) - { - for (auto m : containerDecl->Members) - { - EnusreAllDeclsRec(m); - } - } - } - - // A "proper" type is one that can be used as the type of an expression. - // Put simply, it can be a concrete type like `int`, or a generic - // type that is applied to arguments, like `Texture2D`. - // The type `void` is also a proper type, since we can have expressions - // that return a `void` result (e.g., many function calls). - // - // A "non-proper" type is any type that can't actually have values. - // A simple example of this in C++ is `std::vector` - you can't have - // a value of this type. - // - // Part of what this function does is give errors if somebody tries - // to use a non-proper type as the type of a variable (or anything - // else that needs a proper type). - // - // The other thing it handles is the fact that HLSL lets you use - // the name of a non-proper type, and then have the compiler fill - // in the default values for its type arguments (e.g., a variable - // given type `Texture2D` will actually have type `Texture2D`). - bool CoerceToProperTypeImpl( - TypeExp const& typeExp, - RefPtr* outProperType, - DiagnosticSink* diagSink) - { - Type* type = typeExp.type.Ptr(); - if(!type && typeExp.exp) - { - if(auto typeType = as(typeExp.exp->type)) - { - type = typeType->type; - } - } - - if (!type) - { - if (outProperType) - { - *outProperType = nullptr; - } - return false; - } - - if (auto genericDeclRefType = as(type)) - { - // We are using a reference to a generic declaration as a concrete - // type. This means we should substitute in any default parameter values - // if they are available. - // - // TODO(tfoley): A more expressive type system would substitute in - // "fresh" variables and then solve for their values... - // - - auto genericDeclRef = genericDeclRefType->GetDeclRef(); - checkDecl(genericDeclRef.decl); - List> args; - for (RefPtr member : genericDeclRef.getDecl()->Members) - { - if (auto typeParam = as(member)) - { - if (!typeParam->initType.exp) - { - if (diagSink) - { - diagSink->diagnose(typeExp.exp.Ptr(), Diagnostics::genericTypeNeedsArgs, typeExp); - *outProperType = getSession()->getErrorType(); - } - return false; - } - - // TODO: this is one place where syntax should get cloned! - if (outProperType) - args.add(typeParam->initType.exp); - } - else if (auto valParam = as(member)) - { - if (!valParam->initExpr) - { - if (diagSink) - { - diagSink->diagnose(typeExp.exp.Ptr(), Diagnostics::unimplemented, "can't fill in default for generic type parameter"); - *outProperType = getSession()->getErrorType(); - } - return false; - } - - // TODO: this is one place where syntax should get cloned! - if (outProperType) - args.add(valParam->initExpr); - } - else - { - // ignore non-parameter members - } - } - - if (outProperType) - { - *outProperType = InstantiateGenericType(genericDeclRef, args); - } - return true; - } - - // default case: we expect this to already be a proper type - if (outProperType) - { - *outProperType = type; - } - return true; - } - - - - TypeExp CoerceToProperType(TypeExp const& typeExp) - { - TypeExp result = typeExp; - CoerceToProperTypeImpl(typeExp, &result.type, getSink()); - return result; - } - - TypeExp tryCoerceToProperType(TypeExp const& typeExp) - { - TypeExp result = typeExp; - if(!CoerceToProperTypeImpl(typeExp, &result.type, nullptr)) - return TypeExp(); - return result; - } - - // Check a type, and coerce it to be proper - TypeExp CheckProperType(TypeExp typeExp) - { - return CoerceToProperType(TranslateTypeNode(typeExp)); - } - - // For our purposes, a "usable" type is one that can be - // used to declare a function parameter, variable, etc. - // These turn out to be all the proper types except - // `void`. - // - // TODO(tfoley): consider just allowing `void` as a - // simple example of a "unit" type, and get rid of - // this check. - TypeExp CoerceToUsableType(TypeExp const& typeExp) - { - TypeExp result = CoerceToProperType(typeExp); - Type* type = result.type.Ptr(); - if (auto basicType = as(type)) - { - // TODO: `void` shouldn't be a basic type, to make this easier to avoid - if (basicType->baseType == BaseType::Void) - { - // TODO(tfoley): pick the right diagnostic message - getSink()->diagnose(result.exp.Ptr(), Diagnostics::invalidTypeVoid); - result.type = getSession()->getErrorType(); - return result; - } - } - return result; - } - - // Check a type, and coerce it to be usable - TypeExp CheckUsableType(TypeExp typeExp) - { - return CoerceToUsableType(TranslateTypeNode(typeExp)); - } - - RefPtr CheckTerm(RefPtr term) - { - if (!term) return nullptr; - return ExprVisitor::dispatch(term); - } - - RefPtr CreateErrorExpr(Expr* expr) - { - expr->type = QualType(getSession()->getErrorType()); - return expr; - } - - bool IsErrorExpr(RefPtr expr) - { - // TODO: we may want other cases here... - - if (auto errorType = as(expr->type)) - return true; - - return false; - } - - // Capture the "base" expression in case this is a member reference - RefPtr GetBaseExpr(RefPtr expr) - { - if (auto memberExpr = as(expr)) - { - return memberExpr->BaseExpression; - } - else if(auto overloadedExpr = as(expr)) - { - return overloadedExpr->base; - } - return nullptr; - } - - public: - - bool ValuesAreEqual( - RefPtr left, - RefPtr right) - { - if(left == right) return true; - - if(auto leftConst = as(left)) - { - if(auto rightConst = as(right)) - { - return leftConst->value == rightConst->value; - } - } - - if(auto leftVar = as(left)) - { - if(auto rightVar = as(right)) - { - return leftVar->declRef.Equals(rightVar->declRef); - } - } - - return false; - } - - // Compute the cost of using a particular declaration to - // perform implicit type conversion. - ConversionCost getImplicitConversionCost( - Decl* decl) - { - if(auto modifier = decl->FindModifier()) - { - return modifier->cost; - } - - return kConversionCost_Explicit; - } - - bool isEffectivelyScalarForInitializerLists( - RefPtr type) - { - if(as(type)) return false; - if(as(type)) return false; - if(as(type)) return false; - - if(as(type)) - { - return true; - } - - if(as(type)) - { - return true; - } - if(as(type)) - { - return true; - } - if(as(type)) - { - return true; - } - - if(auto declRefType = as(type)) - { - if(as(declRefType->declRef)) - return false; - } - - return true; - } - - /// Should the provided expression (from an initializer list) be used directly to initialize `toType`? - bool shouldUseInitializerDirectly( - RefPtr toType, - RefPtr fromExpr) - { - // A nested initializer list should always be used directly. - // - if(as(fromExpr)) - { - return true; - } - - // If the desired type is a scalar, then we should always initialize - // directly, since it isn't an aggregate. - // - if(isEffectivelyScalarForInitializerLists(toType)) - return true; - - // If the type we are initializing isn't effectively scalar, - // but the initialization expression *is*, then it doesn't - // seem like direct initialization is intended. - // - if(isEffectivelyScalarForInitializerLists(fromExpr->type)) - return false; - - // Once the above cases are handled, the main thing - // we want to check for is whether a direct initialization - // is possible (a type conversion exists). - // - return canCoerce(toType, fromExpr->type); - } - - /// Read a value from an initializer list expression. - /// - /// This reads one or more argument from the initializer list - /// given as `fromInitializerListExpr` to initialize a value - /// of type `toType`. This may involve reading one or - /// more arguments from the initializer list, depending - /// on whether `toType` is an aggregate or not, and on - /// whether the next argument in the initializer list is - /// itself an initializer list. - /// - /// This routine returns `true` if it was able to read - /// arguments that can form a value of type `toType`, - /// and `false` otherwise. - /// - /// If the routine succeeds and `outToExpr` is non-null, - /// then it will be filled in with an expression - /// representing the value (or type `toType`) that was read, - /// or it will be left null to indicate that a default - /// value should be used. - /// - /// If the routine fails and `outToExpr` is non-null, - /// then a suitable diagnostic will be emitted. - /// - bool _readValueFromInitializerList( - RefPtr toType, - RefPtr* outToExpr, - RefPtr fromInitializerListExpr, - UInt &ioInitArgIndex) - { - // First, we will check if we have run out of arguments - // on the initializer list. - // - UInt initArgCount = fromInitializerListExpr->args.getCount(); - if(ioInitArgIndex >= initArgCount) - { - // If we are at the end of the initializer list, - // then our ability to read an argument depends - // on whether the type we are trying to read - // is default-initializable. - // - // For now, we will just pretend like everything - // is default-initializable and move along. - return true; - } - - // Okay, we have at least one initializer list expression, - // so we will look at the next expression and decide - // whether to use it to initialize the desired type - // directly (possibly via casts), or as the first sub-expression - // for aggregate initialization. - // - auto firstInitExpr = fromInitializerListExpr->args[ioInitArgIndex]; - if(shouldUseInitializerDirectly(toType, firstInitExpr)) - { - ioInitArgIndex++; - return _coerce( - toType, - outToExpr, - firstInitExpr->type, - firstInitExpr, - nullptr); - } - - // If there is somehow an error in one of the initialization - // expressions, then everything could be thrown off and we - // shouldn't keep trying to read arguments. - // - if( IsErrorExpr(firstInitExpr) ) - { - // Stop reading arguments, as if we'd reached - // the end of the list. - // - ioInitArgIndex = initArgCount; - return true; - } - - // The fallback case is to recursively read the - // type from the same list as an aggregate. - // - return _readAggregateValueFromInitializerList( - toType, - outToExpr, - fromInitializerListExpr, - ioInitArgIndex); - } - - /// Read an aggregate value from an initializer list expression. - /// - /// This reads one or more arguments from the initializer list - /// given as `fromInitializerListExpr` to initialize the - /// fields/elements of a value of type `toType`. - /// - /// This routine returns `true` if it was able to read - /// arguments that can form a value of type `toType`, - /// and `false` otherwise. - /// - /// If the routine succeeds and `outToExpr` is non-null, - /// then it will be filled in with an expression - /// representing the value (or type `toType`) that was read, - /// or it will be left null to indicate that a default - /// value should be used. - /// - /// If the routine fails and `outToExpr` is non-null, - /// then a suitable diagnostic will be emitted. - /// - bool _readAggregateValueFromInitializerList( - RefPtr inToType, - RefPtr* outToExpr, - RefPtr fromInitializerListExpr, - UInt &ioArgIndex) - { - auto toType = inToType; - UInt argCount = fromInitializerListExpr->args.getCount(); - - // In the case where we need to build a result expression, - // we will collect the new arguments here - List> coercedArgs; - - if(isEffectivelyScalarForInitializerLists(toType)) - { - // For any type that is effectively a non-aggregate, - // we expect to read a single value from the initializer list - // - if(ioArgIndex < argCount) - { - auto arg = fromInitializerListExpr->args[ioArgIndex++]; - return _coerce( - toType, - outToExpr, - arg->type, - arg, - nullptr); - } - else - { - // If there wasn't an initialization - // expression to be found, then we need - // to perform default initialization here. - // - // We will let this case come through the front-end - // as an `InitializerListExpr` with zero arguments, - // and then have the IR generation logic deal with - // synthesizing default values. - } - } - else if (auto toVecType = as(toType)) - { - auto toElementCount = toVecType->elementCount; - auto toElementType = toVecType->elementType; - - UInt elementCount = 0; - if (auto constElementCount = as(toElementCount)) - { - elementCount = (UInt) constElementCount->value; - } - else - { - // We don't know the element count statically, - // so what are we supposed to be doing? - // - if(outToExpr) - { - getSink()->diagnose(fromInitializerListExpr, Diagnostics::cannotUseInitializerListForVectorOfUnknownSize, toElementCount); - } - return false; - } - - for(UInt ee = 0; ee < elementCount; ++ee) - { - RefPtr coercedArg; - bool argResult = _readValueFromInitializerList( - toElementType, - outToExpr ? &coercedArg : nullptr, - fromInitializerListExpr, - ioArgIndex); - - // No point in trying further if any argument fails - if(!argResult) - return false; - - if( coercedArg ) - { - coercedArgs.add(coercedArg); - } - } - } - else if(auto toArrayType = as(toType)) - { - // TODO(tfoley): If we can compute the size of the array statically, - // then we want to check that there aren't too many initializers present - - auto toElementType = toArrayType->baseType; - - if(auto toElementCount = toArrayType->ArrayLength) - { - // In the case of a sized array, we need to check that the number - // of elements being initialized matches what was declared. - // - UInt elementCount = 0; - if (auto constElementCount = as(toElementCount)) - { - elementCount = (UInt) constElementCount->value; - } - else - { - // We don't know the element count statically, - // so what are we supposed to be doing? - // - if(outToExpr) - { - getSink()->diagnose(fromInitializerListExpr, Diagnostics::cannotUseInitializerListForArrayOfUnknownSize, toElementCount); - } - return false; - } - - for(UInt ee = 0; ee < elementCount; ++ee) - { - RefPtr coercedArg; - bool argResult = _readValueFromInitializerList( - toElementType, - outToExpr ? &coercedArg : nullptr, - fromInitializerListExpr, - ioArgIndex); - - // No point in trying further if any argument fails - if(!argResult) - return false; - - if( coercedArg ) - { - coercedArgs.add(coercedArg); - } - } - } - else - { - // In the case of an unsized array type, we will use the - // number of arguments to the initializer to determine - // the element count. - // - UInt elementCount = 0; - while(ioArgIndex < argCount) - { - RefPtr coercedArg; - bool argResult = _readValueFromInitializerList( - toElementType, - outToExpr ? &coercedArg : nullptr, - fromInitializerListExpr, - ioArgIndex); - - // No point in trying further if any argument fails - if(!argResult) - return false; - - elementCount++; - - if( coercedArg ) - { - coercedArgs.add(coercedArg); - } - } - - // We have a new type for the conversion, based on what - // we learned. - toType = getSession()->getArrayType( - toElementType, - new ConstantIntVal(elementCount)); - } - } - else if(auto toMatrixType = as(toType)) - { - // In the general case, the initializer list might comprise - // both vectors and scalars. - // - // The traditional HLSL compilers treat any vectors in - // the initializer list exactly equivalent to their sequence - // of scalar elements, and don't care how this might, or - // might not, align with the rows of the matrix. - // - // We will draw a line in the sand and say that an initializer - // list for a matrix will act as if the matrix type were an - // array of vectors for the rows. - - - UInt rowCount = 0; - auto toRowType = createVectorType( - toMatrixType->getElementType(), - toMatrixType->getColumnCount()); - - if (auto constRowCount = as(toMatrixType->getRowCount())) - { - rowCount = (UInt) constRowCount->value; - } - else - { - // We don't know the element count statically, - // so what are we supposed to be doing? - // - if(outToExpr) - { - getSink()->diagnose(fromInitializerListExpr, Diagnostics::cannotUseInitializerListForMatrixOfUnknownSize, toMatrixType->getRowCount()); - } - return false; - } - - for(UInt rr = 0; rr < rowCount; ++rr) - { - RefPtr coercedArg; - bool argResult = _readValueFromInitializerList( - toRowType, - outToExpr ? &coercedArg : nullptr, - fromInitializerListExpr, - ioArgIndex); - - // No point in trying further if any argument fails - if(!argResult) - return false; - - if( coercedArg ) - { - coercedArgs.add(coercedArg); - } - } - } - else if(auto toDeclRefType = as(toType)) - { - auto toTypeDeclRef = toDeclRefType->declRef; - if(auto toStructDeclRef = toTypeDeclRef.as()) - { - // Trying to initialize a `struct` type given an initializer list. - // We will go through the fields in order and try to match them - // up with initializer arguments. - // - for(auto fieldDeclRef : getMembersOfType(toStructDeclRef)) - { - RefPtr coercedArg; - bool argResult = _readValueFromInitializerList( - GetType(fieldDeclRef), - outToExpr ? &coercedArg : nullptr, - fromInitializerListExpr, - ioArgIndex); - - // No point in trying further if any argument fails - if(!argResult) - return false; - - if( coercedArg ) - { - coercedArgs.add(coercedArg); - } - } - } - } - else - { - // We shouldn't get to this case in practice, - // but just in case we'll consider an initializer - // list invalid if we are trying to read something - // off of it that wasn't handled by the cases above. - // - if(outToExpr) - { - getSink()->diagnose(fromInitializerListExpr, Diagnostics::cannotUseInitializerListForType, inToType); - } - return false; - } - - // We were able to coerce all the arguments given, and so - // we need to construct a suitable expression to remember the result - // - if(outToExpr) - { - auto toInitializerListExpr = new InitializerListExpr(); - toInitializerListExpr->loc = fromInitializerListExpr->loc; - toInitializerListExpr->type = QualType(toType); - toInitializerListExpr->args = coercedArgs; - - *outToExpr = toInitializerListExpr; - } - - return true; - } - - /// Coerce an initializer-list expression to a specific type. - /// - /// This reads one or more arguments from the initializer list - /// given as `fromInitializerListExpr` to initialize the - /// fields/elements of a value of type `toType`. - /// - /// This routine returns `true` if it was able to read - /// arguments that can form a value of type `toType`, - /// with no arguments left over, and `false` otherwise. - /// - /// If the routine succeeds and `outToExpr` is non-null, - /// then it will be filled in with an expression - /// representing the value (or type `toType`) that was read, - /// or it will be left null to indicate that a default - /// value should be used. - /// - /// If the routine fails and `outToExpr` is non-null, - /// then a suitable diagnostic will be emitted. - /// - bool _coerceInitializerList( - RefPtr toType, - RefPtr* outToExpr, - RefPtr fromInitializerListExpr) - { - UInt argCount = fromInitializerListExpr->args.getCount(); - UInt argIndex = 0; - - // TODO: we should handle the special case of `{0}` as an initializer - // for arbitrary `struct` types here. - - if(!_readAggregateValueFromInitializerList(toType, outToExpr, fromInitializerListExpr, argIndex)) - return false; - - if(argIndex != argCount) - { - if( outToExpr ) - { - getSink()->diagnose(fromInitializerListExpr, Diagnostics::tooManyInitializers, argIndex, argCount); - } - } - - return true; - } - - /// Report that implicit type coercion is not possible. - bool _failedCoercion( - RefPtr toType, - RefPtr* outToExpr, - RefPtr fromExpr) - { - if(outToExpr) - { - getSink()->diagnose(fromExpr->loc, Diagnostics::typeMismatch, toType, fromExpr->type); - } - return false; - } - - /// Central engine for implementing implicit coercion logic - /// - /// This function tries to find an implicit conversion path from - /// `fromType` to `toType`. It returns `true` if a conversion - /// is found, and `false` if not. - /// - /// If a conversion is found, then its cost will be written to `outCost`. - /// - /// If a `fromExpr` is provided, it must be of type `fromType`, - /// and represent a value to be converted. - /// - /// If `outToExpr` is non-null, and if a conversion is found, then - /// `*outToExpr` will be set to an expression that performs the - /// implicit conversion of `fromExpr` (which must be non-null - /// to `toType`). - /// - /// The case where `outToExpr` is non-null is used to identify - /// when a conversion is being done "for real" so that diagnostics - /// should be emitted on failure. - /// - bool _coerce( - RefPtr toType, - RefPtr* outToExpr, - RefPtr fromType, - RefPtr fromExpr, - ConversionCost* outCost) - { - // An important and easy case is when the "to" and "from" types are equal. - // - if( toType->Equals(fromType) ) - { - if(outToExpr) - *outToExpr = fromExpr; - if(outCost) - *outCost = kConversionCost_None; - return true; - } - - // Another important case is when either the "to" or "from" type - // represents an error. In such a case we must have already - // reporeted the error, so it is better to allow the conversion - // to pass than to report a "cascading" error that might not - // make any sense. - // - if(as(toType) || as(fromType)) - { - if(outToExpr) - *outToExpr = CreateImplicitCastExpr(toType, fromExpr); - if(outCost) - *outCost = kConversionCost_None; - return true; - } - - // Coercion from an initializer list is allowed for many types, - // so we will farm that out to its own subroutine. - // - if( auto fromInitializerListExpr = as(fromExpr)) - { - if( !_coerceInitializerList( - toType, - outToExpr, - fromInitializerListExpr) ) - { - return false; - } - - // For now, we treat coercion from an initializer list - // as having no cost, so that all conversions from initializer - // lists are equally valid. This is fine given where initializer - // lists are allowed to appear now, but might need to be made - // more strict if we allow for initializer lists in more - // places in the language (e.g., as function arguments). - // - if(outCost) - { - *outCost = kConversionCost_None; - } - - return true; - } - - // If we are casting to an interface type, then that will succeed - // if the "from" type conforms to the interface. - // - if (auto toDeclRefType = as(toType)) - { - auto toTypeDeclRef = toDeclRefType->declRef; - if (auto interfaceDeclRef = toTypeDeclRef.as()) - { - if(auto witness = tryGetInterfaceConformanceWitness(fromType, interfaceDeclRef)) - { - if (outToExpr) - *outToExpr = createCastToInterfaceExpr(toType, fromExpr, witness); - if (outCost) - *outCost = kConversionCost_CastToInterface; - return true; - } - } - } - - // We allow implicit conversion of a parameter group type like - // `ConstantBuffer` or `ParameterBlock` to its element - // type `X`. - // - if(auto fromParameterGroupType = as(fromType)) - { - auto fromElementType = fromParameterGroupType->getElementType(); - - // If we convert, e.g., `ConstantBuffer to `A`, we will allow - // subsequent conversion of `A` to `B` if such a conversion - // is possible. - // - ConversionCost subCost = kConversionCost_None; - - RefPtr derefExpr; - if(outToExpr) - { - derefExpr = new DerefExpr(); - derefExpr->base = fromExpr; - derefExpr->type = QualType(fromElementType); - } - - if(!_coerce( - toType, - outToExpr, - fromElementType, - derefExpr, - &subCost)) - { - return false; - } - - if(outCost) - *outCost = subCost + kConversionCost_ImplicitDereference; - return true; - } - - // The main general-purpose approach for conversion is - // using suitable marked initializer ("constructor") - // declarations on the target type. - // - // This is treated as a form of overload resolution, - // since we are effectively forming an overloaded - // call to one of the initializers in the target type. - - OverloadResolveContext overloadContext; - overloadContext.disallowNestedConversions = true; - overloadContext.argCount = 1; - overloadContext.argTypes = &fromType; - - overloadContext.originalExpr = nullptr; - if(fromExpr) - { - overloadContext.loc = fromExpr->loc; - overloadContext.funcLoc = fromExpr->loc; - overloadContext.args = &fromExpr; - } - - overloadContext.baseExpr = nullptr; - overloadContext.mode = OverloadResolveContext::Mode::JustTrying; - - AddTypeOverloadCandidates(toType, overloadContext, toType); - - // After all of the overload candidates have been added - // to the context and processed, we need to see whether - // there was one best overload or not. - // - if(overloadContext.bestCandidates.getCount() != 0) - { - // In this case there were multiple equally-good candidates to call. - // - // We will start by checking if the candidates are - // even applicable, because if not, then we shouldn't - // consider the conversion as possible. - // - if(overloadContext.bestCandidates[0].status != OverloadCandidate::Status::Applicable) - return _failedCoercion(toType, outToExpr, fromExpr); - - // If all of the candidates in `bestCandidates` are applicable, - // then we have an ambiguity. - // - // We will compute a nominal conversion cost as the minimum over - // all the conversions available. - // - ConversionCost bestCost = kConversionCost_Explicit; - for(auto candidate : overloadContext.bestCandidates) - { - ConversionCost candidateCost = getImplicitConversionCost( - candidate.item.declRef.getDecl()); - - if(candidateCost < bestCost) - bestCost = candidateCost; - } - - // Conceptually, we want to treat the conversion as - // possible, but report it as ambiguous if we actually - // need to reify the result as an expression. - // - if(outToExpr) - { - getSink()->diagnose(fromExpr, Diagnostics::ambiguousConversion, fromType, toType); - - *outToExpr = CreateErrorExpr(fromExpr); - } - - if(outCost) - *outCost = bestCost; - - return true; - } - else if(overloadContext.bestCandidate) - { - // If there is a single best candidate for conversion, - // then we want to use it. - // - // It is possible that there was a single best candidate, - // but it wasn't actually usable, so we will check for - // that case first. - // - if(overloadContext.bestCandidate->status != OverloadCandidate::Status::Applicable) - return _failedCoercion(toType, outToExpr, fromExpr); - - // Next, we need to look at the implicit conversion - // cost associated with the initializer we are invoking. - // - ConversionCost cost = getImplicitConversionCost( - overloadContext.bestCandidate->item.declRef.getDecl());; - - // If the cost is too high to be usable as an - // implicit conversion, then we will report the - // conversion as possible (so that an overload involving - // this conversion will be selected over one without), - // but then emit a diagnostic when actually reifying - // the result expression. - // - if( cost >= kConversionCost_Explicit ) - { - if( outToExpr ) - { - getSink()->diagnose(fromExpr, Diagnostics::typeMismatch, toType, fromType); - getSink()->diagnose(fromExpr, Diagnostics::noteExplicitConversionPossible, fromType, toType); - } - } - - if(outCost) - *outCost = cost; - - if(outToExpr) - { - // The logic here is a bit ugly, to deal with the fact that - // `CompleteOverloadCandidate` will, left to its own devices, - // construct a vanilla `InvokeExpr` to represent the call - // to the initializer we found, while we *want* it to - // create some variety of `ImplicitCastExpr`. - // - // Now, it just so happens that `CompleteOverloadCandidate` - // will use the "original" expression if one is available, - // so we'll create one and initialize it here. - // We fill in the location and arguments, but not the - // base expression (the callee), since that will come - // from the selected overload candidate. - // - auto castExpr = createImplicitCastExpr(); - castExpr->loc = fromExpr->loc; - castExpr->Arguments.add(fromExpr); - // - // Next we need to set our cast expression as the "original" - // expression and then complete the overload process. - // - overloadContext.originalExpr = castExpr; - *outToExpr = CompleteOverloadCandidate(overloadContext, *overloadContext.bestCandidate); - // - // However, the above isn't *quite* enough, because - // the process of completing the overload candidate - // might overwrite the argument list that was passed - // in to overload resolution, and in this case that - // "argument list" was just a pointer to `fromExpr`. - // - // That means we need to clear the argument list and - // reload it from `fromExpr` to make sure that we - // got the arguments *after* any transformations - // were applied. - // For right now this probably doesn't matter, - // because we don't allow nested implicit conversions, - // but I'd rather play it safe. - // - castExpr->Arguments.clear(); - castExpr->Arguments.add(fromExpr); - } - - return true; - } - - return _failedCoercion(toType, outToExpr, fromExpr); - } - - /// Check whether implicit type coercion from `fromType` to `toType` is possible. - /// - /// If conversion is possible, returns `true` and sets `outCost` to the cost - /// of the conversion found (if `outCost` is non-null). - /// - /// If conversion is not possible, returns `false`. - /// - bool canCoerce( - RefPtr toType, - RefPtr fromType, - ConversionCost* outCost = 0) - { - // As an optimization, we will maintain a cache of conversion results - // for basic types such as scalars and vectors. - // - BasicTypeKey key1, key2; - BasicTypeKeyPair cacheKey; - bool shouldAddToCache = false; - ConversionCost cost; - TypeCheckingCache* typeCheckingCache = getSession()->getTypeCheckingCache(); - if( key1.fromType(toType.Ptr()) && key2.fromType(fromType.Ptr()) ) - { - cacheKey.type1 = key1; - cacheKey.type2 = key2; - - if (typeCheckingCache->conversionCostCache.TryGetValue(cacheKey, cost)) - { - if (outCost) - *outCost = cost; - return cost != kConversionCost_Impossible; - } - else - shouldAddToCache = true; - } - - // If there was no suitable entry in the cache, - // then we fall back to the general-purpose - // conversion checking logic. - // - // Note that we are passing in `nullptr` as - // the output expression to be constructed, - // which suppresses emission of any diagnostics - // during the coercion process. - // - bool rs = _coerce( - toType, - nullptr, - fromType, - nullptr, - &cost); - - if (outCost) - *outCost = cost; - - if (shouldAddToCache) - { - if (!rs) - cost = kConversionCost_Impossible; - typeCheckingCache->conversionCostCache[cacheKey] = cost; - } - - return rs; - } - - RefPtr createImplicitCastExpr() - { - return new ImplicitCastExpr(); - } - - RefPtr CreateImplicitCastExpr( - RefPtr toType, - RefPtr fromExpr) - { - RefPtr castExpr = createImplicitCastExpr(); - - auto typeType = getTypeType(toType); - - auto typeExpr = new SharedTypeExpr(); - typeExpr->type.type = typeType; - typeExpr->base.type = toType; - - castExpr->loc = fromExpr->loc; - castExpr->FunctionExpr = typeExpr; - castExpr->type = QualType(toType); - castExpr->Arguments.add(fromExpr); - return castExpr; - } - - /// Create an "up-cast" from a value to an interface type - /// - /// This operation logically constructs an "existential" value, - /// which packages up the value, its type, and the witness - /// of its conformance to the interface. - /// - RefPtr createCastToInterfaceExpr( - RefPtr toType, - RefPtr fromExpr, - RefPtr witness) - { - RefPtr expr = new CastToInterfaceExpr(); - expr->loc = fromExpr->loc; - expr->type = QualType(toType); - expr->valueArg = fromExpr; - expr->witnessArg = witness; - return expr; - } - - /// Implicitly coerce `fromExpr` to `toType` and diagnose errors if it isn't possible - RefPtr coerce( - RefPtr toType, - RefPtr fromExpr) - { - RefPtr expr; - if (!_coerce( - toType, - &expr, - fromExpr->type.Ptr(), - fromExpr.Ptr(), - nullptr)) - { - // Note(tfoley): We don't call `CreateErrorExpr` here, because that would - // clobber the type on `fromExpr`, and an invariant here is that coercion - // really shouldn't *change* the expression that is passed in, but should - // introduce new AST nodes to coerce its value to a different type... - return CreateImplicitCastExpr( - getSession()->getErrorType(), - fromExpr); - } - return expr; - } - - void CheckVarDeclCommon(RefPtr varDecl) - { - // A variable that didn't have an explicit type written must - // have its type inferred from the initial-value expression. - // - if(!varDecl->type.exp) - { - // In this case we need to perform all checking of the - // variable (including semantic checking of the initial-value - // expression) during the first phase of checking. - - auto initExpr = varDecl->initExpr; - if(!initExpr) - { - getSink()->diagnose(varDecl, Diagnostics::varWithoutTypeMustHaveInitializer); - varDecl->type.type = getSession()->getErrorType(); - } - else - { - initExpr = CheckExpr(initExpr); - - // TODO: We might need some additional steps here to ensure - // that the type of the expression is one we are okay with - // inferring. E.g., if we ever decide that integer and floating-point - // literals have a distinct type from the standard int/float types, - // then we would need to "decay" a literal to an explicit type here. - - varDecl->initExpr = initExpr; - varDecl->type.type = initExpr->type; - } - - varDecl->SetCheckState(DeclCheckState::Checked); - } - else - { - if (function || checkingPhase == CheckingPhase::Header) - { - TypeExp typeExp = CheckUsableType(varDecl->type); - varDecl->type = typeExp; - if (varDecl->type.Equals(getSession()->getVoidType())) - { - getSink()->diagnose(varDecl, Diagnostics::invalidTypeVoid); - } - } - - if (checkingPhase == CheckingPhase::Body) - { - if (auto initExpr = varDecl->initExpr) - { - initExpr = CheckTerm(initExpr); - initExpr = coerce(varDecl->type.Ptr(), initExpr); - varDecl->initExpr = initExpr; - - // If this is an array variable, then we first want to give - // it a chance to infer an array size from its initializer - // - // TODO(tfoley): May need to extend this to handle the - // multi-dimensional case... - // - maybeInferArraySizeForVariable(varDecl); - // - // Next we want to make sure that the declared (or inferred) - // size for the array meets whatever language-specific - // constraints we want to enforce (e.g., disallow empty - // arrays in specific cases) - // - validateArraySizeForVariable(varDecl); - } - } - - } - varDecl->SetCheckState(getCheckedState()); - } - - // Fill in default substitutions for the 'subtype' part of a type constraint decl - void CheckConstraintSubType(TypeExp& typeExp) - { - if (auto sharedTypeExpr = as(typeExp.exp)) - { - if (auto declRefType = as(sharedTypeExpr->base)) - { - declRefType->declRef.substitutions = createDefaultSubstitutions(getSession(), declRefType->declRef.getDecl()); - - if (auto typetype = as(typeExp.exp->type)) - typetype->type = declRefType; - } - } - } - - void CheckGenericConstraintDecl(GenericTypeConstraintDecl* decl) - { - // TODO: are there any other validations we can do at this point? - // - // There probably needs to be a kind of "occurs check" to make - // sure that the constraint actually applies to at least one - // of the parameters of the generic. - if (decl->checkState == DeclCheckState::Unchecked) - { - decl->checkState = getCheckedState(); - CheckConstraintSubType(decl->sub); - decl->sub = TranslateTypeNodeForced(decl->sub); - decl->sup = TranslateTypeNodeForced(decl->sup); - } - } - - void checkDecl(Decl* decl) - { - EnsureDecl(decl, checkingPhase == CheckingPhase::Header ? DeclCheckState::CheckedHeader : DeclCheckState::Checked); - } - - void checkGenericDeclHeader(GenericDecl* genericDecl) - { - if (genericDecl->IsChecked(DeclCheckState::CheckedHeader)) - return; - // check the parameters - for (auto m : genericDecl->Members) - { - if (auto typeParam = as(m)) - { - typeParam->initType = CheckProperType(typeParam->initType); - } - else if (auto valParam = as(m)) - { - // TODO: some real checking here... - CheckVarDeclCommon(valParam); - } - else if (auto constraint = as(m)) - { - CheckGenericConstraintDecl(constraint); - } - } - - genericDecl->SetCheckState(DeclCheckState::CheckedHeader); - } - - void visitGenericDecl(GenericDecl* genericDecl) - { - checkGenericDeclHeader(genericDecl); - - // check the nested declaration - // TODO: this needs to be done in an appropriate environment... - checkDecl(genericDecl->inner); - genericDecl->SetCheckState(getCheckedState()); - } - - void visitGenericTypeConstraintDecl(GenericTypeConstraintDecl * genericConstraintDecl) - { - if (genericConstraintDecl->IsChecked(DeclCheckState::CheckedHeader)) - return; - // check the type being inherited from - auto base = genericConstraintDecl->sup; - base = TranslateTypeNode(base); - genericConstraintDecl->sup = base; - } - - void visitInheritanceDecl(InheritanceDecl* inheritanceDecl) - { - if (inheritanceDecl->IsChecked(DeclCheckState::CheckedHeader)) - return; - // check the type being inherited from - auto base = inheritanceDecl->base; - CheckConstraintSubType(base); - base = TranslateTypeNode(base); - inheritanceDecl->base = base; - - // For now we only allow inheritance from interfaces, so - // we will validate that the type expression names an interface - - if(auto declRefType = as(base.type)) - { - if(auto interfaceDeclRef = declRefType->declRef.as()) - { - return; - } - } - else if(base.type.is()) - { - // If an error was already produced, don't emit a cascading error. - return; - } - - // If type expression didn't name an interface, we'll emit an error here - // TODO: deal with the case of an error in the type expression (don't cascade) - getSink()->diagnose( base.exp, Diagnostics::expectedAnInterfaceGot, base.type); - } - - RefPtr checkConstantIntVal( - RefPtr expr) - { - // First type-check the expression as normal - expr = CheckExpr(expr); - - auto intVal = CheckIntegerConstantExpression(expr.Ptr()); - if(!intVal) - return nullptr; - - auto constIntVal = as(intVal); - if(!constIntVal) - { - getSink()->diagnose(expr->loc, Diagnostics::expectedIntegerConstantNotLiteral); - return nullptr; - } - return constIntVal; - } - - RefPtr checkConstantEnumVal( - RefPtr expr) - { - // First type-check the expression as normal - expr = CheckExpr(expr); - - auto intVal = CheckEnumConstantExpression(expr.Ptr()); - if(!intVal) - return nullptr; - - auto constIntVal = as(intVal); - if(!constIntVal) - { - getSink()->diagnose(expr->loc, Diagnostics::expectedIntegerConstantNotLiteral); - return nullptr; - } - return constIntVal; - } - - // Check an expression, coerce it to the `String` type, and then - // ensure that it has a literal (not just compile-time constant) value. - bool checkLiteralStringVal( - RefPtr expr, - String* outVal) - { - // TODO: This should actually perform semantic checking, etc., - // but for now we are just going to look for a direct string - // literal AST node. - - if(auto stringLitExpr = as(expr)) - { - if(outVal) - { - *outVal = stringLitExpr->value; - } - return true; - } - - getSink()->diagnose(expr, Diagnostics::expectedAStringLiteral); - - return false; - } - - void visitSyntaxDecl(SyntaxDecl*) - { - // These are only used in the stdlib, so no checking is needed - } - - void visitAttributeDecl(AttributeDecl*) - { - // These are only used in the stdlib, so no checking is needed - } - - void visitGenericTypeParamDecl(GenericTypeParamDecl*) - { - // These are only used in the stdlib, so no checking is needed for now - } - - void visitGenericValueParamDecl(GenericValueParamDecl*) - { - // These are only used in the stdlib, so no checking is needed for now - } - - void visitModifier(Modifier*) - { - // Do nothing with modifiers for now - } - - AttributeDecl* lookUpAttributeDecl(Name* attributeName, Scope* scope) - { - // Look up the name and see what we find. - // - // TODO: This needs to have some special filtering or naming - // rules to keep us from seeing shadowing variable declarations. - auto lookupResult = lookUp(getSession(), this, attributeName, scope, LookupMask::Attribute); - - // If the result was overloaded, - // then we aren't going to be able to extract a single decl. - if(lookupResult.isOverloaded()) - return nullptr; - - if (lookupResult.isValid()) - { - auto decl = lookupResult.item.declRef.getDecl(); - if (auto attributeDecl = as(decl)) - { - return attributeDecl; - } - else - { - return nullptr; - } - } - - // If we couldn't find a system attribute, try looking up as a user defined attribute - // A user defined attribute class is defined as a struct type with a "UserDefinedAttributeAttribute" modifier - lookupResult = lookUp(getSession(), this, getSession()->getNameObj(attributeName->text + "Attribute"), scope, LookupMask::type); - if (lookupResult.isOverloaded()) - { - // see if we have already created an AttributeDecl for this attribute struct - for (auto alt : lookupResult.items) - { - if (auto adecl = alt.declRef.as()) - return adecl.getDecl(); - } - } - // If we still cannot find any thing, quit - if (!lookupResult.isValid() || lookupResult.isOverloaded()) - return nullptr; - // Now construct an AttributeDecl for this user defined attribute class for future lookup - auto userDefAttribAttrib = lookupResult.item.declRef.decl->FindModifier(); - if (!userDefAttribAttrib) - return nullptr; - // create an AttributeDecl for the user defined attribute - auto structAttribDef = lookupResult.item.declRef.as().getDecl(); - RefPtr attribDecl = new AttributeDecl(); - attribDecl->nameAndLoc = structAttribDef->nameAndLoc; - attribDecl->loc = structAttribDef->loc; - attribDecl->nextInContainerWithSameName = structAttribDef->nextInContainerWithSameName; - // create a __attributeTarget modifier for the attribute class definition - RefPtr targetModifier = new AttributeTargetModifier(); - targetModifier->syntaxClass = userDefAttribAttrib->targetSyntaxClass; - targetModifier->loc = structAttribDef->loc; - targetModifier->next = attribDecl->modifiers.first; - attribDecl->modifiers.first = targetModifier; - structAttribDef->nextInContainerWithSameName = attribDecl.Ptr(); - // we should create UserDefinedAttribute nodes for all user defined attribute instances - attribDecl->syntaxClass = getSession()->findSyntaxClass(getSession()->getNameObj("UserDefinedAttribute")); - for (auto member : structAttribDef->Members) - { - if (auto varMember = as(member)) - { - RefPtr param = new ParamDecl(); - param->nameAndLoc = member->nameAndLoc; - param->type = varMember->type; - param->loc = member->loc; - attribDecl->Members.add(param); - } - } - // add the attribute class definition to the syntax tree, so it can be found - structAttribDef->ParentDecl->Members.add(attribDecl.Ptr()); - structAttribDef->ParentDecl->memberDictionaryIsValid = false; - // do necessary checks on this newly constructed node - checkDecl(attribDecl.Ptr()); - return attribDecl.Ptr(); - } - - bool hasIntArgs(Attribute* attr, int numArgs) - { - if (int(attr->args.getCount()) != numArgs) - { - return false; - } - for (int i = 0; i < numArgs; ++i) - { - if (!as(attr->args[i])) - { - return false; - } - } - return true; - } - - bool hasStringArgs(Attribute* attr, int numArgs) - { - if (int(attr->args.getCount()) != numArgs) - { - return false; - } - for (int i = 0; i < numArgs; ++i) - { - if (!as(attr->args[i])) - { - return false; - } - } - return true; - } - - bool getAttributeTargetSyntaxClasses(SyntaxClass & cls, uint32_t typeFlags) - { - if (typeFlags == (int)UserDefinedAttributeTargets::Struct) - { - cls = getSession()->findSyntaxClass(getSession()->getNameObj("StructDecl")); - return true; - } - if (typeFlags == (int)UserDefinedAttributeTargets::Var) - { - cls = getSession()->findSyntaxClass(getSession()->getNameObj("VarDecl")); - return true; - } - if (typeFlags == (int)UserDefinedAttributeTargets::Function) - { - cls = getSession()->findSyntaxClass(getSession()->getNameObj("FuncDecl")); - return true; - } - return false; - } - - bool validateAttribute(RefPtr attr, AttributeDecl* attribClassDecl) - { - if(auto numThreadsAttr = as(attr)) - { - SLANG_ASSERT(attr->args.getCount() == 3); - auto xVal = checkConstantIntVal(attr->args[0]); - auto yVal = checkConstantIntVal(attr->args[1]); - auto zVal = checkConstantIntVal(attr->args[2]); - - if(!xVal) return false; - if(!yVal) return false; - if(!zVal) return false; - - numThreadsAttr->x = (int32_t) xVal->value; - numThreadsAttr->y = (int32_t) yVal->value; - numThreadsAttr->z = (int32_t) zVal->value; - } - else if (auto bindingAttr = as(attr)) - { - // This must be vk::binding or gl::binding (as specified in core.meta.slang under vk_binding/gl_binding) - // Must have 2 int parameters. Ideally this would all be checked from the specification - // in core.meta.slang, but that's not completely implemented. So for now we check here. - if (attr->args.getCount() != 2) - { - return false; - } - - // TODO(JS): Prior validation currently doesn't ensure both args are ints (as specified in core.meta.slang), so check here - // to make sure they both are - auto binding = checkConstantIntVal(attr->args[0]); - auto set = checkConstantIntVal(attr->args[1]); - - if (binding == nullptr || set == nullptr) - { - return false; - } - - bindingAttr->binding = int32_t(binding->value); - bindingAttr->set = int32_t(set->value); - } - else if (auto maxVertexCountAttr = as(attr)) - { - SLANG_ASSERT(attr->args.getCount() == 1); - auto val = checkConstantIntVal(attr->args[0]); - - if(!val) return false; - - maxVertexCountAttr->value = (int32_t)val->value; - } - else if(auto instanceAttr = as(attr)) - { - SLANG_ASSERT(attr->args.getCount() == 1); - auto val = checkConstantIntVal(attr->args[0]); - - if(!val) return false; - - instanceAttr->value = (int32_t)val->value; - } - else if(auto entryPointAttr = as(attr)) - { - SLANG_ASSERT(attr->args.getCount() == 1); - - String stageName; - if(!checkLiteralStringVal(attr->args[0], &stageName)) - { - return false; - } - - auto stage = findStageByName(stageName); - if(stage == Stage::Unknown) - { - getSink()->diagnose(attr->args[0], Diagnostics::unknownStageName, stageName); - } - - entryPointAttr->stage = stage; - } - else if ((as(attr)) || - (as(attr)) || - (as(attr)) || - (as(attr)) || - (as(attr))) - { - // Let it go thru iff single string attribute - if (!hasStringArgs(attr, 1)) - { - getSink()->diagnose(attr, Diagnostics::expectedSingleStringArg, attr->name); - } - } - else if (as(attr)) - { - // Let it go thru iff single integral attribute - if (!hasIntArgs(attr, 1)) - { - getSink()->diagnose(attr, Diagnostics::expectedSingleIntArg, attr->name); - } - } - else if (as(attr)) - { - // Has no args - SLANG_ASSERT(attr->args.getCount() == 0); - } - else if (as(attr)) - { - // Has no args - SLANG_ASSERT(attr->args.getCount() == 0); - } - else if (as(attr)) - { - // Has no args - SLANG_ASSERT(attr->args.getCount() == 0); - } - else if (auto attrUsageAttr = as(attr)) - { - uint32_t targetClassId = (uint32_t)UserDefinedAttributeTargets::None; - if (attr->args.getCount() == 1) - { - RefPtr outIntVal; - if (auto cInt = checkConstantEnumVal(attr->args[0])) - { - targetClassId = (uint32_t)(cInt->value); - } - else - { - getSink()->diagnose(attr, Diagnostics::expectedSingleIntArg, attr->name); - return false; - } - } - if (!getAttributeTargetSyntaxClasses(attrUsageAttr->targetSyntaxClass, targetClassId)) - { - getSink()->diagnose(attr, Diagnostics::invalidAttributeTarget); - return false; - } - } - else if (auto unrollAttr = as(attr)) - { - // Check has an argument. We need this because default behavior is to give an error - // if an attribute has arguments, but not handled explicitly (and the default param will come through - // as 1 arg if nothing is specified) - SLANG_ASSERT(attr->args.getCount() == 1); - } - else if (auto userDefAttr = as(attr)) - { - // check arguments against attribute parameters defined in attribClassDecl - Index paramIndex = 0; - auto params = attribClassDecl->getMembersOfType(); - for (auto paramDecl : params) - { - if (paramIndex < attr->args.getCount()) - { - auto & arg = attr->args[paramIndex]; - bool typeChecked = false; - if (auto basicType = as(paramDecl->getType())) - { - if (basicType->baseType == BaseType::Int) - { - if (auto cint = checkConstantIntVal(arg)) - { - attr->intArgVals[(uint32_t)paramIndex] = cint; - } - typeChecked = true; - } - } - if (!typeChecked) - { - arg = CheckExpr(arg); - arg = coerce(paramDecl->getType(), arg); - } - } - paramIndex++; - } - if (params.getCount() < attr->args.getCount()) - { - getSink()->diagnose(attr, Diagnostics::tooManyArguments, attr->args.getCount(), params.getCount()); - } - else if (params.getCount() > attr->args.getCount()) - { - getSink()->diagnose(attr, Diagnostics::notEnoughArguments, attr->args.getCount(), params.getCount()); - } - } - else if (auto formatAttr = as(attr)) - { - SLANG_ASSERT(attr->args.getCount() == 1); - - String formatName; - if(!checkLiteralStringVal(attr->args[0], &formatName)) - { - return false; - } - - ImageFormat format = ImageFormat::unknown; - if(!findImageFormatByName(formatName.getBuffer(), &format)) - { - getSink()->diagnose(attr->args[0], Diagnostics::unknownImageFormatName, formatName); - } - - formatAttr->format = format; - } - else - { - if(attr->args.getCount() == 0) - { - // If the attribute took no arguments, then we will - // assume it is valid as written. - } - else - { - // We should be special-casing the checking of any attribute - // with a non-zero number of arguments. - SLANG_DIAGNOSE_UNEXPECTED(getSink(), attr, "unhandled attribute"); - return false; - } - } - - return true; - } - - RefPtr checkAttribute( - UncheckedAttribute* uncheckedAttr, - ModifiableSyntaxNode* attrTarget) - { - auto attrName = uncheckedAttr->getName(); - auto attrDecl = lookUpAttributeDecl( - attrName, - uncheckedAttr->scope); - - if(!attrDecl) - { - getSink()->diagnose(uncheckedAttr, Diagnostics::unknownAttributeName, attrName); - return uncheckedAttr; - } - - if(!attrDecl->syntaxClass.isSubClassOf()) - { - SLANG_DIAGNOSE_UNEXPECTED(getSink(), attrDecl, "attribute declaration does not reference an attribute class"); - return uncheckedAttr; - } - - // Manage scope - RefPtr attrInstance = attrDecl->syntaxClass.createInstance(); - auto attr = attrInstance.as(); - if(!attr) - { - SLANG_DIAGNOSE_UNEXPECTED(getSink(), attrDecl, "attribute class did not yield an attribute object"); - return uncheckedAttr; - } - - // We are going to replace the unchecked attribute with the checked one. - - // First copy all of the state over from the original attribute. - attr->name = uncheckedAttr->name; - attr->args = uncheckedAttr->args; - attr->loc = uncheckedAttr->loc; - - // We will start with checking steps that can be applied independent - // of the concrete attribute type that was selected. These only need - // us to look at the attribute declaration itself. - // - // Start by doing argument/parameter matching - UInt argCount = attr->args.getCount(); - UInt paramCounter = 0; - bool mismatch = false; - for(auto paramDecl : attrDecl->getMembersOfType()) - { - UInt paramIndex = paramCounter++; - if( paramIndex < argCount ) - { - // TODO: support checking the argument against the declared - // type for the parameter. - } - else - { - // We didn't have enough arguments for the - // number of parameters declared. - if(auto defaultArg = paramDecl->initExpr) - { - // The attribute declaration provided a default, - // so we should use that. - // - // TODO: we need to figure out how to hook up - // default arguments as needed. - // For now just copy the expression over. - - attr->args.add(paramDecl->initExpr); - } - else - { - mismatch = true; - } - } - } - UInt paramCount = paramCounter; - - if(mismatch) - { - getSink()->diagnose(attr, Diagnostics::attributeArgumentCountMismatch, attrName, paramCount, argCount); - return uncheckedAttr; - } - - // The next bit of validation that we can apply semi-generically - // is to validate that the target for this attribute is a valid - // one for the chosen attribute. - // - // The attribute declaration will have one or more `AttributeTargetModifier`s - // that each specify a syntax class that the attribute can be applied to. - // If any of these match `attrTarget`, then we are good. - // - bool validTarget = false; - for(auto attrTargetMod : attrDecl->GetModifiersOfType()) - { - if(attrTarget->getClass().isSubClassOf(attrTargetMod->syntaxClass)) - { - validTarget = true; - break; - } - } - if(!validTarget) - { - getSink()->diagnose(attr, Diagnostics::attributeNotApplicable, attrName); - return uncheckedAttr; - } - - // Now apply type-specific validation to the attribute. - if(!validateAttribute(attr, attrDecl)) - { - return uncheckedAttr; - } - - - return attr; - } - - RefPtr checkModifier( - RefPtr m, - ModifiableSyntaxNode* syntaxNode) - { - if(auto hlslUncheckedAttribute = as(m)) - { - // We have an HLSL `[name(arg,...)]` attribute, and we'd like - // to check that it is provides all the expected arguments - // - // First, look up the attribute name in the current scope to find - // the right syntax class to instantiate. - // - - return checkAttribute(hlslUncheckedAttribute, syntaxNode); - } - // Default behavior is to leave things as they are, - // and assume that modifiers are mostly already checked. - // - // TODO: This would be a good place to validate that - // a modifier is actually valid for the thing it is - // being applied to, and potentially to check that - // it isn't in conflict with any other modifiers - // on the same declaration. - - return m; - } - - - void checkModifiers(ModifiableSyntaxNode* syntaxNode) - { - // TODO(tfoley): need to make sure this only - // performs semantic checks on a `SharedModifier` once... - - // The process of checking a modifier may produce a new modifier in its place, - // so we will build up a new linked list of modifiers that will replace - // the old list. - RefPtr resultModifiers; - RefPtr* resultModifierLink = &resultModifiers; - - RefPtr modifier = syntaxNode->modifiers.first; - while(modifier) - { - // Because we are rewriting the list in place, we need to extract - // the next modifier here (not at the end of the loop). - auto next = modifier->next; - - // We also go ahead and clobber the `next` field on the modifier - // itself, so that the default behavior of `checkModifier()` can - // be to return a single unlinked modifier. - modifier->next = nullptr; - - auto checkedModifier = checkModifier(modifier, syntaxNode); - if(checkedModifier) - { - // If checking gave us a modifier to add, then we - // had better add it. - - // Just in case `checkModifier` ever returns multiple - // modifiers, lets advance to the end of the list we - // are building. - while(*resultModifierLink) - resultModifierLink = &(*resultModifierLink)->next; - - // attach the new modifier at the end of the list, - // and now set the "link" to point to its `next` field - *resultModifierLink = checkedModifier; - resultModifierLink = &checkedModifier->next; - } - - // Move along to the next modifier - modifier = next; - } - - // Whether we actually re-wrote anything or note, lets - // install the new list of modifiers on the declaration - syntaxNode->modifiers.first = resultModifiers; - } - - /// Perform checking of interface conformaces for this decl and all its children - void checkInterfaceConformancesRec(Decl* decl) - { - // Any user-defined type may have declared interface conformances, - // which we should check. - // - if( auto aggTypeDecl = as(decl) ) - { - checkAggTypeConformance(aggTypeDecl); - } - // Conformances can also come via `extension` declarations, and - // we should check them against the type(s) being extended. - // - else if(auto extensionDecl = as(decl)) - { - checkExtensionConformance(extensionDecl); - } - - // We need to handle the recursive cases here, the first - // of which is a generic decl, where we want to recurivsely - // check the inner declaration. - // - if(auto genericDecl = as(decl)) - { - checkInterfaceConformancesRec(genericDecl->inner); - } - // For any other kind of container declaration, we will - // recurse into all of its member declarations, so that - // we can handle, e.g., nested `struct` types. - // - else if(auto containerDecl = as(decl)) - { - for(auto member : containerDecl->Members) - { - checkInterfaceConformancesRec(member); - } - } - } - - void visitModuleDecl(ModuleDecl* programNode) - { - // Try to register all the builtin decls - for (auto decl : programNode->Members) - { - auto inner = decl; - if (auto genericDecl = as(decl)) - { - inner = genericDecl->inner; - } - - if (auto builtinMod = inner->FindModifier()) - { - registerBuiltinDecl(getSession(), decl, builtinMod); - } - if (auto magicMod = inner->FindModifier()) - { - registerMagicDecl(getSession(), decl, magicMod); - } - } - - // We need/want to visit any `import` declarations before - // anything else, to make sure that scoping works. - for(auto& importDecl : programNode->getMembersOfType()) - { - checkDecl(importDecl); - } - // register all extensions - for (auto & s : programNode->getMembersOfType()) - registerExtension(s); - for (auto & g : programNode->getMembersOfType()) - { - if (auto extDecl = as(g->inner)) - { - checkGenericDeclHeader(g); - registerExtension(extDecl); - } - } - // check user defined attribute classes first - for (auto decl : programNode->Members) - { - if (auto typeMember = as(decl)) - { - bool isTypeAttributeClass = false; - for (auto attrib : typeMember->GetModifiersOfType()) - { - if (attrib->name == getSession()->getNameObj("AttributeUsageAttribute")) - { - isTypeAttributeClass = true; - break; - } - } - if (isTypeAttributeClass) - checkDecl(decl); - } - } - // check types - for (auto & s : programNode->getMembersOfType()) - checkDecl(s.Ptr()); - - for (int pass = 0; pass < 2; pass++) - { - checkingPhase = pass == 0 ? CheckingPhase::Header : CheckingPhase::Body; - - for (auto & s : programNode->getMembersOfType()) - { - checkDecl(s.Ptr()); - } - // HACK(tfoley): Visiting all generic declarations here, - // because otherwise they won't get visited. - for (auto & g : programNode->getMembersOfType()) - { - checkDecl(g.Ptr()); - } - - // before checking conformance, make sure we check all the extension bodies - // generic extension decls are already checked by the loop above - for (auto & s : programNode->getMembersOfType()) - checkDecl(s); - - for (auto & func : programNode->getMembersOfType()) - { - if (!func->IsChecked(getCheckedState())) - { - VisitFunctionDeclaration(func.Ptr()); - } - } - for (auto & func : programNode->getMembersOfType()) - { - checkDecl(func); - } - - if (getSink()->GetErrorCount() != 0) - return; - - // Force everything to be fully checked, just in case - // Note that we don't just call this on the program, - // because we'd end up recursing into this very code path... - for (auto d : programNode->Members) - { - EnusreAllDeclsRec(d); - } - - if (pass == 0) - { - checkInterfaceConformancesRec(programNode); - } - } - } - - bool doesSignatureMatchRequirement( - DeclRef satisfyingMemberDeclRef, - DeclRef requiredMemberDeclRef, - RefPtr witnessTable) - { - if(satisfyingMemberDeclRef.getDecl()->HasModifier() - && !requiredMemberDeclRef.getDecl()->HasModifier()) - { - // A `[mutating]` method can't satisfy a non-`[mutating]` requirement, - // but vice-versa is okay. - return false; - } - - if(satisfyingMemberDeclRef.getDecl()->HasModifier() - != requiredMemberDeclRef.getDecl()->HasModifier()) - { - // A `static` method can't satisfy a non-`static` requirement and vice versa. - return false; - } - - // TODO: actually implement matching here. For now we'll - // just pretend that things are satisfied in order to make progress.. - witnessTable->requirementDictionary.Add( - requiredMemberDeclRef.getDecl(), - RequirementWitness(satisfyingMemberDeclRef)); - return true; - } - - bool doesGenericSignatureMatchRequirement( - DeclRef genDecl, - DeclRef requirementGenDecl, - RefPtr witnessTable) - { - if (genDecl.getDecl()->Members.getCount() != requirementGenDecl.getDecl()->Members.getCount()) - return false; - for (Index i = 0; i < genDecl.getDecl()->Members.getCount(); i++) - { - auto genMbr = genDecl.getDecl()->Members[i]; - auto requiredGenMbr = genDecl.getDecl()->Members[i]; - if (auto genTypeMbr = as(genMbr)) - { - if (auto requiredGenTypeMbr = as(requiredGenMbr)) - { - } - else - return false; - } - else if (auto genValMbr = as(genMbr)) - { - if (auto requiredGenValMbr = as(requiredGenMbr)) - { - if (!genValMbr->type->Equals(requiredGenValMbr->type)) - return false; - } - else - return false; - } - else if (auto genTypeConstraintMbr = as(genMbr)) - { - if (auto requiredTypeConstraintMbr = as(requiredGenMbr)) - { - if (!genTypeConstraintMbr->sup->Equals(requiredTypeConstraintMbr->sup)) - { - return false; - } - } - else - return false; - } - } - - // TODO: this isn't right, because we need to specialize the - // declarations of the generics to a common set of substitutions, - // so that their types are comparable (e.g., foo and foo - // need to have substitutions applies so that they are both foo, - // after which uses of the type X in their parameter lists can - // be compared). - - return doesMemberSatisfyRequirement( - DeclRef(genDecl.getDecl()->inner.Ptr(), genDecl.substitutions), - DeclRef(requirementGenDecl.getDecl()->inner.Ptr(), requirementGenDecl.substitutions), - witnessTable); - } - - bool doesTypeSatisfyAssociatedTypeRequirement( - RefPtr satisfyingType, - DeclRef requiredAssociatedTypeDeclRef, - RefPtr witnessTable) - { - // We need to confirm that the chosen type `satisfyingType`, - // meets all the constraints placed on the associated type - // requirement `requiredAssociatedTypeDeclRef`. - // - // We will enumerate the type constraints placed on the - // associated type and see if they can be satisfied. - // - bool conformance = true; - for (auto requiredConstraintDeclRef : getMembersOfType(requiredAssociatedTypeDeclRef)) - { - // Grab the type we expect to conform to from the constraint. - auto requiredSuperType = GetSup(requiredConstraintDeclRef); - - // Perform a search for a witness to the subtype relationship. - auto witness = tryGetSubtypeWitness(satisfyingType, requiredSuperType); - if(witness) - { - // If a subtype witness was found, then the conformance - // appears to hold, and we can satisfy that requirement. - witnessTable->requirementDictionary.Add(requiredConstraintDeclRef, RequirementWitness(witness)); - } - else - { - // If a witness couldn't be found, then the conformance - // seems like it will fail. - conformance = false; - } - } - - // TODO: if any conformance check failed, we should probably include - // that in an error message produced about not satisfying the requirement. - - if(conformance) - { - // If all the constraints were satisfied, then the chosen - // type can indeed satisfy the interface requirement. - witnessTable->requirementDictionary.Add( - requiredAssociatedTypeDeclRef.getDecl(), - RequirementWitness(satisfyingType)); - } - - return conformance; - } - - // Does the given `memberDecl` work as an implementation - // to satisfy the requirement `requiredMemberDeclRef` - // from an interface? - // - // If it does, then inserts a witness into `witnessTable` - // and returns `true`, otherwise returns `false` - bool doesMemberSatisfyRequirement( - DeclRef memberDeclRef, - DeclRef requiredMemberDeclRef, - RefPtr witnessTable) - { - // At a high level, we want to check that the - // `memberDecl` and the `requiredMemberDeclRef` - // have the same AST node class, and then also - // check that their signatures match. - // - // There are a bunch of detailed decisions that - // have to be made, though, because we might, e.g., - // allow a function with more general parameter - // types to satisfy a requirement with more - // specific parameter types. - // - // If we ever allow for "property" declarations, - // then we would probably need to allow an - // ordinary field to satisfy a property requirement. - // - // An associated type requirement should be allowed - // to be satisfied by any type declaration: - // a typedef, a `struct`, etc. - // - if (auto memberFuncDecl = memberDeclRef.as()) - { - if (auto requiredFuncDeclRef = requiredMemberDeclRef.as()) - { - // Check signature match. - return doesSignatureMatchRequirement( - memberFuncDecl, - requiredFuncDeclRef, - witnessTable); - } - } - else if (auto memberInitDecl = memberDeclRef.as()) - { - if (auto requiredInitDecl = requiredMemberDeclRef.as()) - { - // Check signature match. - return doesSignatureMatchRequirement( - memberInitDecl, - requiredInitDecl, - witnessTable); - } - } - else if (auto genDecl = memberDeclRef.as()) - { - // For a generic member, we will check if it can satisfy - // a generic requirement in the interface. - // - // TODO: we could also conceivably check that the generic - // could be *specialized* to satisfy the requirement, - // and then install a specialization of the generic into - // the witness table. Actually doing this would seem - // to require performing something akin to overload - // resolution as part of requirement satisfaction. - // - if (auto requiredGenDeclRef = requiredMemberDeclRef.as()) - { - return doesGenericSignatureMatchRequirement(genDecl, requiredGenDeclRef, witnessTable); - } - } - else if (auto subAggTypeDeclRef = memberDeclRef.as()) - { - if(auto requiredTypeDeclRef = requiredMemberDeclRef.as()) - { - checkDecl(subAggTypeDeclRef.getDecl()); - - auto satisfyingType = DeclRefType::Create(getSession(), subAggTypeDeclRef); - return doesTypeSatisfyAssociatedTypeRequirement(satisfyingType, requiredTypeDeclRef, witnessTable); - } - } - else if (auto typedefDeclRef = memberDeclRef.as()) - { - // this is a type-def decl in an aggregate type - // check if the specified type satisfies the constraints defined by the associated type - if (auto requiredTypeDeclRef = requiredMemberDeclRef.as()) - { - checkDecl(typedefDeclRef.getDecl()); - - auto satisfyingType = getNamedType(getSession(), typedefDeclRef); - return doesTypeSatisfyAssociatedTypeRequirement(satisfyingType, requiredTypeDeclRef, witnessTable); - } - } - // Default: just assume that thing aren't being satisfied. - return false; - } - - // State used while checking if a declaration (either a type declaration - // or an extension of that type) conforms to the interfaces it claims - // via its inheritance clauses. - // - struct ConformanceCheckingContext - { - Dictionary, RefPtr> mapInterfaceToWitnessTable; - }; - - // Find the appropriate member of a declared type to - // satisfy a requirement of an interface the type - // claims to conform to. - // - // The type declaration `typeDecl` has declared that it - // conforms to the interface `interfaceDeclRef`, and - // `requiredMemberDeclRef` is a required member of - // the interface. - // - // If a satisfying value is found, registers it in - // `witnessTable` and returns `true`, otherwise - // returns `false`. - // - bool findWitnessForInterfaceRequirement( - ConformanceCheckingContext* context, - DeclRef typeDeclRef, - InheritanceDecl* inheritanceDecl, - DeclRef interfaceDeclRef, - DeclRef requiredMemberDeclRef, - RefPtr witnessTable) - { - // The goal of this function is to find a suitable - // value to satisfy the requirement. - // - // The 99% case is that the requirement is a named member - // of the interface, and we need to search for a member - // with the same name in the type declaration and - // its (known) extensions. - - // As a first pass, lets check if we already have a - // witness in the table for the requirement, so - // that we can bail out early. - // - if(witnessTable->requirementDictionary.ContainsKey(requiredMemberDeclRef.getDecl())) - { - return true; - } - - - // An important exception to the above is that an - // inheritance declaration in the interface is not going - // to be satisfied by an inheritance declaration in the - // conforming type, but rather by a full "witness table" - // full of the satisfying values for each requirement - // in the inherited-from interface. - // - if( auto requiredInheritanceDeclRef = requiredMemberDeclRef.as() ) - { - // Recursively check that the type conforms - // to the inherited interface. - // - // TODO: we *really* need a linearization step here!!!! - - RefPtr satisfyingWitnessTable = checkConformanceToType( - context, - typeDeclRef, - requiredInheritanceDeclRef.getDecl(), - getBaseType(requiredInheritanceDeclRef)); - - if(!satisfyingWitnessTable) - return false; - - witnessTable->requirementDictionary.Add( - requiredInheritanceDeclRef.getDecl(), - RequirementWitness(satisfyingWitnessTable)); - return true; - } - - // We will look up members with the same name, - // since only same-name members will be able to - // satisfy the requirement. - // - // TODO: this won't work right now for members that - // don't have names, which right now includes - // initializers/constructors. - Name* name = requiredMemberDeclRef.GetName(); - - // We are basically looking up members of the - // given type, but we need to be a bit careful. - // We *cannot* perfom lookup "through" inheritance - // declarations for this or other interfaces, - // since that would let us satisfy a requirement - // with itself. - // - // There's also an interesting question of whether - // we can/should support innterface requirements - // being satisfied via `__transparent` members. - // This seems like a "clever" idea rather than - // a useful one, and IR generation would - // need to construct real IR to trampoline over - // to the implementation. - // - // The final case that can't be reduced to just - // "a directly declared member with the same name" - // is the case where the type inherits a member - // that can satisfy the requirement from a base type. - // We are ignoring implementation inheritance for - // now, so we won't worry about this. - - // Make sure that by-name lookup is possible. - buildMemberDictionary(typeDeclRef.getDecl()); - auto lookupResult = lookUpLocal(getSession(), this, name, typeDeclRef); - - if (!lookupResult.isValid()) - { - getSink()->diagnose(inheritanceDecl, Diagnostics::typeDoesntImplementInterfaceRequirement, typeDeclRef, requiredMemberDeclRef); - return false; - } - - // Iterate over the members and look for one that matches - // the expected signature for the requirement. - for (auto member : lookupResult) - { - if (doesMemberSatisfyRequirement(member.declRef, requiredMemberDeclRef, witnessTable)) - return true; - } - - // No suitable member found, although there were candidates. - // - // TODO: Eventually we might want something akin to the current - // overload resolution logic, where we keep track of a list - // of "candidates" for satisfaction of the requirement, - // and if nothing is found we print the candidates - - getSink()->diagnose(inheritanceDecl, Diagnostics::typeDoesntImplementInterfaceRequirement, typeDeclRef, requiredMemberDeclRef); - return false; - } - - // Check that the type declaration `typeDecl`, which - // declares conformance to the interface `interfaceDeclRef`, - // (via the given `inheritanceDecl`) actually provides - // members to satisfy all the requirements in the interface. - RefPtr checkInterfaceConformance( - ConformanceCheckingContext* context, - DeclRef typeDeclRef, - InheritanceDecl* inheritanceDecl, - DeclRef interfaceDeclRef) - { - // Has somebody already checked this conformance, - // and/or is in the middle of checking it? - RefPtr witnessTable; - if(context->mapInterfaceToWitnessTable.TryGetValue(interfaceDeclRef, witnessTable)) - return witnessTable; - - // We need to check the declaration of the interface - // before we can check that we conform to it. - checkDecl(interfaceDeclRef.getDecl()); - - // We will construct the witness table, and register it - // *before* we go about checking fine-grained requirements, - // in order to short-circuit any potential for infinite recursion. - - // Note: we will re-use the witnes table attached to the inheritance decl, - // if there is one. This catches cases where semantic checking might - // have synthesized some of the conformance witnesses for us. - // - witnessTable = inheritanceDecl->witnessTable; - if(!witnessTable) - { - witnessTable = new WitnessTable(); - } - context->mapInterfaceToWitnessTable.Add(interfaceDeclRef, witnessTable); - - bool result = true; - - // TODO: If we ever allow for implementation inheritance, - // then we will need to consider the case where a type - // declares that it conforms to an interface, but one of - // its (non-interface) base types already conforms to - // that interface, so that all of the requirements are - // already satisfied with inherited implementations... - for(auto requiredMemberDeclRef : getMembers(interfaceDeclRef)) - { - auto requirementSatisfied = findWitnessForInterfaceRequirement( - context, - typeDeclRef, - inheritanceDecl, - interfaceDeclRef, - requiredMemberDeclRef, - witnessTable); - - result = result && requirementSatisfied; - } - - // Extensions that apply to the interface type can create new conformances - // for the concrete types that inherit from the interface. - // - // These new conformances should not be able to introduce new *requirements* - // for an implementing interface (although they currently can), but we - // still need to go through this logic to find the appropriate value - // that will satisfy the requirement in these cases, and also to put - // the required entry into the witness table for the interface itself. - // - // TODO: This logic is a bit slippery, and we need to figure out what - // it means in the context of separate compilation. If module A defines - // an interface IA, module B defines a type C that conforms to IA, and then - // module C defines an extension that makes IA conform to IC, then it is - // unreasonable to expect the {B:IA} witness table to contain an entry - // corresponding to {IA:IC}. - // - // The simple answer then would be that the {IA:IC} conformance should be - // fixed, with a single witness table for {IA:IC}, but then what should - // happen in B explicitly conformed to IC already? - // - // For now we will just walk through the extensions that are known at - // the time we are compiling and handle those, and punt on the larger issue - // for abit longer. - for(auto candidateExt = interfaceDeclRef.getDecl()->candidateExtensions; candidateExt; candidateExt = candidateExt->nextCandidateExtension) - { - // We need to apply the extension to the interface type that our - // concrete type is inheriting from. - // - // TODO: need to decide if a this-type substitution is needed here. - // It probably it. - RefPtr targetType = DeclRefType::Create( - getSession(), - interfaceDeclRef); - auto extDeclRef = ApplyExtensionToType(candidateExt, targetType); - if(!extDeclRef) - continue; - - // Only inheritance clauses from the extension matter right now. - for(auto requiredInheritanceDeclRef : getMembersOfType(extDeclRef)) - { - auto requirementSatisfied = findWitnessForInterfaceRequirement( - context, - typeDeclRef, - inheritanceDecl, - interfaceDeclRef, - requiredInheritanceDeclRef, - witnessTable); - - result = result && requirementSatisfied; - } - } - - // If we failed to satisfy any requirements along the way, - // then we don't actually want to keep the witness table - // we've been constructing, because the whole thing was a failure. - if(!result) - { - return nullptr; - } - - return witnessTable; - } - - RefPtr checkConformanceToType( - ConformanceCheckingContext* context, - DeclRef typeDeclRef, - InheritanceDecl* inheritanceDecl, - Type* baseType) - { - if (auto baseDeclRefType = as(baseType)) - { - auto baseTypeDeclRef = baseDeclRefType->declRef; - if (auto baseInterfaceDeclRef = baseTypeDeclRef.as()) - { - // The type is stating that it conforms to an interface. - // We need to check that it provides all of the members - // required by that interface. - return checkInterfaceConformance( - context, - typeDeclRef, - inheritanceDecl, - baseInterfaceDeclRef); - } - } - - getSink()->diagnose(inheritanceDecl, Diagnostics::unimplemented, "type not supported for inheritance"); - return nullptr; - } - - // Check that the type (or extension) declaration `declRef`, - // which declares that it inherits from another type via - // `inheritanceDecl` actually does what it needs to - // for that inheritance to be valid. - bool checkConformance( - DeclRef declRef, - InheritanceDecl* inheritanceDecl) - { - declRef = createDefaultSubstitutionsIfNeeded(getSession(), declRef).as(); - - // Don't check conformances for abstract types that - // are being used to express *required* conformances. - if (auto assocTypeDeclRef = declRef.as()) - { - // An associated type declaration represents a requirement - // in an outer interface declaration, and its members - // (type constraints) represent additional requirements. - return true; - } - else if (auto interfaceDeclRef = declRef.as()) - { - // HACK: Our semantics as they stand today are that an - // `extension` of an interface that adds a new inheritance - // clause acts *as if* that inheritnace clause had been - // attached to the original `interface` decl: that is, - // it adds additional requirements. - // - // This is *not* a reasonable semantic to keep long-term, - // but it is required for some of our current example - // code to work. - return true; - } - - - // Look at the type being inherited from, and validate - // appropriately. - auto baseType = inheritanceDecl->base.type; - - ConformanceCheckingContext context; - RefPtr witnessTable = checkConformanceToType(&context, declRef, inheritanceDecl, baseType); - if(!witnessTable) - return false; - - inheritanceDecl->witnessTable = witnessTable; - return true; - } - - void checkExtensionConformance(ExtensionDecl* decl) - { - if (auto targetDeclRefType = as(decl->targetType)) - { - if (auto aggTypeDeclRef = targetDeclRefType->declRef.as()) - { - for (auto inheritanceDecl : decl->getMembersOfType()) - { - checkConformance(aggTypeDeclRef, inheritanceDecl); - } - } - } - } - - void checkAggTypeConformance(AggTypeDecl* decl) - { - // After we've checked members, we need to go through - // any inheritance clauses on the type itself, and - // confirm that the type actually provides whatever - // those clauses require. - - if (auto interfaceDecl = as(decl)) - { - // Don't check that an interface conforms to the - // things it inherits from. - } - else if (auto assocTypeDecl = as(decl)) - { - // Don't check that an associated type decl conforms to the - // things it inherits from. - } - else - { - // For non-interface types we need to check conformance. - // - // TODO: Need to figure out what this should do for - // `abstract` types if we ever add them. Should they - // be required to implement all interface requirements, - // just with `abstract` methods that replicate things? - // (That's what C# does). - for (auto inheritanceDecl : decl->getMembersOfType()) - { - checkConformance(makeDeclRef(decl), inheritanceDecl); - } - } - } - - void visitAggTypeDecl(AggTypeDecl* decl) - { - if (decl->IsChecked(getCheckedState())) - return; - - // TODO: we should check inheritance declarations - // first, since they need to be validated before - // we can make use of the type (e.g., you need - // to know that `A` inherits from `B` in order - // to check an expression like `aValue.bMethod()` - // where `aValue` is of type `A` but `bMethod` - // is defined in type `B`. - // - // TODO: We should also add a pass that takes - // all the stated inheritance relationships, - // expands them to include implicit inheritance, - // and then linearizes them. This would allow - // later passes that need to know everything - // a type inherits from to proceed linearly - // through the list, rather than having to - // recurse (and potentially see the same interface - // more than once). - - decl->SetCheckState(DeclCheckState::CheckedHeader); - - // Now check all of the member declarations. - for (auto member : decl->Members) - { - checkDecl(member); - } - decl->SetCheckState(getCheckedState()); - } - - bool isIntegerBaseType(BaseType baseType) - { - switch(baseType) - { - default: - return false; - - case BaseType::Int8: - case BaseType::Int16: - case BaseType::Int: - case BaseType::Int64: - case BaseType::UInt8: - case BaseType::UInt16: - case BaseType::UInt: - case BaseType::UInt64: - return true; - } - } - - // Validate that `type` is a suitable type to use - // as the tag type for an `enum` - void validateEnumTagType(Type* type, SourceLoc const& loc) - { - if(auto basicType = as(type)) - { - // Allow the built-in integer types. - if(isIntegerBaseType(basicType->baseType)) - return; - - // By default, don't allow other types to be used - // as an `enum` tag type. - } - - getSink()->diagnose(loc, Diagnostics::invalidEnumTagType, type); - } - - void visitEnumDecl(EnumDecl* decl) - { - if (decl->IsChecked(getCheckedState())) - return; - - // We need to be careful to avoid recursion in the - // type-checking logic. We will do the minimal work - // to make the type usable in the first phase, and - // then check the actual cases in the second phase. - // - if(this->checkingPhase == CheckingPhase::Header) - { - // Look at inheritance clauses, and - // see if one of them is making the enum - // "inherit" from a concrete type. - // This will become the "tag" type - // of the enum. - RefPtr tagType; - InheritanceDecl* tagTypeInheritanceDecl = nullptr; - for(auto inheritanceDecl : decl->getMembersOfType()) - { - checkDecl(inheritanceDecl); - - // Look at the type being inherited from. - auto superType = inheritanceDecl->base.type; - - if(auto errorType = as(superType)) - { - // Ignore any erroneous inheritance clauses. - continue; - } - else if(auto declRefType = as(superType)) - { - if(auto interfaceDeclRef = declRefType->declRef.as()) - { - // Don't consider interface bases as candidates for - // the tag type. - continue; - } - } - - if(tagType) - { - // We already found a tag type. - getSink()->diagnose(inheritanceDecl, Diagnostics::enumTypeAlreadyHasTagType); - getSink()->diagnose(tagTypeInheritanceDecl, Diagnostics::seePreviousTagType); - break; - } - else - { - tagType = superType; - tagTypeInheritanceDecl = inheritanceDecl; - } - } - - // If a tag type has not been set, then we - // default it to the built-in `int` type. - // - // TODO: In the far-flung future we may want to distinguish - // `enum` types that have a "raw representation" like this from - // ones that are purely abstract and don't expose their - // type of their tag. - if(!tagType) - { - tagType = getSession()->getIntType(); - } - else - { - // TODO: Need to establish that the tag - // type is suitable. (e.g., if we are going - // to allow raw values for case tags to be - // derived automatically, then the tag - // type needs to be some kind of integer type...) - // - // For now we will just be harsh and require it - // to be one of a few builtin types. - validateEnumTagType(tagType, tagTypeInheritanceDecl->loc); - } - decl->tagType = tagType; - - - // An `enum` type should automatically conform to the `__EnumType` interface. - // The compiler needs to insert this conformance behind the scenes, and this - // seems like the best place to do it. - { - // First, look up the type of the `__EnumType` interface. - RefPtr enumTypeType = getSession()->getEnumTypeType(); - - RefPtr enumConformanceDecl = new InheritanceDecl(); - enumConformanceDecl->ParentDecl = decl; - enumConformanceDecl->loc = decl->loc; - enumConformanceDecl->base.type = getSession()->getEnumTypeType(); - decl->Members.add(enumConformanceDecl); - - // The `__EnumType` interface has one required member, the `__Tag` type. - // We need to satisfy this requirement automatically, rather than require - // the user to actually declare a member with this name (otherwise we wouldn't - // let them define a tag value with the name `__Tag`). - // - RefPtr witnessTable = new WitnessTable(); - enumConformanceDecl->witnessTable = witnessTable; - - Name* tagAssociatedTypeName = getSession()->getNameObj("__Tag"); - Decl* tagAssociatedTypeDecl = nullptr; - if(auto enumTypeTypeDeclRefType = enumTypeType.dynamicCast()) - { - if(auto enumTypeTypeInterfaceDecl = as(enumTypeTypeDeclRefType->declRef.getDecl())) - { - for(auto memberDecl : enumTypeTypeInterfaceDecl->Members) - { - if(memberDecl->getName() == tagAssociatedTypeName) - { - tagAssociatedTypeDecl = memberDecl; - break; - } - } - } - } - if(!tagAssociatedTypeDecl) - { - SLANG_DIAGNOSE_UNEXPECTED(getSink(), decl, "failed to find built-in declaration '__Tag'"); - } - - // Okay, add the conformance witness for `__Tag` being satisfied by `tagType` - witnessTable->requirementDictionary.Add(tagAssociatedTypeDecl, RequirementWitness(tagType)); - - // TODO: we actually also need to synthesize a witness for the conformance of `tagType` - // to the `__BuiltinIntegerType` interface, because that is a constraint on the - // associated type `__Tag`. - - // TODO: eventually we should consider synthesizing other requirements for - // the min/max tag values, or the total number of tags, so that people don't - // have to declare these as additional cases. - - enumConformanceDecl->SetCheckState(DeclCheckState::Checked); - } - } - else if( checkingPhase == CheckingPhase::Body ) - { - auto enumType = DeclRefType::Create( - getSession(), - makeDeclRef(decl)); - - auto tagType = decl->tagType; - - // Check the enum cases in order. - for(auto caseDecl : decl->getMembersOfType()) - { - // Each case defines a value of the enum's type. - // - // TODO: If we ever support enum cases with payloads, - // then they would probably have a type that is a - // `FunctionType` from the payload types to the - // enum type. - // - caseDecl->type.type = enumType; - - checkDecl(caseDecl); - } - - // For any enum case that didn't provide an explicit - // tag value, derived an appropriate tag value. - IntegerLiteralValue defaultTag = 0; - for(auto caseDecl : decl->getMembersOfType()) - { - if(auto explicitTagValExpr = caseDecl->tagExpr) - { - // This tag has an initializer, so it should establish - // the tag value for a successor case that doesn't - // provide an explicit tag. - - RefPtr explicitTagVal = TryConstantFoldExpr(explicitTagValExpr); - if(explicitTagVal) - { - if(auto constIntVal = as(explicitTagVal)) - { - defaultTag = constIntVal->value; - } - else - { - // TODO: need to handle other possibilities here - getSink()->diagnose(explicitTagValExpr, Diagnostics::unexpectedEnumTagExpr); - } - } - else - { - // If this happens, then the explicit tag value expression - // doesn't seem to be a constant after all. In this case - // we expect the checking logic to have applied already. - } - } - else - { - // This tag has no initializer, so it should use - // the default tag value we are tracking. - RefPtr tagValExpr = new IntegerLiteralExpr(); - tagValExpr->loc = caseDecl->loc; - tagValExpr->type = QualType(tagType); - tagValExpr->value = defaultTag; - - caseDecl->tagExpr = tagValExpr; - } - - // Default tag for the next case will be one more than - // for the most recent case. - // - // TODO: We might consider adding a `[flags]` attribute - // that modifies this behavior to be `defaultTagForCase <<= 1`. - // - defaultTag++; - } - - // Now check any other member declarations. - for(auto memberDecl : decl->Members) - { - // Already checked inheritance declarations above. - if(auto inheritanceDecl = as(memberDecl)) - continue; - - // Already checked enum case declarations above. - if(auto caseDecl = as(memberDecl)) - continue; - - // TODO: Right now we don't support other kinds of - // member declarations on an `enum`, but that is - // something we may want to allow in the long run. - // - checkDecl(memberDecl); - } - } - decl->SetCheckState(getCheckedState()); - } - - void visitEnumCaseDecl(EnumCaseDecl* decl) - { - if (decl->IsChecked(getCheckedState())) - return; - - if(checkingPhase == CheckingPhase::Body) - { - // An enum case had better appear inside an enum! - // - // TODO: Do we need/want to support generic cases some day? - auto parentEnumDecl = as(decl->ParentDecl); - SLANG_ASSERT(parentEnumDecl); - - // The tag type should have already been set by - // the surrounding `enum` declaration. - auto tagType = parentEnumDecl->tagType; - SLANG_ASSERT(tagType); - - // Need to check the init expression, if present, since - // that represents the explicit tag for this case. - if(auto initExpr = decl->tagExpr) - { - initExpr = CheckExpr(initExpr); - initExpr = coerce(tagType, initExpr); - - // We want to enforce that this is an integer constant - // expression, but we don't actually care to retain - // the value. - CheckIntegerConstantExpression(initExpr); - - decl->tagExpr = initExpr; - } - } - - decl->SetCheckState(getCheckedState()); - } - - void visitDeclGroup(DeclGroup* declGroup) - { - for (auto decl : declGroup->decls) - { - dispatchDecl(decl); - } - } - - void visitTypeDefDecl(TypeDefDecl* decl) - { - if (decl->IsChecked(getCheckedState())) return; - if (checkingPhase == CheckingPhase::Header) - { - decl->type = CheckProperType(decl->type); - } - decl->SetCheckState(getCheckedState()); - } - - void visitGlobalGenericParamDecl(GlobalGenericParamDecl* decl) - { - if (decl->IsChecked(getCheckedState())) return; - if (checkingPhase == CheckingPhase::Header) - { - decl->SetCheckState(DeclCheckState::CheckedHeader); - // global generic param only allowed in global scope - auto program = as(decl->ParentDecl); - if (!program) - getSink()->diagnose(decl, Slang::Diagnostics::globalGenParamInGlobalScopeOnly); - // Now check all of the member declarations. - for (auto member : decl->Members) - { - checkDecl(member); - } - } - decl->SetCheckState(getCheckedState()); - } - - void visitAssocTypeDecl(AssocTypeDecl* decl) - { - if (decl->IsChecked(getCheckedState())) return; - if (checkingPhase == CheckingPhase::Header) - { - decl->SetCheckState(DeclCheckState::CheckedHeader); - - // assoctype only allowed in an interface - auto interfaceDecl = as(decl->ParentDecl); - if (!interfaceDecl) - getSink()->diagnose(decl, Slang::Diagnostics::assocTypeInInterfaceOnly); - - // Now check all of the member declarations. - for (auto member : decl->Members) - { - checkDecl(member); - } - } - decl->SetCheckState(getCheckedState()); - } - - void checkStmt(Stmt* stmt) - { - if (!stmt) return; - dispatchStmt(stmt); - checkModifiers(stmt); - } - - void visitFuncDecl(FuncDecl* functionNode) - { - if (functionNode->IsChecked(getCheckedState())) - return; - - if (checkingPhase == CheckingPhase::Header) - { - VisitFunctionDeclaration(functionNode); - } - // TODO: This should really only set "checked header" - functionNode->SetCheckState(getCheckedState()); - - if (checkingPhase == CheckingPhase::Body) - { - // TODO: should put the checking of the body onto a "work list" - // to avoid recursion here. - if (functionNode->Body) - { - auto oldFunc = function; - this->function = functionNode; - checkStmt(functionNode->Body); - this->function = oldFunc; - } - } - } - - void getGenericParams( - GenericDecl* decl, - List& outParams, - List outConstraints) - { - for (auto dd : decl->Members) - { - if (dd == decl->inner) - continue; - - if (auto typeParamDecl = as(dd)) - outParams.add(typeParamDecl); - else if (auto valueParamDecl = as(dd)) - outParams.add(valueParamDecl); - else if (auto constraintDecl = as(dd)) - outConstraints.add(constraintDecl); - } - } - - bool doGenericSignaturesMatch( - GenericDecl* fst, - GenericDecl* snd) - { - // First we'll extract the parameters and constraints - // in each generic signature. We will consider parameters - // and constraints separately so that we are independent - // of the order in which constraints are given (that is, - // a constraint like `` would be considered - // the same as `` with a later `where T : IFoo`. - - List fstParams; - List fstConstraints; - getGenericParams(fst, fstParams, fstConstraints); - - List sndParams; - List sndConstraints; - getGenericParams(snd, sndParams, sndConstraints); - - // For there to be any hope of a match, the - // two need to have the same number of parameters. - Index paramCount = fstParams.getCount(); - if (paramCount != sndParams.getCount()) - return false; - - // Now we'll walk through the parameters. - for (Index pp = 0; pp < paramCount; ++pp) - { - Decl* fstParam = fstParams[pp]; - Decl* sndParam = sndParams[pp]; - - if (auto fstTypeParam = as(fstParam)) - { - if (auto sndTypeParam = as(sndParam)) - { - // TODO: is there any validation that needs to be performed here? - } - else - { - // Type and non-type parameters can't match. - return false; - } - } - else if (auto fstValueParam = as(fstParam)) - { - if (auto sndValueParam = as(sndParam)) - { - // Need to check that the parameters have the same type. - // - // Note: We are assuming here that the type of a value - // parameter cannot be dependent on any of the type - // parameters in the same signature. This is a reasonable - // assumption for now, but could get thorny down the road. - if (!fstValueParam->getType()->Equals(sndValueParam->getType())) - { - // Type mismatch. - return false; - } - - // TODO: This is not the right place to check on default - // values for the parameter, because they won't affect - // the signature, but we should make sure to do validation - // later on (e.g., that only one declaration can/should - // be allowed to provide a default). - } - else - { - // Value and non-value parameters can't match. - return false; - } - } - } - - // If we got this far, then it means the parameter signatures *seem* - // to match up all right, but now we need to check that the constraints - // placed on those parameters are also consistent. - // - // For now I'm going to assume/require that all declarations must - // declare the signature in a way that matches exactly. - Index constraintCount = fstConstraints.getCount(); - if(constraintCount != sndConstraints.getCount()) - return false; - - for (Index cc = 0; cc < constraintCount; ++cc) - { - //auto fstConstraint = fstConstraints[cc]; - //auto sndConstraint = sndConstraints[cc]; - - // TODO: the challenge here is that the - // constraints are going to be expressed - // in terms of the parameters, which means - // we need to be doing substitution here. - } - - // HACK: okay, we'll just assume things match for now. - return true; - } - - // Check if two functions have the same signature for the purposes - // of overload resolution. - bool doFunctionSignaturesMatch( - DeclRef fst, - DeclRef snd) - { - - // TODO(tfoley): This copies the parameter array, which is bad for performance. - auto fstParams = GetParameters(fst).ToArray(); - auto sndParams = GetParameters(snd).ToArray(); - - // If the functions have different numbers of parameters, then - // their signatures trivially don't match. - auto fstParamCount = fstParams.getCount(); - auto sndParamCount = sndParams.getCount(); - if (fstParamCount != sndParamCount) - return false; - - for (Index ii = 0; ii < fstParamCount; ++ii) - { - auto fstParam = fstParams[ii]; - auto sndParam = sndParams[ii]; - - // If a given parameter type doesn't match, then signatures don't match - if (!GetType(fstParam)->Equals(GetType(sndParam))) - return false; - - // If one parameter is `out` and the other isn't, then they don't match - // - // Note(tfoley): we don't consider `out` and `inout` as distinct here, - // because there is no way for overload resolution to pick between them. - if (fstParam.getDecl()->HasModifier() != sndParam.getDecl()->HasModifier()) - return false; - - // If one parameter is `ref` and the other isn't, then they don't match. - // - if(fstParam.getDecl()->HasModifier() != sndParam.getDecl()->HasModifier()) - return false; - } - - // Note(tfoley): return type doesn't enter into it, because we can't take - // calling context into account during overload resolution. - - return true; - } - - RefPtr createDummySubstitutions( - GenericDecl* genericDecl) - { - RefPtr subst = new GenericSubstitution(); - subst->genericDecl = genericDecl; - for (auto dd : genericDecl->Members) - { - if (dd == genericDecl->inner) - continue; - - if (auto typeParam = as(dd)) - { - auto type = DeclRefType::Create(getSession(), - makeDeclRef(typeParam)); - subst->args.add(type); - } - else if (auto valueParam = as(dd)) - { - auto val = new GenericParamIntVal( - makeDeclRef(valueParam)); - subst->args.add(val); - } - // TODO: need to handle constraints here? - } - return subst; - } - - void ValidateFunctionRedeclaration(FuncDecl* funcDecl) - { - auto parentDecl = funcDecl->ParentDecl; - SLANG_ASSERT(parentDecl); - if (!parentDecl) return; - - Decl* childDecl = funcDecl; - - // If this is a generic function (that is, its parent - // declaration is a generic), then we need to look - // for sibling declarations of the parent. - auto genericDecl = as(parentDecl); - if (genericDecl) - { - parentDecl = genericDecl->ParentDecl; - childDecl = genericDecl; - } - - // Look at previously-declared functions with the same name, - // in the same container - // - // Note: there is an assumption here that declarations that - // occur earlier in the program text will be *later* in - // the linked list of declarations with the same name. - // We are also assuming/requiring that the check here is - // symmetric, in that it is okay to test (A,B) or (B,A), - // and there is no need to test both. - // - buildMemberDictionary(parentDecl); - for (auto pp = childDecl->nextInContainerWithSameName; pp; pp = pp->nextInContainerWithSameName) - { - auto prevDecl = pp; - - // Look through generics to the declaration underneath - auto prevGenericDecl = as(prevDecl); - if (prevGenericDecl) - prevDecl = prevGenericDecl->inner.Ptr(); - - // We only care about previously-declared functions - // Note(tfoley): although we should really error out if the - // name is already in use for something else, like a variable... - auto prevFuncDecl = as(prevDecl); - if (!prevFuncDecl) - continue; - - // If one declaration is a prefix/postfix operator, and the - // other is not a matching operator, then don't consider these - // to be re-declarations. - // - // Note(tfoley): Any attempt to call such an operator using - // ordinary function-call syntax (if we decided to allow it) - // would be ambiguous in such a case, of course. - // - if (funcDecl->HasModifier() != prevDecl->HasModifier()) - continue; - if (funcDecl->HasModifier() != prevDecl->HasModifier()) - continue; - - // If one is generic and the other isn't, then there is no match. - if ((genericDecl != nullptr) != (prevGenericDecl != nullptr)) - continue; - - // We are going to be comparing the signatures of the - // two functions, but if they are *generic* functions - // then we will need to compare them with consistent - // specializations in place. - // - // We'll go ahead and create some (unspecialized) declaration - // references here, just to be prepared. - DeclRef funcDeclRef(funcDecl, nullptr); - DeclRef prevFuncDeclRef(prevFuncDecl, nullptr); - - // If we are working with generic functions, then we need to - // consider if their generic signatures match. - if (genericDecl) - { - SLANG_ASSERT(prevGenericDecl); // already checked above - if (!doGenericSignaturesMatch(genericDecl, prevGenericDecl)) - continue; - - // Now we need specialize the declaration references - // consistently, so that we can compare. - // - // First we create a "dummy" set of substitutions that - // just reference the parameters of the first generic. - auto subst = createDummySubstitutions(genericDecl); - // - // Then we use those parameters to specialize the *other* - // generic. - // - subst->genericDecl = prevGenericDecl; - prevFuncDeclRef.substitutions.substitutions = subst; - // - // One way to think about it is that if we have these - // declarations (ignore the name differences...): - // - // // prevFuncDecl: - // void foo1(T x); - // - // // funcDecl: - // void foo2(U x); - // - // Then we will compare `foo2` against `foo1`. - } - - // If the parameter signatures don't match, then don't worry - if (!doFunctionSignaturesMatch(funcDeclRef, prevFuncDeclRef)) - continue; - - // If we get this far, then we've got two declarations in the same - // scope, with the same name and signature, so they appear - // to be redeclarations. - // - // We will track that redeclaration occured, so that we can - // take it into account for overload resolution. - // - // A huge complication that we'll need to deal with is that - // multiple declarations might introduce default values for - // (different) parameters, and we might need to merge across - // all of them (which could get complicated if defaults for - // parameters can reference earlier parameters). - - // If the previous declaration wasn't already recorded - // as being part of a redeclaration family, then make - // it the primary declaration of a new family. - if (!prevFuncDecl->primaryDecl) - { - prevFuncDecl->primaryDecl = prevFuncDecl; - } - - // The new declaration will belong to the family of - // the previous one, and so it will share the same - // primary declaration. - funcDecl->primaryDecl = prevFuncDecl->primaryDecl; - funcDecl->nextDecl = nullptr; - - // Next we want to chain the new declaration onto - // the linked list of redeclarations. - auto link = &prevFuncDecl->nextDecl; - while (*link) - link = &(*link)->nextDecl; - *link = funcDecl; - - // Now that we've added things to a group of redeclarations, - // we can do some additional validation. - - // First, we will ensure that the return types match - // between the declarations, so that they are truly - // interchangeable. - // - // Note(tfoley): If we ever decide to add a beefier type - // system to Slang, we might allow overloads like this, - // so long as the desired result type can be disambiguated - // based on context at the call type. In that case we would - // consider result types earlier, as part of the signature - // matching step. - // - auto resultType = GetResultType(funcDeclRef); - auto prevResultType = GetResultType(prevFuncDeclRef); - if (!resultType->Equals(prevResultType)) - { - // Bad redeclaration - getSink()->diagnose(funcDecl, Diagnostics::functionRedeclarationWithDifferentReturnType, funcDecl->getName(), resultType, prevResultType); - getSink()->diagnose(prevFuncDecl, Diagnostics::seePreviousDeclarationOf, funcDecl->getName()); - - // Don't bother emitting other errors at this point - break; - } - - // Note(tfoley): several of the following checks should - // really be looping over all the previous declarations - // in the same group, and not just the one previous - // declaration we found just now. - - // TODO: Enforce that the new declaration had better - // not specify a default value for any parameter that - // already had a default value in a prior declaration. - - // We are going to want to enforce that we cannot have - // two declarations of a function both specify bodies. - // Before we make that check, however, we need to deal - // with the case where the two function declarations - // might represent different target-specific versions - // of a function. - // - // TODO: if the two declarations are specialized for - // different targets, then skip the body checks below. - - // If both of the declarations have a body, then there - // is trouble, because we wouldn't know which one to - // use during code generation. - if (funcDecl->Body && prevFuncDecl->Body) - { - // Redefinition - getSink()->diagnose(funcDecl, Diagnostics::functionRedefinition, funcDecl->getName()); - getSink()->diagnose(prevFuncDecl, Diagnostics::seePreviousDefinitionOf, funcDecl->getName()); - - // Don't bother emitting other errors - break; - } - - // At this point we've processed the redeclaration and - // put it into a group, so there is no reason to keep - // looping and looking at prior declarations. - return; - } - } - - void visitScopeDecl(ScopeDecl*) - { - // Nothing to do - } - - void visitParamDecl(ParamDecl* paramDecl) - { - // TODO: This logic should be shared with the other cases of - // variable declarations. The main reason I am not doing it - // yet is that we use a `ParamDecl` with a null type as a - // special case in attribute declarations, and that could - // trip up the ordinary variable checks. - - auto typeExpr = paramDecl->type; - if(typeExpr.exp) - { - typeExpr = CheckUsableType(typeExpr); - paramDecl->type = typeExpr; - } - - paramDecl->SetCheckState(DeclCheckState::CheckedHeader); - - // The "initializer" expression for a parameter represents - // a default argument value to use if an explicit one is - // not supplied. - if(auto initExpr = paramDecl->initExpr) - { - // We must check the expression and coerce it to the - // actual type of the parameter. - // - initExpr = CheckExpr(initExpr); - initExpr = coerce(typeExpr.type, initExpr); - paramDecl->initExpr = initExpr; - - // TODO: a default argument expression needs to - // conform to other constraints to be valid. - // For example, it should not be allowed to refer - // to other parameters of the same function (or maybe - // only the parameters to its left...). - - // A default argument value should not be allowed on an - // `out` or `inout` parameter. - // - // TODO: we could relax this by requiring the expression - // to yield an lvalue, but that seems like a feature - // with limited practical utility (and an easy source - // of confusing behavior). - // - // Note: the `InOutModifier` class inherits from `OutModifier`, - // so we only need to check for the base case. - // - if(paramDecl->FindModifier()) - { - getSink()->diagnose(initExpr, Diagnostics::outputParameterCannotHaveDefaultValue); - } - } - - paramDecl->SetCheckState(DeclCheckState::Checked); - } - - void VisitFunctionDeclaration(FuncDecl *functionNode) - { - if (functionNode->IsChecked(DeclCheckState::CheckedHeader)) return; - functionNode->SetCheckState(DeclCheckState::CheckingHeader); - auto oldFunc = this->function; - this->function = functionNode; - - auto resultType = functionNode->ReturnType; - if(resultType.exp) - { - resultType = CheckProperType(functionNode->ReturnType); - } - else - { - resultType = TypeExp(getSession()->getVoidType()); - } - functionNode->ReturnType = resultType; - - - HashSet paraNames; - for (auto & para : functionNode->GetParameters()) - { - EnsureDecl(para, DeclCheckState::CheckedHeader); - - if (paraNames.Contains(para->getName())) - { - getSink()->diagnose(para, Diagnostics::parameterAlreadyDefined, para->getName()); - } - else - paraNames.Add(para->getName()); - } - this->function = oldFunc; - functionNode->SetCheckState(DeclCheckState::CheckedHeader); - - // One last bit of validation: check if we are redeclaring an existing function - ValidateFunctionRedeclaration(functionNode); - } - - void visitDeclStmt(DeclStmt* stmt) - { - // We directly dispatch here instead of using `EnsureDecl()` for two - // reasons: - // - // 1. We expect that a local declaration won't have been referenced - // before it is declared, so that we can just check things in-order - // - // 2. `EnsureDecl()` is specialized for `Decl*` instead of `DeclBase*` - // and trying to special case `DeclGroup*` here feels silly. - // - dispatchDecl(stmt->decl); - checkModifiers(stmt->decl); - } - - void visitBlockStmt(BlockStmt* stmt) - { - checkStmt(stmt->body); - } - - void visitSeqStmt(SeqStmt* stmt) - { - for(auto ss : stmt->stmts) - { - checkStmt(ss); - } - } - - template - T* FindOuterStmt() - { - const Index outerStmtCount = outerStmts.getCount(); - for (Index ii = outerStmtCount; ii > 0; --ii) - { - auto outerStmt = outerStmts[ii-1]; - auto found = as(outerStmt); - if (found) - return found; - } - return nullptr; - } - - void visitBreakStmt(BreakStmt *stmt) - { - auto outer = FindOuterStmt(); - if (!outer) - { - getSink()->diagnose(stmt, Diagnostics::breakOutsideLoop); - } - stmt->parentStmt = outer; - } - void visitContinueStmt(ContinueStmt *stmt) - { - auto outer = FindOuterStmt(); - if (!outer) - { - getSink()->diagnose(stmt, Diagnostics::continueOutsideLoop); - } - stmt->parentStmt = outer; - } - - void PushOuterStmt(Stmt* stmt) - { - outerStmts.add(stmt); - } - - void PopOuterStmt(Stmt* /*stmt*/) - { - outerStmts.removeAt(outerStmts.getCount() - 1); - } - - RefPtr checkPredicateExpr(Expr* expr) - { - RefPtr e = expr; - e = CheckTerm(e); - e = coerce(getSession()->getBoolType(), e); - return e; - } - - void visitDoWhileStmt(DoWhileStmt *stmt) - { - PushOuterStmt(stmt); - stmt->Predicate = checkPredicateExpr(stmt->Predicate); - checkStmt(stmt->Statement); - - PopOuterStmt(stmt); - } - void visitForStmt(ForStmt *stmt) - { - PushOuterStmt(stmt); - checkStmt(stmt->InitialStatement); - if (stmt->PredicateExpression) - { - stmt->PredicateExpression = checkPredicateExpr(stmt->PredicateExpression); - } - if (stmt->SideEffectExpression) - { - stmt->SideEffectExpression = CheckExpr(stmt->SideEffectExpression); - } - checkStmt(stmt->Statement); - - PopOuterStmt(stmt); - } - - RefPtr checkExpressionAndExpectIntegerConstant(RefPtr expr, RefPtr* outIntVal) - { - expr = CheckExpr(expr); - auto intVal = CheckIntegerConstantExpression(expr); - if (outIntVal) - *outIntVal = intVal; - return expr; - } - - void visitCompileTimeForStmt(CompileTimeForStmt* stmt) - { - PushOuterStmt(stmt); - - stmt->varDecl->type.type = getSession()->getIntType(); - addModifier(stmt->varDecl, new ConstModifier()); - stmt->varDecl->SetCheckState(DeclCheckState::Checked); - - RefPtr rangeBeginVal; - RefPtr rangeEndVal; - - if (stmt->rangeBeginExpr) - { - stmt->rangeBeginExpr = checkExpressionAndExpectIntegerConstant(stmt->rangeBeginExpr, &rangeBeginVal); - } - else - { - RefPtr rangeBeginConst = new ConstantIntVal(); - rangeBeginConst->value = 0; - rangeBeginVal = rangeBeginConst; - } - - stmt->rangeEndExpr = checkExpressionAndExpectIntegerConstant(stmt->rangeEndExpr, &rangeEndVal); - - stmt->rangeBeginVal = rangeBeginVal; - stmt->rangeEndVal = rangeEndVal; - - checkStmt(stmt->body); - - - PopOuterStmt(stmt); - } - - void visitSwitchStmt(SwitchStmt* stmt) - { - PushOuterStmt(stmt); - // TODO(tfoley): need to coerce condition to an integral type... - stmt->condition = CheckExpr(stmt->condition); - checkStmt(stmt->body); - - // TODO(tfoley): need to check that all case tags are unique - - // TODO(tfoley): check that there is at most one `default` clause - - PopOuterStmt(stmt); - } - void visitCaseStmt(CaseStmt* stmt) - { - // TODO(tfoley): Need to coerce to type being switch on, - // and ensure that value is a compile-time constant - auto expr = CheckExpr(stmt->expr); - auto switchStmt = FindOuterStmt(); - - if (!switchStmt) - { - getSink()->diagnose(stmt, Diagnostics::caseOutsideSwitch); - } - else - { - // TODO: need to do some basic matching to ensure the type - // for the `case` is consistent with the type for the `switch`... - } - - stmt->expr = expr; - stmt->parentStmt = switchStmt; - } - void visitDefaultStmt(DefaultStmt* stmt) - { - auto switchStmt = FindOuterStmt(); - if (!switchStmt) - { - getSink()->diagnose(stmt, Diagnostics::defaultOutsideSwitch); - } - stmt->parentStmt = switchStmt; - } - void visitIfStmt(IfStmt *stmt) - { - stmt->Predicate = checkPredicateExpr(stmt->Predicate); - checkStmt(stmt->PositiveStatement); - checkStmt(stmt->NegativeStatement); - } - - void visitUnparsedStmt(UnparsedStmt*) - { - // Nothing to do - } - - void visitEmptyStmt(EmptyStmt*) - { - // Nothing to do - } - - void visitDiscardStmt(DiscardStmt*) - { - // Nothing to do - } - - void visitReturnStmt(ReturnStmt *stmt) - { - if (!stmt->Expression) - { - if (function && !function->ReturnType.Equals(getSession()->getVoidType())) - { - getSink()->diagnose(stmt, Diagnostics::returnNeedsExpression); - } - } - else - { - stmt->Expression = CheckTerm(stmt->Expression); - if (!stmt->Expression->type->Equals(getSession()->getErrorType())) - { - if (function) - { - stmt->Expression = coerce(function->ReturnType.Ptr(), stmt->Expression); - } - else - { - // TODO(tfoley): this case currently gets triggered for member functions, - // which aren't being checked consistently (because of the whole symbol - // table idea getting in the way). - -// getSink()->diagnose(stmt, Diagnostics::unimplemented, "case for return stmt"); - } - } - } - } - - IntegerLiteralValue GetMinBound(RefPtr val) - { - if (auto constantVal = as(val)) - return constantVal->value; - - // TODO(tfoley): Need to track intervals so that this isn't just a lie... - return 1; - } - - void maybeInferArraySizeForVariable(VarDeclBase* varDecl) - { - // Not an array? - auto arrayType = as(varDecl->type); - if (!arrayType) return; - - // Explicit element count given? - auto elementCount = arrayType->ArrayLength; - if (elementCount) return; - - // No initializer? - auto initExpr = varDecl->initExpr; - if(!initExpr) return; - - // Is the type of the initializer an array type? - if(auto arrayInitType = as(initExpr->type)) - { - elementCount = arrayInitType->ArrayLength; - } - else - { - // Nothing to do: we couldn't infer a size - return; - } - - // Create a new array type based on the size we found, - // and install it into our type. - varDecl->type.type = getArrayType( - arrayType->baseType, - elementCount); - } - - void validateArraySizeForVariable(VarDeclBase* varDecl) - { - auto arrayType = as(varDecl->type); - if (!arrayType) return; - - auto elementCount = arrayType->ArrayLength; - if (!elementCount) - { - // Note(tfoley): For now we allow arrays of unspecified size - // everywhere, because some source languages (e.g., GLSL) - // allow them in specific cases. -#if 0 - getSink()->diagnose(varDecl, Diagnostics::invalidArraySize); -#endif - return; - } - - // TODO(tfoley): How to handle the case where bound isn't known? - if (GetMinBound(elementCount) <= 0) - { - getSink()->diagnose(varDecl, Diagnostics::invalidArraySize); - return; - } - } - - void visitVarDecl(VarDecl* varDecl) - { - CheckVarDeclCommon(varDecl); - } - - void visitWhileStmt(WhileStmt *stmt) - { - PushOuterStmt(stmt); - stmt->Predicate = checkPredicateExpr(stmt->Predicate); - checkStmt(stmt->Statement); - PopOuterStmt(stmt); - } - void visitExpressionStmt(ExpressionStmt *stmt) - { - stmt->Expression = CheckExpr(stmt->Expression); - } - - RefPtr visitBoolLiteralExpr(BoolLiteralExpr* expr) - { - expr->type = getSession()->getBoolType(); - return expr; - } - - RefPtr visitIntegerLiteralExpr(IntegerLiteralExpr* expr) - { - // The expression might already have a type, determined by its suffix. - // It it doesn't, we will give it a default type. - // - // TODO: We should be careful to pick a "big enough" type - // based on the size of the value (e.g., don't try to stuff - // a constant in an `int` if it requires 64 or more bits). - // - // The long-term solution here is to give a type to a literal - // based on the context where it is used, but that requires - // a more sophisticated type system than we have today. - // - if(!expr->type.type) - { - expr->type = getSession()->getIntType(); - } - return expr; - } - - RefPtr visitFloatingPointLiteralExpr(FloatingPointLiteralExpr* expr) - { - if(!expr->type.type) - { - expr->type = getSession()->getFloatType(); - } - return expr; - } - - RefPtr visitStringLiteralExpr(StringLiteralExpr* expr) - { - expr->type = getSession()->getStringType(); - return expr; - } - - IntVal* GetIntVal(IntegerLiteralExpr* expr) - { - // TODO(tfoley): don't keep allocating here! - return new ConstantIntVal(expr->value); - } - - Linkage* getLinkage() { return m_linkage; } - NamePool* getNamePool() { return getLinkage()->getNamePool(); } - - Name* getName(String const& text) - { - return getNamePool()->getName(text); - } - - RefPtr TryConstantFoldExpr( - InvokeExpr* invokeExpr) - { - // We need all the operands to the expression - - // Check if the callee is an operation that is amenable to constant-folding. - // - // For right now we will look for calls to intrinsic functions, and then inspect - // their names (this is bad and slow). - auto funcDeclRefExpr = invokeExpr->FunctionExpr.as(); - if (!funcDeclRefExpr) return nullptr; - - auto funcDeclRef = funcDeclRefExpr->declRef; - auto intrinsicMod = funcDeclRef.getDecl()->FindModifier(); - if (!intrinsicMod) - { - // We can't constant fold anything that doesn't map to a builtin - // operation right now. - // - // TODO: we should really allow constant-folding for anything - // that can be lowered to our bytecode... - return nullptr; - } - - - - // Let's not constant-fold operations with more than a certain number of arguments, for simplicity - static const int kMaxArgs = 8; - if (invokeExpr->Arguments.getCount() > kMaxArgs) - return nullptr; - - // Before checking the operation name, let's look at the arguments - RefPtr argVals[kMaxArgs]; - IntegerLiteralValue constArgVals[kMaxArgs]; - int argCount = 0; - bool allConst = true; - for (auto argExpr : invokeExpr->Arguments) - { - auto argVal = TryCheckIntegerConstantExpression(argExpr.Ptr()); - if (!argVal) - return nullptr; - - argVals[argCount] = argVal; - - if (auto constArgVal = as(argVal)) - { - constArgVals[argCount] = constArgVal->value; - } - else - { - allConst = false; - } - argCount++; - } - - if (!allConst) - { - // TODO(tfoley): We probably want to support a very limited number of operations - // on "constants" that aren't actually known, to be able to handle a generic - // that takes an integer `N` but then constructs a vector of size `N+1`. - // - // The hard part there is implementing the rules for value unification in the - // presence of more complicated `IntVal` subclasses, like `SumIntVal`. You'd - // need inference to be smart enough to know that `2 + N` and `N + 2` are the - // same value, as are `N + M + 1 + 1` and `M + 2 + N`. - // - // For now we can just bail in this case. - return nullptr; - } - - // At this point, all the operands had simple integer values, so we are golden. - IntegerLiteralValue resultValue = 0; - auto opName = funcDeclRef.GetName(); - - // handle binary operators - if (opName == getName("-")) - { - if (argCount == 1) - { - resultValue = -constArgVals[0]; - } - else if (argCount == 2) - { - resultValue = constArgVals[0] - constArgVals[1]; - } - } - - // simple binary operators -#define CASE(OP) \ - else if(opName == getName(#OP)) do { \ - if(argCount != 2) return nullptr; \ - resultValue = constArgVals[0] OP constArgVals[1]; \ - } while(0) - - CASE(+); // TODO: this can also be unary... - CASE(*); -#undef CASE - - // binary operators with chance of divide-by-zero - // TODO: issue a suitable error in that case -#define CASE(OP) \ - else if(opName == getName(#OP)) do { \ - if(argCount != 2) return nullptr; \ - if(!constArgVals[1]) return nullptr; \ - resultValue = constArgVals[0] OP constArgVals[1]; \ - } while(0) - - CASE(/); - CASE(%); -#undef CASE - - // TODO(tfoley): more cases - else - { - return nullptr; - } - - RefPtr result = new ConstantIntVal(resultValue); - return result; - } - - RefPtr TryConstantFoldExpr( - Expr* expr) - { - // Unwrap any "identity" expressions - while (auto parenExpr = as(expr)) - { - expr = parenExpr->base; - } - - // TODO(tfoley): more serious constant folding here - if (auto intLitExpr = as(expr)) - { - return GetIntVal(intLitExpr); - } - - // it is possible that we are referring to a generic value param - if (auto declRefExpr = as(expr)) - { - auto declRef = declRefExpr->declRef; - - if (auto genericValParamRef = declRef.as()) - { - // TODO(tfoley): handle the case of non-`int` value parameters... - return new GenericParamIntVal(genericValParamRef); - } - - // We may also need to check for references to variables that - // are defined in a way that can be used as a constant expression: - if(auto varRef = declRef.as()) - { - auto varDecl = varRef.getDecl(); - - // In HLSL, `static const` is used to mark compile-time constant expressions - if(auto staticAttr = varDecl->FindModifier()) - { - if(auto constAttr = varDecl->FindModifier()) - { - // HLSL `static const` can be used as a constant expression - if(auto initExpr = getInitExpr(varRef)) - { - return TryConstantFoldExpr(initExpr.Ptr()); - } - } - } - } - else if(auto enumRef = declRef.as()) - { - // The cases in an `enum` declaration can also be used as constant expressions, - if(auto tagExpr = getTagExpr(enumRef)) - { - return TryConstantFoldExpr(tagExpr.Ptr()); - } - } - } - - if(auto castExpr = as(expr)) - { - auto val = TryConstantFoldExpr(castExpr->Arguments[0].Ptr()); - if(val) - return val; - } - else if (auto invokeExpr = as(expr)) - { - auto val = TryConstantFoldExpr(invokeExpr); - if (val) - return val; - } - - return nullptr; - } - - // Try to check an integer constant expression, either returning the value, - // or NULL if the expression isn't recognized as a constant. - RefPtr TryCheckIntegerConstantExpression(Expr* exp) - { - // Check if type is acceptable for an integer constant expression - if(auto basicType = as(exp->type.type)) - { - if(!isIntegerBaseType(basicType->baseType)) - return nullptr; - } - else - { - return nullptr; - } - - // Consider operations that we might be able to constant-fold... - return TryConstantFoldExpr(exp); - } - - // Enforce that an expression resolves to an integer constant, and get its value - RefPtr CheckIntegerConstantExpression(Expr* inExpr) - { - // No need to issue further errors if the expression didn't even type-check. - if(IsErrorExpr(inExpr)) return nullptr; - - // First coerce the expression to the expected type - auto expr = coerce(getSession()->getIntType(),inExpr); - - // No need to issue further errors if the type coercion failed. - if(IsErrorExpr(expr)) return nullptr; - - auto result = TryCheckIntegerConstantExpression(expr.Ptr()); - if (!result) - { - getSink()->diagnose(expr, Diagnostics::expectedIntegerConstantNotConstant); - } - return result; - } - - RefPtr CheckEnumConstantExpression(Expr* expr) - { - // No need to issue further errors if the expression didn't even type-check. - if(IsErrorExpr(expr)) return nullptr; - - // No need to issue further errors if the type coercion failed. - if(IsErrorExpr(expr)) return nullptr; - - auto result = TryConstantFoldExpr(expr); - if (!result) - { - getSink()->diagnose(expr, Diagnostics::expectedIntegerConstantNotConstant); - } - return result; - } - - - RefPtr CheckSimpleSubscriptExpr( - RefPtr subscriptExpr, - RefPtr elementType) - { - auto baseExpr = subscriptExpr->BaseExpression; - auto indexExpr = subscriptExpr->IndexExpression; - - if (!indexExpr->type->Equals(getSession()->getIntType()) && - !indexExpr->type->Equals(getSession()->getUIntType())) - { - getSink()->diagnose(indexExpr, Diagnostics::subscriptIndexNonInteger); - return CreateErrorExpr(subscriptExpr.Ptr()); - } - - subscriptExpr->type = QualType(elementType); - - // TODO(tfoley): need to be more careful about this stuff - subscriptExpr->type.IsLeftValue = baseExpr->type.IsLeftValue; - - return subscriptExpr; - } - - // The way that we have designed out type system, pretyt much *every* - // type is a reference to some declaration in the standard library. - // That means that when we construct a new type on the fly, we need - // to make sure that it is wired up to reference the appropriate - // declaration, or else it won't compare as equal to other types - // that *do* reference the declaration. - // - // This function is used to construct a `vector` type - // programmatically, so that it will work just like a type of - // that form constructed by the user. - RefPtr createVectorType( - RefPtr elementType, - RefPtr elementCount) - { - auto session = getSession(); - auto vectorGenericDecl = findMagicDecl( - session, "Vector").as(); - auto vectorTypeDecl = vectorGenericDecl->inner; - - auto substitutions = new GenericSubstitution(); - substitutions->genericDecl = vectorGenericDecl.Ptr(); - substitutions->args.add(elementType); - substitutions->args.add(elementCount); - - auto declRef = DeclRef(vectorTypeDecl.Ptr(), substitutions); - - return DeclRefType::Create( - session, - declRef).as(); - } - - RefPtr visitIndexExpr(IndexExpr* subscriptExpr) - { - auto baseExpr = subscriptExpr->BaseExpression; - baseExpr = CheckExpr(baseExpr); - - RefPtr indexExpr = subscriptExpr->IndexExpression; - if (indexExpr) - { - indexExpr = CheckExpr(indexExpr); - } - - subscriptExpr->BaseExpression = baseExpr; - subscriptExpr->IndexExpression = indexExpr; - - // If anything went wrong in the base expression, - // then just move along... - if (IsErrorExpr(baseExpr)) - return CreateErrorExpr(subscriptExpr); - - // Otherwise, we need to look at the type of the base expression, - // to figure out how subscripting should work. - auto baseType = baseExpr->type.Ptr(); - if (auto baseTypeType = as(baseType)) - { - // We are trying to "index" into a type, so we have an expression like `float[2]` - // which should be interpreted as resolving to an array type. - - RefPtr elementCount = nullptr; - if (indexExpr) - { - elementCount = CheckIntegerConstantExpression(indexExpr.Ptr()); - } - - auto elementType = CoerceToUsableType(TypeExp(baseExpr, baseTypeType->type)); - auto arrayType = getArrayType( - elementType, - elementCount); - - typeResult = arrayType; - subscriptExpr->type = QualType(getTypeType(arrayType)); - return subscriptExpr; - } - else if (auto baseArrayType = as(baseType)) - { - return CheckSimpleSubscriptExpr( - subscriptExpr, - baseArrayType->baseType); - } - else if (auto vecType = as(baseType)) - { - return CheckSimpleSubscriptExpr( - subscriptExpr, - vecType->elementType); - } - else if (auto matType = as(baseType)) - { - // TODO(tfoley): We shouldn't go and recompute - // row types over and over like this... :( - auto rowType = createVectorType( - matType->getElementType(), - matType->getColumnCount()); - - return CheckSimpleSubscriptExpr( - subscriptExpr, - rowType); - } - - // Default behavior is to look at all available `__subscript` - // declarations on the type and try to call one of them. - - { - LookupResult lookupResult = lookUpMember( - getSession(), - this, - getName("operator[]"), - baseType); - if (!lookupResult.isValid()) - { - goto fail; - } - - // Now that we know there is at least one subscript member, - // we will construct a reference to it and try to call it. - // - // Note: the expression may be an `OverloadedExpr`, in which - // case the attempt to call it will trigger overload - // resolution. - RefPtr subscriptFuncExpr = createLookupResultExpr( - lookupResult, subscriptExpr->BaseExpression, subscriptExpr->loc); - - RefPtr subscriptCallExpr = new InvokeExpr(); - subscriptCallExpr->loc = subscriptExpr->loc; - subscriptCallExpr->FunctionExpr = subscriptFuncExpr; - - // TODO(tfoley): This path can support multiple arguments easily - subscriptCallExpr->Arguments.add(subscriptExpr->IndexExpression); - - return CheckInvokeExprWithCheckedOperands(subscriptCallExpr.Ptr()); - } - - fail: - { - getSink()->diagnose(subscriptExpr, Diagnostics::subscriptNonArray, baseType); - return CreateErrorExpr(subscriptExpr); - } - } - - bool MatchArguments(FuncDecl * functionNode, List > &args) - { - if (functionNode->GetParameters().getCount() != args.getCount()) - return false; - Index i = 0; - for (auto param : functionNode->GetParameters()) - { - if (!param->type.Equals(args[i]->type.Ptr())) - return false; - i++; - } - return true; - } - - RefPtr visitParenExpr(ParenExpr* expr) - { - auto base = expr->base; - base = CheckTerm(base); - - expr->base = base; - expr->type = base->type; - return expr; - } - - // - - /// Given an immutable `expr` used as an l-value emit a special diagnostic if it was derived from `this`. - void maybeDiagnoseThisNotLValue(Expr* expr) - { - // We will try to handle expressions of the form: - // - // e ::= "this" - // | e . name - // | e [ expr ] - // - // We will unwrap the `e.name` and `e[expr]` cases in a loop. - RefPtr e = expr; - for(;;) - { - if(auto memberExpr = as(e)) - { - e = memberExpr->BaseExpression; - } - else if(auto subscriptExpr = as(e)) - { - e = subscriptExpr->BaseExpression; - } - else - { - break; - } - } - // - // Now we check to see if we have a `this` expression, - // and if it is immutable. - if(auto thisExpr = as(e)) - { - if(!thisExpr->type.IsLeftValue) - { - getSink()->diagnose(thisExpr, Diagnostics::thisIsImmutableByDefault); - } - } - } - - RefPtr visitAssignExpr(AssignExpr* expr) - { - expr->left = CheckExpr(expr->left); - - auto type = expr->left->type; - - expr->right = coerce(type, CheckTerm(expr->right)); - - if (!type.IsLeftValue) - { - if (as(type)) - { - // Don't report an l-value issue on an erroneous expression - } - else - { - getSink()->diagnose(expr, Diagnostics::assignNonLValue); - - // As a special case, check if the LHS expression is derived - // from a `this` parameter (implicitly or explicitly), which - // is immutable. We can give the user a bit more context into - // what is going on. - // - maybeDiagnoseThisNotLValue(expr->left); - } - } - expr->type = type; - return expr; - } - - void registerExtension(ExtensionDecl* decl) - { - if (decl->IsChecked(DeclCheckState::CheckedHeader)) - return; - - decl->SetCheckState(DeclCheckState::CheckingHeader); - decl->targetType = CheckProperType(decl->targetType); - decl->SetCheckState(DeclCheckState::CheckedHeader); - - // TODO: need to check that the target type names a declaration... - - if (auto targetDeclRefType = as(decl->targetType)) - { - // Attach our extension to that type as a candidate... - if (auto aggTypeDeclRef = targetDeclRefType->declRef.as()) - { - auto aggTypeDecl = aggTypeDeclRef.getDecl(); - decl->nextCandidateExtension = aggTypeDecl->candidateExtensions; - aggTypeDecl->candidateExtensions = decl; - return; - } - } - getSink()->diagnose(decl->targetType.exp, Diagnostics::unimplemented, "expected a nominal type here"); - } - - void visitExtensionDecl(ExtensionDecl* decl) - { - if (decl->IsChecked(getCheckedState())) return; - - if (!as(decl->targetType)) - { - getSink()->diagnose(decl->targetType.exp, Diagnostics::unimplemented, "expected a nominal type here"); - } - // now check the members of the extension - for (auto m : decl->Members) - { - checkDecl(m); - } - decl->SetCheckState(getCheckedState()); - } - - // Figure out what type an initializer/constructor declaration - // is supposed to return. In most cases this is just the type - // declaration that its declaration is nested inside. - RefPtr findResultTypeForConstructorDecl(ConstructorDecl* decl) - { - // We want to look at the parent of the declaration, - // but if the declaration is generic, the parent will be - // the `GenericDecl` and we need to skip past that to - // the grandparent. - // - auto parent = decl->ParentDecl; - auto genericParent = as(parent); - if (genericParent) - { - parent = genericParent->ParentDecl; - } - - // Now look at the type of the parent (or grandparent). - if (auto aggTypeDecl = as(parent)) - { - // We are nested in an aggregate type declaration, - // so the result type of the initializer will just - // be the surrounding type. - return DeclRefType::Create( - getSession(), - makeDeclRef(aggTypeDecl)); - } - else if (auto extDecl = as(parent)) - { - // We are nested inside an extension, so the result - // type needs to be the type being extended. - return extDecl->targetType.type; - } - else - { - getSink()->diagnose(decl, Diagnostics::initializerNotInsideType); - return nullptr; - } - } - - void visitConstructorDecl(ConstructorDecl* decl) - { - if (decl->IsChecked(getCheckedState())) return; - if (checkingPhase == CheckingPhase::Header) - { - decl->SetCheckState(DeclCheckState::CheckingHeader); - - for (auto& paramDecl : decl->GetParameters()) - { - paramDecl->type = CheckUsableType(paramDecl->type); - } - - // We need to compute the result tyep for this declaration, - // since it wasn't filled in for us. - decl->ReturnType.type = findResultTypeForConstructorDecl(decl); - } - else - { - // TODO(tfoley): check body - } - decl->SetCheckState(getCheckedState()); - } - - - void visitSubscriptDecl(SubscriptDecl* decl) - { - if (decl->IsChecked(getCheckedState())) return; - for (auto& paramDecl : decl->GetParameters()) - { - paramDecl->type = CheckUsableType(paramDecl->type); - } - - decl->ReturnType = CheckUsableType(decl->ReturnType); - - // If we have a subscript declaration with no accessor declarations, - // then we should create a single `GetterDecl` to represent - // the implicit meaning of their declaration, so: - // - // subscript(uint index) -> T; - // - // becomes: - // - // subscript(uint index) -> T { get; } - // - - bool anyAccessors = false; - for(auto accessorDecl : decl->getMembersOfType()) - { - anyAccessors = true; - } - - if(!anyAccessors) - { - RefPtr getterDecl = new GetterDecl(); - getterDecl->loc = decl->loc; - - getterDecl->ParentDecl = decl; - decl->Members.add(getterDecl); - } - - for(auto mm : decl->Members) - { - checkDecl(mm); - } - - decl->SetCheckState(getCheckedState()); - } - - void visitAccessorDecl(AccessorDecl* decl) - { - if (checkingPhase == CheckingPhase::Header) - { - // An accessor must appear nested inside a subscript declaration (today), - // or a property declaration (when we add them). It will derive - // its return type from the outer declaration, so we handle both - // of these checks at the same place. - auto parent = decl->ParentDecl; - if (auto parentSubscript = as(parent)) - { - decl->ReturnType = parentSubscript->ReturnType; - } - // TODO: when we add "property" declarations, check for them here - else - { - getSink()->diagnose(decl, Diagnostics::accessorMustBeInsideSubscriptOrProperty); - } - - } - else - { - // TODO: check the body! - } - decl->SetCheckState(getCheckedState()); - } - - - // - - struct Constraint - { - Decl* decl; // the declaration of the thing being constraints - RefPtr val; // the value to which we are constraining it - bool satisfied = false; // Has this constraint been met? - }; - - // A collection of constraints that will need to be satisfied (solved) - // in order for checking to succeed. - struct ConstraintSystem - { - // A source location to use in reporting any issues - SourceLoc loc; - - // The generic declaration whose parameters we - // are trying to solve for. - RefPtr genericDecl; - - // Constraints we have accumulated, which constrain - // the possible arguments for those parameters. - List constraints; - }; - - RefPtr TryJoinVectorAndScalarType( - RefPtr vectorType, - RefPtr scalarType) - { - // Join( vector, S ) -> vetor - // - // That is, the join of a vector and a scalar type is - // a vector type with a joined element type. - auto joinElementType = TryJoinTypes( - vectorType->elementType, - scalarType); - if(!joinElementType) - return nullptr; - - return createVectorType( - joinElementType, - vectorType->elementCount); - } - - struct TypeWitnessBreadcrumb - { - TypeWitnessBreadcrumb* prev; - - RefPtr sub; - RefPtr sup; - DeclRef declRef; - }; - - // Crete a subtype witness based on the declared relationship - // found in a single breadcrumb - RefPtr createSimpleSubtypeWitness( - TypeWitnessBreadcrumb* breadcrumb) - { - RefPtr witness = new DeclaredSubtypeWitness(); - witness->sub = breadcrumb->sub; - witness->sup = breadcrumb->sup; - witness->declRef = breadcrumb->declRef; - return witness; - } - - RefPtr createTypeWitness( - RefPtr type, - DeclRef interfaceDeclRef, - TypeWitnessBreadcrumb* inBreadcrumbs) - { - if(!inBreadcrumbs) - { - // We need to construct a witness to the fact - // that `type` has been proven to be *equal* - // to `interfaceDeclRef`. - // - SLANG_UNEXPECTED("reflexive type witness"); - UNREACHABLE_RETURN(nullptr); - } - - // We might have one or more steps in the breadcrumb trail, e.g.: - // - // {A : B} {B : C} {C : D} - // - // The chain is stored as a reversed linked list, so that - // the first entry would be the `(C : D)` relationship - // above. - // - // We need to walk the list and build up a suitable witness, - // which in the above case would look like: - // - // Transitive( - // Transitive( - // Declared({A : B}), - // {B : C}), - // {C : D}) - // - // Because of the ordering of the breadcrumb trail, along - // with the way the `Transitive` case nests, we will be - // building these objects outside-in, and keeping - // track of the "hole" where the next step goes. - // - auto bb = inBreadcrumbs; - - // `witness` here will hold the first (outer-most) object - // we create, which is the overall result. - RefPtr witness; - - // `link` will point at the remaining "hole" in the - // data structure, to be filled in. - RefPtr* link = &witness; - - // As long as there is more than one breadcrumb, we - // need to be creating transitive witnesses. - while(bb->prev) - { - // On the first iteration when processing the list - // above, the breadcrumb would be for `{ C : D }`, - // and so we'd create: - // - // Transitive( - // [...], - // { C : D}) - // - // where `[...]` represents the "hole" we leave - // open to fill in next. - // - RefPtr transitiveWitness = new TransitiveSubtypeWitness(); - transitiveWitness->sub = bb->sub; - transitiveWitness->sup = bb->sup; - transitiveWitness->midToSup = bb->declRef; - - // Fill in the current hole, and then set the - // hole to point into the node we just created. - *link = transitiveWitness; - link = &transitiveWitness->subToMid; - - // Move on with the list. - bb = bb->prev; - } - - // If we exit the loop, then there is only one breadcrumb left. - // In our running example this would be `{ A : B }`. We create - // a simple (declared) subtype witness for it, and plug the - // final hole, after which there shouldn't be a hole to deal with. - RefPtr declaredWitness = createSimpleSubtypeWitness(bb); - *link = declaredWitness; - - // We now know that our original `witness` variable has been - // filled in, and there are no other holes. - return witness; - } - - /// Is the given interface one that a tagged-union type can conform to? - /// - /// If a tagged union type `__TaggedUnion(A,B)` is going to be - /// plugged in for a type parameter `T : IFoo` then we need to - /// be sure that the interface `IFoo` doesn't have anything - /// that could lead to unsafe/unsound behavior. This function - /// checks that all the requirements on the interfaceare safe ones. - /// - bool isInterfaceSafeForTaggedUnion( - DeclRef interfaceDeclRef) - { - for( auto memberDeclRef : getMembers(interfaceDeclRef) ) - { - if(!isInterfaceRequirementSafeForTaggedUnion(interfaceDeclRef, memberDeclRef)) - return false; - } - - return true; - } - - /// Is the given interface requirement one that a tagged-union type can satisfy? - /// - /// Unsafe requirements include any `static` requirements, - /// any associated types, and also any requirements that make - /// use of the `This` type (once we support it). - /// - bool isInterfaceRequirementSafeForTaggedUnion( - DeclRef interfaceDeclRef, - DeclRef requirementDeclRef) - { - if(auto callableDeclRef = requirementDeclRef.as()) - { - // A `static` method requirement can't be satisfied by a - // tagged union, because there is no tag to dispatch on. - // - if(requirementDeclRef.getDecl()->HasModifier()) - return false; - - // TODO: We will eventually want to check that any callable - // requirements do not use the `This` type or any associated - // types in ways that could lead to errors. - // - // For now we are disallowing interfaces that have associated - // types completely, and we haven't implemented the `This` - // type, so we should be safe. - - return true; - } - else - { - return false; - } - } - - bool doesTypeConformToInterfaceImpl( - RefPtr originalType, - RefPtr type, - DeclRef interfaceDeclRef, - RefPtr* outWitness, - TypeWitnessBreadcrumb* inBreadcrumbs) - { - // for now look up a conformance member... - if(auto declRefType = as(type)) - { - auto declRef = declRefType->declRef; - - // Easy case: a type conforms to itself. - // - // TODO: This is actually a bit more complicated, as - // the interface needs to be "object-safe" for us to - // really make this determination... - if(declRef == interfaceDeclRef) - { - if(outWitness) - { - *outWitness = createTypeWitness(originalType, interfaceDeclRef, inBreadcrumbs); - } - return true; - } - - if( auto aggTypeDeclRef = declRef.as() ) - { - checkDecl(aggTypeDeclRef.getDecl()); - - for( auto inheritanceDeclRef : getMembersOfTypeWithExt(aggTypeDeclRef)) - { - checkDecl(inheritanceDeclRef.getDecl()); - - // Here we will recursively look up conformance on the type - // that is being inherited from. This is dangerous because - // it might lead to infinite loops. - // - // TODO: A better approach would be to create a linearized list - // of all the interfaces that a given type directly or indirectly - // inherits, and store it with the type, so that we don't have - // to recurse in places like this (and can maybe catch infinite - // loops better). This would also help avoid checking multiply-inherited - // conformances multiple times. - - auto inheritedType = getBaseType(inheritanceDeclRef); - - // We need to ensure that the witness that gets created - // is a composite one, reflecting lookup through - // the inheritance declaration. - TypeWitnessBreadcrumb breadcrumb; - breadcrumb.prev = inBreadcrumbs; - - breadcrumb.sub = type; - breadcrumb.sup = inheritedType; - breadcrumb.declRef = inheritanceDeclRef; - - if(doesTypeConformToInterfaceImpl(originalType, inheritedType, interfaceDeclRef, outWitness, &breadcrumb)) - { - return true; - } - } - // if an inheritance decl is not found, try to find a GenericTypeConstraintDecl - for (auto genConstraintDeclRef : getMembersOfType(aggTypeDeclRef)) - { - checkDecl(genConstraintDeclRef.getDecl()); - auto inheritedType = GetSup(genConstraintDeclRef); - TypeWitnessBreadcrumb breadcrumb; - breadcrumb.prev = inBreadcrumbs; - breadcrumb.sub = type; - breadcrumb.sup = inheritedType; - breadcrumb.declRef = genConstraintDeclRef; - if (doesTypeConformToInterfaceImpl(originalType, inheritedType, interfaceDeclRef, outWitness, &breadcrumb)) - { - return true; - } - } - } - else if( auto genericTypeParamDeclRef = declRef.as() ) - { - // We need to enumerate the constraints placed on this type by its outer - // generic declaration, and see if any of them guarantees that we - // satisfy the given interface.. - auto genericDeclRef = genericTypeParamDeclRef.GetParent().as(); - SLANG_ASSERT(genericDeclRef); - - for( auto constraintDeclRef : getMembersOfType(genericDeclRef) ) - { - auto sub = GetSub(constraintDeclRef); - auto sup = GetSup(constraintDeclRef); - - auto subDeclRef = as(sub); - if(!subDeclRef) - continue; - if(subDeclRef->declRef != genericTypeParamDeclRef) - continue; - - // The witness that we create needs to reflect that - // it found the needed conformance by lookup through - // a generic type constraint. - - TypeWitnessBreadcrumb breadcrumb; - breadcrumb.prev = inBreadcrumbs; - breadcrumb.sub = sub; - breadcrumb.sup = sup; - breadcrumb.declRef = constraintDeclRef; - - if(doesTypeConformToInterfaceImpl(originalType, sup, interfaceDeclRef, outWitness, &breadcrumb)) - { - return true; - } - } - } - } - else if(auto taggedUnionType = as(type)) - { - // A tagged union type conforms to an interface if all of - // the constituent types in the tagged union conform. - // - // We will iterate over the "case" types in the tagged - // union, and check if they conform to the interface. - // Along the way we will collect the conformance witness - // values *if* we are being asked to produce a witness - // value for the tagged union itself (that is, if - // `outWitness` is non-null). - // - List> caseWitnesses; - for(auto caseType : taggedUnionType->caseTypes) - { - RefPtr caseWitness; - - if(!doesTypeConformToInterfaceImpl( - caseType, - caseType, - interfaceDeclRef, - outWitness ? &caseWitness : nullptr, - nullptr)) - { - return false; - } - - if(outWitness) - { - caseWitnesses.add(caseWitness); - } - } - - // We also need to validate the requirements on - // the interface to make sure that they are suitable for - // use with a tagged-union type. - // - // For example, if the interface includes a `static` method - // (which can therefore be called without a particular instance), - // then we wouldn't know what implementation of that method - // to use because there is no tag value to dispatch on. - // - // We will start out being conservative about what we accept - // here, just to keep things simple. - // - if(!isInterfaceSafeForTaggedUnion(interfaceDeclRef)) - return false; - - // If we reach this point then we have a concrete - // witness for each of the case types, and that is - // enough to build a witness for the tagged union. - // - if(outWitness) - { - RefPtr taggedUnionWitness = new TaggedUnionSubtypeWitness(); - taggedUnionWitness->sub = taggedUnionType; - taggedUnionWitness->sup = DeclRefType::Create(getSession(), interfaceDeclRef); - taggedUnionWitness->caseWitnesses.swapWith(caseWitnesses); - - *outWitness = taggedUnionWitness; - } - return true; - } - - // default is failure - return false; - } - - bool DoesTypeConformToInterface( - RefPtr type, - DeclRef interfaceDeclRef) - { - return doesTypeConformToInterfaceImpl(type, type, interfaceDeclRef, nullptr, nullptr); - } - - RefPtr tryGetInterfaceConformanceWitness( - RefPtr type, - DeclRef interfaceDeclRef) - { - RefPtr result; - doesTypeConformToInterfaceImpl(type, type, interfaceDeclRef, &result, nullptr); - return result; - } - - /// Does there exist an implicit conversion from `fromType` to `toType`? - bool canConvertImplicitly( - RefPtr toType, - RefPtr fromType) - { - // Can we convert at all? - ConversionCost conversionCost; - if(!canCoerce(toType, fromType, &conversionCost)) - return false; - - // Is the conversion cheap enough to be done implicitly? - if(conversionCost >= kConversionCost_GeneralConversion) - return false; - - return true; - } - - RefPtr TryJoinTypeWithInterface( - RefPtr type, - DeclRef interfaceDeclRef) - { - // The most basic test here should be: does the type declare conformance to the trait. - if(DoesTypeConformToInterface(type, interfaceDeclRef)) - return type; - - // Just because `type` doesn't conform to the given `interfaceDeclRef`, that - // doesn't necessarily indicate a failure. It is possible that we have a call - // like `sqrt(2)` so that `type` is `int` and `interfaceDeclRef` is - // `__BuiltinFloatingPointType`. The "obvious" answer is that we should infer - // the type `float`, but it seems like the compiler would have to synthesize - // that answer from thin air. - // - // A robsut/correct solution here might be to enumerate set of types types `S` - // such that for each type `X` in `S`: - // - // * `type` is implicitly convertible to `X` - // * `X` conforms to the interface named by `interfaceDeclRef` - // - // If the set `S` is non-empty then we would try to pick the "best" type from `S`. - // The "best" type would be a type `Y` such that `Y` is implicitly convertible to - // every other type in `S`. - // - // We are going to implement a much simpler strategy for now, where we only apply - // the search process if `type` is a builtin scalar type, and then we only search - // through types `X` that are also builtin scalar types. - // - RefPtr bestType; - if(auto basicType = type.dynamicCast()) - { - for(Int baseTypeFlavorIndex = 0; baseTypeFlavorIndex < Int(BaseType::CountOf); baseTypeFlavorIndex++) - { - // Don't consider `type`, since we already know it doesn't work. - if(baseTypeFlavorIndex == Int(basicType->baseType)) - continue; - - // Look up the type in our session. - auto candidateType = type->getSession()->getBuiltinType(BaseType(baseTypeFlavorIndex)); - if(!candidateType) - continue; - - // We only want to consider types that implement the target interface. - if(!DoesTypeConformToInterface(candidateType, interfaceDeclRef)) - continue; - - // We only want to consider types where we can implicitly convert from `type` - if(!canConvertImplicitly(candidateType, type)) - continue; - - // At this point, we have a candidate type that is usable. - // - // If this is our first viable candidate, then it is our best one: - // - if(!bestType) - { - bestType = candidateType; - } - else - { - // Otherwise, we want to pick the "better" type between `candidateType` - // and `bestType`. - // - // We are going to be a bit loose here, and not worry about the - // case where conversion is allowed in both directions. - // - // TODO: make this completely robust. - // - if(canConvertImplicitly(bestType, candidateType)) - { - // Our candidate can convert to the current "best" type, so - // it is logically a more specific type that satisfies our - // constraints, therefore we should keep it. - // - bestType = candidateType; - } - } - } - if(bestType) - return bestType; - } - - // For all other cases, we will just bail out for now. - // - // TODO: In the future we should build some kind of side data structure - // to accelerate either one or both of these queries: - // - // * Given a type `T`, what types `U` can it convert to implicitly? - // - // * Given an interface `I`, what types `U` conform to it? - // - // The intersection of the sets returned by these two queries is - // the set of candidates we would like to consider here. - - return nullptr; - } - - // Try to compute the "join" between two types - RefPtr TryJoinTypes( - RefPtr left, - RefPtr right) - { - // Easy case: they are the same type! - if (left->Equals(right)) - return left; - - // We can join two basic types by picking the "better" of the two - if (auto leftBasic = as(left)) - { - if (auto rightBasic = as(right)) - { - auto leftFlavor = leftBasic->baseType; - auto rightFlavor = rightBasic->baseType; - - // TODO(tfoley): Need a special-case rule here that if - // either operand is of type `half`, then we promote - // to at least `float` - - // Return the one that had higher rank... - if (leftFlavor > rightFlavor) - return left; - else - { - SLANG_ASSERT(rightFlavor > leftFlavor); // equality was handles at the top of this function - return right; - } - } - - // We can also join a vector and a scalar - if(auto rightVector = as(right)) - { - return TryJoinVectorAndScalarType(rightVector, leftBasic); - } - } - - // We can join two vector types by joining their element types - // (and also their sizes...) - if( auto leftVector = as(left)) - { - if(auto rightVector = as(right)) - { - // Check if the vector sizes match - if(!leftVector->elementCount->EqualsVal(rightVector->elementCount.Ptr())) - return nullptr; - - // Try to join the element types - auto joinElementType = TryJoinTypes( - leftVector->elementType, - rightVector->elementType); - if(!joinElementType) - return nullptr; - - return createVectorType( - joinElementType, - leftVector->elementCount); - } - - // We can also join a vector and a scalar - if(auto rightBasic = as(right)) - { - return TryJoinVectorAndScalarType(leftVector, rightBasic); - } - } - - // HACK: trying to work trait types in here... - if(auto leftDeclRefType = as(left)) - { - if( auto leftInterfaceRef = leftDeclRefType->declRef.as() ) - { - // - return TryJoinTypeWithInterface(right, leftInterfaceRef); - } - } - if(auto rightDeclRefType = as(right)) - { - if( auto rightInterfaceRef = rightDeclRefType->declRef.as() ) - { - // - return TryJoinTypeWithInterface(left, rightInterfaceRef); - } - } - - // TODO: all the cases for vectors apply to matrices too! - - // Default case is that we just fail. - return nullptr; - } - - // Try to solve a system of generic constraints. - // The `system` argument provides the constraints. - // The `varSubst` argument provides the list of constraint - // variables that were created for the system. - // - // Returns a new substitution representing the values that - // we solved for along the way. - SubstitutionSet TrySolveConstraintSystem( - ConstraintSystem* system, - DeclRef genericDeclRef) - { - // For now the "solver" is going to be ridiculously simplistic. - - // The generic itself will have some constraints, and for now we add these - // to the system of constrains we will use for solving for the type variables. - // - // TODO: we need to decide whether constraints are used like this to influence - // how we solve for type/value variables, or whether constraints in the parameter - // list just work as a validation step *after* we've solved for the types. - // - // That is, should we allow `` to be written, and cause us to "infer" - // that `T` should be the type `Int`? That seems a little silly. - // - // Eventually, though, we may want to support type identity constraints, especially - // on associated types, like `` - // These seem more reasonable to have influence constraint solving, since it could - // conceivably let us specialize a `X : IContainer` to `X` if we find - // that `X.IndexType == T`. - for( auto constraintDeclRef : getMembersOfType(genericDeclRef) ) - { - if(!TryUnifyTypes(*system, GetSub(constraintDeclRef), GetSup(constraintDeclRef))) - return SubstitutionSet(); - } - SubstitutionSet resultSubst = genericDeclRef.substitutions; - // We will loop over the generic parameters, and for - // each we will try to find a way to satisfy all - // the constraints for that parameter - List> args; - for (auto m : getMembers(genericDeclRef)) - { - if (auto typeParam = m.as()) - { - RefPtr type = nullptr; - for (auto& c : system->constraints) - { - if (c.decl != typeParam.getDecl()) - continue; - - auto cType = as(c.val); - SLANG_RELEASE_ASSERT(cType); - - if (!type) - { - type = cType; - } - else - { - auto joinType = TryJoinTypes(type, cType); - if (!joinType) - { - // failure! - return SubstitutionSet(); - } - type = joinType; - } - - c.satisfied = true; - } - - if (!type) - { - // failure! - return SubstitutionSet(); - } - args.add(type); - } - else if (auto valParam = m.as()) - { - // TODO(tfoley): maybe support more than integers some day? - // TODO(tfoley): figure out how this needs to interact with - // compile-time integers that aren't just constants... - RefPtr val = nullptr; - for (auto& c : system->constraints) - { - if (c.decl != valParam.getDecl()) - continue; - - auto cVal = as(c.val); - SLANG_RELEASE_ASSERT(cVal); - - if (!val) - { - val = cVal; - } - else - { - if(!val->EqualsVal(cVal)) - { - // failure! - return SubstitutionSet(); - } - } - - c.satisfied = true; - } - - if (!val) - { - // failure! - return SubstitutionSet(); - } - args.add(val); - } - else - { - // ignore anything that isn't a generic parameter - } - } - - // After we've solved for the explicit arguments, we need to - // make a second pass and consider the implicit arguments, - // based on what we've already determined to be the values - // for the explicit arguments. - - // Before we begin, we are going to go ahead and create the - // "solved" substitution that we will return if everything works. - // This is because we are going to use this substitution, - // partially filled in with the results we know so far, - // in order to specialize any constraints on the generic. - // - // E.g., if the generic parameters were ``, and - // we've already decided that `T` is `Robin`, then we want to - // search for a conformance `Robin : ISidekick`, which involved - // apply the substitutions we already know... - - RefPtr solvedSubst = new GenericSubstitution(); - solvedSubst->genericDecl = genericDeclRef.getDecl(); - solvedSubst->outer = genericDeclRef.substitutions.substitutions; - solvedSubst->args = args; - resultSubst.substitutions = solvedSubst; - - for( auto constraintDecl : genericDeclRef.getDecl()->getMembersOfType() ) - { - DeclRef constraintDeclRef( - constraintDecl, - solvedSubst); - - // Extract the (substituted) sub- and super-type from the constraint. - auto sub = GetSub(constraintDeclRef); - auto sup = GetSup(constraintDeclRef); - - // Search for a witness that shows the constraint is satisfied. - auto subTypeWitness = tryGetSubtypeWitness(sub, sup); - if(subTypeWitness) - { - // We found a witness, so it will become an (implicit) argument. - solvedSubst->args.add(subTypeWitness); - } - else - { - // No witness was found, so the inference will now fail. - // - // TODO: Ideally we should print an error message in - // this case, to let the user know why things failed. - return SubstitutionSet(); - } - - // TODO: We may need to mark some constrains in our constraint - // system as being solved now, as a result of the witness we found. - } - - // Make sure we haven't constructed any spurious constraints - // that we aren't able to satisfy: - for (auto c : system->constraints) - { - if (!c.satisfied) - { - return SubstitutionSet(); - } - } - - return resultSubst; - } - - - // State related to overload resolution for a call - // to an overloaded symbol - struct OverloadResolveContext - { - enum class Mode - { - // We are just checking if a candidate works or not - JustTrying, - - // We want to actually update the AST for a chosen candidate - ForReal, - }; - - // Location to use when reporting overload-resolution errors. - SourceLoc loc; - - // The original expression (if any) that triggered things - RefPtr originalExpr; - - // Source location of the "function" part of the expression, if any - SourceLoc funcLoc; - - // The original arguments to the call - Index argCount = 0; - RefPtr* args = nullptr; - RefPtr* argTypes = nullptr; - - Index getArgCount() { return argCount; } - RefPtr& getArg(Index index) { return args[index]; } - RefPtr& getArgType(Index index) - { - if(argTypes) - return argTypes[index]; - else - return getArg(index)->type.type; - } - - bool disallowNestedConversions = false; - - RefPtr baseExpr; - - // Are we still trying out candidates, or are we - // checking the chosen one for real? - Mode mode = Mode::JustTrying; - - // We store one candidate directly, so that we don't - // need to do dynamic allocation on the list every time - OverloadCandidate bestCandidateStorage; - OverloadCandidate* bestCandidate = nullptr; - - // Full list of all candidates being considered, in the ambiguous case - List bestCandidates; - }; - - struct ParamCounts - { - UInt required; - UInt allowed; - }; - - // count the number of parameters required/allowed for a callable - ParamCounts CountParameters(FilteredMemberRefList params) - { - ParamCounts counts = { 0, 0 }; - for (auto param : params) - { - counts.allowed++; - - // No initializer means no default value - // - // TODO(tfoley): The logic here is currently broken in two ways: - // - // 1. We are assuming that once one parameter has a default, then all do. - // This can/should be validated earlier, so that we can assume it here. - // - // 2. We are not handling the possibility of multiple declarations for - // a single function, where we'd need to merge default parameters across - // all the declarations. - if (!param.getDecl()->initExpr) - { - counts.required++; - } - } - return counts; - } - - // count the number of parameters required/allowed for a generic - ParamCounts CountParameters(DeclRef genericRef) - { - ParamCounts counts = { 0, 0 }; - for (auto m : genericRef.getDecl()->Members) - { - if (auto typeParam = as(m)) - { - counts.allowed++; - if (!typeParam->initType.Ptr()) - { - counts.required++; - } - } - else if (auto valParam = as(m)) - { - counts.allowed++; - if (!valParam->initExpr) - { - counts.required++; - } - } - } - return counts; - } - - bool TryCheckOverloadCandidateArity( - OverloadResolveContext& context, - OverloadCandidate const& candidate) - { - UInt argCount = context.getArgCount(); - ParamCounts paramCounts = { 0, 0 }; - switch (candidate.flavor) - { - case OverloadCandidate::Flavor::Func: - paramCounts = CountParameters(GetParameters(candidate.item.declRef.as())); - break; - - case OverloadCandidate::Flavor::Generic: - paramCounts = CountParameters(candidate.item.declRef.as()); - break; - - default: - SLANG_UNEXPECTED("unknown flavor of overload candidate"); - break; - } - - if (argCount >= paramCounts.required && argCount <= paramCounts.allowed) - return true; - - // Emit an error message if we are checking this call for real - if (context.mode != OverloadResolveContext::Mode::JustTrying) - { - if (argCount < paramCounts.required) - { - getSink()->diagnose(context.loc, Diagnostics::notEnoughArguments, argCount, paramCounts.required); - } - else - { - SLANG_ASSERT(argCount > paramCounts.allowed); - getSink()->diagnose(context.loc, Diagnostics::tooManyArguments, argCount, paramCounts.allowed); - } - } - - return false; - } - - bool TryCheckOverloadCandidateFixity( - OverloadResolveContext& context, - OverloadCandidate const& candidate) - { - auto expr = context.originalExpr; - - auto decl = candidate.item.declRef.decl; - - if(auto prefixExpr = as(expr)) - { - if(decl->HasModifier()) - return true; - - if (context.mode != OverloadResolveContext::Mode::JustTrying) - { - getSink()->diagnose(context.loc, Diagnostics::expectedPrefixOperator); - getSink()->diagnose(decl, Diagnostics::seeDefinitionOf, decl->getName()); - } - - return false; - } - else if(auto postfixExpr = as(expr)) - { - if(decl->HasModifier()) - return true; - - if (context.mode != OverloadResolveContext::Mode::JustTrying) - { - getSink()->diagnose(context.loc, Diagnostics::expectedPostfixOperator); - getSink()->diagnose(decl, Diagnostics::seeDefinitionOf, decl->getName()); - } - - return false; - } - else - { - return true; - } - - return false; - } - - bool TryCheckGenericOverloadCandidateTypes( - OverloadResolveContext& context, - OverloadCandidate& candidate) - { - auto genericDeclRef = candidate.item.declRef.as(); - - // We will go ahead and hang onto the arguments that we've - // already checked, since downstream validation might need - // them. - auto genSubst = new GenericSubstitution(); - candidate.subst = genSubst; - auto& checkedArgs = genSubst->args; - - Index aa = 0; - for (auto memberRef : getMembers(genericDeclRef)) - { - if (auto typeParamRef = memberRef.as()) - { - if (aa >= context.argCount) - { - return false; - } - auto arg = context.getArg(aa++); - - TypeExp typeExp; - if (context.mode == OverloadResolveContext::Mode::JustTrying) - { - typeExp = tryCoerceToProperType(TypeExp(arg)); - if(!typeExp.type) - { - return false; - } - } - else - { - typeExp = CoerceToProperType(TypeExp(arg)); - } - checkedArgs.add(typeExp.type); - } - else if (auto valParamRef = memberRef.as()) - { - auto arg = context.getArg(aa++); - - if (context.mode == OverloadResolveContext::Mode::JustTrying) - { - ConversionCost cost = kConversionCost_None; - if (!canCoerce(GetType(valParamRef), arg->type, &cost)) - { - return false; - } - candidate.conversionCostSum += cost; - } - - arg = coerce(GetType(valParamRef), arg); - auto val = ExtractGenericArgInteger(arg); - checkedArgs.add(val); - } - else - { - continue; - } - } - - // Okay, we've made it! - return true; - } - - bool TryCheckOverloadCandidateTypes( - OverloadResolveContext& context, - OverloadCandidate& candidate) - { - Index argCount = context.getArgCount(); - - List> params; - switch (candidate.flavor) - { - case OverloadCandidate::Flavor::Func: - params = GetParameters(candidate.item.declRef.as()).ToArray(); - break; - - case OverloadCandidate::Flavor::Generic: - return TryCheckGenericOverloadCandidateTypes(context, candidate); - - default: - SLANG_UNEXPECTED("unknown flavor of overload candidate"); - break; - } - - // Note(tfoley): We might have fewer arguments than parameters in the - // case where one or more parameters had defaults. - SLANG_RELEASE_ASSERT(argCount <= params.getCount()); - - for (Index ii = 0; ii < argCount; ++ii) - { - auto& arg = context.getArg(ii); - auto argType = context.getArgType(ii); - auto param = params[ii]; - - if (context.mode == OverloadResolveContext::Mode::JustTrying) - { - ConversionCost cost = kConversionCost_None; - if( context.disallowNestedConversions ) - { - // We need an exact match in this case. - if(!GetType(param)->Equals(argType)) - return false; - } - else if (!canCoerce(GetType(param), argType, &cost)) - { - return false; - } - candidate.conversionCostSum += cost; - } - else - { - arg = coerce(GetType(param), arg); - } - } - return true; - } - - bool TryCheckOverloadCandidateDirections( - OverloadResolveContext& /*context*/, - OverloadCandidate const& /*candidate*/) - { - // TODO(tfoley): check `in` and `out` markers, as needed. - return true; - } - - // Create a witness that attests to the fact that `type` - // is equal to itself. - RefPtr createTypeEqualityWitness( - Type* type) - { - RefPtr rs = new TypeEqualityWitness(); - rs->sub = type; - rs->sup = type; - return rs; - } - - // If `sub` is a subtype of `sup`, then return a value that - // can serve as a "witness" for that fact. - RefPtr tryGetSubtypeWitness( - RefPtr sub, - RefPtr sup) - { - if(sub->Equals(sup)) - { - // They are the same type, so we just need a witness - // for type equality. - return createTypeEqualityWitness(sub); - } - - if(auto supDeclRefType = as(sup)) - { - auto supDeclRef = supDeclRefType->declRef; - if(auto supInterfaceDeclRef = supDeclRef.as()) - { - if(auto witness = tryGetInterfaceConformanceWitness(sub, supInterfaceDeclRef)) - { - return witness; - } - } - } - - return nullptr; - } - - // In the case where we are explicitly applying a generic - // to arguments (e.g., `G`) check that the constraints - // on those parameters are satisfied. - // - // Note: the constraints actually work as additional parameters/arguments - // of the generic, and so we need to reify them into the final - // argument list. - // - bool TryCheckOverloadCandidateConstraints( - OverloadResolveContext& context, - OverloadCandidate const& candidate) - { - // We only need this step for generics, so always succeed on - // everything else. - if(candidate.flavor != OverloadCandidate::Flavor::Generic) - return true; - - auto genericDeclRef = candidate.item.declRef.as(); - SLANG_ASSERT(genericDeclRef); // otherwise we wouldn't be a generic candidate... - - // We should have the existing arguments to the generic - // handy, so that we can construct a substitution list. - - auto subst = candidate.subst.as(); - SLANG_ASSERT(subst); - - subst->genericDecl = genericDeclRef.getDecl(); - subst->outer = genericDeclRef.substitutions.substitutions; - - for( auto constraintDecl : genericDeclRef.getDecl()->getMembersOfType() ) - { - auto subset = genericDeclRef.substitutions; - subset.substitutions = subst; - DeclRef constraintDeclRef( - constraintDecl, subset); - - auto sub = GetSub(constraintDeclRef); - auto sup = GetSup(constraintDeclRef); - - auto subTypeWitness = tryGetSubtypeWitness(sub, sup); - if(subTypeWitness) - { - subst->args.add(subTypeWitness); - } - else - { - if(context.mode != OverloadResolveContext::Mode::JustTrying) - { - // TODO: diagnose a problem here - getSink()->diagnose(context.loc, Diagnostics::unimplemented, "generic constraint not satisfied"); - } - return false; - } - } - - // Done checking all the constraints, hooray. - return true; - } - - // Try to check an overload candidate, but bail out - // if any step fails - void TryCheckOverloadCandidate( - OverloadResolveContext& context, - OverloadCandidate& candidate) - { - if (!TryCheckOverloadCandidateArity(context, candidate)) - return; - - candidate.status = OverloadCandidate::Status::ArityChecked; - if (!TryCheckOverloadCandidateFixity(context, candidate)) - return; - - candidate.status = OverloadCandidate::Status::FixityChecked; - if (!TryCheckOverloadCandidateTypes(context, candidate)) - return; - - candidate.status = OverloadCandidate::Status::TypeChecked; - if (!TryCheckOverloadCandidateDirections(context, candidate)) - return; - - candidate.status = OverloadCandidate::Status::DirectionChecked; - if (!TryCheckOverloadCandidateConstraints(context, candidate)) - return; - - candidate.status = OverloadCandidate::Status::Applicable; - } - - // Create the representation of a given generic applied to some arguments - RefPtr createGenericDeclRef( - RefPtr baseExpr, - RefPtr originalExpr, - RefPtr subst) - { - auto baseDeclRefExpr = as(baseExpr); - if (!baseDeclRefExpr) - { - SLANG_DIAGNOSE_UNEXPECTED(getSink(), baseExpr, "expected a reference to a generic declaration"); - return CreateErrorExpr(originalExpr); - } - auto baseGenericRef = baseDeclRefExpr->declRef.as(); - if (!baseGenericRef) - { - SLANG_DIAGNOSE_UNEXPECTED(getSink(), baseExpr, "expected a reference to a generic declaration"); - return CreateErrorExpr(originalExpr); - } - - subst->genericDecl = baseGenericRef.getDecl(); - subst->outer = baseGenericRef.substitutions.substitutions; - - DeclRef innerDeclRef(GetInner(baseGenericRef), subst); - - RefPtr base; - if (auto mbrExpr = as(baseExpr)) - base = mbrExpr->BaseExpression; - - return ConstructDeclRefExpr( - innerDeclRef, - base, - originalExpr->loc); - } - - // Take an overload candidate that previously got through - // `TryCheckOverloadCandidate` above, and try to finish - // up the work and turn it into a real expression. - // - // If the candidate isn't actually applicable, this is - // where we'd start reporting the issue(s). - RefPtr CompleteOverloadCandidate( - OverloadResolveContext& context, - OverloadCandidate& candidate) - { - // special case for generic argument inference failure - if (candidate.status == OverloadCandidate::Status::GenericArgumentInferenceFailed) - { - String callString = getCallSignatureString(context); - getSink()->diagnose( - context.loc, - Diagnostics::genericArgumentInferenceFailed, - callString); - - String declString = getDeclSignatureString(candidate.item); - getSink()->diagnose(candidate.item.declRef, Diagnostics::genericSignatureTried, declString); - goto error; - } - - context.mode = OverloadResolveContext::Mode::ForReal; - - if (!TryCheckOverloadCandidateArity(context, candidate)) - goto error; - - if (!TryCheckOverloadCandidateFixity(context, candidate)) - goto error; - - if (!TryCheckOverloadCandidateTypes(context, candidate)) - goto error; - - if (!TryCheckOverloadCandidateDirections(context, candidate)) - goto error; - - if (!TryCheckOverloadCandidateConstraints(context, candidate)) - goto error; - - { - auto baseExpr = ConstructLookupResultExpr( - candidate.item, context.baseExpr, context.funcLoc); - - switch(candidate.flavor) - { - case OverloadCandidate::Flavor::Func: - { - RefPtr callExpr = as(context.originalExpr); - if(!callExpr) - { - callExpr = new InvokeExpr(); - callExpr->loc = context.loc; - - for(Index aa = 0; aa < context.argCount; ++aa) - callExpr->Arguments.add(context.getArg(aa)); - } - - - callExpr->FunctionExpr = baseExpr; - callExpr->type = QualType(candidate.resultType); - - // A call may yield an l-value, and we should take a look at the candidate to be sure - if(auto subscriptDeclRef = candidate.item.declRef.as()) - { - for(auto setter : subscriptDeclRef.getDecl()->getMembersOfType()) - { - callExpr->type.IsLeftValue = true; - } - for(auto refAccessor : subscriptDeclRef.getDecl()->getMembersOfType()) - { - callExpr->type.IsLeftValue = true; - } - } - - // TODO: there may be other cases that confer l-value-ness - - return callExpr; - } - - break; - - case OverloadCandidate::Flavor::Generic: - return createGenericDeclRef( - baseExpr, - context.originalExpr, - candidate.subst.as()); - break; - - default: - SLANG_DIAGNOSE_UNEXPECTED(getSink(), context.loc, "unknown overload candidate flavor"); - break; - } - } - - - error: - - if(context.originalExpr) - { - return CreateErrorExpr(context.originalExpr.Ptr()); - } - else - { - SLANG_DIAGNOSE_UNEXPECTED(getSink(), context.loc, "no original expression for overload result"); - return nullptr; - } - } - - // Implement a comparison operation between overload candidates, - // so that the better candidate compares as less-than the other - int CompareOverloadCandidates( - OverloadCandidate* left, - OverloadCandidate* right) - { - // If one candidate got further along in validation, pick it - if (left->status != right->status) - return int(right->status) - int(left->status); - - // If both candidates are applicable, then we need to compare - // the costs of their type conversion sequences - if(left->status == OverloadCandidate::Status::Applicable) - { - if (left->conversionCostSum != right->conversionCostSum) - return left->conversionCostSum - right->conversionCostSum; - } - - return 0; - } - - void AddOverloadCandidateInner( - OverloadResolveContext& context, - OverloadCandidate& candidate) - { - // Filter our existing candidates, to remove any that are worse than our new one - - bool keepThisCandidate = true; // should this candidate be kept? - - if (context.bestCandidates.getCount() != 0) - { - // We have multiple candidates right now, so filter them. - bool anyFiltered = false; - // Note that we are querying the list length on every iteration, - // because we might remove things. - for (Index cc = 0; cc < context.bestCandidates.getCount(); ++cc) - { - int cmp = CompareOverloadCandidates(&candidate, &context.bestCandidates[cc]); - if (cmp < 0) - { - // our new candidate is better! - - // remove it from the list (by swapping in a later one) - context.bestCandidates.fastRemoveAt(cc); - // and then reduce our index so that we re-visit the same index - --cc; - - anyFiltered = true; - } - else if(cmp > 0) - { - // our candidate is worse! - keepThisCandidate = false; - } - } - // It should not be possible that we removed some existing candidate *and* - // chose not to keep this candidate (otherwise the better-ness relation - // isn't transitive). Therefore we confirm that we either chose to keep - // this candidate (in which case filtering is okay), or we didn't filter - // anything. - SLANG_ASSERT(keepThisCandidate || !anyFiltered); - } - else if(context.bestCandidate) - { - // There's only one candidate so far - int cmp = CompareOverloadCandidates(&candidate, context.bestCandidate); - if(cmp < 0) - { - // our new candidate is better! - context.bestCandidate = nullptr; - } - else if (cmp > 0) - { - // our candidate is worse! - keepThisCandidate = false; - } - } - - // If our candidate isn't good enough, then drop it - if (!keepThisCandidate) - return; - - // Otherwise we want to keep the candidate - if (context.bestCandidates.getCount() > 0) - { - // There were already multiple candidates, and we are adding one more - context.bestCandidates.add(candidate); - } - else if (context.bestCandidate) - { - // There was a unique best candidate, but now we are ambiguous - context.bestCandidates.add(*context.bestCandidate); - context.bestCandidates.add(candidate); - context.bestCandidate = nullptr; - } - else - { - // This is the only candidate worth keeping track of right now - context.bestCandidateStorage = candidate; - context.bestCandidate = &context.bestCandidateStorage; - } - } - - void AddOverloadCandidate( - OverloadResolveContext& context, - OverloadCandidate& candidate) - { - // Try the candidate out, to see if it is applicable at all. - TryCheckOverloadCandidate(context, candidate); - - // Now (potentially) add it to the set of candidate overloads to consider. - AddOverloadCandidateInner(context, candidate); - } - - void AddFuncOverloadCandidate( - LookupResultItem item, - DeclRef funcDeclRef, - OverloadResolveContext& context) - { - auto funcDecl = funcDeclRef.getDecl(); - checkDecl(funcDecl); - - // If this function is a redeclaration, - // then we don't want to include it multiple times, - // and mistakenly think we have an ambiguous call. - // - // Instead, we will carefully consider only the - // "primary" declaration of any callable. - if (auto primaryDecl = funcDecl->primaryDecl) - { - if (funcDecl != primaryDecl) - { - // This is a redeclaration, so we don't - // want to consider it. The primary - // declaration should also get considered - // for the call site and it will match - // anything this declaration would have - // matched. - return; - } - } - - OverloadCandidate candidate; - candidate.flavor = OverloadCandidate::Flavor::Func; - candidate.item = item; - candidate.resultType = GetResultType(funcDeclRef); - - AddOverloadCandidate(context, candidate); - } - - void AddFuncOverloadCandidate( - RefPtr /*funcType*/, - OverloadResolveContext& /*context*/) - { -#if 0 - if (funcType->decl) - { - AddFuncOverloadCandidate(funcType->decl, context); - } - else if (funcType->Func) - { - AddFuncOverloadCandidate(funcType->Func->SyntaxNode, context); - } - else if (funcType->Component) - { - AddComponentFuncOverloadCandidate(funcType->Component, context); - } -#else - throw "unimplemented"; -#endif - } - - // Add a candidate callee for overload resolution, based on - // calling a particular `ConstructorDecl`. - void AddCtorOverloadCandidate( - LookupResultItem typeItem, - RefPtr type, - DeclRef ctorDeclRef, - OverloadResolveContext& context, - RefPtr resultType) - { - checkDecl(ctorDeclRef.getDecl()); - - // `typeItem` refers to the type being constructed (the thing - // that was applied as a function) so we need to construct - // a `LookupResultItem` that refers to the constructor instead - - LookupResultItem ctorItem; - ctorItem.declRef = ctorDeclRef; - ctorItem.breadcrumbs = new LookupResultItem::Breadcrumb( - LookupResultItem::Breadcrumb::Kind::Member, - typeItem.declRef, - typeItem.breadcrumbs); - - OverloadCandidate candidate; - candidate.flavor = OverloadCandidate::Flavor::Func; - candidate.item = ctorItem; - candidate.resultType = resultType; - - AddOverloadCandidate(context, candidate); - } - - // If the given declaration has generic parameters, then - // return the corresponding `GenericDecl` that holds the - // parameters, etc. - GenericDecl* GetOuterGeneric(Decl* decl) - { - auto parentDecl = decl->ParentDecl; - if (!parentDecl) return nullptr; - auto parentGeneric = as(parentDecl); - return parentGeneric; - } - - // Try to find a unification for two values - bool TryUnifyVals( - ConstraintSystem& constraints, - RefPtr fst, - RefPtr snd) - { - // if both values are types, then unify types - if (auto fstType = as(fst)) - { - if (auto sndType = as(snd)) - { - return TryUnifyTypes(constraints, fstType, sndType); - } - } - - // if both values are constant integers, then compare them - if (auto fstIntVal = as(fst)) - { - if (auto sndIntVal = as(snd)) - { - return fstIntVal->value == sndIntVal->value; - } - } - - // Check if both are integer values in general - if (auto fstInt = as(fst)) - { - if (auto sndInt = as(snd)) - { - auto fstParam = as(fstInt); - auto sndParam = as(sndInt); - - bool okay = false; - if (fstParam) - { - if(TryUnifyIntParam(constraints, fstParam->declRef, sndInt)) - okay = true; - } - if (sndParam) - { - if(TryUnifyIntParam(constraints, sndParam->declRef, fstInt)) - okay = true; - } - return okay; - } - } - - if (auto fstWit = as(fst)) - { - if (auto sndWit = as(snd)) - { - auto constraintDecl1 = fstWit->declRef.as(); - auto constraintDecl2 = sndWit->declRef.as(); - SLANG_ASSERT(constraintDecl1); - SLANG_ASSERT(constraintDecl2); - return TryUnifyTypes(constraints, - constraintDecl1.getDecl()->getSup().type, - constraintDecl2.getDecl()->getSup().type); - } - } - - SLANG_UNIMPLEMENTED_X("value unification case"); - - // default: fail - return false; - } - - bool tryUnifySubstitutions( - ConstraintSystem& constraints, - RefPtr fst, - RefPtr snd) - { - // They must both be NULL or non-NULL - if (!fst || !snd) - return !fst && !snd; - - if(auto fstGeneric = as(fst)) - { - if(auto sndGeneric = as(snd)) - { - return tryUnifyGenericSubstitutions( - constraints, - fstGeneric, - sndGeneric); - } - } - - // TODO: need to handle other cases here - - return false; - } - - bool tryUnifyGenericSubstitutions( - ConstraintSystem& constraints, - RefPtr fst, - RefPtr snd) - { - SLANG_ASSERT(fst); - SLANG_ASSERT(snd); - - auto fstGen = fst; - auto sndGen = snd; - // They must be specializing the same generic - if (fstGen->genericDecl != sndGen->genericDecl) - return false; - - // Their arguments must unify - SLANG_RELEASE_ASSERT(fstGen->args.getCount() == sndGen->args.getCount()); - Index argCount = fstGen->args.getCount(); - bool okay = true; - for (Index aa = 0; aa < argCount; ++aa) - { - if (!TryUnifyVals(constraints, fstGen->args[aa], sndGen->args[aa])) - { - okay = false; - } - } - - // Their "base" specializations must unify - if (!tryUnifySubstitutions(constraints, fstGen->outer, sndGen->outer)) - { - okay = false; - } - - return okay; - } - - bool TryUnifyTypeParam( - ConstraintSystem& constraints, - RefPtr typeParamDecl, - RefPtr type) - { - // We want to constrain the given type parameter - // to equal the given type. - Constraint constraint; - constraint.decl = typeParamDecl.Ptr(); - constraint.val = type; - - constraints.constraints.add(constraint); - - return true; - } - - bool TryUnifyIntParam( - ConstraintSystem& constraints, - RefPtr paramDecl, - RefPtr val) - { - // We only want to accumulate constraints on - // the parameters of the declarations being - // specialized (don't accidentially constrain - // parameters of a generic function based on - // calls in its body). - if(paramDecl->ParentDecl != constraints.genericDecl) - return false; - - // We want to constrain the given parameter to equal the given value. - Constraint constraint; - constraint.decl = paramDecl.Ptr(); - constraint.val = val; - - constraints.constraints.add(constraint); - - return true; - } - - bool TryUnifyIntParam( - ConstraintSystem& constraints, - DeclRef const& varRef, - RefPtr val) - { - if(auto genericValueParamRef = varRef.as()) - { - return TryUnifyIntParam(constraints, RefPtr(genericValueParamRef.getDecl()), val); - } - else - { - return false; - } - } - - bool TryUnifyTypesByStructuralMatch( - ConstraintSystem& constraints, - RefPtr fst, - RefPtr snd) - { - if (auto fstDeclRefType = as(fst)) - { - auto fstDeclRef = fstDeclRefType->declRef; - - if (auto typeParamDecl = as(fstDeclRef.getDecl())) - return TryUnifyTypeParam(constraints, typeParamDecl, snd); - - if (auto sndDeclRefType = as(snd)) - { - auto sndDeclRef = sndDeclRefType->declRef; - - if (auto typeParamDecl = as(sndDeclRef.getDecl())) - return TryUnifyTypeParam(constraints, typeParamDecl, fst); - - // can't be unified if they refer to different declarations. - if (fstDeclRef.getDecl() != sndDeclRef.getDecl()) return false; - - // next we need to unify the substitutions applied - // to each declaration reference. - if (!tryUnifySubstitutions( - constraints, - fstDeclRef.substitutions.substitutions, - sndDeclRef.substitutions.substitutions)) - { - return false; - } - - return true; - } - } - - return false; - } - - bool TryUnifyTypes( - ConstraintSystem& constraints, - RefPtr fst, - RefPtr snd) - { - if (fst->Equals(snd)) return true; - - // An error type can unify with anything, just so we avoid cascading errors. - - if (auto fstErrorType = as(fst)) - return true; - - if (auto sndErrorType = as(snd)) - return true; - - // A generic parameter type can unify with anything. - // TODO: there actually needs to be some kind of "occurs check" sort - // of thing here... - - if (auto fstDeclRefType = as(fst)) - { - auto fstDeclRef = fstDeclRefType->declRef; - - if (auto typeParamDecl = as(fstDeclRef.getDecl())) - { - if(typeParamDecl->ParentDecl == constraints.genericDecl ) - return TryUnifyTypeParam(constraints, typeParamDecl, snd); - } - } - - if (auto sndDeclRefType = as(snd)) - { - auto sndDeclRef = sndDeclRefType->declRef; - - if (auto typeParamDecl = as(sndDeclRef.getDecl())) - { - if(typeParamDecl->ParentDecl == constraints.genericDecl ) - return TryUnifyTypeParam(constraints, typeParamDecl, fst); - } - } - - // If we can unify the types structurally, then we are golden - if(TryUnifyTypesByStructuralMatch(constraints, fst, snd)) - return true; - - // Now we need to consider cases where coercion might - // need to be applied. For now we can try to do this - // in a completely ad hoc fashion, but eventually we'd - // want to do it more formally. - - if(auto fstVectorType = as(fst)) - { - if(auto sndScalarType = as(snd)) - { - return TryUnifyTypes( - constraints, - fstVectorType->elementType, - sndScalarType); - } - } - - if(auto fstScalarType = as(fst)) - { - if(auto sndVectorType = as(snd)) - { - return TryUnifyTypes( - constraints, - fstScalarType, - sndVectorType->elementType); - } - } - - // TODO: the same thing for vectors... - - return false; - } - - // Is the candidate extension declaration actually applicable to the given type - DeclRef ApplyExtensionToType( - ExtensionDecl* extDecl, - RefPtr type) - { - DeclRef extDeclRef = makeDeclRef(extDecl); - - // If the extension is a generic extension, then we - // need to infer type arguments that will give - // us a target type that matches `type`. - // - if (auto extGenericDecl = GetOuterGeneric(extDecl)) - { - ConstraintSystem constraints; - constraints.loc = extDecl->loc; - constraints.genericDecl = extGenericDecl; - - if (!TryUnifyTypes(constraints, extDecl->targetType.Ptr(), type)) - return DeclRef(); - - auto constraintSubst = TrySolveConstraintSystem(&constraints, DeclRef(extGenericDecl, nullptr).as()); - if (!constraintSubst) - { - return DeclRef(); - } - - // Construct a reference to the extension with our constraint variables - // set as they were found by solving the constraint system. - extDeclRef = DeclRef(extDecl, constraintSubst).as(); - } - - // Now extract the target type from our (possibly specialized) extension decl-ref. - RefPtr targetType = GetTargetType(extDeclRef); - - // As a bit of a kludge here, if the target type of the extension is - // an interface, and the `type` we are trying to match up has a this-type - // substitution for that interface, then we want to attach a matching - // substitution to the extension decl-ref. - if(auto targetDeclRefType = as(targetType)) - { - if(auto targetInterfaceDeclRef = targetDeclRefType->declRef.as()) - { - // Okay, the target type is an interface. - // - // Is the type we want to apply to also an interface? - if(auto appDeclRefType = as(type)) - { - if(auto appInterfaceDeclRef = appDeclRefType->declRef.as()) - { - if(appInterfaceDeclRef.getDecl() == targetInterfaceDeclRef.getDecl()) - { - // Looks like we have a match in the types, - // now let's see if we have a this-type substitution. - if(auto appThisTypeSubst = appInterfaceDeclRef.substitutions.substitutions.as()) - { - if(appThisTypeSubst->interfaceDecl == appInterfaceDeclRef.getDecl()) - { - // The type we want to apply to has a this-type substitution, - // and (by construction) the target type currently does not. - // - SLANG_ASSERT(!targetInterfaceDeclRef.substitutions.substitutions.as()); - - // We will create a new substitution to apply to the target type. - RefPtr newTargetSubst = new ThisTypeSubstitution(); - newTargetSubst->interfaceDecl = appThisTypeSubst->interfaceDecl; - newTargetSubst->witness = appThisTypeSubst->witness; - newTargetSubst->outer = targetInterfaceDeclRef.substitutions.substitutions; - - targetType = DeclRefType::Create(getSession(), - DeclRef(targetInterfaceDeclRef.getDecl(), newTargetSubst)); - - // Note: we are constructing a this-type substitution that - // we will apply to the extension declaration as well. - // This is not strictly allowed by our current representation - // choices, but we need it in order to make sure that - // references to the target type of the extension - // declaration have a chance to resolve the way we want them to. - - RefPtr newExtSubst = new ThisTypeSubstitution(); - newExtSubst->interfaceDecl = appThisTypeSubst->interfaceDecl; - newExtSubst->witness = appThisTypeSubst->witness; - newExtSubst->outer = extDeclRef.substitutions.substitutions; - - extDeclRef = DeclRef( - extDeclRef.getDecl(), - newExtSubst); - - // TODO: Ideally we should also apply the chosen specialization to - // the decl-ref for the extension, so that subsequent lookup through - // the members of this extension will retain that substitution and - // be able to apply it. - // - // E.g., if an extension method returns a value of an associated - // type, then we'd want that to become specialized to a concrete - // type when using the extension method on a value of concrete type. - // - // The challenge here that makes me reluctant to just staple on - // such a substitution is that it wouldn't follow our implicit - // rules about where `ThisTypeSubstitution`s can appear. - } - } - } - } - } - } - } - - // In order for this extension to apply to the given type, we - // need to have a match on the target types. - if (!type->Equals(targetType)) - return DeclRef(); - - - return extDeclRef; - } - -#if 0 - bool TryUnifyArgAndParamTypes( - ConstraintSystem& system, - RefPtr argExpr, - DeclRef paramDeclRef) - { - // TODO(tfoley): potentially need a bit more - // nuance in case where argument might be - // an overload group... - return TryUnifyTypes(system, argExpr->type, GetType(paramDeclRef)); - } -#endif - - // Take a generic declaration and try to specialize its parameters - // so that the resulting inner declaration can be applicable in - // a particular context... - DeclRef SpecializeGenericForOverload( - DeclRef genericDeclRef, - OverloadResolveContext& context) - { - checkDecl(genericDeclRef.getDecl()); - - ConstraintSystem constraints; - constraints.loc = context.loc; - constraints.genericDecl = genericDeclRef.getDecl(); - - // Construct a reference to the inner declaration that has any generic - // parameter substitutions in place already, but *not* any substutions - // for the generic declaration we are currently trying to infer. - auto innerDecl = GetInner(genericDeclRef); - DeclRef unspecializedInnerRef = DeclRef(innerDecl, genericDeclRef.substitutions); - - // Check what type of declaration we are dealing with, and then try - // to match it up with the arguments accordingly... - if (auto funcDeclRef = unspecializedInnerRef.as()) - { - auto params = GetParameters(funcDeclRef).ToArray(); - - Index argCount = context.getArgCount(); - Index paramCount = params.getCount(); - - // Bail out on mismatch. - // TODO(tfoley): need more nuance here - if (argCount != paramCount) - { - return DeclRef(nullptr, nullptr); - } - - for (Index aa = 0; aa < argCount; ++aa) - { -#if 0 - if (!TryUnifyArgAndParamTypes(constraints, args[aa], params[aa])) - return DeclRef(nullptr, nullptr); -#else - // The question here is whether failure to "unify" an argument - // and parameter should lead to immediate failure. - // - // The case that is interesting is if we want to unify, say: - // `vector` and `vector` - // - // It is clear that we should solve with `N = 3`, and then - // a later step may find that the resulting types aren't - // actually a match. - // - // A more refined approach to "unification" could of course - // see that `int` can convert to `float` and use that fact. - // (and indeed we already use something like this to unify - // `float` and `vector`) - // - // So the question is then whether a mismatch during the - // unification step should be taken as an immediate failure... - - TryUnifyTypes(constraints, context.getArgType(aa), GetType(params[aa])); -#endif - } - } - else - { - // TODO(tfoley): any other cases needed here? - return DeclRef(nullptr, nullptr); - } - - auto constraintSubst = TrySolveConstraintSystem(&constraints, genericDeclRef); - if (!constraintSubst) - { - // constraint solving failed - return DeclRef(nullptr, nullptr); - } - - // We can now construct a reference to the inner declaration using - // the solution to our constraints. - return DeclRef(innerDecl, constraintSubst); - } - - void AddAggTypeOverloadCandidates( - LookupResultItem typeItem, - RefPtr type, - DeclRef aggTypeDeclRef, - OverloadResolveContext& context, - RefPtr resultType) - { - for (auto ctorDeclRef : getMembersOfType(aggTypeDeclRef)) - { - // now work through this candidate... - AddCtorOverloadCandidate(typeItem, type, ctorDeclRef, context, resultType); - } - - // Also check for generic constructors. - // - // TODO: There is way too much duplication between this case and the extension - // handling below, and all of this is *also* duplicative with the ordinary - // overload resolution logic for function. - // - // The right solution is to handle a "constructor" call expression by - // first doing member lookup in the type (for initializer members, which - // should all share a common name), and then to do overload resolution using - // the (possibly overloaded) result of that lookup. - // - for (auto genericDeclRef : getMembersOfType(aggTypeDeclRef)) - { - if (auto ctorDecl = as(genericDeclRef.getDecl()->inner)) - { - DeclRef innerRef = SpecializeGenericForOverload(genericDeclRef, context); - if (!innerRef) - continue; - - DeclRef innerCtorRef = innerRef.as(); - AddCtorOverloadCandidate(typeItem, type, innerCtorRef, context, resultType); - } - } - - // Now walk through any extensions we can find for this types - for (auto ext = GetCandidateExtensions(aggTypeDeclRef); ext; ext = ext->nextCandidateExtension) - { - auto extDeclRef = ApplyExtensionToType(ext, type); - if (!extDeclRef) - continue; - - for (auto ctorDeclRef : getMembersOfType(extDeclRef)) - { - // TODO(tfoley): `typeItem` here should really reference the extension... - - // now work through this candidate... - AddCtorOverloadCandidate(typeItem, type, ctorDeclRef, context, resultType); - } - - // Also check for generic constructors - for (auto genericDeclRef : getMembersOfType(extDeclRef)) - { - if (auto ctorDecl = genericDeclRef.getDecl()->inner.as()) - { - DeclRef innerRef = SpecializeGenericForOverload(genericDeclRef, context); - if (!innerRef) - continue; - - DeclRef innerCtorRef = innerRef.as(); - - AddCtorOverloadCandidate(typeItem, type, innerCtorRef, context, resultType); - - // TODO(tfoley): need a way to do the solving step for the constraint system - } - } - } - } - - void addGenericTypeParamOverloadCandidates( - DeclRef typeDeclRef, - OverloadResolveContext& context, - RefPtr resultType) - { - // We need to look for any constraints placed on the generic - // type parameter, since they will give us information on - // interfaces that the type must conform to. - - // We expect the parent of the generic type parameter to be a generic... - auto genericDeclRef = typeDeclRef.GetParent().as(); - SLANG_ASSERT(genericDeclRef); - - for(auto constraintDeclRef : getMembersOfType(genericDeclRef)) - { - // Does this constraint pertain to the type we are working on? - // - // We want constraints of the form `T : Foo` where `T` is the - // generic parameter in question, and `Foo` is whatever we are - // constraining it to. - auto subType = GetSub(constraintDeclRef); - auto subDeclRefType = as(subType); - if(!subDeclRefType) - continue; - if(!subDeclRefType->declRef.Equals(typeDeclRef)) - continue; - - // The super-type in the constraint (e.g., `Foo` in `T : Foo`) - // will tell us a type we should use for lookup. - auto bound = GetSup(constraintDeclRef); - - // Go ahead and use the target type: - // - // TODO: Need to consider case where this might recurse infinitely. - AddTypeOverloadCandidates(bound, context, resultType); - } - } - - void AddTypeOverloadCandidates( - RefPtr type, - OverloadResolveContext& context, - RefPtr resultType) - { - if (auto declRefType = as(type)) - { - auto declRef = declRefType->declRef; - if (auto aggTypeDeclRef = declRef.as()) - { - AddAggTypeOverloadCandidates(LookupResultItem(aggTypeDeclRef), type, aggTypeDeclRef, context, resultType); - } - else if(auto genericTypeParamDeclRef = declRef.as()) - { - addGenericTypeParamOverloadCandidates( - genericTypeParamDeclRef, - context, - resultType); - } - } - } - - void AddDeclRefOverloadCandidates( - LookupResultItem item, - OverloadResolveContext& context) - { - auto declRef = item.declRef; - - if (auto funcDeclRef = item.declRef.as()) - { - AddFuncOverloadCandidate(item, funcDeclRef, context); - } - else if (auto aggTypeDeclRef = item.declRef.as()) - { - auto type = DeclRefType::Create( - getSession(), - aggTypeDeclRef); - AddAggTypeOverloadCandidates(item, type, aggTypeDeclRef, context, type); - } - else if (auto genericDeclRef = item.declRef.as()) - { - // Try to infer generic arguments, based on the context - DeclRef innerRef = SpecializeGenericForOverload(genericDeclRef, context); - - if (innerRef) - { - // If inference works, then we've now got a - // specialized declaration reference we can apply. - - LookupResultItem innerItem; - innerItem.breadcrumbs = item.breadcrumbs; - innerItem.declRef = innerRef; - - AddDeclRefOverloadCandidates(innerItem, context); - } - else - { - // If inference failed, then we need to create - // a candidate that can be used to reflect that fact - // (so we can report a good error) - OverloadCandidate candidate; - candidate.item = item; - candidate.flavor = OverloadCandidate::Flavor::UnspecializedGeneric; - candidate.status = OverloadCandidate::Status::GenericArgumentInferenceFailed; - - AddOverloadCandidateInner(context, candidate); - } - } - else if( auto typeDefDeclRef = item.declRef.as() ) - { - auto type = getNamedType(getSession(), typeDefDeclRef); - AddTypeOverloadCandidates(GetType(typeDefDeclRef), context, type); - } - else if( auto genericTypeParamDeclRef = item.declRef.as() ) - { - auto type = DeclRefType::Create( - getSession(), - genericTypeParamDeclRef); - addGenericTypeParamOverloadCandidates(genericTypeParamDeclRef, context, type); - } - else - { - // TODO(tfoley): any other cases needed here? - } - } - - void AddOverloadCandidates( - RefPtr funcExpr, - OverloadResolveContext& context) - { - auto funcExprType = funcExpr->type; - - if (auto declRefExpr = as(funcExpr)) - { - // The expression directly referenced a declaration, - // so we can use that declaration directly to look - // for anything applicable. - AddDeclRefOverloadCandidates(LookupResultItem(declRefExpr->declRef), context); - } - else if (auto funcType = as(funcExprType)) - { - // TODO(tfoley): deprecate this path... - AddFuncOverloadCandidate(funcType, context); - } - else if (auto overloadedExpr = as(funcExpr)) - { - auto lookupResult = overloadedExpr->lookupResult2; - SLANG_RELEASE_ASSERT(lookupResult.isOverloaded()); - for(auto item : lookupResult.items) - { - AddDeclRefOverloadCandidates(item, context); - } - } - else if (auto overloadedExpr2 = as(funcExpr)) - { - for (auto item : overloadedExpr2->candidiateExprs) - { - AddOverloadCandidates(item, context); - } - } - else if (auto typeType = as(funcExprType)) - { - // If none of the above cases matched, but we are - // looking at a type, then I suppose we have - // a constructor call on our hands. - // - // TODO(tfoley): are there any meaningful types left - // that aren't declaration references? - auto type = typeType->type; - AddTypeOverloadCandidates(type, context, type); - return; - } - } - - void formatType(StringBuilder& sb, RefPtr type) - { - sb << type->ToString(); - } - - void formatVal(StringBuilder& sb, RefPtr val) - { - sb << val->ToString(); - } - - void formatDeclPath(StringBuilder& sb, DeclRef declRef) - { - // Find the parent declaration - auto parentDeclRef = declRef.GetParent(); - - // If the immediate parent is a generic, then we probably - // want the declaration above that... - auto parentGenericDeclRef = parentDeclRef.as(); - if(parentGenericDeclRef) - { - parentDeclRef = parentGenericDeclRef.GetParent(); - } - - // Depending on what the parent is, we may want to format things specially - if(auto aggTypeDeclRef = parentDeclRef.as()) - { - formatDeclPath(sb, aggTypeDeclRef); - sb << "."; - } - - sb << getText(declRef.GetName()); - - // If the parent declaration is a generic, then we need to print out its - // signature - if( parentGenericDeclRef ) - { - auto genSubst = declRef.substitutions.substitutions.as(); - SLANG_RELEASE_ASSERT(genSubst); - SLANG_RELEASE_ASSERT(genSubst->genericDecl == parentGenericDeclRef.getDecl()); - - sb << "<"; - bool first = true; - for(auto arg : genSubst->args) - { - if(!first) sb << ", "; - formatVal(sb, arg); - first = false; - } - sb << ">"; - } - } - - void formatDeclParams(StringBuilder& sb, DeclRef declRef) - { - if (auto funcDeclRef = declRef.as()) - { - - // This is something callable, so we need to also print parameter types for overloading - sb << "("; - - bool first = true; - for (auto paramDeclRef : GetParameters(funcDeclRef)) - { - if (!first) sb << ", "; - - formatType(sb, GetType(paramDeclRef)); - - first = false; - - } - - sb << ")"; - } - else if(auto genericDeclRef = declRef.as()) - { - sb << "<"; - bool first = true; - for (auto paramDeclRef : getMembers(genericDeclRef)) - { - if(auto genericTypeParam = paramDeclRef.as()) - { - if (!first) sb << ", "; - first = false; - - sb << getText(genericTypeParam.GetName()); - } - else if(auto genericValParam = paramDeclRef.as()) - { - if (!first) sb << ", "; - first = false; - - formatType(sb, GetType(genericValParam)); - sb << " "; - sb << getText(genericValParam.GetName()); - } - else - {} - } - sb << ">"; - - formatDeclParams(sb, DeclRef(GetInner(genericDeclRef), genericDeclRef.substitutions)); - } - else - { - } - } - - void formatDeclSignature(StringBuilder& sb, DeclRef declRef) - { - formatDeclPath(sb, declRef); - formatDeclParams(sb, declRef); - } - - String getDeclSignatureString(DeclRef declRef) - { - StringBuilder sb; - formatDeclSignature(sb, declRef); - return sb.ProduceString(); - } - - String getDeclSignatureString(LookupResultItem item) - { - return getDeclSignatureString(item.declRef); - } - - String getCallSignatureString( - OverloadResolveContext& context) - { - StringBuilder argsListBuilder; - argsListBuilder << "("; - - UInt argCount = context.getArgCount(); - for( UInt aa = 0; aa < argCount; ++aa ) - { - if(aa != 0) argsListBuilder << ", "; - argsListBuilder << context.getArgType(aa)->ToString(); - } - argsListBuilder << ")"; - return argsListBuilder.ProduceString(); - } - -#if 0 - String GetCallSignatureString(RefPtr expr) - { - return getCallSignatureString(expr->Arguments); - } -#endif - - RefPtr ResolveInvoke(InvokeExpr * expr) - { - OverloadResolveContext context; - // check if this is a stdlib operator call, if so we want to use cached results - // to speed up compilation - bool shouldAddToCache = false; - OperatorOverloadCacheKey key; - TypeCheckingCache* typeCheckingCache = getSession()->getTypeCheckingCache(); - if (auto opExpr = as(expr)) - { - if (key.fromOperatorExpr(opExpr)) - { - OverloadCandidate candidate; - if (typeCheckingCache->resolvedOperatorOverloadCache.TryGetValue(key, candidate)) - { - context.bestCandidateStorage = candidate; - context.bestCandidate = &context.bestCandidateStorage; - } - else - { - shouldAddToCache = true; - } - } - } - - // Look at the base expression for the call, and figure out how to invoke it. - auto funcExpr = expr->FunctionExpr; - auto funcExprType = funcExpr->type; - - // If we are trying to apply an erroneous expression, then just bail out now. - if(IsErrorExpr(funcExpr)) - { - return CreateErrorExpr(expr); - } - // If any of the arguments is an error, then we should bail out, to avoid - // cascading errors where we successfully pick an overload, but not the one - // the user meant. - for (auto arg : expr->Arguments) - { - if (IsErrorExpr(arg)) - return CreateErrorExpr(expr); - } - - context.originalExpr = expr; - context.funcLoc = funcExpr->loc; - - context.argCount = expr->Arguments.getCount(); - context.args = expr->Arguments.getBuffer(); - context.loc = expr->loc; - - if (auto funcMemberExpr = as(funcExpr)) - { - context.baseExpr = funcMemberExpr->BaseExpression; - } - else if (auto funcOverloadExpr = as(funcExpr)) - { - context.baseExpr = funcOverloadExpr->base; - } - else if (auto funcOverloadExpr2 = as(funcExpr)) - { - context.baseExpr = funcOverloadExpr2->base; - } - - if (!context.bestCandidate) - { - AddOverloadCandidates(funcExpr, context); - } - - if (context.bestCandidates.getCount() > 0) - { - // Things were ambiguous. - - // It might be that things were only ambiguous because - // one of the argument expressions had an error, and - // so a bunch of candidates could match at that position. - // - // If any argument was an error, we skip out on printing - // another message, to avoid cascading errors. - for (auto arg : expr->Arguments) - { - if (IsErrorExpr(arg)) - { - return CreateErrorExpr(expr); - } - } - - Name* funcName = nullptr; - if (auto baseVar = as(funcExpr)) - funcName = baseVar->name; - else if(auto baseMemberRef = as(funcExpr)) - funcName = baseMemberRef->name; - - String argsList = getCallSignatureString(context); - - if (context.bestCandidates[0].status != OverloadCandidate::Status::Applicable) - { - // There were multiple equally-good candidates, but none actually usable. - // We will construct a diagnostic message to help out. - - if (funcName) - { - getSink()->diagnose(expr, Diagnostics::noApplicableOverloadForNameWithArgs, funcName, argsList); - } - else - { - getSink()->diagnose(expr, Diagnostics::noApplicableWithArgs, argsList); - } - } - else - { - // There were multiple applicable candidates, so we need to report them. - - if (funcName) - { - getSink()->diagnose(expr, Diagnostics::ambiguousOverloadForNameWithArgs, funcName, argsList); - } - else - { - getSink()->diagnose(expr, Diagnostics::ambiguousOverloadWithArgs, argsList); - } - } - - { - Index candidateCount = context.bestCandidates.getCount(); - Index maxCandidatesToPrint = 10; // don't show too many candidates at once... - Index candidateIndex = 0; - for (auto candidate : context.bestCandidates) - { - String declString = getDeclSignatureString(candidate.item); - -// declString = declString + "[" + String(candidate.conversionCostSum) + "]"; - -#if 0 - // Debugging: ensure that we don't consider multiple declarations of the same operation - if (auto decl = as(candidate.item.declRef.decl)) - { - char buffer[1024]; - sprintf_s(buffer, sizeof(buffer), "[this:%p, primary:%p, next:%p]", - decl, - decl->primaryDecl, - decl->nextDecl); - declString.append(buffer); - } -#endif - - getSink()->diagnose(candidate.item.declRef, Diagnostics::overloadCandidate, declString); - - candidateIndex++; - if (candidateIndex == maxCandidatesToPrint) - break; - } - if (candidateIndex != candidateCount) - { - getSink()->diagnose(expr, Diagnostics::moreOverloadCandidates, candidateCount - candidateIndex); - } - } - - return CreateErrorExpr(expr); - } - else if (context.bestCandidate) - { - // There was one best candidate, even if it might not have been - // applicable in the end. - // We will report errors for this one candidate, then, to give - // the user the most help we can. - if (shouldAddToCache) - typeCheckingCache->resolvedOperatorOverloadCache[key] = *context.bestCandidate; - return CompleteOverloadCandidate(context, *context.bestCandidate); - } - else - { - // Nothing at all was found that we could even consider invoking - getSink()->diagnose(expr->FunctionExpr, Diagnostics::expectedFunction, funcExprType); - expr->type = QualType(getSession()->getErrorType()); - return expr; - } - } - - void AddGenericOverloadCandidate( - LookupResultItem baseItem, - OverloadResolveContext& context) - { - if (auto genericDeclRef = baseItem.declRef.as()) - { - checkDecl(genericDeclRef.getDecl()); - - OverloadCandidate candidate; - candidate.flavor = OverloadCandidate::Flavor::Generic; - candidate.item = baseItem; - candidate.resultType = nullptr; - - AddOverloadCandidate(context, candidate); - } - } - - void AddGenericOverloadCandidates( - RefPtr baseExpr, - OverloadResolveContext& context) - { - if(auto baseDeclRefExpr = as(baseExpr)) - { - auto declRef = baseDeclRefExpr->declRef; - AddGenericOverloadCandidate(LookupResultItem(declRef), context); - } - else if (auto overloadedExpr = as(baseExpr)) - { - // We are referring to a bunch of declarations, each of which might be generic - LookupResult result; - for (auto item : overloadedExpr->lookupResult2.items) - { - AddGenericOverloadCandidate(item, context); - } - } - else - { - // any other cases? - } - } - - RefPtr visitGenericAppExpr(GenericAppExpr* genericAppExpr) - { - // Start by checking the base expression and arguments. - auto& baseExpr = genericAppExpr->FunctionExpr; - baseExpr = CheckTerm(baseExpr); - auto& args = genericAppExpr->Arguments; - for (auto& arg : args) - { - arg = CheckTerm(arg); - } - - return checkGenericAppWithCheckedArgs(genericAppExpr); - } - - /// Check a generic application where the operands have already been checked. - RefPtr checkGenericAppWithCheckedArgs(GenericAppExpr* genericAppExpr) - { - // We are applying a generic to arguments, but there might be multiple generic - // declarations with the same name, so this becomes a specialized case of - // overload resolution. - - auto& baseExpr = genericAppExpr->FunctionExpr; - auto& args = genericAppExpr->Arguments; - - // If there was an error in the base expression, or in any of - // the arguments, then just bail. - if (IsErrorExpr(baseExpr)) - { - return CreateErrorExpr(genericAppExpr); - } - for (auto argExpr : args) - { - if (IsErrorExpr(argExpr)) - { - return CreateErrorExpr(genericAppExpr); - } - } - - // Otherwise, let's start looking at how to find an overload... - - OverloadResolveContext context; - context.originalExpr = genericAppExpr; - context.funcLoc = baseExpr->loc; - context.argCount = args.getCount(); - context.args = args.getBuffer(); - context.loc = genericAppExpr->loc; - - context.baseExpr = GetBaseExpr(baseExpr); - - AddGenericOverloadCandidates(baseExpr, context); - - if (context.bestCandidates.getCount() > 0) - { - // Things were ambiguous. - if (context.bestCandidates[0].status != OverloadCandidate::Status::Applicable) - { - // There were multiple equally-good candidates, but none actually usable. - // We will construct a diagnostic message to help out. - - // TODO(tfoley): print a reasonable message here... - - getSink()->diagnose(genericAppExpr, Diagnostics::unimplemented, "no applicable generic"); - - return CreateErrorExpr(genericAppExpr); - } - else - { - // There were multiple viable candidates, but that isn't an error: we just need - // to complete all of them and create an overloaded expression as a result. - - auto overloadedExpr = new OverloadedExpr2(); - overloadedExpr->base = context.baseExpr; - for (auto candidate : context.bestCandidates) - { - auto candidateExpr = CompleteOverloadCandidate(context, candidate); - overloadedExpr->candidiateExprs.add(candidateExpr); - } - return overloadedExpr; - } - } - else if (context.bestCandidate) - { - // There was one best candidate, even if it might not have been - // applicable in the end. - // We will report errors for this one candidate, then, to give - // the user the most help we can. - return CompleteOverloadCandidate(context, *context.bestCandidate); - } - else - { - // Nothing at all was found that we could even consider invoking - getSink()->diagnose(genericAppExpr, Diagnostics::unimplemented, "expected a generic"); - return CreateErrorExpr(genericAppExpr); - } - } - - RefPtr visitSharedTypeExpr(SharedTypeExpr* expr) - { - if (!expr->type.Ptr()) - { - expr->base = CheckProperType(expr->base); - expr->type = expr->base.exp->type; - } - return expr; - } - - RefPtr visitTaggedUnionTypeExpr(TaggedUnionTypeExpr* expr) - { - // We have an expression of the form `__TaggedUnion(A, B, ...)` - // which will evaluate to a tagged-union type over `A`, `B`, etc. - // - RefPtr type = new TaggedUnionType(); - expr->type = QualType(getTypeType(type)); - - for( auto& caseTypeExpr : expr->caseTypes ) - { - caseTypeExpr = CheckProperType(caseTypeExpr); - type->caseTypes.add(caseTypeExpr.type); - } - - return expr; - } - - - - - RefPtr CheckExpr(RefPtr expr) - { - auto term = CheckTerm(expr); - - // TODO(tfoley): Need a step here to ensure that the term actually - // resolves to a (single) expression with a real type. - - return term; - } - - RefPtr CheckInvokeExprWithCheckedOperands(InvokeExpr *expr) - { - auto rs = ResolveInvoke(expr); - if (auto invoke = as(rs.Ptr())) - { - // if this is still an invoke expression, test arguments passed to inout/out parameter are LValues - if(auto funcType = as(invoke->FunctionExpr->type)) - { - Index paramCount = funcType->getParamCount(); - for (Index pp = 0; pp < paramCount; ++pp) - { - auto paramType = funcType->getParamType(pp); - if (as(paramType) || as(paramType)) - { - // `out`, `inout`, and `ref` parameters currently require - // an *exact* match on the type of the argument. - // - // TODO: relax this requirement by allowing an argument - // for an `inout` parameter to be converted in both - // directions. - // - if( pp < expr->Arguments.getCount() ) - { - auto argExpr = expr->Arguments[pp]; - if( !argExpr->type.IsLeftValue ) - { - getSink()->diagnose( - argExpr, - Diagnostics::argumentExpectedLValue, - pp); - - if( auto implicitCastExpr = as(argExpr) ) - { - getSink()->diagnose( - argExpr, - Diagnostics::implicitCastUsedAsLValue, - implicitCastExpr->Arguments[0]->type, - implicitCastExpr->type); - } - - maybeDiagnoseThisNotLValue(argExpr); - } - } - else - { - // There are two ways we could get here, both involving - // a call where the number of argument expressions is - // less than the number of parameters on the callee: - // - // 1. There might be fewer arguments than parameters - // because the trailing parameters should be defaulted - // - // 2. There might be fewer arguments than parameters - // because the call is incorrect. - // - // In case (2) an error would have already been diagnosed, - // and we don't want to emit another cascading error here. - // - // In case (1) this implies the user declared an `out` - // or `inout` parameter with a default argument expression. - // That should be an error, but it should be detected - // on the declaration instead of here at the use site. - // - // Thus, it makes sense to ignore this case here. - } - } - } - } - } - return rs; - } - - RefPtr visitInvokeExpr(InvokeExpr *expr) - { - // check the base expression first - expr->FunctionExpr = CheckExpr(expr->FunctionExpr); - // Next check the argument expressions - for (auto & arg : expr->Arguments) - { - arg = CheckExpr(arg); - } - - return CheckInvokeExprWithCheckedOperands(expr); - } - - - RefPtr visitVarExpr(VarExpr *expr) - { - // If we've already resolved this expression, don't try again. - if (expr->declRef) - return expr; - - expr->type = QualType(getSession()->getErrorType()); - auto lookupResult = lookUp( - getSession(), - this, expr->name, expr->scope); - if (lookupResult.isValid()) - { - return createLookupResultExpr( - lookupResult, - nullptr, - expr->loc); - } - - getSink()->diagnose(expr, Diagnostics::undefinedIdentifier2, expr->name); - - return expr; - } - - RefPtr visitTypeCastExpr(TypeCastExpr * expr) - { - // Check the term we are applying first - auto funcExpr = expr->FunctionExpr; - funcExpr = CheckTerm(funcExpr); - - // Now ensure that the term represnets a (proper) type. - TypeExp typeExp; - typeExp.exp = funcExpr; - typeExp = CheckProperType(typeExp); - - expr->FunctionExpr = typeExp.exp; - expr->type.type = typeExp.type; - - // Next check the argument expression (there should be only one) - for (auto & arg : expr->Arguments) - { - arg = CheckExpr(arg); - } - - // Now process this like any other explicit call (so casts - // and constructor calls are semantically equivalent). - return CheckInvokeExprWithCheckedOperands(expr); - } - - // Get the type to use when referencing a declaration - QualType GetTypeForDeclRef(DeclRef declRef) - { - return getTypeForDeclRef( - getSession(), - this, - getSink(), - declRef, - &typeResult); - } - - // - // Some syntax nodes should not occur in the concrete input syntax, - // and will only appear *after* checking is complete. We need to - // deal with this cases here, even if they are no-ops. - // - - #define CASE(NAME) \ - RefPtr visit##NAME(NAME* expr) \ - { \ - SLANG_DIAGNOSE_UNEXPECTED(getSink(), expr, \ - "should not appear in input syntax"); \ - return expr; \ - } - - CASE(DerefExpr) - CASE(SwizzleExpr) - CASE(OverloadedExpr) - CASE(OverloadedExpr2) - CASE(AggTypeCtorExpr) - CASE(CastToInterfaceExpr) - CASE(LetExpr) - CASE(ExtractExistentialValueExpr) - - #undef CASE - - // - // - // - - RefPtr MaybeDereference(RefPtr inExpr) - { - RefPtr expr = inExpr; - for (;;) - { - auto baseType = expr->type; - if (auto pointerLikeType = as(baseType)) - { - auto elementType = QualType(pointerLikeType->elementType); - elementType.IsLeftValue = baseType.IsLeftValue; - - auto derefExpr = new DerefExpr(); - derefExpr->base = expr; - derefExpr->type = elementType; - - expr = derefExpr; - continue; - } - - // Default case: just use the expression as-is - return expr; - } - } - - RefPtr CheckSwizzleExpr( - MemberExpr* memberRefExpr, - RefPtr baseElementType, - IntegerLiteralValue baseElementCount) - { - RefPtr swizExpr = new SwizzleExpr(); - swizExpr->loc = memberRefExpr->loc; - swizExpr->base = memberRefExpr->BaseExpression; - - IntegerLiteralValue limitElement = baseElementCount; - - int elementIndices[4]; - int elementCount = 0; - - bool elementUsed[4] = { false, false, false, false }; - bool anyDuplicates = false; - bool anyError = false; - - auto swizzleText = getText(memberRefExpr->name); - - for (Index i = 0; i < swizzleText.getLength(); i++) - { - auto ch = swizzleText[i]; - int elementIndex = -1; - switch (ch) - { - case 'x': case 'r': elementIndex = 0; break; - case 'y': case 'g': elementIndex = 1; break; - case 'z': case 'b': elementIndex = 2; break; - case 'w': case 'a': elementIndex = 3; break; - default: - // An invalid character in the swizzle is an error - getSink()->diagnose(swizExpr, Diagnostics::invalidSwizzleExpr, swizzleText, baseElementType->ToString()); - anyError = true; - continue; - } - - // TODO(tfoley): GLSL requires that all component names - // come from the same "family"... - - // Make sure the index is in range for the source type - if (elementIndex >= limitElement) - { - getSink()->diagnose(swizExpr, Diagnostics::invalidSwizzleExpr, swizzleText, baseElementType->ToString()); - anyError = true; - continue; - } - - // Check if we've seen this index before - for (int ee = 0; ee < elementCount; ee++) - { - if (elementIndices[ee] == elementIndex) - anyDuplicates = true; - } - - // add to our list... - elementIndices[elementCount++] = elementIndex; - } - - for (int ee = 0; ee < elementCount; ++ee) - { - swizExpr->elementIndices[ee] = elementIndices[ee]; - } - swizExpr->elementCount = elementCount; - - if (anyError) - { - return CreateErrorExpr(memberRefExpr); - } - else if (elementCount == 1) - { - // single-component swizzle produces a scalar - // - // Note(tfoley): the official HLSL rules seem to be that it produces - // a one-component vector, which is then implicitly convertible to - // a scalar, but that seems like it just adds complexity. - swizExpr->type = QualType(baseElementType); - } - else - { - // TODO(tfoley): would be nice to "re-sugar" type - // here if the input type had a sugared name... - swizExpr->type = QualType(createVectorType( - baseElementType, - new ConstantIntVal(elementCount))); - } - - // A swizzle can be used as an l-value as long as there - // were no duplicates in the list of components - swizExpr->type.IsLeftValue = !anyDuplicates; - - return swizExpr; - } - - RefPtr CheckSwizzleExpr( - MemberExpr* memberRefExpr, - RefPtr baseElementType, - RefPtr baseElementCount) - { - if (auto constantElementCount = as(baseElementCount)) - { - return CheckSwizzleExpr(memberRefExpr, baseElementType, constantElementCount->value); - } - else - { - getSink()->diagnose(memberRefExpr, Diagnostics::unimplemented, "swizzle on vector of unknown size"); - return CreateErrorExpr(memberRefExpr); - } - } - - // Look up a static member - // @param expr Can be StaticMemberExpr or MemberExpr - // @param baseExpression Is the underlying type expression determined from resolving expr - RefPtr _lookupStaticMember(RefPtr expr, RefPtr baseExpression) - { - auto& baseType = baseExpression->type; - - if (auto typeType = as(baseType)) - { - // We are looking up a member inside a type. - // We want to be careful here because we should only find members - // that are implicitly or explicitly `static`. - // - // TODO: this duplicates a *lot* of logic with the case below. - // We need to fix that. - auto type = typeType->type; - - if (as(type)) - { - return CreateErrorExpr(expr); - } - - LookupResult lookupResult = lookUpMember( - getSession(), - this, - expr->name, - type); - if (!lookupResult.isValid()) - { - return lookupMemberResultFailure(expr, baseType); - } - - // We need to confirm that whatever member we - // are trying to refer to is usable via static reference. - // - // TODO: eventually we might allow a non-static - // member to be adapted by turning it into something - // like a closure that takes the missing `this` parameter. - // - // E.g., a static reference to a method could be treated - // as a value with a function type, where the first parameter - // is `type`. - // - // The biggest challenge there is that we'd need to arrange - // to generate "dispatcher" functions that could be used - // to implement that function, in the case where we are - // making a static reference to some kind of polymorphic declaration. - // - // (Also, static references to fields/properties would get even - // harder, because you'd have to know whether a getter/setter/ref-er - // is needed). - // - // For now let's just be expedient and disallow all of that, because - // we can always add it back in later. - - if (!lookupResult.isOverloaded()) - { - // The non-overloaded case is relatively easy. We just want - // to look at the member being referenced, and check if - // it is allowed in a `static` context: - // - if (!isUsableAsStaticMember(lookupResult.item)) - { - getSink()->diagnose( - expr->loc, - Diagnostics::staticRefToNonStaticMember, - type, - expr->name); - } - } - else - { - // The overloaded case is trickier, because we should first - // filter the list of candidates, because if there is anything - // that *is* usable in a static context, then we should assume - // the user just wants to reference that. We should only - // issue an error if *all* of the items that were discovered - // are non-static. - bool anyNonStatic = false; - List staticItems; - for (auto item : lookupResult.items) - { - // Is this item usable as a static member? - if (isUsableAsStaticMember(item)) - { - // If yes, then it will be part of the output. - staticItems.add(item); - } - else - { - // If no, then we might need to output an error. - anyNonStatic = true; - } - } - - // Was there anything non-static in the list? - if (anyNonStatic) - { - // If we had some static items, then that's okay, - // we just want to use our newly-filtered list. - if (staticItems.getCount()) - { - lookupResult.items = staticItems; - } - else - { - // Otherwise, it is time to report an error. - getSink()->diagnose( - expr->loc, - Diagnostics::staticRefToNonStaticMember, - type, - expr->name); - } - } - // If there were no non-static items, then the `items` - // array already represents what we'd get by filtering... - } - - return createLookupResultExpr( - lookupResult, - baseExpression, - expr->loc); - } - else if (as(baseType)) - { - return CreateErrorExpr(expr); - } - - // Failure - return lookupMemberResultFailure(expr, baseType); - } - - RefPtr visitStaticMemberExpr(StaticMemberExpr* expr) - { - expr->BaseExpression = CheckExpr(expr->BaseExpression); - - // Not sure this is needed -> but guess someone could do - expr->BaseExpression = MaybeDereference(expr->BaseExpression); - - // If the base of the member lookup has an interface type - // *without* a suitable this-type substitution, then we are - // trying to perform lookup on a value of existential type, - // and we should "open" the existential here so that we - // can expose its structure. - // - - expr->BaseExpression = maybeOpenExistential(expr->BaseExpression); - // Do a static lookup - return _lookupStaticMember(expr, expr->BaseExpression); - } - - RefPtr lookupMemberResultFailure( - DeclRefExpr* expr, - QualType const& baseType) - { - // Check it's a member expression - SLANG_ASSERT(as(expr) || as(expr)); - - getSink()->diagnose(expr, Diagnostics::noMemberOfNameInType, expr->name, baseType); - expr->type = QualType(getSession()->getErrorType()); - return expr; - } - - RefPtr visitMemberExpr(MemberExpr * expr) - { - expr->BaseExpression = CheckExpr(expr->BaseExpression); - - expr->BaseExpression = MaybeDereference(expr->BaseExpression); - - // If the base of the member lookup has an interface type - // *without* a suitable this-type substitution, then we are - // trying to perform lookup on a value of existential type, - // and we should "open" the existential here so that we - // can expose its structure. - // - expr->BaseExpression = maybeOpenExistential(expr->BaseExpression); - - auto & baseType = expr->BaseExpression->type; - - // Note: Checking for vector types before declaration-reference types, - // because vectors are also declaration reference types... - // - // Also note: the way this is done right now means that the ability - // to swizzle vectors interferes with any chance of looking up - // members via extension, for vector or scalar types. - // - // TODO: Matrix swizzles probably need to be handled at some point. - if (auto baseVecType = as(baseType)) - { - return CheckSwizzleExpr( - expr, - baseVecType->elementType, - baseVecType->elementCount); - } - else if(auto baseScalarType = as(baseType)) - { - // Treat scalar like a 1-element vector when swizzling - return CheckSwizzleExpr( - expr, - baseScalarType, - 1); - } - else if(auto typeType = as(baseType)) - { - return _lookupStaticMember(expr, expr->BaseExpression); - } - else if (as(baseType)) - { - return CreateErrorExpr(expr); - } - else - { - LookupResult lookupResult = lookUpMember( - getSession(), - this, - expr->name, - baseType.Ptr()); - if (!lookupResult.isValid()) - { - return lookupMemberResultFailure(expr, baseType); - } - - // TODO: need to filter for declarations that are valid to refer - // to in this context... - - return createLookupResultExpr( - lookupResult, - expr->BaseExpression, - expr->loc); - } - } - SemanticsVisitor & operator = (const SemanticsVisitor &) = delete; - - - // - - RefPtr visitInitializerListExpr(InitializerListExpr* expr) - { - // When faced with an initializer list, we first just check the sub-expressions blindly. - // Actually making them conform to a desired type will wait for when we know the desired - // type based on context. - - for( auto& arg : expr->args ) - { - arg = CheckTerm(arg); - } - - expr->type = getSession()->getInitializerListType(); - - return expr; - } - - void importModuleIntoScope(Scope* scope, ModuleDecl* moduleDecl) - { - // If we've imported this one already, then - // skip the step where we modify the current scope. - if (importedModules.Contains(moduleDecl)) - { - return; - } - importedModules.Add(moduleDecl); - - - // Create a new sub-scope to wire the module - // into our lookup chain. - auto subScope = new Scope(); - subScope->containerDecl = moduleDecl; - - subScope->nextSibling = scope->nextSibling; - scope->nextSibling = subScope; - - // Also import any modules from nested `import` declarations - // with the `__exported` modifier - for (auto importDecl : moduleDecl->getMembersOfType()) - { - if (!importDecl->HasModifier()) - continue; - - importModuleIntoScope(scope, importDecl->importedModuleDecl.Ptr()); - } - } - - void visitEmptyDecl(EmptyDecl* /*decl*/) - { - // nothing to do - } - - void visitImportDecl(ImportDecl* decl) - { - if(decl->IsChecked(DeclCheckState::CheckedHeader)) - return; - - // We need to look for a module with the specified name - // (whether it has already been loaded, or needs to - // be loaded), and then put its declarations into - // the current scope. - - auto name = decl->moduleNameAndLoc.name; - auto scope = decl->scope; - - // Try to load a module matching the name - auto importedModule = findOrImportModule( - getLinkage(), - name, - decl->moduleNameAndLoc.loc, - getSink()); - - // If we didn't find a matching module, then bail out - if (!importedModule) - return; - - // Record the module that was imported, so that we can use - // it later during code generation. - auto importedModuleDecl = importedModule->getModuleDecl(); - decl->importedModuleDecl = importedModuleDecl; - - // Add the declarations from the imported module into the scope - // that the `import` declaration is set to extend. - // - importModuleIntoScope(scope.Ptr(), importedModuleDecl); - - // Record the `import`ed module (and everything it depends on) - // as a dependency of the module we are compiling. - if(auto module = getModule(decl)) - { - module->addModuleDependency(importedModule); - } - - decl->SetCheckState(getCheckedState()); - } - - // Perform semantic checking of an object-oriented `this` - // expression. - RefPtr visitThisExpr(ThisExpr* expr) - { - // A `this` expression will default to immutable. - expr->type.IsLeftValue = false; - - // We will do an upwards search starting in the current - // scope, looking for a surrounding type (or `extension`) - // declaration that could be the referrant of the expression. - auto scope = expr->scope; - while (scope) - { - auto containerDecl = scope->containerDecl; - - if( auto funcDeclBase = as(containerDecl) ) - { - if( funcDeclBase->HasModifier() ) - { - expr->type.IsLeftValue = true; - } - } - else if (auto aggTypeDecl = as(containerDecl)) - { - checkDecl(aggTypeDecl); - - // Okay, we are using `this` in the context of an - // aggregate type, so the expression should be - // of the corresponding type. - expr->type.type = DeclRefType::Create( - getSession(), - makeDeclRef(aggTypeDecl)); - return expr; - } - else if (auto extensionDecl = as(containerDecl)) - { - checkDecl(extensionDecl); - - // When `this` is used in the context of an `extension` - // declaration, then it should refer to an instance of - // the type being extended. - // - // TODO: There is potentially a small gotcha here that - // lookup through such a `this` expression should probably - // prioritize members declared in the current extension - // if there are multiple extensions in scope that add - // members with the same name... - // - expr->type.type = extensionDecl->targetType.type; - return expr; - } - - scope = scope->parent; - } - - getSink()->diagnose(expr, Diagnostics::thisExpressionOutsideOfTypeDecl); - return CreateErrorExpr(expr); - } - }; - - bool isPrimaryDecl( - CallableDecl* decl) - { - SLANG_ASSERT(decl); - return (!decl->primaryDecl) || (decl == decl->primaryDecl); - } - - RefPtr checkProperType( - Linkage* linkage, - TypeExp typeExp, - DiagnosticSink* sink) - { - SemanticsVisitor visitor( - linkage, - sink); - auto typeOut = visitor.CheckProperType(typeExp); - return typeOut.type; - } - - - FuncDecl* findFunctionDeclByName( - Module* translationUnit, - Name* name, - DiagnosticSink* sink) - { - auto translationUnitSyntax = translationUnit->getModuleDecl(); - - // Make sure we've got a query-able member dictionary - buildMemberDictionary(translationUnitSyntax); - - // We will look up any global-scope declarations in the translation - // unit that match the name of our entry point. - Decl* firstDeclWithName = nullptr; - if (!translationUnitSyntax->memberDictionary.TryGetValue(name, firstDeclWithName)) - { - // If there doesn't appear to be any such declaration, then we are done. - - sink->diagnose(translationUnitSyntax, Diagnostics::entryPointFunctionNotFound, name); - - return nullptr; - } - - // We found at least one global-scope declaration with the right name, - // but (1) it might not be a function, and (2) there might be - // more than one function. - // - // We'll walk the linked list of declarations with the same name, - // to see what we find. Along the way we'll keep track of the - // first function declaration we find, if any: - FuncDecl* entryPointFuncDecl = nullptr; - for (auto ee = firstDeclWithName; ee; ee = ee->nextInContainerWithSameName) - { - // Is this declaration a function? - if (auto funcDecl = as(ee)) - { - // Skip non-primary declarations, so that - // we don't give an error when an entry - // point is forward-declared. - if (!isPrimaryDecl(funcDecl)) - continue; - - // is this the first one we've seen? - if (!entryPointFuncDecl) - { - // If so, this is a candidate to be - // the entry point function. - entryPointFuncDecl = funcDecl; - } - else - { - // Uh-oh! We've already seen a function declaration with this - // name before, so the whole thing is ambiguous. We need - // to diagnose and bail out. - - sink->diagnose(translationUnitSyntax, Diagnostics::ambiguousEntryPoint, name); - - // List all of the declarations that the user *might* mean - for (auto ff = firstDeclWithName; ff; ff = ff->nextInContainerWithSameName) - { - if (auto candidate = as(ff)) - { - sink->diagnose(candidate, Diagnostics::entryPointCandidate, candidate->getName()); - } - } - - // Bail out. - return nullptr; - } - } - } - - return entryPointFuncDecl; - } - - static bool isValidThreadDispatchIDType(Type* type) - { - // Can accept a single int/unit - { - auto basicType = as(type); - if (basicType) - { - return (basicType->baseType == BaseType::Int || basicType->baseType == BaseType::UInt); - } - } - // Can be an int/uint vector from size 1 to 3 - { - auto vectorType = as(type); - if (!vectorType) - { - return false; - } - auto elemCount = as(vectorType->elementCount); - if (elemCount->value < 1 || elemCount->value > 3) - { - return false; - } - // Must be a basic type - auto basicType = as(vectorType->elementType); - if (!basicType) - { - return false; - } - - // Must be integral - return (basicType->baseType == BaseType::Int || basicType->baseType == BaseType::UInt); - } - } - - /// Recursively walk `paramDeclRef` and add any required existential slots to `ioSlots`. - static void _collectExistentialTypeParamsRec( - ExistentialTypeSlots& ioSlots, - DeclRef paramDeclRef); - - /// Recursively walk `type` and discover any required existential type parameters. - static void _collectExistentialTypeParamsRec( - ExistentialTypeSlots& ioSlots, - Type* type) - { - // Whether or not something is an array does not affect - // the number of existential slots it introduces. - // - while( auto arrayType = as(type) ) - { - type = arrayType->baseType; - } - - if( auto parameterGroupType = as(type) ) - { - _collectExistentialTypeParamsRec(ioSlots, parameterGroupType->getElementType()); - return; - } - - if( auto declRefType = as(type) ) - { - auto typeDeclRef = declRefType->declRef; - if( auto interfaceDeclRef = typeDeclRef.as() ) - { - // Each leaf parameter of interface type adds one slot. - // - ioSlots.paramTypes.add(type); - } - else if( auto structDeclRef = typeDeclRef.as() ) - { - // A structure type should recursively introduce - // existential slots for its fields. - // - for( auto fieldDeclRef : GetFields(structDeclRef) ) - { - if(fieldDeclRef.getDecl()->HasModifier()) - continue; - - _collectExistentialTypeParamsRec(ioSlots, fieldDeclRef); - } - } - } - - // TODO: We eventually need to handle cases like constant - // buffers and parameter blocks that may have existential - // element types. - } - - static void _collectExistentialTypeParamsRec( - ExistentialTypeSlots& ioSlots, - DeclRef paramDeclRef) - { - _collectExistentialTypeParamsRec(ioSlots, GetType(paramDeclRef)); - } - - - /// Add information about a shader parameter to `ioParams` and `ioSlots` - static void _collectExistentialSlotsForShaderParam( - ShaderParamInfo& ioParamInfo, - ExistentialTypeSlots& ioSlots, - DeclRef paramDeclRef) - { - Index startSlot = ioSlots.paramTypes.getCount(); - _collectExistentialTypeParamsRec(ioSlots, paramDeclRef); - Index endSlot = ioSlots.paramTypes.getCount(); - - ioParamInfo.firstExistentialTypeSlot = UInt(startSlot); - ioParamInfo.existentialTypeSlotCount = UInt(endSlot - startSlot);; - } - - /// Enumerate the existential-type parameters of an `EntryPoint`. - /// - /// Any parameters found will be added to the list of existential slots on `this`. - /// - void EntryPoint::_collectShaderParams() - { - // Note: we defensively test whether there is a function decl-ref - // because this routine gets called from the constructor, and - // a "dummy" entry point will have a null pointer for the function. - // - if( auto funcDeclRef = getFuncDeclRef() ) - { - for( auto paramDeclRef : GetParameters(funcDeclRef) ) - { - ShaderParamInfo shaderParamInfo; - shaderParamInfo.paramDeclRef = paramDeclRef; - - _collectExistentialSlotsForShaderParam( - shaderParamInfo, - m_existentialSlots, - paramDeclRef); - - m_shaderParams.add(shaderParamInfo); - } - } - } - - // Validate that an entry point function conforms to any additional - // constraints based on the stage (and profile?) it specifies. - void validateEntryPoint( - EntryPoint* entryPoint, - DiagnosticSink* sink) - { - auto entryPointFuncDecl = entryPoint->getFuncDecl(); - auto stage = entryPoint->getStage(); - - // TODO: We currently do minimal checking here, but this is the - // right place to perform the following validation checks: - // - - // * Are the function input/output parameters and result type - // all valid for the chosen stage? (e.g., there shouldn't be - // an `OutputStream` type in a vertex shader signature) - // - // * For any varying input/output, are there semantics specified - // (Note: this potentially overlaps with layout logic...), and - // are the system-value semantics valid for the given stage? - // - // There's actually a lot of detail to semantic checking, in - // that the AST-level code should probably be validating the - // use of system-value semantics by linking them to explicit - // declarations in the standard library. We should also be - // using profile information on those declarations to infer - // appropriate profile restrictions on the entry point. - // - // * Is the entry point actually usable on the given stage/profile? - // E.g., if we have a vertex shader that (transitively) calls - // `Texture2D.Sample`, then that should produce an error because - // that function is specific to the fragment profile/stage. - // - - auto entryPointName = entryPointFuncDecl->getName(); - - auto module = getModule(entryPointFuncDecl); - auto linkage = module->getLinkage(); - - - // Every entry point needs to have a stage specified either via - // command-line/API options, or via an explicit `[shader("...")]` attribute. - // - if( stage == Stage::Unknown ) - { - sink->diagnose(entryPointFuncDecl, Diagnostics::entryPointHasNoStage, entryPointName); - } - - if( stage == Stage::Hull ) - { - // TODO: We could consider *always* checking any `[patchconsantfunc("...")]` - // attributes, so that they need to resolve to a function. - - auto attr = entryPointFuncDecl->FindModifier(); - - if (attr) - { - if (attr->args.getCount() != 1) - { - sink->diagnose(attr, Diagnostics::badlyDefinedPatchConstantFunc, entryPointName); - return; - } - - Expr* expr = attr->args[0]; - StringLiteralExpr* stringLit = as(expr); - - if (!stringLit) - { - sink->diagnose(expr, Diagnostics::badlyDefinedPatchConstantFunc, entryPointName); - return; - } - - // We look up the patch-constant function by its name in the module - // scope of the translation unit that declared the HS entry point. - // - // TODO: Eventually we probably want to do the lookup in the scope - // of the parent declarations of the entry point. E.g., if the entry - // point is a member function of a `struct`, then its patch-constant - // function should be allowed to be another member function of - // the same `struct`. - // - // In the extremely long run we may want to support an alternative to - // this attribute-based linkage between the two functions that - // make up the entry point. - // - Name* name = linkage->getNamePool()->getName(stringLit->value); - FuncDecl* patchConstantFuncDecl = findFunctionDeclByName( - module, - name, - sink); - if (!patchConstantFuncDecl) - { - sink->diagnose(expr, Diagnostics::attributeFunctionNotFound, name, "patchconstantfunc"); - return; - } - - attr->patchConstantFuncDecl = patchConstantFuncDecl; - } - } - else if(stage == Stage::Compute) - { - for(const auto& param : entryPointFuncDecl->GetParameters()) - { - if(auto semantic = param->FindModifier()) - { - const auto& semanticToken = semantic->name; - - String lowerName = String(semanticToken.Content).toLower(); - - if(lowerName == "sv_dispatchthreadid") - { - Type* paramType = param->getType(); - - if(!isValidThreadDispatchIDType(paramType)) - { - String typeString = paramType->ToString(); - sink->diagnose(param->loc, Diagnostics::invalidDispatchThreadIDType, typeString); - return; - } - } - } - } - } - } - - // Given an entry point specified via API or command line options, - // attempt to find a matching AST declaration that implements the specified - // entry point. If such a function is found, then validate that it actually - // meets the requirements for the selected stage/profile. - // - // Returns an `EntryPoint` object representing the (unspecialized) - // entry point if it is found and validated, and null otherwise. - // - RefPtr findAndValidateEntryPoint( - FrontEndEntryPointRequest* entryPointReq) - { - // The first step in validating the entry point is to find - // the (unique) function declaration that matches its name. - // - // TODO: We may eventually want/need to extend this to - // account for nested names like `SomeStruct.vsMain`, or - // indeed even to handle generics. - // - auto compileRequest = entryPointReq->getCompileRequest(); - auto translationUnit = entryPointReq->getTranslationUnit(); - auto sink = compileRequest->getSink(); - auto translationUnitSyntax = translationUnit->getModuleDecl(); - - auto entryPointName = entryPointReq->getName(); - - // Make sure we've got a query-able member dictionary - buildMemberDictionary(translationUnitSyntax); - - // We will look up any global-scope declarations in the translation - // unit that match the name of our entry point. - Decl* firstDeclWithName = nullptr; - if( !translationUnitSyntax->memberDictionary.TryGetValue(entryPointName, firstDeclWithName) ) - { - // If there doesn't appear to be any such declaration, then - // we need to diagnose it as an error, and then bail out. - sink->diagnose(translationUnitSyntax, Diagnostics::entryPointFunctionNotFound, entryPointName); - return nullptr; - } - - // We found at least one global-scope declaration with the right name, - // but (1) it might not be a function, and (2) there might be - // more than one function. - // - // We'll walk the linked list of declarations with the same name, - // to see what we find. Along the way we'll keep track of the - // first function declaration we find, if any: - // - FuncDecl* entryPointFuncDecl = nullptr; - for(auto ee = firstDeclWithName; ee; ee = ee->nextInContainerWithSameName) - { - // We want to support the case where the declaration is - // a generic function, so we will automatically - // unwrap any outer `GenericDecl` we find here. - // - auto decl = ee; - if(auto genericDecl = as(decl)) - decl = genericDecl->inner; - - // Is this declaration a function? - if (auto funcDecl = as(decl)) - { - // Skip non-primary declarations, so that - // we don't give an error when an entry - // point is forward-declared. - if (!isPrimaryDecl(funcDecl)) - continue; - - // is this the first one we've seen? - if (!entryPointFuncDecl) - { - // If so, this is a candidate to be - // the entry point function. - entryPointFuncDecl = funcDecl; - } - else - { - // Uh-oh! We've already seen a function declaration with this - // name before, so the whole thing is ambiguous. We need - // to diagnose and bail out. - - sink->diagnose(translationUnitSyntax, Diagnostics::ambiguousEntryPoint, entryPointName); - - // List all of the declarations that the user *might* mean - for (auto ff = firstDeclWithName; ff; ff = ff->nextInContainerWithSameName) - { - if (auto candidate = as(ff)) - { - sink->diagnose(candidate, Diagnostics::entryPointCandidate, candidate->getName()); - } - } - - // Bail out. - return nullptr; - } - } - } - - // Did we find a function declaration in our search? - if(!entryPointFuncDecl) - { - // If not, then we need to diagnose the error. - // For convenience, we will point to the first - // declaration with the right name, that wasn't a function. - sink->diagnose(firstDeclWithName, Diagnostics::entryPointSymbolNotAFunction, entryPointName); - return nullptr; - } - - // TODO: it is possible that the entry point was declared with - // profile or target overloading. Is there anything that we need - // to do at this point to filter out declarations that aren't - // relevant to the selected profile for the entry point? - - // We found something, and can start doing some basic checking. - // - // If the entry point specifies a stage via a `[shader("...")]` attribute, - // then we might be able to infer a stage for the entry point request if - // it didn't have one, *or* issue a diagnostic if there is a mismatch. - // - auto entryPointProfile = entryPointReq->getProfile(); - if( auto entryPointAttribute = entryPointFuncDecl->FindModifier() ) - { - auto entryPointStage = entryPointProfile.GetStage(); - if( entryPointStage == Stage::Unknown ) - { - entryPointProfile.setStage(entryPointAttribute->stage); - } - else if( entryPointAttribute->stage != entryPointStage ) - { - sink->diagnose(entryPointFuncDecl, Diagnostics::specifiedStageDoesntMatchAttribute, entryPointName, entryPointStage, entryPointAttribute->stage); - } - } - else - { - // TODO: Should we attach a `[shader(...)]` attribute to an - // entry point that didn't have one, so that we can have - // a more uniform representation in the AST? - } - - - RefPtr entryPoint = EntryPoint::create( - makeDeclRef(entryPointFuncDecl), - entryPointProfile); - - // Now that we've *found* the entry point, it is time to validate - // that it actually meets the constraints for the chosen stage/profile. - // - validateEntryPoint(entryPoint, sink); - - return entryPoint; - } - - /// Get the name a variable will use for reflection purposes -Name* getReflectionName(VarDeclBase* varDecl) -{ - if (auto reflectionNameModifier = varDecl->FindModifier()) - return reflectionNameModifier->nameAndLoc.name; - - return varDecl->getName(); -} - -// Information tracked when doing a structural -// match of types. -struct StructuralTypeMatchStack -{ - DeclRef leftDecl; - DeclRef rightDecl; - StructuralTypeMatchStack* parent; -}; - -static void diagnoseParameterTypeMismatch( - DiagnosticSink* sink, - StructuralTypeMatchStack* inStack) -{ - SLANG_ASSERT(inStack); - - // The bottom-most entry in the stack should represent - // the shader parameters that kicked things off - auto stack = inStack; - while(stack->parent) - stack = stack->parent; - - sink->diagnose(stack->leftDecl, Diagnostics::shaderParameterDeclarationsDontMatch, getReflectionName(stack->leftDecl)); - sink->diagnose(stack->rightDecl, Diagnostics::seeOtherDeclarationOf, getReflectionName(stack->rightDecl)); -} - -// Two types that were expected to match did not. -// Inform the user with a suitable message. -static void diagnoseTypeMismatch( - DiagnosticSink* sink, - StructuralTypeMatchStack* inStack) -{ - auto stack = inStack; - SLANG_ASSERT(stack); - diagnoseParameterTypeMismatch(sink, stack); - - auto leftType = GetType(stack->leftDecl); - auto rightType = GetType(stack->rightDecl); - - if( stack->parent ) - { - sink->diagnose(stack->leftDecl, Diagnostics::fieldTypeMisMatch, getReflectionName(stack->leftDecl), leftType, rightType); - sink->diagnose(stack->rightDecl, Diagnostics::seeOtherDeclarationOf, getReflectionName(stack->rightDecl)); - - stack = stack->parent; - if( stack ) - { - while( stack->parent ) - { - sink->diagnose(stack->leftDecl, Diagnostics::usedInDeclarationOf, getReflectionName(stack->leftDecl)); - stack = stack->parent; - } - } - } - else - { - sink->diagnose(stack->leftDecl, Diagnostics::shaderParameterTypeMismatch, leftType, rightType); - } -} - -// Two types that were expected to match did not. -// Inform the user with a suitable message. -static void diagnoseTypeFieldsMismatch( - DiagnosticSink* sink, - DeclRef const& left, - DeclRef const& right, - StructuralTypeMatchStack* stack) -{ - diagnoseParameterTypeMismatch(sink, stack); - - sink->diagnose(left, Diagnostics::fieldDeclarationsDontMatch, left.GetName()); - sink->diagnose(right, Diagnostics::seeOtherDeclarationOf, right.GetName()); - - if( stack ) - { - while( stack->parent ) - { - sink->diagnose(stack->leftDecl, Diagnostics::usedInDeclarationOf, getReflectionName(stack->leftDecl)); - stack = stack->parent; - } - } -} - -static void collectFields( - DeclRef declRef, - List>& outFields) -{ - for( auto fieldDeclRef : getMembersOfType(declRef) ) - { - if(fieldDeclRef.getDecl()->HasModifier()) - continue; - - outFields.add(fieldDeclRef); - } -} - -static bool validateTypesMatch( - DiagnosticSink* sink, - Type* left, - Type* right, - StructuralTypeMatchStack* stack); - -static bool validateIntValuesMatch( - DiagnosticSink* sink, - IntVal* left, - IntVal* right, - StructuralTypeMatchStack* stack) -{ - if(left->EqualsVal(right)) - return true; - - // TODO: are there other cases we need to handle here? - - diagnoseTypeMismatch(sink, stack); - return false; -} - - -static bool validateValuesMatch( - DiagnosticSink* sink, - Val* left, - Val* right, - StructuralTypeMatchStack* stack) -{ - if( auto leftType = dynamicCast(left) ) - { - if( auto rightType = dynamicCast(right) ) - { - return validateTypesMatch(sink, leftType, rightType, stack); - } - } - - if( auto leftInt = dynamicCast(left) ) - { - if( auto rightInt = dynamicCast(right) ) - { - return validateIntValuesMatch(sink, leftInt, rightInt, stack); - } - } - - if( auto leftWitness = dynamicCast(left) ) - { - if( auto rightWitness = dynamicCast(right) ) - { - return true; - } - } - - diagnoseTypeMismatch(sink, stack); - return false; -} - -static bool validateGenericSubstitutionsMatch( - DiagnosticSink* sink, - GenericSubstitution* left, - GenericSubstitution* right, - StructuralTypeMatchStack* stack) -{ - if( !left ) - { - if( !right ) - { - return true; - } - - diagnoseTypeMismatch(sink, stack); - return false; - } - - - - Index argCount = left->args.getCount(); - if( argCount != right->args.getCount() ) - { - diagnoseTypeMismatch(sink, stack); - return false; - } - - for( Index aa = 0; aa < argCount; ++aa ) - { - auto leftArg = left->args[aa]; - auto rightArg = right->args[aa]; - - if(!validateValuesMatch(sink, leftArg, rightArg, stack)) - return false; - } - - return true; -} - -static bool validateThisTypeSubstitutionsMatch( - DiagnosticSink* /*sink*/, - ThisTypeSubstitution* /*left*/, - ThisTypeSubstitution* /*right*/, - StructuralTypeMatchStack* /*stack*/) -{ - // TODO: actual checking. - return true; -} - -static bool validateSpecializationsMatch( - DiagnosticSink* sink, - SubstitutionSet left, - SubstitutionSet right, - StructuralTypeMatchStack* stack) -{ - auto ll = left.substitutions; - auto rr = right.substitutions; - for(;;) - { - // Skip any global generic substitutions. - if(auto leftGlobalGeneric = as(ll)) - { - ll = leftGlobalGeneric->outer; - continue; - } - if(auto rightGlobalGeneric = as(rr)) - { - rr = rightGlobalGeneric->outer; - continue; - } - - // If either ran out, then we expect both to have run out. - if(!ll || !rr) - return !ll && !rr; - - auto leftSubst = ll; - auto rightSubst = rr; - - ll = ll->outer; - rr = rr->outer; - - if(auto leftGeneric = as(leftSubst)) - { - if(auto rightGeneric = as(rightSubst)) - { - if(validateGenericSubstitutionsMatch(sink, leftGeneric, rightGeneric, stack)) - { - continue; - } - } - } - else if(auto leftThisType = as(leftSubst)) - { - if(auto rightThisType = as(rightSubst)) - { - if(validateThisTypeSubstitutionsMatch(sink, leftThisType, rightThisType, stack)) - { - continue; - } - } - } - - return false; - } - - return true; -} - -// Determine if two types "match" for the purposes of `cbuffer` layout rules. -// -static bool validateTypesMatch( - DiagnosticSink* sink, - Type* left, - Type* right, - StructuralTypeMatchStack* stack) -{ - if(left->Equals(right)) - return true; - - // It is possible that the types don't match exactly, but - // they *do* match structurally. - - // Note: the following code will lead to infinite recursion if there - // are ever recursive types. We'd need a more refined system to - // cache the matches we've already found. - - if( auto leftDeclRefType = as(left) ) - { - if( auto rightDeclRefType = as(right) ) - { - // Are they references to matching decl refs? - auto leftDeclRef = leftDeclRefType->declRef; - auto rightDeclRef = rightDeclRefType->declRef; - - // Do the reference the same declaration? Or declarations - // with the same name? - // - // TODO: we should only consider the same-name case if the - // declarations come from translation units being compiled - // (and not an imported module). - if( leftDeclRef.getDecl() == rightDeclRef.getDecl() - || leftDeclRef.GetName() == rightDeclRef.GetName() ) - { - // Check that any generic arguments match - if( !validateSpecializationsMatch( - sink, - leftDeclRef.substitutions, - rightDeclRef.substitutions, - stack) ) - { - return false; - } - - // Check that any declared fields match too. - if( auto leftStructDeclRef = leftDeclRef.as() ) - { - if( auto rightStructDeclRef = rightDeclRef.as() ) - { - List> leftFields; - List> rightFields; - - collectFields(leftStructDeclRef, leftFields); - collectFields(rightStructDeclRef, rightFields); - - Index leftFieldCount = leftFields.getCount(); - Index rightFieldCount = rightFields.getCount(); - - if( leftFieldCount != rightFieldCount ) - { - diagnoseTypeFieldsMismatch(sink, leftDeclRef, rightDeclRef, stack); - return false; - } - - for( Index ii = 0; ii < leftFieldCount; ++ii ) - { - auto leftField = leftFields[ii]; - auto rightField = rightFields[ii]; - - if( leftField.GetName() != rightField.GetName() ) - { - diagnoseTypeFieldsMismatch(sink, leftDeclRef, rightDeclRef, stack); - return false; - } - - auto leftFieldType = GetType(leftField); - auto rightFieldType = GetType(rightField); - - StructuralTypeMatchStack subStack; - subStack.parent = stack; - subStack.leftDecl = leftField; - subStack.rightDecl = rightField; - - if(!validateTypesMatch(sink, leftFieldType,rightFieldType, &subStack)) - return false; - } - } - } - - // Everything seemed to match recursively. - return true; - } - } - } - - // If we are looking at `T[N]` and `U[M]` we want to check that - // `T` is structurally equivalent to `U` and `N` is the same as `M`. - else if( auto leftArrayType = as(left) ) - { - if( auto rightArrayType = as(right) ) - { - if(!validateTypesMatch(sink, leftArrayType->baseType, rightArrayType->baseType, stack) ) - return false; - - if(!validateValuesMatch(sink, leftArrayType->ArrayLength, rightArrayType->ArrayLength, stack)) - return false; - - return true; - } - } - - diagnoseTypeMismatch(sink, stack); - return false; -} - -// This function is supposed to determine if two global shader -// parameter declarations represent the same logical parameter -// (so that they should get the exact same binding(s) allocated). -// -static bool doesParameterMatch( - DiagnosticSink* sink, - DeclRef varDeclRef, - DeclRef existingVarDeclRef) -{ - StructuralTypeMatchStack stack; - stack.parent = nullptr; - stack.leftDecl = varDeclRef; - stack.rightDecl = existingVarDeclRef; - - validateTypesMatch(sink, GetType(varDeclRef), GetType(existingVarDeclRef), &stack); - - return true; -} - - - - - /// Enumerate the existential-type parameters of a `Program`. - /// - /// Any parameters found will be added to the list of existential slots on `this`. - /// - void Program::_collectShaderParams(DiagnosticSink* sink) - { - // We need to collect all of the global shader parameters - // referenced by the compile request, and for each we - // need to do a few things: - // - // * We need to determine if the parameter is a duplicate/redeclaration - // of the "same" parameter in another translation unit, and collapse - // those into one logical shader parameter if so. - // - // * We need to determine what existential type slots are introduced - // by the parameter, and associate that information with the parameter. - // - // To deal with the first issue, we will maintain a map from a parameter - // name to the index of an existing parameter with that name. - // - Dictionary mapNameToParamIndex; - - for( auto module : getModuleDependencies() ) - { - auto moduleDecl = module->getModuleDecl(); - for( auto globalVar : moduleDecl->getMembersOfType() ) - { - if(!isGlobalShaderParameter(globalVar)) - continue; - - // This declaration may represent the same logical parameter - // as a declaration that came from a different translation unit. - // If that is the case, we want to re-use the same `ShaderParamInfo` - // across both parameters. - // - // TODO: This logic currently detects *any* global-scope parameters - // with matching names, but it should eventually be narrowly - // scoped so that it only applies to parameters from unnamed modules - // (that is, modules that represent directly-compiled shader files - // and not `import`ed code). - // - // First we look for an existing entry matching the name - // of this parameter: - // - auto paramName = getReflectionName(globalVar); - Int existingParamIndex = -1; - if( mapNameToParamIndex.TryGetValue(paramName, existingParamIndex) ) - { - // If the parameters have the same name, but don't "match" according to some reasonable rules, - // then we will treat them as distinct global parameters. - // - // Note: all of the mismatch cases currently report errors, so that - // compilation will fail on a mismatch. - // - auto& existingParam = m_shaderParams[existingParamIndex]; - if( doesParameterMatch(sink, makeDeclRef(globalVar.Ptr()), existingParam.paramDeclRef) ) - { - // If we hit this case, then we had a match, and we should - // consider the new variable to be a redclaration of - // the existing one. - - existingParam.additionalParamDeclRefs.add( - makeDeclRef(globalVar.Ptr())); - continue; - } - } - - Int newParamIndex = Int(m_shaderParams.getCount()); - mapNameToParamIndex.Add(paramName, newParamIndex); - - GlobalShaderParamInfo shaderParamInfo; - shaderParamInfo.paramDeclRef = makeDeclRef(globalVar.Ptr()); - - _collectExistentialSlotsForShaderParam( - shaderParamInfo, - m_globalExistentialSlots, - makeDeclRef(globalVar.Ptr())); - - m_shaderParams.add(shaderParamInfo); - } - } - } - - /// Create a `Program` to represent the compiled code. - /// - /// The created program will comprise all of the translation - /// units that were compiled as part of the request, as - /// well as any entry points in those translation units. - /// - RefPtr createUnspecializedProgram( - FrontEndCompileRequest* compileRequest) - { - // We want our resulting program to depend on - // all the translation units the user specified, - // even if some of them don't contain entry points - // (this is important for parameter layout/binding). - // - // We also want to ensure that the modules for the - // translation units comes first in the enumerated - // order for dependencies, to match the pre-existing - // compiler behavior (at least for now). - // - auto linkage = compileRequest->getLinkage(); - auto sink = compileRequest->getSink(); - auto program = new Program(linkage); - for(auto translationUnit : compileRequest->translationUnits ) - { - program->addReferencedLeafModule(translationUnit->getModule()); - } - for(auto translationUnit : compileRequest->translationUnits ) - { - program->addReferencedModule(translationUnit->getModule()); - } - - - // The validation of entry points here will be modal, and controlled - // by whether the user specified any entry points directly via - // API or command-line options. - // - // TODO: We may want to make this choice explicit rather than implicit. - // - // First, check if the user requested any entry points explicitly via - // the API or command line. - // - bool anyExplicitEntryPoints = compileRequest->getEntryPointReqCount() != 0; - - if( anyExplicitEntryPoints ) - { - // If there were any explicit requests for entry points to be - // checked, then we will *only* check those. - // - for(auto entryPointReq : compileRequest->getEntryPointReqs()) - { - auto entryPoint = findAndValidateEntryPoint( - entryPointReq); - if( entryPoint ) - { - program->addEntryPoint(entryPoint); - entryPointReq->getTranslationUnit()->entryPoints.add(entryPoint); - } - } - - // TODO: We should consider always processing both categories, - // and just making sure to only check each entry point function - // declaration once... - } - else - { - // Otherwise, scan for any `[shader(...)]` attributes in - // the user's code, and construct `EntryPoint`s to - // represent them. - // - // This ensures that downstream code only has to consider - // the central list of entry point requests, and doesn't - // have to know where they came from. - - // TODO: A comprehensive approach here would need to search - // recursively for entry points, because they might appear - // as, e.g., member function of a `struct` type. - // - // For now we'll start with an extremely basic approach that - // should work for typical HLSL code. - // - Index translationUnitCount = compileRequest->translationUnits.getCount(); - for(Index tt = 0; tt < translationUnitCount; ++tt) - { - auto translationUnit = compileRequest->translationUnits[tt]; - for( auto globalDecl : translationUnit->getModuleDecl()->Members ) - { - auto maybeFuncDecl = globalDecl; - if( auto genericDecl = as(maybeFuncDecl) ) - { - maybeFuncDecl = genericDecl->inner; - } - - auto funcDecl = as(maybeFuncDecl); - if(!funcDecl) - continue; - - auto entryPointAttr = funcDecl->FindModifier(); - if(!entryPointAttr) - continue; - - // We've discovered a valid entry point. It is a function (possibly - // generic) that has a `[shader(...)]` attribute to mark it as an - // entry point. - // - // We will now register that entry point as an `EntryPoint` - // with an appropriately chosen profile. - // - // The profile will only include a stage, so that the profile "family" - // and "version" are left unspecified. Downstream code will need - // to be able to handle this case. - // - Profile profile; - profile.setStage(entryPointAttr->stage); - - RefPtr entryPoint = EntryPoint::create( - makeDeclRef(funcDecl), - profile); - - validateEntryPoint(entryPoint, sink); - - program->addEntryPoint(entryPoint); - translationUnit->entryPoints.add(entryPoint); - } - } - } - - program->_collectShaderParams(sink); - - return program; - } - - static void _specializeExistentialTypeParams( - Linkage* linkage, - ExistentialTypeSlots& ioSlots, - List> const& args, - DiagnosticSink* sink) - { - Index slotCount = ioSlots.paramTypes.getCount(); - Index argCount = args.getCount(); - - if( slotCount != argCount ) - { - sink->diagnose(SourceLoc(), Diagnostics::mismatchExistentialSlotArgCount, slotCount, argCount); - return; - } - - SemanticsVisitor visitor(linkage, sink); - - for( Index ii = 0; ii < slotCount; ++ii ) - { - auto slotType = ioSlots.paramTypes[ii]; - auto argExpr = args[ii]; - - auto argType = checkProperType(linkage, TypeExp(argExpr), sink); - if(!argType) - { - // TODO: Each slot should track a source location and/or a `VarDeclBase` - // that names the parameter that the slot corresponds to. - - sink->diagnose(SourceLoc(), Diagnostics::existentialSlotArgNotAType, ii); - return; - } - - - auto witness = visitor.tryGetSubtypeWitness(argType, slotType); - if (!witness) - { - // If no witness was found, then we will be unable to satisfy - // the conformances required. - sink->diagnose(SourceLoc(), Diagnostics::existentialSlotArgDoesNotConform, ii, slotType); - return; - } - - ExistentialTypeSlots::Arg arg; - arg.type = argType; - arg.witness = witness; - ioSlots.args.add(arg); - } - } - - void EntryPoint::_specializeExistentialTypeParams( - List> const& args, - DiagnosticSink* sink) - { - Slang::_specializeExistentialTypeParams(getLinkage(), m_existentialSlots, args, sink); - } - - /// Create a specialization an existing entry point based on generic arguments. - RefPtr createSpecializedEntryPoint( - EntryPoint* unspecializedEntryPoint, - List> const& genericArgs, - List> const& existentialArgs, - DiagnosticSink* sink) - { - auto linkage = unspecializedEntryPoint->getLinkage(); - - // TODO: Need to be careful in case entry point already has a decl-ref, - // pertaining to outer specializations (e.g., when entry point was - // nested in a generic type. - // - auto entryPointFuncDecl = unspecializedEntryPoint->getFuncDecl(); - - SemanticsVisitor semantics( - linkage, - sink); - - DeclRef entryPointFuncDeclRef = makeDeclRef(entryPointFuncDecl.Ptr()); - if( auto genericDecl = as(entryPointFuncDecl->ParentDecl) ) - { - // We will construct a suitable `GenericAppExpr` to represent - // the user-specified `genericDecl` being applied to the - // supplied `genericArgs`, and then use the existing - // semantic checking logic that would apply to an explicit - // generic application like `F` if it were - // encountered in the source code. - - auto session = linkage->getSession(); - auto genericDeclRef = makeDeclRef(genericDecl); - - // The first pieces is a `VarExpr` that refers to `genericDecl`. - // - // TODO: This would not be needed if we instead parsed - // the supplied entry-point name into an expression - // earlier in this function. - // - RefPtr genericExpr = new VarExpr(); - genericExpr->declRef = genericDeclRef; - genericExpr->type.type = getTypeForDeclRef(session, genericDeclRef); - - // Next we construct the actual `GenericAppExpr` - // - RefPtr genericAppExpr = new GenericAppExpr(); - genericAppExpr->FunctionExpr = genericExpr; - genericAppExpr->Arguments = genericArgs; - - // We use the semantics visitor to perform the - // actual checking logic (this might report - // errors) - // - auto checkedExpr = semantics.checkGenericAppWithCheckedArgs(genericAppExpr); - - // Now we need to extract an appropriate decl-ref for the entry - // point from the `checkedExpr`. - // - if( auto declRefExpr = checkedExpr.as() ) - { - // TODO: We should eventually check for the case - // where we have a `MemberExpr` or another case of - // `DeclRefExpr` that cannot be summarized as just - // its decl-ref. - // - // The basic `VarExpr` and `StaticMemberExpr` cases - // should be allow-able. - - entryPointFuncDeclRef = declRefExpr->declRef.as(); - } - else if( semantics.IsErrorExpr(checkedExpr) ) - { - // Any semantic error that occured should have been - // reported already. - return nullptr; - } - else - { - // The result of specializing a reference to a generic - // function should always be a `DeclRefExpr` - // - SLANG_UNEXPECTED("reference to generic decl wasn't a `DeclRefExpr`"); - UNREACHABLE_RETURN(nullptr); - } - } - - RefPtr specializedEntryPoint = EntryPoint::create( - entryPointFuncDeclRef, - unspecializedEntryPoint->getProfile()); - - // Next we need to validate the existential arguments. - specializedEntryPoint->_specializeExistentialTypeParams(existentialArgs, sink); - - return specializedEntryPoint; - } - - /// Parse an array of strings as generic arguments. - /// - /// Names in the strings will be parsed in the context of - /// the code loaded into the given compile request. - /// - void parseGenericArgStrings( - EndToEndCompileRequest* endToEndReq, - List const& genericArgStrings, - List>& outGenericArgs) - { - auto unspecialiedProgram = endToEndReq->getUnspecializedProgram(); - - // TODO: Building a list of `scopesToTry` here shouldn't - // be required, since the `Scope` type itself has the ability - // for form chains for lookup purposes (e.g., the way that - // `import` is handled by modifying a scope). - // - List> scopesToTry; - for( auto module : unspecialiedProgram->getModuleDependencies() ) - scopesToTry.add(module->getModuleDecl()->scope); - - // We are going to do some semantic checking, so we need to - // set up a `SemanticsVistitor` that we can use. - // - auto linkage = endToEndReq->getLinkage(); - auto sink = endToEndReq->getSink(); - SemanticsVisitor semantics( - linkage, - sink); - - // We will be looping over the generic argument strings - // that the user provided via the API (or command line), - // and parsing+checking each into an `Expr`. - // - // This loop will *not* handle coercing the arguments - // to be types. - // - for(auto name : genericArgStrings) - { - RefPtr argExpr; - for (auto & s : scopesToTry) - { - argExpr = linkage->parseTypeString(name, s); - argExpr = semantics.CheckTerm(argExpr); - if( argExpr ) - { - break; - } - } - - outGenericArgs.add(argExpr); - } - } - - void Program::_specializeExistentialTypeParams( - List> const& args, - DiagnosticSink* sink) - { - Slang::_specializeExistentialTypeParams(getLinkage(), m_globalExistentialSlots, args, sink); - } - - Type* Linkage::specializeType( - Type* unspecializedType, - Int argCount, - Type* const* args, - DiagnosticSink* sink) - { - // TODO: We should cache and re-use specialized types - // when the exact same arguments are provided again later. - - SemanticsVisitor visitor(this, sink); - - - ExistentialTypeSlots slots; - _collectExistentialTypeParamsRec(slots, unspecializedType); - - assert(slots.paramTypes.getCount() == argCount); - - for( Int aa = 0; aa < argCount; ++aa ) - { - auto argType = args[aa]; - - ExistentialTypeSlots::Arg arg; - arg.type = argType; - arg.witness = visitor.tryGetSubtypeWitness(argType, slots.paramTypes[aa]); - slots.args.add(arg); - } - - RefPtr specializedType = new ExistentialSpecializedType(); - specializedType->baseType = unspecializedType; - specializedType->slots = slots; - - m_specializedTypes.add(specializedType); - - return specializedType; - } - - /// Specialize a program to global generic arguments - RefPtr createSpecializedProgram( - Linkage* linkage, - Program* unspecializedProgram, - List> const& globalGenericArgs, - List> const& globalExistentialArgs, - DiagnosticSink* sink) - { - // The given `unspecializedProgram` should be one that - // was checked through the front-end, so that now we - // only need to check if the given arguments can satisfy - // the requirements of the global generic parameters. - // - // The new program needs to start off with the same - // module dependency list as the original. - // - RefPtr specializedProgram = new Program(linkage); - for(auto module : unspecializedProgram->getModuleDependencies()) - { - specializedProgram->addReferencedLeafModule(module); - } - - - // We will collect all the global generic parameters - // defined in the modules being referenced, to find - // the global generic parameter signature of the - // program. - // - // TODO: Note that this doesn't handle the case where one - // or more of the type *arguments* that we are specifying - // ends up requiring additional modules to be referenced, - // which might in turn introduce new global generic parameters. - // - List> globalGenericParams; - for(auto module : unspecializedProgram->getModuleDependencies()) - { - for(auto param : module->getModuleDecl()->getMembersOfType()) - globalGenericParams.add(param); - } - - // Next, we will check whether the supplied arguments can - // satisfy those parameters. - // - // An easy early-out case will be if the number of - // arguments isn't correct. - // - if (globalGenericParams.getCount() != globalGenericArgs.getCount()) - { - sink->diagnose(SourceLoc(), Diagnostics::mismatchGlobalGenericArguments, - globalGenericParams.getCount(), - globalGenericArgs.getCount()); - return nullptr; - } - - // We have an appropriate number of arguments for the global generic parameters, - // and now we need to check that the arguments conform to the declared constraints. - // - SemanticsVisitor visitor(linkage, sink); - - // Along the way, we will build up an appropriate set of substitutions to represent - // the generic arguments and their conformances. - // - RefPtr globalGenericSubsts; - auto globalGenericSubstLink = &globalGenericSubsts; - // - // TODO: There is a serious flaw to this checking logic if we ever have cases where - // the constraints on one `type_param` can depend on another `type_param`, e.g.: - // - // type_param A; - // type_param B : ISidekick; - // - // In that case, if a user tries to set `B` to `Robin` and `Robin` conforms to - // `ISidekick`, then the compiler needs to know whether `A` is being - // set to `Batman` to know whether the setting for `B` is valid. In this limit - // the constraints can be mutually recursive (so `A : IMentor`). - // - // The only way to check things correctly is to validate each conformance under - // a set of assumptions (substitutions) that includes all the type substitutions, - // and possibly also all the other constraints *except* the one to be validated. - // - // We will punt on this for now, and just check each constraint in isolation. - // - Index argCounter = 0; - for(auto& globalGenericParam : globalGenericParams) - { - // Get the argument that matches this parameter. - Index argIndex = argCounter++; - SLANG_ASSERT(argIndex < globalGenericArgs.getCount()); - auto globalGenericArg = checkProperType(linkage, TypeExp(globalGenericArgs[argIndex]), sink); - if (!globalGenericArg) - { - sink->diagnose(globalGenericParam, Diagnostics::globalGenericArgumentNotAType, globalGenericParam->getName()); - return nullptr; - } - - // As a quick sanity check, see if the argument that is being supplied for a parameter - // is just the parameter itself, because this should always be an error: - // - if( auto argDeclRefType = globalGenericArg.as() ) - { - auto argDeclRef = argDeclRefType->declRef; - if(auto argGenericParamDeclRef = argDeclRef.as()) - { - if(argGenericParamDeclRef.getDecl() == globalGenericParam) - { - // We are trying to specialize a generic parameter using itself. - sink->diagnose(globalGenericParam, - Diagnostics::cannotSpecializeGlobalGenericToItself, - globalGenericParam->getName()); - continue; - } - else - { - // We are trying to specialize a generic parameter using a *different* - // global generic type parameter. - sink->diagnose(globalGenericParam, - Diagnostics::cannotSpecializeGlobalGenericToAnotherGenericParam, - globalGenericParam->getName(), - argGenericParamDeclRef.GetName()); - continue; - } - } - } - - // Create a substitution for this parameter/argument. - RefPtr subst = new GlobalGenericParamSubstitution(); - subst->paramDecl = globalGenericParam; - subst->actualType = globalGenericArg; - - // Walk through the declared constraints for the parameter, - // and check that the argument actually satisfies them. - for(auto constraint : globalGenericParam->getMembersOfType()) - { - // Get the type that the constraint is enforcing conformance to - auto interfaceType = GetSup(DeclRef(constraint, nullptr)); - - // Use our semantic-checking logic to search for a witness to the required conformance - auto witness = visitor.tryGetSubtypeWitness(globalGenericArg, interfaceType); - if (!witness) - { - // If no witness was found, then we will be unable to satisfy - // the conformances required. - sink->diagnose(globalGenericParam, - Diagnostics::typeArgumentDoesNotConformToInterface, - globalGenericParam->nameAndLoc.name, - globalGenericArg, - interfaceType); - } - - // Attach the concrete witness for this conformance to the - // substutiton - GlobalGenericParamSubstitution::ConstraintArg constraintArg; - constraintArg.decl = constraint; - constraintArg.val = witness; - subst->constraintArgs.add(constraintArg); - } - - // Add the substitution for this parameter to the global substitution - // set that we are building. - - *globalGenericSubstLink = subst; - globalGenericSubstLink = &subst->outer; - } - if(sink->GetErrorCount()) - return nullptr; - - specializedProgram->setGlobalGenericSubsitution(globalGenericSubsts); - - // Now deal with the shader parameters and existential arguments - // - // Note: We should in theory be able to just copy over the shader - // parameters and existential slot information from the unspecialized - // program. This could save some time, but it would also mean that - // the only way to create a specialized program is by creating an - // unspecialized on first, which is maybe not always desirable. - // - specializedProgram->_collectShaderParams(sink); - specializedProgram->_specializeExistentialTypeParams(globalExistentialArgs, sink); - - return specializedProgram; - } - - /// Specialize an entry point that was checked by the front-end, based on generic arguments. - /// - /// If the end-to-end compile request included generic argument strings - /// for this entry point, then they will be parsed, checked, and used - /// as arguments to the generic entry point. - /// - /// Returns a specialized entry point if everything worked as expected. - /// Returns null and diagnoses errors if anything goes wrong. - /// - RefPtr createSpecializedEntryPoint( - EndToEndCompileRequest* endToEndReq, - EntryPoint* unspecializedEntryPoint, - EndToEndCompileRequest::EntryPointInfo const& entryPointInfo) - { - auto sink = endToEndReq->getSink(); - auto entryPointFuncDecl = unspecializedEntryPoint->getFuncDecl(); - - // If the user specified generic arguments for the entry point, - // then we will need to parse the arguments first. - // - List> genericArgs; - parseGenericArgStrings( - endToEndReq, - entryPointInfo.genericArgStrings, - genericArgs); - - List> existentialArgs; - parseGenericArgStrings( - endToEndReq, - entryPointInfo.existentialArgStrings, - existentialArgs); - - // Next we specialize the entry point function given the parsed - // generic argument expressions. - // - auto entryPoint = createSpecializedEntryPoint( - unspecializedEntryPoint, - genericArgs, - existentialArgs, - sink); - - return entryPoint; - } - - /// Create a specialized program based on the given compile request. - /// - RefPtr createSpecializedProgram( - EndToEndCompileRequest* endToEndReq) - { - // The compile request must have already completed front-end processing, - // so that we have an unspecialized program available, and now only need - // to parse and check any generic arguments that are being supplied for - // global or entry-point generic parameters. - // - auto unspecializedProgram = endToEndReq->getUnspecializedProgram(); - - // First, let's parse the generic argument strings that were - // provided via the API, so taht we can match them - // against what was declared in the program. - // - List> globalGenericArgs; - parseGenericArgStrings( - endToEndReq, - endToEndReq->globalGenericArgStrings, - globalGenericArgs); - - // Also handle global existential type arguments. - List> globalExistentialArgs; - parseGenericArgStrings( - endToEndReq, - endToEndReq->globalExistentialSlotArgStrings, - globalExistentialArgs); - - // Now we create the initial specialized program by - // applying the global generic arguments (if any) to the - // unspecialized program. - // - auto specializedProgram = createSpecializedProgram( - endToEndReq->getLinkage(), - unspecializedProgram, - globalGenericArgs, - globalExistentialArgs, - endToEndReq->getSink()); - - // If anything went wrong with the global generic - // arguments, then bail out now. - // - if(!specializedProgram) - return nullptr; - - // Next we will deal with the entry points for the - // new specialized program. - // - // If the user specified explicit entry points as part of the - // end-to-end request, then we only want to process those (and - // ignore any other `[shader(...)]`-attributed entry points). - // - // However, if the user specified *no* entry points as part - // of the end-to-end request, then we would like to go - // ahead and consider all the entry points that were found - // by the front-end. - // - Index entryPointCount = endToEndReq->entryPoints.getCount(); - if( entryPointCount == 0 ) - { - entryPointCount = unspecializedProgram->getEntryPointCount(); - endToEndReq->entryPoints.setCount(entryPointCount); - } - - for( Index ii = 0; ii < entryPointCount; ++ii ) - { - auto unspecializedEntryPoint = unspecializedProgram->getEntryPoint(ii); - auto& entryPointInfo = endToEndReq->entryPoints[ii]; - - auto specializedEntryPoint = createSpecializedEntryPoint(endToEndReq, unspecializedEntryPoint, entryPointInfo); - specializedProgram->addEntryPoint(specializedEntryPoint); - } - - return specializedProgram; - } - - void checkTranslationUnit( - TranslationUnitRequest* translationUnit) - { - SemanticsVisitor visitor( - translationUnit->compileRequest->getLinkage(), - translationUnit->compileRequest->getSink()); - - // Apply the visitor to do the main semantic - // checking that is required on all declarations - // in the translation unit. - visitor.checkDecl(translationUnit->getModuleDecl()); - } - - - // - - // Get the type to use when referencing a declaration - QualType getTypeForDeclRef( - Session* session, - SemanticsVisitor* sema, - DiagnosticSink* sink, - DeclRef declRef, - RefPtr* outTypeResult) - { - if( sema ) - { - sema->checkDecl(declRef.getDecl()); - } - - // We need to insert an appropriate type for the expression, based on - // what we found. - if (auto varDeclRef = declRef.as()) - { - QualType qualType; - qualType.type = GetType(varDeclRef); - - bool isLValue = true; - if(varDeclRef.getDecl()->FindModifier()) - isLValue = false; - - // Global-scope shader parameters should not be writable, - // since they are effectively program inputs. - // - // TODO: We could eventually treat a mutable global shader - // parameter as a shorthand for an immutable parameter and - // a global variable that gets initialized from that parameter, - // but in order to do so we'd need to support global variables - // with resource types better in the back-end. - // - if(isGlobalShaderParameter(varDeclRef.getDecl())) - isLValue = false; - - // Variables declared with `let` are always immutable. - if(varDeclRef.is()) - isLValue = false; - - // Generic value parameters are always immutable - if(varDeclRef.is()) - isLValue = false; - - // Function parameters declared in the "modern" style - // are immutable unless they have an `out` or `inout` modifier. - if(varDeclRef.is()) - { - // Note: the `inout` modifier AST class inherits from - // the class for the `out` modifier so that we can - // make simple checks like this. - // - if( !varDeclRef.getDecl()->HasModifier() ) - { - isLValue = false; - } - } - - qualType.IsLeftValue = isLValue; - return qualType; - } - else if( auto enumCaseDeclRef = declRef.as() ) - { - QualType qualType; - qualType.type = getType(enumCaseDeclRef); - qualType.IsLeftValue = false; - return qualType; - } - else if (auto typeAliasDeclRef = declRef.as()) - { - auto type = getNamedType(session, typeAliasDeclRef); - *outTypeResult = type; - return QualType(getTypeType(type)); - } - else if (auto aggTypeDeclRef = declRef.as()) - { - auto type = DeclRefType::Create(session, aggTypeDeclRef); - *outTypeResult = type; - return QualType(getTypeType(type)); - } - else if (auto simpleTypeDeclRef = declRef.as()) - { - auto type = DeclRefType::Create(session, simpleTypeDeclRef); - *outTypeResult = type; - return QualType(getTypeType(type)); - } - else if (auto genericDeclRef = declRef.as()) - { - auto type = getGenericDeclRefType(session, genericDeclRef); - *outTypeResult = type; - return QualType(getTypeType(type)); - } - else if (auto funcDeclRef = declRef.as()) - { - auto type = getFuncType(session, funcDeclRef); - return QualType(type); - } - else if (auto constraintDeclRef = declRef.as()) - { - // When we access a constraint or an inheritance decl (as a member), - // we are conceptually performing a "cast" to the given super-type, - // with the declaration showing that such a cast is legal. - auto type = GetSup(constraintDeclRef); - return QualType(type); - } - if( sink ) - { - sink->diagnose(declRef, Diagnostics::unimplemented, "cannot form reference to this kind of declaration"); - } - return QualType(session->getErrorType()); - } - - QualType getTypeForDeclRef( - Session* session, - DeclRef declRef) - { - RefPtr typeResult; - return getTypeForDeclRef(session, nullptr, nullptr, declRef, &typeResult); - } - - DeclRef ApplyExtensionToType( - SemanticsVisitor* semantics, - ExtensionDecl* extDecl, - RefPtr type) - { - if(!semantics) - return DeclRef(); - - return semantics->ApplyExtensionToType(extDecl, type); - } - - RefPtr createDefaultSubsitutionsForGeneric( - Session* session, - GenericDecl* genericDecl, - RefPtr outerSubst) - { - RefPtr genericSubst = new GenericSubstitution(); - genericSubst->genericDecl = genericDecl; - genericSubst->outer = outerSubst; - - for( auto mm : genericDecl->Members ) - { - if( auto genericTypeParamDecl = as(mm) ) - { - genericSubst->args.add(DeclRefType::Create(session, DeclRef(genericTypeParamDecl, outerSubst))); - } - else if( auto genericValueParamDecl = as(mm) ) - { - genericSubst->args.add(new GenericParamIntVal(DeclRef(genericValueParamDecl, outerSubst))); - } - } - - // create default substitution arguments for constraints - for (auto mm : genericDecl->Members) - { - if (auto genericTypeConstraintDecl = as(mm)) - { - RefPtr witness = new DeclaredSubtypeWitness(); - witness->declRef = DeclRef(genericTypeConstraintDecl, outerSubst); - witness->sub = genericTypeConstraintDecl->sub.type; - witness->sup = genericTypeConstraintDecl->sup.type; - genericSubst->args.add(witness); - } - } - - return genericSubst; - } - - // Sometimes we need to refer to a declaration the way that it would be specialized - // inside the context where it is declared (e.g., with generic parameters filled in - // using their archetypes). - // - SubstitutionSet createDefaultSubstitutions( - Session* session, - Decl* decl, - SubstitutionSet outerSubstSet) - { - auto dd = decl->ParentDecl; - if( auto genericDecl = as(dd) ) - { - // We don't want to specialize references to anything - // other than the "inner" declaration itself. - if(decl != genericDecl->inner) - return outerSubstSet; - - RefPtr genericSubst = createDefaultSubsitutionsForGeneric( - session, - genericDecl, - outerSubstSet.substitutions); - - return SubstitutionSet(genericSubst); - } - - return outerSubstSet; - } - - SubstitutionSet createDefaultSubstitutions( - Session* session, - Decl* decl) - { - SubstitutionSet subst; - if( auto parentDecl = decl->ParentDecl ) - { - subst = createDefaultSubstitutions(session, parentDecl); - } - subst = createDefaultSubstitutions(session, decl, subst); - return subst; - } - - void checkDecl(SemanticsVisitor* visitor, Decl* decl) - { - visitor->checkDecl(decl); - } -} diff --git a/source/slang/check.h b/source/slang/check.h deleted file mode 100644 index 1f378ec7b..000000000 --- a/source/slang/check.h +++ /dev/null @@ -1,7 +0,0 @@ -// check.h -#pragma once - -namespace Slang -{ - bool isGlobalShaderParameter(VarDeclBase* decl); -} \ No newline at end of file diff --git a/source/slang/compiler.cpp b/source/slang/compiler.cpp deleted file mode 100644 index 22c1d4cd8..000000000 --- a/source/slang/compiler.cpp +++ /dev/null @@ -1,1645 +0,0 @@ -// Compiler.cpp : Defines the entry point for the console application. -// -#include "../core/basic.h" -#include "../core/platform.h" -#include "../core/slang-io.h" -#include "../core/slang-string-util.h" - -#include "compiler.h" -#include "lexer.h" -#include "lower-to-ir.h" -#include "parameter-binding.h" -#include "parser.h" -#include "preprocessor.h" -#include "syntax-visitors.h" -#include "type-layout.h" -#include "reflection.h" -#include "emit.h" - -// Enable calling through to `fxc` or `dxc` to -// generate code on Windows. -#ifdef _WIN32 - #define WIN32_LEAN_AND_MEAN - #define NOMINMAX - #include - #undef WIN32_LEAN_AND_MEAN - #undef NOMINMAX - #include - #ifndef SLANG_ENABLE_DXBC_SUPPORT - #define SLANG_ENABLE_DXBC_SUPPORT 1 - #endif - #ifndef SLANG_ENABLE_DXIL_SUPPORT - #define SLANG_ENABLE_DXIL_SUPPORT 1 - #endif -#endif -// -// Otherwise, don't enable DXBC/DXIL by default: -#ifndef SLANG_ENABLE_DXBC_SUPPORT - #define SLANG_ENABLE_DXBC_SUPPORT 0 -#endif -#ifndef SLANG_ENABLE_DXIL_SUPPORT - #define SLANG_ENABLE_DXIL_SUPPORT 0 -#endif - -// Enable calling through to `glslang` on -// all platforms. -#ifndef SLANG_ENABLE_GLSLANG_SUPPORT - #define SLANG_ENABLE_GLSLANG_SUPPORT 1 -#endif - -#if SLANG_ENABLE_GLSLANG_SUPPORT -#include "../slang-glslang/slang-glslang.h" -#endif - -// Includes to allow us to control console -// output when writing assembly dumps. -#include -#ifdef _WIN32 -#include -#else -#include -#endif - -#ifdef _MSC_VER -#pragma warning(disable: 4996) -#endif - -#ifdef CreateDirectory -#undef CreateDirectory -#endif - -namespace Slang -{ - - // CompileResult - - void CompileResult::append(CompileResult const& result) - { - // Find which to append to - ResultFormat appendTo = ResultFormat::None; - - if (format == ResultFormat::None) - { - format = result.format; - appendTo = result.format; - } - else if (format == result.format) - { - appendTo = format; - } - - if (appendTo == ResultFormat::Text) - { - outputString.append(result.outputString.getBuffer()); - } - else if (appendTo == ResultFormat::Binary) - { - outputBinary.addRange(result.outputBinary.getBuffer(), result.outputBinary.getCount()); - } - } - - ComPtr CompileResult::getBlob() - { - if(!blob) - { - switch(format) - { - case ResultFormat::None: - default: - break; - - case ResultFormat::Text: - blob = StringUtil::createStringBlob(outputString); - break; - - case ResultFormat::Binary: - blob = createRawBlob(outputBinary.getBuffer(), outputBinary.getCount()); - break; - } - } - return blob; - } - - // - // FrontEndEntryPointRequest - // - - FrontEndEntryPointRequest::FrontEndEntryPointRequest( - FrontEndCompileRequest* compileRequest, - int translationUnitIndex, - Name* name, - Profile profile) - : m_compileRequest(compileRequest) - , m_translationUnitIndex(translationUnitIndex) - , m_name(name) - , m_profile(profile) - {} - - - TranslationUnitRequest* FrontEndEntryPointRequest::getTranslationUnit() - { - return getCompileRequest()->translationUnits[m_translationUnitIndex]; - } - - // - // EntryPoint - // - - RefPtr EntryPoint::create( - DeclRef funcDeclRef, - Profile profile) - { - RefPtr entryPoint = new EntryPoint( - funcDeclRef.GetName(), - profile, - funcDeclRef); - return entryPoint; - } - - RefPtr EntryPoint::createDummyForPassThrough( - Name* name, - Profile profile) - { - RefPtr entryPoint = new EntryPoint( - name, - profile, - DeclRef()); - return entryPoint; - } - - EntryPoint::EntryPoint( - Name* name, - Profile profile, - DeclRef funcDeclRef) - : m_name(name) - , m_profile(profile) - , m_funcDeclRef(funcDeclRef) - { - // In order for later code generation to work, we need to track what - // modules each entry point depends on. We will build up the dependency - // list here when an `EntryPoint` gets created. - // - // We know an entry point depends on the module that declared the - // entry-point function itself. - // - // Note: we are carefully handling the case where `module` could - // be null, becase of "dummy" entry points created for pass-through - // compilation. - // - if(auto module = getModule()) - { - m_dependencyList.addDependency(module); - } - // - // TODO: We also need to include the modules needed by any generic - // arguments in the dependency list, since in the general case they - // might come from modules other than the one defining the entry point. - - // The following is a bit of a hack. - // - // Back-end code generation relies on us having computed layouts for all tagged - // unions that end up being used in the code, which means we need a way to find - // all such types that get used in a program (and the stuff it imports). - // - // For now we are assuming a tagged union type only comes into existence - // as a (top-level) argument for a generic type parameter, so that we - // can check for them here and cache them on the entry point. - // - // A longer-term strategy might need to consider any (tagged or untagged) - // union types that get used inside of a module, and also take - // those lists into account. - // - // An even longer-term strategy would be to allow type layout to - // be performed on IR types, so taht we don't need to have front-end - // code worrying about this stuff. - // - for( auto subst = funcDeclRef.substitutions.substitutions; subst; subst = subst->outer ) - { - if( auto genericSubst = as(subst) ) - { - for( auto arg : genericSubst->args ) - { - if( auto taggedUnionType = as(arg) ) - { - m_taggedUnionTypes.add(taggedUnionType); - } - } - } - } - - // Collect any existential-type parameters used by the entry point - // - _collectShaderParams(); - } - - Module* EntryPoint::getModule() - { - return Slang::getModule(getFuncDecl()); - } - - Linkage* EntryPoint::getLinkage() - { - return getModule()->getLinkage(); - } - - // - - Profile Profile::LookUp(char const* name) - { - #define PROFILE(TAG, NAME, STAGE, VERSION) if(strcmp(name, #NAME) == 0) return Profile::TAG; - #define PROFILE_ALIAS(TAG, DEF, NAME) if(strcmp(name, #NAME) == 0) return Profile::TAG; - #include "profile-defs.h" - - return Profile::Unknown; - } - - char const* Profile::getName() - { - switch( raw ) - { - default: - return "unknown"; - - #define PROFILE(TAG, NAME, STAGE, VERSION) case Profile::TAG: return #NAME; - #define PROFILE_ALIAS(TAG, DEF, NAME) /* empty */ - #include "profile-defs.h" - } - } - - Stage findStageByName(String const& name) - { - static const struct - { - char const* name; - Stage stage; - } kStages[] = - { - #define PROFILE_STAGE(ID, NAME, ENUM) \ - { #NAME, Stage::ID }, - - #define PROFILE_STAGE_ALIAS(ID, NAME, VAL) \ - { #NAME, Stage::ID }, - - #include "profile-defs.h" - }; - - for(auto entry : kStages) - { - if(name == entry.name) - { - return entry.stage; - } - } - - return Stage::Unknown; - } - - SlangResult checkExternalCompilerSupport(Session* session, PassThroughMode passThrough) - { - switch (passThrough) - { - case PassThroughMode::None: - { - // If no pass through -> that will always work! - return SLANG_OK; - } - case PassThroughMode::dxc: - { -#if SLANG_ENABLE_DXIL_SUPPORT - // Must have dxc - return session->getOrLoadSharedLibrary(SharedLibraryType::Dxc, nullptr) ? SLANG_OK : SLANG_E_NOT_FOUND; -#endif - break; - } - case PassThroughMode::fxc: - { -#if SLANG_ENABLE_DXBC_SUPPORT - // Must have fxc - return session->getOrLoadSharedLibrary(SharedLibraryType::Fxc, nullptr) ? SLANG_OK : SLANG_E_NOT_FOUND; -#endif - break; - } - case PassThroughMode::glslang: - { -#if SLANG_ENABLE_GLSLANG_SUPPORT - return session->getOrLoadSharedLibrary(Slang::SharedLibraryType::Glslang, nullptr) ? SLANG_OK : SLANG_E_NOT_FOUND; -#endif - break; - } - } - return SLANG_E_NOT_IMPLEMENTED; - } - - static PassThroughMode _getExternalCompilerRequiredForTarget(CodeGenTarget target) - { - switch (target) - { - case CodeGenTarget::None: - { - return PassThroughMode::None; - } - case CodeGenTarget::GLSL: - case CodeGenTarget::GLSL_Vulkan: - case CodeGenTarget::GLSL_Vulkan_OneDesc: - { - // Can always output GLSL - return PassThroughMode::None; - } - case CodeGenTarget::HLSL: - { - // Can always output HLSL - return PassThroughMode::None; - } - case CodeGenTarget::SPIRVAssembly: - case CodeGenTarget::SPIRV: - { - return PassThroughMode::glslang; - } - case CodeGenTarget::DXBytecode: - case CodeGenTarget::DXBytecodeAssembly: - { - return PassThroughMode::fxc; - } - case CodeGenTarget::DXIL: - case CodeGenTarget::DXILAssembly: - { - return PassThroughMode::dxc; - } - case CodeGenTarget::CPPSource: - case CodeGenTarget::CSource: - { - // Don't need an external compiler to output C and C++ code - return PassThroughMode::None; - } - - default: break; - } - - SLANG_ASSERT(!"Unhandled target"); - return PassThroughMode::None; - } - - SlangResult checkCompileTargetSupport(Session* session, CodeGenTarget target) - { - const PassThroughMode mode = _getExternalCompilerRequiredForTarget(target); - return (mode != PassThroughMode::None) ? - checkExternalCompilerSupport(session, mode) : - SLANG_OK; - } - - // - - /// If there is a pass-through compile going on, find the translation unit for the given entry point. - TranslationUnitRequest* findPassThroughTranslationUnit( - EndToEndCompileRequest* endToEndReq, - Int entryPointIndex) - { - // If there isn't an end-to-end compile going on, - // there can be no pass-through. - // - if(!endToEndReq) return nullptr; - - // And if pass-through isn't set, we don't need - // access to the translation unit. - // - if(endToEndReq->passThrough == PassThroughMode::None) return nullptr; - - auto frontEndReq = endToEndReq->getFrontEndReq(); - auto entryPointReq = frontEndReq->getEntryPointReq(entryPointIndex); - auto translationUnit = entryPointReq->getTranslationUnit(); - return translationUnit; - } - - String emitHLSLForEntryPoint( - BackEndCompileRequest* compileRequest, - EntryPoint* entryPoint, - Int entryPointIndex, - TargetRequest* targetReq, - EndToEndCompileRequest* endToEndReq) - { - if(auto translationUnit = findPassThroughTranslationUnit(endToEndReq, entryPointIndex)) - { - // Generate a string that includes the content of - // the source file(s), along with a line directive - // to ensure that we get reasonable messages - // from the downstream compiler when in pass-through - // mode. - - StringBuilder codeBuilder; - for(auto sourceFile : translationUnit->getSourceFiles()) - { - codeBuilder << "#line 1 \""; - - const String& path = sourceFile->getPathInfo().foundPath; - - for(auto c : path) - { - char buffer[] = { c, 0 }; - switch(c) - { - default: - codeBuilder << buffer; - break; - - case '\\': - codeBuilder << "\\\\"; - } - } - codeBuilder << "\"\n"; - - codeBuilder << sourceFile->getContent() << "\n"; - } - - return codeBuilder.ProduceString(); - } - else - { - return emitEntryPoint( - compileRequest, - entryPoint, - CodeGenTarget::HLSL, - targetReq); - } - } - - String emitGLSLForEntryPoint( - BackEndCompileRequest* compileRequest, - EntryPoint* entryPoint, - Int entryPointIndex, - TargetRequest* targetReq, - EndToEndCompileRequest* endToEndReq) - { - if(auto translationUnit = findPassThroughTranslationUnit(endToEndReq, entryPointIndex)) - { - // Generate a string that includes the content of - // the source file(s), along with a line directive - // to ensure that we get reasonable messages - // from the downstream compiler when in pass-through - // mode. - - StringBuilder codeBuilder; - int translationUnitCounter = 0; - for(auto sourceFile : translationUnit->getSourceFiles()) - { - int translationUnitIndex = translationUnitCounter++; - - // We want to output `#line` directives, but we need - // to skip this for the first file, since otherwise - // some GLSL implementations will get tripped up by - // not having the `#version` directive be the first - // thing in the file. - if(translationUnitIndex != 0) - { - codeBuilder << "#line 1 " << translationUnitIndex << "\n"; - } - codeBuilder << sourceFile->getContent() << "\n"; - } - - return codeBuilder.ProduceString(); - } - else - { - // TODO(tfoley): need to pass along the entry point - // so that we properly emit it as the `main` function. - return emitEntryPoint( - compileRequest, - entryPoint, - CodeGenTarget::GLSL, - targetReq); - } - } - - String GetHLSLProfileName(Profile profile) - { - switch( profile.getFamily() ) - { - case ProfileFamily::DX: - // Profile version is a DX one, so stick with it. - break; - - default: - // Profile is a non-DX profile family, so we need to try - // to clobber it with something to get a default. - // - // TODO: This is a huge hack... - profile.setVersion(ProfileVersion::DX_5_0); - break; - } - - char const* stagePrefix = nullptr; - switch( profile.GetStage() ) - { - // Note: All of the raytracing-related stages require - // compiling for a `lib_*` profile, even when only a - // single entry point is present. - // - // We also go ahead and use this target in any case - // where we don't know the actual stage to compiel for, - // as a fallback option. - // - // TODO: We also want to use this option when compiling - // multiple entry points to a DXIL library. - // - default: - stagePrefix = "lib"; - break; - - // The traditional rasterization pipeline and compute - // shaders all have custom profile names that identify - // both the stage and shader model, which need to be - // used when compiling a single entry point. - // - #define CASE(NAME, PREFIX) case Stage::NAME: stagePrefix = #PREFIX; break - CASE(Vertex, vs); - CASE(Hull, hs); - CASE(Domain, ds); - CASE(Geometry, gs); - CASE(Fragment, ps); - CASE(Compute, cs); - #undef CASE - } - - char const* versionSuffix = nullptr; - switch(profile.GetVersion()) - { - #define CASE(TAG, SUFFIX) case ProfileVersion::TAG: versionSuffix = #SUFFIX; break - CASE(DX_4_0, _4_0); - CASE(DX_4_0_Level_9_0, _4_0_level_9_0); - CASE(DX_4_0_Level_9_1, _4_0_level_9_1); - CASE(DX_4_0_Level_9_3, _4_0_level_9_3); - CASE(DX_4_1, _4_1); - CASE(DX_5_0, _5_0); - CASE(DX_5_1, _5_1); - CASE(DX_6_0, _6_0); - CASE(DX_6_1, _6_1); - CASE(DX_6_2, _6_2); - CASE(DX_6_3, _6_3); - #undef CASE - - default: - return "unknown"; - } - - String result; - result.append(stagePrefix); - result.append(versionSuffix); - return result; - } - - void reportExternalCompileError(const char* compilerName, SlangResult res, const UnownedStringSlice& diagnostic, DiagnosticSink* sink) - { - StringBuilder builder; - if (compilerName) - { - builder << compilerName << ": "; - } - - if (diagnostic.size() > 0) - { - builder.Append(diagnostic); - } - - if (SLANG_FAILED(res) && res != SLANG_FAIL) - { - { - char tmp[17]; - sprintf_s(tmp, SLANG_COUNT_OF(tmp), "0x%08x", uint32_t(res)); - builder << "Result(" << tmp << ") "; - } - - PlatformUtil::appendResult(res, builder); - } - - // TODO(tfoley): need a better policy for how we translate diagnostics - // back into the Slang world (although we should always try to generate - // HLSL that doesn't produce any diagnostics...) - sink->diagnoseRaw(SLANG_FAILED(res) ? Severity::Error : Severity::Warning, builder.getUnownedSlice()); - } - - static String _getDisplayPath(DiagnosticSink* sink, SourceFile* sourceFile) - { - if (sink->flags & DiagnosticSink::Flag::VerbosePath) - { - return sourceFile->calcVerbosePath(); - } - else - { - return sourceFile->getPathInfo().foundPath; - } - } - - String calcSourcePathForEntryPoint( - EndToEndCompileRequest* endToEndReq, - UInt entryPointIndex) - { - auto translationUnitRequest = findPassThroughTranslationUnit(endToEndReq, entryPointIndex); - if(!translationUnitRequest) - return "slang-generated"; - - auto sink = endToEndReq->getSink(); - - const auto& sourceFiles = translationUnitRequest->getSourceFiles(); - - const Index numSourceFiles = sourceFiles.getCount(); - - switch (numSourceFiles) - { - case 0: return "unknown"; - case 1: return _getDisplayPath(sink, sourceFiles[0]); - default: - { - StringBuilder builder; - builder << _getDisplayPath(sink, sourceFiles[0]); - for (int i = 1; i < numSourceFiles; ++i) - { - builder << ";" << _getDisplayPath(sink, sourceFiles[i]); - } - return builder; - } - } - } - -#if SLANG_ENABLE_DXBC_SUPPORT - - static UnownedStringSlice _getSlice(ID3DBlob* blob) - { - if (blob) - { - const char* chars = (const char*)blob->GetBufferPointer(); - size_t len = blob->GetBufferSize(); - len -= size_t(len > 0 && chars[len - 1] == 0); - return UnownedStringSlice(chars, len); - } - return UnownedStringSlice(); - } - - SlangResult emitDXBytecodeForEntryPoint( - BackEndCompileRequest* compileRequest, - EntryPoint* entryPoint, - Int entryPointIndex, - TargetRequest* targetReq, - EndToEndCompileRequest* endToEndReq, - List& byteCodeOut) - { - byteCodeOut.clear(); - - auto session = compileRequest->getSession(); - auto sink = compileRequest->getSink(); - - auto compileFunc = (pD3DCompile)session->getSharedLibraryFunc(Session::SharedLibraryFuncType::Fxc_D3DCompile, sink); - if (!compileFunc) - { - return SLANG_FAIL; - } - - auto hlslCode = emitHLSLForEntryPoint(compileRequest, entryPoint, entryPointIndex, targetReq, endToEndReq); - maybeDumpIntermediate(compileRequest, hlslCode.getBuffer(), CodeGenTarget::HLSL); - - auto profile = getEffectiveProfile(entryPoint, targetReq); - - // If we have been invoked in a pass-through mode, then we need to make sure - // that the downstream compiler sees whatever options were passed to Slang - // via the command line or API. - // - // TODO: more pieces of information should be added here as needed. - // - List dxMacrosStorage; - D3D_SHADER_MACRO const* dxMacros = nullptr; - if(auto translationUnit = findPassThroughTranslationUnit(endToEndReq, entryPointIndex)) - { - for( auto& define : translationUnit->compileRequest->preprocessorDefinitions ) - { - D3D_SHADER_MACRO dxMacro; - dxMacro.Name = define.Key.getBuffer(); - dxMacro.Definition = define.Value.getBuffer(); - dxMacrosStorage.add(dxMacro); - } - for( auto& define : translationUnit->preprocessorDefinitions ) - { - D3D_SHADER_MACRO dxMacro; - dxMacro.Name = define.Key.getBuffer(); - dxMacro.Definition = define.Value.getBuffer(); - dxMacrosStorage.add(dxMacro); - } - D3D_SHADER_MACRO nullTerminator = { 0, 0 }; - dxMacrosStorage.add(nullTerminator); - - dxMacros = dxMacrosStorage.getBuffer(); - } - - DWORD flags = 0; - - switch( targetReq->floatingPointMode ) - { - default: - break; - - case FloatingPointMode::Precise: - flags |= D3DCOMPILE_IEEE_STRICTNESS; - break; - } - - // Some of the `D3DCOMPILE_*` constants aren't available in all - // versions of `d3dcompiler.h`, so we define them here just in case - #ifndef D3DCOMPILE_ENABLE_UNBOUNDED_DESCRIPTOR_TABLES - #define D3DCOMPILE_ENABLE_UNBOUNDED_DESCRIPTOR_TABLES (1 << 20) - #endif - - #ifndef D3DCOMPILE_ALL_RESOURCES_BOUND - #define D3DCOMPILE_ALL_RESOURCES_BOUND (1 << 21) - #endif - - flags |= D3DCOMPILE_ENABLE_STRICTNESS; - flags |= D3DCOMPILE_ENABLE_UNBOUNDED_DESCRIPTOR_TABLES; - - auto linkage = compileRequest->getLinkage(); - switch( linkage->optimizationLevel ) - { - default: - break; - - case OptimizationLevel::None: flags |= D3DCOMPILE_OPTIMIZATION_LEVEL0; break; - case OptimizationLevel::Default: flags |= D3DCOMPILE_OPTIMIZATION_LEVEL1; break; - case OptimizationLevel::High: flags |= D3DCOMPILE_OPTIMIZATION_LEVEL2; break; - case OptimizationLevel::Maximal: flags |= D3DCOMPILE_OPTIMIZATION_LEVEL3; break; - } - - switch( linkage->debugInfoLevel ) - { - case DebugInfoLevel::None: - break; - - default: - flags |= D3DCOMPILE_DEBUG; - break; - } - - const String sourcePath = calcSourcePathForEntryPoint(endToEndReq, entryPointIndex); - - ComPtr codeBlob; - ComPtr diagnosticsBlob; - HRESULT hr = compileFunc( - hlslCode.begin(), - hlslCode.getLength(), - sourcePath.getBuffer(), - dxMacros, - nullptr, - getText(entryPoint->getName()).begin(), - GetHLSLProfileName(profile).getBuffer(), - flags, - 0, // unused: effect flags - codeBlob.writeRef(), - diagnosticsBlob.writeRef()); - - if (codeBlob && SLANG_SUCCEEDED(hr)) - { - byteCodeOut.addRange((uint8_t const*)codeBlob->GetBufferPointer(), (int)codeBlob->GetBufferSize()); - } - - if (FAILED(hr)) - { - reportExternalCompileError("fxc", hr, _getSlice(diagnosticsBlob), sink); - } - - return hr; - } - - SlangResult dissassembleDXBC( - BackEndCompileRequest* compileRequest, - void const* data, - size_t size, - String& assemOut) - { - assemOut = String(); - - auto session = compileRequest->getSession(); - auto sink = compileRequest->getSink(); - - auto disassembleFunc = (pD3DDisassemble)session->getSharedLibraryFunc(Session::SharedLibraryFuncType::Fxc_D3DDisassemble, sink); - if (!disassembleFunc) - { - return SLANG_E_NOT_FOUND; - } - - if (!data || !size) - { - return SLANG_FAIL; - } - - ComPtr codeBlob; - SlangResult res = disassembleFunc(data, size, 0, nullptr, codeBlob.writeRef()); - - if (codeBlob) - { - assemOut = _getSlice(codeBlob); - } - if (FAILED(res)) - { - // TODO(tfoley): need to figure out what to diagnose here... - reportExternalCompileError("fxc", res, UnownedStringSlice(), sink); - } - - return res; - } - - SlangResult emitDXBytecodeAssemblyForEntryPoint( - BackEndCompileRequest* compileRequest, - EntryPoint* entryPoint, - Int entryPointIndex, - TargetRequest* targetReq, - EndToEndCompileRequest* endToEndReq, - String& assemOut) - { - - List dxbc; - SLANG_RETURN_ON_FAIL(emitDXBytecodeForEntryPoint( - compileRequest, - entryPoint, - entryPointIndex, - targetReq, - endToEndReq, - dxbc)); - if (!dxbc.getCount()) - { - return SLANG_FAIL; - } - return dissassembleDXBC(compileRequest, dxbc.getBuffer(), dxbc.getCount(), assemOut); - } -#endif - -#if SLANG_ENABLE_DXIL_SUPPORT - -// Implementations in `dxc-support.cpp` - -SlangResult emitDXILForEntryPointUsingDXC( - BackEndCompileRequest* compileRequest, - EntryPoint* entryPoint, - Int entryPointIndex, - TargetRequest* targetReq, - EndToEndCompileRequest* endToEndReq, - List& outCode); - -SlangResult dissassembleDXILUsingDXC( - BackEndCompileRequest* compileRequest, - void const* data, - size_t size, - String& stringOut); - -#endif - -#if SLANG_ENABLE_GLSLANG_SUPPORT - SlangResult invokeGLSLCompiler( - BackEndCompileRequest* slangCompileRequest, - glslang_CompileRequest& request) - { - Session* session = slangCompileRequest->getSession(); - auto sink = slangCompileRequest->getSink(); - - auto glslang_compile = (glslang_CompileFunc)session->getSharedLibraryFunc(Session::SharedLibraryFuncType::Glslang_Compile, sink); - if (!glslang_compile) - { - return SLANG_FAIL; - } - - StringBuilder diagnosticOutput; - - auto diagnosticOutputFunc = [](void const* data, size_t size, void* userData) - { - (*(StringBuilder*)userData).append((char const*)data, (char const*)data + size); - }; - - request.diagnosticFunc = diagnosticOutputFunc; - request.diagnosticUserData = &diagnosticOutput; - - int err = glslang_compile(&request); - - if (err) - { - reportExternalCompileError("glslang", SLANG_FAIL, diagnosticOutput.getUnownedSlice(), sink); - return SLANG_FAIL; - } - - return SLANG_OK; - } - - SlangResult dissassembleSPIRV( - BackEndCompileRequest* slangRequest, - void const* data, - size_t size, - String& stringOut) - { - stringOut = String(); - - String output; - auto outputFunc = [](void const* data, size_t size, void* userData) - { - (*(String*)userData).append((char const*)data, (char const*)data + size); - }; - - glslang_CompileRequest request; - request.action = GLSLANG_ACTION_DISSASSEMBLE_SPIRV; - - request.sourcePath = nullptr; - - request.inputBegin = data; - request.inputEnd = (char*)data + size; - - request.outputFunc = outputFunc; - request.outputUserData = &output; - - SLANG_RETURN_ON_FAIL(invokeGLSLCompiler(slangRequest, request)); - - stringOut = output; - return SLANG_OK; - } - - SlangResult emitSPIRVForEntryPoint( - BackEndCompileRequest* slangRequest, - EntryPoint* entryPoint, - Int entryPointIndex, - TargetRequest* targetReq, - EndToEndCompileRequest* endToEndReq, - List& spirvOut) - { - spirvOut.clear(); - - String rawGLSL = emitGLSLForEntryPoint( - slangRequest, - entryPoint, - entryPointIndex, - targetReq, - endToEndReq); - maybeDumpIntermediate(slangRequest, rawGLSL.getBuffer(), CodeGenTarget::GLSL); - - auto outputFunc = [](void const* data, size_t size, void* userData) - { - ((List*)userData)->addRange((uint8_t*)data, size); - }; - - const String sourcePath = calcSourcePathForEntryPoint(endToEndReq, entryPointIndex); - - glslang_CompileRequest request; - request.action = GLSLANG_ACTION_COMPILE_GLSL_TO_SPIRV; - request.sourcePath = sourcePath.getBuffer(); - request.slangStage = (SlangStage)entryPoint->getStage(); - - request.inputBegin = rawGLSL.begin(); - request.inputEnd = rawGLSL.end(); - - request.outputFunc = outputFunc; - request.outputUserData = &spirvOut; - - SLANG_RETURN_ON_FAIL(invokeGLSLCompiler(slangRequest, request)); - return SLANG_OK; - } - - SlangResult emitSPIRVAssemblyForEntryPoint( - BackEndCompileRequest* slangRequest, - EntryPoint* entryPoint, - Int entryPointIndex, - TargetRequest* targetReq, - EndToEndCompileRequest* endToEndReq, - String& assemblyOut) - { - List spirv; - SLANG_RETURN_ON_FAIL(emitSPIRVForEntryPoint( - slangRequest, - entryPoint, - entryPointIndex, - targetReq, - endToEndReq, - spirv)); - - if (spirv.getCount() == 0) - return SLANG_FAIL; - - return dissassembleSPIRV(slangRequest, spirv.begin(), spirv.getCount(), assemblyOut); - } -#endif - - // Do emit logic for a single entry point - CompileResult emitEntryPoint( - BackEndCompileRequest* compileRequest, - EntryPoint* entryPoint, - Int entryPointIndex, - TargetRequest* targetReq, - EndToEndCompileRequest* endToEndReq) - { - CompileResult result; - - auto target = targetReq->target; - - switch (target) - { - case CodeGenTarget::HLSL: - { - String code = emitHLSLForEntryPoint( - compileRequest, - entryPoint, - entryPointIndex, - targetReq, - endToEndReq); - maybeDumpIntermediate(compileRequest, code.getBuffer(), target); - result = CompileResult(code); - } - break; - - case CodeGenTarget::GLSL: - { - String code = emitGLSLForEntryPoint( - compileRequest, - entryPoint, - entryPointIndex, - targetReq, - endToEndReq); - maybeDumpIntermediate(compileRequest, code.getBuffer(), target); - result = CompileResult(code); - } - break; - - case CodeGenTarget::CPPSource: - case CodeGenTarget::CSource: - { - return emitEntryPoint( - compileRequest, - entryPoint, - target, - targetReq); - } - break; - -#if SLANG_ENABLE_DXBC_SUPPORT - case CodeGenTarget::DXBytecode: - { - List code; - if (SLANG_SUCCEEDED(emitDXBytecodeForEntryPoint( - compileRequest, - entryPoint, - entryPointIndex, - targetReq, - endToEndReq, - code))) - { - maybeDumpIntermediate(compileRequest, code.getBuffer(), code.getCount(), target); - result = CompileResult(code); - } - } - break; - - case CodeGenTarget::DXBytecodeAssembly: - { - String code; - if (SLANG_SUCCEEDED(emitDXBytecodeAssemblyForEntryPoint( - compileRequest, - entryPoint, - entryPointIndex, - targetReq, - endToEndReq, - code))) - { - maybeDumpIntermediate(compileRequest, code.getBuffer(), target); - result = CompileResult(code); - } - } - break; -#endif - -#if SLANG_ENABLE_DXIL_SUPPORT - case CodeGenTarget::DXIL: - { - List code; - if (SLANG_SUCCEEDED(emitDXILForEntryPointUsingDXC( - compileRequest, - entryPoint, - entryPointIndex, - targetReq, - endToEndReq, - code))) - { - maybeDumpIntermediate(compileRequest, code.getBuffer(), code.getCount(), target); - result = CompileResult(code); - } - } - break; - - case CodeGenTarget::DXILAssembly: - { - List code; - if (SLANG_SUCCEEDED(emitDXILForEntryPointUsingDXC( - compileRequest, - entryPoint, - entryPointIndex, - targetReq, - endToEndReq, - code))) - { - String assembly; - dissassembleDXILUsingDXC( - compileRequest, - code.getBuffer(), - code.getCount(), - assembly); - - maybeDumpIntermediate(compileRequest, assembly.getBuffer(), target); - - result = CompileResult(assembly); - } - } - break; -#endif - - case CodeGenTarget::SPIRV: - { - List code; - if (SLANG_SUCCEEDED(emitSPIRVForEntryPoint( - compileRequest, - entryPoint, - entryPointIndex, - targetReq, - endToEndReq, - code))) - { - maybeDumpIntermediate(compileRequest, code.getBuffer(), code.getCount(), target); - result = CompileResult(code); - } - } - break; - - case CodeGenTarget::SPIRVAssembly: - { - String code; - if (SLANG_SUCCEEDED(emitSPIRVAssemblyForEntryPoint( - compileRequest, - entryPoint, - entryPointIndex, - targetReq, - endToEndReq, - code))) - { - maybeDumpIntermediate(compileRequest, code.getBuffer(), target); - result = CompileResult(code); - } - } - break; - - case CodeGenTarget::None: - // The user requested no output - break; - - // Note(tfoley): We currently hit this case when compiling the stdlib - case CodeGenTarget::Unknown: - break; - - default: - SLANG_UNEXPECTED("unhandled code generation target"); - break; - } - - return result; - } - - enum class OutputFileKind - { - Text, - Binary, - }; - - static void writeOutputFile( - BackEndCompileRequest* compileRequest, - FILE* file, - String const& path, - void const* data, - size_t size) - { - size_t count = fwrite(data, size, 1, file); - if (count != 1) - { - compileRequest->getSink()->diagnose( - SourceLoc(), - Diagnostics::cannotWriteOutputFile, - path); - } - } - - static void writeOutputFile( - BackEndCompileRequest* compileRequest, - ISlangWriter* writer, - String const& path, - void const* data, - size_t size) - { - - if (SLANG_FAILED(writer->write((const char*)data, size))) - { - compileRequest->getSink()->diagnose( - SourceLoc(), - Diagnostics::cannotWriteOutputFile, - path); - } - } - - static void writeOutputFile( - BackEndCompileRequest* compileRequest, - String const& path, - void const* data, - size_t size, - OutputFileKind kind) - { - FILE* file = fopen( - path.getBuffer(), - kind == OutputFileKind::Binary ? "wb" : "w"); - if (!file) - { - compileRequest->getSink()->diagnose( - SourceLoc(), - Diagnostics::cannotWriteOutputFile, - path); - return; - } - - writeOutputFile(compileRequest, file, path, data, size); - fclose(file); - } - - static void writeEntryPointResultToFile( - BackEndCompileRequest* compileRequest, - EntryPoint* entryPoint, - String const& outputPath, - CompileResult const& result) - { - SLANG_UNUSED(entryPoint); - - switch (result.format) - { - case ResultFormat::Text: - { - auto text = result.outputString; - writeOutputFile(compileRequest, - outputPath, - text.begin(), - text.end() - text.begin(), - OutputFileKind::Text); - } - break; - - case ResultFormat::Binary: - { - auto& data = result.outputBinary; - writeOutputFile(compileRequest, - outputPath, - data.begin(), - data.end() - data.begin(), - OutputFileKind::Binary); - } - break; - - default: - SLANG_UNEXPECTED("unhandled output format"); - break; - } - - } - - static void writeOutputToConsole( - ISlangWriter* writer, - String const& text) - { - writer->write(text.getBuffer(), text.getLength()); - } - - static void writeEntryPointResultToStandardOutput( - EndToEndCompileRequest* compileRequest, - EntryPoint* entryPoint, - TargetRequest* targetReq, - CompileResult const& result) - { - SLANG_UNUSED(entryPoint); - - ISlangWriter* writer = compileRequest->getWriter(WriterChannel::StdOutput); - auto backEndReq = compileRequest->getBackEndReq(); - - switch (result.format) - { - case ResultFormat::Text: - writeOutputToConsole(writer, result.outputString); - break; - - case ResultFormat::Binary: - { - auto& data = result.outputBinary; - - if (writer->isConsole()) - { - // Writing to console, so we need to generate text output. - - switch (targetReq->target) - { - #if SLANG_ENABLE_DXBC_SUPPORT - case CodeGenTarget::DXBytecode: - { - String assembly; - dissassembleDXBC(backEndReq, - data.begin(), - data.end() - data.begin(), assembly); - writeOutputToConsole(writer, assembly); - } - break; - #endif - - #if SLANG_ENABLE_DXIL_SUPPORT - case CodeGenTarget::DXIL: - { - String assembly; - dissassembleDXILUsingDXC(backEndReq, - data.begin(), - data.end() - data.begin(), - assembly); - writeOutputToConsole(writer, assembly); - } - break; - #endif - - case CodeGenTarget::SPIRV: - { - String assembly; - dissassembleSPIRV(backEndReq, - data.begin(), - data.end() - data.begin(), assembly); - writeOutputToConsole(writer, assembly); - } - break; - - default: - SLANG_UNEXPECTED("unhandled output format"); - return; - } - } - else - { - // Redirecting stdout to a file, so do the usual thing - writer->setMode(SLANG_WRITER_MODE_BINARY); - - writeOutputFile( - backEndReq, - writer, - "stdout", - data.begin(), - data.end() - data.begin()); - } - } - break; - - default: - SLANG_UNEXPECTED("unhandled output format"); - break; - } - - } - - static void writeEntryPointResult( - EndToEndCompileRequest* compileRequest, - EntryPoint* entryPoint, - TargetRequest* targetReq, - Int entryPointIndex) - { - auto program = compileRequest->getSpecializedProgram(); - auto targetProgram = program->getTargetProgram(targetReq); - auto backEndReq = compileRequest->getBackEndReq(); - - auto& result = targetProgram->getExistingEntryPointResult(entryPointIndex); - - // Skip the case with no output - if (result.format == ResultFormat::None) - return; - - // It is possible that we are dynamically discovering entry - // points (using `[shader(...)]` attributes), so that there - // might be entry points added to the program that did not - // get paths specified via command-line options. - // - RefPtr targetInfo; - if(compileRequest->targetInfos.TryGetValue(targetReq, targetInfo)) - { - String outputPath; - if(targetInfo->entryPointOutputPaths.TryGetValue(entryPointIndex, outputPath)) - { - writeEntryPointResultToFile(backEndReq, entryPoint, outputPath, result); - return; - } - } - - writeEntryPointResultToStandardOutput(compileRequest, entryPoint, targetReq, result); - } - - void generateOutputForTarget( - BackEndCompileRequest* compileReq, - TargetRequest* targetReq, - EndToEndCompileRequest* endToEndReq) - { - auto program = compileReq->getProgram(); - auto targetProgram = program->getTargetProgram(targetReq); - - // Generate target code any entry points that - // have been requested for compilation. - auto entryPointCount = program->getEntryPointCount(); - for(Index ii = 0; ii < entryPointCount; ++ii) - { - auto entryPoint = program->getEntryPoint(ii); - CompileResult entryPointResult = emitEntryPoint( - compileReq, - entryPoint, - ii, - targetReq, - endToEndReq); - targetProgram->setEntryPointResult(ii, entryPointResult); - } - } - - static void _generateOutput( - BackEndCompileRequest* compileRequest, - EndToEndCompileRequest* endToEndReq) - { - // Go through the code-generation targets that the user - // has specified, and generate code for each of them. - // - auto linkage = compileRequest->getLinkage(); - for (auto targetReq : linkage->targets) - { - generateOutputForTarget(compileRequest, targetReq, endToEndReq); - } - } - - void generateOutput( - BackEndCompileRequest* compileRequest) - { - _generateOutput(compileRequest, nullptr); - } - - void generateOutput( - EndToEndCompileRequest* compileRequest) - { - _generateOutput(compileRequest->getBackEndReq(), compileRequest); - - // If we are in command-line mode, we might be expected to actually - // write output to one or more files here. - - if (compileRequest->isCommandLineCompile) - { - auto linkage = compileRequest->getLinkage(); - auto program = compileRequest->getSpecializedProgram(); - for (auto targetReq : linkage->targets) - { - Index entryPointCount = program->getEntryPointCount(); - for (Index ee = 0; ee < entryPointCount; ++ee) - { - writeEntryPointResult( - compileRequest, - program->getEntryPoint(ee), - targetReq, - ee); - } - } - } - } - - // Debug logic for dumping intermediate outputs - - // - - void dumpIntermediate( - BackEndCompileRequest*, - void const* data, - size_t size, - char const* ext, - bool isBinary) - { - // Try to generate a unique ID for the file to dump, - // even in cases where there might be multiple threads - // doing compilation. - // - // This is primarily a debugging aid, so we don't - // really need/want to do anything too elaborate - - static uint32_t counter = 0; -#ifdef WIN32 - uint32_t id = InterlockedIncrement(&counter); -#else - // TODO: actually implement the case for other platforms - uint32_t id = counter++; -#endif - - String path; - path.append("slang-dump-"); - path.append(id); - path.append(ext); - - FILE* file = fopen(path.getBuffer(), isBinary ? "wb" : "w"); - if (!file) return; - - fwrite(data, size, 1, file); - fclose(file); - } - - void dumpIntermediateText( - BackEndCompileRequest* compileRequest, - void const* data, - size_t size, - char const* ext) - { - dumpIntermediate(compileRequest, data, size, ext, false); - } - - void dumpIntermediateBinary( - BackEndCompileRequest* compileRequest, - void const* data, - size_t size, - char const* ext) - { - dumpIntermediate(compileRequest, data, size, ext, true); - } - - void maybeDumpIntermediate( - BackEndCompileRequest* compileRequest, - void const* data, - size_t size, - CodeGenTarget target) - { - if (!compileRequest->shouldDumpIntermediates) - return; - - switch (target) - { - default: - break; - - case CodeGenTarget::HLSL: - dumpIntermediateText(compileRequest, data, size, ".hlsl"); - break; - - case CodeGenTarget::GLSL: - dumpIntermediateText(compileRequest, data, size, ".glsl"); - break; - - case CodeGenTarget::SPIRVAssembly: - dumpIntermediateText(compileRequest, data, size, ".spv.asm"); - break; - -#if 0 - case CodeGenTarget::SlangIRAssembly: - dumpIntermediateText(compileRequest, data, size, ".slang-ir.asm"); - break; -#endif - - case CodeGenTarget::SPIRV: - dumpIntermediateBinary(compileRequest, data, size, ".spv"); - { - String spirvAssembly; - dissassembleSPIRV(compileRequest, data, size, spirvAssembly); - dumpIntermediateText(compileRequest, spirvAssembly.begin(), spirvAssembly.getLength(), ".spv.asm"); - } - break; - - #if SLANG_ENABLE_DXBC_SUPPORT - case CodeGenTarget::DXBytecodeAssembly: - dumpIntermediateText(compileRequest, data, size, ".dxbc.asm"); - break; - - case CodeGenTarget::DXBytecode: - dumpIntermediateBinary(compileRequest, data, size, ".dxbc"); - { - String dxbcAssembly; - dissassembleDXBC(compileRequest, data, size, dxbcAssembly); - dumpIntermediateText(compileRequest, dxbcAssembly.begin(), dxbcAssembly.getLength(), ".dxbc.asm"); - } - break; - #endif - - #if SLANG_ENABLE_DXIL_SUPPORT - case CodeGenTarget::DXILAssembly: - dumpIntermediateText(compileRequest, data, size, ".dxil.asm"); - break; - - case CodeGenTarget::DXIL: - dumpIntermediateBinary(compileRequest, data, size, ".dxil"); - { - String dxilAssembly; - dissassembleDXILUsingDXC(compileRequest, data, size, dxilAssembly); - dumpIntermediateText(compileRequest, dxilAssembly.begin(), dxilAssembly.getLength(), ".dxil.asm"); - } - break; - #endif - } - } - - void maybeDumpIntermediate( - BackEndCompileRequest* compileRequest, - char const* text, - CodeGenTarget target) - { - if (!compileRequest->shouldDumpIntermediates) - return; - - maybeDumpIntermediate(compileRequest, text, strlen(text), target); - } - -} diff --git a/source/slang/compiler.h b/source/slang/compiler.h deleted file mode 100644 index 5d9e47aee..000000000 --- a/source/slang/compiler.h +++ /dev/null @@ -1,1423 +0,0 @@ -#ifndef SLANG_COMPILER_H_INCLUDED -#define SLANG_COMPILER_H_INCLUDED - -#include "../core/basic.h" -#include "../core/slang-shared-library.h" - -#include "../../slang-com-ptr.h" - -#include "diagnostics.h" -#include "name.h" -#include "profile.h" -#include "syntax.h" - -#include "../../slang.h" - -namespace Slang -{ - struct PathInfo; - struct IncludeHandler; - class ProgramLayout; - class PtrType; - class TargetProgram; - class TargetRequest; - class TypeLayout; - - enum class CompilerMode - { - ProduceLibrary, - ProduceShader, - GenerateChoice - }; - - enum class StageTarget - { - Unknown, - VertexShader, - HullShader, - DomainShader, - GeometryShader, - FragmentShader, - ComputeShader, - }; - - enum class CodeGenTarget - { - Unknown = SLANG_TARGET_UNKNOWN, - None = SLANG_TARGET_NONE, - GLSL = SLANG_GLSL, - GLSL_Vulkan = SLANG_GLSL_VULKAN, - GLSL_Vulkan_OneDesc = SLANG_GLSL_VULKAN_ONE_DESC, - HLSL = SLANG_HLSL, - SPIRV = SLANG_SPIRV, - SPIRVAssembly = SLANG_SPIRV_ASM, - DXBytecode = SLANG_DXBC, - DXBytecodeAssembly = SLANG_DXBC_ASM, - DXIL = SLANG_DXIL, - DXILAssembly = SLANG_DXIL_ASM, - CSource = SLANG_C_SOURCE, - CPPSource = SLANG_CPP_SOURCE, - }; - - enum class ContainerFormat - { - None = SLANG_CONTAINER_FORMAT_NONE, - SlangModule = SLANG_CONTAINER_FORMAT_SLANG_MODULE, - }; - - enum class LineDirectiveMode : SlangLineDirectiveMode - { - Default = SLANG_LINE_DIRECTIVE_MODE_DEFAULT, - None = SLANG_LINE_DIRECTIVE_MODE_NONE, - Standard = SLANG_LINE_DIRECTIVE_MODE_STANDARD, - GLSL = SLANG_LINE_DIRECTIVE_MODE_GLSL, - }; - - enum class ResultFormat - { - None, - Text, - Binary - }; - - // When storing the layout for a matrix-type - // value, we need to know whether it has been - // laid out with row-major or column-major - // storage. - // - enum MatrixLayoutMode - { - kMatrixLayoutMode_RowMajor = SLANG_MATRIX_LAYOUT_ROW_MAJOR, - kMatrixLayoutMode_ColumnMajor = SLANG_MATRIX_LAYOUT_COLUMN_MAJOR, - }; - - enum class DebugInfoLevel : SlangDebugInfoLevel - { - None = SLANG_DEBUG_INFO_LEVEL_NONE, - Minimal = SLANG_DEBUG_INFO_LEVEL_MINIMAL, - Standard = SLANG_DEBUG_INFO_LEVEL_STANDARD, - Maximal = SLANG_DEBUG_INFO_LEVEL_MAXIMAL, - }; - - enum class OptimizationLevel : SlangOptimizationLevel - { - None = SLANG_OPTIMIZATION_LEVEL_NONE, - Default = SLANG_OPTIMIZATION_LEVEL_DEFAULT, - High = SLANG_OPTIMIZATION_LEVEL_HIGH, - Maximal = SLANG_OPTIMIZATION_LEVEL_MAXIMAL, - }; - - class Linkage; - class Module; - class Program; - class FrontEndCompileRequest; - class BackEndCompileRequest; - class EndToEndCompileRequest; - class TranslationUnitRequest; - - // Result of compiling an entry point. - // Should only ever be string OR binary. - class CompileResult - { - public: - CompileResult() = default; - CompileResult(String const& str) : format(ResultFormat::Text), outputString(str) {} - CompileResult(List const& buffer) : format(ResultFormat::Binary), outputBinary(buffer) {} - - void append(CompileResult const& result); - - ComPtr getBlob(); - - ResultFormat format = ResultFormat::None; - String outputString; - List outputBinary; - - ComPtr blob; - }; - - /// Information collected about global or entry-point shader parameters - struct ShaderParamInfo - { - DeclRef paramDeclRef; - UInt firstExistentialTypeSlot = 0; - UInt existentialTypeSlotCount = 0; - }; - - /// Extended information specific to global shader parameters - struct GlobalShaderParamInfo : ShaderParamInfo - { - // Additional global-scope declarations that are conceptually - // declaring the "same" parameter as the `paramDeclRef`. - List> additionalParamDeclRefs; - }; - - /// A request for the front-end to find and validate an entry-point function - struct FrontEndEntryPointRequest : RefObject - { - public: - /// Create a request for an entry point. - FrontEndEntryPointRequest( - FrontEndCompileRequest* compileRequest, - int translationUnitIndex, - Name* name, - Profile profile); - - /// Get the parent front-end compile request. - FrontEndCompileRequest* getCompileRequest() { return m_compileRequest; } - - /// Get the translation unit that contains the entry point. - TranslationUnitRequest* getTranslationUnit(); - - /// Get the name of the entry point to find. - Name* getName() { return m_name; } - - /// Get the stage that the entry point is to be compiled for - Stage getStage() { return m_profile.GetStage(); } - - /// Get the profile that the entry point is to be compiled for - Profile getProfile() { return m_profile; } - - private: - // The parent compile request - FrontEndCompileRequest* m_compileRequest; - - // The index of the translation unit that will hold the entry point - int m_translationUnitIndex; - - // The name of the entry point function to look for - Name* m_name; - - // The profile to compile for (including stage) - Profile m_profile; - }; - - /// Tracks an ordered list of modules that something depends on. - struct ModuleDependencyList - { - public: - /// Get the list of modules that are depended on. - List> const& getModuleList() { return m_moduleList; } - - /// Add a module and everything it depends on to the list. - void addDependency(Module* module); - - /// Add a module to the list, but not the modules it depends on. - void addLeafDependency(Module* module); - - private: - void _addDependency(Module* module); - - List> m_moduleList; - HashSet m_moduleSet; - }; - - /// Tracks an unordered list of filesystem paths that something depends on - struct FilePathDependencyList - { - public: - /// Get the list of paths that are depended on. - List const& getFilePathList() { return m_filePathList; } - - /// Add a path to the list, if it is not already present - void addDependency(String const& path); - - /// Add all of the paths that `module` depends on to the list - void addDependency(Module* module); - - private: - - // TODO: We are using a `HashSet` here to deduplicate - // the paths so that we don't return the same path - // multiple times from `getFilePathList`, but because - // order isn't important, we could potentially do better - // in terms of memory (at some cost in performance) by - // just sorting the `m_filePathList` every once in - // a while and then deduplicating. - - List m_filePathList; - HashSet m_filePathSet; - }; - - /// Describes an entry point for the purposes of layout and code generation. - /// - /// This class also tracks any generic arguments to the entry point, - /// in the case that it is a specialization of a generic entry point. - /// - /// There is also a provision for creating a "dummy" entry point for - /// the purposes of pass-through compilation modes. Only the - /// `getName()` and `getProfile()` methods should be expected to - /// return useful data on pass-through entry points. - /// - class EntryPoint : public RefObject - { - public: - /// Create an entry point that refers to the given function. - static RefPtr create( - DeclRef funcDeclRef, - Profile profile); - - /// Get the function decl-ref, including any generic arguments. - DeclRef getFuncDeclRef() { return m_funcDeclRef; } - - /// Get the function declaration (without generic arguments). - RefPtr getFuncDecl() { return m_funcDeclRef.getDecl(); } - - /// Get the name of the entry point - Name* getName() { return m_name; } - - /// Get the profile associated with the entry point - /// - /// Note: only the stage part of the profile is expected - /// to contain useful data, but certain legacy code paths - /// allow for "shader model" information to come via this path. - /// - Profile getProfile() { return m_profile; } - - /// Get the stage that the entry point is for. - Stage getStage() { return m_profile.GetStage(); } - - /// Get the module that contains the entry point. - Module* getModule(); - - /// Get the linkage that contains the module for this entry point. - Linkage* getLinkage(); - - /// Get a list of modules that this entry point depends on. - /// - /// This will include the module that defines the entry point (see `getModule()`), - /// but may also include modules that are required by its generic type arguments. - /// - List> getModuleDependencies() { return m_dependencyList.getModuleList(); } - - /// Get a list of tagged-union types referenced by the entry point's generic parameters. - List> const& getTaggedUnionTypes() { return m_taggedUnionTypes; } - - /// Create a dummy `EntryPoint` that is only usable for pass-through compilation. - static RefPtr createDummyForPassThrough( - Name* name, - Profile profile); - - /// Get the number of existential type parameters for the entry point. - Index getExistentialTypeParamCount() { return m_existentialSlots.paramTypes.getCount(); } - - /// Get the existential type parameter at `index`. - Type* getExistentialTypeParam(Index index) { return m_existentialSlots.paramTypes[index]; } - - /// Get the number of arguments supplied for existential type parameters. - /// - /// Note that the number of arguments may not match the number of parameters. - /// In particular, an unspecialized entry point may have many parameters, but zero arguments. - Index getExistentialTypeArgCount() { return m_existentialSlots.args.getCount(); } - - /// Get the existential type argument (type and witness table) at `index`. - ExistentialTypeSlots::Arg getExistentialTypeArg(Index index) { return m_existentialSlots.args[index]; } - - /// Get an array of all existential type arguments. - ExistentialTypeSlots::Arg const* getExistentialTypeArgs() { return m_existentialSlots.args.getBuffer(); } - - /// Get an array of all entry-point shader parameters. - List const& getShaderParams() { return m_shaderParams; } - - void _specializeExistentialTypeParams( - List> const& args, - DiagnosticSink* sink); - - private: - EntryPoint( - Name* name, - Profile profile, - DeclRef funcDeclRef); - - void _collectShaderParams(); - - // The name of the entry point function (e.g., `main`) - // - Name* m_name = nullptr; - - // The declaration of the entry-point function itself. - // - DeclRef m_funcDeclRef; - - /// The existential/interface slots associated with the entry point parameter scope. - ExistentialTypeSlots m_existentialSlots; - - /// Information about entry-point parameters - List m_shaderParams; - - // The profile that the entry point will be compiled for - // (this is a combination of the target stage, and also - // a feature level that sets capabilities) - // - // Note: the profile-version part of this should probably - // be moving towards deprecation, in favor of the version - // information (e.g., "Shader Model 5.1") always coming - // from the target, while the stage part is all that is - // intrinsic to the entry point. - // - Profile m_profile; - - // Any tagged union types that were referenced by the generic arguments of the entry point. - List> m_taggedUnionTypes; - - // Modules the entry point depends on. - ModuleDependencyList m_dependencyList; - }; - - enum class PassThroughMode : SlangPassThrough - { - None = SLANG_PASS_THROUGH_NONE, // don't pass through: use Slang compiler - fxc = SLANG_PASS_THROUGH_FXC, // pass through HLSL to `D3DCompile` API - dxc = SLANG_PASS_THROUGH_DXC, // pass through HLSL to `IDxcCompiler` API - glslang = SLANG_PASS_THROUGH_GLSLANG, // pass through GLSL to `glslang` library - }; - - class SourceFile; - - /// A module of code that has been compiled through the front-end - /// - /// A module comprises all the code from one translation unit (which - /// may span multiple Slang source files), and provides access - /// to both the AST and IR representations of that code. - /// - class Module : public RefObject - { - public: - /// Create a module (initially empty). - Module(Linkage* linkage); - - /// Get the parent linkage of this module. - Linkage* getLinkage() { return m_linkage; } - - /// Get the AST for the module (if it has been parsed) - ModuleDecl* getModuleDecl() { return m_moduleDecl; } - - /// The the IR for the module (if it has been generated) - IRModule* getIRModule() { return m_irModule; } - - /// Get the list of other modules this module depends on - List> const& getModuleDependencyList() { return m_moduleDependencyList.getModuleList(); } - - /// Get the list of filesystem paths this module depends on - List const& getFilePathDependencyList() { return m_filePathDependencyList.getFilePathList(); } - - /// Register a module that this module depends on - void addModuleDependency(Module* module); - - /// Register a filesystem path that this module depends on - void addFilePathDependency(String const& path); - - /// Set the AST for this module. - /// - /// This should only be called once, during creation of the module. - /// - void setModuleDecl(ModuleDecl* moduleDecl) { m_moduleDecl = moduleDecl; } - - /// Set the IR for this module. - /// - /// This should only be called once, during creation of the module. - /// - void setIRModule(IRModule* irModule) { m_irModule = irModule; } - - private: - // The parent linkage - Linkage* m_linkage = nullptr; - - // The AST for the module - RefPtr m_moduleDecl; - - // The IR for the module - RefPtr m_irModule = nullptr; - - // List of modules this module depends on - ModuleDependencyList m_moduleDependencyList; - - // List of filesystem paths this module depends on - FilePathDependencyList m_filePathDependencyList; - }; - typedef Module LoadedModule; - - /// A request for the front-end to compile a translation unit. - class TranslationUnitRequest : public RefObject - { - public: - TranslationUnitRequest( - FrontEndCompileRequest* compileRequest); - - // The parent compile request - FrontEndCompileRequest* compileRequest = nullptr; - - // The language in which the source file(s) - // are assumed to be written - SourceLanguage sourceLanguage = SourceLanguage::Unknown; - - // The source file(s) that will be compiled to form this translation unit - // - // Usually, for HLSL or GLSL there will be only one file. - List m_sourceFiles; - - List const& getSourceFiles() { return m_sourceFiles; } - void addSourceFile(SourceFile* sourceFile); - - // The entry points associated with this translation unit - List> entryPoints; - - // Preprocessor definitions to use for this translation unit only - // (whereas the ones on `compileRequest` will be shared) - Dictionary preprocessorDefinitions; - - /// The name that will be used for the module this translation unit produces. - Name* moduleName = nullptr; - - /// Result of compiling this translation unit (a module) - RefPtr module; - - Module* getModule() { return module; } - RefPtr getModuleDecl() { return module->getModuleDecl(); } - - Session* getSession(); - NamePool* getNamePool(); - SourceManager* getSourceManager(); - }; - - enum class FloatingPointMode : SlangFloatingPointMode - { - Default = SLANG_FLOATING_POINT_MODE_DEFAULT, - Fast = SLANG_FLOATING_POINT_MODE_FAST, - Precise = SLANG_FLOATING_POINT_MODE_PRECISE, - }; - - enum class WriterChannel : SlangWriterChannel - { - Diagnostic = SLANG_WRITER_CHANNEL_DIAGNOSTIC, - StdOutput = SLANG_WRITER_CHANNEL_STD_OUTPUT, - StdError = SLANG_WRITER_CHANNEL_STD_ERROR, - CountOf = SLANG_WRITER_CHANNEL_COUNT_OF, - }; - - enum class WriterMode : SlangWriterMode - { - Text = SLANG_WRITER_MODE_TEXT, - Binary = SLANG_WRITER_MODE_BINARY, - }; - - /// A request to generate output in some target format. - class TargetRequest : public RefObject - { - public: - Linkage* linkage; - CodeGenTarget target; - SlangTargetFlags targetFlags = 0; - Slang::Profile targetProfile = Slang::Profile(); - FloatingPointMode floatingPointMode = FloatingPointMode::Default; - - Linkage* getLinkage() { return linkage; } - CodeGenTarget getTarget() { return target; } - Profile getTargetProfile() { return targetProfile; } - FloatingPointMode getFloatingPointMode() { return floatingPointMode; } - - Session* getSession(); - MatrixLayoutMode getDefaultMatrixLayoutMode(); - - // TypeLayouts created on the fly by reflection API - Dictionary> typeLayouts; - - Dictionary>& getTypeLayouts() { return typeLayouts; } - }; - - /// Are we generating code for a D3D API? - bool isD3DTarget(TargetRequest* targetReq); - - /// Are we generating code for a Khronos API (OpenGL or Vulkan)? - bool isKhronosTarget(TargetRequest* targetReq); - - // Compute the "effective" profile to use when outputting the given entry point - // for the chosen code-generation target. - // - // The stage of the effective profile will always come from the entry point, while - // the profile version (aka "shader model") will be computed as follows: - // - // - If the entry point and target belong to the same profile family, then take - // the latest version between the two (e.g., if the entry point specified `ps_5_1` - // and the target specifies `sm_5_0` then use `sm_5_1` as the version). - // - // - If the entry point and target disagree on the profile family, always use the - // profile family and version from the target. - // - Profile getEffectiveProfile(EntryPoint* entryPoint, TargetRequest* target); - - - // A directory to be searched when looking for files (e.g., `#include`) - struct SearchDirectory - { - SearchDirectory() = default; - SearchDirectory(SearchDirectory const& other) = default; - SearchDirectory(String const& path) - : path(path) - {} - - String path; - }; - - /// A list of directories to search for files (e.g., `#include`) - struct SearchDirectoryList - { - // A parent list that should also be searched - SearchDirectoryList* parent = nullptr; - - // Directories to be searched - List searchDirectories; - }; - - /// Create a blob that will retain (a copy of) raw data. - /// - ComPtr createRawBlob(void const* data, size_t size); - - /// A context for loading and re-using code modules. - class Linkage : public RefObject - { - public: - /// Create an initially-empty linkage - Linkage(Session* session); - - /// Get the parent session for this linkage - Session* getSession() { return m_session; } - - // Information on the targets we are being asked to - // generate code for. - List> targets; - - // Directories to search for `#include` files or `import`ed modules - SearchDirectoryList searchDirectories; - - SearchDirectoryList const& getSearchDirectories() { return searchDirectories; } - - // Definitions to provide during preprocessing - Dictionary preprocessorDefinitions; - - // Source manager to help track files loaded - SourceManager m_defaultSourceManager; - SourceManager* m_sourceManager = nullptr; - - // Name pool for looking up names - NamePool namePool; - - NamePool* getNamePool() { return &namePool; } - - // Modules that have been dynamically loaded via `import` - // - // This is a list of unique modules loaded, in the order they were encountered. - List > loadedModulesList; - - // Map from the path (or uniqueIdentity if available) of a module file to its definition - Dictionary> mapPathToLoadedModule; - - // Map from the logical name of a module to its definition - Dictionary> mapNameToLoadedModules; - - // The resulting specialized IR module for each entry point request - List> compiledModules; - - /// File system implementation to use when loading files from disk. - /// - /// If this member is `null`, a default implementation that tries - /// to use the native OS filesystem will be used instead. - /// - ComPtr fileSystem; - - /// The extended file system implementation. Will be set to a default implementation - /// if fileSystem is nullptr. Otherwise it will either be fileSystem's interface, - /// or a wrapped impl that makes fileSystem operate as fileSystemExt - ComPtr fileSystemExt; - - ISlangFileSystemExt* getFileSystemExt() { return fileSystemExt; } - - /// Load a file into memory using the configured file system. - /// - /// @param path The path to attempt to load from - /// @param outBlob A destination pointer to receive the loaded blob - /// @returns A `SlangResult` to indicate success or failure. - /// - SlangResult loadFile(String const& path, ISlangBlob** outBlob); - - - RefPtr parseTypeString(String typeStr, RefPtr scope); - - Type* specializeType( - Type* unspecializedType, - Int argCount, - Type* const* args, - DiagnosticSink* sink); - - /// Add a mew target amd return its index. - UInt addTarget( - CodeGenTarget target); - - RefPtr loadModule( - Name* name, - const PathInfo& filePathInfo, - ISlangBlob* fileContentsBlob, - SourceLoc const& loc, - DiagnosticSink* sink); - - void loadParsedModule( - RefPtr translationUnit, - Name* name, - PathInfo const& pathInfo); - - /// Load a module of the given name. - Module* loadModule(String const& name); - - RefPtr findOrImportModule( - Name* name, - SourceLoc const& loc, - DiagnosticSink* sink); - - SourceManager* getSourceManager() - { - return m_sourceManager; - } - - /// Override the source manager for the linakge. - /// - /// This is only used to install a temporary override when - /// parsing stuff from strings (where we don't want to retain - /// full source files for the parsed result). - /// - /// TODO: We should remove the need for this hack. - /// - void setSourceManager(SourceManager* sourceManager) - { - m_sourceManager = sourceManager; - } - - void setFileSystem(ISlangFileSystem* fileSystem); - - /// The layout to use for matrices by default (row/column major) - MatrixLayoutMode defaultMatrixLayoutMode = kMatrixLayoutMode_ColumnMajor; - MatrixLayoutMode getDefaultMatrixLayoutMode() { return defaultMatrixLayoutMode; } - - DebugInfoLevel debugInfoLevel = DebugInfoLevel::None; - - OptimizationLevel optimizationLevel = OptimizationLevel::Default; - - private: - Session* m_session = nullptr; - - /// Tracks state of modules currently being loaded. - /// - /// This information is used to diagnose cases where - /// a user tries to recursively import the same module - /// (possibly along a transitive chain of `import`s). - /// - struct ModuleBeingImportedRAII - { - public: - ModuleBeingImportedRAII( - Linkage* linkage, - Module* module) - : linkage(linkage) - , module(module) - { - next = linkage->m_modulesBeingImported; - linkage->m_modulesBeingImported = this; - } - - ~ModuleBeingImportedRAII() - { - linkage->m_modulesBeingImported = next; - } - - Linkage* linkage; - Module* module; - ModuleBeingImportedRAII* next; - }; - - // Any modules currently being imported will be listed here - ModuleBeingImportedRAII* m_modulesBeingImported = nullptr; - - /// Is the given module in the middle of being imported? - bool isBeingImported(Module* module); - - List> m_specializedTypes; - }; - - /// Shared functionality between front- and back-end compile requests. - /// - /// This is the base class for both `FrontEndCompileRequest` and - /// `BackEndCompileRequest`, and allows a small number of parts of - /// the compiler to be easily invocable from either front-end or - /// back-end work. - /// - class CompileRequestBase : public RefObject - { - // TODO: We really shouldn't need this type in the long run. - // The few places that rely on it should be refactored to just - // depend on the underlying information (a linkage and a diagnostic - // sink) directly. - // - // The flags to control dumping and validation of IR should be - // moved to some kind of shared settings/options `struct` that - // both front-end and back-end requests can store. - - public: - Session* getSession(); - Linkage* getLinkage() { return m_linkage; } - DiagnosticSink* getSink() { return m_sink; } - SourceManager* getSourceManager() { return getLinkage()->getSourceManager(); } - NamePool* getNamePool() { return getLinkage()->getNamePool(); } - ISlangFileSystemExt* getFileSystemExt() { return getLinkage()->getFileSystemExt(); } - SlangResult loadFile(String const& path, ISlangBlob** outBlob) { return getLinkage()->loadFile(path, outBlob); } - - bool shouldDumpIR = false; - bool shouldValidateIR = false; - - protected: - CompileRequestBase( - Linkage* linkage, - DiagnosticSink* sink); - - private: - Linkage* m_linkage = nullptr; - DiagnosticSink* m_sink = nullptr; - }; - - /// A request to compile source code to an AST + IR. - class FrontEndCompileRequest : public CompileRequestBase - { - public: - FrontEndCompileRequest( - Linkage* linkage, - DiagnosticSink* sink); - - int addEntryPoint( - int translationUnitIndex, - String const& name, - Profile entryPointProfile); - - // Translation units we are being asked to compile - List > translationUnits; - - RefPtr getTranslationUnit(UInt index) { return translationUnits[index]; } - - // Compile flags to be shared by all translation units - SlangCompileFlags compileFlags = 0; - - // If true then generateIR will serialize out IR, and serialize back in again. Making - // serialization a bottleneck or firewall between the front end and the backend - bool useSerialIRBottleneck = false; - - // If true will serialize and de-serialize with debug information - bool verifyDebugSerialization = false; - - List> m_entryPointReqs; - - List> const& getEntryPointReqs() { return m_entryPointReqs; } - UInt getEntryPointReqCount() { return m_entryPointReqs.getCount(); } - FrontEndEntryPointRequest* getEntryPointReq(UInt index) { return m_entryPointReqs[index]; } - - // Directories to search for `#include` files or `import`ed modules - // NOTE! That for now these search directories are not settable via the API - // so the search directories on Linkage is used for #include as well as for modules. - SearchDirectoryList searchDirectories; - - SearchDirectoryList const& getSearchDirectories() { return searchDirectories; } - - // Definitions to provide during preprocessing - Dictionary preprocessorDefinitions; - - void parseTranslationUnit( - TranslationUnitRequest* translationUnit); - - // Perform primary semantic checking on all - // of the translation units in the program - void checkAllTranslationUnits(); - - void generateIR(); - - SlangResult executeActionsInner(); - - /// Add a translation unit to be compiled. - /// - /// @param language The source language that the translation unit will use (e.g., `SourceLanguage::Slang` - /// @param moduleName The name that will be used for the module compile from the translation unit. - /// @return The zero-based index of the translation unit in this compile request. - int addTranslationUnit(SourceLanguage language, Name* moduleName); - - /// Add a translation unit to be compiled. - /// - /// @param language The source language that the translation unit will use (e.g., `SourceLanguage::Slang` - /// @return The zero-based index of the translation unit in this compile request. - /// - /// The module name for the translation unit will be automatically generated. - /// If all translation units in a compile request use automatically generated - /// module names, then they are guaranteed not to conflict with one another. - /// - int addTranslationUnit(SourceLanguage language); - - void addTranslationUnitSourceFile( - int translationUnitIndex, - SourceFile* sourceFile); - - void addTranslationUnitSourceBlob( - int translationUnitIndex, - String const& path, - ISlangBlob* sourceBlob); - - void addTranslationUnitSourceString( - int translationUnitIndex, - String const& path, - String const& source); - - void addTranslationUnitSourceFile( - int translationUnitIndex, - String const& path); - - Program* getProgram() { return m_program; } - - private: - RefPtr m_program; - }; - - /// A collection of code modules and entry points that are intended to be used together. - /// - /// A `Program` establishes that certain pieces of code are intended - /// to be used togehter so that, e.g., layout can make sure to allocate - /// space for the global shader parameters in all referenced modules. - /// - class Program : public RefObject - { - public: - /// Create a new program, initially empty. - /// - /// All code loaded into the program must come - /// from the given `linkage`. - Program( - Linkage* linkage); - - /// Get the linkage that this program uses. - Linkage* getLinkage() { return m_linkage; } - - /// Get the number of entry points added to the program - Index getEntryPointCount() { return m_entryPoints.getCount(); } - - /// Get the entry point at the given `index`. - RefPtr getEntryPoint(Index index) { return m_entryPoints[index]; } - - /// Get the full ist of entry points on the program. - List> const& getEntryPoints() { return m_entryPoints; } - - /// Get the substitution (if any) that represents how global generics are specialized. - RefPtr getGlobalGenericSubstitution() { return m_globalGenericSubst; } - - /// Get the full list of modules this program depends on - List> getModuleDependencies() { return m_moduleDependencyList.getModuleList(); } - - /// Get the full list of filesystem paths this program depends on - List getFilePathDependencies() { return m_filePathDependencyList.getFilePathList(); } - - /// Get the target-specific version of this program for the given `target`. - /// - /// The `target` must be a target on the `Linkage` that was used to create this program. - TargetProgram* getTargetProgram(TargetRequest* target); - - /// Add a module (and everything it depends on) to the list of references - void addReferencedModule(Module* module); - - /// Add a module (but not the things it depends on) to the list of references - /// - /// This is a compatiblity hack for legacy compiler behavior. - void addReferencedLeafModule(Module* module); - - - /// Add an entry point to the program - /// - /// This also adds everything the entry point depends on to the list of references. - /// - void addEntryPoint(EntryPoint* entryPoint); - - /// Set the global generic argument substitution to use. - void setGlobalGenericSubsitution(RefPtr subst) - { - m_globalGenericSubst = subst; - } - - /// Parse a type from a string, in the context of this program. - /// - /// Any names in the string will be resolved using the modules - /// referenced by the program. - /// - /// On an error, returns null and reports diagnostic messages - /// to the provided `sink`. - /// - Type* getTypeFromString(String typeStr, DiagnosticSink* sink); - - /// Get the IR module that represents this program and its entry points. - /// - /// The IR module for a program tries to be minimal, and in the - /// common case will only include symbols with `[import]` declarations - /// for the entry point(s) of the program, and any types they - /// depend on. - /// - /// This IR module is intended to be linked against the IR modules - /// for all of the dependencies (see `getModuleDependencies()`) to - /// provide complete code. - /// - RefPtr getOrCreateIRModule(DiagnosticSink* sink); - - /// Get the number of existential type parameters for the program. - Index getExistentialTypeParamCount() { return m_globalExistentialSlots.paramTypes.getCount(); } - - /// Get the existential type parameter at `index`. - Type* getExistentialTypeParam(Index index) { return m_globalExistentialSlots.paramTypes[index]; } - - /// Get the number of arguments supplied for existential type parameters. - /// - /// Note that the number of arguments may not match the number of parameters. - /// In particular, an unspecialized program may have many parameters, but zero arguments. - Index getExistentialTypeArgCount() { return m_globalExistentialSlots.args.getCount(); } - - /// Get the existential type argument (type and witness table) at `index`. - ExistentialTypeSlots::Arg getExistentialTypeArg(Index index) { return m_globalExistentialSlots.args[index]; } - - /// Get an array of all existential type arguments. - ExistentialTypeSlots::Arg const* getExistentialTypeArgs() { return m_globalExistentialSlots.args.getBuffer(); } - - /// Get an array of all global shader parameters. - List const& getShaderParams() { return m_shaderParams; } - - void _collectShaderParams(DiagnosticSink* sink); - void _specializeExistentialTypeParams( - List> const& args, - DiagnosticSink* sink); - - private: - - // The linakge this program is associated with. - // - // Note that a `Program` keeps its associated linkage alive, - // and not vice versa. - // - RefPtr m_linkage; - - // Tracking data for the list of modules dependend on - ModuleDependencyList m_moduleDependencyList; - - // Tracking data for the list of filesystem paths dependend on - FilePathDependencyList m_filePathDependencyList; - - // Entry points that are part of the program. - List > m_entryPoints; - - // Specializations for global generic parameters (if any) - RefPtr m_globalGenericSubst; - - // The existential/interface slots associated with the global scope. - ExistentialTypeSlots m_globalExistentialSlots; - - /// Information about global shader parameters - List m_shaderParams; - - // Generated IR for this program. - RefPtr m_irModule; - - // Cache of target-specific programs for each target. - Dictionary> m_targetPrograms; - - // Any types looked up dynamically using `getTypeFromString` - Dictionary> m_types; - }; - - /// A `Program` specialized for a particular `TargetRequest` - class TargetProgram : public RefObject - { - public: - TargetProgram( - Program* program, - TargetRequest* targetReq); - - /// Get the underlying program - Program* getProgram() { return m_program; } - - /// Get the underlying target - TargetRequest* getTargetReq() { return m_targetReq; } - - /// Get the layout for the program on the target. - /// - /// If this is the first time the layout has been - /// requested, report any errors that arise during - /// layout to the given `sink`. - /// - ProgramLayout* getOrCreateLayout(DiagnosticSink* sink); - - /// Get the layout for the program on the taarget. - /// - /// This routine assumes that `getOrCreateLayout` - /// has already been called previously. - /// - ProgramLayout* getExistingLayout() - { - SLANG_ASSERT(m_layout); - return m_layout; - } - - /// Get the compiled code for an entry point on the target. - /// - /// This routine assumes code generation has already been - /// performed and called `setEntryPointResult`. - /// - CompileResult& getExistingEntryPointResult(Int entryPointIndex) - { - return m_entryPointResults[entryPointIndex]; - } - - // TODO: Need a lazy `getOrCreateEntryPointResult` - - /// Set the compiled code for an entry point. - /// - /// Should only be called by code generation. - void setEntryPointResult(Int entryPointIndex, CompileResult const& result) - { - m_entryPointResults[entryPointIndex] = result; - } - - private: - // The program being compiled or laid out - Program* m_program; - - // The target that code/layout will be generated for - TargetRequest* m_targetReq; - - // The computed layout, if it has been generated yet - RefPtr m_layout; - - // Generated compile results for each entry point - // in the parent `Program` (indexing matches - // the order they are given in the `Program`) - List m_entryPointResults; - }; - - /// A request to generate code for a program - class BackEndCompileRequest : public CompileRequestBase - { - public: - BackEndCompileRequest( - Linkage* linkage, - DiagnosticSink* sink, - Program* program = nullptr); - - // Should we dump intermediate results along the way, for debugging? - bool shouldDumpIntermediates = false; - - // How should `#line` directives be emitted (if at all)? - LineDirectiveMode lineDirectiveMode = LineDirectiveMode::Default; - - LineDirectiveMode getLineDirectiveMode() { return lineDirectiveMode; } - - Program* getProgram() { return m_program; } - void setProgram(Program* program) { m_program = program; } - - // Should R/W images without explicit formats be assumed to have "unknown" format? - // - // The default behavior is to make a best-effort guess as to what format is intended. - // - bool useUnknownImageFormatAsDefault = false; - - private: - RefPtr m_program; - }; - - /// A compile request that spans the front and back ends of the compiler - /// - /// This is what the command-line `slangc` uses, as well as the legacy - /// C API. It ties together the functionality of `Linkage`, - /// `FrontEndCompileRequest`, and `BackEndCompileRequest`, plus a small - /// number of additional features that primarily make sense for - /// command-line usage. - /// - class EndToEndCompileRequest : public RefObject - { - public: - EndToEndCompileRequest( - Session* session); - - // What container format are we being asked to generate? - // - // Note: This field is unused except by the options-parsing - // logic; it exists to support wriiting out binary modules - // once that feature is ready. - // - ContainerFormat containerFormat = ContainerFormat::None; - - // Path to output container to - // - // Note: This field exists to support wriiting out binary modules - // once that feature is ready. - // - String containerOutputPath; - - // Should we just pass the input to another compiler? - PassThroughMode passThrough = PassThroughMode::None; - - /// Source code for the generic arguments to use for the global generic parameters of the program. - List globalGenericArgStrings; - - /// Types to use to fill global existential "slots" - List globalExistentialSlotArgStrings; - - bool shouldSkipCodegen = false; - - // Are we being driven by the command-line `slangc`, and should act accordingly? - bool isCommandLineCompile = false; - - String mDiagnosticOutput; - - /// A blob holding the diagnostic output - ComPtr diagnosticOutputBlob; - - /// Per-entry-point information not tracked by other compile requests - class EntryPointInfo : public RefObject - { - public: - /// Source code for the generic arguments to use for the generic parameters of the entry point. - List genericArgStrings; - - /// Source code for the type arguments to plug into the existential type "slots" of the entry point - List existentialArgStrings; - }; - List entryPoints; - - /// Per-target information only needed for command-line compiles - class TargetInfo : public RefObject - { - public: - // Requested output paths for each entry point. - // An empty string indices no output desired for - // the given entry point. - Dictionary entryPointOutputPaths; - }; - Dictionary> targetInfos; - - Linkage* getLinkage() { return m_linkage; } - - int addEntryPoint( - int translationUnitIndex, - String const& name, - Profile profile, - List const & genericTypeNames); - - void setWriter(WriterChannel chan, ISlangWriter* writer); - ISlangWriter* getWriter(WriterChannel chan) const { return m_writers[int(chan)]; } - - SlangResult executeActionsInner(); - SlangResult executeActions(); - - Session* getSession() { return m_session; } - DiagnosticSink* getSink() { return &m_sink; } - NamePool* getNamePool() { return getLinkage()->getNamePool(); } - - FrontEndCompileRequest* getFrontEndReq() { return m_frontEndReq; } - BackEndCompileRequest* getBackEndReq() { return m_backEndReq; } - Program* getUnspecializedProgram() { return getFrontEndReq()->getProgram(); } - Program* getSpecializedProgram() { return m_specializedProgram; } - - private: - Session* m_session = nullptr; - RefPtr m_linkage; - DiagnosticSink m_sink; - RefPtr m_frontEndReq; - RefPtr m_unspecializedProgram; - RefPtr m_specializedProgram; - RefPtr m_backEndReq; - - // For output - ComPtr m_writers[SLANG_WRITER_CHANNEL_COUNT_OF]; - }; - - void generateOutput( - BackEndCompileRequest* compileRequest); - - void generateOutput( - EndToEndCompileRequest* compileRequest); - - // Helper to dump intermediate output when debugging - void maybeDumpIntermediate( - BackEndCompileRequest* compileRequest, - void const* data, - size_t size, - CodeGenTarget target); - void maybeDumpIntermediate( - BackEndCompileRequest* compileRequest, - char const* text, - CodeGenTarget target); - - /* Returns SLANG_OK if a codeGen target is available. */ - SlangResult checkCompileTargetSupport(Session* session, CodeGenTarget target); - /* Returns SLANG_OK if pass through support is available */ - SlangResult checkExternalCompilerSupport(Session* session, PassThroughMode passThrough); - - /* Report an error appearing from external compiler to the diagnostic sink error to the diagnostic sink. - @param compilerName The name of the compiler the error came for (or nullptr if not known) - @param res Result associated with the error. The error code will be reported. (Can take HRESULT - and will expand to string if known) - @param diagnostic The diagnostic string associated with the compile failure - @param sink The diagnostic sink to report to */ - void reportExternalCompileError(const char* compilerName, SlangResult res, const UnownedStringSlice& diagnostic, DiagnosticSink* sink); - - /* Determines a suitable filename to identify the input for a given entry point being compiled. - If the end-to-end compile is a pass-through case, will attempt to find the (unique) source file - pathname for the translation unit containing the entry point at `entryPointIndex. - If the compilation is not in a pass-through case, then always returns `"slang-generated"`. - @param endToEndReq The end-to-end compile request which might be using pass-through copmilation - @param entryPointIndex The index of the entry point to compute a filename for. - @return the appropriate source filename */ - String calcSourcePathForEntryPoint(EndToEndCompileRequest* endToEndReq, UInt entryPointIndex); - - struct TypeCheckingCache; - // - - class Session - { - public: - enum class SharedLibraryFuncType - { - Glslang_Compile, - Fxc_D3DCompile, - Fxc_D3DDisassemble, - Dxc_DxcCreateInstance, - CountOf, - }; - - // - - RefPtr baseLanguageScope; - RefPtr coreLanguageScope; - RefPtr hlslLanguageScope; - RefPtr slangLanguageScope; - - List> loadedModuleCode; - - SourceManager builtinSourceManager; - - SourceManager* getBuiltinSourceManager() { return &builtinSourceManager; } - - // Name pool stuff for unique-ing identifiers - - RootNamePool rootNamePool; - NamePool namePool; - - RootNamePool* getRootNamePool() { return &rootNamePool; } - NamePool* getNamePool() { return &namePool; } - Name* getNameObj(String name) { return namePool.getName(name); } - Name* tryGetNameObj(String name) { return namePool.tryGetName(name); } - // - - // Generated code for stdlib, etc. - String stdlibPath; - String coreLibraryCode; - String slangLibraryCode; - String hlslLibraryCode; - String glslLibraryCode; - - String getStdlibPath(); - String getCoreLibraryCode(); - String getHLSLLibraryCode(); - - // Basic types that we don't want to re-create all the time - RefPtr errorType; - RefPtr initializerListType; - RefPtr overloadedType; - RefPtr constExprRate; - RefPtr irBasicBlockType; - - RefPtr stringType; - RefPtr enumTypeType; - - ComPtr sharedLibraryLoader; ///< The shared library loader (never null) - ComPtr sharedLibraries[int(SharedLibraryType::CountOf)]; ///< The loaded shared libraries - SlangFuncPtr sharedLibraryFunctions[int(SharedLibraryFuncType::CountOf)]; - - Dictionary> builtinTypes; - Dictionary magicDecls; - - void initializeTypes(); - - Type* getBoolType(); - Type* getHalfType(); - Type* getFloatType(); - Type* getDoubleType(); - Type* getIntType(); - Type* getInt64Type(); - Type* getUIntType(); - Type* getUInt64Type(); - Type* getVoidType(); - Type* getBuiltinType(BaseType flavor); - - Type* getInitializerListType(); - Type* getOverloadedType(); - Type* getErrorType(); - Type* getStringType(); - - Type* getEnumTypeType(); - - // Construct the type `Ptr`, where `Ptr` - // is looked up as a builtin type. - RefPtr getPtrType(RefPtr valueType); - - // Construct the type `Out` - RefPtr getOutType(RefPtr valueType); - - // Construct the type `InOut` - RefPtr getInOutType(RefPtr valueType); - - // Construct the type `Ref` - RefPtr getRefType(RefPtr valueType); - - // Construct a pointer type like `Ptr`, but where - // the actual type name for the pointer type is given by `ptrTypeName` - RefPtr getPtrType(RefPtr valueType, char const* ptrTypeName); - - // Construct a pointer type like `Ptr`, but where - // the generic declaration for the pointer type is `genericDecl` - RefPtr getPtrType(RefPtr valueType, GenericDecl* genericDecl); - - RefPtr getArrayType( - Type* elementType, - IntVal* elementCount); - - RefPtr getVectorType( - RefPtr elementType, - RefPtr elementCount); - - SyntaxClass findSyntaxClass(Name* name); - - Dictionary > mapNameToSyntaxClass; - - // cache used by type checking, implemented in check.cpp - TypeCheckingCache* typeCheckingCache = nullptr; - TypeCheckingCache* getTypeCheckingCache(); - void destroyTypeCheckingCache(); - // - - /// Will try to load the library by specified name (using the set loader), if not one already available. - ISlangSharedLibrary* getOrLoadSharedLibrary(SharedLibraryType type, DiagnosticSink* sink); - - /// Gets a shared library by type, or null if not loaded - ISlangSharedLibrary* getSharedLibrary(SharedLibraryType type) const { return sharedLibraries[int(type)]; } - - SlangFuncPtr getSharedLibraryFunc(SharedLibraryFuncType type, DiagnosticSink* sink); - - Session(); - - void addBuiltinSource( - RefPtr const& scope, - String const& path, - String const& source); - ~Session(); - - private: - /// Linkage used for all built-in (stdlib) code. - RefPtr m_builtinLinkage; - }; - -} - -#endif diff --git a/source/slang/decl-defs.h b/source/slang/decl-defs.h deleted file mode 100644 index 10dcefe19..000000000 --- a/source/slang/decl-defs.h +++ /dev/null @@ -1,325 +0,0 @@ -// decl-defs.h - -// Syntax class definitions for declarations. - -// A group of declarations that should be treated as a unit -SYNTAX_CLASS(DeclGroup, DeclBase) - SYNTAX_FIELD(List>, decls) -END_SYNTAX_CLASS() - -// A "container" decl is a parent to other declarations -ABSTRACT_SYNTAX_CLASS(ContainerDecl, Decl) - SYNTAX_FIELD(List>, Members) - - RAW( - template - FilteredMemberList getMembersOfType() - { - return FilteredMemberList(Members); - } - - - // Dictionary for looking up members by name. - // This is built on demand before performing lookup. - Dictionary memberDictionary; - - // Whether the `memberDictionary` is valid. - // Should be set to `false` if any members get added/remoed. - bool memberDictionaryIsValid = false; - - // A list of transparent members, to be used in lookup - // Note: this is only valid if `memberDictionaryIsValid` is true - List transparentMembers; - ) -END_SYNTAX_CLASS() - -// Base class for all variable declarations -ABSTRACT_SYNTAX_CLASS(VarDeclBase, Decl) - - // type of the variable - SYNTAX_FIELD(TypeExp, type) - - RAW( - Type* getType() { return type.type.Ptr(); } - ) - - // Initializer expression (optional) - SYNTAX_FIELD(RefPtr, initExpr) -END_SYNTAX_CLASS() - -// Ordinary potentially-mutable variables (locals, globals, and member variables) -SYNTAX_CLASS(VarDecl, VarDeclBase) -END_SYNTAX_CLASS() - -// A variable declaration that is always immutable (whether local, global, or member variable) -SYNTAX_CLASS(LetDecl, VarDecl) -END_SYNTAX_CLASS() - -// An `AggTypeDeclBase` captures the shared functionality -// between true aggregate type declarations and extension -// declarations: -// -// - Both can container members (they are `ContainerDecl`s) -// - Both can have declared bases -// - Both expose a `this` variable in their body -// -ABSTRACT_SYNTAX_CLASS(AggTypeDeclBase, ContainerDecl) -END_SYNTAX_CLASS() - -// An extension to apply to an existing type -SYNTAX_CLASS(ExtensionDecl, AggTypeDeclBase) - SYNTAX_FIELD(TypeExp, targetType) - - // next extension attached to the same nominal type - DECL_FIELD(ExtensionDecl*, nextCandidateExtension RAW(= nullptr)) -END_SYNTAX_CLASS() - -// Declaration of a type that represents some sort of aggregate -ABSTRACT_SYNTAX_CLASS(AggTypeDecl, AggTypeDeclBase) - -RAW( - // extensions that might apply to this declaration - ExtensionDecl* candidateExtensions = nullptr; - FilteredMemberList GetFields() - { - return getMembersOfType(); - } - ) -END_SYNTAX_CLASS() - -SIMPLE_SYNTAX_CLASS(StructDecl, AggTypeDecl) - -SIMPLE_SYNTAX_CLASS(ClassDecl, AggTypeDecl) - -// TODO: Is it appropriate to treat an `enum` as an aggregate type? -// Most code that looks for, e.g., conformances assumes user-defined -// types are all `AggTypeDecl`, so this is the right choice for now -// if we want `enum` types to be able to implement interfaces, etc. -// -SYNTAX_CLASS(EnumDecl, AggTypeDecl) -RAW( - RefPtr tagType; -) -END_SYNTAX_CLASS() - -// A single case in an enum. -// -// E.g., in a declaration like: -// -// enum Color { Red = 0, Green, Blue }; -// -// The `Red = 0` is the declaration of the `Red` -// case, with `0` as an explicit expression for its -// _tag value_. -// -SYNTAX_CLASS(EnumCaseDecl, Decl) - - // type of the parent `enum` - SYNTAX_FIELD(TypeExp, type) - - RAW( - Type* getType() { return type.type.Ptr(); } - ) - - // Tag value - SYNTAX_FIELD(RefPtr, tagExpr) -END_SYNTAX_CLASS() - -// An interface which other types can conform to -SIMPLE_SYNTAX_CLASS(InterfaceDecl, AggTypeDecl) - -ABSTRACT_SYNTAX_CLASS(TypeConstraintDecl, Decl) - RAW( - virtual TypeExp& getSup() = 0; - ) -END_SYNTAX_CLASS() - -// A kind of pseudo-member that represents an explicit -// or implicit inheritance relationship. -// -SYNTAX_CLASS(InheritanceDecl, TypeConstraintDecl) -// The type expression as written - SYNTAX_FIELD(TypeExp, base) - - RAW( - // After checking, this dictionary will map members - // required by the base type to their concrete - // implementations in the type that contains - // this inheritance declaration. - RefPtr witnessTable; - virtual TypeExp& getSup() override - { - return base; - } - ) -END_SYNTAX_CLASS() - -// TODO: may eventually need sub-classes for explicit/direct vs. implicit/indirect inheritance - - -// A declaration that represents a simple (non-aggregate) type -// -// TODO: probably all types will be aggregate decls eventually, -// so that we can easily store conformances/constraints on type variables -ABSTRACT_SYNTAX_CLASS(SimpleTypeDecl, Decl) -END_SYNTAX_CLASS() - -// A `typedef` declaration -SYNTAX_CLASS(TypeDefDecl, SimpleTypeDecl) - SYNTAX_FIELD(TypeExp, type) -END_SYNTAX_CLASS() - -SIMPLE_SYNTAX_CLASS(TypeAliasDecl, TypeDefDecl) - -// An 'assoctype' declaration, it is a container of inheritance clauses -SYNTAX_CLASS(AssocTypeDecl, AggTypeDecl) -END_SYNTAX_CLASS() - -// A 'type_param' declaration, which defines a generic -// entry-point parameter. Is a container of GenericTypeConstraintDecl -SYNTAX_CLASS(GlobalGenericParamDecl, AggTypeDecl) -END_SYNTAX_CLASS() - -// A scope for local declarations (e.g., as part of a statement) -SIMPLE_SYNTAX_CLASS(ScopeDecl, ContainerDecl) - -// A function/initializer/subscript parameter (potentially mutable) -SIMPLE_SYNTAX_CLASS(ParamDecl, VarDeclBase) - -// A parameter of a function declared in "modern" types (immutable unless explicitly `out` or `inout`) -SIMPLE_SYNTAX_CLASS(ModernParamDecl, ParamDecl) - -// Base class for things that have parameter lists and can thus be applied to arguments ("called") -ABSTRACT_SYNTAX_CLASS(CallableDecl, ContainerDecl) - RAW( - FilteredMemberList GetParameters() - { - return getMembersOfType(); - }) - - SYNTAX_FIELD(TypeExp, ReturnType) - - // Fields related to redeclaration, so that we - // can support multiple specialized varaitions - // of the "same" logical function. - // - // This should also help us to support redeclaration - // of functions when handling HLSL/GLSL. - - // The "primary" declaration of the function, which will - // be used whenever we need to unique things. - FIELD_INIT(CallableDecl*, primaryDecl, nullptr) - - // The next declaration of the "same" function (that is, - // with the same `primaryDecl`). - FIELD_INIT(CallableDecl*, nextDecl, nullptr); - -END_SYNTAX_CLASS() - -// Base class for callable things that may also have a body that is evaluated to produce their result -ABSTRACT_SYNTAX_CLASS(FunctionDeclBase, CallableDecl) - SYNTAX_FIELD(RefPtr, Body) -END_SYNTAX_CLASS() - -// A constructor/initializer to create instances of a type -SIMPLE_SYNTAX_CLASS(ConstructorDecl, FunctionDeclBase) - -// A subscript operation used to index instances of a type -SIMPLE_SYNTAX_CLASS(SubscriptDecl, CallableDecl) - -// An "accessor" for a subscript or property -SIMPLE_SYNTAX_CLASS(AccessorDecl, FunctionDeclBase) - -SIMPLE_SYNTAX_CLASS(GetterDecl, AccessorDecl) -SIMPLE_SYNTAX_CLASS(SetterDecl, AccessorDecl) -SIMPLE_SYNTAX_CLASS(RefAccessorDecl, AccessorDecl) - -SIMPLE_SYNTAX_CLASS(FuncDecl, FunctionDeclBase) - -// A "module" of code (essentiately, a single translation unit) -// that provides a scope for some number of declarations. -SYNTAX_CLASS(ModuleDecl, ContainerDecl) - FIELD(RefPtr, scope) - - // The API-level module that this declaration belong to. - // - // This field allows lookup of the `Module` based on a - // declaration nested under a `ModuleDecl` by following - // its chain of parents. - // - RAW(Module* module = nullptr;) -END_SYNTAX_CLASS() - -SYNTAX_CLASS(ImportDecl, Decl) - // The name of the module we are trying to import - FIELD(NameLoc, moduleNameAndLoc) - - // The scope that we want to import into - FIELD(RefPtr, scope) - - // The module that actually got imported - DECL_FIELD(RefPtr, importedModuleDecl) -END_SYNTAX_CLASS() - -// A generic declaration, parameterized on types/values -SYNTAX_CLASS(GenericDecl, ContainerDecl) - // The decl that is genericized... - SYNTAX_FIELD(RefPtr, inner) -END_SYNTAX_CLASS() - -SYNTAX_CLASS(GenericTypeParamDecl, SimpleTypeDecl) - // The bound for the type parameter represents a trait that any - // type used as this parameter must conform to -// TypeExp bound; - - // The "initializer" for the parameter represents a default value - SYNTAX_FIELD(TypeExp, initType) -END_SYNTAX_CLASS() - -// A constraint placed as part of a generic declaration -SYNTAX_CLASS(GenericTypeConstraintDecl, TypeConstraintDecl) - // A type constraint like `T : U` is constraining `T` to be "below" `U` - // on a lattice of types. This may not be a subtyping relationship - // per se, but it makes sense to use that terminology here, so we - // think of these fields as the sub-type and sup-ertype, respectively. - SYNTAX_FIELD(TypeExp, sub) - SYNTAX_FIELD(TypeExp, sup) - RAW( - virtual TypeExp& getSup() override - { - return sup; - } - ) -END_SYNTAX_CLASS() - -SIMPLE_SYNTAX_CLASS(GenericValueParamDecl, VarDeclBase) - -// An empty declaration (which might still have modifiers attached). -// -// An empty declaration is uncommon in HLSL, but -// in GLSL it is often used at the global scope -// to declare metadata that logically belongs -// to the entry point, e.g.: -// -// layout(local_size_x = 16) in; -// -SIMPLE_SYNTAX_CLASS(EmptyDecl, Decl) - -// A declaration used by the implementation to put syntax keywords -// into the current scope. -// -SYNTAX_CLASS(SyntaxDecl, Decl) - // What type of syntax node will be produced when parsing with this keyword? - FIELD(SyntaxClass, syntaxClass) - - // Callback to invoke in order to parse syntax with this keyword. - FIELD(SyntaxParseCallback, parseCallback) - FIELD(void*, parseUserData) -END_SYNTAX_CLASS() - -// A declaration of an attribute to be used with `[name(...)]` syntax. -// -SYNTAX_CLASS(AttributeDecl, ContainerDecl) - // What type of syntax node will be produced to represent this attribute. - FIELD(SyntaxClass, syntaxClass) -END_SYNTAX_CLASS() diff --git a/source/slang/diagnostic-defs.h b/source/slang/diagnostic-defs.h deleted file mode 100644 index 59d840997..000000000 --- a/source/slang/diagnostic-defs.h +++ /dev/null @@ -1,487 +0,0 @@ -// - -// The file is meant to be included multiple times, to produce different -// pieces of declaration/definition code related to diagnostic messages -// -// Each diagnostic is declared here with: -// -// DIAGNOSTIC(id, severity, name, messageFormat) -// -// Where `id` is the unique diagnostic ID, `severity` is the default -// severity (from the `Severity` enum), `name` is a name used to refer -// to this diagnostic from code, and `messageFormat` is the default -// (non-localized) message for the diagnostic, with placeholders -// for any arguments. - -#ifndef DIAGNOSTIC -#error Need to #define DIAGNOSTIC(...) before including "DiagnosticDefs.h" -#define DIAGNOSTIC(id, severity, name, messageFormat) /* */ -#endif - -// -// -1 - Notes that decorate another diagnostic. -// - -DIAGNOSTIC(-1, Note, alsoSeePipelineDefinition, "also see pipeline definition"); -DIAGNOSTIC(-1, Note, implicitParameterMatchingFailedBecauseNameNotAccessible, "implicit parameter matching failed because the component of the same name is not accessible from '$0'.\ncheck if you have declared necessary requirements and properly used the 'public' qualifier.") -DIAGNOSTIC(-1, Note, implicitParameterMatchingFailedBecauseShaderDoesNotDefineComponent, "implicit parameter matching failed because shader '$0' does not define component '$1'.") -DIAGNOSTIC(-1, Note, implicitParameterMatchingFailedBecauseTypeMismatch, "implicit parameter matching failed because the component of the same name does not match parameter type '$0'.") -DIAGNOSTIC(-1, Note, noteShaderIsTargetingPipeine, "shader '$0' is targeting pipeline '$1'") -DIAGNOSTIC(-1, Note, seeDefinitionOf, "see definition of '$0'") -DIAGNOSTIC(-1, Note, seeInterfaceDefinitionOf, "see interface definition of '$0'") -DIAGNOSTIC(-1, Note, seeUsingOf, "see using of '$0'") -DIAGNOSTIC(-1, Note, seeDefinitionOfShader, "see definition of shader '$0'") -DIAGNOSTIC(-1, Note, seeInclusionOf, "see inclusion of '$0'") -DIAGNOSTIC(-1, Note, seeModuleBeingUsedIn, "see module '$0' being used in '$1'") -DIAGNOSTIC(-1, Note, seePipelineRequirementDefinition, "see pipeline requirement definition") -DIAGNOSTIC(-1, Note, seePotentialDefinitionOfComponent, "see potential definition of component '$0'") -DIAGNOSTIC(-1, Note, seePreviousDefinition, "see previous definition") -DIAGNOSTIC(-1, Note, seePreviousDefinitionOf, "see previous definition of '$0'") -DIAGNOSTIC(-1, Note, seeRequirementDeclaration, "see requirement declaration") -DIAGNOSTIC(-1, Note, doYouForgetToMakeComponentAccessible, "do you forget to make component '$0' acessible from '$1' (missing public qualifier)?") - -DIAGNOSTIC(-1, Note, seeDeclarationOf, "see declaration of '$0'") -DIAGNOSTIC(-1, Note, seeOtherDeclarationOf, "see other declaration of '$0'") -DIAGNOSTIC(-1, Note, seePreviousDeclarationOf, "see previous declaration of '$0'") - -// -// 0xxxx - Command line and interaction with host platform APIs. -// - -DIAGNOSTIC( 1, Error, cannotOpenFile, "cannot open file '$0'.") -DIAGNOSTIC( 2, Error, cannotFindFile, "cannot find file '$0'.") -DIAGNOSTIC( 2, Error, unsupportedCompilerMode, "unsupported compiler mode.") -DIAGNOSTIC( 4, Error, cannotWriteOutputFile, "cannot write output file '$0'.") -DIAGNOSTIC( 5, Error, failedToLoadDynamicLibrary, "failed to load dynamic library '$0'") -DIAGNOSTIC( 6, Error, tooManyOutputPathsSpecified, "$0 output paths specified, but only $1 entry points given") - -DIAGNOSTIC( 7, Error, noOutputPathSpecifiedForEntryPoint, - "no output path specified for entry point '$0' (the '-o' option for an entry point must precede the corresponding '-entry')") - -DIAGNOSTIC( 8, Error, outputPathsImplyDifferentFormats, - "the output paths '$0' and '$1' require different code-generation targets") - -DIAGNOSTIC( 10, Error, explicitOutputPathsAndMultipleTargets, "canot use both explicit output paths ('-o') and multiple targets ('-target')") -DIAGNOSTIC( 11, Error, glslIsNotSupported, "the Slang compiler does not support GLSL as a source language"); -DIAGNOSTIC( 12, Error, cannotDeduceSourceLanguage, "can't deduce language for input file '$0'"); -DIAGNOSTIC( 13, Error, unknownCodeGenerationTarget, "unknown code generation target '$0'"); -DIAGNOSTIC( 14, Error, unknownProfile, "unknown profile '$0'"); -DIAGNOSTIC( 15, Error, unknownStage, "unknown stage '$0'"); -DIAGNOSTIC( 16, Error, unknownPassThroughTarget, "unknown pass-through target '$0'"); -DIAGNOSTIC( 17, Error, unknownCommandLineOption, "unknown command-line option '$0'"); -DIAGNOSTIC( 20, Error, entryPointsNeedToBeAssociatedWithTranslationUnits, "when using multiple source files, entry points must be specified after their corresponding source file(s)"); -DIAGNOSTIC( 21, Error, expectedArgumentForOption, "expected an argument for command-line option '$0'"); - -DIAGNOSTIC( 24, Error, unknownLineDirectiveMode, "unknown '#line' directive mode '$0'"); -DIAGNOSTIC( 25, Error, unknownFloatingPointMode, "unknown floating-point mode '$0'"); -DIAGNOSTIC( 26, Error, unknownOptimiziationLevel, "unknown optimization level '$0'"); -DIAGNOSTIC( 27, Error, uknownDebugInfoLevel, "unknown debug info level '$0'"); - -DIAGNOSTIC( 30, Warning, sameStageSpecifiedMoreThanOnce, "the stage '$0' was specified more than once for entry point '$1'") -DIAGNOSTIC( 31, Error, conflictingStagesForEntryPoint, "conflicting stages have been specified for entry point '$0'") -DIAGNOSTIC( 32, Warning, explicitStageDoesntMatchImpliedStage, "the stage specified for entry point '$0' ('$1') does not match the stage implied by the source file name ('$2')") -DIAGNOSTIC( 33, Error, stageSpecificationIgnoredBecauseNoEntryPoints, "one or more stages were specified, but no entry points were specified with '-entry'") -DIAGNOSTIC( 34, Error, stageSpecificationIgnoredBecauseBeforeAllEntryPoints, "when compiling multiple entry points, any '-stage' options must follow the '-entry' option that they apply to") -DIAGNOSTIC( 35, Error, noStageSpecifiedInPassThroughMode, "no stage was specified for entry point '$0'; when using the '-pass-through' option, stages must be fully specified on the command line") - -DIAGNOSTIC( 40, Warning, sameProfileSpecifiedMoreThanOnce, "the '$0' was specified more than once for target '$0'") -DIAGNOSTIC( 41, Error, conflictingProfilesSpecifiedForTarget, "conflicting profiles have been specified for target '$0'") - -DIAGNOSTIC( 42, Error, profileSpecificationIgnoredBecauseNoTargets, "a '-profile' option was specified, but no target was specified with '-target'") -DIAGNOSTIC( 43, Error, profileSpecificationIgnoredBecauseBeforeAllTargets, "when using multiple targets, any '-profile' option must follow the '-target' it applies to") - -DIAGNOSTIC( 42, Error, targetFlagsIgnoredBecauseNoTargets, "target options were specified, but no target was specified with '-target'") -DIAGNOSTIC( 43, Error, targetFlagsIgnoredBecauseBeforeAllTargets, "when using multiple targets, any target options must follow the '-target' they apply to") - -DIAGNOSTIC( 50, Error, duplicateTargets, "the target '$0' has been specified more than once") - -DIAGNOSTIC( 60, Error, cannotDeduceOutputFormatFromPath, "cannot infer an output format from the output path '$0'") -DIAGNOSTIC( 61, Error, cannotMatchOutputFileToTarget, "no specified '-target' option matches the output path '$0', which implies the '$1' format") - -DIAGNOSTIC( 62, Error, failedToFindFunctionInSharedLibrary, "failed to find function '$0' in shared/dynamic library '$1'") - -DIAGNOSTIC( 70, Error, cannotMatchOutputFileToEntryPoint, "the output path '$0' is not associated with any entry point; a '-o' option for a compiled kernel must follow the '-entry' option for its corresponding entry point") - -DIAGNOSTIC( 80, Error, duplicateOutputPathsForEntryPointAndTarget, "multiple output paths have been specified entry point '$0' on target '$1'") - -// -// 1xxxx - Lexical anaylsis -// - -DIAGNOSTIC(10000, Error, illegalCharacterPrint, "illegal character '$0'"); -DIAGNOSTIC(10000, Error, illegalCharacterHex, "illegal character (0x$0)"); -DIAGNOSTIC(10001, Error, illegalCharacterLiteral, "illegal character literal"); - -DIAGNOSTIC(10002, Warning, octalLiteral, "'0' prefix indicates octal literal") -DIAGNOSTIC(10003, Error, invalidDigitForBase, "invalid digit for base-$1 literal: '$0'") - -DIAGNOSTIC(10004, Error, endOfFileInLiteral, "end of file in literal"); -DIAGNOSTIC(10005, Error, newlineInLiteral, "newline in literal"); - -// -// 15xxx - Preprocessing -// - -// 150xx - conditionals -DIAGNOSTIC(15000, Error, endOfFileInPreprocessorConditional, "end of file encountered during preprocessor conditional") -DIAGNOSTIC(15001, Error, directiveWithoutIf, "'$0' directive without '#if'") -DIAGNOSTIC(15002, Error, directiveAfterElse , "'$0' directive without '#if'") - -DIAGNOSTIC(-1, Note, seeDirective, "see '$0' directive") - -// 151xx - directive parsing -DIAGNOSTIC(15100, Error, expectedPreprocessorDirectiveName, "expected preprocessor directive name") -DIAGNOSTIC(15101, Error, unknownPreprocessorDirective, "unknown preprocessor directive '$0'") -DIAGNOSTIC(15102, Error, expectedTokenInPreprocessorDirective, "expected '$0' in '$1' directive") -DIAGNOSTIC(15102, Error, expected2TokensInPreprocessorDirective, "expected '$0' or '$1' in '$2' directive") -DIAGNOSTIC(15103, Error, unexpectedTokensAfterDirective, "unexpected tokens following '$0' directive") - - -// 152xx - preprocessor expressions -DIAGNOSTIC(15200, Error, expectedTokenInPreprocessorExpression, "expected '$0' in preprocessor expression"); -DIAGNOSTIC(15201, Error, syntaxErrorInPreprocessorExpression, "syntax error in preprocessor expression"); -DIAGNOSTIC(15202, Error, divideByZeroInPreprocessorExpression, "division by zero in preprocessor expression"); -DIAGNOSTIC(15203, Error, expectedTokenInDefinedExpression, "expected '$0' in 'defined' expression"); -DIAGNOSTIC(15204, Warning, directiveExpectsExpression, "'$0' directive requires an expression"); -DIAGNOSTIC(15205, Warning, undefinedIdentifierInPreprocessorExpression, "undefined identifier '$0' in preprocessor expression will evaluate to zero") - -DIAGNOSTIC(-1, Note, seeOpeningToken, "see opening '$0'") - -// 153xx - #include -DIAGNOSTIC(15300, Error, includeFailed, "failed to find include file '$0'") -DIAGNOSTIC(15301, Error, importFailed, "failed to find imported file '$0'") -DIAGNOSTIC(-1, Error, noIncludeHandlerSpecified, "no `#include` handler was specified") -DIAGNOSTIC(15302, Error, noUniqueIdentity, "`#include` handler didn't generate a unique identity for file '$0'") - - -// 154xx - macro definition -DIAGNOSTIC(15400, Warning, macroRedefinition, "redefinition of macro '$0'") -DIAGNOSTIC(15401, Warning, macroNotDefined, "macro '$0' is not defined") -DIAGNOSTIC(15403, Error, expectedTokenInMacroParameters, "expected '$0' in macro parameters") - -// 155xx - macro expansion -DIAGNOSTIC(15500, Warning, expectedTokenInMacroArguments, "expected '$0' in macro invocation") -DIAGNOSTIC(15501, Error, wrongNumberOfArgumentsToMacro, "wrong number of arguments to macro (expected $0, got $1)") - -// 156xx - pragmas -DIAGNOSTIC(15600, Error, expectedPragmaDirectiveName, "expected a name after '#pragma'") -DIAGNOSTIC(15601, Warning, unknownPragmaDirectiveIgnored, "ignoring unknown directive '#pragma $0'") -DIAGNOSTIC(15602, Warning, pragmaOnceIgnored, "pragma once was ignored - this is typically because is not placed in an include") - -// 159xx - user-defined error/warning -DIAGNOSTIC(15900, Error, userDefinedError, "#error: $0") -DIAGNOSTIC(15901, Warning, userDefinedWarning, "#warning: $0") - -// -// 2xxxx - Parsing -// - -DIAGNOSTIC(20003, Error, unexpectedToken, "unexpected $0"); -DIAGNOSTIC(20001, Error, unexpectedTokenExpectedTokenType, "unexpected $0, expected $1"); -DIAGNOSTIC(20001, Error, unexpectedTokenExpectedTokenName, "unexpected $0, expected '$1'"); - -DIAGNOSTIC(0, Error, tokenNameExpectedButEOF, "\"$0\" expected but end of file encountered."); -DIAGNOSTIC(0, Error, tokenTypeExpectedButEOF, "$0 expected but end of file encountered."); -DIAGNOSTIC(20001, Error, tokenNameExpected, "\"$0\" expected"); -DIAGNOSTIC(20001, Error, tokenNameExpectedButEOF2, "\"$0\" expected but end of file encountered."); -DIAGNOSTIC(20001, Error, tokenTypeExpected, "$0 expected"); -DIAGNOSTIC(20001, Error, tokenTypeExpectedButEOF2, "$0 expected but end of file encountered."); -DIAGNOSTIC(20001, Error, typeNameExpectedBut, "unexpected $0, expected type name"); -DIAGNOSTIC(20001, Error, typeNameExpectedButEOF, "type name expected but end of file encountered."); -DIAGNOSTIC(20001, Error, unexpectedEOF, " Unexpected end of file."); -DIAGNOSTIC(20002, Error, syntaxError, "syntax error."); -DIAGNOSTIC(20004, Error, unexpectedTokenExpectedComponentDefinition, "unexpected token '$0', only component definitions are allowed in a shader scope.") -DIAGNOSTIC(20008, Error, invalidOperator, "invalid operator '$0'."); -DIAGNOSTIC(20011, Error, unexpectedColon, "unexpected ':'.") - -// -// 3xxxx - Semantic analysis -// - -DIAGNOSTIC(30002, Error, parameterAlreadyDefined, "parameter '$0' already defined.") -DIAGNOSTIC(30003, Error, breakOutsideLoop, "'break' must appear inside loop constructs.") -DIAGNOSTIC(30004, Error, continueOutsideLoop, "'continue' must appear inside loop constructs.") -DIAGNOSTIC(30005, Error, whilePredicateTypeError, "'while': expression must evaluate to int.") -DIAGNOSTIC(30006, Error, ifPredicateTypeError, "'if': expression must evaluate to int.") -DIAGNOSTIC(30006, Error, returnNeedsExpression, "'return' should have an expression.") -DIAGNOSTIC(30007, Error, componentReturnTypeMismatch, "expression type '$0' does not match component's type '$1'") -DIAGNOSTIC(30007, Error, functionReturnTypeMismatch, "expression type '$0' does not match function's return type '$1'") -DIAGNOSTIC(30008, Error, variableNameAlreadyDefined, "variable $0 already defined.") -DIAGNOSTIC(30009, Error, invalidTypeVoid, "invalid type 'void'.") -DIAGNOSTIC(30010, Error, whilePredicateTypeError2, "'while': expression must evaluate to int.") -DIAGNOSTIC(30011, Error, assignNonLValue, "left of '=' is not an l-value.") -DIAGNOSTIC(30012, Error, noApplicationUnaryOperator, "no overload found for operator $0 ($1).") -DIAGNOSTIC(30012, Error, noOverloadFoundForBinOperatorOnTypes, "no overload found for operator $0 ($1, $2).") -DIAGNOSTIC(30013, Error, subscriptNonArray, "no subscript operation found for type '$0'") -DIAGNOSTIC(30014, Error, subscriptIndexNonInteger, "index expression must evaluate to int.") -DIAGNOSTIC(30015, Error, undefinedIdentifier, "'$0': undefined identifier.") -DIAGNOSTIC(30015, Error, undefinedIdentifier2, "undefined identifier '$0'.") -DIAGNOSTIC(30017, Error, componentNotAccessibleFromShader, "component '$0' is not accessible from shader '$1'.") -DIAGNOSTIC(30019, Error, typeMismatch, "expected an expression of type '$0', got '$1'") -DIAGNOSTIC(30020, Error, importOperatorReturnTypeMismatch, "import operator should return '$1', but the expression has type '$0''. do you forget 'project'?") -DIAGNOSTIC(30021, Error, noApplicationFunction, "$0: no overload takes arguments ($1)") -DIAGNOSTIC(30022, Error, invalidTypeCast, "invalid type cast between \"$0\" and \"$1\".") -DIAGNOSTIC(30023, Error, typeHasNoPublicMemberOfName, "\"$0\" does not have public member \"$1\"."); -DIAGNOSTIC(30025, Error, invalidArraySize, "array size must be larger than zero.") -DIAGNOSTIC(30026, Error, returnInComponentMustComeLast, "'return' can only appear as the last statement in component definition.") -DIAGNOSTIC(30027, Error, noMemberOfNameInType, "'$0' is not a member of '$1'."); -DIAGNOSTIC(30028, Error, forPredicateTypeError, "'for': predicate expression must evaluate to bool.") -DIAGNOSTIC(30030, Error, projectionOutsideImportOperator, "'project': invalid use outside import operator.") -DIAGNOSTIC(30031, Error, projectTypeMismatch, "'project': expression must evaluate to record type '$0'.") -DIAGNOSTIC(30033, Error, invalidTypeForLocalVariable, "cannot declare a local variable of this type.") -DIAGNOSTIC(30035, Error, componentOverloadTypeMismatch, "'$0': type of overloaded component mismatches previous definition.") -DIAGNOSTIC(30041, Error, bitOperationNonIntegral, "bit operation: operand must be integral type.") -DIAGNOSTIC(30047, Error, argumentExpectedLValue, "argument passed to parameter '$0' must be l-value.") -DIAGNOSTIC(30048, Note, implicitCastUsedAsLValue, "argument was implicitly cast from '$0' to '$1', and Slang does not support using an implicit cast as an l-value") -DIAGNOSTIC(30049, Note, thisIsImmutableByDefault, "a 'this' parameter is an immutable parameter by default in Slang; apply the `[mutating]` attribute to the function declaration to opt in to a mutable `this`") - -DIAGNOSTIC(30051, Error, invalidValueForArgument, "invalid value for argument '$0'") -DIAGNOSTIC(30052, Error, invalidSwizzleExpr, "invalid swizzle pattern '$0' on type '$1'") - -DIAGNOSTIC(30100, Error, staticRefToNonStaticMember, "type '$0' cannot be used to refer to non-static member '$1'") - -DIAGNOSTIC(30201, Error, functionRedefinition, "function '$0' already has a body") -DIAGNOSTIC(30202, Error, functionRedeclarationWithDifferentReturnType, "function '$0' declared to return '$1' was previously declared to return '$2'") - - -DIAGNOSTIC(33070, Error, expectedFunction, "expected a function, got '$0'") -DIAGNOSTIC(33071, Error, expectedAStringLiteral, "expected a string literal") - -DIAGNOSTIC( -1, Note, noteExplicitConversionPossible, "explicit conversion from '$0' to '$1' is possible") -DIAGNOSTIC(30080, Error, ambiguousConversion, "more than one implicit conversion exists from '$0' to '$1'") - - - - -// Attributes -DIAGNOSTIC(31000, Error, unknownAttributeName, "unknown attribute '$0'") -DIAGNOSTIC(31001, Error, attributeArgumentCountMismatch, "attribute '$0' expects $1 arguments ($2 provided)") -DIAGNOSTIC(31002, Error, attributeNotApplicable, "attribute '$0' is not valid here") - -DIAGNOSTIC(31003, Error, badlyDefinedPatchConstantFunc, "hull shader '$0' has has badly defined 'patchconstantfunc' attribute."); - -DIAGNOSTIC(31004, Error, expectedSingleIntArg, "attribute '$0' expects a single int argument") -DIAGNOSTIC(31005, Error, expectedSingleStringArg, "attribute '$0' expects a single string argument") - -DIAGNOSTIC(31006, Error, attributeFunctionNotFound, "Could not find function '$0' for attribute'$1'") - -DIAGNOSTIC(31100, Error, unknownStageName, "unknown stage name '$0'") -DIAGNOSTIC(31101, Error, unknownImageFormatName, "unknown image format '$0'") - -DIAGNOSTIC(31120, Error, invalidAttributeTarget, "invalid syntax target for user defined attribute") - -// Enums - -DIAGNOSTIC(32000, Error, invalidEnumTagType, "invalid tag type for 'enum': '$0'") -DIAGNOSTIC(32001, Error, enumTypeAlreadyHasTagType, "'enum' type has already declared a tag type") -DIAGNOSTIC(32002, Note, seePreviousTagType, "see previous tag type declaration") -DIAGNOSTIC(32003, Error, unexpectedEnumTagExpr, "unexpected form for 'enum' tag value expression") - - - -// 303xx: interfaces and associated types -DIAGNOSTIC(30300, Error, assocTypeInInterfaceOnly, "'associatedtype' can only be defined in an 'interface'.") -DIAGNOSTIC(30301, Error, globalGenParamInGlobalScopeOnly, "'type_param' can only be defined global scope.") -// TODO: need to assign numbers to all these extra diagnostics... -DIAGNOSTIC(39999, Fatal, cyclicReference, "cyclic reference '$0'.") -DIAGNOSTIC(39999, Fatal, localVariableUsedBeforeDeclared, "local variable '$0' is being used before its declaration.") - -// 304xx: generics -DIAGNOSTIC(30400, Error, genericTypeNeedsArgs, "generic type '$0' used without argument") - -// 305xx: initializer lists -DIAGNOSTIC(30500, Error, tooManyInitializers, "too many initializers (expected $0, got $1)"); -DIAGNOSTIC(30501, Error, cannotUseInitializerListForArrayOfUnknownSize, "cannot use initializer list for array of statically unknown size '$0'"); -DIAGNOSTIC(30502, Error, cannotUseInitializerListForVectorOfUnknownSize, "cannot use initializer list for vector of statically unknown size '$0'"); -DIAGNOSTIC(30503, Error, cannotUseInitializerListForMatrixOfUnknownSize, "cannot use initializer list for matrix of statically unknown size '$0' rows"); -DIAGNOSTIC(30504, Error, cannotUseInitializerListForType, "cannot use initializer list for type '$0'") - -// 306xx: variables -DIAGNOSTIC(30600, Error, varWithoutTypeMustHaveInitializer, "a variable declaration without an initial-value expression must be given an explicit type"); - -// 307xx: parameters -DIAGNOSTIC(30700, Error, outputParameterCannotHaveDefaultValue, "an 'out' or 'inout' parameter cannot have a default-value expression"); - -DIAGNOSTIC(39999, Error, expectedIntegerConstantWrongType, "expected integer constant (found: '$0')") -DIAGNOSTIC(39999, Error, expectedIntegerConstantNotConstant, "expression does not evaluate to a compile-time constant") -DIAGNOSTIC(39999, Error, expectedIntegerConstantNotLiteral, "could not extract value from integer constant") - -DIAGNOSTIC(39999, Error, noApplicableOverloadForNameWithArgs, "no overload for '$0' applicable to arguments of type $1") -DIAGNOSTIC(39999, Error, noApplicableWithArgs, "no overload applicable to arguments of type $0") - -DIAGNOSTIC(39999, Error, ambiguousOverloadForNameWithArgs, "ambiguous call to '$0' operation with arguments of type $1") -DIAGNOSTIC(39999, Error, ambiguousOverloadWithArgs, "ambiguous call to overloaded operation with arguments of type $0") - -DIAGNOSTIC(39999, Note, overloadCandidate, "candidate: $0") -DIAGNOSTIC(39999, Note, moreOverloadCandidates, "$0 more overload candidates") - -DIAGNOSTIC(39999, Error, caseOutsideSwitch, "'case' not allowed outside of a 'switch' statement") -DIAGNOSTIC(39999, Error, defaultOutsideSwitch, "'default' not allowed outside of a 'switch' statement") - -DIAGNOSTIC(39999, Error, expectedAGeneric, "expected a generic when using '<...>' (found: '$0')") - -DIAGNOSTIC(39999, Error, genericArgumentInferenceFailed, "could not specialize generic for arguments of type $0") -DIAGNOSTIC(39999, Note, genericSignatureTried, "see declaration of $0") - -DIAGNOSTIC(39999, Error, expectedAnInterfaceGot, "expected an interface, got '$0'") - -DIAGNOSTIC(39999, Error, ambiguousReference, "amiguous reference to '$0'"); - -DIAGNOSTIC(39999, Error, declarationDidntDeclareAnything, "declaration does not declare anything"); - - -DIAGNOSTIC(39999, Error, expectedPrefixOperator, "function called as prefix operator was not declared `__prefix`") -DIAGNOSTIC(39999, Error, expectedPostfixOperator, "function called as postfix operator was not declared `__postfix`") - -DIAGNOSTIC(39999, Error, notEnoughArguments, "not enough arguments to call (got $0, expected $1)") -DIAGNOSTIC(39999, Error, tooManyArguments, "too many arguments to call (got $0, expected $1)") - -DIAGNOSTIC(39999, Error, invalidIntegerLiteralSuffix, "invalid suffix '$0' on integer literal") -DIAGNOSTIC(39999, Error, invalidFloatingPointLiteralSuffix, "invalid suffix '$0' on floating-point literal") - - - -DIAGNOSTIC(38000, Error, entryPointFunctionNotFound, "no function found matching entry point name '$0'") -DIAGNOSTIC(38001, Error, ambiguousEntryPoint, "more than one function matches entry point name '$0'") -DIAGNOSTIC(38002, Note, entryPointCandidate, "see candidate declaration for entry point '$0'") -DIAGNOSTIC(38003, Error, entryPointSymbolNotAFunction, "entry point '$0' must be declared as a function") - -DIAGNOSTIC(38004, Error, entryPointTypeParameterNotFound, "no type found matching entry-point type parameter name '$0'") -DIAGNOSTIC(38005, Error, globalGenericArgumentNotAType, "argument for global generic parameter '$0' must be a type") - -DIAGNOSTIC(38006, Warning, specifiedStageDoesntMatchAttribute, "entry point '$0' being compiled for the '$1' stage has a '[shader(...)]' attribute that specifies the '$2' stage") -DIAGNOSTIC(38007, Error, entryPointHasNoStage, "no stage specified for entry point '$0'; use either a '[shader(\"name\")]' function attribute or the '-stage ' command-line option to specify a stage") - -DIAGNOSTIC(38100, Error, typeDoesntImplementInterfaceRequirement, "type '$0' does not provide required interface member '$1'") -DIAGNOSTIC(38101, Error, thisExpressionOutsideOfTypeDecl, "'this' expression can only be used in members of an aggregate type") -DIAGNOSTIC(38102, Error, initializerNotInsideType, "an 'init' declaration is only allowed inside a type or 'extension' declaration") -DIAGNOSTIC(38102, Error, accessorMustBeInsideSubscriptOrProperty, "an accessor declaration is only allowed inside a subscript or property declaration") - -DIAGNOSTIC(38020, Error, mismatchEntryPointTypeArgument, "expecting $0 entry-point type arguments, provided $1.") -DIAGNOSTIC(38021, Error, typeArgumentDoesNotConformToInterface, "type argument `$1` for generic parameter `$0` does not conform to interface `$2`.") - -DIAGNOSTIC(38022, Error, cannotSpecializeGlobalGenericToItself, "the global type parameter '$0' cannot be specialized to itself") -DIAGNOSTIC(38023, Error, cannotSpecializeGlobalGenericToAnotherGenericParam, "the global type parameter '$0' cannot be specialized using another global type parameter ('$1')") - -DIAGNOSTIC(38024, Error, invalidDispatchThreadIDType, "parameter with SV_DispatchThreadID must be either scalar or vector (1 to 3) of uint/int but is $0"); - -DIAGNOSTIC(-1, Note, noteWhenCompilingEntryPoint, "when compiling entry point '$0'") - -DIAGNOSTIC(38025, Error, mismatchGlobalGenericArguments, "expected $0 global generic arguments ($1 provided)") -DIAGNOSTIC(38026, Error, globalTypeArgumentDoesNotConformToInterface, "type argument `$1` for global generic parameter `$0` does not conform to interface `$2`.") - -DIAGNOSTIC(38027, Error, mismatchExistentialSlotArgCount, "expected $0 existential slot arguments ($1 provided)") -DIAGNOSTIC(38028, Error, existentialSlotArgNotAType, "existential slot argument $0 was not a type") -DIAGNOSTIC(38029, Error, existentialSlotArgDoesNotConform, "existential slot argument $0 does not conform to the required interface '$1'") - -DIAGNOSTIC(38200, Error, recursiveModuleImport, "module `$0` recursively imports itself") -DIAGNOSTIC(39999, Fatal, errorInImportedModule, "error in imported module, compilation ceased.") - -// 39xxx - Type layout and parameter binding. - -DIAGNOSTIC(39000, Error, conflictingExplicitBindingsForParameter, "conflicting explicit bindings for parameter '$0'") -DIAGNOSTIC(39001, Warning, parameterBindingsOverlap, "explicit binding for parameter '$0' overlaps with parameter '$1'") - - -DIAGNOSTIC(39002, Error, shaderParameterDeclarationsDontMatch, "declarations of shader parameter '$0' in different translation units don't match") - -DIAGNOSTIC(39003, Note, shaderParameterTypeMismatch, "type is declared as '$0' in one translation unit, and '$0' in another") -DIAGNOSTIC(39004, Note, fieldTypeMisMatch, "type of field '$0' is declared as '$1' in one translation unit, and '$2' in another") -DIAGNOSTIC(39005, Note, fieldDeclarationsDontMatch, "type '$0' is declared with different fields in each translation unit") -DIAGNOSTIC(39006, Note, usedInDeclarationOf, "used in declaration of '$0'") - -DIAGNOSTIC(39007, Error, unknownRegisterClass, "unknown register class: '$0'") -DIAGNOSTIC(39008, Error, expectedARegisterIndex, "expected a register index after '$0'") -DIAGNOSTIC(39009, Error, expectedSpace, "expected 'space', got '$0'") -DIAGNOSTIC(39010, Error, expectedSpaceIndex, "expected a register space index after 'space'") -DIAGNOSTIC(39011, Error, componentMaskNotSupported, "explicit register component masks are not yet supported in Slang") -DIAGNOSTIC(39012, Error, packOffsetNotSupported, "explicit 'packoffset' bindings are not yet supported in Slang") -DIAGNOSTIC(39013, Warning, registerModifierButNoVulkanLayout, "shader parameter '$0' has a 'register' specified for D3D, but no '[[vk::binding(...)]]` specified for Vulkan") -DIAGNOSTIC(39014, Error, unexpectedSpecifierAfterSpace, "unexpected specifier after register space: '$0'") -DIAGNOSTIC(39015, Error, wholeSpaceParameterRequiresZeroBinding, "shader parameter '$0' consumes whole descriptor sets, so the binding must be in the form '[[vk::binding(0, ...)]]'; the non-zero binding '$1' is not allowed") - -DIAGNOSTIC(39013, Error, dontExpectOutParametersForStage, "the '$0' stage does not support `out` or `inout` entry point parameters") -DIAGNOSTIC(39014, Error, dontExpectInParametersForStage, "the '$0' stage does not support `in` entry point parameters") - -DIAGNOSTIC(39016, Error, globalUniformsNotSupported, "'$0' is implicitly a global uniform shader parameter, which is currently unsupported by Slang. If a uniform parameter is intended, use a constant buffer or parameter block. If a global is intended, use the 'static' modifier.") - -DIAGNOSTIC(39017, Error, tooManyShaderRecordConstantBuffers, "can have at most one 'shader record' attributed constant buffer; found $0.") - -// -// 4xxxx - IL code generation. -// -DIAGNOSTIC(40001, Error, bindingAlreadyOccupiedByComponent, "resource binding location '$0' is already occupied by component '$1'.") -DIAGNOSTIC(40002, Error, invalidBindingValue, "binding location '$0' is out of valid range.") -DIAGNOSTIC(40003, Error, bindingExceedsLimit, "binding location '$0' assigned to component '$1' exceeds maximum limit.") -DIAGNOSTIC(40004, Error, bindingAlreadyOccupiedByModule, "DescriptorSet ID '$0' is already occupied by module instance '$1'.") -DIAGNOSTIC(40005, Error, topLevelModuleUsedWithoutSpecifyingBinding, "top level module '$0' is being used without specifying binding location. Use [Binding: \"index\"] attribute to provide a binding location.") - - -DIAGNOSTIC(49999, Error, unknownSystemValueSemantic, "unknown system-value semantic '$0'") - -DIAGNOSTIC(40006, Error, needCompileTimeConstant, "expected a compile-time constant") - -DIAGNOSTIC(40007, Internal, irValidationFailed, "IR validation failed: $0") - -DIAGNOSTIC(40008, Error, invalidLValueForRefParameter, "the form of this l-value argument is not valid for a `ref` parameter") - -// 41000 - IR-level validation issues - -DIAGNOSTIC(41000, Warning, unreachableCode, "unreachable code detected") - -DIAGNOSTIC(41010, Warning, missingReturn, "control flow may reach end of non-'void' function") - -// -// 5xxxx - Target code generation. -// - -DIAGNOSTIC(50010, Internal, missingExistentialBindingsForParameter, "missing argument for existential parameter slot"); - -DIAGNOSTIC(50020, Error, invalidTessCoordType, "TessCoord must have vec2 or vec3 type.") -DIAGNOSTIC(50020, Error, invalidFragCoordType, "FragCoord must be a vec4.") -DIAGNOSTIC(50020, Error, invalidInvocationIdType, "InvocationId must have int type.") -DIAGNOSTIC(50020, Error, invalidThreadIdType, "ThreadId must have int type.") -DIAGNOSTIC(50020, Error, invalidPrimitiveIdType, "PrimitiveId must have int type.") -DIAGNOSTIC(50020, Error, invalidPatchVertexCountType, "PatchVertexCount must have int type.") -DIAGNOSTIC(50022, Error, worldIsNotDefined, "world '$0' is not defined."); -DIAGNOSTIC(50023, Error, stageShouldProvideWorldAttribute, "'$0' should provide 'World' attribute."); -DIAGNOSTIC(50040, Error, componentHasInvalidTypeForPositionOutput, "'$0': component used as 'loc' output must be of vec4 type.") -DIAGNOSTIC(50041, Error, componentNotDefined, "'$0': component not defined.") - -DIAGNOSTIC(50052, Error, domainShaderRequiresControlPointCount, "'DomainShader' requires attribute 'ControlPointCount'."); -DIAGNOSTIC(50052, Error, hullShaderRequiresControlPointCount, "'HullShader' requires attribute 'ControlPointCount'.") -DIAGNOSTIC(50052, Error, hullShaderRequiresControlPointWorld, "'HullShader' requires attribute 'ControlPointWorld'."); -DIAGNOSTIC(50052, Error, hullShaderRequiresCornerPointWorld, "'HullShader' requires attribute 'CornerPointWorld'."); -DIAGNOSTIC(50052, Error, hullShaderRequiresDomain, "'HullShader' requires attribute 'Domain'."); -DIAGNOSTIC(50052, Error, hullShaderRequiresInputControlPointCount, "'HullShader' requires attribute 'InputControlPointCount'.") -DIAGNOSTIC(50052, Error, hullShaderRequiresOutputTopology, "'HullShader' requires attribute 'OutputTopology'.") -DIAGNOSTIC(50052, Error, hullShaderRequiresPartitioning, "'HullShader' requires attribute 'Partitioning'.") -DIAGNOSTIC(50052, Error, hullShaderRequiresPatchWorld, "'HullShader' requires attribute 'PatchWorld'."); -DIAGNOSTIC(50052, Error, hullShaderRequiresTessLevelInner, "'HullShader' requires attribute 'TessLevelInner'.") -DIAGNOSTIC(50052, Error, hullShaderRequiresTessLevelOuter, "'HullShader' requires attribute 'TessLevelOuter'.") - -DIAGNOSTIC(50053, Error, invalidTessellationDomian, "'Domain' should be either 'triangles' or 'quads'."); -DIAGNOSTIC(50053, Error, invalidTessellationOutputTopology, "'OutputTopology' must be one of: 'point', 'line', 'triangle_cw', or 'triangle_ccw'."); -DIAGNOSTIC(50053, Error, invalidTessellationPartitioning, "'Partitioning' must be one of: 'integer', 'pow2', 'fractional_even', or 'fractional_odd'.") -DIAGNOSTIC(50053, Error, invalidTessellationDomain, "'Domain' should be either 'triangles' or 'quads'.") - -DIAGNOSTIC(50082, Error, importingFromPackedBufferUnsupported, "importing type '$0' from PackedBuffer is not supported by the GLSL backend.") -DIAGNOSTIC(51090, Error, cannotGenerateCodeForExternComponentType, "cannot generate code for extern component type '$0'.") -DIAGNOSTIC(51091, Error, typeCannotBePlacedInATexture, "type '$0' cannot be placed in a texture.") -DIAGNOSTIC(51092, Error, stageDoesntHaveInputWorld, "'$0' doesn't appear to have any input world"); - -DIAGNOSTIC(52000, Error, multiLevelBreakUnsupported, "control flow appears to require multi-level `break`, which Slang does not yet support"); - -DIAGNOSTIC(52001, Warning, dxilNotFound, "dxil shared library not found, so 'dxc' output cannot be signed! Shader code will not be runnable in non-development environments."); - -// 99999 - Internal compiler errors, and not-yet-classified diagnostics. - -DIAGNOSTIC(99999, Internal, unimplemented, "unimplemented feature in Slang compiler: $0") -DIAGNOSTIC(99999, Internal, unexpected, "unexpected condition encountered in Slang compiler: $0") -DIAGNOSTIC(99999, Internal, internalCompilerError, "Slang internal compiler error") -DIAGNOSTIC(99999, Error, compilationAborted, "Slang compilation aborted due to internal error"); -DIAGNOSTIC(99999, Error, compilationAbortedDueToException, "Slang compilation aborted due to an exception of $0: $1"); -DIAGNOSTIC(99999, Note, noteLocationOfInternalError, "the Slang compiler threw an exception while working on code near this location"); -DIAGNOSTIC(99999, Internal, serialDebugVerificationFailed, "Verification of serial debug information failed."); - -#undef DIAGNOSTIC diff --git a/source/slang/diagnostics.cpp b/source/slang/diagnostics.cpp deleted file mode 100644 index a947e7369..000000000 --- a/source/slang/diagnostics.cpp +++ /dev/null @@ -1,350 +0,0 @@ -// diagnostics.cpp -#include "diagnostics.h" - -#include "compiler.h" -#include "name.h" -#include "syntax.h" - -#include - -#ifdef _WIN32 -#define WIN32_LEAN_AND_MEAN -#define NOMINMAX -#include -#undef WIN32_LEAN_AND_MEAN -#undef NOMINMAX -#include -#endif - -namespace Slang { - -void printDiagnosticArg(StringBuilder& sb, char const* str) -{ - sb << str; -} - -void printDiagnosticArg(StringBuilder& sb, int32_t val) -{ - sb << val; -} - -void printDiagnosticArg(StringBuilder& sb, uint32_t val) -{ - sb << val; -} - -void printDiagnosticArg(StringBuilder& sb, int64_t val) -{ - sb << val; -} - -void printDiagnosticArg(StringBuilder& sb, uint64_t val) -{ - sb << val; -} - -void printDiagnosticArg(StringBuilder& sb, Slang::String const& str) -{ - sb << str; -} - -void printDiagnosticArg(StringBuilder& sb, Slang::UnownedStringSlice const& str) -{ - sb.append(str); -} - - -void printDiagnosticArg(StringBuilder& sb, Name* name) -{ - sb << getText(name); -} - - -void printDiagnosticArg(StringBuilder& sb, Decl* decl) -{ - sb << getText(decl->getName()); -} - -void printDiagnosticArg(StringBuilder& sb, Type* type) -{ - sb << type->ToString(); -} - -void printDiagnosticArg(StringBuilder& sb, Val* val) -{ - sb << val->ToString(); -} - -void printDiagnosticArg(StringBuilder& sb, TypeExp const& type) -{ - sb << type.type->ToString(); -} - -void printDiagnosticArg(StringBuilder& sb, QualType const& type) -{ - if (type.type) - sb << type.type->ToString(); - else - sb << ""; -} - -void printDiagnosticArg(StringBuilder& sb, TokenType tokenType) -{ - sb << TokenTypeToString(tokenType); -} - -void printDiagnosticArg(StringBuilder& sb, Token const& token) -{ - sb << token.Content; -} - -void printDiagnosticArg(StringBuilder& sb, CodeGenTarget val) -{ - switch( val ) - { - default: - sb << ""; - break; - -#define CASE(TAG, STR) case CodeGenTarget::TAG: sb << STR; break - CASE(GLSL, "glsl"); - CASE(HLSL, "hlsl"); - CASE(SPIRV, "spirv"); - CASE(SPIRVAssembly, "spriv-assembly"); - CASE(DXBytecode, "dxbc"); - CASE(DXBytecodeAssembly, "dxbc-assembly"); - CASE(DXIL, "dxil"); - CASE(DXILAssembly, "dxil-assembly"); -#undef CASE - } -} - -void printDiagnosticArg(StringBuilder& sb, Stage val) -{ - sb << getStageName(val); -} - -void printDiagnosticArg(StringBuilder& sb, ProfileVersion val) -{ - sb << Profile(val).getName(); -} - - -SourceLoc const& getDiagnosticPos(SyntaxNode const* syntax) -{ - return syntax->loc; -} - -SourceLoc const& getDiagnosticPos(Token const& token) -{ - return token.loc; -} - -SourceLoc const& getDiagnosticPos(TypeExp const& typeExp) -{ - return typeExp.exp->loc; -} - -SourceLoc const& getDiagnosticPos(IRInst* inst) -{ - return inst->sourceLoc; -} - - -// Take the format string for a diagnostic message, along with its arguments, and turn it into a -static void formatDiagnosticMessage(StringBuilder& sb, char const* format, int argCount, DiagnosticArg const* const* args) -{ - char const* spanBegin = format; - for(;;) - { - char const* spanEnd = spanBegin; - while (int c = *spanEnd) - { - if (c == '$') - break; - spanEnd++; - } - - sb.Append(spanBegin, int(spanEnd - spanBegin)); - if (!*spanEnd) - return; - - SLANG_ASSERT(*spanEnd == '$'); - spanEnd++; - int d = *spanEnd++; - switch (d) - { - // A double dollar sign `$$` is used to emit a single `$` - case '$': - sb.Append('$'); - break; - - // A single digit means to emit the corresponding argument. - // TODO: support more than 10 arguments, and add options - // to control formatting, etc. - case '0': case '1': case '2': case '3': case '4': - case '5': case '6': case '7': case '8': case '9': - { - int index = d - '0'; - if (index >= argCount) - { - // TODO(tfoley): figure out what a good policy will be for "panic" situations like this - throw InvalidOperationException("too few arguments for diagnostic message"); - } - else - { - DiagnosticArg const* arg = args[index]; - arg->printFunc(sb, arg->data); - } - } - break; - - default: - throw InvalidOperationException("invalid diagnostic message format"); - break; - } - - spanBegin = spanEnd; - } -} - -static void formatDiagnostic(const HumaneSourceLoc& humaneLoc, Diagnostic const& diagnostic, StringBuilder& outBuilder) -{ - outBuilder << humaneLoc.pathInfo.foundPath; - outBuilder << "("; - outBuilder << Int32(humaneLoc.line); - outBuilder << "): "; - - outBuilder << getSeverityName(diagnostic.severity); - - if (diagnostic.ErrorID >= 0) - { - outBuilder << " "; - outBuilder << diagnostic.ErrorID; - } - - outBuilder << ": "; - outBuilder << diagnostic.Message; - outBuilder << "\n"; -} - -static void formatDiagnostic( - DiagnosticSink* sink, - Diagnostic const& diagnostic, - StringBuilder& sb) -{ - auto sourceManager = sink->sourceManager; - - SourceView* sourceView = nullptr; - HumaneSourceLoc humaneLoc; - const auto sourceLoc = diagnostic.loc; - { - sourceView = sourceManager->findSourceViewRecursively(sourceLoc); - if (sourceView) - { - humaneLoc = sourceView->getHumaneLoc(sourceLoc); - } - formatDiagnostic(humaneLoc, diagnostic, sb); - } - - if (sourceView && (sink->flags & DiagnosticSink::Flag::VerbosePath)) - { - auto actualHumaneLoc = sourceView->getHumaneLoc(diagnostic.loc, SourceLocType::Actual); - - // Look up the path verbosely (will get the canonical path if necessary) - actualHumaneLoc.pathInfo.foundPath = sourceView->getSourceFile()->calcVerbosePath(); - - // Only output if it's actually different - if (actualHumaneLoc.pathInfo.foundPath != humaneLoc.pathInfo.foundPath || - actualHumaneLoc.line != humaneLoc.line || - actualHumaneLoc.column != humaneLoc.column) - { - formatDiagnostic(actualHumaneLoc, diagnostic, sb); - } - } -} - -void DiagnosticSink::diagnoseImpl(SourceLoc const& pos, DiagnosticInfo const& info, int argCount, DiagnosticArg const* const* args) -{ - StringBuilder sb; - formatDiagnosticMessage(sb, info.messageFormat, argCount, args); - - Diagnostic diagnostic; - diagnostic.ErrorID = info.id; - diagnostic.Message = sb.ProduceString(); - diagnostic.loc = pos; - diagnostic.severity = info.severity; - - if (diagnostic.severity >= Severity::Error) - { - errorCount++; - } - - // Did the client supply a callback for us to use? - if( writer ) - { - // If so, pass the error string along to them - StringBuilder messageBuilder; - formatDiagnostic(this, diagnostic, messageBuilder); - - writer->write(messageBuilder.getBuffer(), messageBuilder.getLength()); - } - else - { - // If the user doesn't have a callback, then just - // collect our diagnostic messages into a buffer - formatDiagnostic(this, diagnostic, outputBuffer); - } - - if (diagnostic.severity >= Severity::Fatal) - { - // TODO: figure out a better policy for aborting compilation - throw AbortCompilationException(); - } -} - -void DiagnosticSink::diagnoseRaw( - Severity severity, - char const* message) -{ - return diagnoseRaw(severity, UnownedStringSlice(message)); -} - -void DiagnosticSink::diagnoseRaw( - Severity severity, - const UnownedStringSlice& message) -{ - if (severity >= Severity::Error) - { - errorCount++; - } - - // Did the client supply a callback for us to use? - if(writer) - { - // If so, pass the error string along to them - writer->write(message.begin(), message.size()); - } - else - { - // If the user doesn't have a callback, then just - // collect our diagnostic messages into a buffer - outputBuffer.append(message); - } - - if (severity >= Severity::Fatal) - { - // TODO: figure out a better policy for aborting compilation - throw InvalidOperationException(); - } -} - - -namespace Diagnostics -{ -#define DIAGNOSTIC(id, severity, name, messageFormat) const DiagnosticInfo name = { id, Severity::severity, messageFormat }; -#include "diagnostic-defs.h" -} - - -} // namespace Slang diff --git a/source/slang/diagnostics.h b/source/slang/diagnostics.h deleted file mode 100644 index 8e5ba809b..000000000 --- a/source/slang/diagnostics.h +++ /dev/null @@ -1,280 +0,0 @@ -#ifndef RASTER_RENDERER_COMPILE_ERROR_H -#define RASTER_RENDERER_COMPILE_ERROR_H - -#include "../core/basic.h" -#include "../core/slang-writer.h" - -#include "source-loc.h" -#include "token.h" - -#include "../../slang.h" - -namespace Slang -{ - enum class Severity - { - Note, - Warning, - Error, - Fatal, - Internal, - }; - - // TODO(tfoley): move this into a source file... - inline const char* getSeverityName(Severity severity) - { - switch (severity) - { - case Severity::Note: return "note"; - case Severity::Warning: return "warning"; - case Severity::Error: return "error"; - case Severity::Fatal: return "fatal error"; - case Severity::Internal: return "internal error"; - default: return "unknown error"; - } - } - - // A structure to be used in static data describing different - // diagnostic messages. - struct DiagnosticInfo - { - int id; - Severity severity; - char const* messageFormat; - }; - - class Diagnostic - { - public: - String Message; - SourceLoc loc; - int ErrorID; - Severity severity; - - Diagnostic() - { - ErrorID = -1; - } - Diagnostic( - const String & msg, - int id, - const SourceLoc & pos, - Severity severity) - : severity(severity) - { - Message = msg; - ErrorID = id; - loc = pos; - } - }; - - class Name; - class Decl; - struct QualType; - class Type; - struct TypeExp; - class Val; - - enum class CodeGenTarget; - enum class Stage : SlangStage; - enum class ProfileVersion; - - void printDiagnosticArg(StringBuilder& sb, char const* str); - - void printDiagnosticArg(StringBuilder& sb, int32_t val); - void printDiagnosticArg(StringBuilder& sb, uint32_t val); - - void printDiagnosticArg(StringBuilder& sb, int64_t val); - void printDiagnosticArg(StringBuilder& sb, uint64_t val); - - void printDiagnosticArg(StringBuilder& sb, Slang::String const& str); - void printDiagnosticArg(StringBuilder& sb, Slang::UnownedStringSlice const& str); - void printDiagnosticArg(StringBuilder& sb, Name* name); - void printDiagnosticArg(StringBuilder& sb, Decl* decl); - void printDiagnosticArg(StringBuilder& sb, Type* type); - void printDiagnosticArg(StringBuilder& sb, TypeExp const& type); - void printDiagnosticArg(StringBuilder& sb, QualType const& type); - void printDiagnosticArg(StringBuilder& sb, TokenType tokenType); - void printDiagnosticArg(StringBuilder& sb, Token const& token); - void printDiagnosticArg(StringBuilder& sb, CodeGenTarget val); - void printDiagnosticArg(StringBuilder& sb, Stage val); - void printDiagnosticArg(StringBuilder& sb, ProfileVersion val); - void printDiagnosticArg(StringBuilder& sb, Val* val); - - template - void printDiagnosticArg(StringBuilder& sb, RefPtr ptr) - { - printDiagnosticArg(sb, ptr.Ptr()); - } - - inline SourceLoc const& getDiagnosticPos(SourceLoc const& pos) { return pos; } - - class SyntaxNode; - SourceLoc const& getDiagnosticPos(SyntaxNode const* syntax); - SourceLoc const& getDiagnosticPos(Token const& token); - SourceLoc const& getDiagnosticPos(TypeExp const& typeExp); - - struct IRInst; - SourceLoc const& getDiagnosticPos(IRInst* inst); - - template - SourceLoc getDiagnosticPos(RefPtr const& ptr) - { - return getDiagnosticPos(ptr.Ptr()); - } - - struct DiagnosticArg - { - void* data; - void (*printFunc)(StringBuilder&, void*); - - template - struct Helper - { - static void printFunc(StringBuilder& sb, void* data) { printDiagnosticArg(sb, *(T*)data); } - }; - - template - DiagnosticArg(T const& arg) - : data((void*)&arg) - , printFunc(&Helper::printFunc) - {} - }; - - class DiagnosticSink - { - public: - struct Flag - { - enum Enum: uint32_t - { - VerbosePath = 0x1, ///< Will display a more verbose path (if available) - such as a canonical or absolute path - }; - }; - typedef uint32_t Flags; - - StringBuilder outputBuffer; -// List diagnostics; - int errorCount = 0; - int internalErrorLocsNoted = 0; - - ISlangWriter* writer = nullptr; - Flags flags = 0; - - // The source manager to use when mapping source locations to file+line info - SourceManager* sourceManager = nullptr; - -/* - void Error(int id, const String & msg, const SourceLoc & pos) - { - diagnostics.Add(Diagnostic(msg, id, pos, Severity::Error)); - errorCount++; - } - - void Warning(int id, const String & msg, const SourceLoc & pos) - { - diagnostics.Add(Diagnostic(msg, id, pos, Severity::Warning)); - } -*/ - int GetErrorCount() { return errorCount; } - - void diagnoseDispatch(SourceLoc const& pos, DiagnosticInfo const& info) - { - diagnoseImpl(pos, info, 0, nullptr); - } - - void diagnoseDispatch(SourceLoc const& pos, DiagnosticInfo const& info, DiagnosticArg const& arg0) - { - DiagnosticArg const* args[] = { &arg0 }; - diagnoseImpl(pos, info, 1, args); - } - - void diagnoseDispatch(SourceLoc const& pos, DiagnosticInfo const& info, DiagnosticArg const& arg0, DiagnosticArg const& arg1) - { - DiagnosticArg const* args[] = { &arg0, &arg1 }; - diagnoseImpl(pos, info, 2, args); - } - - void diagnoseDispatch(SourceLoc const& pos, DiagnosticInfo const& info, DiagnosticArg const& arg0, DiagnosticArg const& arg1, DiagnosticArg const& arg2) - { - DiagnosticArg const* args[] = { &arg0, &arg1, &arg2 }; - diagnoseImpl(pos, info, 3, args); - } - - void diagnoseDispatch(SourceLoc const& pos, DiagnosticInfo const& info, DiagnosticArg const& arg0, DiagnosticArg const& arg1, DiagnosticArg const& arg2, DiagnosticArg const& arg3) - { - DiagnosticArg const* args[] = { &arg0, &arg1, &arg2, &arg3 }; - diagnoseImpl(pos, info, 4, args); - } - - template - void diagnose(P const& pos, DiagnosticInfo const& info, Args const&... args ) - { - diagnoseDispatch(getDiagnosticPos(pos), info, args...); - } - - void diagnoseImpl(SourceLoc const& pos, DiagnosticInfo const& info, int argCount, DiagnosticArg const* const* args); - - // Add a diagnostic with raw text - // (used when we get errors from a downstream compiler) - void diagnoseRaw( - Severity severity, - char const* message); - void diagnoseRaw( - Severity severity, - const UnownedStringSlice& message); - - /// During propagation of an exception for an internal - /// error, note that this source location was involved - void noteInternalErrorLoc(SourceLoc const& loc); - - SlangResult getBlobIfNeeded(ISlangBlob** outBlob); - }; - - /// An `ISlangWriter` that writes directly to a diagnostic sink. - class DiagnosticSinkWriter : public AppendBufferWriter - { - public: - typedef AppendBufferWriter Super; - - DiagnosticSinkWriter(DiagnosticSink* sink) - : Super(WriterFlag::IsStatic) - , m_sink(sink) - {} - - // ISlangWriter - SLANG_NO_THROW virtual SlangResult SLANG_MCALL write(const char* chars, size_t numChars) SLANG_OVERRIDE - { - m_sink->diagnoseRaw(Severity::Note, UnownedStringSlice(chars, chars+numChars)); - return SLANG_OK; - } - - private: - DiagnosticSink* m_sink = nullptr; - }; - - namespace Diagnostics - { -#define DIAGNOSTIC(id, severity, name, messageFormat) extern const DiagnosticInfo name; -#include "diagnostic-defs.h" - } -} - -#ifdef _DEBUG -#define SLANG_INTERNAL_ERROR(sink, pos) \ - (sink)->diagnose(Slang::SourceLoc(__LINE__, 0, 0, __FILE__), Slang::Diagnostics::internalCompilerError) -#define SLANG_UNIMPLEMENTED(sink, pos, what) \ - (sink)->diagnose(Slang::SourceLoc(__LINE__, 0, 0, __FILE__), Slang::Diagnostics::unimplemented, what) - -#else -#define SLANG_INTERNAL_ERROR(sink, pos) \ - (sink)->diagnose(pos, Slang::Diagnostics::internalCompilerError) -#define SLANG_UNIMPLEMENTED(sink, pos, what) \ - (sink)->diagnose(pos, Slang::Diagnostics::unimplemented, what) - -#endif - -#define SLANG_DIAGNOSE_UNEXPECTED(sink, pos, message) \ - (sink)->diagnose(pos, Slang::Diagnostics::unexpected, message) - -#endif diff --git a/source/slang/dxc-support.cpp b/source/slang/dxc-support.cpp deleted file mode 100644 index b6eaf8aa9..000000000 --- a/source/slang/dxc-support.cpp +++ /dev/null @@ -1,302 +0,0 @@ -// dxc-support.cpp -#include "compiler.h" - -// This file implements support for invoking the `dxcompiler` -// library to translate HLSL to DXIL. - -#if defined(_WIN32) -# if !defined(SLANG_ENABLE_DXIL_SUPPORT) -# define SLANG_ENABLE_DXIL_SUPPORT 1 -# endif -#endif - -#if !defined(SLANG_ENABLE_DXIL_SUPPORT) -# define SLANG_ENABLE_DXIL_SUPPORT 0 -#endif - -#if SLANG_ENABLE_DXIL_SUPPORT - -#define WIN32_LEAN_AND_MEAN -#define NOMINMAX -#include -#include -#include "../../external/dxc/dxcapi.h" -#undef WIN32_LEAN_AND_MEAN -#undef NOMINMAX - -#include "../core/platform.h" - -namespace Slang -{ - String GetHLSLProfileName(Profile profile); - String emitHLSLForEntryPoint( - BackEndCompileRequest* compileRequest, - EntryPoint* entryPoint, - Int entryPointIndex, - TargetRequest* targetReq, - EndToEndCompileRequest* endToEndReq); - - static UnownedStringSlice _getSlice(IDxcBlob* blob) - { - if (blob) - { - const char* chars = (const char*)blob->GetBufferPointer(); - size_t len = blob->GetBufferSize(); - len -= size_t(len > 0 && chars[len - 1] == 0); - return UnownedStringSlice(chars, len); - } - return UnownedStringSlice(); - } - - SlangResult emitDXILForEntryPointUsingDXC( - BackEndCompileRequest* compileRequest, - EntryPoint* entryPoint, - Int entryPointIndex, - TargetRequest* targetReq, - EndToEndCompileRequest* endToEndReq, - List& outCode) - { - auto session = compileRequest->getSession(); - auto sink = compileRequest->getSink(); - - // First deal with all the rigamarole of loading - // the `dxcompiler` library, and creating the - // top-level COM objects that will be used to - // compile things. - - auto dxcCreateInstance = (DxcCreateInstanceProc)session->getSharedLibraryFunc(Session::SharedLibraryFuncType::Dxc_DxcCreateInstance, sink); - if (!dxcCreateInstance) - { - return SLANG_FAIL; - } - - { - if (!session->getSharedLibrary(SharedLibraryType::Dxil)) - { - // If can't load dxil - dxc will not be able to sign output - // Output a suitable warning to the user - sink->diagnose(SourceLoc(), Diagnostics::dxilNotFound); - } - } - - ComPtr dxcCompiler; - SLANG_RETURN_ON_FAIL(dxcCreateInstance( - CLSID_DxcCompiler, - __uuidof(dxcCompiler), - (LPVOID*)dxcCompiler.writeRef())); - - ComPtr dxcLibrary; - SLANG_RETURN_ON_FAIL(dxcCreateInstance( - CLSID_DxcLibrary, - __uuidof(dxcLibrary), - (LPVOID*)dxcLibrary.writeRef())); - - // Now let's go ahead and generate HLSL for the entry - // point, since we'll need that to feed into dxc. - auto hlslCode = emitHLSLForEntryPoint( - compileRequest, - entryPoint, - entryPointIndex, - targetReq, - endToEndReq); - maybeDumpIntermediate(compileRequest, hlslCode.getBuffer(), CodeGenTarget::HLSL); - - // Wrap the - - // Create blob from the string - ComPtr dxcSourceBlob; - SLANG_RETURN_ON_FAIL(dxcLibrary->CreateBlobWithEncodingFromPinned( - (LPBYTE)hlslCode.getBuffer(), - (UINT32)hlslCode.getLength(), - 0, - dxcSourceBlob.writeRef())); - - WCHAR const* args[16]; - UINT32 argCount = 0; - - // TODO: deal with - bool treatWarningsAsErrors = false; - if (treatWarningsAsErrors) - { - args[argCount++] = L"-WX"; - } - - switch( targetReq->getDefaultMatrixLayoutMode() ) - { - default: - break; - - case kMatrixLayoutMode_RowMajor: - args[argCount++] = L"-Zpr"; - break; - } - - switch( targetReq->getFloatingPointMode() ) - { - default: - break; - - case FloatingPointMode::Precise: - args[argCount++] = L"-Gis"; // "force IEEE strictness" - break; - } - - auto linkage = compileRequest->getLinkage(); - switch( linkage->optimizationLevel ) - { - default: - break; - - case OptimizationLevel::None: args[argCount++] = L"-Od"; break; - case OptimizationLevel::Default: args[argCount++] = L"-O1"; break; - case OptimizationLevel::High: args[argCount++] = L"-O2"; break; - case OptimizationLevel::Maximal: args[argCount++] = L"-O3"; break; - } - - switch( linkage->debugInfoLevel ) - { - case DebugInfoLevel::None: - break; - - default: - args[argCount++] = L"-Zi"; - break; - } - - // Slang strives to produce correct code, and by default - // we do not show the user warnings produced by a downstream - // compiler. When the downstream compiler *does* produce an - // error, then we dump its entire diagnostic log, which can - // include many distracting spurious warnings that have nothing - // to do with the user's code, and just relate to the idiomatic - // way that Slang outputs HLSL. - // - // It would be nice to use fine-grained flags to disable specific - // warnings here, so that we keep ourselves honest (e.g., only - // use `-Wno-parentheses` to eliminate that class of false positives), - // but alas dxc doesn't support these options even though they - // work on mainline Clang. Thus the only option we have available - // is the big hammer of turning off *all* warnings coming from dxc. - // - args[argCount++] = L"-no-warnings"; - - String entryPointName = getText(entryPoint->getName()); - OSString wideEntryPointName = entryPointName.toWString(); - - auto profile = getEffectiveProfile(entryPoint, targetReq); - String profileName = GetHLSLProfileName(profile); - OSString wideProfileName = profileName.toWString(); - - // We will enable the flag to generate proper code for 16-bit types - // by default, as long as the user is requesting a sufficiently - // high shader model. - // - // TODO: Need to check that this is safe to enable in all cases, - // or if it will make a shader demand hardware features that - // aren't always present. - // - // TODO: Ideally the dxc back-end should be passed some information - // on the "capabilities" that were used and/or requested in the code. - // - if( profile.GetVersion() >= ProfileVersion::DX_6_2 ) - { - args[argCount++] = L"-enable-16bit-types"; - } - - const String sourcePath = calcSourcePathForEntryPoint(endToEndReq, entryPointIndex); - - ComPtr dxcResult; - SLANG_RETURN_ON_FAIL(dxcCompiler->Compile(dxcSourceBlob, - sourcePath.toWString().begin(), - profile.GetStage() == Stage::Unknown ? L"" : wideEntryPointName.begin(), - wideProfileName.begin(), - args, - argCount, - nullptr, // `#define`s - 0, // `#define` count - nullptr, // `#include` handler - dxcResult.writeRef())); - - // Retrieve result. - HRESULT resultCode = S_OK; - SLANG_RETURN_ON_FAIL(dxcResult->GetStatus(&resultCode)); - - // Note: it seems like the dxcompiler interface - // doesn't support querying diagnostic output - // *unless* the compile failed (no way to get - // warnings out!?). - - // Verify compile result - if (SLANG_FAILED(resultCode)) - { - // Compilation failed. - // Try to read any diagnostic output. - ComPtr dxcErrorBlob; - SLANG_RETURN_ON_FAIL(dxcResult->GetErrorBuffer(dxcErrorBlob.writeRef())); - - // Note: the error blob returned by dxc doesn't always seem - // to be nul-terminated, so we should be careful and turn it - // into a string for safety. - // - - reportExternalCompileError("dxc", resultCode, _getSlice(dxcErrorBlob), compileRequest->getSink()); - return resultCode; - } - - // Okay, the compile supposedly succeeded, so we - // just need to grab the buffer with the output DXIL. - ComPtr dxcResultBlob; - SLANG_RETURN_ON_FAIL(dxcResult->GetResult(dxcResultBlob.writeRef())); - - outCode.addRange( - (uint8_t const*)dxcResultBlob->GetBufferPointer(), - (int) dxcResultBlob->GetBufferSize()); - - return SLANG_OK; - } - - SlangResult dissassembleDXILUsingDXC( - BackEndCompileRequest* compileRequest, - void const* data, - size_t size, - String& stringOut) - { - stringOut = String(); - auto session = compileRequest->getSession(); - auto sink = compileRequest->getSink(); - - // First deal with all the rigamarole of loading - // the `dxcompiler` library, and creating the - // top-level COM objects that will be used to - // compile things. - - auto dxcCreateInstance = (DxcCreateInstanceProc)session->getSharedLibraryFunc(Session::SharedLibraryFuncType::Dxc_DxcCreateInstance, sink); - if (!dxcCreateInstance) - { - return SLANG_FAIL; - } - - ComPtr dxcCompiler; - SLANG_RETURN_ON_FAIL(dxcCreateInstance(CLSID_DxcCompiler, __uuidof(dxcCompiler), (LPVOID*) dxcCompiler.writeRef())); - ComPtr dxcLibrary; - SLANG_RETURN_ON_FAIL(dxcCreateInstance(CLSID_DxcLibrary, __uuidof(dxcLibrary), (LPVOID*) dxcLibrary.writeRef())); - - // Create blob from the input data - ComPtr dxcSourceBlob; - SLANG_RETURN_ON_FAIL(dxcLibrary->CreateBlobWithEncodingFromPinned((LPBYTE) data, (UINT32) size, 0, dxcSourceBlob.writeRef())); - - ComPtr dxcResultBlob; - SLANG_RETURN_ON_FAIL(dxcCompiler->Disassemble(dxcSourceBlob, dxcResultBlob.writeRef())); - - stringOut = _getSlice(dxcResultBlob); - - return SLANG_OK; - } - - -} // namespace Slang - -#endif - - - diff --git a/source/slang/emit.cpp b/source/slang/emit.cpp deleted file mode 100644 index d90415d89..000000000 --- a/source/slang/emit.cpp +++ /dev/null @@ -1,510 +0,0 @@ -// emit.cpp -#include "emit.h" - -#include "../core/slang-writer.h" -#include "ir-bind-existentials.h" -#include "ir-dce.h" -#include "ir-entry-point-uniforms.h" -#include "ir-glsl-legalize.h" -#include "ir-insts.h" -#include "ir-link.h" -#include "ir-restructure.h" -#include "ir-restructure-scoping.h" -#include "ir-specialize.h" -#include "ir-specialize-resources.h" -#include "ir-ssa.h" -#include "ir-union.h" -#include "ir-validate.h" -#include "legalize-types.h" -#include "lower-to-ir.h" -#include "mangle.h" -#include "name.h" -#include "syntax.h" -#include "type-layout.h" -#include "visitor.h" - -#include "slang-source-stream.h" -#include "slang-emit-context.h" - -#include "slang-c-like-source-emitter.h" - -#include - -namespace Slang { - -enum class BuiltInCOp -{ - Splat, //< Splat a single value to all values of a vector or matrix type - Init, //< Initialize with parameters (must match the type) -}; - - -// - - -// - -EntryPointLayout* findEntryPointLayout( - ProgramLayout* programLayout, - EntryPoint* entryPoint) -{ - for( auto entryPointLayout : programLayout->entryPoints ) - { - if(entryPointLayout->entryPoint->getName() != entryPoint->getName()) - continue; - - // TODO: We need to be careful about this check, since it relies on - // the profile information in the layout matching that in the request. - // - // What we really seem to want here is some dictionary mapping the - // `EntryPoint` directly to the `EntryPointLayout`, and maybe - // that is precisely what we should build... - // - if(entryPointLayout->profile != entryPoint->getProfile()) - continue; - - // TODO: can't easily filter on translation unit here... - // Ideally the `EntryPoint` should get filled in with a pointer - // the specific function declaration that represents the entry point. - - return entryPointLayout.Ptr(); - } - - return nullptr; -} - - /// Given a layout computed for a scope, get the layout to use when lookup up variables. - /// - /// A scope (such as the global scope of a program) groups its - /// parameters into a pseudo-`struct` type for layout purposes, - /// and in some cases that type will in turn be wrapped in a - /// `ConstantBuffer` type to indicate that the parameters needed - /// an implicit constant buffer to be allocated. - /// - /// This function "unwraps" the type layout to find the structure - /// type layout that must be stored inside. - /// -StructTypeLayout* getScopeStructLayout( - ScopeLayout* scopeLayout) -{ - auto scopeTypeLayout = scopeLayout->parametersLayout->typeLayout; - - if( auto constantBufferTypeLayout = as(scopeTypeLayout) ) - { - scopeTypeLayout = constantBufferTypeLayout->offsetElementTypeLayout; - } - - if( auto structTypeLayout = as(scopeTypeLayout) ) - { - return structTypeLayout; - } - - SLANG_UNEXPECTED("uhandled global-scope binding layout"); - return nullptr; -} - - /// Given a layout computed for a program, get the layout to use when lookup up variables. - /// - /// This is just an alias of `getScopeStructLayout`. - /// -StructTypeLayout* getGlobalStructLayout( - ProgramLayout* programLayout) -{ - return getScopeStructLayout(programLayout); -} - -static void dumpIR( - BackEndCompileRequest* compileRequest, - IRModule* irModule, - char const* label) -{ - DiagnosticSinkWriter writerImpl(compileRequest->getSink()); - WriterHelper writer(&writerImpl); - - if(label) - { - writer.put("### "); - writer.put(label); - writer.put(":\n"); - } - - dumpIR(irModule, writer.getWriter()); - - if( label ) - { - writer.put("###\n"); - } -} - -static void dumpIRIfEnabled( - BackEndCompileRequest* compileRequest, - IRModule* irModule, - char const* label = nullptr) -{ - if(compileRequest->shouldDumpIR) - { - dumpIR(compileRequest, irModule, label); - } -} - -String emitEntryPoint( - BackEndCompileRequest* compileRequest, - EntryPoint* entryPoint, - CodeGenTarget target, - TargetRequest* targetRequest) -{ - auto sink = compileRequest->getSink(); - auto program = compileRequest->getProgram(); - auto targetProgram = program->getTargetProgram(targetRequest); - auto programLayout = targetProgram->getOrCreateLayout(sink); - -// auto translationUnit = entryPoint->getTranslationUnit(); - - auto lineDirectiveMode = compileRequest->getLineDirectiveMode(); - // To try to make the default behavior reasonable, we will - // always use C-style line directives (to give the user - // good source locations on error messages from downstream - // compilers) *unless* they requested raw GLSL as the - // output (in which case we want to maximize compatibility - // with downstream tools). - if (lineDirectiveMode == LineDirectiveMode::Default && targetRequest->getTarget() == CodeGenTarget::GLSL) - { - lineDirectiveMode = LineDirectiveMode::GLSL; - } - - SourceStream sourceStream(compileRequest->getSourceManager(), lineDirectiveMode ); - - EmitContext emitContext; - emitContext.compileRequest = compileRequest; - emitContext.target = target; - emitContext.entryPoint = entryPoint; - emitContext.effectiveProfile = getEffectiveProfile(entryPoint, targetRequest); - emitContext.stream = &sourceStream; - - if (entryPoint && programLayout) - { - emitContext.entryPointLayout = findEntryPointLayout( - programLayout, - entryPoint); - } - - emitContext.programLayout = programLayout; - - // Layout information for the global scope is either an ordinary - // `struct` in the common case, or a constant buffer in the case - // where there were global-scope uniforms. - - StructTypeLayout* globalStructLayout = programLayout ? getGlobalStructLayout(programLayout) : nullptr; - emitContext.globalStructLayout = globalStructLayout; - - CLikeSourceEmitter sourceEmitter(&emitContext); - - { - auto session = targetRequest->getSession(); - - // We start out by performing "linking" at the level of the IR. - // This step will create a fresh IR module to be used for - // code generation, and will copy in any IR definitions that - // the desired entry point requires. Along the way it will - // resolve references to imported/exported symbols across - // modules, and also select between the definitions of - // any "profile-overloaded" symbols. - // - auto linkedIR = linkIR( - compileRequest, - entryPoint, - programLayout, - target, - targetRequest); - auto irModule = linkedIR.module; - auto irEntryPoint = linkedIR.entryPoint; - -#if 0 - dumpIRIfEnabled(compileRequest, irModule, "LINKED"); -#endif - - validateIRModuleIfEnabled(compileRequest, irModule); - - // If the user specified the flag that they want us to dump - // IR, then do it here, for the target-specific, but - // un-specialized IR. - dumpIRIfEnabled(compileRequest, irModule); - - // When there are top-level existential-type parameters - // to the shader, we need to take the side-band information - // on how the existential "slots" were bound to concrete - // types, and use it to introduce additional explicit - // shader parameters for those slots, to be wired up to - // use sites. - // - bindExistentialSlots(irModule, sink); -#if 0 - dumpIRIfEnabled(compileRequest, irModule, "EXISTENTIALS BOUND"); -#endif - validateIRModuleIfEnabled(compileRequest, irModule); - - - - - - // Now that we've linked the IR code, any layout/binding - // information has been attached to shader parameters - // and entry points. Now we are safe to make transformations - // that might move code without worrying about losing - // the connection between a parameter and its layout. - // - // An easy transformation of this kind is to take uniform - // parameters of a shader entry point and move them into - // the global scope instead. - // - moveEntryPointUniformParamsToGlobalScope(irModule); -#if 0 - dumpIRIfEnabled(compileRequest, irModule, "ENTRY POINT UNIFORMS MOVED"); -#endif - validateIRModuleIfEnabled(compileRequest, irModule); - - // Desguar any union types, since these will be illegal on - // various targets. - // - desugarUnionTypes(irModule); -#if 0 - dumpIRIfEnabled(compileRequest, irModule, "UNIONS DESUGARED"); -#endif - validateIRModuleIfEnabled(compileRequest, irModule); - - // Next, we need to ensure that the code we emit for - // the target doesn't contain any operations that would - // be illegal on the target platform. For example, - // none of our target supports generics, or interfaces, - // so we need to specialize those away. - // - // Simplification of existential-based and generics-based - // code may each open up opportunities for the other, so - // the relevant specialization transformations are handled in a - // single pass that looks for all simplification opportunities. - // - // TODO: We also need to extend this pass so that it will "expose" - // existential values that are nested inside of other types, - // so that the simplifications can be applied. - // - // TODO: This pass is *also* likely to be the place where we - // perform specialization of functions based on parameter - // values that need to be compile-time constants. - // - specializeModule(irModule); - - // Debugging code for IR transformations... -#if 0 - dumpIRIfEnabled(compileRequest, irModule, "SPECIALIZED"); -#endif - validateIRModuleIfEnabled(compileRequest, irModule); - - - // Specialization can introduce dead code that could trip - // up downstream passes like type legalization, so we - // will run a DCE pass to clean up after the specialization. - // - // TODO: Are there other cleanup optimizations we should - // apply at this point? - // - eliminateDeadCode(compileRequest, irModule); -#if 0 - dumpIRIfEnabled(compileRequest, irModule, "AFTER DCE"); -#endif - validateIRModuleIfEnabled(compileRequest, irModule); - - // The Slang language allows interfaces to be used like - // ordinary types (including placing them in constant - // buffers and entry-point parameter lists), but then - // getting them to lay out in a reasonable way requires - // us to treat fields/variables with interface type - // *as if* they were pointers to heap-allocated "objects." - // - // Specialization will have replaced fields/variables - // with interface types like `IFoo` with fields/variables - // with pointer-like types like `ExistentialBox`. - // - // We need to legalize these pointer-like types away, - // which involves two main changes: - // - // 1. Any `ExistentialBox<...>` fields need to be moved - // out of their enclosing `struct` type, so that the layout - // of the enclosing type is computed as if the field had - // zero size. - // - // 2. Once an `ExistentialBox` has been floated out - // of its parent and landed somwhere permanent (e.g., either - // a dedicated variable, or a field of constant buffer), - // we need to replace it with just an `X`, after which we - // will have (more) legal shader code. - // - legalizeExistentialTypeLayout( - irModule, - sink); - eliminateDeadCode(compileRequest, irModule); - -#if 0 - dumpIRIfEnabled(compileRequest, irModule, "EXISTENTIALS LEGALIZED"); -#endif - validateIRModuleIfEnabled(compileRequest, irModule); - - // Many of our target languages and/or downstream compilers - // don't support `struct` types that have resource-type fields. - // In order to work around this limitation, we will rewrite the - // IR so that any structure types with resource-type fields get - // split into a "tuple" that comprises the ordinary fields (still - // bundles up as a `struct`) and one element for each resource-type - // field (recursively). - // - // What used to be individual variables/parameters/arguments/etc. - // then become multiple variables/parameters/arguments/etc. - // - legalizeResourceTypes( - irModule, - sink); - eliminateDeadCode(compileRequest, irModule); - - // Debugging output of legalization -#if 0 - dumpIRIfEnabled(compileRequest, irModule, "LEGALIZED"); -#endif - validateIRModuleIfEnabled(compileRequest, irModule); - - // Once specialization and type legalization have been performed, - // we should perform some of our basic optimization steps again, - // to see if we can clean up any temporaries created by legalization. - // (e.g., things that used to be aggregated might now be split up, - // so that we can work with the individual fields). - constructSSA(irModule); - -#if 0 - dumpIRIfEnabled(compileRequest, irModule, "AFTER SSA"); -#endif - validateIRModuleIfEnabled(compileRequest, irModule); - - // After type legalization and subsequent SSA cleanup we expect - // that any resource types passed to functions are exposed - // as their own top-level parameters (which might have - // resource or array-of-...-resource types). - // - // Many of our targets place restrictions on how certain - // resource types can be used, so that having them as - // function parameters is invalid. To clean this up, - // we will try to specialize called functions based - // on the actual resources that are being passed to them - // at specific call sites. - // - // Because the legalization may depend on what target - // we are compiling for (certain things might be okay - // for D3D targets that are not okay for Vulkan), we - // pass down the target request along with the IR. - // - specializeResourceParameters(compileRequest, targetRequest, irModule); - -#if 0 - dumpIRIfEnabled(compileRequest, irModule, "AFTER RESOURCE SPECIALIZATION"); -#endif - validateIRModuleIfEnabled(compileRequest, irModule); - - - // For GLSL only, we will need to perform "legalization" of - // the entry point and any entry-point parameters. - // - // TODO: We should consider moving this legalization work - // as late as possible, so that it doesn't affect how other - // optimization passes need to work. - // - switch (target) - { - case CodeGenTarget::GLSL: - { - legalizeEntryPointForGLSL( - session, - irModule, - irEntryPoint, - compileRequest->getSink(), - &emitContext.extensionUsageTracker); - -#if 0 - dumpIRIfEnabled(compileRequest, irModule, "GLSL LEGALIZED"); -#endif - validateIRModuleIfEnabled(compileRequest, irModule); - } - break; - - default: - break; - } - - // The resource-based specialization pass above - // may create specialized versions of functions, but - // it does not try to completely eliminate the original - // functions, so there might still be invalid code in - // our IR module. - // - // To clean up the code, we will apply a fairly general - // dead-code-elimination (DCE) pass that only retains - // whatever code is "live." - // - eliminateDeadCode(compileRequest, irModule); -#if 0 - dumpIRIfEnabled(compileRequest, irModule, "AFTER DCE"); -#endif - validateIRModuleIfEnabled(compileRequest, irModule); - - // After all of the required optimization and legalization - // passes have been performed, we can emit target code from - // the IR module. - // - // TODO: do we want to emit directly from IR, or translate the - // IR back into AST for emission? - sourceEmitter.emitIRModule(irModule); - } - - // Deal with cases where a particular stage requires certain GLSL versions - // and/or extensions. - switch( entryPoint->getStage() ) - { - default: - break; - - case Stage::AnyHit: - case Stage::Callable: - case Stage::ClosestHit: - case Stage::Intersection: - case Stage::Miss: - case Stage::RayGeneration: - if( target == CodeGenTarget::GLSL ) - { - emitContext.extensionUsageTracker.requireGLSLExtension("GL_NV_ray_tracing"); - emitContext.extensionUsageTracker.requireGLSLVersion(ProfileVersion::GLSL_460); - } - break; - } - - String code = sourceStream.getContent(); - sourceStream.clearContent(); - - // Now that we've emitted the code for all the declarations in the file, - // it is time to stitch together the final output. - - // There may be global-scope modifiers that we should emit now - sourceEmitter.emitGLSLPreprocessorDirectives(); - - sourceEmitter.emitLayoutDirectives(targetRequest); - - String prefix = sourceStream.getContent(); - - StringBuilder finalResultBuilder; - finalResultBuilder << prefix; - - finalResultBuilder << emitContext.extensionUsageTracker.getGLSLExtensionRequireLines(); - - finalResultBuilder << code; - - String finalResult = finalResultBuilder.ProduceString(); - - return finalResult; -} - -} // namespace Slang diff --git a/source/slang/emit.h b/source/slang/emit.h deleted file mode 100644 index dc9300025..000000000 --- a/source/slang/emit.h +++ /dev/null @@ -1,27 +0,0 @@ -// Emit.h -#ifndef SLANG_EMIT_H_INCLUDED -#define SLANG_EMIT_H_INCLUDED - -#include "../core/basic.h" - -#include "compiler.h" - -namespace Slang -{ - class EntryPoint; - class ProgramLayout; - class TranslationUnitRequest; - - // Emit code for a single entry point, based on - // the input translation unit. - String emitEntryPoint( - BackEndCompileRequest* compileRequest, - EntryPoint* entryPoint, - - // The target language to generate code in (e.g., HLSL/GLSL) - CodeGenTarget target, - - // The full target request - TargetRequest* targetRequest); -} -#endif diff --git a/source/slang/expr-defs.h b/source/slang/expr-defs.h deleted file mode 100644 index 8afa93fbd..000000000 --- a/source/slang/expr-defs.h +++ /dev/null @@ -1,206 +0,0 @@ -// expr-defs.h - -// Syntax class definitions for expressions. - - -// Base class for expressions that will reference declarations -ABSTRACT_SYNTAX_CLASS(DeclRefExpr, Expr) - -// The scope in which to perform lookup - FIELD(RefPtr, scope) - - // The declaration of the symbol being referenced - DECL_FIELD(DeclRef, declRef) - - // The name of the symbol being referenced - FIELD(Name*, name) -END_SYNTAX_CLASS() - -SIMPLE_SYNTAX_CLASS(VarExpr, DeclRefExpr) - -// An expression that references an overloaded set of declarations -// having the same name. -SYNTAX_CLASS(OverloadedExpr, Expr) - - // Optional: the base expression is this overloaded result - // arose from a member-reference expression. - SYNTAX_FIELD(RefPtr, base) - - // The lookup result that was ambiguous - FIELD(LookupResult, lookupResult2) -END_SYNTAX_CLASS() - -// An expression that references an overloaded set of declarations -// having the same name. -SYNTAX_CLASS(OverloadedExpr2, Expr) - - // Optional: the base expression is this overloaded result - // arose from a member-reference expression. - SYNTAX_FIELD(RefPtr, base) - - // The lookup result that was ambiguous - FIELD(List>, candidiateExprs) -END_SYNTAX_CLASS() - -ABSTRACT_SYNTAX_CLASS(LiteralExpr, Expr) - // The token that was used to express the literal. This can be - // used to get the raw text of the literal, including any suffix. - FIELD(Token, token) -END_SYNTAX_CLASS() - -SYNTAX_CLASS(IntegerLiteralExpr, LiteralExpr) - FIELD(IntegerLiteralValue, value) -END_SYNTAX_CLASS() - -SYNTAX_CLASS(FloatingPointLiteralExpr, LiteralExpr) - FIELD(FloatingPointLiteralValue, value) -END_SYNTAX_CLASS() - -SYNTAX_CLASS(BoolLiteralExpr, LiteralExpr) - FIELD(bool, value) -END_SYNTAX_CLASS() - -SYNTAX_CLASS(StringLiteralExpr, LiteralExpr) - // TODO: consider storing the "segments" of the string - // literal, in the case where multiple literals were - //lined up at the lexer level, e.g.: - // - // "first" "second" "third" - // - FIELD(String, value) -END_SYNTAX_CLASS() - -// An initializer list, e.g. `{ 1, 2, 3 }` -SYNTAX_CLASS(InitializerListExpr, Expr) - SYNTAX_FIELD(List>, args) -END_SYNTAX_CLASS() - -// A base class for expressions with arguments -ABSTRACT_SYNTAX_CLASS(ExprWithArgsBase, Expr) - SYNTAX_FIELD(List>, Arguments) -END_SYNTAX_CLASS() - -// An aggregate type constructor -SYNTAX_CLASS(AggTypeCtorExpr, ExprWithArgsBase) - SYNTAX_FIELD(TypeExp, base); -END_SYNTAX_CLASS() - - -// A base expression being applied to arguments: covers -// both ordinary `()` function calls and `<>` generic application -ABSTRACT_SYNTAX_CLASS(AppExprBase, ExprWithArgsBase) - SYNTAX_FIELD(RefPtr, FunctionExpr) -END_SYNTAX_CLASS() - -SIMPLE_SYNTAX_CLASS(InvokeExpr, AppExprBase) - -SIMPLE_SYNTAX_CLASS(OperatorExpr, InvokeExpr) - -SIMPLE_SYNTAX_CLASS(InfixExpr , OperatorExpr) -SIMPLE_SYNTAX_CLASS(PrefixExpr , OperatorExpr) -SIMPLE_SYNTAX_CLASS(PostfixExpr, OperatorExpr) - -SYNTAX_CLASS(IndexExpr, Expr) - SYNTAX_FIELD(RefPtr, BaseExpression) - SYNTAX_FIELD(RefPtr, IndexExpression) -END_SYNTAX_CLASS() - -SYNTAX_CLASS(MemberExpr, DeclRefExpr) - SYNTAX_FIELD(RefPtr, BaseExpression) -END_SYNTAX_CLASS() - -// Member looked up on a type, rather than a value -SYNTAX_CLASS(StaticMemberExpr, DeclRefExpr) - SYNTAX_FIELD(RefPtr, BaseExpression) -END_SYNTAX_CLASS() - -SYNTAX_CLASS(SwizzleExpr, Expr) - SYNTAX_FIELD(RefPtr, base) - FIELD(int, elementCount) - FIELD(int, elementIndices[4]) -END_SYNTAX_CLASS() - -// A dereference of a pointer or pointer-like type -SYNTAX_CLASS(DerefExpr, Expr) - SYNTAX_FIELD(RefPtr, base) -END_SYNTAX_CLASS() - -// Any operation that performs type-casting -SYNTAX_CLASS(TypeCastExpr, InvokeExpr) -// SYNTAX_FIELD(TypeExp, TargetType) -// SYNTAX_FIELD(RefPtr, Expression) -END_SYNTAX_CLASS() - -// An explicit type-cast that appear in the user's code with `(type) expr` syntax -SYNTAX_CLASS(ExplicitCastExpr, TypeCastExpr) -END_SYNTAX_CLASS() - -// An implicit type-cast inserted during semantic checking -SYNTAX_CLASS(ImplicitCastExpr, TypeCastExpr) -END_SYNTAX_CLASS() - - /// A cast from a value to an interface ("existential") type. -SYNTAX_CLASS(CastToInterfaceExpr, Expr) -RAW( - /// The value being cast to an interface type - RefPtr valueArg; - - /// A witness showing that `valueArg` conforms to the chosen interface - RefPtr witnessArg; -) -END_SYNTAX_CLASS() - -SIMPLE_SYNTAX_CLASS(SelectExpr, OperatorExpr) - -SIMPLE_SYNTAX_CLASS(GenericAppExpr, AppExprBase) - -// An expression representing re-use of the syntax for a type in more -// than once conceptually-distinct declaration -SYNTAX_CLASS(SharedTypeExpr, Expr) - // The underlying type expression that we want to share - SYNTAX_FIELD(TypeExp, base) -END_SYNTAX_CLASS() - -SYNTAX_CLASS(AssignExpr, Expr) - SYNTAX_FIELD(RefPtr, left); - SYNTAX_FIELD(RefPtr, right); -END_SYNTAX_CLASS() - -// Just an expression inside parentheses `(exp)` -// -// We keep this around explicitly to be sure we don't lose any structure -// when we do rewriter stuff. -SYNTAX_CLASS(ParenExpr, Expr) - SYNTAX_FIELD(RefPtr, base); -END_SYNTAX_CLASS() - -// An object-oriented `this` expression, used to -// refer to the current instance of an enclosing type. -SYNTAX_CLASS(ThisExpr, Expr) - FIELD(RefPtr, scope); -END_SYNTAX_CLASS() - -// An expression that binds a temporary variable in a local expression context -SYNTAX_CLASS(LetExpr, Expr) -RAW( - RefPtr decl; - RefPtr body; -) -END_SYNTAX_CLASS() - -SYNTAX_CLASS(ExtractExistentialValueExpr, Expr) -RAW( - DeclRef declRef; -) -END_SYNTAX_CLASS() - - /// A type expression of the form `__TaggedUnion(A, ...)`. - /// - /// An expression of this form will resolve to a `TaggedUnionType` - /// when checked. - /// -SYNTAX_CLASS(TaggedUnionTypeExpr, Expr) -RAW( - List caseTypes; -) -END_SYNTAX_CLASS() \ No newline at end of file diff --git a/source/slang/image-format-defs.h b/source/slang/image-format-defs.h deleted file mode 100644 index bbd60456d..000000000 --- a/source/slang/image-format-defs.h +++ /dev/null @@ -1,47 +0,0 @@ -// image-format-defs.h -#ifndef FORMAT -#error Must define FORMAT macro before including image-format-defs.h -#endif - -FORMAT(unknown) -FORMAT(rgba32f) -FORMAT(rgba16f) -FORMAT(rg32f) -FORMAT(rg16f) -FORMAT(r11f_g11f_b10f) -FORMAT(r32f) -FORMAT(r16f) -FORMAT(rgba16) -FORMAT(rgb10_a2) -FORMAT(rgba8) -FORMAT(rg16) -FORMAT(rg8) -FORMAT(r16) -FORMAT(r8) -FORMAT(rgba16_snorm) -FORMAT(rgba8_snorm) -FORMAT(rg16_snorm) -FORMAT(rg8_snorm) -FORMAT(r16_snorm) -FORMAT(r8_snorm) -FORMAT(rgba32i) -FORMAT(rgba16i) -FORMAT(rgba8i) -FORMAT(rg32i) -FORMAT(rg16i) -FORMAT(rg8i) -FORMAT(r32i) -FORMAT(r16i) -FORMAT(r8i) -FORMAT(rgba32ui) -FORMAT(rgba16ui) -FORMAT(rgb10_a2ui) -FORMAT(rgba8ui) -FORMAT(rg32ui) -FORMAT(rg16ui) -FORMAT(rg8ui) -FORMAT(r32ui) -FORMAT(r16ui) -FORMAT(r8ui) - -#undef FORMAT diff --git a/source/slang/ir-bind-existentials.cpp b/source/slang/ir-bind-existentials.cpp deleted file mode 100644 index f0c02dd67..000000000 --- a/source/slang/ir-bind-existentials.cpp +++ /dev/null @@ -1,352 +0,0 @@ -// ir-bind-existentials.cpp -#include "ir-bind-existentials.h" - -#include "ir.h" -#include "ir-insts.h" - -namespace Slang -{ - -// The code that comes out of the linking step will have instructions added -// that indicate how parameters with existential (interface) types are supposed -// to be specialized to concrete types. -// -// If there are any global existential-type parameters there should be a -// `bindGlobalExistentialSlots(...)` instruction at module scope. -// -// For each entry point with entry-point existential parameters, there should -// be a `[bindExistentialSlots(...)]` decoration attached to the entry -// point itself. -// -// In each case, the operands of the instruction should be a sequence of -// pairs. The number of pairs should match the number of existential "slots" -// at global or entry-point scope. Each pair should comprise a type `T` -// to plug into the slot, and a witness table `w` for the conformance of -// `T` to the interface type in that slot. -// -// In the simplest case, if we have a global shader parameter of interface -// type: -// -// IFoo p; -// -// Then this will lower to the IR as: -// -// global_param p : IFoo; -// -// And if the user tries to specialie `p` to type `Bar`, and a witness -// table `bar_is_ifoo`, we've have: -// -// bindGlobalExistentialSlots(Bar, bar_is_ifoo); -// -// The goal of this pass is to replace the parameter of interface type -// with one of concrete type: -// -// global_param p_new : Bar; -// -// and replace any reference to the old `p` parameter with -// a `makeExistential(p_new, bar_is_ifoo)`. That preserves the -// fact that a reference to `p` is conceptually of type `IFoo`, -// but allows downstream optimization passes to start specializing -// code based on the concrete knowledge that the value "backing" -// the parameter is actaully of type `Bar`. - -// As is typically for IR passes, we will encapsulate all the -// logic in a `struct` type. -// -struct BindExistentialSlots -{ - IRModule* module = nullptr; - DiagnosticSink* sink = nullptr; - - void processModule() - { - // We will start by dealing with the global existential slots. - processGlobalExistentialSlots(); - - // Then we will process the per-entry-point existential slots. - processEntryPointExistentialSlots(); - } - - void processGlobalExistentialSlots() - { - // If there are any global existential slots, we will expect - // to find a `bindGlobalExistentialSlots` instruction at module scope. - // - // We will start out by finding that instruction, if it exists. - // - IRInst* bindGlobalExistentialSlotsInst = nullptr; - for( auto inst : module->getGlobalInsts() ) - { - if( inst->op == kIROp_BindGlobalExistentialSlots ) - { - bindGlobalExistentialSlotsInst = inst; - break; - } - } - - // Now we will start looking for global shader parameters that make - // use of existential slots (we can determine this from their - // layout). - // - for( auto inst : module->getGlobalInsts() ) - { - // We only care about global shader parameters. - // - auto globalParam = as(inst); - if(!globalParam) - continue; - - // We will delegate to a subroutine for the meat - // of the work, since much of it can be shared - // with the case for entry-point existential - // parameters. - // - processParameter(globalParam, bindGlobalExistentialSlotsInst); - } - - // Once we are done looping over global shader parameters, - // all of the relevant information from the - // `bindGlobalExistentialSlots` instruction will have - // been moved to the parameters themselves, so we - // can eliminate the binding instruction. - // - if( bindGlobalExistentialSlotsInst ) - { - bindGlobalExistentialSlotsInst->removeAndDeallocate(); - } - } - - void processEntryPointExistentialSlots() - { - // The overall flow for the entry-point case is similar - // to the global case. - // - // We start by iterating over all the functions at - // global scope and look for entry points. - // - for( auto inst : module->getGlobalInsts() ) - { - auto func = as(inst); - if(!func) - continue; - - if(!func->findDecorationImpl(kIROp_EntryPointDecoration)) - continue; - - // We then process each entry point we find. - // - processEntryPointExistentialSlots(func); - } - } - - void processEntryPointExistentialSlots(IRFunc* func) - { - // When looking at a single `func`, we need - // to find the `[bindExistentialSlots(...)]` decoration, - // if it has one. - // - auto bindEntryPointExistentialSlotsInst = func->findDecorationImpl(kIROp_BindExistentialSlotsDecoration); - - // We then need to process each of the entry-point - // parameters just like we did for global parameters. - // - for( auto param : func->getParams() ) - { - processParameter(param, bindEntryPointExistentialSlotsInst); - } - - // TODO: We would need to consider what to do if - // we had an existential return type for `func`. - // - // In general, it probably doesn't make sense to - // have existential types in varying input/output - // at all, so the front-end should probably be - // validating that. - - // Once we've processed all the parameters, the information - // in the `[bindExistentialSlots(...)]` decoration is - // no longer needed, and we can remove it. - // - if( bindEntryPointExistentialSlotsInst ) - { - bindEntryPointExistentialSlotsInst->removeAndDeallocate(); - } - } - - // When processing a single parameter we need to have access - // to the corresponding instruction that will bind its slots. - // - // We don't care whether we have a `global_param` and a - // `bindGlobalExistentialSlots` instruction, or an entry-point - // function `param` and a `[bindExistentialSlots(...)]` - // decoration; both use the same subroutine. - // - void processParameter( - IRInst* param, - IRInst* bindSlotsInst) - { - // We expect all shader parameters to have layout information, - // but to be defensive we will skip any that don't. - // - auto layoutDecoration = param->findDecoration(); - if(!layoutDecoration) - return; - auto varLayout = as(layoutDecoration->getLayout()); - if(!varLayout) - return; - - // We only care about parameters that are associated - // with one or more existential slots. - // - auto resInfo = varLayout->FindResourceInfo(LayoutResourceKind::ExistentialTypeParam); - if(!resInfo) - return; - - // We will use the layout information on the variable to - // find out the stating slot, and the information on - // the type to find out the number of slots. - // - UInt firstSlot = resInfo->index; - UInt slotCount = 0; - if(auto typeResInfo = varLayout->getTypeLayout()->FindResourceInfo(LayoutResourceKind::ExistentialTypeParam)) - slotCount = UInt(typeResInfo->count.getFiniteValue()); - - // At this point we know that the parameter consumes - // some number of slots, so it would be an error - // if we don't have an instruction to bind the slots. - // - if( !bindSlotsInst ) - { - // Note: This error is considered an internal error because - // we should be detecting and diagnosing this problem before - // we make it to back-end code generation. - // - sink->diagnose(param->sourceLoc, Diagnostics::missingExistentialBindingsForParameter); - return; - } - - // Each existential slot corresponds to *two* arguments - // on the binding instruction: one for the type, and - // another for the witness table. - // - // We will check to make sure we have enough operands to cover - // this parameter. - // - UInt bindOperandCount = bindSlotsInst->getOperandCount(); - if( 2*(firstSlot + slotCount) > bindOperandCount ) - { - sink->diagnose(param->sourceLoc, Diagnostics::missingExistentialBindingsForParameter); - return; - } - // - // If there are enough operands, then we will offset to - // get to the starting point for the current parameter, - // keeping in mind that each slot accounts for two - // operands. - // - auto operandsForInst = bindSlotsInst->getOperands() + firstSlot; - - // Once we've found the operands that are relevent to - // the slots used by `param`, we will defer to a routine - // that replaces the type of `param` based on the - // information in the slots. - // - replaceTypeUsingExistentialSlots( - param, - slotCount, - operandsForInst); - } - - void replaceTypeUsingExistentialSlots( - IRInst* inst, - UInt slotCount, - IRUse const* slotArgs) - { - SLANG_UNUSED(slotCount); - - // We are going to alter the type of the - // given `inst` based on information in - // the `slotArgs`. - - auto fullType = inst->getFullType(); - - SharedIRBuilder sharedBuilder; - sharedBuilder.session = module->getSession(); - sharedBuilder.module = module; - - IRBuilder builder; - builder.sharedBuilder = &sharedBuilder; - - // Every argument that is filling an existential - // type param/slot comprises both a type and - // a witness table, so the total number of operands - // is twice the number of slots we are filling. - // - UInt slotOperandCount = slotCount*2; - List slotOperands; - for(UInt ii = 0; ii < slotOperandCount; ++ii) - slotOperands.add(slotArgs[ii].get()); - - // We are going to create a proxy type that represents - // the results of plugging all the information - // from the existential slots into the original type. - // - auto newType = builder.getBindExistentialsType( - fullType, - slotOperandCount, - slotOperands.getBuffer()); - - // We will replace the type of the original parameter - // with the new proxy type. - // - builder.setDataType(inst, newType); - - // Next we want to replace all uses of `inst` (which - // expect a value of its old type) with a fresh - // `wrapExistential(...)` instruction that refers to - // `inst` with its new type. - // - // Note: we make a copy of the list of uses for `inst` - // before going through and replacing them, because - // during the replacement we make *more* uses of `inst`, - // as an operand to the `makeExistential` instructions. - // We only want to replace the old uses, and not the - // new ones we'll be making. - // - List usesToReplace; - for(auto use = inst->firstUse; use; use = use->nextUse ) - usesToReplace.add(use); - - // Now we can loop over our list of uses and replace each. - // - for(auto use : usesToReplace) - { - // First we emit a `makeExisential` right before the - // use site. - // - builder.setInsertBefore(use->getUser()); - auto newVal = builder.emitWrapExistential( - fullType, - inst, - slotOperandCount, - slotOperands.getBuffer()); - - // Second we make the use site point at the new - // value instead. - // - use->set(newVal); - } - } -}; - -void bindExistentialSlots( - IRModule* module, - DiagnosticSink* sink) -{ - BindExistentialSlots context; - context.module = module; - context.sink = sink; - context.processModule(); -} - -} diff --git a/source/slang/ir-bind-existentials.h b/source/slang/ir-bind-existentials.h deleted file mode 100644 index c7fca3bb3..000000000 --- a/source/slang/ir-bind-existentials.h +++ /dev/null @@ -1,15 +0,0 @@ -// ir-bind-existentials.h -#pragma once - -namespace Slang -{ - -class DiagnosticSink; -struct IRModule; - - /// Bind concrete types to paameters that use existential slots. -void bindExistentialSlots( - IRModule* module, - DiagnosticSink* sink); - -} diff --git a/source/slang/ir-clone.cpp b/source/slang/ir-clone.cpp deleted file mode 100644 index d26b470d6..000000000 --- a/source/slang/ir-clone.cpp +++ /dev/null @@ -1,295 +0,0 @@ -// ir-clone.cpp -#include "ir-clone.h" - -#include "ir.h" -#include "ir-insts.h" - -namespace Slang -{ - -IRInst* lookUp(IRCloneEnv* env, IRInst* oldVal) -{ - for( auto ee = env; ee; ee = ee->parent ) - { - IRInst* newVal = nullptr; - if(ee->mapOldValToNew.TryGetValue(oldVal, newVal)) - return newVal; - } - return nullptr; -} - -IRInst* findCloneForOperand( - IRCloneEnv* env, - IRInst* oldOperand) -{ - if(!oldOperand) return nullptr; - - // If there is a registered replacement for - // the existing operand, then use it. - // - if( IRInst* newVal = lookUp(env, oldOperand) ) - return newVal; - - // Otherwise, we assume that the caller wants - // to default to using existing values wherever - // an explicit replacement hasn't been registered. - // - // This is, notably, the right default whenever - // `oldOperand` is a global value or constant - // and our cloned code will sit in the same - // module as the original. - // - // TODO: We could make this a customization point - // down the road, if we ever had a case where - // we want to clone things with a different policy. - // - return oldOperand; -} - -IRInst* cloneInstAndOperands( - IRCloneEnv* env, - IRBuilder* builder, - IRInst* oldInst) -{ - SLANG_ASSERT(env); - SLANG_ASSERT(builder); - SLANG_ASSERT(oldInst); - - // This logic will not handle any instructions - // with special-case data attached, but that only - // applies to `IRConstant`s at this point, and those - // should only appear at the global scope rather than - // in function bodies. - // - // TODO: It would be easy enough to extend this logic - // to handle constants gracefully, if it ever comes up. - // - SLANG_ASSERT(!as(oldInst)); - - // We start by mapping the type of the orignal instruction - // to its replacement value, if any. - // - auto oldType = oldInst->getFullType(); - auto newType = (IRType*) findCloneForOperand(env, oldType); - - // Next we will create an empty shell of the instruction, - // with space for the operands, but no actual operand - // values attached. - // - UInt operandCount = oldInst->getOperandCount(); - auto newInst = builder->emitIntrinsicInst( - newType, - oldInst->op, - operandCount, - nullptr); - - // Finally we will iterate over the operands of `oldInst` - // to find their replacements and install them as - // the operands of `newInst`. - // - for(UInt ii = 0; ii < operandCount; ++ii) - { - auto oldOperand = oldInst->getOperand(ii); - auto newOperand = findCloneForOperand(env, oldOperand); - - newInst->getOperands()[ii].init(newInst, newOperand); - } - - return newInst; -} - -// The complexity of the second phase of cloning (the -// one that deals with decorations and children) comes -// from the fact that it needs to sequence the two phases -// of cloning for any child instructions. We will do this -// by performing the first phase of cloning, and building -// up a list of children that require the second phase of processing. -// Each entry in that list will be a pair of an old instruction -// and its new clone. -// -struct IRCloningOldNewPair -{ - IRInst* oldInst; - IRInst* newInst; -}; - -// We will use an internal variant of `cloneInstDecorationsAndChildren` -// that modifies the provided `env` as it goes as the main -// workhorse, since we need to make sure that instructions in -// earlier blocks are visible to those in other, later, blocks -// when cloning a function, so that strict scoping along the -// lines of the nesting of instructions isn't sufficient. -// -static void _cloneInstDecorationsAndChildren( - IRCloneEnv* env, - SharedIRBuilder* sharedBuilder, - IRInst* oldInst, - IRInst* newInst) -{ - SLANG_ASSERT(env); - SLANG_ASSERT(sharedBuilder); - SLANG_ASSERT(oldInst); - SLANG_ASSERT(newInst); - - // We will set up an IR builder that inserts - // into the new parent instruction. - // - IRBuilder builderStorage; - auto builder = &builderStorage; - builder->sharedBuilder = sharedBuilder; - builder->setInsertInto(newInst); - - // When applying the first phase of cloning to - // children, we will keep track of those that - // require the second phase. - // - List pairs; - - for( auto oldChild : oldInst->getDecorationsAndChildren() ) - { - // As a very subtle special case, if one of the children - // of our `oldInst` already has a registered replacement, - // then we don't want to clone it (not least because - // the `Dictionary::Add` method would give us an error - // when we try to insert a new value for the same key). - // - // This arises for entries in `mapOldValToNew` that were - // seeded before cloning begain (e.g., function - // parameters that are to be replaced). - // - if(lookUp(env, oldChild)) - continue; - - // Now we can perform the first phase of cloning - // on the child, and register it in our map from - // old to new values. - // - auto newChild = cloneInstAndOperands(env, builder, oldChild); - env->mapOldValToNew.Add(oldChild, newChild); - - // If and only if the old child had decorations - // or children, we will register it into our - // list for processing in the second phase. - // - if( oldChild->getFirstDecorationOrChild() ) - { - IRCloningOldNewPair pair; - pair.oldInst = oldChild; - pair.newInst = newChild; - pairs.add(pair); - } - } - - // Once we have done first-phase processing for - // all child instructions, we scan through those - // in the list that required second-phase processing, - // and clone their decorations and/or children recursively. - // - for( auto pair : pairs ) - { - auto oldChild = pair.oldInst; - auto newChild = pair.newInst; - - _cloneInstDecorationsAndChildren(env, sharedBuilder, oldChild, newChild); - } -} - -// The public version of `cloneInstDecorationsAndChildren` is then -// just a wrapper over the internal one that sets up a temporary -// environment to use for the cloning process, so that we do -// not leave any lasting changes in the user-provided `env`. -// -void cloneInstDecorationsAndChildren( - IRCloneEnv* env, - SharedIRBuilder* sharedBuilder, - IRInst* oldInst, - IRInst* newInst) -{ - SLANG_ASSERT(sharedBuilder); - SLANG_ASSERT(oldInst); - SLANG_ASSERT(newInst); - - IRCloneEnv subEnvStorage; - auto subEnv = &subEnvStorage; - subEnv->parent = env; - - _cloneInstDecorationsAndChildren(subEnv, sharedBuilder, oldInst, newInst); -} - -// The convenience function `cloneInst` just sequences the -// operations that have already been defined. -// -IRInst* cloneInst( - IRCloneEnv* env, - IRBuilder* builder, - IRInst* oldInst) -{ - SLANG_ASSERT(env); - SLANG_ASSERT(builder); - SLANG_ASSERT(oldInst); - - auto newInst = cloneInstAndOperands( - env, builder, oldInst); - - env->mapOldValToNew.Add(oldInst, newInst); - - cloneInstDecorationsAndChildren( - env, builder->sharedBuilder, oldInst, newInst); - - return newInst; -} - -void cloneDecoration( - IRDecoration* oldDecoration, - IRInst* newParent, - IRModule* module) -{ - SharedIRBuilder sharedBuilder; - sharedBuilder.module = module; - - IRBuilder builder; - builder.sharedBuilder = &sharedBuilder; - - if(auto first = newParent->getFirstDecorationOrChild()) - builder.setInsertBefore(first); - else - builder.setInsertInto(newParent); - - IRCloneEnv env; - cloneInst(&env, &builder, oldDecoration); -} - -void cloneDecoration( - IRDecoration* oldDecoration, - IRInst* newParent) -{ - cloneDecoration( - oldDecoration, - newParent, - newParent->getModule()); -} - -bool IRSimpleSpecializationKey::operator==(IRSimpleSpecializationKey const& other) const -{ - auto valCount = vals.getCount(); - if(valCount != other.vals.getCount()) return false; - for( Index ii = 0; ii < valCount; ++ii ) - { - if(vals[ii] != other.vals[ii]) return false; - } - return true; -} - -int IRSimpleSpecializationKey::GetHashCode() const -{ - auto valCount = vals.getCount(); - int hash = Slang::GetHashCode(valCount); - for( Index ii = 0; ii < valCount; ++ii ) - { - hash = combineHash(hash, Slang::GetHashCode(vals[ii])); - } - return hash; -} - - -} // namespace Slang diff --git a/source/slang/ir-clone.h b/source/slang/ir-clone.h deleted file mode 100644 index bafbfa69d..000000000 --- a/source/slang/ir-clone.h +++ /dev/null @@ -1,183 +0,0 @@ -// ir-clone.h -#pragma once - -#include "../core/dictionary.h" - -#include "ir.h" - -namespace Slang -{ -struct IRBuilder; -struct IRInst; -struct SharedIRBuilder; - -// This file provides an interface to simplify the task of -// correctling "cloning" IR code, whether individual -// instructions, or whole functions. - - /// An environment for mapping existing values to their cloned replacements. - /// - /// This type serves two main roles in the process of IR cloning: - /// - /// * Before cloning begins, a client will usually - /// register the mapping from things that are to be - /// replaced entirely (like function parameters to - /// be specialized away) to their replacements (e.g., - /// a constant value). - /// - /// * During the process of cloning, env environment - /// will be maintained and updated so that when, e.g., - /// an instruction later in a function refers to - /// something from earlier, we can look up the - /// replacement. - /// -struct IRCloneEnv -{ - /// A mapping from old values to their replacements. - Dictionary mapOldValToNew; - - /// A parent environment to fall back to if `mapOldValToNew` doesn't contain a key. - IRCloneEnv* parent = nullptr; -}; - - /// Look up the replacement for `oldVal`, if any, registered in `env`. - /// - /// Returns `nullptr` if `oldVal` has no registered replacement. - /// -IRInst* lookUp(IRCloneEnv* env, IRInst* oldVal); - -// The SSA property and the way we have structured -// our "phi nodes" (block parameters) means that -// just going through the children of a function, -// and then the children of a block will generally -// do the Right Thing and always visit an instruction -// before its uses. -// -// The big exception to this is that branch instructions -// can refer to blocks later in the same function. -// -// We work around this sort of problem in a fairly -// general fashion, by splitting the cloning of -// an instruction into two steps. -// -// The first step is just to clone the instruction -// and its direct operands, but not any decorations -// or children. - - /// Clone `oldInst` and its direct operands. - /// - /// The "direct operands" include the type of the instruction. - /// The type and operands of `oldInst` will be mapped to now - /// values using `findOrCloneOperand` with the given `env`. - /// - /// Any new instruction that gets emitted will be output to - /// the provided `builder`, which must be non-null. - /// - /// This operation does *not* clone any children or decorations on `oldInst`. - /// This operation does *not* register its result as a replacement - /// for `oldInst` in the given `env`. - /// -IRInst* cloneInstAndOperands( - IRCloneEnv* env, - IRBuilder* builder, - IRInst* oldInst); - -// The second phase of cloning an instruction is to clone -// its decorations and children. This step only needs to -// be performed on those instructions that *have* decorations -// and/or children. - - /// Clone any decorations and/or children of `oldInst` onto `newInst` - /// - /// Any new instructions that get emitted will use the - /// provided `sharedBuilder`, which must be non-null. - /// - /// During the process of cloning decorations/children, operand values - /// will be looked up in the provided `env`, which should provide - /// replacement values for instructions that should have a different - /// identity in the clone. - /// The provided `env` will *not* be updated/modified during the - /// process of cloding decorations/children. - /// - /// If any child or decoration on `oldInst` already has a replacement - /// registered in `env`, it will *not* be cloned into `newInst`. - /// -void cloneInstDecorationsAndChildren( - IRCloneEnv* env, - SharedIRBuilder* sharedBuilder, - IRInst* oldInst, - IRInst* newInst); - -// For the case where the user knows the sequencing constraints -// on cloning operands before uses can be satisfied, we provide -// a convenience wrapper around the two phases of cloning: - - /// Clone `oldInst` and return the cloned value. - /// - /// This function is a convenience wrapper around - /// `cloneInstAndOperands` and `cloneInstDecorationsAndChildren`. - /// It also registers the resultint instruction as - /// the replacement value for `oldInst` in the given `env` - /// which must therefore be non-null. - /// -IRInst* cloneInst( - IRCloneEnv* env, - IRBuilder* builder, - IRInst* oldInst); - - /// Clone `oldDecoration` and attach the clone to `newParent`. - /// - /// Uses `module` to allocate any new instructions. - /// -void cloneDecoration( - IRDecoration* oldDecoration, - IRInst* newParent, - IRModule* module); - - /// Clone `oldDecoration` and attach the clone to `newParent`. - /// - /// Uses the module of `newParent` to allocate any new instructions, - /// so that `newParent` must already be installed somewhere - /// in the ownership hierarchy of an existing module. - /// -void cloneDecoration( - IRDecoration* oldDecoration, - IRInst* newParent); - - - /// Find the "cloned" value to use for an operand. - /// - /// This either returns the value registered for `oldOperand` - /// in `env`, or else `oldOperand` itself. -IRInst* findCloneForOperand( - IRCloneEnv* env, - IRInst* oldOperand); - -// It isn't technically part of the cloning infrastructure, -// but when make specialized copies of IR instructions via -// cloning we often need a simple kind of key suitable -// for caching existing specializations, so we'll define -// it here so that is is easily accessible to code that -// needs it. - -struct IRSimpleSpecializationKey -{ - // The structure of a specialization key will be a list - // of instructions, typically starting with the function, - // generic, or other object to be specialized, and then - // having one or more entries to represent the specialization - // arguments. - // - List vals; - - // In order to use this type as a `Dictionary` key we - // need it to support equality and hashing. - // - // TODO: honestly we might consider having `GetHashCode` - // and `operator==` defined for `List`. - - bool operator==(IRSimpleSpecializationKey const& other) const; - int GetHashCode() const; -}; - -} diff --git a/source/slang/ir-constexpr.cpp b/source/slang/ir-constexpr.cpp deleted file mode 100644 index c56a15663..000000000 --- a/source/slang/ir-constexpr.cpp +++ /dev/null @@ -1,553 +0,0 @@ -// ir-constexpr.cpp -#include "ir-constexpr.h" - -#include "ir.h" -#include "ir-insts.h" - -namespace Slang { - -struct PropagateConstExprContext -{ - IRModule* module; - IRModule* getModule() { return module; } - - DiagnosticSink* sink; - - SharedIRBuilder sharedBuilder; - IRBuilder builder; - - List workList; - HashSet onWorkList; - - IRBuilder* getBuilder() { return &builder; } - - Session* getSession() { return sharedBuilder.session; } - - DiagnosticSink* getSink() { return sink; } -}; - -bool isConstExpr(IRType* fullType) -{ - if( auto rateQualifiedType = as(fullType)) - { - auto rate = rateQualifiedType->getRate(); - if(auto constExprRate = as(rate)) - return true; - } - - return false; -} - -bool isConstExpr(IRInst* value) -{ - // Certain IR value ops are implicitly `constexpr` - // - // TODO: should we just go ahead and make that explicit - // in the type system? - switch(value->op) - { - case kIROp_IntLit: - case kIROp_FloatLit: - case kIROp_BoolLit: - case kIROp_Func: - return true; - - default: - break; - } - - if(isConstExpr(value->getFullType())) - return true; - - return false; -} - -bool opCanBeConstExpr(IROp op) -{ - switch( op ) - { - case kIROp_IntLit: - case kIROp_FloatLit: - case kIROp_BoolLit: - case kIROp_Add: - case kIROp_Sub: - case kIROp_Mul: - case kIROp_Div: - case kIROp_Mod: - case kIROp_Neg: - case kIROp_Construct: - case kIROp_makeVector: - case kIROp_makeArray: - case kIROp_MakeMatrix: - // TODO: more cases - return true; - - default: - return false; - } -} - -bool opCanBeConstExpr(IRInst* value) -{ - // TODO: realistically need to special-case `call` - // operations here, so that we check whether the - // callee function is fixed/known, and if it is - // whether it has been decoared as constant-foldable - - return opCanBeConstExpr(value->op); -} - -void markConstExpr( - PropagateConstExprContext* context, - IRInst* value) -{ - Slang::markConstExpr(context->getBuilder(), value); -} - - -// Propagate `constexpr`-ness in a forward direction, from the -// operands of an instruction to the instruction itself. -bool propagateConstExprForward( - PropagateConstExprContext* context, - IRGlobalValueWithCode* code) -{ - bool anyChanges = false; - for(;;) - { - bool changedThisIteration = false; - for( auto bb = code->getFirstBlock(); bb; bb = bb->getNextBlock() ) - { - for( auto ii = bb->getFirstInst(); ii; ii = ii->getNextInst() ) - { - // Instruction already `constexpr`? Then skip it. - if(isConstExpr(ii)) - continue; - - // Is the operation one that we can actually make be constexpr? - if(!opCanBeConstExpr(ii)) - continue; - - // Are all arguments `constexpr`? - bool allArgsConstExpr = true; - UInt argCount = ii->getOperandCount(); - for( UInt aa = 0; aa < argCount; ++aa ) - { - auto arg = ii->getOperand(aa); - - if( !isConstExpr(arg) ) - { - allArgsConstExpr = false; - break; - } - } - if(!allArgsConstExpr) - continue; - - // Seems like this operation can/should be made constexpr - markConstExpr(context, ii); - changedThisIteration = true; - } - } - - if( !changedThisIteration ) - return anyChanges; - - anyChanges = true; - } -} - -void maybeAddToWorkList( - PropagateConstExprContext* context, - IRInst* gv) -{ - if( !context->onWorkList.Contains(gv) ) - { - context->workList.add(gv); - context->onWorkList.Add(gv); - } -} - -bool maybeMarkConstExpr( - PropagateConstExprContext* context, - IRInst* value) -{ - if(isConstExpr(value)) - return false; - - if(!opCanBeConstExpr(value)) - return false; - - markConstExpr(context, value); - - // TODO: we should only allow function parameters to be - // changed to be `constexpr` when we are compiling "application" - // code, and not library code. - // (Or eventually we'd have a rule that only non-`public` symbols - // can have this kind of propagation applied). - - if(value->op == kIROp_Param) - { - auto param = (IRParam*) value; - auto block = (IRBlock*) param->parent; - auto code = block->getParent(); - - if(block == code->getFirstBlock()) - { - // We've just changed a function parameter to - // be `constexpr`. We need to remember that - // fact so taht we can mark callers of this - // function as `constexpr` themselves. - - for( auto u = code->firstUse; u; u = u->nextUse ) - { - auto user = u->getUser(); - - switch( user->op ) - { - case kIROp_Call: - { - auto inst = (IRCall*) user; - auto caller = as(inst->getParent()->getParent()); - maybeAddToWorkList(context, caller); - } - break; - - default: - break; - } - } - } - } - - return true; -} - -// Propagate `constexpr`-ness in a backward direction, from an instruction -// to its operands. -bool propagateConstExprBackward( - PropagateConstExprContext* context, - IRGlobalValueWithCode* code) -{ - SharedIRBuilder sharedBuilder; - sharedBuilder.module = context->getModule(); - sharedBuilder.session = sharedBuilder.module->session; - - IRBuilder builder; - builder.sharedBuilder = &sharedBuilder; - builder.setInsertInto(code); - - bool anyChanges = false; - for(;;) - { - // Note: we are walking the list of blocks and the instructions - // in each block in reverse order, to maximize the chances that - // we propagate multiple changes in a each pass. - // - // TODO: this should probably all be done with a work list instead, - // but that requires being able to detect instructions vs. other - // values. - - bool changedThisIteration = false; - for( auto bb = code->getLastBlock(); bb; bb = bb->getPrevBlock() ) - { - for( auto ii = bb->getLastInst(); ii; ii = ii->getPrevInst() ) - { - if( isConstExpr(ii) ) - { - // If this instruction is `constexpr`, then its operands should be too. - UInt argCount = ii->getOperandCount(); - for( UInt aa = 0; aa < argCount; ++aa ) - { - auto arg = ii->getOperand(aa); - if(isConstExpr(arg)) - continue; - - if(!opCanBeConstExpr(arg)) - continue; - - if( maybeMarkConstExpr(context, arg) ) - { - changedThisIteration = true; - } - } - } - else if( ii->op == kIROp_Call ) - { - // A non-constexpr call might be calling a function with one or - // more constexpr parameters. We should check if we can resolve - // the callee for this call statically, and if so try to propagate - // constexpr from the parameters back to the arguments. - auto callInst = (IRCall*) ii; - - UInt operandCount = callInst->getOperandCount(); - - UInt firstCallArg = 1; - UInt callArgCount = operandCount - firstCallArg; - - auto callee = callInst->getOperand(0); - - // If we are calling a generic operation, then - // try to follow through the `specialize` chain - // and find the callee. - // - // TODO: This probably shouldn't be required, - // since we can hopefully use the type of the - // callee in all cases. - // - while(auto specInst = as(callee)) - { - auto genericInst = as(specInst->getBase()); - if(!genericInst) - break; - - auto returnVal = findGenericReturnVal(genericInst); - if(!returnVal) - break; - - callee = returnVal; - } - - auto calleeFunc = as(callee); - if(calleeFunc && isDefinition(calleeFunc)) - { - // We have an IR-level function definition we are calling, - // and thus we can propagate `constexpr` information - // through its `IRParam`s. - - auto calleeFuncType = calleeFunc->getDataType(); - - UInt callParamCount = calleeFuncType->getParamCount(); - SLANG_RELEASE_ASSERT(callParamCount == callArgCount); - - // If the callee has a definition, then we can read `constexpr` - // information off of the parameters of its first IR block. - if(auto calleeFirstBlock = calleeFunc->getFirstBlock()) - { - UInt paramCounter = 0; - for(auto pp = calleeFirstBlock->getFirstParam(); pp; pp = pp->getNextParam()) - { - UInt paramIndex = paramCounter++; - - auto param = pp; - auto arg = callInst->getOperand(firstCallArg + paramIndex); - - if(isConstExpr(param)) - { - if(maybeMarkConstExpr(context, arg)) - { - changedThisIteration = true; - } - } - } - } - } - else - { - // If we don't have a concrete callee function - // definition, then we need to extract the - // type of the callee instruction, and try to work - // with that. - // - // Note that this does not allow us to propagate - // `constexpr` information from the body of a callee - // back to call sites. - auto calleeType = callee->getDataType(); - if(auto caleeFuncType = as(calleeType)) - { - auto paramCount = caleeFuncType->getParamCount(); - for( UInt pp = 0; pp < paramCount; ++pp ) - { - auto paramType = caleeFuncType->getParamType(pp); - auto arg = callInst->getOperand(firstCallArg + pp); - if( isConstExpr(paramType) ) - { - if( maybeMarkConstExpr(context, arg) ) - { - changedThisIteration = true; - } - } - } - } - } - } - } - - if( bb != code->getFirstBlock() ) - { - // A parameter in anything butr the first block is - // conceptually a phi node, which means its operands - // are the corresponding values from the terminating - // branch in a predecessor block. - - UInt paramCounter = 0; - for( auto pp = bb->getFirstParam(); pp; pp = pp->getNextParam() ) - { - UInt paramIndex = paramCounter++; - - if(!isConstExpr(pp)) - continue; - - for(auto pred : bb->getPredecessors()) - { - auto terminator = pred->getLastInst(); - if(terminator->op != kIROp_unconditionalBranch) - continue; - - UInt operandIndex = paramIndex + 1; - SLANG_RELEASE_ASSERT(operandIndex < terminator->getOperandCount()); - - auto operand = terminator->getOperand(operandIndex); - if( maybeMarkConstExpr(context, operand) ) - { - changedThisIteration = true; - } - } - } - } - - } - - if( !changedThisIteration ) - return anyChanges; - - anyChanges = true; - } -} - -// Validate use of `constexpr` within a function (in particular, -// diagnose places where a value that must be contexpr depends -// on a value that cannot be) -void validateConstExpr( - PropagateConstExprContext* context, - IRGlobalValueWithCode* code) -{ - for( auto bb = code->getFirstBlock(); bb; bb = bb->getNextBlock() ) - { - for( auto ii = bb->getFirstInst(); ii; ii = ii->getNextInst() ) - { - if(isConstExpr(ii)) - { - // For an instruction that must be `constexpr`, we need - // to ensure that its argumenst are all `constexpr` - - UInt argCount = ii->getOperandCount(); - for( UInt aa = 0; aa < argCount; ++aa ) - { - auto arg = ii->getOperand(aa); - - if( !isConstExpr(arg) ) - { - // Diagnose the failure. - - context->getSink()->diagnose(ii->sourceLoc, Diagnostics::needCompileTimeConstant); - - break; - } - } - } - } - } -} - -void propagateConstExpr( - IRModule* module, - DiagnosticSink* sink) -{ - auto session = module->session; - - PropagateConstExprContext context; - context.module = module; - context.sink = sink; - context.sharedBuilder.module = module; - context.sharedBuilder.session = session; - context.builder.sharedBuilder = &context.sharedBuilder; - - - // We need to propagate information both forward and backward. - // - // In the forward direction we need to check if all of the operands - // to an instruction are `constexpr` *and* if the operation is - // one that can conceptually be "promoted" to the constexpr rate. - // - // In the backward direction, if an instruction has already been - // marked as needing to be `constexpr`, then its operands had - // better be too. - // - // The backward direction needs to be interprocedural, because - // a parameter to a function might be `constexpr`, so that callers - // of that function would need to be marked too. If backwards - // propagation in any of the callers leads to some of their - // parameters being marked constexpr, then we would need to - // revisit their callers. - - // We will build an initial work list with all of the global values in it. - - for( auto ii : module->getGlobalInsts() ) - { - maybeAddToWorkList(&context, ii); - } - - // We will iterate applying propagation to one global value at a time - // until we run out. - while( context.workList.getCount() ) - { - auto gv = context.workList[0]; - context.workList.fastRemoveAt(0); - context.onWorkList.Remove(gv); - - switch( gv->op ) - { - default: - break; - - case kIROp_Func: - case kIROp_GlobalVar: - case kIROp_GlobalConstant: - { - IRGlobalValueWithCode* code = (IRGlobalValueWithCode*) gv; - - for( ;;) - { - bool anyChange = false; - if( propagateConstExprForward(&context, code) ) - { - anyChange = true; - } - if( propagateConstExprBackward(&context, code) ) - { - anyChange = true; - } - if(!anyChange) - break; - } - } - break; - } - } - - // Okay, we've processed all our functions and found a steady state. - // Now we need to try and issue diagnostics for any IR values where - // we find that they are *required* to be `constexpr`, but *cannot* - // be, for some reason. - - for(auto ii : module->getGlobalInsts()) - { - switch( ii->op ) - { - default: - break; - - case kIROp_Func: - case kIROp_GlobalVar: - case kIROp_GlobalConstant: - { - IRGlobalValueWithCode* code = (IRGlobalValueWithCode*) ii; - validateConstExpr(&context, code); - } - break; - } - } - -} - -} diff --git a/source/slang/ir-constexpr.h b/source/slang/ir-constexpr.h deleted file mode 100644 index 04f2e59ec..000000000 --- a/source/slang/ir-constexpr.h +++ /dev/null @@ -1,12 +0,0 @@ -// ir-constexpr.h -#pragma once - -namespace Slang -{ - class DiagnosticSink; - struct IRModule; - - void propagateConstExpr( - IRModule* module, - DiagnosticSink* sink); -} diff --git a/source/slang/ir-dce.cpp b/source/slang/ir-dce.cpp deleted file mode 100644 index f1a34bedf..000000000 --- a/source/slang/ir-dce.cpp +++ /dev/null @@ -1,325 +0,0 @@ -// ir-dce.cpp -#include "ir-dce.h" - -#include "ir.h" -#include "ir-insts.h" - -namespace Slang -{ - -struct DeadCodeEliminationContext -{ - // This type implements a simple global DCE pass over - // an entire module. - // - // We start with member variables to stand in for - // the parameters that were passed to the top-level - // `eliminateDeadCode` function. - // - BackEndCompileRequest* compileRequest; - IRModule* module; - - // Our overall process is going to be to determine - // which instructions in the module are "live" - // and then eliminate anything that wasn't found to - // be live. - // - // We will track the liveness state by keeping - // a set of all instructions we have so far determined - // to be live. - // - HashSet liveInsts; - - // Querying whether an instruction has been - // determined to be live is easy. - // - bool isInstLive(IRInst* inst) - { - // The only wrinkle is that we want to safeguard - // against a null instruction (there are some - // corner cases where we still construct IR - // instructions with a null type). - // - if(!inst) return false; - - return liveInsts.Contains(inst); - } - - // We are going to do an iterative analysis - // where we mark instructions we know are - // live, and then see if that can help us - // identify any other instructions that - // must also be live. - // - // For this, we will use a work list of - // instructions that have been marked - // as live, but for which we haven't - // looked at their impact on other - // instructions. - // - List workList; - - // When we discover that an instruction seems - // to be live, we will add it to our set, - // and also the work list, but only if we - // haven't done so previously. - // - void markInstAsLive(IRInst* inst) - { - // Again, we safeguard against null instructions - // just in case. - // - if(!inst) return; - - if(liveInsts.Contains(inst)) - return; - liveInsts.Add(inst); - workList.add(inst); - } - - // Given the basic infrastructrure above, let's - // dive into the task of actually finding all - // the live code in a module. - // - void processModule() - { - // First of all, we know that the root module instruction - // should be considered as live, because otherwise - // we'd end up eliminating it, so that is a - // good place to start. - // - markInstAsLive(module->getModuleInst()); - - // Marking the module as live should have - // seeded our work list, so we can now start - // processing entries off of our work list - // until it goes dry. - // - while( workList.getCount() ) - { - auto inst = workList.getLast(); - workList.removeLast(); - - // At this point we know that `inst` is live, - // and we want to start considering which other - // instructions must be live because of that - // knowlege. - // - // A first easy case is that the parent (if any) - // of a live instruction had better be live, or - // else we might delete the parent, and - // the child with it. - // - markInstAsLive(inst->getParent()); - - // Next the type of a live instruction, and all - // of its operands must also be live, or else - // we won't be able to compute its value. - // - markInstAsLive(inst->getFullType()); - UInt operandCount = inst->getOperandCount(); - for( UInt ii = 0; ii < operandCount; ++ii ) - { - markInstAsLive(inst->getOperand(ii)); - } - - // Finally, we need to consider the children - // and decorations of the instruction. - // - // Note that just because an instruction is - // live doesn't mean its children must be, or - // else we'd never eliminate *anything* (we - // marked the whole module as live, and everything - // is a transitive child of the module). - // - // Decorations, in contrast, are always live if their - // parents are (because we don't want to silently drop - // decorations). It is still important to *mark* - // decorations as live, because they have operands, - // and those operands need to be marked as live. - // We will fold decorations into the same loop - // as children for simplicity. - // - // To keep the code here simple, we'll defer the - // decision of whether a child (or decoration) - // should be live when its parent is to a subroutine. - // - for( auto child : inst->getDecorationsAndChildren() ) - { - if(shouldInstBeLiveIfParentIsLive(child)) - { - // In this case, we know `inst` is live and - // its `child` should be live if its parent is, - // so the `child` must be live too. - // - markInstAsLive(child); - } - } - } - - // If our work list runs dry, that means we've reached a steady - // state where everything that is transitively relevant to - // the "outputs" of the module has been marked as live. - // - // Now we can simply walk through all of our instructions - // recursively and eliminate those that are "dead" by - // virtue of not having been found live. - // - eliminateDeadInstsRec(module->getModuleInst()); - } - - void eliminateDeadInstsRec(IRInst* inst) - { - // Given the instruction `inst` we need to eliminate - // any dead code at, or under it. - // - // The easy case is if `inst` is dead (that is, not live). - // - if( !isInstLive(inst) ) - { - // We can simply remove and deallocate `inst` because it is - // dead, and not worry about any of its descendents, - // because they must have been dead too (since we always - // mark the parent of a live instruction as live). - // - inst->removeAndDeallocate(); - } - else - { - // If `inst` is live, then we need to deal with the possibility - // that its children/decorations (or descendents in general) - // might still be dead. - // - // The biggest wrinkle is that we walk the linked list of - // children/decorations a bit carefully, using a temporary - // to hold the next node, in case we eliminate one of - // the children as we go. - // - IRInst* next = nullptr; - for( IRInst* child = inst->getFirstDecorationOrChild(); child; child = next ) - { - next = child->getNextInst(); - eliminateDeadInstsRec(child); - } - } - } - - // Now we come to the decision procedure we put off before: - // should a given `inst` be live if its parent is? - // - bool shouldInstBeLiveIfParentIsLive(IRInst* inst) - { - // The main source of confusion/complexity here is that - // we are using the same routine to decide: - // - // * Should some ordinary instruction in a basic block be kept around? - // * Should a basic block in some function be kept around? - // * Should a function/type/variable in a module be kept around? - // - // Still, there are a few basic patterns we can observe. - // First, if `inst` is an instruction that might have some effects - // when it is executed, then we should keep it around. - // - if(inst->mightHaveSideEffects()) - return true; - // - // The `mightHaveSideEffects` query is conservative, and will - // return `true` as its default mode, so once we are past that - // query we know that `inst` is either something "structural" - // (that makes up the program) rather than executable, or it - // is executable but was on a white list of things that are - // safe to eliminate. - - // Most top-level objects (functions, types, etc.) obviously - // do *not* have side effects. That creates the risk that - // we'll just go ahead and eliminate every single function/type - // in a module. There needs to be a way to identify the - // functions we want to keep around, and for right now - // that is handled with the `[keepAlive]` decoration. - // - if(inst->findDecorationImpl(kIROp_KeepAliveDecoration)) - return true; - // - // TODO: Eventually it would make sense to consider everything - // with an `[export(...)]` decoration as live, but our current - // approach to linking for back-end compilation leaves many - // linkage decorations in place that we seemingly don't need/want. - - // A basic block is an interesting case. Knowing that a function - // is live means that its entry block is live, but the liveness - // of any other blocks is determined by whether they are referenced - // by other instructions (e.g., a branch from one block to - // another). - // - if( auto block = as(inst) ) - { - // To determine whether this is the first block in its - // parent function (or what-have-you) we can simply - // check if there is a previous block before it. - // - auto prevBlock = block->getPrevBlock(); - return prevBlock == nullptr; - } - - // There are a few special cases of "structural" instructions - // that we don't want to eliminate, so we'll check for those next. - // - switch( inst->op ) - { - // Function parameters obviously shouldn't get eliminated, - // even if nothing references them, and block parameters - // (phi nodes) will be considered live when their block is, - // just so that we don't have to deal with any complications - // around re-writing the relevant inter-block argument passing. - // - // TODO: A smarter DCE pass could deal with this case more - // carefully, or we could improve the interprocedural SCCP - // pass to deal with block parameters instead. - // - case kIROp_Param: - return true; - - // IR struct types and witness tables are currently kludged - // so that they have child instructions that represent their - // entries (effectively `(key,value)` pairs), and those child - // instructions are never directly referenced (e.g., an access - // to a struct field references the *key* but not the `(key,value)` - // pair that is the `IRField` instruction. - // - // TODO: at some point the IR should use a different representation - // for struct types and witness tables that does away with - // this problem. - // - case kIROp_StructField: - case kIROp_WitnessTableEntry: - return true; - - default: - break; - } - - // If none of the explicit cases above matched, then we will consider - // the instruction to not be live just because its parent is. Further - // analysis could still lead to a change in the status of `inst`, if - // an instruction that uses it as an operand is marked live. - // - return false; - } -}; - -// The top-level function for invoking the DCE pass -// is straighforward. We set up the context object -// and then defer to it for the real work. -// -void eliminateDeadCode( - BackEndCompileRequest* compileRequest, - IRModule* module) -{ - DeadCodeEliminationContext context; - context.compileRequest = compileRequest; - context.module = module; - - context.processModule(); -} - -} diff --git a/source/slang/ir-dce.h b/source/slang/ir-dce.h deleted file mode 100644 index 6089b404a..000000000 --- a/source/slang/ir-dce.h +++ /dev/null @@ -1,19 +0,0 @@ -// ir-dce.h -#pragma once - -namespace Slang -{ - class BackEndCompileRequest; - struct IRModule; - - /// Eliminate "dead" code from the given IR module. - /// - /// This pass is primarily designed for flow-insensitive - /// "global" dead code elimination (DCE), such as removing - /// types that are unused, functions that are never called, - /// etc. - /// - void eliminateDeadCode( - BackEndCompileRequest* compileRequest, - IRModule* module); -} diff --git a/source/slang/ir-dominators.cpp b/source/slang/ir-dominators.cpp deleted file mode 100644 index 488e67724..000000000 --- a/source/slang/ir-dominators.cpp +++ /dev/null @@ -1,720 +0,0 @@ -// ir-dominators.cpp -#include "ir-dominators.h" - -// -// This file implements the public interface of the `IRDominatorTree` type, -// to enable queries on dominance relationships in a control-flow graph. -// -// It also implements computation of the dominator tree for a CFG using -// the algorithm presented in "A Simple, Fast Dominance Algorithm" by -// Keith D. Cooper, Timothy J. Harvey, and Ken Kennedy. -// -// The algorithm is *not* the most efficinet one, asymptotically, but -// it is one that is easy to implement and explain, and so we favor it -// in order to get something up and running with a reasonable level of -// confidence that the results are correct. -// - -#include "ir.h" - -namespace Slang { - -// -// Let's start with the implementation of the public API for `IRDominatorTree` -// - -// IRDominatorTree - -bool IRDominatorTree::immediatelyDominates(IRBlock* dominator, IRBlock* dominated) -{ - // To test if block A immediately dominates block B, we just - // check if A is the (one and only) immediate dominator of B. - return dominator == getImmediateDominator(dominated); -} - -bool IRDominatorTree::properlyDominates(IRBlock* dominator, IRBlock* dominated) -{ - // Because of how we laid out the tree, we can test if one node - // properly dominates another in constant time. - // - // We simply need to test if the node index for `dominated` falls - // in the range of indices for the descendents of `dominator`. - // - - Int dominatorIndex = getBlockIndex(dominator); - Int dominatedIndex = getBlockIndex(dominated); - Node& dominatorNode = nodes[dominatorIndex]; - - return (dominatedIndex >= dominatorNode.beginDescendents) - && (dominatedIndex < dominatorNode.endDescendents); -} - -bool IRDominatorTree::dominates(IRBlock* dominator, IRBlock* dominated) -{ - // We need to check two cases here. - // - // First, a node always dominated itself, so if the blocks are - // the the same, then we are done: - // - if(dominator == dominated) - return true; - // - // Otherwise, for distinct blocks we just check for - // proper dominance: - // - return properlyDominates(dominator, dominated); -} - -IRBlock* IRDominatorTree::getImmediateDominator(IRBlock* block) -{ - // The immediate dominator of a block is its parent - // in the dominator tree. Looking this up is straightforward, - // and we just need to be a bit careful to deal with - // invalid node indices. - - Int blockIndex = getBlockIndex(block); - if(blockIndex == kInvalidIndex) return nullptr; - - Int parentIndex = nodes[blockIndex].parent; - if(parentIndex == kInvalidIndex) return nullptr; - - return nodes[parentIndex].block; -} - -IRDominatorTree::DominatedList IRDominatorTree::getImmediatelyDominatedBlocks(IRBlock* block) -{ - // Because of our representation, the immediately dominated blocks - // for a node are contiguous, and we store their range in the - // node already. - - Int blockIndex = getBlockIndex(block); - if(blockIndex == kInvalidIndex) return DominatedList(); - - Node& node = nodes[blockIndex]; - return DominatedList( - this, - node.beginDescendents, - node.endChildren); -} - -IRDominatorTree::DominatedList IRDominatorTree::getProperlyDominatedBlocks(IRBlock* block) -{ - // Because of our representation, the properly dominated blocks - // for a node are contiguous, and we store their range in the - // node already. - - Int blockIndex = getBlockIndex(block); - if(blockIndex == kInvalidIndex) return DominatedList(); - - Node& node = nodes[blockIndex]; - return DominatedList( - this, - node.beginDescendents, - node.endDescendents); -} - -Int IRDominatorTree::getBlockIndex(IRBlock* block) -{ - Int index = kInvalidIndex; - if(!mapBlockToIndex.TryGetValue(block, index)) - { - SLANG_UNEXPECTED("block was not present in dominator tree"); - } - return index; -} - -// IRDominatorTree::DominatedList - -IRDominatorTree::DominatedList::DominatedList() - : mTree(nullptr) - , mBegin(0) - , mEnd(0) -{} - -IRDominatorTree::DominatedList::DominatedList( - IRDominatorTree* tree, - Int begin, - Int end) - : mTree(tree) - , mBegin(begin) - , mEnd(end) -{} - -IRDominatorTree::DominatedList::Iterator IRDominatorTree::DominatedList::begin() const -{ - return Iterator(mTree, mBegin); -} - -IRDominatorTree::DominatedList::Iterator IRDominatorTree::DominatedList::end() const -{ - return Iterator(mTree, mEnd); -} - - -// IRDominatorTree::DominatedList::Iterator - -IRDominatorTree::DominatedList::Iterator::Iterator() - : mTree(nullptr) - , mIndex(0) -{} - -IRDominatorTree::DominatedList::Iterator::Iterator( - IRDominatorTree* tree, - Int index) - : mTree(tree) - , mIndex(index) -{} - -IRBlock* IRDominatorTree::DominatedList::Iterator::operator*() const -{ - return mTree->nodes[mIndex].block; -} - -void IRDominatorTree::DominatedList::Iterator::operator++() -{ - mIndex++; -} - -bool IRDominatorTree::DominatedList::Iterator::operator==(Iterator const& that) const -{ - SLANG_ASSERT(mTree == that.mTree); - return mIndex == that.mIndex; -} - -// -// The dominance computation algorithm we are using relies on being able to compute -// a reverse postorder traversal of the nodes in the CFG, which is done using a depth-first -// search (DFS). We don't currently have infrastructure for DFS in the compiler, so -// we will implement it here for now, and plan to move it into its own file once -// we have a second use case. -// - -/// A base "visitor" class for use in depth-first search algorithms on an IR CFG. -struct DepthFirstSearchContext -{ - /// The blocks in the CFG that we've already visited. - HashSet visited; - - /// Walk a (previously unvisited) block. - /// - /// This will perform any pre-order actions on the block, - /// then recursively visit its (unvisited) successors, and - /// then perform any post-actions. - /// - void walk(IRBlock* block) - { - visited.Add(block); - preVisit(block); - for(auto succ : block->getSuccessors()) - { - if(!visited.Contains(succ)) - { - walk(succ); - } - } - postVisit(block); - } - - /// Walk the blocks in a function (or other code-bearing value). - void walk(IRGlobalValueWithCode* code) - { - auto root = code->getFirstBlock(); - if(!root) - return; - walk(root); - } - - /// Overridable action to perform on first entering a CFG node. - virtual void preVisit(IRBlock* /*block*/) {} - - /// Overridable action to perform on exiting a CFG node - virtual void postVisit(IRBlock* /*block*/) {} -}; - -// -// With DFS traversal factored out, computing a post-order walk -// of the CFG is a simple matter of defining a visitor that appends -// to an order as a post-action: -// - -/// A visitor that computes a postorder traversal for a CFG. -struct PostorderComputationContext : public DepthFirstSearchContext -{ - /// List to append the computed order onto - List* order; - - virtual void postVisit(IRBlock* block) SLANG_OVERRIDE - { - order->add(block); - } -}; - -/// Compute a postorder traversal of the blocks in `code`, writing the resulting order to `outOrder`. -void computePostorder(IRGlobalValueWithCode* code, List& outOrder) -{ - PostorderComputationContext context; - context.order = &outOrder; - context.walk(code); -} - -// -// With the preliminaries out of the way, we are ready to implement -// the dominator tree construction algorithm as described by Cooper, Harvey, and Kennedy. -// The actual code for the algorithm is given in Figure 3 of the paper. -// -// We will wrap the subroutines of their algorithm in a `struct` type -// to allow the temporary structures to be shared. -// -struct DominatorTreeComputationContext -{ - // We will use signed integers to represent the "name" of a block. - // The integers will reflect the a postorder traversal, and this - // property will be exploited in the `intersect()` function. - // - typedef Int BlockName; - // - // An invalid/undefined block name will be represented as -1. - // - static const BlockName kUndefined = BlockName(-1); - // - // We will explicitly store the blocks visited in the postorder - // traversal, so that we can look up a block based on its "name" - // - List postorder; - - // - // We need a way to map our actual IR blocks to their names for - // the purpose of this algorithm. This mapping step adds overhead, - // but it seems unavoidable unless we also translate the CFG itself - // to an index-based representation. - // - Dictionary mapBlockToName; - BlockName getBlockName(IRBlock* block) - { - return mapBlockToName[block]; - } - - // - // The algorithm iteratively builds up an array `doms` that upon - // completion will directly encode the immediate dominator for each - // node. During the iterative steps it is used to implicitly encode - // a representation of the set of dominators for each node. - // - List doms; - - - // - // Here we get to the meat of the algorithm presented in Cooper et al. - // Figure 3: - // - void iterativelyComputeImmediateDominators(IRGlobalValueWithCode* code) - { - // First we compute the postorder traversal order for the blocks in the CFG. - computePostorder(code, postorder); - - // We will initialize our map from the block objects to their "name" - // (index in the traversal order), before moving on. - BlockName blockCount = BlockName(postorder.getCount()); - for(BlockName bb = 0; bb < blockCount; ++bb) - { - mapBlockToName[postorder[bb]] = bb; - } - - // Next we initialize the `doms` array that we will iteratively turn - // into an encoding of the dominator tree. - doms.setCount(blockCount); - for(BlockName bb = 0; bb < blockCount; ++bb) - { - doms[bb] = kUndefined; - } - - // The start node is special, since it is the root of the dominator tree. - // Technically it doesn't have an immediate dominator, but we will set - // its entry in `doms` to refer to itself, to indicate that we are done - // processing the given node. - // - BlockName startNode = getBlockName(code->getFirstBlock()); - doms[startNode] = startNode; - - // Given that we computed a postorder traversal of the graph, we know - // that the start node should be the last one in the computed order. - // - SLANG_ASSERT(startNode == blockCount - 1); - - // We are using an iterative algorithm, so we will detect that we - // have reached a fixed point when we hit an iteration where nothing - // changes. - // - bool changed = true; - while(changed) - { - changed = false; - - // The algorithm specifies that we should walk through the blocks - // in *reverse* postorder, since this speeds up convergence. - // Because we've numbered the blocks in postorder, walking them - // in reverse numerical order will do the trick. - // - // We don't want to include the start node in our iteration - // (since we already know its dominators), and because we know - // that the start node is always the last in the order (`blockCount - 1`) - // we can just start at the next node after it (`blockCount - 2`). - // - // Note: it is important that we are using signed integers for - // block numbers here, since we will drop below zero before exiting - // the loop, and if the CFG had only a single block, then our *starting* - // block index would be `-1`. - // - for(auto b = blockCount - 2; b >= 0; --b) - { - // We are walking through block indices, but the predecessor - // lists are encoded in the IR blocks themselves. - // - IRBlock* block = postorder[b]; - - // The algorithm description in the paper says to pick the - // initial value for the `new_idom` variable from the "first - // (processed) predecessor of b (pick one)". - // After that step, the algorithm walks over the remaining - // predecessors, and for the ones that have a valid entry - // in the `doms` array, performs an intersection of their - // implicitly-represented dominator sets. - // - // The paper doesn't precisely clarify what they mean by - // a "processed" predecessor, but it seems to mean one that - // has a valid value in the `doms` array, which is what - // the subsequent loop is already checking. - // - // We are going to fold this logic together into a single loop. - // We will start with an invalid/undefined value for - // `new_idom`, which represents our best guess at the - // immediate dominator for block `b`: - // - BlockName new_idom = kUndefined; - - // Now we will loop over *all* of the predecessors, ... - for(auto pred : block->getPredecessors()) - { - // ... and skip those that haven't been "processed". - BlockName p = getBlockName(pred); - BlockName dominatorOfPredecessor = doms[p]; - if(dominatorOfPredecessor == kUndefined) - continue; - - // When we encounter the first "processed" predecessor, - // we can initialize the variable tracking our best - // guess at the immediate dominator. - // - if(new_idom == kUndefined) - { - new_idom = p; - } - // - // Otherwise, we need to merge information between - // the predecessor `p` and our best-guess immediate - // dominator `new_idom`. We need a node that dominates - // both of them to be the immediate dominator of `b`. - // - else - { - new_idom = intersect(p, new_idom); - } - } - - // After we've computed a new best guess at the immediate - // dominator for `b`, we need to see if the computed - // value differs from what we'd previously stored in the - // `doms` array. If anything changed, then we haven't - // converged yet, and we need to keep going. - // - BlockName oldDominator = doms[b]; - if(oldDominator != new_idom) - { - doms[b] = new_idom; - changed = true; - } - } - } - - // Upon exiting the loop, things should have converged with - // the `doms` array being an explicit encoding of the immediate - // dominator for each node, with one small error: there is no - // immediate dominator for the start node: - doms[startNode] = kUndefined; - } - - // - // The algorithm above relied on a utility routine `intersect()` that - // is implicitly used to compute intersections between sets of nodes, - // but explicitly takes the form of a routine that computes a common - // parent in the dominator tree for two nodes. - // - // We present that subroutine here, almost identical to how it - // is presented in Cooper et al. Figure 3: - // - BlockName intersect(BlockName b1, BlockName b2) - { - // We need to find a common ancestor of both `b1` and `b2`, - // and will do this by tracking two "fingers," each initially - // pointing at one node, and then iteratively move the finger - // that is furthest to the "left" (earlier in the postorder - // traversal to the left until) to the "right" (by moving - // the immediate dominator of the node we are pointing at), - // until the two fingers are pointing at the same place. - // - // Termination is guaranteed because we are always moving the - // fingers from a node to its immediate dominator, and the - // entry node is guaranteed to be at the root of the dominator - // tree. - // - // The use of the postorder here relies on the (subtle) fact - // that the immediate dominator of a node must come later - // in a postorder traversal. - // - BlockName finger1 = b1; - BlockName finger2 = b2; - - while(finger1 != finger2) - { - while(finger1 < finger2) - finger1 = doms[finger1]; - while(finger2 < finger1) - finger2 = doms[finger2]; - } - return finger1; - } - - // - // Now that we've implemented Cooper et al. fairly close to how - // it was presented, we can build an array encoding the immediate - // dominator relationship. We still need to expand that array - // into an encoding that lets us efficiently answer queries - // about dominance. - // - // In order to do that, we need to expand the information we - // have built on each block (currently just an immediate dominator) - // into a bit more detail: - // - struct BlockInfo - { - // How many children does this node/block have in the dominator tree? - Int childCount = 0; - - // How many indirect (non-child) descendents? - Int indirectDescendentCount = 0; - - // What is the 0-based offset of this node among all the children of its parent? - Int childOffsetInParent = 0; - - // What is the 0-based offset for this node's descendent list, - // among all the children in its parent? - Int descendentOffsetInParent = 0; - - Int nodeIndex = 0; - Int firstDescendentIndex = 0; - }; - // - - RefPtr createDominatorTree(IRGlobalValueWithCode* code) - { - // We first run the Cooper et al. algorithm to compute the `doms` array - // which encodes immediate dominators. - // - iterativelyComputeImmediateDominators(code); - - // We will build some intermediate information on each - // block to help us fill out the tree. - BlockName blockCount = BlockName(doms.getCount()); - List blockInfos; - for(BlockName bb = 0; bb < blockCount; ++bb) - { - blockInfos.add(BlockInfo()); - } - - // We will propagate layout information in two passes over the tree. - // - // First we will perform a "bottom up" pass that will accumulate - // the number of children and the total number of descendents for - // each node, and also assign each child its relative offsets within - // the storage for its parent. - // - // Because our blocks are ordered in postorder, we can do this - // bottom-up walk just by iterating over them in the given order. - // - for(BlockName bb = 0; bb < blockCount; ++bb) - { - BlockName parent = doms[bb]; - if(parent == kUndefined) - continue; - - // For our iteration order to make sense, we need to be certain - // that parent nodes come after their child nodes in the postorder traversal. - SLANG_ASSERT(parent > bb); - - // Compute the 0-based index of this child among all the children - // with the same parent, and increment its child count. - blockInfos[bb].childOffsetInParent = blockInfos[parent].childCount; - blockInfos[parent].childCount++; - - // Our layout for the descendents of a node will put all the immediate - // child nodes contiguously first, followed by their descendents (in contiguous blocks). - // - // We need to compute an offset for where the descendents of this node will - // be stored, within the overall space carved out for the "indirect" descendents - // of the parent node. - // - blockInfos[bb].descendentOffsetInParent = blockInfos[parent].indirectDescendentCount; - // - // When adding up the indirect descendents of `parent`, we need to include both - // the direct and indirect descendents of our node `bb`. - blockInfos[parent].indirectDescendentCount += blockInfos[bb].childCount - + blockInfos[bb].indirectDescendentCount; - } - // - // The next pass is a top-down pass that uses the accumulated - // information to assign absolute indices to each node. - // - // For each node, we want to compute its absolute index in - // the overall array of nodes, and then we also want to compute - // the index where its first descendent node will be placed - // (which can then be used by child nodes to compute their - // index). - // - // The start node in the CFG is special, and will always get - // index zero, with its first desecendent at index 1. - // - BlockName startBlock = getBlockName(code->getFirstBlock()); - blockInfos[startBlock].nodeIndex = 0; - blockInfos[startBlock].firstDescendentIndex = 1; - // - // For the remaining nodes, we'll compute them in a top-down - // pass (using reverse postorder). - // - for(BlockName bb = blockCount-1; bb >= 0; --bb) - { - // We will skip nodes without a parent in the dominator tree. - // This should really only be the start node, but it might - // happen that we have some unreachable nodes that shouldn't - // appear in the dominator tree at all. - // - // TODO: make sure we either handle those correctly, or - // else add a pass to eliminate unreachable blocks first. - // - BlockName parent = doms[bb]; - if(parent == kUndefined) - continue; - - // The absolute index of a node is the absolute index for its - // parent's descendent list, plus the relative offset of this - // child node in its parent. - // - blockInfos[bb].nodeIndex = blockInfos[parent].firstDescendentIndex - + blockInfos[bb].childOffsetInParent; - - // The other descendents of a node are always laid out in the space - // after its immediate children. Thus, the index for where this node - // will place its descendents (direct + indirect) must come after - // the storage for the children of the parent. - // - blockInfos[bb].firstDescendentIndex = blockInfos[parent].firstDescendentIndex - + blockInfos[parent].childCount - + blockInfos[bb].descendentOffsetInParent; - } - - // We now have all the information we need, and can start to fill in - // the actual `IRDominatorTree` structure with the encoded information. - // - RefPtr dominatorTree = new IRDominatorTree(); - dominatorTree->code = code; - dominatorTree->nodes.setCount(blockCount); - - // We will iterate over all of the blocks, and fill in the corresponding - // dominator tree node for each. - // - // Note that the number of the blocks (in postorder) and the numbering - // of the nodes (in breadth-first order) will not match, so we have - // to be careful around whehter we are working with a block index/name, - // or a node index. - // - for(BlockName bb = 0; bb < blockCount; ++bb) - { - // Find the IR block, look up our pre-computed information, - // and find the corresponding node in the dominator tree. - // - IRBlock* block = postorder[bb]; - BlockInfo const& blockInfo = blockInfos[bb]; - Int nodeIndex = blockInfo.nodeIndex; - IRDominatorTree::Node& node = dominatorTree->nodes[nodeIndex]; - - // We will now start filling in the node. Filling in the block is - // trial, and while we are at it we can add an entry to the mapping - // from the block to the node index. - // - node.block = block; - dominatorTree->mapBlockToIndex.Add(block, nodeIndex); - - // Filling in the parent is easy enough, just with the detail that - // we need to handle the invalid case explicitly (for a node with - // no parent), and need to carefully map the block index `parent` - // over to its corresponding node index. - // - BlockName parent = doms[bb]; - node.parent = parent == kUndefined ? IRDominatorTree::kInvalidIndex : blockInfos[parent].nodeIndex; - - // Finally we need to compute the range information to use for the - // descendents (both immediate children and indirect descendents). - // - // All of the relevant information was computed in our two passes - // above, so all that has to happen here is adding together the - // absolute start index for the descendent range with the counts - // we accumulated. - // - Int beginDescendents = blockInfo.firstDescendentIndex; - Int endChildren = beginDescendents + blockInfo.childCount; - // - // The indirect descendents of a node will always come after - // its direct descenents. - // - Int endDescendents = endChildren + blockInfo.indirectDescendentCount; - node.beginDescendents = beginDescendents; - node.endChildren = endChildren; - node.endDescendents = endDescendents; - } - -#if 0 - // Let's do some ad hoc validation here, just to be sure we built the - // data structure reasonably. - for(BlockName ii = 0; ii < blockCount; ++ii) - { - for(BlockName jj = 0; jj < blockCount; ++jj) - { - IRBlock* i = postorder[ii]; - IRBlock* j = postorder[jj]; - - SLANG_RELEASE_ASSERT(dominatorTree->immediatelyDominates(i, j) == (ii == doms[jj])); - - Int dd = jj; - while(dd != kUndefined) - { - if(dd == ii) - break; - dd = doms[dd]; - } - SLANG_RELEASE_ASSERT(dominatorTree->dominates(i, j) == (dd != kUndefined)); - - } - } -#endif - - return dominatorTree; - } -}; - - -RefPtr computeDominatorTree(IRGlobalValueWithCode* code) -{ - DominatorTreeComputationContext context; - return context.createDominatorTree(code); -} - -} diff --git a/source/slang/ir-dominators.h b/source/slang/ir-dominators.h deleted file mode 100644 index 936e9780a..000000000 --- a/source/slang/ir-dominators.h +++ /dev/null @@ -1,162 +0,0 @@ -// ir-dominators.h -#pragma once - -#include "../core/basic.h" - -namespace Slang -{ - struct IRBlock; - struct IRGlobalValueWithCode; - - /// The computed dominator tree for an IR control flow graph. - struct IRDominatorTree : public RefObject - { - /// The function or other code-bearing value for which the dominator tree was computed. - IRGlobalValueWithCode* code; - - /// Does the first block dominate the second? - /// - /// A block A dominates block B iff every control-flow path - /// that starts at the entry block of the CFG and passes - /// through B must first pass through A. - /// - bool dominates(IRBlock* dominator, IRBlock* dominated); - - /// Does the first block properly dominate the second? - /// - /// Block A properly dominates block B iff A dominates B - /// and A != B. - /// - bool properlyDominates(IRBlock* dominator, IRBlock* dominated); - - /// Does the first block immediately dominate the second? - /// - /// Block A immediately dominates block B iff A dominates B - /// and for any block X that dominates B, X also dominates A. - /// - bool immediatelyDominates(IRBlock* dominator, IRBlock* dominated); - - /// Get the immediate dominator (idom) of a block. - /// - /// This is the parent of `block` in the dominator tree. - IRBlock* getImmediateDominator(IRBlock* block); - - /// An iterable collection of the blocks dominated by a specific block - struct DominatedList; - - /// Get the blocks that a block immediately dominates. - /// - /// These are the children of the block in the dominator tree. - DominatedList getImmediatelyDominatedBlocks(IRBlock* block); - - /// Get the blocks that a block properly dominates. - /// - /// These are the descendents of the block in the dominator tree. - DominatedList getProperlyDominatedBlocks(IRBlock* block); - - struct DominatedList - { - public: - DominatedList(); - - struct Iterator - { - public: - Iterator(); - - IRBlock* operator*() const; - void operator++(); - bool operator==(Iterator const& that) const; - - private: - friend struct DominatedList; - Iterator( - IRDominatorTree* tree, - Int index); - - IRDominatorTree* mTree; - Int mIndex; - }; - - Iterator begin() const; - Iterator end() const; - - private: - friend struct IRDominatorTree; - DominatedList( - IRDominatorTree* tree, - Int begin, - Int end); - - IRDominatorTree* mTree; - Int mBegin; - Int mEnd; - }; - - private: - // - // The layout of an `IRDominatorTree` uses a dense array for all of the nodes in the CFG. - // We therefore need a way to map an `IRBlock*` pointer over to an index in this array: - // - - /// Map a block to its index in the `nodes` array - Int getBlockIndex(IRBlock* block); - - /// Dictionary used to accelerate `getBlockIndex` - Dictionary mapBlockToIndex; - - // - // In order to accelerate queries on the tree structure, we will order the tree nodes - // carefully, so that all of the descendants of a node are contiguous, with all of - // the immediate children coming first. - // - // Each node thus needs to remember its parent (immediate dominator), and the range - // of indices that represent children and descendents (respectively), with the knowledge - // that the first child and first descendent share the same index. - // - - /// Information about one node in the dominator tree - struct Node - { - /// The block associated with this tree node - IRBlock* block; - - /// Index of the parent node or -1 if no parent - Int parent; - - /// Index of first descendent - Int beginDescendents; - - /// "One after the end" value for range of child node indices. - Int endChildren; - - /// "One after the end" value for range of descendent node indices. - Int endDescendents; - }; - - /// Storage for the dominator tree itself - List nodes; - - /// Value to use for invalid node indices (e.g., - /// when a node has no parent). - static const Int kInvalidIndex = -1; - - // - // The `DominatedList` type needs direct access to all of this - // data in order to provide iteration. - // - friend struct DominatedList; - friend struct DominatedList::Iterator; - // - // The context type we will use to compute the dominator tree - // also needs to be able to access all the fields to initialze - // an `IRDominatorTree` - // - friend struct DominatorTreeComputationContext; - - // TODO: we should probably build/store a postdominator - // tree in the same structure, just to make life simpler. - }; - - RefPtr computeDominatorTree(IRGlobalValueWithCode* code); -} diff --git a/source/slang/ir-entry-point-uniforms.cpp b/source/slang/ir-entry-point-uniforms.cpp deleted file mode 100644 index 5c7cdb5b4..000000000 --- a/source/slang/ir-entry-point-uniforms.cpp +++ /dev/null @@ -1,425 +0,0 @@ -// ir-entry-point-uniforms.cpp -#include "ir-entry-point-uniforms.h" - -#include "ir.h" -#include "ir-insts.h" - -#include "mangle.h" - -namespace Slang -{ - - -// The transformation in this file will solve the problem of taking -// code like the following: -// -// float4 fragmentMain( -// uniform Texture2D t, -// uniform SamplerState s; -// uniform float4 c, -// float2 uv : UV) : SV_Target -// { -// return t.Sample(s, uv) + c; -// } -// -// and transforming into code like this: -// -// struct Params -// { -// Texture2D t; -// SamplerState s; -// float4 c; -// } -// ConstantBuffer params; -// -// float4 fragmentMain( -// float2 uv : UV) : SV_Target -// { -// return params.t.Sample(params.s, uv) + params.c; -// } -// -// As can be seen in this example, the `uniform` parameters -// declared as entry point parameters have been moved into -// a `struct` declaration that we then use to declare a global -// shader parameter that is a `ConstantBuffer`. We then -// rewrite references to those parameters to refer to the -// contents of the new constant buffer instead. -// -// We perform this transformation after the target-specific -// linking step, because that will have attached layout information -// to the entry point and its parameters. We need that layout -// information so that we can: -// -// * Identify which parameters are uniform vs. varying. -// * Have an appropriate layout to attached to the synthesized -// global shader parameter `params`. -// -// One additional wrinkle this pass has to deal with is that -// in the case where the shader doesn't have any "ordinary" -// uniform parameters like `c` (e.g., it only has resource/object -// parameters), we do *not* wrap the parameter `struct` in -// a `ConstantBuffer`. For example, suppose we have: -// -// float4 fragmentMain( -// uniform Texture2D t, -// uniform SamplerState s; -// float2 uv : UV) : SV_Target -// { -// return t.Sample(s, uv); -// } -// -// In this case the output of the transformation should be: -// -// struct Params -// { -// Texture2D t; -// SamplerState s; -// } -// Params params; -// -// float4 fragmentMain( -// float2 uv : UV) : SV_Target -// { -// return params.t.Sample(params.s, uv) + params.c; -// } -// -// Note that this pass should always come before type legalization, -// which will take responsibility for turning a variable like -// `params` above into individual variables for the `t` and -// `s` fields. - -// The overall structure here is similar to many other IR passes. -// We define a "context" structure to encapsulate the pass. -// -struct MoveEntryPointUniformParametersToGlobalScope -{ - // We'll hang on to the module we are processing, - // so that we can refer to it when setting up `IRBuilder`s. - // - IRModule* module; - - // We will process a whole module by visiting all - // its global functions, looking for entry points. - // - void processModule() - { - // Note that we are only looking at true global-scope - // functions and not functions nested inside of - // IR generics. When using generic entry points, this - // pass should be run after the entry point(s) have - // been specialized to their generic type parameters. - - for( auto inst : module->getGlobalInsts() ) - { - // We are only interested in entry points. - // - // Every entry point must be a function. - // - auto func = as(inst); - if( !func ) - continue; - - // Entry points will always have the `[entryPoint]` - // decoration to differentiate them from ordinary - // functions. - // - // TODO: we could make `IREntryPoint` a subclass of - // `IRFunc` if desired, to avoid having to attach - // an explicit decoration to identify them. - // - if( !func->findDecorationImpl(kIROp_EntryPointDecoration) ) - continue; - - // If we fine a candidate entry point, then we - // will process it. - // - processEntryPoint(func); - } - } - - void processEntryPoint(IRFunc* func) - { - // We expect all entry points to have explicit layout information attached. - // - // We will assert that we have the information we need, but try to be - // defensive and bail out in the failure case in release builds. - // - auto funcLayoutDecoration = func->findDecoration(); - SLANG_ASSERT(funcLayoutDecoration); - if(!funcLayoutDecoration) - return; - - auto entryPointLayout = as(funcLayoutDecoration->getLayout()); - SLANG_ASSERT(entryPointLayout); - if(!entryPointLayout) - return; - - // The parameter layout for an entry point will either be a structure - // type layout, or a constant buffer (a case of parameter group) - // wrapped around such a structure. - // - // If we are in the latter case we will need to make sure to allocate - // an explicit IR constant buffer for that wrapper, - // - auto entryPointParamsLayout = entryPointLayout->parametersLayout; - bool needConstantBuffer = entryPointParamsLayout->typeLayout.is(); - - // We will set up an IR builder so that we are ready to generate code. - // - SharedIRBuilder sharedBuilderStorage; - auto sharedBuilder = &sharedBuilderStorage; - sharedBuilder->module = module; - sharedBuilder->session = module->getSession(); - - IRBuilder builderStorage; - auto builder = &builderStorage; - builder->sharedBuilder = sharedBuilder; - - // *If* the entry point has any uniform parameter then we want to create a - // structure type to house them, and a global shader parameter (either - // an instance of that type or a constant buffer). - // - // We only want to create these if actually needed, so we will declare - // them here and then initialize them on-demand. - // - IRStructType* paramStructType = nullptr; - IRGlobalParam* globalParam = nullptr; - - // We will be removing any uniform parameters we run into, so we - // need to iterate the parameter list carefully to deal with - // us modifying it along the way. - // - IRParam* nextParam = nullptr; - for( IRParam* param = func->getFirstParam(); param; param = nextParam ) - { - nextParam = param->getNextParam(); - - // We expect all entry-point parameters to have layout information, - // but we will be defensive and skip parameters without the required - // information when we are in a release build. - // - auto layoutDecoration = param->findDecoration(); - SLANG_ASSERT(layoutDecoration); - if(!layoutDecoration) - continue; - auto paramLayout = as(layoutDecoration->getLayout()); - SLANG_ASSERT(paramLayout); - if(!paramLayout) - continue; - - // A parameter that has varying input/output behavior should be left alone, - // since this pass is only supposed to apply to uniform (non-varying) - // parameters. - // - if(isVaryingParameter(paramLayout)) - continue; - - // At this point we know that `param` is not a varying shader parameter, - // so that we want to turn it into an equivalent global shader parameter. - // - // If this is the first parameter we are running into, then we need - // to deal with creating the structure type and global shader - // parameter that our transformed entry point will use. - // - if( !paramStructType ) - { - // First we create the structure to hold the parameters. - // - builder->setInsertBefore(func); - paramStructType = builder->createStructType(); - - if( needConstantBuffer ) - { - // If we need a constant buffer, then the global - // shader parameter will be a `ConstantBuffer` - // - auto constantBufferType = builder->getConstantBufferType(paramStructType); - globalParam = builder->createGlobalParam(constantBufferType); - } - else - { - // Otherwise, the global shader parameter is just - // an instance of `paramStructType`. - // - globalParam = builder->createGlobalParam(paramStructType); - } - - // No matter what, the global shader parameter should have the layout - // information from the entry point attached to it, so that the - // contained parameters will end up in the right place(s). - // - builder->addLayoutDecoration(globalParam, entryPointParamsLayout); - } - - // Now that we've ensured the global `struct` type and shader paramter - // exist, we need to add a field to the `struct` to represent the - // current parameter. - // - - auto paramType = param->getFullType(); - - builder->setInsertBefore(paramStructType); - auto paramFieldKey = builder->createStructKey(); - auto paramField = builder->createStructField(paramStructType, paramFieldKey, paramType); - SLANG_UNUSED(paramField); - - // We will transfer all decorations on the parameter over to the key - // so that they can affect downstream emit logic. - // - // TODO: We should double-check whether any of the decorations should - // be moved to the *field* instead. - // - param->transferDecorationsTo(paramFieldKey); - - // There is a bit of a hacky issue, where downstream passes (notably - // type legalization) require the field keys for `struct` types to - // have mangled names, because those mangled names will be used to - // lookup field layout information inside of the layout information - // for the `struct` type. - // - // TODO: We should fix that design choice in how layout information - // is stored, to avoid the reliance on name strings. - // - builder->addExportDecoration(paramFieldKey, getMangledName(paramLayout->varDecl).getUnownedSlice()); - - // At this point we want to eliminate the original entry point - // parameter, in favor of the `struct` field we declared. - // That required replacing any uses of the parameter with - // appropriate code to pull out the field. - // - // We *could* extract the field at the start of the shader - // and then do a `replaceAllUsesWith` to propragate it - // down, but in practice we expect that it is better for - // performance to "rematerialize" the value of a shader - // parameter as close to where it is used as possible. - // - // We are therefore going to replace the uses one at a time. - // - while(auto use = param->firstUse ) - { - // Given a `use` of the paramter, we will insert - // the replacement code right before the instruction - // that is doing the using. - // - builder->setInsertBefore(use->getUser()); - - // The way to extract the field that corresponds - // to the parameter depends on whether or not - // we generated a constant buffer. - // - IRInst* fieldVal = nullptr; - if( needConstantBuffer ) - { - // A constant buffer behaves like a pointer - // at the IR level, so we first do a pointer - // offset operation to compute what amounts - // to `&cb->field`, and then load from that address. - // - auto fieldAddress = builder->emitFieldAddress( - builder->getPtrType(paramType), - globalParam, - paramFieldKey); - fieldVal = builder->emitLoad(fieldAddress); - } - else - { - // In the ordinary struct case, the parameter - // has an ordinary `struct` type (not a pointer), - // so we just extract the field directly. - // - fieldVal = builder->emitFieldExtract( - paramType, - globalParam, - paramFieldKey); - } - - // We replace the value used at this use site, which - // will have a side effect of making `use` no longer - // be on the list of uses for `param`, so that when - // we get back to the top of the loop the list of - // uses will be shorter. - // - use->set(fieldVal); - } - - // Once we've replaced all the uses of `param`, we - // can go ahead and remove it completely. - // - param->removeAndDeallocate(); - } - - fixUpFuncType(func); - } - - // We need to be able to determine if a parameter is logically - // a "varying" parameter based on its layout. - // - bool isVaryingParameter(VarLayout* layout) - { - // If *any* of the resources consumed by the parameter - // is a varying resource kind (e.g., varying input) then - // we consider the whole parameter to be varying. - // - // This is reasonable because there is no way to declare - // a parameter that mixes varying and non-varying fields. - // - for( auto resInfo : layout->resourceInfos ) - { - if(isVaryingResourceKind(resInfo.kind)) - return true; - } - - // Varying parameters with "system value" semantics currently show up as - // consuming no resources, so we need to special-case that here. - // - // Note: an empty `struct` parameter would also show up the same way, but - // we should eliminate any such parameters later on during type legalization. - // - if(layout->resourceInfos.getCount() == 0) - return true; - - // if none of the above tests determined that the - // parameter was varying, then we can safely consider - // it to be non-varying (uniform): - return false; - } - - // In order to determine whether a parameter is varying based on its - // layout, we need to know which resource kinds represent varying - // shader parameters. - // - bool isVaryingResourceKind(LayoutResourceKind kind) - { - switch( kind ) - { - default: - return false; - - // Note: The set of cases that are considered - // varying here would need to be extended if we - // add more fine-grained resource kinds (e.g., - // if we ever add an explicit resource kind - // for geometry shader output streams). - // - // Ordinary varying input/output: - case LayoutResourceKind::VaryingInput: - case LayoutResourceKind::VaryingOutput: - // - // Ray-tracing shader input/output: - case LayoutResourceKind::CallablePayload: - case LayoutResourceKind::HitAttributes: - case LayoutResourceKind::RayPayload: - return true; - } - } -}; - -void moveEntryPointUniformParamsToGlobalScope( - IRModule* module) -{ - MoveEntryPointUniformParametersToGlobalScope context; - context.module = module; - context.processModule(); -} - -} diff --git a/source/slang/ir-entry-point-uniforms.h b/source/slang/ir-entry-point-uniforms.h deleted file mode 100644 index 5fcfab167..000000000 --- a/source/slang/ir-entry-point-uniforms.h +++ /dev/null @@ -1,12 +0,0 @@ -// ir-entry-point-uniform.h -#pragma once - -namespace Slang -{ -struct IRModule; - - /// Move any uniform parameters of entry points to the global scope instead. -void moveEntryPointUniformParamsToGlobalScope( - IRModule* module); - -} diff --git a/source/slang/ir-glsl-legalize.cpp b/source/slang/ir-glsl-legalize.cpp deleted file mode 100644 index b63225729..000000000 --- a/source/slang/ir-glsl-legalize.cpp +++ /dev/null @@ -1,1687 +0,0 @@ -// ir-glsl-legalize.cpp -#include "ir-glsl-legalize.h" - -#include "ir.h" -#include "ir-insts.h" - -#include "slang-extension-usage-tracker.h" - -namespace Slang -{ - -// -// Legalization of entry points for GLSL: -// - -IRGlobalParam* addGlobalParam( - IRModule* module, - IRType* valueType) -{ - auto session = module->session; - - SharedIRBuilder shared; - shared.module = module; - shared.session = session; - - IRBuilder builder; - builder.sharedBuilder = &shared; - return builder.createGlobalParam(valueType); -} - -void moveValueBefore( - IRInst* valueToMove, - IRInst* placeBefore) -{ - valueToMove->removeFromParent(); - valueToMove->insertBefore(placeBefore); -} - -IRType* getFieldType( - IRType* baseType, - IRStructKey* fieldKey) -{ - if(auto structType = as(baseType)) - { - for(auto ff : structType->getFields()) - { - if(ff->getKey() == fieldKey) - return ff->getFieldType(); - } - } - - SLANG_UNEXPECTED("no such field"); - UNREACHABLE_RETURN(nullptr); -} - - - -// When scalarizing shader inputs/outputs for GLSL, we need a way -// to refer to a conceptual "value" that might comprise multiple -// IR-level values. We could in principle introduce tuple types -// into the IR so that everything stays at the IR level, but -// it seems easier to just layer it over the top for now. -// -// The `ScalarizedVal` type deals with the "tuple or single value?" -// question, and also the "l-value or r-value?" question. -struct ScalarizedValImpl : RefObject -{}; -struct ScalarizedTupleValImpl; -struct ScalarizedTypeAdapterValImpl; -struct ScalarizedVal -{ - enum class Flavor - { - // no value (null pointer) - none, - - // A simple `IRInst*` that represents the actual value - value, - - // An `IRInst*` that represents the address of the actual value - address, - - // A `TupleValImpl` that represents zero or more `ScalarizedVal`s - tuple, - - // A `TypeAdapterValImpl` that wraps a single `ScalarizedVal` and - // represents an implicit type conversion applied to it on read - // or write. - typeAdapter, - }; - - // Create a value representing a simple value - static ScalarizedVal value(IRInst* irValue) - { - ScalarizedVal result; - result.flavor = Flavor::value; - result.irValue = irValue; - return result; - } - - - // Create a value representing an address - static ScalarizedVal address(IRInst* irValue) - { - ScalarizedVal result; - result.flavor = Flavor::address; - result.irValue = irValue; - return result; - } - - static ScalarizedVal tuple(ScalarizedTupleValImpl* impl) - { - ScalarizedVal result; - result.flavor = Flavor::tuple; - result.impl = (ScalarizedValImpl*)impl; - return result; - } - - static ScalarizedVal typeAdapter(ScalarizedTypeAdapterValImpl* impl) - { - ScalarizedVal result; - result.flavor = Flavor::typeAdapter; - result.impl = (ScalarizedValImpl*)impl; - return result; - } - - Flavor flavor = Flavor::none; - IRInst* irValue = nullptr; - RefPtr impl; -}; - -// This is the case for a value that is a "tuple" of other values -struct ScalarizedTupleValImpl : ScalarizedValImpl -{ - struct Element - { - IRStructKey* key; - ScalarizedVal val; - }; - - IRType* type; - List elements; -}; - -// This is the case for a value that is stored with one type, -// but needs to present itself as having a different type -struct ScalarizedTypeAdapterValImpl : ScalarizedValImpl -{ - ScalarizedVal val; - IRType* actualType; // the actual type of `val` - IRType* pretendType; // the type this value pretends to have -}; - -struct GlobalVaryingDeclarator -{ - enum class Flavor - { - array, - }; - - Flavor flavor; - IRInst* elementCount; - GlobalVaryingDeclarator* next; -}; - -struct GLSLSystemValueInfo -{ - // The name of the built-in GLSL variable - char const* name; - - // The name of an outer array that wraps - // the variable, in the case of a GS input - char const* outerArrayName; - - // The required type of the built-in variable - IRType* requiredType; -}; - -struct GLSLLegalizationContext -{ - Session* session; - ExtensionUsageTracker* extensionUsageTracker; - DiagnosticSink* sink; - Stage stage; - - void requireGLSLExtension(String const& name) - { - extensionUsageTracker->requireGLSLExtension(name); - } - - void requireGLSLVersion(ProfileVersion version) - { - extensionUsageTracker->requireGLSLVersion(version); - } - - Stage getStage() - { - return stage; - } - - DiagnosticSink* getSink() - { - return sink; - } - - IRBuilder* builder; - IRBuilder* getBuilder() { return builder; } -}; - -GLSLSystemValueInfo* getGLSLSystemValueInfo( - GLSLLegalizationContext* context, - VarLayout* varLayout, - LayoutResourceKind kind, - Stage stage, - GLSLSystemValueInfo* inStorage) -{ - char const* name = nullptr; - char const* outerArrayName = nullptr; - - auto semanticNameSpelling = varLayout->systemValueSemantic; - if(semanticNameSpelling.getLength() == 0) - return nullptr; - - auto semanticName = semanticNameSpelling.toLower(); - - // HLSL semantic types can be found here - // https://docs.microsoft.com/en-us/windows/desktop/direct3dhlsl/dx-graphics-hlsl-semantics - /// NOTE! While there might be an "official" type for most of these in HLSL, in practice the user is allowed to declare almost anything - /// that the HLSL compiler can implicitly convert to/from the correct type - - auto builder = context->getBuilder(); - IRType* requiredType = nullptr; - - if(semanticName == "sv_position") - { - // float4 in hlsl & glsl - // https://www.khronos.org/registry/OpenGL-Refpages/gl4/html/gl_FragCoord.xhtml - // https://www.khronos.org/registry/OpenGL-Refpages/gl4/html/gl_Position.xhtml - - // This semantic can either work like `gl_FragCoord` - // when it is used as a fragment shader input, or - // like `gl_Position` when used in other stages. - // - // Note: This isn't as simple as testing input-vs-output, - // because a user might have a VS output `SV_Position`, - // and then pass it along to a GS that reads it as input. - // - if( stage == Stage::Fragment - && kind == LayoutResourceKind::VaryingInput ) - { - name = "gl_FragCoord"; - } - else if( stage == Stage::Geometry - && kind == LayoutResourceKind::VaryingInput ) - { - // As a GS input, the correct syntax is `gl_in[...].gl_Position`, - // but that is not compatible with picking the array dimension later, - // of course. - outerArrayName = "gl_in"; - name = "gl_Position"; - } - else - { - name = "gl_Position"; - } - - requiredType = builder->getVectorType(builder->getBasicType(BaseType::Float), builder->getIntValue(builder->getIntType(), 4)); - } - else if(semanticName == "sv_target") - { - // Note: we do *not* need to generate some kind of `gl_` - // builtin for fragment-shader outputs: they are just - // ordinary `out` variables, with ordinary `location`s, - // as far as GLSL is concerned. - return nullptr; - } - else if(semanticName == "sv_clipdistance") - { - // TODO: type conversion is required here. - - // float in hlsl & glsl. - // "Clip distance data. SV_ClipDistance values are each assumed to be a float32 signed distance to a plane." - // In glsl clipping value meaning is probably different - // https://www.khronos.org/registry/OpenGL-Refpages/gl4/html/gl_ClipDistance.xhtml - - name = "gl_ClipDistance"; - requiredType = builder->getBasicType(BaseType::Float); - } - else if(semanticName == "sv_culldistance") - { - // float in hlsl & glsl. - // https://www.khronos.org/registry/OpenGL-Refpages/gl4/html/gl_CullDistance.xhtml - - context->requireGLSLExtension("ARB_cull_distance"); - - // TODO: type conversion is required here. - name = "gl_CullDistance"; - requiredType = builder->getBasicType(BaseType::Float); - } - else if(semanticName == "sv_coverage") - { - // TODO: deal with `gl_SampleMaskIn` when used as an input. - - // TODO: type conversion is required here. - - // uint in hlsl, int in glsl - // https://www.opengl.org/sdk/docs/manglsl/docbook4/xhtml/gl_SampleMask.xml - - requiredType = builder->getBasicType(BaseType::Int); - - name = "gl_SampleMask"; - } - else if(semanticName == "sv_depth") - { - // Float in hlsl & glsl - // https://www.khronos.org/registry/OpenGL-Refpages/gl4/html/gl_FragDepth.xhtml - name = "gl_FragDepth"; - requiredType = builder->getBasicType(BaseType::Float); - } - else if(semanticName == "sv_depthgreaterequal") - { - // TODO: layout(depth_greater) out float gl_FragDepth; - - // Type is 'unknown' in hlsl - name = "gl_FragDepth"; - requiredType = builder->getBasicType(BaseType::Float); - } - else if(semanticName == "sv_depthlessequal") - { - // TODO: layout(depth_greater) out float gl_FragDepth; - - // 'unknown' in hlsl, float in glsl - name = "gl_FragDepth"; - requiredType = builder->getBasicType(BaseType::Float); - } - else if(semanticName == "sv_dispatchthreadid") - { - // uint3 in hlsl, uvec3 in glsl - // https://www.opengl.org/sdk/docs/manglsl/docbook4/xhtml/gl_GlobalInvocationID.xml - name = "gl_GlobalInvocationID"; - - requiredType = builder->getVectorType(builder->getBasicType(BaseType::UInt), builder->getIntValue(builder->getIntType(), 3)); - } - else if(semanticName == "sv_domainlocation") - { - // float2|3 in hlsl, vec3 in glsl - // https://www.khronos.org/registry/OpenGL-Refpages/gl4/html/gl_TessCoord.xhtml - - requiredType = builder->getVectorType(builder->getBasicType(BaseType::Float), builder->getIntValue(builder->getIntType(), 3)); - - name = "gl_TessCoord"; - } - else if(semanticName == "sv_groupid") - { - // uint3 in hlsl, uvec3 in glsl - // https://www.opengl.org/sdk/docs/manglsl/docbook4/xhtml/gl_WorkGroupID.xml - name = "gl_WorkGroupID"; - - requiredType = builder->getVectorType(builder->getBasicType(BaseType::UInt), builder->getIntValue(builder->getIntType(), 3)); - } - else if(semanticName == "sv_groupindex") - { - // uint in hlsl & in glsl - name = "gl_LocalInvocationIndex"; - requiredType = builder->getBasicType(BaseType::UInt); - } - else if(semanticName == "sv_groupthreadid") - { - // uint3 in hlsl, uvec3 in glsl - name = "gl_LocalInvocationID"; - - requiredType = builder->getVectorType(builder->getBasicType(BaseType::UInt), builder->getIntValue(builder->getIntType(), 3)); - } - else if(semanticName == "sv_gsinstanceid") - { - // uint in hlsl, int in glsl - // https://www.khronos.org/registry/OpenGL-Refpages/gl4/html/gl_InvocationID.xhtml - - requiredType = builder->getBasicType(BaseType::Int); - name = "gl_InvocationID"; - } - else if(semanticName == "sv_instanceid") - { - // https://docs.microsoft.com/en-us/windows/desktop/direct3d11/d3d10-graphics-programming-guide-input-assembler-stage-using#instanceid - // uint in hlsl, int in glsl - - requiredType = builder->getBasicType(BaseType::Int); - name = "gl_InstanceIndex"; - } - else if(semanticName == "sv_isfrontface") - { - // bool in hlsl & glsl - // https://www.khronos.org/registry/OpenGL-Refpages/gl4/html/gl_FrontFacing.xhtml - name = "gl_FrontFacing"; - requiredType = builder->getBasicType(BaseType::Bool); - } - else if(semanticName == "sv_outputcontrolpointid") - { - // uint in hlsl, int in glsl - // https://www.khronos.org/registry/OpenGL-Refpages/gl4/html/gl_InvocationID.xhtml - - name = "gl_InvocationID"; - - requiredType = builder->getBasicType(BaseType::Int); - } - else if (semanticName == "sv_pointsize") - { - // float in hlsl & glsl - name = "gl_PointSize"; - requiredType = builder->getBasicType(BaseType::Float); - } - else if(semanticName == "sv_primitiveid") - { - // uint in hlsl, int in glsl - // https://www.khronos.org/registry/OpenGL-Refpages/gl4/html/gl_PrimitiveID.xhtml - name = "gl_PrimitiveID"; - - requiredType = builder->getBasicType(BaseType::Int); - } - else if (semanticName == "sv_rendertargetarrayindex") - { - // uint on hlsl, int on glsl - // https://www.khronos.org/registry/OpenGL-Refpages/gl4/html/gl_Layer.xhtml - - switch (context->getStage()) - { - case Stage::Geometry: - context->requireGLSLVersion(ProfileVersion::GLSL_150); - break; - - case Stage::Fragment: - context->requireGLSLVersion(ProfileVersion::GLSL_430); - break; - - default: - context->requireGLSLVersion(ProfileVersion::GLSL_450); - context->requireGLSLExtension("GL_ARB_shader_viewport_layer_array"); - break; - } - - name = "gl_Layer"; - requiredType = builder->getBasicType(BaseType::Int); - } - else if (semanticName == "sv_sampleindex") - { - // uint in hlsl, int in glsl - // https://www.khronos.org/registry/OpenGL-Refpages/gl4/html/gl_SampleID.xhtml - - requiredType = builder->getBasicType(BaseType::Int); - name = "gl_SampleID"; - } - else if (semanticName == "sv_stencilref") - { - // uint in hlsl, int in glsl - // https://www.khronos.org/registry/OpenGL/extensions/ARB/ARB_shader_stencil_export.txt - - requiredType = builder->getBasicType(BaseType::Int); - - context->requireGLSLExtension("ARB_shader_stencil_export"); - name = "gl_FragStencilRef"; - } - else if (semanticName == "sv_tessfactor") - { - // TODO(JS): Adjust type does *not* handle the conversion correctly. More specifically a float array hlsl - // parameter goes through code to make SOA in createGLSLGlobalVaryingsImpl. - // - // Can be input and output. - // - // https://docs.microsoft.com/en-us/windows/desktop/direct3dhlsl/sv-tessfactor - // "Tessellation factors must be declared as an array; they cannot be packed into a single vector." - // - // float[2|3|4] in hlsl, float[4] on glsl (ie both are arrays but might be different size) - // https://www.khronos.org/registry/OpenGL-Refpages/gl4/html/gl_TessLevelOuter.xhtml - - name = "gl_TessLevelOuter"; - - // float[4] on glsl - requiredType = builder->getArrayType(builder->getBasicType(BaseType::Float), builder->getIntValue(builder->getIntType(), 4)); - } - else if (semanticName == "sv_vertexid") - { - // uint in hlsl, int in glsl (https://www.khronos.org/opengl/wiki/Built-in_Variable_(GLSL)) - requiredType = builder->getBasicType(BaseType::Int); - name = "gl_VertexIndex"; - } - else if (semanticName == "sv_viewportarrayindex") - { - // uint on hlsl, int on glsl - // https://www.khronos.org/registry/OpenGL-Refpages/gl4/html/gl_ViewportIndex.xhtml - - requiredType = builder->getBasicType(BaseType::Int); - name = "gl_ViewportIndex"; - } - else if (semanticName == "nv_x_right") - { - context->requireGLSLVersion(ProfileVersion::GLSL_450); - context->requireGLSLExtension("GL_NVX_multiview_per_view_attributes"); - - // The actual output in GLSL is: - // - // vec4 gl_PositionPerViewNV[]; - // - // and is meant to support an arbitrary number of views, - // while the HLSL case just defines a second position - // output. - // - // For now we will hack this by: - // 1. Mapping an `NV_X_Right` output to `gl_PositionPerViewNV[1]` - // (that is, just one element of the output array) - // 2. Adding logic to copy the traditional `gl_Position` output - // over to `gl_PositionPerViewNV[0]` - // - - name = "gl_PositionPerViewNV[1]"; - -// shared->requiresCopyGLPositionToPositionPerView = true; - } - else if (semanticName == "nv_viewport_mask") - { - // TODO: This doesn't seem to work correctly on it's own between hlsl/glsl - - // Indeed on slang issue 109 claims this remains a problem - // https://github.com/shader-slang/slang/issues/109 - - // On hlsl it's UINT related. "higher 16 bits for the right view, lower 16 bits for the left view." - // There is use in hlsl shader code as uint4 - not clear if that varies - // https://github.com/KhronosGroup/GLSL/blob/master/extensions/nvx/GL_NVX_multiview_per_view_attributes.txt - // On glsl its highp int gl_ViewportMaskPerViewNV[]; - - context->requireGLSLVersion(ProfileVersion::GLSL_450); - context->requireGLSLExtension("GL_NVX_multiview_per_view_attributes"); - - name = "gl_ViewportMaskPerViewNV"; -// globalVarExpr = createGLSLBuiltinRef("gl_ViewportMaskPerViewNV", -// getUnsizedArrayType(getIntType())); - } - - if( name ) - { - inStorage->name = name; - inStorage->outerArrayName = outerArrayName; - inStorage->requiredType = requiredType; - return inStorage; - } - - context->getSink()->diagnose(varLayout->varDecl.getDecl()->loc, Diagnostics::unknownSystemValueSemantic, semanticNameSpelling); - return nullptr; -} - -ScalarizedVal createSimpleGLSLGlobalVarying( - GLSLLegalizationContext* context, - IRBuilder* builder, - IRType* inType, - VarLayout* inVarLayout, - TypeLayout* inTypeLayout, - LayoutResourceKind kind, - Stage stage, - UInt bindingIndex, - GlobalVaryingDeclarator* declarator) -{ - // Check if we have a system value on our hands. - GLSLSystemValueInfo systemValueInfoStorage; - auto systemValueInfo = getGLSLSystemValueInfo( - context, - inVarLayout, - kind, - stage, - &systemValueInfoStorage); - - IRType* type = inType; - - // A system-value semantic might end up needing to override the type - // that the user specified. - if( systemValueInfo && systemValueInfo->requiredType ) - { - type = systemValueInfo->requiredType; - } - - // Construct the actual type and type-layout for the global variable - // - RefPtr typeLayout = inTypeLayout; - for( auto dd = declarator; dd; dd = dd->next ) - { - // We only have one declarator case right now... - SLANG_ASSERT(dd->flavor == GlobalVaryingDeclarator::Flavor::array); - - auto arrayType = builder->getArrayType( - type, - dd->elementCount); - - RefPtr arrayTypeLayout = new ArrayTypeLayout(); -// arrayTypeLayout->type = arrayType; - arrayTypeLayout->rules = typeLayout->rules; - arrayTypeLayout->originalElementTypeLayout = typeLayout; - arrayTypeLayout->elementTypeLayout = typeLayout; - arrayTypeLayout->uniformStride = 0; - - if( auto resInfo = inTypeLayout->FindResourceInfo(kind) ) - { - // TODO: it is kind of gross to be re-running some - // of the type layout logic here. - - UInt elementCount = (UInt) GetIntVal(dd->elementCount); - arrayTypeLayout->addResourceUsage( - kind, - resInfo->count * elementCount); - } - - type = arrayType; - typeLayout = arrayTypeLayout; - } - - // We need to construct a fresh layout for the variable, even - // if the original had its own layout, because it might be - // an `inout` parameter, and we only want to deal with the case - // described by our `kind` parameter. - RefPtr varLayout = new VarLayout(); - varLayout->varDecl = inVarLayout->varDecl; - varLayout->typeLayout = typeLayout; - varLayout->flags = inVarLayout->flags; - varLayout->systemValueSemantic = inVarLayout->systemValueSemantic; - varLayout->systemValueSemanticIndex = inVarLayout->systemValueSemanticIndex; - varLayout->semanticName = inVarLayout->semanticName; - varLayout->semanticIndex = inVarLayout->semanticIndex; - varLayout->stage = inVarLayout->stage; - varLayout->AddResourceInfo(kind)->index = bindingIndex; - - // We are going to be creating a global parameter to replace - // the function parameter, but we need to handle the case - // where the parameter represents a varying *output* and not - // just an input. - // - // Our IR global shader parameters are read-only, just - // like our IR function parameters, and need a wrapper - // `Out<...>` type to represent outputs. - // - bool isOutput = kind == LayoutResourceKind::VaryingOutput; - IRType* paramType = isOutput ? builder->getOutType(type) : type; - - auto globalParam = addGlobalParam(builder->getModule(), paramType); - moveValueBefore(globalParam, builder->getFunc()); - - ScalarizedVal val = isOutput ? ScalarizedVal::address(globalParam) : ScalarizedVal::value(globalParam); - - if( systemValueInfo ) - { - builder->addImportDecoration(globalParam, UnownedTerminatedStringSlice(systemValueInfo->name)); - - if( auto fromType = systemValueInfo->requiredType ) - { - // We may need to adapt from the declared type to/from - // the actual type of the GLSL global. - auto toType = inType; - - if( !isTypeEqual(fromType, toType )) - { - RefPtr typeAdapter = new ScalarizedTypeAdapterValImpl; - typeAdapter->actualType = systemValueInfo->requiredType; - typeAdapter->pretendType = inType; - typeAdapter->val = val; - - val = ScalarizedVal::typeAdapter(typeAdapter); - } - } - - if(auto outerArrayName = systemValueInfo->outerArrayName) - { - builder->addGLSLOuterArrayDecoration(globalParam, UnownedTerminatedStringSlice(outerArrayName)); - } - } - - builder->addLayoutDecoration(globalParam, varLayout); - - return val; -} - -ScalarizedVal createGLSLGlobalVaryingsImpl( - GLSLLegalizationContext* context, - IRBuilder* builder, - IRType* type, - VarLayout* varLayout, - TypeLayout* typeLayout, - LayoutResourceKind kind, - Stage stage, - UInt bindingIndex, - GlobalVaryingDeclarator* declarator) -{ - if (as(type)) - { - return ScalarizedVal(); - } - else if( as(type) ) - { - return createSimpleGLSLGlobalVarying( - context, - builder, type, varLayout, typeLayout, kind, stage, bindingIndex, declarator); - } - else if( as(type) ) - { - return createSimpleGLSLGlobalVarying( - context, - builder, type, varLayout, typeLayout, kind, stage, bindingIndex, declarator); - } - else if( as(type) ) - { - // TODO: a matrix-type varying should probably be handled like an array of rows - return createSimpleGLSLGlobalVarying( - context, - builder, type, varLayout, typeLayout, kind, stage, bindingIndex, declarator); - } - else if( auto arrayType = as(type) ) - { - // We will need to SOA-ize any nested types. - - auto elementType = arrayType->getElementType(); - auto elementCount = arrayType->getElementCount(); - auto arrayLayout = as(typeLayout); - SLANG_ASSERT(arrayLayout); - auto elementTypeLayout = arrayLayout->elementTypeLayout; - - GlobalVaryingDeclarator arrayDeclarator; - arrayDeclarator.flavor = GlobalVaryingDeclarator::Flavor::array; - arrayDeclarator.elementCount = elementCount; - arrayDeclarator.next = declarator; - - return createGLSLGlobalVaryingsImpl( - context, - builder, - elementType, - varLayout, - elementTypeLayout, - kind, - stage, - bindingIndex, - &arrayDeclarator); - } - else if( auto streamType = as(type)) - { - auto elementType = streamType->getElementType(); - auto streamLayout = as(typeLayout); - SLANG_ASSERT(streamLayout); - auto elementTypeLayout = streamLayout->elementTypeLayout; - - return createGLSLGlobalVaryingsImpl( - context, - builder, - elementType, - varLayout, - elementTypeLayout, - kind, - stage, - bindingIndex, - declarator); - } - else if(auto structType = as(type)) - { - // We need to recurse down into the individual fields, - // and generate a variable for each of them. - - auto structTypeLayout = as(typeLayout); - SLANG_ASSERT(structTypeLayout); - RefPtr tupleValImpl = new ScalarizedTupleValImpl(); - - - // Construct the actual type for the tuple (including any outer arrays) - IRType* fullType = type; - for( auto dd = declarator; dd; dd = dd->next ) - { - SLANG_ASSERT(dd->flavor == GlobalVaryingDeclarator::Flavor::array); - fullType = builder->getArrayType( - fullType, - dd->elementCount); - } - - tupleValImpl->type = fullType; - - // Okay, we want to walk through the fields here, and - // generate one variable for each. - UInt fieldCounter = 0; - for(auto field : structType->getFields()) - { - UInt fieldIndex = fieldCounter++; - - auto fieldLayout = structTypeLayout->fields[fieldIndex]; - - UInt fieldBindingIndex = bindingIndex; - if(auto fieldResInfo = fieldLayout->FindResourceInfo(kind)) - fieldBindingIndex += fieldResInfo->index; - - auto fieldVal = createGLSLGlobalVaryingsImpl( - context, - builder, - field->getFieldType(), - fieldLayout, - fieldLayout->typeLayout, - kind, - stage, - fieldBindingIndex, - declarator); - if (fieldVal.flavor != ScalarizedVal::Flavor::none) - { - ScalarizedTupleValImpl::Element element; - element.val = fieldVal; - element.key = field->getKey(); - - tupleValImpl->elements.add(element); - } - } - - return ScalarizedVal::tuple(tupleValImpl); - } - - // Default case is to fall back on the simple behavior - return createSimpleGLSLGlobalVarying( - context, - builder, type, varLayout, typeLayout, kind, stage, bindingIndex, declarator); -} - -ScalarizedVal createGLSLGlobalVaryings( - GLSLLegalizationContext* context, - IRBuilder* builder, - IRType* type, - VarLayout* layout, - LayoutResourceKind kind, - Stage stage) -{ - UInt bindingIndex = 0; - if(auto rr = layout->FindResourceInfo(kind)) - bindingIndex = rr->index; - return createGLSLGlobalVaryingsImpl( - context, - builder, type, layout, layout->typeLayout, kind, stage, bindingIndex, nullptr); -} - -ScalarizedVal extractField( - IRBuilder* builder, - ScalarizedVal const& val, - UInt fieldIndex, - IRStructKey* fieldKey) -{ - switch( val.flavor ) - { - case ScalarizedVal::Flavor::value: - return ScalarizedVal::value( - builder->emitFieldExtract( - getFieldType(val.irValue->getDataType(), fieldKey), - val.irValue, - fieldKey)); - - case ScalarizedVal::Flavor::address: - { - auto ptrType = as(val.irValue->getDataType()); - auto valType = ptrType->getValueType(); - auto fieldType = getFieldType(valType, fieldKey); - auto fieldPtrType = builder->getPtrType(ptrType->op, fieldType); - return ScalarizedVal::address( - builder->emitFieldAddress( - fieldPtrType, - val.irValue, - fieldKey)); - } - - case ScalarizedVal::Flavor::tuple: - { - auto tupleVal = as(val.impl); - return tupleVal->elements[fieldIndex].val; - } - - default: - SLANG_UNEXPECTED("unimplemented"); - UNREACHABLE_RETURN(ScalarizedVal()); - } - -} - -ScalarizedVal adaptType( - IRBuilder* builder, - IRInst* val, - IRType* toType, - IRType* /*fromType*/) -{ - // TODO: actually consider what needs to go on here... - return ScalarizedVal::value(builder->emitConstructorInst( - toType, - 1, - &val)); -} - -ScalarizedVal adaptType( - IRBuilder* builder, - ScalarizedVal const& val, - IRType* toType, - IRType* fromType) -{ - switch( val.flavor ) - { - case ScalarizedVal::Flavor::value: - return adaptType(builder, val.irValue, toType, fromType); - break; - - case ScalarizedVal::Flavor::address: - { - auto loaded = builder->emitLoad(val.irValue); - return adaptType(builder, loaded, toType, fromType); - } - break; - - default: - SLANG_UNEXPECTED("unimplemented"); - UNREACHABLE_RETURN(ScalarizedVal()); - } -} - -void assign( - IRBuilder* builder, - ScalarizedVal const& left, - ScalarizedVal const& right) -{ - switch( left.flavor ) - { - case ScalarizedVal::Flavor::address: - switch( right.flavor ) - { - case ScalarizedVal::Flavor::value: - { - builder->emitStore(left.irValue, right.irValue); - } - break; - - case ScalarizedVal::Flavor::address: - { - auto val = builder->emitLoad(right.irValue); - builder->emitStore(left.irValue, val); - } - break; - - case ScalarizedVal::Flavor::tuple: - { - // We are assigning from a tuple to a destination - // that is not a tuple. We will perform assignment - // element-by-element. - auto rightTupleVal = as(right.impl); - Index elementCount = rightTupleVal->elements.getCount(); - - for( Index ee = 0; ee < elementCount; ++ee ) - { - auto rightElement = rightTupleVal->elements[ee]; - auto leftElementVal = extractField( - builder, - left, - ee, - rightElement.key); - assign(builder, leftElementVal, rightElement.val); - } - } - break; - - default: - SLANG_UNEXPECTED("unimplemented"); - break; - } - break; - - case ScalarizedVal::Flavor::tuple: - { - // We have a tuple, so we are going to need to try and assign - // to each of its constituent fields. - auto leftTupleVal = as(left.impl); - Index elementCount = leftTupleVal->elements.getCount(); - - for( Index ee = 0; ee < elementCount; ++ee ) - { - auto rightElementVal = extractField( - builder, - right, - ee, - leftTupleVal->elements[ee].key); - assign(builder, leftTupleVal->elements[ee].val, rightElementVal); - } - } - break; - - case ScalarizedVal::Flavor::typeAdapter: - { - // We are trying to assign to something that had its type adjusted, - // so we will need to adjust the type of the right-hand side first. - // - // In this case we are converting to the actual type of the GLSL variable, - // from the "pretend" type that it had in the IR before. - auto typeAdapter = as(left.impl); - auto adaptedRight = adaptType(builder, right, typeAdapter->actualType, typeAdapter->pretendType); - assign(builder, typeAdapter->val, adaptedRight); - } - break; - - default: - SLANG_UNEXPECTED("unimplemented"); - break; - } -} - -ScalarizedVal getSubscriptVal( - IRBuilder* builder, - IRType* elementType, - ScalarizedVal val, - IRInst* indexVal) -{ - switch( val.flavor ) - { - case ScalarizedVal::Flavor::value: - return ScalarizedVal::value( - builder->emitElementExtract( - elementType, - val.irValue, - indexVal)); - - case ScalarizedVal::Flavor::address: - return ScalarizedVal::address( - builder->emitElementAddress( - builder->getPtrType(elementType), - val.irValue, - indexVal)); - - case ScalarizedVal::Flavor::tuple: - { - auto inputTuple = val.impl.as(); - - RefPtr resultTuple = new ScalarizedTupleValImpl(); - resultTuple->type = elementType; - - Index elementCount = inputTuple->elements.getCount(); - Index elementCounter = 0; - - auto structType = as(elementType); - for(auto field : structType->getFields()) - { - auto tupleElementType = field->getFieldType(); - - Index elementIndex = elementCounter++; - - SLANG_RELEASE_ASSERT(elementIndex < elementCount); - auto inputElement = inputTuple->elements[elementIndex]; - - ScalarizedTupleValImpl::Element resultElement; - resultElement.key = inputElement.key; - resultElement.val = getSubscriptVal( - builder, - tupleElementType, - inputElement.val, - indexVal); - - resultTuple->elements.add(resultElement); - } - SLANG_RELEASE_ASSERT(elementCounter == elementCount); - - return ScalarizedVal::tuple(resultTuple); - } - - default: - SLANG_UNEXPECTED("unimplemented"); - UNREACHABLE_RETURN(ScalarizedVal()); - } -} - -ScalarizedVal getSubscriptVal( - IRBuilder* builder, - IRType* elementType, - ScalarizedVal val, - UInt index) -{ - return getSubscriptVal( - builder, - elementType, - val, - builder->getIntValue( - builder->getIntType(), - index)); -} - -IRInst* materializeValue( - IRBuilder* builder, - ScalarizedVal const& val); - -IRInst* materializeTupleValue( - IRBuilder* builder, - ScalarizedVal val) -{ - auto tupleVal = val.impl.as(); - SLANG_ASSERT(tupleVal); - - Index elementCount = tupleVal->elements.getCount(); - auto type = tupleVal->type; - - if( auto arrayType = as(type)) - { - // The tuple represent an array, which means that the - // individual elements are expected to yield arrays as well. - // - // We will extract a value for each array element, and - // then use these to construct our result. - - List arrayElementVals; - UInt arrayElementCount = (UInt) GetIntVal(arrayType->getElementCount()); - - for( UInt ii = 0; ii < arrayElementCount; ++ii ) - { - auto arrayElementPseudoVal = getSubscriptVal( - builder, - arrayType->getElementType(), - val, - ii); - - auto arrayElementVal = materializeValue( - builder, - arrayElementPseudoVal); - - arrayElementVals.add(arrayElementVal); - } - - return builder->emitMakeArray( - arrayType, - arrayElementVals.getCount(), - arrayElementVals.getBuffer()); - } - else - { - // The tuple represents a value of some aggregate type, - // so we can simply materialize the elements and then - // construct a value of that type. - // - // TODO: this should be using a `makeStruct` instruction. - - List elementVals; - for( Index ee = 0; ee < elementCount; ++ee ) - { - auto elementVal = materializeValue(builder, tupleVal->elements[ee].val); - elementVals.add(elementVal); - } - - return builder->emitConstructorInst( - tupleVal->type, - elementVals.getCount(), - elementVals.getBuffer()); - } -} - -IRInst* materializeValue( - IRBuilder* builder, - ScalarizedVal const& val) -{ - switch( val.flavor ) - { - case ScalarizedVal::Flavor::value: - return val.irValue; - - case ScalarizedVal::Flavor::address: - { - auto loadInst = builder->emitLoad(val.irValue); - return loadInst; - } - break; - - case ScalarizedVal::Flavor::tuple: - { - //auto tupleVal = as(val.impl); - return materializeTupleValue(builder, val); - } - break; - - case ScalarizedVal::Flavor::typeAdapter: - { - // Somebody is trying to use a value where its actual type - // doesn't match the type it pretends to have. To make this - // work we need to adapt the type from its actual type over - // to its pretend type. - auto typeAdapter = as(val.impl); - auto adapted = adaptType(builder, typeAdapter->val, typeAdapter->pretendType, typeAdapter->actualType); - return materializeValue(builder, adapted); - } - break; - - default: - SLANG_UNEXPECTED("unimplemented"); - break; - } -} - -void legalizeRayTracingEntryPointParameterForGLSL( - GLSLLegalizationContext* context, - IRFunc* func, - IRParam* pp, - VarLayout* paramLayout) -{ - auto builder = context->getBuilder(); - auto paramType = pp->getDataType(); - - // The parameter might be either an `in` parameter, - // or an `out` or `in out` parameter, and in those - // latter cases its IR-level type will include a - // wrapping "pointer-like" type (e.g., `Out` - // instead of just `Float`). - // - // Because global shader parameters are read-only - // in the same way function types are, we can take - // care of that detail here just by allocating a - // global shader parameter with exactly the type - // of the original function parameter. - // - auto globalParam = addGlobalParam(builder->getModule(), paramType); - builder->addLayoutDecoration(globalParam, paramLayout); - moveValueBefore(globalParam, builder->getFunc()); - pp->replaceUsesWith(globalParam); - - // Because linkage between ray-tracing shaders is - // based on the type of incoming/outgoing payload - // and attribute parameters, it would be an error to - // eliminate the global parameter *even if* it is - // not actually used inside the entry point. - // - // We attach a decoration to the entry point that - // makes note of the dependency, so that steps - // like dead code elimination cannot get rid of - // the parameter. - // - // TODO: We could consider using a structure like - // this for *all* of the entry point parameters - // that get moved to the global scope, since SPIR-V - // ends up requiring such information on an `OpEntryPoint`. - // - // As a further alternative, we could decide to - // keep entry point varying input/outtput attached - // to the parameter list through all of the Slang IR - // steps, and only declare it as global variables at - // the last minute when emitting a GLSL `main` or - // SPIR-V for an entry point. - // - builder->addDependsOnDecoration(func, globalParam); -} - -void legalizeEntryPointParameterForGLSL( - GLSLLegalizationContext* context, - IRFunc* func, - IRParam* pp, - VarLayout* paramLayout) -{ - auto builder = context->getBuilder(); - auto stage = context->getStage(); - - // We need to create a global variable that will replace the parameter. - // It seems superficially obvious that the variable should have - // the same type as the parameter. - // However, if the parameter was a pointer, in order to - // support `out` or `in out` parameter passing, we need - // to be sure to allocate a variable of the pointed-to - // type instead. - // - // We also need to replace uses of the parameter with - // uses of the variable, and the exact logic there - // will differ a bit between the pointer and non-pointer - // cases. - auto paramType = pp->getDataType(); - - // First we will special-case stage input/outputs that - // don't fit into the standard varying model. - // For right now we are only doing special-case handling - // of geometry shader output streams. - if( auto paramPtrType = as(paramType) ) - { - auto valueType = paramPtrType->getValueType(); - if( auto gsStreamType = as(valueType) ) - { - // An output stream type like `TriangleStream` should - // more or less translate into `out Foo` (plus scalarization). - - auto globalOutputVal = createGLSLGlobalVaryings( - context, - builder, - valueType, - paramLayout, - LayoutResourceKind::VaryingOutput, - stage); - - // TODO: a GS output stream might be passed into other - // functions, so that we should really be modifying - // any function that has one of these in its parameter - // list (and in the limit we should be leagalizing any - // type that nests these...). - // - // For now we will just try to deal with `Append` calls - // directly in this function. - - - - for( auto bb = func->getFirstBlock(); bb; bb = bb->getNextBlock() ) - { - for( auto ii = bb->getFirstInst(); ii; ii = ii->getNextInst() ) - { - // Is it a call? - if(ii->op != kIROp_Call) - continue; - - // Is it calling the append operation? - auto callee = ii->getOperand(0); - for(;;) - { - // If the instruction is `specialize(X,...)` then - // we want to look at `X`, and if it is `generic { ... return R; }` - // then we want to look at `R`. We handle this - // iteratively here. - // - // TODO: This idiom seems to come up enough that we - // should probably have a dedicated convenience routine - // for this. - // - // Alternatively, we could switch the IR encoding so - // that decorations are added to the generic instead of the - // value it returns. - // - switch(callee->op) - { - case kIROp_Specialize: - { - callee = cast(callee)->getOperand(0); - continue; - } - - case kIROp_Generic: - { - auto genericResult = findGenericReturnVal(cast(callee)); - if(genericResult) - { - callee = genericResult; - continue; - } - } - - default: - break; - } - break; - } - if(callee->op != kIROp_Func) - continue; - - // HACK: we will identify the operation based - // on the target-intrinsic definition that was - // given to it. - auto decoration = findTargetIntrinsicDecoration(callee, "glsl"); - if(!decoration) - continue; - - if(decoration->getDefinition() != UnownedStringSlice::fromLiteral("EmitVertex()")) - { - continue; - } - - // Okay, we have a declaration, and we want to modify it! - - builder->setInsertBefore(ii); - - assign(builder, globalOutputVal, ScalarizedVal::value(ii->getOperand(2))); - } - } - - // We will still have references to the parameter coming - // from the `EmitVertex` calls, so we need to replace it - // with something. There isn't anything reasonable to - // replace it with that would have the right type, so - // we will replace it with an undefined value, knowing - // that the emitted code will not actually reference it. - // - // TODO: This approach to generating geometry shader code - // is not ideal, and we should strive to find a better - // approach that involes coding the `EmitVertex` operation - // directly in the stdlib, similar to how ray-tracing - // operations like `TraceRay` are handled. - // - builder->setInsertBefore(func->getFirstBlock()->getFirstOrdinaryInst()); - auto undefinedVal = builder->emitUndefined(pp->getFullType()); - pp->replaceUsesWith(undefinedVal); - - return; - } - } - - // When we have an HLSL ray tracing shader entry point, - // we don't want to translate the inputs/outputs for GLSL/SPIR-V - // according to our default rules, for two reasons: - // - // 1. The input and output for these stages are expected to - // be packaged into `struct` types rather than be scalarized, - // so the usual scalarization approach we take here should - // not be applied. - // - // 2. An `in out` parameter isn't just sugar for a combination - // of an `in` and an `out` parameter, and instead represents the - // read/write "payload" that was passed in. It should legalize - // to a single variable, and we can lower reads/writes of it - // directly, rather than introduce an intermediate temporary. - // - switch( stage ) - { - default: - break; - - case Stage::AnyHit: - case Stage::Callable: - case Stage::ClosestHit: - case Stage::Intersection: - case Stage::Miss: - case Stage::RayGeneration: - legalizeRayTracingEntryPointParameterForGLSL(context, func, pp, paramLayout); - return; - } - - // Is the parameter type a special pointer type - // that indicates the parameter is used for `out` - // or `inout` access? - if(auto paramPtrType = as(paramType) ) - { - // Okay, we have the more interesting case here, - // where the parameter was being passed by reference. - // We are going to create a local variable of the appropriate - // type, which will replace the parameter, along with - // one or more global variables for the actual input/output. - - auto valueType = paramPtrType->getValueType(); - - auto localVariable = builder->emitVar(valueType); - auto localVal = ScalarizedVal::address(localVariable); - - if( auto inOutType = as(paramPtrType) ) - { - // In the `in out` case we need to declare two - // sets of global variables: one for the `in` - // side and one for the `out` side. - auto globalInputVal = createGLSLGlobalVaryings( - context, - builder, valueType, paramLayout, LayoutResourceKind::VaryingInput, stage); - - assign(builder, localVal, globalInputVal); - } - - // Any places where the original parameter was used inside - // the function body should instead use the new local variable. - // Since the parameter was a pointer, we use the variable instruction - // itself (which is an `alloca`d pointer) directly: - pp->replaceUsesWith(localVariable); - - // We also need one or more global variables to write the output to - // when the function is done. We create them here. - auto globalOutputVal = createGLSLGlobalVaryings( - context, - builder, valueType, paramLayout, LayoutResourceKind::VaryingOutput, stage); - - // Now we need to iterate over all the blocks in the function looking - // for any `return*` instructions, so that we can write to the output variable - for( auto bb = func->getFirstBlock(); bb; bb = bb->getNextBlock() ) - { - auto terminatorInst = bb->getLastInst(); - if(!terminatorInst) - continue; - - switch( terminatorInst->op ) - { - default: - continue; - - case kIROp_ReturnVal: - case kIROp_ReturnVoid: - break; - } - - // We dont' re-use `builder` here because we don't want to - // disrupt the source location it is using for inserting - // temporary variables at the top of the function. - // - IRBuilder terminatorBuilder; - terminatorBuilder.sharedBuilder = builder->sharedBuilder; - terminatorBuilder.setInsertBefore(terminatorInst); - - // Assign from the local variabel to the global output - // variable before the actual `return` takes place. - assign(&terminatorBuilder, globalOutputVal, localVal); - } - } - else - { - // This is the "easy" case where the parameter wasn't - // being passed by reference. We start by just creating - // one or more global variables to represent the parameter, - // and attach the required layout information to it along - // the way. - - auto globalValue = createGLSLGlobalVaryings( - context, - builder, paramType, paramLayout, LayoutResourceKind::VaryingInput, stage); - - // Next we need to replace uses of the parameter with - // references to the variable(s). We are going to do that - // somewhat naively, by simply materializing the - // variables at the start. - IRInst* materialized = materializeValue(builder, globalValue); - - pp->replaceUsesWith(materialized); - } -} - -void legalizeEntryPointForGLSL( - Session* session, - IRModule* module, - IRFunc* func, - DiagnosticSink* sink, - ExtensionUsageTracker* extensionUsageTracker) -{ - auto layoutDecoration = func->findDecoration(); - SLANG_ASSERT(layoutDecoration); - - auto entryPointLayout = as(layoutDecoration->getLayout()); - SLANG_ASSERT(entryPointLayout); - - GLSLLegalizationContext context; - context.session = session; - context.stage = entryPointLayout->profile.GetStage(); - context.sink = sink; - context.extensionUsageTracker = extensionUsageTracker; - - Stage stage = entryPointLayout->profile.GetStage(); - - // We require that the entry-point function has no uses, - // because otherwise we'd invalidate the signature - // at all existing call sites. - // - // TODO: the right thing to do here is to split any - // function that both gets called as an entry point - // and as an ordinary function. - SLANG_ASSERT(!func->firstUse); - - // We create a dummy IR builder, since some of - // the functions require it. - // - // TODO: make some of these free functions... - // - SharedIRBuilder shared; - shared.module = module; - shared.session = session; - IRBuilder builder; - builder.sharedBuilder = &shared; - builder.setInsertInto(func); - - context.builder = &builder; - - // We will start by looking at the return type of the - // function, because that will enable us to do an - // early-out check to avoid more work. - // - // Specifically, we need to check if the function has - // a `void` return type, because there is no work - // to be done on its return value in that case. - auto resultType = func->getResultType(); - if(as(resultType)) - { - // In this case, the function doesn't return a value - // so we don't need to transform its `return` sites. - // - // We can also use this opportunity to quickly - // check if the function has any parameters, and if - // it doesn't use the chance to bail out immediately. - if( func->getParamCount() == 0 ) - { - // This function is already legal for GLSL - // (at least in terms of parameter/result signature), - // so we won't bother doing anything at all. - return; - } - - // If the function does have parameters, then we need - // to let the logic later in this function handle them. - } - else - { - // Function returns a value, so we need - // to introduce a new global variable - // to hold that value, and then replace - // any `returnVal` instructions with - // code to write to that variable. - - auto resultGlobal = createGLSLGlobalVaryings( - &context, - &builder, - resultType, - entryPointLayout->resultLayout, - LayoutResourceKind::VaryingOutput, - stage); - - for( auto bb = func->getFirstBlock(); bb; bb = bb->getNextBlock() ) - { - // TODO: This is silly, because we are looking at every instruction, - // when we know that a `returnVal` should only ever appear as a - // terminator... - for( auto ii = bb->getFirstInst(); ii; ii = ii->getNextInst() ) - { - if(ii->op != kIROp_ReturnVal) - continue; - - IRReturnVal* returnInst = (IRReturnVal*) ii; - IRInst* returnValue = returnInst->getVal(); - - // Make sure we add these instructions to the right block - builder.setInsertInto(bb); - - // Write to our global variable(s) from the value being returned. - assign(&builder, resultGlobal, ScalarizedVal::value(returnValue)); - - // Emit a `returnVoid` to end the block - auto returnVoid = builder.emitReturn(); - - // Remove the old `returnVal` instruction. - returnInst->removeAndDeallocate(); - - // Make sure to resume our iteration at an - // appropriate instruciton, since we deleted - // the one we had been using. - ii = returnVoid; - } - } - } - - // Next we will walk through any parameters of the entry-point function, - // and turn them into global variables. - if( auto firstBlock = func->getFirstBlock() ) - { - // Any initialization code we insert for parameters needs - // to be at the start of the "ordinary" instructions in the block: - builder.setInsertBefore(firstBlock->getFirstOrdinaryInst()); - - for( auto pp = firstBlock->getFirstParam(); pp; pp = pp->getNextParam() ) - { - // We assume that the entry-point parameters will all have - // layout information attached to them, which is kept up-to-date - // by any transformations affecting the parameter list. - // - auto paramLayoutDecoration = pp->findDecoration(); - SLANG_ASSERT(paramLayoutDecoration); - auto paramLayout = as(paramLayoutDecoration->getLayout()); - SLANG_ASSERT(paramLayout); - - legalizeEntryPointParameterForGLSL( - &context, - func, - pp, - paramLayout); - } - - // At this point we should have eliminated all uses of the - // parameters of the entry block. Also, our control-flow - // rules mean that the entry block cannot be the target - // of any branches in the code, so there can't be - // any control-flow ops that try to match the parameter - // list. - // - // We can safely go through and destroy the parameters - // themselves, and then clear out the parameter list. - - for( auto pp = firstBlock->getFirstParam(); pp; ) - { - auto next = pp->getNextParam(); - pp->removeAndDeallocate(); - pp = next; - } - } - - // Finally, we need to patch up the type of the entry point, - // because it is no longer accurate. - - IRFuncType* voidFuncType = builder.getFuncType( - 0, - nullptr, - builder.getVoidType()); - func->setFullType(voidFuncType); - - // TODO: we should technically be constructing - // a new `EntryPointLayout` here to reflect - // the way that things have been moved around. -} - -} // namespace Slang diff --git a/source/slang/ir-glsl-legalize.h b/source/slang/ir-glsl-legalize.h deleted file mode 100644 index 994a68247..000000000 --- a/source/slang/ir-glsl-legalize.h +++ /dev/null @@ -1,22 +0,0 @@ -// ir-glsl-legalize.h -#pragma once - -namespace Slang -{ - -class DiagnosticSink; -class Session; - -class ExtensionUsageTracker; - -struct IRFunc; -struct IRModule; - -void legalizeEntryPointForGLSL( - Session* session, - IRModule* module, - IRFunc* func, - DiagnosticSink* sink, - ExtensionUsageTracker* extensionUsageTracker); - -} diff --git a/source/slang/ir-inst-defs.h b/source/slang/ir-inst-defs.h deleted file mode 100644 index f2393b2b3..000000000 --- a/source/slang/ir-inst-defs.h +++ /dev/null @@ -1,480 +0,0 @@ -// ir-inst-defs.h - -#ifndef INST -#error Must #define `INST` before including `ir-inst-defs.h` -#endif - -#ifndef INST_RANGE -#define INST_RANGE(BASE, FIRST, LAST) /* empty */ -#endif - -#ifndef PSEUDO_INST -#define PSEUDO_INST(ID) /* empty */ -#endif - -#define PARENT kIROpFlag_Parent -#define USE_OTHER kIROpFlag_UseOther - -INST(Nop, nop, 0, 0) - -/* Types */ - - /* Basic Types */ - - #define DEFINE_BASE_TYPE_INST(NAME) INST(NAME ## Type, NAME, 0, 0) - FOREACH_BASE_TYPE(DEFINE_BASE_TYPE_INST) - #undef DEFINE_BASE_TYPE_INST - INST(AfterBaseType, afterBaseType, 0, 0) - - INST_RANGE(BasicType, VoidType, AfterBaseType) - - INST(StringType, String, 0, 0) - - /* ArrayTypeBase */ - INST(ArrayType, Array, 2, 0) - INST(UnsizedArrayType, UnsizedArray, 1, 0) - INST_RANGE(ArrayTypeBase, ArrayType, UnsizedArrayType) - - INST(FuncType, Func, 0, 0) - INST(BasicBlockType, BasicBlock, 0, 0) - - INST(VectorType, Vec, 2, 0) - INST(MatrixType, Mat, 3, 0) - - INST(TaggedUnionType, TaggedUnion, 0, 0) - - // A `BindExistentials` represents - // taking type `B` and binding each of its existential type - // parameters, recursively, with the specified arguments, - // where each `Ti, wi` pair represents the concrete type - // and witness table to plug in for parameter `i`. - // - INST(BindExistentialsType, BindExistentials, 1, 0) - - /* Rate */ - INST(ConstExprRate, ConstExpr, 0, 0) - INST(GroupSharedRate, GroupShared, 0, 0) - INST_RANGE(Rate, ConstExprRate, GroupSharedRate) - - INST(RateQualifiedType, RateQualified, 2, 0) - - // Kinds represent the "types of types." - // They should not really be nested under `IRType` - // in the overall hierarchy, but we can fix that later. - // - /* Kind */ - INST(TypeKind, Type, 0, 0) - INST(RateKind, Rate, 0, 0) - INST(GenericKind, Generic, 0, 0) - INST_RANGE(Kind, TypeKind, GenericKind) - - /* PtrTypeBase */ - INST(PtrType, Ptr, 1, 0) - INST(RefType, Ref, 1, 0) - - // An `ExistentialBox` represents a logical pointer to a value of type `T`. - // On targets that support pointers this might lower to a pointer, but on - // current targets it will lower to zero bytes, with a value of type `T` - // being stored "out of line" somewhere. - // - INST(ExistentialBoxType, ExistentialBox, 1, 0) - - /* OutTypeBase */ - INST(OutType, Out, 1, 0) - INST(InOutType, InOut, 1, 0) - INST_RANGE(OutTypeBase, OutType, InOutType) - INST_RANGE(PtrTypeBase, PtrType, InOutType) - - /* SamplerStateTypeBase */ - INST(SamplerStateType, SamplerState, 0, 0) - INST(SamplerComparisonStateType, SamplerComparisonState, 0, 0) - INST_RANGE(SamplerStateTypeBase, SamplerStateType, SamplerComparisonStateType) - - // TODO: Why do we have all this hierarchy here, when everything - // that actually matters is currently nested under `TextureTypeBase`? - /* ResourceTypeBase */ - /* ResourceType */ - /* TextureTypeBase */ - // NOTE! TextureFlavor::Flavor is stored in 'other' bits for these types. - /* TextureType */ - INST(TextureType, TextureType, 0, USE_OTHER) - /* TextureSamplerType */ - INST(TextureSamplerType, TextureSamplerType, 0, USE_OTHER) - /* GLSLImageType */ - INST(GLSLImageType, GLSLImageType, 0, USE_OTHER) - INST_RANGE(TextureTypeBase, TextureType, GLSLImageType) - INST_RANGE(ResourceType, TextureType, GLSLImageType) - INST_RANGE(ResourceTypeBase, TextureType, GLSLImageType) - - - /* UntypedBufferResourceType */ - /* ByteAddressBufferTypeBase */ - INST(HLSLByteAddressBufferType, ByteAddressBuffer, 0, 0) - INST(HLSLRWByteAddressBufferType, RWByteAddressBuffer, 0, 0) - INST(HLSLRasterizerOrderedByteAddressBufferType, RasterizerOrderedByteAddressBuffer, 0, 0) - INST_RANGE(ByteAddressBufferTypeBase, HLSLByteAddressBufferType, HLSLRasterizerOrderedByteAddressBufferType) - INST(RaytracingAccelerationStructureType, RaytracingAccelerationStructure, 0, 0) - INST_RANGE(UntypedBufferResourceType, HLSLByteAddressBufferType, RaytracingAccelerationStructureType) - - /* HLSLPatchType */ - INST(HLSLInputPatchType, InputPatch, 2, 0) - INST(HLSLOutputPatchType, OutputPatch, 2, 0) - INST_RANGE(HLSLPatchType, HLSLInputPatchType, HLSLOutputPatchType) - - INST(GLSLInputAttachmentType, GLSLInputAttachment, 0, 0) - - /* BuiltinGenericType */ - /* HLSLStreamOutputType */ - INST(HLSLPointStreamType, PointStream, 1, 0) - INST(HLSLLineStreamType, LineStream, 1, 0) - INST(HLSLTriangleStreamType, TriangleStream, 1, 0) - INST_RANGE(HLSLStreamOutputType, HLSLPointStreamType, HLSLTriangleStreamType) - - /* HLSLStructuredBufferTypeBase */ - INST(HLSLStructuredBufferType, StructuredBuffer, 0, 0) - INST(HLSLRWStructuredBufferType, RWStructuredBuffer, 0, 0) - INST(HLSLRasterizerOrderedStructuredBufferType, RasterizerOrderedStructuredBuffer, 0, 0) - INST(HLSLAppendStructuredBufferType, AppendStructuredBuffer, 0, 0) - INST(HLSLConsumeStructuredBufferType, ConsumeStructuredBuffer, 0, 0) - INST_RANGE(HLSLStructuredBufferTypeBase, HLSLStructuredBufferType, HLSLConsumeStructuredBufferType) - - /* PointerLikeType */ - /* ParameterGroupType */ - /* UniformParameterGroupType */ - INST(ConstantBufferType, ConstantBuffer, 1, 0) - INST(TextureBufferType, TextureBuffer, 1, 0) - INST(ParameterBlockType, ParameterBlock, 1, 0) - INST(GLSLShaderStorageBufferType, GLSLShaderStorageBuffer, 0, 0) - INST_RANGE(UniformParameterGroupType, ConstantBufferType, GLSLShaderStorageBufferType) - - /* VaryingParameterGroupType */ - INST(GLSLInputParameterGroupType, GLSLInputParameterGroup, 0, 0) - INST(GLSLOutputParameterGroupType, GLSLOutputParameterGroup, 0, 0) - INST_RANGE(VaryingParameterGroupType, GLSLInputParameterGroupType, GLSLOutputParameterGroupType) - INST_RANGE(ParameterGroupType, ConstantBufferType, GLSLOutputParameterGroupType) - INST_RANGE(PointerLikeType, ConstantBufferType, GLSLOutputParameterGroupType) - INST_RANGE(BuiltinGenericType, HLSLPointStreamType, GLSLOutputParameterGroupType) - - - - -// A user-defined structure declaration at the IR level. -// Unlike in the AST where there is a distinction between -// a `StructDecl` and a `DeclRefType` that refers to it, -// at the IR level the struct declaration and the type -// are the same IR instruction. -// -// This is a parent instruction that holds zero or more -// `field` instructions. -// -INST(StructType, struct, 0, PARENT) -INST(InterfaceType, interface, 0, PARENT) - -INST_RANGE(Type, VoidType, InterfaceType) - -/*IRGlobalValueWithCode*/ - /* IRGlobalValueWIthParams*/ - INST(Func, func, 0, PARENT) - INST(Generic, generic, 0, PARENT) - INST_RANGE(GlobalValueWithParams, Func, Generic) - - INST(GlobalVar, global_var, 0, 0) - INST(GlobalConstant, global_constant, 0, 0) -INST_RANGE(GlobalValueWithCode, Func, GlobalConstant) - -INST(GlobalParam, global_param, 0, 0) - -INST(StructKey, key, 0, 0) -INST(GlobalGenericParam, global_generic_param, 0, 0) -INST(WitnessTable, witness_table, 0, 0) - -INST(Module, module, 0, PARENT) - -INST(Block, block, 0, PARENT) - -/* IRConstant */ - INST(BoolLit, boolConst, 0, 0) - INST(IntLit, integer_constant, 0, 0) - INST(FloatLit, float_constant, 0, 0) - INST(PtrLit, ptr_constant, 0, 0) - INST(StringLit, string_constant, 0, 0) -INST_RANGE(Constant, BoolLit, StringLit) - -INST(undefined, undefined, 0, 0) - -INST(Specialize, specialize, 2, 0) -INST(lookup_interface_method, lookup_interface_method, 2, 0) -INST(lookup_witness_table, lookup_witness_table, 2, 0) -INST(BindGlobalGenericParam, bind_global_generic_param, 2, 0) -INST(BindGlobalExistentialSlots, bindGlobalExistentialSlots, 0, 0) - -INST(Construct, construct, 0, 0) - -INST(makeVector, makeVector, 0, 0) -INST(MakeMatrix, makeMatrix, 0, 0) -INST(makeArray, makeArray, 0, 0) -INST(makeStruct, makeStruct, 0, 0) - -INST(Call, call, 1, 0) - - -INST(WitnessTableEntry, witness_table_entry, 2, 0) - -INST(Param, param, 0, 0) -INST(StructField, field, 2, 0) -INST(Var, var, 0, 0) - -INST(Load, load, 1, 0) -INST(Store, store, 2, 0) - -INST(FieldExtract, get_field, 2, 0) -INST(FieldAddress, get_field_addr, 2, 0) - -INST(getElement, getElement, 2, 0) -INST(getElementPtr, getElementPtr, 2, 0) - -// "Subscript" an image at a pixel coordinate to get pointer -INST(ImageSubscript, imageSubscript, 2, 0) - -// Construct a vector from a scalar -// -// %dst = constructVectorFromScalar %T %N %val -// -// where -// - `T` is a `Type` -// - `N` is a (compile-time) `Int` -// - `val` is a `T` -// - dst is a `Vec` -// -INST(constructVectorFromScalar, constructVectorFromScalar, 3, 0) - -// A swizzle of a vector: -// -// %dst = swizzle %src %idx0 %idx1 ... -// -// where: -// - `src` is a vector -// - `dst` is a vector -// - `idx0` through `idx[M-1]` are literal integers -// -INST(swizzle, swizzle, 1, 0) - -// Setting a vector via swizzle -// -// %dst = swizzle %base %src %idx0 %idx1 ... -// -// where: -// - `base` is a vector -// - `dst` is a vector -// - `src` is a vector -// - `idx0` through `idx[M-1]` are literal integers -// -// The semantics of the op is: -// -// dst = base; -// for(ii : 0 ... M-1 ) -// dst[ii] = src[idx[ii]]; -// -INST(swizzleSet, swizzleSet, 2, 0) - -// Store to memory with a swizzle -// -// TODO: eventually this should be reduced to just -// a write mask by moving the actual swizzle to the RHS. -// -// swizzleStore %dst %src %idx0 %idx1 ... -// -// where: -// - `dst` is a vector -// - `src` is a vector -// - `idx0` through `idx[M-1]` are literal integers -// -// The semantics of the op is: -// -// for(ii : 0 ... M-1 ) -// dst[ii] = src[idx[ii]]; -// -INST(SwizzledStore, swizzledStore, 2, 0) - - -/* IRTerminatorInst */ - - INST(ReturnVal, return_val, 1, 0) - INST(ReturnVoid, return_void, 1, 0) - - /* IRUnconditionalBranch */ - // unconditionalBranch - INST(unconditionalBranch, unconditionalBranch, 1, 0) - - // loop - INST(loop, loop, 3, 0) - INST_RANGE(UnconditionalBranch, unconditionalBranch, loop) - - /* IRConditionalbranch */ - - // conditionalBranch - INST(conditionalBranch, conditionalBranch, 3, 0) - - // ifElse - INST(ifElse, ifElse, 4, 0) - INST_RANGE(ConditionalBranch, conditionalBranch, ifElse) - - // switch ... - INST(Switch, switch, 3, 0) - - INST(discard, discard, 0, 0) - - /* IRUnreachable */ - INST(MissingReturn, missingReturn, 0, 0) - INST(Unreachable, unreachable, 0, 0) - INST_RANGE(Unreachable, MissingReturn, Unreachable) - -INST_RANGE(TerminatorInst, ReturnVal, Unreachable) - -INST(Add, add, 2, 0) -INST(Sub, sub, 2, 0) -INST(Mul, mul, 2, 0) -INST(Div, div, 2, 0) -INST(Mod, mod, 2, 0) - -INST(Lsh, shl, 2, 0) -INST(Rsh, shr, 2, 0) - -INST(Eql, cmpEQ, 2, 0) -INST(Neq, cmpNE, 2, 0) -INST(Greater, cmpGT, 2, 0) -INST(Less, cmpLT, 2, 0) -INST(Geq, cmpGE, 2, 0) -INST(Leq, cmpLE, 2, 0) - -INST(BitAnd, and, 2, 0) -INST(BitXor, xor, 2, 0) -INST(BitOr, or , 2, 0) - -INST(And, logicalAnd, 2, 0) -INST(Or, logicalOr, 2, 0) - -INST(Neg, neg, 1, 0) -INST(Not, not, 1, 0) -INST(BitNot, bitnot, 1, 0) - -INST(Select, select, 3, 0) - -INST(Dot, dot, 2, 0) - -INST(Mul_Vector_Matrix, mulVectorMatrix, 2, 0) -INST(Mul_Matrix_Vector, mulMatrixVector, 2, 0) -INST(Mul_Matrix_Matrix, mulMatrixMatrix, 2, 0) - -// Texture sampling operation of the form `t.Sample(s,u)` -INST(Sample, sample, 3, 0) - -INST(SampleGrad, sampleGrad, 4, 0) - -INST(GroupMemoryBarrierWithGroupSync, GroupMemoryBarrierWithGroupSync, 0, 0) - -/* Decoration */ - -INST(HighLevelDeclDecoration, highLevelDecl, 1, 0) - INST(LayoutDecoration, layout, 1, 0) - INST(LoopControlDecoration, loopControl, 1, 0) - /* TargetSpecificDecoration */ - INST(TargetDecoration, target, 1, 0) - INST(TargetIntrinsicDecoration, targetIntrinsic, 2, 0) - INST_RANGE(TargetSpecificDecoration, TargetDecoration, TargetIntrinsicDecoration) - INST(GLSLOuterArrayDecoration, glslOuterArray, 1, 0) - INST(SemanticDecoration, semantic, 1, 0) - INST(InterpolationModeDecoration, interpolationMode, 1, 0) - INST(NameHintDecoration, nameHint, 1, 0) - - /** The decorated _instruction_ is transitory. Such a decoration should NEVER be found on an output instruction a module. - Typically used mark an instruction so can be specially handled - say when creating a IRConstant literal, and the payload of - needs to be special cased for lookup. */ - INST(TransitoryDecoration, transitory, 0, 0) - - INST(VulkanRayPayloadDecoration, vulkanRayPayload, 0, 0) - INST(VulkanHitAttributesDecoration, vulkanHitAttributes, 0, 0) - INST(RequireGLSLVersionDecoration, requireGLSLVersion, 1, 0) - INST(RequireGLSLExtensionDecoration, requireGLSLExtension, 1, 0) - INST(ReadNoneDecoration, readNone, 0, 0) - INST(VulkanCallablePayloadDecoration, vulkanCallablePayload, 0, 0) - INST(EarlyDepthStencilDecoration, earlyDepthStencil, 0, 0) - INST(GloballyCoherentDecoration, globallyCoherent, 0, 0) - INST(PreciseDecoration, precise, 0, 0) - INST(PatchConstantFuncDecoration, patchConstantFunc, 1, 0) - - /// An `[entryPoint]` decoration marks a function that represents a shader entry point. - INST(EntryPointDecoration, entryPoint, 0, 0) - - /// A `[dependsOn(x)]` decoration indicates that the parent instruction depends on `x` - /// even if it does not otherwise reference it. - INST(DependsOnDecoration, dependsOn, 1, 0) - - /// A `[keepAlive]` decoration marks an instruction that should not be eliminated. - INST(KeepAliveDecoration, keepAlive, 0, 0) - - INST(BindExistentialSlotsDecoration, bindExistentialSlots, 0, 0) - - /// A `[format(f)]` decoration specifies that the format of an image should be `f` - INST(FormatDecoration, format, 1, 0) - - /* LinkageDecoration */ - INST(ImportDecoration, import, 1, 0) - INST(ExportDecoration, export, 1, 0) - INST_RANGE(LinkageDecoration, ImportDecoration, ExportDecoration) - -INST_RANGE(Decoration, HighLevelDeclDecoration, ExportDecoration) - - -// - -// A `makeExistential(v : C, w) : I` instruction takes a value `v` of type `C` -// and produces a value of interface type `I` by using the witness `w` which -// shows that `C` conforms to `I`. -// -INST(MakeExistential, makeExistential, 2, 0) - -// A `wrapExistential(v, T0,w0, T1,w0) : T` instruction is similar to `makeExistential`. -// but applies to a value `v` that is of type `BindExistentials(T, T0,w0, ...)`. The -// result of the `wrapExistentials` operation is a value of type `T`, allowing us to -// "smuggle" a value of specialized type into computations that expect an unspecialized type. -// -INST(WrapExistential, wrapExistential, 2, 0) - -INST(ExtractExistentialValue, extractExistentialValue, 1, 0) -INST(ExtractExistentialType, extractExistentialType, 1, 0) -INST(ExtractExistentialWitnessTable, extractExistentialWitnessTable, 1, 0) - -INST(ExtractTaggedUnionTag, extractTaggedUnionTag, 1, 0) -INST(ExtractTaggedUnionPayload, extractTaggedUnionPayload, 1, 0) - -INST(BitCast, bitCast, 1, 0) - -PSEUDO_INST(Pos) -PSEUDO_INST(PreInc) - -PSEUDO_INST(PreDec) -PSEUDO_INST(PostInc) -PSEUDO_INST(PostDec) -PSEUDO_INST(Sequence) -PSEUDO_INST(AddAssign) -PSEUDO_INST(SubAssign) -PSEUDO_INST(MulAssign) -PSEUDO_INST(DivAssign) -PSEUDO_INST(ModAssign) -PSEUDO_INST(AndAssign) -PSEUDO_INST(OrAssign) -PSEUDO_INST(XorAssign ) -PSEUDO_INST(LshAssign) -PSEUDO_INST(RshAssign) -PSEUDO_INST(Assign) -PSEUDO_INST(And) -PSEUDO_INST(Or) - - -#undef PSEUDO_INST -#undef PARENT -#undef USE_OTHER -#undef INST_RANGE -#undef INST - diff --git a/source/slang/ir-insts.h b/source/slang/ir-insts.h deleted file mode 100644 index 20279e3aa..000000000 --- a/source/slang/ir-insts.h +++ /dev/null @@ -1,1343 +0,0 @@ -// ir-insts.h -#ifndef SLANG_IR_INSTS_H_INCLUDED -#define SLANG_IR_INSTS_H_INCLUDED - -// This file extends the core definitions in `ir.h` -// with a wider variety of concrete instructions, -// and a "builder" abstraction. -// -// TODO: the builder probably needs its own file. - -#include "compiler.h" -#include "ir.h" -#include "syntax.h" -#include "type-layout.h" - -namespace Slang { - -class Decl; - -struct IRDecoration : IRInst -{ - IR_PARENT_ISA(Decoration) - - IRDecoration* getNextDecoration() - { - return as(getNextInst()); - } -}; - -// Associates an IR-level decoration with a source declaration -// in the high-level AST, that can be used to extract -// additional information that informs code emission. -struct IRHighLevelDeclDecoration : IRDecoration -{ - enum { kOp = kIROp_HighLevelDeclDecoration }; - IR_LEAF_ISA(HighLevelDeclDecoration) - - IRPtrLit* getDeclOperand() { return cast(getOperand(0)); } - Decl* getDecl() { return (Decl*) getDeclOperand()->getValue(); } -}; - -// Associates an IR-level decoration with a source layout -struct IRLayoutDecoration : IRDecoration -{ - enum { kOp = kIROp_LayoutDecoration }; - IR_LEAF_ISA(LayoutDecoration) - - IRPtrLit* getLayoutOperand() { return cast(getOperand(0)); } - Layout* getLayout() { return (Layout*) getLayoutOperand()->getValue(); } -}; - -enum IRLoopControl -{ - kIRLoopControl_Unroll, -}; - -struct IRLoopControlDecoration : IRDecoration -{ - enum { kOp = kIROp_LoopControlDecoration }; - IR_LEAF_ISA(LoopControlDecoration) - - IRConstant* getModeOperand() { return cast(getOperand(0)); } - - IRLoopControl getMode() - { - return IRLoopControl(getModeOperand()->value.intVal); - } -}; - - -struct IRTargetSpecificDecoration : IRDecoration -{ - IR_PARENT_ISA(TargetSpecificDecoration) - - IRStringLit* getTargetNameOperand() { return cast(getOperand(0)); } - - UnownedStringSlice getTargetName() - { - return getTargetNameOperand()->getStringSlice(); - } -}; - -struct IRTargetDecoration : IRTargetSpecificDecoration -{ - enum { kOp = kIROp_TargetDecoration }; - IR_LEAF_ISA(TargetDecoration) -}; - -struct IRTargetIntrinsicDecoration : IRTargetSpecificDecoration -{ - enum { kOp = kIROp_TargetIntrinsicDecoration }; - IR_LEAF_ISA(TargetIntrinsicDecoration) - - IRStringLit* getDefinitionOperand() { return cast(getOperand(1)); } - - UnownedStringSlice getDefinition() - { - return getDefinitionOperand()->getStringSlice(); - } -}; - -struct IRGLSLOuterArrayDecoration : IRDecoration -{ - enum { kOp = kIROp_GLSLOuterArrayDecoration }; - IR_LEAF_ISA(GLSLOuterArrayDecoration) - - IRStringLit* getOuterArraynameOperand() { return cast(getOperand(0)); } - - UnownedStringSlice getOuterArrayName() - { - return getOuterArraynameOperand()->getStringSlice(); - } -}; - -// A decoration that marks a field key as having been associated -// with a particular simple semantic (e.g., `COLOR` or `SV_Position`, -// but not a `register` semantic). -// -// This is currently needed so that we can round-trip HLSL `struct` -// types that get used for varying input/output. This is an unfortunate -// case where some amount of "layout" information can't just come -// in via the `TypeLayout` part of things. -// -struct IRSemanticDecoration : IRDecoration -{ - enum { kOp = kIROp_SemanticDecoration }; - IR_LEAF_ISA(SemanticDecoration) - - IRStringLit* getSemanticNameOperand() { return cast(getOperand(0)); } - - UnownedStringSlice getSemanticName() - { - return getSemanticNameOperand()->getStringSlice(); - } -}; - -enum class IRInterpolationMode -{ - Linear, - NoPerspective, - NoInterpolation, - - Centroid, - Sample, -}; - -struct IRInterpolationModeDecoration : IRDecoration -{ - enum { kOp = kIROp_InterpolationModeDecoration }; - IR_LEAF_ISA(InterpolationModeDecoration) - - IRConstant* getModeOperand() { return cast(getOperand(0)); } - - IRInterpolationMode getMode() - { - return IRInterpolationMode(getModeOperand()->value.intVal); - } -}; - -/// A decoration that provides a desired name to be used -/// in conjunction with the given instruction. Back-end -/// code generation may use this to help derive symbol -/// names, emit debug information, etc. -struct IRNameHintDecoration : IRDecoration -{ - enum { kOp = kIROp_NameHintDecoration }; - IR_LEAF_ISA(NameHintDecoration) - - IRStringLit* getNameOperand() { return cast(getOperand(0)); } - - UnownedStringSlice getName() - { - return getNameOperand()->getStringSlice(); - } -}; - -#define IR_SIMPLE_DECORATION(NAME) \ - struct IR##NAME : IRDecoration \ - { \ - enum { kOp = kIROp_##NAME }; \ - IR_LEAF_ISA(NAME) \ - }; \ - /**/ - -/// A decoration that indicates that a variable represents -/// a vulkan ray payload, and should have a location assigned -/// to it. -IR_SIMPLE_DECORATION(VulkanRayPayloadDecoration) - -/// A decoration that indicates that a variable represents -/// a vulkan callable shader payload, and should have a location assigned -/// to it. -IR_SIMPLE_DECORATION(VulkanCallablePayloadDecoration) - -/// A decoration that indicates that a variable represents -/// vulkan hit attributes, and should have a location assigned -/// to it. -IR_SIMPLE_DECORATION(VulkanHitAttributesDecoration) - -struct IRRequireGLSLVersionDecoration : IRDecoration -{ - enum { kOp = kIROp_RequireGLSLVersionDecoration }; - IR_LEAF_ISA(RequireGLSLVersionDecoration) - - IRConstant* getLanguageVersionOperand() { return cast(getOperand(0)); } - - Int getLanguageVersion() - { - return Int(getLanguageVersionOperand()->value.intVal); - } -}; - -struct IRRequireGLSLExtensionDecoration : IRDecoration -{ - enum { kOp = kIROp_RequireGLSLExtensionDecoration }; - IR_LEAF_ISA(RequireGLSLExtensionDecoration) - - IRStringLit* getExtensionNameOperand() { return cast(getOperand(0)); } - - UnownedStringSlice getExtensionName() - { - return getExtensionNameOperand()->getStringSlice(); - } -}; - -IR_SIMPLE_DECORATION(ReadNoneDecoration) -IR_SIMPLE_DECORATION(EarlyDepthStencilDecoration) -IR_SIMPLE_DECORATION(GloballyCoherentDecoration) -IR_SIMPLE_DECORATION(PreciseDecoration) - - /// A decoration that marks a value as having linkage. - /// - /// A value with linkage is either exported from its module, - /// or will have a definition imported from another module. - /// In either case, it requires a mangled name to use when - /// matching imports and exports. - /// -struct IRLinkageDecoration : IRDecoration -{ - IR_PARENT_ISA(LinkageDecoration) - - IRStringLit* getMangledNameOperand() { return cast(getOperand(0)); } - - UnownedStringSlice getMangledName() - { - return getMangledNameOperand()->getStringSlice(); - } -}; - -struct IRImportDecoration : IRLinkageDecoration -{ - enum { kOp = kIROp_ImportDecoration }; - IR_LEAF_ISA(ImportDecoration) -}; - -struct IRExportDecoration : IRLinkageDecoration -{ - enum { kOp = kIROp_ExportDecoration }; - IR_LEAF_ISA(ExportDecoration) -}; - -struct IRFormatDecoration : IRDecoration -{ - enum { kOp = kIROp_FormatDecoration }; - IR_LEAF_ISA(FormatDecoration) - - IRConstant* getFormatOperand() { return cast(getOperand(0)); } - - ImageFormat getFormat() - { - return ImageFormat(getFormatOperand()->value.intVal); - } -}; - -// An instruction that specializes another IR value -// (representing a generic) to a particular set of generic arguments -// (instructions representing types, witness tables, etc.) -struct IRSpecialize : IRInst -{ - // The "base" for the call is the generic to be specialized - IRUse base; - IRInst* getBase() { return getOperand(0); } - - // after the generic value come the arguments - UInt getArgCount() { return getOperandCount() - 1; } - IRInst* getArg(UInt index) { return getOperand(index + 1); } - - IR_LEAF_ISA(Specialize) -}; - -// An instruction that looks up the implementation -// of an interface operation identified by `requirementDeclRef` -// in the witness table `witnessTable` which should -// hold the conformance information for a specific type. -struct IRLookupWitnessMethod : IRInst -{ - IRUse witnessTable; - IRUse requirementKey; - - IRInst* getWitnessTable() { return witnessTable.get(); } - IRInst* getRequirementKey() { return requirementKey.get(); } -}; - -struct IRLookupWitnessTable : IRInst -{ - IRUse sourceType; - IRUse interfaceType; -}; - -// - -struct IRCall : IRInst -{ - IR_LEAF_ISA(Call) - - IRInst* getCallee() { return getOperand(0); } - - UInt getArgCount() { return getOperandCount() - 1; } - IRInst* getArg(UInt index) { return getOperand(index + 1); } -}; - -struct IRLoad : IRInst -{ - IRUse ptr; -}; - -struct IRStore : IRInst -{ - IRUse ptr; - IRUse val; -}; - -struct IRFieldExtract : IRInst -{ - IRUse base; - IRUse field; - - IRInst* getBase() { return base.get(); } - IRInst* getField() { return field.get(); } -}; - -struct IRFieldAddress : IRInst -{ - IRUse base; - IRUse field; - - IRInst* getBase() { return base.get(); } - IRInst* getField() { return field.get(); } -}; - -// Terminators - -struct IRReturn : IRTerminatorInst -{}; - -struct IRReturnVal : IRReturn -{ - IRUse val; - - IRInst* getVal() { return val.get(); } -}; - -struct IRReturnVoid : IRReturn -{}; - -struct IRDiscard : IRTerminatorInst -{}; - -// Signals that this point in the code should be unreachable. -// We can/should emit a dataflow error if we can ever determine -// that a block ending in one of these can actually be -// executed. -struct IRUnreachable : IRTerminatorInst -{ - IR_PARENT_ISA(Unreachable); -}; - -struct IRMissingReturn : IRUnreachable -{ - IR_LEAF_ISA(MissingReturn); -}; - -struct IRBlock; - -struct IRUnconditionalBranch : IRTerminatorInst -{ - IRUse block; - - IRBlock* getTargetBlock() { return (IRBlock*)block.get(); } - - UInt getArgCount(); - IRUse* getArgs(); - IRInst* getArg(UInt index); - - IR_PARENT_ISA(UnconditionalBranch); -}; - -// Special cases of unconditional branch, to handle -// structured control flow: -struct IRBreak : IRUnconditionalBranch {}; -struct IRContinue : IRUnconditionalBranch {}; - -// The start of a loop is a special control-flow -// instruction, that records relevant information -// about the loop structure: -struct IRLoop : IRUnconditionalBranch -{ - // The next block after the loop, which - // is where we expect control flow to - // re-converge, and also where a - // `break` will target. - IRUse breakBlock; - - // The block where control flow will go - // on a `continue`. - IRUse continueBlock; - - IRBlock* getBreakBlock() { return (IRBlock*)breakBlock.get(); } - IRBlock* getContinueBlock() { return (IRBlock*)continueBlock.get(); } -}; - -struct IRConditionalBranch : IRTerminatorInst -{ - IR_PARENT_ISA(ConditionalBranch) - - IRUse condition; - IRUse trueBlock; - IRUse falseBlock; - - IRInst* getCondition() { return condition.get(); } - IRBlock* getTrueBlock() { return (IRBlock*)trueBlock.get(); } - IRBlock* getFalseBlock() { return (IRBlock*)falseBlock.get(); } -}; - -// A conditional branch that represent the test inside a loop -struct IRLoopTest : IRConditionalBranch -{ -}; - -// A conditional branch that represents a one-sided `if`: -// -// if( ) { } -// -struct IRIf : IRConditionalBranch -{ - IRBlock* getAfterBlock() { return getFalseBlock(); } -}; - -// A conditional branch that represents a two-sided `if`: -// -// if( ) { } -// else { } -// -// -struct IRIfElse : IRConditionalBranch -{ - IRUse afterBlock; - - IRBlock* getAfterBlock() { return (IRBlock*)afterBlock.get(); } -}; - -// A multi-way branch that represents a source-level `switch` -struct IRSwitch : IRTerminatorInst -{ - IR_LEAF_ISA(Switch); - - IRUse condition; - IRUse breakLabel; - IRUse defaultLabel; - - IRInst* getCondition() { return condition.get(); } - IRBlock* getBreakLabel() { return (IRBlock*) breakLabel.get(); } - IRBlock* getDefaultLabel() { return (IRBlock*) defaultLabel.get(); } - - // remaining args are: caseVal, caseLabel, ... - - UInt getCaseCount() { return (getOperandCount() - 3) / 2; } - IRInst* getCaseValue(UInt index) { return getOperand(3 + index*2 + 0); } - IRBlock* getCaseLabel(UInt index) { return (IRBlock*) getOperand(3 + index*2 + 1); } -}; - -struct IRSwizzle : IRInst -{ - IRUse base; - - IRInst* getBase() { return base.get(); } - UInt getElementCount() - { - return getOperandCount() - 1; - } - IRInst* getElementIndex(UInt index) - { - return getOperand(index + 1); - } -}; - -struct IRSwizzleSet : IRInst -{ - IRUse base; - IRUse source; - - IRInst* getBase() { return base.get(); } - IRInst* getSource() { return source.get(); } - UInt getElementCount() - { - return getOperandCount() - 2; - } - IRInst* getElementIndex(UInt index) - { - return getOperand(index + 2); - } -}; - -struct IRSwizzledStore : IRInst -{ - IRInst* getDest() { return getOperand(0); } - IRInst* getSource() { return getOperand(1); } - UInt getElementCount() - { - return getOperandCount() - 2; - } - IRInst* getElementIndex(UInt index) - { - return getOperand(index + 2); - } - - IR_LEAF_ISA(SwizzledStore) -}; - - -struct IRPatchConstantFuncDecoration : IRDecoration -{ - enum { kOp = kIROp_PatchConstantFuncDecoration }; - IR_LEAF_ISA(PatchConstantFuncDecoration) - - IRInst* getFunc() { return getOperand(0); } -}; - -// An IR `var` instruction conceptually represents -// a stack allocation of some memory. -struct IRVar : IRInst -{ - IRPtrType* getDataType() - { - return cast(IRInst::getDataType()); - } - - static bool isaImpl(IROp op) { return op == kIROp_Var; } -}; - -/// @brief A global variable. -/// -/// Represents a global variable in the IR. -/// If the variable has an initializer, then -/// it is represented by the code in the basic -/// blocks nested inside this value. -struct IRGlobalVar : IRGlobalValueWithCode -{ - IRPtrType* getDataType() - { - return cast(IRInst::getDataType()); - } -}; - -/// @brief A global constant. -/// -/// Represents a global-scope constant value in the IR. -/// The initializer for the constant is represented by -/// the code in the basic block(s) nested in this value. -struct IRGlobalConstant : IRGlobalValueWithCode -{ - IR_LEAF_ISA(GlobalConstant) -}; - -struct IRGlobalParam : IRInst -{ - IR_LEAF_ISA(GlobalParam) -}; - - -// An entry in a witness table (see below) -struct IRWitnessTableEntry : IRInst -{ - // The AST-level requirement - IRUse requirementKey; - - // The IR-level value that satisfies the requirement - IRUse satisfyingVal; - - IRInst* getRequirementKey() { return getOperand(0); } - IRInst* getSatisfyingVal() { return getOperand(1); } - - IR_LEAF_ISA(WitnessTableEntry) -}; - -// A witness table is a global value that stores -// information about how a type conforms to some -// interface. It basically takes the form of a -// map from the required members of the interface -// to the IR values that satisfy those requirements. -struct IRWitnessTable : IRInst -{ - IRInstList getEntries() - { - return IRInstList(getChildren()); - } - - IR_LEAF_ISA(WitnessTable) -}; - -// An instruction that yields an undefined value. -// -// Note that we make this an instruction rather than a value, -// so that we will be able to identify a variable that is -// used when undefined. -struct IRUndefined : IRInst -{ -}; - -// A global-scope generic parameter (a type parameter, a -// constraint parameter, etc.) -struct IRGlobalGenericParam : IRInst -{ - IR_LEAF_ISA(GlobalGenericParam) -}; - -// An instruction that binds a global generic parameter -// to a particular value. -struct IRBindGlobalGenericParam : IRInst -{ - IRGlobalGenericParam* getParam() { return cast(getOperand(0)); } - IRInst* getVal() { return getOperand(1); } - - IR_LEAF_ISA(BindGlobalGenericParam) -}; - - - /// An instruction that packs a concrete value into an existential-type "box" -struct IRMakeExistential : IRInst -{ - IRInst* getWrappedValue() { return getOperand(0); } - IRInst* getWitnessTable() { return getOperand(1); } - - IR_LEAF_ISA(MakeExistential) -}; - - /// Generalizes `IRMakeExistential` by allowing a type with existential sub-fields to be boxed -struct IRWrapExistential : IRInst -{ - IRInst* getWrappedValue() { return getOperand(0); } - - UInt getSlotOperandCount() { return getOperandCount() - 1; } - IRInst* getSlotOperand(UInt index) { return getOperand(index + 1); } - IRUse* getSlotOperands() { return getOperands() + 1; } - - IR_LEAF_ISA(WrapExistential) -}; - - -// Description of an instruction to be used for global value numbering -struct IRInstKey -{ - IRInst* inst; - - int GetHashCode(); -}; - -bool operator==(IRInstKey const& left, IRInstKey const& right); - -struct IRConstantKey -{ - IRConstant* inst; - - bool operator==(const IRConstantKey& rhs) const { return inst->equal(rhs.inst); } - int GetHashCode() const { return inst->getHashCode(); } -}; - -struct SharedIRBuilder -{ - // The parent compilation session - Session* session; - Session* getSession() - { - return session; - } - - // The module that will own all of the IR - IRModule* module; - - Dictionary globalValueNumberingMap; - Dictionary constantMap; -}; - -struct IRBuilderSourceLocRAII; - -struct IRBuilder -{ - // Shared state for all IR builders working on the same module - SharedIRBuilder* sharedBuilder; - - Session* getSession() - { - return sharedBuilder->getSession(); - } - - IRModule* getModule() { return sharedBuilder->module; } - - // The current parent being inserted into (this might - // be the global scope, a function, a block inside - // a function, etc.) - IRInst* insertIntoParent = nullptr; - // - // An instruction in the current parent that we should insert before - IRInst* insertBeforeInst = nullptr; - - // Get the current basic block we are inserting into (if any) - IRBlock* getBlock(); - - // Get the current function (or other value with code) - // that we are inserting into (if any). - IRGlobalValueWithCode* getFunc(); - - void setInsertInto(IRInst* insertInto); - void setInsertBefore(IRInst* insertBefore); - - IRBuilderSourceLocRAII* sourceLocInfo = nullptr; - - void addInst(IRInst* inst); - - IRInst* getBoolValue(bool value); - IRInst* getIntValue(IRType* type, IRIntegerValue value); - IRInst* getFloatValue(IRType* type, IRFloatingPointValue value); - IRStringLit* getStringValue(const UnownedStringSlice& slice); - IRPtrLit* getPtrValue(void* value); - - IRBasicType* getBasicType(BaseType baseType); - IRBasicType* getVoidType(); - IRBasicType* getBoolType(); - IRBasicType* getIntType(); - IRStringType* getStringType(); - - IRBasicBlockType* getBasicBlockType(); - IRType* getWitnessTableType() { return nullptr; } - IRType* getKeyType() { return nullptr; } - - IRTypeKind* getTypeKind(); - IRGenericKind* getGenericKind(); - - IRPtrType* getPtrType(IRType* valueType); - IROutType* getOutType(IRType* valueType); - IRInOutType* getInOutType(IRType* valueType); - IRRefType* getRefType(IRType* valueType); - IRPtrTypeBase* getPtrType(IROp op, IRType* valueType); - - IRArrayTypeBase* getArrayTypeBase( - IROp op, - IRType* elementType, - IRInst* elementCount); - - IRArrayType* getArrayType( - IRType* elementType, - IRInst* elementCount); - - IRUnsizedArrayType* getUnsizedArrayType( - IRType* elementType); - - IRVectorType* getVectorType( - IRType* elementType, - IRInst* elementCount); - - IRMatrixType* getMatrixType( - IRType* elementType, - IRInst* rowCount, - IRInst* columnCount); - - IRFuncType* getFuncType( - UInt paramCount, - IRType* const* paramTypes, - IRType* resultType); - - IRFuncType* getFuncType( - List const& paramTypes, - IRType* resultType) - { - return getFuncType(paramTypes.getCount(), paramTypes.getBuffer(), resultType); - } - - IRConstantBufferType* getConstantBufferType( - IRType* elementType); - - IRConstExprRate* getConstExprRate(); - IRGroupSharedRate* getGroupSharedRate(); - - IRRateQualifiedType* getRateQualifiedType( - IRRate* rate, - IRType* dataType); - - IRType* getTaggedUnionType( - UInt caseCount, - IRType* const* caseTypes); - - IRType* getTaggedUnionType( - List const& caseTypes) - { - return getTaggedUnionType(caseTypes.getCount(), caseTypes.getBuffer()); - } - - IRType* getBindExistentialsType( - IRInst* baseType, - UInt slotArgCount, - IRInst* const* slotArgs); - - IRType* getBindExistentialsType( - IRInst* baseType, - UInt slotArgCount, - IRUse const* slotArgs); - - // Set the data type of an instruction, while preserving - // its rate, if any. - void setDataType(IRInst* inst, IRType* dataType); - - /// Given an existential value, extract the underlying "real" value - IRInst* emitExtractExistentialValue( - IRType* type, - IRInst* existentialValue); - - /// Given an existential value, extract the underlying "real" type - IRType* emitExtractExistentialType( - IRInst* existentialValue); - - /// Given an existential value, extract the witness table showing how the value conforms to the existential type. - IRInst* emitExtractExistentialWitnessTable( - IRInst* existentialValue); - - IRInst* emitSpecializeInst( - IRType* type, - IRInst* genericVal, - UInt argCount, - IRInst* const* args); - - IRInst* emitLookupInterfaceMethodInst( - IRType* type, - IRInst* witnessTableVal, - IRInst* interfaceMethodVal); - - IRInst* emitCallInst( - IRType* type, - IRInst* func, - UInt argCount, - IRInst* const* args); - - IRInst* emitCallInst( - IRType* type, - IRInst* func, - List const& args) - { - return emitCallInst(type, func, args.getCount(), args.getBuffer()); - } - - IRInst* createIntrinsicInst( - IRType* type, - IROp op, - UInt argCount, - IRInst* const* args); - - IRInst* emitIntrinsicInst( - IRType* type, - IROp op, - UInt argCount, - IRInst* const* args); - - IRInst* emitConstructorInst( - IRType* type, - UInt argCount, - IRInst* const* args); - - IRInst* emitMakeVector( - IRType* type, - UInt argCount, - IRInst* const* args); - - IRInst* emitMakeVector( - IRType* type, - List const& args) - { - return emitMakeVector(type, args.getCount(), args.getBuffer()); - } - - IRInst* emitMakeMatrix( - IRType* type, - UInt argCount, - IRInst* const* args); - - IRInst* emitMakeArray( - IRType* type, - UInt argCount, - IRInst* const* args); - - IRInst* emitMakeStruct( - IRType* type, - UInt argCount, - IRInst* const* args); - - IRInst* emitMakeStruct( - IRType* type, - List const& args) - { - return emitMakeStruct(type, args.getCount(), args.getBuffer()); - } - - IRInst* emitMakeExistential( - IRType* type, - IRInst* value, - IRInst* witnessTable); - - IRInst* emitWrapExistential( - IRType* type, - IRInst* value, - UInt slotArgCount, - IRInst* const* slotArgs); - - IRInst* emitWrapExistential( - IRType* type, - IRInst* value, - UInt slotArgCount, - IRUse const* slotArgs) - { - List slotArgVals; - for(UInt ii = 0; ii < slotArgCount; ++ii) - slotArgVals.add(slotArgs[ii].get()); - - return emitWrapExistential(type, value, slotArgCount, slotArgVals.getBuffer()); - } - - IRUndefined* emitUndefined(IRType* type); - - - - IRModule* createModule(); - - IRFunc* createFunc(); - IRGlobalVar* createGlobalVar( - IRType* valueType); - IRGlobalConstant* createGlobalConstant( - IRType* valueType); - IRGlobalParam* createGlobalParam( - IRType* valueType); - IRWitnessTable* createWitnessTable(); - IRWitnessTableEntry* createWitnessTableEntry( - IRWitnessTable* witnessTable, - IRInst* requirementKey, - IRInst* satisfyingVal); - - // Create an initially empty `struct` type. - IRStructType* createStructType(); - - // Create an empty `interface` type. - IRInterfaceType* createInterfaceType(); - - // Create a global "key" to use for indexing into a `struct` type. - IRStructKey* createStructKey(); - - // Create a field nested in a struct type, declaring that - // the specified field key maps to a field with the specified type. - IRStructField* createStructField( - IRStructType* structType, - IRStructKey* fieldKey, - IRType* fieldType); - - IRGeneric* createGeneric(); - IRGeneric* emitGeneric(); - - // Low-level operation for creating a type. - IRType* getType( - IROp op, - UInt operandCount, - IRInst* const* operands); - IRType* getType( - IROp op); - - /// Create an empty basic block. - /// - /// The created block will not be inserted into the current - /// function; call `insertBlock()` to attach the block - /// at an appropriate point. - /// - IRBlock* createBlock(); - - /// Insert a block into the current function. - /// - /// This attaches the given `block` to the current function, - /// and makes it the current block for - /// new instructions that get emitted. - /// - void insertBlock(IRBlock* block); - - /// Emit a new block into the current function. - /// - /// This function is equivalent to using `createBlock()` - /// and then `insertBlock()`. - /// - IRBlock* emitBlock(); - - - - IRParam* createParam( - IRType* type); - IRParam* emitParam( - IRType* type); - - IRVar* emitVar( - IRType* type); - - IRInst* emitLoad( - IRType* type, - IRInst* ptr); - - IRInst* emitLoad( - IRInst* ptr); - - IRInst* emitStore( - IRInst* dstPtr, - IRInst* srcVal); - - IRInst* emitFieldExtract( - IRType* type, - IRInst* base, - IRInst* field); - - IRInst* emitFieldAddress( - IRType* type, - IRInst* basePtr, - IRInst* field); - - IRInst* emitElementExtract( - IRType* type, - IRInst* base, - IRInst* index); - - IRInst* emitElementAddress( - IRType* type, - IRInst* basePtr, - IRInst* index); - - IRInst* emitSwizzle( - IRType* type, - IRInst* base, - UInt elementCount, - IRInst* const* elementIndices); - - IRInst* emitSwizzle( - IRType* type, - IRInst* base, - UInt elementCount, - UInt const* elementIndices); - - IRInst* emitSwizzleSet( - IRType* type, - IRInst* base, - IRInst* source, - UInt elementCount, - IRInst* const* elementIndices); - - IRInst* emitSwizzleSet( - IRType* type, - IRInst* base, - IRInst* source, - UInt elementCount, - UInt const* elementIndices); - - IRInst* emitSwizzledStore( - IRInst* dest, - IRInst* source, - UInt elementCount, - IRInst* const* elementIndices); - - IRInst* emitSwizzledStore( - IRInst* dest, - IRInst* source, - UInt elementCount, - UInt const* elementIndices); - - - - IRInst* emitReturn( - IRInst* val); - - IRInst* emitReturn(); - - IRInst* emitDiscard(); - - IRInst* emitUnreachable(); - IRInst* emitMissingReturn(); - - IRInst* emitBranch( - IRBlock* block); - - IRInst* emitBreak( - IRBlock* target); - - IRInst* emitContinue( - IRBlock* target); - - IRInst* emitLoop( - IRBlock* target, - IRBlock* breakBlock, - IRBlock* continueBlock); - - IRInst* emitBranch( - IRInst* val, - IRBlock* trueBlock, - IRBlock* falseBlock); - - IRInst* emitIf( - IRInst* val, - IRBlock* trueBlock, - IRBlock* afterBlock); - - IRInst* emitIfElse( - IRInst* val, - IRBlock* trueBlock, - IRBlock* falseBlock, - IRBlock* afterBlock); - - IRInst* emitLoopTest( - IRInst* val, - IRBlock* bodyBlock, - IRBlock* breakBlock); - - IRInst* emitSwitch( - IRInst* val, - IRBlock* breakLabel, - IRBlock* defaultLabel, - UInt caseArgCount, - IRInst* const* caseArgs); - - IRGlobalGenericParam* emitGlobalGenericParam(); - - IRBindGlobalGenericParam* emitBindGlobalGenericParam( - IRInst* param, - IRInst* val); - - IRInst* emitBindGlobalExistentialSlots( - UInt argCount, - IRInst* const* args); - - IRDecoration* addBindExistentialSlotsDecoration( - IRInst* value, - UInt argCount, - IRInst* const* args); - - IRInst* emitExtractTaggedUnionTag( - IRInst* val); - - IRInst* emitExtractTaggedUnionPayload( - IRType* type, - IRInst* val, - IRInst* tag); - - IRInst* emitBitCast( - IRType* type, - IRInst* val); - - // - // Decorations - // - - IRDecoration* addDecoration(IRInst* value, IROp op, IRInst* const* operands, Int operandCount); - - IRDecoration* addDecoration(IRInst* value, IROp op) - { - return addDecoration(value, op, (IRInst* const*) nullptr, 0); - } - - IRDecoration* addDecoration(IRInst* value, IROp op, IRInst* operand) - { - return addDecoration(value, op, &operand, 1); - } - - IRDecoration* addDecoration(IRInst* value, IROp op, IRInst* operand0, IRInst* operand1) - { - IRInst* operands[] = { operand0, operand1 }; - return addDecoration(value, op, operands, SLANG_COUNT_OF(operands)); - } - - template - T* addRefObjectToFree(T* ptr) - { - getModule()->getObjectScopeManager()->addMaybeNull(ptr); - return ptr; - } - - template - void addSimpleDecoration(IRInst* value) - { - addDecoration(value, IROp(T::kOp), (IRInst* const*) nullptr, 0); - } - - void addHighLevelDeclDecoration(IRInst* value, Decl* decl); - void addLayoutDecoration(IRInst* value, Layout* layout); - - void addNameHintDecoration(IRInst* value, IRStringLit* name) - { - addDecoration(value, kIROp_NameHintDecoration, name); - } - - void addNameHintDecoration(IRInst* value, UnownedStringSlice const& text) - { - addNameHintDecoration(value, getStringValue(text)); - } - - void addGLSLOuterArrayDecoration(IRInst* value, UnownedStringSlice const& text) - { - addDecoration(value, kIROp_GLSLOuterArrayDecoration, getStringValue(text)); - } - - void addInterpolationModeDecoration(IRInst* value, IRInterpolationMode mode) - { - addDecoration(value, kIROp_InterpolationModeDecoration, getIntValue(getIntType(), IRIntegerValue(mode))); - } - - void addLoopControlDecoration(IRInst* value, IRLoopControl mode) - { - addDecoration(value, kIROp_LoopControlDecoration, getIntValue(getIntType(), IRIntegerValue(mode))); - } - - void addSemanticDecoration(IRInst* value, UnownedStringSlice const& text) - { - addDecoration(value, kIROp_SemanticDecoration, getStringValue(text)); - } - - void addTargetIntrinsicDecoration(IRInst* value, UnownedStringSlice const& target, UnownedStringSlice const& definition) - { - addDecoration(value, kIROp_TargetIntrinsicDecoration, getStringValue(target), getStringValue(definition)); - } - - void addTargetDecoration(IRInst* value, UnownedStringSlice const& target) - { - addDecoration(value, kIROp_TargetDecoration, getStringValue(target)); - } - - void addRequireGLSLExtensionDecoration(IRInst* value, UnownedStringSlice const& extensionName) - { - addDecoration(value, kIROp_RequireGLSLExtensionDecoration, getStringValue(extensionName)); - } - - void addRequireGLSLVersionDecoration(IRInst* value, Int version) - { - addDecoration(value, kIROp_RequireGLSLVersionDecoration, getIntValue(getIntType(), IRIntegerValue(version))); - } - - void addPatchConstantFuncDecoration(IRInst* value, IRInst* patchConstantFunc) - { - addDecoration(value, kIROp_PatchConstantFuncDecoration, patchConstantFunc); - } - - void addImportDecoration(IRInst* value, UnownedStringSlice const& mangledName) - { - addDecoration(value, kIROp_ImportDecoration, getStringValue(mangledName)); - } - - void addExportDecoration(IRInst* value, UnownedStringSlice const& mangledName) - { - addDecoration(value, kIROp_ExportDecoration, getStringValue(mangledName)); - } - - void addEntryPointDecoration(IRInst* value) - { - addDecoration(value, kIROp_EntryPointDecoration); - } - - void addKeepAliveDecoration(IRInst* value) - { - addDecoration(value, kIROp_KeepAliveDecoration); - } - - /// Add a decoration that indicates that the given `inst` depends on the given `dependency`. - /// - /// This decoration can be used to ensure that a value that an instruction - /// implicitly depends on cannot be eliminated so long as the instruction - /// itself is kept alive. - /// - void addDependsOnDecoration(IRInst* inst, IRInst* dependency) - { - addDecoration(inst, kIROp_DependsOnDecoration, dependency); - } - - void addFormatDecoration(IRInst* inst, ImageFormat format) - { - addFormatDecoration(inst, getIntValue(getIntType(), IRIntegerValue(format))); - } - - void addFormatDecoration(IRInst* inst, IRInst* format) - { - addDecoration(inst, kIROp_FormatDecoration, format); - } -}; - -void addHoistableInst( - IRBuilder* builder, - IRInst* inst); - -// Helper to establish the source location that will be used -// by an IRBuilder. -struct IRBuilderSourceLocRAII -{ - IRBuilder* builder; - SourceLoc sourceLoc; - IRBuilderSourceLocRAII* next; - - IRBuilderSourceLocRAII( - IRBuilder* builder, - SourceLoc sourceLoc) - : builder(builder) - , sourceLoc(sourceLoc) - , next(nullptr) - { - next = builder->sourceLocInfo; - builder->sourceLocInfo = this; - } - - ~IRBuilderSourceLocRAII() - { - SLANG_ASSERT(builder->sourceLocInfo == this); - builder->sourceLocInfo = next; - } -}; - -// - -void markConstExpr( - IRBuilder* builder, - IRInst* irValue); - -// - -IRTargetIntrinsicDecoration* findTargetIntrinsicDecoration( - IRInst* val, - String const& targetName); - -} - -#endif diff --git a/source/slang/ir-legalize-types.cpp b/source/slang/ir-legalize-types.cpp deleted file mode 100644 index 18039315e..000000000 --- a/source/slang/ir-legalize-types.cpp +++ /dev/null @@ -1,2626 +0,0 @@ -// ir-legalize-types.cpp - -// This file implements type legalization for the IR. -// It uses the core legalization logic in -// `legalize-types.{h,cpp}` to decide what to do with -// the types, while this file handles the actual -// rewriting of the IR to use the new types. -// -// This pass should only be applied to IR that has been -// fully specialized (no more generics/interfaces), so -// that the concrete type of everything is known. - -#include "ir.h" -#include "ir-clone.h" -#include "ir-insts.h" -#include "legalize-types.h" -#include "mangle.h" -#include "name.h" - -namespace Slang -{ - -LegalVal LegalVal::tuple(RefPtr tupleVal) -{ - SLANG_ASSERT(tupleVal->elements.getCount()); - - LegalVal result; - result.flavor = LegalVal::Flavor::tuple; - result.obj = tupleVal; - return result; -} - -LegalVal LegalVal::pair(RefPtr pairInfo) -{ - LegalVal result; - result.flavor = LegalVal::Flavor::pair; - result.obj = pairInfo; - return result; -} - -LegalVal LegalVal::pair( - LegalVal const& ordinaryVal, - LegalVal const& specialVal, - RefPtr pairInfo) -{ - if (ordinaryVal.flavor == LegalVal::Flavor::none) - return specialVal; - - if (specialVal.flavor == LegalVal::Flavor::none) - return ordinaryVal; - - - RefPtr obj = new PairPseudoVal(); - obj->ordinaryVal = ordinaryVal; - obj->specialVal = specialVal; - obj->pairInfo = pairInfo; - - return LegalVal::pair(obj); -} - -LegalVal LegalVal::implicitDeref(LegalVal const& val) -{ - RefPtr implicitDerefVal = new ImplicitDerefVal(); - implicitDerefVal->val = val; - - LegalVal result; - result.flavor = LegalVal::Flavor::implicitDeref; - result.obj = implicitDerefVal; - return result; -} - -LegalVal LegalVal::getImplicitDeref() -{ - SLANG_ASSERT(flavor == Flavor::implicitDeref); - return as(obj)->val; -} - -LegalVal LegalVal::wrappedBuffer( - LegalVal const& baseVal, - LegalElementWrapping const& elementInfo) -{ - RefPtr obj = new WrappedBufferPseudoVal(); - obj->base = baseVal; - obj->elementInfo = elementInfo; - - LegalVal result; - result.flavor = LegalVal::Flavor::wrappedBuffer; - result.obj = obj; - return result; -} - -// - -IRTypeLegalizationContext::IRTypeLegalizationContext( - IRModule* inModule) -{ - session = inModule->getSession(); - module = inModule; - - auto sharedBuilder = &sharedBuilderStorage; - sharedBuilder->session = session; - sharedBuilder->module = module; - - builder = &builderStorage; - builder->sharedBuilder = sharedBuilder; -} - -static void registerLegalizedValue( - IRTypeLegalizationContext* context, - IRInst* irValue, - LegalVal const& legalVal) -{ - context->mapValToLegalVal[irValue] = legalVal; -} - -struct IRGlobalNameInfo -{ - IRInst* globalVar; - UInt counter; -}; - -static LegalVal declareVars( - IRTypeLegalizationContext* context, - IROp op, - LegalType type, - TypeLayout* typeLayout, - LegalVarChain const& varChain, - UnownedStringSlice nameHint, - IRInst* leafVar, - IRGlobalNameInfo* globalNameInfo, - bool isSpecial); - - /// Unwrap a value with flavor `wrappedBuffer` - /// - /// The original `legalPtrOperand` has a wrapped-buffer type - /// which encodes the way that, e.g., a `ConstantBuffer` - /// where `Foo` includes interface types, got legalized - /// into a buffer that stores a `Foo` value plus addition - /// fields for the concrete types that got plugged in. - /// - /// The `elementInfo` is the layout information for the - /// modified ("wrapped") buffer type, and specifies how - /// the logical element type was expanded into actual fields. - /// - /// This function returns a new value that undoes all of - /// the wrapping and produces a new `LegalVal` that matches - /// the nominal type of the original buffer. - /// -static LegalVal unwrapBufferValue( - IRTypeLegalizationContext* context, - LegalVal legalPtrOperand, - LegalElementWrapping const& elementInfo); - - /// Perform any actions required to materialize `val` into a usable value. - /// - /// Certain case of `LegalVal` (currently just the `wrappedBuffer` case) are - /// suitable for use to represent a variable, but cannot be used directly - /// in computations, because their structured needs to be "unwrapped." - /// - /// This function unwraps any `val` that needs it, which may involve - /// emitting additional IR instructions, and returns the unmodified - /// `val` otherwise. - /// -static LegalVal maybeMaterializeWrappedValue( - IRTypeLegalizationContext* context, - LegalVal val) -{ - if(val.flavor != LegalVal::Flavor::wrappedBuffer) - return val; - - auto wrappedBufferVal = val.getWrappedBuffer(); - return unwrapBufferValue( - context, - wrappedBufferVal->base, - wrappedBufferVal->elementInfo); -} - -// Take a value that is being used as an operand, -// and turn it into the equivalent legalized value. -static LegalVal legalizeOperand( - IRTypeLegalizationContext* context, - IRInst* irValue) -{ - LegalVal legalVal; - if( context->mapValToLegalVal.TryGetValue(irValue, legalVal) ) - { - return maybeMaterializeWrappedValue(context, legalVal); - } - - // For now, assume that anything not covered - // by the mapping is legal as-is. - - return LegalVal::simple(irValue); -} - -static void getArgumentValues( - List & instArgs, - LegalVal val) -{ - switch (val.flavor) - { - case LegalVal::Flavor::none: - break; - - case LegalVal::Flavor::simple: - instArgs.add(val.getSimple()); - break; - - case LegalVal::Flavor::implicitDeref: - getArgumentValues(instArgs, val.getImplicitDeref()); - break; - - case LegalVal::Flavor::pair: - { - auto pairVal = val.getPair(); - getArgumentValues(instArgs, pairVal->ordinaryVal); - getArgumentValues(instArgs, pairVal->specialVal); - } - break; - - case LegalVal::Flavor::tuple: - { - auto tuplePsuedoVal = val.getTuple(); - for (auto elem : val.getTuple()->elements) - { - getArgumentValues(instArgs, elem.val); - } - } - break; - - default: - SLANG_UNEXPECTED("uhandled val flavor"); - break; - } -} - -static LegalVal legalizeCall( - IRTypeLegalizationContext* context, - IRCall* callInst) -{ - auto retType = legalizeType(context, callInst->getFullType()); - IRType* retIRType = nullptr; - switch (retType.flavor) - { - case LegalType::Flavor::simple: - retIRType = retType.getSimple(); - break; - case LegalType::Flavor::none: - retIRType = context->builder->getVoidType(); - break; - default: - // TODO: implement legalization of non-simple return types - SLANG_UNEXPECTED("unimplemented legalized return type for IRInstCall."); - } - - List instArgs; - for (auto i = 1u; i < callInst->getOperandCount(); i++) - getArgumentValues(instArgs, legalizeOperand(context, callInst->getOperand(i))); - - return LegalVal::simple(context->builder->emitCallInst( - retIRType, - callInst->getCallee(), - instArgs.getCount(), - instArgs.getBuffer())); -} - -static LegalVal legalizeRetVal(IRTypeLegalizationContext* context, - LegalVal retVal) -{ - switch (retVal.flavor) - { - case LegalVal::Flavor::simple: - return LegalVal::simple(context->builder->emitReturn(retVal.getSimple())); - case LegalVal::Flavor::none: - return LegalVal::simple(context->builder->emitReturn()); - default: - // TODO: implement legalization of non-simple return types - SLANG_UNEXPECTED("unimplemented legalized return type for IRReturnVal."); - } -} - -static LegalVal legalizeLoad( - IRTypeLegalizationContext* context, - LegalVal legalPtrVal) -{ - switch (legalPtrVal.flavor) - { - case LegalVal::Flavor::none: - return LegalVal(); - - case LegalVal::Flavor::simple: - { - return LegalVal::simple( - context->builder->emitLoad(legalPtrVal.getSimple())); - } - break; - - case LegalVal::Flavor::implicitDeref: - // We have turne a pointer(-like) type into its pointed-to (value) - // type, and so the operation of loading goes away; we just use - // the underlying value. - return legalPtrVal.getImplicitDeref(); - - case LegalVal::Flavor::pair: - { - auto ptrPairVal = legalPtrVal.getPair(); - - auto ordinaryVal = legalizeLoad(context, ptrPairVal->ordinaryVal); - auto specialVal = legalizeLoad(context, ptrPairVal->specialVal); - return LegalVal::pair(ordinaryVal, specialVal, ptrPairVal->pairInfo); - } - - case LegalVal::Flavor::tuple: - { - // We need to emit a load for each element of - // the tuple. - auto ptrTupleVal = legalPtrVal.getTuple(); - RefPtr tupleVal = new TuplePseudoVal(); - - for (auto ee : legalPtrVal.getTuple()->elements) - { - TuplePseudoVal::Element element; - element.key = ee.key; - element.val = legalizeLoad(context, ee.val); - - tupleVal->elements.add(element); - } - return LegalVal::tuple(tupleVal); - } - break; - - default: - SLANG_UNEXPECTED("unhandled case"); - break; - } -} - -static LegalVal legalizeStore( - IRTypeLegalizationContext* context, - LegalVal legalPtrVal, - LegalVal legalVal) -{ - switch (legalPtrVal.flavor) - { - case LegalVal::Flavor::none: - return LegalVal(); - - case LegalVal::Flavor::simple: - { - context->builder->emitStore(legalPtrVal.getSimple(), legalVal.getSimple()); - return legalVal; - } - break; - - case LegalVal::Flavor::implicitDeref: - // TODO: what is the right behavior here? - // - // The crux of the problem is that we may legalize a pointer-to-pointer - // type in cases where one of the two needs to become an implicit-deref, - // so that we have `PtrA>` become, say, `PtrA` with - // an `implicitDeref` wrapper. When we encounter a store to that - // wrapped value, we seemingly need to know whether the original code - // meant to store to `*ptrPtr` or `**ptrPtr`, and need to legalize - // the result accordingly... - // - if( legalVal.flavor == LegalVal::Flavor::implicitDeref ) - return legalizeStore(context, legalPtrVal.getImplicitDeref(), legalVal.getImplicitDeref()); - else - return legalizeStore(context, legalPtrVal.getImplicitDeref(), legalVal); - - case LegalVal::Flavor::pair: - { - auto destPair = legalPtrVal.getPair(); - auto valPair = legalVal.getPair(); - legalizeStore(context, destPair->ordinaryVal, valPair->ordinaryVal); - legalizeStore(context, destPair->specialVal, valPair->specialVal); - return LegalVal(); - } - - case LegalVal::Flavor::tuple: - { - // We need to emit a store for each element of - // the tuple. - auto destTuple = legalPtrVal.getTuple(); - auto valTuple = legalVal.getTuple(); - SLANG_ASSERT(destTuple->elements.getCount() == valTuple->elements.getCount()); - - for (Index i = 0; i < valTuple->elements.getCount(); i++) - { - legalizeStore(context, destTuple->elements[i].val, valTuple->elements[i].val); - } - return legalVal; - } - break; - - default: - SLANG_UNEXPECTED("unhandled case"); - break; - } -} - -static LegalVal legalizeFieldExtract( - IRTypeLegalizationContext* context, - LegalType type, - LegalVal legalStructOperand, - IRStructKey* fieldKey) -{ - auto builder = context->builder; - - if (type.flavor == LegalType::Flavor::none) - return LegalVal(); - - switch (legalStructOperand.flavor) - { - case LegalVal::Flavor::none: - return LegalVal(); - - case LegalVal::Flavor::simple: - return LegalVal::simple( - builder->emitFieldExtract( - type.getSimple(), - legalStructOperand.getSimple(), - fieldKey)); - - case LegalVal::Flavor::pair: - { - // There are two sides, the ordinary and the special, - // and we basically just dispatch to both of them. - auto pairVal = legalStructOperand.getPair(); - auto pairInfo = pairVal->pairInfo; - auto pairElement = pairInfo->findElement(fieldKey); - if (!pairElement) - { - SLANG_UNEXPECTED("didn't find tuple element"); - UNREACHABLE_RETURN(LegalVal()); - } - - // If the field we are extracting has a pair type, - // that means it exists on both the ordinary and - // special sides. - RefPtr fieldPairInfo; - LegalType ordinaryType = type; - LegalType specialType = type; - if (type.flavor == LegalType::Flavor::pair) - { - auto fieldPairType = type.getPair(); - fieldPairInfo = fieldPairType->pairInfo; - ordinaryType = fieldPairType->ordinaryType; - specialType = fieldPairType->specialType; - } - - LegalVal ordinaryVal; - LegalVal specialVal; - - if (pairElement->flags & PairInfo::kFlag_hasOrdinary) - { - ordinaryVal = legalizeFieldExtract( - context, - ordinaryType, - pairVal->ordinaryVal, - fieldKey); - } - - if (pairElement->flags & PairInfo::kFlag_hasSpecial) - { - specialVal = legalizeFieldExtract( - context, - specialType, - pairVal->specialVal, - fieldKey); - } - return LegalVal::pair(ordinaryVal, specialVal, fieldPairInfo); - } - break; - - case LegalVal::Flavor::tuple: - { - // The operand is a tuple of pointer-like - // values, we want to extract the element - // corresponding to a field. We will handle - // this by simply returning the corresponding - // element from the operand. - auto ptrTupleInfo = legalStructOperand.getTuple(); - for (auto ee : ptrTupleInfo->elements) - { - if (ee.key == fieldKey) - { - return ee.val; - } - } - - // TODO: we can legally reach this case now - // when the field is "ordinary". - - SLANG_UNEXPECTED("didn't find tuple element"); - UNREACHABLE_RETURN(LegalVal()); - } - - default: - SLANG_UNEXPECTED("unhandled"); - UNREACHABLE_RETURN(LegalVal()); - } -} - -static LegalVal legalizeFieldExtract( - IRTypeLegalizationContext* context, - LegalType type, - LegalVal legalPtrOperand, - LegalVal legalFieldOperand) -{ - // We don't expect any legalization to affect - // the "field" argument. - auto fieldKey = legalFieldOperand.getSimple(); - - return legalizeFieldExtract( - context, - type, - legalPtrOperand, - (IRStructKey*) fieldKey); -} - - /// Take a value of some buffer/pointer type and unwrap it according to provided info. -static LegalVal unwrapBufferValue( - IRTypeLegalizationContext* context, - LegalVal legalPtrOperand, - LegalElementWrapping const& elementInfo) -{ - // The `elementInfo` tells us how a non-simple element - // type was wrapped up into a new structure types used - // as the element type of the buffer. - // - // This function will recurse through the structure of - // `elementInfo` to pull out all the required data from - // the buffer represented by `legalPtrOperand`. - - switch( elementInfo.flavor ) - { - default: - SLANG_UNEXPECTED("unhandled"); - UNREACHABLE_RETURN(LegalVal()); - break; - - case LegalElementWrapping::Flavor::none: - return LegalVal(); - - case LegalElementWrapping::Flavor::simple: - { - // In the leaf case, we just had to store some - // data of a simple type in the buffer. We can - // produce a valid result by computing the - // address of the field used to represent the - // element, and then returning *that* as if - // it were the buffer type itself. - // - // (Basically instead of `someBuffer` we will - // end up with `&(someBuffer->field)`. - // - auto builder = context->getBuilder(); - - auto simpleElementInfo = elementInfo.getSimple(); - auto valPtr = builder->emitFieldAddress( - builder->getPtrType(simpleElementInfo->type), - legalPtrOperand.getSimple(), - simpleElementInfo->key); - - return LegalVal::simple(valPtr); - } - - case LegalElementWrapping::Flavor::implicitDeref: - { - // If the element type was logically `ImplicitDeref`, - // then we declared actual fields based on `T`, and - // we need to extract references to those fields and - // wrap them up in an `implicitDeref` value. - // - auto derefField = elementInfo.getImplicitDeref(); - auto baseVal = unwrapBufferValue(context, legalPtrOperand, derefField->field); - return LegalVal::implicitDeref(baseVal); - } - - case LegalElementWrapping::Flavor::pair: - { - // If the element type was logically a `Pair` - // then we encoded fields for both `O` and `S` into - // the actual element type, and now we need to - // extract references to both and pair them up. - // - auto pairField = elementInfo.getPair(); - auto pairInfo = pairField->pairInfo; - - auto ordinaryVal = unwrapBufferValue(context, legalPtrOperand, pairField->ordinary); - auto specialVal = unwrapBufferValue(context, legalPtrOperand, pairField->special); - return LegalVal::pair(ordinaryVal, specialVal, pairInfo); - } - - case LegalElementWrapping::Flavor::tuple: - { - // If the element type was logically a `Tuple` - // then we encoded fields for each of the `Ei` and - // need to extract references to all of them and - // encode them as a tuple. - // - auto tupleField = elementInfo.getTuple(); - - RefPtr obj = new TuplePseudoVal(); - for( auto ee : tupleField->elements ) - { - auto elementVal = unwrapBufferValue( - context, - legalPtrOperand, - ee.field); - - TuplePseudoVal::Element element; - element.key = ee.key; - element.val = unwrapBufferValue( - context, - legalPtrOperand, - ee.field); - obj->elements.add(element); - } - - return LegalVal::tuple(obj); - } - } -} - -static IRType* getPointedToType( - IRTypeLegalizationContext* context, - IRType* ptrType) -{ - auto valueType = tryGetPointedToType(context->builder, ptrType); - if( !valueType ) - { - SLANG_UNEXPECTED("expected a pointer type during type legalization"); - } - return valueType; -} - -static LegalType getPointedToType( - IRTypeLegalizationContext* context, - LegalType type) -{ - switch( type.flavor ) - { - case LegalType::Flavor::none: - return LegalType(); - - case LegalType::Flavor::simple: - return LegalType::simple(getPointedToType(context, type.getSimple())); - - case LegalType::Flavor::implicitDeref: - return type.getImplicitDeref()->valueType; - - case LegalType::Flavor::pair: - { - auto pairType = type.getPair(); - auto ordinary = getPointedToType(context, pairType->ordinaryType); - auto special = getPointedToType(context, pairType->specialType); - return LegalType::pair(ordinary, special, pairType->pairInfo); - } - - case LegalType::Flavor::tuple: - { - auto tupleType = type.getTuple(); - RefPtr resultTuple = new TuplePseudoType(); - for( auto ee : tupleType->elements ) - { - TuplePseudoType::Element resultElement; - resultElement.key = ee.key; - resultElement.type = getPointedToType(context, ee.type); - resultTuple->elements.add(resultElement); - } - return LegalType::tuple(resultTuple); - } - - default: - SLANG_UNEXPECTED("unhandled case in type legalization"); - UNREACHABLE_RETURN(LegalType()); - } -} - -static LegalVal legalizeFieldAddress( - IRTypeLegalizationContext* context, - LegalType type, - LegalVal legalPtrOperand, - IRStructKey* fieldKey) -{ - auto builder = context->builder; - if (type.flavor == LegalType::Flavor::none) - return LegalVal(); - - switch (legalPtrOperand.flavor) - { - case LegalVal::Flavor::none: - return LegalVal(); - - case LegalVal::Flavor::simple: - switch( type.flavor ) - { - case LegalType::Flavor::implicitDeref: - // TODO: Should this case be needed? - return legalizeFieldAddress( - context, - type.getImplicitDeref()->valueType, - legalPtrOperand, - fieldKey); - - default: - return LegalVal::simple( - builder->emitFieldAddress( - type.getSimple(), - legalPtrOperand.getSimple(), - fieldKey)); - } - - case LegalVal::Flavor::pair: - { - // There are two sides, the ordinary and the special, - // and we basically just dispatch to both of them. - auto pairVal = legalPtrOperand.getPair(); - auto pairInfo = pairVal->pairInfo; - auto pairElement = pairInfo->findElement(fieldKey); - if (!pairElement) - { - SLANG_UNEXPECTED("didn't find tuple element"); - UNREACHABLE_RETURN(LegalVal()); - } - - // If the field we are extracting has a pair type, - // that means it exists on both the ordinary and - // special sides. - RefPtr fieldPairInfo; - LegalType ordinaryType = type; - LegalType specialType = type; - if (type.flavor == LegalType::Flavor::pair) - { - auto fieldPairType = type.getPair(); - fieldPairInfo = fieldPairType->pairInfo; - ordinaryType = fieldPairType->ordinaryType; - specialType = fieldPairType->specialType; - } - - LegalVal ordinaryVal; - LegalVal specialVal; - - if (pairElement->flags & PairInfo::kFlag_hasOrdinary) - { - ordinaryVal = legalizeFieldAddress( - context, - ordinaryType, - pairVal->ordinaryVal, - fieldKey); - } - - if (pairElement->flags & PairInfo::kFlag_hasSpecial) - { - specialVal = legalizeFieldAddress( - context, - specialType, - pairVal->specialVal, - fieldKey); - } - return LegalVal::pair(ordinaryVal, specialVal, fieldPairInfo); - } - break; - - case LegalVal::Flavor::tuple: - { - // The operand is a tuple of pointer-like - // values, we want to extract the element - // corresponding to a field. We will handle - // this by simply returning the corresponding - // element from the operand. - auto ptrTupleInfo = legalPtrOperand.getTuple(); - for (auto ee : ptrTupleInfo->elements) - { - if (ee.key == fieldKey) - { - return ee.val; - } - } - - // TODO: we can legally reach this case now - // when the field is "ordinary". - - SLANG_UNEXPECTED("didn't find tuple element"); - UNREACHABLE_RETURN(LegalVal()); - } - - case LegalVal::Flavor::implicitDeref: - { - // The original value had a level of indirection - // that is now being removed, so should not be - // able to get at the *address* of the field any - // more, and need to resign ourselves to just - // getting at the field *value* and then - // adding an implicit dereference on top of that. - // - auto implicitDerefVal = legalPtrOperand.getImplicitDeref(); - auto valueType = getPointedToType(context, type); - return LegalVal::implicitDeref(legalizeFieldExtract(context, valueType, implicitDerefVal, fieldKey)); - } - - default: - SLANG_UNEXPECTED("unhandled"); - UNREACHABLE_RETURN(LegalVal()); - } -} - -static LegalVal legalizeFieldAddress( - IRTypeLegalizationContext* context, - LegalType type, - LegalVal legalPtrOperand, - LegalVal legalFieldOperand) -{ - // We don't expect any legalization to affect - // the "field" argument. - auto fieldKey = legalFieldOperand.getSimple(); - - return legalizeFieldAddress( - context, - type, - legalPtrOperand, - (IRStructKey*) fieldKey); -} - -static LegalVal legalizeGetElement( - IRTypeLegalizationContext* context, - LegalType type, - LegalVal legalPtrOperand, - IRInst* indexOperand) -{ - auto builder = context->builder; - - switch (legalPtrOperand.flavor) - { - case LegalVal::Flavor::none: - return LegalVal(); - - case LegalVal::Flavor::simple: - return LegalVal::simple( - builder->emitElementExtract( - type.getSimple(), - legalPtrOperand.getSimple(), - indexOperand)); - - case LegalVal::Flavor::pair: - { - // There are two sides, the ordinary and the special, - // and we basically just dispatch to both of them. - auto pairVal = legalPtrOperand.getPair(); - auto pairInfo = pairVal->pairInfo; - - LegalType ordinaryType = type; - LegalType specialType = type; - if (type.flavor == LegalType::Flavor::pair) - { - auto pairType = type.getPair(); - ordinaryType = pairType->ordinaryType; - specialType = pairType->specialType; - } - - LegalVal ordinaryVal = legalizeGetElement( - context, - ordinaryType, - pairVal->ordinaryVal, - indexOperand); - - LegalVal specialVal = legalizeGetElement( - context, - specialType, - pairVal->specialVal, - indexOperand); - - return LegalVal::pair(ordinaryVal, specialVal, pairInfo); - } - break; - - case LegalVal::Flavor::tuple: - { - // The operand is a tuple of pointer-like - // values, we want to extract the element - // corresponding to a field. We will handle - // this by simply returning the corresponding - // element from the operand. - auto ptrTupleInfo = legalPtrOperand.getTuple(); - - RefPtr resTupleInfo = new TuplePseudoVal(); - - auto tupleType = type.getTuple(); - SLANG_ASSERT(tupleType); - - auto elemCount = ptrTupleInfo->elements.getCount(); - SLANG_ASSERT(elemCount == tupleType->elements.getCount()); - - for(Index ee = 0; ee < elemCount; ++ee) - { - auto ptrElem = ptrTupleInfo->elements[ee]; - auto elemType = tupleType->elements[ee].type; - - TuplePseudoVal::Element resElem; - resElem.key = ptrElem.key; - resElem.val = legalizeGetElement( - context, - elemType, - ptrElem.val, - indexOperand); - - resTupleInfo->elements.add(resElem); - } - - return LegalVal::tuple(resTupleInfo); - } - - default: - SLANG_UNEXPECTED("unhandled"); - UNREACHABLE_RETURN(LegalVal()); - } -} - -static LegalVal legalizeGetElement( - IRTypeLegalizationContext* context, - LegalType type, - LegalVal legalPtrOperand, - LegalVal legalIndexOperand) -{ - // We don't expect any legalization to affect - // the "index" argument. - auto indexOperand = legalIndexOperand.getSimple(); - - return legalizeGetElement( - context, - type, - legalPtrOperand, - indexOperand); -} - -static LegalVal legalizeGetElementPtr( - IRTypeLegalizationContext* context, - LegalType type, - LegalVal legalPtrOperand, - IRInst* indexOperand) -{ - auto builder = context->builder; - - switch (legalPtrOperand.flavor) - { - case LegalVal::Flavor::none: - return LegalVal(); - - case LegalVal::Flavor::simple: - return LegalVal::simple( - builder->emitElementAddress( - type.getSimple(), - legalPtrOperand.getSimple(), - indexOperand)); - - case LegalVal::Flavor::pair: - { - // There are two sides, the ordinary and the special, - // and we basically just dispatch to both of them. - auto pairVal = legalPtrOperand.getPair(); - auto pairInfo = pairVal->pairInfo; - - LegalType ordinaryType = type; - LegalType specialType = type; - if (type.flavor == LegalType::Flavor::pair) - { - auto pairType = type.getPair(); - ordinaryType = pairType->ordinaryType; - specialType = pairType->specialType; - } - - LegalVal ordinaryVal = legalizeGetElementPtr( - context, - ordinaryType, - pairVal->ordinaryVal, - indexOperand); - - LegalVal specialVal = legalizeGetElementPtr( - context, - specialType, - pairVal->specialVal, - indexOperand); - - return LegalVal::pair(ordinaryVal, specialVal, pairInfo); - } - break; - - case LegalVal::Flavor::tuple: - { - // The operand is a tuple of pointer-like - // values, we want to extract the element - // corresponding to a field. We will handle - // this by simply returning the corresponding - // element from the operand. - auto ptrTupleInfo = legalPtrOperand.getTuple(); - - RefPtr resTupleInfo = new TuplePseudoVal(); - - auto tupleType = type.getTuple(); - SLANG_ASSERT(tupleType); - - auto elemCount = ptrTupleInfo->elements.getCount(); - SLANG_ASSERT(elemCount == tupleType->elements.getCount()); - - for(Index ee = 0; ee < elemCount; ++ee) - { - auto ptrElem = ptrTupleInfo->elements[ee]; - auto elemType = tupleType->elements[ee].type; - - TuplePseudoVal::Element resElem; - resElem.key = ptrElem.key; - resElem.val = legalizeGetElementPtr( - context, - elemType, - ptrElem.val, - indexOperand); - - resTupleInfo->elements.add(resElem); - } - - return LegalVal::tuple(resTupleInfo); - } - - case LegalVal::Flavor::implicitDeref: - { - // The original value used to be a pointer to an array, - // and somebody is trying to get at an element pointer. - // Now we just have an array (wrapped with an implicit - // dereference) and need to just fetch the chosen element - // instead (and then wrap the element value with an - // implicit dereference). - // - // The result type for our `getElement` instruction needs - // to be the type *pointed to* by `type`, and not `type. - // - auto valueType = getPointedToType(context, type); - - auto implicitDerefVal = legalPtrOperand.getImplicitDeref(); - return LegalVal::implicitDeref(legalizeGetElement( - context, - valueType, - implicitDerefVal, - indexOperand)); - } - - default: - SLANG_UNEXPECTED("unhandled"); - UNREACHABLE_RETURN(LegalVal()); - } -} - -static LegalVal legalizeGetElementPtr( - IRTypeLegalizationContext* context, - LegalType type, - LegalVal legalPtrOperand, - LegalVal legalIndexOperand) -{ - // We don't expect any legalization to affect - // the "index" argument. - auto indexOperand = legalIndexOperand.getSimple(); - - return legalizeGetElementPtr( - context, - type, - legalPtrOperand, - indexOperand); -} - -static LegalVal legalizeMakeStruct( - IRTypeLegalizationContext* context, - LegalType legalType, - LegalVal const* legalArgs, - UInt argCount) -{ - auto builder = context->builder; - - switch(legalType.flavor) - { - case LegalType::Flavor::none: - return LegalVal(); - - case LegalType::Flavor::simple: - { - List args; - for(UInt aa = 0; aa < argCount; ++aa) - { - // Note: we assume that all the arguments - // must be simple here, because otherwise - // the `struct` type with them as fields - // would not be simple... - // - args.add(legalArgs[aa].getSimple()); - } - return LegalVal::simple( - builder->emitMakeStruct( - legalType.getSimple(), - argCount, - args.getBuffer())); - } - - case LegalType::Flavor::pair: - { - // There are two sides, the ordinary and the special, - // and we basically just dispatch to both of them. - auto pairType = legalType.getPair(); - auto pairInfo = pairType->pairInfo; - LegalType ordinaryType = pairType->ordinaryType; - LegalType specialType = pairType->specialType; - - List ordinaryArgs; - List specialArgs; - UInt argCounter = 0; - for(auto ee : pairInfo->elements) - { - UInt argIndex = argCounter++; - LegalVal arg = legalArgs[argIndex]; - - if( arg.flavor == LegalVal::Flavor::pair ) - { - // The argument is itself a pair - auto argPair = arg.getPair(); - ordinaryArgs.add(argPair->ordinaryVal); - specialArgs.add(argPair->specialVal); - } - else if(ee.flags & Slang::PairInfo::kFlag_hasOrdinary) - { - ordinaryArgs.add(arg); - } - else if(ee.flags & Slang::PairInfo::kFlag_hasSpecial) - { - specialArgs.add(arg); - } - } - - LegalVal ordinaryVal = legalizeMakeStruct( - context, - ordinaryType, - ordinaryArgs.getBuffer(), - ordinaryArgs.getCount()); - - LegalVal specialVal = legalizeMakeStruct( - context, - specialType, - specialArgs.getBuffer(), - specialArgs.getCount()); - - return LegalVal::pair(ordinaryVal, specialVal, pairInfo); - } - break; - - case LegalType::Flavor::tuple: - { - // We are constructing a tuple of values from - // the individual fields. We need to identify - // for each tuple element what field it uses, - // and then extract that field's value. - - auto tupleType = legalType.getTuple(); - - RefPtr resTupleInfo = new TuplePseudoVal(); - UInt argCounter = 0; - for(auto typeElem : tupleType->elements) - { - auto elemKey = typeElem.key; - UInt argIndex = argCounter++; - SLANG_ASSERT(argIndex < argCount); - - LegalVal argVal = legalArgs[argIndex]; - - TuplePseudoVal::Element resElem; - resElem.key = elemKey; - resElem.val = argVal; - - resTupleInfo->elements.add(resElem); - } - return LegalVal::tuple(resTupleInfo); - } - - default: - SLANG_UNEXPECTED("unhandled"); - UNREACHABLE_RETURN(LegalVal()); - } -} - -static LegalVal legalizeConstruct(IRTypeLegalizationContext* context, - LegalType type) -{ - switch (type.flavor) - { - case LegalType::Flavor::none: - return LegalVal(); - case LegalType::Flavor::simple: - return LegalVal::simple(context->builder->emitConstructorInst(type.getSimple(), 0, nullptr)); - default: - SLANG_UNEXPECTED("unhandled legalization case for construct inst."); - UNREACHABLE_RETURN(LegalVal()); - } -} - -static LegalVal legalizeInst( - IRTypeLegalizationContext* context, - IRInst* inst, - LegalType type, - LegalVal const* args) -{ - switch (inst->op) - { - case kIROp_Load: - return legalizeLoad(context, args[0]); - - case kIROp_FieldAddress: - return legalizeFieldAddress(context, type, args[0], args[1]); - - case kIROp_FieldExtract: - return legalizeFieldExtract(context, type, args[0], args[1]); - - case kIROp_getElement: - return legalizeGetElement(context, type, args[0], args[1]); - - case kIROp_getElementPtr: - return legalizeGetElementPtr(context, type, args[0], args[1]); - - case kIROp_Store: - return legalizeStore(context, args[0], args[1]); - - case kIROp_Call: - return legalizeCall(context, (IRCall*)inst); - case kIROp_ReturnVal: - return legalizeRetVal(context, args[0]); - case kIROp_makeStruct: - return legalizeMakeStruct( - context, - type, - args, - inst->getOperandCount()); - case kIROp_Construct: - return legalizeConstruct(context, type); - case kIROp_undefined: - return LegalVal(); - default: - // TODO: produce a user-visible diagnostic here - SLANG_UNEXPECTED("non-simple operand(s)!"); - break; - } -} - -RefPtr findVarLayout(IRInst* value) -{ - if (auto layoutDecoration = value->findDecoration()) - return as(layoutDecoration->getLayout()); - return nullptr; -} - -static UnownedStringSlice findNameHint(IRInst* inst) -{ - if( auto nameHintDecoration = inst->findDecoration() ) - { - return nameHintDecoration->getName(); - } - return UnownedStringSlice(); -} - -static LegalVal legalizeLocalVar( - IRTypeLegalizationContext* context, - IRVar* irLocalVar) -{ - // Legalize the type for the variable's value - auto originalValueType = irLocalVar->getDataType()->getValueType(); - auto legalValueType = legalizeType( - context, - originalValueType); - - auto originalRate = irLocalVar->getRate(); - - RefPtr varLayout = findVarLayout(irLocalVar); - RefPtr typeLayout = varLayout ? varLayout->typeLayout : nullptr; - - // If we've decided to do implicit deref on the type, - // then go ahead and declare a value of the pointed-to type. - LegalType maybeSimpleType = legalValueType; - while (maybeSimpleType.flavor == LegalType::Flavor::implicitDeref) - { - maybeSimpleType = maybeSimpleType.getImplicitDeref()->valueType; - } - - switch (maybeSimpleType.flavor) - { - case LegalType::Flavor::simple: - { - // Easy case: the type is usable as-is, and we - // should just do that. - auto type = maybeSimpleType.getSimple(); - type = context->builder->getPtrType(type); - if( originalRate ) - { - type = context->builder->getRateQualifiedType( - originalRate, - type); - } - irLocalVar->setFullType(type); - return LegalVal::simple(irLocalVar); - } - - default: - { - // TODO: We don't handle rates in this path. - - context->insertBeforeLocalVar = irLocalVar; - - LegalVarChainLink varChain(LegalVarChain(), varLayout); - - UnownedStringSlice nameHint = findNameHint(irLocalVar); - context->builder->setInsertBefore(irLocalVar); - LegalVal newVal = declareVars(context, kIROp_Var, legalValueType, typeLayout, varChain, nameHint, irLocalVar, nullptr, context->isSpecialType(originalValueType)); - - // Remove the old local var. - irLocalVar->removeFromParent(); - // add old local var to list - context->replacedInstructions.add(irLocalVar); - return newVal; - } - break; - } -} - -static LegalVal legalizeParam( - IRTypeLegalizationContext* context, - IRParam* originalParam) -{ - auto legalParamType = legalizeType(context, originalParam->getFullType()); - if (legalParamType.flavor == LegalType::Flavor::simple) - { - // Simple case: things were legalized to a simple type, - // so we can just use the original parameter as-is. - originalParam->setFullType(legalParamType.getSimple()); - return LegalVal::simple(originalParam); - } - else - { - // Complex case: we need to insert zero or more new parameters, - // which will replace the old ones. - - context->insertBeforeParam = originalParam; - - UnownedStringSlice nameHint = findNameHint(originalParam); - - context->builder->setInsertBefore(originalParam); - auto newVal = declareVars(context, kIROp_Param, legalParamType, nullptr, LegalVarChain(), nameHint, originalParam, nullptr, context->isSpecialType(originalParam->getDataType())); - - originalParam->removeFromParent(); - context->replacedInstructions.add(originalParam); - return newVal; - } -} - -static LegalVal legalizeFunc( - IRTypeLegalizationContext* context, - IRFunc* irFunc); - -static LegalVal legalizeGlobalVar( - IRTypeLegalizationContext* context, - IRGlobalVar* irGlobalVar); - -static LegalVal legalizeGlobalConstant( - IRTypeLegalizationContext* context, - IRGlobalConstant* irGlobalConstant); - -static LegalVal legalizeGlobalParam( - IRTypeLegalizationContext* context, - IRGlobalParam* irGlobalParam); - -static LegalVal legalizeInst( - IRTypeLegalizationContext* context, - IRInst* inst) -{ - // Any additional instructions we need to emit - // in the process of legalizing `inst` should - // by default be insertied right before `inst`. - // - context->builder->setInsertBefore(inst); - - // Special-case certain operations - switch (inst->op) - { - case kIROp_Var: - return legalizeLocalVar(context, cast(inst)); - - case kIROp_Param: - return legalizeParam(context, cast(inst)); - - case kIROp_WitnessTable: - // Just skip these. - break; - - case kIROp_Func: - return legalizeFunc(context, cast(inst)); - - case kIROp_GlobalVar: - return legalizeGlobalVar(context, cast(inst)); - - case kIROp_GlobalConstant: - return legalizeGlobalConstant(context, cast(inst)); - - case kIROp_GlobalParam: - return legalizeGlobalParam(context, cast(inst)); - - default: - break; - } - - // We will iterate over all the operands, extract the legalized - // value of each, and collect them in an array for subsequent use. - // - auto argCount = inst->getOperandCount(); - List legalArgs; - // - // Along the way we will also note whether there were any operands - // with non-simple legalized values. - // - bool anyComplex = false; - for (UInt aa = 0; aa < argCount; ++aa) - { - auto oldArg = inst->getOperand(aa); - auto legalArg = legalizeOperand(context, oldArg); - legalArgs.add(legalArg); - - if (legalArg.flavor != LegalVal::Flavor::simple) - anyComplex = true; - } - - // We must also legalize the type of the instruction, since that - // is implicitly one of its operands. - // - LegalType legalType = legalizeType(context, inst->getFullType()); - - // If there was nothing interesting that occured for the operands - // then we can re-use this instruction as-is. - // - if (!anyComplex && legalType.flavor == LegalType::Flavor::simple) - { - // While the operands are all "simple," they might not necessarily - // be equal to the operands we started with. - // - for (UInt aa = 0; aa < argCount; ++aa) - { - auto legalArg = legalArgs[aa]; - inst->setOperand(aa, legalArg.getSimple()); - } - - inst->setFullType(legalType.getSimple()); - - return LegalVal::simple(inst); - } - - // We have at least one "complex" operand, and we - // need to figure out what to do with it. The anwer - // will, in general, depend on what we are doing. - - // We will set up the IR builder so that any new - // instructions generated will be placed before - // the location of the original instruction. - auto builder = context->builder; - builder->setInsertBefore(inst); - - LegalVal legalVal = legalizeInst( - context, - inst, - legalType, - legalArgs.getBuffer()); - - // After we are done, we will eliminate the - // original instruction by removing it from - // the IR. - // - inst->removeFromParent(); - context->replacedInstructions.add(inst); - - // The value to be used when referencing - // the original instruction will now be - // whatever value(s) we created to replace it. - return legalVal; -} - -static void addParamType(List& ioParamTypes, LegalType t) -{ - switch (t.flavor) - { - case LegalType::Flavor::none: - break; - - case LegalType::Flavor::simple: - ioParamTypes.add(t.getSimple()); - break; - - case LegalType::Flavor::implicitDeref: - { - auto imp = t.getImplicitDeref(); - addParamType(ioParamTypes, imp->valueType); - break; - } - case LegalType::Flavor::pair: - { - auto pairInfo = t.getPair(); - addParamType(ioParamTypes, pairInfo->ordinaryType); - addParamType(ioParamTypes, pairInfo->specialType); - } - break; - case LegalType::Flavor::tuple: - { - auto tup = t.getTuple(); - for (auto & elem : tup->elements) - addParamType(ioParamTypes, elem.type); - } - break; - default: - SLANG_UNEXPECTED("unknown legalized type flavor"); - } -} - -static void legalizeInstsInParent( - IRTypeLegalizationContext* context, - IRInst* parent) -{ - IRInst* nextChild = nullptr; - for(auto child = parent->getFirstDecorationOrChild(); child; child = nextChild) - { - nextChild = child->getNextInst(); - - if (auto block = as(child)) - { - legalizeInstsInParent(context, block); - } - else - { - LegalVal legalVal = legalizeInst(context, child); - registerLegalizedValue(context, child, legalVal); - } - } -} - -static LegalVal legalizeFunc( - IRTypeLegalizationContext* context, - IRFunc* irFunc) -{ - // Overwrite the function's type with the result of legalization. - - IRFuncType* oldFuncType = irFunc->getDataType(); - UInt oldParamCount = oldFuncType->getParamCount(); - - // TODO: we should give an error message when the result type of a function - // can't be legalized (e.g., trying to return a texture, or a structue that - // contains one). - auto legalReturnType = legalizeType(context, oldFuncType->getResultType()); - IRType* newResultType = nullptr; - switch (legalReturnType.flavor) - { - case LegalType::Flavor::simple: - newResultType = legalReturnType.getSimple(); - break; - case LegalType::Flavor::none: - newResultType = context->builder->getVoidType(); - break; - default: - SLANG_UNEXPECTED("unknown legalized function return type."); - } - List newParamTypes; - for (UInt pp = 0; pp < oldParamCount; ++pp) - { - auto legalParamType = legalizeType(context, oldFuncType->getParamType(pp)); - addParamType(newParamTypes, legalParamType); - } - - auto newFuncType = context->builder->getFuncType( - newParamTypes.getCount(), - newParamTypes.getBuffer(), - newResultType); - - context->builder->setDataType(irFunc, newFuncType); - - legalizeInstsInParent(context, irFunc); - return LegalVal::simple(irFunc); -} - -static LegalVal declareSimpleVar( - IRTypeLegalizationContext* context, - IROp op, - IRType* type, - TypeLayout* typeLayout, - LegalVarChain const& varChain, - UnownedStringSlice nameHint, - IRInst* leafVar, - IRGlobalNameInfo* globalNameInfo) -{ - SLANG_UNUSED(globalNameInfo); - - RefPtr varLayout = createVarLayout(varChain, typeLayout); - - DeclRef varDeclRef = varChain.getLeafVarDeclRef(); - - IRBuilder* builder = context->builder; - - IRInst* irVar = nullptr; - LegalVal legalVarVal; - - switch (op) - { - case kIROp_GlobalVar: - { - auto globalVar = builder->createGlobalVar(type); - globalVar->removeFromParent(); - globalVar->insertBefore(context->insertBeforeGlobal); - - irVar = globalVar; - legalVarVal = LegalVal::simple(irVar); - } - break; - - case kIROp_GlobalConstant: - { - auto globalConst = builder->createGlobalConstant(type); - globalConst->removeFromParent(); - globalConst->insertBefore(context->insertBeforeGlobal); - - irVar = globalConst; - legalVarVal = LegalVal::simple(globalConst); - } - break; - - case kIROp_GlobalParam: - { - auto globalParam = builder->createGlobalParam(type); - globalParam->removeFromParent(); - globalParam->insertBefore(context->insertBeforeGlobal); - - irVar = globalParam; - legalVarVal = LegalVal::simple(globalParam); - } - break; - - case kIROp_Var: - { - builder->setInsertBefore(context->insertBeforeLocalVar); - auto localVar = builder->emitVar(type); - - irVar = localVar; - legalVarVal = LegalVal::simple(irVar); - - } - break; - - case kIROp_Param: - { - auto param = builder->emitParam(type); - param->insertBefore(context->insertBeforeParam); - - irVar = param; - legalVarVal = LegalVal::simple(irVar); - } - break; - - default: - SLANG_UNEXPECTED("unexpected IR opcode"); - break; - } - - if (irVar) - { - if (varLayout) - { - builder->addLayoutDecoration(irVar, varLayout); - } - - if (varDeclRef) - { - builder->addHighLevelDeclDecoration(irVar, varDeclRef.getDecl()); - } - - if( nameHint.size() ) - { - context->builder->addNameHintDecoration(irVar, nameHint); - } - - if( leafVar ) - { - for( auto decoration : leafVar->getDecorations() ) - { - switch( decoration->op ) - { - case kIROp_FormatDecoration: - cloneDecoration(decoration, irVar); - break; - - default: - break; - } - } - } - - } - - return legalVarVal; -} - - /// Add layout information for the fields of a wrapped buffer type. - /// - /// A wrapped buffer type encodes a buffer like `ConstantBuffer` - /// where `Foo` might have interface-type fields that have been - /// specialized to a concrete type. E.g.: - /// - /// struct Car { IDriver driver; int mph; }; - /// ConstantBuffer machOne; - /// - /// In a case where the `machOne.driver` field has been specialized - /// to the type `SpeedRacer`, we need to generate a legalized - /// buffer layout something like: - /// - /// struct Car_0 { int mph; } - /// struct Wrapped { Car_0 car; SpeedRacer card_d; } - /// ConstantBuffer machOne; - /// - /// The layout information for the existing `machOne` clearly - /// can't apply because we have a new element type with new fields. - /// - /// This function is used to recursively fill in the layout for - /// the fields of the `Wrapped` type, using information recorded - /// when the legal wrapped buffer type was created. - /// -static void _addFieldsToWrappedBufferElementTypeLayout( - TypeLayout* elementTypeLayout, // layout of the original field type - StructTypeLayout* newTypeLayout, // layout we are filling in - LegalElementWrapping const& elementInfo, // information on how the original type got wrapped - LegalVarChain const& varChain, // chain of variables that is leading to this field - bool isSpecial) // should we assume a leaf field is a special (interface) type? -{ - // The way we handle things depends primary on the - // `elementInfo`, because that tells us how things - // were wrapped up when the type was legalized. - - switch( elementInfo.flavor ) - { - case LegalElementWrapping::Flavor::none: - // A leaf `none` value meant there was nothing - // to encode for a particular field (probably - // had a `void` or empty structure type). - break; - - case LegalElementWrapping::Flavor::simple: - { - auto simpleInfo = elementInfo.getSimple(); - - // A `simple` wrapping means we hit a leaf - // field that can be encoded directly. - // What we do here depends on whether we've - // reached an ordinary field of the original - // data type, or if we've reached a leaf - // field of interface type. - // - // We've been tracking a `varChain` that - // remembers all the parent `struct` fields - // we've navigated through to get here, and - // that information has been tracking two - // different pieces of layout: - // - // * The "primary" layout represents the storage - // of the buffer element type as we usually - // think of its (e.g., the bytes starting at offset zero). - // - // * The "pending" layout tells us where all the - // fields representing concrete types plugged in - // for interface-type slots got placed. - // - // We have tunneled down info to tell us which case - // we should use (`isSpecial`). - // - // Most of the logic is the same between the two - // cases. We will be computing layout information - // for a field of the new/wrapped buffer element type. - // - RefPtr newFieldLayout; - if(isSpecial) - { - // In the special case, that field will be laid out - // based on the "pending" var chain, and the type - // of the pending data for the element. - // - newFieldLayout = createSimpleVarLayout(varChain.pendingChain, elementTypeLayout->pendingDataTypeLayout); - } - else - { - // The ordinary case just uses the primary layout - // information and the primary/nominal type of - // the field. - // - newFieldLayout = createSimpleVarLayout(varChain.primaryChain, elementTypeLayout); - } - - // Either way, we add the new field to the struct type - // layout we are building, and also update the mapping - // information so that we can find the field layout - // based on the IR key for the struct field. - // - newTypeLayout->fields.add(newFieldLayout); - newTypeLayout->mapKeyToLayout.Add(simpleInfo->key, newFieldLayout); - } - break; - - case LegalElementWrapping::Flavor::implicitDeref: - { - // This is the case where a field in the element type - // has been legalized from `SomePtrLikeType` to - // `T`, so there is a different in levels of indirection. - // - // We need to recurse and see how the type `T` - // got laid out to know what field(s) it might comprise. - // - auto implicitDerefInfo = elementInfo.getImplicitDeref(); - _addFieldsToWrappedBufferElementTypeLayout( - elementTypeLayout, - newTypeLayout, - implicitDerefInfo->field, - varChain, - isSpecial); - return; - } - break; - - case LegalElementWrapping::Flavor::pair: - { - // The pair case is the first main workhorse where - // if we had a type that mixed ordinary and interface-type - // fields, it would get split into an ordinary part - // and a "special" part, each of which might comprise - // zero or more fields. - // - // Here we recurse on both the ordinary and special - // sides, and the only interesting tidbit is that - // we pass along appropriate values for the `isSpecial` - // flag so that we act appropriately upon running - // into a leaf field. - // - auto pairElementInfo = elementInfo.getPair(); - _addFieldsToWrappedBufferElementTypeLayout( - elementTypeLayout, - newTypeLayout, - pairElementInfo->ordinary, - varChain, - false); - _addFieldsToWrappedBufferElementTypeLayout( - elementTypeLayout, - newTypeLayout, - pairElementInfo->special, - varChain, - true); - } - break; - - case LegalElementWrapping::Flavor::tuple: - { - // A tuple comes up when we've turned an aggregate - // with one or more interface-type fields into - // distinct fields at the top level. - // - // For the most part we just recurse on each field, - // but note that we set the `isSpecial` flag on - // the recursive calls, since we never use tuples - // to store anything that isn't special. - - auto tupleInfo = elementInfo.getTuple(); - for( auto ee : tupleInfo->elements ) - { - auto oldFieldLayout = getFieldLayout(elementTypeLayout, ee.key); - SLANG_ASSERT(oldFieldLayout); - - LegalVarChainLink fieldChain(varChain, oldFieldLayout); - - _addFieldsToWrappedBufferElementTypeLayout( - oldFieldLayout->typeLayout, - newTypeLayout, - ee.field, - fieldChain, - true); - } - } - break; - - default: - SLANG_UNEXPECTED("unhandled element wrapping flavor"); - break; - } -} - - /// Add offset information for `kind` to `resultVarLayout`, - /// if it doesn't already exist, and adjust the offset so - /// that it will represent an offset relative to the - /// "primary" data for the surrounding type, rather than - /// being relative to the "pending" data. - /// -static void _addOffsetVarLayoutEntry( - VarLayout* resultVarLayout, - LegalVarChain const& varChain, - LayoutResourceKind kind) -{ - // If the target already has an offset for this kind, bail out. - // - if(resultVarLayout->FindResourceInfo(kind)) - return; - - // Add the `ResourceInfo` that will represent the offset for - // this resource kind (it will be initialized to zero by default) - // - auto resultResInfo = resultVarLayout->findOrAddResourceInfo(kind); - - // Add in any contributions from the "pending" var chain, since - // that chain of offsets will accumulate to get the leaf offset - // within the pending data, which in this case we assume amounts - // to an *absolute* offset. - // - for(auto vv = varChain.pendingChain; vv; vv = vv->next ) - { - if( auto chainResInfo = vv->varLayout->FindResourceInfo(kind) ) - { - resultResInfo->index += chainResInfo->index; - resultResInfo->space += chainResInfo->space; - } - } - - // Subtract any contributions from the primary var chain, since - // we want the resulting offset to be relative to the same - // base as that chain. - // - for(auto vv = varChain.primaryChain; vv; vv = vv->next ) - { - if( auto chainResInfo = vv->varLayout->FindResourceInfo(kind) ) - { - resultResInfo->index -= chainResInfo->index; - resultResInfo->space -= chainResInfo->space; - } - } -} - - /// Create a variable layout for an field with "pending" type. - /// - /// The given `typeLayout` should represent the type of a field - /// that is being stored in "pending" data, but that now needs - /// to be made relative to the "primary" data, because we are - /// legalizing the pending data out of the code. - /// -static RefPtr _createOffsetVarLayout( - LegalVarChain const& varChain, - TypeLayout* typeLayout) -{ - RefPtr resultVarLayout = new VarLayout(); - - // For every resource kind the type consumes, we will - // compute an adjusted offset for the variable that - // encodes the (absolute) offset of the pending data - // in `varChain` relative to its primary data. - // - for( auto resInfo : typeLayout->resourceInfos ) - { - _addOffsetVarLayoutEntry(resultVarLayout, varChain, resInfo.kind); - } - - return resultVarLayout; -} - - /// Place offset information from `srcResInfo` onto `dstLayout`, - /// offset by whatever is in `offsetVarLayout` -static void addOffsetResInfo( - VarLayout* dstLayout, - VarLayout::ResourceInfo const& srcResInfo, - VarLayout* offsetVarLayout) -{ - auto kind = srcResInfo.kind; - auto dstResInfo = dstLayout->findOrAddResourceInfo(kind); - - dstResInfo->index = srcResInfo.index; - dstResInfo->space = srcResInfo.space; - - if( auto offsetResInfo = offsetVarLayout->findOrAddResourceInfo(kind) ) - { - dstResInfo->index += offsetResInfo->index; - dstResInfo->space += offsetResInfo->space; - } -} - - /// Create layout information for a wrapped buffer type. - /// - /// A wrapped buffer type encodes a buffer like `ConstantBuffer` - /// where `Foo` might have interface-type fields that have been - /// specialized to a concrete type. - /// - /// Consider: - /// - /// struct Car { IDriver driver; int mph; }; - /// ConstantBuffer machOne; - /// - /// In a case where the `machOne.driver` field has been specialized - /// to the type `SpeedRacer`, we need to generate a legalized - /// buffer layout something like: - /// - /// struct Car_0 { int mph; } - /// struct Wrapped { Car_0 car; SpeedRacer card_d; } - /// ConstantBuffer machOne; - /// - /// The layout information for the existing `machOne` clearly - /// can't apply because we have a new element type with new fields. - /// - /// This function is used to create a layout for a legalized - /// buffer type that requires wrapping, based on the original - /// type layout information and the variable layout information - /// of the surrounding context (e.g., the global shader parameter - /// that has this type). - /// -static RefPtr _createWrappedBufferTypeLayout( - TypeLayout* oldTypeLayout, - WrappedBufferPseudoType* wrappedBufferTypeInfo, - LegalVarChain const& outerVarChain) -{ - // We shouldn't get invoked unless there was a parameter group type, - // so we will sanity check for that just to be sure. - // - auto oldParameterGroupTypeLayout = as(oldTypeLayout); - SLANG_ASSERT(oldParameterGroupTypeLayout); - if(!oldParameterGroupTypeLayout) - return oldTypeLayout; - - // The original type must have been split between the direct/primary - // data and some amount of "pending" data to deal with interface-type - // data in the element type of the parameter group. - // - // The legalization step will have already flattened the data inside of - // the group to a single `struct` type, which places the primary data first, - // and then any pending data into additional fields. - // - // Our job is to compute a type layout that we can apply to that new - // element type, and to a parameter group surrounding it, that will - // re-create the original intention of the split layout (both primary - // and pending data) for a type that now only has the "primary" data. - // - RefPtr newTypeLayout = new ParameterGroupTypeLayout(); - newTypeLayout->type = oldTypeLayout->type; - newTypeLayout->rules = oldTypeLayout->rules; - newTypeLayout->uniformAlignment = oldTypeLayout->uniformAlignment; - for(auto resInfo : oldTypeLayout->resourceInfos) - newTypeLayout->addResourceUsage(resInfo); - - // Any fields in the "pending" data will have offset information - // that is relative to the pending data for their parent, and so on. - // We need to compute layout information that only includes primary - // data, so any offset information that is relative to the pending data - // needs to instead be relative to the primary data. That amounts to - // computing the absolute offset of each pending field, and then - // subtracting off the absolute offset of the primary data. - // - // We will compute the offset that needs to be added up front, - // and store it in the form of a `VarLayout`. The offsets we need - // can be computed from the `outerVarChain`, and we only need to - // store offset information for resource kinds actually consumed - // by the pending data type for the buffer as a whole (e.g., we - // don't need to apply offsetting to uniform bytes, because - // those don't show up in the resource usage of a constant buffer - // itself, and so the offsets already *are* relative to the start - // of the buffer). - // - auto offsetVarLayout = _createOffsetVarLayout(outerVarChain, oldTypeLayout->pendingDataTypeLayout); - LegalVarChainLink offsetVarChain(LegalVarChain(), offsetVarLayout); - - // We will start our construction of the pieces of the output - // type layout by looking at the "container" type/variable. - // - // A parameter block or constant buffer in Slang needs to - // distinguish between the resource usage of the thing in - // the block/buffer, vs. the resource usage of the block/buffer - // itself. Consider: - // - // struct Material { float4 color; Texture2D tex; } - // ConstantBuffer gMat; - // - // When compiling for Vulkan, the `gMat` constant buffer needs - // a `binding`, and the `tex` field does too, so the overall - // resource usage of `gMat` is two bindings, but we need a - // way to encode which of those bindings goes to `gMat.tex` - // and which to the constant buffer for `gMat` itself. - // - { - // We will start by extracting the "primary" part of the old - // container type/var layout, and constructing new objects - // that will represent the layout for our wrapped buffer. - // - auto oldPrimaryContainerVarLayout = oldParameterGroupTypeLayout->containerVarLayout; - auto oldPrimaryContainerTypeLayout = oldPrimaryContainerVarLayout->typeLayout; - - RefPtr newContainerTypeLayout = new TypeLayout(); - newContainerTypeLayout->type = oldPrimaryContainerTypeLayout->type; - - RefPtr newContainerVarLayout = new VarLayout(); - newContainerVarLayout->typeLayout = newContainerTypeLayout; - - newTypeLayout->containerVarLayout = newContainerVarLayout; - - // Whatever got allocated for the primary container should get copied - // over to the new layout (e.g., if we allocated a constant buffer - // for `gMat` then we need to retain that information). - // - newContainerTypeLayout->addResourceUsageFrom(oldPrimaryContainerTypeLayout); - for( auto resInfo : oldPrimaryContainerVarLayout->resourceInfos ) - { - auto newResInfo = newContainerVarLayout->findOrAddResourceInfo(resInfo.kind); - newResInfo->index = resInfo.index; - newResInfo->space = resInfo.space; - } - - // It is possible that a constant buffer and/or space didn't get - // allocated for the "primary" data, but ended up being required for - // the "pending" data (this would happen if, e.g., a constant buffer - // didn't appear to have any uniform data in it, but then once we - // plugged in concrete types for interface fields it did...), so - // we need to account for that case and copy over the relevant - // resource usage from the pending data, if there is any. - // - if( auto oldPendingContainerVarLayout = oldPrimaryContainerVarLayout->pendingVarLayout ) - { - // Whatever resources were allocated for the pending data type, - // our new combined container type needs to account for them - // (e.g., if we didn't have a constant buffer in the primary - // data, but one got allocated in the pending data, we need - // to end up with type layout information that includes a - // constnat buffer). - // - auto oldPendingContainerTypeLayout = oldPendingContainerVarLayout->typeLayout; - newContainerTypeLayout->addResourceUsageFrom(oldPendingContainerTypeLayout); - - // We also need to add offset information based on the "pending" - // var layout, but we need to deal with the fact that this information - // is currently stored relative to the pending var layout for the surrounding - // context (passed in as `outerVarChain.pendingChain`), but we need it to be - // relative to the primary layout for the surrounding context (`outerVarChain.primaryChain`). - // This is where the `offsetVarLayout` we computed above comes - // in handy, because it represents the value(s) we need to - // add to each of the per-resource-kind offsets. - // - for( auto resInfo : oldPendingContainerVarLayout->resourceInfos ) - { - addOffsetResInfo(newContainerVarLayout, resInfo, offsetVarLayout); - } - } - } - - // Now that we've dealt with the container variable, we can turn - // our attention to the element type. This is the part that - // actually got legalized and required us to create a "wrapped" - // buffer type in the first place, so we know that it will - // have both primary and "pending" parts. - // - // Let's start by extracting the fields we care about from - // the original element type/var layout, and constructing - // the objects we'll use to represent the type/var layout for - // the new element type. - // - auto oldElementVarLayout = oldParameterGroupTypeLayout->elementVarLayout; - auto oldElementTypeLayout = oldElementVarLayout->typeLayout; - - // Now matter what, the element type of a wrapped buffer - // will always have a structure type. - // - RefPtr newElementTypeLayout = new StructTypeLayout(); - newElementTypeLayout->type = oldElementTypeLayout->type; - - // The `wrappedBufferTypeInfo` that was passed in tells - // us how the fields of the original type got turned into - // zero or more fields in the new element type, so we - // need to follow its recursive structure to build - // layout information for each of the new fields. - // - // We will track a "chain" of parent variables that - // determines how we got to each leaf field, and is - // used to add up the offsets that will be stored - // in the new `VarLayout`s that get created. - // We know we need to add in some offsets (usually - // negative) to any fields that were pending data, - // so we will account for that in the initial - // chain of outer variables that we pass in. - // - LegalVarChain varChainForElementType; - varChainForElementType.primaryChain = nullptr; - varChainForElementType.pendingChain = offsetVarChain.primaryChain; - - _addFieldsToWrappedBufferElementTypeLayout( - oldElementTypeLayout, - newElementTypeLayout, - wrappedBufferTypeInfo->elementInfo, - varChainForElementType, - true); - - // A parameter group type layout holds a `VarLayout` for the element type, - // which encodes the offset of the element type with respect to the - // start of the parameter group as a whole (e.g., to handle the case - // where a constant buffer needs a `binding`, and so does its - // element type, so the offset to the first `binding` for the element - // type is one, not zero. - // - LegalVarChainLink elementVarChain(LegalVarChain(), oldParameterGroupTypeLayout->elementVarLayout); - auto newElementVarLayout = createVarLayout(elementVarChain, newElementTypeLayout); - newTypeLayout->elementVarLayout = newElementVarLayout; - - // For legacy/API reasons, we also need to compute a version of the - // element type where the offset stored in the `elementVarLayout` - // gets "baked in" to the fields of the element type. - // - newTypeLayout->offsetElementTypeLayout = applyOffsetToTypeLayout( - newElementTypeLayout, - newElementVarLayout); - - return newTypeLayout; -} - -static LegalVal declareVars( - IRTypeLegalizationContext* context, - IROp op, - LegalType type, - TypeLayout* inTypeLayout, - LegalVarChain const& inVarChain, - UnownedStringSlice nameHint, - IRInst* leafVar, - IRGlobalNameInfo* globalNameInfo, - bool isSpecial) -{ - LegalVarChain varChain = inVarChain; - TypeLayout* typeLayout = inTypeLayout; - if( isSpecial ) - { - if( varChain.pendingChain ) - { - varChain.primaryChain = varChain.pendingChain; - varChain.pendingChain = nullptr; - } - if( typeLayout ) - { - if( auto pendingTypeLayout = typeLayout->pendingDataTypeLayout ) - { - typeLayout = pendingTypeLayout; - } - } - } - - switch (type.flavor) - { - case LegalType::Flavor::none: - return LegalVal(); - - case LegalType::Flavor::simple: - return declareSimpleVar(context, op, type.getSimple(), typeLayout, varChain, nameHint, leafVar, globalNameInfo); - break; - - case LegalType::Flavor::implicitDeref: - { - // Just declare a variable of the pointed-to type, - // since we are removing the indirection. - - auto val = declareVars( - context, - op, - type.getImplicitDeref()->valueType, - typeLayout, - varChain, - nameHint, - leafVar, - globalNameInfo, - isSpecial); - return LegalVal::implicitDeref(val); - } - break; - - case LegalType::Flavor::pair: - { - auto pairType = type.getPair(); - auto ordinaryVal = declareVars(context, op, pairType->ordinaryType, typeLayout, varChain, nameHint, leafVar, globalNameInfo, false); - auto specialVal = declareVars(context, op, pairType->specialType, typeLayout, varChain, nameHint, leafVar, globalNameInfo, true); - return LegalVal::pair(ordinaryVal, specialVal, pairType->pairInfo); - } - - case LegalType::Flavor::tuple: - { - // Declare one variable for each element of the tuple - auto tupleType = type.getTuple(); - - RefPtr tupleVal = new TuplePseudoVal(); - - for (auto ee : tupleType->elements) - { - auto fieldLayout = getFieldLayout(typeLayout, ee.key); - RefPtr fieldTypeLayout = fieldLayout ? fieldLayout->typeLayout : nullptr; - - // If we have a type layout coming in, we really expect to have a layout for each field. - SLANG_ASSERT(fieldLayout || !typeLayout); - - // If we are processing layout information, then - // we need to create a new link in the chain - // of variables that will determine offsets - // for the eventual leaf fields... - // - LegalVarChainLink newVarChain(varChain, fieldLayout); - - UnownedStringSlice fieldNameHint; - String joinedNameHintStorage; - if( nameHint.size() ) - { - if( auto fieldNameHintDecoration = ee.key->findDecoration() ) - { - joinedNameHintStorage.append(nameHint); - joinedNameHintStorage.append("."); - joinedNameHintStorage.append(fieldNameHintDecoration->getName()); - - fieldNameHint = joinedNameHintStorage.getUnownedSlice(); - } - - } - - LegalVal fieldVal = declareVars( - context, - op, - ee.type, - fieldTypeLayout, - newVarChain, - fieldNameHint, - ee.key, - globalNameInfo, - true); - - TuplePseudoVal::Element element; - element.key = ee.key; - element.val = fieldVal; - tupleVal->elements.add(element); - } - - return LegalVal::tuple(tupleVal); - } - break; - - case LegalType::Flavor::wrappedBuffer: - { - auto wrappedBuffer = type.getWrappedBuffer(); - - auto wrappedTypeLayout = _createWrappedBufferTypeLayout(typeLayout, wrappedBuffer, varChain); - - auto innerVal = declareSimpleVar( - context, - op, - wrappedBuffer->simpleType, - wrappedTypeLayout, - varChain, - nameHint, - leafVar, - globalNameInfo); - - return LegalVal::wrappedBuffer(innerVal, wrappedBuffer->elementInfo); - } - - default: - SLANG_UNEXPECTED("unhandled"); - UNREACHABLE_RETURN(LegalVal()); - break; - } -} - -static LegalVal legalizeGlobalVar( - IRTypeLegalizationContext* context, - IRGlobalVar* irGlobalVar) -{ - // Legalize the type for the variable's value - auto originalValueType = irGlobalVar->getDataType()->getValueType(); - auto legalValueType = legalizeType( - context, - originalValueType); - - switch (legalValueType.flavor) - { - case LegalType::Flavor::simple: - // Easy case: the type is usable as-is, and we - // should just do that. - context->builder->setDataType( - irGlobalVar, - context->builder->getPtrType( - legalValueType.getSimple())); - return LegalVal::simple(irGlobalVar); - - default: - { - context->insertBeforeGlobal = irGlobalVar->getNextInst(); - - IRGlobalNameInfo globalNameInfo; - globalNameInfo.globalVar = irGlobalVar; - globalNameInfo.counter = 0; - - UnownedStringSlice nameHint = findNameHint(irGlobalVar); - context->builder->setInsertBefore(irGlobalVar); - LegalVal newVal = declareVars(context, kIROp_GlobalVar, legalValueType, nullptr, LegalVarChain(), nameHint, irGlobalVar, &globalNameInfo, context->isSpecialType(originalValueType)); - - // Register the new value as the replacement for the old - registerLegalizedValue(context, irGlobalVar, newVal); - - // Remove the old global from the module. - irGlobalVar->removeFromParent(); - context->replacedInstructions.add(irGlobalVar); - - return newVal; - } - break; - } -} - -static LegalVal legalizeGlobalConstant( - IRTypeLegalizationContext* context, - IRGlobalConstant* irGlobalConstant) -{ - // Legalize the type for the variable's value - auto legalValueType = legalizeType( - context, - irGlobalConstant->getFullType()); - - switch (legalValueType.flavor) - { - case LegalType::Flavor::simple: - // Easy case: the type is usable as-is, and we - // should just do that. - irGlobalConstant->setFullType(legalValueType.getSimple()); - return LegalVal::simple(irGlobalConstant); - - default: - { - context->insertBeforeGlobal = irGlobalConstant->getNextInst(); - - IRGlobalNameInfo globalNameInfo; - globalNameInfo.globalVar = irGlobalConstant; - globalNameInfo.counter = 0; - - // TODO: need to handle initializer here! - - UnownedStringSlice nameHint = findNameHint(irGlobalConstant); - context->builder->setInsertBefore(irGlobalConstant); - LegalVal newVal = declareVars(context, kIROp_GlobalConstant, legalValueType, nullptr, LegalVarChain(), nameHint, irGlobalConstant, &globalNameInfo, context->isSpecialType(irGlobalConstant->getDataType())); - - // Register the new value as the replacement for the old - registerLegalizedValue(context, irGlobalConstant, newVal); - - // Remove the old global from the module. - irGlobalConstant->removeFromParent(); - context->replacedInstructions.add(irGlobalConstant); - - return newVal; - } - break; - } -} - -static LegalVal legalizeGlobalParam( - IRTypeLegalizationContext* context, - IRGlobalParam* irGlobalParam) -{ - // Legalize the type for the variable's value - auto legalValueType = legalizeType( - context, - irGlobalParam->getFullType()); - - RefPtr varLayout = findVarLayout(irGlobalParam); - RefPtr typeLayout = varLayout ? varLayout->typeLayout : nullptr; - - switch (legalValueType.flavor) - { - case LegalType::Flavor::simple: - // Easy case: the type is usable as-is, and we - // should just do that. - irGlobalParam->setFullType(legalValueType.getSimple()); - return LegalVal::simple(irGlobalParam); - - default: - { - context->insertBeforeGlobal = irGlobalParam->getNextInst(); - - LegalVarChainLink varChain(LegalVarChain(), varLayout); - - IRGlobalNameInfo globalNameInfo; - globalNameInfo.globalVar = irGlobalParam; - globalNameInfo.counter = 0; - - // TODO: need to handle initializer here! - - UnownedStringSlice nameHint = findNameHint(irGlobalParam); - context->builder->setInsertBefore(irGlobalParam); - LegalVal newVal = declareVars(context, kIROp_GlobalParam, legalValueType, typeLayout, varChain, nameHint, irGlobalParam, &globalNameInfo, context->isSpecialType(irGlobalParam->getDataType())); - - // Register the new value as the replacement for the old - registerLegalizedValue(context, irGlobalParam, newVal); - - // Remove the old global from the module. - irGlobalParam->removeFromParent(); - context->replacedInstructions.add(irGlobalParam); - - return newVal; - } - break; - } -} - - -static void legalizeTypes( - IRTypeLegalizationContext* context) -{ - // Legalize all the top-level instructions in the module - auto module = context->module; - legalizeInstsInParent(context, module->moduleInst); - - // Clean up after any instructions we replaced along the way. - for (auto& lv : context->replacedInstructions) - { - lv->removeAndDeallocate(); - } -} - -// We use the same basic type legalization machinery for both simplifying -// away resource-type fields nested in `struct`s and for shuffling around -// exisential-box fields to get the layout right. -// -// The differences between the two passes come down to some very small -// distinctions about what types each pass considers "special" (e.g., -// resources in one case and existential boxes in the other), along -// with what they want to do when a uniform/constant buffer needs to -// be made where the element type is non-simple (that is, includes -// some fields of "special" type). -// -// The resource case is then the simpler one: -// -struct IRResourceTypeLegalizationContext : IRTypeLegalizationContext -{ - IRResourceTypeLegalizationContext(IRModule* module) - : IRTypeLegalizationContext(module) - {} - - bool isSpecialType(IRType* type) override - { - // For resource type legalization, the "special" types - // we are working with are resource types. - // - return isResourceType(type); - } - - LegalType createLegalUniformBufferType( - IROp op, - LegalType legalElementType) override - { - // The appropriate strategy for legalizing uniform buffers - // with resources inside already exists, so we can delegate to it. - // - return createLegalUniformBufferTypeForResources( - this, - op, - legalElementType); - } -}; - -// The case for legalizing existential box types is then similar. -// -struct IRExistentialTypeLegalizationContext : IRTypeLegalizationContext -{ - IRExistentialTypeLegalizationContext(IRModule* module) - : IRTypeLegalizationContext(module) - {} - - bool isSpecialType(IRType* inType) override - { - // The "special" types for our purposes are existential - // boxes, or arrays thereof. - // - auto type = unwrapArray(inType); - return as(type) != nullptr; - } - - LegalType createLegalUniformBufferType( - IROp op, - LegalType legalElementType) override - { - // We'll delegate the logic for creating uniform buffers - // over a mix of ordinary and existential-box types to - // a subroutine so it can live near the resource case. - // - // TODO: We should eventually try to refactor this code - // so that related functionality is grouped together. - // - return createLegalUniformBufferTypeForExistentials( - this, - op, - legalElementType); - } -}; - -// The main entry points that are used when transforming IR code -// to get it ready for lower-level codegen are then simple -// wrappers around `legalizeTypes()` that pick an appropriately -// specialized context type to use to get the job done. - -void legalizeResourceTypes( - IRModule* module, - DiagnosticSink* sink) -{ - SLANG_UNUSED(sink); - - IRResourceTypeLegalizationContext context(module); - legalizeTypes(&context); -} - -void legalizeExistentialTypeLayout( - IRModule* module, - DiagnosticSink* sink) -{ - SLANG_UNUSED(module); - SLANG_UNUSED(sink); - - IRExistentialTypeLegalizationContext context(module); - legalizeTypes(&context); -} - - -} diff --git a/source/slang/ir-link.cpp b/source/slang/ir-link.cpp deleted file mode 100644 index dc433663a..000000000 --- a/source/slang/ir-link.cpp +++ /dev/null @@ -1,1361 +0,0 @@ -// ir-link.cpp -#include "ir-link.h" - -#include "ir.h" -#include "ir-insts.h" -#include "mangle.h" - -namespace Slang -{ - -// Needed for lookup up entry-point layouts. -// -// TODO: maybe arrange so that codegen is driven from the layout layer -// instead of the input/request layer. -EntryPointLayout* findEntryPointLayout( - ProgramLayout* programLayout, - EntryPoint* EntryPoint); - -struct IRSpecSymbol : RefObject -{ - IRInst* irGlobalValue; - RefPtr nextWithSameName; -}; - -struct IRSpecEnv -{ - IRSpecEnv* parent = nullptr; - - // A map from original values to their cloned equivalents. - typedef Dictionary ClonedValueDictionary; - ClonedValueDictionary clonedValues; -}; - -struct IRSharedSpecContext -{ - // The code-generation target in use - CodeGenTarget target; - - // The specialized module we are building - RefPtr module; - - // A map from mangled symbol names to zero or - // more global IR values that have that name, - // in the *original* module. - typedef Dictionary> SymbolDictionary; - SymbolDictionary symbols; - - SharedIRBuilder sharedBuilderStorage; - IRBuilder builderStorage; - - // The "global" specialization environment. - IRSpecEnv globalEnv; -}; - -struct IRSpecContextBase -{ - // A map from the mangled name of a global variable - // to the layout to use for it. - Dictionary globalVarLayouts; - - IRSharedSpecContext* shared; - - IRSharedSpecContext* getShared() { return shared; } - - IRModule* getModule() { return getShared()->module; } - - IRSharedSpecContext::SymbolDictionary& getSymbols() { return getShared()->symbols; } - - // The current specialization environment to use. - IRSpecEnv* env = nullptr; - IRSpecEnv* getEnv() - { - // TODO: need to actually establish environments on contexts we create. - // - // Or more realistically we need to change the whole approach - // to specialization and cloning so that we don't try to share - // logic between two very different cases. - - - return env; - } - - // The IR builder to use for creating nodes - IRBuilder* builder; - - // A callback to be used when a value that is not registerd in `clonedValues` - // is needed during cloning. This gives the subtype a chance to intercept - // the operation and clone (or not) as needed. - virtual IRInst* maybeCloneValue(IRInst* originalVal) - { - return originalVal; - } -}; - -void registerClonedValue( - IRSpecContextBase* context, - IRInst* clonedValue, - IRInst* originalValue) -{ - if(!originalValue) - return; - - // TODO: now that things are scoped using environments, we - // shouldn't be running into the cases where a value with - // the same key already exists. This should be changed to - // an `Add()` call. - // - context->getEnv()->clonedValues[originalValue] = clonedValue; -} - -// Information on values to use when registering a cloned value -struct IROriginalValuesForClone -{ - IRInst* originalVal = nullptr; - IRSpecSymbol* sym = nullptr; - - IROriginalValuesForClone() {} - - IROriginalValuesForClone(IRInst* originalValue) - : originalVal(originalValue) - {} - - IROriginalValuesForClone(IRSpecSymbol* symbol) - : sym(symbol) - {} -}; - -void registerClonedValue( - IRSpecContextBase* context, - IRInst* clonedValue, - IROriginalValuesForClone const& originalValues) -{ - registerClonedValue(context, clonedValue, originalValues.originalVal); - for( auto s = originalValues.sym; s; s = s->nextWithSameName ) - { - registerClonedValue(context, clonedValue, s->irGlobalValue); - } -} - -IRInst* cloneInst( - IRSpecContextBase* context, - IRBuilder* builder, - IRInst* originalInst, - IROriginalValuesForClone const& originalValues); - -IRInst* cloneInst( - IRSpecContextBase* context, - IRBuilder* builder, - IRInst* originalInst) -{ - return cloneInst(context, builder, originalInst, originalInst); -} - - /// Clone any decorations from `originalValue` onto `clonedValue` -void cloneDecorations( - IRSpecContextBase* context, - IRInst* clonedValue, - IRInst* originalValue) -{ - // TODO: In many cases we might be able to use this as a general-purpose - // place to do cloning of *all* the children of an instruction, and - // not just its decorations. We should look to refactor this code - // later. - - IRBuilder builderStorage = *context->builder; - IRBuilder* builder = &builderStorage; - builder->setInsertInto(clonedValue); - - - SLANG_UNUSED(context); - for(auto originalDecoration : originalValue->getDecorations()) - { - cloneInst(context, builder, originalDecoration); - } - - // We will also clone the location here, just because this is a convenient bottleneck - clonedValue->sourceLoc = originalValue->sourceLoc; -} - - /// Clone any decorations and children from `originalValue` onto `clonedValue` -void cloneDecorationsAndChildren( - IRSpecContextBase* context, - IRInst* clonedValue, - IRInst* originalValue) -{ - IRBuilder builderStorage = *context->builder; - IRBuilder* builder = &builderStorage; - builder->setInsertInto(clonedValue); - - SLANG_UNUSED(context); - for(auto originalItem : originalValue->getDecorationsAndChildren()) - { - cloneInst(context, builder, originalItem); - } - - // We will also clone the location here, just because this is a convenient bottleneck - clonedValue->sourceLoc = originalValue->sourceLoc; -} - -// We use an `IRSpecContext` for the case where we are cloning -// code from one or more input modules to create a "linked" output -// module. Along the way, we will resolve profile-specific functions -// to the best definition for a given target. -// -struct IRSpecContext : IRSpecContextBase -{ - // Override the "maybe clone" logic so that we always clone - virtual IRInst* maybeCloneValue(IRInst* originalVal) override; -}; - - -IRInst* cloneGlobalValue(IRSpecContext* context, IRInst* originalVal); - -IRInst* cloneValue( - IRSpecContextBase* context, - IRInst* originalValue); - -IRType* cloneType( - IRSpecContextBase* context, - IRType* originalType); - -IRInst* IRSpecContext::maybeCloneValue(IRInst* originalValue) -{ - switch (originalValue->op) - { - case kIROp_StructType: - case kIROp_Func: - case kIROp_Generic: - case kIROp_GlobalVar: - case kIROp_GlobalConstant: - case kIROp_GlobalParam: - case kIROp_StructKey: - case kIROp_GlobalGenericParam: - case kIROp_WitnessTable: - return cloneGlobalValue(this, originalValue); - - case kIROp_BoolLit: - { - IRConstant* c = (IRConstant*)originalValue; - return builder->getBoolValue(c->value.intVal != 0); - } - break; - - - case kIROp_IntLit: - { - IRConstant* c = (IRConstant*)originalValue; - return builder->getIntValue(cloneType(this, c->getDataType()), c->value.intVal); - } - break; - - case kIROp_FloatLit: - { - IRConstant* c = (IRConstant*)originalValue; - return builder->getFloatValue(cloneType(this, c->getDataType()), c->value.floatVal); - } - break; - - case kIROp_StringLit: - { - IRConstant* c = (IRConstant*)originalValue; - return builder->getStringValue(c->getStringSlice()); - } - break; - - case kIROp_PtrLit: - { - IRConstant* c = (IRConstant*)originalValue; - return builder->getPtrValue(c->value.ptrVal); - } - break; - - default: - { - // In the deafult case, assume that we have some sort of "hoistable" - // instruction that requires us to create a clone of it. - - UInt argCount = originalValue->getOperandCount(); - IRInst* clonedValue = builder->createIntrinsicInst( - cloneType(this, originalValue->getFullType()), - originalValue->op, - argCount, nullptr); - registerClonedValue(this, clonedValue, originalValue); - for (UInt aa = 0; aa < argCount; ++aa) - { - IRInst* originalArg = originalValue->getOperand(aa); - IRInst* clonedArg = cloneValue(this, originalArg); - clonedValue->getOperands()[aa].init(clonedValue, clonedArg); - } - cloneDecorationsAndChildren(this, clonedValue, originalValue); - - addHoistableInst(builder, clonedValue); - - return clonedValue; - } - break; - } -} - -IRInst* cloneValue( - IRSpecContextBase* context, - IRInst* originalValue); - -// Find a pre-existing cloned value, or return null if none is available. -IRInst* findClonedValue( - IRSpecContextBase* context, - IRInst* originalValue) -{ - IRInst* clonedValue = nullptr; - for (auto env = context->getEnv(); env; env = env->parent) - { - if (env->clonedValues.TryGetValue(originalValue, clonedValue)) - { - return clonedValue; - } - } - - return nullptr; -} - -IRInst* cloneValue( - IRSpecContextBase* context, - IRInst* originalValue) -{ - if (!originalValue) - return nullptr; - - if (IRInst* clonedValue = findClonedValue(context, originalValue)) - return clonedValue; - - return context->maybeCloneValue(originalValue); -} - -IRType* cloneType( - IRSpecContextBase* context, - IRType* originalType) -{ - return (IRType*)cloneValue(context, originalType); -} - -void cloneGlobalValueWithCodeCommon( - IRSpecContextBase* context, - IRGlobalValueWithCode* clonedValue, - IRGlobalValueWithCode* originalValue); - -IRRate* cloneRate( - IRSpecContextBase* context, - IRRate* rate) -{ - return (IRRate*) cloneType(context, rate); -} - -void maybeSetClonedRate( - IRSpecContextBase* context, - IRBuilder* builder, - IRInst* clonedValue, - IRInst* originalValue) -{ - if(auto rate = originalValue->getRate() ) - { - clonedValue->setFullType(builder->getRateQualifiedType( - cloneRate(context, rate), - clonedValue->getFullType())); - } -} - -IRGlobalVar* cloneGlobalVarImpl( - IRSpecContextBase* context, - IRBuilder* builder, - IRGlobalVar* originalVar, - IROriginalValuesForClone const& originalValues) -{ - auto clonedVar = builder->createGlobalVar( - cloneType(context, originalVar->getDataType()->getValueType())); - - maybeSetClonedRate(context, builder, clonedVar, originalVar); - - registerClonedValue(context, clonedVar, originalValues); - - // Clone any code in the body of the variable, since this - // represents the initializer. - cloneGlobalValueWithCodeCommon( - context, - clonedVar, - originalVar); - - return clonedVar; -} - -IRGlobalConstant* cloneGlobalConstantImpl( - IRSpecContextBase* context, - IRBuilder* builder, - IRGlobalConstant* originalVal, - IROriginalValuesForClone const& originalValues) -{ - auto clonedVal = builder->createGlobalConstant( - cloneType(context, originalVal->getFullType())); - registerClonedValue(context, clonedVal, originalValues); - - // Clone any code in the body of the constant, since this - // represents the initializer. - cloneGlobalValueWithCodeCommon( - context, - clonedVal, - originalVal); - - return clonedVal; -} - -void cloneSimpleGlobalValueImpl( - IRSpecContextBase* context, - IRInst* originalInst, - IROriginalValuesForClone const& originalValues, - IRInst* clonedInst, - bool registerValue = true) -{ - if (registerValue) - registerClonedValue(context, clonedInst, originalValues); - - // Set up an IR builder for inserting into the inst - IRBuilder builderStorage = *context->builder; - IRBuilder* builder = &builderStorage; - builder->setInsertInto(clonedInst); - - // Clone any children of the instruction - for (auto child : originalInst->getDecorationsAndChildren()) - { - cloneInst(context, builder, child); - } -} - -IRGlobalParam* cloneGlobalParamImpl( - IRSpecContextBase* context, - IRBuilder* builder, - IRGlobalParam* originalVal, - IROriginalValuesForClone const& originalValues) -{ - auto clonedVal = builder->createGlobalParam( - cloneType(context, originalVal->getFullType())); - cloneSimpleGlobalValueImpl(context, originalVal, originalValues, clonedVal); - - if(auto linkage = originalVal->findDecoration()) - { - auto mangledName = String(linkage->getMangledName()); - VarLayout* layout = nullptr; - if (context->globalVarLayouts.TryGetValue(mangledName, layout)) - { - builder->addLayoutDecoration(clonedVal, layout); - } - } - - return clonedVal; -} - -IRGeneric* cloneGenericImpl( - IRSpecContextBase* context, - IRBuilder* builder, - IRGeneric* originalVal, - IROriginalValuesForClone const& originalValues) -{ - auto clonedVal = builder->emitGeneric(); - registerClonedValue(context, clonedVal, originalValues); - - // Clone any code in the body of the generic, since this - // computes its result value. - cloneGlobalValueWithCodeCommon( - context, - clonedVal, - originalVal); - - return clonedVal; -} - -IRStructKey* cloneStructKeyImpl( - IRSpecContextBase* context, - IRBuilder* builder, - IRStructKey* originalVal, - IROriginalValuesForClone const& originalValues) -{ - auto clonedVal = builder->createStructKey(); - cloneSimpleGlobalValueImpl(context, originalVal, originalValues, clonedVal); - return clonedVal; -} - -IRGlobalGenericParam* cloneGlobalGenericParamImpl( - IRSpecContextBase* context, - IRBuilder* builder, - IRGlobalGenericParam* originalVal, - IROriginalValuesForClone const& originalValues) -{ - auto clonedVal = builder->emitGlobalGenericParam(); - cloneSimpleGlobalValueImpl(context, originalVal, originalValues, clonedVal); - return clonedVal; -} - - -IRWitnessTable* cloneWitnessTableImpl( - IRSpecContextBase* context, - IRBuilder* builder, - IRWitnessTable* originalTable, - IROriginalValuesForClone const& originalValues, - IRWitnessTable* dstTable = nullptr, - bool registerValue = true) -{ - auto clonedTable = dstTable ? dstTable : builder->createWitnessTable(); - cloneSimpleGlobalValueImpl(context, originalTable, originalValues, clonedTable, registerValue); - return clonedTable; -} - -IRWitnessTable* cloneWitnessTableWithoutRegistering( - IRSpecContextBase* context, - IRBuilder* builder, - IRWitnessTable* originalTable, - IRWitnessTable* dstTable = nullptr) -{ - return cloneWitnessTableImpl(context, builder, originalTable, IROriginalValuesForClone(), dstTable, false); -} - -IRStructType* cloneStructTypeImpl( - IRSpecContextBase* context, - IRBuilder* builder, - IRStructType* originalStruct, - IROriginalValuesForClone const& originalValues) -{ - auto clonedStruct = builder->createStructType(); - cloneSimpleGlobalValueImpl(context, originalStruct, originalValues, clonedStruct); - return clonedStruct; -} - - -IRInterfaceType* cloneInterfaceTypeImpl( - IRSpecContextBase* context, - IRBuilder* builder, - IRInterfaceType* originalInterface, - IROriginalValuesForClone const& originalValues) -{ - auto clonedInterface = builder->createInterfaceType(); - cloneSimpleGlobalValueImpl(context, originalInterface, originalValues, clonedInterface); - return clonedInterface; -} - -void cloneGlobalValueWithCodeCommon( - IRSpecContextBase* context, - IRGlobalValueWithCode* clonedValue, - IRGlobalValueWithCode* originalValue) -{ - // Next we are going to clone the actual code. - IRBuilder builderStorage = *context->builder; - IRBuilder* builder = &builderStorage; - builder->setInsertInto(clonedValue); - - cloneDecorations(context, clonedValue, originalValue); - - // We will walk through the blocks of the function, and clone each of them. - // - // We need to create the cloned blocks first, and then walk through them, - // because blocks might be forward referenced (this is not possible - // for other cases of instructions). - for (auto originalBlock = originalValue->getFirstBlock(); - originalBlock; - originalBlock = originalBlock->getNextBlock()) - { - IRBlock* clonedBlock = builder->createBlock(); - clonedValue->addBlock(clonedBlock); - registerClonedValue(context, clonedBlock, originalBlock); - -#if 0 - // We can go ahead and clone parameters here, while we are at it. - builder->curBlock = clonedBlock; - for (auto originalParam = originalBlock->getFirstParam(); - originalParam; - originalParam = originalParam->getNextParam()) - { - IRParam* clonedParam = builder->emitParam( - context->maybeCloneType( - originalParam->getFullType())); - cloneDecorations(context, clonedParam, originalParam); - registerClonedValue(context, clonedParam, originalParam); - } -#endif - } - - // Okay, now we are in a good position to start cloning - // the instructions inside the blocks. - { - IRBlock* ob = originalValue->getFirstBlock(); - IRBlock* cb = clonedValue->getFirstBlock(); - while (ob) - { - SLANG_ASSERT(cb); - - builder->setInsertInto(cb); - for (auto oi = ob->getFirstInst(); oi; oi = oi->getNextInst()) - { - cloneInst(context, builder, oi); - } - - ob = ob->getNextBlock(); - cb = cb->getNextBlock(); - } - } - -} - -void checkIRDuplicate(IRInst* inst, IRInst* moduleInst, UnownedStringSlice const& mangledName) -{ -#ifdef _DEBUG - for (auto child : moduleInst->getDecorationsAndChildren()) - { - if (child == inst) - continue; - - if(auto childLinkage = child->findDecoration()) - { - if(mangledName == childLinkage->getMangledName()) - { - SLANG_UNEXPECTED("duplicate global instruction"); - } - } - } -#else - SLANG_UNREFERENCED_PARAMETER(inst); - SLANG_UNREFERENCED_PARAMETER(moduleInst); - SLANG_UNREFERENCED_PARAMETER(mangledName); -#endif -} - -void cloneFunctionCommon( - IRSpecContextBase* context, - IRFunc* clonedFunc, - IRFunc* originalFunc, - bool checkDuplicate = true) -{ - // First clone all the simple properties. - clonedFunc->setFullType(cloneType(context, originalFunc->getFullType())); - - cloneGlobalValueWithCodeCommon( - context, - clonedFunc, - originalFunc); - - // Shuffle the function to the end of the list, because - // it needs to follow its dependencies. - // - // TODO: This isn't really a good requirement to place on the IR... - clonedFunc->moveToEnd(); - - if( checkDuplicate ) - { - if( auto linkage = clonedFunc->findDecoration() ) - { - checkIRDuplicate(clonedFunc, context->getModule()->getModuleInst(), linkage->getMangledName()); - } - } -} - -// We will forward-declare the subroutine for eagerly specializing -// an IR-level generic to argument values, because `specializeIRForEntryPoint` -// needs to perform this operation even though it is logically part of -// the later generic specialization pass. -// -IRInst* specializeGeneric( - IRSpecialize* specializeInst); - -IRFunc* specializeIRForEntryPoint( - IRSpecContext* context, - EntryPoint* entryPoint, - EntryPointLayout* entryPointLayout) -{ - // We start by looking up the IR symbol that - // matches the mangled name given to the - // function we want to emit. - // - // Note: the function decl-ref may refer to - // a specialization of a generic function, - // so that the mangled name of the decl-ref is - // not the same as the mangled name of the decl. - // - auto mangledName = getMangledName(entryPoint->getFuncDeclRef()); - RefPtr sym; - if (!context->getSymbols().TryGetValue(mangledName, sym)) - { - SLANG_UNEXPECTED("no matching IR symbol"); - return nullptr; - } - - // TODO: deal with the case where we might - // have multiple (profile-overloaded) versions... - // - auto originalVal = sym->irGlobalValue; - - // We will start by cloning the entry point reference - // like any other global value. - // - auto clonedVal = cloneGlobalValue(context, originalVal); - - // In the case where the user is requesting a specialization - // of a generic entry point, we have a bit of a problem. - // - // This function is expected to return an `IRFunc` and - // subsequent passes expect to find, e.g., layout information - // attached to the parameters of such a func. - // - // In the generic case, the `clonedValue` won't be an - // `IRFunc`, but instead an `IRSpecialize`. - // - if(auto clonedSpec = as(clonedVal)) - { - // The Right Thing to do here is to perform some - // amount of generic specialization, at least - // until we get back an `IRFunc`. - // - // The dangerous thing is that the generic specialization - // pass can, in principle, change the signature of - // functions, so that attaching parameter layout - // information *after* specialization might not work. - // - // The compromise we make here is to directly - // invoke the logic for specializing a generic. - // - // In theory this isn't valid, because there is no - // way we can register the specialized function we - // create so that it would be re-used by other instantiations - // with the same arguments (because we cannot be - // sure the generic arguments are themselves fully specialized) - // - // In practice this isn't really a problem, because - // we don't want to share the definition between - // an entry point and an ordinary function anyway. - // - clonedVal = specializeGeneric(clonedSpec); - } - - // TODO: If there is an existential-related decoration - // on the entry point, we need to transfer it over - // to the specialized function. - if( auto bindExistentialSlots = originalVal->findDecorationImpl(kIROp_BindExistentialSlotsDecoration) ) - { - if( !clonedVal->findDecorationImpl(kIROp_BindExistentialSlotsDecoration) ) - { - IRBuilder builderStorage = *context->builder; - IRBuilder* builder = &builderStorage; - builder->setInsertInto(clonedVal); - - auto clonedBind = cloneInst(context, builder, bindExistentialSlots); - clonedBind->moveToStart(); - } - } - - - auto clonedFunc = as(clonedVal); - if(!clonedFunc) - { - SLANG_UNEXPECTED("expected entry point to be a function"); - return nullptr; - } - - if( !clonedFunc->findDecorationImpl(kIROp_KeepAliveDecoration) ) - { - context->builder->addKeepAliveDecoration(clonedFunc); - } - - // We need to attach the layout information for - // the entry point to this declaration, so that - // we can use it to inform downstream code emit. - // - context->builder->addLayoutDecoration( - clonedFunc, - entryPointLayout); - - // We will also go on and attach layout information - // to the function parameters, so that we have it - // available directly on the parameters, rather - // than having to look it up on the original entry-point layout. - if( auto firstBlock = clonedFunc->getFirstBlock() ) - { - auto paramsStructLayout = getScopeStructLayout(entryPointLayout); - Index paramLayoutCount = paramsStructLayout->fields.getCount(); - Index paramCounter = 0; - for( auto pp = firstBlock->getFirstParam(); pp; pp = pp->getNextParam() ) - { - Index paramIndex = paramCounter++; - if( paramIndex < paramLayoutCount ) - { - auto paramLayout = paramsStructLayout->fields[paramIndex]; - context->builder->addLayoutDecoration( - pp, - paramLayout); - } - else - { - SLANG_UNEXPECTED("too many parameters"); - } - } - } - - return clonedFunc; -} - -// Get a string form of the target so that we can -// use it to match against target-specialization modifiers -// -// TODO: We shouldn't be using strings for this. -String getTargetName(IRSpecContext* context) -{ - switch( context->shared->target ) - { - case CodeGenTarget::HLSL: - return "hlsl"; - - case CodeGenTarget::GLSL: - return "glsl"; - - case CodeGenTarget::CSource: - return "c"; - - case CodeGenTarget::CPPSource: - return "cpp"; - - default: - SLANG_UNEXPECTED("unhandled case"); - UNREACHABLE_RETURN("unknown"); - } -} - -// How specialized is a given declaration for the chosen target? -enum class TargetSpecializationLevel -{ - specializedForOtherTarget = 0, - notSpecialized, - specializedForTarget, -}; - -TargetSpecializationLevel getTargetSpecialiationLevel( - IRInst* inVal, - String const& targetName) -{ - // HACK: Currently the front-end is placing modifiers related - // to target specialization on nodes like functions, even when - // those functions are being returned by a generic. This - // means that we need to try and inspect the value being - // returned by the generic if we are looking at a generic. - IRInst* val = inVal; - while( auto genericVal = as(val) ) - { - auto firstBlock = genericVal->getFirstBlock(); - if(!firstBlock) break; - - auto returnInst = as(firstBlock->getLastInst()); - if(!returnInst) break; - - val = returnInst->getVal(); - } - - TargetSpecializationLevel result = TargetSpecializationLevel::notSpecialized; - for(auto dd : val->getDecorations()) - { - if(dd->op != kIROp_TargetDecoration) - continue; - - auto decoration = (IRTargetDecoration*) dd; - if(String(decoration->getTargetName()) == targetName) - return TargetSpecializationLevel::specializedForTarget; - - result = TargetSpecializationLevel::specializedForOtherTarget; - } - - return result; -} - -// Is `newVal` marked as being a better match for our -// chosen code-generation target? -// -// TODO: there is a missing step here where we need -// to check if things are even available in the first place... -bool isBetterForTarget( - IRSpecContext* context, - IRInst* newVal, - IRInst* oldVal) -{ - String targetName = getTargetName(context); - - // For right now every declaration might have zero or more - // modifiers, representing the targets for which it is specialized. - // Each modifier has a single string "tag" to represent a target. - // We thus decide that a declaration is "more specialized" by: - // - // - Does it have a modifier with a tag with the string for the current target? - // If yes, it is the most specialized it can be. - // - // - Does it have a no tags? Then it is "unspecialized" and that is okay. - // - // - Does it have a modifier with a tag for a *different* target? - // If yes, then it shouldn't even be usable on this target. - // - // Longer term a better approach is to think of this in terms - // of a "disjunction of conjunctions" that is: - // - // (A and B and C) or (A and D) or (E) or (F and G) ... - // - // A code generation target would then consist of a - // conjunction of invidual tags: - // - // (HLSL and SM_4_0 and Vertex and ...) - // - // A declaration is *applicable* on a target if one of - // its conjunctions of tags is a subset of the target's. - // - // One declaration is *better* than another on a target - // if it is applicable and its tags are a superset - // of the other's. - - auto newLevel = getTargetSpecialiationLevel(newVal, targetName); - auto oldLevel = getTargetSpecialiationLevel(oldVal, targetName); - if(newLevel != oldLevel) - return UInt(newLevel) > UInt(oldLevel); - - // All preceding factors being equal, an `[export]` is better - // than an `[import]`. - // - bool newIsExport = newVal->findDecoration() != nullptr; - bool oldIsExport = oldVal->findDecoration() != nullptr; - if(newIsExport != oldIsExport) - return newIsExport; - - // All preceding factors being equal, a definition is - // better than a declaration. - auto newIsDef = isDefinition(newVal); - auto oldIsDef = isDefinition(oldVal); - if (newIsDef != oldIsDef) - return newIsDef; - - return false; -} - -IRFunc* cloneFuncImpl( - IRSpecContextBase* context, - IRBuilder* builder, - IRFunc* originalFunc, - IROriginalValuesForClone const& originalValues) -{ - auto clonedFunc = builder->createFunc(); - registerClonedValue(context, clonedFunc, originalValues); - cloneFunctionCommon(context, clonedFunc, originalFunc); - return clonedFunc; -} - - -IRInst* cloneInst( - IRSpecContextBase* context, - IRBuilder* builder, - IRInst* originalInst, - IROriginalValuesForClone const& originalValues) -{ - switch (originalInst->op) - { - // We need to special-case any instruction that is not - // allocated like an ordinary `IRInst` with trailing args. - case kIROp_Func: - return cloneFuncImpl(context, builder, cast(originalInst), originalValues); - - case kIROp_GlobalVar: - return cloneGlobalVarImpl(context, builder, cast(originalInst), originalValues); - - case kIROp_GlobalConstant: - return cloneGlobalConstantImpl(context, builder, cast(originalInst), originalValues); - - case kIROp_GlobalParam: - return cloneGlobalParamImpl(context, builder, cast(originalInst), originalValues); - - case kIROp_WitnessTable: - return cloneWitnessTableImpl(context, builder, cast(originalInst), originalValues); - - case kIROp_StructType: - return cloneStructTypeImpl(context, builder, cast(originalInst), originalValues); - - case kIROp_InterfaceType: - return cloneInterfaceTypeImpl(context, builder, cast(originalInst), originalValues); - - case kIROp_Generic: - return cloneGenericImpl(context, builder, cast(originalInst), originalValues); - - case kIROp_StructKey: - return cloneStructKeyImpl(context, builder, cast(originalInst), originalValues); - - case kIROp_GlobalGenericParam: - return cloneGlobalGenericParamImpl(context, builder, cast(originalInst), originalValues); - - default: - break; - } - - // The common case is that we just need to construct a cloned - // instruction with the right number of operands, intialize - // it, and then add it to the sequence. - UInt argCount = originalInst->getOperandCount(); - IRInst* clonedInst = builder->createIntrinsicInst( - cloneType(context, originalInst->getFullType()), - originalInst->op, - argCount, nullptr); - registerClonedValue(context, clonedInst, originalValues); - auto oldBuilder = context->builder; - context->builder = builder; - for (UInt aa = 0; aa < argCount; ++aa) - { - IRInst* originalArg = originalInst->getOperand(aa); - IRInst* clonedArg = cloneValue(context, originalArg); - clonedInst->getOperands()[aa].init(clonedInst, clonedArg); - } - builder->addInst(clonedInst); - context->builder = oldBuilder; - cloneDecorations(context, clonedInst, originalInst); - - return clonedInst; -} - -IRInst* cloneGlobalValueImpl( - IRSpecContext* context, - IRInst* originalInst, - IROriginalValuesForClone const& originalValues) -{ - auto clonedValue = cloneInst(context, &context->shared->builderStorage, originalInst, originalValues); - clonedValue->moveToEnd(); - return clonedValue; -} - - - /// Clone a global value, which has the given `originalLinkage`. - /// - /// The `originalVal` is a known global IR value with that linkage, if one is available. - /// (It is okay for this parameter to be null). - /// -IRInst* cloneGlobalValueWithLinkage( - IRSpecContext* context, - IRInst* originalVal, - IRLinkageDecoration* originalLinkage) -{ - // If the global value being cloned is already in target module, don't clone - // Why checking this? - // When specializing a generic function G (which is already in target module), - // where G calls a normal function F (which is already in target module), - // then when we are making a copy of G via cloneFuncCommom(), it will recursively clone F, - // however we don't want to make a duplicate of F in the target module. - if (originalVal->getParent() == context->getModule()->getModuleInst()) - return originalVal; - - // Check if we've already cloned this value, for the case where - // an original value has already been established. - if (originalVal) - { - if (IRInst* clonedVal = findClonedValue(context, originalVal)) - { - return clonedVal; - } - } - - if(!originalLinkage) - { - // If there is no mangled name, then we assume this is a local symbol, - // and it can't possibly have multiple declarations. - return cloneGlobalValueImpl(context, originalVal, IROriginalValuesForClone(originalVal)); - } - - // - // We will scan through all of the available declarations - // with the same mangled name as `originalVal` and try - // to pick the "best" one for our target. - - auto mangledName = String(originalLinkage->getMangledName()); - RefPtr sym; - if( !context->getSymbols().TryGetValue(mangledName, sym) ) - { - if(!originalVal) - return nullptr; - - // This shouldn't happen! - SLANG_UNEXPECTED("no matching values registered"); - UNREACHABLE_RETURN(cloneGlobalValueImpl(context, originalVal, IROriginalValuesForClone())); - } - - // We will try to track the "best" declaration we can find. - // - // Generally, one declaration wil lbe better than another if it is - // more specialized for the chosen target. Otherwise, we simply favor - // definitions over declarations. - // - IRInst* bestVal = sym->irGlobalValue; - for( auto ss = sym->nextWithSameName; ss; ss = ss->nextWithSameName ) - { - IRInst* newVal = ss->irGlobalValue; - if(isBetterForTarget(context, newVal, bestVal)) - bestVal = newVal; - } - - // Check if we've already cloned this value, for the case where - // we didn't have an original value (just a name), but we've - // now found a representative value. - if (!originalVal) - { - if (IRInst* clonedVal = findClonedValue(context, bestVal)) - { - return clonedVal; - } - } - - return cloneGlobalValueImpl(context, bestVal, IROriginalValuesForClone(sym)); -} - -// Clone a global value, where `originalVal` is one declaration/definition, but we might -// have to consider others, in order to find the "best" version of the symbol. -IRInst* cloneGlobalValue(IRSpecContext* context, IRInst* originalVal) -{ - // We are being asked to clone a particular global value, but in - // the IR that comes out of the front-end there could still - // be multiple, target-specific, declarations of any given - // global value, all of which share the same mangled name. - return cloneGlobalValueWithLinkage( - context, - originalVal, - originalVal->findDecoration()); -} - -void insertGlobalValueSymbol( - IRSharedSpecContext* sharedContext, - IRInst* gv) -{ - auto linkage = gv->findDecoration(); - - // Don't try to register a symbol for global values - // that don't have linkage. - // - if (!linkage) - return; - - auto mangledName = String(linkage->getMangledName()); - - RefPtr sym = new IRSpecSymbol(); - sym->irGlobalValue = gv; - - RefPtr prev; - if (sharedContext->symbols.TryGetValue(mangledName, prev)) - { - sym->nextWithSameName = prev->nextWithSameName; - prev->nextWithSameName = sym; - } - else - { - sharedContext->symbols.Add(mangledName, sym); - } -} - -void insertGlobalValueSymbols( - IRSharedSpecContext* sharedContext, - IRModule* originalModule) -{ - if (!originalModule) - return; - - for(auto ii : originalModule->getGlobalInsts()) - { - insertGlobalValueSymbol(sharedContext, ii); - } -} - -void initializeSharedSpecContext( - IRSharedSpecContext* sharedContext, - Session* session, - IRModule* module, - CodeGenTarget target) -{ - - SharedIRBuilder* sharedBuilder = &sharedContext->sharedBuilderStorage; - sharedBuilder->module = nullptr; - sharedBuilder->session = session; - - IRBuilder* builder = &sharedContext->builderStorage; - builder->sharedBuilder = sharedBuilder; - - if( !module ) - { - module = builder->createModule(); - } - - sharedBuilder->module = module; - sharedContext->module = module; - sharedContext->target = target; -} - -struct IRSpecializationState -{ - ProgramLayout* programLayout; - CodeGenTarget target; - TargetRequest* targetReq; - - IRModule* irModule = nullptr; - - IRSharedSpecContext sharedContextStorage; - IRSpecContext contextStorage; - - IRSpecEnv globalEnv; - - IRSharedSpecContext* getSharedContext() { return &sharedContextStorage; } - IRSpecContext* getContext() { return &contextStorage; } - - IRSpecializationState() - { - contextStorage.env = &globalEnv; - } - - ~IRSpecializationState() - { - contextStorage = IRSpecContext(); - sharedContextStorage = IRSharedSpecContext(); - } -}; - -LinkedIR linkIR( - BackEndCompileRequest* compileRequest, - EntryPoint* entryPoint, - ProgramLayout* programLayout, - CodeGenTarget target, - TargetRequest* targetReq) -{ - auto sink = compileRequest->getSink(); - - IRSpecializationState stateStorage; - auto state = &stateStorage; - - state->programLayout = programLayout; - state->target = target; - state->targetReq = targetReq; - - auto program = compileRequest->getProgram(); - - auto sharedContext = state->getSharedContext(); - initializeSharedSpecContext( - sharedContext, - compileRequest->getSession(), - nullptr, - target); - - state->irModule = sharedContext->module; - - // We need to be able to look up IR definitions for any symbols in - // modules that the program depends on (transitively). To - // accelerate lookup, we will create a symbol table for looking - // up IR definitions by their mangled name. - // - auto originalProgramIRModule = program->getOrCreateIRModule(sink); - insertGlobalValueSymbols(sharedContext, originalProgramIRModule); - for (auto module : program->getModuleDependencies()) - { - insertGlobalValueSymbols(sharedContext, module->getIRModule()); - } - - auto context = state->getContext(); - context->shared = sharedContext; - context->builder = &sharedContext->builderStorage; - - // Next, we want to optimize lookup for layout information - // associated with global declarations, so that we can - // look things up based on the IR values (using mangled names) - // - // Note: We are scanning over all the key-value pairs for - // entries in the global scope, to account for the fact - // that the "same" shader parameter could be declared in - // multiple translation units, and thus end up with - // multiple mangled names (when the unique translation - // unit name gets involved). - // - auto globalStructLayout = getScopeStructLayout(programLayout); - for(auto entry : globalStructLayout->mapVarToLayout) - { - auto mangledName = getMangledName(entry.Key); - auto globalVarLayout = entry.Value; - context->globalVarLayouts.AddIfNotExists(mangledName, globalVarLayout); - } - - context->builder->setInsertInto(context->getModule()->getModuleInst()); - - // for now, clone all unreferenced witness tables - // - // TODO: This step should *not* be needed with the current IR - // specialization approach, so we should consider removing it. - // - for (auto sym :context->getSymbols()) - { - if (sym.Value->irGlobalValue->op == kIROp_WitnessTable) - cloneGlobalValue(context, (IRWitnessTable*)sym.Value->irGlobalValue); - } - - auto entryPointLayout = findEntryPointLayout(programLayout, entryPoint); - - // Next, we make sure to clone the global value for - // the entry point function itself, and rely on - // this step to recursively copy over anything else - // it might reference. - auto irEntryPoint = specializeIRForEntryPoint(context, entryPoint, entryPointLayout); - - // HACK: right now the bindings for global generic parameters are coming in - // as part of the original IR module, and we need to make sure these get - // copied over, even if they aren't referenced. - // - for(auto inst : originalProgramIRModule->getGlobalInsts()) - { - auto bindInst = as(inst); - if(!bindInst) - continue; - - cloneValue(context, bindInst); - } - - for(auto inst : originalProgramIRModule->getGlobalInsts()) - { - if(inst->op != kIROp_BindGlobalExistentialSlots) - continue; - - cloneValue(context, inst); - } - - // HACK: we need to ensure that any tagged union types - // in the IR module have layout information copied over to them. - // - // Note that we do this *after* cloning the `bindGlobalGenericParam` - // instructions, since we expected the tagged union type(s) to - // be referenced by them. - // - for( auto taggedUnionTypeLayout : entryPointLayout->taggedUnionTypeLayouts ) - { - auto taggedUnionType = taggedUnionTypeLayout->getType(); - auto mangledName = getMangledTypeName(taggedUnionType); - - RefPtr sym; - if(!context->getSymbols().TryGetValue(mangledName, sym)) - continue; - - IRInst* clonedType = findClonedValue(context, sym->irGlobalValue); - if(!clonedType) - continue; - - context->builder->addLayoutDecoration(clonedType, taggedUnionTypeLayout); - } - - // TODO: *technically* we should consider the case where - // we have global variables with initializers, since - // these should get run whether or not the entry point - // references them. - - // Now that we've cloned the entry point and everything - // it refers to, we can package up the data we return - // to the caller. - // - LinkedIR linkedIR; - linkedIR.module = state->irModule; - linkedIR.entryPoint = irEntryPoint; - return linkedIR; -} - - - -} // namespace Slang diff --git a/source/slang/ir-link.h b/source/slang/ir-link.h deleted file mode 100644 index dba3ccc97..000000000 --- a/source/slang/ir-link.h +++ /dev/null @@ -1,27 +0,0 @@ -// ir-link.h -#pragma once - -#include "compiler.h" - -namespace Slang -{ - struct LinkedIR - { - RefPtr module; - IRFunc* entryPoint; - }; - - - // Clone the IR values reachable from the given entry point - // into the IR module associated with the specialization state. - // When multiple definitions of a symbol are found, the one - // that is best specialized for the given `targetReq` will be - // used. - // - LinkedIR linkIR( - BackEndCompileRequest* compileRequest, - EntryPoint* entryPoint, - ProgramLayout* programLayout, - CodeGenTarget target, - TargetRequest* targetReq); -} diff --git a/source/slang/ir-missing-return.cpp b/source/slang/ir-missing-return.cpp deleted file mode 100644 index c32b71ab6..000000000 --- a/source/slang/ir-missing-return.cpp +++ /dev/null @@ -1,43 +0,0 @@ -// ir-missing-return.cpp -#include "ir-missing-return.h" - -#include "ir.h" -#include "ir-insts.h" - -namespace Slang { - -class DiagnosticSink; -struct IRModule; - -void checkForMissingReturnsRec( - IRInst* inst, - DiagnosticSink* sink) -{ - if( auto code = as(inst) ) - { - for( auto block : code->getBlocks() ) - { - auto terminator = block->getTerminator(); - - if( auto missingReturn = as(terminator) ) - { - sink->diagnose(missingReturn, Diagnostics::missingReturn); - } - } - } - - for( auto childInst : inst->getDecorationsAndChildren() ) - { - checkForMissingReturnsRec(childInst, sink); - } -} - -void checkForMissingReturns( - IRModule* module, - DiagnosticSink* sink) -{ - // Look for any `missingReturn` instructions - checkForMissingReturnsRec(module->getModuleInst(), sink); -} - -} diff --git a/source/slang/ir-missing-return.h b/source/slang/ir-missing-return.h deleted file mode 100644 index 0d22a07c4..000000000 --- a/source/slang/ir-missing-return.h +++ /dev/null @@ -1,12 +0,0 @@ -// ir-missing-return.h -#pragma once - -namespace Slang -{ - class DiagnosticSink; - struct IRModule; - - void checkForMissingReturns( - IRModule* module, - DiagnosticSink* sink); -} diff --git a/source/slang/ir-restructure-scoping.cpp b/source/slang/ir-restructure-scoping.cpp deleted file mode 100644 index c5e628e71..000000000 --- a/source/slang/ir-restructure-scoping.cpp +++ /dev/null @@ -1,434 +0,0 @@ -// ir-restructure-scoping.cpp -#include "ir-restructure-scoping.h" - -#include "ir.h" -#include "ir-insts.h" -#include "ir-restructure.h" - -namespace Slang -{ - -/// Try to find the first structured region that represents `block` -/// -/// In general the same block may appear as multiple regions, -/// so this will return the first region in the linked list. -static SimpleRegion* getFirstRegionForBlock( - RegionTree* regionTree, - IRBlock* block) -{ - SimpleRegion* region = nullptr; - if( regionTree->mapBlockToRegion.TryGetValue(block, region) ) - { - return region; - } - return nullptr; -} - -/// Try to find the first structured region that contains `inst`. -static SimpleRegion* getFirstRegionForInst( - RegionTree* regionTree, - IRInst* inst) -{ - auto ii = inst; - while(ii) - { - if(auto block = as(ii)) - return getFirstRegionForBlock(regionTree, block); - - ii = ii->getParent(); - } - - return nullptr; -} - -/// Compute the depth of a node in the region tree. -/// -/// This is the number of nodes (including `region`) -/// on a path from `region` to the root. -/// -static Int computeDepth(Region* region) -{ - Int depth = 0; - for( Region* rr = region; rr; rr = rr->getParent() ) - { - depth++; - } - return depth; -} - -/// Get the `n`th ancestor of `region`. -/// -/// When `n` is zero, this returns `region`. -/// When `n` is one, this returns the parent of `region`, and so forth. -/// -static Region* getAncestor(Region* region, Int n) -{ - Region* rr = region; - for( Int ii = 0; ii < n; ++ii ) - { - SLANG_ASSERT(rr); - rr = rr->getParent(); - } - return rr; -} - -/// Find a region that is an ancestor of both `left` and `right`. -static Region* findCommonAncestorRegion( - Region* left, - Region* right) -{ - // Rather than blinding search through each ancestor of `left` - // and see if it is also an ancestor of `right` and vice-versa, - // let's try to be smart about this. - // - // We will start by computing the depth of `left` and `right`: - // - Int leftDepth = computeDepth(left); - Int rightDepth = computeDepth(right); - - // Whatever the common ancestor is, it can't be any deeper - // than the minimum of these two depths. - // - Int minDepth = Math::Min(leftDepth, rightDepth); - - // Let's fetch the ancestor of each of `left` and `right` - // corresponding to that depth: - // - Region* leftAncestor = getAncestor(left, leftDepth - minDepth); - Region* rightAncestor = getAncestor(right, rightDepth - minDepth); - - // Now we know that `leftAncestor` and `rightAncestor` - // must have the same depth. Let's go ahead and assert - // it just to be safe: - // - SLANG_ASSERT(computeDepth(leftAncestor) == minDepth); - SLANG_ASSERT(computeDepth(rightAncestor) == minDepth); - - // If `leftAncestor` and `rightAncestor` are the same node, - // then we've found a common ancestor, otherwise we should - // look at their parents. Because the depth must match - // on both sides, we will never risk missing an ancestor. - // - while( leftAncestor != rightAncestor ) - { - leftAncestor = leftAncestor->getParent(); - rightAncestor = rightAncestor->getParent(); - } - - // Okay, we've found a common ancestor. - // - Region* commonAncestor = leftAncestor; - return commonAncestor; -} - -/// Find a simple region that is an ancestor of both `left` and `right`. -static SimpleRegion* findSimpleCommonAncestorRegion( - Region* left, - Region* right) -{ - // Start by finding a common ancestor without worrying about it being simple. - Region* ancestor = findCommonAncestorRegion(left, right); - - // Now search for a simple region up the tree. - while( ancestor ) - { - if(ancestor->getFlavor() == Region::Flavor::Simple) - return (SimpleRegion*) ancestor; - - ancestor = ancestor->getParent(); - } - - // This shouldn't ever occur. The root of the region tree should - // be a simple regions that represents the entry block of the - // function. - // - SLANG_UNEXPECTED("no common ancestor found in region tree"); - UNREACHABLE_RETURN(nullptr); -} - -IRInst* getDefaultInitVal( - IRBuilder* builder, - IRType* type) -{ - switch( type->op ) - { - default: - return nullptr; - - case kIROp_BoolType: - return builder->getBoolValue(false); - - case kIROp_IntType: - case kIROp_UIntType: - case kIROp_UInt64Type: - return builder->getIntValue(type, 0); - - case kIROp_HalfType: - case kIROp_FloatType: - case kIROp_DoubleType: - return builder->getFloatValue(type, 0.0); - - // TODO: handle vector/matrix types here, by - // creating an appropriate scalar value and - // then "splatting" it. - } -} - -/// Initialize a variable to a sane default value, if possible. -void defaultInitializeVar( - IRBuilder* builder, - IRVar* var, - IRType* type) -{ - IRInst* initVal = nullptr; - switch( type->op ) - { - case kIROp_VoidType: - default: - // By default, see if we can synthesize an IR value - // to be used as the default, and allow the logic - // below to store it into the variable. - initVal = getDefaultInitVal(builder, type); - break; - - // TODO: Handle aggregate types (structures, arrays) - // explicitly here, since they need to be careful about - // the cases where an element/field type might not - // be something we can default-initialize. - } - - if( initVal ) - { - builder->emitStore(var, initVal); - } -} - -/// Detect and fix any structured scoping issues for a given `def` instruction. -/// -/// The `defRegion` should be the region that contains `def`, and `regionTree` -/// should be the region tree for the function that contains `def`. -/// -static void fixValueScopingForInst( - IRInst* def, - SimpleRegion* defRegion, - RegionTree* regionTree) -{ - // This algorithm should not consider "phi nodes" for now, - // because the emit logic will already create variables for them. - // We could consider folding the logic to move out of SSA form - // into this function, but that would add a lot of complexity for now. - if(def->op == kIROp_Param) - return; - - // We would have a scoping violation if there exists some - // use `u` of `def` such that the region containing `u` - // (call it `useRegion`) is not a descendent of `defRegion` - // in the region tree. - // - // If there are no scoping violations, we don't want to do - // anything. If there *are* any scoping violations, then - // we ill need to introduce a temporary `tmp`, store into - // it right after `def`, and then load from it at any "bad" - // use sites. - // - // Of course, for the whole thing to work, we also need - // to put `tmp` into a block somwhere, and it needs to - // be a block that is visible to all of the uses, or we - // are just back int the same mess again. - // - // The right block to use for inserting `tmp` is the least - // common ancestor of `def` and all the "bad" uses, so - // we will get a bit "clever" and fold in the search for - // bad uses with the computation of the region we should - // insert `tmp` into (to avoid looping over the uses - // twice). - // - SimpleRegion* insertRegion = defRegion; - IRVar* tmp = nullptr; - - // If we end up needing to insert code we'll need an IR builder, - // so we will go ahead and create one now. - // - // TODO: the logic to compute `module` here could be hoisted - // out earlier, rather than being done per-instruction. - // - IRModule* module = regionTree->irCode->getModule(); - - SharedIRBuilder sharedBuilder; - sharedBuilder.session = module->session; - sharedBuilder.module = module; - - IRBuilder builder; - builder.sharedBuilder = &sharedBuilder; - - // Because we will be changing some of the uses of `def` - // to use other values while we iterate the list, we - // need to be a bit careful and extract the next use - // in the linked list *before* we operator on `u`. - // - IRUse* nextUse = nullptr; - for( auto u = def->firstUse; u; u = nextUse ) - { - nextUse = u->nextUse; - - // Looking at the use site `u`, we'd like to check if - // it violates our scoping rules. - // - // As a simple early-exit case, if the user is in - // the same block as the definition, there are no problems. - // - IRInst* user = u->getUser(); - if(user->getParent() == defRegion->block) - continue; - - // Otherwise, let's find the structures control-flow - // region that holds the user. We expect to always - // find one, because the use site must be in the same - // function. - // - // TODO: Double check that logic if we ever introduce - // things like nested function. - // - SimpleRegion* useRegion = getFirstRegionForInst(regionTree, user); - - // If there is no region associated with the use, then - // the use must be in unreachable code (part of the CFG, - // but not part of the region tree). We will skip - // such uses for now, since they won't even appear in - // the output. - // - if(!useRegion) - continue; - - // Now we want to check if `useRegion` is a child/descendent - // of a region that has the same block as `defRegion`. - // If it is, then there is no scoping problem with this use. - // - if(useRegion->isDescendentOf(defRegion->block)) - continue; - - // If we've gotten this far, we know that `u` is a "bad" - // use of `def`, and needs fixing. - // - // We will create the `tmp` variable on demand, so - // that we create it when the first "bad" use is encountered, - // and then re-use it for subsequent bad uses. - // - if( !tmp ) - { - // We will create a temporary to represent `def`, - // and insert a `store` into it right after `def`. - // - // Note: we are inserting the new variable right - // after `def` for now, just because we don't - // yet know the final region that it should be - // placed into. We will move it to the correct - // location when we are done. - // - builder.setInsertBefore(def->getNextInst()); - tmp = builder.emitVar(def->getDataType()); - builder.emitStore(tmp, def); - } - - // In order to know where `tmp` should be defined - // at the end of the algorithm, we need to compute - // a valid `insertRegion` that is an ancestor of - // all of the use sites (and it also a simple region - // so that we can insert into its IR block). - // - // We need to deal with one complexity in our restructuring - // process, which is that a block may be duplicated into - // one or more regions, so we loop over all the regions - // for the same block as `useRegion`. - // - for(auto rr = useRegion; rr; rr = rr->nextSimpleRegionForSameBlock) - { - insertRegion = findSimpleCommonAncestorRegion( - insertRegion, - rr); - } - - // To fix up the use `u`, we will need to change - // it from using `def` to using a load from `tmp` - // - builder.setInsertBefore(user); - IRInst* tmpVal = builder.emitLoad(tmp); - - // We are clobbering the value used by the `IRUse` `u`, - // while will cut it out of the list of uses for `def`. - // We need to be careful when doing this to not disrupt - // our iteration of the uses of `def`, so we carefully - // used the `nextUse` temporary at the start of the loop. - // - u->set(tmpVal); - } - - // At the end of the loop, the `tmp` variable will have - // been created if and only if we fixed up anything. - // - if( tmp ) - { - // If we created a temporary, then now we need to move - // its definition to the right place, which is the - // `insertRegion` that we computed during the loop. - // - // We'd like to insert our temporary near the top - // of the region, since that is the conventional - // place for local variables to go. - // - tmp->insertBefore( - insertRegion->block->getFirstOrdinaryInst()); - - // The whole point of the transformation we are doing - // here is that `def` is not on the "obvious" control - // flow path to one or more uses (which are now using - // `tmp`), but that means that it might not be "obvious" - // to a downstream compiler that `tmp` always gets - // initialized (by the code we inserted after `def`) - // before each of these use sites. - // - // We *know* that things are valid as long as our - // dominator tree was valid - there is no way to - // get to the block that loads from `tmp` without passing - // through the block that computes `def` (and then - // stores it into `tmp`) first. - // - // To avoid warnings/errros, we will go ahead and try - // to emit logic to "default initialize" the `tmp` - // variable if possible. - // - builder.setInsertBefore(tmp->getNextInst()); - defaultInitializeVar(&builder, tmp, def->getDataType()); - } -} - -void fixValueScoping(RegionTree* regionTree) -{ - // We are going to have to walk through every instruction - // in the code of the function to detect an bad cases. - // - auto code = regionTree->irCode; - for(auto block : code->getBlocks()) - { - // All of the instruction in `block` will have the same - // parent region, so we will look it up now rather than - // have to re-do this work on a per-instruction basis. - // - auto parentRegion = getFirstRegionForBlock(regionTree, block); - - // If a block has no region then it must be unreachable, - // so we will skip it entirely for this pass. - // - // TODO: we should be eliminating unrechable blocks anyway. - // - if(!parentRegion) - continue; - - for(auto inst : block->getDecorationsAndChildren()) - { - fixValueScopingForInst(inst, parentRegion, regionTree); - } - } -} - -} diff --git a/source/slang/ir-restructure-scoping.h b/source/slang/ir-restructure-scoping.h deleted file mode 100644 index 7840dda80..000000000 --- a/source/slang/ir-restructure-scoping.h +++ /dev/null @@ -1,24 +0,0 @@ -// ir-restructure-scoping.h -#pragma once - -namespace Slang -{ - -class RegionTree; - -/// Fix cases where a value might be used in a non-nested region. -/// -/// There can be cases where an IR value V in block A is used in -/// some block B, where A dominates B, *but* when we constructed -/// the region tree, the block B is not in a child/descendent -/// region of A's region, so that it won't be visible through the -/// scoping rules of a target language. -/// -/// This function detects such cases, and fixes them up by inserting -/// new temporaries into the IR code so that values that need -/// to survive across blocks are communicated through variables -/// declared at a sufficiently broad scope. -/// -void fixValueScoping(RegionTree* regionTree); - -} diff --git a/source/slang/ir-restructure.cpp b/source/slang/ir-restructure.cpp deleted file mode 100644 index 47a0d1fee..000000000 --- a/source/slang/ir-restructure.cpp +++ /dev/null @@ -1,663 +0,0 @@ -// ir-restructure.cpp -#include "ir-restructure.h" - -#include "ir.h" -#include "ir-insts.h" - -namespace Slang -{ - bool Region::isDescendentOf(Region* other) - { - Region* rr = this; - while( rr ) - { - if(rr == other) - return true; - - rr = rr->getParent(); - } - return false; - } - - bool Region::isDescendentOf(IRBlock* block) - { - Region* rr = this; - while( rr ) - { - if( rr->getFlavor() == Region::Flavor::Simple ) - { - SimpleRegion* simpleRegion = (SimpleRegion*) rr; - if(simpleRegion->block == block) - return true; - } - - rr = rr->getParent(); - } - return false; - } - - /// An "active" label during control flow (re)structuring. - struct LabelStack - { - /// Possible operations associated with labels. - enum class Op - { - Break, - Continue, - - CountOf, - }; - - /// What kind of operation does a branch to this label represent? - Op op; - - /// The next label down on the stack - LabelStack* parent; - - /// The block the represents this label in the IR control flow graph. - IRBlock* block; - - /// The region that represents this label in the structured program - Region* region; - }; - - /// State used when restructuring control flow. - struct ControlFlowRestructuringContext - { - /// Sink to use when diagnosing errors in control-flow restructuring. - /// - /// The restructuring pass should be able to handle anything the front-end - /// throws at it, so these errors will all be unexpected. Still, we need - /// a way to report them cleanly without crashing the process. - /// - DiagnosticSink* sink = nullptr; - DiagnosticSink* getSink() { return sink; } - - /// The region tree we are in the process of building. - RegionTree* regionTree = nullptr; - }; - - /// Convert a range of blocks in the IR CFG into a region. - /// - /// We want to generate a region that stands in for the - /// blocks that are logically in the internal [begin, end) - /// which we consider as representing a single-entry multiple-exit - /// sub-graph. Note that `end` is *not* part of the sub-graph, - /// but instead points to a block that is logically "after" - /// the sub-graph. `end` can be `null` to indicate that the - /// sub-graph extends as far as possible. - /// - /// Because there can be multiple exits, control flow may - /// exit the sub-graph without branching to `end`, any - /// such "non-local" branching should be to one of the - /// blocks stored in the current `LabelStack`. - /// - // TODO: Eventually we should replace all of this logic with - // a variation on the "Relooper" algorithm as it is used - // in Emscripten. - // - static RefPtr generateRegionsForIRBlocks( - ControlFlowRestructuringContext* ctx, - Region* inParentRegion, - IRBlock* begin, - IRBlock* end, - LabelStack* initialLabels, // Labels to use at the start - LabelStack* labels = nullptr) // Labels to switch to after emitting first basic block - { - if(!labels) - labels = initialLabels; - auto useLabels = initialLabels; - - // - // We will try to build up as long of a sequential/simple region - // as possible, to avoid deep recursion in this algorithm. - // - RefPtr resultRegion = nullptr; - RefPtr* resultLink = &resultRegion; - - // As we move along, the parent region to use for regions - // we create will shift, so we need a temporary to track - // the current parent region. - // - Region* parentRegion = inParentRegion; - - // - // We will start with the `begin` block, and try to proceed - // sequentially until we see the `end` block, or run into - // an edge that exits teh region. - // - IRBlock* block = begin; - while(block != end) - { - // If the block we are trying to emit has been registered as a - // destination label (e.g. for a loop or `switch`) then we - // need to exit the current region, which amounts to generating - // a `break` or `continue` operation. - // - // TODO: we eventually need to handle the possibility of - // multi-level break/continue targets, which could be challenging. - - // Because we will only support single-level break/continue, we - // want to resolve what is the most recent label that is "active" - // for the given operation (`break` or `continue`). - // - // We will do this with a naive loop, just to keep things simple. - // We start with no block "regsitered" as the target for each - // operation. - // - IRBlock* registeredBlock[(int)LabelStack::Op::CountOf] = {}; - for( auto ll = useLabels; ll; ll = ll->parent ) - { - // For each active label, see if it is the first one - // we encounter for the given op. - // - if(!registeredBlock[(int)ll->op]) - { - registeredBlock[(int)ll->op] = ll->block; - } - } - - // Next we will search through *all* of the registered labels, - // and see if one of them matches the current `block`. - // - for(auto ll = useLabels; ll; ll = ll->parent) - { - // Does this label match the block we are trying to translate? - if(ll->block != block) - continue; - - // Okay, the block we are trying to generate code for is a label - // that we should branch to (we shouldn't just emit the code here - // and now...) - // - // We should first confirm that the block is the inner-most label - // registered for the given control-flow op (`break` or `continue`) - // because if it *isn't* we currently can't generate code. - // - if(block != registeredBlock[(int)ll->op]) - { - ctx->getSink()->diagnose(block, Diagnostics::multiLevelBreakUnsupported); - } - - // Now we need to create a structured `break` or `continue` operation - // to match the operation associated with the target. - // - switch(ll->op) - { - case LabelStack::Op::Break: - { - auto outerRegion = (BreakableRegion*) ll->region; - RefPtr breakRegion = new BreakRegion(parentRegion, outerRegion); - - *resultLink = breakRegion; - resultLink = nullptr; - } - break; - - case LabelStack::Op::Continue: - { - auto outerRegion = (LoopRegion*) ll->region; - RefPtr continueRegion = new ContinueRegion(parentRegion, outerRegion); - - *resultLink = continueRegion; - resultLink = nullptr; - } - break; - } - - // If the `block` matched an active label, then we should have - // created a branch, and there is nothing to be done here. - return resultRegion; - } - - // We now know that the given `block` is part of our control-flow region, - // so we need to output a simple region that executes the code in that block. - // - RefPtr simpleRegion = new SimpleRegion(parentRegion, block); - - // We need to register the mapping from `block` to this region, but in - // general this isn't a one-to-one mapping, but rather one-to-many. - // This is because a "continue clause" in a `for` loop might get duplicated - // at each `continue` site in the output code. To deal with this - // we build a singly-linked list of regions for each block. - // - // TODO: confirm that continue clauses are the only case that leads - // to duplication. - // - // TODO: remove this workaround once we have a more powerful restructuring - // pass that avoids duplicating blocks (by introducing new temporaries...) - // - SimpleRegion* nextSimpleRegionForSameBlock = nullptr; - ctx->regionTree->mapBlockToRegion.TryGetValue(block, nextSimpleRegionForSameBlock); - ctx->regionTree->mapBlockToRegion[block] = simpleRegion; - - *resultLink = simpleRegion; - resultLink = &simpleRegion->nextRegion; - parentRegion = simpleRegion; - - // The simple region we created will represent all of the non-terminator - // instructions in the `block`, so now we need to figure out what to - // create to represent that terminator. - // - auto terminator = block->getTerminator(); - SLANG_ASSERT(terminator != nullptr); - switch (terminator->op) - { - default: - case kIROp_conditionalBranch: - // Note: we don't currently generate ordinary `conditionalBranch` instructions, - // and instead only generate `ifElse` instructions, which include additional - // information that can inform our control-flow restructuring pass. - // - SLANG_UNEXPECTED("unhandled terminator instruction opcode"); - ; // fall through to: - case kIROp_Unreachable: - case kIROp_MissingReturn: - case kIROp_ReturnVal: - case kIROp_ReturnVoid: - case kIROp_discard: - // These cases are all simple terminators that can be handled as-is - // without needing to construct a separate `Region` to encapsulate them. - // - // We will cap off the current sequence of simple regions and return. - // - *resultLink = nullptr; - return resultRegion; - - case kIROp_ifElse: - { - // Here we have a two-way branch, so that we will construct a - // region representing an `if` statement. - // - auto ifInst = (IRIfElse*)terminator; - auto condition = ifInst->getCondition(); - auto trueBlock = ifInst->getTrueBlock(); - auto falseBlock = ifInst->getFalseBlock(); - auto afterBlock = ifInst->getAfterBlock(); - - - RefPtr ifRegion = new IfRegion(parentRegion, condition); - - // The region for the "then" part of things will consist of - // the range of blocks `[trueBlock, afterBlock)`. - // - // This logic assumes that `afterBlock` is a valid structured - // "join point" such that any branch out of the sub-region - // either leads to `afterBlock` *or* one of the labels - // that is already present on our label stack. - // - ifRegion->thenRegion = generateRegionsForIRBlocks( - ctx, - ifRegion, - trueBlock, - afterBlock, - labels); - - // Generating a region for the `else` part is similar. - // Note that it is possible for this to be a `null` - // region, if `falseBlock == afterBlock`. - // - ifRegion->elseRegion = generateRegionsForIRBlocks( - ctx, - ifRegion, - falseBlock, - afterBlock, - labels); - - *resultLink = ifRegion; - resultLink = &ifRegion->nextRegion; - parentRegion = ifRegion; - - // Continue with the block after the `ifElse` instruction. - block = afterBlock; - } - break; - - case kIROp_loop: - { - // The terminator in this case is the header for a structured loop. - // - auto loopInst = (IRLoop*) terminator; - auto bodyBlock = loopInst->getTargetBlock(); - auto afterBlock = loopInst->getBreakBlock(); - - RefPtr loopRegion = new LoopRegion(parentRegion, loopInst); - - // We will need to set up entries on our label stack to - // represent the targets for `break` or `continue` - // operations inside the loop. - // - // First we set up the stack entry for the `break` label, - // which will refer to the block *after* the loop. - // - // The region we specify for the label will still be - // the loop region, though, because the loop is what - // we are breaking out of. - // - LabelStack loopBreakLabelStack; - loopBreakLabelStack.parent = labels; - loopBreakLabelStack.block = afterBlock; - loopBreakLabelStack.region = loopRegion; - loopBreakLabelStack.op = LabelStack::Op::Break; - - // - // The `continue` label warrants a bit more careful explanation, - // because it will *not* refer to the block that was regsitered - // as the continue target in the IR `loop` instruction. This - // is because we will always emit our loops as `for(;;) { ... }` - // with no continue clause at all, so that a `continue` in - // the output code will always refer to the top of the loop. - // - // This means that the `continue` label for the purposes of - // structured control flow will be the start of the loop body: - // - LabelStack loopContinueLabelStack; - loopContinueLabelStack.parent = &loopBreakLabelStack; - loopContinueLabelStack.block = bodyBlock; - loopContinueLabelStack.region = loopRegion; - loopContinueLabelStack.op = LabelStack::Op::Continue; - // - // Note: by ignoring the original continue block from the - // high-level loop, we create a situation where that code - // might get emitted more than once (once per implicit - // or explicit `continue` site in the original program). - // - // That is an acceptable trade-off for now, because continue - // blocks will usually be small (and fxc makes the same choice), - // but it could lead to Bad Things if somebody were to call - // a function in their continue clause, and that function does - // a compute shader barrier operation. - // - // A better long-term fix is to take a high-level loop like: - // - // for(A; B; C) { ... continue; ... break; ... } - // - // and translate it into something like the following (assuming - // we have labeled statements and multi-level `break`): - // - // A; - // Outer: for(;;) { - // Inner: for(;;) { - // if(B) {} else break Outer; - // ... - // break Inner; // `continue` becomes break of inner loop - // ... - // break Outer; // `break` becomes break of outer loop - // ... - // break; // inner loop unconditionally breaks at the end - // } - // C; // continue clause comes after inner loop - // } - // - // If you draw up a control flow graph for that code, you'll find - // it is equivalent to the orignal `for` loop, but now supports - // arbitrary code (not just a single expression) for the continue clause. - // Unlike the current code-duplication solution, `C` appears only once - // in the output, and seems to clearly be at a "joint point" for control - // flow so that it is clear that a barrier there is valid in GLSL. - // - // Anyway, back our regularly scheduled programming. - // - // With the label stack stuff set up, we want to take the region - // of the CFG defined by `[bodyBlock, afterBlock)` and turn it into - // the body region for our loop. - // - // The only thing we want to be a little bit careful about is - // that we don't want the logic at the top of this function - // that looks for a block it can translate into a `continue` - // to trigger on `bodyBlock`, since that means we'd just turn - // the whole body into a single `continue`. - // - // To avoid this problem, we pass in two different label stacks: - // one to use for the first block, and one to use for subsequent - // blocks. - // - loopRegion->body = generateRegionsForIRBlocks( - ctx, - loopRegion, - bodyBlock, - // TODO: should we pass `afterBlock` here instead of `null`? - nullptr, - // For the first block, we only want the `break` label active - &loopBreakLabelStack, - // After the first block, we can safely use the `continue` label too - &loopContinueLabelStack); - - *resultLink = loopRegion; - resultLink = &loopRegion->nextRegion; - parentRegion = loopRegion; - - // Continue with the block after the loop - block = afterBlock; - } - break; - - case kIROp_unconditionalBranch: - { - // Here we have an unconditional branch that was - // not covered by one of our labels for non-local - // branches (`break` or `continue`). - // - // We will thus assume that the target of the - // branch is part of the same region we are building, - // and continue with the target block; - // - auto branchInst = (IRUnconditionalBranch*) terminator; - block = branchInst->getTargetBlock(); - } - break; - - case kIROp_Switch: - { - // A `switch` instruction will always translate - // to a `SwitchRegion` and then to a `switch` statement. - // - // We will need to take care to emit `case`s in ways - // that avoid code duplication. - // - // The logic here isn't going to be robust in edge cases - // (please don't write Duff's Device in Slang just yet). - // Doing significantly better than what is here would - // require something like the Relooper algorithm, though. - // - auto switchInst = (IRSwitch*) terminator; - auto condition = switchInst->getCondition(); - auto breakLabel = switchInst->getBreakLabel(); - auto defaultLabel = switchInst->getDefaultLabel(); - - RefPtr switchRegion = new SwitchRegion(parentRegion, condition); - - // A direct branch to the block after the `switch` can - // be emitted as a `break` statement, so we will register - // the appropriate label on a label stack: - // - LabelStack switchBreakLabelStack; - switchBreakLabelStack.parent = labels; - switchBreakLabelStack.op = LabelStack::Op::Break; - switchBreakLabelStack.block = breakLabel; - switchBreakLabelStack.region = switchRegion; - - // We need to track whether we've dealt with - // the `default` case already. - // - bool defaultLabelHandled = false; - - // If the `default` case just branches to - // the join point, then we don't need to - // do anything with it. - // - if(defaultLabel == breakLabel) - defaultLabelHandled = true; - - // We will now iterate over the different `case`s, and - // try to group them together to minimize the number of - // sub-regions we have to create. - // - UInt caseIndex = 0; - UInt caseCount = switchInst->getCaseCount(); - while(caseIndex < caseCount) - { - // We are going to extract one case here, - // but we might need to fold additional - // cases into it, if they share the - // same label. - // - // Note: this makes assumptions that the - // IR code generator orders cases such - // that: (1) cases with the same label - // are consecutive, and (2) any case - // that "falls through" to another must - // come right before it in the list. - - auto caseVal = switchInst->getCaseValue(caseIndex); - auto caseLabel = switchInst->getCaseLabel(caseIndex); - caseIndex++; - - RefPtr currentCase = new SwitchRegion::Case(); - switchRegion->cases.add(currentCase); - - // Add the case value for this case, and any - // others that share the same label - // - for(;;) - { - currentCase->values.add(caseVal); - - // Are there any more `case`s left? - // - if(caseIndex >= caseCount) - break; - - // Does the next `case` share the same target label? - auto nextCaseLabel = switchInst->getCaseLabel(caseIndex); - if(nextCaseLabel != caseLabel) - break; - - // If those checks passed, then we will fold - // the next `case` into the same region, and - // keep looking. - caseVal = switchInst->getCaseValue(caseIndex); - caseIndex++; - } - - // The label for the current `case` might also - // be the label used by the `default` case, so - // check for that here. - // - if(caseLabel == defaultLabel) - { - switchRegion->defaultCase = currentCase; - defaultLabelHandled = true; - } - - // Now we need to generate a region for the instructions - // that make up this case. The 99% case will be that it - // will terminate with a `break` (or a `return`, - // `continue`, etc.) and so we can pass in `nullptr` - // for the ending block. - // - IRBlock* caseEndLabel = nullptr; - - // However, there is also the possibility that - // this `case` will fall through to the next, and - // so we need to prepare for that possibility here. - // - // If there *is* a next `case`, then we will set its - // label up as the "end" label when emitting - // the statements inside the block. - if(caseIndex < caseCount) - { - caseEndLabel = switchInst->getCaseLabel(caseIndex); - } - - // Now we can actually generate the region. - // - currentCase->body = generateRegionsForIRBlocks( - ctx, - switchRegion, - caseLabel, - caseEndLabel, - &switchBreakLabelStack); - } - - // If we've gone through all the cases and haven't - // managed to encounter the `default:` label, - // then assume it is a distinct case and handle it here. - if(!defaultLabelHandled) - { - RefPtr defaultCase = new SwitchRegion::Case(); - switchRegion->cases.add(defaultCase); - - // Note: we use `null` instead of `breakLabel` as the end block - // here, to ensure that the `default` region will end with an - // explicit `break` rather than just falling off the end. - - defaultCase->body = generateRegionsForIRBlocks( - ctx, - switchRegion, - defaultLabel, - nullptr, - &switchBreakLabelStack); - - switchRegion->defaultCase = defaultCase; - } - - *resultLink = switchRegion; - resultLink = &switchRegion->nextRegion; - parentRegion = switchRegion; - - // Continue with the block after the `switch` - block = breakLabel; - } - break; - } - - // After we've emitted the first block, we are safe from accidental - // cases where we'd emit an entire loop body as a single `continue`, - // so we can safely switch in whatever labels are intended to be used. - useLabels = labels; - - // If we reach this point, then we've emitted - // one block, and we have a new block where - // control flow continues. - // - // We need to handle a special case here, - // when control flow jumps back to the - // starting block of the range we were - // asked to work with: - if (block == begin) - { - break; - } - } - - // We seem to have reached the rend of the region - // without anything special happening. This means - // we should cap off the current sequence of regions - // and return what we have. - // - *resultLink = nullptr; - return resultRegion; - } - - RefPtr generateRegionTreeForFunc( - IRGlobalValueWithCode* code, - DiagnosticSink* sink) - { - RefPtr regionTree = new RegionTree(); - regionTree->irCode = code; - - ControlFlowRestructuringContext restructuringContext; - restructuringContext.sink = sink; - restructuringContext.regionTree = regionTree; - - regionTree->rootRegion = generateRegionsForIRBlocks( - &restructuringContext, - nullptr, - code->getFirstBlock(), - nullptr, - nullptr); - - return regionTree; - } -} diff --git a/source/slang/ir-restructure.h b/source/slang/ir-restructure.h deleted file mode 100644 index d27f7dbc8..000000000 --- a/source/slang/ir-restructure.h +++ /dev/null @@ -1,261 +0,0 @@ -// ir-restructure.h -#pragma once - -#include "../core/basic.h" - -namespace Slang -{ - class DiagnosticSink; - struct IRBlock; - struct IRGlobalValueWithCode; - struct IRInst; - struct IRLoop; - - /// A structured control-flow region. - /// - /// A `Region` is used to layer structured control flow information - /// over an existing IR control flow graph (CFG). Each `Region` - /// represents a sub-graph of the CFG such that control always - /// enters at the start of the region. - /// - class Region : public RefObject - { - public: - enum class Flavor - { - Simple, - If, - Break, - Continue, - Loop, - Switch, - }; - - Flavor getFlavor() { return flavor; } - - Region* getParent() { return parent; } - - /// Is this region a descendent of `other`? - /// - /// For the purpose of this query, a region - /// is a descendent of itself. - bool isDescendentOf(Region* other); - - /// Is this region a descendent of `block`? - /// - /// This tests is the region is a descendent - /// of any simple region for `block`. - bool isDescendentOf(IRBlock* block); - - protected: - Region(Flavor flavor, Region* parent) - : flavor(flavor) - , parent(parent) - {} - - /// What kind of region is this? - Flavor flavor; - - /// The parent region of this region. - Region* parent; - }; - - /// Base type for regions that have a "next" region. - /// - /// While we think of it as a region to execute - /// after this region, the `nextRegion` is actually - /// a *child* region, in that it can see local - /// values that were defined in this parent region - /// (and any other ancestor regions). - class SeqRegion : public Region - { - protected: - SeqRegion(Flavor flavor, Region* parent) - : Region(flavor, parent) - {} - - public: - /// The (child) region to execute after this one. - RefPtr nextRegion; - }; - - /// A simple region that encapsulates a basic block. - /// - class SimpleRegion : public SeqRegion - { - public: - SimpleRegion(Region* parent, IRBlock* block) - : SeqRegion(Region::Flavor::Simple, parent) - , block(block) - {} - - /// The basic block for this region. - IRBlock* block = nullptr; - - /// The next simple region for the same block - /// - /// A single IR basic block may turn into multiple regions, - /// if the restructuring pass has to duplicate it (this - /// currently happens for the continue clause in a `for` - /// loop if it has multiple `continue` sites. - /// - SimpleRegion* nextSimpleRegionForSameBlock = nullptr; - }; - - /// A conditional region, corresponding to an `if` - /// - class IfRegion : public SeqRegion - { - public: - IfRegion(Region* parent, IRInst* condition) - : SeqRegion(Region::Flavor::If, parent) - , condition(condition) - {} - - /// The IR value that controls the conditional branch - IRInst* condition; - - /// The region to execute if the `condition` is `true` - RefPtr thenRegion; - - /// The region to execute if the `condition` is `false` - RefPtr elseRegion; - }; - - /// Base type for regions that execution can `break` out of - class BreakableRegion : public SeqRegion - { - protected: - BreakableRegion(Flavor flavor, Region* parent) - : SeqRegion(flavor, parent) - {} - }; - - /// A region that expresses a `break` out of nested control flow. - /// - class BreakRegion : public Region - { - public: - BreakRegion(Region* parent, BreakableRegion* outerRegion) - : Region(Region::Flavor::Break, parent) - , outerRegion(outerRegion) - {} - - BreakableRegion* outerRegion; - }; - - /// A structured loop - class LoopRegion : public BreakableRegion - { - public: - LoopRegion(Region* parent, IRLoop* loopInst) - : BreakableRegion(Region::Flavor::Loop, parent) - , loopInst(loopInst) - {} - - /// The IR instruction that represents the branch into the loop. - /// We keep this instruction around because it may have decorations - /// that need to influence how we emit this loop. - /// - IRLoop* loopInst; - - /// The code inside the loop. - /// - /// The body region may include `break` or `continue` operations for this loop. - RefPtr body; - }; - - /// A region that expresses a `continue` for a structured loop. - /// - class ContinueRegion : public Region - { - public: - ContinueRegion(Region* parent, LoopRegion* outerRegion) - : Region(Region::Flavor::Continue, parent) - , outerRegion(outerRegion) - {} - - LoopRegion* outerRegion; - }; - - /// A structured `switch` statement. - class SwitchRegion : public BreakableRegion - { - public: - SwitchRegion(Region* parent, IRInst* condition) - : BreakableRegion(Region::Flavor::Switch, parent) - , condition(condition) - {} - - /// The IR value that controls the conditional branch - IRInst* condition; - - /// A collection of `case`s that share the same code. - class Case : public RefObject - { - public: - /// The various values that should branch to this case. - /// - /// It is possible for this list to be empty if this - /// is the `default` case and has no explicit values - /// that map to it. - /// - List values; - - /// The region to execute if this case is selected. - RefPtr body; - }; - - /// All of the cases for the `switch`. - /// - /// This includes any `default` cases. - /// - /// As an invariant, a case that "falls through" to another - /// should immediately precede its target in this list. - /// - List> cases; - - /// The default case, if any. - /// - /// It is valid for this to be `null` if there is no `default` case, - /// in which case the default behavior should be to branch to the region - /// after the `switch`. - /// - /// The default case must also be present in `cases`. - Case* defaultCase; - }; - - /// Container for all of the regions in a function. - /// - /// A `RegionTree` owns the `Region` objects associated with a function, - /// along with a mapping from basic blocks in the IR function to regions - /// in the tree. - /// - class RegionTree : public RefObject - { - public: - /// Type for the mapping from IR blocks to regions. - typedef Dictionary MapBlockToRegion; - - /// A dictionary to map from IR blocks to regions. - MapBlockToRegion mapBlockToRegion; - - /// The root region of the region tree. - RefPtr rootRegion; - - /// The IR function that was used to compute the region tree. - IRGlobalValueWithCode* irCode = nullptr; - }; - - /// Construct structrured regions to represent the control flow in an IR function. - /// - /// The resulting `RegionTree` will encode a structured (statement-like) - /// form for the control flow graph (CFG) of `code`. - /// In cases where our current restructuring approach is now powerful - /// enough to handle something in the input CFG, diagnostic messages - /// will be output to the given `sink`. - /// - RefPtr generateRegionTreeForFunc( - IRGlobalValueWithCode* code, - DiagnosticSink* sink); -} diff --git a/source/slang/ir-sccp.cpp b/source/slang/ir-sccp.cpp deleted file mode 100644 index 242ef0a37..000000000 --- a/source/slang/ir-sccp.cpp +++ /dev/null @@ -1,950 +0,0 @@ -// ir-sccp.cpp -#include "ir-sccp.h" - -#include "ir.h" -#include "ir-insts.h" - -namespace Slang { - - -// This file implements the Spare Conditional Constant Propagation (SCCP) optimization. -// -// We will apply the optimization over individual functions, so we will start with -// a context struct for the state that we will share across functions: -// -struct SharedSCCPContext -{ - IRModule* module; - SharedIRBuilder sharedBuilder; -}; -// -// Next we have a context struct that will be applied for each function (or other -// code-bearing value) that we optimize: -// -struct SCCPContext -{ - SharedSCCPContext* shared; // shared state across functions - IRGlobalValueWithCode* code; // the function/code we are optimizing - - // The SCCP algorithm applies abstract interpretation to the code of the - // function using a "lattice" of values. We can think of a node on the - // lattice as representing a set of values that a given instruction - // might take on. - // - struct LatticeVal - { - // We will use three "flavors" of values on our lattice. - // - enum class Flavor - { - // The `None` flavor represent an empty set of values, meaning - // that we've never seen any indication that the instruction - // produces a (well-defined) value. This could indicate an - // instruction that does not appear to execute, but it could - // also indicate an instruction that we know invokes undefined - // behavior, so we can freely pick a value for it on a whim. - None, - - // The `Constant` flavor represents an instuction that we - // have only ever seen produce a single, fixed value. It's - // `value` field will hold that constant value. - Constant, - - // The `Any` flavor represents an instruction that might produce - // different values at runtime, so we go ahead and approximate - // this as it potentially yielding any value whatsoever. A - // more precise analysis could use sets or intervals of values, - // but for SCCP anything that could take on more than 1 value - // at runtime is assumed to be able to take on *any* value. - Any, - }; - - // The flavor of this value (`None`, `Constant`, or `Any`) - Flavor flavor; - - // If this is a `Constant` lattice value, then this field - // points to the IR instruction that defines the actual constant value. - // For all other flavors it should be null. - IRInst* value = nullptr; - - // For convenience, we define `static` factory functions to - // produce values of each of the flavors. - - static LatticeVal getNone() - { - LatticeVal result; - result.flavor = Flavor::None; - return result; - } - - static LatticeVal getAny() - { - LatticeVal result; - result.flavor = Flavor::Any; - return result; - } - - static LatticeVal getConstant(IRInst* value) - { - LatticeVal result; - result.flavor = Flavor::Constant; - result.value = value; - return result; - } - - // We also need to be able to test if two lattice - // values are equal, so that we can avoid updating - // downstream dependencies if our knowledge about - // an instruction hasn't actually changed. - // - bool operator==(LatticeVal const& that) - { - return this->flavor == that.flavor - && this->value == that.value; - } - - bool operator!=(LatticeVal const& that) - { - return !( *this == that ); - } - }; - - // If we imagine a variable (actually an SSA phi node...) that - // might be assigned lattice value A at one point in the code, - // and lattice value B at another point, we need a way to - // combine these to form our knowledge of the possible value(s) - // for the variable. - // - // In terms of computation on a lattice, we want the "meet" - // operation, which computes the lower bound on what we know. - // If we interpret our lattice values as sets, then we are - // trying to compute the union. - // - LatticeVal meet(LatticeVal const& left, LatticeVal const& right) - { - // If either value is `None` (the empty set), then the union - // will be the other value. - // - if(left.flavor == LatticeVal::Flavor::None) return right; - if(right.flavor == LatticeVal::Flavor::None) return left; - - // If either value is `Any` (the universal set), then - // the union is also the universal set. - // - if(left.flavor == LatticeVal::Flavor::Any) return LatticeVal::getAny(); - if(right.flavor == LatticeVal::Flavor::Any) return LatticeVal::getAny(); - - // At this point we've ruled out the case where either value - // is `None` *or* `Any`, so we can assume both values are - // `Constant`s. - SLANG_ASSERT(left.flavor == LatticeVal::Flavor::Constant); - // - SLANG_ASSERT(right.flavor == LatticeVal::Flavor::Constant); - - // If the two lattice values represent the *same* constant value - // (they are the same singleton set) then the union is that - // singleton set as well. - // - // TODO: This comparison assumes that constants with - // the same value with be represented with the - // same instruction, which is not *always* - // guaranteed in the IR today. - // - if(left.value == right.value) - return left; - - // Otherwise, we have two distinct singleton sets, and their - // union should be a set with two elements. We can't represent - // that on the lattice for SCCP, so the proper lower bound - // is the universal set (`Any`) - // - return LatticeVal::getAny(); - } - - // During the execution of the SCCP algorithm, we will track our best - // "estimate" so far of the set of values each instruction could take - // on. This amounts to a mapping from IR instructions to lattice values, - // where any instruction not present in the map is assumed to default - // to the `None` case (the empty set) - // - Dictionary mapInstToLatticeVal; - - // Updating the lattice value for an instruction is easy, but we'll - // use a simple function to make our intention clear. - // - void setLatticeVal(IRInst* inst, LatticeVal const& val) - { - mapInstToLatticeVal[inst] = val; - } - - // Querying the lattice value for an instruction isn't *just* a matter - // of looking it up in the dictionary, because we need to account for - // cases of lattice values that might come from outside the current - // function. - // - LatticeVal getLatticeVal(IRInst* inst) - { - // Instructions that represent constant values should always - // have a lattice value that reflects this. - // - switch( inst->op ) - { - case kIROp_IntLit: - case kIROp_FloatLit: - case kIROp_StringLit: - case kIROp_BoolLit: - return LatticeVal::getConstant(inst); - break; - - // TODO: We might want to start having support for constant - // values of aggregate types (e.g., a `makeArray` or `makeStruct` - // where all the operands are constant is itself a constant). - - default: - break; - } - - // We might be asked for the lattice value of an instruction - // not contained in the current function. When that happens, - // we will treat it as having potentially any value, rather - // than the default of none. - // - auto parentBlock = as(inst->getParent()); - if(!parentBlock || parentBlock->getParent() != code) return LatticeVal::getAny(); - - // Once the special cases are dealt with, we can look up in - // the dictionary and just return the value we get from it, - // or default to the `None` (empty set) case. - LatticeVal latticeVal; - if(mapInstToLatticeVal.TryGetValue(inst, latticeVal)) - return latticeVal; - return LatticeVal::getNone(); - } - - // Along the way we might need to create new IR instructions - // to represnet new constant values we find, or new control - // flow instructiosn when we start simplifying things. - // - IRBuilder builderStorage; - IRBuilder* getBuilder() { return &builderStorage; } - - // In order to perform constant folding, we need to be able to - // interpret an instruction over the lattice values. - // - LatticeVal interpretOverLattice(IRInst* inst) - { - SLANG_UNUSED(inst); - - // Certain instruction always produce constants, and we - // want to special-case them here. - switch( inst->op ) - { - case kIROp_IntLit: - case kIROp_FloatLit: - case kIROp_StringLit: - case kIROp_BoolLit: - return LatticeVal::getConstant(inst); - - // TODO: we might also want to special-case certain - // instructions where we shouldn't bother trying to - // constant-fold them and should just default to the - // `Any` value right away. - - default: - break; - } - - // TODO: We should now look up the lattice values for - // the operands of the instruction. - // - // If all of the operands have `Constant` lattice values, - // then we can potential execute the operation directly - // on those constant values, create a fresh `IRConstant`, - // and return a `Constant` lattice value for it. This - // would allow us to achieve true constant folding here. - // - // Textbook discussions of SCCP often point out that it - // is also possible to perform certain algebraic simplifications - // here, such as evaluating a multiply by a `Constant` zero - // to zero. - // - // As a default, if any operand has the `Any` value - // then the result of the operation should be treated as - // `Any`. There are exceptions to this, however, with the - // multiply-by-zero example being an important example. - // If we had previously decided that (Any * None) -> Any - // but then we refine our estimates and have (Any * Constant(0)) -> Constant(0) - // then we have violated the monotonicity rules for how - // our values move through the lattice, and we may break - // the convergence guarantees of the analysis. - // - // When we have a mix of `None` and `Constant` operands, - // then the `None` values imply that our operation is using - // uninitialized data or the results of undefined behavior. - // We could try to propagate the `None` through, and allow - // the compiler to speculatively assume that the operation - // produces whatever value we find convenient. Alternatively, - // we can be less aggressive and treat an operation with - // `None` inputs as producing `Any` to make sure we don't - // optimize the code based on non-obvious assumptions. - // - // For now we aren't implementing *any* folding logic here, - // for simplicity. This is the right place to add folding - // optimizations if/when we need them. - // - - // A safe default is to assume that every instruction not - // handled by one of the cases above could produce *any* - // value whatsoever. - return LatticeVal::getAny(); - } - - - // For basic blocks, we will do tracking very similar to what we do for - // ordinary instructions, just with a simpler lattice: every block - // will either be marked as "never executed" or in a "possibly executed" - // state. We track this as a set of the blocks that have been - // marked as possibly executed, plus a getter and setter function. - - HashSet executedBlocks; - - bool isMarkedAsExecuted(IRBlock* block) - { - return executedBlocks.Contains(block); - } - - void markAsExecuted(IRBlock* block) - { - executedBlocks.Add(block); - } - - // The core of the algorithm is based on two work lists. - // One list holds CFG nodes (basic blocks) that we have - // discovered might execute, and thus need to be processed, - // and the other holds SSA nodes (instructions) that need - // their "estimated" value to be updated. - - List cfgWorkList; - List ssaWorkList; - - // A key operation is to take an IR instruction and update - // its "estimated" value on the lattice. This might happen when - // we first discover the instruction could be executed, or - // when we discover that one or more of its operands has - // changed its lattice value so that we need to update our estimate. - // - void updateValueForInst(IRInst* inst) - { - // Block parameters are conceptually SSA "phi nodes", and it - // doesn't make sense to update their values here, because the - // actual candidate values for them comes from the predecessor blocks - // that provide arguments. We will see that logic shortly, when - // handling `IRUnconditionalBranch`. - // - if(as(inst)) - return; - - // We want to special-case terminator instructions here, - // since abstract interpretation of them should cause blocks to - // be marked as executed, etc. - // - if( auto terminator = as(inst) ) - { - if( auto unconditionalBranch = as(inst) ) - { - // When our abstract interpreter "executes" an unconditional - // branch, it needs to mark the target block as potentially - // executed. We do this by adding the target to our CFG work list. - // - auto target = unconditionalBranch->getTargetBlock(); - cfgWorkList.add(target); - - // Besides transferring control to another block, the other - // thing our unconditional branch instructions do is provide - // the arguments for phi nodes in the target block. - // We thus need to interpret each argument on the branch - // instruction like an "assignment" to the corresponding - // parameter of the target block. - // - UInt argCount = unconditionalBranch->getArgCount(); - IRParam* pp = target->getFirstParam(); - for( UInt aa = 0; aa < argCount; ++aa, pp = pp->getNextParam() ) - { - IRInst* arg = unconditionalBranch->getArg(aa); - IRInst* param = pp; - - // We expect the number of arguments and parameters to match, - // or else the IR is violating its own invariants. - // - SLANG_ASSERT(param); - - // We will update the value for the target block's parameter - // using our "meet" operation (union of sets of possible values) - // - LatticeVal oldVal = getLatticeVal(param); - - // If we've already determined that the block parameter could - // have any value whatsoever, there is no reason to bother - // updating it. - // - if(oldVal.flavor == LatticeVal::Flavor::Any) - continue; - - // We can look up the lattice value for the argument, - // because we should have interpreted it already - // - LatticeVal argVal = getLatticeVal(arg); - - // Now we apply the meet operation and see if the value changed. - // - LatticeVal newVal = meet(oldVal, argVal); - if( newVal != oldVal ) - { - // If the "estimated" value for the parameter has changed, - // then we need to update it in our dictionary, and then - // make sure that all of the users of the parameter get - // their estimates updated as well. - // - setLatticeVal(param, newVal); - for( auto use = param->firstUse; use; use = use->nextUse ) - { - ssaWorkList.add(use->getUser()); - } - } - } - } - else if( auto conditionalBranch = as(inst) ) - { - // An `IRConditionalBranch` is used for two-way branches. - // We will look at the lattice value for the condition, - // to see if we can narrow down which of the two ways - // might actually be taken. - // - auto condVal = getLatticeVal(conditionalBranch->getCondition()); - - // We do not expect to see a `None` value here, because that - // would mean the user is branching based on an undefined - // value. - // - // TODO: We should make sure there is no way for the user - // to trigger this assert with bad code that involves - // uninitialized variables. Right now we don't special - // case the `undefined` instruction when computing lattice - // values, so it shouldn't be a problem. - // - SLANG_ASSERT(condVal.flavor != LatticeVal::Flavor::None); - - // If the branch condition is a constant, we expect it to - // be a Boolean constant. We won't assert that it is the - // case here, just to be defensive. - // - if( condVal.flavor == LatticeVal::Flavor::Constant ) - { - if( auto boolConst = as(condVal.value) ) - { - // Only one of the two targe blocks is possible to - // execute, based on what we know of the condition, - // so we will add that target to our work list and - // bail out now. - // - auto target = boolConst->getValue() ? conditionalBranch->getTrueBlock() : conditionalBranch->getFalseBlock(); - cfgWorkList.add(target); - return; - } - } - - // As a fallback, if the condition isn't constant - // (or somehow wasn't a Boolean constnat), we will - // assume that either side of the branch could be - // taken, so that both of the target blocks are - // potentially executed. - // - cfgWorkList.add(conditionalBranch->getTrueBlock()); - cfgWorkList.add(conditionalBranch->getFalseBlock()); - } - else if( auto switchInst = as(inst) ) - { - // The handling of a `switch` instruction is similar to the - // case for a two-way branch, with the main difference that - // we have to deal with an integer condition value. - - auto condVal = getLatticeVal(switchInst->getCondition()); - SLANG_ASSERT(condVal.flavor != LatticeVal::Flavor::None); - - UInt caseCount = switchInst->getCaseCount(); - if( condVal.flavor == LatticeVal::Flavor::Constant ) - { - if( auto condConst = as(condVal.value) ) - { - // At this point we have a constant integer condition - // value, and we just need to find the case (if any) - // that matches it. We will default to considering - // the `default` label as the target. - // - auto target = switchInst->getDefaultLabel(); - for( UInt cc = 0; cc < caseCount; ++cc ) - { - if( auto caseConst = as(switchInst->getCaseValue(cc)) ) - { - if(caseConst->getValue() == condConst->getValue()) - { - target = switchInst->getCaseLabel(cc); - break; - } - } - } - - // Whatever single block we decided will get executed, - // we need to make sure it gets processed and then bail. - // - cfgWorkList.add(target); - return; - } - } - - // The fallback is to assume that the `switch` instruction might - // branch to any of its cases, or the `default` label. - // - for( UInt cc = 0; cc < caseCount; ++cc ) - { - cfgWorkList.add(switchInst->getCaseLabel(cc)); - } - cfgWorkList.add(switchInst->getDefaultLabel()); - } - - // There are other cases of terminator instructions not handled - // above (e.g., `return` instructions), but these can't cause - // additional basic blocks in the CFG to execute, so we don't - // need to consider them here. - // - // No matter what, we are done with a terminator instruction - // after inspecting it, and there is no reason we have to - // try and compute its "value." - return; - } - - // For an "ordinary" instruction, we will first check what value - // has been registered for it already. - // - LatticeVal oldVal = getLatticeVal(inst); - - // If we have previous decided that the instruction could take - // on any value whatsoever, then any further update to our - // guess can't expand things more, and so there is nothing to do. - // - if( oldVal.flavor == LatticeVal::Flavor::Any ) - { - return; - } - - // Otherwise, we compute a new guess at the value of - // the instruction based on the lattice values of the - // stuff it depends on. - // - LatticeVal newVal = interpretOverLattice(inst); - - // If nothing changed about our guess, then there is nothing - // further to do, because users of this instruction have - // already computed their guess based on its current value. - // - if(newVal == oldVal) - { - return; - } - - // If the guess did change, then we want to register our - // new guess as the lattice value for this instruction. - // - setLatticeVal(inst, newVal); - - // Next we iterate over all the users of this instruction - // and add them to our work list so that we can update - // their values based on the new information. - // - for( auto use = inst->firstUse; use; use = use->nextUse ) - { - ssaWorkList.add(use->getUser()); - } - } - - // The `apply()` function will run the full algorithm. - // - void apply() - { - // We start with the busy-work of setting up our IR builder. - // - builderStorage.sharedBuilder = &shared->sharedBuilder; - - // We expect the caller to have filtered out functions with - // no bodies, so there should always be at least one basic block. - // - auto firstBlock = code->getFirstBlock(); - SLANG_ASSERT(firstBlock); - - // The entry block is always going to be executed when the - // function gets called, so we will process it right away. - // - cfgWorkList.add(firstBlock); - - // The parameters of the first block are our function parameters, - // and we want to operate on the assumption that they could have - // any value possible, so we will record that in our dictionary. - // - for( auto pp : firstBlock->getParams() ) - { - setLatticeVal(pp, LatticeVal::getAny()); - } - - // Now we will iterate until both of our work lists go dry. - // - while(cfgWorkList.getCount() || ssaWorkList.getCount()) - { - // Note: there is a design choice to be had here - // around whether we do `if if` or `while while` - // for these nested checks. The choice can affect - // how long things take to converge. - - // We will start by processing any blocks that we - // have determined are potentially reachable. - // - while( cfgWorkList.getCount() ) - { - // We pop one block off of the work list. - // - auto block = cfgWorkList[0]; - cfgWorkList.fastRemoveAt(0); - - // We only want to process blocks that haven't - // already been marked as executed, so that we - // don't do redundant work. - // - if( !isMarkedAsExecuted(block) ) - { - // We should mark this new block as executed, - // so we can ignore it if it ever ends up on - // the work list again. - // - markAsExecuted(block); - - // If the block is potentially executed, then - // that means the instructions in the block are too. - // We will walk through the block and update our - // guess at the value of each instruction, which - // may in turn add other blocks/instructions to - // the work lists. - // - for( auto inst : block->getDecorationsAndChildren() ) - { - updateValueForInst(inst); - } - } - } - - // Once we've cleared the work list of blocks, we - // will start looking at individual instructions that - // need to be updated. - // - while( ssaWorkList.getCount() ) - { - // We pop one instruction that needs an update. - // - auto inst = ssaWorkList[0]; - ssaWorkList.fastRemoveAt(0); - - // Before updating the instruction, we will check if - // the parent block of the instructin is marked as - // being executed. If it isn't, there is no reason - // to update the value for the instruction, since - // it might never be used anyway. - // - IRBlock* block = as(inst->getParent()); - - // It is possible that an instruction ended up on - // our SSA work list because it is a user of an - // instruction in a block of `code`, but it is not - // itself an instruction a block of `code`. - // - // For example, if `code` is an `IRGeneric` that - // yields a function, then `inst` might be an - // instruction of that nested function, and not - // an instruction of the generic itself. - // Note that in such a case, the `inst` cannot - // possible affect the values computed in the outer - // generic, or the control-flow paths it might take, - // so there is no reason to consider it. - // - // We guard against this case by only processing `inst` - // if it is a child of a block in the current `code`. - // - if(!block || block->getParent() != code) - continue; - - if( isMarkedAsExecuted(block) ) - { - // If the instruction is potentially executed, we update - // its lattice value based on our abstraction interpretation. - // - updateValueForInst(inst); - } - } - } - - // Once the work lists are empty, our "guesses" at the value - // of different instructions and the potentially-executed-ness - // of blocks should have converged to a conservative steady state. - // - // We are now equiped to start using the information we've gathered - // to modify the code. - - // First, we will walk through all the code and replace instructions - // with constants where it is possible. - // - List instsToRemove; - for( auto block : code->getBlocks() ) - { - for( auto inst : block->getDecorationsAndChildren() ) - { - // We look for instructions that have a constnat value on - // the lattice. - // - LatticeVal latticeVal = getLatticeVal(inst); - if(latticeVal.flavor != LatticeVal::Flavor::Constant) - continue; - - // As a small sanity check, we won't go replacing an - // instruction with itself (this shouldn't really come - // up, since constants are supposed to be at the global - // scope right now) - // - IRInst* constantVal = latticeVal.value; - if(constantVal == inst) - continue; - - // We replace any uses of the instruction with its - // constant expected value, and add it to a list of - // instructions to be removed *iff* the instruction - // is known to have no obersvable side effects. - // - inst->replaceUsesWith(constantVal); - if( !inst->mightHaveSideEffects() ) - { - instsToRemove.add(inst); - } - } - } - - // Once we've replaced the uses of instructions that evaluate - // to constants, we make a second pass to remove the instructions - // themselves (or at least those without side effects). - // - for( auto inst : instsToRemove ) - { - inst->removeAndDeallocate(); - } - - // Next we are going to walk through all of the terminator - // instructions on blocks and look for ones that branch - // based on a constant condition. These will be rewritten - // to use direct branching instructions, which will of course - // need to be emitted using a builder. - // - auto builder = getBuilder(); - for( auto block : code->getBlocks() ) - { - auto terminator = block->getTerminator(); - - // We check if we have a `switch` instruction with a constant - // integer as its condition. - // - if( auto switchInst = as(terminator) ) - { - if( auto constVal = as(switchInst->getCondition()) ) - { - // We will select the one branch that gets taken, based - // on the constant condition value. The `default` label - // will of course be taken if no `case` label matches. - // - IRBlock* target = switchInst->getDefaultLabel(); - UInt caseCount = switchInst->getCaseCount(); - for(UInt cc = 0; cc < caseCount; ++cc) - { - auto caseVal = switchInst->getCaseValue(cc); - if(auto caseConst = as(caseVal)) - { - if( caseConst->getValue() == constVal->getValue() ) - { - target = switchInst->getCaseLabel(cc); - break; - } - } - } - - // Once we've found the target, we will emit a direct - // branch to it before the old terminator, and then remove - // the old terminator instruction. - // - builder->setInsertBefore(terminator); - builder->emitBranch(target); - terminator->removeAndDeallocate(); - } - } - else if(auto condBranchInst = as(terminator)) - { - if( auto constVal = as(condBranchInst->getCondition()) ) - { - // The case for a two-sided conditional branch is similar - // to the `switch` case, but simpler. - - IRBlock* target = constVal->getValue() ? condBranchInst->getTrueBlock() : condBranchInst->getFalseBlock(); - - builder->setInsertBefore(terminator); - builder->emitBranch(target); - terminator->removeAndDeallocate(); - } - - } - } - - // At this point we've replaced some conditional branches - // that would always go the same way (e.g., a `while(true)`), - // which should render some of our blocks unreachable. - // We will collect all those unreachable blocks into a list - // of blocks to be removed, and then go about trying to - // remove them. - // - List unreachableBlocks; - for( auto block : code->getBlocks() ) - { - if( !isMarkedAsExecuted(block) ) - { - unreachableBlocks.add(block); - } - } - // - // It might seem like we could just do: - // - // block->removeAndDeallocate(); - // - // for each of the blocks in `unreachableBlocks`, but there - // is a subtle point that has to be considered: - // - // We have a structured control-flow representation where - // certain branching instructions name "join points" where - // control flow logically re-converges. It is possible that - // one of our unreachable blocks is still being used as - // a join point. - // - // For example: - // - // if(A) - // return B; - // else - // return C; - // D; - // - // In the above example, the block that computes `D` is - // unreachable, but it is still the join point for the `if(A)` - // branch. - // - // Rather than complicate the encoding of join points to - // try to special-case an unreachable join point, we will - // instead retain the join point as a block with only a single - // `unreachable` instruction. - // - // To detect which blocks are unreachable and unreferenced, - // we will check which blocks have any uses. Of course, it - // might be that some of our unreachable blocks still reference - // one another (e.g., an unreachable loop) so we will start - // by removing the instructions from the bodies of our unreachable - // blocks to eliminate any cross-references between them. - // - for( auto block : unreachableBlocks ) - { - // TODO: In principle we could produce a diagnostic here - // if any of these unreachable blocks appears to have - // "non-trivial" code in it (that is, any code explicitly - // written by the user, and not just code synthesized by - // the compiler to satisfy language rules). Making that - // determination could be tricky, so for now we will - // err on the side of allowing unreachable code without - // a warning. - // - block->removeAndDeallocateAllDecorationsAndChildren(); - } - // - // At this point every one of our unreachable blocks is empty, - // and there should be no branches from reachable blocks - // to unreachable ones. - // - // We will iterate over our unreachable blocks, and process - // them differently based on whether they have any remaining uses. - // - for( auto block : unreachableBlocks ) - { - // At this point there had better be no edges branching to - // our block. We determined it was unreachable, so there had - // better not be branches from reachable blocks to this one, - // and all the unreachable blocks had their instructions - // removed, so there should be no branches to it from other - // unreachable blocks (or itself). - // - SLANG_ASSERT(block->getPredecessors().isEmpty()); - - // If the block is completely unreferenced, we can safely - // remove and deallocate it now. - // - if( !block->hasUses() ) - { - block->removeAndDeallocate(); - } - else - { - // Otherwise, the block has at least one use (but - // no predecessors), which should indicate that it - // is an unreachable join point. - // - // We will keep the block around, but its entire - // body will consist of a single `unreachable` - // instruction. - // - builder->setInsertInto(block); - builder->emitUnreachable(); - } - } - } -}; - -static void applySparseConditionalConstantPropagationRec( - SharedSCCPContext* shared, - IRInst* inst) -{ - if( auto code = as(inst) ) - { - if( code->getFirstBlock() ) - { - SCCPContext context; - context.shared = shared; - context.code = code; - context.apply(); - } - } - - for( auto childInst : inst->getDecorationsAndChildren() ) - { - applySparseConditionalConstantPropagationRec(shared, childInst); - } -} - -void applySparseConditionalConstantPropagation( - IRModule* module) -{ - SharedSCCPContext shared; - shared.module = module; - shared.sharedBuilder.module = module; - shared.sharedBuilder.session = module->getSession(); - - applySparseConditionalConstantPropagationRec(&shared, module->getModuleInst()); -} - -} - diff --git a/source/slang/ir-sccp.h b/source/slang/ir-sccp.h deleted file mode 100644 index cd075761a..000000000 --- a/source/slang/ir-sccp.h +++ /dev/null @@ -1,18 +0,0 @@ -// ir-sccp.h -#pragma once - -namespace Slang -{ - struct IRModule; - - /// Apply Sparse Conditional Constant Propagation (SCCP) to a module. - /// - /// This optimization replaces instructions that can only ever evaluate - /// to a single (well-defined) value with that constant value, and - /// also eliminates conditional branches where the condition will - /// always evaluate to a constant (which can lead to entire blocks - /// becoming dead code) - void applySparseConditionalConstantPropagation( - IRModule* module); -} - diff --git a/source/slang/ir-serialize.cpp b/source/slang/ir-serialize.cpp deleted file mode 100644 index 33238a9a6..000000000 --- a/source/slang/ir-serialize.cpp +++ /dev/null @@ -1,2125 +0,0 @@ -// ir-serialize.cpp -#include "ir-serialize.h" - -#include "../core/text-io.h" -#include "../core/slang-byte-encode-util.h" - -#include "ir-insts.h" - -#include "../core/slang-math.h" - -namespace Slang { - -// Needed for linkage with some compilers -/* static */ const IRSerialData::StringIndex IRSerialData::kNullStringIndex; -/* static */ const IRSerialData::StringIndex IRSerialData::kEmptyStringIndex; - -/* Note that an IRInst can be derived from, but when it derived from it's new members are IRUse variables, and they in -effect alias over the operands - and reflected in the operand count. There _could_ be other members after these IRUse -variables, but only a few types include extra data, and these do not have any operands: - -* IRConstant - Needs special-case handling -* IRModuleInst - Presumably we can just set to the module pointer on reconstruction - -Note! That on an IRInst there is an IRType* variable (accessed as getFullType()). As it stands it may NOT actually point -to an IRType derived type. Its 'ok' as long as it's an instruction that can be used in the place of the type. So this code does not -bother to check if it's correct, and just casts it. -*/ - -/* static */const IRSerialData::PayloadInfo IRSerialData::s_payloadInfos[int(Inst::PayloadType::CountOf)] = -{ - { 0, 0 }, // Empty - { 1, 0 }, // Operand_1 - { 2, 0 }, // Operand_2 - { 1, 0 }, // OperandAndUInt32, - { 0, 0 }, // OperandExternal - This isn't correct, Operand has to be specially handled - { 0, 1 }, // String_1, - { 0, 2 }, // String_2, - { 0, 0 }, // UInt32, - { 0, 0 }, // Float64, - { 0, 0 } // Int64, -}; - -static bool isTextureTypeBase(IROp opIn) -{ - const int op = (kIROpMeta_PseudoOpMask & opIn); - return op >= kIROp_FirstTextureTypeBase && op <= kIROp_LastTextureTypeBase; -} - -static bool isConstant(IROp opIn) -{ - const int op = (kIROpMeta_PseudoOpMask & opIn); - return op >= kIROp_FirstConstant && op <= kIROp_LastConstant; -} - -struct PrefixString; - -namespace { // anonymous - -struct CharReader -{ - char operator()(int pos) const { SLANG_UNUSED(pos); return *m_pos++; } - CharReader(const char* pos) :m_pos(pos) {} - mutable const char* m_pos; -}; - -} // anonymous - -// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! StringRepresentationCache !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! - -StringRepresentationCache::StringRepresentationCache(): - m_stringTable(nullptr), - m_namePool(nullptr), - m_scopeManager(nullptr) -{ -} - -void StringRepresentationCache::init(const List* stringTable, NamePool* namePool, ObjectScopeManager* scopeManager) -{ - m_stringTable = stringTable; - m_namePool = namePool; - m_scopeManager = scopeManager; - - // Decode the table - m_entries.setCount(StringSlicePool::kNumDefaultHandles); - SLANG_COMPILE_TIME_ASSERT(StringSlicePool::kNumDefaultHandles == 2); - - { - Entry& entry = m_entries[0]; - entry.m_numChars = 0; - entry.m_startIndex = 0; - entry.m_object = nullptr; - } - { - Entry& entry = m_entries[1]; - entry.m_numChars = 0; - entry.m_startIndex = 0; - entry.m_object = nullptr; - } - - { - const char* start = stringTable->begin(); - const char* cur = start; - const char* end = stringTable->end(); - - while (cur < end) - { - CharReader reader(cur); - const int len = GetUnicodePointFromUTF8(reader); - - Entry entry; - entry.m_startIndex = uint32_t(reader.m_pos - start); - entry.m_numChars = len; - entry.m_object = nullptr; - - m_entries.add(entry); - - cur = reader.m_pos + len; - } - } - - m_entries.compress(); -} - -Name* StringRepresentationCache::getName(Handle handle) -{ - if (handle == StringSlicePool::kNullHandle) - { - return nullptr; - } - - Entry& entry = m_entries[int(handle)]; - if (entry.m_object) - { - Name* name = dynamicCast(entry.m_object); - if (name) - { - return name; - } - StringRepresentation* stringRep = static_cast(entry.m_object); - // Promote it to a name - name = m_namePool->getName(String(stringRep)); - entry.m_object = name; - return name; - } - - Name* name = m_namePool->getName(String(getStringSlice(handle))); - entry.m_object = name; - return name; -} - -String StringRepresentationCache::getString(Handle handle) -{ - return String(getStringRepresentation(handle)); -} - -UnownedStringSlice StringRepresentationCache::getStringSlice(Handle handle) const -{ - const Entry& entry = m_entries[int(handle)]; - const char* start = m_stringTable->begin(); - - return UnownedStringSlice(start + entry.m_startIndex, int(entry.m_numChars)); -} - -StringRepresentation* StringRepresentationCache::getStringRepresentation(Handle handle) -{ - if (handle == StringSlicePool::kNullHandle || handle == StringSlicePool::kEmptyHandle) - { - return nullptr; - } - - Entry& entry = m_entries[int(handle)]; - if (entry.m_object) - { - Name* name = dynamicCast(entry.m_object); - if (name) - { - return name->text.getStringRepresentation(); - } - return static_cast(entry.m_object); - } - - const UnownedStringSlice slice = getStringSlice(handle); - const UInt size = slice.size(); - - StringRepresentation* stringRep = StringRepresentation::createWithCapacityAndLength(size, size); - memcpy(stringRep->getData(), slice.begin(), size); - entry.m_object = stringRep; - - // Keep the StringRepresentation in scope - m_scopeManager->add(stringRep); - - return stringRep; -} - -char* StringRepresentationCache::getCStr(Handle handle) -{ - // It turns out StringRepresentation is always 0 terminated, so can just use that - StringRepresentation* rep = getStringRepresentation(handle); - return rep->getData(); -} - -// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! SerialStringTableUtil !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! - -/* static */void SerialStringTableUtil::encodeStringTable(const StringSlicePool& pool, List& stringTable) -{ - // Skip the default handles -> nothing is encoded via them - return encodeStringTable(pool.getSlices().begin() + StringSlicePool::kNumDefaultHandles, pool.getNumSlices() - StringSlicePool::kNumDefaultHandles, stringTable); -} - -/* static */void SerialStringTableUtil::encodeStringTable(const UnownedStringSlice* slices, size_t numSlices, List& stringTable) -{ - stringTable.clear(); - for (size_t i = 0; i < numSlices; ++i) - { - const UnownedStringSlice slice = slices[i]; - const int len = int(slice.size()); - - // We need to write into the the string array - char prefixBytes[6]; - const int numPrefixBytes = EncodeUnicodePointToUTF8(prefixBytes, len); - const Index baseIndex = stringTable.getCount(); - - stringTable.setCount(baseIndex + numPrefixBytes + len); - - char* dst = stringTable.begin() + baseIndex; - - memcpy(dst, prefixBytes, numPrefixBytes); - memcpy(dst + numPrefixBytes, slice.begin(), len); - } -} - -/* static */void SerialStringTableUtil::appendDecodedStringTable(const List& stringTable, List& slicesOut) -{ - const char* start = stringTable.begin(); - const char* cur = start; - const char* end = stringTable.end(); - - while (cur < end) - { - CharReader reader(cur); - const int len = GetUnicodePointFromUTF8(reader); - slicesOut.add(UnownedStringSlice(reader.m_pos, len)); - cur = reader.m_pos + len; - } -} - -/* static */void SerialStringTableUtil::decodeStringTable(const List& stringTable, List& slicesOut) -{ - slicesOut.setCount(2); - slicesOut[0] = UnownedStringSlice(nullptr, size_t(0)); - slicesOut[1] = UnownedStringSlice("", size_t(0)); - - appendDecodedStringTable(stringTable, slicesOut); -} - -/* static */void SerialStringTableUtil::calcStringSlicePoolMap(const List& slices, StringSlicePool& pool, List& indexMapOut) -{ - SLANG_ASSERT(slices.getCount() >= StringSlicePool::kNumDefaultHandles); - SLANG_ASSERT(slices[int(StringSlicePool::kNullHandle)] == "" && slices[int(StringSlicePool::kNullHandle)].begin() == nullptr); - SLANG_ASSERT(slices[int(StringSlicePool::kEmptyHandle)] == ""); - - indexMapOut.setCount(slices.getCount()); - // Set up all of the defaults - for (int i = 0; i < StringSlicePool::kNumDefaultHandles; ++i) - { - indexMapOut[i] = StringSlicePool::Handle(i); - } - - const Index numSlices = slices.getCount(); - for (Index i = StringSlicePool::kNumDefaultHandles; i < numSlices ; ++i) - { - indexMapOut[i] = pool.add(slices[i]); - } -} - -// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! IRSerialData !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! - -template -static size_t _calcArraySize(const List& list) -{ - return list.getCount() * sizeof(T); -} - -size_t IRSerialData::calcSizeInBytes() const -{ - return - _calcArraySize(m_insts) + - _calcArraySize(m_childRuns) + - _calcArraySize(m_externalOperands) + - _calcArraySize(m_stringTable) + - /* Raw source locs */ - _calcArraySize(m_rawSourceLocs) + - /* Debug */ - _calcArraySize(m_debugStringTable) + - _calcArraySize(m_debugLineInfos) + - _calcArraySize(m_debugSourceInfos) + - _calcArraySize(m_debugAdjustedLineInfos) + - _calcArraySize(m_debugSourceLocRuns); -} - -IRSerialData::IRSerialData() -{ - clear(); -} - -void IRSerialData::clear() -{ - // First Instruction is null - m_insts.setCount(1); - memset(&m_insts[0], 0, sizeof(Inst)); - - m_childRuns.clear(); - m_externalOperands.clear(); - m_rawSourceLocs.clear(); - - m_stringTable.clear(); - - // Debug data - m_debugLineInfos.clear(); - m_debugAdjustedLineInfos.clear(); - m_debugSourceInfos.clear(); - m_debugSourceLocRuns.clear(); - m_debugStringTable.clear(); -} - -template -static bool _isEqual(const List& aIn, const List& bIn) -{ - if (aIn.getCount() != bIn.getCount()) - { - return false; - } - - size_t size = size_t(aIn.getCount()); - - const T* a = aIn.begin(); - const T* b = bIn.begin(); - - if (a == b) - { - return true; - } - - for (size_t i = 0; i < size; ++i) - { - if (a[i] != b[i]) - { - return false; - } - } - - return true; -} - -bool IRSerialData::operator==(const ThisType& rhs) const -{ - return (this == &rhs) || - (_isEqual(m_insts, rhs.m_insts) && - _isEqual(m_childRuns, rhs.m_childRuns) && - _isEqual(m_externalOperands, rhs.m_externalOperands) && - _isEqual(m_rawSourceLocs, rhs.m_rawSourceLocs) && - _isEqual(m_stringTable, rhs.m_stringTable) && - /* Debug */ - _isEqual(m_debugStringTable, rhs.m_debugStringTable) && - _isEqual(m_debugLineInfos, rhs.m_debugLineInfos) && - _isEqual(m_debugAdjustedLineInfos, rhs.m_debugAdjustedLineInfos) && - _isEqual(m_debugSourceInfos, rhs.m_debugSourceInfos) && - _isEqual(m_debugSourceLocRuns, rhs.m_debugSourceLocRuns)); -} - -// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! IRSerialWriter !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! - -void IRSerialWriter::_addInstruction(IRInst* inst) -{ - // It cannot already be in the map - SLANG_ASSERT(!m_instMap.ContainsKey(inst)); - - // Add to the map - m_instMap.Add(inst, Ser::InstIndex(m_insts.getCount())); - m_insts.add(inst); -} - -#if 0 -// Find a view index that matches the view by file (and perhaps other characteristics in the future) -static int _findSourceViewIndex(const List& viewsIn, SourceView* view) -{ - const int numViews = int(viewsIn.Count()); - SourceView*const* views = viewsIn.begin(); - - SourceFile* sourceFile = view->getSourceFile(); - - for (int i = 0; i < numViews; ++i) - { - SourceView* curView = views[i]; - // For now we just match on source file - if (curView->getSourceFile() == sourceFile) - { - // It's a hit - return i; - } - } - return -1; -} -#endif - -void IRSerialWriter::_addDebugSourceLocRun(SourceLoc sourceLoc, uint32_t startInstIndex, uint32_t numInsts) -{ - SourceView* sourceView = m_sourceManager->findSourceView(sourceLoc); - if (!sourceView) - { - return; - } - - SourceFile* sourceFile = sourceView->getSourceFile(); - DebugSourceFile* debugSourceFile; - { - RefPtr* ptrDebugSourceFile = m_debugSourceFileMap.TryGetValue(sourceFile); - if (ptrDebugSourceFile == nullptr) - { - const SourceLoc::RawValue baseSourceLoc = m_debugFreeSourceLoc; - m_debugFreeSourceLoc += SourceLoc::RawValue(sourceView->getRange().getSize() + 1); - - debugSourceFile = new DebugSourceFile(sourceFile, baseSourceLoc); - m_debugSourceFileMap.Add(sourceFile, debugSourceFile); - } - else - { - debugSourceFile = *ptrDebugSourceFile; - } - } - - // We need to work out the line index - - int offset = sourceView->getRange().getOffset(sourceLoc); - int lineIndex = sourceFile->calcLineIndexFromOffset(offset); - - IRSerialData::DebugLineInfo lineInfo; - lineInfo.m_lineStartOffset = sourceFile->getLineBreakOffsets()[lineIndex]; - lineInfo.m_lineIndex = lineIndex; - - if (!debugSourceFile->hasLineIndex(lineIndex)) - { - // Add the information about the line - int entryIndex = sourceView->findEntryIndex(sourceLoc); - if (entryIndex < 0) - { - debugSourceFile->m_lineInfos.add(lineInfo); - } - else - { - const auto& entry = sourceView->getEntries()[entryIndex]; - - IRSerialData::DebugAdjustedLineInfo adjustedLineInfo; - adjustedLineInfo.m_lineInfo = lineInfo; - adjustedLineInfo.m_pathStringIndex = Ser::kNullStringIndex; - - if (StringSlicePool::hasContents(entry.m_pathHandle)) - { - UnownedStringSlice slice = sourceView->getSourceManager()->getStringSlicePool().getSlice(entry.m_pathHandle); - SLANG_ASSERT(slice.size() > 0); - adjustedLineInfo.m_pathStringIndex = Ser::StringIndex(m_debugStringSlicePool.add(slice)); - } - - adjustedLineInfo.m_adjustedLineIndex = lineIndex + entry.m_lineAdjust; - - debugSourceFile->m_adjustedLineInfos.add(adjustedLineInfo); - } - - debugSourceFile->setHasLineIndex(lineIndex); - } - - // Add the run - IRSerialData::SourceLocRun sourceLocRun; - sourceLocRun.m_numInst = numInsts; - sourceLocRun.m_startInstIndex = IRSerialData::InstIndex(startInstIndex); - sourceLocRun.m_sourceLoc = uint32_t(debugSourceFile->m_baseSourceLoc + offset); - - m_serialData->m_debugSourceLocRuns.add(sourceLocRun); -} - -Result IRSerialWriter::_calcDebugInfo() -{ - // We need to find the unique source Locs - // We are not going to store SourceLocs directly, because there may be multiple views mapping down to - // the same underlying source file - - // First find all the unique locs - struct InstLoc - { - typedef InstLoc ThisType; - - SLANG_FORCE_INLINE bool operator<(const ThisType& rhs) const { return sourceLoc < rhs.sourceLoc || (sourceLoc == rhs.sourceLoc && instIndex < rhs.instIndex); } - - uint32_t instIndex; - uint32_t sourceLoc; - }; - - // Find all of the source locations and their associated instructions - List instLocs; - const Index numInsts = m_insts.getCount(); - for (Index i = 1; i < numInsts; i++) - { - IRInst* srcInst = m_insts[i]; - if (!srcInst->sourceLoc.isValid()) - { - continue; - } - InstLoc instLoc; - instLoc.instIndex = uint32_t(i); - instLoc.sourceLoc = uint32_t(srcInst->sourceLoc.getRaw()); - instLocs.add(instLoc); - } - - // Sort them - instLocs.sort(); - m_debugFreeSourceLoc = 1; - - // Look for runs - const InstLoc* startInstLoc = instLocs.begin(); - const InstLoc* endInstLoc = instLocs.end(); - - while (startInstLoc < endInstLoc) - { - const uint32_t startSourceLoc = startInstLoc->sourceLoc; - - // Find the run with the same source loc - - const InstLoc* curInstLoc = startInstLoc + 1; - uint32_t curInstIndex = startInstLoc->instIndex + 1; - - // Find the run size with same source loc and run of instruction indices - for (; curInstLoc < endInstLoc && curInstLoc->sourceLoc == startSourceLoc && curInstLoc->instIndex == curInstIndex; ++curInstLoc, ++curInstIndex) - { - } - - // Try adding the run - _addDebugSourceLocRun(SourceLoc::fromRaw(startSourceLoc), startInstLoc->instIndex, curInstIndex - startInstLoc->instIndex); - - // Next - startInstLoc = curInstLoc; - } - - // Okay we can now calculate the final source information - - for (auto& pair : m_debugSourceFileMap) - { - DebugSourceFile* debugSourceFile = pair.Value; - SourceFile* sourceFile = debugSourceFile->m_sourceFile; - - IRSerialData::DebugSourceInfo sourceInfo; - - sourceInfo.m_numLines = uint32_t(debugSourceFile->m_sourceFile->getLineBreakOffsets().getCount()); - - sourceInfo.m_startSourceLoc = uint32_t(debugSourceFile->m_baseSourceLoc); - sourceInfo.m_endSourceLoc = uint32_t(debugSourceFile->m_baseSourceLoc + sourceFile->getContentSize()); - - sourceInfo.m_pathIndex = Ser::StringIndex(m_debugStringSlicePool.add(sourceFile->getPathInfo().foundPath)); - - sourceInfo.m_lineInfosStartIndex = uint32_t(m_serialData->m_debugLineInfos.getCount()); - sourceInfo.m_adjustedLineInfosStartIndex = uint32_t(m_serialData->m_debugAdjustedLineInfos.getCount()); - - sourceInfo.m_numLineInfos = uint32_t(debugSourceFile->m_lineInfos.getCount()); - sourceInfo.m_numAdjustedLineInfos = uint32_t(debugSourceFile->m_adjustedLineInfos.getCount()); - - // Add the line infos - m_serialData->m_debugLineInfos.addRange(debugSourceFile->m_lineInfos.begin(), debugSourceFile->m_lineInfos.getCount()); - m_serialData->m_debugAdjustedLineInfos.addRange(debugSourceFile->m_adjustedLineInfos.begin(), debugSourceFile->m_adjustedLineInfos.getCount()); - - // Add the source info - m_serialData->m_debugSourceInfos.add(sourceInfo); - } - - // Convert the string pool - SerialStringTableUtil::encodeStringTable(m_debugStringSlicePool, m_serialData->m_debugStringTable); - - return SLANG_OK; -} - -Result IRSerialWriter::write(IRModule* module, SourceManager* sourceManager, OptionFlags options, IRSerialData* serialData) -{ - typedef Ser::Inst::PayloadType PayloadType; - - m_sourceManager = sourceManager; - m_serialData = serialData; - - serialData->clear(); - - // We reserve 0 for null - m_insts.clear(); - m_insts.add(nullptr); - - // Reset - m_instMap.Clear(); - m_decorations.clear(); - - // Stack for parentInst - List parentInstStack; - - IRModuleInst* moduleInst = module->getModuleInst(); - parentInstStack.add(moduleInst); - - // Add to the map - _addInstruction(moduleInst); - - // Traverse all of the instructions - while (parentInstStack.getCount()) - { - // If it's in the stack it is assumed it is already in the inst map - IRInst* parentInst = parentInstStack.getLast(); - parentInstStack.removeLast(); - SLANG_ASSERT(m_instMap.ContainsKey(parentInst)); - - // Okay we go through each of the children in order. If they are IRInstParent derived, we add to stack to process later - // cos we want breadth first so the order of children is the same as their index order, meaning we don't need to store explicit indices - const Ser::InstIndex startChildInstIndex = Ser::InstIndex(m_insts.getCount()); - - IRInstListBase childrenList = parentInst->getDecorationsAndChildren(); - for (IRInst* child : childrenList) - { - // This instruction can't be in the map... - SLANG_ASSERT(!m_instMap.ContainsKey(child)); - - _addInstruction(child); - - parentInstStack.add(child); - } - - // If it had any children, then store the information about it - if (Ser::InstIndex(m_insts.getCount()) != startChildInstIndex) - { - Ser::InstRun run; - run.m_parentIndex = m_instMap[parentInst]; - run.m_startInstIndex = startChildInstIndex; - run.m_numChildren = Ser::SizeType(m_insts.getCount() - int(startChildInstIndex)); - - m_serialData->m_childRuns.add(run); - } - } - -#if 0 - { - List workInsts; - calcInstructionList(module, workInsts); - SLANG_ASSERT(workInsts.Count() == m_insts.Count()); - for (UInt i = 0; i < workInsts.Count(); ++i) - { - SLANG_ASSERT(workInsts[i] == m_insts[i]); - } - } -#endif - - // Set to the right size - m_serialData->m_insts.setCount(m_insts.getCount()); - // Clear all instructions - memset(m_serialData->m_insts.begin(), 0, sizeof(Ser::Inst) * m_serialData->m_insts.getCount()); - - // Need to set up the actual instructions - { - const Index numInsts = m_insts.getCount(); - - for (Index i = 1; i < numInsts; ++i) - { - IRInst* srcInst = m_insts[i]; - Ser::Inst& dstInst = m_serialData->m_insts[i]; - - // Can't be any pseudo ops - SLANG_ASSERT(!isPseudoOp(srcInst->op)); - - dstInst.m_op = uint8_t(srcInst->op & kIROpMeta_OpMask); - dstInst.m_payloadType = PayloadType::Empty; - - dstInst.m_resultTypeIndex = getInstIndex(srcInst->getFullType()); - - IRConstant* irConst = as(srcInst); - if (irConst) - { - switch (srcInst->op) - { - // Special handling for the ir const derived types - case kIROp_StringLit: - { - auto stringLit = static_cast(srcInst); - dstInst.m_payloadType = PayloadType::String_1; - dstInst.m_payload.m_stringIndices[0] = getStringIndex(stringLit->getStringSlice()); - break; - } - case kIROp_IntLit: - { - dstInst.m_payloadType = PayloadType::Int64; - dstInst.m_payload.m_int64 = irConst->value.intVal; - break; - } - case kIROp_PtrLit: - { - dstInst.m_payloadType = PayloadType::Int64; - dstInst.m_payload.m_int64 = (intptr_t) irConst->value.ptrVal; - break; - } - case kIROp_FloatLit: - { - dstInst.m_payloadType = PayloadType::Float64; - dstInst.m_payload.m_float64 = irConst->value.floatVal; - break; - } - case kIROp_BoolLit: - { - dstInst.m_payloadType = PayloadType::UInt32; - dstInst.m_payload.m_uint32 = irConst->value.intVal ? 1 : 0; - break; - } - default: - { - SLANG_RELEASE_ASSERT(!"Unhandled constant type"); - return SLANG_FAIL; - } - } - continue; - } - - IRTextureTypeBase* textureBase = as(srcInst); - if (textureBase) - { - dstInst.m_payloadType = PayloadType::OperandAndUInt32; - dstInst.m_payload.m_operandAndUInt32.m_uint32 = uint32_t(srcInst->op) >> kIROpMeta_OtherShift; - dstInst.m_payload.m_operandAndUInt32.m_operand = getInstIndex(textureBase->getElementType()); - continue; - } - - // ModuleInst is different, in so far as it holds a pointer to IRModule, but we don't need - // to save that off in a special way, so can just use regular path - - const int numOperands = int(srcInst->operandCount); - Ser::InstIndex* dstOperands = nullptr; - - if (numOperands <= Ser::Inst::kMaxOperands) - { - // Checks the compile below is valid - SLANG_COMPILE_TIME_ASSERT(PayloadType(0) == PayloadType::Empty && PayloadType(1) == PayloadType::Operand_1 && PayloadType(2) == PayloadType::Operand_2); - - dstInst.m_payloadType = PayloadType(numOperands); - dstOperands = dstInst.m_payload.m_operands; - } - else - { - dstInst.m_payloadType = PayloadType::OperandExternal; - - int operandArrayBaseIndex = int(m_serialData->m_externalOperands.getCount()); - m_serialData->m_externalOperands.setCount(operandArrayBaseIndex + numOperands); - - dstOperands = m_serialData->m_externalOperands.begin() + operandArrayBaseIndex; - - auto& externalOperands = dstInst.m_payload.m_externalOperand; - externalOperands.m_arrayIndex = Ser::ArrayIndex(operandArrayBaseIndex); - externalOperands.m_size = Ser::SizeType(numOperands); - } - - for (int j = 0; j < numOperands; ++j) - { - const Ser::InstIndex dstInstIndex = getInstIndex(srcInst->getOperand(j)); - dstOperands[j] = dstInstIndex; - } - } - } - - // Convert strings into a string table - { - SerialStringTableUtil::encodeStringTable(m_stringSlicePool, serialData->m_stringTable); - } - - // If the option to use RawSourceLocations is enabled, serialize out as is - if (options & OptionFlag::RawSourceLocation) - { - const Index numInsts = m_insts.getCount(); - serialData->m_rawSourceLocs.setCount(numInsts); - - Ser::RawSourceLoc* dstLocs = serialData->m_rawSourceLocs.begin(); - // 0 is null, just mark as no location - dstLocs[0] = Ser::RawSourceLoc(0); - for (Index i = 1; i < numInsts; ++i) - { - IRInst* srcInst = m_insts[i]; - dstLocs[i] = Ser::RawSourceLoc(srcInst->sourceLoc.getRaw()); - } - } - - if (options & OptionFlag::DebugInfo) - { - _calcDebugInfo(); - } - - m_serialData = nullptr; - return SLANG_OK; -} - -template -static size_t _calcChunkSize(IRSerialBinary::CompressionType compressionType, const List& array) -{ - typedef IRSerialBinary Bin; - - if (array.getCount()) - { - switch (compressionType) - { - case Bin::CompressionType::None: - { - const size_t size = sizeof(Bin::ArrayHeader) + sizeof(T) * array.getCount(); - return (size + 3) & ~size_t(3); - } - case Bin::CompressionType::VariableByteLite: - { - const size_t payloadSize = ByteEncodeUtil::calcEncodeLiteSizeUInt32((const uint32_t*)array.begin(), (array.getCount() * sizeof(T)) / sizeof(uint32_t)); - const size_t size = sizeof(Bin::CompressedArrayHeader) + payloadSize; - return (size + 3) & ~size_t(3); - } - default: - { - SLANG_ASSERT(!"Unhandled compression type"); - return 0; - } - } - } - else - { - return 0; - } -} - -static Result _writeArrayChunk(IRSerialBinary::CompressionType compressionType, uint32_t chunkId, const void* data, size_t numEntries, size_t typeSize, Stream* stream) -{ - typedef IRSerialBinary Bin; - - if (numEntries == 0) - { - return SLANG_OK; - } - - size_t payloadSize; - - switch (compressionType) - { - case Bin::CompressionType::None: - { - payloadSize = sizeof(Bin::ArrayHeader) - sizeof(Bin::Chunk) + typeSize * numEntries; - - Bin::ArrayHeader header; - header.m_chunk.m_type = chunkId; - header.m_chunk.m_size = uint32_t(payloadSize); - header.m_numEntries = uint32_t(numEntries); - - stream->Write(&header, sizeof(header)); - - stream->Write(data, typeSize * numEntries); - break; - } - case Bin::CompressionType::VariableByteLite: - { - List compressedPayload; - - size_t numCompressedEntries = (numEntries * typeSize) / sizeof(uint32_t); - - ByteEncodeUtil::encodeLiteUInt32((const uint32_t*)data, numCompressedEntries, compressedPayload); - - payloadSize = sizeof(Bin::CompressedArrayHeader) - sizeof(Bin::Chunk) + compressedPayload.getCount(); - - Bin::CompressedArrayHeader header; - header.m_chunk.m_type = SLANG_MAKE_COMPRESSED_FOUR_CC(chunkId); - header.m_chunk.m_size = uint32_t(payloadSize); - header.m_numEntries = uint32_t(numEntries); - header.m_numCompressedEntries = uint32_t(numCompressedEntries); - - stream->Write(&header, sizeof(header)); - - stream->Write(compressedPayload.begin(), compressedPayload.getCount()); - break; - } - default: - { - return SLANG_FAIL; - } - } - // All chunks have sizes rounded to dword size - if (payloadSize & 3) - { - const uint8_t pad[4] = { 0, 0, 0, 0 }; - // Pad outs - int padSize = 4 - (payloadSize & 3); - stream->Write(pad, padSize); - } - - return SLANG_OK; -} - -template -Result _writeArrayChunk(IRSerialBinary::CompressionType compressionType, uint32_t chunkId, const List& array, Stream* stream) -{ - return _writeArrayChunk(compressionType, chunkId, array.begin(), size_t(array.getCount()), sizeof(T), stream); -} - -Result _encodeInsts(IRSerialBinary::CompressionType compressionType, const List& instsIn, List& encodeArrayOut) -{ - typedef IRSerialBinary Bin; - typedef IRSerialData::Inst::PayloadType PayloadType; - - if (compressionType != Bin::CompressionType::VariableByteLite) - { - return SLANG_FAIL; - } - encodeArrayOut.clear(); - - const size_t numInsts = size_t(instsIn.getCount()); - const IRSerialData::Inst* insts = instsIn.begin(); - - uint8_t* encodeOut = encodeArrayOut.begin(); - uint8_t* encodeEnd = encodeArrayOut.end(); - - // Calculate the maximum instruction size with worst case possible encoding - // 2 bytes hold the payload size, and the result type - // Note that if there were some free bits, we could encode some of this stuff into bits, but if we remove payloadType, then there are no free bits - const size_t maxInstSize = 2 + ByteEncodeUtil::kMaxLiteEncodeUInt32 + Math::Max(sizeof(insts->m_payload.m_float64), size_t(2 * ByteEncodeUtil::kMaxLiteEncodeUInt32)); - - for (size_t i = 0; i < numInsts; ++i) - { - const auto& inst = insts[i]; - - // Make sure there is space for the largest possible instruction - if (encodeOut + maxInstSize >= encodeEnd) - { - const size_t offset = size_t(encodeOut - encodeArrayOut.begin()); - - const UInt oldCapacity = encodeArrayOut.getCapacity(); - - encodeArrayOut.reserve(oldCapacity + (oldCapacity >> 1) + maxInstSize); - const UInt capacity = encodeArrayOut.getCapacity(); - encodeArrayOut.setCount(capacity); - - encodeOut = encodeArrayOut.begin() + offset; - encodeEnd = encodeArrayOut.end(); - } - - *encodeOut++ = uint8_t(inst.m_op); - *encodeOut++ = uint8_t(inst.m_payloadType); - - encodeOut += ByteEncodeUtil::encodeLiteUInt32((uint32_t)inst.m_resultTypeIndex, encodeOut); - - switch (inst.m_payloadType) - { - case PayloadType::Empty: - { - break; - } - case PayloadType::Operand_1: - case PayloadType::String_1: - case PayloadType::UInt32: - { - // 1 UInt32 - encodeOut += ByteEncodeUtil::encodeLiteUInt32((uint32_t)inst.m_payload.m_operands[0], encodeOut); - break; - } - case PayloadType::Operand_2: - case PayloadType::OperandAndUInt32: - case PayloadType::OperandExternal: - case PayloadType::String_2: - { - // 2 UInt32 - encodeOut += ByteEncodeUtil::encodeLiteUInt32((uint32_t)inst.m_payload.m_operands[0], encodeOut); - encodeOut += ByteEncodeUtil::encodeLiteUInt32((uint32_t)inst.m_payload.m_operands[1], encodeOut); - break; - } - case PayloadType::Float64: - { - memcpy(encodeOut, &inst.m_payload.m_float64, sizeof(inst.m_payload.m_float64)); - encodeOut += sizeof(inst.m_payload.m_float64); - break; - } - case PayloadType::Int64: - { - memcpy(encodeOut, &inst.m_payload.m_int64, sizeof(inst.m_payload.m_int64)); - encodeOut += sizeof(inst.m_payload.m_int64); - break; - } - } - } - - // Fix the size - encodeArrayOut.setCount(UInt(encodeOut - encodeArrayOut.begin())); - return SLANG_OK; -} - -Result _writeInstArrayChunk(IRSerialBinary::CompressionType compressionType, uint32_t chunkId, const List& array, Stream* stream) -{ - typedef IRSerialBinary Bin; - if (array.getCount() == 0) - { - return SLANG_OK; - } - - switch (compressionType) - { - case Bin::CompressionType::None: - { - return _writeArrayChunk(compressionType, chunkId, array, stream); - } - case Bin::CompressionType::VariableByteLite: - { - List compressedPayload; - SLANG_RETURN_ON_FAIL(_encodeInsts(compressionType, array, compressedPayload)); - - size_t payloadSize = sizeof(Bin::CompressedArrayHeader) - sizeof(Bin::Chunk) + compressedPayload.getCount(); - - Bin::CompressedArrayHeader header; - header.m_chunk.m_type = SLANG_MAKE_COMPRESSED_FOUR_CC(chunkId); - header.m_chunk.m_size = uint32_t(payloadSize); - header.m_numEntries = uint32_t(array.getCount()); - header.m_numCompressedEntries = 0; - - stream->Write(&header, sizeof(header)); - stream->Write(compressedPayload.begin(), compressedPayload.getCount()); - - // All chunks have sizes rounded to dword size - if (payloadSize & 3) - { - const uint8_t pad[4] = { 0, 0, 0, 0 }; - // Pad outs - int padSize = 4 - (payloadSize & 3); - stream->Write(pad, padSize); - } - return SLANG_OK; - } - default: break; - } - return SLANG_FAIL; -} - -static size_t _calcInstChunkSize(IRSerialBinary::CompressionType compressionType, const List& instsIn) -{ - typedef IRSerialBinary Bin; - typedef IRSerialData::Inst::PayloadType PayloadType; - - switch (compressionType) - { - case Bin::CompressionType::None: - { - return _calcChunkSize(compressionType, instsIn); - } - case Bin::CompressionType::VariableByteLite: - { - size_t size = sizeof(Bin::CompressedArrayHeader); - - size_t numInsts = size_t(instsIn.getCount()); - size += numInsts * 2; // op and payload - - IRSerialData::Inst* insts = instsIn.begin(); - - for (size_t i = 0; i < numInsts; ++i) - { - const auto& inst = insts[i]; - - size += ByteEncodeUtil::calcEncodeLiteSizeUInt32((uint32_t)inst.m_resultTypeIndex); - - switch (inst.m_payloadType) - { - case PayloadType::Empty: - { - break; - } - case PayloadType::Operand_1: - case PayloadType::String_1: - case PayloadType::UInt32: - { - // 1 UInt32 - size += ByteEncodeUtil::calcEncodeLiteSizeUInt32((uint32_t)inst.m_payload.m_operands[0]); - break; - } - case PayloadType::Operand_2: - case PayloadType::OperandAndUInt32: - case PayloadType::OperandExternal: - case PayloadType::String_2: - { - // 2 UInt32 - size += ByteEncodeUtil::calcEncodeLiteSizeUInt32((uint32_t)inst.m_payload.m_operands[0]); - size += ByteEncodeUtil::calcEncodeLiteSizeUInt32((uint32_t)inst.m_payload.m_operands[1]); - break; - } - case PayloadType::Float64: - { - size += sizeof(inst.m_payload.m_float64); - break; - } - case PayloadType::Int64: - { - size += sizeof(inst.m_payload.m_int64); - break; - } - } - } - - return (size + 3) & ~size_t(3); - } - default: break; - } - - SLANG_ASSERT(!"Unhandled compression type"); - return 0; -} - -/* static */Result IRSerialWriter::writeStream(const IRSerialData& data, Bin::CompressionType compressionType, Stream* stream) -{ - size_t totalSize = 0; - - totalSize += sizeof(Bin::SlangHeader) + - _calcInstChunkSize(compressionType, data.m_insts) + - _calcChunkSize(compressionType, data.m_childRuns) + - _calcChunkSize(compressionType, data.m_externalOperands) + - _calcChunkSize(Bin::CompressionType::None, data.m_stringTable) + - _calcChunkSize(Bin::CompressionType::None, data.m_rawSourceLocs); - - if (data.m_debugSourceInfos.getCount()) - { - totalSize += _calcChunkSize(Bin::CompressionType::None, data.m_debugStringTable) + - _calcChunkSize(Bin::CompressionType::None, data.m_debugLineInfos) + - _calcChunkSize(Bin::CompressionType::None, data.m_debugAdjustedLineInfos) + - _calcChunkSize(Bin::CompressionType::None, data.m_debugSourceInfos) + - _calcChunkSize(compressionType, data.m_debugSourceLocRuns); - } - - { - Bin::Chunk riffHeader; - riffHeader.m_type = Bin::kRiffFourCc; - riffHeader.m_size = uint32_t(totalSize); - - stream->Write(&riffHeader, sizeof(riffHeader)); - } - { - Bin::SlangHeader slangHeader; - slangHeader.m_chunk.m_type = Bin::kSlangFourCc; - slangHeader.m_chunk.m_size = uint32_t(sizeof(slangHeader) - sizeof(Bin::Chunk)); - slangHeader.m_compressionType = uint32_t(Bin::CompressionType::VariableByteLite); - - stream->Write(&slangHeader, sizeof(slangHeader)); - } - - SLANG_RETURN_ON_FAIL(_writeInstArrayChunk(compressionType, Bin::kInstFourCc, data.m_insts, stream)); - SLANG_RETURN_ON_FAIL(_writeArrayChunk(compressionType, Bin::kChildRunFourCc, data.m_childRuns, stream)); - SLANG_RETURN_ON_FAIL(_writeArrayChunk(compressionType, Bin::kExternalOperandsFourCc, data.m_externalOperands, stream)); - SLANG_RETURN_ON_FAIL(_writeArrayChunk(Bin::CompressionType::None, Bin::kStringFourCc, data.m_stringTable, stream)); - - SLANG_RETURN_ON_FAIL(_writeArrayChunk(Bin::CompressionType::None, Bin::kUInt32SourceLocFourCc, data.m_rawSourceLocs, stream)); - - if (data.m_debugSourceInfos.getCount()) - { - _writeArrayChunk(Bin::CompressionType::None, Bin::kDebugStringFourCc, data.m_debugStringTable, stream); - _writeArrayChunk(Bin::CompressionType::None, Bin::kDebugLineInfoFourCc, data.m_debugLineInfos, stream); - _writeArrayChunk(Bin::CompressionType::None, Bin::kDebugAdjustedLineInfoFourCc, data.m_debugAdjustedLineInfos, stream); - _writeArrayChunk(Bin::CompressionType::None, Bin::kDebugSourceInfoFourCc, data.m_debugSourceInfos, stream); - _writeArrayChunk(compressionType, Bin::kDebugSourceLocRunFourCc, data.m_debugSourceLocRuns, stream); - } - - return SLANG_OK; -} - -// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! IRSerialReader !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! - -class ListResizer -{ - public: - virtual void* setSize(size_t newSize) = 0; - SLANG_FORCE_INLINE size_t getTypeSize() const { return m_typeSize; } - ListResizer(size_t typeSize):m_typeSize(typeSize) {} - - protected: - size_t m_typeSize; -}; - -template -class ListResizerForType: public ListResizer -{ - public: - typedef ListResizer Parent; - - SLANG_FORCE_INLINE ListResizerForType(List& list): - Parent(sizeof(T)), - m_list(list) - {} - - virtual void* setSize(size_t newSize) SLANG_OVERRIDE - { - m_list.setCount(UInt(newSize)); - return (void*)m_list.begin(); - } - - protected: - List& m_list; -}; - -static Result _readArrayChunk(IRSerialBinary::CompressionType compressionType, const IRSerialBinary::Chunk& chunk, Stream* stream, size_t* numReadInOut, ListResizer& listOut) -{ - typedef IRSerialBinary Bin; - - const size_t typeSize = listOut.getTypeSize(); - - switch (compressionType) - { - case Bin::CompressionType::VariableByteLite: - { - // We have a compressed header - Bin::CompressedArrayHeader header; - header.m_chunk = chunk; - - stream->Read(&header.m_chunk + 1, sizeof(header) - sizeof(Bin::Chunk)); - *numReadInOut += sizeof(header) - sizeof(Bin::Chunk); - - void* data = listOut.setSize(header.m_numEntries); - - // Need to read all the compressed data... - size_t payloadSize = header.m_chunk.m_size - (sizeof(header) - sizeof(Bin::Chunk)); - - List compressedPayload; - compressedPayload.setCount(payloadSize); - - stream->Read(compressedPayload.begin(), payloadSize); - *numReadInOut += payloadSize; - - SLANG_ASSERT(header.m_numCompressedEntries == uint32_t((header.m_numEntries * typeSize) / sizeof(uint32_t))); - - // Decode.. - ByteEncodeUtil::decodeLiteUInt32(compressedPayload.begin(), header.m_numCompressedEntries, (uint32_t*)data); - break; - } - case Bin::CompressionType::None: - { - // Read uncompressed - Bin::ArrayHeader header; - header.m_chunk = chunk; - - stream->Read(&header.m_chunk + 1, sizeof(header) - sizeof(Bin::Chunk)); - *numReadInOut += sizeof(header) - sizeof(Bin::Chunk); - - const size_t payloadSize = header.m_numEntries * typeSize; - - void* data = listOut.setSize(header.m_numEntries); - - stream->Read(data, payloadSize); - *numReadInOut += payloadSize; - break; - } - } - - // All chunks have sizes rounded to dword size - if (*numReadInOut & 3) - { - const uint8_t pad[4] = { 0, 0, 0, 0 }; - // Pad outs - int padSize = 4 - int(*numReadInOut & 3); - stream->Seek(SeekOrigin::Current, padSize); - - *numReadInOut += padSize; - } - - return SLANG_OK; -} - -template -Result _readArrayChunk(const IRSerialBinary::SlangHeader& header, const IRSerialBinary::Chunk& chunk, Stream* stream, size_t* numReadInOut, List& arrayOut) -{ - typedef IRSerialBinary Bin; - - Bin::CompressionType compressionType = Bin::CompressionType::None; - - if (chunk.m_type == SLANG_MAKE_COMPRESSED_FOUR_CC(chunk.m_type)) - { - // If it has compression, use the compression type set in the header - compressionType = Bin::CompressionType(header.m_compressionType); - } - ListResizerForType resizer(arrayOut); - return _readArrayChunk(compressionType, chunk, stream, numReadInOut, resizer); -} - -template -Result _readArrayUncompressedChunk(const IRSerialBinary::SlangHeader& header, const IRSerialBinary::Chunk& chunk, Stream* stream, size_t* numReadInOut, List& arrayOut) -{ - typedef IRSerialBinary Bin; - SLANG_UNUSED(header); - ListResizerForType resizer(arrayOut); - return _readArrayChunk(Bin::CompressionType::None, chunk, stream, numReadInOut, resizer); -} - -static Result _decodeInsts(IRSerialBinary::CompressionType compressionType, const List& encodeIn, List& instsOut) -{ - typedef IRSerialBinary Bin; - typedef IRSerialData::Inst::PayloadType PayloadType; - - if (compressionType != Bin::CompressionType::VariableByteLite) - { - return SLANG_FAIL; - } - - const size_t numInsts = size_t(instsOut.getCount()); - IRSerialData::Inst* insts = instsOut.begin(); - - const uint8_t* encodeCur = encodeIn.begin(); - - for (size_t i = 0; i < numInsts; ++i) - { - auto& inst = insts[i]; - - inst.m_op = *encodeCur++; - const PayloadType payloadType = PayloadType(*encodeCur++); - inst.m_payloadType = payloadType; - - // Read the result value - encodeCur += ByteEncodeUtil::decodeLiteUInt32(encodeCur, (uint32_t*)&inst.m_resultTypeIndex); - - switch (inst.m_payloadType) - { - case PayloadType::Empty: - { - break; - } - case PayloadType::Operand_1: - case PayloadType::String_1: - case PayloadType::UInt32: - { - // 1 UInt32 - encodeCur += ByteEncodeUtil::decodeLiteUInt32(encodeCur, (uint32_t*)&inst.m_payload.m_operands[0]); - break; - } - case PayloadType::Operand_2: - case PayloadType::OperandAndUInt32: - case PayloadType::OperandExternal: - case PayloadType::String_2: - { - // 2 UInt32 - encodeCur += ByteEncodeUtil::decodeLiteUInt32(encodeCur, 2, (uint32_t*)&inst.m_payload.m_operands[0]); - break; - } - case PayloadType::Float64: - { - memcpy(&inst.m_payload.m_float64, encodeCur, sizeof(inst.m_payload.m_float64)); - encodeCur += sizeof(inst.m_payload.m_float64); - break; - } - case PayloadType::Int64: - { - memcpy(&inst.m_payload.m_int64, encodeCur, sizeof(inst.m_payload.m_int64)); - encodeCur += sizeof(inst.m_payload.m_int64); - break; - } - } - } - - return SLANG_OK; -} - -Result _readInstArrayChunk(const IRSerialBinary::SlangHeader& slangHeader, const IRSerialBinary::Chunk& chunk, Stream* stream, size_t* numReadInOut, List& arrayOut) -{ - typedef IRSerialBinary Bin; - - Bin::CompressionType compressionType = Bin::CompressionType::None; - if (chunk.m_type == SLANG_MAKE_COMPRESSED_FOUR_CC(chunk.m_type)) - { - compressionType = Bin::CompressionType(slangHeader.m_compressionType); - } - - switch (compressionType) - { - case Bin::CompressionType::None: - { - ListResizerForType resizer(arrayOut); - return _readArrayChunk(compressionType, chunk, stream, numReadInOut, resizer); - } - case Bin::CompressionType::VariableByteLite: - { - // We have a compressed header - Bin::CompressedArrayHeader header; - header.m_chunk = chunk; - - stream->Read(&header.m_chunk + 1, sizeof(header) - sizeof(Bin::Chunk)); - *numReadInOut += sizeof(header) - sizeof(Bin::Chunk); - - // Need to read all the compressed data... - size_t payloadSize = header.m_chunk.m_size - (sizeof(header) - sizeof(Bin::Chunk)); - - List compressedPayload; - compressedPayload.setCount(payloadSize); - - stream->Read(compressedPayload.begin(), payloadSize); - *numReadInOut += payloadSize; - - arrayOut.setCount(header.m_numEntries); - - SLANG_RETURN_ON_FAIL(_decodeInsts(compressionType, compressedPayload, arrayOut)); - break; - } - default: - { - return SLANG_FAIL; - } - } - - // All chunks have sizes rounded to dword size - if (*numReadInOut & 3) - { - // Pad outs - int padSize = 4 - int(*numReadInOut & 3); - stream->Seek(SeekOrigin::Current, padSize); - *numReadInOut += padSize; - } - - return SLANG_OK; -} - -int64_t _calcChunkTotalSize(const IRSerialBinary::Chunk& chunk) -{ - int64_t size = chunk.m_size + sizeof(IRSerialBinary::Chunk); - return (size + 3) & ~int64_t(3); -} - -/* static */Result IRSerialReader::_skip(const IRSerialBinary::Chunk& chunk, Stream* stream, int64_t* remainingBytesInOut) -{ - typedef IRSerialBinary Bin; - int64_t chunkSize = _calcChunkTotalSize(chunk); - if (remainingBytesInOut) - { - *remainingBytesInOut -= chunkSize; - } - - // Skip the payload (we don't need to skip the Chunk because that was already read - stream->Seek(SeekOrigin::Current, chunkSize - sizeof(IRSerialBinary::Chunk)); - return SLANG_OK; -} - -/* static */Result IRSerialReader::readStream(Stream* stream, IRSerialData* dataOut) -{ - typedef IRSerialBinary Bin; - - dataOut->clear(); - - int64_t remainingBytes = 0; - { - Bin::Chunk header; - stream->Read(&header, sizeof(header)); - if (header.m_type != Bin::kRiffFourCc) - { - return SLANG_FAIL; - } - - remainingBytes = header.m_size; - } - - // Header - // Chunk will not be kSlangFourCC if not read yet - Bin::SlangHeader slangHeader; - memset(&slangHeader, 0, sizeof(slangHeader)); - - while (remainingBytes > 0) - { - Bin::Chunk chunk; - - stream->Read(&chunk, sizeof(chunk)); - - size_t bytesRead = sizeof(chunk); - - switch (chunk.m_type) - { - case Bin::kSlangFourCc: - { - // Slang header - slangHeader.m_chunk = chunk; - - // NOTE! Really we should only read what we know the size to be... - // and skip if it's larger - - stream->Read(&slangHeader.m_chunk + 1, sizeof(slangHeader) - sizeof(chunk)); - - remainingBytes -= _calcChunkTotalSize(chunk); - break; - } - case SLANG_MAKE_COMPRESSED_FOUR_CC(Bin::kInstFourCc): - case Bin::kInstFourCc: - { - SLANG_RETURN_ON_FAIL(_readInstArrayChunk(slangHeader, chunk, stream, &bytesRead, dataOut->m_insts)); - remainingBytes -= _calcChunkTotalSize(chunk); - break; - } - case SLANG_MAKE_COMPRESSED_FOUR_CC(Bin::kChildRunFourCc): - case Bin::kChildRunFourCc: - { - SLANG_RETURN_ON_FAIL(_readArrayChunk(slangHeader, chunk, stream, &bytesRead, dataOut->m_childRuns)); - remainingBytes -= _calcChunkTotalSize(chunk); - break; - } - case SLANG_MAKE_COMPRESSED_FOUR_CC(Bin::kExternalOperandsFourCc): - case Bin::kExternalOperandsFourCc: - { - SLANG_RETURN_ON_FAIL(_readArrayChunk(slangHeader, chunk, stream, &bytesRead, dataOut->m_externalOperands)); - remainingBytes -= _calcChunkTotalSize(chunk); - break; - } - case Bin::kStringFourCc: - { - SLANG_RETURN_ON_FAIL(_readArrayUncompressedChunk(slangHeader, chunk, stream, &bytesRead, dataOut->m_stringTable)); - remainingBytes -= _calcChunkTotalSize(chunk); - break; - } - case Bin::kUInt32SourceLocFourCc: - { - SLANG_RETURN_ON_FAIL(_readArrayUncompressedChunk(slangHeader, chunk, stream, &bytesRead, dataOut->m_rawSourceLocs)); - remainingBytes -= _calcChunkTotalSize(chunk); - break; - } - case Bin::kDebugStringFourCc: - { - SLANG_RETURN_ON_FAIL(_readArrayUncompressedChunk(slangHeader, chunk, stream, &bytesRead, dataOut->m_debugStringTable)); - remainingBytes -= _calcChunkTotalSize(chunk); - break; - } - case Bin::kDebugLineInfoFourCc: - { - SLANG_RETURN_ON_FAIL(_readArrayUncompressedChunk(slangHeader, chunk, stream, &bytesRead, dataOut->m_debugLineInfos)); - remainingBytes -= _calcChunkTotalSize(chunk); - break; - } - case Bin::kDebugAdjustedLineInfoFourCc: - { - SLANG_RETURN_ON_FAIL(_readArrayUncompressedChunk(slangHeader, chunk, stream, &bytesRead, dataOut->m_debugAdjustedLineInfos)); - remainingBytes -= _calcChunkTotalSize(chunk); - break; - } - case Bin::kDebugSourceInfoFourCc: - { - SLANG_RETURN_ON_FAIL(_readArrayChunk(slangHeader, chunk, stream, &bytesRead, dataOut->m_debugSourceInfos)); - remainingBytes -= _calcChunkTotalSize(chunk); - break; - } - case SLANG_MAKE_COMPRESSED_FOUR_CC(Bin::kDebugSourceLocRunFourCc): - case Bin::kDebugSourceLocRunFourCc: - { - SLANG_RETURN_ON_FAIL(_readArrayChunk(slangHeader, chunk, stream, &bytesRead, dataOut->m_debugSourceLocRuns)); - remainingBytes -= _calcChunkTotalSize(chunk); - break; - } - - default: - { - SLANG_RETURN_ON_FAIL(_skip(chunk, stream, &remainingBytes)); - break; - } - } - } - - return SLANG_OK; -} - -static SourceRange _toSourceRange(const IRSerialData::DebugSourceInfo& info) -{ - SourceRange range; - range.begin = SourceLoc::fromRaw(info.m_startSourceLoc); - range.end = SourceLoc::fromRaw(info.m_endSourceLoc); - return range; -} - -static int _findIndex(const List& infos, SourceLoc sourceLoc) -{ - const int numInfos = int(infos.getCount()); - for (int i = 0; i < numInfos; ++i) - { - if (_toSourceRange(infos[i]).contains(sourceLoc)) - { - return i; - } - } - - return -1; -} - -static int _calcFixSourceLoc(const IRSerialData::DebugSourceInfo& info, SourceView* sourceView, SourceRange& rangeOut) -{ - rangeOut = _toSourceRange(info); - return int(sourceView->getRange().begin.getRaw()) - int(info.m_startSourceLoc); -} - -/* static */Result IRSerialReader::read(const IRSerialData& data, Session* session, SourceManager* sourceManager, RefPtr& moduleOut) -{ - typedef Ser::Inst::PayloadType PayloadType; - - m_serialData = &data; - - auto module = new IRModule(); - moduleOut = module; - m_module = module; - - module->session = session; - - // Set up the string rep cache - m_stringRepresentationCache.init(&data.m_stringTable, session->getNamePool(), module->getObjectScopeManager()); - - // Add all the instructions - - List insts; - - const Index numInsts = data.m_insts.getCount(); - - SLANG_ASSERT(numInsts > 0); - - insts.setCount(numInsts); - insts[0] = nullptr; - - // 0 holds null - // 1 holds the IRModuleInst - { - // Check that insts[1] is the module inst - const Ser::Inst& srcInst = data.m_insts[1]; - SLANG_RELEASE_ASSERT(srcInst.m_op == kIROp_Module); - SLANG_ASSERT(srcInst.m_payloadType == PayloadType::Empty); - - // Create the module inst - auto moduleInst = static_cast(createEmptyInstWithSize(module, kIROp_Module, sizeof(IRModuleInst))); - module->moduleInst = moduleInst; - moduleInst->module = module; - - // Set the IRModuleInst - insts[1] = moduleInst; - } - - for (Index i = 2; i < numInsts; ++i) - { - const Ser::Inst& srcInst = data.m_insts[i]; - - const IROp op((IROp)srcInst.m_op); - - if (isConstant(op)) - { - // Handling of constants - - // Calculate the minimum object size (ie not including the payload of value) - const size_t prefixSize = SLANG_OFFSET_OF(IRConstant, value); - - IRConstant* irConst = nullptr; - switch (op) - { - case kIROp_BoolLit: - { - SLANG_ASSERT(srcInst.m_payloadType == PayloadType::UInt32); - irConst = static_cast(createEmptyInstWithSize(module, op, prefixSize + sizeof(IRIntegerValue))); - irConst->value.intVal = srcInst.m_payload.m_uint32 != 0; - break; - } - case kIROp_IntLit: - { - SLANG_ASSERT(srcInst.m_payloadType == PayloadType::Int64); - irConst = static_cast(createEmptyInstWithSize(module, op, prefixSize + sizeof(IRIntegerValue))); - irConst->value.intVal = srcInst.m_payload.m_int64; - break; - } - case kIROp_PtrLit: - { - SLANG_ASSERT(srcInst.m_payloadType == PayloadType::Int64); - irConst = static_cast(createEmptyInstWithSize(module, op, prefixSize + sizeof(void*))); - irConst->value.ptrVal = (void*) (intptr_t) srcInst.m_payload.m_int64; - break; - } - case kIROp_FloatLit: - { - SLANG_ASSERT(srcInst.m_payloadType == PayloadType::Float64); - irConst = static_cast(createEmptyInstWithSize(module, op, prefixSize + sizeof(IRFloatingPointValue))); - irConst->value.floatVal = srcInst.m_payload.m_float64; - break; - } - case kIROp_StringLit: - { - SLANG_ASSERT(srcInst.m_payloadType == PayloadType::String_1); - - const UnownedStringSlice slice = m_stringRepresentationCache.getStringSlice(StringHandle(srcInst.m_payload.m_stringIndices[0])); - - const size_t sliceSize = slice.size(); - const size_t instSize = prefixSize + SLANG_OFFSET_OF(IRConstant::StringValue, chars) + sliceSize; - - irConst = static_cast(createEmptyInstWithSize(module, op, instSize)); - - IRConstant::StringValue& dstString = irConst->value.stringVal; - - dstString.numChars = uint32_t(sliceSize); - // Turn into pointer to avoid warning of array overrun - char* dstChars = dstString.chars; - // Copy the chars - memcpy(dstChars, slice.begin(), sliceSize); - break; - } - default: - { - SLANG_ASSERT(!"Unknown constant type"); - return SLANG_FAIL; - } - } - - insts[i] = irConst; - } - else if (isTextureTypeBase(op)) - { - IRTextureTypeBase* inst = static_cast(createEmptyInst(module, op, 1)); - SLANG_ASSERT(srcInst.m_payloadType == PayloadType::OperandAndUInt32); - - // Reintroduce the texture type bits into the the - const uint32_t other = srcInst.m_payload.m_operandAndUInt32.m_uint32; - inst->op = IROp(uint32_t(inst->op) | (other << kIROpMeta_OtherShift)); - - insts[i] = inst; - } - else - { - int numOperands = srcInst.getNumOperands(); - insts[i] = createEmptyInst(module, op, numOperands); - } - } - - // Patch up the operands - for (Index i = 1; i < numInsts; ++i) - { - const Ser::Inst& srcInst = data.m_insts[i]; - const IROp op((IROp)srcInst.m_op); - - IRInst* dstInst = insts[i]; - - // Set the result type - if (srcInst.m_resultTypeIndex != Ser::InstIndex(0)) - { - IRInst* resultInst = insts[int(srcInst.m_resultTypeIndex)]; - // NOTE! Counter intuitively the IRType* paramter may not be IRType* derived for example - // IRGlobalGenericParam is valid, but isn't IRType* derived - - //SLANG_RELEASE_ASSERT(as(resultInst)); - dstInst->setFullType(static_cast(resultInst)); - } - - //if (!isParentDerived(op)) - { - const Ser::InstIndex* srcOperandIndices; - const int numOperands = data.getOperands(srcInst, &srcOperandIndices); - - for (int j = 0; j < numOperands; j++) - { - dstInst->setOperand(j, insts[int(srcOperandIndices[j])]); - } - } - } - - // Patch up the children - { - const Index numChildRuns = data.m_childRuns.getCount(); - for (Index i = 0; i < numChildRuns; i++) - { - const auto& run = data.m_childRuns[i]; - - IRInst* inst = insts[int(run.m_parentIndex)]; - - for (int j = 0; j < int(run.m_numChildren); ++j) - { - IRInst* child = insts[j + int(run.m_startInstIndex)]; - SLANG_ASSERT(child->parent == nullptr); - child->insertAtEnd(inst); - } - } - } - - // Re-add source locations, if they are defined - if (m_serialData->m_rawSourceLocs.getCount() == numInsts) - { - const Ser::RawSourceLoc* srcLocs = m_serialData->m_rawSourceLocs.begin(); - for (Index i = 1; i < numInsts; ++i) - { - IRInst* dstInst = insts[i]; - - dstInst->sourceLoc.setRaw(Slang::SourceLoc::RawValue(srcLocs[i])); - } - } - - if (sourceManager && m_serialData->m_debugSourceInfos.getCount()) - { - List debugStringSlices; - SerialStringTableUtil::decodeStringTable(m_serialData->m_debugStringTable, debugStringSlices); - - // All of the strings are placed in the manager (and its StringSlicePool) where the SourceView and SourceFile are constructed from - List stringMap; - SerialStringTableUtil::calcStringSlicePoolMap(debugStringSlices, sourceManager->getStringSlicePool(), stringMap); - - const List& sourceInfos = m_serialData->m_debugSourceInfos; - - // Construct the source files - Index numSourceFiles = sourceInfos.getCount(); - - // These hold the views (and SourceFile as there is only one SourceFile per view) in the same order as the sourceInfos - List sourceViews; - sourceViews.setCount(numSourceFiles); - - for (Index i = 0; i < numSourceFiles; ++i) - { - const IRSerialData::DebugSourceInfo& srcSourceInfo = sourceInfos[i]; - - PathInfo pathInfo; - pathInfo.type = PathInfo::Type::FoundPath; - pathInfo.foundPath = debugStringSlices[UInt(srcSourceInfo.m_pathIndex)]; - - SourceFile* sourceFile = sourceManager->createSourceFileWithSize(pathInfo, srcSourceInfo.m_endSourceLoc - srcSourceInfo.m_startSourceLoc); - SourceView* sourceView = sourceManager->createSourceView(sourceFile, nullptr); - - // We need to accumulate all line numbers, for this source file, both adjusted and unadjusted - List lineInfos; - // Add the adjusted lines - { - lineInfos.setCount(srcSourceInfo.m_numAdjustedLineInfos); - const IRSerialData::DebugAdjustedLineInfo* srcAdjustedLineInfos = m_serialData->m_debugAdjustedLineInfos.getBuffer() + srcSourceInfo.m_adjustedLineInfosStartIndex; - const int numAdjustedLines = int(srcSourceInfo.m_numAdjustedLineInfos); - for (int j = 0; j < numAdjustedLines; ++j) - { - lineInfos[j] = srcAdjustedLineInfos[j].m_lineInfo; - } - } - // Add regular lines - lineInfos.addRange(m_serialData->m_debugLineInfos.getBuffer() + srcSourceInfo.m_lineInfosStartIndex, srcSourceInfo.m_numLineInfos); - // Put in sourceloc order - lineInfos.sort(); - - List lineBreakOffsets; - - // We can now set up the line breaks array - const int numLines = int(srcSourceInfo.m_numLines); - lineBreakOffsets.setCount(numLines); - - { - const Index numLineInfos = lineInfos.getCount(); - Index lineIndex = 0; - - // Every line up and including should hold the same offset - for (Index lineInfoIndex = 0; lineInfoIndex < numLineInfos; ++lineInfoIndex) - { - const auto& lineInfo = lineInfos[lineInfoIndex]; - - const uint32_t offset = lineInfo.m_lineStartOffset; - SLANG_ASSERT(offset > 0); - const int finishIndex = int(lineInfo.m_lineIndex); - - SLANG_ASSERT(finishIndex < numLines); - - for (; lineIndex < finishIndex; ++lineIndex) - { - lineBreakOffsets[lineIndex] = offset - 1; - } - lineBreakOffsets[lineIndex] = offset; - lineIndex++; - } - - // Do the remaining lines - const uint32_t offset = uint32_t(srcSourceInfo.m_endSourceLoc - srcSourceInfo.m_startSourceLoc); - for (; lineIndex < numLines; ++lineIndex) - { - lineBreakOffsets[lineIndex] = offset; - } - } - - sourceFile->setLineBreakOffsets(lineBreakOffsets.getBuffer(), lineBreakOffsets.getCount()); - - if (srcSourceInfo.m_numAdjustedLineInfos) - { - List adjustedLineInfos; - - int numEntries = int(srcSourceInfo.m_numAdjustedLineInfos); - - adjustedLineInfos.addRange(m_serialData->m_debugAdjustedLineInfos.getBuffer() + srcSourceInfo.m_adjustedLineInfosStartIndex, numEntries); - adjustedLineInfos.sort(); - - // Work out the views adjustments, and place in dstEntries - List dstEntries; - dstEntries.setCount(numEntries); - - const uint32_t sourceLocOffset = uint32_t(sourceView->getRange().begin.getRaw()); - - for (int j = 0; j < numEntries; ++j) - { - const auto& srcEntry = adjustedLineInfos[j]; - auto& dstEntry = dstEntries[j]; - - dstEntry.m_pathHandle = stringMap[int(srcEntry.m_pathStringIndex)]; - dstEntry.m_startLoc = SourceLoc::fromRaw(srcEntry.m_lineInfo.m_lineStartOffset + sourceLocOffset); - dstEntry.m_lineAdjust = int32_t(srcEntry.m_adjustedLineIndex) - int32_t(srcEntry.m_lineInfo.m_lineIndex); - } - - // Set the adjustments on the view - sourceView->setEntries(dstEntries.getBuffer(), dstEntries.getCount()); - } - - sourceViews[i] = sourceView; - } - - // We now need to apply the runs - { - List sourceRuns(m_serialData->m_debugSourceLocRuns); - // They are now in source location order - sourceRuns.sort(); - - // Just guess initially 0 for the source file that contains the initial run - SourceRange range; - int fixSourceLoc = _calcFixSourceLoc(sourceInfos[0], sourceViews[0], range); - - const Index numRuns = sourceRuns.getCount(); - for (Index i = 0; i < numRuns; ++i) - { - const auto& run = sourceRuns[i]; - const SourceLoc srcSourceLoc = SourceLoc::fromRaw(run.m_sourceLoc); - - if (!range.contains(srcSourceLoc)) - { - int index = _findIndex(sourceInfos, srcSourceLoc); - if (index < 0) - { - // Didn't find the match - continue; - } - fixSourceLoc = _calcFixSourceLoc(sourceInfos[index], sourceViews[index], range); - SLANG_ASSERT(range.contains(srcSourceLoc)); - } - - // Work out the fixed source location - SourceLoc sourceLoc = SourceLoc::fromRaw(int(run.m_sourceLoc) + fixSourceLoc); - - SLANG_ASSERT(Index(uint32_t(run.m_startInstIndex) + run.m_numInst) <= insts.getCount()); - IRInst** dstInsts = insts.getBuffer() + int(run.m_startInstIndex); - - const int runSize = int(run.m_numInst); - for (int j = 0; j < runSize; ++j) - { - dstInsts[j]->sourceLoc = sourceLoc; - } - } - } - } - - return SLANG_OK; -} - -// !!!!!!!!!!!!!!!!!!!!!!!!!!!! IRSerialUtil !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! - -/* static */void IRSerialUtil::calcInstructionList(IRModule* module, List& instsOut) -{ - // We reserve 0 for null - instsOut.setCount(1); - instsOut[0] = nullptr; - - // Stack for parentInst - List parentInstStack; - - IRModuleInst* moduleInst = module->getModuleInst(); - parentInstStack.add(moduleInst); - - // Add to list - instsOut.add(moduleInst); - - // Traverse all of the instructions - while (parentInstStack.getCount()) - { - // If it's in the stack it is assumed it is already in the inst map - IRInst* parentInst = parentInstStack.getLast(); - parentInstStack.removeLast(); - - IRInstListBase childrenList = parentInst->getDecorationsAndChildren(); - for (IRInst* child : childrenList) - { - instsOut.add(child); - parentInstStack.add(child); - } - } -} - -/* static */SlangResult IRSerialUtil::verifySerialize(IRModule* module, Session* session, SourceManager* sourceManager, IRSerialBinary::CompressionType compressionType, IRSerialWriter::OptionFlags optionFlags) -{ - // Verify if we can stream out with debug information - - List originalInsts; - calcInstructionList(module, originalInsts); - - IRSerialData serialData; - { - // Write IR out to serialData - copying over SourceLoc information directly - IRSerialWriter writer; - SLANG_RETURN_ON_FAIL(writer.write(module, sourceManager, optionFlags, &serialData)); - } - - // Write the data out to stream - MemoryStream memoryStream(FileAccess::ReadWrite); - SLANG_RETURN_ON_FAIL(IRSerialWriter::writeStream(serialData, compressionType, &memoryStream)); - - // Reset stream - memoryStream.Seek(SeekOrigin::Start, 0); - - IRSerialData readData; - - SLANG_RETURN_ON_FAIL(IRSerialReader::readStream(&memoryStream, &readData)); - - // Check the stream read data is the same - if (readData != serialData) - { - SLANG_ASSERT(!"Streamed in data doesn't match"); - return SLANG_FAIL; - } - - RefPtr irReadModule; - - SourceManager workSourceManager; - workSourceManager.initialize(sourceManager, nullptr); - - { - IRSerialReader reader; - SLANG_RETURN_ON_FAIL(reader.read(serialData, session, &workSourceManager, irReadModule)); - } - - List readInsts; - calcInstructionList(irReadModule, readInsts); - - if (readInsts.getCount() != originalInsts.getCount()) - { - SLANG_ASSERT(!"Instruction counts don't match"); - return SLANG_FAIL; - } - - if (optionFlags & IRSerialWriter::OptionFlag::RawSourceLocation) - { - SLANG_ASSERT(readInsts[0] == originalInsts[0]); - // All the source locs should be identical - for (Index i = 1; i < readInsts.getCount(); ++i) - { - IRInst* origInst = originalInsts[i]; - IRInst* readInst = readInsts[i]; - - if (origInst->sourceLoc.getRaw() != readInst->sourceLoc.getRaw()) - { - SLANG_ASSERT(!"Source locs don't match"); - return SLANG_FAIL; - } - } - } - else if (optionFlags & IRSerialWriter::OptionFlag::DebugInfo) - { - // They should be on the same line nos - for (Index i = 1; i < readInsts.getCount(); ++i) - { - IRInst* origInst = originalInsts[i]; - IRInst* readInst = readInsts[i]; - - if (origInst->sourceLoc.getRaw() == readInst->sourceLoc.getRaw()) - { - continue; - } - - // Work out the - SourceView* origSourceView = sourceManager->findSourceView(origInst->sourceLoc); - SourceView* readSourceView = workSourceManager.findSourceView(readInst->sourceLoc); - - // if both are null we are done - if (origSourceView == nullptr && origSourceView == readSourceView) - { - continue; - } - SLANG_ASSERT(origSourceView && readSourceView); - - { - auto origInfo = origSourceView->getHumaneLoc(origInst->sourceLoc, SourceLocType::Actual); - auto readInfo = readSourceView->getHumaneLoc(readInst->sourceLoc, SourceLocType::Actual); - - if (!(origInfo.line == readInfo.line && origInfo.column == readInfo.column && origInfo.pathInfo.foundPath == readInfo.pathInfo.foundPath)) - { - SLANG_ASSERT(!"Debug data didn't match"); - return SLANG_FAIL; - } - } - - // We may have adjusted line numbers -> but they may not match, because we only reconstruct one view - // So for now disable this test - - if (false) - { - auto origInfo = origSourceView->getHumaneLoc(origInst->sourceLoc, SourceLocType::Nominal); - auto readInfo = readSourceView->getHumaneLoc(readInst->sourceLoc, SourceLocType::Nominal); - - if (!(origInfo.line == readInfo.line && origInfo.column == readInfo.column && origInfo.pathInfo.foundPath == readInfo.pathInfo.foundPath)) - { - SLANG_ASSERT(!"Debug data didn't match"); - return SLANG_FAIL; - } - } - } - } - - return SLANG_OK; -} - -// !!!!!!!!!!!!!!!!!!!!!!!!!!!! Free functions !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! - -#if 0 - -Result serializeModule(IRModule* module, SourceManager* sourceManager, Stream* stream) -{ - IRSerialWriter serializer; - IRSerialData serialData; - - SLANG_RETURN_ON_FAIL(serializer.write(module, sourceManager, IRSerialWriter::OptionFlag::RawSourceLocation, &serialData)); - - if (stream) - { - SLANG_RETURN_ON_FAIL(IRSerialWriter::writeStream(serialData, IRSerialBinary::CompressionType::VariableByteLite, stream)); - } - - return SLANG_OK; -} - -Result readModule(Session* session, Stream* stream, RefPtr& moduleOut) -{ - IRSerialData serialData; - IRSerialReader::readStream(stream, &serialData); - - IRSerialReader reader; - return reader.read(serialData, session, moduleOut); -} - -#endif - -} // namespace Slang diff --git a/source/slang/ir-serialize.h b/source/slang/ir-serialize.h deleted file mode 100644 index 8be852e95..000000000 --- a/source/slang/ir-serialize.h +++ /dev/null @@ -1,549 +0,0 @@ -// ir-serialize.h -#ifndef SLANG_IR_SERIALIZE_H_INCLUDED -#define SLANG_IR_SERIALIZE_H_INCLUDED - -#include "../core/basic.h" -#include "../core/stream.h" - -#include "../core/slang-object-scope-manager.h" - -#include "ir.h" - -// For TranslationUnitRequest -#include "compiler.h" - -namespace Slang { - -class StringRepresentationCache -{ - public: - typedef StringSlicePool::Handle Handle; - - struct Entry - { - uint32_t m_startIndex; - uint32_t m_numChars; - RefObject* m_object; ///< Could be nullptr, Name, or StringRepresentation. - }; - - /// Get as a name - Name* getName(Handle handle); - /// Get as a string - String getString(Handle handle); - /// Get as string representation - StringRepresentation* getStringRepresentation(Handle handle); - /// Get as a string slice - UnownedStringSlice getStringSlice(Handle handle) const; - /// Get as a 0 terminated 'c style' string - char* getCStr(Handle handle); - - /// Initialize a cache to use a string table, namePool and scopeManager - void init(const List* stringTable, NamePool* namePool, ObjectScopeManager* scopeManager); - - /// Ctor - StringRepresentationCache(); - - protected: - ObjectScopeManager* m_scopeManager; - NamePool* m_namePool; - const List* m_stringTable; - List m_entries; -}; - -struct SerialStringTableUtil -{ - /// Convert a pool into a string table - static void encodeStringTable(const StringSlicePool& pool, List& stringTable); - static void encodeStringTable(const UnownedStringSlice* slices, size_t numSlices, List& stringTable); - /// Appends the decoded strings into slicesOut - static void appendDecodedStringTable(const List& stringTable, List& slicesOut); - /// Decodes a string table (and does so such that the indices are compatible with StringSlicePool) - static void decodeStringTable(const List& stringTable, List& slicesOut); - - /// Produces an index map, from slices to indices in pool - static void calcStringSlicePoolMap(const List& slices, StringSlicePool& pool, List& indexMap); -}; - -// Pre-declare -class Name; - -struct IRSerialData -{ - typedef IRSerialData ThisType; - - enum class InstIndex : uint32_t; - enum class StringIndex : uint32_t; - enum class ArrayIndex : uint32_t; - - enum class RawSourceLoc : SourceLoc::RawValue; ///< This is just to copy over source loc data (ie not strictly serialize) - enum class StringOffset : uint32_t; ///< Offset into the m_stringsBuffer - - typedef uint32_t SizeType; - - static const StringIndex kNullStringIndex = StringIndex(StringSlicePool::kNullHandle); - static const StringIndex kEmptyStringIndex = StringIndex(StringSlicePool::kEmptyHandle); - - /// A run of instructions - struct InstRun - { - typedef InstRun ThisType; - SLANG_FORCE_INLINE bool operator==(const ThisType& rhs) const - { - return m_parentIndex == rhs.m_parentIndex && - m_startInstIndex == rhs.m_startInstIndex && - m_numChildren == rhs.m_numChildren; - } - SLANG_FORCE_INLINE bool operator!=(const ThisType& rhs) const { return !(*this == rhs); } - - InstIndex m_parentIndex; ///< The parent instruction - InstIndex m_startInstIndex; ///< The index to the first instruction - SizeType m_numChildren; ///< The number of children - }; - - struct SourceLocRun - { - typedef SourceLocRun ThisType; - - bool operator==(const ThisType& rhs) const { return m_sourceLoc == rhs.m_sourceLoc && m_startInstIndex == rhs.m_startInstIndex && m_numInst == rhs.m_numInst; } - bool operator!=(const ThisType& rhs) const { return !(*this == rhs); } - bool operator<(const ThisType& rhs) const { return m_sourceLoc < rhs.m_sourceLoc; } - - uint32_t m_sourceLoc; ///< The source location - InstIndex m_startInstIndex; ///< The index to the first instruction - SizeType m_numInst; ///< The number of children - }; - - struct PayloadInfo - { - uint8_t m_numOperands; - uint8_t m_numStrings; - }; - - struct DebugSourceInfo - { - typedef DebugSourceInfo ThisType; - - bool operator==(const ThisType& rhs) const - { - return m_pathIndex == rhs.m_pathIndex && - m_startSourceLoc == rhs.m_startSourceLoc && - m_endSourceLoc == rhs.m_endSourceLoc && - m_numLineInfos == rhs.m_numLineInfos && - m_lineInfosStartIndex == rhs.m_lineInfosStartIndex && - m_numLineInfos == rhs.m_numLineInfos; - } - bool operator!=(const ThisType& rhs) const { return !(*this == rhs); } - - bool isSourceLocInRange(uint32_t sourceLoc) const { return sourceLoc >= m_startSourceLoc && sourceLoc <= m_endSourceLoc; } - - StringIndex m_pathIndex; ///< Index to the string table - uint32_t m_startSourceLoc; ///< The offset to the source - uint32_t m_endSourceLoc; ///< The number of bytes in the source - - uint32_t m_numLines; ///< Total number of lines in source file - - uint32_t m_lineInfosStartIndex; ///< Index into m_debugLineInfos - uint32_t m_numLineInfos; ///< The number of line infos - - uint32_t m_adjustedLineInfosStartIndex; ///< Adjusted start index - uint32_t m_numAdjustedLineInfos; ///< The number of line infos - }; - - struct DebugLineInfo - { - typedef DebugLineInfo ThisType; - bool operator<(const ThisType& rhs) const { return m_lineStartOffset < rhs.m_lineStartOffset; } - bool operator==(const ThisType& rhs) const - { - return m_lineStartOffset == rhs.m_lineStartOffset && - m_lineIndex == rhs.m_lineIndex; - } - bool operator!=(const ThisType& rhs) const { return !(*this == rhs); } - - uint32_t m_lineStartOffset; ///< The offset into the source file - uint32_t m_lineIndex; ///< Original line index - }; - - struct DebugAdjustedLineInfo - { - typedef DebugAdjustedLineInfo ThisType; - bool operator==(const ThisType& rhs) const - { - return m_lineInfo == rhs.m_lineInfo && - m_adjustedLineIndex == rhs.m_adjustedLineIndex && - m_pathStringIndex == rhs.m_pathStringIndex; - } - bool operator!=(const ThisType& rhs) const { return !(*this == rhs); } - bool operator<(const ThisType& rhs) const { return m_lineInfo < rhs.m_lineInfo; } - - DebugLineInfo m_lineInfo; - uint32_t m_adjustedLineIndex; ///< The line index with the adjustment (if there is any). Is 0 if m_pathStringIndex is 0. - StringIndex m_pathStringIndex; ///< The path as an index - }; - - // Instruction... - // We can store SourceLoc values separately. Just store per index information. - // Parent information is stored in m_childRuns - // Decoration information is stored in m_decorationRuns - struct Inst - { - typedef Inst ThisType; - enum - { - kMaxOperands = 2, ///< Maximum number of operands that can be held in an instruction (otherwise held 'externally') - }; - - // NOTE! Can't change order or list without changing appropriate s_payloadInfos - enum class PayloadType : uint8_t - { - // First 3 must be in this order so a cast from 0-2 is directly represented as number of operands - Empty, ///< Has no payload (or operands) - Operand_1, ///< 1 Operand - Operand_2, ///< 2 Operands - - OperandAndUInt32, ///< 1 Operand and a single UInt32 - OperandExternal, ///< Operands are held externally - String_1, ///< 1 String - String_2, ///< 2 Strings - UInt32, ///< Holds an unsigned 32 bit integral (might represent a type) - Float64, - Int64, - - CountOf, - }; - - /// Get the number of operands - SLANG_FORCE_INLINE int getNumOperands() const; - - bool operator==(const ThisType& rhs) const; - - SLANG_FORCE_INLINE bool operator!=(const ThisType& rhs) const { return !(*this == rhs); } - - uint8_t m_op; ///< For now one of IROp - PayloadType m_payloadType; ///< The type of payload - uint16_t m_pad0; ///< Not currently used - - InstIndex m_resultTypeIndex; //< 0 if has no type. The result type of this instruction - - struct ExternalOperandPayload - { - ArrayIndex m_arrayIndex; ///< Index into the m_externalOperands table - SizeType m_size; ///< The amount of entries in that table - }; - - struct OperandAndUInt32 - { - InstIndex m_operand; - uint32_t m_uint32; - }; - - union Payload - { - double m_float64; - int64_t m_int64; - uint32_t m_uint32; ///< Unsigned integral value - IRFloatingPointValue m_float; ///< Floating point value - IRIntegerValue m_int; ///< Integral value - InstIndex m_operands[kMaxOperands]; ///< For items that 2 or less operands it can use this. - StringIndex m_stringIndices[kMaxOperands]; - ExternalOperandPayload m_externalOperand; ///< Operands are stored in an an index of an operand array - OperandAndUInt32 m_operandAndUInt32; - }; - - Payload m_payload; - }; - - /// Clear to initial state - void clear(); - /// Get the operands of an instruction - SLANG_FORCE_INLINE int getOperands(const Inst& inst, const InstIndex** operandsOut) const; - - /// == - bool operator==(const ThisType& rhs) const; - SLANG_FORCE_INLINE bool operator!=(const ThisType& rhs) const { return !(*this == rhs); } - - /// Calculate the amount of memory used by this IRSerialData - size_t calcSizeInBytes() const; - - /// Ctor - IRSerialData(); - - List m_insts; ///< The instructions - - List m_childRuns; ///< Holds the information about children that belong to an instruction - - List m_externalOperands; ///< Holds external operands (for instructions with more than kNumOperands) - - List m_stringTable; ///< All strings. Indexed into by StringIndex - - List m_rawSourceLocs; ///< A source location per instruction (saved without modification from IRInst)s - - // Data only set if we have debug information - - List m_debugStringTable; ///< String table for debug use only - List m_debugLineInfos; ///< Debug line information - List m_debugAdjustedLineInfos; ///< Adjusted line infos - List m_debugSourceInfos; ///< Debug source information - List m_debugSourceLocRuns; ///< Runs of instructions that use a source loc - - static const PayloadInfo s_payloadInfos[int(Inst::PayloadType::CountOf)]; -}; - -// -------------------------------------------------------------------------- -SLANG_FORCE_INLINE int IRSerialData::Inst::getNumOperands() const -{ - return (m_payloadType == PayloadType::OperandExternal) ? m_payload.m_externalOperand.m_size : s_payloadInfos[int(m_payloadType)].m_numOperands; -} - -// -------------------------------------------------------------------------- -SLANG_FORCE_INLINE bool IRSerialData::Inst::operator==(const ThisType& rhs) const -{ - if (m_op == rhs.m_op && - m_payloadType == rhs.m_payloadType && - m_resultTypeIndex == rhs.m_resultTypeIndex) - { - switch (m_payloadType) - { - case PayloadType::Empty: - { - return true; - } - case PayloadType::Operand_1: - case PayloadType::String_1: - case PayloadType::UInt32: - { - return m_payload.m_operands[0] == rhs.m_payload.m_operands[0]; - } - case PayloadType::OperandAndUInt32: - case PayloadType::OperandExternal: - case PayloadType::Operand_2: - case PayloadType::String_2: - { - return m_payload.m_operands[0] == rhs.m_payload.m_operands[0] && - m_payload.m_operands[1] == rhs.m_payload.m_operands[1]; - } - case PayloadType::Float64: - case PayloadType::Int64: - { - return m_payload.m_int64 == rhs.m_payload.m_int64; - } - default: break; - } - } - - return false; -} -// -------------------------------------------------------------------------- -SLANG_FORCE_INLINE int IRSerialData::getOperands(const Inst& inst, const InstIndex** operandsOut) const -{ - if (inst.m_payloadType == Inst::PayloadType::OperandExternal) - { - *operandsOut = m_externalOperands.begin() + int(inst.m_payload.m_externalOperand.m_arrayIndex); - return int(inst.m_payload.m_externalOperand.m_size); - } - else - { - *operandsOut = inst.m_payload.m_operands; - return s_payloadInfos[int(inst.m_payloadType)].m_numOperands; - } -} - - -#define SLANG_FOUR_CC(c0, c1, c2, c3) ((uint32_t(c0) << 0) | (uint32_t(c1) << 8) | (uint32_t(c2) << 16) | (uint32_t(c3) << 24)) - -#define SLANG_MAKE_COMPRESSED_FOUR_CC(fourCc) (((fourCc) & 0xffff00ff) | (uint32_t('c') << 8)) - -struct IRSerialBinary -{ - // http://fileformats.archiveteam.org/wiki/RIFF - // http://www.fileformat.info/format/riff/egff.htm - - struct Chunk - { - uint32_t m_type; - uint32_t m_size; - }; - - enum class CompressionType - { - None, - VariableByteLite, - }; - - - static const uint32_t kRiffFourCc = SLANG_FOUR_CC('R', 'I', 'F', 'F'); - static const uint32_t kSlangFourCc = SLANG_FOUR_CC('S', 'L', 'N', 'G'); ///< Holds all the slang specific chunks - - static const uint32_t kInstFourCc = SLANG_FOUR_CC('S', 'L', 'i', 'n'); - static const uint32_t kChildRunFourCc = SLANG_FOUR_CC('S', 'L', 'c', 'r'); - static const uint32_t kExternalOperandsFourCc = SLANG_FOUR_CC('S', 'L', 'e', 'o'); - - static const uint32_t kCompressedInstFourCc = SLANG_MAKE_COMPRESSED_FOUR_CC(kInstFourCc); - static const uint32_t kCompressedChildRunFourCc = SLANG_MAKE_COMPRESSED_FOUR_CC(kChildRunFourCc); - static const uint32_t kCompressedExternalOperandsFourCc = SLANG_MAKE_COMPRESSED_FOUR_CC(kExternalOperandsFourCc); - - static const uint32_t kStringFourCc = SLANG_FOUR_CC('S', 'L', 's', 't'); - - static const uint32_t kUInt32SourceLocFourCc = SLANG_FOUR_CC('S', 'r', 's', '4'); - - static const uint32_t kDebugStringFourCc = SLANG_FOUR_CC('S', 'd', 's', 't'); - static const uint32_t kDebugLineInfoFourCc = SLANG_FOUR_CC('S', 'd', 'l', 'n'); - static const uint32_t kDebugAdjustedLineInfoFourCc = SLANG_FOUR_CC('S', 'd', 'a', 'l'); - static const uint32_t kDebugSourceInfoFourCc = SLANG_FOUR_CC('S', 'd', 's', 'o'); - static const uint32_t kDebugSourceLocRunFourCc = SLANG_FOUR_CC('S', 'd', 's', 'r'); - - struct SlangHeader - { - Chunk m_chunk; - uint32_t m_compressionType; ///< Holds the compression type used (if used at all) - }; - struct ArrayHeader - { - Chunk m_chunk; - uint32_t m_numEntries; - }; - struct CompressedArrayHeader - { - Chunk m_chunk; - uint32_t m_numEntries; ///< The number of entries - uint32_t m_numCompressedEntries; ///< The amount of compressed entries - }; -}; - - -struct IRSerialWriter -{ - typedef IRSerialData Ser; - typedef IRSerialBinary Bin; - - struct OptionFlag - { - typedef uint32_t Type; - enum Enum: Type - { - RawSourceLocation = 0x01, - DebugInfo = 0x02, - }; - }; - typedef OptionFlag::Type OptionFlags; - - Result write(IRModule* module, SourceManager* sourceManager, OptionFlags options, IRSerialData* serialData); - - static Result writeStream(const IRSerialData& data, Bin::CompressionType compressionType, Stream* stream); - - /// Get an instruction index from an instruction - Ser::InstIndex getInstIndex(IRInst* inst) const { return inst ? Ser::InstIndex(m_instMap[inst]) : Ser::InstIndex(0); } - - /// Get a slice from an index - UnownedStringSlice getStringSlice(Ser::StringIndex index) const { return m_stringSlicePool.getSlice(StringSlicePool::Handle(index)); } - /// Get index from string representations - Ser::StringIndex getStringIndex(StringRepresentation* string) { return Ser::StringIndex(m_stringSlicePool.add(string)); } - Ser::StringIndex getStringIndex(const UnownedStringSlice& slice) { return Ser::StringIndex(m_stringSlicePool.add(slice)); } - Ser::StringIndex getStringIndex(Name* name) { return name ? getStringIndex(name->text) : Ser::kNullStringIndex; } - Ser::StringIndex getStringIndex(const char* chars) { return Ser::StringIndex(m_stringSlicePool.add(chars)); } - Ser::StringIndex getStringIndex(const String& string) { return Ser::StringIndex(m_stringSlicePool.add(string.getUnownedSlice())); } - - StringSlicePool& getStringPool() { return m_stringSlicePool; } - StringSlicePool& getDebugStringPool() { return m_debugStringSlicePool; } - - IRSerialWriter() : - m_serialData(nullptr) - {} - -protected: - class DebugSourceFile : public RefObject - { - public: - DebugSourceFile(SourceFile* sourceFile, SourceLoc::RawValue baseSourceLoc): - m_sourceFile(sourceFile), - m_baseSourceLoc(baseSourceLoc) - { - // Need to know how many lines there are - const List& lineOffsets = sourceFile->getLineBreakOffsets(); - - const auto numLineIndices = lineOffsets.getCount(); - - // Set none as being used initially - m_lineIndexUsed.setCount(numLineIndices); - ::memset(m_lineIndexUsed.begin(), 0, numLineIndices * sizeof(uint8_t)); - } - /// True if we have information on that line index - bool hasLineIndex(int lineIndex) const { return m_lineIndexUsed[lineIndex] != 0; } - void setHasLineIndex(int lineIndex) { m_lineIndexUsed[lineIndex] = 1; } - - SourceLoc::RawValue m_baseSourceLoc; ///< The base source location - - SourceFile* m_sourceFile; ///< The source file - List m_lineIndexUsed; ///< Has 1 if the line is used - List m_usedLineIndices; ///< Holds the lines that have been hit - - List m_lineInfos; ///< The line infos - List m_adjustedLineInfos; ///< The adjusted line infos - }; - - void _addInstruction(IRInst* inst); - Result _calcDebugInfo(); - /// Returns the remapped sourceLoc, or 0 if sourceLoc couldn't be added - void _addDebugSourceLocRun(SourceLoc sourceLoc, uint32_t startInstIndex, uint32_t numInst); - - List m_insts; ///< Instructions in same order as stored in the - - List m_decorations; ///< Holds all decorations in order of the instructions as found - List m_instWithFirstDecoration; ///< All decorations are held in this order after all the regular instructions - - Dictionary m_instMap; ///< Map an instruction to an instruction index - - StringSlicePool m_stringSlicePool; - IRSerialData* m_serialData; ///< Where the data is stored - - StringSlicePool m_debugStringSlicePool; ///< Slices held just for debug usage - - SourceLoc::RawValue m_debugFreeSourceLoc; /// Locations greater than this are free - Dictionary > m_debugSourceFileMap; - - SourceManager* m_sourceManager; ///< The source manager -}; - -struct IRSerialReader -{ - typedef IRSerialData Ser; - typedef StringRepresentationCache::Handle StringHandle; - - /// Read a stream to fill in dataOut IRSerialData - static Result readStream(Stream* stream, IRSerialData* dataOut); - - /// Read a module from serial data - Result read(const IRSerialData& data, Session* session, SourceManager* sourceManager, RefPtr& moduleOut); - - /// Get the representation cache - StringRepresentationCache& getStringRepresentationCache() { return m_stringRepresentationCache; } - - IRSerialReader(): - m_serialData(nullptr), - m_module(nullptr) - { - } - - protected: - - static Result _skip(const IRSerialBinary::Chunk& chunk, Stream* stream, int64_t* remainingBytesInOut); - - StringRepresentationCache m_stringRepresentationCache; - - const IRSerialData* m_serialData; - IRModule* m_module; -}; - -struct IRSerialUtil -{ - /// Produces an instruction list which is in same order as written through IRSerialWriter - static void calcInstructionList(IRModule* module, List& instsOut); - - /// Verify serialization - static SlangResult verifySerialize(IRModule* module, Session* session, SourceManager* sourceManager, IRSerialBinary::CompressionType compressionType, IRSerialWriter::OptionFlags optionFlags); -}; - - -} // namespace Slang - -#endif diff --git a/source/slang/ir-specialize-resources.cpp b/source/slang/ir-specialize-resources.cpp deleted file mode 100644 index 96f328672..000000000 --- a/source/slang/ir-specialize-resources.cpp +++ /dev/null @@ -1,865 +0,0 @@ -// ir-specialize-resources.cpp -#include "ir-specialize-resources.h" - -#include "ir.h" -#include "ir-clone.h" -#include "ir-insts.h" - -namespace Slang -{ - -struct ResourceParameterSpecializationContext -{ - // This type implements a pass to specialize functions - // with resource parameters to ensure that they are - // legal for a given target. - // - // We start with member variables to stand in for - // the parameters that were passed to the top-level - // `specializeResourceParameters` function. - // - BackEndCompileRequest* compileRequest; - TargetRequest* targetRequest; - IRModule* module; - - // Our general approach will be to think in terms - // of specializing call sites, which amount to - // `IRCall` instructions. We will keep a work list - // of call sites in the program that may be worth - // considering for specialization. - // - List workList; - - // Because we may need to generate specialized functions - // and generate new calls to those functions, we'll - // need some IR building state to get our work done. - // - SharedIRBuilder sharedBuilderStorage; - IRBuilder builderStorage; - IRBuilder* getBuilder() { return &builderStorage; } - - // With the basic state out of the way, let's walk - // through the overall flow of the pass. - // - void processModule() - { - // We will start by initializing our IR building state. - // - sharedBuilderStorage.module = module; - sharedBuilderStorage.session = module->getSession(); - builderStorage.sharedBuilder = &sharedBuilderStorage; - - // Next we will populate our initial work list by - // recursively finding every single call site in the module. - // - addCallsToWorkListRec(module->getModuleInst()); - - // We will process the work list until it goes dry, - // treating it like a stack of work items. - // - while( workList.getCount() ) - { - auto call = workList.getLast(); - workList.removeLast(); - - // At each call site we first check whether it - // is something we can (and should) specialize, - // and if so, do it. The process of specializing - // a function may introduce new call sites that - // become candidates for specialization, so - // our work list may grow along the way. - // - if( canSpecializeCall(call) ) - { - specializeCall(call); - } - } - } - - // Setting up the work list is a simple recursive procedure. - // - void addCallsToWorkListRec(IRInst* inst) - { - // If we have a call site, then add it to the list. - // - if( auto call = as(inst) ) - { - workList.add(call); - } - - // Recursively walk through any children, to - // see if we uncover more call sites. - // - for( auto child : inst->getChildren() ) - { - addCallsToWorkListRec(child); - } - } - - // We need a way to decide for a given call site - // whether we can/must specialize it. - // - bool canSpecializeCall(IRCall* call) - { - // We can only specialize calls where the callee - // func can be statically identified, and where - // the callee is a definition (with body) rather - // than a declaration. Otherwise there is no - // way to generate a specialized callee function. - // - auto func = as(call->getCallee()); - if(!func) - return false; - if(!func->isDefinition()) - return false; - - // With the basic checks out of the way, there are - // two conditions we care about: - // - // 1. Should we specialize? This amounts to whether - // `func` has any parameters that need specialization. - // We will call those "specializable" parameters for - // lack of a better name. - // - // 2. Can we specialize? This amounts to whether the - // arguments in `call` that correspond to those - // specializable parameters are "suitable" for use - // in specialization. - // - // We are going to answer both of these queries in - // a single loop that walks over the parameters of - // `func` as well as the arguments to `call`. - // - // The loop may seem a bit awkward because we are - // doing a parallel iteration over a linked list - // (the parameters of `func`) and an array (the - // arguments of `call`). - // - bool anySpecializableParam = false; - UInt argCounter = 0; - for( auto param : func->getParams() ) - { - UInt argIndex = argCounter++; - SLANG_ASSERT(argIndex < call->getArgCount()); - auto arg = call->getArg(argIndex); - - // If the given parameter doesn't need specialization, - // then we need to keep looking. - // - if(!doesParamNeedSpecialization(param)) - continue; - - // If we have run into a `param` that needs specialization, - // then our first condition is met. - // - anySpecializableParam = true; - - // Now we need to check whether `arg` is actually suitable - // for specialization (our second condition). If not, we - // can bail out immediately because our second condition - // cannot be met. - // - if(!isArgSuitableForSpecialization(arg)) - return false; - } - - // If we exit the loop, then the second condition must have - // been met (all the arguments for specializable parameters - // were suitable for specialization), and the result of the - // query comes down to the first condition. - // - return anySpecializableParam; - } - - // Of course, now we need to back-fill the predicates that - // the above function used to evaluate prameters and arguments. - - bool doesParamNeedSpecialization(IRParam* param) - { - // Whether or not a parameter needs specialization is really - // a function of its type: - // - IRType* type = param->getDataType(); - - // What's more, if a parameter of type `T` would need - // specialization, then it seems clear that a parameter - // of type "array of `T`" would also need specialization. - // We will "unwrap" any outer arrays from the parameter - // type before moving on, since they won't affect - // our decision. - // - type = unwrapArray(type); - - // On all of our (current) targets, a function that - // takes a `ConstantBuffer` parameter requires - // specialization. Surprisingly this includes DXIL - // because dxc apparently does not treat `ConstantBuffer` - // as a first-class type. - // - if(as(type)) - return true; - - // For GL/Vulkan targets, we also need to specialize - // any parameters that use structured or byte-addressed - // buffers. - // - if( isKhronosTarget(targetRequest) ) - { - if(as(type)) - return true; - if(as(type)) - return true; - } - - // For now, we will not treat any other parameters as - // needing specialization, even if they use resource - // types like `Texure2D`, because these are allowed - // as function parameters in both HLSL and GLSL. - // - // TODO: Eventually, if we start generating SPIR-V - // directly rather than through glslang, we will need - // to specialize *all* resource-type parameters - // to follow the restrictions in the spec. - // - // TODO: We may want to perform more aggressive - // specialization in general, especially insofar - // as it could simplify the task of supporting - // functions with resource-type outputs. - - return false; - } - - bool isArgSuitableForSpecialization(IRInst* inArg) - { - // Determining if an argument is suitable for - // specializing a callee function requires - // looking at its (recurisve) structure. - // - // Rather than write a recursively procedure - // here, we will be tail-recursive by using - // a simple loop. - // - IRInst* arg = inArg; - for(;;) - { - // The leaf case we care about is when the - // argument at the call site is a global - // shader parameter, because then we can - // specialize a callee to refer to the same - // global parameter directly. - // - if(as(arg)) return true; - - // As we will see later, we can also - // specialize a call when the argument - // is the result of indexing into an - // array (`base[index]`) *if* the `base` - // of the indexing operation is also - // suitable for specialization. - // - if( arg->op == kIROp_getElement ) - { - auto base = arg->getOperand(0); - - // We will "recurse" on the base of - // the indexing operation by continuing - // our loop with the `base` as our new - // argument. - // - arg = base; - continue; - } - - // By default, we will *not* consider an argument - // suitable for specialization. - // - // TODO: There may be other cases that are worth - // handling here. The current code is based on - // observation of what simple shaders do in - // practice. - // - return false; - } - } - - // Once we'e determined that a given call site can/should - // be specialized, we need to perform the actual specialization. - // This is where things are going to get more involved. - // - // There are a few different concerns we need to deal with - // that mean we end up having two different passes that walk - // over the parameters/arguments of the call (in addition to - // the ones we had above for determining if we can/should - // specialize in the first place). - // - // The first of the two passes determines information - // relevant to the call site, comprising both the arguments - // that will be passed to the specialized function as - // well as a "key" to identify the specialized function - // that is required. - // - // We will use the key type defined as part of the IR cloning - // infrastructure, which uses a sequence of `IRInst*`s - // to hold the state of the key: - // - typedef IRSimpleSpecializationKey Key; - - // As indicated above, the information we collect about a call - // site consists of the key for the specialized function we - // will call, and a list of the arguments that will be passed - // to the call. - // - struct CallSpecializationInfo - { - Key key; - List newArgs; - }; - - // Once we've collected the information about a call site - // we can use a dictionary to see if we already created - // a specialized version of the callee that matches its - // requirements. - // - Dictionary specializedFuncs; - - // If the dictionary didn't have a specialized function - // suitable for a call site, we need a second information-gathering - // pass to decide what the new parameters of the specialized - // functions should be, and what instructions the new function - // must execute in its body to set up the replacements for the - // old parameters. - // - struct FuncSpecializationInfo - { - List newParams; - List newBodyInsts; - List replacementsForOldParameters; - }; - - // Before diving into how the different passes collect - // their information, we will dive into the main - // specialization logic first. - // - void specializeCall(IRCall* oldCall) - { - // We have an existing call site `oldCall` that - // we know can and should be specialized. - // - // That means the callee should be a known function - // definition, or else `canSpecializeCall` didn't - // correctly check the preconditions. - // - auto oldFunc = as(oldCall->getCallee()); - SLANG_ASSERT(oldFunc); - SLANG_ASSERT(oldFunc->isDefinition()); - - // Our first information-gathering pass will - // compute the key for the specialized function - // we want to call, and the arguments we will - // use for that call. - // - CallSpecializationInfo callInfo; - gatherCallInfo(oldCall, oldFunc, callInfo); - - // Once we have gathered information on the call, - // we can check if we have an existing specialization - // that we generated before (for another call site) - // that is suitable to this call site. - // - IRFunc* newFunc = nullptr; - if( !specializedFuncs.TryGetValue(callInfo.key, newFunc) ) - { - // If we didn't find a pre-existing specialized - // function, then we will go ahead and create one. - // - // We start by gathering the information from the call - // site that is relevant to generating a specialized - // callee function, which we avoided doing earlier - // because it might have been throwaway work. - // - FuncSpecializationInfo funcInfo; - gatherFuncInfo(oldCall, oldFunc, funcInfo); - - // Now we use the gathered information to generate - // a new callee function based on the original - // function and the information we gathered. - // - newFunc = generateSpecializedFunc(oldFunc, funcInfo); - specializedFuncs.Add(callInfo.key, newFunc); - } - - // Once we've other found or generated a specialized function - // we need to generate a call to it, and then use the new - // call as a replacement for the old one. - // - auto newCall = getBuilder()->emitCallInst( - oldCall->getFullType(), - newFunc, - callInfo.newArgs.getCount(), - callInfo.newArgs.getBuffer()); - - newCall->insertBefore(oldCall); - oldCall->replaceUsesWith(newCall); - oldCall->removeAndDeallocate(); - } - - // Before diving into the details on how we gather information - // and specialize callees, lets stop to think about what we'd - // like to do in terms of individual parameters and arguments. - // - // Suppose we are specializing both a call site C and the callee - // function F, and we are consisering a particular pair of - // a parmeter P of F, and an argument A at the call site. - // - // The full extent of information we might want to know given - // P and A is: - // - // * What arguments need to be added to the specialized call? - // * What parameters need to be added to the specialized callee? - // * What instructions are needed in the body of the specialized - // callee to synthesize the value that will stand in for P? - // * What information, if any, needs to be used to distinguish - // this specialized callee from others that might be generated for F? - // - // An easy case is when P is a parameter that doesn't need - // specialization. In that case: - // - // * The existing argument A should be used as an argument in - // the specialized call. - // * A clone P' of the existing parameter P should be used as a - // parameter of the specialized callee. - // * No additional instructions are needed in the body of - // the callee; the cloned parameter P' should stand in for P. - // * No information should be added to the specialization key - // based on P and A. - // - // The more interesting case is when P has a resource type, and - // A is some global shader parameter G. - // - // * No argument should be added at the new call site - // * No parameter should be added to the specialized callee - // * No additional instructions are needed in the body of - // the callee; the global G should stand in for P. - // * The global G should be used to distinguish this specialized - // callee from those that might be specialized for a different - // global shader parameter. - // - // As a final example, imagine that P is still a resource type, - // but A is now an indexing operation into an array: `G[idx]`: - // - // * An argument for `idx` should be added at the call site - // * A parameter `p_idx` with the same type as `idx` should be added - // to the specialized callee. - // * An instruction should be added to the specialized callee - // to compute `G[p_idx]` and use that to stand in for P. - // * The global G should still be used to distinguish this specialized - // call site from others. - // - // That's a lot of examples, I know, but hopefully it gives a - // sense of the information we are tracking and how it differs - // across the various cases. While the example only covered one - // level of indexing, the actual implementation will handle the - // case of arbitrarily many levels of indexing, which can mean - // piping through any number of additional integer parameters - // to the callee. - - // The information we gather for a call site (before we know - // whether a specialize calle is needed) is just the new - // argument list, and the "key" information that distinguishes - // what specialized callee we want/need. - // - void gatherCallInfo( - IRCall* oldCall, - IRFunc* oldFunc, - CallSpecializationInfo& callInfo) - { - // The specialized callee key always needs to include - // the original function, since different functions - // will always yield different specializations. - // - callInfo.key.vals.add(oldFunc); - - // The rest of the information is gathered by looking - // at parameter and argument pairs. - // - UInt oldArgCounter = 0; - for( auto oldParam : oldFunc->getParams() ) - { - UInt oldArgIndex = oldArgCounter++; - auto oldArg = oldCall->getArg(oldArgIndex); - - getCallInfoForParam(callInfo, oldParam, oldArg); - } - } - - void getCallInfoForParam( - CallSpecializationInfo& ioInfo, - IRParam* oldParam, - IRInst* oldArg) - { - // We know that the case where a parameter - // doesn't need specialization is easy. - // - if( !doesParamNeedSpecialization(oldParam) ) - { - // The new call site will use the same argument - // value as the old one, and we don't need - // to add any information to distinguish the - // specialized callee based on this paramter. - // - ioInfo.newArgs.add(oldArg); - } - else - { - // If specialization is needed, we need - // to inspect the argument value. This - // is handled with a different function - // because it needs to recurse in some cases. - // - getCallInfoForArg(ioInfo, oldArg); - } - } - - void getCallInfoForArg( - CallSpecializationInfo& ioInfo, - IRInst* oldArg) - { - // The base case we care about is when the original - // argument is a global shader parameter. - // - if( auto oldGlobalParam = as(oldArg) ) - { - // In this case we don't need to pass anything - // as an argument at the new call site (the - // global parameter will get specialized into - // the callee), but we *do* need to make sure - // that our key for identifying the specialized - // callee reflects that we are specializing - // to the chosen parameter. - // - ioInfo.key.vals.add(oldGlobalParam); - } - else if( oldArg->op == kIROp_getElement ) - { - // This is the case where the `oldArg` is - // in the form `oldBase[oldIndex]` - // - auto oldBase = oldArg->getOperand(0); - auto oldIndex = oldArg->getOperand(1); - - // Effectively, we act as if `oldBase` and - // `oldIndex` were passed to the callee separately, - // so that `oldBase` is an array-of-resouces and - // `oldIndex` is an ordinary integer argument. - // - // We start by recursively setting up whatever - // `oldBase` needs: - // - getCallInfoForArg(ioInfo, oldBase); - - // Then we process `oldIndex` just like we - // would have an ordinary argument that doesn't - // involve specialization: add its value to - // the arguments at the new call site, and - // don't add anything to the specialization key. - // - ioInfo.newArgs.add(oldIndex); - } - else - { - // If we fail to match any of the cases above - // then a precondition was violated in that - // `isArgSuitableForSpecialization` is allowing - // a case that this routine is not covering. - // - SLANG_UNEXPECTED("mising case in 'getCallInfoForArg'"); - } - } - - // The remaining information we've discussed is only - // gathered once we decide we want to generate a - // specialized function, but it follows much the same flow. - // - void gatherFuncInfo( - IRCall* oldCall, - IRFunc* oldFunc, - FuncSpecializationInfo& funcInfo) - { - UInt oldArgCounter = 0; - for( auto oldParam : oldFunc->getParams() ) - { - UInt oldArgIndex = oldArgCounter++; - auto oldArg = oldCall->getArg(oldArgIndex); - - // For each parameter and argument pair we will - // frame the main task as producing a value that - // will stand in for the parameter in the specialized - // function. - // - auto newVal = getSpecializedValueForParam(funcInfo, oldParam, oldArg); - - // We will collect the replacement value to use - // for each of the original parameters in an array. - // - funcInfo.replacementsForOldParameters.add(newVal); - } - } - - IRInst* getSpecializedValueForParam( - FuncSpecializationInfo& ioInfo, - IRParam* oldParam, - IRInst* oldArg) - { - // As always, the easy case is when the parameter of - // the original function doesn't need specialization. - // - if( !doesParamNeedSpecialization(oldParam) ) - { - // The specialized callee will need a new parameter - // that fills the same role as the old one, so we - // create it here. - // - auto newParam = getBuilder()->createParam(oldParam->getFullType()); - ioInfo.newParams.add(newParam); - - // The new parameter will be used as the replacement - // for the old one in the specialized function. - // - return newParam; - } - else - { - // If the parameter requires specialization, then it - // is time to look at the structure of the argument. - // - return getSpecializedValueForArg(ioInfo, oldArg); - } - } - - IRInst* getSpecializedValueForArg( - FuncSpecializationInfo& ioInfo, - IRInst* oldArg) - { - // The logic here parallels `gatherCallInfoForArg`, - // and only differs in what information it is gathering. - // - // As before, the base case is when we have a global - // shader parameter. - // - if( auto globalParam = as(oldArg) ) - { - // The specialized function will not need any - // parameter in this case, and the global itself - // should be used to stand in for the original - // parameter in the specialized function. - // - return globalParam; - } - else if( oldArg->op == kIROp_getElement ) - { - // This is the case where the argument is - // in the form `oldBase[oldIndex]`. - // - auto oldBase = oldArg->getOperand(0); - auto oldIndex = oldArg->getOperand(1); - - // In `gatherCallInfoForArg` this case was - // handled by acting as if `oldBase` and - // `oldIndex` were being passed as two - // separate arguments. - // - // We'll follow the same structure here, - // starting by recursively processing `oldBase` - // to get a value that can stand in for it - // in the specialized callee. - // - auto newBase = getSpecializedValueForArg(ioInfo, oldBase); - - // Next we'll process `oldIndex` as if it - // was an ordinary argument (not a specialized one), - // which means creating a parameter to receive its value, - // which will also stand in for `oldIndex` in - // the body of the specialized callee. - // - auto builder = getBuilder(); - auto newIndex = builder->createParam(oldIndex->getFullType()); - ioInfo.newParams.add(newIndex); - - // Finally, we need to compute a value that - // can stand in for `oldArg` (which was - // `oldBase[oldIndex]`) in the body of the - // specialized callee. - // - // Because we have both a `newBase` and a - // `newIndex` it is natural to construct - // `newBase[newIndex]` and use that. - // - // The only complication is that we need - // to make sure that our IR builder isn't - // set to insert newly created instructions - // anywhere, since the `emit*` functions - // will try to automatically insert new - // instructions if an insertion location - // is set. - // - builder->setInsertInto(nullptr); - auto newVal = builder->emitElementExtract( - oldArg->getFullType(), - newBase, - newIndex); - - // Because our new instruction wasn't - // actually inserted anywhere, we need to - // add it to our gathered list of instructions - // that should be inserted into the body of - // the specialized callee. - // - ioInfo.newBodyInsts.add(newVal); - - return newVal; - } - else - { - // If we don't match one of the above cases, - // then `isArgSuitableForSpecialization` is - // letting through cases that this function - // hasn't been updated to handle. - // - SLANG_UNEXPECTED("mising case in 'getSpecializedValueForArg'"); - UNREACHABLE_RETURN(nullptr); - } - } - - // With all of that data-gathering code out of the way, - // we are now prepared to walk through the process of - // specializing a given callee function based on - // the information we have gathered. - // - IRFunc* generateSpecializedFunc( - IRFunc* oldFunc, - FuncSpecializationInfo const& funcInfo) - { - // We will make use of the infrastructure for cloning - // IR code, that is defined in `ir-clone.{h,cpp}`. - // - // In order to do the cloning work we need an - // "environment" that will map old values to - // their replacements. - // - IRCloneEnv cloneEnv; - - // Next we iterate over the parameters of the old - // function, and register each as being mapped - // to its replacement in the `funcInfo` that was - // already gathered. - // - UInt paramCounter = 0; - for( auto oldParam : oldFunc->getParams() ) - { - UInt paramIndex = paramCounter++; - auto newVal = funcInfo.replacementsForOldParameters[paramIndex]; - cloneEnv.mapOldValToNew.Add(oldParam, newVal); - } - - // Next we will create the skeleton of the new - // specialized function, including its type. - // - // To get the type of the new function we will - // iterate over the collected list of new - // parameters (which may differ greatly from the - // parameter list of the original) and extract - // their types. - // - List paramTypes; - for( auto param : funcInfo.newParams ) - { - paramTypes.add(param->getFullType()); - } - - auto builder = getBuilder(); - IRType* funcType = builder->getFuncType( - paramTypes.getCount(), - paramTypes.getBuffer(), - oldFunc->getResultType()); - - IRFunc* newFunc = builder->createFunc(); - newFunc->setFullType(funcType); - - // The above step has accomplished the "first phase" - // of cloning the function (since `IRFunc`s have no - // operands). - // - // We can now use the shared IR cloning infrastructure - // to perform the second phase of cloning, which will recursively - // clone any nested decorations, blocks, and instructions. - // - cloneInstDecorationsAndChildren( - &cloneEnv, - builder->sharedBuilder, - oldFunc, - newFunc); - - // We are almost done at this point, except that `newFunc` - // is lacking its parameters, as well as any of the body - // instructions that we decided were needed during - // the information-gathering steps. - // - // We will insert these instructions into the first block - // of the function, before its first ordinary instruction. - // We know that these should exist because we had as - // a precondition that `oldFunc` was a definition (so it - // has at least one block), and in valid IR every block - // has at least one ordinary instruction (its terminator). - // - auto newEntryBlock = newFunc->getFirstBlock(); - SLANG_ASSERT(newEntryBlock); - auto newFirstOrdinary = newEntryBlock->getFirstOrdinaryInst(); - SLANG_ASSERT(newFirstOrdinary); - - // We simply iterate over the list of parameters and then - // body instructions that were produced in the information - // gathering step, and insert each before `newFirstOrdinary`, - // which has the effect or arranging them in the output - // in the order they are enumerated here. - // - for( auto newParam : funcInfo.newParams ) - { - newParam->insertBefore(newFirstOrdinary); - } - for( auto newBodyInst : funcInfo.newBodyInsts ) - { - newBodyInst->insertBefore(newFirstOrdinary); - } - - // At this point we've created a new specialized function, - // and as such it may contain call sites that were not - // covered when we built our initial work list. - // - // Before handing the specialized function back to the - // caller, we will make sure to recursively add any - // potentially-specializable call sites to our work list. - // - addCallsToWorkListRec(newFunc); - - return newFunc; - } -}; - -// The top-level function for invoking the specialization pass -// is straighforward. We set up the context object -// and then defer to it for the real work. -// -void specializeResourceParameters( - BackEndCompileRequest* compileRequest, - TargetRequest* targetRequest, - IRModule* module) -{ - ResourceParameterSpecializationContext context; - context.compileRequest = compileRequest; - context.targetRequest = targetRequest; - context.module = module; - - context.processModule(); -} - -} // namesapce Slang diff --git a/source/slang/ir-specialize-resources.h b/source/slang/ir-specialize-resources.h deleted file mode 100644 index 0e636318c..000000000 --- a/source/slang/ir-specialize-resources.h +++ /dev/null @@ -1,24 +0,0 @@ -// ir-specialize-resources.h -#pragma once - -namespace Slang -{ - class BackEndCompileRequest; - class TargetRequest; - struct IRModule; - - /// Specialize calls to functions with resource-type parameters. - /// - /// For any function that has resource-type input parameters that - /// would be invalid on the chosen target, this pass will rewrite - /// any call sites that pass suitable arguments (e.g., direct - /// references to global shader parameters) to instead call - /// a specialized variant of the function that does not have - /// those resource parameters (and instead, e.g, refers to the - /// global shader parameters directly). - /// - void specializeResourceParameters( - BackEndCompileRequest* compileRequest, - TargetRequest* targetRequest, - IRModule* module); -} diff --git a/source/slang/ir-specialize.cpp b/source/slang/ir-specialize.cpp deleted file mode 100644 index b57d2b58f..000000000 --- a/source/slang/ir-specialize.cpp +++ /dev/null @@ -1,1864 +0,0 @@ -// ir-specialize.cpp -#include "ir-specialize.h" - -#include "ir.h" -#include "ir-clone.h" -#include "ir-insts.h" - -namespace Slang -{ - -// This file implements the primary specialization pass, that takes -// generic/polymorphic Slang code and specializes/monomorphises it. -// -// At present this primarily means generating specialized copies -// of generic functions/types based on the concrete types used -// at specialization sites, and also specializing instances -// of witness-table lookup to directly refer to the concrete -// values for witnesses when witness tables are known. -// -// This pass also performs some amount of simplification and -// specialization for code using existential (interface) types -// for local variables and function parameters/results. -// -// Eventually, this pass will also need to perform specialization -// of functions to argument values for parameters that must -// be compile-time constants, -// -// All of these passes are inter-related in that applying -// simplifications/specializations of one category can open -// up opportunities for transformations in the other categories. - -struct SpecializationContext; - -IRInst* specializeGenericImpl( - IRGeneric* genericVal, - IRSpecialize* specializeInst, - IRModule* module, - SpecializationContext* context); - -struct SpecializationContext -{ - // For convenience, we will keep a pointer to the module - // we are specializing. - IRModule* module; - - // We know that we can only perform generic specialization when all - // of the arguments to a generic are also fully specialized. - // The "is fully specialized" condition is something we - // need to solve for over the program, because the fully- - // specialized-ness of an instruction depends on the - // fully-specialized-ness of its operands. - // - // We will build an explicit hash set to encode those - // instructions that are fully specialized. - // - HashSet fullySpecializedInsts; - - // An instruction is then fully specialized if and only - // if it is in our set. - // - bool isInstFullySpecialized( - IRInst* inst) - { - // A small wrinkle is that a null instruction pointer - // sometimes appears a a type, and so should be treated - // as fully specialized too. - // - // TODO: It would be nice to remove this wrinkle. - // - if(!inst) return true; - - return fullySpecializedInsts.Contains(inst); - } - - // When an instruction isn't fully specialized, but its operands *are* - // then it is a candidate for specialization itself, so we will have - // a query to check for the "all operands fully specialized" case. - // - bool areAllOperandsFullySpecialized( - IRInst* inst) - { - if(!isInstFullySpecialized(inst->getFullType())) - return false; - - UInt operandCount = inst->getOperandCount(); - for(UInt ii = 0; ii < operandCount; ++ii) - { - IRInst* operand = inst->getOperand(ii); - if(!isInstFullySpecialized(operand)) - return false; - } - - return true; - } - - // We will use a single work list of instructions that need - // to be considered for specialization or simplification, - // whether generic, existential, etc. - // - List workList; - HashSet workListSet; - - HashSet cleanInsts; - - void addToWorkList( - IRInst* inst) - { - // We will ignore any code that is nested under a generic, - // because it doesn't make sense to perform specialization - // on such code. - // - for( auto ii = inst->getParent(); ii; ii = ii->getParent() ) - { - if(as(ii)) - return; - } - - if(workListSet.Contains(inst)) - return; - - workList.add(inst); - workListSet.Add(inst); - cleanInsts.Remove(inst); - - addUsersToWorkList(inst); - } - - // When a transformation makes a change to an instruction, - // we may need to re-consider transformations for instructions - // that use its value. In those cases we will call `addUsersToWorkList` - // on the instruction that is being modified or replaced. - // - void addUsersToWorkList( - IRInst* inst) - { - for( auto use = inst->firstUse; use; use = use->nextUse ) - { - auto user = use->getUser(); - addToWorkList(user); - } - } - - // One of the main transformations we will apply is to - // consider an instruction as being fully specialized. - // - void markInstAsFullySpecialized( - IRInst* inst) - { - if(fullySpecializedInsts.Contains(inst)) - return; - fullySpecializedInsts.Add(inst); - - // If we know that an instruction is fully specialized, - // then we should start to consider its uses and children - // as candidates for being fully specialized too... - // - addUsersToWorkList(inst); - } - - - // Of course, somewhere along the way we expect - // to run into uses of `specialize(...)` instructions - // to bind a generic to arguments that we want to - // specialize into concrete code. - // - // We also know that if we encouter `specialize(g, a, b, c)` - // and then later `specialize(g, a, b, c)` again, we - // only want to generate the specialized code for `g` - // *once*, and re-use it for both versions. - // - // We will cache existing specializations of generic function/types - // using the simple key type defined as part of the IR cloning infrastructure. - // - typedef IRSimpleSpecializationKey Key; - Dictionary genericSpecializations; - - // We will also use some shared IR building state across - // all of our specialization/cloning steps. - // - SharedIRBuilder sharedBuilderStorage; - - // Now let's look at the task of finding or generation a - // specialization of some generic `g`, given a specialization - // instruction like `specialize(g, a, b, c)`. - // - // The `specializeGeneric` function will return a value - // suitable for use as a replacement for the `specialize(...)` - // instruction. - // - IRInst* specializeGeneric( - IRGeneric* genericVal, - IRSpecialize* specializeInst) - { - // First, we want to see if an existing specialization - // has already been made. To do that we will construct a key - // for lookup in the generic specialization context. - // - // Our key will consist of the identity of the generic - // being specialized, and each of the argument values - // being pased to it. In our hypothetical example of - // `specialize(g, a, b, c)` the key will then be - // the array `[g, a, b, c]`. - // - Key key; - key.vals.add(specializeInst->getBase()); - UInt argCount = specializeInst->getArgCount(); - for( UInt ii = 0; ii < argCount; ++ii ) - { - key.vals.add(specializeInst->getArg(ii)); - } - - { - // We use our generated key to look for an - // existing specialization that has been registered. - // If one is found, our work is done. - // - IRInst* specializedVal = nullptr; - if(genericSpecializations.TryGetValue(key, specializedVal)) - return specializedVal; - } - - // If no existing specialization is found, we need - // to create the specialization instead. - // This mostly amounts to evaluating the generic as - // if it were a function being called. - // - // We will use a free function to do the actual work - // of evaluating the generic, so that the logic - // can be re-used in other cases that need to - // do one-off specialization. - // - IRInst* specializedVal = specializeGenericImpl(genericVal, specializeInst, module, this); - - - // The value that was returned from evaluating - // the generic is the specialized value, and we - // need to remember it in our dictionary of - // specializations so that we don't instantiate - // this generic again for the same arguments. - // - genericSpecializations.Add(key, specializedVal); - - return specializedVal; - } - - // The logic for generating a specialization of an IR generic - // relies on the ability to "evaluate" the code in the body of - // the generic, but that obviously doesn't work if we don't - // actually have the full definition for the body. - // - // This can arise in particular for builtin operations/types. - // - // Before calling `specializeGeneric()` we need to make sure - // that the generic is actually amenable to specialization, - // by looking at whether it is a definition or a declaration. - // - bool canSpecializeGeneric( - IRGeneric* generic) - { - // It is possible to have multiple "layers" of generics - // (e.g., when a generic function is nested in a generic - // type). Therefore we need to drill down through all - // of the layers present to see if at the leaf we have - // something that looks like a definition. - // - IRGeneric* g = generic; - for(;;) - { - // Given the generic `g`, we will find the value - // it appears to return in its body. - // - auto val = findGenericReturnVal(g); - if(!val) - return false; - - // If `g` returns an inner generic, then we need - // to drill down further. - // - if (auto nestedGeneric = as(val)) - { - g = nestedGeneric; - continue; - } - - // Once we've found the leaf value that will be produced - // after all specialization is complete, we can check - // whether it looks like a definition or not. - // - return isDefinition(val); - } - } - - // Now that we know when we can specialize a generic, and how - // to do it, we can write a subroutine that takes a - // `specialize(g, a, b, c, ...)` instruction and performs - // specialization if it is possible. - // - void maybeSpecializeGeneric( - IRSpecialize* specInst) - { - // We will only attempt to specialize when all of the - // operands to the `speicalize(...)` instruction are - // themselves fully specialized. - // - if(!areAllOperandsFullySpecialized(specInst)) - return; - - // The invariant that the arguments are fully specialized - // should mean that `a, b, c, ...` are in a form that - // we can work with, but it does *not* guarantee - // that the `g` operand is something we can work with. - // - // We can only perform specialization in the case where - // the base `g` is a known `generic` instruction. - // - auto baseVal = specInst->getBase(); - auto genericVal = as(baseVal); - if(!genericVal) - return; - - // We can also only specialize a generic if it - // represents a definition rather than a declaration. - // - if(!canSpecializeGeneric(genericVal)) - return; - - // Once we know that specialization is possible, - // the actual work is fairly simple. - // - // First, we find or generate a specialized - // version of the result of the generic (a specialized - // type, function, or whatever). - // - auto specializedVal = specializeGeneric(genericVal, specInst); - - // Any uses of this `specialize(...)` instruction will - // become uses of `specializeVal`, so we want to re-consider - // them for subsequent transformations. - // - addUsersToWorkList(specInst); - - // Then we simply replace any uses of the `specialize(...)` - // instruction with the specialized value and delete - // the `specialize(...)` instruction from existence. - // - specInst->replaceUsesWith(specializedVal); - specInst->removeAndDeallocate(); - } - - // Generic specialization depends on identifying when - // instructions are fully specialized. - // - void maybeMarkAsFullySpecialized( - IRInst* inst) - { - switch(inst->op) - { - default: - // The default case is that an instruction can - // be considered as fully specialized as soon - // as all of its operands are. - // - // TODO: We realistically need a more refined - // check here that uses a white-list of instructions - // that can represent values suitable for use - // as generic arguments. - // - if(areAllOperandsFullySpecialized(inst)) - { - markInstAsFullySpecialized(inst); - } - break; - - // Certain instructions cannot ever be considered - // fully specialized because they should never - // be substituted into a generic as its arguments. - case kIROp_Specialize: - case kIROp_lookup_interface_method: - case kIROp_ExtractExistentialType: - case kIROp_BindExistentialsType: - break; - } - } - - // The core of this pass is to look at one instruction - // at a time, and try to perform whatever specialization - // is appropriate based on its opcode. - // - void maybeSpecializeInst( - IRInst* inst) - { - switch(inst->op) - { - default: - // By default we assume that specialization is - // not possible for a given opcode. - // - break; - - case kIROp_Specialize: - // The logic for specializing a `specialize(...)` - // instruction has already been elaborated above. - // - maybeSpecializeGeneric(cast(inst)); - break; - - case kIROp_lookup_interface_method: - // The remaining case we need to consider here for generics - // is when we have a `lookup_witness_method` instruction - // that is being applied to a concrete witness table, - // because we can specialize it to just be a direct - // reference to the actual witness value from the table. - // - maybeSpecializeWitnessLookup(cast(inst)); - break; - - case kIROp_Call: - // When writing functions with existential-type parameters, - // we need additional support to specialize a callee - // function based on the concrete type encapsulated in - // an argument of existential type. - // - maybeSpecializeExistentialsForCall(cast(inst)); - break; - - // The specialization of functions with existential-type - // parameters can create further opportunities for specialization, - // but in order to realize these we often need to propagate - // through local simplification on values of existential type. - // - case kIROp_ExtractExistentialType: - maybeSpecializeExtractExistentialType(inst); - break; - case kIROp_ExtractExistentialValue: - maybeSpecializeExtractExistentialValue(inst); - break; - case kIROp_ExtractExistentialWitnessTable: - maybeSpecializeExtractExistentialWitnessTable(inst); - break; - - case kIROp_Load: - maybeSpecializeLoad(as(inst)); - break; - - case kIROp_FieldExtract: - maybeSpecializeFieldExtract(as(inst)); - break; - case kIROp_FieldAddress: - maybeSpecializeFieldAddress(as(inst)); - break; - - case kIROp_BindExistentialsType: - maybeSpecializeBindExistentialsType(as(inst)); - break; - } - } - - // Specializing lookup on witness tables is a general - // transformation that helps with both generic and - // existential-based code. - // - void maybeSpecializeWitnessLookup( - IRLookupWitnessMethod* lookupInst) - { - // Note: While we currently have named the instruction - // `lookup_witness_method`, the `method` part is a misnomer - // and the same instruction can look up *any* interface - // requirement based on the witness table that provides - // a conformance, and the "key" that indicates the interface - // requirement. - - // We can only specialize in the case where the lookup - // is being done on a concrete witness table, and not - // the result of a `specialize` instruction or other - // operation that will yield such a table. - // - auto witnessTable = as(lookupInst->getWitnessTable()); - if(!witnessTable) - return; - - // Because we have a concrete witness table, we can - // use it to look up the IR value that satisfies - // the given interface requirement. - // - auto requirementKey = lookupInst->getRequirementKey(); - auto satisfyingVal = findWitnessVal(witnessTable, requirementKey); - - // We expect to always find a satisfying value, but - // we will go ahead and code defensively so that - // we leave "correct" but unspecialized code if - // we cannot find a concrete value to use. - // - if(!satisfyingVal) - return; - - // At this point, we know that `satisfyingVal` is what - // would result from executing this `lookup_witness_method` - // instruction dynamically, so we can go ahead and - // replace the original instruction with that value. - // - // We also make sure to add any uses of the lookup - // instruction to our work list, because subsequent - // simplifications might be possible now. - // - addUsersToWorkList(lookupInst); - lookupInst->replaceUsesWith(satisfyingVal); - lookupInst->removeAndDeallocate(); - } - - // The above subroutine needed a way to look up - // the satisfying value for a given requirement - // key in a concrete witness table, so let's - // define that now. - // - IRInst* findWitnessVal( - IRWitnessTable* witnessTable, - IRInst* requirementKey) - { - // A witness table is basically just a container - // for key-value pairs, and so the best we can - // do for now is a naive linear search. - // - for( auto entry : witnessTable->getEntries() ) - { - if (requirementKey == entry->getRequirementKey()) - { - return entry->getSatisfyingVal(); - } - } - return nullptr; - } - - // All of the machinery for generic specialization - // has been defined above, so we will now walk - // through the flow of the overall specialization pass. - // - void processModule() - { - // We start by initializing our shared IR building state, - // since we will re-use that state for any code we - // generate along the way. - // - SharedIRBuilder* sharedBuilder = &sharedBuilderStorage; - sharedBuilder->module = module; - sharedBuilder->session = module->session; - - // The unspecialized IR we receive as input will have - // `IRBindGlobalGenericParam` instructions that associate - // each global-scope generic parameter (a type, witness - // table, or what-have-you) with the value that it should - // be bound to for the purposes of this code-generation - // pass. - // - // Before doing any other specialization work, we will - // iterate over these instructions (which may only - // appear at the global scope) and use them to drive - // replacement of the given generic type parameter with - // the desired concrete value. - // - // TODO: When we start to support global shader parameters - // that include existential/interface types, we will need - // to support a similar specialization step for them. - // - specializeGlobalGenericParameters(); - - // Now that we've eliminated all cases of global generic parameters, - // we should now have the properties that: - // - // 1. Execution starts in non-generic code, with no unbound - // generic parameters in scope. - // - // 2. Any case where non-generic code makes use of a generic - // type/function, there will be a `specialize` instruction - // that specifies both the generic and the (concrete) type - // arguments that should be provided to it. - // - // The basic approach now is to look for opportunities to apply - // our specialization rules (e.g., a `specialize` instruction - // where all the type arguments are concrete types) and then - // processing any additional opportunities created along the way. - // - // We start out simple by putting the root instruction for the - // module onto our work list. - // - addToWorkList(module->getModuleInst()); - - while(workList.getCount() != 0) - { - - // We will then iterate until our work list goes dry. - // - while(workList.getCount() != 0) - { - IRInst* inst = workList.getLast(); - - workList.removeLast(); - workListSet.Remove(inst); - cleanInsts.Add(inst); - - // For each instruction we process, we want to perform - // a few steps. - // - // First we will do any checking required to tag an - // instruction as being fully specialized. - // - maybeMarkAsFullySpecialized(inst); - - // Next we will look for all the general-purpose - // specialization opportunities (generic specialization, - // existential specialization, simplifications, etc.) - // - maybeSpecializeInst(inst); - - // Finally, we need to make our logic recurse through - // the whole IR module, so we want to add the children - // of any parent instructions to our work list so that - // we process them too. - // - // Note that we are adding the children of an instruction - // in reverse order. This is because the way we are - // using the work list treats it like a stack (LIFO) and - // we know that fully-specialized-ness will tend to flow - // top-down through the program, so that we want to process - // the children of an instruction in their original order. - // - for(auto child = inst->getLastChild(); child; child = child->getPrevInst()) - { - // Also note that `addToWorkList` has been written - // to avoid adding any instruction that is a descendent - // of an IR generic, because we don't actually want - // to perform specialization inside of generics. - // - addToWorkList(child); - } - } - - addDirtyInstsToWorkListRec(module->getModuleInst()); - - } - - // Once the work list has gone dry, we should have the invariant - // that there are no `specialize` instructions inside of non-generic - // functions that in turn reference a generic type/function, *except* - // in the case where that generic is for a builtin type/function, in - // which case we wouldn't want to specialize it anyway. - } - - void addDirtyInstsToWorkListRec(IRInst* inst) - { - if( !cleanInsts.Contains(inst) ) - { - addToWorkList(inst); - } - - for(auto child = inst->getLastChild(); child; child = child->getPrevInst()) - { - addDirtyInstsToWorkListRec(child); - } - } - - // Given a `call` instruction in the IR, we need to detect the case - // where the callee has some interface-type parameter(s) and at the - // call site it is statically clear what concrete type(s) the arguments - // will have. - // - void maybeSpecializeExistentialsForCall(IRCall* inst) - { - // We can only specialize a call when the callee function is known. - // - auto calleeFunc = as(inst->getCallee()); - if(!calleeFunc) - return; - - // We can only specialize if we have access to a body for the callee. - // - if(!calleeFunc->isDefinition()) - return; - - // We shouldn't bother specializing unless the callee has at least - // one parameter that has an existential/interface type. - // - bool shouldSpecialize = false; - UInt argCounter = 0; - for( auto param : calleeFunc->getParams() ) - { - auto arg = inst->getArg(argCounter++); - if( !isExistentialType(param->getDataType()) ) - continue; - - shouldSpecialize = true; - - // We *cannot* specialize unless the argument value corresponding - // to such a parameter is one we can specialize. - // - if( !canSpecializeExistentialArg(arg)) - return; - - } - // If we never found a parameter worth specializing, we should bail out. - // - if(!shouldSpecialize) - return; - - // At this point, we believe we *should* and *can* specialize. - // - // We need a specialized variant of the callee (with the concrete - // types substituted in for existential-type parameters), and then - // we can replace the call site to call the new function instead. - // - // Any two call sites where the argument types are the same can - // re-use the same callee, so we will cache and re-use the - // specialized functions that we generate (similar to how generic - // specialization works). Therefore we will construct a key - // for use when caching the specialized functions. - // - IRSimpleSpecializationKey key; - - // The specialized callee will always depend on the unspecialized - // function from which it is generated, so we add that to our key. - // - key.vals.add(calleeFunc); - - // Also, for any parameter that has an existential type, the - // specialized function will depend on the concrete type of the - // argument. - // - argCounter = 0; - for( auto param : calleeFunc->getParams() ) - { - auto arg = inst->getArg(argCounter++); - if( !isExistentialType(param->getDataType()) ) - continue; - - if( auto makeExistential = as(arg) ) - { - // Note that we use the *type* stored in the - // existential-type argument, but not anything to - // do with the particular value (otherwise we'd only - // be able to re-use the specialized callee for - // call sites that pass in the exact same argument). - // - auto val = makeExistential->getWrappedValue(); - auto valType = val->getFullType(); - key.vals.add(valType); - - // We are also including the witness table in the key. - // This isn't required with our current language model, - // since a given type can only conform to a given interface - // in one way (so there can be only one witness table). - // That means that the `valType` and the existential - // type of `param` above should uniquely determine - // the witness table we see. - // - // There are forward-looking cases where supporting - // "overlapping conformances" could be required, and - // there is low incremental cost to future-proofing - // this code, so we go ahead and add the witness - // table even if it is redundant. - // - auto witnessTable = makeExistential->getWitnessTable(); - key.vals.add(witnessTable); - } - else if( auto wrapExistential = as(arg) ) - { - auto val = wrapExistential->getWrappedValue(); - auto valType = val->getFullType(); - key.vals.add(valType); - - UInt slotOperandCount = wrapExistential->getSlotOperandCount(); - for( UInt ii = 0; ii < slotOperandCount; ++ii ) - { - auto slotOperand = wrapExistential->getSlotOperand(ii); - key.vals.add(slotOperand); - } - } - else - { - SLANG_UNEXPECTED("missing case for existential argument"); - } - } - - // Once we've constructed our key, we can try to look for an - // existing specialization of the callee that we can use. - // - IRFunc* specializedCallee = nullptr; - if( !existentialSpecializedFuncs.TryGetValue(key, specializedCallee) ) - { - // If we didn't find a specialized callee already made, then we - // will go ahead and create one, and then register it in our cache. - // - specializedCallee = createExistentialSpecializedFunc(inst, calleeFunc); - existentialSpecializedFuncs.Add(key, specializedCallee); - } - - // At this point we have found or generated a specialized version - // of the callee, and we need to emit a call to it. - // - // We will start by constructing the argument list for the new call. - // - argCounter = 0; - List newArgs; - for( auto param : calleeFunc->getParams() ) - { - auto arg = inst->getArg(argCounter++); - - // How we handle each argument depends on whether the corresponding - // parameter has an existential type or not. - // - if( !isExistentialType(param->getDataType()) ) - { - // If the parameter doesn't have an existential type, then we - // don't want to change up the argument we pass at all. - // - newArgs.add(arg); - } - else - { - // Any place where the original function had a parameter of - // existential type, we will now be passing in the concrete - // argument value instead of an existential wrapper. - // - if( auto makeExistential = as(arg) ) - { - auto val = makeExistential->getWrappedValue(); - newArgs.add(val); - } - else if( auto wrapExistential = as(arg) ) - { - auto val = wrapExistential->getWrappedValue(); - newArgs.add(val); - } - else - { - SLANG_UNEXPECTED("missing case for existential argument"); - } - } - } - - // Now that we've built up our argument list, it is simple enough - // to construct a new `call` instruction. - // - IRBuilder builderStorage; - auto builder = &builderStorage; - builder->sharedBuilder = &sharedBuilderStorage; - - builder->setInsertBefore(inst); - auto newCall = builder->emitCallInst( - inst->getFullType(), specializedCallee, newArgs); - - // We will completely replace the old `call` instruction with the - // new one, and will go so far as to transfer any decorations - // that were attached to the old call over to the new one. - // - inst->transferDecorationsTo(newCall); - inst->replaceUsesWith(newCall); - inst->removeAndDeallocate(); - - // Just in case, we will add any instructions that used the - // result of this call to our work list for re-consideration. - // At this moment this shouldn't open up new opportunities - // for specialization, but we can always play it safe. - // - addUsersToWorkList(newCall); - } - - // The above `maybeSpecializeExistentialsForCall` routine needed - // a few utilities, which we will now define. - - // First, we want to be able to test whether a type (used by - // a parameter) is an existential type so that we should specialize it. - // - bool isExistentialType(IRType* type) - { - // An IR-level interface type is always an existential. - // - if(as(type)) - return true; - - // Eventually we will also want to handle arrays over - // existential types, but that will require careful - // handling in many places. - - return false; - } - - // Similarly, we want to be able to test whether an instruction - // used as an argument for an existential-type parameter is - // suitable for use in specialization. - // - bool canSpecializeExistentialArg(IRInst* inst) - { - // A `makeExistential(v, w)` instruction can be used - // for specialization, since we have the concrete value `v` - // (which implicitly determines the concrete type), and - // the witness table `w. - // - if(as(inst)) - return true; - - // A `wrapExistential(v, T0,w0, T1, w1, ...)` instruction - // is just a generalization of `makeExistential`, so it - // can apply in the same cases. - // - if(as(inst)) - return true; - - // If we start to specialize functions that take arrays - // of existentials as input, we will need a strategy to - // determine arguments suitable for use in specializing - // them (these would need to be arrays that nominally - // have an existential element type, but somehow have - // annotations to indicate that the concrete type - // underlying the elements in homogeneous). - - return false; - } - - // In order to cache and re-use functions that have had existential-type - // parameters specialized, we need storage for the cache. - // - Dictionary existentialSpecializedFuncs; - - // The logic for creating a specialized callee function by plugging - // in concrete types for existentials is similar to other cases of - // specialization in the compiler. - // - IRFunc* createExistentialSpecializedFunc( - IRCall* oldCall, - IRFunc* oldFunc) - { - // We will make use of the infrastructure for cloning - // IR code, that is defined in `ir-clone.{h,cpp}`. - // - // In order to do the cloning work we need an - // "environment" that will map old values to - // their replacements. - // - IRCloneEnv cloneEnv; - - // We also need some IR building state, for any - // new instructions we will emit. - // - IRBuilder builderStorage; - auto builder = &builderStorage; - builder->sharedBuilder = &sharedBuilderStorage; - - // We will start out by determining what the parameters - // of the specialized function should be, based on - // the parameters of the original, and the concrete - // type of selected arguments at the call site. - // - // Along the way we will build up explicit lists of - // the parameters, as well as any new instructions - // that need to be added to the body of the function - // we generate (as a kind of "prologue"). We build - // the lists here because we don't yet have a basic - // block, or even a function, to insert them into. - // - List newParams; - List newBodyInsts; - UInt argCounter = 0; - for( auto oldParam : oldFunc->getParams() ) - { - auto arg = oldCall->getArg(argCounter++); - - // Given an old parameter, and the argument value at - // the (old) call site, we need to determine what - // value should stand in for that parameter in - // the specialized callee. - // - IRInst* replacementVal = nullptr; - - // The trickier case is when we have an existential-type - // parameter, because we need to extract out the concrete - // type that is coming from the call site. - // - if( auto oldMakeExistential = as(arg) ) - { - // In this case, the `arg` is `makeExistential(val, witnessTable)` - // and we know that the specialized call site will just be - // passing in `val`. - // - auto val = oldMakeExistential->getWrappedValue(); - auto witnessTable = oldMakeExistential->getWitnessTable(); - - // Our specialized function needs to take a parameter with the - // same type as `val`, to match the call site(s) that will be - // created. - // - auto valType = val->getFullType(); - auto newParam = builder->createParam(valType); - newParams.add(newParam); - - // Within the body of the function we cannot just use `val` - // directly, because the existing code expects an existential - // value, including its witness table. - // - // Therefore we will create a `makeExistential(newParam, witnessTable)` - // in the body of the new function and use *that* as the replacement - // value for the original parameter (since it will have the - // correct existential type, and stores the right witness table). - // - auto newMakeExistential = builder->emitMakeExistential(oldParam->getFullType(), newParam, witnessTable); - newBodyInsts.add(newMakeExistential); - replacementVal = newMakeExistential; - } - else if( auto oldWrapExistential = as(arg) ) - { - auto val = oldWrapExistential->getWrappedValue(); - auto valType = val->getFullType(); - - auto newParam = builder->createParam(valType); - newParams.add(newParam); - - // Within the body of the function we cannot just use `val` - // directly, because the existing code expects an existential - // value, including its witness table. - // - // Therefore we will create a `makeExistential(newParam, witnessTable)` - // in the body of the new function and use *that* as the replacement - // value for the original parameter (since it will have the - // correct existential type, and stores the right witness table). - // - auto newWrapExistential = builder->emitWrapExistential( - oldParam->getFullType(), - newParam, - oldWrapExistential->getSlotOperandCount(), - oldWrapExistential->getSlotOperands()); - newBodyInsts.add(newWrapExistential); - replacementVal = newWrapExistential; - } - else - { - // For parameters that don't have an existential type, - // there is nothing interesting to do. The new function - // will also have a parameter of the exact same type, - // and we'll use that instead of the original parameter. - // - auto newParam = builder->createParam(oldParam->getFullType()); - newParams.add(newParam); - replacementVal = newParam; - } - - // Whatever replacement value was constructed, we need to - // register it as the replacement for the original parameter. - // - cloneEnv.mapOldValToNew.Add(oldParam, replacementVal); - } - - // Next we will create the skeleton of the new - // specialized function, including its type. - // - // In order to construct the type of the new function, we - // need to extract the types of all its parameters. - // - List newParamTypes; - for( auto newParam : newParams ) - { - newParamTypes.add(newParam->getFullType()); - } - IRType* newFuncType = builder->getFuncType( - newParamTypes.getCount(), - newParamTypes.getBuffer(), - oldFunc->getResultType()); - IRFunc* newFunc = builder->createFunc(); - newFunc->setFullType(newFuncType); - - // By construction, our new function type will be - // "fully specialized" by the rules used for doing - // generic specialization elsewhere in this pass. - // - fullySpecializedInsts.Add(newFuncType); - - // The above steps have accomplished the "first phase" - // of cloning the function (since `IRFunc`s have no - // operands). - // - // We can now use the shared IR cloning infrastructure - // to perform the second phase of cloning, which will recursively - // clone any nested decorations, blocks, and instructions. - // - cloneInstDecorationsAndChildren( - &cloneEnv, - builder->sharedBuilder, - oldFunc, - newFunc); - - // Now that the main body of existing isntructions have - // been cloned into the new function, we can go ahead - // and insert all the parameters and body instructions - // we built up into the function at the right place. - // - // We expect the function to always have at least one - // block (this was an invariant established before - // we decided to specialize). - // - auto newEntryBlock = newFunc->getFirstBlock(); - SLANG_ASSERT(newEntryBlock); - - // We expect every valid block to have at least one - // "ordinary" instruction (it will at least have - // a terminator like a `return`). - // - auto newFirstOrdinary = newEntryBlock->getFirstOrdinaryInst(); - SLANG_ASSERT(newFirstOrdinary); - - // All of our parameters will get inserted before - // the first ordinary instruction (since the function parameters - // should come at the start of the first block). - // - for( auto newParam : newParams ) - { - newParam->insertBefore(newFirstOrdinary); - } - - // All of our new body instructions will *also* be inserted - // before the first ordinary instruction (but will come - // *after* the parameters by the order of these two loops). - // - for( auto newBodyInst : newBodyInsts ) - { - newBodyInst->insertBefore(newFirstOrdinary); - } - - // After all this work we have a valid `newFunc` that has been - // specialized to match the types at the call site. - // - // There might be further opportunities for simplification and - // specialization in the function body now that we've plugged - // in some more concrete type information, so we will - // add the whole function to our work list for subsequent - // consideration. - // - addToWorkList(newFunc); - - return newFunc; - } - - // When we've specialized a function with an interface-type parameter - // we will still end up with a `makeExistential` operation in its - // body, which could impede subequent specializations. - // - // For example, if we have the following after specialization: - // - // e = makeExistential(v, w1); - // w2 = extractExistentialWitnessTable(e); - // f = lookup_witness_method(w2, k); - // call(f, ...); - // - // We cannot then specialize the lookup for `f` in this code as written, - // but it seems obvious that we could replace `w2` with `w1` and maybe - // get further along. - // - // In order to set up further specialization opportunities we need - // to implement a few simplification rules around operations that - // extract from an existential, when their operand is a `makeExistential`. - // - // Let's start with the routine for the case above of extracting - // a witness table. - // - void maybeSpecializeExtractExistentialWitnessTable(IRInst* inst) - { - // We know `inst` is `extractExistentialWitnessTable(existentialArg)`. - // - auto existentialArg = inst->getOperand(0); - - if( auto makeExistential = as(existentialArg) ) - { - // In this case we know `inst` is: - // - // extractExistentialWitnessTable(makeExistential(..., witnessTable)) - // - // and we can just simplify that to `witnessTable`. - // - auto witnessTable = makeExistential->getWitnessTable(); - - // Anything that used this instruction is now a candidate for - // further simplification or specialization (e.g., one of - // the users of this instruction could be a `lookup_witness_method` - // that we can now specialize). - // - addUsersToWorkList(inst); - - inst->replaceUsesWith(witnessTable); - inst->removeAndDeallocate(); - } - } - - // The cases for simplifying `extractExistentialValue` is more or less the same - // as for witness tables. - // - void maybeSpecializeExtractExistentialValue(IRInst* inst) - { - // We know `inst` is `extractExistentialValue(existentialArg)`. - // - auto existentialArg = inst->getOperand(0); - if( auto makeExistential = as(existentialArg) ) - { - // Now we know `inst` is: - // - // extractExistentialValue(makeExistential(val, ...)) - // - // and we can just simplify that to `val`. - // - auto val = makeExistential->getWrappedValue(); - - addUsersToWorkList(inst); - - inst->replaceUsesWith(val); - inst->removeAndDeallocate(); - } - } - - // The cases for simplifying `extractExistentialType` is more or less the same - // as for witness tables. - // - void maybeSpecializeExtractExistentialType(IRInst* inst) - { - // We know `inst` is `extractExistentialValue(existentialArg)`. - // - auto existentialArg = inst->getOperand(0); - if( auto makeExistential = as(existentialArg) ) - { - // Now we know `inst` is: - // - // extractExistentialType(makeExistential(val, ...)) - // - // and we can just simplify that to type type of `val`. - // - auto val = makeExistential->getWrappedValue(); - auto valType = val->getFullType(); - - addUsersToWorkList(inst); - - inst->replaceUsesWith(valType); - inst->removeAndDeallocate(); - } - } - - void maybeSpecializeLoad(IRLoad* inst) - { - auto ptrArg = inst->ptr.get(); - - if( auto wrapInst = as(ptrArg) ) - { - // We have an instruction of the form `load(wrapExistential(val, ...))` - // - auto val = wrapInst->getWrappedValue(); - - // We know what type we are expected to - // produce (which should be the pointed-to - // type for whatever the type of the - // `wrapExistential` is). - // - auto resultType = inst->getFullType(); - - IRBuilder builder; - builder.sharedBuilder = &sharedBuilderStorage; - builder.setInsertBefore(inst); - - // We'd *like* to replace this instruction with - // `wrapExistential(load(val))` instead, since that - // will enable subsequent specializations. - // - // To do that, we need to be able to determine - // the type that `load(val)` should return. - // - auto elementType = tryGetPointedToType(&builder, val->getDataType()); - if(!elementType) - return; - - - List slotOperands; - UInt slotOperandCount = wrapInst->getSlotOperandCount(); - for( UInt ii = 0; ii < slotOperandCount; ++ii ) - { - slotOperands.add(wrapInst->getSlotOperand(ii)); - } - - auto newLoadInst = builder.emitLoad(elementType, val); - auto newWrapExistentialInst = builder.emitWrapExistential( - resultType, - newLoadInst, - slotOperandCount, - slotOperands.getBuffer()); - - addUsersToWorkList(inst); - - inst->replaceUsesWith(newWrapExistentialInst); - inst->removeAndDeallocate(); - } - } - - UInt calcExistentialBoxSlotCount(IRType* type) - { - top: - if( as(type) ) - { - return 2; - } - else if( auto ptrType = as(type) ) - { - type = ptrType->getValueType(); - goto top; - } - else if( auto ptrLikeType = as(type) ) - { - type = ptrLikeType->getElementType(); - goto top; - } - else if( auto structType = as(type) ) - { - UInt count = 0; - for( auto field : structType->getFields() ) - { - count += calcExistentialBoxSlotCount(field->getFieldType()); - } - return count; - } - else - { - return 0; - } - } - - void maybeSpecializeFieldExtract(IRFieldExtract* inst) - { - auto baseArg = inst->getBase(); - auto fieldKey = inst->getField(); - - if( auto wrapInst = as(baseArg) ) - { - // We have `getField(wrapExistential(val, ...), fieldKey)` - // - auto val = wrapInst->getWrappedValue(); - - // We know what type we are expected to produce. - // - auto resultType = inst->getFullType(); - - IRBuilder builder; - builder.sharedBuilder = &sharedBuilderStorage; - builder.setInsertBefore(inst); - - // We'd *like* to replace this instruction with - // `wrapExistential(getField(val, fieldKey), ...)` instead, since that - // will enable subsequent specializations. - // - // To do that, we need to figure out: - // - // 1. What type that inner `getField` would return (what - // is the type of the `fieldKey` field in `val`?) - // - // 2. Which of the existential slot operands in `...` there - // actually apply to the given field. - // - - // To determine these things, we need the type of - // `val` to be a structure type so that we can look - // up the field corresponding to `fieldKey`. - // - auto valType = val->getDataType(); - auto valStructType = as(valType); - if(!valStructType) - return; - - UInt slotOperandOffset = 0; - - IRStructField* foundField = nullptr; - for( auto valField : valStructType->getFields() ) - { - if( valField->getKey() == fieldKey ) - { - foundField = valField; - break; - } - - slotOperandOffset += calcExistentialBoxSlotCount(valField->getFieldType()); - } - - if(!foundField) - return; - - auto foundFieldType = foundField->getFieldType(); - - List slotOperands; - UInt slotOperandCount = calcExistentialBoxSlotCount(foundFieldType); - - for( UInt ii = 0; ii < slotOperandCount; ++ii ) - { - slotOperands.add(wrapInst->getSlotOperand(slotOperandOffset + ii)); - } - - auto newGetField = builder.emitFieldExtract( - foundFieldType, - val, - fieldKey); - - auto newWrapExistentialInst = builder.emitWrapExistential( - resultType, - newGetField, - slotOperandCount, - slotOperands.getBuffer()); - - addUsersToWorkList(inst); - inst->replaceUsesWith(newWrapExistentialInst); - inst->removeAndDeallocate(); - } - } - - - void maybeSpecializeFieldAddress(IRFieldAddress* inst) - { - auto baseArg = inst->getBase(); - auto fieldKey = inst->getField(); - - if( auto wrapInst = as(baseArg) ) - { - // We have `getFieldAddr(wrapExistential(val, ...), fieldKey)` - // - auto val = wrapInst->getWrappedValue(); - - // We know what type we are expected to produce. - // - auto resultType = inst->getFullType(); - - IRBuilder builder; - builder.sharedBuilder = &sharedBuilderStorage; - builder.setInsertBefore(inst); - - // We'd *like* to replace this instruction with - // `wrapExistential(getFieldAddr(val, fieldKey), ...)` instead, since that - // will enable subsequent specializations. - // - // To do that, we need to figure out: - // - // 1. What type that inner `getFieldAddr` would return (what - // is the type of the `fieldKey` field in `val`?) - // - // 2. Which of the existential slot operands in `...` there - // actually apply to the given field. - // - - // To determine these things, we need the type of - // `val` to be a (pointer to a) structure type so that we can look - // up the field corresponding to `fieldKey`. - // - auto valType = tryGetPointedToType(&builder, val->getDataType()); - if(!valType) - return; - - auto valStructType = as(valType); - if(!valStructType) - return; - - UInt slotOperandOffset = 0; - - IRStructField* foundField = nullptr; - for( auto valField : valStructType->getFields() ) - { - if( valField->getKey() == fieldKey ) - { - foundField = valField; - break; - } - - slotOperandOffset += calcExistentialBoxSlotCount(valField->getFieldType()); - } - - if(!foundField) - return; - - auto foundFieldType = foundField->getFieldType(); - - List slotOperands; - UInt slotOperandCount = calcExistentialBoxSlotCount(foundFieldType); - - for( UInt ii = 0; ii < slotOperandCount; ++ii ) - { - slotOperands.add(wrapInst->getSlotOperand(slotOperandOffset + ii)); - } - - auto newGetFieldAddr = builder.emitFieldAddress( - builder.getPtrType(foundFieldType), - val, - fieldKey); - - auto newWrapExistentialInst = builder.emitWrapExistential( - resultType, - newGetFieldAddr, - slotOperandCount, - slotOperands.getBuffer()); - - addUsersToWorkList(inst); - inst->replaceUsesWith(newWrapExistentialInst); - inst->removeAndDeallocate(); - } - } - - UInt calcExistentialTypeParamSlotCount(IRType* type) - { - top: - if( as(type) ) - { - return 2; - } - else if( auto ptrType = as(type) ) - { - type = ptrType->getValueType(); - goto top; - } - else if( auto ptrLikeType = as(type) ) - { - type = ptrLikeType->getElementType(); - goto top; - } - else if( auto structType = as(type) ) - { - UInt count = 0; - for( auto field : structType->getFields() ) - { - count += calcExistentialTypeParamSlotCount(field->getFieldType()); - } - return count; - } - else - { - return 0; - } - } - - Dictionary existentialSpecializedStructs; - - void maybeSpecializeBindExistentialsType(IRBindExistentialsType* type) - { - auto baseType = type->getBaseType(); - UInt slotOperandCount = type->getExistentialArgCount(); - - IRBuilder builder; - builder.sharedBuilder = &sharedBuilderStorage; - builder.setInsertBefore(type); - - if( auto baseInterfaceType = as(baseType) ) - { - // A `BindExistentials` can - // just be simplified to `ExistentialBox`. - // - // Note: We do *not* simplify straight to `ConcreteType`, because - // that would mess up the layout for aggregate types that - // contain interfaces. The logical indirection introduced - // by `ExistentialBox<...>` will be handled by a later type - // legalization pass that moved the type "pointed to" by - // the box out of line from other fields. - - // We always expect two slot operands, one for the concrete type - // and one for the witness table. - // - SLANG_ASSERT(slotOperandCount == 2); - if(slotOperandCount <= 1) return; - - auto concreteType = (IRType*) type->getExistentialArg(0); - auto newVal = builder.getPtrType(kIROp_ExistentialBoxType, concreteType); - - addUsersToWorkList(type); - type->replaceUsesWith(newVal); - type->removeAndDeallocate(); - return; - } - else if( auto basePtrLikeType = as(baseType) ) - { - // A `BindExistentials, ...>` can be simplified to - // `P>` when `P` is a pointer-like - // type constructor. - // - auto baseElementType = basePtrLikeType->getElementType(); - IRInst* wrappedElementType = builder.getBindExistentialsType( - baseElementType, - slotOperandCount, - type->getExistentialArgs()); - addToWorkList(wrappedElementType); - - auto newPtrLikeType = builder.getType( - basePtrLikeType->op, - 1, - &wrappedElementType); - addToWorkList(newPtrLikeType); - - addUsersToWorkList(type); - type->replaceUsesWith(newPtrLikeType); - type->removeAndDeallocate(); - return; - } - else if( auto baseStructType = as(baseType) ) - { - // In order to bind a `struct` type we will generate - // a new specialized `struct` type on demand and then - // cache and re-use it. - // - // We don't want to start specializing here unless - // all the operand types (and witness tables) we - // will be specializing to are themselves fully - // specialized, so that we can be sure that we - // have a unique type. - // - if( !areAllOperandsFullySpecialized(type) ) - return; - - // Now we we check to see if we've already created - // a specialized struct type or not. - // - IRSimpleSpecializationKey key; - key.vals.add(baseStructType); - for( UInt ii = 0; ii < slotOperandCount; ++ii ) - { - key.vals.add(type->getExistentialArg(ii)); - } - - IRStructType* newStructType = nullptr; - if( !existentialSpecializedStructs.TryGetValue(key, newStructType) ) - { - builder.setInsertBefore(baseStructType); - newStructType = builder.createStructType(); - - auto fieldSlotArgs = type->getExistentialArgs(); - - for( auto oldField : baseStructType->getFields() ) - { - // TODO: we need to figure out which of the specialization arguments - // apply to this field... - - auto oldFieldType = oldField->getFieldType(); - auto fieldSlotArgCount = calcExistentialTypeParamSlotCount(oldFieldType); - - auto newFieldType = builder.getBindExistentialsType( - oldFieldType, - fieldSlotArgCount, - fieldSlotArgs); - - addToWorkList(newFieldType); - - fieldSlotArgs += fieldSlotArgCount; - - builder.createStructField(newStructType, oldField->getKey(), newFieldType); - } - - existentialSpecializedStructs.Add(key, newStructType); - addToWorkList(newStructType); - } - - addUsersToWorkList(type); - type->replaceUsesWith(newStructType); - type->removeAndDeallocate(); - return; - - } - } - - // The handling of specialization for global generic type - // parameters involves searching for all `bind_global_generic_param` - // instructions in the input module. - // - void specializeGlobalGenericParameters() - { - auto moduleInst = module->getModuleInst(); - for(auto inst : moduleInst->getChildren()) - { - // We only want to consider the `bind_global_generic_param` - // instructions, and ignore everything else. - // - auto bindInst = as(inst); - if(!bindInst) - continue; - - // HACK: Our current front-end emit logic can end up emitting multiple - // `bind_global_generic_param` instructions for the same parameter. This is - // a buggy behavior, but a real fix would require refactoring the way - // global generic arguments are specified today. - // - // For now we will do a sanity check to detect parameters that - // have already been specialized. - // - if( !as(bindInst->getOperand(0)) ) - { - // The "parameter" operand is no longer a parameter, so it - // seems things must have been specialized already. - // - continue; - } - - // The actual logic for applying the substitution is - // almost trivial: we will replace any uses of the - // global generic parameter with its desired value. - // - auto param = bindInst->getParam(); - auto val = bindInst->getVal(); - param->replaceUsesWith(val); - } - { - // Now that we've replaced any uses of global generic - // parameters, we will do a second pass to remove - // the parameters and any `bind_global_generic_param` - // instructions, since both should be dead/unused. - // - IRInst* next = nullptr; - for(auto inst = moduleInst->getFirstChild(); inst; inst = next) - { - next = inst->getNextInst(); - - switch(inst->op) - { - default: - break; - - case kIROp_GlobalGenericParam: - case kIROp_BindGlobalGenericParam: - // A `bind_global_generic_param` instruction should - // have no uses in the first place, and all the global - // generic parameters should have had their uses replaced. - // - SLANG_ASSERT(!inst->firstUse); - inst->removeAndDeallocate(); - break; - } - } - } - } -}; - -void specializeModule( - IRModule* module) -{ - SpecializationContext context; - context.module = module; - context.processModule(); -} - - -IRInst* specializeGenericImpl( - IRGeneric* genericVal, - IRSpecialize* specializeInst, - IRModule* module, - SpecializationContext* context) -{ - // Effectively, specializing a generic amounts to "calling" the generic - // on its concrete argument values and computing the - // result it returns. - // - // For now, all of our generics consist of a single - // basic block, so we can "call" them just by - // cloning the instructions in their single block - // into the global scope, using an environment for - // cloning that maps the generic parameters to - // the concrete arguments that were provided - // by the `specialize(...)` instruction. - // - IRCloneEnv env; - - // We will walk through the parameters of the generic and - // register the corresponding argument of the `specialize` - // instruction to be used as the "cloned" value for each - // parameter. - // - // Suppose we are looking at `specialize(g, a, b, c)` and `g` has - // three generic parameters: `T`, `U`, and `V`. Then we will - // be initializing our environment to map `T -> a`, `U -> b`, - // and `V -> c`. - // - UInt argCounter = 0; - for( auto param : genericVal->getParams() ) - { - UInt argIndex = argCounter++; - SLANG_ASSERT(argIndex < specializeInst->getArgCount()); - - IRInst* arg = specializeInst->getArg(argIndex); - - env.mapOldValToNew.Add(param, arg); - } - - // We will set up an IR builder for insertion - // into the global scope, at the same location - // as the original generic. - // - SharedIRBuilder sharedBuilderStorage; - sharedBuilderStorage.module = module; - sharedBuilderStorage.session = module->getSession(); - - IRBuilder builderStorage; - IRBuilder* builder = &builderStorage; - builder->sharedBuilder = &sharedBuilderStorage; - builder->setInsertBefore(genericVal); - - // Now we will run through the body of the generic and - // clone each of its instructions into the global scope, - // until we reach a `return` instruction. - // - for( auto bb : genericVal->getBlocks() ) - { - // We expect a generic to only ever contain a single block. - // - SLANG_ASSERT(bb == genericVal->getFirstBlock()); - - // We will iterate over the non-parameter ("ordinary") - // instructions only, because parameters were dealt - // with explictly at an earlier point. - // - for( auto ii : bb->getOrdinaryInsts() ) - { - // The last block of the generic is expected to end with - // a `return` instruction for the specialized value that - // comes out of the abstraction. - // - // We thus use that cloned value as the result of the - // specialization step. - // - if( auto returnValInst = as(ii) ) - { - auto specializedVal = findCloneForOperand(&env, returnValInst->getVal()); - return specializedVal; - } - - // For any instruction other than a `return`, we will - // simply clone it completely into the global scope. - // - IRInst* clonedInst = cloneInst(&env, builder, ii); - - // Any new instructions we create during cloning were - // not present when we initially built our work list, - // so we need to make sure to consider them now. - // - // This is important for the cases where one generic - // invokes another, because there will be `specialize` - // operations nested inside the first generic that refer - // to the second. - // - if( context ) - { - context->addToWorkList(clonedInst); - } - } - } - - // If we reach this point, something went wrong, because we - // never encountered a `return` inside the body of the generic. - // - SLANG_UNEXPECTED("no return from generic"); - UNREACHABLE_RETURN(nullptr); -} - -IRInst* specializeGeneric( - IRSpecialize* specializeInst) -{ - auto baseGeneric = as(specializeInst->getBase()); - SLANG_ASSERT(baseGeneric); - if(!baseGeneric) return specializeInst; - - auto module = specializeInst->getModule(); - SLANG_ASSERT(module); - if(!module) return specializeInst; - - return specializeGenericImpl(baseGeneric, specializeInst, module, nullptr); -} - - -} // namespace Slang diff --git a/source/slang/ir-specialize.h b/source/slang/ir-specialize.h deleted file mode 100644 index 0b53d28eb..000000000 --- a/source/slang/ir-specialize.h +++ /dev/null @@ -1,12 +0,0 @@ -// ir-specialize.h -#pragma once - -namespace Slang -{ -struct IRModule; - - /// Specialize generic and interface-based code to use concrete types. -void specializeModule( - IRModule* module); - -} diff --git a/source/slang/ir-ssa.cpp b/source/slang/ir-ssa.cpp deleted file mode 100644 index 64f9210e1..000000000 --- a/source/slang/ir-ssa.cpp +++ /dev/null @@ -1,1159 +0,0 @@ -// ir-ssa.cpp -#include "ir-ssa.h" - -#include "ir.h" -#include "ir-clone.h" -#include "ir-insts.h" - -namespace Slang { - -// Track information on a phi node we are in -// the process of constructing. -struct PhiInfo : RefObject -{ - // The phi node will be represented as a parameter - // to a (non-entry) basic block. - IRParam* phi; - - // The original variable that this phi will be replacing. - IRVar* var; - - // The operands to the phi will be stored as uses here, - // because our IR parameters don't have operands. - // - // Once we've collected all the values we plan to use, - // we will turn this into argument in predecessor blocks - // that branch to this one. - // - // The order of elements in this list must match the - // order in which the predecessor blocks get enumerated. - List operands; - - // If this phi ended up being removed as trivial, then - // this will be the value that we replaced it with. - IRInst* replacement = nullptr; -}; - -// Information about a basic block that we generate/use -// during SSA construction. -struct SSABlockInfo : RefObject -{ - // Map a promotable variable to the value to - // use for that variable - Dictionary valueForVar; - - // The underlying basic block. - IRBlock* block; - - // Have we processed all the instructions in the - // body of this block (so that we would have - // found any stores to SSA variables)? - bool isFilled = false; - - // Have we filled all the predecessors of - // this block, so that we can actually perform - // look up in them? - bool isSealed = false; - - // An IR builder to use when we want to construct - // stuff in the context of this block - IRBuilder builder; - - // Phi nodes we are creating for this block. - List phis; - - // Arguments that this block needs to pass along - // to the phi nodes defined by is sucessor - List successorArgs; -}; - -// State for constructing SSA form for a global value -// with code (usually a function). -struct ConstructSSAContext -{ - // The value that we want to rewrite into SSA form - // (usually an IR function) - IRGlobalValueWithCode* globalVal; - - // Variables that we've identified for promotion - // to SSA values. - List promotableVars; - - // Information about each basic block - Dictionary> blockInfos; - - // IR building state to use during the operation - SharedIRBuilder sharedBuilder; - - // Instructions to remove during cleanup - List instsToRemove; - - IRBuilder builder; - IRBuilder* getBuilder() { return &builder; } - - - Dictionary> phiInfos; - - PhiInfo* getPhiInfo(IRParam* phi) - { - if(auto found = phiInfos.TryGetValue(phi)) - return *found; - return nullptr; - } -}; - -/// Do all uses of this instruction lead to a `load`? -/// -/// Checks if all uses of `inst` are either loads, -/// or get-element-address/get-field-address operations -/// that also lead to loads. -bool allUsesLeadToLoads(IRInst* inst) -{ - for (auto u = inst->firstUse; u; u = u->nextUse) - { - auto user = u->getUser(); - switch (user->op) - { - default: - return false; - - case kIROp_Load: - break; - - case kIROp_getElementPtr: - case kIROp_FieldAddress: - { - // Sanity check: the address being used should - // be the base-address operand, and not the field - // key or index (this should never be a problem). - if (u != &user->getOperands()[0]) - return false; - - if (!allUsesLeadToLoads(user)) - return false; - } - break; - } - } - - // If all of the uses passed our checking, then - // we are good to go. - return true; - -} - -// Is the given variable one that we can promote to SSA form? -bool isPromotableVar( - ConstructSSAContext* /*context*/, - IRVar* var) -{ - // We want to identify variables such that we can always - // determine what they will contain at a point in the - // program by directly inspecting their uses. - // - // The simplest possible answer would be instructions - // that are only ever used as the operand of "full" - // load and store instructions (loads and stores that - // write the entire variable). This is enough to - // promote simple scalar variables to SSA temporaries, - // but falls apart for aggregates and arrays. - // - // A slightly more powerful option (which is what we - // implement for now) is to promote variables when - // all of the stores are "full," and all other uses - // are in the form of a "chain" of `getElmeentAddress` - // or `getFieldAddress` operations that terminates - // with a load. - // - // An even more powerful option (which we do not yet - // implement) would be to handle cases where there are - // "chains" that end with stores, and to treat these - // as partial assignments (where we can still form - // an SSA value by creating a new temporary with just - // one element/field different). This kind of approach - // would be best if it is combined with scalarization, - // so that we don't need to construct aggregate temps. - // - - for (auto u = var->firstUse; u; u = u->nextUse) - { - auto user = u->getUser(); - switch (user->op) - { - default: - // If the variable gets used by any operation - // we can't account for directly, then it isn't - // promotable. - return false; - - case kIROp_Load: - { - // A load has only a single argument, so - // it had better be our pointer. - SLANG_ASSERT(u == &((IRLoad*) user)->ptr); - } - break; - - case kIROp_Store: - { - auto storeInst = (IRStore*)user; - - // We don't want to promote a variable if - // its address gets stored into another - // variable, so check for that case. - if (u == &storeInst->val) - return false; - - // Otherwise our variable is being used - // as the destination for the store, and - // that is okay by us. - SLANG_ASSERT(u == &storeInst->ptr); - } - break; - - case kIROp_getElementPtr: - case kIROp_FieldAddress: - { - // Sanity check: the address being used should - // be the base-address operand, and not the field - // key or index (this should never be a problem). - if (u != &user->getOperands()[0]) - return false; - - if (!allUsesLeadToLoads(user)) - return false; - } - break; - } - } - - // If all of the uses passed our checking, then - // we are good to go. - return true; -} - -// Identify local variables that can be promoted to SSA form -void identifyPromotableVars( - ConstructSSAContext* context) -{ - for (auto bb = context->globalVal->getFirstBlock(); bb; bb = bb->getNextBlock()) - { - for (auto ii = bb->getFirstInst(); ii; ii = ii->getNextInst()) - { - if (ii->op != kIROp_Var) - continue; - - IRVar* var = (IRVar*)ii; - - if (isPromotableVar(context, var)) - { - context->promotableVars.add(var); - } - } - } -} - -/// If `value` is a promotable variable, then cast and return it. -IRVar* asPromotableVar( - ConstructSSAContext* context, - IRInst* value) -{ - if (value->op != kIROp_Var) - return nullptr; - - IRVar* var = (IRVar*)value; - if (!context->promotableVars.contains(var)) - return nullptr; - - return var; -} - -/// If `value` is a promotable variable or an access chain -/// based on one, then cast and return the variable. -IRVar* asPromotableVarAccessChain( - ConstructSSAContext* context, - IRInst* value) -{ - switch (value->op) - { - case kIROp_Var: - return asPromotableVar(context, value); - - case kIROp_FieldAddress: - case kIROp_getElementPtr: - return asPromotableVarAccessChain(context, value->getOperand(0)); - - default: - return nullptr; - } -} - -/// After looking up the SSA value of avariable in some context, -/// apply whatever "access chain" was applied at the original use site. -/// -/// E.g., if the original operation was *((&a)->b) or *((&a) + i) and we've -/// resolved that the value of the variable `a` should be `v`, then -/// construct v.b or v[i]. -/// -IRInst* applyAccessChain( - ConstructSSAContext* context, - IRBuilder* builder, - IRInst* accessChain, - IRInst* leafVarValue) -{ - switch (accessChain->op) - { - default: - SLANG_UNEXPECTED("unexpected op along access chain"); - UNREACHABLE_RETURN(leafVarValue); - - case kIROp_Var: - return leafVarValue; - - case kIROp_FieldAddress: - { - SLANG_ASSERT(context->instsToRemove.contains(accessChain)); - - auto baseChain = accessChain->getOperand(0); - auto fieldKey = accessChain->getOperand(1); - auto type = cast(accessChain->getDataType())->getValueType(); - auto baseValue = applyAccessChain(context, builder, baseChain, leafVarValue); - return builder->emitFieldExtract( - type, - baseValue, - fieldKey); - } - - case kIROp_getElementPtr: - { - SLANG_ASSERT(context->instsToRemove.contains(accessChain)); - - auto baseChain = accessChain->getOperand(0); - auto index = accessChain->getOperand(1); - auto type = cast(accessChain->getDataType())->getValueType(); - auto baseValue = applyAccessChain(context, builder, baseChain, leafVarValue); - return builder->emitElementExtract( - type, - baseValue, - index); - } - } -} - -// Try to read the value of an SSA variable -// in the context of the given block. If -// the variable is defined in the block, then -// that value will be used. If not, this all -// may recursively work its way up through -// the predecessors of the block. -IRInst* readVar( - ConstructSSAContext* context, - SSABlockInfo* blockInfo, - IRVar* var); - - /// Try to copy any relevant decorations from `var` over to `val`. - /// -static void cloneRelevantDecorations( - IRVar* var, - IRInst* val) -{ - // Copy selected decorations over from the original - // variable to the SSA variable, when doing so is - // required for semantics. - // - for( auto decoration : var->getDecorations() ) - { - switch(decoration->op) - { - default: - // Ignore most decorations. - // - // TODO: Should we include or exclude by default? - break; - - case kIROp_PreciseDecoration: - case kIROp_NameHintDecoration: - // Copy these decorations if the target doesn't already have them, - // but don't make duplicate decorations on the target. - // - if( !val->findDecorationImpl(decoration->op) ) - { - cloneDecoration(decoration, val, var->getModule()); - } - break; - } - } -} - -// Add a phi node to represent the given variable -PhiInfo* addPhi( - ConstructSSAContext* context, - SSABlockInfo* blockInfo, - IRVar* var) -{ - auto builder = &blockInfo->builder; - - auto valueType = var->getDataType()->getValueType(); - if( auto rate = var->getRate() ) - { - valueType = context->getBuilder()->getRateQualifiedType(rate, valueType); - } - IRParam* phi = builder->createParam(valueType); - cloneRelevantDecorations(var, phi); - - RefPtr phiInfo = new PhiInfo(); - context->phiInfos.Add(phi, phiInfo); - - phiInfo->phi = phi; - phiInfo->var = var; - - blockInfo->phis.add(phiInfo); - - return phiInfo; -} - -IRInst* tryRemoveTrivialPhi( - ConstructSSAContext* context, - PhiInfo* phiInfo) -{ - auto phi = phiInfo->phi; - - // We are going to check if all of the operands - // to the phi are either the same, or are equal - // to the phi itself. - - IRInst* same = nullptr; - for (auto u : phiInfo->operands) - { - auto usedVal = u.get(); - SLANG_ASSERT(usedVal); - - if (usedVal == same || usedVal == phi) - { - // Either this is a self-reference, or it refers - // to the same value we've seen already. - continue; - } - if (same != nullptr) - { - // We've found at least two distinct values - // other than the phi itself, so this phi - // indeed appears to be non-trivial. - // - // We will keep the phi around. - return phi; - } - else - { - // This value is distinct from the phi itself, - // so we need to track its value. - same = usedVal; - } - } - - if (!same) - { - // There were no operands other than the phi itself. - // This implies that the value at the use sites should - // actually be undefined. - SLANG_UNIMPLEMENTED_X("trivial phi"); - } - - // Removing this phi as trivial may make other phi nodes - // become trivial. We will recognize such candidates - // by looking for phi nodes that use this node. - List otherPhis; - for( auto u = phi->firstUse; u; u = u->nextUse ) - { - auto user = u->user; - if(!user) continue; - if(user == phi) continue; - - if( user->op == kIROp_Param ) - { - auto maybeOtherPhi = (IRParam*) user; - if( auto otherPhiInfo = context->getPhiInfo(maybeOtherPhi) ) - { - otherPhis.add(otherPhiInfo); - } - } - } - - // replace uses of the phi (including its possible uses - // of itself) with the unique non-phi value. - phi->replaceUsesWith(same); - - // Clear out the operands to the phi, since they won't - // actually get used in the program any more. - for( auto& u : phiInfo->operands ) - { - u.clear(); - } - - // We will record the value that was used to replace this - // phi, so that we can easily look it up later. - phiInfo->replacement = same; - - // Now that we've cleaned up this phi, we need to consider - // other phis that might have become trivial. - for( auto otherPhi : otherPhis ) - { - tryRemoveTrivialPhi(context, otherPhi); - } - - return same; -} - -IRInst* addPhiOperands( - ConstructSSAContext* context, - SSABlockInfo* blockInfo, - PhiInfo* phiInfo) -{ - auto var = phiInfo->var; - - auto block = blockInfo->block; - - List operandValues; - for (auto predBlock : block->getPredecessors()) - { - // Precondition: if we have multiple predecessors, then - // each must have only one successor (no critical edges). - // - SLANG_ASSERT(predBlock->getSuccessors().getCount() == 1); - - auto predInfo = *context->blockInfos.TryGetValue(predBlock); - - auto phiOperand = readVar(context, predInfo, var); - - operandValues.add(phiOperand); - } - - // The `IRUse` type needs to stay at a stable location - // since they get threaded into lists. We allocate the - // list with its final size so that we can preserve the - // required invariant. - - UInt operandCount = operandValues.getCount(); - phiInfo->operands.setCount(operandCount); - for(UInt ii = 0; ii < operandCount; ++ii) - { - phiInfo->operands[ii].init(phiInfo->phi, operandValues[ii]); - } - - return tryRemoveTrivialPhi(context, phiInfo); -} - -void writeVar( - ConstructSSAContext* /*context*/, - SSABlockInfo* blockInfo, - IRVar* var, - IRInst* val) -{ - blockInfo->valueForVar[var] = val; -} - -void maybeSealBlock( - ConstructSSAContext* context, - SSABlockInfo* blockInfo) -{ - // We can't seal a block that has already been sealed. - if (blockInfo->isSealed) - return; - - // We can't seal a block until all of its predecessors - // have been filled. - for (auto pp : blockInfo->block->getPredecessors()) - { - auto predInfo = *context->blockInfos.TryGetValue(pp); - if (!predInfo->isFilled) - return; - } - - // All the checks passed, so it seems like we can be sealed. - - // We will loop over any incomplete phis that have been recoreded - // for this block, and complete them here. - // - // Note that we are doing the "inefficient" loop where we compute - // the count on each iteration to account for the possibility that - // new incomplete phis will get added while we are working. - for (Index ii = 0; ii < blockInfo->phis.getCount(); ++ii) - { - auto incompletePhi = blockInfo->phis[ii]; - addPhiOperands(context, blockInfo, incompletePhi); - } - - // After we've completed all our incomplete phis, we can mark this - // block as sealed and move along. - blockInfo->isSealed = true; -} - -// In some cases we may have a pointer to an IR value that -// represents a phi node that has been replaced with another -// IR value, because we discovered that the phi is no longer -// needed. -// -// The `maybeGetPhiReplacement` function will follow any -// chain of replacements that might be present, so that we -// don't end up referencing a dangling/unused value in -// the code that we generate. -// -IRInst* maybeGetPhiReplacement( - ConstructSSAContext* context, - IRInst* inVal) -{ - IRInst* val = inVal; - - while( val->op == kIROp_Param ) - { - // The value is a parameter, but is it a phi? - IRParam* maybePhi = (IRParam*) val; - RefPtr phiInfo = nullptr; - if(!context->phiInfos.TryGetValue(maybePhi, phiInfo)) - break; - - // Okay, this is indeed a phi we are adding, but - // is it one that got replaced? - if(!phiInfo->replacement) - break; - - // The phi we want to use got replaced, so we - // had better use the replacement instead. - val = phiInfo->replacement; - } - - return val; -} - -IRInst* readVarRec( - ConstructSSAContext* context, - SSABlockInfo* blockInfo, - IRVar* var) -{ - IRInst* val = nullptr; - if (!blockInfo->isSealed) - { - // If block isn't sealed, we need to - // speculatively add a phi to it. - // This phi may get removed later, once - // we are able to seal this block. - - PhiInfo* phiInfo = addPhi(context, blockInfo, var); - val = phiInfo->phi; - } - else - { - // If the block is sealed, then we are free to look at - // it predecessor list, and use that to decide what to do. - auto predecessors = blockInfo->block->getPredecessors(); - - // - IRBlock* firstPred = nullptr; - bool multiplePreds = false; - for (auto pp : predecessors) - { - if (!firstPred) - { - // A candidate for the sole predecessor - firstPred = pp; - } - else if (pp == firstPred) - { - // Same as existing predecessor - } - else - { - // Multiple unique predecessors - multiplePreds = true; - } - } - - if (!firstPred) - { - // The block had *no* predecssors. This will commonly - // happen for the entry block, but could also conceivably - // happen for a block that is somehow disconnected - // from the CFG and thus unreachable. - - // We would only reach this function (`readVarRec`) if - // a local lookup in the block had already failed, so - // at this point we are dealing with an undefined value. - - auto type = var->getDataType()->getValueType(); - val = blockInfo->builder.emitUndefined(type); - } - else if (!multiplePreds) - { - // There is only a single predecessor for this block, - // so there is no need to insert a phi. Instead, we - // just perform the lookup step recursively in - // the predecessor. - auto predInfo = *context->blockInfos.TryGetValue(firstPred); - val = readVar(context, predInfo, var); - } - else - { - // The default/fallback case requires us to create - // a phi node in the current block, and then look - // up the appropriate operands in the predecessor - // blocks, which will eventually become the operands - // that drive the phi. - - // Create the phi node for the given variable - PhiInfo* phiInfo = addPhi(context, blockInfo, var); - - // Mark the phi as the value for the variable inside - // this block - writeVar(context, blockInfo, var, phiInfo->phi); - - // Now add operands to the phi and maybe simplify - // it, based on what gets found. - - val = addPhiOperands(context, blockInfo, phiInfo); - } - } - - // Whatever value we find, we need to mark it as the - // value for the given variable in this block - writeVar(context, blockInfo, var, val); - - // If `val` represents a phi node (block parameter) then - // it is possible that some of the operations above might - // have caused it to be replaced with another value, - // and in that case we had better not return it to - // be referenced in user code. - // - // Note: it is okay for the `valueForVar` map that - // we update in `writeVar` to use the old value, so long - // as we do this replacement logic anywhere we might read - // from that map. - // - val = maybeGetPhiReplacement(context, val); - - return val; -} - - - -IRInst* readVar( - ConstructSSAContext* context, - SSABlockInfo* blockInfo, - IRVar* var) -{ - // In the easy case, there will be a preceeding - // store in the same block, so we can use - // that local value. - IRInst* val = nullptr; - if (blockInfo->valueForVar.TryGetValue(var, val)) - { - // Hooray, we found a value to use, and we - // can proceed without too many complications. - - // Just like in the `readVarRec` case above, we need - // to handle the case where `val` might represent - // a phi node that has subsequently been replaced. - // - val = maybeGetPhiReplacement(context, val); - - return val; - } - - // Otherwise we need to try to non-trivial/recursive - // case of lookup. - return readVarRec(context, blockInfo, var); -} - -void processBlock( - ConstructSSAContext* context, - IRBlock* block, - SSABlockInfo* blockInfo) -{ - // Before starting, check if this block can be sealed - maybeSealBlock(context, blockInfo); - - // Walk the instructions in the block, and either - // leave them as-is, or replace them with a value - // that we look up with local/global value numbering - - IRInst* next = nullptr; - for (auto ii = block->getFirstInst(); ii; ii = next) - { - next = ii->getNextInst(); - - // Any new instructions we create to represent - // the new value will get inserted before whatever - // instruction we are working with. - blockInfo->builder.setInsertBefore(ii); - - switch (ii->op) - { - default: - // Ordinary instruction -> leave as-is - break; - - case kIROp_Store: - { - auto storeInst = (IRStore*)ii; - auto ptrArg = storeInst->ptr.get(); - auto valArg = storeInst->val.get(); - - if (auto var = asPromotableVar(context, ptrArg)) - { - // We are storing to a promotable variable, - // so we want to register the value being - // stored as the value for the given SSA - // variable. - writeVar(context, blockInfo, var, valArg); - - // Also eliminate the store instruction, - // since it is no longer needed. - storeInst->removeAndDeallocate(); - } - } - break; - - case kIROp_Load: - { - IRLoad* loadInst = (IRLoad*)ii; - auto ptrArg = loadInst->ptr.get(); - - if (auto var = asPromotableVarAccessChain(context, ptrArg)) - { - // We are loading from a promotable variable. - // Look up the value in the context of this - // block. - auto val = readVar(context, blockInfo, var); - - cloneRelevantDecorations(var, val); - - val = applyAccessChain(context, &blockInfo->builder, ptrArg, val); - - // We can just replace all uses of this - // load instruction with the given value. - loadInst->replaceUsesWith(val); - - // Also eliminate the load instruction, - // since it is no longer needed. - loadInst->removeAndDeallocate(); - } - } - break; - - case kIROp_getElementPtr: - case kIROp_FieldAddress: - { - auto ptrArg = ii->getOperand(0); - if (auto var = asPromotableVarAccessChain(context, ptrArg)) - { - context->instsToRemove.add(ii); - } - } - break; - - - } - } - - auto terminator = block->getTerminator(); - SLANG_ASSERT(terminator); - blockInfo->builder.setInsertBefore(terminator); - - // Once we are done with all of the instructions - // in a block, we can mark it as "filled," which - // means we can actually consider lookups into - // it. - blockInfo->isFilled = true; - - // Having filled this block might allow us to seal some - // of its successor(s) - for (auto ss : block->getSuccessors()) - { - auto successorInfo = *context->blockInfos.TryGetValue(ss); - maybeSealBlock(context, successorInfo); - } -} - -static void breakCriticalEdges( - ConstructSSAContext* context) -{ - // A critical edge is an edge P -> S where - // P has multiple sucessors, and S has multiple - // predecessors. - // - // In the context of our CFG representation, such an edge - // will be an `IRUse` in the terminator instruction of block P, - // which refers to block S. - // - // We will make a pass over the CFG to collect all the critical - // edges, and then we will break them in a follow-up pass. - - List criticalEdges; - - auto globalVal = context->globalVal; - for (auto pred = globalVal->getFirstBlock(); pred; pred = pred->getNextBlock()) - { - auto successors = pred->getSuccessors(); - if (successors.getCount() <= 1) - continue; - - auto succIter = successors.begin(); - auto succEnd = successors.end(); - - for (; succIter != succEnd; ++succIter) - { - auto succ = *succIter; - - // For the edge to be critical, the successor must have - // more than one predecessor. - // More than that, we require that it has more than one - // *unique* predecessor, to handle the case where multiple - // cases of a `switch` might lead to the same block. - // - // To implement this, we test if it has any predecessor - // other than `pred` which we already know about. - - bool multiplePreds = false; - for (auto pp : succ->getPredecessors()) - { - if (pp != pred) - { - multiplePreds = true; - break; - } - } - if (!multiplePreds) - continue; - - // We have found a critical edge from `pred` to `succ`. - // - // Furthermore, the `IRUse` embedded in `succIter` represents - // that edge directly. - auto edgeUse = succIter.use; - criticalEdges.add(edgeUse); - } - } - - // Now we will iterate over the critical edges and break each - // one by inserting a new block. Note that we do not try - // to break the edges while doing the initial walk, because - // that would change the CFG while we are walking it. - - for (auto edgeUse : criticalEdges) - { - auto pred = cast(edgeUse->getUser()->parent); - auto succ = cast(edgeUse->get()); - - IRBuilder builder; - builder.sharedBuilder = &context->sharedBuilder; - builder.setInsertInto(pred); - - // Create a new block that will sit "along" the edge - IRBlock* edgeBlock = builder.createBlock(); - - edgeUse->debugValidate(); - - // The predecessor block should now branch to - // the edge block. - edgeUse->set(edgeBlock); - - // The edge block should branch (unconditionally) - // to the successor block. - builder.setInsertInto(edgeBlock); - builder.emitBranch(succ); - - // Insert the new block into the block list - // for the function. - // - // In principle, the order of this list shouldn't - // affect the semantics of a program, but we - // might want to be careful about ordering anyway. - edgeBlock->insertAfter(pred); - } -} - -// Construct SSA form for a global value with code -void constructSSA(ConstructSSAContext* context) -{ - // First, detect and and break any critical edges in the CFG, - // because our representation of SSA form doesn't allow for them. - breakCriticalEdges(context); - - // Figure out what variables we can promote to - // SSA temporaries. - identifyPromotableVars(context); - - // If none of the variables are promote-able, - // then we can exit without making any changes - if (context->promotableVars.getCount() == 0) - return; - - // We are going to walk the blocks in order, - // and try to process each, by replacing loads - // and stores of promotable variables with simple values. - - auto globalVal = context->globalVal; - for(auto bb : globalVal->getBlocks()) - { - auto blockInfo = new SSABlockInfo(); - blockInfo->block = bb; - - blockInfo->builder.sharedBuilder = &context->sharedBuilder; - blockInfo->builder.setInsertBefore(bb->getLastInst()); - - context->blockInfos.Add(bb, blockInfo); - } - for(auto bb : globalVal->getBlocks()) - { - auto blockInfo = * context->blockInfos.TryGetValue(bb); - processBlock(context, bb, blockInfo); - } - - // We need to transfer the logical arguments to our phi nodes - // from the phi nodes back to the predecessor blocks that will - // pass them in. - for(auto bb : globalVal->getBlocks()) - { - auto blockInfo = *context->blockInfos.TryGetValue(bb); - - for (auto phiInfo : blockInfo->phis) - { - // If we replaced this phi with another value, - // then we had better not include it in the result. - if (phiInfo->replacement) - continue; - - // We should add the phi as an explicit parameter of - // the given block. - bb->addParam(phiInfo->phi); - - UInt predCounter = 0; - for (auto pp : bb->getPredecessors()) - { - UInt predIndex = predCounter++; - auto predInfo = *context->blockInfos.TryGetValue(pp); - - IRInst* operandVal = phiInfo->operands[predIndex].get(); - - phiInfo->operands[predIndex].clear(); - - predInfo->successorArgs.add(operandVal); - } - } - } - - // Some blocks may now need to pass along arguments to their sucessor, - // which have been stored into the `SSABlockInfo::successorArgs` field. - for(auto bb : globalVal->getBlocks()) - { - auto blockInfo = * context->blockInfos.TryGetValue(bb); - - // Sanity check: all blocks should be filled and sealed. - SLANG_ASSERT(blockInfo->isSealed); - SLANG_ASSERT(blockInfo->isFilled); - - // Don't do any work for blocks that don't need to pass along - // values to the sucessor block. - auto addedArgCount = blockInfo->successorArgs.getCount(); - if (addedArgCount == 0) - continue; - - // We need to replace the terminator instruction with one that - // has additional arguments. - - IRTerminatorInst* oldTerminator = bb->getTerminator(); - SLANG_ASSERT(oldTerminator); - - blockInfo->builder.setInsertInto(bb); - - auto oldArgCount = oldTerminator->getOperandCount(); - auto newArgCount = oldArgCount + addedArgCount; - - List newArgs; - for (UInt aa = 0; aa < oldArgCount; ++aa) - { - newArgs.add(oldTerminator->getOperand(aa)); - } - for (Index aa = 0; aa < addedArgCount; ++aa) - { - newArgs.add(blockInfo->successorArgs[aa]); - } - - IRTerminatorInst* newTerminator = (IRTerminatorInst*)blockInfo->builder.emitIntrinsicInst( - oldTerminator->getFullType(), - oldTerminator->op, - newArgCount, - newArgs.getBuffer()); - - // Transfer decorations (a terminator should have no children) over to the new instruction. - // - oldTerminator->transferDecorationsTo(newTerminator); - - // A terminator better not have uses, so we shouldn't have - // to replace them. - SLANG_ASSERT(!oldTerminator->firstUse); - - - // Okay, we should be clear to remove the old terminator - oldTerminator->removeAndDeallocate(); - } - - // Remove all the instructions we marked for deletion along - // the way. - // - // Currently these are "access chain" instructions for - // loads from (parts of) variables that got promoted. - for (auto inst : context->instsToRemove) - { - // TODO: do we need to be careful here in case one - // of thes operations still has uses, as part of - // another to-be-remvoed instruction? - - inst->removeAndDeallocate(); - } - - // Now we should be able to go through and remove - // of of the variables - for (auto var : context->promotableVars) - { - var->removeAndDeallocate(); - } -} - -// Construct SSA form for a global value with code -void constructSSA(IRModule* module, IRGlobalValueWithCode* globalVal) -{ - ConstructSSAContext context; - context.globalVal = globalVal; - - context.sharedBuilder.module = module; - context.sharedBuilder.session = module->session; - - context.builder.sharedBuilder = &context.sharedBuilder; - context.builder.setInsertInto(module->moduleInst); - - constructSSA(&context); -} - -void constructSSA(IRModule* module, IRInst* globalVal) -{ - switch (globalVal->op) - { - case kIROp_Func: - case kIROp_GlobalVar: - case kIROp_GlobalConstant: - constructSSA(module, (IRGlobalValueWithCode*)globalVal); - - default: - break; - } -} - -void constructSSA(IRModule* module) -{ - for(auto ii : module->getGlobalInsts()) - { - constructSSA(module, ii); - } -} - -} diff --git a/source/slang/ir-ssa.h b/source/slang/ir-ssa.h deleted file mode 100644 index ad874845b..000000000 --- a/source/slang/ir-ssa.h +++ /dev/null @@ -1,9 +0,0 @@ -// ir-ssa.h -#pragma once - -namespace Slang -{ - struct IRModule; - - void constructSSA(IRModule* module); -} diff --git a/source/slang/ir-union.cpp b/source/slang/ir-union.cpp deleted file mode 100644 index b4fbc4c96..000000000 --- a/source/slang/ir-union.cpp +++ /dev/null @@ -1,776 +0,0 @@ -// ir-union.cpp -#include "ir-union.h" - -#include "ir.h" -#include "ir-insts.h" - -namespace Slang { - -// This file will implement a pass to replace any union types (currently -// just tagged unions) with plain `struct` types that attempt to provide -// equivalent semantics. This will necessarily be a bit fragile, and there -// will be fundamental limits to what the translation can support without -// improved features in the target shading languages/ILs. - -struct DesugarUnionTypesContext -{ - // We'll start with some basic state that we need to get the job done. - // - // This includes the IR module we are to process, as well as IR building - // state that we will initialize once and then use throughout the pass. - // - IRModule* module; - SharedIRBuilder sharedBuilderStorage; - IRBuilder builderStorage; - IRBuilder* getBuilder() { return &builderStorage; } - - // Because we will be replacing instructions that refer to unions with - // different logic, we'll want to remove the original instructions. - // However, we need to be careful about modifying the IR tree while also - // iterating it, and to keep things simple for ourselves we'll go ahead - // and build up a list of instruction to remove along the way, and then - // remove them all at the end. - // - List instsToRemove; - - // The overall flow of the pass is pretty simple, so we will walk through it now. - // - void processModule() - { - // We start by initializing our IR building state. - // - sharedBuilderStorage.session = module->session; - sharedBuilderStorage.module = module; - builderStorage.sharedBuilder = &sharedBuilderStorage; - - // Next, we will search for any instruction that create or use - // union types, and process them accordingingly (usually by - // constructing a new instruction to replace them). - // - processInstRec(module->getModuleInst()); - - // Along the way we will build up a list of the tagged union - // types that we encountered, but we will refrain from replacing - // them until we are done (so that we always know that the instructions - // we process above refer to the original type, and not its - // replacement. - // - for( auto info : taggedUnionInfos ) - { - auto taggedUnionType = info->taggedUnionType; - auto replacementInst = info->replacementInst; - - // TODO: We should consider transferring decorations from the source - // type to the destination, but doing so carelessly could create - // problems, since an IR struct type shouldn't have, e.g., a - // `TaggedUnionTypeLayout` attached to it. - - taggedUnionType->replaceUsesWith(replacementInst); - taggedUnionType->removeAndDeallocate(); - } - - // As described previously, we build up the `instsToRemove` list as - // we iterate so that we can remove them all here and not risk - // modifying the IR tree while also walking it. - // - // TODO: This might be overkill and we could conceivably just be - // a bit careful in `processInstRec`. - // - for(auto inst : instsToRemove) - { - inst->removeAndDeallocate(); - } - } - - // In order to replace a (tagged) union type, we will need to know - // something about it, and we will use the `TaggedUnionInfo` type - // to collect all the relevant information. - // - struct TaggedUnionInfo : public RefObject - { - // We obviously need to know the tagged union itself, and - // we will also use this structure to track the instruction - // (an IR struct type) that will replace it. - // - IRTaggedUnionType* taggedUnionType; - IRInst* replacementInst; - - // In order to compute a suitable layout for the replacement - // `struct` type we need to know how the tagged union itself - // would be laid out in memory, so we require that all tagged - // unions in the generated IR have an associated (target-specific) - // layout. - // - TaggedUnionTypeLayout* taggedUnionTypeLayout; - - // The basic approach we will use 16-byte chunks (represented as an array - // of `uint4`s) to reprent the "bulk" of a type, and then use a single field - // that could be up to 12 bytes to represent the "rest" of the type. - // - // Note that there are deeply ingrained assumptions here that all types - // are at least four bytes in size (so that unions cannot easily - // accomodate `half` value), and that any types *larger* than four bytes - // will need to be loaded/stored via multiple 4-byte loads/stores. - // - // With the basic idea out of the way, we need an IR level field - // in our struct to hold the bulk data, which comprises a "key" for - // looking up the field, and the type of the field itself. We also - // keep track of how many bytes we put in our bulk storage. - // - // The bulk field might be: - // - // - null, if none of the case types was 16 bytes or more - // - a single `uint4` for between 16 and 31 (inclusive) bytes - // - an array of `uint4`s for 32 or more bytes - // - UInt64 bulkSize = 0; - IRInst* bulkFieldKey = nullptr; - IRType* bulkFieldType = nullptr; - - // The same basic idea then applies to the rest of the data. - // - // The "rest" field will be either be absent (if the size of the - // type was evently divisible by 16), a scalar `uint`, or else - // a 2- or 3-component vector of `uint`. - // - UInt64 restSize = 0; - IRInst* restFieldKey = nullptr; - IRType* restFieldType = nullptr; - - // Finally, since we are currently working with tagged unions, - // we need a field to hold the tag, which will always be allocated - // after the fields that hold the bulk/rest of the payload. - // - // This field is always a single `uint`. - // - // TODO: if/when we support untagged unions, they could be handled - // by having this field be null. - // - IRInst* tagFieldKey; - }; - - // We will build up a list of all the tagged union types we encounter, - // so that we can replace them with the synthesized types when we are done. - // - List> taggedUnionInfos; - - // It is possible that we will see the same tagged union type referenced - // many times in the IR, but we only want to synthesize the information - // above (including the various IR structures) once, so we also maintain - // a map from the original IR type to the corresponding information. - // - Dictionary mapIRTypeToTaggedUnionInfo; - - // We will process all instructions in the module in a single recursive walk. - // - void processInstRec(IRInst* inst) - { - processInst(inst); - - for( auto child : inst->getChildren() ) - { - processInstRec(child); - } - } - // - // At each instruction, we will check if it is one of the union-related instructions - // we need to replace, and process it accordingly. - // - void processInst(IRInst* inst) - { - switch( inst->op ) - { - default: - // Any instruction not listed below either doesn't involve union types, - // or handles them in a hands-off fashion that we don't need to care about. - // - // E.g., a `load` of a union type from a constant buffer will turn into - // a load of the replacement `struct` type once we are done, and nothing - // needs to be done to the `load` instruction. - // - break; - - case kIROp_TaggedUnionType: - { - // We clearly need to process the tagged union type itself, but the actual - // work is handled by other functions. All we need to do here is ensure - // that the information for this type gets generated, and then we can - // rely on the main `processModule` function to do the actual replacement later. - // - auto type = cast(inst); - getTaggedUnionInfo(type); - } - break; - - case kIROp_ExtractTaggedUnionTag: - { - // The case of extracting the tag from a tagged union is relatively - // simple, because the replacement type will have a dedicated field or it. - // - // We start by finding the tagged union value the instruction is operating - // on, and then looking up the information for its type (which had - // better be a tagged union type). - // - auto taggedUnionVal = inst->getOperand(0); - auto taggedUnionInfo = getTaggedUnionInfo(taggedUnionVal->getDataType()); - - // Because the replacement type will have an explicit field for the tag, - // we can simply emit a single field-extract instruction to read its value - // out. - // - auto builder = getBuilder(); - builder->setInsertBefore(inst); - auto replacement = builder->emitFieldExtract( - inst->getFullType(), - taggedUnionVal, - taggedUnionInfo->tagFieldKey); - - // Now we can replace anything that used the original instruction with - // the new field-extract operation, and add this instruction to the - // list for later removal. - // - inst->replaceUsesWith(replacement); - instsToRemove.add(inst); - } - break; - - case kIROp_ExtractTaggedUnionPayload: - { - // The most interesting case is when we are trying to extract a particular - // payload (one of the case types) from a union. We may need to extract - // one or more fields from the data stored in the union's replacement - // type (the bulk/rest fields), and we may also have to convert them - // to the type expected via bit-casts. - - // We can start things off easily enough by extracting the tagged union - // value being operated on, as well as the information for its type. - // - auto taggedUnionVal = inst->getOperand(0); - auto taggedUnionInfo = getTaggedUnionInfo(taggedUnionVal->getDataType()); - - // Next we need to figure out which case is being extracted from the union. - // The operand for the case tag should be a literal by construction. - // - auto caseTagVal = inst->getOperand(1); - auto caseTagConst = as(caseTagVal); - SLANG_ASSERT(caseTagConst); - - // The case type we are extracting will be the result type of the instruciton. - // - auto caseType = inst->getDataType(); - // - // The tag value itself will be the index of the case type in the union - // type (and its layout). - // - auto caseTagIndex = UInt(caseTagConst->getValue()); - - // We can use the case tag value to look up the layout for the particular - // case type we are extracting (this will allow us to resolve byte offsets - // for fields, etc.). - // - auto taggedUnionTypeLayout = taggedUnionInfo->taggedUnionTypeLayout; - SLANG_ASSERT(caseTagIndex < UInt(taggedUnionTypeLayout->caseTypeLayouts.getCount())); - auto caseTypeLayout = taggedUnionTypeLayout->caseTypeLayouts[caseTagIndex]; - - // At this point we know the type we are trying to extract, as well - // as its layout. We will defer the actual implementation of extraction - // to a (recursive) subroutine that can extract a (sub-)field from the - // union at a given byte offset. Since we are extracting a full case - // right now, the byte offset will be zero. - // - auto payloadVal = extractPayload( - taggedUnionInfo, - taggedUnionVal, - caseType, - caseTypeLayout, - 0); - - // TODO: There is a significant flaw in the above approach when - // the case type might be (or contain) an array. If we have a setup - // like the following: - // - // union SomeUnion { float someCase[100]; ... } - // ... - // float result = someUnion.someCase[someIndex]; - // - // The current logic would desugar this into something like: - // - // struct SomeUnion { uint4 bulk[100]; ... } - // ... - // float[] tmp = { asfloat(someUnion.bulk[0].x), asfloat(someUnion.bulk[1].x), ... } - // float result = tmp[someIndex]; - // - // The result is that we copy an entire 100-element array into local memory - // just to fetch a single element, when it would be much nicer to just do: - // - // float result = asfloat(someUnion.bulk[someIndex].x); - // - // Achieving the latter code requires that rather than blindly translate - // the `extractTaggedUnionPayload` instruction into a semantically equiavlent - // value (which might lead to a big copy in the end), we should transitively - // chase down any "access chains" off of `inst` and see what leaf values are - // actually needed, and generated more tailored extraction logic for just - // the elements/fields that actually get referenced. - // - // The more refined approach can be built on top of many of the same primitives, - // so for now we will resign ourselves to the simpler but potentially less - // efficient approach. - - // Now that we've extracted the value for the payload from the fields of - // the replacement struct, we can use that extracted value to replace - // this instruction, and schedule the original instruction for removal. - // - inst->replaceUsesWith(payloadVal); - instsToRemove.add(inst); - } - break; - } - } - - // The `extractPayload` operation is the most important bit of translation we - // need to do to make unions work. We have as input the following: - // - IRInst* extractPayload( - - // - Information about a tagged union type and its layout. - TaggedUnionInfo* taggedUnionInfo, - - // - A single value of that tagged unon type. - IRInst* taggedUnionVal, - - // - Type type of some "payload" field we want to extract from the union. - IRType* payloadType, - - // - The memory layout of that payload type. - TypeLayout* payloadTypeLayout, - - // - The byte offset at which we want to fetch the payload. - UInt64 payloadOffset) - { - // We are going to be building some IR code no matter what. - // - auto builder = getBuilder(); - - // The basic approach here will be to look at the type we - // are trying to extract from the union, and whenever possible - // recursively walk its structure so that we can express things - // in terms of extraction of smaller/simpler types. - // - if( auto irStructType = as(payloadType) ) - { - // A structure type is a nice recursive case: we simply - // want to extract each of its field recursively, and - // then construct a fresh value of the `struct` type. - - // In all of the cases of this function we expect/require - // there to be complete type layout information for the - // types involved. - // - auto structTypeLayout = as(payloadTypeLayout); - SLANG_ASSERT(structTypeLayout); - - // We are going to emit code to extract each of the fields - // and collect them to use as operands to a `makeStruct`. - // - List fieldVals; - - // We need to walk over the fields in the order the IR expects them - UInt fieldCounter = 0; - for( auto irField : irStructType->getFields() ) - { - IRType* fieldType = irField->getFieldType(); - - // TODO: We need to confirm/enforce that the fields of the - // IR struct and the fields of the layout still align. - // - UInt fieldIndex = fieldCounter++; - auto fieldLayout = structTypeLayout->fields[fieldIndex]; - auto fieldTypeLayout = fieldLayout->getTypeLayout(); - - // The offset of the field can be computed from the base - // offset passed in, plus the reflection data for the field. - // - UInt64 fieldOffset = payloadOffset; - if(auto resInfo = fieldLayout->FindResourceInfo(LayoutResourceKind::Uniform)) - fieldOffset += resInfo->index; - - // We make a recursive call to extract each field, expecting - // that this will bottom out eventually. - // - IRInst* fieldVal = extractPayload( - taggedUnionInfo, - taggedUnionVal, - fieldType, - fieldTypeLayout, - fieldOffset); - fieldVals.add(fieldVal); - } - - // The final value is then just a new struct constructed from - // the extracted field values. - // - auto payloadVal = builder->emitMakeStruct(irStructType, fieldVals); - return payloadVal; - } - else if( auto vecType = as(payloadType) ) - { - auto elementType = vecType->getElementType(); - - // We expect that by the time we are desugaring union types - // all vector types have literal constant values for their - // element count. - // - auto elementCountVal = vecType->getElementCount(); - auto elementCountConst = as(elementCountVal); - SLANG_ASSERT(elementCountConst); - UInt elementCount = UInt(elementCountConst->getValue()); - - // HACK: There is currently no `VectorTypeLayout` and thus - // no way to query the layout of the elements of a vector - // type. Until that gets added we will kludge things here. - // - TypeLayout* elementTypeLayout = nullptr; - size_t elementSize = 0; - if(auto resInfo = payloadTypeLayout->FindResourceInfo(LayoutResourceKind::Uniform)) - elementSize = resInfo->count.getFiniteValue() / elementCount; - - // Similar to the `struct` case above, we will extract a - // value for each element of the vector, and then use - // `makeVector` to construct the result value. - // - List elementVals; - for(UInt ii = 0; ii < elementCount; ++ii) - { - auto elementVal = extractPayload( - taggedUnionInfo, - taggedUnionVal, - elementType, - elementTypeLayout, - payloadOffset + ii*elementSize); - elementVals.add(elementVal); - } - return builder->emitMakeVector(vecType, elementVals); - } - else if( auto matType = as(payloadType) ) - { - SLANG_UNIMPLEMENTED_X("matrix in union type"); - } - else if( auto arrayType = as(payloadType) ) - { - SLANG_UNIMPLEMENTED_X("array in union type"); - } - else - { - // If none of the above cases match, then we assume that - // we have an individual scalar field that we need to fetch. - // - UInt64 payloadSize = 0; - if( auto resInfo = payloadTypeLayout->FindResourceInfo(LayoutResourceKind::Uniform) ) - { - // TODO: somebody before this point should generate an error if - // we have a `union` type that contains a potentially unbounded - // amount of data. - // - payloadSize = resInfo->count.getFiniteValue(); - } - - if( payloadSize != 4 ) - { - // TODO: We should handle the case of 64-bit fields by fetching - // two `uint` values to form a `uint2`, and then using an - // appropriate bit-cast to get from `uint2` to, e.g., `double`. - // - // The case of 16-bit and smaller fields is more troublesome, but - // in the worst case we can load a `uint` and then use bitwise - // ops to extract what we need before bitcasting. - // - // The right long-term solution is for downstream languages to have - // better support for raw memory addressing. - - SLANG_UNIMPLEMENTED_X("leaf union field with size other than 4 bytes"); - } - - // We know that we want to fetch a value of size `payloadSize`, and - // we have a known base value and an initial offset into it. - // - IRInst* baseVal = taggedUnionVal; - UInt64 offset = payloadOffset; - - // We are going to refine our `baseVal` and `offset` as we go, by - // trying to narrow down the data we will access in the `struct` - // type that will provide storage for the union. - // - // The first thing we want to check is if the value sits in the - // "bulk" part of the storage, or the "rest." - // - UInt64 bulkSize = taggedUnionInfo->bulkSize; - if( offset < bulkSize ) - { - // If the value starts in the bulk area, then the whole - // thing had better fit in the bulk area. The 16-byte - // granularity rules for constant buffers should ensure - // this property for us on current targets. - // - SLANG_ASSERT(offset + payloadSize <= bulkSize); - - // Since we know we'll be accessing the bulk storage, - // we will extract it here. The extracted field will - // be our new base value, but the `offset` doesn't need - // to be updated since the bulk field sits at offset 0. - // - baseVal = builder->emitFieldExtract( - taggedUnionInfo->bulkFieldType, - baseVal, - taggedUnionInfo->bulkFieldKey); - - // The bulk storage could be an array, if there are 32 - // or more bytes of bulk storage. - // - if( auto baseArrayType = as(baseVal->getDataType()) ) - { - // If an array was allocated for bulk storage then - // our leaf value resides entirely within a single - // element (due to constant buffer layout rules), - // and so we will fetch the appropriate element here. - // - // We will change our `baseVal` to the extracted element, - // and then also adjust our `offset` to be relative - // to that element. - // - size_t bulkElementSize = 16; - auto index = offset / bulkElementSize; - baseVal = builder->emitElementExtract( - baseArrayType->getElementType(), - baseVal, - builder->getIntValue(builder->getIntType(), index)); - offset -= index*bulkElementSize; - } - } - else - { - // If the offset of the field we want is past the end of - // the bulk field then it must sit inside of the rest field, - // and we'll extract it here. This establishes a new - // base value, and we adjust the `offset` to be relative - // to the rest field (which starts at an offset equal to `bulkSize`). - // - baseVal = builder->emitFieldExtract( - taggedUnionInfo->restFieldType, - baseVal, - taggedUnionInfo->restFieldKey); - offset -= bulkSize; - } - - // We've now extracted a field that could be either a scalar or - // a vector, and we have an offset into it. In the case where - // the base value is a vector, we will extract out the appropriate - // element. - // - if( auto baseVecType = as(baseVal->getDataType()) ) - { - size_t vecElementSize = 4; - auto index = offset / vecElementSize; - baseVal = builder->emitElementExtract( - baseVecType->getElementType(), - baseVal, - builder->getIntValue(builder->getIntType(), index)); - offset -= index*vecElementSize; - } - - // At this point, our `baseVal` should be a single `uint`, and - // it should provide the storage for the exact thing we wanted - // to access (under the assumption that we always fetch 4 bytes - // on 4-byte alignment). - // - IRInst* payloadVal = baseVal; - SLANG_ASSERT(offset == 0); - - // TODO: we could imagine adding logic here to handle types less - // than 4 bytes in size by shifting and masking the value we - // just loaded. - - // The payload field we were trying to extract might have a type - // other than `uint`, and to handle that case we need to employ - // a bit-cast to get to the desired type. - // - if( payloadVal->getDataType() != payloadType ) - { - payloadVal = builder->emitBitCast( - payloadType, - payloadVal); - } - return payloadVal; - } - } - - // All of the logic so far as assumed we can just call `getTaggedUnionInfo` - // and have easy access to all the required information and the - // synthesized replacement type. - // - TaggedUnionInfo* getTaggedUnionInfo(IRType* type) - { - // The big picture is fairly simple: we will lazily build and - // memoize the information about tagged unions. - // - { - TaggedUnionInfo* info = nullptr; - if(mapIRTypeToTaggedUnionInfo.TryGetValue(type, info)) - return info; - } - - // When we don't find information in our memo-cache, we - // will construct it and add it to both the memo-cache - // *and* a global list of all tagged unions encountered, - // so that we can replacement them later. - // - auto info = createTaggedUnionInfo(type); - mapIRTypeToTaggedUnionInfo.Add(type, info.Ptr()); - taggedUnionInfos.add(info); - - return info; - } - - // The actual logic for creating a `TaggedUnionInfo` is relatively - // straightforward once we've decided what information we need. - // - RefPtr createTaggedUnionInfo(IRType* type) - { - // We expect that any type used as an operation to one of the - // `extractTaggedUnion*` operations must be an IR tagged union. - // - // Note: If/when we ever expose `union`s to user and allow - // then to create *generic* tagged union types it might appear - // that this needs to be changed to account for a `specialize` - // instruction in place of a concrete tagged union, but in - // practice this pass needs to be performed late enough that - // any such generic should be fully specialized. - // - auto taggedUnionType = as(type); - SLANG_ASSERT(taggedUnionType); - - RefPtr info = new TaggedUnionInfo(); - info->taggedUnionType = taggedUnionType; - - // We are going to create an instruction to replace `type`, - // and thus will be placing it into the same parent. - // - auto builder = getBuilder(); - builder->setInsertBefore(type); - - // A tagged union type will be replaced with an ordinary - // `struct` type with fields to store all the relevant - // data from any of the cases, plus a tag field. - // - auto structType = builder->createStructType(); - info->replacementInst = structType; - - // We require/expect the earlier code generation steps to have - // associated a layout with every tagged union that appears in - // the code. - // - auto layoutDecoration = type->findDecoration(); - SLANG_ASSERT(layoutDecoration); - auto layout = layoutDecoration->getLayout(); - SLANG_ASSERT(layout); - auto taggedUnionTypeLayout = as(layout); - SLANG_ASSERT(taggedUnionTypeLayout); - - info->taggedUnionTypeLayout = taggedUnionTypeLayout; - - // The size of the "payload" for the different cases (everything but - // the tag) is taken to be the offset of the tag itself. - // - // TODO: this might be inaccurate if the payload size isn't a multiple - // of the tag's alignment. We should deal with that when/if we support - // types smaller than 4 bytes in unions. - // - auto payloadSize = taggedUnionTypeLayout->tagOffset.getFiniteValue(); - - // We are going to be construction IR code that makes use of the `int` - // and `uint` types in several cases, so we go ahead and get a pointer - // to those types here. - // - auto intType = getBuilder()->getIntType(); - auto uintType = getBuilder()->getBasicType(BaseType::UInt); - - // For now we will use a simple stragegy for how we encode a union, - // which depends only on the total number of bytes needed, and not - // on the makeup of the values being stored. - // - // We will start by allocating one or more `uint4` values (in an - // array for the "or more" case) to hold the bulk of any large - // payload value. - // - size_t bulkVectorSize = 16; // Note: assuming `sizeof(uint4) == 16` on all targets - auto bulkVectorCount = payloadSize / bulkVectorSize; - auto bulkFieldSize = bulkVectorCount * bulkVectorSize; - if( bulkVectorCount ) - { - IRType* bulkFieldType = builder->getVectorType( - uintType, - builder->getIntValue(intType, 4)); - - if( bulkVectorCount > 1 ) - { - bulkFieldType = builder->getArrayType( - bulkFieldType, - builder->getIntValue(intType, bulkVectorCount)); - } - - auto bulkFieldKey = builder->createStructKey(); - builder->createStructField(structType, bulkFieldKey, bulkFieldType); - - info->bulkFieldKey = bulkFieldKey; - info->bulkFieldType = bulkFieldType; - } - info->bulkSize = bulkFieldSize; - - // The rest of the data (anything that doesn't fit in the bulk field), - // will get allocated into a single scalar or vector of `uint`. - // - auto restSize = payloadSize - bulkFieldSize; - if( restSize ) - { - size_t restElementSize = 4; // assuming `sizeof(uint) == 4` on all targets - auto restElementCount = restSize / restElementSize; - auto restFieldSize = restElementSize * restElementCount; - SLANG_ASSERT(restFieldSize == restSize); // Note: all our current targets have minimum 4-byte storage granularity - - IRType* restFieldType = uintType; - if( restElementCount > 1 ) - { - restFieldType = builder->getVectorType( - restFieldType, - builder->getIntValue(intType, restElementCount)); - } - - auto restFieldKey = builder->createStructKey(); - builder->createStructField(structType, restFieldKey, restFieldType); - - info->restFieldKey = restFieldKey; - info->restFieldType = restFieldType; - info->restSize = restFieldSize; - } - - // Finally, we add a field to represent the tag. - // - auto tagFieldType = uintType; - auto tagFieldKey = builder->createStructKey(); - builder->createStructField(structType, tagFieldKey, tagFieldType); - - info->tagFieldKey = tagFieldKey; - - return info; - } -}; - -void desugarUnionTypes( - IRModule* module) -{ - DesugarUnionTypesContext context; - context.module = module; - - context.processModule(); -} - -} // namespace Slang diff --git a/source/slang/ir-union.h b/source/slang/ir-union.h deleted file mode 100644 index 58de4e81e..000000000 --- a/source/slang/ir-union.h +++ /dev/null @@ -1,18 +0,0 @@ -// ir-union.h -#pragma once - -namespace Slang { - -struct IRModule; - - /// Desugar any unions types, and code using them, in `module` - /// - /// Union types will be replaced with ordinary `struct` types that store - /// the data of the underlying type as a "bag of bits" and references - /// to cases of the union will be replaced with logic to extract the - /// relevant bits. - /// -void desugarUnionTypes( - IRModule* module); - -} // namespace Slang diff --git a/source/slang/ir-validate.cpp b/source/slang/ir-validate.cpp deleted file mode 100644 index 9564873b1..000000000 --- a/source/slang/ir-validate.cpp +++ /dev/null @@ -1,207 +0,0 @@ -// ir-validate.cpp -#include "ir-validate.h" - -#include "ir.h" -#include "ir-insts.h" - -namespace Slang -{ - struct IRValidateContext - { - // The IR module we are validating. - IRModule* module; - - // A diagnostic sink to send errors to if anything is invalid. - DiagnosticSink* sink; - - DiagnosticSink* getSink() { return sink; } - - // A set of instructions we've seen, to help confirm that - // values are defined before they are used in a given block. - HashSet seenInsts; - }; - - void validateIRInst( - IRValidateContext* context, - IRInst* inst); - - void validate(IRValidateContext* context, bool condition, IRInst* inst, char const* message) - { - if (!condition) - { - context->getSink()->diagnose(inst, Diagnostics::irValidationFailed, message); - } - } - - void validateIRInstChildren( - IRValidateContext* context, - IRInst* parent) - { - IRInst* prevChild = nullptr; - for(auto child : parent->getDecorationsAndChildren() ) - { - // We need to check the integrity of the parent/next/prev links of - // all of our instructions - validate(context, child->parent == parent, child, "parent link"); - validate(context, child->prev == prevChild, child, "next/prev link"); - - // Recursively validate the instruction itself. - validateIRInst(context, child); - - // Do some extra validation around terminator instructions: - // - // * The last instruction of a block should always be a terminator - // * No other instruction should be a terminator - // - if(as(parent) && (child == parent->getLastDecorationOrChild())) - { - validate(context, as(child) != nullptr, child, "last instruction in block must be terminator"); - } - else - { - validate(context, !as(child), child, "terminator must be last instruction in a block"); - } - - - prevChild = child; - } - } - - void validateIRInstOperand( - IRValidateContext* context, - IRInst* inst, - IRUse* operandUse) - { - // The `IRUse` for the operand had better have `inst` as its user. - validate(context, operandUse->getUser() == inst, inst, "operand user"); - - // The value we are using needs to fit into one of a few cases. - // - // * If the parent of `inst` and of `operand` is the same block, then - // we require that `operand` is defined before `inst` - // - // * If the parents of `inst` and `operand` are both blocks in the - // same functin, then the block defining `operand` must dominate - // the block defining `inst`. - // - // * Otherwise, we simply require that the parent of `operand` be - // an ancestor (transitive parent) of `inst`. - - auto instParent = inst->getParent(); - - auto operandValue = operandUse->get(); - - if( !operandValue ) - { - // A null operand should almost always be an error, but - // we currently have a few cases where this arises. - // - // TODO: plug the leaks. - return; - } - - auto operandParent = operandValue->getParent(); - - if (auto instParentBlock = as(instParent)) - { - if (auto operandParentBlock = as(operandParent)) - { - if (instParentBlock == operandParentBlock) - { - // If `operandValue` precedes `inst`, then we should - // have already seen it, because we scan parent instructions - // in order. - validate(context, context->seenInsts.Contains(operandValue), inst, "def must come before use in same block"); - return; - } - - auto instFunc = instParentBlock->getParent(); - auto operandFunc = operandParentBlock->getParent(); - if (instFunc == operandFunc) - { - // The two instructions are defined in different blocks of - // the same function (or another value with code). We need - // to validate that `operandParentBlock` dominates `instParentBlock`. - // - // TODO: implement this validation once we compute dominator trees. - // - // validate(context, operandParentBlock->dominates(instParentBlock), inst, "def must dominate use"); - return; - } - } - } - - // If the special cases above did not trigger, then either the two values - // are nested in the same parent, but that parent isn't a block, or they - // are nested in distinct parents, and those parents aren't both children - // of a function. - // - // In either case, we need to enforce that the parent of `operand` needs - // to be an ancestor of `inst`. - // - for (auto pp = instParent; pp; pp = pp->getParent()) - { - if (pp == operandParent) - return; - } - // - // We failed to find `operandParent` while walking the ancestors of `inst`, - // so something had gone wrong. - validate(context, false, inst, "def must be ancestor of use"); - } - - void validateIRInstOperands( - IRValidateContext* context, - IRInst* inst) - { - if(inst->getFullType()) - validateIRInstOperand(context, inst, &inst->typeUse); - - UInt operandCount = inst->getOperandCount(); - for (UInt ii = 0; ii < operandCount; ++ii) - { - validateIRInstOperand(context, inst, inst->getOperands() + ii); - } - } - - void validateIRInst( - IRValidateContext* context, - IRInst* inst) - { - // Validate that any operands of the instruction are used appropriately - validateIRInstOperands(context, inst); - context->seenInsts.Add(inst); - - // If `inst` is itself a parent instruction, then we need to recursively - // validate its children. - validateIRInstChildren(context, inst); - } - - void validateIRModule(IRModule* module, DiagnosticSink* sink) - { - IRValidateContext contextStorage; - IRValidateContext* context = &contextStorage; - context->module = module; - context->sink = sink; - - auto moduleInst = module->moduleInst; - - validate(context, moduleInst != nullptr, moduleInst, "module instruction"); - validate(context, moduleInst->parent == nullptr, moduleInst, "module instruction parent"); - validate(context, moduleInst->prev == nullptr, moduleInst, "module instruction prev"); - validate(context, moduleInst->next == nullptr, moduleInst, "module instruction next"); - - validateIRInst(context, module->moduleInst); - } - - void validateIRModuleIfEnabled( - CompileRequestBase* compileRequest, - IRModule* module) - { - if (!compileRequest->shouldValidateIR) - return; - - auto sink = compileRequest->getSink(); - validateIRModule(module, sink); - } -} diff --git a/source/slang/ir-validate.h b/source/slang/ir-validate.h deleted file mode 100644 index 1cb30961d..000000000 --- a/source/slang/ir-validate.h +++ /dev/null @@ -1,35 +0,0 @@ -// ir-validate.h -#pragma once - -namespace Slang -{ - class CompileRequestBase; - class DiagnosticSink; - struct IRModule; - - - // Validate that an IR module obeys the invariants we need to enforce. - // For example: - // - // * Confirm that linked lists for children and for use-def chains are consistent - // (e.g., x.next.prev == x) - // - // * Confirm that parent/child relationships are correct (e.g., if is `x` is in - // `y.children`, then `x.parent == y` - // - // * Confirm that every operand of an instruction is valid to reference (i.e., it - // must either be defined earlier in the same block, in a different block that - // dominates the current one, or in a parent instruction of the block. - // - // * Confirm that every block ends with a terminator, and there are no terminators - // elsewhere in a block. - // - // * Confirm that all the parameters of a block come before any "ordinary" instructions. - void validateIRModule(IRModule* module, DiagnosticSink* sink); - - // A wrapper that calls `validateIRModule` only when IR validation is enabled - // for the given compile request. - void validateIRModuleIfEnabled( - CompileRequestBase* compileRequest, - IRModule* module); -} diff --git a/source/slang/ir.cpp b/source/slang/ir.cpp deleted file mode 100644 index 1045b4884..000000000 --- a/source/slang/ir.cpp +++ /dev/null @@ -1,4511 +0,0 @@ -// ir.cpp -#include "ir.h" -#include "ir-insts.h" - -#include "../core/basic.h" - -#include "mangle.h" - -namespace Slang -{ - struct IRSpecContext; - - IRInst* cloneGlobalValueWithLinkage( - IRSpecContext* context, - IRInst* originalVal, - IRLinkageDecoration* originalLinkage); - - struct IROpMapEntry - { - IROp op; - IROpInfo info; - }; - - // TODO: We should ideally be speeding up the name->inst - // mapping by using a dictionary, or even by pre-computing - // a hash table to be stored as a `static const` array. - // - // NOTE! That this array is now constructed in such a way that looking up - // an entry from an op is fast, by keeping blocks of main, and pseudo ops in same order - // as the ops themselves. Care must be taken to keep this constraint. - static const IROpMapEntry kIROps[] = - { - - // Main ops in order -#define INST(ID, MNEMONIC, ARG_COUNT, FLAGS) \ - { kIROp_##ID, { #MNEMONIC, ARG_COUNT, FLAGS, } }, -#include "ir-inst-defs.h" - - // Pseudo ops -#define INST(ID, MNEMONIC, ARG_COUNT, FLAGS) /* empty */ -#define PSEUDO_INST(ID) \ - { kIRPseudoOp_##ID, { #ID, 0, 0 } }, - - // First is 'invalid' - { kIROp_Invalid,{ "invalid", 0, 0 } }, - // Then all the other psuedo ops -#include "ir-inst-defs.h" - - }; - - IROpInfo getIROpInfo(IROp opIn) - { - const int op = opIn & kIROpMeta_PseudoOpMask; - if ((op & kIROpMeta_IsPseudoOp) && op < kIRPseudoOp_LastPlusOne) - { - // It's a pseudo op - const int index = op - kIRPseudoOp_First; - // Pseudo ops start from kIROpcount - const auto& entry = kIROps[kIROpCount + index]; - SLANG_ASSERT(entry.op == op); - return entry.info; - } - else if (op < kIROpCount) - { - // It's a main op - const auto& entry = kIROps[op]; - SLANG_ASSERT(entry.op == op); - return entry.info; - } - - // Don't know what this is - SLANG_ASSERT(!"Invalid op"); - SLANG_ASSERT(kIROps[kIROpCount].op == kIROp_Invalid); - return kIROps[kIROpCount].info; - } - - IROp findIROp(const UnownedStringSlice& name) - { - for (auto ee : kIROps) - { - if (name == ee.info.name) - return ee.op; - } - - return IROp(kIROp_Invalid); - } - - - - // - - void IRUse::debugValidate() - { -#ifdef _DEBUG - auto uv = this->usedValue; - if(!uv) - { - assert(!nextUse); - assert(!prevLink); - return; - } - - auto pp = &uv->firstUse; - for(auto u = uv->firstUse; u;) - { - assert(u->prevLink == pp); - - pp = &u->nextUse; - u = u->nextUse; - } -#endif - } - - void IRUse::init(IRInst* u, IRInst* v) - { - clear(); - - user = u; - usedValue = v; - if(v) - { - nextUse = v->firstUse; - prevLink = &v->firstUse; - - if(nextUse) - { - nextUse->prevLink = &this->nextUse; - } - - v->firstUse = this; - } - - debugValidate(); - } - - void IRUse::set(IRInst* uv) - { - init(user, uv); - } - - void IRUse::clear() - { - // This `IRUse` is part of the linked list - // of uses for `usedValue`. - - debugValidate(); - - if (usedValue) - { - auto uv = usedValue; - - *prevLink = nextUse; - if(nextUse) - { - nextUse->prevLink = prevLink; - } - - user = nullptr; - usedValue = nullptr; - nextUse = nullptr; - prevLink = nullptr; - - if(uv->firstUse) - uv->firstUse->debugValidate(); - } - } - - // IRInstListBase - - void IRInstListBase::Iterator::operator++() - { - if (inst) - { - inst = inst->next; - } - } - - IRInstListBase::Iterator IRInstListBase::begin() { return Iterator(first); } - IRInstListBase::Iterator IRInstListBase::end() { return Iterator(last ? last->next : nullptr); } - - // - - IRUse* IRInst::getOperands() - { - // We assume that *all* instructions are laid out - // in memory such that their arguments come right - // after the first `sizeof(IRInst)` bytes. - // - // TODO: we probably need to be careful and make - // this more robust. - - return (IRUse*)(this + 1); - } - - IRDecoration* IRInst::findDecorationImpl(IROp decorationOp) - { - for(auto dd : getDecorations()) - { - if(dd->op == decorationOp) - return dd; - } - return nullptr; - } - - // IRConstant - - IRIntegerValue GetIntVal(IRInst* inst) - { - switch (inst->op) - { - default: - SLANG_UNEXPECTED("needed a known integer value"); - UNREACHABLE_RETURN(0); - - case kIROp_IntLit: - return static_cast(inst)->value.intVal; - break; - } - } - - // IRParam - - IRParam* IRParam::getNextParam() - { - return as(getNextInst()); - } - - // IRArrayTypeBase - - IRInst* IRArrayTypeBase::getElementCount() - { - if (auto arrayType = as(this)) - return arrayType->getElementCount(); - - return nullptr; - } - - // IRPtrTypeBase - - IRType* tryGetPointedToType( - IRBuilder* builder, - IRType* type) - { - if( auto rateQualType = as(type) ) - { - type = rateQualType->getDataType(); - } - - // The "true" pointers and the pointer-like stdlib types are the easy cases. - if( auto ptrType = as(type) ) - { - return ptrType->getValueType(); - } - else if( auto ptrLikeType = as(type) ) - { - return ptrLikeType->getElementType(); - } - // - // A more interesting case arises when we have a `BindExistentials, ...>` - // where `P` is a pointer(-like) type. - // - else if( auto bindExistentials = as(type) ) - { - // We know that `BindExistentials` won't introduce its own - // existential type parameters, nor will any of the pointer(-like) - // type constructors `P`. - // - // Thus we know that the type that is pointed to should be - // the same as `BindExistentials`. - // - auto baseType = bindExistentials->getBaseType(); - if( auto baseElementType = tryGetPointedToType(builder, baseType) ) - { - UInt existentialArgCount = bindExistentials->getExistentialArgCount(); - List existentialArgs; - for( UInt ii = 0; ii < existentialArgCount; ++ii ) - { - existentialArgs.add(bindExistentials->getExistentialArg(ii)); - } - return builder->getBindExistentialsType( - baseElementType, - existentialArgCount, - existentialArgs.getBuffer()); - } - } - - // TODO: We may need to handle other cases here. - - return nullptr; - } - - - // IRBlock - - IRParam* IRBlock::getLastParam() - { - IRParam* param = getFirstParam(); - if (!param) return nullptr; - - while (auto nextParam = param->getNextParam()) - param = nextParam; - - return param; - } - - void IRBlock::addParam(IRParam* param) - { - // If there are any existing parameters, - // then insert after the last of them. - // - if (auto lastParam = getLastParam()) - { - param->insertAfter(lastParam); - } - // - // Otherwise, if there are any existing - // "ordinary" instructions, insert before - // the first of them. - // - else if(auto firstOrdinary = getFirstOrdinaryInst()) - { - param->insertBefore(firstOrdinary); - } - // - // Otherwise the block currently has neither - // parameters nor orindary instructions, - // so we can safely insert at the end of - // the list of (raw) children. - // - else - { - param->insertAtEnd(this); - } - } - - IRInst* IRBlock::getFirstOrdinaryInst() - { - // Find the last parameter (if any) of the block - auto lastParam = getLastParam(); - if (lastParam) - { - // If there is a last parameter, then the - // instructions after it are the ordinary - // instructions. - return lastParam->getNextInst(); - } - else - { - // If there isn't a last parameter, then - // there must not have been *any* parameters, - // and so the first instruction in the block - // is also the first ordinary one. - return getFirstInst(); - } - } - - IRInst* IRBlock::getLastOrdinaryInst() - { - // Under normal circumstances, the last instruction - // in the block is also the last ordinary instruction. - // However, there is the special case of a block with - // only parameters (which might happen as a temporary - // state while we are building IR). - auto inst = getLastInst(); - - // If the last instruction is a parameter, then - // there are no ordinary instructions, so the last - // one is a null pointer. - if (as(inst)) - return nullptr; - - // Otherwise the last instruction is the last "ordinary" - // instruction as well. - return inst; - } - - - // The predecessors of a block should all show up as users - // of its value, so rather than explicitly store the CFG, - // we will recover it on demand from the use-def information. - // - // Note: we are really iterating over incoming/outgoing *edges* - // for a block, because there might be multiple uses of a block, - // if more than one way of an N-way branch targets the same block. - - // Get the list of successor blocks for an instruction, - // which we expect to be the last instruction in a block. - static IRBlock::SuccessorList getSuccessors(IRInst* terminator) - { - // If the block somehow isn't terminated, then - // there is no way to read its successors, so - // we return an empty list. - if (!terminator || !as(terminator)) - return IRBlock::SuccessorList(nullptr, nullptr); - - // Otherwise, based on the opcode of the terminator - // instruction, we will build up our list of uses. - IRUse* begin = nullptr; - IRUse* end = nullptr; - UInt stride = 1; - - auto operands = terminator->getOperands(); - switch (terminator->op) - { - case kIROp_ReturnVal: - case kIROp_ReturnVoid: - case kIROp_Unreachable: - case kIROp_MissingReturn: - case kIROp_discard: - break; - - case kIROp_unconditionalBranch: - case kIROp_loop: - // unconditonalBranch - begin = operands + 0; - end = begin + 1; - break; - - case kIROp_conditionalBranch: - case kIROp_ifElse: - // conditionalBranch - begin = operands + 1; - end = begin + 2; - break; - - case kIROp_Switch: - // switch ... - begin = operands + 2; - - // TODO: this ends up point one *after* the "one after the end" - // location, so we should really change the representation - // so that we don't need to form this pointer... - end = operands + terminator->getOperandCount() + 1; - stride = 2; - break; - - default: - SLANG_UNEXPECTED("unhandled terminator instruction"); - UNREACHABLE_RETURN(IRBlock::SuccessorList(nullptr, nullptr)); - } - - return IRBlock::SuccessorList(begin, end, stride); - } - - static IRUse* adjustPredecessorUse(IRUse* use) - { - // We will search until we either find a - // suitable use, or run out of uses. - for (;use; use = use->nextUse) - { - // We only want to deal with uses that represent - // a "sucessor" operand to some terminator instruction. - // We will re-use the logic for getting the successor - // list from such an instruction. - - auto successorList = getSuccessors((IRInst*) use->getUser()); - - if(use >= successorList.begin_ - && use < successorList.end_) - { - UInt index = (use - successorList.begin_); - if ((index % successorList.stride) == 0) - { - // This use is in the range of the sucessor list, - // and so it represents a real edge between - // blocks. - return use; - } - } - } - - // If we ran out of uses, then we are at the end - // of the list of incoming edges. - return nullptr; - } - - IRBlock::PredecessorList IRBlock::getPredecessors() - { - // We want to iterate over the predecessors of this block. - // First, we resign ourselves to iterating over the - // incoming edges, rather than the blocks themselves. - // This might sound like a trival distinction, but it is - // possible for there to be multiple edges between two - // blocks (as for a `switch` with multiple cases that - // map to the same code). Any client that wants just - // the unique predecessor blocks needs to deal with - // the deduplication themselves. - // - // Next, we note that for any predecessor edge, there will - // be a use of this block in the terminator instruction of - // the predecessor. We basically just want to iterate over - // the users of this block, then, but we need to be careful - // to rule out anything that doesn't actually represent - // an edge. The `adjustPredecessorUse` function will be - // used to search for a use that actually represents an edge. - - return PredecessorList( - adjustPredecessorUse(firstUse)); - } - - UInt IRBlock::PredecessorList::getCount() - { - UInt count = 0; - for (auto ii : *this) - { - (void)ii; - count++; - } - return count; - } - - bool IRBlock::PredecessorList::isEmpty() - { - return !(begin() != end()); - } - - - void IRBlock::PredecessorList::Iterator::operator++() - { - if (!use) return; - use = adjustPredecessorUse(use->nextUse); - } - - IRBlock* IRBlock::PredecessorList::Iterator::operator*() - { - if (!use) return nullptr; - return (IRBlock*)use->getUser()->parent; - } - - IRBlock::SuccessorList IRBlock::getSuccessors() - { - // The successors of a block will all be listed - // as operands of its terminator instruction. - // Depending on the terminator, we might have - // different numbers of operands to deal with. - // - // (We might also have to deal with a "stride" - // in the case where the basic-block operands - // are mixed up with non-block operands) - - auto terminator = getLastInst(); - return Slang::getSuccessors(terminator); - } - - UInt IRBlock::SuccessorList::getCount() - { - UInt count = 0; - for (auto ii : *this) - { - (void)ii; - count++; - } - return count; - } - - void IRBlock::SuccessorList::Iterator::operator++() - { - use += stride; - } - - IRBlock* IRBlock::SuccessorList::Iterator::operator*() - { - return (IRBlock*)use->get(); - } - - UInt IRUnconditionalBranch::getArgCount() - { - switch(op) - { - case kIROp_unconditionalBranch: - return getOperandCount() - 1; - - case kIROp_loop: - return getOperandCount() - 3; - - default: - SLANG_UNEXPECTED("unhandled unconditional branch opcode"); - UNREACHABLE_RETURN(0); - } - } - - IRUse* IRUnconditionalBranch::getArgs() - { - switch(op) - { - case kIROp_unconditionalBranch: - return getOperands() + 1; - - case kIROp_loop: - return getOperands() + 3; - - default: - SLANG_UNEXPECTED("unhandled unconditional branch opcode"); - UNREACHABLE_RETURN(0); - } - } - - IRInst* IRUnconditionalBranch::getArg(UInt index) - { - return getArgs()[index].usedValue; - } - - IRParam* IRGlobalValueWithParams::getFirstParam() - { - auto entryBlock = getFirstBlock(); - if(!entryBlock) return nullptr; - - return entryBlock->getFirstParam(); - } - - IRParam* IRGlobalValueWithParams::getLastParam() - { - auto entryBlock = getFirstBlock(); - if(!entryBlock) return nullptr; - - return entryBlock->getLastParam(); - } - - IRInstList IRGlobalValueWithParams::getParams() - { - auto entryBlock = getFirstBlock(); - if(!entryBlock) return IRInstList(); - - return entryBlock->getParams(); - } - - - // IRFunc - - IRType* IRFunc::getResultType() { return getDataType()->getResultType(); } - UInt IRFunc::getParamCount() { return getDataType()->getParamCount(); } - IRType* IRFunc::getParamType(UInt index) { return getDataType()->getParamType(index); } - - void IRGlobalValueWithCode::addBlock(IRBlock* block) - { - block->insertAtEnd(this); - } - - void fixUpFuncType(IRFunc* func) - { - SLANG_ASSERT(func); - - auto irModule = func->getModule(); - SLANG_ASSERT(irModule); - - SharedIRBuilder sharedBuilder; - sharedBuilder.module = irModule; - - IRBuilder builder; - builder.sharedBuilder = &sharedBuilder; - - builder.setInsertBefore(func); - - List paramTypes; - for(auto param : func->getParams()) - { - paramTypes.add(param->getFullType()); - } - - auto resultType = func->getResultType(); - - auto funcType = builder.getFuncType(paramTypes, resultType); - builder.setDataType(func, funcType); - } - - // - - bool isTerminatorInst(IROp op) - { - switch (op) - { - default: - return false; - - case kIROp_ReturnVal: - case kIROp_ReturnVoid: - case kIROp_unconditionalBranch: - case kIROp_conditionalBranch: - case kIROp_loop: - case kIROp_ifElse: - case kIROp_discard: - case kIROp_Switch: - case kIROp_Unreachable: - case kIROp_MissingReturn: - return true; - } - } - - bool isTerminatorInst(IRInst* inst) - { - if (!inst) return false; - return isTerminatorInst(inst->op); - } - - // - - IRBlock* IRBuilder::getBlock() - { - return as(insertIntoParent); - } - - // Get the current function (or other value with code) - // that we are inserting into (if any). - IRGlobalValueWithCode* IRBuilder::getFunc() - { - auto pp = insertIntoParent; - if (auto block = as(pp)) - { - pp = pp->getParent(); - } - return as(pp); - } - - - void IRBuilder::setInsertInto(IRInst* insertInto) - { - insertIntoParent = insertInto; - insertBeforeInst = nullptr; - } - - void IRBuilder::setInsertBefore(IRInst* insertBefore) - { - SLANG_ASSERT(insertBefore); - insertIntoParent = insertBefore->parent; - insertBeforeInst = insertBefore; - } - - - // Add an instruction into the current scope - void IRBuilder::addInst( - IRInst* inst) - { - if(insertBeforeInst) - { - inst->insertBefore(insertBeforeInst); - } - else if (insertIntoParent) - { - inst->insertAtEnd(insertIntoParent); - } - else - { - // Don't append the instruction anywhere - } - } - - // Given two parent instructions, pick the better one to use as as - // insertion location for a "hoistable" instruction. - // - IRInst* mergeCandidateParentsForHoistableInst(IRInst* left, IRInst* right) - { - // If the candidates are both the same, then who cares? - if(left == right) return left; - - // If either `left` or `right` is a block, then we need to be - // a bit careful, because blocks can see other values just using - // the dominance relationship, without a direct parent-child relationship. - // - // First, check if each of `left` and `right` is a block. - // - auto leftBlock = as(left); - auto rightBlock = as(right); - // - // As a special case, if both of these are blocks in the same parent, - // then we need to pick between them based on dominance. - // - if (leftBlock && rightBlock && (leftBlock->getParent() == rightBlock->getParent())) - { - // We assume that the order of basic blocks in a function is compatible - // with the dominance relationship (that is, if A dominates B, then - // A comes before B in the list of blocks), so it suffices to pick - // the *later* of the two blocks. - // - // There are ways we could try to speed up this search, but no matter - // what it will be O(n) in the number of blocks, unless we build - // an explicit dominator tree, which is infeasible during IR building. - // Thus we just do a simple linear walk here. - // - // We will start at `leftBlock` and walk forward, until either... - // - for (auto ll = leftBlock; ll; ll = ll->getNextBlock()) - { - // ... we see `rightBlock` (in which case `rightBlock` came later), or ... - // - if (ll == rightBlock) return rightBlock; - } - // - // ... we run out of blocks (in which case `leftBlock` came later). - // - return leftBlock; - } - - // - // If the special case above doesn't apply, then `left` or `right` might - // still be a block, but they aren't blocks nested in the same function. - // We will find the first non-block ancestor of `left` and/or `right`. - // This will either be the inst itself (it is isn't a block), or - // its immediate parent (if it *is* a block). - // - auto leftNonBlock = leftBlock ? leftBlock->getParent() : left; - auto rightNonBlock = rightBlock ? rightBlock->getParent() : right; - - // If either side is null, then take the non-null one. - // - if (!leftNonBlock) return right; - if (!rightNonBlock) return left; - - // If the non-block on the left or right is a descendent of - // the other, then that is what we should use. - // - IRInst* parentNonBlock = nullptr; - for (auto ll = leftNonBlock; ll; ll = ll->getParent()) - { - if (ll == rightNonBlock) - { - parentNonBlock = leftNonBlock; - break; - } - } - for (auto rr = rightNonBlock; rr; rr = rr->getParent()) - { - if (rr == leftNonBlock) - { - SLANG_ASSERT(!parentNonBlock || parentNonBlock == leftNonBlock); - parentNonBlock = rightNonBlock; - break; - } - } - - // As a matter of validity in the IR, we expect one - // of the two to be an ancestor (in the non-block case), - // because otherwise we'd be violating the basic dominance - // assumptions. - // - SLANG_ASSERT(parentNonBlock); - - // As a fallback, try to use the left parent as a default - // in case things go badly. - // - if (!parentNonBlock) - { - parentNonBlock = leftNonBlock; - } - - IRInst* parent = parentNonBlock; - - // At this point we've found a non-block parent where we - // could stick things, but we have to fix things up in - // case we should be inserting into a block beneath - // that non-block parent. - if (leftBlock && (parentNonBlock == leftNonBlock)) - { - // We have a left block, and have picked its parent. - - // It cannot be the case that there is a right block - // with the same parent, or else our special case - // would have triggered at the start. - SLANG_ASSERT(!rightBlock || (parentNonBlock != rightNonBlock)); - - parent = leftBlock; - } - else if (rightBlock && (parentNonBlock == rightNonBlock)) - { - // We have a right block, and have picked its parent. - - // We already tested above, so we know there isn't a - // matching situation on the left side. - - parent = rightBlock; - } - - // Okay, we've picked the parent we want to insert into, - // *but* one last special case arises, because an `IRGlobalValueWithCode` - // is not actually a suitable place to insert instructions. - // Furthermore, there is no actual need to insert instructions at - // that scope, because any parameters, etc. are actually attached - // to the block(s) within the function. - if (auto parentFunc = as(parent)) - { - // Insert in the parent of the function (or other value with code). - // We know that the parent must be able to hold ordinary instructions, - // because it was able to hold this `IRGlobalValueWithCode` - parent = parentFunc->getParent(); - } - - return parent; - } - - IRInst* createEmptyInst( - IRModule* module, - IROp op, - int totalArgCount) - { - size_t size = sizeof(IRInst) + (totalArgCount) * sizeof(IRUse); - - SLANG_ASSERT(module); - IRInst* inst = (IRInst*)module->memoryArena.allocateAndZero(size); - - inst->operandCount = uint32_t(totalArgCount); - inst->op = op; - - return inst; - } - - IRInst* createEmptyInstWithSize( - IRModule* module, - IROp op, - size_t totalSizeInBytes) - { - SLANG_ASSERT(totalSizeInBytes >= sizeof(IRInst)); - - SLANG_ASSERT(module); - IRInst* inst = (IRInst*)module->memoryArena.allocateAndZero(totalSizeInBytes); - - inst->operandCount = 0; - inst->op = op; - - return inst; - } - - // Given an instruction that represents a constant, a type, etc. - // Try to "hoist" it as far toward the global scope as possible - // to insert it at a location where it will be maximally visible. - // - void addHoistableInst( - IRBuilder* builder, - IRInst* inst) - { - // Start with the assumption that we would insert this instruction - // into the global scope (the instruction that represents the module) - IRInst* parent = builder->getModule()->getModuleInst(); - - // The above decision might be invalid, because there might be - // one or more operands of the instruction that are defined in - // more deeply nested parents than the global scope. - // - // Therefore, we will scan the operands of the instruction, and - // look at the parents that define them. - // - UInt operandCount = inst->getOperandCount(); - for (UInt ii = 0; ii < operandCount; ++ii) - { - auto operand = inst->getOperand(ii); - if (!operand) - continue; - - auto operandParent = operand->getParent(); - - parent = mergeCandidateParentsForHoistableInst(parent, operandParent); - } - - // We better have ended up with a place to insert. - SLANG_ASSERT(parent); - - // If we have chosen to insert into the same parent that the - // IRBuilder is configured to use, then respect its `insertBeforeInst` - // setting. - if (parent == builder->insertIntoParent) - { - builder->addInst(inst); - return; - } - - // Otherwise, we just want to insert at the end of the chosen parent. - // - // TODO: be careful about inserting after the terminator of a block... - - inst->insertAtEnd(parent); - } - - static void maybeSetSourceLoc( - IRBuilder* builder, - IRInst* value) - { - if(!builder) - return; - - auto sourceLocInfo = builder->sourceLocInfo; - if(!sourceLocInfo) - return; - - // Try to find something with usable location info - for(;;) - { - if(sourceLocInfo->sourceLoc.getRaw()) - break; - - if(!sourceLocInfo->next) - break; - - sourceLocInfo = sourceLocInfo->next; - } - - value->sourceLoc = sourceLocInfo->sourceLoc; - } - - // Create an IR instruction/value and initialize it. - // - // In this case `argCount` and `args` represent the - // arguments *after* the type (which is a mandatory - // argument for all instructions). - template - static T* createInstImpl( - IRModule* module, - IRBuilder* builder, - IROp op, - IRType* type, - UInt fixedArgCount, - IRInst* const* fixedArgs, - UInt varArgListCount, - UInt const* listArgCounts, - IRInst* const* const* listArgs) - { - UInt varArgCount = 0; - for (UInt ii = 0; ii < varArgListCount; ++ii) - { - varArgCount += listArgCounts[ii]; - } - - UInt size = sizeof(IRInst) + (fixedArgCount + varArgCount) * sizeof(IRUse); - if (sizeof(T) > size) - { - size = sizeof(T); - } - - SLANG_ASSERT(module); - T* inst = (T*)module->memoryArena.allocateAndZero(size); - - // TODO: Do we need to run ctor after zeroing? - new(inst)T(); - - inst->operandCount = (uint32_t)(fixedArgCount + varArgCount); - - inst->op = op; - - if (type) - { - inst->typeUse.init(inst, type); - } - - maybeSetSourceLoc(builder, inst); - - auto operand = inst->getOperands(); - - for( UInt aa = 0; aa < fixedArgCount; ++aa ) - { - if (fixedArgs) - { - operand->init(inst, fixedArgs[aa]); - } - operand++; - } - - for (UInt ii = 0; ii < varArgListCount; ++ii) - { - UInt listArgCount = listArgCounts[ii]; - for (UInt jj = 0; jj < listArgCount; ++jj) - { - if (listArgs[ii]) - { - operand->init(inst, listArgs[ii][jj]); - } - else - { - operand->init(inst, nullptr); - } - operand++; - } - } - return inst; - } - - static IRInst* createInstWithSizeImpl( - IRBuilder* builder, - IROp op, - IRType* type, - size_t sizeInBytes) - { - auto module = builder->getModule(); - IRInst* inst = (IRInst*)module->memoryArena.allocate(sizeInBytes); - // Zero only the 'type' - memset(inst, 0, sizeof(IRInst)); - // TODO: Do we need to run ctor after zeroing? - new (inst) IRInst; - - inst->op = op; - if (type) - { - inst->typeUse.init(inst, type); - } - maybeSetSourceLoc(builder, inst); - return inst; - } - - template - static T* createInstImpl( - IRBuilder* builder, - IROp op, - IRType* type, - UInt fixedArgCount, - IRInst* const* fixedArgs, - UInt varArgCount = 0, - IRInst* const* varArgs = nullptr) - { - return createInstImpl( - builder->getModule(), - builder, - op, - type, - fixedArgCount, - fixedArgs, - 1, - &varArgCount, - &varArgs); - } - - template - static T* createInstImpl( - IRBuilder* builder, - IROp op, - IRType* type, - UInt fixedArgCount, - IRInst* const* fixedArgs, - UInt varArgListCount, - UInt const* listArgCount, - IRInst* const* const* listArgs) - { - return createInstImpl( - builder->getModule(), - builder, - op, - type, - fixedArgCount, - fixedArgs, - varArgListCount, - listArgCount, - listArgs); - } - - template - static T* createInst( - IRBuilder* builder, - IROp op, - IRType* type, - UInt argCount, - IRInst* const* args) - { - return createInstImpl( - builder, - op, - type, - argCount, - args); - } - - template - static T* createInst( - IRBuilder* builder, - IROp op, - IRType* type) - { - return createInstImpl( - builder, - op, - type, - 0, - nullptr); - } - - template - static T* createInst( - IRBuilder* builder, - IROp op, - IRType* type, - IRInst* arg) - { - return createInstImpl( - builder, - op, - type, - 1, - &arg); - } - - template - static T* createInst( - IRBuilder* builder, - IROp op, - IRType* type, - IRInst* arg1, - IRInst* arg2) - { - IRInst* args[] = { arg1, arg2 }; - return createInstImpl( - builder, - op, - type, - 2, - &args[0]); - } - - template - static T* createInstWithTrailingArgs( - IRBuilder* builder, - IROp op, - IRType* type, - UInt argCount, - IRInst* const* args) - { - return createInstImpl( - builder, - op, - type, - argCount, - args); - } - - template - static T* createInstWithTrailingArgs( - IRBuilder* builder, - IROp op, - IRType* type, - UInt fixedArgCount, - IRInst* const* fixedArgs, - UInt varArgCount, - IRInst* const* varArgs) - { - return createInstImpl( - builder, - op, - type, - fixedArgCount, - fixedArgs, - varArgCount, - varArgs); - } - - template - static T* createInstWithTrailingArgs( - IRBuilder* builder, - IROp op, - IRType* type, - IRInst* arg1, - UInt varArgCount, - IRInst* const* varArgs) - { - IRInst* fixedArgs[] = { arg1 }; - UInt fixedArgCount = sizeof(fixedArgs) / sizeof(fixedArgs[0]); - - return createInstImpl( - builder, - op, - type, - fixedArgCount, - fixedArgs, - varArgCount, - varArgs); - } - // - - bool operator==(IRInstKey const& left, IRInstKey const& right) - { - if(left.inst->op != right.inst->op) return false; - if(left.inst->getFullType() != right.inst->getFullType()) return false; - if(left.inst->operandCount != right.inst->operandCount) return false; - - auto argCount = left.inst->operandCount; - auto leftArgs = left.inst->getOperands(); - auto rightArgs = right.inst->getOperands(); - for( UInt aa = 0; aa < argCount; ++aa ) - { - if(leftArgs[aa].get() != rightArgs[aa].get()) - return false; - } - - return true; - } - - int IRInstKey::GetHashCode() - { - auto code = Slang::GetHashCode(inst->op); - code = combineHash(code, Slang::GetHashCode(inst->getFullType())); - code = combineHash(code, Slang::GetHashCode(inst->getOperandCount())); - - auto argCount = inst->getOperandCount(); - auto args = inst->getOperands(); - for( UInt aa = 0; aa < argCount; ++aa ) - { - code = combineHash(code, Slang::GetHashCode(args[aa].get())); - } - return code; - } - - UnownedStringSlice IRConstant::getStringSlice() - { - assert(op == kIROp_StringLit); - // If the transitory decoration is set, then this is uses the transitoryStringVal for the text storage. - // This is typically used when we are using a transitory IRInst held on the stack (such that it can be looked up in cached), - // that just points to a string elsewhere, and NOT the typical normal style, where the string is held after the instruction in memory. - // - if(findDecorationImpl(kIROp_TransitoryDecoration)) - { - return UnownedStringSlice(value.transitoryStringVal.chars, value.transitoryStringVal.numChars); - } - else - { - return UnownedStringSlice(value.stringVal.chars, value.stringVal.numChars); - } - } - - bool IRConstant::isValueEqual(IRConstant* rhs) - { - // If they are literally the same thing.. - if (this == rhs) - { - return true; - } - // Check the type and they are the same op & same type - if (op != rhs->op) - { - return false; - } - - switch (op) - { - case kIROp_BoolLit: - case kIROp_FloatLit: - case kIROp_IntLit: - { - SLANG_COMPILE_TIME_ASSERT(sizeof(IRFloatingPointValue) == sizeof(IRIntegerValue)); - // ... we can just compare as bits - return value.intVal == rhs->value.intVal; - } - case kIROp_PtrLit: - { - return value.ptrVal == rhs->value.ptrVal; - } - case kIROp_StringLit: - { - return getStringSlice() == rhs->getStringSlice(); - } - default: break; - } - - SLANG_ASSERT(!"Unhandled type"); - return false; - } - - /// True if constants are equal - bool IRConstant::equal(IRConstant* rhs) - { - // TODO(JS): Only equal if pointer types are identical (to match how getHashCode works below) - return isValueEqual(rhs) && getFullType() == rhs->getFullType(); - } - - int IRConstant::getHashCode() - { - auto code = Slang::GetHashCode(op); - code = combineHash(code, Slang::GetHashCode(getFullType())); - - switch (op) - { - case kIROp_BoolLit: - case kIROp_FloatLit: - case kIROp_IntLit: - { - SLANG_COMPILE_TIME_ASSERT(sizeof(IRFloatingPointValue) == sizeof(IRIntegerValue)); - // ... we can just compare as bits - return combineHash(code, Slang::GetHashCode(value.intVal)); - } - case kIROp_PtrLit: - { - return combineHash(code, Slang::GetHashCode(value.ptrVal)); - } - case kIROp_StringLit: - { - const UnownedStringSlice slice = getStringSlice(); - return combineHash(code, Slang::GetHashCode(slice.begin(), slice.size())); - } - default: - { - SLANG_ASSERT(!"Invalid type"); - return 0; - } - } - } - - static IRConstant* findOrEmitConstant( - IRBuilder* builder, - IRConstant& keyInst) - { - // We now know where we want to insert, but there might - // already be an equivalent instruction in that block. - // - // We will check for such an instruction in a slightly hacky - // way: we will construct a temporary instruction and - // then use it to look up in a cache of instructions. - // The 'fake' instruction is passed in as keyInst. - - IRConstantKey key; - key.inst = &keyInst; - - IRConstant* irValue = nullptr; - if( builder->sharedBuilder->constantMap.TryGetValue(key, irValue) ) - { - // We found a match, so just use that. - return irValue; - } - - // Calculate the minimum object size (ie not including the payload of value) - const size_t prefixSize = SLANG_OFFSET_OF(IRConstant, value); - - switch (keyInst.op) - { - default: - SLANG_UNEXPECTED("missing case for IR constant"); - break; - - case kIROp_BoolLit: - case kIROp_IntLit: - { - irValue = static_cast(createInstWithSizeImpl(builder, keyInst.op, keyInst.getFullType(), prefixSize + sizeof(IRIntegerValue))); - irValue->value.intVal = keyInst.value.intVal; - break; - } - case kIROp_FloatLit: - { - irValue = static_cast(createInstWithSizeImpl(builder, keyInst.op, keyInst.getFullType(), prefixSize + sizeof(IRFloatingPointValue))); - irValue->value.floatVal = keyInst.value.floatVal; - break; - } - case kIROp_PtrLit: - { - irValue = static_cast(createInstWithSizeImpl(builder, keyInst.op, keyInst.getFullType(), prefixSize + sizeof(void*))); - irValue->value.ptrVal = keyInst.value.ptrVal; - break; - } - case kIROp_StringLit: - { - const UnownedStringSlice slice = keyInst.getStringSlice(); - - const size_t sliceSize = slice.size(); - const size_t instSize = prefixSize + offsetof(IRConstant::StringValue, chars) + sliceSize; - - irValue = static_cast(createInstWithSizeImpl(builder, keyInst.op, keyInst.getFullType(), instSize)); - - IRConstant::StringValue& dstString = irValue->value.stringVal; - - dstString.numChars = uint32_t(sliceSize); - // Turn into pointer to avoid warning of array overrun - char* dstChars = dstString.chars; - // Copy the chars - memcpy(dstChars, slice.begin(), sliceSize); - - break; - } - } - - key.inst = irValue; - builder->sharedBuilder->constantMap.Add(key, irValue); - - addHoistableInst(builder, irValue); - - return irValue; - } - - // - - IRInst* IRBuilder::getBoolValue(bool inValue) - { - IRConstant keyInst; - memset(&keyInst, 0, sizeof(keyInst)); - keyInst.op = kIROp_BoolLit; - keyInst.typeUse.usedValue = getBoolType(); - keyInst.value.intVal = IRIntegerValue(inValue); - return findOrEmitConstant(this, keyInst); - } - - IRInst* IRBuilder::getIntValue(IRType* type, IRIntegerValue inValue) - { - IRConstant keyInst; - memset(&keyInst, 0, sizeof(keyInst)); - keyInst.op = kIROp_IntLit; - keyInst.typeUse.usedValue = type; - keyInst.value.intVal = inValue; - return findOrEmitConstant(this, keyInst); - } - - IRInst* IRBuilder::getFloatValue(IRType* type, IRFloatingPointValue inValue) - { - IRConstant keyInst; - memset(&keyInst, 0, sizeof(keyInst)); - keyInst.op = kIROp_FloatLit; - keyInst.typeUse.usedValue = type; - keyInst.value.floatVal = inValue; - return findOrEmitConstant(this, keyInst); - } - - IRStringLit* IRBuilder::getStringValue(const UnownedStringSlice& inSlice) - { - IRConstant keyInst; - memset(&keyInst, 0, sizeof(keyInst)); - - // Mark that this is on the stack... - IRDecoration stackDecoration; - memset(&stackDecoration, 0, sizeof(stackDecoration)); - stackDecoration.op = kIROp_TransitoryDecoration; - stackDecoration.insertAtEnd(&keyInst); - - keyInst.op = kIROp_StringLit; - keyInst.typeUse.usedValue = getStringType(); - - IRConstant::StringSliceValue& dstSlice = keyInst.value.transitoryStringVal; - dstSlice.chars = const_cast(inSlice.begin()); - dstSlice.numChars = uint32_t(inSlice.size()); - - return static_cast(findOrEmitConstant(this, keyInst)); - } - - IRPtrLit* IRBuilder::getPtrValue(void* value) - { - IRType* type = getPtrType(getVoidType()); - - IRConstant keyInst; - memset(&keyInst, 0, sizeof(keyInst)); - keyInst.op = kIROp_PtrLit; - keyInst.typeUse.usedValue = type; - keyInst.value.ptrVal = value; - return (IRPtrLit*) findOrEmitConstant(this, keyInst); - } - - - IRInst* findOrEmitHoistableInst( - IRBuilder* builder, - IRType* type, - IROp op, - UInt operandListCount, - UInt const* listOperandCounts, - IRInst* const* const* listOperands) - { - UInt operandCount = 0; - for (UInt ii = 0; ii < operandListCount; ++ii) - { - operandCount += listOperandCounts[ii]; - } - - auto& memoryArena = builder->getModule()->memoryArena; - void* cursor = memoryArena.getCursor(); - - // We are going to create a 'dummy' instruction on the memoryArena - // which can be used as a key for lookup, so see if we - // already have an equivalent instruction available to use. - size_t keySize = sizeof(IRInst) + operandCount * sizeof(IRUse); - IRInst* inst = (IRInst*) memoryArena.allocateAndZero(keySize); - - void* endCursor = memoryArena.getCursor(); - // Mark as 'unused' cos it is unused on release builds. - SLANG_UNUSED(endCursor); - - new(inst) IRInst(); - inst->op = op; - inst->typeUse.usedValue = type; - inst->operandCount = (uint32_t) operandCount; - - // Don't link up as we may free (if we already have this key) - { - IRUse* operand = inst->getOperands(); - for (UInt ii = 0; ii < operandListCount; ++ii) - { - UInt listOperandCount = listOperandCounts[ii]; - for (UInt jj = 0; jj < listOperandCount; ++jj) - { - operand->usedValue = listOperands[ii][jj]; - operand++; - } - } - } - - // Find or add the key/inst - { - IRInstKey key = { inst }; - - // Ideally we would add if not found, else return if was found instead of testing & then adding. - IRInst** found = builder->sharedBuilder->globalValueNumberingMap.TryGetValueOrAdd(key, inst); - SLANG_ASSERT(endCursor == memoryArena.getCursor()); - // If it's found, just return, and throw away the instruction - if (found) - { - memoryArena.rewindToCursor(cursor); - return *found; - } - } - - // Make the lookup 'inst' instruction into 'proper' instruction. Equivalent to - // IRInst* inst = createInstImpl(builder, op, type, 0, nullptr, operandListCount, listOperandCounts, listOperands); - { - if (type) - { - inst->typeUse.usedValue = nullptr; - inst->typeUse.init(inst, type); - } - - maybeSetSourceLoc(builder, inst); - - IRUse*const operands = inst->getOperands(); - for (UInt i = 0; i < operandCount; ++i) - { - IRUse& operand = operands[i]; - auto value = operand.usedValue; - - operand.usedValue = nullptr; - operand.init(inst, value); - } - } - - addHoistableInst(builder, inst); - - return inst; - } - - IRInst* findOrEmitHoistableInst( - IRBuilder* builder, - IRType* type, - IROp op, - UInt operandCount, - IRInst* const* operands) - { - return findOrEmitHoistableInst( - builder, - type, - op, - 1, - &operandCount, - &operands); - } - - IRInst* findOrEmitHoistableInst( - IRBuilder* builder, - IRType* type, - IROp op, - IRInst* operand, - UInt operandCount, - IRInst* const* operands) - { - UInt counts[] = { 1, operandCount }; - IRInst* const* lists[] = { &operand, operands }; - - return findOrEmitHoistableInst( - builder, - type, - op, - 2, - counts, - lists); - } - - - IRType* IRBuilder::getType( - IROp op, - UInt operandCount, - IRInst* const* operands) - { - return (IRType*) findOrEmitHoistableInst( - this, - nullptr, - op, - operandCount, - operands); - } - - IRType* IRBuilder::getType( - IROp op) - { - return getType(op, 0, nullptr); - } - - IRBasicType* IRBuilder::getBasicType(BaseType baseType) - { - return (IRBasicType*)getType( - IROp((UInt)kIROp_FirstBasicType + (UInt)baseType)); - } - - IRBasicType* IRBuilder::getVoidType() - { - return (IRVoidType*)getType(kIROp_VoidType); - } - - IRBasicType* IRBuilder::getBoolType() - { - return (IRBoolType*)getType(kIROp_BoolType); - } - - IRBasicType* IRBuilder::getIntType() - { - return (IRBasicType*)getType(kIROp_IntType); - } - - IRStringType* IRBuilder::getStringType() - { - return (IRStringType*)getType(kIROp_StringType); - } - - IRBasicBlockType* IRBuilder::getBasicBlockType() - { - return (IRBasicBlockType*)getType(kIROp_BasicBlockType); - } - - IRTypeKind* IRBuilder::getTypeKind() - { - return (IRTypeKind*)getType(kIROp_TypeKind); - } - - IRGenericKind* IRBuilder::getGenericKind() - { - return (IRGenericKind*)getType(kIROp_GenericKind); - } - - IRPtrType* IRBuilder::getPtrType(IRType* valueType) - { - return (IRPtrType*) getPtrType(kIROp_PtrType, valueType); - } - - IROutType* IRBuilder::getOutType(IRType* valueType) - { - return (IROutType*) getPtrType(kIROp_OutType, valueType); - } - - IRInOutType* IRBuilder::getInOutType(IRType* valueType) - { - return (IRInOutType*) getPtrType(kIROp_InOutType, valueType); - } - - IRRefType* IRBuilder::getRefType(IRType* valueType) - { - return (IRRefType*) getPtrType(kIROp_RefType, valueType); - } - - IRPtrTypeBase* IRBuilder::getPtrType(IROp op, IRType* valueType) - { - IRInst* operands[] = { valueType }; - return (IRPtrTypeBase*) getType( - op, - 1, - operands); - } - - IRArrayTypeBase* IRBuilder::getArrayTypeBase( - IROp op, - IRType* elementType, - IRInst* elementCount) - { - IRInst* operands[] = { elementType, elementCount }; - return (IRArrayTypeBase*)getType( - op, - op == kIROp_ArrayType ? 2 : 1, - operands); - } - - IRArrayType* IRBuilder::getArrayType( - IRType* elementType, - IRInst* elementCount) - { - IRInst* operands[] = { elementType, elementCount }; - return (IRArrayType*)getType( - kIROp_ArrayType, - sizeof(operands) / sizeof(operands[0]), - operands); - } - - IRUnsizedArrayType* IRBuilder::getUnsizedArrayType( - IRType* elementType) - { - IRInst* operands[] = { elementType }; - return (IRUnsizedArrayType*)getType( - kIROp_UnsizedArrayType, - sizeof(operands) / sizeof(operands[0]), - operands); - } - - IRVectorType* IRBuilder::getVectorType( - IRType* elementType, - IRInst* elementCount) - { - IRInst* operands[] = { elementType, elementCount }; - return (IRVectorType*)getType( - kIROp_VectorType, - sizeof(operands) / sizeof(operands[0]), - operands); - } - - IRMatrixType* IRBuilder::getMatrixType( - IRType* elementType, - IRInst* rowCount, - IRInst* columnCount) - { - IRInst* operands[] = { elementType, rowCount, columnCount }; - return (IRMatrixType*)getType( - kIROp_MatrixType, - sizeof(operands) / sizeof(operands[0]), - operands); - } - - IRFuncType* IRBuilder::getFuncType( - UInt paramCount, - IRType* const* paramTypes, - IRType* resultType) - { - return (IRFuncType*) findOrEmitHoistableInst( - this, - nullptr, - kIROp_FuncType, - resultType, - paramCount, - (IRInst* const*) paramTypes); - } - - IRConstantBufferType* IRBuilder::getConstantBufferType(IRType* elementType) - { - IRInst* operands[] = { elementType }; - return (IRConstantBufferType*) getType( - kIROp_ConstantBufferType, - 1, - operands); - } - - IRConstExprRate* IRBuilder::getConstExprRate() - { - return (IRConstExprRate*)getType(kIROp_ConstExprRate); - } - - IRGroupSharedRate* IRBuilder::getGroupSharedRate() - { - return (IRGroupSharedRate*)getType(kIROp_GroupSharedRate); - } - - IRRateQualifiedType* IRBuilder::getRateQualifiedType( - IRRate* rate, - IRType* dataType) - { - IRInst* operands[] = { rate, dataType }; - return (IRRateQualifiedType*)getType( - kIROp_RateQualifiedType, - sizeof(operands) / sizeof(operands[0]), - operands); - } - - IRType* IRBuilder::getTaggedUnionType( - UInt caseCount, - IRType* const* caseTypes) - { - return (IRType*) findOrEmitHoistableInst( - this, - getTypeKind(), - kIROp_TaggedUnionType, - caseCount, - (IRInst* const*) caseTypes); - } - - IRType* IRBuilder::getBindExistentialsType( - IRInst* baseType, - UInt slotArgCount, - IRInst* const* slotArgs) - { - if(slotArgCount == 0) - return (IRType*) baseType; - - // If we are trying to bind an interface type, then - // we will go ahead and simplify the instruction - // away impmediately. - // - if(as(baseType)) - { - if(slotArgCount >= 1) - { - // We are being asked to emit `BindExistentials(someInterface, someConcreteType, ...)` - // so we just want to return `ExistentialBox`. - // - auto concreteType = (IRType*) slotArgs[0]; - auto ptrType = getPtrType(kIROp_ExistentialBoxType, concreteType); - return ptrType; - } - } - - return (IRType*) findOrEmitHoistableInst( - this, - getTypeKind(), - kIROp_BindExistentialsType, - baseType, - slotArgCount, - (IRInst* const*) slotArgs); - } - - IRType* IRBuilder::getBindExistentialsType( - IRInst* baseType, - UInt slotArgCount, - IRUse const* slotArgUses) - { - if(slotArgCount == 0) - return (IRType*) baseType; - - List slotArgs; - for( UInt ii = 0; ii < slotArgCount; ++ii ) - { - slotArgs.add(slotArgUses[ii].get()); - } - return getBindExistentialsType( - baseType, - slotArgCount, - slotArgs.getBuffer()); - } - - - - void IRBuilder::setDataType(IRInst* inst, IRType* dataType) - { - if (auto oldRateQualifiedType = as(inst->getFullType())) - { - // Construct a new rate-qualified type using the same rate. - - auto newRateQualifiedType = getRateQualifiedType( - oldRateQualifiedType->getRate(), - dataType); - - inst->setFullType(newRateQualifiedType); - } - else - { - // No rate? Just clobber the data type. - inst->setFullType(dataType); - } - } - - - IRUndefined* IRBuilder::emitUndefined(IRType* type) - { - auto inst = createInst( - this, - kIROp_undefined, - type); - - addInst(inst); - - return inst; - } - - IRInst* IRBuilder::emitExtractExistentialValue( - IRType* type, - IRInst* existentialValue) - { - auto inst = createInst( - this, - kIROp_ExtractExistentialValue, - type, - 1, - &existentialValue); - addInst(inst); - return inst; - } - - IRType* IRBuilder::emitExtractExistentialType( - IRInst* existentialValue) - { - auto type = getTypeKind(); - auto inst = createInst( - this, - kIROp_ExtractExistentialType, - type, - 1, - &existentialValue); - addInst(inst); - return (IRType*) inst; - } - - IRInst* IRBuilder::emitExtractExistentialWitnessTable( - IRInst* existentialValue) - { - auto type = getWitnessTableType(); - auto inst = createInst( - this, - kIROp_ExtractExistentialWitnessTable, - type, - 1, - &existentialValue); - addInst(inst); - return inst; - } - - IRInst* IRBuilder::emitSpecializeInst( - IRType* type, - IRInst* genericVal, - UInt argCount, - IRInst* const* args) - { - auto inst = createInstWithTrailingArgs( - this, - kIROp_Specialize, - type, - 1, - &genericVal, - argCount, - args); - - addInst(inst); - return inst; - } - - IRInst* IRBuilder::emitLookupInterfaceMethodInst( - IRType* type, - IRInst* witnessTableVal, - IRInst* interfaceMethodVal) - { - auto inst = createInst( - this, - kIROp_lookup_interface_method, - type, - witnessTableVal, - interfaceMethodVal); - - addInst(inst); - return inst; - } - - IRInst* IRBuilder::emitCallInst( - IRType* type, - IRInst* pFunc, - UInt argCount, - IRInst* const* args) - { - auto inst = createInstWithTrailingArgs( - this, - kIROp_Call, - type, - 1, - &pFunc, - argCount, - args); - addInst(inst); - return inst; - } - - IRInst* IRBuilder::createIntrinsicInst( - IRType* type, - IROp op, - UInt argCount, - IRInst* const* args) - { - return createInstWithTrailingArgs( - this, - op, - type, - argCount, - args); - } - - - IRInst* IRBuilder::emitIntrinsicInst( - IRType* type, - IROp op, - UInt argCount, - IRInst* const* args) - { - auto inst = createIntrinsicInst( - type, - op, - argCount, - args); - addInst(inst); - return inst; - } - - IRInst* IRBuilder::emitConstructorInst( - IRType* type, - UInt argCount, - IRInst* const* args) - { - auto inst = createInstWithTrailingArgs( - this, - kIROp_Construct, - type, - argCount, - args); - addInst(inst); - return inst; - } - - IRInst* IRBuilder::emitMakeVector( - IRType* type, - UInt argCount, - IRInst* const* args) - { - return emitIntrinsicInst(type, kIROp_makeVector, argCount, args); - } - - IRInst* IRBuilder::emitMakeMatrix( - IRType* type, - UInt argCount, - IRInst* const* args) - { - return emitIntrinsicInst(type, kIROp_MakeMatrix, argCount, args); - } - - IRInst* IRBuilder::emitMakeArray( - IRType* type, - UInt argCount, - IRInst* const* args) - { - return emitIntrinsicInst(type, kIROp_makeArray, argCount, args); - } - - IRInst* IRBuilder::emitMakeStruct( - IRType* type, - UInt argCount, - IRInst* const* args) - { - return emitIntrinsicInst(type, kIROp_makeStruct, argCount, args); - } - - IRInst* IRBuilder::emitMakeExistential( - IRType* type, - IRInst* value, - IRInst* witnessTable) - { - IRInst* args[] = {value, witnessTable}; - return emitIntrinsicInst(type, kIROp_MakeExistential, SLANG_COUNT_OF(args), args); - } - - IRInst* IRBuilder::emitWrapExistential( - IRType* type, - IRInst* value, - UInt slotArgCount, - IRInst* const* slotArgs) - { - if(slotArgCount == 0) - return value; - - // If we are wrapping a single concrete value into - // an interface type, then this is really a `makeExistential` - // - // TODO: We may want to check for a `specialize` of a generic interface as well. - // - if(as(type)) - { - if(slotArgCount >= 2) - { - // We are being asked to emit `wrapExistential(value, concreteType, witnessTable, ...) : someInterface` - // - // We also know that a concrete value being wrapped will always be an existential box, - // so we expect that `value : ExistentialBox` for some `T`. - // - // We want to emit `makeExistential(load(value), witnessTable)`. - // - auto deref = emitLoad(value); - return emitMakeExistential(type, deref, slotArgs[1]); - } - } - - IRInst* fixedArgs[] = {value}; - auto inst = createInstImpl( - this, - kIROp_WrapExistential, - type, - SLANG_COUNT_OF(fixedArgs), - fixedArgs, - slotArgCount, - slotArgs); - addInst(inst); - return inst; - } - - IRModule* IRBuilder::createModule() - { - auto module = new IRModule(); - module->session = getSession(); - - auto moduleInst = createInstImpl( - module, - this, - kIROp_Module, - nullptr, - 0, - nullptr, - 0, - nullptr, - nullptr); - module->moduleInst = moduleInst; - moduleInst->module = module; - - return module; - } - - void addGlobalValue( - IRBuilder* builder, - IRInst* value) - { - // Try to find a suitable parent for the - // global value we are emitting. - // - // We will start out search at the current - // parent instruction for the builder, and - // possibly work our way up. - // - auto parent = builder->insertIntoParent; - while(parent) - { - // Inserting into the top level of a module? - // That is fine, and we can stop searching. - if (as(parent)) - break; - - // Inserting into a basic block inside of - // a generic? That is okay too. - if (auto block = as(parent)) - { - if (as(block->parent)) - break; - } - - // Otherwise, move up the chain. - parent = parent->parent; - } - - // If we somehow ran out of parents (possibly - // because an instruction wasn't linked into - // the full hierarchy yet), then we will - // fall back to inserting into the overall module. - if (!parent) - { - parent = builder->getModule()->getModuleInst(); - } - - // If it turns out that we are inserting into the - // current "insert into" parent for the builder, then - // we need to respect its "insert before" setting - // as well. - if (parent == builder->insertIntoParent - && builder->insertBeforeInst) - { - value->insertBefore(builder->insertBeforeInst); - } - else - { - value->insertAtEnd(parent); - } - } - - IRFunc* IRBuilder::createFunc() - { - IRFunc* rsFunc = createInst( - this, - kIROp_Func, - nullptr); - maybeSetSourceLoc(this, rsFunc); - addGlobalValue(this, rsFunc); - return rsFunc; - } - - IRGlobalVar* IRBuilder::createGlobalVar( - IRType* valueType) - { - auto ptrType = getPtrType(valueType); - IRGlobalVar* globalVar = createInst( - this, - kIROp_GlobalVar, - ptrType); - maybeSetSourceLoc(this, globalVar); - addGlobalValue(this, globalVar); - return globalVar; - } - - IRGlobalConstant* IRBuilder::createGlobalConstant( - IRType* valueType) - { - IRGlobalConstant* globalConstant = createInst( - this, - kIROp_GlobalConstant, - valueType); - maybeSetSourceLoc(this, globalConstant); - addGlobalValue(this, globalConstant); - return globalConstant; - } - - IRGlobalParam* IRBuilder::createGlobalParam( - IRType* valueType) - { - IRGlobalParam* inst = createInst( - this, - kIROp_GlobalParam, - valueType); - maybeSetSourceLoc(this, inst); - addGlobalValue(this, inst); - return inst; - } - - IRWitnessTable* IRBuilder::createWitnessTable() - { - IRWitnessTable* witnessTable = createInst( - this, - kIROp_WitnessTable, - nullptr); - addGlobalValue(this, witnessTable); - return witnessTable; - } - - IRWitnessTableEntry* IRBuilder::createWitnessTableEntry( - IRWitnessTable* witnessTable, - IRInst* requirementKey, - IRInst* satisfyingVal) - { - IRWitnessTableEntry* entry = createInst( - this, - kIROp_WitnessTableEntry, - nullptr, - requirementKey, - satisfyingVal); - - if (witnessTable) - { - entry->insertAtEnd(witnessTable); - } - - return entry; - } - - IRStructType* IRBuilder::createStructType() - { - IRStructType* structType = createInst( - this, - kIROp_StructType, - nullptr); - addGlobalValue(this, structType); - return structType; - } - - IRInterfaceType* IRBuilder::createInterfaceType() - { - IRInterfaceType* interfaceType = createInst( - this, - kIROp_InterfaceType, - nullptr); - addGlobalValue(this, interfaceType); - return interfaceType; - } - - IRStructKey* IRBuilder::createStructKey() - { - IRStructKey* structKey = createInst( - this, - kIROp_StructKey, - nullptr); - addGlobalValue(this, structKey); - return structKey; - } - - // Create a field nested in a struct type, declaring that - // the specified field key maps to a field with the specified type. - IRStructField* IRBuilder::createStructField( - IRStructType* structType, - IRStructKey* fieldKey, - IRType* fieldType) - { - IRInst* operands[] = { fieldKey, fieldType }; - IRStructField* field = (IRStructField*) createInstWithTrailingArgs( - this, - kIROp_StructField, - nullptr, - 0, - nullptr, - 2, - operands); - - if (structType) - { - field->insertAtEnd(structType); - } - - return field; - } - - IRGeneric* IRBuilder::createGeneric() - { - IRGeneric* irGeneric = createInst( - this, - kIROp_Generic, - nullptr); - return irGeneric; - } - - IRGeneric* IRBuilder::emitGeneric() - { - auto irGeneric = createGeneric(); - addGlobalValue(this, irGeneric); - return irGeneric; - } - - IRBlock* IRBuilder::createBlock() - { - return createInst( - this, - kIROp_Block, - getBasicBlockType()); - } - - void IRBuilder::insertBlock(IRBlock* block) - { - // If we are emitting into a function - // (or another value with code), then - // append the block to the function and - // set this block as the new parent for - // subsequent instructions we insert. - // - // TODO: This should probably insert the block - // after the current "insert into" block if - // there is one. Right now we are always - // adding the block to the end of the list, - // which is technically valid (the ordering - // of blocks doesn't affect the CFG topology), - // but some later passes might assume the ordering - // is significant in representing the intent - // of the original code. - // - auto f = getFunc(); - if (f) - { - f->addBlock(block); - setInsertInto(block); - } - } - - IRBlock* IRBuilder::emitBlock() - { - auto block = createBlock(); - insertBlock(block); - return block; - } - - IRParam* IRBuilder::createParam( - IRType* type) - { - auto param = createInst( - this, - kIROp_Param, - type); - return param; - } - - IRParam* IRBuilder::emitParam( - IRType* type) - { - auto param = createParam(type); - if (auto bb = getBlock()) - { - bb->addParam(param); - } - return param; - } - - IRVar* IRBuilder::emitVar( - IRType* type) - { - auto allocatedType = getPtrType(type); - auto inst = createInst( - this, - kIROp_Var, - allocatedType); - addInst(inst); - return inst; - } - - IRInst* IRBuilder::emitLoad( - IRType* type, - IRInst* ptr) - { - auto inst = createInst( - this, - kIROp_Load, - type, - ptr); - - addInst(inst); - return inst; - } - - IRInst* IRBuilder::emitLoad( - IRInst* ptr) - { - // Note: a `load` operation does not consider the rate - // (if any) attached to its operand (see the use of `getDataType` - // below). This means that a load from a rate-qualified - // variable will still conceptually execute (and return - // results) at the "default" rate of the parent function, - // unless a subsequent analysis pass constraints it. - - IRType* valueType = tryGetPointedToType(this, ptr->getFullType()); - SLANG_ASSERT(valueType); - - // Ugly special case: if the front-end created a variable with - // type `Ptr<@R T>` instead of `@R Ptr`, then the above - // logic will yield `@R T` instead of `T`, and we need to - // try and fix that up here. - // - // TODO: Lowering to the IR should be fixed to never create - // that case: rate-qualified types should only be allowed - // to appear as the type of an instruction, and should not - // be allowed as operands to type constructors (except - // in special cases we decide to allow). - // - if(auto rateType = as(valueType)) - { - valueType = rateType->getValueType(); - } - - return emitLoad(valueType, ptr); - } - - IRInst* IRBuilder::emitStore( - IRInst* dstPtr, - IRInst* srcVal) - { - auto inst = createInst( - this, - kIROp_Store, - nullptr, - dstPtr, - srcVal); - - addInst(inst); - return inst; - } - - IRInst* IRBuilder::emitFieldExtract( - IRType* type, - IRInst* base, - IRInst* field) - { - auto inst = createInst( - this, - kIROp_FieldExtract, - type, - base, - field); - - addInst(inst); - return inst; - } - - IRInst* IRBuilder::emitFieldAddress( - IRType* type, - IRInst* base, - IRInst* field) - { - auto inst = createInst( - this, - kIROp_FieldAddress, - type, - base, - field); - - addInst(inst); - return inst; - } - - IRInst* IRBuilder::emitElementExtract( - IRType* type, - IRInst* base, - IRInst* index) - { - auto inst = createInst( - this, - kIROp_getElement, - type, - base, - index); - - addInst(inst); - return inst; - } - - IRInst* IRBuilder::emitElementAddress( - IRType* type, - IRInst* basePtr, - IRInst* index) - { - auto inst = createInst( - this, - kIROp_getElementPtr, - type, - basePtr, - index); - - addInst(inst); - return inst; - } - - IRInst* IRBuilder::emitSwizzle( - IRType* type, - IRInst* base, - UInt elementCount, - IRInst* const* elementIndices) - { - auto inst = createInstWithTrailingArgs( - this, - kIROp_swizzle, - type, - base, - elementCount, - elementIndices); - - addInst(inst); - return inst; - } - - IRInst* IRBuilder::emitSwizzle( - IRType* type, - IRInst* base, - UInt elementCount, - UInt const* elementIndices) - { - auto intType = getBasicType(BaseType::Int); - - IRInst* irElementIndices[4]; - for (UInt ii = 0; ii < elementCount; ++ii) - { - irElementIndices[ii] = getIntValue(intType, elementIndices[ii]); - } - - return emitSwizzle(type, base, elementCount, irElementIndices); - } - - - IRInst* IRBuilder::emitSwizzleSet( - IRType* type, - IRInst* base, - IRInst* source, - UInt elementCount, - IRInst* const* elementIndices) - { - IRInst* fixedArgs[] = { base, source }; - UInt fixedArgCount = sizeof(fixedArgs) / sizeof(fixedArgs[0]); - - auto inst = createInstWithTrailingArgs( - this, - kIROp_swizzleSet, - type, - fixedArgCount, - fixedArgs, - elementCount, - elementIndices); - - addInst(inst); - return inst; - } - - IRInst* IRBuilder::emitSwizzleSet( - IRType* type, - IRInst* base, - IRInst* source, - UInt elementCount, - UInt const* elementIndices) - { - auto intType = getBasicType(BaseType::Int); - - IRInst* irElementIndices[4]; - for (UInt ii = 0; ii < elementCount; ++ii) - { - irElementIndices[ii] = getIntValue(intType, elementIndices[ii]); - } - - return emitSwizzleSet(type, base, source, elementCount, irElementIndices); - } - - IRInst* IRBuilder::emitSwizzledStore( - IRInst* dest, - IRInst* source, - UInt elementCount, - IRInst* const* elementIndices) - { - IRInst* fixedArgs[] = { dest, source }; - UInt fixedArgCount = sizeof(fixedArgs) / sizeof(fixedArgs[0]); - - auto inst = createInstImpl( - this, - kIROp_SwizzledStore, - nullptr, - fixedArgCount, - fixedArgs, - elementCount, - elementIndices); - - addInst(inst); - return inst; - } - - IRInst* IRBuilder::emitSwizzledStore( - IRInst* dest, - IRInst* source, - UInt elementCount, - UInt const* elementIndices) - { - auto intType = getBasicType(BaseType::Int); - - IRInst* irElementIndices[4]; - for (UInt ii = 0; ii < elementCount; ++ii) - { - irElementIndices[ii] = getIntValue(intType, elementIndices[ii]); - } - - return emitSwizzledStore(dest, source, elementCount, irElementIndices); - } - - IRInst* IRBuilder::emitReturn( - IRInst* val) - { - auto inst = createInst( - this, - kIROp_ReturnVal, - nullptr, - val); - addInst(inst); - return inst; - } - - IRInst* IRBuilder::emitReturn() - { - auto inst = createInst( - this, - kIROp_ReturnVoid, - nullptr); - addInst(inst); - return inst; - } - - IRInst* IRBuilder::emitUnreachable() - { - auto inst = createInst( - this, - kIROp_Unreachable, - nullptr); - addInst(inst); - return inst; - } - - IRInst* IRBuilder::emitMissingReturn() - { - auto inst = createInst( - this, - kIROp_MissingReturn, - nullptr); - addInst(inst); - return inst; - } - - IRInst* IRBuilder::emitDiscard() - { - auto inst = createInst( - this, - kIROp_discard, - nullptr); - addInst(inst); - return inst; - } - - - IRInst* IRBuilder::emitBranch( - IRBlock* pBlock) - { - auto inst = createInst( - this, - kIROp_unconditionalBranch, - nullptr, - pBlock); - addInst(inst); - return inst; - } - - IRInst* IRBuilder::emitBreak( - IRBlock* target) - { - return emitBranch(target); - } - - IRInst* IRBuilder::emitContinue( - IRBlock* target) - { - return emitBranch(target); - } - - IRInst* IRBuilder::emitLoop( - IRBlock* target, - IRBlock* breakBlock, - IRBlock* continueBlock) - { - IRInst* args[] = { target, breakBlock, continueBlock }; - UInt argCount = sizeof(args) / sizeof(args[0]); - - auto inst = createInst( - this, - kIROp_loop, - nullptr, - argCount, - args); - addInst(inst); - return inst; - } - - IRInst* IRBuilder::emitBranch( - IRInst* val, - IRBlock* trueBlock, - IRBlock* falseBlock) - { - IRInst* args[] = { val, trueBlock, falseBlock }; - UInt argCount = sizeof(args) / sizeof(args[0]); - - auto inst = createInst( - this, - kIROp_conditionalBranch, - nullptr, - argCount, - args); - addInst(inst); - return inst; - } - - IRInst* IRBuilder::emitIfElse( - IRInst* val, - IRBlock* trueBlock, - IRBlock* falseBlock, - IRBlock* afterBlock) - { - IRInst* args[] = { val, trueBlock, falseBlock, afterBlock }; - UInt argCount = sizeof(args) / sizeof(args[0]); - - auto inst = createInst( - this, - kIROp_ifElse, - nullptr, - argCount, - args); - addInst(inst); - return inst; - } - - IRInst* IRBuilder::emitIf( - IRInst* val, - IRBlock* trueBlock, - IRBlock* afterBlock) - { - return emitIfElse(val, trueBlock, afterBlock, afterBlock); - } - - IRInst* IRBuilder::emitLoopTest( - IRInst* val, - IRBlock* bodyBlock, - IRBlock* breakBlock) - { - return emitIfElse(val, bodyBlock, breakBlock, bodyBlock); - } - - IRInst* IRBuilder::emitSwitch( - IRInst* val, - IRBlock* breakLabel, - IRBlock* defaultLabel, - UInt caseArgCount, - IRInst* const* caseArgs) - { - IRInst* fixedArgs[] = { val, breakLabel, defaultLabel }; - UInt fixedArgCount = sizeof(fixedArgs) / sizeof(fixedArgs[0]); - - auto inst = createInstWithTrailingArgs( - this, - kIROp_Switch, - nullptr, - fixedArgCount, - fixedArgs, - caseArgCount, - caseArgs); - addInst(inst); - return inst; - } - - IRGlobalGenericParam* IRBuilder::emitGlobalGenericParam() - { - IRGlobalGenericParam* irGenericParam = createInst( - this, - kIROp_GlobalGenericParam, - nullptr); - addGlobalValue(this, irGenericParam); - return irGenericParam; - } - - IRBindGlobalGenericParam* IRBuilder::emitBindGlobalGenericParam( - IRInst* param, - IRInst* val) - { - auto inst = createInst( - this, - kIROp_BindGlobalGenericParam, - nullptr, - param, - val); - addInst(inst); - return inst; - } - - IRInst* IRBuilder::emitBindGlobalExistentialSlots( - UInt argCount, - IRInst* const* args) - { - auto inst = createInstWithTrailingArgs( - this, - kIROp_BindGlobalExistentialSlots, - getVoidType(), - 0, - nullptr, - argCount, - args); - addInst(inst); - return inst; - } - - IRDecoration* IRBuilder::addBindExistentialSlotsDecoration( - IRInst* value, - UInt argCount, - IRInst* const* args) - { - auto decoration = createInstWithTrailingArgs( - this, - kIROp_BindExistentialSlotsDecoration, - getVoidType(), - 0, - nullptr, - argCount, - args); - - decoration->insertAtStart(value); - - return decoration; - } - - IRInst* IRBuilder::emitExtractTaggedUnionTag( - IRInst* val) - { - auto inst = createInst( - this, - kIROp_ExtractTaggedUnionTag, - getBasicType(BaseType::UInt), - val); - addInst(inst); - return inst; - } - - IRInst* IRBuilder::emitExtractTaggedUnionPayload( - IRType* type, - IRInst* val, - IRInst* tag) - { - auto inst = createInst( - this, - kIROp_ExtractTaggedUnionPayload, - type, - val, - tag); - addInst(inst); - return inst; - } - - IRInst* IRBuilder::emitBitCast( - IRType* type, - IRInst* val) - { - auto inst = createInst( - this, - kIROp_BitCast, - type, - val); - addInst(inst); - return inst; - } - - // - // Decorations - // - - IRDecoration* IRBuilder::addDecoration(IRInst* value, IROp op, IRInst* const* operands, Int operandCount) - { - auto decoration = createInstWithTrailingArgs( - this, - op, - getVoidType(), - operandCount, - operands); - - // Decoration order should not, in general, be semantically - // meaningful, so we will elect to insert a new decoration - // at the start of an instruction (constant time) rather - // than at the end of any existing list of deocrations - // (which would take time linear in the number of decorations). - // - // TODO: revisit this if maintaining decoration ordering - // from input source code is desirable. - // - decoration->insertAtStart(value); - - return decoration; - } - - - void IRBuilder::addHighLevelDeclDecoration(IRInst* inst, Decl* decl) - { - auto ptrConst = getPtrValue(addRefObjectToFree(decl)); - addDecoration(inst, kIROp_HighLevelDeclDecoration, ptrConst); - } - - void IRBuilder::addLayoutDecoration(IRInst* inst, Layout* layout) - { - auto ptrConst = getPtrValue(addRefObjectToFree(layout)); - addDecoration(inst, kIROp_LayoutDecoration, ptrConst); - } - - // - - - struct IRDumpContext - { - StringBuilder* builder = nullptr; - int indent = 0; - IRDumpMode mode = IRDumpMode::Simplified; - - Dictionary mapValueToName; - Dictionary uniqueNameCounters; - UInt uniqueIDCounter = 1; - }; - - static void dump( - IRDumpContext* context, - char const* text) - { - context->builder->append(text); - } - - static void dump( - IRDumpContext* context, - String const& text) - { - context->builder->append(text); - } - - /* - static void dump( - IRDumpContext* context, - UInt val) - { - context->builder->append(val); - } - */ - - static void dump( - IRDumpContext* context, - IntegerLiteralValue val) - { - context->builder->append(val); - } - - static void dump( - IRDumpContext* context, - FloatingPointLiteralValue val) - { - context->builder->append(val); - } - - static void dumpIndent( - IRDumpContext* context) - { - for (int ii = 0; ii < context->indent; ++ii) - { - dump(context, "\t"); - } - } - - bool opHasResult(IRInst* inst) - { - auto type = inst->getDataType(); - if (!type) return false; - - // As a bit of a hack right now, we need to check whether - // the function returns the distinguished `Void` type, - // since that is conceptually the same as "not returning - // a value." - if(type->op == kIROp_VoidType) - return false; - - return true; - } - - bool instHasUses(IRInst* inst) - { - return inst->firstUse != nullptr; - } - - static void scrubName( - String const& name, - StringBuilder& sb) - { - // Note: this function duplicates a lot of the logic - // in `EmitVisitor::scrubName`, so we should consider - // whether they can share code at some point. - // - // There is no requirement that assembly dumps and output - // code follow the same model, though, so this is just - // a nice-to-have rather than a maintenance problem - // waiting to happen. - - // Allow an empty nam - // Special case a name that is the empty string, just in case. - if(name.getLength() == 0) - { - sb.append('_'); - } - - int prevChar = -1; - for(auto c : name) - { - if(c == '.') - { - c = '_'; - } - - if(((c >= 'a') && (c <= 'z')) - || ((c >= 'A') && (c <= 'Z'))) - { - // Ordinary ASCII alphabetic characters are assumed - // to always be okay. - } - else if((c >= '0') && (c <= '9')) - { - // We don't want to allow a digit as the first - // byte in a name. - if(prevChar == -1) - { - sb.append('_'); - } - } - else - { - // If we run into a character that wouldn't normally - // be allowed in an identifier, we need to translate - // it into something that *is* valid. - // - // Our solution for now will be very clumsy: we will - // emit `x` and then the hexadecimal version of - // the byte we were given. - sb.append("x"); - sb.append(uint32_t((unsigned char) c), 16); - - // We don't want to apply the default handling below, - // so skip to the top of the loop now. - prevChar = c; - continue; - } - - sb.append(c); - prevChar = c; - } - - // If the whole thing ended with a digit, then add - // a final `_` just to make sure that we can append - // a unique ID suffix without risk of collisions. - if(('0' <= prevChar) && (prevChar <= '9')) - { - sb.append('_'); - } - } - - static String createName( - IRDumpContext* context, - IRInst* value) - { - if(auto nameHintDecoration = value->findDecoration()) - { - String nameHint = nameHintDecoration->getName(); - - StringBuilder sb; - scrubName(nameHint, sb); - - String key = sb.ProduceString(); - UInt count = 0; - context->uniqueNameCounters.TryGetValue(key, count); - - context->uniqueNameCounters[key] = count+1; - - if(count) - { - sb.append(count); - } - return sb.ProduceString(); - } - else - { - StringBuilder sb; - auto id = context->uniqueIDCounter++; - sb.append(id); - return sb.ProduceString(); - } - } - - static String getName( - IRDumpContext* context, - IRInst* value) - { - String name; - if (context->mapValueToName.TryGetValue(value, name)) - return name; - - name = createName(context, value); - context->mapValueToName.Add(value, name); - return name; - } - - static void dumpID( - IRDumpContext* context, - IRInst* inst) - { - if (!inst) - { - dump(context, ""); - return; - } - - if( opHasResult(inst) || instHasUses(inst) ) - { - dump(context, "%"); - dump(context, getName(context, inst)); - } - else - { - dump(context, "_"); - } - } - - - - struct StringEncoder - { - static char getHexChar(int v) - { - return (v <= 9) ? char(v + '0') : char(v - 10 + 'A'); - } - - void flush(const char* pos) - { - if (pos > m_runStart) - { - m_builder->append(m_runStart, pos); - } - m_runStart = pos + 1; - } - - void appendEscapedChar(const char* pos, char encodeChar) - { - flush(pos); - const char chars[] = { '\\', encodeChar }; - m_builder->Append(chars, 2); - } - - void appendAsHex(const char* pos) - { - flush(pos); - - const int v = *(const uint8_t*)pos; - - char buf[5]; - buf[0] = '\\'; - buf[1] = 'x'; - buf[2] = '0'; - - buf[3] = getHexChar(v >> 4); - buf[4] = getHexChar(v & 0xf); - - m_builder->Append(buf, 5); - } - - StringEncoder(StringBuilder* builder, const char* start): - m_runStart(start), - m_builder(builder) - {} - - StringBuilder* m_builder; - const char* m_runStart; - }; - - static void dumpEncodeString( - IRDumpContext* context, - const UnownedStringSlice& slice) - { - // https://msdn.microsoft.com/en-us/library/69ze775t.aspx - - StringBuilder& builder = *context->builder; - builder.Append('"'); - - { - const char* cur = slice.begin(); - StringEncoder encoder(&builder, cur); - const char* end = slice.end(); - - for (; cur < end; cur++) - { - const int8_t c = uint8_t(*cur); - switch (c) - { - case '\\': - encoder.appendEscapedChar(cur, '\\'); - break; - case '"': - encoder.appendEscapedChar(cur, '"'); - break; - case '\n': - encoder.appendEscapedChar(cur, 'n'); - break; - case '\t': - encoder.appendEscapedChar(cur, 't'); - break; - case '\r': - encoder.appendEscapedChar(cur, 'r'); - break; - case '\0': - encoder.appendEscapedChar(cur, '0'); - break; - default: - { - if (c < 32) - { - encoder.appendAsHex(cur); - } - break; - } - } - } - encoder.flush(end); - } - - builder.Append('"'); - } - - static void dumpType( - IRDumpContext* context, - IRType* type); - - static bool shouldFoldInstIntoUses( - IRDumpContext* context, - IRInst* inst) - { - // Never fold an instruction into its use site - // in the "detailed" mode, so that we always - // accurately reflect the structure of the IR. - // - if(context->mode == IRDumpMode::Detailed) - return false; - - if(as(inst)) - return true; - - // We are going to have a general rule that - // a type should be folded into its use site, - // which improves output in most cases, but - // we would like to not apply that rule to - // "nominal" types like `struct`s. - // - switch( inst->op ) - { - case kIROp_StructType: - case kIROp_InterfaceType: - return false; - - default: - break; - } - - if(as(inst)) - return true; - - return false; - } - - static void dumpInst( - IRDumpContext* context, - IRInst* inst); - - static void dumpInstBody( - IRDumpContext* context, - IRInst* inst); - - static void dumpInstExpr( - IRDumpContext* context, - IRInst* inst); - - static void dumpOperand( - IRDumpContext* context, - IRInst* inst) - { - // TODO: we should have a dedicated value for the `undef` case - if (!inst) - { - dumpID(context, inst); - return; - } - - if(shouldFoldInstIntoUses(context, inst)) - { - dumpInstExpr(context, inst); - return; - } - - dumpID(context, inst); - } - - static void dumpType( - IRDumpContext* context, - IRType* type) - { - if (!type) - { - dump(context, "_"); - return; - } - - // TODO: we should consider some special-case printing - // for types, so that the IR doesn't get too hard to read - // (always having to back-reference for what a type expands to) - dumpOperand(context, type); - } - - static void dumpInstTypeClause( - IRDumpContext* context, - IRType* type) - { - dump(context, "\t: "); - dumpType(context, type); - - } - - void dumpIRDecorations( - IRDumpContext* context, - IRInst* inst) - { - for(auto dd : inst->getDecorations()) - { - // Certain decorations aren't helpful to appear - // in output dumps, so we will only include them - // in the "detailed" dumping mode. - // - // For all other modes, we will check the opcode - // and skip selected decorations. - // - if(context->mode != IRDumpMode::Detailed) - { - switch(dd->op) - { - default: - break; - - case kIROp_HighLevelDeclDecoration: - case kIROp_LayoutDecoration: - continue; - } - } - - dump(context, "["); - dumpInstBody(context, dd); - dump(context, "]\n"); - - dumpIndent(context); - } - } - - static void dumpBlock( - IRDumpContext* context, - IRBlock* block) - { - context->indent--; - dump(context, "block "); - dumpID(context, block); - - IRInst* inst = block->getFirstInst(); - - // First walk through any `param` instructions, - // so that we can format them nicely - if (auto firstParam = as(inst)) - { - dump(context, "(\n"); - context->indent += 2; - - for(;;) - { - auto param = as(inst); - if (!param) - break; - - if (param != firstParam) - dump(context, ",\n"); - - inst = inst->getNextInst(); - - dumpIndent(context); - dumpIRDecorations(context, param); - dump(context, "param "); - dumpID(context, param); - dumpInstTypeClause(context, param->getFullType()); - } - context->indent -= 2; - dump(context, ")"); - } - dump(context, ":\n"); - context->indent++; - - for(; inst; inst = inst->getNextInst()) - { - dumpInst(context, inst); - } - } - - void dumpIRGlobalValueWithCode( - IRDumpContext* context, - IRGlobalValueWithCode* code) - { - auto opInfo = getIROpInfo(code->op); - - dumpIndent(context); - dump(context, opInfo.name); - dump(context, " "); - dumpID(context, code); - - dumpInstTypeClause(context, code->getFullType()); - - if (!code->getFirstBlock()) - { - // Just a declaration. - dump(context, ";\n"); - return; - } - - dump(context, "\n"); - - dumpIndent(context); - dump(context, "{\n"); - context->indent++; - - for (auto bb = code->getFirstBlock(); bb; bb = bb->getNextBlock()) - { - if (bb != code->getFirstBlock()) - dump(context, "\n"); - dumpBlock(context, bb); - } - - context->indent--; - dump(context, "}"); - } - - - void dumpIRWitnessTableEntry( - IRDumpContext* context, - IRWitnessTableEntry* entry) - { - dump(context, "witness_table_entry("); - dumpOperand(context, entry->requirementKey.get()); - dump(context, ","); - dumpOperand(context, entry->satisfyingVal.get()); - dump(context, ")\n"); - } - - void dumpIRParentInst( - IRDumpContext* context, - IRInst* inst) - { - auto opInfo = getIROpInfo(inst->op); - - dumpIndent(context); - dump(context, opInfo.name); - dump(context, " "); - dumpID(context, inst); - - dumpInstTypeClause(context, inst->getFullType()); - - if (!inst->getFirstChild()) - { - // Empty. - dump(context, ";\n"); - return; - } - - dump(context, "\n"); - - dumpIndent(context); - dump(context, "{\n"); - context->indent++; - - for(auto child : inst->getChildren()) - { - dumpInst(context, child); - } - - context->indent--; - dump(context, "}\n"); - } - - void dumpIRGeneric( - IRDumpContext* context, - IRGeneric* witnessTable) - { - dump(context, "\n"); - dumpIndent(context); - dump(context, "ir_witness_table "); - dumpID(context, witnessTable); - dump(context, "\n{\n"); - context->indent++; - - for (auto ii : witnessTable->getChildren()) - { - dumpInst(context, ii); - } - - context->indent--; - dump(context, "}\n"); - } - - static void dumpInstExpr( - IRDumpContext* context, - IRInst* inst) - { - if (!inst) - { - dump(context, ""); - return; - } - - auto op = inst->op; - auto opInfo = getIROpInfo(op); - - // Special-case the literal instructions. - if(auto irConst = as(inst)) - { - switch (op) - { - case kIROp_IntLit: - dump(context, irConst->value.intVal); - return; - - case kIROp_FloatLit: - dump(context, irConst->value.floatVal); - return; - - case kIROp_BoolLit: - dump(context, irConst->value.intVal ? "true" : "false"); - return; - - case kIROp_StringLit: - dumpEncodeString(context, irConst->getStringSlice()); - return; - - case kIROp_PtrLit: - dump(context, ""); - return; - - default: - break; - } - } - - dump(context, opInfo.name); - - UInt argCount = inst->getOperandCount(); - - if(argCount == 0) - return; - - UInt ii = 0; - - // Special case: make printing of `call` a bit - // nicer to look at - if (inst->op == kIROp_Call && argCount > 0) - { - dump(context, " "); - auto argVal = inst->getOperand(ii++); - dumpOperand(context, argVal); - } - - bool first = true; - dump(context, "("); - for (; ii < argCount; ++ii) - { - if (!first) - dump(context, ", "); - - auto argVal = inst->getOperand(ii); - - dumpOperand(context, argVal); - - first = false; - } - - dump(context, ")"); - - } - - static void dumpInstBody( - IRDumpContext* context, - IRInst* inst) - { - if (!inst) - { - dump(context, ""); - return; - } - - auto op = inst->op; - - dumpIRDecorations(context, inst); - - // There are several ops we want to special-case here, - // so that they will be more pleasant to look at. - // - switch (op) - { - case kIROp_Func: - case kIROp_GlobalVar: - case kIROp_GlobalConstant: - case kIROp_Generic: - dumpIRGlobalValueWithCode(context, (IRGlobalValueWithCode*)inst); - return; - - case kIROp_WitnessTable: - case kIROp_StructType: - dumpIRParentInst(context, inst); - return; - - case kIROp_WitnessTableEntry: - dumpIRWitnessTableEntry(context, (IRWitnessTableEntry*)inst); - return; - - default: - break; - } - - // Okay, we have a seemingly "ordinary" op now - auto dataType = inst->getDataType(); - auto rate = inst->getRate(); - - if(rate) - { - dump(context, "@"); - dumpOperand(context, rate); - dump(context, " "); - } - - if(opHasResult(inst) || instHasUses(inst)) - { - dump(context, "let "); - dumpID(context, inst); - dumpInstTypeClause(context, dataType); - dump(context, "\t= "); - } - else - { - // No result, okay... - } - - dumpInstExpr(context, inst); - } - - static void dumpInst( - IRDumpContext* context, - IRInst* inst) - { - if(shouldFoldInstIntoUses(context, inst)) - return; - - dumpIndent(context); - dumpInstBody(context, inst); - dump(context, "\n"); - } - - void dumpIRModule( - IRDumpContext* context, - IRModule* module) - { - for(auto ii : module->getGlobalInsts()) - { - dumpInst(context, ii); - } - } - - void printSlangIRAssembly(StringBuilder& builder, IRModule* module, IRDumpMode mode) - { - IRDumpContext context; - context.builder = &builder; - context.indent = 0; - context.mode = mode; - - dumpIRModule(&context, module); - } - - void dumpIR(IRInst* globalVal, ISlangWriter* writer, IRDumpMode mode) - { - StringBuilder sb; - - IRDumpContext context; - context.builder = &sb; - context.indent = 0; - context.mode = mode; - - dumpInst(&context, globalVal); - - writer->write(sb.getBuffer(), sb.getLength()); - writer->flush(); - } - - String getSlangIRAssembly(IRModule* module, IRDumpMode mode) - { - StringBuilder sb; - printSlangIRAssembly(sb, module, mode); - return sb; - } - - void dumpIR(IRModule* module, ISlangWriter* writer, IRDumpMode mode) - { - String ir = getSlangIRAssembly(module, mode); - writer->write(ir.getBuffer(), ir.getLength()); - writer->flush(); - } - - // Pre-declare - static bool _isTypeOperandEqual(IRInst* a, IRInst* b); - - static bool _areTypeOperandsEqual(IRInst* a, IRInst* b) - { - // Must have same number of operands - const auto operandCountA = Index(a->getOperandCount()); - if (operandCountA != Index(b->getOperandCount())) - { - return false; - } - - // All the operands must be equal - for (Index i = 0; i < operandCountA; ++i) - { - IRInst* operandA = a->getOperand(i); - IRInst* operandB = b->getOperand(i); - - if (!_isTypeOperandEqual(operandA, operandB)) - { - return false; - } - } - - return true; - } - - static bool _isNominalOp(IROp op) - { - // True if the op identity is 'nominal' - switch (op) - { - case kIROp_StructType: - case kIROp_InterfaceType: - case kIROp_Generic: - case kIROp_Param: - { - return true; - } - } - return false; - } - - // True if a type operand is equal. Operands are 'IRInst' - but it's only a restricted set that - // can be operands of IRType instructions - static bool _isTypeOperandEqual(IRInst* a, IRInst* b) - { - if (a == b) - { - return true; - } - - if (a == nullptr || b == nullptr) - { - return false; - } - - const IROp opA = IROp(a->op & kIROpMeta_PseudoOpMask); - const IROp opB = IROp(b->op & kIROpMeta_PseudoOpMask); - - if (opA != opB) - { - return false; - } - - // If the type is nominal - it can only be the same if the pointer is the same. - if (_isNominalOp(opA)) - { - // The pointer isn't the same (as that was already tested), so cannot be equal - return false; - } - - // Both are types - if (IRType::isaImpl(opA)) - { - if (IRBasicType::isaImpl(opA)) - { - // If it's a basic type, then their op being the same means we are done - return true; - } - - // We don't care about the parent or positioning - // We also don't care about 'type' - because these instructions are defining the type. - // - // We may want to care about decorations. - - // If it's a resource type - special case the handling of the resource flavor - if (IRResourceTypeBase::isaImpl(opA) && - static_cast(a)->getFlavor() != static_cast(b)->getFlavor()) - { - return false; - } - - // TODO(JS): There is a question here about what to do about decorations. - // For now we ignore decorations. Are two types potentially different if there decorations different? - // If decorations play a part in difference in types - the order of decorations presumably is not important. - - // All the operands of the types must be equal - return _areTypeOperandsEqual(a, b); - } - - // If it's a constant... - if (IRConstant::isaImpl(opA)) - { - // TODO: This is contrived in that we want two types that are the same, but are different - // pointers to match here. - // If we make GetHashCode for IRType* compatible with isTypeEqual, then we should probably use that. - return static_cast(a)->isValueEqual(static_cast(b)) && - isTypeEqual(a->getFullType(), b->getFullType()); - } - - SLANG_ASSERT(!"Unhandled comparison"); - - // We can't equate any other type.. - return false; - } - - bool isTypeEqual(IRType* a, IRType* b) - { - // _isTypeOperandEqual handles comparison of types so can defer to it - return _isTypeOperandEqual(a, b); - } - - void findAllInstsBreadthFirst(IRInst* inst, List& outInsts) - { - Index index = outInsts.getCount(); - - outInsts.add(inst); - - while (index < outInsts.getCount()) - { - IRInst* cur = outInsts[index++]; - - IRInstListBase childrenList = cur->getDecorationsAndChildren(); - for (IRInst* child : childrenList) - { - outInsts.add(child); - } - } - } - - IRDecoration* IRInst::getFirstDecoration() - { - return as(getFirstDecorationOrChild()); - } - - IRDecoration* IRInst::getLastDecoration() - { - IRDecoration* decoration = getFirstDecoration(); - if (!decoration) return nullptr; - - while (auto nextDecoration = decoration->getNextDecoration()) - decoration = nextDecoration; - - return decoration; - } - - IRInstList IRInst::getDecorations() - { - return IRInstList( - getFirstDecoration(), - getLastDecoration()); - } - - IRInst* IRInst::getFirstChild() - { - // The children come after any decorations, - // so if there are any decorations, then the - // first child is right after the last decoration. - // - if(auto lastDecoration = getLastDecoration()) - return lastDecoration->getNextInst(); - // - // Otherwise, there must be no decorations, so - // that the first "child or decoration" is a child. - // - return getFirstDecorationOrChild(); - } - - IRInst* IRInst::getLastChild() - { - // The children come after any decorations, so - // that the last item in the list of children - // and decorations is the last child *unless* - // it is a decoration, in which case there are - // no children. - // - auto lastChild = getLastDecorationOrChild(); - return as(lastChild) ? nullptr : lastChild; - } - - - IRRate* IRInst::getRate() - { - if(auto rateQualifiedType = as(getFullType())) - return rateQualifiedType->getRate(); - - return nullptr; - } - - IRType* IRInst::getDataType() - { - auto type = getFullType(); - if(auto rateQualifiedType = as(type)) - return rateQualifiedType->getValueType(); - - return type; - } - - void IRInst::replaceUsesWith(IRInst* other) - { - // Safety check: don't try to replace something with itself. - if(other == this) - return; - - // We will walk through the list of uses for the current - // instruction, and make them point to the other inst. - IRUse* ff = firstUse; - - // No uses? Nothing to do. - if(!ff) - return; - - ff->debugValidate(); - - IRUse* uu = ff; - for(;;) - { - // The uses had better all be uses of this - // instruction, or invariants are broken. - SLANG_ASSERT(uu->get() == this); - - // Swap this use over to use the other value. - uu->usedValue = other; - - // Try to move to the next use, but bail - // out if we are at the last one. - IRUse* nn = uu->nextUse; - if( !nn ) - break; - - uu = nn; - } - - // We are at the last use (and there must - // be at least one, because we handled - // the case of an empty list earlier). - SLANG_ASSERT(uu); - - // Our job at this point is to splice - // our list of uses onto the other - // value's uses. - // - // If the value already had uses, then - // we need to patch our new list onto - // the front. - if( auto nn = other->firstUse ) - { - uu->nextUse = nn; - nn->prevLink = &uu->nextUse; - } - - // No matter what, our list of - // uses will become the start - // of the list of uses for - // `other` - other->firstUse = ff; - ff->prevLink = &other->firstUse; - - // And `this` will have no uses any more. - this->firstUse = nullptr; - - ff->debugValidate(); - } - - // Insert this instruction into the same basic block - // as `other`, right before it. - void IRInst::insertBefore(IRInst* other) - { - SLANG_ASSERT(other); - _insertAt(other->getPrevInst(), other, other->getParent()); - } - - void IRInst::insertAtStart(IRInst* newParent) - { - SLANG_ASSERT(newParent); - _insertAt(nullptr, newParent->getFirstDecorationOrChild(), newParent); - } - - void IRInst::moveToStart() - { - auto p = parent; - removeFromParent(); - insertAtStart(p); - } - - void IRInst::_insertAt(IRInst* inPrev, IRInst* inNext, IRInst* inParent) - { - // Make sure this instruction has been removed from any previous parent - this->removeFromParent(); - - SLANG_ASSERT(inParent); - SLANG_ASSERT(!inPrev || (inPrev->getNextInst() == inNext) && (inPrev->getParent() == inParent)); - SLANG_ASSERT(!inNext || (inNext->getPrevInst() == inPrev) && (inNext->getParent() == inParent)); - - if( inPrev ) - { - inPrev->next = this; - } - else - { - inParent->m_decorationsAndChildren.first = this; - } - - if (inNext) - { - inNext->prev = this; - } - else - { - inParent->m_decorationsAndChildren.last = this; - } - - this->prev = inPrev; - this->next = inNext; - this->parent = inParent; - } - - void IRInst::insertAfter(IRInst* other) - { - SLANG_ASSERT(other); - removeFromParent(); - _insertAt(other, other->getNextInst(), other->getParent()); - } - - void IRInst::insertAtEnd(IRInst* newParent) - { - SLANG_ASSERT(newParent); - removeFromParent(); - _insertAt(newParent->getLastDecorationOrChild(), nullptr, newParent); - } - - void IRInst::moveToEnd() - { - auto p = parent; - removeFromParent(); - insertAtEnd(p); - } - - // Remove this instruction from its parent block, - // and then destroy it (it had better have no uses!) - void IRInst::removeFromParent() - { - auto oldParent = getParent(); - - // If we don't currently have a parent, then - // we are doing fine. - if(!oldParent) - return; - - auto pp = getPrevInst(); - auto nn = getNextInst(); - - if(pp) - { - SLANG_ASSERT(pp->getParent() == oldParent); - pp->next = nn; - } - else - { - oldParent->m_decorationsAndChildren.first = nn; - } - - if(nn) - { - SLANG_ASSERT(nn->getParent() == oldParent); - nn->prev = pp; - } - else - { - oldParent->m_decorationsAndChildren.last = pp; - } - - prev = nullptr; - next = nullptr; - parent = nullptr; - } - - void IRInst::removeArguments() - { - typeUse.clear(); - for( UInt aa = 0; aa < operandCount; ++aa ) - { - IRUse& use = getOperands()[aa]; - use.clear(); - } - } - - // Remove this instruction from its parent block, - // and then destroy it (it had better have no uses!) - void IRInst::removeAndDeallocate() - { - removeFromParent(); - removeArguments(); - removeAndDeallocateAllDecorationsAndChildren(); - - // Run destructor to be sure... - this->~IRInst(); - } - - void IRInst::removeAndDeallocateAllDecorationsAndChildren() - { - IRInst* nextChild = nullptr; - for( IRInst* child = getFirstDecorationOrChild(); child; child = nextChild ) - { - nextChild = child->getNextInst(); - child->removeAndDeallocate(); - } - } - - void IRInst::transferDecorationsTo(IRInst* target) - { - while( auto decoration = getFirstDecoration() ) - { - decoration->removeFromParent(); - decoration->insertAtStart(target); - } - } - - bool IRInst::mightHaveSideEffects() - { - // TODO: We should drive this based on flags specified - // in `ir-inst-defs.h` isntead of hard-coding things here, - // but this is good enough for now if we are conservative: - - if(as(this)) - return false; - - if(as(this)) - return false; - - switch(op) - { - // By default, assume that we might have side effects, - // to safely cover all the instructions we haven't had time to think about. - default: - return true; - - case kIROp_Call: - { - // In the general case, a function call must be assumed to - // have almost arbitrary side effects. - // - // However, it is possible that the callee can be identified, - // and it may be a function with an attribute that explicitly - // limits the side effects it is allowed to have. - // - // For now, we will explicitly check for the `[__readNone]` - // attribute, which was used to mark functions that compute - // their result strictly as a function of the arguments (and - // not anything they point to, or other non-argument state). - // Calls to such functions cannot have side effects (except - // for things like stack overflow that abstract language models - // tend to ignore), and can be subject to dead code elimination, - // common subexpression elimination, etc. - // - auto call = cast(this); - auto callee = getResolvedInstForDecorations(call->getCallee()); - if(callee->findDecoration()) - { - return false; - } - } - return true; - - // All of the cases for "global values" are side-effect-free. - case kIROp_StructType: - case kIROp_StructField: - case kIROp_Func: - case kIROp_Generic: - case kIROp_GlobalVar: - case kIROp_GlobalConstant: - case kIROp_GlobalParam: - case kIROp_StructKey: - case kIROp_GlobalGenericParam: - case kIROp_WitnessTable: - case kIROp_WitnessTableEntry: - case kIROp_Block: - return false; - - case kIROp_Nop: - case kIROp_Specialize: - case kIROp_lookup_interface_method: - case kIROp_Construct: - case kIROp_makeVector: - case kIROp_MakeMatrix: - case kIROp_makeArray: - case kIROp_makeStruct: - case kIROp_Load: // We are ignoring the possibility of loads from bad addresses, or `volatile` loads - case kIROp_FieldExtract: - case kIROp_FieldAddress: - case kIROp_getElement: - case kIROp_getElementPtr: - case kIROp_constructVectorFromScalar: - case kIROp_swizzle: - case kIROp_swizzleSet: // Doesn't actually "set" anything - just returns the resulting vector - case kIROp_Add: - case kIROp_Sub: - case kIROp_Mul: - //case kIROp_Div: // TODO: We could split out integer vs. floating-point div/mod and assume the floating-point cases have no side effects - //case kIROp_Mod: - case kIROp_Lsh: - case kIROp_Rsh: - case kIROp_Eql: - case kIROp_Neq: - case kIROp_Greater: - case kIROp_Less: - case kIROp_Geq: - case kIROp_Leq: - case kIROp_BitAnd: - case kIROp_BitXor: - case kIROp_BitOr: - case kIROp_And: - case kIROp_Or: - case kIROp_Neg: - case kIROp_Not: - case kIROp_BitNot: - case kIROp_Select: - case kIROp_Dot: - case kIROp_Mul_Vector_Matrix: - case kIROp_Mul_Matrix_Vector: - case kIROp_Mul_Matrix_Matrix: - case kIROp_MakeExistential: - case kIROp_ExtractExistentialType: - case kIROp_ExtractExistentialValue: - case kIROp_ExtractExistentialWitnessTable: - case kIROp_WrapExistential: - return false; - } - } - - IRModule* IRInst::getModule() - { - IRInst* ii = this; - while(ii) - { - if(auto moduleInst = as(ii)) - return moduleInst->module; - - ii = ii->getParent(); - } - return nullptr; - } - - // - // IRType - // - - IRType* unwrapArray(IRType* type) - { - IRType* t = type; - while( auto arrayType = as(t) ) - { - t = arrayType->getElementType(); - } - return t; - } - - IRTargetIntrinsicDecoration* findTargetIntrinsicDecoration( - IRInst* val, - String const& targetName) - { - for(auto dd : val->getDecorations()) - { - if(dd->op != kIROp_TargetIntrinsicDecoration) - continue; - - auto decoration = (IRTargetIntrinsicDecoration*) dd; - if(String(decoration->getTargetName()) == targetName) - return decoration; - } - - return nullptr; - } - -#if 0 - IRFunc* cloneSimpleFuncWithoutRegistering(IRSpecContextBase* context, IRFunc* originalFunc) - { - auto clonedFunc = context->builder->createFunc(); - cloneFunctionCommon(context, clonedFunc, originalFunc, false); - return clonedFunc; - } -#endif - - IRInst* findGenericReturnVal(IRGeneric* generic) - { - auto lastBlock = generic->getLastBlock(); - if (!lastBlock) - return nullptr; - - auto returnInst = as(lastBlock->getTerminator()); - if (!returnInst) - return nullptr; - - auto val = returnInst->getVal(); - return val; - } - - IRInst* getResolvedInstForDecorations(IRInst* inst) - { - IRInst* candidate = inst; - while(auto specInst = as(candidate)) - { - auto genericInst = as(specInst->getBase()); - if(!genericInst) - break; - - auto returnVal = findGenericReturnVal(genericInst); - if(!returnVal) - break; - - candidate = returnVal; - } - return candidate; - } - - bool isDefinition( - IRInst* inVal) - { - IRInst* val = inVal; - // unwrap any generic declarations to see - // the value they return. - for(;;) - { - auto genericInst = as(val); - if(!genericInst) - break; - - auto returnVal = findGenericReturnVal(genericInst); - if(!returnVal) - break; - - val = returnVal; - } - - // TODO: the logic here should probably - // be that anything with an `IRImportDecoration` - // is considered to be a declaration rather than definition. - - switch (val->op) - { - case kIROp_WitnessTable: - case kIROp_GlobalConstant: - case kIROp_Func: - case kIROp_Generic: - return val->getFirstChild() != nullptr; - - case kIROp_StructType: - case kIROp_GlobalVar: - case kIROp_GlobalParam: - return true; - - default: - return false; - } - } - - void markConstExpr( - IRBuilder* builder, - IRInst* irValue) - { - // We will take an IR value with type `T`, - // and turn it into one with type `@ConstExpr T`. - - // TODO: need to be careful if the value already has a rate - // qualifier set. - - irValue->setFullType( - builder->getRateQualifiedType( - builder->getConstExprRate(), - irValue->getDataType())); - } -} diff --git a/source/slang/ir.h b/source/slang/ir.h deleted file mode 100644 index e7f9dff75..000000000 --- a/source/slang/ir.h +++ /dev/null @@ -1,1202 +0,0 @@ -// ir.h -#ifndef SLANG_IR_H_INCLUDED -#define SLANG_IR_H_INCLUDED - -// This file defines the intermediate representation (IR) used for Slang -// shader code. This is a typed static single assignment (SSA) IR, -// similar in spirit to LLVM (but much simpler). -// - -#include "../core/basic.h" - -#include "source-loc.h" - -#include "../core/slang-memory-arena.h" -#include "../core/slang-object-scope-manager.h" - -#include "type-system-shared.h" - -namespace Slang { - -class Decl; -class GenericDecl; -class FuncType; -class Layout; -class Type; -class Session; -class Name; -struct IRBuilder; -struct IRFunc; -struct IRGlobalValueWithCode; -struct IRInst; -struct IRModule; - -typedef unsigned int IROpFlags; -enum : IROpFlags -{ - kIROpFlags_None = 0, - kIROpFlag_Parent = 1 << 0, ///< This op is a parent op - kIROpFlag_UseOther = 1 << 1, ///< If set this op can use 'other bits' to store information -}; - -/* Bit usage of IROp is a follows - - MainOp | Pseudo | Other -Bit range: 0-7 | 8 | Remaining bits - -If an instruction is 'pseudo' (ie shouldn't appear in output IR), then the Pseudo bit is set - and 'Invalid' falls into -this category as well as all pseudo ops. -For doing range checks (for example for doing isa tests), the value is masked by kIROpMeta_OpMask, such that the Other bits don't interfere. -The other bits can be used for storage for anything that needs to identify as a different 'op' or 'type'. It is currently -used currently for storing the TextureFlavor of a IRResourceTypeBase derived types for example. -*/ -enum IROp : int32_t -{ -#define INST(ID, MNEMONIC, ARG_COUNT, FLAGS) \ - kIROp_##ID, - -#include "ir-inst-defs.h" - - kIROpCount, - - // We use the range 0x100 to 0x1ff set for pseudo/non main codes - // Instructions that should not appear in valid IR. - - kIROp_Invalid = 0x100, ///< If bit set, then in pseudo/not normal space - kIRPseudoOp_First = kIROp_Invalid, - -#define INST(ID, MNEMONIC, ARG_COUNT, FLAGS) /* empty */ -#define PSEUDO_INST(ID) kIRPseudoOp_##ID, - - kIRPseudoOp_LastPlusOne, - -#include "ir-inst-defs.h" - -#define INST(ID, MNEMONIC, ARG_COUNT, FLAGS) /* empty */ -#define INST_RANGE(BASE, FIRST, LAST) \ - kIROp_First##BASE = kIROp_##FIRST, \ - kIROp_Last##BASE = kIROp_##LAST, - -#include "ir-inst-defs.h" -}; - -/* IROpMeta describe values for layout of IROp, as well as values for accessing aspects of IROp bits. */ -enum IROpMeta -{ - kIROpMeta_OtherShift = 9, ///< Number of bits for op/pseudo ops (shift right by this to get the other bits) - kIROpMeta_PseudoOpMask = (int32_t(1) << kIROpMeta_OtherShift) - 1, ///< Mask for ops including pseudo ops - kIROpMeta_OpMask = 0xff, ///< Mask for just ops - kIrOpMeta_OtherMask = ~kIROpMeta_PseudoOpMask, ///< Mask for bits that can be used for other purposes than 'op' ('other' bits) - kIROpMeta_IsPseudoOp = kIROp_Invalid, ///< 'And' with op, if set, the op is a pseudo op -}; - -// True if op is pseudo (or invalid which is 'pseudo-like' at least in as so far as current behavior) -SLANG_FORCE_INLINE bool isPseudoOp(IROp op) { return (op & kIROpMeta_IsPseudoOp) != 0; } - -IROp findIROp(const UnownedStringSlice& name); - -// A logical operation/opcode in the IR -struct IROpInfo -{ - // What is the name/mnemonic for this operation - char const* name; - - // How many required arguments are there - // (not including the mandatory type argument) - unsigned int fixedArgCount; - - // Flags to control how we emit additional info - IROpFlags flags; -}; - -// Look up the info for an op -IROpInfo getIROpInfo(IROp op); - -// A use of another value/inst within an IR operation -struct IRUse -{ - IRInst* get() const { return usedValue; } - IRInst* getUser() const { return user; } - - void init(IRInst* user, IRInst* usedValue); - void set(IRInst* usedValue); - void clear(); - - // The instruction that is being used - IRInst* usedValue = nullptr; - - // The instruction that is doing the using. - IRInst* user = nullptr; - - // The next use of the same value - IRUse* nextUse = nullptr; - - // A "link" back to where this use is referenced, - // so that we can simplify updates. - IRUse** prevLink = nullptr; - - void debugValidate(); -}; - -struct IRBlock; -struct IRDecoration; -struct IRRate; -struct IRType; - -// A double-linked list of instruction -struct IRInstListBase -{ - IRInstListBase() - {} - - IRInstListBase(IRInst* first, IRInst* last) - : first(first) - , last(last) - {} - - - - IRInst* first = nullptr; - IRInst* last = nullptr; - - IRInst* getFirst() { return first; } - IRInst* getLast() { return last; } - - struct Iterator - { - IRInst* inst; - - Iterator() : inst(nullptr) {} - Iterator(IRInst* inst) : inst(inst) {} - - void operator++(); - IRInst* operator*() - { - return inst; - } - - bool operator!=(Iterator const& i) - { - return inst != i.inst; - } - }; - - Iterator begin(); - Iterator end(); -}; - -// Specialization of `IRInstListBase` for the case where -// we know (or at least expect) all of the instructions -// to be of type `T` -template -struct IRInstList : IRInstListBase -{ - IRInstList() {} - - IRInstList(T* first, T* last) - : IRInstListBase(first, last) - {} - - explicit IRInstList(IRInstListBase const& list) - : IRInstListBase(list) - {} - - T* getFirst() { return (T*) first; } - T* getLast() { return (T*) last; } - - struct Iterator : public IRInstListBase::Iterator - { - Iterator() {} - Iterator(IRInst* inst) : IRInstListBase::Iterator(inst) {} - - T* operator*() - { - return (T*) inst; - } - }; - - Iterator begin() { return Iterator(first); } - Iterator end(); -}; - - - -// Every value in the IR is an instruction (even things -// like literal values). -// -struct IRInst -{ - // The operation that this value represents - IROp op; - - // The total number of operands of this instruction. - // - // TODO: We shouldn't need to allocate this on - // all instructions. Instead we should have - // instructions that need "vararg" support to - // allocate this field ahead of the `this` - // pointer. - uint32_t operandCount = 0; - - UInt getOperandCount() - { - return operandCount; - } - - // Source location information for this value, if any - SourceLoc sourceLoc; - - // Each instruction can have zero or more "decorations" - // attached to it. A decoration is a specialized kind - // of instruction that either attaches metadata to, - // or modifies the sematnics of, its parent instruction. - // - IRDecoration* getFirstDecoration(); - IRDecoration* getLastDecoration(); - IRInstList getDecorations(); - - // Look up a decoration in the list of decorations - IRDecoration* findDecorationImpl(IROp op); - template - T* findDecoration(); - - // The first use of this value (start of a linked list) - IRUse* firstUse = nullptr; - - - // The parent of this instruction. - IRInst* parent; - - IRInst* getParent() { return parent; } - - // The next and previous instructions with the same parent - IRInst* next; - IRInst* prev; - - IRInst* getNextInst() { return next; } - IRInst* getPrevInst() { return prev; } - - // An instruction can have zero or more children, although - // only certain instruction opcodes are allowed to have - // children. - // - // For example, a function will have children that are - // its basic blocks, and the basic blocks will have children - // that represent parameters and ordinary executable instructions. - // - IRInst* getFirstChild(); - IRInst* getLastChild(); - IRInstList getChildren() - { - return IRInstList( - getFirstChild(), - getLastChild()); - } - - /// A doubly-linked list containing any decorations and then any children of this instruction. - /// - /// We store both the decorations and children of an instruction - /// in the same list, to conserve space in the instruction itself - /// (rather than storing distinct lists for decorations and children). - /// - // Note: This field is *not* being declared `private` because doing so could - // mess with our required memory layout, where `typeUse` below is assumed - // to be the last field in `IRInst` and to come right before any additional - // `IRUse` values that represent operands. - // - IRInstListBase m_decorationsAndChildren; - - IRInst* getFirstDecorationOrChild() { return m_decorationsAndChildren.first; } - IRInst* getLastDecorationOrChild() { return m_decorationsAndChildren.last; } - IRInstListBase getDecorationsAndChildren() { return m_decorationsAndChildren; } - - void removeAndDeallocateAllDecorationsAndChildren(); - - // The type of the result value of this instruction, - // or `null` to indicate that the instruction has - // no value. - IRUse typeUse; - - IRType* getFullType() { return (IRType*) typeUse.get(); } - void setFullType(IRType* type) { typeUse.init(this, (IRInst*) type); } - - IRRate* getRate(); - - IRType* getDataType(); - - // After the type, we have data that is specific to - // the subtype of `IRInst`. In most cases, this is - // just a series of `IRUse` values representing - // operands of the instruction. - - IRUse* getOperands(); - - IRInst* getOperand(UInt index) - { - return getOperands()[index].get(); - } - - void setOperand(UInt index, IRInst* value) - { - getOperands()[index].set(value); - } - - - // - - // Replace all uses of this value with `other`, so - // that this value will now have no uses. - void replaceUsesWith(IRInst* other); - - // Insert this instruction into the same basic block - // as `other`, right before/after it. - void insertBefore(IRInst* other); - void insertAfter(IRInst* other); - - // Insert as first/last child of given parent - void insertAtStart(IRInst* parent); - void insertAtEnd(IRInst* parent); - - // Move to the start/end of current parent - void moveToStart(); - void moveToEnd(); - - // Remove this instruction from its parent block, - // but don't delete it, or replace uses. - void removeFromParent(); - - // Remove this instruction from its parent block, - // and then destroy it (it had better have no uses!) - void removeAndDeallocate(); - - // Clear out the arguments of this instruction, - // so that we don't appear on the list of uses - // for those values. - void removeArguments(); - - /// Transfer any decorations of this instruction to the `target` instruction. - void transferDecorationsTo(IRInst* target); - - /// Does this instruction have any uses? - bool hasUses() const { return firstUse != nullptr; } - - /// Does this instructiomn have more than one use? - bool hasMoreThanOneUse() const { return firstUse != nullptr && firstUse->nextUse != nullptr; } - - /// It is possible that this instruction has side effects? - /// - /// This is a conservative test, and will return `true` if an exact answer can't be determined. - bool mightHaveSideEffects(); - - // RTTI support - static bool isaImpl(IROp) { return true; } - - /// Find the module that this instruction is nested under. - /// - /// If this instruction is transitively nested inside some IR module, - /// this function will return it, and will otherwise return `null`. - IRModule* getModule(); - - /// Insert this instruction into `inParent`, after `inPrev` and before `inNext`. - /// - /// `inParent` must be non-null - /// If `inPrev` is non-null it must satisfy `inPrev->getNextInst() == inNext` and `inPrev->getParent() == inParent` - /// If `inNext` is non-null it must satisfy `inNext->getPrevInst() == inPrev` and `inNext->getParent() == inParent` - /// - /// If both `inPrev` and `inNext` are null, then `inParent` must have no (raw) children. - /// - void _insertAt(IRInst* inPrev, IRInst* inNext, IRInst* inParent); -}; - -template -T* dynamicCast(IRInst* inst) -{ - if (inst && T::isaImpl(inst->op)) - return static_cast(inst); - return nullptr; -} - -template -const T* dynamicCast(const IRInst* inst) -{ - if (inst && T::isaImpl(inst->op)) - return static_cast(inst); - return nullptr; -} - -// `dynamic_cast` equivalent (we just use dynamicCast) -template -T* as(IRInst* inst) -{ - return dynamicCast(inst); -} - -template -const T* as(const IRInst* inst) -{ - return dynamicCast(inst); -} - -// `static_cast` equivalent, with debug validation -template -T* cast(IRInst* inst, T* /* */ = nullptr) -{ - SLANG_ASSERT(!inst || as(inst)); - return (T*)inst; -} - -// Now that `IRInst` is defined we can back-fill the definitions that need to access it. - -template -T* IRInst::findDecoration() -{ - for( auto decoration : getDecorations() ) - { - if(auto match = as(decoration)) - return match; - } - return nullptr; -} - -template -typename IRInstList::Iterator IRInstList::end() -{ - return Iterator(last ? last->next : nullptr); -} - - -// Types - -#define IR_LEAF_ISA(NAME) static bool isaImpl(IROp op) { return (kIROpMeta_PseudoOpMask & op) == kIROp_##NAME; } -#define IR_PARENT_ISA(NAME) static bool isaImpl(IROp opIn) { const int op = (kIROpMeta_PseudoOpMask & opIn); return op >= kIROp_First##NAME && op <= kIROp_Last##NAME; } - -#define SIMPLE_IR_TYPE(NAME, BASE) struct IR##NAME : IR##BASE { IR_LEAF_ISA(NAME) }; -#define SIMPLE_IR_PARENT_TYPE(NAME, BASE) struct IR##NAME : IR##BASE { IR_PARENT_ISA(NAME) }; - - -// All types in the IR are represented as instructions which conceptually -// execute before run time. -struct IRType : IRInst -{ - IRType* getCanonicalType() { return this; } - - IR_PARENT_ISA(Type) -}; - -IRType* unwrapArray(IRType* type); - -struct IRBasicType : IRType -{ - BaseType getBaseType() { return BaseType(op - kIROp_FirstBasicType); } - - IR_PARENT_ISA(BasicType) -}; - -struct IRVoidType : IRBasicType -{ - IR_LEAF_ISA(VoidType) -}; - -struct IRBoolType : IRBasicType -{ - IR_LEAF_ISA(BoolType) -}; - -SIMPLE_IR_TYPE(StringType, Type) - - -// True if types are equal -// Note compares nominal types by name alone -bool isTypeEqual(IRType* a, IRType* b); - -void findAllInstsBreadthFirst(IRInst* inst, List& outInsts); - -// Constant Instructions - -typedef int64_t IRIntegerValue; -typedef double IRFloatingPointValue; - -struct IRConstant : IRInst -{ - struct StringValue - { - uint32_t numChars; ///< The number of chars - char chars[1]; ///< Chars added at end. NOTE! Must be last member of struct! - }; - struct StringSliceValue - { - uint32_t numChars; - char* chars; - }; - - union ValueUnion - { - IRIntegerValue intVal; ///< Used for integrals and boolean - IRFloatingPointValue floatVal; - void* ptrVal; - - /// Either of these types could be set with kIROp_StringLit. - /// Which is used is currently determined with decorations - if a kIROp_TransitoryDecoration is set, then the transitory StringVal is used, else stringVal - // which relies on chars being held after the struct). - StringValue stringVal; - StringSliceValue transitoryStringVal; - }; - - /// Returns a string slice (or empty string if not appropriate) - UnownedStringSlice getStringSlice(); - - /// True if constants are equal - bool equal(IRConstant* rhs); - /// True if the value is equal. - /// Does *NOT* compare if the type is equal. - bool isValueEqual(IRConstant* rhs); - - /// Get the hash - int getHashCode(); - - IR_PARENT_ISA(Constant) - - // Must be last member, because data may be held behind - // NOTE! The total size of IRConstant may not be allocated - only enough space is allocated for the value type held in the union. - ValueUnion value; -}; - -struct IRIntLit : IRConstant -{ - IRIntegerValue getValue() { return value.intVal; } - - IR_LEAF_ISA(IntLit); -}; - -struct IRBoolLit : IRConstant -{ - bool getValue() { return value.intVal != 0; } - - IR_LEAF_ISA(BoolLit); -}; - - -// Get the compile-time constant integer value of an instruction, -// if it has one, and assert-fail otherwise. -IRIntegerValue GetIntVal(IRInst* inst); - -struct IRStringLit : IRConstant -{ - - IR_LEAF_ISA(StringLit); -}; - -struct IRPtrLit : IRConstant -{ - IR_LEAF_ISA(PtrLit); - - void* getValue() { return value.ptrVal; } -}; - -// A instruction that ends a basic block (usually because of control flow) -struct IRTerminatorInst : IRInst -{ - IR_PARENT_ISA(TerminatorInst) -}; - -// A function parameter is owned by a basic block, and represents -// either an incoming function parameter (in the entry block), or -// a value that flows from one SSA block to another (in a non-entry -// block). -// -// In each case, the basic idea is that a block is a "label with -// arguments." -// -struct IRParam : IRInst -{ - IRParam* getNextParam(); - IRParam* getPrevParam(); - - IR_LEAF_ISA(Param) -}; - -// A basic block is a parent instruction that adds the constraint -// that all the children need to be "ordinary" instructions (so -// no function declarations, or nested blocks). We also expect -// that the previous/next instruction are always a basic block. -// -struct IRBlock : IRInst -{ - // Linked list of the instructions contained in this block - // - IRInst* getFirstInst() { return getChildren().first; } - IRInst* getLastInst() { return getChildren().last; } - - // In a valid program, every basic block should end with - // a "terminator" instruction. - // - // This function will return the terminator, if it exists, - // or `null` if there is none. - IRTerminatorInst* getTerminator() { return as(getLastDecorationOrChild()); } - - // We expect that the siblings of a basic block will - // always be other basic blocks (we don't allow - // mixing of blocks and other instructions in the - // same parent). - // - // The exception to this is that decorations on the function - // that contains a block could appear before the first block, - // so we need to be careful to do a dynamic cast (`as`) in - // the `getPrevBlock` case, but don't need to worry about - // it for `getNextBlock`. - IRBlock* getPrevBlock() { return as(getPrevInst()); } - IRBlock* getNextBlock() { return cast(getNextInst()); } - - // The parameters of a block are represented by `IRParam` - // instructions at the start of the block. These play - // the role of function parameters for the entry block - // of a function, and of phi nodes in other blocks. - IRParam* getFirstParam() { return as(getFirstInst()); } - IRParam* getLastParam(); - IRInstList getParams() - { - return IRInstList( - getFirstParam(), - getLastParam()); - } - - void addParam(IRParam* param); - - // The "ordinary" instructions come after the parameters - IRInst* getFirstOrdinaryInst(); - IRInst* getLastOrdinaryInst(); - IRInstList getOrdinaryInsts() - { - return IRInstList( - getFirstOrdinaryInst(), - getLastOrdinaryInst()); - } - - // The parent of a basic block is assumed to be a - // value with code (e.g., a function, global variable - // with initializer, etc.). - IRGlobalValueWithCode* getParent() { return cast(IRInst::getParent()); } - - // The predecessor and successor lists of a block are needed - // when we want to work with the control flow graph (CFG) of - // a function. Rather than store these explicitly (and thus - // need to update them when transformations might change the - // CFG), we compute predecessors and successors in an - // implicit fashion using the use-def information for a - // block itself. - // - // To a first approximation, the predecessors of a block - // are the blocks where the IR value of the block is used. - // Similarly, the successors of a block are all values used - // by the terminator instruction of the block. - // The `getPredecessors()` and `getSuccessors()` functions - // make this more precise. - // - struct PredecessorList - { - PredecessorList(IRUse* begin) : b(begin) {} - IRUse* b; - - UInt getCount(); - bool isEmpty(); - - struct Iterator - { - Iterator(IRUse* use) : use(use) {} - - IRBlock* operator*(); - - void operator++(); - - bool operator!=(Iterator const& that) - { - return use != that.use; - } - - IRUse* use; - }; - - Iterator begin() { return Iterator(b); } - Iterator end() { return Iterator(nullptr); } - }; - - struct SuccessorList - { - SuccessorList(IRUse* begin, IRUse* end, UInt stride = 1) : begin_(begin), end_(end), stride(stride) {} - IRUse* begin_; - IRUse* end_; - UInt stride; - - UInt getCount(); - - struct Iterator - { - Iterator(IRUse* use, UInt stride) : use(use), stride(stride) {} - - IRBlock* operator*(); - - void operator++(); - - bool operator!=(Iterator const& that) - { - return use != that.use; - } - - IRUse* use; - UInt stride; - }; - - Iterator begin() { return Iterator(begin_, stride); } - Iterator end() { return Iterator(end_, stride); } - }; - - PredecessorList getPredecessors(); - SuccessorList getSuccessors(); - - // - - IR_LEAF_ISA(Block) -}; - -SIMPLE_IR_TYPE(BasicBlockType, Type) - -struct IRResourceTypeBase : IRType -{ - TextureFlavor getFlavor() const - { - return TextureFlavor((op >> kIROpMeta_OtherShift) & 0xFFFF); - } - - TextureFlavor::Shape GetBaseShape() const - { - return getFlavor().GetBaseShape(); - } - bool isMultisample() const { return getFlavor().isMultisample(); } - bool isArray() const { return getFlavor().isArray(); } - SlangResourceShape getShape() const { return getFlavor().getShape(); } - SlangResourceAccess getAccess() const { return getFlavor().getAccess(); } - - IR_PARENT_ISA(ResourceTypeBase); -}; - -struct IRResourceType : IRResourceTypeBase -{ - IRType* getElementType() { return (IRType*)getOperand(0); } - - IR_PARENT_ISA(ResourceType) -}; - -struct IRTextureTypeBase : IRResourceType -{ - IR_PARENT_ISA(TextureTypeBase) -}; - -struct IRTextureType : IRTextureTypeBase -{ - IR_LEAF_ISA(TextureType) -}; - -struct IRTextureSamplerType : IRTextureTypeBase -{ - IR_LEAF_ISA(TextureSamplerType) -}; - -struct IRGLSLImageType : IRTextureTypeBase -{ - IR_LEAF_ISA(GLSLImageType) -}; - -struct IRSamplerStateTypeBase : IRType -{ - IR_PARENT_ISA(SamplerStateTypeBase) -}; - -SIMPLE_IR_TYPE(SamplerStateType, SamplerStateTypeBase) -SIMPLE_IR_TYPE(SamplerComparisonStateType, SamplerStateTypeBase) - -struct IRBuiltinGenericType : IRType -{ - IRType* getElementType() { return (IRType*)getOperand(0); } - - IR_PARENT_ISA(BuiltinGenericType) -}; - -SIMPLE_IR_PARENT_TYPE(PointerLikeType, BuiltinGenericType); -SIMPLE_IR_PARENT_TYPE(HLSLStructuredBufferTypeBase, BuiltinGenericType) -SIMPLE_IR_TYPE(HLSLStructuredBufferType, HLSLStructuredBufferTypeBase) -SIMPLE_IR_TYPE(HLSLRWStructuredBufferType, HLSLStructuredBufferTypeBase) -SIMPLE_IR_TYPE(HLSLRasterizerOrderedStructuredBufferType, HLSLStructuredBufferTypeBase) - -SIMPLE_IR_PARENT_TYPE(UntypedBufferResourceType, Type) -SIMPLE_IR_PARENT_TYPE(ByteAddressBufferTypeBase, UntypedBufferResourceType) -SIMPLE_IR_TYPE(HLSLByteAddressBufferType, ByteAddressBufferTypeBase) -SIMPLE_IR_TYPE(HLSLRWByteAddressBufferType, ByteAddressBufferTypeBase) -SIMPLE_IR_TYPE(HLSLRasterizerOrderedByteAddressBufferType, ByteAddressBufferTypeBase) - -SIMPLE_IR_TYPE(HLSLAppendStructuredBufferType, HLSLStructuredBufferTypeBase) -SIMPLE_IR_TYPE(HLSLConsumeStructuredBufferType, HLSLStructuredBufferTypeBase) - -struct IRHLSLPatchType : IRType -{ - IRType* getElementType() { return (IRType*)getOperand(0); } - IRInst* getElementCount() { return getOperand(1); } - - IR_PARENT_ISA(HLSLPatchType) -}; - -SIMPLE_IR_TYPE(HLSLInputPatchType, HLSLPatchType) -SIMPLE_IR_TYPE(HLSLOutputPatchType, HLSLPatchType) - -SIMPLE_IR_PARENT_TYPE(HLSLStreamOutputType, BuiltinGenericType) -SIMPLE_IR_TYPE(HLSLPointStreamType, HLSLStreamOutputType) -SIMPLE_IR_TYPE(HLSLLineStreamType, HLSLStreamOutputType) -SIMPLE_IR_TYPE(HLSLTriangleStreamType, HLSLStreamOutputType) - -SIMPLE_IR_TYPE(GLSLInputAttachmentType, Type) -SIMPLE_IR_PARENT_TYPE(ParameterGroupType, PointerLikeType) -SIMPLE_IR_PARENT_TYPE(UniformParameterGroupType, ParameterGroupType) -SIMPLE_IR_PARENT_TYPE(VaryingParameterGroupType, ParameterGroupType) -SIMPLE_IR_TYPE(ConstantBufferType, UniformParameterGroupType) -SIMPLE_IR_TYPE(TextureBufferType, UniformParameterGroupType) -SIMPLE_IR_TYPE(GLSLInputParameterGroupType, VaryingParameterGroupType) -SIMPLE_IR_TYPE(GLSLOutputParameterGroupType, VaryingParameterGroupType) -SIMPLE_IR_TYPE(GLSLShaderStorageBufferType, UniformParameterGroupType) -SIMPLE_IR_TYPE(ParameterBlockType, UniformParameterGroupType) - -struct IRArrayTypeBase : IRType -{ - IRType* getElementType() { return (IRType*)getOperand(0); } - - // Returns the element count for an `IRArrayType`, and null - // for an `IRUnsizedArrayType`. - IRInst* getElementCount(); - - IR_PARENT_ISA(ArrayTypeBase) -}; - -struct IRArrayType: IRArrayTypeBase -{ - IRInst* getElementCount() { return getOperand(1); } - - IR_LEAF_ISA(ArrayType) -}; - -SIMPLE_IR_TYPE(UnsizedArrayType, ArrayTypeBase) - -SIMPLE_IR_PARENT_TYPE(Rate, Type) -SIMPLE_IR_TYPE(ConstExprRate, Rate) -SIMPLE_IR_TYPE(GroupSharedRate, Rate) - -struct IRRateQualifiedType : IRType -{ - IRRate* getRate() { return (IRRate*) getOperand(0); } - IRType* getValueType() { return (IRType*) getOperand(1); } - - IR_LEAF_ISA(RateQualifiedType) -}; - - -// Unlike the AST-level type system where `TypeType` tracks the -// underlying type, the "type of types" in the IR is a simple -// value with no operands, so that all type nodes have the -// same type. -SIMPLE_IR_PARENT_TYPE(Kind, Type); -SIMPLE_IR_TYPE(TypeKind, Kind); - -// The kind of any and all generics. -// -// A more complete type system would include "arrow kinds" to -// be able to track the domain and range of generics (e.g., -// the `vector` generic maps a type and an integer to a type). -// This is only really needed if we ever wanted to support -// "higher-kinded" generics (e.g., a generic that takes another -// generic as a parameter). -// -SIMPLE_IR_TYPE(GenericKind, Kind) - -struct IRVectorType : IRType -{ - IRType* getElementType() { return (IRType*)getOperand(0); } - IRInst* getElementCount() { return getOperand(1); } - - IR_LEAF_ISA(VectorType) -}; - -struct IRMatrixType : IRType -{ - IRType* getElementType() { return (IRType*)getOperand(0); } - IRInst* getRowCount() { return getOperand(1); } - IRInst* getColumnCount() { return getOperand(2); } - - IR_LEAF_ISA(MatrixType) -}; - -struct IRPtrTypeBase : IRType -{ - IRType* getValueType() { return (IRType*)getOperand(0); } - - IR_PARENT_ISA(PtrTypeBase) -}; - -SIMPLE_IR_TYPE(PtrType, PtrTypeBase) -SIMPLE_IR_TYPE(RefType, PtrTypeBase) - -SIMPLE_IR_PARENT_TYPE(OutTypeBase, PtrTypeBase) -SIMPLE_IR_TYPE(OutType, OutTypeBase) -SIMPLE_IR_TYPE(InOutType, OutTypeBase) -SIMPLE_IR_TYPE(ExistentialBoxType, PtrTypeBase) - - /// Get the type pointed to be `ptrType`, or `nullptr` if it is not a pointer(-like) type. - /// - /// The given IR `builder` will be used if new instructions need to be created. -IRType* tryGetPointedToType( - IRBuilder* builder, - IRType* type); - -struct IRFuncType : IRType -{ - IRType* getResultType() { return (IRType*) getOperand(0); } - UInt getParamCount() { return getOperandCount() - 1; } - IRType* getParamType(UInt index) { return (IRType*)getOperand(1 + index); } - - IR_LEAF_ISA(FuncType) -}; - -bool isDefinition( - IRInst* inVal); - -// A structure type is represented as a parent instruction, -// where the child instructions represent the fields of the -// struct. -// -// The space of fields that a given struct type supports -// are defined as its "keys", which are global values -// (that is, they have mangled names that can be used -// for linkage). -// -struct IRStructKey : IRInst -{ - IR_LEAF_ISA(StructKey) -}; -// -// The fields of the struct are then defined as mappings -// from those keys to the associated type (in the case of -// the struct type) or to values (when lookup up a field). -// -// A struct field thus has two operands: the key, and the -// type of the field. -// -struct IRStructField : IRInst -{ - IRStructKey* getKey() { return cast(getOperand(0)); } - IRType* getFieldType() - { - // Note: We do not use `cast` here because there are - // cases of types (which we would like to conveniently - // refer to via an `IRType*`) which do not actually - // inherit from `IRType` in the hierarchy. - // - return (IRType*) getOperand(1); - } - - IR_LEAF_ISA(StructField) -}; -// -// The struct type is then represented as a parent instruction -// that contains the various fields. Note that a struct does -// *not* contain the keys, because code needs to be able to -// reference the keys from scopes outside of the struct. -// -struct IRStructType : IRType -{ - IRInstList getFields() { return IRInstList(getChildren()); } - - IR_LEAF_ISA(StructType) -}; - -struct IRInterfaceType : IRType -{ - IR_LEAF_ISA(InterfaceType) -}; - -struct IRTaggedUnionType : IRType -{ - IR_LEAF_ISA(TaggedUnionType) -}; - -struct IRBindExistentialsType : IRType -{ - IR_LEAF_ISA(BindExistentialsType) - - IRType* getBaseType() { return (IRType*) getOperand(0); } - UInt getExistentialArgCount() { return getOperandCount() - 1; } - IRUse* getExistentialArgs() { return getOperands() + 1; } - IRInst* getExistentialArg(UInt index) { return getExistentialArgs()[index].get(); } -}; - -/// @brief A global value that potentially holds executable code. -/// -struct IRGlobalValueWithCode : IRInst -{ - // The children of a value with code will be the basic - // blocks of its definition. - IRBlock* getFirstBlock() { return cast(getFirstChild()); } - IRBlock* getLastBlock() { return cast(getLastChild()); } - IRInstList getBlocks() - { - return IRInstList(getChildren()); - } - - // Add a block to the end of this function. - void addBlock(IRBlock* block); - - IR_PARENT_ISA(GlobalValueWithCode) -}; - -// A value that has parameters so that it can conceptually be called. -struct IRGlobalValueWithParams : IRGlobalValueWithCode -{ - // Convenience accessor for the IR parameters, - // which are actually the parameters of the first - // block. - IRParam* getFirstParam(); - IRParam* getLastParam(); - IRInstList getParams(); - - IR_PARENT_ISA(GlobalValueWithParams) -}; - -// A function is a parent to zero or more blocks of instructions. -// -// A function is itself a value, so that it can be a direct operand of -// an instruction (e.g., a call). -struct IRFunc : IRGlobalValueWithParams -{ - // The type of the IR-level function - IRFuncType* getDataType() { return (IRFuncType*) IRInst::getDataType(); } - - // Convenience accessors for working with the - // function's type. - IRType* getResultType(); - UInt getParamCount(); - IRType* getParamType(UInt index); - - bool isDefinition() { return getFirstBlock() != nullptr; } - - IR_LEAF_ISA(Func) -}; - - /// Adjust the type of an IR function based on its parameter list. -void fixUpFuncType(IRFunc* func); - -// A generic is akin to a function, but is conceptually executed -// before runtime, to specialize the code nested within. -// -// In practice, a generic always holds only a single block, and ends -// with a `return` instruction for the value that the generic yields. -struct IRGeneric : IRGlobalValueWithParams -{ - IR_LEAF_ISA(Generic) -}; - -// Find the value that is returned from a generic, so that -// a pass can glean information from it. -IRInst* findGenericReturnVal(IRGeneric* generic); - -// Resolve an instruction that might reference a static definition -// to the most specific IR node possible, so that we can read -// decorations from it (e.g., if this is a `specialize` instruction, -// then try to chase down the generic being specialized, and what -// it seems to return). -// -IRInst* getResolvedInstForDecorations(IRInst* inst); - -// The IR module itself is represented as an instruction, which -// serves at the root of the tree of all instructions in the module. -struct IRModuleInst : IRInst -{ - // Pointer back to the non-instruction object that represents - // the module, so that we can get back to it in algorithms - // that need it. - IRModule* module; - - IRInstListBase getGlobalInsts() { return getChildren(); } - - IR_LEAF_ISA(Module) -}; - -struct IRModule : RefObject -{ - enum - { - kMemoryArenaBlockSize = 16 * 1024, ///< Use 16k block size for memory arena - }; - - SLANG_FORCE_INLINE Session* getSession() const { return session; } - SLANG_FORCE_INLINE IRModuleInst* getModuleInst() const { return moduleInst; } - - IRInstListBase getGlobalInsts() const { return getModuleInst()->getChildren(); } - - /// Get the object scope manager - SLANG_FORCE_INLINE ObjectScopeManager* getObjectScopeManager() { return &m_objectScopeManager; } - - /// Ctor - IRModule(): - memoryArena(kMemoryArenaBlockSize) - { - } - - MemoryArena memoryArena; - - // The compilation session in use. - Session* session; - IRModuleInst* moduleInst; - - protected: - - ObjectScopeManager m_objectScopeManager; -}; - - /// How much detail to include in dumped IR. - /// - /// Used with the `dumpIR` functions to determine - /// whether a completely faithful, but verbose, IR - /// dump is produced, or something simplified for ease - /// or reading. - /// -enum class IRDumpMode -{ - /// Produce a simplified IR dump. - /// - /// Simplified IR dumping will skip certain instructions - /// and print them at their use sites instead, so that - /// the overall dump is shorter and easier to read. - Simplified, - - /// Produce a detailed/accurate IR dump. - /// - /// A detailed IR dump will make sure to emit exactly - /// the instructions that were present with no attempt - /// to selectively skip them or give special formatting. - /// - Detailed, -}; - -void printSlangIRAssembly(StringBuilder& builder, IRModule* module, IRDumpMode mode = IRDumpMode::Simplified); -String getSlangIRAssembly(IRModule* module, IRDumpMode mode = IRDumpMode::Simplified); - -void dumpIR(IRModule* module, ISlangWriter* writer, IRDumpMode mode = IRDumpMode::Simplified); -void dumpIR(IRInst* globalVal, ISlangWriter* writer, IRDumpMode mode = IRDumpMode::Simplified); - -IRInst* createEmptyInst( - IRModule* module, - IROp op, - int totalArgCount); - -IRInst* createEmptyInstWithSize( - IRModule* module, - IROp op, - size_t totalSizeInBytes); -} - -#endif diff --git a/source/slang/legalize-types.cpp b/source/slang/legalize-types.cpp deleted file mode 100644 index 64dafb938..000000000 --- a/source/slang/legalize-types.cpp +++ /dev/null @@ -1,1486 +0,0 @@ -// legalize-types.cpp -#include "legalize-types.h" - -#include "ir-insts.h" -#include "mangle.h" - -namespace Slang -{ - -LegalType LegalType::implicitDeref( - LegalType const& valueType) -{ - RefPtr obj = new ImplicitDerefType(); - obj->valueType = valueType; - - LegalType result; - result.flavor = Flavor::implicitDeref; - result.obj = obj; - return result; -} - -LegalType LegalType::tuple( - RefPtr tupleType) -{ - SLANG_ASSERT(tupleType->elements.getCount()); - - LegalType result; - result.flavor = Flavor::tuple; - result.obj = tupleType; - return result; -} - -LegalType LegalType::pair( - RefPtr pairType) -{ - LegalType result; - result.flavor = Flavor::pair; - result.obj = pairType; - return result; -} - -LegalType LegalType::pair( - LegalType const& ordinaryType, - LegalType const& specialType, - RefPtr pairInfo) -{ - // Handle some special cases for when - // one or the other of the types isn't - // actually used. - - if (ordinaryType.flavor == LegalType::Flavor::none) - { - // There was nothing ordinary. - return specialType; - } - - if (specialType.flavor == LegalType::Flavor::none) - { - return ordinaryType; - } - - // There were both ordinary and special fields, - // and so we need to handle them here. - - RefPtr obj = new PairPseudoType(); - obj->ordinaryType = ordinaryType; - obj->specialType = specialType; - obj->pairInfo = pairInfo; - return LegalType::pair(obj); -} - -LegalType LegalType::makeWrappedBuffer( - IRType* simpleType, - LegalElementWrapping const& elementInfo) -{ - RefPtr obj = new WrappedBufferPseudoType(); - obj->simpleType = simpleType; - obj->elementInfo = elementInfo; - - LegalType result; - result.flavor = Flavor::wrappedBuffer; - result.obj = obj; - return result; -} - -// - -LegalElementWrapping LegalElementWrapping::makeVoid() -{ - LegalElementWrapping result; - result.flavor = Flavor::none; - return result; -} - -LegalElementWrapping LegalElementWrapping::makeSimple(IRStructKey* key, IRType* type) -{ - RefPtr obj = new SimpleLegalElementWrappingObj(); - obj->key = key; - obj->type = type; - - LegalElementWrapping result; - result.flavor = Flavor::simple; - result.obj = obj; - return result; -} - -RefPtr LegalElementWrapping::getSimple() const -{ - SLANG_ASSERT(flavor == Flavor::simple); - return obj.as(); -} - -LegalElementWrapping LegalElementWrapping::makeImplicitDeref(LegalElementWrapping const& field) -{ - RefPtr obj = new ImplicitDerefLegalElementWrappingObj(); - obj->field = field; - - LegalElementWrapping result; - result.flavor = Flavor::implicitDeref; - result.obj = obj; - return result; -} - -RefPtr LegalElementWrapping::getImplicitDeref() const -{ - SLANG_ASSERT(flavor == Flavor::implicitDeref); - return obj.as(); -} - -LegalElementWrapping LegalElementWrapping::makePair( - LegalElementWrapping const& ordinary, - LegalElementWrapping const& special, - PairInfo* pairInfo) -{ - RefPtr obj = new PairLegalElementWrappingObj(); - obj->ordinary = ordinary; - obj->special = special; - obj->pairInfo = pairInfo; - - LegalElementWrapping result; - result.flavor = Flavor::pair; - result.obj = obj; - return result; -} - -RefPtr LegalElementWrapping::getPair() const -{ - SLANG_ASSERT(flavor == Flavor::pair); - return obj.as(); -} - -LegalElementWrapping LegalElementWrapping::makeTuple(TupleLegalElementWrappingObj* obj) -{ - LegalElementWrapping result; - result.flavor = Flavor::tuple; - result.obj = obj; - return result; -} - -RefPtr LegalElementWrapping::getTuple() const -{ - SLANG_ASSERT(flavor == Flavor::tuple); - return obj.as(); -} - -// - -bool isResourceType(IRType* type) -{ - while (auto arrayType = as(type)) - { - type = arrayType->getElementType(); - } - - if (auto resourceTypeBase = as(type)) - { - return true; - } - else if (auto builtinGenericType = as(type)) - { - return true; - } - else if (auto pointerLikeType = as(type)) - { - return true; - } - else if (auto samplerType = as(type)) - { - return true; - } - else if(auto untypedBufferType = as(type)) - { - return true; - } - - // TODO: need more comprehensive coverage here - - return false; -} - -ModuleDecl* findModuleForDecl( - Decl* decl) -{ - for (auto dd = decl; dd; dd = dd->ParentDecl) - { - if (auto moduleDecl = as(dd)) - return moduleDecl; - } - return nullptr; -} - - -// Helper type for legalization of aggregate types -// that might need to be turned into tuple pseudo-types. -struct TupleTypeBuilder -{ - TypeLegalizationContext* context; - IRType* type; - IRStructType* originalStructType; - - struct OrdinaryElement - { - IRStructKey* fieldKey = nullptr; - IRType* type = nullptr; - }; - - - List ordinaryElements; - List specialElements; - - List pairElements; - - // Did we have any fields that forced us to change - // the actual type away from the declared type? - bool anyComplex = false; - - // Did we have any fields that actually required - // storage in the "special" part of things? - bool anySpecial = false; - - // Did we have any fields that actually used ordinary storage? - bool anyOrdinary = false; - - // Add a field to the (pseudo-)type we are building - void addField( - IRStructKey* fieldKey, - LegalType legalFieldType, - LegalType legalLeafType, - bool isSpecial) - { - LegalType ordinaryType; - LegalType specialType; - RefPtr elementPairInfo; - switch (legalLeafType.flavor) - { - case LegalType::Flavor::simple: - { - // We need to add an actual field, but we need - // to check if it is a resource type to know - // whether it should go in the "ordinary" list or not. - if (!isSpecial) - { - ordinaryType = legalLeafType; - } - else - { - specialType = legalFieldType; - } - } - break; - - case LegalType::Flavor::none: - anyComplex = true; - break; - - case LegalType::Flavor::implicitDeref: - { - // TODO: we may want to say that any use - // of `implicitDeref` puts the entire thing - // into the "special" category, rather than - // try to look under the hood... - - anyComplex = true; - - // We want to recursively add data - // based on the unwrapped type. - // - // Note: this assumes we can't have a tuple - // or a pair "under" an `implicitDeref`, so - // we'll need to ensure that elsewhere. - addField( - fieldKey, - legalFieldType, - legalLeafType.getImplicitDeref()->valueType, - isSpecial); - return; - } - break; - - case LegalType::Flavor::pair: - { - // The field's type had both special and non-special parts - auto pairType = legalLeafType.getPair(); - - // If things originally started as a resource type, then - // we want to externalize all the fields that arose, even - // if there is (nominally) ordinary data. - // - // This is because the "ordinary" side of the legalization - // of `ConstantBuffer` will still be a resource type. - if(isSpecial) - { - specialType = legalFieldType; - } - else - { - ordinaryType = pairType->ordinaryType; - specialType = pairType->specialType; - elementPairInfo = pairType->pairInfo; - } - } - break; - - case LegalType::Flavor::tuple: - { - // A tuple always represents "special" data - specialType = legalFieldType; - } - break; - - default: - SLANG_UNEXPECTED("unknown legal type flavor"); - break; - } - - PairInfo::Element pairElement; - pairElement.flags = 0; - pairElement.key = fieldKey; - pairElement.fieldPairInfo = elementPairInfo; - - // We will always add a field to the "ordinary" - // side of things, even if it has no ordinary - // data, just to keep the list of fields aligned - // with the original type. - OrdinaryElement ordinaryElement; - ordinaryElement.fieldKey = fieldKey; - if (ordinaryType.flavor != LegalType::Flavor::none) - { - anyOrdinary = true; - pairElement.flags |= PairInfo::kFlag_hasOrdinary; - - LegalType ot = ordinaryType; - - // TODO: any cases we should "unwrap" here? - // E.g., `implicitDeref`? - - if(ot.flavor == LegalType::Flavor::simple) - { - ordinaryElement.type = ot.getSimple(); - } - else - { - SLANG_UNEXPECTED("unexpected ordinary field type"); - } - } - ordinaryElements.add(ordinaryElement); - - if (specialType.flavor != LegalType::Flavor::none) - { - anySpecial = true; - anyComplex = true; - pairElement.flags |= PairInfo::kFlag_hasSpecial; - - TuplePseudoType::Element specialElement; - specialElement.key = fieldKey; - specialElement.type = specialType; - specialElements.add(specialElement); - } - - pairElement.type = LegalType::pair(ordinaryType, specialType, elementPairInfo); - pairElements.add(pairElement); - } - - // Add a field to the (pseudo-)type we are building - void addField( - IRStructField* field) - { - auto fieldType = field->getFieldType(); - - bool isSpecialField = context->isSpecialType(fieldType); - auto legalFieldType = legalizeType(context, fieldType); - - addField( - field->getKey(), - legalFieldType, - legalFieldType, - isSpecialField); - } - - LegalType getResult() - { - // If this is an empty struct, return a none type - // This helps get rid of emtpy structs that often trips up the - // downstream compiler - if (!anyOrdinary && !anySpecial && !anyComplex) - return LegalType(); - - // If we didn't see anything "special" - // then we can use the type as-is. - // we can conceivably just use the type as-is - // - if (!anyComplex) - { - return LegalType::simple(type); - } - - // If there were any "ordinary" fields along the way, - // then we need to collect them into a new `struct` type - // that represents these fields. - // - LegalType ordinaryType; - if (anyOrdinary) - { - // We are going to create an new IR `struct` type that contains - // the "ordinary" fields from the original type. Note that these - // fields may have different types from what they did before, - // because the fields themselves might have been legalized. - // - // The new type will have the same mangled name as the old one, so - // downstream code is going to need to be careful not to emit declarations - // for both of them. This should be okay, though, because the original - // type was illegal (that was the whole point) and so it shouldn't be - // referenced in the output anyway. - // - IRBuilder* builder = context->getBuilder(); - IRStructType* ordinaryStructType = builder->createStructType(); - ordinaryStructType->sourceLoc = originalStructType->sourceLoc; - - if(auto nameHintDecoration = originalStructType->findDecoration()) - { - builder->addNameHintDecoration(ordinaryStructType, nameHintDecoration->getNameOperand()); - } - - // The new struct type will appear right after the original in the IR, - // so that we can be sure any instruction that could reference the - // original can also reference the new one. - ordinaryStructType->insertAfter(originalStructType); - - // Mark the original type for removal once all the other legalization - // activity is completed. This is necessary because both the original - // and replacement type have the same mangled name, so they would - // collide. - // - // (Also, the original type wasn't legal - that was the whole point...) - context->replacedInstructions.add(originalStructType); - - for(auto ee : ordinaryElements) - { - // We will ensure that all the original fields are represented, - // although they may have different types (due to legalization). - // For fields that have *no* ordinary data, we will give them - // a dummy `void` type and rely on downstream passes to not - // actually emit declarations for those fields. - // - // (This helps keeps things simple because both the original - // and modified type will have the same number of fields, so - // we can continue to look up field layouts by index in the - // emit logic) - // - // TODO: we should scrap that, and layout lookup should just - // be based on mangled field names in all cases. - // - IRType* fieldType = ee.type; - if(!fieldType) - fieldType = context->getBuilder()->getVoidType(); - - // TODO: shallow clone of modifiers, etc. - - builder->createStructField( - ordinaryStructType, - ee.fieldKey, - fieldType); - } - - ordinaryType = LegalType::simple((IRType*) ordinaryStructType); - } - - LegalType specialType; - if (anySpecial) - { - RefPtr specialTuple = new TuplePseudoType(); - specialTuple->elements = specialElements; - specialType = LegalType::tuple(specialTuple); - } - - RefPtr pairInfo; - if (anyOrdinary && anySpecial) - { - pairInfo = new PairInfo(); - pairInfo->elements = pairElements; - } - - return LegalType::pair(ordinaryType, specialType, pairInfo); - } - -}; - -static IRType* createBuiltinGenericType( - TypeLegalizationContext* context, - IROp op, - IRType* elementType) -{ - IRInst* operands[] = { elementType }; - return context->getBuilder()->getType( - op, - 1, - operands); -} - -// Create a uniform buffer type with a given legalized -// element type. -static LegalType createLegalUniformBufferType( - TypeLegalizationContext* context, - IROp op, - LegalType legalElementType) -{ - // We will handle some of the easy/non-interesting - // cases here in the main routine, but for all - // the non-trivial cases we will dispatch to logic - // on the `context` (which may differ depending - // on what we are using legalization to accomplish). - // - switch (legalElementType.flavor) - { - default: - return context->createLegalUniformBufferType( - op, - legalElementType); - - case LegalType::Flavor::none: - return LegalType(); - - case LegalType::Flavor::simple: - { - // Easy case: we just have a simple element type, - // so we want to create a uniform buffer that wraps it. - // - // TODO: This isn't *quite* right, since it won't handle something - // like a `ParameterBlock`, but that seems like - // an unlikely case in practice. - // - return LegalType::simple(createBuiltinGenericType( - context, - op, - legalElementType.getSimple())); - } - break; - - case LegalType::Flavor::implicitDeref: - { - // This is actually an annoying case, because - // we are being asked to convert, e.g.,: - // - // cbuffer Foo { ParameterBlock bar; } - // - // into the equivalent of: - // - // cbuffer Foo { Bar bar; } - // - // Which would really require a new `LegalType` that - // would reprerent a resource type with a modified - // element type. - // - // I'm going to attempt to hack this for now. - return LegalType::implicitDeref(createLegalUniformBufferType( - context, - op, - legalElementType.getImplicitDeref()->valueType)); - } - break; - } -} - -// Create a uniform buffer type with a given legalized element type, -// under the assumption that we are doing resource-based type legalization. -// -LegalType createLegalUniformBufferTypeForResources( - TypeLegalizationContext* context, - IROp op, - LegalType legalElementType) -{ - switch (legalElementType.flavor) - { - case LegalType::Flavor::simple: - { - // Seeing a simple type here means that it must be a - // "special" type (a resource type or array thereof) - // because otherwise the catch-all behavior in - // `createLegalUniformBufferType()` would have handled it. - // - // This case is the same as what we do for tuple elements below. - // - return LegalType::implicitDeref(legalElementType); - } - - case LegalType::Flavor::pair: - { - auto pairType = legalElementType.getPair(); - - // The pair has both an "ordinary" and a "special" - // side, where the ordinary side is just plain data - // that we can put in a constant buffer type without - // any problems. The special side will (recursively) - // contain any resource-type fields that were nested - // in the constant buffer, and we'll need to - // treat those as resources that stand alongside - // the original constant buffer. - // - // We can start with the ordinary side, which we - // just want to wrap up in an ordinary uniform - // buffer with the appropriate `op`, so that case - // is easy: - // - auto ordinaryType = createLegalUniformBufferType( - context, - op, - pairType->ordinaryType); - - // For the special side, we really just want to turn - // a special field of type `R` into a value of type - // `R`, and the main detail we have to be aware of - // is that any use sites for the original buffer/block - // will include a dereferencing step to get from - // the block to this field, so we need to add - // something to the type structure to account for - // that step. - // - // We handle that issue by wrapping the special - // part of the type in an `implicitDeref` wrapper, - // which indicates that we logically have `SomePtr` - // but we actually just have `R`, and any attempt to - // load from or otherwise indirect through that pointer - // will turn into a plain old reference to the `R` value. - // - auto specialType = LegalType::implicitDeref(pairType->specialType); - - // Once we've wrapped up both the ordinary and special - // sides suitably, we tie them back together in a pair - // and make that be the legalized type of the result. - // - return LegalType::pair(ordinaryType, specialType, pairType->pairInfo); - } - - case LegalType::Flavor::tuple: - { - // A tuple type always represents purely "special" data, - // which in this case means resources. - // - // As in the `pair` case, the main thing we have to - // take into account is that each of the entries in the - // tuple itself (e.g., a value of type `R`) and the code - // that uses the legalized buffer type will expect a - // `ConstantBuffer` or at least `SomePtrType`. - // - // We will construct a new tuple type that wraps each - // of the element types in an `implicitDeref` to - // account for the different in levels of indirection. - // - // TODO: This seems odd, because we *should* be able to - // just wrap the whole thing in an `implicitDeref` and - // have done. We should investigate why this roundabout - // way of doing things was ever necessary. - - auto elementPseudoTupleType = legalElementType.getTuple(); - RefPtr bufferPseudoTupleType = new TuplePseudoType(); - - for (auto ee : elementPseudoTupleType->elements) - { - TuplePseudoType::Element newElement; - - newElement.key = ee.key; - newElement.type = LegalType::implicitDeref(ee.type); - - bufferPseudoTupleType->elements.add(newElement); - } - - return LegalType::tuple(bufferPseudoTupleType); - } - break; - - default: - SLANG_UNEXPECTED("unhandled legal type flavor"); - UNREACHABLE_RETURN(LegalType()); - break; - } -} - -// Legalizing a uniform buffer/block type for existentials is -// more interesting, because we don't actually want to push -// the "special" fields out of the buffer entirely (as we -// do for resources), and instead we just want to place -// them in the buffer *after* all the ordinary data. -// -// In order to accomplish this we need a way to emit a -// constant buffer with a new element type, and then -// "wrap" that constant buffer so that it looks like -// something that matches the legalization of the original -// element type. -// -// As a concrete example, suppose we have: -// -// struct Params { ExistentialBox f; int x; ExistentialBox b; }; -// ConstantBuffer gParams; -// -// The legalized form of `Params` will be something like: -// -// Pair( -// /* ordinary: */ struct Params_Ordinary { int x; }, -// /* special: */ Tuple( -// f -> ImplicitDeref(Foo), -// b -> ImplicitDeref(Bar))) -// -// We need to be able to splat that all out into a single -// structure declaration like: -// -// struct Params_Reordered -// { -// Params_Ordinary ordinary; -// Foo f; -// Bar b; -// } -// -// That allows us to declare: -// -// ConstantBuffer gParams; -// -// That gets the in-memory layout of things correct for the -// way we are defining existential value slots to work. -// The challenge is that elsewehere in the code there are -// operations like `gParams.x` need to now refer to -// `gParams.ordinary.x`. Furthermore, even for something like -// `f` that seems fine in the example above, we have lost -// a level of indirection, so that where we had `load(gParams.f)` -// we now want just `gParams.f`. -// -// The solution is to take `gParams` as soon as it is declared -// and wrap it up as a new value: -// -// pair( -// /* ordinary: */ gParams.ordinary, -// /* special: */ tuple( -// f -> implicitDeref(gParams.f), -// b -> implicitDeref(gParams.b))) -// -// -// Let's begin by just defining a function that can take -// a `LegalType` and turn it into zero or more field -// declarations, and return enough tracking information -// for us to be able to reconstruct a value like the above. -// -LegalElementWrapping declareStructFields( - TypeLegalizationContext* context, - IRStructType* structType, - LegalType fieldType) -{ - // TODO: We should eventually thread through some kind - // of "name hint" that can be used to give the generated - // fields more useful names. - - switch(fieldType.flavor) - { - case LegalType::Flavor::none: - return LegalElementWrapping::makeVoid(); - - case LegalType::Flavor::simple: - { - auto simpleFieldType = fieldType.getSimple(); - auto builder = context->getBuilder(); - auto fieldKey = builder->createStructKey(); - builder->createStructField(structType, fieldKey, simpleFieldType); - return LegalElementWrapping::makeSimple(fieldKey, simpleFieldType); - } - - case LegalType::Flavor::implicitDeref: - { - auto subField = declareStructFields(context, structType, fieldType.getImplicitDeref()->valueType); - return LegalElementWrapping::makeImplicitDeref(subField); - } - - case LegalType::Flavor::pair: - { - auto pairType = fieldType.getPair(); - auto ordinaryField = declareStructFields(context, structType, pairType->ordinaryType); - auto specialField = declareStructFields(context, structType, pairType->specialType); - return LegalElementWrapping::makePair( - ordinaryField, - specialField, - pairType->pairInfo); - } - - case LegalType::Flavor::tuple: - { - auto tupleType = fieldType.getTuple(); - - RefPtr obj = new TupleLegalElementWrappingObj(); - for( auto ee : tupleType->elements ) - { - TupleLegalElementWrappingObj::Element element; - element.key = ee.key; - element.field = declareStructFields(context, structType, ee.type); - obj->elements.add(element); - } - return LegalElementWrapping::makeTuple(obj); - } - - default: - SLANG_UNEXPECTED("unhandled legal type flavor"); - UNREACHABLE_RETURN(LegalElementWrapping::makeVoid()); - break; - } -} - -LegalType createLegalUniformBufferTypeForExistentials( - TypeLegalizationContext* context, - IROp op, - LegalType legalElementType) -{ - auto builder = context->getBuilder(); - - // In order to wrap up all the data in `legalElementType`, - // will create a fresh `struct` type and then declare - // fields in it that are sufficient to hold that data - // in `legalElementType`. - // - auto structType = builder->createStructType(); - auto elementWrapping = declareStructFields( - context, structType, legalElementType); - - // Because the `structType` is an ordinary IR type - // (not a `LegalType`) we can go ahead and create an - // IR uniform buffer type that wraps it. - // - auto bufferType = createBuiltinGenericType( - context, - op, - structType); - - // The `elementWrapping` computed when we declared all - // the `struct` fields tells us how to get from the - // actual fields declared in the structure type to a - // `LegalVal` with the right shape for what users of - // the buffer will expect. We record both the underlying - // IR buffer type and that wrapping information into - // the resulting `LegalType` so that we can use it - // when declaring variables of this type. - // - return LegalType::makeWrappedBuffer(bufferType, elementWrapping); -} - -static LegalType createLegalUniformBufferType( - TypeLegalizationContext* context, - IRUniformParameterGroupType* uniformBufferType, - LegalType legalElementType) -{ - return createLegalUniformBufferType( - context, - uniformBufferType->op, - legalElementType); -} - -// Create a pointer type with a given legalized value type. -static LegalType createLegalPtrType( - TypeLegalizationContext* context, - IROp op, - LegalType legalValueType) -{ - switch (legalValueType.flavor) - { - case LegalType::Flavor::none: - return LegalType(); - - case LegalType::Flavor::simple: - { - // Easy case: we just have a simple element type, - // so we want to create a uniform buffer that wraps it. - return LegalType::simple(createBuiltinGenericType( - context, - op, - legalValueType.getSimple())); - } - break; - - case LegalType::Flavor::implicitDeref: - { - // We are being asked to create a pointer type to something - // that is implicitly dereferenced, meaning we had: - // - // Ptr(PtrLike(T)) - // - // and now are being asked to make: - // - // Ptr(implicitDeref(LegalT)) - // - // So it seems like we can just create: - // - // implicitDeref(Ptr(LegalT)) - // - // and nobody should really be able to tell the difference, right? - // - // TODO: invetigate whether there are situations where this - // will matter. - return LegalType::implicitDeref(createLegalPtrType( - context, - op, - legalValueType.getImplicitDeref()->valueType)); - } - break; - - case LegalType::Flavor::pair: - { - // We just need to pointer-ify both sides of the pair. - auto pairType = legalValueType.getPair(); - - auto ordinaryType = createLegalPtrType( - context, - op, - pairType->ordinaryType); - auto specialType = createLegalPtrType( - context, - op, - pairType->specialType); - - return LegalType::pair(ordinaryType, specialType, pairType->pairInfo); - } - - case LegalType::Flavor::tuple: - { - // Wrap each of the tuple elements up as a pointer. - auto valuePseudoTupleType = legalValueType.getTuple(); - - RefPtr ptrPseudoTupleType = new TuplePseudoType(); - - // Wrap all the pseudo-tuple elements with `implicitDeref`, - // since they used to be inside a tuple, but aren't any more. - for (auto ee : valuePseudoTupleType->elements) - { - TuplePseudoType::Element newElement; - - newElement.key = ee.key; - newElement.type = createLegalPtrType( - context, - op, - ee.type); - - ptrPseudoTupleType->elements.add(newElement); - } - - return LegalType::tuple(ptrPseudoTupleType); - } - break; - - default: - SLANG_UNEXPECTED("unknown legal type flavor"); - UNREACHABLE_RETURN(LegalType()); - break; - } -} - -struct LegalTypeWrapper -{ - virtual LegalType wrap(TypeLegalizationContext* context, IRType* type) = 0; -}; - -struct ArrayLegalTypeWrapper : LegalTypeWrapper -{ - IRArrayTypeBase* arrayType; - - LegalType wrap(TypeLegalizationContext* context, IRType* type) - { - return LegalType::simple(context->getBuilder()->getArrayTypeBase( - arrayType->op, - type, - arrayType->getElementCount())); - } -}; - -struct BuiltinGenericLegalTypeWrapper : LegalTypeWrapper -{ - IROp op; - - LegalType wrap(TypeLegalizationContext* context, IRType* type) - { - return LegalType::simple(createBuiltinGenericType( - context, - op, - type)); - } -}; - - -struct ImplicitDerefLegalTypeWrapper : LegalTypeWrapper -{ - LegalType wrap(TypeLegalizationContext*, IRType* type) - { - return LegalType::implicitDeref(LegalType::simple(type)); - } -}; - -static LegalType wrapLegalType( - TypeLegalizationContext* context, - LegalType legalType, - LegalTypeWrapper* ordinaryWrapper, - LegalTypeWrapper* specialWrapper) -{ - switch (legalType.flavor) - { - case LegalType::Flavor::none: - return LegalType(); - - case LegalType::Flavor::simple: - { - return ordinaryWrapper->wrap(context, legalType.getSimple()); - } - break; - - case LegalType::Flavor::implicitDeref: - { - return LegalType::implicitDeref(wrapLegalType( - context, - legalType, - ordinaryWrapper, - specialWrapper)); - } - break; - - case LegalType::Flavor::pair: - { - // We just need to pointer-ify both sides of the pair. - auto pairType = legalType.getPair(); - - auto ordinaryType = wrapLegalType( - context, - pairType->ordinaryType, - ordinaryWrapper, - ordinaryWrapper); - auto specialType = wrapLegalType( - context, - pairType->specialType, - specialWrapper, - specialWrapper); - - return LegalType::pair(ordinaryType, specialType, pairType->pairInfo); - } - - case LegalType::Flavor::tuple: - { - // Wrap each of the tuple elements up as a pointer. - auto tupleType = legalType.getTuple(); - - RefPtr resultTupleType = new TuplePseudoType(); - - // Wrap all the pseudo-tuple elements with `implicitDeref`, - // since they used to be inside a tuple, but aren't any more. - for (auto ee : tupleType->elements) - { - TuplePseudoType::Element element; - - element.key = ee.key; - element.type = wrapLegalType( - context, - ee.type, - ordinaryWrapper, - specialWrapper); - - resultTupleType->elements.add(element); - } - - return LegalType::tuple(resultTupleType); - } - break; - - default: - SLANG_UNEXPECTED("unknown legal type flavor"); - UNREACHABLE_RETURN(LegalType()); - break; - } -} - -// Legalize a type, including any nested types -// that it transitively contains. -LegalType legalizeTypeImpl( - TypeLegalizationContext* context, - IRType* type) -{ - if(!type) - return LegalType::simple(nullptr); - - context->builder->setInsertBefore(type); - - if (auto uniformBufferType = as(type)) - { - // We have one of: - // - // ConstantBuffer - // TextureBuffer - // ParameterBlock - // - // or some other pointer-like type that represents uniform - // parameters. We need to pull any resource-type fields out - // of it, but leave non-resource fields where they are. - // - // As a special case, if the type contains *no* uniform data, - // we'll want to completely eliminate the uniform/ordinary - // part. - - auto originalElementType = uniformBufferType->getElementType(); - - // Legalize the element type to see what we are working with. - auto legalElementType = legalizeType(context, - originalElementType); - - // As a bit of a corner case, if the user requested something - // like `ConstantBuffer` the element type would - // legalize to a "simple" type, and that would be interpreted - // as an *ordinary* type, but we really need to notice the - // case when the element type is simple, but *special*. - // - if( context->isSpecialType(originalElementType) ) - { - // Anything that has a special element type needs to - // be handled by the pass-specific logic in the context. - // - return context->createLegalUniformBufferType( - uniformBufferType->op, - legalElementType); - } - - // Note that even when legalElementType.flavor == Simple - // we still need to create a new uniform buffer type - // from `legalElementType` instead of `type` - // because the `legalElementType` may still differ from `type` - // if, e.g., `type` contains empty structs. - return createLegalUniformBufferType( - context, - uniformBufferType, - legalElementType); - - } - else if (isResourceType(type)) - { - // We assume that any resource types not handled above - // are legal as-is. - return LegalType::simple(type); - } - else if (as(type)) - { - return LegalType::simple(type); - } - else if (as(type)) - { - return LegalType::simple(type); - } - else if (as(type)) - { - return LegalType::simple(type); - } - else if( auto existentialPtrType = as(type)) - { - // We want to transform an `ExistentialBox` into just - // a `T`, with an `iplicitDeref` to make sure that any - // pointer-related operations on the box Just Work. - // - // Note: the logic here doesn't have to deal with moving - // existential-type fields to the end of their outer - // type(s) because that is mostly dealt with in the - // case for struct types below. - // - auto legalValueType = legalizeType(context, existentialPtrType->getValueType()); - return LegalType::implicitDeref(legalValueType); - } - else if (auto ptrType = as(type)) - { - auto legalValueType = legalizeType(context, ptrType->getValueType()); - return createLegalPtrType(context, ptrType->op, legalValueType); - } - else if(auto structType = as(type)) - { - // Look at the (non-static) fields, and - // see if anything needs to be cleaned up. - // The things that need to be "cleaned up" for - // our purposes are: - // - // - Fields of resource type, or any other future - // type we run into that isn't allowed in - // aggregates for at least some targets - // - // - Fields with types that themselves had to - // get legalized. - // - // If we don't run into any of these, we - // can just use the type as-is. Hooray! - // - // Otherwise, we are effectively going to split - // the type apart and create a `TuplePseudoType`. - // Every field of the original type will be - // represented as an element of this pseudo-type. - // Each element will record its `LegalType`, - // and the original field that it was created from. - // An element will also track whether it contains - // any "ordinary" data, and if so, it will remember - // an element index in a real (AST-level, non-pseudo) - // `TupleType` that is used to bundle together - // such fields. - // - // Storing all the simple fields together like this - // obviously adds complexity to the legalization - // pass, but it has important benefits: - // - // - It avoids creating functions with a very large - // number of parameters (when passing a structure - // with many fields), which might confuse downstream - // compilers. - // - // - It avoids applying AOS->SOA conversion to fields - // that don't actually need it, which is basically - // required if we want type layout to work. - // - // - It ensures that we can actually construct a - // constant-buffer type that wraps a legalized - // aggregate type; the ordinary fields will get - // placed inside a new constant-buffer type, - // while the special ones will get left outside. - // - - // TODO: there is a risk here that we might recursively - // invole `legalizeType` on the type that we are - // currently trying to legalize. We need to detect that - // situation somehow, by inserting a sentinel value - // into `mapTypeToLegalType` during the per-field - // legalization process, and then if we ever see that - // sentinel in a call to `legalizeType`, we need - // to construct some kind of proxy type to help resolve - // the problem. - - TupleTypeBuilder builder; - builder.context = context; - builder.type = type; - builder.originalStructType = structType; - - for (auto ff : structType->getFields()) - { - builder.addField(ff); - } - - return builder.getResult(); - } - else if(auto arrayType = as(type)) - { - auto legalElementType = legalizeType( - context, - arrayType->getElementType()); - - ArrayLegalTypeWrapper wrapper; - wrapper.arrayType = arrayType; - - return wrapLegalType( - context, - legalElementType, - &wrapper, - &wrapper); - } - - return LegalType::simple(type); -} - -LegalType legalizeType( - TypeLegalizationContext* context, - IRType* type) -{ - LegalType legalType; - if(context->mapTypeToLegalType.TryGetValue(type, legalType)) - return legalType; - - legalType = legalizeTypeImpl(context, type); - context->mapTypeToLegalType[type] = legalType; - return legalType; -} - -// - -RefPtr getDerefTypeLayout( - TypeLayout* typeLayout) -{ - if (!typeLayout) - return nullptr; - - if (auto parameterGroupTypeLayout = as(typeLayout)) - { - return parameterGroupTypeLayout->offsetElementTypeLayout; - } - - return typeLayout; -} - -RefPtr getFieldLayout( - TypeLayout* typeLayout, - IRInst* fieldKey) -{ - if (!typeLayout) - return nullptr; - - for(;;) - { - if(auto arrayTypeLayout = as(typeLayout)) - { - typeLayout = arrayTypeLayout->elementTypeLayout; - } - else if(auto parameterGroupTypeLayout = as(typeLayout)) - { - typeLayout = parameterGroupTypeLayout->offsetElementTypeLayout; - } - else - { - break; - } - } - - - if (auto structTypeLayout = as(typeLayout)) - { - // First, let's see if the field had a layout registered - // directly using its IR key. - // - RefPtr fieldLayout; - if(structTypeLayout->mapKeyToLayout.TryGetValue(fieldKey, fieldLayout)) - return fieldLayout; - - // Otherwise, fall back to doing lookup using the linkage - // attached to the key, and its mangled name. - // - auto fieldLinkage = fieldKey->findDecoration(); - if(!fieldLinkage) - return nullptr; - auto mangledFieldName = fieldLinkage->getMangledName(); - - // In this case we fall back to a linear search over the fields. - // - for(auto ff : structTypeLayout->fields) - { - if(mangledFieldName == getMangledName(ff->varDecl.getDecl()).getUnownedSlice() ) - { - return ff; - } - } - } - - return nullptr; -} - -RefPtr createSimpleVarLayout( - SimpleLegalVarChain* varChain, - TypeLayout* typeLayout) -{ - if (!typeLayout) - return nullptr; - - // We need to construct a layout for the new variable - // that reflects both the type we have given it, as - // well as all the offset information that has accumulated - // along the chain of parent variables. - - // TODO: this logic needs to propagate through semantics... - - RefPtr varLayout = new VarLayout(); - varLayout->typeLayout = typeLayout; - - // For most resource kinds, the register index/space to use should - // be the sum along the entire chain of variables. - // - // For example, if we had input: - // - // struct S { Texture2D a; Texture2D b; }; - // S s : register(t10); - // - // And we were generating a stand-alone variable for `s.b`, then - // we'd need to add the offset for `b` (1 texture register), to - // the offset for `s` (10 texture registers) to get the final - // binding to apply. - // - for (auto rr : typeLayout->resourceInfos) - { - auto resInfo = varLayout->findOrAddResourceInfo(rr.kind); - - for (auto vv = varChain; vv; vv = vv->next) - { - if (auto parentResInfo = vv->varLayout->FindResourceInfo(rr.kind)) - { - resInfo->index += parentResInfo->index; - resInfo->space += parentResInfo->space; - } - } - } - - // As a special case, if the leaf variable doesn't hold an entry for - // `RegisterSpace`, but at least one declaration in the chain *does*, - // then we want to make sure that we add such an entry. - if (!varLayout->FindResourceInfo(LayoutResourceKind::RegisterSpace)) - { - // Sum up contributions from all parents. - UInt space = 0; - for (auto vv = varChain; vv; vv = vv->next) - { - if (auto parentResInfo = vv->varLayout->FindResourceInfo(LayoutResourceKind::RegisterSpace)) - { - space += parentResInfo->index; - } - } - - // If there were non-zero contributions, then add an entry to represent them. - if (space) - { - varLayout->findOrAddResourceInfo(LayoutResourceKind::RegisterSpace)->index = space; - } - } - - return varLayout; -} - - -RefPtr createVarLayout( - LegalVarChain const& varChain, - TypeLayout* typeLayout) -{ - if(!typeLayout) - return nullptr; - - auto varLayout = createSimpleVarLayout(varChain.primaryChain, typeLayout); - - if(auto pendingDataTypeLayout = typeLayout->pendingDataTypeLayout) - { - varLayout->pendingVarLayout = createSimpleVarLayout(varChain.pendingChain, typeLayout); - } - - return varLayout; -} - -// - -// TODO(tfoley): The code captured here is the logic that used to be -// applied to decide whether or not to desugar aggregate types that -// contain resources. Right now the implementation will *always* legalize -// away such types (since the IR always does this), while the AST-to-AST -// pass would only do it if required (according to the tests below). -// -// For right now this is an academic distinction, since the only project -// using Slang right now enables this tansformation unconditionally, but -// we probably need to re-parent this code back into the `TypeLegalizationContext` -// somewhere. -#if 0 - -bool shouldDesugarTupleTypes = false; -if (getTarget() == CodeGenTarget::GLSL) -{ - // Always desugar this stuff for GLSL, since it doesn't - // support nesting of resources in structs. - // - // TODO: Need a way to make this more fine-grained to - // handle cases where a nested member might be allowed - // due to, e.g., bindless textures. - shouldDesugarTupleTypes = true; -} -else if( shared->compileRequest->compileFlags & SLANG_COMPILE_FLAG_SPLIT_MIXED_TYPES ) -{ - // If the user is directly asking us to do this transformation, - // then obviously we need to do it. - // - // TODO: The way this is defined here means it will even apply to user - // HLSL code (not just code written in Slang). We may want to - // reconsider that choice, and only split things that originated in Slang. - // - shouldDesugarTupleTypes = true; -} - -#endif - -} diff --git a/source/slang/legalize-types.h b/source/slang/legalize-types.h deleted file mode 100644 index b35ed9d3c..000000000 --- a/source/slang/legalize-types.h +++ /dev/null @@ -1,678 +0,0 @@ -// legalize-types.h -#ifndef SLANG_LEGALIZE_TYPES_H_INCLUDED -#define SLANG_LEGALIZE_TYPES_H_INCLUDED - -// This file and `legalize-types.cpp` implement the core -// logic for taking a `Type` as produced by the front-end, -// and turning it into a suitable representation for use -// on a particular back-end. -// -// The main work applies to aggregate (e.g., `struct`) types, -// since various targets have rules about what is and isn't -// allowed in an aggregate (or where aggregates are allowed -// to be used). -// -// We might completely replace an aggregate `Type` with a -// "pseudo-type" that is just the enumeration of its field -// types (sort of a tuple type) so that a variable declared -// with the original type should be transformed into a -// bunch of individual variables. -// -// Alternatively, we might replace an aggregate type, where -// only *some* of the fields are illegal with a combination -// of an aggregate (containing the legal/legalized fields), -// and some extra tuple-ified fields. - -#include "../core/basic.h" -#include "ir-insts.h" -#include "syntax.h" -#include "type-layout.h" -#include "name.h" - -namespace Slang -{ - -struct IRBuilder; - -struct LegalTypeImpl : RefObject -{ -}; -struct ImplicitDerefType; -struct TuplePseudoType; -struct PairPseudoType; -struct PairInfo; -struct LegalElementWrapping; -struct WrappedBufferPseudoType; - - /// A flavor for types or values that arise during legalization. -enum class LegalFlavor -{ - /// Nothing: an empty type or value. Equivalent to `void`. - none, - - /// A simple type/value that can be represented as an `IRType*` or `IRInst*` - simple, - - /// Logically, a pointer-like type/value, but represented as the type/value being pointed type. - implicitDeref, - - /// A compound type/value made up of the constituent fields of some original value. - tuple, - - /// A type/value that was split into "ordinary" and "special" parts. - pair, - - /// A type/value that represents, e.g., `ConstantBuffer` where `T` needed legalization. - wrappedBuffer, -}; - -struct LegalType -{ - typedef LegalFlavor Flavor; - - Flavor flavor = Flavor::none; - RefPtr obj; - IRType* irType; - - static LegalType simple(IRType* type) - { - LegalType result; - result.flavor = Flavor::simple; - result.irType = type; - return result; - } - - IRType* getSimple() const - { - SLANG_ASSERT(flavor == Flavor::simple); - return irType; - } - - static LegalType implicitDeref( - LegalType const& valueType); - - RefPtr getImplicitDeref() const - { - SLANG_ASSERT(flavor == Flavor::implicitDeref); - return obj.as(); - } - - static LegalType tuple( - RefPtr tupleType); - - RefPtr getTuple() const - { - SLANG_ASSERT(flavor == Flavor::tuple); - return obj.as(); - } - - static LegalType pair( - RefPtr pairType); - - static LegalType pair( - LegalType const& ordinaryType, - LegalType const& specialType, - RefPtr pairInfo); - - RefPtr getPair() const - { - SLANG_ASSERT(flavor == Flavor::pair); - return obj.as(); - } - - static LegalType makeWrappedBuffer( - IRType* simpleType, - LegalElementWrapping const& elementInfo); - - RefPtr getWrappedBuffer() const - { - SLANG_ASSERT(flavor == Flavor::wrappedBuffer); - return obj.as(); - } -}; - -struct LegalElementWrappingObj : RefObject -{ -}; - -struct SimpleLegalElementWrappingObj; -struct ImplicitDerefLegalElementWrappingObj; -struct PairLegalElementWrappingObj; -struct TupleLegalElementWrappingObj; - - /// Information on how the element type of a buffer needs to be wrapped. -struct LegalElementWrapping -{ - typedef LegalFlavor Flavor; - - Flavor flavor; - RefPtr obj; - - static LegalElementWrapping makeVoid(); - static LegalElementWrapping makeSimple(IRStructKey* key, IRType* type); - static LegalElementWrapping makeImplicitDeref(LegalElementWrapping const& field); - static LegalElementWrapping makePair( - LegalElementWrapping const& ordinary, - LegalElementWrapping const& special, - PairInfo* pairInfo); - static LegalElementWrapping makeTuple(TupleLegalElementWrappingObj* obj); - - RefPtr getSimple() const; - RefPtr getImplicitDeref() const; - RefPtr getPair() const; - RefPtr getTuple() const; -}; - -struct SimpleLegalElementWrappingObj : LegalElementWrappingObj -{ - IRStructKey* key; - IRType* type; -}; - -struct ImplicitDerefLegalElementWrappingObj : LegalElementWrappingObj -{ - LegalElementWrapping field; -}; - -struct PairLegalElementWrappingObj : LegalElementWrappingObj -{ - LegalElementWrapping ordinary; - LegalElementWrapping special; - RefPtr pairInfo; -}; - -struct TupleLegalElementWrappingObj : LegalElementWrappingObj -{ - struct Element - { - IRStructKey* key; - LegalElementWrapping field; - }; - - List elements; -}; - -// Represents the pseudo-type of a type that is pointer-like -// (and thus requires dereferencing, even if implicit), but -// was legalized to just use the type of the pointed-type value. -// -// The two cases where this comes up are: -// -// 1. When we have a type like `ConstantBuffer` that -// implies a level of indirection, but need to legalize it to just -// `Texture2D`, which eliminates that indirection. -// -// 2. When we have a type like `ExistentialBox` that will -// become just a `Foo` field, but which needs to be allocated -// out-of-line from the rest of its enclosing type. -// -struct ImplicitDerefType : LegalTypeImpl -{ - LegalType valueType; -}; - -// Represents the pseudo-type for a compound type -// that had to be broken apart because it contained -// one or more fields of types that shouldn't be -// allowed in aggregates. -// -// A tuple pseduo-type will have an element for -// each field of the original type, that represents -// the legalization of that field's type. -// -// It optionally also contains an "ordinary" type -// that packs together any per-field data that -// itself has (or contains) an ordinary type. -struct TuplePseudoType : LegalTypeImpl -{ - // Represents one element of the tuple pseudo-type - struct Element - { - // The field that this element replaces - IRStructKey* key; - - // The legalized type of the element - LegalType type; - }; - - // All of the elements of the tuple pseduo-type. - List elements; -}; - -struct IRStructKey; - -struct PairInfo : RefObject -{ - typedef unsigned int Flags; - enum - { - kFlag_hasOrdinary = 0x1, - kFlag_hasSpecial = 0x2, - }; - - - struct Element - { - // The original field the element represents - IRStructKey* key; - - // The conceptual type of the field. - // If both the `hasOrdinary` and - // `hasSpecial` bits are set, then - // this is expected to be a - // `LegalType::Flavor::pair` - LegalType type; - - // Is the value represented on - // the ordinary side, the special - // side, or both? - Flags flags; - - // If the type of this element is - // itself a pair type (that is, - // it both `hasOrdinary` and `hasSpecial`) - // then this is the `PairInfo` for that - // pair type: - RefPtr fieldPairInfo; - }; - - // For a pair type or value, we need to track - // which fields are on which side(s). - List elements; - - Element* findElement(IRStructKey* key) - { - for (auto& ee : elements) - { - if(ee.key == key) - return ⅇ - } - return nullptr; - } -}; - -struct PairPseudoType : LegalTypeImpl -{ - // Any field(s) with ordinary types will - // get captured here, usually as a single - // `simple` or `implicitDeref` type. - LegalType ordinaryType; - - // Any fields with "special" (not ordinary) - // types will get captured here (usually - // with a tuple). - LegalType specialType; - - // The `pairInfo` field helps to tell us which members - // of the original aggregate type appear on which side(s) - // of the new pair type. - RefPtr pairInfo; -}; - - -struct WrappedBufferPseudoType : LegalTypeImpl -{ - // The actual IR type that was used for the buffer. - IRType* simpleType; - - // Adjustments that need to be made when fetching - // an element from this buffer type. - // - LegalElementWrapping elementInfo; -}; - -// - -RefPtr getDerefTypeLayout( - TypeLayout* typeLayout); - -RefPtr getFieldLayout( - TypeLayout* typeLayout, - IRInst* fieldKey); - - /// Represents a "chain" of variables leading to some leaf field. - /// - /// Consider code like: - /// - /// struct Branch { int leaf; } - /// struct Tree { Branch left; Branch right; } - /// cbuffer Forest - /// { - /// int maxTreeHeight; - /// Tree tree; - /// } - /// - /// If we ask "what is the offset of `leaf`" the simple answer is zero, - /// but sometimes we are talking about `Forest.tree.right.leaf` which - /// will have a very different offset. In Slang parameters can consume - /// various (and multiple) resource kinds, so a single offset can't - /// be tunneled down through most recursive procedures. - /// - /// Instead we use a "chain" that works up through the stack, and - /// records the path from leaf field like `leaf` up to whatever - /// variable is the root for the curent operation. - /// - /// Operations like computing an offset can then be encoded by - /// starting with zero and then walking up the chain and adding in - /// offsets as encountered. - /// -struct SimpleLegalVarChain -{ - // The next link up the chain, or null if this is the end. - SimpleLegalVarChain* next = nullptr; - - // The layout for the variable at this link in thain. - VarLayout* varLayout = nullptr; -}; - - /// A "chain" of variable declarations that can handle both primary and "pending" data. - /// - /// In the presence of interface-type fields, a single variable may - /// have data that sits in two distinct allocations, and may have - /// `VarLayout`s that represent offseting into each of those - /// allocations. - /// - /// A `LegalVarChain` tracks two distinct `SimpleVarChain`s: one for - /// the primary/ordinary data allocation, and one for any pending - /// data. - /// - /// It is okay if the primary/pending chains have different numbers - /// of links in them. - /// - /// Offsets for particular resource kinds in the primary or pending - /// data allocation can be queried on the appropriate sub-chain. - /// -struct LegalVarChain -{ - // The chain of variables that represents the primary allocation. - SimpleLegalVarChain* primaryChain = nullptr; - - // The chain of variables that represents the pending allocation. - SimpleLegalVarChain* pendingChain = nullptr; - - // If the primary chain is non-empty, gets the variable at the leaf. - DeclRef getLeafVarDeclRef() const - { - if(!primaryChain) - return DeclRef(); - - return primaryChain->varLayout->varDecl; - } -}; - - /// RAII type for adding a link to a `LegalVarChain` as needed. - /// - /// This type handles the bookkeeping for creating a `LegalVarChain` - /// that links in one more variable. It will add a link to each of - /// the primary and pending sub-chains if and only if there is non-null - /// layout information for the primary/pending case. - /// - /// Typical usage in a recursive function is: - /// - /// void someRecursiveFunc(LegalVarChain const& outerChain, ...) - /// { - /// if(auto subVar = needToRecurse(...)) - /// { - /// LegalVarChainLink subChain(outerChain, subVar); - /// someRecursiveFunc(subChain, ...); - /// } - /// ... - /// } - /// -struct LegalVarChainLink : LegalVarChain -{ - /// Default constructor: yields an empty chain. - LegalVarChainLink() - { - } - - /// Copy constructor: yields a copy of the `parent` chain. - LegalVarChainLink(LegalVarChain const& parent) - : LegalVarChain(parent) - {} - - /// Construct a chain that extends `parent` with `varLayout`, if it is non-null. - LegalVarChainLink(LegalVarChain const& parent, VarLayout* varLayout) - : LegalVarChain(parent) - { - if( varLayout ) - { - primaryLink.next = parent.primaryChain; - primaryLink.varLayout = varLayout; - primaryChain = &primaryLink; - - if( auto pendingVarLayout = varLayout->pendingVarLayout ) - { - pendingLink.next = parent.pendingChain; - pendingLink.varLayout = pendingVarLayout; - pendingChain = &pendingLink; - } - } - } - - SimpleLegalVarChain primaryLink; - SimpleLegalVarChain pendingLink; -}; - -RefPtr createVarLayout( - LegalVarChain const& varChain, - TypeLayout* typeLayout); - -RefPtr createSimpleVarLayout( - SimpleLegalVarChain* varChain, - TypeLayout* typeLayout); - -// -// The result of legalizing an IR value will be -// represented with the `LegalVal` type. It is exposed -// in this header (rather than kept as an implementation -// detail, because the AST-based legalization logic needs -// a way to find the post-legalization version of a -// global name). -// -// TODO: We really shouldn't have this structure exposed, -// and instead should really be constructing AST-side -// `LegalExpr` values on-demand whenever we legalize something -// in the IR that will need to be used by the AST, and then -// store *those* in a map indexed in mangled names. -// - -struct LegalValImpl : RefObject -{ -}; -struct TuplePseudoVal; -struct PairPseudoVal; -struct WrappedBufferPseudoVal; - -struct LegalVal -{ - typedef LegalFlavor Flavor; - - Flavor flavor = Flavor::none; - RefPtr obj; - IRInst* irValue = nullptr; - - static LegalVal simple(IRInst* irValue) - { - LegalVal result; - result.flavor = Flavor::simple; - result.irValue = irValue; - return result; - } - - IRInst* getSimple() const - { - SLANG_ASSERT(flavor == Flavor::simple); - return irValue; - } - - static LegalVal tuple(RefPtr tupleVal); - - RefPtr getTuple() const - { - SLANG_ASSERT(flavor == Flavor::tuple); - return obj.as(); - } - - static LegalVal implicitDeref(LegalVal const& val); - LegalVal getImplicitDeref(); - - static LegalVal pair(RefPtr pairInfo); - static LegalVal pair( - LegalVal const& ordinaryVal, - LegalVal const& specialVal, - RefPtr pairInfo); - - RefPtr getPair() const - { - SLANG_ASSERT(flavor == Flavor::pair); - return obj.as(); - } - - static LegalVal wrappedBuffer( - LegalVal const& baseVal, - LegalElementWrapping const& elementInfo); - - RefPtr getWrappedBuffer() const - { - SLANG_ASSERT(flavor == Flavor::wrappedBuffer); - return obj.as(); - } -}; - -struct TuplePseudoVal : LegalValImpl -{ - struct Element - { - IRStructKey* key; - LegalVal val; - }; - - List elements; -}; - -struct PairPseudoVal : LegalValImpl -{ - LegalVal ordinaryVal; - LegalVal specialVal; - - // The info to tell us which fields - // are on which side(s) - RefPtr pairInfo; -}; - -struct ImplicitDerefVal : LegalValImpl -{ - LegalVal val; -}; - -struct WrappedBufferPseudoVal : LegalValImpl -{ - LegalVal base; - LegalElementWrapping elementInfo; -}; - -// - - /// Context that drives type legalization - /// - /// This type is an abstract base class, and there are - /// customization points that a concrete pass needs to - /// override (e.g., to specify what needs to be legalized). -struct IRTypeLegalizationContext -{ - Session* session; - IRModule* module; - IRBuilder* builder; - - SharedIRBuilder sharedBuilderStorage; - IRBuilder builderStorage; - - IRTypeLegalizationContext( - IRModule* inModule); - - // When inserting new globals, put them before this one. - IRInst* insertBeforeGlobal = nullptr; - - // When inserting new parameters, put them before this one. - IRParam* insertBeforeParam = nullptr; - - Dictionary mapValToLegalVal; - - IRVar* insertBeforeLocalVar = nullptr; - - // store instructions that have been replaced here, so we can free them - // when legalization has done - List replacedInstructions; - - Dictionary mapTypeToLegalType; - - IRBuilder* getBuilder() { return builder; } - - /// Customization point to decide what types are "special." - /// - /// When legalizing a `struct` type, any fields that have "special" - /// types will get moved out of the `struc` itself. - virtual bool isSpecialType(IRType* type) = 0; - - /// Customization point to construct uniform-buffer/block types. - /// - /// This function will only be called if `legalElementType` is - /// somehow non-trivial. - /// - virtual LegalType createLegalUniformBufferType( - IROp op, - LegalType legalElementType) = 0; -}; - -// This typedef exists to support pre-existing code from when -// `IRTypeLegalizationContext` and `TypeLegalizationContext` were -// two different types that had to coordinate. -typedef struct IRTypeLegalizationContext TypeLegalizationContext; - -LegalType legalizeType( - TypeLegalizationContext* context, - IRType* type); - -/// Try to find the module that (recursively) contains a given declaration. -ModuleDecl* findModuleForDecl( - Decl* decl); - - /// Create a uniform buffer type suitable for resource legalization. - /// - /// This will allocate a real buffer for the ordinary data (if any), - /// and leave the special data (if any) as a tuple. - /// -LegalType createLegalUniformBufferTypeForResources( - TypeLegalizationContext* context, - IROp op, - LegalType legalElementType); - - /// Create a uniform buffer type suitable for existential legalization. - /// - /// This will allocate a real uniform buffer for *all* the data, by - /// declaring an intermediate `struct` type to hold the ordinary and - /// special (existential-box) fields, if required. - /// -LegalType createLegalUniformBufferTypeForExistentials( - TypeLegalizationContext* context, - IROp op, - LegalType legalElementType); - - - - -void legalizeExistentialTypeLayout( - IRModule* module, - DiagnosticSink* sink); - -void legalizeResourceTypes( - IRModule* module, - DiagnosticSink* sink); - -bool isResourceType(IRType* type); - - -} - -#endif diff --git a/source/slang/lexer.cpp b/source/slang/lexer.cpp deleted file mode 100644 index 8cb9fa5ee..000000000 --- a/source/slang/lexer.cpp +++ /dev/null @@ -1,1334 +0,0 @@ -// lexer.cpp -#include "lexer.h" - -// This file implements the lexer/scanner, which is responsible for taking a raw stream of -// input bytes and turning it into semantically useful tokens. -// - -#include "compiler.h" -#include "source-loc.h" - -#include - -namespace Slang -{ - Token TokenReader::GetEndOfFileToken() - { - return Token(TokenType::EndOfFile, UnownedStringSlice::fromLiteral(""), SourceLoc()); - } - - Token* TokenList::begin() const - { - SLANG_ASSERT(mTokens.getCount()); - return &mTokens[0]; - } - - Token* TokenList::end() const - { - SLANG_ASSERT(mTokens.getCount()); - SLANG_ASSERT(mTokens[mTokens.getCount()-1].type == TokenType::EndOfFile); - return &mTokens[mTokens.getCount() - 1]; - } - - TokenSpan::TokenSpan() - : mBegin(NULL) - , mEnd (NULL) - {} - - TokenReader::TokenReader() - : mCursor(NULL) - , mEnd (NULL) - {} - - - Token& TokenReader::PeekToken() - { - return nextToken; - } - - TokenType TokenReader::PeekTokenType() const - { - return nextToken.type; - } - - SourceLoc TokenReader::PeekLoc() const - { - return nextToken.loc; - } - - Token TokenReader::AdvanceToken() - { - if (!mCursor) - return GetEndOfFileToken(); - - Token token = nextToken; - if (mCursor < mEnd) - { - mCursor++; - nextToken = *mCursor; - } - else - nextToken.type = TokenType::EndOfFile; - return token; - } - - // Lexer - - void Lexer::initialize( - SourceView* inSourceView, - DiagnosticSink* inSink, - NamePool* inNamePool, - MemoryArena* inMemoryArena) - { - sourceView = inSourceView; - sink = inSink; - namePool = inNamePool; - memoryArena = inMemoryArena; - - auto content = inSourceView->getContent(); - - begin = content.begin(); - cursor = content.begin(); - end = content.end(); - - // Set the start location - startLoc = inSourceView->getRange().begin; - - tokenFlags = TokenFlag::AtStartOfLine | TokenFlag::AfterWhitespace; - lexerFlags = 0; - } - - Lexer::~Lexer() - { - } - - enum { kEOF = -1 }; - - // Get the next input byte, without any handling of - // escaped newlines, non-ASCII code points, source locations, etc. - static int peekRaw(Lexer* lexer) - { - // If we are at the end of the input, return a designated end-of-file value - if(lexer->cursor == lexer->end) - return kEOF; - - // Otherwise, just look at the next byte - return *lexer->cursor; - } - - // Read one input byte without any special handling (similar to `peekRaw`) - static int advanceRaw(Lexer* lexer) - { - // The logic here is basically the same as for `peekRaw()`, - // escape we advance `cursor` if we aren't at the end. - - if (lexer->cursor == lexer->end) - return kEOF; - - return *lexer->cursor++; - } - - // When the cursor is already at the first byte of an end-of-line sequence, - // consume one or two bytes that compose the sequence. - // - // Basically, a newline is one of: - // - // "\n" - // "\r" - // "\r\n" - // "\n\r" - // - // We always look for the longest match possible. - // - static void handleNewLineInner(Lexer* lexer, int c) - { - SLANG_ASSERT(c == '\n' || c == '\r'); - - int d = peekRaw(lexer); - if( (c ^ d) == ('\n' ^ '\r') ) - { - advanceRaw(lexer); - } - } - - // Look ahead one code point, dealing with complications like - // escaped newlines. - static int peek(Lexer* lexer) - { - // Look at the next raw byte, and decide what to do - int c = peekRaw(lexer); - - if(c == '\\') - { - // We might have a backslash-escaped newline. - // Look at the next byte (if any) to see. - // - // Note(tfoley): We are assuming a null-terminated input here, - // so that we can safely look at the next byte without issue. - int d = lexer->cursor[1]; - switch (d) - { - case '\r': case '\n': - { - // The newline was escaped, so return the code point after *that* - - int e = lexer->cursor[2]; - if ((d ^ e) == ('\r' ^ '\n')) - return lexer->cursor[3]; - return e; - } - - default: - break; - } - } - // TODO: handle UTF-8 encoding for non-ASCII code points here - - // Default case is to just hand along the byte we read as an ASCII code point. - return c; - } - - // Get the next code point from the input, and advance the cursor. - static int advance(Lexer* lexer) - { - // We are going to loop, but only as a way of handling - // escaped line endings. - for (;;) - { - // If we are at the end of the input, then the task is easy. - if (lexer->cursor == lexer->end) - return kEOF; - - // Look at the next raw byte, and decide what to do - int c = *lexer->cursor++; - - if (c == '\\') - { - // We might have a backslash-escaped newline. - // Look at the next byte (if any) to see. - // - // Note(tfoley): We are assuming a null-terminated input here, - // so that we can safely look at the next byte without issue. - int d = *lexer->cursor; - switch (d) - { - case '\r': case '\n': - // handle the end-of-line for our source location tracking - lexer->cursor++; - handleNewLineInner(lexer, d); - - lexer->tokenFlags |= TokenFlag::ScrubbingNeeded; - - // Now try again, looking at the character after the - // escaped newline. - continue; - - default: - break; - } - } - - // TODO: Need to handle non-ASCII code points. - - // Default case is to return the raw byte we saw. - return c; - } - } - - static void handleNewLine(Lexer* lexer) - { - int c = advance(lexer); - handleNewLineInner(lexer, c); - } - - static void lexLineComment(Lexer* lexer) - { - for(;;) - { - switch(peek(lexer)) - { - case '\n': case '\r': case kEOF: - return; - - default: - advance(lexer); - continue; - } - } - } - - static void lexBlockComment(Lexer* lexer) - { - for(;;) - { - switch(peek(lexer)) - { - case kEOF: - // TODO(tfoley) diagnostic! - return; - - case '\n': case '\r': - handleNewLine(lexer); - continue; - - case '*': - advance(lexer); - switch( peek(lexer) ) - { - case '/': - advance(lexer); - return; - - default: - continue; - } - - default: - advance(lexer); - continue; - } - } - } - - static void lexHorizontalSpace(Lexer* lexer) - { - for(;;) - { - switch(peek(lexer)) - { - case ' ': case '\t': - advance(lexer); - continue; - - default: - return; - } - } - } - - static void lexIdentifier(Lexer* lexer) - { - for(;;) - { - int c = peek(lexer); - if(('a' <= c ) && (c <= 'z') - || ('A' <= c) && (c <= 'Z') - || ('0' <= c) && (c <= '9') - || (c == '_')) - { - advance(lexer); - continue; - } - - return; - } - } - - static SourceLoc getSourceLoc(Lexer* lexer) - { - return lexer->startLoc + (lexer->cursor - lexer->begin); - } - - static void lexDigits(Lexer* lexer, int base) - { - for(;;) - { - int c = peek(lexer); - - int digitVal = 0; - switch(c) - { - case '0': case '1': case '2': case '3': case '4': - case '5': case '6': case '7': case '8': case '9': - digitVal = c - '0'; - break; - - case 'a': case 'b': case 'c': case 'd': case 'e': case 'f': - if(base <= 10) return; - digitVal = 10 + c - 'a'; - break; - - case 'A': case 'B': case 'C': case 'D': case 'E': case 'F': - if(base <= 10) return; - digitVal = 10 + c - 'A'; - break; - - default: - // Not more digits! - return; - } - - if(digitVal >= base) - { - char buffer[] = { (char) c, 0 }; - lexer->sink->diagnose(getSourceLoc(lexer), Diagnostics::invalidDigitForBase, buffer, base); - } - - advance(lexer); - } - } - - static TokenType maybeLexNumberSuffix(Lexer* lexer, TokenType tokenType) - { - // Be liberal in what we accept here, so that figuring out - // the semantics of a numeric suffix is left up to the parser - // and semantic checking logic. - // - for( ;;) - { - int c = peek(lexer); - - // Accept any alphanumeric character, plus underscores. - if(('a' <= c ) && (c <= 'z') - || ('A' <= c) && (c <= 'Z') - || ('0' <= c) && (c <= '9') - || (c == '_')) - { - advance(lexer); - continue; - } - - // Stop at the first character that isn't - // alphanumeric. - return tokenType; - } - } - - static bool isNumberExponent(int c, int base) - { - switch( c ) - { - default: - return false; - - case 'e': case 'E': - if(base != 10) return false; - break; - - case 'p': case 'P': - if(base != 16) return false; - break; - } - - return true; - } - - static bool maybeLexNumberExponent(Lexer* lexer, int base) - { - if(!isNumberExponent(peek(lexer), base)) - return false; - - // we saw an exponent marker - advance(lexer); - - // Now start to read the exponent - switch( peek(lexer) ) - { - case '+': case '-': - advance(lexer); - break; - } - - // TODO(tfoley): it would be an error to not see digits here... - - lexDigits(lexer, 10); - - return true; - } - - static TokenType lexNumberAfterDecimalPoint(Lexer* lexer, int base) - { - lexDigits(lexer, base); - maybeLexNumberExponent(lexer, base); - - return maybeLexNumberSuffix(lexer, TokenType::FloatingPointLiteral); - } - - static TokenType lexNumber(Lexer* lexer, int base) - { - // TODO(tfoley): Need to consider whehter to allow any kind of digit separator character. - - TokenType tokenType = TokenType::IntegerLiteral; - - // At the start of things, we just concern ourselves with digits - lexDigits(lexer, base); - - if( peek(lexer) == '.' ) - { - tokenType = TokenType::FloatingPointLiteral; - - advance(lexer); - lexDigits(lexer, base); - } - - if( maybeLexNumberExponent(lexer, base)) - { - tokenType = TokenType::FloatingPointLiteral; - } - - maybeLexNumberSuffix(lexer, tokenType); - return tokenType; - } - - static int maybeReadDigit(char const** ioCursor, int base) - { - auto& cursor = *ioCursor; - - for(;;) - { - int c = *cursor; - switch(c) - { - default: - return -1; - - // TODO: need to decide on digit separator characters - case '_': - cursor++; - continue; - - case '0': case '1': case '2': case '3': case '4': - case '5': case '6': case '7': case '8': case '9': - cursor++; - return c - '0'; - - case 'a': case 'b': case 'c': case 'd': case 'e': case 'f': - if(base > 10) - { - cursor++; - return 10 + c - 'a'; - } - return -1; - - case 'A': case 'B': case 'C': case 'D': case 'E': case 'F': - if(base > 10) - { - cursor++; - return 10 + c - 'A'; - } - return -1; - } - } - } - - static int readOptionalBase(char const** ioCursor) - { - auto& cursor = *ioCursor; - if( *cursor == '0' ) - { - cursor++; - switch(*cursor) - { - case 'x': case 'X': - cursor++; - return 16; - - case 'b': case 'B': - cursor++; - return 2; - - case '0': case '1': case '2': case '3': case '4': - case '5': case '6': case '7': case '8': case '9': - return 8; - - default: - return 10; - } - } - - return 10; - } - - - - IntegerLiteralValue getIntegerLiteralValue(Token const& token, UnownedStringSlice* outSuffix) - { - IntegerLiteralValue value = 0; - - char const* cursor = token.Content.begin(); - char const* end = token.Content.end(); - - int base = readOptionalBase(&cursor); - - for( ;;) - { - int digit = maybeReadDigit(&cursor, base); - if(digit < 0) - break; - - value = value*base + digit; - } - - if(outSuffix) - { - *outSuffix = UnownedStringSlice(cursor, end); - } - - return value; - } - - FloatingPointLiteralValue getFloatingPointLiteralValue(Token const& token, UnownedStringSlice* outSuffix) - { - FloatingPointLiteralValue value = 0; - - char const* cursor = token.Content.begin(); - char const* end = token.Content.end(); - - int radix = readOptionalBase(&cursor); - - bool seenDot = false; - FloatingPointLiteralValue divisor = 1; - for( ;;) - { - if(*cursor == '.') - { - cursor++; - seenDot = true; - continue; - } - - int digit = maybeReadDigit(&cursor, radix); - if(digit < 0) - break; - - value = value*radix + digit; - - if(seenDot) - { - divisor *= radix; - } - } - - // Now read optional exponent - if(isNumberExponent(*cursor, radix)) - { - cursor++; - - bool exponentIsNegative = false; - switch(*cursor) - { - default: - break; - - case '-': - exponentIsNegative = true; - cursor++; - break; - - case '+': - cursor++; - break; - } - - int exponentRadix = 10; - int exponent = 0; - - for(;;) - { - int digit = maybeReadDigit(&cursor, exponentRadix); - if(digit < 0) - break; - - exponent = exponent*exponentRadix + digit; - } - - FloatingPointLiteralValue exponentBase = 10; - if(radix == 16) - { - exponentBase = 2; - } - - FloatingPointLiteralValue exponentValue = pow(exponentBase, exponent); - - if( exponentIsNegative ) - { - divisor *= exponentValue; - } - else - { - value *= exponentValue; - } - } - - value /= divisor; - - if(outSuffix) - { - *outSuffix = UnownedStringSlice(cursor, end); - } - - return value; - } - - static void lexStringLiteralBody(Lexer* lexer, char quote) - { - for(;;) - { - int c = peek(lexer); - if(c == quote) - { - advance(lexer); - return; - } - - switch(c) - { - case kEOF: - lexer->sink->diagnose(getSourceLoc(lexer), Diagnostics::endOfFileInLiteral); - return; - - case '\n': case '\r': - lexer->sink->diagnose(getSourceLoc(lexer), Diagnostics::newlineInLiteral); - return; - - case '\\': - // Need to handle various escape sequence cases - advance(lexer); - switch(peek(lexer)) - { - case '\'': - case '\"': - case '\\': - case '?': - case 'a': - case 'b': - case 'f': - case 'n': - case 'r': - case 't': - case 'v': - advance(lexer); - break; - - case '0': case '1': case '2': case '3': case '4': - case '5': case '6': case '7': - // octal escape: up to 3 characters - advance(lexer); - for(int ii = 0; ii < 3; ++ii) - { - int d = peek(lexer); - if(('0' <= d) && (d <= '7')) - { - advance(lexer); - continue; - } - else - { - break; - } - } - break; - - case 'x': - // hexadecimal escape: any number of characters - advance(lexer); - for(;;) - { - int d = peek(lexer); - if(('0' <= d) && (d <= '9') - || ('a' <= d) && (d <= 'f') - || ('A' <= d) && (d <= 'F')) - { - advance(lexer); - continue; - } - else - { - break; - } - } - break; - - // TODO: Unicode escape sequences - - } - break; - - default: - advance(lexer); - continue; - } - } - } - - String getStringLiteralTokenValue(Token const& token) - { - SLANG_ASSERT(token.type == TokenType::StringLiteral - || token.type == TokenType::CharLiteral); - - char const* cursor = token.Content.begin(); - char const* end = token.Content.end(); - SLANG_UNREFERENCED_VARIABLE(end); - - auto quote = *cursor++; - SLANG_ASSERT(quote == '\'' || quote == '"'); - - StringBuilder valueBuilder; - for(;;) - { - SLANG_ASSERT(cursor != end); - - auto c = *cursor++; - - // If we see a closing quote, then we are at the end of the string literal - if(c == quote) - { - SLANG_ASSERT(cursor == end); - return valueBuilder.ProduceString(); - } - - // Characters that don't being escape sequences are easy; - // just append them to the buffer and move on. - if(c != '\\') - { - valueBuilder.Append(c); - continue; - } - - // Now we look at another character to figure out the kind of - // escape sequence we are dealing with: - - char d = *cursor++; - - switch(d) - { - // Simple characters that just needed to be escaped - case '\'': - case '\"': - case '\\': - case '?': - valueBuilder.Append(d); - continue; - - // Traditional escape sequences for special characters - case 'a': valueBuilder.Append('\a'); continue; - case 'b': valueBuilder.Append('\b'); continue; - case 'f': valueBuilder.Append('\f'); continue; - case 'n': valueBuilder.Append('\n'); continue; - case 'r': valueBuilder.Append('\r'); continue; - case 't': valueBuilder.Append('\t'); continue; - case 'v': valueBuilder.Append('\v'); continue; - - // Octal escape: up to 3 characterws - case '0': case '1': case '2': case '3': case '4': - case '5': case '6': case '7': - { - cursor--; - int value = 0; - for(int ii = 0; ii < 3; ++ii) - { - d = *cursor; - if(('0' <= d) && (d <= '7')) - { - value = value*8 + (d - '0'); - - cursor++; - continue; - } - else - { - break; - } - } - - // TODO: add support for appending an arbitrary code point? - valueBuilder.Append((char) value); - } - continue; - - // Hexadecimal escape: any number of characters - case 'x': - { - cursor--; - int value = 0; - for(;;) - { - d = *cursor++; - int digitValue = 0; - if(('0' <= d) && (d <= '9')) - { - digitValue = d - '0'; - } - else if( ('a' <= d) && (d <= 'f') ) - { - digitValue = d - 'a'; - } - else if( ('A' <= d) && (d <= 'F') ) - { - digitValue = d - 'A'; - } - else - { - cursor--; - break; - } - - value = value*16 + digitValue; - } - - // TODO: add support for appending an arbitrary code point? - valueBuilder.Append((char) value); - } - continue; - - // TODO: Unicode escape sequences - - } - } - } - - String getFileNameTokenValue(Token const& token) - { - // A file name usually doesn't process escape sequences - // (this is import on Windows, where `\\` is a valid - // path separator character). - - // Just trim off the first and last characters to remove the quotes - // (whether they were `""` or `<>`. - return String(token.Content.begin() + 1, token.Content.end() - 1); - } - - - - static TokenType lexTokenImpl(Lexer* lexer, LexerFlags effectiveFlags) - { - if(effectiveFlags & kLexerFlag_ExpectDirectiveMessage) - { - for(;;) - { - switch(peek(lexer)) - { - default: - advance(lexer); - continue; - - case kEOF: case '\r': case '\n': - break; - } - break; - } - return TokenType::DirectiveMessage; - } - - switch(peek(lexer)) - { - default: - break; - - case kEOF: - if((effectiveFlags & kLexerFlag_InDirective) != 0) - return TokenType::EndOfDirective; - return TokenType::EndOfFile; - - case '\r': case '\n': - if((effectiveFlags & kLexerFlag_InDirective) != 0) - return TokenType::EndOfDirective; - handleNewLine(lexer); - return TokenType::NewLine; - - case ' ': case '\t': - lexHorizontalSpace(lexer); - return TokenType::WhiteSpace; - - case '.': - advance(lexer); - switch(peek(lexer)) - { - case '0': case '1': case '2': case '3': case '4': - case '5': case '6': case '7': case '8': case '9': - return lexNumberAfterDecimalPoint(lexer, 10); - - // TODO(tfoley): handle ellipsis (`...`) - - default: - return TokenType::Dot; - } - - case '1': case '2': case '3': case '4': - case '5': case '6': case '7': case '8': case '9': - return lexNumber(lexer, 10); - - case '0': - { - auto loc = getSourceLoc(lexer); - advance(lexer); - switch(peek(lexer)) - { - default: - return maybeLexNumberSuffix(lexer, TokenType::IntegerLiteral); - - case '.': - advance(lexer); - return lexNumberAfterDecimalPoint(lexer, 10); - - case 'x': case 'X': - advance(lexer); - return lexNumber(lexer, 16); - - case 'b': case 'B': - advance(lexer); - return lexNumber(lexer, 2); - - case '0': case '1': case '2': case '3': case '4': - case '5': case '6': case '7': case '8': case '9': - lexer->sink->diagnose(loc, Diagnostics::octalLiteral); - return lexNumber(lexer, 8); - } - } - - case 'a': case 'b': case 'c': case 'd': case 'e': - case 'f': case 'g': case 'h': case 'i': case 'j': - case 'k': case 'l': case 'm': case 'n': case 'o': - case 'p': case 'q': case 'r': case 's': case 't': - case 'u': case 'v': case 'w': case 'x': case 'y': - case 'z': - case 'A': case 'B': case 'C': case 'D': case 'E': - case 'F': case 'G': case 'H': case 'I': case 'J': - case 'K': case 'L': case 'M': case 'N': case 'O': - case 'P': case 'Q': case 'R': case 'S': case 'T': - case 'U': case 'V': case 'W': case 'X': case 'Y': - case 'Z': - case '_': - lexIdentifier(lexer); - return TokenType::Identifier; - - case '\"': - advance(lexer); - lexStringLiteralBody(lexer, '\"'); - return TokenType::StringLiteral; - - case '\'': - advance(lexer); - lexStringLiteralBody(lexer, '\''); - return TokenType::CharLiteral; - - case '+': - advance(lexer); - switch(peek(lexer)) - { - case '+': advance(lexer); return TokenType::OpInc; - case '=': advance(lexer); return TokenType::OpAddAssign; - default: - return TokenType::OpAdd; - } - - case '-': - advance(lexer); - switch(peek(lexer)) - { - case '-': advance(lexer); return TokenType::OpDec; - case '=': advance(lexer); return TokenType::OpSubAssign; - case '>': advance(lexer); return TokenType::RightArrow; - default: - return TokenType::OpSub; - } - - case '*': - advance(lexer); - switch(peek(lexer)) - { - case '=': advance(lexer); return TokenType::OpMulAssign; - default: - return TokenType::OpMul; - } - - case '/': - advance(lexer); - switch(peek(lexer)) - { - case '=': advance(lexer); return TokenType::OpDivAssign; - case '/': advance(lexer); lexLineComment(lexer); return TokenType::LineComment; - case '*': advance(lexer); lexBlockComment(lexer); return TokenType::BlockComment; - default: - return TokenType::OpDiv; - } - - case '%': - advance(lexer); - switch(peek(lexer)) - { - case '=': advance(lexer); return TokenType::OpModAssign; - default: - return TokenType::OpMod; - } - - case '|': - advance(lexer); - switch(peek(lexer)) - { - case '|': advance(lexer); return TokenType::OpOr; - case '=': advance(lexer); return TokenType::OpOrAssign; - default: - return TokenType::OpBitOr; - } - - case '&': - advance(lexer); - switch(peek(lexer)) - { - case '&': advance(lexer); return TokenType::OpAnd; - case '=': advance(lexer); return TokenType::OpAndAssign; - default: - return TokenType::OpBitAnd; - } - - case '^': - advance(lexer); - switch(peek(lexer)) - { - case '=': advance(lexer); return TokenType::OpXorAssign; - default: - return TokenType::OpBitXor; - } - - case '>': - advance(lexer); - switch(peek(lexer)) - { - case '>': - advance(lexer); - switch(peek(lexer)) - { - case '=': advance(lexer); return TokenType::OpShrAssign; - default: return TokenType::OpRsh; - } - case '=': advance(lexer); return TokenType::OpGeq; - default: - return TokenType::OpGreater; - } - - case '<': - advance(lexer); - switch(peek(lexer)) - { - case '<': - advance(lexer); - switch(peek(lexer)) - { - case '=': advance(lexer); return TokenType::OpShlAssign; - default: return TokenType::OpLsh; - } - case '=': advance(lexer); return TokenType::OpLeq; - default: - return TokenType::OpLess; - } - - case '=': - advance(lexer); - switch(peek(lexer)) - { - case '=': advance(lexer); return TokenType::OpEql; - default: - return TokenType::OpAssign; - } - - case '!': - advance(lexer); - switch(peek(lexer)) - { - case '=': advance(lexer); return TokenType::OpNeq; - default: - return TokenType::OpNot; - } - - case '#': - advance(lexer); - switch(peek(lexer)) - { - case '#': advance(lexer); return TokenType::PoundPound; - default: - return TokenType::Pound; - } - - case '~': advance(lexer); return TokenType::OpBitNot; - - case ':': - { - advance(lexer); - if (peek(lexer) == ':') - { - advance(lexer); - return TokenType::Scope; - } - return TokenType::Colon; - } - case ';': advance(lexer); return TokenType::Semicolon; - case ',': advance(lexer); return TokenType::Comma; - - case '{': advance(lexer); return TokenType::LBrace; - case '}': advance(lexer); return TokenType::RBrace; - case '[': advance(lexer); return TokenType::LBracket; - case ']': advance(lexer); return TokenType::RBracket; - case '(': advance(lexer); return TokenType::LParent; - case ')': advance(lexer); return TokenType::RParent; - - case '?': advance(lexer); return TokenType::QuestionMark; - case '@': advance(lexer); return TokenType::At; - case '$': advance(lexer); return TokenType::Dollar; - - } - - // TODO(tfoley): If we ever wanted to support proper Unicode - // in identifiers, etc., then this would be the right place - // to perform a more expensive dispatch based on the actual - // code point (and not just the first byte). - - { - // If none of the above cases matched, then we have an - // unexpected/invalid character. - - auto loc = getSourceLoc(lexer); - int c = advance(lexer); - if(!(effectiveFlags & kLexerFlag_IgnoreInvalid)) - { - auto sink = lexer->sink; - if(c >= 0x20 && c <= 0x7E) - { - char buffer[] = { (char) c, 0 }; - sink->diagnose(loc, Diagnostics::illegalCharacterPrint, buffer); - } - else - { - // Fallback: print as hexadecimal - sink->diagnose(loc, Diagnostics::illegalCharacterHex, String((unsigned char)c, 16)); - } - } - - return TokenType::Invalid; - } - } - - Token Lexer::lexToken(LexerFlags extraFlags) - { - auto& flags = this->tokenFlags; - for(;;) - { - Token token; - token.loc = getSourceLoc(this); - - char const* textBegin = cursor; - - auto tokenType = lexTokenImpl(this, this->lexerFlags | extraFlags); - - // The low-level lexer produces tokens for things we want - // to ignore, such as white space, so we skip them here. - switch(tokenType) - { - case TokenType::Invalid: - flags = 0; - continue; - - case TokenType::NewLine: - flags = TokenFlag::AtStartOfLine | TokenFlag::AfterWhitespace; - continue; - - case TokenType::WhiteSpace: - case TokenType::LineComment: - case TokenType::BlockComment: - flags |= TokenFlag::AfterWhitespace; - continue; - - // We don't want to skip the end-of-file token, but we *do* - // want to make sure it has appropriate flags to make our life easier - case TokenType::EndOfFile: - flags |= TokenFlag::AtStartOfLine | TokenFlag::AfterWhitespace; - break; - - // We will also do some book-keeping around preprocessor directives here: - // - // If we see a `#` at the start of a line, then we are entering a - // preprocessor directive. - case TokenType::Pound: - if((flags & TokenFlag::AtStartOfLine) != 0) - lexerFlags |= kLexerFlag_InDirective; - break; - // - // And if we saw an end-of-line during a directive, then we are - // now leaving that directive. - // - case TokenType::EndOfDirective: - lexerFlags &= ~kLexerFlag_InDirective; - break; - - default: - break; - } - - token.type = tokenType; - - char const* textEnd = cursor; - - // Note(tfoley): `StringBuilder::Append()` seems to crash when appending zero bytes - if(textEnd != textBegin) - { - // "scrubbing" token value here to remove escaped newlines... - // - // Only perform this work if we encountered an escaped newline - // while lexing this token (e.g., keep a flag on the lexer), or - // do it on-demand when the actual value of the token is needed. - if (tokenFlags & TokenFlag::ScrubbingNeeded) - { - // Allocate space that will always be more than enough for stripped contents - char* startDst = (char*)memoryArena->allocateUnaligned(textEnd - textBegin); - char* dst = startDst; - - auto tt = textBegin; - while (tt != textEnd) - { - char c = *tt++; - if (c == '\\') - { - char d = *tt; - switch (d) - { - case '\r': case '\n': - { - tt++; - char e = *tt; - if ((d ^ e) == ('\r' ^ '\n')) - { - tt++; - } - } - continue; - - default: - break; - } - } - *dst++ = c; - } - token.Content = UnownedStringSlice(startDst, dst); - } - else - { - token.Content = UnownedStringSlice(textBegin, textEnd); - } - } - - token.flags = flags; - - this->tokenFlags = 0; - - if (tokenType == TokenType::Identifier) - { - token.ptrValue = this->namePool->getName(token.Content); - } - - return token; - } - } - - TokenList Lexer::lexAllTokens() - { - TokenList tokenList; - for(;;) - { - Token token = lexToken(); - tokenList.mTokens.add(token); - - if(token.type == TokenType::EndOfFile) - return tokenList; - } - } -} diff --git a/source/slang/lexer.h b/source/slang/lexer.h deleted file mode 100644 index 8587cc904..000000000 --- a/source/slang/lexer.h +++ /dev/null @@ -1,136 +0,0 @@ -#ifndef RASTER_RENDERER_LEXER_H -#define RASTER_RENDERER_LEXER_H - -#include "../core/basic.h" -#include "diagnostics.h" - -namespace Slang -{ - struct NamePool; - - // - - struct TokenList - { - Token* begin() const; - Token* end() const; - - List mTokens; - }; - - struct TokenSpan - { - TokenSpan(); - TokenSpan( - TokenList const& tokenList) - : mBegin(tokenList.begin()) - , mEnd (tokenList.end ()) - {} - - Token* begin() const { return mBegin; } - Token* end () const { return mEnd ; } - - int GetCount() { return (int)(mEnd - mBegin); } - - Token* mBegin; - Token* mEnd; - }; - - struct TokenReader - { - Token nextToken; - TokenReader(); - explicit TokenReader(TokenSpan const& tokens) - : mCursor(tokens.begin()) - , mEnd (tokens.end ()) - , nextToken(tokens.begin() ? *tokens.begin() : GetEndOfFileToken()) - {} - explicit TokenReader(TokenList const& tokens) - : mCursor(tokens.begin()) - , mEnd (tokens.end ()) - , nextToken(tokens.begin() ? *tokens.begin() : GetEndOfFileToken()) - {} - struct ParsingCursor - { - Token nextToken; - Token* tokenReaderCursor = nullptr; - }; - ParsingCursor getCursor() - { - ParsingCursor rs; - rs.nextToken = nextToken; - rs.tokenReaderCursor = mCursor; - return rs; - } - void setCursor(ParsingCursor cursor) - { - mCursor = cursor.tokenReaderCursor; - nextToken = cursor.nextToken; - } - bool IsAtEnd() const { return mCursor == mEnd; } - Token& PeekToken(); - TokenType PeekTokenType() const; - SourceLoc PeekLoc() const; - - Token AdvanceToken(); - - int GetCount() { return (int)(mEnd - mCursor); } - - Token* mCursor; - Token* mEnd; - static Token GetEndOfFileToken(); - }; - - typedef unsigned int LexerFlags; - enum - { - kLexerFlag_InDirective = 1 << 0, ///< Turn end-of-line and end-of-file into end-of-directive - kLexerFlag_ExpectFileName = 1 << 1, ///< Support `<>` style strings for file paths - kLexerFlag_IgnoreInvalid = 1 << 2, ///< Suppress errors about invalid/unsupported characters - kLexerFlag_ExpectDirectiveMessage = 1 << 3, ///< Don't lexer ordinary tokens, and instead consume rest of line as a string - }; - - struct Lexer - { - void initialize( - SourceView* sourceView, - DiagnosticSink* sink, - NamePool* namePool, - MemoryArena* memoryArena); - - ~Lexer(); - - Token lexToken(LexerFlags extraFlags = 0); - - TokenList lexAllTokens(); - - SourceView* sourceView; - DiagnosticSink* sink; - NamePool* namePool; - - char const* cursor; - - char const* begin; - char const* end; - - /// The starting sourceLoc (same as first location of SourceView) - SourceLoc startLoc; - - TokenFlags tokenFlags; - LexerFlags lexerFlags; - - MemoryArena* memoryArena; - }; - - // Helper routines for extracting values from tokens - String getStringLiteralTokenValue(Token const& token); - String getFileNameTokenValue(Token const& token); - - typedef int64_t IntegerLiteralValue; - typedef double FloatingPointLiteralValue; - - IntegerLiteralValue getIntegerLiteralValue(Token const& token, UnownedStringSlice* outSuffix = 0); - FloatingPointLiteralValue getFloatingPointLiteralValue(Token const& token, UnownedStringSlice* outSuffix = 0); -} - -#endif diff --git a/source/slang/lookup.cpp b/source/slang/lookup.cpp deleted file mode 100644 index fe535d6e7..000000000 --- a/source/slang/lookup.cpp +++ /dev/null @@ -1,713 +0,0 @@ -// lookup.cpp -#include "lookup.h" -#include "name.h" - -namespace Slang { - -void checkDecl(SemanticsVisitor* visitor, Decl* decl); - -// - -DeclRef ApplyExtensionToType( - SemanticsVisitor* semantics, - ExtensionDecl* extDecl, - RefPtr type); - -// - - -// Helper for constructing breadcrumb trails during lookup, without unnecessary heap allocaiton -struct BreadcrumbInfo -{ - LookupResultItem::Breadcrumb::Kind kind; - LookupResultItem::Breadcrumb::ThisParameterMode thisParameterMode = LookupResultItem::Breadcrumb::ThisParameterMode::Default; - DeclRef declRef; - BreadcrumbInfo* prev = nullptr; -}; - -void DoLocalLookupImpl( - Session* session, - Name* name, - DeclRef containerDeclRef, - LookupRequest const& request, - LookupResult& result, - BreadcrumbInfo* inBreadcrumbs); - -// - -void buildMemberDictionary(ContainerDecl* decl) -{ - // Don't rebuild if already built - if (decl->memberDictionaryIsValid) - return; - - decl->memberDictionary.Clear(); - decl->transparentMembers.clear(); - - // are we a generic? - GenericDecl* genericDecl = as(decl); - - for (auto m : decl->Members) - { - auto name = m->getName(); - - // Add any transparent members to a separate list for lookup - if (m->HasModifier()) - { - TransparentMemberInfo info; - info.decl = m.Ptr(); - decl->transparentMembers.add(info); - } - - // Ignore members with no name - if (!name) - continue; - - // Ignore the "inner" member of a generic declaration - if (genericDecl && m == genericDecl->inner) - continue; - - - m->nextInContainerWithSameName = nullptr; - - Decl* next = nullptr; - if (decl->memberDictionary.TryGetValue(name, next)) - m->nextInContainerWithSameName = next; - - decl->memberDictionary[name] = m.Ptr(); - - } - decl->memberDictionaryIsValid = true; -} - - -bool DeclPassesLookupMask(Decl* decl, LookupMask mask) -{ - // type declarations - if(auto aggTypeDecl = as(decl)) - { - return int(mask) & int(LookupMask::type); - } - else if(auto simpleTypeDecl = as(decl)) - { - return int(mask) & int(LookupMask::type); - } - // function declarations - else if(auto funcDecl = as(decl)) - { - return (int(mask) & int(LookupMask::Function)) != 0; - } - // attribute declaration - else if( auto attrDecl = as(decl) ) - { - return (int(mask) & int(LookupMask::Attribute)) != 0; - } - - // default behavior is to assume a value declaration - // (no overloading allowed) - - return (int(mask) & int(LookupMask::Value)) != 0; -} - -void AddToLookupResult( - LookupResult& result, - LookupResultItem item) -{ - if (!result.isValid()) - { - // If we hadn't found a hit before, we have one now - result.item = item; - } - else if (!result.isOverloaded()) - { - // We are about to make this overloaded - result.items.add(result.item); - result.items.add(item); - } - else - { - // The result was already overloaded, so we pile on - result.items.add(item); - } -} - -LookupResult refineLookup(LookupResult const& inResult, LookupMask mask) -{ - if (!inResult.isValid()) return inResult; - if (!inResult.isOverloaded()) return inResult; - - LookupResult result; - for (auto item : inResult.items) - { - if (!DeclPassesLookupMask(item.declRef.getDecl(), mask)) - continue; - - AddToLookupResult(result, item); - } - return result; -} - -LookupResultItem CreateLookupResultItem( - DeclRef declRef, - BreadcrumbInfo* breadcrumbInfos) -{ - LookupResultItem item; - item.declRef = declRef; - - // breadcrumbs were constructed "backwards" on the stack, so we - // reverse them here by building a linked list the other way - RefPtr breadcrumbs; - for (auto bb = breadcrumbInfos; bb; bb = bb->prev) - { - breadcrumbs = new LookupResultItem::Breadcrumb( - bb->kind, - bb->declRef, - breadcrumbs, - bb->thisParameterMode); - } - item.breadcrumbs = breadcrumbs; - return item; -} - -void DoMemberLookupImpl( - Session* session, - Name* name, - RefPtr baseType, - LookupRequest const& request, - LookupResult& ioResult, - BreadcrumbInfo* breadcrumbs) -{ - if (!baseType) - { - return; - } - - // If the type was pointer-like, then dereference it - // automatically here. - if (auto pointerLikeType = as(baseType)) - { - // Need to leave a breadcrumb to indicate that we - // did an implicit dereference here - BreadcrumbInfo derefBreacrumb; - derefBreacrumb.kind = LookupResultItem::Breadcrumb::Kind::Deref; - derefBreacrumb.prev = breadcrumbs; - - // Recursively perform lookup on the result of deref - return DoMemberLookupImpl( - session, - name, pointerLikeType->elementType, request, ioResult, &derefBreacrumb); - } - - // Default case: no dereference needed - - if (auto baseDeclRefType = as(baseType)) - { - if (auto baseAggTypeDeclRef = baseDeclRefType->declRef.as()) - { - DoLocalLookupImpl( - session, - name, baseAggTypeDeclRef, request, ioResult, breadcrumbs); - } - } - - // TODO(tfoley): any other cases to handle here? -} - -void DoMemberLookupImpl( - Session* session, - Name* name, - DeclRef baseDeclRef, - LookupRequest const& request, - LookupResult& ioResult, - BreadcrumbInfo* breadcrumbs) -{ - auto baseType = getTypeForDeclRef( - session, - baseDeclRef); - return DoMemberLookupImpl( - session, - name, baseType, request, ioResult, breadcrumbs); -} - -// If we are about to perform lookup through an interface, then -// we need to specialize the decl-ref to that interface to include -// a "this type" subtitution. This function applies that substition -// when it is required, and returns the existing `declRef` otherwise. -DeclRef maybeSpecializeInterfaceDeclRef( - RefPtr subType, - RefPtr superType, - DeclRef superTypeDeclRef, // The decl-ref we are going to perform lookup in - DeclRef constraintDeclRef) // The type constraint that told us our type is a subtype -{ - if (auto superInterfaceDeclRef = superTypeDeclRef.as()) - { - // Create a subtype witness value to note the subtype relationship - // that makes this specialization valid. - // - // Note: this is to ensure that we can specialize the subtype witness - // later (e.g., by replacing a subtype witness that represents a generic - // constraint parameter with the concrete generic arguments that - // are used at a particular call site to the generic). - RefPtr subtypeWitness = new DeclaredSubtypeWitness(); - subtypeWitness->declRef = constraintDeclRef; - subtypeWitness->sub = subType; - subtypeWitness->sup = superType; - - RefPtr thisTypeSubst = new ThisTypeSubstitution(); - thisTypeSubst->interfaceDecl = superInterfaceDeclRef.getDecl(); - thisTypeSubst->witness = subtypeWitness; - thisTypeSubst->outer = superInterfaceDeclRef.substitutions.substitutions; - - auto specializedInterfaceDeclRef = DeclRef(superInterfaceDeclRef.getDecl(), thisTypeSubst); - return specializedInterfaceDeclRef; - } - - return superTypeDeclRef; -} - -// Same as the above, but we are specializing a type instead of a decl-ref -RefPtr maybeSpecializeInterfaceDeclRef( - Session* session, - RefPtr subType, - RefPtr superType, // The type we are going to perform lookup in - DeclRef constraintDeclRef) // The type constraint that told us our type is a subtype -{ - if (auto superDeclRefType = as(superType)) - { - if (auto superInterfaceDeclRef = superDeclRefType->declRef.as()) - { - auto specializedInterfaceDeclRef = maybeSpecializeInterfaceDeclRef( - subType, - superType, - superInterfaceDeclRef, - constraintDeclRef); - auto specializedInterfaceType = DeclRefType::Create(session, specializedInterfaceDeclRef); - return specializedInterfaceType; - } - } - - return superType; -} - - -// Look for members of the given name in the given container for declarations -void DoLocalLookupImpl( - Session* session, - Name* name, - DeclRef containerDeclRef, - LookupRequest const& request, - LookupResult& result, - BreadcrumbInfo* inBreadcrumbs) -{ - if (result.lookedupDecls.Contains(containerDeclRef)) - return; - result.lookedupDecls.Add(containerDeclRef); - - ContainerDecl* containerDecl = containerDeclRef.getDecl(); - - // Ensure that the lookup dictionary in the container is up to date - if (!containerDecl->memberDictionaryIsValid) - { - buildMemberDictionary(containerDecl); - } - - // Look up the declarations with the chosen name in the container. - Decl* firstDecl = nullptr; - containerDecl->memberDictionary.TryGetValue(name, firstDecl); - - // Now iterate over those declarations (if any) and see if - // we find any that meet our filtering criteria. - // For example, we might be filtering so that we only consider - // type declarations. - for (auto m = firstDecl; m; m = m->nextInContainerWithSameName) - { - if (!DeclPassesLookupMask(m, request.mask)) - continue; - - // The declaration passed the test, so add it! - AddToLookupResult(result, CreateLookupResultItem(DeclRef(m, containerDeclRef.substitutions), inBreadcrumbs)); - } - - - // TODO(tfoley): should we look up in the transparent decls - // if we already has a hit in the current container? - - for(auto transparentInfo : containerDecl->transparentMembers) - { - // The reference to the transparent member should use whatever - // substitutions we used in referring to its outer container - DeclRef transparentMemberDeclRef(transparentInfo.decl, containerDeclRef.substitutions); - - // We need to leave a breadcrumb so that we know that the result - // of lookup involves a member lookup step here - - BreadcrumbInfo memberRefBreadcrumb; - memberRefBreadcrumb.kind = LookupResultItem::Breadcrumb::Kind::Member; - memberRefBreadcrumb.declRef = transparentMemberDeclRef; - memberRefBreadcrumb.prev = inBreadcrumbs; - - DoMemberLookupImpl( - session, - name, - transparentMemberDeclRef, - request, - result, - &memberRefBreadcrumb); - } - - // Consider lookup via extension - if( auto aggTypeDeclRef = containerDeclRef.as() ) - { - RefPtr type = DeclRefType::Create( - session, - aggTypeDeclRef); - - for (auto ext = GetCandidateExtensions(aggTypeDeclRef); ext; ext = ext->nextCandidateExtension) - { - auto extDeclRef = ApplyExtensionToType(request.semantics, ext, type); - if (!extDeclRef) - continue; - - // TODO: eventually we need to insert a breadcrumb here so that - // the constructed result can somehow indicate that a member - // was found through an extension. - - DoLocalLookupImpl( - session, - name, extDeclRef, request, result, inBreadcrumbs); - } - - } - // for interface decls, also lookup in the base interfaces - if (request.semantics) - { - // TODO: - // The logic here is a bit gross, because it tries to work in terms of - // decl-refs instead of types (e.g., it asserts that the target type - // for an `extension` declaration must be a decl-ref type). - // - // This code should be converted to do a type-based lookup - // through declared bases for *any* aggregate type declaration. - // I think that logic is present in the type-based lookup path, but - // it would be needed here for when doing lookup from inside an - // aggregate declaration. - - // if we are looking at an extension, find the target decl that we are extending - DeclRef targetDeclRef = containerDeclRef; - RefPtr targetDeclRefType; - if (auto extDeclRef = containerDeclRef.as()) - { - targetDeclRefType = as(extDeclRef.getDecl()->targetType); - SLANG_ASSERT(targetDeclRefType); - int diff = 0; - targetDeclRef = targetDeclRefType->declRef.as().SubstituteImpl(containerDeclRef.substitutions, &diff); - } - - // if we are looking inside an interface decl, try find in the interfaces it inherits from - if (targetDeclRef.is()) - { - if(!targetDeclRefType) - { - targetDeclRefType = DeclRefType::Create(session, targetDeclRef); - } - - auto baseInterfaces = getMembersOfType(containerDeclRef); - for (auto inheritanceDeclRef : baseInterfaces) - { - checkDecl(request.semantics, inheritanceDeclRef.decl); - - auto baseType = inheritanceDeclRef.getDecl()->base.type.dynamicCast(); - SLANG_ASSERT(baseType); - int diff = 0; - auto baseInterfaceDeclRef = baseType->declRef.SubstituteImpl(containerDeclRef.substitutions, &diff); - - baseInterfaceDeclRef = maybeSpecializeInterfaceDeclRef( - targetDeclRefType, - baseType, - baseInterfaceDeclRef, - inheritanceDeclRef); - - DoLocalLookupImpl(session, name, baseInterfaceDeclRef.as(), request, result, inBreadcrumbs); - } - } - } -} - -void DoLookupImpl( - Session* session, - Name* name, - LookupRequest const& request, - LookupResult& result) -{ - auto thisParameterMode = LookupResultItem::Breadcrumb::ThisParameterMode::Default; - - auto scope = request.scope; - auto endScope = request.endScope; - for (;scope != endScope; scope = scope->parent) - { - // Note that we consider all "peer" scopes together, - // so that a hit in one of them does not preclude - // also finding a hit in another - for(auto link = scope; link; link = link->nextSibling) - { - auto containerDecl = link->containerDecl; - - if(!containerDecl) - continue; - - DeclRef containerDeclRef = - DeclRef(containerDecl, createDefaultSubstitutions(session, containerDecl)).as(); - - BreadcrumbInfo breadcrumb; - BreadcrumbInfo* breadcrumbs = nullptr; - - // Depending on the kind of container we are looking into, - // we may need to insert something like a `this` expression - // to resolve the lookup result. - // - // Note: We are checking for `AggTypeDeclBase` here, and not - // just `AggTypeDecl`, because we want to catch `extension` - // declarations as well. - // - if (auto aggTypeDeclRef = containerDeclRef.as()) - { - breadcrumb.kind = LookupResultItem::Breadcrumb::Kind::This; - breadcrumb.thisParameterMode = thisParameterMode; - breadcrumb.declRef = aggTypeDeclRef; - breadcrumb.prev = nullptr; - - breadcrumbs = &breadcrumb; - } - - // Now perform "local" lookup in the context of the container, - // as if we were looking up a member directly. - - // if we are currently in an extension decl, perform local lookup - // in the target decl we are extending - if (auto extDeclRef = containerDeclRef.as()) - { - if (extDeclRef.getDecl()->targetType) - { - if (auto targetDeclRef = as(extDeclRef.getDecl()->targetType)) - { - if (auto aggDeclRef = targetDeclRef->declRef.as()) - { - containerDeclRef = extDeclRef.Substitute(aggDeclRef); - } - } - } - } - DoLocalLookupImpl( - session, - name, containerDeclRef, request, result, breadcrumbs); - - if( auto funcDeclRef = containerDeclRef.as() ) - { - if( funcDeclRef.getDecl()->HasModifier() ) - { - thisParameterMode = LookupResultItem::Breadcrumb::ThisParameterMode::Mutating; - } - else - { - thisParameterMode = LookupResultItem::Breadcrumb::ThisParameterMode::Default; - } - } - } - - if (result.isValid()) - { - // If we've found a result in this scope, then there - // is no reason to look further up (for now). - return; - } - } - - // If we run out of scopes, then we are done. -} - -LookupResult DoLookup( - Session* session, - Name* name, - LookupRequest const& request) -{ - LookupResult result; - DoLookupImpl(session, name, request, result); - return result; -} - -LookupResult lookUp( - Session* session, - SemanticsVisitor* semantics, - Name* name, - RefPtr scope, - LookupMask mask) -{ - LookupRequest request; - request.semantics = semantics; - request.scope = scope; - request.mask = mask; - return DoLookup(session, name, request); -} - -// perform lookup within the context of a particular container declaration, -// and do *not* look further up the chain -LookupResult lookUpLocal( - Session* session, - SemanticsVisitor* semantics, - Name* name, - DeclRef containerDeclRef, - LookupMask mask) -{ - LookupRequest request; - request.semantics = semantics; - request.mask = mask; - - LookupResult result; - DoLocalLookupImpl(session, name, containerDeclRef, request, result, nullptr); - return result; -} - -void lookUpMemberImpl( - Session* session, - SemanticsVisitor* semantics, - Name* name, - Type* type, - LookupResult& ioResult, - BreadcrumbInfo* inBreadcrumbs, - LookupMask mask); - -// Perform lookup "through" the given constraint decl-ref, -// which should show that `subType` is a sub-type of some -// super-type (e.g., an interface). -// -void lookUpThroughConstraint( - Session* session, - SemanticsVisitor* semantics, - Name* name, - Type* subType, - DeclRef constraintDeclRef, - LookupResult& ioResult, - BreadcrumbInfo* inBreadcrumbs, - LookupMask mask) -{ - // The super-type in the constraint (e.g., `Foo` in `T : Foo`) - // will tell us a type we should use for lookup. - // - auto superType = GetSup(constraintDeclRef); - // - // We will go ahead and perform lookup using `superType`, - // after dealing with some details. - - // If we are looking up through an interface type, then - // we need to be sure that we add an appropriate - // "this type" substitution here, since that needs to - // be applied to any members we look up. - // - superType = maybeSpecializeInterfaceDeclRef( - session, - subType, - superType, - constraintDeclRef); - - // We need to track the indirection we took in lookup, - // so that we can construct an appropriate AST on the other - // side that includes the "upcase" from sub-type to super-type. - // - BreadcrumbInfo breadcrumb; - breadcrumb.prev = inBreadcrumbs; - breadcrumb.kind = LookupResultItem::Breadcrumb::Kind::Constraint; - breadcrumb.declRef = constraintDeclRef; - - // TODO: Need to consider case where this might recurse infinitely (e.g., - // if an inheritance clause does something like `Bad : Bad>`. - // - // TODO: The even simpler thing we need to worry about here is that if - // there is ever a "diamond" relationship in the inheritance hierarchy, - // we might end up seeing the same interface via different "paths" and - // we wouldn't want that to lead to overload-resolution failure. - // - lookUpMemberImpl(session, semantics, name, superType, ioResult, &breadcrumb, mask); -} - -void lookUpMemberImpl( - Session* session, - SemanticsVisitor* semantics, - Name* name, - Type* type, - LookupResult& ioResult, - BreadcrumbInfo* inBreadcrumbs, - LookupMask mask) -{ - if (auto declRefType = as(type)) - { - auto declRef = declRefType->declRef; - if (declRef.as() || declRef.as()) - { - for (auto constraintDeclRef : getMembersOfType(declRef.as())) - { - lookUpThroughConstraint( - session, - semantics, - name, - type, - constraintDeclRef, - ioResult, - inBreadcrumbs, - mask); - } - } - else if (auto aggTypeDeclRef = declRef.as()) - { - LookupRequest request; - request.semantics = semantics; - - DoLocalLookupImpl(session, name, aggTypeDeclRef, request, ioResult, inBreadcrumbs); - } - else if (auto genericTypeParamDeclRef = declRef.as()) - { - auto genericDeclRef = genericTypeParamDeclRef.GetParent().as(); - assert(genericDeclRef); - - for(auto constraintDeclRef : getMembersOfType(genericDeclRef)) - { - // Does this constraint pertain to the type we are working on? - // - // We want constraints of the form `T : Foo` where `T` is the - // generic parameter in question, and `Foo` is whatever we are - // constraining it to. - auto subType = GetSub(constraintDeclRef); - auto subDeclRefType = as(subType); - if(!subDeclRefType) - continue; - if(!subDeclRefType->declRef.Equals(genericTypeParamDeclRef)) - continue; - - lookUpThroughConstraint( - session, - semantics, - name, - type, - constraintDeclRef, - ioResult, - inBreadcrumbs, - mask); - } - } - - } - -} - -LookupResult lookUpMember( - Session* session, - SemanticsVisitor* semantics, - Name* name, - Type* type, - LookupMask mask) -{ - LookupResult result; - lookUpMemberImpl(session, semantics, name, type, result, nullptr, mask); - return result; -} - -} diff --git a/source/slang/lookup.h b/source/slang/lookup.h deleted file mode 100644 index 37ab5cf06..000000000 --- a/source/slang/lookup.h +++ /dev/null @@ -1,60 +0,0 @@ -#ifndef SLANG_LOOKUP_H_INCLUDED -#define SLANG_LOOKUP_H_INCLUDED - -#include "syntax.h" - -namespace Slang { - -struct SemanticsVisitor; - -// Take an existing lookup result and refine it to only include -// results that pass the given `LookupMask`. -LookupResult refineLookup(LookupResult const& inResult, LookupMask mask); - -// Ensure that the dictionary for name-based member lookup has been -// built for the given container declaration. -void buildMemberDictionary(ContainerDecl* decl); - -// Look up a name in the given scope, proceeding up through -// parent scopes as needed. -LookupResult lookUp( - Session* session, - SemanticsVisitor* semantics, - Name* name, - RefPtr scope, - LookupMask mask = LookupMask::Default); - -// perform lookup within the context of a particular container declaration, -// and do *not* look further up the chain -LookupResult lookUpLocal( - Session* session, - SemanticsVisitor* semantics, - Name* name, - DeclRef containerDeclRef, - LookupMask mask = LookupMask::Default); - -// Perform member lookup in the context of a type -LookupResult lookUpMember( - Session* session, - SemanticsVisitor* semantics, - Name* name, - Type* type, - LookupMask mask = LookupMask::Default); - -// TODO: this belongs somewhere else - -QualType getTypeForDeclRef( - Session* session, - SemanticsVisitor* sema, - DiagnosticSink* sink, - DeclRef declRef, - RefPtr* outTypeResult); - -QualType getTypeForDeclRef( - Session* session, - DeclRef declRef); - - -} - -#endif \ No newline at end of file diff --git a/source/slang/lower-to-ir.cpp b/source/slang/lower-to-ir.cpp deleted file mode 100644 index a7be244c8..000000000 --- a/source/slang/lower-to-ir.cpp +++ /dev/null @@ -1,6498 +0,0 @@ -// lower.cpp -#include "lower-to-ir.h" - -#include "../../slang.h" - -#include "check.h" -#include "ir.h" -#include "ir-constexpr.h" -#include "ir-insts.h" -#include "ir-missing-return.h" -#include "ir-sccp.h" -#include "ir-ssa.h" -#include "ir-validate.h" -#include "mangle.h" -#include "type-layout.h" -#include "visitor.h" - -namespace Slang -{ - -// This file implements lowering of the Slang AST to a simpler SSA -// intermediate representation. -// -// IR is generated in a context (`IRGenContext`), which tracks the current -// location in the IR where code should be emitted (e.g., what basic -// block to add instructions to). Lowering a statement will emit some -// number of instructions to the context, and possibly change the -// insertion point (because of control flow). -// -// When lowering an expression we have a more interesting challenge, for -// two main reasons: -// -// 1. There might be types that are representible in the AST, but which -// we don't want to support natively in the IR. An example is a `struct` -// type with both ordinary and resource-type members; we might want to -// split values with such a type into distinct values during lowering. -// -// 2. We need to handle the difference between l-value and r-value expressions, -// and in particular the fact that HLSL/Slang supports complicated sorts -// of l-values (e.g., `someVector.zxy` is an l-value, even though it can't -// be represented by a single pointer), and also allows l-values to appear -// in multiple contexts (not just the left-hand side of assignment, but -// also as an argument to match an `out` or `in out` parameter). -// -// Our solution to both of these problems is the same. Rather than having -// the lowering of an expression return a single IR-level value (`IRInst*`), -// we have it return a more complex type (`LoweredValInfo`) which can represent -// a wider range of conceptual "values" which might correspond to multiple IR-level -// values, and/or represent a pointer to an l-value rather than the r-value itself. - -// We want to keep the representation of a `LoweringValInfo` relatively light -// - right now it is just a single pointer plus a "tag" to distinguish the cases. -// -// This means that cases that can't fit in a single pointer need a heap allocation -// to store their payload. For simplicity we represent all of these with a class -// hierarchy: -// -struct ExtendedValueInfo : RefObject -{}; - -// This case is used to indicate a value that is a reference -// to an AST-level subscript declaration. -// -struct SubscriptInfo : ExtendedValueInfo -{ - DeclRef declRef; -}; - -// This case is used to indicate a reference to an AST-level -// subscript operation bound to particular arguments. -// -// For example in a case like this: -// -// RWStructuredBuffer gBuffer; -// ... gBuffer[someIndex] ... -// -// the expression `gBuffer[someIndex]` will be lowered to -// a value that references `RWStructureBuffer::operator[]` -// with arguments `(gBuffer, someIndex)`. -// -// Such a value can be an l-value, and depending on the context -// where it is used, can lower into a call to either the getter -// or setter operations of the subscript. -// -struct BoundSubscriptInfo : ExtendedValueInfo -{ - DeclRef declRef; - IRType* type; - List args; -}; - -// Some cases of `ExtendedValueInfo` need to -// recursively contain `LoweredValInfo`s, and -// so we forward declare them here and fill -// them in later. -// -struct BoundMemberInfo; -struct SwizzledLValueInfo; - - -// This type is our core representation of lowered values. -// In the simple case, it just wraps an `IRInst*`. -// More complex cases, representing l-values or aggregate -// values are also supported. -struct LoweredValInfo -{ - // Which of the cases of value are we looking at? - enum class Flavor - { - // No value (akin to a null pointer) - None, - - // A simple IR value - Simple, - - // An l-value represented as an IR - // pointer to the value - Ptr, - - // A member declaration bound to a particular `this` value - BoundMember, - - // A reference to an AST-level subscript operation - Subscript, - - // An AST-level subscript operation bound to a particular - // object and arguments. - BoundSubscript, - - // The result of applying swizzling to an l-value - SwizzledLValue, - }; - - union - { - IRInst* val; - ExtendedValueInfo* ext; - }; - Flavor flavor; - - LoweredValInfo() - { - flavor = Flavor::None; - val = nullptr; - } - - LoweredValInfo(IRType* t) - { - flavor = Flavor::Simple; - val = t; - } - - static LoweredValInfo simple(IRInst* v) - { - LoweredValInfo info; - info.flavor = Flavor::Simple; - info.val = v; - return info; - } - - static LoweredValInfo ptr(IRInst* v) - { - LoweredValInfo info; - info.flavor = Flavor::Ptr; - info.val = v; - return info; - } - - static LoweredValInfo boundMember( - BoundMemberInfo* boundMemberInfo); - - BoundMemberInfo* getBoundMemberInfo() - { - SLANG_ASSERT(flavor == Flavor::BoundMember); - return (BoundMemberInfo*)ext; - } - - static LoweredValInfo subscript( - SubscriptInfo* subscriptInfo); - - SubscriptInfo* getSubscriptInfo() - { - SLANG_ASSERT(flavor == Flavor::Subscript); - return (SubscriptInfo*)ext; - } - - static LoweredValInfo boundSubscript( - BoundSubscriptInfo* boundSubscriptInfo); - - BoundSubscriptInfo* getBoundSubscriptInfo() - { - SLANG_ASSERT(flavor == Flavor::BoundSubscript); - return (BoundSubscriptInfo*)ext; - } - - static LoweredValInfo swizzledLValue( - SwizzledLValueInfo* extInfo); - - SwizzledLValueInfo* getSwizzledLValueInfo() - { - SLANG_ASSERT(flavor == Flavor::SwizzledLValue); - return (SwizzledLValueInfo*)ext; - } -}; - -// Represents some declaration bound to a particular -// object. For example, if we had `obj.f` where `f` -// is a member function, we'd use a `BoundMemberInfo` -// to represnet this. -// -// Note: This case is largely avoided by special-casing -// in the handling of calls (like `obj.f(arg)`), but -// it is being left here as an example of what we might -// need/want to do in the long term. -struct BoundMemberInfo : ExtendedValueInfo -{ - // The base object - LoweredValInfo base; - - // The (AST-level) declaration reference. - DeclRef declRef; - - // The type of this value - IRType* type; -}; - -// Represents the result of a swizzle operation in -// an l-value context. A swizzle without duplicate -// elements is allowed as an l-value, even if the -// element are non-contiguous (`.xz`) or out of -// order (`.zxy`). -// -struct SwizzledLValueInfo : ExtendedValueInfo -{ - // The type of the expression. - IRType* type; - - // The base expression (this should be an l-value) - LoweredValInfo base; - - // The number of elements in the swizzle - UInt elementCount; - - // THe indices for the elements being swizzled - UInt elementIndices[4]; -}; - -LoweredValInfo LoweredValInfo::boundMember( - BoundMemberInfo* boundMemberInfo) -{ - LoweredValInfo info; - info.flavor = Flavor::BoundMember; - info.ext = boundMemberInfo; - return info; -} - -LoweredValInfo LoweredValInfo::subscript( - SubscriptInfo* subscriptInfo) -{ - LoweredValInfo info; - info.flavor = Flavor::Subscript; - info.ext = subscriptInfo; - return info; -} - -LoweredValInfo LoweredValInfo::boundSubscript( - BoundSubscriptInfo* boundSubscriptInfo) -{ - LoweredValInfo info; - info.flavor = Flavor::BoundSubscript; - info.ext = boundSubscriptInfo; - return info; -} - -LoweredValInfo LoweredValInfo::swizzledLValue( - SwizzledLValueInfo* extInfo) -{ - LoweredValInfo info; - info.flavor = Flavor::SwizzledLValue; - info.ext = extInfo; - return info; -} - -// An "environment" for mapping AST declarations to IR values. -// -// This is required because in some cases we might lower the -// same AST declaration to the IR multiple times (e.g., when -// a generic transitively contains multiple functions, we -// will emit a distinct IR generic for each function, with -// its own copies of the generic parameters). -// -struct IRGenEnv -{ - // Map an AST-level declaration to the IR-level value that represents it. - Dictionary mapDeclToValue; - - // The next outer env around this one - IRGenEnv* outer = nullptr; -}; - -struct SharedIRGenContext -{ - SharedIRGenContext( - Session* session, - DiagnosticSink* sink, - ModuleDecl* mainModuleDecl = nullptr) - : m_session(session) - , m_sink(sink) - , m_mainModuleDecl(mainModuleDecl) - {} - - Session* m_session = nullptr; - DiagnosticSink* m_sink = nullptr; - ModuleDecl* m_mainModuleDecl = nullptr; - - // The "global" environment for mapping declarations to their IR values. - IRGenEnv globalEnv; - - // Map an AST-level declaration of an interface - // requirement to the IR-level "key" that - // is used to fetch that requirement from a - // witness table. - Dictionary interfaceRequirementKeys; - - // Arrays we keep around strictly for memory-management purposes: - - // Any extended values created during lowering need - // to be cleaned up after the fact. We don't try - // to reference-count these along the way because - // they need to get stored into a `union` inside `LoweredValInfo` - List> extValues; - - // Map from an AST-level statement that can be - // used as the target of a `break` or `continue` - // to the appropriate basic block to jump to. - Dictionary breakLabels; - Dictionary continueLabels; -}; - - -struct IRGenContext -{ - // Shared state for the IR generation process - SharedIRGenContext* shared; - - // environment for mapping AST decls to IR values - IRGenEnv* env; - - // IR builder to use when building code under this context - IRBuilder* irBuilder; - - // The value to use for any `this` expressions - // that appear in the current context. - // - // TODO: If we ever allow nesting of (non-static) - // types, then we may need to support references - // to an "outer `this`", and this representation - // might be insufficient. - LoweredValInfo thisVal; - - explicit IRGenContext(SharedIRGenContext* inShared) - : shared(inShared) - , env(&inShared->globalEnv) - , irBuilder(nullptr) - {} - - Session* getSession() - { - return shared->m_session; - } - - DiagnosticSink* getSink() - { - return shared->m_sink; - } - - ModuleDecl* getMainModuleDecl() - { - return shared->m_mainModuleDecl; - } -}; - -void setGlobalValue(SharedIRGenContext* sharedContext, Decl* decl, LoweredValInfo value) -{ - sharedContext->globalEnv.mapDeclToValue[decl] = value; -} - -void setGlobalValue(IRGenContext* context, Decl* decl, LoweredValInfo value) -{ - setGlobalValue(context->shared, decl, value); -} - -void setValue(IRGenContext* context, Decl* decl, LoweredValInfo value) -{ - context->env->mapDeclToValue[decl] = value; -} - -ModuleDecl* findModuleDecl(Decl* decl) -{ - for (auto dd = decl; dd; dd = dd->ParentDecl) - { - if (auto moduleDecl = as(dd)) - return moduleDecl; - } - return nullptr; -} - -bool isFromStdLib(Decl* decl) -{ - for (auto dd = decl; dd; dd = dd->ParentDecl) - { - if (dd->HasModifier()) - return true; - } - return false; -} - -bool isImportedDecl(IRGenContext* context, Decl* decl) -{ - ModuleDecl* moduleDecl = findModuleDecl(decl); - if (!moduleDecl) - return false; - - // HACK: don't treat standard library code as - // being imported for right now, just because - // we don't load its IR in the same way as - // for other imports. - // - // TODO: Fix this the right way, by having standard - // library declarations have IR modules that we link - // in via the normal means. - if (isFromStdLib(decl)) - return false; - - if (moduleDecl != context->getMainModuleDecl()) - return true; - - return false; -} - - /// Should the given `decl` nested in `parentDecl` be treated as a static rather than instance declaration? -bool isEffectivelyStatic( - Decl* decl, - ContainerDecl* parentDecl); - -// Ensure that a version of the given declaration has been emitted to the IR -LoweredValInfo ensureDecl( - IRGenContext* context, - Decl* decl); - -// Emit code as needed to construct a reference to the given declaration with -// any needed specializations in place. -LoweredValInfo emitDeclRef( - IRGenContext* context, - DeclRef declRef, - IRType* type); - -IRInst* getSimpleVal(IRGenContext* context, LoweredValInfo lowered); - -IROp getIntrinsicOp( - Decl* decl, - IntrinsicOpModifier* intrinsicOpMod) -{ - if (int(intrinsicOpMod->op) != 0) - return intrinsicOpMod->op; - - // No specified modifier? Then we need to look it up - // based on the name of the declaration... - - auto name = decl->getName(); - auto nameText = getUnownedStringSliceText(name); - - IROp op = findIROp(nameText); - SLANG_ASSERT(op != kIROp_Invalid); - return op; -} - -// Given a `LoweredValInfo` for something callable, along with a -// bunch of arguments, emit an appropriate call to it. -LoweredValInfo emitCallToVal( - IRGenContext* context, - IRType* type, - LoweredValInfo funcVal, - UInt argCount, - IRInst* const* args) -{ - auto builder = context->irBuilder; - switch (funcVal.flavor) - { - case LoweredValInfo::Flavor::None: - SLANG_UNEXPECTED("null function"); - default: - return LoweredValInfo::simple( - builder->emitCallInst(type, getSimpleVal(context, funcVal), argCount, args)); - } -} - -LoweredValInfo emitCompoundAssignOp( - IRGenContext* context, - IRType* type, - IROp op, - UInt argCount, - IRInst* const* args) -{ - auto builder = context->irBuilder; - SLANG_UNREFERENCED_PARAMETER(argCount); - SLANG_ASSERT(argCount == 2); - auto leftPtr = args[0]; - auto rightVal = args[1]; - - auto leftVal = builder->emitLoad(leftPtr); - - IRInst* innerArgs[] = { leftVal, rightVal }; - auto innerOp = builder->emitIntrinsicInst(type, op, 2, innerArgs); - - builder->emitStore(leftPtr, innerOp); - - return LoweredValInfo::ptr(leftPtr); -} - -IRInst* getOneValOfType( - IRGenContext* context, - IRType* type) -{ - switch(type->op) - { - case kIROp_IntType: - case kIROp_UIntType: - case kIROp_UInt64Type: - return context->irBuilder->getIntValue(type, 1); - - case kIROp_HalfType: - case kIROp_FloatType: - case kIROp_DoubleType: - return context->irBuilder->getFloatValue(type, 1.0); - - default: - break; - } - - // TODO: should make sure to handle vector and matrix types here - - SLANG_UNEXPECTED("inc/dec type"); - UNREACHABLE_RETURN(nullptr); -} - -LoweredValInfo emitPrefixIncDecOp( - IRGenContext* context, - IRType* type, - IROp op, - UInt argCount, - IRInst* const* args) -{ - auto builder = context->irBuilder; - SLANG_UNREFERENCED_PARAMETER(argCount); - SLANG_ASSERT(argCount == 1); - auto argPtr = args[0]; - - auto preVal = builder->emitLoad(argPtr); - - IRInst* oneVal = getOneValOfType(context, type); - - IRInst* innerArgs[] = { preVal, oneVal }; - auto innerOp = builder->emitIntrinsicInst(type, op, 2, innerArgs); - - builder->emitStore(argPtr, innerOp); - - // For a prefix operator like `++i` we return - // the value after the increment/decrement has - // been applied. In casual terms we "increment - // the varaible, then return its value." - // - return LoweredValInfo::simple(innerOp); -} - -LoweredValInfo emitPostfixIncDecOp( - IRGenContext* context, - IRType* type, - IROp op, - UInt argCount, - IRInst* const* args) -{ - auto builder = context->irBuilder; - SLANG_UNREFERENCED_PARAMETER(argCount); - SLANG_ASSERT(argCount == 1); - auto argPtr = args[0]; - - auto preVal = builder->emitLoad(argPtr); - - IRInst* oneVal = getOneValOfType(context, type); - - IRInst* innerArgs[] = { preVal, oneVal }; - auto innerOp = builder->emitIntrinsicInst(type, op, 2, innerArgs); - - builder->emitStore(argPtr, innerOp); - - // For a postfix operator like `i++` we return - // the value that we read before the increment/decrement - // gets applied. In casual terms we "read - // the variable, then increment it." - // - return LoweredValInfo::simple(preVal); -} - -LoweredValInfo lowerRValueExpr( - IRGenContext* context, - Expr* expr); - -IRType* lowerType( - IRGenContext* context, - Type* type); - -static IRType* lowerType( - IRGenContext* context, - QualType const& type) -{ - return lowerType(context, type.type); -} - -// Given a `DeclRef` for something callable, along with a bunch of -// arguments, emit an appropriate call to it. -LoweredValInfo emitCallToDeclRef( - IRGenContext* context, - IRType* type, - DeclRef funcDeclRef, - IRType* funcType, - UInt argCount, - IRInst* const* args) -{ - auto builder = context->irBuilder; - - - if (auto subscriptDeclRef = funcDeclRef.as()) - { - // A reference to a subscript declaration is a special case, - // because it is not possible to call a subscript directly; - // we must call one of its accessors. - // - // TODO: everything here will also apply to propery declarations - // once we have them, so some of this code might be shared - // some day. - - DeclRef getterDeclRef; - bool justAGetter = true; - for (auto accessorDeclRef : getMembersOfType(subscriptDeclRef)) - { - // We want to track whether this subscript has any accessors other than - // `get` (assuming that everything except `get` can be used for setting...). - - if (auto foundGetterDeclRef = accessorDeclRef.as()) - { - // We found a getter. - getterDeclRef = foundGetterDeclRef; - } - else - { - // There was something other than a getter, so we can't - // invoke an accessor just now. - justAGetter = false; - } - } - - if (!justAGetter || !getterDeclRef) - { - // We can't perform an actual call right now, because - // this expression might appear in an r-value or l-value - // position (or *both* if it is being passed as an argument - // for an `in out` parameter!). - // - // Instead, we will construct a special-case value to - // represent the latent subscript operation (abstractly - // this is a reference to a storage location). - - // The abstract storage location will need to include - // all the arguments being passed to the subscript operation. - - RefPtr boundSubscript = new BoundSubscriptInfo(); - boundSubscript->declRef = subscriptDeclRef; - boundSubscript->type = type; - boundSubscript->args.addRange(args, argCount); - - context->shared->extValues.add(boundSubscript); - - return LoweredValInfo::boundSubscript(boundSubscript); - } - - // Otherwise we are just call the getter, and so that - // is what we need to be emitting a call to... - funcDeclRef = getterDeclRef; - } - - auto funcDecl = funcDeclRef.getDecl(); - if(auto intrinsicOpModifier = funcDecl->FindModifier()) - { - auto op = getIntrinsicOp(funcDecl, intrinsicOpModifier); - - if (isPseudoOp(op)) - { - switch (op) - { - case kIRPseudoOp_Pos: - return LoweredValInfo::simple(args[0]); - - case kIRPseudoOp_Sequence: - // The main effect of "operator comma" is to enforce - // sequencing of its operands, but Slang already - // implements a strictly left-to-right evaluation - // order for function arguments, so in practice we - // just need to compile `a, b` to the value of `b` - // (because argument evaluation already happened). - return LoweredValInfo::simple(args[1]); - -#define CASE(COMPOUND, OP) \ - case COMPOUND: return emitCompoundAssignOp(context, type, OP, argCount, args) - - CASE(kIRPseudoOp_AddAssign, kIROp_Add); - CASE(kIRPseudoOp_SubAssign, kIROp_Sub); - CASE(kIRPseudoOp_MulAssign, kIROp_Mul); - CASE(kIRPseudoOp_DivAssign, kIROp_Div); - CASE(kIRPseudoOp_ModAssign, kIROp_Mod); - CASE(kIRPseudoOp_AndAssign, kIROp_BitAnd); - CASE(kIRPseudoOp_OrAssign, kIROp_BitOr); - CASE(kIRPseudoOp_XorAssign, kIROp_BitXor); - CASE(kIRPseudoOp_LshAssign, kIROp_Lsh); - CASE(kIRPseudoOp_RshAssign, kIROp_Rsh); - -#undef CASE - -#define CASE(COMPOUND, OP) \ - case COMPOUND: return emitPrefixIncDecOp(context, type, OP, argCount, args) - CASE(kIRPseudoOp_PreInc, kIROp_Add); - CASE(kIRPseudoOp_PreDec, kIROp_Sub); -#undef CASE - -#define CASE(COMPOUND, OP) \ - case COMPOUND: return emitPostfixIncDecOp(context, type, OP, argCount, args) - CASE(kIRPseudoOp_PostInc, kIROp_Add); - CASE(kIRPseudoOp_PostDec, kIROp_Sub); -#undef CASE - default: - SLANG_UNIMPLEMENTED_X("IR pseudo-op"); - UNREACHABLE_RETURN(LoweredValInfo()); - } - } - - return LoweredValInfo::simple(builder->emitIntrinsicInst( - type, - op, - argCount, - args)); - } - // TODO: handle target intrinsic modifier too... - - if( auto ctorDeclRef = funcDeclRef.as() ) - { - // HACK: we know all constructors are builtins for now, - // so we need to emit them as a call to the corresponding - // builtin operation. - // - // TODO: these should all either be intrinsic operations, - // or calls to library functions. - - return LoweredValInfo::simple(builder->emitConstructorInst(type, argCount, args)); - } - - // Fallback case is to emit an actual call. - if(!funcType) - { - List argTypes; - for(UInt ii = 0; ii < argCount; ++ii) - { - argTypes.add(args[ii]->getDataType()); - } - funcType = builder->getFuncType(argCount, argTypes.getBuffer(), type); - } - LoweredValInfo funcVal = emitDeclRef(context, funcDeclRef, funcType); - return emitCallToVal(context, type, funcVal, argCount, args); -} - -LoweredValInfo emitCallToDeclRef( - IRGenContext* context, - IRType* type, - DeclRef funcDeclRef, - IRType* funcType, - List const& args) -{ - return emitCallToDeclRef(context, type, funcDeclRef, funcType, args.getCount(), args.getBuffer()); -} - -IRInst* getFieldKey( - IRGenContext* context, - DeclRef field) -{ - return getSimpleVal(context, emitDeclRef(context, field, context->irBuilder->getKeyType())); -} - -LoweredValInfo extractField( - IRGenContext* context, - IRType* fieldType, - LoweredValInfo base, - DeclRef field) -{ - IRBuilder* builder = context->irBuilder; - - switch (base.flavor) - { - default: - { - IRInst* irBase = getSimpleVal(context, base); - return LoweredValInfo::simple( - builder->emitFieldExtract( - fieldType, - irBase, - getFieldKey(context, field))); - } - break; - - case LoweredValInfo::Flavor::BoundMember: - case LoweredValInfo::Flavor::BoundSubscript: - { - // The base value is one that is trying to defer a get-vs-set - // decision, so we will need to do the same. - - RefPtr boundMemberInfo = new BoundMemberInfo(); - boundMemberInfo->type = fieldType; - boundMemberInfo->base = base; - boundMemberInfo->declRef = field; - - context->shared->extValues.add(boundMemberInfo); - return LoweredValInfo::boundMember(boundMemberInfo); - } - break; - - case LoweredValInfo::Flavor::Ptr: - { - // We are "extracting" a field from an lvalue address, - // which means we should just compute an lvalue - // representing the field address. - IRInst* irBasePtr = base.val; - return LoweredValInfo::ptr( - builder->emitFieldAddress( - builder->getPtrType(fieldType), - irBasePtr, - getFieldKey(context, field))); - } - break; - } -} - - - -LoweredValInfo materialize( - IRGenContext* context, - LoweredValInfo lowered) -{ - auto builder = context->irBuilder; - -top: - switch(lowered.flavor) - { - case LoweredValInfo::Flavor::None: - case LoweredValInfo::Flavor::Simple: - case LoweredValInfo::Flavor::Ptr: - return lowered; - - case LoweredValInfo::Flavor::BoundSubscript: - { - auto boundSubscriptInfo = lowered.getBoundSubscriptInfo(); - - // We are being asked to extract a value from a subscript call - // (e.g., `base[index]`). We will first check if the subscript - // declared a getter and use that if possible, and then fall - // back to a `ref` accessor if one is defined. - // - // (Picking the `get` over the `ref` accessor simplifies things - // in case the `get` operation has a natural translation for - // a target, while the general `ref` case does not...) - - auto getters = getMembersOfType(boundSubscriptInfo->declRef); - if (getters.Count()) - { - lowered = emitCallToDeclRef( - context, - boundSubscriptInfo->type, - *getters.begin(), - nullptr, - boundSubscriptInfo->args); - goto top; - } - - auto refAccessors = getMembersOfType(boundSubscriptInfo->declRef); - if(refAccessors.Count()) - { - // The `ref` accessor will return a pointer to the value, so - // we need to reflect that in the type of our `call` instruction. - IRType* ptrType = context->irBuilder->getPtrType(boundSubscriptInfo->type); - - LoweredValInfo refVal = emitCallToDeclRef( - context, - ptrType, - *refAccessors.begin(), - nullptr, - boundSubscriptInfo->args); - - // The result from the call needs to be implicitly dereferenced, - // so that it can work as an l-value of the desired result type. - lowered = LoweredValInfo::ptr(getSimpleVal(context, refVal)); - - goto top; - } - - SLANG_UNEXPECTED("subscript had no getter"); - UNREACHABLE_RETURN(LoweredValInfo()); - } - break; - - case LoweredValInfo::Flavor::BoundMember: - { - auto boundMemberInfo = lowered.getBoundMemberInfo(); - auto base = materialize(context, boundMemberInfo->base); - - auto declRef = boundMemberInfo->declRef; - if( auto fieldDeclRef = declRef.as() ) - { - lowered = extractField(context, boundMemberInfo->type, base, fieldDeclRef); - goto top; - } - else - { - - SLANG_UNEXPECTED("unexpected member flavor"); - UNREACHABLE_RETURN(LoweredValInfo()); - } - } - break; - - case LoweredValInfo::Flavor::SwizzledLValue: - { - auto swizzleInfo = lowered.getSwizzledLValueInfo(); - - return LoweredValInfo::simple(builder->emitSwizzle( - swizzleInfo->type, - getSimpleVal(context, swizzleInfo->base), - swizzleInfo->elementCount, - swizzleInfo->elementIndices)); - } - - default: - SLANG_UNEXPECTED("unhandled value flavor"); - UNREACHABLE_RETURN(LoweredValInfo()); - } - -} - -IRInst* getSimpleVal(IRGenContext* context, LoweredValInfo lowered) -{ - auto builder = context->irBuilder; - - // First, try to eliminate any "bound" operations along the chain, - // so that we are dealing with an ordinary value, or an l-value pointer. - lowered = materialize(context, lowered); - - switch(lowered.flavor) - { - case LoweredValInfo::Flavor::None: - return nullptr; - - case LoweredValInfo::Flavor::Simple: - return lowered.val; - - case LoweredValInfo::Flavor::Ptr: - return builder->emitLoad(lowered.val); - - default: - SLANG_UNEXPECTED("unhandled value flavor"); - UNREACHABLE_RETURN(nullptr); - } -} - -LoweredValInfo lowerVal( - IRGenContext* context, - Val* val); - -IRInst* lowerSimpleVal( - IRGenContext* context, - Val* val) -{ - auto lowered = lowerVal(context, val); - return getSimpleVal(context, lowered); -} - -LoweredValInfo lowerLValueExpr( - IRGenContext* context, - Expr* expr); - -void assign( - IRGenContext* context, - LoweredValInfo const& left, - LoweredValInfo const& right); - -IRInst* getAddress( - IRGenContext* context, - LoweredValInfo const& inVal, - SourceLoc diagnosticLocation); - -void lowerStmt( - IRGenContext* context, - Stmt* stmt); - -LoweredValInfo lowerDecl( - IRGenContext* context, - DeclBase* decl); - -IRType* getIntType( - IRGenContext* context) -{ - return context->irBuilder->getBasicType(BaseType::Int); -} - -static IRGeneric* getOuterGeneric(IRInst* gv) -{ - auto parentBlock = as(gv->getParent()); - if (!parentBlock) return nullptr; - - auto parentGeneric = as(parentBlock->getParent()); - return parentGeneric; -} - -static void addLinkageDecoration( - IRGenContext* context, - IRInst* inInst, - Decl* decl, - UnownedStringSlice const& mangledName) -{ - // If the instruction is nested inside one or more generics, - // then the mangled name should really apply to the outer-most - // generic, and not the declaration nested inside. - - auto builder = context->irBuilder; - - IRInst* inst = inInst; - while (auto outerGeneric = getOuterGeneric(inst)) - { - inst = outerGeneric; - } - - if(isImportedDecl(context, decl)) - { - builder->addImportDecoration(inst, mangledName); - } - else - { - builder->addExportDecoration(inst, mangledName); - } -} - -static void addLinkageDecoration( - IRGenContext* context, - IRInst* inst, - Decl* decl) -{ - addLinkageDecoration(context, inst, decl, getMangledName(decl).getUnownedSlice()); -} - -IRStructKey* getInterfaceRequirementKey( - IRGenContext* context, - Decl* requirementDecl) -{ - IRStructKey* requirementKey = nullptr; - if(context->shared->interfaceRequirementKeys.TryGetValue(requirementDecl, requirementKey)) - { - return requirementKey; - } - - IRBuilder builderStorage = *context->irBuilder; - auto builder = &builderStorage; - - builder->setInsertInto(builder->sharedBuilder->module->getModuleInst()); - - // Construct a key to serve as the representation of - // this requirement in the IR, and to allow lookup - // into the declaration. - requirementKey = builder->createStructKey(); - - addLinkageDecoration(context, requirementKey, requirementDecl); - - context->shared->interfaceRequirementKeys.Add(requirementDecl, requirementKey); - - return requirementKey; -} - - -SubstitutionSet lowerSubstitutions(IRGenContext* context, SubstitutionSet subst); -// - -struct ValLoweringVisitor : ValVisitor -{ - IRGenContext* context; - - IRBuilder* getBuilder() { return context->irBuilder; } - - LoweredValInfo visitVal(Val* /*val*/) - { - SLANG_UNIMPLEMENTED_X("value lowering"); - UNREACHABLE_RETURN(LoweredValInfo()); - } - - LoweredValInfo visitGenericParamIntVal(GenericParamIntVal* val) - { - return emitDeclRef(context, val->declRef, - lowerType(context, GetType(val->declRef))); - } - - LoweredValInfo visitDeclaredSubtypeWitness(DeclaredSubtypeWitness* val) - { - return emitDeclRef(context, val->declRef, - context->irBuilder->getWitnessTableType()); - } - - LoweredValInfo visitTransitiveSubtypeWitness( - TransitiveSubtypeWitness* val) - { - // The base (subToMid) will turn into a value with - // witness-table type. - IRInst* baseWitnessTable = lowerSimpleVal(context, val->subToMid); - - // The next step should map to an interface requirement - // that is itself an interface conformance, so the result - // of lowering this value should be a "key" that we can - // use to look up a witness table. - IRInst* requirementKey = getInterfaceRequirementKey(context, val->midToSup.getDecl()); - - // TODO: There are some ugly cases here if `midToSup` is allowed - // to be an arbitrary witness, rather than just a declared one, - // and we should probably change the front-end representation - // to reflect the right constraints. - - return LoweredValInfo::simple(getBuilder()->emitLookupInterfaceMethodInst( - nullptr, - baseWitnessTable, - requirementKey)); - } - - LoweredValInfo visitTaggedUnionSubtypeWitness( - TaggedUnionSubtypeWitness* val) - { - // The sub-type in this case is a tagged union `A | B | ...`, - // and the witness holds an array of witnesses showing that each - // "case" (`A`, `B`, etc.) is a subtype of the super-type. - - // We will start by getting the IR-level representation of the - // sub type (the tagged union type). - // - auto irTaggedUnionType = lowerType(context, val->sub); - - // We can turn each of those per-case witnesses into a witness - // table value: - // - auto caseCount = val->caseWitnesses.getCount(); - List caseWitnessTables; - for( auto caseWitness : val->caseWitnesses ) - { - auto caseWitnessTable = lowerSimpleVal(context, caseWitness); - caseWitnessTables.add(caseWitnessTable); - } - - // Now we need to synthesize a witness table for the tagged union - // value, showing how it can implement all of the requirements - // of the super type by delegating to the appropriate implementation - // on a per-case basis. - // - // We will assume here that the super-type is an interface, and it - // will be left to the front-end to ensure this property. - // - auto supDeclRefType = as(val->sup); - if(!supDeclRefType) - { - SLANG_UNEXPECTED("super-type not a decl-ref type when generating tagged union witness table"); - UNREACHABLE_RETURN(LoweredValInfo()); - } - auto supInterfaceDeclRef = supDeclRefType->declRef.as(); - if( !supInterfaceDeclRef ) - { - SLANG_UNEXPECTED("super-type not an interface type when generating tagged union witness table"); - UNREACHABLE_RETURN(LoweredValInfo()); - } - - auto irWitnessTable = getBuilder()->createWitnessTable(); - - // Now we will iterate over the requirements (members) of the - // interface and try to synthesize an appropriate value for each. - // - for( auto reqDeclRef : getMembers(supInterfaceDeclRef) ) - { - // TODO: if there are any members we shouldn't process as a requirement, - // then we should detect and skip them here. - // - - // Every interface requirement will have a unique key that is used - // when looking up the requirement in a concrete witness table. - // - auto irReqKey = getInterfaceRequirementKey(context, reqDeclRef.getDecl()); - - // We expect that each of the witness tables in `caseWitnessTables` - // will have an entry to match these keys. However, we may not - // have a concrete `IRWitnessTable` for each of the case types, either - // because they are a specialization of a generic (so that the witness - // table reference is a `specialize` instruction at this point), or - // they are a type external to this module (so that we have a declaration - // rather than a definition of the witness table). - - // Our task is to create an IR value that can satisfy the interface - // requirement for the tagged union type, by appropriately delegating - // to the implementations of the same requirement in the case types. - // - IRInst* irSatisfyingVal = nullptr; - - - - if(auto callableDeclRef = reqDeclRef.as()) - { - // We have something callable, so we need to synthesize - // a function to satisfy it. - // - auto irFunc = getBuilder()->createFunc(); - irSatisfyingVal = irFunc; - - IRBuilder subBuilderStorage; - auto subBuilder = &subBuilderStorage; - subBuilder->sharedBuilder = getBuilder()->sharedBuilder; - subBuilder->setInsertInto(irFunc); - - // We will start by setting up the function parameters, - // which live in the entry block of the IR function. - // - auto entryBlock = subBuilder->emitBlock(); - subBuilder->setInsertInto(entryBlock); - - // Create a `this` parameter of the tagged-union type. - // - // TODO: need to handle the `[mutating]` case here... - // - auto irThisType = irTaggedUnionType; - auto irThisParam = subBuilder->emitParam(irThisType); - - List irParamTypes; - irParamTypes.add(irThisType); - - // Create the remaining parameters of the callable, - // using a decl-ref specialized to the tagged union - // type (so that things like associated types are - // mapped to the correct witness value). - // - List irParams; - for( auto paramDeclRef : getMembersOfType(callableDeclRef) ) - { - // TODO: need to handle `out` and `in out` here. Over all - // there is a lot of duplication here with the existing logic - // for emitting the signature of a `CallableDecl`, and we should - // try to re-use that if at all possible. - // - auto irParamType = lowerType(context, GetType(paramDeclRef)); - auto irParam = subBuilder->emitParam(irParamType); - - irParams.add(irParam); - irParamTypes.add(irParamType); - } - - auto irResultType = lowerType(context, GetResultType(callableDeclRef)); - - auto irFuncType = subBuilder->getFuncType( - irParamTypes, - irResultType); - irFunc->setFullType(irFuncType); - - // The first thing our function needs to do is extract the tag - // from the incoming `this` parameter. - // - auto irTagVal = subBuilder->emitExtractTaggedUnionTag(irThisParam); - - // Next we want to emit a `switch` on the tag value, but before we - // do that we need to generate the code for each of the cases so that - // our `switch` has somewhere to branch to. - // - List switchCaseOperands; - - IRBlock* defaultLabel = nullptr; - - for( Index ii = 0; ii < caseCount; ++ii ) - { - auto caseTag = subBuilder->getIntValue(irTagVal->getDataType(), ii); - - subBuilder->setInsertInto(irFunc); - auto caseLabel = subBuilder->emitBlock(); - - if(!defaultLabel) - defaultLabel = caseLabel; - - switchCaseOperands.add(caseTag); - switchCaseOperands.add(caseLabel); - - subBuilder->setInsertInto(caseLabel); - - // We need to look up the satisfying value for this interface - // requirement on the witness table of the particular case value. - // - // We already have the witness table, and the requirement key is - // just `irReqKey`. - // - auto caseWitnessTable = caseWitnessTables[ii]; - - // The subtle bit here is determining the type we expect the - // satisfying value to have, since that depends on the actual - // type that is satisfying the requirement. - // - IRType* caseResultType = irResultType; - IRType* caseFuncType = nullptr; - auto caseFunc = subBuilder->emitLookupInterfaceMethodInst( - caseFuncType, - caseWitnessTable, - irReqKey); - - // We are going to emit a `call` to the satisfying value - // for the case type, so we will collect the arguments for that call. - // - List caseArgs; - - // The `this` argument to the call will need to represent the - // appropriate field of our tagged union. - // - IRType* caseThisType = (IRType*) irTaggedUnionType->getOperand(ii); - auto caseThisArg = subBuilder->emitExtractTaggedUnionPayload( - caseThisType, - irThisParam, caseTag); - caseArgs.add(caseThisArg); - - // The remaining arguments to the call will just be forwarded from - // the parameters of the wrapper function. - // - // TODO: This would need to change if/when we started allowing `This` type - // or associated-type parameters to be used at call sites where a tagged - // union is used. - // - for( auto param : irParams ) - { - caseArgs.add(param); - } - - auto caseCall = subBuilder->emitCallInst(caseResultType, caseFunc, caseArgs); - - if( as(irResultType->getDataType()) ) - { - subBuilder->emitReturn(); - } - else - { - subBuilder->emitReturn(caseCall); - } - } - - // We will create a block to represent the supposedly-unreachable - // code that will run if no `case` matches. - // - subBuilder->setInsertInto(irFunc); - auto invalidLabel = subBuilder->emitBlock(); - subBuilder->setInsertInto(invalidLabel); - subBuilder->emitUnreachable(); - - if(!defaultLabel) defaultLabel = invalidLabel; - - // Now we have enough information to go back and emit the `switch` instruction - // into the entry block. - subBuilder->setInsertInto(entryBlock); - subBuilder->emitSwitch( - irTagVal, // value to `switch` on - invalidLabel, // `break` label (block after the `switch` statement ends) - defaultLabel, // `default` label (where to go if no `case` matches) - switchCaseOperands.getCount(), - switchCaseOperands.getBuffer()); - } - else - { - // TODO: We need to handle other cases of interface requirements. - SLANG_UNEXPECTED("unexpceted interface requirement when generating tagged union witness table"); - UNREACHABLE_RETURN(LoweredValInfo()); - } - - // Once we've generating a value to satisfying the requirement, we install - // it into the witness table for our tagged-union type. - // - getBuilder()->createWitnessTableEntry(irWitnessTable, irReqKey, irSatisfyingVal); - } - - return LoweredValInfo::simple(irWitnessTable); - } - - LoweredValInfo visitConstantIntVal(ConstantIntVal* val) - { - // TODO: it is a bit messy here that the `ConstantIntVal` representation - // has no notion of a *type* associated with the value... - - auto type = getIntType(context); - return LoweredValInfo::simple(getBuilder()->getIntValue(type, val->value)); - } - - IRFuncType* visitFuncType(FuncType* type) - { - IRType* resultType = lowerType(context, type->getResultType()); - UInt paramCount = type->getParamCount(); - List paramTypes; - for (UInt pp = 0; pp < paramCount; ++pp) - { - paramTypes.add(lowerType(context, type->getParamType(pp))); - } - return getBuilder()->getFuncType( - paramCount, - paramTypes.getBuffer(), - resultType); - } - - IRType* visitDeclRefType(DeclRefType* type) - { - auto declRef = type->declRef; - auto decl = declRef.getDecl(); - - // Check for types with teh `__intrinsic_type` modifier. - if(decl->FindModifier()) - { - return lowerSimpleIntrinsicType(type); - } - - - return (IRType*) getSimpleVal( - context, - emitDeclRef(context, declRef, - context->irBuilder->getTypeKind())); - } - - IRType* visitNamedExpressionType(NamedExpressionType* type) - { - return (IRType*)getSimpleVal(context, dispatchType(type->GetCanonicalType())); - } - - IRType* visitBasicExpressionType(BasicExpressionType* type) - { - return getBuilder()->getBasicType( - type->baseType); - } - - IRType* visitVectorExpressionType(VectorExpressionType* type) - { - auto elementType = lowerType(context, type->elementType); - auto elementCount = lowerSimpleVal(context, type->elementCount); - - return getBuilder()->getVectorType( - elementType, - elementCount); - } - - IRType* visitMatrixExpressionType(MatrixExpressionType* type) - { - auto elementType = lowerType(context, type->getElementType()); - auto rowCount = lowerSimpleVal(context, type->getRowCount()); - auto columnCount = lowerSimpleVal(context, type->getColumnCount()); - - return getBuilder()->getMatrixType( - elementType, - rowCount, - columnCount); - } - - IRType* visitArrayExpressionType(ArrayExpressionType* type) - { - auto elementType = lowerType(context, type->baseType); - if (type->ArrayLength) - { - auto elementCount = lowerSimpleVal(context, type->ArrayLength); - return getBuilder()->getArrayType( - elementType, - elementCount); - } - else - { - return getBuilder()->getUnsizedArrayType( - elementType); - } - } - - // Lower a type where the type declaration being referenced is assumed - // to be an intrinsic type, which can thus be lowered to a simple IR - // type with the appropriate opcode. - IRType* lowerSimpleIntrinsicType(DeclRefType* type) - { - auto intrinsicTypeModifier = type->declRef.getDecl()->FindModifier(); - SLANG_ASSERT(intrinsicTypeModifier); - IROp op = IROp(intrinsicTypeModifier->irOp); - return getBuilder()->getType(op); - } - - // Lower a type where the type declaration being referenced is assumed - // to be an intrinsic type with a single generic type parameter, and - // which can thus be lowered to a simple IR type with the appropriate opcode. - IRType* lowerGenericIntrinsicType(DeclRefType* type, Type* elementType) - { - auto intrinsicTypeModifier = type->declRef.getDecl()->FindModifier(); - SLANG_ASSERT(intrinsicTypeModifier); - IROp op = IROp(intrinsicTypeModifier->irOp); - IRInst* irElementType = lowerType(context, elementType); - return getBuilder()->getType( - op, - 1, - &irElementType); - } - - IRType* lowerGenericIntrinsicType(DeclRefType* type, Type* elementType, IntVal* count) - { - auto intrinsicTypeModifier = type->declRef.getDecl()->FindModifier(); - SLANG_ASSERT(intrinsicTypeModifier); - IROp op = IROp(intrinsicTypeModifier->irOp); - IRInst* irElementType = lowerType(context, elementType); - - IRInst* irCount = lowerSimpleVal(context, count); - - IRInst* const operands[2] = - { - irElementType, - irCount, - }; - - return getBuilder()->getType( - op, - SLANG_COUNT_OF(operands), - operands); - } - - IRType* visitResourceType(ResourceType* type) - { - return lowerGenericIntrinsicType(type, type->elementType); - } - - IRType* visitSamplerStateType(SamplerStateType* type) - { - return lowerSimpleIntrinsicType(type); - } - - IRType* visitBuiltinGenericType(BuiltinGenericType* type) - { - return lowerGenericIntrinsicType(type, type->elementType); - } - - IRType* visitUntypedBufferResourceType(UntypedBufferResourceType* type) - { - return lowerSimpleIntrinsicType(type); - } - - IRType* visitHLSLPatchType(HLSLPatchType* type) - { - Type* elementType = type->getElementType(); - IntVal* count = type->getElementCount(); - - return lowerGenericIntrinsicType(type, elementType, count); - } - - IRType* visitExtractExistentialType(ExtractExistentialType* type) - { - auto declRef = type->declRef; - auto existentialType = lowerType(context, GetType(declRef)); - IRInst* existentialVal = getSimpleVal(context, emitDeclRef(context, declRef, existentialType)); - return getBuilder()->emitExtractExistentialType(existentialVal); - } - - LoweredValInfo visitExtractExistentialSubtypeWitness(ExtractExistentialSubtypeWitness* witness) - { - auto declRef = witness->declRef; - auto existentialType = lowerType(context, GetType(declRef)); - IRInst* existentialVal = getSimpleVal(context, emitDeclRef(context, declRef, existentialType)); - return LoweredValInfo::simple(getBuilder()->emitExtractExistentialWitnessTable(existentialVal)); - } - - LoweredValInfo visitTaggedUnionType(TaggedUnionType* type) - { - // A tagged union type will lower into an IR `union` over the cases, - // along with an IR `struct` with a field for the union and a tag. - // (Note: we are placing the tag after the payload to avoid padding - // in the case where the payload is more aligned than the tag) - // - // TODO: should we be lowering directly like this, or have - // an IR-level representation of tagged unions? - // - - List irCaseTypes; - for(auto caseType : type->caseTypes) - { - auto irCaseType = lowerType(context, caseType); - irCaseTypes.add(irCaseType); - } - - auto irType = getBuilder()->getTaggedUnionType(irCaseTypes); - if(!irType->findDecoration()) - { - // We need a way for later passes to attach layout information - // to this type, so we will give it a mangled name here. - // - getBuilder()->addExportDecoration( - irType, - getMangledTypeName(type).getUnownedSlice()); - } - return LoweredValInfo::simple(irType); - } - - LoweredValInfo visitExistentialSpecializedType(ExistentialSpecializedType* type) - { - auto irBaseType = lowerType(context, type->baseType); - - List slotArgs; - for(auto arg : type->slots.args) - { - auto irArgType = lowerType(context, arg.type); - auto irArgWitness = lowerSimpleVal(context, arg.witness); - - slotArgs.add(irArgType); - slotArgs.add(irArgWitness); - } - - auto irType = getBuilder()->getBindExistentialsType(irBaseType, slotArgs.getCount(), slotArgs.getBuffer()); - return LoweredValInfo::simple(irType); - } - - // We do not expect to encounter the following types in ASTs that have - // passed front-end semantic checking. -#define UNEXPECTED_CASE(NAME) IRType* visit##NAME(NAME*) { SLANG_UNEXPECTED(#NAME); UNREACHABLE_RETURN(nullptr); } - UNEXPECTED_CASE(GenericDeclRefType) - UNEXPECTED_CASE(TypeType) - UNEXPECTED_CASE(ErrorType) - UNEXPECTED_CASE(InitializerListType) - UNEXPECTED_CASE(OverloadGroupType) -}; - -LoweredValInfo lowerVal( - IRGenContext* context, - Val* val) -{ - ValLoweringVisitor visitor; - visitor.context = context; - return visitor.dispatch(val); -} - -IRType* lowerType( - IRGenContext* context, - Type* type) -{ - ValLoweringVisitor visitor; - visitor.context = context; - return (IRType*) getSimpleVal(context, visitor.dispatchType(type)); -} - -void addVarDecorations( - IRGenContext* context, - IRInst* inst, - Decl* decl) -{ - auto builder = context->irBuilder; - for(RefPtr mod : decl->modifiers) - { - if(as(mod)) - { - builder->addInterpolationModeDecoration(inst, IRInterpolationMode::NoInterpolation); - } - else if(as(mod)) - { - builder->addInterpolationModeDecoration(inst, IRInterpolationMode::NoPerspective); - } - else if(as(mod)) - { - builder->addInterpolationModeDecoration(inst, IRInterpolationMode::Linear); - } - else if(as(mod)) - { - builder->addInterpolationModeDecoration(inst, IRInterpolationMode::Sample); - } - else if(as(mod)) - { - builder->addInterpolationModeDecoration(inst, IRInterpolationMode::Centroid); - } - else if(as(mod)) - { - builder->addSimpleDecoration(inst); - } - else if(as(mod)) - { - builder->addSimpleDecoration(inst); - } - else if(as(mod)) - { - builder->addSimpleDecoration(inst); - } - else if(as(mod)) - { - builder->addSimpleDecoration(inst); - } - else if(as(mod)) - { - builder->addSimpleDecoration(inst); - } - else if(auto formatAttr = as(mod)) - { - builder->addFormatDecoration(inst, formatAttr->format); - } - - // TODO: what are other modifiers we need to propagate through? - } -} - -/// If `decl` has a modifier that should turn into a -/// rate qualifier, then apply it to `inst`. -void maybeSetRate( - IRGenContext* context, - IRInst* inst, - Decl* decl) -{ - auto builder = context->irBuilder; - - if (decl->HasModifier()) - { - inst->setFullType(builder->getRateQualifiedType( - builder->getGroupSharedRate(), - inst->getFullType())); - } -} - -static String getNameForNameHint( - IRGenContext* context, - Decl* decl) -{ - // We will use a bit of an ad hoc convention here for now. - - Name* leafName = decl->getName(); - - // Handle custom name for a global parameter group (e.g., a `cbuffer`) - if(auto reflectionNameModifier = decl->FindModifier()) - { - leafName = reflectionNameModifier->nameAndLoc.name; - } - - // There is no point in trying to provide a name hint for something with no name, - // or with an empty name - if(!leafName) - return String(); - if(leafName->text.getLength() == 0) - return String(); - - - if(auto varDecl = as(decl)) - { - // For an ordinary local variable, global variable, - // parameter, or field, we will just use the name - // as declared, and now work in anything from - // its parent declaration(s). - // - // TODO: consider whether global/static variables should - // follow different rules. - // - return leafName->text; - } - - // For other cases of declaration, we want to consider - // merging its name with the name of its parent declaration. - auto parentDecl = decl->ParentDecl; - - // Skip past a generic parent, if we are a declaration nested in a generic. - if(auto genericParentDecl = as(parentDecl)) - parentDecl = genericParentDecl->ParentDecl; - - // A `ModuleDecl` can have a name too, but in the common case - // we don't want to generate name hints that include the module - // name, simply because they would lead to every global symbol - // getting a much longer name. - // - // TODO: We should probably include the module name for symbols - // being `import`ed, and not for symbols being compiled directly - // (those coming from a module that had no name given to it). - // - // For now we skip past a `ModuleDecl` parent. - // - if(auto moduleParentDecl = as(parentDecl)) - parentDecl = moduleParentDecl->ParentDecl; - - if(!parentDecl) - { - return leafName->text; - } - - auto parentName = getNameForNameHint(context, parentDecl); - if(parentName.getLength() == 0) - { - return leafName->text; - } - - // We will now construct a new `Name` to use as the hint, - // combining the name of the parent and the leaf declaration. - - StringBuilder sb; - sb.append(parentName); - sb.append("."); - sb.append(leafName->text); - - return sb.ProduceString(); -} - -/// Try to add an appropriate name hint to the instruction, -/// that can be used for back-end code emission or debug info. -static void addNameHint( - IRGenContext* context, - IRInst* inst, - Decl* decl) -{ - String name = getNameForNameHint(context, decl); - if(name.getLength() == 0) - return; - context->irBuilder->addNameHintDecoration(inst, name.getUnownedSlice()); -} - -/// Add a name hint based on a fixed string. -static void addNameHint( - IRGenContext* context, - IRInst* inst, - char const* text) -{ - context->irBuilder->addNameHintDecoration(inst, UnownedTerminatedStringSlice(text)); -} - -LoweredValInfo createVar( - IRGenContext* context, - IRType* type, - Decl* decl = nullptr) -{ - auto builder = context->irBuilder; - auto irAlloc = builder->emitVar(type); - - if (decl) - { - maybeSetRate(context, irAlloc, decl); - - addVarDecorations(context, irAlloc, decl); - - builder->addHighLevelDeclDecoration(irAlloc, decl); - - addNameHint(context, irAlloc, decl); - } - - return LoweredValInfo::ptr(irAlloc); -} - -void addArgs( - IRGenContext* context, - List* ioArgs, - LoweredValInfo argInfo) -{ - auto& args = *ioArgs; - switch( argInfo.flavor ) - { - case LoweredValInfo::Flavor::Simple: - case LoweredValInfo::Flavor::Ptr: - case LoweredValInfo::Flavor::SwizzledLValue: - case LoweredValInfo::Flavor::BoundSubscript: - case LoweredValInfo::Flavor::BoundMember: - args.add(getSimpleVal(context, argInfo)); - break; - - default: - SLANG_UNIMPLEMENTED_X("addArgs case"); - break; - } -} - -// - -// When we try to turn a `LoweredValInfo` into an address of some temporary storage, -// we can either do it "aggressively" or not (what we'll call the "default" behavior, -// although it isn't strictly more common). -// -// The case that this is mostly there to address is when somebody writes an operation -// like: -// -// foo[a] = b; -// -// In that case, we might as well just use the `set` accessor if there is one, rather -// than complicate things. However, in more complex cases like: -// -// foo[a].x = b; -// -// there is no way to satisfy the semantics of the code the user wrote (in terms of -// only writing one vector component, and not a full vector) by using the `set` -// accessor, and we need to be "aggressive" in turning the lvalue `foo[a]` into -// an address. -// -// TODO: realistically IR lowering is too early to be binding to this choice, -// because different accessors might be supported on different targets. -// -enum class TryGetAddressMode -{ - Default, - Aggressive, -}; - -/// Try to coerce `inVal` into a `LoweredValInfo::ptr()` with a simple address. -LoweredValInfo tryGetAddress( - IRGenContext* context, - LoweredValInfo const& inVal, - TryGetAddressMode mode); - - -// - -template -struct ExprLoweringVisitorBase : ExprVisitor -{ - IRGenContext* context; - - IRBuilder* getBuilder() { return context->irBuilder; } - - // Lower an expression that should have the same l-value-ness - // as the visitor itself. - LoweredValInfo lowerSubExpr(Expr* expr) - { - IRBuilderSourceLocRAII sourceLocInfo(getBuilder(), expr->loc); - return this->dispatch(expr); - } - - - LoweredValInfo visitVarExpr(VarExpr* expr) - { - LoweredValInfo info = emitDeclRef( - context, - expr->declRef, - lowerType(context, expr->type)); - return info; - } - - LoweredValInfo visitOverloadedExpr(OverloadedExpr* /*expr*/) - { - SLANG_UNEXPECTED("overloaded expressions should not occur in checked AST"); - UNREACHABLE_RETURN(LoweredValInfo()); - } - - LoweredValInfo visitOverloadedExpr2(OverloadedExpr2* /*expr*/) - { - SLANG_UNEXPECTED("overloaded expressions should not occur in checked AST"); - UNREACHABLE_RETURN(LoweredValInfo()); - } - - LoweredValInfo visitIndexExpr(IndexExpr* expr) - { - auto type = lowerType(context, expr->type); - auto baseVal = lowerSubExpr(expr->BaseExpression); - auto indexVal = getSimpleVal(context, lowerRValueExpr(context, expr->IndexExpression)); - - return subscriptValue(type, baseVal, indexVal); - } - - LoweredValInfo visitThisExpr(ThisExpr* /*expr*/) - { - return context->thisVal; - } - - LoweredValInfo visitMemberExpr(MemberExpr* expr) - { - auto loweredType = lowerType(context, expr->type); - auto loweredBase = lowerRValueExpr(context, expr->BaseExpression); - - auto declRef = expr->declRef; - if (auto fieldDeclRef = declRef.as()) - { - // Okay, easy enough: we have a reference to a field of a struct type... - return extractField(loweredType, loweredBase, fieldDeclRef); - } - else if (auto callableDeclRef = declRef.as()) - { - RefPtr boundMemberInfo = new BoundMemberInfo(); - boundMemberInfo->type = nullptr; - boundMemberInfo->base = loweredBase; - boundMemberInfo->declRef = callableDeclRef; - return LoweredValInfo::boundMember(boundMemberInfo); - } - else if(auto constraintDeclRef = declRef.as()) - { - // The code is making use of a "witness" that a value of - // some generic type conforms to an interface. - // - // For now we will just emit the base expression as-is. - // TODO: we may need to insert an explicit instruction - // for a cast here (that could become a no-op later). - return loweredBase; - } - - SLANG_UNIMPLEMENTED_X("codegen for subscript expression"); - UNREACHABLE_RETURN(LoweredValInfo()); - } - - // We will always lower a dereference expression (`*ptr`) - // as an l-value, since that is the easiest way to handle it. - LoweredValInfo visitDerefExpr(DerefExpr* expr) - { - auto loweredBase = lowerRValueExpr(context, expr->base); - - // TODO: handle tupel-type for `base` - - // The type of the lowered base must by some kind of pointer, - // in order for a dereference to make senese, so we just - // need to extract the value type from that pointer here. - // - IRInst* loweredBaseVal = getSimpleVal(context, loweredBase); - IRType* loweredBaseType = loweredBaseVal->getDataType(); - - if (as(loweredBaseType) - || as(loweredBaseType)) - { - // Note that we do *not* perform an actual `load` operation - // here, but rather just use the pointer value to construct - // an appropriate `LoweredValInfo` representing the underlying - // dereference. - // - // This is important so that an expression like `&((*foo).bar)` - // (which is desugared from `&foo->bar`) can be handled; such - // an expression does *not* perform a dereference at runtime, - // and is just a bit of pointer math. - // - return LoweredValInfo::ptr(loweredBaseVal); - } - else - { - SLANG_UNIMPLEMENTED_X("codegen for deref expression"); - UNREACHABLE_RETURN(LoweredValInfo()); - } - } - - LoweredValInfo visitParenExpr(ParenExpr* expr) - { - return lowerSubExpr(expr->base); - } - - LoweredValInfo getSimpleDefaultVal(IRType* type) - { - if(auto basicType = as(type)) - { - switch( basicType->getBaseType() ) - { - default: - SLANG_UNEXPECTED("missing case for getting IR default value"); - UNREACHABLE_RETURN(LoweredValInfo()); - break; - - case BaseType::Bool: - case BaseType::Int8: - case BaseType::Int16: - case BaseType::Int: - case BaseType::Int64: - case BaseType::UInt8: - case BaseType::UInt16: - case BaseType::UInt: - case BaseType::UInt64: - return LoweredValInfo::simple(getBuilder()->getIntValue(type, 0)); - - case BaseType::Half: - case BaseType::Float: - case BaseType::Double: - return LoweredValInfo::simple(getBuilder()->getFloatValue(type, 0.0)); - } - } - - SLANG_UNEXPECTED("missing case for getting IR default value"); - UNREACHABLE_RETURN(LoweredValInfo()); - } - - LoweredValInfo getDefaultVal(Type* type) - { - auto irType = lowerType(context, type); - if (auto basicType = as(type)) - { - return getSimpleDefaultVal(irType); - } - else if (auto vectorType = as(type)) - { - UInt elementCount = (UInt) GetIntVal(vectorType->elementCount); - - auto irDefaultValue = getSimpleVal(context, getDefaultVal(vectorType->elementType)); - - List args; - for(UInt ee = 0; ee < elementCount; ++ee) - { - args.add(irDefaultValue); - } - return LoweredValInfo::simple( - getBuilder()->emitMakeVector(irType, args.getCount(), args.getBuffer())); - } - else if (auto matrixType = as(type)) - { - UInt rowCount = (UInt) GetIntVal(matrixType->getRowCount()); - - auto rowType = matrixType->getRowType(); - - auto irDefaultValue = getSimpleVal(context, getDefaultVal(rowType)); - - List args; - for(UInt rr = 0; rr < rowCount; ++rr) - { - args.add(irDefaultValue); - } - return LoweredValInfo::simple( - getBuilder()->emitMakeMatrix(irType, args.getCount(), args.getBuffer())); - } - else if (auto arrayType = as(type)) - { - UInt elementCount = (UInt) GetIntVal(arrayType->ArrayLength); - - auto irDefaultElement = getSimpleVal(context, getDefaultVal(arrayType->baseType)); - - List args; - for(UInt ee = 0; ee < elementCount; ++ee) - { - args.add(irDefaultElement); - } - - return LoweredValInfo::simple( - getBuilder()->emitMakeArray(irType, args.getCount(), args.getBuffer())); - } - else if (auto declRefType = as(type)) - { - DeclRef declRef = declRefType->declRef; - if (auto aggTypeDeclRef = declRef.as()) - { - List args; - for (auto ff : getMembersOfType(aggTypeDeclRef)) - { - if (ff.getDecl()->HasModifier()) - continue; - - auto irFieldVal = getSimpleVal(context, getDefaultVal(ff)); - args.add(irFieldVal); - } - - return LoweredValInfo::simple( - getBuilder()->emitMakeStruct(irType, args.getCount(), args.getBuffer())); - } - } - - SLANG_UNEXPECTED("unexpected type when creating default value"); - UNREACHABLE_RETURN(LoweredValInfo()); - } - - LoweredValInfo getDefaultVal(VarDeclBase* decl) - { - if(auto initExpr = decl->initExpr) - { - return lowerRValueExpr(context, initExpr); - } - else - { - return getDefaultVal(decl->type); - } - } - - LoweredValInfo visitInitializerListExpr(InitializerListExpr* expr) - { - // Allocate a temporary of the given type - auto type = expr->type; - IRType* irType = lowerType(context, type); - List args; - - UInt argCount = expr->args.getCount(); - - // If the initializer list was empty, then the user was - // asking for default initialization, which should apply - // to (almost) any type. - // - if(argCount == 0) - { - return getDefaultVal(type.type); - } - - // Now for each argument in the initializer list, - // fill in the appropriate field of the result - if (auto arrayType = as(type)) - { - UInt elementCount = (UInt) GetIntVal(arrayType->ArrayLength); - - for (UInt ee = 0; ee < argCount; ++ee) - { - auto argExpr = expr->args[ee]; - LoweredValInfo argVal = lowerRValueExpr(context, argExpr); - args.add(getSimpleVal(context, argVal)); - } - if(elementCount > argCount) - { - auto irDefaultValue = getSimpleVal(context, getDefaultVal(arrayType->baseType)); - for(UInt ee = argCount; ee < elementCount; ++ee) - { - args.add(irDefaultValue); - } - } - - return LoweredValInfo::simple( - getBuilder()->emitMakeArray(irType, args.getCount(), args.getBuffer())); - } - else if (auto vectorType = as(type)) - { - UInt elementCount = (UInt) GetIntVal(vectorType->elementCount); - - for (UInt ee = 0; ee < argCount; ++ee) - { - auto argExpr = expr->args[ee]; - LoweredValInfo argVal = lowerRValueExpr(context, argExpr); - args.add(getSimpleVal(context, argVal)); - } - if(elementCount > argCount) - { - auto irDefaultValue = getSimpleVal(context, getDefaultVal(vectorType->elementType)); - for(UInt ee = argCount; ee < elementCount; ++ee) - { - args.add(irDefaultValue); - } - } - - return LoweredValInfo::simple( - getBuilder()->emitMakeVector(irType, args.getCount(), args.getBuffer())); - } - else if (auto matrixType = as(type)) - { - UInt rowCount = (UInt) GetIntVal(matrixType->getRowCount()); - - for (UInt rr = 0; rr < argCount; ++rr) - { - auto argExpr = expr->args[rr]; - LoweredValInfo argVal = lowerRValueExpr(context, argExpr); - args.add(getSimpleVal(context, argVal)); - } - if(rowCount > argCount) - { - auto rowType = matrixType->getRowType(); - auto irDefaultValue = getSimpleVal(context, getDefaultVal(rowType)); - - for(UInt rr = argCount; rr < rowCount; ++rr) - { - args.add(irDefaultValue); - } - } - - return LoweredValInfo::simple( - getBuilder()->emitMakeMatrix(irType, args.getCount(), args.getBuffer())); - } - else if (auto declRefType = as(type)) - { - DeclRef declRef = declRefType->declRef; - if (auto aggTypeDeclRef = declRef.as()) - { - UInt argCounter = 0; - for (auto ff : getMembersOfType(aggTypeDeclRef)) - { - if (ff.getDecl()->HasModifier()) - continue; - - UInt argIndex = argCounter++; - if (argIndex < argCount) - { - auto argExpr = expr->args[argIndex]; - LoweredValInfo argVal = lowerRValueExpr(context, argExpr); - args.add(getSimpleVal(context, argVal)); - } - else - { - auto irDefaultValue = getSimpleVal(context, getDefaultVal(ff)); - args.add(irDefaultValue); - } - } - - return LoweredValInfo::simple( - getBuilder()->emitMakeStruct(irType, args.getCount(), args.getBuffer())); - } - } - - // If none of the above cases matched, then we had better - // have zero arguments in the initializer list, in which - // case we are just looking for default initialization. - // - SLANG_UNEXPECTED("unhandled case for initializer list codegen"); - UNREACHABLE_RETURN(LoweredValInfo()); - } - - LoweredValInfo visitBoolLiteralExpr(BoolLiteralExpr* expr) - { - return LoweredValInfo::simple(context->irBuilder->getBoolValue(expr->value)); - } - - LoweredValInfo visitIntegerLiteralExpr(IntegerLiteralExpr* expr) - { - auto type = lowerType(context, expr->type); - return LoweredValInfo::simple(context->irBuilder->getIntValue(type, expr->value)); - } - - LoweredValInfo visitFloatingPointLiteralExpr(FloatingPointLiteralExpr* expr) - { - auto type = lowerType(context, expr->type); - return LoweredValInfo::simple(context->irBuilder->getFloatValue(type, expr->value)); - } - - LoweredValInfo visitStringLiteralExpr(StringLiteralExpr* expr) - { - return LoweredValInfo::simple(context->irBuilder->getStringValue(expr->value.getUnownedSlice())); - } - - LoweredValInfo visitAggTypeCtorExpr(AggTypeCtorExpr* /*expr*/) - { - SLANG_UNIMPLEMENTED_X("codegen for aggregate type constructor expression"); - UNREACHABLE_RETURN(LoweredValInfo()); - } - - // After a call to a function with `out` or `in out` - // parameters, we may need to copy data back into - // the l-value locations used for output arguments. - // - // During lowering of the argument list, we build - // up a list of these "fixup" assignments that need - // to be performed. - struct OutArgumentFixup - { - LoweredValInfo dst; - LoweredValInfo src; - }; - - void addDirectCallArgs( - InvokeExpr* expr, - DeclRef funcDeclRef, - List* ioArgs, - List* ioFixups) - { - UInt argCount = expr->Arguments.getCount(); - UInt argCounter = 0; - for (auto paramDeclRef : getMembersOfType(funcDeclRef)) - { - auto paramDecl = paramDeclRef.getDecl(); - IRType* paramType = lowerType(context, GetType(paramDeclRef)); - - UInt argIndex = argCounter++; - RefPtr argExpr; - if(argIndex < argCount) - { - argExpr = expr->Arguments[argIndex]; - } - else - { - // We have run out of arguments supplied at the call site, - // but there are still parameters remaining. This must mean - // that these parameters have default argument expressions - // associated with them. - argExpr = getInitExpr(paramDeclRef); - - // Assert that such an expression must have been present. - SLANG_ASSERT(argExpr); - - // TODO: The approach we are taking here to default arguments - // is simplistic, and has consequences for the front-end as - // well as binary serialization of modules. - // - // We could consider some more refined approaches where, e.g., - // functions with default arguments generate multiple IR-level - // functions, that compute and provide the default values. - // - // Alternatively, each parameter with defaults could be generated - // into its own callable function that provides the default value, - // so that calling modules can call into a pre-generated function. - // - // Each of these options involves trade-offs, and we need to - // make a conscious decision at some point. - } - - if(paramDecl->HasModifier()) - { - // A `ref` qualified parameter must be implemented with by-reference - // parameter passing, so the argument value should be lowered as - // an l-value. - // - LoweredValInfo loweredArg = lowerLValueExpr(context, argExpr); - - // According to our "calling convention" we need to - // pass a pointer into the callee. Unlike the case for - // `out` and `inout` below, it is never valid to do - // copy-in/copy-out for a `ref` parameter, so we just - // pass in the actual pointer. - // - IRInst* argPtr = getAddress(context, loweredArg, argExpr->loc); - (*ioArgs).add(argPtr); - } - else if (paramDecl->HasModifier() - || paramDecl->HasModifier()) - { - // This is a `out` or `inout` parameter, and so - // the argument must be lowered as an l-value. - - LoweredValInfo loweredArg = lowerLValueExpr(context, argExpr); - - // According to our "calling convention" we need to - // pass a pointer into the callee. - // - // A naive approach would be to just take the address - // of `loweredArg` above and pass it in, but that - // has two issues: - // - // 1. The l-value might not be something that has a single - // well-defined "address" (e.g., `foo.xzy`). - // - // 2. The l-value argument might actually alias some other - // storage that the callee will access (e.g., we are - // passing in a global variable, or two `out` parameters - // are being passed the same location in an array). - // - // In each of these cases, the safe option is to create - // a temporary variable to use for argument-passing, - // and then do copy-in/copy-out around the call. - - LoweredValInfo tempVar = createVar(context, paramType); - - // If the parameter is `in out` or `inout`, then we need - // to ensure that we pass in the original value stored - // in the argument, which we accomplish by assigning - // from the l-value to our temp. - if (paramDecl->HasModifier() - || paramDecl->HasModifier()) - { - assign(context, tempVar, loweredArg); - } - - // Now we can pass the address of the temporary variable - // to the callee as the actual argument for the `in out` - SLANG_ASSERT(tempVar.flavor == LoweredValInfo::Flavor::Ptr); - (*ioArgs).add(tempVar.val); - - // Finally, after the call we will need - // to copy in the other direction: from our - // temp back to the original l-value. - OutArgumentFixup fixup; - fixup.src = tempVar; - fixup.dst = loweredArg; - - (*ioFixups).add(fixup); - - } - else - { - // This is a pure input parameter, and so we will - // pass it as an r-value. - LoweredValInfo loweredArg = lowerRValueExpr(context, argExpr); - addArgs(context, ioArgs, loweredArg); - } - } - } - - // Add arguments that appeared directly in an argument list - // to the list of argument values for a call. - void addDirectCallArgs( - InvokeExpr* expr, - DeclRef funcDeclRef, - List* ioArgs, - List* ioFixups) - { - if (auto callableDeclRef = funcDeclRef.as()) - { - addDirectCallArgs(expr, callableDeclRef, ioArgs, ioFixups); - } - else - { - SLANG_UNEXPECTED("callee was not a callable decl"); - } - } - - void addFuncBaseArgs( - LoweredValInfo funcVal, - List* ioArgs) - { - switch (funcVal.flavor) - { - default: - return; - } - } - - void applyOutArgumentFixups(List const& fixups) - { - for (auto fixup : fixups) - { - assign(context, fixup.dst, fixup.src); - } - } - - struct ResolvedCallInfo - { - DeclRef funcDeclRef; - Expr* baseExpr = nullptr; - }; - - // Try to resolve a the function expression for a call - // into a reference to a specific declaration, along - // with some contextual information about the declaration - // we are calling. - bool tryResolveDeclRefForCall( - RefPtr funcExpr, - ResolvedCallInfo* outInfo) - { - // TODO: unwrap any "identity" expressions that might - // be wrapping the callee. - - // First look to see if the expression references a - // declaration at all. - auto declRefExpr = as(funcExpr); - if(!declRefExpr) - return false; - - // A little bit of future proofing here: if we ever - // allow higher-order functions, then we might be - // calling through a variable/field that has a function - // type, but is not itself a function. - // In such a case we should be careful to not statically - // resolve things. - // - if(auto callableDecl = as(declRefExpr->declRef.getDecl())) - { - // Okay, the declaration is directly callable, so we can continue. - } - else - { - // The callee declaration isn't itself a callable (it must have - // a function type, though). - return false; - } - - // Now we can look at the specific kinds of declaration references, - // and try to tease them apart. - if (auto memberFuncExpr = as(funcExpr)) - { - outInfo->funcDeclRef = memberFuncExpr->declRef; - outInfo->baseExpr = memberFuncExpr->BaseExpression; - return true; - } - else if (auto staticMemberFuncExpr = as(funcExpr)) - { - outInfo->funcDeclRef = staticMemberFuncExpr->declRef; - return true; - } - else if (auto varExpr = as(funcExpr)) - { - outInfo->funcDeclRef = varExpr->declRef; - return true; - } - else - { - // Seems to be a case of declaration-reference we don't know about. - SLANG_UNEXPECTED("unknown declaration reference kind"); - return false; - } - } - - - LoweredValInfo visitInvokeExpr(InvokeExpr* expr) - { - auto type = lowerType(context, expr->type); - - // We are going to look at the syntactic form of - // the "function" expression, so that we can avoid - // a lot of complexity that would come from lowering - // it as a general expression first, and then trying - // to apply it. For example, given `obj.f(a,b)` we - // will try to detect that we are trying to compute - // something like `ObjType::f(obj, a, b)` (in pseudo-code), - // rather than trying to construct a meaningful - // intermediate value for `obj.f` first. - // - // Note that this doe not preclude having support - // for directly generating code from `obj.f` - it - // just may be that such usage is more complicated. - - // Along the way, we may end up collecting additional - // arguments that will be part of the call. - List irArgs; - - // We will also collect "fixup" actions that need - // to be performed after the call, in order to - // copy the final values for `out` parameters - // back to their arguments. - List argFixups; - - auto funcExpr = expr->FunctionExpr; - ResolvedCallInfo resolvedInfo; - if( tryResolveDeclRefForCall(funcExpr, &resolvedInfo) ) - { - // In this case we know exactly what declaration we - // are going to call, and so we can resolve things - // appropriately. - auto funcDeclRef = resolvedInfo.funcDeclRef; - auto baseExpr = resolvedInfo.baseExpr; - - // First comes the `this` argument if we are calling - // a member function: - if( baseExpr ) - { - auto loweredBaseVal = lowerRValueExpr(context, baseExpr); - addArgs(context, &irArgs, loweredBaseVal); - } - - // Then we have the "direct" arguments to the call. - // These may include `out` and `inout` arguments that - // require "fixup" work on the other side. - // - auto funcType = lowerType(context, funcExpr->type); - addDirectCallArgs(expr, funcDeclRef, &irArgs, &argFixups); - auto result = emitCallToDeclRef( - context, - type, - funcDeclRef, - funcType, - irArgs); - applyOutArgumentFixups(argFixups); - return result; - } - - // TODO: In this case we should be emitting code for the callee as - // an ordinary expression, then emitting the arguments according - // to the type information on the callee (e.g., which parameters - // are `out` or `inout`, and then finally emitting the `call` - // instruction. - // - // We don't currently have the case of emitting arguments according - // to function type info (instead of declaration info), and really - // this case can't occur unless we start adding first-class functions - // to the source language. - // - // For now we just bail out with an error. - // - SLANG_UNEXPECTED("could not resolve target declaration for call"); - UNREACHABLE_RETURN(LoweredValInfo()); - } - - LoweredValInfo visitCastToInterfaceExpr( - CastToInterfaceExpr* expr) - { - // We have an expression that is "up-casting" some concrete value - // to an existential type (aka interface type), using a subtype witness - // (which will lower as a witness table) to show that the conversion - // is valid. - // - // At the IR level, this will become a `makeExistential` instruction, - // which collects the above information into a single IR-level value. - // A dynamic CPU implementation of Slang might encode an existential - // as a "fat pointer" representation, which includes a pointer to - // data for the concrete value, plus a pointer to the witness table. - // - // Note: if/when Slang supports more general existential types, such - // as compositions of interface (e.g., `IReadable & IWritable`), then - // we should probably extend the AST and IR mechanism here to accept - // a sequence of witness tables. - // - auto existentialType = lowerType(context, expr->type); - auto concreteValue = getSimpleVal(context, lowerRValueExpr(context, expr->valueArg)); - auto witnessTable = lowerSimpleVal(context, expr->witnessArg); - auto existentialValue = getBuilder()->emitMakeExistential(existentialType, concreteValue, witnessTable); - return LoweredValInfo::simple(existentialValue); - } - - LoweredValInfo subscriptValue( - IRType* type, - LoweredValInfo baseVal, - IRInst* indexVal) - { - auto builder = getBuilder(); - - // The `tryGetAddress` operation will take a complex value representation - // and try to turn it into a single pointer, if possible. - // - baseVal = tryGetAddress(context, baseVal, TryGetAddressMode::Aggressive); - - // The `materialize` operation should ensure that we only have to deal - // with the small number of base cases for lowered value representations. - // - baseVal = materialize(context, baseVal); - - switch (baseVal.flavor) - { - case LoweredValInfo::Flavor::Simple: - return LoweredValInfo::simple( - builder->emitElementExtract( - type, - getSimpleVal(context, baseVal), - indexVal)); - - case LoweredValInfo::Flavor::Ptr: - return LoweredValInfo::ptr( - builder->emitElementAddress( - context->irBuilder->getPtrType(type), - baseVal.val, - indexVal)); - - default: - SLANG_UNIMPLEMENTED_X("subscript expr"); - UNREACHABLE_RETURN(LoweredValInfo()); - } - - } - - LoweredValInfo extractField( - IRType* fieldType, - LoweredValInfo base, - DeclRef field) - { - return Slang::extractField(context, fieldType, base, field); - } - - LoweredValInfo visitStaticMemberExpr(StaticMemberExpr* expr) - { - return emitDeclRef(context, expr->declRef, - lowerType(context, expr->type)); - } - - LoweredValInfo visitGenericAppExpr(GenericAppExpr* /*expr*/) - { - SLANG_UNIMPLEMENTED_X("generic application expression during code generation"); - UNREACHABLE_RETURN(LoweredValInfo()); - } - - LoweredValInfo visitSharedTypeExpr(SharedTypeExpr* /*expr*/) - { - SLANG_UNIMPLEMENTED_X("shared type expression during code generation"); - UNREACHABLE_RETURN(LoweredValInfo()); - } - - LoweredValInfo visitTaggedUnionTypeExpr(TaggedUnionTypeExpr* /*expr*/) - { - SLANG_UNIMPLEMENTED_X("tagged union type expression during code generation"); - UNREACHABLE_RETURN(LoweredValInfo()); - } - - LoweredValInfo visitAssignExpr(AssignExpr* expr) - { - // Because our representation of lowered "values" - // can encompass l-values explicitly, we can - // lower assignment easily. We just lower the left- - // and right-hand sides, and then perform an assignment - // based on the resulting values. - // - auto leftVal = lowerLValueExpr(context, expr->left); - auto rightVal = lowerRValueExpr(context, expr->right); - assign(context, leftVal, rightVal); - - // The result value of the assignment expression is - // the value of the left-hand side (and it is expected - // to be an l-value). - return leftVal; - } - - LoweredValInfo visitLetExpr(LetExpr* expr) - { - // TODO: deal with the case where we might want to capture - // a reference to the bound value... - - auto initVal = lowerLValueExpr(context, expr->decl->initExpr); - setGlobalValue(context, expr->decl, initVal); - auto bodyVal = lowerSubExpr(expr->body); - return bodyVal; - } - - LoweredValInfo visitExtractExistentialValueExpr(ExtractExistentialValueExpr* expr) - { - auto existentialType = lowerType(context, GetType(expr->declRef)); - auto existentialVal = getSimpleVal(context, emitDeclRef(context, expr->declRef, existentialType)); - - auto openedType = lowerType(context, expr->type); - - return LoweredValInfo::simple(getBuilder()->emitExtractExistentialValue(openedType, existentialVal)); - } -}; - -struct LValueExprLoweringVisitor : ExprLoweringVisitorBase -{ - // When visiting a swizzle expression in an l-value context, - // we need to construct a "sizzled l-value." - LoweredValInfo visitSwizzleExpr(SwizzleExpr* expr) - { - auto irType = lowerType(context, expr->type); - auto loweredBase = lowerRValueExpr(context, expr->base); - - RefPtr swizzledLValue = new SwizzledLValueInfo(); - swizzledLValue->type = irType; - - UInt elementCount = (UInt)expr->elementCount; - swizzledLValue->elementCount = elementCount; - - // As a small optimization, we will detect if the base expression - // has also lowered into a swizzle and only return a single - // swizzle instead of nested swizzles. - // - // E.g., if we have input like `foo[i].zw.y` we should optimize it - // down to just `foo[i].w`. - // - if(loweredBase.flavor == LoweredValInfo::Flavor::SwizzledLValue) - { - auto baseSwizzleInfo = loweredBase.getSwizzledLValueInfo(); - - // Our new swizzle will use the same base expression (e.g., - // `foo[i]` in our example above), but will need to remap - // the swizzle indices it uses. - // - swizzledLValue->base = baseSwizzleInfo->base; - for (UInt ii = 0; ii < elementCount; ++ii) - { - // First we get the swizzle element of the "outer" swizzle, - // as it was written by the user. In our running example of - // `foo[i].zw.y` this is the `y` element reference. - // - UInt originalElementIndex = UInt(expr->elementIndices[ii]); - - // Next we will use that original element index to figure - // out which of the elements of the original swizzle this - // should map to. - // - // In our example, `y` means index 1, and so we fetch - // element 1 from the inner swizzle sequence `zw`, to get `w`. - // - SLANG_ASSERT(originalElementIndex < baseSwizzleInfo->elementCount); - UInt remappedElementIndex = baseSwizzleInfo->elementIndices[originalElementIndex]; - - swizzledLValue->elementIndices[ii] = remappedElementIndex; - } - } - else - { - // In the default case, we can just copy the indices being - // used for the swizzle over directly from the expression, - // and use the base as-is. - // - swizzledLValue->base = loweredBase; - for (UInt ii = 0; ii < elementCount; ++ii) - { - swizzledLValue->elementIndices[ii] = (UInt) expr->elementIndices[ii]; - } - } - - context->shared->extValues.add(swizzledLValue); - return LoweredValInfo::swizzledLValue(swizzledLValue); - } -}; - -struct RValueExprLoweringVisitor : ExprLoweringVisitorBase -{ - // A swizzle in an r-value context can save time by just - // emitting the swizzle instructions directly. - LoweredValInfo visitSwizzleExpr(SwizzleExpr* expr) - { - auto irType = lowerType(context, expr->type); - auto irBase = getSimpleVal(context, lowerRValueExpr(context, expr->base)); - - auto builder = getBuilder(); - - auto irIntType = getIntType(context); - - UInt elementCount = (UInt)expr->elementCount; - IRInst* irElementIndices[4]; - for (UInt ii = 0; ii < elementCount; ++ii) - { - irElementIndices[ii] = builder->getIntValue( - irIntType, - (IRIntegerValue)expr->elementIndices[ii]); - } - - auto irSwizzle = builder->emitSwizzle( - irType, - irBase, - elementCount, - &irElementIndices[0]); - - return LoweredValInfo::simple(irSwizzle); - } -}; - -LoweredValInfo lowerLValueExpr( - IRGenContext* context, - Expr* expr) -{ - IRBuilderSourceLocRAII sourceLocInfo(context->irBuilder, expr->loc); - - LValueExprLoweringVisitor visitor; - visitor.context = context; - return visitor.dispatch(expr); -} - -LoweredValInfo lowerRValueExpr( - IRGenContext* context, - Expr* expr) -{ - IRBuilderSourceLocRAII sourceLocInfo(context->irBuilder, expr->loc); - - RValueExprLoweringVisitor visitor; - visitor.context = context; - return visitor.dispatch(expr); -} - -struct StmtLoweringVisitor : StmtVisitor -{ - IRGenContext* context; - - IRBuilder* getBuilder() { return context->irBuilder; } - - void visitEmptyStmt(EmptyStmt*) - { - // Nothing to do. - } - - void visitUnparsedStmt(UnparsedStmt*) - { - SLANG_UNEXPECTED("UnparsedStmt not supported by IR"); - } - - void visitCaseStmtBase(CaseStmtBase*) - { - SLANG_UNEXPECTED("`case` or `default` not under `switch`"); - } - - void visitCompileTimeForStmt(CompileTimeForStmt* stmt) - { - // The user is asking us to emit code for the loop - // body for each value in the given integer range. - // For now, we will handle this by repeatedly lowering - // the body statement, with the loop variable bound - // to a different integer literal value each time. - // - // TODO: eventually we might handle this as just an - // ordinary loop, with an `[unroll]` attribute on - // it that we would respect. - - auto rangeBeginVal = GetIntVal(stmt->rangeBeginVal); - auto rangeEndVal = GetIntVal(stmt->rangeEndVal); - - if (rangeBeginVal >= rangeEndVal) - return; - - auto varDecl = stmt->varDecl; - auto varType = lowerType(context, varDecl->type); - - IRGenEnv subEnvStorage; - IRGenEnv* subEnv = &subEnvStorage; - subEnv->outer = context->env; - - IRGenContext subContextStorage = *context; - IRGenContext* subContext = &subContextStorage; - subContext->env = subEnv; - - - - for (IntegerLiteralValue ii = rangeBeginVal; ii < rangeEndVal; ++ii) - { - auto constVal = getBuilder()->getIntValue( - varType, - ii); - - subEnv->mapDeclToValue[varDecl] = LoweredValInfo::simple(constVal); - - lowerStmt(subContext, stmt->body); - } - } - - // Create a basic block in the current function, - // so that it can be used for a label. - IRBlock* createBlock() - { - return getBuilder()->createBlock(); - } - - /// Does the given block have a terminator? - bool isBlockTerminated(IRBlock* block) - { - return block->getTerminator() != nullptr; - } - - /// Emit a branch to the target block if the current - /// block being inserted into is not already terminated. - void emitBranchIfNeeded(IRBlock* targetBlock) - { - auto builder = getBuilder(); - auto currentBlock = builder->getBlock(); - - // Don't emit if there is no current block. - if(!currentBlock) - return; - - // Don't emit if the block already has a terminator. - if(isBlockTerminated(currentBlock)) - return; - - // The block is unterminated, so cap it off with - // a terminator that branches to the target. - builder->emitBranch(targetBlock); - } - - /// Insert a block at the current location (ending - /// the previous block with an unconditional jump - /// if needed). - void insertBlock(IRBlock* block) - { - auto builder = getBuilder(); - - auto prevBlock = builder->getBlock(); - auto parentFunc = prevBlock ? prevBlock->getParent() : builder->getFunc(); - - // If the previous block doesn't already have - // a terminator instruction, then be sure to - // emit a branch to the new block. - emitBranchIfNeeded(block); - - // Add the new block to the function we are building, - // and setit as the block we will be inserting into. - parentFunc->addBlock(block); - builder->setInsertInto(block); - } - - // Start a new block at the current location. - // This is just the composition of `createBlock` - // and `insertBlock`. - IRBlock* startBlock() - { - auto block = createBlock(); - insertBlock(block); - return block; - } - - /// Start a new block if there isn't a current - /// block that we can append to. - /// - /// The `stmt` parameter is the statement we - /// are about to emit. - void startBlockIfNeeded(Stmt* stmt) - { - auto builder = getBuilder(); - auto currentBlock = builder->getBlock(); - - // If there is a current block and it hasn't - // been terminated, then we can just use that. - if(currentBlock && !isBlockTerminated(currentBlock)) - { - return; - } - - // We are about to emit code *after* a terminator - // instruction, and there is no label to allow - // branching into this code, so whatever we are - // about to emit is going to be unreachable. - // - // Let's diagnose that here just to help the user. - // - // TODO: We might want to have a more robust check - // for unreachable code based on IR analysis instead, - // at which point we'd probably disable this check. - // - context->getSink()->diagnose(stmt, Diagnostics::unreachableCode); - - startBlock(); - } - - void visitIfStmt(IfStmt* stmt) - { - auto builder = getBuilder(); - startBlockIfNeeded(stmt); - - auto condExpr = stmt->Predicate; - auto thenStmt = stmt->PositiveStatement; - auto elseStmt = stmt->NegativeStatement; - - auto irCond = getSimpleVal(context, - lowerRValueExpr(context, condExpr)); - - if (elseStmt) - { - auto thenBlock = createBlock(); - auto elseBlock = createBlock(); - auto afterBlock = createBlock(); - - builder->emitIfElse(irCond, thenBlock, elseBlock, afterBlock); - - insertBlock(thenBlock); - lowerStmt(context, thenStmt); - emitBranchIfNeeded(afterBlock); - - insertBlock(elseBlock); - lowerStmt(context, elseStmt); - - insertBlock(afterBlock); - } - else - { - auto thenBlock = createBlock(); - auto afterBlock = createBlock(); - - builder->emitIf(irCond, thenBlock, afterBlock); - - insertBlock(thenBlock); - lowerStmt(context, thenStmt); - - insertBlock(afterBlock); - } - } - - void addLoopDecorations( - IRInst* inst, - Stmt* stmt) - { - if( stmt->FindModifier() ) - { - getBuilder()->addLoopControlDecoration(inst, kIRLoopControl_Unroll); - } - // TODO: handle other cases here - } - - void visitForStmt(ForStmt* stmt) - { - auto builder = getBuilder(); - startBlockIfNeeded(stmt); - - // The initializer clause for the statement - // can always safetly be emitted to the current block. - if (auto initStmt = stmt->InitialStatement) - { - lowerStmt(context, initStmt); - } - - // We will create blocks for the various places - // we need to jump to inside the control flow, - // including the blocks that will be referenced - // by `continue` or `break` statements. - auto loopHead = createBlock(); - auto bodyLabel = createBlock(); - auto breakLabel = createBlock(); - auto continueLabel = createBlock(); - - // Register the `break` and `continue` labels so - // that we can find them for nested statements. - context->shared->breakLabels.Add(stmt, breakLabel); - context->shared->continueLabels.Add(stmt, continueLabel); - - // Emit the branch that will start out loop, - // and then insert the block for the head. - - auto loopInst = builder->emitLoop( - loopHead, - breakLabel, - continueLabel); - - addLoopDecorations(loopInst, stmt); - - insertBlock(loopHead); - - // Now that we are within the header block, we - // want to emit the expression for the loop condition: - if (auto condExpr = stmt->PredicateExpression) - { - auto irCondition = getSimpleVal(context, - lowerRValueExpr(context, stmt->PredicateExpression)); - - // Now we want to `break` if the loop condition is false. - builder->emitLoopTest( - irCondition, - bodyLabel, - breakLabel); - } - - // Emit the body of the loop - insertBlock(bodyLabel); - lowerStmt(context, stmt->Statement); - - // Insert the `continue` block - insertBlock(continueLabel); - if (auto incrExpr = stmt->SideEffectExpression) - { - lowerRValueExpr(context, incrExpr); - } - - // At the end of the body we need to jump back to the top. - emitBranchIfNeeded(loopHead); - - // Finally we insert the label that a `break` will jump to - insertBlock(breakLabel); - } - - void visitWhileStmt(WhileStmt* stmt) - { - // Generating IR for `while` statement is similar to a - // `for` statement, but without a lot of the complications. - - auto builder = getBuilder(); - startBlockIfNeeded(stmt); - - // We will create blocks for the various places - // we need to jump to inside the control flow, - // including the blocks that will be referenced - // by `continue` or `break` statements. - auto loopHead = createBlock(); - auto bodyLabel = createBlock(); - auto breakLabel = createBlock(); - - // A `continue` inside a `while` loop always - // jumps to the head of hte loop. - auto continueLabel = loopHead; - - // Register the `break` and `continue` labels so - // that we can find them for nested statements. - context->shared->breakLabels.Add(stmt, breakLabel); - context->shared->continueLabels.Add(stmt, continueLabel); - - // Emit the branch that will start out loop, - // and then insert the block for the head. - - auto loopInst = builder->emitLoop( - loopHead, - breakLabel, - continueLabel); - - addLoopDecorations(loopInst, stmt); - - insertBlock(loopHead); - - // Now that we are within the header block, we - // want to emit the expression for the loop condition: - if (auto condExpr = stmt->Predicate) - { - auto irCondition = getSimpleVal(context, - lowerRValueExpr(context, condExpr)); - - // Now we want to `break` if the loop condition is false. - builder->emitLoopTest( - irCondition, - bodyLabel, - breakLabel); - } - - // Emit the body of the loop - insertBlock(bodyLabel); - lowerStmt(context, stmt->Statement); - - // At the end of the body we need to jump back to the top. - emitBranchIfNeeded(loopHead); - - // Finally we insert the label that a `break` will jump to - insertBlock(breakLabel); - } - - void visitDoWhileStmt(DoWhileStmt* stmt) - { - // Generating IR for `do {...} while` statement is similar to a - // `while` statement, just with the test in a different place - - auto builder = getBuilder(); - startBlockIfNeeded(stmt); - - // We will create blocks for the various places - // we need to jump to inside the control flow, - // including the blocks that will be referenced - // by `continue` or `break` statements. - auto loopHead = createBlock(); - auto testLabel = createBlock(); - auto breakLabel = createBlock(); - - // A `continue` inside a `do { ... } while ( ... )` loop always - // jumps to the loop test. - auto continueLabel = testLabel; - - // Register the `break` and `continue` labels so - // that we can find them for nested statements. - context->shared->breakLabels.Add(stmt, breakLabel); - context->shared->continueLabels.Add(stmt, continueLabel); - - // Emit the branch that will start out loop, - // and then insert the block for the head. - - auto loopInst = builder->emitLoop( - loopHead, - breakLabel, - continueLabel); - - addLoopDecorations(loopInst, stmt); - - insertBlock(loopHead); - - // Emit the body of the loop - lowerStmt(context, stmt->Statement); - - insertBlock(testLabel); - - // Now that we are within the header block, we - // want to emit the expression for the loop condition: - if (auto condExpr = stmt->Predicate) - { - auto irCondition = getSimpleVal(context, - lowerRValueExpr(context, condExpr)); - - // Now we want to `break` if the loop condition is false, - // otherwise we will jump back to the head of the loop. - builder->emitLoopTest( - irCondition, - loopHead, - breakLabel); - } - - // Finally we insert the label that a `break` will jump to - insertBlock(breakLabel); - } - - void visitExpressionStmt(ExpressionStmt* stmt) - { - startBlockIfNeeded(stmt); - - // The statement evaluates an expression - // (for side effects, one assumes) and then - // discards the result. As such, we simply - // lower the expression, and don't use - // the result. - // - // Note that we lower using the l-value path, - // so that an expression statement that names - // a location (but doesn't load from it) - // will not actually emit a load. - lowerLValueExpr(context, stmt->Expression); - } - - void visitDeclStmt(DeclStmt* stmt) - { - startBlockIfNeeded(stmt); - - // For now, we lower a declaration directly - // into the current context. - // - // TODO: We may want to consider whether - // nested type/function declarations should - // be lowered into the global scope during - // IR generation, or whether they should - // be lifted later (pushing capture analysis - // down to the IR). - // - lowerDecl(context, stmt->decl); - } - - void visitSeqStmt(SeqStmt* stmt) - { - // To lower a sequence of statements, - // just lower each in order - for (auto ss : stmt->stmts) - { - lowerStmt(context, ss); - } - } - - void visitBlockStmt(BlockStmt* stmt) - { - // To lower a block (scope) statement, - // just lower its body. The IR doesn't - // need to reflect the scoping of the AST. - lowerStmt(context, stmt->body); - } - - void visitReturnStmt(ReturnStmt* stmt) - { - startBlockIfNeeded(stmt); - - // A `return` statement turns into a return - // instruction. If the statement had an argument - // expression, then we need to lower that to - // a value first, and then emit the resulting value. - if( auto expr = stmt->Expression ) - { - auto loweredExpr = lowerRValueExpr(context, expr); - - getBuilder()->emitReturn(getSimpleVal(context, loweredExpr)); - } - else - { - getBuilder()->emitReturn(); - } - } - - void visitDiscardStmt(DiscardStmt* stmt) - { - startBlockIfNeeded(stmt); - getBuilder()->emitDiscard(); - } - - void visitBreakStmt(BreakStmt* stmt) - { - startBlockIfNeeded(stmt); - - // Semantic checking is responsible for finding - // the statement taht this `break` breaks out of - auto parentStmt = stmt->parentStmt; - SLANG_ASSERT(parentStmt); - - // We just need to look up the basic block that - // corresponds to the break label for that statement, - // and then emit an instruction to jump to it. - IRBlock* targetBlock = nullptr; - context->shared->breakLabels.TryGetValue(parentStmt, targetBlock); - SLANG_ASSERT(targetBlock); - getBuilder()->emitBreak(targetBlock); - } - - void visitContinueStmt(ContinueStmt* stmt) - { - startBlockIfNeeded(stmt); - - // Semantic checking is responsible for finding - // the loop that this `continue` statement continues - auto parentStmt = stmt->parentStmt; - SLANG_ASSERT(parentStmt); - - - // We just need to look up the basic block that - // corresponds to the continue label for that statement, - // and then emit an instruction to jump to it. - IRBlock* targetBlock = nullptr; - context->shared->continueLabels.TryGetValue(parentStmt, targetBlock); - SLANG_ASSERT(targetBlock); - getBuilder()->emitContinue(targetBlock); - } - - // Lowering a `switch` statement can get pretty involved, - // so we need to track a bit of extra data: - struct SwitchStmtInfo - { - // The block that will be made to contain the `switch` statement - IRBlock* initialBlock = nullptr; - - // The label for the `default` case, if any. - IRBlock* defaultLabel = nullptr; - - // The label of the current "active" case block. - IRBlock* currentCaseLabel = nullptr; - - // Has anything been emitted to the current "active" case block? - bool anythingEmittedToCurrentCaseBlock = false; - - // The collected (value, label) pairs for - // all the `case` statements. - List cases; - }; - - // We need a label to use for a `case` or `default` statement, - // so either create one here, or re-use the current one if - // that is okay. - IRBlock* getLabelForCase(SwitchStmtInfo* info) - { - // Look at the "current" label we are working with. - auto currentCaseLabel = info->currentCaseLabel; - - // If there is a current block, and it is empty, - // then it is still a viable target (we are in - // a case of "trivial fall-through" from the previous - // block). - if(currentCaseLabel && !info->anythingEmittedToCurrentCaseBlock) - { - return currentCaseLabel; - } - - // Othwerise, we need to start a new block and use that. - IRBlock* newCaseLabel = createBlock(); - - // Note: if the previous block failed - // to end with a `break`, then inserting - // this block will append an unconditional - // branch to the end of it that will target - // this block. - insertBlock(newCaseLabel); - - info->currentCaseLabel = newCaseLabel; - info->anythingEmittedToCurrentCaseBlock = false; - return newCaseLabel; - } - - // Given a statement that appears as (or in) the body - // of a `switch` statement - void lowerSwitchCases(Stmt* inStmt, SwitchStmtInfo* info) - { - // TODO: in the general case (e.g., if we were going - // to eventual lower to an unstructured format like LLVM), - // the Right Way to handle C-style `switch` statements - // is just to emit the body directly as "normal" statements, - // and then treat `case` and `default` as special statements - // that start a new block and register a label with the - // enclosing `switch`. - // - // For now we will assume that any `case` and `default` - // statements need to be directly nested under the `switch`, - // and so we can find them with a simpler walk. - - Stmt* stmt = inStmt; - - // Unwrap any surrounding `{ ... }` so we can look - // at the statement inside. - while(auto blockStmt = as(stmt)) - { - stmt = blockStmt->body; - continue; - } - - if(auto seqStmt = as(stmt)) - { - // Walk through teh children and process each. - for(auto childStmt : seqStmt->stmts) - { - lowerSwitchCases(childStmt, info); - } - } - else if(auto caseStmt = as(stmt)) - { - // A full `case` statement has a value we need - // to test against. It is expected to be a - // compile-time constant, so we will emit - // it like an expression here, and then hope - // for the best. - // - // TODO: figure out something cleaner. - - // Actually, one gotcha is that if we ever allow non-constant - // expressions here (or anything that requires instructions - // to be emitted to yield its value), then those instructions - // need to go into an appropriate block. - - IRGenContext subContext = *context; - IRBuilder subBuilder = *getBuilder(); - subBuilder.setInsertInto(info->initialBlock); - subContext.irBuilder = &subBuilder; - auto caseVal = getSimpleVal(context, lowerRValueExpr(&subContext, caseStmt->expr)); - - // Figure out where we are branching to. - auto label = getLabelForCase(info); - - // Add this `case` to the list for the enclosing `switch`. - info->cases.add(caseVal); - info->cases.add(label); - } - else if(auto defaultStmt = as(stmt)) - { - auto label = getLabelForCase(info); - - // We expect to only find a single `default` stmt. - SLANG_ASSERT(!info->defaultLabel); - - info->defaultLabel = label; - } - else if(auto emptyStmt = as(stmt)) - { - // Special-case empty statements so they don't - // mess up our "trivial fall-through" optimization. - } - else - { - // We have an ordinary statement, that needs to get - // emitted to the current case block. - if(!info->currentCaseLabel) - { - // It possible in full C/C++ to have statements - // before the first `case`. Usually these are - // unreachable, unless they start with a label. - // - // We'll ignore them here, figuring they are - // dead. If we ever add `LabelStmt` then we'd - // need to emit these statements to a dummy - // block just in case. - } - else - { - // Emit the code to our current case block, - // and record that we've done so. - lowerStmt(context, stmt); - info->anythingEmittedToCurrentCaseBlock = true; - } - } - } - - void visitSwitchStmt(SwitchStmt* stmt) - { - auto builder = getBuilder(); - startBlockIfNeeded(stmt); - - // Given a statement: - // - // switch( CONDITION ) - // { - // case V0: - // S0; - // break; - // - // case V1: - // default: - // S1; - // break; - // } - // - // we want to generate IR like: - // - // let %c = ; - // switch %c, // value to switch on - // %breakLabel, // join point (and break target) - // %s1, // default label - // %v0, // first case value - // %s0, // first case label - // %v1, // second case value - // %s1 // second case label - // s0: - // - // break %breakLabel - // s1: - // - // break %breakLabel - // breakLabel: - // - - // First emit code to compute the condition: - auto conditionVal = getSimpleVal(context, lowerRValueExpr(context, stmt->condition)); - - // Remember the initial block so that we can add to it - // after we've collected all the `case`s - auto initialBlock = builder->getBlock(); - - // Next, create a block to use as the target for any `break` statements - auto breakLabel = createBlock(); - - // Register the `break` label so - // that we can find it for nested statements. - context->shared->breakLabels.Add(stmt, breakLabel); - - builder->setInsertInto(initialBlock->getParent()); - - // Iterate over the body of the statement, looking - // for `case` or `default` statements: - SwitchStmtInfo info; - info.initialBlock = initialBlock; - info.defaultLabel = nullptr; - lowerSwitchCases(stmt->body, &info); - - // TODO: once we've discovered the cases, we should - // be able to make a quick pass over the list and eliminate - // any cases that have the exact same label as the `default` - // case, since these don't actually need to be represented. - - // If the current block (the end of the last - // `case`) is not terminated, then terminate with a - // `break` operation. - // - // Double check that we aren't in the initial - // block, so we don't get tripped up on an - // empty `switch`. - auto curBlock = builder->getBlock(); - if(curBlock != initialBlock) - { - // Is the block already terminated? - if(!curBlock->getTerminator()) - { - // Not terminated, so add one. - builder->emitBreak(breakLabel); - } - } - - // If there was no `default` statement, then the - // default case will just branch directly to the end. - auto defaultLabel = info.defaultLabel ? info.defaultLabel : breakLabel; - - // Now that we've collected the cases, we are - // prepared to emit the `switch` instruction - // itself. - builder->setInsertInto(initialBlock); - builder->emitSwitch( - conditionVal, - breakLabel, - defaultLabel, - info.cases.getCount(), - info.cases.getBuffer()); - - // Finally we insert the label that a `break` will jump to - // (and that control flow will fall through to otherwise). - // This is the block that subsequent code will go into. - insertBlock(breakLabel); - context->shared->breakLabels.Remove(stmt); - } -}; - -void lowerStmt( - IRGenContext* context, - Stmt* stmt) -{ - IRBuilderSourceLocRAII sourceLocInfo(context->irBuilder, stmt->loc); - - StmtLoweringVisitor visitor; - visitor.context = context; - - try - { - visitor.dispatch(stmt); - } - // Don't emit any context message for an explicit `AbortCompilationException` - // because it should only happen when an error is already emitted. - catch(AbortCompilationException&) { throw; } - catch(...) - { - context->getSink()->noteInternalErrorLoc(stmt->loc); - throw; - } -} - -/// Create and return a mutable temporary initialized with `val` -static LoweredValInfo moveIntoMutableTemp( - IRGenContext* context, - LoweredValInfo const& val) -{ - IRInst* irVal = getSimpleVal(context, val); - auto type = irVal->getDataType(); - auto var = createVar(context, type); - - assign(context, var, LoweredValInfo::simple(irVal)); - return var; -} - -LoweredValInfo tryGetAddress( - IRGenContext* context, - LoweredValInfo const& inVal, - TryGetAddressMode mode) -{ - LoweredValInfo val = inVal; - - switch(val.flavor) - { - case LoweredValInfo::Flavor::Ptr: - // The `Ptr` case means that we already have an IR value with - // the address of our value. Easy! - return val; - - case LoweredValInfo::Flavor::BoundSubscript: - { - // If we are are trying to turn a subscript operation like `buffer[index]` - // into a pointer, then we need to find a `ref` accessor declared - // as part of the subscript operation being referenced. - // - auto subscriptInfo = val.getBoundSubscriptInfo(); - - // We don't want to immediately bind to a `ref` accessor if there is - // a `set` accessor available, unless we are in an "aggressive" mode - // where we really want/need a pointer to be able to make progress. - // - if(mode != TryGetAddressMode::Aggressive - && getMembersOfType(subscriptInfo->declRef).Count()) - { - // There is a setter that we should consider using, - // so don't go and aggressively collapse things just yet. - return val; - } - - auto refAccessors = getMembersOfType(subscriptInfo->declRef); - if(refAccessors.Count()) - { - // The `ref` accessor will return a pointer to the value, so - // we need to reflect that in the type of our `call` instruction. - IRType* ptrType = context->irBuilder->getPtrType(subscriptInfo->type); - - LoweredValInfo refVal = emitCallToDeclRef( - context, - ptrType, - *refAccessors.begin(), - nullptr, - subscriptInfo->args); - - // The result from the call should be a pointer, and it - // is the address that we wanted in the first place. - return LoweredValInfo::ptr(getSimpleVal(context, refVal)); - } - - // Otherwise, there was no `ref` accessor, and so it is not possible - // to materialize this location into a pointer for whatever purpose - // we have in mind (e.g., passing it to an atomic operation). - } - break; - - case LoweredValInfo::Flavor::BoundMember: - { - auto boundMemberInfo = val.getBoundMemberInfo(); - - // If we hit this case, then it means that we have a reference - // to a single field in something, but for whatever reason the - // higher-level logic was not able to turn it into a pointer - // already (maybe the base value for the field reference is - // a `BoundSubscript`, etc.). - // - // We need to read the entire base value out, modify the field - // we care about, and then write it back. - - auto declRef = boundMemberInfo->declRef; - if( auto fieldDeclRef = declRef.as() ) - { - auto baseVal = boundMemberInfo->base; - auto basePtr = tryGetAddress(context, baseVal, TryGetAddressMode::Aggressive); - - return extractField(context, boundMemberInfo->type, basePtr, fieldDeclRef); - } - - } - break; - - case LoweredValInfo::Flavor::SwizzledLValue: - { - auto originalSwizzleInfo = val.getSwizzledLValueInfo(); - auto originalBase = originalSwizzleInfo->base; - - UInt elementCount = originalSwizzleInfo->elementCount; - - auto newBase = tryGetAddress(context, originalBase, TryGetAddressMode::Aggressive); - RefPtr newSwizzleInfo = new SwizzledLValueInfo(); - context->shared->extValues.add(newSwizzleInfo); - - newSwizzleInfo->base = newBase; - newSwizzleInfo->type = originalSwizzleInfo->type; - newSwizzleInfo->elementCount = elementCount; - for(UInt ee = 0; ee < elementCount; ++ee) - newSwizzleInfo->elementIndices[ee] = originalSwizzleInfo->elementIndices[ee]; - - return LoweredValInfo::swizzledLValue(newSwizzleInfo); - } - break; - - // TODO: are there other cases we need to handled here? - - default: - break; - } - - // If none of the special cases above applied, then we werent' able to make - // this value into a pointer, and we should just return it as-is. - return val; -} - -IRInst* getAddress( - IRGenContext* context, - LoweredValInfo const& inVal, - SourceLoc diagnosticLocation) -{ - LoweredValInfo val = tryGetAddress(context, inVal, TryGetAddressMode::Aggressive); - - if( val.flavor == LoweredValInfo::Flavor::Ptr ) - { - return val.val; - } - - context->getSink()->diagnose(diagnosticLocation, Diagnostics::invalidLValueForRefParameter); - return nullptr; -} - -void assign( - IRGenContext* context, - LoweredValInfo const& inLeft, - LoweredValInfo const& inRight) -{ - LoweredValInfo left = inLeft; - LoweredValInfo right = inRight; - - // Before doing the case analysis on the shape of the `left` value, - // we might as well go ahead and see if we can coerce it into - // a simple pointer, since that would make our life a lot easier - // when handling complex cases. - // - left = tryGetAddress(context, left, TryGetAddressMode::Default); - - auto builder = context->irBuilder; - -top: - switch (left.flavor) - { - case LoweredValInfo::Flavor::Ptr: - { - // The `left` value is just a pointer, so we can emit - // a store to it directly. - // - builder->emitStore( - left.val, - getSimpleVal(context, right)); - } - break; - - case LoweredValInfo::Flavor::SwizzledLValue: - { - // The `left` value is of the form `.`. - // How we will handle this depends on what `base` looks like: - auto swizzleInfo = left.getSwizzledLValueInfo(); - auto loweredBase = swizzleInfo->base; - - // Note that the call to `tryGetAddress` at the start should - // ensure that `loweredBase` has been simplified as much as - // possible (e.g., if it is possible to turn it into a - // `LoweredValInfo::ptr()` then that will have been done). - - switch( loweredBase.flavor ) - { - default: - { - // Our fallback position is to lower via a temporary, e.g.: - // - // float4 tmp = ; - // tmp.xyz = float3(...); - // = tmp; - // - - // Load from the base value - IRInst* irLeftVal = getSimpleVal(context, loweredBase); - - // Extract a simple value for the right-hand side - IRInst* irRightVal = getSimpleVal(context, right); - - // Apply the swizzle - IRInst* irSwizzled = builder->emitSwizzleSet( - irLeftVal->getDataType(), - irLeftVal, - irRightVal, - swizzleInfo->elementCount, - swizzleInfo->elementIndices); - - // And finally, store the value back where we got it. - // - // Note: this is effectively a recursive call to - // `assign()`, so we do a simple tail-recursive call here. - left = loweredBase; - right = LoweredValInfo::simple(irSwizzled); - goto top; - } - break; - - case LoweredValInfo::Flavor::Ptr: - { - // We are writing through a pointer, which might be - // pointing into a UAV or other memory resource, so - // we can't introduce use a temporary like the case - // above, because then we would read and write bytes - // that are not strictly required for the store. - // - // Note that the messy case of a "swizzle of a swizzle" - // was handled already in lowering of a `SwizzleExpr`, - // so that we don't need to deal with that case here. - // - // TODO: we may need to consider whether there is - // enough value in a masked store like this to keep - // it around, in comparison to a simpler model where - // we simply form a pointer to each of the vector - // elements and write to them individually. - // - // TODO: we might also consider just special-casing - // single-element swizzles so that the common case - // can turn into a simple `store` instead of a - // `swizzledStore`. - // - IRInst* irRightVal = getSimpleVal(context, right); - builder->emitSwizzledStore( - loweredBase.val, - irRightVal, - swizzleInfo->elementCount, - swizzleInfo->elementIndices); - } - break; - } - } - break; - - case LoweredValInfo::Flavor::BoundSubscript: - { - // The `left` value refers to a subscript operation on - // a resource type, bound to particular arguments, e.g.: - // `someStructuredBuffer[index]`. - // - // When storing to such a value, we need to emit a call - // to the appropriate builtin "setter" accessor, if there - // is one, and then fall back to a `ref` accessor if - // there is no setter. - // - auto subscriptInfo = left.getBoundSubscriptInfo(); - - // Search for an appropriate "setter" declaration - auto setters = getMembersOfType(subscriptInfo->declRef); - if (setters.Count()) - { - auto allArgs = subscriptInfo->args; - addArgs(context, &allArgs, right); - - emitCallToDeclRef( - context, - builder->getVoidType(), - *setters.begin(), - nullptr, - allArgs); - return; - } - - auto refAccessors = getMembersOfType(subscriptInfo->declRef); - if(refAccessors.Count()) - { - // The `ref` accessor will return a pointer to the value, so - // we need to reflect that in the type of our `call` instruction. - IRType* ptrType = context->irBuilder->getPtrType(subscriptInfo->type); - - LoweredValInfo refVal = emitCallToDeclRef( - context, - ptrType, - *refAccessors.begin(), - nullptr, - subscriptInfo->args); - - // The result from the call needs to be implicitly dereferenced, - // so that it can work as an l-value of the desired result type. - left = LoweredValInfo::ptr(getSimpleVal(context, refVal)); - - // Tail-recursively attempt assignment again on the new l-value. - goto top; - } - - // No setter found? Then we have an error! - SLANG_UNEXPECTED("no setter found"); - break; - } - break; - - case LoweredValInfo::Flavor::BoundMember: - { - auto boundMemberInfo = left.getBoundMemberInfo(); - - // If we hit this case, then it means that we are trying to set - // a single field in someting that is not atomically set-able. - // (e.g., an element of a value where the `subscript` operation - // has `get` and `set` but not a `ref` accessor). - // - // We need to read the entire base value out, modify the field - // we care about, and then write it back. - - auto declRef = boundMemberInfo->declRef; - if( auto fieldDeclRef = declRef.as() ) - { - // materialize the base value and move it into - // a mutable temporary if needed - auto baseVal = boundMemberInfo->base; - auto tempVal = moveIntoMutableTemp(context, baseVal); - - // extract the field l-value out of the temporary - auto tempFieldVal = extractField(context, boundMemberInfo->type, tempVal, fieldDeclRef); - - // assign to the field of the temporary l-value - assign(context, tempFieldVal, right); - - // write back the modified temporary to the base l-value - assign(context, baseVal, tempVal); - - return; - } - else - { - SLANG_UNEXPECTED("handled member flavor"); - } - - } - break; - - default: - SLANG_UNIMPLEMENTED_X("assignment"); - break; - } -} - -struct DeclLoweringVisitor : DeclVisitor -{ - IRGenContext* context; - - IRBuilder* getBuilder() - { - return context->irBuilder; - } - - LoweredValInfo visitDeclBase(DeclBase* /*decl*/) - { - SLANG_UNIMPLEMENTED_X("decl catch-all"); - UNREACHABLE_RETURN(LoweredValInfo()); - } - - LoweredValInfo visitDecl(Decl* /*decl*/) - { - SLANG_UNIMPLEMENTED_X("decl catch-all"); - UNREACHABLE_RETURN(LoweredValInfo()); - } - - LoweredValInfo visitExtensionDecl(ExtensionDecl* decl) - { - for (auto & member : decl->Members) - ensureDecl(context, member); - return LoweredValInfo(); - } - - LoweredValInfo visitImportDecl(ImportDecl* /*decl*/) - { - return LoweredValInfo(); - } - - LoweredValInfo visitEmptyDecl(EmptyDecl* /*decl*/) - { - return LoweredValInfo(); - } - - LoweredValInfo visitSyntaxDecl(SyntaxDecl* /*decl*/) - { - return LoweredValInfo(); - } - - LoweredValInfo visitAttributeDecl(AttributeDecl* /*decl*/) - { - return LoweredValInfo(); - } - - LoweredValInfo visitTypeDefDecl(TypeDefDecl* decl) - { - // A type alias declaration may be generic, if it is - // nested under a generic type/function/etc. - // - NestedContext nested(this); - auto subBuilder = nested.getBuilder(); - auto subContext = nested.getContet(); - IRGeneric* outerGeneric = emitOuterGenerics(subContext, decl, decl); - - // TODO: if a type alias declaration can have linkage, - // we will need to lower it to some kind of global - // value in the IR so that we can attach a name to it. - // - // For now, we can only attach a name *if* the type - // alias is somehow generic. - if(outerGeneric) - { - addLinkageDecoration(context, outerGeneric, decl); - } - - auto type = lowerType(subContext, decl->type.type); - - return LoweredValInfo::simple(finishOuterGenerics(subBuilder, type)); - } - - LoweredValInfo visitGenericTypeParamDecl(GenericTypeParamDecl* /*decl*/) - { - return LoweredValInfo(); - } - - LoweredValInfo visitGenericTypeConstraintDecl(GenericTypeConstraintDecl* decl) - { - // This might be a type constraint on an associated type, - // in which case it should lower as the key for that - // interface requirement. - if(auto assocTypeDecl = as(decl->ParentDecl)) - { - // TODO: might need extra steps if we ever allow - // generic associated types. - - - if(auto interfaceDecl = as(assocTypeDecl->ParentDecl)) - { - // Okay, this seems to be an interface rquirement, and - // we should lower it as such. - return LoweredValInfo::simple(getInterfaceRequirementKey(decl)); - } - } - - if(auto globalGenericParamDecl = as(decl->ParentDecl)) - { - // This is a constraint on a global generic type parameters, - // and so it should lower as a parameter of its own. - - auto inst = getBuilder()->emitGlobalGenericParam(); - addLinkageDecoration(context, inst, decl); - return LoweredValInfo::simple(inst); - } - - // Otherwise we really don't expect to see a type constraint - // declaration like this during lowering, because a generic - // should have set up a parameter for any constraints as - // part of being lowered. - - SLANG_UNEXPECTED("generic type constraint during lowering"); - UNREACHABLE_RETURN(LoweredValInfo()); - } - - LoweredValInfo visitGlobalGenericParamDecl(GlobalGenericParamDecl* decl) - { - auto inst = getBuilder()->emitGlobalGenericParam(); - addLinkageDecoration(context, inst, decl); - return LoweredValInfo::simple(inst); - } - - void lowerWitnessTable( - IRGenContext* subContext, - WitnessTable* astWitnessTable, - IRWitnessTable* irWitnessTable, - Dictionary mapASTToIRWitnessTable) - { - auto subBuilder = subContext->irBuilder; - - for(auto entry : astWitnessTable->requirementDictionary) - { - auto requiredMemberDecl = entry.Key; - auto satisfyingWitness = entry.Value; - - auto irRequirementKey = getInterfaceRequirementKey(requiredMemberDecl); - IRInst* irSatisfyingVal = nullptr; - - switch(satisfyingWitness.getFlavor()) - { - case RequirementWitness::Flavor::declRef: - { - auto satisfyingDeclRef = satisfyingWitness.getDeclRef(); - irSatisfyingVal = getSimpleVal(subContext, - emitDeclRef(subContext, satisfyingDeclRef, - // TODO: we need to know what type to plug in here... - nullptr)); - } - break; - - case RequirementWitness::Flavor::val: - { - auto satisfyingVal = satisfyingWitness.getVal(); - irSatisfyingVal = lowerSimpleVal(subContext, satisfyingVal); - } - break; - - case RequirementWitness::Flavor::witnessTable: - { - auto astReqWitnessTable = satisfyingWitness.getWitnessTable(); - IRWitnessTable* irSatisfyingWitnessTable = nullptr; - if(!mapASTToIRWitnessTable.TryGetValue(astReqWitnessTable, irSatisfyingWitnessTable)) - { - // Need to construct a sub-witness-table - irSatisfyingWitnessTable = subBuilder->createWitnessTable(); - - // Recursively lower the sub-table. - lowerWitnessTable( - subContext, - astReqWitnessTable, - irSatisfyingWitnessTable, - mapASTToIRWitnessTable); - - irSatisfyingWitnessTable->moveToEnd(); - } - irSatisfyingVal = irSatisfyingWitnessTable; - } - break; - - default: - SLANG_UNEXPECTED("handled requirement witness case"); - break; - } - - - subBuilder->createWitnessTableEntry( - irWitnessTable, - irRequirementKey, - irSatisfyingVal); - } - } - - LoweredValInfo visitInheritanceDecl(InheritanceDecl* inheritanceDecl) - { - // An inheritance clause inside of an `interface` - // declaration should not give rise to a witness - // table, because it represents something the - // interface requires, and not what it provides. - // - auto parentDecl = inheritanceDecl->ParentDecl; - if (auto parentInterfaceDecl = as(parentDecl)) - { - return LoweredValInfo::simple(getInterfaceRequirementKey(inheritanceDecl)); - } - // - // We also need to cover the case where an `extension` - // declaration is being used to add a conformance to - // an existing `interface`: - // - if(auto parentExtensionDecl = as(parentDecl)) - { - auto targetType = parentExtensionDecl->targetType; - if(auto targetDeclRefType = as(targetType)) - { - if(auto targetInterfaceDeclRef = targetDeclRefType->declRef.as()) - { - return LoweredValInfo::simple(getInterfaceRequirementKey(inheritanceDecl)); - } - } - } - - // Find the type that is doing the inheriting. - // Under normal circumstances it is the type declaration that - // is the parent for the inheritance declaration, but if - // the inheritance declaration is on an `extension` declaration, - // then we need to identify the type being extended. - // - RefPtr subType; - if (auto extParentDecl = as(parentDecl)) - { - subType = extParentDecl->targetType.type; - } - else - { - subType = DeclRefType::Create( - context->getSession(), - makeDeclRef(parentDecl)); - } - - // What is the super-type that we have declared we inherit from? - RefPtr superType = inheritanceDecl->base.type; - - // Construct the mangled name for the witness table, which depends - // on the type that is conforming, and the type that it conforms to. - // - // TODO: This approach doesn't really make sense for generic `extension` conformances. - auto mangledName = getMangledNameForConformanceWitness(subType, superType); - - // A witness table may need to be generic, if the outer - // declaration (either a type declaration or an `extension`) - // is generic. - // - NestedContext nested(this); - auto subBuilder = nested.getBuilder(); - auto subContext = nested.getContet(); - emitOuterGenerics(subContext, inheritanceDecl, inheritanceDecl); - - // Lower the super-type to force its declaration to be lowered. - // - // Note: we are using the "sub-context" here because the - // type being inherited from could reference generic parameters, - // and we need those parameters to lower as references to - // the parameters of our IR-level generic. - // - lowerType(subContext, superType); - - // Create the IR-level witness table - auto irWitnessTable = subBuilder->createWitnessTable(); - addLinkageDecoration(context, irWitnessTable, inheritanceDecl, mangledName.getUnownedSlice()); - - // Register the value now, rather than later, to avoid any possible infinite recursion. - setGlobalValue(context, inheritanceDecl, LoweredValInfo::simple(irWitnessTable)); - - // Make sure that all the entries in the witness table have been filled in, - // including any cases where there are sub-witness-tables for conformances - Dictionary mapASTToIRWitnessTable; - lowerWitnessTable( - subContext, - inheritanceDecl->witnessTable, - irWitnessTable, - mapASTToIRWitnessTable); - - irWitnessTable->moveToEnd(); - - return LoweredValInfo::simple(finishOuterGenerics(subBuilder, irWitnessTable)); - } - - LoweredValInfo visitDeclGroup(DeclGroup* declGroup) - { - // To lower a group of declarations, we just - // lower each one individually. - // - for (auto decl : declGroup->decls) - { - IRBuilderSourceLocRAII sourceLocInfo(context->irBuilder, decl->loc); - - // Note: I am directly invoking `dispatch` here, - // instead of `ensureDecl` just to try and - // make sure that we don't accidentally - // emit things to an outer context. - // - // TODO: make sure that can't happen anyway. - dispatch(decl); - } - - return LoweredValInfo(); - } - - LoweredValInfo visitSubscriptDecl(SubscriptDecl* decl) - { - // A subscript operation may encompass one or more - // accessors, and these are what should actually - // get lowered (they are effectively functions). - - for (auto accessor : decl->getMembersOfType()) - { - if (accessor->HasModifier()) - continue; - - ensureDecl(context, accessor); - } - - // The subscript declaration itself won't correspond - // to anything in the lowered program, so we don't - // bother creating a representation here. - // - // Note: We may want to have a specific lowered value - // that can represent the combination of callables - // that make up the subscript operation. - return LoweredValInfo(); - } - - bool isGlobalVarDecl(VarDecl* decl) - { - auto parent = decl->ParentDecl; - if (as(parent)) - { - // Variable declared at global scope? -> Global. - return true; - } - else if(as(parent)) - { - if(decl->HasModifier()) - { - // A `static` member variable is effectively global. - return true; - } - } - - return false; - } - - bool isMemberVarDecl(VarDecl* decl) - { - auto parent = decl->ParentDecl; - if (as(parent)) - { - // A variable declared inside of an aggregate type declaration is a member. - return true; - } - - return false; - } - - LoweredValInfo lowerGlobalShaderParam(VarDecl* decl) - { - IRType* paramType = lowerType(context, decl->getType()); - - auto builder = getBuilder(); - - auto irParam = builder->createGlobalParam(paramType); - auto paramVal = LoweredValInfo::simple(irParam); - - addLinkageDecoration(context, irParam, decl); - addNameHint(context, irParam, decl); - maybeSetRate(context, irParam, decl); - addVarDecorations(context, irParam, decl); - - if (decl) - { - builder->addHighLevelDeclDecoration(irParam, decl); - } - - // A global variable's SSA value is a *pointer* to - // the underlying storage. - setGlobalValue(context, decl, paramVal); - - irParam->moveToEnd(); - - return paramVal; - } - - LoweredValInfo lowerGlobalVarDecl(VarDecl* decl) - { - if(isGlobalShaderParameter(decl)) - { - return lowerGlobalShaderParam(decl); - } - - IRType* varType = lowerType(context, decl->getType()); - - auto builder = getBuilder(); - - IRGlobalValueWithCode* irGlobal = nullptr; - LoweredValInfo globalVal; - - // a `static const` global is actually a compile-time constant - if (decl->HasModifier() && decl->HasModifier()) - { - irGlobal = builder->createGlobalConstant(varType); - globalVal = LoweredValInfo::simple(irGlobal); - } - else - { - irGlobal = builder->createGlobalVar(varType); - globalVal = LoweredValInfo::ptr(irGlobal); - } - addLinkageDecoration(context, irGlobal, decl); - addNameHint(context, irGlobal, decl); - - maybeSetRate(context, irGlobal, decl); - - addVarDecorations(context, irGlobal, decl); - - if (decl) - { - builder->addHighLevelDeclDecoration(irGlobal, decl); - } - - // A global variable's SSA value is a *pointer* to - // the underlying storage. - setGlobalValue(context, decl, globalVal); - - if (isImportedDecl(decl)) - { - // Always emit imported declarations as declarations, - // and not definitions. - } - else if( auto initExpr = decl->initExpr ) - { - IRBuilder subBuilderStorage = *getBuilder(); - IRBuilder* subBuilder = &subBuilderStorage; - - subBuilder->setInsertInto(irGlobal); - - IRGenContext subContextStorage = *context; - IRGenContext* subContext = &subContextStorage; - - subContext->irBuilder = subBuilder; - - // TODO: set up a parent IR decl to put the instructions into - - IRBlock* entryBlock = subBuilder->emitBlock(); - subBuilder->setInsertInto(entryBlock); - - LoweredValInfo initVal = lowerLValueExpr(subContext, initExpr); - subContext->irBuilder->emitReturn(getSimpleVal(subContext, initVal)); - } - - irGlobal->moveToEnd(); - - return globalVal; - } - - bool isFunctionStaticVarDecl(VarDeclBase* decl) - { - // Only a variable marked `static` can be static. - if(!decl->FindModifier()) - return false; - - // The immediate parent of a function-scope variable - // declaration will be a `ScopeDecl`. - // - // TODO: right now the parent links for scopes are *not* - // set correctly, so we can't just scan up and look - // for a function in the parent chain... - auto parent = decl->ParentDecl; - if( as(parent) ) - { - return true; - } - - return false; - } - - IRInst* defaultSpecializeOuterGeneric( - IRInst* outerVal, - IRType* type, - GenericDecl* genericDecl) - { - auto builder = getBuilder(); - - // We need to specialize any generics that are further out... - auto specialiedOuterVal = defaultSpecializeOuterGenerics( - outerVal, - builder->getGenericKind(), - genericDecl); - - List genericArgs; - - // Walk the parameters of the generic, and emit an argument for each, - // which will be a reference to binding for that parameter in the - // current scope. - // - // First we start with type and value parameters, - // in the order they were declared. - for (auto member : genericDecl->Members) - { - if (auto typeParamDecl = as(member)) - { - genericArgs.add(getSimpleVal(context, ensureDecl(context, typeParamDecl))); - } - else if (auto valDecl = as(member)) - { - genericArgs.add(getSimpleVal(context, ensureDecl(context, valDecl))); - } - } - // Then we emit constraint parameters, again in - // declaration order. - for (auto member : genericDecl->Members) - { - if (auto constraintDecl = as(member)) - { - genericArgs.add(getSimpleVal(context, ensureDecl(context, constraintDecl))); - } - } - - return builder->emitSpecializeInst(type, specialiedOuterVal, genericArgs.getCount(), genericArgs.getBuffer()); - } - - IRInst* defaultSpecializeOuterGenerics( - IRInst* val, - IRType* type, - Decl* decl) - { - if(!val) return nullptr; - - auto parentVal = val->getParent(); - while(parentVal) - { - if(as(parentVal)) - break; - parentVal = parentVal->getParent(); - } - if(!parentVal) - return val; - - for(auto pp = decl->ParentDecl; pp; pp = pp->ParentDecl) - { - if(auto genericAncestor = as(pp)) - { - return defaultSpecializeOuterGeneric(parentVal, type, genericAncestor); - } - } - - return val; - } - - struct NestedContext - { - IRGenEnv subEnvStorage; - IRBuilder subBuilderStorage; - IRGenContext subContextStorage; - - NestedContext(DeclLoweringVisitor* outer) - : subBuilderStorage(*outer->getBuilder()) - , subContextStorage(*outer->context) - { - auto outerContext = outer->context; - - subEnvStorage.outer = outerContext->env; - - subContextStorage.irBuilder = &subBuilderStorage; - subContextStorage.env = &subEnvStorage; - } - - IRBuilder* getBuilder() { return &subBuilderStorage; } - IRGenContext* getContet() { return &subContextStorage; } - }; - - LoweredValInfo lowerFunctionStaticConstVarDecl( - VarDeclBase* decl) - { - // We need to insert the constant at a level above - // the function being emitted. This will usually - // be the global scope, but it might be an outer - // generic if we are lowering a generic function. - // - NestedContext nestedContext(this); - auto subBuilder = nestedContext.getBuilder(); - auto subContext = nestedContext.getContet(); - - subBuilder->setInsertInto(subBuilder->getFunc()->getParent()); - - IRType* subVarType = lowerType(subContext, decl->getType()); - - IRGlobalConstant* irConstant = subBuilder->createGlobalConstant(subVarType); - addVarDecorations(subContext, irConstant, decl); - addNameHint(context, irConstant, decl); - maybeSetRate(context, irConstant, decl); - subBuilder->addHighLevelDeclDecoration(irConstant, decl); - - LoweredValInfo constantVal = LoweredValInfo::ptr(irConstant); - setValue(context, decl, constantVal); - - if( auto initExpr = decl->initExpr ) - { - NestedContext nestedInitContext(this); - auto initBuilder = nestedInitContext.getBuilder(); - auto initContext = nestedInitContext.getContet(); - - initBuilder->setInsertInto(irConstant); - - IRBlock* entryBlock = initBuilder->emitBlock(); - initBuilder->setInsertInto(entryBlock); - - LoweredValInfo initVal = lowerRValueExpr(initContext, initExpr); - initBuilder->emitReturn(getSimpleVal(initContext, initVal)); - } - - return constantVal; - } - - LoweredValInfo lowerFunctionStaticVarDecl( - VarDeclBase* decl) - { - // We know the variable is `static`, but it might also be `const. - if(decl->HasModifier()) - return lowerFunctionStaticConstVarDecl(decl); - - // A global variable may need to be generic, if one - // of the outer declarations is generic. - NestedContext nestedContext(this); - auto subBuilder = nestedContext.getBuilder(); - auto subContext = nestedContext.getContet(); - subBuilder->setInsertInto(subBuilder->getModule()->getModuleInst()); - emitOuterGenerics(subContext, decl, decl); - - IRType* subVarType = lowerType(subContext, decl->getType()); - - IRGlobalValueWithCode* irGlobal = subBuilder->createGlobalVar(subVarType); - addVarDecorations(subContext, irGlobal, decl); - - addNameHint(context, irGlobal, decl); - maybeSetRate(context, irGlobal, decl); - - subBuilder->addHighLevelDeclDecoration(irGlobal, decl); - - // We are inside of a function, and that function might be generic, - // in which case the `static` variable will be lowered to another - // generic. Let's start with a terrible example: - // - // interface IHasCount { int getCount(); } - // int incrementCounter(T val) { - // static int counter = 0; - // counter += val.getCount(); - // return counter; - // } - // - // In this case, `incrementCounter` will lower to a function - // nested in a generic, while `counter` will be lowered to - // a global variable nested in a *different* generic. - // The net result is something like this: - // - // int counter = 0; - // - // int incrementCounter(T val) { - // counter += val.getCount(); - // return counter; - // - // The references to `counter` inside of `incrementCounter` - // become references to `counter`. - // - // At the IR level, this means that the value we install - // for `decl` needs to be a specialized reference to `irGlobal`, - // for any outer generics. - // - IRType* varType = lowerType(context, decl->getType()); - IRType* varPtrType = getBuilder()->getPtrType(varType); - auto irSpecializedGlobal = defaultSpecializeOuterGenerics(irGlobal, varPtrType, decl); - LoweredValInfo globalVal = LoweredValInfo::ptr(irSpecializedGlobal); - setValue(context, decl, globalVal); - - // A `static` variable with an initializer needs special handling, - // at least if the initializer isn't a compile-time constant. - if( auto initExpr = decl->initExpr ) - { - // We must create an ordinary global `bool isInitialized = false` - // to represent whether we've initialized this before. - // Then emit code like: - // - // if(!isInitialized) { = ; isInitialized = true; } - // - // TODO: we could conceivably optimize this by detecting - // when the `initExpr` lowers to just a reference to a constant, - // and then either deleting the extra code structure there, - // or not generating it in the first place. That is a bit - // more complexity than I'm ready for at the moment. - // - - // Of course, if we are under a generic, then the Boolean - // variable need to be generic as well! - NestedContext nestedBoolContext(this); - auto boolBuilder = nestedBoolContext.getBuilder(); - auto boolContext = nestedBoolContext.getContet(); - boolBuilder->setInsertInto(boolBuilder->getModule()->getModuleInst()); - emitOuterGenerics(boolContext, decl, decl); - - auto irBoolType = boolBuilder->getBoolType(); - auto irBool = boolBuilder->createGlobalVar(irBoolType); - boolBuilder->setInsertInto(irBool); - boolBuilder->setInsertInto(boolBuilder->createBlock()); - boolBuilder->emitReturn(boolBuilder->getBoolValue(false)); - - auto boolVal = LoweredValInfo::ptr(defaultSpecializeOuterGenerics(irBool, irBoolType, decl)); - - - // Okay, with our global Boolean created, we can move on to - // generating the code we actually care about, back in the original function. - - auto builder = getBuilder(); - - auto initBlock = builder->createBlock(); - auto afterBlock = builder->createBlock(); - - builder->emitIfElse(getSimpleVal(context, boolVal), afterBlock, initBlock, afterBlock); - - builder->insertBlock(initBlock); - LoweredValInfo initVal = lowerLValueExpr(context, initExpr); - assign(context, globalVal, initVal); - assign(context, boolVal, LoweredValInfo::simple(builder->getBoolValue(true))); - builder->emitBranch(afterBlock); - - builder->insertBlock(afterBlock); - } - - irGlobal->moveToEnd(); - finishOuterGenerics(subBuilder, irGlobal); - return globalVal; - } - - LoweredValInfo visitGenericValueParamDecl(GenericValueParamDecl* decl) - { - return emitDeclRef(context, makeDeclRef(decl), - lowerType(context, decl->type)); - } - - LoweredValInfo visitVarDecl(VarDecl* decl) - { - // Detect global (or effectively global) variables - // and handle them differently. - if (isGlobalVarDecl(decl)) - { - return lowerGlobalVarDecl(decl); - } - - if(isFunctionStaticVarDecl(decl)) - { - return lowerFunctionStaticVarDecl(decl); - } - - if(isMemberVarDecl(decl)) - { - return lowerMemberVarDecl(decl); - } - - // A user-defined variable declaration will usually turn into - // an `alloca` operation for the variable's storage, - // plus some code to initialize it and then store to the variable. - - IRType* varType = lowerType(context, decl->getType()); - - // As a special case, an immutable local variable with an - // initializer can just lower to the SSA value of its initializer. - // - if(as(decl)) - { - if(auto initExpr = decl->initExpr) - { - auto initVal = lowerRValueExpr(context, initExpr); - initVal = materialize(context, initVal); - setGlobalValue(context, decl, initVal); - return initVal; - } - } - - - LoweredValInfo varVal = createVar(context, varType, decl); - - if( auto initExpr = decl->initExpr ) - { - auto initVal = lowerRValueExpr(context, initExpr); - - assign(context, varVal, initVal); - } - - setGlobalValue(context, decl, varVal); - - return varVal; - } - - IRStructKey* getInterfaceRequirementKey(Decl* requirementDecl) - { - return Slang::getInterfaceRequirementKey(context, requirementDecl); - } - - LoweredValInfo visitInterfaceDecl(InterfaceDecl* decl) - { - // The members of an interface will turn into the keys that will - // be used for lookup operations into witness - // tables that promise conformance to the interface. - // - // TODO: we don't handle the case here of an interface - // with concrete/default implementations for any - // of its members. - // - // TODO: If we want to support using an interface as - // an existential type, then we might need to emit - // a witness table for the interface type's conformance - // to its own interface. - // - for (auto requirementDecl : decl->Members) - { - getInterfaceRequirementKey(requirementDecl); - - // As a special case, any type constraints placed - // on an associated type will *also* need to be turned - // into requirement keys for this interface. - if (auto associatedTypeDecl = as(requirementDecl)) - { - for (auto constraintDecl : associatedTypeDecl->getMembersOfType()) - { - getInterfaceRequirementKey(constraintDecl); - } - } - } - - - NestedContext nestedContext(this); - auto subBuilder = nestedContext.getBuilder(); - auto subContext = nestedContext.getContet(); - - // Emit any generics that should wrap the actual type. - emitOuterGenerics(subContext, decl, decl); - - IRInterfaceType* irInterface = subBuilder->createInterfaceType(); - addNameHint(context, irInterface, decl); - addLinkageDecoration(context, irInterface, decl); - subBuilder->setInsertInto(irInterface); - - // TODO: are there any interface members that should be - // nested inside the interface type itself? - - irInterface->moveToEnd(); - - addTargetIntrinsicDecorations(irInterface, decl); - - - return LoweredValInfo::simple(finishOuterGenerics(subBuilder, irInterface)); - } - - LoweredValInfo visitEnumCaseDecl(EnumCaseDecl* decl) - { - // A case within an `enum` decl will lower to a value - // of the `enum`'s "tag" type. - // - // TODO: a bit more work will be needed if we allow for - // enum cases that have payloads, because then we need - // a function that constructs the value given arguments. - // - NestedContext nestedContext(this); - auto subContext = nestedContext.getContet(); - - // Emit any generics that should wrap the actual type. - emitOuterGenerics(subContext, decl, decl); - - return lowerRValueExpr(subContext, decl->tagExpr); - } - - LoweredValInfo visitEnumDecl(EnumDecl* decl) - { - // Given a declaration of a type, we need to make sure - // to output "witness tables" for any interfaces this - // type has declared conformance to. - for( auto inheritanceDecl : decl->getMembersOfType() ) - { - ensureDecl(context, inheritanceDecl); - } - - NestedContext nestedContext(this); - auto subBuilder = nestedContext.getBuilder(); - auto subContext = nestedContext.getContet(); - emitOuterGenerics(subContext, decl, decl); - - // An `enum` declaration will currently lower directly to its "tag" - // type, so that any references to the `enum` become referenes to - // the tag type instead. - // - // TODO: if we ever support `enum` types with payloads, we would - // need to make the `enum` lower to some kind of custom "tagged union" - // type. - - IRType* loweredTagType = lowerType(subContext, decl->tagType); - - return LoweredValInfo::simple(finishOuterGenerics(subBuilder, loweredTagType)); - } - - LoweredValInfo visitAggTypeDecl(AggTypeDecl* decl) - { - // Don't generate an IR `struct` for intrinsic types - if(decl->FindModifier() || decl->FindModifier()) - { - return LoweredValInfo(); - } - - // Given a declaration of a type, we need to make sure - // to output "witness tables" for any interfaces this - // type has declared conformance to. - for( auto inheritanceDecl : decl->getMembersOfType() ) - { - ensureDecl(context, inheritanceDecl); - } - - // We are going to create nested IR building state - // to use when emitting the members of the type. - // - NestedContext nestedContext(this); - auto subBuilder = nestedContext.getBuilder(); - auto subContext = nestedContext.getContet(); - - // Emit any generics that should wrap the actual type. - emitOuterGenerics(subContext, decl, decl); - - IRStructType* irStruct = subBuilder->createStructType(); - addNameHint(context, irStruct, decl); - addLinkageDecoration(context, irStruct, decl); - - subBuilder->setInsertInto(irStruct); - - for (auto fieldDecl : decl->getMembersOfType()) - { - if (fieldDecl->HasModifier()) - { - // A `static` field is actually a global variable, - // and we should emit it as such. - ensureDecl(context, fieldDecl); - continue; - } - - // Each ordinary field will need to turn into a struct "key" - // that is used for fetching the field. - IRInst* fieldKeyInst = getSimpleVal(context, - ensureDecl(context, fieldDecl)); - auto fieldKey = as(fieldKeyInst); - SLANG_ASSERT(fieldKey); - - // Note: we lower the type of the field in the "sub" - // context, so that any generic parameters that were - // set up for the type can be referenced by the field type. - IRType* fieldType = lowerType( - subContext, - fieldDecl->getType()); - - // Then, the parent `struct` instruction itself will have - // a "field" instruction. - subBuilder->createStructField( - irStruct, - fieldKey, - fieldType); - } - - // There may be members not handled by the above logic (e.g., - // member functions), but we will not immediately force them - // to be emitted here, so as not to risk a circular dependency. - // - // Instead we will force emission of all children of aggregate - // type declarations later, from the top-level emit logic. - - irStruct->moveToEnd(); - addTargetIntrinsicDecorations(irStruct, decl); - - return LoweredValInfo::simple(finishOuterGenerics(subBuilder, irStruct)); - } - - LoweredValInfo lowerMemberVarDecl(VarDecl* fieldDecl) - { - // Each field declaration in the AST translates into - // a "key" that can be used to extract field values - // from instances of struct types that contain the field. - // - // It is correct to say struct *types* because a `struct` - // nested under a generic can be used to realize a number - // of different concrete types, but all of these types - // will use the same space of keys. - - auto builder = getBuilder(); - auto irFieldKey = builder->createStructKey(); - addNameHint(context, irFieldKey, fieldDecl); - - addVarDecorations(context, irFieldKey, fieldDecl); - - addLinkageDecoration(context, irFieldKey, fieldDecl); - - if (auto semanticModifier = fieldDecl->FindModifier()) - { - builder->addSemanticDecoration(irFieldKey, semanticModifier->name.getName()->text.getUnownedSlice()); - } - - // We allow a field to be marked as a target intrinsic, - // so that we can override its mangled name in the - // output for the chosen target. - addTargetIntrinsicDecorations(irFieldKey, fieldDecl); - - - return LoweredValInfo::simple(irFieldKey); - } - - - DeclRef createDefaultSpecializedDeclRefImpl(Decl* decl) - { - DeclRef declRef; - declRef.decl = decl; - declRef.substitutions = createDefaultSubstitutions(context->getSession(), decl); - return declRef; - } - // - // The client should actually call the templated wrapper, to preserve type information. - template - DeclRef createDefaultSpecializedDeclRef(D* decl) - { - DeclRef declRef = createDefaultSpecializedDeclRefImpl(decl); - return declRef.as(); - } - - - // When lowering something callable (most commonly a function declaration), - // we need to construct an appropriate parameter list for the IR function - // that folds in any contributions from both the declaration itself *and* - // its parent declaration(s). - // - // For example, given code like: - // - // struct Foo { int bar(float y) { ... } }; - // - // we need to generate IR-level code something like: - // - // func Foo_bar(Foo this, float y) -> int; - // - // that is, the `this` parameter has become explicit. - // - // The same applies to generic parameters, and these - // should apply even if the nested declaration is `static`: - // - // struct Foo { static int bar(T y) { ... } }; - // - // becomes: - // - // func Foo_bar(T y) -> int; - // - // In order to implement this, we are going to do a recursive - // walk over a declaration and its parents, collecting separate - // lists of ordinary and generic parameters that will need - // to be included in the final declaration's parameter list. - // - // When doing code generation for an ordinary value parameter, - // we mostly care about its type, and then also its "direction" - // (`in`, `out`, `in out`). We sometimes need acess to the - // original declaration so that we can inspect it for meta-data, - // but in some cases there is no such declaration (e.g., a `this` - // parameter doesn't get an explicit declaration in the AST). - // To handle this we break out the relevant data into derived - // structures: - // - enum ParameterDirection - { - kParameterDirection_In, ///< Copy in - kParameterDirection_Out, ///< Copy out - kParameterDirection_InOut, ///< Copy in, copy out - kParameterDirection_Ref, ///< By-reference - }; - struct ParameterInfo - { - // This AST-level type of the parameter - RefPtr type; - - // The direction (`in` vs `out` vs `in out`) - ParameterDirection direction; - - // The variable/parameter declaration for - // this parameter (if any) - VarDeclBase* decl; - - // Is this the representation of a `this` parameter? - bool isThisParam = false; - }; - // - // We need a way to compute the appropriate `ParameterDirection` for a - // declared parameter: - // - ParameterDirection getParameterDirection(VarDeclBase* paramDecl) - { - if( paramDecl->HasModifier() ) - { - // The AST specified `ref`: - return kParameterDirection_Ref; - } - if( paramDecl->HasModifier() ) - { - // The AST specified `inout`: - return kParameterDirection_InOut; - } - if (paramDecl->HasModifier()) - { - // We saw an `out` modifier, so now we need - // to check if there was a paired `in`. - if(paramDecl->HasModifier()) - return kParameterDirection_InOut; - else - return kParameterDirection_Out; - } - else - { - // No direction modifier, or just `in`: - return kParameterDirection_In; - } - } - // We need a way to be able to create a `ParameterInfo` given the declaration - // of a parameter: - // - ParameterInfo getParameterInfo(VarDeclBase* paramDecl) - { - ParameterInfo info; - info.type = paramDecl->getType(); - info.decl = paramDecl; - info.direction = getParameterDirection(paramDecl); - info.isThisParam = false; - return info; - } - // - - // Here's the declaration for the type to hold the lists: - struct ParameterLists - { - List params; - }; - // - // Because there might be a `static` declaration somewhere - // along the lines, we need to be careful to prohibit adding - // non-generic parameters in some cases. - enum ParameterListCollectMode - { - // Collect everything: ordinary and generic parameters. - kParameterListCollectMode_Default, - - - // Only collect generic parameters. - kParameterListCollectMode_Static, - }; - // - // We also need to be able to detect whether a declaration is - // either explicitly or implicitly treated as `static`: - ParameterListCollectMode getModeForCollectingParentParameters( - Decl* decl, - ContainerDecl* parentDecl) - { - // If we have a `static` parameter, then it is obvious - // that we should use the `static` mode - if(isEffectivelyStatic(decl, parentDecl)) - return kParameterListCollectMode_Static; - - // Otherwise, let's default to collecting everything - return kParameterListCollectMode_Default; - } - // - // When dealing with a member function, we need to be able to add the `this` - // parameter for the enclosing type: - // - void addThisParameter( - ParameterDirection direction, - Type* type, - ParameterLists* ioParameterLists) - { - ParameterInfo info; - info.type = type; - info.decl = nullptr; - info.direction = direction; - info.isThisParam = true; - - ioParameterLists->params.add(info); - } - void addThisParameter( - ParameterDirection direction, - AggTypeDecl* typeDecl, - ParameterLists* ioParameterLists) - { - // We need to construct an appopriate declaration-reference - // for the type declaration we were given. In particular, - // we need to specialize it for any generic parameters - // that are in scope here. - auto declRef = createDefaultSpecializedDeclRef(typeDecl); - RefPtr type = DeclRefType::Create(context->getSession(), declRef); - addThisParameter( - direction, - type, - ioParameterLists); - } - // - // And here is our function that will do the recursive walk: - void collectParameterLists( - Decl* decl, - ParameterLists* ioParameterLists, - ParameterListCollectMode mode) - { - // The parameters introduced by any "parent" declarations - // will need to come first, so we'll deal with that - // logic here. - if( auto parentDecl = decl->ParentDecl ) - { - // Compute the mode to use when collecting parameters from - // the outer declaration. The most important question here - // is whether parameters of the outer declaration should - // also count as parameters of the inner declaration. - ParameterListCollectMode innerMode = getModeForCollectingParentParameters(decl, parentDecl); - - // Don't down-grade our `static`-ness along the chain. - if(innerMode < mode) - innerMode = mode; - - // Now collect any parameters from the parent declaration itself - collectParameterLists(parentDecl, ioParameterLists, innerMode); - - // We also need to consider whether the inner declaration needs to have a `this` - // parameter corresponding to the outer declaration. - if( innerMode != kParameterListCollectMode_Static ) - { - // For now we make any `this` parameter default to `in`. - // - ParameterDirection direction = kParameterDirection_In; - // - // Applications can opt in to a mutable `this` parameter, - // by applying the `[mutating]` attribute to their - // declaration. - // - if( decl->HasModifier() ) - { - direction = kParameterDirection_InOut; - } - - if( auto aggTypeDecl = as(parentDecl) ) - { - addThisParameter(direction, aggTypeDecl, ioParameterLists); - } - else if( auto extensionDecl = as(parentDecl) ) - { - addThisParameter(direction, extensionDecl->targetType, ioParameterLists); - } - } - } - - // Once we've added any parameters based on parent declarations, - // we can see if this declaration itself introduces parameters. - // - if( auto callableDecl = as(decl) ) - { - // Don't collect parameters from the outer scope if - // we are in a `static` context. - if( mode == kParameterListCollectMode_Default ) - { - for( auto paramDecl : callableDecl->GetParameters() ) - { - ioParameterLists->params.add(getParameterInfo(paramDecl)); - } - } - } - } - - bool isImportedDecl(Decl* decl) - { - return Slang::isImportedDecl(context, decl); - } - - bool isConstExprVar(Decl* decl) - { - if( decl->HasModifier() ) - { - return true; - } - else if(decl->HasModifier() && decl->HasModifier()) - { - return true; - } - - return false; - } - - IRType* maybeGetConstExprType(IRType* type, Decl* decl) - { - if(isConstExprVar(decl)) - { - return getBuilder()->getRateQualifiedType( - getBuilder()->getConstExprRate(), - type); - } - - return type; - } - - IRGeneric* emitOuterGeneric( - IRGenContext* subContext, - GenericDecl* genericDecl, - Decl* leafDecl) - { - auto subBuilder = subContext->irBuilder; - - // Of course, a generic might itself be nested inside of other generics... - emitOuterGenerics(subContext, genericDecl, leafDecl); - - // We need to create an IR generic - - auto irGeneric = subBuilder->emitGeneric(); - subBuilder->setInsertInto(irGeneric); - - auto irBlock = subBuilder->emitBlock(); - subBuilder->setInsertInto(irBlock); - - // Now emit any parameters of the generic - // - // First we start with type and value parameters, - // in the order they were declared. - for (auto member : genericDecl->Members) - { - if (auto typeParamDecl = as(member)) - { - // TODO: use a `TypeKind` to represent the - // classifier of the parameter. - auto param = subBuilder->emitParam(nullptr); - addNameHint(context, param, typeParamDecl); - setValue(subContext, typeParamDecl, LoweredValInfo::simple(param)); - } - else if (auto valDecl = as(member)) - { - auto paramType = lowerType(subContext, valDecl->getType()); - auto param = subBuilder->emitParam(paramType); - addNameHint(context, param, valDecl); - setValue(subContext, valDecl, LoweredValInfo::simple(param)); - } - } - // Then we emit constraint parameters, again in - // declaration order. - for (auto member : genericDecl->Members) - { - if (auto constraintDecl = as(member)) - { - // TODO: use a `WitnessTableKind` to represent the - // classifier of the parameter. - auto param = subBuilder->emitParam(nullptr); - addNameHint(context, param, constraintDecl); - setValue(subContext, constraintDecl, LoweredValInfo::simple(param)); - } - } - - return irGeneric; - } - - // If the given `decl` is enclosed in any generic declarations, then - // emit IR-level generics to represent them. - // The `leafDecl` represents the inner-most declaration we are actually - // trying to emit, which is the one that should receive the mangled name. - // - IRGeneric* emitOuterGenerics(IRGenContext* subContext, Decl* decl, Decl* leafDecl) - { - for(auto pp = decl->ParentDecl; pp; pp = pp->ParentDecl) - { - if(auto genericAncestor = as(pp)) - { - return emitOuterGeneric(subContext, genericAncestor, leafDecl); - } - } - - return nullptr; - } - - // If any generic declarations have been created by `emitOuterGenerics`, - // then finish them off by emitting `return` instructions for the - // values that they should produce. - // - // Return the outer-most generic (if there is one), or the original - // value (if there were no generics), which should be the IR-level - // representation of the original declaration. - // - IRInst* finishOuterGenerics( - IRBuilder* subBuilder, - IRInst* val) - { - IRInst* v = val; - for(;;) - { - auto parentBlock = as(v->getParent()); - if (!parentBlock) break; - - auto parentGeneric = as(parentBlock->getParent()); - if (!parentGeneric) break; - - subBuilder->setInsertInto(parentBlock); - subBuilder->emitReturn(v); - parentGeneric->moveToEnd(); - - // There might be more outer generics, - // so we need to loop until we run out. - v = parentGeneric; - } - return v; - } - - // Attach target-intrinsic decorations to an instruction, - // based on modifiers on an AST declaration. - void addTargetIntrinsicDecorations( - IRInst* irInst, - Decl* decl) - { - auto builder = getBuilder(); - - for (auto targetMod : decl->GetModifiersOfType()) - { - String definition; - auto definitionToken = targetMod->definitionToken; - if (definitionToken.type == TokenType::StringLiteral) - { - definition = getStringLiteralTokenValue(definitionToken); - } - else - { - definition = definitionToken.Content; - } - - builder->addTargetIntrinsicDecoration(irInst, targetMod->targetToken.Content, definition.getUnownedSlice()); - } - } - - void addParamNameHint(IRInst* inst, ParameterInfo info) - { - if(auto decl = info.decl) - { - addNameHint(context, inst, decl); - } - else if( info.isThisParam ) - { - addNameHint(context, inst, "this"); - } - } - - LoweredValInfo lowerFuncDecl(FunctionDeclBase* decl) - { - // We are going to use a nested builder, because we will - // change the parent node that things get nested into. - // - NestedContext nestedContext(this); - auto subBuilder = nestedContext.getBuilder(); - auto subContext = nestedContext.getContet(); - - // The actual `IRFunction` that we emit needs to be nested - // inside of one `IRGeneric` for every outer `GenericDecl` - // in the declaration hierarchy. - - emitOuterGenerics(subContext, decl, decl); - - // Collect the parameter lists we will use for our new function. - ParameterLists parameterLists; - collectParameterLists(decl, ¶meterLists, kParameterListCollectMode_Default); - - // TODO: if there are any generic parameters in the collected list, then - // we need to output an IR function with generic parameters (or a generic - // with a nested function... the exact representation is still TBD). - - // In most cases the return type for a declaration can be read off the declaration - // itself, but things get a bit more complicated when we have to deal with - // accessors for subscript declarations (and eventually for properties). - // - // We compute a declaration to use for looking up the return type here: - CallableDecl* declForReturnType = decl; - if (auto accessorDecl = as(decl)) - { - // We are some kind of accessor, so the parent declaration should - // know the correct return type to expose. - // - auto parentDecl = accessorDecl->ParentDecl; - if (auto subscriptDecl = as(parentDecl)) - { - declForReturnType = subscriptDecl; - } - } - - // need to create an IR function here - - IRFunc* irFunc = subBuilder->createFunc(); - addNameHint(context, irFunc, decl); - addLinkageDecoration(context, irFunc, decl); - - List paramTypes; - - for( auto paramInfo : parameterLists.params ) - { - IRType* irParamType = lowerType(subContext, paramInfo.type); - - switch( paramInfo.direction ) - { - case kParameterDirection_In: - // Simple case of a by-value input parameter. - break; - - // If the parameter is declared `out` or `inout`, - // then we will represent it with a pointer type in - // the IR, but we will use a specialized pointer - // type that encodes the parameter direction information. - case kParameterDirection_Out: - irParamType = subBuilder->getOutType(irParamType); - break; - case kParameterDirection_InOut: - irParamType = subBuilder->getInOutType(irParamType); - break; - case kParameterDirection_Ref: - irParamType = subBuilder->getRefType(irParamType); - break; - - default: - SLANG_UNEXPECTED("unknown parameter direction"); - break; - } - - // If the parameter was explicitly marked as being a compile-time - // constant (`constexpr`), then attach that information to its - // IR-level type explicitly. - if( paramInfo.decl ) - { - irParamType = maybeGetConstExprType(irParamType, paramInfo.decl); - } - - paramTypes.add(irParamType); - } - - auto irResultType = lowerType(subContext, declForReturnType->ReturnType); - - if (auto setterDecl = as(decl)) - { - // We are lowering a "setter" accessor inside a subscript - // declaration, which means we don't want to *return* the - // stated return type of the subscript, but instead take - // it as a parameter. - // - IRType* irParamType = irResultType; - paramTypes.add(irParamType); - - // Instead, a setter always returns `void` - // - irResultType = subBuilder->getVoidType(); - } - - if( auto refAccessorDecl = as(decl) ) - { - // A `ref` accessor needs to return a *pointer* to the value - // being accessed, rather than a simple value. - irResultType = subBuilder->getPtrType(irResultType); - } - - auto irFuncType = subBuilder->getFuncType( - paramTypes.getCount(), - paramTypes.getBuffer(), - irResultType); - irFunc->setFullType(irFuncType); - - subBuilder->setInsertInto(irFunc); - - if (isImportedDecl(decl)) - { - // Always emit imported declarations as declarations, - // and not definitions. - } - else if (!decl->Body) - { - // This is a function declaration without a body. - // In Slang we currently try not to support forward declarations - // (although we might have to give in eventually), so - // this case should really only occur for builtin declarations. - } - else - { - // This is a function definition, so we need to actually - // construct IR for the body... - IRBlock* entryBlock = subBuilder->emitBlock(); - subBuilder->setInsertInto(entryBlock); - - UInt paramTypeIndex = 0; - for( auto paramInfo : parameterLists.params ) - { - auto irParamType = paramTypes[paramTypeIndex++]; - - LoweredValInfo paramVal; - - switch( paramInfo.direction ) - { - default: - { - // The parameter is being used for input/output purposes, - // so it will lower to an actual parameter with a pointer type. - // - // TODO: Is this the best representation we can use? - - IRParam* irParamPtr = subBuilder->emitParam(irParamType); - if(auto paramDecl = paramInfo.decl) - { - addVarDecorations(context, irParamPtr, paramDecl); - subBuilder->addHighLevelDeclDecoration(irParamPtr, paramDecl); - } - addParamNameHint(irParamPtr, paramInfo); - - paramVal = LoweredValInfo::ptr(irParamPtr); - - // TODO: We might want to copy the pointed-to value into - // a temporary at the start of the function, and then copy - // back out at the end, so that we don't have to worry - // about things like aliasing in the function body. - // - // For now we will just use the storage that was passed - // in by the caller, knowing that our current lowering - // at call sites will guarantee a fresh/unique location. - } - break; - - case kParameterDirection_In: - { - // Simple case of a by-value input parameter. - // - // We start by declaring an IR parameter of the same type. - // - auto paramDecl = paramInfo.decl; - IRParam* irParam = subBuilder->emitParam(irParamType); - if( paramDecl ) - { - addVarDecorations(context, irParam, paramDecl); - subBuilder->addHighLevelDeclDecoration(irParam, paramDecl); - } - addParamNameHint(irParam, paramInfo); - paramVal = LoweredValInfo::simple(irParam); - // - // HLSL allows a function parameter to be used as a local - // variable in the function body (just like C/C++), so - // we need to support that case as well. - // - // However, if we notice that the parameter was marked - // `const`, then we can skip this step. - // - // TODO: we should consider having all parameter be implicitly - // immutable except in a specific "compatibility mode." - // - if(paramDecl && paramDecl->FindModifier()) - { - // This parameter was declared to be immutable, - // so there should be no assignment to it in the - // function body, and we don't need a temporary. - } - else - { - // The parameter migth get used as a temporary in - // the function body. We will allocate a mutable - // local variable for is value, and then assign - // from the parameter to the local at the start - // of the function. - // - auto irLocal = subBuilder->emitVar(irParamType); - auto localVal = LoweredValInfo::ptr(irLocal); - assign(subContext, localVal, paramVal); - // - // When code later in the body of the function refers - // to the parameter declaration, it will actually refer - // to the value stored in the local variable. - // - paramVal = localVal; - } - } - break; - } - - if( auto paramDecl = paramInfo.decl ) - { - setValue(subContext, paramDecl, paramVal); - } - - if (paramInfo.isThisParam) - { - subContext->thisVal = paramVal; - } - } - - if (auto setterDecl = as(decl)) - { - // Add the IR parameter for the new value - IRType* irParamType = irResultType; - auto irParam = subBuilder->emitParam(irParamType); - addNameHint(context, irParam, "newValue"); - - // TODO: we need some way to wire this up to the `newValue` - // or whatever name we give for that parameter inside - // the setter body. - } - - { - - auto attr = decl->FindModifier(); - - // I needed to test for patchConstantFuncDecl here - // because it is only set if validateEntryPoint is called with Hull as the required stage - // If I just build domain shader, and then the attribute exists, but patchConstantFuncDecl is not set - // and thus leads to a crash. - if (attr && attr->patchConstantFuncDecl) - { - // We need to lower the function - FuncDecl* patchConstantFunc = attr->patchConstantFuncDecl; - assert(patchConstantFunc); - - // Convert the patch constant function into IRInst - IRInst* irPatchConstantFunc = getSimpleVal(context, ensureDecl(subContext, patchConstantFunc)); - - // Attach a decoration so that our IR function references - // the patch constant function. - // - subContext->irBuilder->addPatchConstantFuncDecoration( - irFunc, - irPatchConstantFunc); - - } - } - - // Lower body - - lowerStmt(subContext, decl->Body); - - // We need to carefully add a terminator instruction to the end - // of the body, in case the user didn't do so. - if (!subContext->irBuilder->getBlock()->getTerminator()) - { - if(as(irResultType)) - { - // `void`-returning function can get an implicit - // return on exit of the body statement. - subContext->irBuilder->emitReturn(); - } - else - { - // Value-returning function is expected to `return` - // on every control-flow path. We need to enforce - // this by putting an `unreachable` terminator here, - // and then emit a dataflow error if this block - // can't be eliminated. - subContext->irBuilder->emitMissingReturn(); - } - } - } - - getBuilder()->addHighLevelDeclDecoration(irFunc, decl); - - // If this declaration was marked as being an intrinsic for a particular - // target, then we should reflect that here. - for( auto targetMod : decl->GetModifiersOfType() ) - { - // `targetMod` indicates that this particular declaration represents - // a specialized definition of the particular function for the given - // target, and we need to reflect that at the IR level. - - getBuilder()->addTargetDecoration(irFunc, targetMod->targetToken.Content); - } - - // If this declaration was marked as having a target-specific lowering - // for a particular target, then handle that here. - addTargetIntrinsicDecorations(irFunc, decl); - - // If this declaration requires certain GLSL extension (or a particular GLSL version) - // for it to be usable, then declare that here. - // - // TODO: We should wrap this an `SpecializedForTargetModifier` together into a single - // case for enumerating the "capabilities" that a declaration requires. - // - for(auto extensionMod : decl->GetModifiersOfType()) - { - getBuilder()->addRequireGLSLExtensionDecoration(irFunc, extensionMod->extensionNameToken.Content); - } - for(auto versionMod : decl->GetModifiersOfType()) - { - getBuilder()->addRequireGLSLVersionDecoration(irFunc, Int(getIntegerLiteralValue(versionMod->versionNumberToken))); - } - - if(decl->FindModifier()) - { - getBuilder()->addSimpleDecoration(irFunc); - } - - if (decl->FindModifier()) - { - getBuilder()->addSimpleDecoration(irFunc); - } - - // For convenience, ensure that any additional global - // values that were emitted while outputting the function - // body appear before the function itself in the list - // of global values. - irFunc->moveToEnd(); - return LoweredValInfo::simple(finishOuterGenerics(subBuilder, irFunc)); - } - - LoweredValInfo visitGenericDecl(GenericDecl * genDecl) - { - // TODO: Should this just always visit/lower the inner decl? - - if (auto innerFuncDecl = as(genDecl->inner)) - return ensureDecl(context, innerFuncDecl); - else if (auto innerStructDecl = as(genDecl->inner)) - { - ensureDecl(context, innerStructDecl); - return LoweredValInfo(); - } - else if( auto extensionDecl = as(genDecl->inner) ) - { - return ensureDecl(context, extensionDecl); - } - SLANG_RELEASE_ASSERT(false); - UNREACHABLE_RETURN(LoweredValInfo()); - } - - LoweredValInfo visitFunctionDeclBase(FunctionDeclBase* decl) - { - // A function declaration may have multiple, target-specific - // overloads, and we need to emit an IR version of each of these. - - // The front end will form a linked list of declarations with - // the same signature, whenever there is any kind of redeclaration. - // We will look to see if that linked list has been formed. - auto primaryDecl = decl->primaryDecl; - - if (!primaryDecl) - { - // If there is no linked list then we are in the ordinary - // case with a single declaration, and no special handling - // is needed. - return lowerFuncDecl(decl); - } - - // Otherwise, we need to walk the linked list of declarations - // and make sure to emit IR code for any targets that need it. - - // TODO: Need to be careful about how this is approached, - // to avoid emitting a bunch of extra definitions in the IR. - - auto primaryFuncDecl = as(primaryDecl); - SLANG_ASSERT(primaryFuncDecl); - LoweredValInfo result = lowerFuncDecl(primaryFuncDecl); - for (auto dd = primaryDecl->nextDecl; dd; dd = dd->nextDecl) - { - auto funcDecl = as(dd); - SLANG_ASSERT(funcDecl); - lowerFuncDecl(funcDecl); - } - return result; - } -}; - -LoweredValInfo lowerDecl( - IRGenContext* context, - DeclBase* decl) -{ - IRBuilderSourceLocRAII sourceLocInfo(context->irBuilder, decl->loc); - - DeclLoweringVisitor visitor; - visitor.context = context; - - try - { - return visitor.dispatch(decl); - } - // Don't emit any context message for an explicit `AbortCompilationException` - // because it should only happen when an error is already emitted. - catch(AbortCompilationException&) { throw; } - catch(...) - { - context->getSink()->noteInternalErrorLoc(decl->loc); - throw; - } -} - -// Ensure that a version of the given declaration has been emitted to the IR -LoweredValInfo ensureDecl( - IRGenContext* context, - Decl* decl) -{ - auto shared = context->shared; - - LoweredValInfo result; - - // Look for an existing value installed in this context - auto env = context->env; - while(env) - { - if(env->mapDeclToValue.TryGetValue(decl, result)) - return result; - - env = env->outer; - } - - IRBuilder subIRBuilder; - subIRBuilder.sharedBuilder = context->irBuilder->sharedBuilder; - subIRBuilder.setInsertInto(subIRBuilder.sharedBuilder->module->getModuleInst()); - - IRGenEnv subEnv; - subEnv.outer = context->env; - - IRGenContext subContext = *context; - subContext.irBuilder = &subIRBuilder; - subContext.env = &subEnv; - - result = lowerDecl(&subContext, decl); - - // By default assume that any value we are lowering represents - // something that should be installed globally. - setGlobalValue(shared, decl, result); - - return result; -} - -IRInst* lowerSubstitutionArg( - IRGenContext* context, - Val* val) -{ - if (auto type = dynamicCast(val)) - { - return lowerType(context, type); - } - else if (auto declaredSubtypeWitness = as(val)) - { - // We need to look up the IR-level representation of the witness (which will be a witness table). - auto irWitnessTable = getSimpleVal( - context, - emitDeclRef( - context, - declaredSubtypeWitness->declRef, - context->irBuilder->getWitnessTableType())); - return irWitnessTable; - } - else - { - SLANG_UNIMPLEMENTED_X("value cases"); - UNREACHABLE_RETURN(nullptr); - } -} - -// Can the IR lowered version of this declaration ever be an `IRGeneric`? -bool canDeclLowerToAGeneric(RefPtr decl) -{ - // A callable decl lowers to an `IRFunc`, and can be generic - if(as(decl)) return true; - - // An aggregate type decl lowers to an `IRStruct`, and can be generic - if(as(decl)) return true; - - // An inheritance decl lowers to an `IRWitnessTable`, and can be generic - if(as(decl)) return true; - - // A `typedef` declaration nested under a generic will turn into - // a generic that returns a type (a simple type-level function). - if(as(decl)) return true; - - return false; -} - -LoweredValInfo emitDeclRef( - IRGenContext* context, - RefPtr decl, - RefPtr subst, - IRType* type) -{ - // We need to proceed by considering the specializations that - // have been put in place. - - // Ignore any global generic type substitutions during lowering. - // Really, we don't even expect these to appear. - while(auto globalGenericSubst = as(subst)) - subst = globalGenericSubst->outer; - - // If the declaration would not get wrapped in a `IRGeneric`, - // even if it is nested inside of an AST `GenericDecl`, then - // we should also ignore any generic substitutions. - if(!canDeclLowerToAGeneric(decl)) - { - while(auto genericSubst = as(subst)) - subst = genericSubst->outer; - } - - // In the simplest case, there is no specialization going - // on, and the decl-ref turns into a reference to the - // lowered IR value for the declaration. - if(!subst) - { - LoweredValInfo loweredDecl = ensureDecl(context, decl); - return loweredDecl; - } - - // Otherwise, we look at the kind of substitution, and let it guide us. - if(auto genericSubst = subst.as()) - { - // A generic substitution means we will need to output - // a `specialize` instruction to specialize the generic. - // - // First we want to emit the value without generic specialization - // applied, to get a correct value for it. - // - // Note: we only "unwrap" a single layer from the - // substitutions here, because the underlying declaration - // might be nested in multiple generics, or it might - // come from an interface. - // - LoweredValInfo genericVal = emitDeclRef( - context, - decl, - genericSubst->outer, - context->irBuilder->getGenericKind()); - - // There's no reason to specialize something that maps to a NULL pointer. - if (genericVal.flavor == LoweredValInfo::Flavor::None) - return LoweredValInfo(); - - // We can only really specialize things that map to single values. - // It would be an error if we got a non-`None` value that - // wasn't somehow a single value. - auto irGenericVal = getSimpleVal(context, genericVal); - - // We have the IR value for the generic we'd like to specialize, - // and now we need to get the value for the arguments. - List irArgs; - for (auto argVal : genericSubst->args) - { - auto irArgVal = lowerSimpleVal(context, argVal); - SLANG_ASSERT(irArgVal); - irArgs.add(irArgVal); - } - - // Once we have both the generic and its arguments, - // we can emit a `specialize` instruction and use - // its value as the result. - auto irSpecializedVal = context->irBuilder->emitSpecializeInst( - type, - irGenericVal, - irArgs.getCount(), - irArgs.getBuffer()); - - return LoweredValInfo::simple(irSpecializedVal); - } - else if(auto thisTypeSubst = subst.as()) - { - if(decl.Ptr() == thisTypeSubst->interfaceDecl) - { - // This is a reference to the interface type itself, - // through the this-type substitution, so it is really - // a reference to the this-type. - return lowerType(context, thisTypeSubst->witness->sub); - } - - // Somebody is trying to look up an interface requirement - // "through" some concrete type. We need to lower this decl-ref - // as a lookup of the corresponding member in a witness table. - // - // The witness table itself is referenced by the this-type - // substitution, so we can just lower that. - // - // Note: unlike the case for generics above, in the interface-lookup - // case, we don't end up caring about any further outer substitutions. - // That is because even if we are naming `ISomething.doIt()`, - // a method inside a generic interface, we don't actually care - // about the substitution of `Foo` for the parameter `T` of - // `ISomething`. That is because we really care about the - // witness table for the concrete type that conforms to `ISomething`. - // - auto irWitnessTable = lowerSimpleVal(context, thisTypeSubst->witness); - // - // The key to use for looking up the interface member is - // derived from the declaration. - // - auto irRequirementKey = getInterfaceRequirementKey(context, decl); - // - // Those two pieces of information tell us what we need to - // do in order to look up the value that satisfied the requirement. - // - auto irSatisfyingVal = context->irBuilder->emitLookupInterfaceMethodInst( - type, - irWitnessTable, - irRequirementKey); - return LoweredValInfo::simple(irSatisfyingVal); - } - else - { - SLANG_UNEXPECTED("uhandled substitution type"); - UNREACHABLE_RETURN(LoweredValInfo()); - } -} - -LoweredValInfo emitDeclRef( - IRGenContext* context, - DeclRef declRef, - IRType* type) -{ - return emitDeclRef( - context, - declRef.decl, - declRef.substitutions.substitutions, - type); -} - -static void lowerFrontEndEntryPointToIR( - IRGenContext* context, - EntryPoint* entryPoint) -{ - // TODO: We should emit an entry point as a dedicated IR function - // (distinct from the IR function used if it were called normally), - // with a mangled name based on the original function name plus - // the stage for which it is being compiled as an entry point (so - // that entry points for distinct stages always have distinct names). - // - // For now we just have an (implicit) constraint that a given - // function should only be used as an entry point for one stage, - // and any such function should *not* be used as an ordinary function. - - auto entryPointFuncDecl = entryPoint->getFuncDecl(); - - auto builder = context->irBuilder; - builder->setInsertInto(builder->getModule()->getModuleInst()); - - auto loweredEntryPointFunc = getSimpleVal(context, - ensureDecl(context, entryPointFuncDecl)); - - // Attach a marker decoration so that we recognize - // this as an entry point. - // - IRInst* instToDecorate = loweredEntryPointFunc; - if(auto irGeneric = as(instToDecorate)) - { - instToDecorate = findGenericReturnVal(irGeneric); - } - builder->addEntryPointDecoration(instToDecorate); -} - -static void lowerProgramEntryPointToIR( - IRGenContext* context, - EntryPoint* entryPoint) -{ - // First, lower the entry point like an ordinary function - - auto session = context->getSession(); - auto entryPointFuncDeclRef = entryPoint->getFuncDeclRef(); - auto entryPointFuncType = lowerType(context, getFuncType(session, entryPointFuncDeclRef)); - - auto builder = context->irBuilder; - builder->setInsertInto(builder->getModule()->getModuleInst()); - - auto loweredEntryPointFunc = getSimpleVal(context, - emitDeclRef(context, entryPointFuncDeclRef, entryPointFuncType)); - - // - if(!loweredEntryPointFunc->findDecoration()) - { - builder->addExportDecoration(loweredEntryPointFunc, getMangledName(entryPointFuncDeclRef).getUnownedSlice()); - } - - // We may have shader parameters of interface/existential type, - // which need us to supply concrete type information for specialization. - // - auto existentialTypeArgCount = entryPoint->getExistentialTypeArgCount(); - if( existentialTypeArgCount ) - { - List existentialSlotArgs; - for( Index ii = 0; ii < existentialTypeArgCount; ++ii ) - { - auto arg = entryPoint->getExistentialTypeArg(ii); - - auto irArgType = lowerType(context, arg.type); - auto irWitnessTable = lowerSimpleVal(context, arg.witness); - - existentialSlotArgs.add(irArgType); - existentialSlotArgs.add(irWitnessTable); - } - - builder->addBindExistentialSlotsDecoration(loweredEntryPointFunc, existentialSlotArgs.getCount(), existentialSlotArgs.getBuffer()); - } - - - -} - - /// Ensure that `decl` and all relevant declarations under it get emitted. -static void ensureAllDeclsRec( - IRGenContext* context, - Decl* decl) -{ - ensureDecl(context, decl); - - // Note: We are checking here for aggregate type declarations, and - // not for `ContainerDecl`s in general. This is because many kinds - // of container declarations will already take responsibility for emitting - // their children directly (e.g., a function declaration is responsible - // for emitting its own parameters). - // - // Aggregate types are the main case where we can emit an outer declaration - // and not the stuff nested inside of it. - // - if(auto containerDecl = as(decl)) - { - for (auto memberDecl : containerDecl->Members) - { - ensureAllDeclsRec(context, memberDecl); - } - } -} - -IRModule* generateIRForTranslationUnit( - TranslationUnitRequest* translationUnit) -{ - auto compileRequest = translationUnit->compileRequest; - - SharedIRGenContext sharedContextStorage( - translationUnit->getSession(), - translationUnit->compileRequest->getSink(), - translationUnit->getModuleDecl()); - SharedIRGenContext* sharedContext = &sharedContextStorage; - - IRGenContext contextStorage(sharedContext); - IRGenContext* context = &contextStorage; - - SharedIRBuilder sharedBuilderStorage; - SharedIRBuilder* sharedBuilder = &sharedBuilderStorage; - sharedBuilder->module = nullptr; - sharedBuilder->session = compileRequest->getSession(); - - IRBuilder builderStorage; - IRBuilder* builder = &builderStorage; - builder->sharedBuilder = sharedBuilder; - - IRModule* module = builder->createModule(); - sharedBuilder->module = module; - - context->irBuilder = builder; - - // We need to emit IR for all public/exported symbols - // in the translation unit. - // - // For now, we will assume that *all* global-scope declarations - // represent public/exported symbols. - - // First, ensure that all entry points have been emitted, - // in case they require special handling. - for (auto entryPoint : translationUnit->entryPoints) - { - lowerFrontEndEntryPointToIR(context, entryPoint); - } - - // - // Next, ensure that all other global declarations have - // been emitted. - for (auto decl : translationUnit->getModuleDecl()->Members) - { - ensureAllDeclsRec(context, decl); - } - -#if 0 - fprintf(stderr, "### GENERATED\n"); - dumpIR(module); - fprintf(stderr, "###\n"); -#endif - - validateIRModuleIfEnabled(compileRequest, module); - - // We will perform certain "mandatory" optimization passes now. - // These passes serve two purposes: - // - // 1. To simplify the code that we use in backend compilation, - // or when serializing/deserializing modules, so that we can - // amortize this effort when we compile multiple entry points - // that use the same module(s). - // - // 2. To ensure certain semantic properties that can't be - // validated without dataflow information. For example, we want - // to detect when a variable might be used before it is initialized. - - // Note: if you need to debug the IR that is created before - // any mandatory optimizations have been applied, then - // uncomment this line while debugging. - - // dumpIR(module); - - // First, attempt to promote local variables to SSA - // temporaries whenever possible. - constructSSA(module); - - // Do basic constant folding and dead code elimination - // using Sparse Conditional Constant Propagation (SCCP) - // - applySparseConditionalConstantPropagation(module); - - // Propagate `constexpr`-ness through the dataflow graph (and the - // call graph) based on constraints imposed by different instructions. - propagateConstExpr(module, compileRequest->getSink()); - - // TODO: give error messages if any `undefined` or - // `unreachable` instructions remain. - - checkForMissingReturns(module, compileRequest->getSink()); - - // TODO: consider doing some more aggressive optimizations - // (in particular specialization of generics) here, so - // that we can avoid doing them downstream. - // - // Note: doing specialization or inlining involving code - // from other modules potentially makes the IR we generate - // "fragile" in that we'd now need to recompile when - // a module we depend on changes. - - validateIRModuleIfEnabled(compileRequest, module); - - // If we are being sked to dump IR during compilation, - // then we can dump the initial IR for the module here. - if(compileRequest->shouldDumpIR) - { - DiagnosticSinkWriter writer(compileRequest->getSink()); - dumpIR(module, &writer); - } - - return module; -} - -RefPtr generateIRForProgram( - Session* session, - Program* program, - DiagnosticSink* sink) -{ -// auto compileRequest = translationUnit->compileRequest; - - SharedIRGenContext sharedContextStorage( - session, - sink); - SharedIRGenContext* sharedContext = &sharedContextStorage; - - IRGenContext contextStorage(sharedContext); - IRGenContext* context = &contextStorage; - - SharedIRBuilder sharedBuilderStorage; - SharedIRBuilder* sharedBuilder = &sharedBuilderStorage; - sharedBuilder->module = nullptr; - sharedBuilder->session = session; - - IRBuilder builderStorage; - IRBuilder* builder = &builderStorage; - builder->sharedBuilder = sharedBuilder; - - RefPtr module = builder->createModule(); - sharedBuilder->module = module; - - context->irBuilder = builder; - - // We need to emit symbols for all of the entry - // points in the program; this is especially - // important in the case where a generic entry - // point is being specialized. - // - for(auto entryPoint : program->getEntryPoints()) - { - lowerProgramEntryPointToIR(context, entryPoint); - } - - - // Now lower all the arguments supplied for global generic - // type parameters. - // - for (RefPtr subst = program->getGlobalGenericSubstitution(); subst; subst = subst->outer) - { - auto gSubst = subst.as(); - if(!gSubst) - continue; - - IRInst* typeParam = getSimpleVal(context, ensureDecl(context, gSubst->paramDecl)); - IRType* typeVal = lowerType(context, gSubst->actualType); - - // bind `typeParam` to `typeVal` - builder->emitBindGlobalGenericParam(typeParam, typeVal); - - for (auto& constraintArg : gSubst->constraintArgs) - { - IRInst* constraintParam = getSimpleVal(context, ensureDecl(context, constraintArg.decl)); - IRInst* constraintVal = lowerSimpleVal(context, constraintArg.val); - - // bind `constraintParam` to `constraintVal` - builder->emitBindGlobalGenericParam(constraintParam, constraintVal); - } - } - - // We may have shader parameters of interface/existential type, - // which need us to supply concrete type information for specialization. - // - auto existentialTypeArgCount = program->getExistentialTypeArgCount(); - if( existentialTypeArgCount ) - { - List existentialSlotArgs; - for( Index ii = 0; ii < existentialTypeArgCount; ++ii ) - { - auto arg = program->getExistentialTypeArg(ii); - - auto irArgType = lowerType(context, arg.type); - auto irWitnessTable = lowerSimpleVal(context, arg.witness); - - existentialSlotArgs.add(irArgType); - existentialSlotArgs.add(irWitnessTable); - } - - builder->emitBindGlobalExistentialSlots(existentialSlotArgs.getCount(), existentialSlotArgs.getBuffer()); - } - - - // TODO: Should we apply any of the validation or - // mandatory optimization passes here? - - return module; -} - -} // namespace Slang diff --git a/source/slang/lower-to-ir.h b/source/slang/lower-to-ir.h deleted file mode 100644 index 6ac2e182a..000000000 --- a/source/slang/lower-to-ir.h +++ /dev/null @@ -1,28 +0,0 @@ -// lower.h -#ifndef SLANG_LOWER_TO_IR_H_INCLUDED -#define SLANG_LOWER_TO_IR_H_INCLUDED - -// The lowering step translates from a (type-checked) AST into -// our intermediate representation, to facilitate further -// optimization and transformation. - -#include "../core/basic.h" - -#include "compiler.h" -#include "ir.h" - -namespace Slang -{ - class EntryPoint; - class ProgramLayout; - class TranslationUnitRequest; - - IRModule* generateIRForTranslationUnit( - TranslationUnitRequest* translationUnit); - - RefPtr generateIRForProgram( - Session* session, - Program* program, - DiagnosticSink* sink); -} -#endif diff --git a/source/slang/mangle.cpp b/source/slang/mangle.cpp deleted file mode 100644 index 2ed9cfea6..000000000 --- a/source/slang/mangle.cpp +++ /dev/null @@ -1,478 +0,0 @@ -#include "mangle.h" - -#include "name.h" -#include "syntax.h" - -namespace Slang -{ - struct ManglingContext - { - StringBuilder sb; - }; - - void emitRaw( - ManglingContext* context, - char const* text) - { - context->sb.append(text); - } - - void emit( - ManglingContext* context, - UInt value) - { - context->sb.append(value); - } - - void emit( - ManglingContext* context, - String const& value) - { - context->sb.append(value); - } - - void emitName( - ManglingContext* context, - Name* name) - { - String str = getText(name); - - // If the name consists of only traditional "identifer characters" - // (`[a-zA-Z_]`), then we wnat to emit it more or less directly. - // - // If it contains code points outside that range, we'll need to - // do something to encode them. I don't want to deal with - // that right now, so I'm going to ignore it. - - // We prefix the string with its byte length, so that - // decoding doesn't have to worry about finding a terminator. - Index length = str.getLength(); - emit(context, length); - context->sb.append(str); - } - - void emitVal( - ManglingContext* context, - Val* val); - - void emitQualifiedName( - ManglingContext* context, - DeclRef declRef); - - void emitSimpleIntVal( - ManglingContext* context, - Val* val) - { - if( auto constVal = as(val) ) - { - auto cVal = constVal->value; - if(cVal >= 0 && cVal <= 9 ) - { - emit(context, (UInt)cVal); - return; - } - } - - // Fallback: - emitVal(context, val); - } - - void emitBaseType( - ManglingContext* context, - BaseType baseType) - { - switch( baseType ) - { - case BaseType::Void: emitRaw(context, "V"); break; - case BaseType::Bool: emitRaw(context, "b"); break; - case BaseType::Int8: emitRaw(context, "c"); break; - case BaseType::Int16: emitRaw(context, "s"); break; - case BaseType::Int: emitRaw(context, "i"); break; - case BaseType::Int64: emitRaw(context, "I"); break; - case BaseType::UInt8: emitRaw(context, "C"); break; - case BaseType::UInt16: emitRaw(context, "S"); break; - case BaseType::UInt: emitRaw(context, "u"); break; - case BaseType::UInt64: emitRaw(context, "U"); break; - case BaseType::Half: emitRaw(context, "h"); break; - case BaseType::Float: emitRaw(context, "f"); break; - case BaseType::Double: emitRaw(context, "d"); break; - break; - - default: - SLANG_UNEXPECTED("unimplemented case in mangling"); - break; - } - } - - void emitType( - ManglingContext* context, - Type* type) - { - // TODO: actually implement this bit... - - if( auto basicType = dynamicCast(type) ) - { - emitBaseType(context, basicType->baseType); - } - else if( auto vecType = dynamicCast(type) ) - { - emitRaw(context, "v"); - emitSimpleIntVal(context, vecType->elementCount); - emitType(context, vecType->elementType); - } - else if( auto matType = dynamicCast(type) ) - { - emitRaw(context, "m"); - emitSimpleIntVal(context, matType->getRowCount()); - emitRaw(context, "x"); - emitSimpleIntVal(context, matType->getColumnCount()); - emitType(context, matType->getElementType()); - } - else if( auto namedType = dynamicCast(type) ) - { - emitType(context, GetType(namedType->declRef)); - } - else if( auto declRefType = dynamicCast(type) ) - { - emitQualifiedName(context, declRefType->declRef); - } - else if (auto arrType = dynamicCast(type)) - { - emitRaw(context, "a"); - emitSimpleIntVal(context, arrType->ArrayLength); - emitType(context, arrType->baseType); - } - else if( auto taggedUnionType = dynamicCast(type) ) - { - emitRaw(context, "u"); - for( auto caseType : taggedUnionType->caseTypes ) - { - emitType(context, caseType); - } - emitRaw(context, "U"); - } - else - { - SLANG_UNEXPECTED("unimplemented case in mangling"); - } - } - - void emitVal( - ManglingContext* context, - Val* val) - { - if( auto type = dynamicCast(val) ) - { - emitType(context, type); - } - else if( auto witness = dynamicCast(val) ) - { - // We don't emit witnesses as part of a mangled - // name, because the way that the front-end - // arrived at the witness is not important; - // what matters is that the type constraint - // was satisfied. - // - // TODO: make sure we can't get name collisions - // between specializations of declarations - // with the same numbers of generic parameters, - // but different constraints. We might have - // to mangle in the constraints even when - // the whole thing is specialized... - } - else if( auto genericParamIntVal = dynamicCast(val) ) - { - // TODO: we shouldn't be including the names of generic parameters - // anywhere in mangled names, since changing parameter names - // shouldn't break binary compatibility. - // - // The right solution in the long term is for generic parameters - // (both types and values) to be mangled in terms of their - // "depth" (how many outer generics) and "index" (which - // parameter are they at the specified depth). - emitRaw(context, "K"); - emitName(context, genericParamIntVal->declRef.GetName()); - } - else if( auto constantIntVal = dynamicCast(val) ) - { - // TODO: need to figure out what prefix/suffix is needed - // to allow demangling later. - emitRaw(context, "k"); - emit(context, (UInt) constantIntVal->value); - } - else - { - SLANG_UNEXPECTED("unimplemented case in mangling"); - } - } - - void emitQualifiedName( - ManglingContext* context, - DeclRef declRef) - { - auto parentDeclRef = declRef.GetParent(); - auto parentGenericDeclRef = parentDeclRef.as(); - if( parentDeclRef ) - { - // In certain cases we want to skip emitting the parent - if(parentGenericDeclRef && (parentGenericDeclRef.getDecl()->inner.Ptr() != declRef.getDecl())) - { - } - else if(parentDeclRef.as()) - { - } - else - { - emitQualifiedName(context, parentDeclRef); - } - } - - // A generic declaration is kind of a pseudo-declaration - // as far as the user is concerned; so we don't want - // to emit its name. - if(auto genericDeclRef = declRef.as()) - { - return; - } - - // Inheritance declarations don't have meaningful names, - // and so we should emit them based on the type - // that is doing the inheriting. - if(auto inheritanceDeclRef = declRef.as()) - { - emit(context, "I"); - emitType(context, GetSup(inheritanceDeclRef)); - return; - } - - // Similarly, an extension doesn't have a name worth - // emitting, and we should base things on its target - // type instead. - if(auto extensionDeclRef = declRef.as()) - { - // TODO: as a special case, an "unconditional" extension - // that is in the same module as the type it extends should - // be treated as equivalent to the type itself. - emit(context, "X"); - emitType(context, GetTargetType(extensionDeclRef)); - return; - } - - emitName(context, declRef.GetName()); - - // Special case: accessors need some way to distinguish themselves - // so that a getter/setter/ref-er don't all compile to the same name. - { - if (declRef.is()) emitRaw(context, "Ag"); - if (declRef.is()) emitRaw(context, "As"); - if (declRef.is()) emitRaw(context, "Ar"); - } - - // Are we the "inner" declaration beneath a generic decl? - if(parentGenericDeclRef && (parentGenericDeclRef.getDecl()->inner.Ptr() == declRef.getDecl())) - { - // There are two cases here: either we have specializations - // in place for the parent generic declaration, or we don't. - - auto subst = findInnerMostGenericSubstitution(declRef.substitutions); - if( subst && subst->genericDecl == parentGenericDeclRef.getDecl() ) - { - // This is the case where we *do* have substitutions. - emitRaw(context, "G"); - UInt genericArgCount = subst->args.getCount(); - emit(context, genericArgCount); - for( auto aa : subst->args ) - { - emitVal(context, aa); - } - } - else - { - // We don't have substitutions, so we will emit - // information about the parameters of the generic here. - emitRaw(context, "g"); - UInt genericParameterCount = 0; - for( auto mm : getMembers(parentGenericDeclRef) ) - { - if(mm.is()) - { - genericParameterCount++; - } - else if(mm.is()) - { - genericParameterCount++; - } - else if(mm.is()) - { - genericParameterCount++; - } - else - { - } - } - - emit(context, genericParameterCount); - for( auto mm : getMembers(parentGenericDeclRef) ) - { - if(auto genericTypeParamDecl = mm.as()) - { - emitRaw(context, "T"); - } - else if(auto genericValueParamDecl = mm.as()) - { - emitRaw(context, "v"); - emitType(context, GetType(genericValueParamDecl)); - } - else if(mm.as()) - { - emitRaw(context, "C"); - // TODO: actually emit info about the constraint - } - else - { - } - } - } - } - - // If the declaration has parameters, then we need to emit - // those parameters to distinguish it from other declarations - // of the same name that might have different parameters. - // - // We'll also go ahead and emit the result type as well, - // just for completeness. - // - if( auto callableDeclRef = declRef.as()) - { - auto parameters = GetParameters(callableDeclRef); - UInt parameterCount = parameters.Count(); - - emitRaw(context, "p"); - emit(context, parameterCount); - emitRaw(context, "p"); - - for(auto paramDeclRef : parameters) - { - emitType(context, GetType(paramDeclRef)); - } - - // Don't print result type for an initializer/constructor, - // since it is implicit in the qualified name. - if (!callableDeclRef.is()) - { - emitType(context, GetResultType(callableDeclRef)); - } - } - } - - void mangleName( - ManglingContext* context, - DeclRef declRef) - { - // TODO: catch cases where the declaration should - // forward to something else? E.g., what if we - // are asked to mangle the name of a `typedef`? - - // We will start with a unique prefix to avoid - // clashes with user-defined symbols: - emitRaw(context, "_S"); - - auto decl = declRef.getDecl(); - - // Next we will add a bit of info to register - // the *kind* of declaration we are dealing with. - // - // Functions will get no prefix, since we assume - // they are a common case: - if(as(decl)) - {} - // Types will get a `T` prefix: - else if(as(decl)) - emitRaw(context, "T"); - else if(as(decl)) - emitRaw(context, "T"); - // Variables will get a `V` prefix: - // - // TODO: probably need to pull constant-buffer - // declarations out of this... - else if(as(decl)) - emitRaw(context, "V"); - else - { - // TODO: handle other cases - } - - // Now we encode the qualified name of the decl. - emitQualifiedName(context, declRef); - } - - String getMangledName(DeclRef const& declRef) - { - ManglingContext context; - mangleName(&context, declRef); - return context.sb.ProduceString(); - } - - String getMangledName(DeclRefBase const & declRef) - { - return getMangledName( - DeclRef(declRef.decl, declRef.substitutions)); - } - - String getMangledName(Decl* decl) - { - return getMangledName(makeDeclRef(decl)); - } - - String getMangledNameForConformanceWitness( - DeclRef sub, - DeclRef sup) - { - ManglingContext context; - emitRaw(&context, "_SW"); - emitQualifiedName(&context, sub); - emitQualifiedName(&context, sup); - return context.sb.ProduceString(); - } - - String getMangledNameForConformanceWitness( - DeclRef sub, - Type* sup) - { - // The mangled form for a witness that `sub` - // conforms to `sup` will be named: - // - // {Conforms(sub,sup)} => _SW{sub}{sup} - // - ManglingContext context; - emitRaw(&context, "_SW"); - emitQualifiedName(&context, sub); - emitType(&context, sup); - return context.sb.ProduceString(); - } - - String getMangledNameForConformanceWitness( - Type* sub, - Type* sup) - { - // The mangled form for a witness that `sub` - // conforms to `sup` will be named: - // - // {Conforms(sub,sup)} => _SW{sub}{sup} - // - ManglingContext context; - emitRaw(&context, "_SW"); - emitType(&context, sub); - emitType(&context, sup); - return context.sb.ProduceString(); - } - - String getMangledTypeName(Type* type) - { - ManglingContext context; - emitType(&context, type); - return context.sb.ProduceString(); - } - - -} diff --git a/source/slang/mangle.h b/source/slang/mangle.h deleted file mode 100644 index 7fc8d0d93..000000000 --- a/source/slang/mangle.h +++ /dev/null @@ -1,29 +0,0 @@ -#ifndef SLANG_MANGLE_H_INCLUDED -#define SLANG_MANGLE_H_INCLUDED - -// This file implements the name mangling scheme for the Slang language. - -#include "../core/basic.h" -#include "syntax.h" - -namespace Slang -{ - struct IRSpecialize; - - String getMangledName(Decl* decl); - String getMangledName(DeclRef const & declRef); - String getMangledName(DeclRefBase const & declRef); - - String getMangledNameForConformanceWitness( - Type* sub, - Type* sup); - String getMangledNameForConformanceWitness( - DeclRef sub, - DeclRef sup); - String getMangledNameForConformanceWitness( - DeclRef sub, - Type* sup); - String getMangledTypeName(Type* type); -} - -#endif \ No newline at end of file diff --git a/source/slang/modifier-defs.h b/source/slang/modifier-defs.h deleted file mode 100644 index 01ed792a9..000000000 --- a/source/slang/modifier-defs.h +++ /dev/null @@ -1,463 +0,0 @@ -// modifier-defs.h - -// Syntax class definitions for modifiers. - -// Simple modifiers have no state beyond their identity -#define SIMPLE_MODIFIER(NAME) \ - SIMPLE_SYNTAX_CLASS(NAME##Modifier, Modifier) - -SIMPLE_MODIFIER(In); -SIMPLE_MODIFIER(Out); -SIMPLE_MODIFIER(Const); -SIMPLE_MODIFIER(Instance); -SIMPLE_MODIFIER(Builtin); -SIMPLE_MODIFIER(Inline); -SIMPLE_MODIFIER(Public); -SIMPLE_MODIFIER(Require); -SIMPLE_MODIFIER(Param); -SIMPLE_MODIFIER(Extern); -SIMPLE_MODIFIER(Input); -SIMPLE_MODIFIER(Transparent); -SIMPLE_MODIFIER(FromStdLib); -SIMPLE_MODIFIER(Prefix); -SIMPLE_MODIFIER(Postfix); -SIMPLE_MODIFIER(Exported); -SIMPLE_MODIFIER(ConstExpr); -SIMPLE_MODIFIER(GloballyCoherent) - -#undef SIMPLE_MODIFIER - -// A modifier that marks something as an operation that -// has a one-to-one translation to the IR, and thus -// has no direct definition in the high-level language. -// -SYNTAX_CLASS(IntrinsicOpModifier, Modifier) - - // token that names the intrinsic op - FIELD(Token, opToken) - - // The opcode for the intrinsic operation - FIELD_INIT(IROp, op, kIROp_Nop) -END_SYNTAX_CLASS() - -// A modifier that marks something as an intrinsic function, -// for some subset of targets. -SYNTAX_CLASS(TargetIntrinsicModifier, Modifier) - // Token that names the target that the operation - // is an intrisic for. - FIELD(Token, targetToken) - - // A custom definition for the operation - FIELD(Token, definitionToken) -END_SYNTAX_CLASS() - -// A modifier that marks a declaration as representing a -// specialization that should be preferred on a particular -// target. -SYNTAX_CLASS(SpecializedForTargetModifier, Modifier) - // Token that names the target that the operation - // has been specialized for. - FIELD(Token, targetToken) -END_SYNTAX_CLASS() - -// A modifier to tag something as an intrinsic that requires -// a certain GLSL extension to be enabled when used -SYNTAX_CLASS(RequiredGLSLExtensionModifier, Modifier) -FIELD(Token, extensionNameToken) -END_SYNTAX_CLASS() - -// A modifier to tag something as an intrinsic that requires -// a certain GLSL version to be enabled when used -SYNTAX_CLASS(RequiredGLSLVersionModifier, Modifier) -FIELD(Token, versionNumberToken) -END_SYNTAX_CLASS() - - -SIMPLE_SYNTAX_CLASS(InOutModifier, OutModifier) - -// `__ref` modifier for by-reference parameter passing -SIMPLE_SYNTAX_CLASS(RefModifier, Modifier) - -// This is a special sentinel modifier that gets added -// to the list when we have multiple variable declarations -// all sharing the same modifiers: -// -// static uniform int a : FOO, *b : register(x0); -// -// In this case both `a` and `b` share the syntax -// for part of their modifier list, but then have -// their own modifiers as well: -// -// a: SemanticModifier("FOO") --> SharedModifiers --> StaticModifier --> UniformModifier -// / -// b: RegisterModifier("x0") / -// -SIMPLE_SYNTAX_CLASS(SharedModifiers, Modifier) - -// A GLSL `layout` modifier -// -// We use a distinct modifier for each key that -// appears within the `layout(...)` construct, -// and each key might have an optional value token. -// -// TODO: We probably want a notion of "modifier groups" -// so that we can recover good source location info -// for modifiers that were part of the same vs. -// different constructs. -ABSTRACT_SYNTAX_CLASS(GLSLLayoutModifier, Modifier) - -// The token used to introduce the modifier is stored -// as the `nameToken` field. - -// TODO: may want to accept a full expression here -FIELD(Token, valToken) -END_SYNTAX_CLASS() - -// AST nodes to represent the begin/end of a `layout` modifier group -ABSTRACT_SYNTAX_CLASS(GLSLLayoutModifierGroupMarker, Modifier) -END_SYNTAX_CLASS() -SIMPLE_SYNTAX_CLASS(GLSLLayoutModifierGroupBegin, GLSLLayoutModifierGroupMarker) -SIMPLE_SYNTAX_CLASS(GLSLLayoutModifierGroupEnd, GLSLLayoutModifierGroupMarker) - -// We divide GLSL `layout` modifiers into those we have parsed -// (in the sense of having some notion of their semantics), and -// those we have not. -ABSTRACT_SYNTAX_CLASS(GLSLParsedLayoutModifier , GLSLLayoutModifier) -END_SYNTAX_CLASS() - -SIMPLE_SYNTAX_CLASS(GLSLUnparsedLayoutModifier , GLSLLayoutModifier) - -// Specific cases for known GLSL `layout` modifiers that we need to work with -SIMPLE_SYNTAX_CLASS(GLSLConstantIDLayoutModifier , GLSLParsedLayoutModifier) -SIMPLE_SYNTAX_CLASS(GLSLLocationLayoutModifier , GLSLParsedLayoutModifier) - -SIMPLE_SYNTAX_CLASS(GLSLLocalSizeLayoutModifier, GLSLUnparsedLayoutModifier) -SIMPLE_SYNTAX_CLASS(GLSLLocalSizeXLayoutModifier, GLSLLocalSizeLayoutModifier) -SIMPLE_SYNTAX_CLASS(GLSLLocalSizeYLayoutModifier, GLSLLocalSizeLayoutModifier) -SIMPLE_SYNTAX_CLASS(GLSLLocalSizeZLayoutModifier, GLSLLocalSizeLayoutModifier) - -// A catch-all for single-keyword modifiers -SIMPLE_SYNTAX_CLASS(SimpleModifier, Modifier) - -// Some GLSL-specific modifiers -SIMPLE_SYNTAX_CLASS(GLSLBufferModifier , SimpleModifier) -SIMPLE_SYNTAX_CLASS(GLSLWriteOnlyModifier, SimpleModifier) -SIMPLE_SYNTAX_CLASS(GLSLReadOnlyModifier , SimpleModifier) -SIMPLE_SYNTAX_CLASS(GLSLPatchModifier , SimpleModifier) - -// Indicates that this is a variable declaration that corresponds to -// a parameter block declaration in the source program. -SIMPLE_SYNTAX_CLASS(ImplicitParameterGroupVariableModifier , Modifier) - -// Indicates that this is a type that corresponds to the element -// type of a parameter block declaration in the source program. -SIMPLE_SYNTAX_CLASS(ImplicitParameterGroupElementTypeModifier, Modifier) - -// An HLSL semantic -ABSTRACT_SYNTAX_CLASS(HLSLSemantic, Modifier) - FIELD(Token, name) -END_SYNTAX_CLASS() - -// An HLSL semantic that affects layout -SYNTAX_CLASS(HLSLLayoutSemantic, HLSLSemantic) - - FIELD(Token, registerName) - FIELD(Token, componentMask) -END_SYNTAX_CLASS() - -// An HLSL `register` semantic -SYNTAX_CLASS(HLSLRegisterSemantic, HLSLLayoutSemantic) - FIELD(Token, spaceName) -END_SYNTAX_CLASS() - -// TODO(tfoley): `packoffset` -SIMPLE_SYNTAX_CLASS(HLSLPackOffsetSemantic, HLSLLayoutSemantic) - -// An HLSL semantic that just associated a declaration with a semantic name -SIMPLE_SYNTAX_CLASS(HLSLSimpleSemantic, HLSLSemantic) - -// GLSL - -// Directives that came in via the preprocessor, but -// that we need to keep around for later steps -SIMPLE_SYNTAX_CLASS(GLSLPreprocessorDirective, Modifier) - -// A GLSL `#version` directive -SYNTAX_CLASS(GLSLVersionDirective, GLSLPreprocessorDirective) - - // Token giving the version number to use - FIELD(Token, versionNumberToken) - - // Optional token giving the sub-profile to be used - FIELD(Token, glslProfileToken) -END_SYNTAX_CLASS() - -// A GLSL `#extension` directive -SYNTAX_CLASS(GLSLExtensionDirective, GLSLPreprocessorDirective) - - // Token giving the version number to use - FIELD(Token, extensionNameToken) - - // Optional token giving the sub-profile to be used - FIELD(Token, dispositionToken) -END_SYNTAX_CLASS() - -SYNTAX_CLASS(ParameterGroupReflectionName, Modifier) - FIELD(NameLoc, nameAndLoc) -END_SYNTAX_CLASS() - -// A modifier that indicates a built-in base type (e.g., `float`) -SYNTAX_CLASS(BuiltinTypeModifier, Modifier) - FIELD(BaseType, tag) -END_SYNTAX_CLASS() - -// A modifier that indicates a built-in type that isn't a base type (e.g., `vector`) -// -// TODO(tfoley): This deserves a better name than "magic" -SYNTAX_CLASS(MagicTypeModifier, Modifier) - FIELD(String, name) - FIELD(uint32_t, tag) -END_SYNTAX_CLASS() - -// A modifier applied to declarations of builtin types to indicate how they -// should be lowered to the IR. -// -// TODO: This should really subsume `BuiltinTypeModifier` and -// `MagicTypeModifier` so that we don't have to apply all of them. -SYNTAX_CLASS(IntrinsicTypeModifier, Modifier) - // The IR opcode to use when constructing a type - FIELD(uint32_t, irOp) - - // Additional literal opreands to provide when creating instances. - // (e.g., for a texture type this passes in shape/mutability info) - FIELD(List, irOperands) -END_SYNTAX_CLASS() - -// Modifiers that affect the storage layout for matrices -SIMPLE_SYNTAX_CLASS(MatrixLayoutModifier, Modifier) - -// Modifiers that specify row- and column-major layout, respectively -SIMPLE_SYNTAX_CLASS(RowMajorLayoutModifier, MatrixLayoutModifier) -SIMPLE_SYNTAX_CLASS(ColumnMajorLayoutModifier, MatrixLayoutModifier) - -// The HLSL flavor of those modifiers -SIMPLE_SYNTAX_CLASS(HLSLRowMajorLayoutModifier, RowMajorLayoutModifier) -SIMPLE_SYNTAX_CLASS(HLSLColumnMajorLayoutModifier, ColumnMajorLayoutModifier) - -// The GLSL flavor of those modifiers -// -// Note(tfoley): The GLSL versions of these modifiers are "backwards" -// in the sense that when a GLSL programmer requests row-major layout, -// we actually interpret that as requesting column-major. This makes -// sense because we interpret matrix conventions backwards from how -// GLSL specifies them. -SIMPLE_SYNTAX_CLASS(GLSLRowMajorLayoutModifier, ColumnMajorLayoutModifier) -SIMPLE_SYNTAX_CLASS(GLSLColumnMajorLayoutModifier, RowMajorLayoutModifier) - -// More HLSL Keyword - -ABSTRACT_SYNTAX_CLASS(InterpolationModeModifier, Modifier) -END_SYNTAX_CLASS() - -// HLSL `nointerpolation` modifier -SIMPLE_SYNTAX_CLASS(HLSLNoInterpolationModifier, InterpolationModeModifier) - -// HLSL `noperspective` modifier -SIMPLE_SYNTAX_CLASS(HLSLNoPerspectiveModifier, InterpolationModeModifier) - -// HLSL `linear` modifier -SIMPLE_SYNTAX_CLASS(HLSLLinearModifier, InterpolationModeModifier) - -// HLSL `sample` modifier -SIMPLE_SYNTAX_CLASS(HLSLSampleModifier, InterpolationModeModifier) - -// HLSL `centroid` modifier -SIMPLE_SYNTAX_CLASS(HLSLCentroidModifier, InterpolationModeModifier) - -// HLSL `precise` modifier -SIMPLE_SYNTAX_CLASS(PreciseModifier, Modifier) - -// HLSL `shared` modifier (which is used by the effect system, -// and shouldn't be confused with `groupshared`) -SIMPLE_SYNTAX_CLASS(HLSLEffectSharedModifier, Modifier) - -// HLSL `groupshared` modifier -SIMPLE_SYNTAX_CLASS(HLSLGroupSharedModifier, Modifier) - -// HLSL `static` modifier (probably doesn't need to be -// treated as HLSL-specific) -SIMPLE_SYNTAX_CLASS(HLSLStaticModifier, Modifier) - -// HLSL `uniform` modifier (distinct meaning from GLSL -// use of the keyword) -SIMPLE_SYNTAX_CLASS(HLSLUniformModifier, Modifier) - -// HLSL `volatile` modifier (ignored) -SIMPLE_SYNTAX_CLASS(HLSLVolatileModifier, Modifier) - -SYNTAX_CLASS(AttributeTargetModifier, Modifier) - // A class to which the declared attribute type is applicable - FIELD(SyntaxClass, syntaxClass) -END_SYNTAX_CLASS() - -// Base class for checked and unchecked `[name(arg0, ...)]` style attribute. -SYNTAX_CLASS(AttributeBase, Modifier) - SYNTAX_FIELD(List>, args) -END_SYNTAX_CLASS() - -// A `[name(...)]` attribute that hasn't undergone any semantic analysis. -// After analysis, this will be transformed into a more specific case. -SYNTAX_CLASS(UncheckedAttribute, AttributeBase) - FIELD(RefPtr, scope) -END_SYNTAX_CLASS() - -// A `[name(arg0, ...)]` style attribute that has been validated. -SYNTAX_CLASS(Attribute, AttributeBase) - FIELD(AttributeArgumentValueDict, intArgVals) -END_SYNTAX_CLASS() - -SYNTAX_CLASS(UserDefinedAttribute, Attribute) -END_SYNTAX_CLASS() - -SYNTAX_CLASS(AttributeUsageAttribute, Attribute) - FIELD(SyntaxClass, targetSyntaxClass) -END_SYNTAX_CLASS() - -// An `[unroll]` or `[unroll(count)]` attribute -SYNTAX_CLASS(UnrollAttribute, Attribute) - RAW(IntegerLiteralValue getCount();) -END_SYNTAX_CLASS() - -SIMPLE_SYNTAX_CLASS(LoopAttribute, Attribute) // `[loop]` -SIMPLE_SYNTAX_CLASS(FastOptAttribute, Attribute) // `[fastopt]` -SIMPLE_SYNTAX_CLASS(AllowUAVConditionAttribute, Attribute) // `[allow_uav_condition]` -SIMPLE_SYNTAX_CLASS(BranchAttribute, Attribute) // `[branch]` -SIMPLE_SYNTAX_CLASS(FlattenAttribute, Attribute) // `[flatten]` -SIMPLE_SYNTAX_CLASS(ForceCaseAttribute, Attribute) // `[forcecase]` -SIMPLE_SYNTAX_CLASS(CallAttribute, Attribute) // `[call]` - - -// [[vk_push_constant]] [[push_constant]] -SIMPLE_SYNTAX_CLASS(PushConstantAttribute, Attribute) - -// [[vk_shader_record]] [[shader_record]] -SIMPLE_SYNTAX_CLASS(ShaderRecordAttribute, Attribute) - -// [[vk_binding]] -SYNTAX_CLASS(GLSLBindingAttribute, Attribute) - FIELD(int32_t, binding = 0) - FIELD(int32_t, set = 0) -END_SYNTAX_CLASS() - -// TODO: for attributes that take arguments, the syntax node -// classes should provide accessors for the values of those arguments. - -SIMPLE_SYNTAX_CLASS(MaxTessFactorAttribute, Attribute) -SIMPLE_SYNTAX_CLASS(OutputControlPointsAttribute, Attribute) -SIMPLE_SYNTAX_CLASS(OutputTopologyAttribute, Attribute) -SIMPLE_SYNTAX_CLASS(PartitioningAttribute, Attribute) -SYNTAX_CLASS(PatchConstantFuncAttribute, Attribute) - FIELD(RefPtr, patchConstantFuncDecl) -END_SYNTAX_CLASS() -SIMPLE_SYNTAX_CLASS(DomainAttribute, Attribute) - -SIMPLE_SYNTAX_CLASS(EarlyDepthStencilAttribute, Attribute) // `[earlydepthstencil]` - -// An HLSL `[numthreads(x,y,z)]` attribute -SYNTAX_CLASS(NumThreadsAttribute, Attribute) - // The number of threads to use along each axis - // - // TODO: These should be accessors that use the - // ordinary `args` list, rather than side data. - FIELD(int32_t, x) - FIELD(int32_t, y) - FIELD(int32_t, z) -END_SYNTAX_CLASS() - -SYNTAX_CLASS(MaxVertexCountAttribute, Attribute) - // The number of max vertex count for geometry shader - // - // TODO: This should be an accessor that uses the - // ordinary `args` list, rather than side data. - FIELD(int32_t, value) -END_SYNTAX_CLASS() - -SYNTAX_CLASS(InstanceAttribute, Attribute) - // The number of instances to run for geometry shader - // - // TODO: This should be an accessor that uses the - // ordinary `args` list, rather than side data. - FIELD(int32_t, value) -END_SYNTAX_CLASS() - -// A `[shader("stageName")]` attribute, which marks an entry point -// to be compiled, and specifies the stage for that entry point -SYNTAX_CLASS(EntryPointAttribute, Attribute) - // The resolved stage that the entry point is targetting. - // - // TODO: This should be an accessor that uses the - // ordinary `args` list, rather than side data. - FIELD(Stage, stage); -END_SYNTAX_CLASS() - -// A `[__vulkanRayPayload]` attribute, which is used in the -// standard library implementation to indicate that a variable -// actually represents the input/output interface for a Vulkan -// ray tracing shader to pass per-ray payload information. -SIMPLE_SYNTAX_CLASS(VulkanRayPayloadAttribute, Attribute) - -// A `[__vulkanCallablePayload]` attribute, which is used in the -// standard library implementation to indicate that a variable -// actually represents the input/output interface for a Vulkan -// ray tracing shader to pass payload information to/from a callee. -SIMPLE_SYNTAX_CLASS(VulkanCallablePayloadAttribute, Attribute) - -// A `[__vulkanHitAttributes]` attribute, which is used in the -// standard library implementation to indicate that a variable -// actually represents the output interface for a Vulkan -// intersection shader to pass hit attribute information. -SIMPLE_SYNTAX_CLASS(VulkanHitAttributesAttribute, Attribute) - -// A `[mutating]` attribute, which indicates that a member -// function is allowed to modify things through its `this` -// argument. -// -SIMPLE_SYNTAX_CLASS(MutatingAttribute, Attribute) - -// A `[__readNone]` attribute, which indicates that a function -// computes its results strictly based on argument values, without -// reading or writing through any pointer arguments, or any other -// state that could be observed by a caller. -// -SIMPLE_SYNTAX_CLASS(ReadNoneAttribute, Attribute) - - -// HLSL modifiers for geometry shader input topology -SIMPLE_SYNTAX_CLASS(HLSLGeometryShaderInputPrimitiveTypeModifier, Modifier) -SIMPLE_SYNTAX_CLASS(HLSLPointModifier , HLSLGeometryShaderInputPrimitiveTypeModifier) -SIMPLE_SYNTAX_CLASS(HLSLLineModifier , HLSLGeometryShaderInputPrimitiveTypeModifier) -SIMPLE_SYNTAX_CLASS(HLSLTriangleModifier , HLSLGeometryShaderInputPrimitiveTypeModifier) -SIMPLE_SYNTAX_CLASS(HLSLLineAdjModifier , HLSLGeometryShaderInputPrimitiveTypeModifier) -SIMPLE_SYNTAX_CLASS(HLSLTriangleAdjModifier , HLSLGeometryShaderInputPrimitiveTypeModifier) - -// A modifier to be attached to syntax after we've computed layout -SYNTAX_CLASS(ComputedLayoutModifier, Modifier) - FIELD(RefPtr, layout) -END_SYNTAX_CLASS() - - -SYNTAX_CLASS(TupleVarModifier, Modifier) -// FIELD_INIT(TupleFieldModifier*, tupleField, nullptr) -END_SYNTAX_CLASS() - -// A modifier to indicate that a constructor/initializer can be used -// to perform implicit type conversion, and to specify the cost of -// the conversion, if applied. -SYNTAX_CLASS(ImplicitConversionModifier, Modifier) - // The conversion cost, used to rank conversions - FIELD(ConversionCost, cost) -END_SYNTAX_CLASS() - -SYNTAX_CLASS(FormatAttribute, Attribute) - FIELD(ImageFormat, format) -END_SYNTAX_CLASS() diff --git a/source/slang/name.cpp b/source/slang/name.cpp deleted file mode 100644 index 21586e0b6..000000000 --- a/source/slang/name.cpp +++ /dev/null @@ -1,37 +0,0 @@ -// name.cpp -#include "name.h" - -namespace Slang { - -String getText(Name* name) -{ - if (!name) return String(); - return name->text; -} - -UnownedStringSlice getUnownedStringSliceText(Name* name) -{ - return name ? name->text.getUnownedSlice() : UnownedStringSlice(); -} - -Name* NamePool::getName(String const& text) -{ - RefPtr name; - if (rootPool->names.TryGetValue(text, name)) - return name; - - name = new Name(); - name->text = text; - rootPool->names.Add(text, name); - return name; -} - -Name* NamePool::tryGetName(String const& text) -{ - RefPtr name; - if (rootPool->names.TryGetValue(text, name)) - return name; - return nullptr; -} - -} // namespace Slang diff --git a/source/slang/name.h b/source/slang/name.h deleted file mode 100644 index a144fbb84..000000000 --- a/source/slang/name.h +++ /dev/null @@ -1,86 +0,0 @@ -// name.h -#ifndef SLANG_NAME_H_INCLUDED -#define SLANG_NAME_H_INCLUDED - -// This file defines the `Name` type, used to represent -// the name of types, variables, etc. in the AST. - -#include "../core/basic.h" - -namespace Slang { - -// The `Name` type is used to represent the name of a type, variable, etc. -// -// The key benefit of using `Name`s instead of raw strings is that `Name`s -// can be compared for equality just by testing pointer equality. Names -// also don't require any memory management; you can just retain an ordinary -// pointer to one and not deal with reference-counting overhead. -// -// In order to provide these benefits, a `Name` can only be created using -// a `NamePool` that owns the allocations for all the names (so they get -// cleaned up when the pool is deleted), and which is responsible for -// ensuring the uniqueness of name objects. -// -class Name : public RefObject -{ -public: - // The raw text of the name. - // - // Note that at some point in the future we might have other categories - // of name than "simple" names, and so this might change to a structured - // ADT instead of a simple string. - String text; -}; - -// Get the textual string representation of a name -// (e.g., so that it can be printed). -String getText(Name* name); - -/// Get the text as unowned string slice -UnownedStringSlice getUnownedStringSliceText(Name* name); - -// A `RootNamePool` is used to store and look up names. -// If two systems need to work together with names, and be sure that they -// get equivalent names for a string like `"Foo"`, then they need to use -// the same root name pool (directly or indirectly). -// -struct RootNamePool -{ - // The mapping from text strings to the corresponding name. - Dictionary > names; -}; - -// A `NamePool` is effectively a way of storing a subset of the -// names that have been created through a `RootNamePool`. -// -// The intention is that eventually we will add the ability to clean -// up a `NamePool`, and remove the names it created from the corresponding -// `RootNamePool` *if* those names are no longer in use. -// -// The goal of such an approach would be to ensure that the memory -// usage of a `Session` can't bloat over time just because of multiple -// `CompileRequest`s being created, used, and then destroyed (each time -// adding just a few more strings to the name mapping). -// -struct NamePool -{ - // Find or create the `Name` that represents the given `text`. - Name* getName(String const& text); - // Try find the `Name` that represents the given `text`. - // If the name does not exist, return nullptr - Name* tryGetName(String const& text); - // Set the parent name pool to use for lookup - void setRootNamePool(RootNamePool* rootNamePool) - { - this->rootPool = rootNamePool; - } - - // - - // The root name pool to use for storage/lookup - RootNamePool* rootPool = nullptr; -}; - -} // namespace Slang - -#endif diff --git a/source/slang/object-meta-begin.h b/source/slang/object-meta-begin.h deleted file mode 100644 index 7340ed413..000000000 --- a/source/slang/object-meta-begin.h +++ /dev/null @@ -1,43 +0,0 @@ -// object-meta-begin.h - -#ifndef SYNTAX_CLASS -#error The 'SYNTAX_CLASS' macro should be defined before including 'object-meta-begin.h' -#endif - -#ifndef ABSTRACT_SYNTAX_CLASS -#define ABSTRACT_SYNTAX_CLASS(NAME, BASE) SYNTAX_CLASS(NAME, BASE) -#endif - -#ifndef END_SYNTAX_CLASS -#define END_SYNTAX_CLASS() /* empty */ -#endif - -#ifndef DECL_FIELD -#define DECL_FIELD(TYPE, NAME) SYNTAX_FIELD(TYPE, NAME) -#endif - -#ifndef SYNTAX_FIELD -#define SYNTAX_FIELD(TYPE, NAME) FIELD(TYPE, NAME) -#endif - -#ifndef FIELD_INIT -#define FIELD_INIT(TYPE, NAME, INIT) FIELD(TYPE, NAME) -#endif - -#ifndef FIELD -#define FIELD(...) /* empty */ -#endif - -#ifndef RAW -#define RAW(...) /* empty */ -#endif - -#define SIMPLE_SYNTAX_CLASS(NAME, BASE) SYNTAX_CLASS(NAME, BASE) END_SYNTAX_CLASS() - -// Hack to remove 'warning C4702: unreachable code' on VS2017, blocking compilation -// Note! This is matched in object-meta-end.h -#if _MSC_VER >= 1910 -#pragma warning(push) -#pragma warning(disable: 4702) -#endif - diff --git a/source/slang/object-meta-end.h b/source/slang/object-meta-end.h deleted file mode 100644 index 96f41d39d..000000000 --- a/source/slang/object-meta-end.h +++ /dev/null @@ -1,17 +0,0 @@ -// object-meta-end.h - -#undef SYNTAX_CLASS -#undef ABSTRACT_SYNTAX_CLASS -#undef END_SYNTAX_CLASS -#undef SYNTAX_FIELD -#undef FIELD -#undef FIELD_INIT -#undef DECL_FIELD -#undef RAW -#undef SIMPLE_SYNTAX_CLASS - -// Hack to remove 'warning C4702: unreachable code' on VS2017, blocking compilation -// Note! This is matched in object-meta-begin.h -#if _MSC_VER >= 1910 -#pragma warning(pop) -#endif diff --git a/source/slang/options.cpp b/source/slang/options.cpp deleted file mode 100644 index 46e0203cf..000000000 --- a/source/slang/options.cpp +++ /dev/null @@ -1,1356 +0,0 @@ -// options.cpp - -// Implementation of options parsing for `slangc` command line, -// and also for API interface that takes command-line argument strings. - -#include "../../slang.h" - -#include "compiler.h" -#include "profile.h" - -#include - -namespace Slang { - -SlangResult tryReadCommandLineArgumentRaw(DiagnosticSink* sink, char const* option, char const* const**ioCursor, char const* const*end, char const** argOut) -{ - *argOut = nullptr; - char const* const*& cursor = *ioCursor; - if (cursor == end) - { - sink->diagnose(SourceLoc(), Diagnostics::expectedArgumentForOption, option); - return SLANG_FAIL; - } - else - { - *argOut = *cursor++; - return SLANG_OK; - } -} - -SlangResult tryReadCommandLineArgument(DiagnosticSink* sink, char const* option, char const* const**ioCursor, char const* const*end, String& argOut) -{ - const char* arg; - SLANG_RETURN_ON_FAIL(tryReadCommandLineArgumentRaw(sink, option, ioCursor, end, &arg)); - argOut = arg; - return SLANG_OK; -} - -struct OptionsParser -{ - SlangSession* session = nullptr; - SlangCompileRequest* compileRequest = nullptr; - - Slang::EndToEndCompileRequest* requestImpl = nullptr; - - Slang::RefPtr sharedLibraryLoader; - - // A "translation unit" represents one or more source files - // that are processed as a single entity when it comes to - // semantic checking. - // - // For languages like HLSL, GLSL, and C, a translation unit - // is usually a single source file (which can then go on - // to `#include` other files into the same translation unit). - // - // For Slang, we support having multiple source files in - // a single translation unit, and indeed command-line `slangc` - // will always put all the source files into a single translation - // unit. - // - // We track information on the translation units that we - // create during options parsing, so that we can assocaite - // other entities with these translation units: - // - struct RawTranslationUnit - { - // What language is the translation unit using? - // - // Note: We do not support translation units that mix - // languages. - // - SlangSourceLanguage sourceLanguage; - - // Certain naming conventions imply a stage for - // a file with only a single entry point, and in - // those cases we will try to infer the stage from - // the file when it is possible. - // - Stage impliedStage; - - // We retain the Slang API level translation unit index, - // which we will call an "ID" inside the options parsing code. - // - // This will almost always be the index into the - // `rawTranslationUnits` array below, but could conceivably, - // be mismatched if we were parsing options for a compile - // request that already had some translation unit(s) added - // manually. - // - int translationUnitID; - }; - List rawTranslationUnits; - - // If we already have a translation unit for Slang code, then this will give its index. - // If not, it will be `-1`. - int slangTranslationUnitIndex = -1; - - // The number of input files that have been specified - int inputPathCount = 0; - - int translationUnitCount = 0; - int currentTranslationUnitIndex= -1; - - // An entry point represents a function to be checked and possibly have - // code generated in one of our translation units. An entry point - // needs to have an associated stage, which might come via the - // `-stage` command line option, or a `[shader("...")]` attribute - // in the source code. - // - struct RawEntryPoint - { - String name; - Stage stage = Stage::Unknown; - int translationUnitIndex = -1; - int entryPointID = -1; - - // State for tracking command-line errors - bool conflictingStagesSet = false; - bool redundantStageSet = false; - }; - // - // We collect the entry points in a "raw" array so that we can - // possibly associate them with a stage or translation unit - // after the fact. - // - List rawEntryPoints; - - // In the case where we have only a single entry point, - // the entry point and its options might be specified out - // of order, so we will keep a single `RawEntryPoint` around - // and use it as the target for any state-setting options - // before the first "proper" entry point is specified. - RawEntryPoint defaultEntryPoint; - - SlangCompileFlags flags = 0; - - struct RawOutput - { - String path; - CodeGenTarget impliedFormat = CodeGenTarget::Unknown; - int targetIndex = -1; - int entryPointIndex = -1; - }; - List rawOutputs; - - struct RawTarget - { - CodeGenTarget format = CodeGenTarget::Unknown; - ProfileVersion profileVersion = ProfileVersion::Unknown; - SlangTargetFlags targetFlags = 0; - int targetID = -1; - FloatingPointMode floatingPointMode = FloatingPointMode::Default; - - // State for tracking command-line errors - bool conflictingProfilesSet = false; - bool redundantProfileSet = false; - - }; - List rawTargets; - - RawTarget defaultTarget; - - void addSharedLibraryPath(SharedLibraryType libType, const String& path) - { - if (!sharedLibraryLoader) - { - sharedLibraryLoader = new ConfigurableSharedLibraryLoader; - } - sharedLibraryLoader->addEntry(libType, ConfigurableSharedLibraryLoader::changePath, path); - } - - int addTranslationUnit( - SlangSourceLanguage language, - Stage impliedStage) - { - auto translationUnitIndex = rawTranslationUnits.getCount(); - auto translationUnitID = spAddTranslationUnit(compileRequest, language, nullptr); - - // As a sanity check: the API should be returning the same translation - // unit index as we maintain internally. This invariant would only - // be broken if we decide to support a mix of translation units specified - // via API, and ones specified via command-line arguments. - // - SLANG_RELEASE_ASSERT(Index(translationUnitID) == translationUnitIndex); - - RawTranslationUnit rawTranslationUnit; - rawTranslationUnit.sourceLanguage = language; - rawTranslationUnit.translationUnitID = translationUnitID; - rawTranslationUnit.impliedStage = impliedStage; - - rawTranslationUnits.add(rawTranslationUnit); - - return int(translationUnitIndex); - } - - void addInputSlangPath( - String const& path) - { - // All of the input .slang files will be grouped into a single logical translation unit, - // which we create lazily when the first .slang file is encountered. - if( slangTranslationUnitIndex == -1 ) - { - translationUnitCount++; - slangTranslationUnitIndex = addTranslationUnit(SLANG_SOURCE_LANGUAGE_SLANG, Stage::Unknown); - } - - spAddTranslationUnitSourceFile( - compileRequest, - rawTranslationUnits[slangTranslationUnitIndex].translationUnitID, - path.begin()); - - // Set the translation unit to be used by subsequent entry points - currentTranslationUnitIndex = slangTranslationUnitIndex; - } - - void addInputForeignShaderPath( - String const& path, - SlangSourceLanguage language, - Stage impliedStage) - { - translationUnitCount++; - currentTranslationUnitIndex = addTranslationUnit(language, impliedStage); - - spAddTranslationUnitSourceFile( - compileRequest, - rawTranslationUnits[currentTranslationUnitIndex].translationUnitID, - path.begin()); - } - - static Profile::RawVal findGlslProfileFromPath(const String& path) - { - struct Entry - { - const char* ext; - Profile::RawVal profileId; - }; - - static const Entry entries[] = - { - { ".frag", Profile::GLSL_Fragment }, - { ".geom", Profile::GLSL_Geometry }, - { ".tesc", Profile::GLSL_TessControl }, - { ".tese", Profile::GLSL_TessEval }, - { ".comp", Profile::GLSL_Compute } - }; - - for (int i = 0; i < SLANG_COUNT_OF(entries); ++i) - { - const Entry& entry = entries[i]; - if (path.endsWith(entry.ext)) - { - return entry.profileId; - } - } - return Profile::Unknown; - } - - static SlangSourceLanguage findSourceLanguageFromPath(const String& path, Stage& outImpliedStage) - { - struct Entry - { - const char* ext; - SlangSourceLanguage sourceLanguage; - SlangStage impliedStage; - }; - - static const Entry entries[] = - { - { ".slang", SLANG_SOURCE_LANGUAGE_SLANG, SLANG_STAGE_NONE }, - - { ".hlsl", SLANG_SOURCE_LANGUAGE_HLSL, SLANG_STAGE_NONE }, - { ".fx", SLANG_SOURCE_LANGUAGE_HLSL, SLANG_STAGE_NONE }, - - { ".glsl", SLANG_SOURCE_LANGUAGE_GLSL, SLANG_STAGE_NONE }, - { ".vert", SLANG_SOURCE_LANGUAGE_GLSL, SLANG_STAGE_VERTEX }, - { ".frag", SLANG_SOURCE_LANGUAGE_GLSL, SLANG_STAGE_FRAGMENT }, - { ".geom", SLANG_SOURCE_LANGUAGE_GLSL, SLANG_STAGE_GEOMETRY }, - { ".tesc", SLANG_SOURCE_LANGUAGE_GLSL, SLANG_STAGE_HULL }, - { ".tese", SLANG_SOURCE_LANGUAGE_GLSL, SLANG_STAGE_DOMAIN }, - { ".comp", SLANG_SOURCE_LANGUAGE_GLSL, SLANG_STAGE_COMPUTE }, - }; - - for (int i = 0; i < SLANG_COUNT_OF(entries); ++i) - { - const Entry& entry = entries[i]; - if (path.endsWith(entry.ext)) - { - outImpliedStage = Stage(entry.impliedStage); - return entry.sourceLanguage; - } - } - return SLANG_SOURCE_LANGUAGE_UNKNOWN; - } - - SlangResult addInputPath( - char const* inPath) - { - inputPathCount++; - - // look at the extension on the file name to determine - // how we should handle it. - String path = String(inPath); - - if( path.endsWith(".slang") ) - { - // Plain old slang code - addInputSlangPath(path); - return SLANG_OK; - } - - Stage impliedStage = Stage::Unknown; - SlangSourceLanguage sourceLanguage = findSourceLanguageFromPath(path, impliedStage); - - if (sourceLanguage == SLANG_SOURCE_LANGUAGE_UNKNOWN) - { - requestImpl->getSink()->diagnose(SourceLoc(), Diagnostics::cannotDeduceSourceLanguage, inPath); - return SLANG_FAIL; - } - - addInputForeignShaderPath(path, sourceLanguage, impliedStage); - - return SLANG_OK; - } - - void addOutputPath( - String const& path, - CodeGenTarget impliedFormat) - { - RawOutput rawOutput; - rawOutput.path = path; - rawOutput.impliedFormat = impliedFormat; - rawOutputs.add(rawOutput); - } - - void addOutputPath(char const* inPath) - { - String path = String(inPath); - - if (!inPath) {} -#define CASE(EXT, TARGET) \ - else if(path.endsWith(EXT)) do { addOutputPath(path, CodeGenTarget(SLANG_##TARGET)); } while(0) - - CASE(".hlsl", HLSL); - CASE(".fx", HLSL); - - CASE(".dxbc", DXBC); - CASE(".dxbc.asm", DXBC_ASM); - - CASE(".dxil", DXIL); - CASE(".dxil.asm", DXIL_ASM); - - CASE(".glsl", GLSL); - CASE(".vert", GLSL); - CASE(".frag", GLSL); - CASE(".geom", GLSL); - CASE(".tesc", GLSL); - CASE(".tese", GLSL); - CASE(".comp", GLSL); - - CASE(".spv", SPIRV); - CASE(".spv.asm", SPIRV_ASM); - - CASE(".c", C_SOURCE); - CASE(".cpp", CPP_SOURCE); - -#undef CASE - - else if (path.endsWith(".slang-module")) - { - spSetOutputContainerFormat(compileRequest, SLANG_CONTAINER_FORMAT_SLANG_MODULE); - requestImpl->containerOutputPath = path; - } - else - { - // Allow an unknown-format `-o`, assuming we get a target format - // from another argument. - addOutputPath(path, CodeGenTarget::Unknown); - } - } - - RawEntryPoint* getCurrentEntryPoint() - { - auto rawEntryPointCount = rawEntryPoints.getCount(); - return rawEntryPointCount ? &rawEntryPoints[rawEntryPointCount-1] : &defaultEntryPoint; - } - - void setStage(RawEntryPoint* rawEntryPoint, Stage stage) - { - if(rawEntryPoint->stage != Stage::Unknown) - { - rawEntryPoint->redundantStageSet = true; - if( stage != rawEntryPoint->stage ) - { - rawEntryPoint->conflictingStagesSet = true; - } - } - rawEntryPoint->stage = stage; - } - - RawTarget* getCurrentTarget() - { - auto rawTargetCount = rawTargets.getCount(); - return rawTargetCount ? &rawTargets[rawTargetCount-1] : &defaultTarget; - } - - void setProfileVersion(RawTarget* rawTarget, ProfileVersion profileVersion) - { - if(rawTarget->profileVersion != ProfileVersion::Unknown) - { - rawTarget->redundantProfileSet = true; - - if(profileVersion != rawTarget->profileVersion) - { - rawTarget->conflictingProfilesSet = true; - } - } - rawTarget->profileVersion = profileVersion; - } - - void setFloatingPointMode(RawTarget* rawTarget, FloatingPointMode mode) - { - rawTarget->floatingPointMode = mode; - } - - SlangResult parse( - int argc, - char const* const* argv) - { - // Copy some state out of the current request, in case we've been called - // after some other initialization has been performed. - flags = requestImpl->getFrontEndReq()->compileFlags; - - DiagnosticSink* sink = requestImpl->getSink(); - - SlangMatrixLayoutMode defaultMatrixLayoutMode = SLANG_MATRIX_LAYOUT_MODE_UNKNOWN; - - char const* const* argCursor = &argv[0]; - char const* const* argEnd = &argv[argc]; - while (argCursor != argEnd) - { - char const* arg = *argCursor++; - if (arg[0] == '-') - { - String argStr = String(arg); - - if(argStr == "-no-mangle" ) - { - flags |= SLANG_COMPILE_FLAG_NO_MANGLING; - } - else if (argStr == "-no-codegen") - { - flags |= SLANG_COMPILE_FLAG_NO_CODEGEN; - } - else if(argStr == "-dump-ir" ) - { - requestImpl->getFrontEndReq()->shouldDumpIR = true; - requestImpl->getBackEndReq()->shouldDumpIR = true; - } - else if (argStr == "-serial-ir") - { - requestImpl->getFrontEndReq()->useSerialIRBottleneck = true; - } - else if (argStr == "-verbose-paths") - { - requestImpl->getSink()->flags |= DiagnosticSink::Flag::VerbosePath; - } - else if (argStr == "-verify-debug-serial-ir") - { - requestImpl->getFrontEndReq()->verifyDebugSerialization = true; - } - else if(argStr == "-validate-ir" ) - { - requestImpl->getFrontEndReq()->shouldValidateIR = true; - requestImpl->getBackEndReq()->shouldValidateIR = true; - } - else if(argStr == "-skip-codegen" ) - { - requestImpl->shouldSkipCodegen = true; - } - else if(argStr == "-parameter-blocks-use-register-spaces" ) - { - getCurrentTarget()->targetFlags |= SLANG_TARGET_FLAG_PARAMETER_BLOCKS_USE_REGISTER_SPACES; - } - else if (argStr == "-target") - { - String name; - SLANG_RETURN_ON_FAIL(tryReadCommandLineArgument(sink, arg, &argCursor, argEnd, name)); - - SlangCompileTarget format = SLANG_TARGET_UNKNOWN; - - #define CASE(NAME, TARGET) \ - if(name == NAME) { format = SLANG_##TARGET; } else - - CASE("hlsl", HLSL) - CASE("glsl", GLSL) - CASE("dxbc", DXBC) - CASE("dxbc-assembly", DXBC_ASM) - CASE("dxbc-asm", DXBC_ASM) - CASE("spirv", SPIRV) - CASE("spirv-assembly", SPIRV_ASM) - CASE("spirv-asm", SPIRV_ASM) - CASE("dxil", DXIL) - CASE("dxil-assembly", DXIL_ASM) - CASE("dxil-asm", DXIL_ASM) - CASE("c", C_SOURCE) - CASE("cpp", CPP_SOURCE) - - #undef CASE - /* else */ - { - sink->diagnose(SourceLoc(), Diagnostics::unknownCodeGenerationTarget, name); - return SLANG_FAIL; - } - - RawTarget rawTarget; - rawTarget.format = CodeGenTarget(format); - - rawTargets.add(rawTarget); - } - // A "profile" can specify both a general capability level for - // a target, and also (as a legacy/compatibility feature) a - // specific stage to use for an entry point. - else if (argStr == "-profile") - { - String name; - SLANG_RETURN_ON_FAIL(tryReadCommandLineArgument(sink, arg, &argCursor, argEnd, name)); - - SlangProfileID profileID = spFindProfile(session, name.begin()); - if( profileID == SLANG_PROFILE_UNKNOWN ) - { - sink->diagnose(SourceLoc(), Diagnostics::unknownProfile, name); - return SLANG_FAIL; - } - else - { - auto profile = Profile(profileID); - - setProfileVersion(getCurrentTarget(), profile.GetVersion()); - - // A `-profile` option that also specifies a stage (e.g., `-profile vs_5_0`) - // should be treated like a composite (e.g., `-profile sm_5_0 -stage vertex`) - auto stage = profile.GetStage(); - if(stage != Stage::Unknown) - { - setStage(getCurrentEntryPoint(), stage); - } - } - } - else if (argStr == "-stage") - { - String name; - SLANG_RETURN_ON_FAIL(tryReadCommandLineArgument(sink, arg, &argCursor, argEnd, name)); - - Stage stage = findStageByName(name); - if( stage == Stage::Unknown ) - { - sink->diagnose(SourceLoc(), Diagnostics::unknownStage, name); - return SLANG_FAIL; - } - else - { - setStage(getCurrentEntryPoint(), stage); - } - } - else if (argStr == "-entry") - { - String name; - SLANG_RETURN_ON_FAIL(tryReadCommandLineArgument(sink, arg, &argCursor, argEnd, name)); - - RawEntryPoint rawEntryPoint; - rawEntryPoint.name = name; - rawEntryPoint.translationUnitIndex = currentTranslationUnitIndex; - - rawEntryPoints.add(rawEntryPoint); - } - else if (argStr == "-pass-through") - { - String name; - SLANG_RETURN_ON_FAIL(tryReadCommandLineArgument(sink, arg, &argCursor, argEnd, name)); - - SlangPassThrough passThrough = SLANG_PASS_THROUGH_NONE; - if (name == "fxc") { passThrough = SLANG_PASS_THROUGH_FXC; } - else if (name == "dxc") { passThrough = SLANG_PASS_THROUGH_DXC; } - else if (name == "glslang") { passThrough = SLANG_PASS_THROUGH_GLSLANG; } - else - { - sink->diagnose(SourceLoc(), Diagnostics::unknownPassThroughTarget, name); - return SLANG_FAIL; - } - - spSetPassThrough( - compileRequest, - passThrough); - } - else if (argStr == "-dxc-path") - { - String name; - SLANG_RETURN_ON_FAIL(tryReadCommandLineArgument(sink, arg, &argCursor, argEnd, name)); - addSharedLibraryPath(SharedLibraryType::Dxc, name); - addSharedLibraryPath(SharedLibraryType::Dxil, name); - } - else if (argStr == "-glslang-path") - { - String name; - SLANG_RETURN_ON_FAIL(tryReadCommandLineArgument(sink, arg, &argCursor, argEnd, name)); - addSharedLibraryPath(SharedLibraryType::Glslang, name); - } - else if (argStr == "-fxc-path") - { - String name; - SLANG_RETURN_ON_FAIL(tryReadCommandLineArgument(sink, arg, &argCursor, argEnd, name)); - addSharedLibraryPath(SharedLibraryType::Fxc, name); - } - else if (argStr[1] == 'D') - { - // The value to be defined might be part of the same option, as in: - // -DFOO - // or it might come separately, as in: - // -D FOO - char const* defineStr = arg + 2; - if (defineStr[0] == 0) - { - // Need to read another argument from the command line - SLANG_RETURN_ON_FAIL(tryReadCommandLineArgumentRaw(sink, arg, &argCursor, argEnd, &defineStr)); - } - // The string that sets up the define can have an `=` between - // the name to be defined and its value, so we search for one. - char const* eqPos = nullptr; - for(char const* dd = defineStr; *dd; ++dd) - { - if (*dd == '=') - { - eqPos = dd; - break; - } - } - - // Now set the preprocessor define - // - if (eqPos) - { - // If we found an `=`, we split the string... - - spAddPreprocessorDefine( - compileRequest, - String(defineStr, eqPos).begin(), - String(eqPos+1).begin()); - } - else - { - // If there was no `=`, then just #define it to an empty string - - spAddPreprocessorDefine( - compileRequest, - String(defineStr).begin(), - ""); - } - } - else if (argStr[1] == 'I') - { - // The value to be defined might be part of the same option, as in: - // -IFOO - // or it might come separately, as in: - // -I FOO - // (see handling of `-D` above) - char const* includeDirStr = arg + 2; - if (includeDirStr[0] == 0) - { - // Need to read another argument from the command line - SLANG_RETURN_ON_FAIL(tryReadCommandLineArgumentRaw(sink, arg, &argCursor, argEnd, &includeDirStr)); - } - - spAddSearchPath( - compileRequest, - String(includeDirStr).begin()); - } - // - // A `-o` option is used to specify a desired output file. - else if (argStr == "-o") - { - char const* outputPath = nullptr; - SLANG_RETURN_ON_FAIL(tryReadCommandLineArgumentRaw(sink, arg, &argCursor, argEnd, &outputPath)); - if (!outputPath) continue; - - addOutputPath(outputPath); - } - else if(argStr == "-matrix-layout-row-major") - { - defaultMatrixLayoutMode = kMatrixLayoutMode_RowMajor; - } - else if(argStr == "-matrix-layout-column-major") - { - defaultMatrixLayoutMode = kMatrixLayoutMode_ColumnMajor; - } - else if(argStr == "-line-directive-mode") - { - String name; - SLANG_RETURN_ON_FAIL(tryReadCommandLineArgument(sink, arg, &argCursor, argEnd, name)); - - SlangLineDirectiveMode mode = SLANG_LINE_DIRECTIVE_MODE_DEFAULT; - if(name == "none") - { - mode = SLANG_LINE_DIRECTIVE_MODE_NONE; - } - else - { - sink->diagnose(SourceLoc(), Diagnostics::unknownLineDirectiveMode, name); - return SLANG_FAIL; - } - - spSetLineDirectiveMode(compileRequest, mode); - - } - else if( argStr == "-fp-mode" || argStr == "-floating-point-mode" ) - { - String name; - SLANG_RETURN_ON_FAIL(tryReadCommandLineArgument(sink, arg, &argCursor, argEnd, name)); - - FloatingPointMode mode = FloatingPointMode::Default; - if(name == "fast") - { - mode = FloatingPointMode::Fast; - } - else if(name == "precise") - { - mode = FloatingPointMode::Precise; - } - else - { - sink->diagnose(SourceLoc(), Diagnostics::unknownFloatingPointMode, name); - return SLANG_FAIL; - } - - setFloatingPointMode(getCurrentTarget(), mode); - } - else if( argStr[1] == 'O' ) - { - char const* name = arg + 2; - SlangOptimizationLevel level = SLANG_OPTIMIZATION_LEVEL_DEFAULT; - - bool invalidOptimizationLevel = strlen(name) > 2; - switch( name[0] ) - { - case '0': level = SLANG_OPTIMIZATION_LEVEL_NONE; break; - case '1': level = SLANG_OPTIMIZATION_LEVEL_DEFAULT; break; - case '2': level = SLANG_OPTIMIZATION_LEVEL_HIGH; break; - case '3': level = SLANG_OPTIMIZATION_LEVEL_MAXIMAL; break; - case 0 : level = SLANG_OPTIMIZATION_LEVEL_DEFAULT; break; - default: - invalidOptimizationLevel = true; - break; - } - if( invalidOptimizationLevel ) - { - sink->diagnose(SourceLoc(), Diagnostics::unknownOptimiziationLevel, name); - return SLANG_FAIL; - } - - spSetOptimizationLevel(compileRequest, level); - } - - // Note: unlike with `-O` above, we have to consider that other - // options might have names that start with `-g` and so cannot - // just detect it as a prefix. - else if( argStr == "-g" || argStr == "-g2" ) - { - spSetDebugInfoLevel(compileRequest, SLANG_DEBUG_INFO_LEVEL_STANDARD); - } - else if( argStr == "-g0" ) - { - spSetDebugInfoLevel(compileRequest, SLANG_DEBUG_INFO_LEVEL_NONE); - } - else if( argStr == "-g1" ) - { - spSetDebugInfoLevel(compileRequest, SLANG_DEBUG_INFO_LEVEL_MINIMAL); - } - else if( argStr == "-g3" ) - { - spSetDebugInfoLevel(compileRequest, SLANG_DEBUG_INFO_LEVEL_MAXIMAL); - } - else if( argStr == "-default-image-format-unknown" ) - { - requestImpl->getBackEndReq()->useUnknownImageFormatAsDefault = true; - } - else if (argStr == "--") - { - // The `--` option causes us to stop trying to parse options, - // and treat the rest of the command line as input file names: - while (argCursor != argEnd) - { - SLANG_RETURN_ON_FAIL(addInputPath(*argCursor++)); - } - break; - } - else - { - sink->diagnose(SourceLoc(), Diagnostics::unknownCommandLineOption, argStr); - // TODO: print a usage message - return SLANG_FAIL; - } - } - else - { - SLANG_RETURN_ON_FAIL(addInputPath(arg)); - } - } - - spSetCompileFlags(compileRequest, flags); - - // As a compatability feature, if the user didn't list any explicit entry - // point names, *and* they are compiling a single translation unit, *and* they - // have either specified a stage, or we can assume one from the naming - // of the translation unit, then we assume they wanted to compile a single - // entry point named `main`. - // - if(rawEntryPoints.getCount() == 0 - && rawTranslationUnits.getCount() == 1 - && (defaultEntryPoint.stage != Stage::Unknown - || rawTranslationUnits[0].impliedStage != Stage::Unknown)) - { - RawEntryPoint entry; - entry.name = "main"; - entry.translationUnitIndex = 0; - rawEntryPoints.add(entry); - } - - // If the user (manually or implicitly) specified only a single entry point, - // then we allow the associated stage to be specified either before or after - // the entry point. This means that if there is a stage attached - // to the "default" entry point, we should copy it over to the - // explicit one. - // - if( rawEntryPoints.getCount() == 1 ) - { - if(defaultEntryPoint.stage != Stage::Unknown) - { - setStage(getCurrentEntryPoint(), defaultEntryPoint.stage); - } - - if(defaultEntryPoint.redundantStageSet) - getCurrentEntryPoint()->redundantStageSet = true; - if(defaultEntryPoint.conflictingStagesSet) - getCurrentEntryPoint()->conflictingStagesSet = true; - } - else - { - // If the "default" entry point has had a stage (or - // other state, if we add other per-entry-point state) - // specified, but there is more than one entry point, - // then that state doesn't apply to anything and we - // should issue an error to tell the user something - // funky is going on. - // - if( defaultEntryPoint.stage != Stage::Unknown ) - { - if( rawEntryPoints.getCount() == 0 ) - { - sink->diagnose(SourceLoc(), Diagnostics::stageSpecificationIgnoredBecauseNoEntryPoints); - } - else - { - sink->diagnose(SourceLoc(), Diagnostics::stageSpecificationIgnoredBecauseBeforeAllEntryPoints); - } - } - } - - // Slang requires that every explicit entry point indicate the translation - // unit it comes from. If there is only one translation unit specified, - // then implicitly all entry points come from it. - // - if(translationUnitCount == 1) - { - for( auto& entryPoint : rawEntryPoints ) - { - entryPoint.translationUnitIndex = 0; - } - } - else - { - // Otherwise, we require that all entry points be specified after - // the translation unit to which tye belong. - bool anyEntryPointWithoutTranslationUnit = false; - for( auto& entryPoint : rawEntryPoints ) - { - // Skip entry points that are already associated with a translation unit... - if( entryPoint.translationUnitIndex != -1 ) - continue; - - anyEntryPointWithoutTranslationUnit = true; - } - if( anyEntryPointWithoutTranslationUnit ) - { - sink->diagnose(SourceLoc(), Diagnostics::entryPointsNeedToBeAssociatedWithTranslationUnits); - return SLANG_FAIL; - } - } - - // Now that entry points are associated with translation units, - // we can make one additional pass where if an entry point has - // no specified stage, but the nameing of its translation unit - // implies a stage, we will use that (a manual `-stage` annotation - // will always win out in such a case). - // - for( auto& rawEntryPoint : rawEntryPoints ) - { - // Skip entry points that already have a stage. - if(rawEntryPoint.stage != Stage::Unknown) - continue; - - // Sanity check: don't process entry points with no associated translation unit. - if( rawEntryPoint.translationUnitIndex == -1 ) - continue; - - auto impliedStage = rawTranslationUnits[rawEntryPoint.translationUnitIndex].impliedStage; - if(impliedStage != Stage::Unknown) - rawEntryPoint.stage = impliedStage; - } - - // Note: it is possible that some entry points still won't have associated - // stages at this point, but we don't want to error out here, because - // those entry points might get stages later, as part of semantic checking, - // if the corresponding function has a `[shader("...")]` attribute. - - // Now that we've tried to establish stages for entry points, we can - // issue diagnostics for cases where stages were set redundantly or - // in conflicting ways. - // - for( auto& rawEntryPoint : rawEntryPoints ) - { - if( rawEntryPoint.conflictingStagesSet ) - { - sink->diagnose(SourceLoc(), Diagnostics::conflictingStagesForEntryPoint, rawEntryPoint.name); - } - else if( rawEntryPoint.redundantStageSet ) - { - sink->diagnose(SourceLoc(), Diagnostics::sameStageSpecifiedMoreThanOnce, rawEntryPoint.stage, rawEntryPoint.name); - } - else if( rawEntryPoint.translationUnitIndex != -1 ) - { - // As a quality-of-life feature, if the file name implies a particular - // stage, but the user manually specified something different for - // their entry point, give a warning in case they made a mistake. - - auto& rawTranslationUnit = rawTranslationUnits[rawEntryPoint.translationUnitIndex]; - if( rawTranslationUnit.impliedStage != Stage::Unknown - && rawEntryPoint.stage != Stage::Unknown - && rawTranslationUnit.impliedStage != rawEntryPoint.stage ) - { - sink->diagnose(SourceLoc(), Diagnostics::explicitStageDoesntMatchImpliedStage, rawEntryPoint.name, rawEntryPoint.stage, rawTranslationUnit.impliedStage); - } - } - } - - // If the user is requesting code generation via pass-through, - // then any entry points they specify need to have a stage set, - // because fxc/dxc/glslang don't have a facility for taking - // a named entry point and pulling its stage from an attribute. - // - if( requestImpl->passThrough != PassThroughMode::None ) - { - for( auto& rawEntryPoint : rawEntryPoints ) - { - if( rawEntryPoint.stage == Stage::Unknown ) - { - sink->diagnose(SourceLoc(), Diagnostics::noStageSpecifiedInPassThroughMode, rawEntryPoint.name); - } - } - } - - // We now have inferred enough information to add the - // entry points to our compile request. - // - for( auto& rawEntryPoint : rawEntryPoints ) - { - if(rawEntryPoint.translationUnitIndex < 0) - continue; - - auto translationUnitID = rawTranslationUnits[rawEntryPoint.translationUnitIndex].translationUnitID; - - int entryPointID = spAddEntryPoint( - compileRequest, - translationUnitID, - rawEntryPoint.name.begin(), - SlangStage(rawEntryPoint.stage)); - - rawEntryPoint.entryPointID = entryPointID; - } - - // We are going to build a mapping from target formats to the - // target that handles that format. - Dictionary mapFormatToTargetIndex; - - // If there was no explicit `-target` specified, then we will look - // at the `-o` options to see what we can infer. - // - if(rawTargets.getCount() == 0) - { - for(auto& rawOutput : rawOutputs) - { - // Some outputs don't imply a target format, and we shouldn't use those for inference. - auto impliedFormat = rawOutput.impliedFormat; - if( impliedFormat == CodeGenTarget::Unknown ) - continue; - - int targetIndex = 0; - if( !mapFormatToTargetIndex.TryGetValue(impliedFormat, targetIndex) ) - { - targetIndex = (int) rawTargets.getCount(); - - RawTarget rawTarget; - rawTarget.format = impliedFormat; - rawTargets.add(rawTarget); - - mapFormatToTargetIndex[impliedFormat] = targetIndex; - } - - rawOutput.targetIndex = targetIndex; - } - } - else - { - // If there were explicit targets, then we will use those, but still - // build up our mapping. We should object if the same target format - // is specified more than once (just because of the ambiguities - // it will create). - // - int targetCount = (int) rawTargets.getCount(); - for(int targetIndex = 0; targetIndex < targetCount; ++targetIndex) - { - auto format = rawTargets[targetIndex].format; - - if( mapFormatToTargetIndex.ContainsKey(format) ) - { - sink->diagnose(SourceLoc(), Diagnostics::duplicateTargets, format); - } - else - { - mapFormatToTargetIndex[format] = targetIndex; - } - } - } - - // If we weren't able to infer any targets from output paths (perhaps - // because there were no output paths), but there was a profile specified, - // then we can try to infer a target from the profile. - // - if( rawTargets.getCount() == 0 - && defaultTarget.profileVersion != ProfileVersion::Unknown - && !defaultTarget.conflictingProfilesSet) - { - // Let's see if the chosen profile allows us to infer - // the code gen target format that the user probably meant. - // - CodeGenTarget inferredFormat = CodeGenTarget::Unknown; - auto profileVersion = defaultTarget.profileVersion; - switch( Profile(profileVersion).getFamily() ) - { - default: - break; - - // For GLSL profile versions, we will assume SPIR-V - // is the output format the user intended. - case ProfileFamily::GLSL: - inferredFormat = CodeGenTarget::SPIRV; - break; - - // For DX profile versions, we will assume that the - // user wants DXIL for Shader Model 6.0 and up, - // and DXBC for all earlier versions. - // - // Note: There is overlap where both DXBC and DXIL - // nominally support SM 5.1, but in general we - // expect users to prefer to make a clean break - // at SM 6.0. Anybody who cares about the overlap - // cases should manually specify `-target dxil`. - // - case ProfileFamily::DX: - if( profileVersion >= ProfileVersion::DX_6_0 ) - { - inferredFormat = CodeGenTarget::DXIL; - } - else - { - inferredFormat = CodeGenTarget::DXBytecode; - } - break; - } - - if( inferredFormat != CodeGenTarget::Unknown ) - { - RawTarget rawTarget; - rawTarget.format = inferredFormat; - rawTargets.add(rawTarget); - } - } - - // Similar to the case for entry points, if there is a single target, - // then we allow some of its options to come from the "default" - // target state. - if(rawTargets.getCount() == 1) - { - if(defaultTarget.profileVersion != ProfileVersion::Unknown) - { - setProfileVersion(getCurrentTarget(), defaultTarget.profileVersion); - } - - getCurrentTarget()->targetFlags |= defaultTarget.targetFlags; - - if( defaultTarget.floatingPointMode != FloatingPointMode::Default ) - { - setFloatingPointMode(getCurrentTarget(), defaultTarget.floatingPointMode); - } - } - else - { - // If the "default" target has had a profile (or other state) - // specified, but there is != 1 taget, then that state doesn't - // apply to anythign and we should give the user an error. - // - if( defaultTarget.profileVersion != ProfileVersion::Unknown ) - { - if( rawTargets.getCount() == 0 ) - { - // This should only happen if there were multiple `-profile` options, - // so we didn't try to infer a target, or if the `-profile` option - // somehow didn't imply a target. - // - sink->diagnose(SourceLoc(), Diagnostics::profileSpecificationIgnoredBecauseNoTargets); - } - else - { - sink->diagnose(SourceLoc(), Diagnostics::profileSpecificationIgnoredBecauseBeforeAllTargets); - } - } - - if( defaultTarget.targetFlags ) - { - if( rawTargets.getCount() == 0 ) - { - sink->diagnose(SourceLoc(), Diagnostics::targetFlagsIgnoredBecauseNoTargets); - } - else - { - sink->diagnose(SourceLoc(), Diagnostics::targetFlagsIgnoredBecauseBeforeAllTargets); - } - } - - if( defaultTarget.floatingPointMode != FloatingPointMode::Default ) - { - if( rawTargets.getCount() == 0 ) - { - sink->diagnose(SourceLoc(), Diagnostics::targetFlagsIgnoredBecauseNoTargets); - } - else - { - sink->diagnose(SourceLoc(), Diagnostics::targetFlagsIgnoredBecauseBeforeAllTargets); - } - } - - } - - for(auto& rawTarget : rawTargets) - { - if( rawTarget.conflictingProfilesSet ) - { - sink->diagnose(SourceLoc(), Diagnostics::conflictingProfilesSpecifiedForTarget, rawTarget.format); - } - else if( rawTarget.redundantProfileSet ) - { - sink->diagnose(SourceLoc(), Diagnostics::sameProfileSpecifiedMoreThanOnce, rawTarget.profileVersion, rawTarget.format); - } - } - - // TODO: do we need to require that a target must have a profile specified, - // or will we continue to allow the profile to be inferred from the target? - - // We now have enough information to go ahead and declare the targets - // through the Slang API: - // - for(auto& rawTarget : rawTargets) - { - int targetID = spAddCodeGenTarget(compileRequest, SlangCompileTarget(rawTarget.format)); - rawTarget.targetID = targetID; - - if( rawTarget.profileVersion != ProfileVersion::Unknown ) - { - spSetTargetProfile(compileRequest, targetID, Profile(rawTarget.profileVersion).raw); - } - - if( rawTarget.targetFlags ) - { - spSetTargetFlags(compileRequest, targetID, rawTarget.targetFlags); - } - - if( rawTarget.floatingPointMode != FloatingPointMode::Default ) - { - spSetTargetFloatingPointMode(compileRequest, targetID, SlangFloatingPointMode(rawTarget.floatingPointMode)); - } - } - - if(defaultMatrixLayoutMode != SLANG_MATRIX_LAYOUT_MODE_UNKNOWN) - { - spSetMatrixLayoutMode(compileRequest, defaultMatrixLayoutMode); - } - - // Next we need to sort out the output files specified with `-o`, and - // figure out which entry point and/or target they apply to. - // - // If there is only a single entry point, then that is automatically - // the entry point that should be associated with all outputs. - // - if( rawEntryPoints.getCount() == 1 ) - { - for( auto& rawOutput : rawOutputs ) - { - rawOutput.entryPointIndex = 0; - } - } - // - // Similarly, if there is only one target, then all outputs must - // implicitly appertain to that target. - // - if( rawTargets.getCount() == 1 ) - { - for( auto& rawOutput : rawOutputs ) - { - rawOutput.targetIndex = 0; - } - } - - // Consider the output files specified via `-o` and try to figure - // out how to deal with them. - // - for(auto& rawOutput : rawOutputs) - { - // For now, all output formats need to be tightly bound to - // both a target and an entry point (down the road we will - // need to support output formats that can store multiple - // entry points in one file). - - // If an output doesn't have a target associated with - // it, then search for the target with the matching format. - if( rawOutput.targetIndex == -1 ) - { - auto impliedFormat = rawOutput.impliedFormat; - int targetIndex = -1; - - if(impliedFormat == CodeGenTarget::Unknown) - { - // If we hit this case, then it means that we need to pick the - // target to assocaite with this output based on its implied - // format, but the file path doesn't direclty imply a format - // (it doesn't have a suffix like `.spv` that tells us what to write). - // - sink->diagnose(SourceLoc(), Diagnostics::cannotDeduceOutputFormatFromPath, rawOutput.path); - } - else if( mapFormatToTargetIndex.TryGetValue(rawOutput.impliedFormat, targetIndex) ) - { - rawOutput.targetIndex = targetIndex; - } - else - { - sink->diagnose(SourceLoc(), Diagnostics::cannotMatchOutputFileToTarget, rawOutput.path, rawOutput.impliedFormat); - } - } - - // We won't do any searching to match an output file - // with an entry point, since the case of a single entry - // point was handled above, and the user is expected to - // follow the ordering rules when using multiple entry points. - // - if( rawOutput.entryPointIndex == -1 ) - { - sink->diagnose(SourceLoc(), Diagnostics::cannotMatchOutputFileToEntryPoint, rawOutput.path); - } - } - - // Now that we've diagnosed the output paths, we can add them - // to the compile request at the appropriate locations. - // - // We will consider the output files specified via `-o` and try to figure - // out how to deal with them. - // - for(auto& rawOutput : rawOutputs) - { - if(rawOutput.targetIndex == -1) continue; - if(rawOutput.entryPointIndex == -1) continue; - - auto targetID = rawTargets[rawOutput.targetIndex].targetID; - Int entryPointID = rawEntryPoints[rawOutput.entryPointIndex].entryPointID; - - auto target = requestImpl->getLinkage()->targets[targetID]; - auto entryPointReq = requestImpl->getFrontEndReq()->getEntryPointReqs()[entryPointID]; - - RefPtr targetInfo; - if( !requestImpl->targetInfos.TryGetValue(target, targetInfo) ) - { - targetInfo = new EndToEndCompileRequest::TargetInfo(); - requestImpl->targetInfos[target] = targetInfo; - } - - String outputPath; - if( targetInfo->entryPointOutputPaths.ContainsKey(entryPointID) ) - { - sink->diagnose(SourceLoc(), Diagnostics::duplicateOutputPathsForEntryPointAndTarget, entryPointReq->getName(), target->getTarget()); - } - else - { - targetInfo->entryPointOutputPaths[entryPointID] = rawOutput.path; - } - } - - if (sharedLibraryLoader) - { - spSessionSetSharedLibraryLoader(session, sharedLibraryLoader); - } - - return (sink->GetErrorCount() == 0) ? SLANG_OK : SLANG_FAIL; - } -}; - - -SlangResult parseOptions( - SlangCompileRequest* compileRequestIn, - int argc, - char const* const* argv) -{ - Slang::EndToEndCompileRequest* compileRequest = (Slang::EndToEndCompileRequest*) compileRequestIn; - - OptionsParser parser; - parser.compileRequest = compileRequestIn; - parser.requestImpl = compileRequest; - parser.session = (SlangSession*)compileRequest->getSession(); - - Result res = parser.parse(argc, argv); - - DiagnosticSink* sink = compileRequest->getSink(); - if (sink->GetErrorCount() > 0) - { - // Put the errors in the diagnostic - compileRequest->mDiagnosticOutput = sink->outputBuffer.ProduceString(); - } - - return res; -} - - -} // namespace Slang - -SLANG_API SlangResult spProcessCommandLineArguments( - SlangCompileRequest* request, - char const* const* args, - int argCount) -{ - return Slang::parseOptions(request, argCount, args); -} diff --git a/source/slang/parameter-binding.cpp b/source/slang/parameter-binding.cpp deleted file mode 100644 index bdb76a005..000000000 --- a/source/slang/parameter-binding.cpp +++ /dev/null @@ -1,2583 +0,0 @@ -// parameter-binding.cpp -#include "parameter-binding.h" - -#include "lookup.h" -#include "compiler.h" -#include "type-layout.h" - -#include "../../slang.h" - -namespace Slang { - -struct ParameterInfo; - -// Information on ranges of registers already claimed/used -struct UsedRange -{ - // What parameter has claimed this range? - VarLayout* parameter; - - // Begin/end of the range (half-open interval) - UInt begin; - UInt end; -}; -bool operator<(UsedRange left, UsedRange right) -{ - if (left.begin != right.begin) - return left.begin < right.begin; - if (left.end != right.end) - return left.end < right.end; - return false; -} - -static bool rangesOverlap(UsedRange const& x, UsedRange const& y) -{ - SLANG_ASSERT(x.begin <= x.end); - SLANG_ASSERT(y.begin <= y.end); - - // If they don't overlap, then one must be earlier than the other, - // and that one must therefore *end* before the other *begins* - - if (x.end <= y.begin) return false; - if (y.end <= x.begin) return false; - - // Otherwise they must overlap - return true; -} - - -struct UsedRanges -{ - // The `ranges` array maintains a sorted list of `UsedRange` - // objects such that the `end` of a range is <= the `begin` - // of any range that comes after it. - // - // The values covered by each `[begin,end)` range are marked - // as used, and anything not in such an interval is implicitly - // free. - // - // TODO: if it ever starts to matter for performance, we - // could encode this information as a tree instead of an array. - // - List ranges; - - // Add a range to the set, either by extending - // existing range(s), or by adding a new one. - // - // If we find that the new range overlaps with - // an existing range for a *different* parameter - // then we return that parameter so that the - // caller can issue an error. - // - VarLayout* Add(UsedRange range) - { - // The invariant on entry to this - // function is that the `ranges` array - // is sorted and no two entries in the - // array intersect. We must preserve - // that property as a postcondition. - // - // The other postcondition is that the - // interval covered by the input `range` - // must be marked as consumed. - - // We will try track any parameter associated - // with an overlapping range that doesn't - // match the parameter on `range`, so that - // the compiler can issue useful diagnostics. - // - VarLayout* newParam = range.parameter; - VarLayout* existingParam = nullptr; - - // A clever algorithm might use a binary - // search to identify the first entry in `ranges` - // that might overlap `range`, but we are going - // to settle for being less clever for now, in - // the hopes that we can at least be correct. - // - // Note: we are going to iterate over `ranges` - // using indices, because we may actually modify - // the array as we go. - // - Int rangeCount = ranges.getCount(); - for(Int rr = 0; rr < rangeCount; ++rr) - { - auto existingRange = ranges[rr]; - - // The invariant on entry to each loop - // iteration will be that `range` does - // *not* intersect any preceding entry - // in the array. - // - // Note that this invariant might be - // true only because we modified - // `range` along the way. - // - // If `range` does not intertsect `existingRange` - // then our invariant will be trivially - // true for the next iteration. - // - if(!rangesOverlap(existingRange, range)) - { - continue; - } - - // We now know that `range` and `existingRange` - // intersect. The first thing to do - // is to check if we have a parameter - // associated with `existingRange`, so - // that we can use it for emitting diagnostics - // about the overlap: - // - if( existingRange.parameter - && existingRange.parameter != newParam) - { - // There was an overlap with a range that - // had a parameter specified, so we will - // use that parameter in any subsequent - // diagnostics. - // - existingParam = existingRange.parameter; - } - - // Before we can move on in our iteration, - // we need to re-establish our invariant by modifying - // `range` so that it doesn't overlap with `existingRange`. - // Of course we also want to end up with a correct - // result for the overall operation, so we can't just - // throw away intervals. - // - // We first note that if `range` starts before `existingRange`, - // then the interval from `range.begin` to `existingRange.begin` - // needs to be accounted for in the final result. Furthermore, - // the interval `[range.begin, existingRange.begin)` could not - // intersect with any range already in the `ranges` array, - // because it comes strictly before `existingRange`, and our - // invariant says there is no intersection with preceding ranges. - // - if(range.begin < existingRange.begin) - { - UsedRange prefix; - prefix.begin = range.begin; - prefix.end = existingRange.begin; - prefix.parameter = range.parameter; - ranges.add(prefix); - } - // - // Now we know that the interval `[range.begin, existingRange.begin)` - // is claimed, if it exists, and clearly the interval - // `[existingRange.begin, existingRange.end)` is already claimed, - // so the only interval left to consider would be - // `[existingRange.end, range.end)`, if it is non-empty. - // That range might intersect with others in the array, so - // we will need to continue iterating to deal with that - // possibility. - // - range.begin = existingRange.end; - - // If the range would be empty, then of course we have nothing - // left to do. - // - if(range.begin >= range.end) - break; - - // Otherwise, have can be sure that `range` now comes - // strictly *after* `existingRange`, and thus our invariant - // is preserved. - } - - // If we manage to exit the loop, then we have resolved - // an intersection with existing entries - possibly by - // adding some new entries. - // - // If the `range` we are left with is still non-empty, - // then we should go ahead and add it. - // - if(range.begin < range.end) - { - ranges.add(range); - } - - // Any ranges that got added along the way might not - // be in the proper sorted order, so we'll need to - // sort the array to restore our global invariant. - // - ranges.sort(); - - // We end by returning an overlapping parameter that - // we found along the way, if any. - // - return existingParam; - } - - VarLayout* Add(VarLayout* param, UInt begin, UInt end) - { - UsedRange range; - range.parameter = param; - range.begin = begin; - range.end = end; - return Add(range); - } - - VarLayout* Add(VarLayout* param, UInt begin, LayoutSize end) - { - UsedRange range; - range.parameter = param; - range.begin = begin; - range.end = end.isFinite() ? end.getFiniteValue() : UInt(-1); - return Add(range); - } - - bool contains(UInt index) - { - for (auto rr : ranges) - { - if (index < rr.begin) - return false; - - if (index >= rr.end) - continue; - - return true; - } - - return false; - } - - - // Try to find space for `count` entries - UInt Allocate(VarLayout* param, UInt count) - { - UInt begin = 0; - - UInt rangeCount = ranges.getCount(); - for (UInt rr = 0; rr < rangeCount; ++rr) - { - // try to fit in before this range... - - UInt end = ranges[rr].begin; - - // If there is enough space... - if (end >= begin + count) - { - // ... then claim it and be done - Add(param, begin, begin + count); - return begin; - } - - // ... otherwise, we need to look at the - // space between this range and the next - begin = ranges[rr].end; - } - - // We've run out of ranges to check, so we - // can safely go after the last one! - Add(param, begin, begin + count); - return begin; - } -}; - -struct ParameterBindingInfo -{ - size_t space = 0; - size_t index = 0; - LayoutSize count; -}; - -struct ParameterBindingAndKindInfo : ParameterBindingInfo -{ - LayoutResourceKind kind = LayoutResourceKind::None; -}; - -enum -{ - kLayoutResourceKindCount = SLANG_PARAMETER_CATEGORY_COUNT, -}; - -struct UsedRangeSet : RefObject -{ - // Information on what ranges of "registers" have already - // been claimed, for each resource type - UsedRanges usedResourceRanges[kLayoutResourceKindCount]; -}; - -// Information on a single parameter -struct ParameterInfo : RefObject -{ - // Layout info for the concrete variables that will make up this parameter - List> varLayouts; - - ParameterBindingInfo bindingInfo[kLayoutResourceKindCount]; - - // The translation unit this parameter is specific to, if any - TranslationUnitRequest* translationUnit = nullptr; - - ParameterInfo() - { - // Make sure we aren't claiming any resources yet - for( int ii = 0; ii < kLayoutResourceKindCount; ++ii ) - { - bindingInfo[ii].count = 0; - } - } -}; - -struct EntryPointParameterBindingContext -{ - // What ranges of resources bindings are already claimed for this translation unit - UsedRangeSet usedRangeSet; -}; - -// State that is shared during parameter binding, -// across all translation units -struct SharedParameterBindingContext -{ - SharedParameterBindingContext( - LayoutRulesFamilyImpl* defaultLayoutRules, - ProgramLayout* programLayout, - TargetRequest* targetReq, - DiagnosticSink* sink) - : defaultLayoutRules(defaultLayoutRules) - , programLayout(programLayout) - , targetRequest(targetReq) - , m_sink(sink) - { - } - - DiagnosticSink* m_sink = nullptr; - - // The program that we are laying out -// Program* program = nullptr; - - // The target request that is triggering layout - // - // TODO: We should eventually strip this down to - // just the subset of fields on the target that - // can influence layout decisions. - TargetRequest* targetRequest = nullptr; - - LayoutRulesFamilyImpl* defaultLayoutRules; - - // All shader parameters we've discovered so far, and started to lay out... - List> parameters; - - // The program layout we are trying to construct - RefPtr programLayout; - - // What ranges of resources bindings are already claimed at the global scope? - // We store one of these for each declared binding space/set. - // - Dictionary> globalSpaceUsedRangeSets; - - // Which register spaces have been claimed so far? - UsedRanges usedSpaces; - - // The space to use for auto-generated bindings. - UInt defaultSpace = 0; - - TargetRequest* getTargetRequest() { return targetRequest; } - DiagnosticSink* getSink() { return m_sink; } -}; - -static DiagnosticSink* getSink(SharedParameterBindingContext* shared) -{ - return shared->getSink(); -} - -// State that might be specific to a single translation unit -// or event to an entry point. -struct ParameterBindingContext -{ - // All the shared state needs to be available - SharedParameterBindingContext* shared; - - // The type layout context to use when computing - // the resource usage of shader parameters. - TypeLayoutContext layoutContext; - - // What stage (if any) are we compiling for? - Stage stage; - - // The entry point that is being processed right now. - EntryPointLayout* entryPointLayout = nullptr; - - TargetRequest* getTargetRequest() { return shared->getTargetRequest(); } - LayoutRulesFamilyImpl* getRulesFamily() { return layoutContext.getRulesFamily(); } -}; - -static DiagnosticSink* getSink(ParameterBindingContext* context) -{ - return getSink(context->shared); -} - - -struct LayoutSemanticInfo -{ - LayoutResourceKind kind; // the register kind - UInt space; - UInt index; - - // TODO: need to deal with component-granularity binding... -}; - -static bool isDigit(char c) -{ - return (c >= '0') && (c <= '9'); -} - -/// Given a string that specifies a name and index (e.g., `COLOR0`), -/// split it into slices for the name part and the index part. -static void splitNameAndIndex( - UnownedStringSlice const& text, - UnownedStringSlice& outName, - UnownedStringSlice& outDigits) -{ - char const* nameBegin = text.begin(); - char const* digitsEnd = text.end(); - - char const* nameEnd = digitsEnd; - while( nameEnd != nameBegin && isDigit(*(nameEnd - 1)) ) - { - nameEnd--; - } - char const* digitsBegin = nameEnd; - - outName = UnownedStringSlice(nameBegin, nameEnd); - outDigits = UnownedStringSlice(digitsBegin, digitsEnd); -} - -LayoutResourceKind findRegisterClassFromName(UnownedStringSlice const& registerClassName) -{ - switch( registerClassName.size() ) - { - case 1: - switch (*registerClassName.begin()) - { - case 'b': return LayoutResourceKind::ConstantBuffer; - case 't': return LayoutResourceKind::ShaderResource; - case 'u': return LayoutResourceKind::UnorderedAccess; - case 's': return LayoutResourceKind::SamplerState; - - default: - break; - } - break; - - case 5: - if( registerClassName == "space" ) - { - return LayoutResourceKind::RegisterSpace; - } - break; - - default: - break; - } - return LayoutResourceKind::None; -} - -LayoutSemanticInfo ExtractLayoutSemanticInfo( - ParameterBindingContext* context, - HLSLLayoutSemantic* semantic) -{ - LayoutSemanticInfo info; - info.space = 0; - info.index = 0; - info.kind = LayoutResourceKind::None; - - UnownedStringSlice registerName = semantic->registerName.Content; - if (registerName.size() == 0) - return info; - - // The register name is expected to be in the form: - // - // identifier-char+ digit+ - // - // where the identifier characters name a "register class" - // and the digits identify a register index within that class. - // - // We are going to split the string the user gave us - // into these constituent parts: - // - UnownedStringSlice registerClassName; - UnownedStringSlice registerIndexDigits; - splitNameAndIndex(registerName, registerClassName, registerIndexDigits); - - LayoutResourceKind kind = findRegisterClassFromName(registerClassName); - if(kind == LayoutResourceKind::None) - { - getSink(context)->diagnose(semantic->registerName, Diagnostics::unknownRegisterClass, registerClassName); - return info; - } - - // For a `register` semantic, the register index is not optional (unlike - // how it works for varying input/output semantics). - if( registerIndexDigits.size() == 0 ) - { - getSink(context)->diagnose(semantic->registerName, Diagnostics::expectedARegisterIndex, registerClassName); - } - - UInt index = 0; - for(auto c : registerIndexDigits) - { - SLANG_ASSERT(isDigit(c)); - index = index * 10 + (c - '0'); - } - - - UInt space = 0; - if( auto registerSemantic = as(semantic) ) - { - auto const& spaceName = registerSemantic->spaceName.Content; - if(spaceName.size() != 0) - { - UnownedStringSlice spaceSpelling; - UnownedStringSlice spaceDigits; - splitNameAndIndex(spaceName, spaceSpelling, spaceDigits); - - if( kind == LayoutResourceKind::RegisterSpace ) - { - getSink(context)->diagnose(registerSemantic->spaceName, Diagnostics::unexpectedSpecifierAfterSpace, spaceName); - } - else if( spaceSpelling != UnownedTerminatedStringSlice("space") ) - { - getSink(context)->diagnose(registerSemantic->spaceName, Diagnostics::expectedSpace, spaceSpelling); - } - else if( spaceDigits.size() == 0 ) - { - getSink(context)->diagnose(registerSemantic->spaceName, Diagnostics::expectedSpaceIndex); - } - else - { - for(auto c : spaceDigits) - { - SLANG_ASSERT(isDigit(c)); - space = space * 10 + (c - '0'); - } - } - } - } - - // TODO: handle component mask part of things... - if( semantic->componentMask.Content.size() != 0 ) - { - getSink(context)->diagnose(semantic->componentMask, Diagnostics::componentMaskNotSupported); - } - - info.kind = kind; - info.index = (int) index; - info.space = space; - return info; -} - - -// - -// Given a GLSL `layout` modifier, we need to be able to check for -// a particular sub-argument and extract its value if present. -template -static bool findLayoutArg( - RefPtr syntax, - UInt* outVal) -{ - for( auto modifier : syntax->GetModifiersOfType() ) - { - if( modifier ) - { - *outVal = (UInt) strtoull(String(modifier->valToken.Content).getBuffer(), nullptr, 10); - return true; - } - } - return false; -} - -template -static bool findLayoutArg( - DeclRef declRef, - UInt* outVal) -{ - return findLayoutArg(declRef.getDecl(), outVal); -} - - /// Determine how to lay out a global variable that might be a shader parameter. - /// - /// Returns `nullptr` if the declaration does not represent a shader parameter. -RefPtr getTypeLayoutForGlobalShaderParameter( - ParameterBindingContext* context, - VarDeclBase* varDecl, - Type* type) -{ - auto layoutContext = context->layoutContext; - auto rules = layoutContext.getRulesFamily(); - - if( varDecl->HasModifier() && as(type) ) - { - return createTypeLayout( - layoutContext.with(rules->getShaderRecordConstantBufferRules()), - type); - } - - - // We want to check for a constant-buffer type with a `push_constant` layout - // qualifier before we move on to anything else. - if (varDecl->HasModifier() && as(type)) - { - return createTypeLayout( - layoutContext.with(rules->getPushConstantBufferRules()), - type); - } - - // HLSL `static` modifier indicates "thread local" - if(varDecl->HasModifier()) - return nullptr; - - // HLSL `groupshared` modifier indicates "thread-group local" - if(varDecl->HasModifier()) - return nullptr; - - // TODO(tfoley): there may be other cases that we need to handle here - - // An "ordinary" global variable is implicitly a uniform - // shader parameter. - return createTypeLayout( - layoutContext.with(rules->getConstantBufferRules()), - type); -} - -// - -struct EntryPointParameterState -{ - String* optSemanticName = nullptr; - int* ioSemanticIndex = nullptr; - EntryPointParameterDirectionMask directionMask; - int semanticSlotCount; - Stage stage = Stage::Unknown; - bool isSampleRate = false; - SourceLoc loc; -}; - - -static RefPtr processEntryPointVaryingParameter( - ParameterBindingContext* context, - RefPtr type, - EntryPointParameterState const& state, - RefPtr varLayout); - -// Collect a single declaration into our set of parameters -static void collectGlobalGenericParameter( - ParameterBindingContext* context, - RefPtr paramDecl) -{ - RefPtr layout = new GenericParamLayout(); - layout->decl = paramDecl; - layout->index = (int)context->shared->programLayout->globalGenericParams.getCount(); - context->shared->programLayout->globalGenericParams.add(layout); - context->shared->programLayout->globalGenericParamsMap[layout->decl->getName()->text] = layout.Ptr(); -} - -// Collect a single declaration into our set of parameters -static void collectGlobalScopeParameter( - ParameterBindingContext* context, - GlobalShaderParamInfo const& shaderParamInfo, - SubstitutionSet globalGenericSubst) -{ - auto varDeclRef = shaderParamInfo.paramDeclRef; - - // We apply any substitutions for global generic parameters here. - auto type = GetType(varDeclRef)->Substitute(globalGenericSubst).as(); - - // We use a single operation to both check whether the - // variable represents a shader parameter, and to compute - // the layout for that parameter's type. - auto typeLayout = getTypeLayoutForGlobalShaderParameter( - context, - varDeclRef.getDecl(), - type); - - // If we did not find appropriate layout rules, then it - // must mean that this global variable is *not* a shader - // parameter. - if(!typeLayout) - return; - - // Now create a variable layout that we can use - RefPtr varLayout = new VarLayout(); - varLayout->typeLayout = typeLayout; - varLayout->varDecl = varDeclRef; - - // The logic in `check.cpp` that created the `GlobalShaderParamInfo` - // will have identified any cases where there might be multiple - // global variables that logically represent the same shader parameter. - // - // We will track the same basic information during layout using - // the `ParameterInfo` type. - // - // TODO: `ParameterInfo` should probably become `LayoutParamInfo`. - // - ParameterInfo* parameterInfo = new ParameterInfo(); - context->shared->parameters.add(parameterInfo); - - // Add the first variable declaration to the list of declarations for the parameter - parameterInfo->varLayouts.add(varLayout); - - // Add any additional variables to the list of declarations - for( auto additionalVarDeclRef : shaderParamInfo.additionalParamDeclRefs ) - { - // TODO: We should either eliminate the design choice where different - // declarations of the "same" shade parameter get merged across - // translation units (it is effectively just a compatiblity feature), - // or we should clean things up earlier in the chain so that we can - // re-use a single `VarLayout` across all of the different declarations. - // - // TODO: It would also make sense in these cases to ensure that - // such global shader parameters get the same mangled name across - // all translation units, so that they can automatically be collapsed - // during linking. - - RefPtr additionalVarLayout = new VarLayout(); - additionalVarLayout->typeLayout = typeLayout; - additionalVarLayout->varDecl = additionalVarDeclRef; - - parameterInfo->varLayouts.add(additionalVarLayout); - } -} - -static RefPtr findUsedRangeSetForSpace( - ParameterBindingContext* context, - UInt space) -{ - RefPtr usedRangeSet; - if (context->shared->globalSpaceUsedRangeSets.TryGetValue(space, usedRangeSet)) - return usedRangeSet; - - usedRangeSet = new UsedRangeSet(); - context->shared->globalSpaceUsedRangeSets.Add(space, usedRangeSet); - return usedRangeSet; -} - -// Record that a particular register space (or set, in the GLSL case) -// has been used in at least one binding, and so it should not -// be used by auto-generated bindings that need to claim entire -// spaces. -static void markSpaceUsed( - ParameterBindingContext* context, - UInt space) -{ - context->shared->usedSpaces.Add(nullptr, space, space+1); -} - -static UInt allocateUnusedSpaces( - ParameterBindingContext* context, - UInt count) -{ - return context->shared->usedSpaces.Allocate(nullptr, count); -} - -static void addExplicitParameterBinding( - ParameterBindingContext* context, - RefPtr parameterInfo, - VarDeclBase* varDecl, - LayoutSemanticInfo const& semanticInfo, - LayoutSize count, - RefPtr usedRangeSet = nullptr) -{ - auto kind = semanticInfo.kind; - - auto& bindingInfo = parameterInfo->bindingInfo[(int)kind]; - if( bindingInfo.count != 0 ) - { - // We already have a binding here, so we want to - // confirm that it matches the new one that is - // incoming... - if( bindingInfo.count != count - || bindingInfo.index != semanticInfo.index - || bindingInfo.space != semanticInfo.space ) - { - getSink(context)->diagnose(varDecl, Diagnostics::conflictingExplicitBindingsForParameter, getReflectionName(varDecl)); - - auto firstVarDecl = parameterInfo->varLayouts[0]->varDecl.getDecl(); - if( firstVarDecl != varDecl ) - { - getSink(context)->diagnose(firstVarDecl, Diagnostics::seeOtherDeclarationOf, getReflectionName(firstVarDecl)); - } - } - - // TODO(tfoley): `register` semantics can technically be - // profile-specific (not sure if anybody uses that)... - } - else - { - bindingInfo.count = count; - bindingInfo.index = semanticInfo.index; - bindingInfo.space = semanticInfo.space; - - if (!usedRangeSet) - { - usedRangeSet = findUsedRangeSetForSpace(context, semanticInfo.space); - - // Record that the particular binding space was - // used by an explicit binding, so that we don't - // claim it for auto-generated bindings that - // need to grab a full space - markSpaceUsed(context, semanticInfo.space); - } - auto overlappedVarLayout = usedRangeSet->usedResourceRanges[(int)semanticInfo.kind].Add( - parameterInfo->varLayouts[0], - semanticInfo.index, - semanticInfo.index + count); - - if (overlappedVarLayout) - { - auto paramA = parameterInfo->varLayouts[0]->varDecl.getDecl(); - auto paramB = overlappedVarLayout->varDecl.getDecl(); - - getSink(context)->diagnose(paramA, Diagnostics::parameterBindingsOverlap, - getReflectionName(paramA), - getReflectionName(paramB)); - - getSink(context)->diagnose(paramB, Diagnostics::seeDeclarationOf, getReflectionName(paramB)); - } - } -} - -static void addExplicitParameterBindings_HLSL( - ParameterBindingContext* context, - RefPtr parameterInfo, - RefPtr varLayout) -{ - // We only want to apply D3D `register` modifiers when compiling for - // D3D targets. - // - // TODO: Nominally, the `register` keyword allows for a shader - // profile to be specified, so that a given binding only - // applies for a specific profile: - // - // https://docs.microsoft.com/en-us/windows/desktop/direct3dhlsl/dx-graphics-hlsl-variable-register - // - // We might want to consider supporting that syntax in the - // long run, in order to handle bindings for multiple targets - // in a more consistent fashion (whereas using `register` for D3D - // and `[[vk::binding(...)]]` for Vulkan creates a lot of - // visual noise). - // - // For now we do the filtering on target in a very direct fashion: - // - if(!isD3DTarget(context->getTargetRequest())) - return; - - auto typeLayout = varLayout->typeLayout; - auto varDecl = varLayout->varDecl; - - // If the declaration has explicit binding modifiers, then - // here is where we want to extract and apply them... - - // Look for HLSL `register` or `packoffset` semantics. - for (auto semantic : varDecl.getDecl()->GetModifiersOfType()) - { - // Need to extract the information encoded in the semantic - LayoutSemanticInfo semanticInfo = ExtractLayoutSemanticInfo(context, semantic); - auto kind = semanticInfo.kind; - if (kind == LayoutResourceKind::None) - continue; - - // TODO: need to special-case when this is a `c` register binding... - - // Find the appropriate resource-binding information - // inside the type, to see if we even use any resources - // of the given kind. - - auto typeRes = typeLayout->FindResourceInfo(kind); - LayoutSize count = 0; - if (typeRes) - { - count = typeRes->count; - } - else - { - // TODO: warning here! - } - - addExplicitParameterBinding(context, parameterInfo, varDecl, semanticInfo, count); - } -} - -static void maybeDiagnoseMissingVulkanLayoutModifier( - ParameterBindingContext* context, - DeclRef const& varDecl) -{ - // If the user didn't specify a `binding` (and optional `set`) for Vulkan, - // but they *did* specify a `register` for D3D, then that is probably an - // oversight on their part. - if( auto registerModifier = varDecl.getDecl()->FindModifier() ) - { - getSink(context)->diagnose(registerModifier, Diagnostics::registerModifierButNoVulkanLayout, varDecl.GetName()); - } -} - -static void addExplicitParameterBindings_GLSL( - ParameterBindingContext* context, - RefPtr parameterInfo, - RefPtr varLayout) -{ - - // We only want to apply GLSL-style layout modifers - // when compiling for a Khronos-related target. - // - // TODO: This should have some finer granularity - // so that we are able to distinguish between - // Vulkan and OpenGL as targets. - // - if(!isKhronosTarget(context->getTargetRequest())) - return; - - auto typeLayout = varLayout->typeLayout; - auto varDecl = varLayout->varDecl; - - // The catch in GLSL is that the expected resource type - // is implied by the parameter declaration itself, and - // the `layout` modifier is only allowed to adjust - // the index/offset/etc. - // - - // We also may need to store explicit binding info in a different place, - // in the case of varying input/output, since we don't want to collect - // things globally; - RefPtr usedRangeSet; - - TypeLayout::ResourceInfo* resInfo = nullptr; - LayoutSemanticInfo semanticInfo; - semanticInfo.index = 0; - semanticInfo.space = 0; - if( (resInfo = typeLayout->FindResourceInfo(LayoutResourceKind::DescriptorTableSlot)) != nullptr ) - { - // Try to find `binding` and `set` - auto attr = varDecl.getDecl()->FindModifier(); - if (!attr) - { - maybeDiagnoseMissingVulkanLayoutModifier(context, varDecl); - return; - } - semanticInfo.index = attr->binding; - semanticInfo.space = attr->set; - } - else if( (resInfo = typeLayout->FindResourceInfo(LayoutResourceKind::RegisterSpace)) != nullptr ) - { - // Try to find `set` - auto attr = varDecl.getDecl()->FindModifier(); - if (!attr) - { - maybeDiagnoseMissingVulkanLayoutModifier(context, varDecl); - return; - } - if( attr->binding != 0) - { - getSink(context)->diagnose(attr, Diagnostics::wholeSpaceParameterRequiresZeroBinding, varDecl.GetName(), attr->binding); - } - semanticInfo.index = attr->set; - semanticInfo.space = 0; - } - else if( (resInfo = typeLayout->FindResourceInfo(LayoutResourceKind::SpecializationConstant)) != nullptr ) - { - // Try to find `constant_id` binding - if(!findLayoutArg(varDecl, &semanticInfo.index)) - return; - } - - // If we didn't find any matches, then bail - if(!resInfo) - return; - - auto kind = resInfo->kind; - auto count = resInfo->count; - semanticInfo.kind = kind; - - addExplicitParameterBinding(context, parameterInfo, varDecl, semanticInfo, count, usedRangeSet); -} - -// Given a single parameter, collect whatever information we have on -// how it has been explicitly bound, which may come from multiple declarations -void generateParameterBindings( - ParameterBindingContext* context, - RefPtr parameterInfo) -{ - // There must be at least one declaration for the parameter. - SLANG_RELEASE_ASSERT(parameterInfo->varLayouts.getCount() != 0); - - // Iterate over all declarations looking for explicit binding information. - for( auto& varLayout : parameterInfo->varLayouts ) - { - // Handle HLSL `register` and `packoffset` modifiers - addExplicitParameterBindings_HLSL(context, parameterInfo, varLayout); - - - // Handle GLSL `layout` modifiers - addExplicitParameterBindings_GLSL(context, parameterInfo, varLayout); - } -} - -// Generate the binding information for a shader parameter. -static void completeBindingsForParameterImpl( - ParameterBindingContext* context, - RefPtr firstVarLayout, - ParameterBindingInfo bindingInfos[kLayoutResourceKindCount], - RefPtr parameterInfo) -{ - // For any resource kind used by the parameter - // we need to update its layout information - // to include a binding for that resource kind. - // - auto firstTypeLayout = firstVarLayout->typeLayout; - - // We need to deal with allocation of full register spaces first, - // since that is the most complicated bit of logic. - // - // We will compute how many full register spaces the parameter - // needs to allocate, across all the kinds of resources it - // consumes, so that we can allocate a contiguous range of - // spaces. - // - UInt spacesToAllocateCount = 0; - for(auto typeRes : firstTypeLayout->resourceInfos) - { - auto kind = typeRes.kind; - - // We want to ignore resource kinds for which the user - // has specified an explicit binding, since those won't - // go into our contiguously allocated range. - // - auto& bindingInfo = bindingInfos[(int)kind]; - if( bindingInfo.count != 0 ) - { - continue; - } - - // Now we inspect the kind of resource to figure out - // its space requirements: - // - switch( kind ) - { - default: - // An unbounded-size array will need its own space. - // - if( typeRes.count.isInfinite() ) - { - spacesToAllocateCount++; - } - break; - - case LayoutResourceKind::RegisterSpace: - // If the parameter consumes any full spaces (e.g., it - // is a `struct` type with one or more unbounded arrays - // for fields), then we will include those spaces in - // our allocaiton. - // - // We assume/require here that we never end up needing - // an unbounded number of spaces. - // TODO: we should enforce that somewhere with an error. - // - spacesToAllocateCount += typeRes.count.getFiniteValue(); - break; - - case LayoutResourceKind::Uniform: - // We want to ignore uniform data for this calculation, - // since any uniform data in top-level shader parameters - // needs to go into a global constant buffer. - // - break; - - case LayoutResourceKind::GenericResource: - // This is more of a marker case, and shouldn't ever - // need a space allocated to it. - break; - } - } - - // If we compute that the parameter needs some number of full - // spaces allocated to it, then we will go ahead and allocate - // contiguous spaces here. - // - UInt firstAllocatedSpace = 0; - if(spacesToAllocateCount) - { - firstAllocatedSpace = allocateUnusedSpaces(context, spacesToAllocateCount); - } - - // We'll then dole the allocated spaces (if any) out to the resource - // categories that need them. - // - UInt currentAllocatedSpace = firstAllocatedSpace; - - for(auto typeRes : firstTypeLayout->resourceInfos) - { - // Did we already apply some explicit binding information - // for this resource kind? - auto kind = typeRes.kind; - auto& bindingInfo = bindingInfos[(int)kind]; - if( bindingInfo.count != 0 ) - { - // If things have already been bound, our work is done. - // - // TODO: it would be good to handle the case where a - // binding specified a space, but not an offset/index - // for some kind of resource. - // - continue; - } - - auto count = typeRes.count; - - // Certain resource kinds require special handling. - // - // Note: This `switch` statement should have a `case` for - // all of the special cases above that affect the computation of - // `spacesToAllocateCount`. - // - switch( kind ) - { - case LayoutResourceKind::RegisterSpace: - { - // The parameter's type needs to consume some number of whole - // register spaces, and we have already allocated a contiguous - // range of spaces above. - // - // As always, we can't handle the case of a parameter that needs - // an infinite number of spaces. - // - SLANG_ASSERT(count.isFinite()); - bindingInfo.count = count; - - // We will use the spaces we've allocated, and bump - // the variable tracking the "current" space by - // the number of spaces consumed. - // - bindingInfo.index = currentAllocatedSpace; - currentAllocatedSpace += count.getFiniteValue(); - - // TODO: what should we store as the "space" for - // an allocation of register spaces? Either zero - // or `space` makes sense, but it isn't clear - // which is a better choice. - bindingInfo.space = 0; - - continue; - } - - case LayoutResourceKind::GenericResource: - { - // `GenericResource` is somewhat confusingly named, - // but simply indicates that the type of this parameter - // in some way depends on a generic parameter that has - // not been bound to a concrete value, so that asking - // specific questions about its resource usage isn't - // really possible. - // - bindingInfo.space = 0; - bindingInfo.count = 1; - bindingInfo.index = 0; - continue; - } - - case LayoutResourceKind::Uniform: - // TODO: we don't currently handle global-scope uniform parameters. - break; - } - - // At this point, we know the parameter consumes some resource - // (e.g., D3D `t` registers or Vulkan `binding`s), and the user - // didn't specify an explicit binding, so we will have to - // assign one for them. - // - // If we are consuming an infinite amount of the given resource - // (e.g., an unbounded array of `Texure2D` requires an infinite - // number of `t` regisers in D3D), then we will go ahead - // and assign a full space: - // - if( count.isInfinite() ) - { - bindingInfo.count = count; - bindingInfo.index = 0; - bindingInfo.space = currentAllocatedSpace; - currentAllocatedSpace++; - } - else - { - // If we have a finite amount of resources, then - // we will go ahead and allocate from the "default" - // space. - - UInt space = context->shared->defaultSpace; - RefPtr usedRangeSet = findUsedRangeSetForSpace(context, space); - - bindingInfo.count = count; - bindingInfo.index = usedRangeSet->usedResourceRanges[(int)kind].Allocate(firstVarLayout, count.getFiniteValue()); - bindingInfo.space = space; - } - } -} - -static void applyBindingInfoToParameter( - RefPtr varLayout, - ParameterBindingInfo bindingInfos[kLayoutResourceKindCount]) -{ - for(auto k = 0; k < kLayoutResourceKindCount; ++k) - { - auto kind = LayoutResourceKind(k); - auto& bindingInfo = bindingInfos[k]; - - // skip resources we aren't consuming - if(bindingInfo.count == 0) - continue; - - // Add a record to the variable layout - auto varRes = varLayout->AddResourceInfo(kind); - varRes->space = (int) bindingInfo.space; - varRes->index = (int) bindingInfo.index; - } -} - -// Generate the binding information for a shader parameter. -static void completeBindingsForParameter( - ParameterBindingContext* context, - RefPtr parameterInfo) -{ - // We will use the first declaration of the parameter as - // a stand-in for all the declarations, so it is important - // that earlier code has validated that the declarations - // "match". - - SLANG_RELEASE_ASSERT(parameterInfo->varLayouts.getCount() != 0); - auto firstVarLayout = parameterInfo->varLayouts.getFirst(); - - completeBindingsForParameterImpl( - context, - firstVarLayout, - parameterInfo->bindingInfo, - parameterInfo); - - // At this point we should have explicit binding locations chosen for - // all the relevant resource kinds, so we can apply these to the - // declarations: - - for(auto& varLayout : parameterInfo->varLayouts) - { - applyBindingInfoToParameter(varLayout, parameterInfo->bindingInfo); - } -} - -static void completeBindingsForParameter( - ParameterBindingContext* context, - RefPtr varLayout) -{ - ParameterBindingInfo bindingInfos[kLayoutResourceKindCount]; - completeBindingsForParameterImpl( - context, - varLayout, - bindingInfos, - nullptr); - applyBindingInfoToParameter(varLayout, bindingInfos); -} - - /// Allocate binding location for any "pending" data in a shader parameter. - /// - /// When a parameter contains interface-type fields (recursively), we might - /// not have included them in the base layout for the parameter, and instead - /// need to allocate space for them after all other shader parameters have - /// been laid out. - /// - /// This function should be called on the `pendingVarLayout` field of an - /// existing `VarLayout` to ensure that its pending data has been properly - /// assigned storage. It handles the case where the `pendingVarLayout` - /// field is null. - /// -static void _allocateBindingsForPendingData( - ParameterBindingContext* context, - RefPtr pendingVarLayout) -{ - if(!pendingVarLayout) return; - - completeBindingsForParameter(context, pendingVarLayout); -} - -struct SimpleSemanticInfo -{ - String name; - int index; -}; - -SimpleSemanticInfo decomposeSimpleSemantic( - HLSLSimpleSemantic* semantic) -{ - auto composedName = semantic->name.Content; - - // look for a trailing sequence of decimal digits - // at the end of the composed name - UInt length = composedName.size(); - UInt indexLoc = length; - while( indexLoc > 0 ) - { - auto c = composedName[indexLoc-1]; - if( c >= '0' && c <= '9' ) - { - indexLoc--; - continue; - } - else - { - break; - } - } - - SimpleSemanticInfo info; - - // - if( indexLoc == length ) - { - // No index suffix - info.name = composedName; - info.index = 0; - } - else - { - // The name is everything before the digits - String stringComposedName(composedName); - - info.name = stringComposedName.subString(0, indexLoc); - info.index = strtol(stringComposedName.begin() + indexLoc, nullptr, 10); - } - return info; -} - -static RefPtr processSimpleEntryPointParameter( - ParameterBindingContext* context, - RefPtr type, - EntryPointParameterState const& inState, - RefPtr varLayout, - int semanticSlotCount = 1) -{ - EntryPointParameterState state = inState; - state.semanticSlotCount = semanticSlotCount; - - auto optSemanticName = state.optSemanticName; - auto semanticIndex = *state.ioSemanticIndex; - - String semanticName = optSemanticName ? *optSemanticName : ""; - String sn = semanticName.toLower(); - - RefPtr typeLayout; - if (sn.startsWith("sv_") - || sn.startsWith("nv_")) - { - // System-value semantic. - - if (state.directionMask & kEntryPointParameterDirection_Output) - { - // Note: I'm just doing something expedient here and detecting `SV_Target` - // outputs and claiming the appropriate register range right away. - // - // TODO: we should really be building up some representation of all of this, - // once we've gone to the trouble of looking it all up... - if( sn == "sv_target" ) - { - // TODO: construct a `ParameterInfo` we can use here so that - // overlapped layout errors get reported nicely. - - auto usedResourceSet = findUsedRangeSetForSpace(context, 0); - usedResourceSet->usedResourceRanges[int(LayoutResourceKind::UnorderedAccess)].Add(nullptr, semanticIndex, semanticIndex + semanticSlotCount); - - - // We also need to track this as an ordinary varying output from the stage, - // since that is how GLSL will want to see it. - // - typeLayout = getSimpleVaryingParameterTypeLayout( - context->layoutContext, - type, - kEntryPointParameterDirection_Output); - } - } - - if (state.directionMask & kEntryPointParameterDirection_Input) - { - if (sn == "sv_sampleindex") - { - state.isSampleRate = true; - } - } - - if( !typeLayout ) - { - // If we didn't compute a special-case layout for the - // system-value parameter (e.g., because it was an - // `SV_Target` output), then create a default layout - // that consumes no input/output varying slots. - // (since system parameters are distinct from - // user-defined parameters for layout purposes) - // - typeLayout = getSimpleVaryingParameterTypeLayout( - context->layoutContext, - type, - 0); - } - - // Remember the system-value semantic so that we can query it later - if (varLayout) - { - varLayout->systemValueSemantic = semanticName; - varLayout->systemValueSemanticIndex = semanticIndex; - } - - // TODO: add some kind of usage information for system input/output - } - else - { - // In this case we have a user-defined semantic, which means - // an ordinary input and/or output varying parameter. - // - typeLayout = getSimpleVaryingParameterTypeLayout( - context->layoutContext, - type, - state.directionMask); - } - - if (state.isSampleRate - && (state.directionMask & kEntryPointParameterDirection_Input) - && (context->stage == Stage::Fragment)) - { - if (auto entryPointLayout = context->entryPointLayout) - { - entryPointLayout->flags |= EntryPointLayout::Flag::usesAnySampleRateInput; - } - } - - *state.ioSemanticIndex += state.semanticSlotCount; - typeLayout->type = type; - - return typeLayout; -} - -static RefPtr processEntryPointVaryingParameterDecl( - ParameterBindingContext* context, - Decl* decl, - RefPtr type, - EntryPointParameterState const& inState, - RefPtr varLayout) -{ - SimpleSemanticInfo semanticInfo; - int semanticIndex = 0; - - EntryPointParameterState state = inState; - - // If there is no explicit semantic already in effect, *and* we find an explicit - // semantic on the associated declaration, then we'll use it. - if( !state.optSemanticName ) - { - if( auto semantic = decl->FindModifier() ) - { - semanticInfo = decomposeSimpleSemantic(semantic); - semanticIndex = semanticInfo.index; - - state.optSemanticName = &semanticInfo.name; - state.ioSemanticIndex = &semanticIndex; - } - } - - if (decl) - { - if (decl->FindModifier()) - { - state.isSampleRate = true; - } - } - - // Default case: either there was an explicit semantic in effect already, - // *or* we couldn't find an explicit semantic to apply on the given - // declaration, so we will just recursive with whatever we have at - // the moment. - return processEntryPointVaryingParameter(context, type, state, varLayout); -} - -static RefPtr processEntryPointVaryingParameter( - ParameterBindingContext* context, - RefPtr type, - EntryPointParameterState const& state, - RefPtr varLayout) -{ - // Make sure to associate a stage with every - // varying parameter (including sub-fields of - // `struct`-type parameters), since downstream - // code generation will need to look at the - // stage (possibly on individual leaf fields) to - // decide when to emit things like the `flat` - // interpolation modifier. - // - if( varLayout ) - { - varLayout->stage = state.stage; - } - - // The default handling of varying parameters should not apply - // to geometry shader output streams; they have their own special rules. - if( auto gsStreamType = as(type) ) - { - // - - auto elementType = gsStreamType->getElementType(); - - int semanticIndex = 0; - - EntryPointParameterState elementState; - elementState.directionMask = kEntryPointParameterDirection_Output; - elementState.ioSemanticIndex = &semanticIndex; - elementState.isSampleRate = false; - elementState.optSemanticName = nullptr; - elementState.semanticSlotCount = 0; - elementState.stage = state.stage; - elementState.loc = state.loc; - - auto elementTypeLayout = processEntryPointVaryingParameter(context, elementType, elementState, nullptr); - - RefPtr typeLayout = new StreamOutputTypeLayout(); - typeLayout->type = type; - typeLayout->rules = elementTypeLayout->rules; - typeLayout->elementTypeLayout = elementTypeLayout; - - for(auto resInfo : elementTypeLayout->resourceInfos) - typeLayout->addResourceUsage(resInfo); - - return typeLayout; - } - - // Raytracing shaders have a slightly different interpretation of their - // "varying" input/output parameters, since they don't have the same - // idea of previous/next stage as the rasterization shader types. - // - if( state.directionMask & kEntryPointParameterDirection_Output ) - { - // Note: we are silently treating `out` parameters as if they - // were `in out` for this test, under the assumption that - // an `out` parameter represents a write-only payload. - - switch(state.stage) - { - default: - // Not a raytracing shader. - break; - - case Stage::Intersection: - case Stage::RayGeneration: - // Don't expect this case to have any `in out` parameters. - getSink(context)->diagnose(state.loc, Diagnostics::dontExpectOutParametersForStage, getStageName(state.stage)); - break; - - case Stage::AnyHit: - case Stage::ClosestHit: - case Stage::Miss: - // `in out` or `out` parameter is payload - return createTypeLayout(context->layoutContext.with( - context->getRulesFamily()->getRayPayloadParameterRules()), - type); - - case Stage::Callable: - // `in out` or `out` parameter is payload - return createTypeLayout(context->layoutContext.with( - context->getRulesFamily()->getCallablePayloadParameterRules()), - type); - - } - } - else - { - switch(state.stage) - { - default: - // Not a raytracing shader. - break; - - case Stage::Intersection: - case Stage::RayGeneration: - case Stage::Miss: - case Stage::Callable: - // Don't expect this case to have any `in` parameters. - // - // TODO: For a miss or callable shader we could interpret - // an `in` parameter as indicating a payload that the - // programmer doesn't intend to write to. - // - getSink(context)->diagnose(state.loc, Diagnostics::dontExpectInParametersForStage, getStageName(state.stage)); - break; - - case Stage::AnyHit: - case Stage::ClosestHit: - // `in` parameter is hit attributes - return createTypeLayout(context->layoutContext.with( - context->getRulesFamily()->getHitAttributesParameterRules()), - type); - } - } - - // If there is an available semantic name and index, - // then we should apply it to this parameter unconditionally - // (that is, not just if it is a leaf parameter). - auto optSemanticName = state.optSemanticName; - if (optSemanticName && varLayout) - { - // Always store semantics in upper-case for - // reflection information, since they are - // supposed to be case-insensitive and - // upper-case is the dominant convention. - String semanticName = *optSemanticName; - String sn = semanticName.toUpper(); - - auto semanticIndex = *state.ioSemanticIndex; - - varLayout->semanticName = sn; - varLayout->semanticIndex = semanticIndex; - varLayout->flags |= VarLayoutFlag::HasSemantic; - } - - // Scalar and vector types are treated as outputs directly - if(auto basicType = as(type)) - { - return processSimpleEntryPointParameter(context, basicType, state, varLayout); - } - else if(auto vectorType = as(type)) - { - return processSimpleEntryPointParameter(context, vectorType, state, varLayout); - } - // A matrix is processed as if it was an array of rows - else if( auto matrixType = as(type) ) - { - auto rowCount = GetIntVal(matrixType->getRowCount()); - return processSimpleEntryPointParameter(context, matrixType, state, varLayout, (int) rowCount); - } - else if( auto arrayType = as(type) ) - { - // Note: Bad Things will happen if we have an array input - // without a semantic already being enforced. - - auto elementCount = (UInt) GetIntVal(arrayType->ArrayLength); - - // We use the first element to derive the layout for the element type - auto elementTypeLayout = processEntryPointVaryingParameter(context, arrayType->baseType, state, varLayout); - - // We still walk over subsequent elements to make sure they consume resources - // as needed - for( UInt ii = 1; ii < elementCount; ++ii ) - { - processEntryPointVaryingParameter(context, arrayType->baseType, state, nullptr); - } - - RefPtr arrayTypeLayout = new ArrayTypeLayout(); - arrayTypeLayout->elementTypeLayout = elementTypeLayout; - arrayTypeLayout->type = arrayType; - - for (auto rr : elementTypeLayout->resourceInfos) - { - arrayTypeLayout->findOrAddResourceInfo(rr.kind)->count = rr.count * elementCount; - } - - return arrayTypeLayout; - } - // Ignore a bunch of types that don't make sense here... - else if (auto textureType = as(type)) { return nullptr; } - else if(auto samplerStateType = as(type)) { return nullptr; } - else if(auto constantBufferType = as(type)) { return nullptr; } - // Catch declaration-reference types late in the sequence, since - // otherwise they will include all of the above cases... - else if( auto declRefType = as(type) ) - { - auto declRef = declRefType->declRef; - - if (auto structDeclRef = declRef.as()) - { - RefPtr structLayout = new StructTypeLayout(); - structLayout->type = type; - - // Need to recursively walk the fields of the structure now... - for( auto field : GetFields(structDeclRef) ) - { - RefPtr fieldVarLayout = new VarLayout(); - fieldVarLayout->varDecl = field; - - auto fieldTypeLayout = processEntryPointVaryingParameterDecl( - context, - field.getDecl(), - GetType(field), - state, - fieldVarLayout); - - if(fieldTypeLayout) - { - fieldVarLayout->typeLayout = fieldTypeLayout; - - for (auto rr : fieldTypeLayout->resourceInfos) - { - SLANG_RELEASE_ASSERT(rr.count != 0); - - auto structRes = structLayout->findOrAddResourceInfo(rr.kind); - fieldVarLayout->findOrAddResourceInfo(rr.kind)->index = structRes->count.getFiniteValue(); - structRes->count += rr.count; - } - } - - structLayout->fields.add(fieldVarLayout); - structLayout->mapVarToLayout.Add(field.getDecl(), fieldVarLayout); - } - - return structLayout; - } - else if (auto globalGenericParam = declRef.as()) - { - auto genParamTypeLayout = new GenericParamTypeLayout(); - // we should have already populated ProgramLayout::genericEntryPointParams list at this point, - // so we can find the index of this generic param decl in the list - genParamTypeLayout->type = type; - genParamTypeLayout->paramIndex = findGenericParam(context->shared->programLayout->globalGenericParams, globalGenericParam.getDecl()); - genParamTypeLayout->findOrAddResourceInfo(LayoutResourceKind::GenericResource)->count += 1; - return genParamTypeLayout; - } - else if (auto associatedTypeParam = declRef.as()) - { - RefPtr assocTypeLayout = new TypeLayout(); - assocTypeLayout->type = type; - return assocTypeLayout; - } - else - { - SLANG_UNEXPECTED("unhandled type kind"); - } - } - // If we ran into an error in checking the user's code, then skip this parameter - else if( auto errorType = as(type) ) - { - return nullptr; - } - - SLANG_UNEXPECTED("unhandled type kind"); - UNREACHABLE_RETURN(nullptr); -} - - /// Compute the type layout for a parameter declared directly on an entry point. -static RefPtr computeEntryPointParameterTypeLayout( - ParameterBindingContext* context, - SubstitutionSet typeSubst, - DeclRef paramDeclRef, - RefPtr paramVarLayout, - EntryPointParameterState& state) -{ - auto paramDeclRefType = GetType(paramDeclRef); - SLANG_ASSERT(paramDeclRefType); - - auto paramType = paramDeclRefType->Substitute(typeSubst).as(); - - if( paramDeclRef.getDecl()->HasModifier() ) - { - // An entry-point parameter that is explicitly marked `uniform` represents - // a uniform shader parameter passed via the implicitly-defined - // constant buffer (e.g., the `$Params` constant buffer seen in fxc/dxc output). - // - return createTypeLayout( - context->layoutContext.with(context->getRulesFamily()->getConstantBufferRules()), - paramType); - } - else - { - // The default case is a varying shader parameter, which could be used for - // input, output, or both. - // - // The varying case needs to not only compute a layout, but also assocaite - // "semantic" strings/indices with the varying parameters by recursively - // walking their structure. - - state.directionMask = 0; - - // If it appears to be an input, process it as such. - if( paramDeclRef.getDecl()->HasModifier() - || paramDeclRef.getDecl()->HasModifier() - || !paramDeclRef.getDecl()->HasModifier() ) - { - state.directionMask |= kEntryPointParameterDirection_Input; - } - - // If it appears to be an output, process it as such. - if(paramDeclRef.getDecl()->HasModifier() - || paramDeclRef.getDecl()->HasModifier()) - { - state.directionMask |= kEntryPointParameterDirection_Output; - } - - return processEntryPointVaryingParameterDecl( - context, - paramDeclRef.getDecl(), - paramType, - state, - paramVarLayout); - } -} - -// There are multiple places where we need to compute the layout -// for a "scope" such as the global scope or an entry point. -// The `ScopeLayoutBuilder` encapsulates the logic around: -// -// * Doing layout for the ordinary/uniform fields, which involves -// using the `struct` layout rules for constant buffers on -// the target. -// -// * Creating a final type/var layout that reflects whether the -// scope needs a constant buffer to be allocated to it. -// -struct ScopeLayoutBuilder -{ - ParameterBindingContext* m_context = nullptr; - LayoutRulesImpl* m_rules = nullptr; - RefPtr m_structLayout; - UniformLayoutInfo m_structLayoutInfo; - - // We need to compute a layout for any "pending" data inside - // of the parameters being added to the scope, to facilitate - // later allocating space for all the pending parameters after - // the primary shader parameters. - // - StructTypeLayoutBuilder m_pendingDataTypeLayoutBuilder; - - void beginLayout( - ParameterBindingContext* context) - { - m_context = context; - m_rules = context->getRulesFamily()->getConstantBufferRules(); - m_structLayout = new StructTypeLayout(); - m_structLayout->rules = m_rules; - - m_structLayoutInfo = m_rules->BeginStructLayout(); - } - - void _addParameter( - RefPtr firstVarLayout, - ParameterInfo* parameterInfo) - { - // Does the parameter have any uniform data? - auto layoutInfo = firstVarLayout->typeLayout->FindResourceInfo(LayoutResourceKind::Uniform); - LayoutSize uniformSize = layoutInfo ? layoutInfo->count : 0; - if( uniformSize != 0 ) - { - // Make sure uniform fields get laid out properly... - - UniformLayoutInfo fieldInfo( - uniformSize, - firstVarLayout->typeLayout->uniformAlignment); - - LayoutSize uniformOffset = m_rules->AddStructField( - &m_structLayoutInfo, - fieldInfo); - - if( parameterInfo ) - { - for( auto& varLayout : parameterInfo->varLayouts ) - { - varLayout->findOrAddResourceInfo(LayoutResourceKind::Uniform)->index = uniformOffset.getFiniteValue(); - } - } - else - { - firstVarLayout->findOrAddResourceInfo(LayoutResourceKind::Uniform)->index = uniformOffset.getFiniteValue(); - } - } - - m_structLayout->fields.add(firstVarLayout); - - if( parameterInfo ) - { - for( auto& varLayout : parameterInfo->varLayouts ) - { - m_structLayout->mapVarToLayout.Add(varLayout->varDecl.getDecl(), varLayout); - } - } - else - { - m_structLayout->mapVarToLayout.Add(firstVarLayout->varDecl.getDecl(), firstVarLayout); - } - - // Any "pending" items on a field type become "pending" items - // on the overall `struct` type layout. - // - // TODO: This logic ends up duplicated between here and the main - // `struct` layout logic in `type-layout.cpp`. If this gets any - // more complicated we should see if there is a way to share it. - // - if( auto fieldPendingDataTypeLayout = firstVarLayout->typeLayout->pendingDataTypeLayout ) - { - m_pendingDataTypeLayoutBuilder.beginLayoutIfNeeded(nullptr, m_rules); - auto fieldPendingDataVarLayout = m_pendingDataTypeLayoutBuilder.addField(firstVarLayout->varDecl, fieldPendingDataTypeLayout); - - m_structLayout->pendingDataTypeLayout = m_pendingDataTypeLayoutBuilder.getTypeLayout(); - - if( parameterInfo ) - { - for( auto& varLayout : parameterInfo->varLayouts ) - { - varLayout->pendingVarLayout = fieldPendingDataVarLayout; - } - } - else - { - firstVarLayout->pendingVarLayout = fieldPendingDataVarLayout; - } - } - } - - void addParameter( - RefPtr varLayout) - { - _addParameter(varLayout, nullptr); - } - - void addParameter( - ParameterInfo* parameterInfo) - { - SLANG_RELEASE_ASSERT(parameterInfo->varLayouts.getCount() != 0); - auto firstVarLayout = parameterInfo->varLayouts.getFirst(); - - _addParameter(firstVarLayout, parameterInfo); - } - - RefPtr endLayout() - { - // Finish computing the layout for the ordindary data (if any). - // - m_rules->EndStructLayout(&m_structLayoutInfo); - m_pendingDataTypeLayoutBuilder.endLayout(); - - // Copy the final layout information computed for ordinary data - // over to the struct type layout for the scope. - // - m_structLayout->addResourceUsage(LayoutResourceKind::Uniform, m_structLayoutInfo.size); - m_structLayout->uniformAlignment = m_structLayout->uniformAlignment; - - RefPtr scopeTypeLayout = m_structLayout; - - // If a constant buffer is needed (because there is a non-zero - // amount of uniform data), then we need to wrap up the layout - // to reflect the constant buffer that will be generated. - // - scopeTypeLayout = createConstantBufferTypeLayoutIfNeeded( - m_context->layoutContext, - scopeTypeLayout); - - // We now have a bunch of layout information, which we should - // record into a suitable object that represents the scope - RefPtr scopeVarLayout = new VarLayout(); - scopeVarLayout->typeLayout = scopeTypeLayout; - - if( auto pendingTypeLayout = scopeTypeLayout->pendingDataTypeLayout ) - { - RefPtr pendingVarLayout = new VarLayout(); - pendingVarLayout->typeLayout = pendingTypeLayout; - scopeVarLayout->pendingVarLayout = pendingVarLayout; - } - - return scopeVarLayout; - } -}; - - /// Helper routine to allocate a constant buffer binding if one is needed. - /// - /// This function primarily exists to encapsulate the logic for allocating - /// the resources required for a constant buffer in the appropriate - /// target-specific fashion. - /// -static ParameterBindingAndKindInfo maybeAllocateConstantBufferBinding( - ParameterBindingContext* context, - bool needConstantBuffer) -{ - if( !needConstantBuffer ) return ParameterBindingAndKindInfo(); - - UInt space = context->shared->defaultSpace; - auto usedRangeSet = findUsedRangeSetForSpace(context, space); - - auto layoutInfo = context->getRulesFamily()->getConstantBufferRules()->GetObjectLayout( - ShaderParameterKind::ConstantBuffer); - - ParameterBindingAndKindInfo info; - info.kind = layoutInfo.kind; - info.count = layoutInfo.size; - info.index = usedRangeSet->usedResourceRanges[(int)layoutInfo.kind].Allocate(nullptr, layoutInfo.size.getFiniteValue()); - info.space = space; - return info; -} - - /// Iterate over the parameters of an entry point to compute its requirements. - /// -static void collectEntryPointParameters( - ParameterBindingContext* context, - EntryPoint* entryPoint, - SubstitutionSet typeSubst) -{ - DeclRef entryPointFuncDeclRef = entryPoint->getFuncDeclRef(); - - // We will take responsibility for creating and filling in - // the `EntryPointLayout` object here. - // - RefPtr entryPointLayout = new EntryPointLayout(); - entryPointLayout->profile = entryPoint->getProfile(); - entryPointLayout->entryPoint = entryPointFuncDeclRef.getDecl(); - - // The entry point layout must be added to the output - // program layout so that it can be accessed by reflection. - // - context->shared->programLayout->entryPoints.add(entryPointLayout); - - // For the duration of our parameter collection work we will - // establish this entry point as the current one in the context. - // - context->entryPointLayout = entryPointLayout; - - // Note: this isn't really the best place for this logic to sit, - // but it is the simplest place where we have a direct correspondence - // between a single `EntryPoint` and its matching `EntryPointLayout`, - // so we'll use it. - // - for( auto taggedUnionType : entryPoint->getTaggedUnionTypes() ) - { - SLANG_ASSERT(taggedUnionType); - auto substType = taggedUnionType->Substitute(typeSubst).as(); - auto typeLayout = createTypeLayout(context->layoutContext, substType); - entryPointLayout->taggedUnionTypeLayouts.add(typeLayout); - } - - // We are going to iterate over the entry-point parameters, - // and while we do so we will go ahead and perform layout/binding - // assignment for two cases: - // - // First, the varying parameters of the entry point will have - // their semantics and locations assigned, so we set up state - // for tracking that layout. - // - int defaultSemanticIndex = 0; - EntryPointParameterState state; - state.ioSemanticIndex = &defaultSemanticIndex; - state.optSemanticName = nullptr; - state.semanticSlotCount = 0; - state.stage = entryPoint->getStage(); - - // Second, we will compute offsets for any "ordinary" data - // in the parameter list (e.g., a `uniform float4x4 mvp` parameter), - // which is what the `ScopeLayoutBuilder` is designed to help with. - // - ScopeLayoutBuilder scopeBuilder; - scopeBuilder.beginLayout(context); - auto paramsStructLayout = scopeBuilder.m_structLayout; - - for( auto& shaderParamInfo : entryPoint->getShaderParams() ) - { - auto paramDeclRef = shaderParamInfo.paramDeclRef; - - // When computing layout for an entry-point parameter, - // we want to make sure that the layout context has access - // to the existential type arguments (if any) that were - // provided for the entry-point existential type parameters (if any). - // - context->layoutContext= context->layoutContext - .withExistentialTypeArgs( - entryPoint->getExistentialTypeArgCount(), - entryPoint->getExistentialTypeArgs()) - .withExistentialTypeSlotsOffsetBy( - shaderParamInfo.firstExistentialTypeSlot); - - // Any error messages we emit during the process should - // refer to the location of this parameter. - // - state.loc = paramDeclRef.getLoc(); - - // We are going to construct the variable layout for this - // parameter *before* computing the type layout, because - // the type layout computation is also determining the effective - // semantic of the parameter, which needs to be stored - // back onto the `VarLayout`. - // - RefPtr paramVarLayout = new VarLayout(); - paramVarLayout->varDecl = paramDeclRef; - paramVarLayout->stage = state.stage; - - auto paramTypeLayout = computeEntryPointParameterTypeLayout( - context, - typeSubst, - paramDeclRef, - paramVarLayout, - state); - paramVarLayout->typeLayout = paramTypeLayout; - - // We expect to always be able to compute a layout for - // entry-point parameters, but to be defensive we will - // skip parameters that couldn't have a layout computed - // when assertions are disabled. - // - SLANG_ASSERT(paramTypeLayout); - if(!paramTypeLayout) - continue; - - // Now that we've computed the layout to use for the parameter, - // we need to add its resource usage to that of the entry - // point as a whole. - // - // Any "ordinary" data (e.g., a `float4x4`) needs to be accounted - // for using the `ScopeLayoutBuilder`, since it will handle - // the details of target-specific `struct` type layout. - // - scopeBuilder.addParameter(paramVarLayout); - - // All of the other resources types will be handled in a - // simpler loop that just increments the relevant counters. - // - for (auto paramTypeResInfo : paramTypeLayout->resourceInfos) - { - // We need to skip ordinary data because it is being - // handled by the `scopeBuilder`. - // - if(paramTypeResInfo.kind == LayoutResourceKind::Uniform) - continue; - - // Whatever resources the parameter uses, we need to - // assign the parameter's location/register/binding offset to - // be the sum of everything added so far. - // - auto entryPointResInfo = paramsStructLayout->findOrAddResourceInfo(paramTypeResInfo.kind); - paramVarLayout->findOrAddResourceInfo(paramTypeResInfo.kind)->index = entryPointResInfo->count.getFiniteValue(); - - // We then need to add the resources consumed by the parameter - // to those consumed by the entry point. - // - entryPointResInfo->count += paramTypeResInfo.count; - } - } - entryPointLayout->parametersLayout = scopeBuilder.endLayout(); - - // For an entry point with a non-`void` return type, we need to process the - // return type as a varying output parameter. - // - // TODO: Ideally we should make the layout process more robust to empty/void - // types and apply this logic unconditionally. - // - auto resultType = GetResultType(entryPointFuncDeclRef)->Substitute(typeSubst).as(); - SLANG_ASSERT(resultType); - - if( !resultType->Equals(resultType->getSession()->getVoidType()) ) - { - state.loc = entryPointFuncDeclRef.getLoc(); - state.directionMask = kEntryPointParameterDirection_Output; - - RefPtr resultLayout = new VarLayout(); - resultLayout->stage = state.stage; - - auto resultTypeLayout = processEntryPointVaryingParameterDecl( - context, - entryPointFuncDeclRef.getDecl(), - resultType->Substitute(typeSubst).as(), - state, - resultLayout); - - if( resultTypeLayout ) - { - resultLayout->typeLayout = resultTypeLayout; - - for (auto rr : resultTypeLayout->resourceInfos) - { - auto entryPointRes = paramsStructLayout->findOrAddResourceInfo(rr.kind); - resultLayout->findOrAddResourceInfo(rr.kind)->index = entryPointRes->count.getFiniteValue(); - entryPointRes->count += rr.count; - } - } - - entryPointLayout->resultLayout = resultLayout; - } -} - -static void collectParameters( - ParameterBindingContext* inContext, - Program* program) -{ - // All of the parameters in translation units directly - // referenced in the compile request are part of one - // logical namespace/"linkage" so that two parameters - // with the same name should represent the same - // parameter, and get the same binding(s) - - ParameterBindingContext contextData = *inContext; - auto context = &contextData; - context->stage = Stage::Unknown; - - auto globalGenericSubst = program->getGlobalGenericSubstitution(); - - // We will start by looking for any global generic type parameters. - - for(RefPtr module : program->getModuleDependencies()) - { - for( auto genParamDecl : module->getModuleDecl()->getMembersOfType() ) - { - collectGlobalGenericParameter(context, genParamDecl); - } - } - - // Once we have enumerated global generic type parameters, we can - // begin enumerating shader parameters, starting at the global scope. - // - // Because we have already enumerated the global generic type parameters, - // we will be able to look up the index of a global generic type parameter - // when we see it referenced in the type of one of the shader parameters. - - for(auto& globalParamInfo : program->getShaderParams() ) - { - // When computing layout for a global shader parameter, - // we want to make sure that the layout context has access - // to the existential type arguments (if any) that were - // provided for the global existential type parameters (if any). - // - context->layoutContext= context->layoutContext - .withExistentialTypeArgs( - program->getExistentialTypeArgCount(), - program->getExistentialTypeArgs()) - .withExistentialTypeSlotsOffsetBy( - globalParamInfo.firstExistentialTypeSlot); - - collectGlobalScopeParameter(context, globalParamInfo, globalGenericSubst); - } - - // Next consider parameters for entry points - for(auto entryPoint : program->getEntryPoints()) - { - context->stage = entryPoint->getStage(); - collectEntryPointParameters(context, entryPoint, globalGenericSubst); - } - context->entryPointLayout = nullptr; -} - - /// Emit a diagnostic about a uniform parameter at global scope. -void diagnoseGlobalUniform( - SharedParameterBindingContext* sharedContext, - VarDeclBase* varDecl) -{ - // It is entirely possible for Slang to support uniform parameters at the global scope, - // by bundling them into an implicit constant buffer, and indeed the layout algorithm - // implemented in this file computes a layout *as if* the Slang compiler does just that. - // - // The missing link is the downstream IR and code generation steps, where we would need - // to collect all of the global-scope uniforms into a common `struct` type and then - // create a new constant buffer parameter over that type. - // - // For now it is easier to simply ban this case, since most shader authors have - // switched to modern HLSL/GLSL style with `cbuffer` or `uniform` block declarations. - // - // TODO: In the long run it may be best to require *all* global-scope shader parameters - // to be marked with a keyword (e.g., `uniform`) so that ordinary global variable syntax can be - // used safely. - // - getSink(sharedContext)->diagnose(varDecl, Diagnostics::globalUniformsNotSupported, varDecl->getName()); -} - -static int _calcTotalNumUsedRegistersForLayoutResourceKind(ParameterBindingContext* bindingContext, LayoutResourceKind kind) -{ - int numUsed = 0; - for (auto& pair : bindingContext->shared->globalSpaceUsedRangeSets) - { - UsedRangeSet* rangeSet = pair.Value; - const auto& usedRanges = rangeSet->usedResourceRanges[kind]; - for (const auto& usedRange : usedRanges.ranges) - { - numUsed += int(usedRange.end - usedRange.begin); - } - } - return numUsed; -} - -RefPtr generateParameterBindings( - TargetProgram* targetProgram, - DiagnosticSink* sink) -{ - auto program = targetProgram->getProgram(); - auto targetReq = targetProgram->getTargetReq(); - - RefPtr programLayout = new ProgramLayout(); - programLayout->targetProgram = targetProgram; - - // Try to find rules based on the selected code-generation target - auto layoutContext = getInitialLayoutContextForTarget(targetReq, programLayout); - - // If there was no target, or there are no rules for the target, - // then bail out here. - if (!layoutContext.rules) - return nullptr; - - // Create a context to hold shared state during the process - // of generating parameter bindings - SharedParameterBindingContext sharedContext( - layoutContext.getRulesFamily(), - programLayout, - targetReq, - sink); - - // Create a sub-context to collect parameters that get - // declared into the global scope - ParameterBindingContext context; - context.shared = &sharedContext; - context.layoutContext = layoutContext; - - // Walk through AST to discover all the parameters - collectParameters(&context, program); - - // Now walk through the parameters to generate initial binding information - for( auto& parameter : sharedContext.parameters ) - { - generateParameterBindings(&context, parameter); - } - - // Determine if there are any global-scope parameters that use `Uniform` - // resources, and thus need to get packaged into a constant buffer. - // - // Note: this doesn't account for GLSL's support for "legacy" uniforms - // at global scope, which don't get assigned a CB. - bool needDefaultConstantBuffer = false; - for( auto& parameterInfo : sharedContext.parameters ) - { - SLANG_RELEASE_ASSERT(parameterInfo->varLayouts.getCount() != 0); - auto firstVarLayout = parameterInfo->varLayouts.getFirst(); - - // Does the field have any uniform data? - if( firstVarLayout->typeLayout->FindResourceInfo(LayoutResourceKind::Uniform) ) - { - needDefaultConstantBuffer = true; - diagnoseGlobalUniform(&sharedContext, firstVarLayout->varDecl); - } - } - - // Next, we want to determine if there are any global-scope parameters - // that don't just allocate a whole register space to themselves; these - // parameters will need to go into a "default" space, which should always - // be the first space we allocate. - // - // As a starting point, we will definitely need a "default" space if - // we are creating a default constant buffer, since it should get - // a binding in that "default" space. - // - bool needDefaultSpace = needDefaultConstantBuffer; - if (!needDefaultSpace) - { - // Next we will look at the global-scope parameters and see if - // any of them requires a `register` or `binding` that will - // thus need to land in a default space. - // - for (auto& parameterInfo : sharedContext.parameters) - { - SLANG_RELEASE_ASSERT(parameterInfo->varLayouts.getCount() != 0); - auto firstVarLayout = parameterInfo->varLayouts.getFirst(); - - // For each parameter, we will look at each resource it consumes. - // - for (auto resInfo : firstVarLayout->typeLayout->resourceInfos) - { - // We don't care about whole register spaces/sets, since - // we don't need to allocate a default space/set for a parameter - // that itself consumes a whole space/set. - // - if( resInfo.kind == LayoutResourceKind::RegisterSpace ) - continue; - - // We also don't want to consider resource kinds for which - // the variable already has an (explicit) binding, since - // the space from the explicit binding will be used, so - // that a default space isn't needed. - // - if( parameterInfo->bindingInfo[resInfo.kind].count != 0 ) - continue; - - // Otherwise, we have a shader parameter that will need - // a default space or set to live in. - // - needDefaultSpace = true; - break; - } - } - } - - // If we need a space for default bindings, then allocate it here. - if (needDefaultSpace) - { - UInt defaultSpace = 0; - - // Check if space #0 has been allocated yet. If not, then we'll - // want to use it. - if (sharedContext.usedSpaces.contains(0)) - { - // Somebody has already put things in space zero. - // - // TODO: There are two cases to handle here: - // - // 1) If there is any free register ranges in space #0, - // then we should keep using it as the default space. - // - // 2) If somebody went and put an HLSL unsized array into space #0, - // *or* if they manually placed something like a paramter block - // there (which should consume whole spaces), then we need to - // allocate an unused space instead. - // - // For now we don't deal with the concept of unsized arrays, or - // manually assigning parameter blocks to spaces, so we punt - // on this and assume case (1). - - defaultSpace = 0; - } - else - { - // Nobody has used space zero yet, so we need - // to make sure to reserve it for defaults. - defaultSpace = allocateUnusedSpaces(&context, 1); - - // The result of this allocation had better be that - // we got space #0, or else something has gone wrong. - SLANG_ASSERT(defaultSpace == 0); - } - - sharedContext.defaultSpace = defaultSpace; - } - - // If there are any global-scope uniforms, then we need to - // allocate a constant-buffer binding for them here. - // - ParameterBindingAndKindInfo globalConstantBufferBinding = maybeAllocateConstantBufferBinding( - &context, - needDefaultConstantBuffer); - - // Now walk through again to actually give everything - // ranges of registers... - for( auto& parameter : sharedContext.parameters ) - { - completeBindingsForParameter(&context, parameter); - } - - // After we have allocated registers/bindings to everything - // in the global scope we will process the parameters - // of each entry point in order. - // - // Note: the effect of the current implementation is to - // allocate non-overlapping registers/bindings between all - // the entry points in the compile request (e.g., if you - // have a vertex and fragment shader being compiled together, - // we will allocate distinct constant buffer registers for - // their uniform parameters). - // - // TODO: We probably need to provide some more nuanced control - // over whether entry points get overlapping or non-overlapping - // bindings. It seems clear that if we were compiling multiple - // compute kernels in one invocation we'd want them to get - // overlapping bindings, because we cannot ever have them bound - // together in a single pipeline state. - // - // Similarly, entry point parameters of DirectX Raytracing (DXR) - // shaders should probably be allowed to overlap by default, - // since those parameters should really go into the "local root signature." - // (Note: there is a bit more subtlety around ray tracing - // shaders that will be assembled into a "hit group") - // - // For now we are just doing the simplest thing, which will be - // appropriate for: - // - // * Compiling a single compute shader in a compile request. - // * Compiling some number of rasterization shader entry points - // in a single request, to be used together. - // * Compiling a single ray-tracing shader in a compile request. - // - for( auto entryPoint : sharedContext.programLayout->entryPoints ) - { - auto entryPointParamsLayout = entryPoint->parametersLayout; - completeBindingsForParameter(&context, entryPointParamsLayout); - } - - // Next we need to create a type layout to reflect the information - // we have collected, and we will use the `ScopeLayoutBuilder` - // to encapsulate the logic that can be shared with the entry-point - // case. - // - ScopeLayoutBuilder globalScopeLayoutBuilder; - globalScopeLayoutBuilder.beginLayout(&context); - for( auto& parameterInfo : sharedContext.parameters ) - { - globalScopeLayoutBuilder.addParameter(parameterInfo); - } - - auto globalScopeVarLayout = globalScopeLayoutBuilder.endLayout(); - if( globalConstantBufferBinding.count != 0 ) - { - auto cbInfo = globalScopeVarLayout->findOrAddResourceInfo(globalConstantBufferBinding.kind); - cbInfo->space = globalConstantBufferBinding.space; - cbInfo->index = globalConstantBufferBinding.index; - } - - // After we have laid out all the ordinary parameters, - // we need to go through the global scope plus each entry point, - // and "flush" out any pending data that was associated with - // those scopes as part of dealing with interface-type parameters. - // - _allocateBindingsForPendingData(&context, globalScopeVarLayout->pendingVarLayout); - for( auto entryPoint : sharedContext.programLayout->entryPoints ) - { - _allocateBindingsForPendingData(&context, entryPoint->parametersLayout->pendingVarLayout); - } - - - // HACK: we want global parameters to not have to deal with offsetting - // by the `VarLayout` stored in `globalScopeVarLayout`, so we will scan - // through and for any global parameter that used "pending" data, we will manually - // offset all of its resource infos to account for where the global pending data - // got placed. - // - // TODO: A more appropriate solution would be to pass the `globalScopeVarLayout` - // down into the pass that puts layout information onto global parameters in - // the IR, and apply the offsetting there. - // - for( auto& parameterInfo : sharedContext.parameters ) - { - for( auto varLayout : parameterInfo->varLayouts ) - { - auto pendingVarLayout = varLayout->pendingVarLayout; - if(!pendingVarLayout) continue; - - for( auto& resInfo : pendingVarLayout->resourceInfos ) - { - if( auto globalResInfo = globalScopeVarLayout->pendingVarLayout->FindResourceInfo(resInfo.kind) ) - { - resInfo.index += globalResInfo->index; - resInfo.space += globalResInfo->space; - } - } - } - } - - programLayout->parametersLayout = globalScopeVarLayout; - - { - const int numShaderRecordRegs = _calcTotalNumUsedRegistersForLayoutResourceKind(&context, LayoutResourceKind::ShaderRecord); - if (numShaderRecordRegs > 1) - { - sink->diagnose(SourceLoc(), Diagnostics::tooManyShaderRecordConstantBuffers, numShaderRecordRegs); - } - } - - return programLayout; -} - -ProgramLayout* TargetProgram::getOrCreateLayout(DiagnosticSink* sink) -{ - if( !m_layout ) - { - m_layout = generateParameterBindings(this, sink); - } - return m_layout; -} - -void generateParameterBindings( - Program* program, - TargetRequest* targetReq, - DiagnosticSink* sink) -{ - program->getTargetProgram(targetReq)->getOrCreateLayout(sink); -} - -} // namespace Slang diff --git a/source/slang/parameter-binding.h b/source/slang/parameter-binding.h deleted file mode 100644 index 82b114021..000000000 --- a/source/slang/parameter-binding.h +++ /dev/null @@ -1,34 +0,0 @@ -#ifndef SLANG_PARAMETER_BINDING_H -#define SLANG_PARAMETER_BINDING_H - -#include "../core/basic.h" -#include "syntax.h" - -#include "../../slang.h" - -namespace Slang { - -class Program; -class TargetRequest; - -// The parameter-binding interface is responsible for assigning -// binding locations/registers to every parameter of a shader -// program. This can include both parameters declared on a -// particular entry point, as well as parameters declared at -// global scope. -// - - -// Generate binding information for the given program, -// represented as a collection of different translation units, -// and attach that information to the syntax nodes -// of the program. - -void generateParameterBindings( - Program* program, - TargetRequest* targetReq, - DiagnosticSink* sink); - -} - -#endif // SLANG_REFLECTION_H diff --git a/source/slang/parser.cpp b/source/slang/parser.cpp deleted file mode 100644 index 825aea324..000000000 --- a/source/slang/parser.cpp +++ /dev/null @@ -1,4725 +0,0 @@ -#include "parser.h" - -#include - -#include "compiler.h" -#include "lookup.h" -#include "visitor.h" - -namespace Slang -{ - // pre-declare - static Name* getName(Parser* parser, String const& text); - - // Helper class useful to build a list of modifiers. - struct ModifierListBuilder - { - ModifierListBuilder() - { - m_next = &m_result; - } - void add(Modifier* modifier) - { - // Doesn't handle SharedModifiers - SLANG_ASSERT(as(modifier) == nullptr); - - // Splice at end - *m_next = modifier; - m_next = &modifier->next; - } - template - T* find() const - { - Modifier* cur = m_result; - while (cur) - { - T* castCur = as(cur); - if (castCur) - { - return castCur; - } - cur = cur->next; - } - return nullptr; - } - template - bool hasType() const - { - return find() != nullptr; - } - RefPtr getFirst() { return m_result; }; - protected: - - RefPtr m_result; - RefPtr* m_next; - }; - - enum Precedence : int - { - Invalid = -1, - Comma, - Assignment, - TernaryConditional, - LogicalOr, - LogicalAnd, - BitOr, - BitXor, - BitAnd, - EqualityComparison, - RelationalComparison, - BitShift, - Additive, - Multiplicative, - Prefix, - Postfix, - }; - - // TODO: implement two pass parsing for file reference and struct type recognition - - class Parser - { - public: - NamePool* namePool; - SourceLanguage sourceLanguage; - - NamePool* getNamePool() { return namePool; } - SourceLanguage getSourceLanguage() { return sourceLanguage; } - - int anonymousCounter = 0; - - RefPtr outerScope; - RefPtr currentScope; - - TokenReader tokenReader; - DiagnosticSink* sink; - int genericDepth = 0; - - // Have we seen any `import` declarations? If so, we need - // to parse function bodies completely, even if we are in - // "rewrite" mode. - bool haveSeenAnyImportDecls = false; - - // Is the parser in a "recovering" state? - // During recovery we don't emit additional errors, until we find - // a token that we expected, when we exit recovery. - bool isRecovering = false; - - void FillPosition(SyntaxNode * node) - { - node->loc = tokenReader.PeekLoc(); - } - void PushScope(ContainerDecl* containerDecl) - { - RefPtr newScope = new Scope(); - newScope->containerDecl = containerDecl; - newScope->parent = currentScope; - - currentScope = newScope; - } - - void pushScopeAndSetParent(ContainerDecl* containerDecl) - { - containerDecl->ParentDecl = currentScope->containerDecl; - PushScope(containerDecl); - } - - void PopScope() - { - currentScope = currentScope->parent; - } - Parser( - Session* session, - TokenSpan const& _tokens, - DiagnosticSink * sink, - RefPtr const& outerScope) - : tokenReader(_tokens) - , sink(sink) - , outerScope(outerScope) - , m_session(session) - {} - Parser(const Parser & other) = default; - - Session* m_session = nullptr; - Session* getSession() { return m_session; } - - Token ReadToken(); - Token ReadToken(TokenType type); - Token ReadToken(const char * string); - bool LookAheadToken(TokenType type, int offset = 0); - bool LookAheadToken(const char * string, int offset = 0); - void parseSourceFile(ModuleDecl* program); - RefPtr ParseStruct(); - RefPtr ParseClass(); - RefPtr ParseStatement(); - RefPtr parseBlockStatement(); - RefPtr parseVarDeclrStatement(Modifiers modifiers); - RefPtr parseIfStatement(); - RefPtr ParseForStatement(); - RefPtr ParseWhileStatement(); - RefPtr ParseDoWhileStatement(); - RefPtr ParseBreakStatement(); - RefPtr ParseContinueStatement(); - RefPtr ParseReturnStatement(); - RefPtr ParseExpressionStatement(); - RefPtr ParseExpression(Precedence level = Precedence::Comma); - - // Parse an expression that might be used in an initializer or argument context, so we should avoid operator-comma - inline RefPtr ParseInitExpr() { return ParseExpression(Precedence::Assignment); } - inline RefPtr ParseArgExpr() { return ParseExpression(Precedence::Assignment); } - - RefPtr ParseLeafExpression(); - RefPtr ParseParameter(); - RefPtr ParseType(); - TypeExp ParseTypeExp(); - - Parser & operator = (const Parser &) = delete; - }; - - // Forward Declarations - - static void ParseDeclBody( - Parser* parser, - ContainerDecl* containerDecl, - TokenType closingToken); - - static RefPtr parseEnumDecl(Parser* parser); - - // Parse the `{}`-delimeted body of an aggregate type declaration - static void parseAggTypeDeclBody( - Parser* parser, - AggTypeDeclBase* decl); - - static RefPtr ParseOptSemantics( - Parser* parser); - - static void ParseOptSemantics( - Parser* parser, - Decl* decl); - - static RefPtr ParseDecl( - Parser* parser, - ContainerDecl* containerDecl); - - static RefPtr ParseSingleDecl( - Parser* parser, - ContainerDecl* containerDecl); - - // - - static void Unexpected( - Parser* parser) - { - // Don't emit "unexpected token" errors if we are in recovering mode - if (!parser->isRecovering) - { - parser->sink->diagnose(parser->tokenReader.PeekLoc(), Diagnostics::unexpectedToken, - parser->tokenReader.PeekTokenType()); - - // Switch into recovery mode, to suppress additional errors - parser->isRecovering = true; - } - } - - static void Unexpected( - Parser* parser, - char const* expected) - { - // Don't emit "unexpected token" errors if we are in recovering mode - if (!parser->isRecovering) - { - parser->sink->diagnose(parser->tokenReader.PeekLoc(), Diagnostics::unexpectedTokenExpectedTokenName, - parser->tokenReader.PeekTokenType(), - expected); - - // Switch into recovery mode, to suppress additional errors - parser->isRecovering = true; - } - } - - static void Unexpected( - Parser* parser, - TokenType expected) - { - // Don't emit "unexpected token" errors if we are in recovering mode - if (!parser->isRecovering) - { - parser->sink->diagnose(parser->tokenReader.PeekLoc(), Diagnostics::unexpectedTokenExpectedTokenType, - parser->tokenReader.PeekTokenType(), - expected); - - // Switch into recovery mode, to suppress additional errors - parser->isRecovering = true; - } - } - - static TokenType SkipToMatchingToken(TokenReader* reader, TokenType tokenType); - - // Skip a singel balanced token, which is either a single token in - // the common case, or a matched pair of tokens for `()`, `[]`, and `{}` - static TokenType SkipBalancedToken( - TokenReader* reader) - { - TokenType tokenType = reader->AdvanceToken().type; - switch (tokenType) - { - default: - break; - - case TokenType::LParent: tokenType = SkipToMatchingToken(reader, TokenType::RParent); break; - case TokenType::LBrace: tokenType = SkipToMatchingToken(reader, TokenType::RBrace); break; - case TokenType::LBracket: tokenType = SkipToMatchingToken(reader, TokenType::RBracket); break; - } - return tokenType; - } - - // Skip balanced - static TokenType SkipToMatchingToken( - TokenReader* reader, - TokenType tokenType) - { - for (;;) - { - if (reader->IsAtEnd()) return TokenType::EndOfFile; - if (reader->PeekTokenType() == tokenType) - { - reader->AdvanceToken(); - return tokenType; - } - SkipBalancedToken(reader); - } - } - - // Is the given token type one that is used to "close" a - // balanced construct. - static bool IsClosingToken(TokenType tokenType) - { - switch (tokenType) - { - case TokenType::EndOfFile: - case TokenType::RBracket: - case TokenType::RParent: - case TokenType::RBrace: - return true; - - default: - return false; - } - } - - - // Expect an identifier token with the given content, and consume it. - Token Parser::ReadToken(const char* expected) - { - if (tokenReader.PeekTokenType() == TokenType::Identifier - && tokenReader.PeekToken().Content == expected) - { - isRecovering = false; - return tokenReader.AdvanceToken(); - } - - if (!isRecovering) - { - Unexpected(this, expected); - return tokenReader.PeekToken(); - } - else - { - // Try to find a place to recover - for (;;) - { - // The token we expected? - // Then exit recovery mode and pretend like all is well. - if (tokenReader.PeekTokenType() == TokenType::Identifier - && tokenReader.PeekToken().Content == expected) - { - isRecovering = false; - return tokenReader.AdvanceToken(); - } - - - // Don't skip past any "closing" tokens. - if (IsClosingToken(tokenReader.PeekTokenType())) - { - return tokenReader.PeekToken(); - } - - // Skip balanced tokens and try again. - SkipBalancedToken(&tokenReader); - } - } - } - - Token Parser::ReadToken() - { - return tokenReader.AdvanceToken(); - } - - static bool TryRecover( - Parser* parser, - TokenType const* recoverBefore, - int recoverBeforeCount, - TokenType const* recoverAfter, - int recoverAfterCount) - { - if (!parser->isRecovering) - return true; - - // Determine if we are looking for common closing tokens, - // so that we can know whether or we are allowed to skip - // over them. - - bool lookingForEOF = false; - bool lookingForRCurly = false; - bool lookingForRParen = false; - bool lookingForRSquare = false; - - for (int ii = 0; ii < recoverBeforeCount; ++ii) - { - switch (recoverBefore[ii]) - { - default: - break; - - case TokenType::EndOfFile: lookingForEOF = true; break; - case TokenType::RBrace: lookingForRCurly = true; break; - case TokenType::RParent: lookingForRParen = true; break; - case TokenType::RBracket: lookingForRSquare = true; break; - } - } - for (int ii = 0; ii < recoverAfterCount; ++ii) - { - switch (recoverAfter[ii]) - { - default: - break; - - case TokenType::EndOfFile: lookingForEOF = true; break; - case TokenType::RBrace: lookingForRCurly = true; break; - case TokenType::RParent: lookingForRParen = true; break; - case TokenType::RBracket: lookingForRSquare = true; break; - } - } - - TokenReader* tokenReader = &parser->tokenReader; - for (;;) - { - TokenType peek = tokenReader->PeekTokenType(); - - // Is the next token in our recover-before set? - // If so, then we have recovered successfully! - for (int ii = 0; ii < recoverBeforeCount; ++ii) - { - if (peek == recoverBefore[ii]) - { - parser->isRecovering = false; - return true; - } - } - - // If we are looking at a token in our recover-after set, - // then consume it and recover - for (int ii = 0; ii < recoverAfterCount; ++ii) - { - if (peek == recoverAfter[ii]) - { - tokenReader->AdvanceToken(); - parser->isRecovering = false; - return true; - } - } - - // Don't try to skip past end of file - if (peek == TokenType::EndOfFile) - return false; - - switch (peek) - { - // Don't skip past simple "closing" tokens, *unless* - // we are looking for a closing token - case TokenType::RParent: - case TokenType::RBracket: - if (lookingForRParen || lookingForRSquare || lookingForRCurly || lookingForEOF) - { - // We are looking for a closing token, so it is okay to skip these - } - else - return false; - break; - - // Don't skip a `}`, to avoid spurious errors, - // with the exception of when we are looking for EOF - case TokenType::RBrace: - if (lookingForRCurly || lookingForEOF) - { - // We are looking for end-of-file, so it is okay to skip here - } - else - { - return false; - } - } - - // Skip balanced tokens and try again. - TokenType skipped = SkipBalancedToken(tokenReader); - - // If we happened to find a matched pair of tokens, and - // the end of it was a token we were looking for, - // then recover here - for (int ii = 0; ii < recoverAfterCount; ++ii) - { - if (skipped == recoverAfter[ii]) - { - parser->isRecovering = false; - return true; - } - } - } - } - - static bool TryRecoverBefore( - Parser* parser, - TokenType before0) - { - TokenType recoverBefore[] = { before0 }; - return TryRecover(parser, recoverBefore, 1, nullptr, 0); - } - - // Default recovery strategy, to use inside `{}`-delimeted blocks. - static bool TryRecover( - Parser* parser) - { - TokenType recoverBefore[] = { TokenType::RBrace }; - TokenType recoverAfter[] = { TokenType::Semicolon }; - return TryRecover(parser, recoverBefore, 1, recoverAfter, 1); - } - - Token Parser::ReadToken(TokenType expected) - { - if (tokenReader.PeekTokenType() == expected) - { - isRecovering = false; - return tokenReader.AdvanceToken(); - } - - if (!isRecovering) - { - Unexpected(this, expected); - return tokenReader.PeekToken(); - } - else - { - // Try to find a place to recover - if (TryRecoverBefore(this, expected)) - { - isRecovering = false; - return tokenReader.AdvanceToken(); - } - - return tokenReader.PeekToken(); - } - } - - bool Parser::LookAheadToken(const char * string, int offset) - { - TokenReader r = tokenReader; - for (int ii = 0; ii < offset; ++ii) - r.AdvanceToken(); - - return r.PeekTokenType() == TokenType::Identifier - && r.PeekToken().Content == string; -} - - bool Parser::LookAheadToken(TokenType type, int offset) - { - TokenReader r = tokenReader; - for (int ii = 0; ii < offset; ++ii) - r.AdvanceToken(); - - return r.PeekTokenType() == type; - } - - // Consume a token and return true it if matches, otherwise false - bool AdvanceIf(Parser* parser, TokenType tokenType) - { - if (parser->LookAheadToken(tokenType)) - { - parser->ReadToken(); - return true; - } - return false; - } - - // Consume a token and return true it if matches, otherwise false - bool AdvanceIf(Parser* parser, char const* text) - { - if (parser->LookAheadToken(text)) - { - parser->ReadToken(); - return true; - } - return false; - } - - // Consume a token and return true if it matches, otherwise check - // for end-of-file and expect that token (potentially producing - // an error) and return true to maintain forward progress. - // Otherwise return false. - bool AdvanceIfMatch(Parser* parser, TokenType tokenType) - { - // If we've run into a syntax error, but haven't recovered inside - // the block, then try to recover here. - if (parser->isRecovering) - { - TryRecoverBefore(parser, tokenType); - } - if (AdvanceIf(parser, tokenType)) - return true; - if (parser->tokenReader.PeekTokenType() == TokenType::EndOfFile) - { - parser->ReadToken(tokenType); - return true; - } - return false; - } - - RefPtr ParseTypeDef(Parser* parser, void* /*userData*/) - { - RefPtr typeDefDecl = new TypeDefDecl(); - - // TODO(tfoley): parse an actual declarator - auto type = parser->ParseTypeExp(); - - auto nameToken = parser->ReadToken(TokenType::Identifier); - typeDefDecl->loc = nameToken.loc; - - typeDefDecl->nameAndLoc = NameLoc(nameToken); - typeDefDecl->type = type; - - return typeDefDecl; - } - - // Add a modifier to a list of modifiers being built - static void AddModifier(RefPtr** ioModifierLink, RefPtr modifier) - { - RefPtr*& modifierLink = *ioModifierLink; - - // We'd like to add the modifier to the end of the list, - // but we need to be careful, in case there is a "shared" - // section of modifiers for multiple declarations. - // - // TODO: This whole approach is a mess because we are "accidentally quadratic" - // when adding many modifiers. - for(;;) - { - // At end of the chain? Done. - if(!*modifierLink) - break; - - // About to look at shared modifiers? Done. - RefPtr linkMod = *modifierLink; - if(as(linkMod)) - { - break; - } - - // Otherwise: keep traversing the modifier list. - modifierLink = &(*modifierLink)->next; - } - - // Splice the modifier into the linked list - - // We need to deal with the case where the modifier to - // be spliced in might actually be a modifier *list*, - // so that we actually want to splice in at the - // end of the new list... - auto spliceLink = &modifier->next; - while(*spliceLink) - spliceLink = &(*spliceLink)->next; - - // Do the splice. - *spliceLink = *modifierLink; - - *modifierLink = modifier; - modifierLink = &modifier->next; - } - - void addModifier( - RefPtr syntax, - RefPtr modifier) - { - auto modifierLink = &syntax->modifiers.first; - AddModifier(&modifierLink, modifier); - } - - // - // '::'? identifier ('::' identifier)* - static Token parseAttributeName(Parser* parser) - { - const SourceLoc scopedIdSourceLoc = parser->tokenReader.PeekLoc(); - - // Strip initial :: if there is one - const TokenType initialTokenType = parser->tokenReader.PeekTokenType(); - if (initialTokenType == TokenType::Scope) - { - parser->ReadToken(TokenType::Scope); - } - - const Token firstIdentifier = parser->ReadToken(TokenType::Identifier); - if (initialTokenType != TokenType::Scope && parser->tokenReader.PeekTokenType() != TokenType::Scope) - { - return firstIdentifier; - } - - // Build up scoped string - StringBuilder scopedIdentifierBuilder; - if (initialTokenType == TokenType::Scope) - { - scopedIdentifierBuilder.Append('_'); - } - scopedIdentifierBuilder.Append(firstIdentifier.Content); - - while (parser->tokenReader.PeekTokenType() == TokenType::Scope) - { - parser->ReadToken(TokenType::Scope); - scopedIdentifierBuilder.Append('_'); - - const Token nextIdentifier(parser->ReadToken(TokenType::Identifier)); - scopedIdentifierBuilder.Append(nextIdentifier.Content); - } - - // Make a 'token' - SourceManager* sourceManager = parser->sink->sourceManager; - const UnownedStringSlice scopedIdentifier(sourceManager->allocateStringSlice(scopedIdentifierBuilder.getUnownedSlice())); - Token token(TokenType::Identifier, scopedIdentifier, scopedIdSourceLoc); - - // Get the name pool - auto namePool = parser->getNamePool(); - - // Since it's an Identifier have to set the name. - token.ptrValue = namePool->getName(token.Content); - - return token; - } - - // Parse HLSL-style `[name(arg, ...)]` style "attribute" modifiers - static void ParseSquareBracketAttributes(Parser* parser, RefPtr** ioModifierLink) - { - parser->ReadToken(TokenType::LBracket); - - const bool hasDoubleBracket = AdvanceIf(parser, TokenType::LBracket); - - for(;;) - { - // Note: When parsing we just construct an AST node for an - // "unchecked" attribute, and defer all detailed semantic - // checking until later. - // - // An alternative would be to perform lookup of an `AttributeDecl` - // at this point, similar to what we do for `SyntaxDecl`, but it - // seems better to not complicate the parsing process any more. - // - - Token nameToken = parseAttributeName(parser); - - RefPtr modifier = new UncheckedAttribute(); - modifier->name = nameToken.getName(); - modifier->loc = nameToken.getLoc(); - modifier->scope = parser->currentScope; - - if (AdvanceIf(parser, TokenType::LParent)) - { - // HLSL-style `[name(arg0, ...)]` attribute - - while (!AdvanceIfMatch(parser, TokenType::RParent)) - { - auto arg = parser->ParseArgExpr(); - if (arg) - { - modifier->args.add(arg); - } - - if (AdvanceIfMatch(parser, TokenType::RParent)) - break; - - parser->ReadToken(TokenType::Comma); - } - } - AddModifier(ioModifierLink, modifier); - - - if (AdvanceIfMatch(parser, TokenType::RBracket)) - break; - - parser->ReadToken(TokenType::Comma); - } - - if (hasDoubleBracket) - { - // Read the second ] - parser->ReadToken(TokenType::RBracket); - } - } - - static TokenType peekTokenType(Parser* parser) - { - return parser->tokenReader.PeekTokenType(); - } - - static Token advanceToken(Parser* parser) - { - return parser->ReadToken(); - } - - static Token peekToken(Parser* parser) - { - return parser->tokenReader.PeekToken(); - } - - static SyntaxDecl* tryLookUpSyntaxDecl( - Parser* parser, - Name* name) - { - // Let's look up the name and see what we find. - - auto lookupResult = lookUp( - parser->getSession(), - nullptr, // no semantics visitor available yet - name, - parser->currentScope); - - // If we didn't find anything, or the result was overloaded, - // then we aren't going to be able to extract a single decl. - if(!lookupResult.isValid() || lookupResult.isOverloaded()) - return nullptr; - - auto decl = lookupResult.item.declRef.getDecl(); - if( auto syntaxDecl = as(decl) ) - { - return syntaxDecl; - } - else - { - return nullptr; - } - } - - template - bool tryParseUsingSyntaxDecl( - Parser* parser, - SyntaxDecl* syntaxDecl, - RefPtr* outSyntax) - { - if (!syntaxDecl) - return false; - - if (!syntaxDecl->syntaxClass.isSubClassOf()) - return false; - - // Consume the token that specified the keyword - auto keywordToken = advanceToken(parser); - - RefPtr parsedObject = syntaxDecl->parseCallback(parser, syntaxDecl->parseUserData); - if (!parsedObject) - { - return false; - } - - auto syntax = as(parsedObject); - if (syntax) - { - if (!syntax->loc.isValid()) - { - syntax->loc = keywordToken.loc; - } - } - else if (parsedObject) - { - // Something was parsed, but it didn't have the expected type! - SLANG_DIAGNOSE_UNEXPECTED(parser->sink, keywordToken, "parser callback did not return the expected type"); - } - - *outSyntax = syntax; - return true; - } - - template - bool tryParseUsingSyntaxDecl( - Parser* parser, - RefPtr* outSyntax) - { - if (peekTokenType(parser) != TokenType::Identifier) - return false; - - auto nameToken = peekToken(parser); - auto name = nameToken.getName(); - - auto syntaxDecl = tryLookUpSyntaxDecl(parser, name); - - if (!syntaxDecl) - return false; - - return tryParseUsingSyntaxDecl(parser, syntaxDecl, outSyntax); - } - - static Modifiers ParseModifiers(Parser* parser) - { - Modifiers modifiers; - RefPtr* modifierLink = &modifiers.first; - for (;;) - { - SourceLoc loc = parser->tokenReader.PeekLoc(); - - switch (peekTokenType(parser)) - { - default: - // If we don't see a token type that we recognize, then - // assume we are done with the modifier sequence. - return modifiers; - - case TokenType::Identifier: - { - // We see an identifier ahead, and it might be the name - // of a modifier keyword of some kind. - - Token nameToken = peekToken(parser); - - RefPtr parsedModifier; - if (tryParseUsingSyntaxDecl(parser, &parsedModifier)) - { - parsedModifier->name = nameToken.getName(); - if (!parsedModifier->loc.isValid()) - { - parsedModifier->loc = nameToken.loc; - } - - AddModifier(&modifierLink, parsedModifier); - continue; - } - - // If there was no match for a modifier keyword, then we - // must be at the end of the modifier sequence - return modifiers; - } - break; - - // HLSL uses `[attributeName]` style for its modifiers, which closely - // matches the C++ `[[attributeName]]` style. - case TokenType::LBracket: - ParseSquareBracketAttributes(parser, &modifierLink); - break; - } - } - } - - static Name* getName(Parser* parser, String const& text) - { - return parser->getNamePool()->getName(text); - } - - static NameLoc expectIdentifier(Parser* parser) - { - return NameLoc(parser->ReadToken(TokenType::Identifier)); - } - - - static RefPtr parseImportDecl( - Parser* parser, void* /*userData*/) - { - parser->haveSeenAnyImportDecls = true; - - auto decl = new ImportDecl(); - decl->scope = parser->currentScope; - - if (peekTokenType(parser) == TokenType::StringLiteral) - { - auto nameToken = parser->ReadToken(TokenType::StringLiteral); - auto nameString = getStringLiteralTokenValue(nameToken); - auto moduleName = getName(parser, nameString); - - decl->moduleNameAndLoc = NameLoc(moduleName, nameToken.loc); - } - else - { - auto moduleNameAndLoc = expectIdentifier(parser); - - // We allow a dotted format for the name, as sugar - if (peekTokenType(parser) == TokenType::Dot) - { - StringBuilder sb; - sb << getText(moduleNameAndLoc.name); - while (AdvanceIf(parser, TokenType::Dot)) - { - sb << "/"; - sb << parser->ReadToken(TokenType::Identifier).Content; - } - - moduleNameAndLoc.name = getName(parser, sb.ProduceString()); - } - - decl->moduleNameAndLoc = moduleNameAndLoc; - } - - parser->ReadToken(TokenType::Semicolon); - - return decl; - } - - static NameLoc ParseDeclName( - Parser* parser) - { - Token nameToken; - if (AdvanceIf(parser, "operator")) - { - nameToken = parser->ReadToken(); - switch (nameToken.type) - { - case TokenType::OpAdd: case TokenType::OpSub: case TokenType::OpMul: case TokenType::OpDiv: - case TokenType::OpMod: case TokenType::OpNot: case TokenType::OpBitNot: case TokenType::OpLsh: case TokenType::OpRsh: - case TokenType::OpEql: case TokenType::OpNeq: case TokenType::OpGreater: case TokenType::OpLess: case TokenType::OpGeq: - case TokenType::OpLeq: case TokenType::OpAnd: case TokenType::OpOr: case TokenType::OpBitXor: case TokenType::OpBitAnd: - case TokenType::OpBitOr: case TokenType::OpInc: case TokenType::OpDec: - case TokenType::OpAddAssign: - case TokenType::OpSubAssign: - case TokenType::OpMulAssign: - case TokenType::OpDivAssign: - case TokenType::OpModAssign: - case TokenType::OpShlAssign: - case TokenType::OpShrAssign: - case TokenType::OpOrAssign: - case TokenType::OpAndAssign: - case TokenType::OpXorAssign: - - // Note(tfoley): A bit of a hack: - case TokenType::Comma: - case TokenType::OpAssign: - break; - - // Note(tfoley): Even more of a hack! - case TokenType::QuestionMark: - if (AdvanceIf(parser, TokenType::Colon)) - { - // Concat : onto ? - nameToken.Content = UnownedStringSlice::fromLiteral("?:"); - break; - } - ; // fall-thru - default: - parser->sink->diagnose(nameToken.loc, Diagnostics::invalidOperator, nameToken); - break; - } - - return NameLoc( - getName(parser, nameToken.Content), - nameToken.loc); - } - else - { - nameToken = parser->ReadToken(TokenType::Identifier); - return NameLoc(nameToken); - } - } - - // A "declarator" as used in C-style languages - struct Declarator : RefObject - { - // Different cases of declarator appear as "flavors" here - enum class Flavor - { - name, - Pointer, - Array, - }; - Flavor flavor; - }; - - // The most common case of declarator uses a simple name - struct NameDeclarator : Declarator - { - NameLoc nameAndLoc; - }; - - // A declarator that declares a pointer type - struct PointerDeclarator : Declarator - { - // location of the `*` token - SourceLoc starLoc; - - RefPtr inner; - }; - - // A declarator that declares an array type - struct ArrayDeclarator : Declarator - { - RefPtr inner; - - // location of the `[` token - SourceLoc openBracketLoc; - - // The expression that yields the element count, or NULL - RefPtr elementCountExpr; - }; - - // "Unwrapped" information about a declarator - struct DeclaratorInfo - { - RefPtr typeSpec; - NameLoc nameAndLoc; - RefPtr semantics; - RefPtr initializer; - }; - - // Add a member declaration to its container, and ensure that its - // parent link is set up correctly. - static void AddMember(RefPtr container, RefPtr member) - { - if (container) - { - member->ParentDecl = container.Ptr(); - container->Members.add(member); - - container->memberDictionaryIsValid = false; - } - } - - static void AddMember(RefPtr scope, RefPtr member) - { - if (scope) - { - AddMember(scope->containerDecl, member); - } - } - - static RefPtr ParseGenericParamDecl( - Parser* parser, - RefPtr genericDecl) - { - // simple syntax to introduce a value parameter - if (AdvanceIf(parser, "let")) - { - // default case is a type parameter - auto paramDecl = new GenericValueParamDecl(); - paramDecl->nameAndLoc = NameLoc(parser->ReadToken(TokenType::Identifier)); - if (AdvanceIf(parser, TokenType::Colon)) - { - paramDecl->type = parser->ParseTypeExp(); - } - if (AdvanceIf(parser, TokenType::OpAssign)) - { - paramDecl->initExpr = parser->ParseInitExpr(); - } - return paramDecl; - } - else - { - // default case is a type parameter - RefPtr paramDecl = new GenericTypeParamDecl(); - parser->FillPosition(paramDecl); - paramDecl->nameAndLoc = NameLoc(parser->ReadToken(TokenType::Identifier)); - if (AdvanceIf(parser, TokenType::Colon)) - { - // The user is apply a constraint to this type parameter... - - auto paramConstraint = new GenericTypeConstraintDecl(); - parser->FillPosition(paramConstraint); - - auto paramType = DeclRefType::Create( - parser->getSession(), - DeclRef(paramDecl, nullptr)); - - auto paramTypeExpr = new SharedTypeExpr(); - paramTypeExpr->loc = paramDecl->loc; - paramTypeExpr->base.type = paramType; - paramTypeExpr->type = QualType(getTypeType(paramType)); - - paramConstraint->sub = TypeExp(paramTypeExpr); - paramConstraint->sup = parser->ParseTypeExp(); - - AddMember(genericDecl, paramConstraint); - - - } - if (AdvanceIf(parser, TokenType::OpAssign)) - { - paramDecl->initType = parser->ParseTypeExp(); - } - return paramDecl; - } - } - - template - static void ParseGenericDeclImpl( - Parser* parser, GenericDecl* decl, const TFunc & parseInnerFunc) - { - parser->ReadToken(TokenType::OpLess); - parser->genericDepth++; - while (!parser->LookAheadToken(TokenType::OpGreater)) - { - AddMember(decl, ParseGenericParamDecl(parser, decl)); - - if (parser->LookAheadToken(TokenType::OpGreater)) - break; - - parser->ReadToken(TokenType::Comma); - } - parser->genericDepth--; - parser->ReadToken(TokenType::OpGreater); - decl->inner = parseInnerFunc(decl); - decl->inner->ParentDecl = decl; - - // A generic decl hijacks the name of the declaration - // it wraps, so that lookup can find it. - if (decl->inner) - { - decl->nameAndLoc = decl->inner->nameAndLoc; - decl->loc = decl->inner->loc; - } - } - - template - static RefPtr parseOptGenericDecl( - Parser* parser, const ParseFunc& parseInner) - { - // TODO: may want more advanced disambiguation than this... - if (parser->LookAheadToken(TokenType::OpLess)) - { - RefPtr genericDecl = new GenericDecl(); - parser->FillPosition(genericDecl); - parser->PushScope(genericDecl); - ParseGenericDeclImpl(parser, genericDecl, parseInner); - parser->PopScope(); - return genericDecl; - } - else - { - return parseInner(nullptr); - } - } - - static RefPtr ParseGenericDecl(Parser* parser, void*) - { - RefPtr decl = new GenericDecl(); - parser->FillPosition(decl.Ptr()); - parser->PushScope(decl.Ptr()); - ParseGenericDeclImpl(parser, decl.Ptr(), [=](GenericDecl* genDecl) {return ParseSingleDecl(parser, genDecl); }); - parser->PopScope(); - return decl; - } - - static void parseParameterList( - Parser* parser, - RefPtr decl) - { - parser->ReadToken(TokenType::LParent); - - // Allow a declaration to use the keyword `void` for a parameter list, - // since that was required in ancient C, and continues to be supported - // in a bunc hof its derivatives even if it is a Bad Design Choice - // - // TODO: conditionalize this so we don't keep this around for "pure" - // Slang code - if( parser->LookAheadToken("void") && parser->LookAheadToken(TokenType::RParent, 1) ) - { - parser->ReadToken("void"); - parser->ReadToken(TokenType::RParent); - return; - } - - while (!AdvanceIfMatch(parser, TokenType::RParent)) - { - AddMember(decl, parser->ParseParameter()); - if (AdvanceIf(parser, TokenType::RParent)) - break; - parser->ReadToken(TokenType::Comma); - } - } - - // systematically replace all scopes in an expression tree - class ReplaceScopeVisitor : public ExprVisitor - { - public: - RefPtr scope; - void visitDeclRefExpr(DeclRefExpr* expr) - { - expr->scope = scope; - } - void visitGenericAppExpr(GenericAppExpr * expr) - { - expr->FunctionExpr->accept(this, nullptr); - for (auto arg : expr->Arguments) - arg->accept(this, nullptr); - } - void visitIndexExpr(IndexExpr * expr) - { - expr->BaseExpression->accept(this, nullptr); - expr->IndexExpression->accept(this, nullptr); - } - void visitMemberExpr(MemberExpr * expr) - { - expr->BaseExpression->accept(this, nullptr); - expr->scope = scope; - } - void visitStaticMemberExpr(StaticMemberExpr * expr) - { - expr->BaseExpression->accept(this, nullptr); - expr->scope = scope; - } - void visitExpr(Expr* /*expr*/) - {} - }; - - /// Parse an optional body statement for a declaration that can have a body. - static RefPtr parseOptBody(Parser* parser) - { - if (AdvanceIf(parser, TokenType::Semicolon)) - { - // empty body - return nullptr; - } - else - { - return parser->parseBlockStatement(); - } - } - - /// Complete parsing of a function using traditional (C-like) declarator syntax - static RefPtr parseTraditionalFuncDecl( - Parser* parser, - DeclaratorInfo const& declaratorInfo) - { - RefPtr decl = new FuncDecl(); - parser->FillPosition(decl.Ptr()); - decl->loc = declaratorInfo.nameAndLoc.loc; - decl->nameAndLoc = declaratorInfo.nameAndLoc; - - return parseOptGenericDecl(parser, [&](GenericDecl*) - { - // HACK: The return type of the function will already have been - // parsed in a scope that didn't include the function's generic - // parameters. - // - // We will use a visitor here to try and replace the scope associated - // with any name expressiosn in the reuslt type. - // - // TODO: This should be fixed by not associating scopes with - // such expressions at parse time, and instead pushing down scopes - // as part of the state during semantic checking. - // - ReplaceScopeVisitor replaceScopeVisitor; - replaceScopeVisitor.scope = parser->currentScope; - declaratorInfo.typeSpec->accept(&replaceScopeVisitor, nullptr); - - decl->ReturnType = TypeExp(declaratorInfo.typeSpec); - - parser->PushScope(decl); - - parseParameterList(parser, decl); - ParseOptSemantics(parser, decl.Ptr()); - decl->Body = parseOptBody(parser); - - parser->PopScope(); - - return decl; - }); - } - - static RefPtr CreateVarDeclForContext( - ContainerDecl* containerDecl ) - { - if (as(containerDecl)) - { - // Function parameters always use their dedicated syntax class. - // - return new ParamDecl(); - } - else - { - // Globals, locals, and member variables all use the same syntax class. - // - return new VarDecl(); - } - } - - // Add modifiers to the end of the modifier list for a declaration - void AddModifiers(Decl* decl, RefPtr modifiers) - { - if (!modifiers) - return; - - RefPtr* link = &decl->modifiers.first; - while (*link) - { - link = &(*link)->next; - } - *link = modifiers; - } - - static Name* generateName(Parser* parser, String const& base) - { - // TODO: somehow mangle the name to avoid clashes - return getName(parser, "SLANG_" + base); - } - - static Name* generateName(Parser* parser) - { - return generateName(parser, "anonymous_" + String(parser->anonymousCounter++)); - } - - - // Set up a variable declaration based on what we saw in its declarator... - static void CompleteVarDecl( - Parser* parser, - RefPtr decl, - DeclaratorInfo const& declaratorInfo) - { - parser->FillPosition(decl.Ptr()); - - if( !declaratorInfo.nameAndLoc.name ) - { - // HACK(tfoley): we always give a name, even if the declarator didn't include one... :( - decl->nameAndLoc = NameLoc(generateName(parser)); - } - else - { - decl->loc = declaratorInfo.nameAndLoc.loc; - decl->nameAndLoc = declaratorInfo.nameAndLoc; - } - decl->type = TypeExp(declaratorInfo.typeSpec); - - AddModifiers(decl.Ptr(), declaratorInfo.semantics); - - decl->initExpr = declaratorInfo.initializer; - } - - static RefPtr ParseDeclarator(Parser* parser); - - static RefPtr ParseDirectAbstractDeclarator( - Parser* parser) - { - RefPtr declarator; - switch( parser->tokenReader.PeekTokenType() ) - { - case TokenType::Identifier: - { - auto nameDeclarator = new NameDeclarator(); - nameDeclarator->flavor = Declarator::Flavor::name; - nameDeclarator->nameAndLoc = ParseDeclName(parser); - declarator = nameDeclarator; - } - break; - - case TokenType::LParent: - { - // Note(tfoley): This is a point where disambiguation is required. - // We could be looking at an abstract declarator for a function-type - // parameter: - // - // void F( int(int) ); - // - // Or we could be looking at the use of parenthesese in an ordinary - // declarator: - // - // void (*f)(int); - // - // The difference really doesn't matter right now, but we err in - // the direction of assuming the second case. - parser->ReadToken(TokenType::LParent); - declarator = ParseDeclarator(parser); - parser->ReadToken(TokenType::RParent); - } - break; - - default: - // an empty declarator is allowed - return nullptr; - } - - // postifx additions - for( ;;) - { - switch( parser->tokenReader.PeekTokenType() ) - { - case TokenType::LBracket: - { - auto arrayDeclarator = new ArrayDeclarator(); - arrayDeclarator->openBracketLoc = parser->tokenReader.PeekLoc(); - arrayDeclarator->flavor = Declarator::Flavor::Array; - arrayDeclarator->inner = declarator; - - parser->ReadToken(TokenType::LBracket); - if( parser->tokenReader.PeekTokenType() != TokenType::RBracket ) - { - arrayDeclarator->elementCountExpr = parser->ParseExpression(); - } - parser->ReadToken(TokenType::RBracket); - - declarator = arrayDeclarator; - continue; - } - - case TokenType::LParent: - break; - - default: - break; - } - - break; - } - - return declarator; - } - - // Parse a declarator (or at least as much of one as we support) - static RefPtr ParseDeclarator( - Parser* parser) - { - if( parser->tokenReader.PeekTokenType() == TokenType::OpMul ) - { - auto ptrDeclarator = new PointerDeclarator(); - ptrDeclarator->starLoc = parser->tokenReader.PeekLoc(); - ptrDeclarator->flavor = Declarator::Flavor::Pointer; - - parser->ReadToken(TokenType::OpMul); - - // TODO(tfoley): allow qualifiers like `const` here? - - ptrDeclarator->inner = ParseDeclarator(parser); - return ptrDeclarator; - } - else - { - return ParseDirectAbstractDeclarator(parser); - } - } - - // A declarator plus optional semantics and initializer - struct InitDeclarator - { - RefPtr declarator; - RefPtr semantics; - RefPtr initializer; - }; - - // Parse a declarator plus optional semantics - static InitDeclarator ParseSemanticDeclarator( - Parser* parser) - { - InitDeclarator result; - result.declarator = ParseDeclarator(parser); - result.semantics = ParseOptSemantics(parser); - return result; - } - - // Parse a declarator plus optional semantics and initializer - static InitDeclarator ParseInitDeclarator( - Parser* parser) - { - InitDeclarator result = ParseSemanticDeclarator(parser); - if (AdvanceIf(parser, TokenType::OpAssign)) - { - result.initializer = parser->ParseInitExpr(); - } - return result; - } - - static void UnwrapDeclarator( - RefPtr declarator, - DeclaratorInfo* ioInfo) - { - while( declarator ) - { - switch(declarator->flavor) - { - case Declarator::Flavor::name: - { - auto nameDeclarator = (NameDeclarator*) declarator.Ptr(); - ioInfo->nameAndLoc = nameDeclarator->nameAndLoc; - return; - } - break; - - case Declarator::Flavor::Pointer: - { - auto ptrDeclarator = (PointerDeclarator*) declarator.Ptr(); - - // TODO(tfoley): we don't support pointers for now - // ioInfo->typeSpec = new PointerTypeExpr(ioInfo->typeSpec); - - declarator = ptrDeclarator->inner; - } - break; - - case Declarator::Flavor::Array: - { - // TODO(tfoley): we don't support pointers for now - auto arrayDeclarator = (ArrayDeclarator*) declarator.Ptr(); - - auto arrayTypeExpr = new IndexExpr(); - arrayTypeExpr->loc = arrayDeclarator->openBracketLoc; - arrayTypeExpr->BaseExpression = ioInfo->typeSpec; - arrayTypeExpr->IndexExpression = arrayDeclarator->elementCountExpr; - ioInfo->typeSpec = arrayTypeExpr; - - declarator = arrayDeclarator->inner; - } - break; - - default: - SLANG_UNREACHABLE("all cases handled"); - break; - } - } - } - - static void UnwrapDeclarator( - InitDeclarator const& initDeclarator, - DeclaratorInfo* ioInfo) - { - UnwrapDeclarator(initDeclarator.declarator, ioInfo); - ioInfo->semantics = initDeclarator.semantics; - ioInfo->initializer = initDeclarator.initializer; - } - - // Either a single declaration, or a group of them - struct DeclGroupBuilder - { - SourceLoc startPosition; - RefPtr decl; - RefPtr group; - - // Add a new declaration to the potential group - void addDecl( - RefPtr newDecl) - { - SLANG_ASSERT(newDecl); - - if( decl ) - { - group = new DeclGroup(); - group->loc = startPosition; - group->decls.add(decl); - decl = nullptr; - } - - if( group ) - { - group->decls.add(newDecl); - } - else - { - decl = newDecl; - } - } - - RefPtr getResult() - { - if(group) return group; - return decl; - } - }; - - // Pares an argument to an application of a generic - RefPtr ParseGenericArg(Parser* parser) - { - return parser->ParseArgExpr(); - } - - // Create a type expression that will refer to the given declaration - static RefPtr - createDeclRefType(Parser* parser, RefPtr decl) - { - // For now we just construct an expression that - // will look up the given declaration by name. - // - // TODO: do this better, e.g. by filling in the `declRef` field directly - - auto expr = new VarExpr(); - expr->scope = parser->currentScope.Ptr(); - expr->loc = decl->getNameLoc(); - expr->name = decl->getName(); - return expr; - } - - // Representation for a parsed type specifier, which might - // include a declaration (e.g., of a `struct` type) - struct TypeSpec - { - // If the type-spec declared something, then put it here - RefPtr decl; - - // Put the resulting expression (which should evaluate to a type) here - RefPtr expr; - }; - - static RefPtr parseGenericApp( - Parser* parser, - RefPtr base) - { - RefPtr genericApp = new GenericAppExpr(); - - parser->FillPosition(genericApp.Ptr()); // set up scope for lookup - genericApp->FunctionExpr = base; - parser->ReadToken(TokenType::OpLess); - parser->genericDepth++; - // For now assume all generics have at least one argument - genericApp->Arguments.add(ParseGenericArg(parser)); - while (AdvanceIf(parser, TokenType::Comma)) - { - genericApp->Arguments.add(ParseGenericArg(parser)); - } - parser->genericDepth--; - - if (parser->tokenReader.PeekToken().type == TokenType::OpRsh) - { - parser->tokenReader.PeekToken().type = TokenType::OpGreater; - parser->tokenReader.PeekToken().loc.setRaw(parser->tokenReader.PeekToken().loc.getRaw() + 1); - } - else if (parser->LookAheadToken(TokenType::OpGreater)) - parser->ReadToken(TokenType::OpGreater); - else - parser->sink->diagnose(parser->tokenReader.PeekToken(), Diagnostics::tokenTypeExpected, "'>'"); - return genericApp; - } - - static bool isGenericName(Parser* parser, Name* name) - { - auto lookupResult = lookUp( - parser->getSession(), - nullptr, // no semantics visitor available yet - name, - parser->currentScope); - if (!lookupResult.isValid() || lookupResult.isOverloaded()) - return false; - - return lookupResult.item.declRef.is(); - } - - static RefPtr tryParseGenericApp( - Parser* parser, - RefPtr base) - { - Name * baseName = nullptr; - if (auto varExpr = as(base)) - baseName = varExpr->name; - // if base is a known generics, parse as generics - if (baseName && isGenericName(parser, baseName)) - return parseGenericApp(parser, base); - - // otherwise, we speculate as generics, and fallback to comparison when parsing failed - TokenSpan tokenSpan; - tokenSpan.mBegin = parser->tokenReader.mCursor; - tokenSpan.mEnd = parser->tokenReader.mEnd; - DiagnosticSink newSink; - newSink.sourceManager = parser->sink->sourceManager; - Parser newParser(*parser); - newParser.sink = &newSink; - auto speculateParseRs = parseGenericApp(&newParser, base); - if (newSink.errorCount == 0) - { - // disambiguate based on FOLLOW set - switch (peekTokenType(&newParser)) - { - case TokenType::Dot: - case TokenType::LParent: - case TokenType::RParent: - case TokenType::RBracket: - case TokenType::Colon: - case TokenType::Comma: - case TokenType::QuestionMark: - case TokenType::Semicolon: - case TokenType::OpEql: - case TokenType::OpNeq: - { - return parseGenericApp(parser, base); - } - } - } - return base; - } - static RefPtr parseMemberType(Parser * parser, RefPtr base) - { - // When called the :: or . have been consumed, so don't need to consume here. - - RefPtr memberExpr = new MemberExpr(); - - parser->FillPosition(memberExpr.Ptr()); - memberExpr->BaseExpression = base; - memberExpr->name = expectIdentifier(parser).name; - return memberExpr; - } - - // Parse option `[]` braces after a type expression, that indicate an array type - static RefPtr parsePostfixTypeSuffix( - Parser* parser, - RefPtr inTypeExpr) - { - auto typeExpr = inTypeExpr; - while (parser->LookAheadToken(TokenType::LBracket)) - { - RefPtr arrType = new IndexExpr(); - arrType->loc = typeExpr->loc; - arrType->BaseExpression = typeExpr; - parser->ReadToken(TokenType::LBracket); - if (!parser->LookAheadToken(TokenType::RBracket)) - { - arrType->IndexExpression = parser->ParseExpression(); - } - parser->ReadToken(TokenType::RBracket); - typeExpr = arrType; - } - return typeExpr; - } - - static RefPtr parseTaggedUnionType(Parser* parser) - { - RefPtr taggedUnionType = new TaggedUnionTypeExpr(); - - parser->ReadToken(TokenType::LParent); - while(!AdvanceIfMatch(parser, TokenType::RParent)) - { - auto caseType = parser->ParseTypeExp(); - taggedUnionType->caseTypes.add(caseType); - - if(AdvanceIf(parser, TokenType::RParent)) - break; - - parser->ReadToken(TokenType::Comma); - } - - return taggedUnionType; - } - - static TypeSpec parseTypeSpec(Parser* parser) - { - TypeSpec typeSpec; - - // We may see a `struct` (or `enum` or `class`) tag specified here, and need to act accordingly - // - // TODO(tfoley): Handle the case where the user is just using `struct` - // as a way to name an existing struct "tag" (e.g., `struct Foo foo;`) - // - // TODO: We should really make these keywords be registered like any other - // syntax category, rather than be special-cased here. The main issue here - // is that we need to allow them to be used as type specifiers, as in: - // - // struct Foo { int x } foo; - // - // The ideal answer would be to register certain keywords as being able - // to parse a type specifier, and look for those keywords here. - // We should ideally add special case logic that bails out of declarator - // parsing iff we have one of these kinds of type specifiers and the - // closing `}` is at the end of its line, as a bit of a special case - // to allow the common idiom. - // - if( parser->LookAheadToken("struct") ) - { - auto decl = parser->ParseStruct(); - typeSpec.decl = decl; - typeSpec.expr = createDeclRefType(parser, decl); - return typeSpec; - } - else if( parser->LookAheadToken("class") ) - { - auto decl = parser->ParseClass(); - typeSpec.decl = decl; - typeSpec.expr = createDeclRefType(parser, decl); - return typeSpec; - } - else if(parser->LookAheadToken("enum")) - { - auto decl = parseEnumDecl(parser); - typeSpec.decl = decl; - typeSpec.expr = createDeclRefType(parser, decl); - return typeSpec; - } - else if(AdvanceIf(parser, "__TaggedUnion")) - { - typeSpec.expr = parseTaggedUnionType(parser); - return typeSpec; - } - - Token typeName = parser->ReadToken(TokenType::Identifier); - - auto basicType = new VarExpr(); - basicType->scope = parser->currentScope.Ptr(); - basicType->loc = typeName.loc; - basicType->name = typeName.getNameOrNull(); - - RefPtr typeExpr = basicType; - - bool shouldLoop = true; - while (shouldLoop) - { - switch (peekTokenType(parser)) - { - case TokenType::OpLess: - typeExpr = parseGenericApp(parser, typeExpr); - break; - case TokenType::Scope: - parser->ReadToken(TokenType::Scope); - typeExpr = parseMemberType(parser, typeExpr); - break; - case TokenType::Dot: - parser->ReadToken(TokenType::Dot); - typeExpr = parseMemberType(parser, typeExpr); - break; - default: - shouldLoop = false; - } - } - - typeSpec.expr = typeExpr; - return typeSpec; - } - - static RefPtr ParseDeclaratorDecl( - Parser* parser, - ContainerDecl* containerDecl) - { - SourceLoc startPosition = parser->tokenReader.PeekLoc(); - - auto typeSpec = parseTypeSpec(parser); - - // We may need to build up multiple declarations in a group, - // but the common case will be when we have just a single - // declaration - DeclGroupBuilder declGroupBuilder; - declGroupBuilder.startPosition = startPosition; - - // The type specifier may include a declaration. E.g., - // it might declare a `struct` type. - if(typeSpec.decl) - declGroupBuilder.addDecl(typeSpec.decl); - - if( AdvanceIf(parser, TokenType::Semicolon) ) - { - // No actual variable is being declared here, but - // that might not be an error. - - auto result = declGroupBuilder.getResult(); - if( !result ) - { - parser->sink->diagnose(startPosition, Diagnostics::declarationDidntDeclareAnything); - } - return result; - } - - // It is possible that we have a plain `struct`, `enum`, - // or similar declaration that isn't being used to declare - // any variable, and the user didn't put a trailing - // semicolon on it: - // - // struct Batman - // { - // int cape; - // } - // - // We want to allow this syntax (rather than give an - // inscrutable error), but also support the less common - // idiom where that declaration is used as part of - // a variable declaration: - // - // struct Robin - // { - // float tights; - // } boyWonder; - // - // As a bit of a hack (insofar as it means we aren't - // *really* compatible with arbitrary HLSL code), we - // will check if there are any more tokens on the - // same line as the closing `}`, and if not, we - // will treat it like the end of the declaration. - // - // Just as a safety net, only apply this logic for - // a file that is being passed in as "true" Slang code. - // - if(parser->getSourceLanguage() == SourceLanguage::Slang) - { - if(typeSpec.decl) - { - if(peekToken(parser).flags & TokenFlag::AtStartOfLine) - { - // The token after the `}` is at the start of its - // own line, which means it can't be on the same line. - // - // This means the programmer probably wants to - // just treat this as a declaration. - return declGroupBuilder.getResult(); - } - } - } - - - InitDeclarator initDeclarator = ParseInitDeclarator(parser); - - DeclaratorInfo declaratorInfo; - declaratorInfo.typeSpec = typeSpec.expr; - - - // Rather than parse function declarators properly for now, - // we'll just do a quick disambiguation here. This won't - // matter unless we actually decide to support function-type parameters, - // using C syntax. - // - if ((parser->tokenReader.PeekTokenType() == TokenType::LParent || - parser->tokenReader.PeekTokenType() == TokenType::OpLess) - - // Only parse as a function if we didn't already see mutually-exclusive - // constructs when parsing the declarator. - && !initDeclarator.initializer - && !initDeclarator.semantics) - { - // Looks like a function, so parse it like one. - UnwrapDeclarator(initDeclarator, &declaratorInfo); - return parseTraditionalFuncDecl(parser, declaratorInfo); - } - - // Otherwise we are looking at a variable declaration, which could be one in a sequence... - - if( AdvanceIf(parser, TokenType::Semicolon) ) - { - // easy case: we only had a single declaration! - UnwrapDeclarator(initDeclarator, &declaratorInfo); - RefPtr firstDecl = CreateVarDeclForContext(containerDecl); - CompleteVarDecl(parser, firstDecl, declaratorInfo); - - declGroupBuilder.addDecl(firstDecl); - return declGroupBuilder.getResult(); - } - - // Otherwise we have multiple declarations in a sequence, and these - // declarations need to somehow share both the type spec and modifiers. - // - // If there are any errors in the type specifier, we only want to hear - // about it once, so we need to share structure rather than just - // clone syntax. - - auto sharedTypeSpec = new SharedTypeExpr(); - sharedTypeSpec->loc = typeSpec.expr->loc; - sharedTypeSpec->base = TypeExp(typeSpec.expr); - - for(;;) - { - declaratorInfo.typeSpec = sharedTypeSpec; - UnwrapDeclarator(initDeclarator, &declaratorInfo); - - RefPtr varDecl = CreateVarDeclForContext(containerDecl); - CompleteVarDecl(parser, varDecl, declaratorInfo); - - declGroupBuilder.addDecl(varDecl); - - // end of the sequence? - if(AdvanceIf(parser, TokenType::Semicolon)) - return declGroupBuilder.getResult(); - - // ad-hoc recovery, to avoid infinite loops - if( parser->isRecovering ) - { - parser->ReadToken(TokenType::Semicolon); - return declGroupBuilder.getResult(); - } - - // Let's default to assuming that a missing `,` - // indicates the end of a declaration, - // where a `;` would be expected, and not - // a continuation of this declaration, where - // a `,` would be expected (this is tailoring - // the diagnostic message a bit). - // - // TODO: a more advanced heuristic here might - // look at whether the next token is on the - // same line, to predict whether `,` or `;` - // would be more likely... - - if (!AdvanceIf(parser, TokenType::Comma)) - { - parser->ReadToken(TokenType::Semicolon); - return declGroupBuilder.getResult(); - } - - // expect another variable declaration... - initDeclarator = ParseInitDeclarator(parser); - } - } - - /// Parse the "register name" part of a `register` or `packoffset` semantic. - /// - /// The syntax matched is: - /// - /// register-name-and-component-mask ::= register-name component-mask? - /// register-name ::= identifier - /// component-mask ::= '.' identifier - /// - static void parseHLSLRegisterNameAndOptionalComponentMask( - Parser* parser, - HLSLLayoutSemantic* semantic) - { - semantic->registerName = parser->ReadToken(TokenType::Identifier); - if (AdvanceIf(parser, TokenType::Dot)) - { - semantic->componentMask = parser->ReadToken(TokenType::Identifier); - } - } - - /// Parse an HLSL `register` semantic. - /// - /// The syntax matched is: - /// - /// register-semantic ::= 'register' '(' register-name-and-component-mask register-space? ')' - /// register-space ::= ',' identifier - /// - static void parseHLSLRegisterSemantic( - Parser* parser, - HLSLRegisterSemantic* semantic) - { - // Read the `register` keyword - semantic->name = parser->ReadToken(TokenType::Identifier); - - // Expect a parenthized list of additional arguments - parser->ReadToken(TokenType::LParent); - - // First argument is a required register name and optional component mask - parseHLSLRegisterNameAndOptionalComponentMask(parser, semantic); - - // Second argument is an optional register space - if(AdvanceIf(parser, TokenType::Comma)) - { - semantic->spaceName = parser->ReadToken(TokenType::Identifier); - } - - parser->ReadToken(TokenType::RParent); - } - - /// Parse an HLSL `packoffset` semantic. - /// - /// The syntax matched is: - /// - /// packoffset-semantic ::= 'packoffset' '(' register-name-and-component-mask ')' - /// - static void parseHLSLPackOffsetSemantic( - Parser* parser, - HLSLPackOffsetSemantic* semantic) - { - // Read the `packoffset` keyword - semantic->name = parser->ReadToken(TokenType::Identifier); - - // Expect a parenthized list of additional arguments - parser->ReadToken(TokenType::LParent); - - // First and only argument is a required register name and optional component mask - parseHLSLRegisterNameAndOptionalComponentMask(parser, semantic); - - parser->ReadToken(TokenType::RParent); - - parser->sink->diagnose(semantic, Diagnostics::packOffsetNotSupported); - } - - // - // semantic ::= identifier ( '(' args ')' )? - // - static RefPtr ParseSemantic( - Parser* parser) - { - if (parser->LookAheadToken("register")) - { - RefPtr semantic = new HLSLRegisterSemantic(); - parser->FillPosition(semantic); - parseHLSLRegisterSemantic(parser, semantic.Ptr()); - return semantic; - } - else if (parser->LookAheadToken("packoffset")) - { - RefPtr semantic = new HLSLPackOffsetSemantic(); - parser->FillPosition(semantic); - parseHLSLPackOffsetSemantic(parser, semantic.Ptr()); - return semantic; - } - else if (parser->LookAheadToken(TokenType::Identifier)) - { - RefPtr semantic = new HLSLSimpleSemantic(); - parser->FillPosition(semantic); - semantic->name = parser->ReadToken(TokenType::Identifier); - return semantic; - } - else - { - // expect an identifier, just to produce an error message - parser->ReadToken(TokenType::Identifier); - return nullptr; - } - } - - // - // opt-semantics ::= (':' semantic)* - // - static RefPtr ParseOptSemantics( - Parser* parser) - { - if (!AdvanceIf(parser, TokenType::Colon)) - return nullptr; - - RefPtr result; - RefPtr* link = &result; - SLANG_ASSERT(!*link); - - for (;;) - { - RefPtr semantic = ParseSemantic(parser); - if (semantic) - { - *link = semantic; - link = &semantic->next; - } - - // If we see another `:`, then that means there - // is yet another semantic to be processed. - // Otherwise we assume we are at the end of the list. - // - // TODO: This could produce sub-optimal diagnostics - // when the user *meant* to apply multiple semantics - // to a single declaration: - // - // Foo foo : register(t0) register(s0); - // ^ - // missing ':' here | - // - // However, that is an uncommon occurence, and trying - // to continue parsing semantics here even if we didn't - // see a colon forces us to be careful about - // avoiding an infinite loop here. - if (!AdvanceIf(parser, TokenType::Colon)) - { - return result; - } - } - - } - - - static void ParseOptSemantics( - Parser* parser, - Decl* decl) - { - AddModifiers(decl, ParseOptSemantics(parser)); - } - - static RefPtr ParseHLSLBufferDecl( - Parser* parser, - String bufferWrapperTypeName) - { - // An HLSL declaration of a constant buffer like this: - // - // cbuffer Foo : register(b0) { int a; float b; }; - // - // is treated as syntax sugar for a type declaration - // and then a global variable declaration using that type: - // - // struct $anonymous { int a; float b; }; - // ConstantBuffer<$anonymous> Foo; - // - // where `$anonymous` is a fresh name, and the variable - // declaration is made to be "transparent" so that lookup - // will see through it to the members inside. - - auto bufferWrapperTypeNamePos = parser->tokenReader.PeekLoc(); - - // We are going to represent each buffer as a pair of declarations. - // The first is a type declaration that holds all the members, while - // the second is a variable declaration that uses the buffer type. - RefPtr bufferDataTypeDecl = new StructDecl(); - RefPtr bufferVarDecl = new VarDecl(); - - // Both declarations will have a location that points to the name - parser->FillPosition(bufferDataTypeDecl.Ptr()); - parser->FillPosition(bufferVarDecl.Ptr()); - - auto reflectionNameToken = parser->ReadToken(TokenType::Identifier); - - // Attach the reflection name to the block so we can use it - auto reflectionNameModifier = new ParameterGroupReflectionName(); - reflectionNameModifier->nameAndLoc = NameLoc(reflectionNameToken); - addModifier(bufferVarDecl, reflectionNameModifier); - - // Both the buffer variable and its type need to have names generated - bufferVarDecl->nameAndLoc.name = generateName(parser, "parameterGroup_" + String(reflectionNameToken.Content)); - bufferDataTypeDecl->nameAndLoc.name = generateName(parser, "ParameterGroup_" + String(reflectionNameToken.Content)); - - addModifier(bufferDataTypeDecl, new ImplicitParameterGroupElementTypeModifier()); - addModifier(bufferVarDecl, new ImplicitParameterGroupVariableModifier()); - - // TODO(tfoley): We end up constructing unchecked syntax here that - // is expected to type check into the right form, but it might be - // cleaner to have a more explicit desugaring pass where we parse - // these constructs directly into the AST and *then* desugar them. - - // Construct a type expression to reference the buffer data type - auto bufferDataTypeExpr = new VarExpr(); - bufferDataTypeExpr->loc = bufferDataTypeDecl->loc; - bufferDataTypeExpr->name = bufferDataTypeDecl->nameAndLoc.name; - bufferDataTypeExpr->scope = parser->currentScope.Ptr(); - - // Construct a type expression to reference the type constructor - auto bufferWrapperTypeExpr = new VarExpr(); - bufferWrapperTypeExpr->loc = bufferWrapperTypeNamePos; - bufferWrapperTypeExpr->name = getName(parser, bufferWrapperTypeName); - - // Always need to look this up in the outer scope, - // so that it won't collide with, e.g., a local variable called `ConstantBuffer` - bufferWrapperTypeExpr->scope = parser->outerScope; - - // Construct a type expression that represents the type for the variable, - // which is the wrapper type applied to the data type - auto bufferVarTypeExpr = new GenericAppExpr(); - bufferVarTypeExpr->loc = bufferVarDecl->loc; - bufferVarTypeExpr->FunctionExpr = bufferWrapperTypeExpr; - bufferVarTypeExpr->Arguments.add(bufferDataTypeExpr); - - bufferVarDecl->type.exp = bufferVarTypeExpr; - - // Any semantics applied to the buffer declaration are taken as applying - // to the variable instead. - ParseOptSemantics(parser, bufferVarDecl.Ptr()); - - // The declarations in the body belong to the data type. - parseAggTypeDeclBody(parser, bufferDataTypeDecl.Ptr()); - - // All HLSL buffer declarations are "transparent" in that their - // members are implicitly made visible in the parent scope. - // We achieve this by applying the transparent modifier to the variable. - auto transparentModifier = new TransparentModifier(); - transparentModifier->next = bufferVarDecl->modifiers.first; - bufferVarDecl->modifiers.first = transparentModifier; - - // Because we are constructing two declarations, we have a thorny - // issue that were are only supposed to return one. - // For now we handle this by adding the type declaration to - // the current scope manually, and then returning the variable - // declaration. - // - // Note: this means that any modifiers that have already been parsed - // will get attached to the variable declaration, not the type. - // There might be cases where we need to shuffle things around. - - AddMember(parser->currentScope, bufferDataTypeDecl); - - return bufferVarDecl; - } - - static RefPtr parseHLSLCBufferDecl( - Parser* parser, void* /*userData*/) - { - return ParseHLSLBufferDecl(parser, "ConstantBuffer"); - } - - static RefPtr parseHLSLTBufferDecl( - Parser* parser, void* /*userData*/) - { - return ParseHLSLBufferDecl(parser, "TextureBuffer"); - } - - static void parseOptionalInheritanceClause(Parser* parser, AggTypeDeclBase* decl) - { - if (AdvanceIf(parser, TokenType::Colon)) - { - do - { - auto base = parser->ParseTypeExp(); - - auto inheritanceDecl = new InheritanceDecl(); - inheritanceDecl->loc = base.exp->loc; - inheritanceDecl->nameAndLoc.name = getName(parser, "$inheritance"); - inheritanceDecl->base = base; - - AddMember(decl, inheritanceDecl); - - } while (AdvanceIf(parser, TokenType::Comma)); - } - } - - static RefPtr ParseExtensionDecl(Parser* parser, void* /*userData*/) - { - RefPtr decl = new ExtensionDecl(); - parser->FillPosition(decl.Ptr()); - decl->targetType = parser->ParseTypeExp(); - parseOptionalInheritanceClause(parser, decl); - parseAggTypeDeclBody(parser, decl.Ptr()); - - return decl; - } - - - void parseOptionalGenericConstraints(Parser * parser, ContainerDecl* decl) - { - if (AdvanceIf(parser, TokenType::Colon)) - { - do - { - RefPtr paramConstraint = new GenericTypeConstraintDecl(); - parser->FillPosition(paramConstraint); - - // substitution needs to be filled during check - RefPtr paramType = DeclRefType::Create( - parser->getSession(), - DeclRef(decl, nullptr)); - - RefPtr paramTypeExpr = new SharedTypeExpr(); - paramTypeExpr->loc = decl->loc; - paramTypeExpr->base.type = paramType; - paramTypeExpr->type = QualType(getTypeType(paramType)); - - paramConstraint->sub = TypeExp(paramTypeExpr); - paramConstraint->sup = parser->ParseTypeExp(); - - AddMember(decl, paramConstraint); - } while (AdvanceIf(parser, TokenType::Comma)); - } - } - - RefPtr parseAssocType(Parser * parser, void *) - { - RefPtr assocTypeDecl = new AssocTypeDecl(); - - auto nameToken = parser->ReadToken(TokenType::Identifier); - assocTypeDecl->nameAndLoc = NameLoc(nameToken); - assocTypeDecl->loc = nameToken.loc; - parseOptionalGenericConstraints(parser, assocTypeDecl); - parser->ReadToken(TokenType::Semicolon); - return assocTypeDecl; - } - - RefPtr parseGlobalGenericParamDecl(Parser * parser, void *) - { - RefPtr genParamDecl = new GlobalGenericParamDecl(); - auto nameToken = parser->ReadToken(TokenType::Identifier); - genParamDecl->nameAndLoc = NameLoc(nameToken); - genParamDecl->loc = nameToken.loc; - parseOptionalGenericConstraints(parser, genParamDecl); - parser->ReadToken(TokenType::Semicolon); - return genParamDecl; - } - - static RefPtr parseInterfaceDecl(Parser* parser, void* /*userData*/) - { - RefPtr decl = new InterfaceDecl(); - parser->FillPosition(decl.Ptr()); - decl->nameAndLoc = NameLoc(parser->ReadToken(TokenType::Identifier)); - - parseOptionalInheritanceClause(parser, decl.Ptr()); - - parseAggTypeDeclBody(parser, decl.Ptr()); - - return decl; - } - - static RefPtr parseConstructorDecl(Parser* parser, void* /*userData*/) - { - RefPtr decl = new ConstructorDecl(); - parser->FillPosition(decl.Ptr()); - - // TODO: we need to make sure that all initializers have - // the same name, but that this name doesn't conflict - // with any user-defined names. - // Giving them a name (rather than leaving it null) - // ensures that we can use name-based lookup to find - // all of the initializers on a type (and has - // the potential to unify initializer lookup with - // ordinary member lookup). - decl->nameAndLoc.name = getName(parser, "$init"); - - parseParameterList(parser, decl); - - decl->Body = parseOptBody(parser); - - return decl; - } - - static RefPtr parseAccessorDecl(Parser* parser) - { - Modifiers modifiers = ParseModifiers(parser); - - RefPtr decl; - if( AdvanceIf(parser, "get") ) - { - decl = new GetterDecl(); - } - else if( AdvanceIf(parser, "set") ) - { - decl = new SetterDecl(); - } - else if( AdvanceIf(parser, "ref") ) - { - decl = new RefAccessorDecl(); - } - else - { - Unexpected(parser); - return nullptr; - } - - AddModifiers(decl, modifiers.first); - - if( parser->tokenReader.PeekTokenType() == TokenType::LBrace ) - { - decl->Body = parser->parseBlockStatement(); - } - else - { - parser->ReadToken(TokenType::Semicolon); - } - - return decl; - } - - static RefPtr ParseSubscriptDecl(Parser* parser, void* /*userData*/) - { - RefPtr decl = new SubscriptDecl(); - parser->FillPosition(decl.Ptr()); - - // TODO: the use of this name here is a bit magical... - decl->nameAndLoc.name = getName(parser, "operator[]"); - - parseParameterList(parser, decl); - - if( AdvanceIf(parser, TokenType::RightArrow) ) - { - decl->ReturnType = parser->ParseTypeExp(); - } - - if( AdvanceIf(parser, TokenType::LBrace) ) - { - // We want to parse nested "accessor" declarations - while( !AdvanceIfMatch(parser, TokenType::RBrace) ) - { - auto accessor = parseAccessorDecl(parser); - AddMember(decl, accessor); - } - } - else - { - parser->ReadToken(TokenType::Semicolon); - - // empty body should be treated like `{ get; }` - } - - return decl; - } - - static bool expect(Parser* parser, TokenType tokenType) - { - return parser->ReadToken(tokenType).type == tokenType; - } - - static void parseModernVarDeclBaseCommon( - Parser* parser, - RefPtr decl) - { - parser->FillPosition(decl.Ptr()); - decl->nameAndLoc = NameLoc(parser->ReadToken(TokenType::Identifier)); - - if(AdvanceIf(parser, TokenType::Colon)) - { - decl->type = parser->ParseTypeExp(); - } - - if(AdvanceIf(parser, TokenType::OpAssign)) - { - decl->initExpr = parser->ParseInitExpr(); - } - } - - static void parseModernVarDeclCommon( - Parser* parser, - RefPtr decl) - { - parseModernVarDeclBaseCommon(parser, decl); - expect(parser, TokenType::Semicolon); - } - - static RefPtr parseLetDecl( - Parser* parser, void* /*userData*/) - { - RefPtr decl = new LetDecl(); - parseModernVarDeclCommon(parser, decl); - return decl; - } - - static RefPtr parseVarDecl( - Parser* parser, void* /*userData*/) - { - RefPtr decl = new VarDecl(); - parseModernVarDeclCommon(parser, decl); - return decl; - } - - static RefPtr parseModernParamDecl( - Parser* parser) - { - RefPtr decl = new ParamDecl(); - - // TODO: "modern" parameters should not accept keyword-based - // modifiers and should only accept `[attribute]` syntax for - // modifiers to keep the grammar as simple as possible. - // - // Further, they should accept `out` and `in out`/`inout` - // before the type (e.g., `a: inout float4`). - // - decl->modifiers = ParseModifiers(parser); - parseModernVarDeclBaseCommon(parser, decl); - return decl; - } - - static void parseModernParamList( - Parser* parser, - RefPtr decl) - { - parser->ReadToken(TokenType::LParent); - - while (!AdvanceIfMatch(parser, TokenType::RParent)) - { - AddMember(decl, parseModernParamDecl(parser)); - if (AdvanceIf(parser, TokenType::RParent)) - break; - parser->ReadToken(TokenType::Comma); - } - } - - static RefPtr parseFuncDecl( - Parser* parser, void* /*userData*/) - { - RefPtr decl = new FuncDecl(); - - parser->FillPosition(decl.Ptr()); - decl->nameAndLoc = NameLoc(parser->ReadToken(TokenType::Identifier)); - - return parseOptGenericDecl(parser, [&](GenericDecl*) - { - parser->PushScope(decl.Ptr()); - parseModernParamList(parser, decl); - if(AdvanceIf(parser, TokenType::RightArrow)) - { - decl->ReturnType = parser->ParseTypeExp(); - } - decl->Body = parseOptBody(parser); - parser->PopScope(); - return decl; - }); - } - - static RefPtr parseTypeAliasDecl( - Parser* parser, void* /*userData*/) - { - RefPtr decl = new TypeAliasDecl(); - - parser->FillPosition(decl.Ptr()); - decl->nameAndLoc = NameLoc(parser->ReadToken(TokenType::Identifier)); - - return parseOptGenericDecl(parser, [&](GenericDecl*) - { - if( expect(parser, TokenType::OpAssign) ) - { - decl->type = parser->ParseTypeExp(); - } - expect(parser, TokenType::Semicolon); - return decl; - }); - } - - // This is a catch-all syntax-construction callback to handle cases where - // a piece of syntax is fully defined by the keyword to use, along with - // the class of AST node to construct. - static RefPtr parseSimpleSyntax(Parser* /*parser*/, void* userData) - { - SyntaxClassBase syntaxClass((SyntaxClassBase::ClassInfo*) userData); - return (RefObject*) syntaxClass.createInstanceImpl(); - } - - // Parse a declaration of a keyword that can be used to define further syntax. - static RefPtr parseSyntaxDecl(Parser* parser, void* /*userData*/) - { - // Right now the basic form is: - // - // syntax [: ] [= ]; - // - // - `name` gives the name of the keyword to define. - // - `syntaxClass` is the name of an AST node class that we expect - // this syntax to construct when parsed. - // - `existingKeyword` is the name of an existing keyword that - // the new syntax should be an alias for. - - // First we parse the keyword name. - auto nameAndLoc = expectIdentifier(parser); - - // Next we look for a clause that specified the AST node class. - SyntaxClass syntaxClass; - if (AdvanceIf(parser, TokenType::Colon)) - { - // User is specifying the class that should be construted - auto classNameAndLoc = expectIdentifier(parser); - - syntaxClass = parser->getSession()->findSyntaxClass(classNameAndLoc.name); - } - - // If the user specified a syntax class, then we will default - // to the `parseSimpleSyntax` callback that will just construct - // an instance of that type to represent the keyword in the AST. - SyntaxParseCallback parseCallback = &parseSimpleSyntax; - void* parseUserData = (void*) syntaxClass.classInfo; - - // Next we look for an initializer that will make this keyword - // an alias for some existing keyword. - if (AdvanceIf(parser, TokenType::OpAssign)) - { - auto existingKeywordNameAndLoc = expectIdentifier(parser); - - auto existingSyntax = tryLookUpSyntaxDecl(parser, existingKeywordNameAndLoc.name); - if (!existingSyntax) - { - // TODO: diagnose: keyword did not name syntax - } - else - { - // The user is expecting us to parse our new syntax like - // the existing syntax given, so we need to override - // the callback. - parseCallback = existingSyntax->parseCallback; - parseUserData = existingSyntax->parseUserData; - - // If we don't already have a syntax class specified, then - // we will crib the one from the existing syntax, to ensure - // that we are creating a drop-in alias. - if (!syntaxClass.classInfo) - syntaxClass = existingSyntax->syntaxClass; - } - } - - // It is an error if the user didn't give us either an existing keyword - // to use to the define the callback, or a valid AST node class to construct. - // - // TODO: down the line this should be expanded so that the user can reference - // an existing *function* to use to parse the chosen syntax. - if (!syntaxClass.classInfo) - { - // TODO: diagnose: either a type or an existing keyword needs to be specified - } - - expect(parser, TokenType::Semicolon); - - // TODO: skip creating the declaration if anything failed, just to not screw things - // up for downstream code? - - RefPtr syntaxDecl = new SyntaxDecl(); - syntaxDecl->nameAndLoc = nameAndLoc; - syntaxDecl->loc = nameAndLoc.loc; - syntaxDecl->syntaxClass = syntaxClass; - syntaxDecl->parseCallback = parseCallback; - syntaxDecl->parseUserData = parseUserData; - return syntaxDecl; - } - - // A parameter declaration in an attribute declaration. - // - // We are going to use `name: type` syntax just for simplicty, and let the type - // be optional, because we don't actually need it in all cases. - // - static RefPtr parseAttributeParamDecl(Parser* parser) - { - auto nameAndLoc = expectIdentifier(parser); - - RefPtr paramDecl = new ParamDecl(); - paramDecl->nameAndLoc = nameAndLoc; - - if(AdvanceIf(parser, TokenType::Colon)) - { - paramDecl->type = parser->ParseTypeExp(); - } - - if(AdvanceIf(parser, TokenType::OpAssign)) - { - paramDecl->initExpr = parser->ParseInitExpr(); - } - - return paramDecl; - } - - // Parse declaration of a name to be used for resolving `[attribute(...)]` style modifiers. - // - // These are distinct from `syntax` declarations, because their names don't get added - // to the current scope using their default name. - // - // Also, attribute-specific code doesn't get invokved during parsing. We always parse - // using the default attribute-parsing logic and then all specialized behavior takes - // place during semantic checking. - // - static RefPtr parseAttributeSyntaxDecl(Parser* parser, void* /*userData*/) - { - // Right now the basic form is: - // - // attribute_syntax : ; - // - // - `name` gives the name of the attribute to define. - // - `syntaxClass` is the name of an AST node class that we expect - // this attribute to create when checked. - // - `existingKeyword` is the name of an existing keyword that - // the new syntax should be an alias for. - - expect(parser, TokenType::LBracket); - - // First we parse the attribute name. - auto nameAndLoc = expectIdentifier(parser); - - RefPtr attrDecl = new AttributeDecl(); - if(AdvanceIf(parser, TokenType::LParent)) - { - while(!AdvanceIfMatch(parser, TokenType::RParent)) - { - auto param = parseAttributeParamDecl(parser); - - AddMember(attrDecl, param); - - if(AdvanceIfMatch(parser, TokenType::RParent)) - break; - - expect(parser, TokenType::Comma); - } - } - - expect(parser, TokenType::RBracket); - - // TODO: we should allow parameters to be specified here, to cut down - // on the amount of per-attribute-type logic that has to occur later. - - // Next we look for a clause that specified the AST node class. - SyntaxClass syntaxClass; - if (AdvanceIf(parser, TokenType::Colon)) - { - // User is specifying the class that should be construted - auto classNameAndLoc = expectIdentifier(parser); - - syntaxClass = parser->getSession()->findSyntaxClass(classNameAndLoc.name); - } - else - { - // For now we don't support the alternative approach where - // an existing piece of syntax is named to provide the parsing - // support. - - // TODO: diagnose: a syntax class must be specified. - } - - expect(parser, TokenType::Semicolon); - - // TODO: skip creating the declaration if anything failed, just to not screw things - // up for downstream code? - - attrDecl->nameAndLoc = nameAndLoc; - attrDecl->loc = nameAndLoc.loc; - attrDecl->syntaxClass = syntaxClass; - return attrDecl; - } - - // Finish up work on a declaration that was parsed - static void CompleteDecl( - Parser* /*parser*/, - RefPtr decl, - ContainerDecl* containerDecl, - Modifiers modifiers) - { - // Add any modifiers we parsed before the declaration to the list - // of modifiers on the declaration itself. - // - // We need to be careful, because if `decl` is a generic declaration, - // then we really want the modifiers to apply to the inner declaration. - // - RefPtr declToModify = decl; - if(auto genericDecl = as(decl)) - declToModify = genericDecl->inner; - AddModifiers(declToModify.Ptr(), modifiers.first); - - // Make sure the decl is properly nested inside its lexical parent - if (containerDecl) - { - AddMember(containerDecl, decl); - } - } - - static RefPtr ParseDeclWithModifiers( - Parser* parser, - ContainerDecl* containerDecl, - Modifiers modifiers ) - { - RefPtr decl; - - auto loc = parser->tokenReader.PeekLoc(); - - switch (peekTokenType(parser)) - { - case TokenType::Identifier: - { - // A declaration that starts with an identifier might be: - // - // - A keyword-based declaration (e.g., `cbuffer ...`) - // - The beginning of a type in a declarator-based declaration (e.g., `int ...`) - - // First we will check whether we can use the identifier token - // as a declaration keyword and parse a declaration using - // its associated callback: - RefPtr parsedDecl; - if (tryParseUsingSyntaxDecl(parser, &parsedDecl)) - { - decl = parsedDecl; - break; - } - - // Our final fallback case is to assume that the user is - // probably writing a C-style declarator-based declaration. - decl = ParseDeclaratorDecl(parser, containerDecl); - break; - } - break; - - // It is valid in HLSL/GLSL to have an "empty" declaration - // that consists of just a semicolon. In particular, this - // gets used a lot in GLSL to attach custom semantics to - // shader input or output. - // - case TokenType::Semicolon: - { - advanceToken(parser); - - decl = new EmptyDecl(); - decl->loc = loc; - } - break; - - // If nothing else matched, we try to parse an "ordinary" declarator-based declaration - default: - decl = ParseDeclaratorDecl(parser, containerDecl); - break; - } - - if (decl) - { - if( auto dd = as(decl) ) - { - CompleteDecl(parser, dd, containerDecl, modifiers); - } - else if(auto declGroup = as(decl)) - { - // We are going to add the same modifiers to *all* of these declarations, - // so we want to give later passes a way to detect which modifiers - // were shared, vs. which ones are specific to a single declaration. - - auto sharedModifiers = new SharedModifiers(); - sharedModifiers->next = modifiers.first; - modifiers.first = sharedModifiers; - - for( auto subDecl : declGroup->decls ) - { - CompleteDecl(parser, subDecl, containerDecl, modifiers); - } - } - } - return decl; - } - - static RefPtr ParseDecl( - Parser* parser, - ContainerDecl* containerDecl) - { - Modifiers modifiers = ParseModifiers(parser); - return ParseDeclWithModifiers(parser, containerDecl, modifiers); - } - - static RefPtr ParseSingleDecl( - Parser* parser, - ContainerDecl* containerDecl) - { - auto declBase = ParseDecl(parser, containerDecl); - if(!declBase) - return nullptr; - if( auto decl = as(declBase) ) - { - return decl; - } - else if( auto declGroup = as(declBase) ) - { - if( declGroup->decls.getCount() == 1 ) - { - return declGroup->decls[0]; - } - } - - parser->sink->diagnose(declBase->loc, Diagnostics::unimplemented, "didn't expect multiple declarations here"); - return nullptr; - } - - - // Parse a body consisting of declarations - static void ParseDeclBody( - Parser* parser, - ContainerDecl* containerDecl, - TokenType closingToken) - { - while(!AdvanceIfMatch(parser, closingToken)) - { - ParseDecl(parser, containerDecl); - } - } - - // Parse the `{}`-delimeted body of an aggregate type declaration - static void parseAggTypeDeclBody( - Parser* parser, - AggTypeDeclBase* decl) - { - // TODO: the scope used for the body might need to be - // slightly specialized to deal with the complexity - // of how `this` works. - // - // Alternatively, that complexity can be pushed down - // to semantic analysis so that it doesn't clutter - // things here. - parser->PushScope(decl); - - parser->ReadToken(TokenType::LBrace); - ParseDeclBody(parser, decl, TokenType::RBrace); - - parser->PopScope(); - } - - - void Parser::parseSourceFile(ModuleDecl* program) - { - if (outerScope) - { - currentScope = outerScope; - } - - PushScope(program); - program->loc = tokenReader.PeekLoc(); - program->scope = currentScope; - ParseDeclBody(this, program, TokenType::EndOfFile); - PopScope(); - - SLANG_RELEASE_ASSERT(currentScope == outerScope); - currentScope = nullptr; - } - - RefPtr Parser::ParseStruct() - { - RefPtr rs = new StructDecl(); - FillPosition(rs.Ptr()); - ReadToken("struct"); - - // TODO: support `struct` declaration without tag - rs->nameAndLoc = expectIdentifier(this); - - return parseOptGenericDecl(this, [&](GenericDecl*) - { - // We allow for an inheritance clause on a `struct` - // so that it can conform to interfaces. - parseOptionalInheritanceClause(this, rs.Ptr()); - parseAggTypeDeclBody(this, rs.Ptr()); - return rs; - }); - } - - RefPtr Parser::ParseClass() - { - RefPtr rs = new ClassDecl(); - FillPosition(rs.Ptr()); - ReadToken("class"); - rs->nameAndLoc = expectIdentifier(this); - - parseOptionalInheritanceClause(this, rs.Ptr()); - - parseAggTypeDeclBody(this, rs.Ptr()); - return rs; - } - - static RefPtr parseEnumCaseDecl(Parser* parser) - { - RefPtr decl = new EnumCaseDecl(); - decl->nameAndLoc = expectIdentifier(parser); - - if(AdvanceIf(parser, TokenType::OpAssign)) - { - decl->tagExpr = parser->ParseArgExpr(); - } - - return decl; - } - - static RefPtr parseEnumDecl(Parser* parser) - { - RefPtr decl = new EnumDecl(); - parser->FillPosition(decl); - - parser->ReadToken("enum"); - - // HACK: allow the user to write `enum class` in case - // they are trying to share a header between C++ and Slang. - // - // TODO: diagnose this with a warning some day, and move - // toward deprecating it. - // - AdvanceIf(parser, "class"); - - decl->nameAndLoc = expectIdentifier(parser); - - - return parseOptGenericDecl(parser, [&](GenericDecl*) - { - parseOptionalInheritanceClause(parser, decl); - parser->ReadToken(TokenType::LBrace); - - while(!AdvanceIfMatch(parser, TokenType::RBrace)) - { - RefPtr caseDecl = parseEnumCaseDecl(parser); - AddMember(decl, caseDecl); - - if(AdvanceIf(parser, TokenType::RBrace)) - break; - - parser->ReadToken(TokenType::Comma); - } - return decl; - }); - } - - static RefPtr ParseSwitchStmt(Parser* parser) - { - RefPtr stmt = new SwitchStmt(); - parser->FillPosition(stmt.Ptr()); - parser->ReadToken("switch"); - parser->ReadToken(TokenType::LParent); - stmt->condition = parser->ParseExpression(); - parser->ReadToken(TokenType::RParent); - stmt->body = parser->parseBlockStatement(); - return stmt; - } - - static RefPtr ParseCaseStmt(Parser* parser) - { - RefPtr stmt = new CaseStmt(); - parser->FillPosition(stmt.Ptr()); - parser->ReadToken("case"); - stmt->expr = parser->ParseExpression(); - parser->ReadToken(TokenType::Colon); - return stmt; - } - - static RefPtr ParseDefaultStmt(Parser* parser) - { - RefPtr stmt = new DefaultStmt(); - parser->FillPosition(stmt.Ptr()); - parser->ReadToken("default"); - parser->ReadToken(TokenType::Colon); - return stmt; - } - - static bool isTypeName(Parser* parser, Name* name) - { - auto lookupResult = lookUp( - parser->getSession(), - nullptr, // no semantics visitor available yet - name, - parser->currentScope); - if(!lookupResult.isValid() || lookupResult.isOverloaded()) - return false; - - auto decl = lookupResult.item.declRef.getDecl(); - if( auto typeDecl = as(decl) ) - { - return true; - } - else if( auto typeVarDecl = as(decl) ) - { - return true; - } - else - { - return false; - } - } - - static bool peekTypeName(Parser* parser) - { - if(!parser->LookAheadToken(TokenType::Identifier)) - return false; - - auto name = parser->tokenReader.PeekToken().getName(); - return isTypeName(parser, name); - } - - RefPtr parseCompileTimeForStmt( - Parser* parser) - { - RefPtr scopeDecl = new ScopeDecl(); - RefPtr stmt = new CompileTimeForStmt(); - stmt->scopeDecl = scopeDecl; - - - parser->ReadToken("for"); - parser->ReadToken(TokenType::LParent); - - NameLoc varNameAndLoc = expectIdentifier(parser); - RefPtr varDecl = new VarDecl(); - varDecl->nameAndLoc = varNameAndLoc; - varDecl->loc = varNameAndLoc.loc; - - stmt->varDecl = varDecl; - - parser->ReadToken("in"); - parser->ReadToken("Range"); - parser->ReadToken(TokenType::LParent); - - RefPtr rangeBeginExpr; - RefPtr rangeEndExpr = parser->ParseArgExpr(); - if (AdvanceIf(parser, TokenType::Comma)) - { - rangeBeginExpr = rangeEndExpr; - rangeEndExpr = parser->ParseArgExpr(); - } - - stmt->rangeBeginExpr = rangeBeginExpr; - stmt->rangeEndExpr = rangeEndExpr; - - parser->ReadToken(TokenType::RParent); - parser->ReadToken(TokenType::RParent); - - parser->pushScopeAndSetParent(scopeDecl); - AddMember(parser->currentScope, varDecl); - - stmt->body = parser->ParseStatement(); - - parser->PopScope(); - - return stmt; - } - - RefPtr parseCompileTimeStmt( - Parser* parser) - { - parser->ReadToken(TokenType::Dollar); - if (parser->LookAheadToken("for")) - { - return parseCompileTimeForStmt(parser); - } - else - { - Unexpected(parser); - return nullptr; - } - } - - RefPtr Parser::ParseStatement() - { - auto modifiers = ParseModifiers(this); - - RefPtr statement; - if (LookAheadToken(TokenType::LBrace)) - statement = parseBlockStatement(); - else if (peekTypeName(this)) - statement = parseVarDeclrStatement(modifiers); - else if (LookAheadToken("if")) - statement = parseIfStatement(); - else if (LookAheadToken("for")) - statement = ParseForStatement(); - else if (LookAheadToken("while")) - statement = ParseWhileStatement(); - else if (LookAheadToken("do")) - statement = ParseDoWhileStatement(); - else if (LookAheadToken("break")) - statement = ParseBreakStatement(); - else if (LookAheadToken("continue")) - statement = ParseContinueStatement(); - else if (LookAheadToken("return")) - statement = ParseReturnStatement(); - else if (LookAheadToken("discard")) - { - statement = new DiscardStmt(); - FillPosition(statement.Ptr()); - ReadToken("discard"); - ReadToken(TokenType::Semicolon); - } - else if (LookAheadToken("switch")) - statement = ParseSwitchStmt(this); - else if (LookAheadToken("case")) - statement = ParseCaseStmt(this); - else if (LookAheadToken("default")) - statement = ParseDefaultStmt(this); - else if (LookAheadToken(TokenType::Dollar)) - { - statement = parseCompileTimeStmt(this); - } - else if (LookAheadToken(TokenType::Identifier)) - { - // We might be looking at a local declaration, or an - // expression statement, and we need to figure out which. - // - // We'll solve this with backtracking for now. - - TokenReader::ParsingCursor startPos = tokenReader.getCursor(); - - // Try to parse a type (knowing that the type grammar is - // a subset of the expression grammar, and so this should - // always succeed). - RefPtr type = ParseType(); - // We don't actually care about the type, though, so - // don't retain it - type = nullptr; - - // If the next token after we parsed a type looks like - // we are going to declare a variable, then lets guess - // that this is a declaration. - // - // TODO(tfoley): this wouldn't be robust for more - // general kinds of declarators (notably pointer declarators), - // so we'll need to be careful about this. - if (LookAheadToken(TokenType::Identifier)) - { - // Reset the cursor and try to parse a declaration now. - // Note: the declaration will consume any modifiers - // that had been in place on the statement. - tokenReader.setCursor(startPos); - statement = parseVarDeclrStatement(modifiers); - return statement; - } - - // Fallback: reset and parse an expression - tokenReader.setCursor(startPos); - statement = ParseExpressionStatement(); - } - else if (LookAheadToken(TokenType::Semicolon)) - { - statement = new EmptyStmt(); - FillPosition(statement.Ptr()); - ReadToken(TokenType::Semicolon); - } - else - { - // Default case should always fall back to parsing an expression, - // and then let that detect any errors - statement = ParseExpressionStatement(); - } - - if (statement && !as(statement)) - { - // Install any modifiers onto the statement. - // Note: this path is bypassed in the case of a - // declaration statement, so we don't end up - // doubling up the modifiers. - statement->modifiers = modifiers; - } - - return statement; - } - - RefPtr Parser::parseBlockStatement() - { - RefPtr scopeDecl = new ScopeDecl(); - RefPtr blockStatement = new BlockStmt(); - blockStatement->scopeDecl = scopeDecl; - pushScopeAndSetParent(scopeDecl.Ptr()); - ReadToken(TokenType::LBrace); - - RefPtr body; - - if(!tokenReader.IsAtEnd()) - { - FillPosition(blockStatement.Ptr()); - } - while (!AdvanceIfMatch(this, TokenType::RBrace)) - { - auto stmt = ParseStatement(); - if(stmt) - { - if (!body) - { - body = stmt; - } - else if (auto seqStmt = as(body)) - { - seqStmt->stmts.add(stmt); - } - else - { - RefPtr newBody = new SeqStmt(); - newBody->loc = blockStatement->loc; - newBody->stmts.add(body); - newBody->stmts.add(stmt); - - body = newBody; - } - } - TryRecover(this); - } - PopScope(); - - if(!body) - { - body = new EmptyStmt(); - body->loc = blockStatement->loc; - } - - blockStatement->body = body; - return blockStatement; - } - - RefPtr Parser::parseVarDeclrStatement( - Modifiers modifiers) - { - RefPtrvarDeclrStatement = new DeclStmt(); - - FillPosition(varDeclrStatement.Ptr()); - auto decl = ParseDeclWithModifiers(this, currentScope->containerDecl, modifiers); - varDeclrStatement->decl = decl; - return varDeclrStatement; - } - - RefPtr Parser::parseIfStatement() - { - RefPtr ifStatement = new IfStmt(); - FillPosition(ifStatement.Ptr()); - ReadToken("if"); - ReadToken(TokenType::LParent); - ifStatement->Predicate = ParseExpression(); - ReadToken(TokenType::RParent); - ifStatement->PositiveStatement = ParseStatement(); - if (LookAheadToken("else")) - { - ReadToken("else"); - ifStatement->NegativeStatement = ParseStatement(); - } - return ifStatement; - } - - RefPtr Parser::ParseForStatement() - { - RefPtr scopeDecl = new ScopeDecl(); - - // HLSL implements the bad approach to scoping a `for` loop - // variable, and we want to respect that, but *only* when - // parsing HLSL code. - // - - bool brokenScoping = getSourceLanguage() == SourceLanguage::HLSL; - - // We will create a distinct syntax node class for the unscoped - // case, just so that we can correctly handle it in downstream - // logic. - // - RefPtr stmt; - if (brokenScoping) - { - stmt = new UnscopedForStmt(); - } - else - { - stmt = new ForStmt(); - } - - stmt->scopeDecl = scopeDecl; - - if(!brokenScoping) - pushScopeAndSetParent(scopeDecl.Ptr()); - FillPosition(stmt.Ptr()); - ReadToken("for"); - ReadToken(TokenType::LParent); - if (peekTypeName(this)) - { - stmt->InitialStatement = parseVarDeclrStatement(Modifiers()); - } - else - { - if (!LookAheadToken(TokenType::Semicolon)) - { - stmt->InitialStatement = ParseExpressionStatement(); - } - else - { - ReadToken(TokenType::Semicolon); - } - } - if (!LookAheadToken(TokenType::Semicolon)) - stmt->PredicateExpression = ParseExpression(); - ReadToken(TokenType::Semicolon); - if (!LookAheadToken(TokenType::RParent)) - stmt->SideEffectExpression = ParseExpression(); - ReadToken(TokenType::RParent); - stmt->Statement = ParseStatement(); - - if (!brokenScoping) - PopScope(); - - return stmt; - } - - RefPtr Parser::ParseWhileStatement() - { - RefPtr whileStatement = new WhileStmt(); - FillPosition(whileStatement.Ptr()); - ReadToken("while"); - ReadToken(TokenType::LParent); - whileStatement->Predicate = ParseExpression(); - ReadToken(TokenType::RParent); - whileStatement->Statement = ParseStatement(); - return whileStatement; - } - - RefPtr Parser::ParseDoWhileStatement() - { - RefPtr doWhileStatement = new DoWhileStmt(); - FillPosition(doWhileStatement.Ptr()); - ReadToken("do"); - doWhileStatement->Statement = ParseStatement(); - ReadToken("while"); - ReadToken(TokenType::LParent); - doWhileStatement->Predicate = ParseExpression(); - ReadToken(TokenType::RParent); - ReadToken(TokenType::Semicolon); - return doWhileStatement; - } - - RefPtr Parser::ParseBreakStatement() - { - RefPtr breakStatement = new BreakStmt(); - FillPosition(breakStatement.Ptr()); - ReadToken("break"); - ReadToken(TokenType::Semicolon); - return breakStatement; - } - - RefPtr Parser::ParseContinueStatement() - { - RefPtr continueStatement = new ContinueStmt(); - FillPosition(continueStatement.Ptr()); - ReadToken("continue"); - ReadToken(TokenType::Semicolon); - return continueStatement; - } - - RefPtr Parser::ParseReturnStatement() - { - RefPtr returnStatement = new ReturnStmt(); - FillPosition(returnStatement.Ptr()); - ReadToken("return"); - if (!LookAheadToken(TokenType::Semicolon)) - returnStatement->Expression = ParseExpression(); - ReadToken(TokenType::Semicolon); - return returnStatement; - } - - RefPtr Parser::ParseExpressionStatement() - { - RefPtr statement = new ExpressionStmt(); - - FillPosition(statement.Ptr()); - statement->Expression = ParseExpression(); - - ReadToken(TokenType::Semicolon); - return statement; - } - - RefPtr Parser::ParseParameter() - { - RefPtr parameter = new ParamDecl(); - parameter->modifiers = ParseModifiers(this); - - DeclaratorInfo declaratorInfo; - declaratorInfo.typeSpec = ParseType(); - - InitDeclarator initDeclarator = ParseInitDeclarator(this); - UnwrapDeclarator(initDeclarator, &declaratorInfo); - - // Assume it is a variable-like declarator - CompleteVarDecl(this, parameter, declaratorInfo); - return parameter; - } - - RefPtr Parser::ParseType() - { - auto typeSpec = parseTypeSpec(this); - if( typeSpec.decl ) - { - AddMember(currentScope, typeSpec.decl); - } - auto typeExpr = typeSpec.expr; - - typeExpr = parsePostfixTypeSuffix(this, typeExpr); - - return typeExpr; - } - - - - TypeExp Parser::ParseTypeExp() - { - return TypeExp(ParseType()); - } - - enum class Associativity - { - Left, Right - }; - - - - Associativity GetAssociativityFromLevel(Precedence level) - { - if (level == Precedence::Assignment) - return Associativity::Right; - else - return Associativity::Left; - } - - - - - Precedence GetOpLevel(Parser* parser, TokenType type) - { - switch(type) - { - case TokenType::QuestionMark: - return Precedence::TernaryConditional; - case TokenType::Comma: - return Precedence::Comma; - case TokenType::OpAssign: - case TokenType::OpMulAssign: - case TokenType::OpDivAssign: - case TokenType::OpAddAssign: - case TokenType::OpSubAssign: - case TokenType::OpModAssign: - case TokenType::OpShlAssign: - case TokenType::OpShrAssign: - case TokenType::OpOrAssign: - case TokenType::OpAndAssign: - case TokenType::OpXorAssign: - return Precedence::Assignment; - case TokenType::OpOr: - return Precedence::LogicalOr; - case TokenType::OpAnd: - return Precedence::LogicalAnd; - case TokenType::OpBitOr: - return Precedence::BitOr; - case TokenType::OpBitXor: - return Precedence::BitXor; - case TokenType::OpBitAnd: - return Precedence::BitAnd; - case TokenType::OpEql: - case TokenType::OpNeq: - return Precedence::EqualityComparison; - case TokenType::OpGreater: - case TokenType::OpGeq: - // Don't allow these ops inside a generic argument - if (parser->genericDepth > 0) return Precedence::Invalid; - ; // fall-thru - case TokenType::OpLeq: - case TokenType::OpLess: - return Precedence::RelationalComparison; - case TokenType::OpRsh: - // Don't allow this op inside a generic argument - if (parser->genericDepth > 0) return Precedence::Invalid; - ; // fall-thru - case TokenType::OpLsh: - return Precedence::BitShift; - case TokenType::OpAdd: - case TokenType::OpSub: - return Precedence::Additive; - case TokenType::OpMul: - case TokenType::OpDiv: - case TokenType::OpMod: - return Precedence::Multiplicative; - default: - return Precedence::Invalid; - } - } - - static RefPtr parseOperator(Parser* parser) - { - Token opToken; - switch(parser->tokenReader.PeekTokenType()) - { - case TokenType::QuestionMark: - opToken = parser->ReadToken(); - opToken.Content = UnownedStringSlice::fromLiteral("?:"); - break; - - default: - opToken = parser->ReadToken(); - break; - } - - auto opExpr = new VarExpr(); - opExpr->name = getName(parser, opToken.Content); - opExpr->scope = parser->currentScope; - opExpr->loc = opToken.loc; - - return opExpr; - - } - - static RefPtr createInfixExpr( - Parser* /*parser*/, - RefPtr left, - RefPtr op, - RefPtr right) - { - RefPtr expr = new InfixExpr(); - expr->loc = op->loc; - expr->FunctionExpr = op; - expr->Arguments.add(left); - expr->Arguments.add(right); - return expr; - } - - static RefPtr parseInfixExprWithPrecedence( - Parser* parser, - RefPtr inExpr, - Precedence prec) - { - auto expr = inExpr; - for(;;) - { - auto opTokenType = parser->tokenReader.PeekTokenType(); - auto opPrec = GetOpLevel(parser, opTokenType); - if(opPrec < prec) - break; - - auto op = parseOperator(parser); - - // Special case the `?:` operator since it is the - // one non-binary case we need to deal with. - if(opTokenType == TokenType::QuestionMark) - { - RefPtr select = new SelectExpr(); - select->loc = op->loc; - select->FunctionExpr = op; - - select->Arguments.add(expr); - - select->Arguments.add(parser->ParseExpression(opPrec)); - parser->ReadToken(TokenType::Colon); - select->Arguments.add(parser->ParseExpression(opPrec)); - - expr = select; - continue; - } - - auto right = parser->ParseLeafExpression(); - - for(;;) - { - auto nextOpPrec = GetOpLevel(parser, parser->tokenReader.PeekTokenType()); - - if((GetAssociativityFromLevel(nextOpPrec) == Associativity::Right) ? (nextOpPrec < opPrec) : (nextOpPrec <= opPrec)) - break; - - right = parseInfixExprWithPrecedence(parser, right, nextOpPrec); - } - - if (opTokenType == TokenType::OpAssign) - { - RefPtr assignExpr = new AssignExpr(); - assignExpr->loc = op->loc; - assignExpr->left = expr; - assignExpr->right = right; - - expr = assignExpr; - } - else - { - expr = createInfixExpr(parser, expr, op, right); - } - } - return expr; - } - - RefPtr Parser::ParseExpression(Precedence level) - { - auto expr = ParseLeafExpression(); - return parseInfixExprWithPrecedence(this, expr, level); - -#if 0 - - if (level == Precedence::Prefix) - return ParseLeafExpression(); - if (level == Precedence::TernaryConditional) - { - // parse select clause - auto condition = ParseExpression(Precedence(level + 1)); - if (LookAheadToken(TokenType::QuestionMark)) - { - RefPtr select = new SelectExpr(); - FillPosition(select.Ptr()); - - select->Arguments.Add(condition); - - select->FunctionExpr = parseOperator(this); - - select->Arguments.Add(ParseExpression(level)); - ReadToken(TokenType::Colon); - select->Arguments.Add(ParseExpression(level)); - return select; - } - else - return condition; - } - else - { - if (GetAssociativityFromLevel(level) == Associativity::Left) - { - auto left = ParseExpression(Precedence(level + 1)); - while (GetOpLevel(this, tokenReader.PeekTokenType()) == level) - { - RefPtr tmp = new InfixExpr(); - tmp->FunctionExpr = parseOperator(this); - - tmp->Arguments.Add(left); - FillPosition(tmp.Ptr()); - tmp->Arguments.Add(ParseExpression(Precedence(level + 1))); - left = tmp; - } - return left; - } - else - { - auto left = ParseExpression(Precedence(level + 1)); - if (GetOpLevel(this, tokenReader.PeekTokenType()) == level) - { - RefPtr tmp = new InfixExpr(); - tmp->Arguments.Add(left); - FillPosition(tmp.Ptr()); - tmp->FunctionExpr = parseOperator(this); - tmp->Arguments.Add(ParseExpression(level)); - left = tmp; - } - return left; - } - } -#endif - } - - // We *might* be looking at an application of a generic to arguments, - // but we need to disambiguate to make sure. - static RefPtr maybeParseGenericApp( - Parser* parser, - - // TODO: need to support more general expressions here - RefPtr base) - { - if(peekTokenType(parser) != TokenType::OpLess) - return base; - return tryParseGenericApp(parser, base); - } - - static RefPtr parsePrefixExpr(Parser* parser); - - // Parse OOP `this` expression syntax - static RefPtr parseThisExpr(Parser* parser, void* /*userData*/) - { - RefPtr expr = new ThisExpr(); - expr->scope = parser->currentScope; - return expr; - } - - static RefPtr parseBoolLitExpr(Parser* /*parser*/, bool value) - { - RefPtr expr = new BoolLiteralExpr(); - expr->value = value; - return expr; - } - - static RefPtr parseTrueExpr(Parser* parser, void* /*userData*/) - { - return parseBoolLitExpr(parser, true); - } - - static RefPtr parseFalseExpr(Parser* parser, void* /*userData*/) - { - return parseBoolLitExpr(parser, false); - } - - static RefPtr parseAtomicExpr(Parser* parser) - { - switch( peekTokenType(parser) ) - { - default: - // TODO: should this return an error expression instead of NULL? - parser->sink->diagnose(parser->tokenReader.PeekLoc(), Diagnostics::syntaxError); - return nullptr; - - // Either: - // - parenthized expression `(exp)` - // - cast `(type) exp` - // - // Proper disambiguation requires mixing up parsing - // and semantic checking (which we should do eventually) - // but for now we will follow some heuristics. - case TokenType::LParent: - { - Token openParen = parser->ReadToken(TokenType::LParent); - - if (peekTypeName(parser) && parser->LookAheadToken(TokenType::RParent, 1)) - { - RefPtr tcexpr = new ExplicitCastExpr(); - parser->FillPosition(tcexpr.Ptr()); - tcexpr->FunctionExpr = parser->ParseType(); - parser->ReadToken(TokenType::RParent); - - auto arg = parsePrefixExpr(parser); - tcexpr->Arguments.add(arg); - - return tcexpr; - } - else - { - RefPtr base = parser->ParseExpression(); - parser->ReadToken(TokenType::RParent); - - RefPtr parenExpr = new ParenExpr(); - parenExpr->loc = openParen.loc; - parenExpr->base = base; - return parenExpr; - } - } - - // An initializer list `{ expr, ... }` - case TokenType::LBrace: - { - RefPtr initExpr = new InitializerListExpr(); - parser->FillPosition(initExpr.Ptr()); - - // Initializer list - parser->ReadToken(TokenType::LBrace); - - List> exprs; - - for(;;) - { - if(AdvanceIfMatch(parser, TokenType::RBrace)) - break; - - auto expr = parser->ParseArgExpr(); - if( expr ) - { - initExpr->args.add(expr); - } - - if(AdvanceIfMatch(parser, TokenType::RBrace)) - break; - - parser->ReadToken(TokenType::Comma); - } - - return initExpr; - } - - case TokenType::IntegerLiteral: - { - RefPtr constExpr = new IntegerLiteralExpr(); - parser->FillPosition(constExpr.Ptr()); - - auto token = parser->tokenReader.AdvanceToken(); - constExpr->token = token; - - UnownedStringSlice suffix; - IntegerLiteralValue value = getIntegerLiteralValue(token, &suffix); - - // Look at any suffix on the value - char const* suffixCursor = suffix.begin(); - const char*const suffixEnd = suffix.end(); - - RefPtr suffixType = nullptr; - if( suffixCursor < suffixEnd ) - { - int lCount = 0; - int uCount = 0; - int unknownCount = 0; - while(suffixCursor < suffixEnd) - { - switch( *suffixCursor++ ) - { - case 'l': case 'L': - lCount++; - break; - - case 'u': case 'U': - uCount++; - break; - - default: - unknownCount++; - break; - } - } - - if(unknownCount) - { - parser->sink->diagnose(token, Diagnostics::invalidIntegerLiteralSuffix, suffix); - suffixType = parser->getSession()->getErrorType(); - } - // `u` or `ul` suffix -> `uint` - else if(uCount == 1 && (lCount <= 1)) - { - suffixType = parser->getSession()->getUIntType(); - } - // `l` suffix on integer -> `int` (== `long`) - else if(lCount == 1 && !uCount) - { - suffixType = parser->getSession()->getIntType(); - } - // `ull` suffix -> `uint64_t` - else if(uCount == 1 && lCount == 2) - { - suffixType = parser->getSession()->getUInt64Type(); - } - // `ll` suffix -> `int64_t` - else if(uCount == 0 && lCount == 2) - { - suffixType = parser->getSession()->getInt64Type(); - } - // TODO: do we need suffixes for smaller integer types? - else - { - parser->sink->diagnose(token, Diagnostics::invalidIntegerLiteralSuffix, suffix); - suffixType = parser->getSession()->getErrorType(); - } - } - - constExpr->value = value; - constExpr->type = QualType(suffixType); - - return constExpr; - } - - - case TokenType::FloatingPointLiteral: - { - RefPtr constExpr = new FloatingPointLiteralExpr(); - parser->FillPosition(constExpr.Ptr()); - - auto token = parser->tokenReader.AdvanceToken(); - constExpr->token = token; - - UnownedStringSlice suffix; - FloatingPointLiteralValue value = getFloatingPointLiteralValue(token, &suffix); - - // Look at any suffix on the value - char const* suffixCursor = suffix.begin(); - const char*const suffixEnd = suffix.end(); - - RefPtr suffixType = nullptr; - if( suffixCursor < suffixEnd ) - { - int fCount = 0; - int lCount = 0; - int hCount = 0; - int unknownCount = 0; - while(suffixCursor < suffixEnd) - { - switch( *suffixCursor++ ) - { - case 'f': case 'F': - fCount++; - break; - - case 'l': case 'L': - lCount++; - break; - - case 'h': case 'H': - hCount++; - break; - - default: - unknownCount++; - break; - } - } - - if (unknownCount) - { - parser->sink->diagnose(token, Diagnostics::invalidFloatingPointLiteralSuffix, suffix); - suffixType = parser->getSession()->getErrorType(); - } - // `f` suffix -> `float` - if(fCount == 1 && !lCount) - { - suffixType = parser->getSession()->getFloatType(); - } - // `l` or `lf` suffix on floating-point literal -> `double` - else if(lCount == 1 && (fCount <= 1)) - { - suffixType = parser->getSession()->getDoubleType(); - } - // `h` or `hf` suffix on floating-point literal -> `half` - else if(lCount == 1 && (fCount <= 1)) - { - suffixType = parser->getSession()->getHalfType(); - } - // TODO: are there other suffixes we need to handle? - else - { - parser->sink->diagnose(token, Diagnostics::invalidFloatingPointLiteralSuffix, suffix); - suffixType = parser->getSession()->getErrorType(); - } - } - - constExpr->value = value; - constExpr->type = QualType(suffixType); - - return constExpr; - } - - case TokenType::StringLiteral: - { - RefPtr constExpr = new StringLiteralExpr(); - auto token = parser->tokenReader.AdvanceToken(); - constExpr->token = token; - parser->FillPosition(constExpr.Ptr()); - - if (!parser->LookAheadToken(TokenType::StringLiteral)) - { - // Easy/common case: a single string - constExpr->value = getStringLiteralTokenValue(token); - } - else - { - StringBuilder sb; - sb << getStringLiteralTokenValue(token); - while (parser->LookAheadToken(TokenType::StringLiteral)) - { - token = parser->tokenReader.AdvanceToken(); - sb << getStringLiteralTokenValue(token); - } - constExpr->value = sb.ProduceString(); - } - - return constExpr; - } - - case TokenType::Identifier: - { - // We will perform name lookup here so that we can find syntax - // keywords registered for use as expressions. - Token nameToken = peekToken(parser); - - RefPtr parsedExpr; - if (tryParseUsingSyntaxDecl(parser, &parsedExpr)) - { - if (!parsedExpr->loc.isValid()) - { - parsedExpr->loc = nameToken.loc; - } - return parsedExpr; - } - - // Default behavior is just to create a name expression - RefPtr varExpr = new VarExpr(); - varExpr->scope = parser->currentScope.Ptr(); - parser->FillPosition(varExpr.Ptr()); - - auto nameAndLoc = expectIdentifier(parser); - varExpr->name = nameAndLoc.name; - - if(peekTokenType(parser) == TokenType::OpLess) - { - return maybeParseGenericApp(parser, varExpr); - } - - return varExpr; - } - } - } - - static RefPtr parsePostfixExpr(Parser* parser) - { - auto expr = parseAtomicExpr(parser); - for(;;) - { - switch( peekTokenType(parser) ) - { - default: - return expr; - - // Postfix increment/decrement - case TokenType::OpInc: - case TokenType::OpDec: - { - RefPtr postfixExpr = new PostfixExpr(); - parser->FillPosition(postfixExpr.Ptr()); - postfixExpr->FunctionExpr = parseOperator(parser); - postfixExpr->Arguments.add(expr); - - expr = postfixExpr; - } - break; - - // Subscript operation `a[i]` - case TokenType::LBracket: - { - RefPtr indexExpr = new IndexExpr(); - indexExpr->BaseExpression = expr; - parser->FillPosition(indexExpr.Ptr()); - parser->ReadToken(TokenType::LBracket); - // TODO: eventually we may want to support multiple arguments inside the `[]` - if (!parser->LookAheadToken(TokenType::RBracket)) - { - indexExpr->IndexExpression = parser->ParseExpression(); - } - parser->ReadToken(TokenType::RBracket); - - expr = indexExpr; - } - break; - - // Call oepration `f(x)` - case TokenType::LParent: - { - RefPtr invokeExpr = new InvokeExpr(); - invokeExpr->FunctionExpr = expr; - parser->FillPosition(invokeExpr.Ptr()); - parser->ReadToken(TokenType::LParent); - while (!parser->tokenReader.IsAtEnd()) - { - if (!parser->LookAheadToken(TokenType::RParent)) - invokeExpr->Arguments.add(parser->ParseArgExpr()); - else - { - break; - } - if (!parser->LookAheadToken(TokenType::Comma)) - break; - parser->ReadToken(TokenType::Comma); - } - parser->ReadToken(TokenType::RParent); - - expr = invokeExpr; - } - break; - - // Scope access `x::m` - case TokenType::Scope: - { - RefPtr staticMemberExpr = new StaticMemberExpr(); - - // TODO(tfoley): why would a member expression need this? - staticMemberExpr->scope = parser->currentScope.Ptr(); - - parser->FillPosition(staticMemberExpr.Ptr()); - staticMemberExpr->BaseExpression = expr; - parser->ReadToken(TokenType::Scope); - staticMemberExpr->name = expectIdentifier(parser).name; - - if (peekTokenType(parser) == TokenType::OpLess) - expr = maybeParseGenericApp(parser, staticMemberExpr); - else - expr = staticMemberExpr; - - break; - } - // Member access `x.m` - case TokenType::Dot: - { - RefPtr memberExpr = new MemberExpr(); - - // TODO(tfoley): why would a member expression need this? - memberExpr->scope = parser->currentScope.Ptr(); - - parser->FillPosition(memberExpr.Ptr()); - memberExpr->BaseExpression = expr; - parser->ReadToken(TokenType::Dot); - memberExpr->name = expectIdentifier(parser).name; - - if (peekTokenType(parser) == TokenType::OpLess) - expr = maybeParseGenericApp(parser, memberExpr); - else - expr = memberExpr; - } - break; - } - } - } - - static RefPtr parsePrefixExpr(Parser* parser) - { - switch( peekTokenType(parser) ) - { - default: - return parsePostfixExpr(parser); - - case TokenType::OpInc: - case TokenType::OpDec: - case TokenType::OpNot: - case TokenType::OpBitNot: - case TokenType::OpAdd: - case TokenType::OpSub: - { - RefPtr prefixExpr = new PrefixExpr(); - parser->FillPosition(prefixExpr.Ptr()); - prefixExpr->FunctionExpr = parseOperator(parser); - prefixExpr->Arguments.add(parsePrefixExpr(parser)); - return prefixExpr; - } - break; - } - } - - RefPtr Parser::ParseLeafExpression() - { - return parsePrefixExpr(this); - } - - RefPtr parseTypeFromSourceFile( - Session* session, - TokenSpan const& tokens, - DiagnosticSink* sink, - RefPtr const& outerScope, - NamePool* namePool, - SourceLanguage sourceLanguage) - { - Parser parser(session, tokens, sink, outerScope); - parser.currentScope = outerScope; - parser.namePool = namePool; - parser.sourceLanguage = sourceLanguage; - return parser.ParseType(); - } - - // Parse a source file into an existing translation unit - void parseSourceFile( - TranslationUnitRequest* translationUnit, - TokenSpan const& tokens, - DiagnosticSink* sink, - RefPtr const& outerScope) - { - Parser parser(translationUnit->getSession(), tokens, sink, outerScope); - parser.namePool = translationUnit->getNamePool(); - parser.sourceLanguage = translationUnit->sourceLanguage; - - return parser.parseSourceFile(translationUnit->getModuleDecl()); - } - - static void addBuiltinSyntaxImpl( - Session* session, - Scope* scope, - char const* nameText, - SyntaxParseCallback callback, - void* userData, - SyntaxClass syntaxClass) - { - Name* name = session->getNamePool()->getName(nameText); - - RefPtr syntaxDecl = new SyntaxDecl(); - syntaxDecl->nameAndLoc = NameLoc(name); - syntaxDecl->syntaxClass = syntaxClass; - syntaxDecl->parseCallback = callback; - syntaxDecl->parseUserData = userData; - - AddMember(scope, syntaxDecl); - } - - template - static void addBuiltinSyntax( - Session* session, - Scope* scope, - char const* name, - SyntaxParseCallback callback, - void* userData = nullptr) - { - addBuiltinSyntaxImpl(session, scope, name, callback, userData, getClass()); - } - - template - static void addSimpleModifierSyntax( - Session* session, - Scope* scope, - char const* name) - { - auto syntaxClass = getClass(); - addBuiltinSyntaxImpl(session, scope, name, &parseSimpleSyntax, (void*) syntaxClass.classInfo, getClass()); - } - - static RefPtr parseIntrinsicOpModifier(Parser* parser, void* /*userData*/) - { - RefPtr modifier = new IntrinsicOpModifier(); - - // We allow a few difference forms here: - // - // First, we can specify the intrinsic op `enum` value directly: - // - // __intrinsic_op() - // - // Second, we can specify the operation by name: - // - // __intrinsic_op() - // - // Finally, we can leave off the specification, so that the - // op name will be derived from the function name: - // - // __intrinsic_op - // - if (AdvanceIf(parser, TokenType::LParent)) - { - if (AdvanceIf(parser, TokenType::OpSub)) - { - modifier->op = IROp(-StringToInt(parser->ReadToken().Content)); - } - else if (parser->LookAheadToken(TokenType::IntegerLiteral)) - { - modifier->op = IROp(StringToInt(parser->ReadToken().Content)); - } - else - { - modifier->opToken = parser->ReadToken(TokenType::Identifier); - - modifier->op = findIROp(modifier->opToken.Content); - - if (modifier->op == kIROp_Invalid) - { - parser->sink->diagnose(modifier->opToken, Diagnostics::unimplemented, "unknown intrinsic op"); - } - } - - parser->ReadToken(TokenType::RParent); - } - - - return modifier; - } - - static RefPtr parseTargetIntrinsicModifier(Parser* parser, void* /*userData*/) - { - auto modifier = new TargetIntrinsicModifier(); - - if (AdvanceIf(parser, TokenType::LParent)) - { - modifier->targetToken = parser->ReadToken(TokenType::Identifier); - - if( AdvanceIf(parser, TokenType::Comma) ) - { - if( parser->LookAheadToken(TokenType::StringLiteral) ) - { - modifier->definitionToken = parser->ReadToken(); - } - else - { - modifier->definitionToken = parser->ReadToken(TokenType::Identifier); - } - } - - parser->ReadToken(TokenType::RParent); - } - - return modifier; - } - - static RefPtr parseSpecializedForTargetModifier(Parser* parser, void* /*userData*/) - { - auto modifier = new SpecializedForTargetModifier(); - if (AdvanceIf(parser, TokenType::LParent)) - { - modifier->targetToken = parser->ReadToken(TokenType::Identifier); - parser->ReadToken(TokenType::RParent); - } - return modifier; - } - - static RefPtr parseGLSLExtensionModifier(Parser* parser, void* /*userData*/) - { - auto modifier = new RequiredGLSLExtensionModifier(); - - parser->ReadToken(TokenType::LParent); - modifier->extensionNameToken = parser->ReadToken(TokenType::Identifier); - parser->ReadToken(TokenType::RParent); - - return modifier; - } - - static RefPtr parseGLSLVersionModifier(Parser* parser, void* /*userData*/) - { - auto modifier = new RequiredGLSLVersionModifier(); - - parser->ReadToken(TokenType::LParent); - modifier->versionNumberToken = parser->ReadToken(TokenType::IntegerLiteral); - parser->ReadToken(TokenType::RParent); - - return modifier; - } - - static RefPtr parseLayoutModifier(Parser* parser, void* /*userData*/) - { - ModifierListBuilder listBuilder; - - listBuilder.add(new GLSLLayoutModifierGroupBegin()); - - parser->ReadToken(TokenType::LParent); - while (!AdvanceIfMatch(parser, TokenType::RParent)) - { - auto nameAndLoc = expectIdentifier(parser); - const String& nameText = nameAndLoc.name->text; - - if (nameText == "binding" || - nameText == "set") - { - GLSLBindingAttribute* attr = listBuilder.find(); - if (!attr) - { - attr = new GLSLBindingAttribute(); - listBuilder.add(attr); - } - - parser->ReadToken(TokenType::OpAssign); - - // If the token asked for is not returned found will put in recovering state, and return token found - Token valToken = parser->ReadToken(TokenType::IntegerLiteral); - // If wasn't the desired IntegerLiteral return that couldn't parse - if (valToken.type != TokenType::IntegerLiteral) - { - return nullptr; - } - - // Work out the value - auto value = getIntegerLiteralValue(valToken); - - if (nameText == "binding") - { - attr->binding = int32_t(value); - } - else - { - attr->set = int32_t(value); - } - } - else - { - RefPtr modifier; - -#define CASE(key, type) if (nameText == #key) { modifier = new type; } else - CASE(push_constant, PushConstantAttribute) - CASE(shaderRecordNV, ShaderRecordAttribute) - CASE(constant_id, GLSLConstantIDLayoutModifier) - CASE(location, GLSLLocationLayoutModifier) - CASE(local_size_x, GLSLLocalSizeXLayoutModifier) - CASE(local_size_y, GLSLLocalSizeYLayoutModifier) - CASE(local_size_z, GLSLLocalSizeZLayoutModifier) - { - modifier = new GLSLUnparsedLayoutModifier(); - } - SLANG_ASSERT(modifier); -#undef CASE - - modifier->name = nameAndLoc.name; - modifier->loc = nameAndLoc.loc; - - // Special handling for GLSLLayoutModifier - if (auto glslModifier = as(modifier)) - { - if (AdvanceIf(parser, TokenType::OpAssign)) - { - glslModifier->valToken = parser->ReadToken(TokenType::IntegerLiteral); - } - } - - listBuilder.add(modifier); - } - - if (AdvanceIf(parser, TokenType::RParent)) - break; - parser->ReadToken(TokenType::Comma); - } - - listBuilder.add(new GLSLLayoutModifierGroupEnd()); - - return listBuilder.getFirst(); - } - - static RefPtr parseBuiltinTypeModifier(Parser* parser, void* /*userData*/) - { - RefPtr modifier = new BuiltinTypeModifier(); - parser->ReadToken(TokenType::LParent); - modifier->tag = BaseType(StringToInt(parser->ReadToken(TokenType::IntegerLiteral).Content)); - parser->ReadToken(TokenType::RParent); - - return modifier; - } - - static RefPtr parseMagicTypeModifier(Parser* parser, void* /*userData*/) - { - RefPtr modifier = new MagicTypeModifier(); - parser->ReadToken(TokenType::LParent); - modifier->name = parser->ReadToken(TokenType::Identifier).Content; - if (AdvanceIf(parser, TokenType::Comma)) - { - modifier->tag = uint32_t(StringToInt(parser->ReadToken(TokenType::IntegerLiteral).Content)); - } - parser->ReadToken(TokenType::RParent); - - return modifier; - } - - static RefPtr parseIntrinsicTypeModifier(Parser* parser, void* /*userData*/) - { - RefPtr modifier = new IntrinsicTypeModifier(); - parser->ReadToken(TokenType::LParent); - modifier->irOp = uint32_t(StringToInt(parser->ReadToken(TokenType::IntegerLiteral).Content)); - while( AdvanceIf(parser, TokenType::Comma) ) - { - auto operand = uint32_t(StringToInt(parser->ReadToken(TokenType::IntegerLiteral).Content)); - modifier->irOperands.add(operand); - } - parser->ReadToken(TokenType::RParent); - - return modifier; - } - static RefPtr parseImplicitConversionModifier(Parser* parser, void* /*userData*/) - { - RefPtr modifier = new ImplicitConversionModifier(); - - ConversionCost cost = kConversionCost_Default; - if( AdvanceIf(parser, TokenType::LParent) ) - { - cost = ConversionCost(StringToInt(parser->ReadToken(TokenType::IntegerLiteral).Content)); - parser->ReadToken(TokenType::RParent); - } - modifier->cost = cost; - return modifier; - } - - static RefPtr parseAttributeTargetModifier(Parser* parser, void* /*userData*/) - { - expect(parser, TokenType::LParent); - auto syntaxClassNameAndLoc = expectIdentifier(parser); - expect(parser, TokenType::RParent); - - auto syntaxClass = parser->getSession()->findSyntaxClass(syntaxClassNameAndLoc.name); - - RefPtr modifier = new AttributeTargetModifier(); - modifier->syntaxClass = syntaxClass; - - return modifier; - } - - RefPtr populateBaseLanguageModule( - Session* session, - RefPtr scope) - { - RefPtr moduleDecl = new ModuleDecl(); - scope->containerDecl = moduleDecl; - - // Add syntax for declaration keywords - #define DECL(KEYWORD, CALLBACK) \ - addBuiltinSyntax(session, scope, #KEYWORD, &CALLBACK) - DECL(typedef, ParseTypeDef); - DECL(associatedtype, parseAssocType); - DECL(type_param, parseGlobalGenericParamDecl); - DECL(cbuffer, parseHLSLCBufferDecl); - DECL(tbuffer, parseHLSLTBufferDecl); - DECL(__generic, ParseGenericDecl); - DECL(__extension, ParseExtensionDecl); - DECL(extension, ParseExtensionDecl); - DECL(__init, parseConstructorDecl); - DECL(__subscript, ParseSubscriptDecl); - DECL(interface, parseInterfaceDecl); - DECL(syntax, parseSyntaxDecl); - DECL(attribute_syntax,parseAttributeSyntaxDecl); - DECL(__import, parseImportDecl); - DECL(import, parseImportDecl); - DECL(let, parseLetDecl); - DECL(var, parseVarDecl); - DECL(func, parseFuncDecl); - DECL(typealias, parseTypeAliasDecl); - - #undef DECL - - // Add syntax for "simple" modifier keywords. - // These are the ones that just appear as a single - // keyword (no further tokens expected/allowed), - // and which can be represented just by creating - // a new AST node of the corresponding type. - #define MODIFIER(KEYWORD, CLASS) \ - addSimpleModifierSyntax(session, scope, #KEYWORD) - - MODIFIER(in, InModifier); - MODIFIER(input, InputModifier); - MODIFIER(out, OutModifier); - MODIFIER(inout, InOutModifier); - MODIFIER(__ref, RefModifier); - MODIFIER(const, ConstModifier); - MODIFIER(instance, InstanceModifier); - MODIFIER(__builtin, BuiltinModifier); - - MODIFIER(inline, InlineModifier); - MODIFIER(public, PublicModifier); - MODIFIER(require, RequireModifier); - MODIFIER(param, ParamModifier); - MODIFIER(extern, ExternModifier); - - MODIFIER(row_major, HLSLRowMajorLayoutModifier); - MODIFIER(column_major, HLSLColumnMajorLayoutModifier); - - MODIFIER(nointerpolation, HLSLNoInterpolationModifier); - MODIFIER(noperspective, HLSLNoPerspectiveModifier); - MODIFIER(linear, HLSLLinearModifier); - MODIFIER(sample, HLSLSampleModifier); - MODIFIER(centroid, HLSLCentroidModifier); - MODIFIER(precise, PreciseModifier); - MODIFIER(shared, HLSLEffectSharedModifier); - MODIFIER(groupshared, HLSLGroupSharedModifier); - MODIFIER(static, HLSLStaticModifier); - MODIFIER(uniform, HLSLUniformModifier); - MODIFIER(volatile, HLSLVolatileModifier); - - // Modifiers for geometry shader input - MODIFIER(point, HLSLPointModifier); - MODIFIER(line, HLSLLineModifier); - MODIFIER(triangle, HLSLTriangleModifier); - MODIFIER(lineadj, HLSLLineAdjModifier); - MODIFIER(triangleadj, HLSLTriangleAdjModifier); - - // Modifiers for unary operator declarations - MODIFIER(__prefix, PrefixModifier); - MODIFIER(__postfix, PostfixModifier); - - // Modifier to apply to `import` that should be re-exported - MODIFIER(__exported, ExportedModifier); - - #undef MODIFIER - - // Add syntax for more complex modifiers, which allow - // or expect more tokens after the initial keyword. - #define MODIFIER(KEYWORD, CALLBACK) \ - addBuiltinSyntax(session, scope, #KEYWORD, &CALLBACK) - - MODIFIER(layout, parseLayoutModifier); - - MODIFIER(__intrinsic_op, parseIntrinsicOpModifier); - MODIFIER(__target_intrinsic, parseTargetIntrinsicModifier); - MODIFIER(__specialized_for_target, parseSpecializedForTargetModifier); - MODIFIER(__glsl_extension, parseGLSLExtensionModifier); - MODIFIER(__glsl_version, parseGLSLVersionModifier); - - MODIFIER(__builtin_type, parseBuiltinTypeModifier); - MODIFIER(__magic_type, parseMagicTypeModifier); - MODIFIER(__intrinsic_type, parseIntrinsicTypeModifier); - MODIFIER(__implicit_conversion, parseImplicitConversionModifier); - - MODIFIER(__attributeTarget, parseAttributeTargetModifier); - - -#undef MODIFIER - - // Add syntax for expression keywords - #define EXPR(KEYWORD, CALLBACK) \ - addBuiltinSyntax(session, scope, #KEYWORD, &CALLBACK) - - EXPR(this, parseThisExpr); - EXPR(true, parseTrueExpr); - EXPR(false, parseFalseExpr); - - #undef EXPR - - return moduleDecl; - } - -} diff --git a/source/slang/parser.h b/source/slang/parser.h deleted file mode 100644 index abad902da..000000000 --- a/source/slang/parser.h +++ /dev/null @@ -1,30 +0,0 @@ -#ifndef RASTER_RENDERER_PARSER_H -#define RASTER_RENDERER_PARSER_H - -#include "lexer.h" -#include "compiler.h" -#include "syntax.h" - -namespace Slang -{ - // Parse a source file into an existing translation unit - void parseSourceFile( - TranslationUnitRequest* translationUnit, - TokenSpan const& tokens, - DiagnosticSink* sink, - RefPtr const& outerScope); - - RefPtr parseTypeFromSourceFile( - Session* session, - TokenSpan const& tokens, - DiagnosticSink* sink, - RefPtr const& outerScope, - NamePool* namePool, - SourceLanguage sourceLanguage); - - RefPtr populateBaseLanguageModule( - Session* session, - RefPtr scope); -} - -#endif \ No newline at end of file diff --git a/source/slang/preprocessor.cpp b/source/slang/preprocessor.cpp deleted file mode 100644 index bf6f7b7ca..000000000 --- a/source/slang/preprocessor.cpp +++ /dev/null @@ -1,2302 +0,0 @@ -// preprocessor.cpp -#include "preprocessor.h" - -#include "compiler.h" -#include "diagnostics.h" -#include "lexer.h" -// Needed so that we can construct modifier syntax to represent GLSL directives -#include "syntax.h" - -#include - -// This file provides an implementation of a simple C-style preprocessor. -// It does not aim for 100% compatibility with any particular preprocessor -// specification, but the goal is to have it accept the most common -// idioms for using the preprocessor, found in shader code in the wild. - - -namespace Slang { - -// State of a preprocessor conditional, which can change when -// we encounter directives like `#elif` or `#endif` -enum class PreprocessorConditionalState -{ - Before, // We have not yet seen a branch with a `true` condition. - During, // We are inside the branch with a `true` condition. - After, // We have already seen the branch with a `true` condition. -}; - -// Represents a preprocessor conditional that we are currently -// nested inside. -struct PreprocessorConditional -{ - // The next outer conditional in the current file/stream, or NULL. - PreprocessorConditional* parent; - - // The directive token that started the conditional (an `#if` or `#ifdef`) - Token ifToken; - - // The `#else` directive token, if one has been seen (otherwise `TokenType::Unknown`) - Token elseToken; - - // The state of the conditional - PreprocessorConditionalState state; -}; - -struct PreprocessorMacro; - -struct PreprocessorEnvironment -{ - // The "outer" environment, to be used if lookup in this env fails - PreprocessorEnvironment* parent = NULL; - - // Macros defined in this environment - Dictionary macros; - - ~PreprocessorEnvironment(); -}; - -// Input tokens can either come from source text, or from macro expansion. -// In general, input streams can be nested, so we have to keep a conceptual -// stack of input. - -struct PrimaryInputStream; - -// A stream of input tokens to be consumed -struct PreprocessorInputStream -{ - // The primary input stream that is the parent to this one, - // or NULL if this stream is itself a primary stream. - PrimaryInputStream* primaryStream; - - // The next input stream up the stack, if any. - PreprocessorInputStream* parent; - - // Environment to use when looking up macros - PreprocessorEnvironment* environment; - - // Destructor is virtual so that we can clean up - // after concrete subtypes. - virtual ~PreprocessorInputStream() = default; -}; - -// A "primary" input stream represents the top-level context of a file -// being parsed, and tracks things like preprocessor conditional state -struct PrimaryInputStream : PreprocessorInputStream -{ - // The next *primary* input stream up the stack - PrimaryInputStream* parentPrimaryInputStream; - - // The deepest preprocessor conditional active for this stream. - PreprocessorConditional* conditional; - - // The lexer state that will provide input - Lexer lexer; - - // One token of lookahead - Token token; -}; - -// A "secondary" input stream represents code that is being expanded -// into the current scope, but which had already been tokenized before. -// -struct PretokenizedInputStream : PreprocessorInputStream -{ - // Reader for pre-tokenized input - TokenReader tokenReader; -}; - -// A pre-tokenized input stream that will only be used once, and which -// therefore owns the memory for its tokens. -struct SimpleTokenInputStream : PretokenizedInputStream -{ - // A list of raw tokens that will provide input - TokenList lexedTokens; -}; - -struct MacroExpansion : PretokenizedInputStream -{ - // The macro we will expand - PreprocessorMacro* macro; -}; - -struct ObjectLikeMacroExpansion : MacroExpansion -{ -}; - -struct FunctionLikeMacroExpansion : MacroExpansion -{ - // Environment for macro arguments - PreprocessorEnvironment argumentEnvironment; -}; - -// An enumeration for the diferent types of macros -enum class PreprocessorMacroFlavor -{ - ObjectLike, - FunctionArg, - FunctionLike, -}; - -// In the current design (which we may want to re-consider), -// a macro is a specialized flavor of input stream, that -// captures the token list in its expansion, and then -// can be "played back." -struct PreprocessorMacro -{ - // The name under which the macro was `#define`d - NameLoc nameAndLoc; - - // Parameters of the macro, in case of a function-like macro - List params; - - // The tokens that make up the macro body - TokenList tokens; - - // The flavor of macro - PreprocessorMacroFlavor flavor; - - // The environment in which this macro needs to be expanded. - // For ordinary macros this will be the global environment, - // while for function-like macro arguments, it will be - // the environment of the macro invocation. - PreprocessorEnvironment* environment; - - // - Name* getName() - { - return nameAndLoc.name; - } - - SourceLoc getLoc() - { - return nameAndLoc.loc; - } -}; - -// State of the preprocessor -struct Preprocessor -{ - // diagnostics sink to use when writing messages - DiagnosticSink* sink; - - // An external callback interface to use when looking - // for files in a `#include` directive - IncludeHandler* includeHandler; - - // Current input stream (top of the stack of input) - PreprocessorInputStream* inputStream; - - // Currently-defined macros - PreprocessorEnvironment globalEnv; - - // A pre-allocated token that can be returned to - // represent end-of-input situations. - Token endOfFileToken; - - /// The linkage the provides the context for preprocessing - Linkage* linkage = nullptr; - - /// The module, if any, that the preprocessed result will belong to - Module* parentModule = nullptr; - - // The unique identities of any paths that have issued `#pragma once` directives to - // stop them from being included again. - HashSet pragmaOnceUniqueIdentities; - - NamePool* getNamePool() { return linkage->getNamePool(); } - SourceManager* getSourceManager() { return linkage->getSourceManager(); } -}; - -// Convenience routine to access the diagnostic sink -static DiagnosticSink* GetSink(Preprocessor* preprocessor) -{ - return preprocessor->sink; -} - -// -// Forward declarations -// - -static void DestroyConditional(PreprocessorConditional* conditional); -static void DestroyMacro(Preprocessor* preprocessor, PreprocessorMacro* macro); -static bool IsSkipping(Preprocessor* preprocessor); - -// -// Basic Input Handling -// - -// Create a fresh input stream -static void initializeInputStream(Preprocessor* preprocessor, PreprocessorInputStream* inputStream) -{ - inputStream->parent = NULL; - inputStream->environment = &preprocessor->globalEnv; -} - -static void initializePrimaryInputStream(Preprocessor* preprocessor, PrimaryInputStream* inputStream) -{ - initializeInputStream(preprocessor, inputStream); - inputStream->primaryStream = inputStream; - inputStream->conditional = NULL; -} - -// Destroy an input stream -static void destroyInputStream(Preprocessor* /*preprocessor*/, PreprocessorInputStream* inputStream) -{ - delete inputStream; -} - -// Create an input stream to represent a pre-tokenized input file. -// TODO(tfoley): pre-tokenizing files isn't going to work in the long run. -static PreprocessorInputStream* CreateInputStreamForSource( - Preprocessor* preprocessor, - SourceView* sourceView) -{ - MemoryArena* memoryArena = sourceView->getSourceManager()->getMemoryArena(); - - PrimaryInputStream* inputStream = new PrimaryInputStream(); - initializePrimaryInputStream(preprocessor, inputStream); - - // initialize the embedded lexer so that it can generate a token stream - inputStream->lexer.initialize(sourceView, GetSink(preprocessor), preprocessor->getNamePool(), memoryArena); - inputStream->token = inputStream->lexer.lexToken(); - - return inputStream; -} - -static PrimaryInputStream* asPrimaryInputStream(PreprocessorInputStream* inputStream) -{ - auto primaryStream = inputStream->primaryStream; - if(primaryStream == inputStream) - return primaryStream; - return nullptr; -} - - -static void PushInputStream(Preprocessor* preprocessor, PreprocessorInputStream* inputStream) -{ - inputStream->parent = preprocessor->inputStream; - if(!asPrimaryInputStream(inputStream)) - inputStream->primaryStream = preprocessor->inputStream->primaryStream; - preprocessor->inputStream = inputStream; -} - -// Called when we reach the end of an input stream. -// Performs some validation and then destroys the input stream if required. -static void EndInputStream(Preprocessor* preprocessor, PreprocessorInputStream* inputStream) -{ - if(auto primaryStream = asPrimaryInputStream(inputStream)) - { - // If there are any conditionals that weren't completed, then it is an error - if (primaryStream->conditional) - { - PreprocessorConditional* conditional = primaryStream->conditional; - - GetSink(preprocessor)->diagnose(conditional->ifToken.loc, Diagnostics::endOfFileInPreprocessorConditional); - - while (conditional) - { - PreprocessorConditional* parent = conditional->parent; - DestroyConditional(conditional); - conditional = parent; - } - } - } - - destroyInputStream(preprocessor, inputStream); -} - -// Consume one token from an input stream -static Token AdvanceRawToken(PreprocessorInputStream* inputStream, LexerFlags lexerFlags = 0) -{ - if( auto primaryStream = asPrimaryInputStream(inputStream) ) - { - auto result = primaryStream->token; - primaryStream->token = primaryStream->lexer.lexToken(lexerFlags); - return result; - } - else - { - PretokenizedInputStream* pretokenized = (PretokenizedInputStream*) inputStream; - return pretokenized->tokenReader.AdvanceToken(); - } -} - -// Peek one token from an input stream -static Token PeekRawToken(PreprocessorInputStream* inputStream) -{ - if( auto primaryStream = asPrimaryInputStream(inputStream) ) - { - return primaryStream->token; - } - else - { - PretokenizedInputStream* pretokenized = (PretokenizedInputStream*) inputStream; - return pretokenized->tokenReader.PeekToken(); - } -} - -// Peek one token type from an input stream -static TokenType PeekRawTokenType(PreprocessorInputStream* inputStream) -{ - if( auto primaryStream = asPrimaryInputStream(inputStream) ) - { - return primaryStream->token.type; - } - else - { - PretokenizedInputStream* pretokenized = (PretokenizedInputStream*) inputStream; - return pretokenized->tokenReader.PeekTokenType(); - } -} - - -// Read one token in "raw" mode (meaning don't expand macros) -static Token AdvanceRawToken(Preprocessor* preprocessor, LexerFlags lexerFlags = 0) -{ - for(;;) - { - // Look at the input stream on top of the stack - PreprocessorInputStream* inputStream = preprocessor->inputStream; - - // If there isn't one, then there is no more input left to read. - if(!inputStream) - { - return preprocessor->endOfFileToken; - } - - // The top-most input stream may be at its end - if(PeekRawTokenType(inputStream) == TokenType::EndOfFile) - { - // If there is another stream remaining, switch to it - if(inputStream->parent) - { - preprocessor->inputStream = inputStream->parent; - EndInputStream(preprocessor, inputStream); - continue; - } - } - - // Everything worked, so read a token from the top-most stream - return AdvanceRawToken( - inputStream, - lexerFlags | (IsSkipping(preprocessor) ? kLexerFlag_IgnoreInvalid : 0)); - } -} - -// Return the next token in "raw" mode, but don't advance the -// current token state. -static Token PeekRawToken(Preprocessor* preprocessor) -{ - // We need to find the stream that `advanceRawToken` would read from. - PreprocessorInputStream* inputStream = preprocessor->inputStream; - for (;;) - { - if (!inputStream) - { - // No more input streams left to read - return preprocessor->endOfFileToken; - } - - // The top-most input stream may be at its end, so - // look one entry up the stack (don't actually pop - // here, since we are just peeking) - if (PeekRawTokenType(inputStream) == TokenType::EndOfFile) - { - if (inputStream->parent) - { - inputStream = inputStream->parent; - continue; - } - } - - // Everything worked, so the token we just peeked is fine. - return PeekRawToken(inputStream); - } -} - -// Get the location of the current (raw) token -static SourceLoc PeekLoc(Preprocessor* preprocessor) -{ - return PeekRawToken(preprocessor).loc; -} - -// Get the `TokenType` of the current (raw) token -static TokenType PeekRawTokenType(Preprocessor* preprocessor) -{ - return PeekRawToken(preprocessor).type; -} - -// -// Macros -// - -// Create a macro -static PreprocessorMacro* CreateMacro(Preprocessor* preprocessor) -{ - // TODO(tfoley): Allocate these more intelligently. - // For example, consider pooling them on the preprocessor. - - PreprocessorMacro* macro = new PreprocessorMacro(); - macro->flavor = PreprocessorMacroFlavor::ObjectLike; - macro->environment = &preprocessor->globalEnv; - return macro; -} - -// Destroy a macro -static void DestroyMacro(Preprocessor* /*preprocessor*/, PreprocessorMacro* macro) -{ - delete macro; -} - - -// Find the currently-defined macro of the given name, or return NULL -static PreprocessorMacro* LookupMacro(PreprocessorEnvironment* environment, Name* name) -{ - for(PreprocessorEnvironment* e = environment; e; e = e->parent) - { - PreprocessorMacro* macro = NULL; - if (e->macros.TryGetValue(name, macro)) - return macro; - } - - return NULL; -} - -static PreprocessorEnvironment* GetCurrentEnvironment(Preprocessor* preprocessor) -{ - // The environment we will use for looking up a macro is associated - // with the current input stream (because it may include entries - // for macro arguments). - // - // We need to be careful, though, when we are at the end of an - // input stream (e.g., representing one argument), so that we - // don't use its environment. - - PreprocessorInputStream* inputStream = preprocessor->inputStream; - - for(;;) - { - // If there is no input stream that isn't at its end, - // then fall back to the global environment. - if (!inputStream) - return &preprocessor->globalEnv; - - // If the current input stream is at its end, then - // fall back to its parent stream. - if (PeekRawTokenType(inputStream) == TokenType::EndOfFile) - { - inputStream = inputStream->parent; - continue; - } - - // If we've found an active stream that isn't at its end, - // then use that for lookup. - return inputStream->environment; - } -} - -static PreprocessorMacro* LookupMacro(Preprocessor* preprocessor, Name* name) -{ - return LookupMacro(GetCurrentEnvironment(preprocessor), name); -} - -// A macro is "busy" if it is currently being used for expansion. -// A macro cannot be expanded again while busy, to avoid infinite recursion. -static bool IsMacroBusy(PreprocessorMacro* /*macro*/) -{ - // TODO: need to implement this correctly - // - // The challenge here is that we are implementing expansion - // for argumenst to function-like macros in a "lazy" fashion. - // - // The letter of the spec is that we should macro expand - // each argument *before* substitution, and then go and - // macro-expand the substituted body. This means that we - // can invoke a macro as part of an argument to an - // invocation of the same macro: - // - // FOO( 1, FOO(22), 333 ); - // - // In our implementation, the "inner" invocation of `FOO` - // gets expanded at the point where it gets referenced - // in the body of the "outer" invocation of `FOO`. - // Doing things this way leads to greatly simplified - // code for handling expansion. - // - // A proper implementation of `IsMacroBusy` needs to - // take context into account, so that it bans recursive - // use of a macro when it occurs (indirectly) through - // the *body* of the expansion, but not when it occcurs - // only through an *argument*. - return false; -} - -// -// Reading Tokens With Expansion -// - -static void InitializeMacroExpansion( - Preprocessor* preprocessor, - MacroExpansion* expansion, - PreprocessorMacro* macro) -{ - initializeInputStream(preprocessor, expansion); - - expansion->parent = preprocessor->inputStream; - expansion->primaryStream = preprocessor->inputStream->primaryStream; - - expansion->environment = macro->environment; - expansion->macro = macro; - expansion->tokenReader = TokenReader(macro->tokens); -} - -static void PushMacroExpansion( - Preprocessor* preprocessor, - MacroExpansion* expansion) -{ - PushInputStream(preprocessor, expansion); -} - -static void AddEndOfStreamToken( - Preprocessor* preprocessor, - PreprocessorMacro* macro) -{ - Token token = PeekRawToken(preprocessor); - token.type = TokenType::EndOfFile; - macro->tokens.mTokens.add(token); -} - -static SimpleTokenInputStream* createSimpleInputStream( - Preprocessor* preprocessor, - Token const& token) -{ - SimpleTokenInputStream* inputStream = new SimpleTokenInputStream(); - initializeInputStream(preprocessor, inputStream); - - inputStream->lexedTokens.mTokens.add(token); - - Token eofToken; - eofToken.type = TokenType::EndOfFile; - eofToken.loc = token.loc; - eofToken.flags = TokenFlag::AfterWhitespace | TokenFlag::AtStartOfLine; - inputStream->lexedTokens.mTokens.add(eofToken); - - inputStream->tokenReader = TokenReader(inputStream->lexedTokens); - - return inputStream; -} - -// Check whether the current token on the given input stream should be -// treated as a macro invocation, and if so set up state for expanding -// that macro. -static void MaybeBeginMacroExpansion( - Preprocessor* preprocessor ) -{ - // We iterate because the first token in the expansion of one - // macro may be another macro invocation. - for (;;) - { - // Look at the next token ahead of us - Token token = PeekRawToken(preprocessor); - - // Not an identifier? Can't be a macro. - if (token.type != TokenType::Identifier) - return; - - // Look for a macro with the given name. - Name* name = token.getName(); - PreprocessorMacro* macro = LookupMacro(preprocessor, name); - - // Not a macro? Can't be an invocation. - if (!macro) - return; - - // If the macro is busy (already being expanded), - // don't try to trigger recursive expansion - if (IsMacroBusy(macro)) - return; - - // We might already have looked at this token, - // and need to suppress expansion - if (token.flags & TokenFlag::SuppressMacroExpansion) - return; - - // A function-style macro invocation should only match - // if the token *after* the identifier is `(`. This - // requires more lookahead than we usually have/need - if (macro->flavor == PreprocessorMacroFlavor::FunctionLike) - { - // Consume the token that (possibly) triggered macro expansion - AdvanceRawToken(preprocessor); - - // Look at the next token, and see if it is an opening `(` - // that indicates we should actually expand a macro. - if(PeekRawTokenType(preprocessor) != TokenType::LParent) - { - // In this case, we are in a bit of a mess, because we have - // consumed the token that named the macro, but we need to - // make sure that token (and not whatever came after it) - // gets returned to the user. - // - // To work around this we will construct a short-lived input - // stream just to handle that one token, and also set - // a flag on the token to keep us from doing this logic again. - - token.flags |= TokenFlag::SuppressMacroExpansion; - - SimpleTokenInputStream* simpleStream = createSimpleInputStream(preprocessor, token); - PushInputStream(preprocessor, simpleStream); - return; - } - - // Consume the opening `(` - Token leftParen = AdvanceRawToken(preprocessor); - - FunctionLikeMacroExpansion* expansion = new FunctionLikeMacroExpansion(); - InitializeMacroExpansion(preprocessor, expansion, macro); - expansion->argumentEnvironment.parent = &preprocessor->globalEnv; - expansion->environment = &expansion->argumentEnvironment; - - // Try to read any arguments present. - UInt paramCount = macro->params.getCount(); - UInt argIndex = 0; - - switch (PeekRawTokenType(preprocessor)) - { - case TokenType::EndOfFile: - case TokenType::RParent: - // No arguments. - break; - - default: - // At least one argument - while(argIndex < paramCount) - { - // Read an argument - - // Create the argument, represented as a special flavor of macro - PreprocessorMacro* arg = CreateMacro(preprocessor); - arg->flavor = PreprocessorMacroFlavor::FunctionArg; - arg->environment = GetCurrentEnvironment(preprocessor); - - // Associate the new macro with its parameter name - NameLoc paramNameAndLoc = macro->params[argIndex]; - Name* paramName = paramNameAndLoc.name; - arg->nameAndLoc = paramNameAndLoc; - expansion->argumentEnvironment.macros[paramName] = arg; - argIndex++; - - // Read tokens for the argument - - // We track the nesting depth, since we don't break - // arguments on a `,` nested in balanced parentheses - // - int nesting = 0; - for (;;) - { - switch (PeekRawTokenType(preprocessor)) - { - case TokenType::EndOfFile: - // if we reach the end of the file, - // then we have an error, and need to - // bail out - AddEndOfStreamToken(preprocessor, arg); - goto doneWithAllArguments; - - case TokenType::RParent: - // If we see a right paren when we aren't nested - // then we are at the end of an argument - if (nesting == 0) - { - AddEndOfStreamToken(preprocessor, arg); - goto doneWithAllArguments; - } - // Otherwise we decrease our nesting depth, add - // the token, and keep going - nesting--; - break; - - case TokenType::Comma: - // If we see a comma when we aren't nested - // then we are at the end of an argument - if (nesting == 0) - { - AddEndOfStreamToken(preprocessor, arg); - AdvanceRawToken(preprocessor); - goto doneWithArgument; - } - // Otherwise we add it as a normal token - break; - - case TokenType::LParent: - // If we see a left paren then we need to - // increase our tracking of nesting - nesting++; - break; - - default: - break; - } - - // Add the token and continue parsing. - arg->tokens.mTokens.add(AdvanceRawToken(preprocessor)); - } - doneWithArgument: {} - // We've parsed an argument and should move onto - // the next one. - } - break; - } - doneWithAllArguments: - // TODO: handle possible varargs - - // Expect closing right paren - if (PeekRawTokenType(preprocessor) == TokenType::RParent) - { - AdvanceRawToken(preprocessor); - } - else - { - GetSink(preprocessor)->diagnose(PeekLoc(preprocessor), Diagnostics::expectedTokenInMacroArguments, TokenType::RParent, PeekRawTokenType(preprocessor)); - } - - UInt argCount = argIndex; - if (argCount != paramCount) - { - GetSink(preprocessor)->diagnose(PeekLoc(preprocessor), Diagnostics::wrongNumberOfArgumentsToMacro, paramCount, argCount); - } - - // We are ready to expand. - PushMacroExpansion(preprocessor, expansion); - } - else - { - // Consume the token that triggered macro expansion - AdvanceRawToken(preprocessor); - - // Object-like macros are the easy case. - ObjectLikeMacroExpansion* expansion = new ObjectLikeMacroExpansion(); - InitializeMacroExpansion(preprocessor, expansion, macro); - PushMacroExpansion(preprocessor, expansion); - } - } -} - -// Read one token with macro-expansion enabled. -static Token AdvanceToken(Preprocessor* preprocessor) -{ -top: - // Check whether we need to macro expand at the cursor. - MaybeBeginMacroExpansion(preprocessor); - - // Read a raw token (now that expansion has been triggered) - Token token = AdvanceRawToken(preprocessor); - - // Check if we need to perform token pasting - if (PeekRawTokenType(preprocessor) != TokenType::PoundPound) - { - // If we aren't token pasting, then we are done - return token; - } - else - { - // We are pasting tokens, which could get messy - - StringBuilder sb; - sb << token.Content; - - while (PeekRawTokenType(preprocessor) == TokenType::PoundPound) - { - // Consume the `##` - AdvanceRawToken(preprocessor); - - // Possibly macro-expand the next token - MaybeBeginMacroExpansion(preprocessor); - - // Read the next raw token (now that expansion has been triggered) - Token nextToken = AdvanceRawToken(preprocessor); - - sb << nextToken.Content; - } - - // Now re-lex the input - - SourceManager* sourceManager = preprocessor->getSourceManager(); - - // We create a dummy file to represent the token-paste operation - PathInfo pathInfo = PathInfo::makeTokenPaste(); - - SourceFile* sourceFile = sourceManager->createSourceFileWithString(pathInfo, sb.ProduceString()); - SourceView* sourceView = sourceManager->createSourceView(sourceFile, nullptr); - - Lexer lexer; - lexer.initialize(sourceView, GetSink(preprocessor), preprocessor->getNamePool(), sourceManager->getMemoryArena()); - - SimpleTokenInputStream* inputStream = new SimpleTokenInputStream(); - initializeInputStream(preprocessor, inputStream); - - inputStream->lexedTokens = lexer.lexAllTokens(); - inputStream->tokenReader = TokenReader(inputStream->lexedTokens); - - // We expect the reuslt of lexing to be two tokens: one for the actual value, - // and one for the end-of-input marker. - if (inputStream->tokenReader.GetCount() != 2) - { - // We expect a token paste to produce a single token - // TODO(tfoley): emit a diagnostic here - } - - PushInputStream(preprocessor, inputStream); - goto top; - } -} - -// Read one token with macro-expansion enabled. -// -// Note that because triggering macro expansion may -// involve changing the input-stream state, this -// operation *can* have side effects. -static Token PeekToken(Preprocessor* preprocessor) -{ - // Check whether we need to macro expand at the cursor. - MaybeBeginMacroExpansion(preprocessor); - - // Peek a raw token (now that expansion has been triggered) - return PeekRawToken(preprocessor); - - // TODO: need a plan for how to handle token pasting - // here without it being onerous. Would be nice if we - // didn't have to re-do pasting on a "peek"... -} - -// Peek the type of the next token, including macro expansion. -static TokenType PeekTokenType(Preprocessor* preprocessor) -{ - return PeekToken(preprocessor).type; -} - -// -// Preprocessor Directives -// - -// When reading a preprocessor directive, we use a context -// to wrap the direct preprocessor routines defines so far. -// -// One of the most important things the directive context -// does is give us a convenient way to read tokens with -// a guarantee that we won't read past the end of a line. -struct PreprocessorDirectiveContext -{ - // The preprocessor that is parsing the directive. - Preprocessor* preprocessor; - - // The directive token (e.g., the `if` in `#if`). - // Useful for reference in diagnostic messages. - Token directiveToken; - - // Has any kind of parse error been encountered in - // the directive so far? - bool parseError; - - // Have we done the necessary checks at the end - // of the directive already? - bool haveDoneEndOfDirectiveChecks; -}; - -// Get the token for the preprocessor directive being parsed. -inline Token const& GetDirective(PreprocessorDirectiveContext* context) -{ - return context->directiveToken; -} - -// Get the name of the directive being parsed. -inline UnownedStringSlice const& GetDirectiveName(PreprocessorDirectiveContext* context) -{ - return context->directiveToken.Content; -} - -// Get the location of the directive being parsed. -inline SourceLoc const& GetDirectiveLoc(PreprocessorDirectiveContext* context) -{ - return context->directiveToken.loc; -} - -// Wrapper to get the diagnostic sink in the context of a directive. -static inline DiagnosticSink* GetSink(PreprocessorDirectiveContext* context) -{ - return GetSink(context->preprocessor); -} - -// Wrapper to get a "current" location when parsing a directive -static SourceLoc PeekLoc(PreprocessorDirectiveContext* context) -{ - return PeekLoc(context->preprocessor); -} - -// Wrapper to look up a macro in the context of a directive. -static PreprocessorMacro* LookupMacro(PreprocessorDirectiveContext* context, Name* name) -{ - return LookupMacro(context->preprocessor, name); -} - -// Determine if we have read everything on the directive's line. -static bool IsEndOfLine(PreprocessorDirectiveContext* context) -{ - return PeekRawToken(context->preprocessor).type == TokenType::EndOfDirective; -} - -// Peek one raw token in a directive, without going past the end of the line. -static Token PeekRawToken(PreprocessorDirectiveContext* context) -{ - return PeekRawToken(context->preprocessor); -} - -// Read one raw token in a directive, without going past the end of the line. -static Token AdvanceRawToken(PreprocessorDirectiveContext* context, LexerFlags lexerFlags = 0) -{ - if (IsEndOfLine(context)) - return PeekRawToken(context); - return AdvanceRawToken(context->preprocessor, lexerFlags); -} - -// Peek next raw token type, without going past the end of the line. -static TokenType PeekRawTokenType(PreprocessorDirectiveContext* context) -{ - return PeekRawTokenType(context->preprocessor); -} - -// Read one token, with macro-expansion, without going past the end of the line. -static Token AdvanceToken(PreprocessorDirectiveContext* context) -{ - if (IsEndOfLine(context)) - return PeekRawToken(context); - return AdvanceToken(context->preprocessor); -} - -// Peek one token, with macro-expansion, without going past the end of the line. -static Token PeekToken(PreprocessorDirectiveContext* context) -{ - if (IsEndOfLine(context)) - return context->preprocessor->endOfFileToken; - return PeekToken(context->preprocessor); -} - -// Peek next token type, with macro-expansion, without going past the end of the line. -static TokenType PeekTokenType(PreprocessorDirectiveContext* context) -{ - if (IsEndOfLine(context)) - return TokenType::EndOfDirective; - return PeekTokenType(context->preprocessor); -} - -// Skip to the end of the line (useful for recovering from errors in a directive) -static void SkipToEndOfLine(PreprocessorDirectiveContext* context) -{ - while(!IsEndOfLine(context)) - { - AdvanceRawToken(context); - } -} - -static bool ExpectRaw(PreprocessorDirectiveContext* context, TokenType tokenType, DiagnosticInfo const& diagnostic, Token* outToken = NULL) -{ - if (PeekRawTokenType(context) != tokenType) - { - // Only report the first parse error within a directive - if (!context->parseError) - { - GetSink(context)->diagnose(PeekLoc(context), diagnostic, tokenType, GetDirectiveName(context)); - } - context->parseError = true; - return false; - } - Token const& token = AdvanceRawToken(context); - if (outToken) - *outToken = token; - return true; -} - -static bool Expect(PreprocessorDirectiveContext* context, TokenType tokenType, DiagnosticInfo const& diagnostic, Token* outToken = NULL) -{ - if (PeekTokenType(context) != tokenType) - { - // Only report the first parse error within a directive - if (!context->parseError) - { - GetSink(context)->diagnose(PeekLoc(context), diagnostic, tokenType, GetDirectiveName(context)); - context->parseError = true; - } - return false; - } - Token const& token = AdvanceToken(context); - if (outToken) - *outToken = token; - return true; -} - - - -// -// Preprocessor Conditionals -// - -// Determine whether the current preprocessor state means we -// should be skipping tokens. -static bool IsSkipping(Preprocessor* preprocessor) -{ - PreprocessorInputStream* inputStream = preprocessor->inputStream; - if (!inputStream) return false; - - PrimaryInputStream* primaryStream = inputStream->primaryStream; - if(!primaryStream) return false; - - // If we are not inside a preprocessor conditional, then don't skip - PreprocessorConditional* conditional = primaryStream->conditional; - if (!conditional) return false; - - // skip tokens unless the conditional is inside its `true` case - return conditional->state != PreprocessorConditionalState::During; -} - -// Wrapper for use inside directives -static inline bool IsSkipping(PreprocessorDirectiveContext* context) -{ - return IsSkipping(context->preprocessor); -} - -// Create a preprocessor conditional -static PreprocessorConditional* CreateConditional(Preprocessor* /*preprocessor*/) -{ - // TODO(tfoley): allocate these more intelligently (for example, - // pool them on the `Preprocessor`. - return new PreprocessorConditional(); -} - -// Destroy a preprocessor conditional. -static void DestroyConditional(PreprocessorConditional* conditional) -{ - delete conditional; -} - -// Start a preprocessor conditional, with an initial enable/disable state. -static void beginConditional( - PreprocessorDirectiveContext* context, - PreprocessorInputStream* inputStream, - bool enable) -{ - Preprocessor* preprocessor = context->preprocessor; - SLANG_ASSERT(inputStream); - - PreprocessorConditional* conditional = CreateConditional(preprocessor); - - conditional->ifToken = context->directiveToken; - - // Set state of this condition appropriately. - // - // Default to the "haven't yet seen a `true` branch" state. - PreprocessorConditionalState state = PreprocessorConditionalState::Before; - // - // If we are nested inside a `false` branch of another condition, then - // we never want to enable, so we act as if we already *saw* the `true` branch. - // - if (IsSkipping(preprocessor)) state = PreprocessorConditionalState::After; - // - // Similarly, if we ran into any parse errors when dealing with the - // opening directive, then things are probably screwy and we should just - // skip all the branches. - if (IsSkipping(preprocessor)) state = PreprocessorConditionalState::After; - // - // Otherwise, if our condition was true, then set us to be inside the `true` branch - else if (enable) state = PreprocessorConditionalState::During; - - conditional->state = state; - - // Push conditional onto the stack - auto primaryStream = inputStream->primaryStream; - conditional->parent = primaryStream->conditional; - primaryStream->conditional = conditional; -} - -// Start a preprocessor conditional, with an initial enable/disable state. -static void beginConditional( - PreprocessorDirectiveContext* context, - bool enable) -{ - beginConditional(context, context->preprocessor->inputStream, enable); -} - -// -// Preprocessor Conditional Expressions -// - -// Conditional expressions are always of type `int` -typedef int PreprocessorExpressionValue; - -// Forward-declaretion -static PreprocessorExpressionValue ParseAndEvaluateExpression(PreprocessorDirectiveContext* context); - -// Parse a unary (prefix) expression inside of a preprocessor directive. -static PreprocessorExpressionValue ParseAndEvaluateUnaryExpression(PreprocessorDirectiveContext* context) -{ - switch (PeekTokenType(context)) - { - // handle prefix unary ops - case TokenType::OpSub: - AdvanceToken(context); - return -ParseAndEvaluateUnaryExpression(context); - case TokenType::OpNot: - AdvanceToken(context); - return !ParseAndEvaluateUnaryExpression(context); - case TokenType::OpBitNot: - AdvanceToken(context); - return ~ParseAndEvaluateUnaryExpression(context); - - // handle parenthized sub-expression - case TokenType::LParent: - { - Token leftParen = AdvanceToken(context); - PreprocessorExpressionValue value = ParseAndEvaluateExpression(context); - if (!Expect(context, TokenType::RParent, Diagnostics::expectedTokenInPreprocessorExpression)) - { - GetSink(context)->diagnose(leftParen.loc, Diagnostics::seeOpeningToken, leftParen); - } - return value; - } - - case TokenType::IntegerLiteral: - return StringToInt(AdvanceToken(context).Content); - - case TokenType::Identifier: - { - Token token = AdvanceToken(context); - if (token.Content == "defined") - { - // handle `defined(someName)` - - // Possibly parse a `(` - Token leftParen; - if (PeekRawTokenType(context) == TokenType::LParent) - { - leftParen = AdvanceRawToken(context); - } - - // Expect an identifier - Token nameToken; - if (!ExpectRaw(context, TokenType::Identifier, Diagnostics::expectedTokenInDefinedExpression, &nameToken)) - { - return 0; - } - Name* name = nameToken.getName(); - - // If we saw an opening `(`, then expect one to close - if (leftParen.type != TokenType::Unknown) - { - if(!ExpectRaw(context, TokenType::RParent, Diagnostics::expectedTokenInDefinedExpression)) - { - GetSink(context)->diagnose(leftParen.loc, Diagnostics::seeOpeningToken, leftParen); - return 0; - } - } - - return LookupMacro(context, name) != NULL; - } - - // An identifier here means it was not defined as a macro (or - // it is defined, but as a function-like macro. These should - // just evaluate to zero (possibly with a warning) - GetSink(context)->diagnose(token.loc, Diagnostics::undefinedIdentifierInPreprocessorExpression, token.getName()); - return 0; - } - - default: - GetSink(context)->diagnose(PeekLoc(context), Diagnostics::syntaxErrorInPreprocessorExpression); - return 0; - } -} - -// Determine the precedence level of an infix operator -// for use in parsing preprocessor conditionals. -static int GetInfixOpPrecedence(Token const& opToken) -{ - // If token is on another line, it is not part of the - // expression - if (opToken.flags & TokenFlag::AtStartOfLine) - return -1; - - // otherwise we look at the token type to figure - // out what precedence it should be parse with - switch (opToken.type) - { - default: - // tokens that aren't infix operators should - // cause us to stop parsing an expression - return -1; - - case TokenType::OpMul: return 10; - case TokenType::OpDiv: return 10; - case TokenType::OpMod: return 10; - - case TokenType::OpAdd: return 9; - case TokenType::OpSub: return 9; - - case TokenType::OpLsh: return 8; - case TokenType::OpRsh: return 8; - - case TokenType::OpLess: return 7; - case TokenType::OpGreater: return 7; - case TokenType::OpLeq: return 7; - case TokenType::OpGeq: return 7; - - case TokenType::OpEql: return 6; - case TokenType::OpNeq: return 6; - - case TokenType::OpBitAnd: return 5; - case TokenType::OpBitOr: return 4; - case TokenType::OpBitXor: return 3; - case TokenType::OpAnd: return 2; - case TokenType::OpOr: return 1; - } -}; - -// Evaluate one infix operation in a preprocessor -// conditional expression -static PreprocessorExpressionValue EvaluateInfixOp( - PreprocessorDirectiveContext* context, - Token const& opToken, - PreprocessorExpressionValue left, - PreprocessorExpressionValue right) -{ - switch (opToken.type) - { - default: -// SLANG_INTERNAL_ERROR(getSink(preprocessor), opToken); - return 0; - break; - - case TokenType::OpMul: return left * right; - case TokenType::OpDiv: - { - if (right == 0) - { - if (!context->parseError) - { - GetSink(context)->diagnose(opToken.loc, Diagnostics::divideByZeroInPreprocessorExpression); - } - return 0; - } - return left / right; - } - case TokenType::OpMod: - { - if (right == 0) - { - if (!context->parseError) - { - GetSink(context)->diagnose(opToken.loc, Diagnostics::divideByZeroInPreprocessorExpression); - } - return 0; - } - return left % right; - } - case TokenType::OpAdd: return left + right; - case TokenType::OpSub: return left - right; - case TokenType::OpLsh: return left << right; - case TokenType::OpRsh: return left >> right; - case TokenType::OpLess: return left < right ? 1 : 0; - case TokenType::OpGreater: return left > right ? 1 : 0; - case TokenType::OpLeq: return left <= right ? 1 : 0; - case TokenType::OpGeq: return left >= right ? 1 : 0; - case TokenType::OpEql: return left == right ? 1 : 0; - case TokenType::OpNeq: return left != right ? 1 : 0; - case TokenType::OpBitAnd: return left & right; - case TokenType::OpBitOr: return left | right; - case TokenType::OpBitXor: return left ^ right; - case TokenType::OpAnd: return left && right; - case TokenType::OpOr: return left || right; - } -} - -// Parse the rest of an infix preprocessor expression with -// precedence greater than or equal to the given `precedence` argument. -// The value of the left-hand-side expression is provided as -// an argument. -// This is used to form a simple recursive-descent expression parser. -static PreprocessorExpressionValue ParseAndEvaluateInfixExpressionWithPrecedence( - PreprocessorDirectiveContext* context, - PreprocessorExpressionValue left, - int precedence) -{ - for (;;) - { - // Look at the next token, and see if it is an operator of - // high enough precedence to be included in our expression - Token opToken = PeekToken(context); - int opPrecedence = GetInfixOpPrecedence(opToken); - - // If it isn't an operator of high enough precedence, we are done. - if(opPrecedence < precedence) - break; - - // Otherwise we need to consume the operator token. - AdvanceToken(context); - - // Next we parse a right-hand-side expression by starting with - // a unary expression and absorbing and many infix operators - // as possible with strictly higher precedence than the operator - // we found above. - PreprocessorExpressionValue right = ParseAndEvaluateUnaryExpression(context); - for (;;) - { - // Look for an operator token - Token rightOpToken = PeekToken(context); - int rightOpPrecedence = GetInfixOpPrecedence(rightOpToken); - - // If no operator was found, or the operator wasn't high - // enough precedence to fold into the right-hand-side, - // exit this loop. - if (rightOpPrecedence <= opPrecedence) - break; - - // Now invoke the parser recursively, passing in our - // existing right-hand side to form an even larger one. - right = ParseAndEvaluateInfixExpressionWithPrecedence( - context, - right, - rightOpPrecedence); - } - - // Now combine the left- and right-hand sides using - // the operator we found above. - left = EvaluateInfixOp(context, opToken, left, right); - } - return left; -} - -// Parse a complete (infix) preprocessor expression, and return its value -static PreprocessorExpressionValue ParseAndEvaluateExpression(PreprocessorDirectiveContext* context) -{ - // First read in the left-hand side (or the whole expression in the unary case) - PreprocessorExpressionValue value = ParseAndEvaluateUnaryExpression(context); - - // Try to read in trailing infix operators with correct precedence - return ParseAndEvaluateInfixExpressionWithPrecedence(context, value, 0); -} - -// Handle a `#if` directive -static void HandleIfDirective(PreprocessorDirectiveContext* context) -{ - // Record current input stream in case preprocessor expression - // changes the input stream to a macro expansion while we - // are parsing. - auto inputStream = context->preprocessor->inputStream; - - // If we are skipping, we can just consume the expression, and assume true - if (IsSkipping(context->preprocessor)) - { - // Consume everything until the end of the line - SkipToEndOfLine(context); - // Begin a preprocessor block, assume true based on the expression - // (contents will all be ignored because skipping). - beginConditional(context, inputStream, true); - } - else - { - // Parse a preprocessor expression. - PreprocessorExpressionValue value = ParseAndEvaluateExpression(context); - - // Begin a preprocessor block, enabled based on the expression. - beginConditional(context, inputStream, value != 0); - } -} - -// Handle a `#ifdef` directive -static void HandleIfDefDirective(PreprocessorDirectiveContext* context) -{ - // Expect a raw identifier, so we can check if it is defined - Token nameToken; - if(!ExpectRaw(context, TokenType::Identifier, Diagnostics::expectedTokenInPreprocessorDirective, &nameToken)) - return; - Name* name = nameToken.getName(); - - // Check if the name is defined. - beginConditional(context, LookupMacro(context, name) != NULL); -} - -// Handle a `#ifndef` directive -static void HandleIfNDefDirective(PreprocessorDirectiveContext* context) -{ - // Expect a raw identifier, so we can check if it is defined - Token nameToken; - if(!ExpectRaw(context, TokenType::Identifier, Diagnostics::expectedTokenInPreprocessorDirective, &nameToken)) - return; - Name* name = nameToken.getName(); - - // Check if the name is defined. - beginConditional(context, LookupMacro(context, name) == NULL); -} - -// Handle a `#else` directive -static void HandleElseDirective(PreprocessorDirectiveContext* context) -{ - PreprocessorInputStream* inputStream = context->preprocessor->inputStream; - SLANG_ASSERT(inputStream); - - // if we aren't inside a conditional, then error - PreprocessorConditional* conditional = inputStream->primaryStream->conditional; - if (!conditional) - { - GetSink(context)->diagnose(GetDirectiveLoc(context), Diagnostics::directiveWithoutIf, GetDirectiveName(context)); - return; - } - - // if we've already seen a `#else`, then it is an error - if (conditional->elseToken.type != TokenType::Unknown) - { - GetSink(context)->diagnose(GetDirectiveLoc(context), Diagnostics::directiveAfterElse, GetDirectiveName(context)); - GetSink(context)->diagnose(conditional->elseToken.loc, Diagnostics::seeDirective); - return; - } - conditional->elseToken = context->directiveToken; - - switch (conditional->state) - { - case PreprocessorConditionalState::Before: - conditional->state = PreprocessorConditionalState::During; - break; - - case PreprocessorConditionalState::During: - conditional->state = PreprocessorConditionalState::After; - break; - - default: - break; - } -} - -// Handle a `#elif` directive -static void HandleElifDirective(PreprocessorDirectiveContext* context) -{ - // Need to grab current input stream *before* we try to parse - // the conditional expression. - PreprocessorInputStream* inputStream = context->preprocessor->inputStream; - SLANG_ASSERT(inputStream); - - // HACK(tfoley): handle an empty `elif` like an `else` directive - // - // This is the behavior expected by at least one input program. - // We will eventually want to be pedantic about this. - // even if t - if (PeekRawTokenType(context) == TokenType::EndOfDirective) - { - GetSink(context)->diagnose(GetDirectiveLoc(context), Diagnostics::directiveExpectsExpression, GetDirectiveName(context)); - HandleElseDirective(context); - return; - } - - PreprocessorExpressionValue value = ParseAndEvaluateExpression(context); - - // if we aren't inside a conditional, then error - PreprocessorConditional* conditional = inputStream->primaryStream->conditional; - if (!conditional) - { - GetSink(context)->diagnose(GetDirectiveLoc(context), Diagnostics::directiveWithoutIf, GetDirectiveName(context)); - return; - } - - // if we've already seen a `#else`, then it is an error - if (conditional->elseToken.type != TokenType::Unknown) - { - GetSink(context)->diagnose(GetDirectiveLoc(context), Diagnostics::directiveAfterElse, GetDirectiveName(context)); - GetSink(context)->diagnose(conditional->elseToken.loc, Diagnostics::seeDirective); - return; - } - - switch (conditional->state) - { - case PreprocessorConditionalState::Before: - if(value) - conditional->state = PreprocessorConditionalState::During; - break; - - case PreprocessorConditionalState::During: - conditional->state = PreprocessorConditionalState::After; - break; - - default: - break; - } -} - -// Handle a `#endif` directive -static void HandleEndIfDirective(PreprocessorDirectiveContext* context) -{ - PreprocessorInputStream* inputStream = context->preprocessor->inputStream; - SLANG_ASSERT(inputStream); - - // if we aren't inside a conditional, then error - PreprocessorConditional* conditional = inputStream->primaryStream->conditional; - if (!conditional) - { - GetSink(context)->diagnose(GetDirectiveLoc(context), Diagnostics::directiveWithoutIf, GetDirectiveName(context)); - return; - } - - inputStream->primaryStream->conditional = conditional->parent; - DestroyConditional(conditional); -} - -// Helper routine to check that we find the end of a directive where -// we expect it. -// -// Most directives do not need to call this directly, since we have -// a catch-all case in the main `HandleDirective()` function. -// The `#include` case will call it directly to avoid complications -// when it switches the input stream. -static void expectEndOfDirective(PreprocessorDirectiveContext* context) -{ - if(context->haveDoneEndOfDirectiveChecks) - return; - - context->haveDoneEndOfDirectiveChecks = true; - - if (!IsEndOfLine(context)) - { - // If we already saw a previous parse error, then don't - // emit another one for the same directive. - if (!context->parseError) - { - GetSink(context)->diagnose(PeekLoc(context), Diagnostics::unexpectedTokensAfterDirective, GetDirectiveName(context)); - } - SkipToEndOfLine(context); - } - - // Clear out the end-of-directive token - AdvanceRawToken(context->preprocessor); -} - - /// Read a file in the context of handling a preprocessor directive -static SlangResult readFile( - PreprocessorDirectiveContext* context, - String const& path, - ISlangBlob** outBlob) -{ - // The actual file loading will be handled by the file system - // associated with the parent linkage. - // - auto linkage = context->preprocessor->linkage; - auto fileSystemExt = linkage->getFileSystemExt(); - SLANG_RETURN_ON_FAIL(fileSystemExt->loadFile(path.getBuffer(), outBlob)); - - // If we are running the preprocessor as part of compiling a - // specific module, then we must keep track of the file we've - // read as yet another file that the module will depend on. - // - if(auto module = context->preprocessor->parentModule) - { - module->addFilePathDependency(path); - } - - return SLANG_OK; -} - -// Handle a `#include` directive -static void HandleIncludeDirective(PreprocessorDirectiveContext* context) -{ - // Consume the directive, and inform the lexer to process the remainder of the line as a file path. - AdvanceRawToken(context, kLexerFlag_ExpectFileName); - - Token pathToken; - if(!Expect(context, TokenType::StringLiteral, Diagnostics::expectedTokenInPreprocessorDirective, &pathToken)) - return; - - String path = getFileNameTokenValue(pathToken); - - auto directiveLoc = GetDirectiveLoc(context); - - PathInfo includedFromPathInfo = context->preprocessor->getSourceManager()->getPathInfo(directiveLoc, SourceLocType::Actual); - - IncludeHandler* includeHandler = context->preprocessor->includeHandler; - if (!includeHandler) - { - GetSink(context)->diagnose(pathToken.loc, Diagnostics::includeFailed, path); - GetSink(context)->diagnose(pathToken.loc, Diagnostics::noIncludeHandlerSpecified); - return; - } - - /* Find the path relative to the foundPath */ - PathInfo filePathInfo; - if (SLANG_FAILED(includeHandler->findFile(path, includedFromPathInfo.foundPath, filePathInfo))) - { - GetSink(context)->diagnose(pathToken.loc, Diagnostics::includeFailed, path); - return; - } - - // We must have a uniqueIdentity to be compare - if (!filePathInfo.hasUniqueIdentity()) - { - GetSink(context)->diagnose(pathToken.loc, Diagnostics::noUniqueIdentity, path); - return; - } - - // Do all checking related to the end of this directive before we push a new stream, - // just to avoid complications where that check would need to deal with - // a switch of input stream - expectEndOfDirective(context); - - // Check whether we've previously included this file and seen a `#pragma once` directive - if(context->preprocessor->pragmaOnceUniqueIdentities.Contains(filePathInfo.uniqueIdentity)) - { - return; - } - - // Simplify the path - filePathInfo.foundPath = includeHandler->simplifyPath(filePathInfo.foundPath); - - // Push the new file onto our stack of input streams - // TODO(tfoley): check if we have made our include stack too deep - auto sourceManager = context->preprocessor->getSourceManager(); - - // See if this an already loaded source file - SourceFile* sourceFile = sourceManager->findSourceFileRecursively(filePathInfo.uniqueIdentity); - // If not create a new one, and add to the list of known source files - if (!sourceFile) - { - ComPtr foundSourceBlob; - if (SLANG_FAILED(readFile(context, filePathInfo.foundPath, foundSourceBlob.writeRef()))) - { - GetSink(context)->diagnose(pathToken.loc, Diagnostics::includeFailed, path); - return; - } - - - sourceFile = sourceManager->createSourceFileWithBlob(filePathInfo, foundSourceBlob); - sourceManager->addSourceFile(filePathInfo.uniqueIdentity, sourceFile); - } - - // This is a new parse (even if it's a pre-existing source file), so create a new SourceUnit - SourceView* sourceView = sourceManager->createSourceView(sourceFile, &filePathInfo); - - PreprocessorInputStream* inputStream = CreateInputStreamForSource(context->preprocessor, sourceView); - inputStream->parent = context->preprocessor->inputStream; - context->preprocessor->inputStream = inputStream; -} - -// Handle a `#define` directive -static void HandleDefineDirective(PreprocessorDirectiveContext* context) -{ - Token nameToken; - if (!ExpectRaw(context, TokenType::Identifier, Diagnostics::expectedTokenInPreprocessorDirective, &nameToken)) - return; - Name* name = nameToken.getName(); - - PreprocessorMacro* macro = CreateMacro(context->preprocessor); - macro->nameAndLoc = NameLoc(nameToken); - - PreprocessorMacro* oldMacro = LookupMacro(&context->preprocessor->globalEnv, name); - if (oldMacro) - { - GetSink(context)->diagnose(nameToken.loc, Diagnostics::macroRedefinition, name); - GetSink(context)->diagnose(oldMacro->getLoc(), Diagnostics::seePreviousDefinitionOf, name); - - DestroyMacro(context->preprocessor, oldMacro); - } - context->preprocessor->globalEnv.macros[name] = macro; - - // If macro name is immediately followed (with no space) by `(`, - // then we have a function-like macro - if (PeekRawTokenType(context) == TokenType::LParent) - { - if (!(PeekRawToken(context).flags & TokenFlag::AfterWhitespace)) - { - // This is a function-like macro, so we need to remember that - // and start capturing parameters - macro->flavor = PreprocessorMacroFlavor::FunctionLike; - - AdvanceRawToken(context); - - // If there are any parameters, parse them - if (PeekRawTokenType(context) != TokenType::RParent) - { - for (;;) - { - // TODO: handle elipsis (`...`) for varags - - // A macro parameter name should be a raw identifier - Token paramToken; - if (!ExpectRaw(context, TokenType::Identifier, Diagnostics::expectedTokenInMacroParameters, ¶mToken)) - break; - - // TODO(tfoley): some validation on parameter name. - // Certain names (e.g., `defined` and `__VA_ARGS__` - // are not allowed to be used as macros or parameters). - - // Add the parameter to the macro being deifned - macro->params.add(paramToken); - - // If we see `)` then we are done with arguments - if (PeekRawTokenType(context) == TokenType::RParent) - break; - - ExpectRaw(context, TokenType::Comma, Diagnostics::expectedTokenInMacroParameters); - } - } - - ExpectRaw(context, TokenType::RParent, Diagnostics::expectedTokenInMacroParameters); - } - } - - // consume tokens until end-of-line - for(;;) - { - Token token = AdvanceRawToken(context); - if( token.type == TokenType::EndOfDirective ) - { - // Last token on line will be turned into a conceptual end-of-file - // token for the sub-stream that the macro expands into. - token.type = TokenType::EndOfFile; - macro->tokens.mTokens.add(token); - break; - } - - // In the ordinary case, we just add the token to the definition - macro->tokens.mTokens.add(token); - } -} - -// Handle a `#undef` directive -static void HandleUndefDirective(PreprocessorDirectiveContext* context) -{ - Token nameToken; - if (!ExpectRaw(context, TokenType::Identifier, Diagnostics::expectedTokenInPreprocessorDirective, &nameToken)) - return; - Name* name = nameToken.getName(); - - PreprocessorEnvironment* env = &context->preprocessor->globalEnv; - PreprocessorMacro* macro = LookupMacro(env, name); - if (macro != NULL) - { - // name was defined, so remove it - env->macros.Remove(name); - - DestroyMacro(context->preprocessor, macro); - } - else - { - // name wasn't defined - GetSink(context)->diagnose(nameToken.loc, Diagnostics::macroNotDefined, name); - } -} - -// Handle a `#warning` directive -static void HandleWarningDirective(PreprocessorDirectiveContext* context) -{ - // Consume the directive, and inform the lexer to process the remainder of the line as a custom message. - AdvanceRawToken(context, kLexerFlag_ExpectDirectiveMessage); - - // Read the message token. - Token messageToken; - Expect(context, TokenType::DirectiveMessage, Diagnostics::expectedTokenInPreprocessorDirective, &messageToken); - - // Report the custom error. - GetSink(context)->diagnose(GetDirectiveLoc(context), Diagnostics::userDefinedWarning, messageToken.Content); -} - -// Handle a `#error` directive -static void HandleErrorDirective(PreprocessorDirectiveContext* context) -{ - // Consume the directive, and inform the lexer to process the remainder of the line as a custom message. - AdvanceRawToken(context, kLexerFlag_ExpectDirectiveMessage); - - // Read the message token. - Token messageToken; - Expect(context, TokenType::DirectiveMessage, Diagnostics::expectedTokenInPreprocessorDirective, &messageToken); - - // Report the custom error. - GetSink(context)->diagnose(GetDirectiveLoc(context), Diagnostics::userDefinedError, messageToken.Content); -} - -// Handle a `#line` directive -static void HandleLineDirective(PreprocessorDirectiveContext* context) -{ - auto inputStream = context->preprocessor->inputStream; - - int line = 0; - - SourceLoc directiveLoc = GetDirectiveLoc(context); - - // `#line ...` - if (PeekTokenType(context) == TokenType::IntegerLiteral) - { - line = StringToInt(AdvanceToken(context).Content); - } - // `#line` - // `#line default` - else if ( - PeekTokenType(context) == TokenType::EndOfDirective - || (PeekTokenType(context) == TokenType::Identifier - && PeekToken(context).Content == "default")) - { - AdvanceToken(context); - - // Stop overriding source locations. - auto sourceView = inputStream->primaryStream->lexer.sourceView; - sourceView->addDefaultLineDirective(directiveLoc); - return; - } - else - { - GetSink(context)->diagnose(PeekLoc(context), Diagnostics::expected2TokensInPreprocessorDirective, - TokenType::IntegerLiteral, - "default", - GetDirectiveName(context)); - context->parseError = true; - return; - } - - auto sourceManager = context->preprocessor->getSourceManager(); - - String file; - if (PeekTokenType(context) == TokenType::EndOfDirective) - { - file = sourceManager->getPathInfo(directiveLoc).foundPath; - } - else if (PeekTokenType(context) == TokenType::StringLiteral) - { - file = getStringLiteralTokenValue(AdvanceToken(context)); - } - else if (PeekTokenType(context) == TokenType::IntegerLiteral) - { - // Note(tfoley): GLSL allows the "source string" to be indicated by an integer - // TODO(tfoley): Figure out a better way to handle this, if it matters - file = AdvanceToken(context).Content; - } - else - { - Expect(context, TokenType::StringLiteral, Diagnostics::expectedTokenInPreprocessorDirective); - return; - } - - auto sourceView = inputStream->primaryStream->lexer.sourceView; - sourceView->addLineDirective(directiveLoc, file, line); -} - -#define SLANG_PRAGMA_DIRECTIVE_CALLBACK(NAME) \ - void NAME(PreprocessorDirectiveContext* context, Token subDirectiveToken) - -// Callback interface used by `#pragma` directives -typedef SLANG_PRAGMA_DIRECTIVE_CALLBACK((*PragmaDirectiveCallback)); - -SLANG_PRAGMA_DIRECTIVE_CALLBACK(handleUnknownPragmaDirective) -{ - GetSink(context)->diagnose(subDirectiveToken, Diagnostics::unknownPragmaDirectiveIgnored, subDirectiveToken.getName()); - SkipToEndOfLine(context); - return; -} - -SLANG_PRAGMA_DIRECTIVE_CALLBACK(handlePragmaOnceDirective) -{ - // We need to identify the path of the file we are preprocessing, - // so that we can avoid including it again. - // - // We are using the 'uniqueIdentity' as determined by the ISlangFileSystemEx interface to determine file identities. - - auto directiveLoc = GetDirectiveLoc(context); - auto issuedFromPathInfo = context->preprocessor->getSourceManager()->getPathInfo(directiveLoc, SourceLocType::Actual); - - // Must have uniqueIdentity for a #pragma once to work - if (!issuedFromPathInfo.hasUniqueIdentity()) - { - GetSink(context)->diagnose(subDirectiveToken, Diagnostics::pragmaOnceIgnored); - return; - } - - context->preprocessor->pragmaOnceUniqueIdentities.Add(issuedFromPathInfo.uniqueIdentity); -} - -// Information about a specific `#pragma` directive -struct PragmaDirective -{ - // name of the directive - char const* name; - - // Callback to handle the directive - PragmaDirectiveCallback callback; -}; - -// A simple array of all the `#pragma` directives we know how to handle. -static const PragmaDirective kPragmaDirectives[] = -{ - { "once", &handlePragmaOnceDirective }, - - { NULL, NULL }, -}; - -static const PragmaDirective kUnknownPragmaDirective = { - NULL, &handleUnknownPragmaDirective, -}; - -// Look up the `#pragma` directive with the given name. -static PragmaDirective const* findPragmaDirective(String const& name) -{ - char const* nameStr = name.getBuffer(); - for (int ii = 0; kPragmaDirectives[ii].name; ++ii) - { - if (strcmp(kPragmaDirectives[ii].name, nameStr) != 0) - continue; - - return &kPragmaDirectives[ii]; - } - - return &kUnknownPragmaDirective; -} - -// Handle a `#pragma` directive -static void HandlePragmaDirective(PreprocessorDirectiveContext* context) -{ - // Try to read the sub-directive name. - Token subDirectiveToken = PeekRawToken(context); - - // The sub-directive had better be an identifier - if (subDirectiveToken.type != TokenType::Identifier) - { - GetSink(context)->diagnose(GetDirectiveLoc(context), Diagnostics::expectedPragmaDirectiveName); - SkipToEndOfLine(context); - return; - } - AdvanceRawToken(context); - - // Look up the handler for the sub-directive. - PragmaDirective const* subDirective = findPragmaDirective(subDirectiveToken.getName()->text); - - // Apply the sub-directive-specific callback - (subDirective->callback)(context, subDirectiveToken); -} - -// Handle an invalid directive -static void HandleInvalidDirective(PreprocessorDirectiveContext* context) -{ - GetSink(context)->diagnose(GetDirectiveLoc(context), Diagnostics::unknownPreprocessorDirective, GetDirectiveName(context)); - SkipToEndOfLine(context); -} - -// Callback interface used by preprocessor directives -typedef void (*PreprocessorDirectiveCallback)(PreprocessorDirectiveContext* context); - -enum PreprocessorDirectiveFlag : unsigned int -{ - // Should this directive be handled even when skipping disbaled code? - ProcessWhenSkipping = 1 << 0, - - /// Allow the handler for this directive to advance past the - /// directive token itself, so that it can control lexer behavior - /// more closely. - DontConsumeDirectiveAutomatically = 1 << 1, -}; - -// Information about a specific directive -struct PreprocessorDirective -{ - // name of the directive - char const* name; - - // Callback to handle the directive - PreprocessorDirectiveCallback callback; - - unsigned int flags; -}; - -// A simple array of all the directives we know how to handle. -// TODO(tfoley): considering making this into a real hash map, -// and then make it easy-ish for users of the codebase to add -// their own directives as desired. -static const PreprocessorDirective kDirectives[] = -{ - { "if", &HandleIfDirective, ProcessWhenSkipping }, - { "ifdef", &HandleIfDefDirective, ProcessWhenSkipping }, - { "ifndef", &HandleIfNDefDirective, ProcessWhenSkipping }, - { "else", &HandleElseDirective, ProcessWhenSkipping }, - { "elif", &HandleElifDirective, ProcessWhenSkipping }, - { "endif", &HandleEndIfDirective, ProcessWhenSkipping }, - - { "include", &HandleIncludeDirective, DontConsumeDirectiveAutomatically }, - { "define", &HandleDefineDirective, 0 }, - { "undef", &HandleUndefDirective, 0 }, - { "warning", &HandleWarningDirective, DontConsumeDirectiveAutomatically }, - { "error", &HandleErrorDirective, DontConsumeDirectiveAutomatically }, - { "line", &HandleLineDirective, 0 }, - { "pragma", &HandlePragmaDirective, 0 }, - - { nullptr, nullptr, 0 }, -}; - -static const PreprocessorDirective kInvalidDirective = { - nullptr, &HandleInvalidDirective, 0, -}; - -// Look up the directive with the given name. -static PreprocessorDirective const* FindDirective(String const& name) -{ - char const* nameStr = name.getBuffer(); - for (int ii = 0; kDirectives[ii].name; ++ii) - { - if (strcmp(kDirectives[ii].name, nameStr) != 0) - continue; - - return &kDirectives[ii]; - } - - return &kInvalidDirective; -} - -// Process a directive, where the preprocessor has already consumed the -// `#` token that started the directive line. -static void HandleDirective(PreprocessorDirectiveContext* context) -{ - // Try to read the directive name. - context->directiveToken = PeekRawToken(context); - - TokenType directiveTokenType = GetDirective(context).type; - - // An empty directive is allowed, and ignored. - if (directiveTokenType == TokenType::EndOfDirective) - { - return; - } - // Otherwise the directive name had better be an identifier - else if (directiveTokenType != TokenType::Identifier) - { - GetSink(context)->diagnose(GetDirectiveLoc(context), Diagnostics::expectedPreprocessorDirectiveName); - SkipToEndOfLine(context); - return; - } - - // Look up the handler for the directive. - PreprocessorDirective const* directive = FindDirective(GetDirectiveName(context)); - - // If we are skipping disabled code, and the directive is not one - // of the small number that need to run even in that case, skip it. - if (IsSkipping(context) && !(directive->flags & PreprocessorDirectiveFlag::ProcessWhenSkipping)) - { - SkipToEndOfLine(context); - return; - } - - if(!(directive->flags & PreprocessorDirectiveFlag::DontConsumeDirectiveAutomatically)) - { - // Consume the directive name token. - AdvanceRawToken(context); - } - - // Apply the directive-specific callback - (directive->callback)(context); - - // We expect the directive callback to consume the entire line, so if - // it hasn't that is a parse error. - expectEndOfDirective(context); -} - -// Read one token using the full preprocessor, with all its behaviors. -static Token ReadToken(Preprocessor* preprocessor) -{ - for (;;) - { - // Depending on what the lookahead token is, we - // might need to start expanding it. - // - // Note: doing this at the start of this loop - // is important, in case a macro has an empty - // expansion, and we end up looking at a different - // token after applying the expansion. - if(!IsSkipping(preprocessor)) - { - MaybeBeginMacroExpansion(preprocessor); - } - - // Look at the next raw token in the input. - Token const& token = PeekRawToken(preprocessor); - if (token.type == TokenType::EndOfFile) - return token; - - // If we have a directive (`#` at start of line) then handle it - if ((token.type == TokenType::Pound) && (token.flags & TokenFlag::AtStartOfLine)) - { - // Skip the `#` - AdvanceRawToken(preprocessor); - - // Create a context for parsing the directive - PreprocessorDirectiveContext directiveContext; - directiveContext.preprocessor = preprocessor; - directiveContext.parseError = false; - directiveContext.haveDoneEndOfDirectiveChecks = false; - - // Parse and handle the directive - HandleDirective(&directiveContext); - continue; - } - - // otherwise, if we are currently in a skipping mode, then skip tokens - if (IsSkipping(preprocessor)) - { - AdvanceRawToken(preprocessor); - continue; - } - - // otherwise read a token, which may involve macro expansion - return AdvanceToken(preprocessor); - } -} - -// intialize a preprocessor context, using the given sink for errros -static void InitializePreprocessor( - Preprocessor* preprocessor, - DiagnosticSink* sink) -{ - preprocessor->sink = sink; - preprocessor->includeHandler = NULL; - preprocessor->endOfFileToken.type = TokenType::EndOfFile; - preprocessor->endOfFileToken.flags = TokenFlag::AtStartOfLine; -} - -// clean up after an environment -PreprocessorEnvironment::~PreprocessorEnvironment() -{ - for (auto pair : this->macros) - { - DestroyMacro(NULL, pair.Value); - } -} - -// finalize a preprocessor and free any memory still in use -static void FinalizePreprocessor( - Preprocessor* preprocessor) -{ - // Clear out any waiting input streams - PreprocessorInputStream* input = preprocessor->inputStream; - while (input) - { - PreprocessorInputStream* parent = input->parent; - EndInputStream(preprocessor, input); - input = parent; - } - -#if 0 - // clean up any macros that were allocated - for (auto pair : preprocessor->globalEnv.macros) - { - DestroyMacro(preprocessor, pair.Value); - } -#endif -} - -// Add a simple macro definition from a string (e.g., for a -// `-D` option passed on the command line -static void DefineMacro( - Preprocessor* preprocessor, - String const& key, - String const& value) -{ - PathInfo pathInfo = PathInfo::makeCommandLine(); - - PreprocessorMacro* macro = CreateMacro(preprocessor); - - auto sourceManager = preprocessor->getSourceManager(); - - SourceFile* keyFile = sourceManager->createSourceFileWithString(pathInfo, key); - SourceFile* valueFile = sourceManager->createSourceFileWithString(pathInfo, value); - - SourceView* keyView = sourceManager->createSourceView(keyFile, nullptr); - SourceView* valueView = sourceManager->createSourceView(valueFile, nullptr); - - // Use existing `Lexer` to generate a token stream. - Lexer lexer; - lexer.initialize(valueView, GetSink(preprocessor), preprocessor->getNamePool(), sourceManager->getMemoryArena()); - macro->tokens = lexer.lexAllTokens(); - - Name* keyName = preprocessor->getNamePool()->getName(key); - - macro->nameAndLoc.name = keyName; - macro->nameAndLoc.loc = keyView->getRange().begin; - - PreprocessorMacro* oldMacro = NULL; - if (preprocessor->globalEnv.macros.TryGetValue(keyName, oldMacro)) - { - DestroyMacro(preprocessor, oldMacro); - } - - preprocessor->globalEnv.macros[keyName] = macro; -} - -// read the entire input into tokens -static TokenList ReadAllTokens( - Preprocessor* preprocessor) -{ - TokenList tokens; - for (;;) - { - Token token = ReadToken(preprocessor); - - tokens.mTokens.add(token); - - // Note: we include the EOF token in the list, - // since that is expected by the `TokenList` type. - if (token.type == TokenType::EndOfFile) - break; - } - return tokens; -} - -TokenList preprocessSource( - SourceFile* file, - DiagnosticSink* sink, - IncludeHandler* includeHandler, - Dictionary defines, - Linkage* linkage, - Module* parentModule) -{ - Preprocessor preprocessor; - InitializePreprocessor(&preprocessor, sink); - preprocessor.linkage = linkage; - preprocessor.parentModule = parentModule; - - preprocessor.includeHandler = includeHandler; - for (auto p : defines) - { - DefineMacro(&preprocessor, p.Key, p.Value); - } - - SourceManager* sourceManager = linkage->getSourceManager(); - - SourceView* sourceView = sourceManager->createSourceView(file, nullptr); - - // create an initial input stream based on the provided buffer - preprocessor.inputStream = CreateInputStreamForSource(&preprocessor, sourceView); - - TokenList tokens = ReadAllTokens(&preprocessor); - - FinalizePreprocessor(&preprocessor); - - // debugging: build the pre-processed source back together -#if 0 - StringBuilder sb; - for (auto t : tokens) - { - if (t.flags & TokenFlag::AtStartOfLine) - { - sb << "\n"; - } - else if (t.flags & TokenFlag::AfterWhitespace) - { - sb << " "; - } - - sb << t.Content; - } - - String s = sb.ProduceString(); -#endif - - return tokens; -} - -} diff --git a/source/slang/preprocessor.h b/source/slang/preprocessor.h deleted file mode 100644 index 6e8ac1c69..000000000 --- a/source/slang/preprocessor.h +++ /dev/null @@ -1,38 +0,0 @@ -// Preprocessor.h -#ifndef SLANG_PREPROCESSOR_H_INCLUDED -#define SLANG_PREPROCESSOR_H_INCLUDED - -#include "../core/basic.h" -#include "../slang/lexer.h" - -namespace Slang { - -class DiagnosticSink; -class Linkage; -class Module; -class ModuleDecl; - -// Callback interface for the preprocessor to use when looking -// for files in `#include` directives. -struct IncludeHandler -{ - - virtual SlangResult findFile(const String& pathToInclude, - const String& pathIncludedFrom, - PathInfo& pathInfoOut) = 0; - - virtual String simplifyPath(const String& path) = 0; -}; - -// Take a string of source code and preprocess it into a list of tokens. -TokenList preprocessSource( - SourceFile* file, - DiagnosticSink* sink, - IncludeHandler* includeHandler, - Dictionary defines, - Linkage* linkage, - Module* parentModule); - -} // namespace Slang - -#endif diff --git a/source/slang/profile-defs.h b/source/slang/profile-defs.h deleted file mode 100644 index 238621084..000000000 --- a/source/slang/profile-defs.h +++ /dev/null @@ -1,305 +0,0 @@ -// - -// Define all the various language "profiles" we want to support. - -#ifndef LANGUAGE -#define LANGUAGE(TAG, NAME) /* emptry */ -#endif - -#ifndef LANGUAGE_ALIAS -#define LANGUAGE_ALIAS(TAG, NAME) /* empty */ -#endif - -#ifndef PROFILE_FAMILY -#define PROFILE_FAMILY(TAG) /* empty */ -#endif - -#ifndef PROFILE_VERSION -#define PROFILE_VERSION(TAG, FAMILY) /* empty */ -#endif - -#ifndef PROFILE_STAGE -#define PROFILE_STAGE(TAG, NAME, VAL) /* empty */ -#endif - -#ifndef PROFILE_STAGE_ALIAS -#define PROFILE_STAGE_ALIAS(TAG, NAME, VAL) /* empty */ -#endif - - -#ifndef PROFILE -#define PROFILE(TAG, NAME, STAGE, VERSION) /* empty */ -#endif - -#ifndef PROFILE_ALIAS -#define PROFILE_ALIAS(TAG, DEF, NAME) /* empty */ -#endif - -// Source and destination languages - -LANGUAGE(HLSL, hlsl) -LANGUAGE(DXBytecode, dxbc) -LANGUAGE(DXBytecodeAssembly,dxbc_asm) -LANGUAGE(DXIL, dxil) -LANGUAGE(DXILAssembly, dxil_asm) -LANGUAGE(GLSL, glsl) -LANGUAGE(GLSL_ES, glsl_es) -LANGUAGE(GLSL_VK, glsl_vk) -LANGUAGE(SPIRV, spirv) -LANGUAGE(SPIRV_GL, spirv_gl) - -LANGUAGE_ALIAS(GLSL, glsl_gl) -LANGUAGE_ALIAS(SPIRV, spirv_vk) - - -// Pipeline stages to target -PROFILE_STAGE(Vertex, vertex, SLANG_STAGE_VERTEX) -PROFILE_STAGE(Hull, hull, SLANG_STAGE_HULL) -PROFILE_STAGE(Domain, domain, SLANG_STAGE_DOMAIN) -PROFILE_STAGE(Geometry, geometry, SLANG_STAGE_GEOMETRY) -PROFILE_STAGE(Pixel, pixel, SLANG_STAGE_FRAGMENT) -PROFILE_STAGE(Compute, compute, SLANG_STAGE_COMPUTE) - -PROFILE_STAGE(RayGeneration, raygeneration, SLANG_STAGE_RAY_GENERATION) -PROFILE_STAGE(Intersection, intersection, SLANG_STAGE_INTERSECTION) -PROFILE_STAGE(AnyHit, anyhit, SLANG_STAGE_ANY_HIT) -PROFILE_STAGE(ClosestHit, closesthit, SLANG_STAGE_CLOSEST_HIT) -PROFILE_STAGE(Miss, miss, SLANG_STAGE_MISS) -PROFILE_STAGE(Callable, callable, SLANG_STAGE_CALLABLE) - - -// Note: HLSL and Direct3D convention erroneously uses the term "Pixel Shader" -// for the thing that shades *fragments*. Slang strives to treat the more correct -// term "Fragment Shader" as the primary one, but in order to be compatible with -// existing HLSL conventions, we need to treat `pixel` as the official stage -// name and `fragment` as an alias for it here, because the lower-case stage -// names are used to drive output HLSL generation. -// -PROFILE_STAGE_ALIAS(Fragment, fragment, Pixel) - -// Profile families - -PROFILE_FAMILY(DX) -PROFILE_FAMILY(GLSL) - -// Profile versions - - -PROFILE_VERSION(DX_4_0, DX) -PROFILE_VERSION(DX_4_0_Level_9_0, DX) -PROFILE_VERSION(DX_4_0_Level_9_1, DX) -PROFILE_VERSION(DX_4_0_Level_9_3, DX) -PROFILE_VERSION(DX_4_1, DX) -PROFILE_VERSION(DX_5_0, DX) -PROFILE_VERSION(DX_5_1, DX) -PROFILE_VERSION(DX_6_0, DX) -PROFILE_VERSION(DX_6_1, DX) -PROFILE_VERSION(DX_6_2, DX) -PROFILE_VERSION(DX_6_3, DX) - -PROFILE_VERSION(GLSL_110, GLSL) -PROFILE_VERSION(GLSL_120, GLSL) -PROFILE_VERSION(GLSL_130, GLSL) -PROFILE_VERSION(GLSL_140, GLSL) -PROFILE_VERSION(GLSL_150, GLSL) -PROFILE_VERSION(GLSL_330, GLSL) -PROFILE_VERSION(GLSL_400, GLSL) -PROFILE_VERSION(GLSL_410, GLSL) -PROFILE_VERSION(GLSL_420, GLSL) -PROFILE_VERSION(GLSL_430, GLSL) -PROFILE_VERSION(GLSL_440, GLSL) -PROFILE_VERSION(GLSL_450, GLSL) -PROFILE_VERSION(GLSL_460, GLSL) - - -// Specific profiles - -PROFILE(DX_Compute_4_0, cs_4_0, Compute, DX_4_0) -PROFILE(DX_Compute_4_1, cs_4_1, Compute, DX_4_1) -PROFILE(DX_Compute_5_0, cs_5_0, Compute, DX_5_0) -PROFILE(DX_Compute_5_1, cs_5_1, Compute, DX_5_1) -PROFILE(DX_Compute_6_0, cs_6_0, Compute, DX_6_0) -PROFILE(DX_Compute_6_1, cs_6_1, Compute, DX_6_1) -PROFILE(DX_Compute_6_2, cs_6_2, Compute, DX_6_2) -PROFILE(DX_Compute_6_3, cs_6_3, Compute, DX_6_3) - -PROFILE(DX_Domain_5_0, ds_5_0, Domain, DX_5_0) -PROFILE(DX_Domain_5_1, ds_5_1, Domain, DX_5_1) -PROFILE(DX_Domain_6_0, ds_6_0, Domain, DX_6_0) -PROFILE(DX_Domain_6_1, ds_6_1, Domain, DX_6_1) -PROFILE(DX_Domain_6_2, ds_6_2, Domain, DX_6_2) -PROFILE(DX_Domain_6_3, ds_6_3, Domain, DX_6_3) - -PROFILE(DX_Geometry_4_0, gs_4_0, Geometry, DX_4_0) -PROFILE(DX_Geometry_4_1, gs_4_1, Geometry, DX_4_1) -PROFILE(DX_Geometry_5_0, gs_5_0, Geometry, DX_5_0) -PROFILE(DX_Geometry_5_1, gs_5_1, Geometry, DX_5_1) -PROFILE(DX_Geometry_6_0, gs_6_0, Geometry, DX_6_0) -PROFILE(DX_Geometry_6_1, gs_6_1, Geometry, DX_6_1) -PROFILE(DX_Geometry_6_2, gs_6_2, Geometry, DX_6_2) -PROFILE(DX_Geometry_6_3, gs_6_3, Geometry, DX_6_3) - - -PROFILE(DX_Hull_5_0, hs_5_0, Hull, DX_5_0) -PROFILE(DX_Hull_5_1, hs_5_1, Hull, DX_5_1) -PROFILE(DX_Hull_6_0, hs_6_0, Hull, DX_6_0) -PROFILE(DX_Hull_6_1, hs_6_1, Hull, DX_6_1) -PROFILE(DX_Hull_6_2, hs_6_2, Hull, DX_6_2) -PROFILE(DX_Hull_6_3, hs_6_3, Hull, DX_6_3) - - -PROFILE(DX_Fragment_4_0, ps_4_0, Fragment, DX_4_0) -PROFILE(DX_Fragment_4_0_Level_9_0, ps_4_0_level_9_0, Fragment, DX_4_0_Level_9_0) -PROFILE(DX_Fragment_4_0_Level_9_1, ps_4_0_level_9_1, Fragment, DX_4_0_Level_9_1) -PROFILE(DX_Fragment_4_0_Level_9_3, ps_4_0_level_9_3, Fragment, DX_4_0_Level_9_3) -PROFILE(DX_Fragment_4_1, ps_4_1, Fragment, DX_4_1) -PROFILE(DX_Fragment_5_0, ps_5_0, Fragment, DX_5_0) -PROFILE(DX_Fragment_5_1, ps_5_1, Fragment, DX_5_1) -PROFILE(DX_Fragment_6_0, ps_6_0, Fragment, DX_6_0) -PROFILE(DX_Fragment_6_1, ps_6_1, Fragment, DX_6_1) -PROFILE(DX_Fragment_6_2, ps_6_2, Fragment, DX_6_2) -PROFILE(DX_Fragment_6_3, ps_6_3, Fragment, DX_6_3) - - -PROFILE(DX_Vertex_4_0, vs_4_0, Vertex, DX_4_0) -PROFILE(DX_Vertex_4_0_Level_9_0, vs_4_0_level_9_0, Vertex, DX_4_0_Level_9_0) -PROFILE(DX_Vertex_4_0_Level_9_1, vs_4_0_level_9_1, Vertex, DX_4_0_Level_9_1) -PROFILE(DX_Vertex_4_0_Level_9_3, vs_4_0_level_9_3, Vertex, DX_4_0_Level_9_3) -PROFILE(DX_Vertex_4_1, vs_4_1, Vertex, DX_4_1) -PROFILE(DX_Vertex_5_0, vs_5_0, Vertex, DX_5_0) -PROFILE(DX_Vertex_5_1, vs_5_1, Vertex, DX_5_1) -PROFILE(DX_Vertex_6_0, vs_6_0, Vertex, DX_6_0) -PROFILE(DX_Vertex_6_1, vs_6_1, Vertex, DX_6_1) -PROFILE(DX_Vertex_6_2, vs_6_2, Vertex, DX_6_2) -PROFILE(DX_Vertex_6_3, vs_6_3, Vertex, DX_6_3) - -// TODO: consider making `lib_*_*` alias these... -PROFILE(DX_None_4_0, sm_4_0, Unknown, DX_4_0) -PROFILE(DX_None_4_0_Level_9_0, sm_4_0_level_9_0, Unknown, DX_4_0_Level_9_0) -PROFILE(DX_None_4_0_Level_9_1, sm_4_0_level_9_1, Unknown, DX_4_0_Level_9_1) -PROFILE(DX_None_4_0_Level_9_3, sm_4_0_level_9_3, Unknown, DX_4_0_Level_9_3) -PROFILE(DX_None_4_1, sm_4_1, Unknown, DX_4_1) -PROFILE(DX_None_5_0, sm_5_0, Unknown, DX_5_0) -PROFILE(DX_None_5_1, sm_5_1, Unknown, DX_5_1) -PROFILE(DX_None_6_0, sm_6_0, Unknown, DX_6_0) - -// From Shader Model 6.1 on, the dxc compiler recognizes a `lib` profile -// that can be used to compile multiple entry points. We want that -// `lib` name to be the default for how we render these profiles when -// invoking downstream tools, so we use that instead of the `sm_` -// prefix, and then re-introduce the `sm_` variants as aliases. -// -// TODO: We may eventually want a split between how Slang represents -// profiles and their names to users, vs. how it renders them when -// invoking downstream tools, so that the profile name in any -// error messages can be consistent with our `sm_*` naems above -// -PROFILE(DX_Lib_6_1, lib_6_1, Unknown, DX_6_1) -PROFILE(DX_Lib_6_2, lib_6_2, Unknown, DX_6_2) -PROFILE(DX_Lib_6_3, lib_6_3, Unknown, DX_6_3) - -PROFILE_ALIAS(DX_None_6_1, DX_Lib_6_1, sm_6_1) -PROFILE_ALIAS(DX_None_6_2, DX_Lib_6_2, sm_6_2) -PROFILE_ALIAS(DX_None_6_3, DX_Lib_6_3, sm_6_3) - - -// Define all the GLSL profiles - -PROFILE(GLSL_None_110, glsl_110, Unknown, GLSL_110) -PROFILE(GLSL_None_120, glsl_120, Unknown, GLSL_120) -PROFILE(GLSL_None_130, glsl_130, Unknown, GLSL_130) -PROFILE(GLSL_None_140, glsl_140, Unknown, GLSL_140) -PROFILE(GLSL_None_150, glsl_150, Unknown, GLSL_150) -PROFILE(GLSL_None_330, glsl_330, Unknown, GLSL_330) -PROFILE(GLSL_None_400, glsl_400, Unknown, GLSL_400) -PROFILE(GLSL_None_410, glsl_410, Unknown, GLSL_410) -PROFILE(GLSL_None_420, glsl_420, Unknown, GLSL_420) -PROFILE(GLSL_None_430, glsl_430, Unknown, GLSL_430) -PROFILE(GLSL_None_440, glsl_440, Unknown, GLSL_440) -PROFILE(GLSL_None_450, glsl_450, Unknown, GLSL_450) -PROFILE(GLSL_None_460, glsl_460, Unknown, GLSL_460) - -#define P(UPPER, LOWER, VERSION) \ -PROFILE(GLSL_##UPPER##_##VERSION, glsl_##LOWER##_##VERSION, UPPER, GLSL_##VERSION) - -P(Vertex, vertex, 110) -P(Vertex, vertex, 120) -P(Vertex, vertex, 130) -P(Vertex, vertex, 140) -P(Vertex, vertex, 150) -P(Vertex, vertex, 330) -P(Vertex, vertex, 400) -P(Vertex, vertex, 410) -P(Vertex, vertex, 420) -P(Vertex, vertex, 430) -P(Vertex, vertex, 440) -P(Vertex, vertex, 450) - -P(Fragment, fragment, 110) -P(Fragment, fragment, 120) -P(Fragment, fragment, 130) -P(Fragment, fragment, 140) -P(Fragment, fragment, 150) -P(Fragment, fragment, 330) -P(Fragment, fragment, 400) -P(Fragment, fragment, 410) -P(Fragment, fragment, 420) -P(Fragment, fragment, 430) -P(Fragment, fragment, 440) -P(Fragment, fragment, 450) - -P(Geometry, geometry, 150) -P(Geometry, geometry, 330) -P(Geometry, geometry, 400) -P(Geometry, geometry, 410) -P(Geometry, geometry, 420) -P(Geometry, geometry, 430) -P(Geometry, geometry, 440) -P(Geometry, geometry, 450) - -P(Compute, compute, 430) -P(Compute, compute, 440) -P(Compute, compute, 450) - -#undef P -#define P(UPPER, LOWER, STAGE, VERSION) \ -PROFILE(GLSL_##UPPER##_##VERSION, glsl_##LOWER##_##VERSION, STAGE, GLSL_##VERSION) - -P(TessControl, tess_control, Hull, 400) -P(TessControl, tess_control, Hull, 410) -P(TessControl, tess_control, Hull, 420) -P(TessControl, tess_control, Hull, 430) -P(TessControl, tess_control, Hull, 440) -P(TessControl, tess_control, Hull, 450) - -P(TessEval, tess_eval, Domain, 400) -P(TessEval, tess_eval, Domain, 410) -P(TessEval, tess_eval, Domain, 420) -P(TessEval, tess_eval, Domain, 430) -P(TessEval, tess_eval, Domain, 440) -P(TessEval, tess_eval, Domain, 450) - -#undef P - -// Define a default profile for each GLSL stage that just -// uses the latest language version we know of - -PROFILE_ALIAS(GLSL_Vertex, GLSL_Vertex_450, glsl_vertex) -PROFILE_ALIAS(GLSL_Fragment, GLSL_Fragment_450, glsl_fragment) -PROFILE_ALIAS(GLSL_Geometry, GLSL_Geometry_450, glsl_geometry) -PROFILE_ALIAS(GLSL_TessControl, GLSL_TessControl_450, glsl_tess_control) -PROFILE_ALIAS(GLSL_TessEval, GLSL_TessEval_450, glsl_tess_eval) -PROFILE_ALIAS(GLSL_Compute, GLSL_Compute_450, glsl_compute) - -// TODO: define a profile for each GLSL *version* that we can -// use as a catch-all when the stage can be inferred from -// something else - -#undef LANGUAGE -#undef LANGUAGE_ALIAS -#undef PROFILE_FAMILY -#undef PROFILE_VERSION -#undef PROFILE_STAGE -#undef PROFILE_STAGE_ALIAS -#undef PROFILE -#undef PROFILE_ALIAS diff --git a/source/slang/profile.cpp b/source/slang/profile.cpp deleted file mode 100644 index 5f506741f..000000000 --- a/source/slang/profile.cpp +++ /dev/null @@ -1,34 +0,0 @@ -// profile.cpp -#include "profile.h" - -namespace Slang { - -ProfileFamily getProfileFamily(ProfileVersion version) -{ - switch( version ) - { - default: return ProfileFamily::Unknown; - -#define PROFILE_VERSION(TAG, FAMILY) case ProfileVersion::TAG: return ProfileFamily::FAMILY; -#include "profile-defs.h" - } -} - -const char* getStageName(Stage stage) -{ - switch(stage) - { -#define PROFILE_STAGE(ID, NAME, ENUM) \ - case Stage::ID: return #NAME; - -#include "profile-defs.h" - - default: - return nullptr; - } - -} - - - -} diff --git a/source/slang/profile.h b/source/slang/profile.h deleted file mode 100644 index cc142bc2a..000000000 --- a/source/slang/profile.h +++ /dev/null @@ -1,106 +0,0 @@ -#ifndef SLANG_PROFILE_H_INCLUDED -#define SLANG_PROFILE_H_INCLUDED - -#include "../core/basic.h" -#include "../../slang.h" - -namespace Slang -{ - // Flavors of translation unit - enum class SourceLanguage : SlangSourceLanguage - { - Unknown = SLANG_SOURCE_LANGUAGE_UNKNOWN, // should not occur - Slang = SLANG_SOURCE_LANGUAGE_SLANG, - HLSL = SLANG_SOURCE_LANGUAGE_HLSL, - GLSL = SLANG_SOURCE_LANGUAGE_GLSL, - }; - - // TODO(tfoley): This should merge with the above... - enum class Language - { - Unknown, -#define LANGUAGE(TAG, NAME) TAG, -#include "profile-defs.h" - }; - - enum class ProfileFamily - { - Unknown, -#define PROFILE_FAMILY(TAG) TAG, -#include "profile-defs.h" - }; - - enum class ProfileVersion - { - Unknown, -#define PROFILE_VERSION(TAG, FAMILY) TAG, -#include "profile-defs.h" - }; - - enum class Stage : SlangStage - { - Unknown = SLANG_STAGE_NONE, -#define PROFILE_STAGE(TAG, NAME, VAL) TAG = VAL, -#define PROFILE_STAGE_ALIAS(TAG, NAME, VAL) TAG = VAL, -#include "profile-defs.h" - }; - - const char* getStageName(Stage stage); - - ProfileFamily getProfileFamily(ProfileVersion version); - - struct Profile - { - typedef uint32_t RawVal; - enum RawEnum : RawVal - { - Unknown, - -#define PROFILE(TAG, NAME, STAGE, VERSION) TAG = (uint32_t(ProfileVersion::VERSION) << 16) | uint32_t(Stage::STAGE), -#define PROFILE_ALIAS(TAG, DEF, NAME) TAG = DEF, -#include "profile-defs.h" - }; - - Profile() {} - Profile(RawEnum raw) - : raw(raw) - {} - explicit Profile(RawVal raw) - : raw(raw) - {} - explicit Profile(Stage stage) - { - setStage(stage); - } - explicit Profile(ProfileVersion version) - { - setVersion(version); - } - - bool operator==(Profile const& other) const { return raw == other.raw; } - bool operator!=(Profile const& other) const { return raw != other.raw; } - - Stage GetStage() const { return Stage(uint32_t(raw) & 0xFFFF); } - void setStage(Stage stage) - { - raw = (raw & ~0xFFFF) | uint32_t(stage); - } - - ProfileVersion GetVersion() const { return ProfileVersion((uint32_t(raw) >> 16) & 0xFFFF); } - void setVersion(ProfileVersion version) - { - raw = (raw & 0x0000FFFF) | (uint32_t(version) << 16); - } - - ProfileFamily getFamily() const { return getProfileFamily(GetVersion()); } - - static Profile LookUp(char const* name); - char const* getName(); - - RawVal raw = Unknown; - }; - - Stage findStageByName(String const& name); -} - -#endif diff --git a/source/slang/reflection.cpp b/source/slang/reflection.cpp deleted file mode 100644 index 4d13052f6..000000000 --- a/source/slang/reflection.cpp +++ /dev/null @@ -1,1451 +0,0 @@ -// reflection.cpp -#include "reflection.h" - -#include "compiler.h" -#include "type-layout.h" -#include "syntax.h" -#include - -// Don't signal errors for stuff we don't implement here, -// and instead just try to return things defensively -// -// Slang developers can switch this when debugging. -#define SLANG_REFLECTION_UNEXPECTED() do {} while(0) - -// Implementation to back public-facing reflection API - -using namespace Slang; - - -// Conversion routines to help with strongly-typed reflection API -static inline Session* convert(SlangSession* session) -{ - return (Session*)session; -} - -static inline UserDefinedAttribute* convert(SlangReflectionUserAttribute* attrib) -{ - return (UserDefinedAttribute*)attrib; -} -static inline SlangReflectionUserAttribute* convert(UserDefinedAttribute* attrib) -{ - return (SlangReflectionUserAttribute*)attrib; -} -static inline Type* convert(SlangReflectionType* type) -{ - return (Type*) type; -} - -static inline SlangReflectionType* convert(Type* type) -{ - return (SlangReflectionType*) type; -} - -static inline TypeLayout* convert(SlangReflectionTypeLayout* type) -{ - return (TypeLayout*) type; -} - -static inline SlangReflectionTypeLayout* convert(TypeLayout* type) -{ - return (SlangReflectionTypeLayout*) type; -} - -static inline GenericParamLayout* convert(SlangReflectionTypeParameter * typeParam) -{ - return (GenericParamLayout*)typeParam; -} - -static inline VarDeclBase* convert(SlangReflectionVariable* var) -{ - return (VarDeclBase*) var; -} - -static inline SlangReflectionVariable* convert(VarDeclBase* var) -{ - return (SlangReflectionVariable*) var; -} - -static inline VarLayout* convert(SlangReflectionVariableLayout* var) -{ - return (VarLayout*) var; -} - -static inline SlangReflectionVariableLayout* convert(VarLayout* var) -{ - return (SlangReflectionVariableLayout*) var; -} - -static inline EntryPointLayout* convert(SlangReflectionEntryPoint* entryPoint) -{ - return (EntryPointLayout*) entryPoint; -} - -static inline SlangReflectionEntryPoint* convert(EntryPointLayout* entryPoint) -{ - return (SlangReflectionEntryPoint*) entryPoint; -} - - -static inline ProgramLayout* convert(SlangReflection* program) -{ - return (ProgramLayout*) program; -} - -static inline SlangReflection* convert(ProgramLayout* program) -{ - return (SlangReflection*) program; -} - -// user attaribute - -unsigned int getUserAttributeCount(Decl* decl) -{ - unsigned int count = 0; - for (auto x : decl->GetModifiersOfType()) - { - SLANG_UNUSED(x); - count++; - } - return count; -} - -SlangReflectionUserAttribute* findUserAttributeByName(Session* session, Decl* decl, const char* name) -{ - auto nameObj = session->tryGetNameObj(name); - for (auto x : decl->GetModifiersOfType()) - { - if (x->name == nameObj) - return (SlangReflectionUserAttribute*)(x); - } - return nullptr; -} - -SlangReflectionUserAttribute* getUserAttributeByIndex(Decl* decl, unsigned int index) -{ - unsigned int id = 0; - for (auto x : decl->GetModifiersOfType()) - { - if (id == index) - return convert(x); - id++; - } - return nullptr; -} - -SLANG_API char const* spReflectionUserAttribute_GetName(SlangReflectionUserAttribute* attrib) -{ - auto userAttr = convert(attrib); - if (!userAttr) return nullptr; - return userAttr->getName()->text.getBuffer(); -} -SLANG_API unsigned int spReflectionUserAttribute_GetArgumentCount(SlangReflectionUserAttribute* attrib) -{ - auto userAttr = convert(attrib); - if (!userAttr) return 0; - return (unsigned int)userAttr->args.getCount(); -} -SlangReflectionType* spReflectionUserAttribute_GetArgumentType(SlangReflectionUserAttribute* attrib, unsigned int index) -{ - auto userAttr = convert(attrib); - if (!userAttr) return nullptr; - return convert(userAttr->args[index]->type.type.Ptr()); -} -SLANG_API SlangResult spReflectionUserAttribute_GetArgumentValueInt(SlangReflectionUserAttribute* attrib, unsigned int index, int * rs) -{ - auto userAttr = convert(attrib); - if (!userAttr) return SLANG_ERROR_INVALID_PARAMETER; - if (index >= (unsigned int)userAttr->args.getCount()) return SLANG_ERROR_INVALID_PARAMETER; - RefPtr val; - if (userAttr->intArgVals.TryGetValue(index, val)) - { - *rs = (int)as(val)->value; - return 0; - } - return SLANG_ERROR_INVALID_PARAMETER; -} -SLANG_API SlangResult spReflectionUserAttribute_GetArgumentValueFloat(SlangReflectionUserAttribute* attrib, unsigned int index, float * rs) -{ - auto userAttr = convert(attrib); - if (!userAttr) return SLANG_ERROR_INVALID_PARAMETER; - if (index >= (unsigned int)userAttr->args.getCount()) return SLANG_ERROR_INVALID_PARAMETER; - if (auto cexpr = as(userAttr->args[index])) - { - *rs = (float)cexpr->value; - return 0; - } - return SLANG_ERROR_INVALID_PARAMETER; -} -SLANG_API const char* spReflectionUserAttribute_GetArgumentValueString(SlangReflectionUserAttribute* attrib, unsigned int index, size_t* bufLen) -{ - auto userAttr = convert(attrib); - if (!userAttr) return nullptr; - if (index >= (unsigned int)userAttr->args.getCount()) return nullptr; - if (auto cexpr = as(userAttr->args[index])) - { - if (bufLen) - *bufLen = cexpr->token.Content.size(); - return cexpr->token.Content.begin(); - } - return nullptr; -} - - - -// type Reflection - - -SLANG_API SlangTypeKind spReflectionType_GetKind(SlangReflectionType* inType) -{ - auto type = convert(inType); - if(!type) return SLANG_TYPE_KIND_NONE; - - // TODO(tfoley: Don't emit the same type more than once... - - if (auto basicType = as(type)) - { - return SLANG_TYPE_KIND_SCALAR; - } - else if (auto vectorType = as(type)) - { - return SLANG_TYPE_KIND_VECTOR; - } - else if (auto matrixType = as(type)) - { - return SLANG_TYPE_KIND_MATRIX; - } - else if (auto parameterBlockType = as(type)) - { - return SLANG_TYPE_KIND_PARAMETER_BLOCK; - } - else if (auto constantBufferType = as(type)) - { - return SLANG_TYPE_KIND_CONSTANT_BUFFER; - } - else if( auto streamOutputType = as(type) ) - { - return SLANG_TYPE_KIND_OUTPUT_STREAM; - } - else if (as(type)) - { - return SLANG_TYPE_KIND_TEXTURE_BUFFER; - } - else if (as(type)) - { - return SLANG_TYPE_KIND_SHADER_STORAGE_BUFFER; - } - else if (auto samplerStateType = as(type)) - { - return SLANG_TYPE_KIND_SAMPLER_STATE; - } - else if (auto textureType = as(type)) - { - return SLANG_TYPE_KIND_RESOURCE; - } - // TODO: need a better way to handle this stuff... -#define CASE(TYPE) \ - else if(as(type)) do { \ - return SLANG_TYPE_KIND_RESOURCE; \ - } while(0) - - CASE(HLSLStructuredBufferType); - CASE(HLSLRWStructuredBufferType); - CASE(HLSLRasterizerOrderedStructuredBufferType); - CASE(HLSLAppendStructuredBufferType); - CASE(HLSLConsumeStructuredBufferType); - CASE(HLSLByteAddressBufferType); - CASE(HLSLRWByteAddressBufferType); - CASE(HLSLRasterizerOrderedByteAddressBufferType); - CASE(UntypedBufferResourceType); -#undef CASE - - else if (auto arrayType = as(type)) - { - return SLANG_TYPE_KIND_ARRAY; - } - else if( auto declRefType = as(type) ) - { - const auto& declRef = declRefType->declRef; - if(declRef.is() ) - { - return SLANG_TYPE_KIND_STRUCT; - } - else if (declRef.is()) - { - return SLANG_TYPE_KIND_GENERIC_TYPE_PARAMETER; - } - else if (declRef.is()) - { - return SLANG_TYPE_KIND_INTERFACE; - } - } - else if( auto specializedType = as(type) ) - { - return SLANG_TYPE_KIND_SPECIALIZED; - } - else if (auto errorType = as(type)) - { - // This means we saw a type we didn't understand in the user's code - return SLANG_TYPE_KIND_NONE; - } - - SLANG_REFLECTION_UNEXPECTED(); - return SLANG_TYPE_KIND_NONE; -} - -SLANG_API unsigned int spReflectionType_GetFieldCount(SlangReflectionType* inType) -{ - auto type = convert(inType); - if(!type) return 0; - - // TODO: maybe filter based on kind - - if(auto declRefType = as(type)) - { - auto declRef = declRefType->declRef; - if( auto structDeclRef = declRef.as()) - { - return GetFields(structDeclRef).Count(); - } - } - - return 0; -} - -SLANG_API SlangReflectionVariable* spReflectionType_GetFieldByIndex(SlangReflectionType* inType, unsigned index) -{ - auto type = convert(inType); - if(!type) return nullptr; - - // TODO: maybe filter based on kind - - if(auto declRefType = as(type)) - { - auto declRef = declRefType->declRef; - if( auto structDeclRef = declRef.as()) - { - auto fieldDeclRef = GetFields(structDeclRef).ToArray()[index]; - return (SlangReflectionVariable*) fieldDeclRef.getDecl(); - } - } - - return nullptr; -} - -SLANG_API size_t spReflectionType_GetElementCount(SlangReflectionType* inType) -{ - auto type = convert(inType); - if(!type) return 0; - - if(auto arrayType = as(type)) - { - return arrayType->ArrayLength ? (size_t) GetIntVal(arrayType->ArrayLength) : 0; - } - else if( auto vectorType = as(type)) - { - return (size_t) GetIntVal(vectorType->elementCount); - } - - return 0; -} - -SLANG_API SlangReflectionType* spReflectionType_GetElementType(SlangReflectionType* inType) -{ - auto type = convert(inType); - if(!type) return nullptr; - - if(auto arrayType = as(type)) - { - return (SlangReflectionType*) arrayType->baseType.Ptr(); - } - else if( auto constantBufferType = as(type)) - { - return convert(constantBufferType->elementType.Ptr()); - } - else if( auto vectorType = as(type)) - { - return convert(vectorType->elementType.Ptr()); - } - else if( auto matrixType = as(type)) - { - return convert(matrixType->getElementType()); - } - - return nullptr; -} - -SLANG_API unsigned int spReflectionType_GetRowCount(SlangReflectionType* inType) -{ - auto type = convert(inType); - if(!type) return 0; - - if(auto matrixType = as(type)) - { - return (unsigned int) GetIntVal(matrixType->getRowCount()); - } - else if(auto vectorType = as(type)) - { - return 1; - } - else if( auto basicType = as(type) ) - { - return 1; - } - - return 0; -} - -SLANG_API unsigned int spReflectionType_GetColumnCount(SlangReflectionType* inType) -{ - auto type = convert(inType); - if(!type) return 0; - - if(auto matrixType = as(type)) - { - return (unsigned int) GetIntVal(matrixType->getColumnCount()); - } - else if(auto vectorType = as(type)) - { - return (unsigned int) GetIntVal(vectorType->elementCount); - } - else if( auto basicType = as(type) ) - { - return 1; - } - - return 0; -} - -SLANG_API SlangScalarType spReflectionType_GetScalarType(SlangReflectionType* inType) -{ - auto type = convert(inType); - if(!type) return 0; - - if(auto matrixType = as(type)) - { - type = matrixType->getElementType(); - } - else if(auto vectorType = as(type)) - { - type = vectorType->elementType.Ptr(); - } - - if(auto basicType = as(type)) - { - switch (basicType->baseType) - { -#define CASE(BASE, TAG) \ - case BaseType::BASE: return SLANG_SCALAR_TYPE_##TAG - - CASE(Void, VOID); - CASE(Bool, BOOL); - CASE(Int8, INT8); - CASE(Int16, INT16); - CASE(Int, INT32); - CASE(Int64, INT64); - CASE(UInt8, UINT8); - CASE(UInt16, UINT16); - CASE(UInt, UINT32); - CASE(UInt64, UINT64); - CASE(Half, FLOAT16); - CASE(Float, FLOAT32); - CASE(Double, FLOAT64); - -#undef CASE - - default: - SLANG_REFLECTION_UNEXPECTED(); - return SLANG_SCALAR_TYPE_NONE; - break; - } - } - - return SLANG_SCALAR_TYPE_NONE; -} - -SLANG_API unsigned int spReflectionType_GetUserAttributeCount(SlangReflectionType* inType) -{ - auto type = convert(inType); - if (!type) return 0; - if (auto declRefType = as(type)) - { - return getUserAttributeCount(declRefType->declRef.getDecl()); - } - return 0; -} -SLANG_API SlangReflectionUserAttribute* spReflectionType_GetUserAttribute(SlangReflectionType* inType, unsigned int index) -{ - auto type = convert(inType); - if (!type) return 0; - if (auto declRefType = as(type)) - { - return getUserAttributeByIndex(declRefType->declRef.getDecl(), index); - } - return 0; -} -SLANG_API SlangReflectionUserAttribute* spReflectionType_FindUserAttributeByName(SlangReflectionType* inType, char const* name) -{ - auto type = convert(inType); - if (!type) return 0; - if (auto declRefType = as(type)) - { - return findUserAttributeByName(declRefType->getSession(), declRefType->declRef.getDecl(), name); - } - return 0; -} - -SLANG_API SlangResourceShape spReflectionType_GetResourceShape(SlangReflectionType* inType) -{ - auto type = convert(inType); - if(!type) return 0; - - while(auto arrayType = as(type)) - { - type = arrayType->baseType.Ptr(); - } - - if(auto textureType = as(type)) - { - return textureType->getShape(); - } - - // TODO: need a better way to handle this stuff... -#define CASE(TYPE, SHAPE, ACCESS) \ - else if(as(type)) do { \ - return SHAPE; \ - } while(0) - - CASE(HLSLStructuredBufferType, SLANG_STRUCTURED_BUFFER, SLANG_RESOURCE_ACCESS_READ); - CASE(HLSLRWStructuredBufferType, SLANG_STRUCTURED_BUFFER, SLANG_RESOURCE_ACCESS_READ_WRITE); - CASE(HLSLRasterizerOrderedStructuredBufferType, SLANG_STRUCTURED_BUFFER, SLANG_RESOURCE_ACCESS_RASTER_ORDERED); - CASE(HLSLAppendStructuredBufferType, SLANG_STRUCTURED_BUFFER, SLANG_RESOURCE_ACCESS_APPEND); - CASE(HLSLConsumeStructuredBufferType, SLANG_STRUCTURED_BUFFER, SLANG_RESOURCE_ACCESS_CONSUME); - CASE(HLSLByteAddressBufferType, SLANG_BYTE_ADDRESS_BUFFER, SLANG_RESOURCE_ACCESS_READ); - CASE(HLSLRWByteAddressBufferType, SLANG_BYTE_ADDRESS_BUFFER, SLANG_RESOURCE_ACCESS_READ_WRITE); - CASE(HLSLRasterizerOrderedByteAddressBufferType, SLANG_BYTE_ADDRESS_BUFFER, SLANG_RESOURCE_ACCESS_RASTER_ORDERED); - CASE(RaytracingAccelerationStructureType, SLANG_ACCELERATION_STRUCTURE, SLANG_RESOURCE_ACCESS_READ); - CASE(UntypedBufferResourceType, SLANG_BYTE_ADDRESS_BUFFER, SLANG_RESOURCE_ACCESS_READ); -#undef CASE - - return SLANG_RESOURCE_NONE; -} - -SLANG_API SlangResourceAccess spReflectionType_GetResourceAccess(SlangReflectionType* inType) -{ - auto type = convert(inType); - if(!type) return 0; - - while(auto arrayType = as(type)) - { - type = arrayType->baseType.Ptr(); - } - - if(auto textureType = as(type)) - { - return textureType->getAccess(); - } - - // TODO: need a better way to handle this stuff... -#define CASE(TYPE, SHAPE, ACCESS) \ - else if(as(type)) do { \ - return ACCESS; \ - } while(0) - - CASE(HLSLStructuredBufferType, SLANG_STRUCTURED_BUFFER, SLANG_RESOURCE_ACCESS_READ); - CASE(HLSLRWStructuredBufferType, SLANG_STRUCTURED_BUFFER, SLANG_RESOURCE_ACCESS_READ_WRITE); - CASE(HLSLRasterizerOrderedStructuredBufferType, SLANG_STRUCTURED_BUFFER, SLANG_RESOURCE_ACCESS_RASTER_ORDERED); - CASE(HLSLAppendStructuredBufferType, SLANG_STRUCTURED_BUFFER, SLANG_RESOURCE_ACCESS_APPEND); - CASE(HLSLConsumeStructuredBufferType, SLANG_STRUCTURED_BUFFER, SLANG_RESOURCE_ACCESS_CONSUME); - CASE(HLSLByteAddressBufferType, SLANG_BYTE_ADDRESS_BUFFER, SLANG_RESOURCE_ACCESS_READ); - CASE(HLSLRWByteAddressBufferType, SLANG_BYTE_ADDRESS_BUFFER, SLANG_RESOURCE_ACCESS_READ_WRITE); - CASE(HLSLRasterizerOrderedByteAddressBufferType, SLANG_BYTE_ADDRESS_BUFFER, SLANG_RESOURCE_ACCESS_RASTER_ORDERED); - CASE(UntypedBufferResourceType, SLANG_BYTE_ADDRESS_BUFFER, SLANG_RESOURCE_ACCESS_READ); - - // This isn't entirely accurate, but I can live with it for now - CASE(GLSLShaderStorageBufferType, SLANG_STRUCTURED_BUFFER, SLANG_RESOURCE_ACCESS_READ_WRITE); -#undef CASE - - return SLANG_RESOURCE_ACCESS_NONE; -} - -SLANG_API char const* spReflectionType_GetName(SlangReflectionType* inType) -{ - auto type = convert(inType); - - if( auto declRefType = as(type) ) - { - auto declRef = declRefType->declRef; - - // Don't return a name for auto-generated anonymous types - // that represent `cbuffer` members, etc. - auto decl = declRef.getDecl(); - if(decl->HasModifier()) - return nullptr; - - return getText(declRef.GetName()).begin(); - } - - return nullptr; -} - -SLANG_API SlangReflectionType * spReflection_FindTypeByName(SlangReflection * reflection, char const * name) -{ - auto programLayout = convert(reflection); - auto program = programLayout->getProgram(); - - // TODO: We should extend this API to support getting error messages - // when type lookup fails. - // - Slang::DiagnosticSink sink; - - sink.sourceManager = programLayout->getTargetReq()->getLinkage()->getSourceManager();; - RefPtr result = program->getTypeFromString(name, &sink); - return (SlangReflectionType*)result.Ptr(); -} - -SLANG_API SlangReflectionTypeLayout* spReflection_GetTypeLayout( - SlangReflection* reflection, - SlangReflectionType* inType, - SlangLayoutRules /*rules*/) -{ - auto context = convert(reflection); - auto type = convert(inType); - auto targetReq = context->getTargetReq(); - auto layoutContext = getInitialLayoutContextForTarget(targetReq, context); - RefPtr result; - if (targetReq->getTypeLayouts().TryGetValue(type, result)) - return (SlangReflectionTypeLayout*)result.Ptr(); - result = createTypeLayout(layoutContext, type); - targetReq->getTypeLayouts()[type] = result; - return (SlangReflectionTypeLayout*)result.Ptr(); -} - -SLANG_API SlangReflectionType* spReflectionType_GetResourceResultType(SlangReflectionType* inType) -{ - auto type = convert(inType); - if(!type) return nullptr; - - while(auto arrayType = as(type)) - { - type = arrayType->baseType.Ptr(); - } - - if (auto textureType = as(type)) - { - return convert(textureType->elementType.Ptr()); - } - - // TODO: need a better way to handle this stuff... -#define CASE(TYPE, SHAPE, ACCESS) \ - else if(as(type)) do { \ - return convert(as(type)->elementType.Ptr()); \ - } while(0) - - // TODO: structured buffer needs to expose type layout! - - CASE(HLSLStructuredBufferType, SLANG_STRUCTURED_BUFFER, SLANG_RESOURCE_ACCESS_READ); - CASE(HLSLRWStructuredBufferType, SLANG_STRUCTURED_BUFFER, SLANG_RESOURCE_ACCESS_READ_WRITE); - CASE(HLSLRasterizerOrderedStructuredBufferType, SLANG_STRUCTURED_BUFFER, SLANG_RESOURCE_ACCESS_RASTER_ORDERED); - CASE(HLSLAppendStructuredBufferType, SLANG_STRUCTURED_BUFFER, SLANG_RESOURCE_ACCESS_APPEND); - CASE(HLSLConsumeStructuredBufferType, SLANG_STRUCTURED_BUFFER, SLANG_RESOURCE_ACCESS_CONSUME); -#undef CASE - - return nullptr; -} - -// type Layout Reflection - -SLANG_API SlangReflectionType* spReflectionTypeLayout_GetType(SlangReflectionTypeLayout* inTypeLayout) -{ - auto typeLayout = convert(inTypeLayout); - if(!typeLayout) return nullptr; - - return (SlangReflectionType*) typeLayout->type.Ptr(); -} - -namespace -{ - static size_t getReflectionSize(LayoutSize size) - { - if(size.isFinite()) - return size.getFiniteValue(); - - return SLANG_UNBOUNDED_SIZE; - } -} - -SLANG_API size_t spReflectionTypeLayout_GetSize(SlangReflectionTypeLayout* inTypeLayout, SlangParameterCategory category) -{ - auto typeLayout = convert(inTypeLayout); - if(!typeLayout) return 0; - - auto info = typeLayout->FindResourceInfo(LayoutResourceKind(category)); - if(!info) return 0; - - return getReflectionSize(info->count); -} - -SLANG_API SlangReflectionVariableLayout* spReflectionTypeLayout_GetFieldByIndex(SlangReflectionTypeLayout* inTypeLayout, unsigned index) -{ - auto typeLayout = convert(inTypeLayout); - if(!typeLayout) return nullptr; - - if(auto structTypeLayout = as(typeLayout)) - { - return (SlangReflectionVariableLayout*) structTypeLayout->fields[index].Ptr(); - } - - return nullptr; -} - -SLANG_API size_t spReflectionTypeLayout_GetElementStride(SlangReflectionTypeLayout* inTypeLayout, SlangParameterCategory category) -{ - auto typeLayout = convert(inTypeLayout); - if(!typeLayout) return 0; - - if( auto arrayTypeLayout = as(typeLayout)) - { - switch (category) - { - // We store the stride explicitly for the uniform case - case SLANG_PARAMETER_CATEGORY_UNIFORM: - return arrayTypeLayout->uniformStride; - - // For most other cases (resource registers), the "stride" - // of an array is simply the number of resources (if any) - // consumed by its element type. - default: - { - auto elementTypeLayout = arrayTypeLayout->elementTypeLayout; - auto info = elementTypeLayout->FindResourceInfo(LayoutResourceKind(category)); - if(!info) return 0; - return getReflectionSize(info->count); - } - - // An important special case, though, is Vulkan descriptor-table slots, - // where an entire array will use a single `binding`, so that the - // effective stride is zero: - case SLANG_PARAMETER_CATEGORY_DESCRIPTOR_TABLE_SLOT: - return 0; - } - } - - return 0; -} - -SLANG_API SlangReflectionTypeLayout* spReflectionTypeLayout_GetElementTypeLayout(SlangReflectionTypeLayout* inTypeLayout) -{ - auto typeLayout = convert(inTypeLayout); - if(!typeLayout) return nullptr; - - if( auto arrayTypeLayout = as(typeLayout)) - { - return (SlangReflectionTypeLayout*) arrayTypeLayout->elementTypeLayout.Ptr(); - } - else if( auto constantBufferTypeLayout = as(typeLayout)) - { - return convert(constantBufferTypeLayout->offsetElementTypeLayout.Ptr()); - } - else if( auto structuredBufferTypeLayout = as(typeLayout)) - { - return convert(structuredBufferTypeLayout->elementTypeLayout.Ptr()); - } - else if( auto specializedTypeLayout = as(typeLayout) ) - { - return convert(specializedTypeLayout->baseTypeLayout.Ptr()); - } - - return nullptr; -} - -SLANG_API SlangReflectionVariableLayout* spReflectionTypeLayout_GetElementVarLayout(SlangReflectionTypeLayout* inTypeLayout) -{ - auto typeLayout = convert(inTypeLayout); - if(!typeLayout) return nullptr; - - if( auto constantBufferTypeLayout = as(typeLayout)) - { - return convert(constantBufferTypeLayout->elementVarLayout.Ptr()); - } - - return nullptr; -} - -static SlangParameterCategory getParameterCategory( - LayoutResourceKind kind) -{ - return SlangParameterCategory(kind); -} - -static SlangParameterCategory getParameterCategory( - TypeLayout* typeLayout) -{ - auto resourceInfoCount = typeLayout->resourceInfos.getCount(); - if(resourceInfoCount == 1) - { - return getParameterCategory(typeLayout->resourceInfos[0].kind); - } - else if(resourceInfoCount == 0) - { - // TODO: can this ever happen? - return SLANG_PARAMETER_CATEGORY_NONE; - } - return SLANG_PARAMETER_CATEGORY_MIXED; -} - -static TypeLayout* maybeGetContainerLayout(TypeLayout* typeLayout) -{ - if (auto parameterGroupTypeLayout = as(typeLayout)) - { - auto containerTypeLayout = parameterGroupTypeLayout->containerVarLayout->typeLayout; - if (containerTypeLayout->resourceInfos.getCount() != 0) - { - return containerTypeLayout; - } - } - - return typeLayout; -} - -SLANG_API SlangParameterCategory spReflectionTypeLayout_GetParameterCategory(SlangReflectionTypeLayout* inTypeLayout) -{ - auto typeLayout = convert(inTypeLayout); - if(!typeLayout) return SLANG_PARAMETER_CATEGORY_NONE; - - typeLayout = maybeGetContainerLayout(typeLayout); - - return getParameterCategory(typeLayout); -} - -SLANG_API unsigned spReflectionTypeLayout_GetCategoryCount(SlangReflectionTypeLayout* inTypeLayout) -{ - auto typeLayout = convert(inTypeLayout); - if(!typeLayout) return 0; - - typeLayout = maybeGetContainerLayout(typeLayout); - - return (unsigned) typeLayout->resourceInfos.getCount(); -} - -SLANG_API SlangParameterCategory spReflectionTypeLayout_GetCategoryByIndex(SlangReflectionTypeLayout* inTypeLayout, unsigned index) -{ - auto typeLayout = convert(inTypeLayout); - if(!typeLayout) return SLANG_PARAMETER_CATEGORY_NONE; - - typeLayout = maybeGetContainerLayout(typeLayout); - - return typeLayout->resourceInfos[index].kind; -} - -SLANG_API SlangMatrixLayoutMode spReflectionTypeLayout_GetMatrixLayoutMode(SlangReflectionTypeLayout* inTypeLayout) -{ - auto typeLayout = convert(inTypeLayout); - if(!typeLayout) return SLANG_MATRIX_LAYOUT_MODE_UNKNOWN; - - if( auto matrixLayout = as(typeLayout) ) - { - return matrixLayout->mode; - } - else - { - return SLANG_MATRIX_LAYOUT_MODE_UNKNOWN; - } - -} - -SLANG_API int spReflectionTypeLayout_getGenericParamIndex(SlangReflectionTypeLayout* inTypeLayout) -{ - auto typeLayout = convert(inTypeLayout); - if(!typeLayout) return -1; - - if(auto genericParamTypeLayout = as(typeLayout)) - { - return genericParamTypeLayout->paramIndex; - } - else - { - return -1; - } -} - -SLANG_API SlangReflectionTypeLayout* spReflectionTypeLayout_getPendingDataTypeLayout(SlangReflectionTypeLayout* inTypeLayout) -{ - auto typeLayout = convert(inTypeLayout); - if(!typeLayout) return nullptr; - - auto pendingDataTypeLayout = typeLayout->pendingDataTypeLayout.Ptr(); - return convert(pendingDataTypeLayout); -} - -SLANG_API SlangReflectionVariableLayout* spReflectionVariableLayout_getPendingDataLayout(SlangReflectionVariableLayout* inVarLayout) -{ - auto varLayout = convert(inVarLayout); - if(!varLayout) return nullptr; - - auto pendingDataLayout = varLayout->pendingVarLayout.Ptr(); - return convert(pendingDataLayout); -} - -SLANG_API SlangReflectionVariableLayout* spReflectionTypeLayout_getSpecializedTypePendingDataVarLayout(SlangReflectionTypeLayout* inTypeLayout) -{ - auto typeLayout = convert(inTypeLayout); - if(!typeLayout) return nullptr; - - if( auto specializedTypeLayout = as(typeLayout) ) - { - auto pendingDataVarLayout = specializedTypeLayout->pendingDataVarLayout.Ptr(); - return convert(pendingDataVarLayout); - } - else - { - return nullptr; - } -} - - -// Variable Reflection - -SLANG_API char const* spReflectionVariable_GetName(SlangReflectionVariable* inVar) -{ - auto var = convert(inVar); - if(!var) return nullptr; - - // If the variable is one that has an "external" name that is supposed - // to be exposed for reflection, then report it here - if(auto reflectionNameMod = var->FindModifier()) - return getText(reflectionNameMod->nameAndLoc.name).getBuffer(); - - return getText(var->getName()).getBuffer(); -} - -SLANG_API SlangReflectionType* spReflectionVariable_GetType(SlangReflectionVariable* inVar) -{ - auto var = convert(inVar); - if(!var) return nullptr; - - return convert(var->getType()); -} - -SLANG_API SlangReflectionModifier* spReflectionVariable_FindModifier(SlangReflectionVariable* inVar, SlangModifierID modifierID) -{ - auto var = convert(inVar); - if(!var) return nullptr; - - Modifier* modifier = nullptr; - switch( modifierID ) - { - case SLANG_MODIFIER_SHARED: - modifier = var->FindModifier(); - break; - - default: - return nullptr; - } - - return (SlangReflectionModifier*) modifier; -} - -SLANG_API unsigned int spReflectionVariable_GetUserAttributeCount(SlangReflectionVariable* inVar) -{ - auto varDecl = convert(inVar); - if (!varDecl) return 0; - return getUserAttributeCount(varDecl); -} -SLANG_API SlangReflectionUserAttribute* spReflectionVariable_GetUserAttribute(SlangReflectionVariable* inVar, unsigned int index) -{ - auto varDecl = convert(inVar); - if (!varDecl) return 0; - return getUserAttributeByIndex(varDecl, index); -} -SLANG_API SlangReflectionUserAttribute* spReflectionVariable_FindUserAttributeByName(SlangReflectionVariable* inVar, SlangSession* session, char const* name) -{ - auto varDecl = convert(inVar); - if (!varDecl) return 0; - return findUserAttributeByName(convert(session), varDecl, name); -} - -// Variable Layout Reflection - -SLANG_API SlangReflectionVariable* spReflectionVariableLayout_GetVariable(SlangReflectionVariableLayout* inVarLayout) -{ - auto varLayout = convert(inVarLayout); - if(!varLayout) return nullptr; - - return convert(varLayout->varDecl.getDecl()); -} - -SLANG_API SlangReflectionTypeLayout* spReflectionVariableLayout_GetTypeLayout(SlangReflectionVariableLayout* inVarLayout) -{ - auto varLayout = convert(inVarLayout); - if(!varLayout) return nullptr; - - return convert(varLayout->getTypeLayout()); -} - -namespace Slang -{ - // Attempt "do what I mean" remapping from the parameter category the user asked about, - // over to a parameter category that they might have meant. - static SlangParameterCategory maybeRemapParameterCategory( - TypeLayout* typeLayout, - SlangParameterCategory category) - { - // Do we have an entry for the category they asked about? Then use that. - if (typeLayout->FindResourceInfo(LayoutResourceKind(category))) - return category; - - // Do we have an entry for the `DescriptorTableSlot` category? - if (typeLayout->FindResourceInfo(LayoutResourceKind::DescriptorTableSlot)) - { - // Is the category they were asking about one that makes sense for the type - // of this variable? - Type* type = typeLayout->getType(); - while (auto arrayType = as(type)) - type = arrayType->baseType; - switch (spReflectionType_GetKind(convert(type))) - { - case SLANG_TYPE_KIND_CONSTANT_BUFFER: - if(category == SLANG_PARAMETER_CATEGORY_CONSTANT_BUFFER) - return SLANG_PARAMETER_CATEGORY_DESCRIPTOR_TABLE_SLOT; - break; - - case SLANG_TYPE_KIND_RESOURCE: - if(category == SLANG_PARAMETER_CATEGORY_SHADER_RESOURCE) - return SLANG_PARAMETER_CATEGORY_DESCRIPTOR_TABLE_SLOT; - break; - - case SLANG_TYPE_KIND_SAMPLER_STATE: - if(category == SLANG_PARAMETER_CATEGORY_SAMPLER_STATE) - return SLANG_PARAMETER_CATEGORY_DESCRIPTOR_TABLE_SLOT; - break; - - // TODO: implement more helpers here - - default: - break; - } - } - - return category; - } -} - -SLANG_API size_t spReflectionVariableLayout_GetOffset(SlangReflectionVariableLayout* inVarLayout, SlangParameterCategory category) -{ - auto varLayout = convert(inVarLayout); - if(!varLayout) return 0; - - auto info = varLayout->FindResourceInfo(LayoutResourceKind(category)); - - if (!info) - { - // No match with requested category? Try again with one they might have meant... - category = maybeRemapParameterCategory(varLayout->getTypeLayout(), category); - info = varLayout->FindResourceInfo(LayoutResourceKind(category)); - } - - if(!info) return 0; - - return info->index; -} - -SLANG_API size_t spReflectionVariableLayout_GetSpace(SlangReflectionVariableLayout* inVarLayout, SlangParameterCategory category) -{ - auto varLayout = convert(inVarLayout); - if(!varLayout) return 0; - - - auto info = varLayout->FindResourceInfo(LayoutResourceKind(category)); - if (!info) - { - // No match with requested category? Try again with one they might have meant... - category = maybeRemapParameterCategory(varLayout->getTypeLayout(), category); - info = varLayout->FindResourceInfo(LayoutResourceKind(category)); - } - - UInt space = 0; - - // First, deal with any offset applied to the specific resource kind specified - if (info) - { - space += info->space; - } - - // Next, deal with any dedicated register-space offset applied to, e.g., a parameter block - if (auto spaceInfo = varLayout->FindResourceInfo(LayoutResourceKind::RegisterSpace)) - { - space += spaceInfo->index; - } - - return space; -} - -SLANG_API char const* spReflectionVariableLayout_GetSemanticName(SlangReflectionVariableLayout* inVarLayout) -{ - auto varLayout = convert(inVarLayout); - if(!varLayout) return 0; - - if (!(varLayout->flags & Slang::VarLayoutFlag::HasSemantic)) - return 0; - - return varLayout->semanticName.getBuffer(); -} - -SLANG_API size_t spReflectionVariableLayout_GetSemanticIndex(SlangReflectionVariableLayout* inVarLayout) -{ - auto varLayout = convert(inVarLayout); - if(!varLayout) return 0; - - if (!(varLayout->flags & Slang::VarLayoutFlag::HasSemantic)) - return 0; - - return varLayout->semanticIndex; -} - -SLANG_API SlangStage spReflectionVariableLayout_getStage( - SlangReflectionVariableLayout* inVarLayout) -{ - auto varLayout = convert(inVarLayout); - if(!varLayout) return SLANG_STAGE_NONE; - - // A parameter that is not a varying input or output is - // not considered to belong to a single stage. - // - // TODO: We might need to reconsider this for, e.g., entry - // point parameters, where they might be stage-specific even - // if they are uniform. - if (!varLayout->FindResourceInfo(Slang::LayoutResourceKind::VaryingInput) - && !varLayout->FindResourceInfo(Slang::LayoutResourceKind::VaryingOutput)) - { - return SLANG_STAGE_NONE; - } - - // TODO: We should find the stage for a variable layout by - // walking up the tree of layout information, until we find - // something that has a definitive stage attached to it (e.g., - // either an entry point or a GLSL translation unit). - // - // We don't currently have parent links in the reflection layout - // information, so doing that walk would be tricky right now, so - // it is easier to just bloat the representation and store yet another - // field on every variable layout. - return (SlangStage) varLayout->stage; -} - - -// Shader Parameter Reflection - -SLANG_API unsigned spReflectionParameter_GetBindingIndex(SlangReflectionParameter* inVarLayout) -{ - SlangReflectionVariableLayout* varLayout = (SlangReflectionVariableLayout*)inVarLayout; - return (unsigned) spReflectionVariableLayout_GetOffset( - varLayout, - spReflectionTypeLayout_GetParameterCategory( - spReflectionVariableLayout_GetTypeLayout(varLayout))); -} - -SLANG_API unsigned spReflectionParameter_GetBindingSpace(SlangReflectionParameter* inVarLayout) -{ - SlangReflectionVariableLayout* varLayout = (SlangReflectionVariableLayout*)inVarLayout; - return (unsigned) spReflectionVariableLayout_GetSpace( - varLayout, - spReflectionTypeLayout_GetParameterCategory( - spReflectionVariableLayout_GetTypeLayout(varLayout))); -} - -// Helpers for getting parameter count - -namespace Slang -{ - static unsigned getParameterCount(RefPtr typeLayout) - { - if(auto parameterGroupLayout = as(typeLayout)) - { - typeLayout = parameterGroupLayout->offsetElementTypeLayout; - } - - if(auto structLayout = as(typeLayout)) - { - return (unsigned) structLayout->fields.getCount(); - } - - return 0; - } - - static VarLayout* getParameterByIndex(RefPtr typeLayout, unsigned index) - { - if(auto parameterGroupLayout = as(typeLayout)) - { - typeLayout = parameterGroupLayout->offsetElementTypeLayout; - } - - if(auto structLayout = as(typeLayout)) - { - return structLayout->fields[index]; - } - - return 0; - } -} - -// Entry Point Reflection - -SLANG_API char const* spReflectionEntryPoint_getName( - SlangReflectionEntryPoint* inEntryPoint) -{ - auto entryPointLayout = convert(inEntryPoint); - if(!entryPointLayout) return 0; - - return getText(entryPointLayout->entryPoint->getName()).begin(); -} - -SLANG_API unsigned spReflectionEntryPoint_getParameterCount( - SlangReflectionEntryPoint* inEntryPoint) -{ - auto entryPointLayout = convert(inEntryPoint); - if(!entryPointLayout) return 0; - - return getParameterCount(entryPointLayout->parametersLayout->typeLayout); -} - -SLANG_API SlangReflectionVariableLayout* spReflectionEntryPoint_getParameterByIndex( - SlangReflectionEntryPoint* inEntryPoint, - unsigned index) -{ - auto entryPointLayout = convert(inEntryPoint); - if(!entryPointLayout) return 0; - - return convert(getParameterByIndex(entryPointLayout->parametersLayout->typeLayout, index)); -} - -SLANG_API SlangStage spReflectionEntryPoint_getStage(SlangReflectionEntryPoint* inEntryPoint) -{ - auto entryPointLayout = convert(inEntryPoint); - - if(!entryPointLayout) return SLANG_STAGE_NONE; - - return SlangStage(entryPointLayout->profile.GetStage()); -} - -SLANG_API void spReflectionEntryPoint_getComputeThreadGroupSize( - SlangReflectionEntryPoint* inEntryPoint, - SlangUInt axisCount, - SlangUInt* outSizeAlongAxis) -{ - auto entryPointLayout = convert(inEntryPoint); - - if(!entryPointLayout) return; - if(!axisCount) return; - if(!outSizeAlongAxis) return; - - auto entryPointFunc = entryPointLayout->entryPoint; - if(!entryPointFunc) return; - - SlangUInt sizeAlongAxis[3] = { 1, 1, 1 }; - - // First look for the HLSL case, where we have an attribute attached to the entry point function - auto numThreadsAttribute = entryPointFunc->FindModifier(); - if (numThreadsAttribute) - { - sizeAlongAxis[0] = numThreadsAttribute->x; - sizeAlongAxis[1] = numThreadsAttribute->y; - sizeAlongAxis[2] = numThreadsAttribute->z; - } - else - { - // Fall back to the GLSL case, which requires a search over global-scope declarations - // to look for as with the `local_size_*` qualifier - auto module = as(entryPointFunc->ParentDecl); - if (module) - { - for (auto dd : module->Members) - { - for (auto mod : dd->GetModifiersOfType()) - { - if (auto xMod = as(mod)) - sizeAlongAxis[0] = (SlangUInt) getIntegerLiteralValue(xMod->valToken); - else if (auto yMod = as(mod)) - sizeAlongAxis[1] = (SlangUInt) getIntegerLiteralValue(yMod->valToken); - else if (auto zMod = as(mod)) - sizeAlongAxis[2] = (SlangUInt) getIntegerLiteralValue(zMod->valToken); - } - } - } - } - - // - - if(axisCount > 0) outSizeAlongAxis[0] = sizeAlongAxis[0]; - if(axisCount > 1) outSizeAlongAxis[1] = sizeAlongAxis[1]; - if(axisCount > 2) outSizeAlongAxis[2] = sizeAlongAxis[2]; - for( SlangUInt aa = 3; aa < axisCount; ++aa ) - { - outSizeAlongAxis[aa] = 1; - } -} - -SLANG_API int spReflectionEntryPoint_usesAnySampleRateInput( - SlangReflectionEntryPoint* inEntryPoint) -{ - auto entryPointLayout = convert(inEntryPoint); - if(!entryPointLayout) - return 0; - - if (entryPointLayout->profile.GetStage() != Stage::Fragment) - return 0; - - return (entryPointLayout->flags & EntryPointLayout::Flag::usesAnySampleRateInput) != 0; -} - -// SlangReflectionTypeParameter -SLANG_API char const* spReflectionTypeParameter_GetName(SlangReflectionTypeParameter * inTypeParam) -{ - auto typeParam = convert(inTypeParam); - return typeParam->decl->getName()->text.getBuffer(); -} - -SLANG_API unsigned spReflectionTypeParameter_GetIndex(SlangReflectionTypeParameter * inTypeParam) -{ - auto typeParam = convert(inTypeParam); - return (unsigned)(typeParam->index); -} - -SLANG_API unsigned int spReflectionTypeParameter_GetConstraintCount(SlangReflectionTypeParameter* inTypeParam) -{ - auto typeParam = convert(inTypeParam); - auto constraints = typeParam->decl->getMembersOfType(); - return (unsigned int)constraints.getCount(); -} - -SLANG_API SlangReflectionType* spReflectionTypeParameter_GetConstraintByIndex(SlangReflectionTypeParameter * inTypeParam, unsigned index) -{ - auto typeParam = convert(inTypeParam); - auto constraints = typeParam->decl->getMembersOfType(); - return (SlangReflectionType*)constraints.toArray()[index]->sup.Ptr(); -} - -// Shader Reflection - -SLANG_API unsigned spReflection_GetParameterCount(SlangReflection* inProgram) -{ - auto program = convert(inProgram); - if(!program) return 0; - - auto globalStructLayout = getGlobalStructLayout(program); - if (!globalStructLayout) - return 0; - - return (unsigned) globalStructLayout->fields.getCount(); -} - -SLANG_API SlangReflectionParameter* spReflection_GetParameterByIndex(SlangReflection* inProgram, unsigned index) -{ - auto program = convert(inProgram); - if(!program) return nullptr; - - auto globalStructLayout = getGlobalStructLayout(program); - if (!globalStructLayout) - return 0; - - return convert(globalStructLayout->fields[index].Ptr()); -} - -SLANG_API unsigned int spReflection_GetTypeParameterCount(SlangReflection * reflection) -{ - auto program = convert(reflection); - return (unsigned int)program->globalGenericParams.getCount(); -} - -SLANG_API SlangReflectionTypeParameter* spReflection_GetTypeParameterByIndex(SlangReflection * reflection, unsigned int index) -{ - auto program = convert(reflection); - return (SlangReflectionTypeParameter*)program->globalGenericParams[index].Ptr(); -} - -SLANG_API SlangReflectionTypeParameter * spReflection_FindTypeParameter(SlangReflection * inProgram, char const * name) -{ - auto program = convert(inProgram); - if (!program) return nullptr; - GenericParamLayout * result = nullptr; - program->globalGenericParamsMap.TryGetValue(name, result); - return (SlangReflectionTypeParameter*)result; -} - -SLANG_API SlangUInt spReflection_getEntryPointCount(SlangReflection* inProgram) -{ - auto program = convert(inProgram); - if(!program) return 0; - - return SlangUInt(program->entryPoints.getCount()); -} - -SLANG_API SlangReflectionEntryPoint* spReflection_getEntryPointByIndex(SlangReflection* inProgram, SlangUInt index) -{ - auto program = convert(inProgram); - if(!program) return 0; - - return convert(program->entryPoints[(int) index].Ptr()); -} - -SLANG_API SlangReflectionEntryPoint* spReflection_findEntryPointByName(SlangReflection* inProgram, char const* name) -{ - auto program = convert(inProgram); - if(!program) return 0; - - // TODO: improve on dumb linear search - for(auto ep : program->entryPoints) - { - if(ep->entryPoint->getName()->text == name) - { - return convert(ep); - } - } - - return nullptr; -} - - -SLANG_API SlangUInt spReflection_getGlobalConstantBufferBinding(SlangReflection* inProgram) -{ - auto program = convert(inProgram); - if (!program) return 0; - auto cb = program->parametersLayout->FindResourceInfo(LayoutResourceKind::ConstantBuffer); - if (!cb) return 0; - return cb->index; -} - -SLANG_API size_t spReflection_getGlobalConstantBufferSize(SlangReflection* inProgram) -{ - auto program = convert(inProgram); - if (!program) return 0; - auto structLayout = getGlobalStructLayout(program); - auto uniform = structLayout->FindResourceInfo(LayoutResourceKind::Uniform); - if (!uniform) return 0; - return getReflectionSize(uniform->count); -} - -SLANG_API SlangReflectionType* spReflection_specializeType( - SlangReflection* inProgramLayout, - SlangReflectionType* inType, - SlangInt specializationArgCount, - SlangReflectionType* const* specializationArgs, - ISlangBlob** outDiagnostics) -{ - auto programLayout = convert(inProgramLayout); - if(!programLayout) return nullptr; - - auto unspecializedType = convert(inType); - if(!unspecializedType) return nullptr; - - auto linkage = programLayout->getProgram()->getLinkage(); - - DiagnosticSink sink; - sink.sourceManager = linkage->getSourceManager(); - - auto specializedType = linkage->specializeType(unspecializedType, specializationArgCount, (Type* const*) specializationArgs, &sink); - - sink.getBlobIfNeeded(outDiagnostics); - - return convert(specializedType); -} - diff --git a/source/slang/reflection.h b/source/slang/reflection.h deleted file mode 100644 index 09f02d8dd..000000000 --- a/source/slang/reflection.h +++ /dev/null @@ -1,26 +0,0 @@ -#ifndef SLANG_REFLECTION_H -#define SLANG_REFLECTION_H - -#include "../core/basic.h" -#include "syntax.h" - -#include "../../slang.h" - -namespace Slang { - -class ProgramLayout; -class TypeLayout; - -// - -SlangTypeKind getReflectionTypeKind(Type* type); - -SlangTypeKind getReflectionParameterCategory(TypeLayout* typeLayout); - -UInt getReflectionFieldCount(Type* type); -UInt getReflectionFieldByIndex(Type* type, UInt index); -UInt getReflectionFieldByIndex(TypeLayout* typeLayout, UInt index); - -} - -#endif // SLANG_REFLECTION_H diff --git a/source/slang/slang-c-like-source-emitter.cpp b/source/slang/slang-c-like-source-emitter.cpp index 917779b6d..dacb5e9b5 100644 --- a/source/slang/slang-c-like-source-emitter.cpp +++ b/source/slang/slang-c-like-source-emitter.cpp @@ -2,25 +2,25 @@ #include "slang-c-like-source-emitter.h" #include "../core/slang-writer.h" -#include "ir-bind-existentials.h" -#include "ir-dce.h" -#include "ir-entry-point-uniforms.h" -#include "ir-glsl-legalize.h" - -#include "ir-link.h" -#include "ir-restructure-scoping.h" -#include "ir-specialize.h" -#include "ir-specialize-resources.h" -#include "ir-ssa.h" -#include "ir-union.h" -#include "ir-validate.h" -#include "legalize-types.h" -#include "lower-to-ir.h" -#include "mangle.h" -#include "name.h" -#include "syntax.h" -#include "type-layout.h" -#include "visitor.h" +#include "slang-ir-bind-existentials.h" +#include "slang-ir-dce.h" +#include "slang-ir-entry-point-uniforms.h" +#include "slang-ir-glsl-legalize.h" + +#include "slang-ir-link.h" +#include "slang-ir-restructure-scoping.h" +#include "slang-ir-specialize.h" +#include "slang-ir-specialize-resources.h" +#include "slang-ir-ssa.h" +#include "slang-ir-union.h" +#include "slang-ir-validate.h" +#include "slang-legalize-types.h" +#include "slang-lower-to-ir.h" +#include "slang-mangle.h" +#include "slang-name.h" +#include "slang-syntax.h" +#include "slang-type-layout.h" +#include "slang-visitor.h" #include "slang-source-stream.h" #include "slang-emit-context.h" diff --git a/source/slang/slang-c-like-source-emitter.h b/source/slang/slang-c-like-source-emitter.h index 3c8f3dbef..4fcbb0b1c 100644 --- a/source/slang/slang-c-like-source-emitter.h +++ b/source/slang/slang-c-like-source-emitter.h @@ -2,17 +2,17 @@ #ifndef SLANG_C_LIKE_SOURCE_EMITTER_H_INCLUDED #define SLANG_C_LIKE_SOURCE_EMITTER_H_INCLUDED -#include "../core/basic.h" +#include "../core/slang-basic.h" -#include "compiler.h" +#include "slang-compiler.h" #include "slang-emit-context.h" #include "slang-extension-usage-tracker.h" #include "slang-emit-precedence.h" -#include "ir.h" -#include "ir-insts.h" -#include "ir-restructure.h" +#include "slang-ir.h" +#include "slang-ir-insts.h" +#include "slang-ir-restructure.h" namespace Slang { diff --git a/source/slang/slang-check.cpp b/source/slang/slang-check.cpp new file mode 100644 index 000000000..90947cf54 --- /dev/null +++ b/source/slang/slang-check.cpp @@ -0,0 +1,11334 @@ +#include "slang-syntax-visitors.h" + +#include "slang-lookup.h" +#include "slang-compiler.h" +#include "slang-visitor.h" + +#include "../core/slang-secure-crt.h" +#include + +namespace Slang +{ + RefPtr getTypeType( + Type* type); + + /// Should the given `decl` nested in `parentDecl` be treated as a static rather than instance declaration? + bool isEffectivelyStatic( + Decl* decl, + ContainerDecl* parentDecl) + { + // Things at the global scope are always "members" of their module. + // + if(as(parentDecl)) + return false; + + // Anything explicitly marked `static` and not at module scope + // counts as a static rather than instance declaration. + // + if(decl->HasModifier()) + return true; + + // Next we need to deal with cases where a declaration is + // effectively `static` even if the language doesn't make + // the user say so. Most languages make the default assumption + // that nested types are `static` even if they don't say + // so (Java is an exception here, perhaps due to some + // influence from the Scandanavian OOP tradition of Beta/gbeta). + // + if(as(decl)) + return true; + if(as(decl)) + return true; + + // Things nested inside functions may have dependencies + // on values from the enclosing scope, but this needs to + // be dealt with via "capture" so they are also effectively + // `static` + // + if(as(parentDecl)) + return true; + + // Type constraint declarations are used in member-reference + // context as a form of casting operation, so we treat them + // as if they are instance members. This is a bit of a hack, + // but it achieves the result we want until we have an + // explicit representation of up-cast operations in the + // AST. + // + if(as(decl)) + return false; + + return false; + } + + /// Should the given `decl` be treated as a static rather than instance declaration? + bool isEffectivelyStatic( + Decl* decl) + { + // For the purposes of an ordinary declaration, when determining if + // it is static or per-instance, the "parent" declaration we really + // care about is the next outer non-generic declaration. + // + // TODO: This idiom of getting the "next outer non-generic declaration" + // comes up just enough that we should probably have a convenience + // function for it. + + auto parentDecl = decl->ParentDecl; + if(auto genericDecl = as(parentDecl)) + parentDecl = genericDecl->ParentDecl; + + return isEffectivelyStatic(decl, parentDecl); + } + + /// Is `decl` a global shader parameter declaration? + bool isGlobalShaderParameter(VarDeclBase* decl) + { + // A global shader parameter must be declared at global (module) scope. + // + if(!as(decl->ParentDecl)) return false; + + // A global variable marked `static` indicates a traditional + // global variable (albeit one that is implicitly local to + // the translation unit) + // + if(decl->HasModifier()) return false; + + // The `groupshared` modifier indicates that a variable cannot + // be a shader parameters, but is instead transient storage + // allocated for the duration of a thread-group's execution. + // + if(decl->HasModifier()) return false; + + return true; + } + + // A flat representation of basic types (scalars, vectors and matrices) + // that can be used as lookup key in caches + struct BasicTypeKey + { + union + { + struct + { + unsigned char type : 4; + unsigned char dim1 : 2; + unsigned char dim2 : 2; + } data; + unsigned char aggVal; + }; + bool fromType(Type* typeIn) + { + aggVal = 0; + if (auto basicType = as(typeIn)) + { + data.type = (unsigned char)basicType->baseType; + data.dim1 = data.dim2 = 0; + } + else if (auto vectorType = as(typeIn)) + { + if (auto elemCount = as(vectorType->elementCount)) + { + data.dim1 = elemCount->value - 1; + auto elementBasicType = as(vectorType->elementType); + data.type = (unsigned char)elementBasicType->baseType; + data.dim2 = 0; + } + else + return false; + } + else if (auto matrixType = as(typeIn)) + { + if (auto elemCount1 = as(matrixType->getRowCount())) + { + if (auto elemCount2 = as(matrixType->getColumnCount())) + { + auto elemBasicType = as(matrixType->getElementType()); + data.type = (unsigned char)elemBasicType->baseType; + data.dim1 = elemCount1->value - 1; + data.dim2 = elemCount2->value - 1; + } + } + else + return false; + } + else + return false; + return true; + } + }; + + struct BasicTypeKeyPair + { + BasicTypeKey type1, type2; + bool operator == (BasicTypeKeyPair p) + { + return type1.aggVal == p.type1.aggVal && type2.aggVal == p.type2.aggVal; + } + int GetHashCode() + { + return combineHash(type1.aggVal, type2.aggVal); + } + }; + + struct OverloadCandidate + { + enum class Flavor + { + Func, + Generic, + UnspecializedGeneric, + }; + Flavor flavor; + + enum class Status + { + GenericArgumentInferenceFailed, + Unchecked, + ArityChecked, + FixityChecked, + TypeChecked, + DirectionChecked, + Applicable, + }; + Status status = Status::Unchecked; + + // Reference to the declaration being applied + LookupResultItem item; + + // The type of the result expression if this candidate is selected + RefPtr resultType; + + // A system for tracking constraints introduced on generic parameters + // ConstraintSystem constraintSystem; + + // How much conversion cost should be considered for this overload, + // when ranking candidates. + ConversionCost conversionCostSum = kConversionCost_None; + + // When required, a candidate can store a pre-checked list of + // arguments so that we don't have to repeat work across checking + // phases. Currently this is only needed for generics. + RefPtr subst; + }; + + struct OperatorOverloadCacheKey + { + IROp operatorName; + BasicTypeKey args[2]; + bool operator == (OperatorOverloadCacheKey key) + { + return operatorName == key.operatorName && args[0].aggVal == key.args[0].aggVal + && args[1].aggVal == key.args[1].aggVal; + } + int GetHashCode() + { + return ((int)(UInt64)(void*)(operatorName) << 16) ^ (args[0].aggVal << 8) ^ (args[1].aggVal); + } + bool fromOperatorExpr(OperatorExpr* opExpr) + { + // First, lets see if the argument types are ones + // that we can encode in our space of keys. + args[0].aggVal = 0; + args[1].aggVal = 0; + if (opExpr->Arguments.getCount() > 2) + return false; + + for (Index i = 0; i < opExpr->Arguments.getCount(); i++) + { + if (!args[i].fromType(opExpr->Arguments[i]->type.Ptr())) + return false; + } + + // Next, lets see if we can find an intrinsic opcode + // attached to an overloaded definition (filtered for + // definitions that could conceivably apply to us). + // + // TODO: This should really be parsed on the operator name + // plus fixity, rather than the intrinsic opcode... + // + // We will need to reject postfix definitions for prefix + // operators, and vice versa, to ensure things work. + // + auto prefixExpr = as(opExpr); + auto postfixExpr = as(opExpr); + + if (auto overloadedBase = as(opExpr->FunctionExpr)) + { + for(auto item : overloadedBase->lookupResult2 ) + { + // Look at a candidate definition to be called and + // see if it gives us a key to work with. + // + Decl* funcDecl = overloadedBase->lookupResult2.item.declRef.decl; + if (auto genDecl = as(funcDecl)) + funcDecl = genDecl->inner.Ptr(); + + // Reject definitions that have the wrong fixity. + // + if(prefixExpr && !funcDecl->FindModifier()) + continue; + if(postfixExpr && !funcDecl->FindModifier()) + continue; + + if (auto intrinsicOp = funcDecl->FindModifier()) + { + operatorName = intrinsicOp->op; + return true; + } + } + } + return false; + } + }; + + struct TypeCheckingCache + { + Dictionary resolvedOperatorOverloadCache; + Dictionary conversionCostCache; + }; + + TypeCheckingCache* Session::getTypeCheckingCache() + { + if (!typeCheckingCache) + typeCheckingCache = new TypeCheckingCache(); + return typeCheckingCache; + } + + void Session::destroyTypeCheckingCache() + { + delete typeCheckingCache; + typeCheckingCache = nullptr; + } + + namespace { // anonymous + struct FunctionInfo + { + const char* name; + SharedLibraryType libraryType; + }; + } // anonymous + + static FunctionInfo _getFunctionInfo(Session::SharedLibraryFuncType funcType) + { + typedef Session::SharedLibraryFuncType FuncType; + typedef SharedLibraryType LibType; + + switch (funcType) + { + case FuncType::Glslang_Compile: return { "glslang_compile", LibType::Glslang } ; + case FuncType::Fxc_D3DCompile: return { "D3DCompile", LibType::Fxc }; + case FuncType::Fxc_D3DDisassemble: return { "D3DDisassemble", LibType::Fxc }; + case FuncType::Dxc_DxcCreateInstance: return { "DxcCreateInstance", LibType::Dxc }; + default: return { nullptr, LibType::Unknown }; + } + } + + ISlangSharedLibrary* Session::getOrLoadSharedLibrary(SharedLibraryType type, DiagnosticSink* sink) + { + // If not loaded, try loading it + if (!sharedLibraries[int(type)]) + { + // Try to preload dxil first, if loading dxc + if (type == SharedLibraryType::Dxc) + { + // Pass nullptr as the sink, because if it fails we don't want to report as error + getOrLoadSharedLibrary(SharedLibraryType::Dxil, nullptr); + } + + const char* libName = DefaultSharedLibraryLoader::getSharedLibraryNameFromType(type); + if (SLANG_FAILED(sharedLibraryLoader->loadSharedLibrary(libName, sharedLibraries[int(type)].writeRef()))) + { + if (sink) + { + sink->diagnose(SourceLoc(), Diagnostics::failedToLoadDynamicLibrary, libName); + } + return nullptr; + } + } + return sharedLibraries[int(type)]; + } + + SlangFuncPtr Session::getSharedLibraryFunc(SharedLibraryFuncType type, DiagnosticSink* sink) + { + if (sharedLibraryFunctions[int(type)]) + { + return sharedLibraryFunctions[int(type)]; + } + // do we have the library + FunctionInfo info = _getFunctionInfo(type); + if (info.name == nullptr) + { + return nullptr; + } + // Try loading the library + ISlangSharedLibrary* sharedLib = getOrLoadSharedLibrary(info.libraryType, sink); + if (!sharedLib) + { + return nullptr; + } + + // Okay now access the func + SlangFuncPtr func = sharedLib->findFuncByName(info.name); + if (!func) + { + const char* libName = DefaultSharedLibraryLoader::getSharedLibraryNameFromType(info.libraryType); + sink->diagnose(SourceLoc(), Diagnostics::failedToFindFunctionInSharedLibrary, info.name, libName); + return nullptr; + } + + // Store in the function cache + sharedLibraryFunctions[int(type)] = func; + return func; + } + + + enum class CheckingPhase + { + Header, Body + }; + + struct SemanticsVisitor + : ExprVisitor> + , StmtVisitor + , DeclVisitor + { + CheckingPhase checkingPhase = CheckingPhase::Header; + DeclCheckState getCheckedState() + { + if (checkingPhase == CheckingPhase::Body) + return DeclCheckState::Checked; + else + return DeclCheckState::CheckedHeader; + } + + Linkage* m_linkage = nullptr; + DiagnosticSink* m_sink = nullptr; + + DiagnosticSink* getSink() + { + return m_sink; + } + +// ModuleDecl * program = nullptr; + FuncDecl * function = nullptr; + + + // lexical outer statements + List outerStmts; + + // We need to track what has been `import`ed, + // to avoid importing the same thing more than once + // + // TODO: a smarter approach might be to filter + // out duplicate references during lookup. + HashSet importedModules; + + public: + SemanticsVisitor( + Linkage* linkage, + DiagnosticSink* sink) + : m_linkage(linkage) + , m_sink(sink) + {} + + Session* getSession() + { + return m_linkage->getSession(); + } + + public: + // Translate Types + RefPtr typeResult; + RefPtr TranslateTypeNodeImpl(const RefPtr & node) + { + if (!node) return nullptr; + + auto expr = CheckTerm(node); + expr = ExpectATypeRepr(expr); + return expr; + } + RefPtr ExtractTypeFromTypeRepr(const RefPtr& typeRepr) + { + if (!typeRepr) return nullptr; + if (auto typeType = as(typeRepr->type)) + { + return typeType->type; + } + return getSession()->getErrorType(); + } + RefPtr TranslateTypeNode(const RefPtr & node) + { + if (!node) return nullptr; + auto typeRepr = TranslateTypeNodeImpl(node); + return ExtractTypeFromTypeRepr(typeRepr); + } + TypeExp TranslateTypeNodeForced(TypeExp const& typeExp) + { + auto typeRepr = TranslateTypeNodeImpl(typeExp.exp); + + TypeExp result; + result.exp = typeRepr; + result.type = ExtractTypeFromTypeRepr(typeRepr); + return result; + } + TypeExp TranslateTypeNode(TypeExp const& typeExp) + { + // HACK(tfoley): It seems that in some cases we end up re-checking + // syntax that we've already checked. We need to root-cause that + // issue, but for now a quick fix in this case is to early + // exist if we've already got a type associated here: + if (typeExp.type) + { + return typeExp; + } + return TranslateTypeNodeForced(typeExp); + } + + RefPtr getExprDeclRefType(Expr * expr) + { + if (auto typetype = as(expr->type)) + return typetype->type.dynamicCast(); + else + return as(expr->type); + } + + /// Is `decl` usable as a static member? + bool isDeclUsableAsStaticMember( + Decl* decl) + { + if(decl->HasModifier()) + return true; + + if(as(decl)) + return true; + + if(as(decl)) + return true; + + if(as(decl)) + return true; + + if(as(decl)) + return true; + + return false; + } + + /// Is `item` usable as a static member? + bool isUsableAsStaticMember( + LookupResultItem const& item) + { + // There's a bit of a gotcha here, because a lookup result + // item might include "breadcrumbs" that indicate more steps + // along the lookup path. As a result it isn't always + // valid to just check whether the final decl is usable + // as a static member, because it might not even be a + // member of the thing we are trying to work with. + // + + Decl* decl = item.declRef.getDecl(); + for(auto bb = item.breadcrumbs; bb; bb = bb->next) + { + switch(bb->kind) + { + // In case lookup went through a `__transparent` member, + // we are interested in the static-ness of that transparent + // member, and *not* the static-ness of whatever was inside + // of it. + // + // TODO: This would need some work if we ever had + // transparent *type* members. + // + case LookupResultItem::Breadcrumb::Kind::Member: + decl = bb->declRef.getDecl(); + break; + + // TODO: Are there any other cases that need special-case + // handling here? + + default: + break; + } + } + + // Okay, we've found the declaration we should actually + // be checking, so lets validate that. + + return isDeclUsableAsStaticMember(decl); + } + + /// Move `expr` into a temporary variable and execute `func` on that variable. + /// + /// Returns an expression that wraps both the creation and initialization of + /// the temporary, and the computation created by `func`. + /// + template + RefPtr moveTemp(RefPtr const& expr, F const& func) + { + RefPtr varDecl = new VarDecl(); + varDecl->ParentDecl = nullptr; // TODO: need to fill this in somehow! + varDecl->checkState = DeclCheckState::Checked; + varDecl->nameAndLoc.loc = expr->loc; + varDecl->initExpr = expr; + varDecl->type.type = expr->type.type; + + auto varDeclRef = makeDeclRef(varDecl.Ptr()); + + RefPtr letExpr = new LetExpr(); + letExpr->decl = varDecl; + + auto body = func(varDeclRef); + + letExpr->body = body; + letExpr->type = body->type; + + return letExpr; + } + + /// Execute `func` on a variable with the value of `expr`. + /// + /// If `expr` is just a reference to an immutable (e.g., `let`) variable + /// then this might use the existing variable. Otherwise it will create + /// a new variable to hold `expr`, using `moveTemp()`. + /// + template + RefPtr maybeMoveTemp(RefPtr const& expr, F const& func) + { + if(auto varExpr = as(expr)) + { + auto declRef = varExpr->declRef; + if(auto varDeclRef = declRef.as()) + return func(varDeclRef); + } + + return moveTemp(expr, func); + } + + /// Return an expression that represents "opening" the existential `expr`. + /// + /// The type of `expr` must be an interface type, matching `interfaceDeclRef`. + /// + /// If we scope down the PL theory to just the case that Slang cares about, + /// a value of an existential type like `IMover` is a tuple of: + /// + /// * a concrete type `X` + /// * a witness `w` of the fact that `X` implements `IMover` + /// * a value `v` of type `X` + /// + /// "Opening" an existential value is the process of decomposing a single + /// value `e : IMover` into the pieces `X`, `w`, and `v`. + /// + /// Rather than return all those pieces individually, this operation + /// returns an expression that logically corresponds to `v`: an expression + /// of type `X`, where the type carries the knowledge that `X` implements `IMover`. + /// + RefPtr openExistential( + RefPtr expr, + DeclRef interfaceDeclRef) + { + // If `expr` refers to an immutable binding, + // then we can use it directly. If it refers + // to an arbitrary expression or a mutable + // binding, we will move its value into an + // immutable temporary so that we can use + // it directly. + // + auto interfaceDecl = interfaceDeclRef.getDecl(); + return maybeMoveTemp(expr, [&](DeclRef varDeclRef) + { + RefPtr openedType = new ExtractExistentialType(); + openedType->declRef = varDeclRef; + + RefPtr openedWitness = new ExtractExistentialSubtypeWitness(); + openedWitness->sub = openedType; + openedWitness->sup = expr->type.type; + openedWitness->declRef = varDeclRef; + + RefPtr openedThisType = new ThisTypeSubstitution(); + openedThisType->outer = interfaceDeclRef.substitutions.substitutions; + openedThisType->interfaceDecl = interfaceDecl; + openedThisType->witness = openedWitness; + + DeclRef substDeclRef = DeclRef(interfaceDecl, openedThisType); + auto substDeclRefType = DeclRefType::Create(getSession(), substDeclRef); + + RefPtr openedValue = new ExtractExistentialValueExpr(); + openedValue->declRef = varDeclRef; + openedValue->type = QualType(substDeclRefType); + + return openedValue; + }); + } + + /// If `expr` has existential type, then open it. + /// + /// Returns an expression that opens `expr` if it had existential type. + /// Otherwise returns `expr` itself. + /// + /// See `openExistential` for a discussion of what "opening" an + /// existential-type value means. + /// + RefPtr maybeOpenExistential(RefPtr expr) + { + auto exprType = expr->type.type; + + if(auto declRefType = as(exprType)) + { + if(auto interfaceDeclRef = declRefType->declRef.as()) + { + // Is there an this-type substitution being applied, so that + // we are referencing the interface type through a concrete + // type (e.g., a type parameter constrained to this interface)? + // + // Because of the way that substitutions need to mirror the nesting + // hierarchy of declarations, any this-type substitution pertaining + // to the chosen interface decl must be the first substitution on + // the list (which is a linked list from the "inside" out). + // + auto thisTypeSubst = interfaceDeclRef.substitutions.substitutions.as(); + if(thisTypeSubst && thisTypeSubst->interfaceDecl == interfaceDeclRef.decl) + { + // This isn't really an existential type, because somebody + // has already filled in a this-type substitution. + } + else + { + // Okay, here is the case that matters. + // + return openExistential(expr, interfaceDeclRef); + } + } + } + + // Default: apply the callback to the original expression; + return expr; + } + + RefPtr ConstructDeclRefExpr( + DeclRef declRef, + RefPtr baseExpr, + SourceLoc loc) + { + // Compute the type that this declaration reference will have in context. + // + auto type = GetTypeForDeclRef(declRef); + + // Construct an appropriate expression based on the structured of + // the declaration reference. + // + if (baseExpr) + { + // If there was a base expression, we will have some kind of + // member expression. + + // We want to check for the case where the base "expression" + // actually names a type, because in that case we are doing + // a static member reference. + // + // TODO: Should we be checking if the member is static here? + // If it isn't, should we be automatically producing a "curried" + // form (e.g., for a member function, return a value usable + // for referencing it as a free function). + // + if (as(baseExpr->type)) + { + auto expr = new StaticMemberExpr(); + expr->loc = loc; + expr->type = type; + expr->BaseExpression = baseExpr; + expr->name = declRef.GetName(); + expr->declRef = declRef; + return expr; + } + else if(isEffectivelyStatic(declRef.getDecl())) + { + // Extract the type of the baseExpr + auto baseExprType = baseExpr->type.type; + RefPtr baseTypeExpr = new SharedTypeExpr(); + baseTypeExpr->base.type = baseExprType; + baseTypeExpr->type.type = getTypeType(baseExprType); + + auto expr = new StaticMemberExpr(); + expr->loc = loc; + expr->type = type; + expr->BaseExpression = baseTypeExpr; + expr->name = declRef.GetName(); + expr->declRef = declRef; + return expr; + } + else + { + // If the base expression wasn't a type, then this + // is a normal member expression. + // + auto expr = new MemberExpr(); + expr->loc = loc; + expr->type = type; + expr->BaseExpression = baseExpr; + expr->name = declRef.GetName(); + expr->declRef = declRef; + + // When referring to a member through an expression, + // the result is only an l-value if both the base + // expression and the member agree that it should be. + // + // We have already used the `QualType` from the member + // above (that is `type`), so we need to take the + // l-value status of the base expression into account now. + if(!baseExpr->type.IsLeftValue) + { + expr->type.IsLeftValue = false; + } + + return expr; + } + } + else + { + // If there is no base expression, then the result must + // be an ordinary variable expression. + // + auto expr = new VarExpr(); + expr->loc = loc; + expr->name = declRef.GetName(); + expr->type = type; + expr->declRef = declRef; + return expr; + } + } + + RefPtr ConstructDerefExpr( + RefPtr base, + SourceLoc loc) + { + auto ptrLikeType = as(base->type); + SLANG_ASSERT(ptrLikeType); + + auto derefExpr = new DerefExpr(); + derefExpr->loc = loc; + derefExpr->base = base; + derefExpr->type = QualType(ptrLikeType->elementType); + + // TODO(tfoley): handle l-value status here + + return derefExpr; + } + + RefPtr createImplicitThisMemberExpr( + Type* type, + SourceLoc loc, + LookupResultItem::Breadcrumb::ThisParameterMode thisParameterMode) + { + RefPtr expr = new ThisExpr(); + expr->type = type; + expr->type.IsLeftValue = thisParameterMode == LookupResultItem::Breadcrumb::ThisParameterMode::Mutating; + expr->loc = loc; + return expr; + } + + RefPtr ConstructLookupResultExpr( + LookupResultItem const& item, + RefPtr baseExpr, + SourceLoc loc) + { + // If we collected any breadcrumbs, then these represent + // additional segments of the lookup path that we need + // to expand here. + auto bb = baseExpr; + for (auto breadcrumb = item.breadcrumbs; breadcrumb; breadcrumb = breadcrumb->next) + { + switch (breadcrumb->kind) + { + case LookupResultItem::Breadcrumb::Kind::Member: + bb = ConstructDeclRefExpr(breadcrumb->declRef, bb, loc); + break; + + case LookupResultItem::Breadcrumb::Kind::Deref: + bb = ConstructDerefExpr(bb, loc); + break; + + case LookupResultItem::Breadcrumb::Kind::Constraint: + { + // TODO: do we need to make something more + // explicit here? + bb = ConstructDeclRefExpr( + breadcrumb->declRef, + bb, + loc); + } + break; + + case LookupResultItem::Breadcrumb::Kind::This: + { + // We expect a `this` to always come + // at the start of a chain. + SLANG_ASSERT(bb == nullptr); + + // The member was looked up via a `this` expression, + // so we need to create one here. + if (auto extensionDeclRef = breadcrumb->declRef.as()) + { + bb = createImplicitThisMemberExpr( + GetTargetType(extensionDeclRef), + loc, + breadcrumb->thisParameterMode); + } + else + { + auto type = DeclRefType::Create(getSession(), breadcrumb->declRef); + bb = createImplicitThisMemberExpr( + type, + loc, + breadcrumb->thisParameterMode); + } + } + break; + + default: + SLANG_UNREACHABLE("all cases handle"); + } + } + + return ConstructDeclRefExpr(item.declRef, bb, loc); + } + + RefPtr createLookupResultExpr( + LookupResult const& lookupResult, + RefPtr baseExpr, + SourceLoc loc) + { + if (lookupResult.isOverloaded()) + { + auto overloadedExpr = new OverloadedExpr(); + overloadedExpr->loc = loc; + overloadedExpr->type = QualType( + getSession()->getOverloadedType()); + overloadedExpr->base = baseExpr; + overloadedExpr->lookupResult2 = lookupResult; + return overloadedExpr; + } + else + { + return ConstructLookupResultExpr(lookupResult.item, baseExpr, loc); + } + } + + RefPtr ResolveOverloadedExpr(RefPtr overloadedExpr, LookupMask mask) + { + auto lookupResult = overloadedExpr->lookupResult2; + SLANG_RELEASE_ASSERT(lookupResult.isValid() && lookupResult.isOverloaded()); + + // Take the lookup result we had, and refine it based on what is expected in context. + lookupResult = refineLookup(lookupResult, mask); + + if (!lookupResult.isValid()) + { + // If we didn't find any symbols after filtering, then just + // use the original and report errors that way + return overloadedExpr; + } + + if (lookupResult.isOverloaded()) + { + // We had an ambiguity anyway, so report it. + getSink()->diagnose(overloadedExpr, Diagnostics::ambiguousReference, lookupResult.items[0].declRef.GetName()); + + for(auto item : lookupResult.items) + { + String declString = getDeclSignatureString(item); + getSink()->diagnose(item.declRef, Diagnostics::overloadCandidate, declString); + } + + // TODO(tfoley): should we construct a new ErrorExpr here? + return CreateErrorExpr(overloadedExpr); + } + + // otherwise, we had a single decl and it was valid, hooray! + return ConstructLookupResultExpr(lookupResult.item, overloadedExpr->base, overloadedExpr->loc); + } + + RefPtr ExpectATypeRepr(RefPtr expr) + { + if (auto overloadedExpr = as(expr)) + { + expr = ResolveOverloadedExpr(overloadedExpr, LookupMask::type); + } + + if (auto typeType = as(expr->type)) + { + return expr; + } + else if (auto errorType = as(expr->type)) + { + return expr; + } + + getSink()->diagnose(expr, Diagnostics::unimplemented, "expected a type"); + return CreateErrorExpr(expr); + } + + RefPtr ExpectAType(RefPtr expr) + { + auto typeRepr = ExpectATypeRepr(expr); + if (auto typeType = as(typeRepr->type)) + { + return typeType->type; + } + return getSession()->getErrorType(); + } + + RefPtr ExtractGenericArgType(RefPtr exp) + { + return ExpectAType(exp); + } + + RefPtr ExtractGenericArgInteger(RefPtr exp) + { + return CheckIntegerConstantExpression(exp.Ptr()); + } + + RefPtr ExtractGenericArgVal(RefPtr exp) + { + if (auto overloadedExpr = as(exp)) + { + // assume that if it is overloaded, we want a type + exp = ResolveOverloadedExpr(overloadedExpr, LookupMask::type); + } + + if (auto typeType = as(exp->type)) + { + return typeType->type; + } + else if (auto errorType = as(exp->type)) + { + return exp->type.type; + } + else + { + return ExtractGenericArgInteger(exp); + } + } + + // Construct a type representing the instantiation of + // the given generic declaration for the given arguments. + // The arguments should already be checked against + // the declaration. + RefPtr InstantiateGenericType( + DeclRef genericDeclRef, + List> const& args) + { + RefPtr subst = new GenericSubstitution(); + subst->genericDecl = genericDeclRef.getDecl(); + subst->outer = genericDeclRef.substitutions.substitutions; + + for (auto argExpr : args) + { + subst->args.add(ExtractGenericArgVal(argExpr)); + } + + DeclRef innerDeclRef; + innerDeclRef.decl = GetInner(genericDeclRef); + innerDeclRef.substitutions = SubstitutionSet(subst); + + return DeclRefType::Create( + getSession(), + innerDeclRef); + } + + // This routine is a bottleneck for all declaration checking, + // so that we can add some quality-of-life features for users + // in cases where the compiler crashes + void dispatchDecl(DeclBase* decl) + { + try + { + DeclVisitor::dispatch(decl); + } + // Don't emit any context message for an explicit `AbortCompilationException` + // because it should only happen when an error is already emitted. + catch(AbortCompilationException&) { throw; } + catch(...) + { + getSink()->noteInternalErrorLoc(decl->loc); + throw; + } + } + void dispatchStmt(Stmt* stmt) + { + try + { + StmtVisitor::dispatch(stmt); + } + catch(AbortCompilationException&) { throw; } + catch(...) + { + getSink()->noteInternalErrorLoc(stmt->loc); + throw; + } + } + void dispatchExpr(Expr* expr) + { + try + { + ExprVisitor::dispatch(expr); + } + catch(AbortCompilationException&) { throw; } + catch(...) + { + getSink()->noteInternalErrorLoc(expr->loc); + throw; + } + } + + // Make sure a declaration has been checked, so we can refer to it. + // Note that this may lead to us recursively invoking checking, + // so this may not be the best way to handle things. + void EnsureDecl(RefPtr decl, DeclCheckState state) + { + if (decl->IsChecked(state)) return; + if (decl->checkState == DeclCheckState::CheckingHeader) + { + // We tried to reference the same declaration while checking it! + // + // TODO: we should ideally be tracking a "chain" of declarations + // being checked on the stack, so that we can report the full + // chain that leads from this declaration back to itself. + // + getSink()->diagnose(decl, Diagnostics::cyclicReference, decl); + return; + } + + // Hack: if we are somehow referencing a local variable declaration + // before the line of code that defines it, then we need to diagnose + // an error. + // + // TODO: The right answer is that lookup should have been performed in + // the scope that was in place *before* the variable was declared, but + // this is a quick fix that at least alerts the user to how we are + // interpreting their code. + // + if (auto varDecl = as(decl)) + { + if (auto parenScope = as(varDecl->ParentDecl)) + { + // TODO: This diagnostic should be emitted on the line that is referencing + // the declaration. That requires `EnsureDecl` to take the requesting + // location as a parameter. + getSink()->diagnose(decl, Diagnostics::localVariableUsedBeforeDeclared, decl); + return; + } + } + + if (DeclCheckState::CheckingHeader > decl->checkState) + { + decl->SetCheckState(DeclCheckState::CheckingHeader); + } + + // Check the modifiers on the declaration first, in case + // semantics of the body itself will depend on them. + checkModifiers(decl); + + // Use visitor pattern to dispatch to correct case + dispatchDecl(decl); + + if(state > decl->checkState) + { + decl->SetCheckState(state); + } + } + + void EnusreAllDeclsRec(RefPtr decl) + { + checkDecl(decl); + if (auto containerDecl = as(decl)) + { + for (auto m : containerDecl->Members) + { + EnusreAllDeclsRec(m); + } + } + } + + // A "proper" type is one that can be used as the type of an expression. + // Put simply, it can be a concrete type like `int`, or a generic + // type that is applied to arguments, like `Texture2D`. + // The type `void` is also a proper type, since we can have expressions + // that return a `void` result (e.g., many function calls). + // + // A "non-proper" type is any type that can't actually have values. + // A simple example of this in C++ is `std::vector` - you can't have + // a value of this type. + // + // Part of what this function does is give errors if somebody tries + // to use a non-proper type as the type of a variable (or anything + // else that needs a proper type). + // + // The other thing it handles is the fact that HLSL lets you use + // the name of a non-proper type, and then have the compiler fill + // in the default values for its type arguments (e.g., a variable + // given type `Texture2D` will actually have type `Texture2D`). + bool CoerceToProperTypeImpl( + TypeExp const& typeExp, + RefPtr* outProperType, + DiagnosticSink* diagSink) + { + Type* type = typeExp.type.Ptr(); + if(!type && typeExp.exp) + { + if(auto typeType = as(typeExp.exp->type)) + { + type = typeType->type; + } + } + + if (!type) + { + if (outProperType) + { + *outProperType = nullptr; + } + return false; + } + + if (auto genericDeclRefType = as(type)) + { + // We are using a reference to a generic declaration as a concrete + // type. This means we should substitute in any default parameter values + // if they are available. + // + // TODO(tfoley): A more expressive type system would substitute in + // "fresh" variables and then solve for their values... + // + + auto genericDeclRef = genericDeclRefType->GetDeclRef(); + checkDecl(genericDeclRef.decl); + List> args; + for (RefPtr member : genericDeclRef.getDecl()->Members) + { + if (auto typeParam = as(member)) + { + if (!typeParam->initType.exp) + { + if (diagSink) + { + diagSink->diagnose(typeExp.exp.Ptr(), Diagnostics::genericTypeNeedsArgs, typeExp); + *outProperType = getSession()->getErrorType(); + } + return false; + } + + // TODO: this is one place where syntax should get cloned! + if (outProperType) + args.add(typeParam->initType.exp); + } + else if (auto valParam = as(member)) + { + if (!valParam->initExpr) + { + if (diagSink) + { + diagSink->diagnose(typeExp.exp.Ptr(), Diagnostics::unimplemented, "can't fill in default for generic type parameter"); + *outProperType = getSession()->getErrorType(); + } + return false; + } + + // TODO: this is one place where syntax should get cloned! + if (outProperType) + args.add(valParam->initExpr); + } + else + { + // ignore non-parameter members + } + } + + if (outProperType) + { + *outProperType = InstantiateGenericType(genericDeclRef, args); + } + return true; + } + + // default case: we expect this to already be a proper type + if (outProperType) + { + *outProperType = type; + } + return true; + } + + + + TypeExp CoerceToProperType(TypeExp const& typeExp) + { + TypeExp result = typeExp; + CoerceToProperTypeImpl(typeExp, &result.type, getSink()); + return result; + } + + TypeExp tryCoerceToProperType(TypeExp const& typeExp) + { + TypeExp result = typeExp; + if(!CoerceToProperTypeImpl(typeExp, &result.type, nullptr)) + return TypeExp(); + return result; + } + + // Check a type, and coerce it to be proper + TypeExp CheckProperType(TypeExp typeExp) + { + return CoerceToProperType(TranslateTypeNode(typeExp)); + } + + // For our purposes, a "usable" type is one that can be + // used to declare a function parameter, variable, etc. + // These turn out to be all the proper types except + // `void`. + // + // TODO(tfoley): consider just allowing `void` as a + // simple example of a "unit" type, and get rid of + // this check. + TypeExp CoerceToUsableType(TypeExp const& typeExp) + { + TypeExp result = CoerceToProperType(typeExp); + Type* type = result.type.Ptr(); + if (auto basicType = as(type)) + { + // TODO: `void` shouldn't be a basic type, to make this easier to avoid + if (basicType->baseType == BaseType::Void) + { + // TODO(tfoley): pick the right diagnostic message + getSink()->diagnose(result.exp.Ptr(), Diagnostics::invalidTypeVoid); + result.type = getSession()->getErrorType(); + return result; + } + } + return result; + } + + // Check a type, and coerce it to be usable + TypeExp CheckUsableType(TypeExp typeExp) + { + return CoerceToUsableType(TranslateTypeNode(typeExp)); + } + + RefPtr CheckTerm(RefPtr term) + { + if (!term) return nullptr; + return ExprVisitor::dispatch(term); + } + + RefPtr CreateErrorExpr(Expr* expr) + { + expr->type = QualType(getSession()->getErrorType()); + return expr; + } + + bool IsErrorExpr(RefPtr expr) + { + // TODO: we may want other cases here... + + if (auto errorType = as(expr->type)) + return true; + + return false; + } + + // Capture the "base" expression in case this is a member reference + RefPtr GetBaseExpr(RefPtr expr) + { + if (auto memberExpr = as(expr)) + { + return memberExpr->BaseExpression; + } + else if(auto overloadedExpr = as(expr)) + { + return overloadedExpr->base; + } + return nullptr; + } + + public: + + bool ValuesAreEqual( + RefPtr left, + RefPtr right) + { + if(left == right) return true; + + if(auto leftConst = as(left)) + { + if(auto rightConst = as(right)) + { + return leftConst->value == rightConst->value; + } + } + + if(auto leftVar = as(left)) + { + if(auto rightVar = as(right)) + { + return leftVar->declRef.Equals(rightVar->declRef); + } + } + + return false; + } + + // Compute the cost of using a particular declaration to + // perform implicit type conversion. + ConversionCost getImplicitConversionCost( + Decl* decl) + { + if(auto modifier = decl->FindModifier()) + { + return modifier->cost; + } + + return kConversionCost_Explicit; + } + + bool isEffectivelyScalarForInitializerLists( + RefPtr type) + { + if(as(type)) return false; + if(as(type)) return false; + if(as(type)) return false; + + if(as(type)) + { + return true; + } + + if(as(type)) + { + return true; + } + if(as(type)) + { + return true; + } + if(as(type)) + { + return true; + } + + if(auto declRefType = as(type)) + { + if(as(declRefType->declRef)) + return false; + } + + return true; + } + + /// Should the provided expression (from an initializer list) be used directly to initialize `toType`? + bool shouldUseInitializerDirectly( + RefPtr toType, + RefPtr fromExpr) + { + // A nested initializer list should always be used directly. + // + if(as(fromExpr)) + { + return true; + } + + // If the desired type is a scalar, then we should always initialize + // directly, since it isn't an aggregate. + // + if(isEffectivelyScalarForInitializerLists(toType)) + return true; + + // If the type we are initializing isn't effectively scalar, + // but the initialization expression *is*, then it doesn't + // seem like direct initialization is intended. + // + if(isEffectivelyScalarForInitializerLists(fromExpr->type)) + return false; + + // Once the above cases are handled, the main thing + // we want to check for is whether a direct initialization + // is possible (a type conversion exists). + // + return canCoerce(toType, fromExpr->type); + } + + /// Read a value from an initializer list expression. + /// + /// This reads one or more argument from the initializer list + /// given as `fromInitializerListExpr` to initialize a value + /// of type `toType`. This may involve reading one or + /// more arguments from the initializer list, depending + /// on whether `toType` is an aggregate or not, and on + /// whether the next argument in the initializer list is + /// itself an initializer list. + /// + /// This routine returns `true` if it was able to read + /// arguments that can form a value of type `toType`, + /// and `false` otherwise. + /// + /// If the routine succeeds and `outToExpr` is non-null, + /// then it will be filled in with an expression + /// representing the value (or type `toType`) that was read, + /// or it will be left null to indicate that a default + /// value should be used. + /// + /// If the routine fails and `outToExpr` is non-null, + /// then a suitable diagnostic will be emitted. + /// + bool _readValueFromInitializerList( + RefPtr toType, + RefPtr* outToExpr, + RefPtr fromInitializerListExpr, + UInt &ioInitArgIndex) + { + // First, we will check if we have run out of arguments + // on the initializer list. + // + UInt initArgCount = fromInitializerListExpr->args.getCount(); + if(ioInitArgIndex >= initArgCount) + { + // If we are at the end of the initializer list, + // then our ability to read an argument depends + // on whether the type we are trying to read + // is default-initializable. + // + // For now, we will just pretend like everything + // is default-initializable and move along. + return true; + } + + // Okay, we have at least one initializer list expression, + // so we will look at the next expression and decide + // whether to use it to initialize the desired type + // directly (possibly via casts), or as the first sub-expression + // for aggregate initialization. + // + auto firstInitExpr = fromInitializerListExpr->args[ioInitArgIndex]; + if(shouldUseInitializerDirectly(toType, firstInitExpr)) + { + ioInitArgIndex++; + return _coerce( + toType, + outToExpr, + firstInitExpr->type, + firstInitExpr, + nullptr); + } + + // If there is somehow an error in one of the initialization + // expressions, then everything could be thrown off and we + // shouldn't keep trying to read arguments. + // + if( IsErrorExpr(firstInitExpr) ) + { + // Stop reading arguments, as if we'd reached + // the end of the list. + // + ioInitArgIndex = initArgCount; + return true; + } + + // The fallback case is to recursively read the + // type from the same list as an aggregate. + // + return _readAggregateValueFromInitializerList( + toType, + outToExpr, + fromInitializerListExpr, + ioInitArgIndex); + } + + /// Read an aggregate value from an initializer list expression. + /// + /// This reads one or more arguments from the initializer list + /// given as `fromInitializerListExpr` to initialize the + /// fields/elements of a value of type `toType`. + /// + /// This routine returns `true` if it was able to read + /// arguments that can form a value of type `toType`, + /// and `false` otherwise. + /// + /// If the routine succeeds and `outToExpr` is non-null, + /// then it will be filled in with an expression + /// representing the value (or type `toType`) that was read, + /// or it will be left null to indicate that a default + /// value should be used. + /// + /// If the routine fails and `outToExpr` is non-null, + /// then a suitable diagnostic will be emitted. + /// + bool _readAggregateValueFromInitializerList( + RefPtr inToType, + RefPtr* outToExpr, + RefPtr fromInitializerListExpr, + UInt &ioArgIndex) + { + auto toType = inToType; + UInt argCount = fromInitializerListExpr->args.getCount(); + + // In the case where we need to build a result expression, + // we will collect the new arguments here + List> coercedArgs; + + if(isEffectivelyScalarForInitializerLists(toType)) + { + // For any type that is effectively a non-aggregate, + // we expect to read a single value from the initializer list + // + if(ioArgIndex < argCount) + { + auto arg = fromInitializerListExpr->args[ioArgIndex++]; + return _coerce( + toType, + outToExpr, + arg->type, + arg, + nullptr); + } + else + { + // If there wasn't an initialization + // expression to be found, then we need + // to perform default initialization here. + // + // We will let this case come through the front-end + // as an `InitializerListExpr` with zero arguments, + // and then have the IR generation logic deal with + // synthesizing default values. + } + } + else if (auto toVecType = as(toType)) + { + auto toElementCount = toVecType->elementCount; + auto toElementType = toVecType->elementType; + + UInt elementCount = 0; + if (auto constElementCount = as(toElementCount)) + { + elementCount = (UInt) constElementCount->value; + } + else + { + // We don't know the element count statically, + // so what are we supposed to be doing? + // + if(outToExpr) + { + getSink()->diagnose(fromInitializerListExpr, Diagnostics::cannotUseInitializerListForVectorOfUnknownSize, toElementCount); + } + return false; + } + + for(UInt ee = 0; ee < elementCount; ++ee) + { + RefPtr coercedArg; + bool argResult = _readValueFromInitializerList( + toElementType, + outToExpr ? &coercedArg : nullptr, + fromInitializerListExpr, + ioArgIndex); + + // No point in trying further if any argument fails + if(!argResult) + return false; + + if( coercedArg ) + { + coercedArgs.add(coercedArg); + } + } + } + else if(auto toArrayType = as(toType)) + { + // TODO(tfoley): If we can compute the size of the array statically, + // then we want to check that there aren't too many initializers present + + auto toElementType = toArrayType->baseType; + + if(auto toElementCount = toArrayType->ArrayLength) + { + // In the case of a sized array, we need to check that the number + // of elements being initialized matches what was declared. + // + UInt elementCount = 0; + if (auto constElementCount = as(toElementCount)) + { + elementCount = (UInt) constElementCount->value; + } + else + { + // We don't know the element count statically, + // so what are we supposed to be doing? + // + if(outToExpr) + { + getSink()->diagnose(fromInitializerListExpr, Diagnostics::cannotUseInitializerListForArrayOfUnknownSize, toElementCount); + } + return false; + } + + for(UInt ee = 0; ee < elementCount; ++ee) + { + RefPtr coercedArg; + bool argResult = _readValueFromInitializerList( + toElementType, + outToExpr ? &coercedArg : nullptr, + fromInitializerListExpr, + ioArgIndex); + + // No point in trying further if any argument fails + if(!argResult) + return false; + + if( coercedArg ) + { + coercedArgs.add(coercedArg); + } + } + } + else + { + // In the case of an unsized array type, we will use the + // number of arguments to the initializer to determine + // the element count. + // + UInt elementCount = 0; + while(ioArgIndex < argCount) + { + RefPtr coercedArg; + bool argResult = _readValueFromInitializerList( + toElementType, + outToExpr ? &coercedArg : nullptr, + fromInitializerListExpr, + ioArgIndex); + + // No point in trying further if any argument fails + if(!argResult) + return false; + + elementCount++; + + if( coercedArg ) + { + coercedArgs.add(coercedArg); + } + } + + // We have a new type for the conversion, based on what + // we learned. + toType = getSession()->getArrayType( + toElementType, + new ConstantIntVal(elementCount)); + } + } + else if(auto toMatrixType = as(toType)) + { + // In the general case, the initializer list might comprise + // both vectors and scalars. + // + // The traditional HLSL compilers treat any vectors in + // the initializer list exactly equivalent to their sequence + // of scalar elements, and don't care how this might, or + // might not, align with the rows of the matrix. + // + // We will draw a line in the sand and say that an initializer + // list for a matrix will act as if the matrix type were an + // array of vectors for the rows. + + + UInt rowCount = 0; + auto toRowType = createVectorType( + toMatrixType->getElementType(), + toMatrixType->getColumnCount()); + + if (auto constRowCount = as(toMatrixType->getRowCount())) + { + rowCount = (UInt) constRowCount->value; + } + else + { + // We don't know the element count statically, + // so what are we supposed to be doing? + // + if(outToExpr) + { + getSink()->diagnose(fromInitializerListExpr, Diagnostics::cannotUseInitializerListForMatrixOfUnknownSize, toMatrixType->getRowCount()); + } + return false; + } + + for(UInt rr = 0; rr < rowCount; ++rr) + { + RefPtr coercedArg; + bool argResult = _readValueFromInitializerList( + toRowType, + outToExpr ? &coercedArg : nullptr, + fromInitializerListExpr, + ioArgIndex); + + // No point in trying further if any argument fails + if(!argResult) + return false; + + if( coercedArg ) + { + coercedArgs.add(coercedArg); + } + } + } + else if(auto toDeclRefType = as(toType)) + { + auto toTypeDeclRef = toDeclRefType->declRef; + if(auto toStructDeclRef = toTypeDeclRef.as()) + { + // Trying to initialize a `struct` type given an initializer list. + // We will go through the fields in order and try to match them + // up with initializer arguments. + // + for(auto fieldDeclRef : getMembersOfType(toStructDeclRef)) + { + RefPtr coercedArg; + bool argResult = _readValueFromInitializerList( + GetType(fieldDeclRef), + outToExpr ? &coercedArg : nullptr, + fromInitializerListExpr, + ioArgIndex); + + // No point in trying further if any argument fails + if(!argResult) + return false; + + if( coercedArg ) + { + coercedArgs.add(coercedArg); + } + } + } + } + else + { + // We shouldn't get to this case in practice, + // but just in case we'll consider an initializer + // list invalid if we are trying to read something + // off of it that wasn't handled by the cases above. + // + if(outToExpr) + { + getSink()->diagnose(fromInitializerListExpr, Diagnostics::cannotUseInitializerListForType, inToType); + } + return false; + } + + // We were able to coerce all the arguments given, and so + // we need to construct a suitable expression to remember the result + // + if(outToExpr) + { + auto toInitializerListExpr = new InitializerListExpr(); + toInitializerListExpr->loc = fromInitializerListExpr->loc; + toInitializerListExpr->type = QualType(toType); + toInitializerListExpr->args = coercedArgs; + + *outToExpr = toInitializerListExpr; + } + + return true; + } + + /// Coerce an initializer-list expression to a specific type. + /// + /// This reads one or more arguments from the initializer list + /// given as `fromInitializerListExpr` to initialize the + /// fields/elements of a value of type `toType`. + /// + /// This routine returns `true` if it was able to read + /// arguments that can form a value of type `toType`, + /// with no arguments left over, and `false` otherwise. + /// + /// If the routine succeeds and `outToExpr` is non-null, + /// then it will be filled in with an expression + /// representing the value (or type `toType`) that was read, + /// or it will be left null to indicate that a default + /// value should be used. + /// + /// If the routine fails and `outToExpr` is non-null, + /// then a suitable diagnostic will be emitted. + /// + bool _coerceInitializerList( + RefPtr toType, + RefPtr* outToExpr, + RefPtr fromInitializerListExpr) + { + UInt argCount = fromInitializerListExpr->args.getCount(); + UInt argIndex = 0; + + // TODO: we should handle the special case of `{0}` as an initializer + // for arbitrary `struct` types here. + + if(!_readAggregateValueFromInitializerList(toType, outToExpr, fromInitializerListExpr, argIndex)) + return false; + + if(argIndex != argCount) + { + if( outToExpr ) + { + getSink()->diagnose(fromInitializerListExpr, Diagnostics::tooManyInitializers, argIndex, argCount); + } + } + + return true; + } + + /// Report that implicit type coercion is not possible. + bool _failedCoercion( + RefPtr toType, + RefPtr* outToExpr, + RefPtr fromExpr) + { + if(outToExpr) + { + getSink()->diagnose(fromExpr->loc, Diagnostics::typeMismatch, toType, fromExpr->type); + } + return false; + } + + /// Central engine for implementing implicit coercion logic + /// + /// This function tries to find an implicit conversion path from + /// `fromType` to `toType`. It returns `true` if a conversion + /// is found, and `false` if not. + /// + /// If a conversion is found, then its cost will be written to `outCost`. + /// + /// If a `fromExpr` is provided, it must be of type `fromType`, + /// and represent a value to be converted. + /// + /// If `outToExpr` is non-null, and if a conversion is found, then + /// `*outToExpr` will be set to an expression that performs the + /// implicit conversion of `fromExpr` (which must be non-null + /// to `toType`). + /// + /// The case where `outToExpr` is non-null is used to identify + /// when a conversion is being done "for real" so that diagnostics + /// should be emitted on failure. + /// + bool _coerce( + RefPtr toType, + RefPtr* outToExpr, + RefPtr fromType, + RefPtr fromExpr, + ConversionCost* outCost) + { + // An important and easy case is when the "to" and "from" types are equal. + // + if( toType->Equals(fromType) ) + { + if(outToExpr) + *outToExpr = fromExpr; + if(outCost) + *outCost = kConversionCost_None; + return true; + } + + // Another important case is when either the "to" or "from" type + // represents an error. In such a case we must have already + // reporeted the error, so it is better to allow the conversion + // to pass than to report a "cascading" error that might not + // make any sense. + // + if(as(toType) || as(fromType)) + { + if(outToExpr) + *outToExpr = CreateImplicitCastExpr(toType, fromExpr); + if(outCost) + *outCost = kConversionCost_None; + return true; + } + + // Coercion from an initializer list is allowed for many types, + // so we will farm that out to its own subroutine. + // + if( auto fromInitializerListExpr = as(fromExpr)) + { + if( !_coerceInitializerList( + toType, + outToExpr, + fromInitializerListExpr) ) + { + return false; + } + + // For now, we treat coercion from an initializer list + // as having no cost, so that all conversions from initializer + // lists are equally valid. This is fine given where initializer + // lists are allowed to appear now, but might need to be made + // more strict if we allow for initializer lists in more + // places in the language (e.g., as function arguments). + // + if(outCost) + { + *outCost = kConversionCost_None; + } + + return true; + } + + // If we are casting to an interface type, then that will succeed + // if the "from" type conforms to the interface. + // + if (auto toDeclRefType = as(toType)) + { + auto toTypeDeclRef = toDeclRefType->declRef; + if (auto interfaceDeclRef = toTypeDeclRef.as()) + { + if(auto witness = tryGetInterfaceConformanceWitness(fromType, interfaceDeclRef)) + { + if (outToExpr) + *outToExpr = createCastToInterfaceExpr(toType, fromExpr, witness); + if (outCost) + *outCost = kConversionCost_CastToInterface; + return true; + } + } + } + + // We allow implicit conversion of a parameter group type like + // `ConstantBuffer` or `ParameterBlock` to its element + // type `X`. + // + if(auto fromParameterGroupType = as(fromType)) + { + auto fromElementType = fromParameterGroupType->getElementType(); + + // If we convert, e.g., `ConstantBuffer to `A`, we will allow + // subsequent conversion of `A` to `B` if such a conversion + // is possible. + // + ConversionCost subCost = kConversionCost_None; + + RefPtr derefExpr; + if(outToExpr) + { + derefExpr = new DerefExpr(); + derefExpr->base = fromExpr; + derefExpr->type = QualType(fromElementType); + } + + if(!_coerce( + toType, + outToExpr, + fromElementType, + derefExpr, + &subCost)) + { + return false; + } + + if(outCost) + *outCost = subCost + kConversionCost_ImplicitDereference; + return true; + } + + // The main general-purpose approach for conversion is + // using suitable marked initializer ("constructor") + // declarations on the target type. + // + // This is treated as a form of overload resolution, + // since we are effectively forming an overloaded + // call to one of the initializers in the target type. + + OverloadResolveContext overloadContext; + overloadContext.disallowNestedConversions = true; + overloadContext.argCount = 1; + overloadContext.argTypes = &fromType; + + overloadContext.originalExpr = nullptr; + if(fromExpr) + { + overloadContext.loc = fromExpr->loc; + overloadContext.funcLoc = fromExpr->loc; + overloadContext.args = &fromExpr; + } + + overloadContext.baseExpr = nullptr; + overloadContext.mode = OverloadResolveContext::Mode::JustTrying; + + AddTypeOverloadCandidates(toType, overloadContext, toType); + + // After all of the overload candidates have been added + // to the context and processed, we need to see whether + // there was one best overload or not. + // + if(overloadContext.bestCandidates.getCount() != 0) + { + // In this case there were multiple equally-good candidates to call. + // + // We will start by checking if the candidates are + // even applicable, because if not, then we shouldn't + // consider the conversion as possible. + // + if(overloadContext.bestCandidates[0].status != OverloadCandidate::Status::Applicable) + return _failedCoercion(toType, outToExpr, fromExpr); + + // If all of the candidates in `bestCandidates` are applicable, + // then we have an ambiguity. + // + // We will compute a nominal conversion cost as the minimum over + // all the conversions available. + // + ConversionCost bestCost = kConversionCost_Explicit; + for(auto candidate : overloadContext.bestCandidates) + { + ConversionCost candidateCost = getImplicitConversionCost( + candidate.item.declRef.getDecl()); + + if(candidateCost < bestCost) + bestCost = candidateCost; + } + + // Conceptually, we want to treat the conversion as + // possible, but report it as ambiguous if we actually + // need to reify the result as an expression. + // + if(outToExpr) + { + getSink()->diagnose(fromExpr, Diagnostics::ambiguousConversion, fromType, toType); + + *outToExpr = CreateErrorExpr(fromExpr); + } + + if(outCost) + *outCost = bestCost; + + return true; + } + else if(overloadContext.bestCandidate) + { + // If there is a single best candidate for conversion, + // then we want to use it. + // + // It is possible that there was a single best candidate, + // but it wasn't actually usable, so we will check for + // that case first. + // + if(overloadContext.bestCandidate->status != OverloadCandidate::Status::Applicable) + return _failedCoercion(toType, outToExpr, fromExpr); + + // Next, we need to look at the implicit conversion + // cost associated with the initializer we are invoking. + // + ConversionCost cost = getImplicitConversionCost( + overloadContext.bestCandidate->item.declRef.getDecl());; + + // If the cost is too high to be usable as an + // implicit conversion, then we will report the + // conversion as possible (so that an overload involving + // this conversion will be selected over one without), + // but then emit a diagnostic when actually reifying + // the result expression. + // + if( cost >= kConversionCost_Explicit ) + { + if( outToExpr ) + { + getSink()->diagnose(fromExpr, Diagnostics::typeMismatch, toType, fromType); + getSink()->diagnose(fromExpr, Diagnostics::noteExplicitConversionPossible, fromType, toType); + } + } + + if(outCost) + *outCost = cost; + + if(outToExpr) + { + // The logic here is a bit ugly, to deal with the fact that + // `CompleteOverloadCandidate` will, left to its own devices, + // construct a vanilla `InvokeExpr` to represent the call + // to the initializer we found, while we *want* it to + // create some variety of `ImplicitCastExpr`. + // + // Now, it just so happens that `CompleteOverloadCandidate` + // will use the "original" expression if one is available, + // so we'll create one and initialize it here. + // We fill in the location and arguments, but not the + // base expression (the callee), since that will come + // from the selected overload candidate. + // + auto castExpr = createImplicitCastExpr(); + castExpr->loc = fromExpr->loc; + castExpr->Arguments.add(fromExpr); + // + // Next we need to set our cast expression as the "original" + // expression and then complete the overload process. + // + overloadContext.originalExpr = castExpr; + *outToExpr = CompleteOverloadCandidate(overloadContext, *overloadContext.bestCandidate); + // + // However, the above isn't *quite* enough, because + // the process of completing the overload candidate + // might overwrite the argument list that was passed + // in to overload resolution, and in this case that + // "argument list" was just a pointer to `fromExpr`. + // + // That means we need to clear the argument list and + // reload it from `fromExpr` to make sure that we + // got the arguments *after* any transformations + // were applied. + // For right now this probably doesn't matter, + // because we don't allow nested implicit conversions, + // but I'd rather play it safe. + // + castExpr->Arguments.clear(); + castExpr->Arguments.add(fromExpr); + } + + return true; + } + + return _failedCoercion(toType, outToExpr, fromExpr); + } + + /// Check whether implicit type coercion from `fromType` to `toType` is possible. + /// + /// If conversion is possible, returns `true` and sets `outCost` to the cost + /// of the conversion found (if `outCost` is non-null). + /// + /// If conversion is not possible, returns `false`. + /// + bool canCoerce( + RefPtr toType, + RefPtr fromType, + ConversionCost* outCost = 0) + { + // As an optimization, we will maintain a cache of conversion results + // for basic types such as scalars and vectors. + // + BasicTypeKey key1, key2; + BasicTypeKeyPair cacheKey; + bool shouldAddToCache = false; + ConversionCost cost; + TypeCheckingCache* typeCheckingCache = getSession()->getTypeCheckingCache(); + if( key1.fromType(toType.Ptr()) && key2.fromType(fromType.Ptr()) ) + { + cacheKey.type1 = key1; + cacheKey.type2 = key2; + + if (typeCheckingCache->conversionCostCache.TryGetValue(cacheKey, cost)) + { + if (outCost) + *outCost = cost; + return cost != kConversionCost_Impossible; + } + else + shouldAddToCache = true; + } + + // If there was no suitable entry in the cache, + // then we fall back to the general-purpose + // conversion checking logic. + // + // Note that we are passing in `nullptr` as + // the output expression to be constructed, + // which suppresses emission of any diagnostics + // during the coercion process. + // + bool rs = _coerce( + toType, + nullptr, + fromType, + nullptr, + &cost); + + if (outCost) + *outCost = cost; + + if (shouldAddToCache) + { + if (!rs) + cost = kConversionCost_Impossible; + typeCheckingCache->conversionCostCache[cacheKey] = cost; + } + + return rs; + } + + RefPtr createImplicitCastExpr() + { + return new ImplicitCastExpr(); + } + + RefPtr CreateImplicitCastExpr( + RefPtr toType, + RefPtr fromExpr) + { + RefPtr castExpr = createImplicitCastExpr(); + + auto typeType = getTypeType(toType); + + auto typeExpr = new SharedTypeExpr(); + typeExpr->type.type = typeType; + typeExpr->base.type = toType; + + castExpr->loc = fromExpr->loc; + castExpr->FunctionExpr = typeExpr; + castExpr->type = QualType(toType); + castExpr->Arguments.add(fromExpr); + return castExpr; + } + + /// Create an "up-cast" from a value to an interface type + /// + /// This operation logically constructs an "existential" value, + /// which packages up the value, its type, and the witness + /// of its conformance to the interface. + /// + RefPtr createCastToInterfaceExpr( + RefPtr toType, + RefPtr fromExpr, + RefPtr witness) + { + RefPtr expr = new CastToInterfaceExpr(); + expr->loc = fromExpr->loc; + expr->type = QualType(toType); + expr->valueArg = fromExpr; + expr->witnessArg = witness; + return expr; + } + + /// Implicitly coerce `fromExpr` to `toType` and diagnose errors if it isn't possible + RefPtr coerce( + RefPtr toType, + RefPtr fromExpr) + { + RefPtr expr; + if (!_coerce( + toType, + &expr, + fromExpr->type.Ptr(), + fromExpr.Ptr(), + nullptr)) + { + // Note(tfoley): We don't call `CreateErrorExpr` here, because that would + // clobber the type on `fromExpr`, and an invariant here is that coercion + // really shouldn't *change* the expression that is passed in, but should + // introduce new AST nodes to coerce its value to a different type... + return CreateImplicitCastExpr( + getSession()->getErrorType(), + fromExpr); + } + return expr; + } + + void CheckVarDeclCommon(RefPtr varDecl) + { + // A variable that didn't have an explicit type written must + // have its type inferred from the initial-value expression. + // + if(!varDecl->type.exp) + { + // In this case we need to perform all checking of the + // variable (including semantic checking of the initial-value + // expression) during the first phase of checking. + + auto initExpr = varDecl->initExpr; + if(!initExpr) + { + getSink()->diagnose(varDecl, Diagnostics::varWithoutTypeMustHaveInitializer); + varDecl->type.type = getSession()->getErrorType(); + } + else + { + initExpr = CheckExpr(initExpr); + + // TODO: We might need some additional steps here to ensure + // that the type of the expression is one we are okay with + // inferring. E.g., if we ever decide that integer and floating-point + // literals have a distinct type from the standard int/float types, + // then we would need to "decay" a literal to an explicit type here. + + varDecl->initExpr = initExpr; + varDecl->type.type = initExpr->type; + } + + varDecl->SetCheckState(DeclCheckState::Checked); + } + else + { + if (function || checkingPhase == CheckingPhase::Header) + { + TypeExp typeExp = CheckUsableType(varDecl->type); + varDecl->type = typeExp; + if (varDecl->type.Equals(getSession()->getVoidType())) + { + getSink()->diagnose(varDecl, Diagnostics::invalidTypeVoid); + } + } + + if (checkingPhase == CheckingPhase::Body) + { + if (auto initExpr = varDecl->initExpr) + { + initExpr = CheckTerm(initExpr); + initExpr = coerce(varDecl->type.Ptr(), initExpr); + varDecl->initExpr = initExpr; + + // If this is an array variable, then we first want to give + // it a chance to infer an array size from its initializer + // + // TODO(tfoley): May need to extend this to handle the + // multi-dimensional case... + // + maybeInferArraySizeForVariable(varDecl); + // + // Next we want to make sure that the declared (or inferred) + // size for the array meets whatever language-specific + // constraints we want to enforce (e.g., disallow empty + // arrays in specific cases) + // + validateArraySizeForVariable(varDecl); + } + } + + } + varDecl->SetCheckState(getCheckedState()); + } + + // Fill in default substitutions for the 'subtype' part of a type constraint decl + void CheckConstraintSubType(TypeExp& typeExp) + { + if (auto sharedTypeExpr = as(typeExp.exp)) + { + if (auto declRefType = as(sharedTypeExpr->base)) + { + declRefType->declRef.substitutions = createDefaultSubstitutions(getSession(), declRefType->declRef.getDecl()); + + if (auto typetype = as(typeExp.exp->type)) + typetype->type = declRefType; + } + } + } + + void CheckGenericConstraintDecl(GenericTypeConstraintDecl* decl) + { + // TODO: are there any other validations we can do at this point? + // + // There probably needs to be a kind of "occurs check" to make + // sure that the constraint actually applies to at least one + // of the parameters of the generic. + if (decl->checkState == DeclCheckState::Unchecked) + { + decl->checkState = getCheckedState(); + CheckConstraintSubType(decl->sub); + decl->sub = TranslateTypeNodeForced(decl->sub); + decl->sup = TranslateTypeNodeForced(decl->sup); + } + } + + void checkDecl(Decl* decl) + { + EnsureDecl(decl, checkingPhase == CheckingPhase::Header ? DeclCheckState::CheckedHeader : DeclCheckState::Checked); + } + + void checkGenericDeclHeader(GenericDecl* genericDecl) + { + if (genericDecl->IsChecked(DeclCheckState::CheckedHeader)) + return; + // check the parameters + for (auto m : genericDecl->Members) + { + if (auto typeParam = as(m)) + { + typeParam->initType = CheckProperType(typeParam->initType); + } + else if (auto valParam = as(m)) + { + // TODO: some real checking here... + CheckVarDeclCommon(valParam); + } + else if (auto constraint = as(m)) + { + CheckGenericConstraintDecl(constraint); + } + } + + genericDecl->SetCheckState(DeclCheckState::CheckedHeader); + } + + void visitGenericDecl(GenericDecl* genericDecl) + { + checkGenericDeclHeader(genericDecl); + + // check the nested declaration + // TODO: this needs to be done in an appropriate environment... + checkDecl(genericDecl->inner); + genericDecl->SetCheckState(getCheckedState()); + } + + void visitGenericTypeConstraintDecl(GenericTypeConstraintDecl * genericConstraintDecl) + { + if (genericConstraintDecl->IsChecked(DeclCheckState::CheckedHeader)) + return; + // check the type being inherited from + auto base = genericConstraintDecl->sup; + base = TranslateTypeNode(base); + genericConstraintDecl->sup = base; + } + + void visitInheritanceDecl(InheritanceDecl* inheritanceDecl) + { + if (inheritanceDecl->IsChecked(DeclCheckState::CheckedHeader)) + return; + // check the type being inherited from + auto base = inheritanceDecl->base; + CheckConstraintSubType(base); + base = TranslateTypeNode(base); + inheritanceDecl->base = base; + + // For now we only allow inheritance from interfaces, so + // we will validate that the type expression names an interface + + if(auto declRefType = as(base.type)) + { + if(auto interfaceDeclRef = declRefType->declRef.as()) + { + return; + } + } + else if(base.type.is()) + { + // If an error was already produced, don't emit a cascading error. + return; + } + + // If type expression didn't name an interface, we'll emit an error here + // TODO: deal with the case of an error in the type expression (don't cascade) + getSink()->diagnose( base.exp, Diagnostics::expectedAnInterfaceGot, base.type); + } + + RefPtr checkConstantIntVal( + RefPtr expr) + { + // First type-check the expression as normal + expr = CheckExpr(expr); + + auto intVal = CheckIntegerConstantExpression(expr.Ptr()); + if(!intVal) + return nullptr; + + auto constIntVal = as(intVal); + if(!constIntVal) + { + getSink()->diagnose(expr->loc, Diagnostics::expectedIntegerConstantNotLiteral); + return nullptr; + } + return constIntVal; + } + + RefPtr checkConstantEnumVal( + RefPtr expr) + { + // First type-check the expression as normal + expr = CheckExpr(expr); + + auto intVal = CheckEnumConstantExpression(expr.Ptr()); + if(!intVal) + return nullptr; + + auto constIntVal = as(intVal); + if(!constIntVal) + { + getSink()->diagnose(expr->loc, Diagnostics::expectedIntegerConstantNotLiteral); + return nullptr; + } + return constIntVal; + } + + // Check an expression, coerce it to the `String` type, and then + // ensure that it has a literal (not just compile-time constant) value. + bool checkLiteralStringVal( + RefPtr expr, + String* outVal) + { + // TODO: This should actually perform semantic checking, etc., + // but for now we are just going to look for a direct string + // literal AST node. + + if(auto stringLitExpr = as(expr)) + { + if(outVal) + { + *outVal = stringLitExpr->value; + } + return true; + } + + getSink()->diagnose(expr, Diagnostics::expectedAStringLiteral); + + return false; + } + + void visitSyntaxDecl(SyntaxDecl*) + { + // These are only used in the stdlib, so no checking is needed + } + + void visitAttributeDecl(AttributeDecl*) + { + // These are only used in the stdlib, so no checking is needed + } + + void visitGenericTypeParamDecl(GenericTypeParamDecl*) + { + // These are only used in the stdlib, so no checking is needed for now + } + + void visitGenericValueParamDecl(GenericValueParamDecl*) + { + // These are only used in the stdlib, so no checking is needed for now + } + + void visitModifier(Modifier*) + { + // Do nothing with modifiers for now + } + + AttributeDecl* lookUpAttributeDecl(Name* attributeName, Scope* scope) + { + // Look up the name and see what we find. + // + // TODO: This needs to have some special filtering or naming + // rules to keep us from seeing shadowing variable declarations. + auto lookupResult = lookUp(getSession(), this, attributeName, scope, LookupMask::Attribute); + + // If the result was overloaded, + // then we aren't going to be able to extract a single decl. + if(lookupResult.isOverloaded()) + return nullptr; + + if (lookupResult.isValid()) + { + auto decl = lookupResult.item.declRef.getDecl(); + if (auto attributeDecl = as(decl)) + { + return attributeDecl; + } + else + { + return nullptr; + } + } + + // If we couldn't find a system attribute, try looking up as a user defined attribute + // A user defined attribute class is defined as a struct type with a "UserDefinedAttributeAttribute" modifier + lookupResult = lookUp(getSession(), this, getSession()->getNameObj(attributeName->text + "Attribute"), scope, LookupMask::type); + if (lookupResult.isOverloaded()) + { + // see if we have already created an AttributeDecl for this attribute struct + for (auto alt : lookupResult.items) + { + if (auto adecl = alt.declRef.as()) + return adecl.getDecl(); + } + } + // If we still cannot find any thing, quit + if (!lookupResult.isValid() || lookupResult.isOverloaded()) + return nullptr; + // Now construct an AttributeDecl for this user defined attribute class for future lookup + auto userDefAttribAttrib = lookupResult.item.declRef.decl->FindModifier(); + if (!userDefAttribAttrib) + return nullptr; + // create an AttributeDecl for the user defined attribute + auto structAttribDef = lookupResult.item.declRef.as().getDecl(); + RefPtr attribDecl = new AttributeDecl(); + attribDecl->nameAndLoc = structAttribDef->nameAndLoc; + attribDecl->loc = structAttribDef->loc; + attribDecl->nextInContainerWithSameName = structAttribDef->nextInContainerWithSameName; + // create a __attributeTarget modifier for the attribute class definition + RefPtr targetModifier = new AttributeTargetModifier(); + targetModifier->syntaxClass = userDefAttribAttrib->targetSyntaxClass; + targetModifier->loc = structAttribDef->loc; + targetModifier->next = attribDecl->modifiers.first; + attribDecl->modifiers.first = targetModifier; + structAttribDef->nextInContainerWithSameName = attribDecl.Ptr(); + // we should create UserDefinedAttribute nodes for all user defined attribute instances + attribDecl->syntaxClass = getSession()->findSyntaxClass(getSession()->getNameObj("UserDefinedAttribute")); + for (auto member : structAttribDef->Members) + { + if (auto varMember = as(member)) + { + RefPtr param = new ParamDecl(); + param->nameAndLoc = member->nameAndLoc; + param->type = varMember->type; + param->loc = member->loc; + attribDecl->Members.add(param); + } + } + // add the attribute class definition to the syntax tree, so it can be found + structAttribDef->ParentDecl->Members.add(attribDecl.Ptr()); + structAttribDef->ParentDecl->memberDictionaryIsValid = false; + // do necessary checks on this newly constructed node + checkDecl(attribDecl.Ptr()); + return attribDecl.Ptr(); + } + + bool hasIntArgs(Attribute* attr, int numArgs) + { + if (int(attr->args.getCount()) != numArgs) + { + return false; + } + for (int i = 0; i < numArgs; ++i) + { + if (!as(attr->args[i])) + { + return false; + } + } + return true; + } + + bool hasStringArgs(Attribute* attr, int numArgs) + { + if (int(attr->args.getCount()) != numArgs) + { + return false; + } + for (int i = 0; i < numArgs; ++i) + { + if (!as(attr->args[i])) + { + return false; + } + } + return true; + } + + bool getAttributeTargetSyntaxClasses(SyntaxClass & cls, uint32_t typeFlags) + { + if (typeFlags == (int)UserDefinedAttributeTargets::Struct) + { + cls = getSession()->findSyntaxClass(getSession()->getNameObj("StructDecl")); + return true; + } + if (typeFlags == (int)UserDefinedAttributeTargets::Var) + { + cls = getSession()->findSyntaxClass(getSession()->getNameObj("VarDecl")); + return true; + } + if (typeFlags == (int)UserDefinedAttributeTargets::Function) + { + cls = getSession()->findSyntaxClass(getSession()->getNameObj("FuncDecl")); + return true; + } + return false; + } + + bool validateAttribute(RefPtr attr, AttributeDecl* attribClassDecl) + { + if(auto numThreadsAttr = as(attr)) + { + SLANG_ASSERT(attr->args.getCount() == 3); + auto xVal = checkConstantIntVal(attr->args[0]); + auto yVal = checkConstantIntVal(attr->args[1]); + auto zVal = checkConstantIntVal(attr->args[2]); + + if(!xVal) return false; + if(!yVal) return false; + if(!zVal) return false; + + numThreadsAttr->x = (int32_t) xVal->value; + numThreadsAttr->y = (int32_t) yVal->value; + numThreadsAttr->z = (int32_t) zVal->value; + } + else if (auto bindingAttr = as(attr)) + { + // This must be vk::binding or gl::binding (as specified in core.meta.slang under vk_binding/gl_binding) + // Must have 2 int parameters. Ideally this would all be checked from the specification + // in core.meta.slang, but that's not completely implemented. So for now we check here. + if (attr->args.getCount() != 2) + { + return false; + } + + // TODO(JS): Prior validation currently doesn't ensure both args are ints (as specified in core.meta.slang), so check here + // to make sure they both are + auto binding = checkConstantIntVal(attr->args[0]); + auto set = checkConstantIntVal(attr->args[1]); + + if (binding == nullptr || set == nullptr) + { + return false; + } + + bindingAttr->binding = int32_t(binding->value); + bindingAttr->set = int32_t(set->value); + } + else if (auto maxVertexCountAttr = as(attr)) + { + SLANG_ASSERT(attr->args.getCount() == 1); + auto val = checkConstantIntVal(attr->args[0]); + + if(!val) return false; + + maxVertexCountAttr->value = (int32_t)val->value; + } + else if(auto instanceAttr = as(attr)) + { + SLANG_ASSERT(attr->args.getCount() == 1); + auto val = checkConstantIntVal(attr->args[0]); + + if(!val) return false; + + instanceAttr->value = (int32_t)val->value; + } + else if(auto entryPointAttr = as(attr)) + { + SLANG_ASSERT(attr->args.getCount() == 1); + + String stageName; + if(!checkLiteralStringVal(attr->args[0], &stageName)) + { + return false; + } + + auto stage = findStageByName(stageName); + if(stage == Stage::Unknown) + { + getSink()->diagnose(attr->args[0], Diagnostics::unknownStageName, stageName); + } + + entryPointAttr->stage = stage; + } + else if ((as(attr)) || + (as(attr)) || + (as(attr)) || + (as(attr)) || + (as(attr))) + { + // Let it go thru iff single string attribute + if (!hasStringArgs(attr, 1)) + { + getSink()->diagnose(attr, Diagnostics::expectedSingleStringArg, attr->name); + } + } + else if (as(attr)) + { + // Let it go thru iff single integral attribute + if (!hasIntArgs(attr, 1)) + { + getSink()->diagnose(attr, Diagnostics::expectedSingleIntArg, attr->name); + } + } + else if (as(attr)) + { + // Has no args + SLANG_ASSERT(attr->args.getCount() == 0); + } + else if (as(attr)) + { + // Has no args + SLANG_ASSERT(attr->args.getCount() == 0); + } + else if (as(attr)) + { + // Has no args + SLANG_ASSERT(attr->args.getCount() == 0); + } + else if (auto attrUsageAttr = as(attr)) + { + uint32_t targetClassId = (uint32_t)UserDefinedAttributeTargets::None; + if (attr->args.getCount() == 1) + { + RefPtr outIntVal; + if (auto cInt = checkConstantEnumVal(attr->args[0])) + { + targetClassId = (uint32_t)(cInt->value); + } + else + { + getSink()->diagnose(attr, Diagnostics::expectedSingleIntArg, attr->name); + return false; + } + } + if (!getAttributeTargetSyntaxClasses(attrUsageAttr->targetSyntaxClass, targetClassId)) + { + getSink()->diagnose(attr, Diagnostics::invalidAttributeTarget); + return false; + } + } + else if (auto unrollAttr = as(attr)) + { + // Check has an argument. We need this because default behavior is to give an error + // if an attribute has arguments, but not handled explicitly (and the default param will come through + // as 1 arg if nothing is specified) + SLANG_ASSERT(attr->args.getCount() == 1); + } + else if (auto userDefAttr = as(attr)) + { + // check arguments against attribute parameters defined in attribClassDecl + Index paramIndex = 0; + auto params = attribClassDecl->getMembersOfType(); + for (auto paramDecl : params) + { + if (paramIndex < attr->args.getCount()) + { + auto & arg = attr->args[paramIndex]; + bool typeChecked = false; + if (auto basicType = as(paramDecl->getType())) + { + if (basicType->baseType == BaseType::Int) + { + if (auto cint = checkConstantIntVal(arg)) + { + attr->intArgVals[(uint32_t)paramIndex] = cint; + } + typeChecked = true; + } + } + if (!typeChecked) + { + arg = CheckExpr(arg); + arg = coerce(paramDecl->getType(), arg); + } + } + paramIndex++; + } + if (params.getCount() < attr->args.getCount()) + { + getSink()->diagnose(attr, Diagnostics::tooManyArguments, attr->args.getCount(), params.getCount()); + } + else if (params.getCount() > attr->args.getCount()) + { + getSink()->diagnose(attr, Diagnostics::notEnoughArguments, attr->args.getCount(), params.getCount()); + } + } + else if (auto formatAttr = as(attr)) + { + SLANG_ASSERT(attr->args.getCount() == 1); + + String formatName; + if(!checkLiteralStringVal(attr->args[0], &formatName)) + { + return false; + } + + ImageFormat format = ImageFormat::unknown; + if(!findImageFormatByName(formatName.getBuffer(), &format)) + { + getSink()->diagnose(attr->args[0], Diagnostics::unknownImageFormatName, formatName); + } + + formatAttr->format = format; + } + else + { + if(attr->args.getCount() == 0) + { + // If the attribute took no arguments, then we will + // assume it is valid as written. + } + else + { + // We should be special-casing the checking of any attribute + // with a non-zero number of arguments. + SLANG_DIAGNOSE_UNEXPECTED(getSink(), attr, "unhandled attribute"); + return false; + } + } + + return true; + } + + RefPtr checkAttribute( + UncheckedAttribute* uncheckedAttr, + ModifiableSyntaxNode* attrTarget) + { + auto attrName = uncheckedAttr->getName(); + auto attrDecl = lookUpAttributeDecl( + attrName, + uncheckedAttr->scope); + + if(!attrDecl) + { + getSink()->diagnose(uncheckedAttr, Diagnostics::unknownAttributeName, attrName); + return uncheckedAttr; + } + + if(!attrDecl->syntaxClass.isSubClassOf()) + { + SLANG_DIAGNOSE_UNEXPECTED(getSink(), attrDecl, "attribute declaration does not reference an attribute class"); + return uncheckedAttr; + } + + // Manage scope + RefPtr attrInstance = attrDecl->syntaxClass.createInstance(); + auto attr = attrInstance.as(); + if(!attr) + { + SLANG_DIAGNOSE_UNEXPECTED(getSink(), attrDecl, "attribute class did not yield an attribute object"); + return uncheckedAttr; + } + + // We are going to replace the unchecked attribute with the checked one. + + // First copy all of the state over from the original attribute. + attr->name = uncheckedAttr->name; + attr->args = uncheckedAttr->args; + attr->loc = uncheckedAttr->loc; + + // We will start with checking steps that can be applied independent + // of the concrete attribute type that was selected. These only need + // us to look at the attribute declaration itself. + // + // Start by doing argument/parameter matching + UInt argCount = attr->args.getCount(); + UInt paramCounter = 0; + bool mismatch = false; + for(auto paramDecl : attrDecl->getMembersOfType()) + { + UInt paramIndex = paramCounter++; + if( paramIndex < argCount ) + { + // TODO: support checking the argument against the declared + // type for the parameter. + } + else + { + // We didn't have enough arguments for the + // number of parameters declared. + if(auto defaultArg = paramDecl->initExpr) + { + // The attribute declaration provided a default, + // so we should use that. + // + // TODO: we need to figure out how to hook up + // default arguments as needed. + // For now just copy the expression over. + + attr->args.add(paramDecl->initExpr); + } + else + { + mismatch = true; + } + } + } + UInt paramCount = paramCounter; + + if(mismatch) + { + getSink()->diagnose(attr, Diagnostics::attributeArgumentCountMismatch, attrName, paramCount, argCount); + return uncheckedAttr; + } + + // The next bit of validation that we can apply semi-generically + // is to validate that the target for this attribute is a valid + // one for the chosen attribute. + // + // The attribute declaration will have one or more `AttributeTargetModifier`s + // that each specify a syntax class that the attribute can be applied to. + // If any of these match `attrTarget`, then we are good. + // + bool validTarget = false; + for(auto attrTargetMod : attrDecl->GetModifiersOfType()) + { + if(attrTarget->getClass().isSubClassOf(attrTargetMod->syntaxClass)) + { + validTarget = true; + break; + } + } + if(!validTarget) + { + getSink()->diagnose(attr, Diagnostics::attributeNotApplicable, attrName); + return uncheckedAttr; + } + + // Now apply type-specific validation to the attribute. + if(!validateAttribute(attr, attrDecl)) + { + return uncheckedAttr; + } + + + return attr; + } + + RefPtr checkModifier( + RefPtr m, + ModifiableSyntaxNode* syntaxNode) + { + if(auto hlslUncheckedAttribute = as(m)) + { + // We have an HLSL `[name(arg,...)]` attribute, and we'd like + // to check that it is provides all the expected arguments + // + // First, look up the attribute name in the current scope to find + // the right syntax class to instantiate. + // + + return checkAttribute(hlslUncheckedAttribute, syntaxNode); + } + // Default behavior is to leave things as they are, + // and assume that modifiers are mostly already checked. + // + // TODO: This would be a good place to validate that + // a modifier is actually valid for the thing it is + // being applied to, and potentially to check that + // it isn't in conflict with any other modifiers + // on the same declaration. + + return m; + } + + + void checkModifiers(ModifiableSyntaxNode* syntaxNode) + { + // TODO(tfoley): need to make sure this only + // performs semantic checks on a `SharedModifier` once... + + // The process of checking a modifier may produce a new modifier in its place, + // so we will build up a new linked list of modifiers that will replace + // the old list. + RefPtr resultModifiers; + RefPtr* resultModifierLink = &resultModifiers; + + RefPtr modifier = syntaxNode->modifiers.first; + while(modifier) + { + // Because we are rewriting the list in place, we need to extract + // the next modifier here (not at the end of the loop). + auto next = modifier->next; + + // We also go ahead and clobber the `next` field on the modifier + // itself, so that the default behavior of `checkModifier()` can + // be to return a single unlinked modifier. + modifier->next = nullptr; + + auto checkedModifier = checkModifier(modifier, syntaxNode); + if(checkedModifier) + { + // If checking gave us a modifier to add, then we + // had better add it. + + // Just in case `checkModifier` ever returns multiple + // modifiers, lets advance to the end of the list we + // are building. + while(*resultModifierLink) + resultModifierLink = &(*resultModifierLink)->next; + + // attach the new modifier at the end of the list, + // and now set the "link" to point to its `next` field + *resultModifierLink = checkedModifier; + resultModifierLink = &checkedModifier->next; + } + + // Move along to the next modifier + modifier = next; + } + + // Whether we actually re-wrote anything or note, lets + // install the new list of modifiers on the declaration + syntaxNode->modifiers.first = resultModifiers; + } + + /// Perform checking of interface conformaces for this decl and all its children + void checkInterfaceConformancesRec(Decl* decl) + { + // Any user-defined type may have declared interface conformances, + // which we should check. + // + if( auto aggTypeDecl = as(decl) ) + { + checkAggTypeConformance(aggTypeDecl); + } + // Conformances can also come via `extension` declarations, and + // we should check them against the type(s) being extended. + // + else if(auto extensionDecl = as(decl)) + { + checkExtensionConformance(extensionDecl); + } + + // We need to handle the recursive cases here, the first + // of which is a generic decl, where we want to recurivsely + // check the inner declaration. + // + if(auto genericDecl = as(decl)) + { + checkInterfaceConformancesRec(genericDecl->inner); + } + // For any other kind of container declaration, we will + // recurse into all of its member declarations, so that + // we can handle, e.g., nested `struct` types. + // + else if(auto containerDecl = as(decl)) + { + for(auto member : containerDecl->Members) + { + checkInterfaceConformancesRec(member); + } + } + } + + void visitModuleDecl(ModuleDecl* programNode) + { + // Try to register all the builtin decls + for (auto decl : programNode->Members) + { + auto inner = decl; + if (auto genericDecl = as(decl)) + { + inner = genericDecl->inner; + } + + if (auto builtinMod = inner->FindModifier()) + { + registerBuiltinDecl(getSession(), decl, builtinMod); + } + if (auto magicMod = inner->FindModifier()) + { + registerMagicDecl(getSession(), decl, magicMod); + } + } + + // We need/want to visit any `import` declarations before + // anything else, to make sure that scoping works. + for(auto& importDecl : programNode->getMembersOfType()) + { + checkDecl(importDecl); + } + // register all extensions + for (auto & s : programNode->getMembersOfType()) + registerExtension(s); + for (auto & g : programNode->getMembersOfType()) + { + if (auto extDecl = as(g->inner)) + { + checkGenericDeclHeader(g); + registerExtension(extDecl); + } + } + // check user defined attribute classes first + for (auto decl : programNode->Members) + { + if (auto typeMember = as(decl)) + { + bool isTypeAttributeClass = false; + for (auto attrib : typeMember->GetModifiersOfType()) + { + if (attrib->name == getSession()->getNameObj("AttributeUsageAttribute")) + { + isTypeAttributeClass = true; + break; + } + } + if (isTypeAttributeClass) + checkDecl(decl); + } + } + // check types + for (auto & s : programNode->getMembersOfType()) + checkDecl(s.Ptr()); + + for (int pass = 0; pass < 2; pass++) + { + checkingPhase = pass == 0 ? CheckingPhase::Header : CheckingPhase::Body; + + for (auto & s : programNode->getMembersOfType()) + { + checkDecl(s.Ptr()); + } + // HACK(tfoley): Visiting all generic declarations here, + // because otherwise they won't get visited. + for (auto & g : programNode->getMembersOfType()) + { + checkDecl(g.Ptr()); + } + + // before checking conformance, make sure we check all the extension bodies + // generic extension decls are already checked by the loop above + for (auto & s : programNode->getMembersOfType()) + checkDecl(s); + + for (auto & func : programNode->getMembersOfType()) + { + if (!func->IsChecked(getCheckedState())) + { + VisitFunctionDeclaration(func.Ptr()); + } + } + for (auto & func : programNode->getMembersOfType()) + { + checkDecl(func); + } + + if (getSink()->GetErrorCount() != 0) + return; + + // Force everything to be fully checked, just in case + // Note that we don't just call this on the program, + // because we'd end up recursing into this very code path... + for (auto d : programNode->Members) + { + EnusreAllDeclsRec(d); + } + + if (pass == 0) + { + checkInterfaceConformancesRec(programNode); + } + } + } + + bool doesSignatureMatchRequirement( + DeclRef satisfyingMemberDeclRef, + DeclRef requiredMemberDeclRef, + RefPtr witnessTable) + { + if(satisfyingMemberDeclRef.getDecl()->HasModifier() + && !requiredMemberDeclRef.getDecl()->HasModifier()) + { + // A `[mutating]` method can't satisfy a non-`[mutating]` requirement, + // but vice-versa is okay. + return false; + } + + if(satisfyingMemberDeclRef.getDecl()->HasModifier() + != requiredMemberDeclRef.getDecl()->HasModifier()) + { + // A `static` method can't satisfy a non-`static` requirement and vice versa. + return false; + } + + // TODO: actually implement matching here. For now we'll + // just pretend that things are satisfied in order to make progress.. + witnessTable->requirementDictionary.Add( + requiredMemberDeclRef.getDecl(), + RequirementWitness(satisfyingMemberDeclRef)); + return true; + } + + bool doesGenericSignatureMatchRequirement( + DeclRef genDecl, + DeclRef requirementGenDecl, + RefPtr witnessTable) + { + if (genDecl.getDecl()->Members.getCount() != requirementGenDecl.getDecl()->Members.getCount()) + return false; + for (Index i = 0; i < genDecl.getDecl()->Members.getCount(); i++) + { + auto genMbr = genDecl.getDecl()->Members[i]; + auto requiredGenMbr = genDecl.getDecl()->Members[i]; + if (auto genTypeMbr = as(genMbr)) + { + if (auto requiredGenTypeMbr = as(requiredGenMbr)) + { + } + else + return false; + } + else if (auto genValMbr = as(genMbr)) + { + if (auto requiredGenValMbr = as(requiredGenMbr)) + { + if (!genValMbr->type->Equals(requiredGenValMbr->type)) + return false; + } + else + return false; + } + else if (auto genTypeConstraintMbr = as(genMbr)) + { + if (auto requiredTypeConstraintMbr = as(requiredGenMbr)) + { + if (!genTypeConstraintMbr->sup->Equals(requiredTypeConstraintMbr->sup)) + { + return false; + } + } + else + return false; + } + } + + // TODO: this isn't right, because we need to specialize the + // declarations of the generics to a common set of substitutions, + // so that their types are comparable (e.g., foo and foo + // need to have substitutions applies so that they are both foo, + // after which uses of the type X in their parameter lists can + // be compared). + + return doesMemberSatisfyRequirement( + DeclRef(genDecl.getDecl()->inner.Ptr(), genDecl.substitutions), + DeclRef(requirementGenDecl.getDecl()->inner.Ptr(), requirementGenDecl.substitutions), + witnessTable); + } + + bool doesTypeSatisfyAssociatedTypeRequirement( + RefPtr satisfyingType, + DeclRef requiredAssociatedTypeDeclRef, + RefPtr witnessTable) + { + // We need to confirm that the chosen type `satisfyingType`, + // meets all the constraints placed on the associated type + // requirement `requiredAssociatedTypeDeclRef`. + // + // We will enumerate the type constraints placed on the + // associated type and see if they can be satisfied. + // + bool conformance = true; + for (auto requiredConstraintDeclRef : getMembersOfType(requiredAssociatedTypeDeclRef)) + { + // Grab the type we expect to conform to from the constraint. + auto requiredSuperType = GetSup(requiredConstraintDeclRef); + + // Perform a search for a witness to the subtype relationship. + auto witness = tryGetSubtypeWitness(satisfyingType, requiredSuperType); + if(witness) + { + // If a subtype witness was found, then the conformance + // appears to hold, and we can satisfy that requirement. + witnessTable->requirementDictionary.Add(requiredConstraintDeclRef, RequirementWitness(witness)); + } + else + { + // If a witness couldn't be found, then the conformance + // seems like it will fail. + conformance = false; + } + } + + // TODO: if any conformance check failed, we should probably include + // that in an error message produced about not satisfying the requirement. + + if(conformance) + { + // If all the constraints were satisfied, then the chosen + // type can indeed satisfy the interface requirement. + witnessTable->requirementDictionary.Add( + requiredAssociatedTypeDeclRef.getDecl(), + RequirementWitness(satisfyingType)); + } + + return conformance; + } + + // Does the given `memberDecl` work as an implementation + // to satisfy the requirement `requiredMemberDeclRef` + // from an interface? + // + // If it does, then inserts a witness into `witnessTable` + // and returns `true`, otherwise returns `false` + bool doesMemberSatisfyRequirement( + DeclRef memberDeclRef, + DeclRef requiredMemberDeclRef, + RefPtr witnessTable) + { + // At a high level, we want to check that the + // `memberDecl` and the `requiredMemberDeclRef` + // have the same AST node class, and then also + // check that their signatures match. + // + // There are a bunch of detailed decisions that + // have to be made, though, because we might, e.g., + // allow a function with more general parameter + // types to satisfy a requirement with more + // specific parameter types. + // + // If we ever allow for "property" declarations, + // then we would probably need to allow an + // ordinary field to satisfy a property requirement. + // + // An associated type requirement should be allowed + // to be satisfied by any type declaration: + // a typedef, a `struct`, etc. + // + if (auto memberFuncDecl = memberDeclRef.as()) + { + if (auto requiredFuncDeclRef = requiredMemberDeclRef.as()) + { + // Check signature match. + return doesSignatureMatchRequirement( + memberFuncDecl, + requiredFuncDeclRef, + witnessTable); + } + } + else if (auto memberInitDecl = memberDeclRef.as()) + { + if (auto requiredInitDecl = requiredMemberDeclRef.as()) + { + // Check signature match. + return doesSignatureMatchRequirement( + memberInitDecl, + requiredInitDecl, + witnessTable); + } + } + else if (auto genDecl = memberDeclRef.as()) + { + // For a generic member, we will check if it can satisfy + // a generic requirement in the interface. + // + // TODO: we could also conceivably check that the generic + // could be *specialized* to satisfy the requirement, + // and then install a specialization of the generic into + // the witness table. Actually doing this would seem + // to require performing something akin to overload + // resolution as part of requirement satisfaction. + // + if (auto requiredGenDeclRef = requiredMemberDeclRef.as()) + { + return doesGenericSignatureMatchRequirement(genDecl, requiredGenDeclRef, witnessTable); + } + } + else if (auto subAggTypeDeclRef = memberDeclRef.as()) + { + if(auto requiredTypeDeclRef = requiredMemberDeclRef.as()) + { + checkDecl(subAggTypeDeclRef.getDecl()); + + auto satisfyingType = DeclRefType::Create(getSession(), subAggTypeDeclRef); + return doesTypeSatisfyAssociatedTypeRequirement(satisfyingType, requiredTypeDeclRef, witnessTable); + } + } + else if (auto typedefDeclRef = memberDeclRef.as()) + { + // this is a type-def decl in an aggregate type + // check if the specified type satisfies the constraints defined by the associated type + if (auto requiredTypeDeclRef = requiredMemberDeclRef.as()) + { + checkDecl(typedefDeclRef.getDecl()); + + auto satisfyingType = getNamedType(getSession(), typedefDeclRef); + return doesTypeSatisfyAssociatedTypeRequirement(satisfyingType, requiredTypeDeclRef, witnessTable); + } + } + // Default: just assume that thing aren't being satisfied. + return false; + } + + // State used while checking if a declaration (either a type declaration + // or an extension of that type) conforms to the interfaces it claims + // via its inheritance clauses. + // + struct ConformanceCheckingContext + { + Dictionary, RefPtr> mapInterfaceToWitnessTable; + }; + + // Find the appropriate member of a declared type to + // satisfy a requirement of an interface the type + // claims to conform to. + // + // The type declaration `typeDecl` has declared that it + // conforms to the interface `interfaceDeclRef`, and + // `requiredMemberDeclRef` is a required member of + // the interface. + // + // If a satisfying value is found, registers it in + // `witnessTable` and returns `true`, otherwise + // returns `false`. + // + bool findWitnessForInterfaceRequirement( + ConformanceCheckingContext* context, + DeclRef typeDeclRef, + InheritanceDecl* inheritanceDecl, + DeclRef interfaceDeclRef, + DeclRef requiredMemberDeclRef, + RefPtr witnessTable) + { + // The goal of this function is to find a suitable + // value to satisfy the requirement. + // + // The 99% case is that the requirement is a named member + // of the interface, and we need to search for a member + // with the same name in the type declaration and + // its (known) extensions. + + // As a first pass, lets check if we already have a + // witness in the table for the requirement, so + // that we can bail out early. + // + if(witnessTable->requirementDictionary.ContainsKey(requiredMemberDeclRef.getDecl())) + { + return true; + } + + + // An important exception to the above is that an + // inheritance declaration in the interface is not going + // to be satisfied by an inheritance declaration in the + // conforming type, but rather by a full "witness table" + // full of the satisfying values for each requirement + // in the inherited-from interface. + // + if( auto requiredInheritanceDeclRef = requiredMemberDeclRef.as() ) + { + // Recursively check that the type conforms + // to the inherited interface. + // + // TODO: we *really* need a linearization step here!!!! + + RefPtr satisfyingWitnessTable = checkConformanceToType( + context, + typeDeclRef, + requiredInheritanceDeclRef.getDecl(), + getBaseType(requiredInheritanceDeclRef)); + + if(!satisfyingWitnessTable) + return false; + + witnessTable->requirementDictionary.Add( + requiredInheritanceDeclRef.getDecl(), + RequirementWitness(satisfyingWitnessTable)); + return true; + } + + // We will look up members with the same name, + // since only same-name members will be able to + // satisfy the requirement. + // + // TODO: this won't work right now for members that + // don't have names, which right now includes + // initializers/constructors. + Name* name = requiredMemberDeclRef.GetName(); + + // We are basically looking up members of the + // given type, but we need to be a bit careful. + // We *cannot* perfom lookup "through" inheritance + // declarations for this or other interfaces, + // since that would let us satisfy a requirement + // with itself. + // + // There's also an interesting question of whether + // we can/should support innterface requirements + // being satisfied via `__transparent` members. + // This seems like a "clever" idea rather than + // a useful one, and IR generation would + // need to construct real IR to trampoline over + // to the implementation. + // + // The final case that can't be reduced to just + // "a directly declared member with the same name" + // is the case where the type inherits a member + // that can satisfy the requirement from a base type. + // We are ignoring implementation inheritance for + // now, so we won't worry about this. + + // Make sure that by-name lookup is possible. + buildMemberDictionary(typeDeclRef.getDecl()); + auto lookupResult = lookUpLocal(getSession(), this, name, typeDeclRef); + + if (!lookupResult.isValid()) + { + getSink()->diagnose(inheritanceDecl, Diagnostics::typeDoesntImplementInterfaceRequirement, typeDeclRef, requiredMemberDeclRef); + return false; + } + + // Iterate over the members and look for one that matches + // the expected signature for the requirement. + for (auto member : lookupResult) + { + if (doesMemberSatisfyRequirement(member.declRef, requiredMemberDeclRef, witnessTable)) + return true; + } + + // No suitable member found, although there were candidates. + // + // TODO: Eventually we might want something akin to the current + // overload resolution logic, where we keep track of a list + // of "candidates" for satisfaction of the requirement, + // and if nothing is found we print the candidates + + getSink()->diagnose(inheritanceDecl, Diagnostics::typeDoesntImplementInterfaceRequirement, typeDeclRef, requiredMemberDeclRef); + return false; + } + + // Check that the type declaration `typeDecl`, which + // declares conformance to the interface `interfaceDeclRef`, + // (via the given `inheritanceDecl`) actually provides + // members to satisfy all the requirements in the interface. + RefPtr checkInterfaceConformance( + ConformanceCheckingContext* context, + DeclRef typeDeclRef, + InheritanceDecl* inheritanceDecl, + DeclRef interfaceDeclRef) + { + // Has somebody already checked this conformance, + // and/or is in the middle of checking it? + RefPtr witnessTable; + if(context->mapInterfaceToWitnessTable.TryGetValue(interfaceDeclRef, witnessTable)) + return witnessTable; + + // We need to check the declaration of the interface + // before we can check that we conform to it. + checkDecl(interfaceDeclRef.getDecl()); + + // We will construct the witness table, and register it + // *before* we go about checking fine-grained requirements, + // in order to short-circuit any potential for infinite recursion. + + // Note: we will re-use the witnes table attached to the inheritance decl, + // if there is one. This catches cases where semantic checking might + // have synthesized some of the conformance witnesses for us. + // + witnessTable = inheritanceDecl->witnessTable; + if(!witnessTable) + { + witnessTable = new WitnessTable(); + } + context->mapInterfaceToWitnessTable.Add(interfaceDeclRef, witnessTable); + + bool result = true; + + // TODO: If we ever allow for implementation inheritance, + // then we will need to consider the case where a type + // declares that it conforms to an interface, but one of + // its (non-interface) base types already conforms to + // that interface, so that all of the requirements are + // already satisfied with inherited implementations... + for(auto requiredMemberDeclRef : getMembers(interfaceDeclRef)) + { + auto requirementSatisfied = findWitnessForInterfaceRequirement( + context, + typeDeclRef, + inheritanceDecl, + interfaceDeclRef, + requiredMemberDeclRef, + witnessTable); + + result = result && requirementSatisfied; + } + + // Extensions that apply to the interface type can create new conformances + // for the concrete types that inherit from the interface. + // + // These new conformances should not be able to introduce new *requirements* + // for an implementing interface (although they currently can), but we + // still need to go through this logic to find the appropriate value + // that will satisfy the requirement in these cases, and also to put + // the required entry into the witness table for the interface itself. + // + // TODO: This logic is a bit slippery, and we need to figure out what + // it means in the context of separate compilation. If module A defines + // an interface IA, module B defines a type C that conforms to IA, and then + // module C defines an extension that makes IA conform to IC, then it is + // unreasonable to expect the {B:IA} witness table to contain an entry + // corresponding to {IA:IC}. + // + // The simple answer then would be that the {IA:IC} conformance should be + // fixed, with a single witness table for {IA:IC}, but then what should + // happen in B explicitly conformed to IC already? + // + // For now we will just walk through the extensions that are known at + // the time we are compiling and handle those, and punt on the larger issue + // for abit longer. + for(auto candidateExt = interfaceDeclRef.getDecl()->candidateExtensions; candidateExt; candidateExt = candidateExt->nextCandidateExtension) + { + // We need to apply the extension to the interface type that our + // concrete type is inheriting from. + // + // TODO: need to decide if a this-type substitution is needed here. + // It probably it. + RefPtr targetType = DeclRefType::Create( + getSession(), + interfaceDeclRef); + auto extDeclRef = ApplyExtensionToType(candidateExt, targetType); + if(!extDeclRef) + continue; + + // Only inheritance clauses from the extension matter right now. + for(auto requiredInheritanceDeclRef : getMembersOfType(extDeclRef)) + { + auto requirementSatisfied = findWitnessForInterfaceRequirement( + context, + typeDeclRef, + inheritanceDecl, + interfaceDeclRef, + requiredInheritanceDeclRef, + witnessTable); + + result = result && requirementSatisfied; + } + } + + // If we failed to satisfy any requirements along the way, + // then we don't actually want to keep the witness table + // we've been constructing, because the whole thing was a failure. + if(!result) + { + return nullptr; + } + + return witnessTable; + } + + RefPtr checkConformanceToType( + ConformanceCheckingContext* context, + DeclRef typeDeclRef, + InheritanceDecl* inheritanceDecl, + Type* baseType) + { + if (auto baseDeclRefType = as(baseType)) + { + auto baseTypeDeclRef = baseDeclRefType->declRef; + if (auto baseInterfaceDeclRef = baseTypeDeclRef.as()) + { + // The type is stating that it conforms to an interface. + // We need to check that it provides all of the members + // required by that interface. + return checkInterfaceConformance( + context, + typeDeclRef, + inheritanceDecl, + baseInterfaceDeclRef); + } + } + + getSink()->diagnose(inheritanceDecl, Diagnostics::unimplemented, "type not supported for inheritance"); + return nullptr; + } + + // Check that the type (or extension) declaration `declRef`, + // which declares that it inherits from another type via + // `inheritanceDecl` actually does what it needs to + // for that inheritance to be valid. + bool checkConformance( + DeclRef declRef, + InheritanceDecl* inheritanceDecl) + { + declRef = createDefaultSubstitutionsIfNeeded(getSession(), declRef).as(); + + // Don't check conformances for abstract types that + // are being used to express *required* conformances. + if (auto assocTypeDeclRef = declRef.as()) + { + // An associated type declaration represents a requirement + // in an outer interface declaration, and its members + // (type constraints) represent additional requirements. + return true; + } + else if (auto interfaceDeclRef = declRef.as()) + { + // HACK: Our semantics as they stand today are that an + // `extension` of an interface that adds a new inheritance + // clause acts *as if* that inheritnace clause had been + // attached to the original `interface` decl: that is, + // it adds additional requirements. + // + // This is *not* a reasonable semantic to keep long-term, + // but it is required for some of our current example + // code to work. + return true; + } + + + // Look at the type being inherited from, and validate + // appropriately. + auto baseType = inheritanceDecl->base.type; + + ConformanceCheckingContext context; + RefPtr witnessTable = checkConformanceToType(&context, declRef, inheritanceDecl, baseType); + if(!witnessTable) + return false; + + inheritanceDecl->witnessTable = witnessTable; + return true; + } + + void checkExtensionConformance(ExtensionDecl* decl) + { + if (auto targetDeclRefType = as(decl->targetType)) + { + if (auto aggTypeDeclRef = targetDeclRefType->declRef.as()) + { + for (auto inheritanceDecl : decl->getMembersOfType()) + { + checkConformance(aggTypeDeclRef, inheritanceDecl); + } + } + } + } + + void checkAggTypeConformance(AggTypeDecl* decl) + { + // After we've checked members, we need to go through + // any inheritance clauses on the type itself, and + // confirm that the type actually provides whatever + // those clauses require. + + if (auto interfaceDecl = as(decl)) + { + // Don't check that an interface conforms to the + // things it inherits from. + } + else if (auto assocTypeDecl = as(decl)) + { + // Don't check that an associated type decl conforms to the + // things it inherits from. + } + else + { + // For non-interface types we need to check conformance. + // + // TODO: Need to figure out what this should do for + // `abstract` types if we ever add them. Should they + // be required to implement all interface requirements, + // just with `abstract` methods that replicate things? + // (That's what C# does). + for (auto inheritanceDecl : decl->getMembersOfType()) + { + checkConformance(makeDeclRef(decl), inheritanceDecl); + } + } + } + + void visitAggTypeDecl(AggTypeDecl* decl) + { + if (decl->IsChecked(getCheckedState())) + return; + + // TODO: we should check inheritance declarations + // first, since they need to be validated before + // we can make use of the type (e.g., you need + // to know that `A` inherits from `B` in order + // to check an expression like `aValue.bMethod()` + // where `aValue` is of type `A` but `bMethod` + // is defined in type `B`. + // + // TODO: We should also add a pass that takes + // all the stated inheritance relationships, + // expands them to include implicit inheritance, + // and then linearizes them. This would allow + // later passes that need to know everything + // a type inherits from to proceed linearly + // through the list, rather than having to + // recurse (and potentially see the same interface + // more than once). + + decl->SetCheckState(DeclCheckState::CheckedHeader); + + // Now check all of the member declarations. + for (auto member : decl->Members) + { + checkDecl(member); + } + decl->SetCheckState(getCheckedState()); + } + + bool isIntegerBaseType(BaseType baseType) + { + switch(baseType) + { + default: + return false; + + case BaseType::Int8: + case BaseType::Int16: + case BaseType::Int: + case BaseType::Int64: + case BaseType::UInt8: + case BaseType::UInt16: + case BaseType::UInt: + case BaseType::UInt64: + return true; + } + } + + // Validate that `type` is a suitable type to use + // as the tag type for an `enum` + void validateEnumTagType(Type* type, SourceLoc const& loc) + { + if(auto basicType = as(type)) + { + // Allow the built-in integer types. + if(isIntegerBaseType(basicType->baseType)) + return; + + // By default, don't allow other types to be used + // as an `enum` tag type. + } + + getSink()->diagnose(loc, Diagnostics::invalidEnumTagType, type); + } + + void visitEnumDecl(EnumDecl* decl) + { + if (decl->IsChecked(getCheckedState())) + return; + + // We need to be careful to avoid recursion in the + // type-checking logic. We will do the minimal work + // to make the type usable in the first phase, and + // then check the actual cases in the second phase. + // + if(this->checkingPhase == CheckingPhase::Header) + { + // Look at inheritance clauses, and + // see if one of them is making the enum + // "inherit" from a concrete type. + // This will become the "tag" type + // of the enum. + RefPtr tagType; + InheritanceDecl* tagTypeInheritanceDecl = nullptr; + for(auto inheritanceDecl : decl->getMembersOfType()) + { + checkDecl(inheritanceDecl); + + // Look at the type being inherited from. + auto superType = inheritanceDecl->base.type; + + if(auto errorType = as(superType)) + { + // Ignore any erroneous inheritance clauses. + continue; + } + else if(auto declRefType = as(superType)) + { + if(auto interfaceDeclRef = declRefType->declRef.as()) + { + // Don't consider interface bases as candidates for + // the tag type. + continue; + } + } + + if(tagType) + { + // We already found a tag type. + getSink()->diagnose(inheritanceDecl, Diagnostics::enumTypeAlreadyHasTagType); + getSink()->diagnose(tagTypeInheritanceDecl, Diagnostics::seePreviousTagType); + break; + } + else + { + tagType = superType; + tagTypeInheritanceDecl = inheritanceDecl; + } + } + + // If a tag type has not been set, then we + // default it to the built-in `int` type. + // + // TODO: In the far-flung future we may want to distinguish + // `enum` types that have a "raw representation" like this from + // ones that are purely abstract and don't expose their + // type of their tag. + if(!tagType) + { + tagType = getSession()->getIntType(); + } + else + { + // TODO: Need to establish that the tag + // type is suitable. (e.g., if we are going + // to allow raw values for case tags to be + // derived automatically, then the tag + // type needs to be some kind of integer type...) + // + // For now we will just be harsh and require it + // to be one of a few builtin types. + validateEnumTagType(tagType, tagTypeInheritanceDecl->loc); + } + decl->tagType = tagType; + + + // An `enum` type should automatically conform to the `__EnumType` interface. + // The compiler needs to insert this conformance behind the scenes, and this + // seems like the best place to do it. + { + // First, look up the type of the `__EnumType` interface. + RefPtr enumTypeType = getSession()->getEnumTypeType(); + + RefPtr enumConformanceDecl = new InheritanceDecl(); + enumConformanceDecl->ParentDecl = decl; + enumConformanceDecl->loc = decl->loc; + enumConformanceDecl->base.type = getSession()->getEnumTypeType(); + decl->Members.add(enumConformanceDecl); + + // The `__EnumType` interface has one required member, the `__Tag` type. + // We need to satisfy this requirement automatically, rather than require + // the user to actually declare a member with this name (otherwise we wouldn't + // let them define a tag value with the name `__Tag`). + // + RefPtr witnessTable = new WitnessTable(); + enumConformanceDecl->witnessTable = witnessTable; + + Name* tagAssociatedTypeName = getSession()->getNameObj("__Tag"); + Decl* tagAssociatedTypeDecl = nullptr; + if(auto enumTypeTypeDeclRefType = enumTypeType.dynamicCast()) + { + if(auto enumTypeTypeInterfaceDecl = as(enumTypeTypeDeclRefType->declRef.getDecl())) + { + for(auto memberDecl : enumTypeTypeInterfaceDecl->Members) + { + if(memberDecl->getName() == tagAssociatedTypeName) + { + tagAssociatedTypeDecl = memberDecl; + break; + } + } + } + } + if(!tagAssociatedTypeDecl) + { + SLANG_DIAGNOSE_UNEXPECTED(getSink(), decl, "failed to find built-in declaration '__Tag'"); + } + + // Okay, add the conformance witness for `__Tag` being satisfied by `tagType` + witnessTable->requirementDictionary.Add(tagAssociatedTypeDecl, RequirementWitness(tagType)); + + // TODO: we actually also need to synthesize a witness for the conformance of `tagType` + // to the `__BuiltinIntegerType` interface, because that is a constraint on the + // associated type `__Tag`. + + // TODO: eventually we should consider synthesizing other requirements for + // the min/max tag values, or the total number of tags, so that people don't + // have to declare these as additional cases. + + enumConformanceDecl->SetCheckState(DeclCheckState::Checked); + } + } + else if( checkingPhase == CheckingPhase::Body ) + { + auto enumType = DeclRefType::Create( + getSession(), + makeDeclRef(decl)); + + auto tagType = decl->tagType; + + // Check the enum cases in order. + for(auto caseDecl : decl->getMembersOfType()) + { + // Each case defines a value of the enum's type. + // + // TODO: If we ever support enum cases with payloads, + // then they would probably have a type that is a + // `FunctionType` from the payload types to the + // enum type. + // + caseDecl->type.type = enumType; + + checkDecl(caseDecl); + } + + // For any enum case that didn't provide an explicit + // tag value, derived an appropriate tag value. + IntegerLiteralValue defaultTag = 0; + for(auto caseDecl : decl->getMembersOfType()) + { + if(auto explicitTagValExpr = caseDecl->tagExpr) + { + // This tag has an initializer, so it should establish + // the tag value for a successor case that doesn't + // provide an explicit tag. + + RefPtr explicitTagVal = TryConstantFoldExpr(explicitTagValExpr); + if(explicitTagVal) + { + if(auto constIntVal = as(explicitTagVal)) + { + defaultTag = constIntVal->value; + } + else + { + // TODO: need to handle other possibilities here + getSink()->diagnose(explicitTagValExpr, Diagnostics::unexpectedEnumTagExpr); + } + } + else + { + // If this happens, then the explicit tag value expression + // doesn't seem to be a constant after all. In this case + // we expect the checking logic to have applied already. + } + } + else + { + // This tag has no initializer, so it should use + // the default tag value we are tracking. + RefPtr tagValExpr = new IntegerLiteralExpr(); + tagValExpr->loc = caseDecl->loc; + tagValExpr->type = QualType(tagType); + tagValExpr->value = defaultTag; + + caseDecl->tagExpr = tagValExpr; + } + + // Default tag for the next case will be one more than + // for the most recent case. + // + // TODO: We might consider adding a `[flags]` attribute + // that modifies this behavior to be `defaultTagForCase <<= 1`. + // + defaultTag++; + } + + // Now check any other member declarations. + for(auto memberDecl : decl->Members) + { + // Already checked inheritance declarations above. + if(auto inheritanceDecl = as(memberDecl)) + continue; + + // Already checked enum case declarations above. + if(auto caseDecl = as(memberDecl)) + continue; + + // TODO: Right now we don't support other kinds of + // member declarations on an `enum`, but that is + // something we may want to allow in the long run. + // + checkDecl(memberDecl); + } + } + decl->SetCheckState(getCheckedState()); + } + + void visitEnumCaseDecl(EnumCaseDecl* decl) + { + if (decl->IsChecked(getCheckedState())) + return; + + if(checkingPhase == CheckingPhase::Body) + { + // An enum case had better appear inside an enum! + // + // TODO: Do we need/want to support generic cases some day? + auto parentEnumDecl = as(decl->ParentDecl); + SLANG_ASSERT(parentEnumDecl); + + // The tag type should have already been set by + // the surrounding `enum` declaration. + auto tagType = parentEnumDecl->tagType; + SLANG_ASSERT(tagType); + + // Need to check the init expression, if present, since + // that represents the explicit tag for this case. + if(auto initExpr = decl->tagExpr) + { + initExpr = CheckExpr(initExpr); + initExpr = coerce(tagType, initExpr); + + // We want to enforce that this is an integer constant + // expression, but we don't actually care to retain + // the value. + CheckIntegerConstantExpression(initExpr); + + decl->tagExpr = initExpr; + } + } + + decl->SetCheckState(getCheckedState()); + } + + void visitDeclGroup(DeclGroup* declGroup) + { + for (auto decl : declGroup->decls) + { + dispatchDecl(decl); + } + } + + void visitTypeDefDecl(TypeDefDecl* decl) + { + if (decl->IsChecked(getCheckedState())) return; + if (checkingPhase == CheckingPhase::Header) + { + decl->type = CheckProperType(decl->type); + } + decl->SetCheckState(getCheckedState()); + } + + void visitGlobalGenericParamDecl(GlobalGenericParamDecl* decl) + { + if (decl->IsChecked(getCheckedState())) return; + if (checkingPhase == CheckingPhase::Header) + { + decl->SetCheckState(DeclCheckState::CheckedHeader); + // global generic param only allowed in global scope + auto program = as(decl->ParentDecl); + if (!program) + getSink()->diagnose(decl, Slang::Diagnostics::globalGenParamInGlobalScopeOnly); + // Now check all of the member declarations. + for (auto member : decl->Members) + { + checkDecl(member); + } + } + decl->SetCheckState(getCheckedState()); + } + + void visitAssocTypeDecl(AssocTypeDecl* decl) + { + if (decl->IsChecked(getCheckedState())) return; + if (checkingPhase == CheckingPhase::Header) + { + decl->SetCheckState(DeclCheckState::CheckedHeader); + + // assoctype only allowed in an interface + auto interfaceDecl = as(decl->ParentDecl); + if (!interfaceDecl) + getSink()->diagnose(decl, Slang::Diagnostics::assocTypeInInterfaceOnly); + + // Now check all of the member declarations. + for (auto member : decl->Members) + { + checkDecl(member); + } + } + decl->SetCheckState(getCheckedState()); + } + + void checkStmt(Stmt* stmt) + { + if (!stmt) return; + dispatchStmt(stmt); + checkModifiers(stmt); + } + + void visitFuncDecl(FuncDecl* functionNode) + { + if (functionNode->IsChecked(getCheckedState())) + return; + + if (checkingPhase == CheckingPhase::Header) + { + VisitFunctionDeclaration(functionNode); + } + // TODO: This should really only set "checked header" + functionNode->SetCheckState(getCheckedState()); + + if (checkingPhase == CheckingPhase::Body) + { + // TODO: should put the checking of the body onto a "work list" + // to avoid recursion here. + if (functionNode->Body) + { + auto oldFunc = function; + this->function = functionNode; + checkStmt(functionNode->Body); + this->function = oldFunc; + } + } + } + + void getGenericParams( + GenericDecl* decl, + List& outParams, + List outConstraints) + { + for (auto dd : decl->Members) + { + if (dd == decl->inner) + continue; + + if (auto typeParamDecl = as(dd)) + outParams.add(typeParamDecl); + else if (auto valueParamDecl = as(dd)) + outParams.add(valueParamDecl); + else if (auto constraintDecl = as(dd)) + outConstraints.add(constraintDecl); + } + } + + bool doGenericSignaturesMatch( + GenericDecl* fst, + GenericDecl* snd) + { + // First we'll extract the parameters and constraints + // in each generic signature. We will consider parameters + // and constraints separately so that we are independent + // of the order in which constraints are given (that is, + // a constraint like `` would be considered + // the same as `` with a later `where T : IFoo`. + + List fstParams; + List fstConstraints; + getGenericParams(fst, fstParams, fstConstraints); + + List sndParams; + List sndConstraints; + getGenericParams(snd, sndParams, sndConstraints); + + // For there to be any hope of a match, the + // two need to have the same number of parameters. + Index paramCount = fstParams.getCount(); + if (paramCount != sndParams.getCount()) + return false; + + // Now we'll walk through the parameters. + for (Index pp = 0; pp < paramCount; ++pp) + { + Decl* fstParam = fstParams[pp]; + Decl* sndParam = sndParams[pp]; + + if (auto fstTypeParam = as(fstParam)) + { + if (auto sndTypeParam = as(sndParam)) + { + // TODO: is there any validation that needs to be performed here? + } + else + { + // Type and non-type parameters can't match. + return false; + } + } + else if (auto fstValueParam = as(fstParam)) + { + if (auto sndValueParam = as(sndParam)) + { + // Need to check that the parameters have the same type. + // + // Note: We are assuming here that the type of a value + // parameter cannot be dependent on any of the type + // parameters in the same signature. This is a reasonable + // assumption for now, but could get thorny down the road. + if (!fstValueParam->getType()->Equals(sndValueParam->getType())) + { + // Type mismatch. + return false; + } + + // TODO: This is not the right place to check on default + // values for the parameter, because they won't affect + // the signature, but we should make sure to do validation + // later on (e.g., that only one declaration can/should + // be allowed to provide a default). + } + else + { + // Value and non-value parameters can't match. + return false; + } + } + } + + // If we got this far, then it means the parameter signatures *seem* + // to match up all right, but now we need to check that the constraints + // placed on those parameters are also consistent. + // + // For now I'm going to assume/require that all declarations must + // declare the signature in a way that matches exactly. + Index constraintCount = fstConstraints.getCount(); + if(constraintCount != sndConstraints.getCount()) + return false; + + for (Index cc = 0; cc < constraintCount; ++cc) + { + //auto fstConstraint = fstConstraints[cc]; + //auto sndConstraint = sndConstraints[cc]; + + // TODO: the challenge here is that the + // constraints are going to be expressed + // in terms of the parameters, which means + // we need to be doing substitution here. + } + + // HACK: okay, we'll just assume things match for now. + return true; + } + + // Check if two functions have the same signature for the purposes + // of overload resolution. + bool doFunctionSignaturesMatch( + DeclRef fst, + DeclRef snd) + { + + // TODO(tfoley): This copies the parameter array, which is bad for performance. + auto fstParams = GetParameters(fst).ToArray(); + auto sndParams = GetParameters(snd).ToArray(); + + // If the functions have different numbers of parameters, then + // their signatures trivially don't match. + auto fstParamCount = fstParams.getCount(); + auto sndParamCount = sndParams.getCount(); + if (fstParamCount != sndParamCount) + return false; + + for (Index ii = 0; ii < fstParamCount; ++ii) + { + auto fstParam = fstParams[ii]; + auto sndParam = sndParams[ii]; + + // If a given parameter type doesn't match, then signatures don't match + if (!GetType(fstParam)->Equals(GetType(sndParam))) + return false; + + // If one parameter is `out` and the other isn't, then they don't match + // + // Note(tfoley): we don't consider `out` and `inout` as distinct here, + // because there is no way for overload resolution to pick between them. + if (fstParam.getDecl()->HasModifier() != sndParam.getDecl()->HasModifier()) + return false; + + // If one parameter is `ref` and the other isn't, then they don't match. + // + if(fstParam.getDecl()->HasModifier() != sndParam.getDecl()->HasModifier()) + return false; + } + + // Note(tfoley): return type doesn't enter into it, because we can't take + // calling context into account during overload resolution. + + return true; + } + + RefPtr createDummySubstitutions( + GenericDecl* genericDecl) + { + RefPtr subst = new GenericSubstitution(); + subst->genericDecl = genericDecl; + for (auto dd : genericDecl->Members) + { + if (dd == genericDecl->inner) + continue; + + if (auto typeParam = as(dd)) + { + auto type = DeclRefType::Create(getSession(), + makeDeclRef(typeParam)); + subst->args.add(type); + } + else if (auto valueParam = as(dd)) + { + auto val = new GenericParamIntVal( + makeDeclRef(valueParam)); + subst->args.add(val); + } + // TODO: need to handle constraints here? + } + return subst; + } + + void ValidateFunctionRedeclaration(FuncDecl* funcDecl) + { + auto parentDecl = funcDecl->ParentDecl; + SLANG_ASSERT(parentDecl); + if (!parentDecl) return; + + Decl* childDecl = funcDecl; + + // If this is a generic function (that is, its parent + // declaration is a generic), then we need to look + // for sibling declarations of the parent. + auto genericDecl = as(parentDecl); + if (genericDecl) + { + parentDecl = genericDecl->ParentDecl; + childDecl = genericDecl; + } + + // Look at previously-declared functions with the same name, + // in the same container + // + // Note: there is an assumption here that declarations that + // occur earlier in the program text will be *later* in + // the linked list of declarations with the same name. + // We are also assuming/requiring that the check here is + // symmetric, in that it is okay to test (A,B) or (B,A), + // and there is no need to test both. + // + buildMemberDictionary(parentDecl); + for (auto pp = childDecl->nextInContainerWithSameName; pp; pp = pp->nextInContainerWithSameName) + { + auto prevDecl = pp; + + // Look through generics to the declaration underneath + auto prevGenericDecl = as(prevDecl); + if (prevGenericDecl) + prevDecl = prevGenericDecl->inner.Ptr(); + + // We only care about previously-declared functions + // Note(tfoley): although we should really error out if the + // name is already in use for something else, like a variable... + auto prevFuncDecl = as(prevDecl); + if (!prevFuncDecl) + continue; + + // If one declaration is a prefix/postfix operator, and the + // other is not a matching operator, then don't consider these + // to be re-declarations. + // + // Note(tfoley): Any attempt to call such an operator using + // ordinary function-call syntax (if we decided to allow it) + // would be ambiguous in such a case, of course. + // + if (funcDecl->HasModifier() != prevDecl->HasModifier()) + continue; + if (funcDecl->HasModifier() != prevDecl->HasModifier()) + continue; + + // If one is generic and the other isn't, then there is no match. + if ((genericDecl != nullptr) != (prevGenericDecl != nullptr)) + continue; + + // We are going to be comparing the signatures of the + // two functions, but if they are *generic* functions + // then we will need to compare them with consistent + // specializations in place. + // + // We'll go ahead and create some (unspecialized) declaration + // references here, just to be prepared. + DeclRef funcDeclRef(funcDecl, nullptr); + DeclRef prevFuncDeclRef(prevFuncDecl, nullptr); + + // If we are working with generic functions, then we need to + // consider if their generic signatures match. + if (genericDecl) + { + SLANG_ASSERT(prevGenericDecl); // already checked above + if (!doGenericSignaturesMatch(genericDecl, prevGenericDecl)) + continue; + + // Now we need specialize the declaration references + // consistently, so that we can compare. + // + // First we create a "dummy" set of substitutions that + // just reference the parameters of the first generic. + auto subst = createDummySubstitutions(genericDecl); + // + // Then we use those parameters to specialize the *other* + // generic. + // + subst->genericDecl = prevGenericDecl; + prevFuncDeclRef.substitutions.substitutions = subst; + // + // One way to think about it is that if we have these + // declarations (ignore the name differences...): + // + // // prevFuncDecl: + // void foo1(T x); + // + // // funcDecl: + // void foo2(U x); + // + // Then we will compare `foo2` against `foo1`. + } + + // If the parameter signatures don't match, then don't worry + if (!doFunctionSignaturesMatch(funcDeclRef, prevFuncDeclRef)) + continue; + + // If we get this far, then we've got two declarations in the same + // scope, with the same name and signature, so they appear + // to be redeclarations. + // + // We will track that redeclaration occured, so that we can + // take it into account for overload resolution. + // + // A huge complication that we'll need to deal with is that + // multiple declarations might introduce default values for + // (different) parameters, and we might need to merge across + // all of them (which could get complicated if defaults for + // parameters can reference earlier parameters). + + // If the previous declaration wasn't already recorded + // as being part of a redeclaration family, then make + // it the primary declaration of a new family. + if (!prevFuncDecl->primaryDecl) + { + prevFuncDecl->primaryDecl = prevFuncDecl; + } + + // The new declaration will belong to the family of + // the previous one, and so it will share the same + // primary declaration. + funcDecl->primaryDecl = prevFuncDecl->primaryDecl; + funcDecl->nextDecl = nullptr; + + // Next we want to chain the new declaration onto + // the linked list of redeclarations. + auto link = &prevFuncDecl->nextDecl; + while (*link) + link = &(*link)->nextDecl; + *link = funcDecl; + + // Now that we've added things to a group of redeclarations, + // we can do some additional validation. + + // First, we will ensure that the return types match + // between the declarations, so that they are truly + // interchangeable. + // + // Note(tfoley): If we ever decide to add a beefier type + // system to Slang, we might allow overloads like this, + // so long as the desired result type can be disambiguated + // based on context at the call type. In that case we would + // consider result types earlier, as part of the signature + // matching step. + // + auto resultType = GetResultType(funcDeclRef); + auto prevResultType = GetResultType(prevFuncDeclRef); + if (!resultType->Equals(prevResultType)) + { + // Bad redeclaration + getSink()->diagnose(funcDecl, Diagnostics::functionRedeclarationWithDifferentReturnType, funcDecl->getName(), resultType, prevResultType); + getSink()->diagnose(prevFuncDecl, Diagnostics::seePreviousDeclarationOf, funcDecl->getName()); + + // Don't bother emitting other errors at this point + break; + } + + // Note(tfoley): several of the following checks should + // really be looping over all the previous declarations + // in the same group, and not just the one previous + // declaration we found just now. + + // TODO: Enforce that the new declaration had better + // not specify a default value for any parameter that + // already had a default value in a prior declaration. + + // We are going to want to enforce that we cannot have + // two declarations of a function both specify bodies. + // Before we make that check, however, we need to deal + // with the case where the two function declarations + // might represent different target-specific versions + // of a function. + // + // TODO: if the two declarations are specialized for + // different targets, then skip the body checks below. + + // If both of the declarations have a body, then there + // is trouble, because we wouldn't know which one to + // use during code generation. + if (funcDecl->Body && prevFuncDecl->Body) + { + // Redefinition + getSink()->diagnose(funcDecl, Diagnostics::functionRedefinition, funcDecl->getName()); + getSink()->diagnose(prevFuncDecl, Diagnostics::seePreviousDefinitionOf, funcDecl->getName()); + + // Don't bother emitting other errors + break; + } + + // At this point we've processed the redeclaration and + // put it into a group, so there is no reason to keep + // looping and looking at prior declarations. + return; + } + } + + void visitScopeDecl(ScopeDecl*) + { + // Nothing to do + } + + void visitParamDecl(ParamDecl* paramDecl) + { + // TODO: This logic should be shared with the other cases of + // variable declarations. The main reason I am not doing it + // yet is that we use a `ParamDecl` with a null type as a + // special case in attribute declarations, and that could + // trip up the ordinary variable checks. + + auto typeExpr = paramDecl->type; + if(typeExpr.exp) + { + typeExpr = CheckUsableType(typeExpr); + paramDecl->type = typeExpr; + } + + paramDecl->SetCheckState(DeclCheckState::CheckedHeader); + + // The "initializer" expression for a parameter represents + // a default argument value to use if an explicit one is + // not supplied. + if(auto initExpr = paramDecl->initExpr) + { + // We must check the expression and coerce it to the + // actual type of the parameter. + // + initExpr = CheckExpr(initExpr); + initExpr = coerce(typeExpr.type, initExpr); + paramDecl->initExpr = initExpr; + + // TODO: a default argument expression needs to + // conform to other constraints to be valid. + // For example, it should not be allowed to refer + // to other parameters of the same function (or maybe + // only the parameters to its left...). + + // A default argument value should not be allowed on an + // `out` or `inout` parameter. + // + // TODO: we could relax this by requiring the expression + // to yield an lvalue, but that seems like a feature + // with limited practical utility (and an easy source + // of confusing behavior). + // + // Note: the `InOutModifier` class inherits from `OutModifier`, + // so we only need to check for the base case. + // + if(paramDecl->FindModifier()) + { + getSink()->diagnose(initExpr, Diagnostics::outputParameterCannotHaveDefaultValue); + } + } + + paramDecl->SetCheckState(DeclCheckState::Checked); + } + + void VisitFunctionDeclaration(FuncDecl *functionNode) + { + if (functionNode->IsChecked(DeclCheckState::CheckedHeader)) return; + functionNode->SetCheckState(DeclCheckState::CheckingHeader); + auto oldFunc = this->function; + this->function = functionNode; + + auto resultType = functionNode->ReturnType; + if(resultType.exp) + { + resultType = CheckProperType(functionNode->ReturnType); + } + else + { + resultType = TypeExp(getSession()->getVoidType()); + } + functionNode->ReturnType = resultType; + + + HashSet paraNames; + for (auto & para : functionNode->GetParameters()) + { + EnsureDecl(para, DeclCheckState::CheckedHeader); + + if (paraNames.Contains(para->getName())) + { + getSink()->diagnose(para, Diagnostics::parameterAlreadyDefined, para->getName()); + } + else + paraNames.Add(para->getName()); + } + this->function = oldFunc; + functionNode->SetCheckState(DeclCheckState::CheckedHeader); + + // One last bit of validation: check if we are redeclaring an existing function + ValidateFunctionRedeclaration(functionNode); + } + + void visitDeclStmt(DeclStmt* stmt) + { + // We directly dispatch here instead of using `EnsureDecl()` for two + // reasons: + // + // 1. We expect that a local declaration won't have been referenced + // before it is declared, so that we can just check things in-order + // + // 2. `EnsureDecl()` is specialized for `Decl*` instead of `DeclBase*` + // and trying to special case `DeclGroup*` here feels silly. + // + dispatchDecl(stmt->decl); + checkModifiers(stmt->decl); + } + + void visitBlockStmt(BlockStmt* stmt) + { + checkStmt(stmt->body); + } + + void visitSeqStmt(SeqStmt* stmt) + { + for(auto ss : stmt->stmts) + { + checkStmt(ss); + } + } + + template + T* FindOuterStmt() + { + const Index outerStmtCount = outerStmts.getCount(); + for (Index ii = outerStmtCount; ii > 0; --ii) + { + auto outerStmt = outerStmts[ii-1]; + auto found = as(outerStmt); + if (found) + return found; + } + return nullptr; + } + + void visitBreakStmt(BreakStmt *stmt) + { + auto outer = FindOuterStmt(); + if (!outer) + { + getSink()->diagnose(stmt, Diagnostics::breakOutsideLoop); + } + stmt->parentStmt = outer; + } + void visitContinueStmt(ContinueStmt *stmt) + { + auto outer = FindOuterStmt(); + if (!outer) + { + getSink()->diagnose(stmt, Diagnostics::continueOutsideLoop); + } + stmt->parentStmt = outer; + } + + void PushOuterStmt(Stmt* stmt) + { + outerStmts.add(stmt); + } + + void PopOuterStmt(Stmt* /*stmt*/) + { + outerStmts.removeAt(outerStmts.getCount() - 1); + } + + RefPtr checkPredicateExpr(Expr* expr) + { + RefPtr e = expr; + e = CheckTerm(e); + e = coerce(getSession()->getBoolType(), e); + return e; + } + + void visitDoWhileStmt(DoWhileStmt *stmt) + { + PushOuterStmt(stmt); + stmt->Predicate = checkPredicateExpr(stmt->Predicate); + checkStmt(stmt->Statement); + + PopOuterStmt(stmt); + } + void visitForStmt(ForStmt *stmt) + { + PushOuterStmt(stmt); + checkStmt(stmt->InitialStatement); + if (stmt->PredicateExpression) + { + stmt->PredicateExpression = checkPredicateExpr(stmt->PredicateExpression); + } + if (stmt->SideEffectExpression) + { + stmt->SideEffectExpression = CheckExpr(stmt->SideEffectExpression); + } + checkStmt(stmt->Statement); + + PopOuterStmt(stmt); + } + + RefPtr checkExpressionAndExpectIntegerConstant(RefPtr expr, RefPtr* outIntVal) + { + expr = CheckExpr(expr); + auto intVal = CheckIntegerConstantExpression(expr); + if (outIntVal) + *outIntVal = intVal; + return expr; + } + + void visitCompileTimeForStmt(CompileTimeForStmt* stmt) + { + PushOuterStmt(stmt); + + stmt->varDecl->type.type = getSession()->getIntType(); + addModifier(stmt->varDecl, new ConstModifier()); + stmt->varDecl->SetCheckState(DeclCheckState::Checked); + + RefPtr rangeBeginVal; + RefPtr rangeEndVal; + + if (stmt->rangeBeginExpr) + { + stmt->rangeBeginExpr = checkExpressionAndExpectIntegerConstant(stmt->rangeBeginExpr, &rangeBeginVal); + } + else + { + RefPtr rangeBeginConst = new ConstantIntVal(); + rangeBeginConst->value = 0; + rangeBeginVal = rangeBeginConst; + } + + stmt->rangeEndExpr = checkExpressionAndExpectIntegerConstant(stmt->rangeEndExpr, &rangeEndVal); + + stmt->rangeBeginVal = rangeBeginVal; + stmt->rangeEndVal = rangeEndVal; + + checkStmt(stmt->body); + + + PopOuterStmt(stmt); + } + + void visitSwitchStmt(SwitchStmt* stmt) + { + PushOuterStmt(stmt); + // TODO(tfoley): need to coerce condition to an integral type... + stmt->condition = CheckExpr(stmt->condition); + checkStmt(stmt->body); + + // TODO(tfoley): need to check that all case tags are unique + + // TODO(tfoley): check that there is at most one `default` clause + + PopOuterStmt(stmt); + } + void visitCaseStmt(CaseStmt* stmt) + { + // TODO(tfoley): Need to coerce to type being switch on, + // and ensure that value is a compile-time constant + auto expr = CheckExpr(stmt->expr); + auto switchStmt = FindOuterStmt(); + + if (!switchStmt) + { + getSink()->diagnose(stmt, Diagnostics::caseOutsideSwitch); + } + else + { + // TODO: need to do some basic matching to ensure the type + // for the `case` is consistent with the type for the `switch`... + } + + stmt->expr = expr; + stmt->parentStmt = switchStmt; + } + void visitDefaultStmt(DefaultStmt* stmt) + { + auto switchStmt = FindOuterStmt(); + if (!switchStmt) + { + getSink()->diagnose(stmt, Diagnostics::defaultOutsideSwitch); + } + stmt->parentStmt = switchStmt; + } + void visitIfStmt(IfStmt *stmt) + { + stmt->Predicate = checkPredicateExpr(stmt->Predicate); + checkStmt(stmt->PositiveStatement); + checkStmt(stmt->NegativeStatement); + } + + void visitUnparsedStmt(UnparsedStmt*) + { + // Nothing to do + } + + void visitEmptyStmt(EmptyStmt*) + { + // Nothing to do + } + + void visitDiscardStmt(DiscardStmt*) + { + // Nothing to do + } + + void visitReturnStmt(ReturnStmt *stmt) + { + if (!stmt->Expression) + { + if (function && !function->ReturnType.Equals(getSession()->getVoidType())) + { + getSink()->diagnose(stmt, Diagnostics::returnNeedsExpression); + } + } + else + { + stmt->Expression = CheckTerm(stmt->Expression); + if (!stmt->Expression->type->Equals(getSession()->getErrorType())) + { + if (function) + { + stmt->Expression = coerce(function->ReturnType.Ptr(), stmt->Expression); + } + else + { + // TODO(tfoley): this case currently gets triggered for member functions, + // which aren't being checked consistently (because of the whole symbol + // table idea getting in the way). + +// getSink()->diagnose(stmt, Diagnostics::unimplemented, "case for return stmt"); + } + } + } + } + + IntegerLiteralValue GetMinBound(RefPtr val) + { + if (auto constantVal = as(val)) + return constantVal->value; + + // TODO(tfoley): Need to track intervals so that this isn't just a lie... + return 1; + } + + void maybeInferArraySizeForVariable(VarDeclBase* varDecl) + { + // Not an array? + auto arrayType = as(varDecl->type); + if (!arrayType) return; + + // Explicit element count given? + auto elementCount = arrayType->ArrayLength; + if (elementCount) return; + + // No initializer? + auto initExpr = varDecl->initExpr; + if(!initExpr) return; + + // Is the type of the initializer an array type? + if(auto arrayInitType = as(initExpr->type)) + { + elementCount = arrayInitType->ArrayLength; + } + else + { + // Nothing to do: we couldn't infer a size + return; + } + + // Create a new array type based on the size we found, + // and install it into our type. + varDecl->type.type = getArrayType( + arrayType->baseType, + elementCount); + } + + void validateArraySizeForVariable(VarDeclBase* varDecl) + { + auto arrayType = as(varDecl->type); + if (!arrayType) return; + + auto elementCount = arrayType->ArrayLength; + if (!elementCount) + { + // Note(tfoley): For now we allow arrays of unspecified size + // everywhere, because some source languages (e.g., GLSL) + // allow them in specific cases. +#if 0 + getSink()->diagnose(varDecl, Diagnostics::invalidArraySize); +#endif + return; + } + + // TODO(tfoley): How to handle the case where bound isn't known? + if (GetMinBound(elementCount) <= 0) + { + getSink()->diagnose(varDecl, Diagnostics::invalidArraySize); + return; + } + } + + void visitVarDecl(VarDecl* varDecl) + { + CheckVarDeclCommon(varDecl); + } + + void visitWhileStmt(WhileStmt *stmt) + { + PushOuterStmt(stmt); + stmt->Predicate = checkPredicateExpr(stmt->Predicate); + checkStmt(stmt->Statement); + PopOuterStmt(stmt); + } + void visitExpressionStmt(ExpressionStmt *stmt) + { + stmt->Expression = CheckExpr(stmt->Expression); + } + + RefPtr visitBoolLiteralExpr(BoolLiteralExpr* expr) + { + expr->type = getSession()->getBoolType(); + return expr; + } + + RefPtr visitIntegerLiteralExpr(IntegerLiteralExpr* expr) + { + // The expression might already have a type, determined by its suffix. + // It it doesn't, we will give it a default type. + // + // TODO: We should be careful to pick a "big enough" type + // based on the size of the value (e.g., don't try to stuff + // a constant in an `int` if it requires 64 or more bits). + // + // The long-term solution here is to give a type to a literal + // based on the context where it is used, but that requires + // a more sophisticated type system than we have today. + // + if(!expr->type.type) + { + expr->type = getSession()->getIntType(); + } + return expr; + } + + RefPtr visitFloatingPointLiteralExpr(FloatingPointLiteralExpr* expr) + { + if(!expr->type.type) + { + expr->type = getSession()->getFloatType(); + } + return expr; + } + + RefPtr visitStringLiteralExpr(StringLiteralExpr* expr) + { + expr->type = getSession()->getStringType(); + return expr; + } + + IntVal* GetIntVal(IntegerLiteralExpr* expr) + { + // TODO(tfoley): don't keep allocating here! + return new ConstantIntVal(expr->value); + } + + Linkage* getLinkage() { return m_linkage; } + NamePool* getNamePool() { return getLinkage()->getNamePool(); } + + Name* getName(String const& text) + { + return getNamePool()->getName(text); + } + + RefPtr TryConstantFoldExpr( + InvokeExpr* invokeExpr) + { + // We need all the operands to the expression + + // Check if the callee is an operation that is amenable to constant-folding. + // + // For right now we will look for calls to intrinsic functions, and then inspect + // their names (this is bad and slow). + auto funcDeclRefExpr = invokeExpr->FunctionExpr.as(); + if (!funcDeclRefExpr) return nullptr; + + auto funcDeclRef = funcDeclRefExpr->declRef; + auto intrinsicMod = funcDeclRef.getDecl()->FindModifier(); + if (!intrinsicMod) + { + // We can't constant fold anything that doesn't map to a builtin + // operation right now. + // + // TODO: we should really allow constant-folding for anything + // that can be lowered to our bytecode... + return nullptr; + } + + + + // Let's not constant-fold operations with more than a certain number of arguments, for simplicity + static const int kMaxArgs = 8; + if (invokeExpr->Arguments.getCount() > kMaxArgs) + return nullptr; + + // Before checking the operation name, let's look at the arguments + RefPtr argVals[kMaxArgs]; + IntegerLiteralValue constArgVals[kMaxArgs]; + int argCount = 0; + bool allConst = true; + for (auto argExpr : invokeExpr->Arguments) + { + auto argVal = TryCheckIntegerConstantExpression(argExpr.Ptr()); + if (!argVal) + return nullptr; + + argVals[argCount] = argVal; + + if (auto constArgVal = as(argVal)) + { + constArgVals[argCount] = constArgVal->value; + } + else + { + allConst = false; + } + argCount++; + } + + if (!allConst) + { + // TODO(tfoley): We probably want to support a very limited number of operations + // on "constants" that aren't actually known, to be able to handle a generic + // that takes an integer `N` but then constructs a vector of size `N+1`. + // + // The hard part there is implementing the rules for value unification in the + // presence of more complicated `IntVal` subclasses, like `SumIntVal`. You'd + // need inference to be smart enough to know that `2 + N` and `N + 2` are the + // same value, as are `N + M + 1 + 1` and `M + 2 + N`. + // + // For now we can just bail in this case. + return nullptr; + } + + // At this point, all the operands had simple integer values, so we are golden. + IntegerLiteralValue resultValue = 0; + auto opName = funcDeclRef.GetName(); + + // handle binary operators + if (opName == getName("-")) + { + if (argCount == 1) + { + resultValue = -constArgVals[0]; + } + else if (argCount == 2) + { + resultValue = constArgVals[0] - constArgVals[1]; + } + } + + // simple binary operators +#define CASE(OP) \ + else if(opName == getName(#OP)) do { \ + if(argCount != 2) return nullptr; \ + resultValue = constArgVals[0] OP constArgVals[1]; \ + } while(0) + + CASE(+); // TODO: this can also be unary... + CASE(*); +#undef CASE + + // binary operators with chance of divide-by-zero + // TODO: issue a suitable error in that case +#define CASE(OP) \ + else if(opName == getName(#OP)) do { \ + if(argCount != 2) return nullptr; \ + if(!constArgVals[1]) return nullptr; \ + resultValue = constArgVals[0] OP constArgVals[1]; \ + } while(0) + + CASE(/); + CASE(%); +#undef CASE + + // TODO(tfoley): more cases + else + { + return nullptr; + } + + RefPtr result = new ConstantIntVal(resultValue); + return result; + } + + RefPtr TryConstantFoldExpr( + Expr* expr) + { + // Unwrap any "identity" expressions + while (auto parenExpr = as(expr)) + { + expr = parenExpr->base; + } + + // TODO(tfoley): more serious constant folding here + if (auto intLitExpr = as(expr)) + { + return GetIntVal(intLitExpr); + } + + // it is possible that we are referring to a generic value param + if (auto declRefExpr = as(expr)) + { + auto declRef = declRefExpr->declRef; + + if (auto genericValParamRef = declRef.as()) + { + // TODO(tfoley): handle the case of non-`int` value parameters... + return new GenericParamIntVal(genericValParamRef); + } + + // We may also need to check for references to variables that + // are defined in a way that can be used as a constant expression: + if(auto varRef = declRef.as()) + { + auto varDecl = varRef.getDecl(); + + // In HLSL, `static const` is used to mark compile-time constant expressions + if(auto staticAttr = varDecl->FindModifier()) + { + if(auto constAttr = varDecl->FindModifier()) + { + // HLSL `static const` can be used as a constant expression + if(auto initExpr = getInitExpr(varRef)) + { + return TryConstantFoldExpr(initExpr.Ptr()); + } + } + } + } + else if(auto enumRef = declRef.as()) + { + // The cases in an `enum` declaration can also be used as constant expressions, + if(auto tagExpr = getTagExpr(enumRef)) + { + return TryConstantFoldExpr(tagExpr.Ptr()); + } + } + } + + if(auto castExpr = as(expr)) + { + auto val = TryConstantFoldExpr(castExpr->Arguments[0].Ptr()); + if(val) + return val; + } + else if (auto invokeExpr = as(expr)) + { + auto val = TryConstantFoldExpr(invokeExpr); + if (val) + return val; + } + + return nullptr; + } + + // Try to check an integer constant expression, either returning the value, + // or NULL if the expression isn't recognized as a constant. + RefPtr TryCheckIntegerConstantExpression(Expr* exp) + { + // Check if type is acceptable for an integer constant expression + if(auto basicType = as(exp->type.type)) + { + if(!isIntegerBaseType(basicType->baseType)) + return nullptr; + } + else + { + return nullptr; + } + + // Consider operations that we might be able to constant-fold... + return TryConstantFoldExpr(exp); + } + + // Enforce that an expression resolves to an integer constant, and get its value + RefPtr CheckIntegerConstantExpression(Expr* inExpr) + { + // No need to issue further errors if the expression didn't even type-check. + if(IsErrorExpr(inExpr)) return nullptr; + + // First coerce the expression to the expected type + auto expr = coerce(getSession()->getIntType(),inExpr); + + // No need to issue further errors if the type coercion failed. + if(IsErrorExpr(expr)) return nullptr; + + auto result = TryCheckIntegerConstantExpression(expr.Ptr()); + if (!result) + { + getSink()->diagnose(expr, Diagnostics::expectedIntegerConstantNotConstant); + } + return result; + } + + RefPtr CheckEnumConstantExpression(Expr* expr) + { + // No need to issue further errors if the expression didn't even type-check. + if(IsErrorExpr(expr)) return nullptr; + + // No need to issue further errors if the type coercion failed. + if(IsErrorExpr(expr)) return nullptr; + + auto result = TryConstantFoldExpr(expr); + if (!result) + { + getSink()->diagnose(expr, Diagnostics::expectedIntegerConstantNotConstant); + } + return result; + } + + + RefPtr CheckSimpleSubscriptExpr( + RefPtr subscriptExpr, + RefPtr elementType) + { + auto baseExpr = subscriptExpr->BaseExpression; + auto indexExpr = subscriptExpr->IndexExpression; + + if (!indexExpr->type->Equals(getSession()->getIntType()) && + !indexExpr->type->Equals(getSession()->getUIntType())) + { + getSink()->diagnose(indexExpr, Diagnostics::subscriptIndexNonInteger); + return CreateErrorExpr(subscriptExpr.Ptr()); + } + + subscriptExpr->type = QualType(elementType); + + // TODO(tfoley): need to be more careful about this stuff + subscriptExpr->type.IsLeftValue = baseExpr->type.IsLeftValue; + + return subscriptExpr; + } + + // The way that we have designed out type system, pretyt much *every* + // type is a reference to some declaration in the standard library. + // That means that when we construct a new type on the fly, we need + // to make sure that it is wired up to reference the appropriate + // declaration, or else it won't compare as equal to other types + // that *do* reference the declaration. + // + // This function is used to construct a `vector` type + // programmatically, so that it will work just like a type of + // that form constructed by the user. + RefPtr createVectorType( + RefPtr elementType, + RefPtr elementCount) + { + auto session = getSession(); + auto vectorGenericDecl = findMagicDecl( + session, "Vector").as(); + auto vectorTypeDecl = vectorGenericDecl->inner; + + auto substitutions = new GenericSubstitution(); + substitutions->genericDecl = vectorGenericDecl.Ptr(); + substitutions->args.add(elementType); + substitutions->args.add(elementCount); + + auto declRef = DeclRef(vectorTypeDecl.Ptr(), substitutions); + + return DeclRefType::Create( + session, + declRef).as(); + } + + RefPtr visitIndexExpr(IndexExpr* subscriptExpr) + { + auto baseExpr = subscriptExpr->BaseExpression; + baseExpr = CheckExpr(baseExpr); + + RefPtr indexExpr = subscriptExpr->IndexExpression; + if (indexExpr) + { + indexExpr = CheckExpr(indexExpr); + } + + subscriptExpr->BaseExpression = baseExpr; + subscriptExpr->IndexExpression = indexExpr; + + // If anything went wrong in the base expression, + // then just move along... + if (IsErrorExpr(baseExpr)) + return CreateErrorExpr(subscriptExpr); + + // Otherwise, we need to look at the type of the base expression, + // to figure out how subscripting should work. + auto baseType = baseExpr->type.Ptr(); + if (auto baseTypeType = as(baseType)) + { + // We are trying to "index" into a type, so we have an expression like `float[2]` + // which should be interpreted as resolving to an array type. + + RefPtr elementCount = nullptr; + if (indexExpr) + { + elementCount = CheckIntegerConstantExpression(indexExpr.Ptr()); + } + + auto elementType = CoerceToUsableType(TypeExp(baseExpr, baseTypeType->type)); + auto arrayType = getArrayType( + elementType, + elementCount); + + typeResult = arrayType; + subscriptExpr->type = QualType(getTypeType(arrayType)); + return subscriptExpr; + } + else if (auto baseArrayType = as(baseType)) + { + return CheckSimpleSubscriptExpr( + subscriptExpr, + baseArrayType->baseType); + } + else if (auto vecType = as(baseType)) + { + return CheckSimpleSubscriptExpr( + subscriptExpr, + vecType->elementType); + } + else if (auto matType = as(baseType)) + { + // TODO(tfoley): We shouldn't go and recompute + // row types over and over like this... :( + auto rowType = createVectorType( + matType->getElementType(), + matType->getColumnCount()); + + return CheckSimpleSubscriptExpr( + subscriptExpr, + rowType); + } + + // Default behavior is to look at all available `__subscript` + // declarations on the type and try to call one of them. + + { + LookupResult lookupResult = lookUpMember( + getSession(), + this, + getName("operator[]"), + baseType); + if (!lookupResult.isValid()) + { + goto fail; + } + + // Now that we know there is at least one subscript member, + // we will construct a reference to it and try to call it. + // + // Note: the expression may be an `OverloadedExpr`, in which + // case the attempt to call it will trigger overload + // resolution. + RefPtr subscriptFuncExpr = createLookupResultExpr( + lookupResult, subscriptExpr->BaseExpression, subscriptExpr->loc); + + RefPtr subscriptCallExpr = new InvokeExpr(); + subscriptCallExpr->loc = subscriptExpr->loc; + subscriptCallExpr->FunctionExpr = subscriptFuncExpr; + + // TODO(tfoley): This path can support multiple arguments easily + subscriptCallExpr->Arguments.add(subscriptExpr->IndexExpression); + + return CheckInvokeExprWithCheckedOperands(subscriptCallExpr.Ptr()); + } + + fail: + { + getSink()->diagnose(subscriptExpr, Diagnostics::subscriptNonArray, baseType); + return CreateErrorExpr(subscriptExpr); + } + } + + bool MatchArguments(FuncDecl * functionNode, List > &args) + { + if (functionNode->GetParameters().getCount() != args.getCount()) + return false; + Index i = 0; + for (auto param : functionNode->GetParameters()) + { + if (!param->type.Equals(args[i]->type.Ptr())) + return false; + i++; + } + return true; + } + + RefPtr visitParenExpr(ParenExpr* expr) + { + auto base = expr->base; + base = CheckTerm(base); + + expr->base = base; + expr->type = base->type; + return expr; + } + + // + + /// Given an immutable `expr` used as an l-value emit a special diagnostic if it was derived from `this`. + void maybeDiagnoseThisNotLValue(Expr* expr) + { + // We will try to handle expressions of the form: + // + // e ::= "this" + // | e . name + // | e [ expr ] + // + // We will unwrap the `e.name` and `e[expr]` cases in a loop. + RefPtr e = expr; + for(;;) + { + if(auto memberExpr = as(e)) + { + e = memberExpr->BaseExpression; + } + else if(auto subscriptExpr = as(e)) + { + e = subscriptExpr->BaseExpression; + } + else + { + break; + } + } + // + // Now we check to see if we have a `this` expression, + // and if it is immutable. + if(auto thisExpr = as(e)) + { + if(!thisExpr->type.IsLeftValue) + { + getSink()->diagnose(thisExpr, Diagnostics::thisIsImmutableByDefault); + } + } + } + + RefPtr visitAssignExpr(AssignExpr* expr) + { + expr->left = CheckExpr(expr->left); + + auto type = expr->left->type; + + expr->right = coerce(type, CheckTerm(expr->right)); + + if (!type.IsLeftValue) + { + if (as(type)) + { + // Don't report an l-value issue on an erroneous expression + } + else + { + getSink()->diagnose(expr, Diagnostics::assignNonLValue); + + // As a special case, check if the LHS expression is derived + // from a `this` parameter (implicitly or explicitly), which + // is immutable. We can give the user a bit more context into + // what is going on. + // + maybeDiagnoseThisNotLValue(expr->left); + } + } + expr->type = type; + return expr; + } + + void registerExtension(ExtensionDecl* decl) + { + if (decl->IsChecked(DeclCheckState::CheckedHeader)) + return; + + decl->SetCheckState(DeclCheckState::CheckingHeader); + decl->targetType = CheckProperType(decl->targetType); + decl->SetCheckState(DeclCheckState::CheckedHeader); + + // TODO: need to check that the target type names a declaration... + + if (auto targetDeclRefType = as(decl->targetType)) + { + // Attach our extension to that type as a candidate... + if (auto aggTypeDeclRef = targetDeclRefType->declRef.as()) + { + auto aggTypeDecl = aggTypeDeclRef.getDecl(); + decl->nextCandidateExtension = aggTypeDecl->candidateExtensions; + aggTypeDecl->candidateExtensions = decl; + return; + } + } + getSink()->diagnose(decl->targetType.exp, Diagnostics::unimplemented, "expected a nominal type here"); + } + + void visitExtensionDecl(ExtensionDecl* decl) + { + if (decl->IsChecked(getCheckedState())) return; + + if (!as(decl->targetType)) + { + getSink()->diagnose(decl->targetType.exp, Diagnostics::unimplemented, "expected a nominal type here"); + } + // now check the members of the extension + for (auto m : decl->Members) + { + checkDecl(m); + } + decl->SetCheckState(getCheckedState()); + } + + // Figure out what type an initializer/constructor declaration + // is supposed to return. In most cases this is just the type + // declaration that its declaration is nested inside. + RefPtr findResultTypeForConstructorDecl(ConstructorDecl* decl) + { + // We want to look at the parent of the declaration, + // but if the declaration is generic, the parent will be + // the `GenericDecl` and we need to skip past that to + // the grandparent. + // + auto parent = decl->ParentDecl; + auto genericParent = as(parent); + if (genericParent) + { + parent = genericParent->ParentDecl; + } + + // Now look at the type of the parent (or grandparent). + if (auto aggTypeDecl = as(parent)) + { + // We are nested in an aggregate type declaration, + // so the result type of the initializer will just + // be the surrounding type. + return DeclRefType::Create( + getSession(), + makeDeclRef(aggTypeDecl)); + } + else if (auto extDecl = as(parent)) + { + // We are nested inside an extension, so the result + // type needs to be the type being extended. + return extDecl->targetType.type; + } + else + { + getSink()->diagnose(decl, Diagnostics::initializerNotInsideType); + return nullptr; + } + } + + void visitConstructorDecl(ConstructorDecl* decl) + { + if (decl->IsChecked(getCheckedState())) return; + if (checkingPhase == CheckingPhase::Header) + { + decl->SetCheckState(DeclCheckState::CheckingHeader); + + for (auto& paramDecl : decl->GetParameters()) + { + paramDecl->type = CheckUsableType(paramDecl->type); + } + + // We need to compute the result tyep for this declaration, + // since it wasn't filled in for us. + decl->ReturnType.type = findResultTypeForConstructorDecl(decl); + } + else + { + // TODO(tfoley): check body + } + decl->SetCheckState(getCheckedState()); + } + + + void visitSubscriptDecl(SubscriptDecl* decl) + { + if (decl->IsChecked(getCheckedState())) return; + for (auto& paramDecl : decl->GetParameters()) + { + paramDecl->type = CheckUsableType(paramDecl->type); + } + + decl->ReturnType = CheckUsableType(decl->ReturnType); + + // If we have a subscript declaration with no accessor declarations, + // then we should create a single `GetterDecl` to represent + // the implicit meaning of their declaration, so: + // + // subscript(uint index) -> T; + // + // becomes: + // + // subscript(uint index) -> T { get; } + // + + bool anyAccessors = false; + for(auto accessorDecl : decl->getMembersOfType()) + { + anyAccessors = true; + } + + if(!anyAccessors) + { + RefPtr getterDecl = new GetterDecl(); + getterDecl->loc = decl->loc; + + getterDecl->ParentDecl = decl; + decl->Members.add(getterDecl); + } + + for(auto mm : decl->Members) + { + checkDecl(mm); + } + + decl->SetCheckState(getCheckedState()); + } + + void visitAccessorDecl(AccessorDecl* decl) + { + if (checkingPhase == CheckingPhase::Header) + { + // An accessor must appear nested inside a subscript declaration (today), + // or a property declaration (when we add them). It will derive + // its return type from the outer declaration, so we handle both + // of these checks at the same place. + auto parent = decl->ParentDecl; + if (auto parentSubscript = as(parent)) + { + decl->ReturnType = parentSubscript->ReturnType; + } + // TODO: when we add "property" declarations, check for them here + else + { + getSink()->diagnose(decl, Diagnostics::accessorMustBeInsideSubscriptOrProperty); + } + + } + else + { + // TODO: check the body! + } + decl->SetCheckState(getCheckedState()); + } + + + // + + struct Constraint + { + Decl* decl; // the declaration of the thing being constraints + RefPtr val; // the value to which we are constraining it + bool satisfied = false; // Has this constraint been met? + }; + + // A collection of constraints that will need to be satisfied (solved) + // in order for checking to succeed. + struct ConstraintSystem + { + // A source location to use in reporting any issues + SourceLoc loc; + + // The generic declaration whose parameters we + // are trying to solve for. + RefPtr genericDecl; + + // Constraints we have accumulated, which constrain + // the possible arguments for those parameters. + List constraints; + }; + + RefPtr TryJoinVectorAndScalarType( + RefPtr vectorType, + RefPtr scalarType) + { + // Join( vector, S ) -> vetor + // + // That is, the join of a vector and a scalar type is + // a vector type with a joined element type. + auto joinElementType = TryJoinTypes( + vectorType->elementType, + scalarType); + if(!joinElementType) + return nullptr; + + return createVectorType( + joinElementType, + vectorType->elementCount); + } + + struct TypeWitnessBreadcrumb + { + TypeWitnessBreadcrumb* prev; + + RefPtr sub; + RefPtr sup; + DeclRef declRef; + }; + + // Crete a subtype witness based on the declared relationship + // found in a single breadcrumb + RefPtr createSimpleSubtypeWitness( + TypeWitnessBreadcrumb* breadcrumb) + { + RefPtr witness = new DeclaredSubtypeWitness(); + witness->sub = breadcrumb->sub; + witness->sup = breadcrumb->sup; + witness->declRef = breadcrumb->declRef; + return witness; + } + + RefPtr createTypeWitness( + RefPtr type, + DeclRef interfaceDeclRef, + TypeWitnessBreadcrumb* inBreadcrumbs) + { + if(!inBreadcrumbs) + { + // We need to construct a witness to the fact + // that `type` has been proven to be *equal* + // to `interfaceDeclRef`. + // + SLANG_UNEXPECTED("reflexive type witness"); + UNREACHABLE_RETURN(nullptr); + } + + // We might have one or more steps in the breadcrumb trail, e.g.: + // + // {A : B} {B : C} {C : D} + // + // The chain is stored as a reversed linked list, so that + // the first entry would be the `(C : D)` relationship + // above. + // + // We need to walk the list and build up a suitable witness, + // which in the above case would look like: + // + // Transitive( + // Transitive( + // Declared({A : B}), + // {B : C}), + // {C : D}) + // + // Because of the ordering of the breadcrumb trail, along + // with the way the `Transitive` case nests, we will be + // building these objects outside-in, and keeping + // track of the "hole" where the next step goes. + // + auto bb = inBreadcrumbs; + + // `witness` here will hold the first (outer-most) object + // we create, which is the overall result. + RefPtr witness; + + // `link` will point at the remaining "hole" in the + // data structure, to be filled in. + RefPtr* link = &witness; + + // As long as there is more than one breadcrumb, we + // need to be creating transitive witnesses. + while(bb->prev) + { + // On the first iteration when processing the list + // above, the breadcrumb would be for `{ C : D }`, + // and so we'd create: + // + // Transitive( + // [...], + // { C : D}) + // + // where `[...]` represents the "hole" we leave + // open to fill in next. + // + RefPtr transitiveWitness = new TransitiveSubtypeWitness(); + transitiveWitness->sub = bb->sub; + transitiveWitness->sup = bb->sup; + transitiveWitness->midToSup = bb->declRef; + + // Fill in the current hole, and then set the + // hole to point into the node we just created. + *link = transitiveWitness; + link = &transitiveWitness->subToMid; + + // Move on with the list. + bb = bb->prev; + } + + // If we exit the loop, then there is only one breadcrumb left. + // In our running example this would be `{ A : B }`. We create + // a simple (declared) subtype witness for it, and plug the + // final hole, after which there shouldn't be a hole to deal with. + RefPtr declaredWitness = createSimpleSubtypeWitness(bb); + *link = declaredWitness; + + // We now know that our original `witness` variable has been + // filled in, and there are no other holes. + return witness; + } + + /// Is the given interface one that a tagged-union type can conform to? + /// + /// If a tagged union type `__TaggedUnion(A,B)` is going to be + /// plugged in for a type parameter `T : IFoo` then we need to + /// be sure that the interface `IFoo` doesn't have anything + /// that could lead to unsafe/unsound behavior. This function + /// checks that all the requirements on the interfaceare safe ones. + /// + bool isInterfaceSafeForTaggedUnion( + DeclRef interfaceDeclRef) + { + for( auto memberDeclRef : getMembers(interfaceDeclRef) ) + { + if(!isInterfaceRequirementSafeForTaggedUnion(interfaceDeclRef, memberDeclRef)) + return false; + } + + return true; + } + + /// Is the given interface requirement one that a tagged-union type can satisfy? + /// + /// Unsafe requirements include any `static` requirements, + /// any associated types, and also any requirements that make + /// use of the `This` type (once we support it). + /// + bool isInterfaceRequirementSafeForTaggedUnion( + DeclRef interfaceDeclRef, + DeclRef requirementDeclRef) + { + if(auto callableDeclRef = requirementDeclRef.as()) + { + // A `static` method requirement can't be satisfied by a + // tagged union, because there is no tag to dispatch on. + // + if(requirementDeclRef.getDecl()->HasModifier()) + return false; + + // TODO: We will eventually want to check that any callable + // requirements do not use the `This` type or any associated + // types in ways that could lead to errors. + // + // For now we are disallowing interfaces that have associated + // types completely, and we haven't implemented the `This` + // type, so we should be safe. + + return true; + } + else + { + return false; + } + } + + bool doesTypeConformToInterfaceImpl( + RefPtr originalType, + RefPtr type, + DeclRef interfaceDeclRef, + RefPtr* outWitness, + TypeWitnessBreadcrumb* inBreadcrumbs) + { + // for now look up a conformance member... + if(auto declRefType = as(type)) + { + auto declRef = declRefType->declRef; + + // Easy case: a type conforms to itself. + // + // TODO: This is actually a bit more complicated, as + // the interface needs to be "object-safe" for us to + // really make this determination... + if(declRef == interfaceDeclRef) + { + if(outWitness) + { + *outWitness = createTypeWitness(originalType, interfaceDeclRef, inBreadcrumbs); + } + return true; + } + + if( auto aggTypeDeclRef = declRef.as() ) + { + checkDecl(aggTypeDeclRef.getDecl()); + + for( auto inheritanceDeclRef : getMembersOfTypeWithExt(aggTypeDeclRef)) + { + checkDecl(inheritanceDeclRef.getDecl()); + + // Here we will recursively look up conformance on the type + // that is being inherited from. This is dangerous because + // it might lead to infinite loops. + // + // TODO: A better approach would be to create a linearized list + // of all the interfaces that a given type directly or indirectly + // inherits, and store it with the type, so that we don't have + // to recurse in places like this (and can maybe catch infinite + // loops better). This would also help avoid checking multiply-inherited + // conformances multiple times. + + auto inheritedType = getBaseType(inheritanceDeclRef); + + // We need to ensure that the witness that gets created + // is a composite one, reflecting lookup through + // the inheritance declaration. + TypeWitnessBreadcrumb breadcrumb; + breadcrumb.prev = inBreadcrumbs; + + breadcrumb.sub = type; + breadcrumb.sup = inheritedType; + breadcrumb.declRef = inheritanceDeclRef; + + if(doesTypeConformToInterfaceImpl(originalType, inheritedType, interfaceDeclRef, outWitness, &breadcrumb)) + { + return true; + } + } + // if an inheritance decl is not found, try to find a GenericTypeConstraintDecl + for (auto genConstraintDeclRef : getMembersOfType(aggTypeDeclRef)) + { + checkDecl(genConstraintDeclRef.getDecl()); + auto inheritedType = GetSup(genConstraintDeclRef); + TypeWitnessBreadcrumb breadcrumb; + breadcrumb.prev = inBreadcrumbs; + breadcrumb.sub = type; + breadcrumb.sup = inheritedType; + breadcrumb.declRef = genConstraintDeclRef; + if (doesTypeConformToInterfaceImpl(originalType, inheritedType, interfaceDeclRef, outWitness, &breadcrumb)) + { + return true; + } + } + } + else if( auto genericTypeParamDeclRef = declRef.as() ) + { + // We need to enumerate the constraints placed on this type by its outer + // generic declaration, and see if any of them guarantees that we + // satisfy the given interface.. + auto genericDeclRef = genericTypeParamDeclRef.GetParent().as(); + SLANG_ASSERT(genericDeclRef); + + for( auto constraintDeclRef : getMembersOfType(genericDeclRef) ) + { + auto sub = GetSub(constraintDeclRef); + auto sup = GetSup(constraintDeclRef); + + auto subDeclRef = as(sub); + if(!subDeclRef) + continue; + if(subDeclRef->declRef != genericTypeParamDeclRef) + continue; + + // The witness that we create needs to reflect that + // it found the needed conformance by lookup through + // a generic type constraint. + + TypeWitnessBreadcrumb breadcrumb; + breadcrumb.prev = inBreadcrumbs; + breadcrumb.sub = sub; + breadcrumb.sup = sup; + breadcrumb.declRef = constraintDeclRef; + + if(doesTypeConformToInterfaceImpl(originalType, sup, interfaceDeclRef, outWitness, &breadcrumb)) + { + return true; + } + } + } + } + else if(auto taggedUnionType = as(type)) + { + // A tagged union type conforms to an interface if all of + // the constituent types in the tagged union conform. + // + // We will iterate over the "case" types in the tagged + // union, and check if they conform to the interface. + // Along the way we will collect the conformance witness + // values *if* we are being asked to produce a witness + // value for the tagged union itself (that is, if + // `outWitness` is non-null). + // + List> caseWitnesses; + for(auto caseType : taggedUnionType->caseTypes) + { + RefPtr caseWitness; + + if(!doesTypeConformToInterfaceImpl( + caseType, + caseType, + interfaceDeclRef, + outWitness ? &caseWitness : nullptr, + nullptr)) + { + return false; + } + + if(outWitness) + { + caseWitnesses.add(caseWitness); + } + } + + // We also need to validate the requirements on + // the interface to make sure that they are suitable for + // use with a tagged-union type. + // + // For example, if the interface includes a `static` method + // (which can therefore be called without a particular instance), + // then we wouldn't know what implementation of that method + // to use because there is no tag value to dispatch on. + // + // We will start out being conservative about what we accept + // here, just to keep things simple. + // + if(!isInterfaceSafeForTaggedUnion(interfaceDeclRef)) + return false; + + // If we reach this point then we have a concrete + // witness for each of the case types, and that is + // enough to build a witness for the tagged union. + // + if(outWitness) + { + RefPtr taggedUnionWitness = new TaggedUnionSubtypeWitness(); + taggedUnionWitness->sub = taggedUnionType; + taggedUnionWitness->sup = DeclRefType::Create(getSession(), interfaceDeclRef); + taggedUnionWitness->caseWitnesses.swapWith(caseWitnesses); + + *outWitness = taggedUnionWitness; + } + return true; + } + + // default is failure + return false; + } + + bool DoesTypeConformToInterface( + RefPtr type, + DeclRef interfaceDeclRef) + { + return doesTypeConformToInterfaceImpl(type, type, interfaceDeclRef, nullptr, nullptr); + } + + RefPtr tryGetInterfaceConformanceWitness( + RefPtr type, + DeclRef interfaceDeclRef) + { + RefPtr result; + doesTypeConformToInterfaceImpl(type, type, interfaceDeclRef, &result, nullptr); + return result; + } + + /// Does there exist an implicit conversion from `fromType` to `toType`? + bool canConvertImplicitly( + RefPtr toType, + RefPtr fromType) + { + // Can we convert at all? + ConversionCost conversionCost; + if(!canCoerce(toType, fromType, &conversionCost)) + return false; + + // Is the conversion cheap enough to be done implicitly? + if(conversionCost >= kConversionCost_GeneralConversion) + return false; + + return true; + } + + RefPtr TryJoinTypeWithInterface( + RefPtr type, + DeclRef interfaceDeclRef) + { + // The most basic test here should be: does the type declare conformance to the trait. + if(DoesTypeConformToInterface(type, interfaceDeclRef)) + return type; + + // Just because `type` doesn't conform to the given `interfaceDeclRef`, that + // doesn't necessarily indicate a failure. It is possible that we have a call + // like `sqrt(2)` so that `type` is `int` and `interfaceDeclRef` is + // `__BuiltinFloatingPointType`. The "obvious" answer is that we should infer + // the type `float`, but it seems like the compiler would have to synthesize + // that answer from thin air. + // + // A robsut/correct solution here might be to enumerate set of types types `S` + // such that for each type `X` in `S`: + // + // * `type` is implicitly convertible to `X` + // * `X` conforms to the interface named by `interfaceDeclRef` + // + // If the set `S` is non-empty then we would try to pick the "best" type from `S`. + // The "best" type would be a type `Y` such that `Y` is implicitly convertible to + // every other type in `S`. + // + // We are going to implement a much simpler strategy for now, where we only apply + // the search process if `type` is a builtin scalar type, and then we only search + // through types `X` that are also builtin scalar types. + // + RefPtr bestType; + if(auto basicType = type.dynamicCast()) + { + for(Int baseTypeFlavorIndex = 0; baseTypeFlavorIndex < Int(BaseType::CountOf); baseTypeFlavorIndex++) + { + // Don't consider `type`, since we already know it doesn't work. + if(baseTypeFlavorIndex == Int(basicType->baseType)) + continue; + + // Look up the type in our session. + auto candidateType = type->getSession()->getBuiltinType(BaseType(baseTypeFlavorIndex)); + if(!candidateType) + continue; + + // We only want to consider types that implement the target interface. + if(!DoesTypeConformToInterface(candidateType, interfaceDeclRef)) + continue; + + // We only want to consider types where we can implicitly convert from `type` + if(!canConvertImplicitly(candidateType, type)) + continue; + + // At this point, we have a candidate type that is usable. + // + // If this is our first viable candidate, then it is our best one: + // + if(!bestType) + { + bestType = candidateType; + } + else + { + // Otherwise, we want to pick the "better" type between `candidateType` + // and `bestType`. + // + // We are going to be a bit loose here, and not worry about the + // case where conversion is allowed in both directions. + // + // TODO: make this completely robust. + // + if(canConvertImplicitly(bestType, candidateType)) + { + // Our candidate can convert to the current "best" type, so + // it is logically a more specific type that satisfies our + // constraints, therefore we should keep it. + // + bestType = candidateType; + } + } + } + if(bestType) + return bestType; + } + + // For all other cases, we will just bail out for now. + // + // TODO: In the future we should build some kind of side data structure + // to accelerate either one or both of these queries: + // + // * Given a type `T`, what types `U` can it convert to implicitly? + // + // * Given an interface `I`, what types `U` conform to it? + // + // The intersection of the sets returned by these two queries is + // the set of candidates we would like to consider here. + + return nullptr; + } + + // Try to compute the "join" between two types + RefPtr TryJoinTypes( + RefPtr left, + RefPtr right) + { + // Easy case: they are the same type! + if (left->Equals(right)) + return left; + + // We can join two basic types by picking the "better" of the two + if (auto leftBasic = as(left)) + { + if (auto rightBasic = as(right)) + { + auto leftFlavor = leftBasic->baseType; + auto rightFlavor = rightBasic->baseType; + + // TODO(tfoley): Need a special-case rule here that if + // either operand is of type `half`, then we promote + // to at least `float` + + // Return the one that had higher rank... + if (leftFlavor > rightFlavor) + return left; + else + { + SLANG_ASSERT(rightFlavor > leftFlavor); // equality was handles at the top of this function + return right; + } + } + + // We can also join a vector and a scalar + if(auto rightVector = as(right)) + { + return TryJoinVectorAndScalarType(rightVector, leftBasic); + } + } + + // We can join two vector types by joining their element types + // (and also their sizes...) + if( auto leftVector = as(left)) + { + if(auto rightVector = as(right)) + { + // Check if the vector sizes match + if(!leftVector->elementCount->EqualsVal(rightVector->elementCount.Ptr())) + return nullptr; + + // Try to join the element types + auto joinElementType = TryJoinTypes( + leftVector->elementType, + rightVector->elementType); + if(!joinElementType) + return nullptr; + + return createVectorType( + joinElementType, + leftVector->elementCount); + } + + // We can also join a vector and a scalar + if(auto rightBasic = as(right)) + { + return TryJoinVectorAndScalarType(leftVector, rightBasic); + } + } + + // HACK: trying to work trait types in here... + if(auto leftDeclRefType = as(left)) + { + if( auto leftInterfaceRef = leftDeclRefType->declRef.as() ) + { + // + return TryJoinTypeWithInterface(right, leftInterfaceRef); + } + } + if(auto rightDeclRefType = as(right)) + { + if( auto rightInterfaceRef = rightDeclRefType->declRef.as() ) + { + // + return TryJoinTypeWithInterface(left, rightInterfaceRef); + } + } + + // TODO: all the cases for vectors apply to matrices too! + + // Default case is that we just fail. + return nullptr; + } + + // Try to solve a system of generic constraints. + // The `system` argument provides the constraints. + // The `varSubst` argument provides the list of constraint + // variables that were created for the system. + // + // Returns a new substitution representing the values that + // we solved for along the way. + SubstitutionSet TrySolveConstraintSystem( + ConstraintSystem* system, + DeclRef genericDeclRef) + { + // For now the "solver" is going to be ridiculously simplistic. + + // The generic itself will have some constraints, and for now we add these + // to the system of constrains we will use for solving for the type variables. + // + // TODO: we need to decide whether constraints are used like this to influence + // how we solve for type/value variables, or whether constraints in the parameter + // list just work as a validation step *after* we've solved for the types. + // + // That is, should we allow `` to be written, and cause us to "infer" + // that `T` should be the type `Int`? That seems a little silly. + // + // Eventually, though, we may want to support type identity constraints, especially + // on associated types, like `` + // These seem more reasonable to have influence constraint solving, since it could + // conceivably let us specialize a `X : IContainer` to `X` if we find + // that `X.IndexType == T`. + for( auto constraintDeclRef : getMembersOfType(genericDeclRef) ) + { + if(!TryUnifyTypes(*system, GetSub(constraintDeclRef), GetSup(constraintDeclRef))) + return SubstitutionSet(); + } + SubstitutionSet resultSubst = genericDeclRef.substitutions; + // We will loop over the generic parameters, and for + // each we will try to find a way to satisfy all + // the constraints for that parameter + List> args; + for (auto m : getMembers(genericDeclRef)) + { + if (auto typeParam = m.as()) + { + RefPtr type = nullptr; + for (auto& c : system->constraints) + { + if (c.decl != typeParam.getDecl()) + continue; + + auto cType = as(c.val); + SLANG_RELEASE_ASSERT(cType); + + if (!type) + { + type = cType; + } + else + { + auto joinType = TryJoinTypes(type, cType); + if (!joinType) + { + // failure! + return SubstitutionSet(); + } + type = joinType; + } + + c.satisfied = true; + } + + if (!type) + { + // failure! + return SubstitutionSet(); + } + args.add(type); + } + else if (auto valParam = m.as()) + { + // TODO(tfoley): maybe support more than integers some day? + // TODO(tfoley): figure out how this needs to interact with + // compile-time integers that aren't just constants... + RefPtr val = nullptr; + for (auto& c : system->constraints) + { + if (c.decl != valParam.getDecl()) + continue; + + auto cVal = as(c.val); + SLANG_RELEASE_ASSERT(cVal); + + if (!val) + { + val = cVal; + } + else + { + if(!val->EqualsVal(cVal)) + { + // failure! + return SubstitutionSet(); + } + } + + c.satisfied = true; + } + + if (!val) + { + // failure! + return SubstitutionSet(); + } + args.add(val); + } + else + { + // ignore anything that isn't a generic parameter + } + } + + // After we've solved for the explicit arguments, we need to + // make a second pass and consider the implicit arguments, + // based on what we've already determined to be the values + // for the explicit arguments. + + // Before we begin, we are going to go ahead and create the + // "solved" substitution that we will return if everything works. + // This is because we are going to use this substitution, + // partially filled in with the results we know so far, + // in order to specialize any constraints on the generic. + // + // E.g., if the generic parameters were ``, and + // we've already decided that `T` is `Robin`, then we want to + // search for a conformance `Robin : ISidekick`, which involved + // apply the substitutions we already know... + + RefPtr solvedSubst = new GenericSubstitution(); + solvedSubst->genericDecl = genericDeclRef.getDecl(); + solvedSubst->outer = genericDeclRef.substitutions.substitutions; + solvedSubst->args = args; + resultSubst.substitutions = solvedSubst; + + for( auto constraintDecl : genericDeclRef.getDecl()->getMembersOfType() ) + { + DeclRef constraintDeclRef( + constraintDecl, + solvedSubst); + + // Extract the (substituted) sub- and super-type from the constraint. + auto sub = GetSub(constraintDeclRef); + auto sup = GetSup(constraintDeclRef); + + // Search for a witness that shows the constraint is satisfied. + auto subTypeWitness = tryGetSubtypeWitness(sub, sup); + if(subTypeWitness) + { + // We found a witness, so it will become an (implicit) argument. + solvedSubst->args.add(subTypeWitness); + } + else + { + // No witness was found, so the inference will now fail. + // + // TODO: Ideally we should print an error message in + // this case, to let the user know why things failed. + return SubstitutionSet(); + } + + // TODO: We may need to mark some constrains in our constraint + // system as being solved now, as a result of the witness we found. + } + + // Make sure we haven't constructed any spurious constraints + // that we aren't able to satisfy: + for (auto c : system->constraints) + { + if (!c.satisfied) + { + return SubstitutionSet(); + } + } + + return resultSubst; + } + + + // State related to overload resolution for a call + // to an overloaded symbol + struct OverloadResolveContext + { + enum class Mode + { + // We are just checking if a candidate works or not + JustTrying, + + // We want to actually update the AST for a chosen candidate + ForReal, + }; + + // Location to use when reporting overload-resolution errors. + SourceLoc loc; + + // The original expression (if any) that triggered things + RefPtr originalExpr; + + // Source location of the "function" part of the expression, if any + SourceLoc funcLoc; + + // The original arguments to the call + Index argCount = 0; + RefPtr* args = nullptr; + RefPtr* argTypes = nullptr; + + Index getArgCount() { return argCount; } + RefPtr& getArg(Index index) { return args[index]; } + RefPtr& getArgType(Index index) + { + if(argTypes) + return argTypes[index]; + else + return getArg(index)->type.type; + } + + bool disallowNestedConversions = false; + + RefPtr baseExpr; + + // Are we still trying out candidates, or are we + // checking the chosen one for real? + Mode mode = Mode::JustTrying; + + // We store one candidate directly, so that we don't + // need to do dynamic allocation on the list every time + OverloadCandidate bestCandidateStorage; + OverloadCandidate* bestCandidate = nullptr; + + // Full list of all candidates being considered, in the ambiguous case + List bestCandidates; + }; + + struct ParamCounts + { + UInt required; + UInt allowed; + }; + + // count the number of parameters required/allowed for a callable + ParamCounts CountParameters(FilteredMemberRefList params) + { + ParamCounts counts = { 0, 0 }; + for (auto param : params) + { + counts.allowed++; + + // No initializer means no default value + // + // TODO(tfoley): The logic here is currently broken in two ways: + // + // 1. We are assuming that once one parameter has a default, then all do. + // This can/should be validated earlier, so that we can assume it here. + // + // 2. We are not handling the possibility of multiple declarations for + // a single function, where we'd need to merge default parameters across + // all the declarations. + if (!param.getDecl()->initExpr) + { + counts.required++; + } + } + return counts; + } + + // count the number of parameters required/allowed for a generic + ParamCounts CountParameters(DeclRef genericRef) + { + ParamCounts counts = { 0, 0 }; + for (auto m : genericRef.getDecl()->Members) + { + if (auto typeParam = as(m)) + { + counts.allowed++; + if (!typeParam->initType.Ptr()) + { + counts.required++; + } + } + else if (auto valParam = as(m)) + { + counts.allowed++; + if (!valParam->initExpr) + { + counts.required++; + } + } + } + return counts; + } + + bool TryCheckOverloadCandidateArity( + OverloadResolveContext& context, + OverloadCandidate const& candidate) + { + UInt argCount = context.getArgCount(); + ParamCounts paramCounts = { 0, 0 }; + switch (candidate.flavor) + { + case OverloadCandidate::Flavor::Func: + paramCounts = CountParameters(GetParameters(candidate.item.declRef.as())); + break; + + case OverloadCandidate::Flavor::Generic: + paramCounts = CountParameters(candidate.item.declRef.as()); + break; + + default: + SLANG_UNEXPECTED("unknown flavor of overload candidate"); + break; + } + + if (argCount >= paramCounts.required && argCount <= paramCounts.allowed) + return true; + + // Emit an error message if we are checking this call for real + if (context.mode != OverloadResolveContext::Mode::JustTrying) + { + if (argCount < paramCounts.required) + { + getSink()->diagnose(context.loc, Diagnostics::notEnoughArguments, argCount, paramCounts.required); + } + else + { + SLANG_ASSERT(argCount > paramCounts.allowed); + getSink()->diagnose(context.loc, Diagnostics::tooManyArguments, argCount, paramCounts.allowed); + } + } + + return false; + } + + bool TryCheckOverloadCandidateFixity( + OverloadResolveContext& context, + OverloadCandidate const& candidate) + { + auto expr = context.originalExpr; + + auto decl = candidate.item.declRef.decl; + + if(auto prefixExpr = as(expr)) + { + if(decl->HasModifier()) + return true; + + if (context.mode != OverloadResolveContext::Mode::JustTrying) + { + getSink()->diagnose(context.loc, Diagnostics::expectedPrefixOperator); + getSink()->diagnose(decl, Diagnostics::seeDefinitionOf, decl->getName()); + } + + return false; + } + else if(auto postfixExpr = as(expr)) + { + if(decl->HasModifier()) + return true; + + if (context.mode != OverloadResolveContext::Mode::JustTrying) + { + getSink()->diagnose(context.loc, Diagnostics::expectedPostfixOperator); + getSink()->diagnose(decl, Diagnostics::seeDefinitionOf, decl->getName()); + } + + return false; + } + else + { + return true; + } + + return false; + } + + bool TryCheckGenericOverloadCandidateTypes( + OverloadResolveContext& context, + OverloadCandidate& candidate) + { + auto genericDeclRef = candidate.item.declRef.as(); + + // We will go ahead and hang onto the arguments that we've + // already checked, since downstream validation might need + // them. + auto genSubst = new GenericSubstitution(); + candidate.subst = genSubst; + auto& checkedArgs = genSubst->args; + + Index aa = 0; + for (auto memberRef : getMembers(genericDeclRef)) + { + if (auto typeParamRef = memberRef.as()) + { + if (aa >= context.argCount) + { + return false; + } + auto arg = context.getArg(aa++); + + TypeExp typeExp; + if (context.mode == OverloadResolveContext::Mode::JustTrying) + { + typeExp = tryCoerceToProperType(TypeExp(arg)); + if(!typeExp.type) + { + return false; + } + } + else + { + typeExp = CoerceToProperType(TypeExp(arg)); + } + checkedArgs.add(typeExp.type); + } + else if (auto valParamRef = memberRef.as()) + { + auto arg = context.getArg(aa++); + + if (context.mode == OverloadResolveContext::Mode::JustTrying) + { + ConversionCost cost = kConversionCost_None; + if (!canCoerce(GetType(valParamRef), arg->type, &cost)) + { + return false; + } + candidate.conversionCostSum += cost; + } + + arg = coerce(GetType(valParamRef), arg); + auto val = ExtractGenericArgInteger(arg); + checkedArgs.add(val); + } + else + { + continue; + } + } + + // Okay, we've made it! + return true; + } + + bool TryCheckOverloadCandidateTypes( + OverloadResolveContext& context, + OverloadCandidate& candidate) + { + Index argCount = context.getArgCount(); + + List> params; + switch (candidate.flavor) + { + case OverloadCandidate::Flavor::Func: + params = GetParameters(candidate.item.declRef.as()).ToArray(); + break; + + case OverloadCandidate::Flavor::Generic: + return TryCheckGenericOverloadCandidateTypes(context, candidate); + + default: + SLANG_UNEXPECTED("unknown flavor of overload candidate"); + break; + } + + // Note(tfoley): We might have fewer arguments than parameters in the + // case where one or more parameters had defaults. + SLANG_RELEASE_ASSERT(argCount <= params.getCount()); + + for (Index ii = 0; ii < argCount; ++ii) + { + auto& arg = context.getArg(ii); + auto argType = context.getArgType(ii); + auto param = params[ii]; + + if (context.mode == OverloadResolveContext::Mode::JustTrying) + { + ConversionCost cost = kConversionCost_None; + if( context.disallowNestedConversions ) + { + // We need an exact match in this case. + if(!GetType(param)->Equals(argType)) + return false; + } + else if (!canCoerce(GetType(param), argType, &cost)) + { + return false; + } + candidate.conversionCostSum += cost; + } + else + { + arg = coerce(GetType(param), arg); + } + } + return true; + } + + bool TryCheckOverloadCandidateDirections( + OverloadResolveContext& /*context*/, + OverloadCandidate const& /*candidate*/) + { + // TODO(tfoley): check `in` and `out` markers, as needed. + return true; + } + + // Create a witness that attests to the fact that `type` + // is equal to itself. + RefPtr createTypeEqualityWitness( + Type* type) + { + RefPtr rs = new TypeEqualityWitness(); + rs->sub = type; + rs->sup = type; + return rs; + } + + // If `sub` is a subtype of `sup`, then return a value that + // can serve as a "witness" for that fact. + RefPtr tryGetSubtypeWitness( + RefPtr sub, + RefPtr sup) + { + if(sub->Equals(sup)) + { + // They are the same type, so we just need a witness + // for type equality. + return createTypeEqualityWitness(sub); + } + + if(auto supDeclRefType = as(sup)) + { + auto supDeclRef = supDeclRefType->declRef; + if(auto supInterfaceDeclRef = supDeclRef.as()) + { + if(auto witness = tryGetInterfaceConformanceWitness(sub, supInterfaceDeclRef)) + { + return witness; + } + } + } + + return nullptr; + } + + // In the case where we are explicitly applying a generic + // to arguments (e.g., `G`) check that the constraints + // on those parameters are satisfied. + // + // Note: the constraints actually work as additional parameters/arguments + // of the generic, and so we need to reify them into the final + // argument list. + // + bool TryCheckOverloadCandidateConstraints( + OverloadResolveContext& context, + OverloadCandidate const& candidate) + { + // We only need this step for generics, so always succeed on + // everything else. + if(candidate.flavor != OverloadCandidate::Flavor::Generic) + return true; + + auto genericDeclRef = candidate.item.declRef.as(); + SLANG_ASSERT(genericDeclRef); // otherwise we wouldn't be a generic candidate... + + // We should have the existing arguments to the generic + // handy, so that we can construct a substitution list. + + auto subst = candidate.subst.as(); + SLANG_ASSERT(subst); + + subst->genericDecl = genericDeclRef.getDecl(); + subst->outer = genericDeclRef.substitutions.substitutions; + + for( auto constraintDecl : genericDeclRef.getDecl()->getMembersOfType() ) + { + auto subset = genericDeclRef.substitutions; + subset.substitutions = subst; + DeclRef constraintDeclRef( + constraintDecl, subset); + + auto sub = GetSub(constraintDeclRef); + auto sup = GetSup(constraintDeclRef); + + auto subTypeWitness = tryGetSubtypeWitness(sub, sup); + if(subTypeWitness) + { + subst->args.add(subTypeWitness); + } + else + { + if(context.mode != OverloadResolveContext::Mode::JustTrying) + { + // TODO: diagnose a problem here + getSink()->diagnose(context.loc, Diagnostics::unimplemented, "generic constraint not satisfied"); + } + return false; + } + } + + // Done checking all the constraints, hooray. + return true; + } + + // Try to check an overload candidate, but bail out + // if any step fails + void TryCheckOverloadCandidate( + OverloadResolveContext& context, + OverloadCandidate& candidate) + { + if (!TryCheckOverloadCandidateArity(context, candidate)) + return; + + candidate.status = OverloadCandidate::Status::ArityChecked; + if (!TryCheckOverloadCandidateFixity(context, candidate)) + return; + + candidate.status = OverloadCandidate::Status::FixityChecked; + if (!TryCheckOverloadCandidateTypes(context, candidate)) + return; + + candidate.status = OverloadCandidate::Status::TypeChecked; + if (!TryCheckOverloadCandidateDirections(context, candidate)) + return; + + candidate.status = OverloadCandidate::Status::DirectionChecked; + if (!TryCheckOverloadCandidateConstraints(context, candidate)) + return; + + candidate.status = OverloadCandidate::Status::Applicable; + } + + // Create the representation of a given generic applied to some arguments + RefPtr createGenericDeclRef( + RefPtr baseExpr, + RefPtr originalExpr, + RefPtr subst) + { + auto baseDeclRefExpr = as(baseExpr); + if (!baseDeclRefExpr) + { + SLANG_DIAGNOSE_UNEXPECTED(getSink(), baseExpr, "expected a reference to a generic declaration"); + return CreateErrorExpr(originalExpr); + } + auto baseGenericRef = baseDeclRefExpr->declRef.as(); + if (!baseGenericRef) + { + SLANG_DIAGNOSE_UNEXPECTED(getSink(), baseExpr, "expected a reference to a generic declaration"); + return CreateErrorExpr(originalExpr); + } + + subst->genericDecl = baseGenericRef.getDecl(); + subst->outer = baseGenericRef.substitutions.substitutions; + + DeclRef innerDeclRef(GetInner(baseGenericRef), subst); + + RefPtr base; + if (auto mbrExpr = as(baseExpr)) + base = mbrExpr->BaseExpression; + + return ConstructDeclRefExpr( + innerDeclRef, + base, + originalExpr->loc); + } + + // Take an overload candidate that previously got through + // `TryCheckOverloadCandidate` above, and try to finish + // up the work and turn it into a real expression. + // + // If the candidate isn't actually applicable, this is + // where we'd start reporting the issue(s). + RefPtr CompleteOverloadCandidate( + OverloadResolveContext& context, + OverloadCandidate& candidate) + { + // special case for generic argument inference failure + if (candidate.status == OverloadCandidate::Status::GenericArgumentInferenceFailed) + { + String callString = getCallSignatureString(context); + getSink()->diagnose( + context.loc, + Diagnostics::genericArgumentInferenceFailed, + callString); + + String declString = getDeclSignatureString(candidate.item); + getSink()->diagnose(candidate.item.declRef, Diagnostics::genericSignatureTried, declString); + goto error; + } + + context.mode = OverloadResolveContext::Mode::ForReal; + + if (!TryCheckOverloadCandidateArity(context, candidate)) + goto error; + + if (!TryCheckOverloadCandidateFixity(context, candidate)) + goto error; + + if (!TryCheckOverloadCandidateTypes(context, candidate)) + goto error; + + if (!TryCheckOverloadCandidateDirections(context, candidate)) + goto error; + + if (!TryCheckOverloadCandidateConstraints(context, candidate)) + goto error; + + { + auto baseExpr = ConstructLookupResultExpr( + candidate.item, context.baseExpr, context.funcLoc); + + switch(candidate.flavor) + { + case OverloadCandidate::Flavor::Func: + { + RefPtr callExpr = as(context.originalExpr); + if(!callExpr) + { + callExpr = new InvokeExpr(); + callExpr->loc = context.loc; + + for(Index aa = 0; aa < context.argCount; ++aa) + callExpr->Arguments.add(context.getArg(aa)); + } + + + callExpr->FunctionExpr = baseExpr; + callExpr->type = QualType(candidate.resultType); + + // A call may yield an l-value, and we should take a look at the candidate to be sure + if(auto subscriptDeclRef = candidate.item.declRef.as()) + { + for(auto setter : subscriptDeclRef.getDecl()->getMembersOfType()) + { + callExpr->type.IsLeftValue = true; + } + for(auto refAccessor : subscriptDeclRef.getDecl()->getMembersOfType()) + { + callExpr->type.IsLeftValue = true; + } + } + + // TODO: there may be other cases that confer l-value-ness + + return callExpr; + } + + break; + + case OverloadCandidate::Flavor::Generic: + return createGenericDeclRef( + baseExpr, + context.originalExpr, + candidate.subst.as()); + break; + + default: + SLANG_DIAGNOSE_UNEXPECTED(getSink(), context.loc, "unknown overload candidate flavor"); + break; + } + } + + + error: + + if(context.originalExpr) + { + return CreateErrorExpr(context.originalExpr.Ptr()); + } + else + { + SLANG_DIAGNOSE_UNEXPECTED(getSink(), context.loc, "no original expression for overload result"); + return nullptr; + } + } + + // Implement a comparison operation between overload candidates, + // so that the better candidate compares as less-than the other + int CompareOverloadCandidates( + OverloadCandidate* left, + OverloadCandidate* right) + { + // If one candidate got further along in validation, pick it + if (left->status != right->status) + return int(right->status) - int(left->status); + + // If both candidates are applicable, then we need to compare + // the costs of their type conversion sequences + if(left->status == OverloadCandidate::Status::Applicable) + { + if (left->conversionCostSum != right->conversionCostSum) + return left->conversionCostSum - right->conversionCostSum; + } + + return 0; + } + + void AddOverloadCandidateInner( + OverloadResolveContext& context, + OverloadCandidate& candidate) + { + // Filter our existing candidates, to remove any that are worse than our new one + + bool keepThisCandidate = true; // should this candidate be kept? + + if (context.bestCandidates.getCount() != 0) + { + // We have multiple candidates right now, so filter them. + bool anyFiltered = false; + // Note that we are querying the list length on every iteration, + // because we might remove things. + for (Index cc = 0; cc < context.bestCandidates.getCount(); ++cc) + { + int cmp = CompareOverloadCandidates(&candidate, &context.bestCandidates[cc]); + if (cmp < 0) + { + // our new candidate is better! + + // remove it from the list (by swapping in a later one) + context.bestCandidates.fastRemoveAt(cc); + // and then reduce our index so that we re-visit the same index + --cc; + + anyFiltered = true; + } + else if(cmp > 0) + { + // our candidate is worse! + keepThisCandidate = false; + } + } + // It should not be possible that we removed some existing candidate *and* + // chose not to keep this candidate (otherwise the better-ness relation + // isn't transitive). Therefore we confirm that we either chose to keep + // this candidate (in which case filtering is okay), or we didn't filter + // anything. + SLANG_ASSERT(keepThisCandidate || !anyFiltered); + } + else if(context.bestCandidate) + { + // There's only one candidate so far + int cmp = CompareOverloadCandidates(&candidate, context.bestCandidate); + if(cmp < 0) + { + // our new candidate is better! + context.bestCandidate = nullptr; + } + else if (cmp > 0) + { + // our candidate is worse! + keepThisCandidate = false; + } + } + + // If our candidate isn't good enough, then drop it + if (!keepThisCandidate) + return; + + // Otherwise we want to keep the candidate + if (context.bestCandidates.getCount() > 0) + { + // There were already multiple candidates, and we are adding one more + context.bestCandidates.add(candidate); + } + else if (context.bestCandidate) + { + // There was a unique best candidate, but now we are ambiguous + context.bestCandidates.add(*context.bestCandidate); + context.bestCandidates.add(candidate); + context.bestCandidate = nullptr; + } + else + { + // This is the only candidate worth keeping track of right now + context.bestCandidateStorage = candidate; + context.bestCandidate = &context.bestCandidateStorage; + } + } + + void AddOverloadCandidate( + OverloadResolveContext& context, + OverloadCandidate& candidate) + { + // Try the candidate out, to see if it is applicable at all. + TryCheckOverloadCandidate(context, candidate); + + // Now (potentially) add it to the set of candidate overloads to consider. + AddOverloadCandidateInner(context, candidate); + } + + void AddFuncOverloadCandidate( + LookupResultItem item, + DeclRef funcDeclRef, + OverloadResolveContext& context) + { + auto funcDecl = funcDeclRef.getDecl(); + checkDecl(funcDecl); + + // If this function is a redeclaration, + // then we don't want to include it multiple times, + // and mistakenly think we have an ambiguous call. + // + // Instead, we will carefully consider only the + // "primary" declaration of any callable. + if (auto primaryDecl = funcDecl->primaryDecl) + { + if (funcDecl != primaryDecl) + { + // This is a redeclaration, so we don't + // want to consider it. The primary + // declaration should also get considered + // for the call site and it will match + // anything this declaration would have + // matched. + return; + } + } + + OverloadCandidate candidate; + candidate.flavor = OverloadCandidate::Flavor::Func; + candidate.item = item; + candidate.resultType = GetResultType(funcDeclRef); + + AddOverloadCandidate(context, candidate); + } + + void AddFuncOverloadCandidate( + RefPtr /*funcType*/, + OverloadResolveContext& /*context*/) + { +#if 0 + if (funcType->decl) + { + AddFuncOverloadCandidate(funcType->decl, context); + } + else if (funcType->Func) + { + AddFuncOverloadCandidate(funcType->Func->SyntaxNode, context); + } + else if (funcType->Component) + { + AddComponentFuncOverloadCandidate(funcType->Component, context); + } +#else + throw "unimplemented"; +#endif + } + + // Add a candidate callee for overload resolution, based on + // calling a particular `ConstructorDecl`. + void AddCtorOverloadCandidate( + LookupResultItem typeItem, + RefPtr type, + DeclRef ctorDeclRef, + OverloadResolveContext& context, + RefPtr resultType) + { + checkDecl(ctorDeclRef.getDecl()); + + // `typeItem` refers to the type being constructed (the thing + // that was applied as a function) so we need to construct + // a `LookupResultItem` that refers to the constructor instead + + LookupResultItem ctorItem; + ctorItem.declRef = ctorDeclRef; + ctorItem.breadcrumbs = new LookupResultItem::Breadcrumb( + LookupResultItem::Breadcrumb::Kind::Member, + typeItem.declRef, + typeItem.breadcrumbs); + + OverloadCandidate candidate; + candidate.flavor = OverloadCandidate::Flavor::Func; + candidate.item = ctorItem; + candidate.resultType = resultType; + + AddOverloadCandidate(context, candidate); + } + + // If the given declaration has generic parameters, then + // return the corresponding `GenericDecl` that holds the + // parameters, etc. + GenericDecl* GetOuterGeneric(Decl* decl) + { + auto parentDecl = decl->ParentDecl; + if (!parentDecl) return nullptr; + auto parentGeneric = as(parentDecl); + return parentGeneric; + } + + // Try to find a unification for two values + bool TryUnifyVals( + ConstraintSystem& constraints, + RefPtr fst, + RefPtr snd) + { + // if both values are types, then unify types + if (auto fstType = as(fst)) + { + if (auto sndType = as(snd)) + { + return TryUnifyTypes(constraints, fstType, sndType); + } + } + + // if both values are constant integers, then compare them + if (auto fstIntVal = as(fst)) + { + if (auto sndIntVal = as(snd)) + { + return fstIntVal->value == sndIntVal->value; + } + } + + // Check if both are integer values in general + if (auto fstInt = as(fst)) + { + if (auto sndInt = as(snd)) + { + auto fstParam = as(fstInt); + auto sndParam = as(sndInt); + + bool okay = false; + if (fstParam) + { + if(TryUnifyIntParam(constraints, fstParam->declRef, sndInt)) + okay = true; + } + if (sndParam) + { + if(TryUnifyIntParam(constraints, sndParam->declRef, fstInt)) + okay = true; + } + return okay; + } + } + + if (auto fstWit = as(fst)) + { + if (auto sndWit = as(snd)) + { + auto constraintDecl1 = fstWit->declRef.as(); + auto constraintDecl2 = sndWit->declRef.as(); + SLANG_ASSERT(constraintDecl1); + SLANG_ASSERT(constraintDecl2); + return TryUnifyTypes(constraints, + constraintDecl1.getDecl()->getSup().type, + constraintDecl2.getDecl()->getSup().type); + } + } + + SLANG_UNIMPLEMENTED_X("value unification case"); + + // default: fail + return false; + } + + bool tryUnifySubstitutions( + ConstraintSystem& constraints, + RefPtr fst, + RefPtr snd) + { + // They must both be NULL or non-NULL + if (!fst || !snd) + return !fst && !snd; + + if(auto fstGeneric = as(fst)) + { + if(auto sndGeneric = as(snd)) + { + return tryUnifyGenericSubstitutions( + constraints, + fstGeneric, + sndGeneric); + } + } + + // TODO: need to handle other cases here + + return false; + } + + bool tryUnifyGenericSubstitutions( + ConstraintSystem& constraints, + RefPtr fst, + RefPtr snd) + { + SLANG_ASSERT(fst); + SLANG_ASSERT(snd); + + auto fstGen = fst; + auto sndGen = snd; + // They must be specializing the same generic + if (fstGen->genericDecl != sndGen->genericDecl) + return false; + + // Their arguments must unify + SLANG_RELEASE_ASSERT(fstGen->args.getCount() == sndGen->args.getCount()); + Index argCount = fstGen->args.getCount(); + bool okay = true; + for (Index aa = 0; aa < argCount; ++aa) + { + if (!TryUnifyVals(constraints, fstGen->args[aa], sndGen->args[aa])) + { + okay = false; + } + } + + // Their "base" specializations must unify + if (!tryUnifySubstitutions(constraints, fstGen->outer, sndGen->outer)) + { + okay = false; + } + + return okay; + } + + bool TryUnifyTypeParam( + ConstraintSystem& constraints, + RefPtr typeParamDecl, + RefPtr type) + { + // We want to constrain the given type parameter + // to equal the given type. + Constraint constraint; + constraint.decl = typeParamDecl.Ptr(); + constraint.val = type; + + constraints.constraints.add(constraint); + + return true; + } + + bool TryUnifyIntParam( + ConstraintSystem& constraints, + RefPtr paramDecl, + RefPtr val) + { + // We only want to accumulate constraints on + // the parameters of the declarations being + // specialized (don't accidentially constrain + // parameters of a generic function based on + // calls in its body). + if(paramDecl->ParentDecl != constraints.genericDecl) + return false; + + // We want to constrain the given parameter to equal the given value. + Constraint constraint; + constraint.decl = paramDecl.Ptr(); + constraint.val = val; + + constraints.constraints.add(constraint); + + return true; + } + + bool TryUnifyIntParam( + ConstraintSystem& constraints, + DeclRef const& varRef, + RefPtr val) + { + if(auto genericValueParamRef = varRef.as()) + { + return TryUnifyIntParam(constraints, RefPtr(genericValueParamRef.getDecl()), val); + } + else + { + return false; + } + } + + bool TryUnifyTypesByStructuralMatch( + ConstraintSystem& constraints, + RefPtr fst, + RefPtr snd) + { + if (auto fstDeclRefType = as(fst)) + { + auto fstDeclRef = fstDeclRefType->declRef; + + if (auto typeParamDecl = as(fstDeclRef.getDecl())) + return TryUnifyTypeParam(constraints, typeParamDecl, snd); + + if (auto sndDeclRefType = as(snd)) + { + auto sndDeclRef = sndDeclRefType->declRef; + + if (auto typeParamDecl = as(sndDeclRef.getDecl())) + return TryUnifyTypeParam(constraints, typeParamDecl, fst); + + // can't be unified if they refer to different declarations. + if (fstDeclRef.getDecl() != sndDeclRef.getDecl()) return false; + + // next we need to unify the substitutions applied + // to each declaration reference. + if (!tryUnifySubstitutions( + constraints, + fstDeclRef.substitutions.substitutions, + sndDeclRef.substitutions.substitutions)) + { + return false; + } + + return true; + } + } + + return false; + } + + bool TryUnifyTypes( + ConstraintSystem& constraints, + RefPtr fst, + RefPtr snd) + { + if (fst->Equals(snd)) return true; + + // An error type can unify with anything, just so we avoid cascading errors. + + if (auto fstErrorType = as(fst)) + return true; + + if (auto sndErrorType = as(snd)) + return true; + + // A generic parameter type can unify with anything. + // TODO: there actually needs to be some kind of "occurs check" sort + // of thing here... + + if (auto fstDeclRefType = as(fst)) + { + auto fstDeclRef = fstDeclRefType->declRef; + + if (auto typeParamDecl = as(fstDeclRef.getDecl())) + { + if(typeParamDecl->ParentDecl == constraints.genericDecl ) + return TryUnifyTypeParam(constraints, typeParamDecl, snd); + } + } + + if (auto sndDeclRefType = as(snd)) + { + auto sndDeclRef = sndDeclRefType->declRef; + + if (auto typeParamDecl = as(sndDeclRef.getDecl())) + { + if(typeParamDecl->ParentDecl == constraints.genericDecl ) + return TryUnifyTypeParam(constraints, typeParamDecl, fst); + } + } + + // If we can unify the types structurally, then we are golden + if(TryUnifyTypesByStructuralMatch(constraints, fst, snd)) + return true; + + // Now we need to consider cases where coercion might + // need to be applied. For now we can try to do this + // in a completely ad hoc fashion, but eventually we'd + // want to do it more formally. + + if(auto fstVectorType = as(fst)) + { + if(auto sndScalarType = as(snd)) + { + return TryUnifyTypes( + constraints, + fstVectorType->elementType, + sndScalarType); + } + } + + if(auto fstScalarType = as(fst)) + { + if(auto sndVectorType = as(snd)) + { + return TryUnifyTypes( + constraints, + fstScalarType, + sndVectorType->elementType); + } + } + + // TODO: the same thing for vectors... + + return false; + } + + // Is the candidate extension declaration actually applicable to the given type + DeclRef ApplyExtensionToType( + ExtensionDecl* extDecl, + RefPtr type) + { + DeclRef extDeclRef = makeDeclRef(extDecl); + + // If the extension is a generic extension, then we + // need to infer type arguments that will give + // us a target type that matches `type`. + // + if (auto extGenericDecl = GetOuterGeneric(extDecl)) + { + ConstraintSystem constraints; + constraints.loc = extDecl->loc; + constraints.genericDecl = extGenericDecl; + + if (!TryUnifyTypes(constraints, extDecl->targetType.Ptr(), type)) + return DeclRef(); + + auto constraintSubst = TrySolveConstraintSystem(&constraints, DeclRef(extGenericDecl, nullptr).as()); + if (!constraintSubst) + { + return DeclRef(); + } + + // Construct a reference to the extension with our constraint variables + // set as they were found by solving the constraint system. + extDeclRef = DeclRef(extDecl, constraintSubst).as(); + } + + // Now extract the target type from our (possibly specialized) extension decl-ref. + RefPtr targetType = GetTargetType(extDeclRef); + + // As a bit of a kludge here, if the target type of the extension is + // an interface, and the `type` we are trying to match up has a this-type + // substitution for that interface, then we want to attach a matching + // substitution to the extension decl-ref. + if(auto targetDeclRefType = as(targetType)) + { + if(auto targetInterfaceDeclRef = targetDeclRefType->declRef.as()) + { + // Okay, the target type is an interface. + // + // Is the type we want to apply to also an interface? + if(auto appDeclRefType = as(type)) + { + if(auto appInterfaceDeclRef = appDeclRefType->declRef.as()) + { + if(appInterfaceDeclRef.getDecl() == targetInterfaceDeclRef.getDecl()) + { + // Looks like we have a match in the types, + // now let's see if we have a this-type substitution. + if(auto appThisTypeSubst = appInterfaceDeclRef.substitutions.substitutions.as()) + { + if(appThisTypeSubst->interfaceDecl == appInterfaceDeclRef.getDecl()) + { + // The type we want to apply to has a this-type substitution, + // and (by construction) the target type currently does not. + // + SLANG_ASSERT(!targetInterfaceDeclRef.substitutions.substitutions.as()); + + // We will create a new substitution to apply to the target type. + RefPtr newTargetSubst = new ThisTypeSubstitution(); + newTargetSubst->interfaceDecl = appThisTypeSubst->interfaceDecl; + newTargetSubst->witness = appThisTypeSubst->witness; + newTargetSubst->outer = targetInterfaceDeclRef.substitutions.substitutions; + + targetType = DeclRefType::Create(getSession(), + DeclRef(targetInterfaceDeclRef.getDecl(), newTargetSubst)); + + // Note: we are constructing a this-type substitution that + // we will apply to the extension declaration as well. + // This is not strictly allowed by our current representation + // choices, but we need it in order to make sure that + // references to the target type of the extension + // declaration have a chance to resolve the way we want them to. + + RefPtr newExtSubst = new ThisTypeSubstitution(); + newExtSubst->interfaceDecl = appThisTypeSubst->interfaceDecl; + newExtSubst->witness = appThisTypeSubst->witness; + newExtSubst->outer = extDeclRef.substitutions.substitutions; + + extDeclRef = DeclRef( + extDeclRef.getDecl(), + newExtSubst); + + // TODO: Ideally we should also apply the chosen specialization to + // the decl-ref for the extension, so that subsequent lookup through + // the members of this extension will retain that substitution and + // be able to apply it. + // + // E.g., if an extension method returns a value of an associated + // type, then we'd want that to become specialized to a concrete + // type when using the extension method on a value of concrete type. + // + // The challenge here that makes me reluctant to just staple on + // such a substitution is that it wouldn't follow our implicit + // rules about where `ThisTypeSubstitution`s can appear. + } + } + } + } + } + } + } + + // In order for this extension to apply to the given type, we + // need to have a match on the target types. + if (!type->Equals(targetType)) + return DeclRef(); + + + return extDeclRef; + } + +#if 0 + bool TryUnifyArgAndParamTypes( + ConstraintSystem& system, + RefPtr argExpr, + DeclRef paramDeclRef) + { + // TODO(tfoley): potentially need a bit more + // nuance in case where argument might be + // an overload group... + return TryUnifyTypes(system, argExpr->type, GetType(paramDeclRef)); + } +#endif + + // Take a generic declaration and try to specialize its parameters + // so that the resulting inner declaration can be applicable in + // a particular context... + DeclRef SpecializeGenericForOverload( + DeclRef genericDeclRef, + OverloadResolveContext& context) + { + checkDecl(genericDeclRef.getDecl()); + + ConstraintSystem constraints; + constraints.loc = context.loc; + constraints.genericDecl = genericDeclRef.getDecl(); + + // Construct a reference to the inner declaration that has any generic + // parameter substitutions in place already, but *not* any substutions + // for the generic declaration we are currently trying to infer. + auto innerDecl = GetInner(genericDeclRef); + DeclRef unspecializedInnerRef = DeclRef(innerDecl, genericDeclRef.substitutions); + + // Check what type of declaration we are dealing with, and then try + // to match it up with the arguments accordingly... + if (auto funcDeclRef = unspecializedInnerRef.as()) + { + auto params = GetParameters(funcDeclRef).ToArray(); + + Index argCount = context.getArgCount(); + Index paramCount = params.getCount(); + + // Bail out on mismatch. + // TODO(tfoley): need more nuance here + if (argCount != paramCount) + { + return DeclRef(nullptr, nullptr); + } + + for (Index aa = 0; aa < argCount; ++aa) + { +#if 0 + if (!TryUnifyArgAndParamTypes(constraints, args[aa], params[aa])) + return DeclRef(nullptr, nullptr); +#else + // The question here is whether failure to "unify" an argument + // and parameter should lead to immediate failure. + // + // The case that is interesting is if we want to unify, say: + // `vector` and `vector` + // + // It is clear that we should solve with `N = 3`, and then + // a later step may find that the resulting types aren't + // actually a match. + // + // A more refined approach to "unification" could of course + // see that `int` can convert to `float` and use that fact. + // (and indeed we already use something like this to unify + // `float` and `vector`) + // + // So the question is then whether a mismatch during the + // unification step should be taken as an immediate failure... + + TryUnifyTypes(constraints, context.getArgType(aa), GetType(params[aa])); +#endif + } + } + else + { + // TODO(tfoley): any other cases needed here? + return DeclRef(nullptr, nullptr); + } + + auto constraintSubst = TrySolveConstraintSystem(&constraints, genericDeclRef); + if (!constraintSubst) + { + // constraint solving failed + return DeclRef(nullptr, nullptr); + } + + // We can now construct a reference to the inner declaration using + // the solution to our constraints. + return DeclRef(innerDecl, constraintSubst); + } + + void AddAggTypeOverloadCandidates( + LookupResultItem typeItem, + RefPtr type, + DeclRef aggTypeDeclRef, + OverloadResolveContext& context, + RefPtr resultType) + { + for (auto ctorDeclRef : getMembersOfType(aggTypeDeclRef)) + { + // now work through this candidate... + AddCtorOverloadCandidate(typeItem, type, ctorDeclRef, context, resultType); + } + + // Also check for generic constructors. + // + // TODO: There is way too much duplication between this case and the extension + // handling below, and all of this is *also* duplicative with the ordinary + // overload resolution logic for function. + // + // The right solution is to handle a "constructor" call expression by + // first doing member lookup in the type (for initializer members, which + // should all share a common name), and then to do overload resolution using + // the (possibly overloaded) result of that lookup. + // + for (auto genericDeclRef : getMembersOfType(aggTypeDeclRef)) + { + if (auto ctorDecl = as(genericDeclRef.getDecl()->inner)) + { + DeclRef innerRef = SpecializeGenericForOverload(genericDeclRef, context); + if (!innerRef) + continue; + + DeclRef innerCtorRef = innerRef.as(); + AddCtorOverloadCandidate(typeItem, type, innerCtorRef, context, resultType); + } + } + + // Now walk through any extensions we can find for this types + for (auto ext = GetCandidateExtensions(aggTypeDeclRef); ext; ext = ext->nextCandidateExtension) + { + auto extDeclRef = ApplyExtensionToType(ext, type); + if (!extDeclRef) + continue; + + for (auto ctorDeclRef : getMembersOfType(extDeclRef)) + { + // TODO(tfoley): `typeItem` here should really reference the extension... + + // now work through this candidate... + AddCtorOverloadCandidate(typeItem, type, ctorDeclRef, context, resultType); + } + + // Also check for generic constructors + for (auto genericDeclRef : getMembersOfType(extDeclRef)) + { + if (auto ctorDecl = genericDeclRef.getDecl()->inner.as()) + { + DeclRef innerRef = SpecializeGenericForOverload(genericDeclRef, context); + if (!innerRef) + continue; + + DeclRef innerCtorRef = innerRef.as(); + + AddCtorOverloadCandidate(typeItem, type, innerCtorRef, context, resultType); + + // TODO(tfoley): need a way to do the solving step for the constraint system + } + } + } + } + + void addGenericTypeParamOverloadCandidates( + DeclRef typeDeclRef, + OverloadResolveContext& context, + RefPtr resultType) + { + // We need to look for any constraints placed on the generic + // type parameter, since they will give us information on + // interfaces that the type must conform to. + + // We expect the parent of the generic type parameter to be a generic... + auto genericDeclRef = typeDeclRef.GetParent().as(); + SLANG_ASSERT(genericDeclRef); + + for(auto constraintDeclRef : getMembersOfType(genericDeclRef)) + { + // Does this constraint pertain to the type we are working on? + // + // We want constraints of the form `T : Foo` where `T` is the + // generic parameter in question, and `Foo` is whatever we are + // constraining it to. + auto subType = GetSub(constraintDeclRef); + auto subDeclRefType = as(subType); + if(!subDeclRefType) + continue; + if(!subDeclRefType->declRef.Equals(typeDeclRef)) + continue; + + // The super-type in the constraint (e.g., `Foo` in `T : Foo`) + // will tell us a type we should use for lookup. + auto bound = GetSup(constraintDeclRef); + + // Go ahead and use the target type: + // + // TODO: Need to consider case where this might recurse infinitely. + AddTypeOverloadCandidates(bound, context, resultType); + } + } + + void AddTypeOverloadCandidates( + RefPtr type, + OverloadResolveContext& context, + RefPtr resultType) + { + if (auto declRefType = as(type)) + { + auto declRef = declRefType->declRef; + if (auto aggTypeDeclRef = declRef.as()) + { + AddAggTypeOverloadCandidates(LookupResultItem(aggTypeDeclRef), type, aggTypeDeclRef, context, resultType); + } + else if(auto genericTypeParamDeclRef = declRef.as()) + { + addGenericTypeParamOverloadCandidates( + genericTypeParamDeclRef, + context, + resultType); + } + } + } + + void AddDeclRefOverloadCandidates( + LookupResultItem item, + OverloadResolveContext& context) + { + auto declRef = item.declRef; + + if (auto funcDeclRef = item.declRef.as()) + { + AddFuncOverloadCandidate(item, funcDeclRef, context); + } + else if (auto aggTypeDeclRef = item.declRef.as()) + { + auto type = DeclRefType::Create( + getSession(), + aggTypeDeclRef); + AddAggTypeOverloadCandidates(item, type, aggTypeDeclRef, context, type); + } + else if (auto genericDeclRef = item.declRef.as()) + { + // Try to infer generic arguments, based on the context + DeclRef innerRef = SpecializeGenericForOverload(genericDeclRef, context); + + if (innerRef) + { + // If inference works, then we've now got a + // specialized declaration reference we can apply. + + LookupResultItem innerItem; + innerItem.breadcrumbs = item.breadcrumbs; + innerItem.declRef = innerRef; + + AddDeclRefOverloadCandidates(innerItem, context); + } + else + { + // If inference failed, then we need to create + // a candidate that can be used to reflect that fact + // (so we can report a good error) + OverloadCandidate candidate; + candidate.item = item; + candidate.flavor = OverloadCandidate::Flavor::UnspecializedGeneric; + candidate.status = OverloadCandidate::Status::GenericArgumentInferenceFailed; + + AddOverloadCandidateInner(context, candidate); + } + } + else if( auto typeDefDeclRef = item.declRef.as() ) + { + auto type = getNamedType(getSession(), typeDefDeclRef); + AddTypeOverloadCandidates(GetType(typeDefDeclRef), context, type); + } + else if( auto genericTypeParamDeclRef = item.declRef.as() ) + { + auto type = DeclRefType::Create( + getSession(), + genericTypeParamDeclRef); + addGenericTypeParamOverloadCandidates(genericTypeParamDeclRef, context, type); + } + else + { + // TODO(tfoley): any other cases needed here? + } + } + + void AddOverloadCandidates( + RefPtr funcExpr, + OverloadResolveContext& context) + { + auto funcExprType = funcExpr->type; + + if (auto declRefExpr = as(funcExpr)) + { + // The expression directly referenced a declaration, + // so we can use that declaration directly to look + // for anything applicable. + AddDeclRefOverloadCandidates(LookupResultItem(declRefExpr->declRef), context); + } + else if (auto funcType = as(funcExprType)) + { + // TODO(tfoley): deprecate this path... + AddFuncOverloadCandidate(funcType, context); + } + else if (auto overloadedExpr = as(funcExpr)) + { + auto lookupResult = overloadedExpr->lookupResult2; + SLANG_RELEASE_ASSERT(lookupResult.isOverloaded()); + for(auto item : lookupResult.items) + { + AddDeclRefOverloadCandidates(item, context); + } + } + else if (auto overloadedExpr2 = as(funcExpr)) + { + for (auto item : overloadedExpr2->candidiateExprs) + { + AddOverloadCandidates(item, context); + } + } + else if (auto typeType = as(funcExprType)) + { + // If none of the above cases matched, but we are + // looking at a type, then I suppose we have + // a constructor call on our hands. + // + // TODO(tfoley): are there any meaningful types left + // that aren't declaration references? + auto type = typeType->type; + AddTypeOverloadCandidates(type, context, type); + return; + } + } + + void formatType(StringBuilder& sb, RefPtr type) + { + sb << type->ToString(); + } + + void formatVal(StringBuilder& sb, RefPtr val) + { + sb << val->ToString(); + } + + void formatDeclPath(StringBuilder& sb, DeclRef declRef) + { + // Find the parent declaration + auto parentDeclRef = declRef.GetParent(); + + // If the immediate parent is a generic, then we probably + // want the declaration above that... + auto parentGenericDeclRef = parentDeclRef.as(); + if(parentGenericDeclRef) + { + parentDeclRef = parentGenericDeclRef.GetParent(); + } + + // Depending on what the parent is, we may want to format things specially + if(auto aggTypeDeclRef = parentDeclRef.as()) + { + formatDeclPath(sb, aggTypeDeclRef); + sb << "."; + } + + sb << getText(declRef.GetName()); + + // If the parent declaration is a generic, then we need to print out its + // signature + if( parentGenericDeclRef ) + { + auto genSubst = declRef.substitutions.substitutions.as(); + SLANG_RELEASE_ASSERT(genSubst); + SLANG_RELEASE_ASSERT(genSubst->genericDecl == parentGenericDeclRef.getDecl()); + + sb << "<"; + bool first = true; + for(auto arg : genSubst->args) + { + if(!first) sb << ", "; + formatVal(sb, arg); + first = false; + } + sb << ">"; + } + } + + void formatDeclParams(StringBuilder& sb, DeclRef declRef) + { + if (auto funcDeclRef = declRef.as()) + { + + // This is something callable, so we need to also print parameter types for overloading + sb << "("; + + bool first = true; + for (auto paramDeclRef : GetParameters(funcDeclRef)) + { + if (!first) sb << ", "; + + formatType(sb, GetType(paramDeclRef)); + + first = false; + + } + + sb << ")"; + } + else if(auto genericDeclRef = declRef.as()) + { + sb << "<"; + bool first = true; + for (auto paramDeclRef : getMembers(genericDeclRef)) + { + if(auto genericTypeParam = paramDeclRef.as()) + { + if (!first) sb << ", "; + first = false; + + sb << getText(genericTypeParam.GetName()); + } + else if(auto genericValParam = paramDeclRef.as()) + { + if (!first) sb << ", "; + first = false; + + formatType(sb, GetType(genericValParam)); + sb << " "; + sb << getText(genericValParam.GetName()); + } + else + {} + } + sb << ">"; + + formatDeclParams(sb, DeclRef(GetInner(genericDeclRef), genericDeclRef.substitutions)); + } + else + { + } + } + + void formatDeclSignature(StringBuilder& sb, DeclRef declRef) + { + formatDeclPath(sb, declRef); + formatDeclParams(sb, declRef); + } + + String getDeclSignatureString(DeclRef declRef) + { + StringBuilder sb; + formatDeclSignature(sb, declRef); + return sb.ProduceString(); + } + + String getDeclSignatureString(LookupResultItem item) + { + return getDeclSignatureString(item.declRef); + } + + String getCallSignatureString( + OverloadResolveContext& context) + { + StringBuilder argsListBuilder; + argsListBuilder << "("; + + UInt argCount = context.getArgCount(); + for( UInt aa = 0; aa < argCount; ++aa ) + { + if(aa != 0) argsListBuilder << ", "; + argsListBuilder << context.getArgType(aa)->ToString(); + } + argsListBuilder << ")"; + return argsListBuilder.ProduceString(); + } + +#if 0 + String GetCallSignatureString(RefPtr expr) + { + return getCallSignatureString(expr->Arguments); + } +#endif + + RefPtr ResolveInvoke(InvokeExpr * expr) + { + OverloadResolveContext context; + // check if this is a stdlib operator call, if so we want to use cached results + // to speed up compilation + bool shouldAddToCache = false; + OperatorOverloadCacheKey key; + TypeCheckingCache* typeCheckingCache = getSession()->getTypeCheckingCache(); + if (auto opExpr = as(expr)) + { + if (key.fromOperatorExpr(opExpr)) + { + OverloadCandidate candidate; + if (typeCheckingCache->resolvedOperatorOverloadCache.TryGetValue(key, candidate)) + { + context.bestCandidateStorage = candidate; + context.bestCandidate = &context.bestCandidateStorage; + } + else + { + shouldAddToCache = true; + } + } + } + + // Look at the base expression for the call, and figure out how to invoke it. + auto funcExpr = expr->FunctionExpr; + auto funcExprType = funcExpr->type; + + // If we are trying to apply an erroneous expression, then just bail out now. + if(IsErrorExpr(funcExpr)) + { + return CreateErrorExpr(expr); + } + // If any of the arguments is an error, then we should bail out, to avoid + // cascading errors where we successfully pick an overload, but not the one + // the user meant. + for (auto arg : expr->Arguments) + { + if (IsErrorExpr(arg)) + return CreateErrorExpr(expr); + } + + context.originalExpr = expr; + context.funcLoc = funcExpr->loc; + + context.argCount = expr->Arguments.getCount(); + context.args = expr->Arguments.getBuffer(); + context.loc = expr->loc; + + if (auto funcMemberExpr = as(funcExpr)) + { + context.baseExpr = funcMemberExpr->BaseExpression; + } + else if (auto funcOverloadExpr = as(funcExpr)) + { + context.baseExpr = funcOverloadExpr->base; + } + else if (auto funcOverloadExpr2 = as(funcExpr)) + { + context.baseExpr = funcOverloadExpr2->base; + } + + if (!context.bestCandidate) + { + AddOverloadCandidates(funcExpr, context); + } + + if (context.bestCandidates.getCount() > 0) + { + // Things were ambiguous. + + // It might be that things were only ambiguous because + // one of the argument expressions had an error, and + // so a bunch of candidates could match at that position. + // + // If any argument was an error, we skip out on printing + // another message, to avoid cascading errors. + for (auto arg : expr->Arguments) + { + if (IsErrorExpr(arg)) + { + return CreateErrorExpr(expr); + } + } + + Name* funcName = nullptr; + if (auto baseVar = as(funcExpr)) + funcName = baseVar->name; + else if(auto baseMemberRef = as(funcExpr)) + funcName = baseMemberRef->name; + + String argsList = getCallSignatureString(context); + + if (context.bestCandidates[0].status != OverloadCandidate::Status::Applicable) + { + // There were multiple equally-good candidates, but none actually usable. + // We will construct a diagnostic message to help out. + + if (funcName) + { + getSink()->diagnose(expr, Diagnostics::noApplicableOverloadForNameWithArgs, funcName, argsList); + } + else + { + getSink()->diagnose(expr, Diagnostics::noApplicableWithArgs, argsList); + } + } + else + { + // There were multiple applicable candidates, so we need to report them. + + if (funcName) + { + getSink()->diagnose(expr, Diagnostics::ambiguousOverloadForNameWithArgs, funcName, argsList); + } + else + { + getSink()->diagnose(expr, Diagnostics::ambiguousOverloadWithArgs, argsList); + } + } + + { + Index candidateCount = context.bestCandidates.getCount(); + Index maxCandidatesToPrint = 10; // don't show too many candidates at once... + Index candidateIndex = 0; + for (auto candidate : context.bestCandidates) + { + String declString = getDeclSignatureString(candidate.item); + +// declString = declString + "[" + String(candidate.conversionCostSum) + "]"; + +#if 0 + // Debugging: ensure that we don't consider multiple declarations of the same operation + if (auto decl = as(candidate.item.declRef.decl)) + { + char buffer[1024]; + sprintf_s(buffer, sizeof(buffer), "[this:%p, primary:%p, next:%p]", + decl, + decl->primaryDecl, + decl->nextDecl); + declString.append(buffer); + } +#endif + + getSink()->diagnose(candidate.item.declRef, Diagnostics::overloadCandidate, declString); + + candidateIndex++; + if (candidateIndex == maxCandidatesToPrint) + break; + } + if (candidateIndex != candidateCount) + { + getSink()->diagnose(expr, Diagnostics::moreOverloadCandidates, candidateCount - candidateIndex); + } + } + + return CreateErrorExpr(expr); + } + else if (context.bestCandidate) + { + // There was one best candidate, even if it might not have been + // applicable in the end. + // We will report errors for this one candidate, then, to give + // the user the most help we can. + if (shouldAddToCache) + typeCheckingCache->resolvedOperatorOverloadCache[key] = *context.bestCandidate; + return CompleteOverloadCandidate(context, *context.bestCandidate); + } + else + { + // Nothing at all was found that we could even consider invoking + getSink()->diagnose(expr->FunctionExpr, Diagnostics::expectedFunction, funcExprType); + expr->type = QualType(getSession()->getErrorType()); + return expr; + } + } + + void AddGenericOverloadCandidate( + LookupResultItem baseItem, + OverloadResolveContext& context) + { + if (auto genericDeclRef = baseItem.declRef.as()) + { + checkDecl(genericDeclRef.getDecl()); + + OverloadCandidate candidate; + candidate.flavor = OverloadCandidate::Flavor::Generic; + candidate.item = baseItem; + candidate.resultType = nullptr; + + AddOverloadCandidate(context, candidate); + } + } + + void AddGenericOverloadCandidates( + RefPtr baseExpr, + OverloadResolveContext& context) + { + if(auto baseDeclRefExpr = as(baseExpr)) + { + auto declRef = baseDeclRefExpr->declRef; + AddGenericOverloadCandidate(LookupResultItem(declRef), context); + } + else if (auto overloadedExpr = as(baseExpr)) + { + // We are referring to a bunch of declarations, each of which might be generic + LookupResult result; + for (auto item : overloadedExpr->lookupResult2.items) + { + AddGenericOverloadCandidate(item, context); + } + } + else + { + // any other cases? + } + } + + RefPtr visitGenericAppExpr(GenericAppExpr* genericAppExpr) + { + // Start by checking the base expression and arguments. + auto& baseExpr = genericAppExpr->FunctionExpr; + baseExpr = CheckTerm(baseExpr); + auto& args = genericAppExpr->Arguments; + for (auto& arg : args) + { + arg = CheckTerm(arg); + } + + return checkGenericAppWithCheckedArgs(genericAppExpr); + } + + /// Check a generic application where the operands have already been checked. + RefPtr checkGenericAppWithCheckedArgs(GenericAppExpr* genericAppExpr) + { + // We are applying a generic to arguments, but there might be multiple generic + // declarations with the same name, so this becomes a specialized case of + // overload resolution. + + auto& baseExpr = genericAppExpr->FunctionExpr; + auto& args = genericAppExpr->Arguments; + + // If there was an error in the base expression, or in any of + // the arguments, then just bail. + if (IsErrorExpr(baseExpr)) + { + return CreateErrorExpr(genericAppExpr); + } + for (auto argExpr : args) + { + if (IsErrorExpr(argExpr)) + { + return CreateErrorExpr(genericAppExpr); + } + } + + // Otherwise, let's start looking at how to find an overload... + + OverloadResolveContext context; + context.originalExpr = genericAppExpr; + context.funcLoc = baseExpr->loc; + context.argCount = args.getCount(); + context.args = args.getBuffer(); + context.loc = genericAppExpr->loc; + + context.baseExpr = GetBaseExpr(baseExpr); + + AddGenericOverloadCandidates(baseExpr, context); + + if (context.bestCandidates.getCount() > 0) + { + // Things were ambiguous. + if (context.bestCandidates[0].status != OverloadCandidate::Status::Applicable) + { + // There were multiple equally-good candidates, but none actually usable. + // We will construct a diagnostic message to help out. + + // TODO(tfoley): print a reasonable message here... + + getSink()->diagnose(genericAppExpr, Diagnostics::unimplemented, "no applicable generic"); + + return CreateErrorExpr(genericAppExpr); + } + else + { + // There were multiple viable candidates, but that isn't an error: we just need + // to complete all of them and create an overloaded expression as a result. + + auto overloadedExpr = new OverloadedExpr2(); + overloadedExpr->base = context.baseExpr; + for (auto candidate : context.bestCandidates) + { + auto candidateExpr = CompleteOverloadCandidate(context, candidate); + overloadedExpr->candidiateExprs.add(candidateExpr); + } + return overloadedExpr; + } + } + else if (context.bestCandidate) + { + // There was one best candidate, even if it might not have been + // applicable in the end. + // We will report errors for this one candidate, then, to give + // the user the most help we can. + return CompleteOverloadCandidate(context, *context.bestCandidate); + } + else + { + // Nothing at all was found that we could even consider invoking + getSink()->diagnose(genericAppExpr, Diagnostics::unimplemented, "expected a generic"); + return CreateErrorExpr(genericAppExpr); + } + } + + RefPtr visitSharedTypeExpr(SharedTypeExpr* expr) + { + if (!expr->type.Ptr()) + { + expr->base = CheckProperType(expr->base); + expr->type = expr->base.exp->type; + } + return expr; + } + + RefPtr visitTaggedUnionTypeExpr(TaggedUnionTypeExpr* expr) + { + // We have an expression of the form `__TaggedUnion(A, B, ...)` + // which will evaluate to a tagged-union type over `A`, `B`, etc. + // + RefPtr type = new TaggedUnionType(); + expr->type = QualType(getTypeType(type)); + + for( auto& caseTypeExpr : expr->caseTypes ) + { + caseTypeExpr = CheckProperType(caseTypeExpr); + type->caseTypes.add(caseTypeExpr.type); + } + + return expr; + } + + + + + RefPtr CheckExpr(RefPtr expr) + { + auto term = CheckTerm(expr); + + // TODO(tfoley): Need a step here to ensure that the term actually + // resolves to a (single) expression with a real type. + + return term; + } + + RefPtr CheckInvokeExprWithCheckedOperands(InvokeExpr *expr) + { + auto rs = ResolveInvoke(expr); + if (auto invoke = as(rs.Ptr())) + { + // if this is still an invoke expression, test arguments passed to inout/out parameter are LValues + if(auto funcType = as(invoke->FunctionExpr->type)) + { + Index paramCount = funcType->getParamCount(); + for (Index pp = 0; pp < paramCount; ++pp) + { + auto paramType = funcType->getParamType(pp); + if (as(paramType) || as(paramType)) + { + // `out`, `inout`, and `ref` parameters currently require + // an *exact* match on the type of the argument. + // + // TODO: relax this requirement by allowing an argument + // for an `inout` parameter to be converted in both + // directions. + // + if( pp < expr->Arguments.getCount() ) + { + auto argExpr = expr->Arguments[pp]; + if( !argExpr->type.IsLeftValue ) + { + getSink()->diagnose( + argExpr, + Diagnostics::argumentExpectedLValue, + pp); + + if( auto implicitCastExpr = as(argExpr) ) + { + getSink()->diagnose( + argExpr, + Diagnostics::implicitCastUsedAsLValue, + implicitCastExpr->Arguments[0]->type, + implicitCastExpr->type); + } + + maybeDiagnoseThisNotLValue(argExpr); + } + } + else + { + // There are two ways we could get here, both involving + // a call where the number of argument expressions is + // less than the number of parameters on the callee: + // + // 1. There might be fewer arguments than parameters + // because the trailing parameters should be defaulted + // + // 2. There might be fewer arguments than parameters + // because the call is incorrect. + // + // In case (2) an error would have already been diagnosed, + // and we don't want to emit another cascading error here. + // + // In case (1) this implies the user declared an `out` + // or `inout` parameter with a default argument expression. + // That should be an error, but it should be detected + // on the declaration instead of here at the use site. + // + // Thus, it makes sense to ignore this case here. + } + } + } + } + } + return rs; + } + + RefPtr visitInvokeExpr(InvokeExpr *expr) + { + // check the base expression first + expr->FunctionExpr = CheckExpr(expr->FunctionExpr); + // Next check the argument expressions + for (auto & arg : expr->Arguments) + { + arg = CheckExpr(arg); + } + + return CheckInvokeExprWithCheckedOperands(expr); + } + + + RefPtr visitVarExpr(VarExpr *expr) + { + // If we've already resolved this expression, don't try again. + if (expr->declRef) + return expr; + + expr->type = QualType(getSession()->getErrorType()); + auto lookupResult = lookUp( + getSession(), + this, expr->name, expr->scope); + if (lookupResult.isValid()) + { + return createLookupResultExpr( + lookupResult, + nullptr, + expr->loc); + } + + getSink()->diagnose(expr, Diagnostics::undefinedIdentifier2, expr->name); + + return expr; + } + + RefPtr visitTypeCastExpr(TypeCastExpr * expr) + { + // Check the term we are applying first + auto funcExpr = expr->FunctionExpr; + funcExpr = CheckTerm(funcExpr); + + // Now ensure that the term represnets a (proper) type. + TypeExp typeExp; + typeExp.exp = funcExpr; + typeExp = CheckProperType(typeExp); + + expr->FunctionExpr = typeExp.exp; + expr->type.type = typeExp.type; + + // Next check the argument expression (there should be only one) + for (auto & arg : expr->Arguments) + { + arg = CheckExpr(arg); + } + + // Now process this like any other explicit call (so casts + // and constructor calls are semantically equivalent). + return CheckInvokeExprWithCheckedOperands(expr); + } + + // Get the type to use when referencing a declaration + QualType GetTypeForDeclRef(DeclRef declRef) + { + return getTypeForDeclRef( + getSession(), + this, + getSink(), + declRef, + &typeResult); + } + + // + // Some syntax nodes should not occur in the concrete input syntax, + // and will only appear *after* checking is complete. We need to + // deal with this cases here, even if they are no-ops. + // + + #define CASE(NAME) \ + RefPtr visit##NAME(NAME* expr) \ + { \ + SLANG_DIAGNOSE_UNEXPECTED(getSink(), expr, \ + "should not appear in input syntax"); \ + return expr; \ + } + + CASE(DerefExpr) + CASE(SwizzleExpr) + CASE(OverloadedExpr) + CASE(OverloadedExpr2) + CASE(AggTypeCtorExpr) + CASE(CastToInterfaceExpr) + CASE(LetExpr) + CASE(ExtractExistentialValueExpr) + + #undef CASE + + // + // + // + + RefPtr MaybeDereference(RefPtr inExpr) + { + RefPtr expr = inExpr; + for (;;) + { + auto baseType = expr->type; + if (auto pointerLikeType = as(baseType)) + { + auto elementType = QualType(pointerLikeType->elementType); + elementType.IsLeftValue = baseType.IsLeftValue; + + auto derefExpr = new DerefExpr(); + derefExpr->base = expr; + derefExpr->type = elementType; + + expr = derefExpr; + continue; + } + + // Default case: just use the expression as-is + return expr; + } + } + + RefPtr CheckSwizzleExpr( + MemberExpr* memberRefExpr, + RefPtr baseElementType, + IntegerLiteralValue baseElementCount) + { + RefPtr swizExpr = new SwizzleExpr(); + swizExpr->loc = memberRefExpr->loc; + swizExpr->base = memberRefExpr->BaseExpression; + + IntegerLiteralValue limitElement = baseElementCount; + + int elementIndices[4]; + int elementCount = 0; + + bool elementUsed[4] = { false, false, false, false }; + bool anyDuplicates = false; + bool anyError = false; + + auto swizzleText = getText(memberRefExpr->name); + + for (Index i = 0; i < swizzleText.getLength(); i++) + { + auto ch = swizzleText[i]; + int elementIndex = -1; + switch (ch) + { + case 'x': case 'r': elementIndex = 0; break; + case 'y': case 'g': elementIndex = 1; break; + case 'z': case 'b': elementIndex = 2; break; + case 'w': case 'a': elementIndex = 3; break; + default: + // An invalid character in the swizzle is an error + getSink()->diagnose(swizExpr, Diagnostics::invalidSwizzleExpr, swizzleText, baseElementType->ToString()); + anyError = true; + continue; + } + + // TODO(tfoley): GLSL requires that all component names + // come from the same "family"... + + // Make sure the index is in range for the source type + if (elementIndex >= limitElement) + { + getSink()->diagnose(swizExpr, Diagnostics::invalidSwizzleExpr, swizzleText, baseElementType->ToString()); + anyError = true; + continue; + } + + // Check if we've seen this index before + for (int ee = 0; ee < elementCount; ee++) + { + if (elementIndices[ee] == elementIndex) + anyDuplicates = true; + } + + // add to our list... + elementIndices[elementCount++] = elementIndex; + } + + for (int ee = 0; ee < elementCount; ++ee) + { + swizExpr->elementIndices[ee] = elementIndices[ee]; + } + swizExpr->elementCount = elementCount; + + if (anyError) + { + return CreateErrorExpr(memberRefExpr); + } + else if (elementCount == 1) + { + // single-component swizzle produces a scalar + // + // Note(tfoley): the official HLSL rules seem to be that it produces + // a one-component vector, which is then implicitly convertible to + // a scalar, but that seems like it just adds complexity. + swizExpr->type = QualType(baseElementType); + } + else + { + // TODO(tfoley): would be nice to "re-sugar" type + // here if the input type had a sugared name... + swizExpr->type = QualType(createVectorType( + baseElementType, + new ConstantIntVal(elementCount))); + } + + // A swizzle can be used as an l-value as long as there + // were no duplicates in the list of components + swizExpr->type.IsLeftValue = !anyDuplicates; + + return swizExpr; + } + + RefPtr CheckSwizzleExpr( + MemberExpr* memberRefExpr, + RefPtr baseElementType, + RefPtr baseElementCount) + { + if (auto constantElementCount = as(baseElementCount)) + { + return CheckSwizzleExpr(memberRefExpr, baseElementType, constantElementCount->value); + } + else + { + getSink()->diagnose(memberRefExpr, Diagnostics::unimplemented, "swizzle on vector of unknown size"); + return CreateErrorExpr(memberRefExpr); + } + } + + // Look up a static member + // @param expr Can be StaticMemberExpr or MemberExpr + // @param baseExpression Is the underlying type expression determined from resolving expr + RefPtr _lookupStaticMember(RefPtr expr, RefPtr baseExpression) + { + auto& baseType = baseExpression->type; + + if (auto typeType = as(baseType)) + { + // We are looking up a member inside a type. + // We want to be careful here because we should only find members + // that are implicitly or explicitly `static`. + // + // TODO: this duplicates a *lot* of logic with the case below. + // We need to fix that. + auto type = typeType->type; + + if (as(type)) + { + return CreateErrorExpr(expr); + } + + LookupResult lookupResult = lookUpMember( + getSession(), + this, + expr->name, + type); + if (!lookupResult.isValid()) + { + return lookupMemberResultFailure(expr, baseType); + } + + // We need to confirm that whatever member we + // are trying to refer to is usable via static reference. + // + // TODO: eventually we might allow a non-static + // member to be adapted by turning it into something + // like a closure that takes the missing `this` parameter. + // + // E.g., a static reference to a method could be treated + // as a value with a function type, where the first parameter + // is `type`. + // + // The biggest challenge there is that we'd need to arrange + // to generate "dispatcher" functions that could be used + // to implement that function, in the case where we are + // making a static reference to some kind of polymorphic declaration. + // + // (Also, static references to fields/properties would get even + // harder, because you'd have to know whether a getter/setter/ref-er + // is needed). + // + // For now let's just be expedient and disallow all of that, because + // we can always add it back in later. + + if (!lookupResult.isOverloaded()) + { + // The non-overloaded case is relatively easy. We just want + // to look at the member being referenced, and check if + // it is allowed in a `static` context: + // + if (!isUsableAsStaticMember(lookupResult.item)) + { + getSink()->diagnose( + expr->loc, + Diagnostics::staticRefToNonStaticMember, + type, + expr->name); + } + } + else + { + // The overloaded case is trickier, because we should first + // filter the list of candidates, because if there is anything + // that *is* usable in a static context, then we should assume + // the user just wants to reference that. We should only + // issue an error if *all* of the items that were discovered + // are non-static. + bool anyNonStatic = false; + List staticItems; + for (auto item : lookupResult.items) + { + // Is this item usable as a static member? + if (isUsableAsStaticMember(item)) + { + // If yes, then it will be part of the output. + staticItems.add(item); + } + else + { + // If no, then we might need to output an error. + anyNonStatic = true; + } + } + + // Was there anything non-static in the list? + if (anyNonStatic) + { + // If we had some static items, then that's okay, + // we just want to use our newly-filtered list. + if (staticItems.getCount()) + { + lookupResult.items = staticItems; + } + else + { + // Otherwise, it is time to report an error. + getSink()->diagnose( + expr->loc, + Diagnostics::staticRefToNonStaticMember, + type, + expr->name); + } + } + // If there were no non-static items, then the `items` + // array already represents what we'd get by filtering... + } + + return createLookupResultExpr( + lookupResult, + baseExpression, + expr->loc); + } + else if (as(baseType)) + { + return CreateErrorExpr(expr); + } + + // Failure + return lookupMemberResultFailure(expr, baseType); + } + + RefPtr visitStaticMemberExpr(StaticMemberExpr* expr) + { + expr->BaseExpression = CheckExpr(expr->BaseExpression); + + // Not sure this is needed -> but guess someone could do + expr->BaseExpression = MaybeDereference(expr->BaseExpression); + + // If the base of the member lookup has an interface type + // *without* a suitable this-type substitution, then we are + // trying to perform lookup on a value of existential type, + // and we should "open" the existential here so that we + // can expose its structure. + // + + expr->BaseExpression = maybeOpenExistential(expr->BaseExpression); + // Do a static lookup + return _lookupStaticMember(expr, expr->BaseExpression); + } + + RefPtr lookupMemberResultFailure( + DeclRefExpr* expr, + QualType const& baseType) + { + // Check it's a member expression + SLANG_ASSERT(as(expr) || as(expr)); + + getSink()->diagnose(expr, Diagnostics::noMemberOfNameInType, expr->name, baseType); + expr->type = QualType(getSession()->getErrorType()); + return expr; + } + + RefPtr visitMemberExpr(MemberExpr * expr) + { + expr->BaseExpression = CheckExpr(expr->BaseExpression); + + expr->BaseExpression = MaybeDereference(expr->BaseExpression); + + // If the base of the member lookup has an interface type + // *without* a suitable this-type substitution, then we are + // trying to perform lookup on a value of existential type, + // and we should "open" the existential here so that we + // can expose its structure. + // + expr->BaseExpression = maybeOpenExistential(expr->BaseExpression); + + auto & baseType = expr->BaseExpression->type; + + // Note: Checking for vector types before declaration-reference types, + // because vectors are also declaration reference types... + // + // Also note: the way this is done right now means that the ability + // to swizzle vectors interferes with any chance of looking up + // members via extension, for vector or scalar types. + // + // TODO: Matrix swizzles probably need to be handled at some point. + if (auto baseVecType = as(baseType)) + { + return CheckSwizzleExpr( + expr, + baseVecType->elementType, + baseVecType->elementCount); + } + else if(auto baseScalarType = as(baseType)) + { + // Treat scalar like a 1-element vector when swizzling + return CheckSwizzleExpr( + expr, + baseScalarType, + 1); + } + else if(auto typeType = as(baseType)) + { + return _lookupStaticMember(expr, expr->BaseExpression); + } + else if (as(baseType)) + { + return CreateErrorExpr(expr); + } + else + { + LookupResult lookupResult = lookUpMember( + getSession(), + this, + expr->name, + baseType.Ptr()); + if (!lookupResult.isValid()) + { + return lookupMemberResultFailure(expr, baseType); + } + + // TODO: need to filter for declarations that are valid to refer + // to in this context... + + return createLookupResultExpr( + lookupResult, + expr->BaseExpression, + expr->loc); + } + } + SemanticsVisitor & operator = (const SemanticsVisitor &) = delete; + + + // + + RefPtr visitInitializerListExpr(InitializerListExpr* expr) + { + // When faced with an initializer list, we first just check the sub-expressions blindly. + // Actually making them conform to a desired type will wait for when we know the desired + // type based on context. + + for( auto& arg : expr->args ) + { + arg = CheckTerm(arg); + } + + expr->type = getSession()->getInitializerListType(); + + return expr; + } + + void importModuleIntoScope(Scope* scope, ModuleDecl* moduleDecl) + { + // If we've imported this one already, then + // skip the step where we modify the current scope. + if (importedModules.Contains(moduleDecl)) + { + return; + } + importedModules.Add(moduleDecl); + + + // Create a new sub-scope to wire the module + // into our lookup chain. + auto subScope = new Scope(); + subScope->containerDecl = moduleDecl; + + subScope->nextSibling = scope->nextSibling; + scope->nextSibling = subScope; + + // Also import any modules from nested `import` declarations + // with the `__exported` modifier + for (auto importDecl : moduleDecl->getMembersOfType()) + { + if (!importDecl->HasModifier()) + continue; + + importModuleIntoScope(scope, importDecl->importedModuleDecl.Ptr()); + } + } + + void visitEmptyDecl(EmptyDecl* /*decl*/) + { + // nothing to do + } + + void visitImportDecl(ImportDecl* decl) + { + if(decl->IsChecked(DeclCheckState::CheckedHeader)) + return; + + // We need to look for a module with the specified name + // (whether it has already been loaded, or needs to + // be loaded), and then put its declarations into + // the current scope. + + auto name = decl->moduleNameAndLoc.name; + auto scope = decl->scope; + + // Try to load a module matching the name + auto importedModule = findOrImportModule( + getLinkage(), + name, + decl->moduleNameAndLoc.loc, + getSink()); + + // If we didn't find a matching module, then bail out + if (!importedModule) + return; + + // Record the module that was imported, so that we can use + // it later during code generation. + auto importedModuleDecl = importedModule->getModuleDecl(); + decl->importedModuleDecl = importedModuleDecl; + + // Add the declarations from the imported module into the scope + // that the `import` declaration is set to extend. + // + importModuleIntoScope(scope.Ptr(), importedModuleDecl); + + // Record the `import`ed module (and everything it depends on) + // as a dependency of the module we are compiling. + if(auto module = getModule(decl)) + { + module->addModuleDependency(importedModule); + } + + decl->SetCheckState(getCheckedState()); + } + + // Perform semantic checking of an object-oriented `this` + // expression. + RefPtr visitThisExpr(ThisExpr* expr) + { + // A `this` expression will default to immutable. + expr->type.IsLeftValue = false; + + // We will do an upwards search starting in the current + // scope, looking for a surrounding type (or `extension`) + // declaration that could be the referrant of the expression. + auto scope = expr->scope; + while (scope) + { + auto containerDecl = scope->containerDecl; + + if( auto funcDeclBase = as(containerDecl) ) + { + if( funcDeclBase->HasModifier() ) + { + expr->type.IsLeftValue = true; + } + } + else if (auto aggTypeDecl = as(containerDecl)) + { + checkDecl(aggTypeDecl); + + // Okay, we are using `this` in the context of an + // aggregate type, so the expression should be + // of the corresponding type. + expr->type.type = DeclRefType::Create( + getSession(), + makeDeclRef(aggTypeDecl)); + return expr; + } + else if (auto extensionDecl = as(containerDecl)) + { + checkDecl(extensionDecl); + + // When `this` is used in the context of an `extension` + // declaration, then it should refer to an instance of + // the type being extended. + // + // TODO: There is potentially a small gotcha here that + // lookup through such a `this` expression should probably + // prioritize members declared in the current extension + // if there are multiple extensions in scope that add + // members with the same name... + // + expr->type.type = extensionDecl->targetType.type; + return expr; + } + + scope = scope->parent; + } + + getSink()->diagnose(expr, Diagnostics::thisExpressionOutsideOfTypeDecl); + return CreateErrorExpr(expr); + } + }; + + bool isPrimaryDecl( + CallableDecl* decl) + { + SLANG_ASSERT(decl); + return (!decl->primaryDecl) || (decl == decl->primaryDecl); + } + + RefPtr checkProperType( + Linkage* linkage, + TypeExp typeExp, + DiagnosticSink* sink) + { + SemanticsVisitor visitor( + linkage, + sink); + auto typeOut = visitor.CheckProperType(typeExp); + return typeOut.type; + } + + + FuncDecl* findFunctionDeclByName( + Module* translationUnit, + Name* name, + DiagnosticSink* sink) + { + auto translationUnitSyntax = translationUnit->getModuleDecl(); + + // Make sure we've got a query-able member dictionary + buildMemberDictionary(translationUnitSyntax); + + // We will look up any global-scope declarations in the translation + // unit that match the name of our entry point. + Decl* firstDeclWithName = nullptr; + if (!translationUnitSyntax->memberDictionary.TryGetValue(name, firstDeclWithName)) + { + // If there doesn't appear to be any such declaration, then we are done. + + sink->diagnose(translationUnitSyntax, Diagnostics::entryPointFunctionNotFound, name); + + return nullptr; + } + + // We found at least one global-scope declaration with the right name, + // but (1) it might not be a function, and (2) there might be + // more than one function. + // + // We'll walk the linked list of declarations with the same name, + // to see what we find. Along the way we'll keep track of the + // first function declaration we find, if any: + FuncDecl* entryPointFuncDecl = nullptr; + for (auto ee = firstDeclWithName; ee; ee = ee->nextInContainerWithSameName) + { + // Is this declaration a function? + if (auto funcDecl = as(ee)) + { + // Skip non-primary declarations, so that + // we don't give an error when an entry + // point is forward-declared. + if (!isPrimaryDecl(funcDecl)) + continue; + + // is this the first one we've seen? + if (!entryPointFuncDecl) + { + // If so, this is a candidate to be + // the entry point function. + entryPointFuncDecl = funcDecl; + } + else + { + // Uh-oh! We've already seen a function declaration with this + // name before, so the whole thing is ambiguous. We need + // to diagnose and bail out. + + sink->diagnose(translationUnitSyntax, Diagnostics::ambiguousEntryPoint, name); + + // List all of the declarations that the user *might* mean + for (auto ff = firstDeclWithName; ff; ff = ff->nextInContainerWithSameName) + { + if (auto candidate = as(ff)) + { + sink->diagnose(candidate, Diagnostics::entryPointCandidate, candidate->getName()); + } + } + + // Bail out. + return nullptr; + } + } + } + + return entryPointFuncDecl; + } + + static bool isValidThreadDispatchIDType(Type* type) + { + // Can accept a single int/unit + { + auto basicType = as(type); + if (basicType) + { + return (basicType->baseType == BaseType::Int || basicType->baseType == BaseType::UInt); + } + } + // Can be an int/uint vector from size 1 to 3 + { + auto vectorType = as(type); + if (!vectorType) + { + return false; + } + auto elemCount = as(vectorType->elementCount); + if (elemCount->value < 1 || elemCount->value > 3) + { + return false; + } + // Must be a basic type + auto basicType = as(vectorType->elementType); + if (!basicType) + { + return false; + } + + // Must be integral + return (basicType->baseType == BaseType::Int || basicType->baseType == BaseType::UInt); + } + } + + /// Recursively walk `paramDeclRef` and add any required existential slots to `ioSlots`. + static void _collectExistentialTypeParamsRec( + ExistentialTypeSlots& ioSlots, + DeclRef paramDeclRef); + + /// Recursively walk `type` and discover any required existential type parameters. + static void _collectExistentialTypeParamsRec( + ExistentialTypeSlots& ioSlots, + Type* type) + { + // Whether or not something is an array does not affect + // the number of existential slots it introduces. + // + while( auto arrayType = as(type) ) + { + type = arrayType->baseType; + } + + if( auto parameterGroupType = as(type) ) + { + _collectExistentialTypeParamsRec(ioSlots, parameterGroupType->getElementType()); + return; + } + + if( auto declRefType = as(type) ) + { + auto typeDeclRef = declRefType->declRef; + if( auto interfaceDeclRef = typeDeclRef.as() ) + { + // Each leaf parameter of interface type adds one slot. + // + ioSlots.paramTypes.add(type); + } + else if( auto structDeclRef = typeDeclRef.as() ) + { + // A structure type should recursively introduce + // existential slots for its fields. + // + for( auto fieldDeclRef : GetFields(structDeclRef) ) + { + if(fieldDeclRef.getDecl()->HasModifier()) + continue; + + _collectExistentialTypeParamsRec(ioSlots, fieldDeclRef); + } + } + } + + // TODO: We eventually need to handle cases like constant + // buffers and parameter blocks that may have existential + // element types. + } + + static void _collectExistentialTypeParamsRec( + ExistentialTypeSlots& ioSlots, + DeclRef paramDeclRef) + { + _collectExistentialTypeParamsRec(ioSlots, GetType(paramDeclRef)); + } + + + /// Add information about a shader parameter to `ioParams` and `ioSlots` + static void _collectExistentialSlotsForShaderParam( + ShaderParamInfo& ioParamInfo, + ExistentialTypeSlots& ioSlots, + DeclRef paramDeclRef) + { + Index startSlot = ioSlots.paramTypes.getCount(); + _collectExistentialTypeParamsRec(ioSlots, paramDeclRef); + Index endSlot = ioSlots.paramTypes.getCount(); + + ioParamInfo.firstExistentialTypeSlot = UInt(startSlot); + ioParamInfo.existentialTypeSlotCount = UInt(endSlot - startSlot);; + } + + /// Enumerate the existential-type parameters of an `EntryPoint`. + /// + /// Any parameters found will be added to the list of existential slots on `this`. + /// + void EntryPoint::_collectShaderParams() + { + // Note: we defensively test whether there is a function decl-ref + // because this routine gets called from the constructor, and + // a "dummy" entry point will have a null pointer for the function. + // + if( auto funcDeclRef = getFuncDeclRef() ) + { + for( auto paramDeclRef : GetParameters(funcDeclRef) ) + { + ShaderParamInfo shaderParamInfo; + shaderParamInfo.paramDeclRef = paramDeclRef; + + _collectExistentialSlotsForShaderParam( + shaderParamInfo, + m_existentialSlots, + paramDeclRef); + + m_shaderParams.add(shaderParamInfo); + } + } + } + + // Validate that an entry point function conforms to any additional + // constraints based on the stage (and profile?) it specifies. + void validateEntryPoint( + EntryPoint* entryPoint, + DiagnosticSink* sink) + { + auto entryPointFuncDecl = entryPoint->getFuncDecl(); + auto stage = entryPoint->getStage(); + + // TODO: We currently do minimal checking here, but this is the + // right place to perform the following validation checks: + // + + // * Are the function input/output parameters and result type + // all valid for the chosen stage? (e.g., there shouldn't be + // an `OutputStream` type in a vertex shader signature) + // + // * For any varying input/output, are there semantics specified + // (Note: this potentially overlaps with layout logic...), and + // are the system-value semantics valid for the given stage? + // + // There's actually a lot of detail to semantic checking, in + // that the AST-level code should probably be validating the + // use of system-value semantics by linking them to explicit + // declarations in the standard library. We should also be + // using profile information on those declarations to infer + // appropriate profile restrictions on the entry point. + // + // * Is the entry point actually usable on the given stage/profile? + // E.g., if we have a vertex shader that (transitively) calls + // `Texture2D.Sample`, then that should produce an error because + // that function is specific to the fragment profile/stage. + // + + auto entryPointName = entryPointFuncDecl->getName(); + + auto module = getModule(entryPointFuncDecl); + auto linkage = module->getLinkage(); + + + // Every entry point needs to have a stage specified either via + // command-line/API options, or via an explicit `[shader("...")]` attribute. + // + if( stage == Stage::Unknown ) + { + sink->diagnose(entryPointFuncDecl, Diagnostics::entryPointHasNoStage, entryPointName); + } + + if( stage == Stage::Hull ) + { + // TODO: We could consider *always* checking any `[patchconsantfunc("...")]` + // attributes, so that they need to resolve to a function. + + auto attr = entryPointFuncDecl->FindModifier(); + + if (attr) + { + if (attr->args.getCount() != 1) + { + sink->diagnose(attr, Diagnostics::badlyDefinedPatchConstantFunc, entryPointName); + return; + } + + Expr* expr = attr->args[0]; + StringLiteralExpr* stringLit = as(expr); + + if (!stringLit) + { + sink->diagnose(expr, Diagnostics::badlyDefinedPatchConstantFunc, entryPointName); + return; + } + + // We look up the patch-constant function by its name in the module + // scope of the translation unit that declared the HS entry point. + // + // TODO: Eventually we probably want to do the lookup in the scope + // of the parent declarations of the entry point. E.g., if the entry + // point is a member function of a `struct`, then its patch-constant + // function should be allowed to be another member function of + // the same `struct`. + // + // In the extremely long run we may want to support an alternative to + // this attribute-based linkage between the two functions that + // make up the entry point. + // + Name* name = linkage->getNamePool()->getName(stringLit->value); + FuncDecl* patchConstantFuncDecl = findFunctionDeclByName( + module, + name, + sink); + if (!patchConstantFuncDecl) + { + sink->diagnose(expr, Diagnostics::attributeFunctionNotFound, name, "patchconstantfunc"); + return; + } + + attr->patchConstantFuncDecl = patchConstantFuncDecl; + } + } + else if(stage == Stage::Compute) + { + for(const auto& param : entryPointFuncDecl->GetParameters()) + { + if(auto semantic = param->FindModifier()) + { + const auto& semanticToken = semantic->name; + + String lowerName = String(semanticToken.Content).toLower(); + + if(lowerName == "sv_dispatchthreadid") + { + Type* paramType = param->getType(); + + if(!isValidThreadDispatchIDType(paramType)) + { + String typeString = paramType->ToString(); + sink->diagnose(param->loc, Diagnostics::invalidDispatchThreadIDType, typeString); + return; + } + } + } + } + } + } + + // Given an entry point specified via API or command line options, + // attempt to find a matching AST declaration that implements the specified + // entry point. If such a function is found, then validate that it actually + // meets the requirements for the selected stage/profile. + // + // Returns an `EntryPoint` object representing the (unspecialized) + // entry point if it is found and validated, and null otherwise. + // + RefPtr findAndValidateEntryPoint( + FrontEndEntryPointRequest* entryPointReq) + { + // The first step in validating the entry point is to find + // the (unique) function declaration that matches its name. + // + // TODO: We may eventually want/need to extend this to + // account for nested names like `SomeStruct.vsMain`, or + // indeed even to handle generics. + // + auto compileRequest = entryPointReq->getCompileRequest(); + auto translationUnit = entryPointReq->getTranslationUnit(); + auto sink = compileRequest->getSink(); + auto translationUnitSyntax = translationUnit->getModuleDecl(); + + auto entryPointName = entryPointReq->getName(); + + // Make sure we've got a query-able member dictionary + buildMemberDictionary(translationUnitSyntax); + + // We will look up any global-scope declarations in the translation + // unit that match the name of our entry point. + Decl* firstDeclWithName = nullptr; + if( !translationUnitSyntax->memberDictionary.TryGetValue(entryPointName, firstDeclWithName) ) + { + // If there doesn't appear to be any such declaration, then + // we need to diagnose it as an error, and then bail out. + sink->diagnose(translationUnitSyntax, Diagnostics::entryPointFunctionNotFound, entryPointName); + return nullptr; + } + + // We found at least one global-scope declaration with the right name, + // but (1) it might not be a function, and (2) there might be + // more than one function. + // + // We'll walk the linked list of declarations with the same name, + // to see what we find. Along the way we'll keep track of the + // first function declaration we find, if any: + // + FuncDecl* entryPointFuncDecl = nullptr; + for(auto ee = firstDeclWithName; ee; ee = ee->nextInContainerWithSameName) + { + // We want to support the case where the declaration is + // a generic function, so we will automatically + // unwrap any outer `GenericDecl` we find here. + // + auto decl = ee; + if(auto genericDecl = as(decl)) + decl = genericDecl->inner; + + // Is this declaration a function? + if (auto funcDecl = as(decl)) + { + // Skip non-primary declarations, so that + // we don't give an error when an entry + // point is forward-declared. + if (!isPrimaryDecl(funcDecl)) + continue; + + // is this the first one we've seen? + if (!entryPointFuncDecl) + { + // If so, this is a candidate to be + // the entry point function. + entryPointFuncDecl = funcDecl; + } + else + { + // Uh-oh! We've already seen a function declaration with this + // name before, so the whole thing is ambiguous. We need + // to diagnose and bail out. + + sink->diagnose(translationUnitSyntax, Diagnostics::ambiguousEntryPoint, entryPointName); + + // List all of the declarations that the user *might* mean + for (auto ff = firstDeclWithName; ff; ff = ff->nextInContainerWithSameName) + { + if (auto candidate = as(ff)) + { + sink->diagnose(candidate, Diagnostics::entryPointCandidate, candidate->getName()); + } + } + + // Bail out. + return nullptr; + } + } + } + + // Did we find a function declaration in our search? + if(!entryPointFuncDecl) + { + // If not, then we need to diagnose the error. + // For convenience, we will point to the first + // declaration with the right name, that wasn't a function. + sink->diagnose(firstDeclWithName, Diagnostics::entryPointSymbolNotAFunction, entryPointName); + return nullptr; + } + + // TODO: it is possible that the entry point was declared with + // profile or target overloading. Is there anything that we need + // to do at this point to filter out declarations that aren't + // relevant to the selected profile for the entry point? + + // We found something, and can start doing some basic checking. + // + // If the entry point specifies a stage via a `[shader("...")]` attribute, + // then we might be able to infer a stage for the entry point request if + // it didn't have one, *or* issue a diagnostic if there is a mismatch. + // + auto entryPointProfile = entryPointReq->getProfile(); + if( auto entryPointAttribute = entryPointFuncDecl->FindModifier() ) + { + auto entryPointStage = entryPointProfile.GetStage(); + if( entryPointStage == Stage::Unknown ) + { + entryPointProfile.setStage(entryPointAttribute->stage); + } + else if( entryPointAttribute->stage != entryPointStage ) + { + sink->diagnose(entryPointFuncDecl, Diagnostics::specifiedStageDoesntMatchAttribute, entryPointName, entryPointStage, entryPointAttribute->stage); + } + } + else + { + // TODO: Should we attach a `[shader(...)]` attribute to an + // entry point that didn't have one, so that we can have + // a more uniform representation in the AST? + } + + + RefPtr entryPoint = EntryPoint::create( + makeDeclRef(entryPointFuncDecl), + entryPointProfile); + + // Now that we've *found* the entry point, it is time to validate + // that it actually meets the constraints for the chosen stage/profile. + // + validateEntryPoint(entryPoint, sink); + + return entryPoint; + } + + /// Get the name a variable will use for reflection purposes +Name* getReflectionName(VarDeclBase* varDecl) +{ + if (auto reflectionNameModifier = varDecl->FindModifier()) + return reflectionNameModifier->nameAndLoc.name; + + return varDecl->getName(); +} + +// Information tracked when doing a structural +// match of types. +struct StructuralTypeMatchStack +{ + DeclRef leftDecl; + DeclRef rightDecl; + StructuralTypeMatchStack* parent; +}; + +static void diagnoseParameterTypeMismatch( + DiagnosticSink* sink, + StructuralTypeMatchStack* inStack) +{ + SLANG_ASSERT(inStack); + + // The bottom-most entry in the stack should represent + // the shader parameters that kicked things off + auto stack = inStack; + while(stack->parent) + stack = stack->parent; + + sink->diagnose(stack->leftDecl, Diagnostics::shaderParameterDeclarationsDontMatch, getReflectionName(stack->leftDecl)); + sink->diagnose(stack->rightDecl, Diagnostics::seeOtherDeclarationOf, getReflectionName(stack->rightDecl)); +} + +// Two types that were expected to match did not. +// Inform the user with a suitable message. +static void diagnoseTypeMismatch( + DiagnosticSink* sink, + StructuralTypeMatchStack* inStack) +{ + auto stack = inStack; + SLANG_ASSERT(stack); + diagnoseParameterTypeMismatch(sink, stack); + + auto leftType = GetType(stack->leftDecl); + auto rightType = GetType(stack->rightDecl); + + if( stack->parent ) + { + sink->diagnose(stack->leftDecl, Diagnostics::fieldTypeMisMatch, getReflectionName(stack->leftDecl), leftType, rightType); + sink->diagnose(stack->rightDecl, Diagnostics::seeOtherDeclarationOf, getReflectionName(stack->rightDecl)); + + stack = stack->parent; + if( stack ) + { + while( stack->parent ) + { + sink->diagnose(stack->leftDecl, Diagnostics::usedInDeclarationOf, getReflectionName(stack->leftDecl)); + stack = stack->parent; + } + } + } + else + { + sink->diagnose(stack->leftDecl, Diagnostics::shaderParameterTypeMismatch, leftType, rightType); + } +} + +// Two types that were expected to match did not. +// Inform the user with a suitable message. +static void diagnoseTypeFieldsMismatch( + DiagnosticSink* sink, + DeclRef const& left, + DeclRef const& right, + StructuralTypeMatchStack* stack) +{ + diagnoseParameterTypeMismatch(sink, stack); + + sink->diagnose(left, Diagnostics::fieldDeclarationsDontMatch, left.GetName()); + sink->diagnose(right, Diagnostics::seeOtherDeclarationOf, right.GetName()); + + if( stack ) + { + while( stack->parent ) + { + sink->diagnose(stack->leftDecl, Diagnostics::usedInDeclarationOf, getReflectionName(stack->leftDecl)); + stack = stack->parent; + } + } +} + +static void collectFields( + DeclRef declRef, + List>& outFields) +{ + for( auto fieldDeclRef : getMembersOfType(declRef) ) + { + if(fieldDeclRef.getDecl()->HasModifier()) + continue; + + outFields.add(fieldDeclRef); + } +} + +static bool validateTypesMatch( + DiagnosticSink* sink, + Type* left, + Type* right, + StructuralTypeMatchStack* stack); + +static bool validateIntValuesMatch( + DiagnosticSink* sink, + IntVal* left, + IntVal* right, + StructuralTypeMatchStack* stack) +{ + if(left->EqualsVal(right)) + return true; + + // TODO: are there other cases we need to handle here? + + diagnoseTypeMismatch(sink, stack); + return false; +} + + +static bool validateValuesMatch( + DiagnosticSink* sink, + Val* left, + Val* right, + StructuralTypeMatchStack* stack) +{ + if( auto leftType = dynamicCast(left) ) + { + if( auto rightType = dynamicCast(right) ) + { + return validateTypesMatch(sink, leftType, rightType, stack); + } + } + + if( auto leftInt = dynamicCast(left) ) + { + if( auto rightInt = dynamicCast(right) ) + { + return validateIntValuesMatch(sink, leftInt, rightInt, stack); + } + } + + if( auto leftWitness = dynamicCast(left) ) + { + if( auto rightWitness = dynamicCast(right) ) + { + return true; + } + } + + diagnoseTypeMismatch(sink, stack); + return false; +} + +static bool validateGenericSubstitutionsMatch( + DiagnosticSink* sink, + GenericSubstitution* left, + GenericSubstitution* right, + StructuralTypeMatchStack* stack) +{ + if( !left ) + { + if( !right ) + { + return true; + } + + diagnoseTypeMismatch(sink, stack); + return false; + } + + + + Index argCount = left->args.getCount(); + if( argCount != right->args.getCount() ) + { + diagnoseTypeMismatch(sink, stack); + return false; + } + + for( Index aa = 0; aa < argCount; ++aa ) + { + auto leftArg = left->args[aa]; + auto rightArg = right->args[aa]; + + if(!validateValuesMatch(sink, leftArg, rightArg, stack)) + return false; + } + + return true; +} + +static bool validateThisTypeSubstitutionsMatch( + DiagnosticSink* /*sink*/, + ThisTypeSubstitution* /*left*/, + ThisTypeSubstitution* /*right*/, + StructuralTypeMatchStack* /*stack*/) +{ + // TODO: actual checking. + return true; +} + +static bool validateSpecializationsMatch( + DiagnosticSink* sink, + SubstitutionSet left, + SubstitutionSet right, + StructuralTypeMatchStack* stack) +{ + auto ll = left.substitutions; + auto rr = right.substitutions; + for(;;) + { + // Skip any global generic substitutions. + if(auto leftGlobalGeneric = as(ll)) + { + ll = leftGlobalGeneric->outer; + continue; + } + if(auto rightGlobalGeneric = as(rr)) + { + rr = rightGlobalGeneric->outer; + continue; + } + + // If either ran out, then we expect both to have run out. + if(!ll || !rr) + return !ll && !rr; + + auto leftSubst = ll; + auto rightSubst = rr; + + ll = ll->outer; + rr = rr->outer; + + if(auto leftGeneric = as(leftSubst)) + { + if(auto rightGeneric = as(rightSubst)) + { + if(validateGenericSubstitutionsMatch(sink, leftGeneric, rightGeneric, stack)) + { + continue; + } + } + } + else if(auto leftThisType = as(leftSubst)) + { + if(auto rightThisType = as(rightSubst)) + { + if(validateThisTypeSubstitutionsMatch(sink, leftThisType, rightThisType, stack)) + { + continue; + } + } + } + + return false; + } + + return true; +} + +// Determine if two types "match" for the purposes of `cbuffer` layout rules. +// +static bool validateTypesMatch( + DiagnosticSink* sink, + Type* left, + Type* right, + StructuralTypeMatchStack* stack) +{ + if(left->Equals(right)) + return true; + + // It is possible that the types don't match exactly, but + // they *do* match structurally. + + // Note: the following code will lead to infinite recursion if there + // are ever recursive types. We'd need a more refined system to + // cache the matches we've already found. + + if( auto leftDeclRefType = as(left) ) + { + if( auto rightDeclRefType = as(right) ) + { + // Are they references to matching decl refs? + auto leftDeclRef = leftDeclRefType->declRef; + auto rightDeclRef = rightDeclRefType->declRef; + + // Do the reference the same declaration? Or declarations + // with the same name? + // + // TODO: we should only consider the same-name case if the + // declarations come from translation units being compiled + // (and not an imported module). + if( leftDeclRef.getDecl() == rightDeclRef.getDecl() + || leftDeclRef.GetName() == rightDeclRef.GetName() ) + { + // Check that any generic arguments match + if( !validateSpecializationsMatch( + sink, + leftDeclRef.substitutions, + rightDeclRef.substitutions, + stack) ) + { + return false; + } + + // Check that any declared fields match too. + if( auto leftStructDeclRef = leftDeclRef.as() ) + { + if( auto rightStructDeclRef = rightDeclRef.as() ) + { + List> leftFields; + List> rightFields; + + collectFields(leftStructDeclRef, leftFields); + collectFields(rightStructDeclRef, rightFields); + + Index leftFieldCount = leftFields.getCount(); + Index rightFieldCount = rightFields.getCount(); + + if( leftFieldCount != rightFieldCount ) + { + diagnoseTypeFieldsMismatch(sink, leftDeclRef, rightDeclRef, stack); + return false; + } + + for( Index ii = 0; ii < leftFieldCount; ++ii ) + { + auto leftField = leftFields[ii]; + auto rightField = rightFields[ii]; + + if( leftField.GetName() != rightField.GetName() ) + { + diagnoseTypeFieldsMismatch(sink, leftDeclRef, rightDeclRef, stack); + return false; + } + + auto leftFieldType = GetType(leftField); + auto rightFieldType = GetType(rightField); + + StructuralTypeMatchStack subStack; + subStack.parent = stack; + subStack.leftDecl = leftField; + subStack.rightDecl = rightField; + + if(!validateTypesMatch(sink, leftFieldType,rightFieldType, &subStack)) + return false; + } + } + } + + // Everything seemed to match recursively. + return true; + } + } + } + + // If we are looking at `T[N]` and `U[M]` we want to check that + // `T` is structurally equivalent to `U` and `N` is the same as `M`. + else if( auto leftArrayType = as(left) ) + { + if( auto rightArrayType = as(right) ) + { + if(!validateTypesMatch(sink, leftArrayType->baseType, rightArrayType->baseType, stack) ) + return false; + + if(!validateValuesMatch(sink, leftArrayType->ArrayLength, rightArrayType->ArrayLength, stack)) + return false; + + return true; + } + } + + diagnoseTypeMismatch(sink, stack); + return false; +} + +// This function is supposed to determine if two global shader +// parameter declarations represent the same logical parameter +// (so that they should get the exact same binding(s) allocated). +// +static bool doesParameterMatch( + DiagnosticSink* sink, + DeclRef varDeclRef, + DeclRef existingVarDeclRef) +{ + StructuralTypeMatchStack stack; + stack.parent = nullptr; + stack.leftDecl = varDeclRef; + stack.rightDecl = existingVarDeclRef; + + validateTypesMatch(sink, GetType(varDeclRef), GetType(existingVarDeclRef), &stack); + + return true; +} + + + + + /// Enumerate the existential-type parameters of a `Program`. + /// + /// Any parameters found will be added to the list of existential slots on `this`. + /// + void Program::_collectShaderParams(DiagnosticSink* sink) + { + // We need to collect all of the global shader parameters + // referenced by the compile request, and for each we + // need to do a few things: + // + // * We need to determine if the parameter is a duplicate/redeclaration + // of the "same" parameter in another translation unit, and collapse + // those into one logical shader parameter if so. + // + // * We need to determine what existential type slots are introduced + // by the parameter, and associate that information with the parameter. + // + // To deal with the first issue, we will maintain a map from a parameter + // name to the index of an existing parameter with that name. + // + Dictionary mapNameToParamIndex; + + for( auto module : getModuleDependencies() ) + { + auto moduleDecl = module->getModuleDecl(); + for( auto globalVar : moduleDecl->getMembersOfType() ) + { + if(!isGlobalShaderParameter(globalVar)) + continue; + + // This declaration may represent the same logical parameter + // as a declaration that came from a different translation unit. + // If that is the case, we want to re-use the same `ShaderParamInfo` + // across both parameters. + // + // TODO: This logic currently detects *any* global-scope parameters + // with matching names, but it should eventually be narrowly + // scoped so that it only applies to parameters from unnamed modules + // (that is, modules that represent directly-compiled shader files + // and not `import`ed code). + // + // First we look for an existing entry matching the name + // of this parameter: + // + auto paramName = getReflectionName(globalVar); + Int existingParamIndex = -1; + if( mapNameToParamIndex.TryGetValue(paramName, existingParamIndex) ) + { + // If the parameters have the same name, but don't "match" according to some reasonable rules, + // then we will treat them as distinct global parameters. + // + // Note: all of the mismatch cases currently report errors, so that + // compilation will fail on a mismatch. + // + auto& existingParam = m_shaderParams[existingParamIndex]; + if( doesParameterMatch(sink, makeDeclRef(globalVar.Ptr()), existingParam.paramDeclRef) ) + { + // If we hit this case, then we had a match, and we should + // consider the new variable to be a redclaration of + // the existing one. + + existingParam.additionalParamDeclRefs.add( + makeDeclRef(globalVar.Ptr())); + continue; + } + } + + Int newParamIndex = Int(m_shaderParams.getCount()); + mapNameToParamIndex.Add(paramName, newParamIndex); + + GlobalShaderParamInfo shaderParamInfo; + shaderParamInfo.paramDeclRef = makeDeclRef(globalVar.Ptr()); + + _collectExistentialSlotsForShaderParam( + shaderParamInfo, + m_globalExistentialSlots, + makeDeclRef(globalVar.Ptr())); + + m_shaderParams.add(shaderParamInfo); + } + } + } + + /// Create a `Program` to represent the compiled code. + /// + /// The created program will comprise all of the translation + /// units that were compiled as part of the request, as + /// well as any entry points in those translation units. + /// + RefPtr createUnspecializedProgram( + FrontEndCompileRequest* compileRequest) + { + // We want our resulting program to depend on + // all the translation units the user specified, + // even if some of them don't contain entry points + // (this is important for parameter layout/binding). + // + // We also want to ensure that the modules for the + // translation units comes first in the enumerated + // order for dependencies, to match the pre-existing + // compiler behavior (at least for now). + // + auto linkage = compileRequest->getLinkage(); + auto sink = compileRequest->getSink(); + auto program = new Program(linkage); + for(auto translationUnit : compileRequest->translationUnits ) + { + program->addReferencedLeafModule(translationUnit->getModule()); + } + for(auto translationUnit : compileRequest->translationUnits ) + { + program->addReferencedModule(translationUnit->getModule()); + } + + + // The validation of entry points here will be modal, and controlled + // by whether the user specified any entry points directly via + // API or command-line options. + // + // TODO: We may want to make this choice explicit rather than implicit. + // + // First, check if the user requested any entry points explicitly via + // the API or command line. + // + bool anyExplicitEntryPoints = compileRequest->getEntryPointReqCount() != 0; + + if( anyExplicitEntryPoints ) + { + // If there were any explicit requests for entry points to be + // checked, then we will *only* check those. + // + for(auto entryPointReq : compileRequest->getEntryPointReqs()) + { + auto entryPoint = findAndValidateEntryPoint( + entryPointReq); + if( entryPoint ) + { + program->addEntryPoint(entryPoint); + entryPointReq->getTranslationUnit()->entryPoints.add(entryPoint); + } + } + + // TODO: We should consider always processing both categories, + // and just making sure to only check each entry point function + // declaration once... + } + else + { + // Otherwise, scan for any `[shader(...)]` attributes in + // the user's code, and construct `EntryPoint`s to + // represent them. + // + // This ensures that downstream code only has to consider + // the central list of entry point requests, and doesn't + // have to know where they came from. + + // TODO: A comprehensive approach here would need to search + // recursively for entry points, because they might appear + // as, e.g., member function of a `struct` type. + // + // For now we'll start with an extremely basic approach that + // should work for typical HLSL code. + // + Index translationUnitCount = compileRequest->translationUnits.getCount(); + for(Index tt = 0; tt < translationUnitCount; ++tt) + { + auto translationUnit = compileRequest->translationUnits[tt]; + for( auto globalDecl : translationUnit->getModuleDecl()->Members ) + { + auto maybeFuncDecl = globalDecl; + if( auto genericDecl = as(maybeFuncDecl) ) + { + maybeFuncDecl = genericDecl->inner; + } + + auto funcDecl = as(maybeFuncDecl); + if(!funcDecl) + continue; + + auto entryPointAttr = funcDecl->FindModifier(); + if(!entryPointAttr) + continue; + + // We've discovered a valid entry point. It is a function (possibly + // generic) that has a `[shader(...)]` attribute to mark it as an + // entry point. + // + // We will now register that entry point as an `EntryPoint` + // with an appropriately chosen profile. + // + // The profile will only include a stage, so that the profile "family" + // and "version" are left unspecified. Downstream code will need + // to be able to handle this case. + // + Profile profile; + profile.setStage(entryPointAttr->stage); + + RefPtr entryPoint = EntryPoint::create( + makeDeclRef(funcDecl), + profile); + + validateEntryPoint(entryPoint, sink); + + program->addEntryPoint(entryPoint); + translationUnit->entryPoints.add(entryPoint); + } + } + } + + program->_collectShaderParams(sink); + + return program; + } + + static void _specializeExistentialTypeParams( + Linkage* linkage, + ExistentialTypeSlots& ioSlots, + List> const& args, + DiagnosticSink* sink) + { + Index slotCount = ioSlots.paramTypes.getCount(); + Index argCount = args.getCount(); + + if( slotCount != argCount ) + { + sink->diagnose(SourceLoc(), Diagnostics::mismatchExistentialSlotArgCount, slotCount, argCount); + return; + } + + SemanticsVisitor visitor(linkage, sink); + + for( Index ii = 0; ii < slotCount; ++ii ) + { + auto slotType = ioSlots.paramTypes[ii]; + auto argExpr = args[ii]; + + auto argType = checkProperType(linkage, TypeExp(argExpr), sink); + if(!argType) + { + // TODO: Each slot should track a source location and/or a `VarDeclBase` + // that names the parameter that the slot corresponds to. + + sink->diagnose(SourceLoc(), Diagnostics::existentialSlotArgNotAType, ii); + return; + } + + + auto witness = visitor.tryGetSubtypeWitness(argType, slotType); + if (!witness) + { + // If no witness was found, then we will be unable to satisfy + // the conformances required. + sink->diagnose(SourceLoc(), Diagnostics::existentialSlotArgDoesNotConform, ii, slotType); + return; + } + + ExistentialTypeSlots::Arg arg; + arg.type = argType; + arg.witness = witness; + ioSlots.args.add(arg); + } + } + + void EntryPoint::_specializeExistentialTypeParams( + List> const& args, + DiagnosticSink* sink) + { + Slang::_specializeExistentialTypeParams(getLinkage(), m_existentialSlots, args, sink); + } + + /// Create a specialization an existing entry point based on generic arguments. + RefPtr createSpecializedEntryPoint( + EntryPoint* unspecializedEntryPoint, + List> const& genericArgs, + List> const& existentialArgs, + DiagnosticSink* sink) + { + auto linkage = unspecializedEntryPoint->getLinkage(); + + // TODO: Need to be careful in case entry point already has a decl-ref, + // pertaining to outer specializations (e.g., when entry point was + // nested in a generic type. + // + auto entryPointFuncDecl = unspecializedEntryPoint->getFuncDecl(); + + SemanticsVisitor semantics( + linkage, + sink); + + DeclRef entryPointFuncDeclRef = makeDeclRef(entryPointFuncDecl.Ptr()); + if( auto genericDecl = as(entryPointFuncDecl->ParentDecl) ) + { + // We will construct a suitable `GenericAppExpr` to represent + // the user-specified `genericDecl` being applied to the + // supplied `genericArgs`, and then use the existing + // semantic checking logic that would apply to an explicit + // generic application like `F` if it were + // encountered in the source code. + + auto session = linkage->getSession(); + auto genericDeclRef = makeDeclRef(genericDecl); + + // The first pieces is a `VarExpr` that refers to `genericDecl`. + // + // TODO: This would not be needed if we instead parsed + // the supplied entry-point name into an expression + // earlier in this function. + // + RefPtr genericExpr = new VarExpr(); + genericExpr->declRef = genericDeclRef; + genericExpr->type.type = getTypeForDeclRef(session, genericDeclRef); + + // Next we construct the actual `GenericAppExpr` + // + RefPtr genericAppExpr = new GenericAppExpr(); + genericAppExpr->FunctionExpr = genericExpr; + genericAppExpr->Arguments = genericArgs; + + // We use the semantics visitor to perform the + // actual checking logic (this might report + // errors) + // + auto checkedExpr = semantics.checkGenericAppWithCheckedArgs(genericAppExpr); + + // Now we need to extract an appropriate decl-ref for the entry + // point from the `checkedExpr`. + // + if( auto declRefExpr = checkedExpr.as() ) + { + // TODO: We should eventually check for the case + // where we have a `MemberExpr` or another case of + // `DeclRefExpr` that cannot be summarized as just + // its decl-ref. + // + // The basic `VarExpr` and `StaticMemberExpr` cases + // should be allow-able. + + entryPointFuncDeclRef = declRefExpr->declRef.as(); + } + else if( semantics.IsErrorExpr(checkedExpr) ) + { + // Any semantic error that occured should have been + // reported already. + return nullptr; + } + else + { + // The result of specializing a reference to a generic + // function should always be a `DeclRefExpr` + // + SLANG_UNEXPECTED("reference to generic decl wasn't a `DeclRefExpr`"); + UNREACHABLE_RETURN(nullptr); + } + } + + RefPtr specializedEntryPoint = EntryPoint::create( + entryPointFuncDeclRef, + unspecializedEntryPoint->getProfile()); + + // Next we need to validate the existential arguments. + specializedEntryPoint->_specializeExistentialTypeParams(existentialArgs, sink); + + return specializedEntryPoint; + } + + /// Parse an array of strings as generic arguments. + /// + /// Names in the strings will be parsed in the context of + /// the code loaded into the given compile request. + /// + void parseGenericArgStrings( + EndToEndCompileRequest* endToEndReq, + List const& genericArgStrings, + List>& outGenericArgs) + { + auto unspecialiedProgram = endToEndReq->getUnspecializedProgram(); + + // TODO: Building a list of `scopesToTry` here shouldn't + // be required, since the `Scope` type itself has the ability + // for form chains for lookup purposes (e.g., the way that + // `import` is handled by modifying a scope). + // + List> scopesToTry; + for( auto module : unspecialiedProgram->getModuleDependencies() ) + scopesToTry.add(module->getModuleDecl()->scope); + + // We are going to do some semantic checking, so we need to + // set up a `SemanticsVistitor` that we can use. + // + auto linkage = endToEndReq->getLinkage(); + auto sink = endToEndReq->getSink(); + SemanticsVisitor semantics( + linkage, + sink); + + // We will be looping over the generic argument strings + // that the user provided via the API (or command line), + // and parsing+checking each into an `Expr`. + // + // This loop will *not* handle coercing the arguments + // to be types. + // + for(auto name : genericArgStrings) + { + RefPtr argExpr; + for (auto & s : scopesToTry) + { + argExpr = linkage->parseTypeString(name, s); + argExpr = semantics.CheckTerm(argExpr); + if( argExpr ) + { + break; + } + } + + outGenericArgs.add(argExpr); + } + } + + void Program::_specializeExistentialTypeParams( + List> const& args, + DiagnosticSink* sink) + { + Slang::_specializeExistentialTypeParams(getLinkage(), m_globalExistentialSlots, args, sink); + } + + Type* Linkage::specializeType( + Type* unspecializedType, + Int argCount, + Type* const* args, + DiagnosticSink* sink) + { + // TODO: We should cache and re-use specialized types + // when the exact same arguments are provided again later. + + SemanticsVisitor visitor(this, sink); + + + ExistentialTypeSlots slots; + _collectExistentialTypeParamsRec(slots, unspecializedType); + + assert(slots.paramTypes.getCount() == argCount); + + for( Int aa = 0; aa < argCount; ++aa ) + { + auto argType = args[aa]; + + ExistentialTypeSlots::Arg arg; + arg.type = argType; + arg.witness = visitor.tryGetSubtypeWitness(argType, slots.paramTypes[aa]); + slots.args.add(arg); + } + + RefPtr specializedType = new ExistentialSpecializedType(); + specializedType->baseType = unspecializedType; + specializedType->slots = slots; + + m_specializedTypes.add(specializedType); + + return specializedType; + } + + /// Specialize a program to global generic arguments + RefPtr createSpecializedProgram( + Linkage* linkage, + Program* unspecializedProgram, + List> const& globalGenericArgs, + List> const& globalExistentialArgs, + DiagnosticSink* sink) + { + // The given `unspecializedProgram` should be one that + // was checked through the front-end, so that now we + // only need to check if the given arguments can satisfy + // the requirements of the global generic parameters. + // + // The new program needs to start off with the same + // module dependency list as the original. + // + RefPtr specializedProgram = new Program(linkage); + for(auto module : unspecializedProgram->getModuleDependencies()) + { + specializedProgram->addReferencedLeafModule(module); + } + + + // We will collect all the global generic parameters + // defined in the modules being referenced, to find + // the global generic parameter signature of the + // program. + // + // TODO: Note that this doesn't handle the case where one + // or more of the type *arguments* that we are specifying + // ends up requiring additional modules to be referenced, + // which might in turn introduce new global generic parameters. + // + List> globalGenericParams; + for(auto module : unspecializedProgram->getModuleDependencies()) + { + for(auto param : module->getModuleDecl()->getMembersOfType()) + globalGenericParams.add(param); + } + + // Next, we will check whether the supplied arguments can + // satisfy those parameters. + // + // An easy early-out case will be if the number of + // arguments isn't correct. + // + if (globalGenericParams.getCount() != globalGenericArgs.getCount()) + { + sink->diagnose(SourceLoc(), Diagnostics::mismatchGlobalGenericArguments, + globalGenericParams.getCount(), + globalGenericArgs.getCount()); + return nullptr; + } + + // We have an appropriate number of arguments for the global generic parameters, + // and now we need to check that the arguments conform to the declared constraints. + // + SemanticsVisitor visitor(linkage, sink); + + // Along the way, we will build up an appropriate set of substitutions to represent + // the generic arguments and their conformances. + // + RefPtr globalGenericSubsts; + auto globalGenericSubstLink = &globalGenericSubsts; + // + // TODO: There is a serious flaw to this checking logic if we ever have cases where + // the constraints on one `type_param` can depend on another `type_param`, e.g.: + // + // type_param A; + // type_param B : ISidekick; + // + // In that case, if a user tries to set `B` to `Robin` and `Robin` conforms to + // `ISidekick`, then the compiler needs to know whether `A` is being + // set to `Batman` to know whether the setting for `B` is valid. In this limit + // the constraints can be mutually recursive (so `A : IMentor`). + // + // The only way to check things correctly is to validate each conformance under + // a set of assumptions (substitutions) that includes all the type substitutions, + // and possibly also all the other constraints *except* the one to be validated. + // + // We will punt on this for now, and just check each constraint in isolation. + // + Index argCounter = 0; + for(auto& globalGenericParam : globalGenericParams) + { + // Get the argument that matches this parameter. + Index argIndex = argCounter++; + SLANG_ASSERT(argIndex < globalGenericArgs.getCount()); + auto globalGenericArg = checkProperType(linkage, TypeExp(globalGenericArgs[argIndex]), sink); + if (!globalGenericArg) + { + sink->diagnose(globalGenericParam, Diagnostics::globalGenericArgumentNotAType, globalGenericParam->getName()); + return nullptr; + } + + // As a quick sanity check, see if the argument that is being supplied for a parameter + // is just the parameter itself, because this should always be an error: + // + if( auto argDeclRefType = globalGenericArg.as() ) + { + auto argDeclRef = argDeclRefType->declRef; + if(auto argGenericParamDeclRef = argDeclRef.as()) + { + if(argGenericParamDeclRef.getDecl() == globalGenericParam) + { + // We are trying to specialize a generic parameter using itself. + sink->diagnose(globalGenericParam, + Diagnostics::cannotSpecializeGlobalGenericToItself, + globalGenericParam->getName()); + continue; + } + else + { + // We are trying to specialize a generic parameter using a *different* + // global generic type parameter. + sink->diagnose(globalGenericParam, + Diagnostics::cannotSpecializeGlobalGenericToAnotherGenericParam, + globalGenericParam->getName(), + argGenericParamDeclRef.GetName()); + continue; + } + } + } + + // Create a substitution for this parameter/argument. + RefPtr subst = new GlobalGenericParamSubstitution(); + subst->paramDecl = globalGenericParam; + subst->actualType = globalGenericArg; + + // Walk through the declared constraints for the parameter, + // and check that the argument actually satisfies them. + for(auto constraint : globalGenericParam->getMembersOfType()) + { + // Get the type that the constraint is enforcing conformance to + auto interfaceType = GetSup(DeclRef(constraint, nullptr)); + + // Use our semantic-checking logic to search for a witness to the required conformance + auto witness = visitor.tryGetSubtypeWitness(globalGenericArg, interfaceType); + if (!witness) + { + // If no witness was found, then we will be unable to satisfy + // the conformances required. + sink->diagnose(globalGenericParam, + Diagnostics::typeArgumentDoesNotConformToInterface, + globalGenericParam->nameAndLoc.name, + globalGenericArg, + interfaceType); + } + + // Attach the concrete witness for this conformance to the + // substutiton + GlobalGenericParamSubstitution::ConstraintArg constraintArg; + constraintArg.decl = constraint; + constraintArg.val = witness; + subst->constraintArgs.add(constraintArg); + } + + // Add the substitution for this parameter to the global substitution + // set that we are building. + + *globalGenericSubstLink = subst; + globalGenericSubstLink = &subst->outer; + } + if(sink->GetErrorCount()) + return nullptr; + + specializedProgram->setGlobalGenericSubsitution(globalGenericSubsts); + + // Now deal with the shader parameters and existential arguments + // + // Note: We should in theory be able to just copy over the shader + // parameters and existential slot information from the unspecialized + // program. This could save some time, but it would also mean that + // the only way to create a specialized program is by creating an + // unspecialized on first, which is maybe not always desirable. + // + specializedProgram->_collectShaderParams(sink); + specializedProgram->_specializeExistentialTypeParams(globalExistentialArgs, sink); + + return specializedProgram; + } + + /// Specialize an entry point that was checked by the front-end, based on generic arguments. + /// + /// If the end-to-end compile request included generic argument strings + /// for this entry point, then they will be parsed, checked, and used + /// as arguments to the generic entry point. + /// + /// Returns a specialized entry point if everything worked as expected. + /// Returns null and diagnoses errors if anything goes wrong. + /// + RefPtr createSpecializedEntryPoint( + EndToEndCompileRequest* endToEndReq, + EntryPoint* unspecializedEntryPoint, + EndToEndCompileRequest::EntryPointInfo const& entryPointInfo) + { + auto sink = endToEndReq->getSink(); + auto entryPointFuncDecl = unspecializedEntryPoint->getFuncDecl(); + + // If the user specified generic arguments for the entry point, + // then we will need to parse the arguments first. + // + List> genericArgs; + parseGenericArgStrings( + endToEndReq, + entryPointInfo.genericArgStrings, + genericArgs); + + List> existentialArgs; + parseGenericArgStrings( + endToEndReq, + entryPointInfo.existentialArgStrings, + existentialArgs); + + // Next we specialize the entry point function given the parsed + // generic argument expressions. + // + auto entryPoint = createSpecializedEntryPoint( + unspecializedEntryPoint, + genericArgs, + existentialArgs, + sink); + + return entryPoint; + } + + /// Create a specialized program based on the given compile request. + /// + RefPtr createSpecializedProgram( + EndToEndCompileRequest* endToEndReq) + { + // The compile request must have already completed front-end processing, + // so that we have an unspecialized program available, and now only need + // to parse and check any generic arguments that are being supplied for + // global or entry-point generic parameters. + // + auto unspecializedProgram = endToEndReq->getUnspecializedProgram(); + + // First, let's parse the generic argument strings that were + // provided via the API, so taht we can match them + // against what was declared in the program. + // + List> globalGenericArgs; + parseGenericArgStrings( + endToEndReq, + endToEndReq->globalGenericArgStrings, + globalGenericArgs); + + // Also handle global existential type arguments. + List> globalExistentialArgs; + parseGenericArgStrings( + endToEndReq, + endToEndReq->globalExistentialSlotArgStrings, + globalExistentialArgs); + + // Now we create the initial specialized program by + // applying the global generic arguments (if any) to the + // unspecialized program. + // + auto specializedProgram = createSpecializedProgram( + endToEndReq->getLinkage(), + unspecializedProgram, + globalGenericArgs, + globalExistentialArgs, + endToEndReq->getSink()); + + // If anything went wrong with the global generic + // arguments, then bail out now. + // + if(!specializedProgram) + return nullptr; + + // Next we will deal with the entry points for the + // new specialized program. + // + // If the user specified explicit entry points as part of the + // end-to-end request, then we only want to process those (and + // ignore any other `[shader(...)]`-attributed entry points). + // + // However, if the user specified *no* entry points as part + // of the end-to-end request, then we would like to go + // ahead and consider all the entry points that were found + // by the front-end. + // + Index entryPointCount = endToEndReq->entryPoints.getCount(); + if( entryPointCount == 0 ) + { + entryPointCount = unspecializedProgram->getEntryPointCount(); + endToEndReq->entryPoints.setCount(entryPointCount); + } + + for( Index ii = 0; ii < entryPointCount; ++ii ) + { + auto unspecializedEntryPoint = unspecializedProgram->getEntryPoint(ii); + auto& entryPointInfo = endToEndReq->entryPoints[ii]; + + auto specializedEntryPoint = createSpecializedEntryPoint(endToEndReq, unspecializedEntryPoint, entryPointInfo); + specializedProgram->addEntryPoint(specializedEntryPoint); + } + + return specializedProgram; + } + + void checkTranslationUnit( + TranslationUnitRequest* translationUnit) + { + SemanticsVisitor visitor( + translationUnit->compileRequest->getLinkage(), + translationUnit->compileRequest->getSink()); + + // Apply the visitor to do the main semantic + // checking that is required on all declarations + // in the translation unit. + visitor.checkDecl(translationUnit->getModuleDecl()); + } + + + // + + // Get the type to use when referencing a declaration + QualType getTypeForDeclRef( + Session* session, + SemanticsVisitor* sema, + DiagnosticSink* sink, + DeclRef declRef, + RefPtr* outTypeResult) + { + if( sema ) + { + sema->checkDecl(declRef.getDecl()); + } + + // We need to insert an appropriate type for the expression, based on + // what we found. + if (auto varDeclRef = declRef.as()) + { + QualType qualType; + qualType.type = GetType(varDeclRef); + + bool isLValue = true; + if(varDeclRef.getDecl()->FindModifier()) + isLValue = false; + + // Global-scope shader parameters should not be writable, + // since they are effectively program inputs. + // + // TODO: We could eventually treat a mutable global shader + // parameter as a shorthand for an immutable parameter and + // a global variable that gets initialized from that parameter, + // but in order to do so we'd need to support global variables + // with resource types better in the back-end. + // + if(isGlobalShaderParameter(varDeclRef.getDecl())) + isLValue = false; + + // Variables declared with `let` are always immutable. + if(varDeclRef.is()) + isLValue = false; + + // Generic value parameters are always immutable + if(varDeclRef.is()) + isLValue = false; + + // Function parameters declared in the "modern" style + // are immutable unless they have an `out` or `inout` modifier. + if(varDeclRef.is()) + { + // Note: the `inout` modifier AST class inherits from + // the class for the `out` modifier so that we can + // make simple checks like this. + // + if( !varDeclRef.getDecl()->HasModifier() ) + { + isLValue = false; + } + } + + qualType.IsLeftValue = isLValue; + return qualType; + } + else if( auto enumCaseDeclRef = declRef.as() ) + { + QualType qualType; + qualType.type = getType(enumCaseDeclRef); + qualType.IsLeftValue = false; + return qualType; + } + else if (auto typeAliasDeclRef = declRef.as()) + { + auto type = getNamedType(session, typeAliasDeclRef); + *outTypeResult = type; + return QualType(getTypeType(type)); + } + else if (auto aggTypeDeclRef = declRef.as()) + { + auto type = DeclRefType::Create(session, aggTypeDeclRef); + *outTypeResult = type; + return QualType(getTypeType(type)); + } + else if (auto simpleTypeDeclRef = declRef.as()) + { + auto type = DeclRefType::Create(session, simpleTypeDeclRef); + *outTypeResult = type; + return QualType(getTypeType(type)); + } + else if (auto genericDeclRef = declRef.as()) + { + auto type = getGenericDeclRefType(session, genericDeclRef); + *outTypeResult = type; + return QualType(getTypeType(type)); + } + else if (auto funcDeclRef = declRef.as()) + { + auto type = getFuncType(session, funcDeclRef); + return QualType(type); + } + else if (auto constraintDeclRef = declRef.as()) + { + // When we access a constraint or an inheritance decl (as a member), + // we are conceptually performing a "cast" to the given super-type, + // with the declaration showing that such a cast is legal. + auto type = GetSup(constraintDeclRef); + return QualType(type); + } + if( sink ) + { + sink->diagnose(declRef, Diagnostics::unimplemented, "cannot form reference to this kind of declaration"); + } + return QualType(session->getErrorType()); + } + + QualType getTypeForDeclRef( + Session* session, + DeclRef declRef) + { + RefPtr typeResult; + return getTypeForDeclRef(session, nullptr, nullptr, declRef, &typeResult); + } + + DeclRef ApplyExtensionToType( + SemanticsVisitor* semantics, + ExtensionDecl* extDecl, + RefPtr type) + { + if(!semantics) + return DeclRef(); + + return semantics->ApplyExtensionToType(extDecl, type); + } + + RefPtr createDefaultSubsitutionsForGeneric( + Session* session, + GenericDecl* genericDecl, + RefPtr outerSubst) + { + RefPtr genericSubst = new GenericSubstitution(); + genericSubst->genericDecl = genericDecl; + genericSubst->outer = outerSubst; + + for( auto mm : genericDecl->Members ) + { + if( auto genericTypeParamDecl = as(mm) ) + { + genericSubst->args.add(DeclRefType::Create(session, DeclRef(genericTypeParamDecl, outerSubst))); + } + else if( auto genericValueParamDecl = as(mm) ) + { + genericSubst->args.add(new GenericParamIntVal(DeclRef(genericValueParamDecl, outerSubst))); + } + } + + // create default substitution arguments for constraints + for (auto mm : genericDecl->Members) + { + if (auto genericTypeConstraintDecl = as(mm)) + { + RefPtr witness = new DeclaredSubtypeWitness(); + witness->declRef = DeclRef(genericTypeConstraintDecl, outerSubst); + witness->sub = genericTypeConstraintDecl->sub.type; + witness->sup = genericTypeConstraintDecl->sup.type; + genericSubst->args.add(witness); + } + } + + return genericSubst; + } + + // Sometimes we need to refer to a declaration the way that it would be specialized + // inside the context where it is declared (e.g., with generic parameters filled in + // using their archetypes). + // + SubstitutionSet createDefaultSubstitutions( + Session* session, + Decl* decl, + SubstitutionSet outerSubstSet) + { + auto dd = decl->ParentDecl; + if( auto genericDecl = as(dd) ) + { + // We don't want to specialize references to anything + // other than the "inner" declaration itself. + if(decl != genericDecl->inner) + return outerSubstSet; + + RefPtr genericSubst = createDefaultSubsitutionsForGeneric( + session, + genericDecl, + outerSubstSet.substitutions); + + return SubstitutionSet(genericSubst); + } + + return outerSubstSet; + } + + SubstitutionSet createDefaultSubstitutions( + Session* session, + Decl* decl) + { + SubstitutionSet subst; + if( auto parentDecl = decl->ParentDecl ) + { + subst = createDefaultSubstitutions(session, parentDecl); + } + subst = createDefaultSubstitutions(session, decl, subst); + return subst; + } + + void checkDecl(SemanticsVisitor* visitor, Decl* decl) + { + visitor->checkDecl(decl); + } +} diff --git a/source/slang/slang-check.h b/source/slang/slang-check.h new file mode 100644 index 000000000..e243d572b --- /dev/null +++ b/source/slang/slang-check.h @@ -0,0 +1,7 @@ +// slang-check.h +#pragma once + +namespace Slang +{ + bool isGlobalShaderParameter(VarDeclBase* decl); +} diff --git a/source/slang/slang-compiler.cpp b/source/slang/slang-compiler.cpp new file mode 100644 index 000000000..94f350ce3 --- /dev/null +++ b/source/slang/slang-compiler.cpp @@ -0,0 +1,1645 @@ +// Compiler.cpp : Defines the entry point for the console application. +// +#include "../core/slang-basic.h" +#include "../core/slang-platform.h" +#include "../core/slang-io.h" +#include "../core/slang-string-util.h" + +#include "slang-compiler.h" +#include "slang-lexer.h" +#include "slang-lower-to-ir.h" +#include "slang-parameter-binding.h" +#include "slang-parser.h" +#include "slang-preprocessor.h" +#include "slang-syntax-visitors.h" +#include "slang-type-layout.h" +#include "slang-reflection.h" +#include "slang-emit.h" + +// Enable calling through to `fxc` or `dxc` to +// generate code on Windows. +#ifdef _WIN32 + #define WIN32_LEAN_AND_MEAN + #define NOMINMAX + #include + #undef WIN32_LEAN_AND_MEAN + #undef NOMINMAX + #include + #ifndef SLANG_ENABLE_DXBC_SUPPORT + #define SLANG_ENABLE_DXBC_SUPPORT 1 + #endif + #ifndef SLANG_ENABLE_DXIL_SUPPORT + #define SLANG_ENABLE_DXIL_SUPPORT 1 + #endif +#endif +// +// Otherwise, don't enable DXBC/DXIL by default: +#ifndef SLANG_ENABLE_DXBC_SUPPORT + #define SLANG_ENABLE_DXBC_SUPPORT 0 +#endif +#ifndef SLANG_ENABLE_DXIL_SUPPORT + #define SLANG_ENABLE_DXIL_SUPPORT 0 +#endif + +// Enable calling through to `glslang` on +// all platforms. +#ifndef SLANG_ENABLE_GLSLANG_SUPPORT + #define SLANG_ENABLE_GLSLANG_SUPPORT 1 +#endif + +#if SLANG_ENABLE_GLSLANG_SUPPORT +#include "../slang-glslang/slang-glslang.h" +#endif + +// Includes to allow us to control console +// output when writing assembly dumps. +#include +#ifdef _WIN32 +#include +#else +#include +#endif + +#ifdef _MSC_VER +#pragma warning(disable: 4996) +#endif + +#ifdef CreateDirectory +#undef CreateDirectory +#endif + +namespace Slang +{ + + // CompileResult + + void CompileResult::append(CompileResult const& result) + { + // Find which to append to + ResultFormat appendTo = ResultFormat::None; + + if (format == ResultFormat::None) + { + format = result.format; + appendTo = result.format; + } + else if (format == result.format) + { + appendTo = format; + } + + if (appendTo == ResultFormat::Text) + { + outputString.append(result.outputString.getBuffer()); + } + else if (appendTo == ResultFormat::Binary) + { + outputBinary.addRange(result.outputBinary.getBuffer(), result.outputBinary.getCount()); + } + } + + ComPtr CompileResult::getBlob() + { + if(!blob) + { + switch(format) + { + case ResultFormat::None: + default: + break; + + case ResultFormat::Text: + blob = StringUtil::createStringBlob(outputString); + break; + + case ResultFormat::Binary: + blob = createRawBlob(outputBinary.getBuffer(), outputBinary.getCount()); + break; + } + } + return blob; + } + + // + // FrontEndEntryPointRequest + // + + FrontEndEntryPointRequest::FrontEndEntryPointRequest( + FrontEndCompileRequest* compileRequest, + int translationUnitIndex, + Name* name, + Profile profile) + : m_compileRequest(compileRequest) + , m_translationUnitIndex(translationUnitIndex) + , m_name(name) + , m_profile(profile) + {} + + + TranslationUnitRequest* FrontEndEntryPointRequest::getTranslationUnit() + { + return getCompileRequest()->translationUnits[m_translationUnitIndex]; + } + + // + // EntryPoint + // + + RefPtr EntryPoint::create( + DeclRef funcDeclRef, + Profile profile) + { + RefPtr entryPoint = new EntryPoint( + funcDeclRef.GetName(), + profile, + funcDeclRef); + return entryPoint; + } + + RefPtr EntryPoint::createDummyForPassThrough( + Name* name, + Profile profile) + { + RefPtr entryPoint = new EntryPoint( + name, + profile, + DeclRef()); + return entryPoint; + } + + EntryPoint::EntryPoint( + Name* name, + Profile profile, + DeclRef funcDeclRef) + : m_name(name) + , m_profile(profile) + , m_funcDeclRef(funcDeclRef) + { + // In order for later code generation to work, we need to track what + // modules each entry point depends on. We will build up the dependency + // list here when an `EntryPoint` gets created. + // + // We know an entry point depends on the module that declared the + // entry-point function itself. + // + // Note: we are carefully handling the case where `module` could + // be null, becase of "dummy" entry points created for pass-through + // compilation. + // + if(auto module = getModule()) + { + m_dependencyList.addDependency(module); + } + // + // TODO: We also need to include the modules needed by any generic + // arguments in the dependency list, since in the general case they + // might come from modules other than the one defining the entry point. + + // The following is a bit of a hack. + // + // Back-end code generation relies on us having computed layouts for all tagged + // unions that end up being used in the code, which means we need a way to find + // all such types that get used in a program (and the stuff it imports). + // + // For now we are assuming a tagged union type only comes into existence + // as a (top-level) argument for a generic type parameter, so that we + // can check for them here and cache them on the entry point. + // + // A longer-term strategy might need to consider any (tagged or untagged) + // union types that get used inside of a module, and also take + // those lists into account. + // + // An even longer-term strategy would be to allow type layout to + // be performed on IR types, so taht we don't need to have front-end + // code worrying about this stuff. + // + for( auto subst = funcDeclRef.substitutions.substitutions; subst; subst = subst->outer ) + { + if( auto genericSubst = as(subst) ) + { + for( auto arg : genericSubst->args ) + { + if( auto taggedUnionType = as(arg) ) + { + m_taggedUnionTypes.add(taggedUnionType); + } + } + } + } + + // Collect any existential-type parameters used by the entry point + // + _collectShaderParams(); + } + + Module* EntryPoint::getModule() + { + return Slang::getModule(getFuncDecl()); + } + + Linkage* EntryPoint::getLinkage() + { + return getModule()->getLinkage(); + } + + // + + Profile Profile::LookUp(char const* name) + { + #define PROFILE(TAG, NAME, STAGE, VERSION) if(strcmp(name, #NAME) == 0) return Profile::TAG; + #define PROFILE_ALIAS(TAG, DEF, NAME) if(strcmp(name, #NAME) == 0) return Profile::TAG; + #include "slang-profile-defs.h" + + return Profile::Unknown; + } + + char const* Profile::getName() + { + switch( raw ) + { + default: + return "unknown"; + + #define PROFILE(TAG, NAME, STAGE, VERSION) case Profile::TAG: return #NAME; + #define PROFILE_ALIAS(TAG, DEF, NAME) /* empty */ + #include "slang-profile-defs.h" + } + } + + Stage findStageByName(String const& name) + { + static const struct + { + char const* name; + Stage stage; + } kStages[] = + { + #define PROFILE_STAGE(ID, NAME, ENUM) \ + { #NAME, Stage::ID }, + + #define PROFILE_STAGE_ALIAS(ID, NAME, VAL) \ + { #NAME, Stage::ID }, + + #include "slang-profile-defs.h" + }; + + for(auto entry : kStages) + { + if(name == entry.name) + { + return entry.stage; + } + } + + return Stage::Unknown; + } + + SlangResult checkExternalCompilerSupport(Session* session, PassThroughMode passThrough) + { + switch (passThrough) + { + case PassThroughMode::None: + { + // If no pass through -> that will always work! + return SLANG_OK; + } + case PassThroughMode::dxc: + { +#if SLANG_ENABLE_DXIL_SUPPORT + // Must have dxc + return session->getOrLoadSharedLibrary(SharedLibraryType::Dxc, nullptr) ? SLANG_OK : SLANG_E_NOT_FOUND; +#endif + break; + } + case PassThroughMode::fxc: + { +#if SLANG_ENABLE_DXBC_SUPPORT + // Must have fxc + return session->getOrLoadSharedLibrary(SharedLibraryType::Fxc, nullptr) ? SLANG_OK : SLANG_E_NOT_FOUND; +#endif + break; + } + case PassThroughMode::glslang: + { +#if SLANG_ENABLE_GLSLANG_SUPPORT + return session->getOrLoadSharedLibrary(Slang::SharedLibraryType::Glslang, nullptr) ? SLANG_OK : SLANG_E_NOT_FOUND; +#endif + break; + } + } + return SLANG_E_NOT_IMPLEMENTED; + } + + static PassThroughMode _getExternalCompilerRequiredForTarget(CodeGenTarget target) + { + switch (target) + { + case CodeGenTarget::None: + { + return PassThroughMode::None; + } + case CodeGenTarget::GLSL: + case CodeGenTarget::GLSL_Vulkan: + case CodeGenTarget::GLSL_Vulkan_OneDesc: + { + // Can always output GLSL + return PassThroughMode::None; + } + case CodeGenTarget::HLSL: + { + // Can always output HLSL + return PassThroughMode::None; + } + case CodeGenTarget::SPIRVAssembly: + case CodeGenTarget::SPIRV: + { + return PassThroughMode::glslang; + } + case CodeGenTarget::DXBytecode: + case CodeGenTarget::DXBytecodeAssembly: + { + return PassThroughMode::fxc; + } + case CodeGenTarget::DXIL: + case CodeGenTarget::DXILAssembly: + { + return PassThroughMode::dxc; + } + case CodeGenTarget::CPPSource: + case CodeGenTarget::CSource: + { + // Don't need an external compiler to output C and C++ code + return PassThroughMode::None; + } + + default: break; + } + + SLANG_ASSERT(!"Unhandled target"); + return PassThroughMode::None; + } + + SlangResult checkCompileTargetSupport(Session* session, CodeGenTarget target) + { + const PassThroughMode mode = _getExternalCompilerRequiredForTarget(target); + return (mode != PassThroughMode::None) ? + checkExternalCompilerSupport(session, mode) : + SLANG_OK; + } + + // + + /// If there is a pass-through compile going on, find the translation unit for the given entry point. + TranslationUnitRequest* findPassThroughTranslationUnit( + EndToEndCompileRequest* endToEndReq, + Int entryPointIndex) + { + // If there isn't an end-to-end compile going on, + // there can be no pass-through. + // + if(!endToEndReq) return nullptr; + + // And if pass-through isn't set, we don't need + // access to the translation unit. + // + if(endToEndReq->passThrough == PassThroughMode::None) return nullptr; + + auto frontEndReq = endToEndReq->getFrontEndReq(); + auto entryPointReq = frontEndReq->getEntryPointReq(entryPointIndex); + auto translationUnit = entryPointReq->getTranslationUnit(); + return translationUnit; + } + + String emitHLSLForEntryPoint( + BackEndCompileRequest* compileRequest, + EntryPoint* entryPoint, + Int entryPointIndex, + TargetRequest* targetReq, + EndToEndCompileRequest* endToEndReq) + { + if(auto translationUnit = findPassThroughTranslationUnit(endToEndReq, entryPointIndex)) + { + // Generate a string that includes the content of + // the source file(s), along with a line directive + // to ensure that we get reasonable messages + // from the downstream compiler when in pass-through + // mode. + + StringBuilder codeBuilder; + for(auto sourceFile : translationUnit->getSourceFiles()) + { + codeBuilder << "#line 1 \""; + + const String& path = sourceFile->getPathInfo().foundPath; + + for(auto c : path) + { + char buffer[] = { c, 0 }; + switch(c) + { + default: + codeBuilder << buffer; + break; + + case '\\': + codeBuilder << "\\\\"; + } + } + codeBuilder << "\"\n"; + + codeBuilder << sourceFile->getContent() << "\n"; + } + + return codeBuilder.ProduceString(); + } + else + { + return emitEntryPoint( + compileRequest, + entryPoint, + CodeGenTarget::HLSL, + targetReq); + } + } + + String emitGLSLForEntryPoint( + BackEndCompileRequest* compileRequest, + EntryPoint* entryPoint, + Int entryPointIndex, + TargetRequest* targetReq, + EndToEndCompileRequest* endToEndReq) + { + if(auto translationUnit = findPassThroughTranslationUnit(endToEndReq, entryPointIndex)) + { + // Generate a string that includes the content of + // the source file(s), along with a line directive + // to ensure that we get reasonable messages + // from the downstream compiler when in pass-through + // mode. + + StringBuilder codeBuilder; + int translationUnitCounter = 0; + for(auto sourceFile : translationUnit->getSourceFiles()) + { + int translationUnitIndex = translationUnitCounter++; + + // We want to output `#line` directives, but we need + // to skip this for the first file, since otherwise + // some GLSL implementations will get tripped up by + // not having the `#version` directive be the first + // thing in the file. + if(translationUnitIndex != 0) + { + codeBuilder << "#line 1 " << translationUnitIndex << "\n"; + } + codeBuilder << sourceFile->getContent() << "\n"; + } + + return codeBuilder.ProduceString(); + } + else + { + // TODO(tfoley): need to pass along the entry point + // so that we properly emit it as the `main` function. + return emitEntryPoint( + compileRequest, + entryPoint, + CodeGenTarget::GLSL, + targetReq); + } + } + + String GetHLSLProfileName(Profile profile) + { + switch( profile.getFamily() ) + { + case ProfileFamily::DX: + // Profile version is a DX one, so stick with it. + break; + + default: + // Profile is a non-DX profile family, so we need to try + // to clobber it with something to get a default. + // + // TODO: This is a huge hack... + profile.setVersion(ProfileVersion::DX_5_0); + break; + } + + char const* stagePrefix = nullptr; + switch( profile.GetStage() ) + { + // Note: All of the raytracing-related stages require + // compiling for a `lib_*` profile, even when only a + // single entry point is present. + // + // We also go ahead and use this target in any case + // where we don't know the actual stage to compiel for, + // as a fallback option. + // + // TODO: We also want to use this option when compiling + // multiple entry points to a DXIL library. + // + default: + stagePrefix = "lib"; + break; + + // The traditional rasterization pipeline and compute + // shaders all have custom profile names that identify + // both the stage and shader model, which need to be + // used when compiling a single entry point. + // + #define CASE(NAME, PREFIX) case Stage::NAME: stagePrefix = #PREFIX; break + CASE(Vertex, vs); + CASE(Hull, hs); + CASE(Domain, ds); + CASE(Geometry, gs); + CASE(Fragment, ps); + CASE(Compute, cs); + #undef CASE + } + + char const* versionSuffix = nullptr; + switch(profile.GetVersion()) + { + #define CASE(TAG, SUFFIX) case ProfileVersion::TAG: versionSuffix = #SUFFIX; break + CASE(DX_4_0, _4_0); + CASE(DX_4_0_Level_9_0, _4_0_level_9_0); + CASE(DX_4_0_Level_9_1, _4_0_level_9_1); + CASE(DX_4_0_Level_9_3, _4_0_level_9_3); + CASE(DX_4_1, _4_1); + CASE(DX_5_0, _5_0); + CASE(DX_5_1, _5_1); + CASE(DX_6_0, _6_0); + CASE(DX_6_1, _6_1); + CASE(DX_6_2, _6_2); + CASE(DX_6_3, _6_3); + #undef CASE + + default: + return "unknown"; + } + + String result; + result.append(stagePrefix); + result.append(versionSuffix); + return result; + } + + void reportExternalCompileError(const char* compilerName, SlangResult res, const UnownedStringSlice& diagnostic, DiagnosticSink* sink) + { + StringBuilder builder; + if (compilerName) + { + builder << compilerName << ": "; + } + + if (diagnostic.size() > 0) + { + builder.Append(diagnostic); + } + + if (SLANG_FAILED(res) && res != SLANG_FAIL) + { + { + char tmp[17]; + sprintf_s(tmp, SLANG_COUNT_OF(tmp), "0x%08x", uint32_t(res)); + builder << "Result(" << tmp << ") "; + } + + PlatformUtil::appendResult(res, builder); + } + + // TODO(tfoley): need a better policy for how we translate diagnostics + // back into the Slang world (although we should always try to generate + // HLSL that doesn't produce any diagnostics...) + sink->diagnoseRaw(SLANG_FAILED(res) ? Severity::Error : Severity::Warning, builder.getUnownedSlice()); + } + + static String _getDisplayPath(DiagnosticSink* sink, SourceFile* sourceFile) + { + if (sink->flags & DiagnosticSink::Flag::VerbosePath) + { + return sourceFile->calcVerbosePath(); + } + else + { + return sourceFile->getPathInfo().foundPath; + } + } + + String calcSourcePathForEntryPoint( + EndToEndCompileRequest* endToEndReq, + UInt entryPointIndex) + { + auto translationUnitRequest = findPassThroughTranslationUnit(endToEndReq, entryPointIndex); + if(!translationUnitRequest) + return "slang-generated"; + + auto sink = endToEndReq->getSink(); + + const auto& sourceFiles = translationUnitRequest->getSourceFiles(); + + const Index numSourceFiles = sourceFiles.getCount(); + + switch (numSourceFiles) + { + case 0: return "unknown"; + case 1: return _getDisplayPath(sink, sourceFiles[0]); + default: + { + StringBuilder builder; + builder << _getDisplayPath(sink, sourceFiles[0]); + for (int i = 1; i < numSourceFiles; ++i) + { + builder << ";" << _getDisplayPath(sink, sourceFiles[i]); + } + return builder; + } + } + } + +#if SLANG_ENABLE_DXBC_SUPPORT + + static UnownedStringSlice _getSlice(ID3DBlob* blob) + { + if (blob) + { + const char* chars = (const char*)blob->GetBufferPointer(); + size_t len = blob->GetBufferSize(); + len -= size_t(len > 0 && chars[len - 1] == 0); + return UnownedStringSlice(chars, len); + } + return UnownedStringSlice(); + } + + SlangResult emitDXBytecodeForEntryPoint( + BackEndCompileRequest* compileRequest, + EntryPoint* entryPoint, + Int entryPointIndex, + TargetRequest* targetReq, + EndToEndCompileRequest* endToEndReq, + List& byteCodeOut) + { + byteCodeOut.clear(); + + auto session = compileRequest->getSession(); + auto sink = compileRequest->getSink(); + + auto compileFunc = (pD3DCompile)session->getSharedLibraryFunc(Session::SharedLibraryFuncType::Fxc_D3DCompile, sink); + if (!compileFunc) + { + return SLANG_FAIL; + } + + auto hlslCode = emitHLSLForEntryPoint(compileRequest, entryPoint, entryPointIndex, targetReq, endToEndReq); + maybeDumpIntermediate(compileRequest, hlslCode.getBuffer(), CodeGenTarget::HLSL); + + auto profile = getEffectiveProfile(entryPoint, targetReq); + + // If we have been invoked in a pass-through mode, then we need to make sure + // that the downstream compiler sees whatever options were passed to Slang + // via the command line or API. + // + // TODO: more pieces of information should be added here as needed. + // + List dxMacrosStorage; + D3D_SHADER_MACRO const* dxMacros = nullptr; + if(auto translationUnit = findPassThroughTranslationUnit(endToEndReq, entryPointIndex)) + { + for( auto& define : translationUnit->compileRequest->preprocessorDefinitions ) + { + D3D_SHADER_MACRO dxMacro; + dxMacro.Name = define.Key.getBuffer(); + dxMacro.Definition = define.Value.getBuffer(); + dxMacrosStorage.add(dxMacro); + } + for( auto& define : translationUnit->preprocessorDefinitions ) + { + D3D_SHADER_MACRO dxMacro; + dxMacro.Name = define.Key.getBuffer(); + dxMacro.Definition = define.Value.getBuffer(); + dxMacrosStorage.add(dxMacro); + } + D3D_SHADER_MACRO nullTerminator = { 0, 0 }; + dxMacrosStorage.add(nullTerminator); + + dxMacros = dxMacrosStorage.getBuffer(); + } + + DWORD flags = 0; + + switch( targetReq->floatingPointMode ) + { + default: + break; + + case FloatingPointMode::Precise: + flags |= D3DCOMPILE_IEEE_STRICTNESS; + break; + } + + // Some of the `D3DCOMPILE_*` constants aren't available in all + // versions of `d3dcompiler.h`, so we define them here just in case + #ifndef D3DCOMPILE_ENABLE_UNBOUNDED_DESCRIPTOR_TABLES + #define D3DCOMPILE_ENABLE_UNBOUNDED_DESCRIPTOR_TABLES (1 << 20) + #endif + + #ifndef D3DCOMPILE_ALL_RESOURCES_BOUND + #define D3DCOMPILE_ALL_RESOURCES_BOUND (1 << 21) + #endif + + flags |= D3DCOMPILE_ENABLE_STRICTNESS; + flags |= D3DCOMPILE_ENABLE_UNBOUNDED_DESCRIPTOR_TABLES; + + auto linkage = compileRequest->getLinkage(); + switch( linkage->optimizationLevel ) + { + default: + break; + + case OptimizationLevel::None: flags |= D3DCOMPILE_OPTIMIZATION_LEVEL0; break; + case OptimizationLevel::Default: flags |= D3DCOMPILE_OPTIMIZATION_LEVEL1; break; + case OptimizationLevel::High: flags |= D3DCOMPILE_OPTIMIZATION_LEVEL2; break; + case OptimizationLevel::Maximal: flags |= D3DCOMPILE_OPTIMIZATION_LEVEL3; break; + } + + switch( linkage->debugInfoLevel ) + { + case DebugInfoLevel::None: + break; + + default: + flags |= D3DCOMPILE_DEBUG; + break; + } + + const String sourcePath = calcSourcePathForEntryPoint(endToEndReq, entryPointIndex); + + ComPtr codeBlob; + ComPtr diagnosticsBlob; + HRESULT hr = compileFunc( + hlslCode.begin(), + hlslCode.getLength(), + sourcePath.getBuffer(), + dxMacros, + nullptr, + getText(entryPoint->getName()).begin(), + GetHLSLProfileName(profile).getBuffer(), + flags, + 0, // unused: effect flags + codeBlob.writeRef(), + diagnosticsBlob.writeRef()); + + if (codeBlob && SLANG_SUCCEEDED(hr)) + { + byteCodeOut.addRange((uint8_t const*)codeBlob->GetBufferPointer(), (int)codeBlob->GetBufferSize()); + } + + if (FAILED(hr)) + { + reportExternalCompileError("fxc", hr, _getSlice(diagnosticsBlob), sink); + } + + return hr; + } + + SlangResult dissassembleDXBC( + BackEndCompileRequest* compileRequest, + void const* data, + size_t size, + String& assemOut) + { + assemOut = String(); + + auto session = compileRequest->getSession(); + auto sink = compileRequest->getSink(); + + auto disassembleFunc = (pD3DDisassemble)session->getSharedLibraryFunc(Session::SharedLibraryFuncType::Fxc_D3DDisassemble, sink); + if (!disassembleFunc) + { + return SLANG_E_NOT_FOUND; + } + + if (!data || !size) + { + return SLANG_FAIL; + } + + ComPtr codeBlob; + SlangResult res = disassembleFunc(data, size, 0, nullptr, codeBlob.writeRef()); + + if (codeBlob) + { + assemOut = _getSlice(codeBlob); + } + if (FAILED(res)) + { + // TODO(tfoley): need to figure out what to diagnose here... + reportExternalCompileError("fxc", res, UnownedStringSlice(), sink); + } + + return res; + } + + SlangResult emitDXBytecodeAssemblyForEntryPoint( + BackEndCompileRequest* compileRequest, + EntryPoint* entryPoint, + Int entryPointIndex, + TargetRequest* targetReq, + EndToEndCompileRequest* endToEndReq, + String& assemOut) + { + + List dxbc; + SLANG_RETURN_ON_FAIL(emitDXBytecodeForEntryPoint( + compileRequest, + entryPoint, + entryPointIndex, + targetReq, + endToEndReq, + dxbc)); + if (!dxbc.getCount()) + { + return SLANG_FAIL; + } + return dissassembleDXBC(compileRequest, dxbc.getBuffer(), dxbc.getCount(), assemOut); + } +#endif + +#if SLANG_ENABLE_DXIL_SUPPORT + +// Implementations in `dxc-support.cpp` + +SlangResult emitDXILForEntryPointUsingDXC( + BackEndCompileRequest* compileRequest, + EntryPoint* entryPoint, + Int entryPointIndex, + TargetRequest* targetReq, + EndToEndCompileRequest* endToEndReq, + List& outCode); + +SlangResult dissassembleDXILUsingDXC( + BackEndCompileRequest* compileRequest, + void const* data, + size_t size, + String& stringOut); + +#endif + +#if SLANG_ENABLE_GLSLANG_SUPPORT + SlangResult invokeGLSLCompiler( + BackEndCompileRequest* slangCompileRequest, + glslang_CompileRequest& request) + { + Session* session = slangCompileRequest->getSession(); + auto sink = slangCompileRequest->getSink(); + + auto glslang_compile = (glslang_CompileFunc)session->getSharedLibraryFunc(Session::SharedLibraryFuncType::Glslang_Compile, sink); + if (!glslang_compile) + { + return SLANG_FAIL; + } + + StringBuilder diagnosticOutput; + + auto diagnosticOutputFunc = [](void const* data, size_t size, void* userData) + { + (*(StringBuilder*)userData).append((char const*)data, (char const*)data + size); + }; + + request.diagnosticFunc = diagnosticOutputFunc; + request.diagnosticUserData = &diagnosticOutput; + + int err = glslang_compile(&request); + + if (err) + { + reportExternalCompileError("glslang", SLANG_FAIL, diagnosticOutput.getUnownedSlice(), sink); + return SLANG_FAIL; + } + + return SLANG_OK; + } + + SlangResult dissassembleSPIRV( + BackEndCompileRequest* slangRequest, + void const* data, + size_t size, + String& stringOut) + { + stringOut = String(); + + String output; + auto outputFunc = [](void const* data, size_t size, void* userData) + { + (*(String*)userData).append((char const*)data, (char const*)data + size); + }; + + glslang_CompileRequest request; + request.action = GLSLANG_ACTION_DISSASSEMBLE_SPIRV; + + request.sourcePath = nullptr; + + request.inputBegin = data; + request.inputEnd = (char*)data + size; + + request.outputFunc = outputFunc; + request.outputUserData = &output; + + SLANG_RETURN_ON_FAIL(invokeGLSLCompiler(slangRequest, request)); + + stringOut = output; + return SLANG_OK; + } + + SlangResult emitSPIRVForEntryPoint( + BackEndCompileRequest* slangRequest, + EntryPoint* entryPoint, + Int entryPointIndex, + TargetRequest* targetReq, + EndToEndCompileRequest* endToEndReq, + List& spirvOut) + { + spirvOut.clear(); + + String rawGLSL = emitGLSLForEntryPoint( + slangRequest, + entryPoint, + entryPointIndex, + targetReq, + endToEndReq); + maybeDumpIntermediate(slangRequest, rawGLSL.getBuffer(), CodeGenTarget::GLSL); + + auto outputFunc = [](void const* data, size_t size, void* userData) + { + ((List*)userData)->addRange((uint8_t*)data, size); + }; + + const String sourcePath = calcSourcePathForEntryPoint(endToEndReq, entryPointIndex); + + glslang_CompileRequest request; + request.action = GLSLANG_ACTION_COMPILE_GLSL_TO_SPIRV; + request.sourcePath = sourcePath.getBuffer(); + request.slangStage = (SlangStage)entryPoint->getStage(); + + request.inputBegin = rawGLSL.begin(); + request.inputEnd = rawGLSL.end(); + + request.outputFunc = outputFunc; + request.outputUserData = &spirvOut; + + SLANG_RETURN_ON_FAIL(invokeGLSLCompiler(slangRequest, request)); + return SLANG_OK; + } + + SlangResult emitSPIRVAssemblyForEntryPoint( + BackEndCompileRequest* slangRequest, + EntryPoint* entryPoint, + Int entryPointIndex, + TargetRequest* targetReq, + EndToEndCompileRequest* endToEndReq, + String& assemblyOut) + { + List spirv; + SLANG_RETURN_ON_FAIL(emitSPIRVForEntryPoint( + slangRequest, + entryPoint, + entryPointIndex, + targetReq, + endToEndReq, + spirv)); + + if (spirv.getCount() == 0) + return SLANG_FAIL; + + return dissassembleSPIRV(slangRequest, spirv.begin(), spirv.getCount(), assemblyOut); + } +#endif + + // Do emit logic for a single entry point + CompileResult emitEntryPoint( + BackEndCompileRequest* compileRequest, + EntryPoint* entryPoint, + Int entryPointIndex, + TargetRequest* targetReq, + EndToEndCompileRequest* endToEndReq) + { + CompileResult result; + + auto target = targetReq->target; + + switch (target) + { + case CodeGenTarget::HLSL: + { + String code = emitHLSLForEntryPoint( + compileRequest, + entryPoint, + entryPointIndex, + targetReq, + endToEndReq); + maybeDumpIntermediate(compileRequest, code.getBuffer(), target); + result = CompileResult(code); + } + break; + + case CodeGenTarget::GLSL: + { + String code = emitGLSLForEntryPoint( + compileRequest, + entryPoint, + entryPointIndex, + targetReq, + endToEndReq); + maybeDumpIntermediate(compileRequest, code.getBuffer(), target); + result = CompileResult(code); + } + break; + + case CodeGenTarget::CPPSource: + case CodeGenTarget::CSource: + { + return emitEntryPoint( + compileRequest, + entryPoint, + target, + targetReq); + } + break; + +#if SLANG_ENABLE_DXBC_SUPPORT + case CodeGenTarget::DXBytecode: + { + List code; + if (SLANG_SUCCEEDED(emitDXBytecodeForEntryPoint( + compileRequest, + entryPoint, + entryPointIndex, + targetReq, + endToEndReq, + code))) + { + maybeDumpIntermediate(compileRequest, code.getBuffer(), code.getCount(), target); + result = CompileResult(code); + } + } + break; + + case CodeGenTarget::DXBytecodeAssembly: + { + String code; + if (SLANG_SUCCEEDED(emitDXBytecodeAssemblyForEntryPoint( + compileRequest, + entryPoint, + entryPointIndex, + targetReq, + endToEndReq, + code))) + { + maybeDumpIntermediate(compileRequest, code.getBuffer(), target); + result = CompileResult(code); + } + } + break; +#endif + +#if SLANG_ENABLE_DXIL_SUPPORT + case CodeGenTarget::DXIL: + { + List code; + if (SLANG_SUCCEEDED(emitDXILForEntryPointUsingDXC( + compileRequest, + entryPoint, + entryPointIndex, + targetReq, + endToEndReq, + code))) + { + maybeDumpIntermediate(compileRequest, code.getBuffer(), code.getCount(), target); + result = CompileResult(code); + } + } + break; + + case CodeGenTarget::DXILAssembly: + { + List code; + if (SLANG_SUCCEEDED(emitDXILForEntryPointUsingDXC( + compileRequest, + entryPoint, + entryPointIndex, + targetReq, + endToEndReq, + code))) + { + String assembly; + dissassembleDXILUsingDXC( + compileRequest, + code.getBuffer(), + code.getCount(), + assembly); + + maybeDumpIntermediate(compileRequest, assembly.getBuffer(), target); + + result = CompileResult(assembly); + } + } + break; +#endif + + case CodeGenTarget::SPIRV: + { + List code; + if (SLANG_SUCCEEDED(emitSPIRVForEntryPoint( + compileRequest, + entryPoint, + entryPointIndex, + targetReq, + endToEndReq, + code))) + { + maybeDumpIntermediate(compileRequest, code.getBuffer(), code.getCount(), target); + result = CompileResult(code); + } + } + break; + + case CodeGenTarget::SPIRVAssembly: + { + String code; + if (SLANG_SUCCEEDED(emitSPIRVAssemblyForEntryPoint( + compileRequest, + entryPoint, + entryPointIndex, + targetReq, + endToEndReq, + code))) + { + maybeDumpIntermediate(compileRequest, code.getBuffer(), target); + result = CompileResult(code); + } + } + break; + + case CodeGenTarget::None: + // The user requested no output + break; + + // Note(tfoley): We currently hit this case when compiling the stdlib + case CodeGenTarget::Unknown: + break; + + default: + SLANG_UNEXPECTED("unhandled code generation target"); + break; + } + + return result; + } + + enum class OutputFileKind + { + Text, + Binary, + }; + + static void writeOutputFile( + BackEndCompileRequest* compileRequest, + FILE* file, + String const& path, + void const* data, + size_t size) + { + size_t count = fwrite(data, size, 1, file); + if (count != 1) + { + compileRequest->getSink()->diagnose( + SourceLoc(), + Diagnostics::cannotWriteOutputFile, + path); + } + } + + static void writeOutputFile( + BackEndCompileRequest* compileRequest, + ISlangWriter* writer, + String const& path, + void const* data, + size_t size) + { + + if (SLANG_FAILED(writer->write((const char*)data, size))) + { + compileRequest->getSink()->diagnose( + SourceLoc(), + Diagnostics::cannotWriteOutputFile, + path); + } + } + + static void writeOutputFile( + BackEndCompileRequest* compileRequest, + String const& path, + void const* data, + size_t size, + OutputFileKind kind) + { + FILE* file = fopen( + path.getBuffer(), + kind == OutputFileKind::Binary ? "wb" : "w"); + if (!file) + { + compileRequest->getSink()->diagnose( + SourceLoc(), + Diagnostics::cannotWriteOutputFile, + path); + return; + } + + writeOutputFile(compileRequest, file, path, data, size); + fclose(file); + } + + static void writeEntryPointResultToFile( + BackEndCompileRequest* compileRequest, + EntryPoint* entryPoint, + String const& outputPath, + CompileResult const& result) + { + SLANG_UNUSED(entryPoint); + + switch (result.format) + { + case ResultFormat::Text: + { + auto text = result.outputString; + writeOutputFile(compileRequest, + outputPath, + text.begin(), + text.end() - text.begin(), + OutputFileKind::Text); + } + break; + + case ResultFormat::Binary: + { + auto& data = result.outputBinary; + writeOutputFile(compileRequest, + outputPath, + data.begin(), + data.end() - data.begin(), + OutputFileKind::Binary); + } + break; + + default: + SLANG_UNEXPECTED("unhandled output format"); + break; + } + + } + + static void writeOutputToConsole( + ISlangWriter* writer, + String const& text) + { + writer->write(text.getBuffer(), text.getLength()); + } + + static void writeEntryPointResultToStandardOutput( + EndToEndCompileRequest* compileRequest, + EntryPoint* entryPoint, + TargetRequest* targetReq, + CompileResult const& result) + { + SLANG_UNUSED(entryPoint); + + ISlangWriter* writer = compileRequest->getWriter(WriterChannel::StdOutput); + auto backEndReq = compileRequest->getBackEndReq(); + + switch (result.format) + { + case ResultFormat::Text: + writeOutputToConsole(writer, result.outputString); + break; + + case ResultFormat::Binary: + { + auto& data = result.outputBinary; + + if (writer->isConsole()) + { + // Writing to console, so we need to generate text output. + + switch (targetReq->target) + { + #if SLANG_ENABLE_DXBC_SUPPORT + case CodeGenTarget::DXBytecode: + { + String assembly; + dissassembleDXBC(backEndReq, + data.begin(), + data.end() - data.begin(), assembly); + writeOutputToConsole(writer, assembly); + } + break; + #endif + + #if SLANG_ENABLE_DXIL_SUPPORT + case CodeGenTarget::DXIL: + { + String assembly; + dissassembleDXILUsingDXC(backEndReq, + data.begin(), + data.end() - data.begin(), + assembly); + writeOutputToConsole(writer, assembly); + } + break; + #endif + + case CodeGenTarget::SPIRV: + { + String assembly; + dissassembleSPIRV(backEndReq, + data.begin(), + data.end() - data.begin(), assembly); + writeOutputToConsole(writer, assembly); + } + break; + + default: + SLANG_UNEXPECTED("unhandled output format"); + return; + } + } + else + { + // Redirecting stdout to a file, so do the usual thing + writer->setMode(SLANG_WRITER_MODE_BINARY); + + writeOutputFile( + backEndReq, + writer, + "stdout", + data.begin(), + data.end() - data.begin()); + } + } + break; + + default: + SLANG_UNEXPECTED("unhandled output format"); + break; + } + + } + + static void writeEntryPointResult( + EndToEndCompileRequest* compileRequest, + EntryPoint* entryPoint, + TargetRequest* targetReq, + Int entryPointIndex) + { + auto program = compileRequest->getSpecializedProgram(); + auto targetProgram = program->getTargetProgram(targetReq); + auto backEndReq = compileRequest->getBackEndReq(); + + auto& result = targetProgram->getExistingEntryPointResult(entryPointIndex); + + // Skip the case with no output + if (result.format == ResultFormat::None) + return; + + // It is possible that we are dynamically discovering entry + // points (using `[shader(...)]` attributes), so that there + // might be entry points added to the program that did not + // get paths specified via command-line options. + // + RefPtr targetInfo; + if(compileRequest->targetInfos.TryGetValue(targetReq, targetInfo)) + { + String outputPath; + if(targetInfo->entryPointOutputPaths.TryGetValue(entryPointIndex, outputPath)) + { + writeEntryPointResultToFile(backEndReq, entryPoint, outputPath, result); + return; + } + } + + writeEntryPointResultToStandardOutput(compileRequest, entryPoint, targetReq, result); + } + + void generateOutputForTarget( + BackEndCompileRequest* compileReq, + TargetRequest* targetReq, + EndToEndCompileRequest* endToEndReq) + { + auto program = compileReq->getProgram(); + auto targetProgram = program->getTargetProgram(targetReq); + + // Generate target code any entry points that + // have been requested for compilation. + auto entryPointCount = program->getEntryPointCount(); + for(Index ii = 0; ii < entryPointCount; ++ii) + { + auto entryPoint = program->getEntryPoint(ii); + CompileResult entryPointResult = emitEntryPoint( + compileReq, + entryPoint, + ii, + targetReq, + endToEndReq); + targetProgram->setEntryPointResult(ii, entryPointResult); + } + } + + static void _generateOutput( + BackEndCompileRequest* compileRequest, + EndToEndCompileRequest* endToEndReq) + { + // Go through the code-generation targets that the user + // has specified, and generate code for each of them. + // + auto linkage = compileRequest->getLinkage(); + for (auto targetReq : linkage->targets) + { + generateOutputForTarget(compileRequest, targetReq, endToEndReq); + } + } + + void generateOutput( + BackEndCompileRequest* compileRequest) + { + _generateOutput(compileRequest, nullptr); + } + + void generateOutput( + EndToEndCompileRequest* compileRequest) + { + _generateOutput(compileRequest->getBackEndReq(), compileRequest); + + // If we are in command-line mode, we might be expected to actually + // write output to one or more files here. + + if (compileRequest->isCommandLineCompile) + { + auto linkage = compileRequest->getLinkage(); + auto program = compileRequest->getSpecializedProgram(); + for (auto targetReq : linkage->targets) + { + Index entryPointCount = program->getEntryPointCount(); + for (Index ee = 0; ee < entryPointCount; ++ee) + { + writeEntryPointResult( + compileRequest, + program->getEntryPoint(ee), + targetReq, + ee); + } + } + } + } + + // Debug logic for dumping intermediate outputs + + // + + void dumpIntermediate( + BackEndCompileRequest*, + void const* data, + size_t size, + char const* ext, + bool isBinary) + { + // Try to generate a unique ID for the file to dump, + // even in cases where there might be multiple threads + // doing compilation. + // + // This is primarily a debugging aid, so we don't + // really need/want to do anything too elaborate + + static uint32_t counter = 0; +#ifdef WIN32 + uint32_t id = InterlockedIncrement(&counter); +#else + // TODO: actually implement the case for other platforms + uint32_t id = counter++; +#endif + + String path; + path.append("slang-dump-"); + path.append(id); + path.append(ext); + + FILE* file = fopen(path.getBuffer(), isBinary ? "wb" : "w"); + if (!file) return; + + fwrite(data, size, 1, file); + fclose(file); + } + + void dumpIntermediateText( + BackEndCompileRequest* compileRequest, + void const* data, + size_t size, + char const* ext) + { + dumpIntermediate(compileRequest, data, size, ext, false); + } + + void dumpIntermediateBinary( + BackEndCompileRequest* compileRequest, + void const* data, + size_t size, + char const* ext) + { + dumpIntermediate(compileRequest, data, size, ext, true); + } + + void maybeDumpIntermediate( + BackEndCompileRequest* compileRequest, + void const* data, + size_t size, + CodeGenTarget target) + { + if (!compileRequest->shouldDumpIntermediates) + return; + + switch (target) + { + default: + break; + + case CodeGenTarget::HLSL: + dumpIntermediateText(compileRequest, data, size, ".hlsl"); + break; + + case CodeGenTarget::GLSL: + dumpIntermediateText(compileRequest, data, size, ".glsl"); + break; + + case CodeGenTarget::SPIRVAssembly: + dumpIntermediateText(compileRequest, data, size, ".spv.asm"); + break; + +#if 0 + case CodeGenTarget::SlangIRAssembly: + dumpIntermediateText(compileRequest, data, size, ".slang-ir.asm"); + break; +#endif + + case CodeGenTarget::SPIRV: + dumpIntermediateBinary(compileRequest, data, size, ".spv"); + { + String spirvAssembly; + dissassembleSPIRV(compileRequest, data, size, spirvAssembly); + dumpIntermediateText(compileRequest, spirvAssembly.begin(), spirvAssembly.getLength(), ".spv.asm"); + } + break; + + #if SLANG_ENABLE_DXBC_SUPPORT + case CodeGenTarget::DXBytecodeAssembly: + dumpIntermediateText(compileRequest, data, size, ".dxbc.asm"); + break; + + case CodeGenTarget::DXBytecode: + dumpIntermediateBinary(compileRequest, data, size, ".dxbc"); + { + String dxbcAssembly; + dissassembleDXBC(compileRequest, data, size, dxbcAssembly); + dumpIntermediateText(compileRequest, dxbcAssembly.begin(), dxbcAssembly.getLength(), ".dxbc.asm"); + } + break; + #endif + + #if SLANG_ENABLE_DXIL_SUPPORT + case CodeGenTarget::DXILAssembly: + dumpIntermediateText(compileRequest, data, size, ".dxil.asm"); + break; + + case CodeGenTarget::DXIL: + dumpIntermediateBinary(compileRequest, data, size, ".dxil"); + { + String dxilAssembly; + dissassembleDXILUsingDXC(compileRequest, data, size, dxilAssembly); + dumpIntermediateText(compileRequest, dxilAssembly.begin(), dxilAssembly.getLength(), ".dxil.asm"); + } + break; + #endif + } + } + + void maybeDumpIntermediate( + BackEndCompileRequest* compileRequest, + char const* text, + CodeGenTarget target) + { + if (!compileRequest->shouldDumpIntermediates) + return; + + maybeDumpIntermediate(compileRequest, text, strlen(text), target); + } + +} diff --git a/source/slang/slang-compiler.h b/source/slang/slang-compiler.h new file mode 100644 index 000000000..9ea2f520c --- /dev/null +++ b/source/slang/slang-compiler.h @@ -0,0 +1,1423 @@ +#ifndef SLANG_COMPILER_H_INCLUDED +#define SLANG_COMPILER_H_INCLUDED + +#include "../core/slang-basic.h" +#include "../core/slang-shared-library.h" + +#include "../../slang-com-ptr.h" + +#include "slang-diagnostics.h" +#include "slang-name.h" +#include "slang-profile.h" +#include "slang-syntax.h" + +#include "../../slang.h" + +namespace Slang +{ + struct PathInfo; + struct IncludeHandler; + class ProgramLayout; + class PtrType; + class TargetProgram; + class TargetRequest; + class TypeLayout; + + enum class CompilerMode + { + ProduceLibrary, + ProduceShader, + GenerateChoice + }; + + enum class StageTarget + { + Unknown, + VertexShader, + HullShader, + DomainShader, + GeometryShader, + FragmentShader, + ComputeShader, + }; + + enum class CodeGenTarget + { + Unknown = SLANG_TARGET_UNKNOWN, + None = SLANG_TARGET_NONE, + GLSL = SLANG_GLSL, + GLSL_Vulkan = SLANG_GLSL_VULKAN, + GLSL_Vulkan_OneDesc = SLANG_GLSL_VULKAN_ONE_DESC, + HLSL = SLANG_HLSL, + SPIRV = SLANG_SPIRV, + SPIRVAssembly = SLANG_SPIRV_ASM, + DXBytecode = SLANG_DXBC, + DXBytecodeAssembly = SLANG_DXBC_ASM, + DXIL = SLANG_DXIL, + DXILAssembly = SLANG_DXIL_ASM, + CSource = SLANG_C_SOURCE, + CPPSource = SLANG_CPP_SOURCE, + }; + + enum class ContainerFormat + { + None = SLANG_CONTAINER_FORMAT_NONE, + SlangModule = SLANG_CONTAINER_FORMAT_SLANG_MODULE, + }; + + enum class LineDirectiveMode : SlangLineDirectiveMode + { + Default = SLANG_LINE_DIRECTIVE_MODE_DEFAULT, + None = SLANG_LINE_DIRECTIVE_MODE_NONE, + Standard = SLANG_LINE_DIRECTIVE_MODE_STANDARD, + GLSL = SLANG_LINE_DIRECTIVE_MODE_GLSL, + }; + + enum class ResultFormat + { + None, + Text, + Binary + }; + + // When storing the layout for a matrix-type + // value, we need to know whether it has been + // laid out with row-major or column-major + // storage. + // + enum MatrixLayoutMode + { + kMatrixLayoutMode_RowMajor = SLANG_MATRIX_LAYOUT_ROW_MAJOR, + kMatrixLayoutMode_ColumnMajor = SLANG_MATRIX_LAYOUT_COLUMN_MAJOR, + }; + + enum class DebugInfoLevel : SlangDebugInfoLevel + { + None = SLANG_DEBUG_INFO_LEVEL_NONE, + Minimal = SLANG_DEBUG_INFO_LEVEL_MINIMAL, + Standard = SLANG_DEBUG_INFO_LEVEL_STANDARD, + Maximal = SLANG_DEBUG_INFO_LEVEL_MAXIMAL, + }; + + enum class OptimizationLevel : SlangOptimizationLevel + { + None = SLANG_OPTIMIZATION_LEVEL_NONE, + Default = SLANG_OPTIMIZATION_LEVEL_DEFAULT, + High = SLANG_OPTIMIZATION_LEVEL_HIGH, + Maximal = SLANG_OPTIMIZATION_LEVEL_MAXIMAL, + }; + + class Linkage; + class Module; + class Program; + class FrontEndCompileRequest; + class BackEndCompileRequest; + class EndToEndCompileRequest; + class TranslationUnitRequest; + + // Result of compiling an entry point. + // Should only ever be string OR binary. + class CompileResult + { + public: + CompileResult() = default; + CompileResult(String const& str) : format(ResultFormat::Text), outputString(str) {} + CompileResult(List const& buffer) : format(ResultFormat::Binary), outputBinary(buffer) {} + + void append(CompileResult const& result); + + ComPtr getBlob(); + + ResultFormat format = ResultFormat::None; + String outputString; + List outputBinary; + + ComPtr blob; + }; + + /// Information collected about global or entry-point shader parameters + struct ShaderParamInfo + { + DeclRef paramDeclRef; + UInt firstExistentialTypeSlot = 0; + UInt existentialTypeSlotCount = 0; + }; + + /// Extended information specific to global shader parameters + struct GlobalShaderParamInfo : ShaderParamInfo + { + // Additional global-scope declarations that are conceptually + // declaring the "same" parameter as the `paramDeclRef`. + List> additionalParamDeclRefs; + }; + + /// A request for the front-end to find and validate an entry-point function + struct FrontEndEntryPointRequest : RefObject + { + public: + /// Create a request for an entry point. + FrontEndEntryPointRequest( + FrontEndCompileRequest* compileRequest, + int translationUnitIndex, + Name* name, + Profile profile); + + /// Get the parent front-end compile request. + FrontEndCompileRequest* getCompileRequest() { return m_compileRequest; } + + /// Get the translation unit that contains the entry point. + TranslationUnitRequest* getTranslationUnit(); + + /// Get the name of the entry point to find. + Name* getName() { return m_name; } + + /// Get the stage that the entry point is to be compiled for + Stage getStage() { return m_profile.GetStage(); } + + /// Get the profile that the entry point is to be compiled for + Profile getProfile() { return m_profile; } + + private: + // The parent compile request + FrontEndCompileRequest* m_compileRequest; + + // The index of the translation unit that will hold the entry point + int m_translationUnitIndex; + + // The name of the entry point function to look for + Name* m_name; + + // The profile to compile for (including stage) + Profile m_profile; + }; + + /// Tracks an ordered list of modules that something depends on. + struct ModuleDependencyList + { + public: + /// Get the list of modules that are depended on. + List> const& getModuleList() { return m_moduleList; } + + /// Add a module and everything it depends on to the list. + void addDependency(Module* module); + + /// Add a module to the list, but not the modules it depends on. + void addLeafDependency(Module* module); + + private: + void _addDependency(Module* module); + + List> m_moduleList; + HashSet m_moduleSet; + }; + + /// Tracks an unordered list of filesystem paths that something depends on + struct FilePathDependencyList + { + public: + /// Get the list of paths that are depended on. + List const& getFilePathList() { return m_filePathList; } + + /// Add a path to the list, if it is not already present + void addDependency(String const& path); + + /// Add all of the paths that `module` depends on to the list + void addDependency(Module* module); + + private: + + // TODO: We are using a `HashSet` here to deduplicate + // the paths so that we don't return the same path + // multiple times from `getFilePathList`, but because + // order isn't important, we could potentially do better + // in terms of memory (at some cost in performance) by + // just sorting the `m_filePathList` every once in + // a while and then deduplicating. + + List m_filePathList; + HashSet m_filePathSet; + }; + + /// Describes an entry point for the purposes of layout and code generation. + /// + /// This class also tracks any generic arguments to the entry point, + /// in the case that it is a specialization of a generic entry point. + /// + /// There is also a provision for creating a "dummy" entry point for + /// the purposes of pass-through compilation modes. Only the + /// `getName()` and `getProfile()` methods should be expected to + /// return useful data on pass-through entry points. + /// + class EntryPoint : public RefObject + { + public: + /// Create an entry point that refers to the given function. + static RefPtr create( + DeclRef funcDeclRef, + Profile profile); + + /// Get the function decl-ref, including any generic arguments. + DeclRef getFuncDeclRef() { return m_funcDeclRef; } + + /// Get the function declaration (without generic arguments). + RefPtr getFuncDecl() { return m_funcDeclRef.getDecl(); } + + /// Get the name of the entry point + Name* getName() { return m_name; } + + /// Get the profile associated with the entry point + /// + /// Note: only the stage part of the profile is expected + /// to contain useful data, but certain legacy code paths + /// allow for "shader model" information to come via this path. + /// + Profile getProfile() { return m_profile; } + + /// Get the stage that the entry point is for. + Stage getStage() { return m_profile.GetStage(); } + + /// Get the module that contains the entry point. + Module* getModule(); + + /// Get the linkage that contains the module for this entry point. + Linkage* getLinkage(); + + /// Get a list of modules that this entry point depends on. + /// + /// This will include the module that defines the entry point (see `getModule()`), + /// but may also include modules that are required by its generic type arguments. + /// + List> getModuleDependencies() { return m_dependencyList.getModuleList(); } + + /// Get a list of tagged-union types referenced by the entry point's generic parameters. + List> const& getTaggedUnionTypes() { return m_taggedUnionTypes; } + + /// Create a dummy `EntryPoint` that is only usable for pass-through compilation. + static RefPtr createDummyForPassThrough( + Name* name, + Profile profile); + + /// Get the number of existential type parameters for the entry point. + Index getExistentialTypeParamCount() { return m_existentialSlots.paramTypes.getCount(); } + + /// Get the existential type parameter at `index`. + Type* getExistentialTypeParam(Index index) { return m_existentialSlots.paramTypes[index]; } + + /// Get the number of arguments supplied for existential type parameters. + /// + /// Note that the number of arguments may not match the number of parameters. + /// In particular, an unspecialized entry point may have many parameters, but zero arguments. + Index getExistentialTypeArgCount() { return m_existentialSlots.args.getCount(); } + + /// Get the existential type argument (type and witness table) at `index`. + ExistentialTypeSlots::Arg getExistentialTypeArg(Index index) { return m_existentialSlots.args[index]; } + + /// Get an array of all existential type arguments. + ExistentialTypeSlots::Arg const* getExistentialTypeArgs() { return m_existentialSlots.args.getBuffer(); } + + /// Get an array of all entry-point shader parameters. + List const& getShaderParams() { return m_shaderParams; } + + void _specializeExistentialTypeParams( + List> const& args, + DiagnosticSink* sink); + + private: + EntryPoint( + Name* name, + Profile profile, + DeclRef funcDeclRef); + + void _collectShaderParams(); + + // The name of the entry point function (e.g., `main`) + // + Name* m_name = nullptr; + + // The declaration of the entry-point function itself. + // + DeclRef m_funcDeclRef; + + /// The existential/interface slots associated with the entry point parameter scope. + ExistentialTypeSlots m_existentialSlots; + + /// Information about entry-point parameters + List m_shaderParams; + + // The profile that the entry point will be compiled for + // (this is a combination of the target stage, and also + // a feature level that sets capabilities) + // + // Note: the profile-version part of this should probably + // be moving towards deprecation, in favor of the version + // information (e.g., "Shader Model 5.1") always coming + // from the target, while the stage part is all that is + // intrinsic to the entry point. + // + Profile m_profile; + + // Any tagged union types that were referenced by the generic arguments of the entry point. + List> m_taggedUnionTypes; + + // Modules the entry point depends on. + ModuleDependencyList m_dependencyList; + }; + + enum class PassThroughMode : SlangPassThrough + { + None = SLANG_PASS_THROUGH_NONE, // don't pass through: use Slang compiler + fxc = SLANG_PASS_THROUGH_FXC, // pass through HLSL to `D3DCompile` API + dxc = SLANG_PASS_THROUGH_DXC, // pass through HLSL to `IDxcCompiler` API + glslang = SLANG_PASS_THROUGH_GLSLANG, // pass through GLSL to `glslang` library + }; + + class SourceFile; + + /// A module of code that has been compiled through the front-end + /// + /// A module comprises all the code from one translation unit (which + /// may span multiple Slang source files), and provides access + /// to both the AST and IR representations of that code. + /// + class Module : public RefObject + { + public: + /// Create a module (initially empty). + Module(Linkage* linkage); + + /// Get the parent linkage of this module. + Linkage* getLinkage() { return m_linkage; } + + /// Get the AST for the module (if it has been parsed) + ModuleDecl* getModuleDecl() { return m_moduleDecl; } + + /// The the IR for the module (if it has been generated) + IRModule* getIRModule() { return m_irModule; } + + /// Get the list of other modules this module depends on + List> const& getModuleDependencyList() { return m_moduleDependencyList.getModuleList(); } + + /// Get the list of filesystem paths this module depends on + List const& getFilePathDependencyList() { return m_filePathDependencyList.getFilePathList(); } + + /// Register a module that this module depends on + void addModuleDependency(Module* module); + + /// Register a filesystem path that this module depends on + void addFilePathDependency(String const& path); + + /// Set the AST for this module. + /// + /// This should only be called once, during creation of the module. + /// + void setModuleDecl(ModuleDecl* moduleDecl) { m_moduleDecl = moduleDecl; } + + /// Set the IR for this module. + /// + /// This should only be called once, during creation of the module. + /// + void setIRModule(IRModule* irModule) { m_irModule = irModule; } + + private: + // The parent linkage + Linkage* m_linkage = nullptr; + + // The AST for the module + RefPtr m_moduleDecl; + + // The IR for the module + RefPtr m_irModule = nullptr; + + // List of modules this module depends on + ModuleDependencyList m_moduleDependencyList; + + // List of filesystem paths this module depends on + FilePathDependencyList m_filePathDependencyList; + }; + typedef Module LoadedModule; + + /// A request for the front-end to compile a translation unit. + class TranslationUnitRequest : public RefObject + { + public: + TranslationUnitRequest( + FrontEndCompileRequest* compileRequest); + + // The parent compile request + FrontEndCompileRequest* compileRequest = nullptr; + + // The language in which the source file(s) + // are assumed to be written + SourceLanguage sourceLanguage = SourceLanguage::Unknown; + + // The source file(s) that will be compiled to form this translation unit + // + // Usually, for HLSL or GLSL there will be only one file. + List m_sourceFiles; + + List const& getSourceFiles() { return m_sourceFiles; } + void addSourceFile(SourceFile* sourceFile); + + // The entry points associated with this translation unit + List> entryPoints; + + // Preprocessor definitions to use for this translation unit only + // (whereas the ones on `compileRequest` will be shared) + Dictionary preprocessorDefinitions; + + /// The name that will be used for the module this translation unit produces. + Name* moduleName = nullptr; + + /// Result of compiling this translation unit (a module) + RefPtr module; + + Module* getModule() { return module; } + RefPtr getModuleDecl() { return module->getModuleDecl(); } + + Session* getSession(); + NamePool* getNamePool(); + SourceManager* getSourceManager(); + }; + + enum class FloatingPointMode : SlangFloatingPointMode + { + Default = SLANG_FLOATING_POINT_MODE_DEFAULT, + Fast = SLANG_FLOATING_POINT_MODE_FAST, + Precise = SLANG_FLOATING_POINT_MODE_PRECISE, + }; + + enum class WriterChannel : SlangWriterChannel + { + Diagnostic = SLANG_WRITER_CHANNEL_DIAGNOSTIC, + StdOutput = SLANG_WRITER_CHANNEL_STD_OUTPUT, + StdError = SLANG_WRITER_CHANNEL_STD_ERROR, + CountOf = SLANG_WRITER_CHANNEL_COUNT_OF, + }; + + enum class WriterMode : SlangWriterMode + { + Text = SLANG_WRITER_MODE_TEXT, + Binary = SLANG_WRITER_MODE_BINARY, + }; + + /// A request to generate output in some target format. + class TargetRequest : public RefObject + { + public: + Linkage* linkage; + CodeGenTarget target; + SlangTargetFlags targetFlags = 0; + Slang::Profile targetProfile = Slang::Profile(); + FloatingPointMode floatingPointMode = FloatingPointMode::Default; + + Linkage* getLinkage() { return linkage; } + CodeGenTarget getTarget() { return target; } + Profile getTargetProfile() { return targetProfile; } + FloatingPointMode getFloatingPointMode() { return floatingPointMode; } + + Session* getSession(); + MatrixLayoutMode getDefaultMatrixLayoutMode(); + + // TypeLayouts created on the fly by reflection API + Dictionary> typeLayouts; + + Dictionary>& getTypeLayouts() { return typeLayouts; } + }; + + /// Are we generating code for a D3D API? + bool isD3DTarget(TargetRequest* targetReq); + + /// Are we generating code for a Khronos API (OpenGL or Vulkan)? + bool isKhronosTarget(TargetRequest* targetReq); + + // Compute the "effective" profile to use when outputting the given entry point + // for the chosen code-generation target. + // + // The stage of the effective profile will always come from the entry point, while + // the profile version (aka "shader model") will be computed as follows: + // + // - If the entry point and target belong to the same profile family, then take + // the latest version between the two (e.g., if the entry point specified `ps_5_1` + // and the target specifies `sm_5_0` then use `sm_5_1` as the version). + // + // - If the entry point and target disagree on the profile family, always use the + // profile family and version from the target. + // + Profile getEffectiveProfile(EntryPoint* entryPoint, TargetRequest* target); + + + // A directory to be searched when looking for files (e.g., `#include`) + struct SearchDirectory + { + SearchDirectory() = default; + SearchDirectory(SearchDirectory const& other) = default; + SearchDirectory(String const& path) + : path(path) + {} + + String path; + }; + + /// A list of directories to search for files (e.g., `#include`) + struct SearchDirectoryList + { + // A parent list that should also be searched + SearchDirectoryList* parent = nullptr; + + // Directories to be searched + List searchDirectories; + }; + + /// Create a blob that will retain (a copy of) raw data. + /// + ComPtr createRawBlob(void const* data, size_t size); + + /// A context for loading and re-using code modules. + class Linkage : public RefObject + { + public: + /// Create an initially-empty linkage + Linkage(Session* session); + + /// Get the parent session for this linkage + Session* getSession() { return m_session; } + + // Information on the targets we are being asked to + // generate code for. + List> targets; + + // Directories to search for `#include` files or `import`ed modules + SearchDirectoryList searchDirectories; + + SearchDirectoryList const& getSearchDirectories() { return searchDirectories; } + + // Definitions to provide during preprocessing + Dictionary preprocessorDefinitions; + + // Source manager to help track files loaded + SourceManager m_defaultSourceManager; + SourceManager* m_sourceManager = nullptr; + + // Name pool for looking up names + NamePool namePool; + + NamePool* getNamePool() { return &namePool; } + + // Modules that have been dynamically loaded via `import` + // + // This is a list of unique modules loaded, in the order they were encountered. + List > loadedModulesList; + + // Map from the path (or uniqueIdentity if available) of a module file to its definition + Dictionary> mapPathToLoadedModule; + + // Map from the logical name of a module to its definition + Dictionary> mapNameToLoadedModules; + + // The resulting specialized IR module for each entry point request + List> compiledModules; + + /// File system implementation to use when loading files from disk. + /// + /// If this member is `null`, a default implementation that tries + /// to use the native OS filesystem will be used instead. + /// + ComPtr fileSystem; + + /// The extended file system implementation. Will be set to a default implementation + /// if fileSystem is nullptr. Otherwise it will either be fileSystem's interface, + /// or a wrapped impl that makes fileSystem operate as fileSystemExt + ComPtr fileSystemExt; + + ISlangFileSystemExt* getFileSystemExt() { return fileSystemExt; } + + /// Load a file into memory using the configured file system. + /// + /// @param path The path to attempt to load from + /// @param outBlob A destination pointer to receive the loaded blob + /// @returns A `SlangResult` to indicate success or failure. + /// + SlangResult loadFile(String const& path, ISlangBlob** outBlob); + + + RefPtr parseTypeString(String typeStr, RefPtr scope); + + Type* specializeType( + Type* unspecializedType, + Int argCount, + Type* const* args, + DiagnosticSink* sink); + + /// Add a mew target amd return its index. + UInt addTarget( + CodeGenTarget target); + + RefPtr loadModule( + Name* name, + const PathInfo& filePathInfo, + ISlangBlob* fileContentsBlob, + SourceLoc const& loc, + DiagnosticSink* sink); + + void loadParsedModule( + RefPtr translationUnit, + Name* name, + PathInfo const& pathInfo); + + /// Load a module of the given name. + Module* loadModule(String const& name); + + RefPtr findOrImportModule( + Name* name, + SourceLoc const& loc, + DiagnosticSink* sink); + + SourceManager* getSourceManager() + { + return m_sourceManager; + } + + /// Override the source manager for the linakge. + /// + /// This is only used to install a temporary override when + /// parsing stuff from strings (where we don't want to retain + /// full source files for the parsed result). + /// + /// TODO: We should remove the need for this hack. + /// + void setSourceManager(SourceManager* sourceManager) + { + m_sourceManager = sourceManager; + } + + void setFileSystem(ISlangFileSystem* fileSystem); + + /// The layout to use for matrices by default (row/column major) + MatrixLayoutMode defaultMatrixLayoutMode = kMatrixLayoutMode_ColumnMajor; + MatrixLayoutMode getDefaultMatrixLayoutMode() { return defaultMatrixLayoutMode; } + + DebugInfoLevel debugInfoLevel = DebugInfoLevel::None; + + OptimizationLevel optimizationLevel = OptimizationLevel::Default; + + private: + Session* m_session = nullptr; + + /// Tracks state of modules currently being loaded. + /// + /// This information is used to diagnose cases where + /// a user tries to recursively import the same module + /// (possibly along a transitive chain of `import`s). + /// + struct ModuleBeingImportedRAII + { + public: + ModuleBeingImportedRAII( + Linkage* linkage, + Module* module) + : linkage(linkage) + , module(module) + { + next = linkage->m_modulesBeingImported; + linkage->m_modulesBeingImported = this; + } + + ~ModuleBeingImportedRAII() + { + linkage->m_modulesBeingImported = next; + } + + Linkage* linkage; + Module* module; + ModuleBeingImportedRAII* next; + }; + + // Any modules currently being imported will be listed here + ModuleBeingImportedRAII* m_modulesBeingImported = nullptr; + + /// Is the given module in the middle of being imported? + bool isBeingImported(Module* module); + + List> m_specializedTypes; + }; + + /// Shared functionality between front- and back-end compile requests. + /// + /// This is the base class for both `FrontEndCompileRequest` and + /// `BackEndCompileRequest`, and allows a small number of parts of + /// the compiler to be easily invocable from either front-end or + /// back-end work. + /// + class CompileRequestBase : public RefObject + { + // TODO: We really shouldn't need this type in the long run. + // The few places that rely on it should be refactored to just + // depend on the underlying information (a linkage and a diagnostic + // sink) directly. + // + // The flags to control dumping and validation of IR should be + // moved to some kind of shared settings/options `struct` that + // both front-end and back-end requests can store. + + public: + Session* getSession(); + Linkage* getLinkage() { return m_linkage; } + DiagnosticSink* getSink() { return m_sink; } + SourceManager* getSourceManager() { return getLinkage()->getSourceManager(); } + NamePool* getNamePool() { return getLinkage()->getNamePool(); } + ISlangFileSystemExt* getFileSystemExt() { return getLinkage()->getFileSystemExt(); } + SlangResult loadFile(String const& path, ISlangBlob** outBlob) { return getLinkage()->loadFile(path, outBlob); } + + bool shouldDumpIR = false; + bool shouldValidateIR = false; + + protected: + CompileRequestBase( + Linkage* linkage, + DiagnosticSink* sink); + + private: + Linkage* m_linkage = nullptr; + DiagnosticSink* m_sink = nullptr; + }; + + /// A request to compile source code to an AST + IR. + class FrontEndCompileRequest : public CompileRequestBase + { + public: + FrontEndCompileRequest( + Linkage* linkage, + DiagnosticSink* sink); + + int addEntryPoint( + int translationUnitIndex, + String const& name, + Profile entryPointProfile); + + // Translation units we are being asked to compile + List > translationUnits; + + RefPtr getTranslationUnit(UInt index) { return translationUnits[index]; } + + // Compile flags to be shared by all translation units + SlangCompileFlags compileFlags = 0; + + // If true then generateIR will serialize out IR, and serialize back in again. Making + // serialization a bottleneck or firewall between the front end and the backend + bool useSerialIRBottleneck = false; + + // If true will serialize and de-serialize with debug information + bool verifyDebugSerialization = false; + + List> m_entryPointReqs; + + List> const& getEntryPointReqs() { return m_entryPointReqs; } + UInt getEntryPointReqCount() { return m_entryPointReqs.getCount(); } + FrontEndEntryPointRequest* getEntryPointReq(UInt index) { return m_entryPointReqs[index]; } + + // Directories to search for `#include` files or `import`ed modules + // NOTE! That for now these search directories are not settable via the API + // so the search directories on Linkage is used for #include as well as for modules. + SearchDirectoryList searchDirectories; + + SearchDirectoryList const& getSearchDirectories() { return searchDirectories; } + + // Definitions to provide during preprocessing + Dictionary preprocessorDefinitions; + + void parseTranslationUnit( + TranslationUnitRequest* translationUnit); + + // Perform primary semantic checking on all + // of the translation units in the program + void checkAllTranslationUnits(); + + void generateIR(); + + SlangResult executeActionsInner(); + + /// Add a translation unit to be compiled. + /// + /// @param language The source language that the translation unit will use (e.g., `SourceLanguage::Slang` + /// @param moduleName The name that will be used for the module compile from the translation unit. + /// @return The zero-based index of the translation unit in this compile request. + int addTranslationUnit(SourceLanguage language, Name* moduleName); + + /// Add a translation unit to be compiled. + /// + /// @param language The source language that the translation unit will use (e.g., `SourceLanguage::Slang` + /// @return The zero-based index of the translation unit in this compile request. + /// + /// The module name for the translation unit will be automatically generated. + /// If all translation units in a compile request use automatically generated + /// module names, then they are guaranteed not to conflict with one another. + /// + int addTranslationUnit(SourceLanguage language); + + void addTranslationUnitSourceFile( + int translationUnitIndex, + SourceFile* sourceFile); + + void addTranslationUnitSourceBlob( + int translationUnitIndex, + String const& path, + ISlangBlob* sourceBlob); + + void addTranslationUnitSourceString( + int translationUnitIndex, + String const& path, + String const& source); + + void addTranslationUnitSourceFile( + int translationUnitIndex, + String const& path); + + Program* getProgram() { return m_program; } + + private: + RefPtr m_program; + }; + + /// A collection of code modules and entry points that are intended to be used together. + /// + /// A `Program` establishes that certain pieces of code are intended + /// to be used togehter so that, e.g., layout can make sure to allocate + /// space for the global shader parameters in all referenced modules. + /// + class Program : public RefObject + { + public: + /// Create a new program, initially empty. + /// + /// All code loaded into the program must come + /// from the given `linkage`. + Program( + Linkage* linkage); + + /// Get the linkage that this program uses. + Linkage* getLinkage() { return m_linkage; } + + /// Get the number of entry points added to the program + Index getEntryPointCount() { return m_entryPoints.getCount(); } + + /// Get the entry point at the given `index`. + RefPtr getEntryPoint(Index index) { return m_entryPoints[index]; } + + /// Get the full ist of entry points on the program. + List> const& getEntryPoints() { return m_entryPoints; } + + /// Get the substitution (if any) that represents how global generics are specialized. + RefPtr getGlobalGenericSubstitution() { return m_globalGenericSubst; } + + /// Get the full list of modules this program depends on + List> getModuleDependencies() { return m_moduleDependencyList.getModuleList(); } + + /// Get the full list of filesystem paths this program depends on + List getFilePathDependencies() { return m_filePathDependencyList.getFilePathList(); } + + /// Get the target-specific version of this program for the given `target`. + /// + /// The `target` must be a target on the `Linkage` that was used to create this program. + TargetProgram* getTargetProgram(TargetRequest* target); + + /// Add a module (and everything it depends on) to the list of references + void addReferencedModule(Module* module); + + /// Add a module (but not the things it depends on) to the list of references + /// + /// This is a compatiblity hack for legacy compiler behavior. + void addReferencedLeafModule(Module* module); + + + /// Add an entry point to the program + /// + /// This also adds everything the entry point depends on to the list of references. + /// + void addEntryPoint(EntryPoint* entryPoint); + + /// Set the global generic argument substitution to use. + void setGlobalGenericSubsitution(RefPtr subst) + { + m_globalGenericSubst = subst; + } + + /// Parse a type from a string, in the context of this program. + /// + /// Any names in the string will be resolved using the modules + /// referenced by the program. + /// + /// On an error, returns null and reports diagnostic messages + /// to the provided `sink`. + /// + Type* getTypeFromString(String typeStr, DiagnosticSink* sink); + + /// Get the IR module that represents this program and its entry points. + /// + /// The IR module for a program tries to be minimal, and in the + /// common case will only include symbols with `[import]` declarations + /// for the entry point(s) of the program, and any types they + /// depend on. + /// + /// This IR module is intended to be linked against the IR modules + /// for all of the dependencies (see `getModuleDependencies()`) to + /// provide complete code. + /// + RefPtr getOrCreateIRModule(DiagnosticSink* sink); + + /// Get the number of existential type parameters for the program. + Index getExistentialTypeParamCount() { return m_globalExistentialSlots.paramTypes.getCount(); } + + /// Get the existential type parameter at `index`. + Type* getExistentialTypeParam(Index index) { return m_globalExistentialSlots.paramTypes[index]; } + + /// Get the number of arguments supplied for existential type parameters. + /// + /// Note that the number of arguments may not match the number of parameters. + /// In particular, an unspecialized program may have many parameters, but zero arguments. + Index getExistentialTypeArgCount() { return m_globalExistentialSlots.args.getCount(); } + + /// Get the existential type argument (type and witness table) at `index`. + ExistentialTypeSlots::Arg getExistentialTypeArg(Index index) { return m_globalExistentialSlots.args[index]; } + + /// Get an array of all existential type arguments. + ExistentialTypeSlots::Arg const* getExistentialTypeArgs() { return m_globalExistentialSlots.args.getBuffer(); } + + /// Get an array of all global shader parameters. + List const& getShaderParams() { return m_shaderParams; } + + void _collectShaderParams(DiagnosticSink* sink); + void _specializeExistentialTypeParams( + List> const& args, + DiagnosticSink* sink); + + private: + + // The linakge this program is associated with. + // + // Note that a `Program` keeps its associated linkage alive, + // and not vice versa. + // + RefPtr m_linkage; + + // Tracking data for the list of modules dependend on + ModuleDependencyList m_moduleDependencyList; + + // Tracking data for the list of filesystem paths dependend on + FilePathDependencyList m_filePathDependencyList; + + // Entry points that are part of the program. + List > m_entryPoints; + + // Specializations for global generic parameters (if any) + RefPtr m_globalGenericSubst; + + // The existential/interface slots associated with the global scope. + ExistentialTypeSlots m_globalExistentialSlots; + + /// Information about global shader parameters + List m_shaderParams; + + // Generated IR for this program. + RefPtr m_irModule; + + // Cache of target-specific programs for each target. + Dictionary> m_targetPrograms; + + // Any types looked up dynamically using `getTypeFromString` + Dictionary> m_types; + }; + + /// A `Program` specialized for a particular `TargetRequest` + class TargetProgram : public RefObject + { + public: + TargetProgram( + Program* program, + TargetRequest* targetReq); + + /// Get the underlying program + Program* getProgram() { return m_program; } + + /// Get the underlying target + TargetRequest* getTargetReq() { return m_targetReq; } + + /// Get the layout for the program on the target. + /// + /// If this is the first time the layout has been + /// requested, report any errors that arise during + /// layout to the given `sink`. + /// + ProgramLayout* getOrCreateLayout(DiagnosticSink* sink); + + /// Get the layout for the program on the taarget. + /// + /// This routine assumes that `getOrCreateLayout` + /// has already been called previously. + /// + ProgramLayout* getExistingLayout() + { + SLANG_ASSERT(m_layout); + return m_layout; + } + + /// Get the compiled code for an entry point on the target. + /// + /// This routine assumes code generation has already been + /// performed and called `setEntryPointResult`. + /// + CompileResult& getExistingEntryPointResult(Int entryPointIndex) + { + return m_entryPointResults[entryPointIndex]; + } + + // TODO: Need a lazy `getOrCreateEntryPointResult` + + /// Set the compiled code for an entry point. + /// + /// Should only be called by code generation. + void setEntryPointResult(Int entryPointIndex, CompileResult const& result) + { + m_entryPointResults[entryPointIndex] = result; + } + + private: + // The program being compiled or laid out + Program* m_program; + + // The target that code/layout will be generated for + TargetRequest* m_targetReq; + + // The computed layout, if it has been generated yet + RefPtr m_layout; + + // Generated compile results for each entry point + // in the parent `Program` (indexing matches + // the order they are given in the `Program`) + List m_entryPointResults; + }; + + /// A request to generate code for a program + class BackEndCompileRequest : public CompileRequestBase + { + public: + BackEndCompileRequest( + Linkage* linkage, + DiagnosticSink* sink, + Program* program = nullptr); + + // Should we dump intermediate results along the way, for debugging? + bool shouldDumpIntermediates = false; + + // How should `#line` directives be emitted (if at all)? + LineDirectiveMode lineDirectiveMode = LineDirectiveMode::Default; + + LineDirectiveMode getLineDirectiveMode() { return lineDirectiveMode; } + + Program* getProgram() { return m_program; } + void setProgram(Program* program) { m_program = program; } + + // Should R/W images without explicit formats be assumed to have "unknown" format? + // + // The default behavior is to make a best-effort guess as to what format is intended. + // + bool useUnknownImageFormatAsDefault = false; + + private: + RefPtr m_program; + }; + + /// A compile request that spans the front and back ends of the compiler + /// + /// This is what the command-line `slangc` uses, as well as the legacy + /// C API. It ties together the functionality of `Linkage`, + /// `FrontEndCompileRequest`, and `BackEndCompileRequest`, plus a small + /// number of additional features that primarily make sense for + /// command-line usage. + /// + class EndToEndCompileRequest : public RefObject + { + public: + EndToEndCompileRequest( + Session* session); + + // What container format are we being asked to generate? + // + // Note: This field is unused except by the options-parsing + // logic; it exists to support wriiting out binary modules + // once that feature is ready. + // + ContainerFormat containerFormat = ContainerFormat::None; + + // Path to output container to + // + // Note: This field exists to support wriiting out binary modules + // once that feature is ready. + // + String containerOutputPath; + + // Should we just pass the input to another compiler? + PassThroughMode passThrough = PassThroughMode::None; + + /// Source code for the generic arguments to use for the global generic parameters of the program. + List globalGenericArgStrings; + + /// Types to use to fill global existential "slots" + List globalExistentialSlotArgStrings; + + bool shouldSkipCodegen = false; + + // Are we being driven by the command-line `slangc`, and should act accordingly? + bool isCommandLineCompile = false; + + String mDiagnosticOutput; + + /// A blob holding the diagnostic output + ComPtr diagnosticOutputBlob; + + /// Per-entry-point information not tracked by other compile requests + class EntryPointInfo : public RefObject + { + public: + /// Source code for the generic arguments to use for the generic parameters of the entry point. + List genericArgStrings; + + /// Source code for the type arguments to plug into the existential type "slots" of the entry point + List existentialArgStrings; + }; + List entryPoints; + + /// Per-target information only needed for command-line compiles + class TargetInfo : public RefObject + { + public: + // Requested output paths for each entry point. + // An empty string indices no output desired for + // the given entry point. + Dictionary entryPointOutputPaths; + }; + Dictionary> targetInfos; + + Linkage* getLinkage() { return m_linkage; } + + int addEntryPoint( + int translationUnitIndex, + String const& name, + Profile profile, + List const & genericTypeNames); + + void setWriter(WriterChannel chan, ISlangWriter* writer); + ISlangWriter* getWriter(WriterChannel chan) const { return m_writers[int(chan)]; } + + SlangResult executeActionsInner(); + SlangResult executeActions(); + + Session* getSession() { return m_session; } + DiagnosticSink* getSink() { return &m_sink; } + NamePool* getNamePool() { return getLinkage()->getNamePool(); } + + FrontEndCompileRequest* getFrontEndReq() { return m_frontEndReq; } + BackEndCompileRequest* getBackEndReq() { return m_backEndReq; } + Program* getUnspecializedProgram() { return getFrontEndReq()->getProgram(); } + Program* getSpecializedProgram() { return m_specializedProgram; } + + private: + Session* m_session = nullptr; + RefPtr m_linkage; + DiagnosticSink m_sink; + RefPtr m_frontEndReq; + RefPtr m_unspecializedProgram; + RefPtr m_specializedProgram; + RefPtr m_backEndReq; + + // For output + ComPtr m_writers[SLANG_WRITER_CHANNEL_COUNT_OF]; + }; + + void generateOutput( + BackEndCompileRequest* compileRequest); + + void generateOutput( + EndToEndCompileRequest* compileRequest); + + // Helper to dump intermediate output when debugging + void maybeDumpIntermediate( + BackEndCompileRequest* compileRequest, + void const* data, + size_t size, + CodeGenTarget target); + void maybeDumpIntermediate( + BackEndCompileRequest* compileRequest, + char const* text, + CodeGenTarget target); + + /* Returns SLANG_OK if a codeGen target is available. */ + SlangResult checkCompileTargetSupport(Session* session, CodeGenTarget target); + /* Returns SLANG_OK if pass through support is available */ + SlangResult checkExternalCompilerSupport(Session* session, PassThroughMode passThrough); + + /* Report an error appearing from external compiler to the diagnostic sink error to the diagnostic sink. + @param compilerName The name of the compiler the error came for (or nullptr if not known) + @param res Result associated with the error. The error code will be reported. (Can take HRESULT - and will expand to string if known) + @param diagnostic The diagnostic string associated with the compile failure + @param sink The diagnostic sink to report to */ + void reportExternalCompileError(const char* compilerName, SlangResult res, const UnownedStringSlice& diagnostic, DiagnosticSink* sink); + + /* Determines a suitable filename to identify the input for a given entry point being compiled. + If the end-to-end compile is a pass-through case, will attempt to find the (unique) source file + pathname for the translation unit containing the entry point at `entryPointIndex. + If the compilation is not in a pass-through case, then always returns `"slang-generated"`. + @param endToEndReq The end-to-end compile request which might be using pass-through copmilation + @param entryPointIndex The index of the entry point to compute a filename for. + @return the appropriate source filename */ + String calcSourcePathForEntryPoint(EndToEndCompileRequest* endToEndReq, UInt entryPointIndex); + + struct TypeCheckingCache; + // + + class Session + { + public: + enum class SharedLibraryFuncType + { + Glslang_Compile, + Fxc_D3DCompile, + Fxc_D3DDisassemble, + Dxc_DxcCreateInstance, + CountOf, + }; + + // + + RefPtr baseLanguageScope; + RefPtr coreLanguageScope; + RefPtr hlslLanguageScope; + RefPtr slangLanguageScope; + + List> loadedModuleCode; + + SourceManager builtinSourceManager; + + SourceManager* getBuiltinSourceManager() { return &builtinSourceManager; } + + // Name pool stuff for unique-ing identifiers + + RootNamePool rootNamePool; + NamePool namePool; + + RootNamePool* getRootNamePool() { return &rootNamePool; } + NamePool* getNamePool() { return &namePool; } + Name* getNameObj(String name) { return namePool.getName(name); } + Name* tryGetNameObj(String name) { return namePool.tryGetName(name); } + // + + // Generated code for stdlib, etc. + String stdlibPath; + String coreLibraryCode; + String slangLibraryCode; + String hlslLibraryCode; + String glslLibraryCode; + + String getStdlibPath(); + String getCoreLibraryCode(); + String getHLSLLibraryCode(); + + // Basic types that we don't want to re-create all the time + RefPtr errorType; + RefPtr initializerListType; + RefPtr overloadedType; + RefPtr constExprRate; + RefPtr irBasicBlockType; + + RefPtr stringType; + RefPtr enumTypeType; + + ComPtr sharedLibraryLoader; ///< The shared library loader (never null) + ComPtr sharedLibraries[int(SharedLibraryType::CountOf)]; ///< The loaded shared libraries + SlangFuncPtr sharedLibraryFunctions[int(SharedLibraryFuncType::CountOf)]; + + Dictionary> builtinTypes; + Dictionary magicDecls; + + void initializeTypes(); + + Type* getBoolType(); + Type* getHalfType(); + Type* getFloatType(); + Type* getDoubleType(); + Type* getIntType(); + Type* getInt64Type(); + Type* getUIntType(); + Type* getUInt64Type(); + Type* getVoidType(); + Type* getBuiltinType(BaseType flavor); + + Type* getInitializerListType(); + Type* getOverloadedType(); + Type* getErrorType(); + Type* getStringType(); + + Type* getEnumTypeType(); + + // Construct the type `Ptr`, where `Ptr` + // is looked up as a builtin type. + RefPtr getPtrType(RefPtr valueType); + + // Construct the type `Out` + RefPtr getOutType(RefPtr valueType); + + // Construct the type `InOut` + RefPtr getInOutType(RefPtr valueType); + + // Construct the type `Ref` + RefPtr getRefType(RefPtr valueType); + + // Construct a pointer type like `Ptr`, but where + // the actual type name for the pointer type is given by `ptrTypeName` + RefPtr getPtrType(RefPtr valueType, char const* ptrTypeName); + + // Construct a pointer type like `Ptr`, but where + // the generic declaration for the pointer type is `genericDecl` + RefPtr getPtrType(RefPtr valueType, GenericDecl* genericDecl); + + RefPtr getArrayType( + Type* elementType, + IntVal* elementCount); + + RefPtr getVectorType( + RefPtr elementType, + RefPtr elementCount); + + SyntaxClass findSyntaxClass(Name* name); + + Dictionary > mapNameToSyntaxClass; + + // cache used by type checking, implemented in check.cpp + TypeCheckingCache* typeCheckingCache = nullptr; + TypeCheckingCache* getTypeCheckingCache(); + void destroyTypeCheckingCache(); + // + + /// Will try to load the library by specified name (using the set loader), if not one already available. + ISlangSharedLibrary* getOrLoadSharedLibrary(SharedLibraryType type, DiagnosticSink* sink); + + /// Gets a shared library by type, or null if not loaded + ISlangSharedLibrary* getSharedLibrary(SharedLibraryType type) const { return sharedLibraries[int(type)]; } + + SlangFuncPtr getSharedLibraryFunc(SharedLibraryFuncType type, DiagnosticSink* sink); + + Session(); + + void addBuiltinSource( + RefPtr const& scope, + String const& path, + String const& source); + ~Session(); + + private: + /// Linkage used for all built-in (stdlib) code. + RefPtr m_builtinLinkage; + }; + +} + +#endif diff --git a/source/slang/slang-decl-defs.h b/source/slang/slang-decl-defs.h new file mode 100644 index 000000000..04c733aac --- /dev/null +++ b/source/slang/slang-decl-defs.h @@ -0,0 +1,325 @@ +// slang-decl-defs.h + +// Syntax class definitions for declarations. + +// A group of declarations that should be treated as a unit +SYNTAX_CLASS(DeclGroup, DeclBase) + SYNTAX_FIELD(List>, decls) +END_SYNTAX_CLASS() + +// A "container" decl is a parent to other declarations +ABSTRACT_SYNTAX_CLASS(ContainerDecl, Decl) + SYNTAX_FIELD(List>, Members) + + RAW( + template + FilteredMemberList getMembersOfType() + { + return FilteredMemberList(Members); + } + + + // Dictionary for looking up members by name. + // This is built on demand before performing lookup. + Dictionary memberDictionary; + + // Whether the `memberDictionary` is valid. + // Should be set to `false` if any members get added/remoed. + bool memberDictionaryIsValid = false; + + // A list of transparent members, to be used in lookup + // Note: this is only valid if `memberDictionaryIsValid` is true + List transparentMembers; + ) +END_SYNTAX_CLASS() + +// Base class for all variable declarations +ABSTRACT_SYNTAX_CLASS(VarDeclBase, Decl) + + // type of the variable + SYNTAX_FIELD(TypeExp, type) + + RAW( + Type* getType() { return type.type.Ptr(); } + ) + + // Initializer expression (optional) + SYNTAX_FIELD(RefPtr, initExpr) +END_SYNTAX_CLASS() + +// Ordinary potentially-mutable variables (locals, globals, and member variables) +SYNTAX_CLASS(VarDecl, VarDeclBase) +END_SYNTAX_CLASS() + +// A variable declaration that is always immutable (whether local, global, or member variable) +SYNTAX_CLASS(LetDecl, VarDecl) +END_SYNTAX_CLASS() + +// An `AggTypeDeclBase` captures the shared functionality +// between true aggregate type declarations and extension +// declarations: +// +// - Both can container members (they are `ContainerDecl`s) +// - Both can have declared bases +// - Both expose a `this` variable in their body +// +ABSTRACT_SYNTAX_CLASS(AggTypeDeclBase, ContainerDecl) +END_SYNTAX_CLASS() + +// An extension to apply to an existing type +SYNTAX_CLASS(ExtensionDecl, AggTypeDeclBase) + SYNTAX_FIELD(TypeExp, targetType) + + // next extension attached to the same nominal type + DECL_FIELD(ExtensionDecl*, nextCandidateExtension RAW(= nullptr)) +END_SYNTAX_CLASS() + +// Declaration of a type that represents some sort of aggregate +ABSTRACT_SYNTAX_CLASS(AggTypeDecl, AggTypeDeclBase) + +RAW( + // extensions that might apply to this declaration + ExtensionDecl* candidateExtensions = nullptr; + FilteredMemberList GetFields() + { + return getMembersOfType(); + } + ) +END_SYNTAX_CLASS() + +SIMPLE_SYNTAX_CLASS(StructDecl, AggTypeDecl) + +SIMPLE_SYNTAX_CLASS(ClassDecl, AggTypeDecl) + +// TODO: Is it appropriate to treat an `enum` as an aggregate type? +// Most code that looks for, e.g., conformances assumes user-defined +// types are all `AggTypeDecl`, so this is the right choice for now +// if we want `enum` types to be able to implement interfaces, etc. +// +SYNTAX_CLASS(EnumDecl, AggTypeDecl) +RAW( + RefPtr tagType; +) +END_SYNTAX_CLASS() + +// A single case in an enum. +// +// E.g., in a declaration like: +// +// enum Color { Red = 0, Green, Blue }; +// +// The `Red = 0` is the declaration of the `Red` +// case, with `0` as an explicit expression for its +// _tag value_. +// +SYNTAX_CLASS(EnumCaseDecl, Decl) + + // type of the parent `enum` + SYNTAX_FIELD(TypeExp, type) + + RAW( + Type* getType() { return type.type.Ptr(); } + ) + + // Tag value + SYNTAX_FIELD(RefPtr, tagExpr) +END_SYNTAX_CLASS() + +// An interface which other types can conform to +SIMPLE_SYNTAX_CLASS(InterfaceDecl, AggTypeDecl) + +ABSTRACT_SYNTAX_CLASS(TypeConstraintDecl, Decl) + RAW( + virtual TypeExp& getSup() = 0; + ) +END_SYNTAX_CLASS() + +// A kind of pseudo-member that represents an explicit +// or implicit inheritance relationship. +// +SYNTAX_CLASS(InheritanceDecl, TypeConstraintDecl) +// The type expression as written + SYNTAX_FIELD(TypeExp, base) + + RAW( + // After checking, this dictionary will map members + // required by the base type to their concrete + // implementations in the type that contains + // this inheritance declaration. + RefPtr witnessTable; + virtual TypeExp& getSup() override + { + return base; + } + ) +END_SYNTAX_CLASS() + +// TODO: may eventually need sub-classes for explicit/direct vs. implicit/indirect inheritance + + +// A declaration that represents a simple (non-aggregate) type +// +// TODO: probably all types will be aggregate decls eventually, +// so that we can easily store conformances/constraints on type variables +ABSTRACT_SYNTAX_CLASS(SimpleTypeDecl, Decl) +END_SYNTAX_CLASS() + +// A `typedef` declaration +SYNTAX_CLASS(TypeDefDecl, SimpleTypeDecl) + SYNTAX_FIELD(TypeExp, type) +END_SYNTAX_CLASS() + +SIMPLE_SYNTAX_CLASS(TypeAliasDecl, TypeDefDecl) + +// An 'assoctype' declaration, it is a container of inheritance clauses +SYNTAX_CLASS(AssocTypeDecl, AggTypeDecl) +END_SYNTAX_CLASS() + +// A 'type_param' declaration, which defines a generic +// entry-point parameter. Is a container of GenericTypeConstraintDecl +SYNTAX_CLASS(GlobalGenericParamDecl, AggTypeDecl) +END_SYNTAX_CLASS() + +// A scope for local declarations (e.g., as part of a statement) +SIMPLE_SYNTAX_CLASS(ScopeDecl, ContainerDecl) + +// A function/initializer/subscript parameter (potentially mutable) +SIMPLE_SYNTAX_CLASS(ParamDecl, VarDeclBase) + +// A parameter of a function declared in "modern" types (immutable unless explicitly `out` or `inout`) +SIMPLE_SYNTAX_CLASS(ModernParamDecl, ParamDecl) + +// Base class for things that have parameter lists and can thus be applied to arguments ("called") +ABSTRACT_SYNTAX_CLASS(CallableDecl, ContainerDecl) + RAW( + FilteredMemberList GetParameters() + { + return getMembersOfType(); + }) + + SYNTAX_FIELD(TypeExp, ReturnType) + + // Fields related to redeclaration, so that we + // can support multiple specialized varaitions + // of the "same" logical function. + // + // This should also help us to support redeclaration + // of functions when handling HLSL/GLSL. + + // The "primary" declaration of the function, which will + // be used whenever we need to unique things. + FIELD_INIT(CallableDecl*, primaryDecl, nullptr) + + // The next declaration of the "same" function (that is, + // with the same `primaryDecl`). + FIELD_INIT(CallableDecl*, nextDecl, nullptr); + +END_SYNTAX_CLASS() + +// Base class for callable things that may also have a body that is evaluated to produce their result +ABSTRACT_SYNTAX_CLASS(FunctionDeclBase, CallableDecl) + SYNTAX_FIELD(RefPtr, Body) +END_SYNTAX_CLASS() + +// A constructor/initializer to create instances of a type +SIMPLE_SYNTAX_CLASS(ConstructorDecl, FunctionDeclBase) + +// A subscript operation used to index instances of a type +SIMPLE_SYNTAX_CLASS(SubscriptDecl, CallableDecl) + +// An "accessor" for a subscript or property +SIMPLE_SYNTAX_CLASS(AccessorDecl, FunctionDeclBase) + +SIMPLE_SYNTAX_CLASS(GetterDecl, AccessorDecl) +SIMPLE_SYNTAX_CLASS(SetterDecl, AccessorDecl) +SIMPLE_SYNTAX_CLASS(RefAccessorDecl, AccessorDecl) + +SIMPLE_SYNTAX_CLASS(FuncDecl, FunctionDeclBase) + +// A "module" of code (essentiately, a single translation unit) +// that provides a scope for some number of declarations. +SYNTAX_CLASS(ModuleDecl, ContainerDecl) + FIELD(RefPtr, scope) + + // The API-level module that this declaration belong to. + // + // This field allows lookup of the `Module` based on a + // declaration nested under a `ModuleDecl` by following + // its chain of parents. + // + RAW(Module* module = nullptr;) +END_SYNTAX_CLASS() + +SYNTAX_CLASS(ImportDecl, Decl) + // The name of the module we are trying to import + FIELD(NameLoc, moduleNameAndLoc) + + // The scope that we want to import into + FIELD(RefPtr, scope) + + // The module that actually got imported + DECL_FIELD(RefPtr, importedModuleDecl) +END_SYNTAX_CLASS() + +// A generic declaration, parameterized on types/values +SYNTAX_CLASS(GenericDecl, ContainerDecl) + // The decl that is genericized... + SYNTAX_FIELD(RefPtr, inner) +END_SYNTAX_CLASS() + +SYNTAX_CLASS(GenericTypeParamDecl, SimpleTypeDecl) + // The bound for the type parameter represents a trait that any + // type used as this parameter must conform to +// TypeExp bound; + + // The "initializer" for the parameter represents a default value + SYNTAX_FIELD(TypeExp, initType) +END_SYNTAX_CLASS() + +// A constraint placed as part of a generic declaration +SYNTAX_CLASS(GenericTypeConstraintDecl, TypeConstraintDecl) + // A type constraint like `T : U` is constraining `T` to be "below" `U` + // on a lattice of types. This may not be a subtyping relationship + // per se, but it makes sense to use that terminology here, so we + // think of these fields as the sub-type and sup-ertype, respectively. + SYNTAX_FIELD(TypeExp, sub) + SYNTAX_FIELD(TypeExp, sup) + RAW( + virtual TypeExp& getSup() override + { + return sup; + } + ) +END_SYNTAX_CLASS() + +SIMPLE_SYNTAX_CLASS(GenericValueParamDecl, VarDeclBase) + +// An empty declaration (which might still have modifiers attached). +// +// An empty declaration is uncommon in HLSL, but +// in GLSL it is often used at the global scope +// to declare metadata that logically belongs +// to the entry point, e.g.: +// +// layout(local_size_x = 16) in; +// +SIMPLE_SYNTAX_CLASS(EmptyDecl, Decl) + +// A declaration used by the implementation to put syntax keywords +// into the current scope. +// +SYNTAX_CLASS(SyntaxDecl, Decl) + // What type of syntax node will be produced when parsing with this keyword? + FIELD(SyntaxClass, syntaxClass) + + // Callback to invoke in order to parse syntax with this keyword. + FIELD(SyntaxParseCallback, parseCallback) + FIELD(void*, parseUserData) +END_SYNTAX_CLASS() + +// A declaration of an attribute to be used with `[name(...)]` syntax. +// +SYNTAX_CLASS(AttributeDecl, ContainerDecl) + // What type of syntax node will be produced to represent this attribute. + FIELD(SyntaxClass, syntaxClass) +END_SYNTAX_CLASS() diff --git a/source/slang/slang-diagnostic-defs.h b/source/slang/slang-diagnostic-defs.h new file mode 100644 index 000000000..59d840997 --- /dev/null +++ b/source/slang/slang-diagnostic-defs.h @@ -0,0 +1,487 @@ +// + +// The file is meant to be included multiple times, to produce different +// pieces of declaration/definition code related to diagnostic messages +// +// Each diagnostic is declared here with: +// +// DIAGNOSTIC(id, severity, name, messageFormat) +// +// Where `id` is the unique diagnostic ID, `severity` is the default +// severity (from the `Severity` enum), `name` is a name used to refer +// to this diagnostic from code, and `messageFormat` is the default +// (non-localized) message for the diagnostic, with placeholders +// for any arguments. + +#ifndef DIAGNOSTIC +#error Need to #define DIAGNOSTIC(...) before including "DiagnosticDefs.h" +#define DIAGNOSTIC(id, severity, name, messageFormat) /* */ +#endif + +// +// -1 - Notes that decorate another diagnostic. +// + +DIAGNOSTIC(-1, Note, alsoSeePipelineDefinition, "also see pipeline definition"); +DIAGNOSTIC(-1, Note, implicitParameterMatchingFailedBecauseNameNotAccessible, "implicit parameter matching failed because the component of the same name is not accessible from '$0'.\ncheck if you have declared necessary requirements and properly used the 'public' qualifier.") +DIAGNOSTIC(-1, Note, implicitParameterMatchingFailedBecauseShaderDoesNotDefineComponent, "implicit parameter matching failed because shader '$0' does not define component '$1'.") +DIAGNOSTIC(-1, Note, implicitParameterMatchingFailedBecauseTypeMismatch, "implicit parameter matching failed because the component of the same name does not match parameter type '$0'.") +DIAGNOSTIC(-1, Note, noteShaderIsTargetingPipeine, "shader '$0' is targeting pipeline '$1'") +DIAGNOSTIC(-1, Note, seeDefinitionOf, "see definition of '$0'") +DIAGNOSTIC(-1, Note, seeInterfaceDefinitionOf, "see interface definition of '$0'") +DIAGNOSTIC(-1, Note, seeUsingOf, "see using of '$0'") +DIAGNOSTIC(-1, Note, seeDefinitionOfShader, "see definition of shader '$0'") +DIAGNOSTIC(-1, Note, seeInclusionOf, "see inclusion of '$0'") +DIAGNOSTIC(-1, Note, seeModuleBeingUsedIn, "see module '$0' being used in '$1'") +DIAGNOSTIC(-1, Note, seePipelineRequirementDefinition, "see pipeline requirement definition") +DIAGNOSTIC(-1, Note, seePotentialDefinitionOfComponent, "see potential definition of component '$0'") +DIAGNOSTIC(-1, Note, seePreviousDefinition, "see previous definition") +DIAGNOSTIC(-1, Note, seePreviousDefinitionOf, "see previous definition of '$0'") +DIAGNOSTIC(-1, Note, seeRequirementDeclaration, "see requirement declaration") +DIAGNOSTIC(-1, Note, doYouForgetToMakeComponentAccessible, "do you forget to make component '$0' acessible from '$1' (missing public qualifier)?") + +DIAGNOSTIC(-1, Note, seeDeclarationOf, "see declaration of '$0'") +DIAGNOSTIC(-1, Note, seeOtherDeclarationOf, "see other declaration of '$0'") +DIAGNOSTIC(-1, Note, seePreviousDeclarationOf, "see previous declaration of '$0'") + +// +// 0xxxx - Command line and interaction with host platform APIs. +// + +DIAGNOSTIC( 1, Error, cannotOpenFile, "cannot open file '$0'.") +DIAGNOSTIC( 2, Error, cannotFindFile, "cannot find file '$0'.") +DIAGNOSTIC( 2, Error, unsupportedCompilerMode, "unsupported compiler mode.") +DIAGNOSTIC( 4, Error, cannotWriteOutputFile, "cannot write output file '$0'.") +DIAGNOSTIC( 5, Error, failedToLoadDynamicLibrary, "failed to load dynamic library '$0'") +DIAGNOSTIC( 6, Error, tooManyOutputPathsSpecified, "$0 output paths specified, but only $1 entry points given") + +DIAGNOSTIC( 7, Error, noOutputPathSpecifiedForEntryPoint, + "no output path specified for entry point '$0' (the '-o' option for an entry point must precede the corresponding '-entry')") + +DIAGNOSTIC( 8, Error, outputPathsImplyDifferentFormats, + "the output paths '$0' and '$1' require different code-generation targets") + +DIAGNOSTIC( 10, Error, explicitOutputPathsAndMultipleTargets, "canot use both explicit output paths ('-o') and multiple targets ('-target')") +DIAGNOSTIC( 11, Error, glslIsNotSupported, "the Slang compiler does not support GLSL as a source language"); +DIAGNOSTIC( 12, Error, cannotDeduceSourceLanguage, "can't deduce language for input file '$0'"); +DIAGNOSTIC( 13, Error, unknownCodeGenerationTarget, "unknown code generation target '$0'"); +DIAGNOSTIC( 14, Error, unknownProfile, "unknown profile '$0'"); +DIAGNOSTIC( 15, Error, unknownStage, "unknown stage '$0'"); +DIAGNOSTIC( 16, Error, unknownPassThroughTarget, "unknown pass-through target '$0'"); +DIAGNOSTIC( 17, Error, unknownCommandLineOption, "unknown command-line option '$0'"); +DIAGNOSTIC( 20, Error, entryPointsNeedToBeAssociatedWithTranslationUnits, "when using multiple source files, entry points must be specified after their corresponding source file(s)"); +DIAGNOSTIC( 21, Error, expectedArgumentForOption, "expected an argument for command-line option '$0'"); + +DIAGNOSTIC( 24, Error, unknownLineDirectiveMode, "unknown '#line' directive mode '$0'"); +DIAGNOSTIC( 25, Error, unknownFloatingPointMode, "unknown floating-point mode '$0'"); +DIAGNOSTIC( 26, Error, unknownOptimiziationLevel, "unknown optimization level '$0'"); +DIAGNOSTIC( 27, Error, uknownDebugInfoLevel, "unknown debug info level '$0'"); + +DIAGNOSTIC( 30, Warning, sameStageSpecifiedMoreThanOnce, "the stage '$0' was specified more than once for entry point '$1'") +DIAGNOSTIC( 31, Error, conflictingStagesForEntryPoint, "conflicting stages have been specified for entry point '$0'") +DIAGNOSTIC( 32, Warning, explicitStageDoesntMatchImpliedStage, "the stage specified for entry point '$0' ('$1') does not match the stage implied by the source file name ('$2')") +DIAGNOSTIC( 33, Error, stageSpecificationIgnoredBecauseNoEntryPoints, "one or more stages were specified, but no entry points were specified with '-entry'") +DIAGNOSTIC( 34, Error, stageSpecificationIgnoredBecauseBeforeAllEntryPoints, "when compiling multiple entry points, any '-stage' options must follow the '-entry' option that they apply to") +DIAGNOSTIC( 35, Error, noStageSpecifiedInPassThroughMode, "no stage was specified for entry point '$0'; when using the '-pass-through' option, stages must be fully specified on the command line") + +DIAGNOSTIC( 40, Warning, sameProfileSpecifiedMoreThanOnce, "the '$0' was specified more than once for target '$0'") +DIAGNOSTIC( 41, Error, conflictingProfilesSpecifiedForTarget, "conflicting profiles have been specified for target '$0'") + +DIAGNOSTIC( 42, Error, profileSpecificationIgnoredBecauseNoTargets, "a '-profile' option was specified, but no target was specified with '-target'") +DIAGNOSTIC( 43, Error, profileSpecificationIgnoredBecauseBeforeAllTargets, "when using multiple targets, any '-profile' option must follow the '-target' it applies to") + +DIAGNOSTIC( 42, Error, targetFlagsIgnoredBecauseNoTargets, "target options were specified, but no target was specified with '-target'") +DIAGNOSTIC( 43, Error, targetFlagsIgnoredBecauseBeforeAllTargets, "when using multiple targets, any target options must follow the '-target' they apply to") + +DIAGNOSTIC( 50, Error, duplicateTargets, "the target '$0' has been specified more than once") + +DIAGNOSTIC( 60, Error, cannotDeduceOutputFormatFromPath, "cannot infer an output format from the output path '$0'") +DIAGNOSTIC( 61, Error, cannotMatchOutputFileToTarget, "no specified '-target' option matches the output path '$0', which implies the '$1' format") + +DIAGNOSTIC( 62, Error, failedToFindFunctionInSharedLibrary, "failed to find function '$0' in shared/dynamic library '$1'") + +DIAGNOSTIC( 70, Error, cannotMatchOutputFileToEntryPoint, "the output path '$0' is not associated with any entry point; a '-o' option for a compiled kernel must follow the '-entry' option for its corresponding entry point") + +DIAGNOSTIC( 80, Error, duplicateOutputPathsForEntryPointAndTarget, "multiple output paths have been specified entry point '$0' on target '$1'") + +// +// 1xxxx - Lexical anaylsis +// + +DIAGNOSTIC(10000, Error, illegalCharacterPrint, "illegal character '$0'"); +DIAGNOSTIC(10000, Error, illegalCharacterHex, "illegal character (0x$0)"); +DIAGNOSTIC(10001, Error, illegalCharacterLiteral, "illegal character literal"); + +DIAGNOSTIC(10002, Warning, octalLiteral, "'0' prefix indicates octal literal") +DIAGNOSTIC(10003, Error, invalidDigitForBase, "invalid digit for base-$1 literal: '$0'") + +DIAGNOSTIC(10004, Error, endOfFileInLiteral, "end of file in literal"); +DIAGNOSTIC(10005, Error, newlineInLiteral, "newline in literal"); + +// +// 15xxx - Preprocessing +// + +// 150xx - conditionals +DIAGNOSTIC(15000, Error, endOfFileInPreprocessorConditional, "end of file encountered during preprocessor conditional") +DIAGNOSTIC(15001, Error, directiveWithoutIf, "'$0' directive without '#if'") +DIAGNOSTIC(15002, Error, directiveAfterElse , "'$0' directive without '#if'") + +DIAGNOSTIC(-1, Note, seeDirective, "see '$0' directive") + +// 151xx - directive parsing +DIAGNOSTIC(15100, Error, expectedPreprocessorDirectiveName, "expected preprocessor directive name") +DIAGNOSTIC(15101, Error, unknownPreprocessorDirective, "unknown preprocessor directive '$0'") +DIAGNOSTIC(15102, Error, expectedTokenInPreprocessorDirective, "expected '$0' in '$1' directive") +DIAGNOSTIC(15102, Error, expected2TokensInPreprocessorDirective, "expected '$0' or '$1' in '$2' directive") +DIAGNOSTIC(15103, Error, unexpectedTokensAfterDirective, "unexpected tokens following '$0' directive") + + +// 152xx - preprocessor expressions +DIAGNOSTIC(15200, Error, expectedTokenInPreprocessorExpression, "expected '$0' in preprocessor expression"); +DIAGNOSTIC(15201, Error, syntaxErrorInPreprocessorExpression, "syntax error in preprocessor expression"); +DIAGNOSTIC(15202, Error, divideByZeroInPreprocessorExpression, "division by zero in preprocessor expression"); +DIAGNOSTIC(15203, Error, expectedTokenInDefinedExpression, "expected '$0' in 'defined' expression"); +DIAGNOSTIC(15204, Warning, directiveExpectsExpression, "'$0' directive requires an expression"); +DIAGNOSTIC(15205, Warning, undefinedIdentifierInPreprocessorExpression, "undefined identifier '$0' in preprocessor expression will evaluate to zero") + +DIAGNOSTIC(-1, Note, seeOpeningToken, "see opening '$0'") + +// 153xx - #include +DIAGNOSTIC(15300, Error, includeFailed, "failed to find include file '$0'") +DIAGNOSTIC(15301, Error, importFailed, "failed to find imported file '$0'") +DIAGNOSTIC(-1, Error, noIncludeHandlerSpecified, "no `#include` handler was specified") +DIAGNOSTIC(15302, Error, noUniqueIdentity, "`#include` handler didn't generate a unique identity for file '$0'") + + +// 154xx - macro definition +DIAGNOSTIC(15400, Warning, macroRedefinition, "redefinition of macro '$0'") +DIAGNOSTIC(15401, Warning, macroNotDefined, "macro '$0' is not defined") +DIAGNOSTIC(15403, Error, expectedTokenInMacroParameters, "expected '$0' in macro parameters") + +// 155xx - macro expansion +DIAGNOSTIC(15500, Warning, expectedTokenInMacroArguments, "expected '$0' in macro invocation") +DIAGNOSTIC(15501, Error, wrongNumberOfArgumentsToMacro, "wrong number of arguments to macro (expected $0, got $1)") + +// 156xx - pragmas +DIAGNOSTIC(15600, Error, expectedPragmaDirectiveName, "expected a name after '#pragma'") +DIAGNOSTIC(15601, Warning, unknownPragmaDirectiveIgnored, "ignoring unknown directive '#pragma $0'") +DIAGNOSTIC(15602, Warning, pragmaOnceIgnored, "pragma once was ignored - this is typically because is not placed in an include") + +// 159xx - user-defined error/warning +DIAGNOSTIC(15900, Error, userDefinedError, "#error: $0") +DIAGNOSTIC(15901, Warning, userDefinedWarning, "#warning: $0") + +// +// 2xxxx - Parsing +// + +DIAGNOSTIC(20003, Error, unexpectedToken, "unexpected $0"); +DIAGNOSTIC(20001, Error, unexpectedTokenExpectedTokenType, "unexpected $0, expected $1"); +DIAGNOSTIC(20001, Error, unexpectedTokenExpectedTokenName, "unexpected $0, expected '$1'"); + +DIAGNOSTIC(0, Error, tokenNameExpectedButEOF, "\"$0\" expected but end of file encountered."); +DIAGNOSTIC(0, Error, tokenTypeExpectedButEOF, "$0 expected but end of file encountered."); +DIAGNOSTIC(20001, Error, tokenNameExpected, "\"$0\" expected"); +DIAGNOSTIC(20001, Error, tokenNameExpectedButEOF2, "\"$0\" expected but end of file encountered."); +DIAGNOSTIC(20001, Error, tokenTypeExpected, "$0 expected"); +DIAGNOSTIC(20001, Error, tokenTypeExpectedButEOF2, "$0 expected but end of file encountered."); +DIAGNOSTIC(20001, Error, typeNameExpectedBut, "unexpected $0, expected type name"); +DIAGNOSTIC(20001, Error, typeNameExpectedButEOF, "type name expected but end of file encountered."); +DIAGNOSTIC(20001, Error, unexpectedEOF, " Unexpected end of file."); +DIAGNOSTIC(20002, Error, syntaxError, "syntax error."); +DIAGNOSTIC(20004, Error, unexpectedTokenExpectedComponentDefinition, "unexpected token '$0', only component definitions are allowed in a shader scope.") +DIAGNOSTIC(20008, Error, invalidOperator, "invalid operator '$0'."); +DIAGNOSTIC(20011, Error, unexpectedColon, "unexpected ':'.") + +// +// 3xxxx - Semantic analysis +// + +DIAGNOSTIC(30002, Error, parameterAlreadyDefined, "parameter '$0' already defined.") +DIAGNOSTIC(30003, Error, breakOutsideLoop, "'break' must appear inside loop constructs.") +DIAGNOSTIC(30004, Error, continueOutsideLoop, "'continue' must appear inside loop constructs.") +DIAGNOSTIC(30005, Error, whilePredicateTypeError, "'while': expression must evaluate to int.") +DIAGNOSTIC(30006, Error, ifPredicateTypeError, "'if': expression must evaluate to int.") +DIAGNOSTIC(30006, Error, returnNeedsExpression, "'return' should have an expression.") +DIAGNOSTIC(30007, Error, componentReturnTypeMismatch, "expression type '$0' does not match component's type '$1'") +DIAGNOSTIC(30007, Error, functionReturnTypeMismatch, "expression type '$0' does not match function's return type '$1'") +DIAGNOSTIC(30008, Error, variableNameAlreadyDefined, "variable $0 already defined.") +DIAGNOSTIC(30009, Error, invalidTypeVoid, "invalid type 'void'.") +DIAGNOSTIC(30010, Error, whilePredicateTypeError2, "'while': expression must evaluate to int.") +DIAGNOSTIC(30011, Error, assignNonLValue, "left of '=' is not an l-value.") +DIAGNOSTIC(30012, Error, noApplicationUnaryOperator, "no overload found for operator $0 ($1).") +DIAGNOSTIC(30012, Error, noOverloadFoundForBinOperatorOnTypes, "no overload found for operator $0 ($1, $2).") +DIAGNOSTIC(30013, Error, subscriptNonArray, "no subscript operation found for type '$0'") +DIAGNOSTIC(30014, Error, subscriptIndexNonInteger, "index expression must evaluate to int.") +DIAGNOSTIC(30015, Error, undefinedIdentifier, "'$0': undefined identifier.") +DIAGNOSTIC(30015, Error, undefinedIdentifier2, "undefined identifier '$0'.") +DIAGNOSTIC(30017, Error, componentNotAccessibleFromShader, "component '$0' is not accessible from shader '$1'.") +DIAGNOSTIC(30019, Error, typeMismatch, "expected an expression of type '$0', got '$1'") +DIAGNOSTIC(30020, Error, importOperatorReturnTypeMismatch, "import operator should return '$1', but the expression has type '$0''. do you forget 'project'?") +DIAGNOSTIC(30021, Error, noApplicationFunction, "$0: no overload takes arguments ($1)") +DIAGNOSTIC(30022, Error, invalidTypeCast, "invalid type cast between \"$0\" and \"$1\".") +DIAGNOSTIC(30023, Error, typeHasNoPublicMemberOfName, "\"$0\" does not have public member \"$1\"."); +DIAGNOSTIC(30025, Error, invalidArraySize, "array size must be larger than zero.") +DIAGNOSTIC(30026, Error, returnInComponentMustComeLast, "'return' can only appear as the last statement in component definition.") +DIAGNOSTIC(30027, Error, noMemberOfNameInType, "'$0' is not a member of '$1'."); +DIAGNOSTIC(30028, Error, forPredicateTypeError, "'for': predicate expression must evaluate to bool.") +DIAGNOSTIC(30030, Error, projectionOutsideImportOperator, "'project': invalid use outside import operator.") +DIAGNOSTIC(30031, Error, projectTypeMismatch, "'project': expression must evaluate to record type '$0'.") +DIAGNOSTIC(30033, Error, invalidTypeForLocalVariable, "cannot declare a local variable of this type.") +DIAGNOSTIC(30035, Error, componentOverloadTypeMismatch, "'$0': type of overloaded component mismatches previous definition.") +DIAGNOSTIC(30041, Error, bitOperationNonIntegral, "bit operation: operand must be integral type.") +DIAGNOSTIC(30047, Error, argumentExpectedLValue, "argument passed to parameter '$0' must be l-value.") +DIAGNOSTIC(30048, Note, implicitCastUsedAsLValue, "argument was implicitly cast from '$0' to '$1', and Slang does not support using an implicit cast as an l-value") +DIAGNOSTIC(30049, Note, thisIsImmutableByDefault, "a 'this' parameter is an immutable parameter by default in Slang; apply the `[mutating]` attribute to the function declaration to opt in to a mutable `this`") + +DIAGNOSTIC(30051, Error, invalidValueForArgument, "invalid value for argument '$0'") +DIAGNOSTIC(30052, Error, invalidSwizzleExpr, "invalid swizzle pattern '$0' on type '$1'") + +DIAGNOSTIC(30100, Error, staticRefToNonStaticMember, "type '$0' cannot be used to refer to non-static member '$1'") + +DIAGNOSTIC(30201, Error, functionRedefinition, "function '$0' already has a body") +DIAGNOSTIC(30202, Error, functionRedeclarationWithDifferentReturnType, "function '$0' declared to return '$1' was previously declared to return '$2'") + + +DIAGNOSTIC(33070, Error, expectedFunction, "expected a function, got '$0'") +DIAGNOSTIC(33071, Error, expectedAStringLiteral, "expected a string literal") + +DIAGNOSTIC( -1, Note, noteExplicitConversionPossible, "explicit conversion from '$0' to '$1' is possible") +DIAGNOSTIC(30080, Error, ambiguousConversion, "more than one implicit conversion exists from '$0' to '$1'") + + + + +// Attributes +DIAGNOSTIC(31000, Error, unknownAttributeName, "unknown attribute '$0'") +DIAGNOSTIC(31001, Error, attributeArgumentCountMismatch, "attribute '$0' expects $1 arguments ($2 provided)") +DIAGNOSTIC(31002, Error, attributeNotApplicable, "attribute '$0' is not valid here") + +DIAGNOSTIC(31003, Error, badlyDefinedPatchConstantFunc, "hull shader '$0' has has badly defined 'patchconstantfunc' attribute."); + +DIAGNOSTIC(31004, Error, expectedSingleIntArg, "attribute '$0' expects a single int argument") +DIAGNOSTIC(31005, Error, expectedSingleStringArg, "attribute '$0' expects a single string argument") + +DIAGNOSTIC(31006, Error, attributeFunctionNotFound, "Could not find function '$0' for attribute'$1'") + +DIAGNOSTIC(31100, Error, unknownStageName, "unknown stage name '$0'") +DIAGNOSTIC(31101, Error, unknownImageFormatName, "unknown image format '$0'") + +DIAGNOSTIC(31120, Error, invalidAttributeTarget, "invalid syntax target for user defined attribute") + +// Enums + +DIAGNOSTIC(32000, Error, invalidEnumTagType, "invalid tag type for 'enum': '$0'") +DIAGNOSTIC(32001, Error, enumTypeAlreadyHasTagType, "'enum' type has already declared a tag type") +DIAGNOSTIC(32002, Note, seePreviousTagType, "see previous tag type declaration") +DIAGNOSTIC(32003, Error, unexpectedEnumTagExpr, "unexpected form for 'enum' tag value expression") + + + +// 303xx: interfaces and associated types +DIAGNOSTIC(30300, Error, assocTypeInInterfaceOnly, "'associatedtype' can only be defined in an 'interface'.") +DIAGNOSTIC(30301, Error, globalGenParamInGlobalScopeOnly, "'type_param' can only be defined global scope.") +// TODO: need to assign numbers to all these extra diagnostics... +DIAGNOSTIC(39999, Fatal, cyclicReference, "cyclic reference '$0'.") +DIAGNOSTIC(39999, Fatal, localVariableUsedBeforeDeclared, "local variable '$0' is being used before its declaration.") + +// 304xx: generics +DIAGNOSTIC(30400, Error, genericTypeNeedsArgs, "generic type '$0' used without argument") + +// 305xx: initializer lists +DIAGNOSTIC(30500, Error, tooManyInitializers, "too many initializers (expected $0, got $1)"); +DIAGNOSTIC(30501, Error, cannotUseInitializerListForArrayOfUnknownSize, "cannot use initializer list for array of statically unknown size '$0'"); +DIAGNOSTIC(30502, Error, cannotUseInitializerListForVectorOfUnknownSize, "cannot use initializer list for vector of statically unknown size '$0'"); +DIAGNOSTIC(30503, Error, cannotUseInitializerListForMatrixOfUnknownSize, "cannot use initializer list for matrix of statically unknown size '$0' rows"); +DIAGNOSTIC(30504, Error, cannotUseInitializerListForType, "cannot use initializer list for type '$0'") + +// 306xx: variables +DIAGNOSTIC(30600, Error, varWithoutTypeMustHaveInitializer, "a variable declaration without an initial-value expression must be given an explicit type"); + +// 307xx: parameters +DIAGNOSTIC(30700, Error, outputParameterCannotHaveDefaultValue, "an 'out' or 'inout' parameter cannot have a default-value expression"); + +DIAGNOSTIC(39999, Error, expectedIntegerConstantWrongType, "expected integer constant (found: '$0')") +DIAGNOSTIC(39999, Error, expectedIntegerConstantNotConstant, "expression does not evaluate to a compile-time constant") +DIAGNOSTIC(39999, Error, expectedIntegerConstantNotLiteral, "could not extract value from integer constant") + +DIAGNOSTIC(39999, Error, noApplicableOverloadForNameWithArgs, "no overload for '$0' applicable to arguments of type $1") +DIAGNOSTIC(39999, Error, noApplicableWithArgs, "no overload applicable to arguments of type $0") + +DIAGNOSTIC(39999, Error, ambiguousOverloadForNameWithArgs, "ambiguous call to '$0' operation with arguments of type $1") +DIAGNOSTIC(39999, Error, ambiguousOverloadWithArgs, "ambiguous call to overloaded operation with arguments of type $0") + +DIAGNOSTIC(39999, Note, overloadCandidate, "candidate: $0") +DIAGNOSTIC(39999, Note, moreOverloadCandidates, "$0 more overload candidates") + +DIAGNOSTIC(39999, Error, caseOutsideSwitch, "'case' not allowed outside of a 'switch' statement") +DIAGNOSTIC(39999, Error, defaultOutsideSwitch, "'default' not allowed outside of a 'switch' statement") + +DIAGNOSTIC(39999, Error, expectedAGeneric, "expected a generic when using '<...>' (found: '$0')") + +DIAGNOSTIC(39999, Error, genericArgumentInferenceFailed, "could not specialize generic for arguments of type $0") +DIAGNOSTIC(39999, Note, genericSignatureTried, "see declaration of $0") + +DIAGNOSTIC(39999, Error, expectedAnInterfaceGot, "expected an interface, got '$0'") + +DIAGNOSTIC(39999, Error, ambiguousReference, "amiguous reference to '$0'"); + +DIAGNOSTIC(39999, Error, declarationDidntDeclareAnything, "declaration does not declare anything"); + + +DIAGNOSTIC(39999, Error, expectedPrefixOperator, "function called as prefix operator was not declared `__prefix`") +DIAGNOSTIC(39999, Error, expectedPostfixOperator, "function called as postfix operator was not declared `__postfix`") + +DIAGNOSTIC(39999, Error, notEnoughArguments, "not enough arguments to call (got $0, expected $1)") +DIAGNOSTIC(39999, Error, tooManyArguments, "too many arguments to call (got $0, expected $1)") + +DIAGNOSTIC(39999, Error, invalidIntegerLiteralSuffix, "invalid suffix '$0' on integer literal") +DIAGNOSTIC(39999, Error, invalidFloatingPointLiteralSuffix, "invalid suffix '$0' on floating-point literal") + + + +DIAGNOSTIC(38000, Error, entryPointFunctionNotFound, "no function found matching entry point name '$0'") +DIAGNOSTIC(38001, Error, ambiguousEntryPoint, "more than one function matches entry point name '$0'") +DIAGNOSTIC(38002, Note, entryPointCandidate, "see candidate declaration for entry point '$0'") +DIAGNOSTIC(38003, Error, entryPointSymbolNotAFunction, "entry point '$0' must be declared as a function") + +DIAGNOSTIC(38004, Error, entryPointTypeParameterNotFound, "no type found matching entry-point type parameter name '$0'") +DIAGNOSTIC(38005, Error, globalGenericArgumentNotAType, "argument for global generic parameter '$0' must be a type") + +DIAGNOSTIC(38006, Warning, specifiedStageDoesntMatchAttribute, "entry point '$0' being compiled for the '$1' stage has a '[shader(...)]' attribute that specifies the '$2' stage") +DIAGNOSTIC(38007, Error, entryPointHasNoStage, "no stage specified for entry point '$0'; use either a '[shader(\"name\")]' function attribute or the '-stage ' command-line option to specify a stage") + +DIAGNOSTIC(38100, Error, typeDoesntImplementInterfaceRequirement, "type '$0' does not provide required interface member '$1'") +DIAGNOSTIC(38101, Error, thisExpressionOutsideOfTypeDecl, "'this' expression can only be used in members of an aggregate type") +DIAGNOSTIC(38102, Error, initializerNotInsideType, "an 'init' declaration is only allowed inside a type or 'extension' declaration") +DIAGNOSTIC(38102, Error, accessorMustBeInsideSubscriptOrProperty, "an accessor declaration is only allowed inside a subscript or property declaration") + +DIAGNOSTIC(38020, Error, mismatchEntryPointTypeArgument, "expecting $0 entry-point type arguments, provided $1.") +DIAGNOSTIC(38021, Error, typeArgumentDoesNotConformToInterface, "type argument `$1` for generic parameter `$0` does not conform to interface `$2`.") + +DIAGNOSTIC(38022, Error, cannotSpecializeGlobalGenericToItself, "the global type parameter '$0' cannot be specialized to itself") +DIAGNOSTIC(38023, Error, cannotSpecializeGlobalGenericToAnotherGenericParam, "the global type parameter '$0' cannot be specialized using another global type parameter ('$1')") + +DIAGNOSTIC(38024, Error, invalidDispatchThreadIDType, "parameter with SV_DispatchThreadID must be either scalar or vector (1 to 3) of uint/int but is $0"); + +DIAGNOSTIC(-1, Note, noteWhenCompilingEntryPoint, "when compiling entry point '$0'") + +DIAGNOSTIC(38025, Error, mismatchGlobalGenericArguments, "expected $0 global generic arguments ($1 provided)") +DIAGNOSTIC(38026, Error, globalTypeArgumentDoesNotConformToInterface, "type argument `$1` for global generic parameter `$0` does not conform to interface `$2`.") + +DIAGNOSTIC(38027, Error, mismatchExistentialSlotArgCount, "expected $0 existential slot arguments ($1 provided)") +DIAGNOSTIC(38028, Error, existentialSlotArgNotAType, "existential slot argument $0 was not a type") +DIAGNOSTIC(38029, Error, existentialSlotArgDoesNotConform, "existential slot argument $0 does not conform to the required interface '$1'") + +DIAGNOSTIC(38200, Error, recursiveModuleImport, "module `$0` recursively imports itself") +DIAGNOSTIC(39999, Fatal, errorInImportedModule, "error in imported module, compilation ceased.") + +// 39xxx - Type layout and parameter binding. + +DIAGNOSTIC(39000, Error, conflictingExplicitBindingsForParameter, "conflicting explicit bindings for parameter '$0'") +DIAGNOSTIC(39001, Warning, parameterBindingsOverlap, "explicit binding for parameter '$0' overlaps with parameter '$1'") + + +DIAGNOSTIC(39002, Error, shaderParameterDeclarationsDontMatch, "declarations of shader parameter '$0' in different translation units don't match") + +DIAGNOSTIC(39003, Note, shaderParameterTypeMismatch, "type is declared as '$0' in one translation unit, and '$0' in another") +DIAGNOSTIC(39004, Note, fieldTypeMisMatch, "type of field '$0' is declared as '$1' in one translation unit, and '$2' in another") +DIAGNOSTIC(39005, Note, fieldDeclarationsDontMatch, "type '$0' is declared with different fields in each translation unit") +DIAGNOSTIC(39006, Note, usedInDeclarationOf, "used in declaration of '$0'") + +DIAGNOSTIC(39007, Error, unknownRegisterClass, "unknown register class: '$0'") +DIAGNOSTIC(39008, Error, expectedARegisterIndex, "expected a register index after '$0'") +DIAGNOSTIC(39009, Error, expectedSpace, "expected 'space', got '$0'") +DIAGNOSTIC(39010, Error, expectedSpaceIndex, "expected a register space index after 'space'") +DIAGNOSTIC(39011, Error, componentMaskNotSupported, "explicit register component masks are not yet supported in Slang") +DIAGNOSTIC(39012, Error, packOffsetNotSupported, "explicit 'packoffset' bindings are not yet supported in Slang") +DIAGNOSTIC(39013, Warning, registerModifierButNoVulkanLayout, "shader parameter '$0' has a 'register' specified for D3D, but no '[[vk::binding(...)]]` specified for Vulkan") +DIAGNOSTIC(39014, Error, unexpectedSpecifierAfterSpace, "unexpected specifier after register space: '$0'") +DIAGNOSTIC(39015, Error, wholeSpaceParameterRequiresZeroBinding, "shader parameter '$0' consumes whole descriptor sets, so the binding must be in the form '[[vk::binding(0, ...)]]'; the non-zero binding '$1' is not allowed") + +DIAGNOSTIC(39013, Error, dontExpectOutParametersForStage, "the '$0' stage does not support `out` or `inout` entry point parameters") +DIAGNOSTIC(39014, Error, dontExpectInParametersForStage, "the '$0' stage does not support `in` entry point parameters") + +DIAGNOSTIC(39016, Error, globalUniformsNotSupported, "'$0' is implicitly a global uniform shader parameter, which is currently unsupported by Slang. If a uniform parameter is intended, use a constant buffer or parameter block. If a global is intended, use the 'static' modifier.") + +DIAGNOSTIC(39017, Error, tooManyShaderRecordConstantBuffers, "can have at most one 'shader record' attributed constant buffer; found $0.") + +// +// 4xxxx - IL code generation. +// +DIAGNOSTIC(40001, Error, bindingAlreadyOccupiedByComponent, "resource binding location '$0' is already occupied by component '$1'.") +DIAGNOSTIC(40002, Error, invalidBindingValue, "binding location '$0' is out of valid range.") +DIAGNOSTIC(40003, Error, bindingExceedsLimit, "binding location '$0' assigned to component '$1' exceeds maximum limit.") +DIAGNOSTIC(40004, Error, bindingAlreadyOccupiedByModule, "DescriptorSet ID '$0' is already occupied by module instance '$1'.") +DIAGNOSTIC(40005, Error, topLevelModuleUsedWithoutSpecifyingBinding, "top level module '$0' is being used without specifying binding location. Use [Binding: \"index\"] attribute to provide a binding location.") + + +DIAGNOSTIC(49999, Error, unknownSystemValueSemantic, "unknown system-value semantic '$0'") + +DIAGNOSTIC(40006, Error, needCompileTimeConstant, "expected a compile-time constant") + +DIAGNOSTIC(40007, Internal, irValidationFailed, "IR validation failed: $0") + +DIAGNOSTIC(40008, Error, invalidLValueForRefParameter, "the form of this l-value argument is not valid for a `ref` parameter") + +// 41000 - IR-level validation issues + +DIAGNOSTIC(41000, Warning, unreachableCode, "unreachable code detected") + +DIAGNOSTIC(41010, Warning, missingReturn, "control flow may reach end of non-'void' function") + +// +// 5xxxx - Target code generation. +// + +DIAGNOSTIC(50010, Internal, missingExistentialBindingsForParameter, "missing argument for existential parameter slot"); + +DIAGNOSTIC(50020, Error, invalidTessCoordType, "TessCoord must have vec2 or vec3 type.") +DIAGNOSTIC(50020, Error, invalidFragCoordType, "FragCoord must be a vec4.") +DIAGNOSTIC(50020, Error, invalidInvocationIdType, "InvocationId must have int type.") +DIAGNOSTIC(50020, Error, invalidThreadIdType, "ThreadId must have int type.") +DIAGNOSTIC(50020, Error, invalidPrimitiveIdType, "PrimitiveId must have int type.") +DIAGNOSTIC(50020, Error, invalidPatchVertexCountType, "PatchVertexCount must have int type.") +DIAGNOSTIC(50022, Error, worldIsNotDefined, "world '$0' is not defined."); +DIAGNOSTIC(50023, Error, stageShouldProvideWorldAttribute, "'$0' should provide 'World' attribute."); +DIAGNOSTIC(50040, Error, componentHasInvalidTypeForPositionOutput, "'$0': component used as 'loc' output must be of vec4 type.") +DIAGNOSTIC(50041, Error, componentNotDefined, "'$0': component not defined.") + +DIAGNOSTIC(50052, Error, domainShaderRequiresControlPointCount, "'DomainShader' requires attribute 'ControlPointCount'."); +DIAGNOSTIC(50052, Error, hullShaderRequiresControlPointCount, "'HullShader' requires attribute 'ControlPointCount'.") +DIAGNOSTIC(50052, Error, hullShaderRequiresControlPointWorld, "'HullShader' requires attribute 'ControlPointWorld'."); +DIAGNOSTIC(50052, Error, hullShaderRequiresCornerPointWorld, "'HullShader' requires attribute 'CornerPointWorld'."); +DIAGNOSTIC(50052, Error, hullShaderRequiresDomain, "'HullShader' requires attribute 'Domain'."); +DIAGNOSTIC(50052, Error, hullShaderRequiresInputControlPointCount, "'HullShader' requires attribute 'InputControlPointCount'.") +DIAGNOSTIC(50052, Error, hullShaderRequiresOutputTopology, "'HullShader' requires attribute 'OutputTopology'.") +DIAGNOSTIC(50052, Error, hullShaderRequiresPartitioning, "'HullShader' requires attribute 'Partitioning'.") +DIAGNOSTIC(50052, Error, hullShaderRequiresPatchWorld, "'HullShader' requires attribute 'PatchWorld'."); +DIAGNOSTIC(50052, Error, hullShaderRequiresTessLevelInner, "'HullShader' requires attribute 'TessLevelInner'.") +DIAGNOSTIC(50052, Error, hullShaderRequiresTessLevelOuter, "'HullShader' requires attribute 'TessLevelOuter'.") + +DIAGNOSTIC(50053, Error, invalidTessellationDomian, "'Domain' should be either 'triangles' or 'quads'."); +DIAGNOSTIC(50053, Error, invalidTessellationOutputTopology, "'OutputTopology' must be one of: 'point', 'line', 'triangle_cw', or 'triangle_ccw'."); +DIAGNOSTIC(50053, Error, invalidTessellationPartitioning, "'Partitioning' must be one of: 'integer', 'pow2', 'fractional_even', or 'fractional_odd'.") +DIAGNOSTIC(50053, Error, invalidTessellationDomain, "'Domain' should be either 'triangles' or 'quads'.") + +DIAGNOSTIC(50082, Error, importingFromPackedBufferUnsupported, "importing type '$0' from PackedBuffer is not supported by the GLSL backend.") +DIAGNOSTIC(51090, Error, cannotGenerateCodeForExternComponentType, "cannot generate code for extern component type '$0'.") +DIAGNOSTIC(51091, Error, typeCannotBePlacedInATexture, "type '$0' cannot be placed in a texture.") +DIAGNOSTIC(51092, Error, stageDoesntHaveInputWorld, "'$0' doesn't appear to have any input world"); + +DIAGNOSTIC(52000, Error, multiLevelBreakUnsupported, "control flow appears to require multi-level `break`, which Slang does not yet support"); + +DIAGNOSTIC(52001, Warning, dxilNotFound, "dxil shared library not found, so 'dxc' output cannot be signed! Shader code will not be runnable in non-development environments."); + +// 99999 - Internal compiler errors, and not-yet-classified diagnostics. + +DIAGNOSTIC(99999, Internal, unimplemented, "unimplemented feature in Slang compiler: $0") +DIAGNOSTIC(99999, Internal, unexpected, "unexpected condition encountered in Slang compiler: $0") +DIAGNOSTIC(99999, Internal, internalCompilerError, "Slang internal compiler error") +DIAGNOSTIC(99999, Error, compilationAborted, "Slang compilation aborted due to internal error"); +DIAGNOSTIC(99999, Error, compilationAbortedDueToException, "Slang compilation aborted due to an exception of $0: $1"); +DIAGNOSTIC(99999, Note, noteLocationOfInternalError, "the Slang compiler threw an exception while working on code near this location"); +DIAGNOSTIC(99999, Internal, serialDebugVerificationFailed, "Verification of serial debug information failed."); + +#undef DIAGNOSTIC diff --git a/source/slang/slang-diagnostics.cpp b/source/slang/slang-diagnostics.cpp new file mode 100644 index 000000000..4aabd3ab9 --- /dev/null +++ b/source/slang/slang-diagnostics.cpp @@ -0,0 +1,350 @@ +// slang-diagnostics.cpp +#include "slang-diagnostics.h" + +#include "slang-compiler.h" +#include "slang-name.h" +#include "slang-syntax.h" + +#include + +#ifdef _WIN32 +#define WIN32_LEAN_AND_MEAN +#define NOMINMAX +#include +#undef WIN32_LEAN_AND_MEAN +#undef NOMINMAX +#include +#endif + +namespace Slang { + +void printDiagnosticArg(StringBuilder& sb, char const* str) +{ + sb << str; +} + +void printDiagnosticArg(StringBuilder& sb, int32_t val) +{ + sb << val; +} + +void printDiagnosticArg(StringBuilder& sb, uint32_t val) +{ + sb << val; +} + +void printDiagnosticArg(StringBuilder& sb, int64_t val) +{ + sb << val; +} + +void printDiagnosticArg(StringBuilder& sb, uint64_t val) +{ + sb << val; +} + +void printDiagnosticArg(StringBuilder& sb, Slang::String const& str) +{ + sb << str; +} + +void printDiagnosticArg(StringBuilder& sb, Slang::UnownedStringSlice const& str) +{ + sb.append(str); +} + + +void printDiagnosticArg(StringBuilder& sb, Name* name) +{ + sb << getText(name); +} + + +void printDiagnosticArg(StringBuilder& sb, Decl* decl) +{ + sb << getText(decl->getName()); +} + +void printDiagnosticArg(StringBuilder& sb, Type* type) +{ + sb << type->ToString(); +} + +void printDiagnosticArg(StringBuilder& sb, Val* val) +{ + sb << val->ToString(); +} + +void printDiagnosticArg(StringBuilder& sb, TypeExp const& type) +{ + sb << type.type->ToString(); +} + +void printDiagnosticArg(StringBuilder& sb, QualType const& type) +{ + if (type.type) + sb << type.type->ToString(); + else + sb << ""; +} + +void printDiagnosticArg(StringBuilder& sb, TokenType tokenType) +{ + sb << TokenTypeToString(tokenType); +} + +void printDiagnosticArg(StringBuilder& sb, Token const& token) +{ + sb << token.Content; +} + +void printDiagnosticArg(StringBuilder& sb, CodeGenTarget val) +{ + switch( val ) + { + default: + sb << ""; + break; + +#define CASE(TAG, STR) case CodeGenTarget::TAG: sb << STR; break + CASE(GLSL, "glsl"); + CASE(HLSL, "hlsl"); + CASE(SPIRV, "spirv"); + CASE(SPIRVAssembly, "spriv-assembly"); + CASE(DXBytecode, "dxbc"); + CASE(DXBytecodeAssembly, "dxbc-assembly"); + CASE(DXIL, "dxil"); + CASE(DXILAssembly, "dxil-assembly"); +#undef CASE + } +} + +void printDiagnosticArg(StringBuilder& sb, Stage val) +{ + sb << getStageName(val); +} + +void printDiagnosticArg(StringBuilder& sb, ProfileVersion val) +{ + sb << Profile(val).getName(); +} + + +SourceLoc const& getDiagnosticPos(SyntaxNode const* syntax) +{ + return syntax->loc; +} + +SourceLoc const& getDiagnosticPos(Token const& token) +{ + return token.loc; +} + +SourceLoc const& getDiagnosticPos(TypeExp const& typeExp) +{ + return typeExp.exp->loc; +} + +SourceLoc const& getDiagnosticPos(IRInst* inst) +{ + return inst->sourceLoc; +} + + +// Take the format string for a diagnostic message, along with its arguments, and turn it into a +static void formatDiagnosticMessage(StringBuilder& sb, char const* format, int argCount, DiagnosticArg const* const* args) +{ + char const* spanBegin = format; + for(;;) + { + char const* spanEnd = spanBegin; + while (int c = *spanEnd) + { + if (c == '$') + break; + spanEnd++; + } + + sb.Append(spanBegin, int(spanEnd - spanBegin)); + if (!*spanEnd) + return; + + SLANG_ASSERT(*spanEnd == '$'); + spanEnd++; + int d = *spanEnd++; + switch (d) + { + // A double dollar sign `$$` is used to emit a single `$` + case '$': + sb.Append('$'); + break; + + // A single digit means to emit the corresponding argument. + // TODO: support more than 10 arguments, and add options + // to control formatting, etc. + case '0': case '1': case '2': case '3': case '4': + case '5': case '6': case '7': case '8': case '9': + { + int index = d - '0'; + if (index >= argCount) + { + // TODO(tfoley): figure out what a good policy will be for "panic" situations like this + throw InvalidOperationException("too few arguments for diagnostic message"); + } + else + { + DiagnosticArg const* arg = args[index]; + arg->printFunc(sb, arg->data); + } + } + break; + + default: + throw InvalidOperationException("invalid diagnostic message format"); + break; + } + + spanBegin = spanEnd; + } +} + +static void formatDiagnostic(const HumaneSourceLoc& humaneLoc, Diagnostic const& diagnostic, StringBuilder& outBuilder) +{ + outBuilder << humaneLoc.pathInfo.foundPath; + outBuilder << "("; + outBuilder << Int32(humaneLoc.line); + outBuilder << "): "; + + outBuilder << getSeverityName(diagnostic.severity); + + if (diagnostic.ErrorID >= 0) + { + outBuilder << " "; + outBuilder << diagnostic.ErrorID; + } + + outBuilder << ": "; + outBuilder << diagnostic.Message; + outBuilder << "\n"; +} + +static void formatDiagnostic( + DiagnosticSink* sink, + Diagnostic const& diagnostic, + StringBuilder& sb) +{ + auto sourceManager = sink->sourceManager; + + SourceView* sourceView = nullptr; + HumaneSourceLoc humaneLoc; + const auto sourceLoc = diagnostic.loc; + { + sourceView = sourceManager->findSourceViewRecursively(sourceLoc); + if (sourceView) + { + humaneLoc = sourceView->getHumaneLoc(sourceLoc); + } + formatDiagnostic(humaneLoc, diagnostic, sb); + } + + if (sourceView && (sink->flags & DiagnosticSink::Flag::VerbosePath)) + { + auto actualHumaneLoc = sourceView->getHumaneLoc(diagnostic.loc, SourceLocType::Actual); + + // Look up the path verbosely (will get the canonical path if necessary) + actualHumaneLoc.pathInfo.foundPath = sourceView->getSourceFile()->calcVerbosePath(); + + // Only output if it's actually different + if (actualHumaneLoc.pathInfo.foundPath != humaneLoc.pathInfo.foundPath || + actualHumaneLoc.line != humaneLoc.line || + actualHumaneLoc.column != humaneLoc.column) + { + formatDiagnostic(actualHumaneLoc, diagnostic, sb); + } + } +} + +void DiagnosticSink::diagnoseImpl(SourceLoc const& pos, DiagnosticInfo const& info, int argCount, DiagnosticArg const* const* args) +{ + StringBuilder sb; + formatDiagnosticMessage(sb, info.messageFormat, argCount, args); + + Diagnostic diagnostic; + diagnostic.ErrorID = info.id; + diagnostic.Message = sb.ProduceString(); + diagnostic.loc = pos; + diagnostic.severity = info.severity; + + if (diagnostic.severity >= Severity::Error) + { + errorCount++; + } + + // Did the client supply a callback for us to use? + if( writer ) + { + // If so, pass the error string along to them + StringBuilder messageBuilder; + formatDiagnostic(this, diagnostic, messageBuilder); + + writer->write(messageBuilder.getBuffer(), messageBuilder.getLength()); + } + else + { + // If the user doesn't have a callback, then just + // collect our diagnostic messages into a buffer + formatDiagnostic(this, diagnostic, outputBuffer); + } + + if (diagnostic.severity >= Severity::Fatal) + { + // TODO: figure out a better policy for aborting compilation + throw AbortCompilationException(); + } +} + +void DiagnosticSink::diagnoseRaw( + Severity severity, + char const* message) +{ + return diagnoseRaw(severity, UnownedStringSlice(message)); +} + +void DiagnosticSink::diagnoseRaw( + Severity severity, + const UnownedStringSlice& message) +{ + if (severity >= Severity::Error) + { + errorCount++; + } + + // Did the client supply a callback for us to use? + if(writer) + { + // If so, pass the error string along to them + writer->write(message.begin(), message.size()); + } + else + { + // If the user doesn't have a callback, then just + // collect our diagnostic messages into a buffer + outputBuffer.append(message); + } + + if (severity >= Severity::Fatal) + { + // TODO: figure out a better policy for aborting compilation + throw InvalidOperationException(); + } +} + + +namespace Diagnostics +{ +#define DIAGNOSTIC(id, severity, name, messageFormat) const DiagnosticInfo name = { id, Severity::severity, messageFormat }; +#include "slang-diagnostic-defs.h" +} + + +} // namespace Slang diff --git a/source/slang/slang-diagnostics.h b/source/slang/slang-diagnostics.h new file mode 100644 index 000000000..e1b9846d7 --- /dev/null +++ b/source/slang/slang-diagnostics.h @@ -0,0 +1,280 @@ +#ifndef RASTER_RENDERER_COMPILE_ERROR_H +#define RASTER_RENDERER_COMPILE_ERROR_H + +#include "../core/slang-basic.h" +#include "../core/slang-writer.h" + +#include "slang-source-loc.h" +#include "slang-token.h" + +#include "../../slang.h" + +namespace Slang +{ + enum class Severity + { + Note, + Warning, + Error, + Fatal, + Internal, + }; + + // TODO(tfoley): move this into a source file... + inline const char* getSeverityName(Severity severity) + { + switch (severity) + { + case Severity::Note: return "note"; + case Severity::Warning: return "warning"; + case Severity::Error: return "error"; + case Severity::Fatal: return "fatal error"; + case Severity::Internal: return "internal error"; + default: return "unknown error"; + } + } + + // A structure to be used in static data describing different + // diagnostic messages. + struct DiagnosticInfo + { + int id; + Severity severity; + char const* messageFormat; + }; + + class Diagnostic + { + public: + String Message; + SourceLoc loc; + int ErrorID; + Severity severity; + + Diagnostic() + { + ErrorID = -1; + } + Diagnostic( + const String & msg, + int id, + const SourceLoc & pos, + Severity severity) + : severity(severity) + { + Message = msg; + ErrorID = id; + loc = pos; + } + }; + + class Name; + class Decl; + struct QualType; + class Type; + struct TypeExp; + class Val; + + enum class CodeGenTarget; + enum class Stage : SlangStage; + enum class ProfileVersion; + + void printDiagnosticArg(StringBuilder& sb, char const* str); + + void printDiagnosticArg(StringBuilder& sb, int32_t val); + void printDiagnosticArg(StringBuilder& sb, uint32_t val); + + void printDiagnosticArg(StringBuilder& sb, int64_t val); + void printDiagnosticArg(StringBuilder& sb, uint64_t val); + + void printDiagnosticArg(StringBuilder& sb, Slang::String const& str); + void printDiagnosticArg(StringBuilder& sb, Slang::UnownedStringSlice const& str); + void printDiagnosticArg(StringBuilder& sb, Name* name); + void printDiagnosticArg(StringBuilder& sb, Decl* decl); + void printDiagnosticArg(StringBuilder& sb, Type* type); + void printDiagnosticArg(StringBuilder& sb, TypeExp const& type); + void printDiagnosticArg(StringBuilder& sb, QualType const& type); + void printDiagnosticArg(StringBuilder& sb, TokenType tokenType); + void printDiagnosticArg(StringBuilder& sb, Token const& token); + void printDiagnosticArg(StringBuilder& sb, CodeGenTarget val); + void printDiagnosticArg(StringBuilder& sb, Stage val); + void printDiagnosticArg(StringBuilder& sb, ProfileVersion val); + void printDiagnosticArg(StringBuilder& sb, Val* val); + + template + void printDiagnosticArg(StringBuilder& sb, RefPtr ptr) + { + printDiagnosticArg(sb, ptr.Ptr()); + } + + inline SourceLoc const& getDiagnosticPos(SourceLoc const& pos) { return pos; } + + class SyntaxNode; + SourceLoc const& getDiagnosticPos(SyntaxNode const* syntax); + SourceLoc const& getDiagnosticPos(Token const& token); + SourceLoc const& getDiagnosticPos(TypeExp const& typeExp); + + struct IRInst; + SourceLoc const& getDiagnosticPos(IRInst* inst); + + template + SourceLoc getDiagnosticPos(RefPtr const& ptr) + { + return getDiagnosticPos(ptr.Ptr()); + } + + struct DiagnosticArg + { + void* data; + void (*printFunc)(StringBuilder&, void*); + + template + struct Helper + { + static void printFunc(StringBuilder& sb, void* data) { printDiagnosticArg(sb, *(T*)data); } + }; + + template + DiagnosticArg(T const& arg) + : data((void*)&arg) + , printFunc(&Helper::printFunc) + {} + }; + + class DiagnosticSink + { + public: + struct Flag + { + enum Enum: uint32_t + { + VerbosePath = 0x1, ///< Will display a more verbose path (if available) - such as a canonical or absolute path + }; + }; + typedef uint32_t Flags; + + StringBuilder outputBuffer; +// List diagnostics; + int errorCount = 0; + int internalErrorLocsNoted = 0; + + ISlangWriter* writer = nullptr; + Flags flags = 0; + + // The source manager to use when mapping source locations to file+line info + SourceManager* sourceManager = nullptr; + +/* + void Error(int id, const String & msg, const SourceLoc & pos) + { + diagnostics.Add(Diagnostic(msg, id, pos, Severity::Error)); + errorCount++; + } + + void Warning(int id, const String & msg, const SourceLoc & pos) + { + diagnostics.Add(Diagnostic(msg, id, pos, Severity::Warning)); + } +*/ + int GetErrorCount() { return errorCount; } + + void diagnoseDispatch(SourceLoc const& pos, DiagnosticInfo const& info) + { + diagnoseImpl(pos, info, 0, nullptr); + } + + void diagnoseDispatch(SourceLoc const& pos, DiagnosticInfo const& info, DiagnosticArg const& arg0) + { + DiagnosticArg const* args[] = { &arg0 }; + diagnoseImpl(pos, info, 1, args); + } + + void diagnoseDispatch(SourceLoc const& pos, DiagnosticInfo const& info, DiagnosticArg const& arg0, DiagnosticArg const& arg1) + { + DiagnosticArg const* args[] = { &arg0, &arg1 }; + diagnoseImpl(pos, info, 2, args); + } + + void diagnoseDispatch(SourceLoc const& pos, DiagnosticInfo const& info, DiagnosticArg const& arg0, DiagnosticArg const& arg1, DiagnosticArg const& arg2) + { + DiagnosticArg const* args[] = { &arg0, &arg1, &arg2 }; + diagnoseImpl(pos, info, 3, args); + } + + void diagnoseDispatch(SourceLoc const& pos, DiagnosticInfo const& info, DiagnosticArg const& arg0, DiagnosticArg const& arg1, DiagnosticArg const& arg2, DiagnosticArg const& arg3) + { + DiagnosticArg const* args[] = { &arg0, &arg1, &arg2, &arg3 }; + diagnoseImpl(pos, info, 4, args); + } + + template + void diagnose(P const& pos, DiagnosticInfo const& info, Args const&... args ) + { + diagnoseDispatch(getDiagnosticPos(pos), info, args...); + } + + void diagnoseImpl(SourceLoc const& pos, DiagnosticInfo const& info, int argCount, DiagnosticArg const* const* args); + + // Add a diagnostic with raw text + // (used when we get errors from a downstream compiler) + void diagnoseRaw( + Severity severity, + char const* message); + void diagnoseRaw( + Severity severity, + const UnownedStringSlice& message); + + /// During propagation of an exception for an internal + /// error, note that this source location was involved + void noteInternalErrorLoc(SourceLoc const& loc); + + SlangResult getBlobIfNeeded(ISlangBlob** outBlob); + }; + + /// An `ISlangWriter` that writes directly to a diagnostic sink. + class DiagnosticSinkWriter : public AppendBufferWriter + { + public: + typedef AppendBufferWriter Super; + + DiagnosticSinkWriter(DiagnosticSink* sink) + : Super(WriterFlag::IsStatic) + , m_sink(sink) + {} + + // ISlangWriter + SLANG_NO_THROW virtual SlangResult SLANG_MCALL write(const char* chars, size_t numChars) SLANG_OVERRIDE + { + m_sink->diagnoseRaw(Severity::Note, UnownedStringSlice(chars, chars+numChars)); + return SLANG_OK; + } + + private: + DiagnosticSink* m_sink = nullptr; + }; + + namespace Diagnostics + { +#define DIAGNOSTIC(id, severity, name, messageFormat) extern const DiagnosticInfo name; +#include "slang-diagnostic-defs.h" + } +} + +#ifdef _DEBUG +#define SLANG_INTERNAL_ERROR(sink, pos) \ + (sink)->diagnose(Slang::SourceLoc(__LINE__, 0, 0, __FILE__), Slang::Diagnostics::internalCompilerError) +#define SLANG_UNIMPLEMENTED(sink, pos, what) \ + (sink)->diagnose(Slang::SourceLoc(__LINE__, 0, 0, __FILE__), Slang::Diagnostics::unimplemented, what) + +#else +#define SLANG_INTERNAL_ERROR(sink, pos) \ + (sink)->diagnose(pos, Slang::Diagnostics::internalCompilerError) +#define SLANG_UNIMPLEMENTED(sink, pos, what) \ + (sink)->diagnose(pos, Slang::Diagnostics::unimplemented, what) + +#endif + +#define SLANG_DIAGNOSE_UNEXPECTED(sink, pos, message) \ + (sink)->diagnose(pos, Slang::Diagnostics::unexpected, message) + +#endif diff --git a/source/slang/slang-dxc-support.cpp b/source/slang/slang-dxc-support.cpp new file mode 100644 index 000000000..b4bc77fe5 --- /dev/null +++ b/source/slang/slang-dxc-support.cpp @@ -0,0 +1,302 @@ +// slang-dxc-support.cpp +#include "slang-compiler.h" + +// This file implements support for invoking the `dxcompiler` +// library to translate HLSL to DXIL. + +#if defined(_WIN32) +# if !defined(SLANG_ENABLE_DXIL_SUPPORT) +# define SLANG_ENABLE_DXIL_SUPPORT 1 +# endif +#endif + +#if !defined(SLANG_ENABLE_DXIL_SUPPORT) +# define SLANG_ENABLE_DXIL_SUPPORT 0 +#endif + +#if SLANG_ENABLE_DXIL_SUPPORT + +#define WIN32_LEAN_AND_MEAN +#define NOMINMAX +#include +#include +#include "../../external/dxc/dxcapi.h" +#undef WIN32_LEAN_AND_MEAN +#undef NOMINMAX + +#include "../core/slang-platform.h" + +namespace Slang +{ + String GetHLSLProfileName(Profile profile); + String emitHLSLForEntryPoint( + BackEndCompileRequest* compileRequest, + EntryPoint* entryPoint, + Int entryPointIndex, + TargetRequest* targetReq, + EndToEndCompileRequest* endToEndReq); + + static UnownedStringSlice _getSlice(IDxcBlob* blob) + { + if (blob) + { + const char* chars = (const char*)blob->GetBufferPointer(); + size_t len = blob->GetBufferSize(); + len -= size_t(len > 0 && chars[len - 1] == 0); + return UnownedStringSlice(chars, len); + } + return UnownedStringSlice(); + } + + SlangResult emitDXILForEntryPointUsingDXC( + BackEndCompileRequest* compileRequest, + EntryPoint* entryPoint, + Int entryPointIndex, + TargetRequest* targetReq, + EndToEndCompileRequest* endToEndReq, + List& outCode) + { + auto session = compileRequest->getSession(); + auto sink = compileRequest->getSink(); + + // First deal with all the rigamarole of loading + // the `dxcompiler` library, and creating the + // top-level COM objects that will be used to + // compile things. + + auto dxcCreateInstance = (DxcCreateInstanceProc)session->getSharedLibraryFunc(Session::SharedLibraryFuncType::Dxc_DxcCreateInstance, sink); + if (!dxcCreateInstance) + { + return SLANG_FAIL; + } + + { + if (!session->getSharedLibrary(SharedLibraryType::Dxil)) + { + // If can't load dxil - dxc will not be able to sign output + // Output a suitable warning to the user + sink->diagnose(SourceLoc(), Diagnostics::dxilNotFound); + } + } + + ComPtr dxcCompiler; + SLANG_RETURN_ON_FAIL(dxcCreateInstance( + CLSID_DxcCompiler, + __uuidof(dxcCompiler), + (LPVOID*)dxcCompiler.writeRef())); + + ComPtr dxcLibrary; + SLANG_RETURN_ON_FAIL(dxcCreateInstance( + CLSID_DxcLibrary, + __uuidof(dxcLibrary), + (LPVOID*)dxcLibrary.writeRef())); + + // Now let's go ahead and generate HLSL for the entry + // point, since we'll need that to feed into dxc. + auto hlslCode = emitHLSLForEntryPoint( + compileRequest, + entryPoint, + entryPointIndex, + targetReq, + endToEndReq); + maybeDumpIntermediate(compileRequest, hlslCode.getBuffer(), CodeGenTarget::HLSL); + + // Wrap the + + // Create blob from the string + ComPtr dxcSourceBlob; + SLANG_RETURN_ON_FAIL(dxcLibrary->CreateBlobWithEncodingFromPinned( + (LPBYTE)hlslCode.getBuffer(), + (UINT32)hlslCode.getLength(), + 0, + dxcSourceBlob.writeRef())); + + WCHAR const* args[16]; + UINT32 argCount = 0; + + // TODO: deal with + bool treatWarningsAsErrors = false; + if (treatWarningsAsErrors) + { + args[argCount++] = L"-WX"; + } + + switch( targetReq->getDefaultMatrixLayoutMode() ) + { + default: + break; + + case kMatrixLayoutMode_RowMajor: + args[argCount++] = L"-Zpr"; + break; + } + + switch( targetReq->getFloatingPointMode() ) + { + default: + break; + + case FloatingPointMode::Precise: + args[argCount++] = L"-Gis"; // "force IEEE strictness" + break; + } + + auto linkage = compileRequest->getLinkage(); + switch( linkage->optimizationLevel ) + { + default: + break; + + case OptimizationLevel::None: args[argCount++] = L"-Od"; break; + case OptimizationLevel::Default: args[argCount++] = L"-O1"; break; + case OptimizationLevel::High: args[argCount++] = L"-O2"; break; + case OptimizationLevel::Maximal: args[argCount++] = L"-O3"; break; + } + + switch( linkage->debugInfoLevel ) + { + case DebugInfoLevel::None: + break; + + default: + args[argCount++] = L"-Zi"; + break; + } + + // Slang strives to produce correct code, and by default + // we do not show the user warnings produced by a downstream + // compiler. When the downstream compiler *does* produce an + // error, then we dump its entire diagnostic log, which can + // include many distracting spurious warnings that have nothing + // to do with the user's code, and just relate to the idiomatic + // way that Slang outputs HLSL. + // + // It would be nice to use fine-grained flags to disable specific + // warnings here, so that we keep ourselves honest (e.g., only + // use `-Wno-parentheses` to eliminate that class of false positives), + // but alas dxc doesn't support these options even though they + // work on mainline Clang. Thus the only option we have available + // is the big hammer of turning off *all* warnings coming from dxc. + // + args[argCount++] = L"-no-warnings"; + + String entryPointName = getText(entryPoint->getName()); + OSString wideEntryPointName = entryPointName.toWString(); + + auto profile = getEffectiveProfile(entryPoint, targetReq); + String profileName = GetHLSLProfileName(profile); + OSString wideProfileName = profileName.toWString(); + + // We will enable the flag to generate proper code for 16-bit types + // by default, as long as the user is requesting a sufficiently + // high shader model. + // + // TODO: Need to check that this is safe to enable in all cases, + // or if it will make a shader demand hardware features that + // aren't always present. + // + // TODO: Ideally the dxc back-end should be passed some information + // on the "capabilities" that were used and/or requested in the code. + // + if( profile.GetVersion() >= ProfileVersion::DX_6_2 ) + { + args[argCount++] = L"-enable-16bit-types"; + } + + const String sourcePath = calcSourcePathForEntryPoint(endToEndReq, entryPointIndex); + + ComPtr dxcResult; + SLANG_RETURN_ON_FAIL(dxcCompiler->Compile(dxcSourceBlob, + sourcePath.toWString().begin(), + profile.GetStage() == Stage::Unknown ? L"" : wideEntryPointName.begin(), + wideProfileName.begin(), + args, + argCount, + nullptr, // `#define`s + 0, // `#define` count + nullptr, // `#include` handler + dxcResult.writeRef())); + + // Retrieve result. + HRESULT resultCode = S_OK; + SLANG_RETURN_ON_FAIL(dxcResult->GetStatus(&resultCode)); + + // Note: it seems like the dxcompiler interface + // doesn't support querying diagnostic output + // *unless* the compile failed (no way to get + // warnings out!?). + + // Verify compile result + if (SLANG_FAILED(resultCode)) + { + // Compilation failed. + // Try to read any diagnostic output. + ComPtr dxcErrorBlob; + SLANG_RETURN_ON_FAIL(dxcResult->GetErrorBuffer(dxcErrorBlob.writeRef())); + + // Note: the error blob returned by dxc doesn't always seem + // to be nul-terminated, so we should be careful and turn it + // into a string for safety. + // + + reportExternalCompileError("dxc", resultCode, _getSlice(dxcErrorBlob), compileRequest->getSink()); + return resultCode; + } + + // Okay, the compile supposedly succeeded, so we + // just need to grab the buffer with the output DXIL. + ComPtr dxcResultBlob; + SLANG_RETURN_ON_FAIL(dxcResult->GetResult(dxcResultBlob.writeRef())); + + outCode.addRange( + (uint8_t const*)dxcResultBlob->GetBufferPointer(), + (int) dxcResultBlob->GetBufferSize()); + + return SLANG_OK; + } + + SlangResult dissassembleDXILUsingDXC( + BackEndCompileRequest* compileRequest, + void const* data, + size_t size, + String& stringOut) + { + stringOut = String(); + auto session = compileRequest->getSession(); + auto sink = compileRequest->getSink(); + + // First deal with all the rigamarole of loading + // the `dxcompiler` library, and creating the + // top-level COM objects that will be used to + // compile things. + + auto dxcCreateInstance = (DxcCreateInstanceProc)session->getSharedLibraryFunc(Session::SharedLibraryFuncType::Dxc_DxcCreateInstance, sink); + if (!dxcCreateInstance) + { + return SLANG_FAIL; + } + + ComPtr dxcCompiler; + SLANG_RETURN_ON_FAIL(dxcCreateInstance(CLSID_DxcCompiler, __uuidof(dxcCompiler), (LPVOID*) dxcCompiler.writeRef())); + ComPtr dxcLibrary; + SLANG_RETURN_ON_FAIL(dxcCreateInstance(CLSID_DxcLibrary, __uuidof(dxcLibrary), (LPVOID*) dxcLibrary.writeRef())); + + // Create blob from the input data + ComPtr dxcSourceBlob; + SLANG_RETURN_ON_FAIL(dxcLibrary->CreateBlobWithEncodingFromPinned((LPBYTE) data, (UINT32) size, 0, dxcSourceBlob.writeRef())); + + ComPtr dxcResultBlob; + SLANG_RETURN_ON_FAIL(dxcCompiler->Disassemble(dxcSourceBlob, dxcResultBlob.writeRef())); + + stringOut = _getSlice(dxcResultBlob); + + return SLANG_OK; + } + + +} // namespace Slang + +#endif + + + diff --git a/source/slang/slang-emit-context.h b/source/slang/slang-emit-context.h index 75e65feee..337d0434f 100644 --- a/source/slang/slang-emit-context.h +++ b/source/slang/slang-emit-context.h @@ -2,10 +2,10 @@ #ifndef SLANG_EMIT_CONTEXT_H_INCLUDED #define SLANG_EMIT_CONTEXT_H_INCLUDED -#include "../core/basic.h" +#include "../core/slang-basic.h" -#include "compiler.h" -#include "type-layout.h" +#include "slang-compiler.h" +#include "slang-type-layout.h" #include "slang-source-stream.h" #include "slang-extension-usage-tracker.h" diff --git a/source/slang/slang-emit-precedence.h b/source/slang/slang-emit-precedence.h index 30783f685..8d8a146c6 100644 --- a/source/slang/slang-emit-precedence.h +++ b/source/slang/slang-emit-precedence.h @@ -2,7 +2,7 @@ #ifndef SLANG_EMIT_PRECEDENCE_H_INCLUDED #define SLANG_EMIT_PRECEDENCE_H_INCLUDED -#include "../core/basic.h" +#include "../core/slang-basic.h" namespace Slang { diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp new file mode 100644 index 000000000..67b57660a --- /dev/null +++ b/source/slang/slang-emit.cpp @@ -0,0 +1,510 @@ +// slang-emit.cpp +#include "slang-emit.h" + +#include "../core/slang-writer.h" +#include "slang-ir-bind-existentials.h" +#include "slang-ir-dce.h" +#include "slang-ir-entry-point-uniforms.h" +#include "slang-ir-glsl-legalize.h" +#include "slang-ir-insts.h" +#include "slang-ir-link.h" +#include "slang-ir-restructure.h" +#include "slang-ir-restructure-scoping.h" +#include "slang-ir-specialize.h" +#include "slang-ir-specialize-resources.h" +#include "slang-ir-ssa.h" +#include "slang-ir-union.h" +#include "slang-ir-validate.h" +#include "slang-legalize-types.h" +#include "slang-lower-to-ir.h" +#include "slang-mangle.h" +#include "slang-name.h" +#include "slang-syntax.h" +#include "slang-type-layout.h" +#include "slang-visitor.h" + +#include "slang-source-stream.h" +#include "slang-emit-context.h" + +#include "slang-c-like-source-emitter.h" + +#include + +namespace Slang { + +enum class BuiltInCOp +{ + Splat, //< Splat a single value to all values of a vector or matrix type + Init, //< Initialize with parameters (must match the type) +}; + + +// + + +// + +EntryPointLayout* findEntryPointLayout( + ProgramLayout* programLayout, + EntryPoint* entryPoint) +{ + for( auto entryPointLayout : programLayout->entryPoints ) + { + if(entryPointLayout->entryPoint->getName() != entryPoint->getName()) + continue; + + // TODO: We need to be careful about this check, since it relies on + // the profile information in the layout matching that in the request. + // + // What we really seem to want here is some dictionary mapping the + // `EntryPoint` directly to the `EntryPointLayout`, and maybe + // that is precisely what we should build... + // + if(entryPointLayout->profile != entryPoint->getProfile()) + continue; + + // TODO: can't easily filter on translation unit here... + // Ideally the `EntryPoint` should get filled in with a pointer + // the specific function declaration that represents the entry point. + + return entryPointLayout.Ptr(); + } + + return nullptr; +} + + /// Given a layout computed for a scope, get the layout to use when lookup up variables. + /// + /// A scope (such as the global scope of a program) groups its + /// parameters into a pseudo-`struct` type for layout purposes, + /// and in some cases that type will in turn be wrapped in a + /// `ConstantBuffer` type to indicate that the parameters needed + /// an implicit constant buffer to be allocated. + /// + /// This function "unwraps" the type layout to find the structure + /// type layout that must be stored inside. + /// +StructTypeLayout* getScopeStructLayout( + ScopeLayout* scopeLayout) +{ + auto scopeTypeLayout = scopeLayout->parametersLayout->typeLayout; + + if( auto constantBufferTypeLayout = as(scopeTypeLayout) ) + { + scopeTypeLayout = constantBufferTypeLayout->offsetElementTypeLayout; + } + + if( auto structTypeLayout = as(scopeTypeLayout) ) + { + return structTypeLayout; + } + + SLANG_UNEXPECTED("uhandled global-scope binding layout"); + return nullptr; +} + + /// Given a layout computed for a program, get the layout to use when lookup up variables. + /// + /// This is just an alias of `getScopeStructLayout`. + /// +StructTypeLayout* getGlobalStructLayout( + ProgramLayout* programLayout) +{ + return getScopeStructLayout(programLayout); +} + +static void dumpIR( + BackEndCompileRequest* compileRequest, + IRModule* irModule, + char const* label) +{ + DiagnosticSinkWriter writerImpl(compileRequest->getSink()); + WriterHelper writer(&writerImpl); + + if(label) + { + writer.put("### "); + writer.put(label); + writer.put(":\n"); + } + + dumpIR(irModule, writer.getWriter()); + + if( label ) + { + writer.put("###\n"); + } +} + +static void dumpIRIfEnabled( + BackEndCompileRequest* compileRequest, + IRModule* irModule, + char const* label = nullptr) +{ + if(compileRequest->shouldDumpIR) + { + dumpIR(compileRequest, irModule, label); + } +} + +String emitEntryPoint( + BackEndCompileRequest* compileRequest, + EntryPoint* entryPoint, + CodeGenTarget target, + TargetRequest* targetRequest) +{ + auto sink = compileRequest->getSink(); + auto program = compileRequest->getProgram(); + auto targetProgram = program->getTargetProgram(targetRequest); + auto programLayout = targetProgram->getOrCreateLayout(sink); + +// auto translationUnit = entryPoint->getTranslationUnit(); + + auto lineDirectiveMode = compileRequest->getLineDirectiveMode(); + // To try to make the default behavior reasonable, we will + // always use C-style line directives (to give the user + // good source locations on error messages from downstream + // compilers) *unless* they requested raw GLSL as the + // output (in which case we want to maximize compatibility + // with downstream tools). + if (lineDirectiveMode == LineDirectiveMode::Default && targetRequest->getTarget() == CodeGenTarget::GLSL) + { + lineDirectiveMode = LineDirectiveMode::GLSL; + } + + SourceStream sourceStream(compileRequest->getSourceManager(), lineDirectiveMode ); + + EmitContext emitContext; + emitContext.compileRequest = compileRequest; + emitContext.target = target; + emitContext.entryPoint = entryPoint; + emitContext.effectiveProfile = getEffectiveProfile(entryPoint, targetRequest); + emitContext.stream = &sourceStream; + + if (entryPoint && programLayout) + { + emitContext.entryPointLayout = findEntryPointLayout( + programLayout, + entryPoint); + } + + emitContext.programLayout = programLayout; + + // Layout information for the global scope is either an ordinary + // `struct` in the common case, or a constant buffer in the case + // where there were global-scope uniforms. + + StructTypeLayout* globalStructLayout = programLayout ? getGlobalStructLayout(programLayout) : nullptr; + emitContext.globalStructLayout = globalStructLayout; + + CLikeSourceEmitter sourceEmitter(&emitContext); + + { + auto session = targetRequest->getSession(); + + // We start out by performing "linking" at the level of the IR. + // This step will create a fresh IR module to be used for + // code generation, and will copy in any IR definitions that + // the desired entry point requires. Along the way it will + // resolve references to imported/exported symbols across + // modules, and also select between the definitions of + // any "profile-overloaded" symbols. + // + auto linkedIR = linkIR( + compileRequest, + entryPoint, + programLayout, + target, + targetRequest); + auto irModule = linkedIR.module; + auto irEntryPoint = linkedIR.entryPoint; + +#if 0 + dumpIRIfEnabled(compileRequest, irModule, "LINKED"); +#endif + + validateIRModuleIfEnabled(compileRequest, irModule); + + // If the user specified the flag that they want us to dump + // IR, then do it here, for the target-specific, but + // un-specialized IR. + dumpIRIfEnabled(compileRequest, irModule); + + // When there are top-level existential-type parameters + // to the shader, we need to take the side-band information + // on how the existential "slots" were bound to concrete + // types, and use it to introduce additional explicit + // shader parameters for those slots, to be wired up to + // use sites. + // + bindExistentialSlots(irModule, sink); +#if 0 + dumpIRIfEnabled(compileRequest, irModule, "EXISTENTIALS BOUND"); +#endif + validateIRModuleIfEnabled(compileRequest, irModule); + + + + + + // Now that we've linked the IR code, any layout/binding + // information has been attached to shader parameters + // and entry points. Now we are safe to make transformations + // that might move code without worrying about losing + // the connection between a parameter and its layout. + // + // An easy transformation of this kind is to take uniform + // parameters of a shader entry point and move them into + // the global scope instead. + // + moveEntryPointUniformParamsToGlobalScope(irModule); +#if 0 + dumpIRIfEnabled(compileRequest, irModule, "ENTRY POINT UNIFORMS MOVED"); +#endif + validateIRModuleIfEnabled(compileRequest, irModule); + + // Desguar any union types, since these will be illegal on + // various targets. + // + desugarUnionTypes(irModule); +#if 0 + dumpIRIfEnabled(compileRequest, irModule, "UNIONS DESUGARED"); +#endif + validateIRModuleIfEnabled(compileRequest, irModule); + + // Next, we need to ensure that the code we emit for + // the target doesn't contain any operations that would + // be illegal on the target platform. For example, + // none of our target supports generics, or interfaces, + // so we need to specialize those away. + // + // Simplification of existential-based and generics-based + // code may each open up opportunities for the other, so + // the relevant specialization transformations are handled in a + // single pass that looks for all simplification opportunities. + // + // TODO: We also need to extend this pass so that it will "expose" + // existential values that are nested inside of other types, + // so that the simplifications can be applied. + // + // TODO: This pass is *also* likely to be the place where we + // perform specialization of functions based on parameter + // values that need to be compile-time constants. + // + specializeModule(irModule); + + // Debugging code for IR transformations... +#if 0 + dumpIRIfEnabled(compileRequest, irModule, "SPECIALIZED"); +#endif + validateIRModuleIfEnabled(compileRequest, irModule); + + + // Specialization can introduce dead code that could trip + // up downstream passes like type legalization, so we + // will run a DCE pass to clean up after the specialization. + // + // TODO: Are there other cleanup optimizations we should + // apply at this point? + // + eliminateDeadCode(compileRequest, irModule); +#if 0 + dumpIRIfEnabled(compileRequest, irModule, "AFTER DCE"); +#endif + validateIRModuleIfEnabled(compileRequest, irModule); + + // The Slang language allows interfaces to be used like + // ordinary types (including placing them in constant + // buffers and entry-point parameter lists), but then + // getting them to lay out in a reasonable way requires + // us to treat fields/variables with interface type + // *as if* they were pointers to heap-allocated "objects." + // + // Specialization will have replaced fields/variables + // with interface types like `IFoo` with fields/variables + // with pointer-like types like `ExistentialBox`. + // + // We need to legalize these pointer-like types away, + // which involves two main changes: + // + // 1. Any `ExistentialBox<...>` fields need to be moved + // out of their enclosing `struct` type, so that the layout + // of the enclosing type is computed as if the field had + // zero size. + // + // 2. Once an `ExistentialBox` has been floated out + // of its parent and landed somwhere permanent (e.g., either + // a dedicated variable, or a field of constant buffer), + // we need to replace it with just an `X`, after which we + // will have (more) legal shader code. + // + legalizeExistentialTypeLayout( + irModule, + sink); + eliminateDeadCode(compileRequest, irModule); + +#if 0 + dumpIRIfEnabled(compileRequest, irModule, "EXISTENTIALS LEGALIZED"); +#endif + validateIRModuleIfEnabled(compileRequest, irModule); + + // Many of our target languages and/or downstream compilers + // don't support `struct` types that have resource-type fields. + // In order to work around this limitation, we will rewrite the + // IR so that any structure types with resource-type fields get + // split into a "tuple" that comprises the ordinary fields (still + // bundles up as a `struct`) and one element for each resource-type + // field (recursively). + // + // What used to be individual variables/parameters/arguments/etc. + // then become multiple variables/parameters/arguments/etc. + // + legalizeResourceTypes( + irModule, + sink); + eliminateDeadCode(compileRequest, irModule); + + // Debugging output of legalization +#if 0 + dumpIRIfEnabled(compileRequest, irModule, "LEGALIZED"); +#endif + validateIRModuleIfEnabled(compileRequest, irModule); + + // Once specialization and type legalization have been performed, + // we should perform some of our basic optimization steps again, + // to see if we can clean up any temporaries created by legalization. + // (e.g., things that used to be aggregated might now be split up, + // so that we can work with the individual fields). + constructSSA(irModule); + +#if 0 + dumpIRIfEnabled(compileRequest, irModule, "AFTER SSA"); +#endif + validateIRModuleIfEnabled(compileRequest, irModule); + + // After type legalization and subsequent SSA cleanup we expect + // that any resource types passed to functions are exposed + // as their own top-level parameters (which might have + // resource or array-of-...-resource types). + // + // Many of our targets place restrictions on how certain + // resource types can be used, so that having them as + // function parameters is invalid. To clean this up, + // we will try to specialize called functions based + // on the actual resources that are being passed to them + // at specific call sites. + // + // Because the legalization may depend on what target + // we are compiling for (certain things might be okay + // for D3D targets that are not okay for Vulkan), we + // pass down the target request along with the IR. + // + specializeResourceParameters(compileRequest, targetRequest, irModule); + +#if 0 + dumpIRIfEnabled(compileRequest, irModule, "AFTER RESOURCE SPECIALIZATION"); +#endif + validateIRModuleIfEnabled(compileRequest, irModule); + + + // For GLSL only, we will need to perform "legalization" of + // the entry point and any entry-point parameters. + // + // TODO: We should consider moving this legalization work + // as late as possible, so that it doesn't affect how other + // optimization passes need to work. + // + switch (target) + { + case CodeGenTarget::GLSL: + { + legalizeEntryPointForGLSL( + session, + irModule, + irEntryPoint, + compileRequest->getSink(), + &emitContext.extensionUsageTracker); + +#if 0 + dumpIRIfEnabled(compileRequest, irModule, "GLSL LEGALIZED"); +#endif + validateIRModuleIfEnabled(compileRequest, irModule); + } + break; + + default: + break; + } + + // The resource-based specialization pass above + // may create specialized versions of functions, but + // it does not try to completely eliminate the original + // functions, so there might still be invalid code in + // our IR module. + // + // To clean up the code, we will apply a fairly general + // dead-code-elimination (DCE) pass that only retains + // whatever code is "live." + // + eliminateDeadCode(compileRequest, irModule); +#if 0 + dumpIRIfEnabled(compileRequest, irModule, "AFTER DCE"); +#endif + validateIRModuleIfEnabled(compileRequest, irModule); + + // After all of the required optimization and legalization + // passes have been performed, we can emit target code from + // the IR module. + // + // TODO: do we want to emit directly from IR, or translate the + // IR back into AST for emission? + sourceEmitter.emitIRModule(irModule); + } + + // Deal with cases where a particular stage requires certain GLSL versions + // and/or extensions. + switch( entryPoint->getStage() ) + { + default: + break; + + case Stage::AnyHit: + case Stage::Callable: + case Stage::ClosestHit: + case Stage::Intersection: + case Stage::Miss: + case Stage::RayGeneration: + if( target == CodeGenTarget::GLSL ) + { + emitContext.extensionUsageTracker.requireGLSLExtension("GL_NV_ray_tracing"); + emitContext.extensionUsageTracker.requireGLSLVersion(ProfileVersion::GLSL_460); + } + break; + } + + String code = sourceStream.getContent(); + sourceStream.clearContent(); + + // Now that we've emitted the code for all the declarations in the file, + // it is time to stitch together the final output. + + // There may be global-scope modifiers that we should emit now + sourceEmitter.emitGLSLPreprocessorDirectives(); + + sourceEmitter.emitLayoutDirectives(targetRequest); + + String prefix = sourceStream.getContent(); + + StringBuilder finalResultBuilder; + finalResultBuilder << prefix; + + finalResultBuilder << emitContext.extensionUsageTracker.getGLSLExtensionRequireLines(); + + finalResultBuilder << code; + + String finalResult = finalResultBuilder.ProduceString(); + + return finalResult; +} + +} // namespace Slang diff --git a/source/slang/slang-emit.h b/source/slang/slang-emit.h new file mode 100644 index 000000000..c0981a5e4 --- /dev/null +++ b/source/slang/slang-emit.h @@ -0,0 +1,27 @@ +// slang-emit.h +#ifndef SLANG_EMIT_H_INCLUDED +#define SLANG_EMIT_H_INCLUDED + +#include "../core/slang-basic.h" + +#include "slang-compiler.h" + +namespace Slang +{ + class EntryPoint; + class ProgramLayout; + class TranslationUnitRequest; + + // Emit code for a single entry point, based on + // the input translation unit. + String emitEntryPoint( + BackEndCompileRequest* compileRequest, + EntryPoint* entryPoint, + + // The target language to generate code in (e.g., HLSL/GLSL) + CodeGenTarget target, + + // The full target request + TargetRequest* targetRequest); +} +#endif diff --git a/source/slang/slang-expr-defs.h b/source/slang/slang-expr-defs.h new file mode 100644 index 000000000..6cd893302 --- /dev/null +++ b/source/slang/slang-expr-defs.h @@ -0,0 +1,206 @@ +// slang-expr-defs.h + +// Syntax class definitions for expressions. + + +// Base class for expressions that will reference declarations +ABSTRACT_SYNTAX_CLASS(DeclRefExpr, Expr) + +// The scope in which to perform lookup + FIELD(RefPtr, scope) + + // The declaration of the symbol being referenced + DECL_FIELD(DeclRef, declRef) + + // The name of the symbol being referenced + FIELD(Name*, name) +END_SYNTAX_CLASS() + +SIMPLE_SYNTAX_CLASS(VarExpr, DeclRefExpr) + +// An expression that references an overloaded set of declarations +// having the same name. +SYNTAX_CLASS(OverloadedExpr, Expr) + + // Optional: the base expression is this overloaded result + // arose from a member-reference expression. + SYNTAX_FIELD(RefPtr, base) + + // The lookup result that was ambiguous + FIELD(LookupResult, lookupResult2) +END_SYNTAX_CLASS() + +// An expression that references an overloaded set of declarations +// having the same name. +SYNTAX_CLASS(OverloadedExpr2, Expr) + + // Optional: the base expression is this overloaded result + // arose from a member-reference expression. + SYNTAX_FIELD(RefPtr, base) + + // The lookup result that was ambiguous + FIELD(List>, candidiateExprs) +END_SYNTAX_CLASS() + +ABSTRACT_SYNTAX_CLASS(LiteralExpr, Expr) + // The token that was used to express the literal. This can be + // used to get the raw text of the literal, including any suffix. + FIELD(Token, token) +END_SYNTAX_CLASS() + +SYNTAX_CLASS(IntegerLiteralExpr, LiteralExpr) + FIELD(IntegerLiteralValue, value) +END_SYNTAX_CLASS() + +SYNTAX_CLASS(FloatingPointLiteralExpr, LiteralExpr) + FIELD(FloatingPointLiteralValue, value) +END_SYNTAX_CLASS() + +SYNTAX_CLASS(BoolLiteralExpr, LiteralExpr) + FIELD(bool, value) +END_SYNTAX_CLASS() + +SYNTAX_CLASS(StringLiteralExpr, LiteralExpr) + // TODO: consider storing the "segments" of the string + // literal, in the case where multiple literals were + //lined up at the lexer level, e.g.: + // + // "first" "second" "third" + // + FIELD(String, value) +END_SYNTAX_CLASS() + +// An initializer list, e.g. `{ 1, 2, 3 }` +SYNTAX_CLASS(InitializerListExpr, Expr) + SYNTAX_FIELD(List>, args) +END_SYNTAX_CLASS() + +// A base class for expressions with arguments +ABSTRACT_SYNTAX_CLASS(ExprWithArgsBase, Expr) + SYNTAX_FIELD(List>, Arguments) +END_SYNTAX_CLASS() + +// An aggregate type constructor +SYNTAX_CLASS(AggTypeCtorExpr, ExprWithArgsBase) + SYNTAX_FIELD(TypeExp, base); +END_SYNTAX_CLASS() + + +// A base expression being applied to arguments: covers +// both ordinary `()` function calls and `<>` generic application +ABSTRACT_SYNTAX_CLASS(AppExprBase, ExprWithArgsBase) + SYNTAX_FIELD(RefPtr, FunctionExpr) +END_SYNTAX_CLASS() + +SIMPLE_SYNTAX_CLASS(InvokeExpr, AppExprBase) + +SIMPLE_SYNTAX_CLASS(OperatorExpr, InvokeExpr) + +SIMPLE_SYNTAX_CLASS(InfixExpr , OperatorExpr) +SIMPLE_SYNTAX_CLASS(PrefixExpr , OperatorExpr) +SIMPLE_SYNTAX_CLASS(PostfixExpr, OperatorExpr) + +SYNTAX_CLASS(IndexExpr, Expr) + SYNTAX_FIELD(RefPtr, BaseExpression) + SYNTAX_FIELD(RefPtr, IndexExpression) +END_SYNTAX_CLASS() + +SYNTAX_CLASS(MemberExpr, DeclRefExpr) + SYNTAX_FIELD(RefPtr, BaseExpression) +END_SYNTAX_CLASS() + +// Member looked up on a type, rather than a value +SYNTAX_CLASS(StaticMemberExpr, DeclRefExpr) + SYNTAX_FIELD(RefPtr, BaseExpression) +END_SYNTAX_CLASS() + +SYNTAX_CLASS(SwizzleExpr, Expr) + SYNTAX_FIELD(RefPtr, base) + FIELD(int, elementCount) + FIELD(int, elementIndices[4]) +END_SYNTAX_CLASS() + +// A dereference of a pointer or pointer-like type +SYNTAX_CLASS(DerefExpr, Expr) + SYNTAX_FIELD(RefPtr, base) +END_SYNTAX_CLASS() + +// Any operation that performs type-casting +SYNTAX_CLASS(TypeCastExpr, InvokeExpr) +// SYNTAX_FIELD(TypeExp, TargetType) +// SYNTAX_FIELD(RefPtr, Expression) +END_SYNTAX_CLASS() + +// An explicit type-cast that appear in the user's code with `(type) expr` syntax +SYNTAX_CLASS(ExplicitCastExpr, TypeCastExpr) +END_SYNTAX_CLASS() + +// An implicit type-cast inserted during semantic checking +SYNTAX_CLASS(ImplicitCastExpr, TypeCastExpr) +END_SYNTAX_CLASS() + + /// A cast from a value to an interface ("existential") type. +SYNTAX_CLASS(CastToInterfaceExpr, Expr) +RAW( + /// The value being cast to an interface type + RefPtr valueArg; + + /// A witness showing that `valueArg` conforms to the chosen interface + RefPtr witnessArg; +) +END_SYNTAX_CLASS() + +SIMPLE_SYNTAX_CLASS(SelectExpr, OperatorExpr) + +SIMPLE_SYNTAX_CLASS(GenericAppExpr, AppExprBase) + +// An expression representing re-use of the syntax for a type in more +// than once conceptually-distinct declaration +SYNTAX_CLASS(SharedTypeExpr, Expr) + // The underlying type expression that we want to share + SYNTAX_FIELD(TypeExp, base) +END_SYNTAX_CLASS() + +SYNTAX_CLASS(AssignExpr, Expr) + SYNTAX_FIELD(RefPtr, left); + SYNTAX_FIELD(RefPtr, right); +END_SYNTAX_CLASS() + +// Just an expression inside parentheses `(exp)` +// +// We keep this around explicitly to be sure we don't lose any structure +// when we do rewriter stuff. +SYNTAX_CLASS(ParenExpr, Expr) + SYNTAX_FIELD(RefPtr, base); +END_SYNTAX_CLASS() + +// An object-oriented `this` expression, used to +// refer to the current instance of an enclosing type. +SYNTAX_CLASS(ThisExpr, Expr) + FIELD(RefPtr, scope); +END_SYNTAX_CLASS() + +// An expression that binds a temporary variable in a local expression context +SYNTAX_CLASS(LetExpr, Expr) +RAW( + RefPtr decl; + RefPtr body; +) +END_SYNTAX_CLASS() + +SYNTAX_CLASS(ExtractExistentialValueExpr, Expr) +RAW( + DeclRef declRef; +) +END_SYNTAX_CLASS() + + /// A type expression of the form `__TaggedUnion(A, ...)`. + /// + /// An expression of this form will resolve to a `TaggedUnionType` + /// when checked. + /// +SYNTAX_CLASS(TaggedUnionTypeExpr, Expr) +RAW( + List caseTypes; +) +END_SYNTAX_CLASS() diff --git a/source/slang/slang-extension-usage-tracker.h b/source/slang/slang-extension-usage-tracker.h index 32002261d..d17a7a6a1 100644 --- a/source/slang/slang-extension-usage-tracker.h +++ b/source/slang/slang-extension-usage-tracker.h @@ -2,9 +2,9 @@ #ifndef SLANG_EXTENSION_USAGE_TRACKER_H_INCLUDED #define SLANG_EXTENSION_USAGE_TRACKER_H_INCLUDED -#include "../core/basic.h" +#include "../core/slang-basic.h" -#include "compiler.h" +#include "slang-compiler.h" namespace Slang { diff --git a/source/slang/slang-file-system.cpp b/source/slang/slang-file-system.cpp index 9cd2ee035..f8c423e84 100644 --- a/source/slang/slang-file-system.cpp +++ b/source/slang/slang-file-system.cpp @@ -4,7 +4,7 @@ #include "../core/slang-io.h" #include "../core/slang-string-util.h" -#include "compiler.h" +#include "slang-compiler.h" namespace Slang { diff --git a/source/slang/slang-file-system.h b/source/slang/slang-file-system.h index a9e92dc1e..e89bf45db 100644 --- a/source/slang/slang-file-system.h +++ b/source/slang/slang-file-system.h @@ -6,7 +6,7 @@ #include "../../slang-com-ptr.h" #include "../core/slang-string-util.h" -#include "../core/dictionary.h" +#include "../core/slang-dictionary.h" namespace Slang { diff --git a/source/slang/slang-image-format-defs.h b/source/slang/slang-image-format-defs.h new file mode 100644 index 000000000..aa6ffec50 --- /dev/null +++ b/source/slang/slang-image-format-defs.h @@ -0,0 +1,47 @@ +// slang-image-format-defs.h +#ifndef FORMAT +#error Must define FORMAT macro before including image-format-defs.h +#endif + +FORMAT(unknown) +FORMAT(rgba32f) +FORMAT(rgba16f) +FORMAT(rg32f) +FORMAT(rg16f) +FORMAT(r11f_g11f_b10f) +FORMAT(r32f) +FORMAT(r16f) +FORMAT(rgba16) +FORMAT(rgb10_a2) +FORMAT(rgba8) +FORMAT(rg16) +FORMAT(rg8) +FORMAT(r16) +FORMAT(r8) +FORMAT(rgba16_snorm) +FORMAT(rgba8_snorm) +FORMAT(rg16_snorm) +FORMAT(rg8_snorm) +FORMAT(r16_snorm) +FORMAT(r8_snorm) +FORMAT(rgba32i) +FORMAT(rgba16i) +FORMAT(rgba8i) +FORMAT(rg32i) +FORMAT(rg16i) +FORMAT(rg8i) +FORMAT(r32i) +FORMAT(r16i) +FORMAT(r8i) +FORMAT(rgba32ui) +FORMAT(rgba16ui) +FORMAT(rgb10_a2ui) +FORMAT(rgba8ui) +FORMAT(rg32ui) +FORMAT(rg16ui) +FORMAT(rg8ui) +FORMAT(r32ui) +FORMAT(r16ui) +FORMAT(r8ui) + +#undef FORMAT diff --git a/source/slang/slang-ir-bind-existentials.cpp b/source/slang/slang-ir-bind-existentials.cpp new file mode 100644 index 000000000..e426e6e92 --- /dev/null +++ b/source/slang/slang-ir-bind-existentials.cpp @@ -0,0 +1,352 @@ +// slang-ir-bind-existentials.cpp +#include "slang-ir-bind-existentials.h" + +#include "slang-ir.h" +#include "slang-ir-insts.h" + +namespace Slang +{ + +// The code that comes out of the linking step will have instructions added +// that indicate how parameters with existential (interface) types are supposed +// to be specialized to concrete types. +// +// If there are any global existential-type parameters there should be a +// `bindGlobalExistentialSlots(...)` instruction at module scope. +// +// For each entry point with entry-point existential parameters, there should +// be a `[bindExistentialSlots(...)]` decoration attached to the entry +// point itself. +// +// In each case, the operands of the instruction should be a sequence of +// pairs. The number of pairs should match the number of existential "slots" +// at global or entry-point scope. Each pair should comprise a type `T` +// to plug into the slot, and a witness table `w` for the conformance of +// `T` to the interface type in that slot. +// +// In the simplest case, if we have a global shader parameter of interface +// type: +// +// IFoo p; +// +// Then this will lower to the IR as: +// +// global_param p : IFoo; +// +// And if the user tries to specialize `p` to type `Bar`, and a witness +// table `bar_is_ifoo`, we've have: +// +// bindGlobalExistentialSlots(Bar, bar_is_ifoo); +// +// The goal of this pass is to replace the parameter of interface type +// with one of concrete type: +// +// global_param p_new : Bar; +// +// and replace any reference to the old `p` parameter with +// a `makeExistential(p_new, bar_is_ifoo)`. That preserves the +// fact that a reference to `p` is conceptually of type `IFoo`, +// but allows downstream optimization passes to start specializing +// code based on the concrete knowledge that the value "backing" +// the parameter is actaully of type `Bar`. + +// As is typically for IR passes, we will encapsulate all the +// logic in a `struct` type. +// +struct BindExistentialSlots +{ + IRModule* module = nullptr; + DiagnosticSink* sink = nullptr; + + void processModule() + { + // We will start by dealing with the global existential slots. + processGlobalExistentialSlots(); + + // Then we will process the per-entry-point existential slots. + processEntryPointExistentialSlots(); + } + + void processGlobalExistentialSlots() + { + // If there are any global existential slots, we will expect + // to find a `bindGlobalExistentialSlots` instruction at module scope. + // + // We will start out by finding that instruction, if it exists. + // + IRInst* bindGlobalExistentialSlotsInst = nullptr; + for( auto inst : module->getGlobalInsts() ) + { + if( inst->op == kIROp_BindGlobalExistentialSlots ) + { + bindGlobalExistentialSlotsInst = inst; + break; + } + } + + // Now we will start looking for global shader parameters that make + // use of existential slots (we can determine this from their + // layout). + // + for( auto inst : module->getGlobalInsts() ) + { + // We only care about global shader parameters. + // + auto globalParam = as(inst); + if(!globalParam) + continue; + + // We will delegate to a subroutine for the meat + // of the work, since much of it can be shared + // with the case for entry-point existential + // parameters. + // + processParameter(globalParam, bindGlobalExistentialSlotsInst); + } + + // Once we are done looping over global shader parameters, + // all of the relevant information from the + // `bindGlobalExistentialSlots` instruction will have + // been moved to the parameters themselves, so we + // can eliminate the binding instruction. + // + if( bindGlobalExistentialSlotsInst ) + { + bindGlobalExistentialSlotsInst->removeAndDeallocate(); + } + } + + void processEntryPointExistentialSlots() + { + // The overall flow for the entry-point case is similar + // to the global case. + // + // We start by iterating over all the functions at + // global scope and look for entry points. + // + for( auto inst : module->getGlobalInsts() ) + { + auto func = as(inst); + if(!func) + continue; + + if(!func->findDecorationImpl(kIROp_EntryPointDecoration)) + continue; + + // We then process each entry point we find. + // + processEntryPointExistentialSlots(func); + } + } + + void processEntryPointExistentialSlots(IRFunc* func) + { + // When looking at a single `func`, we need + // to find the `[bindExistentialSlots(...)]` decoration, + // if it has one. + // + auto bindEntryPointExistentialSlotsInst = func->findDecorationImpl(kIROp_BindExistentialSlotsDecoration); + + // We then need to process each of the entry-point + // parameters just like we did for global parameters. + // + for( auto param : func->getParams() ) + { + processParameter(param, bindEntryPointExistentialSlotsInst); + } + + // TODO: We would need to consider what to do if + // we had an existential return type for `func`. + // + // In general, it probably doesn't make sense to + // have existential types in varying input/output + // at all, so the front-end should probably be + // validating that. + + // Once we've processed all the parameters, the information + // in the `[bindExistentialSlots(...)]` decoration is + // no longer needed, and we can remove it. + // + if( bindEntryPointExistentialSlotsInst ) + { + bindEntryPointExistentialSlotsInst->removeAndDeallocate(); + } + } + + // When processing a single parameter we need to have access + // to the corresponding instruction that will bind its slots. + // + // We don't care whether we have a `global_param` and a + // `bindGlobalExistentialSlots` instruction, or an entry-point + // function `param` and a `[bindExistentialSlots(...)]` + // decoration; both use the same subroutine. + // + void processParameter( + IRInst* param, + IRInst* bindSlotsInst) + { + // We expect all shader parameters to have layout information, + // but to be defensive we will skip any that don't. + // + auto layoutDecoration = param->findDecoration(); + if(!layoutDecoration) + return; + auto varLayout = as(layoutDecoration->getLayout()); + if(!varLayout) + return; + + // We only care about parameters that are associated + // with one or more existential slots. + // + auto resInfo = varLayout->FindResourceInfo(LayoutResourceKind::ExistentialTypeParam); + if(!resInfo) + return; + + // We will use the layout information on the variable to + // find out the stating slot, and the information on + // the type to find out the number of slots. + // + UInt firstSlot = resInfo->index; + UInt slotCount = 0; + if(auto typeResInfo = varLayout->getTypeLayout()->FindResourceInfo(LayoutResourceKind::ExistentialTypeParam)) + slotCount = UInt(typeResInfo->count.getFiniteValue()); + + // At this point we know that the parameter consumes + // some number of slots, so it would be an error + // if we don't have an instruction to bind the slots. + // + if( !bindSlotsInst ) + { + // Note: This error is considered an internal error because + // we should be detecting and diagnosing this problem before + // we make it to back-end code generation. + // + sink->diagnose(param->sourceLoc, Diagnostics::missingExistentialBindingsForParameter); + return; + } + + // Each existential slot corresponds to *two* arguments + // on the binding instruction: one for the type, and + // another for the witness table. + // + // We will check to make sure we have enough operands to cover + // this parameter. + // + UInt bindOperandCount = bindSlotsInst->getOperandCount(); + if( 2*(firstSlot + slotCount) > bindOperandCount ) + { + sink->diagnose(param->sourceLoc, Diagnostics::missingExistentialBindingsForParameter); + return; + } + // + // If there are enough operands, then we will offset to + // get to the starting point for the current parameter, + // keeping in mind that each slot accounts for two + // operands. + // + auto operandsForInst = bindSlotsInst->getOperands() + firstSlot; + + // Once we've found the operands that are relevent to + // the slots used by `param`, we will defer to a routine + // that replaces the type of `param` based on the + // information in the slots. + // + replaceTypeUsingExistentialSlots( + param, + slotCount, + operandsForInst); + } + + void replaceTypeUsingExistentialSlots( + IRInst* inst, + UInt slotCount, + IRUse const* slotArgs) + { + SLANG_UNUSED(slotCount); + + // We are going to alter the type of the + // given `inst` based on information in + // the `slotArgs`. + + auto fullType = inst->getFullType(); + + SharedIRBuilder sharedBuilder; + sharedBuilder.session = module->getSession(); + sharedBuilder.module = module; + + IRBuilder builder; + builder.sharedBuilder = &sharedBuilder; + + // Every argument that is filling an existential + // type param/slot comprises both a type and + // a witness table, so the total number of operands + // is twice the number of slots we are filling. + // + UInt slotOperandCount = slotCount*2; + List slotOperands; + for(UInt ii = 0; ii < slotOperandCount; ++ii) + slotOperands.add(slotArgs[ii].get()); + + // We are going to create a proxy type that represents + // the results of plugging all the information + // from the existential slots into the original type. + // + auto newType = builder.getBindExistentialsType( + fullType, + slotOperandCount, + slotOperands.getBuffer()); + + // We will replace the type of the original parameter + // with the new proxy type. + // + builder.setDataType(inst, newType); + + // Next we want to replace all uses of `inst` (which + // expect a value of its old type) with a fresh + // `wrapExistential(...)` instruction that refers to + // `inst` with its new type. + // + // Note: we make a copy of the list of uses for `inst` + // before going through and replacing them, because + // during the replacement we make *more* uses of `inst`, + // as an operand to the `makeExistential` instructions. + // We only want to replace the old uses, and not the + // new ones we'll be making. + // + List usesToReplace; + for(auto use = inst->firstUse; use; use = use->nextUse ) + usesToReplace.add(use); + + // Now we can loop over our list of uses and replace each. + // + for(auto use : usesToReplace) + { + // First we emit a `makeExisential` right before the + // use site. + // + builder.setInsertBefore(use->getUser()); + auto newVal = builder.emitWrapExistential( + fullType, + inst, + slotOperandCount, + slotOperands.getBuffer()); + + // Second we make the use site point at the new + // value instead. + // + use->set(newVal); + } + } +}; + +void bindExistentialSlots( + IRModule* module, + DiagnosticSink* sink) +{ + BindExistentialSlots context; + context.module = module; + context.sink = sink; + context.processModule(); +} + +} diff --git a/source/slang/slang-ir-bind-existentials.h b/source/slang/slang-ir-bind-existentials.h new file mode 100644 index 000000000..4c32ee8c2 --- /dev/null +++ b/source/slang/slang-ir-bind-existentials.h @@ -0,0 +1,15 @@ +// slang-ir-bind-existentials.h +#pragma once + +namespace Slang +{ + +class DiagnosticSink; +struct IRModule; + + /// Bind concrete types to paameters that use existential slots. +void bindExistentialSlots( + IRModule* module, + DiagnosticSink* sink); + +} diff --git a/source/slang/slang-ir-clone.cpp b/source/slang/slang-ir-clone.cpp new file mode 100644 index 000000000..df0555f9b --- /dev/null +++ b/source/slang/slang-ir-clone.cpp @@ -0,0 +1,295 @@ +// slang-ir-clone.cpp +#include "slang-ir-clone.h" + +#include "slang-ir.h" +#include "slang-ir-insts.h" + +namespace Slang +{ + +IRInst* lookUp(IRCloneEnv* env, IRInst* oldVal) +{ + for( auto ee = env; ee; ee = ee->parent ) + { + IRInst* newVal = nullptr; + if(ee->mapOldValToNew.TryGetValue(oldVal, newVal)) + return newVal; + } + return nullptr; +} + +IRInst* findCloneForOperand( + IRCloneEnv* env, + IRInst* oldOperand) +{ + if(!oldOperand) return nullptr; + + // If there is a registered replacement for + // the existing operand, then use it. + // + if( IRInst* newVal = lookUp(env, oldOperand) ) + return newVal; + + // Otherwise, we assume that the caller wants + // to default to using existing values wherever + // an explicit replacement hasn't been registered. + // + // This is, notably, the right default whenever + // `oldOperand` is a global value or constant + // and our cloned code will sit in the same + // module as the original. + // + // TODO: We could make this a customization point + // down the road, if we ever had a case where + // we want to clone things with a different policy. + // + return oldOperand; +} + +IRInst* cloneInstAndOperands( + IRCloneEnv* env, + IRBuilder* builder, + IRInst* oldInst) +{ + SLANG_ASSERT(env); + SLANG_ASSERT(builder); + SLANG_ASSERT(oldInst); + + // This logic will not handle any instructions + // with special-case data attached, but that only + // applies to `IRConstant`s at this point, and those + // should only appear at the global scope rather than + // in function bodies. + // + // TODO: It would be easy enough to extend this logic + // to handle constants gracefully, if it ever comes up. + // + SLANG_ASSERT(!as(oldInst)); + + // We start by mapping the type of the orignal instruction + // to its replacement value, if any. + // + auto oldType = oldInst->getFullType(); + auto newType = (IRType*) findCloneForOperand(env, oldType); + + // Next we will create an empty shell of the instruction, + // with space for the operands, but no actual operand + // values attached. + // + UInt operandCount = oldInst->getOperandCount(); + auto newInst = builder->emitIntrinsicInst( + newType, + oldInst->op, + operandCount, + nullptr); + + // Finally we will iterate over the operands of `oldInst` + // to find their replacements and install them as + // the operands of `newInst`. + // + for(UInt ii = 0; ii < operandCount; ++ii) + { + auto oldOperand = oldInst->getOperand(ii); + auto newOperand = findCloneForOperand(env, oldOperand); + + newInst->getOperands()[ii].init(newInst, newOperand); + } + + return newInst; +} + +// The complexity of the second phase of cloning (the +// one that deals with decorations and children) comes +// from the fact that it needs to sequence the two phases +// of cloning for any child instructions. We will do this +// by performing the first phase of cloning, and building +// up a list of children that require the second phase of processing. +// Each entry in that list will be a pair of an old instruction +// and its new clone. +// +struct IRCloningOldNewPair +{ + IRInst* oldInst; + IRInst* newInst; +}; + +// We will use an internal variant of `cloneInstDecorationsAndChildren` +// that modifies the provided `env` as it goes as the main +// workhorse, since we need to make sure that instructions in +// earlier blocks are visible to those in other, later, blocks +// when cloning a function, so that strict scoping along the +// lines of the nesting of instructions isn't sufficient. +// +static void _cloneInstDecorationsAndChildren( + IRCloneEnv* env, + SharedIRBuilder* sharedBuilder, + IRInst* oldInst, + IRInst* newInst) +{ + SLANG_ASSERT(env); + SLANG_ASSERT(sharedBuilder); + SLANG_ASSERT(oldInst); + SLANG_ASSERT(newInst); + + // We will set up an IR builder that inserts + // into the new parent instruction. + // + IRBuilder builderStorage; + auto builder = &builderStorage; + builder->sharedBuilder = sharedBuilder; + builder->setInsertInto(newInst); + + // When applying the first phase of cloning to + // children, we will keep track of those that + // require the second phase. + // + List pairs; + + for( auto oldChild : oldInst->getDecorationsAndChildren() ) + { + // As a very subtle special case, if one of the children + // of our `oldInst` already has a registered replacement, + // then we don't want to clone it (not least because + // the `Dictionary::Add` method would give us an error + // when we try to insert a new value for the same key). + // + // This arises for entries in `mapOldValToNew` that were + // seeded before cloning begain (e.g., function + // parameters that are to be replaced). + // + if(lookUp(env, oldChild)) + continue; + + // Now we can perform the first phase of cloning + // on the child, and register it in our map from + // old to new values. + // + auto newChild = cloneInstAndOperands(env, builder, oldChild); + env->mapOldValToNew.Add(oldChild, newChild); + + // If and only if the old child had decorations + // or children, we will register it into our + // list for processing in the second phase. + // + if( oldChild->getFirstDecorationOrChild() ) + { + IRCloningOldNewPair pair; + pair.oldInst = oldChild; + pair.newInst = newChild; + pairs.add(pair); + } + } + + // Once we have done first-phase processing for + // all child instructions, we scan through those + // in the list that required second-phase processing, + // and clone their decorations and/or children recursively. + // + for( auto pair : pairs ) + { + auto oldChild = pair.oldInst; + auto newChild = pair.newInst; + + _cloneInstDecorationsAndChildren(env, sharedBuilder, oldChild, newChild); + } +} + +// The public version of `cloneInstDecorationsAndChildren` is then +// just a wrapper over the internal one that sets up a temporary +// environment to use for the cloning process, so that we do +// not leave any lasting changes in the user-provided `env`. +// +void cloneInstDecorationsAndChildren( + IRCloneEnv* env, + SharedIRBuilder* sharedBuilder, + IRInst* oldInst, + IRInst* newInst) +{ + SLANG_ASSERT(sharedBuilder); + SLANG_ASSERT(oldInst); + SLANG_ASSERT(newInst); + + IRCloneEnv subEnvStorage; + auto subEnv = &subEnvStorage; + subEnv->parent = env; + + _cloneInstDecorationsAndChildren(subEnv, sharedBuilder, oldInst, newInst); +} + +// The convenience function `cloneInst` just sequences the +// operations that have already been defined. +// +IRInst* cloneInst( + IRCloneEnv* env, + IRBuilder* builder, + IRInst* oldInst) +{ + SLANG_ASSERT(env); + SLANG_ASSERT(builder); + SLANG_ASSERT(oldInst); + + auto newInst = cloneInstAndOperands( + env, builder, oldInst); + + env->mapOldValToNew.Add(oldInst, newInst); + + cloneInstDecorationsAndChildren( + env, builder->sharedBuilder, oldInst, newInst); + + return newInst; +} + +void cloneDecoration( + IRDecoration* oldDecoration, + IRInst* newParent, + IRModule* module) +{ + SharedIRBuilder sharedBuilder; + sharedBuilder.module = module; + + IRBuilder builder; + builder.sharedBuilder = &sharedBuilder; + + if(auto first = newParent->getFirstDecorationOrChild()) + builder.setInsertBefore(first); + else + builder.setInsertInto(newParent); + + IRCloneEnv env; + cloneInst(&env, &builder, oldDecoration); +} + +void cloneDecoration( + IRDecoration* oldDecoration, + IRInst* newParent) +{ + cloneDecoration( + oldDecoration, + newParent, + newParent->getModule()); +} + +bool IRSimpleSpecializationKey::operator==(IRSimpleSpecializationKey const& other) const +{ + auto valCount = vals.getCount(); + if(valCount != other.vals.getCount()) return false; + for( Index ii = 0; ii < valCount; ++ii ) + { + if(vals[ii] != other.vals[ii]) return false; + } + return true; +} + +int IRSimpleSpecializationKey::GetHashCode() const +{ + auto valCount = vals.getCount(); + int hash = Slang::GetHashCode(valCount); + for( Index ii = 0; ii < valCount; ++ii ) + { + hash = combineHash(hash, Slang::GetHashCode(vals[ii])); + } + return hash; +} + + +} // namespace Slang diff --git a/source/slang/slang-ir-clone.h b/source/slang/slang-ir-clone.h new file mode 100644 index 000000000..d2d3b1f55 --- /dev/null +++ b/source/slang/slang-ir-clone.h @@ -0,0 +1,183 @@ +// slang-ir-clone.h +#pragma once + +#include "../core/slang-dictionary.h" + +#include "slang-ir.h" + +namespace Slang +{ +struct IRBuilder; +struct IRInst; +struct SharedIRBuilder; + +// This file provides an interface to simplify the task of +// correcting "cloning" IR code, whether individual +// instructions, or whole functions. + + /// An environment for mapping existing values to their cloned replacements. + /// + /// This type serves two main roles in the process of IR cloning: + /// + /// * Before cloning begins, a client will usually + /// register the mapping from things that are to be + /// replaced entirely (like function parameters to + /// be specialized away) to their replacements (e.g., + /// a constant value). + /// + /// * During the process of cloning, env environment + /// will be maintained and updated so that when, e.g., + /// an instruction later in a function refers to + /// something from earlier, we can look up the + /// replacement. + /// +struct IRCloneEnv +{ + /// A mapping from old values to their replacements. + Dictionary mapOldValToNew; + + /// A parent environment to fall back to if `mapOldValToNew` doesn't contain a key. + IRCloneEnv* parent = nullptr; +}; + + /// Look up the replacement for `oldVal`, if any, registered in `env`. + /// + /// Returns `nullptr` if `oldVal` has no registered replacement. + /// +IRInst* lookUp(IRCloneEnv* env, IRInst* oldVal); + +// The SSA property and the way we have structured +// our "phi nodes" (block parameters) means that +// just going through the children of a function, +// and then the children of a block will generally +// do the Right Thing and always visit an instruction +// before its uses. +// +// The big exception to this is that branch instructions +// can refer to blocks later in the same function. +// +// We work around this sort of problem in a fairly +// general fashion, by splitting the cloning of +// an instruction into two steps. +// +// The first step is just to clone the instruction +// and its direct operands, but not any decorations +// or children. + + /// Clone `oldInst` and its direct operands. + /// + /// The "direct operands" include the type of the instruction. + /// The type and operands of `oldInst` will be mapped to now + /// values using `findOrCloneOperand` with the given `env`. + /// + /// Any new instruction that gets emitted will be output to + /// the provided `builder`, which must be non-null. + /// + /// This operation does *not* clone any children or decorations on `oldInst`. + /// This operation does *not* register its result as a replacement + /// for `oldInst` in the given `env`. + /// +IRInst* cloneInstAndOperands( + IRCloneEnv* env, + IRBuilder* builder, + IRInst* oldInst); + +// The second phase of cloning an instruction is to clone +// its decorations and children. This step only needs to +// be performed on those instructions that *have* decorations +// and/or children. + + /// Clone any decorations and/or children of `oldInst` onto `newInst` + /// + /// Any new instructions that get emitted will use the + /// provided `sharedBuilder`, which must be non-null. + /// + /// During the process of cloning decorations/children, operand values + /// will be looked up in the provided `env`, which should provide + /// replacement values for instructions that should have a different + /// identity in the clone. + /// The provided `env` will *not* be updated/modified during the + /// process of cloding decorations/children. + /// + /// If any child or decoration on `oldInst` already has a replacement + /// registered in `env`, it will *not* be cloned into `newInst`. + /// +void cloneInstDecorationsAndChildren( + IRCloneEnv* env, + SharedIRBuilder* sharedBuilder, + IRInst* oldInst, + IRInst* newInst); + +// For the case where the user knows the sequencing constraints +// on cloning operands before uses can be satisfied, we provide +// a convenience wrapper around the two phases of cloning: + + /// Clone `oldInst` and return the cloned value. + /// + /// This function is a convenience wrapper around + /// `cloneInstAndOperands` and `cloneInstDecorationsAndChildren`. + /// It also registers the resultint instruction as + /// the replacement value for `oldInst` in the given `env` + /// which must therefore be non-null. + /// +IRInst* cloneInst( + IRCloneEnv* env, + IRBuilder* builder, + IRInst* oldInst); + + /// Clone `oldDecoration` and attach the clone to `newParent`. + /// + /// Uses `module` to allocate any new instructions. + /// +void cloneDecoration( + IRDecoration* oldDecoration, + IRInst* newParent, + IRModule* module); + + /// Clone `oldDecoration` and attach the clone to `newParent`. + /// + /// Uses the module of `newParent` to allocate any new instructions, + /// so that `newParent` must already be installed somewhere + /// in the ownership hierarchy of an existing module. + /// +void cloneDecoration( + IRDecoration* oldDecoration, + IRInst* newParent); + + + /// Find the "cloned" value to use for an operand. + /// + /// This either returns the value registered for `oldOperand` + /// in `env`, or else `oldOperand` itself. +IRInst* findCloneForOperand( + IRCloneEnv* env, + IRInst* oldOperand); + +// It isn't technically part of the cloning infrastructure, +// but when make specialized copies of IR instructions via +// cloning we often need a simple kind of key suitable +// for caching existing specializations, so we'll define +// it here so that is is easily accessible to code that +// needs it. + +struct IRSimpleSpecializationKey +{ + // The structure of a specialization key will be a list + // of instructions, typically starting with the function, + // generic, or other object to be specialized, and then + // having one or more entries to represent the specialization + // arguments. + // + List vals; + + // In order to use this type as a `Dictionary` key we + // need it to support equality and hashing. + // + // TODO: honestly we might consider having `GetHashCode` + // and `operator==` defined for `List`. + + bool operator==(IRSimpleSpecializationKey const& other) const; + int GetHashCode() const; +}; + +} diff --git a/source/slang/slang-ir-constexpr.cpp b/source/slang/slang-ir-constexpr.cpp new file mode 100644 index 000000000..f041d7ae6 --- /dev/null +++ b/source/slang/slang-ir-constexpr.cpp @@ -0,0 +1,553 @@ +// slang-ir-constexpr.cpp +#include "slang-ir-constexpr.h" + +#include "slang-ir.h" +#include "slang-ir-insts.h" + +namespace Slang { + +struct PropagateConstExprContext +{ + IRModule* module; + IRModule* getModule() { return module; } + + DiagnosticSink* sink; + + SharedIRBuilder sharedBuilder; + IRBuilder builder; + + List workList; + HashSet onWorkList; + + IRBuilder* getBuilder() { return &builder; } + + Session* getSession() { return sharedBuilder.session; } + + DiagnosticSink* getSink() { return sink; } +}; + +bool isConstExpr(IRType* fullType) +{ + if( auto rateQualifiedType = as(fullType)) + { + auto rate = rateQualifiedType->getRate(); + if(auto constExprRate = as(rate)) + return true; + } + + return false; +} + +bool isConstExpr(IRInst* value) +{ + // Certain IR value ops are implicitly `constexpr` + // + // TODO: should we just go ahead and make that explicit + // in the type system? + switch(value->op) + { + case kIROp_IntLit: + case kIROp_FloatLit: + case kIROp_BoolLit: + case kIROp_Func: + return true; + + default: + break; + } + + if(isConstExpr(value->getFullType())) + return true; + + return false; +} + +bool opCanBeConstExpr(IROp op) +{ + switch( op ) + { + case kIROp_IntLit: + case kIROp_FloatLit: + case kIROp_BoolLit: + case kIROp_Add: + case kIROp_Sub: + case kIROp_Mul: + case kIROp_Div: + case kIROp_Mod: + case kIROp_Neg: + case kIROp_Construct: + case kIROp_makeVector: + case kIROp_makeArray: + case kIROp_MakeMatrix: + // TODO: more cases + return true; + + default: + return false; + } +} + +bool opCanBeConstExpr(IRInst* value) +{ + // TODO: realistically need to special-case `call` + // operations here, so that we check whether the + // callee function is fixed/known, and if it is + // whether it has been decoared as constant-foldable + + return opCanBeConstExpr(value->op); +} + +void markConstExpr( + PropagateConstExprContext* context, + IRInst* value) +{ + Slang::markConstExpr(context->getBuilder(), value); +} + + +// Propagate `constexpr`-ness in a forward direction, from the +// operands of an instruction to the instruction itself. +bool propagateConstExprForward( + PropagateConstExprContext* context, + IRGlobalValueWithCode* code) +{ + bool anyChanges = false; + for(;;) + { + bool changedThisIteration = false; + for( auto bb = code->getFirstBlock(); bb; bb = bb->getNextBlock() ) + { + for( auto ii = bb->getFirstInst(); ii; ii = ii->getNextInst() ) + { + // Instruction already `constexpr`? Then skip it. + if(isConstExpr(ii)) + continue; + + // Is the operation one that we can actually make be constexpr? + if(!opCanBeConstExpr(ii)) + continue; + + // Are all arguments `constexpr`? + bool allArgsConstExpr = true; + UInt argCount = ii->getOperandCount(); + for( UInt aa = 0; aa < argCount; ++aa ) + { + auto arg = ii->getOperand(aa); + + if( !isConstExpr(arg) ) + { + allArgsConstExpr = false; + break; + } + } + if(!allArgsConstExpr) + continue; + + // Seems like this operation can/should be made constexpr + markConstExpr(context, ii); + changedThisIteration = true; + } + } + + if( !changedThisIteration ) + return anyChanges; + + anyChanges = true; + } +} + +void maybeAddToWorkList( + PropagateConstExprContext* context, + IRInst* gv) +{ + if( !context->onWorkList.Contains(gv) ) + { + context->workList.add(gv); + context->onWorkList.Add(gv); + } +} + +bool maybeMarkConstExpr( + PropagateConstExprContext* context, + IRInst* value) +{ + if(isConstExpr(value)) + return false; + + if(!opCanBeConstExpr(value)) + return false; + + markConstExpr(context, value); + + // TODO: we should only allow function parameters to be + // changed to be `constexpr` when we are compiling "application" + // code, and not library code. + // (Or eventually we'd have a rule that only non-`public` symbols + // can have this kind of propagation applied). + + if(value->op == kIROp_Param) + { + auto param = (IRParam*) value; + auto block = (IRBlock*) param->parent; + auto code = block->getParent(); + + if(block == code->getFirstBlock()) + { + // We've just changed a function parameter to + // be `constexpr`. We need to remember that + // fact so taht we can mark callers of this + // function as `constexpr` themselves. + + for( auto u = code->firstUse; u; u = u->nextUse ) + { + auto user = u->getUser(); + + switch( user->op ) + { + case kIROp_Call: + { + auto inst = (IRCall*) user; + auto caller = as(inst->getParent()->getParent()); + maybeAddToWorkList(context, caller); + } + break; + + default: + break; + } + } + } + } + + return true; +} + +// Propagate `constexpr`-ness in a backward direction, from an instruction +// to its operands. +bool propagateConstExprBackward( + PropagateConstExprContext* context, + IRGlobalValueWithCode* code) +{ + SharedIRBuilder sharedBuilder; + sharedBuilder.module = context->getModule(); + sharedBuilder.session = sharedBuilder.module->session; + + IRBuilder builder; + builder.sharedBuilder = &sharedBuilder; + builder.setInsertInto(code); + + bool anyChanges = false; + for(;;) + { + // Note: we are walking the list of blocks and the instructions + // in each block in reverse order, to maximize the chances that + // we propagate multiple changes in a each pass. + // + // TODO: this should probably all be done with a work list instead, + // but that requires being able to detect instructions vs. other + // values. + + bool changedThisIteration = false; + for( auto bb = code->getLastBlock(); bb; bb = bb->getPrevBlock() ) + { + for( auto ii = bb->getLastInst(); ii; ii = ii->getPrevInst() ) + { + if( isConstExpr(ii) ) + { + // If this instruction is `constexpr`, then its operands should be too. + UInt argCount = ii->getOperandCount(); + for( UInt aa = 0; aa < argCount; ++aa ) + { + auto arg = ii->getOperand(aa); + if(isConstExpr(arg)) + continue; + + if(!opCanBeConstExpr(arg)) + continue; + + if( maybeMarkConstExpr(context, arg) ) + { + changedThisIteration = true; + } + } + } + else if( ii->op == kIROp_Call ) + { + // A non-constexpr call might be calling a function with one or + // more constexpr parameters. We should check if we can resolve + // the callee for this call statically, and if so try to propagate + // constexpr from the parameters back to the arguments. + auto callInst = (IRCall*) ii; + + UInt operandCount = callInst->getOperandCount(); + + UInt firstCallArg = 1; + UInt callArgCount = operandCount - firstCallArg; + + auto callee = callInst->getOperand(0); + + // If we are calling a generic operation, then + // try to follow through the `specialize` chain + // and find the callee. + // + // TODO: This probably shouldn't be required, + // since we can hopefully use the type of the + // callee in all cases. + // + while(auto specInst = as(callee)) + { + auto genericInst = as(specInst->getBase()); + if(!genericInst) + break; + + auto returnVal = findGenericReturnVal(genericInst); + if(!returnVal) + break; + + callee = returnVal; + } + + auto calleeFunc = as(callee); + if(calleeFunc && isDefinition(calleeFunc)) + { + // We have an IR-level function definition we are calling, + // and thus we can propagate `constexpr` information + // through its `IRParam`s. + + auto calleeFuncType = calleeFunc->getDataType(); + + UInt callParamCount = calleeFuncType->getParamCount(); + SLANG_RELEASE_ASSERT(callParamCount == callArgCount); + + // If the callee has a definition, then we can read `constexpr` + // information off of the parameters of its first IR block. + if(auto calleeFirstBlock = calleeFunc->getFirstBlock()) + { + UInt paramCounter = 0; + for(auto pp = calleeFirstBlock->getFirstParam(); pp; pp = pp->getNextParam()) + { + UInt paramIndex = paramCounter++; + + auto param = pp; + auto arg = callInst->getOperand(firstCallArg + paramIndex); + + if(isConstExpr(param)) + { + if(maybeMarkConstExpr(context, arg)) + { + changedThisIteration = true; + } + } + } + } + } + else + { + // If we don't have a concrete callee function + // definition, then we need to extract the + // type of the callee instruction, and try to work + // with that. + // + // Note that this does not allow us to propagate + // `constexpr` information from the body of a callee + // back to call sites. + auto calleeType = callee->getDataType(); + if(auto caleeFuncType = as(calleeType)) + { + auto paramCount = caleeFuncType->getParamCount(); + for( UInt pp = 0; pp < paramCount; ++pp ) + { + auto paramType = caleeFuncType->getParamType(pp); + auto arg = callInst->getOperand(firstCallArg + pp); + if( isConstExpr(paramType) ) + { + if( maybeMarkConstExpr(context, arg) ) + { + changedThisIteration = true; + } + } + } + } + } + } + } + + if( bb != code->getFirstBlock() ) + { + // A parameter in anything butr the first block is + // conceptually a phi node, which means its operands + // are the corresponding values from the terminating + // branch in a predecessor block. + + UInt paramCounter = 0; + for( auto pp = bb->getFirstParam(); pp; pp = pp->getNextParam() ) + { + UInt paramIndex = paramCounter++; + + if(!isConstExpr(pp)) + continue; + + for(auto pred : bb->getPredecessors()) + { + auto terminator = pred->getLastInst(); + if(terminator->op != kIROp_unconditionalBranch) + continue; + + UInt operandIndex = paramIndex + 1; + SLANG_RELEASE_ASSERT(operandIndex < terminator->getOperandCount()); + + auto operand = terminator->getOperand(operandIndex); + if( maybeMarkConstExpr(context, operand) ) + { + changedThisIteration = true; + } + } + } + } + + } + + if( !changedThisIteration ) + return anyChanges; + + anyChanges = true; + } +} + +// Validate use of `constexpr` within a function (in particular, +// diagnose places where a value that must be contexpr depends +// on a value that cannot be) +void validateConstExpr( + PropagateConstExprContext* context, + IRGlobalValueWithCode* code) +{ + for( auto bb = code->getFirstBlock(); bb; bb = bb->getNextBlock() ) + { + for( auto ii = bb->getFirstInst(); ii; ii = ii->getNextInst() ) + { + if(isConstExpr(ii)) + { + // For an instruction that must be `constexpr`, we need + // to ensure that its argumenst are all `constexpr` + + UInt argCount = ii->getOperandCount(); + for( UInt aa = 0; aa < argCount; ++aa ) + { + auto arg = ii->getOperand(aa); + + if( !isConstExpr(arg) ) + { + // Diagnose the failure. + + context->getSink()->diagnose(ii->sourceLoc, Diagnostics::needCompileTimeConstant); + + break; + } + } + } + } + } +} + +void propagateConstExpr( + IRModule* module, + DiagnosticSink* sink) +{ + auto session = module->session; + + PropagateConstExprContext context; + context.module = module; + context.sink = sink; + context.sharedBuilder.module = module; + context.sharedBuilder.session = session; + context.builder.sharedBuilder = &context.sharedBuilder; + + + // We need to propagate information both forward and backward. + // + // In the forward direction we need to check if all of the operands + // to an instruction are `constexpr` *and* if the operation is + // one that can conceptually be "promoted" to the constexpr rate. + // + // In the backward direction, if an instruction has already been + // marked as needing to be `constexpr`, then its operands had + // better be too. + // + // The backward direction needs to be interprocedural, because + // a parameter to a function might be `constexpr`, so that callers + // of that function would need to be marked too. If backwards + // propagation in any of the callers leads to some of their + // parameters being marked constexpr, then we would need to + // revisit their callers. + + // We will build an initial work list with all of the global values in it. + + for( auto ii : module->getGlobalInsts() ) + { + maybeAddToWorkList(&context, ii); + } + + // We will iterate applying propagation to one global value at a time + // until we run out. + while( context.workList.getCount() ) + { + auto gv = context.workList[0]; + context.workList.fastRemoveAt(0); + context.onWorkList.Remove(gv); + + switch( gv->op ) + { + default: + break; + + case kIROp_Func: + case kIROp_GlobalVar: + case kIROp_GlobalConstant: + { + IRGlobalValueWithCode* code = (IRGlobalValueWithCode*) gv; + + for( ;;) + { + bool anyChange = false; + if( propagateConstExprForward(&context, code) ) + { + anyChange = true; + } + if( propagateConstExprBackward(&context, code) ) + { + anyChange = true; + } + if(!anyChange) + break; + } + } + break; + } + } + + // Okay, we've processed all our functions and found a steady state. + // Now we need to try and issue diagnostics for any IR values where + // we find that they are *required* to be `constexpr`, but *cannot* + // be, for some reason. + + for(auto ii : module->getGlobalInsts()) + { + switch( ii->op ) + { + default: + break; + + case kIROp_Func: + case kIROp_GlobalVar: + case kIROp_GlobalConstant: + { + IRGlobalValueWithCode* code = (IRGlobalValueWithCode*) ii; + validateConstExpr(&context, code); + } + break; + } + } + +} + +} diff --git a/source/slang/slang-ir-constexpr.h b/source/slang/slang-ir-constexpr.h new file mode 100644 index 000000000..92678c8f6 --- /dev/null +++ b/source/slang/slang-ir-constexpr.h @@ -0,0 +1,12 @@ +// slang-ir-constexpr.h +#pragma once + +namespace Slang +{ + class DiagnosticSink; + struct IRModule; + + void propagateConstExpr( + IRModule* module, + DiagnosticSink* sink); +} diff --git a/source/slang/slang-ir-dce.cpp b/source/slang/slang-ir-dce.cpp new file mode 100644 index 000000000..6dc315c76 --- /dev/null +++ b/source/slang/slang-ir-dce.cpp @@ -0,0 +1,325 @@ +// slang-ir-dce.cpp +#include "slang-ir-dce.h" + +#include "slang-ir.h" +#include "slang-ir-insts.h" + +namespace Slang +{ + +struct DeadCodeEliminationContext +{ + // This type implements a simple global DCE pass over + // an entire module. + // + // We start with member variables to stand in for + // the parameters that were passed to the top-level + // `eliminateDeadCode` function. + // + BackEndCompileRequest* compileRequest; + IRModule* module; + + // Our overall process is going to be to determine + // which instructions in the module are "live" + // and then eliminate anything that wasn't found to + // be live. + // + // We will track the liveness state by keeping + // a set of all instructions we have so far determined + // to be live. + // + HashSet liveInsts; + + // Querying whether an instruction has been + // determined to be live is easy. + // + bool isInstLive(IRInst* inst) + { + // The only wrinkle is that we want to safeguard + // against a null instruction (there are some + // corner cases where we still construct IR + // instructions with a null type). + // + if(!inst) return false; + + return liveInsts.Contains(inst); + } + + // We are going to do an iterative analysis + // where we mark instructions we know are + // live, and then see if that can help us + // identify any other instructions that + // must also be live. + // + // For this, we will use a work list of + // instructions that have been marked + // as live, but for which we haven't + // looked at their impact on other + // instructions. + // + List workList; + + // When we discover that an instruction seems + // to be live, we will add it to our set, + // and also the work list, but only if we + // haven't done so previously. + // + void markInstAsLive(IRInst* inst) + { + // Again, we safeguard against null instructions + // just in case. + // + if(!inst) return; + + if(liveInsts.Contains(inst)) + return; + liveInsts.Add(inst); + workList.add(inst); + } + + // Given the basic infrastructrure above, let's + // dive into the task of actually finding all + // the live code in a module. + // + void processModule() + { + // First of all, we know that the root module instruction + // should be considered as live, because otherwise + // we'd end up eliminating it, so that is a + // good place to start. + // + markInstAsLive(module->getModuleInst()); + + // Marking the module as live should have + // seeded our work list, so we can now start + // processing entries off of our work list + // until it goes dry. + // + while( workList.getCount() ) + { + auto inst = workList.getLast(); + workList.removeLast(); + + // At this point we know that `inst` is live, + // and we want to start considering which other + // instructions must be live because of that + // knowlege. + // + // A first easy case is that the parent (if any) + // of a live instruction had better be live, or + // else we might delete the parent, and + // the child with it. + // + markInstAsLive(inst->getParent()); + + // Next the type of a live instruction, and all + // of its operands must also be live, or else + // we won't be able to compute its value. + // + markInstAsLive(inst->getFullType()); + UInt operandCount = inst->getOperandCount(); + for( UInt ii = 0; ii < operandCount; ++ii ) + { + markInstAsLive(inst->getOperand(ii)); + } + + // Finally, we need to consider the children + // and decorations of the instruction. + // + // Note that just because an instruction is + // live doesn't mean its children must be, or + // else we'd never eliminate *anything* (we + // marked the whole module as live, and everything + // is a transitive child of the module). + // + // Decorations, in contrast, are always live if their + // parents are (because we don't want to silently drop + // decorations). It is still important to *mark* + // decorations as live, because they have operands, + // and those operands need to be marked as live. + // We will fold decorations into the same loop + // as children for simplicity. + // + // To keep the code here simple, we'll defer the + // decision of whether a child (or decoration) + // should be live when its parent is to a subroutine. + // + for( auto child : inst->getDecorationsAndChildren() ) + { + if(shouldInstBeLiveIfParentIsLive(child)) + { + // In this case, we know `inst` is live and + // its `child` should be live if its parent is, + // so the `child` must be live too. + // + markInstAsLive(child); + } + } + } + + // If our work list runs dry, that means we've reached a steady + // state where everything that is transitively relevant to + // the "outputs" of the module has been marked as live. + // + // Now we can simply walk through all of our instructions + // recursively and eliminate those that are "dead" by + // virtue of not having been found live. + // + eliminateDeadInstsRec(module->getModuleInst()); + } + + void eliminateDeadInstsRec(IRInst* inst) + { + // Given the instruction `inst` we need to eliminate + // any dead code at, or under it. + // + // The easy case is if `inst` is dead (that is, not live). + // + if( !isInstLive(inst) ) + { + // We can simply remove and deallocate `inst` because it is + // dead, and not worry about any of its descendents, + // because they must have been dead too (since we always + // mark the parent of a live instruction as live). + // + inst->removeAndDeallocate(); + } + else + { + // If `inst` is live, then we need to deal with the possibility + // that its children/decorations (or descendents in general) + // might still be dead. + // + // The biggest wrinkle is that we walk the linked list of + // children/decorations a bit carefully, using a temporary + // to hold the next node, in case we eliminate one of + // the children as we go. + // + IRInst* next = nullptr; + for( IRInst* child = inst->getFirstDecorationOrChild(); child; child = next ) + { + next = child->getNextInst(); + eliminateDeadInstsRec(child); + } + } + } + + // Now we come to the decision procedure we put off before: + // should a given `inst` be live if its parent is? + // + bool shouldInstBeLiveIfParentIsLive(IRInst* inst) + { + // The main source of confusion/complexity here is that + // we are using the same routine to decide: + // + // * Should some ordinary instruction in a basic block be kept around? + // * Should a basic block in some function be kept around? + // * Should a function/type/variable in a module be kept around? + // + // Still, there are a few basic patterns we can observe. + // First, if `inst` is an instruction that might have some effects + // when it is executed, then we should keep it around. + // + if(inst->mightHaveSideEffects()) + return true; + // + // The `mightHaveSideEffects` query is conservative, and will + // return `true` as its default mode, so once we are past that + // query we know that `inst` is either something "structural" + // (that makes up the program) rather than executable, or it + // is executable but was on a white list of things that are + // safe to eliminate. + + // Most top-level objects (functions, types, etc.) obviously + // do *not* have side effects. That creates the risk that + // we'll just go ahead and eliminate every single function/type + // in a module. There needs to be a way to identify the + // functions we want to keep around, and for right now + // that is handled with the `[keepAlive]` decoration. + // + if(inst->findDecorationImpl(kIROp_KeepAliveDecoration)) + return true; + // + // TODO: Eventually it would make sense to consider everything + // with an `[export(...)]` decoration as live, but our current + // approach to linking for back-end compilation leaves many + // linkage decorations in place that we seemingly don't need/want. + + // A basic block is an interesting case. Knowing that a function + // is live means that its entry block is live, but the liveness + // of any other blocks is determined by whether they are referenced + // by other instructions (e.g., a branch from one block to + // another). + // + if( auto block = as(inst) ) + { + // To determine whether this is the first block in its + // parent function (or what-have-you) we can simply + // check if there is a previous block before it. + // + auto prevBlock = block->getPrevBlock(); + return prevBlock == nullptr; + } + + // There are a few special cases of "structural" instructions + // that we don't want to eliminate, so we'll check for those next. + // + switch( inst->op ) + { + // Function parameters obviously shouldn't get eliminated, + // even if nothing references them, and block parameters + // (phi nodes) will be considered live when their block is, + // just so that we don't have to deal with any complications + // around re-writing the relevant inter-block argument passing. + // + // TODO: A smarter DCE pass could deal with this case more + // carefully, or we could improve the interprocedural SCCP + // pass to deal with block parameters instead. + // + case kIROp_Param: + return true; + + // IR struct types and witness tables are currently kludged + // so that they have child instructions that represent their + // entries (effectively `(key,value)` pairs), and those child + // instructions are never directly referenced (e.g., an access + // to a struct field references the *key* but not the `(key,value)` + // pair that is the `IRField` instruction. + // + // TODO: at some point the IR should use a different representation + // for struct types and witness tables that does away with + // this problem. + // + case kIROp_StructField: + case kIROp_WitnessTableEntry: + return true; + + default: + break; + } + + // If none of the explicit cases above matched, then we will consider + // the instruction to not be live just because its parent is. Further + // analysis could still lead to a change in the status of `inst`, if + // an instruction that uses it as an operand is marked live. + // + return false; + } +}; + +// The top-level function for invoking the DCE pass +// is straighforward. We set up the context object +// and then defer to it for the real work. +// +void eliminateDeadCode( + BackEndCompileRequest* compileRequest, + IRModule* module) +{ + DeadCodeEliminationContext context; + context.compileRequest = compileRequest; + context.module = module; + + context.processModule(); +} + +} diff --git a/source/slang/slang-ir-dce.h b/source/slang/slang-ir-dce.h new file mode 100644 index 000000000..b568d9883 --- /dev/null +++ b/source/slang/slang-ir-dce.h @@ -0,0 +1,19 @@ +// slang-ir-dce.h +#pragma once + +namespace Slang +{ + class BackEndCompileRequest; + struct IRModule; + + /// Eliminate "dead" code from the given IR module. + /// + /// This pass is primarily designed for flow-insensitive + /// "global" dead code elimination (DCE), such as removing + /// types that are unused, functions that are never called, + /// etc. + /// + void eliminateDeadCode( + BackEndCompileRequest* compileRequest, + IRModule* module); +} diff --git a/source/slang/slang-ir-dominators.cpp b/source/slang/slang-ir-dominators.cpp new file mode 100644 index 000000000..7960bcaf1 --- /dev/null +++ b/source/slang/slang-ir-dominators.cpp @@ -0,0 +1,720 @@ +// slang-ir-dominators.cpp +#include "slang-ir-dominators.h" + +// +// This file implements the public interface of the `IRDominatorTree` type, +// to enable queries on dominance relationships in a control-flow graph. +// +// It also implements computation of the dominator tree for a CFG using +// the algorithm presented in "A Simple, Fast Dominance Algorithm" by +// Keith D. Cooper, Timothy J. Harvey, and Ken Kennedy. +// +// The algorithm is *not* the most efficinet one, asymptotically, but +// it is one that is easy to implement and explain, and so we favor it +// in order to get something up and running with a reasonable level of +// confidence that the results are correct. +// + +#include "slang-ir.h" + +namespace Slang { + +// +// Let's start with the implementation of the public API for `IRDominatorTree` +// + +// IRDominatorTree + +bool IRDominatorTree::immediatelyDominates(IRBlock* dominator, IRBlock* dominated) +{ + // To test if block A immediately dominates block B, we just + // check if A is the (one and only) immediate dominator of B. + return dominator == getImmediateDominator(dominated); +} + +bool IRDominatorTree::properlyDominates(IRBlock* dominator, IRBlock* dominated) +{ + // Because of how we laid out the tree, we can test if one node + // properly dominates another in constant time. + // + // We simply need to test if the node index for `dominated` falls + // in the range of indices for the descendents of `dominator`. + // + + Int dominatorIndex = getBlockIndex(dominator); + Int dominatedIndex = getBlockIndex(dominated); + Node& dominatorNode = nodes[dominatorIndex]; + + return (dominatedIndex >= dominatorNode.beginDescendents) + && (dominatedIndex < dominatorNode.endDescendents); +} + +bool IRDominatorTree::dominates(IRBlock* dominator, IRBlock* dominated) +{ + // We need to check two cases here. + // + // First, a node always dominated itself, so if the blocks are + // the the same, then we are done: + // + if(dominator == dominated) + return true; + // + // Otherwise, for distinct blocks we just check for + // proper dominance: + // + return properlyDominates(dominator, dominated); +} + +IRBlock* IRDominatorTree::getImmediateDominator(IRBlock* block) +{ + // The immediate dominator of a block is its parent + // in the dominator tree. Looking this up is straightforward, + // and we just need to be a bit careful to deal with + // invalid node indices. + + Int blockIndex = getBlockIndex(block); + if(blockIndex == kInvalidIndex) return nullptr; + + Int parentIndex = nodes[blockIndex].parent; + if(parentIndex == kInvalidIndex) return nullptr; + + return nodes[parentIndex].block; +} + +IRDominatorTree::DominatedList IRDominatorTree::getImmediatelyDominatedBlocks(IRBlock* block) +{ + // Because of our representation, the immediately dominated blocks + // for a node are contiguous, and we store their range in the + // node already. + + Int blockIndex = getBlockIndex(block); + if(blockIndex == kInvalidIndex) return DominatedList(); + + Node& node = nodes[blockIndex]; + return DominatedList( + this, + node.beginDescendents, + node.endChildren); +} + +IRDominatorTree::DominatedList IRDominatorTree::getProperlyDominatedBlocks(IRBlock* block) +{ + // Because of our representation, the properly dominated blocks + // for a node are contiguous, and we store their range in the + // node already. + + Int blockIndex = getBlockIndex(block); + if(blockIndex == kInvalidIndex) return DominatedList(); + + Node& node = nodes[blockIndex]; + return DominatedList( + this, + node.beginDescendents, + node.endDescendents); +} + +Int IRDominatorTree::getBlockIndex(IRBlock* block) +{ + Int index = kInvalidIndex; + if(!mapBlockToIndex.TryGetValue(block, index)) + { + SLANG_UNEXPECTED("block was not present in dominator tree"); + } + return index; +} + +// IRDominatorTree::DominatedList + +IRDominatorTree::DominatedList::DominatedList() + : mTree(nullptr) + , mBegin(0) + , mEnd(0) +{} + +IRDominatorTree::DominatedList::DominatedList( + IRDominatorTree* tree, + Int begin, + Int end) + : mTree(tree) + , mBegin(begin) + , mEnd(end) +{} + +IRDominatorTree::DominatedList::Iterator IRDominatorTree::DominatedList::begin() const +{ + return Iterator(mTree, mBegin); +} + +IRDominatorTree::DominatedList::Iterator IRDominatorTree::DominatedList::end() const +{ + return Iterator(mTree, mEnd); +} + + +// IRDominatorTree::DominatedList::Iterator + +IRDominatorTree::DominatedList::Iterator::Iterator() + : mTree(nullptr) + , mIndex(0) +{} + +IRDominatorTree::DominatedList::Iterator::Iterator( + IRDominatorTree* tree, + Int index) + : mTree(tree) + , mIndex(index) +{} + +IRBlock* IRDominatorTree::DominatedList::Iterator::operator*() const +{ + return mTree->nodes[mIndex].block; +} + +void IRDominatorTree::DominatedList::Iterator::operator++() +{ + mIndex++; +} + +bool IRDominatorTree::DominatedList::Iterator::operator==(Iterator const& that) const +{ + SLANG_ASSERT(mTree == that.mTree); + return mIndex == that.mIndex; +} + +// +// The dominance computation algorithm we are using relies on being able to compute +// a reverse postorder traversal of the nodes in the CFG, which is done using a depth-first +// search (DFS). We don't currently have infrastructure for DFS in the compiler, so +// we will implement it here for now, and plan to move it into its own file once +// we have a second use case. +// + +/// A base "visitor" class for use in depth-first search algorithms on an IR CFG. +struct DepthFirstSearchContext +{ + /// The blocks in the CFG that we've already visited. + HashSet visited; + + /// Walk a (previously unvisited) block. + /// + /// This will perform any pre-order actions on the block, + /// then recursively visit its (unvisited) successors, and + /// then perform any post-actions. + /// + void walk(IRBlock* block) + { + visited.Add(block); + preVisit(block); + for(auto succ : block->getSuccessors()) + { + if(!visited.Contains(succ)) + { + walk(succ); + } + } + postVisit(block); + } + + /// Walk the blocks in a function (or other code-bearing value). + void walk(IRGlobalValueWithCode* code) + { + auto root = code->getFirstBlock(); + if(!root) + return; + walk(root); + } + + /// Overridable action to perform on first entering a CFG node. + virtual void preVisit(IRBlock* /*block*/) {} + + /// Overridable action to perform on exiting a CFG node + virtual void postVisit(IRBlock* /*block*/) {} +}; + +// +// With DFS traversal factored out, computing a post-order walk +// of the CFG is a simple matter of defining a visitor that appends +// to an order as a post-action: +// + +/// A visitor that computes a postorder traversal for a CFG. +struct PostorderComputationContext : public DepthFirstSearchContext +{ + /// List to append the computed order onto + List* order; + + virtual void postVisit(IRBlock* block) SLANG_OVERRIDE + { + order->add(block); + } +}; + +/// Compute a postorder traversal of the blocks in `code`, writing the resulting order to `outOrder`. +void computePostorder(IRGlobalValueWithCode* code, List& outOrder) +{ + PostorderComputationContext context; + context.order = &outOrder; + context.walk(code); +} + +// +// With the preliminaries out of the way, we are ready to implement +// the dominator tree construction algorithm as described by Cooper, Harvey, and Kennedy. +// The actual code for the algorithm is given in Figure 3 of the paper. +// +// We will wrap the subroutines of their algorithm in a `struct` type +// to allow the temporary structures to be shared. +// +struct DominatorTreeComputationContext +{ + // We will use signed integers to represent the "name" of a block. + // The integers will reflect the a postorder traversal, and this + // property will be exploited in the `intersect()` function. + // + typedef Int BlockName; + // + // An invalid/undefined block name will be represented as -1. + // + static const BlockName kUndefined = BlockName(-1); + // + // We will explicitly store the blocks visited in the postorder + // traversal, so that we can look up a block based on its "name" + // + List postorder; + + // + // We need a way to map our actual IR blocks to their names for + // the purpose of this algorithm. This mapping step adds overhead, + // but it seems unavoidable unless we also translate the CFG itself + // to an index-based representation. + // + Dictionary mapBlockToName; + BlockName getBlockName(IRBlock* block) + { + return mapBlockToName[block]; + } + + // + // The algorithm iteratively builds up an array `doms` that upon + // completion will directly encode the immediate dominator for each + // node. During the iterative steps it is used to implicitly encode + // a representation of the set of dominators for each node. + // + List doms; + + + // + // Here we get to the meat of the algorithm presented in Cooper et al. + // Figure 3: + // + void iterativelyComputeImmediateDominators(IRGlobalValueWithCode* code) + { + // First we compute the postorder traversal order for the blocks in the CFG. + computePostorder(code, postorder); + + // We will initialize our map from the block objects to their "name" + // (index in the traversal order), before moving on. + BlockName blockCount = BlockName(postorder.getCount()); + for(BlockName bb = 0; bb < blockCount; ++bb) + { + mapBlockToName[postorder[bb]] = bb; + } + + // Next we initialize the `doms` array that we will iteratively turn + // into an encoding of the dominator tree. + doms.setCount(blockCount); + for(BlockName bb = 0; bb < blockCount; ++bb) + { + doms[bb] = kUndefined; + } + + // The start node is special, since it is the root of the dominator tree. + // Technically it doesn't have an immediate dominator, but we will set + // its entry in `doms` to refer to itself, to indicate that we are done + // processing the given node. + // + BlockName startNode = getBlockName(code->getFirstBlock()); + doms[startNode] = startNode; + + // Given that we computed a postorder traversal of the graph, we know + // that the start node should be the last one in the computed order. + // + SLANG_ASSERT(startNode == blockCount - 1); + + // We are using an iterative algorithm, so we will detect that we + // have reached a fixed point when we hit an iteration where nothing + // changes. + // + bool changed = true; + while(changed) + { + changed = false; + + // The algorithm specifies that we should walk through the blocks + // in *reverse* postorder, since this speeds up convergence. + // Because we've numbered the blocks in postorder, walking them + // in reverse numerical order will do the trick. + // + // We don't want to include the start node in our iteration + // (since we already know its dominators), and because we know + // that the start node is always the last in the order (`blockCount - 1`) + // we can just start at the next node after it (`blockCount - 2`). + // + // Note: it is important that we are using signed integers for + // block numbers here, since we will drop below zero before exiting + // the loop, and if the CFG had only a single block, then our *starting* + // block index would be `-1`. + // + for(auto b = blockCount - 2; b >= 0; --b) + { + // We are walking through block indices, but the predecessor + // lists are encoded in the IR blocks themselves. + // + IRBlock* block = postorder[b]; + + // The algorithm description in the paper says to pick the + // initial value for the `new_idom` variable from the "first + // (processed) predecessor of b (pick one)". + // After that step, the algorithm walks over the remaining + // predecessors, and for the ones that have a valid entry + // in the `doms` array, performs an intersection of their + // implicitly-represented dominator sets. + // + // The paper doesn't precisely clarify what they mean by + // a "processed" predecessor, but it seems to mean one that + // has a valid value in the `doms` array, which is what + // the subsequent loop is already checking. + // + // We are going to fold this logic together into a single loop. + // We will start with an invalid/undefined value for + // `new_idom`, which represents our best guess at the + // immediate dominator for block `b`: + // + BlockName new_idom = kUndefined; + + // Now we will loop over *all* of the predecessors, ... + for(auto pred : block->getPredecessors()) + { + // ... and skip those that haven't been "processed". + BlockName p = getBlockName(pred); + BlockName dominatorOfPredecessor = doms[p]; + if(dominatorOfPredecessor == kUndefined) + continue; + + // When we encounter the first "processed" predecessor, + // we can initialize the variable tracking our best + // guess at the immediate dominator. + // + if(new_idom == kUndefined) + { + new_idom = p; + } + // + // Otherwise, we need to merge information between + // the predecessor `p` and our best-guess immediate + // dominator `new_idom`. We need a node that dominates + // both of them to be the immediate dominator of `b`. + // + else + { + new_idom = intersect(p, new_idom); + } + } + + // After we've computed a new best guess at the immediate + // dominator for `b`, we need to see if the computed + // value differs from what we'd previously stored in the + // `doms` array. If anything changed, then we haven't + // converged yet, and we need to keep going. + // + BlockName oldDominator = doms[b]; + if(oldDominator != new_idom) + { + doms[b] = new_idom; + changed = true; + } + } + } + + // Upon exiting the loop, things should have converged with + // the `doms` array being an explicit encoding of the immediate + // dominator for each node, with one small error: there is no + // immediate dominator for the start node: + doms[startNode] = kUndefined; + } + + // + // The algorithm above relied on a utility routine `intersect()` that + // is implicitly used to compute intersections between sets of nodes, + // but explicitly takes the form of a routine that computes a common + // parent in the dominator tree for two nodes. + // + // We present that subroutine here, almost identical to how it + // is presented in Cooper et al. Figure 3: + // + BlockName intersect(BlockName b1, BlockName b2) + { + // We need to find a common ancestor of both `b1` and `b2`, + // and will do this by tracking two "fingers," each initially + // pointing at one node, and then iteratively move the finger + // that is furthest to the "left" (earlier in the postorder + // traversal to the left until) to the "right" (by moving + // the immediate dominator of the node we are pointing at), + // until the two fingers are pointing at the same place. + // + // Termination is guaranteed because we are always moving the + // fingers from a node to its immediate dominator, and the + // entry node is guaranteed to be at the root of the dominator + // tree. + // + // The use of the postorder here relies on the (subtle) fact + // that the immediate dominator of a node must come later + // in a postorder traversal. + // + BlockName finger1 = b1; + BlockName finger2 = b2; + + while(finger1 != finger2) + { + while(finger1 < finger2) + finger1 = doms[finger1]; + while(finger2 < finger1) + finger2 = doms[finger2]; + } + return finger1; + } + + // + // Now that we've implemented Cooper et al. fairly close to how + // it was presented, we can build an array encoding the immediate + // dominator relationship. We still need to expand that array + // into an encoding that lets us efficiently answer queries + // about dominance. + // + // In order to do that, we need to expand the information we + // have built on each block (currently just an immediate dominator) + // into a bit more detail: + // + struct BlockInfo + { + // How many children does this node/block have in the dominator tree? + Int childCount = 0; + + // How many indirect (non-child) descendents? + Int indirectDescendentCount = 0; + + // What is the 0-based offset of this node among all the children of its parent? + Int childOffsetInParent = 0; + + // What is the 0-based offset for this node's descendent list, + // among all the children in its parent? + Int descendentOffsetInParent = 0; + + Int nodeIndex = 0; + Int firstDescendentIndex = 0; + }; + // + + RefPtr createDominatorTree(IRGlobalValueWithCode* code) + { + // We first run the Cooper et al. algorithm to compute the `doms` array + // which encodes immediate dominators. + // + iterativelyComputeImmediateDominators(code); + + // We will build some intermediate information on each + // block to help us fill out the tree. + BlockName blockCount = BlockName(doms.getCount()); + List blockInfos; + for(BlockName bb = 0; bb < blockCount; ++bb) + { + blockInfos.add(BlockInfo()); + } + + // We will propagate layout information in two passes over the tree. + // + // First we will perform a "bottom up" pass that will accumulate + // the number of children and the total number of descendents for + // each node, and also assign each child its relative offsets within + // the storage for its parent. + // + // Because our blocks are ordered in postorder, we can do this + // bottom-up walk just by iterating over them in the given order. + // + for(BlockName bb = 0; bb < blockCount; ++bb) + { + BlockName parent = doms[bb]; + if(parent == kUndefined) + continue; + + // For our iteration order to make sense, we need to be certain + // that parent nodes come after their child nodes in the postorder traversal. + SLANG_ASSERT(parent > bb); + + // Compute the 0-based index of this child among all the children + // with the same parent, and increment its child count. + blockInfos[bb].childOffsetInParent = blockInfos[parent].childCount; + blockInfos[parent].childCount++; + + // Our layout for the descendents of a node will put all the immediate + // child nodes contiguously first, followed by their descendents (in contiguous blocks). + // + // We need to compute an offset for where the descendents of this node will + // be stored, within the overall space carved out for the "indirect" descendents + // of the parent node. + // + blockInfos[bb].descendentOffsetInParent = blockInfos[parent].indirectDescendentCount; + // + // When adding up the indirect descendents of `parent`, we need to include both + // the direct and indirect descendents of our node `bb`. + blockInfos[parent].indirectDescendentCount += blockInfos[bb].childCount + + blockInfos[bb].indirectDescendentCount; + } + // + // The next pass is a top-down pass that uses the accumulated + // information to assign absolute indices to each node. + // + // For each node, we want to compute its absolute index in + // the overall array of nodes, and then we also want to compute + // the index where its first descendent node will be placed + // (which can then be used by child nodes to compute their + // index). + // + // The start node in the CFG is special, and will always get + // index zero, with its first desecendent at index 1. + // + BlockName startBlock = getBlockName(code->getFirstBlock()); + blockInfos[startBlock].nodeIndex = 0; + blockInfos[startBlock].firstDescendentIndex = 1; + // + // For the remaining nodes, we'll compute them in a top-down + // pass (using reverse postorder). + // + for(BlockName bb = blockCount-1; bb >= 0; --bb) + { + // We will skip nodes without a parent in the dominator tree. + // This should really only be the start node, but it might + // happen that we have some unreachable nodes that shouldn't + // appear in the dominator tree at all. + // + // TODO: make sure we either handle those correctly, or + // else add a pass to eliminate unreachable blocks first. + // + BlockName parent = doms[bb]; + if(parent == kUndefined) + continue; + + // The absolute index of a node is the absolute index for its + // parent's descendent list, plus the relative offset of this + // child node in its parent. + // + blockInfos[bb].nodeIndex = blockInfos[parent].firstDescendentIndex + + blockInfos[bb].childOffsetInParent; + + // The other descendents of a node are always laid out in the space + // after its immediate children. Thus, the index for where this node + // will place its descendents (direct + indirect) must come after + // the storage for the children of the parent. + // + blockInfos[bb].firstDescendentIndex = blockInfos[parent].firstDescendentIndex + + blockInfos[parent].childCount + + blockInfos[bb].descendentOffsetInParent; + } + + // We now have all the information we need, and can start to fill in + // the actual `IRDominatorTree` structure with the encoded information. + // + RefPtr dominatorTree = new IRDominatorTree(); + dominatorTree->code = code; + dominatorTree->nodes.setCount(blockCount); + + // We will iterate over all of the blocks, and fill in the corresponding + // dominator tree node for each. + // + // Note that the number of the blocks (in postorder) and the numbering + // of the nodes (in breadth-first order) will not match, so we have + // to be careful around whehter we are working with a block index/name, + // or a node index. + // + for(BlockName bb = 0; bb < blockCount; ++bb) + { + // Find the IR block, look up our pre-computed information, + // and find the corresponding node in the dominator tree. + // + IRBlock* block = postorder[bb]; + BlockInfo const& blockInfo = blockInfos[bb]; + Int nodeIndex = blockInfo.nodeIndex; + IRDominatorTree::Node& node = dominatorTree->nodes[nodeIndex]; + + // We will now start filling in the node. Filling in the block is + // trial, and while we are at it we can add an entry to the mapping + // from the block to the node index. + // + node.block = block; + dominatorTree->mapBlockToIndex.Add(block, nodeIndex); + + // Filling in the parent is easy enough, just with the detail that + // we need to handle the invalid case explicitly (for a node with + // no parent), and need to carefully map the block index `parent` + // over to its corresponding node index. + // + BlockName parent = doms[bb]; + node.parent = parent == kUndefined ? IRDominatorTree::kInvalidIndex : blockInfos[parent].nodeIndex; + + // Finally we need to compute the range information to use for the + // descendents (both immediate children and indirect descendents). + // + // All of the relevant information was computed in our two passes + // above, so all that has to happen here is adding together the + // absolute start index for the descendent range with the counts + // we accumulated. + // + Int beginDescendents = blockInfo.firstDescendentIndex; + Int endChildren = beginDescendents + blockInfo.childCount; + // + // The indirect descendents of a node will always come after + // its direct descenents. + // + Int endDescendents = endChildren + blockInfo.indirectDescendentCount; + node.beginDescendents = beginDescendents; + node.endChildren = endChildren; + node.endDescendents = endDescendents; + } + +#if 0 + // Let's do some ad hoc validation here, just to be sure we built the + // data structure reasonably. + for(BlockName ii = 0; ii < blockCount; ++ii) + { + for(BlockName jj = 0; jj < blockCount; ++jj) + { + IRBlock* i = postorder[ii]; + IRBlock* j = postorder[jj]; + + SLANG_RELEASE_ASSERT(dominatorTree->immediatelyDominates(i, j) == (ii == doms[jj])); + + Int dd = jj; + while(dd != kUndefined) + { + if(dd == ii) + break; + dd = doms[dd]; + } + SLANG_RELEASE_ASSERT(dominatorTree->dominates(i, j) == (dd != kUndefined)); + + } + } +#endif + + return dominatorTree; + } +}; + + +RefPtr computeDominatorTree(IRGlobalValueWithCode* code) +{ + DominatorTreeComputationContext context; + return context.createDominatorTree(code); +} + +} diff --git a/source/slang/slang-ir-dominators.h b/source/slang/slang-ir-dominators.h new file mode 100644 index 000000000..7ec31b821 --- /dev/null +++ b/source/slang/slang-ir-dominators.h @@ -0,0 +1,162 @@ +// slang-ir-dominators.h +#pragma once + +#include "../core/slang-basic.h" + +namespace Slang +{ + struct IRBlock; + struct IRGlobalValueWithCode; + + /// The computed dominator tree for an IR control flow graph. + struct IRDominatorTree : public RefObject + { + /// The function or other code-bearing value for which the dominator tree was computed. + IRGlobalValueWithCode* code; + + /// Does the first block dominate the second? + /// + /// A block A dominates block B iff every control-flow path + /// that starts at the entry block of the CFG and passes + /// through B must first pass through A. + /// + bool dominates(IRBlock* dominator, IRBlock* dominated); + + /// Does the first block properly dominate the second? + /// + /// Block A properly dominates block B iff A dominates B + /// and A != B. + /// + bool properlyDominates(IRBlock* dominator, IRBlock* dominated); + + /// Does the first block immediately dominate the second? + /// + /// Block A immediately dominates block B iff A dominates B + /// and for any block X that dominates B, X also dominates A. + /// + bool immediatelyDominates(IRBlock* dominator, IRBlock* dominated); + + /// Get the immediate dominator (idom) of a block. + /// + /// This is the parent of `block` in the dominator tree. + IRBlock* getImmediateDominator(IRBlock* block); + + /// An iterable collection of the blocks dominated by a specific block + struct DominatedList; + + /// Get the blocks that a block immediately dominates. + /// + /// These are the children of the block in the dominator tree. + DominatedList getImmediatelyDominatedBlocks(IRBlock* block); + + /// Get the blocks that a block properly dominates. + /// + /// These are the descendents of the block in the dominator tree. + DominatedList getProperlyDominatedBlocks(IRBlock* block); + + struct DominatedList + { + public: + DominatedList(); + + struct Iterator + { + public: + Iterator(); + + IRBlock* operator*() const; + void operator++(); + bool operator==(Iterator const& that) const; + + private: + friend struct DominatedList; + Iterator( + IRDominatorTree* tree, + Int index); + + IRDominatorTree* mTree; + Int mIndex; + }; + + Iterator begin() const; + Iterator end() const; + + private: + friend struct IRDominatorTree; + DominatedList( + IRDominatorTree* tree, + Int begin, + Int end); + + IRDominatorTree* mTree; + Int mBegin; + Int mEnd; + }; + + private: + // + // The layout of an `IRDominatorTree` uses a dense array for all of the nodes in the CFG. + // We therefore need a way to map an `IRBlock*` pointer over to an index in this array: + // + + /// Map a block to its index in the `nodes` array + Int getBlockIndex(IRBlock* block); + + /// Dictionary used to accelerate `getBlockIndex` + Dictionary mapBlockToIndex; + + // + // In order to accelerate queries on the tree structure, we will order the tree nodes + // carefully, so that all of the descendants of a node are contiguous, with all of + // the immediate children coming first. + // + // Each node thus needs to remember its parent (immediate dominator), and the range + // of indices that represent children and descendents (respectively), with the knowledge + // that the first child and first descendent share the same index. + // + + /// Information about one node in the dominator tree + struct Node + { + /// The block associated with this tree node + IRBlock* block; + + /// Index of the parent node or -1 if no parent + Int parent; + + /// Index of first descendent + Int beginDescendents; + + /// "One after the end" value for range of child node indices. + Int endChildren; + + /// "One after the end" value for range of descendent node indices. + Int endDescendents; + }; + + /// Storage for the dominator tree itself + List nodes; + + /// Value to use for invalid node indices (e.g., + /// when a node has no parent). + static const Int kInvalidIndex = -1; + + // + // The `DominatedList` type needs direct access to all of this + // data in order to provide iteration. + // + friend struct DominatedList; + friend struct DominatedList::Iterator; + // + // The context type we will use to compute the dominator tree + // also needs to be able to access all the fields to initialze + // an `IRDominatorTree` + // + friend struct DominatorTreeComputationContext; + + // TODO: we should probably build/store a postdominator + // tree in the same structure, just to make life simpler. + }; + + RefPtr computeDominatorTree(IRGlobalValueWithCode* code); +} diff --git a/source/slang/slang-ir-entry-point-uniforms.cpp b/source/slang/slang-ir-entry-point-uniforms.cpp new file mode 100644 index 000000000..20e726f25 --- /dev/null +++ b/source/slang/slang-ir-entry-point-uniforms.cpp @@ -0,0 +1,425 @@ +// slang-ir-entry-point-uniforms.cpp +#include "slang-ir-entry-point-uniforms.h" + +#include "slang-ir.h" +#include "slang-ir-insts.h" + +#include "slang-mangle.h" + +namespace Slang +{ + + +// The transformation in this file will solve the problem of taking +// code like the following: +// +// float4 fragmentMain( +// uniform Texture2D t, +// uniform SamplerState s; +// uniform float4 c, +// float2 uv : UV) : SV_Target +// { +// return t.Sample(s, uv) + c; +// } +// +// and transforming into code like this: +// +// struct Params +// { +// Texture2D t; +// SamplerState s; +// float4 c; +// } +// ConstantBuffer params; +// +// float4 fragmentMain( +// float2 uv : UV) : SV_Target +// { +// return params.t.Sample(params.s, uv) + params.c; +// } +// +// As can be seen in this example, the `uniform` parameters +// declared as entry point parameters have been moved into +// a `struct` declaration that we then use to declare a global +// shader parameter that is a `ConstantBuffer`. We then +// rewrite references to those parameters to refer to the +// contents of the new constant buffer instead. +// +// We perform this transformation after the target-specific +// linking step, because that will have attached layout information +// to the entry point and its parameters. We need that layout +// information so that we can: +// +// * Identify which parameters are uniform vs. varying. +// * Have an appropriate layout to attached to the synthesized +// global shader parameter `params`. +// +// One additional wrinkle this pass has to deal with is that +// in the case where the shader doesn't have any "ordinary" +// uniform parameters like `c` (e.g., it only has resource/object +// parameters), we do *not* wrap the parameter `struct` in +// a `ConstantBuffer`. For example, suppose we have: +// +// float4 fragmentMain( +// uniform Texture2D t, +// uniform SamplerState s; +// float2 uv : UV) : SV_Target +// { +// return t.Sample(s, uv); +// } +// +// In this case the output of the transformation should be: +// +// struct Params +// { +// Texture2D t; +// SamplerState s; +// } +// Params params; +// +// float4 fragmentMain( +// float2 uv : UV) : SV_Target +// { +// return params.t.Sample(params.s, uv) + params.c; +// } +// +// Note that this pass should always come before type legalization, +// which will take responsibility for turning a variable like +// `params` above into individual variables for the `t` and +// `s` fields. + +// The overall structure here is similar to many other IR passes. +// We define a "context" structure to encapsulate the pass. +// +struct MoveEntryPointUniformParametersToGlobalScope +{ + // We'll hang on to the module we are processing, + // so that we can refer to it when setting up `IRBuilder`s. + // + IRModule* module; + + // We will process a whole module by visiting all + // its global functions, looking for entry points. + // + void processModule() + { + // Note that we are only looking at true global-scope + // functions and not functions nested inside of + // IR generics. When using generic entry points, this + // pass should be run after the entry point(s) have + // been specialized to their generic type parameters. + + for( auto inst : module->getGlobalInsts() ) + { + // We are only interested in entry points. + // + // Every entry point must be a function. + // + auto func = as(inst); + if( !func ) + continue; + + // Entry points will always have the `[entryPoint]` + // decoration to differentiate them from ordinary + // functions. + // + // TODO: we could make `IREntryPoint` a subclass of + // `IRFunc` if desired, to avoid having to attach + // an explicit decoration to identify them. + // + if( !func->findDecorationImpl(kIROp_EntryPointDecoration) ) + continue; + + // If we fine a candidate entry point, then we + // will process it. + // + processEntryPoint(func); + } + } + + void processEntryPoint(IRFunc* func) + { + // We expect all entry points to have explicit layout information attached. + // + // We will assert that we have the information we need, but try to be + // defensive and bail out in the failure case in release builds. + // + auto funcLayoutDecoration = func->findDecoration(); + SLANG_ASSERT(funcLayoutDecoration); + if(!funcLayoutDecoration) + return; + + auto entryPointLayout = as(funcLayoutDecoration->getLayout()); + SLANG_ASSERT(entryPointLayout); + if(!entryPointLayout) + return; + + // The parameter layout for an entry point will either be a structure + // type layout, or a constant buffer (a case of parameter group) + // wrapped around such a structure. + // + // If we are in the latter case we will need to make sure to allocate + // an explicit IR constant buffer for that wrapper, + // + auto entryPointParamsLayout = entryPointLayout->parametersLayout; + bool needConstantBuffer = entryPointParamsLayout->typeLayout.is(); + + // We will set up an IR builder so that we are ready to generate code. + // + SharedIRBuilder sharedBuilderStorage; + auto sharedBuilder = &sharedBuilderStorage; + sharedBuilder->module = module; + sharedBuilder->session = module->getSession(); + + IRBuilder builderStorage; + auto builder = &builderStorage; + builder->sharedBuilder = sharedBuilder; + + // *If* the entry point has any uniform parameter then we want to create a + // structure type to house them, and a global shader parameter (either + // an instance of that type or a constant buffer). + // + // We only want to create these if actually needed, so we will declare + // them here and then initialize them on-demand. + // + IRStructType* paramStructType = nullptr; + IRGlobalParam* globalParam = nullptr; + + // We will be removing any uniform parameters we run into, so we + // need to iterate the parameter list carefully to deal with + // us modifying it along the way. + // + IRParam* nextParam = nullptr; + for( IRParam* param = func->getFirstParam(); param; param = nextParam ) + { + nextParam = param->getNextParam(); + + // We expect all entry-point parameters to have layout information, + // but we will be defensive and skip parameters without the required + // information when we are in a release build. + // + auto layoutDecoration = param->findDecoration(); + SLANG_ASSERT(layoutDecoration); + if(!layoutDecoration) + continue; + auto paramLayout = as(layoutDecoration->getLayout()); + SLANG_ASSERT(paramLayout); + if(!paramLayout) + continue; + + // A parameter that has varying input/output behavior should be left alone, + // since this pass is only supposed to apply to uniform (non-varying) + // parameters. + // + if(isVaryingParameter(paramLayout)) + continue; + + // At this point we know that `param` is not a varying shader parameter, + // so that we want to turn it into an equivalent global shader parameter. + // + // If this is the first parameter we are running into, then we need + // to deal with creating the structure type and global shader + // parameter that our transformed entry point will use. + // + if( !paramStructType ) + { + // First we create the structure to hold the parameters. + // + builder->setInsertBefore(func); + paramStructType = builder->createStructType(); + + if( needConstantBuffer ) + { + // If we need a constant buffer, then the global + // shader parameter will be a `ConstantBuffer` + // + auto constantBufferType = builder->getConstantBufferType(paramStructType); + globalParam = builder->createGlobalParam(constantBufferType); + } + else + { + // Otherwise, the global shader parameter is just + // an instance of `paramStructType`. + // + globalParam = builder->createGlobalParam(paramStructType); + } + + // No matter what, the global shader parameter should have the layout + // information from the entry point attached to it, so that the + // contained parameters will end up in the right place(s). + // + builder->addLayoutDecoration(globalParam, entryPointParamsLayout); + } + + // Now that we've ensured the global `struct` type and shader paramter + // exist, we need to add a field to the `struct` to represent the + // current parameter. + // + + auto paramType = param->getFullType(); + + builder->setInsertBefore(paramStructType); + auto paramFieldKey = builder->createStructKey(); + auto paramField = builder->createStructField(paramStructType, paramFieldKey, paramType); + SLANG_UNUSED(paramField); + + // We will transfer all decorations on the parameter over to the key + // so that they can affect downstream emit logic. + // + // TODO: We should double-check whether any of the decorations should + // be moved to the *field* instead. + // + param->transferDecorationsTo(paramFieldKey); + + // There is a bit of a hacky issue, where downstream passes (notably + // type legalization) require the field keys for `struct` types to + // have mangled names, because those mangled names will be used to + // lookup field layout information inside of the layout information + // for the `struct` type. + // + // TODO: We should fix that design choice in how layout information + // is stored, to avoid the reliance on name strings. + // + builder->addExportDecoration(paramFieldKey, getMangledName(paramLayout->varDecl).getUnownedSlice()); + + // At this point we want to eliminate the original entry point + // parameter, in favor of the `struct` field we declared. + // That required replacing any uses of the parameter with + // appropriate code to pull out the field. + // + // We *could* extract the field at the start of the shader + // and then do a `replaceAllUsesWith` to propragate it + // down, but in practice we expect that it is better for + // performance to "rematerialize" the value of a shader + // parameter as close to where it is used as possible. + // + // We are therefore going to replace the uses one at a time. + // + while(auto use = param->firstUse ) + { + // Given a `use` of the paramter, we will insert + // the replacement code right before the instruction + // that is doing the using. + // + builder->setInsertBefore(use->getUser()); + + // The way to extract the field that corresponds + // to the parameter depends on whether or not + // we generated a constant buffer. + // + IRInst* fieldVal = nullptr; + if( needConstantBuffer ) + { + // A constant buffer behaves like a pointer + // at the IR level, so we first do a pointer + // offset operation to compute what amounts + // to `&cb->field`, and then load from that address. + // + auto fieldAddress = builder->emitFieldAddress( + builder->getPtrType(paramType), + globalParam, + paramFieldKey); + fieldVal = builder->emitLoad(fieldAddress); + } + else + { + // In the ordinary struct case, the parameter + // has an ordinary `struct` type (not a pointer), + // so we just extract the field directly. + // + fieldVal = builder->emitFieldExtract( + paramType, + globalParam, + paramFieldKey); + } + + // We replace the value used at this use site, which + // will have a side effect of making `use` no longer + // be on the list of uses for `param`, so that when + // we get back to the top of the loop the list of + // uses will be shorter. + // + use->set(fieldVal); + } + + // Once we've replaced all the uses of `param`, we + // can go ahead and remove it completely. + // + param->removeAndDeallocate(); + } + + fixUpFuncType(func); + } + + // We need to be able to determine if a parameter is logically + // a "varying" parameter based on its layout. + // + bool isVaryingParameter(VarLayout* layout) + { + // If *any* of the resources consumed by the parameter + // is a varying resource kind (e.g., varying input) then + // we consider the whole parameter to be varying. + // + // This is reasonable because there is no way to declare + // a parameter that mixes varying and non-varying fields. + // + for( auto resInfo : layout->resourceInfos ) + { + if(isVaryingResourceKind(resInfo.kind)) + return true; + } + + // Varying parameters with "system value" semantics currently show up as + // consuming no resources, so we need to special-case that here. + // + // Note: an empty `struct` parameter would also show up the same way, but + // we should eliminate any such parameters later on during type legalization. + // + if(layout->resourceInfos.getCount() == 0) + return true; + + // if none of the above tests determined that the + // parameter was varying, then we can safely consider + // it to be non-varying (uniform): + return false; + } + + // In order to determine whether a parameter is varying based on its + // layout, we need to know which resource kinds represent varying + // shader parameters. + // + bool isVaryingResourceKind(LayoutResourceKind kind) + { + switch( kind ) + { + default: + return false; + + // Note: The set of cases that are considered + // varying here would need to be extended if we + // add more fine-grained resource kinds (e.g., + // if we ever add an explicit resource kind + // for geometry shader output streams). + // + // Ordinary varying input/output: + case LayoutResourceKind::VaryingInput: + case LayoutResourceKind::VaryingOutput: + // + // Ray-tracing shader input/output: + case LayoutResourceKind::CallablePayload: + case LayoutResourceKind::HitAttributes: + case LayoutResourceKind::RayPayload: + return true; + } + } +}; + +void moveEntryPointUniformParamsToGlobalScope( + IRModule* module) +{ + MoveEntryPointUniformParametersToGlobalScope context; + context.module = module; + context.processModule(); +} + +} diff --git a/source/slang/slang-ir-entry-point-uniforms.h b/source/slang/slang-ir-entry-point-uniforms.h new file mode 100644 index 000000000..49994c202 --- /dev/null +++ b/source/slang/slang-ir-entry-point-uniforms.h @@ -0,0 +1,12 @@ +// slang-ir-entry-point-uniform.h +#pragma once + +namespace Slang +{ +struct IRModule; + + /// Move any uniform parameters of entry points to the global scope instead. +void moveEntryPointUniformParamsToGlobalScope( + IRModule* module); + +} diff --git a/source/slang/slang-ir-glsl-legalize.cpp b/source/slang/slang-ir-glsl-legalize.cpp new file mode 100644 index 000000000..1e42cba62 --- /dev/null +++ b/source/slang/slang-ir-glsl-legalize.cpp @@ -0,0 +1,1687 @@ +// slang-ir-glsl-legalize.cpp +#include "slang-ir-glsl-legalize.h" + +#include "slang-ir.h" +#include "slang-ir-insts.h" + +#include "slang-extension-usage-tracker.h" + +namespace Slang +{ + +// +// Legalization of entry points for GLSL: +// + +IRGlobalParam* addGlobalParam( + IRModule* module, + IRType* valueType) +{ + auto session = module->session; + + SharedIRBuilder shared; + shared.module = module; + shared.session = session; + + IRBuilder builder; + builder.sharedBuilder = &shared; + return builder.createGlobalParam(valueType); +} + +void moveValueBefore( + IRInst* valueToMove, + IRInst* placeBefore) +{ + valueToMove->removeFromParent(); + valueToMove->insertBefore(placeBefore); +} + +IRType* getFieldType( + IRType* baseType, + IRStructKey* fieldKey) +{ + if(auto structType = as(baseType)) + { + for(auto ff : structType->getFields()) + { + if(ff->getKey() == fieldKey) + return ff->getFieldType(); + } + } + + SLANG_UNEXPECTED("no such field"); + UNREACHABLE_RETURN(nullptr); +} + + + +// When scalarizing shader inputs/outputs for GLSL, we need a way +// to refer to a conceptual "value" that might comprise multiple +// IR-level values. We could in principle introduce tuple types +// into the IR so that everything stays at the IR level, but +// it seems easier to just layer it over the top for now. +// +// The `ScalarizedVal` type deals with the "tuple or single value?" +// question, and also the "l-value or r-value?" question. +struct ScalarizedValImpl : RefObject +{}; +struct ScalarizedTupleValImpl; +struct ScalarizedTypeAdapterValImpl; +struct ScalarizedVal +{ + enum class Flavor + { + // no value (null pointer) + none, + + // A simple `IRInst*` that represents the actual value + value, + + // An `IRInst*` that represents the address of the actual value + address, + + // A `TupleValImpl` that represents zero or more `ScalarizedVal`s + tuple, + + // A `TypeAdapterValImpl` that wraps a single `ScalarizedVal` and + // represents an implicit type conversion applied to it on read + // or write. + typeAdapter, + }; + + // Create a value representing a simple value + static ScalarizedVal value(IRInst* irValue) + { + ScalarizedVal result; + result.flavor = Flavor::value; + result.irValue = irValue; + return result; + } + + + // Create a value representing an address + static ScalarizedVal address(IRInst* irValue) + { + ScalarizedVal result; + result.flavor = Flavor::address; + result.irValue = irValue; + return result; + } + + static ScalarizedVal tuple(ScalarizedTupleValImpl* impl) + { + ScalarizedVal result; + result.flavor = Flavor::tuple; + result.impl = (ScalarizedValImpl*)impl; + return result; + } + + static ScalarizedVal typeAdapter(ScalarizedTypeAdapterValImpl* impl) + { + ScalarizedVal result; + result.flavor = Flavor::typeAdapter; + result.impl = (ScalarizedValImpl*)impl; + return result; + } + + Flavor flavor = Flavor::none; + IRInst* irValue = nullptr; + RefPtr impl; +}; + +// This is the case for a value that is a "tuple" of other values +struct ScalarizedTupleValImpl : ScalarizedValImpl +{ + struct Element + { + IRStructKey* key; + ScalarizedVal val; + }; + + IRType* type; + List elements; +}; + +// This is the case for a value that is stored with one type, +// but needs to present itself as having a different type +struct ScalarizedTypeAdapterValImpl : ScalarizedValImpl +{ + ScalarizedVal val; + IRType* actualType; // the actual type of `val` + IRType* pretendType; // the type this value pretends to have +}; + +struct GlobalVaryingDeclarator +{ + enum class Flavor + { + array, + }; + + Flavor flavor; + IRInst* elementCount; + GlobalVaryingDeclarator* next; +}; + +struct GLSLSystemValueInfo +{ + // The name of the built-in GLSL variable + char const* name; + + // The name of an outer array that wraps + // the variable, in the case of a GS input + char const* outerArrayName; + + // The required type of the built-in variable + IRType* requiredType; +}; + +struct GLSLLegalizationContext +{ + Session* session; + ExtensionUsageTracker* extensionUsageTracker; + DiagnosticSink* sink; + Stage stage; + + void requireGLSLExtension(String const& name) + { + extensionUsageTracker->requireGLSLExtension(name); + } + + void requireGLSLVersion(ProfileVersion version) + { + extensionUsageTracker->requireGLSLVersion(version); + } + + Stage getStage() + { + return stage; + } + + DiagnosticSink* getSink() + { + return sink; + } + + IRBuilder* builder; + IRBuilder* getBuilder() { return builder; } +}; + +GLSLSystemValueInfo* getGLSLSystemValueInfo( + GLSLLegalizationContext* context, + VarLayout* varLayout, + LayoutResourceKind kind, + Stage stage, + GLSLSystemValueInfo* inStorage) +{ + char const* name = nullptr; + char const* outerArrayName = nullptr; + + auto semanticNameSpelling = varLayout->systemValueSemantic; + if(semanticNameSpelling.getLength() == 0) + return nullptr; + + auto semanticName = semanticNameSpelling.toLower(); + + // HLSL semantic types can be found here + // https://docs.microsoft.com/en-us/windows/desktop/direct3dhlsl/dx-graphics-hlsl-semantics + /// NOTE! While there might be an "official" type for most of these in HLSL, in practice the user is allowed to declare almost anything + /// that the HLSL compiler can implicitly convert to/from the correct type + + auto builder = context->getBuilder(); + IRType* requiredType = nullptr; + + if(semanticName == "sv_position") + { + // float4 in hlsl & glsl + // https://www.khronos.org/registry/OpenGL-Refpages/gl4/html/gl_FragCoord.xhtml + // https://www.khronos.org/registry/OpenGL-Refpages/gl4/html/gl_Position.xhtml + + // This semantic can either work like `gl_FragCoord` + // when it is used as a fragment shader input, or + // like `gl_Position` when used in other stages. + // + // Note: This isn't as simple as testing input-vs-output, + // because a user might have a VS output `SV_Position`, + // and then pass it along to a GS that reads it as input. + // + if( stage == Stage::Fragment + && kind == LayoutResourceKind::VaryingInput ) + { + name = "gl_FragCoord"; + } + else if( stage == Stage::Geometry + && kind == LayoutResourceKind::VaryingInput ) + { + // As a GS input, the correct syntax is `gl_in[...].gl_Position`, + // but that is not compatible with picking the array dimension later, + // of course. + outerArrayName = "gl_in"; + name = "gl_Position"; + } + else + { + name = "gl_Position"; + } + + requiredType = builder->getVectorType(builder->getBasicType(BaseType::Float), builder->getIntValue(builder->getIntType(), 4)); + } + else if(semanticName == "sv_target") + { + // Note: we do *not* need to generate some kind of `gl_` + // builtin for fragment-shader outputs: they are just + // ordinary `out` variables, with ordinary `location`s, + // as far as GLSL is concerned. + return nullptr; + } + else if(semanticName == "sv_clipdistance") + { + // TODO: type conversion is required here. + + // float in hlsl & glsl. + // "Clip distance data. SV_ClipDistance values are each assumed to be a float32 signed distance to a plane." + // In glsl clipping value meaning is probably different + // https://www.khronos.org/registry/OpenGL-Refpages/gl4/html/gl_ClipDistance.xhtml + + name = "gl_ClipDistance"; + requiredType = builder->getBasicType(BaseType::Float); + } + else if(semanticName == "sv_culldistance") + { + // float in hlsl & glsl. + // https://www.khronos.org/registry/OpenGL-Refpages/gl4/html/gl_CullDistance.xhtml + + context->requireGLSLExtension("ARB_cull_distance"); + + // TODO: type conversion is required here. + name = "gl_CullDistance"; + requiredType = builder->getBasicType(BaseType::Float); + } + else if(semanticName == "sv_coverage") + { + // TODO: deal with `gl_SampleMaskIn` when used as an input. + + // TODO: type conversion is required here. + + // uint in hlsl, int in glsl + // https://www.opengl.org/sdk/docs/manglsl/docbook4/xhtml/gl_SampleMask.xml + + requiredType = builder->getBasicType(BaseType::Int); + + name = "gl_SampleMask"; + } + else if(semanticName == "sv_depth") + { + // Float in hlsl & glsl + // https://www.khronos.org/registry/OpenGL-Refpages/gl4/html/gl_FragDepth.xhtml + name = "gl_FragDepth"; + requiredType = builder->getBasicType(BaseType::Float); + } + else if(semanticName == "sv_depthgreaterequal") + { + // TODO: layout(depth_greater) out float gl_FragDepth; + + // Type is 'unknown' in hlsl + name = "gl_FragDepth"; + requiredType = builder->getBasicType(BaseType::Float); + } + else if(semanticName == "sv_depthlessequal") + { + // TODO: layout(depth_greater) out float gl_FragDepth; + + // 'unknown' in hlsl, float in glsl + name = "gl_FragDepth"; + requiredType = builder->getBasicType(BaseType::Float); + } + else if(semanticName == "sv_dispatchthreadid") + { + // uint3 in hlsl, uvec3 in glsl + // https://www.opengl.org/sdk/docs/manglsl/docbook4/xhtml/gl_GlobalInvocationID.xml + name = "gl_GlobalInvocationID"; + + requiredType = builder->getVectorType(builder->getBasicType(BaseType::UInt), builder->getIntValue(builder->getIntType(), 3)); + } + else if(semanticName == "sv_domainlocation") + { + // float2|3 in hlsl, vec3 in glsl + // https://www.khronos.org/registry/OpenGL-Refpages/gl4/html/gl_TessCoord.xhtml + + requiredType = builder->getVectorType(builder->getBasicType(BaseType::Float), builder->getIntValue(builder->getIntType(), 3)); + + name = "gl_TessCoord"; + } + else if(semanticName == "sv_groupid") + { + // uint3 in hlsl, uvec3 in glsl + // https://www.opengl.org/sdk/docs/manglsl/docbook4/xhtml/gl_WorkGroupID.xml + name = "gl_WorkGroupID"; + + requiredType = builder->getVectorType(builder->getBasicType(BaseType::UInt), builder->getIntValue(builder->getIntType(), 3)); + } + else if(semanticName == "sv_groupindex") + { + // uint in hlsl & in glsl + name = "gl_LocalInvocationIndex"; + requiredType = builder->getBasicType(BaseType::UInt); + } + else if(semanticName == "sv_groupthreadid") + { + // uint3 in hlsl, uvec3 in glsl + name = "gl_LocalInvocationID"; + + requiredType = builder->getVectorType(builder->getBasicType(BaseType::UInt), builder->getIntValue(builder->getIntType(), 3)); + } + else if(semanticName == "sv_gsinstanceid") + { + // uint in hlsl, int in glsl + // https://www.khronos.org/registry/OpenGL-Refpages/gl4/html/gl_InvocationID.xhtml + + requiredType = builder->getBasicType(BaseType::Int); + name = "gl_InvocationID"; + } + else if(semanticName == "sv_instanceid") + { + // https://docs.microsoft.com/en-us/windows/desktop/direct3d11/d3d10-graphics-programming-guide-input-assembler-stage-using#instanceid + // uint in hlsl, int in glsl + + requiredType = builder->getBasicType(BaseType::Int); + name = "gl_InstanceIndex"; + } + else if(semanticName == "sv_isfrontface") + { + // bool in hlsl & glsl + // https://www.khronos.org/registry/OpenGL-Refpages/gl4/html/gl_FrontFacing.xhtml + name = "gl_FrontFacing"; + requiredType = builder->getBasicType(BaseType::Bool); + } + else if(semanticName == "sv_outputcontrolpointid") + { + // uint in hlsl, int in glsl + // https://www.khronos.org/registry/OpenGL-Refpages/gl4/html/gl_InvocationID.xhtml + + name = "gl_InvocationID"; + + requiredType = builder->getBasicType(BaseType::Int); + } + else if (semanticName == "sv_pointsize") + { + // float in hlsl & glsl + name = "gl_PointSize"; + requiredType = builder->getBasicType(BaseType::Float); + } + else if(semanticName == "sv_primitiveid") + { + // uint in hlsl, int in glsl + // https://www.khronos.org/registry/OpenGL-Refpages/gl4/html/gl_PrimitiveID.xhtml + name = "gl_PrimitiveID"; + + requiredType = builder->getBasicType(BaseType::Int); + } + else if (semanticName == "sv_rendertargetarrayindex") + { + // uint on hlsl, int on glsl + // https://www.khronos.org/registry/OpenGL-Refpages/gl4/html/gl_Layer.xhtml + + switch (context->getStage()) + { + case Stage::Geometry: + context->requireGLSLVersion(ProfileVersion::GLSL_150); + break; + + case Stage::Fragment: + context->requireGLSLVersion(ProfileVersion::GLSL_430); + break; + + default: + context->requireGLSLVersion(ProfileVersion::GLSL_450); + context->requireGLSLExtension("GL_ARB_shader_viewport_layer_array"); + break; + } + + name = "gl_Layer"; + requiredType = builder->getBasicType(BaseType::Int); + } + else if (semanticName == "sv_sampleindex") + { + // uint in hlsl, int in glsl + // https://www.khronos.org/registry/OpenGL-Refpages/gl4/html/gl_SampleID.xhtml + + requiredType = builder->getBasicType(BaseType::Int); + name = "gl_SampleID"; + } + else if (semanticName == "sv_stencilref") + { + // uint in hlsl, int in glsl + // https://www.khronos.org/registry/OpenGL/extensions/ARB/ARB_shader_stencil_export.txt + + requiredType = builder->getBasicType(BaseType::Int); + + context->requireGLSLExtension("ARB_shader_stencil_export"); + name = "gl_FragStencilRef"; + } + else if (semanticName == "sv_tessfactor") + { + // TODO(JS): Adjust type does *not* handle the conversion correctly. More specifically a float array hlsl + // parameter goes through code to make SOA in createGLSLGlobalVaryingsImpl. + // + // Can be input and output. + // + // https://docs.microsoft.com/en-us/windows/desktop/direct3dhlsl/sv-tessfactor + // "Tessellation factors must be declared as an array; they cannot be packed into a single vector." + // + // float[2|3|4] in hlsl, float[4] on glsl (ie both are arrays but might be different size) + // https://www.khronos.org/registry/OpenGL-Refpages/gl4/html/gl_TessLevelOuter.xhtml + + name = "gl_TessLevelOuter"; + + // float[4] on glsl + requiredType = builder->getArrayType(builder->getBasicType(BaseType::Float), builder->getIntValue(builder->getIntType(), 4)); + } + else if (semanticName == "sv_vertexid") + { + // uint in hlsl, int in glsl (https://www.khronos.org/opengl/wiki/Built-in_Variable_(GLSL)) + requiredType = builder->getBasicType(BaseType::Int); + name = "gl_VertexIndex"; + } + else if (semanticName == "sv_viewportarrayindex") + { + // uint on hlsl, int on glsl + // https://www.khronos.org/registry/OpenGL-Refpages/gl4/html/gl_ViewportIndex.xhtml + + requiredType = builder->getBasicType(BaseType::Int); + name = "gl_ViewportIndex"; + } + else if (semanticName == "nv_x_right") + { + context->requireGLSLVersion(ProfileVersion::GLSL_450); + context->requireGLSLExtension("GL_NVX_multiview_per_view_attributes"); + + // The actual output in GLSL is: + // + // vec4 gl_PositionPerViewNV[]; + // + // and is meant to support an arbitrary number of views, + // while the HLSL case just defines a second position + // output. + // + // For now we will hack this by: + // 1. Mapping an `NV_X_Right` output to `gl_PositionPerViewNV[1]` + // (that is, just one element of the output array) + // 2. Adding logic to copy the traditional `gl_Position` output + // over to `gl_PositionPerViewNV[0]` + // + + name = "gl_PositionPerViewNV[1]"; + +// shared->requiresCopyGLPositionToPositionPerView = true; + } + else if (semanticName == "nv_viewport_mask") + { + // TODO: This doesn't seem to work correctly on it's own between hlsl/glsl + + // Indeed on slang issue 109 claims this remains a problem + // https://github.com/shader-slang/slang/issues/109 + + // On hlsl it's UINT related. "higher 16 bits for the right view, lower 16 bits for the left view." + // There is use in hlsl shader code as uint4 - not clear if that varies + // https://github.com/KhronosGroup/GLSL/blob/master/extensions/nvx/GL_NVX_multiview_per_view_attributes.txt + // On glsl its highp int gl_ViewportMaskPerViewNV[]; + + context->requireGLSLVersion(ProfileVersion::GLSL_450); + context->requireGLSLExtension("GL_NVX_multiview_per_view_attributes"); + + name = "gl_ViewportMaskPerViewNV"; +// globalVarExpr = createGLSLBuiltinRef("gl_ViewportMaskPerViewNV", +// getUnsizedArrayType(getIntType())); + } + + if( name ) + { + inStorage->name = name; + inStorage->outerArrayName = outerArrayName; + inStorage->requiredType = requiredType; + return inStorage; + } + + context->getSink()->diagnose(varLayout->varDecl.getDecl()->loc, Diagnostics::unknownSystemValueSemantic, semanticNameSpelling); + return nullptr; +} + +ScalarizedVal createSimpleGLSLGlobalVarying( + GLSLLegalizationContext* context, + IRBuilder* builder, + IRType* inType, + VarLayout* inVarLayout, + TypeLayout* inTypeLayout, + LayoutResourceKind kind, + Stage stage, + UInt bindingIndex, + GlobalVaryingDeclarator* declarator) +{ + // Check if we have a system value on our hands. + GLSLSystemValueInfo systemValueInfoStorage; + auto systemValueInfo = getGLSLSystemValueInfo( + context, + inVarLayout, + kind, + stage, + &systemValueInfoStorage); + + IRType* type = inType; + + // A system-value semantic might end up needing to override the type + // that the user specified. + if( systemValueInfo && systemValueInfo->requiredType ) + { + type = systemValueInfo->requiredType; + } + + // Construct the actual type and type-layout for the global variable + // + RefPtr typeLayout = inTypeLayout; + for( auto dd = declarator; dd; dd = dd->next ) + { + // We only have one declarator case right now... + SLANG_ASSERT(dd->flavor == GlobalVaryingDeclarator::Flavor::array); + + auto arrayType = builder->getArrayType( + type, + dd->elementCount); + + RefPtr arrayTypeLayout = new ArrayTypeLayout(); +// arrayTypeLayout->type = arrayType; + arrayTypeLayout->rules = typeLayout->rules; + arrayTypeLayout->originalElementTypeLayout = typeLayout; + arrayTypeLayout->elementTypeLayout = typeLayout; + arrayTypeLayout->uniformStride = 0; + + if( auto resInfo = inTypeLayout->FindResourceInfo(kind) ) + { + // TODO: it is kind of gross to be re-running some + // of the type layout logic here. + + UInt elementCount = (UInt) GetIntVal(dd->elementCount); + arrayTypeLayout->addResourceUsage( + kind, + resInfo->count * elementCount); + } + + type = arrayType; + typeLayout = arrayTypeLayout; + } + + // We need to construct a fresh layout for the variable, even + // if the original had its own layout, because it might be + // an `inout` parameter, and we only want to deal with the case + // described by our `kind` parameter. + RefPtr varLayout = new VarLayout(); + varLayout->varDecl = inVarLayout->varDecl; + varLayout->typeLayout = typeLayout; + varLayout->flags = inVarLayout->flags; + varLayout->systemValueSemantic = inVarLayout->systemValueSemantic; + varLayout->systemValueSemanticIndex = inVarLayout->systemValueSemanticIndex; + varLayout->semanticName = inVarLayout->semanticName; + varLayout->semanticIndex = inVarLayout->semanticIndex; + varLayout->stage = inVarLayout->stage; + varLayout->AddResourceInfo(kind)->index = bindingIndex; + + // We are going to be creating a global parameter to replace + // the function parameter, but we need to handle the case + // where the parameter represents a varying *output* and not + // just an input. + // + // Our IR global shader parameters are read-only, just + // like our IR function parameters, and need a wrapper + // `Out<...>` type to represent outputs. + // + bool isOutput = kind == LayoutResourceKind::VaryingOutput; + IRType* paramType = isOutput ? builder->getOutType(type) : type; + + auto globalParam = addGlobalParam(builder->getModule(), paramType); + moveValueBefore(globalParam, builder->getFunc()); + + ScalarizedVal val = isOutput ? ScalarizedVal::address(globalParam) : ScalarizedVal::value(globalParam); + + if( systemValueInfo ) + { + builder->addImportDecoration(globalParam, UnownedTerminatedStringSlice(systemValueInfo->name)); + + if( auto fromType = systemValueInfo->requiredType ) + { + // We may need to adapt from the declared type to/from + // the actual type of the GLSL global. + auto toType = inType; + + if( !isTypeEqual(fromType, toType )) + { + RefPtr typeAdapter = new ScalarizedTypeAdapterValImpl; + typeAdapter->actualType = systemValueInfo->requiredType; + typeAdapter->pretendType = inType; + typeAdapter->val = val; + + val = ScalarizedVal::typeAdapter(typeAdapter); + } + } + + if(auto outerArrayName = systemValueInfo->outerArrayName) + { + builder->addGLSLOuterArrayDecoration(globalParam, UnownedTerminatedStringSlice(outerArrayName)); + } + } + + builder->addLayoutDecoration(globalParam, varLayout); + + return val; +} + +ScalarizedVal createGLSLGlobalVaryingsImpl( + GLSLLegalizationContext* context, + IRBuilder* builder, + IRType* type, + VarLayout* varLayout, + TypeLayout* typeLayout, + LayoutResourceKind kind, + Stage stage, + UInt bindingIndex, + GlobalVaryingDeclarator* declarator) +{ + if (as(type)) + { + return ScalarizedVal(); + } + else if( as(type) ) + { + return createSimpleGLSLGlobalVarying( + context, + builder, type, varLayout, typeLayout, kind, stage, bindingIndex, declarator); + } + else if( as(type) ) + { + return createSimpleGLSLGlobalVarying( + context, + builder, type, varLayout, typeLayout, kind, stage, bindingIndex, declarator); + } + else if( as(type) ) + { + // TODO: a matrix-type varying should probably be handled like an array of rows + return createSimpleGLSLGlobalVarying( + context, + builder, type, varLayout, typeLayout, kind, stage, bindingIndex, declarator); + } + else if( auto arrayType = as(type) ) + { + // We will need to SOA-ize any nested types. + + auto elementType = arrayType->getElementType(); + auto elementCount = arrayType->getElementCount(); + auto arrayLayout = as(typeLayout); + SLANG_ASSERT(arrayLayout); + auto elementTypeLayout = arrayLayout->elementTypeLayout; + + GlobalVaryingDeclarator arrayDeclarator; + arrayDeclarator.flavor = GlobalVaryingDeclarator::Flavor::array; + arrayDeclarator.elementCount = elementCount; + arrayDeclarator.next = declarator; + + return createGLSLGlobalVaryingsImpl( + context, + builder, + elementType, + varLayout, + elementTypeLayout, + kind, + stage, + bindingIndex, + &arrayDeclarator); + } + else if( auto streamType = as(type)) + { + auto elementType = streamType->getElementType(); + auto streamLayout = as(typeLayout); + SLANG_ASSERT(streamLayout); + auto elementTypeLayout = streamLayout->elementTypeLayout; + + return createGLSLGlobalVaryingsImpl( + context, + builder, + elementType, + varLayout, + elementTypeLayout, + kind, + stage, + bindingIndex, + declarator); + } + else if(auto structType = as(type)) + { + // We need to recurse down into the individual fields, + // and generate a variable for each of them. + + auto structTypeLayout = as(typeLayout); + SLANG_ASSERT(structTypeLayout); + RefPtr tupleValImpl = new ScalarizedTupleValImpl(); + + + // Construct the actual type for the tuple (including any outer arrays) + IRType* fullType = type; + for( auto dd = declarator; dd; dd = dd->next ) + { + SLANG_ASSERT(dd->flavor == GlobalVaryingDeclarator::Flavor::array); + fullType = builder->getArrayType( + fullType, + dd->elementCount); + } + + tupleValImpl->type = fullType; + + // Okay, we want to walk through the fields here, and + // generate one variable for each. + UInt fieldCounter = 0; + for(auto field : structType->getFields()) + { + UInt fieldIndex = fieldCounter++; + + auto fieldLayout = structTypeLayout->fields[fieldIndex]; + + UInt fieldBindingIndex = bindingIndex; + if(auto fieldResInfo = fieldLayout->FindResourceInfo(kind)) + fieldBindingIndex += fieldResInfo->index; + + auto fieldVal = createGLSLGlobalVaryingsImpl( + context, + builder, + field->getFieldType(), + fieldLayout, + fieldLayout->typeLayout, + kind, + stage, + fieldBindingIndex, + declarator); + if (fieldVal.flavor != ScalarizedVal::Flavor::none) + { + ScalarizedTupleValImpl::Element element; + element.val = fieldVal; + element.key = field->getKey(); + + tupleValImpl->elements.add(element); + } + } + + return ScalarizedVal::tuple(tupleValImpl); + } + + // Default case is to fall back on the simple behavior + return createSimpleGLSLGlobalVarying( + context, + builder, type, varLayout, typeLayout, kind, stage, bindingIndex, declarator); +} + +ScalarizedVal createGLSLGlobalVaryings( + GLSLLegalizationContext* context, + IRBuilder* builder, + IRType* type, + VarLayout* layout, + LayoutResourceKind kind, + Stage stage) +{ + UInt bindingIndex = 0; + if(auto rr = layout->FindResourceInfo(kind)) + bindingIndex = rr->index; + return createGLSLGlobalVaryingsImpl( + context, + builder, type, layout, layout->typeLayout, kind, stage, bindingIndex, nullptr); +} + +ScalarizedVal extractField( + IRBuilder* builder, + ScalarizedVal const& val, + UInt fieldIndex, + IRStructKey* fieldKey) +{ + switch( val.flavor ) + { + case ScalarizedVal::Flavor::value: + return ScalarizedVal::value( + builder->emitFieldExtract( + getFieldType(val.irValue->getDataType(), fieldKey), + val.irValue, + fieldKey)); + + case ScalarizedVal::Flavor::address: + { + auto ptrType = as(val.irValue->getDataType()); + auto valType = ptrType->getValueType(); + auto fieldType = getFieldType(valType, fieldKey); + auto fieldPtrType = builder->getPtrType(ptrType->op, fieldType); + return ScalarizedVal::address( + builder->emitFieldAddress( + fieldPtrType, + val.irValue, + fieldKey)); + } + + case ScalarizedVal::Flavor::tuple: + { + auto tupleVal = as(val.impl); + return tupleVal->elements[fieldIndex].val; + } + + default: + SLANG_UNEXPECTED("unimplemented"); + UNREACHABLE_RETURN(ScalarizedVal()); + } + +} + +ScalarizedVal adaptType( + IRBuilder* builder, + IRInst* val, + IRType* toType, + IRType* /*fromType*/) +{ + // TODO: actually consider what needs to go on here... + return ScalarizedVal::value(builder->emitConstructorInst( + toType, + 1, + &val)); +} + +ScalarizedVal adaptType( + IRBuilder* builder, + ScalarizedVal const& val, + IRType* toType, + IRType* fromType) +{ + switch( val.flavor ) + { + case ScalarizedVal::Flavor::value: + return adaptType(builder, val.irValue, toType, fromType); + break; + + case ScalarizedVal::Flavor::address: + { + auto loaded = builder->emitLoad(val.irValue); + return adaptType(builder, loaded, toType, fromType); + } + break; + + default: + SLANG_UNEXPECTED("unimplemented"); + UNREACHABLE_RETURN(ScalarizedVal()); + } +} + +void assign( + IRBuilder* builder, + ScalarizedVal const& left, + ScalarizedVal const& right) +{ + switch( left.flavor ) + { + case ScalarizedVal::Flavor::address: + switch( right.flavor ) + { + case ScalarizedVal::Flavor::value: + { + builder->emitStore(left.irValue, right.irValue); + } + break; + + case ScalarizedVal::Flavor::address: + { + auto val = builder->emitLoad(right.irValue); + builder->emitStore(left.irValue, val); + } + break; + + case ScalarizedVal::Flavor::tuple: + { + // We are assigning from a tuple to a destination + // that is not a tuple. We will perform assignment + // element-by-element. + auto rightTupleVal = as(right.impl); + Index elementCount = rightTupleVal->elements.getCount(); + + for( Index ee = 0; ee < elementCount; ++ee ) + { + auto rightElement = rightTupleVal->elements[ee]; + auto leftElementVal = extractField( + builder, + left, + ee, + rightElement.key); + assign(builder, leftElementVal, rightElement.val); + } + } + break; + + default: + SLANG_UNEXPECTED("unimplemented"); + break; + } + break; + + case ScalarizedVal::Flavor::tuple: + { + // We have a tuple, so we are going to need to try and assign + // to each of its constituent fields. + auto leftTupleVal = as(left.impl); + Index elementCount = leftTupleVal->elements.getCount(); + + for( Index ee = 0; ee < elementCount; ++ee ) + { + auto rightElementVal = extractField( + builder, + right, + ee, + leftTupleVal->elements[ee].key); + assign(builder, leftTupleVal->elements[ee].val, rightElementVal); + } + } + break; + + case ScalarizedVal::Flavor::typeAdapter: + { + // We are trying to assign to something that had its type adjusted, + // so we will need to adjust the type of the right-hand side first. + // + // In this case we are converting to the actual type of the GLSL variable, + // from the "pretend" type that it had in the IR before. + auto typeAdapter = as(left.impl); + auto adaptedRight = adaptType(builder, right, typeAdapter->actualType, typeAdapter->pretendType); + assign(builder, typeAdapter->val, adaptedRight); + } + break; + + default: + SLANG_UNEXPECTED("unimplemented"); + break; + } +} + +ScalarizedVal getSubscriptVal( + IRBuilder* builder, + IRType* elementType, + ScalarizedVal val, + IRInst* indexVal) +{ + switch( val.flavor ) + { + case ScalarizedVal::Flavor::value: + return ScalarizedVal::value( + builder->emitElementExtract( + elementType, + val.irValue, + indexVal)); + + case ScalarizedVal::Flavor::address: + return ScalarizedVal::address( + builder->emitElementAddress( + builder->getPtrType(elementType), + val.irValue, + indexVal)); + + case ScalarizedVal::Flavor::tuple: + { + auto inputTuple = val.impl.as(); + + RefPtr resultTuple = new ScalarizedTupleValImpl(); + resultTuple->type = elementType; + + Index elementCount = inputTuple->elements.getCount(); + Index elementCounter = 0; + + auto structType = as(elementType); + for(auto field : structType->getFields()) + { + auto tupleElementType = field->getFieldType(); + + Index elementIndex = elementCounter++; + + SLANG_RELEASE_ASSERT(elementIndex < elementCount); + auto inputElement = inputTuple->elements[elementIndex]; + + ScalarizedTupleValImpl::Element resultElement; + resultElement.key = inputElement.key; + resultElement.val = getSubscriptVal( + builder, + tupleElementType, + inputElement.val, + indexVal); + + resultTuple->elements.add(resultElement); + } + SLANG_RELEASE_ASSERT(elementCounter == elementCount); + + return ScalarizedVal::tuple(resultTuple); + } + + default: + SLANG_UNEXPECTED("unimplemented"); + UNREACHABLE_RETURN(ScalarizedVal()); + } +} + +ScalarizedVal getSubscriptVal( + IRBuilder* builder, + IRType* elementType, + ScalarizedVal val, + UInt index) +{ + return getSubscriptVal( + builder, + elementType, + val, + builder->getIntValue( + builder->getIntType(), + index)); +} + +IRInst* materializeValue( + IRBuilder* builder, + ScalarizedVal const& val); + +IRInst* materializeTupleValue( + IRBuilder* builder, + ScalarizedVal val) +{ + auto tupleVal = val.impl.as(); + SLANG_ASSERT(tupleVal); + + Index elementCount = tupleVal->elements.getCount(); + auto type = tupleVal->type; + + if( auto arrayType = as(type)) + { + // The tuple represent an array, which means that the + // individual elements are expected to yield arrays as well. + // + // We will extract a value for each array element, and + // then use these to construct our result. + + List arrayElementVals; + UInt arrayElementCount = (UInt) GetIntVal(arrayType->getElementCount()); + + for( UInt ii = 0; ii < arrayElementCount; ++ii ) + { + auto arrayElementPseudoVal = getSubscriptVal( + builder, + arrayType->getElementType(), + val, + ii); + + auto arrayElementVal = materializeValue( + builder, + arrayElementPseudoVal); + + arrayElementVals.add(arrayElementVal); + } + + return builder->emitMakeArray( + arrayType, + arrayElementVals.getCount(), + arrayElementVals.getBuffer()); + } + else + { + // The tuple represents a value of some aggregate type, + // so we can simply materialize the elements and then + // construct a value of that type. + // + // TODO: this should be using a `makeStruct` instruction. + + List elementVals; + for( Index ee = 0; ee < elementCount; ++ee ) + { + auto elementVal = materializeValue(builder, tupleVal->elements[ee].val); + elementVals.add(elementVal); + } + + return builder->emitConstructorInst( + tupleVal->type, + elementVals.getCount(), + elementVals.getBuffer()); + } +} + +IRInst* materializeValue( + IRBuilder* builder, + ScalarizedVal const& val) +{ + switch( val.flavor ) + { + case ScalarizedVal::Flavor::value: + return val.irValue; + + case ScalarizedVal::Flavor::address: + { + auto loadInst = builder->emitLoad(val.irValue); + return loadInst; + } + break; + + case ScalarizedVal::Flavor::tuple: + { + //auto tupleVal = as(val.impl); + return materializeTupleValue(builder, val); + } + break; + + case ScalarizedVal::Flavor::typeAdapter: + { + // Somebody is trying to use a value where its actual type + // doesn't match the type it pretends to have. To make this + // work we need to adapt the type from its actual type over + // to its pretend type. + auto typeAdapter = as(val.impl); + auto adapted = adaptType(builder, typeAdapter->val, typeAdapter->pretendType, typeAdapter->actualType); + return materializeValue(builder, adapted); + } + break; + + default: + SLANG_UNEXPECTED("unimplemented"); + break; + } +} + +void legalizeRayTracingEntryPointParameterForGLSL( + GLSLLegalizationContext* context, + IRFunc* func, + IRParam* pp, + VarLayout* paramLayout) +{ + auto builder = context->getBuilder(); + auto paramType = pp->getDataType(); + + // The parameter might be either an `in` parameter, + // or an `out` or `in out` parameter, and in those + // latter cases its IR-level type will include a + // wrapping "pointer-like" type (e.g., `Out` + // instead of just `Float`). + // + // Because global shader parameters are read-only + // in the same way function types are, we can take + // care of that detail here just by allocating a + // global shader parameter with exactly the type + // of the original function parameter. + // + auto globalParam = addGlobalParam(builder->getModule(), paramType); + builder->addLayoutDecoration(globalParam, paramLayout); + moveValueBefore(globalParam, builder->getFunc()); + pp->replaceUsesWith(globalParam); + + // Because linkage between ray-tracing shaders is + // based on the type of incoming/outgoing payload + // and attribute parameters, it would be an error to + // eliminate the global parameter *even if* it is + // not actually used inside the entry point. + // + // We attach a decoration to the entry point that + // makes note of the dependency, so that steps + // like dead code elimination cannot get rid of + // the parameter. + // + // TODO: We could consider using a structure like + // this for *all* of the entry point parameters + // that get moved to the global scope, since SPIR-V + // ends up requiring such information on an `OpEntryPoint`. + // + // As a further alternative, we could decide to + // keep entry point varying input/outtput attached + // to the parameter list through all of the Slang IR + // steps, and only declare it as global variables at + // the last minute when emitting a GLSL `main` or + // SPIR-V for an entry point. + // + builder->addDependsOnDecoration(func, globalParam); +} + +void legalizeEntryPointParameterForGLSL( + GLSLLegalizationContext* context, + IRFunc* func, + IRParam* pp, + VarLayout* paramLayout) +{ + auto builder = context->getBuilder(); + auto stage = context->getStage(); + + // We need to create a global variable that will replace the parameter. + // It seems superficially obvious that the variable should have + // the same type as the parameter. + // However, if the parameter was a pointer, in order to + // support `out` or `in out` parameter passing, we need + // to be sure to allocate a variable of the pointed-to + // type instead. + // + // We also need to replace uses of the parameter with + // uses of the variable, and the exact logic there + // will differ a bit between the pointer and non-pointer + // cases. + auto paramType = pp->getDataType(); + + // First we will special-case stage input/outputs that + // don't fit into the standard varying model. + // For right now we are only doing special-case handling + // of geometry shader output streams. + if( auto paramPtrType = as(paramType) ) + { + auto valueType = paramPtrType->getValueType(); + if( auto gsStreamType = as(valueType) ) + { + // An output stream type like `TriangleStream` should + // more or less translate into `out Foo` (plus scalarization). + + auto globalOutputVal = createGLSLGlobalVaryings( + context, + builder, + valueType, + paramLayout, + LayoutResourceKind::VaryingOutput, + stage); + + // TODO: a GS output stream might be passed into other + // functions, so that we should really be modifying + // any function that has one of these in its parameter + // list (and in the limit we should be leagalizing any + // type that nests these...). + // + // For now we will just try to deal with `Append` calls + // directly in this function. + + + + for( auto bb = func->getFirstBlock(); bb; bb = bb->getNextBlock() ) + { + for( auto ii = bb->getFirstInst(); ii; ii = ii->getNextInst() ) + { + // Is it a call? + if(ii->op != kIROp_Call) + continue; + + // Is it calling the append operation? + auto callee = ii->getOperand(0); + for(;;) + { + // If the instruction is `specialize(X,...)` then + // we want to look at `X`, and if it is `generic { ... return R; }` + // then we want to look at `R`. We handle this + // iteratively here. + // + // TODO: This idiom seems to come up enough that we + // should probably have a dedicated convenience routine + // for this. + // + // Alternatively, we could switch the IR encoding so + // that decorations are added to the generic instead of the + // value it returns. + // + switch(callee->op) + { + case kIROp_Specialize: + { + callee = cast(callee)->getOperand(0); + continue; + } + + case kIROp_Generic: + { + auto genericResult = findGenericReturnVal(cast(callee)); + if(genericResult) + { + callee = genericResult; + continue; + } + } + + default: + break; + } + break; + } + if(callee->op != kIROp_Func) + continue; + + // HACK: we will identify the operation based + // on the target-intrinsic definition that was + // given to it. + auto decoration = findTargetIntrinsicDecoration(callee, "glsl"); + if(!decoration) + continue; + + if(decoration->getDefinition() != UnownedStringSlice::fromLiteral("EmitVertex()")) + { + continue; + } + + // Okay, we have a declaration, and we want to modify it! + + builder->setInsertBefore(ii); + + assign(builder, globalOutputVal, ScalarizedVal::value(ii->getOperand(2))); + } + } + + // We will still have references to the parameter coming + // from the `EmitVertex` calls, so we need to replace it + // with something. There isn't anything reasonable to + // replace it with that would have the right type, so + // we will replace it with an undefined value, knowing + // that the emitted code will not actually reference it. + // + // TODO: This approach to generating geometry shader code + // is not ideal, and we should strive to find a better + // approach that involes coding the `EmitVertex` operation + // directly in the stdlib, similar to how ray-tracing + // operations like `TraceRay` are handled. + // + builder->setInsertBefore(func->getFirstBlock()->getFirstOrdinaryInst()); + auto undefinedVal = builder->emitUndefined(pp->getFullType()); + pp->replaceUsesWith(undefinedVal); + + return; + } + } + + // When we have an HLSL ray tracing shader entry point, + // we don't want to translate the inputs/outputs for GLSL/SPIR-V + // according to our default rules, for two reasons: + // + // 1. The input and output for these stages are expected to + // be packaged into `struct` types rather than be scalarized, + // so the usual scalarization approach we take here should + // not be applied. + // + // 2. An `in out` parameter isn't just sugar for a combination + // of an `in` and an `out` parameter, and instead represents the + // read/write "payload" that was passed in. It should legalize + // to a single variable, and we can lower reads/writes of it + // directly, rather than introduce an intermediate temporary. + // + switch( stage ) + { + default: + break; + + case Stage::AnyHit: + case Stage::Callable: + case Stage::ClosestHit: + case Stage::Intersection: + case Stage::Miss: + case Stage::RayGeneration: + legalizeRayTracingEntryPointParameterForGLSL(context, func, pp, paramLayout); + return; + } + + // Is the parameter type a special pointer type + // that indicates the parameter is used for `out` + // or `inout` access? + if(auto paramPtrType = as(paramType) ) + { + // Okay, we have the more interesting case here, + // where the parameter was being passed by reference. + // We are going to create a local variable of the appropriate + // type, which will replace the parameter, along with + // one or more global variables for the actual input/output. + + auto valueType = paramPtrType->getValueType(); + + auto localVariable = builder->emitVar(valueType); + auto localVal = ScalarizedVal::address(localVariable); + + if( auto inOutType = as(paramPtrType) ) + { + // In the `in out` case we need to declare two + // sets of global variables: one for the `in` + // side and one for the `out` side. + auto globalInputVal = createGLSLGlobalVaryings( + context, + builder, valueType, paramLayout, LayoutResourceKind::VaryingInput, stage); + + assign(builder, localVal, globalInputVal); + } + + // Any places where the original parameter was used inside + // the function body should instead use the new local variable. + // Since the parameter was a pointer, we use the variable instruction + // itself (which is an `alloca`d pointer) directly: + pp->replaceUsesWith(localVariable); + + // We also need one or more global variables to write the output to + // when the function is done. We create them here. + auto globalOutputVal = createGLSLGlobalVaryings( + context, + builder, valueType, paramLayout, LayoutResourceKind::VaryingOutput, stage); + + // Now we need to iterate over all the blocks in the function looking + // for any `return*` instructions, so that we can write to the output variable + for( auto bb = func->getFirstBlock(); bb; bb = bb->getNextBlock() ) + { + auto terminatorInst = bb->getLastInst(); + if(!terminatorInst) + continue; + + switch( terminatorInst->op ) + { + default: + continue; + + case kIROp_ReturnVal: + case kIROp_ReturnVoid: + break; + } + + // We dont' re-use `builder` here because we don't want to + // disrupt the source location it is using for inserting + // temporary variables at the top of the function. + // + IRBuilder terminatorBuilder; + terminatorBuilder.sharedBuilder = builder->sharedBuilder; + terminatorBuilder.setInsertBefore(terminatorInst); + + // Assign from the local variabel to the global output + // variable before the actual `return` takes place. + assign(&terminatorBuilder, globalOutputVal, localVal); + } + } + else + { + // This is the "easy" case where the parameter wasn't + // being passed by reference. We start by just creating + // one or more global variables to represent the parameter, + // and attach the required layout information to it along + // the way. + + auto globalValue = createGLSLGlobalVaryings( + context, + builder, paramType, paramLayout, LayoutResourceKind::VaryingInput, stage); + + // Next we need to replace uses of the parameter with + // references to the variable(s). We are going to do that + // somewhat naively, by simply materializing the + // variables at the start. + IRInst* materialized = materializeValue(builder, globalValue); + + pp->replaceUsesWith(materialized); + } +} + +void legalizeEntryPointForGLSL( + Session* session, + IRModule* module, + IRFunc* func, + DiagnosticSink* sink, + ExtensionUsageTracker* extensionUsageTracker) +{ + auto layoutDecoration = func->findDecoration(); + SLANG_ASSERT(layoutDecoration); + + auto entryPointLayout = as(layoutDecoration->getLayout()); + SLANG_ASSERT(entryPointLayout); + + GLSLLegalizationContext context; + context.session = session; + context.stage = entryPointLayout->profile.GetStage(); + context.sink = sink; + context.extensionUsageTracker = extensionUsageTracker; + + Stage stage = entryPointLayout->profile.GetStage(); + + // We require that the entry-point function has no uses, + // because otherwise we'd invalidate the signature + // at all existing call sites. + // + // TODO: the right thing to do here is to split any + // function that both gets called as an entry point + // and as an ordinary function. + SLANG_ASSERT(!func->firstUse); + + // We create a dummy IR builder, since some of + // the functions require it. + // + // TODO: make some of these free functions... + // + SharedIRBuilder shared; + shared.module = module; + shared.session = session; + IRBuilder builder; + builder.sharedBuilder = &shared; + builder.setInsertInto(func); + + context.builder = &builder; + + // We will start by looking at the return type of the + // function, because that will enable us to do an + // early-out check to avoid more work. + // + // Specifically, we need to check if the function has + // a `void` return type, because there is no work + // to be done on its return value in that case. + auto resultType = func->getResultType(); + if(as(resultType)) + { + // In this case, the function doesn't return a value + // so we don't need to transform its `return` sites. + // + // We can also use this opportunity to quickly + // check if the function has any parameters, and if + // it doesn't use the chance to bail out immediately. + if( func->getParamCount() == 0 ) + { + // This function is already legal for GLSL + // (at least in terms of parameter/result signature), + // so we won't bother doing anything at all. + return; + } + + // If the function does have parameters, then we need + // to let the logic later in this function handle them. + } + else + { + // Function returns a value, so we need + // to introduce a new global variable + // to hold that value, and then replace + // any `returnVal` instructions with + // code to write to that variable. + + auto resultGlobal = createGLSLGlobalVaryings( + &context, + &builder, + resultType, + entryPointLayout->resultLayout, + LayoutResourceKind::VaryingOutput, + stage); + + for( auto bb = func->getFirstBlock(); bb; bb = bb->getNextBlock() ) + { + // TODO: This is silly, because we are looking at every instruction, + // when we know that a `returnVal` should only ever appear as a + // terminator... + for( auto ii = bb->getFirstInst(); ii; ii = ii->getNextInst() ) + { + if(ii->op != kIROp_ReturnVal) + continue; + + IRReturnVal* returnInst = (IRReturnVal*) ii; + IRInst* returnValue = returnInst->getVal(); + + // Make sure we add these instructions to the right block + builder.setInsertInto(bb); + + // Write to our global variable(s) from the value being returned. + assign(&builder, resultGlobal, ScalarizedVal::value(returnValue)); + + // Emit a `returnVoid` to end the block + auto returnVoid = builder.emitReturn(); + + // Remove the old `returnVal` instruction. + returnInst->removeAndDeallocate(); + + // Make sure to resume our iteration at an + // appropriate instruciton, since we deleted + // the one we had been using. + ii = returnVoid; + } + } + } + + // Next we will walk through any parameters of the entry-point function, + // and turn them into global variables. + if( auto firstBlock = func->getFirstBlock() ) + { + // Any initialization code we insert for parameters needs + // to be at the start of the "ordinary" instructions in the block: + builder.setInsertBefore(firstBlock->getFirstOrdinaryInst()); + + for( auto pp = firstBlock->getFirstParam(); pp; pp = pp->getNextParam() ) + { + // We assume that the entry-point parameters will all have + // layout information attached to them, which is kept up-to-date + // by any transformations affecting the parameter list. + // + auto paramLayoutDecoration = pp->findDecoration(); + SLANG_ASSERT(paramLayoutDecoration); + auto paramLayout = as(paramLayoutDecoration->getLayout()); + SLANG_ASSERT(paramLayout); + + legalizeEntryPointParameterForGLSL( + &context, + func, + pp, + paramLayout); + } + + // At this point we should have eliminated all uses of the + // parameters of the entry block. Also, our control-flow + // rules mean that the entry block cannot be the target + // of any branches in the code, so there can't be + // any control-flow ops that try to match the parameter + // list. + // + // We can safely go through and destroy the parameters + // themselves, and then clear out the parameter list. + + for( auto pp = firstBlock->getFirstParam(); pp; ) + { + auto next = pp->getNextParam(); + pp->removeAndDeallocate(); + pp = next; + } + } + + // Finally, we need to patch up the type of the entry point, + // because it is no longer accurate. + + IRFuncType* voidFuncType = builder.getFuncType( + 0, + nullptr, + builder.getVoidType()); + func->setFullType(voidFuncType); + + // TODO: we should technically be constructing + // a new `EntryPointLayout` here to reflect + // the way that things have been moved around. +} + +} // namespace Slang diff --git a/source/slang/slang-ir-glsl-legalize.h b/source/slang/slang-ir-glsl-legalize.h new file mode 100644 index 000000000..3694005b3 --- /dev/null +++ b/source/slang/slang-ir-glsl-legalize.h @@ -0,0 +1,22 @@ +// slang-ir-glsl-legalize.h +#pragma once + +namespace Slang +{ + +class DiagnosticSink; +class Session; + +class ExtensionUsageTracker; + +struct IRFunc; +struct IRModule; + +void legalizeEntryPointForGLSL( + Session* session, + IRModule* module, + IRFunc* func, + DiagnosticSink* sink, + ExtensionUsageTracker* extensionUsageTracker); + +} diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h new file mode 100644 index 000000000..abd677979 --- /dev/null +++ b/source/slang/slang-ir-inst-defs.h @@ -0,0 +1,480 @@ +// slang-ir-inst-defs.h + +#ifndef INST +#error Must #define `INST` before including `ir-inst-defs.h` +#endif + +#ifndef INST_RANGE +#define INST_RANGE(BASE, FIRST, LAST) /* empty */ +#endif + +#ifndef PSEUDO_INST +#define PSEUDO_INST(ID) /* empty */ +#endif + +#define PARENT kIROpFlag_Parent +#define USE_OTHER kIROpFlag_UseOther + +INST(Nop, nop, 0, 0) + +/* Types */ + + /* Basic Types */ + + #define DEFINE_BASE_TYPE_INST(NAME) INST(NAME ## Type, NAME, 0, 0) + FOREACH_BASE_TYPE(DEFINE_BASE_TYPE_INST) + #undef DEFINE_BASE_TYPE_INST + INST(AfterBaseType, afterBaseType, 0, 0) + + INST_RANGE(BasicType, VoidType, AfterBaseType) + + INST(StringType, String, 0, 0) + + /* ArrayTypeBase */ + INST(ArrayType, Array, 2, 0) + INST(UnsizedArrayType, UnsizedArray, 1, 0) + INST_RANGE(ArrayTypeBase, ArrayType, UnsizedArrayType) + + INST(FuncType, Func, 0, 0) + INST(BasicBlockType, BasicBlock, 0, 0) + + INST(VectorType, Vec, 2, 0) + INST(MatrixType, Mat, 3, 0) + + INST(TaggedUnionType, TaggedUnion, 0, 0) + + // A `BindExistentials` represents + // taking type `B` and binding each of its existential type + // parameters, recursively, with the specified arguments, + // where each `Ti, wi` pair represents the concrete type + // and witness table to plug in for parameter `i`. + // + INST(BindExistentialsType, BindExistentials, 1, 0) + + /* Rate */ + INST(ConstExprRate, ConstExpr, 0, 0) + INST(GroupSharedRate, GroupShared, 0, 0) + INST_RANGE(Rate, ConstExprRate, GroupSharedRate) + + INST(RateQualifiedType, RateQualified, 2, 0) + + // Kinds represent the "types of types." + // They should not really be nested under `IRType` + // in the overall hierarchy, but we can fix that later. + // + /* Kind */ + INST(TypeKind, Type, 0, 0) + INST(RateKind, Rate, 0, 0) + INST(GenericKind, Generic, 0, 0) + INST_RANGE(Kind, TypeKind, GenericKind) + + /* PtrTypeBase */ + INST(PtrType, Ptr, 1, 0) + INST(RefType, Ref, 1, 0) + + // An `ExistentialBox` represents a logical pointer to a value of type `T`. + // On targets that support pointers this might lower to a pointer, but on + // current targets it will lower to zero bytes, with a value of type `T` + // being stored "out of line" somewhere. + // + INST(ExistentialBoxType, ExistentialBox, 1, 0) + + /* OutTypeBase */ + INST(OutType, Out, 1, 0) + INST(InOutType, InOut, 1, 0) + INST_RANGE(OutTypeBase, OutType, InOutType) + INST_RANGE(PtrTypeBase, PtrType, InOutType) + + /* SamplerStateTypeBase */ + INST(SamplerStateType, SamplerState, 0, 0) + INST(SamplerComparisonStateType, SamplerComparisonState, 0, 0) + INST_RANGE(SamplerStateTypeBase, SamplerStateType, SamplerComparisonStateType) + + // TODO: Why do we have all this hierarchy here, when everything + // that actually matters is currently nested under `TextureTypeBase`? + /* ResourceTypeBase */ + /* ResourceType */ + /* TextureTypeBase */ + // NOTE! TextureFlavor::Flavor is stored in 'other' bits for these types. + /* TextureType */ + INST(TextureType, TextureType, 0, USE_OTHER) + /* TextureSamplerType */ + INST(TextureSamplerType, TextureSamplerType, 0, USE_OTHER) + /* GLSLImageType */ + INST(GLSLImageType, GLSLImageType, 0, USE_OTHER) + INST_RANGE(TextureTypeBase, TextureType, GLSLImageType) + INST_RANGE(ResourceType, TextureType, GLSLImageType) + INST_RANGE(ResourceTypeBase, TextureType, GLSLImageType) + + + /* UntypedBufferResourceType */ + /* ByteAddressBufferTypeBase */ + INST(HLSLByteAddressBufferType, ByteAddressBuffer, 0, 0) + INST(HLSLRWByteAddressBufferType, RWByteAddressBuffer, 0, 0) + INST(HLSLRasterizerOrderedByteAddressBufferType, RasterizerOrderedByteAddressBuffer, 0, 0) + INST_RANGE(ByteAddressBufferTypeBase, HLSLByteAddressBufferType, HLSLRasterizerOrderedByteAddressBufferType) + INST(RaytracingAccelerationStructureType, RaytracingAccelerationStructure, 0, 0) + INST_RANGE(UntypedBufferResourceType, HLSLByteAddressBufferType, RaytracingAccelerationStructureType) + + /* HLSLPatchType */ + INST(HLSLInputPatchType, InputPatch, 2, 0) + INST(HLSLOutputPatchType, OutputPatch, 2, 0) + INST_RANGE(HLSLPatchType, HLSLInputPatchType, HLSLOutputPatchType) + + INST(GLSLInputAttachmentType, GLSLInputAttachment, 0, 0) + + /* BuiltinGenericType */ + /* HLSLStreamOutputType */ + INST(HLSLPointStreamType, PointStream, 1, 0) + INST(HLSLLineStreamType, LineStream, 1, 0) + INST(HLSLTriangleStreamType, TriangleStream, 1, 0) + INST_RANGE(HLSLStreamOutputType, HLSLPointStreamType, HLSLTriangleStreamType) + + /* HLSLStructuredBufferTypeBase */ + INST(HLSLStructuredBufferType, StructuredBuffer, 0, 0) + INST(HLSLRWStructuredBufferType, RWStructuredBuffer, 0, 0) + INST(HLSLRasterizerOrderedStructuredBufferType, RasterizerOrderedStructuredBuffer, 0, 0) + INST(HLSLAppendStructuredBufferType, AppendStructuredBuffer, 0, 0) + INST(HLSLConsumeStructuredBufferType, ConsumeStructuredBuffer, 0, 0) + INST_RANGE(HLSLStructuredBufferTypeBase, HLSLStructuredBufferType, HLSLConsumeStructuredBufferType) + + /* PointerLikeType */ + /* ParameterGroupType */ + /* UniformParameterGroupType */ + INST(ConstantBufferType, ConstantBuffer, 1, 0) + INST(TextureBufferType, TextureBuffer, 1, 0) + INST(ParameterBlockType, ParameterBlock, 1, 0) + INST(GLSLShaderStorageBufferType, GLSLShaderStorageBuffer, 0, 0) + INST_RANGE(UniformParameterGroupType, ConstantBufferType, GLSLShaderStorageBufferType) + + /* VaryingParameterGroupType */ + INST(GLSLInputParameterGroupType, GLSLInputParameterGroup, 0, 0) + INST(GLSLOutputParameterGroupType, GLSLOutputParameterGroup, 0, 0) + INST_RANGE(VaryingParameterGroupType, GLSLInputParameterGroupType, GLSLOutputParameterGroupType) + INST_RANGE(ParameterGroupType, ConstantBufferType, GLSLOutputParameterGroupType) + INST_RANGE(PointerLikeType, ConstantBufferType, GLSLOutputParameterGroupType) + INST_RANGE(BuiltinGenericType, HLSLPointStreamType, GLSLOutputParameterGroupType) + + + + +// A user-defined structure declaration at the IR level. +// Unlike in the AST where there is a distinction between +// a `StructDecl` and a `DeclRefType` that refers to it, +// at the IR level the struct declaration and the type +// are the same IR instruction. +// +// This is a parent instruction that holds zero or more +// `field` instructions. +// +INST(StructType, struct, 0, PARENT) +INST(InterfaceType, interface, 0, PARENT) + +INST_RANGE(Type, VoidType, InterfaceType) + +/*IRGlobalValueWithCode*/ + /* IRGlobalValueWIthParams*/ + INST(Func, func, 0, PARENT) + INST(Generic, generic, 0, PARENT) + INST_RANGE(GlobalValueWithParams, Func, Generic) + + INST(GlobalVar, global_var, 0, 0) + INST(GlobalConstant, global_constant, 0, 0) +INST_RANGE(GlobalValueWithCode, Func, GlobalConstant) + +INST(GlobalParam, global_param, 0, 0) + +INST(StructKey, key, 0, 0) +INST(GlobalGenericParam, global_generic_param, 0, 0) +INST(WitnessTable, witness_table, 0, 0) + +INST(Module, module, 0, PARENT) + +INST(Block, block, 0, PARENT) + +/* IRConstant */ + INST(BoolLit, boolConst, 0, 0) + INST(IntLit, integer_constant, 0, 0) + INST(FloatLit, float_constant, 0, 0) + INST(PtrLit, ptr_constant, 0, 0) + INST(StringLit, string_constant, 0, 0) +INST_RANGE(Constant, BoolLit, StringLit) + +INST(undefined, undefined, 0, 0) + +INST(Specialize, specialize, 2, 0) +INST(lookup_interface_method, lookup_interface_method, 2, 0) +INST(lookup_witness_table, lookup_witness_table, 2, 0) +INST(BindGlobalGenericParam, bind_global_generic_param, 2, 0) +INST(BindGlobalExistentialSlots, bindGlobalExistentialSlots, 0, 0) + +INST(Construct, construct, 0, 0) + +INST(makeVector, makeVector, 0, 0) +INST(MakeMatrix, makeMatrix, 0, 0) +INST(makeArray, makeArray, 0, 0) +INST(makeStruct, makeStruct, 0, 0) + +INST(Call, call, 1, 0) + + +INST(WitnessTableEntry, witness_table_entry, 2, 0) + +INST(Param, param, 0, 0) +INST(StructField, field, 2, 0) +INST(Var, var, 0, 0) + +INST(Load, load, 1, 0) +INST(Store, store, 2, 0) + +INST(FieldExtract, get_field, 2, 0) +INST(FieldAddress, get_field_addr, 2, 0) + +INST(getElement, getElement, 2, 0) +INST(getElementPtr, getElementPtr, 2, 0) + +// "Subscript" an image at a pixel coordinate to get pointer +INST(ImageSubscript, imageSubscript, 2, 0) + +// Construct a vector from a scalar +// +// %dst = constructVectorFromScalar %T %N %val +// +// where +// - `T` is a `Type` +// - `N` is a (compile-time) `Int` +// - `val` is a `T` +// - dst is a `Vec` +// +INST(constructVectorFromScalar, constructVectorFromScalar, 3, 0) + +// A swizzle of a vector: +// +// %dst = swizzle %src %idx0 %idx1 ... +// +// where: +// - `src` is a vector +// - `dst` is a vector +// - `idx0` through `idx[M-1]` are literal integers +// +INST(swizzle, swizzle, 1, 0) + +// Setting a vector via swizzle +// +// %dst = swizzle %base %src %idx0 %idx1 ... +// +// where: +// - `base` is a vector +// - `dst` is a vector +// - `src` is a vector +// - `idx0` through `idx[M-1]` are literal integers +// +// The semantics of the op is: +// +// dst = base; +// for(ii : 0 ... M-1 ) +// dst[ii] = src[idx[ii]]; +// +INST(swizzleSet, swizzleSet, 2, 0) + +// Store to memory with a swizzle +// +// TODO: eventually this should be reduced to just +// a write mask by moving the actual swizzle to the RHS. +// +// swizzleStore %dst %src %idx0 %idx1 ... +// +// where: +// - `dst` is a vector +// - `src` is a vector +// - `idx0` through `idx[M-1]` are literal integers +// +// The semantics of the op is: +// +// for(ii : 0 ... M-1 ) +// dst[ii] = src[idx[ii]]; +// +INST(SwizzledStore, swizzledStore, 2, 0) + + +/* IRTerminatorInst */ + + INST(ReturnVal, return_val, 1, 0) + INST(ReturnVoid, return_void, 1, 0) + + /* IRUnconditionalBranch */ + // unconditionalBranch + INST(unconditionalBranch, unconditionalBranch, 1, 0) + + // loop + INST(loop, loop, 3, 0) + INST_RANGE(UnconditionalBranch, unconditionalBranch, loop) + + /* IRConditionalbranch */ + + // conditionalBranch + INST(conditionalBranch, conditionalBranch, 3, 0) + + // ifElse + INST(ifElse, ifElse, 4, 0) + INST_RANGE(ConditionalBranch, conditionalBranch, ifElse) + + // switch ... + INST(Switch, switch, 3, 0) + + INST(discard, discard, 0, 0) + + /* IRUnreachable */ + INST(MissingReturn, missingReturn, 0, 0) + INST(Unreachable, unreachable, 0, 0) + INST_RANGE(Unreachable, MissingReturn, Unreachable) + +INST_RANGE(TerminatorInst, ReturnVal, Unreachable) + +INST(Add, add, 2, 0) +INST(Sub, sub, 2, 0) +INST(Mul, mul, 2, 0) +INST(Div, div, 2, 0) +INST(Mod, mod, 2, 0) + +INST(Lsh, shl, 2, 0) +INST(Rsh, shr, 2, 0) + +INST(Eql, cmpEQ, 2, 0) +INST(Neq, cmpNE, 2, 0) +INST(Greater, cmpGT, 2, 0) +INST(Less, cmpLT, 2, 0) +INST(Geq, cmpGE, 2, 0) +INST(Leq, cmpLE, 2, 0) + +INST(BitAnd, and, 2, 0) +INST(BitXor, xor, 2, 0) +INST(BitOr, or , 2, 0) + +INST(And, logicalAnd, 2, 0) +INST(Or, logicalOr, 2, 0) + +INST(Neg, neg, 1, 0) +INST(Not, not, 1, 0) +INST(BitNot, bitnot, 1, 0) + +INST(Select, select, 3, 0) + +INST(Dot, dot, 2, 0) + +INST(Mul_Vector_Matrix, mulVectorMatrix, 2, 0) +INST(Mul_Matrix_Vector, mulMatrixVector, 2, 0) +INST(Mul_Matrix_Matrix, mulMatrixMatrix, 2, 0) + +// Texture sampling operation of the form `t.Sample(s,u)` +INST(Sample, sample, 3, 0) + +INST(SampleGrad, sampleGrad, 4, 0) + +INST(GroupMemoryBarrierWithGroupSync, GroupMemoryBarrierWithGroupSync, 0, 0) + +/* Decoration */ + +INST(HighLevelDeclDecoration, highLevelDecl, 1, 0) + INST(LayoutDecoration, layout, 1, 0) + INST(LoopControlDecoration, loopControl, 1, 0) + /* TargetSpecificDecoration */ + INST(TargetDecoration, target, 1, 0) + INST(TargetIntrinsicDecoration, targetIntrinsic, 2, 0) + INST_RANGE(TargetSpecificDecoration, TargetDecoration, TargetIntrinsicDecoration) + INST(GLSLOuterArrayDecoration, glslOuterArray, 1, 0) + INST(SemanticDecoration, semantic, 1, 0) + INST(InterpolationModeDecoration, interpolationMode, 1, 0) + INST(NameHintDecoration, nameHint, 1, 0) + + /** The decorated _instruction_ is transitory. Such a decoration should NEVER be found on an output instruction a module. + Typically used mark an instruction so can be specially handled - say when creating a IRConstant literal, and the payload of + needs to be special cased for lookup. */ + INST(TransitoryDecoration, transitory, 0, 0) + + INST(VulkanRayPayloadDecoration, vulkanRayPayload, 0, 0) + INST(VulkanHitAttributesDecoration, vulkanHitAttributes, 0, 0) + INST(RequireGLSLVersionDecoration, requireGLSLVersion, 1, 0) + INST(RequireGLSLExtensionDecoration, requireGLSLExtension, 1, 0) + INST(ReadNoneDecoration, readNone, 0, 0) + INST(VulkanCallablePayloadDecoration, vulkanCallablePayload, 0, 0) + INST(EarlyDepthStencilDecoration, earlyDepthStencil, 0, 0) + INST(GloballyCoherentDecoration, globallyCoherent, 0, 0) + INST(PreciseDecoration, precise, 0, 0) + INST(PatchConstantFuncDecoration, patchConstantFunc, 1, 0) + + /// An `[entryPoint]` decoration marks a function that represents a shader entry point. + INST(EntryPointDecoration, entryPoint, 0, 0) + + /// A `[dependsOn(x)]` decoration indicates that the parent instruction depends on `x` + /// even if it does not otherwise reference it. + INST(DependsOnDecoration, dependsOn, 1, 0) + + /// A `[keepAlive]` decoration marks an instruction that should not be eliminated. + INST(KeepAliveDecoration, keepAlive, 0, 0) + + INST(BindExistentialSlotsDecoration, bindExistentialSlots, 0, 0) + + /// A `[format(f)]` decoration specifies that the format of an image should be `f` + INST(FormatDecoration, format, 1, 0) + + /* LinkageDecoration */ + INST(ImportDecoration, import, 1, 0) + INST(ExportDecoration, export, 1, 0) + INST_RANGE(LinkageDecoration, ImportDecoration, ExportDecoration) + +INST_RANGE(Decoration, HighLevelDeclDecoration, ExportDecoration) + + +// + +// A `makeExistential(v : C, w) : I` instruction takes a value `v` of type `C` +// and produces a value of interface type `I` by using the witness `w` which +// shows that `C` conforms to `I`. +// +INST(MakeExistential, makeExistential, 2, 0) + +// A `wrapExistential(v, T0,w0, T1,w0) : T` instruction is similar to `makeExistential`. +// but applies to a value `v` that is of type `BindExistentials(T, T0,w0, ...)`. The +// result of the `wrapExistentials` operation is a value of type `T`, allowing us to +// "smuggle" a value of specialized type into computations that expect an unspecialized type. +// +INST(WrapExistential, wrapExistential, 2, 0) + +INST(ExtractExistentialValue, extractExistentialValue, 1, 0) +INST(ExtractExistentialType, extractExistentialType, 1, 0) +INST(ExtractExistentialWitnessTable, extractExistentialWitnessTable, 1, 0) + +INST(ExtractTaggedUnionTag, extractTaggedUnionTag, 1, 0) +INST(ExtractTaggedUnionPayload, extractTaggedUnionPayload, 1, 0) + +INST(BitCast, bitCast, 1, 0) + +PSEUDO_INST(Pos) +PSEUDO_INST(PreInc) + +PSEUDO_INST(PreDec) +PSEUDO_INST(PostInc) +PSEUDO_INST(PostDec) +PSEUDO_INST(Sequence) +PSEUDO_INST(AddAssign) +PSEUDO_INST(SubAssign) +PSEUDO_INST(MulAssign) +PSEUDO_INST(DivAssign) +PSEUDO_INST(ModAssign) +PSEUDO_INST(AndAssign) +PSEUDO_INST(OrAssign) +PSEUDO_INST(XorAssign ) +PSEUDO_INST(LshAssign) +PSEUDO_INST(RshAssign) +PSEUDO_INST(Assign) +PSEUDO_INST(And) +PSEUDO_INST(Or) + + +#undef PSEUDO_INST +#undef PARENT +#undef USE_OTHER +#undef INST_RANGE +#undef INST + diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h new file mode 100644 index 000000000..7229f04d1 --- /dev/null +++ b/source/slang/slang-ir-insts.h @@ -0,0 +1,1343 @@ +// slang-ir-insts.h +#ifndef SLANG_IR_INSTS_H_INCLUDED +#define SLANG_IR_INSTS_H_INCLUDED + +// This file extends the core definitions in `ir.h` +// with a wider variety of concrete instructions, +// and a "builder" abstraction. +// +// TODO: the builder probably needs its own file. + +#include "slang-compiler.h" +#include "slang-ir.h" +#include "slang-syntax.h" +#include "slang-type-layout.h" + +namespace Slang { + +class Decl; + +struct IRDecoration : IRInst +{ + IR_PARENT_ISA(Decoration) + + IRDecoration* getNextDecoration() + { + return as(getNextInst()); + } +}; + +// Associates an IR-level decoration with a source declaration +// in the high-level AST, that can be used to extract +// additional information that informs code emission. +struct IRHighLevelDeclDecoration : IRDecoration +{ + enum { kOp = kIROp_HighLevelDeclDecoration }; + IR_LEAF_ISA(HighLevelDeclDecoration) + + IRPtrLit* getDeclOperand() { return cast(getOperand(0)); } + Decl* getDecl() { return (Decl*) getDeclOperand()->getValue(); } +}; + +// Associates an IR-level decoration with a source layout +struct IRLayoutDecoration : IRDecoration +{ + enum { kOp = kIROp_LayoutDecoration }; + IR_LEAF_ISA(LayoutDecoration) + + IRPtrLit* getLayoutOperand() { return cast(getOperand(0)); } + Layout* getLayout() { return (Layout*) getLayoutOperand()->getValue(); } +}; + +enum IRLoopControl +{ + kIRLoopControl_Unroll, +}; + +struct IRLoopControlDecoration : IRDecoration +{ + enum { kOp = kIROp_LoopControlDecoration }; + IR_LEAF_ISA(LoopControlDecoration) + + IRConstant* getModeOperand() { return cast(getOperand(0)); } + + IRLoopControl getMode() + { + return IRLoopControl(getModeOperand()->value.intVal); + } +}; + + +struct IRTargetSpecificDecoration : IRDecoration +{ + IR_PARENT_ISA(TargetSpecificDecoration) + + IRStringLit* getTargetNameOperand() { return cast(getOperand(0)); } + + UnownedStringSlice getTargetName() + { + return getTargetNameOperand()->getStringSlice(); + } +}; + +struct IRTargetDecoration : IRTargetSpecificDecoration +{ + enum { kOp = kIROp_TargetDecoration }; + IR_LEAF_ISA(TargetDecoration) +}; + +struct IRTargetIntrinsicDecoration : IRTargetSpecificDecoration +{ + enum { kOp = kIROp_TargetIntrinsicDecoration }; + IR_LEAF_ISA(TargetIntrinsicDecoration) + + IRStringLit* getDefinitionOperand() { return cast(getOperand(1)); } + + UnownedStringSlice getDefinition() + { + return getDefinitionOperand()->getStringSlice(); + } +}; + +struct IRGLSLOuterArrayDecoration : IRDecoration +{ + enum { kOp = kIROp_GLSLOuterArrayDecoration }; + IR_LEAF_ISA(GLSLOuterArrayDecoration) + + IRStringLit* getOuterArraynameOperand() { return cast(getOperand(0)); } + + UnownedStringSlice getOuterArrayName() + { + return getOuterArraynameOperand()->getStringSlice(); + } +}; + +// A decoration that marks a field key as having been associated +// with a particular simple semantic (e.g., `COLOR` or `SV_Position`, +// but not a `register` semantic). +// +// This is currently needed so that we can round-trip HLSL `struct` +// types that get used for varying input/output. This is an unfortunate +// case where some amount of "layout" information can't just come +// in via the `TypeLayout` part of things. +// +struct IRSemanticDecoration : IRDecoration +{ + enum { kOp = kIROp_SemanticDecoration }; + IR_LEAF_ISA(SemanticDecoration) + + IRStringLit* getSemanticNameOperand() { return cast(getOperand(0)); } + + UnownedStringSlice getSemanticName() + { + return getSemanticNameOperand()->getStringSlice(); + } +}; + +enum class IRInterpolationMode +{ + Linear, + NoPerspective, + NoInterpolation, + + Centroid, + Sample, +}; + +struct IRInterpolationModeDecoration : IRDecoration +{ + enum { kOp = kIROp_InterpolationModeDecoration }; + IR_LEAF_ISA(InterpolationModeDecoration) + + IRConstant* getModeOperand() { return cast(getOperand(0)); } + + IRInterpolationMode getMode() + { + return IRInterpolationMode(getModeOperand()->value.intVal); + } +}; + +/// A decoration that provides a desired name to be used +/// in conjunction with the given instruction. Back-end +/// code generation may use this to help derive symbol +/// names, emit debug information, etc. +struct IRNameHintDecoration : IRDecoration +{ + enum { kOp = kIROp_NameHintDecoration }; + IR_LEAF_ISA(NameHintDecoration) + + IRStringLit* getNameOperand() { return cast(getOperand(0)); } + + UnownedStringSlice getName() + { + return getNameOperand()->getStringSlice(); + } +}; + +#define IR_SIMPLE_DECORATION(NAME) \ + struct IR##NAME : IRDecoration \ + { \ + enum { kOp = kIROp_##NAME }; \ + IR_LEAF_ISA(NAME) \ + }; \ + /**/ + +/// A decoration that indicates that a variable represents +/// a vulkan ray payload, and should have a location assigned +/// to it. +IR_SIMPLE_DECORATION(VulkanRayPayloadDecoration) + +/// A decoration that indicates that a variable represents +/// a vulkan callable shader payload, and should have a location assigned +/// to it. +IR_SIMPLE_DECORATION(VulkanCallablePayloadDecoration) + +/// A decoration that indicates that a variable represents +/// vulkan hit attributes, and should have a location assigned +/// to it. +IR_SIMPLE_DECORATION(VulkanHitAttributesDecoration) + +struct IRRequireGLSLVersionDecoration : IRDecoration +{ + enum { kOp = kIROp_RequireGLSLVersionDecoration }; + IR_LEAF_ISA(RequireGLSLVersionDecoration) + + IRConstant* getLanguageVersionOperand() { return cast(getOperand(0)); } + + Int getLanguageVersion() + { + return Int(getLanguageVersionOperand()->value.intVal); + } +}; + +struct IRRequireGLSLExtensionDecoration : IRDecoration +{ + enum { kOp = kIROp_RequireGLSLExtensionDecoration }; + IR_LEAF_ISA(RequireGLSLExtensionDecoration) + + IRStringLit* getExtensionNameOperand() { return cast(getOperand(0)); } + + UnownedStringSlice getExtensionName() + { + return getExtensionNameOperand()->getStringSlice(); + } +}; + +IR_SIMPLE_DECORATION(ReadNoneDecoration) +IR_SIMPLE_DECORATION(EarlyDepthStencilDecoration) +IR_SIMPLE_DECORATION(GloballyCoherentDecoration) +IR_SIMPLE_DECORATION(PreciseDecoration) + + /// A decoration that marks a value as having linkage. + /// + /// A value with linkage is either exported from its module, + /// or will have a definition imported from another module. + /// In either case, it requires a mangled name to use when + /// matching imports and exports. + /// +struct IRLinkageDecoration : IRDecoration +{ + IR_PARENT_ISA(LinkageDecoration) + + IRStringLit* getMangledNameOperand() { return cast(getOperand(0)); } + + UnownedStringSlice getMangledName() + { + return getMangledNameOperand()->getStringSlice(); + } +}; + +struct IRImportDecoration : IRLinkageDecoration +{ + enum { kOp = kIROp_ImportDecoration }; + IR_LEAF_ISA(ImportDecoration) +}; + +struct IRExportDecoration : IRLinkageDecoration +{ + enum { kOp = kIROp_ExportDecoration }; + IR_LEAF_ISA(ExportDecoration) +}; + +struct IRFormatDecoration : IRDecoration +{ + enum { kOp = kIROp_FormatDecoration }; + IR_LEAF_ISA(FormatDecoration) + + IRConstant* getFormatOperand() { return cast(getOperand(0)); } + + ImageFormat getFormat() + { + return ImageFormat(getFormatOperand()->value.intVal); + } +}; + +// An instruction that specializes another IR value +// (representing a generic) to a particular set of generic arguments +// (instructions representing types, witness tables, etc.) +struct IRSpecialize : IRInst +{ + // The "base" for the call is the generic to be specialized + IRUse base; + IRInst* getBase() { return getOperand(0); } + + // after the generic value come the arguments + UInt getArgCount() { return getOperandCount() - 1; } + IRInst* getArg(UInt index) { return getOperand(index + 1); } + + IR_LEAF_ISA(Specialize) +}; + +// An instruction that looks up the implementation +// of an interface operation identified by `requirementDeclRef` +// in the witness table `witnessTable` which should +// hold the conformance information for a specific type. +struct IRLookupWitnessMethod : IRInst +{ + IRUse witnessTable; + IRUse requirementKey; + + IRInst* getWitnessTable() { return witnessTable.get(); } + IRInst* getRequirementKey() { return requirementKey.get(); } +}; + +struct IRLookupWitnessTable : IRInst +{ + IRUse sourceType; + IRUse interfaceType; +}; + +// + +struct IRCall : IRInst +{ + IR_LEAF_ISA(Call) + + IRInst* getCallee() { return getOperand(0); } + + UInt getArgCount() { return getOperandCount() - 1; } + IRInst* getArg(UInt index) { return getOperand(index + 1); } +}; + +struct IRLoad : IRInst +{ + IRUse ptr; +}; + +struct IRStore : IRInst +{ + IRUse ptr; + IRUse val; +}; + +struct IRFieldExtract : IRInst +{ + IRUse base; + IRUse field; + + IRInst* getBase() { return base.get(); } + IRInst* getField() { return field.get(); } +}; + +struct IRFieldAddress : IRInst +{ + IRUse base; + IRUse field; + + IRInst* getBase() { return base.get(); } + IRInst* getField() { return field.get(); } +}; + +// Terminators + +struct IRReturn : IRTerminatorInst +{}; + +struct IRReturnVal : IRReturn +{ + IRUse val; + + IRInst* getVal() { return val.get(); } +}; + +struct IRReturnVoid : IRReturn +{}; + +struct IRDiscard : IRTerminatorInst +{}; + +// Signals that this point in the code should be unreachable. +// We can/should emit a dataflow error if we can ever determine +// that a block ending in one of these can actually be +// executed. +struct IRUnreachable : IRTerminatorInst +{ + IR_PARENT_ISA(Unreachable); +}; + +struct IRMissingReturn : IRUnreachable +{ + IR_LEAF_ISA(MissingReturn); +}; + +struct IRBlock; + +struct IRUnconditionalBranch : IRTerminatorInst +{ + IRUse block; + + IRBlock* getTargetBlock() { return (IRBlock*)block.get(); } + + UInt getArgCount(); + IRUse* getArgs(); + IRInst* getArg(UInt index); + + IR_PARENT_ISA(UnconditionalBranch); +}; + +// Special cases of unconditional branch, to handle +// structured control flow: +struct IRBreak : IRUnconditionalBranch {}; +struct IRContinue : IRUnconditionalBranch {}; + +// The start of a loop is a special control-flow +// instruction, that records relevant information +// about the loop structure: +struct IRLoop : IRUnconditionalBranch +{ + // The next block after the loop, which + // is where we expect control flow to + // re-converge, and also where a + // `break` will target. + IRUse breakBlock; + + // The block where control flow will go + // on a `continue`. + IRUse continueBlock; + + IRBlock* getBreakBlock() { return (IRBlock*)breakBlock.get(); } + IRBlock* getContinueBlock() { return (IRBlock*)continueBlock.get(); } +}; + +struct IRConditionalBranch : IRTerminatorInst +{ + IR_PARENT_ISA(ConditionalBranch) + + IRUse condition; + IRUse trueBlock; + IRUse falseBlock; + + IRInst* getCondition() { return condition.get(); } + IRBlock* getTrueBlock() { return (IRBlock*)trueBlock.get(); } + IRBlock* getFalseBlock() { return (IRBlock*)falseBlock.get(); } +}; + +// A conditional branch that represent the test inside a loop +struct IRLoopTest : IRConditionalBranch +{ +}; + +// A conditional branch that represents a one-sided `if`: +// +// if( ) { } +// +struct IRIf : IRConditionalBranch +{ + IRBlock* getAfterBlock() { return getFalseBlock(); } +}; + +// A conditional branch that represents a two-sided `if`: +// +// if( ) { } +// else { } +// +// +struct IRIfElse : IRConditionalBranch +{ + IRUse afterBlock; + + IRBlock* getAfterBlock() { return (IRBlock*)afterBlock.get(); } +}; + +// A multi-way branch that represents a source-level `switch` +struct IRSwitch : IRTerminatorInst +{ + IR_LEAF_ISA(Switch); + + IRUse condition; + IRUse breakLabel; + IRUse defaultLabel; + + IRInst* getCondition() { return condition.get(); } + IRBlock* getBreakLabel() { return (IRBlock*) breakLabel.get(); } + IRBlock* getDefaultLabel() { return (IRBlock*) defaultLabel.get(); } + + // remaining args are: caseVal, caseLabel, ... + + UInt getCaseCount() { return (getOperandCount() - 3) / 2; } + IRInst* getCaseValue(UInt index) { return getOperand(3 + index*2 + 0); } + IRBlock* getCaseLabel(UInt index) { return (IRBlock*) getOperand(3 + index*2 + 1); } +}; + +struct IRSwizzle : IRInst +{ + IRUse base; + + IRInst* getBase() { return base.get(); } + UInt getElementCount() + { + return getOperandCount() - 1; + } + IRInst* getElementIndex(UInt index) + { + return getOperand(index + 1); + } +}; + +struct IRSwizzleSet : IRInst +{ + IRUse base; + IRUse source; + + IRInst* getBase() { return base.get(); } + IRInst* getSource() { return source.get(); } + UInt getElementCount() + { + return getOperandCount() - 2; + } + IRInst* getElementIndex(UInt index) + { + return getOperand(index + 2); + } +}; + +struct IRSwizzledStore : IRInst +{ + IRInst* getDest() { return getOperand(0); } + IRInst* getSource() { return getOperand(1); } + UInt getElementCount() + { + return getOperandCount() - 2; + } + IRInst* getElementIndex(UInt index) + { + return getOperand(index + 2); + } + + IR_LEAF_ISA(SwizzledStore) +}; + + +struct IRPatchConstantFuncDecoration : IRDecoration +{ + enum { kOp = kIROp_PatchConstantFuncDecoration }; + IR_LEAF_ISA(PatchConstantFuncDecoration) + + IRInst* getFunc() { return getOperand(0); } +}; + +// An IR `var` instruction conceptually represents +// a stack allocation of some memory. +struct IRVar : IRInst +{ + IRPtrType* getDataType() + { + return cast(IRInst::getDataType()); + } + + static bool isaImpl(IROp op) { return op == kIROp_Var; } +}; + +/// @brief A global variable. +/// +/// Represents a global variable in the IR. +/// If the variable has an initializer, then +/// it is represented by the code in the basic +/// blocks nested inside this value. +struct IRGlobalVar : IRGlobalValueWithCode +{ + IRPtrType* getDataType() + { + return cast(IRInst::getDataType()); + } +}; + +/// @brief A global constant. +/// +/// Represents a global-scope constant value in the IR. +/// The initializer for the constant is represented by +/// the code in the basic block(s) nested in this value. +struct IRGlobalConstant : IRGlobalValueWithCode +{ + IR_LEAF_ISA(GlobalConstant) +}; + +struct IRGlobalParam : IRInst +{ + IR_LEAF_ISA(GlobalParam) +}; + + +// An entry in a witness table (see below) +struct IRWitnessTableEntry : IRInst +{ + // The AST-level requirement + IRUse requirementKey; + + // The IR-level value that satisfies the requirement + IRUse satisfyingVal; + + IRInst* getRequirementKey() { return getOperand(0); } + IRInst* getSatisfyingVal() { return getOperand(1); } + + IR_LEAF_ISA(WitnessTableEntry) +}; + +// A witness table is a global value that stores +// information about how a type conforms to some +// interface. It basically takes the form of a +// map from the required members of the interface +// to the IR values that satisfy those requirements. +struct IRWitnessTable : IRInst +{ + IRInstList getEntries() + { + return IRInstList(getChildren()); + } + + IR_LEAF_ISA(WitnessTable) +}; + +// An instruction that yields an undefined value. +// +// Note that we make this an instruction rather than a value, +// so that we will be able to identify a variable that is +// used when undefined. +struct IRUndefined : IRInst +{ +}; + +// A global-scope generic parameter (a type parameter, a +// constraint parameter, etc.) +struct IRGlobalGenericParam : IRInst +{ + IR_LEAF_ISA(GlobalGenericParam) +}; + +// An instruction that binds a global generic parameter +// to a particular value. +struct IRBindGlobalGenericParam : IRInst +{ + IRGlobalGenericParam* getParam() { return cast(getOperand(0)); } + IRInst* getVal() { return getOperand(1); } + + IR_LEAF_ISA(BindGlobalGenericParam) +}; + + + /// An instruction that packs a concrete value into an existential-type "box" +struct IRMakeExistential : IRInst +{ + IRInst* getWrappedValue() { return getOperand(0); } + IRInst* getWitnessTable() { return getOperand(1); } + + IR_LEAF_ISA(MakeExistential) +}; + + /// Generalizes `IRMakeExistential` by allowing a type with existential sub-fields to be boxed +struct IRWrapExistential : IRInst +{ + IRInst* getWrappedValue() { return getOperand(0); } + + UInt getSlotOperandCount() { return getOperandCount() - 1; } + IRInst* getSlotOperand(UInt index) { return getOperand(index + 1); } + IRUse* getSlotOperands() { return getOperands() + 1; } + + IR_LEAF_ISA(WrapExistential) +}; + + +// Description of an instruction to be used for global value numbering +struct IRInstKey +{ + IRInst* inst; + + int GetHashCode(); +}; + +bool operator==(IRInstKey const& left, IRInstKey const& right); + +struct IRConstantKey +{ + IRConstant* inst; + + bool operator==(const IRConstantKey& rhs) const { return inst->equal(rhs.inst); } + int GetHashCode() const { return inst->getHashCode(); } +}; + +struct SharedIRBuilder +{ + // The parent compilation session + Session* session; + Session* getSession() + { + return session; + } + + // The module that will own all of the IR + IRModule* module; + + Dictionary globalValueNumberingMap; + Dictionary constantMap; +}; + +struct IRBuilderSourceLocRAII; + +struct IRBuilder +{ + // Shared state for all IR builders working on the same module + SharedIRBuilder* sharedBuilder; + + Session* getSession() + { + return sharedBuilder->getSession(); + } + + IRModule* getModule() { return sharedBuilder->module; } + + // The current parent being inserted into (this might + // be the global scope, a function, a block inside + // a function, etc.) + IRInst* insertIntoParent = nullptr; + // + // An instruction in the current parent that we should insert before + IRInst* insertBeforeInst = nullptr; + + // Get the current basic block we are inserting into (if any) + IRBlock* getBlock(); + + // Get the current function (or other value with code) + // that we are inserting into (if any). + IRGlobalValueWithCode* getFunc(); + + void setInsertInto(IRInst* insertInto); + void setInsertBefore(IRInst* insertBefore); + + IRBuilderSourceLocRAII* sourceLocInfo = nullptr; + + void addInst(IRInst* inst); + + IRInst* getBoolValue(bool value); + IRInst* getIntValue(IRType* type, IRIntegerValue value); + IRInst* getFloatValue(IRType* type, IRFloatingPointValue value); + IRStringLit* getStringValue(const UnownedStringSlice& slice); + IRPtrLit* getPtrValue(void* value); + + IRBasicType* getBasicType(BaseType baseType); + IRBasicType* getVoidType(); + IRBasicType* getBoolType(); + IRBasicType* getIntType(); + IRStringType* getStringType(); + + IRBasicBlockType* getBasicBlockType(); + IRType* getWitnessTableType() { return nullptr; } + IRType* getKeyType() { return nullptr; } + + IRTypeKind* getTypeKind(); + IRGenericKind* getGenericKind(); + + IRPtrType* getPtrType(IRType* valueType); + IROutType* getOutType(IRType* valueType); + IRInOutType* getInOutType(IRType* valueType); + IRRefType* getRefType(IRType* valueType); + IRPtrTypeBase* getPtrType(IROp op, IRType* valueType); + + IRArrayTypeBase* getArrayTypeBase( + IROp op, + IRType* elementType, + IRInst* elementCount); + + IRArrayType* getArrayType( + IRType* elementType, + IRInst* elementCount); + + IRUnsizedArrayType* getUnsizedArrayType( + IRType* elementType); + + IRVectorType* getVectorType( + IRType* elementType, + IRInst* elementCount); + + IRMatrixType* getMatrixType( + IRType* elementType, + IRInst* rowCount, + IRInst* columnCount); + + IRFuncType* getFuncType( + UInt paramCount, + IRType* const* paramTypes, + IRType* resultType); + + IRFuncType* getFuncType( + List const& paramTypes, + IRType* resultType) + { + return getFuncType(paramTypes.getCount(), paramTypes.getBuffer(), resultType); + } + + IRConstantBufferType* getConstantBufferType( + IRType* elementType); + + IRConstExprRate* getConstExprRate(); + IRGroupSharedRate* getGroupSharedRate(); + + IRRateQualifiedType* getRateQualifiedType( + IRRate* rate, + IRType* dataType); + + IRType* getTaggedUnionType( + UInt caseCount, + IRType* const* caseTypes); + + IRType* getTaggedUnionType( + List const& caseTypes) + { + return getTaggedUnionType(caseTypes.getCount(), caseTypes.getBuffer()); + } + + IRType* getBindExistentialsType( + IRInst* baseType, + UInt slotArgCount, + IRInst* const* slotArgs); + + IRType* getBindExistentialsType( + IRInst* baseType, + UInt slotArgCount, + IRUse const* slotArgs); + + // Set the data type of an instruction, while preserving + // its rate, if any. + void setDataType(IRInst* inst, IRType* dataType); + + /// Given an existential value, extract the underlying "real" value + IRInst* emitExtractExistentialValue( + IRType* type, + IRInst* existentialValue); + + /// Given an existential value, extract the underlying "real" type + IRType* emitExtractExistentialType( + IRInst* existentialValue); + + /// Given an existential value, extract the witness table showing how the value conforms to the existential type. + IRInst* emitExtractExistentialWitnessTable( + IRInst* existentialValue); + + IRInst* emitSpecializeInst( + IRType* type, + IRInst* genericVal, + UInt argCount, + IRInst* const* args); + + IRInst* emitLookupInterfaceMethodInst( + IRType* type, + IRInst* witnessTableVal, + IRInst* interfaceMethodVal); + + IRInst* emitCallInst( + IRType* type, + IRInst* func, + UInt argCount, + IRInst* const* args); + + IRInst* emitCallInst( + IRType* type, + IRInst* func, + List const& args) + { + return emitCallInst(type, func, args.getCount(), args.getBuffer()); + } + + IRInst* createIntrinsicInst( + IRType* type, + IROp op, + UInt argCount, + IRInst* const* args); + + IRInst* emitIntrinsicInst( + IRType* type, + IROp op, + UInt argCount, + IRInst* const* args); + + IRInst* emitConstructorInst( + IRType* type, + UInt argCount, + IRInst* const* args); + + IRInst* emitMakeVector( + IRType* type, + UInt argCount, + IRInst* const* args); + + IRInst* emitMakeVector( + IRType* type, + List const& args) + { + return emitMakeVector(type, args.getCount(), args.getBuffer()); + } + + IRInst* emitMakeMatrix( + IRType* type, + UInt argCount, + IRInst* const* args); + + IRInst* emitMakeArray( + IRType* type, + UInt argCount, + IRInst* const* args); + + IRInst* emitMakeStruct( + IRType* type, + UInt argCount, + IRInst* const* args); + + IRInst* emitMakeStruct( + IRType* type, + List const& args) + { + return emitMakeStruct(type, args.getCount(), args.getBuffer()); + } + + IRInst* emitMakeExistential( + IRType* type, + IRInst* value, + IRInst* witnessTable); + + IRInst* emitWrapExistential( + IRType* type, + IRInst* value, + UInt slotArgCount, + IRInst* const* slotArgs); + + IRInst* emitWrapExistential( + IRType* type, + IRInst* value, + UInt slotArgCount, + IRUse const* slotArgs) + { + List slotArgVals; + for(UInt ii = 0; ii < slotArgCount; ++ii) + slotArgVals.add(slotArgs[ii].get()); + + return emitWrapExistential(type, value, slotArgCount, slotArgVals.getBuffer()); + } + + IRUndefined* emitUndefined(IRType* type); + + + + IRModule* createModule(); + + IRFunc* createFunc(); + IRGlobalVar* createGlobalVar( + IRType* valueType); + IRGlobalConstant* createGlobalConstant( + IRType* valueType); + IRGlobalParam* createGlobalParam( + IRType* valueType); + IRWitnessTable* createWitnessTable(); + IRWitnessTableEntry* createWitnessTableEntry( + IRWitnessTable* witnessTable, + IRInst* requirementKey, + IRInst* satisfyingVal); + + // Create an initially empty `struct` type. + IRStructType* createStructType(); + + // Create an empty `interface` type. + IRInterfaceType* createInterfaceType(); + + // Create a global "key" to use for indexing into a `struct` type. + IRStructKey* createStructKey(); + + // Create a field nested in a struct type, declaring that + // the specified field key maps to a field with the specified type. + IRStructField* createStructField( + IRStructType* structType, + IRStructKey* fieldKey, + IRType* fieldType); + + IRGeneric* createGeneric(); + IRGeneric* emitGeneric(); + + // Low-level operation for creating a type. + IRType* getType( + IROp op, + UInt operandCount, + IRInst* const* operands); + IRType* getType( + IROp op); + + /// Create an empty basic block. + /// + /// The created block will not be inserted into the current + /// function; call `insertBlock()` to attach the block + /// at an appropriate point. + /// + IRBlock* createBlock(); + + /// Insert a block into the current function. + /// + /// This attaches the given `block` to the current function, + /// and makes it the current block for + /// new instructions that get emitted. + /// + void insertBlock(IRBlock* block); + + /// Emit a new block into the current function. + /// + /// This function is equivalent to using `createBlock()` + /// and then `insertBlock()`. + /// + IRBlock* emitBlock(); + + + + IRParam* createParam( + IRType* type); + IRParam* emitParam( + IRType* type); + + IRVar* emitVar( + IRType* type); + + IRInst* emitLoad( + IRType* type, + IRInst* ptr); + + IRInst* emitLoad( + IRInst* ptr); + + IRInst* emitStore( + IRInst* dstPtr, + IRInst* srcVal); + + IRInst* emitFieldExtract( + IRType* type, + IRInst* base, + IRInst* field); + + IRInst* emitFieldAddress( + IRType* type, + IRInst* basePtr, + IRInst* field); + + IRInst* emitElementExtract( + IRType* type, + IRInst* base, + IRInst* index); + + IRInst* emitElementAddress( + IRType* type, + IRInst* basePtr, + IRInst* index); + + IRInst* emitSwizzle( + IRType* type, + IRInst* base, + UInt elementCount, + IRInst* const* elementIndices); + + IRInst* emitSwizzle( + IRType* type, + IRInst* base, + UInt elementCount, + UInt const* elementIndices); + + IRInst* emitSwizzleSet( + IRType* type, + IRInst* base, + IRInst* source, + UInt elementCount, + IRInst* const* elementIndices); + + IRInst* emitSwizzleSet( + IRType* type, + IRInst* base, + IRInst* source, + UInt elementCount, + UInt const* elementIndices); + + IRInst* emitSwizzledStore( + IRInst* dest, + IRInst* source, + UInt elementCount, + IRInst* const* elementIndices); + + IRInst* emitSwizzledStore( + IRInst* dest, + IRInst* source, + UInt elementCount, + UInt const* elementIndices); + + + + IRInst* emitReturn( + IRInst* val); + + IRInst* emitReturn(); + + IRInst* emitDiscard(); + + IRInst* emitUnreachable(); + IRInst* emitMissingReturn(); + + IRInst* emitBranch( + IRBlock* block); + + IRInst* emitBreak( + IRBlock* target); + + IRInst* emitContinue( + IRBlock* target); + + IRInst* emitLoop( + IRBlock* target, + IRBlock* breakBlock, + IRBlock* continueBlock); + + IRInst* emitBranch( + IRInst* val, + IRBlock* trueBlock, + IRBlock* falseBlock); + + IRInst* emitIf( + IRInst* val, + IRBlock* trueBlock, + IRBlock* afterBlock); + + IRInst* emitIfElse( + IRInst* val, + IRBlock* trueBlock, + IRBlock* falseBlock, + IRBlock* afterBlock); + + IRInst* emitLoopTest( + IRInst* val, + IRBlock* bodyBlock, + IRBlock* breakBlock); + + IRInst* emitSwitch( + IRInst* val, + IRBlock* breakLabel, + IRBlock* defaultLabel, + UInt caseArgCount, + IRInst* const* caseArgs); + + IRGlobalGenericParam* emitGlobalGenericParam(); + + IRBindGlobalGenericParam* emitBindGlobalGenericParam( + IRInst* param, + IRInst* val); + + IRInst* emitBindGlobalExistentialSlots( + UInt argCount, + IRInst* const* args); + + IRDecoration* addBindExistentialSlotsDecoration( + IRInst* value, + UInt argCount, + IRInst* const* args); + + IRInst* emitExtractTaggedUnionTag( + IRInst* val); + + IRInst* emitExtractTaggedUnionPayload( + IRType* type, + IRInst* val, + IRInst* tag); + + IRInst* emitBitCast( + IRType* type, + IRInst* val); + + // + // Decorations + // + + IRDecoration* addDecoration(IRInst* value, IROp op, IRInst* const* operands, Int operandCount); + + IRDecoration* addDecoration(IRInst* value, IROp op) + { + return addDecoration(value, op, (IRInst* const*) nullptr, 0); + } + + IRDecoration* addDecoration(IRInst* value, IROp op, IRInst* operand) + { + return addDecoration(value, op, &operand, 1); + } + + IRDecoration* addDecoration(IRInst* value, IROp op, IRInst* operand0, IRInst* operand1) + { + IRInst* operands[] = { operand0, operand1 }; + return addDecoration(value, op, operands, SLANG_COUNT_OF(operands)); + } + + template + T* addRefObjectToFree(T* ptr) + { + getModule()->getObjectScopeManager()->addMaybeNull(ptr); + return ptr; + } + + template + void addSimpleDecoration(IRInst* value) + { + addDecoration(value, IROp(T::kOp), (IRInst* const*) nullptr, 0); + } + + void addHighLevelDeclDecoration(IRInst* value, Decl* decl); + void addLayoutDecoration(IRInst* value, Layout* layout); + + void addNameHintDecoration(IRInst* value, IRStringLit* name) + { + addDecoration(value, kIROp_NameHintDecoration, name); + } + + void addNameHintDecoration(IRInst* value, UnownedStringSlice const& text) + { + addNameHintDecoration(value, getStringValue(text)); + } + + void addGLSLOuterArrayDecoration(IRInst* value, UnownedStringSlice const& text) + { + addDecoration(value, kIROp_GLSLOuterArrayDecoration, getStringValue(text)); + } + + void addInterpolationModeDecoration(IRInst* value, IRInterpolationMode mode) + { + addDecoration(value, kIROp_InterpolationModeDecoration, getIntValue(getIntType(), IRIntegerValue(mode))); + } + + void addLoopControlDecoration(IRInst* value, IRLoopControl mode) + { + addDecoration(value, kIROp_LoopControlDecoration, getIntValue(getIntType(), IRIntegerValue(mode))); + } + + void addSemanticDecoration(IRInst* value, UnownedStringSlice const& text) + { + addDecoration(value, kIROp_SemanticDecoration, getStringValue(text)); + } + + void addTargetIntrinsicDecoration(IRInst* value, UnownedStringSlice const& target, UnownedStringSlice const& definition) + { + addDecoration(value, kIROp_TargetIntrinsicDecoration, getStringValue(target), getStringValue(definition)); + } + + void addTargetDecoration(IRInst* value, UnownedStringSlice const& target) + { + addDecoration(value, kIROp_TargetDecoration, getStringValue(target)); + } + + void addRequireGLSLExtensionDecoration(IRInst* value, UnownedStringSlice const& extensionName) + { + addDecoration(value, kIROp_RequireGLSLExtensionDecoration, getStringValue(extensionName)); + } + + void addRequireGLSLVersionDecoration(IRInst* value, Int version) + { + addDecoration(value, kIROp_RequireGLSLVersionDecoration, getIntValue(getIntType(), IRIntegerValue(version))); + } + + void addPatchConstantFuncDecoration(IRInst* value, IRInst* patchConstantFunc) + { + addDecoration(value, kIROp_PatchConstantFuncDecoration, patchConstantFunc); + } + + void addImportDecoration(IRInst* value, UnownedStringSlice const& mangledName) + { + addDecoration(value, kIROp_ImportDecoration, getStringValue(mangledName)); + } + + void addExportDecoration(IRInst* value, UnownedStringSlice const& mangledName) + { + addDecoration(value, kIROp_ExportDecoration, getStringValue(mangledName)); + } + + void addEntryPointDecoration(IRInst* value) + { + addDecoration(value, kIROp_EntryPointDecoration); + } + + void addKeepAliveDecoration(IRInst* value) + { + addDecoration(value, kIROp_KeepAliveDecoration); + } + + /// Add a decoration that indicates that the given `inst` depends on the given `dependency`. + /// + /// This decoration can be used to ensure that a value that an instruction + /// implicitly depends on cannot be eliminated so long as the instruction + /// itself is kept alive. + /// + void addDependsOnDecoration(IRInst* inst, IRInst* dependency) + { + addDecoration(inst, kIROp_DependsOnDecoration, dependency); + } + + void addFormatDecoration(IRInst* inst, ImageFormat format) + { + addFormatDecoration(inst, getIntValue(getIntType(), IRIntegerValue(format))); + } + + void addFormatDecoration(IRInst* inst, IRInst* format) + { + addDecoration(inst, kIROp_FormatDecoration, format); + } +}; + +void addHoistableInst( + IRBuilder* builder, + IRInst* inst); + +// Helper to establish the source location that will be used +// by an IRBuilder. +struct IRBuilderSourceLocRAII +{ + IRBuilder* builder; + SourceLoc sourceLoc; + IRBuilderSourceLocRAII* next; + + IRBuilderSourceLocRAII( + IRBuilder* builder, + SourceLoc sourceLoc) + : builder(builder) + , sourceLoc(sourceLoc) + , next(nullptr) + { + next = builder->sourceLocInfo; + builder->sourceLocInfo = this; + } + + ~IRBuilderSourceLocRAII() + { + SLANG_ASSERT(builder->sourceLocInfo == this); + builder->sourceLocInfo = next; + } +}; + +// + +void markConstExpr( + IRBuilder* builder, + IRInst* irValue); + +// + +IRTargetIntrinsicDecoration* findTargetIntrinsicDecoration( + IRInst* val, + String const& targetName); + +} + +#endif diff --git a/source/slang/slang-ir-legalize-types.cpp b/source/slang/slang-ir-legalize-types.cpp new file mode 100644 index 000000000..815afea72 --- /dev/null +++ b/source/slang/slang-ir-legalize-types.cpp @@ -0,0 +1,2626 @@ +// slang-ir-legalize-types.cpp + +// This file implements type legalization for the IR. +// It uses the core legalization logic in +// `legalize-types.{h,cpp}` to decide what to do with +// the types, while this file handles the actual +// rewriting of the IR to use the new types. +// +// This pass should only be applied to IR that has been +// fully specialized (no more generics/interfaces), so +// that the concrete type of everything is known. + +#include "slang-ir.h" +#include "slang-ir-clone.h" +#include "slang-ir-insts.h" +#include "slang-legalize-types.h" +#include "slang-mangle.h" +#include "slang-name.h" + +namespace Slang +{ + +LegalVal LegalVal::tuple(RefPtr tupleVal) +{ + SLANG_ASSERT(tupleVal->elements.getCount()); + + LegalVal result; + result.flavor = LegalVal::Flavor::tuple; + result.obj = tupleVal; + return result; +} + +LegalVal LegalVal::pair(RefPtr pairInfo) +{ + LegalVal result; + result.flavor = LegalVal::Flavor::pair; + result.obj = pairInfo; + return result; +} + +LegalVal LegalVal::pair( + LegalVal const& ordinaryVal, + LegalVal const& specialVal, + RefPtr pairInfo) +{ + if (ordinaryVal.flavor == LegalVal::Flavor::none) + return specialVal; + + if (specialVal.flavor == LegalVal::Flavor::none) + return ordinaryVal; + + + RefPtr obj = new PairPseudoVal(); + obj->ordinaryVal = ordinaryVal; + obj->specialVal = specialVal; + obj->pairInfo = pairInfo; + + return LegalVal::pair(obj); +} + +LegalVal LegalVal::implicitDeref(LegalVal const& val) +{ + RefPtr implicitDerefVal = new ImplicitDerefVal(); + implicitDerefVal->val = val; + + LegalVal result; + result.flavor = LegalVal::Flavor::implicitDeref; + result.obj = implicitDerefVal; + return result; +} + +LegalVal LegalVal::getImplicitDeref() +{ + SLANG_ASSERT(flavor == Flavor::implicitDeref); + return as(obj)->val; +} + +LegalVal LegalVal::wrappedBuffer( + LegalVal const& baseVal, + LegalElementWrapping const& elementInfo) +{ + RefPtr obj = new WrappedBufferPseudoVal(); + obj->base = baseVal; + obj->elementInfo = elementInfo; + + LegalVal result; + result.flavor = LegalVal::Flavor::wrappedBuffer; + result.obj = obj; + return result; +} + +// + +IRTypeLegalizationContext::IRTypeLegalizationContext( + IRModule* inModule) +{ + session = inModule->getSession(); + module = inModule; + + auto sharedBuilder = &sharedBuilderStorage; + sharedBuilder->session = session; + sharedBuilder->module = module; + + builder = &builderStorage; + builder->sharedBuilder = sharedBuilder; +} + +static void registerLegalizedValue( + IRTypeLegalizationContext* context, + IRInst* irValue, + LegalVal const& legalVal) +{ + context->mapValToLegalVal[irValue] = legalVal; +} + +struct IRGlobalNameInfo +{ + IRInst* globalVar; + UInt counter; +}; + +static LegalVal declareVars( + IRTypeLegalizationContext* context, + IROp op, + LegalType type, + TypeLayout* typeLayout, + LegalVarChain const& varChain, + UnownedStringSlice nameHint, + IRInst* leafVar, + IRGlobalNameInfo* globalNameInfo, + bool isSpecial); + + /// Unwrap a value with flavor `wrappedBuffer` + /// + /// The original `legalPtrOperand` has a wrapped-buffer type + /// which encodes the way that, e.g., a `ConstantBuffer` + /// where `Foo` includes interface types, got legalized + /// into a buffer that stores a `Foo` value plus addition + /// fields for the concrete types that got plugged in. + /// + /// The `elementInfo` is the layout information for the + /// modified ("wrapped") buffer type, and specifies how + /// the logical element type was expanded into actual fields. + /// + /// This function returns a new value that undoes all of + /// the wrapping and produces a new `LegalVal` that matches + /// the nominal type of the original buffer. + /// +static LegalVal unwrapBufferValue( + IRTypeLegalizationContext* context, + LegalVal legalPtrOperand, + LegalElementWrapping const& elementInfo); + + /// Perform any actions required to materialize `val` into a usable value. + /// + /// Certain case of `LegalVal` (currently just the `wrappedBuffer` case) are + /// suitable for use to represent a variable, but cannot be used directly + /// in computations, because their structured needs to be "unwrapped." + /// + /// This function unwraps any `val` that needs it, which may involve + /// emitting additional IR instructions, and returns the unmodified + /// `val` otherwise. + /// +static LegalVal maybeMaterializeWrappedValue( + IRTypeLegalizationContext* context, + LegalVal val) +{ + if(val.flavor != LegalVal::Flavor::wrappedBuffer) + return val; + + auto wrappedBufferVal = val.getWrappedBuffer(); + return unwrapBufferValue( + context, + wrappedBufferVal->base, + wrappedBufferVal->elementInfo); +} + +// Take a value that is being used as an operand, +// and turn it into the equivalent legalized value. +static LegalVal legalizeOperand( + IRTypeLegalizationContext* context, + IRInst* irValue) +{ + LegalVal legalVal; + if( context->mapValToLegalVal.TryGetValue(irValue, legalVal) ) + { + return maybeMaterializeWrappedValue(context, legalVal); + } + + // For now, assume that anything not covered + // by the mapping is legal as-is. + + return LegalVal::simple(irValue); +} + +static void getArgumentValues( + List & instArgs, + LegalVal val) +{ + switch (val.flavor) + { + case LegalVal::Flavor::none: + break; + + case LegalVal::Flavor::simple: + instArgs.add(val.getSimple()); + break; + + case LegalVal::Flavor::implicitDeref: + getArgumentValues(instArgs, val.getImplicitDeref()); + break; + + case LegalVal::Flavor::pair: + { + auto pairVal = val.getPair(); + getArgumentValues(instArgs, pairVal->ordinaryVal); + getArgumentValues(instArgs, pairVal->specialVal); + } + break; + + case LegalVal::Flavor::tuple: + { + auto tuplePsuedoVal = val.getTuple(); + for (auto elem : val.getTuple()->elements) + { + getArgumentValues(instArgs, elem.val); + } + } + break; + + default: + SLANG_UNEXPECTED("uhandled val flavor"); + break; + } +} + +static LegalVal legalizeCall( + IRTypeLegalizationContext* context, + IRCall* callInst) +{ + auto retType = legalizeType(context, callInst->getFullType()); + IRType* retIRType = nullptr; + switch (retType.flavor) + { + case LegalType::Flavor::simple: + retIRType = retType.getSimple(); + break; + case LegalType::Flavor::none: + retIRType = context->builder->getVoidType(); + break; + default: + // TODO: implement legalization of non-simple return types + SLANG_UNEXPECTED("unimplemented legalized return type for IRInstCall."); + } + + List instArgs; + for (auto i = 1u; i < callInst->getOperandCount(); i++) + getArgumentValues(instArgs, legalizeOperand(context, callInst->getOperand(i))); + + return LegalVal::simple(context->builder->emitCallInst( + retIRType, + callInst->getCallee(), + instArgs.getCount(), + instArgs.getBuffer())); +} + +static LegalVal legalizeRetVal(IRTypeLegalizationContext* context, + LegalVal retVal) +{ + switch (retVal.flavor) + { + case LegalVal::Flavor::simple: + return LegalVal::simple(context->builder->emitReturn(retVal.getSimple())); + case LegalVal::Flavor::none: + return LegalVal::simple(context->builder->emitReturn()); + default: + // TODO: implement legalization of non-simple return types + SLANG_UNEXPECTED("unimplemented legalized return type for IRReturnVal."); + } +} + +static LegalVal legalizeLoad( + IRTypeLegalizationContext* context, + LegalVal legalPtrVal) +{ + switch (legalPtrVal.flavor) + { + case LegalVal::Flavor::none: + return LegalVal(); + + case LegalVal::Flavor::simple: + { + return LegalVal::simple( + context->builder->emitLoad(legalPtrVal.getSimple())); + } + break; + + case LegalVal::Flavor::implicitDeref: + // We have turne a pointer(-like) type into its pointed-to (value) + // type, and so the operation of loading goes away; we just use + // the underlying value. + return legalPtrVal.getImplicitDeref(); + + case LegalVal::Flavor::pair: + { + auto ptrPairVal = legalPtrVal.getPair(); + + auto ordinaryVal = legalizeLoad(context, ptrPairVal->ordinaryVal); + auto specialVal = legalizeLoad(context, ptrPairVal->specialVal); + return LegalVal::pair(ordinaryVal, specialVal, ptrPairVal->pairInfo); + } + + case LegalVal::Flavor::tuple: + { + // We need to emit a load for each element of + // the tuple. + auto ptrTupleVal = legalPtrVal.getTuple(); + RefPtr tupleVal = new TuplePseudoVal(); + + for (auto ee : legalPtrVal.getTuple()->elements) + { + TuplePseudoVal::Element element; + element.key = ee.key; + element.val = legalizeLoad(context, ee.val); + + tupleVal->elements.add(element); + } + return LegalVal::tuple(tupleVal); + } + break; + + default: + SLANG_UNEXPECTED("unhandled case"); + break; + } +} + +static LegalVal legalizeStore( + IRTypeLegalizationContext* context, + LegalVal legalPtrVal, + LegalVal legalVal) +{ + switch (legalPtrVal.flavor) + { + case LegalVal::Flavor::none: + return LegalVal(); + + case LegalVal::Flavor::simple: + { + context->builder->emitStore(legalPtrVal.getSimple(), legalVal.getSimple()); + return legalVal; + } + break; + + case LegalVal::Flavor::implicitDeref: + // TODO: what is the right behavior here? + // + // The crux of the problem is that we may legalize a pointer-to-pointer + // type in cases where one of the two needs to become an implicit-deref, + // so that we have `PtrA>` become, say, `PtrA` with + // an `implicitDeref` wrapper. When we encounter a store to that + // wrapped value, we seemingly need to know whether the original code + // meant to store to `*ptrPtr` or `**ptrPtr`, and need to legalize + // the result accordingly... + // + if( legalVal.flavor == LegalVal::Flavor::implicitDeref ) + return legalizeStore(context, legalPtrVal.getImplicitDeref(), legalVal.getImplicitDeref()); + else + return legalizeStore(context, legalPtrVal.getImplicitDeref(), legalVal); + + case LegalVal::Flavor::pair: + { + auto destPair = legalPtrVal.getPair(); + auto valPair = legalVal.getPair(); + legalizeStore(context, destPair->ordinaryVal, valPair->ordinaryVal); + legalizeStore(context, destPair->specialVal, valPair->specialVal); + return LegalVal(); + } + + case LegalVal::Flavor::tuple: + { + // We need to emit a store for each element of + // the tuple. + auto destTuple = legalPtrVal.getTuple(); + auto valTuple = legalVal.getTuple(); + SLANG_ASSERT(destTuple->elements.getCount() == valTuple->elements.getCount()); + + for (Index i = 0; i < valTuple->elements.getCount(); i++) + { + legalizeStore(context, destTuple->elements[i].val, valTuple->elements[i].val); + } + return legalVal; + } + break; + + default: + SLANG_UNEXPECTED("unhandled case"); + break; + } +} + +static LegalVal legalizeFieldExtract( + IRTypeLegalizationContext* context, + LegalType type, + LegalVal legalStructOperand, + IRStructKey* fieldKey) +{ + auto builder = context->builder; + + if (type.flavor == LegalType::Flavor::none) + return LegalVal(); + + switch (legalStructOperand.flavor) + { + case LegalVal::Flavor::none: + return LegalVal(); + + case LegalVal::Flavor::simple: + return LegalVal::simple( + builder->emitFieldExtract( + type.getSimple(), + legalStructOperand.getSimple(), + fieldKey)); + + case LegalVal::Flavor::pair: + { + // There are two sides, the ordinary and the special, + // and we basically just dispatch to both of them. + auto pairVal = legalStructOperand.getPair(); + auto pairInfo = pairVal->pairInfo; + auto pairElement = pairInfo->findElement(fieldKey); + if (!pairElement) + { + SLANG_UNEXPECTED("didn't find tuple element"); + UNREACHABLE_RETURN(LegalVal()); + } + + // If the field we are extracting has a pair type, + // that means it exists on both the ordinary and + // special sides. + RefPtr fieldPairInfo; + LegalType ordinaryType = type; + LegalType specialType = type; + if (type.flavor == LegalType::Flavor::pair) + { + auto fieldPairType = type.getPair(); + fieldPairInfo = fieldPairType->pairInfo; + ordinaryType = fieldPairType->ordinaryType; + specialType = fieldPairType->specialType; + } + + LegalVal ordinaryVal; + LegalVal specialVal; + + if (pairElement->flags & PairInfo::kFlag_hasOrdinary) + { + ordinaryVal = legalizeFieldExtract( + context, + ordinaryType, + pairVal->ordinaryVal, + fieldKey); + } + + if (pairElement->flags & PairInfo::kFlag_hasSpecial) + { + specialVal = legalizeFieldExtract( + context, + specialType, + pairVal->specialVal, + fieldKey); + } + return LegalVal::pair(ordinaryVal, specialVal, fieldPairInfo); + } + break; + + case LegalVal::Flavor::tuple: + { + // The operand is a tuple of pointer-like + // values, we want to extract the element + // corresponding to a field. We will handle + // this by simply returning the corresponding + // element from the operand. + auto ptrTupleInfo = legalStructOperand.getTuple(); + for (auto ee : ptrTupleInfo->elements) + { + if (ee.key == fieldKey) + { + return ee.val; + } + } + + // TODO: we can legally reach this case now + // when the field is "ordinary". + + SLANG_UNEXPECTED("didn't find tuple element"); + UNREACHABLE_RETURN(LegalVal()); + } + + default: + SLANG_UNEXPECTED("unhandled"); + UNREACHABLE_RETURN(LegalVal()); + } +} + +static LegalVal legalizeFieldExtract( + IRTypeLegalizationContext* context, + LegalType type, + LegalVal legalPtrOperand, + LegalVal legalFieldOperand) +{ + // We don't expect any legalization to affect + // the "field" argument. + auto fieldKey = legalFieldOperand.getSimple(); + + return legalizeFieldExtract( + context, + type, + legalPtrOperand, + (IRStructKey*) fieldKey); +} + + /// Take a value of some buffer/pointer type and unwrap it according to provided info. +static LegalVal unwrapBufferValue( + IRTypeLegalizationContext* context, + LegalVal legalPtrOperand, + LegalElementWrapping const& elementInfo) +{ + // The `elementInfo` tells us how a non-simple element + // type was wrapped up into a new structure types used + // as the element type of the buffer. + // + // This function will recurse through the structure of + // `elementInfo` to pull out all the required data from + // the buffer represented by `legalPtrOperand`. + + switch( elementInfo.flavor ) + { + default: + SLANG_UNEXPECTED("unhandled"); + UNREACHABLE_RETURN(LegalVal()); + break; + + case LegalElementWrapping::Flavor::none: + return LegalVal(); + + case LegalElementWrapping::Flavor::simple: + { + // In the leaf case, we just had to store some + // data of a simple type in the buffer. We can + // produce a valid result by computing the + // address of the field used to represent the + // element, and then returning *that* as if + // it were the buffer type itself. + // + // (Basically instead of `someBuffer` we will + // end up with `&(someBuffer->field)`. + // + auto builder = context->getBuilder(); + + auto simpleElementInfo = elementInfo.getSimple(); + auto valPtr = builder->emitFieldAddress( + builder->getPtrType(simpleElementInfo->type), + legalPtrOperand.getSimple(), + simpleElementInfo->key); + + return LegalVal::simple(valPtr); + } + + case LegalElementWrapping::Flavor::implicitDeref: + { + // If the element type was logically `ImplicitDeref`, + // then we declared actual fields based on `T`, and + // we need to extract references to those fields and + // wrap them up in an `implicitDeref` value. + // + auto derefField = elementInfo.getImplicitDeref(); + auto baseVal = unwrapBufferValue(context, legalPtrOperand, derefField->field); + return LegalVal::implicitDeref(baseVal); + } + + case LegalElementWrapping::Flavor::pair: + { + // If the element type was logically a `Pair` + // then we encoded fields for both `O` and `S` into + // the actual element type, and now we need to + // extract references to both and pair them up. + // + auto pairField = elementInfo.getPair(); + auto pairInfo = pairField->pairInfo; + + auto ordinaryVal = unwrapBufferValue(context, legalPtrOperand, pairField->ordinary); + auto specialVal = unwrapBufferValue(context, legalPtrOperand, pairField->special); + return LegalVal::pair(ordinaryVal, specialVal, pairInfo); + } + + case LegalElementWrapping::Flavor::tuple: + { + // If the element type was logically a `Tuple` + // then we encoded fields for each of the `Ei` and + // need to extract references to all of them and + // encode them as a tuple. + // + auto tupleField = elementInfo.getTuple(); + + RefPtr obj = new TuplePseudoVal(); + for( auto ee : tupleField->elements ) + { + auto elementVal = unwrapBufferValue( + context, + legalPtrOperand, + ee.field); + + TuplePseudoVal::Element element; + element.key = ee.key; + element.val = unwrapBufferValue( + context, + legalPtrOperand, + ee.field); + obj->elements.add(element); + } + + return LegalVal::tuple(obj); + } + } +} + +static IRType* getPointedToType( + IRTypeLegalizationContext* context, + IRType* ptrType) +{ + auto valueType = tryGetPointedToType(context->builder, ptrType); + if( !valueType ) + { + SLANG_UNEXPECTED("expected a pointer type during type legalization"); + } + return valueType; +} + +static LegalType getPointedToType( + IRTypeLegalizationContext* context, + LegalType type) +{ + switch( type.flavor ) + { + case LegalType::Flavor::none: + return LegalType(); + + case LegalType::Flavor::simple: + return LegalType::simple(getPointedToType(context, type.getSimple())); + + case LegalType::Flavor::implicitDeref: + return type.getImplicitDeref()->valueType; + + case LegalType::Flavor::pair: + { + auto pairType = type.getPair(); + auto ordinary = getPointedToType(context, pairType->ordinaryType); + auto special = getPointedToType(context, pairType->specialType); + return LegalType::pair(ordinary, special, pairType->pairInfo); + } + + case LegalType::Flavor::tuple: + { + auto tupleType = type.getTuple(); + RefPtr resultTuple = new TuplePseudoType(); + for( auto ee : tupleType->elements ) + { + TuplePseudoType::Element resultElement; + resultElement.key = ee.key; + resultElement.type = getPointedToType(context, ee.type); + resultTuple->elements.add(resultElement); + } + return LegalType::tuple(resultTuple); + } + + default: + SLANG_UNEXPECTED("unhandled case in type legalization"); + UNREACHABLE_RETURN(LegalType()); + } +} + +static LegalVal legalizeFieldAddress( + IRTypeLegalizationContext* context, + LegalType type, + LegalVal legalPtrOperand, + IRStructKey* fieldKey) +{ + auto builder = context->builder; + if (type.flavor == LegalType::Flavor::none) + return LegalVal(); + + switch (legalPtrOperand.flavor) + { + case LegalVal::Flavor::none: + return LegalVal(); + + case LegalVal::Flavor::simple: + switch( type.flavor ) + { + case LegalType::Flavor::implicitDeref: + // TODO: Should this case be needed? + return legalizeFieldAddress( + context, + type.getImplicitDeref()->valueType, + legalPtrOperand, + fieldKey); + + default: + return LegalVal::simple( + builder->emitFieldAddress( + type.getSimple(), + legalPtrOperand.getSimple(), + fieldKey)); + } + + case LegalVal::Flavor::pair: + { + // There are two sides, the ordinary and the special, + // and we basically just dispatch to both of them. + auto pairVal = legalPtrOperand.getPair(); + auto pairInfo = pairVal->pairInfo; + auto pairElement = pairInfo->findElement(fieldKey); + if (!pairElement) + { + SLANG_UNEXPECTED("didn't find tuple element"); + UNREACHABLE_RETURN(LegalVal()); + } + + // If the field we are extracting has a pair type, + // that means it exists on both the ordinary and + // special sides. + RefPtr fieldPairInfo; + LegalType ordinaryType = type; + LegalType specialType = type; + if (type.flavor == LegalType::Flavor::pair) + { + auto fieldPairType = type.getPair(); + fieldPairInfo = fieldPairType->pairInfo; + ordinaryType = fieldPairType->ordinaryType; + specialType = fieldPairType->specialType; + } + + LegalVal ordinaryVal; + LegalVal specialVal; + + if (pairElement->flags & PairInfo::kFlag_hasOrdinary) + { + ordinaryVal = legalizeFieldAddress( + context, + ordinaryType, + pairVal->ordinaryVal, + fieldKey); + } + + if (pairElement->flags & PairInfo::kFlag_hasSpecial) + { + specialVal = legalizeFieldAddress( + context, + specialType, + pairVal->specialVal, + fieldKey); + } + return LegalVal::pair(ordinaryVal, specialVal, fieldPairInfo); + } + break; + + case LegalVal::Flavor::tuple: + { + // The operand is a tuple of pointer-like + // values, we want to extract the element + // corresponding to a field. We will handle + // this by simply returning the corresponding + // element from the operand. + auto ptrTupleInfo = legalPtrOperand.getTuple(); + for (auto ee : ptrTupleInfo->elements) + { + if (ee.key == fieldKey) + { + return ee.val; + } + } + + // TODO: we can legally reach this case now + // when the field is "ordinary". + + SLANG_UNEXPECTED("didn't find tuple element"); + UNREACHABLE_RETURN(LegalVal()); + } + + case LegalVal::Flavor::implicitDeref: + { + // The original value had a level of indirection + // that is now being removed, so should not be + // able to get at the *address* of the field any + // more, and need to resign ourselves to just + // getting at the field *value* and then + // adding an implicit dereference on top of that. + // + auto implicitDerefVal = legalPtrOperand.getImplicitDeref(); + auto valueType = getPointedToType(context, type); + return LegalVal::implicitDeref(legalizeFieldExtract(context, valueType, implicitDerefVal, fieldKey)); + } + + default: + SLANG_UNEXPECTED("unhandled"); + UNREACHABLE_RETURN(LegalVal()); + } +} + +static LegalVal legalizeFieldAddress( + IRTypeLegalizationContext* context, + LegalType type, + LegalVal legalPtrOperand, + LegalVal legalFieldOperand) +{ + // We don't expect any legalization to affect + // the "field" argument. + auto fieldKey = legalFieldOperand.getSimple(); + + return legalizeFieldAddress( + context, + type, + legalPtrOperand, + (IRStructKey*) fieldKey); +} + +static LegalVal legalizeGetElement( + IRTypeLegalizationContext* context, + LegalType type, + LegalVal legalPtrOperand, + IRInst* indexOperand) +{ + auto builder = context->builder; + + switch (legalPtrOperand.flavor) + { + case LegalVal::Flavor::none: + return LegalVal(); + + case LegalVal::Flavor::simple: + return LegalVal::simple( + builder->emitElementExtract( + type.getSimple(), + legalPtrOperand.getSimple(), + indexOperand)); + + case LegalVal::Flavor::pair: + { + // There are two sides, the ordinary and the special, + // and we basically just dispatch to both of them. + auto pairVal = legalPtrOperand.getPair(); + auto pairInfo = pairVal->pairInfo; + + LegalType ordinaryType = type; + LegalType specialType = type; + if (type.flavor == LegalType::Flavor::pair) + { + auto pairType = type.getPair(); + ordinaryType = pairType->ordinaryType; + specialType = pairType->specialType; + } + + LegalVal ordinaryVal = legalizeGetElement( + context, + ordinaryType, + pairVal->ordinaryVal, + indexOperand); + + LegalVal specialVal = legalizeGetElement( + context, + specialType, + pairVal->specialVal, + indexOperand); + + return LegalVal::pair(ordinaryVal, specialVal, pairInfo); + } + break; + + case LegalVal::Flavor::tuple: + { + // The operand is a tuple of pointer-like + // values, we want to extract the element + // corresponding to a field. We will handle + // this by simply returning the corresponding + // element from the operand. + auto ptrTupleInfo = legalPtrOperand.getTuple(); + + RefPtr resTupleInfo = new TuplePseudoVal(); + + auto tupleType = type.getTuple(); + SLANG_ASSERT(tupleType); + + auto elemCount = ptrTupleInfo->elements.getCount(); + SLANG_ASSERT(elemCount == tupleType->elements.getCount()); + + for(Index ee = 0; ee < elemCount; ++ee) + { + auto ptrElem = ptrTupleInfo->elements[ee]; + auto elemType = tupleType->elements[ee].type; + + TuplePseudoVal::Element resElem; + resElem.key = ptrElem.key; + resElem.val = legalizeGetElement( + context, + elemType, + ptrElem.val, + indexOperand); + + resTupleInfo->elements.add(resElem); + } + + return LegalVal::tuple(resTupleInfo); + } + + default: + SLANG_UNEXPECTED("unhandled"); + UNREACHABLE_RETURN(LegalVal()); + } +} + +static LegalVal legalizeGetElement( + IRTypeLegalizationContext* context, + LegalType type, + LegalVal legalPtrOperand, + LegalVal legalIndexOperand) +{ + // We don't expect any legalization to affect + // the "index" argument. + auto indexOperand = legalIndexOperand.getSimple(); + + return legalizeGetElement( + context, + type, + legalPtrOperand, + indexOperand); +} + +static LegalVal legalizeGetElementPtr( + IRTypeLegalizationContext* context, + LegalType type, + LegalVal legalPtrOperand, + IRInst* indexOperand) +{ + auto builder = context->builder; + + switch (legalPtrOperand.flavor) + { + case LegalVal::Flavor::none: + return LegalVal(); + + case LegalVal::Flavor::simple: + return LegalVal::simple( + builder->emitElementAddress( + type.getSimple(), + legalPtrOperand.getSimple(), + indexOperand)); + + case LegalVal::Flavor::pair: + { + // There are two sides, the ordinary and the special, + // and we basically just dispatch to both of them. + auto pairVal = legalPtrOperand.getPair(); + auto pairInfo = pairVal->pairInfo; + + LegalType ordinaryType = type; + LegalType specialType = type; + if (type.flavor == LegalType::Flavor::pair) + { + auto pairType = type.getPair(); + ordinaryType = pairType->ordinaryType; + specialType = pairType->specialType; + } + + LegalVal ordinaryVal = legalizeGetElementPtr( + context, + ordinaryType, + pairVal->ordinaryVal, + indexOperand); + + LegalVal specialVal = legalizeGetElementPtr( + context, + specialType, + pairVal->specialVal, + indexOperand); + + return LegalVal::pair(ordinaryVal, specialVal, pairInfo); + } + break; + + case LegalVal::Flavor::tuple: + { + // The operand is a tuple of pointer-like + // values, we want to extract the element + // corresponding to a field. We will handle + // this by simply returning the corresponding + // element from the operand. + auto ptrTupleInfo = legalPtrOperand.getTuple(); + + RefPtr resTupleInfo = new TuplePseudoVal(); + + auto tupleType = type.getTuple(); + SLANG_ASSERT(tupleType); + + auto elemCount = ptrTupleInfo->elements.getCount(); + SLANG_ASSERT(elemCount == tupleType->elements.getCount()); + + for(Index ee = 0; ee < elemCount; ++ee) + { + auto ptrElem = ptrTupleInfo->elements[ee]; + auto elemType = tupleType->elements[ee].type; + + TuplePseudoVal::Element resElem; + resElem.key = ptrElem.key; + resElem.val = legalizeGetElementPtr( + context, + elemType, + ptrElem.val, + indexOperand); + + resTupleInfo->elements.add(resElem); + } + + return LegalVal::tuple(resTupleInfo); + } + + case LegalVal::Flavor::implicitDeref: + { + // The original value used to be a pointer to an array, + // and somebody is trying to get at an element pointer. + // Now we just have an array (wrapped with an implicit + // dereference) and need to just fetch the chosen element + // instead (and then wrap the element value with an + // implicit dereference). + // + // The result type for our `getElement` instruction needs + // to be the type *pointed to* by `type`, and not `type. + // + auto valueType = getPointedToType(context, type); + + auto implicitDerefVal = legalPtrOperand.getImplicitDeref(); + return LegalVal::implicitDeref(legalizeGetElement( + context, + valueType, + implicitDerefVal, + indexOperand)); + } + + default: + SLANG_UNEXPECTED("unhandled"); + UNREACHABLE_RETURN(LegalVal()); + } +} + +static LegalVal legalizeGetElementPtr( + IRTypeLegalizationContext* context, + LegalType type, + LegalVal legalPtrOperand, + LegalVal legalIndexOperand) +{ + // We don't expect any legalization to affect + // the "index" argument. + auto indexOperand = legalIndexOperand.getSimple(); + + return legalizeGetElementPtr( + context, + type, + legalPtrOperand, + indexOperand); +} + +static LegalVal legalizeMakeStruct( + IRTypeLegalizationContext* context, + LegalType legalType, + LegalVal const* legalArgs, + UInt argCount) +{ + auto builder = context->builder; + + switch(legalType.flavor) + { + case LegalType::Flavor::none: + return LegalVal(); + + case LegalType::Flavor::simple: + { + List args; + for(UInt aa = 0; aa < argCount; ++aa) + { + // Note: we assume that all the arguments + // must be simple here, because otherwise + // the `struct` type with them as fields + // would not be simple... + // + args.add(legalArgs[aa].getSimple()); + } + return LegalVal::simple( + builder->emitMakeStruct( + legalType.getSimple(), + argCount, + args.getBuffer())); + } + + case LegalType::Flavor::pair: + { + // There are two sides, the ordinary and the special, + // and we basically just dispatch to both of them. + auto pairType = legalType.getPair(); + auto pairInfo = pairType->pairInfo; + LegalType ordinaryType = pairType->ordinaryType; + LegalType specialType = pairType->specialType; + + List ordinaryArgs; + List specialArgs; + UInt argCounter = 0; + for(auto ee : pairInfo->elements) + { + UInt argIndex = argCounter++; + LegalVal arg = legalArgs[argIndex]; + + if( arg.flavor == LegalVal::Flavor::pair ) + { + // The argument is itself a pair + auto argPair = arg.getPair(); + ordinaryArgs.add(argPair->ordinaryVal); + specialArgs.add(argPair->specialVal); + } + else if(ee.flags & Slang::PairInfo::kFlag_hasOrdinary) + { + ordinaryArgs.add(arg); + } + else if(ee.flags & Slang::PairInfo::kFlag_hasSpecial) + { + specialArgs.add(arg); + } + } + + LegalVal ordinaryVal = legalizeMakeStruct( + context, + ordinaryType, + ordinaryArgs.getBuffer(), + ordinaryArgs.getCount()); + + LegalVal specialVal = legalizeMakeStruct( + context, + specialType, + specialArgs.getBuffer(), + specialArgs.getCount()); + + return LegalVal::pair(ordinaryVal, specialVal, pairInfo); + } + break; + + case LegalType::Flavor::tuple: + { + // We are constructing a tuple of values from + // the individual fields. We need to identify + // for each tuple element what field it uses, + // and then extract that field's value. + + auto tupleType = legalType.getTuple(); + + RefPtr resTupleInfo = new TuplePseudoVal(); + UInt argCounter = 0; + for(auto typeElem : tupleType->elements) + { + auto elemKey = typeElem.key; + UInt argIndex = argCounter++; + SLANG_ASSERT(argIndex < argCount); + + LegalVal argVal = legalArgs[argIndex]; + + TuplePseudoVal::Element resElem; + resElem.key = elemKey; + resElem.val = argVal; + + resTupleInfo->elements.add(resElem); + } + return LegalVal::tuple(resTupleInfo); + } + + default: + SLANG_UNEXPECTED("unhandled"); + UNREACHABLE_RETURN(LegalVal()); + } +} + +static LegalVal legalizeConstruct(IRTypeLegalizationContext* context, + LegalType type) +{ + switch (type.flavor) + { + case LegalType::Flavor::none: + return LegalVal(); + case LegalType::Flavor::simple: + return LegalVal::simple(context->builder->emitConstructorInst(type.getSimple(), 0, nullptr)); + default: + SLANG_UNEXPECTED("unhandled legalization case for construct inst."); + UNREACHABLE_RETURN(LegalVal()); + } +} + +static LegalVal legalizeInst( + IRTypeLegalizationContext* context, + IRInst* inst, + LegalType type, + LegalVal const* args) +{ + switch (inst->op) + { + case kIROp_Load: + return legalizeLoad(context, args[0]); + + case kIROp_FieldAddress: + return legalizeFieldAddress(context, type, args[0], args[1]); + + case kIROp_FieldExtract: + return legalizeFieldExtract(context, type, args[0], args[1]); + + case kIROp_getElement: + return legalizeGetElement(context, type, args[0], args[1]); + + case kIROp_getElementPtr: + return legalizeGetElementPtr(context, type, args[0], args[1]); + + case kIROp_Store: + return legalizeStore(context, args[0], args[1]); + + case kIROp_Call: + return legalizeCall(context, (IRCall*)inst); + case kIROp_ReturnVal: + return legalizeRetVal(context, args[0]); + case kIROp_makeStruct: + return legalizeMakeStruct( + context, + type, + args, + inst->getOperandCount()); + case kIROp_Construct: + return legalizeConstruct(context, type); + case kIROp_undefined: + return LegalVal(); + default: + // TODO: produce a user-visible diagnostic here + SLANG_UNEXPECTED("non-simple operand(s)!"); + break; + } +} + +RefPtr findVarLayout(IRInst* value) +{ + if (auto layoutDecoration = value->findDecoration()) + return as(layoutDecoration->getLayout()); + return nullptr; +} + +static UnownedStringSlice findNameHint(IRInst* inst) +{ + if( auto nameHintDecoration = inst->findDecoration() ) + { + return nameHintDecoration->getName(); + } + return UnownedStringSlice(); +} + +static LegalVal legalizeLocalVar( + IRTypeLegalizationContext* context, + IRVar* irLocalVar) +{ + // Legalize the type for the variable's value + auto originalValueType = irLocalVar->getDataType()->getValueType(); + auto legalValueType = legalizeType( + context, + originalValueType); + + auto originalRate = irLocalVar->getRate(); + + RefPtr varLayout = findVarLayout(irLocalVar); + RefPtr typeLayout = varLayout ? varLayout->typeLayout : nullptr; + + // If we've decided to do implicit deref on the type, + // then go ahead and declare a value of the pointed-to type. + LegalType maybeSimpleType = legalValueType; + while (maybeSimpleType.flavor == LegalType::Flavor::implicitDeref) + { + maybeSimpleType = maybeSimpleType.getImplicitDeref()->valueType; + } + + switch (maybeSimpleType.flavor) + { + case LegalType::Flavor::simple: + { + // Easy case: the type is usable as-is, and we + // should just do that. + auto type = maybeSimpleType.getSimple(); + type = context->builder->getPtrType(type); + if( originalRate ) + { + type = context->builder->getRateQualifiedType( + originalRate, + type); + } + irLocalVar->setFullType(type); + return LegalVal::simple(irLocalVar); + } + + default: + { + // TODO: We don't handle rates in this path. + + context->insertBeforeLocalVar = irLocalVar; + + LegalVarChainLink varChain(LegalVarChain(), varLayout); + + UnownedStringSlice nameHint = findNameHint(irLocalVar); + context->builder->setInsertBefore(irLocalVar); + LegalVal newVal = declareVars(context, kIROp_Var, legalValueType, typeLayout, varChain, nameHint, irLocalVar, nullptr, context->isSpecialType(originalValueType)); + + // Remove the old local var. + irLocalVar->removeFromParent(); + // add old local var to list + context->replacedInstructions.add(irLocalVar); + return newVal; + } + break; + } +} + +static LegalVal legalizeParam( + IRTypeLegalizationContext* context, + IRParam* originalParam) +{ + auto legalParamType = legalizeType(context, originalParam->getFullType()); + if (legalParamType.flavor == LegalType::Flavor::simple) + { + // Simple case: things were legalized to a simple type, + // so we can just use the original parameter as-is. + originalParam->setFullType(legalParamType.getSimple()); + return LegalVal::simple(originalParam); + } + else + { + // Complex case: we need to insert zero or more new parameters, + // which will replace the old ones. + + context->insertBeforeParam = originalParam; + + UnownedStringSlice nameHint = findNameHint(originalParam); + + context->builder->setInsertBefore(originalParam); + auto newVal = declareVars(context, kIROp_Param, legalParamType, nullptr, LegalVarChain(), nameHint, originalParam, nullptr, context->isSpecialType(originalParam->getDataType())); + + originalParam->removeFromParent(); + context->replacedInstructions.add(originalParam); + return newVal; + } +} + +static LegalVal legalizeFunc( + IRTypeLegalizationContext* context, + IRFunc* irFunc); + +static LegalVal legalizeGlobalVar( + IRTypeLegalizationContext* context, + IRGlobalVar* irGlobalVar); + +static LegalVal legalizeGlobalConstant( + IRTypeLegalizationContext* context, + IRGlobalConstant* irGlobalConstant); + +static LegalVal legalizeGlobalParam( + IRTypeLegalizationContext* context, + IRGlobalParam* irGlobalParam); + +static LegalVal legalizeInst( + IRTypeLegalizationContext* context, + IRInst* inst) +{ + // Any additional instructions we need to emit + // in the process of legalizing `inst` should + // by default be insertied right before `inst`. + // + context->builder->setInsertBefore(inst); + + // Special-case certain operations + switch (inst->op) + { + case kIROp_Var: + return legalizeLocalVar(context, cast(inst)); + + case kIROp_Param: + return legalizeParam(context, cast(inst)); + + case kIROp_WitnessTable: + // Just skip these. + break; + + case kIROp_Func: + return legalizeFunc(context, cast(inst)); + + case kIROp_GlobalVar: + return legalizeGlobalVar(context, cast(inst)); + + case kIROp_GlobalConstant: + return legalizeGlobalConstant(context, cast(inst)); + + case kIROp_GlobalParam: + return legalizeGlobalParam(context, cast(inst)); + + default: + break; + } + + // We will iterate over all the operands, extract the legalized + // value of each, and collect them in an array for subsequent use. + // + auto argCount = inst->getOperandCount(); + List legalArgs; + // + // Along the way we will also note whether there were any operands + // with non-simple legalized values. + // + bool anyComplex = false; + for (UInt aa = 0; aa < argCount; ++aa) + { + auto oldArg = inst->getOperand(aa); + auto legalArg = legalizeOperand(context, oldArg); + legalArgs.add(legalArg); + + if (legalArg.flavor != LegalVal::Flavor::simple) + anyComplex = true; + } + + // We must also legalize the type of the instruction, since that + // is implicitly one of its operands. + // + LegalType legalType = legalizeType(context, inst->getFullType()); + + // If there was nothing interesting that occured for the operands + // then we can re-use this instruction as-is. + // + if (!anyComplex && legalType.flavor == LegalType::Flavor::simple) + { + // While the operands are all "simple," they might not necessarily + // be equal to the operands we started with. + // + for (UInt aa = 0; aa < argCount; ++aa) + { + auto legalArg = legalArgs[aa]; + inst->setOperand(aa, legalArg.getSimple()); + } + + inst->setFullType(legalType.getSimple()); + + return LegalVal::simple(inst); + } + + // We have at least one "complex" operand, and we + // need to figure out what to do with it. The anwer + // will, in general, depend on what we are doing. + + // We will set up the IR builder so that any new + // instructions generated will be placed before + // the location of the original instruction. + auto builder = context->builder; + builder->setInsertBefore(inst); + + LegalVal legalVal = legalizeInst( + context, + inst, + legalType, + legalArgs.getBuffer()); + + // After we are done, we will eliminate the + // original instruction by removing it from + // the IR. + // + inst->removeFromParent(); + context->replacedInstructions.add(inst); + + // The value to be used when referencing + // the original instruction will now be + // whatever value(s) we created to replace it. + return legalVal; +} + +static void addParamType(List& ioParamTypes, LegalType t) +{ + switch (t.flavor) + { + case LegalType::Flavor::none: + break; + + case LegalType::Flavor::simple: + ioParamTypes.add(t.getSimple()); + break; + + case LegalType::Flavor::implicitDeref: + { + auto imp = t.getImplicitDeref(); + addParamType(ioParamTypes, imp->valueType); + break; + } + case LegalType::Flavor::pair: + { + auto pairInfo = t.getPair(); + addParamType(ioParamTypes, pairInfo->ordinaryType); + addParamType(ioParamTypes, pairInfo->specialType); + } + break; + case LegalType::Flavor::tuple: + { + auto tup = t.getTuple(); + for (auto & elem : tup->elements) + addParamType(ioParamTypes, elem.type); + } + break; + default: + SLANG_UNEXPECTED("unknown legalized type flavor"); + } +} + +static void legalizeInstsInParent( + IRTypeLegalizationContext* context, + IRInst* parent) +{ + IRInst* nextChild = nullptr; + for(auto child = parent->getFirstDecorationOrChild(); child; child = nextChild) + { + nextChild = child->getNextInst(); + + if (auto block = as(child)) + { + legalizeInstsInParent(context, block); + } + else + { + LegalVal legalVal = legalizeInst(context, child); + registerLegalizedValue(context, child, legalVal); + } + } +} + +static LegalVal legalizeFunc( + IRTypeLegalizationContext* context, + IRFunc* irFunc) +{ + // Overwrite the function's type with the result of legalization. + + IRFuncType* oldFuncType = irFunc->getDataType(); + UInt oldParamCount = oldFuncType->getParamCount(); + + // TODO: we should give an error message when the result type of a function + // can't be legalized (e.g., trying to return a texture, or a structue that + // contains one). + auto legalReturnType = legalizeType(context, oldFuncType->getResultType()); + IRType* newResultType = nullptr; + switch (legalReturnType.flavor) + { + case LegalType::Flavor::simple: + newResultType = legalReturnType.getSimple(); + break; + case LegalType::Flavor::none: + newResultType = context->builder->getVoidType(); + break; + default: + SLANG_UNEXPECTED("unknown legalized function return type."); + } + List newParamTypes; + for (UInt pp = 0; pp < oldParamCount; ++pp) + { + auto legalParamType = legalizeType(context, oldFuncType->getParamType(pp)); + addParamType(newParamTypes, legalParamType); + } + + auto newFuncType = context->builder->getFuncType( + newParamTypes.getCount(), + newParamTypes.getBuffer(), + newResultType); + + context->builder->setDataType(irFunc, newFuncType); + + legalizeInstsInParent(context, irFunc); + return LegalVal::simple(irFunc); +} + +static LegalVal declareSimpleVar( + IRTypeLegalizationContext* context, + IROp op, + IRType* type, + TypeLayout* typeLayout, + LegalVarChain const& varChain, + UnownedStringSlice nameHint, + IRInst* leafVar, + IRGlobalNameInfo* globalNameInfo) +{ + SLANG_UNUSED(globalNameInfo); + + RefPtr varLayout = createVarLayout(varChain, typeLayout); + + DeclRef varDeclRef = varChain.getLeafVarDeclRef(); + + IRBuilder* builder = context->builder; + + IRInst* irVar = nullptr; + LegalVal legalVarVal; + + switch (op) + { + case kIROp_GlobalVar: + { + auto globalVar = builder->createGlobalVar(type); + globalVar->removeFromParent(); + globalVar->insertBefore(context->insertBeforeGlobal); + + irVar = globalVar; + legalVarVal = LegalVal::simple(irVar); + } + break; + + case kIROp_GlobalConstant: + { + auto globalConst = builder->createGlobalConstant(type); + globalConst->removeFromParent(); + globalConst->insertBefore(context->insertBeforeGlobal); + + irVar = globalConst; + legalVarVal = LegalVal::simple(globalConst); + } + break; + + case kIROp_GlobalParam: + { + auto globalParam = builder->createGlobalParam(type); + globalParam->removeFromParent(); + globalParam->insertBefore(context->insertBeforeGlobal); + + irVar = globalParam; + legalVarVal = LegalVal::simple(globalParam); + } + break; + + case kIROp_Var: + { + builder->setInsertBefore(context->insertBeforeLocalVar); + auto localVar = builder->emitVar(type); + + irVar = localVar; + legalVarVal = LegalVal::simple(irVar); + + } + break; + + case kIROp_Param: + { + auto param = builder->emitParam(type); + param->insertBefore(context->insertBeforeParam); + + irVar = param; + legalVarVal = LegalVal::simple(irVar); + } + break; + + default: + SLANG_UNEXPECTED("unexpected IR opcode"); + break; + } + + if (irVar) + { + if (varLayout) + { + builder->addLayoutDecoration(irVar, varLayout); + } + + if (varDeclRef) + { + builder->addHighLevelDeclDecoration(irVar, varDeclRef.getDecl()); + } + + if( nameHint.size() ) + { + context->builder->addNameHintDecoration(irVar, nameHint); + } + + if( leafVar ) + { + for( auto decoration : leafVar->getDecorations() ) + { + switch( decoration->op ) + { + case kIROp_FormatDecoration: + cloneDecoration(decoration, irVar); + break; + + default: + break; + } + } + } + + } + + return legalVarVal; +} + + /// Add layout information for the fields of a wrapped buffer type. + /// + /// A wrapped buffer type encodes a buffer like `ConstantBuffer` + /// where `Foo` might have interface-type fields that have been + /// specialized to a concrete type. E.g.: + /// + /// struct Car { IDriver driver; int mph; }; + /// ConstantBuffer machOne; + /// + /// In a case where the `machOne.driver` field has been specialized + /// to the type `SpeedRacer`, we need to generate a legalized + /// buffer layout something like: + /// + /// struct Car_0 { int mph; } + /// struct Wrapped { Car_0 car; SpeedRacer card_d; } + /// ConstantBuffer machOne; + /// + /// The layout information for the existing `machOne` clearly + /// can't apply because we have a new element type with new fields. + /// + /// This function is used to recursively fill in the layout for + /// the fields of the `Wrapped` type, using information recorded + /// when the legal wrapped buffer type was created. + /// +static void _addFieldsToWrappedBufferElementTypeLayout( + TypeLayout* elementTypeLayout, // layout of the original field type + StructTypeLayout* newTypeLayout, // layout we are filling in + LegalElementWrapping const& elementInfo, // information on how the original type got wrapped + LegalVarChain const& varChain, // chain of variables that is leading to this field + bool isSpecial) // should we assume a leaf field is a special (interface) type? +{ + // The way we handle things depends primary on the + // `elementInfo`, because that tells us how things + // were wrapped up when the type was legalized. + + switch( elementInfo.flavor ) + { + case LegalElementWrapping::Flavor::none: + // A leaf `none` value meant there was nothing + // to encode for a particular field (probably + // had a `void` or empty structure type). + break; + + case LegalElementWrapping::Flavor::simple: + { + auto simpleInfo = elementInfo.getSimple(); + + // A `simple` wrapping means we hit a leaf + // field that can be encoded directly. + // What we do here depends on whether we've + // reached an ordinary field of the original + // data type, or if we've reached a leaf + // field of interface type. + // + // We've been tracking a `varChain` that + // remembers all the parent `struct` fields + // we've navigated through to get here, and + // that information has been tracking two + // different pieces of layout: + // + // * The "primary" layout represents the storage + // of the buffer element type as we usually + // think of its (e.g., the bytes starting at offset zero). + // + // * The "pending" layout tells us where all the + // fields representing concrete types plugged in + // for interface-type slots got placed. + // + // We have tunneled down info to tell us which case + // we should use (`isSpecial`). + // + // Most of the logic is the same between the two + // cases. We will be computing layout information + // for a field of the new/wrapped buffer element type. + // + RefPtr newFieldLayout; + if(isSpecial) + { + // In the special case, that field will be laid out + // based on the "pending" var chain, and the type + // of the pending data for the element. + // + newFieldLayout = createSimpleVarLayout(varChain.pendingChain, elementTypeLayout->pendingDataTypeLayout); + } + else + { + // The ordinary case just uses the primary layout + // information and the primary/nominal type of + // the field. + // + newFieldLayout = createSimpleVarLayout(varChain.primaryChain, elementTypeLayout); + } + + // Either way, we add the new field to the struct type + // layout we are building, and also update the mapping + // information so that we can find the field layout + // based on the IR key for the struct field. + // + newTypeLayout->fields.add(newFieldLayout); + newTypeLayout->mapKeyToLayout.Add(simpleInfo->key, newFieldLayout); + } + break; + + case LegalElementWrapping::Flavor::implicitDeref: + { + // This is the case where a field in the element type + // has been legalized from `SomePtrLikeType` to + // `T`, so there is a different in levels of indirection. + // + // We need to recurse and see how the type `T` + // got laid out to know what field(s) it might comprise. + // + auto implicitDerefInfo = elementInfo.getImplicitDeref(); + _addFieldsToWrappedBufferElementTypeLayout( + elementTypeLayout, + newTypeLayout, + implicitDerefInfo->field, + varChain, + isSpecial); + return; + } + break; + + case LegalElementWrapping::Flavor::pair: + { + // The pair case is the first main workhorse where + // if we had a type that mixed ordinary and interface-type + // fields, it would get split into an ordinary part + // and a "special" part, each of which might comprise + // zero or more fields. + // + // Here we recurse on both the ordinary and special + // sides, and the only interesting tidbit is that + // we pass along appropriate values for the `isSpecial` + // flag so that we act appropriately upon running + // into a leaf field. + // + auto pairElementInfo = elementInfo.getPair(); + _addFieldsToWrappedBufferElementTypeLayout( + elementTypeLayout, + newTypeLayout, + pairElementInfo->ordinary, + varChain, + false); + _addFieldsToWrappedBufferElementTypeLayout( + elementTypeLayout, + newTypeLayout, + pairElementInfo->special, + varChain, + true); + } + break; + + case LegalElementWrapping::Flavor::tuple: + { + // A tuple comes up when we've turned an aggregate + // with one or more interface-type fields into + // distinct fields at the top level. + // + // For the most part we just recurse on each field, + // but note that we set the `isSpecial` flag on + // the recursive calls, since we never use tuples + // to store anything that isn't special. + + auto tupleInfo = elementInfo.getTuple(); + for( auto ee : tupleInfo->elements ) + { + auto oldFieldLayout = getFieldLayout(elementTypeLayout, ee.key); + SLANG_ASSERT(oldFieldLayout); + + LegalVarChainLink fieldChain(varChain, oldFieldLayout); + + _addFieldsToWrappedBufferElementTypeLayout( + oldFieldLayout->typeLayout, + newTypeLayout, + ee.field, + fieldChain, + true); + } + } + break; + + default: + SLANG_UNEXPECTED("unhandled element wrapping flavor"); + break; + } +} + + /// Add offset information for `kind` to `resultVarLayout`, + /// if it doesn't already exist, and adjust the offset so + /// that it will represent an offset relative to the + /// "primary" data for the surrounding type, rather than + /// being relative to the "pending" data. + /// +static void _addOffsetVarLayoutEntry( + VarLayout* resultVarLayout, + LegalVarChain const& varChain, + LayoutResourceKind kind) +{ + // If the target already has an offset for this kind, bail out. + // + if(resultVarLayout->FindResourceInfo(kind)) + return; + + // Add the `ResourceInfo` that will represent the offset for + // this resource kind (it will be initialized to zero by default) + // + auto resultResInfo = resultVarLayout->findOrAddResourceInfo(kind); + + // Add in any contributions from the "pending" var chain, since + // that chain of offsets will accumulate to get the leaf offset + // within the pending data, which in this case we assume amounts + // to an *absolute* offset. + // + for(auto vv = varChain.pendingChain; vv; vv = vv->next ) + { + if( auto chainResInfo = vv->varLayout->FindResourceInfo(kind) ) + { + resultResInfo->index += chainResInfo->index; + resultResInfo->space += chainResInfo->space; + } + } + + // Subtract any contributions from the primary var chain, since + // we want the resulting offset to be relative to the same + // base as that chain. + // + for(auto vv = varChain.primaryChain; vv; vv = vv->next ) + { + if( auto chainResInfo = vv->varLayout->FindResourceInfo(kind) ) + { + resultResInfo->index -= chainResInfo->index; + resultResInfo->space -= chainResInfo->space; + } + } +} + + /// Create a variable layout for an field with "pending" type. + /// + /// The given `typeLayout` should represent the type of a field + /// that is being stored in "pending" data, but that now needs + /// to be made relative to the "primary" data, because we are + /// legalizing the pending data out of the code. + /// +static RefPtr _createOffsetVarLayout( + LegalVarChain const& varChain, + TypeLayout* typeLayout) +{ + RefPtr resultVarLayout = new VarLayout(); + + // For every resource kind the type consumes, we will + // compute an adjusted offset for the variable that + // encodes the (absolute) offset of the pending data + // in `varChain` relative to its primary data. + // + for( auto resInfo : typeLayout->resourceInfos ) + { + _addOffsetVarLayoutEntry(resultVarLayout, varChain, resInfo.kind); + } + + return resultVarLayout; +} + + /// Place offset information from `srcResInfo` onto `dstLayout`, + /// offset by whatever is in `offsetVarLayout` +static void addOffsetResInfo( + VarLayout* dstLayout, + VarLayout::ResourceInfo const& srcResInfo, + VarLayout* offsetVarLayout) +{ + auto kind = srcResInfo.kind; + auto dstResInfo = dstLayout->findOrAddResourceInfo(kind); + + dstResInfo->index = srcResInfo.index; + dstResInfo->space = srcResInfo.space; + + if( auto offsetResInfo = offsetVarLayout->findOrAddResourceInfo(kind) ) + { + dstResInfo->index += offsetResInfo->index; + dstResInfo->space += offsetResInfo->space; + } +} + + /// Create layout information for a wrapped buffer type. + /// + /// A wrapped buffer type encodes a buffer like `ConstantBuffer` + /// where `Foo` might have interface-type fields that have been + /// specialized to a concrete type. + /// + /// Consider: + /// + /// struct Car { IDriver driver; int mph; }; + /// ConstantBuffer machOne; + /// + /// In a case where the `machOne.driver` field has been specialized + /// to the type `SpeedRacer`, we need to generate a legalized + /// buffer layout something like: + /// + /// struct Car_0 { int mph; } + /// struct Wrapped { Car_0 car; SpeedRacer card_d; } + /// ConstantBuffer machOne; + /// + /// The layout information for the existing `machOne` clearly + /// can't apply because we have a new element type with new fields. + /// + /// This function is used to create a layout for a legalized + /// buffer type that requires wrapping, based on the original + /// type layout information and the variable layout information + /// of the surrounding context (e.g., the global shader parameter + /// that has this type). + /// +static RefPtr _createWrappedBufferTypeLayout( + TypeLayout* oldTypeLayout, + WrappedBufferPseudoType* wrappedBufferTypeInfo, + LegalVarChain const& outerVarChain) +{ + // We shouldn't get invoked unless there was a parameter group type, + // so we will sanity check for that just to be sure. + // + auto oldParameterGroupTypeLayout = as(oldTypeLayout); + SLANG_ASSERT(oldParameterGroupTypeLayout); + if(!oldParameterGroupTypeLayout) + return oldTypeLayout; + + // The original type must have been split between the direct/primary + // data and some amount of "pending" data to deal with interface-type + // data in the element type of the parameter group. + // + // The legalization step will have already flattened the data inside of + // the group to a single `struct` type, which places the primary data first, + // and then any pending data into additional fields. + // + // Our job is to compute a type layout that we can apply to that new + // element type, and to a parameter group surrounding it, that will + // re-create the original intention of the split layout (both primary + // and pending data) for a type that now only has the "primary" data. + // + RefPtr newTypeLayout = new ParameterGroupTypeLayout(); + newTypeLayout->type = oldTypeLayout->type; + newTypeLayout->rules = oldTypeLayout->rules; + newTypeLayout->uniformAlignment = oldTypeLayout->uniformAlignment; + for(auto resInfo : oldTypeLayout->resourceInfos) + newTypeLayout->addResourceUsage(resInfo); + + // Any fields in the "pending" data will have offset information + // that is relative to the pending data for their parent, and so on. + // We need to compute layout information that only includes primary + // data, so any offset information that is relative to the pending data + // needs to instead be relative to the primary data. That amounts to + // computing the absolute offset of each pending field, and then + // subtracting off the absolute offset of the primary data. + // + // We will compute the offset that needs to be added up front, + // and store it in the form of a `VarLayout`. The offsets we need + // can be computed from the `outerVarChain`, and we only need to + // store offset information for resource kinds actually consumed + // by the pending data type for the buffer as a whole (e.g., we + // don't need to apply offsetting to uniform bytes, because + // those don't show up in the resource usage of a constant buffer + // itself, and so the offsets already *are* relative to the start + // of the buffer). + // + auto offsetVarLayout = _createOffsetVarLayout(outerVarChain, oldTypeLayout->pendingDataTypeLayout); + LegalVarChainLink offsetVarChain(LegalVarChain(), offsetVarLayout); + + // We will start our construction of the pieces of the output + // type layout by looking at the "container" type/variable. + // + // A parameter block or constant buffer in Slang needs to + // distinguish between the resource usage of the thing in + // the block/buffer, vs. the resource usage of the block/buffer + // itself. Consider: + // + // struct Material { float4 color; Texture2D tex; } + // ConstantBuffer gMat; + // + // When compiling for Vulkan, the `gMat` constant buffer needs + // a `binding`, and the `tex` field does too, so the overall + // resource usage of `gMat` is two bindings, but we need a + // way to encode which of those bindings goes to `gMat.tex` + // and which to the constant buffer for `gMat` itself. + // + { + // We will start by extracting the "primary" part of the old + // container type/var layout, and constructing new objects + // that will represent the layout for our wrapped buffer. + // + auto oldPrimaryContainerVarLayout = oldParameterGroupTypeLayout->containerVarLayout; + auto oldPrimaryContainerTypeLayout = oldPrimaryContainerVarLayout->typeLayout; + + RefPtr newContainerTypeLayout = new TypeLayout(); + newContainerTypeLayout->type = oldPrimaryContainerTypeLayout->type; + + RefPtr newContainerVarLayout = new VarLayout(); + newContainerVarLayout->typeLayout = newContainerTypeLayout; + + newTypeLayout->containerVarLayout = newContainerVarLayout; + + // Whatever got allocated for the primary container should get copied + // over to the new layout (e.g., if we allocated a constant buffer + // for `gMat` then we need to retain that information). + // + newContainerTypeLayout->addResourceUsageFrom(oldPrimaryContainerTypeLayout); + for( auto resInfo : oldPrimaryContainerVarLayout->resourceInfos ) + { + auto newResInfo = newContainerVarLayout->findOrAddResourceInfo(resInfo.kind); + newResInfo->index = resInfo.index; + newResInfo->space = resInfo.space; + } + + // It is possible that a constant buffer and/or space didn't get + // allocated for the "primary" data, but ended up being required for + // the "pending" data (this would happen if, e.g., a constant buffer + // didn't appear to have any uniform data in it, but then once we + // plugged in concrete types for interface fields it did...), so + // we need to account for that case and copy over the relevant + // resource usage from the pending data, if there is any. + // + if( auto oldPendingContainerVarLayout = oldPrimaryContainerVarLayout->pendingVarLayout ) + { + // Whatever resources were allocated for the pending data type, + // our new combined container type needs to account for them + // (e.g., if we didn't have a constant buffer in the primary + // data, but one got allocated in the pending data, we need + // to end up with type layout information that includes a + // constnat buffer). + // + auto oldPendingContainerTypeLayout = oldPendingContainerVarLayout->typeLayout; + newContainerTypeLayout->addResourceUsageFrom(oldPendingContainerTypeLayout); + + // We also need to add offset information based on the "pending" + // var layout, but we need to deal with the fact that this information + // is currently stored relative to the pending var layout for the surrounding + // context (passed in as `outerVarChain.pendingChain`), but we need it to be + // relative to the primary layout for the surrounding context (`outerVarChain.primaryChain`). + // This is where the `offsetVarLayout` we computed above comes + // in handy, because it represents the value(s) we need to + // add to each of the per-resource-kind offsets. + // + for( auto resInfo : oldPendingContainerVarLayout->resourceInfos ) + { + addOffsetResInfo(newContainerVarLayout, resInfo, offsetVarLayout); + } + } + } + + // Now that we've dealt with the container variable, we can turn + // our attention to the element type. This is the part that + // actually got legalized and required us to create a "wrapped" + // buffer type in the first place, so we know that it will + // have both primary and "pending" parts. + // + // Let's start by extracting the fields we care about from + // the original element type/var layout, and constructing + // the objects we'll use to represent the type/var layout for + // the new element type. + // + auto oldElementVarLayout = oldParameterGroupTypeLayout->elementVarLayout; + auto oldElementTypeLayout = oldElementVarLayout->typeLayout; + + // Now matter what, the element type of a wrapped buffer + // will always have a structure type. + // + RefPtr newElementTypeLayout = new StructTypeLayout(); + newElementTypeLayout->type = oldElementTypeLayout->type; + + // The `wrappedBufferTypeInfo` that was passed in tells + // us how the fields of the original type got turned into + // zero or more fields in the new element type, so we + // need to follow its recursive structure to build + // layout information for each of the new fields. + // + // We will track a "chain" of parent variables that + // determines how we got to each leaf field, and is + // used to add up the offsets that will be stored + // in the new `VarLayout`s that get created. + // We know we need to add in some offsets (usually + // negative) to any fields that were pending data, + // so we will account for that in the initial + // chain of outer variables that we pass in. + // + LegalVarChain varChainForElementType; + varChainForElementType.primaryChain = nullptr; + varChainForElementType.pendingChain = offsetVarChain.primaryChain; + + _addFieldsToWrappedBufferElementTypeLayout( + oldElementTypeLayout, + newElementTypeLayout, + wrappedBufferTypeInfo->elementInfo, + varChainForElementType, + true); + + // A parameter group type layout holds a `VarLayout` for the element type, + // which encodes the offset of the element type with respect to the + // start of the parameter group as a whole (e.g., to handle the case + // where a constant buffer needs a `binding`, and so does its + // element type, so the offset to the first `binding` for the element + // type is one, not zero. + // + LegalVarChainLink elementVarChain(LegalVarChain(), oldParameterGroupTypeLayout->elementVarLayout); + auto newElementVarLayout = createVarLayout(elementVarChain, newElementTypeLayout); + newTypeLayout->elementVarLayout = newElementVarLayout; + + // For legacy/API reasons, we also need to compute a version of the + // element type where the offset stored in the `elementVarLayout` + // gets "baked in" to the fields of the element type. + // + newTypeLayout->offsetElementTypeLayout = applyOffsetToTypeLayout( + newElementTypeLayout, + newElementVarLayout); + + return newTypeLayout; +} + +static LegalVal declareVars( + IRTypeLegalizationContext* context, + IROp op, + LegalType type, + TypeLayout* inTypeLayout, + LegalVarChain const& inVarChain, + UnownedStringSlice nameHint, + IRInst* leafVar, + IRGlobalNameInfo* globalNameInfo, + bool isSpecial) +{ + LegalVarChain varChain = inVarChain; + TypeLayout* typeLayout = inTypeLayout; + if( isSpecial ) + { + if( varChain.pendingChain ) + { + varChain.primaryChain = varChain.pendingChain; + varChain.pendingChain = nullptr; + } + if( typeLayout ) + { + if( auto pendingTypeLayout = typeLayout->pendingDataTypeLayout ) + { + typeLayout = pendingTypeLayout; + } + } + } + + switch (type.flavor) + { + case LegalType::Flavor::none: + return LegalVal(); + + case LegalType::Flavor::simple: + return declareSimpleVar(context, op, type.getSimple(), typeLayout, varChain, nameHint, leafVar, globalNameInfo); + break; + + case LegalType::Flavor::implicitDeref: + { + // Just declare a variable of the pointed-to type, + // since we are removing the indirection. + + auto val = declareVars( + context, + op, + type.getImplicitDeref()->valueType, + typeLayout, + varChain, + nameHint, + leafVar, + globalNameInfo, + isSpecial); + return LegalVal::implicitDeref(val); + } + break; + + case LegalType::Flavor::pair: + { + auto pairType = type.getPair(); + auto ordinaryVal = declareVars(context, op, pairType->ordinaryType, typeLayout, varChain, nameHint, leafVar, globalNameInfo, false); + auto specialVal = declareVars(context, op, pairType->specialType, typeLayout, varChain, nameHint, leafVar, globalNameInfo, true); + return LegalVal::pair(ordinaryVal, specialVal, pairType->pairInfo); + } + + case LegalType::Flavor::tuple: + { + // Declare one variable for each element of the tuple + auto tupleType = type.getTuple(); + + RefPtr tupleVal = new TuplePseudoVal(); + + for (auto ee : tupleType->elements) + { + auto fieldLayout = getFieldLayout(typeLayout, ee.key); + RefPtr fieldTypeLayout = fieldLayout ? fieldLayout->typeLayout : nullptr; + + // If we have a type layout coming in, we really expect to have a layout for each field. + SLANG_ASSERT(fieldLayout || !typeLayout); + + // If we are processing layout information, then + // we need to create a new link in the chain + // of variables that will determine offsets + // for the eventual leaf fields... + // + LegalVarChainLink newVarChain(varChain, fieldLayout); + + UnownedStringSlice fieldNameHint; + String joinedNameHintStorage; + if( nameHint.size() ) + { + if( auto fieldNameHintDecoration = ee.key->findDecoration() ) + { + joinedNameHintStorage.append(nameHint); + joinedNameHintStorage.append("."); + joinedNameHintStorage.append(fieldNameHintDecoration->getName()); + + fieldNameHint = joinedNameHintStorage.getUnownedSlice(); + } + + } + + LegalVal fieldVal = declareVars( + context, + op, + ee.type, + fieldTypeLayout, + newVarChain, + fieldNameHint, + ee.key, + globalNameInfo, + true); + + TuplePseudoVal::Element element; + element.key = ee.key; + element.val = fieldVal; + tupleVal->elements.add(element); + } + + return LegalVal::tuple(tupleVal); + } + break; + + case LegalType::Flavor::wrappedBuffer: + { + auto wrappedBuffer = type.getWrappedBuffer(); + + auto wrappedTypeLayout = _createWrappedBufferTypeLayout(typeLayout, wrappedBuffer, varChain); + + auto innerVal = declareSimpleVar( + context, + op, + wrappedBuffer->simpleType, + wrappedTypeLayout, + varChain, + nameHint, + leafVar, + globalNameInfo); + + return LegalVal::wrappedBuffer(innerVal, wrappedBuffer->elementInfo); + } + + default: + SLANG_UNEXPECTED("unhandled"); + UNREACHABLE_RETURN(LegalVal()); + break; + } +} + +static LegalVal legalizeGlobalVar( + IRTypeLegalizationContext* context, + IRGlobalVar* irGlobalVar) +{ + // Legalize the type for the variable's value + auto originalValueType = irGlobalVar->getDataType()->getValueType(); + auto legalValueType = legalizeType( + context, + originalValueType); + + switch (legalValueType.flavor) + { + case LegalType::Flavor::simple: + // Easy case: the type is usable as-is, and we + // should just do that. + context->builder->setDataType( + irGlobalVar, + context->builder->getPtrType( + legalValueType.getSimple())); + return LegalVal::simple(irGlobalVar); + + default: + { + context->insertBeforeGlobal = irGlobalVar->getNextInst(); + + IRGlobalNameInfo globalNameInfo; + globalNameInfo.globalVar = irGlobalVar; + globalNameInfo.counter = 0; + + UnownedStringSlice nameHint = findNameHint(irGlobalVar); + context->builder->setInsertBefore(irGlobalVar); + LegalVal newVal = declareVars(context, kIROp_GlobalVar, legalValueType, nullptr, LegalVarChain(), nameHint, irGlobalVar, &globalNameInfo, context->isSpecialType(originalValueType)); + + // Register the new value as the replacement for the old + registerLegalizedValue(context, irGlobalVar, newVal); + + // Remove the old global from the module. + irGlobalVar->removeFromParent(); + context->replacedInstructions.add(irGlobalVar); + + return newVal; + } + break; + } +} + +static LegalVal legalizeGlobalConstant( + IRTypeLegalizationContext* context, + IRGlobalConstant* irGlobalConstant) +{ + // Legalize the type for the variable's value + auto legalValueType = legalizeType( + context, + irGlobalConstant->getFullType()); + + switch (legalValueType.flavor) + { + case LegalType::Flavor::simple: + // Easy case: the type is usable as-is, and we + // should just do that. + irGlobalConstant->setFullType(legalValueType.getSimple()); + return LegalVal::simple(irGlobalConstant); + + default: + { + context->insertBeforeGlobal = irGlobalConstant->getNextInst(); + + IRGlobalNameInfo globalNameInfo; + globalNameInfo.globalVar = irGlobalConstant; + globalNameInfo.counter = 0; + + // TODO: need to handle initializer here! + + UnownedStringSlice nameHint = findNameHint(irGlobalConstant); + context->builder->setInsertBefore(irGlobalConstant); + LegalVal newVal = declareVars(context, kIROp_GlobalConstant, legalValueType, nullptr, LegalVarChain(), nameHint, irGlobalConstant, &globalNameInfo, context->isSpecialType(irGlobalConstant->getDataType())); + + // Register the new value as the replacement for the old + registerLegalizedValue(context, irGlobalConstant, newVal); + + // Remove the old global from the module. + irGlobalConstant->removeFromParent(); + context->replacedInstructions.add(irGlobalConstant); + + return newVal; + } + break; + } +} + +static LegalVal legalizeGlobalParam( + IRTypeLegalizationContext* context, + IRGlobalParam* irGlobalParam) +{ + // Legalize the type for the variable's value + auto legalValueType = legalizeType( + context, + irGlobalParam->getFullType()); + + RefPtr varLayout = findVarLayout(irGlobalParam); + RefPtr typeLayout = varLayout ? varLayout->typeLayout : nullptr; + + switch (legalValueType.flavor) + { + case LegalType::Flavor::simple: + // Easy case: the type is usable as-is, and we + // should just do that. + irGlobalParam->setFullType(legalValueType.getSimple()); + return LegalVal::simple(irGlobalParam); + + default: + { + context->insertBeforeGlobal = irGlobalParam->getNextInst(); + + LegalVarChainLink varChain(LegalVarChain(), varLayout); + + IRGlobalNameInfo globalNameInfo; + globalNameInfo.globalVar = irGlobalParam; + globalNameInfo.counter = 0; + + // TODO: need to handle initializer here! + + UnownedStringSlice nameHint = findNameHint(irGlobalParam); + context->builder->setInsertBefore(irGlobalParam); + LegalVal newVal = declareVars(context, kIROp_GlobalParam, legalValueType, typeLayout, varChain, nameHint, irGlobalParam, &globalNameInfo, context->isSpecialType(irGlobalParam->getDataType())); + + // Register the new value as the replacement for the old + registerLegalizedValue(context, irGlobalParam, newVal); + + // Remove the old global from the module. + irGlobalParam->removeFromParent(); + context->replacedInstructions.add(irGlobalParam); + + return newVal; + } + break; + } +} + + +static void legalizeTypes( + IRTypeLegalizationContext* context) +{ + // Legalize all the top-level instructions in the module + auto module = context->module; + legalizeInstsInParent(context, module->moduleInst); + + // Clean up after any instructions we replaced along the way. + for (auto& lv : context->replacedInstructions) + { + lv->removeAndDeallocate(); + } +} + +// We use the same basic type legalization machinery for both simplifying +// away resource-type fields nested in `struct`s and for shuffling around +// exisential-box fields to get the layout right. +// +// The differences between the two passes come down to some very small +// distinctions about what types each pass considers "special" (e.g., +// resources in one case and existential boxes in the other), along +// with what they want to do when a uniform/constant buffer needs to +// be made where the element type is non-simple (that is, includes +// some fields of "special" type). +// +// The resource case is then the simpler one: +// +struct IRResourceTypeLegalizationContext : IRTypeLegalizationContext +{ + IRResourceTypeLegalizationContext(IRModule* module) + : IRTypeLegalizationContext(module) + {} + + bool isSpecialType(IRType* type) override + { + // For resource type legalization, the "special" types + // we are working with are resource types. + // + return isResourceType(type); + } + + LegalType createLegalUniformBufferType( + IROp op, + LegalType legalElementType) override + { + // The appropriate strategy for legalizing uniform buffers + // with resources inside already exists, so we can delegate to it. + // + return createLegalUniformBufferTypeForResources( + this, + op, + legalElementType); + } +}; + +// The case for legalizing existential box types is then similar. +// +struct IRExistentialTypeLegalizationContext : IRTypeLegalizationContext +{ + IRExistentialTypeLegalizationContext(IRModule* module) + : IRTypeLegalizationContext(module) + {} + + bool isSpecialType(IRType* inType) override + { + // The "special" types for our purposes are existential + // boxes, or arrays thereof. + // + auto type = unwrapArray(inType); + return as(type) != nullptr; + } + + LegalType createLegalUniformBufferType( + IROp op, + LegalType legalElementType) override + { + // We'll delegate the logic for creating uniform buffers + // over a mix of ordinary and existential-box types to + // a subroutine so it can live near the resource case. + // + // TODO: We should eventually try to refactor this code + // so that related functionality is grouped together. + // + return createLegalUniformBufferTypeForExistentials( + this, + op, + legalElementType); + } +}; + +// The main entry points that are used when transforming IR code +// to get it ready for lower-level codegen are then simple +// wrappers around `legalizeTypes()` that pick an appropriately +// specialized context type to use to get the job done. + +void legalizeResourceTypes( + IRModule* module, + DiagnosticSink* sink) +{ + SLANG_UNUSED(sink); + + IRResourceTypeLegalizationContext context(module); + legalizeTypes(&context); +} + +void legalizeExistentialTypeLayout( + IRModule* module, + DiagnosticSink* sink) +{ + SLANG_UNUSED(module); + SLANG_UNUSED(sink); + + IRExistentialTypeLegalizationContext context(module); + legalizeTypes(&context); +} + + +} diff --git a/source/slang/slang-ir-link.cpp b/source/slang/slang-ir-link.cpp new file mode 100644 index 000000000..4c1f72adb --- /dev/null +++ b/source/slang/slang-ir-link.cpp @@ -0,0 +1,1361 @@ +// slang-ir-link.cpp +#include "slang-ir-link.h" + +#include "slang-ir.h" +#include "slang-ir-insts.h" +#include "slang-mangle.h" + +namespace Slang +{ + +// Needed for lookup up entry-point layouts. +// +// TODO: maybe arrange so that codegen is driven from the layout layer +// instead of the input/request layer. +EntryPointLayout* findEntryPointLayout( + ProgramLayout* programLayout, + EntryPoint* EntryPoint); + +struct IRSpecSymbol : RefObject +{ + IRInst* irGlobalValue; + RefPtr nextWithSameName; +}; + +struct IRSpecEnv +{ + IRSpecEnv* parent = nullptr; + + // A map from original values to their cloned equivalents. + typedef Dictionary ClonedValueDictionary; + ClonedValueDictionary clonedValues; +}; + +struct IRSharedSpecContext +{ + // The code-generation target in use + CodeGenTarget target; + + // The specialized module we are building + RefPtr module; + + // A map from mangled symbol names to zero or + // more global IR values that have that name, + // in the *original* module. + typedef Dictionary> SymbolDictionary; + SymbolDictionary symbols; + + SharedIRBuilder sharedBuilderStorage; + IRBuilder builderStorage; + + // The "global" specialization environment. + IRSpecEnv globalEnv; +}; + +struct IRSpecContextBase +{ + // A map from the mangled name of a global variable + // to the layout to use for it. + Dictionary globalVarLayouts; + + IRSharedSpecContext* shared; + + IRSharedSpecContext* getShared() { return shared; } + + IRModule* getModule() { return getShared()->module; } + + IRSharedSpecContext::SymbolDictionary& getSymbols() { return getShared()->symbols; } + + // The current specialization environment to use. + IRSpecEnv* env = nullptr; + IRSpecEnv* getEnv() + { + // TODO: need to actually establish environments on contexts we create. + // + // Or more realistically we need to change the whole approach + // to specialization and cloning so that we don't try to share + // logic between two very different cases. + + + return env; + } + + // The IR builder to use for creating nodes + IRBuilder* builder; + + // A callback to be used when a value that is not registerd in `clonedValues` + // is needed during cloning. This gives the subtype a chance to intercept + // the operation and clone (or not) as needed. + virtual IRInst* maybeCloneValue(IRInst* originalVal) + { + return originalVal; + } +}; + +void registerClonedValue( + IRSpecContextBase* context, + IRInst* clonedValue, + IRInst* originalValue) +{ + if(!originalValue) + return; + + // TODO: now that things are scoped using environments, we + // shouldn't be running into the cases where a value with + // the same key already exists. This should be changed to + // an `Add()` call. + // + context->getEnv()->clonedValues[originalValue] = clonedValue; +} + +// Information on values to use when registering a cloned value +struct IROriginalValuesForClone +{ + IRInst* originalVal = nullptr; + IRSpecSymbol* sym = nullptr; + + IROriginalValuesForClone() {} + + IROriginalValuesForClone(IRInst* originalValue) + : originalVal(originalValue) + {} + + IROriginalValuesForClone(IRSpecSymbol* symbol) + : sym(symbol) + {} +}; + +void registerClonedValue( + IRSpecContextBase* context, + IRInst* clonedValue, + IROriginalValuesForClone const& originalValues) +{ + registerClonedValue(context, clonedValue, originalValues.originalVal); + for( auto s = originalValues.sym; s; s = s->nextWithSameName ) + { + registerClonedValue(context, clonedValue, s->irGlobalValue); + } +} + +IRInst* cloneInst( + IRSpecContextBase* context, + IRBuilder* builder, + IRInst* originalInst, + IROriginalValuesForClone const& originalValues); + +IRInst* cloneInst( + IRSpecContextBase* context, + IRBuilder* builder, + IRInst* originalInst) +{ + return cloneInst(context, builder, originalInst, originalInst); +} + + /// Clone any decorations from `originalValue` onto `clonedValue` +void cloneDecorations( + IRSpecContextBase* context, + IRInst* clonedValue, + IRInst* originalValue) +{ + // TODO: In many cases we might be able to use this as a general-purpose + // place to do cloning of *all* the children of an instruction, and + // not just its decorations. We should look to refactor this code + // later. + + IRBuilder builderStorage = *context->builder; + IRBuilder* builder = &builderStorage; + builder->setInsertInto(clonedValue); + + + SLANG_UNUSED(context); + for(auto originalDecoration : originalValue->getDecorations()) + { + cloneInst(context, builder, originalDecoration); + } + + // We will also clone the location here, just because this is a convenient bottleneck + clonedValue->sourceLoc = originalValue->sourceLoc; +} + + /// Clone any decorations and children from `originalValue` onto `clonedValue` +void cloneDecorationsAndChildren( + IRSpecContextBase* context, + IRInst* clonedValue, + IRInst* originalValue) +{ + IRBuilder builderStorage = *context->builder; + IRBuilder* builder = &builderStorage; + builder->setInsertInto(clonedValue); + + SLANG_UNUSED(context); + for(auto originalItem : originalValue->getDecorationsAndChildren()) + { + cloneInst(context, builder, originalItem); + } + + // We will also clone the location here, just because this is a convenient bottleneck + clonedValue->sourceLoc = originalValue->sourceLoc; +} + +// We use an `IRSpecContext` for the case where we are cloning +// code from one or more input modules to create a "linked" output +// module. Along the way, we will resolve profile-specific functions +// to the best definition for a given target. +// +struct IRSpecContext : IRSpecContextBase +{ + // Override the "maybe clone" logic so that we always clone + virtual IRInst* maybeCloneValue(IRInst* originalVal) override; +}; + + +IRInst* cloneGlobalValue(IRSpecContext* context, IRInst* originalVal); + +IRInst* cloneValue( + IRSpecContextBase* context, + IRInst* originalValue); + +IRType* cloneType( + IRSpecContextBase* context, + IRType* originalType); + +IRInst* IRSpecContext::maybeCloneValue(IRInst* originalValue) +{ + switch (originalValue->op) + { + case kIROp_StructType: + case kIROp_Func: + case kIROp_Generic: + case kIROp_GlobalVar: + case kIROp_GlobalConstant: + case kIROp_GlobalParam: + case kIROp_StructKey: + case kIROp_GlobalGenericParam: + case kIROp_WitnessTable: + return cloneGlobalValue(this, originalValue); + + case kIROp_BoolLit: + { + IRConstant* c = (IRConstant*)originalValue; + return builder->getBoolValue(c->value.intVal != 0); + } + break; + + + case kIROp_IntLit: + { + IRConstant* c = (IRConstant*)originalValue; + return builder->getIntValue(cloneType(this, c->getDataType()), c->value.intVal); + } + break; + + case kIROp_FloatLit: + { + IRConstant* c = (IRConstant*)originalValue; + return builder->getFloatValue(cloneType(this, c->getDataType()), c->value.floatVal); + } + break; + + case kIROp_StringLit: + { + IRConstant* c = (IRConstant*)originalValue; + return builder->getStringValue(c->getStringSlice()); + } + break; + + case kIROp_PtrLit: + { + IRConstant* c = (IRConstant*)originalValue; + return builder->getPtrValue(c->value.ptrVal); + } + break; + + default: + { + // In the deafult case, assume that we have some sort of "hoistable" + // instruction that requires us to create a clone of it. + + UInt argCount = originalValue->getOperandCount(); + IRInst* clonedValue = builder->createIntrinsicInst( + cloneType(this, originalValue->getFullType()), + originalValue->op, + argCount, nullptr); + registerClonedValue(this, clonedValue, originalValue); + for (UInt aa = 0; aa < argCount; ++aa) + { + IRInst* originalArg = originalValue->getOperand(aa); + IRInst* clonedArg = cloneValue(this, originalArg); + clonedValue->getOperands()[aa].init(clonedValue, clonedArg); + } + cloneDecorationsAndChildren(this, clonedValue, originalValue); + + addHoistableInst(builder, clonedValue); + + return clonedValue; + } + break; + } +} + +IRInst* cloneValue( + IRSpecContextBase* context, + IRInst* originalValue); + +// Find a pre-existing cloned value, or return null if none is available. +IRInst* findClonedValue( + IRSpecContextBase* context, + IRInst* originalValue) +{ + IRInst* clonedValue = nullptr; + for (auto env = context->getEnv(); env; env = env->parent) + { + if (env->clonedValues.TryGetValue(originalValue, clonedValue)) + { + return clonedValue; + } + } + + return nullptr; +} + +IRInst* cloneValue( + IRSpecContextBase* context, + IRInst* originalValue) +{ + if (!originalValue) + return nullptr; + + if (IRInst* clonedValue = findClonedValue(context, originalValue)) + return clonedValue; + + return context->maybeCloneValue(originalValue); +} + +IRType* cloneType( + IRSpecContextBase* context, + IRType* originalType) +{ + return (IRType*)cloneValue(context, originalType); +} + +void cloneGlobalValueWithCodeCommon( + IRSpecContextBase* context, + IRGlobalValueWithCode* clonedValue, + IRGlobalValueWithCode* originalValue); + +IRRate* cloneRate( + IRSpecContextBase* context, + IRRate* rate) +{ + return (IRRate*) cloneType(context, rate); +} + +void maybeSetClonedRate( + IRSpecContextBase* context, + IRBuilder* builder, + IRInst* clonedValue, + IRInst* originalValue) +{ + if(auto rate = originalValue->getRate() ) + { + clonedValue->setFullType(builder->getRateQualifiedType( + cloneRate(context, rate), + clonedValue->getFullType())); + } +} + +IRGlobalVar* cloneGlobalVarImpl( + IRSpecContextBase* context, + IRBuilder* builder, + IRGlobalVar* originalVar, + IROriginalValuesForClone const& originalValues) +{ + auto clonedVar = builder->createGlobalVar( + cloneType(context, originalVar->getDataType()->getValueType())); + + maybeSetClonedRate(context, builder, clonedVar, originalVar); + + registerClonedValue(context, clonedVar, originalValues); + + // Clone any code in the body of the variable, since this + // represents the initializer. + cloneGlobalValueWithCodeCommon( + context, + clonedVar, + originalVar); + + return clonedVar; +} + +IRGlobalConstant* cloneGlobalConstantImpl( + IRSpecContextBase* context, + IRBuilder* builder, + IRGlobalConstant* originalVal, + IROriginalValuesForClone const& originalValues) +{ + auto clonedVal = builder->createGlobalConstant( + cloneType(context, originalVal->getFullType())); + registerClonedValue(context, clonedVal, originalValues); + + // Clone any code in the body of the constant, since this + // represents the initializer. + cloneGlobalValueWithCodeCommon( + context, + clonedVal, + originalVal); + + return clonedVal; +} + +void cloneSimpleGlobalValueImpl( + IRSpecContextBase* context, + IRInst* originalInst, + IROriginalValuesForClone const& originalValues, + IRInst* clonedInst, + bool registerValue = true) +{ + if (registerValue) + registerClonedValue(context, clonedInst, originalValues); + + // Set up an IR builder for inserting into the inst + IRBuilder builderStorage = *context->builder; + IRBuilder* builder = &builderStorage; + builder->setInsertInto(clonedInst); + + // Clone any children of the instruction + for (auto child : originalInst->getDecorationsAndChildren()) + { + cloneInst(context, builder, child); + } +} + +IRGlobalParam* cloneGlobalParamImpl( + IRSpecContextBase* context, + IRBuilder* builder, + IRGlobalParam* originalVal, + IROriginalValuesForClone const& originalValues) +{ + auto clonedVal = builder->createGlobalParam( + cloneType(context, originalVal->getFullType())); + cloneSimpleGlobalValueImpl(context, originalVal, originalValues, clonedVal); + + if(auto linkage = originalVal->findDecoration()) + { + auto mangledName = String(linkage->getMangledName()); + VarLayout* layout = nullptr; + if (context->globalVarLayouts.TryGetValue(mangledName, layout)) + { + builder->addLayoutDecoration(clonedVal, layout); + } + } + + return clonedVal; +} + +IRGeneric* cloneGenericImpl( + IRSpecContextBase* context, + IRBuilder* builder, + IRGeneric* originalVal, + IROriginalValuesForClone const& originalValues) +{ + auto clonedVal = builder->emitGeneric(); + registerClonedValue(context, clonedVal, originalValues); + + // Clone any code in the body of the generic, since this + // computes its result value. + cloneGlobalValueWithCodeCommon( + context, + clonedVal, + originalVal); + + return clonedVal; +} + +IRStructKey* cloneStructKeyImpl( + IRSpecContextBase* context, + IRBuilder* builder, + IRStructKey* originalVal, + IROriginalValuesForClone const& originalValues) +{ + auto clonedVal = builder->createStructKey(); + cloneSimpleGlobalValueImpl(context, originalVal, originalValues, clonedVal); + return clonedVal; +} + +IRGlobalGenericParam* cloneGlobalGenericParamImpl( + IRSpecContextBase* context, + IRBuilder* builder, + IRGlobalGenericParam* originalVal, + IROriginalValuesForClone const& originalValues) +{ + auto clonedVal = builder->emitGlobalGenericParam(); + cloneSimpleGlobalValueImpl(context, originalVal, originalValues, clonedVal); + return clonedVal; +} + + +IRWitnessTable* cloneWitnessTableImpl( + IRSpecContextBase* context, + IRBuilder* builder, + IRWitnessTable* originalTable, + IROriginalValuesForClone const& originalValues, + IRWitnessTable* dstTable = nullptr, + bool registerValue = true) +{ + auto clonedTable = dstTable ? dstTable : builder->createWitnessTable(); + cloneSimpleGlobalValueImpl(context, originalTable, originalValues, clonedTable, registerValue); + return clonedTable; +} + +IRWitnessTable* cloneWitnessTableWithoutRegistering( + IRSpecContextBase* context, + IRBuilder* builder, + IRWitnessTable* originalTable, + IRWitnessTable* dstTable = nullptr) +{ + return cloneWitnessTableImpl(context, builder, originalTable, IROriginalValuesForClone(), dstTable, false); +} + +IRStructType* cloneStructTypeImpl( + IRSpecContextBase* context, + IRBuilder* builder, + IRStructType* originalStruct, + IROriginalValuesForClone const& originalValues) +{ + auto clonedStruct = builder->createStructType(); + cloneSimpleGlobalValueImpl(context, originalStruct, originalValues, clonedStruct); + return clonedStruct; +} + + +IRInterfaceType* cloneInterfaceTypeImpl( + IRSpecContextBase* context, + IRBuilder* builder, + IRInterfaceType* originalInterface, + IROriginalValuesForClone const& originalValues) +{ + auto clonedInterface = builder->createInterfaceType(); + cloneSimpleGlobalValueImpl(context, originalInterface, originalValues, clonedInterface); + return clonedInterface; +} + +void cloneGlobalValueWithCodeCommon( + IRSpecContextBase* context, + IRGlobalValueWithCode* clonedValue, + IRGlobalValueWithCode* originalValue) +{ + // Next we are going to clone the actual code. + IRBuilder builderStorage = *context->builder; + IRBuilder* builder = &builderStorage; + builder->setInsertInto(clonedValue); + + cloneDecorations(context, clonedValue, originalValue); + + // We will walk through the blocks of the function, and clone each of them. + // + // We need to create the cloned blocks first, and then walk through them, + // because blocks might be forward referenced (this is not possible + // for other cases of instructions). + for (auto originalBlock = originalValue->getFirstBlock(); + originalBlock; + originalBlock = originalBlock->getNextBlock()) + { + IRBlock* clonedBlock = builder->createBlock(); + clonedValue->addBlock(clonedBlock); + registerClonedValue(context, clonedBlock, originalBlock); + +#if 0 + // We can go ahead and clone parameters here, while we are at it. + builder->curBlock = clonedBlock; + for (auto originalParam = originalBlock->getFirstParam(); + originalParam; + originalParam = originalParam->getNextParam()) + { + IRParam* clonedParam = builder->emitParam( + context->maybeCloneType( + originalParam->getFullType())); + cloneDecorations(context, clonedParam, originalParam); + registerClonedValue(context, clonedParam, originalParam); + } +#endif + } + + // Okay, now we are in a good position to start cloning + // the instructions inside the blocks. + { + IRBlock* ob = originalValue->getFirstBlock(); + IRBlock* cb = clonedValue->getFirstBlock(); + while (ob) + { + SLANG_ASSERT(cb); + + builder->setInsertInto(cb); + for (auto oi = ob->getFirstInst(); oi; oi = oi->getNextInst()) + { + cloneInst(context, builder, oi); + } + + ob = ob->getNextBlock(); + cb = cb->getNextBlock(); + } + } + +} + +void checkIRDuplicate(IRInst* inst, IRInst* moduleInst, UnownedStringSlice const& mangledName) +{ +#ifdef _DEBUG + for (auto child : moduleInst->getDecorationsAndChildren()) + { + if (child == inst) + continue; + + if(auto childLinkage = child->findDecoration()) + { + if(mangledName == childLinkage->getMangledName()) + { + SLANG_UNEXPECTED("duplicate global instruction"); + } + } + } +#else + SLANG_UNREFERENCED_PARAMETER(inst); + SLANG_UNREFERENCED_PARAMETER(moduleInst); + SLANG_UNREFERENCED_PARAMETER(mangledName); +#endif +} + +void cloneFunctionCommon( + IRSpecContextBase* context, + IRFunc* clonedFunc, + IRFunc* originalFunc, + bool checkDuplicate = true) +{ + // First clone all the simple properties. + clonedFunc->setFullType(cloneType(context, originalFunc->getFullType())); + + cloneGlobalValueWithCodeCommon( + context, + clonedFunc, + originalFunc); + + // Shuffle the function to the end of the list, because + // it needs to follow its dependencies. + // + // TODO: This isn't really a good requirement to place on the IR... + clonedFunc->moveToEnd(); + + if( checkDuplicate ) + { + if( auto linkage = clonedFunc->findDecoration() ) + { + checkIRDuplicate(clonedFunc, context->getModule()->getModuleInst(), linkage->getMangledName()); + } + } +} + +// We will forward-declare the subroutine for eagerly specializing +// an IR-level generic to argument values, because `specializeIRForEntryPoint` +// needs to perform this operation even though it is logically part of +// the later generic specialization pass. +// +IRInst* specializeGeneric( + IRSpecialize* specializeInst); + +IRFunc* specializeIRForEntryPoint( + IRSpecContext* context, + EntryPoint* entryPoint, + EntryPointLayout* entryPointLayout) +{ + // We start by looking up the IR symbol that + // matches the mangled name given to the + // function we want to emit. + // + // Note: the function decl-ref may refer to + // a specialization of a generic function, + // so that the mangled name of the decl-ref is + // not the same as the mangled name of the decl. + // + auto mangledName = getMangledName(entryPoint->getFuncDeclRef()); + RefPtr sym; + if (!context->getSymbols().TryGetValue(mangledName, sym)) + { + SLANG_UNEXPECTED("no matching IR symbol"); + return nullptr; + } + + // TODO: deal with the case where we might + // have multiple (profile-overloaded) versions... + // + auto originalVal = sym->irGlobalValue; + + // We will start by cloning the entry point reference + // like any other global value. + // + auto clonedVal = cloneGlobalValue(context, originalVal); + + // In the case where the user is requesting a specialization + // of a generic entry point, we have a bit of a problem. + // + // This function is expected to return an `IRFunc` and + // subsequent passes expect to find, e.g., layout information + // attached to the parameters of such a func. + // + // In the generic case, the `clonedValue` won't be an + // `IRFunc`, but instead an `IRSpecialize`. + // + if(auto clonedSpec = as(clonedVal)) + { + // The Right Thing to do here is to perform some + // amount of generic specialization, at least + // until we get back an `IRFunc`. + // + // The dangerous thing is that the generic specialization + // pass can, in principle, change the signature of + // functions, so that attaching parameter layout + // information *after* specialization might not work. + // + // The compromise we make here is to directly + // invoke the logic for specializing a generic. + // + // In theory this isn't valid, because there is no + // way we can register the specialized function we + // create so that it would be re-used by other instantiations + // with the same arguments (because we cannot be + // sure the generic arguments are themselves fully specialized) + // + // In practice this isn't really a problem, because + // we don't want to share the definition between + // an entry point and an ordinary function anyway. + // + clonedVal = specializeGeneric(clonedSpec); + } + + // TODO: If there is an existential-related decoration + // on the entry point, we need to transfer it over + // to the specialized function. + if( auto bindExistentialSlots = originalVal->findDecorationImpl(kIROp_BindExistentialSlotsDecoration) ) + { + if( !clonedVal->findDecorationImpl(kIROp_BindExistentialSlotsDecoration) ) + { + IRBuilder builderStorage = *context->builder; + IRBuilder* builder = &builderStorage; + builder->setInsertInto(clonedVal); + + auto clonedBind = cloneInst(context, builder, bindExistentialSlots); + clonedBind->moveToStart(); + } + } + + + auto clonedFunc = as(clonedVal); + if(!clonedFunc) + { + SLANG_UNEXPECTED("expected entry point to be a function"); + return nullptr; + } + + if( !clonedFunc->findDecorationImpl(kIROp_KeepAliveDecoration) ) + { + context->builder->addKeepAliveDecoration(clonedFunc); + } + + // We need to attach the layout information for + // the entry point to this declaration, so that + // we can use it to inform downstream code emit. + // + context->builder->addLayoutDecoration( + clonedFunc, + entryPointLayout); + + // We will also go on and attach layout information + // to the function parameters, so that we have it + // available directly on the parameters, rather + // than having to look it up on the original entry-point layout. + if( auto firstBlock = clonedFunc->getFirstBlock() ) + { + auto paramsStructLayout = getScopeStructLayout(entryPointLayout); + Index paramLayoutCount = paramsStructLayout->fields.getCount(); + Index paramCounter = 0; + for( auto pp = firstBlock->getFirstParam(); pp; pp = pp->getNextParam() ) + { + Index paramIndex = paramCounter++; + if( paramIndex < paramLayoutCount ) + { + auto paramLayout = paramsStructLayout->fields[paramIndex]; + context->builder->addLayoutDecoration( + pp, + paramLayout); + } + else + { + SLANG_UNEXPECTED("too many parameters"); + } + } + } + + return clonedFunc; +} + +// Get a string form of the target so that we can +// use it to match against target-specialization modifiers +// +// TODO: We shouldn't be using strings for this. +String getTargetName(IRSpecContext* context) +{ + switch( context->shared->target ) + { + case CodeGenTarget::HLSL: + return "hlsl"; + + case CodeGenTarget::GLSL: + return "glsl"; + + case CodeGenTarget::CSource: + return "c"; + + case CodeGenTarget::CPPSource: + return "cpp"; + + default: + SLANG_UNEXPECTED("unhandled case"); + UNREACHABLE_RETURN("unknown"); + } +} + +// How specialized is a given declaration for the chosen target? +enum class TargetSpecializationLevel +{ + specializedForOtherTarget = 0, + notSpecialized, + specializedForTarget, +}; + +TargetSpecializationLevel getTargetSpecialiationLevel( + IRInst* inVal, + String const& targetName) +{ + // HACK: Currently the front-end is placing modifiers related + // to target specialization on nodes like functions, even when + // those functions are being returned by a generic. This + // means that we need to try and inspect the value being + // returned by the generic if we are looking at a generic. + IRInst* val = inVal; + while( auto genericVal = as(val) ) + { + auto firstBlock = genericVal->getFirstBlock(); + if(!firstBlock) break; + + auto returnInst = as(firstBlock->getLastInst()); + if(!returnInst) break; + + val = returnInst->getVal(); + } + + TargetSpecializationLevel result = TargetSpecializationLevel::notSpecialized; + for(auto dd : val->getDecorations()) + { + if(dd->op != kIROp_TargetDecoration) + continue; + + auto decoration = (IRTargetDecoration*) dd; + if(String(decoration->getTargetName()) == targetName) + return TargetSpecializationLevel::specializedForTarget; + + result = TargetSpecializationLevel::specializedForOtherTarget; + } + + return result; +} + +// Is `newVal` marked as being a better match for our +// chosen code-generation target? +// +// TODO: there is a missing step here where we need +// to check if things are even available in the first place... +bool isBetterForTarget( + IRSpecContext* context, + IRInst* newVal, + IRInst* oldVal) +{ + String targetName = getTargetName(context); + + // For right now every declaration might have zero or more + // modifiers, representing the targets for which it is specialized. + // Each modifier has a single string "tag" to represent a target. + // We thus decide that a declaration is "more specialized" by: + // + // - Does it have a modifier with a tag with the string for the current target? + // If yes, it is the most specialized it can be. + // + // - Does it have a no tags? Then it is "unspecialized" and that is okay. + // + // - Does it have a modifier with a tag for a *different* target? + // If yes, then it shouldn't even be usable on this target. + // + // Longer term a better approach is to think of this in terms + // of a "disjunction of conjunctions" that is: + // + // (A and B and C) or (A and D) or (E) or (F and G) ... + // + // A code generation target would then consist of a + // conjunction of invidual tags: + // + // (HLSL and SM_4_0 and Vertex and ...) + // + // A declaration is *applicable* on a target if one of + // its conjunctions of tags is a subset of the target's. + // + // One declaration is *better* than another on a target + // if it is applicable and its tags are a superset + // of the other's. + + auto newLevel = getTargetSpecialiationLevel(newVal, targetName); + auto oldLevel = getTargetSpecialiationLevel(oldVal, targetName); + if(newLevel != oldLevel) + return UInt(newLevel) > UInt(oldLevel); + + // All preceding factors being equal, an `[export]` is better + // than an `[import]`. + // + bool newIsExport = newVal->findDecoration() != nullptr; + bool oldIsExport = oldVal->findDecoration() != nullptr; + if(newIsExport != oldIsExport) + return newIsExport; + + // All preceding factors being equal, a definition is + // better than a declaration. + auto newIsDef = isDefinition(newVal); + auto oldIsDef = isDefinition(oldVal); + if (newIsDef != oldIsDef) + return newIsDef; + + return false; +} + +IRFunc* cloneFuncImpl( + IRSpecContextBase* context, + IRBuilder* builder, + IRFunc* originalFunc, + IROriginalValuesForClone const& originalValues) +{ + auto clonedFunc = builder->createFunc(); + registerClonedValue(context, clonedFunc, originalValues); + cloneFunctionCommon(context, clonedFunc, originalFunc); + return clonedFunc; +} + + +IRInst* cloneInst( + IRSpecContextBase* context, + IRBuilder* builder, + IRInst* originalInst, + IROriginalValuesForClone const& originalValues) +{ + switch (originalInst->op) + { + // We need to special-case any instruction that is not + // allocated like an ordinary `IRInst` with trailing args. + case kIROp_Func: + return cloneFuncImpl(context, builder, cast(originalInst), originalValues); + + case kIROp_GlobalVar: + return cloneGlobalVarImpl(context, builder, cast(originalInst), originalValues); + + case kIROp_GlobalConstant: + return cloneGlobalConstantImpl(context, builder, cast(originalInst), originalValues); + + case kIROp_GlobalParam: + return cloneGlobalParamImpl(context, builder, cast(originalInst), originalValues); + + case kIROp_WitnessTable: + return cloneWitnessTableImpl(context, builder, cast(originalInst), originalValues); + + case kIROp_StructType: + return cloneStructTypeImpl(context, builder, cast(originalInst), originalValues); + + case kIROp_InterfaceType: + return cloneInterfaceTypeImpl(context, builder, cast(originalInst), originalValues); + + case kIROp_Generic: + return cloneGenericImpl(context, builder, cast(originalInst), originalValues); + + case kIROp_StructKey: + return cloneStructKeyImpl(context, builder, cast(originalInst), originalValues); + + case kIROp_GlobalGenericParam: + return cloneGlobalGenericParamImpl(context, builder, cast(originalInst), originalValues); + + default: + break; + } + + // The common case is that we just need to construct a cloned + // instruction with the right number of operands, intialize + // it, and then add it to the sequence. + UInt argCount = originalInst->getOperandCount(); + IRInst* clonedInst = builder->createIntrinsicInst( + cloneType(context, originalInst->getFullType()), + originalInst->op, + argCount, nullptr); + registerClonedValue(context, clonedInst, originalValues); + auto oldBuilder = context->builder; + context->builder = builder; + for (UInt aa = 0; aa < argCount; ++aa) + { + IRInst* originalArg = originalInst->getOperand(aa); + IRInst* clonedArg = cloneValue(context, originalArg); + clonedInst->getOperands()[aa].init(clonedInst, clonedArg); + } + builder->addInst(clonedInst); + context->builder = oldBuilder; + cloneDecorations(context, clonedInst, originalInst); + + return clonedInst; +} + +IRInst* cloneGlobalValueImpl( + IRSpecContext* context, + IRInst* originalInst, + IROriginalValuesForClone const& originalValues) +{ + auto clonedValue = cloneInst(context, &context->shared->builderStorage, originalInst, originalValues); + clonedValue->moveToEnd(); + return clonedValue; +} + + + /// Clone a global value, which has the given `originalLinkage`. + /// + /// The `originalVal` is a known global IR value with that linkage, if one is available. + /// (It is okay for this parameter to be null). + /// +IRInst* cloneGlobalValueWithLinkage( + IRSpecContext* context, + IRInst* originalVal, + IRLinkageDecoration* originalLinkage) +{ + // If the global value being cloned is already in target module, don't clone + // Why checking this? + // When specializing a generic function G (which is already in target module), + // where G calls a normal function F (which is already in target module), + // then when we are making a copy of G via cloneFuncCommom(), it will recursively clone F, + // however we don't want to make a duplicate of F in the target module. + if (originalVal->getParent() == context->getModule()->getModuleInst()) + return originalVal; + + // Check if we've already cloned this value, for the case where + // an original value has already been established. + if (originalVal) + { + if (IRInst* clonedVal = findClonedValue(context, originalVal)) + { + return clonedVal; + } + } + + if(!originalLinkage) + { + // If there is no mangled name, then we assume this is a local symbol, + // and it can't possibly have multiple declarations. + return cloneGlobalValueImpl(context, originalVal, IROriginalValuesForClone(originalVal)); + } + + // + // We will scan through all of the available declarations + // with the same mangled name as `originalVal` and try + // to pick the "best" one for our target. + + auto mangledName = String(originalLinkage->getMangledName()); + RefPtr sym; + if( !context->getSymbols().TryGetValue(mangledName, sym) ) + { + if(!originalVal) + return nullptr; + + // This shouldn't happen! + SLANG_UNEXPECTED("no matching values registered"); + UNREACHABLE_RETURN(cloneGlobalValueImpl(context, originalVal, IROriginalValuesForClone())); + } + + // We will try to track the "best" declaration we can find. + // + // Generally, one declaration wil lbe better than another if it is + // more specialized for the chosen target. Otherwise, we simply favor + // definitions over declarations. + // + IRInst* bestVal = sym->irGlobalValue; + for( auto ss = sym->nextWithSameName; ss; ss = ss->nextWithSameName ) + { + IRInst* newVal = ss->irGlobalValue; + if(isBetterForTarget(context, newVal, bestVal)) + bestVal = newVal; + } + + // Check if we've already cloned this value, for the case where + // we didn't have an original value (just a name), but we've + // now found a representative value. + if (!originalVal) + { + if (IRInst* clonedVal = findClonedValue(context, bestVal)) + { + return clonedVal; + } + } + + return cloneGlobalValueImpl(context, bestVal, IROriginalValuesForClone(sym)); +} + +// Clone a global value, where `originalVal` is one declaration/definition, but we might +// have to consider others, in order to find the "best" version of the symbol. +IRInst* cloneGlobalValue(IRSpecContext* context, IRInst* originalVal) +{ + // We are being asked to clone a particular global value, but in + // the IR that comes out of the front-end there could still + // be multiple, target-specific, declarations of any given + // global value, all of which share the same mangled name. + return cloneGlobalValueWithLinkage( + context, + originalVal, + originalVal->findDecoration()); +} + +void insertGlobalValueSymbol( + IRSharedSpecContext* sharedContext, + IRInst* gv) +{ + auto linkage = gv->findDecoration(); + + // Don't try to register a symbol for global values + // that don't have linkage. + // + if (!linkage) + return; + + auto mangledName = String(linkage->getMangledName()); + + RefPtr sym = new IRSpecSymbol(); + sym->irGlobalValue = gv; + + RefPtr prev; + if (sharedContext->symbols.TryGetValue(mangledName, prev)) + { + sym->nextWithSameName = prev->nextWithSameName; + prev->nextWithSameName = sym; + } + else + { + sharedContext->symbols.Add(mangledName, sym); + } +} + +void insertGlobalValueSymbols( + IRSharedSpecContext* sharedContext, + IRModule* originalModule) +{ + if (!originalModule) + return; + + for(auto ii : originalModule->getGlobalInsts()) + { + insertGlobalValueSymbol(sharedContext, ii); + } +} + +void initializeSharedSpecContext( + IRSharedSpecContext* sharedContext, + Session* session, + IRModule* module, + CodeGenTarget target) +{ + + SharedIRBuilder* sharedBuilder = &sharedContext->sharedBuilderStorage; + sharedBuilder->module = nullptr; + sharedBuilder->session = session; + + IRBuilder* builder = &sharedContext->builderStorage; + builder->sharedBuilder = sharedBuilder; + + if( !module ) + { + module = builder->createModule(); + } + + sharedBuilder->module = module; + sharedContext->module = module; + sharedContext->target = target; +} + +struct IRSpecializationState +{ + ProgramLayout* programLayout; + CodeGenTarget target; + TargetRequest* targetReq; + + IRModule* irModule = nullptr; + + IRSharedSpecContext sharedContextStorage; + IRSpecContext contextStorage; + + IRSpecEnv globalEnv; + + IRSharedSpecContext* getSharedContext() { return &sharedContextStorage; } + IRSpecContext* getContext() { return &contextStorage; } + + IRSpecializationState() + { + contextStorage.env = &globalEnv; + } + + ~IRSpecializationState() + { + contextStorage = IRSpecContext(); + sharedContextStorage = IRSharedSpecContext(); + } +}; + +LinkedIR linkIR( + BackEndCompileRequest* compileRequest, + EntryPoint* entryPoint, + ProgramLayout* programLayout, + CodeGenTarget target, + TargetRequest* targetReq) +{ + auto sink = compileRequest->getSink(); + + IRSpecializationState stateStorage; + auto state = &stateStorage; + + state->programLayout = programLayout; + state->target = target; + state->targetReq = targetReq; + + auto program = compileRequest->getProgram(); + + auto sharedContext = state->getSharedContext(); + initializeSharedSpecContext( + sharedContext, + compileRequest->getSession(), + nullptr, + target); + + state->irModule = sharedContext->module; + + // We need to be able to look up IR definitions for any symbols in + // modules that the program depends on (transitively). To + // accelerate lookup, we will create a symbol table for looking + // up IR definitions by their mangled name. + // + auto originalProgramIRModule = program->getOrCreateIRModule(sink); + insertGlobalValueSymbols(sharedContext, originalProgramIRModule); + for (auto module : program->getModuleDependencies()) + { + insertGlobalValueSymbols(sharedContext, module->getIRModule()); + } + + auto context = state->getContext(); + context->shared = sharedContext; + context->builder = &sharedContext->builderStorage; + + // Next, we want to optimize lookup for layout information + // associated with global declarations, so that we can + // look things up based on the IR values (using mangled names) + // + // Note: We are scanning over all the key-value pairs for + // entries in the global scope, to account for the fact + // that the "same" shader parameter could be declared in + // multiple translation units, and thus end up with + // multiple mangled names (when the unique translation + // unit name gets involved). + // + auto globalStructLayout = getScopeStructLayout(programLayout); + for(auto entry : globalStructLayout->mapVarToLayout) + { + auto mangledName = getMangledName(entry.Key); + auto globalVarLayout = entry.Value; + context->globalVarLayouts.AddIfNotExists(mangledName, globalVarLayout); + } + + context->builder->setInsertInto(context->getModule()->getModuleInst()); + + // for now, clone all unreferenced witness tables + // + // TODO: This step should *not* be needed with the current IR + // specialization approach, so we should consider removing it. + // + for (auto sym :context->getSymbols()) + { + if (sym.Value->irGlobalValue->op == kIROp_WitnessTable) + cloneGlobalValue(context, (IRWitnessTable*)sym.Value->irGlobalValue); + } + + auto entryPointLayout = findEntryPointLayout(programLayout, entryPoint); + + // Next, we make sure to clone the global value for + // the entry point function itself, and rely on + // this step to recursively copy over anything else + // it might reference. + auto irEntryPoint = specializeIRForEntryPoint(context, entryPoint, entryPointLayout); + + // HACK: right now the bindings for global generic parameters are coming in + // as part of the original IR module, and we need to make sure these get + // copied over, even if they aren't referenced. + // + for(auto inst : originalProgramIRModule->getGlobalInsts()) + { + auto bindInst = as(inst); + if(!bindInst) + continue; + + cloneValue(context, bindInst); + } + + for(auto inst : originalProgramIRModule->getGlobalInsts()) + { + if(inst->op != kIROp_BindGlobalExistentialSlots) + continue; + + cloneValue(context, inst); + } + + // HACK: we need to ensure that any tagged union types + // in the IR module have layout information copied over to them. + // + // Note that we do this *after* cloning the `bindGlobalGenericParam` + // instructions, since we expected the tagged union type(s) to + // be referenced by them. + // + for( auto taggedUnionTypeLayout : entryPointLayout->taggedUnionTypeLayouts ) + { + auto taggedUnionType = taggedUnionTypeLayout->getType(); + auto mangledName = getMangledTypeName(taggedUnionType); + + RefPtr sym; + if(!context->getSymbols().TryGetValue(mangledName, sym)) + continue; + + IRInst* clonedType = findClonedValue(context, sym->irGlobalValue); + if(!clonedType) + continue; + + context->builder->addLayoutDecoration(clonedType, taggedUnionTypeLayout); + } + + // TODO: *technically* we should consider the case where + // we have global variables with initializers, since + // these should get run whether or not the entry point + // references them. + + // Now that we've cloned the entry point and everything + // it refers to, we can package up the data we return + // to the caller. + // + LinkedIR linkedIR; + linkedIR.module = state->irModule; + linkedIR.entryPoint = irEntryPoint; + return linkedIR; +} + + + +} // namespace Slang diff --git a/source/slang/slang-ir-link.h b/source/slang/slang-ir-link.h new file mode 100644 index 000000000..89486b28f --- /dev/null +++ b/source/slang/slang-ir-link.h @@ -0,0 +1,27 @@ +// slang-ir-link.h +#pragma once + +#include "slang-compiler.h" + +namespace Slang +{ + struct LinkedIR + { + RefPtr module; + IRFunc* entryPoint; + }; + + + // Clone the IR values reachable from the given entry point + // into the IR module associated with the specialization state. + // When multiple definitions of a symbol are found, the one + // that is best specialized for the given `targetReq` will be + // used. + // + LinkedIR linkIR( + BackEndCompileRequest* compileRequest, + EntryPoint* entryPoint, + ProgramLayout* programLayout, + CodeGenTarget target, + TargetRequest* targetReq); +} diff --git a/source/slang/slang-ir-missing-return.cpp b/source/slang/slang-ir-missing-return.cpp new file mode 100644 index 000000000..527fdda5f --- /dev/null +++ b/source/slang/slang-ir-missing-return.cpp @@ -0,0 +1,43 @@ +// ir-missing-return.cpp +#include "slang-ir-missing-return.h" + +#include "slang-ir.h" +#include "slang-ir-insts.h" + +namespace Slang { + +class DiagnosticSink; +struct IRModule; + +void checkForMissingReturnsRec( + IRInst* inst, + DiagnosticSink* sink) +{ + if( auto code = as(inst) ) + { + for( auto block : code->getBlocks() ) + { + auto terminator = block->getTerminator(); + + if( auto missingReturn = as(terminator) ) + { + sink->diagnose(missingReturn, Diagnostics::missingReturn); + } + } + } + + for( auto childInst : inst->getDecorationsAndChildren() ) + { + checkForMissingReturnsRec(childInst, sink); + } +} + +void checkForMissingReturns( + IRModule* module, + DiagnosticSink* sink) +{ + // Look for any `missingReturn` instructions + checkForMissingReturnsRec(module->getModuleInst(), sink); +} + +} diff --git a/source/slang/slang-ir-missing-return.h b/source/slang/slang-ir-missing-return.h new file mode 100644 index 000000000..547737f62 --- /dev/null +++ b/source/slang/slang-ir-missing-return.h @@ -0,0 +1,12 @@ +// slang-ir-missing-return.h +#pragma once + +namespace Slang +{ + class DiagnosticSink; + struct IRModule; + + void checkForMissingReturns( + IRModule* module, + DiagnosticSink* sink); +} diff --git a/source/slang/slang-ir-restructure-scoping.cpp b/source/slang/slang-ir-restructure-scoping.cpp new file mode 100644 index 000000000..c16db8f3c --- /dev/null +++ b/source/slang/slang-ir-restructure-scoping.cpp @@ -0,0 +1,434 @@ +// slang-ir-restructure-scoping.cpp +#include "slang-ir-restructure-scoping.h" + +#include "slang-ir.h" +#include "slang-ir-insts.h" +#include "slang-ir-restructure.h" + +namespace Slang +{ + +/// Try to find the first structured region that represents `block` +/// +/// In general the same block may appear as multiple regions, +/// so this will return the first region in the linked list. +static SimpleRegion* getFirstRegionForBlock( + RegionTree* regionTree, + IRBlock* block) +{ + SimpleRegion* region = nullptr; + if( regionTree->mapBlockToRegion.TryGetValue(block, region) ) + { + return region; + } + return nullptr; +} + +/// Try to find the first structured region that contains `inst`. +static SimpleRegion* getFirstRegionForInst( + RegionTree* regionTree, + IRInst* inst) +{ + auto ii = inst; + while(ii) + { + if(auto block = as(ii)) + return getFirstRegionForBlock(regionTree, block); + + ii = ii->getParent(); + } + + return nullptr; +} + +/// Compute the depth of a node in the region tree. +/// +/// This is the number of nodes (including `region`) +/// on a path from `region` to the root. +/// +static Int computeDepth(Region* region) +{ + Int depth = 0; + for( Region* rr = region; rr; rr = rr->getParent() ) + { + depth++; + } + return depth; +} + +/// Get the `n`th ancestor of `region`. +/// +/// When `n` is zero, this returns `region`. +/// When `n` is one, this returns the parent of `region`, and so forth. +/// +static Region* getAncestor(Region* region, Int n) +{ + Region* rr = region; + for( Int ii = 0; ii < n; ++ii ) + { + SLANG_ASSERT(rr); + rr = rr->getParent(); + } + return rr; +} + +/// Find a region that is an ancestor of both `left` and `right`. +static Region* findCommonAncestorRegion( + Region* left, + Region* right) +{ + // Rather than blinding search through each ancestor of `left` + // and see if it is also an ancestor of `right` and vice-versa, + // let's try to be smart about this. + // + // We will start by computing the depth of `left` and `right`: + // + Int leftDepth = computeDepth(left); + Int rightDepth = computeDepth(right); + + // Whatever the common ancestor is, it can't be any deeper + // than the minimum of these two depths. + // + Int minDepth = Math::Min(leftDepth, rightDepth); + + // Let's fetch the ancestor of each of `left` and `right` + // corresponding to that depth: + // + Region* leftAncestor = getAncestor(left, leftDepth - minDepth); + Region* rightAncestor = getAncestor(right, rightDepth - minDepth); + + // Now we know that `leftAncestor` and `rightAncestor` + // must have the same depth. Let's go ahead and assert + // it just to be safe: + // + SLANG_ASSERT(computeDepth(leftAncestor) == minDepth); + SLANG_ASSERT(computeDepth(rightAncestor) == minDepth); + + // If `leftAncestor` and `rightAncestor` are the same node, + // then we've found a common ancestor, otherwise we should + // look at their parents. Because the depth must match + // on both sides, we will never risk missing an ancestor. + // + while( leftAncestor != rightAncestor ) + { + leftAncestor = leftAncestor->getParent(); + rightAncestor = rightAncestor->getParent(); + } + + // Okay, we've found a common ancestor. + // + Region* commonAncestor = leftAncestor; + return commonAncestor; +} + +/// Find a simple region that is an ancestor of both `left` and `right`. +static SimpleRegion* findSimpleCommonAncestorRegion( + Region* left, + Region* right) +{ + // Start by finding a common ancestor without worrying about it being simple. + Region* ancestor = findCommonAncestorRegion(left, right); + + // Now search for a simple region up the tree. + while( ancestor ) + { + if(ancestor->getFlavor() == Region::Flavor::Simple) + return (SimpleRegion*) ancestor; + + ancestor = ancestor->getParent(); + } + + // This shouldn't ever occur. The root of the region tree should + // be a simple regions that represents the entry block of the + // function. + // + SLANG_UNEXPECTED("no common ancestor found in region tree"); + UNREACHABLE_RETURN(nullptr); +} + +IRInst* getDefaultInitVal( + IRBuilder* builder, + IRType* type) +{ + switch( type->op ) + { + default: + return nullptr; + + case kIROp_BoolType: + return builder->getBoolValue(false); + + case kIROp_IntType: + case kIROp_UIntType: + case kIROp_UInt64Type: + return builder->getIntValue(type, 0); + + case kIROp_HalfType: + case kIROp_FloatType: + case kIROp_DoubleType: + return builder->getFloatValue(type, 0.0); + + // TODO: handle vector/matrix types here, by + // creating an appropriate scalar value and + // then "splatting" it. + } +} + +/// Initialize a variable to a sane default value, if possible. +void defaultInitializeVar( + IRBuilder* builder, + IRVar* var, + IRType* type) +{ + IRInst* initVal = nullptr; + switch( type->op ) + { + case kIROp_VoidType: + default: + // By default, see if we can synthesize an IR value + // to be used as the default, and allow the logic + // below to store it into the variable. + initVal = getDefaultInitVal(builder, type); + break; + + // TODO: Handle aggregate types (structures, arrays) + // explicitly here, since they need to be careful about + // the cases where an element/field type might not + // be something we can default-initialize. + } + + if( initVal ) + { + builder->emitStore(var, initVal); + } +} + +/// Detect and fix any structured scoping issues for a given `def` instruction. +/// +/// The `defRegion` should be the region that contains `def`, and `regionTree` +/// should be the region tree for the function that contains `def`. +/// +static void fixValueScopingForInst( + IRInst* def, + SimpleRegion* defRegion, + RegionTree* regionTree) +{ + // This algorithm should not consider "phi nodes" for now, + // because the emit logic will already create variables for them. + // We could consider folding the logic to move out of SSA form + // into this function, but that would add a lot of complexity for now. + if(def->op == kIROp_Param) + return; + + // We would have a scoping violation if there exists some + // use `u` of `def` such that the region containing `u` + // (call it `useRegion`) is not a descendent of `defRegion` + // in the region tree. + // + // If there are no scoping violations, we don't want to do + // anything. If there *are* any scoping violations, then + // we ill need to introduce a temporary `tmp`, store into + // it right after `def`, and then load from it at any "bad" + // use sites. + // + // Of course, for the whole thing to work, we also need + // to put `tmp` into a block somwhere, and it needs to + // be a block that is visible to all of the uses, or we + // are just back int the same mess again. + // + // The right block to use for inserting `tmp` is the least + // common ancestor of `def` and all the "bad" uses, so + // we will get a bit "clever" and fold in the search for + // bad uses with the computation of the region we should + // insert `tmp` into (to avoid looping over the uses + // twice). + // + SimpleRegion* insertRegion = defRegion; + IRVar* tmp = nullptr; + + // If we end up needing to insert code we'll need an IR builder, + // so we will go ahead and create one now. + // + // TODO: the logic to compute `module` here could be hoisted + // out earlier, rather than being done per-instruction. + // + IRModule* module = regionTree->irCode->getModule(); + + SharedIRBuilder sharedBuilder; + sharedBuilder.session = module->session; + sharedBuilder.module = module; + + IRBuilder builder; + builder.sharedBuilder = &sharedBuilder; + + // Because we will be changing some of the uses of `def` + // to use other values while we iterate the list, we + // need to be a bit careful and extract the next use + // in the linked list *before* we operator on `u`. + // + IRUse* nextUse = nullptr; + for( auto u = def->firstUse; u; u = nextUse ) + { + nextUse = u->nextUse; + + // Looking at the use site `u`, we'd like to check if + // it violates our scoping rules. + // + // As a simple early-exit case, if the user is in + // the same block as the definition, there are no problems. + // + IRInst* user = u->getUser(); + if(user->getParent() == defRegion->block) + continue; + + // Otherwise, let's find the structures control-flow + // region that holds the user. We expect to always + // find one, because the use site must be in the same + // function. + // + // TODO: Double check that logic if we ever introduce + // things like nested function. + // + SimpleRegion* useRegion = getFirstRegionForInst(regionTree, user); + + // If there is no region associated with the use, then + // the use must be in unreachable code (part of the CFG, + // but not part of the region tree). We will skip + // such uses for now, since they won't even appear in + // the output. + // + if(!useRegion) + continue; + + // Now we want to check if `useRegion` is a child/descendent + // of a region that has the same block as `defRegion`. + // If it is, then there is no scoping problem with this use. + // + if(useRegion->isDescendentOf(defRegion->block)) + continue; + + // If we've gotten this far, we know that `u` is a "bad" + // use of `def`, and needs fixing. + // + // We will create the `tmp` variable on demand, so + // that we create it when the first "bad" use is encountered, + // and then re-use it for subsequent bad uses. + // + if( !tmp ) + { + // We will create a temporary to represent `def`, + // and insert a `store` into it right after `def`. + // + // Note: we are inserting the new variable right + // after `def` for now, just because we don't + // yet know the final region that it should be + // placed into. We will move it to the correct + // location when we are done. + // + builder.setInsertBefore(def->getNextInst()); + tmp = builder.emitVar(def->getDataType()); + builder.emitStore(tmp, def); + } + + // In order to know where `tmp` should be defined + // at the end of the algorithm, we need to compute + // a valid `insertRegion` that is an ancestor of + // all of the use sites (and it also a simple region + // so that we can insert into its IR block). + // + // We need to deal with one complexity in our restructuring + // process, which is that a block may be duplicated into + // one or more regions, so we loop over all the regions + // for the same block as `useRegion`. + // + for(auto rr = useRegion; rr; rr = rr->nextSimpleRegionForSameBlock) + { + insertRegion = findSimpleCommonAncestorRegion( + insertRegion, + rr); + } + + // To fix up the use `u`, we will need to change + // it from using `def` to using a load from `tmp` + // + builder.setInsertBefore(user); + IRInst* tmpVal = builder.emitLoad(tmp); + + // We are clobbering the value used by the `IRUse` `u`, + // while will cut it out of the list of uses for `def`. + // We need to be careful when doing this to not disrupt + // our iteration of the uses of `def`, so we carefully + // used the `nextUse` temporary at the start of the loop. + // + u->set(tmpVal); + } + + // At the end of the loop, the `tmp` variable will have + // been created if and only if we fixed up anything. + // + if( tmp ) + { + // If we created a temporary, then now we need to move + // its definition to the right place, which is the + // `insertRegion` that we computed during the loop. + // + // We'd like to insert our temporary near the top + // of the region, since that is the conventional + // place for local variables to go. + // + tmp->insertBefore( + insertRegion->block->getFirstOrdinaryInst()); + + // The whole point of the transformation we are doing + // here is that `def` is not on the "obvious" control + // flow path to one or more uses (which are now using + // `tmp`), but that means that it might not be "obvious" + // to a downstream compiler that `tmp` always gets + // initialized (by the code we inserted after `def`) + // before each of these use sites. + // + // We *know* that things are valid as long as our + // dominator tree was valid - there is no way to + // get to the block that loads from `tmp` without passing + // through the block that computes `def` (and then + // stores it into `tmp`) first. + // + // To avoid warnings/errros, we will go ahead and try + // to emit logic to "default initialize" the `tmp` + // variable if possible. + // + builder.setInsertBefore(tmp->getNextInst()); + defaultInitializeVar(&builder, tmp, def->getDataType()); + } +} + +void fixValueScoping(RegionTree* regionTree) +{ + // We are going to have to walk through every instruction + // in the code of the function to detect an bad cases. + // + auto code = regionTree->irCode; + for(auto block : code->getBlocks()) + { + // All of the instruction in `block` will have the same + // parent region, so we will look it up now rather than + // have to re-do this work on a per-instruction basis. + // + auto parentRegion = getFirstRegionForBlock(regionTree, block); + + // If a block has no region then it must be unreachable, + // so we will skip it entirely for this pass. + // + // TODO: we should be eliminating unrechable blocks anyway. + // + if(!parentRegion) + continue; + + for(auto inst : block->getDecorationsAndChildren()) + { + fixValueScopingForInst(inst, parentRegion, regionTree); + } + } +} + +} diff --git a/source/slang/slang-ir-restructure-scoping.h b/source/slang/slang-ir-restructure-scoping.h new file mode 100644 index 000000000..6c9266754 --- /dev/null +++ b/source/slang/slang-ir-restructure-scoping.h @@ -0,0 +1,24 @@ +// slang-ir-restructure-scoping.h +#pragma once + +namespace Slang +{ + +class RegionTree; + +/// Fix cases where a value might be used in a non-nested region. +/// +/// There can be cases where an IR value V in block A is used in +/// some block B, where A dominates B, *but* when we constructed +/// the region tree, the block B is not in a child/descendent +/// region of A's region, so that it won't be visible through the +/// scoping rules of a target language. +/// +/// This function detects such cases, and fixes them up by inserting +/// new temporaries into the IR code so that values that need +/// to survive across blocks are communicated through variables +/// declared at a sufficiently broad scope. +/// +void fixValueScoping(RegionTree* regionTree); + +} diff --git a/source/slang/slang-ir-restructure.cpp b/source/slang/slang-ir-restructure.cpp new file mode 100644 index 000000000..e88078376 --- /dev/null +++ b/source/slang/slang-ir-restructure.cpp @@ -0,0 +1,663 @@ +// ir-restructure.cpp +#include "slang-ir-restructure.h" + +#include "slang-ir.h" +#include "slang-ir-insts.h" + +namespace Slang +{ + bool Region::isDescendentOf(Region* other) + { + Region* rr = this; + while( rr ) + { + if(rr == other) + return true; + + rr = rr->getParent(); + } + return false; + } + + bool Region::isDescendentOf(IRBlock* block) + { + Region* rr = this; + while( rr ) + { + if( rr->getFlavor() == Region::Flavor::Simple ) + { + SimpleRegion* simpleRegion = (SimpleRegion*) rr; + if(simpleRegion->block == block) + return true; + } + + rr = rr->getParent(); + } + return false; + } + + /// An "active" label during control flow (re)structuring. + struct LabelStack + { + /// Possible operations associated with labels. + enum class Op + { + Break, + Continue, + + CountOf, + }; + + /// What kind of operation does a branch to this label represent? + Op op; + + /// The next label down on the stack + LabelStack* parent; + + /// The block the represents this label in the IR control flow graph. + IRBlock* block; + + /// The region that represents this label in the structured program + Region* region; + }; + + /// State used when restructuring control flow. + struct ControlFlowRestructuringContext + { + /// Sink to use when diagnosing errors in control-flow restructuring. + /// + /// The restructuring pass should be able to handle anything the front-end + /// throws at it, so these errors will all be unexpected. Still, we need + /// a way to report them cleanly without crashing the process. + /// + DiagnosticSink* sink = nullptr; + DiagnosticSink* getSink() { return sink; } + + /// The region tree we are in the process of building. + RegionTree* regionTree = nullptr; + }; + + /// Convert a range of blocks in the IR CFG into a region. + /// + /// We want to generate a region that stands in for the + /// blocks that are logically in the internal [begin, end) + /// which we consider as representing a single-entry multiple-exit + /// sub-graph. Note that `end` is *not* part of the sub-graph, + /// but instead points to a block that is logically "after" + /// the sub-graph. `end` can be `null` to indicate that the + /// sub-graph extends as far as possible. + /// + /// Because there can be multiple exits, control flow may + /// exit the sub-graph without branching to `end`, any + /// such "non-local" branching should be to one of the + /// blocks stored in the current `LabelStack`. + /// + // TODO: Eventually we should replace all of this logic with + // a variation on the "Relooper" algorithm as it is used + // in Emscripten. + // + static RefPtr generateRegionsForIRBlocks( + ControlFlowRestructuringContext* ctx, + Region* inParentRegion, + IRBlock* begin, + IRBlock* end, + LabelStack* initialLabels, // Labels to use at the start + LabelStack* labels = nullptr) // Labels to switch to after emitting first basic block + { + if(!labels) + labels = initialLabels; + auto useLabels = initialLabels; + + // + // We will try to build up as long of a sequential/simple region + // as possible, to avoid deep recursion in this algorithm. + // + RefPtr resultRegion = nullptr; + RefPtr* resultLink = &resultRegion; + + // As we move along, the parent region to use for regions + // we create will shift, so we need a temporary to track + // the current parent region. + // + Region* parentRegion = inParentRegion; + + // + // We will start with the `begin` block, and try to proceed + // sequentially until we see the `end` block, or run into + // an edge that exits teh region. + // + IRBlock* block = begin; + while(block != end) + { + // If the block we are trying to emit has been registered as a + // destination label (e.g. for a loop or `switch`) then we + // need to exit the current region, which amounts to generating + // a `break` or `continue` operation. + // + // TODO: we eventually need to handle the possibility of + // multi-level break/continue targets, which could be challenging. + + // Because we will only support single-level break/continue, we + // want to resolve what is the most recent label that is "active" + // for the given operation (`break` or `continue`). + // + // We will do this with a naive loop, just to keep things simple. + // We start with no block "regsitered" as the target for each + // operation. + // + IRBlock* registeredBlock[(int)LabelStack::Op::CountOf] = {}; + for( auto ll = useLabels; ll; ll = ll->parent ) + { + // For each active label, see if it is the first one + // we encounter for the given op. + // + if(!registeredBlock[(int)ll->op]) + { + registeredBlock[(int)ll->op] = ll->block; + } + } + + // Next we will search through *all* of the registered labels, + // and see if one of them matches the current `block`. + // + for(auto ll = useLabels; ll; ll = ll->parent) + { + // Does this label match the block we are trying to translate? + if(ll->block != block) + continue; + + // Okay, the block we are trying to generate code for is a label + // that we should branch to (we shouldn't just emit the code here + // and now...) + // + // We should first confirm that the block is the inner-most label + // registered for the given control-flow op (`break` or `continue`) + // because if it *isn't* we currently can't generate code. + // + if(block != registeredBlock[(int)ll->op]) + { + ctx->getSink()->diagnose(block, Diagnostics::multiLevelBreakUnsupported); + } + + // Now we need to create a structured `break` or `continue` operation + // to match the operation associated with the target. + // + switch(ll->op) + { + case LabelStack::Op::Break: + { + auto outerRegion = (BreakableRegion*) ll->region; + RefPtr breakRegion = new BreakRegion(parentRegion, outerRegion); + + *resultLink = breakRegion; + resultLink = nullptr; + } + break; + + case LabelStack::Op::Continue: + { + auto outerRegion = (LoopRegion*) ll->region; + RefPtr continueRegion = new ContinueRegion(parentRegion, outerRegion); + + *resultLink = continueRegion; + resultLink = nullptr; + } + break; + } + + // If the `block` matched an active label, then we should have + // created a branch, and there is nothing to be done here. + return resultRegion; + } + + // We now know that the given `block` is part of our control-flow region, + // so we need to output a simple region that executes the code in that block. + // + RefPtr simpleRegion = new SimpleRegion(parentRegion, block); + + // We need to register the mapping from `block` to this region, but in + // general this isn't a one-to-one mapping, but rather one-to-many. + // This is because a "continue clause" in a `for` loop might get duplicated + // at each `continue` site in the output code. To deal with this + // we build a singly-linked list of regions for each block. + // + // TODO: confirm that continue clauses are the only case that leads + // to duplication. + // + // TODO: remove this workaround once we have a more powerful restructuring + // pass that avoids duplicating blocks (by introducing new temporaries...) + // + SimpleRegion* nextSimpleRegionForSameBlock = nullptr; + ctx->regionTree->mapBlockToRegion.TryGetValue(block, nextSimpleRegionForSameBlock); + ctx->regionTree->mapBlockToRegion[block] = simpleRegion; + + *resultLink = simpleRegion; + resultLink = &simpleRegion->nextRegion; + parentRegion = simpleRegion; + + // The simple region we created will represent all of the non-terminator + // instructions in the `block`, so now we need to figure out what to + // create to represent that terminator. + // + auto terminator = block->getTerminator(); + SLANG_ASSERT(terminator != nullptr); + switch (terminator->op) + { + default: + case kIROp_conditionalBranch: + // Note: we don't currently generate ordinary `conditionalBranch` instructions, + // and instead only generate `ifElse` instructions, which include additional + // information that can inform our control-flow restructuring pass. + // + SLANG_UNEXPECTED("unhandled terminator instruction opcode"); + ; // fall through to: + case kIROp_Unreachable: + case kIROp_MissingReturn: + case kIROp_ReturnVal: + case kIROp_ReturnVoid: + case kIROp_discard: + // These cases are all simple terminators that can be handled as-is + // without needing to construct a separate `Region` to encapsulate them. + // + // We will cap off the current sequence of simple regions and return. + // + *resultLink = nullptr; + return resultRegion; + + case kIROp_ifElse: + { + // Here we have a two-way branch, so that we will construct a + // region representing an `if` statement. + // + auto ifInst = (IRIfElse*)terminator; + auto condition = ifInst->getCondition(); + auto trueBlock = ifInst->getTrueBlock(); + auto falseBlock = ifInst->getFalseBlock(); + auto afterBlock = ifInst->getAfterBlock(); + + + RefPtr ifRegion = new IfRegion(parentRegion, condition); + + // The region for the "then" part of things will consist of + // the range of blocks `[trueBlock, afterBlock)`. + // + // This logic assumes that `afterBlock` is a valid structured + // "join point" such that any branch out of the sub-region + // either leads to `afterBlock` *or* one of the labels + // that is already present on our label stack. + // + ifRegion->thenRegion = generateRegionsForIRBlocks( + ctx, + ifRegion, + trueBlock, + afterBlock, + labels); + + // Generating a region for the `else` part is similar. + // Note that it is possible for this to be a `null` + // region, if `falseBlock == afterBlock`. + // + ifRegion->elseRegion = generateRegionsForIRBlocks( + ctx, + ifRegion, + falseBlock, + afterBlock, + labels); + + *resultLink = ifRegion; + resultLink = &ifRegion->nextRegion; + parentRegion = ifRegion; + + // Continue with the block after the `ifElse` instruction. + block = afterBlock; + } + break; + + case kIROp_loop: + { + // The terminator in this case is the header for a structured loop. + // + auto loopInst = (IRLoop*) terminator; + auto bodyBlock = loopInst->getTargetBlock(); + auto afterBlock = loopInst->getBreakBlock(); + + RefPtr loopRegion = new LoopRegion(parentRegion, loopInst); + + // We will need to set up entries on our label stack to + // represent the targets for `break` or `continue` + // operations inside the loop. + // + // First we set up the stack entry for the `break` label, + // which will refer to the block *after* the loop. + // + // The region we specify for the label will still be + // the loop region, though, because the loop is what + // we are breaking out of. + // + LabelStack loopBreakLabelStack; + loopBreakLabelStack.parent = labels; + loopBreakLabelStack.block = afterBlock; + loopBreakLabelStack.region = loopRegion; + loopBreakLabelStack.op = LabelStack::Op::Break; + + // + // The `continue` label warrants a bit more careful explanation, + // because it will *not* refer to the block that was regsitered + // as the continue target in the IR `loop` instruction. This + // is because we will always emit our loops as `for(;;) { ... }` + // with no continue clause at all, so that a `continue` in + // the output code will always refer to the top of the loop. + // + // This means that the `continue` label for the purposes of + // structured control flow will be the start of the loop body: + // + LabelStack loopContinueLabelStack; + loopContinueLabelStack.parent = &loopBreakLabelStack; + loopContinueLabelStack.block = bodyBlock; + loopContinueLabelStack.region = loopRegion; + loopContinueLabelStack.op = LabelStack::Op::Continue; + // + // Note: by ignoring the original continue block from the + // high-level loop, we create a situation where that code + // might get emitted more than once (once per implicit + // or explicit `continue` site in the original program). + // + // That is an acceptable trade-off for now, because continue + // blocks will usually be small (and fxc makes the same choice), + // but it could lead to Bad Things if somebody were to call + // a function in their continue clause, and that function does + // a compute shader barrier operation. + // + // A better long-term fix is to take a high-level loop like: + // + // for(A; B; C) { ... continue; ... break; ... } + // + // and translate it into something like the following (assuming + // we have labeled statements and multi-level `break`): + // + // A; + // Outer: for(;;) { + // Inner: for(;;) { + // if(B) {} else break Outer; + // ... + // break Inner; // `continue` becomes break of inner loop + // ... + // break Outer; // `break` becomes break of outer loop + // ... + // break; // inner loop unconditionally breaks at the end + // } + // C; // continue clause comes after inner loop + // } + // + // If you draw up a control flow graph for that code, you'll find + // it is equivalent to the orignal `for` loop, but now supports + // arbitrary code (not just a single expression) for the continue clause. + // Unlike the current code-duplication solution, `C` appears only once + // in the output, and seems to clearly be at a "joint point" for control + // flow so that it is clear that a barrier there is valid in GLSL. + // + // Anyway, back our regularly scheduled programming. + // + // With the label stack stuff set up, we want to take the region + // of the CFG defined by `[bodyBlock, afterBlock)` and turn it into + // the body region for our loop. + // + // The only thing we want to be a little bit careful about is + // that we don't want the logic at the top of this function + // that looks for a block it can translate into a `continue` + // to trigger on `bodyBlock`, since that means we'd just turn + // the whole body into a single `continue`. + // + // To avoid this problem, we pass in two different label stacks: + // one to use for the first block, and one to use for subsequent + // blocks. + // + loopRegion->body = generateRegionsForIRBlocks( + ctx, + loopRegion, + bodyBlock, + // TODO: should we pass `afterBlock` here instead of `null`? + nullptr, + // For the first block, we only want the `break` label active + &loopBreakLabelStack, + // After the first block, we can safely use the `continue` label too + &loopContinueLabelStack); + + *resultLink = loopRegion; + resultLink = &loopRegion->nextRegion; + parentRegion = loopRegion; + + // Continue with the block after the loop + block = afterBlock; + } + break; + + case kIROp_unconditionalBranch: + { + // Here we have an unconditional branch that was + // not covered by one of our labels for non-local + // branches (`break` or `continue`). + // + // We will thus assume that the target of the + // branch is part of the same region we are building, + // and continue with the target block; + // + auto branchInst = (IRUnconditionalBranch*) terminator; + block = branchInst->getTargetBlock(); + } + break; + + case kIROp_Switch: + { + // A `switch` instruction will always translate + // to a `SwitchRegion` and then to a `switch` statement. + // + // We will need to take care to emit `case`s in ways + // that avoid code duplication. + // + // The logic here isn't going to be robust in edge cases + // (please don't write Duff's Device in Slang just yet). + // Doing significantly better than what is here would + // require something like the Relooper algorithm, though. + // + auto switchInst = (IRSwitch*) terminator; + auto condition = switchInst->getCondition(); + auto breakLabel = switchInst->getBreakLabel(); + auto defaultLabel = switchInst->getDefaultLabel(); + + RefPtr switchRegion = new SwitchRegion(parentRegion, condition); + + // A direct branch to the block after the `switch` can + // be emitted as a `break` statement, so we will register + // the appropriate label on a label stack: + // + LabelStack switchBreakLabelStack; + switchBreakLabelStack.parent = labels; + switchBreakLabelStack.op = LabelStack::Op::Break; + switchBreakLabelStack.block = breakLabel; + switchBreakLabelStack.region = switchRegion; + + // We need to track whether we've dealt with + // the `default` case already. + // + bool defaultLabelHandled = false; + + // If the `default` case just branches to + // the join point, then we don't need to + // do anything with it. + // + if(defaultLabel == breakLabel) + defaultLabelHandled = true; + + // We will now iterate over the different `case`s, and + // try to group them together to minimize the number of + // sub-regions we have to create. + // + UInt caseIndex = 0; + UInt caseCount = switchInst->getCaseCount(); + while(caseIndex < caseCount) + { + // We are going to extract one case here, + // but we might need to fold additional + // cases into it, if they share the + // same label. + // + // Note: this makes assumptions that the + // IR code generator orders cases such + // that: (1) cases with the same label + // are consecutive, and (2) any case + // that "falls through" to another must + // come right before it in the list. + + auto caseVal = switchInst->getCaseValue(caseIndex); + auto caseLabel = switchInst->getCaseLabel(caseIndex); + caseIndex++; + + RefPtr currentCase = new SwitchRegion::Case(); + switchRegion->cases.add(currentCase); + + // Add the case value for this case, and any + // others that share the same label + // + for(;;) + { + currentCase->values.add(caseVal); + + // Are there any more `case`s left? + // + if(caseIndex >= caseCount) + break; + + // Does the next `case` share the same target label? + auto nextCaseLabel = switchInst->getCaseLabel(caseIndex); + if(nextCaseLabel != caseLabel) + break; + + // If those checks passed, then we will fold + // the next `case` into the same region, and + // keep looking. + caseVal = switchInst->getCaseValue(caseIndex); + caseIndex++; + } + + // The label for the current `case` might also + // be the label used by the `default` case, so + // check for that here. + // + if(caseLabel == defaultLabel) + { + switchRegion->defaultCase = currentCase; + defaultLabelHandled = true; + } + + // Now we need to generate a region for the instructions + // that make up this case. The 99% case will be that it + // will terminate with a `break` (or a `return`, + // `continue`, etc.) and so we can pass in `nullptr` + // for the ending block. + // + IRBlock* caseEndLabel = nullptr; + + // However, there is also the possibility that + // this `case` will fall through to the next, and + // so we need to prepare for that possibility here. + // + // If there *is* a next `case`, then we will set its + // label up as the "end" label when emitting + // the statements inside the block. + if(caseIndex < caseCount) + { + caseEndLabel = switchInst->getCaseLabel(caseIndex); + } + + // Now we can actually generate the region. + // + currentCase->body = generateRegionsForIRBlocks( + ctx, + switchRegion, + caseLabel, + caseEndLabel, + &switchBreakLabelStack); + } + + // If we've gone through all the cases and haven't + // managed to encounter the `default:` label, + // then assume it is a distinct case and handle it here. + if(!defaultLabelHandled) + { + RefPtr defaultCase = new SwitchRegion::Case(); + switchRegion->cases.add(defaultCase); + + // Note: we use `null` instead of `breakLabel` as the end block + // here, to ensure that the `default` region will end with an + // explicit `break` rather than just falling off the end. + + defaultCase->body = generateRegionsForIRBlocks( + ctx, + switchRegion, + defaultLabel, + nullptr, + &switchBreakLabelStack); + + switchRegion->defaultCase = defaultCase; + } + + *resultLink = switchRegion; + resultLink = &switchRegion->nextRegion; + parentRegion = switchRegion; + + // Continue with the block after the `switch` + block = breakLabel; + } + break; + } + + // After we've emitted the first block, we are safe from accidental + // cases where we'd emit an entire loop body as a single `continue`, + // so we can safely switch in whatever labels are intended to be used. + useLabels = labels; + + // If we reach this point, then we've emitted + // one block, and we have a new block where + // control flow continues. + // + // We need to handle a special case here, + // when control flow jumps back to the + // starting block of the range we were + // asked to work with: + if (block == begin) + { + break; + } + } + + // We seem to have reached the rend of the region + // without anything special happening. This means + // we should cap off the current sequence of regions + // and return what we have. + // + *resultLink = nullptr; + return resultRegion; + } + + RefPtr generateRegionTreeForFunc( + IRGlobalValueWithCode* code, + DiagnosticSink* sink) + { + RefPtr regionTree = new RegionTree(); + regionTree->irCode = code; + + ControlFlowRestructuringContext restructuringContext; + restructuringContext.sink = sink; + restructuringContext.regionTree = regionTree; + + regionTree->rootRegion = generateRegionsForIRBlocks( + &restructuringContext, + nullptr, + code->getFirstBlock(), + nullptr, + nullptr); + + return regionTree; + } +} diff --git a/source/slang/slang-ir-restructure.h b/source/slang/slang-ir-restructure.h new file mode 100644 index 000000000..6ec15f6d7 --- /dev/null +++ b/source/slang/slang-ir-restructure.h @@ -0,0 +1,261 @@ +// slang-ir-restructure.h +#pragma once + +#include "../core/slang-basic.h" + +namespace Slang +{ + class DiagnosticSink; + struct IRBlock; + struct IRGlobalValueWithCode; + struct IRInst; + struct IRLoop; + + /// A structured control-flow region. + /// + /// A `Region` is used to layer structured control flow information + /// over an existing IR control flow graph (CFG). Each `Region` + /// represents a sub-graph of the CFG such that control always + /// enters at the start of the region. + /// + class Region : public RefObject + { + public: + enum class Flavor + { + Simple, + If, + Break, + Continue, + Loop, + Switch, + }; + + Flavor getFlavor() { return flavor; } + + Region* getParent() { return parent; } + + /// Is this region a descendent of `other`? + /// + /// For the purpose of this query, a region + /// is a descendent of itself. + bool isDescendentOf(Region* other); + + /// Is this region a descendent of `block`? + /// + /// This tests is the region is a descendent + /// of any simple region for `block`. + bool isDescendentOf(IRBlock* block); + + protected: + Region(Flavor flavor, Region* parent) + : flavor(flavor) + , parent(parent) + {} + + /// What kind of region is this? + Flavor flavor; + + /// The parent region of this region. + Region* parent; + }; + + /// Base type for regions that have a "next" region. + /// + /// While we think of it as a region to execute + /// after this region, the `nextRegion` is actually + /// a *child* region, in that it can see local + /// values that were defined in this parent region + /// (and any other ancestor regions). + class SeqRegion : public Region + { + protected: + SeqRegion(Flavor flavor, Region* parent) + : Region(flavor, parent) + {} + + public: + /// The (child) region to execute after this one. + RefPtr nextRegion; + }; + + /// A simple region that encapsulates a basic block. + /// + class SimpleRegion : public SeqRegion + { + public: + SimpleRegion(Region* parent, IRBlock* block) + : SeqRegion(Region::Flavor::Simple, parent) + , block(block) + {} + + /// The basic block for this region. + IRBlock* block = nullptr; + + /// The next simple region for the same block + /// + /// A single IR basic block may turn into multiple regions, + /// if the restructuring pass has to duplicate it (this + /// currently happens for the continue clause in a `for` + /// loop if it has multiple `continue` sites. + /// + SimpleRegion* nextSimpleRegionForSameBlock = nullptr; + }; + + /// A conditional region, corresponding to an `if` + /// + class IfRegion : public SeqRegion + { + public: + IfRegion(Region* parent, IRInst* condition) + : SeqRegion(Region::Flavor::If, parent) + , condition(condition) + {} + + /// The IR value that controls the conditional branch + IRInst* condition; + + /// The region to execute if the `condition` is `true` + RefPtr thenRegion; + + /// The region to execute if the `condition` is `false` + RefPtr elseRegion; + }; + + /// Base type for regions that execution can `break` out of + class BreakableRegion : public SeqRegion + { + protected: + BreakableRegion(Flavor flavor, Region* parent) + : SeqRegion(flavor, parent) + {} + }; + + /// A region that expresses a `break` out of nested control flow. + /// + class BreakRegion : public Region + { + public: + BreakRegion(Region* parent, BreakableRegion* outerRegion) + : Region(Region::Flavor::Break, parent) + , outerRegion(outerRegion) + {} + + BreakableRegion* outerRegion; + }; + + /// A structured loop + class LoopRegion : public BreakableRegion + { + public: + LoopRegion(Region* parent, IRLoop* loopInst) + : BreakableRegion(Region::Flavor::Loop, parent) + , loopInst(loopInst) + {} + + /// The IR instruction that represents the branch into the loop. + /// We keep this instruction around because it may have decorations + /// that need to influence how we emit this loop. + /// + IRLoop* loopInst; + + /// The code inside the loop. + /// + /// The body region may include `break` or `continue` operations for this loop. + RefPtr body; + }; + + /// A region that expresses a `continue` for a structured loop. + /// + class ContinueRegion : public Region + { + public: + ContinueRegion(Region* parent, LoopRegion* outerRegion) + : Region(Region::Flavor::Continue, parent) + , outerRegion(outerRegion) + {} + + LoopRegion* outerRegion; + }; + + /// A structured `switch` statement. + class SwitchRegion : public BreakableRegion + { + public: + SwitchRegion(Region* parent, IRInst* condition) + : BreakableRegion(Region::Flavor::Switch, parent) + , condition(condition) + {} + + /// The IR value that controls the conditional branch + IRInst* condition; + + /// A collection of `case`s that share the same code. + class Case : public RefObject + { + public: + /// The various values that should branch to this case. + /// + /// It is possible for this list to be empty if this + /// is the `default` case and has no explicit values + /// that map to it. + /// + List values; + + /// The region to execute if this case is selected. + RefPtr body; + }; + + /// All of the cases for the `switch`. + /// + /// This includes any `default` cases. + /// + /// As an invariant, a case that "falls through" to another + /// should immediately precede its target in this list. + /// + List> cases; + + /// The default case, if any. + /// + /// It is valid for this to be `null` if there is no `default` case, + /// in which case the default behavior should be to branch to the region + /// after the `switch`. + /// + /// The default case must also be present in `cases`. + Case* defaultCase; + }; + + /// Container for all of the regions in a function. + /// + /// A `RegionTree` owns the `Region` objects associated with a function, + /// along with a mapping from basic blocks in the IR function to regions + /// in the tree. + /// + class RegionTree : public RefObject + { + public: + /// Type for the mapping from IR blocks to regions. + typedef Dictionary MapBlockToRegion; + + /// A dictionary to map from IR blocks to regions. + MapBlockToRegion mapBlockToRegion; + + /// The root region of the region tree. + RefPtr rootRegion; + + /// The IR function that was used to compute the region tree. + IRGlobalValueWithCode* irCode = nullptr; + }; + + /// Construct structrured regions to represent the control flow in an IR function. + /// + /// The resulting `RegionTree` will encode a structured (statement-like) + /// form for the control flow graph (CFG) of `code`. + /// In cases where our current restructuring approach is now powerful + /// enough to handle something in the input CFG, diagnostic messages + /// will be output to the given `sink`. + /// + RefPtr generateRegionTreeForFunc( + IRGlobalValueWithCode* code, + DiagnosticSink* sink); +} diff --git a/source/slang/slang-ir-sccp.cpp b/source/slang/slang-ir-sccp.cpp new file mode 100644 index 000000000..c330d2e0a --- /dev/null +++ b/source/slang/slang-ir-sccp.cpp @@ -0,0 +1,950 @@ +// slang-ir-sccp.cpp +#include "slang-ir-sccp.h" + +#include "slang-ir.h" +#include "slang-ir-insts.h" + +namespace Slang { + + +// This file implements the Spare Conditional Constant Propagation (SCCP) optimization. +// +// We will apply the optimization over individual functions, so we will start with +// a context struct for the state that we will share across functions: +// +struct SharedSCCPContext +{ + IRModule* module; + SharedIRBuilder sharedBuilder; +}; +// +// Next we have a context struct that will be applied for each function (or other +// code-bearing value) that we optimize: +// +struct SCCPContext +{ + SharedSCCPContext* shared; // shared state across functions + IRGlobalValueWithCode* code; // the function/code we are optimizing + + // The SCCP algorithm applies abstract interpretation to the code of the + // function using a "lattice" of values. We can think of a node on the + // lattice as representing a set of values that a given instruction + // might take on. + // + struct LatticeVal + { + // We will use three "flavors" of values on our lattice. + // + enum class Flavor + { + // The `None` flavor represent an empty set of values, meaning + // that we've never seen any indication that the instruction + // produces a (well-defined) value. This could indicate an + // instruction that does not appear to execute, but it could + // also indicate an instruction that we know invokes undefined + // behavior, so we can freely pick a value for it on a whim. + None, + + // The `Constant` flavor represents an instuction that we + // have only ever seen produce a single, fixed value. It's + // `value` field will hold that constant value. + Constant, + + // The `Any` flavor represents an instruction that might produce + // different values at runtime, so we go ahead and approximate + // this as it potentially yielding any value whatsoever. A + // more precise analysis could use sets or intervals of values, + // but for SCCP anything that could take on more than 1 value + // at runtime is assumed to be able to take on *any* value. + Any, + }; + + // The flavor of this value (`None`, `Constant`, or `Any`) + Flavor flavor; + + // If this is a `Constant` lattice value, then this field + // points to the IR instruction that defines the actual constant value. + // For all other flavors it should be null. + IRInst* value = nullptr; + + // For convenience, we define `static` factory functions to + // produce values of each of the flavors. + + static LatticeVal getNone() + { + LatticeVal result; + result.flavor = Flavor::None; + return result; + } + + static LatticeVal getAny() + { + LatticeVal result; + result.flavor = Flavor::Any; + return result; + } + + static LatticeVal getConstant(IRInst* value) + { + LatticeVal result; + result.flavor = Flavor::Constant; + result.value = value; + return result; + } + + // We also need to be able to test if two lattice + // values are equal, so that we can avoid updating + // downstream dependencies if our knowledge about + // an instruction hasn't actually changed. + // + bool operator==(LatticeVal const& that) + { + return this->flavor == that.flavor + && this->value == that.value; + } + + bool operator!=(LatticeVal const& that) + { + return !( *this == that ); + } + }; + + // If we imagine a variable (actually an SSA phi node...) that + // might be assigned lattice value A at one point in the code, + // and lattice value B at another point, we need a way to + // combine these to form our knowledge of the possible value(s) + // for the variable. + // + // In terms of computation on a lattice, we want the "meet" + // operation, which computes the lower bound on what we know. + // If we interpret our lattice values as sets, then we are + // trying to compute the union. + // + LatticeVal meet(LatticeVal const& left, LatticeVal const& right) + { + // If either value is `None` (the empty set), then the union + // will be the other value. + // + if(left.flavor == LatticeVal::Flavor::None) return right; + if(right.flavor == LatticeVal::Flavor::None) return left; + + // If either value is `Any` (the universal set), then + // the union is also the universal set. + // + if(left.flavor == LatticeVal::Flavor::Any) return LatticeVal::getAny(); + if(right.flavor == LatticeVal::Flavor::Any) return LatticeVal::getAny(); + + // At this point we've ruled out the case where either value + // is `None` *or* `Any`, so we can assume both values are + // `Constant`s. + SLANG_ASSERT(left.flavor == LatticeVal::Flavor::Constant); + // + SLANG_ASSERT(right.flavor == LatticeVal::Flavor::Constant); + + // If the two lattice values represent the *same* constant value + // (they are the same singleton set) then the union is that + // singleton set as well. + // + // TODO: This comparison assumes that constants with + // the same value with be represented with the + // same instruction, which is not *always* + // guaranteed in the IR today. + // + if(left.value == right.value) + return left; + + // Otherwise, we have two distinct singleton sets, and their + // union should be a set with two elements. We can't represent + // that on the lattice for SCCP, so the proper lower bound + // is the universal set (`Any`) + // + return LatticeVal::getAny(); + } + + // During the execution of the SCCP algorithm, we will track our best + // "estimate" so far of the set of values each instruction could take + // on. This amounts to a mapping from IR instructions to lattice values, + // where any instruction not present in the map is assumed to default + // to the `None` case (the empty set) + // + Dictionary mapInstToLatticeVal; + + // Updating the lattice value for an instruction is easy, but we'll + // use a simple function to make our intention clear. + // + void setLatticeVal(IRInst* inst, LatticeVal const& val) + { + mapInstToLatticeVal[inst] = val; + } + + // Querying the lattice value for an instruction isn't *just* a matter + // of looking it up in the dictionary, because we need to account for + // cases of lattice values that might come from outside the current + // function. + // + LatticeVal getLatticeVal(IRInst* inst) + { + // Instructions that represent constant values should always + // have a lattice value that reflects this. + // + switch( inst->op ) + { + case kIROp_IntLit: + case kIROp_FloatLit: + case kIROp_StringLit: + case kIROp_BoolLit: + return LatticeVal::getConstant(inst); + break; + + // TODO: We might want to start having support for constant + // values of aggregate types (e.g., a `makeArray` or `makeStruct` + // where all the operands are constant is itself a constant). + + default: + break; + } + + // We might be asked for the lattice value of an instruction + // not contained in the current function. When that happens, + // we will treat it as having potentially any value, rather + // than the default of none. + // + auto parentBlock = as(inst->getParent()); + if(!parentBlock || parentBlock->getParent() != code) return LatticeVal::getAny(); + + // Once the special cases are dealt with, we can look up in + // the dictionary and just return the value we get from it, + // or default to the `None` (empty set) case. + LatticeVal latticeVal; + if(mapInstToLatticeVal.TryGetValue(inst, latticeVal)) + return latticeVal; + return LatticeVal::getNone(); + } + + // Along the way we might need to create new IR instructions + // to represnet new constant values we find, or new control + // flow instructiosn when we start simplifying things. + // + IRBuilder builderStorage; + IRBuilder* getBuilder() { return &builderStorage; } + + // In order to perform constant folding, we need to be able to + // interpret an instruction over the lattice values. + // + LatticeVal interpretOverLattice(IRInst* inst) + { + SLANG_UNUSED(inst); + + // Certain instruction always produce constants, and we + // want to special-case them here. + switch( inst->op ) + { + case kIROp_IntLit: + case kIROp_FloatLit: + case kIROp_StringLit: + case kIROp_BoolLit: + return LatticeVal::getConstant(inst); + + // TODO: we might also want to special-case certain + // instructions where we shouldn't bother trying to + // constant-fold them and should just default to the + // `Any` value right away. + + default: + break; + } + + // TODO: We should now look up the lattice values for + // the operands of the instruction. + // + // If all of the operands have `Constant` lattice values, + // then we can potential execute the operation directly + // on those constant values, create a fresh `IRConstant`, + // and return a `Constant` lattice value for it. This + // would allow us to achieve true constant folding here. + // + // Textbook discussions of SCCP often point out that it + // is also possible to perform certain algebraic simplifications + // here, such as evaluating a multiply by a `Constant` zero + // to zero. + // + // As a default, if any operand has the `Any` value + // then the result of the operation should be treated as + // `Any`. There are exceptions to this, however, with the + // multiply-by-zero example being an important example. + // If we had previously decided that (Any * None) -> Any + // but then we refine our estimates and have (Any * Constant(0)) -> Constant(0) + // then we have violated the monotonicity rules for how + // our values move through the lattice, and we may break + // the convergence guarantees of the analysis. + // + // When we have a mix of `None` and `Constant` operands, + // then the `None` values imply that our operation is using + // uninitialized data or the results of undefined behavior. + // We could try to propagate the `None` through, and allow + // the compiler to speculatively assume that the operation + // produces whatever value we find convenient. Alternatively, + // we can be less aggressive and treat an operation with + // `None` inputs as producing `Any` to make sure we don't + // optimize the code based on non-obvious assumptions. + // + // For now we aren't implementing *any* folding logic here, + // for simplicity. This is the right place to add folding + // optimizations if/when we need them. + // + + // A safe default is to assume that every instruction not + // handled by one of the cases above could produce *any* + // value whatsoever. + return LatticeVal::getAny(); + } + + + // For basic blocks, we will do tracking very similar to what we do for + // ordinary instructions, just with a simpler lattice: every block + // will either be marked as "never executed" or in a "possibly executed" + // state. We track this as a set of the blocks that have been + // marked as possibly executed, plus a getter and setter function. + + HashSet executedBlocks; + + bool isMarkedAsExecuted(IRBlock* block) + { + return executedBlocks.Contains(block); + } + + void markAsExecuted(IRBlock* block) + { + executedBlocks.Add(block); + } + + // The core of the algorithm is based on two work lists. + // One list holds CFG nodes (basic blocks) that we have + // discovered might execute, and thus need to be processed, + // and the other holds SSA nodes (instructions) that need + // their "estimated" value to be updated. + + List cfgWorkList; + List ssaWorkList; + + // A key operation is to take an IR instruction and update + // its "estimated" value on the lattice. This might happen when + // we first discover the instruction could be executed, or + // when we discover that one or more of its operands has + // changed its lattice value so that we need to update our estimate. + // + void updateValueForInst(IRInst* inst) + { + // Block parameters are conceptually SSA "phi nodes", and it + // doesn't make sense to update their values here, because the + // actual candidate values for them comes from the predecessor blocks + // that provide arguments. We will see that logic shortly, when + // handling `IRUnconditionalBranch`. + // + if(as(inst)) + return; + + // We want to special-case terminator instructions here, + // since abstract interpretation of them should cause blocks to + // be marked as executed, etc. + // + if( auto terminator = as(inst) ) + { + if( auto unconditionalBranch = as(inst) ) + { + // When our abstract interpreter "executes" an unconditional + // branch, it needs to mark the target block as potentially + // executed. We do this by adding the target to our CFG work list. + // + auto target = unconditionalBranch->getTargetBlock(); + cfgWorkList.add(target); + + // Besides transferring control to another block, the other + // thing our unconditional branch instructions do is provide + // the arguments for phi nodes in the target block. + // We thus need to interpret each argument on the branch + // instruction like an "assignment" to the corresponding + // parameter of the target block. + // + UInt argCount = unconditionalBranch->getArgCount(); + IRParam* pp = target->getFirstParam(); + for( UInt aa = 0; aa < argCount; ++aa, pp = pp->getNextParam() ) + { + IRInst* arg = unconditionalBranch->getArg(aa); + IRInst* param = pp; + + // We expect the number of arguments and parameters to match, + // or else the IR is violating its own invariants. + // + SLANG_ASSERT(param); + + // We will update the value for the target block's parameter + // using our "meet" operation (union of sets of possible values) + // + LatticeVal oldVal = getLatticeVal(param); + + // If we've already determined that the block parameter could + // have any value whatsoever, there is no reason to bother + // updating it. + // + if(oldVal.flavor == LatticeVal::Flavor::Any) + continue; + + // We can look up the lattice value for the argument, + // because we should have interpreted it already + // + LatticeVal argVal = getLatticeVal(arg); + + // Now we apply the meet operation and see if the value changed. + // + LatticeVal newVal = meet(oldVal, argVal); + if( newVal != oldVal ) + { + // If the "estimated" value for the parameter has changed, + // then we need to update it in our dictionary, and then + // make sure that all of the users of the parameter get + // their estimates updated as well. + // + setLatticeVal(param, newVal); + for( auto use = param->firstUse; use; use = use->nextUse ) + { + ssaWorkList.add(use->getUser()); + } + } + } + } + else if( auto conditionalBranch = as(inst) ) + { + // An `IRConditionalBranch` is used for two-way branches. + // We will look at the lattice value for the condition, + // to see if we can narrow down which of the two ways + // might actually be taken. + // + auto condVal = getLatticeVal(conditionalBranch->getCondition()); + + // We do not expect to see a `None` value here, because that + // would mean the user is branching based on an undefined + // value. + // + // TODO: We should make sure there is no way for the user + // to trigger this assert with bad code that involves + // uninitialized variables. Right now we don't special + // case the `undefined` instruction when computing lattice + // values, so it shouldn't be a problem. + // + SLANG_ASSERT(condVal.flavor != LatticeVal::Flavor::None); + + // If the branch condition is a constant, we expect it to + // be a Boolean constant. We won't assert that it is the + // case here, just to be defensive. + // + if( condVal.flavor == LatticeVal::Flavor::Constant ) + { + if( auto boolConst = as(condVal.value) ) + { + // Only one of the two targe blocks is possible to + // execute, based on what we know of the condition, + // so we will add that target to our work list and + // bail out now. + // + auto target = boolConst->getValue() ? conditionalBranch->getTrueBlock() : conditionalBranch->getFalseBlock(); + cfgWorkList.add(target); + return; + } + } + + // As a fallback, if the condition isn't constant + // (or somehow wasn't a Boolean constnat), we will + // assume that either side of the branch could be + // taken, so that both of the target blocks are + // potentially executed. + // + cfgWorkList.add(conditionalBranch->getTrueBlock()); + cfgWorkList.add(conditionalBranch->getFalseBlock()); + } + else if( auto switchInst = as(inst) ) + { + // The handling of a `switch` instruction is similar to the + // case for a two-way branch, with the main difference that + // we have to deal with an integer condition value. + + auto condVal = getLatticeVal(switchInst->getCondition()); + SLANG_ASSERT(condVal.flavor != LatticeVal::Flavor::None); + + UInt caseCount = switchInst->getCaseCount(); + if( condVal.flavor == LatticeVal::Flavor::Constant ) + { + if( auto condConst = as(condVal.value) ) + { + // At this point we have a constant integer condition + // value, and we just need to find the case (if any) + // that matches it. We will default to considering + // the `default` label as the target. + // + auto target = switchInst->getDefaultLabel(); + for( UInt cc = 0; cc < caseCount; ++cc ) + { + if( auto caseConst = as(switchInst->getCaseValue(cc)) ) + { + if(caseConst->getValue() == condConst->getValue()) + { + target = switchInst->getCaseLabel(cc); + break; + } + } + } + + // Whatever single block we decided will get executed, + // we need to make sure it gets processed and then bail. + // + cfgWorkList.add(target); + return; + } + } + + // The fallback is to assume that the `switch` instruction might + // branch to any of its cases, or the `default` label. + // + for( UInt cc = 0; cc < caseCount; ++cc ) + { + cfgWorkList.add(switchInst->getCaseLabel(cc)); + } + cfgWorkList.add(switchInst->getDefaultLabel()); + } + + // There are other cases of terminator instructions not handled + // above (e.g., `return` instructions), but these can't cause + // additional basic blocks in the CFG to execute, so we don't + // need to consider them here. + // + // No matter what, we are done with a terminator instruction + // after inspecting it, and there is no reason we have to + // try and compute its "value." + return; + } + + // For an "ordinary" instruction, we will first check what value + // has been registered for it already. + // + LatticeVal oldVal = getLatticeVal(inst); + + // If we have previous decided that the instruction could take + // on any value whatsoever, then any further update to our + // guess can't expand things more, and so there is nothing to do. + // + if( oldVal.flavor == LatticeVal::Flavor::Any ) + { + return; + } + + // Otherwise, we compute a new guess at the value of + // the instruction based on the lattice values of the + // stuff it depends on. + // + LatticeVal newVal = interpretOverLattice(inst); + + // If nothing changed about our guess, then there is nothing + // further to do, because users of this instruction have + // already computed their guess based on its current value. + // + if(newVal == oldVal) + { + return; + } + + // If the guess did change, then we want to register our + // new guess as the lattice value for this instruction. + // + setLatticeVal(inst, newVal); + + // Next we iterate over all the users of this instruction + // and add them to our work list so that we can update + // their values based on the new information. + // + for( auto use = inst->firstUse; use; use = use->nextUse ) + { + ssaWorkList.add(use->getUser()); + } + } + + // The `apply()` function will run the full algorithm. + // + void apply() + { + // We start with the busy-work of setting up our IR builder. + // + builderStorage.sharedBuilder = &shared->sharedBuilder; + + // We expect the caller to have filtered out functions with + // no bodies, so there should always be at least one basic block. + // + auto firstBlock = code->getFirstBlock(); + SLANG_ASSERT(firstBlock); + + // The entry block is always going to be executed when the + // function gets called, so we will process it right away. + // + cfgWorkList.add(firstBlock); + + // The parameters of the first block are our function parameters, + // and we want to operate on the assumption that they could have + // any value possible, so we will record that in our dictionary. + // + for( auto pp : firstBlock->getParams() ) + { + setLatticeVal(pp, LatticeVal::getAny()); + } + + // Now we will iterate until both of our work lists go dry. + // + while(cfgWorkList.getCount() || ssaWorkList.getCount()) + { + // Note: there is a design choice to be had here + // around whether we do `if if` or `while while` + // for these nested checks. The choice can affect + // how long things take to converge. + + // We will start by processing any blocks that we + // have determined are potentially reachable. + // + while( cfgWorkList.getCount() ) + { + // We pop one block off of the work list. + // + auto block = cfgWorkList[0]; + cfgWorkList.fastRemoveAt(0); + + // We only want to process blocks that haven't + // already been marked as executed, so that we + // don't do redundant work. + // + if( !isMarkedAsExecuted(block) ) + { + // We should mark this new block as executed, + // so we can ignore it if it ever ends up on + // the work list again. + // + markAsExecuted(block); + + // If the block is potentially executed, then + // that means the instructions in the block are too. + // We will walk through the block and update our + // guess at the value of each instruction, which + // may in turn add other blocks/instructions to + // the work lists. + // + for( auto inst : block->getDecorationsAndChildren() ) + { + updateValueForInst(inst); + } + } + } + + // Once we've cleared the work list of blocks, we + // will start looking at individual instructions that + // need to be updated. + // + while( ssaWorkList.getCount() ) + { + // We pop one instruction that needs an update. + // + auto inst = ssaWorkList[0]; + ssaWorkList.fastRemoveAt(0); + + // Before updating the instruction, we will check if + // the parent block of the instructin is marked as + // being executed. If it isn't, there is no reason + // to update the value for the instruction, since + // it might never be used anyway. + // + IRBlock* block = as(inst->getParent()); + + // It is possible that an instruction ended up on + // our SSA work list because it is a user of an + // instruction in a block of `code`, but it is not + // itself an instruction a block of `code`. + // + // For example, if `code` is an `IRGeneric` that + // yields a function, then `inst` might be an + // instruction of that nested function, and not + // an instruction of the generic itself. + // Note that in such a case, the `inst` cannot + // possible affect the values computed in the outer + // generic, or the control-flow paths it might take, + // so there is no reason to consider it. + // + // We guard against this case by only processing `inst` + // if it is a child of a block in the current `code`. + // + if(!block || block->getParent() != code) + continue; + + if( isMarkedAsExecuted(block) ) + { + // If the instruction is potentially executed, we update + // its lattice value based on our abstraction interpretation. + // + updateValueForInst(inst); + } + } + } + + // Once the work lists are empty, our "guesses" at the value + // of different instructions and the potentially-executed-ness + // of blocks should have converged to a conservative steady state. + // + // We are now equiped to start using the information we've gathered + // to modify the code. + + // First, we will walk through all the code and replace instructions + // with constants where it is possible. + // + List instsToRemove; + for( auto block : code->getBlocks() ) + { + for( auto inst : block->getDecorationsAndChildren() ) + { + // We look for instructions that have a constnat value on + // the lattice. + // + LatticeVal latticeVal = getLatticeVal(inst); + if(latticeVal.flavor != LatticeVal::Flavor::Constant) + continue; + + // As a small sanity check, we won't go replacing an + // instruction with itself (this shouldn't really come + // up, since constants are supposed to be at the global + // scope right now) + // + IRInst* constantVal = latticeVal.value; + if(constantVal == inst) + continue; + + // We replace any uses of the instruction with its + // constant expected value, and add it to a list of + // instructions to be removed *iff* the instruction + // is known to have no obersvable side effects. + // + inst->replaceUsesWith(constantVal); + if( !inst->mightHaveSideEffects() ) + { + instsToRemove.add(inst); + } + } + } + + // Once we've replaced the uses of instructions that evaluate + // to constants, we make a second pass to remove the instructions + // themselves (or at least those without side effects). + // + for( auto inst : instsToRemove ) + { + inst->removeAndDeallocate(); + } + + // Next we are going to walk through all of the terminator + // instructions on blocks and look for ones that branch + // based on a constant condition. These will be rewritten + // to use direct branching instructions, which will of course + // need to be emitted using a builder. + // + auto builder = getBuilder(); + for( auto block : code->getBlocks() ) + { + auto terminator = block->getTerminator(); + + // We check if we have a `switch` instruction with a constant + // integer as its condition. + // + if( auto switchInst = as(terminator) ) + { + if( auto constVal = as(switchInst->getCondition()) ) + { + // We will select the one branch that gets taken, based + // on the constant condition value. The `default` label + // will of course be taken if no `case` label matches. + // + IRBlock* target = switchInst->getDefaultLabel(); + UInt caseCount = switchInst->getCaseCount(); + for(UInt cc = 0; cc < caseCount; ++cc) + { + auto caseVal = switchInst->getCaseValue(cc); + if(auto caseConst = as(caseVal)) + { + if( caseConst->getValue() == constVal->getValue() ) + { + target = switchInst->getCaseLabel(cc); + break; + } + } + } + + // Once we've found the target, we will emit a direct + // branch to it before the old terminator, and then remove + // the old terminator instruction. + // + builder->setInsertBefore(terminator); + builder->emitBranch(target); + terminator->removeAndDeallocate(); + } + } + else if(auto condBranchInst = as(terminator)) + { + if( auto constVal = as(condBranchInst->getCondition()) ) + { + // The case for a two-sided conditional branch is similar + // to the `switch` case, but simpler. + + IRBlock* target = constVal->getValue() ? condBranchInst->getTrueBlock() : condBranchInst->getFalseBlock(); + + builder->setInsertBefore(terminator); + builder->emitBranch(target); + terminator->removeAndDeallocate(); + } + + } + } + + // At this point we've replaced some conditional branches + // that would always go the same way (e.g., a `while(true)`), + // which should render some of our blocks unreachable. + // We will collect all those unreachable blocks into a list + // of blocks to be removed, and then go about trying to + // remove them. + // + List unreachableBlocks; + for( auto block : code->getBlocks() ) + { + if( !isMarkedAsExecuted(block) ) + { + unreachableBlocks.add(block); + } + } + // + // It might seem like we could just do: + // + // block->removeAndDeallocate(); + // + // for each of the blocks in `unreachableBlocks`, but there + // is a subtle point that has to be considered: + // + // We have a structured control-flow representation where + // certain branching instructions name "join points" where + // control flow logically re-converges. It is possible that + // one of our unreachable blocks is still being used as + // a join point. + // + // For example: + // + // if(A) + // return B; + // else + // return C; + // D; + // + // In the above example, the block that computes `D` is + // unreachable, but it is still the join point for the `if(A)` + // branch. + // + // Rather than complicate the encoding of join points to + // try to special-case an unreachable join point, we will + // instead retain the join point as a block with only a single + // `unreachable` instruction. + // + // To detect which blocks are unreachable and unreferenced, + // we will check which blocks have any uses. Of course, it + // might be that some of our unreachable blocks still reference + // one another (e.g., an unreachable loop) so we will start + // by removing the instructions from the bodies of our unreachable + // blocks to eliminate any cross-references between them. + // + for( auto block : unreachableBlocks ) + { + // TODO: In principle we could produce a diagnostic here + // if any of these unreachable blocks appears to have + // "non-trivial" code in it (that is, any code explicitly + // written by the user, and not just code synthesized by + // the compiler to satisfy language rules). Making that + // determination could be tricky, so for now we will + // err on the side of allowing unreachable code without + // a warning. + // + block->removeAndDeallocateAllDecorationsAndChildren(); + } + // + // At this point every one of our unreachable blocks is empty, + // and there should be no branches from reachable blocks + // to unreachable ones. + // + // We will iterate over our unreachable blocks, and process + // them differently based on whether they have any remaining uses. + // + for( auto block : unreachableBlocks ) + { + // At this point there had better be no edges branching to + // our block. We determined it was unreachable, so there had + // better not be branches from reachable blocks to this one, + // and all the unreachable blocks had their instructions + // removed, so there should be no branches to it from other + // unreachable blocks (or itself). + // + SLANG_ASSERT(block->getPredecessors().isEmpty()); + + // If the block is completely unreferenced, we can safely + // remove and deallocate it now. + // + if( !block->hasUses() ) + { + block->removeAndDeallocate(); + } + else + { + // Otherwise, the block has at least one use (but + // no predecessors), which should indicate that it + // is an unreachable join point. + // + // We will keep the block around, but its entire + // body will consist of a single `unreachable` + // instruction. + // + builder->setInsertInto(block); + builder->emitUnreachable(); + } + } + } +}; + +static void applySparseConditionalConstantPropagationRec( + SharedSCCPContext* shared, + IRInst* inst) +{ + if( auto code = as(inst) ) + { + if( code->getFirstBlock() ) + { + SCCPContext context; + context.shared = shared; + context.code = code; + context.apply(); + } + } + + for( auto childInst : inst->getDecorationsAndChildren() ) + { + applySparseConditionalConstantPropagationRec(shared, childInst); + } +} + +void applySparseConditionalConstantPropagation( + IRModule* module) +{ + SharedSCCPContext shared; + shared.module = module; + shared.sharedBuilder.module = module; + shared.sharedBuilder.session = module->getSession(); + + applySparseConditionalConstantPropagationRec(&shared, module->getModuleInst()); +} + +} + diff --git a/source/slang/slang-ir-sccp.h b/source/slang/slang-ir-sccp.h new file mode 100644 index 000000000..b557eefe3 --- /dev/null +++ b/source/slang/slang-ir-sccp.h @@ -0,0 +1,18 @@ +// slang-ir-sccp.h +#pragma once + +namespace Slang +{ + struct IRModule; + + /// Apply Sparse Conditional Constant Propagation (SCCP) to a module. + /// + /// This optimization replaces instructions that can only ever evaluate + /// to a single (well-defined) value with that constant value, and + /// also eliminates conditional branches where the condition will + /// always evaluate to a constant (which can lead to entire blocks + /// becoming dead code) + void applySparseConditionalConstantPropagation( + IRModule* module); +} + diff --git a/source/slang/slang-ir-serialize.cpp b/source/slang/slang-ir-serialize.cpp new file mode 100644 index 000000000..cbb774794 --- /dev/null +++ b/source/slang/slang-ir-serialize.cpp @@ -0,0 +1,2125 @@ +// slang-ir-serialize.cpp +#include "slang-ir-serialize.h" + +#include "../core/slang-text-io.h" +#include "../core/slang-byte-encode-util.h" + +#include "slang-ir-insts.h" + +#include "../core/slang-math.h" + +namespace Slang { + +// Needed for linkage with some compilers +/* static */ const IRSerialData::StringIndex IRSerialData::kNullStringIndex; +/* static */ const IRSerialData::StringIndex IRSerialData::kEmptyStringIndex; + +/* Note that an IRInst can be derived from, but when it derived from it's new members are IRUse variables, and they in +effect alias over the operands - and reflected in the operand count. There _could_ be other members after these IRUse +variables, but only a few types include extra data, and these do not have any operands: + +* IRConstant - Needs special-case handling +* IRModuleInst - Presumably we can just set to the module pointer on reconstruction + +Note! That on an IRInst there is an IRType* variable (accessed as getFullType()). As it stands it may NOT actually point +to an IRType derived type. Its 'ok' as long as it's an instruction that can be used in the place of the type. So this code does not +bother to check if it's correct, and just casts it. +*/ + +/* static */const IRSerialData::PayloadInfo IRSerialData::s_payloadInfos[int(Inst::PayloadType::CountOf)] = +{ + { 0, 0 }, // Empty + { 1, 0 }, // Operand_1 + { 2, 0 }, // Operand_2 + { 1, 0 }, // OperandAndUInt32, + { 0, 0 }, // OperandExternal - This isn't correct, Operand has to be specially handled + { 0, 1 }, // String_1, + { 0, 2 }, // String_2, + { 0, 0 }, // UInt32, + { 0, 0 }, // Float64, + { 0, 0 } // Int64, +}; + +static bool isTextureTypeBase(IROp opIn) +{ + const int op = (kIROpMeta_PseudoOpMask & opIn); + return op >= kIROp_FirstTextureTypeBase && op <= kIROp_LastTextureTypeBase; +} + +static bool isConstant(IROp opIn) +{ + const int op = (kIROpMeta_PseudoOpMask & opIn); + return op >= kIROp_FirstConstant && op <= kIROp_LastConstant; +} + +struct PrefixString; + +namespace { // anonymous + +struct CharReader +{ + char operator()(int pos) const { SLANG_UNUSED(pos); return *m_pos++; } + CharReader(const char* pos) :m_pos(pos) {} + mutable const char* m_pos; +}; + +} // anonymous + +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! StringRepresentationCache !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! + +StringRepresentationCache::StringRepresentationCache(): + m_stringTable(nullptr), + m_namePool(nullptr), + m_scopeManager(nullptr) +{ +} + +void StringRepresentationCache::init(const List* stringTable, NamePool* namePool, ObjectScopeManager* scopeManager) +{ + m_stringTable = stringTable; + m_namePool = namePool; + m_scopeManager = scopeManager; + + // Decode the table + m_entries.setCount(StringSlicePool::kNumDefaultHandles); + SLANG_COMPILE_TIME_ASSERT(StringSlicePool::kNumDefaultHandles == 2); + + { + Entry& entry = m_entries[0]; + entry.m_numChars = 0; + entry.m_startIndex = 0; + entry.m_object = nullptr; + } + { + Entry& entry = m_entries[1]; + entry.m_numChars = 0; + entry.m_startIndex = 0; + entry.m_object = nullptr; + } + + { + const char* start = stringTable->begin(); + const char* cur = start; + const char* end = stringTable->end(); + + while (cur < end) + { + CharReader reader(cur); + const int len = GetUnicodePointFromUTF8(reader); + + Entry entry; + entry.m_startIndex = uint32_t(reader.m_pos - start); + entry.m_numChars = len; + entry.m_object = nullptr; + + m_entries.add(entry); + + cur = reader.m_pos + len; + } + } + + m_entries.compress(); +} + +Name* StringRepresentationCache::getName(Handle handle) +{ + if (handle == StringSlicePool::kNullHandle) + { + return nullptr; + } + + Entry& entry = m_entries[int(handle)]; + if (entry.m_object) + { + Name* name = dynamicCast(entry.m_object); + if (name) + { + return name; + } + StringRepresentation* stringRep = static_cast(entry.m_object); + // Promote it to a name + name = m_namePool->getName(String(stringRep)); + entry.m_object = name; + return name; + } + + Name* name = m_namePool->getName(String(getStringSlice(handle))); + entry.m_object = name; + return name; +} + +String StringRepresentationCache::getString(Handle handle) +{ + return String(getStringRepresentation(handle)); +} + +UnownedStringSlice StringRepresentationCache::getStringSlice(Handle handle) const +{ + const Entry& entry = m_entries[int(handle)]; + const char* start = m_stringTable->begin(); + + return UnownedStringSlice(start + entry.m_startIndex, int(entry.m_numChars)); +} + +StringRepresentation* StringRepresentationCache::getStringRepresentation(Handle handle) +{ + if (handle == StringSlicePool::kNullHandle || handle == StringSlicePool::kEmptyHandle) + { + return nullptr; + } + + Entry& entry = m_entries[int(handle)]; + if (entry.m_object) + { + Name* name = dynamicCast(entry.m_object); + if (name) + { + return name->text.getStringRepresentation(); + } + return static_cast(entry.m_object); + } + + const UnownedStringSlice slice = getStringSlice(handle); + const UInt size = slice.size(); + + StringRepresentation* stringRep = StringRepresentation::createWithCapacityAndLength(size, size); + memcpy(stringRep->getData(), slice.begin(), size); + entry.m_object = stringRep; + + // Keep the StringRepresentation in scope + m_scopeManager->add(stringRep); + + return stringRep; +} + +char* StringRepresentationCache::getCStr(Handle handle) +{ + // It turns out StringRepresentation is always 0 terminated, so can just use that + StringRepresentation* rep = getStringRepresentation(handle); + return rep->getData(); +} + +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! SerialStringTableUtil !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! + +/* static */void SerialStringTableUtil::encodeStringTable(const StringSlicePool& pool, List& stringTable) +{ + // Skip the default handles -> nothing is encoded via them + return encodeStringTable(pool.getSlices().begin() + StringSlicePool::kNumDefaultHandles, pool.getNumSlices() - StringSlicePool::kNumDefaultHandles, stringTable); +} + +/* static */void SerialStringTableUtil::encodeStringTable(const UnownedStringSlice* slices, size_t numSlices, List& stringTable) +{ + stringTable.clear(); + for (size_t i = 0; i < numSlices; ++i) + { + const UnownedStringSlice slice = slices[i]; + const int len = int(slice.size()); + + // We need to write into the the string array + char prefixBytes[6]; + const int numPrefixBytes = EncodeUnicodePointToUTF8(prefixBytes, len); + const Index baseIndex = stringTable.getCount(); + + stringTable.setCount(baseIndex + numPrefixBytes + len); + + char* dst = stringTable.begin() + baseIndex; + + memcpy(dst, prefixBytes, numPrefixBytes); + memcpy(dst + numPrefixBytes, slice.begin(), len); + } +} + +/* static */void SerialStringTableUtil::appendDecodedStringTable(const List& stringTable, List& slicesOut) +{ + const char* start = stringTable.begin(); + const char* cur = start; + const char* end = stringTable.end(); + + while (cur < end) + { + CharReader reader(cur); + const int len = GetUnicodePointFromUTF8(reader); + slicesOut.add(UnownedStringSlice(reader.m_pos, len)); + cur = reader.m_pos + len; + } +} + +/* static */void SerialStringTableUtil::decodeStringTable(const List& stringTable, List& slicesOut) +{ + slicesOut.setCount(2); + slicesOut[0] = UnownedStringSlice(nullptr, size_t(0)); + slicesOut[1] = UnownedStringSlice("", size_t(0)); + + appendDecodedStringTable(stringTable, slicesOut); +} + +/* static */void SerialStringTableUtil::calcStringSlicePoolMap(const List& slices, StringSlicePool& pool, List& indexMapOut) +{ + SLANG_ASSERT(slices.getCount() >= StringSlicePool::kNumDefaultHandles); + SLANG_ASSERT(slices[int(StringSlicePool::kNullHandle)] == "" && slices[int(StringSlicePool::kNullHandle)].begin() == nullptr); + SLANG_ASSERT(slices[int(StringSlicePool::kEmptyHandle)] == ""); + + indexMapOut.setCount(slices.getCount()); + // Set up all of the defaults + for (int i = 0; i < StringSlicePool::kNumDefaultHandles; ++i) + { + indexMapOut[i] = StringSlicePool::Handle(i); + } + + const Index numSlices = slices.getCount(); + for (Index i = StringSlicePool::kNumDefaultHandles; i < numSlices ; ++i) + { + indexMapOut[i] = pool.add(slices[i]); + } +} + +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! IRSerialData !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! + +template +static size_t _calcArraySize(const List& list) +{ + return list.getCount() * sizeof(T); +} + +size_t IRSerialData::calcSizeInBytes() const +{ + return + _calcArraySize(m_insts) + + _calcArraySize(m_childRuns) + + _calcArraySize(m_externalOperands) + + _calcArraySize(m_stringTable) + + /* Raw source locs */ + _calcArraySize(m_rawSourceLocs) + + /* Debug */ + _calcArraySize(m_debugStringTable) + + _calcArraySize(m_debugLineInfos) + + _calcArraySize(m_debugSourceInfos) + + _calcArraySize(m_debugAdjustedLineInfos) + + _calcArraySize(m_debugSourceLocRuns); +} + +IRSerialData::IRSerialData() +{ + clear(); +} + +void IRSerialData::clear() +{ + // First Instruction is null + m_insts.setCount(1); + memset(&m_insts[0], 0, sizeof(Inst)); + + m_childRuns.clear(); + m_externalOperands.clear(); + m_rawSourceLocs.clear(); + + m_stringTable.clear(); + + // Debug data + m_debugLineInfos.clear(); + m_debugAdjustedLineInfos.clear(); + m_debugSourceInfos.clear(); + m_debugSourceLocRuns.clear(); + m_debugStringTable.clear(); +} + +template +static bool _isEqual(const List& aIn, const List& bIn) +{ + if (aIn.getCount() != bIn.getCount()) + { + return false; + } + + size_t size = size_t(aIn.getCount()); + + const T* a = aIn.begin(); + const T* b = bIn.begin(); + + if (a == b) + { + return true; + } + + for (size_t i = 0; i < size; ++i) + { + if (a[i] != b[i]) + { + return false; + } + } + + return true; +} + +bool IRSerialData::operator==(const ThisType& rhs) const +{ + return (this == &rhs) || + (_isEqual(m_insts, rhs.m_insts) && + _isEqual(m_childRuns, rhs.m_childRuns) && + _isEqual(m_externalOperands, rhs.m_externalOperands) && + _isEqual(m_rawSourceLocs, rhs.m_rawSourceLocs) && + _isEqual(m_stringTable, rhs.m_stringTable) && + /* Debug */ + _isEqual(m_debugStringTable, rhs.m_debugStringTable) && + _isEqual(m_debugLineInfos, rhs.m_debugLineInfos) && + _isEqual(m_debugAdjustedLineInfos, rhs.m_debugAdjustedLineInfos) && + _isEqual(m_debugSourceInfos, rhs.m_debugSourceInfos) && + _isEqual(m_debugSourceLocRuns, rhs.m_debugSourceLocRuns)); +} + +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! IRSerialWriter !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! + +void IRSerialWriter::_addInstruction(IRInst* inst) +{ + // It cannot already be in the map + SLANG_ASSERT(!m_instMap.ContainsKey(inst)); + + // Add to the map + m_instMap.Add(inst, Ser::InstIndex(m_insts.getCount())); + m_insts.add(inst); +} + +#if 0 +// Find a view index that matches the view by file (and perhaps other characteristics in the future) +static int _findSourceViewIndex(const List& viewsIn, SourceView* view) +{ + const int numViews = int(viewsIn.Count()); + SourceView*const* views = viewsIn.begin(); + + SourceFile* sourceFile = view->getSourceFile(); + + for (int i = 0; i < numViews; ++i) + { + SourceView* curView = views[i]; + // For now we just match on source file + if (curView->getSourceFile() == sourceFile) + { + // It's a hit + return i; + } + } + return -1; +} +#endif + +void IRSerialWriter::_addDebugSourceLocRun(SourceLoc sourceLoc, uint32_t startInstIndex, uint32_t numInsts) +{ + SourceView* sourceView = m_sourceManager->findSourceView(sourceLoc); + if (!sourceView) + { + return; + } + + SourceFile* sourceFile = sourceView->getSourceFile(); + DebugSourceFile* debugSourceFile; + { + RefPtr* ptrDebugSourceFile = m_debugSourceFileMap.TryGetValue(sourceFile); + if (ptrDebugSourceFile == nullptr) + { + const SourceLoc::RawValue baseSourceLoc = m_debugFreeSourceLoc; + m_debugFreeSourceLoc += SourceLoc::RawValue(sourceView->getRange().getSize() + 1); + + debugSourceFile = new DebugSourceFile(sourceFile, baseSourceLoc); + m_debugSourceFileMap.Add(sourceFile, debugSourceFile); + } + else + { + debugSourceFile = *ptrDebugSourceFile; + } + } + + // We need to work out the line index + + int offset = sourceView->getRange().getOffset(sourceLoc); + int lineIndex = sourceFile->calcLineIndexFromOffset(offset); + + IRSerialData::DebugLineInfo lineInfo; + lineInfo.m_lineStartOffset = sourceFile->getLineBreakOffsets()[lineIndex]; + lineInfo.m_lineIndex = lineIndex; + + if (!debugSourceFile->hasLineIndex(lineIndex)) + { + // Add the information about the line + int entryIndex = sourceView->findEntryIndex(sourceLoc); + if (entryIndex < 0) + { + debugSourceFile->m_lineInfos.add(lineInfo); + } + else + { + const auto& entry = sourceView->getEntries()[entryIndex]; + + IRSerialData::DebugAdjustedLineInfo adjustedLineInfo; + adjustedLineInfo.m_lineInfo = lineInfo; + adjustedLineInfo.m_pathStringIndex = Ser::kNullStringIndex; + + if (StringSlicePool::hasContents(entry.m_pathHandle)) + { + UnownedStringSlice slice = sourceView->getSourceManager()->getStringSlicePool().getSlice(entry.m_pathHandle); + SLANG_ASSERT(slice.size() > 0); + adjustedLineInfo.m_pathStringIndex = Ser::StringIndex(m_debugStringSlicePool.add(slice)); + } + + adjustedLineInfo.m_adjustedLineIndex = lineIndex + entry.m_lineAdjust; + + debugSourceFile->m_adjustedLineInfos.add(adjustedLineInfo); + } + + debugSourceFile->setHasLineIndex(lineIndex); + } + + // Add the run + IRSerialData::SourceLocRun sourceLocRun; + sourceLocRun.m_numInst = numInsts; + sourceLocRun.m_startInstIndex = IRSerialData::InstIndex(startInstIndex); + sourceLocRun.m_sourceLoc = uint32_t(debugSourceFile->m_baseSourceLoc + offset); + + m_serialData->m_debugSourceLocRuns.add(sourceLocRun); +} + +Result IRSerialWriter::_calcDebugInfo() +{ + // We need to find the unique source Locs + // We are not going to store SourceLocs directly, because there may be multiple views mapping down to + // the same underlying source file + + // First find all the unique locs + struct InstLoc + { + typedef InstLoc ThisType; + + SLANG_FORCE_INLINE bool operator<(const ThisType& rhs) const { return sourceLoc < rhs.sourceLoc || (sourceLoc == rhs.sourceLoc && instIndex < rhs.instIndex); } + + uint32_t instIndex; + uint32_t sourceLoc; + }; + + // Find all of the source locations and their associated instructions + List instLocs; + const Index numInsts = m_insts.getCount(); + for (Index i = 1; i < numInsts; i++) + { + IRInst* srcInst = m_insts[i]; + if (!srcInst->sourceLoc.isValid()) + { + continue; + } + InstLoc instLoc; + instLoc.instIndex = uint32_t(i); + instLoc.sourceLoc = uint32_t(srcInst->sourceLoc.getRaw()); + instLocs.add(instLoc); + } + + // Sort them + instLocs.sort(); + m_debugFreeSourceLoc = 1; + + // Look for runs + const InstLoc* startInstLoc = instLocs.begin(); + const InstLoc* endInstLoc = instLocs.end(); + + while (startInstLoc < endInstLoc) + { + const uint32_t startSourceLoc = startInstLoc->sourceLoc; + + // Find the run with the same source loc + + const InstLoc* curInstLoc = startInstLoc + 1; + uint32_t curInstIndex = startInstLoc->instIndex + 1; + + // Find the run size with same source loc and run of instruction indices + for (; curInstLoc < endInstLoc && curInstLoc->sourceLoc == startSourceLoc && curInstLoc->instIndex == curInstIndex; ++curInstLoc, ++curInstIndex) + { + } + + // Try adding the run + _addDebugSourceLocRun(SourceLoc::fromRaw(startSourceLoc), startInstLoc->instIndex, curInstIndex - startInstLoc->instIndex); + + // Next + startInstLoc = curInstLoc; + } + + // Okay we can now calculate the final source information + + for (auto& pair : m_debugSourceFileMap) + { + DebugSourceFile* debugSourceFile = pair.Value; + SourceFile* sourceFile = debugSourceFile->m_sourceFile; + + IRSerialData::DebugSourceInfo sourceInfo; + + sourceInfo.m_numLines = uint32_t(debugSourceFile->m_sourceFile->getLineBreakOffsets().getCount()); + + sourceInfo.m_startSourceLoc = uint32_t(debugSourceFile->m_baseSourceLoc); + sourceInfo.m_endSourceLoc = uint32_t(debugSourceFile->m_baseSourceLoc + sourceFile->getContentSize()); + + sourceInfo.m_pathIndex = Ser::StringIndex(m_debugStringSlicePool.add(sourceFile->getPathInfo().foundPath)); + + sourceInfo.m_lineInfosStartIndex = uint32_t(m_serialData->m_debugLineInfos.getCount()); + sourceInfo.m_adjustedLineInfosStartIndex = uint32_t(m_serialData->m_debugAdjustedLineInfos.getCount()); + + sourceInfo.m_numLineInfos = uint32_t(debugSourceFile->m_lineInfos.getCount()); + sourceInfo.m_numAdjustedLineInfos = uint32_t(debugSourceFile->m_adjustedLineInfos.getCount()); + + // Add the line infos + m_serialData->m_debugLineInfos.addRange(debugSourceFile->m_lineInfos.begin(), debugSourceFile->m_lineInfos.getCount()); + m_serialData->m_debugAdjustedLineInfos.addRange(debugSourceFile->m_adjustedLineInfos.begin(), debugSourceFile->m_adjustedLineInfos.getCount()); + + // Add the source info + m_serialData->m_debugSourceInfos.add(sourceInfo); + } + + // Convert the string pool + SerialStringTableUtil::encodeStringTable(m_debugStringSlicePool, m_serialData->m_debugStringTable); + + return SLANG_OK; +} + +Result IRSerialWriter::write(IRModule* module, SourceManager* sourceManager, OptionFlags options, IRSerialData* serialData) +{ + typedef Ser::Inst::PayloadType PayloadType; + + m_sourceManager = sourceManager; + m_serialData = serialData; + + serialData->clear(); + + // We reserve 0 for null + m_insts.clear(); + m_insts.add(nullptr); + + // Reset + m_instMap.Clear(); + m_decorations.clear(); + + // Stack for parentInst + List parentInstStack; + + IRModuleInst* moduleInst = module->getModuleInst(); + parentInstStack.add(moduleInst); + + // Add to the map + _addInstruction(moduleInst); + + // Traverse all of the instructions + while (parentInstStack.getCount()) + { + // If it's in the stack it is assumed it is already in the inst map + IRInst* parentInst = parentInstStack.getLast(); + parentInstStack.removeLast(); + SLANG_ASSERT(m_instMap.ContainsKey(parentInst)); + + // Okay we go through each of the children in order. If they are IRInstParent derived, we add to stack to process later + // cos we want breadth first so the order of children is the same as their index order, meaning we don't need to store explicit indices + const Ser::InstIndex startChildInstIndex = Ser::InstIndex(m_insts.getCount()); + + IRInstListBase childrenList = parentInst->getDecorationsAndChildren(); + for (IRInst* child : childrenList) + { + // This instruction can't be in the map... + SLANG_ASSERT(!m_instMap.ContainsKey(child)); + + _addInstruction(child); + + parentInstStack.add(child); + } + + // If it had any children, then store the information about it + if (Ser::InstIndex(m_insts.getCount()) != startChildInstIndex) + { + Ser::InstRun run; + run.m_parentIndex = m_instMap[parentInst]; + run.m_startInstIndex = startChildInstIndex; + run.m_numChildren = Ser::SizeType(m_insts.getCount() - int(startChildInstIndex)); + + m_serialData->m_childRuns.add(run); + } + } + +#if 0 + { + List workInsts; + calcInstructionList(module, workInsts); + SLANG_ASSERT(workInsts.Count() == m_insts.Count()); + for (UInt i = 0; i < workInsts.Count(); ++i) + { + SLANG_ASSERT(workInsts[i] == m_insts[i]); + } + } +#endif + + // Set to the right size + m_serialData->m_insts.setCount(m_insts.getCount()); + // Clear all instructions + memset(m_serialData->m_insts.begin(), 0, sizeof(Ser::Inst) * m_serialData->m_insts.getCount()); + + // Need to set up the actual instructions + { + const Index numInsts = m_insts.getCount(); + + for (Index i = 1; i < numInsts; ++i) + { + IRInst* srcInst = m_insts[i]; + Ser::Inst& dstInst = m_serialData->m_insts[i]; + + // Can't be any pseudo ops + SLANG_ASSERT(!isPseudoOp(srcInst->op)); + + dstInst.m_op = uint8_t(srcInst->op & kIROpMeta_OpMask); + dstInst.m_payloadType = PayloadType::Empty; + + dstInst.m_resultTypeIndex = getInstIndex(srcInst->getFullType()); + + IRConstant* irConst = as(srcInst); + if (irConst) + { + switch (srcInst->op) + { + // Special handling for the ir const derived types + case kIROp_StringLit: + { + auto stringLit = static_cast(srcInst); + dstInst.m_payloadType = PayloadType::String_1; + dstInst.m_payload.m_stringIndices[0] = getStringIndex(stringLit->getStringSlice()); + break; + } + case kIROp_IntLit: + { + dstInst.m_payloadType = PayloadType::Int64; + dstInst.m_payload.m_int64 = irConst->value.intVal; + break; + } + case kIROp_PtrLit: + { + dstInst.m_payloadType = PayloadType::Int64; + dstInst.m_payload.m_int64 = (intptr_t) irConst->value.ptrVal; + break; + } + case kIROp_FloatLit: + { + dstInst.m_payloadType = PayloadType::Float64; + dstInst.m_payload.m_float64 = irConst->value.floatVal; + break; + } + case kIROp_BoolLit: + { + dstInst.m_payloadType = PayloadType::UInt32; + dstInst.m_payload.m_uint32 = irConst->value.intVal ? 1 : 0; + break; + } + default: + { + SLANG_RELEASE_ASSERT(!"Unhandled constant type"); + return SLANG_FAIL; + } + } + continue; + } + + IRTextureTypeBase* textureBase = as(srcInst); + if (textureBase) + { + dstInst.m_payloadType = PayloadType::OperandAndUInt32; + dstInst.m_payload.m_operandAndUInt32.m_uint32 = uint32_t(srcInst->op) >> kIROpMeta_OtherShift; + dstInst.m_payload.m_operandAndUInt32.m_operand = getInstIndex(textureBase->getElementType()); + continue; + } + + // ModuleInst is different, in so far as it holds a pointer to IRModule, but we don't need + // to save that off in a special way, so can just use regular path + + const int numOperands = int(srcInst->operandCount); + Ser::InstIndex* dstOperands = nullptr; + + if (numOperands <= Ser::Inst::kMaxOperands) + { + // Checks the compile below is valid + SLANG_COMPILE_TIME_ASSERT(PayloadType(0) == PayloadType::Empty && PayloadType(1) == PayloadType::Operand_1 && PayloadType(2) == PayloadType::Operand_2); + + dstInst.m_payloadType = PayloadType(numOperands); + dstOperands = dstInst.m_payload.m_operands; + } + else + { + dstInst.m_payloadType = PayloadType::OperandExternal; + + int operandArrayBaseIndex = int(m_serialData->m_externalOperands.getCount()); + m_serialData->m_externalOperands.setCount(operandArrayBaseIndex + numOperands); + + dstOperands = m_serialData->m_externalOperands.begin() + operandArrayBaseIndex; + + auto& externalOperands = dstInst.m_payload.m_externalOperand; + externalOperands.m_arrayIndex = Ser::ArrayIndex(operandArrayBaseIndex); + externalOperands.m_size = Ser::SizeType(numOperands); + } + + for (int j = 0; j < numOperands; ++j) + { + const Ser::InstIndex dstInstIndex = getInstIndex(srcInst->getOperand(j)); + dstOperands[j] = dstInstIndex; + } + } + } + + // Convert strings into a string table + { + SerialStringTableUtil::encodeStringTable(m_stringSlicePool, serialData->m_stringTable); + } + + // If the option to use RawSourceLocations is enabled, serialize out as is + if (options & OptionFlag::RawSourceLocation) + { + const Index numInsts = m_insts.getCount(); + serialData->m_rawSourceLocs.setCount(numInsts); + + Ser::RawSourceLoc* dstLocs = serialData->m_rawSourceLocs.begin(); + // 0 is null, just mark as no location + dstLocs[0] = Ser::RawSourceLoc(0); + for (Index i = 1; i < numInsts; ++i) + { + IRInst* srcInst = m_insts[i]; + dstLocs[i] = Ser::RawSourceLoc(srcInst->sourceLoc.getRaw()); + } + } + + if (options & OptionFlag::DebugInfo) + { + _calcDebugInfo(); + } + + m_serialData = nullptr; + return SLANG_OK; +} + +template +static size_t _calcChunkSize(IRSerialBinary::CompressionType compressionType, const List& array) +{ + typedef IRSerialBinary Bin; + + if (array.getCount()) + { + switch (compressionType) + { + case Bin::CompressionType::None: + { + const size_t size = sizeof(Bin::ArrayHeader) + sizeof(T) * array.getCount(); + return (size + 3) & ~size_t(3); + } + case Bin::CompressionType::VariableByteLite: + { + const size_t payloadSize = ByteEncodeUtil::calcEncodeLiteSizeUInt32((const uint32_t*)array.begin(), (array.getCount() * sizeof(T)) / sizeof(uint32_t)); + const size_t size = sizeof(Bin::CompressedArrayHeader) + payloadSize; + return (size + 3) & ~size_t(3); + } + default: + { + SLANG_ASSERT(!"Unhandled compression type"); + return 0; + } + } + } + else + { + return 0; + } +} + +static Result _writeArrayChunk(IRSerialBinary::CompressionType compressionType, uint32_t chunkId, const void* data, size_t numEntries, size_t typeSize, Stream* stream) +{ + typedef IRSerialBinary Bin; + + if (numEntries == 0) + { + return SLANG_OK; + } + + size_t payloadSize; + + switch (compressionType) + { + case Bin::CompressionType::None: + { + payloadSize = sizeof(Bin::ArrayHeader) - sizeof(Bin::Chunk) + typeSize * numEntries; + + Bin::ArrayHeader header; + header.m_chunk.m_type = chunkId; + header.m_chunk.m_size = uint32_t(payloadSize); + header.m_numEntries = uint32_t(numEntries); + + stream->Write(&header, sizeof(header)); + + stream->Write(data, typeSize * numEntries); + break; + } + case Bin::CompressionType::VariableByteLite: + { + List compressedPayload; + + size_t numCompressedEntries = (numEntries * typeSize) / sizeof(uint32_t); + + ByteEncodeUtil::encodeLiteUInt32((const uint32_t*)data, numCompressedEntries, compressedPayload); + + payloadSize = sizeof(Bin::CompressedArrayHeader) - sizeof(Bin::Chunk) + compressedPayload.getCount(); + + Bin::CompressedArrayHeader header; + header.m_chunk.m_type = SLANG_MAKE_COMPRESSED_FOUR_CC(chunkId); + header.m_chunk.m_size = uint32_t(payloadSize); + header.m_numEntries = uint32_t(numEntries); + header.m_numCompressedEntries = uint32_t(numCompressedEntries); + + stream->Write(&header, sizeof(header)); + + stream->Write(compressedPayload.begin(), compressedPayload.getCount()); + break; + } + default: + { + return SLANG_FAIL; + } + } + // All chunks have sizes rounded to dword size + if (payloadSize & 3) + { + const uint8_t pad[4] = { 0, 0, 0, 0 }; + // Pad outs + int padSize = 4 - (payloadSize & 3); + stream->Write(pad, padSize); + } + + return SLANG_OK; +} + +template +Result _writeArrayChunk(IRSerialBinary::CompressionType compressionType, uint32_t chunkId, const List& array, Stream* stream) +{ + return _writeArrayChunk(compressionType, chunkId, array.begin(), size_t(array.getCount()), sizeof(T), stream); +} + +Result _encodeInsts(IRSerialBinary::CompressionType compressionType, const List& instsIn, List& encodeArrayOut) +{ + typedef IRSerialBinary Bin; + typedef IRSerialData::Inst::PayloadType PayloadType; + + if (compressionType != Bin::CompressionType::VariableByteLite) + { + return SLANG_FAIL; + } + encodeArrayOut.clear(); + + const size_t numInsts = size_t(instsIn.getCount()); + const IRSerialData::Inst* insts = instsIn.begin(); + + uint8_t* encodeOut = encodeArrayOut.begin(); + uint8_t* encodeEnd = encodeArrayOut.end(); + + // Calculate the maximum instruction size with worst case possible encoding + // 2 bytes hold the payload size, and the result type + // Note that if there were some free bits, we could encode some of this stuff into bits, but if we remove payloadType, then there are no free bits + const size_t maxInstSize = 2 + ByteEncodeUtil::kMaxLiteEncodeUInt32 + Math::Max(sizeof(insts->m_payload.m_float64), size_t(2 * ByteEncodeUtil::kMaxLiteEncodeUInt32)); + + for (size_t i = 0; i < numInsts; ++i) + { + const auto& inst = insts[i]; + + // Make sure there is space for the largest possible instruction + if (encodeOut + maxInstSize >= encodeEnd) + { + const size_t offset = size_t(encodeOut - encodeArrayOut.begin()); + + const UInt oldCapacity = encodeArrayOut.getCapacity(); + + encodeArrayOut.reserve(oldCapacity + (oldCapacity >> 1) + maxInstSize); + const UInt capacity = encodeArrayOut.getCapacity(); + encodeArrayOut.setCount(capacity); + + encodeOut = encodeArrayOut.begin() + offset; + encodeEnd = encodeArrayOut.end(); + } + + *encodeOut++ = uint8_t(inst.m_op); + *encodeOut++ = uint8_t(inst.m_payloadType); + + encodeOut += ByteEncodeUtil::encodeLiteUInt32((uint32_t)inst.m_resultTypeIndex, encodeOut); + + switch (inst.m_payloadType) + { + case PayloadType::Empty: + { + break; + } + case PayloadType::Operand_1: + case PayloadType::String_1: + case PayloadType::UInt32: + { + // 1 UInt32 + encodeOut += ByteEncodeUtil::encodeLiteUInt32((uint32_t)inst.m_payload.m_operands[0], encodeOut); + break; + } + case PayloadType::Operand_2: + case PayloadType::OperandAndUInt32: + case PayloadType::OperandExternal: + case PayloadType::String_2: + { + // 2 UInt32 + encodeOut += ByteEncodeUtil::encodeLiteUInt32((uint32_t)inst.m_payload.m_operands[0], encodeOut); + encodeOut += ByteEncodeUtil::encodeLiteUInt32((uint32_t)inst.m_payload.m_operands[1], encodeOut); + break; + } + case PayloadType::Float64: + { + memcpy(encodeOut, &inst.m_payload.m_float64, sizeof(inst.m_payload.m_float64)); + encodeOut += sizeof(inst.m_payload.m_float64); + break; + } + case PayloadType::Int64: + { + memcpy(encodeOut, &inst.m_payload.m_int64, sizeof(inst.m_payload.m_int64)); + encodeOut += sizeof(inst.m_payload.m_int64); + break; + } + } + } + + // Fix the size + encodeArrayOut.setCount(UInt(encodeOut - encodeArrayOut.begin())); + return SLANG_OK; +} + +Result _writeInstArrayChunk(IRSerialBinary::CompressionType compressionType, uint32_t chunkId, const List& array, Stream* stream) +{ + typedef IRSerialBinary Bin; + if (array.getCount() == 0) + { + return SLANG_OK; + } + + switch (compressionType) + { + case Bin::CompressionType::None: + { + return _writeArrayChunk(compressionType, chunkId, array, stream); + } + case Bin::CompressionType::VariableByteLite: + { + List compressedPayload; + SLANG_RETURN_ON_FAIL(_encodeInsts(compressionType, array, compressedPayload)); + + size_t payloadSize = sizeof(Bin::CompressedArrayHeader) - sizeof(Bin::Chunk) + compressedPayload.getCount(); + + Bin::CompressedArrayHeader header; + header.m_chunk.m_type = SLANG_MAKE_COMPRESSED_FOUR_CC(chunkId); + header.m_chunk.m_size = uint32_t(payloadSize); + header.m_numEntries = uint32_t(array.getCount()); + header.m_numCompressedEntries = 0; + + stream->Write(&header, sizeof(header)); + stream->Write(compressedPayload.begin(), compressedPayload.getCount()); + + // All chunks have sizes rounded to dword size + if (payloadSize & 3) + { + const uint8_t pad[4] = { 0, 0, 0, 0 }; + // Pad outs + int padSize = 4 - (payloadSize & 3); + stream->Write(pad, padSize); + } + return SLANG_OK; + } + default: break; + } + return SLANG_FAIL; +} + +static size_t _calcInstChunkSize(IRSerialBinary::CompressionType compressionType, const List& instsIn) +{ + typedef IRSerialBinary Bin; + typedef IRSerialData::Inst::PayloadType PayloadType; + + switch (compressionType) + { + case Bin::CompressionType::None: + { + return _calcChunkSize(compressionType, instsIn); + } + case Bin::CompressionType::VariableByteLite: + { + size_t size = sizeof(Bin::CompressedArrayHeader); + + size_t numInsts = size_t(instsIn.getCount()); + size += numInsts * 2; // op and payload + + IRSerialData::Inst* insts = instsIn.begin(); + + for (size_t i = 0; i < numInsts; ++i) + { + const auto& inst = insts[i]; + + size += ByteEncodeUtil::calcEncodeLiteSizeUInt32((uint32_t)inst.m_resultTypeIndex); + + switch (inst.m_payloadType) + { + case PayloadType::Empty: + { + break; + } + case PayloadType::Operand_1: + case PayloadType::String_1: + case PayloadType::UInt32: + { + // 1 UInt32 + size += ByteEncodeUtil::calcEncodeLiteSizeUInt32((uint32_t)inst.m_payload.m_operands[0]); + break; + } + case PayloadType::Operand_2: + case PayloadType::OperandAndUInt32: + case PayloadType::OperandExternal: + case PayloadType::String_2: + { + // 2 UInt32 + size += ByteEncodeUtil::calcEncodeLiteSizeUInt32((uint32_t)inst.m_payload.m_operands[0]); + size += ByteEncodeUtil::calcEncodeLiteSizeUInt32((uint32_t)inst.m_payload.m_operands[1]); + break; + } + case PayloadType::Float64: + { + size += sizeof(inst.m_payload.m_float64); + break; + } + case PayloadType::Int64: + { + size += sizeof(inst.m_payload.m_int64); + break; + } + } + } + + return (size + 3) & ~size_t(3); + } + default: break; + } + + SLANG_ASSERT(!"Unhandled compression type"); + return 0; +} + +/* static */Result IRSerialWriter::writeStream(const IRSerialData& data, Bin::CompressionType compressionType, Stream* stream) +{ + size_t totalSize = 0; + + totalSize += sizeof(Bin::SlangHeader) + + _calcInstChunkSize(compressionType, data.m_insts) + + _calcChunkSize(compressionType, data.m_childRuns) + + _calcChunkSize(compressionType, data.m_externalOperands) + + _calcChunkSize(Bin::CompressionType::None, data.m_stringTable) + + _calcChunkSize(Bin::CompressionType::None, data.m_rawSourceLocs); + + if (data.m_debugSourceInfos.getCount()) + { + totalSize += _calcChunkSize(Bin::CompressionType::None, data.m_debugStringTable) + + _calcChunkSize(Bin::CompressionType::None, data.m_debugLineInfos) + + _calcChunkSize(Bin::CompressionType::None, data.m_debugAdjustedLineInfos) + + _calcChunkSize(Bin::CompressionType::None, data.m_debugSourceInfos) + + _calcChunkSize(compressionType, data.m_debugSourceLocRuns); + } + + { + Bin::Chunk riffHeader; + riffHeader.m_type = Bin::kRiffFourCc; + riffHeader.m_size = uint32_t(totalSize); + + stream->Write(&riffHeader, sizeof(riffHeader)); + } + { + Bin::SlangHeader slangHeader; + slangHeader.m_chunk.m_type = Bin::kSlangFourCc; + slangHeader.m_chunk.m_size = uint32_t(sizeof(slangHeader) - sizeof(Bin::Chunk)); + slangHeader.m_compressionType = uint32_t(Bin::CompressionType::VariableByteLite); + + stream->Write(&slangHeader, sizeof(slangHeader)); + } + + SLANG_RETURN_ON_FAIL(_writeInstArrayChunk(compressionType, Bin::kInstFourCc, data.m_insts, stream)); + SLANG_RETURN_ON_FAIL(_writeArrayChunk(compressionType, Bin::kChildRunFourCc, data.m_childRuns, stream)); + SLANG_RETURN_ON_FAIL(_writeArrayChunk(compressionType, Bin::kExternalOperandsFourCc, data.m_externalOperands, stream)); + SLANG_RETURN_ON_FAIL(_writeArrayChunk(Bin::CompressionType::None, Bin::kStringFourCc, data.m_stringTable, stream)); + + SLANG_RETURN_ON_FAIL(_writeArrayChunk(Bin::CompressionType::None, Bin::kUInt32SourceLocFourCc, data.m_rawSourceLocs, stream)); + + if (data.m_debugSourceInfos.getCount()) + { + _writeArrayChunk(Bin::CompressionType::None, Bin::kDebugStringFourCc, data.m_debugStringTable, stream); + _writeArrayChunk(Bin::CompressionType::None, Bin::kDebugLineInfoFourCc, data.m_debugLineInfos, stream); + _writeArrayChunk(Bin::CompressionType::None, Bin::kDebugAdjustedLineInfoFourCc, data.m_debugAdjustedLineInfos, stream); + _writeArrayChunk(Bin::CompressionType::None, Bin::kDebugSourceInfoFourCc, data.m_debugSourceInfos, stream); + _writeArrayChunk(compressionType, Bin::kDebugSourceLocRunFourCc, data.m_debugSourceLocRuns, stream); + } + + return SLANG_OK; +} + +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! IRSerialReader !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! + +class ListResizer +{ + public: + virtual void* setSize(size_t newSize) = 0; + SLANG_FORCE_INLINE size_t getTypeSize() const { return m_typeSize; } + ListResizer(size_t typeSize):m_typeSize(typeSize) {} + + protected: + size_t m_typeSize; +}; + +template +class ListResizerForType: public ListResizer +{ + public: + typedef ListResizer Parent; + + SLANG_FORCE_INLINE ListResizerForType(List& list): + Parent(sizeof(T)), + m_list(list) + {} + + virtual void* setSize(size_t newSize) SLANG_OVERRIDE + { + m_list.setCount(UInt(newSize)); + return (void*)m_list.begin(); + } + + protected: + List& m_list; +}; + +static Result _readArrayChunk(IRSerialBinary::CompressionType compressionType, const IRSerialBinary::Chunk& chunk, Stream* stream, size_t* numReadInOut, ListResizer& listOut) +{ + typedef IRSerialBinary Bin; + + const size_t typeSize = listOut.getTypeSize(); + + switch (compressionType) + { + case Bin::CompressionType::VariableByteLite: + { + // We have a compressed header + Bin::CompressedArrayHeader header; + header.m_chunk = chunk; + + stream->Read(&header.m_chunk + 1, sizeof(header) - sizeof(Bin::Chunk)); + *numReadInOut += sizeof(header) - sizeof(Bin::Chunk); + + void* data = listOut.setSize(header.m_numEntries); + + // Need to read all the compressed data... + size_t payloadSize = header.m_chunk.m_size - (sizeof(header) - sizeof(Bin::Chunk)); + + List compressedPayload; + compressedPayload.setCount(payloadSize); + + stream->Read(compressedPayload.begin(), payloadSize); + *numReadInOut += payloadSize; + + SLANG_ASSERT(header.m_numCompressedEntries == uint32_t((header.m_numEntries * typeSize) / sizeof(uint32_t))); + + // Decode.. + ByteEncodeUtil::decodeLiteUInt32(compressedPayload.begin(), header.m_numCompressedEntries, (uint32_t*)data); + break; + } + case Bin::CompressionType::None: + { + // Read uncompressed + Bin::ArrayHeader header; + header.m_chunk = chunk; + + stream->Read(&header.m_chunk + 1, sizeof(header) - sizeof(Bin::Chunk)); + *numReadInOut += sizeof(header) - sizeof(Bin::Chunk); + + const size_t payloadSize = header.m_numEntries * typeSize; + + void* data = listOut.setSize(header.m_numEntries); + + stream->Read(data, payloadSize); + *numReadInOut += payloadSize; + break; + } + } + + // All chunks have sizes rounded to dword size + if (*numReadInOut & 3) + { + const uint8_t pad[4] = { 0, 0, 0, 0 }; + // Pad outs + int padSize = 4 - int(*numReadInOut & 3); + stream->Seek(SeekOrigin::Current, padSize); + + *numReadInOut += padSize; + } + + return SLANG_OK; +} + +template +Result _readArrayChunk(const IRSerialBinary::SlangHeader& header, const IRSerialBinary::Chunk& chunk, Stream* stream, size_t* numReadInOut, List& arrayOut) +{ + typedef IRSerialBinary Bin; + + Bin::CompressionType compressionType = Bin::CompressionType::None; + + if (chunk.m_type == SLANG_MAKE_COMPRESSED_FOUR_CC(chunk.m_type)) + { + // If it has compression, use the compression type set in the header + compressionType = Bin::CompressionType(header.m_compressionType); + } + ListResizerForType resizer(arrayOut); + return _readArrayChunk(compressionType, chunk, stream, numReadInOut, resizer); +} + +template +Result _readArrayUncompressedChunk(const IRSerialBinary::SlangHeader& header, const IRSerialBinary::Chunk& chunk, Stream* stream, size_t* numReadInOut, List& arrayOut) +{ + typedef IRSerialBinary Bin; + SLANG_UNUSED(header); + ListResizerForType resizer(arrayOut); + return _readArrayChunk(Bin::CompressionType::None, chunk, stream, numReadInOut, resizer); +} + +static Result _decodeInsts(IRSerialBinary::CompressionType compressionType, const List& encodeIn, List& instsOut) +{ + typedef IRSerialBinary Bin; + typedef IRSerialData::Inst::PayloadType PayloadType; + + if (compressionType != Bin::CompressionType::VariableByteLite) + { + return SLANG_FAIL; + } + + const size_t numInsts = size_t(instsOut.getCount()); + IRSerialData::Inst* insts = instsOut.begin(); + + const uint8_t* encodeCur = encodeIn.begin(); + + for (size_t i = 0; i < numInsts; ++i) + { + auto& inst = insts[i]; + + inst.m_op = *encodeCur++; + const PayloadType payloadType = PayloadType(*encodeCur++); + inst.m_payloadType = payloadType; + + // Read the result value + encodeCur += ByteEncodeUtil::decodeLiteUInt32(encodeCur, (uint32_t*)&inst.m_resultTypeIndex); + + switch (inst.m_payloadType) + { + case PayloadType::Empty: + { + break; + } + case PayloadType::Operand_1: + case PayloadType::String_1: + case PayloadType::UInt32: + { + // 1 UInt32 + encodeCur += ByteEncodeUtil::decodeLiteUInt32(encodeCur, (uint32_t*)&inst.m_payload.m_operands[0]); + break; + } + case PayloadType::Operand_2: + case PayloadType::OperandAndUInt32: + case PayloadType::OperandExternal: + case PayloadType::String_2: + { + // 2 UInt32 + encodeCur += ByteEncodeUtil::decodeLiteUInt32(encodeCur, 2, (uint32_t*)&inst.m_payload.m_operands[0]); + break; + } + case PayloadType::Float64: + { + memcpy(&inst.m_payload.m_float64, encodeCur, sizeof(inst.m_payload.m_float64)); + encodeCur += sizeof(inst.m_payload.m_float64); + break; + } + case PayloadType::Int64: + { + memcpy(&inst.m_payload.m_int64, encodeCur, sizeof(inst.m_payload.m_int64)); + encodeCur += sizeof(inst.m_payload.m_int64); + break; + } + } + } + + return SLANG_OK; +} + +Result _readInstArrayChunk(const IRSerialBinary::SlangHeader& slangHeader, const IRSerialBinary::Chunk& chunk, Stream* stream, size_t* numReadInOut, List& arrayOut) +{ + typedef IRSerialBinary Bin; + + Bin::CompressionType compressionType = Bin::CompressionType::None; + if (chunk.m_type == SLANG_MAKE_COMPRESSED_FOUR_CC(chunk.m_type)) + { + compressionType = Bin::CompressionType(slangHeader.m_compressionType); + } + + switch (compressionType) + { + case Bin::CompressionType::None: + { + ListResizerForType resizer(arrayOut); + return _readArrayChunk(compressionType, chunk, stream, numReadInOut, resizer); + } + case Bin::CompressionType::VariableByteLite: + { + // We have a compressed header + Bin::CompressedArrayHeader header; + header.m_chunk = chunk; + + stream->Read(&header.m_chunk + 1, sizeof(header) - sizeof(Bin::Chunk)); + *numReadInOut += sizeof(header) - sizeof(Bin::Chunk); + + // Need to read all the compressed data... + size_t payloadSize = header.m_chunk.m_size - (sizeof(header) - sizeof(Bin::Chunk)); + + List compressedPayload; + compressedPayload.setCount(payloadSize); + + stream->Read(compressedPayload.begin(), payloadSize); + *numReadInOut += payloadSize; + + arrayOut.setCount(header.m_numEntries); + + SLANG_RETURN_ON_FAIL(_decodeInsts(compressionType, compressedPayload, arrayOut)); + break; + } + default: + { + return SLANG_FAIL; + } + } + + // All chunks have sizes rounded to dword size + if (*numReadInOut & 3) + { + // Pad outs + int padSize = 4 - int(*numReadInOut & 3); + stream->Seek(SeekOrigin::Current, padSize); + *numReadInOut += padSize; + } + + return SLANG_OK; +} + +int64_t _calcChunkTotalSize(const IRSerialBinary::Chunk& chunk) +{ + int64_t size = chunk.m_size + sizeof(IRSerialBinary::Chunk); + return (size + 3) & ~int64_t(3); +} + +/* static */Result IRSerialReader::_skip(const IRSerialBinary::Chunk& chunk, Stream* stream, int64_t* remainingBytesInOut) +{ + typedef IRSerialBinary Bin; + int64_t chunkSize = _calcChunkTotalSize(chunk); + if (remainingBytesInOut) + { + *remainingBytesInOut -= chunkSize; + } + + // Skip the payload (we don't need to skip the Chunk because that was already read + stream->Seek(SeekOrigin::Current, chunkSize - sizeof(IRSerialBinary::Chunk)); + return SLANG_OK; +} + +/* static */Result IRSerialReader::readStream(Stream* stream, IRSerialData* dataOut) +{ + typedef IRSerialBinary Bin; + + dataOut->clear(); + + int64_t remainingBytes = 0; + { + Bin::Chunk header; + stream->Read(&header, sizeof(header)); + if (header.m_type != Bin::kRiffFourCc) + { + return SLANG_FAIL; + } + + remainingBytes = header.m_size; + } + + // Header + // Chunk will not be kSlangFourCC if not read yet + Bin::SlangHeader slangHeader; + memset(&slangHeader, 0, sizeof(slangHeader)); + + while (remainingBytes > 0) + { + Bin::Chunk chunk; + + stream->Read(&chunk, sizeof(chunk)); + + size_t bytesRead = sizeof(chunk); + + switch (chunk.m_type) + { + case Bin::kSlangFourCc: + { + // Slang header + slangHeader.m_chunk = chunk; + + // NOTE! Really we should only read what we know the size to be... + // and skip if it's larger + + stream->Read(&slangHeader.m_chunk + 1, sizeof(slangHeader) - sizeof(chunk)); + + remainingBytes -= _calcChunkTotalSize(chunk); + break; + } + case SLANG_MAKE_COMPRESSED_FOUR_CC(Bin::kInstFourCc): + case Bin::kInstFourCc: + { + SLANG_RETURN_ON_FAIL(_readInstArrayChunk(slangHeader, chunk, stream, &bytesRead, dataOut->m_insts)); + remainingBytes -= _calcChunkTotalSize(chunk); + break; + } + case SLANG_MAKE_COMPRESSED_FOUR_CC(Bin::kChildRunFourCc): + case Bin::kChildRunFourCc: + { + SLANG_RETURN_ON_FAIL(_readArrayChunk(slangHeader, chunk, stream, &bytesRead, dataOut->m_childRuns)); + remainingBytes -= _calcChunkTotalSize(chunk); + break; + } + case SLANG_MAKE_COMPRESSED_FOUR_CC(Bin::kExternalOperandsFourCc): + case Bin::kExternalOperandsFourCc: + { + SLANG_RETURN_ON_FAIL(_readArrayChunk(slangHeader, chunk, stream, &bytesRead, dataOut->m_externalOperands)); + remainingBytes -= _calcChunkTotalSize(chunk); + break; + } + case Bin::kStringFourCc: + { + SLANG_RETURN_ON_FAIL(_readArrayUncompressedChunk(slangHeader, chunk, stream, &bytesRead, dataOut->m_stringTable)); + remainingBytes -= _calcChunkTotalSize(chunk); + break; + } + case Bin::kUInt32SourceLocFourCc: + { + SLANG_RETURN_ON_FAIL(_readArrayUncompressedChunk(slangHeader, chunk, stream, &bytesRead, dataOut->m_rawSourceLocs)); + remainingBytes -= _calcChunkTotalSize(chunk); + break; + } + case Bin::kDebugStringFourCc: + { + SLANG_RETURN_ON_FAIL(_readArrayUncompressedChunk(slangHeader, chunk, stream, &bytesRead, dataOut->m_debugStringTable)); + remainingBytes -= _calcChunkTotalSize(chunk); + break; + } + case Bin::kDebugLineInfoFourCc: + { + SLANG_RETURN_ON_FAIL(_readArrayUncompressedChunk(slangHeader, chunk, stream, &bytesRead, dataOut->m_debugLineInfos)); + remainingBytes -= _calcChunkTotalSize(chunk); + break; + } + case Bin::kDebugAdjustedLineInfoFourCc: + { + SLANG_RETURN_ON_FAIL(_readArrayUncompressedChunk(slangHeader, chunk, stream, &bytesRead, dataOut->m_debugAdjustedLineInfos)); + remainingBytes -= _calcChunkTotalSize(chunk); + break; + } + case Bin::kDebugSourceInfoFourCc: + { + SLANG_RETURN_ON_FAIL(_readArrayChunk(slangHeader, chunk, stream, &bytesRead, dataOut->m_debugSourceInfos)); + remainingBytes -= _calcChunkTotalSize(chunk); + break; + } + case SLANG_MAKE_COMPRESSED_FOUR_CC(Bin::kDebugSourceLocRunFourCc): + case Bin::kDebugSourceLocRunFourCc: + { + SLANG_RETURN_ON_FAIL(_readArrayChunk(slangHeader, chunk, stream, &bytesRead, dataOut->m_debugSourceLocRuns)); + remainingBytes -= _calcChunkTotalSize(chunk); + break; + } + + default: + { + SLANG_RETURN_ON_FAIL(_skip(chunk, stream, &remainingBytes)); + break; + } + } + } + + return SLANG_OK; +} + +static SourceRange _toSourceRange(const IRSerialData::DebugSourceInfo& info) +{ + SourceRange range; + range.begin = SourceLoc::fromRaw(info.m_startSourceLoc); + range.end = SourceLoc::fromRaw(info.m_endSourceLoc); + return range; +} + +static int _findIndex(const List& infos, SourceLoc sourceLoc) +{ + const int numInfos = int(infos.getCount()); + for (int i = 0; i < numInfos; ++i) + { + if (_toSourceRange(infos[i]).contains(sourceLoc)) + { + return i; + } + } + + return -1; +} + +static int _calcFixSourceLoc(const IRSerialData::DebugSourceInfo& info, SourceView* sourceView, SourceRange& rangeOut) +{ + rangeOut = _toSourceRange(info); + return int(sourceView->getRange().begin.getRaw()) - int(info.m_startSourceLoc); +} + +/* static */Result IRSerialReader::read(const IRSerialData& data, Session* session, SourceManager* sourceManager, RefPtr& moduleOut) +{ + typedef Ser::Inst::PayloadType PayloadType; + + m_serialData = &data; + + auto module = new IRModule(); + moduleOut = module; + m_module = module; + + module->session = session; + + // Set up the string rep cache + m_stringRepresentationCache.init(&data.m_stringTable, session->getNamePool(), module->getObjectScopeManager()); + + // Add all the instructions + + List insts; + + const Index numInsts = data.m_insts.getCount(); + + SLANG_ASSERT(numInsts > 0); + + insts.setCount(numInsts); + insts[0] = nullptr; + + // 0 holds null + // 1 holds the IRModuleInst + { + // Check that insts[1] is the module inst + const Ser::Inst& srcInst = data.m_insts[1]; + SLANG_RELEASE_ASSERT(srcInst.m_op == kIROp_Module); + SLANG_ASSERT(srcInst.m_payloadType == PayloadType::Empty); + + // Create the module inst + auto moduleInst = static_cast(createEmptyInstWithSize(module, kIROp_Module, sizeof(IRModuleInst))); + module->moduleInst = moduleInst; + moduleInst->module = module; + + // Set the IRModuleInst + insts[1] = moduleInst; + } + + for (Index i = 2; i < numInsts; ++i) + { + const Ser::Inst& srcInst = data.m_insts[i]; + + const IROp op((IROp)srcInst.m_op); + + if (isConstant(op)) + { + // Handling of constants + + // Calculate the minimum object size (ie not including the payload of value) + const size_t prefixSize = SLANG_OFFSET_OF(IRConstant, value); + + IRConstant* irConst = nullptr; + switch (op) + { + case kIROp_BoolLit: + { + SLANG_ASSERT(srcInst.m_payloadType == PayloadType::UInt32); + irConst = static_cast(createEmptyInstWithSize(module, op, prefixSize + sizeof(IRIntegerValue))); + irConst->value.intVal = srcInst.m_payload.m_uint32 != 0; + break; + } + case kIROp_IntLit: + { + SLANG_ASSERT(srcInst.m_payloadType == PayloadType::Int64); + irConst = static_cast(createEmptyInstWithSize(module, op, prefixSize + sizeof(IRIntegerValue))); + irConst->value.intVal = srcInst.m_payload.m_int64; + break; + } + case kIROp_PtrLit: + { + SLANG_ASSERT(srcInst.m_payloadType == PayloadType::Int64); + irConst = static_cast(createEmptyInstWithSize(module, op, prefixSize + sizeof(void*))); + irConst->value.ptrVal = (void*) (intptr_t) srcInst.m_payload.m_int64; + break; + } + case kIROp_FloatLit: + { + SLANG_ASSERT(srcInst.m_payloadType == PayloadType::Float64); + irConst = static_cast(createEmptyInstWithSize(module, op, prefixSize + sizeof(IRFloatingPointValue))); + irConst->value.floatVal = srcInst.m_payload.m_float64; + break; + } + case kIROp_StringLit: + { + SLANG_ASSERT(srcInst.m_payloadType == PayloadType::String_1); + + const UnownedStringSlice slice = m_stringRepresentationCache.getStringSlice(StringHandle(srcInst.m_payload.m_stringIndices[0])); + + const size_t sliceSize = slice.size(); + const size_t instSize = prefixSize + SLANG_OFFSET_OF(IRConstant::StringValue, chars) + sliceSize; + + irConst = static_cast(createEmptyInstWithSize(module, op, instSize)); + + IRConstant::StringValue& dstString = irConst->value.stringVal; + + dstString.numChars = uint32_t(sliceSize); + // Turn into pointer to avoid warning of array overrun + char* dstChars = dstString.chars; + // Copy the chars + memcpy(dstChars, slice.begin(), sliceSize); + break; + } + default: + { + SLANG_ASSERT(!"Unknown constant type"); + return SLANG_FAIL; + } + } + + insts[i] = irConst; + } + else if (isTextureTypeBase(op)) + { + IRTextureTypeBase* inst = static_cast(createEmptyInst(module, op, 1)); + SLANG_ASSERT(srcInst.m_payloadType == PayloadType::OperandAndUInt32); + + // Reintroduce the texture type bits into the the + const uint32_t other = srcInst.m_payload.m_operandAndUInt32.m_uint32; + inst->op = IROp(uint32_t(inst->op) | (other << kIROpMeta_OtherShift)); + + insts[i] = inst; + } + else + { + int numOperands = srcInst.getNumOperands(); + insts[i] = createEmptyInst(module, op, numOperands); + } + } + + // Patch up the operands + for (Index i = 1; i < numInsts; ++i) + { + const Ser::Inst& srcInst = data.m_insts[i]; + const IROp op((IROp)srcInst.m_op); + + IRInst* dstInst = insts[i]; + + // Set the result type + if (srcInst.m_resultTypeIndex != Ser::InstIndex(0)) + { + IRInst* resultInst = insts[int(srcInst.m_resultTypeIndex)]; + // NOTE! Counter intuitively the IRType* paramter may not be IRType* derived for example + // IRGlobalGenericParam is valid, but isn't IRType* derived + + //SLANG_RELEASE_ASSERT(as(resultInst)); + dstInst->setFullType(static_cast(resultInst)); + } + + //if (!isParentDerived(op)) + { + const Ser::InstIndex* srcOperandIndices; + const int numOperands = data.getOperands(srcInst, &srcOperandIndices); + + for (int j = 0; j < numOperands; j++) + { + dstInst->setOperand(j, insts[int(srcOperandIndices[j])]); + } + } + } + + // Patch up the children + { + const Index numChildRuns = data.m_childRuns.getCount(); + for (Index i = 0; i < numChildRuns; i++) + { + const auto& run = data.m_childRuns[i]; + + IRInst* inst = insts[int(run.m_parentIndex)]; + + for (int j = 0; j < int(run.m_numChildren); ++j) + { + IRInst* child = insts[j + int(run.m_startInstIndex)]; + SLANG_ASSERT(child->parent == nullptr); + child->insertAtEnd(inst); + } + } + } + + // Re-add source locations, if they are defined + if (m_serialData->m_rawSourceLocs.getCount() == numInsts) + { + const Ser::RawSourceLoc* srcLocs = m_serialData->m_rawSourceLocs.begin(); + for (Index i = 1; i < numInsts; ++i) + { + IRInst* dstInst = insts[i]; + + dstInst->sourceLoc.setRaw(Slang::SourceLoc::RawValue(srcLocs[i])); + } + } + + if (sourceManager && m_serialData->m_debugSourceInfos.getCount()) + { + List debugStringSlices; + SerialStringTableUtil::decodeStringTable(m_serialData->m_debugStringTable, debugStringSlices); + + // All of the strings are placed in the manager (and its StringSlicePool) where the SourceView and SourceFile are constructed from + List stringMap; + SerialStringTableUtil::calcStringSlicePoolMap(debugStringSlices, sourceManager->getStringSlicePool(), stringMap); + + const List& sourceInfos = m_serialData->m_debugSourceInfos; + + // Construct the source files + Index numSourceFiles = sourceInfos.getCount(); + + // These hold the views (and SourceFile as there is only one SourceFile per view) in the same order as the sourceInfos + List sourceViews; + sourceViews.setCount(numSourceFiles); + + for (Index i = 0; i < numSourceFiles; ++i) + { + const IRSerialData::DebugSourceInfo& srcSourceInfo = sourceInfos[i]; + + PathInfo pathInfo; + pathInfo.type = PathInfo::Type::FoundPath; + pathInfo.foundPath = debugStringSlices[UInt(srcSourceInfo.m_pathIndex)]; + + SourceFile* sourceFile = sourceManager->createSourceFileWithSize(pathInfo, srcSourceInfo.m_endSourceLoc - srcSourceInfo.m_startSourceLoc); + SourceView* sourceView = sourceManager->createSourceView(sourceFile, nullptr); + + // We need to accumulate all line numbers, for this source file, both adjusted and unadjusted + List lineInfos; + // Add the adjusted lines + { + lineInfos.setCount(srcSourceInfo.m_numAdjustedLineInfos); + const IRSerialData::DebugAdjustedLineInfo* srcAdjustedLineInfos = m_serialData->m_debugAdjustedLineInfos.getBuffer() + srcSourceInfo.m_adjustedLineInfosStartIndex; + const int numAdjustedLines = int(srcSourceInfo.m_numAdjustedLineInfos); + for (int j = 0; j < numAdjustedLines; ++j) + { + lineInfos[j] = srcAdjustedLineInfos[j].m_lineInfo; + } + } + // Add regular lines + lineInfos.addRange(m_serialData->m_debugLineInfos.getBuffer() + srcSourceInfo.m_lineInfosStartIndex, srcSourceInfo.m_numLineInfos); + // Put in sourceloc order + lineInfos.sort(); + + List lineBreakOffsets; + + // We can now set up the line breaks array + const int numLines = int(srcSourceInfo.m_numLines); + lineBreakOffsets.setCount(numLines); + + { + const Index numLineInfos = lineInfos.getCount(); + Index lineIndex = 0; + + // Every line up and including should hold the same offset + for (Index lineInfoIndex = 0; lineInfoIndex < numLineInfos; ++lineInfoIndex) + { + const auto& lineInfo = lineInfos[lineInfoIndex]; + + const uint32_t offset = lineInfo.m_lineStartOffset; + SLANG_ASSERT(offset > 0); + const int finishIndex = int(lineInfo.m_lineIndex); + + SLANG_ASSERT(finishIndex < numLines); + + for (; lineIndex < finishIndex; ++lineIndex) + { + lineBreakOffsets[lineIndex] = offset - 1; + } + lineBreakOffsets[lineIndex] = offset; + lineIndex++; + } + + // Do the remaining lines + const uint32_t offset = uint32_t(srcSourceInfo.m_endSourceLoc - srcSourceInfo.m_startSourceLoc); + for (; lineIndex < numLines; ++lineIndex) + { + lineBreakOffsets[lineIndex] = offset; + } + } + + sourceFile->setLineBreakOffsets(lineBreakOffsets.getBuffer(), lineBreakOffsets.getCount()); + + if (srcSourceInfo.m_numAdjustedLineInfos) + { + List adjustedLineInfos; + + int numEntries = int(srcSourceInfo.m_numAdjustedLineInfos); + + adjustedLineInfos.addRange(m_serialData->m_debugAdjustedLineInfos.getBuffer() + srcSourceInfo.m_adjustedLineInfosStartIndex, numEntries); + adjustedLineInfos.sort(); + + // Work out the views adjustments, and place in dstEntries + List dstEntries; + dstEntries.setCount(numEntries); + + const uint32_t sourceLocOffset = uint32_t(sourceView->getRange().begin.getRaw()); + + for (int j = 0; j < numEntries; ++j) + { + const auto& srcEntry = adjustedLineInfos[j]; + auto& dstEntry = dstEntries[j]; + + dstEntry.m_pathHandle = stringMap[int(srcEntry.m_pathStringIndex)]; + dstEntry.m_startLoc = SourceLoc::fromRaw(srcEntry.m_lineInfo.m_lineStartOffset + sourceLocOffset); + dstEntry.m_lineAdjust = int32_t(srcEntry.m_adjustedLineIndex) - int32_t(srcEntry.m_lineInfo.m_lineIndex); + } + + // Set the adjustments on the view + sourceView->setEntries(dstEntries.getBuffer(), dstEntries.getCount()); + } + + sourceViews[i] = sourceView; + } + + // We now need to apply the runs + { + List sourceRuns(m_serialData->m_debugSourceLocRuns); + // They are now in source location order + sourceRuns.sort(); + + // Just guess initially 0 for the source file that contains the initial run + SourceRange range; + int fixSourceLoc = _calcFixSourceLoc(sourceInfos[0], sourceViews[0], range); + + const Index numRuns = sourceRuns.getCount(); + for (Index i = 0; i < numRuns; ++i) + { + const auto& run = sourceRuns[i]; + const SourceLoc srcSourceLoc = SourceLoc::fromRaw(run.m_sourceLoc); + + if (!range.contains(srcSourceLoc)) + { + int index = _findIndex(sourceInfos, srcSourceLoc); + if (index < 0) + { + // Didn't find the match + continue; + } + fixSourceLoc = _calcFixSourceLoc(sourceInfos[index], sourceViews[index], range); + SLANG_ASSERT(range.contains(srcSourceLoc)); + } + + // Work out the fixed source location + SourceLoc sourceLoc = SourceLoc::fromRaw(int(run.m_sourceLoc) + fixSourceLoc); + + SLANG_ASSERT(Index(uint32_t(run.m_startInstIndex) + run.m_numInst) <= insts.getCount()); + IRInst** dstInsts = insts.getBuffer() + int(run.m_startInstIndex); + + const int runSize = int(run.m_numInst); + for (int j = 0; j < runSize; ++j) + { + dstInsts[j]->sourceLoc = sourceLoc; + } + } + } + } + + return SLANG_OK; +} + +// !!!!!!!!!!!!!!!!!!!!!!!!!!!! IRSerialUtil !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! + +/* static */void IRSerialUtil::calcInstructionList(IRModule* module, List& instsOut) +{ + // We reserve 0 for null + instsOut.setCount(1); + instsOut[0] = nullptr; + + // Stack for parentInst + List parentInstStack; + + IRModuleInst* moduleInst = module->getModuleInst(); + parentInstStack.add(moduleInst); + + // Add to list + instsOut.add(moduleInst); + + // Traverse all of the instructions + while (parentInstStack.getCount()) + { + // If it's in the stack it is assumed it is already in the inst map + IRInst* parentInst = parentInstStack.getLast(); + parentInstStack.removeLast(); + + IRInstListBase childrenList = parentInst->getDecorationsAndChildren(); + for (IRInst* child : childrenList) + { + instsOut.add(child); + parentInstStack.add(child); + } + } +} + +/* static */SlangResult IRSerialUtil::verifySerialize(IRModule* module, Session* session, SourceManager* sourceManager, IRSerialBinary::CompressionType compressionType, IRSerialWriter::OptionFlags optionFlags) +{ + // Verify if we can stream out with debug information + + List originalInsts; + calcInstructionList(module, originalInsts); + + IRSerialData serialData; + { + // Write IR out to serialData - copying over SourceLoc information directly + IRSerialWriter writer; + SLANG_RETURN_ON_FAIL(writer.write(module, sourceManager, optionFlags, &serialData)); + } + + // Write the data out to stream + MemoryStream memoryStream(FileAccess::ReadWrite); + SLANG_RETURN_ON_FAIL(IRSerialWriter::writeStream(serialData, compressionType, &memoryStream)); + + // Reset stream + memoryStream.Seek(SeekOrigin::Start, 0); + + IRSerialData readData; + + SLANG_RETURN_ON_FAIL(IRSerialReader::readStream(&memoryStream, &readData)); + + // Check the stream read data is the same + if (readData != serialData) + { + SLANG_ASSERT(!"Streamed in data doesn't match"); + return SLANG_FAIL; + } + + RefPtr irReadModule; + + SourceManager workSourceManager; + workSourceManager.initialize(sourceManager, nullptr); + + { + IRSerialReader reader; + SLANG_RETURN_ON_FAIL(reader.read(serialData, session, &workSourceManager, irReadModule)); + } + + List readInsts; + calcInstructionList(irReadModule, readInsts); + + if (readInsts.getCount() != originalInsts.getCount()) + { + SLANG_ASSERT(!"Instruction counts don't match"); + return SLANG_FAIL; + } + + if (optionFlags & IRSerialWriter::OptionFlag::RawSourceLocation) + { + SLANG_ASSERT(readInsts[0] == originalInsts[0]); + // All the source locs should be identical + for (Index i = 1; i < readInsts.getCount(); ++i) + { + IRInst* origInst = originalInsts[i]; + IRInst* readInst = readInsts[i]; + + if (origInst->sourceLoc.getRaw() != readInst->sourceLoc.getRaw()) + { + SLANG_ASSERT(!"Source locs don't match"); + return SLANG_FAIL; + } + } + } + else if (optionFlags & IRSerialWriter::OptionFlag::DebugInfo) + { + // They should be on the same line nos + for (Index i = 1; i < readInsts.getCount(); ++i) + { + IRInst* origInst = originalInsts[i]; + IRInst* readInst = readInsts[i]; + + if (origInst->sourceLoc.getRaw() == readInst->sourceLoc.getRaw()) + { + continue; + } + + // Work out the + SourceView* origSourceView = sourceManager->findSourceView(origInst->sourceLoc); + SourceView* readSourceView = workSourceManager.findSourceView(readInst->sourceLoc); + + // if both are null we are done + if (origSourceView == nullptr && origSourceView == readSourceView) + { + continue; + } + SLANG_ASSERT(origSourceView && readSourceView); + + { + auto origInfo = origSourceView->getHumaneLoc(origInst->sourceLoc, SourceLocType::Actual); + auto readInfo = readSourceView->getHumaneLoc(readInst->sourceLoc, SourceLocType::Actual); + + if (!(origInfo.line == readInfo.line && origInfo.column == readInfo.column && origInfo.pathInfo.foundPath == readInfo.pathInfo.foundPath)) + { + SLANG_ASSERT(!"Debug data didn't match"); + return SLANG_FAIL; + } + } + + // We may have adjusted line numbers -> but they may not match, because we only reconstruct one view + // So for now disable this test + + if (false) + { + auto origInfo = origSourceView->getHumaneLoc(origInst->sourceLoc, SourceLocType::Nominal); + auto readInfo = readSourceView->getHumaneLoc(readInst->sourceLoc, SourceLocType::Nominal); + + if (!(origInfo.line == readInfo.line && origInfo.column == readInfo.column && origInfo.pathInfo.foundPath == readInfo.pathInfo.foundPath)) + { + SLANG_ASSERT(!"Debug data didn't match"); + return SLANG_FAIL; + } + } + } + } + + return SLANG_OK; +} + +// !!!!!!!!!!!!!!!!!!!!!!!!!!!! Free functions !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! + +#if 0 + +Result serializeModule(IRModule* module, SourceManager* sourceManager, Stream* stream) +{ + IRSerialWriter serializer; + IRSerialData serialData; + + SLANG_RETURN_ON_FAIL(serializer.write(module, sourceManager, IRSerialWriter::OptionFlag::RawSourceLocation, &serialData)); + + if (stream) + { + SLANG_RETURN_ON_FAIL(IRSerialWriter::writeStream(serialData, IRSerialBinary::CompressionType::VariableByteLite, stream)); + } + + return SLANG_OK; +} + +Result readModule(Session* session, Stream* stream, RefPtr& moduleOut) +{ + IRSerialData serialData; + IRSerialReader::readStream(stream, &serialData); + + IRSerialReader reader; + return reader.read(serialData, session, moduleOut); +} + +#endif + +} // namespace Slang diff --git a/source/slang/slang-ir-serialize.h b/source/slang/slang-ir-serialize.h new file mode 100644 index 000000000..4f4301666 --- /dev/null +++ b/source/slang/slang-ir-serialize.h @@ -0,0 +1,549 @@ +// slang-ir-serialize.h +#ifndef SLANG_IR_SERIALIZE_H_INCLUDED +#define SLANG_IR_SERIALIZE_H_INCLUDED + +#include "../core/slang-basic.h" +#include "../core/slang-stream.h" + +#include "../core/slang-object-scope-manager.h" + +#include "slang-ir.h" + +// For TranslationUnitRequest +#include "slang-compiler.h" + +namespace Slang { + +class StringRepresentationCache +{ + public: + typedef StringSlicePool::Handle Handle; + + struct Entry + { + uint32_t m_startIndex; + uint32_t m_numChars; + RefObject* m_object; ///< Could be nullptr, Name, or StringRepresentation. + }; + + /// Get as a name + Name* getName(Handle handle); + /// Get as a string + String getString(Handle handle); + /// Get as string representation + StringRepresentation* getStringRepresentation(Handle handle); + /// Get as a string slice + UnownedStringSlice getStringSlice(Handle handle) const; + /// Get as a 0 terminated 'c style' string + char* getCStr(Handle handle); + + /// Initialize a cache to use a string table, namePool and scopeManager + void init(const List* stringTable, NamePool* namePool, ObjectScopeManager* scopeManager); + + /// Ctor + StringRepresentationCache(); + + protected: + ObjectScopeManager* m_scopeManager; + NamePool* m_namePool; + const List* m_stringTable; + List m_entries; +}; + +struct SerialStringTableUtil +{ + /// Convert a pool into a string table + static void encodeStringTable(const StringSlicePool& pool, List& stringTable); + static void encodeStringTable(const UnownedStringSlice* slices, size_t numSlices, List& stringTable); + /// Appends the decoded strings into slicesOut + static void appendDecodedStringTable(const List& stringTable, List& slicesOut); + /// Decodes a string table (and does so such that the indices are compatible with StringSlicePool) + static void decodeStringTable(const List& stringTable, List& slicesOut); + + /// Produces an index map, from slices to indices in pool + static void calcStringSlicePoolMap(const List& slices, StringSlicePool& pool, List& indexMap); +}; + +// Pre-declare +class Name; + +struct IRSerialData +{ + typedef IRSerialData ThisType; + + enum class InstIndex : uint32_t; + enum class StringIndex : uint32_t; + enum class ArrayIndex : uint32_t; + + enum class RawSourceLoc : SourceLoc::RawValue; ///< This is just to copy over source loc data (ie not strictly serialize) + enum class StringOffset : uint32_t; ///< Offset into the m_stringsBuffer + + typedef uint32_t SizeType; + + static const StringIndex kNullStringIndex = StringIndex(StringSlicePool::kNullHandle); + static const StringIndex kEmptyStringIndex = StringIndex(StringSlicePool::kEmptyHandle); + + /// A run of instructions + struct InstRun + { + typedef InstRun ThisType; + SLANG_FORCE_INLINE bool operator==(const ThisType& rhs) const + { + return m_parentIndex == rhs.m_parentIndex && + m_startInstIndex == rhs.m_startInstIndex && + m_numChildren == rhs.m_numChildren; + } + SLANG_FORCE_INLINE bool operator!=(const ThisType& rhs) const { return !(*this == rhs); } + + InstIndex m_parentIndex; ///< The parent instruction + InstIndex m_startInstIndex; ///< The index to the first instruction + SizeType m_numChildren; ///< The number of children + }; + + struct SourceLocRun + { + typedef SourceLocRun ThisType; + + bool operator==(const ThisType& rhs) const { return m_sourceLoc == rhs.m_sourceLoc && m_startInstIndex == rhs.m_startInstIndex && m_numInst == rhs.m_numInst; } + bool operator!=(const ThisType& rhs) const { return !(*this == rhs); } + bool operator<(const ThisType& rhs) const { return m_sourceLoc < rhs.m_sourceLoc; } + + uint32_t m_sourceLoc; ///< The source location + InstIndex m_startInstIndex; ///< The index to the first instruction + SizeType m_numInst; ///< The number of children + }; + + struct PayloadInfo + { + uint8_t m_numOperands; + uint8_t m_numStrings; + }; + + struct DebugSourceInfo + { + typedef DebugSourceInfo ThisType; + + bool operator==(const ThisType& rhs) const + { + return m_pathIndex == rhs.m_pathIndex && + m_startSourceLoc == rhs.m_startSourceLoc && + m_endSourceLoc == rhs.m_endSourceLoc && + m_numLineInfos == rhs.m_numLineInfos && + m_lineInfosStartIndex == rhs.m_lineInfosStartIndex && + m_numLineInfos == rhs.m_numLineInfos; + } + bool operator!=(const ThisType& rhs) const { return !(*this == rhs); } + + bool isSourceLocInRange(uint32_t sourceLoc) const { return sourceLoc >= m_startSourceLoc && sourceLoc <= m_endSourceLoc; } + + StringIndex m_pathIndex; ///< Index to the string table + uint32_t m_startSourceLoc; ///< The offset to the source + uint32_t m_endSourceLoc; ///< The number of bytes in the source + + uint32_t m_numLines; ///< Total number of lines in source file + + uint32_t m_lineInfosStartIndex; ///< Index into m_debugLineInfos + uint32_t m_numLineInfos; ///< The number of line infos + + uint32_t m_adjustedLineInfosStartIndex; ///< Adjusted start index + uint32_t m_numAdjustedLineInfos; ///< The number of line infos + }; + + struct DebugLineInfo + { + typedef DebugLineInfo ThisType; + bool operator<(const ThisType& rhs) const { return m_lineStartOffset < rhs.m_lineStartOffset; } + bool operator==(const ThisType& rhs) const + { + return m_lineStartOffset == rhs.m_lineStartOffset && + m_lineIndex == rhs.m_lineIndex; + } + bool operator!=(const ThisType& rhs) const { return !(*this == rhs); } + + uint32_t m_lineStartOffset; ///< The offset into the source file + uint32_t m_lineIndex; ///< Original line index + }; + + struct DebugAdjustedLineInfo + { + typedef DebugAdjustedLineInfo ThisType; + bool operator==(const ThisType& rhs) const + { + return m_lineInfo == rhs.m_lineInfo && + m_adjustedLineIndex == rhs.m_adjustedLineIndex && + m_pathStringIndex == rhs.m_pathStringIndex; + } + bool operator!=(const ThisType& rhs) const { return !(*this == rhs); } + bool operator<(const ThisType& rhs) const { return m_lineInfo < rhs.m_lineInfo; } + + DebugLineInfo m_lineInfo; + uint32_t m_adjustedLineIndex; ///< The line index with the adjustment (if there is any). Is 0 if m_pathStringIndex is 0. + StringIndex m_pathStringIndex; ///< The path as an index + }; + + // Instruction... + // We can store SourceLoc values separately. Just store per index information. + // Parent information is stored in m_childRuns + // Decoration information is stored in m_decorationRuns + struct Inst + { + typedef Inst ThisType; + enum + { + kMaxOperands = 2, ///< Maximum number of operands that can be held in an instruction (otherwise held 'externally') + }; + + // NOTE! Can't change order or list without changing appropriate s_payloadInfos + enum class PayloadType : uint8_t + { + // First 3 must be in this order so a cast from 0-2 is directly represented as number of operands + Empty, ///< Has no payload (or operands) + Operand_1, ///< 1 Operand + Operand_2, ///< 2 Operands + + OperandAndUInt32, ///< 1 Operand and a single UInt32 + OperandExternal, ///< Operands are held externally + String_1, ///< 1 String + String_2, ///< 2 Strings + UInt32, ///< Holds an unsigned 32 bit integral (might represent a type) + Float64, + Int64, + + CountOf, + }; + + /// Get the number of operands + SLANG_FORCE_INLINE int getNumOperands() const; + + bool operator==(const ThisType& rhs) const; + + SLANG_FORCE_INLINE bool operator!=(const ThisType& rhs) const { return !(*this == rhs); } + + uint8_t m_op; ///< For now one of IROp + PayloadType m_payloadType; ///< The type of payload + uint16_t m_pad0; ///< Not currently used + + InstIndex m_resultTypeIndex; //< 0 if has no type. The result type of this instruction + + struct ExternalOperandPayload + { + ArrayIndex m_arrayIndex; ///< Index into the m_externalOperands table + SizeType m_size; ///< The amount of entries in that table + }; + + struct OperandAndUInt32 + { + InstIndex m_operand; + uint32_t m_uint32; + }; + + union Payload + { + double m_float64; + int64_t m_int64; + uint32_t m_uint32; ///< Unsigned integral value + IRFloatingPointValue m_float; ///< Floating point value + IRIntegerValue m_int; ///< Integral value + InstIndex m_operands[kMaxOperands]; ///< For items that 2 or less operands it can use this. + StringIndex m_stringIndices[kMaxOperands]; + ExternalOperandPayload m_externalOperand; ///< Operands are stored in an an index of an operand array + OperandAndUInt32 m_operandAndUInt32; + }; + + Payload m_payload; + }; + + /// Clear to initial state + void clear(); + /// Get the operands of an instruction + SLANG_FORCE_INLINE int getOperands(const Inst& inst, const InstIndex** operandsOut) const; + + /// == + bool operator==(const ThisType& rhs) const; + SLANG_FORCE_INLINE bool operator!=(const ThisType& rhs) const { return !(*this == rhs); } + + /// Calculate the amount of memory used by this IRSerialData + size_t calcSizeInBytes() const; + + /// Ctor + IRSerialData(); + + List m_insts; ///< The instructions + + List m_childRuns; ///< Holds the information about children that belong to an instruction + + List m_externalOperands; ///< Holds external operands (for instructions with more than kNumOperands) + + List m_stringTable; ///< All strings. Indexed into by StringIndex + + List m_rawSourceLocs; ///< A source location per instruction (saved without modification from IRInst)s + + // Data only set if we have debug information + + List m_debugStringTable; ///< String table for debug use only + List m_debugLineInfos; ///< Debug line information + List m_debugAdjustedLineInfos; ///< Adjusted line infos + List m_debugSourceInfos; ///< Debug source information + List m_debugSourceLocRuns; ///< Runs of instructions that use a source loc + + static const PayloadInfo s_payloadInfos[int(Inst::PayloadType::CountOf)]; +}; + +// -------------------------------------------------------------------------- +SLANG_FORCE_INLINE int IRSerialData::Inst::getNumOperands() const +{ + return (m_payloadType == PayloadType::OperandExternal) ? m_payload.m_externalOperand.m_size : s_payloadInfos[int(m_payloadType)].m_numOperands; +} + +// -------------------------------------------------------------------------- +SLANG_FORCE_INLINE bool IRSerialData::Inst::operator==(const ThisType& rhs) const +{ + if (m_op == rhs.m_op && + m_payloadType == rhs.m_payloadType && + m_resultTypeIndex == rhs.m_resultTypeIndex) + { + switch (m_payloadType) + { + case PayloadType::Empty: + { + return true; + } + case PayloadType::Operand_1: + case PayloadType::String_1: + case PayloadType::UInt32: + { + return m_payload.m_operands[0] == rhs.m_payload.m_operands[0]; + } + case PayloadType::OperandAndUInt32: + case PayloadType::OperandExternal: + case PayloadType::Operand_2: + case PayloadType::String_2: + { + return m_payload.m_operands[0] == rhs.m_payload.m_operands[0] && + m_payload.m_operands[1] == rhs.m_payload.m_operands[1]; + } + case PayloadType::Float64: + case PayloadType::Int64: + { + return m_payload.m_int64 == rhs.m_payload.m_int64; + } + default: break; + } + } + + return false; +} +// -------------------------------------------------------------------------- +SLANG_FORCE_INLINE int IRSerialData::getOperands(const Inst& inst, const InstIndex** operandsOut) const +{ + if (inst.m_payloadType == Inst::PayloadType::OperandExternal) + { + *operandsOut = m_externalOperands.begin() + int(inst.m_payload.m_externalOperand.m_arrayIndex); + return int(inst.m_payload.m_externalOperand.m_size); + } + else + { + *operandsOut = inst.m_payload.m_operands; + return s_payloadInfos[int(inst.m_payloadType)].m_numOperands; + } +} + + +#define SLANG_FOUR_CC(c0, c1, c2, c3) ((uint32_t(c0) << 0) | (uint32_t(c1) << 8) | (uint32_t(c2) << 16) | (uint32_t(c3) << 24)) + +#define SLANG_MAKE_COMPRESSED_FOUR_CC(fourCc) (((fourCc) & 0xffff00ff) | (uint32_t('c') << 8)) + +struct IRSerialBinary +{ + // http://fileformats.archiveteam.org/wiki/RIFF + // http://www.fileformat.info/format/riff/egff.htm + + struct Chunk + { + uint32_t m_type; + uint32_t m_size; + }; + + enum class CompressionType + { + None, + VariableByteLite, + }; + + + static const uint32_t kRiffFourCc = SLANG_FOUR_CC('R', 'I', 'F', 'F'); + static const uint32_t kSlangFourCc = SLANG_FOUR_CC('S', 'L', 'N', 'G'); ///< Holds all the slang specific chunks + + static const uint32_t kInstFourCc = SLANG_FOUR_CC('S', 'L', 'i', 'n'); + static const uint32_t kChildRunFourCc = SLANG_FOUR_CC('S', 'L', 'c', 'r'); + static const uint32_t kExternalOperandsFourCc = SLANG_FOUR_CC('S', 'L', 'e', 'o'); + + static const uint32_t kCompressedInstFourCc = SLANG_MAKE_COMPRESSED_FOUR_CC(kInstFourCc); + static const uint32_t kCompressedChildRunFourCc = SLANG_MAKE_COMPRESSED_FOUR_CC(kChildRunFourCc); + static const uint32_t kCompressedExternalOperandsFourCc = SLANG_MAKE_COMPRESSED_FOUR_CC(kExternalOperandsFourCc); + + static const uint32_t kStringFourCc = SLANG_FOUR_CC('S', 'L', 's', 't'); + + static const uint32_t kUInt32SourceLocFourCc = SLANG_FOUR_CC('S', 'r', 's', '4'); + + static const uint32_t kDebugStringFourCc = SLANG_FOUR_CC('S', 'd', 's', 't'); + static const uint32_t kDebugLineInfoFourCc = SLANG_FOUR_CC('S', 'd', 'l', 'n'); + static const uint32_t kDebugAdjustedLineInfoFourCc = SLANG_FOUR_CC('S', 'd', 'a', 'l'); + static const uint32_t kDebugSourceInfoFourCc = SLANG_FOUR_CC('S', 'd', 's', 'o'); + static const uint32_t kDebugSourceLocRunFourCc = SLANG_FOUR_CC('S', 'd', 's', 'r'); + + struct SlangHeader + { + Chunk m_chunk; + uint32_t m_compressionType; ///< Holds the compression type used (if used at all) + }; + struct ArrayHeader + { + Chunk m_chunk; + uint32_t m_numEntries; + }; + struct CompressedArrayHeader + { + Chunk m_chunk; + uint32_t m_numEntries; ///< The number of entries + uint32_t m_numCompressedEntries; ///< The amount of compressed entries + }; +}; + + +struct IRSerialWriter +{ + typedef IRSerialData Ser; + typedef IRSerialBinary Bin; + + struct OptionFlag + { + typedef uint32_t Type; + enum Enum: Type + { + RawSourceLocation = 0x01, + DebugInfo = 0x02, + }; + }; + typedef OptionFlag::Type OptionFlags; + + Result write(IRModule* module, SourceManager* sourceManager, OptionFlags options, IRSerialData* serialData); + + static Result writeStream(const IRSerialData& data, Bin::CompressionType compressionType, Stream* stream); + + /// Get an instruction index from an instruction + Ser::InstIndex getInstIndex(IRInst* inst) const { return inst ? Ser::InstIndex(m_instMap[inst]) : Ser::InstIndex(0); } + + /// Get a slice from an index + UnownedStringSlice getStringSlice(Ser::StringIndex index) const { return m_stringSlicePool.getSlice(StringSlicePool::Handle(index)); } + /// Get index from string representations + Ser::StringIndex getStringIndex(StringRepresentation* string) { return Ser::StringIndex(m_stringSlicePool.add(string)); } + Ser::StringIndex getStringIndex(const UnownedStringSlice& slice) { return Ser::StringIndex(m_stringSlicePool.add(slice)); } + Ser::StringIndex getStringIndex(Name* name) { return name ? getStringIndex(name->text) : Ser::kNullStringIndex; } + Ser::StringIndex getStringIndex(const char* chars) { return Ser::StringIndex(m_stringSlicePool.add(chars)); } + Ser::StringIndex getStringIndex(const String& string) { return Ser::StringIndex(m_stringSlicePool.add(string.getUnownedSlice())); } + + StringSlicePool& getStringPool() { return m_stringSlicePool; } + StringSlicePool& getDebugStringPool() { return m_debugStringSlicePool; } + + IRSerialWriter() : + m_serialData(nullptr) + {} + +protected: + class DebugSourceFile : public RefObject + { + public: + DebugSourceFile(SourceFile* sourceFile, SourceLoc::RawValue baseSourceLoc): + m_sourceFile(sourceFile), + m_baseSourceLoc(baseSourceLoc) + { + // Need to know how many lines there are + const List& lineOffsets = sourceFile->getLineBreakOffsets(); + + const auto numLineIndices = lineOffsets.getCount(); + + // Set none as being used initially + m_lineIndexUsed.setCount(numLineIndices); + ::memset(m_lineIndexUsed.begin(), 0, numLineIndices * sizeof(uint8_t)); + } + /// True if we have information on that line index + bool hasLineIndex(int lineIndex) const { return m_lineIndexUsed[lineIndex] != 0; } + void setHasLineIndex(int lineIndex) { m_lineIndexUsed[lineIndex] = 1; } + + SourceLoc::RawValue m_baseSourceLoc; ///< The base source location + + SourceFile* m_sourceFile; ///< The source file + List m_lineIndexUsed; ///< Has 1 if the line is used + List m_usedLineIndices; ///< Holds the lines that have been hit + + List m_lineInfos; ///< The line infos + List m_adjustedLineInfos; ///< The adjusted line infos + }; + + void _addInstruction(IRInst* inst); + Result _calcDebugInfo(); + /// Returns the remapped sourceLoc, or 0 if sourceLoc couldn't be added + void _addDebugSourceLocRun(SourceLoc sourceLoc, uint32_t startInstIndex, uint32_t numInst); + + List m_insts; ///< Instructions in same order as stored in the + + List m_decorations; ///< Holds all decorations in order of the instructions as found + List m_instWithFirstDecoration; ///< All decorations are held in this order after all the regular instructions + + Dictionary m_instMap; ///< Map an instruction to an instruction index + + StringSlicePool m_stringSlicePool; + IRSerialData* m_serialData; ///< Where the data is stored + + StringSlicePool m_debugStringSlicePool; ///< Slices held just for debug usage + + SourceLoc::RawValue m_debugFreeSourceLoc; /// Locations greater than this are free + Dictionary > m_debugSourceFileMap; + + SourceManager* m_sourceManager; ///< The source manager +}; + +struct IRSerialReader +{ + typedef IRSerialData Ser; + typedef StringRepresentationCache::Handle StringHandle; + + /// Read a stream to fill in dataOut IRSerialData + static Result readStream(Stream* stream, IRSerialData* dataOut); + + /// Read a module from serial data + Result read(const IRSerialData& data, Session* session, SourceManager* sourceManager, RefPtr& moduleOut); + + /// Get the representation cache + StringRepresentationCache& getStringRepresentationCache() { return m_stringRepresentationCache; } + + IRSerialReader(): + m_serialData(nullptr), + m_module(nullptr) + { + } + + protected: + + static Result _skip(const IRSerialBinary::Chunk& chunk, Stream* stream, int64_t* remainingBytesInOut); + + StringRepresentationCache m_stringRepresentationCache; + + const IRSerialData* m_serialData; + IRModule* m_module; +}; + +struct IRSerialUtil +{ + /// Produces an instruction list which is in same order as written through IRSerialWriter + static void calcInstructionList(IRModule* module, List& instsOut); + + /// Verify serialization + static SlangResult verifySerialize(IRModule* module, Session* session, SourceManager* sourceManager, IRSerialBinary::CompressionType compressionType, IRSerialWriter::OptionFlags optionFlags); +}; + + +} // namespace Slang + +#endif diff --git a/source/slang/slang-ir-specialize-resources.cpp b/source/slang/slang-ir-specialize-resources.cpp new file mode 100644 index 000000000..f72ca6b38 --- /dev/null +++ b/source/slang/slang-ir-specialize-resources.cpp @@ -0,0 +1,865 @@ +// slang-ir-specialize-resources.cpp +#include "slang-ir-specialize-resources.h" + +#include "slang-ir.h" +#include "slang-ir-clone.h" +#include "slang-ir-insts.h" + +namespace Slang +{ + +struct ResourceParameterSpecializationContext +{ + // This type implements a pass to specialize functions + // with resource parameters to ensure that they are + // legal for a given target. + // + // We start with member variables to stand in for + // the parameters that were passed to the top-level + // `specializeResourceParameters` function. + // + BackEndCompileRequest* compileRequest; + TargetRequest* targetRequest; + IRModule* module; + + // Our general approach will be to think in terms + // of specializing call sites, which amount to + // `IRCall` instructions. We will keep a work list + // of call sites in the program that may be worth + // considering for specialization. + // + List workList; + + // Because we may need to generate specialized functions + // and generate new calls to those functions, we'll + // need some IR building state to get our work done. + // + SharedIRBuilder sharedBuilderStorage; + IRBuilder builderStorage; + IRBuilder* getBuilder() { return &builderStorage; } + + // With the basic state out of the way, let's walk + // through the overall flow of the pass. + // + void processModule() + { + // We will start by initializing our IR building state. + // + sharedBuilderStorage.module = module; + sharedBuilderStorage.session = module->getSession(); + builderStorage.sharedBuilder = &sharedBuilderStorage; + + // Next we will populate our initial work list by + // recursively finding every single call site in the module. + // + addCallsToWorkListRec(module->getModuleInst()); + + // We will process the work list until it goes dry, + // treating it like a stack of work items. + // + while( workList.getCount() ) + { + auto call = workList.getLast(); + workList.removeLast(); + + // At each call site we first check whether it + // is something we can (and should) specialize, + // and if so, do it. The process of specializing + // a function may introduce new call sites that + // become candidates for specialization, so + // our work list may grow along the way. + // + if( canSpecializeCall(call) ) + { + specializeCall(call); + } + } + } + + // Setting up the work list is a simple recursive procedure. + // + void addCallsToWorkListRec(IRInst* inst) + { + // If we have a call site, then add it to the list. + // + if( auto call = as(inst) ) + { + workList.add(call); + } + + // Recursively walk through any children, to + // see if we uncover more call sites. + // + for( auto child : inst->getChildren() ) + { + addCallsToWorkListRec(child); + } + } + + // We need a way to decide for a given call site + // whether we can/must specialize it. + // + bool canSpecializeCall(IRCall* call) + { + // We can only specialize calls where the callee + // func can be statically identified, and where + // the callee is a definition (with body) rather + // than a declaration. Otherwise there is no + // way to generate a specialized callee function. + // + auto func = as(call->getCallee()); + if(!func) + return false; + if(!func->isDefinition()) + return false; + + // With the basic checks out of the way, there are + // two conditions we care about: + // + // 1. Should we specialize? This amounts to whether + // `func` has any parameters that need specialization. + // We will call those "specializable" parameters for + // lack of a better name. + // + // 2. Can we specialize? This amounts to whether the + // arguments in `call` that correspond to those + // specializable parameters are "suitable" for use + // in specialization. + // + // We are going to answer both of these queries in + // a single loop that walks over the parameters of + // `func` as well as the arguments to `call`. + // + // The loop may seem a bit awkward because we are + // doing a parallel iteration over a linked list + // (the parameters of `func`) and an array (the + // arguments of `call`). + // + bool anySpecializableParam = false; + UInt argCounter = 0; + for( auto param : func->getParams() ) + { + UInt argIndex = argCounter++; + SLANG_ASSERT(argIndex < call->getArgCount()); + auto arg = call->getArg(argIndex); + + // If the given parameter doesn't need specialization, + // then we need to keep looking. + // + if(!doesParamNeedSpecialization(param)) + continue; + + // If we have run into a `param` that needs specialization, + // then our first condition is met. + // + anySpecializableParam = true; + + // Now we need to check whether `arg` is actually suitable + // for specialization (our second condition). If not, we + // can bail out immediately because our second condition + // cannot be met. + // + if(!isArgSuitableForSpecialization(arg)) + return false; + } + + // If we exit the loop, then the second condition must have + // been met (all the arguments for specializable parameters + // were suitable for specialization), and the result of the + // query comes down to the first condition. + // + return anySpecializableParam; + } + + // Of course, now we need to back-fill the predicates that + // the above function used to evaluate prameters and arguments. + + bool doesParamNeedSpecialization(IRParam* param) + { + // Whether or not a parameter needs specialization is really + // a function of its type: + // + IRType* type = param->getDataType(); + + // What's more, if a parameter of type `T` would need + // specialization, then it seems clear that a parameter + // of type "array of `T`" would also need specialization. + // We will "unwrap" any outer arrays from the parameter + // type before moving on, since they won't affect + // our decision. + // + type = unwrapArray(type); + + // On all of our (current) targets, a function that + // takes a `ConstantBuffer` parameter requires + // specialization. Surprisingly this includes DXIL + // because dxc apparently does not treat `ConstantBuffer` + // as a first-class type. + // + if(as(type)) + return true; + + // For GL/Vulkan targets, we also need to specialize + // any parameters that use structured or byte-addressed + // buffers. + // + if( isKhronosTarget(targetRequest) ) + { + if(as(type)) + return true; + if(as(type)) + return true; + } + + // For now, we will not treat any other parameters as + // needing specialization, even if they use resource + // types like `Texure2D`, because these are allowed + // as function parameters in both HLSL and GLSL. + // + // TODO: Eventually, if we start generating SPIR-V + // directly rather than through glslang, we will need + // to specialize *all* resource-type parameters + // to follow the restrictions in the spec. + // + // TODO: We may want to perform more aggressive + // specialization in general, especially insofar + // as it could simplify the task of supporting + // functions with resource-type outputs. + + return false; + } + + bool isArgSuitableForSpecialization(IRInst* inArg) + { + // Determining if an argument is suitable for + // specializing a callee function requires + // looking at its (recurisve) structure. + // + // Rather than write a recursively procedure + // here, we will be tail-recursive by using + // a simple loop. + // + IRInst* arg = inArg; + for(;;) + { + // The leaf case we care about is when the + // argument at the call site is a global + // shader parameter, because then we can + // specialize a callee to refer to the same + // global parameter directly. + // + if(as(arg)) return true; + + // As we will see later, we can also + // specialize a call when the argument + // is the result of indexing into an + // array (`base[index]`) *if* the `base` + // of the indexing operation is also + // suitable for specialization. + // + if( arg->op == kIROp_getElement ) + { + auto base = arg->getOperand(0); + + // We will "recurse" on the base of + // the indexing operation by continuing + // our loop with the `base` as our new + // argument. + // + arg = base; + continue; + } + + // By default, we will *not* consider an argument + // suitable for specialization. + // + // TODO: There may be other cases that are worth + // handling here. The current code is based on + // observation of what simple shaders do in + // practice. + // + return false; + } + } + + // Once we'e determined that a given call site can/should + // be specialized, we need to perform the actual specialization. + // This is where things are going to get more involved. + // + // There are a few different concerns we need to deal with + // that mean we end up having two different passes that walk + // over the parameters/arguments of the call (in addition to + // the ones we had above for determining if we can/should + // specialize in the first place). + // + // The first of the two passes determines information + // relevant to the call site, comprising both the arguments + // that will be passed to the specialized function as + // well as a "key" to identify the specialized function + // that is required. + // + // We will use the key type defined as part of the IR cloning + // infrastructure, which uses a sequence of `IRInst*`s + // to hold the state of the key: + // + typedef IRSimpleSpecializationKey Key; + + // As indicated above, the information we collect about a call + // site consists of the key for the specialized function we + // will call, and a list of the arguments that will be passed + // to the call. + // + struct CallSpecializationInfo + { + Key key; + List newArgs; + }; + + // Once we've collected the information about a call site + // we can use a dictionary to see if we already created + // a specialized version of the callee that matches its + // requirements. + // + Dictionary specializedFuncs; + + // If the dictionary didn't have a specialized function + // suitable for a call site, we need a second information-gathering + // pass to decide what the new parameters of the specialized + // functions should be, and what instructions the new function + // must execute in its body to set up the replacements for the + // old parameters. + // + struct FuncSpecializationInfo + { + List newParams; + List newBodyInsts; + List replacementsForOldParameters; + }; + + // Before diving into how the different passes collect + // their information, we will dive into the main + // specialization logic first. + // + void specializeCall(IRCall* oldCall) + { + // We have an existing call site `oldCall` that + // we know can and should be specialized. + // + // That means the callee should be a known function + // definition, or else `canSpecializeCall` didn't + // correctly check the preconditions. + // + auto oldFunc = as(oldCall->getCallee()); + SLANG_ASSERT(oldFunc); + SLANG_ASSERT(oldFunc->isDefinition()); + + // Our first information-gathering pass will + // compute the key for the specialized function + // we want to call, and the arguments we will + // use for that call. + // + CallSpecializationInfo callInfo; + gatherCallInfo(oldCall, oldFunc, callInfo); + + // Once we have gathered information on the call, + // we can check if we have an existing specialization + // that we generated before (for another call site) + // that is suitable to this call site. + // + IRFunc* newFunc = nullptr; + if( !specializedFuncs.TryGetValue(callInfo.key, newFunc) ) + { + // If we didn't find a pre-existing specialized + // function, then we will go ahead and create one. + // + // We start by gathering the information from the call + // site that is relevant to generating a specialized + // callee function, which we avoided doing earlier + // because it might have been throwaway work. + // + FuncSpecializationInfo funcInfo; + gatherFuncInfo(oldCall, oldFunc, funcInfo); + + // Now we use the gathered information to generate + // a new callee function based on the original + // function and the information we gathered. + // + newFunc = generateSpecializedFunc(oldFunc, funcInfo); + specializedFuncs.Add(callInfo.key, newFunc); + } + + // Once we've other found or generated a specialized function + // we need to generate a call to it, and then use the new + // call as a replacement for the old one. + // + auto newCall = getBuilder()->emitCallInst( + oldCall->getFullType(), + newFunc, + callInfo.newArgs.getCount(), + callInfo.newArgs.getBuffer()); + + newCall->insertBefore(oldCall); + oldCall->replaceUsesWith(newCall); + oldCall->removeAndDeallocate(); + } + + // Before diving into the details on how we gather information + // and specialize callees, lets stop to think about what we'd + // like to do in terms of individual parameters and arguments. + // + // Suppose we are specializing both a call site C and the callee + // function F, and we are consisering a particular pair of + // a parmeter P of F, and an argument A at the call site. + // + // The full extent of information we might want to know given + // P and A is: + // + // * What arguments need to be added to the specialized call? + // * What parameters need to be added to the specialized callee? + // * What instructions are needed in the body of the specialized + // callee to synthesize the value that will stand in for P? + // * What information, if any, needs to be used to distinguish + // this specialized callee from others that might be generated for F? + // + // An easy case is when P is a parameter that doesn't need + // specialization. In that case: + // + // * The existing argument A should be used as an argument in + // the specialized call. + // * A clone P' of the existing parameter P should be used as a + // parameter of the specialized callee. + // * No additional instructions are needed in the body of + // the callee; the cloned parameter P' should stand in for P. + // * No information should be added to the specialization key + // based on P and A. + // + // The more interesting case is when P has a resource type, and + // A is some global shader parameter G. + // + // * No argument should be added at the new call site + // * No parameter should be added to the specialized callee + // * No additional instructions are needed in the body of + // the callee; the global G should stand in for P. + // * The global G should be used to distinguish this specialized + // callee from those that might be specialized for a different + // global shader parameter. + // + // As a final example, imagine that P is still a resource type, + // but A is now an indexing operation into an array: `G[idx]`: + // + // * An argument for `idx` should be added at the call site + // * A parameter `p_idx` with the same type as `idx` should be added + // to the specialized callee. + // * An instruction should be added to the specialized callee + // to compute `G[p_idx]` and use that to stand in for P. + // * The global G should still be used to distinguish this specialized + // call site from others. + // + // That's a lot of examples, I know, but hopefully it gives a + // sense of the information we are tracking and how it differs + // across the various cases. While the example only covered one + // level of indexing, the actual implementation will handle the + // case of arbitrarily many levels of indexing, which can mean + // piping through any number of additional integer parameters + // to the callee. + + // The information we gather for a call site (before we know + // whether a specialize calle is needed) is just the new + // argument list, and the "key" information that distinguishes + // what specialized callee we want/need. + // + void gatherCallInfo( + IRCall* oldCall, + IRFunc* oldFunc, + CallSpecializationInfo& callInfo) + { + // The specialized callee key always needs to include + // the original function, since different functions + // will always yield different specializations. + // + callInfo.key.vals.add(oldFunc); + + // The rest of the information is gathered by looking + // at parameter and argument pairs. + // + UInt oldArgCounter = 0; + for( auto oldParam : oldFunc->getParams() ) + { + UInt oldArgIndex = oldArgCounter++; + auto oldArg = oldCall->getArg(oldArgIndex); + + getCallInfoForParam(callInfo, oldParam, oldArg); + } + } + + void getCallInfoForParam( + CallSpecializationInfo& ioInfo, + IRParam* oldParam, + IRInst* oldArg) + { + // We know that the case where a parameter + // doesn't need specialization is easy. + // + if( !doesParamNeedSpecialization(oldParam) ) + { + // The new call site will use the same argument + // value as the old one, and we don't need + // to add any information to distinguish the + // specialized callee based on this paramter. + // + ioInfo.newArgs.add(oldArg); + } + else + { + // If specialization is needed, we need + // to inspect the argument value. This + // is handled with a different function + // because it needs to recurse in some cases. + // + getCallInfoForArg(ioInfo, oldArg); + } + } + + void getCallInfoForArg( + CallSpecializationInfo& ioInfo, + IRInst* oldArg) + { + // The base case we care about is when the original + // argument is a global shader parameter. + // + if( auto oldGlobalParam = as(oldArg) ) + { + // In this case we don't need to pass anything + // as an argument at the new call site (the + // global parameter will get specialized into + // the callee), but we *do* need to make sure + // that our key for identifying the specialized + // callee reflects that we are specializing + // to the chosen parameter. + // + ioInfo.key.vals.add(oldGlobalParam); + } + else if( oldArg->op == kIROp_getElement ) + { + // This is the case where the `oldArg` is + // in the form `oldBase[oldIndex]` + // + auto oldBase = oldArg->getOperand(0); + auto oldIndex = oldArg->getOperand(1); + + // Effectively, we act as if `oldBase` and + // `oldIndex` were passed to the callee separately, + // so that `oldBase` is an array-of-resouces and + // `oldIndex` is an ordinary integer argument. + // + // We start by recursively setting up whatever + // `oldBase` needs: + // + getCallInfoForArg(ioInfo, oldBase); + + // Then we process `oldIndex` just like we + // would have an ordinary argument that doesn't + // involve specialization: add its value to + // the arguments at the new call site, and + // don't add anything to the specialization key. + // + ioInfo.newArgs.add(oldIndex); + } + else + { + // If we fail to match any of the cases above + // then a precondition was violated in that + // `isArgSuitableForSpecialization` is allowing + // a case that this routine is not covering. + // + SLANG_UNEXPECTED("mising case in 'getCallInfoForArg'"); + } + } + + // The remaining information we've discussed is only + // gathered once we decide we want to generate a + // specialized function, but it follows much the same flow. + // + void gatherFuncInfo( + IRCall* oldCall, + IRFunc* oldFunc, + FuncSpecializationInfo& funcInfo) + { + UInt oldArgCounter = 0; + for( auto oldParam : oldFunc->getParams() ) + { + UInt oldArgIndex = oldArgCounter++; + auto oldArg = oldCall->getArg(oldArgIndex); + + // For each parameter and argument pair we will + // frame the main task as producing a value that + // will stand in for the parameter in the specialized + // function. + // + auto newVal = getSpecializedValueForParam(funcInfo, oldParam, oldArg); + + // We will collect the replacement value to use + // for each of the original parameters in an array. + // + funcInfo.replacementsForOldParameters.add(newVal); + } + } + + IRInst* getSpecializedValueForParam( + FuncSpecializationInfo& ioInfo, + IRParam* oldParam, + IRInst* oldArg) + { + // As always, the easy case is when the parameter of + // the original function doesn't need specialization. + // + if( !doesParamNeedSpecialization(oldParam) ) + { + // The specialized callee will need a new parameter + // that fills the same role as the old one, so we + // create it here. + // + auto newParam = getBuilder()->createParam(oldParam->getFullType()); + ioInfo.newParams.add(newParam); + + // The new parameter will be used as the replacement + // for the old one in the specialized function. + // + return newParam; + } + else + { + // If the parameter requires specialization, then it + // is time to look at the structure of the argument. + // + return getSpecializedValueForArg(ioInfo, oldArg); + } + } + + IRInst* getSpecializedValueForArg( + FuncSpecializationInfo& ioInfo, + IRInst* oldArg) + { + // The logic here parallels `gatherCallInfoForArg`, + // and only differs in what information it is gathering. + // + // As before, the base case is when we have a global + // shader parameter. + // + if( auto globalParam = as(oldArg) ) + { + // The specialized function will not need any + // parameter in this case, and the global itself + // should be used to stand in for the original + // parameter in the specialized function. + // + return globalParam; + } + else if( oldArg->op == kIROp_getElement ) + { + // This is the case where the argument is + // in the form `oldBase[oldIndex]`. + // + auto oldBase = oldArg->getOperand(0); + auto oldIndex = oldArg->getOperand(1); + + // In `gatherCallInfoForArg` this case was + // handled by acting as if `oldBase` and + // `oldIndex` were being passed as two + // separate arguments. + // + // We'll follow the same structure here, + // starting by recursively processing `oldBase` + // to get a value that can stand in for it + // in the specialized callee. + // + auto newBase = getSpecializedValueForArg(ioInfo, oldBase); + + // Next we'll process `oldIndex` as if it + // was an ordinary argument (not a specialized one), + // which means creating a parameter to receive its value, + // which will also stand in for `oldIndex` in + // the body of the specialized callee. + // + auto builder = getBuilder(); + auto newIndex = builder->createParam(oldIndex->getFullType()); + ioInfo.newParams.add(newIndex); + + // Finally, we need to compute a value that + // can stand in for `oldArg` (which was + // `oldBase[oldIndex]`) in the body of the + // specialized callee. + // + // Because we have both a `newBase` and a + // `newIndex` it is natural to construct + // `newBase[newIndex]` and use that. + // + // The only complication is that we need + // to make sure that our IR builder isn't + // set to insert newly created instructions + // anywhere, since the `emit*` functions + // will try to automatically insert new + // instructions if an insertion location + // is set. + // + builder->setInsertInto(nullptr); + auto newVal = builder->emitElementExtract( + oldArg->getFullType(), + newBase, + newIndex); + + // Because our new instruction wasn't + // actually inserted anywhere, we need to + // add it to our gathered list of instructions + // that should be inserted into the body of + // the specialized callee. + // + ioInfo.newBodyInsts.add(newVal); + + return newVal; + } + else + { + // If we don't match one of the above cases, + // then `isArgSuitableForSpecialization` is + // letting through cases that this function + // hasn't been updated to handle. + // + SLANG_UNEXPECTED("mising case in 'getSpecializedValueForArg'"); + UNREACHABLE_RETURN(nullptr); + } + } + + // With all of that data-gathering code out of the way, + // we are now prepared to walk through the process of + // specializing a given callee function based on + // the information we have gathered. + // + IRFunc* generateSpecializedFunc( + IRFunc* oldFunc, + FuncSpecializationInfo const& funcInfo) + { + // We will make use of the infrastructure for cloning + // IR code, that is defined in `ir-clone.{h,cpp}`. + // + // In order to do the cloning work we need an + // "environment" that will map old values to + // their replacements. + // + IRCloneEnv cloneEnv; + + // Next we iterate over the parameters of the old + // function, and register each as being mapped + // to its replacement in the `funcInfo` that was + // already gathered. + // + UInt paramCounter = 0; + for( auto oldParam : oldFunc->getParams() ) + { + UInt paramIndex = paramCounter++; + auto newVal = funcInfo.replacementsForOldParameters[paramIndex]; + cloneEnv.mapOldValToNew.Add(oldParam, newVal); + } + + // Next we will create the skeleton of the new + // specialized function, including its type. + // + // To get the type of the new function we will + // iterate over the collected list of new + // parameters (which may differ greatly from the + // parameter list of the original) and extract + // their types. + // + List paramTypes; + for( auto param : funcInfo.newParams ) + { + paramTypes.add(param->getFullType()); + } + + auto builder = getBuilder(); + IRType* funcType = builder->getFuncType( + paramTypes.getCount(), + paramTypes.getBuffer(), + oldFunc->getResultType()); + + IRFunc* newFunc = builder->createFunc(); + newFunc->setFullType(funcType); + + // The above step has accomplished the "first phase" + // of cloning the function (since `IRFunc`s have no + // operands). + // + // We can now use the shared IR cloning infrastructure + // to perform the second phase of cloning, which will recursively + // clone any nested decorations, blocks, and instructions. + // + cloneInstDecorationsAndChildren( + &cloneEnv, + builder->sharedBuilder, + oldFunc, + newFunc); + + // We are almost done at this point, except that `newFunc` + // is lacking its parameters, as well as any of the body + // instructions that we decided were needed during + // the information-gathering steps. + // + // We will insert these instructions into the first block + // of the function, before its first ordinary instruction. + // We know that these should exist because we had as + // a precondition that `oldFunc` was a definition (so it + // has at least one block), and in valid IR every block + // has at least one ordinary instruction (its terminator). + // + auto newEntryBlock = newFunc->getFirstBlock(); + SLANG_ASSERT(newEntryBlock); + auto newFirstOrdinary = newEntryBlock->getFirstOrdinaryInst(); + SLANG_ASSERT(newFirstOrdinary); + + // We simply iterate over the list of parameters and then + // body instructions that were produced in the information + // gathering step, and insert each before `newFirstOrdinary`, + // which has the effect or arranging them in the output + // in the order they are enumerated here. + // + for( auto newParam : funcInfo.newParams ) + { + newParam->insertBefore(newFirstOrdinary); + } + for( auto newBodyInst : funcInfo.newBodyInsts ) + { + newBodyInst->insertBefore(newFirstOrdinary); + } + + // At this point we've created a new specialized function, + // and as such it may contain call sites that were not + // covered when we built our initial work list. + // + // Before handing the specialized function back to the + // caller, we will make sure to recursively add any + // potentially-specializable call sites to our work list. + // + addCallsToWorkListRec(newFunc); + + return newFunc; + } +}; + +// The top-level function for invoking the specialization pass +// is straighforward. We set up the context object +// and then defer to it for the real work. +// +void specializeResourceParameters( + BackEndCompileRequest* compileRequest, + TargetRequest* targetRequest, + IRModule* module) +{ + ResourceParameterSpecializationContext context; + context.compileRequest = compileRequest; + context.targetRequest = targetRequest; + context.module = module; + + context.processModule(); +} + +} // namesapce Slang diff --git a/source/slang/slang-ir-specialize-resources.h b/source/slang/slang-ir-specialize-resources.h new file mode 100644 index 000000000..1a5e0f7d8 --- /dev/null +++ b/source/slang/slang-ir-specialize-resources.h @@ -0,0 +1,24 @@ +// slang-ir-specialize-resources.h +#pragma once + +namespace Slang +{ + class BackEndCompileRequest; + class TargetRequest; + struct IRModule; + + /// Specialize calls to functions with resource-type parameters. + /// + /// For any function that has resource-type input parameters that + /// would be invalid on the chosen target, this pass will rewrite + /// any call sites that pass suitable arguments (e.g., direct + /// references to global shader parameters) to instead call + /// a specialized variant of the function that does not have + /// those resource parameters (and instead, e.g, refers to the + /// global shader parameters directly). + /// + void specializeResourceParameters( + BackEndCompileRequest* compileRequest, + TargetRequest* targetRequest, + IRModule* module); +} diff --git a/source/slang/slang-ir-specialize.cpp b/source/slang/slang-ir-specialize.cpp new file mode 100644 index 000000000..fe6b82184 --- /dev/null +++ b/source/slang/slang-ir-specialize.cpp @@ -0,0 +1,1864 @@ +// slang-ir-specialize.cpp +#include "slang-ir-specialize.h" + +#include "slang-ir.h" +#include "slang-ir-clone.h" +#include "slang-ir-insts.h" + +namespace Slang +{ + +// This file implements the primary specialization pass, that takes +// generic/polymorphic Slang code and specializes/monomorphises it. +// +// At present this primarily means generating specialized copies +// of generic functions/types based on the concrete types used +// at specialization sites, and also specializing instances +// of witness-table lookup to directly refer to the concrete +// values for witnesses when witness tables are known. +// +// This pass also performs some amount of simplification and +// specialization for code using existential (interface) types +// for local variables and function parameters/results. +// +// Eventually, this pass will also need to perform specialization +// of functions to argument values for parameters that must +// be compile-time constants, +// +// All of these passes are inter-related in that applying +// simplifications/specializations of one category can open +// up opportunities for transformations in the other categories. + +struct SpecializationContext; + +IRInst* specializeGenericImpl( + IRGeneric* genericVal, + IRSpecialize* specializeInst, + IRModule* module, + SpecializationContext* context); + +struct SpecializationContext +{ + // For convenience, we will keep a pointer to the module + // we are specializing. + IRModule* module; + + // We know that we can only perform generic specialization when all + // of the arguments to a generic are also fully specialized. + // The "is fully specialized" condition is something we + // need to solve for over the program, because the fully- + // specialized-ness of an instruction depends on the + // fully-specialized-ness of its operands. + // + // We will build an explicit hash set to encode those + // instructions that are fully specialized. + // + HashSet fullySpecializedInsts; + + // An instruction is then fully specialized if and only + // if it is in our set. + // + bool isInstFullySpecialized( + IRInst* inst) + { + // A small wrinkle is that a null instruction pointer + // sometimes appears a a type, and so should be treated + // as fully specialized too. + // + // TODO: It would be nice to remove this wrinkle. + // + if(!inst) return true; + + return fullySpecializedInsts.Contains(inst); + } + + // When an instruction isn't fully specialized, but its operands *are* + // then it is a candidate for specialization itself, so we will have + // a query to check for the "all operands fully specialized" case. + // + bool areAllOperandsFullySpecialized( + IRInst* inst) + { + if(!isInstFullySpecialized(inst->getFullType())) + return false; + + UInt operandCount = inst->getOperandCount(); + for(UInt ii = 0; ii < operandCount; ++ii) + { + IRInst* operand = inst->getOperand(ii); + if(!isInstFullySpecialized(operand)) + return false; + } + + return true; + } + + // We will use a single work list of instructions that need + // to be considered for specialization or simplification, + // whether generic, existential, etc. + // + List workList; + HashSet workListSet; + + HashSet cleanInsts; + + void addToWorkList( + IRInst* inst) + { + // We will ignore any code that is nested under a generic, + // because it doesn't make sense to perform specialization + // on such code. + // + for( auto ii = inst->getParent(); ii; ii = ii->getParent() ) + { + if(as(ii)) + return; + } + + if(workListSet.Contains(inst)) + return; + + workList.add(inst); + workListSet.Add(inst); + cleanInsts.Remove(inst); + + addUsersToWorkList(inst); + } + + // When a transformation makes a change to an instruction, + // we may need to re-consider transformations for instructions + // that use its value. In those cases we will call `addUsersToWorkList` + // on the instruction that is being modified or replaced. + // + void addUsersToWorkList( + IRInst* inst) + { + for( auto use = inst->firstUse; use; use = use->nextUse ) + { + auto user = use->getUser(); + addToWorkList(user); + } + } + + // One of the main transformations we will apply is to + // consider an instruction as being fully specialized. + // + void markInstAsFullySpecialized( + IRInst* inst) + { + if(fullySpecializedInsts.Contains(inst)) + return; + fullySpecializedInsts.Add(inst); + + // If we know that an instruction is fully specialized, + // then we should start to consider its uses and children + // as candidates for being fully specialized too... + // + addUsersToWorkList(inst); + } + + + // Of course, somewhere along the way we expect + // to run into uses of `specialize(...)` instructions + // to bind a generic to arguments that we want to + // specialize into concrete code. + // + // We also know that if we encouter `specialize(g, a, b, c)` + // and then later `specialize(g, a, b, c)` again, we + // only want to generate the specialized code for `g` + // *once*, and re-use it for both versions. + // + // We will cache existing specializations of generic function/types + // using the simple key type defined as part of the IR cloning infrastructure. + // + typedef IRSimpleSpecializationKey Key; + Dictionary genericSpecializations; + + // We will also use some shared IR building state across + // all of our specialization/cloning steps. + // + SharedIRBuilder sharedBuilderStorage; + + // Now let's look at the task of finding or generation a + // specialization of some generic `g`, given a specialization + // instruction like `specialize(g, a, b, c)`. + // + // The `specializeGeneric` function will return a value + // suitable for use as a replacement for the `specialize(...)` + // instruction. + // + IRInst* specializeGeneric( + IRGeneric* genericVal, + IRSpecialize* specializeInst) + { + // First, we want to see if an existing specialization + // has already been made. To do that we will construct a key + // for lookup in the generic specialization context. + // + // Our key will consist of the identity of the generic + // being specialized, and each of the argument values + // being pased to it. In our hypothetical example of + // `specialize(g, a, b, c)` the key will then be + // the array `[g, a, b, c]`. + // + Key key; + key.vals.add(specializeInst->getBase()); + UInt argCount = specializeInst->getArgCount(); + for( UInt ii = 0; ii < argCount; ++ii ) + { + key.vals.add(specializeInst->getArg(ii)); + } + + { + // We use our generated key to look for an + // existing specialization that has been registered. + // If one is found, our work is done. + // + IRInst* specializedVal = nullptr; + if(genericSpecializations.TryGetValue(key, specializedVal)) + return specializedVal; + } + + // If no existing specialization is found, we need + // to create the specialization instead. + // This mostly amounts to evaluating the generic as + // if it were a function being called. + // + // We will use a free function to do the actual work + // of evaluating the generic, so that the logic + // can be re-used in other cases that need to + // do one-off specialization. + // + IRInst* specializedVal = specializeGenericImpl(genericVal, specializeInst, module, this); + + + // The value that was returned from evaluating + // the generic is the specialized value, and we + // need to remember it in our dictionary of + // specializations so that we don't instantiate + // this generic again for the same arguments. + // + genericSpecializations.Add(key, specializedVal); + + return specializedVal; + } + + // The logic for generating a specialization of an IR generic + // relies on the ability to "evaluate" the code in the body of + // the generic, but that obviously doesn't work if we don't + // actually have the full definition for the body. + // + // This can arise in particular for builtin operations/types. + // + // Before calling `specializeGeneric()` we need to make sure + // that the generic is actually amenable to specialization, + // by looking at whether it is a definition or a declaration. + // + bool canSpecializeGeneric( + IRGeneric* generic) + { + // It is possible to have multiple "layers" of generics + // (e.g., when a generic function is nested in a generic + // type). Therefore we need to drill down through all + // of the layers present to see if at the leaf we have + // something that looks like a definition. + // + IRGeneric* g = generic; + for(;;) + { + // Given the generic `g`, we will find the value + // it appears to return in its body. + // + auto val = findGenericReturnVal(g); + if(!val) + return false; + + // If `g` returns an inner generic, then we need + // to drill down further. + // + if (auto nestedGeneric = as(val)) + { + g = nestedGeneric; + continue; + } + + // Once we've found the leaf value that will be produced + // after all specialization is complete, we can check + // whether it looks like a definition or not. + // + return isDefinition(val); + } + } + + // Now that we know when we can specialize a generic, and how + // to do it, we can write a subroutine that takes a + // `specialize(g, a, b, c, ...)` instruction and performs + // specialization if it is possible. + // + void maybeSpecializeGeneric( + IRSpecialize* specInst) + { + // We will only attempt to specialize when all of the + // operands to the `speicalize(...)` instruction are + // themselves fully specialized. + // + if(!areAllOperandsFullySpecialized(specInst)) + return; + + // The invariant that the arguments are fully specialized + // should mean that `a, b, c, ...` are in a form that + // we can work with, but it does *not* guarantee + // that the `g` operand is something we can work with. + // + // We can only perform specialization in the case where + // the base `g` is a known `generic` instruction. + // + auto baseVal = specInst->getBase(); + auto genericVal = as(baseVal); + if(!genericVal) + return; + + // We can also only specialize a generic if it + // represents a definition rather than a declaration. + // + if(!canSpecializeGeneric(genericVal)) + return; + + // Once we know that specialization is possible, + // the actual work is fairly simple. + // + // First, we find or generate a specialized + // version of the result of the generic (a specialized + // type, function, or whatever). + // + auto specializedVal = specializeGeneric(genericVal, specInst); + + // Any uses of this `specialize(...)` instruction will + // become uses of `specializeVal`, so we want to re-consider + // them for subsequent transformations. + // + addUsersToWorkList(specInst); + + // Then we simply replace any uses of the `specialize(...)` + // instruction with the specialized value and delete + // the `specialize(...)` instruction from existence. + // + specInst->replaceUsesWith(specializedVal); + specInst->removeAndDeallocate(); + } + + // Generic specialization depends on identifying when + // instructions are fully specialized. + // + void maybeMarkAsFullySpecialized( + IRInst* inst) + { + switch(inst->op) + { + default: + // The default case is that an instruction can + // be considered as fully specialized as soon + // as all of its operands are. + // + // TODO: We realistically need a more refined + // check here that uses a white-list of instructions + // that can represent values suitable for use + // as generic arguments. + // + if(areAllOperandsFullySpecialized(inst)) + { + markInstAsFullySpecialized(inst); + } + break; + + // Certain instructions cannot ever be considered + // fully specialized because they should never + // be substituted into a generic as its arguments. + case kIROp_Specialize: + case kIROp_lookup_interface_method: + case kIROp_ExtractExistentialType: + case kIROp_BindExistentialsType: + break; + } + } + + // The core of this pass is to look at one instruction + // at a time, and try to perform whatever specialization + // is appropriate based on its opcode. + // + void maybeSpecializeInst( + IRInst* inst) + { + switch(inst->op) + { + default: + // By default we assume that specialization is + // not possible for a given opcode. + // + break; + + case kIROp_Specialize: + // The logic for specializing a `specialize(...)` + // instruction has already been elaborated above. + // + maybeSpecializeGeneric(cast(inst)); + break; + + case kIROp_lookup_interface_method: + // The remaining case we need to consider here for generics + // is when we have a `lookup_witness_method` instruction + // that is being applied to a concrete witness table, + // because we can specialize it to just be a direct + // reference to the actual witness value from the table. + // + maybeSpecializeWitnessLookup(cast(inst)); + break; + + case kIROp_Call: + // When writing functions with existential-type parameters, + // we need additional support to specialize a callee + // function based on the concrete type encapsulated in + // an argument of existential type. + // + maybeSpecializeExistentialsForCall(cast(inst)); + break; + + // The specialization of functions with existential-type + // parameters can create further opportunities for specialization, + // but in order to realize these we often need to propagate + // through local simplification on values of existential type. + // + case kIROp_ExtractExistentialType: + maybeSpecializeExtractExistentialType(inst); + break; + case kIROp_ExtractExistentialValue: + maybeSpecializeExtractExistentialValue(inst); + break; + case kIROp_ExtractExistentialWitnessTable: + maybeSpecializeExtractExistentialWitnessTable(inst); + break; + + case kIROp_Load: + maybeSpecializeLoad(as(inst)); + break; + + case kIROp_FieldExtract: + maybeSpecializeFieldExtract(as(inst)); + break; + case kIROp_FieldAddress: + maybeSpecializeFieldAddress(as(inst)); + break; + + case kIROp_BindExistentialsType: + maybeSpecializeBindExistentialsType(as(inst)); + break; + } + } + + // Specializing lookup on witness tables is a general + // transformation that helps with both generic and + // existential-based code. + // + void maybeSpecializeWitnessLookup( + IRLookupWitnessMethod* lookupInst) + { + // Note: While we currently have named the instruction + // `lookup_witness_method`, the `method` part is a misnomer + // and the same instruction can look up *any* interface + // requirement based on the witness table that provides + // a conformance, and the "key" that indicates the interface + // requirement. + + // We can only specialize in the case where the lookup + // is being done on a concrete witness table, and not + // the result of a `specialize` instruction or other + // operation that will yield such a table. + // + auto witnessTable = as(lookupInst->getWitnessTable()); + if(!witnessTable) + return; + + // Because we have a concrete witness table, we can + // use it to look up the IR value that satisfies + // the given interface requirement. + // + auto requirementKey = lookupInst->getRequirementKey(); + auto satisfyingVal = findWitnessVal(witnessTable, requirementKey); + + // We expect to always find a satisfying value, but + // we will go ahead and code defensively so that + // we leave "correct" but unspecialized code if + // we cannot find a concrete value to use. + // + if(!satisfyingVal) + return; + + // At this point, we know that `satisfyingVal` is what + // would result from executing this `lookup_witness_method` + // instruction dynamically, so we can go ahead and + // replace the original instruction with that value. + // + // We also make sure to add any uses of the lookup + // instruction to our work list, because subsequent + // simplifications might be possible now. + // + addUsersToWorkList(lookupInst); + lookupInst->replaceUsesWith(satisfyingVal); + lookupInst->removeAndDeallocate(); + } + + // The above subroutine needed a way to look up + // the satisfying value for a given requirement + // key in a concrete witness table, so let's + // define that now. + // + IRInst* findWitnessVal( + IRWitnessTable* witnessTable, + IRInst* requirementKey) + { + // A witness table is basically just a container + // for key-value pairs, and so the best we can + // do for now is a naive linear search. + // + for( auto entry : witnessTable->getEntries() ) + { + if (requirementKey == entry->getRequirementKey()) + { + return entry->getSatisfyingVal(); + } + } + return nullptr; + } + + // All of the machinery for generic specialization + // has been defined above, so we will now walk + // through the flow of the overall specialization pass. + // + void processModule() + { + // We start by initializing our shared IR building state, + // since we will re-use that state for any code we + // generate along the way. + // + SharedIRBuilder* sharedBuilder = &sharedBuilderStorage; + sharedBuilder->module = module; + sharedBuilder->session = module->session; + + // The unspecialized IR we receive as input will have + // `IRBindGlobalGenericParam` instructions that associate + // each global-scope generic parameter (a type, witness + // table, or what-have-you) with the value that it should + // be bound to for the purposes of this code-generation + // pass. + // + // Before doing any other specialization work, we will + // iterate over these instructions (which may only + // appear at the global scope) and use them to drive + // replacement of the given generic type parameter with + // the desired concrete value. + // + // TODO: When we start to support global shader parameters + // that include existential/interface types, we will need + // to support a similar specialization step for them. + // + specializeGlobalGenericParameters(); + + // Now that we've eliminated all cases of global generic parameters, + // we should now have the properties that: + // + // 1. Execution starts in non-generic code, with no unbound + // generic parameters in scope. + // + // 2. Any case where non-generic code makes use of a generic + // type/function, there will be a `specialize` instruction + // that specifies both the generic and the (concrete) type + // arguments that should be provided to it. + // + // The basic approach now is to look for opportunities to apply + // our specialization rules (e.g., a `specialize` instruction + // where all the type arguments are concrete types) and then + // processing any additional opportunities created along the way. + // + // We start out simple by putting the root instruction for the + // module onto our work list. + // + addToWorkList(module->getModuleInst()); + + while(workList.getCount() != 0) + { + + // We will then iterate until our work list goes dry. + // + while(workList.getCount() != 0) + { + IRInst* inst = workList.getLast(); + + workList.removeLast(); + workListSet.Remove(inst); + cleanInsts.Add(inst); + + // For each instruction we process, we want to perform + // a few steps. + // + // First we will do any checking required to tag an + // instruction as being fully specialized. + // + maybeMarkAsFullySpecialized(inst); + + // Next we will look for all the general-purpose + // specialization opportunities (generic specialization, + // existential specialization, simplifications, etc.) + // + maybeSpecializeInst(inst); + + // Finally, we need to make our logic recurse through + // the whole IR module, so we want to add the children + // of any parent instructions to our work list so that + // we process them too. + // + // Note that we are adding the children of an instruction + // in reverse order. This is because the way we are + // using the work list treats it like a stack (LIFO) and + // we know that fully-specialized-ness will tend to flow + // top-down through the program, so that we want to process + // the children of an instruction in their original order. + // + for(auto child = inst->getLastChild(); child; child = child->getPrevInst()) + { + // Also note that `addToWorkList` has been written + // to avoid adding any instruction that is a descendent + // of an IR generic, because we don't actually want + // to perform specialization inside of generics. + // + addToWorkList(child); + } + } + + addDirtyInstsToWorkListRec(module->getModuleInst()); + + } + + // Once the work list has gone dry, we should have the invariant + // that there are no `specialize` instructions inside of non-generic + // functions that in turn reference a generic type/function, *except* + // in the case where that generic is for a builtin type/function, in + // which case we wouldn't want to specialize it anyway. + } + + void addDirtyInstsToWorkListRec(IRInst* inst) + { + if( !cleanInsts.Contains(inst) ) + { + addToWorkList(inst); + } + + for(auto child = inst->getLastChild(); child; child = child->getPrevInst()) + { + addDirtyInstsToWorkListRec(child); + } + } + + // Given a `call` instruction in the IR, we need to detect the case + // where the callee has some interface-type parameter(s) and at the + // call site it is statically clear what concrete type(s) the arguments + // will have. + // + void maybeSpecializeExistentialsForCall(IRCall* inst) + { + // We can only specialize a call when the callee function is known. + // + auto calleeFunc = as(inst->getCallee()); + if(!calleeFunc) + return; + + // We can only specialize if we have access to a body for the callee. + // + if(!calleeFunc->isDefinition()) + return; + + // We shouldn't bother specializing unless the callee has at least + // one parameter that has an existential/interface type. + // + bool shouldSpecialize = false; + UInt argCounter = 0; + for( auto param : calleeFunc->getParams() ) + { + auto arg = inst->getArg(argCounter++); + if( !isExistentialType(param->getDataType()) ) + continue; + + shouldSpecialize = true; + + // We *cannot* specialize unless the argument value corresponding + // to such a parameter is one we can specialize. + // + if( !canSpecializeExistentialArg(arg)) + return; + + } + // If we never found a parameter worth specializing, we should bail out. + // + if(!shouldSpecialize) + return; + + // At this point, we believe we *should* and *can* specialize. + // + // We need a specialized variant of the callee (with the concrete + // types substituted in for existential-type parameters), and then + // we can replace the call site to call the new function instead. + // + // Any two call sites where the argument types are the same can + // re-use the same callee, so we will cache and re-use the + // specialized functions that we generate (similar to how generic + // specialization works). Therefore we will construct a key + // for use when caching the specialized functions. + // + IRSimpleSpecializationKey key; + + // The specialized callee will always depend on the unspecialized + // function from which it is generated, so we add that to our key. + // + key.vals.add(calleeFunc); + + // Also, for any parameter that has an existential type, the + // specialized function will depend on the concrete type of the + // argument. + // + argCounter = 0; + for( auto param : calleeFunc->getParams() ) + { + auto arg = inst->getArg(argCounter++); + if( !isExistentialType(param->getDataType()) ) + continue; + + if( auto makeExistential = as(arg) ) + { + // Note that we use the *type* stored in the + // existential-type argument, but not anything to + // do with the particular value (otherwise we'd only + // be able to re-use the specialized callee for + // call sites that pass in the exact same argument). + // + auto val = makeExistential->getWrappedValue(); + auto valType = val->getFullType(); + key.vals.add(valType); + + // We are also including the witness table in the key. + // This isn't required with our current language model, + // since a given type can only conform to a given interface + // in one way (so there can be only one witness table). + // That means that the `valType` and the existential + // type of `param` above should uniquely determine + // the witness table we see. + // + // There are forward-looking cases where supporting + // "overlapping conformances" could be required, and + // there is low incremental cost to future-proofing + // this code, so we go ahead and add the witness + // table even if it is redundant. + // + auto witnessTable = makeExistential->getWitnessTable(); + key.vals.add(witnessTable); + } + else if( auto wrapExistential = as(arg) ) + { + auto val = wrapExistential->getWrappedValue(); + auto valType = val->getFullType(); + key.vals.add(valType); + + UInt slotOperandCount = wrapExistential->getSlotOperandCount(); + for( UInt ii = 0; ii < slotOperandCount; ++ii ) + { + auto slotOperand = wrapExistential->getSlotOperand(ii); + key.vals.add(slotOperand); + } + } + else + { + SLANG_UNEXPECTED("missing case for existential argument"); + } + } + + // Once we've constructed our key, we can try to look for an + // existing specialization of the callee that we can use. + // + IRFunc* specializedCallee = nullptr; + if( !existentialSpecializedFuncs.TryGetValue(key, specializedCallee) ) + { + // If we didn't find a specialized callee already made, then we + // will go ahead and create one, and then register it in our cache. + // + specializedCallee = createExistentialSpecializedFunc(inst, calleeFunc); + existentialSpecializedFuncs.Add(key, specializedCallee); + } + + // At this point we have found or generated a specialized version + // of the callee, and we need to emit a call to it. + // + // We will start by constructing the argument list for the new call. + // + argCounter = 0; + List newArgs; + for( auto param : calleeFunc->getParams() ) + { + auto arg = inst->getArg(argCounter++); + + // How we handle each argument depends on whether the corresponding + // parameter has an existential type or not. + // + if( !isExistentialType(param->getDataType()) ) + { + // If the parameter doesn't have an existential type, then we + // don't want to change up the argument we pass at all. + // + newArgs.add(arg); + } + else + { + // Any place where the original function had a parameter of + // existential type, we will now be passing in the concrete + // argument value instead of an existential wrapper. + // + if( auto makeExistential = as(arg) ) + { + auto val = makeExistential->getWrappedValue(); + newArgs.add(val); + } + else if( auto wrapExistential = as(arg) ) + { + auto val = wrapExistential->getWrappedValue(); + newArgs.add(val); + } + else + { + SLANG_UNEXPECTED("missing case for existential argument"); + } + } + } + + // Now that we've built up our argument list, it is simple enough + // to construct a new `call` instruction. + // + IRBuilder builderStorage; + auto builder = &builderStorage; + builder->sharedBuilder = &sharedBuilderStorage; + + builder->setInsertBefore(inst); + auto newCall = builder->emitCallInst( + inst->getFullType(), specializedCallee, newArgs); + + // We will completely replace the old `call` instruction with the + // new one, and will go so far as to transfer any decorations + // that were attached to the old call over to the new one. + // + inst->transferDecorationsTo(newCall); + inst->replaceUsesWith(newCall); + inst->removeAndDeallocate(); + + // Just in case, we will add any instructions that used the + // result of this call to our work list for re-consideration. + // At this moment this shouldn't open up new opportunities + // for specialization, but we can always play it safe. + // + addUsersToWorkList(newCall); + } + + // The above `maybeSpecializeExistentialsForCall` routine needed + // a few utilities, which we will now define. + + // First, we want to be able to test whether a type (used by + // a parameter) is an existential type so that we should specialize it. + // + bool isExistentialType(IRType* type) + { + // An IR-level interface type is always an existential. + // + if(as(type)) + return true; + + // Eventually we will also want to handle arrays over + // existential types, but that will require careful + // handling in many places. + + return false; + } + + // Similarly, we want to be able to test whether an instruction + // used as an argument for an existential-type parameter is + // suitable for use in specialization. + // + bool canSpecializeExistentialArg(IRInst* inst) + { + // A `makeExistential(v, w)` instruction can be used + // for specialization, since we have the concrete value `v` + // (which implicitly determines the concrete type), and + // the witness table `w. + // + if(as(inst)) + return true; + + // A `wrapExistential(v, T0,w0, T1, w1, ...)` instruction + // is just a generalization of `makeExistential`, so it + // can apply in the same cases. + // + if(as(inst)) + return true; + + // If we start to specialize functions that take arrays + // of existentials as input, we will need a strategy to + // determine arguments suitable for use in specializing + // them (these would need to be arrays that nominally + // have an existential element type, but somehow have + // annotations to indicate that the concrete type + // underlying the elements in homogeneous). + + return false; + } + + // In order to cache and re-use functions that have had existential-type + // parameters specialized, we need storage for the cache. + // + Dictionary existentialSpecializedFuncs; + + // The logic for creating a specialized callee function by plugging + // in concrete types for existentials is similar to other cases of + // specialization in the compiler. + // + IRFunc* createExistentialSpecializedFunc( + IRCall* oldCall, + IRFunc* oldFunc) + { + // We will make use of the infrastructure for cloning + // IR code, that is defined in `ir-clone.{h,cpp}`. + // + // In order to do the cloning work we need an + // "environment" that will map old values to + // their replacements. + // + IRCloneEnv cloneEnv; + + // We also need some IR building state, for any + // new instructions we will emit. + // + IRBuilder builderStorage; + auto builder = &builderStorage; + builder->sharedBuilder = &sharedBuilderStorage; + + // We will start out by determining what the parameters + // of the specialized function should be, based on + // the parameters of the original, and the concrete + // type of selected arguments at the call site. + // + // Along the way we will build up explicit lists of + // the parameters, as well as any new instructions + // that need to be added to the body of the function + // we generate (as a kind of "prologue"). We build + // the lists here because we don't yet have a basic + // block, or even a function, to insert them into. + // + List newParams; + List newBodyInsts; + UInt argCounter = 0; + for( auto oldParam : oldFunc->getParams() ) + { + auto arg = oldCall->getArg(argCounter++); + + // Given an old parameter, and the argument value at + // the (old) call site, we need to determine what + // value should stand in for that parameter in + // the specialized callee. + // + IRInst* replacementVal = nullptr; + + // The trickier case is when we have an existential-type + // parameter, because we need to extract out the concrete + // type that is coming from the call site. + // + if( auto oldMakeExistential = as(arg) ) + { + // In this case, the `arg` is `makeExistential(val, witnessTable)` + // and we know that the specialized call site will just be + // passing in `val`. + // + auto val = oldMakeExistential->getWrappedValue(); + auto witnessTable = oldMakeExistential->getWitnessTable(); + + // Our specialized function needs to take a parameter with the + // same type as `val`, to match the call site(s) that will be + // created. + // + auto valType = val->getFullType(); + auto newParam = builder->createParam(valType); + newParams.add(newParam); + + // Within the body of the function we cannot just use `val` + // directly, because the existing code expects an existential + // value, including its witness table. + // + // Therefore we will create a `makeExistential(newParam, witnessTable)` + // in the body of the new function and use *that* as the replacement + // value for the original parameter (since it will have the + // correct existential type, and stores the right witness table). + // + auto newMakeExistential = builder->emitMakeExistential(oldParam->getFullType(), newParam, witnessTable); + newBodyInsts.add(newMakeExistential); + replacementVal = newMakeExistential; + } + else if( auto oldWrapExistential = as(arg) ) + { + auto val = oldWrapExistential->getWrappedValue(); + auto valType = val->getFullType(); + + auto newParam = builder->createParam(valType); + newParams.add(newParam); + + // Within the body of the function we cannot just use `val` + // directly, because the existing code expects an existential + // value, including its witness table. + // + // Therefore we will create a `makeExistential(newParam, witnessTable)` + // in the body of the new function and use *that* as the replacement + // value for the original parameter (since it will have the + // correct existential type, and stores the right witness table). + // + auto newWrapExistential = builder->emitWrapExistential( + oldParam->getFullType(), + newParam, + oldWrapExistential->getSlotOperandCount(), + oldWrapExistential->getSlotOperands()); + newBodyInsts.add(newWrapExistential); + replacementVal = newWrapExistential; + } + else + { + // For parameters that don't have an existential type, + // there is nothing interesting to do. The new function + // will also have a parameter of the exact same type, + // and we'll use that instead of the original parameter. + // + auto newParam = builder->createParam(oldParam->getFullType()); + newParams.add(newParam); + replacementVal = newParam; + } + + // Whatever replacement value was constructed, we need to + // register it as the replacement for the original parameter. + // + cloneEnv.mapOldValToNew.Add(oldParam, replacementVal); + } + + // Next we will create the skeleton of the new + // specialized function, including its type. + // + // In order to construct the type of the new function, we + // need to extract the types of all its parameters. + // + List newParamTypes; + for( auto newParam : newParams ) + { + newParamTypes.add(newParam->getFullType()); + } + IRType* newFuncType = builder->getFuncType( + newParamTypes.getCount(), + newParamTypes.getBuffer(), + oldFunc->getResultType()); + IRFunc* newFunc = builder->createFunc(); + newFunc->setFullType(newFuncType); + + // By construction, our new function type will be + // "fully specialized" by the rules used for doing + // generic specialization elsewhere in this pass. + // + fullySpecializedInsts.Add(newFuncType); + + // The above steps have accomplished the "first phase" + // of cloning the function (since `IRFunc`s have no + // operands). + // + // We can now use the shared IR cloning infrastructure + // to perform the second phase of cloning, which will recursively + // clone any nested decorations, blocks, and instructions. + // + cloneInstDecorationsAndChildren( + &cloneEnv, + builder->sharedBuilder, + oldFunc, + newFunc); + + // Now that the main body of existing isntructions have + // been cloned into the new function, we can go ahead + // and insert all the parameters and body instructions + // we built up into the function at the right place. + // + // We expect the function to always have at least one + // block (this was an invariant established before + // we decided to specialize). + // + auto newEntryBlock = newFunc->getFirstBlock(); + SLANG_ASSERT(newEntryBlock); + + // We expect every valid block to have at least one + // "ordinary" instruction (it will at least have + // a terminator like a `return`). + // + auto newFirstOrdinary = newEntryBlock->getFirstOrdinaryInst(); + SLANG_ASSERT(newFirstOrdinary); + + // All of our parameters will get inserted before + // the first ordinary instruction (since the function parameters + // should come at the start of the first block). + // + for( auto newParam : newParams ) + { + newParam->insertBefore(newFirstOrdinary); + } + + // All of our new body instructions will *also* be inserted + // before the first ordinary instruction (but will come + // *after* the parameters by the order of these two loops). + // + for( auto newBodyInst : newBodyInsts ) + { + newBodyInst->insertBefore(newFirstOrdinary); + } + + // After all this work we have a valid `newFunc` that has been + // specialized to match the types at the call site. + // + // There might be further opportunities for simplification and + // specialization in the function body now that we've plugged + // in some more concrete type information, so we will + // add the whole function to our work list for subsequent + // consideration. + // + addToWorkList(newFunc); + + return newFunc; + } + + // When we've specialized a function with an interface-type parameter + // we will still end up with a `makeExistential` operation in its + // body, which could impede subequent specializations. + // + // For example, if we have the following after specialization: + // + // e = makeExistential(v, w1); + // w2 = extractExistentialWitnessTable(e); + // f = lookup_witness_method(w2, k); + // call(f, ...); + // + // We cannot then specialize the lookup for `f` in this code as written, + // but it seems obvious that we could replace `w2` with `w1` and maybe + // get further along. + // + // In order to set up further specialization opportunities we need + // to implement a few simplification rules around operations that + // extract from an existential, when their operand is a `makeExistential`. + // + // Let's start with the routine for the case above of extracting + // a witness table. + // + void maybeSpecializeExtractExistentialWitnessTable(IRInst* inst) + { + // We know `inst` is `extractExistentialWitnessTable(existentialArg)`. + // + auto existentialArg = inst->getOperand(0); + + if( auto makeExistential = as(existentialArg) ) + { + // In this case we know `inst` is: + // + // extractExistentialWitnessTable(makeExistential(..., witnessTable)) + // + // and we can just simplify that to `witnessTable`. + // + auto witnessTable = makeExistential->getWitnessTable(); + + // Anything that used this instruction is now a candidate for + // further simplification or specialization (e.g., one of + // the users of this instruction could be a `lookup_witness_method` + // that we can now specialize). + // + addUsersToWorkList(inst); + + inst->replaceUsesWith(witnessTable); + inst->removeAndDeallocate(); + } + } + + // The cases for simplifying `extractExistentialValue` is more or less the same + // as for witness tables. + // + void maybeSpecializeExtractExistentialValue(IRInst* inst) + { + // We know `inst` is `extractExistentialValue(existentialArg)`. + // + auto existentialArg = inst->getOperand(0); + if( auto makeExistential = as(existentialArg) ) + { + // Now we know `inst` is: + // + // extractExistentialValue(makeExistential(val, ...)) + // + // and we can just simplify that to `val`. + // + auto val = makeExistential->getWrappedValue(); + + addUsersToWorkList(inst); + + inst->replaceUsesWith(val); + inst->removeAndDeallocate(); + } + } + + // The cases for simplifying `extractExistentialType` is more or less the same + // as for witness tables. + // + void maybeSpecializeExtractExistentialType(IRInst* inst) + { + // We know `inst` is `extractExistentialValue(existentialArg)`. + // + auto existentialArg = inst->getOperand(0); + if( auto makeExistential = as(existentialArg) ) + { + // Now we know `inst` is: + // + // extractExistentialType(makeExistential(val, ...)) + // + // and we can just simplify that to type type of `val`. + // + auto val = makeExistential->getWrappedValue(); + auto valType = val->getFullType(); + + addUsersToWorkList(inst); + + inst->replaceUsesWith(valType); + inst->removeAndDeallocate(); + } + } + + void maybeSpecializeLoad(IRLoad* inst) + { + auto ptrArg = inst->ptr.get(); + + if( auto wrapInst = as(ptrArg) ) + { + // We have an instruction of the form `load(wrapExistential(val, ...))` + // + auto val = wrapInst->getWrappedValue(); + + // We know what type we are expected to + // produce (which should be the pointed-to + // type for whatever the type of the + // `wrapExistential` is). + // + auto resultType = inst->getFullType(); + + IRBuilder builder; + builder.sharedBuilder = &sharedBuilderStorage; + builder.setInsertBefore(inst); + + // We'd *like* to replace this instruction with + // `wrapExistential(load(val))` instead, since that + // will enable subsequent specializations. + // + // To do that, we need to be able to determine + // the type that `load(val)` should return. + // + auto elementType = tryGetPointedToType(&builder, val->getDataType()); + if(!elementType) + return; + + + List slotOperands; + UInt slotOperandCount = wrapInst->getSlotOperandCount(); + for( UInt ii = 0; ii < slotOperandCount; ++ii ) + { + slotOperands.add(wrapInst->getSlotOperand(ii)); + } + + auto newLoadInst = builder.emitLoad(elementType, val); + auto newWrapExistentialInst = builder.emitWrapExistential( + resultType, + newLoadInst, + slotOperandCount, + slotOperands.getBuffer()); + + addUsersToWorkList(inst); + + inst->replaceUsesWith(newWrapExistentialInst); + inst->removeAndDeallocate(); + } + } + + UInt calcExistentialBoxSlotCount(IRType* type) + { + top: + if( as(type) ) + { + return 2; + } + else if( auto ptrType = as(type) ) + { + type = ptrType->getValueType(); + goto top; + } + else if( auto ptrLikeType = as(type) ) + { + type = ptrLikeType->getElementType(); + goto top; + } + else if( auto structType = as(type) ) + { + UInt count = 0; + for( auto field : structType->getFields() ) + { + count += calcExistentialBoxSlotCount(field->getFieldType()); + } + return count; + } + else + { + return 0; + } + } + + void maybeSpecializeFieldExtract(IRFieldExtract* inst) + { + auto baseArg = inst->getBase(); + auto fieldKey = inst->getField(); + + if( auto wrapInst = as(baseArg) ) + { + // We have `getField(wrapExistential(val, ...), fieldKey)` + // + auto val = wrapInst->getWrappedValue(); + + // We know what type we are expected to produce. + // + auto resultType = inst->getFullType(); + + IRBuilder builder; + builder.sharedBuilder = &sharedBuilderStorage; + builder.setInsertBefore(inst); + + // We'd *like* to replace this instruction with + // `wrapExistential(getField(val, fieldKey), ...)` instead, since that + // will enable subsequent specializations. + // + // To do that, we need to figure out: + // + // 1. What type that inner `getField` would return (what + // is the type of the `fieldKey` field in `val`?) + // + // 2. Which of the existential slot operands in `...` there + // actually apply to the given field. + // + + // To determine these things, we need the type of + // `val` to be a structure type so that we can look + // up the field corresponding to `fieldKey`. + // + auto valType = val->getDataType(); + auto valStructType = as(valType); + if(!valStructType) + return; + + UInt slotOperandOffset = 0; + + IRStructField* foundField = nullptr; + for( auto valField : valStructType->getFields() ) + { + if( valField->getKey() == fieldKey ) + { + foundField = valField; + break; + } + + slotOperandOffset += calcExistentialBoxSlotCount(valField->getFieldType()); + } + + if(!foundField) + return; + + auto foundFieldType = foundField->getFieldType(); + + List slotOperands; + UInt slotOperandCount = calcExistentialBoxSlotCount(foundFieldType); + + for( UInt ii = 0; ii < slotOperandCount; ++ii ) + { + slotOperands.add(wrapInst->getSlotOperand(slotOperandOffset + ii)); + } + + auto newGetField = builder.emitFieldExtract( + foundFieldType, + val, + fieldKey); + + auto newWrapExistentialInst = builder.emitWrapExistential( + resultType, + newGetField, + slotOperandCount, + slotOperands.getBuffer()); + + addUsersToWorkList(inst); + inst->replaceUsesWith(newWrapExistentialInst); + inst->removeAndDeallocate(); + } + } + + + void maybeSpecializeFieldAddress(IRFieldAddress* inst) + { + auto baseArg = inst->getBase(); + auto fieldKey = inst->getField(); + + if( auto wrapInst = as(baseArg) ) + { + // We have `getFieldAddr(wrapExistential(val, ...), fieldKey)` + // + auto val = wrapInst->getWrappedValue(); + + // We know what type we are expected to produce. + // + auto resultType = inst->getFullType(); + + IRBuilder builder; + builder.sharedBuilder = &sharedBuilderStorage; + builder.setInsertBefore(inst); + + // We'd *like* to replace this instruction with + // `wrapExistential(getFieldAddr(val, fieldKey), ...)` instead, since that + // will enable subsequent specializations. + // + // To do that, we need to figure out: + // + // 1. What type that inner `getFieldAddr` would return (what + // is the type of the `fieldKey` field in `val`?) + // + // 2. Which of the existential slot operands in `...` there + // actually apply to the given field. + // + + // To determine these things, we need the type of + // `val` to be a (pointer to a) structure type so that we can look + // up the field corresponding to `fieldKey`. + // + auto valType = tryGetPointedToType(&builder, val->getDataType()); + if(!valType) + return; + + auto valStructType = as(valType); + if(!valStructType) + return; + + UInt slotOperandOffset = 0; + + IRStructField* foundField = nullptr; + for( auto valField : valStructType->getFields() ) + { + if( valField->getKey() == fieldKey ) + { + foundField = valField; + break; + } + + slotOperandOffset += calcExistentialBoxSlotCount(valField->getFieldType()); + } + + if(!foundField) + return; + + auto foundFieldType = foundField->getFieldType(); + + List slotOperands; + UInt slotOperandCount = calcExistentialBoxSlotCount(foundFieldType); + + for( UInt ii = 0; ii < slotOperandCount; ++ii ) + { + slotOperands.add(wrapInst->getSlotOperand(slotOperandOffset + ii)); + } + + auto newGetFieldAddr = builder.emitFieldAddress( + builder.getPtrType(foundFieldType), + val, + fieldKey); + + auto newWrapExistentialInst = builder.emitWrapExistential( + resultType, + newGetFieldAddr, + slotOperandCount, + slotOperands.getBuffer()); + + addUsersToWorkList(inst); + inst->replaceUsesWith(newWrapExistentialInst); + inst->removeAndDeallocate(); + } + } + + UInt calcExistentialTypeParamSlotCount(IRType* type) + { + top: + if( as(type) ) + { + return 2; + } + else if( auto ptrType = as(type) ) + { + type = ptrType->getValueType(); + goto top; + } + else if( auto ptrLikeType = as(type) ) + { + type = ptrLikeType->getElementType(); + goto top; + } + else if( auto structType = as(type) ) + { + UInt count = 0; + for( auto field : structType->getFields() ) + { + count += calcExistentialTypeParamSlotCount(field->getFieldType()); + } + return count; + } + else + { + return 0; + } + } + + Dictionary existentialSpecializedStructs; + + void maybeSpecializeBindExistentialsType(IRBindExistentialsType* type) + { + auto baseType = type->getBaseType(); + UInt slotOperandCount = type->getExistentialArgCount(); + + IRBuilder builder; + builder.sharedBuilder = &sharedBuilderStorage; + builder.setInsertBefore(type); + + if( auto baseInterfaceType = as(baseType) ) + { + // A `BindExistentials` can + // just be simplified to `ExistentialBox`. + // + // Note: We do *not* simplify straight to `ConcreteType`, because + // that would mess up the layout for aggregate types that + // contain interfaces. The logical indirection introduced + // by `ExistentialBox<...>` will be handled by a later type + // legalization pass that moved the type "pointed to" by + // the box out of line from other fields. + + // We always expect two slot operands, one for the concrete type + // and one for the witness table. + // + SLANG_ASSERT(slotOperandCount == 2); + if(slotOperandCount <= 1) return; + + auto concreteType = (IRType*) type->getExistentialArg(0); + auto newVal = builder.getPtrType(kIROp_ExistentialBoxType, concreteType); + + addUsersToWorkList(type); + type->replaceUsesWith(newVal); + type->removeAndDeallocate(); + return; + } + else if( auto basePtrLikeType = as(baseType) ) + { + // A `BindExistentials, ...>` can be simplified to + // `P>` when `P` is a pointer-like + // type constructor. + // + auto baseElementType = basePtrLikeType->getElementType(); + IRInst* wrappedElementType = builder.getBindExistentialsType( + baseElementType, + slotOperandCount, + type->getExistentialArgs()); + addToWorkList(wrappedElementType); + + auto newPtrLikeType = builder.getType( + basePtrLikeType->op, + 1, + &wrappedElementType); + addToWorkList(newPtrLikeType); + + addUsersToWorkList(type); + type->replaceUsesWith(newPtrLikeType); + type->removeAndDeallocate(); + return; + } + else if( auto baseStructType = as(baseType) ) + { + // In order to bind a `struct` type we will generate + // a new specialized `struct` type on demand and then + // cache and re-use it. + // + // We don't want to start specializing here unless + // all the operand types (and witness tables) we + // will be specializing to are themselves fully + // specialized, so that we can be sure that we + // have a unique type. + // + if( !areAllOperandsFullySpecialized(type) ) + return; + + // Now we we check to see if we've already created + // a specialized struct type or not. + // + IRSimpleSpecializationKey key; + key.vals.add(baseStructType); + for( UInt ii = 0; ii < slotOperandCount; ++ii ) + { + key.vals.add(type->getExistentialArg(ii)); + } + + IRStructType* newStructType = nullptr; + if( !existentialSpecializedStructs.TryGetValue(key, newStructType) ) + { + builder.setInsertBefore(baseStructType); + newStructType = builder.createStructType(); + + auto fieldSlotArgs = type->getExistentialArgs(); + + for( auto oldField : baseStructType->getFields() ) + { + // TODO: we need to figure out which of the specialization arguments + // apply to this field... + + auto oldFieldType = oldField->getFieldType(); + auto fieldSlotArgCount = calcExistentialTypeParamSlotCount(oldFieldType); + + auto newFieldType = builder.getBindExistentialsType( + oldFieldType, + fieldSlotArgCount, + fieldSlotArgs); + + addToWorkList(newFieldType); + + fieldSlotArgs += fieldSlotArgCount; + + builder.createStructField(newStructType, oldField->getKey(), newFieldType); + } + + existentialSpecializedStructs.Add(key, newStructType); + addToWorkList(newStructType); + } + + addUsersToWorkList(type); + type->replaceUsesWith(newStructType); + type->removeAndDeallocate(); + return; + + } + } + + // The handling of specialization for global generic type + // parameters involves searching for all `bind_global_generic_param` + // instructions in the input module. + // + void specializeGlobalGenericParameters() + { + auto moduleInst = module->getModuleInst(); + for(auto inst : moduleInst->getChildren()) + { + // We only want to consider the `bind_global_generic_param` + // instructions, and ignore everything else. + // + auto bindInst = as(inst); + if(!bindInst) + continue; + + // HACK: Our current front-end emit logic can end up emitting multiple + // `bind_global_generic_param` instructions for the same parameter. This is + // a buggy behavior, but a real fix would require refactoring the way + // global generic arguments are specified today. + // + // For now we will do a sanity check to detect parameters that + // have already been specialized. + // + if( !as(bindInst->getOperand(0)) ) + { + // The "parameter" operand is no longer a parameter, so it + // seems things must have been specialized already. + // + continue; + } + + // The actual logic for applying the substitution is + // almost trivial: we will replace any uses of the + // global generic parameter with its desired value. + // + auto param = bindInst->getParam(); + auto val = bindInst->getVal(); + param->replaceUsesWith(val); + } + { + // Now that we've replaced any uses of global generic + // parameters, we will do a second pass to remove + // the parameters and any `bind_global_generic_param` + // instructions, since both should be dead/unused. + // + IRInst* next = nullptr; + for(auto inst = moduleInst->getFirstChild(); inst; inst = next) + { + next = inst->getNextInst(); + + switch(inst->op) + { + default: + break; + + case kIROp_GlobalGenericParam: + case kIROp_BindGlobalGenericParam: + // A `bind_global_generic_param` instruction should + // have no uses in the first place, and all the global + // generic parameters should have had their uses replaced. + // + SLANG_ASSERT(!inst->firstUse); + inst->removeAndDeallocate(); + break; + } + } + } + } +}; + +void specializeModule( + IRModule* module) +{ + SpecializationContext context; + context.module = module; + context.processModule(); +} + + +IRInst* specializeGenericImpl( + IRGeneric* genericVal, + IRSpecialize* specializeInst, + IRModule* module, + SpecializationContext* context) +{ + // Effectively, specializing a generic amounts to "calling" the generic + // on its concrete argument values and computing the + // result it returns. + // + // For now, all of our generics consist of a single + // basic block, so we can "call" them just by + // cloning the instructions in their single block + // into the global scope, using an environment for + // cloning that maps the generic parameters to + // the concrete arguments that were provided + // by the `specialize(...)` instruction. + // + IRCloneEnv env; + + // We will walk through the parameters of the generic and + // register the corresponding argument of the `specialize` + // instruction to be used as the "cloned" value for each + // parameter. + // + // Suppose we are looking at `specialize(g, a, b, c)` and `g` has + // three generic parameters: `T`, `U`, and `V`. Then we will + // be initializing our environment to map `T -> a`, `U -> b`, + // and `V -> c`. + // + UInt argCounter = 0; + for( auto param : genericVal->getParams() ) + { + UInt argIndex = argCounter++; + SLANG_ASSERT(argIndex < specializeInst->getArgCount()); + + IRInst* arg = specializeInst->getArg(argIndex); + + env.mapOldValToNew.Add(param, arg); + } + + // We will set up an IR builder for insertion + // into the global scope, at the same location + // as the original generic. + // + SharedIRBuilder sharedBuilderStorage; + sharedBuilderStorage.module = module; + sharedBuilderStorage.session = module->getSession(); + + IRBuilder builderStorage; + IRBuilder* builder = &builderStorage; + builder->sharedBuilder = &sharedBuilderStorage; + builder->setInsertBefore(genericVal); + + // Now we will run through the body of the generic and + // clone each of its instructions into the global scope, + // until we reach a `return` instruction. + // + for( auto bb : genericVal->getBlocks() ) + { + // We expect a generic to only ever contain a single block. + // + SLANG_ASSERT(bb == genericVal->getFirstBlock()); + + // We will iterate over the non-parameter ("ordinary") + // instructions only, because parameters were dealt + // with explictly at an earlier point. + // + for( auto ii : bb->getOrdinaryInsts() ) + { + // The last block of the generic is expected to end with + // a `return` instruction for the specialized value that + // comes out of the abstraction. + // + // We thus use that cloned value as the result of the + // specialization step. + // + if( auto returnValInst = as(ii) ) + { + auto specializedVal = findCloneForOperand(&env, returnValInst->getVal()); + return specializedVal; + } + + // For any instruction other than a `return`, we will + // simply clone it completely into the global scope. + // + IRInst* clonedInst = cloneInst(&env, builder, ii); + + // Any new instructions we create during cloning were + // not present when we initially built our work list, + // so we need to make sure to consider them now. + // + // This is important for the cases where one generic + // invokes another, because there will be `specialize` + // operations nested inside the first generic that refer + // to the second. + // + if( context ) + { + context->addToWorkList(clonedInst); + } + } + } + + // If we reach this point, something went wrong, because we + // never encountered a `return` inside the body of the generic. + // + SLANG_UNEXPECTED("no return from generic"); + UNREACHABLE_RETURN(nullptr); +} + +IRInst* specializeGeneric( + IRSpecialize* specializeInst) +{ + auto baseGeneric = as(specializeInst->getBase()); + SLANG_ASSERT(baseGeneric); + if(!baseGeneric) return specializeInst; + + auto module = specializeInst->getModule(); + SLANG_ASSERT(module); + if(!module) return specializeInst; + + return specializeGenericImpl(baseGeneric, specializeInst, module, nullptr); +} + + +} // namespace Slang diff --git a/source/slang/slang-ir-specialize.h b/source/slang/slang-ir-specialize.h new file mode 100644 index 000000000..9c2c19785 --- /dev/null +++ b/source/slang/slang-ir-specialize.h @@ -0,0 +1,12 @@ +// slang-ir-specialize.h +#pragma once + +namespace Slang +{ +struct IRModule; + + /// Specialize generic and interface-based code to use concrete types. +void specializeModule( + IRModule* module); + +} diff --git a/source/slang/slang-ir-ssa.cpp b/source/slang/slang-ir-ssa.cpp new file mode 100644 index 000000000..9390b1e69 --- /dev/null +++ b/source/slang/slang-ir-ssa.cpp @@ -0,0 +1,1159 @@ +// slang-ir-ssa.cpp +#include "slang-ir-ssa.h" + +#include "slang-ir.h" +#include "slang-ir-clone.h" +#include "slang-ir-insts.h" + +namespace Slang { + +// Track information on a phi node we are in +// the process of constructing. +struct PhiInfo : RefObject +{ + // The phi node will be represented as a parameter + // to a (non-entry) basic block. + IRParam* phi; + + // The original variable that this phi will be replacing. + IRVar* var; + + // The operands to the phi will be stored as uses here, + // because our IR parameters don't have operands. + // + // Once we've collected all the values we plan to use, + // we will turn this into argument in predecessor blocks + // that branch to this one. + // + // The order of elements in this list must match the + // order in which the predecessor blocks get enumerated. + List operands; + + // If this phi ended up being removed as trivial, then + // this will be the value that we replaced it with. + IRInst* replacement = nullptr; +}; + +// Information about a basic block that we generate/use +// during SSA construction. +struct SSABlockInfo : RefObject +{ + // Map a promotable variable to the value to + // use for that variable + Dictionary valueForVar; + + // The underlying basic block. + IRBlock* block; + + // Have we processed all the instructions in the + // body of this block (so that we would have + // found any stores to SSA variables)? + bool isFilled = false; + + // Have we filled all the predecessors of + // this block, so that we can actually perform + // look up in them? + bool isSealed = false; + + // An IR builder to use when we want to construct + // stuff in the context of this block + IRBuilder builder; + + // Phi nodes we are creating for this block. + List phis; + + // Arguments that this block needs to pass along + // to the phi nodes defined by is sucessor + List successorArgs; +}; + +// State for constructing SSA form for a global value +// with code (usually a function). +struct ConstructSSAContext +{ + // The value that we want to rewrite into SSA form + // (usually an IR function) + IRGlobalValueWithCode* globalVal; + + // Variables that we've identified for promotion + // to SSA values. + List promotableVars; + + // Information about each basic block + Dictionary> blockInfos; + + // IR building state to use during the operation + SharedIRBuilder sharedBuilder; + + // Instructions to remove during cleanup + List instsToRemove; + + IRBuilder builder; + IRBuilder* getBuilder() { return &builder; } + + + Dictionary> phiInfos; + + PhiInfo* getPhiInfo(IRParam* phi) + { + if(auto found = phiInfos.TryGetValue(phi)) + return *found; + return nullptr; + } +}; + +/// Do all uses of this instruction lead to a `load`? +/// +/// Checks if all uses of `inst` are either loads, +/// or get-element-address/get-field-address operations +/// that also lead to loads. +bool allUsesLeadToLoads(IRInst* inst) +{ + for (auto u = inst->firstUse; u; u = u->nextUse) + { + auto user = u->getUser(); + switch (user->op) + { + default: + return false; + + case kIROp_Load: + break; + + case kIROp_getElementPtr: + case kIROp_FieldAddress: + { + // Sanity check: the address being used should + // be the base-address operand, and not the field + // key or index (this should never be a problem). + if (u != &user->getOperands()[0]) + return false; + + if (!allUsesLeadToLoads(user)) + return false; + } + break; + } + } + + // If all of the uses passed our checking, then + // we are good to go. + return true; + +} + +// Is the given variable one that we can promote to SSA form? +bool isPromotableVar( + ConstructSSAContext* /*context*/, + IRVar* var) +{ + // We want to identify variables such that we can always + // determine what they will contain at a point in the + // program by directly inspecting their uses. + // + // The simplest possible answer would be instructions + // that are only ever used as the operand of "full" + // load and store instructions (loads and stores that + // write the entire variable). This is enough to + // promote simple scalar variables to SSA temporaries, + // but falls apart for aggregates and arrays. + // + // A slightly more powerful option (which is what we + // implement for now) is to promote variables when + // all of the stores are "full," and all other uses + // are in the form of a "chain" of `getElmeentAddress` + // or `getFieldAddress` operations that terminates + // with a load. + // + // An even more powerful option (which we do not yet + // implement) would be to handle cases where there are + // "chains" that end with stores, and to treat these + // as partial assignments (where we can still form + // an SSA value by creating a new temporary with just + // one element/field different). This kind of approach + // would be best if it is combined with scalarization, + // so that we don't need to construct aggregate temps. + // + + for (auto u = var->firstUse; u; u = u->nextUse) + { + auto user = u->getUser(); + switch (user->op) + { + default: + // If the variable gets used by any operation + // we can't account for directly, then it isn't + // promotable. + return false; + + case kIROp_Load: + { + // A load has only a single argument, so + // it had better be our pointer. + SLANG_ASSERT(u == &((IRLoad*) user)->ptr); + } + break; + + case kIROp_Store: + { + auto storeInst = (IRStore*)user; + + // We don't want to promote a variable if + // its address gets stored into another + // variable, so check for that case. + if (u == &storeInst->val) + return false; + + // Otherwise our variable is being used + // as the destination for the store, and + // that is okay by us. + SLANG_ASSERT(u == &storeInst->ptr); + } + break; + + case kIROp_getElementPtr: + case kIROp_FieldAddress: + { + // Sanity check: the address being used should + // be the base-address operand, and not the field + // key or index (this should never be a problem). + if (u != &user->getOperands()[0]) + return false; + + if (!allUsesLeadToLoads(user)) + return false; + } + break; + } + } + + // If all of the uses passed our checking, then + // we are good to go. + return true; +} + +// Identify local variables that can be promoted to SSA form +void identifyPromotableVars( + ConstructSSAContext* context) +{ + for (auto bb = context->globalVal->getFirstBlock(); bb; bb = bb->getNextBlock()) + { + for (auto ii = bb->getFirstInst(); ii; ii = ii->getNextInst()) + { + if (ii->op != kIROp_Var) + continue; + + IRVar* var = (IRVar*)ii; + + if (isPromotableVar(context, var)) + { + context->promotableVars.add(var); + } + } + } +} + +/// If `value` is a promotable variable, then cast and return it. +IRVar* asPromotableVar( + ConstructSSAContext* context, + IRInst* value) +{ + if (value->op != kIROp_Var) + return nullptr; + + IRVar* var = (IRVar*)value; + if (!context->promotableVars.contains(var)) + return nullptr; + + return var; +} + +/// If `value` is a promotable variable or an access chain +/// based on one, then cast and return the variable. +IRVar* asPromotableVarAccessChain( + ConstructSSAContext* context, + IRInst* value) +{ + switch (value->op) + { + case kIROp_Var: + return asPromotableVar(context, value); + + case kIROp_FieldAddress: + case kIROp_getElementPtr: + return asPromotableVarAccessChain(context, value->getOperand(0)); + + default: + return nullptr; + } +} + +/// After looking up the SSA value of avariable in some context, +/// apply whatever "access chain" was applied at the original use site. +/// +/// E.g., if the original operation was *((&a)->b) or *((&a) + i) and we've +/// resolved that the value of the variable `a` should be `v`, then +/// construct v.b or v[i]. +/// +IRInst* applyAccessChain( + ConstructSSAContext* context, + IRBuilder* builder, + IRInst* accessChain, + IRInst* leafVarValue) +{ + switch (accessChain->op) + { + default: + SLANG_UNEXPECTED("unexpected op along access chain"); + UNREACHABLE_RETURN(leafVarValue); + + case kIROp_Var: + return leafVarValue; + + case kIROp_FieldAddress: + { + SLANG_ASSERT(context->instsToRemove.contains(accessChain)); + + auto baseChain = accessChain->getOperand(0); + auto fieldKey = accessChain->getOperand(1); + auto type = cast(accessChain->getDataType())->getValueType(); + auto baseValue = applyAccessChain(context, builder, baseChain, leafVarValue); + return builder->emitFieldExtract( + type, + baseValue, + fieldKey); + } + + case kIROp_getElementPtr: + { + SLANG_ASSERT(context->instsToRemove.contains(accessChain)); + + auto baseChain = accessChain->getOperand(0); + auto index = accessChain->getOperand(1); + auto type = cast(accessChain->getDataType())->getValueType(); + auto baseValue = applyAccessChain(context, builder, baseChain, leafVarValue); + return builder->emitElementExtract( + type, + baseValue, + index); + } + } +} + +// Try to read the value of an SSA variable +// in the context of the given block. If +// the variable is defined in the block, then +// that value will be used. If not, this all +// may recursively work its way up through +// the predecessors of the block. +IRInst* readVar( + ConstructSSAContext* context, + SSABlockInfo* blockInfo, + IRVar* var); + + /// Try to copy any relevant decorations from `var` over to `val`. + /// +static void cloneRelevantDecorations( + IRVar* var, + IRInst* val) +{ + // Copy selected decorations over from the original + // variable to the SSA variable, when doing so is + // required for semantics. + // + for( auto decoration : var->getDecorations() ) + { + switch(decoration->op) + { + default: + // Ignore most decorations. + // + // TODO: Should we include or exclude by default? + break; + + case kIROp_PreciseDecoration: + case kIROp_NameHintDecoration: + // Copy these decorations if the target doesn't already have them, + // but don't make duplicate decorations on the target. + // + if( !val->findDecorationImpl(decoration->op) ) + { + cloneDecoration(decoration, val, var->getModule()); + } + break; + } + } +} + +// Add a phi node to represent the given variable +PhiInfo* addPhi( + ConstructSSAContext* context, + SSABlockInfo* blockInfo, + IRVar* var) +{ + auto builder = &blockInfo->builder; + + auto valueType = var->getDataType()->getValueType(); + if( auto rate = var->getRate() ) + { + valueType = context->getBuilder()->getRateQualifiedType(rate, valueType); + } + IRParam* phi = builder->createParam(valueType); + cloneRelevantDecorations(var, phi); + + RefPtr phiInfo = new PhiInfo(); + context->phiInfos.Add(phi, phiInfo); + + phiInfo->phi = phi; + phiInfo->var = var; + + blockInfo->phis.add(phiInfo); + + return phiInfo; +} + +IRInst* tryRemoveTrivialPhi( + ConstructSSAContext* context, + PhiInfo* phiInfo) +{ + auto phi = phiInfo->phi; + + // We are going to check if all of the operands + // to the phi are either the same, or are equal + // to the phi itself. + + IRInst* same = nullptr; + for (auto u : phiInfo->operands) + { + auto usedVal = u.get(); + SLANG_ASSERT(usedVal); + + if (usedVal == same || usedVal == phi) + { + // Either this is a self-reference, or it refers + // to the same value we've seen already. + continue; + } + if (same != nullptr) + { + // We've found at least two distinct values + // other than the phi itself, so this phi + // indeed appears to be non-trivial. + // + // We will keep the phi around. + return phi; + } + else + { + // This value is distinct from the phi itself, + // so we need to track its value. + same = usedVal; + } + } + + if (!same) + { + // There were no operands other than the phi itself. + // This implies that the value at the use sites should + // actually be undefined. + SLANG_UNIMPLEMENTED_X("trivial phi"); + } + + // Removing this phi as trivial may make other phi nodes + // become trivial. We will recognize such candidates + // by looking for phi nodes that use this node. + List otherPhis; + for( auto u = phi->firstUse; u; u = u->nextUse ) + { + auto user = u->user; + if(!user) continue; + if(user == phi) continue; + + if( user->op == kIROp_Param ) + { + auto maybeOtherPhi = (IRParam*) user; + if( auto otherPhiInfo = context->getPhiInfo(maybeOtherPhi) ) + { + otherPhis.add(otherPhiInfo); + } + } + } + + // replace uses of the phi (including its possible uses + // of itself) with the unique non-phi value. + phi->replaceUsesWith(same); + + // Clear out the operands to the phi, since they won't + // actually get used in the program any more. + for( auto& u : phiInfo->operands ) + { + u.clear(); + } + + // We will record the value that was used to replace this + // phi, so that we can easily look it up later. + phiInfo->replacement = same; + + // Now that we've cleaned up this phi, we need to consider + // other phis that might have become trivial. + for( auto otherPhi : otherPhis ) + { + tryRemoveTrivialPhi(context, otherPhi); + } + + return same; +} + +IRInst* addPhiOperands( + ConstructSSAContext* context, + SSABlockInfo* blockInfo, + PhiInfo* phiInfo) +{ + auto var = phiInfo->var; + + auto block = blockInfo->block; + + List operandValues; + for (auto predBlock : block->getPredecessors()) + { + // Precondition: if we have multiple predecessors, then + // each must have only one successor (no critical edges). + // + SLANG_ASSERT(predBlock->getSuccessors().getCount() == 1); + + auto predInfo = *context->blockInfos.TryGetValue(predBlock); + + auto phiOperand = readVar(context, predInfo, var); + + operandValues.add(phiOperand); + } + + // The `IRUse` type needs to stay at a stable location + // since they get threaded into lists. We allocate the + // list with its final size so that we can preserve the + // required invariant. + + UInt operandCount = operandValues.getCount(); + phiInfo->operands.setCount(operandCount); + for(UInt ii = 0; ii < operandCount; ++ii) + { + phiInfo->operands[ii].init(phiInfo->phi, operandValues[ii]); + } + + return tryRemoveTrivialPhi(context, phiInfo); +} + +void writeVar( + ConstructSSAContext* /*context*/, + SSABlockInfo* blockInfo, + IRVar* var, + IRInst* val) +{ + blockInfo->valueForVar[var] = val; +} + +void maybeSealBlock( + ConstructSSAContext* context, + SSABlockInfo* blockInfo) +{ + // We can't seal a block that has already been sealed. + if (blockInfo->isSealed) + return; + + // We can't seal a block until all of its predecessors + // have been filled. + for (auto pp : blockInfo->block->getPredecessors()) + { + auto predInfo = *context->blockInfos.TryGetValue(pp); + if (!predInfo->isFilled) + return; + } + + // All the checks passed, so it seems like we can be sealed. + + // We will loop over any incomplete phis that have been recoreded + // for this block, and complete them here. + // + // Note that we are doing the "inefficient" loop where we compute + // the count on each iteration to account for the possibility that + // new incomplete phis will get added while we are working. + for (Index ii = 0; ii < blockInfo->phis.getCount(); ++ii) + { + auto incompletePhi = blockInfo->phis[ii]; + addPhiOperands(context, blockInfo, incompletePhi); + } + + // After we've completed all our incomplete phis, we can mark this + // block as sealed and move along. + blockInfo->isSealed = true; +} + +// In some cases we may have a pointer to an IR value that +// represents a phi node that has been replaced with another +// IR value, because we discovered that the phi is no longer +// needed. +// +// The `maybeGetPhiReplacement` function will follow any +// chain of replacements that might be present, so that we +// don't end up referencing a dangling/unused value in +// the code that we generate. +// +IRInst* maybeGetPhiReplacement( + ConstructSSAContext* context, + IRInst* inVal) +{ + IRInst* val = inVal; + + while( val->op == kIROp_Param ) + { + // The value is a parameter, but is it a phi? + IRParam* maybePhi = (IRParam*) val; + RefPtr phiInfo = nullptr; + if(!context->phiInfos.TryGetValue(maybePhi, phiInfo)) + break; + + // Okay, this is indeed a phi we are adding, but + // is it one that got replaced? + if(!phiInfo->replacement) + break; + + // The phi we want to use got replaced, so we + // had better use the replacement instead. + val = phiInfo->replacement; + } + + return val; +} + +IRInst* readVarRec( + ConstructSSAContext* context, + SSABlockInfo* blockInfo, + IRVar* var) +{ + IRInst* val = nullptr; + if (!blockInfo->isSealed) + { + // If block isn't sealed, we need to + // speculatively add a phi to it. + // This phi may get removed later, once + // we are able to seal this block. + + PhiInfo* phiInfo = addPhi(context, blockInfo, var); + val = phiInfo->phi; + } + else + { + // If the block is sealed, then we are free to look at + // it predecessor list, and use that to decide what to do. + auto predecessors = blockInfo->block->getPredecessors(); + + // + IRBlock* firstPred = nullptr; + bool multiplePreds = false; + for (auto pp : predecessors) + { + if (!firstPred) + { + // A candidate for the sole predecessor + firstPred = pp; + } + else if (pp == firstPred) + { + // Same as existing predecessor + } + else + { + // Multiple unique predecessors + multiplePreds = true; + } + } + + if (!firstPred) + { + // The block had *no* predecssors. This will commonly + // happen for the entry block, but could also conceivably + // happen for a block that is somehow disconnected + // from the CFG and thus unreachable. + + // We would only reach this function (`readVarRec`) if + // a local lookup in the block had already failed, so + // at this point we are dealing with an undefined value. + + auto type = var->getDataType()->getValueType(); + val = blockInfo->builder.emitUndefined(type); + } + else if (!multiplePreds) + { + // There is only a single predecessor for this block, + // so there is no need to insert a phi. Instead, we + // just perform the lookup step recursively in + // the predecessor. + auto predInfo = *context->blockInfos.TryGetValue(firstPred); + val = readVar(context, predInfo, var); + } + else + { + // The default/fallback case requires us to create + // a phi node in the current block, and then look + // up the appropriate operands in the predecessor + // blocks, which will eventually become the operands + // that drive the phi. + + // Create the phi node for the given variable + PhiInfo* phiInfo = addPhi(context, blockInfo, var); + + // Mark the phi as the value for the variable inside + // this block + writeVar(context, blockInfo, var, phiInfo->phi); + + // Now add operands to the phi and maybe simplify + // it, based on what gets found. + + val = addPhiOperands(context, blockInfo, phiInfo); + } + } + + // Whatever value we find, we need to mark it as the + // value for the given variable in this block + writeVar(context, blockInfo, var, val); + + // If `val` represents a phi node (block parameter) then + // it is possible that some of the operations above might + // have caused it to be replaced with another value, + // and in that case we had better not return it to + // be referenced in user code. + // + // Note: it is okay for the `valueForVar` map that + // we update in `writeVar` to use the old value, so long + // as we do this replacement logic anywhere we might read + // from that map. + // + val = maybeGetPhiReplacement(context, val); + + return val; +} + + + +IRInst* readVar( + ConstructSSAContext* context, + SSABlockInfo* blockInfo, + IRVar* var) +{ + // In the easy case, there will be a preceeding + // store in the same block, so we can use + // that local value. + IRInst* val = nullptr; + if (blockInfo->valueForVar.TryGetValue(var, val)) + { + // Hooray, we found a value to use, and we + // can proceed without too many complications. + + // Just like in the `readVarRec` case above, we need + // to handle the case where `val` might represent + // a phi node that has subsequently been replaced. + // + val = maybeGetPhiReplacement(context, val); + + return val; + } + + // Otherwise we need to try to non-trivial/recursive + // case of lookup. + return readVarRec(context, blockInfo, var); +} + +void processBlock( + ConstructSSAContext* context, + IRBlock* block, + SSABlockInfo* blockInfo) +{ + // Before starting, check if this block can be sealed + maybeSealBlock(context, blockInfo); + + // Walk the instructions in the block, and either + // leave them as-is, or replace them with a value + // that we look up with local/global value numbering + + IRInst* next = nullptr; + for (auto ii = block->getFirstInst(); ii; ii = next) + { + next = ii->getNextInst(); + + // Any new instructions we create to represent + // the new value will get inserted before whatever + // instruction we are working with. + blockInfo->builder.setInsertBefore(ii); + + switch (ii->op) + { + default: + // Ordinary instruction -> leave as-is + break; + + case kIROp_Store: + { + auto storeInst = (IRStore*)ii; + auto ptrArg = storeInst->ptr.get(); + auto valArg = storeInst->val.get(); + + if (auto var = asPromotableVar(context, ptrArg)) + { + // We are storing to a promotable variable, + // so we want to register the value being + // stored as the value for the given SSA + // variable. + writeVar(context, blockInfo, var, valArg); + + // Also eliminate the store instruction, + // since it is no longer needed. + storeInst->removeAndDeallocate(); + } + } + break; + + case kIROp_Load: + { + IRLoad* loadInst = (IRLoad*)ii; + auto ptrArg = loadInst->ptr.get(); + + if (auto var = asPromotableVarAccessChain(context, ptrArg)) + { + // We are loading from a promotable variable. + // Look up the value in the context of this + // block. + auto val = readVar(context, blockInfo, var); + + cloneRelevantDecorations(var, val); + + val = applyAccessChain(context, &blockInfo->builder, ptrArg, val); + + // We can just replace all uses of this + // load instruction with the given value. + loadInst->replaceUsesWith(val); + + // Also eliminate the load instruction, + // since it is no longer needed. + loadInst->removeAndDeallocate(); + } + } + break; + + case kIROp_getElementPtr: + case kIROp_FieldAddress: + { + auto ptrArg = ii->getOperand(0); + if (auto var = asPromotableVarAccessChain(context, ptrArg)) + { + context->instsToRemove.add(ii); + } + } + break; + + + } + } + + auto terminator = block->getTerminator(); + SLANG_ASSERT(terminator); + blockInfo->builder.setInsertBefore(terminator); + + // Once we are done with all of the instructions + // in a block, we can mark it as "filled," which + // means we can actually consider lookups into + // it. + blockInfo->isFilled = true; + + // Having filled this block might allow us to seal some + // of its successor(s) + for (auto ss : block->getSuccessors()) + { + auto successorInfo = *context->blockInfos.TryGetValue(ss); + maybeSealBlock(context, successorInfo); + } +} + +static void breakCriticalEdges( + ConstructSSAContext* context) +{ + // A critical edge is an edge P -> S where + // P has multiple sucessors, and S has multiple + // predecessors. + // + // In the context of our CFG representation, such an edge + // will be an `IRUse` in the terminator instruction of block P, + // which refers to block S. + // + // We will make a pass over the CFG to collect all the critical + // edges, and then we will break them in a follow-up pass. + + List criticalEdges; + + auto globalVal = context->globalVal; + for (auto pred = globalVal->getFirstBlock(); pred; pred = pred->getNextBlock()) + { + auto successors = pred->getSuccessors(); + if (successors.getCount() <= 1) + continue; + + auto succIter = successors.begin(); + auto succEnd = successors.end(); + + for (; succIter != succEnd; ++succIter) + { + auto succ = *succIter; + + // For the edge to be critical, the successor must have + // more than one predecessor. + // More than that, we require that it has more than one + // *unique* predecessor, to handle the case where multiple + // cases of a `switch` might lead to the same block. + // + // To implement this, we test if it has any predecessor + // other than `pred` which we already know about. + + bool multiplePreds = false; + for (auto pp : succ->getPredecessors()) + { + if (pp != pred) + { + multiplePreds = true; + break; + } + } + if (!multiplePreds) + continue; + + // We have found a critical edge from `pred` to `succ`. + // + // Furthermore, the `IRUse` embedded in `succIter` represents + // that edge directly. + auto edgeUse = succIter.use; + criticalEdges.add(edgeUse); + } + } + + // Now we will iterate over the critical edges and break each + // one by inserting a new block. Note that we do not try + // to break the edges while doing the initial walk, because + // that would change the CFG while we are walking it. + + for (auto edgeUse : criticalEdges) + { + auto pred = cast(edgeUse->getUser()->parent); + auto succ = cast(edgeUse->get()); + + IRBuilder builder; + builder.sharedBuilder = &context->sharedBuilder; + builder.setInsertInto(pred); + + // Create a new block that will sit "along" the edge + IRBlock* edgeBlock = builder.createBlock(); + + edgeUse->debugValidate(); + + // The predecessor block should now branch to + // the edge block. + edgeUse->set(edgeBlock); + + // The edge block should branch (unconditionally) + // to the successor block. + builder.setInsertInto(edgeBlock); + builder.emitBranch(succ); + + // Insert the new block into the block list + // for the function. + // + // In principle, the order of this list shouldn't + // affect the semantics of a program, but we + // might want to be careful about ordering anyway. + edgeBlock->insertAfter(pred); + } +} + +// Construct SSA form for a global value with code +void constructSSA(ConstructSSAContext* context) +{ + // First, detect and and break any critical edges in the CFG, + // because our representation of SSA form doesn't allow for them. + breakCriticalEdges(context); + + // Figure out what variables we can promote to + // SSA temporaries. + identifyPromotableVars(context); + + // If none of the variables are promote-able, + // then we can exit without making any changes + if (context->promotableVars.getCount() == 0) + return; + + // We are going to walk the blocks in order, + // and try to process each, by replacing loads + // and stores of promotable variables with simple values. + + auto globalVal = context->globalVal; + for(auto bb : globalVal->getBlocks()) + { + auto blockInfo = new SSABlockInfo(); + blockInfo->block = bb; + + blockInfo->builder.sharedBuilder = &context->sharedBuilder; + blockInfo->builder.setInsertBefore(bb->getLastInst()); + + context->blockInfos.Add(bb, blockInfo); + } + for(auto bb : globalVal->getBlocks()) + { + auto blockInfo = * context->blockInfos.TryGetValue(bb); + processBlock(context, bb, blockInfo); + } + + // We need to transfer the logical arguments to our phi nodes + // from the phi nodes back to the predecessor blocks that will + // pass them in. + for(auto bb : globalVal->getBlocks()) + { + auto blockInfo = *context->blockInfos.TryGetValue(bb); + + for (auto phiInfo : blockInfo->phis) + { + // If we replaced this phi with another value, + // then we had better not include it in the result. + if (phiInfo->replacement) + continue; + + // We should add the phi as an explicit parameter of + // the given block. + bb->addParam(phiInfo->phi); + + UInt predCounter = 0; + for (auto pp : bb->getPredecessors()) + { + UInt predIndex = predCounter++; + auto predInfo = *context->blockInfos.TryGetValue(pp); + + IRInst* operandVal = phiInfo->operands[predIndex].get(); + + phiInfo->operands[predIndex].clear(); + + predInfo->successorArgs.add(operandVal); + } + } + } + + // Some blocks may now need to pass along arguments to their sucessor, + // which have been stored into the `SSABlockInfo::successorArgs` field. + for(auto bb : globalVal->getBlocks()) + { + auto blockInfo = * context->blockInfos.TryGetValue(bb); + + // Sanity check: all blocks should be filled and sealed. + SLANG_ASSERT(blockInfo->isSealed); + SLANG_ASSERT(blockInfo->isFilled); + + // Don't do any work for blocks that don't need to pass along + // values to the sucessor block. + auto addedArgCount = blockInfo->successorArgs.getCount(); + if (addedArgCount == 0) + continue; + + // We need to replace the terminator instruction with one that + // has additional arguments. + + IRTerminatorInst* oldTerminator = bb->getTerminator(); + SLANG_ASSERT(oldTerminator); + + blockInfo->builder.setInsertInto(bb); + + auto oldArgCount = oldTerminator->getOperandCount(); + auto newArgCount = oldArgCount + addedArgCount; + + List newArgs; + for (UInt aa = 0; aa < oldArgCount; ++aa) + { + newArgs.add(oldTerminator->getOperand(aa)); + } + for (Index aa = 0; aa < addedArgCount; ++aa) + { + newArgs.add(blockInfo->successorArgs[aa]); + } + + IRTerminatorInst* newTerminator = (IRTerminatorInst*)blockInfo->builder.emitIntrinsicInst( + oldTerminator->getFullType(), + oldTerminator->op, + newArgCount, + newArgs.getBuffer()); + + // Transfer decorations (a terminator should have no children) over to the new instruction. + // + oldTerminator->transferDecorationsTo(newTerminator); + + // A terminator better not have uses, so we shouldn't have + // to replace them. + SLANG_ASSERT(!oldTerminator->firstUse); + + + // Okay, we should be clear to remove the old terminator + oldTerminator->removeAndDeallocate(); + } + + // Remove all the instructions we marked for deletion along + // the way. + // + // Currently these are "access chain" instructions for + // loads from (parts of) variables that got promoted. + for (auto inst : context->instsToRemove) + { + // TODO: do we need to be careful here in case one + // of thes operations still has uses, as part of + // another to-be-remvoed instruction? + + inst->removeAndDeallocate(); + } + + // Now we should be able to go through and remove + // of of the variables + for (auto var : context->promotableVars) + { + var->removeAndDeallocate(); + } +} + +// Construct SSA form for a global value with code +void constructSSA(IRModule* module, IRGlobalValueWithCode* globalVal) +{ + ConstructSSAContext context; + context.globalVal = globalVal; + + context.sharedBuilder.module = module; + context.sharedBuilder.session = module->session; + + context.builder.sharedBuilder = &context.sharedBuilder; + context.builder.setInsertInto(module->moduleInst); + + constructSSA(&context); +} + +void constructSSA(IRModule* module, IRInst* globalVal) +{ + switch (globalVal->op) + { + case kIROp_Func: + case kIROp_GlobalVar: + case kIROp_GlobalConstant: + constructSSA(module, (IRGlobalValueWithCode*)globalVal); + + default: + break; + } +} + +void constructSSA(IRModule* module) +{ + for(auto ii : module->getGlobalInsts()) + { + constructSSA(module, ii); + } +} + +} diff --git a/source/slang/slang-ir-ssa.h b/source/slang/slang-ir-ssa.h new file mode 100644 index 000000000..635810c08 --- /dev/null +++ b/source/slang/slang-ir-ssa.h @@ -0,0 +1,9 @@ +// slang-ir-ssa.h +#pragma once + +namespace Slang +{ + struct IRModule; + + void constructSSA(IRModule* module); +} diff --git a/source/slang/slang-ir-union.cpp b/source/slang/slang-ir-union.cpp new file mode 100644 index 000000000..e39fae262 --- /dev/null +++ b/source/slang/slang-ir-union.cpp @@ -0,0 +1,776 @@ +// slang-ir-union.cpp +#include "slang-ir-union.h" + +#include "slang-ir.h" +#include "slang-ir-insts.h" + +namespace Slang { + +// This file will implement a pass to replace any union types (currently +// just tagged unions) with plain `struct` types that attempt to provide +// equivalent semantics. This will necessarily be a bit fragile, and there +// will be fundamental limits to what the translation can support without +// improved features in the target shading languages/ILs. + +struct DesugarUnionTypesContext +{ + // We'll start with some basic state that we need to get the job done. + // + // This includes the IR module we are to process, as well as IR building + // state that we will initialize once and then use throughout the pass. + // + IRModule* module; + SharedIRBuilder sharedBuilderStorage; + IRBuilder builderStorage; + IRBuilder* getBuilder() { return &builderStorage; } + + // Because we will be replacing instructions that refer to unions with + // different logic, we'll want to remove the original instructions. + // However, we need to be careful about modifying the IR tree while also + // iterating it, and to keep things simple for ourselves we'll go ahead + // and build up a list of instruction to remove along the way, and then + // remove them all at the end. + // + List instsToRemove; + + // The overall flow of the pass is pretty simple, so we will walk through it now. + // + void processModule() + { + // We start by initializing our IR building state. + // + sharedBuilderStorage.session = module->session; + sharedBuilderStorage.module = module; + builderStorage.sharedBuilder = &sharedBuilderStorage; + + // Next, we will search for any instruction that create or use + // union types, and process them accordingingly (usually by + // constructing a new instruction to replace them). + // + processInstRec(module->getModuleInst()); + + // Along the way we will build up a list of the tagged union + // types that we encountered, but we will refrain from replacing + // them until we are done (so that we always know that the instructions + // we process above refer to the original type, and not its + // replacement. + // + for( auto info : taggedUnionInfos ) + { + auto taggedUnionType = info->taggedUnionType; + auto replacementInst = info->replacementInst; + + // TODO: We should consider transferring decorations from the source + // type to the destination, but doing so carelessly could create + // problems, since an IR struct type shouldn't have, e.g., a + // `TaggedUnionTypeLayout` attached to it. + + taggedUnionType->replaceUsesWith(replacementInst); + taggedUnionType->removeAndDeallocate(); + } + + // As described previously, we build up the `instsToRemove` list as + // we iterate so that we can remove them all here and not risk + // modifying the IR tree while also walking it. + // + // TODO: This might be overkill and we could conceivably just be + // a bit careful in `processInstRec`. + // + for(auto inst : instsToRemove) + { + inst->removeAndDeallocate(); + } + } + + // In order to replace a (tagged) union type, we will need to know + // something about it, and we will use the `TaggedUnionInfo` type + // to collect all the relevant information. + // + struct TaggedUnionInfo : public RefObject + { + // We obviously need to know the tagged union itself, and + // we will also use this structure to track the instruction + // (an IR struct type) that will replace it. + // + IRTaggedUnionType* taggedUnionType; + IRInst* replacementInst; + + // In order to compute a suitable layout for the replacement + // `struct` type we need to know how the tagged union itself + // would be laid out in memory, so we require that all tagged + // unions in the generated IR have an associated (target-specific) + // layout. + // + TaggedUnionTypeLayout* taggedUnionTypeLayout; + + // The basic approach we will use 16-byte chunks (represented as an array + // of `uint4`s) to reprent the "bulk" of a type, and then use a single field + // that could be up to 12 bytes to represent the "rest" of the type. + // + // Note that there are deeply ingrained assumptions here that all types + // are at least four bytes in size (so that unions cannot easily + // accomodate `half` value), and that any types *larger* than four bytes + // will need to be loaded/stored via multiple 4-byte loads/stores. + // + // With the basic idea out of the way, we need an IR level field + // in our struct to hold the bulk data, which comprises a "key" for + // looking up the field, and the type of the field itself. We also + // keep track of how many bytes we put in our bulk storage. + // + // The bulk field might be: + // + // - null, if none of the case types was 16 bytes or more + // - a single `uint4` for between 16 and 31 (inclusive) bytes + // - an array of `uint4`s for 32 or more bytes + // + UInt64 bulkSize = 0; + IRInst* bulkFieldKey = nullptr; + IRType* bulkFieldType = nullptr; + + // The same basic idea then applies to the rest of the data. + // + // The "rest" field will be either be absent (if the size of the + // type was evently divisible by 16), a scalar `uint`, or else + // a 2- or 3-component vector of `uint`. + // + UInt64 restSize = 0; + IRInst* restFieldKey = nullptr; + IRType* restFieldType = nullptr; + + // Finally, since we are currently working with tagged unions, + // we need a field to hold the tag, which will always be allocated + // after the fields that hold the bulk/rest of the payload. + // + // This field is always a single `uint`. + // + // TODO: if/when we support untagged unions, they could be handled + // by having this field be null. + // + IRInst* tagFieldKey; + }; + + // We will build up a list of all the tagged union types we encounter, + // so that we can replace them with the synthesized types when we are done. + // + List> taggedUnionInfos; + + // It is possible that we will see the same tagged union type referenced + // many times in the IR, but we only want to synthesize the information + // above (including the various IR structures) once, so we also maintain + // a map from the original IR type to the corresponding information. + // + Dictionary mapIRTypeToTaggedUnionInfo; + + // We will process all instructions in the module in a single recursive walk. + // + void processInstRec(IRInst* inst) + { + processInst(inst); + + for( auto child : inst->getChildren() ) + { + processInstRec(child); + } + } + // + // At each instruction, we will check if it is one of the union-related instructions + // we need to replace, and process it accordingly. + // + void processInst(IRInst* inst) + { + switch( inst->op ) + { + default: + // Any instruction not listed below either doesn't involve union types, + // or handles them in a hands-off fashion that we don't need to care about. + // + // E.g., a `load` of a union type from a constant buffer will turn into + // a load of the replacement `struct` type once we are done, and nothing + // needs to be done to the `load` instruction. + // + break; + + case kIROp_TaggedUnionType: + { + // We clearly need to process the tagged union type itself, but the actual + // work is handled by other functions. All we need to do here is ensure + // that the information for this type gets generated, and then we can + // rely on the main `processModule` function to do the actual replacement later. + // + auto type = cast(inst); + getTaggedUnionInfo(type); + } + break; + + case kIROp_ExtractTaggedUnionTag: + { + // The case of extracting the tag from a tagged union is relatively + // simple, because the replacement type will have a dedicated field or it. + // + // We start by finding the tagged union value the instruction is operating + // on, and then looking up the information for its type (which had + // better be a tagged union type). + // + auto taggedUnionVal = inst->getOperand(0); + auto taggedUnionInfo = getTaggedUnionInfo(taggedUnionVal->getDataType()); + + // Because the replacement type will have an explicit field for the tag, + // we can simply emit a single field-extract instruction to read its value + // out. + // + auto builder = getBuilder(); + builder->setInsertBefore(inst); + auto replacement = builder->emitFieldExtract( + inst->getFullType(), + taggedUnionVal, + taggedUnionInfo->tagFieldKey); + + // Now we can replace anything that used the original instruction with + // the new field-extract operation, and add this instruction to the + // list for later removal. + // + inst->replaceUsesWith(replacement); + instsToRemove.add(inst); + } + break; + + case kIROp_ExtractTaggedUnionPayload: + { + // The most interesting case is when we are trying to extract a particular + // payload (one of the case types) from a union. We may need to extract + // one or more fields from the data stored in the union's replacement + // type (the bulk/rest fields), and we may also have to convert them + // to the type expected via bit-casts. + + // We can start things off easily enough by extracting the tagged union + // value being operated on, as well as the information for its type. + // + auto taggedUnionVal = inst->getOperand(0); + auto taggedUnionInfo = getTaggedUnionInfo(taggedUnionVal->getDataType()); + + // Next we need to figure out which case is being extracted from the union. + // The operand for the case tag should be a literal by construction. + // + auto caseTagVal = inst->getOperand(1); + auto caseTagConst = as(caseTagVal); + SLANG_ASSERT(caseTagConst); + + // The case type we are extracting will be the result type of the instruciton. + // + auto caseType = inst->getDataType(); + // + // The tag value itself will be the index of the case type in the union + // type (and its layout). + // + auto caseTagIndex = UInt(caseTagConst->getValue()); + + // We can use the case tag value to look up the layout for the particular + // case type we are extracting (this will allow us to resolve byte offsets + // for fields, etc.). + // + auto taggedUnionTypeLayout = taggedUnionInfo->taggedUnionTypeLayout; + SLANG_ASSERT(caseTagIndex < UInt(taggedUnionTypeLayout->caseTypeLayouts.getCount())); + auto caseTypeLayout = taggedUnionTypeLayout->caseTypeLayouts[caseTagIndex]; + + // At this point we know the type we are trying to extract, as well + // as its layout. We will defer the actual implementation of extraction + // to a (recursive) subroutine that can extract a (sub-)field from the + // union at a given byte offset. Since we are extracting a full case + // right now, the byte offset will be zero. + // + auto payloadVal = extractPayload( + taggedUnionInfo, + taggedUnionVal, + caseType, + caseTypeLayout, + 0); + + // TODO: There is a significant flaw in the above approach when + // the case type might be (or contain) an array. If we have a setup + // like the following: + // + // union SomeUnion { float someCase[100]; ... } + // ... + // float result = someUnion.someCase[someIndex]; + // + // The current logic would desugar this into something like: + // + // struct SomeUnion { uint4 bulk[100]; ... } + // ... + // float[] tmp = { asfloat(someUnion.bulk[0].x), asfloat(someUnion.bulk[1].x), ... } + // float result = tmp[someIndex]; + // + // The result is that we copy an entire 100-element array into local memory + // just to fetch a single element, when it would be much nicer to just do: + // + // float result = asfloat(someUnion.bulk[someIndex].x); + // + // Achieving the latter code requires that rather than blindly translate + // the `extractTaggedUnionPayload` instruction into a semantically equiavlent + // value (which might lead to a big copy in the end), we should transitively + // chase down any "access chains" off of `inst` and see what leaf values are + // actually needed, and generated more tailored extraction logic for just + // the elements/fields that actually get referenced. + // + // The more refined approach can be built on top of many of the same primitives, + // so for now we will resign ourselves to the simpler but potentially less + // efficient approach. + + // Now that we've extracted the value for the payload from the fields of + // the replacement struct, we can use that extracted value to replace + // this instruction, and schedule the original instruction for removal. + // + inst->replaceUsesWith(payloadVal); + instsToRemove.add(inst); + } + break; + } + } + + // The `extractPayload` operation is the most important bit of translation we + // need to do to make unions work. We have as input the following: + // + IRInst* extractPayload( + + // - Information about a tagged union type and its layout. + TaggedUnionInfo* taggedUnionInfo, + + // - A single value of that tagged unon type. + IRInst* taggedUnionVal, + + // - Type type of some "payload" field we want to extract from the union. + IRType* payloadType, + + // - The memory layout of that payload type. + TypeLayout* payloadTypeLayout, + + // - The byte offset at which we want to fetch the payload. + UInt64 payloadOffset) + { + // We are going to be building some IR code no matter what. + // + auto builder = getBuilder(); + + // The basic approach here will be to look at the type we + // are trying to extract from the union, and whenever possible + // recursively walk its structure so that we can express things + // in terms of extraction of smaller/simpler types. + // + if( auto irStructType = as(payloadType) ) + { + // A structure type is a nice recursive case: we simply + // want to extract each of its field recursively, and + // then construct a fresh value of the `struct` type. + + // In all of the cases of this function we expect/require + // there to be complete type layout information for the + // types involved. + // + auto structTypeLayout = as(payloadTypeLayout); + SLANG_ASSERT(structTypeLayout); + + // We are going to emit code to extract each of the fields + // and collect them to use as operands to a `makeStruct`. + // + List fieldVals; + + // We need to walk over the fields in the order the IR expects them + UInt fieldCounter = 0; + for( auto irField : irStructType->getFields() ) + { + IRType* fieldType = irField->getFieldType(); + + // TODO: We need to confirm/enforce that the fields of the + // IR struct and the fields of the layout still align. + // + UInt fieldIndex = fieldCounter++; + auto fieldLayout = structTypeLayout->fields[fieldIndex]; + auto fieldTypeLayout = fieldLayout->getTypeLayout(); + + // The offset of the field can be computed from the base + // offset passed in, plus the reflection data for the field. + // + UInt64 fieldOffset = payloadOffset; + if(auto resInfo = fieldLayout->FindResourceInfo(LayoutResourceKind::Uniform)) + fieldOffset += resInfo->index; + + // We make a recursive call to extract each field, expecting + // that this will bottom out eventually. + // + IRInst* fieldVal = extractPayload( + taggedUnionInfo, + taggedUnionVal, + fieldType, + fieldTypeLayout, + fieldOffset); + fieldVals.add(fieldVal); + } + + // The final value is then just a new struct constructed from + // the extracted field values. + // + auto payloadVal = builder->emitMakeStruct(irStructType, fieldVals); + return payloadVal; + } + else if( auto vecType = as(payloadType) ) + { + auto elementType = vecType->getElementType(); + + // We expect that by the time we are desugaring union types + // all vector types have literal constant values for their + // element count. + // + auto elementCountVal = vecType->getElementCount(); + auto elementCountConst = as(elementCountVal); + SLANG_ASSERT(elementCountConst); + UInt elementCount = UInt(elementCountConst->getValue()); + + // HACK: There is currently no `VectorTypeLayout` and thus + // no way to query the layout of the elements of a vector + // type. Until that gets added we will kludge things here. + // + TypeLayout* elementTypeLayout = nullptr; + size_t elementSize = 0; + if(auto resInfo = payloadTypeLayout->FindResourceInfo(LayoutResourceKind::Uniform)) + elementSize = resInfo->count.getFiniteValue() / elementCount; + + // Similar to the `struct` case above, we will extract a + // value for each element of the vector, and then use + // `makeVector` to construct the result value. + // + List elementVals; + for(UInt ii = 0; ii < elementCount; ++ii) + { + auto elementVal = extractPayload( + taggedUnionInfo, + taggedUnionVal, + elementType, + elementTypeLayout, + payloadOffset + ii*elementSize); + elementVals.add(elementVal); + } + return builder->emitMakeVector(vecType, elementVals); + } + else if( auto matType = as(payloadType) ) + { + SLANG_UNIMPLEMENTED_X("matrix in union type"); + } + else if( auto arrayType = as(payloadType) ) + { + SLANG_UNIMPLEMENTED_X("array in union type"); + } + else + { + // If none of the above cases match, then we assume that + // we have an individual scalar field that we need to fetch. + // + UInt64 payloadSize = 0; + if( auto resInfo = payloadTypeLayout->FindResourceInfo(LayoutResourceKind::Uniform) ) + { + // TODO: somebody before this point should generate an error if + // we have a `union` type that contains a potentially unbounded + // amount of data. + // + payloadSize = resInfo->count.getFiniteValue(); + } + + if( payloadSize != 4 ) + { + // TODO: We should handle the case of 64-bit fields by fetching + // two `uint` values to form a `uint2`, and then using an + // appropriate bit-cast to get from `uint2` to, e.g., `double`. + // + // The case of 16-bit and smaller fields is more troublesome, but + // in the worst case we can load a `uint` and then use bitwise + // ops to extract what we need before bitcasting. + // + // The right long-term solution is for downstream languages to have + // better support for raw memory addressing. + + SLANG_UNIMPLEMENTED_X("leaf union field with size other than 4 bytes"); + } + + // We know that we want to fetch a value of size `payloadSize`, and + // we have a known base value and an initial offset into it. + // + IRInst* baseVal = taggedUnionVal; + UInt64 offset = payloadOffset; + + // We are going to refine our `baseVal` and `offset` as we go, by + // trying to narrow down the data we will access in the `struct` + // type that will provide storage for the union. + // + // The first thing we want to check is if the value sits in the + // "bulk" part of the storage, or the "rest." + // + UInt64 bulkSize = taggedUnionInfo->bulkSize; + if( offset < bulkSize ) + { + // If the value starts in the bulk area, then the whole + // thing had better fit in the bulk area. The 16-byte + // granularity rules for constant buffers should ensure + // this property for us on current targets. + // + SLANG_ASSERT(offset + payloadSize <= bulkSize); + + // Since we know we'll be accessing the bulk storage, + // we will extract it here. The extracted field will + // be our new base value, but the `offset` doesn't need + // to be updated since the bulk field sits at offset 0. + // + baseVal = builder->emitFieldExtract( + taggedUnionInfo->bulkFieldType, + baseVal, + taggedUnionInfo->bulkFieldKey); + + // The bulk storage could be an array, if there are 32 + // or more bytes of bulk storage. + // + if( auto baseArrayType = as(baseVal->getDataType()) ) + { + // If an array was allocated for bulk storage then + // our leaf value resides entirely within a single + // element (due to constant buffer layout rules), + // and so we will fetch the appropriate element here. + // + // We will change our `baseVal` to the extracted element, + // and then also adjust our `offset` to be relative + // to that element. + // + size_t bulkElementSize = 16; + auto index = offset / bulkElementSize; + baseVal = builder->emitElementExtract( + baseArrayType->getElementType(), + baseVal, + builder->getIntValue(builder->getIntType(), index)); + offset -= index*bulkElementSize; + } + } + else + { + // If the offset of the field we want is past the end of + // the bulk field then it must sit inside of the rest field, + // and we'll extract it here. This establishes a new + // base value, and we adjust the `offset` to be relative + // to the rest field (which starts at an offset equal to `bulkSize`). + // + baseVal = builder->emitFieldExtract( + taggedUnionInfo->restFieldType, + baseVal, + taggedUnionInfo->restFieldKey); + offset -= bulkSize; + } + + // We've now extracted a field that could be either a scalar or + // a vector, and we have an offset into it. In the case where + // the base value is a vector, we will extract out the appropriate + // element. + // + if( auto baseVecType = as(baseVal->getDataType()) ) + { + size_t vecElementSize = 4; + auto index = offset / vecElementSize; + baseVal = builder->emitElementExtract( + baseVecType->getElementType(), + baseVal, + builder->getIntValue(builder->getIntType(), index)); + offset -= index*vecElementSize; + } + + // At this point, our `baseVal` should be a single `uint`, and + // it should provide the storage for the exact thing we wanted + // to access (under the assumption that we always fetch 4 bytes + // on 4-byte alignment). + // + IRInst* payloadVal = baseVal; + SLANG_ASSERT(offset == 0); + + // TODO: we could imagine adding logic here to handle types less + // than 4 bytes in size by shifting and masking the value we + // just loaded. + + // The payload field we were trying to extract might have a type + // other than `uint`, and to handle that case we need to employ + // a bit-cast to get to the desired type. + // + if( payloadVal->getDataType() != payloadType ) + { + payloadVal = builder->emitBitCast( + payloadType, + payloadVal); + } + return payloadVal; + } + } + + // All of the logic so far as assumed we can just call `getTaggedUnionInfo` + // and have easy access to all the required information and the + // synthesized replacement type. + // + TaggedUnionInfo* getTaggedUnionInfo(IRType* type) + { + // The big picture is fairly simple: we will lazily build and + // memoize the information about tagged unions. + // + { + TaggedUnionInfo* info = nullptr; + if(mapIRTypeToTaggedUnionInfo.TryGetValue(type, info)) + return info; + } + + // When we don't find information in our memo-cache, we + // will construct it and add it to both the memo-cache + // *and* a global list of all tagged unions encountered, + // so that we can replacement them later. + // + auto info = createTaggedUnionInfo(type); + mapIRTypeToTaggedUnionInfo.Add(type, info.Ptr()); + taggedUnionInfos.add(info); + + return info; + } + + // The actual logic for creating a `TaggedUnionInfo` is relatively + // straightforward once we've decided what information we need. + // + RefPtr createTaggedUnionInfo(IRType* type) + { + // We expect that any type used as an operation to one of the + // `extractTaggedUnion*` operations must be an IR tagged union. + // + // Note: If/when we ever expose `union`s to user and allow + // then to create *generic* tagged union types it might appear + // that this needs to be changed to account for a `specialize` + // instruction in place of a concrete tagged union, but in + // practice this pass needs to be performed late enough that + // any such generic should be fully specialized. + // + auto taggedUnionType = as(type); + SLANG_ASSERT(taggedUnionType); + + RefPtr info = new TaggedUnionInfo(); + info->taggedUnionType = taggedUnionType; + + // We are going to create an instruction to replace `type`, + // and thus will be placing it into the same parent. + // + auto builder = getBuilder(); + builder->setInsertBefore(type); + + // A tagged union type will be replaced with an ordinary + // `struct` type with fields to store all the relevant + // data from any of the cases, plus a tag field. + // + auto structType = builder->createStructType(); + info->replacementInst = structType; + + // We require/expect the earlier code generation steps to have + // associated a layout with every tagged union that appears in + // the code. + // + auto layoutDecoration = type->findDecoration(); + SLANG_ASSERT(layoutDecoration); + auto layout = layoutDecoration->getLayout(); + SLANG_ASSERT(layout); + auto taggedUnionTypeLayout = as(layout); + SLANG_ASSERT(taggedUnionTypeLayout); + + info->taggedUnionTypeLayout = taggedUnionTypeLayout; + + // The size of the "payload" for the different cases (everything but + // the tag) is taken to be the offset of the tag itself. + // + // TODO: this might be inaccurate if the payload size isn't a multiple + // of the tag's alignment. We should deal with that when/if we support + // types smaller than 4 bytes in unions. + // + auto payloadSize = taggedUnionTypeLayout->tagOffset.getFiniteValue(); + + // We are going to be construction IR code that makes use of the `int` + // and `uint` types in several cases, so we go ahead and get a pointer + // to those types here. + // + auto intType = getBuilder()->getIntType(); + auto uintType = getBuilder()->getBasicType(BaseType::UInt); + + // For now we will use a simple stragegy for how we encode a union, + // which depends only on the total number of bytes needed, and not + // on the makeup of the values being stored. + // + // We will start by allocating one or more `uint4` values (in an + // array for the "or more" case) to hold the bulk of any large + // payload value. + // + size_t bulkVectorSize = 16; // Note: assuming `sizeof(uint4) == 16` on all targets + auto bulkVectorCount = payloadSize / bulkVectorSize; + auto bulkFieldSize = bulkVectorCount * bulkVectorSize; + if( bulkVectorCount ) + { + IRType* bulkFieldType = builder->getVectorType( + uintType, + builder->getIntValue(intType, 4)); + + if( bulkVectorCount > 1 ) + { + bulkFieldType = builder->getArrayType( + bulkFieldType, + builder->getIntValue(intType, bulkVectorCount)); + } + + auto bulkFieldKey = builder->createStructKey(); + builder->createStructField(structType, bulkFieldKey, bulkFieldType); + + info->bulkFieldKey = bulkFieldKey; + info->bulkFieldType = bulkFieldType; + } + info->bulkSize = bulkFieldSize; + + // The rest of the data (anything that doesn't fit in the bulk field), + // will get allocated into a single scalar or vector of `uint`. + // + auto restSize = payloadSize - bulkFieldSize; + if( restSize ) + { + size_t restElementSize = 4; // assuming `sizeof(uint) == 4` on all targets + auto restElementCount = restSize / restElementSize; + auto restFieldSize = restElementSize * restElementCount; + SLANG_ASSERT(restFieldSize == restSize); // Note: all our current targets have minimum 4-byte storage granularity + + IRType* restFieldType = uintType; + if( restElementCount > 1 ) + { + restFieldType = builder->getVectorType( + restFieldType, + builder->getIntValue(intType, restElementCount)); + } + + auto restFieldKey = builder->createStructKey(); + builder->createStructField(structType, restFieldKey, restFieldType); + + info->restFieldKey = restFieldKey; + info->restFieldType = restFieldType; + info->restSize = restFieldSize; + } + + // Finally, we add a field to represent the tag. + // + auto tagFieldType = uintType; + auto tagFieldKey = builder->createStructKey(); + builder->createStructField(structType, tagFieldKey, tagFieldType); + + info->tagFieldKey = tagFieldKey; + + return info; + } +}; + +void desugarUnionTypes( + IRModule* module) +{ + DesugarUnionTypesContext context; + context.module = module; + + context.processModule(); +} + +} // namespace Slang diff --git a/source/slang/slang-ir-union.h b/source/slang/slang-ir-union.h new file mode 100644 index 000000000..81757dced --- /dev/null +++ b/source/slang/slang-ir-union.h @@ -0,0 +1,18 @@ +// slang-ir-union.h +#pragma once + +namespace Slang { + +struct IRModule; + + /// Desugar any unions types, and code using them, in `module` + /// + /// Union types will be replaced with ordinary `struct` types that store + /// the data of the underlying type as a "bag of bits" and references + /// to cases of the union will be replaced with logic to extract the + /// relevant bits. + /// +void desugarUnionTypes( + IRModule* module); + +} // namespace Slang diff --git a/source/slang/slang-ir-validate.cpp b/source/slang/slang-ir-validate.cpp new file mode 100644 index 000000000..15228200e --- /dev/null +++ b/source/slang/slang-ir-validate.cpp @@ -0,0 +1,207 @@ +// slang-ir-validate.cpp +#include "slang-ir-validate.h" + +#include "slang-ir.h" +#include "slang-ir-insts.h" + +namespace Slang +{ + struct IRValidateContext + { + // The IR module we are validating. + IRModule* module; + + // A diagnostic sink to send errors to if anything is invalid. + DiagnosticSink* sink; + + DiagnosticSink* getSink() { return sink; } + + // A set of instructions we've seen, to help confirm that + // values are defined before they are used in a given block. + HashSet seenInsts; + }; + + void validateIRInst( + IRValidateContext* context, + IRInst* inst); + + void validate(IRValidateContext* context, bool condition, IRInst* inst, char const* message) + { + if (!condition) + { + context->getSink()->diagnose(inst, Diagnostics::irValidationFailed, message); + } + } + + void validateIRInstChildren( + IRValidateContext* context, + IRInst* parent) + { + IRInst* prevChild = nullptr; + for(auto child : parent->getDecorationsAndChildren() ) + { + // We need to check the integrity of the parent/next/prev links of + // all of our instructions + validate(context, child->parent == parent, child, "parent link"); + validate(context, child->prev == prevChild, child, "next/prev link"); + + // Recursively validate the instruction itself. + validateIRInst(context, child); + + // Do some extra validation around terminator instructions: + // + // * The last instruction of a block should always be a terminator + // * No other instruction should be a terminator + // + if(as(parent) && (child == parent->getLastDecorationOrChild())) + { + validate(context, as(child) != nullptr, child, "last instruction in block must be terminator"); + } + else + { + validate(context, !as(child), child, "terminator must be last instruction in a block"); + } + + + prevChild = child; + } + } + + void validateIRInstOperand( + IRValidateContext* context, + IRInst* inst, + IRUse* operandUse) + { + // The `IRUse` for the operand had better have `inst` as its user. + validate(context, operandUse->getUser() == inst, inst, "operand user"); + + // The value we are using needs to fit into one of a few cases. + // + // * If the parent of `inst` and of `operand` is the same block, then + // we require that `operand` is defined before `inst` + // + // * If the parents of `inst` and `operand` are both blocks in the + // same functin, then the block defining `operand` must dominate + // the block defining `inst`. + // + // * Otherwise, we simply require that the parent of `operand` be + // an ancestor (transitive parent) of `inst`. + + auto instParent = inst->getParent(); + + auto operandValue = operandUse->get(); + + if( !operandValue ) + { + // A null operand should almost always be an error, but + // we currently have a few cases where this arises. + // + // TODO: plug the leaks. + return; + } + + auto operandParent = operandValue->getParent(); + + if (auto instParentBlock = as(instParent)) + { + if (auto operandParentBlock = as(operandParent)) + { + if (instParentBlock == operandParentBlock) + { + // If `operandValue` precedes `inst`, then we should + // have already seen it, because we scan parent instructions + // in order. + validate(context, context->seenInsts.Contains(operandValue), inst, "def must come before use in same block"); + return; + } + + auto instFunc = instParentBlock->getParent(); + auto operandFunc = operandParentBlock->getParent(); + if (instFunc == operandFunc) + { + // The two instructions are defined in different blocks of + // the same function (or another value with code). We need + // to validate that `operandParentBlock` dominates `instParentBlock`. + // + // TODO: implement this validation once we compute dominator trees. + // + // validate(context, operandParentBlock->dominates(instParentBlock), inst, "def must dominate use"); + return; + } + } + } + + // If the special cases above did not trigger, then either the two values + // are nested in the same parent, but that parent isn't a block, or they + // are nested in distinct parents, and those parents aren't both children + // of a function. + // + // In either case, we need to enforce that the parent of `operand` needs + // to be an ancestor of `inst`. + // + for (auto pp = instParent; pp; pp = pp->getParent()) + { + if (pp == operandParent) + return; + } + // + // We failed to find `operandParent` while walking the ancestors of `inst`, + // so something had gone wrong. + validate(context, false, inst, "def must be ancestor of use"); + } + + void validateIRInstOperands( + IRValidateContext* context, + IRInst* inst) + { + if(inst->getFullType()) + validateIRInstOperand(context, inst, &inst->typeUse); + + UInt operandCount = inst->getOperandCount(); + for (UInt ii = 0; ii < operandCount; ++ii) + { + validateIRInstOperand(context, inst, inst->getOperands() + ii); + } + } + + void validateIRInst( + IRValidateContext* context, + IRInst* inst) + { + // Validate that any operands of the instruction are used appropriately + validateIRInstOperands(context, inst); + context->seenInsts.Add(inst); + + // If `inst` is itself a parent instruction, then we need to recursively + // validate its children. + validateIRInstChildren(context, inst); + } + + void validateIRModule(IRModule* module, DiagnosticSink* sink) + { + IRValidateContext contextStorage; + IRValidateContext* context = &contextStorage; + context->module = module; + context->sink = sink; + + auto moduleInst = module->moduleInst; + + validate(context, moduleInst != nullptr, moduleInst, "module instruction"); + validate(context, moduleInst->parent == nullptr, moduleInst, "module instruction parent"); + validate(context, moduleInst->prev == nullptr, moduleInst, "module instruction prev"); + validate(context, moduleInst->next == nullptr, moduleInst, "module instruction next"); + + validateIRInst(context, module->moduleInst); + } + + void validateIRModuleIfEnabled( + CompileRequestBase* compileRequest, + IRModule* module) + { + if (!compileRequest->shouldValidateIR) + return; + + auto sink = compileRequest->getSink(); + validateIRModule(module, sink); + } +} diff --git a/source/slang/slang-ir-validate.h b/source/slang/slang-ir-validate.h new file mode 100644 index 000000000..c9b0016f4 --- /dev/null +++ b/source/slang/slang-ir-validate.h @@ -0,0 +1,35 @@ +// slang-ir-validate.h +#pragma once + +namespace Slang +{ + class CompileRequestBase; + class DiagnosticSink; + struct IRModule; + + + // Validate that an IR module obeys the invariants we need to enforce. + // For example: + // + // * Confirm that linked lists for children and for use-def chains are consistent + // (e.g., x.next.prev == x) + // + // * Confirm that parent/child relationships are correct (e.g., if is `x` is in + // `y.children`, then `x.parent == y` + // + // * Confirm that every operand of an instruction is valid to reference (i.e., it + // must either be defined earlier in the same block, in a different block that + // dominates the current one, or in a parent instruction of the block. + // + // * Confirm that every block ends with a terminator, and there are no terminators + // elsewhere in a block. + // + // * Confirm that all the parameters of a block come before any "ordinary" instructions. + void validateIRModule(IRModule* module, DiagnosticSink* sink); + + // A wrapper that calls `validateIRModule` only when IR validation is enabled + // for the given compile request. + void validateIRModuleIfEnabled( + CompileRequestBase* compileRequest, + IRModule* module); +} diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp new file mode 100644 index 000000000..4975ac824 --- /dev/null +++ b/source/slang/slang-ir.cpp @@ -0,0 +1,4511 @@ +// slang-ir.cpp +#include "slang-ir.h" +#include "slang-ir-insts.h" + +#include "../core/slang-basic.h" + +#include "slang-mangle.h" + +namespace Slang +{ + struct IRSpecContext; + + IRInst* cloneGlobalValueWithLinkage( + IRSpecContext* context, + IRInst* originalVal, + IRLinkageDecoration* originalLinkage); + + struct IROpMapEntry + { + IROp op; + IROpInfo info; + }; + + // TODO: We should ideally be speeding up the name->inst + // mapping by using a dictionary, or even by pre-computing + // a hash table to be stored as a `static const` array. + // + // NOTE! That this array is now constructed in such a way that looking up + // an entry from an op is fast, by keeping blocks of main, and pseudo ops in same order + // as the ops themselves. Care must be taken to keep this constraint. + static const IROpMapEntry kIROps[] = + { + + // Main ops in order +#define INST(ID, MNEMONIC, ARG_COUNT, FLAGS) \ + { kIROp_##ID, { #MNEMONIC, ARG_COUNT, FLAGS, } }, +#include "slang-ir-inst-defs.h" + + // Pseudo ops +#define INST(ID, MNEMONIC, ARG_COUNT, FLAGS) /* empty */ +#define PSEUDO_INST(ID) \ + { kIRPseudoOp_##ID, { #ID, 0, 0 } }, + + // First is 'invalid' + { kIROp_Invalid,{ "invalid", 0, 0 } }, + // Then all the other psuedo ops +#include "slang-ir-inst-defs.h" + + }; + + IROpInfo getIROpInfo(IROp opIn) + { + const int op = opIn & kIROpMeta_PseudoOpMask; + if ((op & kIROpMeta_IsPseudoOp) && op < kIRPseudoOp_LastPlusOne) + { + // It's a pseudo op + const int index = op - kIRPseudoOp_First; + // Pseudo ops start from kIROpcount + const auto& entry = kIROps[kIROpCount + index]; + SLANG_ASSERT(entry.op == op); + return entry.info; + } + else if (op < kIROpCount) + { + // It's a main op + const auto& entry = kIROps[op]; + SLANG_ASSERT(entry.op == op); + return entry.info; + } + + // Don't know what this is + SLANG_ASSERT(!"Invalid op"); + SLANG_ASSERT(kIROps[kIROpCount].op == kIROp_Invalid); + return kIROps[kIROpCount].info; + } + + IROp findIROp(const UnownedStringSlice& name) + { + for (auto ee : kIROps) + { + if (name == ee.info.name) + return ee.op; + } + + return IROp(kIROp_Invalid); + } + + + + // + + void IRUse::debugValidate() + { +#ifdef _DEBUG + auto uv = this->usedValue; + if(!uv) + { + assert(!nextUse); + assert(!prevLink); + return; + } + + auto pp = &uv->firstUse; + for(auto u = uv->firstUse; u;) + { + assert(u->prevLink == pp); + + pp = &u->nextUse; + u = u->nextUse; + } +#endif + } + + void IRUse::init(IRInst* u, IRInst* v) + { + clear(); + + user = u; + usedValue = v; + if(v) + { + nextUse = v->firstUse; + prevLink = &v->firstUse; + + if(nextUse) + { + nextUse->prevLink = &this->nextUse; + } + + v->firstUse = this; + } + + debugValidate(); + } + + void IRUse::set(IRInst* uv) + { + init(user, uv); + } + + void IRUse::clear() + { + // This `IRUse` is part of the linked list + // of uses for `usedValue`. + + debugValidate(); + + if (usedValue) + { + auto uv = usedValue; + + *prevLink = nextUse; + if(nextUse) + { + nextUse->prevLink = prevLink; + } + + user = nullptr; + usedValue = nullptr; + nextUse = nullptr; + prevLink = nullptr; + + if(uv->firstUse) + uv->firstUse->debugValidate(); + } + } + + // IRInstListBase + + void IRInstListBase::Iterator::operator++() + { + if (inst) + { + inst = inst->next; + } + } + + IRInstListBase::Iterator IRInstListBase::begin() { return Iterator(first); } + IRInstListBase::Iterator IRInstListBase::end() { return Iterator(last ? last->next : nullptr); } + + // + + IRUse* IRInst::getOperands() + { + // We assume that *all* instructions are laid out + // in memory such that their arguments come right + // after the first `sizeof(IRInst)` bytes. + // + // TODO: we probably need to be careful and make + // this more robust. + + return (IRUse*)(this + 1); + } + + IRDecoration* IRInst::findDecorationImpl(IROp decorationOp) + { + for(auto dd : getDecorations()) + { + if(dd->op == decorationOp) + return dd; + } + return nullptr; + } + + // IRConstant + + IRIntegerValue GetIntVal(IRInst* inst) + { + switch (inst->op) + { + default: + SLANG_UNEXPECTED("needed a known integer value"); + UNREACHABLE_RETURN(0); + + case kIROp_IntLit: + return static_cast(inst)->value.intVal; + break; + } + } + + // IRParam + + IRParam* IRParam::getNextParam() + { + return as(getNextInst()); + } + + // IRArrayTypeBase + + IRInst* IRArrayTypeBase::getElementCount() + { + if (auto arrayType = as(this)) + return arrayType->getElementCount(); + + return nullptr; + } + + // IRPtrTypeBase + + IRType* tryGetPointedToType( + IRBuilder* builder, + IRType* type) + { + if( auto rateQualType = as(type) ) + { + type = rateQualType->getDataType(); + } + + // The "true" pointers and the pointer-like stdlib types are the easy cases. + if( auto ptrType = as(type) ) + { + return ptrType->getValueType(); + } + else if( auto ptrLikeType = as(type) ) + { + return ptrLikeType->getElementType(); + } + // + // A more interesting case arises when we have a `BindExistentials, ...>` + // where `P` is a pointer(-like) type. + // + else if( auto bindExistentials = as(type) ) + { + // We know that `BindExistentials` won't introduce its own + // existential type parameters, nor will any of the pointer(-like) + // type constructors `P`. + // + // Thus we know that the type that is pointed to should be + // the same as `BindExistentials`. + // + auto baseType = bindExistentials->getBaseType(); + if( auto baseElementType = tryGetPointedToType(builder, baseType) ) + { + UInt existentialArgCount = bindExistentials->getExistentialArgCount(); + List existentialArgs; + for( UInt ii = 0; ii < existentialArgCount; ++ii ) + { + existentialArgs.add(bindExistentials->getExistentialArg(ii)); + } + return builder->getBindExistentialsType( + baseElementType, + existentialArgCount, + existentialArgs.getBuffer()); + } + } + + // TODO: We may need to handle other cases here. + + return nullptr; + } + + + // IRBlock + + IRParam* IRBlock::getLastParam() + { + IRParam* param = getFirstParam(); + if (!param) return nullptr; + + while (auto nextParam = param->getNextParam()) + param = nextParam; + + return param; + } + + void IRBlock::addParam(IRParam* param) + { + // If there are any existing parameters, + // then insert after the last of them. + // + if (auto lastParam = getLastParam()) + { + param->insertAfter(lastParam); + } + // + // Otherwise, if there are any existing + // "ordinary" instructions, insert before + // the first of them. + // + else if(auto firstOrdinary = getFirstOrdinaryInst()) + { + param->insertBefore(firstOrdinary); + } + // + // Otherwise the block currently has neither + // parameters nor orindary instructions, + // so we can safely insert at the end of + // the list of (raw) children. + // + else + { + param->insertAtEnd(this); + } + } + + IRInst* IRBlock::getFirstOrdinaryInst() + { + // Find the last parameter (if any) of the block + auto lastParam = getLastParam(); + if (lastParam) + { + // If there is a last parameter, then the + // instructions after it are the ordinary + // instructions. + return lastParam->getNextInst(); + } + else + { + // If there isn't a last parameter, then + // there must not have been *any* parameters, + // and so the first instruction in the block + // is also the first ordinary one. + return getFirstInst(); + } + } + + IRInst* IRBlock::getLastOrdinaryInst() + { + // Under normal circumstances, the last instruction + // in the block is also the last ordinary instruction. + // However, there is the special case of a block with + // only parameters (which might happen as a temporary + // state while we are building IR). + auto inst = getLastInst(); + + // If the last instruction is a parameter, then + // there are no ordinary instructions, so the last + // one is a null pointer. + if (as(inst)) + return nullptr; + + // Otherwise the last instruction is the last "ordinary" + // instruction as well. + return inst; + } + + + // The predecessors of a block should all show up as users + // of its value, so rather than explicitly store the CFG, + // we will recover it on demand from the use-def information. + // + // Note: we are really iterating over incoming/outgoing *edges* + // for a block, because there might be multiple uses of a block, + // if more than one way of an N-way branch targets the same block. + + // Get the list of successor blocks for an instruction, + // which we expect to be the last instruction in a block. + static IRBlock::SuccessorList getSuccessors(IRInst* terminator) + { + // If the block somehow isn't terminated, then + // there is no way to read its successors, so + // we return an empty list. + if (!terminator || !as(terminator)) + return IRBlock::SuccessorList(nullptr, nullptr); + + // Otherwise, based on the opcode of the terminator + // instruction, we will build up our list of uses. + IRUse* begin = nullptr; + IRUse* end = nullptr; + UInt stride = 1; + + auto operands = terminator->getOperands(); + switch (terminator->op) + { + case kIROp_ReturnVal: + case kIROp_ReturnVoid: + case kIROp_Unreachable: + case kIROp_MissingReturn: + case kIROp_discard: + break; + + case kIROp_unconditionalBranch: + case kIROp_loop: + // unconditonalBranch + begin = operands + 0; + end = begin + 1; + break; + + case kIROp_conditionalBranch: + case kIROp_ifElse: + // conditionalBranch + begin = operands + 1; + end = begin + 2; + break; + + case kIROp_Switch: + // switch ... + begin = operands + 2; + + // TODO: this ends up point one *after* the "one after the end" + // location, so we should really change the representation + // so that we don't need to form this pointer... + end = operands + terminator->getOperandCount() + 1; + stride = 2; + break; + + default: + SLANG_UNEXPECTED("unhandled terminator instruction"); + UNREACHABLE_RETURN(IRBlock::SuccessorList(nullptr, nullptr)); + } + + return IRBlock::SuccessorList(begin, end, stride); + } + + static IRUse* adjustPredecessorUse(IRUse* use) + { + // We will search until we either find a + // suitable use, or run out of uses. + for (;use; use = use->nextUse) + { + // We only want to deal with uses that represent + // a "sucessor" operand to some terminator instruction. + // We will re-use the logic for getting the successor + // list from such an instruction. + + auto successorList = getSuccessors((IRInst*) use->getUser()); + + if(use >= successorList.begin_ + && use < successorList.end_) + { + UInt index = (use - successorList.begin_); + if ((index % successorList.stride) == 0) + { + // This use is in the range of the sucessor list, + // and so it represents a real edge between + // blocks. + return use; + } + } + } + + // If we ran out of uses, then we are at the end + // of the list of incoming edges. + return nullptr; + } + + IRBlock::PredecessorList IRBlock::getPredecessors() + { + // We want to iterate over the predecessors of this block. + // First, we resign ourselves to iterating over the + // incoming edges, rather than the blocks themselves. + // This might sound like a trival distinction, but it is + // possible for there to be multiple edges between two + // blocks (as for a `switch` with multiple cases that + // map to the same code). Any client that wants just + // the unique predecessor blocks needs to deal with + // the deduplication themselves. + // + // Next, we note that for any predecessor edge, there will + // be a use of this block in the terminator instruction of + // the predecessor. We basically just want to iterate over + // the users of this block, then, but we need to be careful + // to rule out anything that doesn't actually represent + // an edge. The `adjustPredecessorUse` function will be + // used to search for a use that actually represents an edge. + + return PredecessorList( + adjustPredecessorUse(firstUse)); + } + + UInt IRBlock::PredecessorList::getCount() + { + UInt count = 0; + for (auto ii : *this) + { + (void)ii; + count++; + } + return count; + } + + bool IRBlock::PredecessorList::isEmpty() + { + return !(begin() != end()); + } + + + void IRBlock::PredecessorList::Iterator::operator++() + { + if (!use) return; + use = adjustPredecessorUse(use->nextUse); + } + + IRBlock* IRBlock::PredecessorList::Iterator::operator*() + { + if (!use) return nullptr; + return (IRBlock*)use->getUser()->parent; + } + + IRBlock::SuccessorList IRBlock::getSuccessors() + { + // The successors of a block will all be listed + // as operands of its terminator instruction. + // Depending on the terminator, we might have + // different numbers of operands to deal with. + // + // (We might also have to deal with a "stride" + // in the case where the basic-block operands + // are mixed up with non-block operands) + + auto terminator = getLastInst(); + return Slang::getSuccessors(terminator); + } + + UInt IRBlock::SuccessorList::getCount() + { + UInt count = 0; + for (auto ii : *this) + { + (void)ii; + count++; + } + return count; + } + + void IRBlock::SuccessorList::Iterator::operator++() + { + use += stride; + } + + IRBlock* IRBlock::SuccessorList::Iterator::operator*() + { + return (IRBlock*)use->get(); + } + + UInt IRUnconditionalBranch::getArgCount() + { + switch(op) + { + case kIROp_unconditionalBranch: + return getOperandCount() - 1; + + case kIROp_loop: + return getOperandCount() - 3; + + default: + SLANG_UNEXPECTED("unhandled unconditional branch opcode"); + UNREACHABLE_RETURN(0); + } + } + + IRUse* IRUnconditionalBranch::getArgs() + { + switch(op) + { + case kIROp_unconditionalBranch: + return getOperands() + 1; + + case kIROp_loop: + return getOperands() + 3; + + default: + SLANG_UNEXPECTED("unhandled unconditional branch opcode"); + UNREACHABLE_RETURN(0); + } + } + + IRInst* IRUnconditionalBranch::getArg(UInt index) + { + return getArgs()[index].usedValue; + } + + IRParam* IRGlobalValueWithParams::getFirstParam() + { + auto entryBlock = getFirstBlock(); + if(!entryBlock) return nullptr; + + return entryBlock->getFirstParam(); + } + + IRParam* IRGlobalValueWithParams::getLastParam() + { + auto entryBlock = getFirstBlock(); + if(!entryBlock) return nullptr; + + return entryBlock->getLastParam(); + } + + IRInstList IRGlobalValueWithParams::getParams() + { + auto entryBlock = getFirstBlock(); + if(!entryBlock) return IRInstList(); + + return entryBlock->getParams(); + } + + + // IRFunc + + IRType* IRFunc::getResultType() { return getDataType()->getResultType(); } + UInt IRFunc::getParamCount() { return getDataType()->getParamCount(); } + IRType* IRFunc::getParamType(UInt index) { return getDataType()->getParamType(index); } + + void IRGlobalValueWithCode::addBlock(IRBlock* block) + { + block->insertAtEnd(this); + } + + void fixUpFuncType(IRFunc* func) + { + SLANG_ASSERT(func); + + auto irModule = func->getModule(); + SLANG_ASSERT(irModule); + + SharedIRBuilder sharedBuilder; + sharedBuilder.module = irModule; + + IRBuilder builder; + builder.sharedBuilder = &sharedBuilder; + + builder.setInsertBefore(func); + + List paramTypes; + for(auto param : func->getParams()) + { + paramTypes.add(param->getFullType()); + } + + auto resultType = func->getResultType(); + + auto funcType = builder.getFuncType(paramTypes, resultType); + builder.setDataType(func, funcType); + } + + // + + bool isTerminatorInst(IROp op) + { + switch (op) + { + default: + return false; + + case kIROp_ReturnVal: + case kIROp_ReturnVoid: + case kIROp_unconditionalBranch: + case kIROp_conditionalBranch: + case kIROp_loop: + case kIROp_ifElse: + case kIROp_discard: + case kIROp_Switch: + case kIROp_Unreachable: + case kIROp_MissingReturn: + return true; + } + } + + bool isTerminatorInst(IRInst* inst) + { + if (!inst) return false; + return isTerminatorInst(inst->op); + } + + // + + IRBlock* IRBuilder::getBlock() + { + return as(insertIntoParent); + } + + // Get the current function (or other value with code) + // that we are inserting into (if any). + IRGlobalValueWithCode* IRBuilder::getFunc() + { + auto pp = insertIntoParent; + if (auto block = as(pp)) + { + pp = pp->getParent(); + } + return as(pp); + } + + + void IRBuilder::setInsertInto(IRInst* insertInto) + { + insertIntoParent = insertInto; + insertBeforeInst = nullptr; + } + + void IRBuilder::setInsertBefore(IRInst* insertBefore) + { + SLANG_ASSERT(insertBefore); + insertIntoParent = insertBefore->parent; + insertBeforeInst = insertBefore; + } + + + // Add an instruction into the current scope + void IRBuilder::addInst( + IRInst* inst) + { + if(insertBeforeInst) + { + inst->insertBefore(insertBeforeInst); + } + else if (insertIntoParent) + { + inst->insertAtEnd(insertIntoParent); + } + else + { + // Don't append the instruction anywhere + } + } + + // Given two parent instructions, pick the better one to use as as + // insertion location for a "hoistable" instruction. + // + IRInst* mergeCandidateParentsForHoistableInst(IRInst* left, IRInst* right) + { + // If the candidates are both the same, then who cares? + if(left == right) return left; + + // If either `left` or `right` is a block, then we need to be + // a bit careful, because blocks can see other values just using + // the dominance relationship, without a direct parent-child relationship. + // + // First, check if each of `left` and `right` is a block. + // + auto leftBlock = as(left); + auto rightBlock = as(right); + // + // As a special case, if both of these are blocks in the same parent, + // then we need to pick between them based on dominance. + // + if (leftBlock && rightBlock && (leftBlock->getParent() == rightBlock->getParent())) + { + // We assume that the order of basic blocks in a function is compatible + // with the dominance relationship (that is, if A dominates B, then + // A comes before B in the list of blocks), so it suffices to pick + // the *later* of the two blocks. + // + // There are ways we could try to speed up this search, but no matter + // what it will be O(n) in the number of blocks, unless we build + // an explicit dominator tree, which is infeasible during IR building. + // Thus we just do a simple linear walk here. + // + // We will start at `leftBlock` and walk forward, until either... + // + for (auto ll = leftBlock; ll; ll = ll->getNextBlock()) + { + // ... we see `rightBlock` (in which case `rightBlock` came later), or ... + // + if (ll == rightBlock) return rightBlock; + } + // + // ... we run out of blocks (in which case `leftBlock` came later). + // + return leftBlock; + } + + // + // If the special case above doesn't apply, then `left` or `right` might + // still be a block, but they aren't blocks nested in the same function. + // We will find the first non-block ancestor of `left` and/or `right`. + // This will either be the inst itself (it is isn't a block), or + // its immediate parent (if it *is* a block). + // + auto leftNonBlock = leftBlock ? leftBlock->getParent() : left; + auto rightNonBlock = rightBlock ? rightBlock->getParent() : right; + + // If either side is null, then take the non-null one. + // + if (!leftNonBlock) return right; + if (!rightNonBlock) return left; + + // If the non-block on the left or right is a descendent of + // the other, then that is what we should use. + // + IRInst* parentNonBlock = nullptr; + for (auto ll = leftNonBlock; ll; ll = ll->getParent()) + { + if (ll == rightNonBlock) + { + parentNonBlock = leftNonBlock; + break; + } + } + for (auto rr = rightNonBlock; rr; rr = rr->getParent()) + { + if (rr == leftNonBlock) + { + SLANG_ASSERT(!parentNonBlock || parentNonBlock == leftNonBlock); + parentNonBlock = rightNonBlock; + break; + } + } + + // As a matter of validity in the IR, we expect one + // of the two to be an ancestor (in the non-block case), + // because otherwise we'd be violating the basic dominance + // assumptions. + // + SLANG_ASSERT(parentNonBlock); + + // As a fallback, try to use the left parent as a default + // in case things go badly. + // + if (!parentNonBlock) + { + parentNonBlock = leftNonBlock; + } + + IRInst* parent = parentNonBlock; + + // At this point we've found a non-block parent where we + // could stick things, but we have to fix things up in + // case we should be inserting into a block beneath + // that non-block parent. + if (leftBlock && (parentNonBlock == leftNonBlock)) + { + // We have a left block, and have picked its parent. + + // It cannot be the case that there is a right block + // with the same parent, or else our special case + // would have triggered at the start. + SLANG_ASSERT(!rightBlock || (parentNonBlock != rightNonBlock)); + + parent = leftBlock; + } + else if (rightBlock && (parentNonBlock == rightNonBlock)) + { + // We have a right block, and have picked its parent. + + // We already tested above, so we know there isn't a + // matching situation on the left side. + + parent = rightBlock; + } + + // Okay, we've picked the parent we want to insert into, + // *but* one last special case arises, because an `IRGlobalValueWithCode` + // is not actually a suitable place to insert instructions. + // Furthermore, there is no actual need to insert instructions at + // that scope, because any parameters, etc. are actually attached + // to the block(s) within the function. + if (auto parentFunc = as(parent)) + { + // Insert in the parent of the function (or other value with code). + // We know that the parent must be able to hold ordinary instructions, + // because it was able to hold this `IRGlobalValueWithCode` + parent = parentFunc->getParent(); + } + + return parent; + } + + IRInst* createEmptyInst( + IRModule* module, + IROp op, + int totalArgCount) + { + size_t size = sizeof(IRInst) + (totalArgCount) * sizeof(IRUse); + + SLANG_ASSERT(module); + IRInst* inst = (IRInst*)module->memoryArena.allocateAndZero(size); + + inst->operandCount = uint32_t(totalArgCount); + inst->op = op; + + return inst; + } + + IRInst* createEmptyInstWithSize( + IRModule* module, + IROp op, + size_t totalSizeInBytes) + { + SLANG_ASSERT(totalSizeInBytes >= sizeof(IRInst)); + + SLANG_ASSERT(module); + IRInst* inst = (IRInst*)module->memoryArena.allocateAndZero(totalSizeInBytes); + + inst->operandCount = 0; + inst->op = op; + + return inst; + } + + // Given an instruction that represents a constant, a type, etc. + // Try to "hoist" it as far toward the global scope as possible + // to insert it at a location where it will be maximally visible. + // + void addHoistableInst( + IRBuilder* builder, + IRInst* inst) + { + // Start with the assumption that we would insert this instruction + // into the global scope (the instruction that represents the module) + IRInst* parent = builder->getModule()->getModuleInst(); + + // The above decision might be invalid, because there might be + // one or more operands of the instruction that are defined in + // more deeply nested parents than the global scope. + // + // Therefore, we will scan the operands of the instruction, and + // look at the parents that define them. + // + UInt operandCount = inst->getOperandCount(); + for (UInt ii = 0; ii < operandCount; ++ii) + { + auto operand = inst->getOperand(ii); + if (!operand) + continue; + + auto operandParent = operand->getParent(); + + parent = mergeCandidateParentsForHoistableInst(parent, operandParent); + } + + // We better have ended up with a place to insert. + SLANG_ASSERT(parent); + + // If we have chosen to insert into the same parent that the + // IRBuilder is configured to use, then respect its `insertBeforeInst` + // setting. + if (parent == builder->insertIntoParent) + { + builder->addInst(inst); + return; + } + + // Otherwise, we just want to insert at the end of the chosen parent. + // + // TODO: be careful about inserting after the terminator of a block... + + inst->insertAtEnd(parent); + } + + static void maybeSetSourceLoc( + IRBuilder* builder, + IRInst* value) + { + if(!builder) + return; + + auto sourceLocInfo = builder->sourceLocInfo; + if(!sourceLocInfo) + return; + + // Try to find something with usable location info + for(;;) + { + if(sourceLocInfo->sourceLoc.getRaw()) + break; + + if(!sourceLocInfo->next) + break; + + sourceLocInfo = sourceLocInfo->next; + } + + value->sourceLoc = sourceLocInfo->sourceLoc; + } + + // Create an IR instruction/value and initialize it. + // + // In this case `argCount` and `args` represent the + // arguments *after* the type (which is a mandatory + // argument for all instructions). + template + static T* createInstImpl( + IRModule* module, + IRBuilder* builder, + IROp op, + IRType* type, + UInt fixedArgCount, + IRInst* const* fixedArgs, + UInt varArgListCount, + UInt const* listArgCounts, + IRInst* const* const* listArgs) + { + UInt varArgCount = 0; + for (UInt ii = 0; ii < varArgListCount; ++ii) + { + varArgCount += listArgCounts[ii]; + } + + UInt size = sizeof(IRInst) + (fixedArgCount + varArgCount) * sizeof(IRUse); + if (sizeof(T) > size) + { + size = sizeof(T); + } + + SLANG_ASSERT(module); + T* inst = (T*)module->memoryArena.allocateAndZero(size); + + // TODO: Do we need to run ctor after zeroing? + new(inst)T(); + + inst->operandCount = (uint32_t)(fixedArgCount + varArgCount); + + inst->op = op; + + if (type) + { + inst->typeUse.init(inst, type); + } + + maybeSetSourceLoc(builder, inst); + + auto operand = inst->getOperands(); + + for( UInt aa = 0; aa < fixedArgCount; ++aa ) + { + if (fixedArgs) + { + operand->init(inst, fixedArgs[aa]); + } + operand++; + } + + for (UInt ii = 0; ii < varArgListCount; ++ii) + { + UInt listArgCount = listArgCounts[ii]; + for (UInt jj = 0; jj < listArgCount; ++jj) + { + if (listArgs[ii]) + { + operand->init(inst, listArgs[ii][jj]); + } + else + { + operand->init(inst, nullptr); + } + operand++; + } + } + return inst; + } + + static IRInst* createInstWithSizeImpl( + IRBuilder* builder, + IROp op, + IRType* type, + size_t sizeInBytes) + { + auto module = builder->getModule(); + IRInst* inst = (IRInst*)module->memoryArena.allocate(sizeInBytes); + // Zero only the 'type' + memset(inst, 0, sizeof(IRInst)); + // TODO: Do we need to run ctor after zeroing? + new (inst) IRInst; + + inst->op = op; + if (type) + { + inst->typeUse.init(inst, type); + } + maybeSetSourceLoc(builder, inst); + return inst; + } + + template + static T* createInstImpl( + IRBuilder* builder, + IROp op, + IRType* type, + UInt fixedArgCount, + IRInst* const* fixedArgs, + UInt varArgCount = 0, + IRInst* const* varArgs = nullptr) + { + return createInstImpl( + builder->getModule(), + builder, + op, + type, + fixedArgCount, + fixedArgs, + 1, + &varArgCount, + &varArgs); + } + + template + static T* createInstImpl( + IRBuilder* builder, + IROp op, + IRType* type, + UInt fixedArgCount, + IRInst* const* fixedArgs, + UInt varArgListCount, + UInt const* listArgCount, + IRInst* const* const* listArgs) + { + return createInstImpl( + builder->getModule(), + builder, + op, + type, + fixedArgCount, + fixedArgs, + varArgListCount, + listArgCount, + listArgs); + } + + template + static T* createInst( + IRBuilder* builder, + IROp op, + IRType* type, + UInt argCount, + IRInst* const* args) + { + return createInstImpl( + builder, + op, + type, + argCount, + args); + } + + template + static T* createInst( + IRBuilder* builder, + IROp op, + IRType* type) + { + return createInstImpl( + builder, + op, + type, + 0, + nullptr); + } + + template + static T* createInst( + IRBuilder* builder, + IROp op, + IRType* type, + IRInst* arg) + { + return createInstImpl( + builder, + op, + type, + 1, + &arg); + } + + template + static T* createInst( + IRBuilder* builder, + IROp op, + IRType* type, + IRInst* arg1, + IRInst* arg2) + { + IRInst* args[] = { arg1, arg2 }; + return createInstImpl( + builder, + op, + type, + 2, + &args[0]); + } + + template + static T* createInstWithTrailingArgs( + IRBuilder* builder, + IROp op, + IRType* type, + UInt argCount, + IRInst* const* args) + { + return createInstImpl( + builder, + op, + type, + argCount, + args); + } + + template + static T* createInstWithTrailingArgs( + IRBuilder* builder, + IROp op, + IRType* type, + UInt fixedArgCount, + IRInst* const* fixedArgs, + UInt varArgCount, + IRInst* const* varArgs) + { + return createInstImpl( + builder, + op, + type, + fixedArgCount, + fixedArgs, + varArgCount, + varArgs); + } + + template + static T* createInstWithTrailingArgs( + IRBuilder* builder, + IROp op, + IRType* type, + IRInst* arg1, + UInt varArgCount, + IRInst* const* varArgs) + { + IRInst* fixedArgs[] = { arg1 }; + UInt fixedArgCount = sizeof(fixedArgs) / sizeof(fixedArgs[0]); + + return createInstImpl( + builder, + op, + type, + fixedArgCount, + fixedArgs, + varArgCount, + varArgs); + } + // + + bool operator==(IRInstKey const& left, IRInstKey const& right) + { + if(left.inst->op != right.inst->op) return false; + if(left.inst->getFullType() != right.inst->getFullType()) return false; + if(left.inst->operandCount != right.inst->operandCount) return false; + + auto argCount = left.inst->operandCount; + auto leftArgs = left.inst->getOperands(); + auto rightArgs = right.inst->getOperands(); + for( UInt aa = 0; aa < argCount; ++aa ) + { + if(leftArgs[aa].get() != rightArgs[aa].get()) + return false; + } + + return true; + } + + int IRInstKey::GetHashCode() + { + auto code = Slang::GetHashCode(inst->op); + code = combineHash(code, Slang::GetHashCode(inst->getFullType())); + code = combineHash(code, Slang::GetHashCode(inst->getOperandCount())); + + auto argCount = inst->getOperandCount(); + auto args = inst->getOperands(); + for( UInt aa = 0; aa < argCount; ++aa ) + { + code = combineHash(code, Slang::GetHashCode(args[aa].get())); + } + return code; + } + + UnownedStringSlice IRConstant::getStringSlice() + { + assert(op == kIROp_StringLit); + // If the transitory decoration is set, then this is uses the transitoryStringVal for the text storage. + // This is typically used when we are using a transitory IRInst held on the stack (such that it can be looked up in cached), + // that just points to a string elsewhere, and NOT the typical normal style, where the string is held after the instruction in memory. + // + if(findDecorationImpl(kIROp_TransitoryDecoration)) + { + return UnownedStringSlice(value.transitoryStringVal.chars, value.transitoryStringVal.numChars); + } + else + { + return UnownedStringSlice(value.stringVal.chars, value.stringVal.numChars); + } + } + + bool IRConstant::isValueEqual(IRConstant* rhs) + { + // If they are literally the same thing.. + if (this == rhs) + { + return true; + } + // Check the type and they are the same op & same type + if (op != rhs->op) + { + return false; + } + + switch (op) + { + case kIROp_BoolLit: + case kIROp_FloatLit: + case kIROp_IntLit: + { + SLANG_COMPILE_TIME_ASSERT(sizeof(IRFloatingPointValue) == sizeof(IRIntegerValue)); + // ... we can just compare as bits + return value.intVal == rhs->value.intVal; + } + case kIROp_PtrLit: + { + return value.ptrVal == rhs->value.ptrVal; + } + case kIROp_StringLit: + { + return getStringSlice() == rhs->getStringSlice(); + } + default: break; + } + + SLANG_ASSERT(!"Unhandled type"); + return false; + } + + /// True if constants are equal + bool IRConstant::equal(IRConstant* rhs) + { + // TODO(JS): Only equal if pointer types are identical (to match how getHashCode works below) + return isValueEqual(rhs) && getFullType() == rhs->getFullType(); + } + + int IRConstant::getHashCode() + { + auto code = Slang::GetHashCode(op); + code = combineHash(code, Slang::GetHashCode(getFullType())); + + switch (op) + { + case kIROp_BoolLit: + case kIROp_FloatLit: + case kIROp_IntLit: + { + SLANG_COMPILE_TIME_ASSERT(sizeof(IRFloatingPointValue) == sizeof(IRIntegerValue)); + // ... we can just compare as bits + return combineHash(code, Slang::GetHashCode(value.intVal)); + } + case kIROp_PtrLit: + { + return combineHash(code, Slang::GetHashCode(value.ptrVal)); + } + case kIROp_StringLit: + { + const UnownedStringSlice slice = getStringSlice(); + return combineHash(code, Slang::GetHashCode(slice.begin(), slice.size())); + } + default: + { + SLANG_ASSERT(!"Invalid type"); + return 0; + } + } + } + + static IRConstant* findOrEmitConstant( + IRBuilder* builder, + IRConstant& keyInst) + { + // We now know where we want to insert, but there might + // already be an equivalent instruction in that block. + // + // We will check for such an instruction in a slightly hacky + // way: we will construct a temporary instruction and + // then use it to look up in a cache of instructions. + // The 'fake' instruction is passed in as keyInst. + + IRConstantKey key; + key.inst = &keyInst; + + IRConstant* irValue = nullptr; + if( builder->sharedBuilder->constantMap.TryGetValue(key, irValue) ) + { + // We found a match, so just use that. + return irValue; + } + + // Calculate the minimum object size (ie not including the payload of value) + const size_t prefixSize = SLANG_OFFSET_OF(IRConstant, value); + + switch (keyInst.op) + { + default: + SLANG_UNEXPECTED("missing case for IR constant"); + break; + + case kIROp_BoolLit: + case kIROp_IntLit: + { + irValue = static_cast(createInstWithSizeImpl(builder, keyInst.op, keyInst.getFullType(), prefixSize + sizeof(IRIntegerValue))); + irValue->value.intVal = keyInst.value.intVal; + break; + } + case kIROp_FloatLit: + { + irValue = static_cast(createInstWithSizeImpl(builder, keyInst.op, keyInst.getFullType(), prefixSize + sizeof(IRFloatingPointValue))); + irValue->value.floatVal = keyInst.value.floatVal; + break; + } + case kIROp_PtrLit: + { + irValue = static_cast(createInstWithSizeImpl(builder, keyInst.op, keyInst.getFullType(), prefixSize + sizeof(void*))); + irValue->value.ptrVal = keyInst.value.ptrVal; + break; + } + case kIROp_StringLit: + { + const UnownedStringSlice slice = keyInst.getStringSlice(); + + const size_t sliceSize = slice.size(); + const size_t instSize = prefixSize + offsetof(IRConstant::StringValue, chars) + sliceSize; + + irValue = static_cast(createInstWithSizeImpl(builder, keyInst.op, keyInst.getFullType(), instSize)); + + IRConstant::StringValue& dstString = irValue->value.stringVal; + + dstString.numChars = uint32_t(sliceSize); + // Turn into pointer to avoid warning of array overrun + char* dstChars = dstString.chars; + // Copy the chars + memcpy(dstChars, slice.begin(), sliceSize); + + break; + } + } + + key.inst = irValue; + builder->sharedBuilder->constantMap.Add(key, irValue); + + addHoistableInst(builder, irValue); + + return irValue; + } + + // + + IRInst* IRBuilder::getBoolValue(bool inValue) + { + IRConstant keyInst; + memset(&keyInst, 0, sizeof(keyInst)); + keyInst.op = kIROp_BoolLit; + keyInst.typeUse.usedValue = getBoolType(); + keyInst.value.intVal = IRIntegerValue(inValue); + return findOrEmitConstant(this, keyInst); + } + + IRInst* IRBuilder::getIntValue(IRType* type, IRIntegerValue inValue) + { + IRConstant keyInst; + memset(&keyInst, 0, sizeof(keyInst)); + keyInst.op = kIROp_IntLit; + keyInst.typeUse.usedValue = type; + keyInst.value.intVal = inValue; + return findOrEmitConstant(this, keyInst); + } + + IRInst* IRBuilder::getFloatValue(IRType* type, IRFloatingPointValue inValue) + { + IRConstant keyInst; + memset(&keyInst, 0, sizeof(keyInst)); + keyInst.op = kIROp_FloatLit; + keyInst.typeUse.usedValue = type; + keyInst.value.floatVal = inValue; + return findOrEmitConstant(this, keyInst); + } + + IRStringLit* IRBuilder::getStringValue(const UnownedStringSlice& inSlice) + { + IRConstant keyInst; + memset(&keyInst, 0, sizeof(keyInst)); + + // Mark that this is on the stack... + IRDecoration stackDecoration; + memset(&stackDecoration, 0, sizeof(stackDecoration)); + stackDecoration.op = kIROp_TransitoryDecoration; + stackDecoration.insertAtEnd(&keyInst); + + keyInst.op = kIROp_StringLit; + keyInst.typeUse.usedValue = getStringType(); + + IRConstant::StringSliceValue& dstSlice = keyInst.value.transitoryStringVal; + dstSlice.chars = const_cast(inSlice.begin()); + dstSlice.numChars = uint32_t(inSlice.size()); + + return static_cast(findOrEmitConstant(this, keyInst)); + } + + IRPtrLit* IRBuilder::getPtrValue(void* value) + { + IRType* type = getPtrType(getVoidType()); + + IRConstant keyInst; + memset(&keyInst, 0, sizeof(keyInst)); + keyInst.op = kIROp_PtrLit; + keyInst.typeUse.usedValue = type; + keyInst.value.ptrVal = value; + return (IRPtrLit*) findOrEmitConstant(this, keyInst); + } + + + IRInst* findOrEmitHoistableInst( + IRBuilder* builder, + IRType* type, + IROp op, + UInt operandListCount, + UInt const* listOperandCounts, + IRInst* const* const* listOperands) + { + UInt operandCount = 0; + for (UInt ii = 0; ii < operandListCount; ++ii) + { + operandCount += listOperandCounts[ii]; + } + + auto& memoryArena = builder->getModule()->memoryArena; + void* cursor = memoryArena.getCursor(); + + // We are going to create a 'dummy' instruction on the memoryArena + // which can be used as a key for lookup, so see if we + // already have an equivalent instruction available to use. + size_t keySize = sizeof(IRInst) + operandCount * sizeof(IRUse); + IRInst* inst = (IRInst*) memoryArena.allocateAndZero(keySize); + + void* endCursor = memoryArena.getCursor(); + // Mark as 'unused' cos it is unused on release builds. + SLANG_UNUSED(endCursor); + + new(inst) IRInst(); + inst->op = op; + inst->typeUse.usedValue = type; + inst->operandCount = (uint32_t) operandCount; + + // Don't link up as we may free (if we already have this key) + { + IRUse* operand = inst->getOperands(); + for (UInt ii = 0; ii < operandListCount; ++ii) + { + UInt listOperandCount = listOperandCounts[ii]; + for (UInt jj = 0; jj < listOperandCount; ++jj) + { + operand->usedValue = listOperands[ii][jj]; + operand++; + } + } + } + + // Find or add the key/inst + { + IRInstKey key = { inst }; + + // Ideally we would add if not found, else return if was found instead of testing & then adding. + IRInst** found = builder->sharedBuilder->globalValueNumberingMap.TryGetValueOrAdd(key, inst); + SLANG_ASSERT(endCursor == memoryArena.getCursor()); + // If it's found, just return, and throw away the instruction + if (found) + { + memoryArena.rewindToCursor(cursor); + return *found; + } + } + + // Make the lookup 'inst' instruction into 'proper' instruction. Equivalent to + // IRInst* inst = createInstImpl(builder, op, type, 0, nullptr, operandListCount, listOperandCounts, listOperands); + { + if (type) + { + inst->typeUse.usedValue = nullptr; + inst->typeUse.init(inst, type); + } + + maybeSetSourceLoc(builder, inst); + + IRUse*const operands = inst->getOperands(); + for (UInt i = 0; i < operandCount; ++i) + { + IRUse& operand = operands[i]; + auto value = operand.usedValue; + + operand.usedValue = nullptr; + operand.init(inst, value); + } + } + + addHoistableInst(builder, inst); + + return inst; + } + + IRInst* findOrEmitHoistableInst( + IRBuilder* builder, + IRType* type, + IROp op, + UInt operandCount, + IRInst* const* operands) + { + return findOrEmitHoistableInst( + builder, + type, + op, + 1, + &operandCount, + &operands); + } + + IRInst* findOrEmitHoistableInst( + IRBuilder* builder, + IRType* type, + IROp op, + IRInst* operand, + UInt operandCount, + IRInst* const* operands) + { + UInt counts[] = { 1, operandCount }; + IRInst* const* lists[] = { &operand, operands }; + + return findOrEmitHoistableInst( + builder, + type, + op, + 2, + counts, + lists); + } + + + IRType* IRBuilder::getType( + IROp op, + UInt operandCount, + IRInst* const* operands) + { + return (IRType*) findOrEmitHoistableInst( + this, + nullptr, + op, + operandCount, + operands); + } + + IRType* IRBuilder::getType( + IROp op) + { + return getType(op, 0, nullptr); + } + + IRBasicType* IRBuilder::getBasicType(BaseType baseType) + { + return (IRBasicType*)getType( + IROp((UInt)kIROp_FirstBasicType + (UInt)baseType)); + } + + IRBasicType* IRBuilder::getVoidType() + { + return (IRVoidType*)getType(kIROp_VoidType); + } + + IRBasicType* IRBuilder::getBoolType() + { + return (IRBoolType*)getType(kIROp_BoolType); + } + + IRBasicType* IRBuilder::getIntType() + { + return (IRBasicType*)getType(kIROp_IntType); + } + + IRStringType* IRBuilder::getStringType() + { + return (IRStringType*)getType(kIROp_StringType); + } + + IRBasicBlockType* IRBuilder::getBasicBlockType() + { + return (IRBasicBlockType*)getType(kIROp_BasicBlockType); + } + + IRTypeKind* IRBuilder::getTypeKind() + { + return (IRTypeKind*)getType(kIROp_TypeKind); + } + + IRGenericKind* IRBuilder::getGenericKind() + { + return (IRGenericKind*)getType(kIROp_GenericKind); + } + + IRPtrType* IRBuilder::getPtrType(IRType* valueType) + { + return (IRPtrType*) getPtrType(kIROp_PtrType, valueType); + } + + IROutType* IRBuilder::getOutType(IRType* valueType) + { + return (IROutType*) getPtrType(kIROp_OutType, valueType); + } + + IRInOutType* IRBuilder::getInOutType(IRType* valueType) + { + return (IRInOutType*) getPtrType(kIROp_InOutType, valueType); + } + + IRRefType* IRBuilder::getRefType(IRType* valueType) + { + return (IRRefType*) getPtrType(kIROp_RefType, valueType); + } + + IRPtrTypeBase* IRBuilder::getPtrType(IROp op, IRType* valueType) + { + IRInst* operands[] = { valueType }; + return (IRPtrTypeBase*) getType( + op, + 1, + operands); + } + + IRArrayTypeBase* IRBuilder::getArrayTypeBase( + IROp op, + IRType* elementType, + IRInst* elementCount) + { + IRInst* operands[] = { elementType, elementCount }; + return (IRArrayTypeBase*)getType( + op, + op == kIROp_ArrayType ? 2 : 1, + operands); + } + + IRArrayType* IRBuilder::getArrayType( + IRType* elementType, + IRInst* elementCount) + { + IRInst* operands[] = { elementType, elementCount }; + return (IRArrayType*)getType( + kIROp_ArrayType, + sizeof(operands) / sizeof(operands[0]), + operands); + } + + IRUnsizedArrayType* IRBuilder::getUnsizedArrayType( + IRType* elementType) + { + IRInst* operands[] = { elementType }; + return (IRUnsizedArrayType*)getType( + kIROp_UnsizedArrayType, + sizeof(operands) / sizeof(operands[0]), + operands); + } + + IRVectorType* IRBuilder::getVectorType( + IRType* elementType, + IRInst* elementCount) + { + IRInst* operands[] = { elementType, elementCount }; + return (IRVectorType*)getType( + kIROp_VectorType, + sizeof(operands) / sizeof(operands[0]), + operands); + } + + IRMatrixType* IRBuilder::getMatrixType( + IRType* elementType, + IRInst* rowCount, + IRInst* columnCount) + { + IRInst* operands[] = { elementType, rowCount, columnCount }; + return (IRMatrixType*)getType( + kIROp_MatrixType, + sizeof(operands) / sizeof(operands[0]), + operands); + } + + IRFuncType* IRBuilder::getFuncType( + UInt paramCount, + IRType* const* paramTypes, + IRType* resultType) + { + return (IRFuncType*) findOrEmitHoistableInst( + this, + nullptr, + kIROp_FuncType, + resultType, + paramCount, + (IRInst* const*) paramTypes); + } + + IRConstantBufferType* IRBuilder::getConstantBufferType(IRType* elementType) + { + IRInst* operands[] = { elementType }; + return (IRConstantBufferType*) getType( + kIROp_ConstantBufferType, + 1, + operands); + } + + IRConstExprRate* IRBuilder::getConstExprRate() + { + return (IRConstExprRate*)getType(kIROp_ConstExprRate); + } + + IRGroupSharedRate* IRBuilder::getGroupSharedRate() + { + return (IRGroupSharedRate*)getType(kIROp_GroupSharedRate); + } + + IRRateQualifiedType* IRBuilder::getRateQualifiedType( + IRRate* rate, + IRType* dataType) + { + IRInst* operands[] = { rate, dataType }; + return (IRRateQualifiedType*)getType( + kIROp_RateQualifiedType, + sizeof(operands) / sizeof(operands[0]), + operands); + } + + IRType* IRBuilder::getTaggedUnionType( + UInt caseCount, + IRType* const* caseTypes) + { + return (IRType*) findOrEmitHoistableInst( + this, + getTypeKind(), + kIROp_TaggedUnionType, + caseCount, + (IRInst* const*) caseTypes); + } + + IRType* IRBuilder::getBindExistentialsType( + IRInst* baseType, + UInt slotArgCount, + IRInst* const* slotArgs) + { + if(slotArgCount == 0) + return (IRType*) baseType; + + // If we are trying to bind an interface type, then + // we will go ahead and simplify the instruction + // away impmediately. + // + if(as(baseType)) + { + if(slotArgCount >= 1) + { + // We are being asked to emit `BindExistentials(someInterface, someConcreteType, ...)` + // so we just want to return `ExistentialBox`. + // + auto concreteType = (IRType*) slotArgs[0]; + auto ptrType = getPtrType(kIROp_ExistentialBoxType, concreteType); + return ptrType; + } + } + + return (IRType*) findOrEmitHoistableInst( + this, + getTypeKind(), + kIROp_BindExistentialsType, + baseType, + slotArgCount, + (IRInst* const*) slotArgs); + } + + IRType* IRBuilder::getBindExistentialsType( + IRInst* baseType, + UInt slotArgCount, + IRUse const* slotArgUses) + { + if(slotArgCount == 0) + return (IRType*) baseType; + + List slotArgs; + for( UInt ii = 0; ii < slotArgCount; ++ii ) + { + slotArgs.add(slotArgUses[ii].get()); + } + return getBindExistentialsType( + baseType, + slotArgCount, + slotArgs.getBuffer()); + } + + + + void IRBuilder::setDataType(IRInst* inst, IRType* dataType) + { + if (auto oldRateQualifiedType = as(inst->getFullType())) + { + // Construct a new rate-qualified type using the same rate. + + auto newRateQualifiedType = getRateQualifiedType( + oldRateQualifiedType->getRate(), + dataType); + + inst->setFullType(newRateQualifiedType); + } + else + { + // No rate? Just clobber the data type. + inst->setFullType(dataType); + } + } + + + IRUndefined* IRBuilder::emitUndefined(IRType* type) + { + auto inst = createInst( + this, + kIROp_undefined, + type); + + addInst(inst); + + return inst; + } + + IRInst* IRBuilder::emitExtractExistentialValue( + IRType* type, + IRInst* existentialValue) + { + auto inst = createInst( + this, + kIROp_ExtractExistentialValue, + type, + 1, + &existentialValue); + addInst(inst); + return inst; + } + + IRType* IRBuilder::emitExtractExistentialType( + IRInst* existentialValue) + { + auto type = getTypeKind(); + auto inst = createInst( + this, + kIROp_ExtractExistentialType, + type, + 1, + &existentialValue); + addInst(inst); + return (IRType*) inst; + } + + IRInst* IRBuilder::emitExtractExistentialWitnessTable( + IRInst* existentialValue) + { + auto type = getWitnessTableType(); + auto inst = createInst( + this, + kIROp_ExtractExistentialWitnessTable, + type, + 1, + &existentialValue); + addInst(inst); + return inst; + } + + IRInst* IRBuilder::emitSpecializeInst( + IRType* type, + IRInst* genericVal, + UInt argCount, + IRInst* const* args) + { + auto inst = createInstWithTrailingArgs( + this, + kIROp_Specialize, + type, + 1, + &genericVal, + argCount, + args); + + addInst(inst); + return inst; + } + + IRInst* IRBuilder::emitLookupInterfaceMethodInst( + IRType* type, + IRInst* witnessTableVal, + IRInst* interfaceMethodVal) + { + auto inst = createInst( + this, + kIROp_lookup_interface_method, + type, + witnessTableVal, + interfaceMethodVal); + + addInst(inst); + return inst; + } + + IRInst* IRBuilder::emitCallInst( + IRType* type, + IRInst* pFunc, + UInt argCount, + IRInst* const* args) + { + auto inst = createInstWithTrailingArgs( + this, + kIROp_Call, + type, + 1, + &pFunc, + argCount, + args); + addInst(inst); + return inst; + } + + IRInst* IRBuilder::createIntrinsicInst( + IRType* type, + IROp op, + UInt argCount, + IRInst* const* args) + { + return createInstWithTrailingArgs( + this, + op, + type, + argCount, + args); + } + + + IRInst* IRBuilder::emitIntrinsicInst( + IRType* type, + IROp op, + UInt argCount, + IRInst* const* args) + { + auto inst = createIntrinsicInst( + type, + op, + argCount, + args); + addInst(inst); + return inst; + } + + IRInst* IRBuilder::emitConstructorInst( + IRType* type, + UInt argCount, + IRInst* const* args) + { + auto inst = createInstWithTrailingArgs( + this, + kIROp_Construct, + type, + argCount, + args); + addInst(inst); + return inst; + } + + IRInst* IRBuilder::emitMakeVector( + IRType* type, + UInt argCount, + IRInst* const* args) + { + return emitIntrinsicInst(type, kIROp_makeVector, argCount, args); + } + + IRInst* IRBuilder::emitMakeMatrix( + IRType* type, + UInt argCount, + IRInst* const* args) + { + return emitIntrinsicInst(type, kIROp_MakeMatrix, argCount, args); + } + + IRInst* IRBuilder::emitMakeArray( + IRType* type, + UInt argCount, + IRInst* const* args) + { + return emitIntrinsicInst(type, kIROp_makeArray, argCount, args); + } + + IRInst* IRBuilder::emitMakeStruct( + IRType* type, + UInt argCount, + IRInst* const* args) + { + return emitIntrinsicInst(type, kIROp_makeStruct, argCount, args); + } + + IRInst* IRBuilder::emitMakeExistential( + IRType* type, + IRInst* value, + IRInst* witnessTable) + { + IRInst* args[] = {value, witnessTable}; + return emitIntrinsicInst(type, kIROp_MakeExistential, SLANG_COUNT_OF(args), args); + } + + IRInst* IRBuilder::emitWrapExistential( + IRType* type, + IRInst* value, + UInt slotArgCount, + IRInst* const* slotArgs) + { + if(slotArgCount == 0) + return value; + + // If we are wrapping a single concrete value into + // an interface type, then this is really a `makeExistential` + // + // TODO: We may want to check for a `specialize` of a generic interface as well. + // + if(as(type)) + { + if(slotArgCount >= 2) + { + // We are being asked to emit `wrapExistential(value, concreteType, witnessTable, ...) : someInterface` + // + // We also know that a concrete value being wrapped will always be an existential box, + // so we expect that `value : ExistentialBox` for some `T`. + // + // We want to emit `makeExistential(load(value), witnessTable)`. + // + auto deref = emitLoad(value); + return emitMakeExistential(type, deref, slotArgs[1]); + } + } + + IRInst* fixedArgs[] = {value}; + auto inst = createInstImpl( + this, + kIROp_WrapExistential, + type, + SLANG_COUNT_OF(fixedArgs), + fixedArgs, + slotArgCount, + slotArgs); + addInst(inst); + return inst; + } + + IRModule* IRBuilder::createModule() + { + auto module = new IRModule(); + module->session = getSession(); + + auto moduleInst = createInstImpl( + module, + this, + kIROp_Module, + nullptr, + 0, + nullptr, + 0, + nullptr, + nullptr); + module->moduleInst = moduleInst; + moduleInst->module = module; + + return module; + } + + void addGlobalValue( + IRBuilder* builder, + IRInst* value) + { + // Try to find a suitable parent for the + // global value we are emitting. + // + // We will start out search at the current + // parent instruction for the builder, and + // possibly work our way up. + // + auto parent = builder->insertIntoParent; + while(parent) + { + // Inserting into the top level of a module? + // That is fine, and we can stop searching. + if (as(parent)) + break; + + // Inserting into a basic block inside of + // a generic? That is okay too. + if (auto block = as(parent)) + { + if (as(block->parent)) + break; + } + + // Otherwise, move up the chain. + parent = parent->parent; + } + + // If we somehow ran out of parents (possibly + // because an instruction wasn't linked into + // the full hierarchy yet), then we will + // fall back to inserting into the overall module. + if (!parent) + { + parent = builder->getModule()->getModuleInst(); + } + + // If it turns out that we are inserting into the + // current "insert into" parent for the builder, then + // we need to respect its "insert before" setting + // as well. + if (parent == builder->insertIntoParent + && builder->insertBeforeInst) + { + value->insertBefore(builder->insertBeforeInst); + } + else + { + value->insertAtEnd(parent); + } + } + + IRFunc* IRBuilder::createFunc() + { + IRFunc* rsFunc = createInst( + this, + kIROp_Func, + nullptr); + maybeSetSourceLoc(this, rsFunc); + addGlobalValue(this, rsFunc); + return rsFunc; + } + + IRGlobalVar* IRBuilder::createGlobalVar( + IRType* valueType) + { + auto ptrType = getPtrType(valueType); + IRGlobalVar* globalVar = createInst( + this, + kIROp_GlobalVar, + ptrType); + maybeSetSourceLoc(this, globalVar); + addGlobalValue(this, globalVar); + return globalVar; + } + + IRGlobalConstant* IRBuilder::createGlobalConstant( + IRType* valueType) + { + IRGlobalConstant* globalConstant = createInst( + this, + kIROp_GlobalConstant, + valueType); + maybeSetSourceLoc(this, globalConstant); + addGlobalValue(this, globalConstant); + return globalConstant; + } + + IRGlobalParam* IRBuilder::createGlobalParam( + IRType* valueType) + { + IRGlobalParam* inst = createInst( + this, + kIROp_GlobalParam, + valueType); + maybeSetSourceLoc(this, inst); + addGlobalValue(this, inst); + return inst; + } + + IRWitnessTable* IRBuilder::createWitnessTable() + { + IRWitnessTable* witnessTable = createInst( + this, + kIROp_WitnessTable, + nullptr); + addGlobalValue(this, witnessTable); + return witnessTable; + } + + IRWitnessTableEntry* IRBuilder::createWitnessTableEntry( + IRWitnessTable* witnessTable, + IRInst* requirementKey, + IRInst* satisfyingVal) + { + IRWitnessTableEntry* entry = createInst( + this, + kIROp_WitnessTableEntry, + nullptr, + requirementKey, + satisfyingVal); + + if (witnessTable) + { + entry->insertAtEnd(witnessTable); + } + + return entry; + } + + IRStructType* IRBuilder::createStructType() + { + IRStructType* structType = createInst( + this, + kIROp_StructType, + nullptr); + addGlobalValue(this, structType); + return structType; + } + + IRInterfaceType* IRBuilder::createInterfaceType() + { + IRInterfaceType* interfaceType = createInst( + this, + kIROp_InterfaceType, + nullptr); + addGlobalValue(this, interfaceType); + return interfaceType; + } + + IRStructKey* IRBuilder::createStructKey() + { + IRStructKey* structKey = createInst( + this, + kIROp_StructKey, + nullptr); + addGlobalValue(this, structKey); + return structKey; + } + + // Create a field nested in a struct type, declaring that + // the specified field key maps to a field with the specified type. + IRStructField* IRBuilder::createStructField( + IRStructType* structType, + IRStructKey* fieldKey, + IRType* fieldType) + { + IRInst* operands[] = { fieldKey, fieldType }; + IRStructField* field = (IRStructField*) createInstWithTrailingArgs( + this, + kIROp_StructField, + nullptr, + 0, + nullptr, + 2, + operands); + + if (structType) + { + field->insertAtEnd(structType); + } + + return field; + } + + IRGeneric* IRBuilder::createGeneric() + { + IRGeneric* irGeneric = createInst( + this, + kIROp_Generic, + nullptr); + return irGeneric; + } + + IRGeneric* IRBuilder::emitGeneric() + { + auto irGeneric = createGeneric(); + addGlobalValue(this, irGeneric); + return irGeneric; + } + + IRBlock* IRBuilder::createBlock() + { + return createInst( + this, + kIROp_Block, + getBasicBlockType()); + } + + void IRBuilder::insertBlock(IRBlock* block) + { + // If we are emitting into a function + // (or another value with code), then + // append the block to the function and + // set this block as the new parent for + // subsequent instructions we insert. + // + // TODO: This should probably insert the block + // after the current "insert into" block if + // there is one. Right now we are always + // adding the block to the end of the list, + // which is technically valid (the ordering + // of blocks doesn't affect the CFG topology), + // but some later passes might assume the ordering + // is significant in representing the intent + // of the original code. + // + auto f = getFunc(); + if (f) + { + f->addBlock(block); + setInsertInto(block); + } + } + + IRBlock* IRBuilder::emitBlock() + { + auto block = createBlock(); + insertBlock(block); + return block; + } + + IRParam* IRBuilder::createParam( + IRType* type) + { + auto param = createInst( + this, + kIROp_Param, + type); + return param; + } + + IRParam* IRBuilder::emitParam( + IRType* type) + { + auto param = createParam(type); + if (auto bb = getBlock()) + { + bb->addParam(param); + } + return param; + } + + IRVar* IRBuilder::emitVar( + IRType* type) + { + auto allocatedType = getPtrType(type); + auto inst = createInst( + this, + kIROp_Var, + allocatedType); + addInst(inst); + return inst; + } + + IRInst* IRBuilder::emitLoad( + IRType* type, + IRInst* ptr) + { + auto inst = createInst( + this, + kIROp_Load, + type, + ptr); + + addInst(inst); + return inst; + } + + IRInst* IRBuilder::emitLoad( + IRInst* ptr) + { + // Note: a `load` operation does not consider the rate + // (if any) attached to its operand (see the use of `getDataType` + // below). This means that a load from a rate-qualified + // variable will still conceptually execute (and return + // results) at the "default" rate of the parent function, + // unless a subsequent analysis pass constraints it. + + IRType* valueType = tryGetPointedToType(this, ptr->getFullType()); + SLANG_ASSERT(valueType); + + // Ugly special case: if the front-end created a variable with + // type `Ptr<@R T>` instead of `@R Ptr`, then the above + // logic will yield `@R T` instead of `T`, and we need to + // try and fix that up here. + // + // TODO: Lowering to the IR should be fixed to never create + // that case: rate-qualified types should only be allowed + // to appear as the type of an instruction, and should not + // be allowed as operands to type constructors (except + // in special cases we decide to allow). + // + if(auto rateType = as(valueType)) + { + valueType = rateType->getValueType(); + } + + return emitLoad(valueType, ptr); + } + + IRInst* IRBuilder::emitStore( + IRInst* dstPtr, + IRInst* srcVal) + { + auto inst = createInst( + this, + kIROp_Store, + nullptr, + dstPtr, + srcVal); + + addInst(inst); + return inst; + } + + IRInst* IRBuilder::emitFieldExtract( + IRType* type, + IRInst* base, + IRInst* field) + { + auto inst = createInst( + this, + kIROp_FieldExtract, + type, + base, + field); + + addInst(inst); + return inst; + } + + IRInst* IRBuilder::emitFieldAddress( + IRType* type, + IRInst* base, + IRInst* field) + { + auto inst = createInst( + this, + kIROp_FieldAddress, + type, + base, + field); + + addInst(inst); + return inst; + } + + IRInst* IRBuilder::emitElementExtract( + IRType* type, + IRInst* base, + IRInst* index) + { + auto inst = createInst( + this, + kIROp_getElement, + type, + base, + index); + + addInst(inst); + return inst; + } + + IRInst* IRBuilder::emitElementAddress( + IRType* type, + IRInst* basePtr, + IRInst* index) + { + auto inst = createInst( + this, + kIROp_getElementPtr, + type, + basePtr, + index); + + addInst(inst); + return inst; + } + + IRInst* IRBuilder::emitSwizzle( + IRType* type, + IRInst* base, + UInt elementCount, + IRInst* const* elementIndices) + { + auto inst = createInstWithTrailingArgs( + this, + kIROp_swizzle, + type, + base, + elementCount, + elementIndices); + + addInst(inst); + return inst; + } + + IRInst* IRBuilder::emitSwizzle( + IRType* type, + IRInst* base, + UInt elementCount, + UInt const* elementIndices) + { + auto intType = getBasicType(BaseType::Int); + + IRInst* irElementIndices[4]; + for (UInt ii = 0; ii < elementCount; ++ii) + { + irElementIndices[ii] = getIntValue(intType, elementIndices[ii]); + } + + return emitSwizzle(type, base, elementCount, irElementIndices); + } + + + IRInst* IRBuilder::emitSwizzleSet( + IRType* type, + IRInst* base, + IRInst* source, + UInt elementCount, + IRInst* const* elementIndices) + { + IRInst* fixedArgs[] = { base, source }; + UInt fixedArgCount = sizeof(fixedArgs) / sizeof(fixedArgs[0]); + + auto inst = createInstWithTrailingArgs( + this, + kIROp_swizzleSet, + type, + fixedArgCount, + fixedArgs, + elementCount, + elementIndices); + + addInst(inst); + return inst; + } + + IRInst* IRBuilder::emitSwizzleSet( + IRType* type, + IRInst* base, + IRInst* source, + UInt elementCount, + UInt const* elementIndices) + { + auto intType = getBasicType(BaseType::Int); + + IRInst* irElementIndices[4]; + for (UInt ii = 0; ii < elementCount; ++ii) + { + irElementIndices[ii] = getIntValue(intType, elementIndices[ii]); + } + + return emitSwizzleSet(type, base, source, elementCount, irElementIndices); + } + + IRInst* IRBuilder::emitSwizzledStore( + IRInst* dest, + IRInst* source, + UInt elementCount, + IRInst* const* elementIndices) + { + IRInst* fixedArgs[] = { dest, source }; + UInt fixedArgCount = sizeof(fixedArgs) / sizeof(fixedArgs[0]); + + auto inst = createInstImpl( + this, + kIROp_SwizzledStore, + nullptr, + fixedArgCount, + fixedArgs, + elementCount, + elementIndices); + + addInst(inst); + return inst; + } + + IRInst* IRBuilder::emitSwizzledStore( + IRInst* dest, + IRInst* source, + UInt elementCount, + UInt const* elementIndices) + { + auto intType = getBasicType(BaseType::Int); + + IRInst* irElementIndices[4]; + for (UInt ii = 0; ii < elementCount; ++ii) + { + irElementIndices[ii] = getIntValue(intType, elementIndices[ii]); + } + + return emitSwizzledStore(dest, source, elementCount, irElementIndices); + } + + IRInst* IRBuilder::emitReturn( + IRInst* val) + { + auto inst = createInst( + this, + kIROp_ReturnVal, + nullptr, + val); + addInst(inst); + return inst; + } + + IRInst* IRBuilder::emitReturn() + { + auto inst = createInst( + this, + kIROp_ReturnVoid, + nullptr); + addInst(inst); + return inst; + } + + IRInst* IRBuilder::emitUnreachable() + { + auto inst = createInst( + this, + kIROp_Unreachable, + nullptr); + addInst(inst); + return inst; + } + + IRInst* IRBuilder::emitMissingReturn() + { + auto inst = createInst( + this, + kIROp_MissingReturn, + nullptr); + addInst(inst); + return inst; + } + + IRInst* IRBuilder::emitDiscard() + { + auto inst = createInst( + this, + kIROp_discard, + nullptr); + addInst(inst); + return inst; + } + + + IRInst* IRBuilder::emitBranch( + IRBlock* pBlock) + { + auto inst = createInst( + this, + kIROp_unconditionalBranch, + nullptr, + pBlock); + addInst(inst); + return inst; + } + + IRInst* IRBuilder::emitBreak( + IRBlock* target) + { + return emitBranch(target); + } + + IRInst* IRBuilder::emitContinue( + IRBlock* target) + { + return emitBranch(target); + } + + IRInst* IRBuilder::emitLoop( + IRBlock* target, + IRBlock* breakBlock, + IRBlock* continueBlock) + { + IRInst* args[] = { target, breakBlock, continueBlock }; + UInt argCount = sizeof(args) / sizeof(args[0]); + + auto inst = createInst( + this, + kIROp_loop, + nullptr, + argCount, + args); + addInst(inst); + return inst; + } + + IRInst* IRBuilder::emitBranch( + IRInst* val, + IRBlock* trueBlock, + IRBlock* falseBlock) + { + IRInst* args[] = { val, trueBlock, falseBlock }; + UInt argCount = sizeof(args) / sizeof(args[0]); + + auto inst = createInst( + this, + kIROp_conditionalBranch, + nullptr, + argCount, + args); + addInst(inst); + return inst; + } + + IRInst* IRBuilder::emitIfElse( + IRInst* val, + IRBlock* trueBlock, + IRBlock* falseBlock, + IRBlock* afterBlock) + { + IRInst* args[] = { val, trueBlock, falseBlock, afterBlock }; + UInt argCount = sizeof(args) / sizeof(args[0]); + + auto inst = createInst( + this, + kIROp_ifElse, + nullptr, + argCount, + args); + addInst(inst); + return inst; + } + + IRInst* IRBuilder::emitIf( + IRInst* val, + IRBlock* trueBlock, + IRBlock* afterBlock) + { + return emitIfElse(val, trueBlock, afterBlock, afterBlock); + } + + IRInst* IRBuilder::emitLoopTest( + IRInst* val, + IRBlock* bodyBlock, + IRBlock* breakBlock) + { + return emitIfElse(val, bodyBlock, breakBlock, bodyBlock); + } + + IRInst* IRBuilder::emitSwitch( + IRInst* val, + IRBlock* breakLabel, + IRBlock* defaultLabel, + UInt caseArgCount, + IRInst* const* caseArgs) + { + IRInst* fixedArgs[] = { val, breakLabel, defaultLabel }; + UInt fixedArgCount = sizeof(fixedArgs) / sizeof(fixedArgs[0]); + + auto inst = createInstWithTrailingArgs( + this, + kIROp_Switch, + nullptr, + fixedArgCount, + fixedArgs, + caseArgCount, + caseArgs); + addInst(inst); + return inst; + } + + IRGlobalGenericParam* IRBuilder::emitGlobalGenericParam() + { + IRGlobalGenericParam* irGenericParam = createInst( + this, + kIROp_GlobalGenericParam, + nullptr); + addGlobalValue(this, irGenericParam); + return irGenericParam; + } + + IRBindGlobalGenericParam* IRBuilder::emitBindGlobalGenericParam( + IRInst* param, + IRInst* val) + { + auto inst = createInst( + this, + kIROp_BindGlobalGenericParam, + nullptr, + param, + val); + addInst(inst); + return inst; + } + + IRInst* IRBuilder::emitBindGlobalExistentialSlots( + UInt argCount, + IRInst* const* args) + { + auto inst = createInstWithTrailingArgs( + this, + kIROp_BindGlobalExistentialSlots, + getVoidType(), + 0, + nullptr, + argCount, + args); + addInst(inst); + return inst; + } + + IRDecoration* IRBuilder::addBindExistentialSlotsDecoration( + IRInst* value, + UInt argCount, + IRInst* const* args) + { + auto decoration = createInstWithTrailingArgs( + this, + kIROp_BindExistentialSlotsDecoration, + getVoidType(), + 0, + nullptr, + argCount, + args); + + decoration->insertAtStart(value); + + return decoration; + } + + IRInst* IRBuilder::emitExtractTaggedUnionTag( + IRInst* val) + { + auto inst = createInst( + this, + kIROp_ExtractTaggedUnionTag, + getBasicType(BaseType::UInt), + val); + addInst(inst); + return inst; + } + + IRInst* IRBuilder::emitExtractTaggedUnionPayload( + IRType* type, + IRInst* val, + IRInst* tag) + { + auto inst = createInst( + this, + kIROp_ExtractTaggedUnionPayload, + type, + val, + tag); + addInst(inst); + return inst; + } + + IRInst* IRBuilder::emitBitCast( + IRType* type, + IRInst* val) + { + auto inst = createInst( + this, + kIROp_BitCast, + type, + val); + addInst(inst); + return inst; + } + + // + // Decorations + // + + IRDecoration* IRBuilder::addDecoration(IRInst* value, IROp op, IRInst* const* operands, Int operandCount) + { + auto decoration = createInstWithTrailingArgs( + this, + op, + getVoidType(), + operandCount, + operands); + + // Decoration order should not, in general, be semantically + // meaningful, so we will elect to insert a new decoration + // at the start of an instruction (constant time) rather + // than at the end of any existing list of deocrations + // (which would take time linear in the number of decorations). + // + // TODO: revisit this if maintaining decoration ordering + // from input source code is desirable. + // + decoration->insertAtStart(value); + + return decoration; + } + + + void IRBuilder::addHighLevelDeclDecoration(IRInst* inst, Decl* decl) + { + auto ptrConst = getPtrValue(addRefObjectToFree(decl)); + addDecoration(inst, kIROp_HighLevelDeclDecoration, ptrConst); + } + + void IRBuilder::addLayoutDecoration(IRInst* inst, Layout* layout) + { + auto ptrConst = getPtrValue(addRefObjectToFree(layout)); + addDecoration(inst, kIROp_LayoutDecoration, ptrConst); + } + + // + + + struct IRDumpContext + { + StringBuilder* builder = nullptr; + int indent = 0; + IRDumpMode mode = IRDumpMode::Simplified; + + Dictionary mapValueToName; + Dictionary uniqueNameCounters; + UInt uniqueIDCounter = 1; + }; + + static void dump( + IRDumpContext* context, + char const* text) + { + context->builder->append(text); + } + + static void dump( + IRDumpContext* context, + String const& text) + { + context->builder->append(text); + } + + /* + static void dump( + IRDumpContext* context, + UInt val) + { + context->builder->append(val); + } + */ + + static void dump( + IRDumpContext* context, + IntegerLiteralValue val) + { + context->builder->append(val); + } + + static void dump( + IRDumpContext* context, + FloatingPointLiteralValue val) + { + context->builder->append(val); + } + + static void dumpIndent( + IRDumpContext* context) + { + for (int ii = 0; ii < context->indent; ++ii) + { + dump(context, "\t"); + } + } + + bool opHasResult(IRInst* inst) + { + auto type = inst->getDataType(); + if (!type) return false; + + // As a bit of a hack right now, we need to check whether + // the function returns the distinguished `Void` type, + // since that is conceptually the same as "not returning + // a value." + if(type->op == kIROp_VoidType) + return false; + + return true; + } + + bool instHasUses(IRInst* inst) + { + return inst->firstUse != nullptr; + } + + static void scrubName( + String const& name, + StringBuilder& sb) + { + // Note: this function duplicates a lot of the logic + // in `EmitVisitor::scrubName`, so we should consider + // whether they can share code at some point. + // + // There is no requirement that assembly dumps and output + // code follow the same model, though, so this is just + // a nice-to-have rather than a maintenance problem + // waiting to happen. + + // Allow an empty nam + // Special case a name that is the empty string, just in case. + if(name.getLength() == 0) + { + sb.append('_'); + } + + int prevChar = -1; + for(auto c : name) + { + if(c == '.') + { + c = '_'; + } + + if(((c >= 'a') && (c <= 'z')) + || ((c >= 'A') && (c <= 'Z'))) + { + // Ordinary ASCII alphabetic characters are assumed + // to always be okay. + } + else if((c >= '0') && (c <= '9')) + { + // We don't want to allow a digit as the first + // byte in a name. + if(prevChar == -1) + { + sb.append('_'); + } + } + else + { + // If we run into a character that wouldn't normally + // be allowed in an identifier, we need to translate + // it into something that *is* valid. + // + // Our solution for now will be very clumsy: we will + // emit `x` and then the hexadecimal version of + // the byte we were given. + sb.append("x"); + sb.append(uint32_t((unsigned char) c), 16); + + // We don't want to apply the default handling below, + // so skip to the top of the loop now. + prevChar = c; + continue; + } + + sb.append(c); + prevChar = c; + } + + // If the whole thing ended with a digit, then add + // a final `_` just to make sure that we can append + // a unique ID suffix without risk of collisions. + if(('0' <= prevChar) && (prevChar <= '9')) + { + sb.append('_'); + } + } + + static String createName( + IRDumpContext* context, + IRInst* value) + { + if(auto nameHintDecoration = value->findDecoration()) + { + String nameHint = nameHintDecoration->getName(); + + StringBuilder sb; + scrubName(nameHint, sb); + + String key = sb.ProduceString(); + UInt count = 0; + context->uniqueNameCounters.TryGetValue(key, count); + + context->uniqueNameCounters[key] = count+1; + + if(count) + { + sb.append(count); + } + return sb.ProduceString(); + } + else + { + StringBuilder sb; + auto id = context->uniqueIDCounter++; + sb.append(id); + return sb.ProduceString(); + } + } + + static String getName( + IRDumpContext* context, + IRInst* value) + { + String name; + if (context->mapValueToName.TryGetValue(value, name)) + return name; + + name = createName(context, value); + context->mapValueToName.Add(value, name); + return name; + } + + static void dumpID( + IRDumpContext* context, + IRInst* inst) + { + if (!inst) + { + dump(context, ""); + return; + } + + if( opHasResult(inst) || instHasUses(inst) ) + { + dump(context, "%"); + dump(context, getName(context, inst)); + } + else + { + dump(context, "_"); + } + } + + + + struct StringEncoder + { + static char getHexChar(int v) + { + return (v <= 9) ? char(v + '0') : char(v - 10 + 'A'); + } + + void flush(const char* pos) + { + if (pos > m_runStart) + { + m_builder->append(m_runStart, pos); + } + m_runStart = pos + 1; + } + + void appendEscapedChar(const char* pos, char encodeChar) + { + flush(pos); + const char chars[] = { '\\', encodeChar }; + m_builder->Append(chars, 2); + } + + void appendAsHex(const char* pos) + { + flush(pos); + + const int v = *(const uint8_t*)pos; + + char buf[5]; + buf[0] = '\\'; + buf[1] = 'x'; + buf[2] = '0'; + + buf[3] = getHexChar(v >> 4); + buf[4] = getHexChar(v & 0xf); + + m_builder->Append(buf, 5); + } + + StringEncoder(StringBuilder* builder, const char* start): + m_runStart(start), + m_builder(builder) + {} + + StringBuilder* m_builder; + const char* m_runStart; + }; + + static void dumpEncodeString( + IRDumpContext* context, + const UnownedStringSlice& slice) + { + // https://msdn.microsoft.com/en-us/library/69ze775t.aspx + + StringBuilder& builder = *context->builder; + builder.Append('"'); + + { + const char* cur = slice.begin(); + StringEncoder encoder(&builder, cur); + const char* end = slice.end(); + + for (; cur < end; cur++) + { + const int8_t c = uint8_t(*cur); + switch (c) + { + case '\\': + encoder.appendEscapedChar(cur, '\\'); + break; + case '"': + encoder.appendEscapedChar(cur, '"'); + break; + case '\n': + encoder.appendEscapedChar(cur, 'n'); + break; + case '\t': + encoder.appendEscapedChar(cur, 't'); + break; + case '\r': + encoder.appendEscapedChar(cur, 'r'); + break; + case '\0': + encoder.appendEscapedChar(cur, '0'); + break; + default: + { + if (c < 32) + { + encoder.appendAsHex(cur); + } + break; + } + } + } + encoder.flush(end); + } + + builder.Append('"'); + } + + static void dumpType( + IRDumpContext* context, + IRType* type); + + static bool shouldFoldInstIntoUses( + IRDumpContext* context, + IRInst* inst) + { + // Never fold an instruction into its use site + // in the "detailed" mode, so that we always + // accurately reflect the structure of the IR. + // + if(context->mode == IRDumpMode::Detailed) + return false; + + if(as(inst)) + return true; + + // We are going to have a general rule that + // a type should be folded into its use site, + // which improves output in most cases, but + // we would like to not apply that rule to + // "nominal" types like `struct`s. + // + switch( inst->op ) + { + case kIROp_StructType: + case kIROp_InterfaceType: + return false; + + default: + break; + } + + if(as(inst)) + return true; + + return false; + } + + static void dumpInst( + IRDumpContext* context, + IRInst* inst); + + static void dumpInstBody( + IRDumpContext* context, + IRInst* inst); + + static void dumpInstExpr( + IRDumpContext* context, + IRInst* inst); + + static void dumpOperand( + IRDumpContext* context, + IRInst* inst) + { + // TODO: we should have a dedicated value for the `undef` case + if (!inst) + { + dumpID(context, inst); + return; + } + + if(shouldFoldInstIntoUses(context, inst)) + { + dumpInstExpr(context, inst); + return; + } + + dumpID(context, inst); + } + + static void dumpType( + IRDumpContext* context, + IRType* type) + { + if (!type) + { + dump(context, "_"); + return; + } + + // TODO: we should consider some special-case printing + // for types, so that the IR doesn't get too hard to read + // (always having to back-reference for what a type expands to) + dumpOperand(context, type); + } + + static void dumpInstTypeClause( + IRDumpContext* context, + IRType* type) + { + dump(context, "\t: "); + dumpType(context, type); + + } + + void dumpIRDecorations( + IRDumpContext* context, + IRInst* inst) + { + for(auto dd : inst->getDecorations()) + { + // Certain decorations aren't helpful to appear + // in output dumps, so we will only include them + // in the "detailed" dumping mode. + // + // For all other modes, we will check the opcode + // and skip selected decorations. + // + if(context->mode != IRDumpMode::Detailed) + { + switch(dd->op) + { + default: + break; + + case kIROp_HighLevelDeclDecoration: + case kIROp_LayoutDecoration: + continue; + } + } + + dump(context, "["); + dumpInstBody(context, dd); + dump(context, "]\n"); + + dumpIndent(context); + } + } + + static void dumpBlock( + IRDumpContext* context, + IRBlock* block) + { + context->indent--; + dump(context, "block "); + dumpID(context, block); + + IRInst* inst = block->getFirstInst(); + + // First walk through any `param` instructions, + // so that we can format them nicely + if (auto firstParam = as(inst)) + { + dump(context, "(\n"); + context->indent += 2; + + for(;;) + { + auto param = as(inst); + if (!param) + break; + + if (param != firstParam) + dump(context, ",\n"); + + inst = inst->getNextInst(); + + dumpIndent(context); + dumpIRDecorations(context, param); + dump(context, "param "); + dumpID(context, param); + dumpInstTypeClause(context, param->getFullType()); + } + context->indent -= 2; + dump(context, ")"); + } + dump(context, ":\n"); + context->indent++; + + for(; inst; inst = inst->getNextInst()) + { + dumpInst(context, inst); + } + } + + void dumpIRGlobalValueWithCode( + IRDumpContext* context, + IRGlobalValueWithCode* code) + { + auto opInfo = getIROpInfo(code->op); + + dumpIndent(context); + dump(context, opInfo.name); + dump(context, " "); + dumpID(context, code); + + dumpInstTypeClause(context, code->getFullType()); + + if (!code->getFirstBlock()) + { + // Just a declaration. + dump(context, ";\n"); + return; + } + + dump(context, "\n"); + + dumpIndent(context); + dump(context, "{\n"); + context->indent++; + + for (auto bb = code->getFirstBlock(); bb; bb = bb->getNextBlock()) + { + if (bb != code->getFirstBlock()) + dump(context, "\n"); + dumpBlock(context, bb); + } + + context->indent--; + dump(context, "}"); + } + + + void dumpIRWitnessTableEntry( + IRDumpContext* context, + IRWitnessTableEntry* entry) + { + dump(context, "witness_table_entry("); + dumpOperand(context, entry->requirementKey.get()); + dump(context, ","); + dumpOperand(context, entry->satisfyingVal.get()); + dump(context, ")\n"); + } + + void dumpIRParentInst( + IRDumpContext* context, + IRInst* inst) + { + auto opInfo = getIROpInfo(inst->op); + + dumpIndent(context); + dump(context, opInfo.name); + dump(context, " "); + dumpID(context, inst); + + dumpInstTypeClause(context, inst->getFullType()); + + if (!inst->getFirstChild()) + { + // Empty. + dump(context, ";\n"); + return; + } + + dump(context, "\n"); + + dumpIndent(context); + dump(context, "{\n"); + context->indent++; + + for(auto child : inst->getChildren()) + { + dumpInst(context, child); + } + + context->indent--; + dump(context, "}\n"); + } + + void dumpIRGeneric( + IRDumpContext* context, + IRGeneric* witnessTable) + { + dump(context, "\n"); + dumpIndent(context); + dump(context, "ir_witness_table "); + dumpID(context, witnessTable); + dump(context, "\n{\n"); + context->indent++; + + for (auto ii : witnessTable->getChildren()) + { + dumpInst(context, ii); + } + + context->indent--; + dump(context, "}\n"); + } + + static void dumpInstExpr( + IRDumpContext* context, + IRInst* inst) + { + if (!inst) + { + dump(context, ""); + return; + } + + auto op = inst->op; + auto opInfo = getIROpInfo(op); + + // Special-case the literal instructions. + if(auto irConst = as(inst)) + { + switch (op) + { + case kIROp_IntLit: + dump(context, irConst->value.intVal); + return; + + case kIROp_FloatLit: + dump(context, irConst->value.floatVal); + return; + + case kIROp_BoolLit: + dump(context, irConst->value.intVal ? "true" : "false"); + return; + + case kIROp_StringLit: + dumpEncodeString(context, irConst->getStringSlice()); + return; + + case kIROp_PtrLit: + dump(context, ""); + return; + + default: + break; + } + } + + dump(context, opInfo.name); + + UInt argCount = inst->getOperandCount(); + + if(argCount == 0) + return; + + UInt ii = 0; + + // Special case: make printing of `call` a bit + // nicer to look at + if (inst->op == kIROp_Call && argCount > 0) + { + dump(context, " "); + auto argVal = inst->getOperand(ii++); + dumpOperand(context, argVal); + } + + bool first = true; + dump(context, "("); + for (; ii < argCount; ++ii) + { + if (!first) + dump(context, ", "); + + auto argVal = inst->getOperand(ii); + + dumpOperand(context, argVal); + + first = false; + } + + dump(context, ")"); + + } + + static void dumpInstBody( + IRDumpContext* context, + IRInst* inst) + { + if (!inst) + { + dump(context, ""); + return; + } + + auto op = inst->op; + + dumpIRDecorations(context, inst); + + // There are several ops we want to special-case here, + // so that they will be more pleasant to look at. + // + switch (op) + { + case kIROp_Func: + case kIROp_GlobalVar: + case kIROp_GlobalConstant: + case kIROp_Generic: + dumpIRGlobalValueWithCode(context, (IRGlobalValueWithCode*)inst); + return; + + case kIROp_WitnessTable: + case kIROp_StructType: + dumpIRParentInst(context, inst); + return; + + case kIROp_WitnessTableEntry: + dumpIRWitnessTableEntry(context, (IRWitnessTableEntry*)inst); + return; + + default: + break; + } + + // Okay, we have a seemingly "ordinary" op now + auto dataType = inst->getDataType(); + auto rate = inst->getRate(); + + if(rate) + { + dump(context, "@"); + dumpOperand(context, rate); + dump(context, " "); + } + + if(opHasResult(inst) || instHasUses(inst)) + { + dump(context, "let "); + dumpID(context, inst); + dumpInstTypeClause(context, dataType); + dump(context, "\t= "); + } + else + { + // No result, okay... + } + + dumpInstExpr(context, inst); + } + + static void dumpInst( + IRDumpContext* context, + IRInst* inst) + { + if(shouldFoldInstIntoUses(context, inst)) + return; + + dumpIndent(context); + dumpInstBody(context, inst); + dump(context, "\n"); + } + + void dumpIRModule( + IRDumpContext* context, + IRModule* module) + { + for(auto ii : module->getGlobalInsts()) + { + dumpInst(context, ii); + } + } + + void printSlangIRAssembly(StringBuilder& builder, IRModule* module, IRDumpMode mode) + { + IRDumpContext context; + context.builder = &builder; + context.indent = 0; + context.mode = mode; + + dumpIRModule(&context, module); + } + + void dumpIR(IRInst* globalVal, ISlangWriter* writer, IRDumpMode mode) + { + StringBuilder sb; + + IRDumpContext context; + context.builder = &sb; + context.indent = 0; + context.mode = mode; + + dumpInst(&context, globalVal); + + writer->write(sb.getBuffer(), sb.getLength()); + writer->flush(); + } + + String getSlangIRAssembly(IRModule* module, IRDumpMode mode) + { + StringBuilder sb; + printSlangIRAssembly(sb, module, mode); + return sb; + } + + void dumpIR(IRModule* module, ISlangWriter* writer, IRDumpMode mode) + { + String ir = getSlangIRAssembly(module, mode); + writer->write(ir.getBuffer(), ir.getLength()); + writer->flush(); + } + + // Pre-declare + static bool _isTypeOperandEqual(IRInst* a, IRInst* b); + + static bool _areTypeOperandsEqual(IRInst* a, IRInst* b) + { + // Must have same number of operands + const auto operandCountA = Index(a->getOperandCount()); + if (operandCountA != Index(b->getOperandCount())) + { + return false; + } + + // All the operands must be equal + for (Index i = 0; i < operandCountA; ++i) + { + IRInst* operandA = a->getOperand(i); + IRInst* operandB = b->getOperand(i); + + if (!_isTypeOperandEqual(operandA, operandB)) + { + return false; + } + } + + return true; + } + + static bool _isNominalOp(IROp op) + { + // True if the op identity is 'nominal' + switch (op) + { + case kIROp_StructType: + case kIROp_InterfaceType: + case kIROp_Generic: + case kIROp_Param: + { + return true; + } + } + return false; + } + + // True if a type operand is equal. Operands are 'IRInst' - but it's only a restricted set that + // can be operands of IRType instructions + static bool _isTypeOperandEqual(IRInst* a, IRInst* b) + { + if (a == b) + { + return true; + } + + if (a == nullptr || b == nullptr) + { + return false; + } + + const IROp opA = IROp(a->op & kIROpMeta_PseudoOpMask); + const IROp opB = IROp(b->op & kIROpMeta_PseudoOpMask); + + if (opA != opB) + { + return false; + } + + // If the type is nominal - it can only be the same if the pointer is the same. + if (_isNominalOp(opA)) + { + // The pointer isn't the same (as that was already tested), so cannot be equal + return false; + } + + // Both are types + if (IRType::isaImpl(opA)) + { + if (IRBasicType::isaImpl(opA)) + { + // If it's a basic type, then their op being the same means we are done + return true; + } + + // We don't care about the parent or positioning + // We also don't care about 'type' - because these instructions are defining the type. + // + // We may want to care about decorations. + + // If it's a resource type - special case the handling of the resource flavor + if (IRResourceTypeBase::isaImpl(opA) && + static_cast(a)->getFlavor() != static_cast(b)->getFlavor()) + { + return false; + } + + // TODO(JS): There is a question here about what to do about decorations. + // For now we ignore decorations. Are two types potentially different if there decorations different? + // If decorations play a part in difference in types - the order of decorations presumably is not important. + + // All the operands of the types must be equal + return _areTypeOperandsEqual(a, b); + } + + // If it's a constant... + if (IRConstant::isaImpl(opA)) + { + // TODO: This is contrived in that we want two types that are the same, but are different + // pointers to match here. + // If we make GetHashCode for IRType* compatible with isTypeEqual, then we should probably use that. + return static_cast(a)->isValueEqual(static_cast(b)) && + isTypeEqual(a->getFullType(), b->getFullType()); + } + + SLANG_ASSERT(!"Unhandled comparison"); + + // We can't equate any other type.. + return false; + } + + bool isTypeEqual(IRType* a, IRType* b) + { + // _isTypeOperandEqual handles comparison of types so can defer to it + return _isTypeOperandEqual(a, b); + } + + void findAllInstsBreadthFirst(IRInst* inst, List& outInsts) + { + Index index = outInsts.getCount(); + + outInsts.add(inst); + + while (index < outInsts.getCount()) + { + IRInst* cur = outInsts[index++]; + + IRInstListBase childrenList = cur->getDecorationsAndChildren(); + for (IRInst* child : childrenList) + { + outInsts.add(child); + } + } + } + + IRDecoration* IRInst::getFirstDecoration() + { + return as(getFirstDecorationOrChild()); + } + + IRDecoration* IRInst::getLastDecoration() + { + IRDecoration* decoration = getFirstDecoration(); + if (!decoration) return nullptr; + + while (auto nextDecoration = decoration->getNextDecoration()) + decoration = nextDecoration; + + return decoration; + } + + IRInstList IRInst::getDecorations() + { + return IRInstList( + getFirstDecoration(), + getLastDecoration()); + } + + IRInst* IRInst::getFirstChild() + { + // The children come after any decorations, + // so if there are any decorations, then the + // first child is right after the last decoration. + // + if(auto lastDecoration = getLastDecoration()) + return lastDecoration->getNextInst(); + // + // Otherwise, there must be no decorations, so + // that the first "child or decoration" is a child. + // + return getFirstDecorationOrChild(); + } + + IRInst* IRInst::getLastChild() + { + // The children come after any decorations, so + // that the last item in the list of children + // and decorations is the last child *unless* + // it is a decoration, in which case there are + // no children. + // + auto lastChild = getLastDecorationOrChild(); + return as(lastChild) ? nullptr : lastChild; + } + + + IRRate* IRInst::getRate() + { + if(auto rateQualifiedType = as(getFullType())) + return rateQualifiedType->getRate(); + + return nullptr; + } + + IRType* IRInst::getDataType() + { + auto type = getFullType(); + if(auto rateQualifiedType = as(type)) + return rateQualifiedType->getValueType(); + + return type; + } + + void IRInst::replaceUsesWith(IRInst* other) + { + // Safety check: don't try to replace something with itself. + if(other == this) + return; + + // We will walk through the list of uses for the current + // instruction, and make them point to the other inst. + IRUse* ff = firstUse; + + // No uses? Nothing to do. + if(!ff) + return; + + ff->debugValidate(); + + IRUse* uu = ff; + for(;;) + { + // The uses had better all be uses of this + // instruction, or invariants are broken. + SLANG_ASSERT(uu->get() == this); + + // Swap this use over to use the other value. + uu->usedValue = other; + + // Try to move to the next use, but bail + // out if we are at the last one. + IRUse* nn = uu->nextUse; + if( !nn ) + break; + + uu = nn; + } + + // We are at the last use (and there must + // be at least one, because we handled + // the case of an empty list earlier). + SLANG_ASSERT(uu); + + // Our job at this point is to splice + // our list of uses onto the other + // value's uses. + // + // If the value already had uses, then + // we need to patch our new list onto + // the front. + if( auto nn = other->firstUse ) + { + uu->nextUse = nn; + nn->prevLink = &uu->nextUse; + } + + // No matter what, our list of + // uses will become the start + // of the list of uses for + // `other` + other->firstUse = ff; + ff->prevLink = &other->firstUse; + + // And `this` will have no uses any more. + this->firstUse = nullptr; + + ff->debugValidate(); + } + + // Insert this instruction into the same basic block + // as `other`, right before it. + void IRInst::insertBefore(IRInst* other) + { + SLANG_ASSERT(other); + _insertAt(other->getPrevInst(), other, other->getParent()); + } + + void IRInst::insertAtStart(IRInst* newParent) + { + SLANG_ASSERT(newParent); + _insertAt(nullptr, newParent->getFirstDecorationOrChild(), newParent); + } + + void IRInst::moveToStart() + { + auto p = parent; + removeFromParent(); + insertAtStart(p); + } + + void IRInst::_insertAt(IRInst* inPrev, IRInst* inNext, IRInst* inParent) + { + // Make sure this instruction has been removed from any previous parent + this->removeFromParent(); + + SLANG_ASSERT(inParent); + SLANG_ASSERT(!inPrev || (inPrev->getNextInst() == inNext) && (inPrev->getParent() == inParent)); + SLANG_ASSERT(!inNext || (inNext->getPrevInst() == inPrev) && (inNext->getParent() == inParent)); + + if( inPrev ) + { + inPrev->next = this; + } + else + { + inParent->m_decorationsAndChildren.first = this; + } + + if (inNext) + { + inNext->prev = this; + } + else + { + inParent->m_decorationsAndChildren.last = this; + } + + this->prev = inPrev; + this->next = inNext; + this->parent = inParent; + } + + void IRInst::insertAfter(IRInst* other) + { + SLANG_ASSERT(other); + removeFromParent(); + _insertAt(other, other->getNextInst(), other->getParent()); + } + + void IRInst::insertAtEnd(IRInst* newParent) + { + SLANG_ASSERT(newParent); + removeFromParent(); + _insertAt(newParent->getLastDecorationOrChild(), nullptr, newParent); + } + + void IRInst::moveToEnd() + { + auto p = parent; + removeFromParent(); + insertAtEnd(p); + } + + // Remove this instruction from its parent block, + // and then destroy it (it had better have no uses!) + void IRInst::removeFromParent() + { + auto oldParent = getParent(); + + // If we don't currently have a parent, then + // we are doing fine. + if(!oldParent) + return; + + auto pp = getPrevInst(); + auto nn = getNextInst(); + + if(pp) + { + SLANG_ASSERT(pp->getParent() == oldParent); + pp->next = nn; + } + else + { + oldParent->m_decorationsAndChildren.first = nn; + } + + if(nn) + { + SLANG_ASSERT(nn->getParent() == oldParent); + nn->prev = pp; + } + else + { + oldParent->m_decorationsAndChildren.last = pp; + } + + prev = nullptr; + next = nullptr; + parent = nullptr; + } + + void IRInst::removeArguments() + { + typeUse.clear(); + for( UInt aa = 0; aa < operandCount; ++aa ) + { + IRUse& use = getOperands()[aa]; + use.clear(); + } + } + + // Remove this instruction from its parent block, + // and then destroy it (it had better have no uses!) + void IRInst::removeAndDeallocate() + { + removeFromParent(); + removeArguments(); + removeAndDeallocateAllDecorationsAndChildren(); + + // Run destructor to be sure... + this->~IRInst(); + } + + void IRInst::removeAndDeallocateAllDecorationsAndChildren() + { + IRInst* nextChild = nullptr; + for( IRInst* child = getFirstDecorationOrChild(); child; child = nextChild ) + { + nextChild = child->getNextInst(); + child->removeAndDeallocate(); + } + } + + void IRInst::transferDecorationsTo(IRInst* target) + { + while( auto decoration = getFirstDecoration() ) + { + decoration->removeFromParent(); + decoration->insertAtStart(target); + } + } + + bool IRInst::mightHaveSideEffects() + { + // TODO: We should drive this based on flags specified + // in `ir-inst-defs.h` isntead of hard-coding things here, + // but this is good enough for now if we are conservative: + + if(as(this)) + return false; + + if(as(this)) + return false; + + switch(op) + { + // By default, assume that we might have side effects, + // to safely cover all the instructions we haven't had time to think about. + default: + return true; + + case kIROp_Call: + { + // In the general case, a function call must be assumed to + // have almost arbitrary side effects. + // + // However, it is possible that the callee can be identified, + // and it may be a function with an attribute that explicitly + // limits the side effects it is allowed to have. + // + // For now, we will explicitly check for the `[__readNone]` + // attribute, which was used to mark functions that compute + // their result strictly as a function of the arguments (and + // not anything they point to, or other non-argument state). + // Calls to such functions cannot have side effects (except + // for things like stack overflow that abstract language models + // tend to ignore), and can be subject to dead code elimination, + // common subexpression elimination, etc. + // + auto call = cast(this); + auto callee = getResolvedInstForDecorations(call->getCallee()); + if(callee->findDecoration()) + { + return false; + } + } + return true; + + // All of the cases for "global values" are side-effect-free. + case kIROp_StructType: + case kIROp_StructField: + case kIROp_Func: + case kIROp_Generic: + case kIROp_GlobalVar: + case kIROp_GlobalConstant: + case kIROp_GlobalParam: + case kIROp_StructKey: + case kIROp_GlobalGenericParam: + case kIROp_WitnessTable: + case kIROp_WitnessTableEntry: + case kIROp_Block: + return false; + + case kIROp_Nop: + case kIROp_Specialize: + case kIROp_lookup_interface_method: + case kIROp_Construct: + case kIROp_makeVector: + case kIROp_MakeMatrix: + case kIROp_makeArray: + case kIROp_makeStruct: + case kIROp_Load: // We are ignoring the possibility of loads from bad addresses, or `volatile` loads + case kIROp_FieldExtract: + case kIROp_FieldAddress: + case kIROp_getElement: + case kIROp_getElementPtr: + case kIROp_constructVectorFromScalar: + case kIROp_swizzle: + case kIROp_swizzleSet: // Doesn't actually "set" anything - just returns the resulting vector + case kIROp_Add: + case kIROp_Sub: + case kIROp_Mul: + //case kIROp_Div: // TODO: We could split out integer vs. floating-point div/mod and assume the floating-point cases have no side effects + //case kIROp_Mod: + case kIROp_Lsh: + case kIROp_Rsh: + case kIROp_Eql: + case kIROp_Neq: + case kIROp_Greater: + case kIROp_Less: + case kIROp_Geq: + case kIROp_Leq: + case kIROp_BitAnd: + case kIROp_BitXor: + case kIROp_BitOr: + case kIROp_And: + case kIROp_Or: + case kIROp_Neg: + case kIROp_Not: + case kIROp_BitNot: + case kIROp_Select: + case kIROp_Dot: + case kIROp_Mul_Vector_Matrix: + case kIROp_Mul_Matrix_Vector: + case kIROp_Mul_Matrix_Matrix: + case kIROp_MakeExistential: + case kIROp_ExtractExistentialType: + case kIROp_ExtractExistentialValue: + case kIROp_ExtractExistentialWitnessTable: + case kIROp_WrapExistential: + return false; + } + } + + IRModule* IRInst::getModule() + { + IRInst* ii = this; + while(ii) + { + if(auto moduleInst = as(ii)) + return moduleInst->module; + + ii = ii->getParent(); + } + return nullptr; + } + + // + // IRType + // + + IRType* unwrapArray(IRType* type) + { + IRType* t = type; + while( auto arrayType = as(t) ) + { + t = arrayType->getElementType(); + } + return t; + } + + IRTargetIntrinsicDecoration* findTargetIntrinsicDecoration( + IRInst* val, + String const& targetName) + { + for(auto dd : val->getDecorations()) + { + if(dd->op != kIROp_TargetIntrinsicDecoration) + continue; + + auto decoration = (IRTargetIntrinsicDecoration*) dd; + if(String(decoration->getTargetName()) == targetName) + return decoration; + } + + return nullptr; + } + +#if 0 + IRFunc* cloneSimpleFuncWithoutRegistering(IRSpecContextBase* context, IRFunc* originalFunc) + { + auto clonedFunc = context->builder->createFunc(); + cloneFunctionCommon(context, clonedFunc, originalFunc, false); + return clonedFunc; + } +#endif + + IRInst* findGenericReturnVal(IRGeneric* generic) + { + auto lastBlock = generic->getLastBlock(); + if (!lastBlock) + return nullptr; + + auto returnInst = as(lastBlock->getTerminator()); + if (!returnInst) + return nullptr; + + auto val = returnInst->getVal(); + return val; + } + + IRInst* getResolvedInstForDecorations(IRInst* inst) + { + IRInst* candidate = inst; + while(auto specInst = as(candidate)) + { + auto genericInst = as(specInst->getBase()); + if(!genericInst) + break; + + auto returnVal = findGenericReturnVal(genericInst); + if(!returnVal) + break; + + candidate = returnVal; + } + return candidate; + } + + bool isDefinition( + IRInst* inVal) + { + IRInst* val = inVal; + // unwrap any generic declarations to see + // the value they return. + for(;;) + { + auto genericInst = as(val); + if(!genericInst) + break; + + auto returnVal = findGenericReturnVal(genericInst); + if(!returnVal) + break; + + val = returnVal; + } + + // TODO: the logic here should probably + // be that anything with an `IRImportDecoration` + // is considered to be a declaration rather than definition. + + switch (val->op) + { + case kIROp_WitnessTable: + case kIROp_GlobalConstant: + case kIROp_Func: + case kIROp_Generic: + return val->getFirstChild() != nullptr; + + case kIROp_StructType: + case kIROp_GlobalVar: + case kIROp_GlobalParam: + return true; + + default: + return false; + } + } + + void markConstExpr( + IRBuilder* builder, + IRInst* irValue) + { + // We will take an IR value with type `T`, + // and turn it into one with type `@ConstExpr T`. + + // TODO: need to be careful if the value already has a rate + // qualifier set. + + irValue->setFullType( + builder->getRateQualifiedType( + builder->getConstExprRate(), + irValue->getDataType())); + } +} diff --git a/source/slang/slang-ir.h b/source/slang/slang-ir.h new file mode 100644 index 000000000..956d97ad3 --- /dev/null +++ b/source/slang/slang-ir.h @@ -0,0 +1,1202 @@ +// slang-ir.h +#ifndef SLANG_IR_H_INCLUDED +#define SLANG_IR_H_INCLUDED + +// This file defines the intermediate representation (IR) used for Slang +// shader code. This is a typed static single assignment (SSA) IR, +// similar in spirit to LLVM (but much simpler). +// + +#include "../core/slang-basic.h" + +#include "slang-source-loc.h" + +#include "../core/slang-memory-arena.h" +#include "../core/slang-object-scope-manager.h" + +#include "slang-type-system-shared.h" + +namespace Slang { + +class Decl; +class GenericDecl; +class FuncType; +class Layout; +class Type; +class Session; +class Name; +struct IRBuilder; +struct IRFunc; +struct IRGlobalValueWithCode; +struct IRInst; +struct IRModule; + +typedef unsigned int IROpFlags; +enum : IROpFlags +{ + kIROpFlags_None = 0, + kIROpFlag_Parent = 1 << 0, ///< This op is a parent op + kIROpFlag_UseOther = 1 << 1, ///< If set this op can use 'other bits' to store information +}; + +/* Bit usage of IROp is a follows + + MainOp | Pseudo | Other +Bit range: 0-7 | 8 | Remaining bits + +If an instruction is 'pseudo' (ie shouldn't appear in output IR), then the Pseudo bit is set - and 'Invalid' falls into +this category as well as all pseudo ops. +For doing range checks (for example for doing isa tests), the value is masked by kIROpMeta_OpMask, such that the Other bits don't interfere. +The other bits can be used for storage for anything that needs to identify as a different 'op' or 'type'. It is currently +used currently for storing the TextureFlavor of a IRResourceTypeBase derived types for example. +*/ +enum IROp : int32_t +{ +#define INST(ID, MNEMONIC, ARG_COUNT, FLAGS) \ + kIROp_##ID, + +#include "slang-ir-inst-defs.h" + + kIROpCount, + + // We use the range 0x100 to 0x1ff set for pseudo/non main codes + // Instructions that should not appear in valid IR. + + kIROp_Invalid = 0x100, ///< If bit set, then in pseudo/not normal space + kIRPseudoOp_First = kIROp_Invalid, + +#define INST(ID, MNEMONIC, ARG_COUNT, FLAGS) /* empty */ +#define PSEUDO_INST(ID) kIRPseudoOp_##ID, + + kIRPseudoOp_LastPlusOne, + +#include "slang-ir-inst-defs.h" + +#define INST(ID, MNEMONIC, ARG_COUNT, FLAGS) /* empty */ +#define INST_RANGE(BASE, FIRST, LAST) \ + kIROp_First##BASE = kIROp_##FIRST, \ + kIROp_Last##BASE = kIROp_##LAST, + +#include "slang-ir-inst-defs.h" +}; + +/* IROpMeta describe values for layout of IROp, as well as values for accessing aspects of IROp bits. */ +enum IROpMeta +{ + kIROpMeta_OtherShift = 9, ///< Number of bits for op/pseudo ops (shift right by this to get the other bits) + kIROpMeta_PseudoOpMask = (int32_t(1) << kIROpMeta_OtherShift) - 1, ///< Mask for ops including pseudo ops + kIROpMeta_OpMask = 0xff, ///< Mask for just ops + kIrOpMeta_OtherMask = ~kIROpMeta_PseudoOpMask, ///< Mask for bits that can be used for other purposes than 'op' ('other' bits) + kIROpMeta_IsPseudoOp = kIROp_Invalid, ///< 'And' with op, if set, the op is a pseudo op +}; + +// True if op is pseudo (or invalid which is 'pseudo-like' at least in as so far as current behavior) +SLANG_FORCE_INLINE bool isPseudoOp(IROp op) { return (op & kIROpMeta_IsPseudoOp) != 0; } + +IROp findIROp(const UnownedStringSlice& name); + +// A logical operation/opcode in the IR +struct IROpInfo +{ + // What is the name/mnemonic for this operation + char const* name; + + // How many required arguments are there + // (not including the mandatory type argument) + unsigned int fixedArgCount; + + // Flags to control how we emit additional info + IROpFlags flags; +}; + +// Look up the info for an op +IROpInfo getIROpInfo(IROp op); + +// A use of another value/inst within an IR operation +struct IRUse +{ + IRInst* get() const { return usedValue; } + IRInst* getUser() const { return user; } + + void init(IRInst* user, IRInst* usedValue); + void set(IRInst* usedValue); + void clear(); + + // The instruction that is being used + IRInst* usedValue = nullptr; + + // The instruction that is doing the using. + IRInst* user = nullptr; + + // The next use of the same value + IRUse* nextUse = nullptr; + + // A "link" back to where this use is referenced, + // so that we can simplify updates. + IRUse** prevLink = nullptr; + + void debugValidate(); +}; + +struct IRBlock; +struct IRDecoration; +struct IRRate; +struct IRType; + +// A double-linked list of instruction +struct IRInstListBase +{ + IRInstListBase() + {} + + IRInstListBase(IRInst* first, IRInst* last) + : first(first) + , last(last) + {} + + + + IRInst* first = nullptr; + IRInst* last = nullptr; + + IRInst* getFirst() { return first; } + IRInst* getLast() { return last; } + + struct Iterator + { + IRInst* inst; + + Iterator() : inst(nullptr) {} + Iterator(IRInst* inst) : inst(inst) {} + + void operator++(); + IRInst* operator*() + { + return inst; + } + + bool operator!=(Iterator const& i) + { + return inst != i.inst; + } + }; + + Iterator begin(); + Iterator end(); +}; + +// Specialization of `IRInstListBase` for the case where +// we know (or at least expect) all of the instructions +// to be of type `T` +template +struct IRInstList : IRInstListBase +{ + IRInstList() {} + + IRInstList(T* first, T* last) + : IRInstListBase(first, last) + {} + + explicit IRInstList(IRInstListBase const& list) + : IRInstListBase(list) + {} + + T* getFirst() { return (T*) first; } + T* getLast() { return (T*) last; } + + struct Iterator : public IRInstListBase::Iterator + { + Iterator() {} + Iterator(IRInst* inst) : IRInstListBase::Iterator(inst) {} + + T* operator*() + { + return (T*) inst; + } + }; + + Iterator begin() { return Iterator(first); } + Iterator end(); +}; + + + +// Every value in the IR is an instruction (even things +// like literal values). +// +struct IRInst +{ + // The operation that this value represents + IROp op; + + // The total number of operands of this instruction. + // + // TODO: We shouldn't need to allocate this on + // all instructions. Instead we should have + // instructions that need "vararg" support to + // allocate this field ahead of the `this` + // pointer. + uint32_t operandCount = 0; + + UInt getOperandCount() + { + return operandCount; + } + + // Source location information for this value, if any + SourceLoc sourceLoc; + + // Each instruction can have zero or more "decorations" + // attached to it. A decoration is a specialized kind + // of instruction that either attaches metadata to, + // or modifies the sematnics of, its parent instruction. + // + IRDecoration* getFirstDecoration(); + IRDecoration* getLastDecoration(); + IRInstList getDecorations(); + + // Look up a decoration in the list of decorations + IRDecoration* findDecorationImpl(IROp op); + template + T* findDecoration(); + + // The first use of this value (start of a linked list) + IRUse* firstUse = nullptr; + + + // The parent of this instruction. + IRInst* parent; + + IRInst* getParent() { return parent; } + + // The next and previous instructions with the same parent + IRInst* next; + IRInst* prev; + + IRInst* getNextInst() { return next; } + IRInst* getPrevInst() { return prev; } + + // An instruction can have zero or more children, although + // only certain instruction opcodes are allowed to have + // children. + // + // For example, a function will have children that are + // its basic blocks, and the basic blocks will have children + // that represent parameters and ordinary executable instructions. + // + IRInst* getFirstChild(); + IRInst* getLastChild(); + IRInstList getChildren() + { + return IRInstList( + getFirstChild(), + getLastChild()); + } + + /// A doubly-linked list containing any decorations and then any children of this instruction. + /// + /// We store both the decorations and children of an instruction + /// in the same list, to conserve space in the instruction itself + /// (rather than storing distinct lists for decorations and children). + /// + // Note: This field is *not* being declared `private` because doing so could + // mess with our required memory layout, where `typeUse` below is assumed + // to be the last field in `IRInst` and to come right before any additional + // `IRUse` values that represent operands. + // + IRInstListBase m_decorationsAndChildren; + + IRInst* getFirstDecorationOrChild() { return m_decorationsAndChildren.first; } + IRInst* getLastDecorationOrChild() { return m_decorationsAndChildren.last; } + IRInstListBase getDecorationsAndChildren() { return m_decorationsAndChildren; } + + void removeAndDeallocateAllDecorationsAndChildren(); + + // The type of the result value of this instruction, + // or `null` to indicate that the instruction has + // no value. + IRUse typeUse; + + IRType* getFullType() { return (IRType*) typeUse.get(); } + void setFullType(IRType* type) { typeUse.init(this, (IRInst*) type); } + + IRRate* getRate(); + + IRType* getDataType(); + + // After the type, we have data that is specific to + // the subtype of `IRInst`. In most cases, this is + // just a series of `IRUse` values representing + // operands of the instruction. + + IRUse* getOperands(); + + IRInst* getOperand(UInt index) + { + return getOperands()[index].get(); + } + + void setOperand(UInt index, IRInst* value) + { + getOperands()[index].set(value); + } + + + // + + // Replace all uses of this value with `other`, so + // that this value will now have no uses. + void replaceUsesWith(IRInst* other); + + // Insert this instruction into the same basic block + // as `other`, right before/after it. + void insertBefore(IRInst* other); + void insertAfter(IRInst* other); + + // Insert as first/last child of given parent + void insertAtStart(IRInst* parent); + void insertAtEnd(IRInst* parent); + + // Move to the start/end of current parent + void moveToStart(); + void moveToEnd(); + + // Remove this instruction from its parent block, + // but don't delete it, or replace uses. + void removeFromParent(); + + // Remove this instruction from its parent block, + // and then destroy it (it had better have no uses!) + void removeAndDeallocate(); + + // Clear out the arguments of this instruction, + // so that we don't appear on the list of uses + // for those values. + void removeArguments(); + + /// Transfer any decorations of this instruction to the `target` instruction. + void transferDecorationsTo(IRInst* target); + + /// Does this instruction have any uses? + bool hasUses() const { return firstUse != nullptr; } + + /// Does this instructiomn have more than one use? + bool hasMoreThanOneUse() const { return firstUse != nullptr && firstUse->nextUse != nullptr; } + + /// It is possible that this instruction has side effects? + /// + /// This is a conservative test, and will return `true` if an exact answer can't be determined. + bool mightHaveSideEffects(); + + // RTTI support + static bool isaImpl(IROp) { return true; } + + /// Find the module that this instruction is nested under. + /// + /// If this instruction is transitively nested inside some IR module, + /// this function will return it, and will otherwise return `null`. + IRModule* getModule(); + + /// Insert this instruction into `inParent`, after `inPrev` and before `inNext`. + /// + /// `inParent` must be non-null + /// If `inPrev` is non-null it must satisfy `inPrev->getNextInst() == inNext` and `inPrev->getParent() == inParent` + /// If `inNext` is non-null it must satisfy `inNext->getPrevInst() == inPrev` and `inNext->getParent() == inParent` + /// + /// If both `inPrev` and `inNext` are null, then `inParent` must have no (raw) children. + /// + void _insertAt(IRInst* inPrev, IRInst* inNext, IRInst* inParent); +}; + +template +T* dynamicCast(IRInst* inst) +{ + if (inst && T::isaImpl(inst->op)) + return static_cast(inst); + return nullptr; +} + +template +const T* dynamicCast(const IRInst* inst) +{ + if (inst && T::isaImpl(inst->op)) + return static_cast(inst); + return nullptr; +} + +// `dynamic_cast` equivalent (we just use dynamicCast) +template +T* as(IRInst* inst) +{ + return dynamicCast(inst); +} + +template +const T* as(const IRInst* inst) +{ + return dynamicCast(inst); +} + +// `static_cast` equivalent, with debug validation +template +T* cast(IRInst* inst, T* /* */ = nullptr) +{ + SLANG_ASSERT(!inst || as(inst)); + return (T*)inst; +} + +// Now that `IRInst` is defined we can back-fill the definitions that need to access it. + +template +T* IRInst::findDecoration() +{ + for( auto decoration : getDecorations() ) + { + if(auto match = as(decoration)) + return match; + } + return nullptr; +} + +template +typename IRInstList::Iterator IRInstList::end() +{ + return Iterator(last ? last->next : nullptr); +} + + +// Types + +#define IR_LEAF_ISA(NAME) static bool isaImpl(IROp op) { return (kIROpMeta_PseudoOpMask & op) == kIROp_##NAME; } +#define IR_PARENT_ISA(NAME) static bool isaImpl(IROp opIn) { const int op = (kIROpMeta_PseudoOpMask & opIn); return op >= kIROp_First##NAME && op <= kIROp_Last##NAME; } + +#define SIMPLE_IR_TYPE(NAME, BASE) struct IR##NAME : IR##BASE { IR_LEAF_ISA(NAME) }; +#define SIMPLE_IR_PARENT_TYPE(NAME, BASE) struct IR##NAME : IR##BASE { IR_PARENT_ISA(NAME) }; + + +// All types in the IR are represented as instructions which conceptually +// execute before run time. +struct IRType : IRInst +{ + IRType* getCanonicalType() { return this; } + + IR_PARENT_ISA(Type) +}; + +IRType* unwrapArray(IRType* type); + +struct IRBasicType : IRType +{ + BaseType getBaseType() { return BaseType(op - kIROp_FirstBasicType); } + + IR_PARENT_ISA(BasicType) +}; + +struct IRVoidType : IRBasicType +{ + IR_LEAF_ISA(VoidType) +}; + +struct IRBoolType : IRBasicType +{ + IR_LEAF_ISA(BoolType) +}; + +SIMPLE_IR_TYPE(StringType, Type) + + +// True if types are equal +// Note compares nominal types by name alone +bool isTypeEqual(IRType* a, IRType* b); + +void findAllInstsBreadthFirst(IRInst* inst, List& outInsts); + +// Constant Instructions + +typedef int64_t IRIntegerValue; +typedef double IRFloatingPointValue; + +struct IRConstant : IRInst +{ + struct StringValue + { + uint32_t numChars; ///< The number of chars + char chars[1]; ///< Chars added at end. NOTE! Must be last member of struct! + }; + struct StringSliceValue + { + uint32_t numChars; + char* chars; + }; + + union ValueUnion + { + IRIntegerValue intVal; ///< Used for integrals and boolean + IRFloatingPointValue floatVal; + void* ptrVal; + + /// Either of these types could be set with kIROp_StringLit. + /// Which is used is currently determined with decorations - if a kIROp_TransitoryDecoration is set, then the transitory StringVal is used, else stringVal + // which relies on chars being held after the struct). + StringValue stringVal; + StringSliceValue transitoryStringVal; + }; + + /// Returns a string slice (or empty string if not appropriate) + UnownedStringSlice getStringSlice(); + + /// True if constants are equal + bool equal(IRConstant* rhs); + /// True if the value is equal. + /// Does *NOT* compare if the type is equal. + bool isValueEqual(IRConstant* rhs); + + /// Get the hash + int getHashCode(); + + IR_PARENT_ISA(Constant) + + // Must be last member, because data may be held behind + // NOTE! The total size of IRConstant may not be allocated - only enough space is allocated for the value type held in the union. + ValueUnion value; +}; + +struct IRIntLit : IRConstant +{ + IRIntegerValue getValue() { return value.intVal; } + + IR_LEAF_ISA(IntLit); +}; + +struct IRBoolLit : IRConstant +{ + bool getValue() { return value.intVal != 0; } + + IR_LEAF_ISA(BoolLit); +}; + + +// Get the compile-time constant integer value of an instruction, +// if it has one, and assert-fail otherwise. +IRIntegerValue GetIntVal(IRInst* inst); + +struct IRStringLit : IRConstant +{ + + IR_LEAF_ISA(StringLit); +}; + +struct IRPtrLit : IRConstant +{ + IR_LEAF_ISA(PtrLit); + + void* getValue() { return value.ptrVal; } +}; + +// A instruction that ends a basic block (usually because of control flow) +struct IRTerminatorInst : IRInst +{ + IR_PARENT_ISA(TerminatorInst) +}; + +// A function parameter is owned by a basic block, and represents +// either an incoming function parameter (in the entry block), or +// a value that flows from one SSA block to another (in a non-entry +// block). +// +// In each case, the basic idea is that a block is a "label with +// arguments." +// +struct IRParam : IRInst +{ + IRParam* getNextParam(); + IRParam* getPrevParam(); + + IR_LEAF_ISA(Param) +}; + +// A basic block is a parent instruction that adds the constraint +// that all the children need to be "ordinary" instructions (so +// no function declarations, or nested blocks). We also expect +// that the previous/next instruction are always a basic block. +// +struct IRBlock : IRInst +{ + // Linked list of the instructions contained in this block + // + IRInst* getFirstInst() { return getChildren().first; } + IRInst* getLastInst() { return getChildren().last; } + + // In a valid program, every basic block should end with + // a "terminator" instruction. + // + // This function will return the terminator, if it exists, + // or `null` if there is none. + IRTerminatorInst* getTerminator() { return as(getLastDecorationOrChild()); } + + // We expect that the siblings of a basic block will + // always be other basic blocks (we don't allow + // mixing of blocks and other instructions in the + // same parent). + // + // The exception to this is that decorations on the function + // that contains a block could appear before the first block, + // so we need to be careful to do a dynamic cast (`as`) in + // the `getPrevBlock` case, but don't need to worry about + // it for `getNextBlock`. + IRBlock* getPrevBlock() { return as(getPrevInst()); } + IRBlock* getNextBlock() { return cast(getNextInst()); } + + // The parameters of a block are represented by `IRParam` + // instructions at the start of the block. These play + // the role of function parameters for the entry block + // of a function, and of phi nodes in other blocks. + IRParam* getFirstParam() { return as(getFirstInst()); } + IRParam* getLastParam(); + IRInstList getParams() + { + return IRInstList( + getFirstParam(), + getLastParam()); + } + + void addParam(IRParam* param); + + // The "ordinary" instructions come after the parameters + IRInst* getFirstOrdinaryInst(); + IRInst* getLastOrdinaryInst(); + IRInstList getOrdinaryInsts() + { + return IRInstList( + getFirstOrdinaryInst(), + getLastOrdinaryInst()); + } + + // The parent of a basic block is assumed to be a + // value with code (e.g., a function, global variable + // with initializer, etc.). + IRGlobalValueWithCode* getParent() { return cast(IRInst::getParent()); } + + // The predecessor and successor lists of a block are needed + // when we want to work with the control flow graph (CFG) of + // a function. Rather than store these explicitly (and thus + // need to update them when transformations might change the + // CFG), we compute predecessors and successors in an + // implicit fashion using the use-def information for a + // block itself. + // + // To a first approximation, the predecessors of a block + // are the blocks where the IR value of the block is used. + // Similarly, the successors of a block are all values used + // by the terminator instruction of the block. + // The `getPredecessors()` and `getSuccessors()` functions + // make this more precise. + // + struct PredecessorList + { + PredecessorList(IRUse* begin) : b(begin) {} + IRUse* b; + + UInt getCount(); + bool isEmpty(); + + struct Iterator + { + Iterator(IRUse* use) : use(use) {} + + IRBlock* operator*(); + + void operator++(); + + bool operator!=(Iterator const& that) + { + return use != that.use; + } + + IRUse* use; + }; + + Iterator begin() { return Iterator(b); } + Iterator end() { return Iterator(nullptr); } + }; + + struct SuccessorList + { + SuccessorList(IRUse* begin, IRUse* end, UInt stride = 1) : begin_(begin), end_(end), stride(stride) {} + IRUse* begin_; + IRUse* end_; + UInt stride; + + UInt getCount(); + + struct Iterator + { + Iterator(IRUse* use, UInt stride) : use(use), stride(stride) {} + + IRBlock* operator*(); + + void operator++(); + + bool operator!=(Iterator const& that) + { + return use != that.use; + } + + IRUse* use; + UInt stride; + }; + + Iterator begin() { return Iterator(begin_, stride); } + Iterator end() { return Iterator(end_, stride); } + }; + + PredecessorList getPredecessors(); + SuccessorList getSuccessors(); + + // + + IR_LEAF_ISA(Block) +}; + +SIMPLE_IR_TYPE(BasicBlockType, Type) + +struct IRResourceTypeBase : IRType +{ + TextureFlavor getFlavor() const + { + return TextureFlavor((op >> kIROpMeta_OtherShift) & 0xFFFF); + } + + TextureFlavor::Shape GetBaseShape() const + { + return getFlavor().GetBaseShape(); + } + bool isMultisample() const { return getFlavor().isMultisample(); } + bool isArray() const { return getFlavor().isArray(); } + SlangResourceShape getShape() const { return getFlavor().getShape(); } + SlangResourceAccess getAccess() const { return getFlavor().getAccess(); } + + IR_PARENT_ISA(ResourceTypeBase); +}; + +struct IRResourceType : IRResourceTypeBase +{ + IRType* getElementType() { return (IRType*)getOperand(0); } + + IR_PARENT_ISA(ResourceType) +}; + +struct IRTextureTypeBase : IRResourceType +{ + IR_PARENT_ISA(TextureTypeBase) +}; + +struct IRTextureType : IRTextureTypeBase +{ + IR_LEAF_ISA(TextureType) +}; + +struct IRTextureSamplerType : IRTextureTypeBase +{ + IR_LEAF_ISA(TextureSamplerType) +}; + +struct IRGLSLImageType : IRTextureTypeBase +{ + IR_LEAF_ISA(GLSLImageType) +}; + +struct IRSamplerStateTypeBase : IRType +{ + IR_PARENT_ISA(SamplerStateTypeBase) +}; + +SIMPLE_IR_TYPE(SamplerStateType, SamplerStateTypeBase) +SIMPLE_IR_TYPE(SamplerComparisonStateType, SamplerStateTypeBase) + +struct IRBuiltinGenericType : IRType +{ + IRType* getElementType() { return (IRType*)getOperand(0); } + + IR_PARENT_ISA(BuiltinGenericType) +}; + +SIMPLE_IR_PARENT_TYPE(PointerLikeType, BuiltinGenericType); +SIMPLE_IR_PARENT_TYPE(HLSLStructuredBufferTypeBase, BuiltinGenericType) +SIMPLE_IR_TYPE(HLSLStructuredBufferType, HLSLStructuredBufferTypeBase) +SIMPLE_IR_TYPE(HLSLRWStructuredBufferType, HLSLStructuredBufferTypeBase) +SIMPLE_IR_TYPE(HLSLRasterizerOrderedStructuredBufferType, HLSLStructuredBufferTypeBase) + +SIMPLE_IR_PARENT_TYPE(UntypedBufferResourceType, Type) +SIMPLE_IR_PARENT_TYPE(ByteAddressBufferTypeBase, UntypedBufferResourceType) +SIMPLE_IR_TYPE(HLSLByteAddressBufferType, ByteAddressBufferTypeBase) +SIMPLE_IR_TYPE(HLSLRWByteAddressBufferType, ByteAddressBufferTypeBase) +SIMPLE_IR_TYPE(HLSLRasterizerOrderedByteAddressBufferType, ByteAddressBufferTypeBase) + +SIMPLE_IR_TYPE(HLSLAppendStructuredBufferType, HLSLStructuredBufferTypeBase) +SIMPLE_IR_TYPE(HLSLConsumeStructuredBufferType, HLSLStructuredBufferTypeBase) + +struct IRHLSLPatchType : IRType +{ + IRType* getElementType() { return (IRType*)getOperand(0); } + IRInst* getElementCount() { return getOperand(1); } + + IR_PARENT_ISA(HLSLPatchType) +}; + +SIMPLE_IR_TYPE(HLSLInputPatchType, HLSLPatchType) +SIMPLE_IR_TYPE(HLSLOutputPatchType, HLSLPatchType) + +SIMPLE_IR_PARENT_TYPE(HLSLStreamOutputType, BuiltinGenericType) +SIMPLE_IR_TYPE(HLSLPointStreamType, HLSLStreamOutputType) +SIMPLE_IR_TYPE(HLSLLineStreamType, HLSLStreamOutputType) +SIMPLE_IR_TYPE(HLSLTriangleStreamType, HLSLStreamOutputType) + +SIMPLE_IR_TYPE(GLSLInputAttachmentType, Type) +SIMPLE_IR_PARENT_TYPE(ParameterGroupType, PointerLikeType) +SIMPLE_IR_PARENT_TYPE(UniformParameterGroupType, ParameterGroupType) +SIMPLE_IR_PARENT_TYPE(VaryingParameterGroupType, ParameterGroupType) +SIMPLE_IR_TYPE(ConstantBufferType, UniformParameterGroupType) +SIMPLE_IR_TYPE(TextureBufferType, UniformParameterGroupType) +SIMPLE_IR_TYPE(GLSLInputParameterGroupType, VaryingParameterGroupType) +SIMPLE_IR_TYPE(GLSLOutputParameterGroupType, VaryingParameterGroupType) +SIMPLE_IR_TYPE(GLSLShaderStorageBufferType, UniformParameterGroupType) +SIMPLE_IR_TYPE(ParameterBlockType, UniformParameterGroupType) + +struct IRArrayTypeBase : IRType +{ + IRType* getElementType() { return (IRType*)getOperand(0); } + + // Returns the element count for an `IRArrayType`, and null + // for an `IRUnsizedArrayType`. + IRInst* getElementCount(); + + IR_PARENT_ISA(ArrayTypeBase) +}; + +struct IRArrayType: IRArrayTypeBase +{ + IRInst* getElementCount() { return getOperand(1); } + + IR_LEAF_ISA(ArrayType) +}; + +SIMPLE_IR_TYPE(UnsizedArrayType, ArrayTypeBase) + +SIMPLE_IR_PARENT_TYPE(Rate, Type) +SIMPLE_IR_TYPE(ConstExprRate, Rate) +SIMPLE_IR_TYPE(GroupSharedRate, Rate) + +struct IRRateQualifiedType : IRType +{ + IRRate* getRate() { return (IRRate*) getOperand(0); } + IRType* getValueType() { return (IRType*) getOperand(1); } + + IR_LEAF_ISA(RateQualifiedType) +}; + + +// Unlike the AST-level type system where `TypeType` tracks the +// underlying type, the "type of types" in the IR is a simple +// value with no operands, so that all type nodes have the +// same type. +SIMPLE_IR_PARENT_TYPE(Kind, Type); +SIMPLE_IR_TYPE(TypeKind, Kind); + +// The kind of any and all generics. +// +// A more complete type system would include "arrow kinds" to +// be able to track the domain and range of generics (e.g., +// the `vector` generic maps a type and an integer to a type). +// This is only really needed if we ever wanted to support +// "higher-kinded" generics (e.g., a generic that takes another +// generic as a parameter). +// +SIMPLE_IR_TYPE(GenericKind, Kind) + +struct IRVectorType : IRType +{ + IRType* getElementType() { return (IRType*)getOperand(0); } + IRInst* getElementCount() { return getOperand(1); } + + IR_LEAF_ISA(VectorType) +}; + +struct IRMatrixType : IRType +{ + IRType* getElementType() { return (IRType*)getOperand(0); } + IRInst* getRowCount() { return getOperand(1); } + IRInst* getColumnCount() { return getOperand(2); } + + IR_LEAF_ISA(MatrixType) +}; + +struct IRPtrTypeBase : IRType +{ + IRType* getValueType() { return (IRType*)getOperand(0); } + + IR_PARENT_ISA(PtrTypeBase) +}; + +SIMPLE_IR_TYPE(PtrType, PtrTypeBase) +SIMPLE_IR_TYPE(RefType, PtrTypeBase) + +SIMPLE_IR_PARENT_TYPE(OutTypeBase, PtrTypeBase) +SIMPLE_IR_TYPE(OutType, OutTypeBase) +SIMPLE_IR_TYPE(InOutType, OutTypeBase) +SIMPLE_IR_TYPE(ExistentialBoxType, PtrTypeBase) + + /// Get the type pointed to be `ptrType`, or `nullptr` if it is not a pointer(-like) type. + /// + /// The given IR `builder` will be used if new instructions need to be created. +IRType* tryGetPointedToType( + IRBuilder* builder, + IRType* type); + +struct IRFuncType : IRType +{ + IRType* getResultType() { return (IRType*) getOperand(0); } + UInt getParamCount() { return getOperandCount() - 1; } + IRType* getParamType(UInt index) { return (IRType*)getOperand(1 + index); } + + IR_LEAF_ISA(FuncType) +}; + +bool isDefinition( + IRInst* inVal); + +// A structure type is represented as a parent instruction, +// where the child instructions represent the fields of the +// struct. +// +// The space of fields that a given struct type supports +// are defined as its "keys", which are global values +// (that is, they have mangled names that can be used +// for linkage). +// +struct IRStructKey : IRInst +{ + IR_LEAF_ISA(StructKey) +}; +// +// The fields of the struct are then defined as mappings +// from those keys to the associated type (in the case of +// the struct type) or to values (when lookup up a field). +// +// A struct field thus has two operands: the key, and the +// type of the field. +// +struct IRStructField : IRInst +{ + IRStructKey* getKey() { return cast(getOperand(0)); } + IRType* getFieldType() + { + // Note: We do not use `cast` here because there are + // cases of types (which we would like to conveniently + // refer to via an `IRType*`) which do not actually + // inherit from `IRType` in the hierarchy. + // + return (IRType*) getOperand(1); + } + + IR_LEAF_ISA(StructField) +}; +// +// The struct type is then represented as a parent instruction +// that contains the various fields. Note that a struct does +// *not* contain the keys, because code needs to be able to +// reference the keys from scopes outside of the struct. +// +struct IRStructType : IRType +{ + IRInstList getFields() { return IRInstList(getChildren()); } + + IR_LEAF_ISA(StructType) +}; + +struct IRInterfaceType : IRType +{ + IR_LEAF_ISA(InterfaceType) +}; + +struct IRTaggedUnionType : IRType +{ + IR_LEAF_ISA(TaggedUnionType) +}; + +struct IRBindExistentialsType : IRType +{ + IR_LEAF_ISA(BindExistentialsType) + + IRType* getBaseType() { return (IRType*) getOperand(0); } + UInt getExistentialArgCount() { return getOperandCount() - 1; } + IRUse* getExistentialArgs() { return getOperands() + 1; } + IRInst* getExistentialArg(UInt index) { return getExistentialArgs()[index].get(); } +}; + +/// @brief A global value that potentially holds executable code. +/// +struct IRGlobalValueWithCode : IRInst +{ + // The children of a value with code will be the basic + // blocks of its definition. + IRBlock* getFirstBlock() { return cast(getFirstChild()); } + IRBlock* getLastBlock() { return cast(getLastChild()); } + IRInstList getBlocks() + { + return IRInstList(getChildren()); + } + + // Add a block to the end of this function. + void addBlock(IRBlock* block); + + IR_PARENT_ISA(GlobalValueWithCode) +}; + +// A value that has parameters so that it can conceptually be called. +struct IRGlobalValueWithParams : IRGlobalValueWithCode +{ + // Convenience accessor for the IR parameters, + // which are actually the parameters of the first + // block. + IRParam* getFirstParam(); + IRParam* getLastParam(); + IRInstList getParams(); + + IR_PARENT_ISA(GlobalValueWithParams) +}; + +// A function is a parent to zero or more blocks of instructions. +// +// A function is itself a value, so that it can be a direct operand of +// an instruction (e.g., a call). +struct IRFunc : IRGlobalValueWithParams +{ + // The type of the IR-level function + IRFuncType* getDataType() { return (IRFuncType*) IRInst::getDataType(); } + + // Convenience accessors for working with the + // function's type. + IRType* getResultType(); + UInt getParamCount(); + IRType* getParamType(UInt index); + + bool isDefinition() { return getFirstBlock() != nullptr; } + + IR_LEAF_ISA(Func) +}; + + /// Adjust the type of an IR function based on its parameter list. +void fixUpFuncType(IRFunc* func); + +// A generic is akin to a function, but is conceptually executed +// before runtime, to specialize the code nested within. +// +// In practice, a generic always holds only a single block, and ends +// with a `return` instruction for the value that the generic yields. +struct IRGeneric : IRGlobalValueWithParams +{ + IR_LEAF_ISA(Generic) +}; + +// Find the value that is returned from a generic, so that +// a pass can glean information from it. +IRInst* findGenericReturnVal(IRGeneric* generic); + +// Resolve an instruction that might reference a static definition +// to the most specific IR node possible, so that we can read +// decorations from it (e.g., if this is a `specialize` instruction, +// then try to chase down the generic being specialized, and what +// it seems to return). +// +IRInst* getResolvedInstForDecorations(IRInst* inst); + +// The IR module itself is represented as an instruction, which +// serves at the root of the tree of all instructions in the module. +struct IRModuleInst : IRInst +{ + // Pointer back to the non-instruction object that represents + // the module, so that we can get back to it in algorithms + // that need it. + IRModule* module; + + IRInstListBase getGlobalInsts() { return getChildren(); } + + IR_LEAF_ISA(Module) +}; + +struct IRModule : RefObject +{ + enum + { + kMemoryArenaBlockSize = 16 * 1024, ///< Use 16k block size for memory arena + }; + + SLANG_FORCE_INLINE Session* getSession() const { return session; } + SLANG_FORCE_INLINE IRModuleInst* getModuleInst() const { return moduleInst; } + + IRInstListBase getGlobalInsts() const { return getModuleInst()->getChildren(); } + + /// Get the object scope manager + SLANG_FORCE_INLINE ObjectScopeManager* getObjectScopeManager() { return &m_objectScopeManager; } + + /// Ctor + IRModule(): + memoryArena(kMemoryArenaBlockSize) + { + } + + MemoryArena memoryArena; + + // The compilation session in use. + Session* session; + IRModuleInst* moduleInst; + + protected: + + ObjectScopeManager m_objectScopeManager; +}; + + /// How much detail to include in dumped IR. + /// + /// Used with the `dumpIR` functions to determine + /// whether a completely faithful, but verbose, IR + /// dump is produced, or something simplified for ease + /// or reading. + /// +enum class IRDumpMode +{ + /// Produce a simplified IR dump. + /// + /// Simplified IR dumping will skip certain instructions + /// and print them at their use sites instead, so that + /// the overall dump is shorter and easier to read. + Simplified, + + /// Produce a detailed/accurate IR dump. + /// + /// A detailed IR dump will make sure to emit exactly + /// the instructions that were present with no attempt + /// to selectively skip them or give special formatting. + /// + Detailed, +}; + +void printSlangIRAssembly(StringBuilder& builder, IRModule* module, IRDumpMode mode = IRDumpMode::Simplified); +String getSlangIRAssembly(IRModule* module, IRDumpMode mode = IRDumpMode::Simplified); + +void dumpIR(IRModule* module, ISlangWriter* writer, IRDumpMode mode = IRDumpMode::Simplified); +void dumpIR(IRInst* globalVal, ISlangWriter* writer, IRDumpMode mode = IRDumpMode::Simplified); + +IRInst* createEmptyInst( + IRModule* module, + IROp op, + int totalArgCount); + +IRInst* createEmptyInstWithSize( + IRModule* module, + IROp op, + size_t totalSizeInBytes); +} + +#endif diff --git a/source/slang/slang-legalize-types.cpp b/source/slang/slang-legalize-types.cpp new file mode 100644 index 000000000..d541818ba --- /dev/null +++ b/source/slang/slang-legalize-types.cpp @@ -0,0 +1,1486 @@ +// slang-legalize-types.cpp +#include "slang-legalize-types.h" + +#include "slang-ir-insts.h" +#include "slang-mangle.h" + +namespace Slang +{ + +LegalType LegalType::implicitDeref( + LegalType const& valueType) +{ + RefPtr obj = new ImplicitDerefType(); + obj->valueType = valueType; + + LegalType result; + result.flavor = Flavor::implicitDeref; + result.obj = obj; + return result; +} + +LegalType LegalType::tuple( + RefPtr tupleType) +{ + SLANG_ASSERT(tupleType->elements.getCount()); + + LegalType result; + result.flavor = Flavor::tuple; + result.obj = tupleType; + return result; +} + +LegalType LegalType::pair( + RefPtr pairType) +{ + LegalType result; + result.flavor = Flavor::pair; + result.obj = pairType; + return result; +} + +LegalType LegalType::pair( + LegalType const& ordinaryType, + LegalType const& specialType, + RefPtr pairInfo) +{ + // Handle some special cases for when + // one or the other of the types isn't + // actually used. + + if (ordinaryType.flavor == LegalType::Flavor::none) + { + // There was nothing ordinary. + return specialType; + } + + if (specialType.flavor == LegalType::Flavor::none) + { + return ordinaryType; + } + + // There were both ordinary and special fields, + // and so we need to handle them here. + + RefPtr obj = new PairPseudoType(); + obj->ordinaryType = ordinaryType; + obj->specialType = specialType; + obj->pairInfo = pairInfo; + return LegalType::pair(obj); +} + +LegalType LegalType::makeWrappedBuffer( + IRType* simpleType, + LegalElementWrapping const& elementInfo) +{ + RefPtr obj = new WrappedBufferPseudoType(); + obj->simpleType = simpleType; + obj->elementInfo = elementInfo; + + LegalType result; + result.flavor = Flavor::wrappedBuffer; + result.obj = obj; + return result; +} + +// + +LegalElementWrapping LegalElementWrapping::makeVoid() +{ + LegalElementWrapping result; + result.flavor = Flavor::none; + return result; +} + +LegalElementWrapping LegalElementWrapping::makeSimple(IRStructKey* key, IRType* type) +{ + RefPtr obj = new SimpleLegalElementWrappingObj(); + obj->key = key; + obj->type = type; + + LegalElementWrapping result; + result.flavor = Flavor::simple; + result.obj = obj; + return result; +} + +RefPtr LegalElementWrapping::getSimple() const +{ + SLANG_ASSERT(flavor == Flavor::simple); + return obj.as(); +} + +LegalElementWrapping LegalElementWrapping::makeImplicitDeref(LegalElementWrapping const& field) +{ + RefPtr obj = new ImplicitDerefLegalElementWrappingObj(); + obj->field = field; + + LegalElementWrapping result; + result.flavor = Flavor::implicitDeref; + result.obj = obj; + return result; +} + +RefPtr LegalElementWrapping::getImplicitDeref() const +{ + SLANG_ASSERT(flavor == Flavor::implicitDeref); + return obj.as(); +} + +LegalElementWrapping LegalElementWrapping::makePair( + LegalElementWrapping const& ordinary, + LegalElementWrapping const& special, + PairInfo* pairInfo) +{ + RefPtr obj = new PairLegalElementWrappingObj(); + obj->ordinary = ordinary; + obj->special = special; + obj->pairInfo = pairInfo; + + LegalElementWrapping result; + result.flavor = Flavor::pair; + result.obj = obj; + return result; +} + +RefPtr LegalElementWrapping::getPair() const +{ + SLANG_ASSERT(flavor == Flavor::pair); + return obj.as(); +} + +LegalElementWrapping LegalElementWrapping::makeTuple(TupleLegalElementWrappingObj* obj) +{ + LegalElementWrapping result; + result.flavor = Flavor::tuple; + result.obj = obj; + return result; +} + +RefPtr LegalElementWrapping::getTuple() const +{ + SLANG_ASSERT(flavor == Flavor::tuple); + return obj.as(); +} + +// + +bool isResourceType(IRType* type) +{ + while (auto arrayType = as(type)) + { + type = arrayType->getElementType(); + } + + if (auto resourceTypeBase = as(type)) + { + return true; + } + else if (auto builtinGenericType = as(type)) + { + return true; + } + else if (auto pointerLikeType = as(type)) + { + return true; + } + else if (auto samplerType = as(type)) + { + return true; + } + else if(auto untypedBufferType = as(type)) + { + return true; + } + + // TODO: need more comprehensive coverage here + + return false; +} + +ModuleDecl* findModuleForDecl( + Decl* decl) +{ + for (auto dd = decl; dd; dd = dd->ParentDecl) + { + if (auto moduleDecl = as(dd)) + return moduleDecl; + } + return nullptr; +} + + +// Helper type for legalization of aggregate types +// that might need to be turned into tuple pseudo-types. +struct TupleTypeBuilder +{ + TypeLegalizationContext* context; + IRType* type; + IRStructType* originalStructType; + + struct OrdinaryElement + { + IRStructKey* fieldKey = nullptr; + IRType* type = nullptr; + }; + + + List ordinaryElements; + List specialElements; + + List pairElements; + + // Did we have any fields that forced us to change + // the actual type away from the declared type? + bool anyComplex = false; + + // Did we have any fields that actually required + // storage in the "special" part of things? + bool anySpecial = false; + + // Did we have any fields that actually used ordinary storage? + bool anyOrdinary = false; + + // Add a field to the (pseudo-)type we are building + void addField( + IRStructKey* fieldKey, + LegalType legalFieldType, + LegalType legalLeafType, + bool isSpecial) + { + LegalType ordinaryType; + LegalType specialType; + RefPtr elementPairInfo; + switch (legalLeafType.flavor) + { + case LegalType::Flavor::simple: + { + // We need to add an actual field, but we need + // to check if it is a resource type to know + // whether it should go in the "ordinary" list or not. + if (!isSpecial) + { + ordinaryType = legalLeafType; + } + else + { + specialType = legalFieldType; + } + } + break; + + case LegalType::Flavor::none: + anyComplex = true; + break; + + case LegalType::Flavor::implicitDeref: + { + // TODO: we may want to say that any use + // of `implicitDeref` puts the entire thing + // into the "special" category, rather than + // try to look under the hood... + + anyComplex = true; + + // We want to recursively add data + // based on the unwrapped type. + // + // Note: this assumes we can't have a tuple + // or a pair "under" an `implicitDeref`, so + // we'll need to ensure that elsewhere. + addField( + fieldKey, + legalFieldType, + legalLeafType.getImplicitDeref()->valueType, + isSpecial); + return; + } + break; + + case LegalType::Flavor::pair: + { + // The field's type had both special and non-special parts + auto pairType = legalLeafType.getPair(); + + // If things originally started as a resource type, then + // we want to externalize all the fields that arose, even + // if there is (nominally) ordinary data. + // + // This is because the "ordinary" side of the legalization + // of `ConstantBuffer` will still be a resource type. + if(isSpecial) + { + specialType = legalFieldType; + } + else + { + ordinaryType = pairType->ordinaryType; + specialType = pairType->specialType; + elementPairInfo = pairType->pairInfo; + } + } + break; + + case LegalType::Flavor::tuple: + { + // A tuple always represents "special" data + specialType = legalFieldType; + } + break; + + default: + SLANG_UNEXPECTED("unknown legal type flavor"); + break; + } + + PairInfo::Element pairElement; + pairElement.flags = 0; + pairElement.key = fieldKey; + pairElement.fieldPairInfo = elementPairInfo; + + // We will always add a field to the "ordinary" + // side of things, even if it has no ordinary + // data, just to keep the list of fields aligned + // with the original type. + OrdinaryElement ordinaryElement; + ordinaryElement.fieldKey = fieldKey; + if (ordinaryType.flavor != LegalType::Flavor::none) + { + anyOrdinary = true; + pairElement.flags |= PairInfo::kFlag_hasOrdinary; + + LegalType ot = ordinaryType; + + // TODO: any cases we should "unwrap" here? + // E.g., `implicitDeref`? + + if(ot.flavor == LegalType::Flavor::simple) + { + ordinaryElement.type = ot.getSimple(); + } + else + { + SLANG_UNEXPECTED("unexpected ordinary field type"); + } + } + ordinaryElements.add(ordinaryElement); + + if (specialType.flavor != LegalType::Flavor::none) + { + anySpecial = true; + anyComplex = true; + pairElement.flags |= PairInfo::kFlag_hasSpecial; + + TuplePseudoType::Element specialElement; + specialElement.key = fieldKey; + specialElement.type = specialType; + specialElements.add(specialElement); + } + + pairElement.type = LegalType::pair(ordinaryType, specialType, elementPairInfo); + pairElements.add(pairElement); + } + + // Add a field to the (pseudo-)type we are building + void addField( + IRStructField* field) + { + auto fieldType = field->getFieldType(); + + bool isSpecialField = context->isSpecialType(fieldType); + auto legalFieldType = legalizeType(context, fieldType); + + addField( + field->getKey(), + legalFieldType, + legalFieldType, + isSpecialField); + } + + LegalType getResult() + { + // If this is an empty struct, return a none type + // This helps get rid of emtpy structs that often trips up the + // downstream compiler + if (!anyOrdinary && !anySpecial && !anyComplex) + return LegalType(); + + // If we didn't see anything "special" + // then we can use the type as-is. + // we can conceivably just use the type as-is + // + if (!anyComplex) + { + return LegalType::simple(type); + } + + // If there were any "ordinary" fields along the way, + // then we need to collect them into a new `struct` type + // that represents these fields. + // + LegalType ordinaryType; + if (anyOrdinary) + { + // We are going to create an new IR `struct` type that contains + // the "ordinary" fields from the original type. Note that these + // fields may have different types from what they did before, + // because the fields themselves might have been legalized. + // + // The new type will have the same mangled name as the old one, so + // downstream code is going to need to be careful not to emit declarations + // for both of them. This should be okay, though, because the original + // type was illegal (that was the whole point) and so it shouldn't be + // referenced in the output anyway. + // + IRBuilder* builder = context->getBuilder(); + IRStructType* ordinaryStructType = builder->createStructType(); + ordinaryStructType->sourceLoc = originalStructType->sourceLoc; + + if(auto nameHintDecoration = originalStructType->findDecoration()) + { + builder->addNameHintDecoration(ordinaryStructType, nameHintDecoration->getNameOperand()); + } + + // The new struct type will appear right after the original in the IR, + // so that we can be sure any instruction that could reference the + // original can also reference the new one. + ordinaryStructType->insertAfter(originalStructType); + + // Mark the original type for removal once all the other legalization + // activity is completed. This is necessary because both the original + // and replacement type have the same mangled name, so they would + // collide. + // + // (Also, the original type wasn't legal - that was the whole point...) + context->replacedInstructions.add(originalStructType); + + for(auto ee : ordinaryElements) + { + // We will ensure that all the original fields are represented, + // although they may have different types (due to legalization). + // For fields that have *no* ordinary data, we will give them + // a dummy `void` type and rely on downstream passes to not + // actually emit declarations for those fields. + // + // (This helps keeps things simple because both the original + // and modified type will have the same number of fields, so + // we can continue to look up field layouts by index in the + // emit logic) + // + // TODO: we should scrap that, and layout lookup should just + // be based on mangled field names in all cases. + // + IRType* fieldType = ee.type; + if(!fieldType) + fieldType = context->getBuilder()->getVoidType(); + + // TODO: shallow clone of modifiers, etc. + + builder->createStructField( + ordinaryStructType, + ee.fieldKey, + fieldType); + } + + ordinaryType = LegalType::simple((IRType*) ordinaryStructType); + } + + LegalType specialType; + if (anySpecial) + { + RefPtr specialTuple = new TuplePseudoType(); + specialTuple->elements = specialElements; + specialType = LegalType::tuple(specialTuple); + } + + RefPtr pairInfo; + if (anyOrdinary && anySpecial) + { + pairInfo = new PairInfo(); + pairInfo->elements = pairElements; + } + + return LegalType::pair(ordinaryType, specialType, pairInfo); + } + +}; + +static IRType* createBuiltinGenericType( + TypeLegalizationContext* context, + IROp op, + IRType* elementType) +{ + IRInst* operands[] = { elementType }; + return context->getBuilder()->getType( + op, + 1, + operands); +} + +// Create a uniform buffer type with a given legalized +// element type. +static LegalType createLegalUniformBufferType( + TypeLegalizationContext* context, + IROp op, + LegalType legalElementType) +{ + // We will handle some of the easy/non-interesting + // cases here in the main routine, but for all + // the non-trivial cases we will dispatch to logic + // on the `context` (which may differ depending + // on what we are using legalization to accomplish). + // + switch (legalElementType.flavor) + { + default: + return context->createLegalUniformBufferType( + op, + legalElementType); + + case LegalType::Flavor::none: + return LegalType(); + + case LegalType::Flavor::simple: + { + // Easy case: we just have a simple element type, + // so we want to create a uniform buffer that wraps it. + // + // TODO: This isn't *quite* right, since it won't handle something + // like a `ParameterBlock`, but that seems like + // an unlikely case in practice. + // + return LegalType::simple(createBuiltinGenericType( + context, + op, + legalElementType.getSimple())); + } + break; + + case LegalType::Flavor::implicitDeref: + { + // This is actually an annoying case, because + // we are being asked to convert, e.g.,: + // + // cbuffer Foo { ParameterBlock bar; } + // + // into the equivalent of: + // + // cbuffer Foo { Bar bar; } + // + // Which would really require a new `LegalType` that + // would reprerent a resource type with a modified + // element type. + // + // I'm going to attempt to hack this for now. + return LegalType::implicitDeref(createLegalUniformBufferType( + context, + op, + legalElementType.getImplicitDeref()->valueType)); + } + break; + } +} + +// Create a uniform buffer type with a given legalized element type, +// under the assumption that we are doing resource-based type legalization. +// +LegalType createLegalUniformBufferTypeForResources( + TypeLegalizationContext* context, + IROp op, + LegalType legalElementType) +{ + switch (legalElementType.flavor) + { + case LegalType::Flavor::simple: + { + // Seeing a simple type here means that it must be a + // "special" type (a resource type or array thereof) + // because otherwise the catch-all behavior in + // `createLegalUniformBufferType()` would have handled it. + // + // This case is the same as what we do for tuple elements below. + // + return LegalType::implicitDeref(legalElementType); + } + + case LegalType::Flavor::pair: + { + auto pairType = legalElementType.getPair(); + + // The pair has both an "ordinary" and a "special" + // side, where the ordinary side is just plain data + // that we can put in a constant buffer type without + // any problems. The special side will (recursively) + // contain any resource-type fields that were nested + // in the constant buffer, and we'll need to + // treat those as resources that stand alongside + // the original constant buffer. + // + // We can start with the ordinary side, which we + // just want to wrap up in an ordinary uniform + // buffer with the appropriate `op`, so that case + // is easy: + // + auto ordinaryType = createLegalUniformBufferType( + context, + op, + pairType->ordinaryType); + + // For the special side, we really just want to turn + // a special field of type `R` into a value of type + // `R`, and the main detail we have to be aware of + // is that any use sites for the original buffer/block + // will include a dereferencing step to get from + // the block to this field, so we need to add + // something to the type structure to account for + // that step. + // + // We handle that issue by wrapping the special + // part of the type in an `implicitDeref` wrapper, + // which indicates that we logically have `SomePtr` + // but we actually just have `R`, and any attempt to + // load from or otherwise indirect through that pointer + // will turn into a plain old reference to the `R` value. + // + auto specialType = LegalType::implicitDeref(pairType->specialType); + + // Once we've wrapped up both the ordinary and special + // sides suitably, we tie them back together in a pair + // and make that be the legalized type of the result. + // + return LegalType::pair(ordinaryType, specialType, pairType->pairInfo); + } + + case LegalType::Flavor::tuple: + { + // A tuple type always represents purely "special" data, + // which in this case means resources. + // + // As in the `pair` case, the main thing we have to + // take into account is that each of the entries in the + // tuple itself (e.g., a value of type `R`) and the code + // that uses the legalized buffer type will expect a + // `ConstantBuffer` or at least `SomePtrType`. + // + // We will construct a new tuple type that wraps each + // of the element types in an `implicitDeref` to + // account for the different in levels of indirection. + // + // TODO: This seems odd, because we *should* be able to + // just wrap the whole thing in an `implicitDeref` and + // have done. We should investigate why this roundabout + // way of doing things was ever necessary. + + auto elementPseudoTupleType = legalElementType.getTuple(); + RefPtr bufferPseudoTupleType = new TuplePseudoType(); + + for (auto ee : elementPseudoTupleType->elements) + { + TuplePseudoType::Element newElement; + + newElement.key = ee.key; + newElement.type = LegalType::implicitDeref(ee.type); + + bufferPseudoTupleType->elements.add(newElement); + } + + return LegalType::tuple(bufferPseudoTupleType); + } + break; + + default: + SLANG_UNEXPECTED("unhandled legal type flavor"); + UNREACHABLE_RETURN(LegalType()); + break; + } +} + +// Legalizing a uniform buffer/block type for existentials is +// more interesting, because we don't actually want to push +// the "special" fields out of the buffer entirely (as we +// do for resources), and instead we just want to place +// them in the buffer *after* all the ordinary data. +// +// In order to accomplish this we need a way to emit a +// constant buffer with a new element type, and then +// "wrap" that constant buffer so that it looks like +// something that matches the legalization of the original +// element type. +// +// As a concrete example, suppose we have: +// +// struct Params { ExistentialBox f; int x; ExistentialBox b; }; +// ConstantBuffer gParams; +// +// The legalized form of `Params` will be something like: +// +// Pair( +// /* ordinary: */ struct Params_Ordinary { int x; }, +// /* special: */ Tuple( +// f -> ImplicitDeref(Foo), +// b -> ImplicitDeref(Bar))) +// +// We need to be able to splat that all out into a single +// structure declaration like: +// +// struct Params_Reordered +// { +// Params_Ordinary ordinary; +// Foo f; +// Bar b; +// } +// +// That allows us to declare: +// +// ConstantBuffer gParams; +// +// That gets the in-memory layout of things correct for the +// way we are defining existential value slots to work. +// The challenge is that elsewehere in the code there are +// operations like `gParams.x` need to now refer to +// `gParams.ordinary.x`. Furthermore, even for something like +// `f` that seems fine in the example above, we have lost +// a level of indirection, so that where we had `load(gParams.f)` +// we now want just `gParams.f`. +// +// The solution is to take `gParams` as soon as it is declared +// and wrap it up as a new value: +// +// pair( +// /* ordinary: */ gParams.ordinary, +// /* special: */ tuple( +// f -> implicitDeref(gParams.f), +// b -> implicitDeref(gParams.b))) +// +// +// Let's begin by just defining a function that can take +// a `LegalType` and turn it into zero or more field +// declarations, and return enough tracking information +// for us to be able to reconstruct a value like the above. +// +LegalElementWrapping declareStructFields( + TypeLegalizationContext* context, + IRStructType* structType, + LegalType fieldType) +{ + // TODO: We should eventually thread through some kind + // of "name hint" that can be used to give the generated + // fields more useful names. + + switch(fieldType.flavor) + { + case LegalType::Flavor::none: + return LegalElementWrapping::makeVoid(); + + case LegalType::Flavor::simple: + { + auto simpleFieldType = fieldType.getSimple(); + auto builder = context->getBuilder(); + auto fieldKey = builder->createStructKey(); + builder->createStructField(structType, fieldKey, simpleFieldType); + return LegalElementWrapping::makeSimple(fieldKey, simpleFieldType); + } + + case LegalType::Flavor::implicitDeref: + { + auto subField = declareStructFields(context, structType, fieldType.getImplicitDeref()->valueType); + return LegalElementWrapping::makeImplicitDeref(subField); + } + + case LegalType::Flavor::pair: + { + auto pairType = fieldType.getPair(); + auto ordinaryField = declareStructFields(context, structType, pairType->ordinaryType); + auto specialField = declareStructFields(context, structType, pairType->specialType); + return LegalElementWrapping::makePair( + ordinaryField, + specialField, + pairType->pairInfo); + } + + case LegalType::Flavor::tuple: + { + auto tupleType = fieldType.getTuple(); + + RefPtr obj = new TupleLegalElementWrappingObj(); + for( auto ee : tupleType->elements ) + { + TupleLegalElementWrappingObj::Element element; + element.key = ee.key; + element.field = declareStructFields(context, structType, ee.type); + obj->elements.add(element); + } + return LegalElementWrapping::makeTuple(obj); + } + + default: + SLANG_UNEXPECTED("unhandled legal type flavor"); + UNREACHABLE_RETURN(LegalElementWrapping::makeVoid()); + break; + } +} + +LegalType createLegalUniformBufferTypeForExistentials( + TypeLegalizationContext* context, + IROp op, + LegalType legalElementType) +{ + auto builder = context->getBuilder(); + + // In order to wrap up all the data in `legalElementType`, + // will create a fresh `struct` type and then declare + // fields in it that are sufficient to hold that data + // in `legalElementType`. + // + auto structType = builder->createStructType(); + auto elementWrapping = declareStructFields( + context, structType, legalElementType); + + // Because the `structType` is an ordinary IR type + // (not a `LegalType`) we can go ahead and create an + // IR uniform buffer type that wraps it. + // + auto bufferType = createBuiltinGenericType( + context, + op, + structType); + + // The `elementWrapping` computed when we declared all + // the `struct` fields tells us how to get from the + // actual fields declared in the structure type to a + // `LegalVal` with the right shape for what users of + // the buffer will expect. We record both the underlying + // IR buffer type and that wrapping information into + // the resulting `LegalType` so that we can use it + // when declaring variables of this type. + // + return LegalType::makeWrappedBuffer(bufferType, elementWrapping); +} + +static LegalType createLegalUniformBufferType( + TypeLegalizationContext* context, + IRUniformParameterGroupType* uniformBufferType, + LegalType legalElementType) +{ + return createLegalUniformBufferType( + context, + uniformBufferType->op, + legalElementType); +} + +// Create a pointer type with a given legalized value type. +static LegalType createLegalPtrType( + TypeLegalizationContext* context, + IROp op, + LegalType legalValueType) +{ + switch (legalValueType.flavor) + { + case LegalType::Flavor::none: + return LegalType(); + + case LegalType::Flavor::simple: + { + // Easy case: we just have a simple element type, + // so we want to create a uniform buffer that wraps it. + return LegalType::simple(createBuiltinGenericType( + context, + op, + legalValueType.getSimple())); + } + break; + + case LegalType::Flavor::implicitDeref: + { + // We are being asked to create a pointer type to something + // that is implicitly dereferenced, meaning we had: + // + // Ptr(PtrLike(T)) + // + // and now are being asked to make: + // + // Ptr(implicitDeref(LegalT)) + // + // So it seems like we can just create: + // + // implicitDeref(Ptr(LegalT)) + // + // and nobody should really be able to tell the difference, right? + // + // TODO: invetigate whether there are situations where this + // will matter. + return LegalType::implicitDeref(createLegalPtrType( + context, + op, + legalValueType.getImplicitDeref()->valueType)); + } + break; + + case LegalType::Flavor::pair: + { + // We just need to pointer-ify both sides of the pair. + auto pairType = legalValueType.getPair(); + + auto ordinaryType = createLegalPtrType( + context, + op, + pairType->ordinaryType); + auto specialType = createLegalPtrType( + context, + op, + pairType->specialType); + + return LegalType::pair(ordinaryType, specialType, pairType->pairInfo); + } + + case LegalType::Flavor::tuple: + { + // Wrap each of the tuple elements up as a pointer. + auto valuePseudoTupleType = legalValueType.getTuple(); + + RefPtr ptrPseudoTupleType = new TuplePseudoType(); + + // Wrap all the pseudo-tuple elements with `implicitDeref`, + // since they used to be inside a tuple, but aren't any more. + for (auto ee : valuePseudoTupleType->elements) + { + TuplePseudoType::Element newElement; + + newElement.key = ee.key; + newElement.type = createLegalPtrType( + context, + op, + ee.type); + + ptrPseudoTupleType->elements.add(newElement); + } + + return LegalType::tuple(ptrPseudoTupleType); + } + break; + + default: + SLANG_UNEXPECTED("unknown legal type flavor"); + UNREACHABLE_RETURN(LegalType()); + break; + } +} + +struct LegalTypeWrapper +{ + virtual LegalType wrap(TypeLegalizationContext* context, IRType* type) = 0; +}; + +struct ArrayLegalTypeWrapper : LegalTypeWrapper +{ + IRArrayTypeBase* arrayType; + + LegalType wrap(TypeLegalizationContext* context, IRType* type) + { + return LegalType::simple(context->getBuilder()->getArrayTypeBase( + arrayType->op, + type, + arrayType->getElementCount())); + } +}; + +struct BuiltinGenericLegalTypeWrapper : LegalTypeWrapper +{ + IROp op; + + LegalType wrap(TypeLegalizationContext* context, IRType* type) + { + return LegalType::simple(createBuiltinGenericType( + context, + op, + type)); + } +}; + + +struct ImplicitDerefLegalTypeWrapper : LegalTypeWrapper +{ + LegalType wrap(TypeLegalizationContext*, IRType* type) + { + return LegalType::implicitDeref(LegalType::simple(type)); + } +}; + +static LegalType wrapLegalType( + TypeLegalizationContext* context, + LegalType legalType, + LegalTypeWrapper* ordinaryWrapper, + LegalTypeWrapper* specialWrapper) +{ + switch (legalType.flavor) + { + case LegalType::Flavor::none: + return LegalType(); + + case LegalType::Flavor::simple: + { + return ordinaryWrapper->wrap(context, legalType.getSimple()); + } + break; + + case LegalType::Flavor::implicitDeref: + { + return LegalType::implicitDeref(wrapLegalType( + context, + legalType, + ordinaryWrapper, + specialWrapper)); + } + break; + + case LegalType::Flavor::pair: + { + // We just need to pointer-ify both sides of the pair. + auto pairType = legalType.getPair(); + + auto ordinaryType = wrapLegalType( + context, + pairType->ordinaryType, + ordinaryWrapper, + ordinaryWrapper); + auto specialType = wrapLegalType( + context, + pairType->specialType, + specialWrapper, + specialWrapper); + + return LegalType::pair(ordinaryType, specialType, pairType->pairInfo); + } + + case LegalType::Flavor::tuple: + { + // Wrap each of the tuple elements up as a pointer. + auto tupleType = legalType.getTuple(); + + RefPtr resultTupleType = new TuplePseudoType(); + + // Wrap all the pseudo-tuple elements with `implicitDeref`, + // since they used to be inside a tuple, but aren't any more. + for (auto ee : tupleType->elements) + { + TuplePseudoType::Element element; + + element.key = ee.key; + element.type = wrapLegalType( + context, + ee.type, + ordinaryWrapper, + specialWrapper); + + resultTupleType->elements.add(element); + } + + return LegalType::tuple(resultTupleType); + } + break; + + default: + SLANG_UNEXPECTED("unknown legal type flavor"); + UNREACHABLE_RETURN(LegalType()); + break; + } +} + +// Legalize a type, including any nested types +// that it transitively contains. +LegalType legalizeTypeImpl( + TypeLegalizationContext* context, + IRType* type) +{ + if(!type) + return LegalType::simple(nullptr); + + context->builder->setInsertBefore(type); + + if (auto uniformBufferType = as(type)) + { + // We have one of: + // + // ConstantBuffer + // TextureBuffer + // ParameterBlock + // + // or some other pointer-like type that represents uniform + // parameters. We need to pull any resource-type fields out + // of it, but leave non-resource fields where they are. + // + // As a special case, if the type contains *no* uniform data, + // we'll want to completely eliminate the uniform/ordinary + // part. + + auto originalElementType = uniformBufferType->getElementType(); + + // Legalize the element type to see what we are working with. + auto legalElementType = legalizeType(context, + originalElementType); + + // As a bit of a corner case, if the user requested something + // like `ConstantBuffer` the element type would + // legalize to a "simple" type, and that would be interpreted + // as an *ordinary* type, but we really need to notice the + // case when the element type is simple, but *special*. + // + if( context->isSpecialType(originalElementType) ) + { + // Anything that has a special element type needs to + // be handled by the pass-specific logic in the context. + // + return context->createLegalUniformBufferType( + uniformBufferType->op, + legalElementType); + } + + // Note that even when legalElementType.flavor == Simple + // we still need to create a new uniform buffer type + // from `legalElementType` instead of `type` + // because the `legalElementType` may still differ from `type` + // if, e.g., `type` contains empty structs. + return createLegalUniformBufferType( + context, + uniformBufferType, + legalElementType); + + } + else if (isResourceType(type)) + { + // We assume that any resource types not handled above + // are legal as-is. + return LegalType::simple(type); + } + else if (as(type)) + { + return LegalType::simple(type); + } + else if (as(type)) + { + return LegalType::simple(type); + } + else if (as(type)) + { + return LegalType::simple(type); + } + else if( auto existentialPtrType = as(type)) + { + // We want to transform an `ExistentialBox` into just + // a `T`, with an `iplicitDeref` to make sure that any + // pointer-related operations on the box Just Work. + // + // Note: the logic here doesn't have to deal with moving + // existential-type fields to the end of their outer + // type(s) because that is mostly dealt with in the + // case for struct types below. + // + auto legalValueType = legalizeType(context, existentialPtrType->getValueType()); + return LegalType::implicitDeref(legalValueType); + } + else if (auto ptrType = as(type)) + { + auto legalValueType = legalizeType(context, ptrType->getValueType()); + return createLegalPtrType(context, ptrType->op, legalValueType); + } + else if(auto structType = as(type)) + { + // Look at the (non-static) fields, and + // see if anything needs to be cleaned up. + // The things that need to be "cleaned up" for + // our purposes are: + // + // - Fields of resource type, or any other future + // type we run into that isn't allowed in + // aggregates for at least some targets + // + // - Fields with types that themselves had to + // get legalized. + // + // If we don't run into any of these, we + // can just use the type as-is. Hooray! + // + // Otherwise, we are effectively going to split + // the type apart and create a `TuplePseudoType`. + // Every field of the original type will be + // represented as an element of this pseudo-type. + // Each element will record its `LegalType`, + // and the original field that it was created from. + // An element will also track whether it contains + // any "ordinary" data, and if so, it will remember + // an element index in a real (AST-level, non-pseudo) + // `TupleType` that is used to bundle together + // such fields. + // + // Storing all the simple fields together like this + // obviously adds complexity to the legalization + // pass, but it has important benefits: + // + // - It avoids creating functions with a very large + // number of parameters (when passing a structure + // with many fields), which might confuse downstream + // compilers. + // + // - It avoids applying AOS->SOA conversion to fields + // that don't actually need it, which is basically + // required if we want type layout to work. + // + // - It ensures that we can actually construct a + // constant-buffer type that wraps a legalized + // aggregate type; the ordinary fields will get + // placed inside a new constant-buffer type, + // while the special ones will get left outside. + // + + // TODO: there is a risk here that we might recursively + // invole `legalizeType` on the type that we are + // currently trying to legalize. We need to detect that + // situation somehow, by inserting a sentinel value + // into `mapTypeToLegalType` during the per-field + // legalization process, and then if we ever see that + // sentinel in a call to `legalizeType`, we need + // to construct some kind of proxy type to help resolve + // the problem. + + TupleTypeBuilder builder; + builder.context = context; + builder.type = type; + builder.originalStructType = structType; + + for (auto ff : structType->getFields()) + { + builder.addField(ff); + } + + return builder.getResult(); + } + else if(auto arrayType = as(type)) + { + auto legalElementType = legalizeType( + context, + arrayType->getElementType()); + + ArrayLegalTypeWrapper wrapper; + wrapper.arrayType = arrayType; + + return wrapLegalType( + context, + legalElementType, + &wrapper, + &wrapper); + } + + return LegalType::simple(type); +} + +LegalType legalizeType( + TypeLegalizationContext* context, + IRType* type) +{ + LegalType legalType; + if(context->mapTypeToLegalType.TryGetValue(type, legalType)) + return legalType; + + legalType = legalizeTypeImpl(context, type); + context->mapTypeToLegalType[type] = legalType; + return legalType; +} + +// + +RefPtr getDerefTypeLayout( + TypeLayout* typeLayout) +{ + if (!typeLayout) + return nullptr; + + if (auto parameterGroupTypeLayout = as(typeLayout)) + { + return parameterGroupTypeLayout->offsetElementTypeLayout; + } + + return typeLayout; +} + +RefPtr getFieldLayout( + TypeLayout* typeLayout, + IRInst* fieldKey) +{ + if (!typeLayout) + return nullptr; + + for(;;) + { + if(auto arrayTypeLayout = as(typeLayout)) + { + typeLayout = arrayTypeLayout->elementTypeLayout; + } + else if(auto parameterGroupTypeLayout = as(typeLayout)) + { + typeLayout = parameterGroupTypeLayout->offsetElementTypeLayout; + } + else + { + break; + } + } + + + if (auto structTypeLayout = as(typeLayout)) + { + // First, let's see if the field had a layout registered + // directly using its IR key. + // + RefPtr fieldLayout; + if(structTypeLayout->mapKeyToLayout.TryGetValue(fieldKey, fieldLayout)) + return fieldLayout; + + // Otherwise, fall back to doing lookup using the linkage + // attached to the key, and its mangled name. + // + auto fieldLinkage = fieldKey->findDecoration(); + if(!fieldLinkage) + return nullptr; + auto mangledFieldName = fieldLinkage->getMangledName(); + + // In this case we fall back to a linear search over the fields. + // + for(auto ff : structTypeLayout->fields) + { + if(mangledFieldName == getMangledName(ff->varDecl.getDecl()).getUnownedSlice() ) + { + return ff; + } + } + } + + return nullptr; +} + +RefPtr createSimpleVarLayout( + SimpleLegalVarChain* varChain, + TypeLayout* typeLayout) +{ + if (!typeLayout) + return nullptr; + + // We need to construct a layout for the new variable + // that reflects both the type we have given it, as + // well as all the offset information that has accumulated + // along the chain of parent variables. + + // TODO: this logic needs to propagate through semantics... + + RefPtr varLayout = new VarLayout(); + varLayout->typeLayout = typeLayout; + + // For most resource kinds, the register index/space to use should + // be the sum along the entire chain of variables. + // + // For example, if we had input: + // + // struct S { Texture2D a; Texture2D b; }; + // S s : register(t10); + // + // And we were generating a stand-alone variable for `s.b`, then + // we'd need to add the offset for `b` (1 texture register), to + // the offset for `s` (10 texture registers) to get the final + // binding to apply. + // + for (auto rr : typeLayout->resourceInfos) + { + auto resInfo = varLayout->findOrAddResourceInfo(rr.kind); + + for (auto vv = varChain; vv; vv = vv->next) + { + if (auto parentResInfo = vv->varLayout->FindResourceInfo(rr.kind)) + { + resInfo->index += parentResInfo->index; + resInfo->space += parentResInfo->space; + } + } + } + + // As a special case, if the leaf variable doesn't hold an entry for + // `RegisterSpace`, but at least one declaration in the chain *does*, + // then we want to make sure that we add such an entry. + if (!varLayout->FindResourceInfo(LayoutResourceKind::RegisterSpace)) + { + // Sum up contributions from all parents. + UInt space = 0; + for (auto vv = varChain; vv; vv = vv->next) + { + if (auto parentResInfo = vv->varLayout->FindResourceInfo(LayoutResourceKind::RegisterSpace)) + { + space += parentResInfo->index; + } + } + + // If there were non-zero contributions, then add an entry to represent them. + if (space) + { + varLayout->findOrAddResourceInfo(LayoutResourceKind::RegisterSpace)->index = space; + } + } + + return varLayout; +} + + +RefPtr createVarLayout( + LegalVarChain const& varChain, + TypeLayout* typeLayout) +{ + if(!typeLayout) + return nullptr; + + auto varLayout = createSimpleVarLayout(varChain.primaryChain, typeLayout); + + if(auto pendingDataTypeLayout = typeLayout->pendingDataTypeLayout) + { + varLayout->pendingVarLayout = createSimpleVarLayout(varChain.pendingChain, typeLayout); + } + + return varLayout; +} + +// + +// TODO(tfoley): The code captured here is the logic that used to be +// applied to decide whether or not to desugar aggregate types that +// contain resources. Right now the implementation will *always* legalize +// away such types (since the IR always does this), while the AST-to-AST +// pass would only do it if required (according to the tests below). +// +// For right now this is an academic distinction, since the only project +// using Slang right now enables this tansformation unconditionally, but +// we probably need to re-parent this code back into the `TypeLegalizationContext` +// somewhere. +#if 0 + +bool shouldDesugarTupleTypes = false; +if (getTarget() == CodeGenTarget::GLSL) +{ + // Always desugar this stuff for GLSL, since it doesn't + // support nesting of resources in structs. + // + // TODO: Need a way to make this more fine-grained to + // handle cases where a nested member might be allowed + // due to, e.g., bindless textures. + shouldDesugarTupleTypes = true; +} +else if( shared->compileRequest->compileFlags & SLANG_COMPILE_FLAG_SPLIT_MIXED_TYPES ) +{ + // If the user is directly asking us to do this transformation, + // then obviously we need to do it. + // + // TODO: The way this is defined here means it will even apply to user + // HLSL code (not just code written in Slang). We may want to + // reconsider that choice, and only split things that originated in Slang. + // + shouldDesugarTupleTypes = true; +} + +#endif + +} diff --git a/source/slang/slang-legalize-types.h b/source/slang/slang-legalize-types.h new file mode 100644 index 000000000..e92a9fc41 --- /dev/null +++ b/source/slang/slang-legalize-types.h @@ -0,0 +1,678 @@ +// slang-legalize-types.h +#ifndef SLANG_LEGALIZE_TYPES_H_INCLUDED +#define SLANG_LEGALIZE_TYPES_H_INCLUDED + +// This file and `legalize-types.cpp` implement the core +// logic for taking a `Type` as produced by the front-end, +// and turning it into a suitable representation for use +// on a particular back-end. +// +// The main work applies to aggregate (e.g., `struct`) types, +// since various targets have rules about what is and isn't +// allowed in an aggregate (or where aggregates are allowed +// to be used). +// +// We might completely replace an aggregate `Type` with a +// "pseudo-type" that is just the enumeration of its field +// types (sort of a tuple type) so that a variable declared +// with the original type should be transformed into a +// bunch of individual variables. +// +// Alternatively, we might replace an aggregate type, where +// only *some* of the fields are illegal with a combination +// of an aggregate (containing the legal/legalized fields), +// and some extra tuple-ified fields. + +#include "../core/slang-basic.h" +#include "slang-ir-insts.h" +#include "slang-syntax.h" +#include "slang-type-layout.h" +#include "slang-name.h" + +namespace Slang +{ + +struct IRBuilder; + +struct LegalTypeImpl : RefObject +{ +}; +struct ImplicitDerefType; +struct TuplePseudoType; +struct PairPseudoType; +struct PairInfo; +struct LegalElementWrapping; +struct WrappedBufferPseudoType; + + /// A flavor for types or values that arise during legalization. +enum class LegalFlavor +{ + /// Nothing: an empty type or value. Equivalent to `void`. + none, + + /// A simple type/value that can be represented as an `IRType*` or `IRInst*` + simple, + + /// Logically, a pointer-like type/value, but represented as the type/value being pointed type. + implicitDeref, + + /// A compound type/value made up of the constituent fields of some original value. + tuple, + + /// A type/value that was split into "ordinary" and "special" parts. + pair, + + /// A type/value that represents, e.g., `ConstantBuffer` where `T` needed legalization. + wrappedBuffer, +}; + +struct LegalType +{ + typedef LegalFlavor Flavor; + + Flavor flavor = Flavor::none; + RefPtr obj; + IRType* irType; + + static LegalType simple(IRType* type) + { + LegalType result; + result.flavor = Flavor::simple; + result.irType = type; + return result; + } + + IRType* getSimple() const + { + SLANG_ASSERT(flavor == Flavor::simple); + return irType; + } + + static LegalType implicitDeref( + LegalType const& valueType); + + RefPtr getImplicitDeref() const + { + SLANG_ASSERT(flavor == Flavor::implicitDeref); + return obj.as(); + } + + static LegalType tuple( + RefPtr tupleType); + + RefPtr getTuple() const + { + SLANG_ASSERT(flavor == Flavor::tuple); + return obj.as(); + } + + static LegalType pair( + RefPtr pairType); + + static LegalType pair( + LegalType const& ordinaryType, + LegalType const& specialType, + RefPtr pairInfo); + + RefPtr getPair() const + { + SLANG_ASSERT(flavor == Flavor::pair); + return obj.as(); + } + + static LegalType makeWrappedBuffer( + IRType* simpleType, + LegalElementWrapping const& elementInfo); + + RefPtr getWrappedBuffer() const + { + SLANG_ASSERT(flavor == Flavor::wrappedBuffer); + return obj.as(); + } +}; + +struct LegalElementWrappingObj : RefObject +{ +}; + +struct SimpleLegalElementWrappingObj; +struct ImplicitDerefLegalElementWrappingObj; +struct PairLegalElementWrappingObj; +struct TupleLegalElementWrappingObj; + + /// Information on how the element type of a buffer needs to be wrapped. +struct LegalElementWrapping +{ + typedef LegalFlavor Flavor; + + Flavor flavor; + RefPtr obj; + + static LegalElementWrapping makeVoid(); + static LegalElementWrapping makeSimple(IRStructKey* key, IRType* type); + static LegalElementWrapping makeImplicitDeref(LegalElementWrapping const& field); + static LegalElementWrapping makePair( + LegalElementWrapping const& ordinary, + LegalElementWrapping const& special, + PairInfo* pairInfo); + static LegalElementWrapping makeTuple(TupleLegalElementWrappingObj* obj); + + RefPtr getSimple() const; + RefPtr getImplicitDeref() const; + RefPtr getPair() const; + RefPtr getTuple() const; +}; + +struct SimpleLegalElementWrappingObj : LegalElementWrappingObj +{ + IRStructKey* key; + IRType* type; +}; + +struct ImplicitDerefLegalElementWrappingObj : LegalElementWrappingObj +{ + LegalElementWrapping field; +}; + +struct PairLegalElementWrappingObj : LegalElementWrappingObj +{ + LegalElementWrapping ordinary; + LegalElementWrapping special; + RefPtr pairInfo; +}; + +struct TupleLegalElementWrappingObj : LegalElementWrappingObj +{ + struct Element + { + IRStructKey* key; + LegalElementWrapping field; + }; + + List elements; +}; + +// Represents the pseudo-type of a type that is pointer-like +// (and thus requires dereferencing, even if implicit), but +// was legalized to just use the type of the pointed-type value. +// +// The two cases where this comes up are: +// +// 1. When we have a type like `ConstantBuffer` that +// implies a level of indirection, but need to legalize it to just +// `Texture2D`, which eliminates that indirection. +// +// 2. When we have a type like `ExistentialBox` that will +// become just a `Foo` field, but which needs to be allocated +// out-of-line from the rest of its enclosing type. +// +struct ImplicitDerefType : LegalTypeImpl +{ + LegalType valueType; +}; + +// Represents the pseudo-type for a compound type +// that had to be broken apart because it contained +// one or more fields of types that shouldn't be +// allowed in aggregates. +// +// A tuple pseduo-type will have an element for +// each field of the original type, that represents +// the legalization of that field's type. +// +// It optionally also contains an "ordinary" type +// that packs together any per-field data that +// itself has (or contains) an ordinary type. +struct TuplePseudoType : LegalTypeImpl +{ + // Represents one element of the tuple pseudo-type + struct Element + { + // The field that this element replaces + IRStructKey* key; + + // The legalized type of the element + LegalType type; + }; + + // All of the elements of the tuple pseduo-type. + List elements; +}; + +struct IRStructKey; + +struct PairInfo : RefObject +{ + typedef unsigned int Flags; + enum + { + kFlag_hasOrdinary = 0x1, + kFlag_hasSpecial = 0x2, + }; + + + struct Element + { + // The original field the element represents + IRStructKey* key; + + // The conceptual type of the field. + // If both the `hasOrdinary` and + // `hasSpecial` bits are set, then + // this is expected to be a + // `LegalType::Flavor::pair` + LegalType type; + + // Is the value represented on + // the ordinary side, the special + // side, or both? + Flags flags; + + // If the type of this element is + // itself a pair type (that is, + // it both `hasOrdinary` and `hasSpecial`) + // then this is the `PairInfo` for that + // pair type: + RefPtr fieldPairInfo; + }; + + // For a pair type or value, we need to track + // which fields are on which side(s). + List elements; + + Element* findElement(IRStructKey* key) + { + for (auto& ee : elements) + { + if(ee.key == key) + return ⅇ + } + return nullptr; + } +}; + +struct PairPseudoType : LegalTypeImpl +{ + // Any field(s) with ordinary types will + // get captured here, usually as a single + // `simple` or `implicitDeref` type. + LegalType ordinaryType; + + // Any fields with "special" (not ordinary) + // types will get captured here (usually + // with a tuple). + LegalType specialType; + + // The `pairInfo` field helps to tell us which members + // of the original aggregate type appear on which side(s) + // of the new pair type. + RefPtr pairInfo; +}; + + +struct WrappedBufferPseudoType : LegalTypeImpl +{ + // The actual IR type that was used for the buffer. + IRType* simpleType; + + // Adjustments that need to be made when fetching + // an element from this buffer type. + // + LegalElementWrapping elementInfo; +}; + +// + +RefPtr getDerefTypeLayout( + TypeLayout* typeLayout); + +RefPtr getFieldLayout( + TypeLayout* typeLayout, + IRInst* fieldKey); + + /// Represents a "chain" of variables leading to some leaf field. + /// + /// Consider code like: + /// + /// struct Branch { int leaf; } + /// struct Tree { Branch left; Branch right; } + /// cbuffer Forest + /// { + /// int maxTreeHeight; + /// Tree tree; + /// } + /// + /// If we ask "what is the offset of `leaf`" the simple answer is zero, + /// but sometimes we are talking about `Forest.tree.right.leaf` which + /// will have a very different offset. In Slang parameters can consume + /// various (and multiple) resource kinds, so a single offset can't + /// be tunneled down through most recursive procedures. + /// + /// Instead we use a "chain" that works up through the stack, and + /// records the path from leaf field like `leaf` up to whatever + /// variable is the root for the curent operation. + /// + /// Operations like computing an offset can then be encoded by + /// starting with zero and then walking up the chain and adding in + /// offsets as encountered. + /// +struct SimpleLegalVarChain +{ + // The next link up the chain, or null if this is the end. + SimpleLegalVarChain* next = nullptr; + + // The layout for the variable at this link in thain. + VarLayout* varLayout = nullptr; +}; + + /// A "chain" of variable declarations that can handle both primary and "pending" data. + /// + /// In the presence of interface-type fields, a single variable may + /// have data that sits in two distinct allocations, and may have + /// `VarLayout`s that represent offseting into each of those + /// allocations. + /// + /// A `LegalVarChain` tracks two distinct `SimpleVarChain`s: one for + /// the primary/ordinary data allocation, and one for any pending + /// data. + /// + /// It is okay if the primary/pending chains have different numbers + /// of links in them. + /// + /// Offsets for particular resource kinds in the primary or pending + /// data allocation can be queried on the appropriate sub-chain. + /// +struct LegalVarChain +{ + // The chain of variables that represents the primary allocation. + SimpleLegalVarChain* primaryChain = nullptr; + + // The chain of variables that represents the pending allocation. + SimpleLegalVarChain* pendingChain = nullptr; + + // If the primary chain is non-empty, gets the variable at the leaf. + DeclRef getLeafVarDeclRef() const + { + if(!primaryChain) + return DeclRef(); + + return primaryChain->varLayout->varDecl; + } +}; + + /// RAII type for adding a link to a `LegalVarChain` as needed. + /// + /// This type handles the bookkeeping for creating a `LegalVarChain` + /// that links in one more variable. It will add a link to each of + /// the primary and pending sub-chains if and only if there is non-null + /// layout information for the primary/pending case. + /// + /// Typical usage in a recursive function is: + /// + /// void someRecursiveFunc(LegalVarChain const& outerChain, ...) + /// { + /// if(auto subVar = needToRecurse(...)) + /// { + /// LegalVarChainLink subChain(outerChain, subVar); + /// someRecursiveFunc(subChain, ...); + /// } + /// ... + /// } + /// +struct LegalVarChainLink : LegalVarChain +{ + /// Default constructor: yields an empty chain. + LegalVarChainLink() + { + } + + /// Copy constructor: yields a copy of the `parent` chain. + LegalVarChainLink(LegalVarChain const& parent) + : LegalVarChain(parent) + {} + + /// Construct a chain that extends `parent` with `varLayout`, if it is non-null. + LegalVarChainLink(LegalVarChain const& parent, VarLayout* varLayout) + : LegalVarChain(parent) + { + if( varLayout ) + { + primaryLink.next = parent.primaryChain; + primaryLink.varLayout = varLayout; + primaryChain = &primaryLink; + + if( auto pendingVarLayout = varLayout->pendingVarLayout ) + { + pendingLink.next = parent.pendingChain; + pendingLink.varLayout = pendingVarLayout; + pendingChain = &pendingLink; + } + } + } + + SimpleLegalVarChain primaryLink; + SimpleLegalVarChain pendingLink; +}; + +RefPtr createVarLayout( + LegalVarChain const& varChain, + TypeLayout* typeLayout); + +RefPtr createSimpleVarLayout( + SimpleLegalVarChain* varChain, + TypeLayout* typeLayout); + +// +// The result of legalizing an IR value will be +// represented with the `LegalVal` type. It is exposed +// in this header (rather than kept as an implementation +// detail, because the AST-based legalization logic needs +// a way to find the post-legalization version of a +// global name). +// +// TODO: We really shouldn't have this structure exposed, +// and instead should really be constructing AST-side +// `LegalExpr` values on-demand whenever we legalize something +// in the IR that will need to be used by the AST, and then +// store *those* in a map indexed in mangled names. +// + +struct LegalValImpl : RefObject +{ +}; +struct TuplePseudoVal; +struct PairPseudoVal; +struct WrappedBufferPseudoVal; + +struct LegalVal +{ + typedef LegalFlavor Flavor; + + Flavor flavor = Flavor::none; + RefPtr obj; + IRInst* irValue = nullptr; + + static LegalVal simple(IRInst* irValue) + { + LegalVal result; + result.flavor = Flavor::simple; + result.irValue = irValue; + return result; + } + + IRInst* getSimple() const + { + SLANG_ASSERT(flavor == Flavor::simple); + return irValue; + } + + static LegalVal tuple(RefPtr tupleVal); + + RefPtr getTuple() const + { + SLANG_ASSERT(flavor == Flavor::tuple); + return obj.as(); + } + + static LegalVal implicitDeref(LegalVal const& val); + LegalVal getImplicitDeref(); + + static LegalVal pair(RefPtr pairInfo); + static LegalVal pair( + LegalVal const& ordinaryVal, + LegalVal const& specialVal, + RefPtr pairInfo); + + RefPtr getPair() const + { + SLANG_ASSERT(flavor == Flavor::pair); + return obj.as(); + } + + static LegalVal wrappedBuffer( + LegalVal const& baseVal, + LegalElementWrapping const& elementInfo); + + RefPtr getWrappedBuffer() const + { + SLANG_ASSERT(flavor == Flavor::wrappedBuffer); + return obj.as(); + } +}; + +struct TuplePseudoVal : LegalValImpl +{ + struct Element + { + IRStructKey* key; + LegalVal val; + }; + + List elements; +}; + +struct PairPseudoVal : LegalValImpl +{ + LegalVal ordinaryVal; + LegalVal specialVal; + + // The info to tell us which fields + // are on which side(s) + RefPtr pairInfo; +}; + +struct ImplicitDerefVal : LegalValImpl +{ + LegalVal val; +}; + +struct WrappedBufferPseudoVal : LegalValImpl +{ + LegalVal base; + LegalElementWrapping elementInfo; +}; + +// + + /// Context that drives type legalization + /// + /// This type is an abstract base class, and there are + /// customization points that a concrete pass needs to + /// override (e.g., to specify what needs to be legalized). +struct IRTypeLegalizationContext +{ + Session* session; + IRModule* module; + IRBuilder* builder; + + SharedIRBuilder sharedBuilderStorage; + IRBuilder builderStorage; + + IRTypeLegalizationContext( + IRModule* inModule); + + // When inserting new globals, put them before this one. + IRInst* insertBeforeGlobal = nullptr; + + // When inserting new parameters, put them before this one. + IRParam* insertBeforeParam = nullptr; + + Dictionary mapValToLegalVal; + + IRVar* insertBeforeLocalVar = nullptr; + + // store instructions that have been replaced here, so we can free them + // when legalization has done + List replacedInstructions; + + Dictionary mapTypeToLegalType; + + IRBuilder* getBuilder() { return builder; } + + /// Customization point to decide what types are "special." + /// + /// When legalizing a `struct` type, any fields that have "special" + /// types will get moved out of the `struc` itself. + virtual bool isSpecialType(IRType* type) = 0; + + /// Customization point to construct uniform-buffer/block types. + /// + /// This function will only be called if `legalElementType` is + /// somehow non-trivial. + /// + virtual LegalType createLegalUniformBufferType( + IROp op, + LegalType legalElementType) = 0; +}; + +// This typedef exists to support pre-existing code from when +// `IRTypeLegalizationContext` and `TypeLegalizationContext` were +// two different types that had to coordinate. +typedef struct IRTypeLegalizationContext TypeLegalizationContext; + +LegalType legalizeType( + TypeLegalizationContext* context, + IRType* type); + +/// Try to find the module that (recursively) contains a given declaration. +ModuleDecl* findModuleForDecl( + Decl* decl); + + /// Create a uniform buffer type suitable for resource legalization. + /// + /// This will allocate a real buffer for the ordinary data (if any), + /// and leave the special data (if any) as a tuple. + /// +LegalType createLegalUniformBufferTypeForResources( + TypeLegalizationContext* context, + IROp op, + LegalType legalElementType); + + /// Create a uniform buffer type suitable for existential legalization. + /// + /// This will allocate a real uniform buffer for *all* the data, by + /// declaring an intermediate `struct` type to hold the ordinary and + /// special (existential-box) fields, if required. + /// +LegalType createLegalUniformBufferTypeForExistentials( + TypeLegalizationContext* context, + IROp op, + LegalType legalElementType); + + + + +void legalizeExistentialTypeLayout( + IRModule* module, + DiagnosticSink* sink); + +void legalizeResourceTypes( + IRModule* module, + DiagnosticSink* sink); + +bool isResourceType(IRType* type); + + +} + +#endif diff --git a/source/slang/slang-lexer.cpp b/source/slang/slang-lexer.cpp new file mode 100644 index 000000000..d7c086fba --- /dev/null +++ b/source/slang/slang-lexer.cpp @@ -0,0 +1,1334 @@ +// slang-lexer.cpp +#include "slang-lexer.h" + +// This file implements the lexer/scanner, which is responsible for taking a raw stream of +// input bytes and turning it into semantically useful tokens. +// + +#include "slang-compiler.h" +#include "slang-source-loc.h" + +#include + +namespace Slang +{ + Token TokenReader::GetEndOfFileToken() + { + return Token(TokenType::EndOfFile, UnownedStringSlice::fromLiteral(""), SourceLoc()); + } + + Token* TokenList::begin() const + { + SLANG_ASSERT(mTokens.getCount()); + return &mTokens[0]; + } + + Token* TokenList::end() const + { + SLANG_ASSERT(mTokens.getCount()); + SLANG_ASSERT(mTokens[mTokens.getCount()-1].type == TokenType::EndOfFile); + return &mTokens[mTokens.getCount() - 1]; + } + + TokenSpan::TokenSpan() + : mBegin(NULL) + , mEnd (NULL) + {} + + TokenReader::TokenReader() + : mCursor(NULL) + , mEnd (NULL) + {} + + + Token& TokenReader::PeekToken() + { + return nextToken; + } + + TokenType TokenReader::PeekTokenType() const + { + return nextToken.type; + } + + SourceLoc TokenReader::PeekLoc() const + { + return nextToken.loc; + } + + Token TokenReader::AdvanceToken() + { + if (!mCursor) + return GetEndOfFileToken(); + + Token token = nextToken; + if (mCursor < mEnd) + { + mCursor++; + nextToken = *mCursor; + } + else + nextToken.type = TokenType::EndOfFile; + return token; + } + + // Lexer + + void Lexer::initialize( + SourceView* inSourceView, + DiagnosticSink* inSink, + NamePool* inNamePool, + MemoryArena* inMemoryArena) + { + sourceView = inSourceView; + sink = inSink; + namePool = inNamePool; + memoryArena = inMemoryArena; + + auto content = inSourceView->getContent(); + + begin = content.begin(); + cursor = content.begin(); + end = content.end(); + + // Set the start location + startLoc = inSourceView->getRange().begin; + + tokenFlags = TokenFlag::AtStartOfLine | TokenFlag::AfterWhitespace; + lexerFlags = 0; + } + + Lexer::~Lexer() + { + } + + enum { kEOF = -1 }; + + // Get the next input byte, without any handling of + // escaped newlines, non-ASCII code points, source locations, etc. + static int peekRaw(Lexer* lexer) + { + // If we are at the end of the input, return a designated end-of-file value + if(lexer->cursor == lexer->end) + return kEOF; + + // Otherwise, just look at the next byte + return *lexer->cursor; + } + + // Read one input byte without any special handling (similar to `peekRaw`) + static int advanceRaw(Lexer* lexer) + { + // The logic here is basically the same as for `peekRaw()`, + // escape we advance `cursor` if we aren't at the end. + + if (lexer->cursor == lexer->end) + return kEOF; + + return *lexer->cursor++; + } + + // When the cursor is already at the first byte of an end-of-line sequence, + // consume one or two bytes that compose the sequence. + // + // Basically, a newline is one of: + // + // "\n" + // "\r" + // "\r\n" + // "\n\r" + // + // We always look for the longest match possible. + // + static void handleNewLineInner(Lexer* lexer, int c) + { + SLANG_ASSERT(c == '\n' || c == '\r'); + + int d = peekRaw(lexer); + if( (c ^ d) == ('\n' ^ '\r') ) + { + advanceRaw(lexer); + } + } + + // Look ahead one code point, dealing with complications like + // escaped newlines. + static int peek(Lexer* lexer) + { + // Look at the next raw byte, and decide what to do + int c = peekRaw(lexer); + + if(c == '\\') + { + // We might have a backslash-escaped newline. + // Look at the next byte (if any) to see. + // + // Note(tfoley): We are assuming a null-terminated input here, + // so that we can safely look at the next byte without issue. + int d = lexer->cursor[1]; + switch (d) + { + case '\r': case '\n': + { + // The newline was escaped, so return the code point after *that* + + int e = lexer->cursor[2]; + if ((d ^ e) == ('\r' ^ '\n')) + return lexer->cursor[3]; + return e; + } + + default: + break; + } + } + // TODO: handle UTF-8 encoding for non-ASCII code points here + + // Default case is to just hand along the byte we read as an ASCII code point. + return c; + } + + // Get the next code point from the input, and advance the cursor. + static int advance(Lexer* lexer) + { + // We are going to loop, but only as a way of handling + // escaped line endings. + for (;;) + { + // If we are at the end of the input, then the task is easy. + if (lexer->cursor == lexer->end) + return kEOF; + + // Look at the next raw byte, and decide what to do + int c = *lexer->cursor++; + + if (c == '\\') + { + // We might have a backslash-escaped newline. + // Look at the next byte (if any) to see. + // + // Note(tfoley): We are assuming a null-terminated input here, + // so that we can safely look at the next byte without issue. + int d = *lexer->cursor; + switch (d) + { + case '\r': case '\n': + // handle the end-of-line for our source location tracking + lexer->cursor++; + handleNewLineInner(lexer, d); + + lexer->tokenFlags |= TokenFlag::ScrubbingNeeded; + + // Now try again, looking at the character after the + // escaped newline. + continue; + + default: + break; + } + } + + // TODO: Need to handle non-ASCII code points. + + // Default case is to return the raw byte we saw. + return c; + } + } + + static void handleNewLine(Lexer* lexer) + { + int c = advance(lexer); + handleNewLineInner(lexer, c); + } + + static void lexLineComment(Lexer* lexer) + { + for(;;) + { + switch(peek(lexer)) + { + case '\n': case '\r': case kEOF: + return; + + default: + advance(lexer); + continue; + } + } + } + + static void lexBlockComment(Lexer* lexer) + { + for(;;) + { + switch(peek(lexer)) + { + case kEOF: + // TODO(tfoley) diagnostic! + return; + + case '\n': case '\r': + handleNewLine(lexer); + continue; + + case '*': + advance(lexer); + switch( peek(lexer) ) + { + case '/': + advance(lexer); + return; + + default: + continue; + } + + default: + advance(lexer); + continue; + } + } + } + + static void lexHorizontalSpace(Lexer* lexer) + { + for(;;) + { + switch(peek(lexer)) + { + case ' ': case '\t': + advance(lexer); + continue; + + default: + return; + } + } + } + + static void lexIdentifier(Lexer* lexer) + { + for(;;) + { + int c = peek(lexer); + if(('a' <= c ) && (c <= 'z') + || ('A' <= c) && (c <= 'Z') + || ('0' <= c) && (c <= '9') + || (c == '_')) + { + advance(lexer); + continue; + } + + return; + } + } + + static SourceLoc getSourceLoc(Lexer* lexer) + { + return lexer->startLoc + (lexer->cursor - lexer->begin); + } + + static void lexDigits(Lexer* lexer, int base) + { + for(;;) + { + int c = peek(lexer); + + int digitVal = 0; + switch(c) + { + case '0': case '1': case '2': case '3': case '4': + case '5': case '6': case '7': case '8': case '9': + digitVal = c - '0'; + break; + + case 'a': case 'b': case 'c': case 'd': case 'e': case 'f': + if(base <= 10) return; + digitVal = 10 + c - 'a'; + break; + + case 'A': case 'B': case 'C': case 'D': case 'E': case 'F': + if(base <= 10) return; + digitVal = 10 + c - 'A'; + break; + + default: + // Not more digits! + return; + } + + if(digitVal >= base) + { + char buffer[] = { (char) c, 0 }; + lexer->sink->diagnose(getSourceLoc(lexer), Diagnostics::invalidDigitForBase, buffer, base); + } + + advance(lexer); + } + } + + static TokenType maybeLexNumberSuffix(Lexer* lexer, TokenType tokenType) + { + // Be liberal in what we accept here, so that figuring out + // the semantics of a numeric suffix is left up to the parser + // and semantic checking logic. + // + for( ;;) + { + int c = peek(lexer); + + // Accept any alphanumeric character, plus underscores. + if(('a' <= c ) && (c <= 'z') + || ('A' <= c) && (c <= 'Z') + || ('0' <= c) && (c <= '9') + || (c == '_')) + { + advance(lexer); + continue; + } + + // Stop at the first character that isn't + // alphanumeric. + return tokenType; + } + } + + static bool isNumberExponent(int c, int base) + { + switch( c ) + { + default: + return false; + + case 'e': case 'E': + if(base != 10) return false; + break; + + case 'p': case 'P': + if(base != 16) return false; + break; + } + + return true; + } + + static bool maybeLexNumberExponent(Lexer* lexer, int base) + { + if(!isNumberExponent(peek(lexer), base)) + return false; + + // we saw an exponent marker + advance(lexer); + + // Now start to read the exponent + switch( peek(lexer) ) + { + case '+': case '-': + advance(lexer); + break; + } + + // TODO(tfoley): it would be an error to not see digits here... + + lexDigits(lexer, 10); + + return true; + } + + static TokenType lexNumberAfterDecimalPoint(Lexer* lexer, int base) + { + lexDigits(lexer, base); + maybeLexNumberExponent(lexer, base); + + return maybeLexNumberSuffix(lexer, TokenType::FloatingPointLiteral); + } + + static TokenType lexNumber(Lexer* lexer, int base) + { + // TODO(tfoley): Need to consider whehter to allow any kind of digit separator character. + + TokenType tokenType = TokenType::IntegerLiteral; + + // At the start of things, we just concern ourselves with digits + lexDigits(lexer, base); + + if( peek(lexer) == '.' ) + { + tokenType = TokenType::FloatingPointLiteral; + + advance(lexer); + lexDigits(lexer, base); + } + + if( maybeLexNumberExponent(lexer, base)) + { + tokenType = TokenType::FloatingPointLiteral; + } + + maybeLexNumberSuffix(lexer, tokenType); + return tokenType; + } + + static int maybeReadDigit(char const** ioCursor, int base) + { + auto& cursor = *ioCursor; + + for(;;) + { + int c = *cursor; + switch(c) + { + default: + return -1; + + // TODO: need to decide on digit separator characters + case '_': + cursor++; + continue; + + case '0': case '1': case '2': case '3': case '4': + case '5': case '6': case '7': case '8': case '9': + cursor++; + return c - '0'; + + case 'a': case 'b': case 'c': case 'd': case 'e': case 'f': + if(base > 10) + { + cursor++; + return 10 + c - 'a'; + } + return -1; + + case 'A': case 'B': case 'C': case 'D': case 'E': case 'F': + if(base > 10) + { + cursor++; + return 10 + c - 'A'; + } + return -1; + } + } + } + + static int readOptionalBase(char const** ioCursor) + { + auto& cursor = *ioCursor; + if( *cursor == '0' ) + { + cursor++; + switch(*cursor) + { + case 'x': case 'X': + cursor++; + return 16; + + case 'b': case 'B': + cursor++; + return 2; + + case '0': case '1': case '2': case '3': case '4': + case '5': case '6': case '7': case '8': case '9': + return 8; + + default: + return 10; + } + } + + return 10; + } + + + + IntegerLiteralValue getIntegerLiteralValue(Token const& token, UnownedStringSlice* outSuffix) + { + IntegerLiteralValue value = 0; + + char const* cursor = token.Content.begin(); + char const* end = token.Content.end(); + + int base = readOptionalBase(&cursor); + + for( ;;) + { + int digit = maybeReadDigit(&cursor, base); + if(digit < 0) + break; + + value = value*base + digit; + } + + if(outSuffix) + { + *outSuffix = UnownedStringSlice(cursor, end); + } + + return value; + } + + FloatingPointLiteralValue getFloatingPointLiteralValue(Token const& token, UnownedStringSlice* outSuffix) + { + FloatingPointLiteralValue value = 0; + + char const* cursor = token.Content.begin(); + char const* end = token.Content.end(); + + int radix = readOptionalBase(&cursor); + + bool seenDot = false; + FloatingPointLiteralValue divisor = 1; + for( ;;) + { + if(*cursor == '.') + { + cursor++; + seenDot = true; + continue; + } + + int digit = maybeReadDigit(&cursor, radix); + if(digit < 0) + break; + + value = value*radix + digit; + + if(seenDot) + { + divisor *= radix; + } + } + + // Now read optional exponent + if(isNumberExponent(*cursor, radix)) + { + cursor++; + + bool exponentIsNegative = false; + switch(*cursor) + { + default: + break; + + case '-': + exponentIsNegative = true; + cursor++; + break; + + case '+': + cursor++; + break; + } + + int exponentRadix = 10; + int exponent = 0; + + for(;;) + { + int digit = maybeReadDigit(&cursor, exponentRadix); + if(digit < 0) + break; + + exponent = exponent*exponentRadix + digit; + } + + FloatingPointLiteralValue exponentBase = 10; + if(radix == 16) + { + exponentBase = 2; + } + + FloatingPointLiteralValue exponentValue = pow(exponentBase, exponent); + + if( exponentIsNegative ) + { + divisor *= exponentValue; + } + else + { + value *= exponentValue; + } + } + + value /= divisor; + + if(outSuffix) + { + *outSuffix = UnownedStringSlice(cursor, end); + } + + return value; + } + + static void lexStringLiteralBody(Lexer* lexer, char quote) + { + for(;;) + { + int c = peek(lexer); + if(c == quote) + { + advance(lexer); + return; + } + + switch(c) + { + case kEOF: + lexer->sink->diagnose(getSourceLoc(lexer), Diagnostics::endOfFileInLiteral); + return; + + case '\n': case '\r': + lexer->sink->diagnose(getSourceLoc(lexer), Diagnostics::newlineInLiteral); + return; + + case '\\': + // Need to handle various escape sequence cases + advance(lexer); + switch(peek(lexer)) + { + case '\'': + case '\"': + case '\\': + case '?': + case 'a': + case 'b': + case 'f': + case 'n': + case 'r': + case 't': + case 'v': + advance(lexer); + break; + + case '0': case '1': case '2': case '3': case '4': + case '5': case '6': case '7': + // octal escape: up to 3 characters + advance(lexer); + for(int ii = 0; ii < 3; ++ii) + { + int d = peek(lexer); + if(('0' <= d) && (d <= '7')) + { + advance(lexer); + continue; + } + else + { + break; + } + } + break; + + case 'x': + // hexadecimal escape: any number of characters + advance(lexer); + for(;;) + { + int d = peek(lexer); + if(('0' <= d) && (d <= '9') + || ('a' <= d) && (d <= 'f') + || ('A' <= d) && (d <= 'F')) + { + advance(lexer); + continue; + } + else + { + break; + } + } + break; + + // TODO: Unicode escape sequences + + } + break; + + default: + advance(lexer); + continue; + } + } + } + + String getStringLiteralTokenValue(Token const& token) + { + SLANG_ASSERT(token.type == TokenType::StringLiteral + || token.type == TokenType::CharLiteral); + + char const* cursor = token.Content.begin(); + char const* end = token.Content.end(); + SLANG_UNREFERENCED_VARIABLE(end); + + auto quote = *cursor++; + SLANG_ASSERT(quote == '\'' || quote == '"'); + + StringBuilder valueBuilder; + for(;;) + { + SLANG_ASSERT(cursor != end); + + auto c = *cursor++; + + // If we see a closing quote, then we are at the end of the string literal + if(c == quote) + { + SLANG_ASSERT(cursor == end); + return valueBuilder.ProduceString(); + } + + // Characters that don't being escape sequences are easy; + // just append them to the buffer and move on. + if(c != '\\') + { + valueBuilder.Append(c); + continue; + } + + // Now we look at another character to figure out the kind of + // escape sequence we are dealing with: + + char d = *cursor++; + + switch(d) + { + // Simple characters that just needed to be escaped + case '\'': + case '\"': + case '\\': + case '?': + valueBuilder.Append(d); + continue; + + // Traditional escape sequences for special characters + case 'a': valueBuilder.Append('\a'); continue; + case 'b': valueBuilder.Append('\b'); continue; + case 'f': valueBuilder.Append('\f'); continue; + case 'n': valueBuilder.Append('\n'); continue; + case 'r': valueBuilder.Append('\r'); continue; + case 't': valueBuilder.Append('\t'); continue; + case 'v': valueBuilder.Append('\v'); continue; + + // Octal escape: up to 3 characterws + case '0': case '1': case '2': case '3': case '4': + case '5': case '6': case '7': + { + cursor--; + int value = 0; + for(int ii = 0; ii < 3; ++ii) + { + d = *cursor; + if(('0' <= d) && (d <= '7')) + { + value = value*8 + (d - '0'); + + cursor++; + continue; + } + else + { + break; + } + } + + // TODO: add support for appending an arbitrary code point? + valueBuilder.Append((char) value); + } + continue; + + // Hexadecimal escape: any number of characters + case 'x': + { + cursor--; + int value = 0; + for(;;) + { + d = *cursor++; + int digitValue = 0; + if(('0' <= d) && (d <= '9')) + { + digitValue = d - '0'; + } + else if( ('a' <= d) && (d <= 'f') ) + { + digitValue = d - 'a'; + } + else if( ('A' <= d) && (d <= 'F') ) + { + digitValue = d - 'A'; + } + else + { + cursor--; + break; + } + + value = value*16 + digitValue; + } + + // TODO: add support for appending an arbitrary code point? + valueBuilder.Append((char) value); + } + continue; + + // TODO: Unicode escape sequences + + } + } + } + + String getFileNameTokenValue(Token const& token) + { + // A file name usually doesn't process escape sequences + // (this is import on Windows, where `\\` is a valid + // path separator character). + + // Just trim off the first and last characters to remove the quotes + // (whether they were `""` or `<>`. + return String(token.Content.begin() + 1, token.Content.end() - 1); + } + + + + static TokenType lexTokenImpl(Lexer* lexer, LexerFlags effectiveFlags) + { + if(effectiveFlags & kLexerFlag_ExpectDirectiveMessage) + { + for(;;) + { + switch(peek(lexer)) + { + default: + advance(lexer); + continue; + + case kEOF: case '\r': case '\n': + break; + } + break; + } + return TokenType::DirectiveMessage; + } + + switch(peek(lexer)) + { + default: + break; + + case kEOF: + if((effectiveFlags & kLexerFlag_InDirective) != 0) + return TokenType::EndOfDirective; + return TokenType::EndOfFile; + + case '\r': case '\n': + if((effectiveFlags & kLexerFlag_InDirective) != 0) + return TokenType::EndOfDirective; + handleNewLine(lexer); + return TokenType::NewLine; + + case ' ': case '\t': + lexHorizontalSpace(lexer); + return TokenType::WhiteSpace; + + case '.': + advance(lexer); + switch(peek(lexer)) + { + case '0': case '1': case '2': case '3': case '4': + case '5': case '6': case '7': case '8': case '9': + return lexNumberAfterDecimalPoint(lexer, 10); + + // TODO(tfoley): handle ellipsis (`...`) + + default: + return TokenType::Dot; + } + + case '1': case '2': case '3': case '4': + case '5': case '6': case '7': case '8': case '9': + return lexNumber(lexer, 10); + + case '0': + { + auto loc = getSourceLoc(lexer); + advance(lexer); + switch(peek(lexer)) + { + default: + return maybeLexNumberSuffix(lexer, TokenType::IntegerLiteral); + + case '.': + advance(lexer); + return lexNumberAfterDecimalPoint(lexer, 10); + + case 'x': case 'X': + advance(lexer); + return lexNumber(lexer, 16); + + case 'b': case 'B': + advance(lexer); + return lexNumber(lexer, 2); + + case '0': case '1': case '2': case '3': case '4': + case '5': case '6': case '7': case '8': case '9': + lexer->sink->diagnose(loc, Diagnostics::octalLiteral); + return lexNumber(lexer, 8); + } + } + + case 'a': case 'b': case 'c': case 'd': case 'e': + case 'f': case 'g': case 'h': case 'i': case 'j': + case 'k': case 'l': case 'm': case 'n': case 'o': + case 'p': case 'q': case 'r': case 's': case 't': + case 'u': case 'v': case 'w': case 'x': case 'y': + case 'z': + case 'A': case 'B': case 'C': case 'D': case 'E': + case 'F': case 'G': case 'H': case 'I': case 'J': + case 'K': case 'L': case 'M': case 'N': case 'O': + case 'P': case 'Q': case 'R': case 'S': case 'T': + case 'U': case 'V': case 'W': case 'X': case 'Y': + case 'Z': + case '_': + lexIdentifier(lexer); + return TokenType::Identifier; + + case '\"': + advance(lexer); + lexStringLiteralBody(lexer, '\"'); + return TokenType::StringLiteral; + + case '\'': + advance(lexer); + lexStringLiteralBody(lexer, '\''); + return TokenType::CharLiteral; + + case '+': + advance(lexer); + switch(peek(lexer)) + { + case '+': advance(lexer); return TokenType::OpInc; + case '=': advance(lexer); return TokenType::OpAddAssign; + default: + return TokenType::OpAdd; + } + + case '-': + advance(lexer); + switch(peek(lexer)) + { + case '-': advance(lexer); return TokenType::OpDec; + case '=': advance(lexer); return TokenType::OpSubAssign; + case '>': advance(lexer); return TokenType::RightArrow; + default: + return TokenType::OpSub; + } + + case '*': + advance(lexer); + switch(peek(lexer)) + { + case '=': advance(lexer); return TokenType::OpMulAssign; + default: + return TokenType::OpMul; + } + + case '/': + advance(lexer); + switch(peek(lexer)) + { + case '=': advance(lexer); return TokenType::OpDivAssign; + case '/': advance(lexer); lexLineComment(lexer); return TokenType::LineComment; + case '*': advance(lexer); lexBlockComment(lexer); return TokenType::BlockComment; + default: + return TokenType::OpDiv; + } + + case '%': + advance(lexer); + switch(peek(lexer)) + { + case '=': advance(lexer); return TokenType::OpModAssign; + default: + return TokenType::OpMod; + } + + case '|': + advance(lexer); + switch(peek(lexer)) + { + case '|': advance(lexer); return TokenType::OpOr; + case '=': advance(lexer); return TokenType::OpOrAssign; + default: + return TokenType::OpBitOr; + } + + case '&': + advance(lexer); + switch(peek(lexer)) + { + case '&': advance(lexer); return TokenType::OpAnd; + case '=': advance(lexer); return TokenType::OpAndAssign; + default: + return TokenType::OpBitAnd; + } + + case '^': + advance(lexer); + switch(peek(lexer)) + { + case '=': advance(lexer); return TokenType::OpXorAssign; + default: + return TokenType::OpBitXor; + } + + case '>': + advance(lexer); + switch(peek(lexer)) + { + case '>': + advance(lexer); + switch(peek(lexer)) + { + case '=': advance(lexer); return TokenType::OpShrAssign; + default: return TokenType::OpRsh; + } + case '=': advance(lexer); return TokenType::OpGeq; + default: + return TokenType::OpGreater; + } + + case '<': + advance(lexer); + switch(peek(lexer)) + { + case '<': + advance(lexer); + switch(peek(lexer)) + { + case '=': advance(lexer); return TokenType::OpShlAssign; + default: return TokenType::OpLsh; + } + case '=': advance(lexer); return TokenType::OpLeq; + default: + return TokenType::OpLess; + } + + case '=': + advance(lexer); + switch(peek(lexer)) + { + case '=': advance(lexer); return TokenType::OpEql; + default: + return TokenType::OpAssign; + } + + case '!': + advance(lexer); + switch(peek(lexer)) + { + case '=': advance(lexer); return TokenType::OpNeq; + default: + return TokenType::OpNot; + } + + case '#': + advance(lexer); + switch(peek(lexer)) + { + case '#': advance(lexer); return TokenType::PoundPound; + default: + return TokenType::Pound; + } + + case '~': advance(lexer); return TokenType::OpBitNot; + + case ':': + { + advance(lexer); + if (peek(lexer) == ':') + { + advance(lexer); + return TokenType::Scope; + } + return TokenType::Colon; + } + case ';': advance(lexer); return TokenType::Semicolon; + case ',': advance(lexer); return TokenType::Comma; + + case '{': advance(lexer); return TokenType::LBrace; + case '}': advance(lexer); return TokenType::RBrace; + case '[': advance(lexer); return TokenType::LBracket; + case ']': advance(lexer); return TokenType::RBracket; + case '(': advance(lexer); return TokenType::LParent; + case ')': advance(lexer); return TokenType::RParent; + + case '?': advance(lexer); return TokenType::QuestionMark; + case '@': advance(lexer); return TokenType::At; + case '$': advance(lexer); return TokenType::Dollar; + + } + + // TODO(tfoley): If we ever wanted to support proper Unicode + // in identifiers, etc., then this would be the right place + // to perform a more expensive dispatch based on the actual + // code point (and not just the first byte). + + { + // If none of the above cases matched, then we have an + // unexpected/invalid character. + + auto loc = getSourceLoc(lexer); + int c = advance(lexer); + if(!(effectiveFlags & kLexerFlag_IgnoreInvalid)) + { + auto sink = lexer->sink; + if(c >= 0x20 && c <= 0x7E) + { + char buffer[] = { (char) c, 0 }; + sink->diagnose(loc, Diagnostics::illegalCharacterPrint, buffer); + } + else + { + // Fallback: print as hexadecimal + sink->diagnose(loc, Diagnostics::illegalCharacterHex, String((unsigned char)c, 16)); + } + } + + return TokenType::Invalid; + } + } + + Token Lexer::lexToken(LexerFlags extraFlags) + { + auto& flags = this->tokenFlags; + for(;;) + { + Token token; + token.loc = getSourceLoc(this); + + char const* textBegin = cursor; + + auto tokenType = lexTokenImpl(this, this->lexerFlags | extraFlags); + + // The low-level lexer produces tokens for things we want + // to ignore, such as white space, so we skip them here. + switch(tokenType) + { + case TokenType::Invalid: + flags = 0; + continue; + + case TokenType::NewLine: + flags = TokenFlag::AtStartOfLine | TokenFlag::AfterWhitespace; + continue; + + case TokenType::WhiteSpace: + case TokenType::LineComment: + case TokenType::BlockComment: + flags |= TokenFlag::AfterWhitespace; + continue; + + // We don't want to skip the end-of-file token, but we *do* + // want to make sure it has appropriate flags to make our life easier + case TokenType::EndOfFile: + flags |= TokenFlag::AtStartOfLine | TokenFlag::AfterWhitespace; + break; + + // We will also do some book-keeping around preprocessor directives here: + // + // If we see a `#` at the start of a line, then we are entering a + // preprocessor directive. + case TokenType::Pound: + if((flags & TokenFlag::AtStartOfLine) != 0) + lexerFlags |= kLexerFlag_InDirective; + break; + // + // And if we saw an end-of-line during a directive, then we are + // now leaving that directive. + // + case TokenType::EndOfDirective: + lexerFlags &= ~kLexerFlag_InDirective; + break; + + default: + break; + } + + token.type = tokenType; + + char const* textEnd = cursor; + + // Note(tfoley): `StringBuilder::Append()` seems to crash when appending zero bytes + if(textEnd != textBegin) + { + // "scrubbing" token value here to remove escaped newlines... + // + // Only perform this work if we encountered an escaped newline + // while lexing this token (e.g., keep a flag on the lexer), or + // do it on-demand when the actual value of the token is needed. + if (tokenFlags & TokenFlag::ScrubbingNeeded) + { + // Allocate space that will always be more than enough for stripped contents + char* startDst = (char*)memoryArena->allocateUnaligned(textEnd - textBegin); + char* dst = startDst; + + auto tt = textBegin; + while (tt != textEnd) + { + char c = *tt++; + if (c == '\\') + { + char d = *tt; + switch (d) + { + case '\r': case '\n': + { + tt++; + char e = *tt; + if ((d ^ e) == ('\r' ^ '\n')) + { + tt++; + } + } + continue; + + default: + break; + } + } + *dst++ = c; + } + token.Content = UnownedStringSlice(startDst, dst); + } + else + { + token.Content = UnownedStringSlice(textBegin, textEnd); + } + } + + token.flags = flags; + + this->tokenFlags = 0; + + if (tokenType == TokenType::Identifier) + { + token.ptrValue = this->namePool->getName(token.Content); + } + + return token; + } + } + + TokenList Lexer::lexAllTokens() + { + TokenList tokenList; + for(;;) + { + Token token = lexToken(); + tokenList.mTokens.add(token); + + if(token.type == TokenType::EndOfFile) + return tokenList; + } + } +} diff --git a/source/slang/slang-lexer.h b/source/slang/slang-lexer.h new file mode 100644 index 000000000..d3bf68e45 --- /dev/null +++ b/source/slang/slang-lexer.h @@ -0,0 +1,136 @@ +#ifndef SLANG_LEXER_H +#define SLANG_LEXER_H + +#include "../core/slang-basic.h" +#include "slang-diagnostics.h" + +namespace Slang +{ + struct NamePool; + + // + + struct TokenList + { + Token* begin() const; + Token* end() const; + + List mTokens; + }; + + struct TokenSpan + { + TokenSpan(); + TokenSpan( + TokenList const& tokenList) + : mBegin(tokenList.begin()) + , mEnd (tokenList.end ()) + {} + + Token* begin() const { return mBegin; } + Token* end () const { return mEnd ; } + + int GetCount() { return (int)(mEnd - mBegin); } + + Token* mBegin; + Token* mEnd; + }; + + struct TokenReader + { + Token nextToken; + TokenReader(); + explicit TokenReader(TokenSpan const& tokens) + : mCursor(tokens.begin()) + , mEnd (tokens.end ()) + , nextToken(tokens.begin() ? *tokens.begin() : GetEndOfFileToken()) + {} + explicit TokenReader(TokenList const& tokens) + : mCursor(tokens.begin()) + , mEnd (tokens.end ()) + , nextToken(tokens.begin() ? *tokens.begin() : GetEndOfFileToken()) + {} + struct ParsingCursor + { + Token nextToken; + Token* tokenReaderCursor = nullptr; + }; + ParsingCursor getCursor() + { + ParsingCursor rs; + rs.nextToken = nextToken; + rs.tokenReaderCursor = mCursor; + return rs; + } + void setCursor(ParsingCursor cursor) + { + mCursor = cursor.tokenReaderCursor; + nextToken = cursor.nextToken; + } + bool IsAtEnd() const { return mCursor == mEnd; } + Token& PeekToken(); + TokenType PeekTokenType() const; + SourceLoc PeekLoc() const; + + Token AdvanceToken(); + + int GetCount() { return (int)(mEnd - mCursor); } + + Token* mCursor; + Token* mEnd; + static Token GetEndOfFileToken(); + }; + + typedef unsigned int LexerFlags; + enum + { + kLexerFlag_InDirective = 1 << 0, ///< Turn end-of-line and end-of-file into end-of-directive + kLexerFlag_ExpectFileName = 1 << 1, ///< Support `<>` style strings for file paths + kLexerFlag_IgnoreInvalid = 1 << 2, ///< Suppress errors about invalid/unsupported characters + kLexerFlag_ExpectDirectiveMessage = 1 << 3, ///< Don't lexer ordinary tokens, and instead consume rest of line as a string + }; + + struct Lexer + { + void initialize( + SourceView* sourceView, + DiagnosticSink* sink, + NamePool* namePool, + MemoryArena* memoryArena); + + ~Lexer(); + + Token lexToken(LexerFlags extraFlags = 0); + + TokenList lexAllTokens(); + + SourceView* sourceView; + DiagnosticSink* sink; + NamePool* namePool; + + char const* cursor; + + char const* begin; + char const* end; + + /// The starting sourceLoc (same as first location of SourceView) + SourceLoc startLoc; + + TokenFlags tokenFlags; + LexerFlags lexerFlags; + + MemoryArena* memoryArena; + }; + + // Helper routines for extracting values from tokens + String getStringLiteralTokenValue(Token const& token); + String getFileNameTokenValue(Token const& token); + + typedef int64_t IntegerLiteralValue; + typedef double FloatingPointLiteralValue; + + IntegerLiteralValue getIntegerLiteralValue(Token const& token, UnownedStringSlice* outSuffix = 0); + FloatingPointLiteralValue getFloatingPointLiteralValue(Token const& token, UnownedStringSlice* outSuffix = 0); +} + +#endif diff --git a/source/slang/slang-lookup.cpp b/source/slang/slang-lookup.cpp new file mode 100644 index 000000000..0a77a259a --- /dev/null +++ b/source/slang/slang-lookup.cpp @@ -0,0 +1,713 @@ +// slang-lookup.cpp +#include "slang-lookup.h" +#include "slang-name.h" + +namespace Slang { + +void checkDecl(SemanticsVisitor* visitor, Decl* decl); + +// + +DeclRef ApplyExtensionToType( + SemanticsVisitor* semantics, + ExtensionDecl* extDecl, + RefPtr type); + +// + + +// Helper for constructing breadcrumb trails during lookup, without unnecessary heap allocaiton +struct BreadcrumbInfo +{ + LookupResultItem::Breadcrumb::Kind kind; + LookupResultItem::Breadcrumb::ThisParameterMode thisParameterMode = LookupResultItem::Breadcrumb::ThisParameterMode::Default; + DeclRef declRef; + BreadcrumbInfo* prev = nullptr; +}; + +void DoLocalLookupImpl( + Session* session, + Name* name, + DeclRef containerDeclRef, + LookupRequest const& request, + LookupResult& result, + BreadcrumbInfo* inBreadcrumbs); + +// + +void buildMemberDictionary(ContainerDecl* decl) +{ + // Don't rebuild if already built + if (decl->memberDictionaryIsValid) + return; + + decl->memberDictionary.Clear(); + decl->transparentMembers.clear(); + + // are we a generic? + GenericDecl* genericDecl = as(decl); + + for (auto m : decl->Members) + { + auto name = m->getName(); + + // Add any transparent members to a separate list for lookup + if (m->HasModifier()) + { + TransparentMemberInfo info; + info.decl = m.Ptr(); + decl->transparentMembers.add(info); + } + + // Ignore members with no name + if (!name) + continue; + + // Ignore the "inner" member of a generic declaration + if (genericDecl && m == genericDecl->inner) + continue; + + + m->nextInContainerWithSameName = nullptr; + + Decl* next = nullptr; + if (decl->memberDictionary.TryGetValue(name, next)) + m->nextInContainerWithSameName = next; + + decl->memberDictionary[name] = m.Ptr(); + + } + decl->memberDictionaryIsValid = true; +} + + +bool DeclPassesLookupMask(Decl* decl, LookupMask mask) +{ + // type declarations + if(auto aggTypeDecl = as(decl)) + { + return int(mask) & int(LookupMask::type); + } + else if(auto simpleTypeDecl = as(decl)) + { + return int(mask) & int(LookupMask::type); + } + // function declarations + else if(auto funcDecl = as(decl)) + { + return (int(mask) & int(LookupMask::Function)) != 0; + } + // attribute declaration + else if( auto attrDecl = as(decl) ) + { + return (int(mask) & int(LookupMask::Attribute)) != 0; + } + + // default behavior is to assume a value declaration + // (no overloading allowed) + + return (int(mask) & int(LookupMask::Value)) != 0; +} + +void AddToLookupResult( + LookupResult& result, + LookupResultItem item) +{ + if (!result.isValid()) + { + // If we hadn't found a hit before, we have one now + result.item = item; + } + else if (!result.isOverloaded()) + { + // We are about to make this overloaded + result.items.add(result.item); + result.items.add(item); + } + else + { + // The result was already overloaded, so we pile on + result.items.add(item); + } +} + +LookupResult refineLookup(LookupResult const& inResult, LookupMask mask) +{ + if (!inResult.isValid()) return inResult; + if (!inResult.isOverloaded()) return inResult; + + LookupResult result; + for (auto item : inResult.items) + { + if (!DeclPassesLookupMask(item.declRef.getDecl(), mask)) + continue; + + AddToLookupResult(result, item); + } + return result; +} + +LookupResultItem CreateLookupResultItem( + DeclRef declRef, + BreadcrumbInfo* breadcrumbInfos) +{ + LookupResultItem item; + item.declRef = declRef; + + // breadcrumbs were constructed "backwards" on the stack, so we + // reverse them here by building a linked list the other way + RefPtr breadcrumbs; + for (auto bb = breadcrumbInfos; bb; bb = bb->prev) + { + breadcrumbs = new LookupResultItem::Breadcrumb( + bb->kind, + bb->declRef, + breadcrumbs, + bb->thisParameterMode); + } + item.breadcrumbs = breadcrumbs; + return item; +} + +void DoMemberLookupImpl( + Session* session, + Name* name, + RefPtr baseType, + LookupRequest const& request, + LookupResult& ioResult, + BreadcrumbInfo* breadcrumbs) +{ + if (!baseType) + { + return; + } + + // If the type was pointer-like, then dereference it + // automatically here. + if (auto pointerLikeType = as(baseType)) + { + // Need to leave a breadcrumb to indicate that we + // did an implicit dereference here + BreadcrumbInfo derefBreacrumb; + derefBreacrumb.kind = LookupResultItem::Breadcrumb::Kind::Deref; + derefBreacrumb.prev = breadcrumbs; + + // Recursively perform lookup on the result of deref + return DoMemberLookupImpl( + session, + name, pointerLikeType->elementType, request, ioResult, &derefBreacrumb); + } + + // Default case: no dereference needed + + if (auto baseDeclRefType = as(baseType)) + { + if (auto baseAggTypeDeclRef = baseDeclRefType->declRef.as()) + { + DoLocalLookupImpl( + session, + name, baseAggTypeDeclRef, request, ioResult, breadcrumbs); + } + } + + // TODO(tfoley): any other cases to handle here? +} + +void DoMemberLookupImpl( + Session* session, + Name* name, + DeclRef baseDeclRef, + LookupRequest const& request, + LookupResult& ioResult, + BreadcrumbInfo* breadcrumbs) +{ + auto baseType = getTypeForDeclRef( + session, + baseDeclRef); + return DoMemberLookupImpl( + session, + name, baseType, request, ioResult, breadcrumbs); +} + +// If we are about to perform lookup through an interface, then +// we need to specialize the decl-ref to that interface to include +// a "this type" subtitution. This function applies that substition +// when it is required, and returns the existing `declRef` otherwise. +DeclRef maybeSpecializeInterfaceDeclRef( + RefPtr subType, + RefPtr superType, + DeclRef superTypeDeclRef, // The decl-ref we are going to perform lookup in + DeclRef constraintDeclRef) // The type constraint that told us our type is a subtype +{ + if (auto superInterfaceDeclRef = superTypeDeclRef.as()) + { + // Create a subtype witness value to note the subtype relationship + // that makes this specialization valid. + // + // Note: this is to ensure that we can specialize the subtype witness + // later (e.g., by replacing a subtype witness that represents a generic + // constraint parameter with the concrete generic arguments that + // are used at a particular call site to the generic). + RefPtr subtypeWitness = new DeclaredSubtypeWitness(); + subtypeWitness->declRef = constraintDeclRef; + subtypeWitness->sub = subType; + subtypeWitness->sup = superType; + + RefPtr thisTypeSubst = new ThisTypeSubstitution(); + thisTypeSubst->interfaceDecl = superInterfaceDeclRef.getDecl(); + thisTypeSubst->witness = subtypeWitness; + thisTypeSubst->outer = superInterfaceDeclRef.substitutions.substitutions; + + auto specializedInterfaceDeclRef = DeclRef(superInterfaceDeclRef.getDecl(), thisTypeSubst); + return specializedInterfaceDeclRef; + } + + return superTypeDeclRef; +} + +// Same as the above, but we are specializing a type instead of a decl-ref +RefPtr maybeSpecializeInterfaceDeclRef( + Session* session, + RefPtr subType, + RefPtr superType, // The type we are going to perform lookup in + DeclRef constraintDeclRef) // The type constraint that told us our type is a subtype +{ + if (auto superDeclRefType = as(superType)) + { + if (auto superInterfaceDeclRef = superDeclRefType->declRef.as()) + { + auto specializedInterfaceDeclRef = maybeSpecializeInterfaceDeclRef( + subType, + superType, + superInterfaceDeclRef, + constraintDeclRef); + auto specializedInterfaceType = DeclRefType::Create(session, specializedInterfaceDeclRef); + return specializedInterfaceType; + } + } + + return superType; +} + + +// Look for members of the given name in the given container for declarations +void DoLocalLookupImpl( + Session* session, + Name* name, + DeclRef containerDeclRef, + LookupRequest const& request, + LookupResult& result, + BreadcrumbInfo* inBreadcrumbs) +{ + if (result.lookedupDecls.Contains(containerDeclRef)) + return; + result.lookedupDecls.Add(containerDeclRef); + + ContainerDecl* containerDecl = containerDeclRef.getDecl(); + + // Ensure that the lookup dictionary in the container is up to date + if (!containerDecl->memberDictionaryIsValid) + { + buildMemberDictionary(containerDecl); + } + + // Look up the declarations with the chosen name in the container. + Decl* firstDecl = nullptr; + containerDecl->memberDictionary.TryGetValue(name, firstDecl); + + // Now iterate over those declarations (if any) and see if + // we find any that meet our filtering criteria. + // For example, we might be filtering so that we only consider + // type declarations. + for (auto m = firstDecl; m; m = m->nextInContainerWithSameName) + { + if (!DeclPassesLookupMask(m, request.mask)) + continue; + + // The declaration passed the test, so add it! + AddToLookupResult(result, CreateLookupResultItem(DeclRef(m, containerDeclRef.substitutions), inBreadcrumbs)); + } + + + // TODO(tfoley): should we look up in the transparent decls + // if we already has a hit in the current container? + + for(auto transparentInfo : containerDecl->transparentMembers) + { + // The reference to the transparent member should use whatever + // substitutions we used in referring to its outer container + DeclRef transparentMemberDeclRef(transparentInfo.decl, containerDeclRef.substitutions); + + // We need to leave a breadcrumb so that we know that the result + // of lookup involves a member lookup step here + + BreadcrumbInfo memberRefBreadcrumb; + memberRefBreadcrumb.kind = LookupResultItem::Breadcrumb::Kind::Member; + memberRefBreadcrumb.declRef = transparentMemberDeclRef; + memberRefBreadcrumb.prev = inBreadcrumbs; + + DoMemberLookupImpl( + session, + name, + transparentMemberDeclRef, + request, + result, + &memberRefBreadcrumb); + } + + // Consider lookup via extension + if( auto aggTypeDeclRef = containerDeclRef.as() ) + { + RefPtr type = DeclRefType::Create( + session, + aggTypeDeclRef); + + for (auto ext = GetCandidateExtensions(aggTypeDeclRef); ext; ext = ext->nextCandidateExtension) + { + auto extDeclRef = ApplyExtensionToType(request.semantics, ext, type); + if (!extDeclRef) + continue; + + // TODO: eventually we need to insert a breadcrumb here so that + // the constructed result can somehow indicate that a member + // was found through an extension. + + DoLocalLookupImpl( + session, + name, extDeclRef, request, result, inBreadcrumbs); + } + + } + // for interface decls, also lookup in the base interfaces + if (request.semantics) + { + // TODO: + // The logic here is a bit gross, because it tries to work in terms of + // decl-refs instead of types (e.g., it asserts that the target type + // for an `extension` declaration must be a decl-ref type). + // + // This code should be converted to do a type-based lookup + // through declared bases for *any* aggregate type declaration. + // I think that logic is present in the type-based lookup path, but + // it would be needed here for when doing lookup from inside an + // aggregate declaration. + + // if we are looking at an extension, find the target decl that we are extending + DeclRef targetDeclRef = containerDeclRef; + RefPtr targetDeclRefType; + if (auto extDeclRef = containerDeclRef.as()) + { + targetDeclRefType = as(extDeclRef.getDecl()->targetType); + SLANG_ASSERT(targetDeclRefType); + int diff = 0; + targetDeclRef = targetDeclRefType->declRef.as().SubstituteImpl(containerDeclRef.substitutions, &diff); + } + + // if we are looking inside an interface decl, try find in the interfaces it inherits from + if (targetDeclRef.is()) + { + if(!targetDeclRefType) + { + targetDeclRefType = DeclRefType::Create(session, targetDeclRef); + } + + auto baseInterfaces = getMembersOfType(containerDeclRef); + for (auto inheritanceDeclRef : baseInterfaces) + { + checkDecl(request.semantics, inheritanceDeclRef.decl); + + auto baseType = inheritanceDeclRef.getDecl()->base.type.dynamicCast(); + SLANG_ASSERT(baseType); + int diff = 0; + auto baseInterfaceDeclRef = baseType->declRef.SubstituteImpl(containerDeclRef.substitutions, &diff); + + baseInterfaceDeclRef = maybeSpecializeInterfaceDeclRef( + targetDeclRefType, + baseType, + baseInterfaceDeclRef, + inheritanceDeclRef); + + DoLocalLookupImpl(session, name, baseInterfaceDeclRef.as(), request, result, inBreadcrumbs); + } + } + } +} + +void DoLookupImpl( + Session* session, + Name* name, + LookupRequest const& request, + LookupResult& result) +{ + auto thisParameterMode = LookupResultItem::Breadcrumb::ThisParameterMode::Default; + + auto scope = request.scope; + auto endScope = request.endScope; + for (;scope != endScope; scope = scope->parent) + { + // Note that we consider all "peer" scopes together, + // so that a hit in one of them does not preclude + // also finding a hit in another + for(auto link = scope; link; link = link->nextSibling) + { + auto containerDecl = link->containerDecl; + + if(!containerDecl) + continue; + + DeclRef containerDeclRef = + DeclRef(containerDecl, createDefaultSubstitutions(session, containerDecl)).as(); + + BreadcrumbInfo breadcrumb; + BreadcrumbInfo* breadcrumbs = nullptr; + + // Depending on the kind of container we are looking into, + // we may need to insert something like a `this` expression + // to resolve the lookup result. + // + // Note: We are checking for `AggTypeDeclBase` here, and not + // just `AggTypeDecl`, because we want to catch `extension` + // declarations as well. + // + if (auto aggTypeDeclRef = containerDeclRef.as()) + { + breadcrumb.kind = LookupResultItem::Breadcrumb::Kind::This; + breadcrumb.thisParameterMode = thisParameterMode; + breadcrumb.declRef = aggTypeDeclRef; + breadcrumb.prev = nullptr; + + breadcrumbs = &breadcrumb; + } + + // Now perform "local" lookup in the context of the container, + // as if we were looking up a member directly. + + // if we are currently in an extension decl, perform local lookup + // in the target decl we are extending + if (auto extDeclRef = containerDeclRef.as()) + { + if (extDeclRef.getDecl()->targetType) + { + if (auto targetDeclRef = as(extDeclRef.getDecl()->targetType)) + { + if (auto aggDeclRef = targetDeclRef->declRef.as()) + { + containerDeclRef = extDeclRef.Substitute(aggDeclRef); + } + } + } + } + DoLocalLookupImpl( + session, + name, containerDeclRef, request, result, breadcrumbs); + + if( auto funcDeclRef = containerDeclRef.as() ) + { + if( funcDeclRef.getDecl()->HasModifier() ) + { + thisParameterMode = LookupResultItem::Breadcrumb::ThisParameterMode::Mutating; + } + else + { + thisParameterMode = LookupResultItem::Breadcrumb::ThisParameterMode::Default; + } + } + } + + if (result.isValid()) + { + // If we've found a result in this scope, then there + // is no reason to look further up (for now). + return; + } + } + + // If we run out of scopes, then we are done. +} + +LookupResult DoLookup( + Session* session, + Name* name, + LookupRequest const& request) +{ + LookupResult result; + DoLookupImpl(session, name, request, result); + return result; +} + +LookupResult lookUp( + Session* session, + SemanticsVisitor* semantics, + Name* name, + RefPtr scope, + LookupMask mask) +{ + LookupRequest request; + request.semantics = semantics; + request.scope = scope; + request.mask = mask; + return DoLookup(session, name, request); +} + +// perform lookup within the context of a particular container declaration, +// and do *not* look further up the chain +LookupResult lookUpLocal( + Session* session, + SemanticsVisitor* semantics, + Name* name, + DeclRef containerDeclRef, + LookupMask mask) +{ + LookupRequest request; + request.semantics = semantics; + request.mask = mask; + + LookupResult result; + DoLocalLookupImpl(session, name, containerDeclRef, request, result, nullptr); + return result; +} + +void lookUpMemberImpl( + Session* session, + SemanticsVisitor* semantics, + Name* name, + Type* type, + LookupResult& ioResult, + BreadcrumbInfo* inBreadcrumbs, + LookupMask mask); + +// Perform lookup "through" the given constraint decl-ref, +// which should show that `subType` is a sub-type of some +// super-type (e.g., an interface). +// +void lookUpThroughConstraint( + Session* session, + SemanticsVisitor* semantics, + Name* name, + Type* subType, + DeclRef constraintDeclRef, + LookupResult& ioResult, + BreadcrumbInfo* inBreadcrumbs, + LookupMask mask) +{ + // The super-type in the constraint (e.g., `Foo` in `T : Foo`) + // will tell us a type we should use for lookup. + // + auto superType = GetSup(constraintDeclRef); + // + // We will go ahead and perform lookup using `superType`, + // after dealing with some details. + + // If we are looking up through an interface type, then + // we need to be sure that we add an appropriate + // "this type" substitution here, since that needs to + // be applied to any members we look up. + // + superType = maybeSpecializeInterfaceDeclRef( + session, + subType, + superType, + constraintDeclRef); + + // We need to track the indirection we took in lookup, + // so that we can construct an appropriate AST on the other + // side that includes the "upcase" from sub-type to super-type. + // + BreadcrumbInfo breadcrumb; + breadcrumb.prev = inBreadcrumbs; + breadcrumb.kind = LookupResultItem::Breadcrumb::Kind::Constraint; + breadcrumb.declRef = constraintDeclRef; + + // TODO: Need to consider case where this might recurse infinitely (e.g., + // if an inheritance clause does something like `Bad : Bad>`. + // + // TODO: The even simpler thing we need to worry about here is that if + // there is ever a "diamond" relationship in the inheritance hierarchy, + // we might end up seeing the same interface via different "paths" and + // we wouldn't want that to lead to overload-resolution failure. + // + lookUpMemberImpl(session, semantics, name, superType, ioResult, &breadcrumb, mask); +} + +void lookUpMemberImpl( + Session* session, + SemanticsVisitor* semantics, + Name* name, + Type* type, + LookupResult& ioResult, + BreadcrumbInfo* inBreadcrumbs, + LookupMask mask) +{ + if (auto declRefType = as(type)) + { + auto declRef = declRefType->declRef; + if (declRef.as() || declRef.as()) + { + for (auto constraintDeclRef : getMembersOfType(declRef.as())) + { + lookUpThroughConstraint( + session, + semantics, + name, + type, + constraintDeclRef, + ioResult, + inBreadcrumbs, + mask); + } + } + else if (auto aggTypeDeclRef = declRef.as()) + { + LookupRequest request; + request.semantics = semantics; + + DoLocalLookupImpl(session, name, aggTypeDeclRef, request, ioResult, inBreadcrumbs); + } + else if (auto genericTypeParamDeclRef = declRef.as()) + { + auto genericDeclRef = genericTypeParamDeclRef.GetParent().as(); + assert(genericDeclRef); + + for(auto constraintDeclRef : getMembersOfType(genericDeclRef)) + { + // Does this constraint pertain to the type we are working on? + // + // We want constraints of the form `T : Foo` where `T` is the + // generic parameter in question, and `Foo` is whatever we are + // constraining it to. + auto subType = GetSub(constraintDeclRef); + auto subDeclRefType = as(subType); + if(!subDeclRefType) + continue; + if(!subDeclRefType->declRef.Equals(genericTypeParamDeclRef)) + continue; + + lookUpThroughConstraint( + session, + semantics, + name, + type, + constraintDeclRef, + ioResult, + inBreadcrumbs, + mask); + } + } + + } + +} + +LookupResult lookUpMember( + Session* session, + SemanticsVisitor* semantics, + Name* name, + Type* type, + LookupMask mask) +{ + LookupResult result; + lookUpMemberImpl(session, semantics, name, type, result, nullptr, mask); + return result; +} + +} diff --git a/source/slang/slang-lookup.h b/source/slang/slang-lookup.h new file mode 100644 index 000000000..705b952f3 --- /dev/null +++ b/source/slang/slang-lookup.h @@ -0,0 +1,60 @@ +#ifndef SLANG_LOOKUP_H_INCLUDED +#define SLANG_LOOKUP_H_INCLUDED + +#include "slang-syntax.h" + +namespace Slang { + +struct SemanticsVisitor; + +// Take an existing lookup result and refine it to only include +// results that pass the given `LookupMask`. +LookupResult refineLookup(LookupResult const& inResult, LookupMask mask); + +// Ensure that the dictionary for name-based member lookup has been +// built for the given container declaration. +void buildMemberDictionary(ContainerDecl* decl); + +// Look up a name in the given scope, proceeding up through +// parent scopes as needed. +LookupResult lookUp( + Session* session, + SemanticsVisitor* semantics, + Name* name, + RefPtr scope, + LookupMask mask = LookupMask::Default); + +// perform lookup within the context of a particular container declaration, +// and do *not* look further up the chain +LookupResult lookUpLocal( + Session* session, + SemanticsVisitor* semantics, + Name* name, + DeclRef containerDeclRef, + LookupMask mask = LookupMask::Default); + +// Perform member lookup in the context of a type +LookupResult lookUpMember( + Session* session, + SemanticsVisitor* semantics, + Name* name, + Type* type, + LookupMask mask = LookupMask::Default); + +// TODO: this belongs somewhere else + +QualType getTypeForDeclRef( + Session* session, + SemanticsVisitor* sema, + DiagnosticSink* sink, + DeclRef declRef, + RefPtr* outTypeResult); + +QualType getTypeForDeclRef( + Session* session, + DeclRef declRef); + + +} + +#endif diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp new file mode 100644 index 000000000..8425b9664 --- /dev/null +++ b/source/slang/slang-lower-to-ir.cpp @@ -0,0 +1,6498 @@ +// lower.cpp +#include "slang-lower-to-ir.h" + +#include "../../slang.h" + +#include "slang-check.h" +#include "slang-ir.h" +#include "slang-ir-constexpr.h" +#include "slang-ir-insts.h" +#include "slang-ir-missing-return.h" +#include "slang-ir-sccp.h" +#include "slang-ir-ssa.h" +#include "slang-ir-validate.h" +#include "slang-mangle.h" +#include "slang-type-layout.h" +#include "slang-visitor.h" + +namespace Slang +{ + +// This file implements lowering of the Slang AST to a simpler SSA +// intermediate representation. +// +// IR is generated in a context (`IRGenContext`), which tracks the current +// location in the IR where code should be emitted (e.g., what basic +// block to add instructions to). Lowering a statement will emit some +// number of instructions to the context, and possibly change the +// insertion point (because of control flow). +// +// When lowering an expression we have a more interesting challenge, for +// two main reasons: +// +// 1. There might be types that are representible in the AST, but which +// we don't want to support natively in the IR. An example is a `struct` +// type with both ordinary and resource-type members; we might want to +// split values with such a type into distinct values during lowering. +// +// 2. We need to handle the difference between l-value and r-value expressions, +// and in particular the fact that HLSL/Slang supports complicated sorts +// of l-values (e.g., `someVector.zxy` is an l-value, even though it can't +// be represented by a single pointer), and also allows l-values to appear +// in multiple contexts (not just the left-hand side of assignment, but +// also as an argument to match an `out` or `in out` parameter). +// +// Our solution to both of these problems is the same. Rather than having +// the lowering of an expression return a single IR-level value (`IRInst*`), +// we have it return a more complex type (`LoweredValInfo`) which can represent +// a wider range of conceptual "values" which might correspond to multiple IR-level +// values, and/or represent a pointer to an l-value rather than the r-value itself. + +// We want to keep the representation of a `LoweringValInfo` relatively light +// - right now it is just a single pointer plus a "tag" to distinguish the cases. +// +// This means that cases that can't fit in a single pointer need a heap allocation +// to store their payload. For simplicity we represent all of these with a class +// hierarchy: +// +struct ExtendedValueInfo : RefObject +{}; + +// This case is used to indicate a value that is a reference +// to an AST-level subscript declaration. +// +struct SubscriptInfo : ExtendedValueInfo +{ + DeclRef declRef; +}; + +// This case is used to indicate a reference to an AST-level +// subscript operation bound to particular arguments. +// +// For example in a case like this: +// +// RWStructuredBuffer gBuffer; +// ... gBuffer[someIndex] ... +// +// the expression `gBuffer[someIndex]` will be lowered to +// a value that references `RWStructureBuffer::operator[]` +// with arguments `(gBuffer, someIndex)`. +// +// Such a value can be an l-value, and depending on the context +// where it is used, can lower into a call to either the getter +// or setter operations of the subscript. +// +struct BoundSubscriptInfo : ExtendedValueInfo +{ + DeclRef declRef; + IRType* type; + List args; +}; + +// Some cases of `ExtendedValueInfo` need to +// recursively contain `LoweredValInfo`s, and +// so we forward declare them here and fill +// them in later. +// +struct BoundMemberInfo; +struct SwizzledLValueInfo; + + +// This type is our core representation of lowered values. +// In the simple case, it just wraps an `IRInst*`. +// More complex cases, representing l-values or aggregate +// values are also supported. +struct LoweredValInfo +{ + // Which of the cases of value are we looking at? + enum class Flavor + { + // No value (akin to a null pointer) + None, + + // A simple IR value + Simple, + + // An l-value represented as an IR + // pointer to the value + Ptr, + + // A member declaration bound to a particular `this` value + BoundMember, + + // A reference to an AST-level subscript operation + Subscript, + + // An AST-level subscript operation bound to a particular + // object and arguments. + BoundSubscript, + + // The result of applying swizzling to an l-value + SwizzledLValue, + }; + + union + { + IRInst* val; + ExtendedValueInfo* ext; + }; + Flavor flavor; + + LoweredValInfo() + { + flavor = Flavor::None; + val = nullptr; + } + + LoweredValInfo(IRType* t) + { + flavor = Flavor::Simple; + val = t; + } + + static LoweredValInfo simple(IRInst* v) + { + LoweredValInfo info; + info.flavor = Flavor::Simple; + info.val = v; + return info; + } + + static LoweredValInfo ptr(IRInst* v) + { + LoweredValInfo info; + info.flavor = Flavor::Ptr; + info.val = v; + return info; + } + + static LoweredValInfo boundMember( + BoundMemberInfo* boundMemberInfo); + + BoundMemberInfo* getBoundMemberInfo() + { + SLANG_ASSERT(flavor == Flavor::BoundMember); + return (BoundMemberInfo*)ext; + } + + static LoweredValInfo subscript( + SubscriptInfo* subscriptInfo); + + SubscriptInfo* getSubscriptInfo() + { + SLANG_ASSERT(flavor == Flavor::Subscript); + return (SubscriptInfo*)ext; + } + + static LoweredValInfo boundSubscript( + BoundSubscriptInfo* boundSubscriptInfo); + + BoundSubscriptInfo* getBoundSubscriptInfo() + { + SLANG_ASSERT(flavor == Flavor::BoundSubscript); + return (BoundSubscriptInfo*)ext; + } + + static LoweredValInfo swizzledLValue( + SwizzledLValueInfo* extInfo); + + SwizzledLValueInfo* getSwizzledLValueInfo() + { + SLANG_ASSERT(flavor == Flavor::SwizzledLValue); + return (SwizzledLValueInfo*)ext; + } +}; + +// Represents some declaration bound to a particular +// object. For example, if we had `obj.f` where `f` +// is a member function, we'd use a `BoundMemberInfo` +// to represnet this. +// +// Note: This case is largely avoided by special-casing +// in the handling of calls (like `obj.f(arg)`), but +// it is being left here as an example of what we might +// need/want to do in the long term. +struct BoundMemberInfo : ExtendedValueInfo +{ + // The base object + LoweredValInfo base; + + // The (AST-level) declaration reference. + DeclRef declRef; + + // The type of this value + IRType* type; +}; + +// Represents the result of a swizzle operation in +// an l-value context. A swizzle without duplicate +// elements is allowed as an l-value, even if the +// element are non-contiguous (`.xz`) or out of +// order (`.zxy`). +// +struct SwizzledLValueInfo : ExtendedValueInfo +{ + // The type of the expression. + IRType* type; + + // The base expression (this should be an l-value) + LoweredValInfo base; + + // The number of elements in the swizzle + UInt elementCount; + + // THe indices for the elements being swizzled + UInt elementIndices[4]; +}; + +LoweredValInfo LoweredValInfo::boundMember( + BoundMemberInfo* boundMemberInfo) +{ + LoweredValInfo info; + info.flavor = Flavor::BoundMember; + info.ext = boundMemberInfo; + return info; +} + +LoweredValInfo LoweredValInfo::subscript( + SubscriptInfo* subscriptInfo) +{ + LoweredValInfo info; + info.flavor = Flavor::Subscript; + info.ext = subscriptInfo; + return info; +} + +LoweredValInfo LoweredValInfo::boundSubscript( + BoundSubscriptInfo* boundSubscriptInfo) +{ + LoweredValInfo info; + info.flavor = Flavor::BoundSubscript; + info.ext = boundSubscriptInfo; + return info; +} + +LoweredValInfo LoweredValInfo::swizzledLValue( + SwizzledLValueInfo* extInfo) +{ + LoweredValInfo info; + info.flavor = Flavor::SwizzledLValue; + info.ext = extInfo; + return info; +} + +// An "environment" for mapping AST declarations to IR values. +// +// This is required because in some cases we might lower the +// same AST declaration to the IR multiple times (e.g., when +// a generic transitively contains multiple functions, we +// will emit a distinct IR generic for each function, with +// its own copies of the generic parameters). +// +struct IRGenEnv +{ + // Map an AST-level declaration to the IR-level value that represents it. + Dictionary mapDeclToValue; + + // The next outer env around this one + IRGenEnv* outer = nullptr; +}; + +struct SharedIRGenContext +{ + SharedIRGenContext( + Session* session, + DiagnosticSink* sink, + ModuleDecl* mainModuleDecl = nullptr) + : m_session(session) + , m_sink(sink) + , m_mainModuleDecl(mainModuleDecl) + {} + + Session* m_session = nullptr; + DiagnosticSink* m_sink = nullptr; + ModuleDecl* m_mainModuleDecl = nullptr; + + // The "global" environment for mapping declarations to their IR values. + IRGenEnv globalEnv; + + // Map an AST-level declaration of an interface + // requirement to the IR-level "key" that + // is used to fetch that requirement from a + // witness table. + Dictionary interfaceRequirementKeys; + + // Arrays we keep around strictly for memory-management purposes: + + // Any extended values created during lowering need + // to be cleaned up after the fact. We don't try + // to reference-count these along the way because + // they need to get stored into a `union` inside `LoweredValInfo` + List> extValues; + + // Map from an AST-level statement that can be + // used as the target of a `break` or `continue` + // to the appropriate basic block to jump to. + Dictionary breakLabels; + Dictionary continueLabels; +}; + + +struct IRGenContext +{ + // Shared state for the IR generation process + SharedIRGenContext* shared; + + // environment for mapping AST decls to IR values + IRGenEnv* env; + + // IR builder to use when building code under this context + IRBuilder* irBuilder; + + // The value to use for any `this` expressions + // that appear in the current context. + // + // TODO: If we ever allow nesting of (non-static) + // types, then we may need to support references + // to an "outer `this`", and this representation + // might be insufficient. + LoweredValInfo thisVal; + + explicit IRGenContext(SharedIRGenContext* inShared) + : shared(inShared) + , env(&inShared->globalEnv) + , irBuilder(nullptr) + {} + + Session* getSession() + { + return shared->m_session; + } + + DiagnosticSink* getSink() + { + return shared->m_sink; + } + + ModuleDecl* getMainModuleDecl() + { + return shared->m_mainModuleDecl; + } +}; + +void setGlobalValue(SharedIRGenContext* sharedContext, Decl* decl, LoweredValInfo value) +{ + sharedContext->globalEnv.mapDeclToValue[decl] = value; +} + +void setGlobalValue(IRGenContext* context, Decl* decl, LoweredValInfo value) +{ + setGlobalValue(context->shared, decl, value); +} + +void setValue(IRGenContext* context, Decl* decl, LoweredValInfo value) +{ + context->env->mapDeclToValue[decl] = value; +} + +ModuleDecl* findModuleDecl(Decl* decl) +{ + for (auto dd = decl; dd; dd = dd->ParentDecl) + { + if (auto moduleDecl = as(dd)) + return moduleDecl; + } + return nullptr; +} + +bool isFromStdLib(Decl* decl) +{ + for (auto dd = decl; dd; dd = dd->ParentDecl) + { + if (dd->HasModifier()) + return true; + } + return false; +} + +bool isImportedDecl(IRGenContext* context, Decl* decl) +{ + ModuleDecl* moduleDecl = findModuleDecl(decl); + if (!moduleDecl) + return false; + + // HACK: don't treat standard library code as + // being imported for right now, just because + // we don't load its IR in the same way as + // for other imports. + // + // TODO: Fix this the right way, by having standard + // library declarations have IR modules that we link + // in via the normal means. + if (isFromStdLib(decl)) + return false; + + if (moduleDecl != context->getMainModuleDecl()) + return true; + + return false; +} + + /// Should the given `decl` nested in `parentDecl` be treated as a static rather than instance declaration? +bool isEffectivelyStatic( + Decl* decl, + ContainerDecl* parentDecl); + +// Ensure that a version of the given declaration has been emitted to the IR +LoweredValInfo ensureDecl( + IRGenContext* context, + Decl* decl); + +// Emit code as needed to construct a reference to the given declaration with +// any needed specializations in place. +LoweredValInfo emitDeclRef( + IRGenContext* context, + DeclRef declRef, + IRType* type); + +IRInst* getSimpleVal(IRGenContext* context, LoweredValInfo lowered); + +IROp getIntrinsicOp( + Decl* decl, + IntrinsicOpModifier* intrinsicOpMod) +{ + if (int(intrinsicOpMod->op) != 0) + return intrinsicOpMod->op; + + // No specified modifier? Then we need to look it up + // based on the name of the declaration... + + auto name = decl->getName(); + auto nameText = getUnownedStringSliceText(name); + + IROp op = findIROp(nameText); + SLANG_ASSERT(op != kIROp_Invalid); + return op; +} + +// Given a `LoweredValInfo` for something callable, along with a +// bunch of arguments, emit an appropriate call to it. +LoweredValInfo emitCallToVal( + IRGenContext* context, + IRType* type, + LoweredValInfo funcVal, + UInt argCount, + IRInst* const* args) +{ + auto builder = context->irBuilder; + switch (funcVal.flavor) + { + case LoweredValInfo::Flavor::None: + SLANG_UNEXPECTED("null function"); + default: + return LoweredValInfo::simple( + builder->emitCallInst(type, getSimpleVal(context, funcVal), argCount, args)); + } +} + +LoweredValInfo emitCompoundAssignOp( + IRGenContext* context, + IRType* type, + IROp op, + UInt argCount, + IRInst* const* args) +{ + auto builder = context->irBuilder; + SLANG_UNREFERENCED_PARAMETER(argCount); + SLANG_ASSERT(argCount == 2); + auto leftPtr = args[0]; + auto rightVal = args[1]; + + auto leftVal = builder->emitLoad(leftPtr); + + IRInst* innerArgs[] = { leftVal, rightVal }; + auto innerOp = builder->emitIntrinsicInst(type, op, 2, innerArgs); + + builder->emitStore(leftPtr, innerOp); + + return LoweredValInfo::ptr(leftPtr); +} + +IRInst* getOneValOfType( + IRGenContext* context, + IRType* type) +{ + switch(type->op) + { + case kIROp_IntType: + case kIROp_UIntType: + case kIROp_UInt64Type: + return context->irBuilder->getIntValue(type, 1); + + case kIROp_HalfType: + case kIROp_FloatType: + case kIROp_DoubleType: + return context->irBuilder->getFloatValue(type, 1.0); + + default: + break; + } + + // TODO: should make sure to handle vector and matrix types here + + SLANG_UNEXPECTED("inc/dec type"); + UNREACHABLE_RETURN(nullptr); +} + +LoweredValInfo emitPrefixIncDecOp( + IRGenContext* context, + IRType* type, + IROp op, + UInt argCount, + IRInst* const* args) +{ + auto builder = context->irBuilder; + SLANG_UNREFERENCED_PARAMETER(argCount); + SLANG_ASSERT(argCount == 1); + auto argPtr = args[0]; + + auto preVal = builder->emitLoad(argPtr); + + IRInst* oneVal = getOneValOfType(context, type); + + IRInst* innerArgs[] = { preVal, oneVal }; + auto innerOp = builder->emitIntrinsicInst(type, op, 2, innerArgs); + + builder->emitStore(argPtr, innerOp); + + // For a prefix operator like `++i` we return + // the value after the increment/decrement has + // been applied. In casual terms we "increment + // the varaible, then return its value." + // + return LoweredValInfo::simple(innerOp); +} + +LoweredValInfo emitPostfixIncDecOp( + IRGenContext* context, + IRType* type, + IROp op, + UInt argCount, + IRInst* const* args) +{ + auto builder = context->irBuilder; + SLANG_UNREFERENCED_PARAMETER(argCount); + SLANG_ASSERT(argCount == 1); + auto argPtr = args[0]; + + auto preVal = builder->emitLoad(argPtr); + + IRInst* oneVal = getOneValOfType(context, type); + + IRInst* innerArgs[] = { preVal, oneVal }; + auto innerOp = builder->emitIntrinsicInst(type, op, 2, innerArgs); + + builder->emitStore(argPtr, innerOp); + + // For a postfix operator like `i++` we return + // the value that we read before the increment/decrement + // gets applied. In casual terms we "read + // the variable, then increment it." + // + return LoweredValInfo::simple(preVal); +} + +LoweredValInfo lowerRValueExpr( + IRGenContext* context, + Expr* expr); + +IRType* lowerType( + IRGenContext* context, + Type* type); + +static IRType* lowerType( + IRGenContext* context, + QualType const& type) +{ + return lowerType(context, type.type); +} + +// Given a `DeclRef` for something callable, along with a bunch of +// arguments, emit an appropriate call to it. +LoweredValInfo emitCallToDeclRef( + IRGenContext* context, + IRType* type, + DeclRef funcDeclRef, + IRType* funcType, + UInt argCount, + IRInst* const* args) +{ + auto builder = context->irBuilder; + + + if (auto subscriptDeclRef = funcDeclRef.as()) + { + // A reference to a subscript declaration is a special case, + // because it is not possible to call a subscript directly; + // we must call one of its accessors. + // + // TODO: everything here will also apply to propery declarations + // once we have them, so some of this code might be shared + // some day. + + DeclRef getterDeclRef; + bool justAGetter = true; + for (auto accessorDeclRef : getMembersOfType(subscriptDeclRef)) + { + // We want to track whether this subscript has any accessors other than + // `get` (assuming that everything except `get` can be used for setting...). + + if (auto foundGetterDeclRef = accessorDeclRef.as()) + { + // We found a getter. + getterDeclRef = foundGetterDeclRef; + } + else + { + // There was something other than a getter, so we can't + // invoke an accessor just now. + justAGetter = false; + } + } + + if (!justAGetter || !getterDeclRef) + { + // We can't perform an actual call right now, because + // this expression might appear in an r-value or l-value + // position (or *both* if it is being passed as an argument + // for an `in out` parameter!). + // + // Instead, we will construct a special-case value to + // represent the latent subscript operation (abstractly + // this is a reference to a storage location). + + // The abstract storage location will need to include + // all the arguments being passed to the subscript operation. + + RefPtr boundSubscript = new BoundSubscriptInfo(); + boundSubscript->declRef = subscriptDeclRef; + boundSubscript->type = type; + boundSubscript->args.addRange(args, argCount); + + context->shared->extValues.add(boundSubscript); + + return LoweredValInfo::boundSubscript(boundSubscript); + } + + // Otherwise we are just call the getter, and so that + // is what we need to be emitting a call to... + funcDeclRef = getterDeclRef; + } + + auto funcDecl = funcDeclRef.getDecl(); + if(auto intrinsicOpModifier = funcDecl->FindModifier()) + { + auto op = getIntrinsicOp(funcDecl, intrinsicOpModifier); + + if (isPseudoOp(op)) + { + switch (op) + { + case kIRPseudoOp_Pos: + return LoweredValInfo::simple(args[0]); + + case kIRPseudoOp_Sequence: + // The main effect of "operator comma" is to enforce + // sequencing of its operands, but Slang already + // implements a strictly left-to-right evaluation + // order for function arguments, so in practice we + // just need to compile `a, b` to the value of `b` + // (because argument evaluation already happened). + return LoweredValInfo::simple(args[1]); + +#define CASE(COMPOUND, OP) \ + case COMPOUND: return emitCompoundAssignOp(context, type, OP, argCount, args) + + CASE(kIRPseudoOp_AddAssign, kIROp_Add); + CASE(kIRPseudoOp_SubAssign, kIROp_Sub); + CASE(kIRPseudoOp_MulAssign, kIROp_Mul); + CASE(kIRPseudoOp_DivAssign, kIROp_Div); + CASE(kIRPseudoOp_ModAssign, kIROp_Mod); + CASE(kIRPseudoOp_AndAssign, kIROp_BitAnd); + CASE(kIRPseudoOp_OrAssign, kIROp_BitOr); + CASE(kIRPseudoOp_XorAssign, kIROp_BitXor); + CASE(kIRPseudoOp_LshAssign, kIROp_Lsh); + CASE(kIRPseudoOp_RshAssign, kIROp_Rsh); + +#undef CASE + +#define CASE(COMPOUND, OP) \ + case COMPOUND: return emitPrefixIncDecOp(context, type, OP, argCount, args) + CASE(kIRPseudoOp_PreInc, kIROp_Add); + CASE(kIRPseudoOp_PreDec, kIROp_Sub); +#undef CASE + +#define CASE(COMPOUND, OP) \ + case COMPOUND: return emitPostfixIncDecOp(context, type, OP, argCount, args) + CASE(kIRPseudoOp_PostInc, kIROp_Add); + CASE(kIRPseudoOp_PostDec, kIROp_Sub); +#undef CASE + default: + SLANG_UNIMPLEMENTED_X("IR pseudo-op"); + UNREACHABLE_RETURN(LoweredValInfo()); + } + } + + return LoweredValInfo::simple(builder->emitIntrinsicInst( + type, + op, + argCount, + args)); + } + // TODO: handle target intrinsic modifier too... + + if( auto ctorDeclRef = funcDeclRef.as() ) + { + // HACK: we know all constructors are builtins for now, + // so we need to emit them as a call to the corresponding + // builtin operation. + // + // TODO: these should all either be intrinsic operations, + // or calls to library functions. + + return LoweredValInfo::simple(builder->emitConstructorInst(type, argCount, args)); + } + + // Fallback case is to emit an actual call. + if(!funcType) + { + List argTypes; + for(UInt ii = 0; ii < argCount; ++ii) + { + argTypes.add(args[ii]->getDataType()); + } + funcType = builder->getFuncType(argCount, argTypes.getBuffer(), type); + } + LoweredValInfo funcVal = emitDeclRef(context, funcDeclRef, funcType); + return emitCallToVal(context, type, funcVal, argCount, args); +} + +LoweredValInfo emitCallToDeclRef( + IRGenContext* context, + IRType* type, + DeclRef funcDeclRef, + IRType* funcType, + List const& args) +{ + return emitCallToDeclRef(context, type, funcDeclRef, funcType, args.getCount(), args.getBuffer()); +} + +IRInst* getFieldKey( + IRGenContext* context, + DeclRef field) +{ + return getSimpleVal(context, emitDeclRef(context, field, context->irBuilder->getKeyType())); +} + +LoweredValInfo extractField( + IRGenContext* context, + IRType* fieldType, + LoweredValInfo base, + DeclRef field) +{ + IRBuilder* builder = context->irBuilder; + + switch (base.flavor) + { + default: + { + IRInst* irBase = getSimpleVal(context, base); + return LoweredValInfo::simple( + builder->emitFieldExtract( + fieldType, + irBase, + getFieldKey(context, field))); + } + break; + + case LoweredValInfo::Flavor::BoundMember: + case LoweredValInfo::Flavor::BoundSubscript: + { + // The base value is one that is trying to defer a get-vs-set + // decision, so we will need to do the same. + + RefPtr boundMemberInfo = new BoundMemberInfo(); + boundMemberInfo->type = fieldType; + boundMemberInfo->base = base; + boundMemberInfo->declRef = field; + + context->shared->extValues.add(boundMemberInfo); + return LoweredValInfo::boundMember(boundMemberInfo); + } + break; + + case LoweredValInfo::Flavor::Ptr: + { + // We are "extracting" a field from an lvalue address, + // which means we should just compute an lvalue + // representing the field address. + IRInst* irBasePtr = base.val; + return LoweredValInfo::ptr( + builder->emitFieldAddress( + builder->getPtrType(fieldType), + irBasePtr, + getFieldKey(context, field))); + } + break; + } +} + + + +LoweredValInfo materialize( + IRGenContext* context, + LoweredValInfo lowered) +{ + auto builder = context->irBuilder; + +top: + switch(lowered.flavor) + { + case LoweredValInfo::Flavor::None: + case LoweredValInfo::Flavor::Simple: + case LoweredValInfo::Flavor::Ptr: + return lowered; + + case LoweredValInfo::Flavor::BoundSubscript: + { + auto boundSubscriptInfo = lowered.getBoundSubscriptInfo(); + + // We are being asked to extract a value from a subscript call + // (e.g., `base[index]`). We will first check if the subscript + // declared a getter and use that if possible, and then fall + // back to a `ref` accessor if one is defined. + // + // (Picking the `get` over the `ref` accessor simplifies things + // in case the `get` operation has a natural translation for + // a target, while the general `ref` case does not...) + + auto getters = getMembersOfType(boundSubscriptInfo->declRef); + if (getters.Count()) + { + lowered = emitCallToDeclRef( + context, + boundSubscriptInfo->type, + *getters.begin(), + nullptr, + boundSubscriptInfo->args); + goto top; + } + + auto refAccessors = getMembersOfType(boundSubscriptInfo->declRef); + if(refAccessors.Count()) + { + // The `ref` accessor will return a pointer to the value, so + // we need to reflect that in the type of our `call` instruction. + IRType* ptrType = context->irBuilder->getPtrType(boundSubscriptInfo->type); + + LoweredValInfo refVal = emitCallToDeclRef( + context, + ptrType, + *refAccessors.begin(), + nullptr, + boundSubscriptInfo->args); + + // The result from the call needs to be implicitly dereferenced, + // so that it can work as an l-value of the desired result type. + lowered = LoweredValInfo::ptr(getSimpleVal(context, refVal)); + + goto top; + } + + SLANG_UNEXPECTED("subscript had no getter"); + UNREACHABLE_RETURN(LoweredValInfo()); + } + break; + + case LoweredValInfo::Flavor::BoundMember: + { + auto boundMemberInfo = lowered.getBoundMemberInfo(); + auto base = materialize(context, boundMemberInfo->base); + + auto declRef = boundMemberInfo->declRef; + if( auto fieldDeclRef = declRef.as() ) + { + lowered = extractField(context, boundMemberInfo->type, base, fieldDeclRef); + goto top; + } + else + { + + SLANG_UNEXPECTED("unexpected member flavor"); + UNREACHABLE_RETURN(LoweredValInfo()); + } + } + break; + + case LoweredValInfo::Flavor::SwizzledLValue: + { + auto swizzleInfo = lowered.getSwizzledLValueInfo(); + + return LoweredValInfo::simple(builder->emitSwizzle( + swizzleInfo->type, + getSimpleVal(context, swizzleInfo->base), + swizzleInfo->elementCount, + swizzleInfo->elementIndices)); + } + + default: + SLANG_UNEXPECTED("unhandled value flavor"); + UNREACHABLE_RETURN(LoweredValInfo()); + } + +} + +IRInst* getSimpleVal(IRGenContext* context, LoweredValInfo lowered) +{ + auto builder = context->irBuilder; + + // First, try to eliminate any "bound" operations along the chain, + // so that we are dealing with an ordinary value, or an l-value pointer. + lowered = materialize(context, lowered); + + switch(lowered.flavor) + { + case LoweredValInfo::Flavor::None: + return nullptr; + + case LoweredValInfo::Flavor::Simple: + return lowered.val; + + case LoweredValInfo::Flavor::Ptr: + return builder->emitLoad(lowered.val); + + default: + SLANG_UNEXPECTED("unhandled value flavor"); + UNREACHABLE_RETURN(nullptr); + } +} + +LoweredValInfo lowerVal( + IRGenContext* context, + Val* val); + +IRInst* lowerSimpleVal( + IRGenContext* context, + Val* val) +{ + auto lowered = lowerVal(context, val); + return getSimpleVal(context, lowered); +} + +LoweredValInfo lowerLValueExpr( + IRGenContext* context, + Expr* expr); + +void assign( + IRGenContext* context, + LoweredValInfo const& left, + LoweredValInfo const& right); + +IRInst* getAddress( + IRGenContext* context, + LoweredValInfo const& inVal, + SourceLoc diagnosticLocation); + +void lowerStmt( + IRGenContext* context, + Stmt* stmt); + +LoweredValInfo lowerDecl( + IRGenContext* context, + DeclBase* decl); + +IRType* getIntType( + IRGenContext* context) +{ + return context->irBuilder->getBasicType(BaseType::Int); +} + +static IRGeneric* getOuterGeneric(IRInst* gv) +{ + auto parentBlock = as(gv->getParent()); + if (!parentBlock) return nullptr; + + auto parentGeneric = as(parentBlock->getParent()); + return parentGeneric; +} + +static void addLinkageDecoration( + IRGenContext* context, + IRInst* inInst, + Decl* decl, + UnownedStringSlice const& mangledName) +{ + // If the instruction is nested inside one or more generics, + // then the mangled name should really apply to the outer-most + // generic, and not the declaration nested inside. + + auto builder = context->irBuilder; + + IRInst* inst = inInst; + while (auto outerGeneric = getOuterGeneric(inst)) + { + inst = outerGeneric; + } + + if(isImportedDecl(context, decl)) + { + builder->addImportDecoration(inst, mangledName); + } + else + { + builder->addExportDecoration(inst, mangledName); + } +} + +static void addLinkageDecoration( + IRGenContext* context, + IRInst* inst, + Decl* decl) +{ + addLinkageDecoration(context, inst, decl, getMangledName(decl).getUnownedSlice()); +} + +IRStructKey* getInterfaceRequirementKey( + IRGenContext* context, + Decl* requirementDecl) +{ + IRStructKey* requirementKey = nullptr; + if(context->shared->interfaceRequirementKeys.TryGetValue(requirementDecl, requirementKey)) + { + return requirementKey; + } + + IRBuilder builderStorage = *context->irBuilder; + auto builder = &builderStorage; + + builder->setInsertInto(builder->sharedBuilder->module->getModuleInst()); + + // Construct a key to serve as the representation of + // this requirement in the IR, and to allow lookup + // into the declaration. + requirementKey = builder->createStructKey(); + + addLinkageDecoration(context, requirementKey, requirementDecl); + + context->shared->interfaceRequirementKeys.Add(requirementDecl, requirementKey); + + return requirementKey; +} + + +SubstitutionSet lowerSubstitutions(IRGenContext* context, SubstitutionSet subst); +// + +struct ValLoweringVisitor : ValVisitor +{ + IRGenContext* context; + + IRBuilder* getBuilder() { return context->irBuilder; } + + LoweredValInfo visitVal(Val* /*val*/) + { + SLANG_UNIMPLEMENTED_X("value lowering"); + UNREACHABLE_RETURN(LoweredValInfo()); + } + + LoweredValInfo visitGenericParamIntVal(GenericParamIntVal* val) + { + return emitDeclRef(context, val->declRef, + lowerType(context, GetType(val->declRef))); + } + + LoweredValInfo visitDeclaredSubtypeWitness(DeclaredSubtypeWitness* val) + { + return emitDeclRef(context, val->declRef, + context->irBuilder->getWitnessTableType()); + } + + LoweredValInfo visitTransitiveSubtypeWitness( + TransitiveSubtypeWitness* val) + { + // The base (subToMid) will turn into a value with + // witness-table type. + IRInst* baseWitnessTable = lowerSimpleVal(context, val->subToMid); + + // The next step should map to an interface requirement + // that is itself an interface conformance, so the result + // of lowering this value should be a "key" that we can + // use to look up a witness table. + IRInst* requirementKey = getInterfaceRequirementKey(context, val->midToSup.getDecl()); + + // TODO: There are some ugly cases here if `midToSup` is allowed + // to be an arbitrary witness, rather than just a declared one, + // and we should probably change the front-end representation + // to reflect the right constraints. + + return LoweredValInfo::simple(getBuilder()->emitLookupInterfaceMethodInst( + nullptr, + baseWitnessTable, + requirementKey)); + } + + LoweredValInfo visitTaggedUnionSubtypeWitness( + TaggedUnionSubtypeWitness* val) + { + // The sub-type in this case is a tagged union `A | B | ...`, + // and the witness holds an array of witnesses showing that each + // "case" (`A`, `B`, etc.) is a subtype of the super-type. + + // We will start by getting the IR-level representation of the + // sub type (the tagged union type). + // + auto irTaggedUnionType = lowerType(context, val->sub); + + // We can turn each of those per-case witnesses into a witness + // table value: + // + auto caseCount = val->caseWitnesses.getCount(); + List caseWitnessTables; + for( auto caseWitness : val->caseWitnesses ) + { + auto caseWitnessTable = lowerSimpleVal(context, caseWitness); + caseWitnessTables.add(caseWitnessTable); + } + + // Now we need to synthesize a witness table for the tagged union + // value, showing how it can implement all of the requirements + // of the super type by delegating to the appropriate implementation + // on a per-case basis. + // + // We will assume here that the super-type is an interface, and it + // will be left to the front-end to ensure this property. + // + auto supDeclRefType = as(val->sup); + if(!supDeclRefType) + { + SLANG_UNEXPECTED("super-type not a decl-ref type when generating tagged union witness table"); + UNREACHABLE_RETURN(LoweredValInfo()); + } + auto supInterfaceDeclRef = supDeclRefType->declRef.as(); + if( !supInterfaceDeclRef ) + { + SLANG_UNEXPECTED("super-type not an interface type when generating tagged union witness table"); + UNREACHABLE_RETURN(LoweredValInfo()); + } + + auto irWitnessTable = getBuilder()->createWitnessTable(); + + // Now we will iterate over the requirements (members) of the + // interface and try to synthesize an appropriate value for each. + // + for( auto reqDeclRef : getMembers(supInterfaceDeclRef) ) + { + // TODO: if there are any members we shouldn't process as a requirement, + // then we should detect and skip them here. + // + + // Every interface requirement will have a unique key that is used + // when looking up the requirement in a concrete witness table. + // + auto irReqKey = getInterfaceRequirementKey(context, reqDeclRef.getDecl()); + + // We expect that each of the witness tables in `caseWitnessTables` + // will have an entry to match these keys. However, we may not + // have a concrete `IRWitnessTable` for each of the case types, either + // because they are a specialization of a generic (so that the witness + // table reference is a `specialize` instruction at this point), or + // they are a type external to this module (so that we have a declaration + // rather than a definition of the witness table). + + // Our task is to create an IR value that can satisfy the interface + // requirement for the tagged union type, by appropriately delegating + // to the implementations of the same requirement in the case types. + // + IRInst* irSatisfyingVal = nullptr; + + + + if(auto callableDeclRef = reqDeclRef.as()) + { + // We have something callable, so we need to synthesize + // a function to satisfy it. + // + auto irFunc = getBuilder()->createFunc(); + irSatisfyingVal = irFunc; + + IRBuilder subBuilderStorage; + auto subBuilder = &subBuilderStorage; + subBuilder->sharedBuilder = getBuilder()->sharedBuilder; + subBuilder->setInsertInto(irFunc); + + // We will start by setting up the function parameters, + // which live in the entry block of the IR function. + // + auto entryBlock = subBuilder->emitBlock(); + subBuilder->setInsertInto(entryBlock); + + // Create a `this` parameter of the tagged-union type. + // + // TODO: need to handle the `[mutating]` case here... + // + auto irThisType = irTaggedUnionType; + auto irThisParam = subBuilder->emitParam(irThisType); + + List irParamTypes; + irParamTypes.add(irThisType); + + // Create the remaining parameters of the callable, + // using a decl-ref specialized to the tagged union + // type (so that things like associated types are + // mapped to the correct witness value). + // + List irParams; + for( auto paramDeclRef : getMembersOfType(callableDeclRef) ) + { + // TODO: need to handle `out` and `in out` here. Over all + // there is a lot of duplication here with the existing logic + // for emitting the signature of a `CallableDecl`, and we should + // try to re-use that if at all possible. + // + auto irParamType = lowerType(context, GetType(paramDeclRef)); + auto irParam = subBuilder->emitParam(irParamType); + + irParams.add(irParam); + irParamTypes.add(irParamType); + } + + auto irResultType = lowerType(context, GetResultType(callableDeclRef)); + + auto irFuncType = subBuilder->getFuncType( + irParamTypes, + irResultType); + irFunc->setFullType(irFuncType); + + // The first thing our function needs to do is extract the tag + // from the incoming `this` parameter. + // + auto irTagVal = subBuilder->emitExtractTaggedUnionTag(irThisParam); + + // Next we want to emit a `switch` on the tag value, but before we + // do that we need to generate the code for each of the cases so that + // our `switch` has somewhere to branch to. + // + List switchCaseOperands; + + IRBlock* defaultLabel = nullptr; + + for( Index ii = 0; ii < caseCount; ++ii ) + { + auto caseTag = subBuilder->getIntValue(irTagVal->getDataType(), ii); + + subBuilder->setInsertInto(irFunc); + auto caseLabel = subBuilder->emitBlock(); + + if(!defaultLabel) + defaultLabel = caseLabel; + + switchCaseOperands.add(caseTag); + switchCaseOperands.add(caseLabel); + + subBuilder->setInsertInto(caseLabel); + + // We need to look up the satisfying value for this interface + // requirement on the witness table of the particular case value. + // + // We already have the witness table, and the requirement key is + // just `irReqKey`. + // + auto caseWitnessTable = caseWitnessTables[ii]; + + // The subtle bit here is determining the type we expect the + // satisfying value to have, since that depends on the actual + // type that is satisfying the requirement. + // + IRType* caseResultType = irResultType; + IRType* caseFuncType = nullptr; + auto caseFunc = subBuilder->emitLookupInterfaceMethodInst( + caseFuncType, + caseWitnessTable, + irReqKey); + + // We are going to emit a `call` to the satisfying value + // for the case type, so we will collect the arguments for that call. + // + List caseArgs; + + // The `this` argument to the call will need to represent the + // appropriate field of our tagged union. + // + IRType* caseThisType = (IRType*) irTaggedUnionType->getOperand(ii); + auto caseThisArg = subBuilder->emitExtractTaggedUnionPayload( + caseThisType, + irThisParam, caseTag); + caseArgs.add(caseThisArg); + + // The remaining arguments to the call will just be forwarded from + // the parameters of the wrapper function. + // + // TODO: This would need to change if/when we started allowing `This` type + // or associated-type parameters to be used at call sites where a tagged + // union is used. + // + for( auto param : irParams ) + { + caseArgs.add(param); + } + + auto caseCall = subBuilder->emitCallInst(caseResultType, caseFunc, caseArgs); + + if( as(irResultType->getDataType()) ) + { + subBuilder->emitReturn(); + } + else + { + subBuilder->emitReturn(caseCall); + } + } + + // We will create a block to represent the supposedly-unreachable + // code that will run if no `case` matches. + // + subBuilder->setInsertInto(irFunc); + auto invalidLabel = subBuilder->emitBlock(); + subBuilder->setInsertInto(invalidLabel); + subBuilder->emitUnreachable(); + + if(!defaultLabel) defaultLabel = invalidLabel; + + // Now we have enough information to go back and emit the `switch` instruction + // into the entry block. + subBuilder->setInsertInto(entryBlock); + subBuilder->emitSwitch( + irTagVal, // value to `switch` on + invalidLabel, // `break` label (block after the `switch` statement ends) + defaultLabel, // `default` label (where to go if no `case` matches) + switchCaseOperands.getCount(), + switchCaseOperands.getBuffer()); + } + else + { + // TODO: We need to handle other cases of interface requirements. + SLANG_UNEXPECTED("unexpceted interface requirement when generating tagged union witness table"); + UNREACHABLE_RETURN(LoweredValInfo()); + } + + // Once we've generating a value to satisfying the requirement, we install + // it into the witness table for our tagged-union type. + // + getBuilder()->createWitnessTableEntry(irWitnessTable, irReqKey, irSatisfyingVal); + } + + return LoweredValInfo::simple(irWitnessTable); + } + + LoweredValInfo visitConstantIntVal(ConstantIntVal* val) + { + // TODO: it is a bit messy here that the `ConstantIntVal` representation + // has no notion of a *type* associated with the value... + + auto type = getIntType(context); + return LoweredValInfo::simple(getBuilder()->getIntValue(type, val->value)); + } + + IRFuncType* visitFuncType(FuncType* type) + { + IRType* resultType = lowerType(context, type->getResultType()); + UInt paramCount = type->getParamCount(); + List paramTypes; + for (UInt pp = 0; pp < paramCount; ++pp) + { + paramTypes.add(lowerType(context, type->getParamType(pp))); + } + return getBuilder()->getFuncType( + paramCount, + paramTypes.getBuffer(), + resultType); + } + + IRType* visitDeclRefType(DeclRefType* type) + { + auto declRef = type->declRef; + auto decl = declRef.getDecl(); + + // Check for types with teh `__intrinsic_type` modifier. + if(decl->FindModifier()) + { + return lowerSimpleIntrinsicType(type); + } + + + return (IRType*) getSimpleVal( + context, + emitDeclRef(context, declRef, + context->irBuilder->getTypeKind())); + } + + IRType* visitNamedExpressionType(NamedExpressionType* type) + { + return (IRType*)getSimpleVal(context, dispatchType(type->GetCanonicalType())); + } + + IRType* visitBasicExpressionType(BasicExpressionType* type) + { + return getBuilder()->getBasicType( + type->baseType); + } + + IRType* visitVectorExpressionType(VectorExpressionType* type) + { + auto elementType = lowerType(context, type->elementType); + auto elementCount = lowerSimpleVal(context, type->elementCount); + + return getBuilder()->getVectorType( + elementType, + elementCount); + } + + IRType* visitMatrixExpressionType(MatrixExpressionType* type) + { + auto elementType = lowerType(context, type->getElementType()); + auto rowCount = lowerSimpleVal(context, type->getRowCount()); + auto columnCount = lowerSimpleVal(context, type->getColumnCount()); + + return getBuilder()->getMatrixType( + elementType, + rowCount, + columnCount); + } + + IRType* visitArrayExpressionType(ArrayExpressionType* type) + { + auto elementType = lowerType(context, type->baseType); + if (type->ArrayLength) + { + auto elementCount = lowerSimpleVal(context, type->ArrayLength); + return getBuilder()->getArrayType( + elementType, + elementCount); + } + else + { + return getBuilder()->getUnsizedArrayType( + elementType); + } + } + + // Lower a type where the type declaration being referenced is assumed + // to be an intrinsic type, which can thus be lowered to a simple IR + // type with the appropriate opcode. + IRType* lowerSimpleIntrinsicType(DeclRefType* type) + { + auto intrinsicTypeModifier = type->declRef.getDecl()->FindModifier(); + SLANG_ASSERT(intrinsicTypeModifier); + IROp op = IROp(intrinsicTypeModifier->irOp); + return getBuilder()->getType(op); + } + + // Lower a type where the type declaration being referenced is assumed + // to be an intrinsic type with a single generic type parameter, and + // which can thus be lowered to a simple IR type with the appropriate opcode. + IRType* lowerGenericIntrinsicType(DeclRefType* type, Type* elementType) + { + auto intrinsicTypeModifier = type->declRef.getDecl()->FindModifier(); + SLANG_ASSERT(intrinsicTypeModifier); + IROp op = IROp(intrinsicTypeModifier->irOp); + IRInst* irElementType = lowerType(context, elementType); + return getBuilder()->getType( + op, + 1, + &irElementType); + } + + IRType* lowerGenericIntrinsicType(DeclRefType* type, Type* elementType, IntVal* count) + { + auto intrinsicTypeModifier = type->declRef.getDecl()->FindModifier(); + SLANG_ASSERT(intrinsicTypeModifier); + IROp op = IROp(intrinsicTypeModifier->irOp); + IRInst* irElementType = lowerType(context, elementType); + + IRInst* irCount = lowerSimpleVal(context, count); + + IRInst* const operands[2] = + { + irElementType, + irCount, + }; + + return getBuilder()->getType( + op, + SLANG_COUNT_OF(operands), + operands); + } + + IRType* visitResourceType(ResourceType* type) + { + return lowerGenericIntrinsicType(type, type->elementType); + } + + IRType* visitSamplerStateType(SamplerStateType* type) + { + return lowerSimpleIntrinsicType(type); + } + + IRType* visitBuiltinGenericType(BuiltinGenericType* type) + { + return lowerGenericIntrinsicType(type, type->elementType); + } + + IRType* visitUntypedBufferResourceType(UntypedBufferResourceType* type) + { + return lowerSimpleIntrinsicType(type); + } + + IRType* visitHLSLPatchType(HLSLPatchType* type) + { + Type* elementType = type->getElementType(); + IntVal* count = type->getElementCount(); + + return lowerGenericIntrinsicType(type, elementType, count); + } + + IRType* visitExtractExistentialType(ExtractExistentialType* type) + { + auto declRef = type->declRef; + auto existentialType = lowerType(context, GetType(declRef)); + IRInst* existentialVal = getSimpleVal(context, emitDeclRef(context, declRef, existentialType)); + return getBuilder()->emitExtractExistentialType(existentialVal); + } + + LoweredValInfo visitExtractExistentialSubtypeWitness(ExtractExistentialSubtypeWitness* witness) + { + auto declRef = witness->declRef; + auto existentialType = lowerType(context, GetType(declRef)); + IRInst* existentialVal = getSimpleVal(context, emitDeclRef(context, declRef, existentialType)); + return LoweredValInfo::simple(getBuilder()->emitExtractExistentialWitnessTable(existentialVal)); + } + + LoweredValInfo visitTaggedUnionType(TaggedUnionType* type) + { + // A tagged union type will lower into an IR `union` over the cases, + // along with an IR `struct` with a field for the union and a tag. + // (Note: we are placing the tag after the payload to avoid padding + // in the case where the payload is more aligned than the tag) + // + // TODO: should we be lowering directly like this, or have + // an IR-level representation of tagged unions? + // + + List irCaseTypes; + for(auto caseType : type->caseTypes) + { + auto irCaseType = lowerType(context, caseType); + irCaseTypes.add(irCaseType); + } + + auto irType = getBuilder()->getTaggedUnionType(irCaseTypes); + if(!irType->findDecoration()) + { + // We need a way for later passes to attach layout information + // to this type, so we will give it a mangled name here. + // + getBuilder()->addExportDecoration( + irType, + getMangledTypeName(type).getUnownedSlice()); + } + return LoweredValInfo::simple(irType); + } + + LoweredValInfo visitExistentialSpecializedType(ExistentialSpecializedType* type) + { + auto irBaseType = lowerType(context, type->baseType); + + List slotArgs; + for(auto arg : type->slots.args) + { + auto irArgType = lowerType(context, arg.type); + auto irArgWitness = lowerSimpleVal(context, arg.witness); + + slotArgs.add(irArgType); + slotArgs.add(irArgWitness); + } + + auto irType = getBuilder()->getBindExistentialsType(irBaseType, slotArgs.getCount(), slotArgs.getBuffer()); + return LoweredValInfo::simple(irType); + } + + // We do not expect to encounter the following types in ASTs that have + // passed front-end semantic checking. +#define UNEXPECTED_CASE(NAME) IRType* visit##NAME(NAME*) { SLANG_UNEXPECTED(#NAME); UNREACHABLE_RETURN(nullptr); } + UNEXPECTED_CASE(GenericDeclRefType) + UNEXPECTED_CASE(TypeType) + UNEXPECTED_CASE(ErrorType) + UNEXPECTED_CASE(InitializerListType) + UNEXPECTED_CASE(OverloadGroupType) +}; + +LoweredValInfo lowerVal( + IRGenContext* context, + Val* val) +{ + ValLoweringVisitor visitor; + visitor.context = context; + return visitor.dispatch(val); +} + +IRType* lowerType( + IRGenContext* context, + Type* type) +{ + ValLoweringVisitor visitor; + visitor.context = context; + return (IRType*) getSimpleVal(context, visitor.dispatchType(type)); +} + +void addVarDecorations( + IRGenContext* context, + IRInst* inst, + Decl* decl) +{ + auto builder = context->irBuilder; + for(RefPtr mod : decl->modifiers) + { + if(as(mod)) + { + builder->addInterpolationModeDecoration(inst, IRInterpolationMode::NoInterpolation); + } + else if(as(mod)) + { + builder->addInterpolationModeDecoration(inst, IRInterpolationMode::NoPerspective); + } + else if(as(mod)) + { + builder->addInterpolationModeDecoration(inst, IRInterpolationMode::Linear); + } + else if(as(mod)) + { + builder->addInterpolationModeDecoration(inst, IRInterpolationMode::Sample); + } + else if(as(mod)) + { + builder->addInterpolationModeDecoration(inst, IRInterpolationMode::Centroid); + } + else if(as(mod)) + { + builder->addSimpleDecoration(inst); + } + else if(as(mod)) + { + builder->addSimpleDecoration(inst); + } + else if(as(mod)) + { + builder->addSimpleDecoration(inst); + } + else if(as(mod)) + { + builder->addSimpleDecoration(inst); + } + else if(as(mod)) + { + builder->addSimpleDecoration(inst); + } + else if(auto formatAttr = as(mod)) + { + builder->addFormatDecoration(inst, formatAttr->format); + } + + // TODO: what are other modifiers we need to propagate through? + } +} + +/// If `decl` has a modifier that should turn into a +/// rate qualifier, then apply it to `inst`. +void maybeSetRate( + IRGenContext* context, + IRInst* inst, + Decl* decl) +{ + auto builder = context->irBuilder; + + if (decl->HasModifier()) + { + inst->setFullType(builder->getRateQualifiedType( + builder->getGroupSharedRate(), + inst->getFullType())); + } +} + +static String getNameForNameHint( + IRGenContext* context, + Decl* decl) +{ + // We will use a bit of an ad hoc convention here for now. + + Name* leafName = decl->getName(); + + // Handle custom name for a global parameter group (e.g., a `cbuffer`) + if(auto reflectionNameModifier = decl->FindModifier()) + { + leafName = reflectionNameModifier->nameAndLoc.name; + } + + // There is no point in trying to provide a name hint for something with no name, + // or with an empty name + if(!leafName) + return String(); + if(leafName->text.getLength() == 0) + return String(); + + + if(auto varDecl = as(decl)) + { + // For an ordinary local variable, global variable, + // parameter, or field, we will just use the name + // as declared, and now work in anything from + // its parent declaration(s). + // + // TODO: consider whether global/static variables should + // follow different rules. + // + return leafName->text; + } + + // For other cases of declaration, we want to consider + // merging its name with the name of its parent declaration. + auto parentDecl = decl->ParentDecl; + + // Skip past a generic parent, if we are a declaration nested in a generic. + if(auto genericParentDecl = as(parentDecl)) + parentDecl = genericParentDecl->ParentDecl; + + // A `ModuleDecl` can have a name too, but in the common case + // we don't want to generate name hints that include the module + // name, simply because they would lead to every global symbol + // getting a much longer name. + // + // TODO: We should probably include the module name for symbols + // being `import`ed, and not for symbols being compiled directly + // (those coming from a module that had no name given to it). + // + // For now we skip past a `ModuleDecl` parent. + // + if(auto moduleParentDecl = as(parentDecl)) + parentDecl = moduleParentDecl->ParentDecl; + + if(!parentDecl) + { + return leafName->text; + } + + auto parentName = getNameForNameHint(context, parentDecl); + if(parentName.getLength() == 0) + { + return leafName->text; + } + + // We will now construct a new `Name` to use as the hint, + // combining the name of the parent and the leaf declaration. + + StringBuilder sb; + sb.append(parentName); + sb.append("."); + sb.append(leafName->text); + + return sb.ProduceString(); +} + +/// Try to add an appropriate name hint to the instruction, +/// that can be used for back-end code emission or debug info. +static void addNameHint( + IRGenContext* context, + IRInst* inst, + Decl* decl) +{ + String name = getNameForNameHint(context, decl); + if(name.getLength() == 0) + return; + context->irBuilder->addNameHintDecoration(inst, name.getUnownedSlice()); +} + +/// Add a name hint based on a fixed string. +static void addNameHint( + IRGenContext* context, + IRInst* inst, + char const* text) +{ + context->irBuilder->addNameHintDecoration(inst, UnownedTerminatedStringSlice(text)); +} + +LoweredValInfo createVar( + IRGenContext* context, + IRType* type, + Decl* decl = nullptr) +{ + auto builder = context->irBuilder; + auto irAlloc = builder->emitVar(type); + + if (decl) + { + maybeSetRate(context, irAlloc, decl); + + addVarDecorations(context, irAlloc, decl); + + builder->addHighLevelDeclDecoration(irAlloc, decl); + + addNameHint(context, irAlloc, decl); + } + + return LoweredValInfo::ptr(irAlloc); +} + +void addArgs( + IRGenContext* context, + List* ioArgs, + LoweredValInfo argInfo) +{ + auto& args = *ioArgs; + switch( argInfo.flavor ) + { + case LoweredValInfo::Flavor::Simple: + case LoweredValInfo::Flavor::Ptr: + case LoweredValInfo::Flavor::SwizzledLValue: + case LoweredValInfo::Flavor::BoundSubscript: + case LoweredValInfo::Flavor::BoundMember: + args.add(getSimpleVal(context, argInfo)); + break; + + default: + SLANG_UNIMPLEMENTED_X("addArgs case"); + break; + } +} + +// + +// When we try to turn a `LoweredValInfo` into an address of some temporary storage, +// we can either do it "aggressively" or not (what we'll call the "default" behavior, +// although it isn't strictly more common). +// +// The case that this is mostly there to address is when somebody writes an operation +// like: +// +// foo[a] = b; +// +// In that case, we might as well just use the `set` accessor if there is one, rather +// than complicate things. However, in more complex cases like: +// +// foo[a].x = b; +// +// there is no way to satisfy the semantics of the code the user wrote (in terms of +// only writing one vector component, and not a full vector) by using the `set` +// accessor, and we need to be "aggressive" in turning the lvalue `foo[a]` into +// an address. +// +// TODO: realistically IR lowering is too early to be binding to this choice, +// because different accessors might be supported on different targets. +// +enum class TryGetAddressMode +{ + Default, + Aggressive, +}; + +/// Try to coerce `inVal` into a `LoweredValInfo::ptr()` with a simple address. +LoweredValInfo tryGetAddress( + IRGenContext* context, + LoweredValInfo const& inVal, + TryGetAddressMode mode); + + +// + +template +struct ExprLoweringVisitorBase : ExprVisitor +{ + IRGenContext* context; + + IRBuilder* getBuilder() { return context->irBuilder; } + + // Lower an expression that should have the same l-value-ness + // as the visitor itself. + LoweredValInfo lowerSubExpr(Expr* expr) + { + IRBuilderSourceLocRAII sourceLocInfo(getBuilder(), expr->loc); + return this->dispatch(expr); + } + + + LoweredValInfo visitVarExpr(VarExpr* expr) + { + LoweredValInfo info = emitDeclRef( + context, + expr->declRef, + lowerType(context, expr->type)); + return info; + } + + LoweredValInfo visitOverloadedExpr(OverloadedExpr* /*expr*/) + { + SLANG_UNEXPECTED("overloaded expressions should not occur in checked AST"); + UNREACHABLE_RETURN(LoweredValInfo()); + } + + LoweredValInfo visitOverloadedExpr2(OverloadedExpr2* /*expr*/) + { + SLANG_UNEXPECTED("overloaded expressions should not occur in checked AST"); + UNREACHABLE_RETURN(LoweredValInfo()); + } + + LoweredValInfo visitIndexExpr(IndexExpr* expr) + { + auto type = lowerType(context, expr->type); + auto baseVal = lowerSubExpr(expr->BaseExpression); + auto indexVal = getSimpleVal(context, lowerRValueExpr(context, expr->IndexExpression)); + + return subscriptValue(type, baseVal, indexVal); + } + + LoweredValInfo visitThisExpr(ThisExpr* /*expr*/) + { + return context->thisVal; + } + + LoweredValInfo visitMemberExpr(MemberExpr* expr) + { + auto loweredType = lowerType(context, expr->type); + auto loweredBase = lowerRValueExpr(context, expr->BaseExpression); + + auto declRef = expr->declRef; + if (auto fieldDeclRef = declRef.as()) + { + // Okay, easy enough: we have a reference to a field of a struct type... + return extractField(loweredType, loweredBase, fieldDeclRef); + } + else if (auto callableDeclRef = declRef.as()) + { + RefPtr boundMemberInfo = new BoundMemberInfo(); + boundMemberInfo->type = nullptr; + boundMemberInfo->base = loweredBase; + boundMemberInfo->declRef = callableDeclRef; + return LoweredValInfo::boundMember(boundMemberInfo); + } + else if(auto constraintDeclRef = declRef.as()) + { + // The code is making use of a "witness" that a value of + // some generic type conforms to an interface. + // + // For now we will just emit the base expression as-is. + // TODO: we may need to insert an explicit instruction + // for a cast here (that could become a no-op later). + return loweredBase; + } + + SLANG_UNIMPLEMENTED_X("codegen for subscript expression"); + UNREACHABLE_RETURN(LoweredValInfo()); + } + + // We will always lower a dereference expression (`*ptr`) + // as an l-value, since that is the easiest way to handle it. + LoweredValInfo visitDerefExpr(DerefExpr* expr) + { + auto loweredBase = lowerRValueExpr(context, expr->base); + + // TODO: handle tupel-type for `base` + + // The type of the lowered base must by some kind of pointer, + // in order for a dereference to make senese, so we just + // need to extract the value type from that pointer here. + // + IRInst* loweredBaseVal = getSimpleVal(context, loweredBase); + IRType* loweredBaseType = loweredBaseVal->getDataType(); + + if (as(loweredBaseType) + || as(loweredBaseType)) + { + // Note that we do *not* perform an actual `load` operation + // here, but rather just use the pointer value to construct + // an appropriate `LoweredValInfo` representing the underlying + // dereference. + // + // This is important so that an expression like `&((*foo).bar)` + // (which is desugared from `&foo->bar`) can be handled; such + // an expression does *not* perform a dereference at runtime, + // and is just a bit of pointer math. + // + return LoweredValInfo::ptr(loweredBaseVal); + } + else + { + SLANG_UNIMPLEMENTED_X("codegen for deref expression"); + UNREACHABLE_RETURN(LoweredValInfo()); + } + } + + LoweredValInfo visitParenExpr(ParenExpr* expr) + { + return lowerSubExpr(expr->base); + } + + LoweredValInfo getSimpleDefaultVal(IRType* type) + { + if(auto basicType = as(type)) + { + switch( basicType->getBaseType() ) + { + default: + SLANG_UNEXPECTED("missing case for getting IR default value"); + UNREACHABLE_RETURN(LoweredValInfo()); + break; + + case BaseType::Bool: + case BaseType::Int8: + case BaseType::Int16: + case BaseType::Int: + case BaseType::Int64: + case BaseType::UInt8: + case BaseType::UInt16: + case BaseType::UInt: + case BaseType::UInt64: + return LoweredValInfo::simple(getBuilder()->getIntValue(type, 0)); + + case BaseType::Half: + case BaseType::Float: + case BaseType::Double: + return LoweredValInfo::simple(getBuilder()->getFloatValue(type, 0.0)); + } + } + + SLANG_UNEXPECTED("missing case for getting IR default value"); + UNREACHABLE_RETURN(LoweredValInfo()); + } + + LoweredValInfo getDefaultVal(Type* type) + { + auto irType = lowerType(context, type); + if (auto basicType = as(type)) + { + return getSimpleDefaultVal(irType); + } + else if (auto vectorType = as(type)) + { + UInt elementCount = (UInt) GetIntVal(vectorType->elementCount); + + auto irDefaultValue = getSimpleVal(context, getDefaultVal(vectorType->elementType)); + + List args; + for(UInt ee = 0; ee < elementCount; ++ee) + { + args.add(irDefaultValue); + } + return LoweredValInfo::simple( + getBuilder()->emitMakeVector(irType, args.getCount(), args.getBuffer())); + } + else if (auto matrixType = as(type)) + { + UInt rowCount = (UInt) GetIntVal(matrixType->getRowCount()); + + auto rowType = matrixType->getRowType(); + + auto irDefaultValue = getSimpleVal(context, getDefaultVal(rowType)); + + List args; + for(UInt rr = 0; rr < rowCount; ++rr) + { + args.add(irDefaultValue); + } + return LoweredValInfo::simple( + getBuilder()->emitMakeMatrix(irType, args.getCount(), args.getBuffer())); + } + else if (auto arrayType = as(type)) + { + UInt elementCount = (UInt) GetIntVal(arrayType->ArrayLength); + + auto irDefaultElement = getSimpleVal(context, getDefaultVal(arrayType->baseType)); + + List args; + for(UInt ee = 0; ee < elementCount; ++ee) + { + args.add(irDefaultElement); + } + + return LoweredValInfo::simple( + getBuilder()->emitMakeArray(irType, args.getCount(), args.getBuffer())); + } + else if (auto declRefType = as(type)) + { + DeclRef declRef = declRefType->declRef; + if (auto aggTypeDeclRef = declRef.as()) + { + List args; + for (auto ff : getMembersOfType(aggTypeDeclRef)) + { + if (ff.getDecl()->HasModifier()) + continue; + + auto irFieldVal = getSimpleVal(context, getDefaultVal(ff)); + args.add(irFieldVal); + } + + return LoweredValInfo::simple( + getBuilder()->emitMakeStruct(irType, args.getCount(), args.getBuffer())); + } + } + + SLANG_UNEXPECTED("unexpected type when creating default value"); + UNREACHABLE_RETURN(LoweredValInfo()); + } + + LoweredValInfo getDefaultVal(VarDeclBase* decl) + { + if(auto initExpr = decl->initExpr) + { + return lowerRValueExpr(context, initExpr); + } + else + { + return getDefaultVal(decl->type); + } + } + + LoweredValInfo visitInitializerListExpr(InitializerListExpr* expr) + { + // Allocate a temporary of the given type + auto type = expr->type; + IRType* irType = lowerType(context, type); + List args; + + UInt argCount = expr->args.getCount(); + + // If the initializer list was empty, then the user was + // asking for default initialization, which should apply + // to (almost) any type. + // + if(argCount == 0) + { + return getDefaultVal(type.type); + } + + // Now for each argument in the initializer list, + // fill in the appropriate field of the result + if (auto arrayType = as(type)) + { + UInt elementCount = (UInt) GetIntVal(arrayType->ArrayLength); + + for (UInt ee = 0; ee < argCount; ++ee) + { + auto argExpr = expr->args[ee]; + LoweredValInfo argVal = lowerRValueExpr(context, argExpr); + args.add(getSimpleVal(context, argVal)); + } + if(elementCount > argCount) + { + auto irDefaultValue = getSimpleVal(context, getDefaultVal(arrayType->baseType)); + for(UInt ee = argCount; ee < elementCount; ++ee) + { + args.add(irDefaultValue); + } + } + + return LoweredValInfo::simple( + getBuilder()->emitMakeArray(irType, args.getCount(), args.getBuffer())); + } + else if (auto vectorType = as(type)) + { + UInt elementCount = (UInt) GetIntVal(vectorType->elementCount); + + for (UInt ee = 0; ee < argCount; ++ee) + { + auto argExpr = expr->args[ee]; + LoweredValInfo argVal = lowerRValueExpr(context, argExpr); + args.add(getSimpleVal(context, argVal)); + } + if(elementCount > argCount) + { + auto irDefaultValue = getSimpleVal(context, getDefaultVal(vectorType->elementType)); + for(UInt ee = argCount; ee < elementCount; ++ee) + { + args.add(irDefaultValue); + } + } + + return LoweredValInfo::simple( + getBuilder()->emitMakeVector(irType, args.getCount(), args.getBuffer())); + } + else if (auto matrixType = as(type)) + { + UInt rowCount = (UInt) GetIntVal(matrixType->getRowCount()); + + for (UInt rr = 0; rr < argCount; ++rr) + { + auto argExpr = expr->args[rr]; + LoweredValInfo argVal = lowerRValueExpr(context, argExpr); + args.add(getSimpleVal(context, argVal)); + } + if(rowCount > argCount) + { + auto rowType = matrixType->getRowType(); + auto irDefaultValue = getSimpleVal(context, getDefaultVal(rowType)); + + for(UInt rr = argCount; rr < rowCount; ++rr) + { + args.add(irDefaultValue); + } + } + + return LoweredValInfo::simple( + getBuilder()->emitMakeMatrix(irType, args.getCount(), args.getBuffer())); + } + else if (auto declRefType = as(type)) + { + DeclRef declRef = declRefType->declRef; + if (auto aggTypeDeclRef = declRef.as()) + { + UInt argCounter = 0; + for (auto ff : getMembersOfType(aggTypeDeclRef)) + { + if (ff.getDecl()->HasModifier()) + continue; + + UInt argIndex = argCounter++; + if (argIndex < argCount) + { + auto argExpr = expr->args[argIndex]; + LoweredValInfo argVal = lowerRValueExpr(context, argExpr); + args.add(getSimpleVal(context, argVal)); + } + else + { + auto irDefaultValue = getSimpleVal(context, getDefaultVal(ff)); + args.add(irDefaultValue); + } + } + + return LoweredValInfo::simple( + getBuilder()->emitMakeStruct(irType, args.getCount(), args.getBuffer())); + } + } + + // If none of the above cases matched, then we had better + // have zero arguments in the initializer list, in which + // case we are just looking for default initialization. + // + SLANG_UNEXPECTED("unhandled case for initializer list codegen"); + UNREACHABLE_RETURN(LoweredValInfo()); + } + + LoweredValInfo visitBoolLiteralExpr(BoolLiteralExpr* expr) + { + return LoweredValInfo::simple(context->irBuilder->getBoolValue(expr->value)); + } + + LoweredValInfo visitIntegerLiteralExpr(IntegerLiteralExpr* expr) + { + auto type = lowerType(context, expr->type); + return LoweredValInfo::simple(context->irBuilder->getIntValue(type, expr->value)); + } + + LoweredValInfo visitFloatingPointLiteralExpr(FloatingPointLiteralExpr* expr) + { + auto type = lowerType(context, expr->type); + return LoweredValInfo::simple(context->irBuilder->getFloatValue(type, expr->value)); + } + + LoweredValInfo visitStringLiteralExpr(StringLiteralExpr* expr) + { + return LoweredValInfo::simple(context->irBuilder->getStringValue(expr->value.getUnownedSlice())); + } + + LoweredValInfo visitAggTypeCtorExpr(AggTypeCtorExpr* /*expr*/) + { + SLANG_UNIMPLEMENTED_X("codegen for aggregate type constructor expression"); + UNREACHABLE_RETURN(LoweredValInfo()); + } + + // After a call to a function with `out` or `in out` + // parameters, we may need to copy data back into + // the l-value locations used for output arguments. + // + // During lowering of the argument list, we build + // up a list of these "fixup" assignments that need + // to be performed. + struct OutArgumentFixup + { + LoweredValInfo dst; + LoweredValInfo src; + }; + + void addDirectCallArgs( + InvokeExpr* expr, + DeclRef funcDeclRef, + List* ioArgs, + List* ioFixups) + { + UInt argCount = expr->Arguments.getCount(); + UInt argCounter = 0; + for (auto paramDeclRef : getMembersOfType(funcDeclRef)) + { + auto paramDecl = paramDeclRef.getDecl(); + IRType* paramType = lowerType(context, GetType(paramDeclRef)); + + UInt argIndex = argCounter++; + RefPtr argExpr; + if(argIndex < argCount) + { + argExpr = expr->Arguments[argIndex]; + } + else + { + // We have run out of arguments supplied at the call site, + // but there are still parameters remaining. This must mean + // that these parameters have default argument expressions + // associated with them. + argExpr = getInitExpr(paramDeclRef); + + // Assert that such an expression must have been present. + SLANG_ASSERT(argExpr); + + // TODO: The approach we are taking here to default arguments + // is simplistic, and has consequences for the front-end as + // well as binary serialization of modules. + // + // We could consider some more refined approaches where, e.g., + // functions with default arguments generate multiple IR-level + // functions, that compute and provide the default values. + // + // Alternatively, each parameter with defaults could be generated + // into its own callable function that provides the default value, + // so that calling modules can call into a pre-generated function. + // + // Each of these options involves trade-offs, and we need to + // make a conscious decision at some point. + } + + if(paramDecl->HasModifier()) + { + // A `ref` qualified parameter must be implemented with by-reference + // parameter passing, so the argument value should be lowered as + // an l-value. + // + LoweredValInfo loweredArg = lowerLValueExpr(context, argExpr); + + // According to our "calling convention" we need to + // pass a pointer into the callee. Unlike the case for + // `out` and `inout` below, it is never valid to do + // copy-in/copy-out for a `ref` parameter, so we just + // pass in the actual pointer. + // + IRInst* argPtr = getAddress(context, loweredArg, argExpr->loc); + (*ioArgs).add(argPtr); + } + else if (paramDecl->HasModifier() + || paramDecl->HasModifier()) + { + // This is a `out` or `inout` parameter, and so + // the argument must be lowered as an l-value. + + LoweredValInfo loweredArg = lowerLValueExpr(context, argExpr); + + // According to our "calling convention" we need to + // pass a pointer into the callee. + // + // A naive approach would be to just take the address + // of `loweredArg` above and pass it in, but that + // has two issues: + // + // 1. The l-value might not be something that has a single + // well-defined "address" (e.g., `foo.xzy`). + // + // 2. The l-value argument might actually alias some other + // storage that the callee will access (e.g., we are + // passing in a global variable, or two `out` parameters + // are being passed the same location in an array). + // + // In each of these cases, the safe option is to create + // a temporary variable to use for argument-passing, + // and then do copy-in/copy-out around the call. + + LoweredValInfo tempVar = createVar(context, paramType); + + // If the parameter is `in out` or `inout`, then we need + // to ensure that we pass in the original value stored + // in the argument, which we accomplish by assigning + // from the l-value to our temp. + if (paramDecl->HasModifier() + || paramDecl->HasModifier()) + { + assign(context, tempVar, loweredArg); + } + + // Now we can pass the address of the temporary variable + // to the callee as the actual argument for the `in out` + SLANG_ASSERT(tempVar.flavor == LoweredValInfo::Flavor::Ptr); + (*ioArgs).add(tempVar.val); + + // Finally, after the call we will need + // to copy in the other direction: from our + // temp back to the original l-value. + OutArgumentFixup fixup; + fixup.src = tempVar; + fixup.dst = loweredArg; + + (*ioFixups).add(fixup); + + } + else + { + // This is a pure input parameter, and so we will + // pass it as an r-value. + LoweredValInfo loweredArg = lowerRValueExpr(context, argExpr); + addArgs(context, ioArgs, loweredArg); + } + } + } + + // Add arguments that appeared directly in an argument list + // to the list of argument values for a call. + void addDirectCallArgs( + InvokeExpr* expr, + DeclRef funcDeclRef, + List* ioArgs, + List* ioFixups) + { + if (auto callableDeclRef = funcDeclRef.as()) + { + addDirectCallArgs(expr, callableDeclRef, ioArgs, ioFixups); + } + else + { + SLANG_UNEXPECTED("callee was not a callable decl"); + } + } + + void addFuncBaseArgs( + LoweredValInfo funcVal, + List* ioArgs) + { + switch (funcVal.flavor) + { + default: + return; + } + } + + void applyOutArgumentFixups(List const& fixups) + { + for (auto fixup : fixups) + { + assign(context, fixup.dst, fixup.src); + } + } + + struct ResolvedCallInfo + { + DeclRef funcDeclRef; + Expr* baseExpr = nullptr; + }; + + // Try to resolve a the function expression for a call + // into a reference to a specific declaration, along + // with some contextual information about the declaration + // we are calling. + bool tryResolveDeclRefForCall( + RefPtr funcExpr, + ResolvedCallInfo* outInfo) + { + // TODO: unwrap any "identity" expressions that might + // be wrapping the callee. + + // First look to see if the expression references a + // declaration at all. + auto declRefExpr = as(funcExpr); + if(!declRefExpr) + return false; + + // A little bit of future proofing here: if we ever + // allow higher-order functions, then we might be + // calling through a variable/field that has a function + // type, but is not itself a function. + // In such a case we should be careful to not statically + // resolve things. + // + if(auto callableDecl = as(declRefExpr->declRef.getDecl())) + { + // Okay, the declaration is directly callable, so we can continue. + } + else + { + // The callee declaration isn't itself a callable (it must have + // a function type, though). + return false; + } + + // Now we can look at the specific kinds of declaration references, + // and try to tease them apart. + if (auto memberFuncExpr = as(funcExpr)) + { + outInfo->funcDeclRef = memberFuncExpr->declRef; + outInfo->baseExpr = memberFuncExpr->BaseExpression; + return true; + } + else if (auto staticMemberFuncExpr = as(funcExpr)) + { + outInfo->funcDeclRef = staticMemberFuncExpr->declRef; + return true; + } + else if (auto varExpr = as(funcExpr)) + { + outInfo->funcDeclRef = varExpr->declRef; + return true; + } + else + { + // Seems to be a case of declaration-reference we don't know about. + SLANG_UNEXPECTED("unknown declaration reference kind"); + return false; + } + } + + + LoweredValInfo visitInvokeExpr(InvokeExpr* expr) + { + auto type = lowerType(context, expr->type); + + // We are going to look at the syntactic form of + // the "function" expression, so that we can avoid + // a lot of complexity that would come from lowering + // it as a general expression first, and then trying + // to apply it. For example, given `obj.f(a,b)` we + // will try to detect that we are trying to compute + // something like `ObjType::f(obj, a, b)` (in pseudo-code), + // rather than trying to construct a meaningful + // intermediate value for `obj.f` first. + // + // Note that this doe not preclude having support + // for directly generating code from `obj.f` - it + // just may be that such usage is more complicated. + + // Along the way, we may end up collecting additional + // arguments that will be part of the call. + List irArgs; + + // We will also collect "fixup" actions that need + // to be performed after the call, in order to + // copy the final values for `out` parameters + // back to their arguments. + List argFixups; + + auto funcExpr = expr->FunctionExpr; + ResolvedCallInfo resolvedInfo; + if( tryResolveDeclRefForCall(funcExpr, &resolvedInfo) ) + { + // In this case we know exactly what declaration we + // are going to call, and so we can resolve things + // appropriately. + auto funcDeclRef = resolvedInfo.funcDeclRef; + auto baseExpr = resolvedInfo.baseExpr; + + // First comes the `this` argument if we are calling + // a member function: + if( baseExpr ) + { + auto loweredBaseVal = lowerRValueExpr(context, baseExpr); + addArgs(context, &irArgs, loweredBaseVal); + } + + // Then we have the "direct" arguments to the call. + // These may include `out` and `inout` arguments that + // require "fixup" work on the other side. + // + auto funcType = lowerType(context, funcExpr->type); + addDirectCallArgs(expr, funcDeclRef, &irArgs, &argFixups); + auto result = emitCallToDeclRef( + context, + type, + funcDeclRef, + funcType, + irArgs); + applyOutArgumentFixups(argFixups); + return result; + } + + // TODO: In this case we should be emitting code for the callee as + // an ordinary expression, then emitting the arguments according + // to the type information on the callee (e.g., which parameters + // are `out` or `inout`, and then finally emitting the `call` + // instruction. + // + // We don't currently have the case of emitting arguments according + // to function type info (instead of declaration info), and really + // this case can't occur unless we start adding first-class functions + // to the source language. + // + // For now we just bail out with an error. + // + SLANG_UNEXPECTED("could not resolve target declaration for call"); + UNREACHABLE_RETURN(LoweredValInfo()); + } + + LoweredValInfo visitCastToInterfaceExpr( + CastToInterfaceExpr* expr) + { + // We have an expression that is "up-casting" some concrete value + // to an existential type (aka interface type), using a subtype witness + // (which will lower as a witness table) to show that the conversion + // is valid. + // + // At the IR level, this will become a `makeExistential` instruction, + // which collects the above information into a single IR-level value. + // A dynamic CPU implementation of Slang might encode an existential + // as a "fat pointer" representation, which includes a pointer to + // data for the concrete value, plus a pointer to the witness table. + // + // Note: if/when Slang supports more general existential types, such + // as compositions of interface (e.g., `IReadable & IWritable`), then + // we should probably extend the AST and IR mechanism here to accept + // a sequence of witness tables. + // + auto existentialType = lowerType(context, expr->type); + auto concreteValue = getSimpleVal(context, lowerRValueExpr(context, expr->valueArg)); + auto witnessTable = lowerSimpleVal(context, expr->witnessArg); + auto existentialValue = getBuilder()->emitMakeExistential(existentialType, concreteValue, witnessTable); + return LoweredValInfo::simple(existentialValue); + } + + LoweredValInfo subscriptValue( + IRType* type, + LoweredValInfo baseVal, + IRInst* indexVal) + { + auto builder = getBuilder(); + + // The `tryGetAddress` operation will take a complex value representation + // and try to turn it into a single pointer, if possible. + // + baseVal = tryGetAddress(context, baseVal, TryGetAddressMode::Aggressive); + + // The `materialize` operation should ensure that we only have to deal + // with the small number of base cases for lowered value representations. + // + baseVal = materialize(context, baseVal); + + switch (baseVal.flavor) + { + case LoweredValInfo::Flavor::Simple: + return LoweredValInfo::simple( + builder->emitElementExtract( + type, + getSimpleVal(context, baseVal), + indexVal)); + + case LoweredValInfo::Flavor::Ptr: + return LoweredValInfo::ptr( + builder->emitElementAddress( + context->irBuilder->getPtrType(type), + baseVal.val, + indexVal)); + + default: + SLANG_UNIMPLEMENTED_X("subscript expr"); + UNREACHABLE_RETURN(LoweredValInfo()); + } + + } + + LoweredValInfo extractField( + IRType* fieldType, + LoweredValInfo base, + DeclRef field) + { + return Slang::extractField(context, fieldType, base, field); + } + + LoweredValInfo visitStaticMemberExpr(StaticMemberExpr* expr) + { + return emitDeclRef(context, expr->declRef, + lowerType(context, expr->type)); + } + + LoweredValInfo visitGenericAppExpr(GenericAppExpr* /*expr*/) + { + SLANG_UNIMPLEMENTED_X("generic application expression during code generation"); + UNREACHABLE_RETURN(LoweredValInfo()); + } + + LoweredValInfo visitSharedTypeExpr(SharedTypeExpr* /*expr*/) + { + SLANG_UNIMPLEMENTED_X("shared type expression during code generation"); + UNREACHABLE_RETURN(LoweredValInfo()); + } + + LoweredValInfo visitTaggedUnionTypeExpr(TaggedUnionTypeExpr* /*expr*/) + { + SLANG_UNIMPLEMENTED_X("tagged union type expression during code generation"); + UNREACHABLE_RETURN(LoweredValInfo()); + } + + LoweredValInfo visitAssignExpr(AssignExpr* expr) + { + // Because our representation of lowered "values" + // can encompass l-values explicitly, we can + // lower assignment easily. We just lower the left- + // and right-hand sides, and then perform an assignment + // based on the resulting values. + // + auto leftVal = lowerLValueExpr(context, expr->left); + auto rightVal = lowerRValueExpr(context, expr->right); + assign(context, leftVal, rightVal); + + // The result value of the assignment expression is + // the value of the left-hand side (and it is expected + // to be an l-value). + return leftVal; + } + + LoweredValInfo visitLetExpr(LetExpr* expr) + { + // TODO: deal with the case where we might want to capture + // a reference to the bound value... + + auto initVal = lowerLValueExpr(context, expr->decl->initExpr); + setGlobalValue(context, expr->decl, initVal); + auto bodyVal = lowerSubExpr(expr->body); + return bodyVal; + } + + LoweredValInfo visitExtractExistentialValueExpr(ExtractExistentialValueExpr* expr) + { + auto existentialType = lowerType(context, GetType(expr->declRef)); + auto existentialVal = getSimpleVal(context, emitDeclRef(context, expr->declRef, existentialType)); + + auto openedType = lowerType(context, expr->type); + + return LoweredValInfo::simple(getBuilder()->emitExtractExistentialValue(openedType, existentialVal)); + } +}; + +struct LValueExprLoweringVisitor : ExprLoweringVisitorBase +{ + // When visiting a swizzle expression in an l-value context, + // we need to construct a "sizzled l-value." + LoweredValInfo visitSwizzleExpr(SwizzleExpr* expr) + { + auto irType = lowerType(context, expr->type); + auto loweredBase = lowerRValueExpr(context, expr->base); + + RefPtr swizzledLValue = new SwizzledLValueInfo(); + swizzledLValue->type = irType; + + UInt elementCount = (UInt)expr->elementCount; + swizzledLValue->elementCount = elementCount; + + // As a small optimization, we will detect if the base expression + // has also lowered into a swizzle and only return a single + // swizzle instead of nested swizzles. + // + // E.g., if we have input like `foo[i].zw.y` we should optimize it + // down to just `foo[i].w`. + // + if(loweredBase.flavor == LoweredValInfo::Flavor::SwizzledLValue) + { + auto baseSwizzleInfo = loweredBase.getSwizzledLValueInfo(); + + // Our new swizzle will use the same base expression (e.g., + // `foo[i]` in our example above), but will need to remap + // the swizzle indices it uses. + // + swizzledLValue->base = baseSwizzleInfo->base; + for (UInt ii = 0; ii < elementCount; ++ii) + { + // First we get the swizzle element of the "outer" swizzle, + // as it was written by the user. In our running example of + // `foo[i].zw.y` this is the `y` element reference. + // + UInt originalElementIndex = UInt(expr->elementIndices[ii]); + + // Next we will use that original element index to figure + // out which of the elements of the original swizzle this + // should map to. + // + // In our example, `y` means index 1, and so we fetch + // element 1 from the inner swizzle sequence `zw`, to get `w`. + // + SLANG_ASSERT(originalElementIndex < baseSwizzleInfo->elementCount); + UInt remappedElementIndex = baseSwizzleInfo->elementIndices[originalElementIndex]; + + swizzledLValue->elementIndices[ii] = remappedElementIndex; + } + } + else + { + // In the default case, we can just copy the indices being + // used for the swizzle over directly from the expression, + // and use the base as-is. + // + swizzledLValue->base = loweredBase; + for (UInt ii = 0; ii < elementCount; ++ii) + { + swizzledLValue->elementIndices[ii] = (UInt) expr->elementIndices[ii]; + } + } + + context->shared->extValues.add(swizzledLValue); + return LoweredValInfo::swizzledLValue(swizzledLValue); + } +}; + +struct RValueExprLoweringVisitor : ExprLoweringVisitorBase +{ + // A swizzle in an r-value context can save time by just + // emitting the swizzle instructions directly. + LoweredValInfo visitSwizzleExpr(SwizzleExpr* expr) + { + auto irType = lowerType(context, expr->type); + auto irBase = getSimpleVal(context, lowerRValueExpr(context, expr->base)); + + auto builder = getBuilder(); + + auto irIntType = getIntType(context); + + UInt elementCount = (UInt)expr->elementCount; + IRInst* irElementIndices[4]; + for (UInt ii = 0; ii < elementCount; ++ii) + { + irElementIndices[ii] = builder->getIntValue( + irIntType, + (IRIntegerValue)expr->elementIndices[ii]); + } + + auto irSwizzle = builder->emitSwizzle( + irType, + irBase, + elementCount, + &irElementIndices[0]); + + return LoweredValInfo::simple(irSwizzle); + } +}; + +LoweredValInfo lowerLValueExpr( + IRGenContext* context, + Expr* expr) +{ + IRBuilderSourceLocRAII sourceLocInfo(context->irBuilder, expr->loc); + + LValueExprLoweringVisitor visitor; + visitor.context = context; + return visitor.dispatch(expr); +} + +LoweredValInfo lowerRValueExpr( + IRGenContext* context, + Expr* expr) +{ + IRBuilderSourceLocRAII sourceLocInfo(context->irBuilder, expr->loc); + + RValueExprLoweringVisitor visitor; + visitor.context = context; + return visitor.dispatch(expr); +} + +struct StmtLoweringVisitor : StmtVisitor +{ + IRGenContext* context; + + IRBuilder* getBuilder() { return context->irBuilder; } + + void visitEmptyStmt(EmptyStmt*) + { + // Nothing to do. + } + + void visitUnparsedStmt(UnparsedStmt*) + { + SLANG_UNEXPECTED("UnparsedStmt not supported by IR"); + } + + void visitCaseStmtBase(CaseStmtBase*) + { + SLANG_UNEXPECTED("`case` or `default` not under `switch`"); + } + + void visitCompileTimeForStmt(CompileTimeForStmt* stmt) + { + // The user is asking us to emit code for the loop + // body for each value in the given integer range. + // For now, we will handle this by repeatedly lowering + // the body statement, with the loop variable bound + // to a different integer literal value each time. + // + // TODO: eventually we might handle this as just an + // ordinary loop, with an `[unroll]` attribute on + // it that we would respect. + + auto rangeBeginVal = GetIntVal(stmt->rangeBeginVal); + auto rangeEndVal = GetIntVal(stmt->rangeEndVal); + + if (rangeBeginVal >= rangeEndVal) + return; + + auto varDecl = stmt->varDecl; + auto varType = lowerType(context, varDecl->type); + + IRGenEnv subEnvStorage; + IRGenEnv* subEnv = &subEnvStorage; + subEnv->outer = context->env; + + IRGenContext subContextStorage = *context; + IRGenContext* subContext = &subContextStorage; + subContext->env = subEnv; + + + + for (IntegerLiteralValue ii = rangeBeginVal; ii < rangeEndVal; ++ii) + { + auto constVal = getBuilder()->getIntValue( + varType, + ii); + + subEnv->mapDeclToValue[varDecl] = LoweredValInfo::simple(constVal); + + lowerStmt(subContext, stmt->body); + } + } + + // Create a basic block in the current function, + // so that it can be used for a label. + IRBlock* createBlock() + { + return getBuilder()->createBlock(); + } + + /// Does the given block have a terminator? + bool isBlockTerminated(IRBlock* block) + { + return block->getTerminator() != nullptr; + } + + /// Emit a branch to the target block if the current + /// block being inserted into is not already terminated. + void emitBranchIfNeeded(IRBlock* targetBlock) + { + auto builder = getBuilder(); + auto currentBlock = builder->getBlock(); + + // Don't emit if there is no current block. + if(!currentBlock) + return; + + // Don't emit if the block already has a terminator. + if(isBlockTerminated(currentBlock)) + return; + + // The block is unterminated, so cap it off with + // a terminator that branches to the target. + builder->emitBranch(targetBlock); + } + + /// Insert a block at the current location (ending + /// the previous block with an unconditional jump + /// if needed). + void insertBlock(IRBlock* block) + { + auto builder = getBuilder(); + + auto prevBlock = builder->getBlock(); + auto parentFunc = prevBlock ? prevBlock->getParent() : builder->getFunc(); + + // If the previous block doesn't already have + // a terminator instruction, then be sure to + // emit a branch to the new block. + emitBranchIfNeeded(block); + + // Add the new block to the function we are building, + // and setit as the block we will be inserting into. + parentFunc->addBlock(block); + builder->setInsertInto(block); + } + + // Start a new block at the current location. + // This is just the composition of `createBlock` + // and `insertBlock`. + IRBlock* startBlock() + { + auto block = createBlock(); + insertBlock(block); + return block; + } + + /// Start a new block if there isn't a current + /// block that we can append to. + /// + /// The `stmt` parameter is the statement we + /// are about to emit. + void startBlockIfNeeded(Stmt* stmt) + { + auto builder = getBuilder(); + auto currentBlock = builder->getBlock(); + + // If there is a current block and it hasn't + // been terminated, then we can just use that. + if(currentBlock && !isBlockTerminated(currentBlock)) + { + return; + } + + // We are about to emit code *after* a terminator + // instruction, and there is no label to allow + // branching into this code, so whatever we are + // about to emit is going to be unreachable. + // + // Let's diagnose that here just to help the user. + // + // TODO: We might want to have a more robust check + // for unreachable code based on IR analysis instead, + // at which point we'd probably disable this check. + // + context->getSink()->diagnose(stmt, Diagnostics::unreachableCode); + + startBlock(); + } + + void visitIfStmt(IfStmt* stmt) + { + auto builder = getBuilder(); + startBlockIfNeeded(stmt); + + auto condExpr = stmt->Predicate; + auto thenStmt = stmt->PositiveStatement; + auto elseStmt = stmt->NegativeStatement; + + auto irCond = getSimpleVal(context, + lowerRValueExpr(context, condExpr)); + + if (elseStmt) + { + auto thenBlock = createBlock(); + auto elseBlock = createBlock(); + auto afterBlock = createBlock(); + + builder->emitIfElse(irCond, thenBlock, elseBlock, afterBlock); + + insertBlock(thenBlock); + lowerStmt(context, thenStmt); + emitBranchIfNeeded(afterBlock); + + insertBlock(elseBlock); + lowerStmt(context, elseStmt); + + insertBlock(afterBlock); + } + else + { + auto thenBlock = createBlock(); + auto afterBlock = createBlock(); + + builder->emitIf(irCond, thenBlock, afterBlock); + + insertBlock(thenBlock); + lowerStmt(context, thenStmt); + + insertBlock(afterBlock); + } + } + + void addLoopDecorations( + IRInst* inst, + Stmt* stmt) + { + if( stmt->FindModifier() ) + { + getBuilder()->addLoopControlDecoration(inst, kIRLoopControl_Unroll); + } + // TODO: handle other cases here + } + + void visitForStmt(ForStmt* stmt) + { + auto builder = getBuilder(); + startBlockIfNeeded(stmt); + + // The initializer clause for the statement + // can always safetly be emitted to the current block. + if (auto initStmt = stmt->InitialStatement) + { + lowerStmt(context, initStmt); + } + + // We will create blocks for the various places + // we need to jump to inside the control flow, + // including the blocks that will be referenced + // by `continue` or `break` statements. + auto loopHead = createBlock(); + auto bodyLabel = createBlock(); + auto breakLabel = createBlock(); + auto continueLabel = createBlock(); + + // Register the `break` and `continue` labels so + // that we can find them for nested statements. + context->shared->breakLabels.Add(stmt, breakLabel); + context->shared->continueLabels.Add(stmt, continueLabel); + + // Emit the branch that will start out loop, + // and then insert the block for the head. + + auto loopInst = builder->emitLoop( + loopHead, + breakLabel, + continueLabel); + + addLoopDecorations(loopInst, stmt); + + insertBlock(loopHead); + + // Now that we are within the header block, we + // want to emit the expression for the loop condition: + if (auto condExpr = stmt->PredicateExpression) + { + auto irCondition = getSimpleVal(context, + lowerRValueExpr(context, stmt->PredicateExpression)); + + // Now we want to `break` if the loop condition is false. + builder->emitLoopTest( + irCondition, + bodyLabel, + breakLabel); + } + + // Emit the body of the loop + insertBlock(bodyLabel); + lowerStmt(context, stmt->Statement); + + // Insert the `continue` block + insertBlock(continueLabel); + if (auto incrExpr = stmt->SideEffectExpression) + { + lowerRValueExpr(context, incrExpr); + } + + // At the end of the body we need to jump back to the top. + emitBranchIfNeeded(loopHead); + + // Finally we insert the label that a `break` will jump to + insertBlock(breakLabel); + } + + void visitWhileStmt(WhileStmt* stmt) + { + // Generating IR for `while` statement is similar to a + // `for` statement, but without a lot of the complications. + + auto builder = getBuilder(); + startBlockIfNeeded(stmt); + + // We will create blocks for the various places + // we need to jump to inside the control flow, + // including the blocks that will be referenced + // by `continue` or `break` statements. + auto loopHead = createBlock(); + auto bodyLabel = createBlock(); + auto breakLabel = createBlock(); + + // A `continue` inside a `while` loop always + // jumps to the head of hte loop. + auto continueLabel = loopHead; + + // Register the `break` and `continue` labels so + // that we can find them for nested statements. + context->shared->breakLabels.Add(stmt, breakLabel); + context->shared->continueLabels.Add(stmt, continueLabel); + + // Emit the branch that will start out loop, + // and then insert the block for the head. + + auto loopInst = builder->emitLoop( + loopHead, + breakLabel, + continueLabel); + + addLoopDecorations(loopInst, stmt); + + insertBlock(loopHead); + + // Now that we are within the header block, we + // want to emit the expression for the loop condition: + if (auto condExpr = stmt->Predicate) + { + auto irCondition = getSimpleVal(context, + lowerRValueExpr(context, condExpr)); + + // Now we want to `break` if the loop condition is false. + builder->emitLoopTest( + irCondition, + bodyLabel, + breakLabel); + } + + // Emit the body of the loop + insertBlock(bodyLabel); + lowerStmt(context, stmt->Statement); + + // At the end of the body we need to jump back to the top. + emitBranchIfNeeded(loopHead); + + // Finally we insert the label that a `break` will jump to + insertBlock(breakLabel); + } + + void visitDoWhileStmt(DoWhileStmt* stmt) + { + // Generating IR for `do {...} while` statement is similar to a + // `while` statement, just with the test in a different place + + auto builder = getBuilder(); + startBlockIfNeeded(stmt); + + // We will create blocks for the various places + // we need to jump to inside the control flow, + // including the blocks that will be referenced + // by `continue` or `break` statements. + auto loopHead = createBlock(); + auto testLabel = createBlock(); + auto breakLabel = createBlock(); + + // A `continue` inside a `do { ... } while ( ... )` loop always + // jumps to the loop test. + auto continueLabel = testLabel; + + // Register the `break` and `continue` labels so + // that we can find them for nested statements. + context->shared->breakLabels.Add(stmt, breakLabel); + context->shared->continueLabels.Add(stmt, continueLabel); + + // Emit the branch that will start out loop, + // and then insert the block for the head. + + auto loopInst = builder->emitLoop( + loopHead, + breakLabel, + continueLabel); + + addLoopDecorations(loopInst, stmt); + + insertBlock(loopHead); + + // Emit the body of the loop + lowerStmt(context, stmt->Statement); + + insertBlock(testLabel); + + // Now that we are within the header block, we + // want to emit the expression for the loop condition: + if (auto condExpr = stmt->Predicate) + { + auto irCondition = getSimpleVal(context, + lowerRValueExpr(context, condExpr)); + + // Now we want to `break` if the loop condition is false, + // otherwise we will jump back to the head of the loop. + builder->emitLoopTest( + irCondition, + loopHead, + breakLabel); + } + + // Finally we insert the label that a `break` will jump to + insertBlock(breakLabel); + } + + void visitExpressionStmt(ExpressionStmt* stmt) + { + startBlockIfNeeded(stmt); + + // The statement evaluates an expression + // (for side effects, one assumes) and then + // discards the result. As such, we simply + // lower the expression, and don't use + // the result. + // + // Note that we lower using the l-value path, + // so that an expression statement that names + // a location (but doesn't load from it) + // will not actually emit a load. + lowerLValueExpr(context, stmt->Expression); + } + + void visitDeclStmt(DeclStmt* stmt) + { + startBlockIfNeeded(stmt); + + // For now, we lower a declaration directly + // into the current context. + // + // TODO: We may want to consider whether + // nested type/function declarations should + // be lowered into the global scope during + // IR generation, or whether they should + // be lifted later (pushing capture analysis + // down to the IR). + // + lowerDecl(context, stmt->decl); + } + + void visitSeqStmt(SeqStmt* stmt) + { + // To lower a sequence of statements, + // just lower each in order + for (auto ss : stmt->stmts) + { + lowerStmt(context, ss); + } + } + + void visitBlockStmt(BlockStmt* stmt) + { + // To lower a block (scope) statement, + // just lower its body. The IR doesn't + // need to reflect the scoping of the AST. + lowerStmt(context, stmt->body); + } + + void visitReturnStmt(ReturnStmt* stmt) + { + startBlockIfNeeded(stmt); + + // A `return` statement turns into a return + // instruction. If the statement had an argument + // expression, then we need to lower that to + // a value first, and then emit the resulting value. + if( auto expr = stmt->Expression ) + { + auto loweredExpr = lowerRValueExpr(context, expr); + + getBuilder()->emitReturn(getSimpleVal(context, loweredExpr)); + } + else + { + getBuilder()->emitReturn(); + } + } + + void visitDiscardStmt(DiscardStmt* stmt) + { + startBlockIfNeeded(stmt); + getBuilder()->emitDiscard(); + } + + void visitBreakStmt(BreakStmt* stmt) + { + startBlockIfNeeded(stmt); + + // Semantic checking is responsible for finding + // the statement taht this `break` breaks out of + auto parentStmt = stmt->parentStmt; + SLANG_ASSERT(parentStmt); + + // We just need to look up the basic block that + // corresponds to the break label for that statement, + // and then emit an instruction to jump to it. + IRBlock* targetBlock = nullptr; + context->shared->breakLabels.TryGetValue(parentStmt, targetBlock); + SLANG_ASSERT(targetBlock); + getBuilder()->emitBreak(targetBlock); + } + + void visitContinueStmt(ContinueStmt* stmt) + { + startBlockIfNeeded(stmt); + + // Semantic checking is responsible for finding + // the loop that this `continue` statement continues + auto parentStmt = stmt->parentStmt; + SLANG_ASSERT(parentStmt); + + + // We just need to look up the basic block that + // corresponds to the continue label for that statement, + // and then emit an instruction to jump to it. + IRBlock* targetBlock = nullptr; + context->shared->continueLabels.TryGetValue(parentStmt, targetBlock); + SLANG_ASSERT(targetBlock); + getBuilder()->emitContinue(targetBlock); + } + + // Lowering a `switch` statement can get pretty involved, + // so we need to track a bit of extra data: + struct SwitchStmtInfo + { + // The block that will be made to contain the `switch` statement + IRBlock* initialBlock = nullptr; + + // The label for the `default` case, if any. + IRBlock* defaultLabel = nullptr; + + // The label of the current "active" case block. + IRBlock* currentCaseLabel = nullptr; + + // Has anything been emitted to the current "active" case block? + bool anythingEmittedToCurrentCaseBlock = false; + + // The collected (value, label) pairs for + // all the `case` statements. + List cases; + }; + + // We need a label to use for a `case` or `default` statement, + // so either create one here, or re-use the current one if + // that is okay. + IRBlock* getLabelForCase(SwitchStmtInfo* info) + { + // Look at the "current" label we are working with. + auto currentCaseLabel = info->currentCaseLabel; + + // If there is a current block, and it is empty, + // then it is still a viable target (we are in + // a case of "trivial fall-through" from the previous + // block). + if(currentCaseLabel && !info->anythingEmittedToCurrentCaseBlock) + { + return currentCaseLabel; + } + + // Othwerise, we need to start a new block and use that. + IRBlock* newCaseLabel = createBlock(); + + // Note: if the previous block failed + // to end with a `break`, then inserting + // this block will append an unconditional + // branch to the end of it that will target + // this block. + insertBlock(newCaseLabel); + + info->currentCaseLabel = newCaseLabel; + info->anythingEmittedToCurrentCaseBlock = false; + return newCaseLabel; + } + + // Given a statement that appears as (or in) the body + // of a `switch` statement + void lowerSwitchCases(Stmt* inStmt, SwitchStmtInfo* info) + { + // TODO: in the general case (e.g., if we were going + // to eventual lower to an unstructured format like LLVM), + // the Right Way to handle C-style `switch` statements + // is just to emit the body directly as "normal" statements, + // and then treat `case` and `default` as special statements + // that start a new block and register a label with the + // enclosing `switch`. + // + // For now we will assume that any `case` and `default` + // statements need to be directly nested under the `switch`, + // and so we can find them with a simpler walk. + + Stmt* stmt = inStmt; + + // Unwrap any surrounding `{ ... }` so we can look + // at the statement inside. + while(auto blockStmt = as(stmt)) + { + stmt = blockStmt->body; + continue; + } + + if(auto seqStmt = as(stmt)) + { + // Walk through teh children and process each. + for(auto childStmt : seqStmt->stmts) + { + lowerSwitchCases(childStmt, info); + } + } + else if(auto caseStmt = as(stmt)) + { + // A full `case` statement has a value we need + // to test against. It is expected to be a + // compile-time constant, so we will emit + // it like an expression here, and then hope + // for the best. + // + // TODO: figure out something cleaner. + + // Actually, one gotcha is that if we ever allow non-constant + // expressions here (or anything that requires instructions + // to be emitted to yield its value), then those instructions + // need to go into an appropriate block. + + IRGenContext subContext = *context; + IRBuilder subBuilder = *getBuilder(); + subBuilder.setInsertInto(info->initialBlock); + subContext.irBuilder = &subBuilder; + auto caseVal = getSimpleVal(context, lowerRValueExpr(&subContext, caseStmt->expr)); + + // Figure out where we are branching to. + auto label = getLabelForCase(info); + + // Add this `case` to the list for the enclosing `switch`. + info->cases.add(caseVal); + info->cases.add(label); + } + else if(auto defaultStmt = as(stmt)) + { + auto label = getLabelForCase(info); + + // We expect to only find a single `default` stmt. + SLANG_ASSERT(!info->defaultLabel); + + info->defaultLabel = label; + } + else if(auto emptyStmt = as(stmt)) + { + // Special-case empty statements so they don't + // mess up our "trivial fall-through" optimization. + } + else + { + // We have an ordinary statement, that needs to get + // emitted to the current case block. + if(!info->currentCaseLabel) + { + // It possible in full C/C++ to have statements + // before the first `case`. Usually these are + // unreachable, unless they start with a label. + // + // We'll ignore them here, figuring they are + // dead. If we ever add `LabelStmt` then we'd + // need to emit these statements to a dummy + // block just in case. + } + else + { + // Emit the code to our current case block, + // and record that we've done so. + lowerStmt(context, stmt); + info->anythingEmittedToCurrentCaseBlock = true; + } + } + } + + void visitSwitchStmt(SwitchStmt* stmt) + { + auto builder = getBuilder(); + startBlockIfNeeded(stmt); + + // Given a statement: + // + // switch( CONDITION ) + // { + // case V0: + // S0; + // break; + // + // case V1: + // default: + // S1; + // break; + // } + // + // we want to generate IR like: + // + // let %c = ; + // switch %c, // value to switch on + // %breakLabel, // join point (and break target) + // %s1, // default label + // %v0, // first case value + // %s0, // first case label + // %v1, // second case value + // %s1 // second case label + // s0: + // + // break %breakLabel + // s1: + // + // break %breakLabel + // breakLabel: + // + + // First emit code to compute the condition: + auto conditionVal = getSimpleVal(context, lowerRValueExpr(context, stmt->condition)); + + // Remember the initial block so that we can add to it + // after we've collected all the `case`s + auto initialBlock = builder->getBlock(); + + // Next, create a block to use as the target for any `break` statements + auto breakLabel = createBlock(); + + // Register the `break` label so + // that we can find it for nested statements. + context->shared->breakLabels.Add(stmt, breakLabel); + + builder->setInsertInto(initialBlock->getParent()); + + // Iterate over the body of the statement, looking + // for `case` or `default` statements: + SwitchStmtInfo info; + info.initialBlock = initialBlock; + info.defaultLabel = nullptr; + lowerSwitchCases(stmt->body, &info); + + // TODO: once we've discovered the cases, we should + // be able to make a quick pass over the list and eliminate + // any cases that have the exact same label as the `default` + // case, since these don't actually need to be represented. + + // If the current block (the end of the last + // `case`) is not terminated, then terminate with a + // `break` operation. + // + // Double check that we aren't in the initial + // block, so we don't get tripped up on an + // empty `switch`. + auto curBlock = builder->getBlock(); + if(curBlock != initialBlock) + { + // Is the block already terminated? + if(!curBlock->getTerminator()) + { + // Not terminated, so add one. + builder->emitBreak(breakLabel); + } + } + + // If there was no `default` statement, then the + // default case will just branch directly to the end. + auto defaultLabel = info.defaultLabel ? info.defaultLabel : breakLabel; + + // Now that we've collected the cases, we are + // prepared to emit the `switch` instruction + // itself. + builder->setInsertInto(initialBlock); + builder->emitSwitch( + conditionVal, + breakLabel, + defaultLabel, + info.cases.getCount(), + info.cases.getBuffer()); + + // Finally we insert the label that a `break` will jump to + // (and that control flow will fall through to otherwise). + // This is the block that subsequent code will go into. + insertBlock(breakLabel); + context->shared->breakLabels.Remove(stmt); + } +}; + +void lowerStmt( + IRGenContext* context, + Stmt* stmt) +{ + IRBuilderSourceLocRAII sourceLocInfo(context->irBuilder, stmt->loc); + + StmtLoweringVisitor visitor; + visitor.context = context; + + try + { + visitor.dispatch(stmt); + } + // Don't emit any context message for an explicit `AbortCompilationException` + // because it should only happen when an error is already emitted. + catch(AbortCompilationException&) { throw; } + catch(...) + { + context->getSink()->noteInternalErrorLoc(stmt->loc); + throw; + } +} + +/// Create and return a mutable temporary initialized with `val` +static LoweredValInfo moveIntoMutableTemp( + IRGenContext* context, + LoweredValInfo const& val) +{ + IRInst* irVal = getSimpleVal(context, val); + auto type = irVal->getDataType(); + auto var = createVar(context, type); + + assign(context, var, LoweredValInfo::simple(irVal)); + return var; +} + +LoweredValInfo tryGetAddress( + IRGenContext* context, + LoweredValInfo const& inVal, + TryGetAddressMode mode) +{ + LoweredValInfo val = inVal; + + switch(val.flavor) + { + case LoweredValInfo::Flavor::Ptr: + // The `Ptr` case means that we already have an IR value with + // the address of our value. Easy! + return val; + + case LoweredValInfo::Flavor::BoundSubscript: + { + // If we are are trying to turn a subscript operation like `buffer[index]` + // into a pointer, then we need to find a `ref` accessor declared + // as part of the subscript operation being referenced. + // + auto subscriptInfo = val.getBoundSubscriptInfo(); + + // We don't want to immediately bind to a `ref` accessor if there is + // a `set` accessor available, unless we are in an "aggressive" mode + // where we really want/need a pointer to be able to make progress. + // + if(mode != TryGetAddressMode::Aggressive + && getMembersOfType(subscriptInfo->declRef).Count()) + { + // There is a setter that we should consider using, + // so don't go and aggressively collapse things just yet. + return val; + } + + auto refAccessors = getMembersOfType(subscriptInfo->declRef); + if(refAccessors.Count()) + { + // The `ref` accessor will return a pointer to the value, so + // we need to reflect that in the type of our `call` instruction. + IRType* ptrType = context->irBuilder->getPtrType(subscriptInfo->type); + + LoweredValInfo refVal = emitCallToDeclRef( + context, + ptrType, + *refAccessors.begin(), + nullptr, + subscriptInfo->args); + + // The result from the call should be a pointer, and it + // is the address that we wanted in the first place. + return LoweredValInfo::ptr(getSimpleVal(context, refVal)); + } + + // Otherwise, there was no `ref` accessor, and so it is not possible + // to materialize this location into a pointer for whatever purpose + // we have in mind (e.g., passing it to an atomic operation). + } + break; + + case LoweredValInfo::Flavor::BoundMember: + { + auto boundMemberInfo = val.getBoundMemberInfo(); + + // If we hit this case, then it means that we have a reference + // to a single field in something, but for whatever reason the + // higher-level logic was not able to turn it into a pointer + // already (maybe the base value for the field reference is + // a `BoundSubscript`, etc.). + // + // We need to read the entire base value out, modify the field + // we care about, and then write it back. + + auto declRef = boundMemberInfo->declRef; + if( auto fieldDeclRef = declRef.as() ) + { + auto baseVal = boundMemberInfo->base; + auto basePtr = tryGetAddress(context, baseVal, TryGetAddressMode::Aggressive); + + return extractField(context, boundMemberInfo->type, basePtr, fieldDeclRef); + } + + } + break; + + case LoweredValInfo::Flavor::SwizzledLValue: + { + auto originalSwizzleInfo = val.getSwizzledLValueInfo(); + auto originalBase = originalSwizzleInfo->base; + + UInt elementCount = originalSwizzleInfo->elementCount; + + auto newBase = tryGetAddress(context, originalBase, TryGetAddressMode::Aggressive); + RefPtr newSwizzleInfo = new SwizzledLValueInfo(); + context->shared->extValues.add(newSwizzleInfo); + + newSwizzleInfo->base = newBase; + newSwizzleInfo->type = originalSwizzleInfo->type; + newSwizzleInfo->elementCount = elementCount; + for(UInt ee = 0; ee < elementCount; ++ee) + newSwizzleInfo->elementIndices[ee] = originalSwizzleInfo->elementIndices[ee]; + + return LoweredValInfo::swizzledLValue(newSwizzleInfo); + } + break; + + // TODO: are there other cases we need to handled here? + + default: + break; + } + + // If none of the special cases above applied, then we werent' able to make + // this value into a pointer, and we should just return it as-is. + return val; +} + +IRInst* getAddress( + IRGenContext* context, + LoweredValInfo const& inVal, + SourceLoc diagnosticLocation) +{ + LoweredValInfo val = tryGetAddress(context, inVal, TryGetAddressMode::Aggressive); + + if( val.flavor == LoweredValInfo::Flavor::Ptr ) + { + return val.val; + } + + context->getSink()->diagnose(diagnosticLocation, Diagnostics::invalidLValueForRefParameter); + return nullptr; +} + +void assign( + IRGenContext* context, + LoweredValInfo const& inLeft, + LoweredValInfo const& inRight) +{ + LoweredValInfo left = inLeft; + LoweredValInfo right = inRight; + + // Before doing the case analysis on the shape of the `left` value, + // we might as well go ahead and see if we can coerce it into + // a simple pointer, since that would make our life a lot easier + // when handling complex cases. + // + left = tryGetAddress(context, left, TryGetAddressMode::Default); + + auto builder = context->irBuilder; + +top: + switch (left.flavor) + { + case LoweredValInfo::Flavor::Ptr: + { + // The `left` value is just a pointer, so we can emit + // a store to it directly. + // + builder->emitStore( + left.val, + getSimpleVal(context, right)); + } + break; + + case LoweredValInfo::Flavor::SwizzledLValue: + { + // The `left` value is of the form `.`. + // How we will handle this depends on what `base` looks like: + auto swizzleInfo = left.getSwizzledLValueInfo(); + auto loweredBase = swizzleInfo->base; + + // Note that the call to `tryGetAddress` at the start should + // ensure that `loweredBase` has been simplified as much as + // possible (e.g., if it is possible to turn it into a + // `LoweredValInfo::ptr()` then that will have been done). + + switch( loweredBase.flavor ) + { + default: + { + // Our fallback position is to lower via a temporary, e.g.: + // + // float4 tmp = ; + // tmp.xyz = float3(...); + // = tmp; + // + + // Load from the base value + IRInst* irLeftVal = getSimpleVal(context, loweredBase); + + // Extract a simple value for the right-hand side + IRInst* irRightVal = getSimpleVal(context, right); + + // Apply the swizzle + IRInst* irSwizzled = builder->emitSwizzleSet( + irLeftVal->getDataType(), + irLeftVal, + irRightVal, + swizzleInfo->elementCount, + swizzleInfo->elementIndices); + + // And finally, store the value back where we got it. + // + // Note: this is effectively a recursive call to + // `assign()`, so we do a simple tail-recursive call here. + left = loweredBase; + right = LoweredValInfo::simple(irSwizzled); + goto top; + } + break; + + case LoweredValInfo::Flavor::Ptr: + { + // We are writing through a pointer, which might be + // pointing into a UAV or other memory resource, so + // we can't introduce use a temporary like the case + // above, because then we would read and write bytes + // that are not strictly required for the store. + // + // Note that the messy case of a "swizzle of a swizzle" + // was handled already in lowering of a `SwizzleExpr`, + // so that we don't need to deal with that case here. + // + // TODO: we may need to consider whether there is + // enough value in a masked store like this to keep + // it around, in comparison to a simpler model where + // we simply form a pointer to each of the vector + // elements and write to them individually. + // + // TODO: we might also consider just special-casing + // single-element swizzles so that the common case + // can turn into a simple `store` instead of a + // `swizzledStore`. + // + IRInst* irRightVal = getSimpleVal(context, right); + builder->emitSwizzledStore( + loweredBase.val, + irRightVal, + swizzleInfo->elementCount, + swizzleInfo->elementIndices); + } + break; + } + } + break; + + case LoweredValInfo::Flavor::BoundSubscript: + { + // The `left` value refers to a subscript operation on + // a resource type, bound to particular arguments, e.g.: + // `someStructuredBuffer[index]`. + // + // When storing to such a value, we need to emit a call + // to the appropriate builtin "setter" accessor, if there + // is one, and then fall back to a `ref` accessor if + // there is no setter. + // + auto subscriptInfo = left.getBoundSubscriptInfo(); + + // Search for an appropriate "setter" declaration + auto setters = getMembersOfType(subscriptInfo->declRef); + if (setters.Count()) + { + auto allArgs = subscriptInfo->args; + addArgs(context, &allArgs, right); + + emitCallToDeclRef( + context, + builder->getVoidType(), + *setters.begin(), + nullptr, + allArgs); + return; + } + + auto refAccessors = getMembersOfType(subscriptInfo->declRef); + if(refAccessors.Count()) + { + // The `ref` accessor will return a pointer to the value, so + // we need to reflect that in the type of our `call` instruction. + IRType* ptrType = context->irBuilder->getPtrType(subscriptInfo->type); + + LoweredValInfo refVal = emitCallToDeclRef( + context, + ptrType, + *refAccessors.begin(), + nullptr, + subscriptInfo->args); + + // The result from the call needs to be implicitly dereferenced, + // so that it can work as an l-value of the desired result type. + left = LoweredValInfo::ptr(getSimpleVal(context, refVal)); + + // Tail-recursively attempt assignment again on the new l-value. + goto top; + } + + // No setter found? Then we have an error! + SLANG_UNEXPECTED("no setter found"); + break; + } + break; + + case LoweredValInfo::Flavor::BoundMember: + { + auto boundMemberInfo = left.getBoundMemberInfo(); + + // If we hit this case, then it means that we are trying to set + // a single field in someting that is not atomically set-able. + // (e.g., an element of a value where the `subscript` operation + // has `get` and `set` but not a `ref` accessor). + // + // We need to read the entire base value out, modify the field + // we care about, and then write it back. + + auto declRef = boundMemberInfo->declRef; + if( auto fieldDeclRef = declRef.as() ) + { + // materialize the base value and move it into + // a mutable temporary if needed + auto baseVal = boundMemberInfo->base; + auto tempVal = moveIntoMutableTemp(context, baseVal); + + // extract the field l-value out of the temporary + auto tempFieldVal = extractField(context, boundMemberInfo->type, tempVal, fieldDeclRef); + + // assign to the field of the temporary l-value + assign(context, tempFieldVal, right); + + // write back the modified temporary to the base l-value + assign(context, baseVal, tempVal); + + return; + } + else + { + SLANG_UNEXPECTED("handled member flavor"); + } + + } + break; + + default: + SLANG_UNIMPLEMENTED_X("assignment"); + break; + } +} + +struct DeclLoweringVisitor : DeclVisitor +{ + IRGenContext* context; + + IRBuilder* getBuilder() + { + return context->irBuilder; + } + + LoweredValInfo visitDeclBase(DeclBase* /*decl*/) + { + SLANG_UNIMPLEMENTED_X("decl catch-all"); + UNREACHABLE_RETURN(LoweredValInfo()); + } + + LoweredValInfo visitDecl(Decl* /*decl*/) + { + SLANG_UNIMPLEMENTED_X("decl catch-all"); + UNREACHABLE_RETURN(LoweredValInfo()); + } + + LoweredValInfo visitExtensionDecl(ExtensionDecl* decl) + { + for (auto & member : decl->Members) + ensureDecl(context, member); + return LoweredValInfo(); + } + + LoweredValInfo visitImportDecl(ImportDecl* /*decl*/) + { + return LoweredValInfo(); + } + + LoweredValInfo visitEmptyDecl(EmptyDecl* /*decl*/) + { + return LoweredValInfo(); + } + + LoweredValInfo visitSyntaxDecl(SyntaxDecl* /*decl*/) + { + return LoweredValInfo(); + } + + LoweredValInfo visitAttributeDecl(AttributeDecl* /*decl*/) + { + return LoweredValInfo(); + } + + LoweredValInfo visitTypeDefDecl(TypeDefDecl* decl) + { + // A type alias declaration may be generic, if it is + // nested under a generic type/function/etc. + // + NestedContext nested(this); + auto subBuilder = nested.getBuilder(); + auto subContext = nested.getContet(); + IRGeneric* outerGeneric = emitOuterGenerics(subContext, decl, decl); + + // TODO: if a type alias declaration can have linkage, + // we will need to lower it to some kind of global + // value in the IR so that we can attach a name to it. + // + // For now, we can only attach a name *if* the type + // alias is somehow generic. + if(outerGeneric) + { + addLinkageDecoration(context, outerGeneric, decl); + } + + auto type = lowerType(subContext, decl->type.type); + + return LoweredValInfo::simple(finishOuterGenerics(subBuilder, type)); + } + + LoweredValInfo visitGenericTypeParamDecl(GenericTypeParamDecl* /*decl*/) + { + return LoweredValInfo(); + } + + LoweredValInfo visitGenericTypeConstraintDecl(GenericTypeConstraintDecl* decl) + { + // This might be a type constraint on an associated type, + // in which case it should lower as the key for that + // interface requirement. + if(auto assocTypeDecl = as(decl->ParentDecl)) + { + // TODO: might need extra steps if we ever allow + // generic associated types. + + + if(auto interfaceDecl = as(assocTypeDecl->ParentDecl)) + { + // Okay, this seems to be an interface rquirement, and + // we should lower it as such. + return LoweredValInfo::simple(getInterfaceRequirementKey(decl)); + } + } + + if(auto globalGenericParamDecl = as(decl->ParentDecl)) + { + // This is a constraint on a global generic type parameters, + // and so it should lower as a parameter of its own. + + auto inst = getBuilder()->emitGlobalGenericParam(); + addLinkageDecoration(context, inst, decl); + return LoweredValInfo::simple(inst); + } + + // Otherwise we really don't expect to see a type constraint + // declaration like this during lowering, because a generic + // should have set up a parameter for any constraints as + // part of being lowered. + + SLANG_UNEXPECTED("generic type constraint during lowering"); + UNREACHABLE_RETURN(LoweredValInfo()); + } + + LoweredValInfo visitGlobalGenericParamDecl(GlobalGenericParamDecl* decl) + { + auto inst = getBuilder()->emitGlobalGenericParam(); + addLinkageDecoration(context, inst, decl); + return LoweredValInfo::simple(inst); + } + + void lowerWitnessTable( + IRGenContext* subContext, + WitnessTable* astWitnessTable, + IRWitnessTable* irWitnessTable, + Dictionary mapASTToIRWitnessTable) + { + auto subBuilder = subContext->irBuilder; + + for(auto entry : astWitnessTable->requirementDictionary) + { + auto requiredMemberDecl = entry.Key; + auto satisfyingWitness = entry.Value; + + auto irRequirementKey = getInterfaceRequirementKey(requiredMemberDecl); + IRInst* irSatisfyingVal = nullptr; + + switch(satisfyingWitness.getFlavor()) + { + case RequirementWitness::Flavor::declRef: + { + auto satisfyingDeclRef = satisfyingWitness.getDeclRef(); + irSatisfyingVal = getSimpleVal(subContext, + emitDeclRef(subContext, satisfyingDeclRef, + // TODO: we need to know what type to plug in here... + nullptr)); + } + break; + + case RequirementWitness::Flavor::val: + { + auto satisfyingVal = satisfyingWitness.getVal(); + irSatisfyingVal = lowerSimpleVal(subContext, satisfyingVal); + } + break; + + case RequirementWitness::Flavor::witnessTable: + { + auto astReqWitnessTable = satisfyingWitness.getWitnessTable(); + IRWitnessTable* irSatisfyingWitnessTable = nullptr; + if(!mapASTToIRWitnessTable.TryGetValue(astReqWitnessTable, irSatisfyingWitnessTable)) + { + // Need to construct a sub-witness-table + irSatisfyingWitnessTable = subBuilder->createWitnessTable(); + + // Recursively lower the sub-table. + lowerWitnessTable( + subContext, + astReqWitnessTable, + irSatisfyingWitnessTable, + mapASTToIRWitnessTable); + + irSatisfyingWitnessTable->moveToEnd(); + } + irSatisfyingVal = irSatisfyingWitnessTable; + } + break; + + default: + SLANG_UNEXPECTED("handled requirement witness case"); + break; + } + + + subBuilder->createWitnessTableEntry( + irWitnessTable, + irRequirementKey, + irSatisfyingVal); + } + } + + LoweredValInfo visitInheritanceDecl(InheritanceDecl* inheritanceDecl) + { + // An inheritance clause inside of an `interface` + // declaration should not give rise to a witness + // table, because it represents something the + // interface requires, and not what it provides. + // + auto parentDecl = inheritanceDecl->ParentDecl; + if (auto parentInterfaceDecl = as(parentDecl)) + { + return LoweredValInfo::simple(getInterfaceRequirementKey(inheritanceDecl)); + } + // + // We also need to cover the case where an `extension` + // declaration is being used to add a conformance to + // an existing `interface`: + // + if(auto parentExtensionDecl = as(parentDecl)) + { + auto targetType = parentExtensionDecl->targetType; + if(auto targetDeclRefType = as(targetType)) + { + if(auto targetInterfaceDeclRef = targetDeclRefType->declRef.as()) + { + return LoweredValInfo::simple(getInterfaceRequirementKey(inheritanceDecl)); + } + } + } + + // Find the type that is doing the inheriting. + // Under normal circumstances it is the type declaration that + // is the parent for the inheritance declaration, but if + // the inheritance declaration is on an `extension` declaration, + // then we need to identify the type being extended. + // + RefPtr subType; + if (auto extParentDecl = as(parentDecl)) + { + subType = extParentDecl->targetType.type; + } + else + { + subType = DeclRefType::Create( + context->getSession(), + makeDeclRef(parentDecl)); + } + + // What is the super-type that we have declared we inherit from? + RefPtr superType = inheritanceDecl->base.type; + + // Construct the mangled name for the witness table, which depends + // on the type that is conforming, and the type that it conforms to. + // + // TODO: This approach doesn't really make sense for generic `extension` conformances. + auto mangledName = getMangledNameForConformanceWitness(subType, superType); + + // A witness table may need to be generic, if the outer + // declaration (either a type declaration or an `extension`) + // is generic. + // + NestedContext nested(this); + auto subBuilder = nested.getBuilder(); + auto subContext = nested.getContet(); + emitOuterGenerics(subContext, inheritanceDecl, inheritanceDecl); + + // Lower the super-type to force its declaration to be lowered. + // + // Note: we are using the "sub-context" here because the + // type being inherited from could reference generic parameters, + // and we need those parameters to lower as references to + // the parameters of our IR-level generic. + // + lowerType(subContext, superType); + + // Create the IR-level witness table + auto irWitnessTable = subBuilder->createWitnessTable(); + addLinkageDecoration(context, irWitnessTable, inheritanceDecl, mangledName.getUnownedSlice()); + + // Register the value now, rather than later, to avoid any possible infinite recursion. + setGlobalValue(context, inheritanceDecl, LoweredValInfo::simple(irWitnessTable)); + + // Make sure that all the entries in the witness table have been filled in, + // including any cases where there are sub-witness-tables for conformances + Dictionary mapASTToIRWitnessTable; + lowerWitnessTable( + subContext, + inheritanceDecl->witnessTable, + irWitnessTable, + mapASTToIRWitnessTable); + + irWitnessTable->moveToEnd(); + + return LoweredValInfo::simple(finishOuterGenerics(subBuilder, irWitnessTable)); + } + + LoweredValInfo visitDeclGroup(DeclGroup* declGroup) + { + // To lower a group of declarations, we just + // lower each one individually. + // + for (auto decl : declGroup->decls) + { + IRBuilderSourceLocRAII sourceLocInfo(context->irBuilder, decl->loc); + + // Note: I am directly invoking `dispatch` here, + // instead of `ensureDecl` just to try and + // make sure that we don't accidentally + // emit things to an outer context. + // + // TODO: make sure that can't happen anyway. + dispatch(decl); + } + + return LoweredValInfo(); + } + + LoweredValInfo visitSubscriptDecl(SubscriptDecl* decl) + { + // A subscript operation may encompass one or more + // accessors, and these are what should actually + // get lowered (they are effectively functions). + + for (auto accessor : decl->getMembersOfType()) + { + if (accessor->HasModifier()) + continue; + + ensureDecl(context, accessor); + } + + // The subscript declaration itself won't correspond + // to anything in the lowered program, so we don't + // bother creating a representation here. + // + // Note: We may want to have a specific lowered value + // that can represent the combination of callables + // that make up the subscript operation. + return LoweredValInfo(); + } + + bool isGlobalVarDecl(VarDecl* decl) + { + auto parent = decl->ParentDecl; + if (as(parent)) + { + // Variable declared at global scope? -> Global. + return true; + } + else if(as(parent)) + { + if(decl->HasModifier()) + { + // A `static` member variable is effectively global. + return true; + } + } + + return false; + } + + bool isMemberVarDecl(VarDecl* decl) + { + auto parent = decl->ParentDecl; + if (as(parent)) + { + // A variable declared inside of an aggregate type declaration is a member. + return true; + } + + return false; + } + + LoweredValInfo lowerGlobalShaderParam(VarDecl* decl) + { + IRType* paramType = lowerType(context, decl->getType()); + + auto builder = getBuilder(); + + auto irParam = builder->createGlobalParam(paramType); + auto paramVal = LoweredValInfo::simple(irParam); + + addLinkageDecoration(context, irParam, decl); + addNameHint(context, irParam, decl); + maybeSetRate(context, irParam, decl); + addVarDecorations(context, irParam, decl); + + if (decl) + { + builder->addHighLevelDeclDecoration(irParam, decl); + } + + // A global variable's SSA value is a *pointer* to + // the underlying storage. + setGlobalValue(context, decl, paramVal); + + irParam->moveToEnd(); + + return paramVal; + } + + LoweredValInfo lowerGlobalVarDecl(VarDecl* decl) + { + if(isGlobalShaderParameter(decl)) + { + return lowerGlobalShaderParam(decl); + } + + IRType* varType = lowerType(context, decl->getType()); + + auto builder = getBuilder(); + + IRGlobalValueWithCode* irGlobal = nullptr; + LoweredValInfo globalVal; + + // a `static const` global is actually a compile-time constant + if (decl->HasModifier() && decl->HasModifier()) + { + irGlobal = builder->createGlobalConstant(varType); + globalVal = LoweredValInfo::simple(irGlobal); + } + else + { + irGlobal = builder->createGlobalVar(varType); + globalVal = LoweredValInfo::ptr(irGlobal); + } + addLinkageDecoration(context, irGlobal, decl); + addNameHint(context, irGlobal, decl); + + maybeSetRate(context, irGlobal, decl); + + addVarDecorations(context, irGlobal, decl); + + if (decl) + { + builder->addHighLevelDeclDecoration(irGlobal, decl); + } + + // A global variable's SSA value is a *pointer* to + // the underlying storage. + setGlobalValue(context, decl, globalVal); + + if (isImportedDecl(decl)) + { + // Always emit imported declarations as declarations, + // and not definitions. + } + else if( auto initExpr = decl->initExpr ) + { + IRBuilder subBuilderStorage = *getBuilder(); + IRBuilder* subBuilder = &subBuilderStorage; + + subBuilder->setInsertInto(irGlobal); + + IRGenContext subContextStorage = *context; + IRGenContext* subContext = &subContextStorage; + + subContext->irBuilder = subBuilder; + + // TODO: set up a parent IR decl to put the instructions into + + IRBlock* entryBlock = subBuilder->emitBlock(); + subBuilder->setInsertInto(entryBlock); + + LoweredValInfo initVal = lowerLValueExpr(subContext, initExpr); + subContext->irBuilder->emitReturn(getSimpleVal(subContext, initVal)); + } + + irGlobal->moveToEnd(); + + return globalVal; + } + + bool isFunctionStaticVarDecl(VarDeclBase* decl) + { + // Only a variable marked `static` can be static. + if(!decl->FindModifier()) + return false; + + // The immediate parent of a function-scope variable + // declaration will be a `ScopeDecl`. + // + // TODO: right now the parent links for scopes are *not* + // set correctly, so we can't just scan up and look + // for a function in the parent chain... + auto parent = decl->ParentDecl; + if( as(parent) ) + { + return true; + } + + return false; + } + + IRInst* defaultSpecializeOuterGeneric( + IRInst* outerVal, + IRType* type, + GenericDecl* genericDecl) + { + auto builder = getBuilder(); + + // We need to specialize any generics that are further out... + auto specialiedOuterVal = defaultSpecializeOuterGenerics( + outerVal, + builder->getGenericKind(), + genericDecl); + + List genericArgs; + + // Walk the parameters of the generic, and emit an argument for each, + // which will be a reference to binding for that parameter in the + // current scope. + // + // First we start with type and value parameters, + // in the order they were declared. + for (auto member : genericDecl->Members) + { + if (auto typeParamDecl = as(member)) + { + genericArgs.add(getSimpleVal(context, ensureDecl(context, typeParamDecl))); + } + else if (auto valDecl = as(member)) + { + genericArgs.add(getSimpleVal(context, ensureDecl(context, valDecl))); + } + } + // Then we emit constraint parameters, again in + // declaration order. + for (auto member : genericDecl->Members) + { + if (auto constraintDecl = as(member)) + { + genericArgs.add(getSimpleVal(context, ensureDecl(context, constraintDecl))); + } + } + + return builder->emitSpecializeInst(type, specialiedOuterVal, genericArgs.getCount(), genericArgs.getBuffer()); + } + + IRInst* defaultSpecializeOuterGenerics( + IRInst* val, + IRType* type, + Decl* decl) + { + if(!val) return nullptr; + + auto parentVal = val->getParent(); + while(parentVal) + { + if(as(parentVal)) + break; + parentVal = parentVal->getParent(); + } + if(!parentVal) + return val; + + for(auto pp = decl->ParentDecl; pp; pp = pp->ParentDecl) + { + if(auto genericAncestor = as(pp)) + { + return defaultSpecializeOuterGeneric(parentVal, type, genericAncestor); + } + } + + return val; + } + + struct NestedContext + { + IRGenEnv subEnvStorage; + IRBuilder subBuilderStorage; + IRGenContext subContextStorage; + + NestedContext(DeclLoweringVisitor* outer) + : subBuilderStorage(*outer->getBuilder()) + , subContextStorage(*outer->context) + { + auto outerContext = outer->context; + + subEnvStorage.outer = outerContext->env; + + subContextStorage.irBuilder = &subBuilderStorage; + subContextStorage.env = &subEnvStorage; + } + + IRBuilder* getBuilder() { return &subBuilderStorage; } + IRGenContext* getContet() { return &subContextStorage; } + }; + + LoweredValInfo lowerFunctionStaticConstVarDecl( + VarDeclBase* decl) + { + // We need to insert the constant at a level above + // the function being emitted. This will usually + // be the global scope, but it might be an outer + // generic if we are lowering a generic function. + // + NestedContext nestedContext(this); + auto subBuilder = nestedContext.getBuilder(); + auto subContext = nestedContext.getContet(); + + subBuilder->setInsertInto(subBuilder->getFunc()->getParent()); + + IRType* subVarType = lowerType(subContext, decl->getType()); + + IRGlobalConstant* irConstant = subBuilder->createGlobalConstant(subVarType); + addVarDecorations(subContext, irConstant, decl); + addNameHint(context, irConstant, decl); + maybeSetRate(context, irConstant, decl); + subBuilder->addHighLevelDeclDecoration(irConstant, decl); + + LoweredValInfo constantVal = LoweredValInfo::ptr(irConstant); + setValue(context, decl, constantVal); + + if( auto initExpr = decl->initExpr ) + { + NestedContext nestedInitContext(this); + auto initBuilder = nestedInitContext.getBuilder(); + auto initContext = nestedInitContext.getContet(); + + initBuilder->setInsertInto(irConstant); + + IRBlock* entryBlock = initBuilder->emitBlock(); + initBuilder->setInsertInto(entryBlock); + + LoweredValInfo initVal = lowerRValueExpr(initContext, initExpr); + initBuilder->emitReturn(getSimpleVal(initContext, initVal)); + } + + return constantVal; + } + + LoweredValInfo lowerFunctionStaticVarDecl( + VarDeclBase* decl) + { + // We know the variable is `static`, but it might also be `const. + if(decl->HasModifier()) + return lowerFunctionStaticConstVarDecl(decl); + + // A global variable may need to be generic, if one + // of the outer declarations is generic. + NestedContext nestedContext(this); + auto subBuilder = nestedContext.getBuilder(); + auto subContext = nestedContext.getContet(); + subBuilder->setInsertInto(subBuilder->getModule()->getModuleInst()); + emitOuterGenerics(subContext, decl, decl); + + IRType* subVarType = lowerType(subContext, decl->getType()); + + IRGlobalValueWithCode* irGlobal = subBuilder->createGlobalVar(subVarType); + addVarDecorations(subContext, irGlobal, decl); + + addNameHint(context, irGlobal, decl); + maybeSetRate(context, irGlobal, decl); + + subBuilder->addHighLevelDeclDecoration(irGlobal, decl); + + // We are inside of a function, and that function might be generic, + // in which case the `static` variable will be lowered to another + // generic. Let's start with a terrible example: + // + // interface IHasCount { int getCount(); } + // int incrementCounter(T val) { + // static int counter = 0; + // counter += val.getCount(); + // return counter; + // } + // + // In this case, `incrementCounter` will lower to a function + // nested in a generic, while `counter` will be lowered to + // a global variable nested in a *different* generic. + // The net result is something like this: + // + // int counter = 0; + // + // int incrementCounter(T val) { + // counter += val.getCount(); + // return counter; + // + // The references to `counter` inside of `incrementCounter` + // become references to `counter`. + // + // At the IR level, this means that the value we install + // for `decl` needs to be a specialized reference to `irGlobal`, + // for any outer generics. + // + IRType* varType = lowerType(context, decl->getType()); + IRType* varPtrType = getBuilder()->getPtrType(varType); + auto irSpecializedGlobal = defaultSpecializeOuterGenerics(irGlobal, varPtrType, decl); + LoweredValInfo globalVal = LoweredValInfo::ptr(irSpecializedGlobal); + setValue(context, decl, globalVal); + + // A `static` variable with an initializer needs special handling, + // at least if the initializer isn't a compile-time constant. + if( auto initExpr = decl->initExpr ) + { + // We must create an ordinary global `bool isInitialized = false` + // to represent whether we've initialized this before. + // Then emit code like: + // + // if(!isInitialized) { = ; isInitialized = true; } + // + // TODO: we could conceivably optimize this by detecting + // when the `initExpr` lowers to just a reference to a constant, + // and then either deleting the extra code structure there, + // or not generating it in the first place. That is a bit + // more complexity than I'm ready for at the moment. + // + + // Of course, if we are under a generic, then the Boolean + // variable need to be generic as well! + NestedContext nestedBoolContext(this); + auto boolBuilder = nestedBoolContext.getBuilder(); + auto boolContext = nestedBoolContext.getContet(); + boolBuilder->setInsertInto(boolBuilder->getModule()->getModuleInst()); + emitOuterGenerics(boolContext, decl, decl); + + auto irBoolType = boolBuilder->getBoolType(); + auto irBool = boolBuilder->createGlobalVar(irBoolType); + boolBuilder->setInsertInto(irBool); + boolBuilder->setInsertInto(boolBuilder->createBlock()); + boolBuilder->emitReturn(boolBuilder->getBoolValue(false)); + + auto boolVal = LoweredValInfo::ptr(defaultSpecializeOuterGenerics(irBool, irBoolType, decl)); + + + // Okay, with our global Boolean created, we can move on to + // generating the code we actually care about, back in the original function. + + auto builder = getBuilder(); + + auto initBlock = builder->createBlock(); + auto afterBlock = builder->createBlock(); + + builder->emitIfElse(getSimpleVal(context, boolVal), afterBlock, initBlock, afterBlock); + + builder->insertBlock(initBlock); + LoweredValInfo initVal = lowerLValueExpr(context, initExpr); + assign(context, globalVal, initVal); + assign(context, boolVal, LoweredValInfo::simple(builder->getBoolValue(true))); + builder->emitBranch(afterBlock); + + builder->insertBlock(afterBlock); + } + + irGlobal->moveToEnd(); + finishOuterGenerics(subBuilder, irGlobal); + return globalVal; + } + + LoweredValInfo visitGenericValueParamDecl(GenericValueParamDecl* decl) + { + return emitDeclRef(context, makeDeclRef(decl), + lowerType(context, decl->type)); + } + + LoweredValInfo visitVarDecl(VarDecl* decl) + { + // Detect global (or effectively global) variables + // and handle them differently. + if (isGlobalVarDecl(decl)) + { + return lowerGlobalVarDecl(decl); + } + + if(isFunctionStaticVarDecl(decl)) + { + return lowerFunctionStaticVarDecl(decl); + } + + if(isMemberVarDecl(decl)) + { + return lowerMemberVarDecl(decl); + } + + // A user-defined variable declaration will usually turn into + // an `alloca` operation for the variable's storage, + // plus some code to initialize it and then store to the variable. + + IRType* varType = lowerType(context, decl->getType()); + + // As a special case, an immutable local variable with an + // initializer can just lower to the SSA value of its initializer. + // + if(as(decl)) + { + if(auto initExpr = decl->initExpr) + { + auto initVal = lowerRValueExpr(context, initExpr); + initVal = materialize(context, initVal); + setGlobalValue(context, decl, initVal); + return initVal; + } + } + + + LoweredValInfo varVal = createVar(context, varType, decl); + + if( auto initExpr = decl->initExpr ) + { + auto initVal = lowerRValueExpr(context, initExpr); + + assign(context, varVal, initVal); + } + + setGlobalValue(context, decl, varVal); + + return varVal; + } + + IRStructKey* getInterfaceRequirementKey(Decl* requirementDecl) + { + return Slang::getInterfaceRequirementKey(context, requirementDecl); + } + + LoweredValInfo visitInterfaceDecl(InterfaceDecl* decl) + { + // The members of an interface will turn into the keys that will + // be used for lookup operations into witness + // tables that promise conformance to the interface. + // + // TODO: we don't handle the case here of an interface + // with concrete/default implementations for any + // of its members. + // + // TODO: If we want to support using an interface as + // an existential type, then we might need to emit + // a witness table for the interface type's conformance + // to its own interface. + // + for (auto requirementDecl : decl->Members) + { + getInterfaceRequirementKey(requirementDecl); + + // As a special case, any type constraints placed + // on an associated type will *also* need to be turned + // into requirement keys for this interface. + if (auto associatedTypeDecl = as(requirementDecl)) + { + for (auto constraintDecl : associatedTypeDecl->getMembersOfType()) + { + getInterfaceRequirementKey(constraintDecl); + } + } + } + + + NestedContext nestedContext(this); + auto subBuilder = nestedContext.getBuilder(); + auto subContext = nestedContext.getContet(); + + // Emit any generics that should wrap the actual type. + emitOuterGenerics(subContext, decl, decl); + + IRInterfaceType* irInterface = subBuilder->createInterfaceType(); + addNameHint(context, irInterface, decl); + addLinkageDecoration(context, irInterface, decl); + subBuilder->setInsertInto(irInterface); + + // TODO: are there any interface members that should be + // nested inside the interface type itself? + + irInterface->moveToEnd(); + + addTargetIntrinsicDecorations(irInterface, decl); + + + return LoweredValInfo::simple(finishOuterGenerics(subBuilder, irInterface)); + } + + LoweredValInfo visitEnumCaseDecl(EnumCaseDecl* decl) + { + // A case within an `enum` decl will lower to a value + // of the `enum`'s "tag" type. + // + // TODO: a bit more work will be needed if we allow for + // enum cases that have payloads, because then we need + // a function that constructs the value given arguments. + // + NestedContext nestedContext(this); + auto subContext = nestedContext.getContet(); + + // Emit any generics that should wrap the actual type. + emitOuterGenerics(subContext, decl, decl); + + return lowerRValueExpr(subContext, decl->tagExpr); + } + + LoweredValInfo visitEnumDecl(EnumDecl* decl) + { + // Given a declaration of a type, we need to make sure + // to output "witness tables" for any interfaces this + // type has declared conformance to. + for( auto inheritanceDecl : decl->getMembersOfType() ) + { + ensureDecl(context, inheritanceDecl); + } + + NestedContext nestedContext(this); + auto subBuilder = nestedContext.getBuilder(); + auto subContext = nestedContext.getContet(); + emitOuterGenerics(subContext, decl, decl); + + // An `enum` declaration will currently lower directly to its "tag" + // type, so that any references to the `enum` become referenes to + // the tag type instead. + // + // TODO: if we ever support `enum` types with payloads, we would + // need to make the `enum` lower to some kind of custom "tagged union" + // type. + + IRType* loweredTagType = lowerType(subContext, decl->tagType); + + return LoweredValInfo::simple(finishOuterGenerics(subBuilder, loweredTagType)); + } + + LoweredValInfo visitAggTypeDecl(AggTypeDecl* decl) + { + // Don't generate an IR `struct` for intrinsic types + if(decl->FindModifier() || decl->FindModifier()) + { + return LoweredValInfo(); + } + + // Given a declaration of a type, we need to make sure + // to output "witness tables" for any interfaces this + // type has declared conformance to. + for( auto inheritanceDecl : decl->getMembersOfType() ) + { + ensureDecl(context, inheritanceDecl); + } + + // We are going to create nested IR building state + // to use when emitting the members of the type. + // + NestedContext nestedContext(this); + auto subBuilder = nestedContext.getBuilder(); + auto subContext = nestedContext.getContet(); + + // Emit any generics that should wrap the actual type. + emitOuterGenerics(subContext, decl, decl); + + IRStructType* irStruct = subBuilder->createStructType(); + addNameHint(context, irStruct, decl); + addLinkageDecoration(context, irStruct, decl); + + subBuilder->setInsertInto(irStruct); + + for (auto fieldDecl : decl->getMembersOfType()) + { + if (fieldDecl->HasModifier()) + { + // A `static` field is actually a global variable, + // and we should emit it as such. + ensureDecl(context, fieldDecl); + continue; + } + + // Each ordinary field will need to turn into a struct "key" + // that is used for fetching the field. + IRInst* fieldKeyInst = getSimpleVal(context, + ensureDecl(context, fieldDecl)); + auto fieldKey = as(fieldKeyInst); + SLANG_ASSERT(fieldKey); + + // Note: we lower the type of the field in the "sub" + // context, so that any generic parameters that were + // set up for the type can be referenced by the field type. + IRType* fieldType = lowerType( + subContext, + fieldDecl->getType()); + + // Then, the parent `struct` instruction itself will have + // a "field" instruction. + subBuilder->createStructField( + irStruct, + fieldKey, + fieldType); + } + + // There may be members not handled by the above logic (e.g., + // member functions), but we will not immediately force them + // to be emitted here, so as not to risk a circular dependency. + // + // Instead we will force emission of all children of aggregate + // type declarations later, from the top-level emit logic. + + irStruct->moveToEnd(); + addTargetIntrinsicDecorations(irStruct, decl); + + return LoweredValInfo::simple(finishOuterGenerics(subBuilder, irStruct)); + } + + LoweredValInfo lowerMemberVarDecl(VarDecl* fieldDecl) + { + // Each field declaration in the AST translates into + // a "key" that can be used to extract field values + // from instances of struct types that contain the field. + // + // It is correct to say struct *types* because a `struct` + // nested under a generic can be used to realize a number + // of different concrete types, but all of these types + // will use the same space of keys. + + auto builder = getBuilder(); + auto irFieldKey = builder->createStructKey(); + addNameHint(context, irFieldKey, fieldDecl); + + addVarDecorations(context, irFieldKey, fieldDecl); + + addLinkageDecoration(context, irFieldKey, fieldDecl); + + if (auto semanticModifier = fieldDecl->FindModifier()) + { + builder->addSemanticDecoration(irFieldKey, semanticModifier->name.getName()->text.getUnownedSlice()); + } + + // We allow a field to be marked as a target intrinsic, + // so that we can override its mangled name in the + // output for the chosen target. + addTargetIntrinsicDecorations(irFieldKey, fieldDecl); + + + return LoweredValInfo::simple(irFieldKey); + } + + + DeclRef createDefaultSpecializedDeclRefImpl(Decl* decl) + { + DeclRef declRef; + declRef.decl = decl; + declRef.substitutions = createDefaultSubstitutions(context->getSession(), decl); + return declRef; + } + // + // The client should actually call the templated wrapper, to preserve type information. + template + DeclRef createDefaultSpecializedDeclRef(D* decl) + { + DeclRef declRef = createDefaultSpecializedDeclRefImpl(decl); + return declRef.as(); + } + + + // When lowering something callable (most commonly a function declaration), + // we need to construct an appropriate parameter list for the IR function + // that folds in any contributions from both the declaration itself *and* + // its parent declaration(s). + // + // For example, given code like: + // + // struct Foo { int bar(float y) { ... } }; + // + // we need to generate IR-level code something like: + // + // func Foo_bar(Foo this, float y) -> int; + // + // that is, the `this` parameter has become explicit. + // + // The same applies to generic parameters, and these + // should apply even if the nested declaration is `static`: + // + // struct Foo { static int bar(T y) { ... } }; + // + // becomes: + // + // func Foo_bar(T y) -> int; + // + // In order to implement this, we are going to do a recursive + // walk over a declaration and its parents, collecting separate + // lists of ordinary and generic parameters that will need + // to be included in the final declaration's parameter list. + // + // When doing code generation for an ordinary value parameter, + // we mostly care about its type, and then also its "direction" + // (`in`, `out`, `in out`). We sometimes need acess to the + // original declaration so that we can inspect it for meta-data, + // but in some cases there is no such declaration (e.g., a `this` + // parameter doesn't get an explicit declaration in the AST). + // To handle this we break out the relevant data into derived + // structures: + // + enum ParameterDirection + { + kParameterDirection_In, ///< Copy in + kParameterDirection_Out, ///< Copy out + kParameterDirection_InOut, ///< Copy in, copy out + kParameterDirection_Ref, ///< By-reference + }; + struct ParameterInfo + { + // This AST-level type of the parameter + RefPtr type; + + // The direction (`in` vs `out` vs `in out`) + ParameterDirection direction; + + // The variable/parameter declaration for + // this parameter (if any) + VarDeclBase* decl; + + // Is this the representation of a `this` parameter? + bool isThisParam = false; + }; + // + // We need a way to compute the appropriate `ParameterDirection` for a + // declared parameter: + // + ParameterDirection getParameterDirection(VarDeclBase* paramDecl) + { + if( paramDecl->HasModifier() ) + { + // The AST specified `ref`: + return kParameterDirection_Ref; + } + if( paramDecl->HasModifier() ) + { + // The AST specified `inout`: + return kParameterDirection_InOut; + } + if (paramDecl->HasModifier()) + { + // We saw an `out` modifier, so now we need + // to check if there was a paired `in`. + if(paramDecl->HasModifier()) + return kParameterDirection_InOut; + else + return kParameterDirection_Out; + } + else + { + // No direction modifier, or just `in`: + return kParameterDirection_In; + } + } + // We need a way to be able to create a `ParameterInfo` given the declaration + // of a parameter: + // + ParameterInfo getParameterInfo(VarDeclBase* paramDecl) + { + ParameterInfo info; + info.type = paramDecl->getType(); + info.decl = paramDecl; + info.direction = getParameterDirection(paramDecl); + info.isThisParam = false; + return info; + } + // + + // Here's the declaration for the type to hold the lists: + struct ParameterLists + { + List params; + }; + // + // Because there might be a `static` declaration somewhere + // along the lines, we need to be careful to prohibit adding + // non-generic parameters in some cases. + enum ParameterListCollectMode + { + // Collect everything: ordinary and generic parameters. + kParameterListCollectMode_Default, + + + // Only collect generic parameters. + kParameterListCollectMode_Static, + }; + // + // We also need to be able to detect whether a declaration is + // either explicitly or implicitly treated as `static`: + ParameterListCollectMode getModeForCollectingParentParameters( + Decl* decl, + ContainerDecl* parentDecl) + { + // If we have a `static` parameter, then it is obvious + // that we should use the `static` mode + if(isEffectivelyStatic(decl, parentDecl)) + return kParameterListCollectMode_Static; + + // Otherwise, let's default to collecting everything + return kParameterListCollectMode_Default; + } + // + // When dealing with a member function, we need to be able to add the `this` + // parameter for the enclosing type: + // + void addThisParameter( + ParameterDirection direction, + Type* type, + ParameterLists* ioParameterLists) + { + ParameterInfo info; + info.type = type; + info.decl = nullptr; + info.direction = direction; + info.isThisParam = true; + + ioParameterLists->params.add(info); + } + void addThisParameter( + ParameterDirection direction, + AggTypeDecl* typeDecl, + ParameterLists* ioParameterLists) + { + // We need to construct an appopriate declaration-reference + // for the type declaration we were given. In particular, + // we need to specialize it for any generic parameters + // that are in scope here. + auto declRef = createDefaultSpecializedDeclRef(typeDecl); + RefPtr type = DeclRefType::Create(context->getSession(), declRef); + addThisParameter( + direction, + type, + ioParameterLists); + } + // + // And here is our function that will do the recursive walk: + void collectParameterLists( + Decl* decl, + ParameterLists* ioParameterLists, + ParameterListCollectMode mode) + { + // The parameters introduced by any "parent" declarations + // will need to come first, so we'll deal with that + // logic here. + if( auto parentDecl = decl->ParentDecl ) + { + // Compute the mode to use when collecting parameters from + // the outer declaration. The most important question here + // is whether parameters of the outer declaration should + // also count as parameters of the inner declaration. + ParameterListCollectMode innerMode = getModeForCollectingParentParameters(decl, parentDecl); + + // Don't down-grade our `static`-ness along the chain. + if(innerMode < mode) + innerMode = mode; + + // Now collect any parameters from the parent declaration itself + collectParameterLists(parentDecl, ioParameterLists, innerMode); + + // We also need to consider whether the inner declaration needs to have a `this` + // parameter corresponding to the outer declaration. + if( innerMode != kParameterListCollectMode_Static ) + { + // For now we make any `this` parameter default to `in`. + // + ParameterDirection direction = kParameterDirection_In; + // + // Applications can opt in to a mutable `this` parameter, + // by applying the `[mutating]` attribute to their + // declaration. + // + if( decl->HasModifier() ) + { + direction = kParameterDirection_InOut; + } + + if( auto aggTypeDecl = as(parentDecl) ) + { + addThisParameter(direction, aggTypeDecl, ioParameterLists); + } + else if( auto extensionDecl = as(parentDecl) ) + { + addThisParameter(direction, extensionDecl->targetType, ioParameterLists); + } + } + } + + // Once we've added any parameters based on parent declarations, + // we can see if this declaration itself introduces parameters. + // + if( auto callableDecl = as(decl) ) + { + // Don't collect parameters from the outer scope if + // we are in a `static` context. + if( mode == kParameterListCollectMode_Default ) + { + for( auto paramDecl : callableDecl->GetParameters() ) + { + ioParameterLists->params.add(getParameterInfo(paramDecl)); + } + } + } + } + + bool isImportedDecl(Decl* decl) + { + return Slang::isImportedDecl(context, decl); + } + + bool isConstExprVar(Decl* decl) + { + if( decl->HasModifier() ) + { + return true; + } + else if(decl->HasModifier() && decl->HasModifier()) + { + return true; + } + + return false; + } + + IRType* maybeGetConstExprType(IRType* type, Decl* decl) + { + if(isConstExprVar(decl)) + { + return getBuilder()->getRateQualifiedType( + getBuilder()->getConstExprRate(), + type); + } + + return type; + } + + IRGeneric* emitOuterGeneric( + IRGenContext* subContext, + GenericDecl* genericDecl, + Decl* leafDecl) + { + auto subBuilder = subContext->irBuilder; + + // Of course, a generic might itself be nested inside of other generics... + emitOuterGenerics(subContext, genericDecl, leafDecl); + + // We need to create an IR generic + + auto irGeneric = subBuilder->emitGeneric(); + subBuilder->setInsertInto(irGeneric); + + auto irBlock = subBuilder->emitBlock(); + subBuilder->setInsertInto(irBlock); + + // Now emit any parameters of the generic + // + // First we start with type and value parameters, + // in the order they were declared. + for (auto member : genericDecl->Members) + { + if (auto typeParamDecl = as(member)) + { + // TODO: use a `TypeKind` to represent the + // classifier of the parameter. + auto param = subBuilder->emitParam(nullptr); + addNameHint(context, param, typeParamDecl); + setValue(subContext, typeParamDecl, LoweredValInfo::simple(param)); + } + else if (auto valDecl = as(member)) + { + auto paramType = lowerType(subContext, valDecl->getType()); + auto param = subBuilder->emitParam(paramType); + addNameHint(context, param, valDecl); + setValue(subContext, valDecl, LoweredValInfo::simple(param)); + } + } + // Then we emit constraint parameters, again in + // declaration order. + for (auto member : genericDecl->Members) + { + if (auto constraintDecl = as(member)) + { + // TODO: use a `WitnessTableKind` to represent the + // classifier of the parameter. + auto param = subBuilder->emitParam(nullptr); + addNameHint(context, param, constraintDecl); + setValue(subContext, constraintDecl, LoweredValInfo::simple(param)); + } + } + + return irGeneric; + } + + // If the given `decl` is enclosed in any generic declarations, then + // emit IR-level generics to represent them. + // The `leafDecl` represents the inner-most declaration we are actually + // trying to emit, which is the one that should receive the mangled name. + // + IRGeneric* emitOuterGenerics(IRGenContext* subContext, Decl* decl, Decl* leafDecl) + { + for(auto pp = decl->ParentDecl; pp; pp = pp->ParentDecl) + { + if(auto genericAncestor = as(pp)) + { + return emitOuterGeneric(subContext, genericAncestor, leafDecl); + } + } + + return nullptr; + } + + // If any generic declarations have been created by `emitOuterGenerics`, + // then finish them off by emitting `return` instructions for the + // values that they should produce. + // + // Return the outer-most generic (if there is one), or the original + // value (if there were no generics), which should be the IR-level + // representation of the original declaration. + // + IRInst* finishOuterGenerics( + IRBuilder* subBuilder, + IRInst* val) + { + IRInst* v = val; + for(;;) + { + auto parentBlock = as(v->getParent()); + if (!parentBlock) break; + + auto parentGeneric = as(parentBlock->getParent()); + if (!parentGeneric) break; + + subBuilder->setInsertInto(parentBlock); + subBuilder->emitReturn(v); + parentGeneric->moveToEnd(); + + // There might be more outer generics, + // so we need to loop until we run out. + v = parentGeneric; + } + return v; + } + + // Attach target-intrinsic decorations to an instruction, + // based on modifiers on an AST declaration. + void addTargetIntrinsicDecorations( + IRInst* irInst, + Decl* decl) + { + auto builder = getBuilder(); + + for (auto targetMod : decl->GetModifiersOfType()) + { + String definition; + auto definitionToken = targetMod->definitionToken; + if (definitionToken.type == TokenType::StringLiteral) + { + definition = getStringLiteralTokenValue(definitionToken); + } + else + { + definition = definitionToken.Content; + } + + builder->addTargetIntrinsicDecoration(irInst, targetMod->targetToken.Content, definition.getUnownedSlice()); + } + } + + void addParamNameHint(IRInst* inst, ParameterInfo info) + { + if(auto decl = info.decl) + { + addNameHint(context, inst, decl); + } + else if( info.isThisParam ) + { + addNameHint(context, inst, "this"); + } + } + + LoweredValInfo lowerFuncDecl(FunctionDeclBase* decl) + { + // We are going to use a nested builder, because we will + // change the parent node that things get nested into. + // + NestedContext nestedContext(this); + auto subBuilder = nestedContext.getBuilder(); + auto subContext = nestedContext.getContet(); + + // The actual `IRFunction` that we emit needs to be nested + // inside of one `IRGeneric` for every outer `GenericDecl` + // in the declaration hierarchy. + + emitOuterGenerics(subContext, decl, decl); + + // Collect the parameter lists we will use for our new function. + ParameterLists parameterLists; + collectParameterLists(decl, ¶meterLists, kParameterListCollectMode_Default); + + // TODO: if there are any generic parameters in the collected list, then + // we need to output an IR function with generic parameters (or a generic + // with a nested function... the exact representation is still TBD). + + // In most cases the return type for a declaration can be read off the declaration + // itself, but things get a bit more complicated when we have to deal with + // accessors for subscript declarations (and eventually for properties). + // + // We compute a declaration to use for looking up the return type here: + CallableDecl* declForReturnType = decl; + if (auto accessorDecl = as(decl)) + { + // We are some kind of accessor, so the parent declaration should + // know the correct return type to expose. + // + auto parentDecl = accessorDecl->ParentDecl; + if (auto subscriptDecl = as(parentDecl)) + { + declForReturnType = subscriptDecl; + } + } + + // need to create an IR function here + + IRFunc* irFunc = subBuilder->createFunc(); + addNameHint(context, irFunc, decl); + addLinkageDecoration(context, irFunc, decl); + + List paramTypes; + + for( auto paramInfo : parameterLists.params ) + { + IRType* irParamType = lowerType(subContext, paramInfo.type); + + switch( paramInfo.direction ) + { + case kParameterDirection_In: + // Simple case of a by-value input parameter. + break; + + // If the parameter is declared `out` or `inout`, + // then we will represent it with a pointer type in + // the IR, but we will use a specialized pointer + // type that encodes the parameter direction information. + case kParameterDirection_Out: + irParamType = subBuilder->getOutType(irParamType); + break; + case kParameterDirection_InOut: + irParamType = subBuilder->getInOutType(irParamType); + break; + case kParameterDirection_Ref: + irParamType = subBuilder->getRefType(irParamType); + break; + + default: + SLANG_UNEXPECTED("unknown parameter direction"); + break; + } + + // If the parameter was explicitly marked as being a compile-time + // constant (`constexpr`), then attach that information to its + // IR-level type explicitly. + if( paramInfo.decl ) + { + irParamType = maybeGetConstExprType(irParamType, paramInfo.decl); + } + + paramTypes.add(irParamType); + } + + auto irResultType = lowerType(subContext, declForReturnType->ReturnType); + + if (auto setterDecl = as(decl)) + { + // We are lowering a "setter" accessor inside a subscript + // declaration, which means we don't want to *return* the + // stated return type of the subscript, but instead take + // it as a parameter. + // + IRType* irParamType = irResultType; + paramTypes.add(irParamType); + + // Instead, a setter always returns `void` + // + irResultType = subBuilder->getVoidType(); + } + + if( auto refAccessorDecl = as(decl) ) + { + // A `ref` accessor needs to return a *pointer* to the value + // being accessed, rather than a simple value. + irResultType = subBuilder->getPtrType(irResultType); + } + + auto irFuncType = subBuilder->getFuncType( + paramTypes.getCount(), + paramTypes.getBuffer(), + irResultType); + irFunc->setFullType(irFuncType); + + subBuilder->setInsertInto(irFunc); + + if (isImportedDecl(decl)) + { + // Always emit imported declarations as declarations, + // and not definitions. + } + else if (!decl->Body) + { + // This is a function declaration without a body. + // In Slang we currently try not to support forward declarations + // (although we might have to give in eventually), so + // this case should really only occur for builtin declarations. + } + else + { + // This is a function definition, so we need to actually + // construct IR for the body... + IRBlock* entryBlock = subBuilder->emitBlock(); + subBuilder->setInsertInto(entryBlock); + + UInt paramTypeIndex = 0; + for( auto paramInfo : parameterLists.params ) + { + auto irParamType = paramTypes[paramTypeIndex++]; + + LoweredValInfo paramVal; + + switch( paramInfo.direction ) + { + default: + { + // The parameter is being used for input/output purposes, + // so it will lower to an actual parameter with a pointer type. + // + // TODO: Is this the best representation we can use? + + IRParam* irParamPtr = subBuilder->emitParam(irParamType); + if(auto paramDecl = paramInfo.decl) + { + addVarDecorations(context, irParamPtr, paramDecl); + subBuilder->addHighLevelDeclDecoration(irParamPtr, paramDecl); + } + addParamNameHint(irParamPtr, paramInfo); + + paramVal = LoweredValInfo::ptr(irParamPtr); + + // TODO: We might want to copy the pointed-to value into + // a temporary at the start of the function, and then copy + // back out at the end, so that we don't have to worry + // about things like aliasing in the function body. + // + // For now we will just use the storage that was passed + // in by the caller, knowing that our current lowering + // at call sites will guarantee a fresh/unique location. + } + break; + + case kParameterDirection_In: + { + // Simple case of a by-value input parameter. + // + // We start by declaring an IR parameter of the same type. + // + auto paramDecl = paramInfo.decl; + IRParam* irParam = subBuilder->emitParam(irParamType); + if( paramDecl ) + { + addVarDecorations(context, irParam, paramDecl); + subBuilder->addHighLevelDeclDecoration(irParam, paramDecl); + } + addParamNameHint(irParam, paramInfo); + paramVal = LoweredValInfo::simple(irParam); + // + // HLSL allows a function parameter to be used as a local + // variable in the function body (just like C/C++), so + // we need to support that case as well. + // + // However, if we notice that the parameter was marked + // `const`, then we can skip this step. + // + // TODO: we should consider having all parameter be implicitly + // immutable except in a specific "compatibility mode." + // + if(paramDecl && paramDecl->FindModifier()) + { + // This parameter was declared to be immutable, + // so there should be no assignment to it in the + // function body, and we don't need a temporary. + } + else + { + // The parameter migth get used as a temporary in + // the function body. We will allocate a mutable + // local variable for is value, and then assign + // from the parameter to the local at the start + // of the function. + // + auto irLocal = subBuilder->emitVar(irParamType); + auto localVal = LoweredValInfo::ptr(irLocal); + assign(subContext, localVal, paramVal); + // + // When code later in the body of the function refers + // to the parameter declaration, it will actually refer + // to the value stored in the local variable. + // + paramVal = localVal; + } + } + break; + } + + if( auto paramDecl = paramInfo.decl ) + { + setValue(subContext, paramDecl, paramVal); + } + + if (paramInfo.isThisParam) + { + subContext->thisVal = paramVal; + } + } + + if (auto setterDecl = as(decl)) + { + // Add the IR parameter for the new value + IRType* irParamType = irResultType; + auto irParam = subBuilder->emitParam(irParamType); + addNameHint(context, irParam, "newValue"); + + // TODO: we need some way to wire this up to the `newValue` + // or whatever name we give for that parameter inside + // the setter body. + } + + { + + auto attr = decl->FindModifier(); + + // I needed to test for patchConstantFuncDecl here + // because it is only set if validateEntryPoint is called with Hull as the required stage + // If I just build domain shader, and then the attribute exists, but patchConstantFuncDecl is not set + // and thus leads to a crash. + if (attr && attr->patchConstantFuncDecl) + { + // We need to lower the function + FuncDecl* patchConstantFunc = attr->patchConstantFuncDecl; + assert(patchConstantFunc); + + // Convert the patch constant function into IRInst + IRInst* irPatchConstantFunc = getSimpleVal(context, ensureDecl(subContext, patchConstantFunc)); + + // Attach a decoration so that our IR function references + // the patch constant function. + // + subContext->irBuilder->addPatchConstantFuncDecoration( + irFunc, + irPatchConstantFunc); + + } + } + + // Lower body + + lowerStmt(subContext, decl->Body); + + // We need to carefully add a terminator instruction to the end + // of the body, in case the user didn't do so. + if (!subContext->irBuilder->getBlock()->getTerminator()) + { + if(as(irResultType)) + { + // `void`-returning function can get an implicit + // return on exit of the body statement. + subContext->irBuilder->emitReturn(); + } + else + { + // Value-returning function is expected to `return` + // on every control-flow path. We need to enforce + // this by putting an `unreachable` terminator here, + // and then emit a dataflow error if this block + // can't be eliminated. + subContext->irBuilder->emitMissingReturn(); + } + } + } + + getBuilder()->addHighLevelDeclDecoration(irFunc, decl); + + // If this declaration was marked as being an intrinsic for a particular + // target, then we should reflect that here. + for( auto targetMod : decl->GetModifiersOfType() ) + { + // `targetMod` indicates that this particular declaration represents + // a specialized definition of the particular function for the given + // target, and we need to reflect that at the IR level. + + getBuilder()->addTargetDecoration(irFunc, targetMod->targetToken.Content); + } + + // If this declaration was marked as having a target-specific lowering + // for a particular target, then handle that here. + addTargetIntrinsicDecorations(irFunc, decl); + + // If this declaration requires certain GLSL extension (or a particular GLSL version) + // for it to be usable, then declare that here. + // + // TODO: We should wrap this an `SpecializedForTargetModifier` together into a single + // case for enumerating the "capabilities" that a declaration requires. + // + for(auto extensionMod : decl->GetModifiersOfType()) + { + getBuilder()->addRequireGLSLExtensionDecoration(irFunc, extensionMod->extensionNameToken.Content); + } + for(auto versionMod : decl->GetModifiersOfType()) + { + getBuilder()->addRequireGLSLVersionDecoration(irFunc, Int(getIntegerLiteralValue(versionMod->versionNumberToken))); + } + + if(decl->FindModifier()) + { + getBuilder()->addSimpleDecoration(irFunc); + } + + if (decl->FindModifier()) + { + getBuilder()->addSimpleDecoration(irFunc); + } + + // For convenience, ensure that any additional global + // values that were emitted while outputting the function + // body appear before the function itself in the list + // of global values. + irFunc->moveToEnd(); + return LoweredValInfo::simple(finishOuterGenerics(subBuilder, irFunc)); + } + + LoweredValInfo visitGenericDecl(GenericDecl * genDecl) + { + // TODO: Should this just always visit/lower the inner decl? + + if (auto innerFuncDecl = as(genDecl->inner)) + return ensureDecl(context, innerFuncDecl); + else if (auto innerStructDecl = as(genDecl->inner)) + { + ensureDecl(context, innerStructDecl); + return LoweredValInfo(); + } + else if( auto extensionDecl = as(genDecl->inner) ) + { + return ensureDecl(context, extensionDecl); + } + SLANG_RELEASE_ASSERT(false); + UNREACHABLE_RETURN(LoweredValInfo()); + } + + LoweredValInfo visitFunctionDeclBase(FunctionDeclBase* decl) + { + // A function declaration may have multiple, target-specific + // overloads, and we need to emit an IR version of each of these. + + // The front end will form a linked list of declarations with + // the same signature, whenever there is any kind of redeclaration. + // We will look to see if that linked list has been formed. + auto primaryDecl = decl->primaryDecl; + + if (!primaryDecl) + { + // If there is no linked list then we are in the ordinary + // case with a single declaration, and no special handling + // is needed. + return lowerFuncDecl(decl); + } + + // Otherwise, we need to walk the linked list of declarations + // and make sure to emit IR code for any targets that need it. + + // TODO: Need to be careful about how this is approached, + // to avoid emitting a bunch of extra definitions in the IR. + + auto primaryFuncDecl = as(primaryDecl); + SLANG_ASSERT(primaryFuncDecl); + LoweredValInfo result = lowerFuncDecl(primaryFuncDecl); + for (auto dd = primaryDecl->nextDecl; dd; dd = dd->nextDecl) + { + auto funcDecl = as(dd); + SLANG_ASSERT(funcDecl); + lowerFuncDecl(funcDecl); + } + return result; + } +}; + +LoweredValInfo lowerDecl( + IRGenContext* context, + DeclBase* decl) +{ + IRBuilderSourceLocRAII sourceLocInfo(context->irBuilder, decl->loc); + + DeclLoweringVisitor visitor; + visitor.context = context; + + try + { + return visitor.dispatch(decl); + } + // Don't emit any context message for an explicit `AbortCompilationException` + // because it should only happen when an error is already emitted. + catch(AbortCompilationException&) { throw; } + catch(...) + { + context->getSink()->noteInternalErrorLoc(decl->loc); + throw; + } +} + +// Ensure that a version of the given declaration has been emitted to the IR +LoweredValInfo ensureDecl( + IRGenContext* context, + Decl* decl) +{ + auto shared = context->shared; + + LoweredValInfo result; + + // Look for an existing value installed in this context + auto env = context->env; + while(env) + { + if(env->mapDeclToValue.TryGetValue(decl, result)) + return result; + + env = env->outer; + } + + IRBuilder subIRBuilder; + subIRBuilder.sharedBuilder = context->irBuilder->sharedBuilder; + subIRBuilder.setInsertInto(subIRBuilder.sharedBuilder->module->getModuleInst()); + + IRGenEnv subEnv; + subEnv.outer = context->env; + + IRGenContext subContext = *context; + subContext.irBuilder = &subIRBuilder; + subContext.env = &subEnv; + + result = lowerDecl(&subContext, decl); + + // By default assume that any value we are lowering represents + // something that should be installed globally. + setGlobalValue(shared, decl, result); + + return result; +} + +IRInst* lowerSubstitutionArg( + IRGenContext* context, + Val* val) +{ + if (auto type = dynamicCast(val)) + { + return lowerType(context, type); + } + else if (auto declaredSubtypeWitness = as(val)) + { + // We need to look up the IR-level representation of the witness (which will be a witness table). + auto irWitnessTable = getSimpleVal( + context, + emitDeclRef( + context, + declaredSubtypeWitness->declRef, + context->irBuilder->getWitnessTableType())); + return irWitnessTable; + } + else + { + SLANG_UNIMPLEMENTED_X("value cases"); + UNREACHABLE_RETURN(nullptr); + } +} + +// Can the IR lowered version of this declaration ever be an `IRGeneric`? +bool canDeclLowerToAGeneric(RefPtr decl) +{ + // A callable decl lowers to an `IRFunc`, and can be generic + if(as(decl)) return true; + + // An aggregate type decl lowers to an `IRStruct`, and can be generic + if(as(decl)) return true; + + // An inheritance decl lowers to an `IRWitnessTable`, and can be generic + if(as(decl)) return true; + + // A `typedef` declaration nested under a generic will turn into + // a generic that returns a type (a simple type-level function). + if(as(decl)) return true; + + return false; +} + +LoweredValInfo emitDeclRef( + IRGenContext* context, + RefPtr decl, + RefPtr subst, + IRType* type) +{ + // We need to proceed by considering the specializations that + // have been put in place. + + // Ignore any global generic type substitutions during lowering. + // Really, we don't even expect these to appear. + while(auto globalGenericSubst = as(subst)) + subst = globalGenericSubst->outer; + + // If the declaration would not get wrapped in a `IRGeneric`, + // even if it is nested inside of an AST `GenericDecl`, then + // we should also ignore any generic substitutions. + if(!canDeclLowerToAGeneric(decl)) + { + while(auto genericSubst = as(subst)) + subst = genericSubst->outer; + } + + // In the simplest case, there is no specialization going + // on, and the decl-ref turns into a reference to the + // lowered IR value for the declaration. + if(!subst) + { + LoweredValInfo loweredDecl = ensureDecl(context, decl); + return loweredDecl; + } + + // Otherwise, we look at the kind of substitution, and let it guide us. + if(auto genericSubst = subst.as()) + { + // A generic substitution means we will need to output + // a `specialize` instruction to specialize the generic. + // + // First we want to emit the value without generic specialization + // applied, to get a correct value for it. + // + // Note: we only "unwrap" a single layer from the + // substitutions here, because the underlying declaration + // might be nested in multiple generics, or it might + // come from an interface. + // + LoweredValInfo genericVal = emitDeclRef( + context, + decl, + genericSubst->outer, + context->irBuilder->getGenericKind()); + + // There's no reason to specialize something that maps to a NULL pointer. + if (genericVal.flavor == LoweredValInfo::Flavor::None) + return LoweredValInfo(); + + // We can only really specialize things that map to single values. + // It would be an error if we got a non-`None` value that + // wasn't somehow a single value. + auto irGenericVal = getSimpleVal(context, genericVal); + + // We have the IR value for the generic we'd like to specialize, + // and now we need to get the value for the arguments. + List irArgs; + for (auto argVal : genericSubst->args) + { + auto irArgVal = lowerSimpleVal(context, argVal); + SLANG_ASSERT(irArgVal); + irArgs.add(irArgVal); + } + + // Once we have both the generic and its arguments, + // we can emit a `specialize` instruction and use + // its value as the result. + auto irSpecializedVal = context->irBuilder->emitSpecializeInst( + type, + irGenericVal, + irArgs.getCount(), + irArgs.getBuffer()); + + return LoweredValInfo::simple(irSpecializedVal); + } + else if(auto thisTypeSubst = subst.as()) + { + if(decl.Ptr() == thisTypeSubst->interfaceDecl) + { + // This is a reference to the interface type itself, + // through the this-type substitution, so it is really + // a reference to the this-type. + return lowerType(context, thisTypeSubst->witness->sub); + } + + // Somebody is trying to look up an interface requirement + // "through" some concrete type. We need to lower this decl-ref + // as a lookup of the corresponding member in a witness table. + // + // The witness table itself is referenced by the this-type + // substitution, so we can just lower that. + // + // Note: unlike the case for generics above, in the interface-lookup + // case, we don't end up caring about any further outer substitutions. + // That is because even if we are naming `ISomething.doIt()`, + // a method inside a generic interface, we don't actually care + // about the substitution of `Foo` for the parameter `T` of + // `ISomething`. That is because we really care about the + // witness table for the concrete type that conforms to `ISomething`. + // + auto irWitnessTable = lowerSimpleVal(context, thisTypeSubst->witness); + // + // The key to use for looking up the interface member is + // derived from the declaration. + // + auto irRequirementKey = getInterfaceRequirementKey(context, decl); + // + // Those two pieces of information tell us what we need to + // do in order to look up the value that satisfied the requirement. + // + auto irSatisfyingVal = context->irBuilder->emitLookupInterfaceMethodInst( + type, + irWitnessTable, + irRequirementKey); + return LoweredValInfo::simple(irSatisfyingVal); + } + else + { + SLANG_UNEXPECTED("uhandled substitution type"); + UNREACHABLE_RETURN(LoweredValInfo()); + } +} + +LoweredValInfo emitDeclRef( + IRGenContext* context, + DeclRef declRef, + IRType* type) +{ + return emitDeclRef( + context, + declRef.decl, + declRef.substitutions.substitutions, + type); +} + +static void lowerFrontEndEntryPointToIR( + IRGenContext* context, + EntryPoint* entryPoint) +{ + // TODO: We should emit an entry point as a dedicated IR function + // (distinct from the IR function used if it were called normally), + // with a mangled name based on the original function name plus + // the stage for which it is being compiled as an entry point (so + // that entry points for distinct stages always have distinct names). + // + // For now we just have an (implicit) constraint that a given + // function should only be used as an entry point for one stage, + // and any such function should *not* be used as an ordinary function. + + auto entryPointFuncDecl = entryPoint->getFuncDecl(); + + auto builder = context->irBuilder; + builder->setInsertInto(builder->getModule()->getModuleInst()); + + auto loweredEntryPointFunc = getSimpleVal(context, + ensureDecl(context, entryPointFuncDecl)); + + // Attach a marker decoration so that we recognize + // this as an entry point. + // + IRInst* instToDecorate = loweredEntryPointFunc; + if(auto irGeneric = as(instToDecorate)) + { + instToDecorate = findGenericReturnVal(irGeneric); + } + builder->addEntryPointDecoration(instToDecorate); +} + +static void lowerProgramEntryPointToIR( + IRGenContext* context, + EntryPoint* entryPoint) +{ + // First, lower the entry point like an ordinary function + + auto session = context->getSession(); + auto entryPointFuncDeclRef = entryPoint->getFuncDeclRef(); + auto entryPointFuncType = lowerType(context, getFuncType(session, entryPointFuncDeclRef)); + + auto builder = context->irBuilder; + builder->setInsertInto(builder->getModule()->getModuleInst()); + + auto loweredEntryPointFunc = getSimpleVal(context, + emitDeclRef(context, entryPointFuncDeclRef, entryPointFuncType)); + + // + if(!loweredEntryPointFunc->findDecoration()) + { + builder->addExportDecoration(loweredEntryPointFunc, getMangledName(entryPointFuncDeclRef).getUnownedSlice()); + } + + // We may have shader parameters of interface/existential type, + // which need us to supply concrete type information for specialization. + // + auto existentialTypeArgCount = entryPoint->getExistentialTypeArgCount(); + if( existentialTypeArgCount ) + { + List existentialSlotArgs; + for( Index ii = 0; ii < existentialTypeArgCount; ++ii ) + { + auto arg = entryPoint->getExistentialTypeArg(ii); + + auto irArgType = lowerType(context, arg.type); + auto irWitnessTable = lowerSimpleVal(context, arg.witness); + + existentialSlotArgs.add(irArgType); + existentialSlotArgs.add(irWitnessTable); + } + + builder->addBindExistentialSlotsDecoration(loweredEntryPointFunc, existentialSlotArgs.getCount(), existentialSlotArgs.getBuffer()); + } + + + +} + + /// Ensure that `decl` and all relevant declarations under it get emitted. +static void ensureAllDeclsRec( + IRGenContext* context, + Decl* decl) +{ + ensureDecl(context, decl); + + // Note: We are checking here for aggregate type declarations, and + // not for `ContainerDecl`s in general. This is because many kinds + // of container declarations will already take responsibility for emitting + // their children directly (e.g., a function declaration is responsible + // for emitting its own parameters). + // + // Aggregate types are the main case where we can emit an outer declaration + // and not the stuff nested inside of it. + // + if(auto containerDecl = as(decl)) + { + for (auto memberDecl : containerDecl->Members) + { + ensureAllDeclsRec(context, memberDecl); + } + } +} + +IRModule* generateIRForTranslationUnit( + TranslationUnitRequest* translationUnit) +{ + auto compileRequest = translationUnit->compileRequest; + + SharedIRGenContext sharedContextStorage( + translationUnit->getSession(), + translationUnit->compileRequest->getSink(), + translationUnit->getModuleDecl()); + SharedIRGenContext* sharedContext = &sharedContextStorage; + + IRGenContext contextStorage(sharedContext); + IRGenContext* context = &contextStorage; + + SharedIRBuilder sharedBuilderStorage; + SharedIRBuilder* sharedBuilder = &sharedBuilderStorage; + sharedBuilder->module = nullptr; + sharedBuilder->session = compileRequest->getSession(); + + IRBuilder builderStorage; + IRBuilder* builder = &builderStorage; + builder->sharedBuilder = sharedBuilder; + + IRModule* module = builder->createModule(); + sharedBuilder->module = module; + + context->irBuilder = builder; + + // We need to emit IR for all public/exported symbols + // in the translation unit. + // + // For now, we will assume that *all* global-scope declarations + // represent public/exported symbols. + + // First, ensure that all entry points have been emitted, + // in case they require special handling. + for (auto entryPoint : translationUnit->entryPoints) + { + lowerFrontEndEntryPointToIR(context, entryPoint); + } + + // + // Next, ensure that all other global declarations have + // been emitted. + for (auto decl : translationUnit->getModuleDecl()->Members) + { + ensureAllDeclsRec(context, decl); + } + +#if 0 + fprintf(stderr, "### GENERATED\n"); + dumpIR(module); + fprintf(stderr, "###\n"); +#endif + + validateIRModuleIfEnabled(compileRequest, module); + + // We will perform certain "mandatory" optimization passes now. + // These passes serve two purposes: + // + // 1. To simplify the code that we use in backend compilation, + // or when serializing/deserializing modules, so that we can + // amortize this effort when we compile multiple entry points + // that use the same module(s). + // + // 2. To ensure certain semantic properties that can't be + // validated without dataflow information. For example, we want + // to detect when a variable might be used before it is initialized. + + // Note: if you need to debug the IR that is created before + // any mandatory optimizations have been applied, then + // uncomment this line while debugging. + + // dumpIR(module); + + // First, attempt to promote local variables to SSA + // temporaries whenever possible. + constructSSA(module); + + // Do basic constant folding and dead code elimination + // using Sparse Conditional Constant Propagation (SCCP) + // + applySparseConditionalConstantPropagation(module); + + // Propagate `constexpr`-ness through the dataflow graph (and the + // call graph) based on constraints imposed by different instructions. + propagateConstExpr(module, compileRequest->getSink()); + + // TODO: give error messages if any `undefined` or + // `unreachable` instructions remain. + + checkForMissingReturns(module, compileRequest->getSink()); + + // TODO: consider doing some more aggressive optimizations + // (in particular specialization of generics) here, so + // that we can avoid doing them downstream. + // + // Note: doing specialization or inlining involving code + // from other modules potentially makes the IR we generate + // "fragile" in that we'd now need to recompile when + // a module we depend on changes. + + validateIRModuleIfEnabled(compileRequest, module); + + // If we are being sked to dump IR during compilation, + // then we can dump the initial IR for the module here. + if(compileRequest->shouldDumpIR) + { + DiagnosticSinkWriter writer(compileRequest->getSink()); + dumpIR(module, &writer); + } + + return module; +} + +RefPtr generateIRForProgram( + Session* session, + Program* program, + DiagnosticSink* sink) +{ +// auto compileRequest = translationUnit->compileRequest; + + SharedIRGenContext sharedContextStorage( + session, + sink); + SharedIRGenContext* sharedContext = &sharedContextStorage; + + IRGenContext contextStorage(sharedContext); + IRGenContext* context = &contextStorage; + + SharedIRBuilder sharedBuilderStorage; + SharedIRBuilder* sharedBuilder = &sharedBuilderStorage; + sharedBuilder->module = nullptr; + sharedBuilder->session = session; + + IRBuilder builderStorage; + IRBuilder* builder = &builderStorage; + builder->sharedBuilder = sharedBuilder; + + RefPtr module = builder->createModule(); + sharedBuilder->module = module; + + context->irBuilder = builder; + + // We need to emit symbols for all of the entry + // points in the program; this is especially + // important in the case where a generic entry + // point is being specialized. + // + for(auto entryPoint : program->getEntryPoints()) + { + lowerProgramEntryPointToIR(context, entryPoint); + } + + + // Now lower all the arguments supplied for global generic + // type parameters. + // + for (RefPtr subst = program->getGlobalGenericSubstitution(); subst; subst = subst->outer) + { + auto gSubst = subst.as(); + if(!gSubst) + continue; + + IRInst* typeParam = getSimpleVal(context, ensureDecl(context, gSubst->paramDecl)); + IRType* typeVal = lowerType(context, gSubst->actualType); + + // bind `typeParam` to `typeVal` + builder->emitBindGlobalGenericParam(typeParam, typeVal); + + for (auto& constraintArg : gSubst->constraintArgs) + { + IRInst* constraintParam = getSimpleVal(context, ensureDecl(context, constraintArg.decl)); + IRInst* constraintVal = lowerSimpleVal(context, constraintArg.val); + + // bind `constraintParam` to `constraintVal` + builder->emitBindGlobalGenericParam(constraintParam, constraintVal); + } + } + + // We may have shader parameters of interface/existential type, + // which need us to supply concrete type information for specialization. + // + auto existentialTypeArgCount = program->getExistentialTypeArgCount(); + if( existentialTypeArgCount ) + { + List existentialSlotArgs; + for( Index ii = 0; ii < existentialTypeArgCount; ++ii ) + { + auto arg = program->getExistentialTypeArg(ii); + + auto irArgType = lowerType(context, arg.type); + auto irWitnessTable = lowerSimpleVal(context, arg.witness); + + existentialSlotArgs.add(irArgType); + existentialSlotArgs.add(irWitnessTable); + } + + builder->emitBindGlobalExistentialSlots(existentialSlotArgs.getCount(), existentialSlotArgs.getBuffer()); + } + + + // TODO: Should we apply any of the validation or + // mandatory optimization passes here? + + return module; +} + +} // namespace Slang diff --git a/source/slang/slang-lower-to-ir.h b/source/slang/slang-lower-to-ir.h new file mode 100644 index 000000000..060efb88b --- /dev/null +++ b/source/slang/slang-lower-to-ir.h @@ -0,0 +1,28 @@ +// slang-lower-to-ir.h +#ifndef SLANG_LOWER_TO_IR_H_INCLUDED +#define SLANG_LOWER_TO_IR_H_INCLUDED + +// The lowering step translates from a (type-checked) AST into +// our intermediate representation, to facilitate further +// optimization and transformation. + +#include "../core/slang-basic.h" + +#include "slang-compiler.h" +#include "slang-ir.h" + +namespace Slang +{ + class EntryPoint; + class ProgramLayout; + class TranslationUnitRequest; + + IRModule* generateIRForTranslationUnit( + TranslationUnitRequest* translationUnit); + + RefPtr generateIRForProgram( + Session* session, + Program* program, + DiagnosticSink* sink); +} +#endif diff --git a/source/slang/slang-mangle.cpp b/source/slang/slang-mangle.cpp new file mode 100644 index 000000000..16f7b64bb --- /dev/null +++ b/source/slang/slang-mangle.cpp @@ -0,0 +1,478 @@ +#include "slang-mangle.h" + +#include "slang-name.h" +#include "slang-syntax.h" + +namespace Slang +{ + struct ManglingContext + { + StringBuilder sb; + }; + + void emitRaw( + ManglingContext* context, + char const* text) + { + context->sb.append(text); + } + + void emit( + ManglingContext* context, + UInt value) + { + context->sb.append(value); + } + + void emit( + ManglingContext* context, + String const& value) + { + context->sb.append(value); + } + + void emitName( + ManglingContext* context, + Name* name) + { + String str = getText(name); + + // If the name consists of only traditional "identifer characters" + // (`[a-zA-Z_]`), then we wnat to emit it more or less directly. + // + // If it contains code points outside that range, we'll need to + // do something to encode them. I don't want to deal with + // that right now, so I'm going to ignore it. + + // We prefix the string with its byte length, so that + // decoding doesn't have to worry about finding a terminator. + Index length = str.getLength(); + emit(context, length); + context->sb.append(str); + } + + void emitVal( + ManglingContext* context, + Val* val); + + void emitQualifiedName( + ManglingContext* context, + DeclRef declRef); + + void emitSimpleIntVal( + ManglingContext* context, + Val* val) + { + if( auto constVal = as(val) ) + { + auto cVal = constVal->value; + if(cVal >= 0 && cVal <= 9 ) + { + emit(context, (UInt)cVal); + return; + } + } + + // Fallback: + emitVal(context, val); + } + + void emitBaseType( + ManglingContext* context, + BaseType baseType) + { + switch( baseType ) + { + case BaseType::Void: emitRaw(context, "V"); break; + case BaseType::Bool: emitRaw(context, "b"); break; + case BaseType::Int8: emitRaw(context, "c"); break; + case BaseType::Int16: emitRaw(context, "s"); break; + case BaseType::Int: emitRaw(context, "i"); break; + case BaseType::Int64: emitRaw(context, "I"); break; + case BaseType::UInt8: emitRaw(context, "C"); break; + case BaseType::UInt16: emitRaw(context, "S"); break; + case BaseType::UInt: emitRaw(context, "u"); break; + case BaseType::UInt64: emitRaw(context, "U"); break; + case BaseType::Half: emitRaw(context, "h"); break; + case BaseType::Float: emitRaw(context, "f"); break; + case BaseType::Double: emitRaw(context, "d"); break; + break; + + default: + SLANG_UNEXPECTED("unimplemented case in mangling"); + break; + } + } + + void emitType( + ManglingContext* context, + Type* type) + { + // TODO: actually implement this bit... + + if( auto basicType = dynamicCast(type) ) + { + emitBaseType(context, basicType->baseType); + } + else if( auto vecType = dynamicCast(type) ) + { + emitRaw(context, "v"); + emitSimpleIntVal(context, vecType->elementCount); + emitType(context, vecType->elementType); + } + else if( auto matType = dynamicCast(type) ) + { + emitRaw(context, "m"); + emitSimpleIntVal(context, matType->getRowCount()); + emitRaw(context, "x"); + emitSimpleIntVal(context, matType->getColumnCount()); + emitType(context, matType->getElementType()); + } + else if( auto namedType = dynamicCast(type) ) + { + emitType(context, GetType(namedType->declRef)); + } + else if( auto declRefType = dynamicCast(type) ) + { + emitQualifiedName(context, declRefType->declRef); + } + else if (auto arrType = dynamicCast(type)) + { + emitRaw(context, "a"); + emitSimpleIntVal(context, arrType->ArrayLength); + emitType(context, arrType->baseType); + } + else if( auto taggedUnionType = dynamicCast(type) ) + { + emitRaw(context, "u"); + for( auto caseType : taggedUnionType->caseTypes ) + { + emitType(context, caseType); + } + emitRaw(context, "U"); + } + else + { + SLANG_UNEXPECTED("unimplemented case in mangling"); + } + } + + void emitVal( + ManglingContext* context, + Val* val) + { + if( auto type = dynamicCast(val) ) + { + emitType(context, type); + } + else if( auto witness = dynamicCast(val) ) + { + // We don't emit witnesses as part of a mangled + // name, because the way that the front-end + // arrived at the witness is not important; + // what matters is that the type constraint + // was satisfied. + // + // TODO: make sure we can't get name collisions + // between specializations of declarations + // with the same numbers of generic parameters, + // but different constraints. We might have + // to mangle in the constraints even when + // the whole thing is specialized... + } + else if( auto genericParamIntVal = dynamicCast(val) ) + { + // TODO: we shouldn't be including the names of generic parameters + // anywhere in mangled names, since changing parameter names + // shouldn't break binary compatibility. + // + // The right solution in the long term is for generic parameters + // (both types and values) to be mangled in terms of their + // "depth" (how many outer generics) and "index" (which + // parameter are they at the specified depth). + emitRaw(context, "K"); + emitName(context, genericParamIntVal->declRef.GetName()); + } + else if( auto constantIntVal = dynamicCast(val) ) + { + // TODO: need to figure out what prefix/suffix is needed + // to allow demangling later. + emitRaw(context, "k"); + emit(context, (UInt) constantIntVal->value); + } + else + { + SLANG_UNEXPECTED("unimplemented case in mangling"); + } + } + + void emitQualifiedName( + ManglingContext* context, + DeclRef declRef) + { + auto parentDeclRef = declRef.GetParent(); + auto parentGenericDeclRef = parentDeclRef.as(); + if( parentDeclRef ) + { + // In certain cases we want to skip emitting the parent + if(parentGenericDeclRef && (parentGenericDeclRef.getDecl()->inner.Ptr() != declRef.getDecl())) + { + } + else if(parentDeclRef.as()) + { + } + else + { + emitQualifiedName(context, parentDeclRef); + } + } + + // A generic declaration is kind of a pseudo-declaration + // as far as the user is concerned; so we don't want + // to emit its name. + if(auto genericDeclRef = declRef.as()) + { + return; + } + + // Inheritance declarations don't have meaningful names, + // and so we should emit them based on the type + // that is doing the inheriting. + if(auto inheritanceDeclRef = declRef.as()) + { + emit(context, "I"); + emitType(context, GetSup(inheritanceDeclRef)); + return; + } + + // Similarly, an extension doesn't have a name worth + // emitting, and we should base things on its target + // type instead. + if(auto extensionDeclRef = declRef.as()) + { + // TODO: as a special case, an "unconditional" extension + // that is in the same module as the type it extends should + // be treated as equivalent to the type itself. + emit(context, "X"); + emitType(context, GetTargetType(extensionDeclRef)); + return; + } + + emitName(context, declRef.GetName()); + + // Special case: accessors need some way to distinguish themselves + // so that a getter/setter/ref-er don't all compile to the same name. + { + if (declRef.is()) emitRaw(context, "Ag"); + if (declRef.is()) emitRaw(context, "As"); + if (declRef.is()) emitRaw(context, "Ar"); + } + + // Are we the "inner" declaration beneath a generic decl? + if(parentGenericDeclRef && (parentGenericDeclRef.getDecl()->inner.Ptr() == declRef.getDecl())) + { + // There are two cases here: either we have specializations + // in place for the parent generic declaration, or we don't. + + auto subst = findInnerMostGenericSubstitution(declRef.substitutions); + if( subst && subst->genericDecl == parentGenericDeclRef.getDecl() ) + { + // This is the case where we *do* have substitutions. + emitRaw(context, "G"); + UInt genericArgCount = subst->args.getCount(); + emit(context, genericArgCount); + for( auto aa : subst->args ) + { + emitVal(context, aa); + } + } + else + { + // We don't have substitutions, so we will emit + // information about the parameters of the generic here. + emitRaw(context, "g"); + UInt genericParameterCount = 0; + for( auto mm : getMembers(parentGenericDeclRef) ) + { + if(mm.is()) + { + genericParameterCount++; + } + else if(mm.is()) + { + genericParameterCount++; + } + else if(mm.is()) + { + genericParameterCount++; + } + else + { + } + } + + emit(context, genericParameterCount); + for( auto mm : getMembers(parentGenericDeclRef) ) + { + if(auto genericTypeParamDecl = mm.as()) + { + emitRaw(context, "T"); + } + else if(auto genericValueParamDecl = mm.as()) + { + emitRaw(context, "v"); + emitType(context, GetType(genericValueParamDecl)); + } + else if(mm.as()) + { + emitRaw(context, "C"); + // TODO: actually emit info about the constraint + } + else + { + } + } + } + } + + // If the declaration has parameters, then we need to emit + // those parameters to distinguish it from other declarations + // of the same name that might have different parameters. + // + // We'll also go ahead and emit the result type as well, + // just for completeness. + // + if( auto callableDeclRef = declRef.as()) + { + auto parameters = GetParameters(callableDeclRef); + UInt parameterCount = parameters.Count(); + + emitRaw(context, "p"); + emit(context, parameterCount); + emitRaw(context, "p"); + + for(auto paramDeclRef : parameters) + { + emitType(context, GetType(paramDeclRef)); + } + + // Don't print result type for an initializer/constructor, + // since it is implicit in the qualified name. + if (!callableDeclRef.is()) + { + emitType(context, GetResultType(callableDeclRef)); + } + } + } + + void mangleName( + ManglingContext* context, + DeclRef declRef) + { + // TODO: catch cases where the declaration should + // forward to something else? E.g., what if we + // are asked to mangle the name of a `typedef`? + + // We will start with a unique prefix to avoid + // clashes with user-defined symbols: + emitRaw(context, "_S"); + + auto decl = declRef.getDecl(); + + // Next we will add a bit of info to register + // the *kind* of declaration we are dealing with. + // + // Functions will get no prefix, since we assume + // they are a common case: + if(as(decl)) + {} + // Types will get a `T` prefix: + else if(as(decl)) + emitRaw(context, "T"); + else if(as(decl)) + emitRaw(context, "T"); + // Variables will get a `V` prefix: + // + // TODO: probably need to pull constant-buffer + // declarations out of this... + else if(as(decl)) + emitRaw(context, "V"); + else + { + // TODO: handle other cases + } + + // Now we encode the qualified name of the decl. + emitQualifiedName(context, declRef); + } + + String getMangledName(DeclRef const& declRef) + { + ManglingContext context; + mangleName(&context, declRef); + return context.sb.ProduceString(); + } + + String getMangledName(DeclRefBase const & declRef) + { + return getMangledName( + DeclRef(declRef.decl, declRef.substitutions)); + } + + String getMangledName(Decl* decl) + { + return getMangledName(makeDeclRef(decl)); + } + + String getMangledNameForConformanceWitness( + DeclRef sub, + DeclRef sup) + { + ManglingContext context; + emitRaw(&context, "_SW"); + emitQualifiedName(&context, sub); + emitQualifiedName(&context, sup); + return context.sb.ProduceString(); + } + + String getMangledNameForConformanceWitness( + DeclRef sub, + Type* sup) + { + // The mangled form for a witness that `sub` + // conforms to `sup` will be named: + // + // {Conforms(sub,sup)} => _SW{sub}{sup} + // + ManglingContext context; + emitRaw(&context, "_SW"); + emitQualifiedName(&context, sub); + emitType(&context, sup); + return context.sb.ProduceString(); + } + + String getMangledNameForConformanceWitness( + Type* sub, + Type* sup) + { + // The mangled form for a witness that `sub` + // conforms to `sup` will be named: + // + // {Conforms(sub,sup)} => _SW{sub}{sup} + // + ManglingContext context; + emitRaw(&context, "_SW"); + emitType(&context, sub); + emitType(&context, sup); + return context.sb.ProduceString(); + } + + String getMangledTypeName(Type* type) + { + ManglingContext context; + emitType(&context, type); + return context.sb.ProduceString(); + } + + +} diff --git a/source/slang/slang-mangle.h b/source/slang/slang-mangle.h new file mode 100644 index 000000000..5e03f8228 --- /dev/null +++ b/source/slang/slang-mangle.h @@ -0,0 +1,29 @@ +#ifndef SLANG_MANGLE_H_INCLUDED +#define SLANG_MANGLE_H_INCLUDED + +// This file implements the name mangling scheme for the Slang language. + +#include "../core/slang-basic.h" +#include "slang-syntax.h" + +namespace Slang +{ + struct IRSpecialize; + + String getMangledName(Decl* decl); + String getMangledName(DeclRef const & declRef); + String getMangledName(DeclRefBase const & declRef); + + String getMangledNameForConformanceWitness( + Type* sub, + Type* sup); + String getMangledNameForConformanceWitness( + DeclRef sub, + DeclRef sup); + String getMangledNameForConformanceWitness( + DeclRef sub, + Type* sup); + String getMangledTypeName(Type* type); +} + +#endif diff --git a/source/slang/slang-mangled-lexer.h b/source/slang/slang-mangled-lexer.h index 4890ae80f..8ec86c982 100644 --- a/source/slang/slang-mangled-lexer.h +++ b/source/slang/slang-mangled-lexer.h @@ -2,9 +2,9 @@ #ifndef SLANG_MANGLED_LEXER_H_INCLUDED #define SLANG_MANGLED_LEXER_H_INCLUDED -#include "../core/basic.h" +#include "../core/slang-basic.h" -#include "compiler.h" +#include "slang-compiler.h" namespace Slang { diff --git a/source/slang/slang-modifier-defs.h b/source/slang/slang-modifier-defs.h new file mode 100644 index 000000000..2af1f0f8f --- /dev/null +++ b/source/slang/slang-modifier-defs.h @@ -0,0 +1,463 @@ +// slang-modifier-defs.h + +// Syntax class definitions for modifiers. + +// Simple modifiers have no state beyond their identity +#define SIMPLE_MODIFIER(NAME) \ + SIMPLE_SYNTAX_CLASS(NAME##Modifier, Modifier) + +SIMPLE_MODIFIER(In); +SIMPLE_MODIFIER(Out); +SIMPLE_MODIFIER(Const); +SIMPLE_MODIFIER(Instance); +SIMPLE_MODIFIER(Builtin); +SIMPLE_MODIFIER(Inline); +SIMPLE_MODIFIER(Public); +SIMPLE_MODIFIER(Require); +SIMPLE_MODIFIER(Param); +SIMPLE_MODIFIER(Extern); +SIMPLE_MODIFIER(Input); +SIMPLE_MODIFIER(Transparent); +SIMPLE_MODIFIER(FromStdLib); +SIMPLE_MODIFIER(Prefix); +SIMPLE_MODIFIER(Postfix); +SIMPLE_MODIFIER(Exported); +SIMPLE_MODIFIER(ConstExpr); +SIMPLE_MODIFIER(GloballyCoherent) + +#undef SIMPLE_MODIFIER + +// A modifier that marks something as an operation that +// has a one-to-one translation to the IR, and thus +// has no direct definition in the high-level language. +// +SYNTAX_CLASS(IntrinsicOpModifier, Modifier) + + // token that names the intrinsic op + FIELD(Token, opToken) + + // The opcode for the intrinsic operation + FIELD_INIT(IROp, op, kIROp_Nop) +END_SYNTAX_CLASS() + +// A modifier that marks something as an intrinsic function, +// for some subset of targets. +SYNTAX_CLASS(TargetIntrinsicModifier, Modifier) + // Token that names the target that the operation + // is an intrisic for. + FIELD(Token, targetToken) + + // A custom definition for the operation + FIELD(Token, definitionToken) +END_SYNTAX_CLASS() + +// A modifier that marks a declaration as representing a +// specialization that should be preferred on a particular +// target. +SYNTAX_CLASS(SpecializedForTargetModifier, Modifier) + // Token that names the target that the operation + // has been specialized for. + FIELD(Token, targetToken) +END_SYNTAX_CLASS() + +// A modifier to tag something as an intrinsic that requires +// a certain GLSL extension to be enabled when used +SYNTAX_CLASS(RequiredGLSLExtensionModifier, Modifier) +FIELD(Token, extensionNameToken) +END_SYNTAX_CLASS() + +// A modifier to tag something as an intrinsic that requires +// a certain GLSL version to be enabled when used +SYNTAX_CLASS(RequiredGLSLVersionModifier, Modifier) +FIELD(Token, versionNumberToken) +END_SYNTAX_CLASS() + + +SIMPLE_SYNTAX_CLASS(InOutModifier, OutModifier) + +// `__ref` modifier for by-reference parameter passing +SIMPLE_SYNTAX_CLASS(RefModifier, Modifier) + +// This is a special sentinel modifier that gets added +// to the list when we have multiple variable declarations +// all sharing the same modifiers: +// +// static uniform int a : FOO, *b : register(x0); +// +// In this case both `a` and `b` share the syntax +// for part of their modifier list, but then have +// their own modifiers as well: +// +// a: SemanticModifier("FOO") --> SharedModifiers --> StaticModifier --> UniformModifier +// / +// b: RegisterModifier("x0") / +// +SIMPLE_SYNTAX_CLASS(SharedModifiers, Modifier) + +// A GLSL `layout` modifier +// +// We use a distinct modifier for each key that +// appears within the `layout(...)` construct, +// and each key might have an optional value token. +// +// TODO: We probably want a notion of "modifier groups" +// so that we can recover good source location info +// for modifiers that were part of the same vs. +// different constructs. +ABSTRACT_SYNTAX_CLASS(GLSLLayoutModifier, Modifier) + +// The token used to introduce the modifier is stored +// as the `nameToken` field. + +// TODO: may want to accept a full expression here +FIELD(Token, valToken) +END_SYNTAX_CLASS() + +// AST nodes to represent the begin/end of a `layout` modifier group +ABSTRACT_SYNTAX_CLASS(GLSLLayoutModifierGroupMarker, Modifier) +END_SYNTAX_CLASS() +SIMPLE_SYNTAX_CLASS(GLSLLayoutModifierGroupBegin, GLSLLayoutModifierGroupMarker) +SIMPLE_SYNTAX_CLASS(GLSLLayoutModifierGroupEnd, GLSLLayoutModifierGroupMarker) + +// We divide GLSL `layout` modifiers into those we have parsed +// (in the sense of having some notion of their semantics), and +// those we have not. +ABSTRACT_SYNTAX_CLASS(GLSLParsedLayoutModifier , GLSLLayoutModifier) +END_SYNTAX_CLASS() + +SIMPLE_SYNTAX_CLASS(GLSLUnparsedLayoutModifier , GLSLLayoutModifier) + +// Specific cases for known GLSL `layout` modifiers that we need to work with +SIMPLE_SYNTAX_CLASS(GLSLConstantIDLayoutModifier , GLSLParsedLayoutModifier) +SIMPLE_SYNTAX_CLASS(GLSLLocationLayoutModifier , GLSLParsedLayoutModifier) + +SIMPLE_SYNTAX_CLASS(GLSLLocalSizeLayoutModifier, GLSLUnparsedLayoutModifier) +SIMPLE_SYNTAX_CLASS(GLSLLocalSizeXLayoutModifier, GLSLLocalSizeLayoutModifier) +SIMPLE_SYNTAX_CLASS(GLSLLocalSizeYLayoutModifier, GLSLLocalSizeLayoutModifier) +SIMPLE_SYNTAX_CLASS(GLSLLocalSizeZLayoutModifier, GLSLLocalSizeLayoutModifier) + +// A catch-all for single-keyword modifiers +SIMPLE_SYNTAX_CLASS(SimpleModifier, Modifier) + +// Some GLSL-specific modifiers +SIMPLE_SYNTAX_CLASS(GLSLBufferModifier , SimpleModifier) +SIMPLE_SYNTAX_CLASS(GLSLWriteOnlyModifier, SimpleModifier) +SIMPLE_SYNTAX_CLASS(GLSLReadOnlyModifier , SimpleModifier) +SIMPLE_SYNTAX_CLASS(GLSLPatchModifier , SimpleModifier) + +// Indicates that this is a variable declaration that corresponds to +// a parameter block declaration in the source program. +SIMPLE_SYNTAX_CLASS(ImplicitParameterGroupVariableModifier , Modifier) + +// Indicates that this is a type that corresponds to the element +// type of a parameter block declaration in the source program. +SIMPLE_SYNTAX_CLASS(ImplicitParameterGroupElementTypeModifier, Modifier) + +// An HLSL semantic +ABSTRACT_SYNTAX_CLASS(HLSLSemantic, Modifier) + FIELD(Token, name) +END_SYNTAX_CLASS() + +// An HLSL semantic that affects layout +SYNTAX_CLASS(HLSLLayoutSemantic, HLSLSemantic) + + FIELD(Token, registerName) + FIELD(Token, componentMask) +END_SYNTAX_CLASS() + +// An HLSL `register` semantic +SYNTAX_CLASS(HLSLRegisterSemantic, HLSLLayoutSemantic) + FIELD(Token, spaceName) +END_SYNTAX_CLASS() + +// TODO(tfoley): `packoffset` +SIMPLE_SYNTAX_CLASS(HLSLPackOffsetSemantic, HLSLLayoutSemantic) + +// An HLSL semantic that just associated a declaration with a semantic name +SIMPLE_SYNTAX_CLASS(HLSLSimpleSemantic, HLSLSemantic) + +// GLSL + +// Directives that came in via the preprocessor, but +// that we need to keep around for later steps +SIMPLE_SYNTAX_CLASS(GLSLPreprocessorDirective, Modifier) + +// A GLSL `#version` directive +SYNTAX_CLASS(GLSLVersionDirective, GLSLPreprocessorDirective) + + // Token giving the version number to use + FIELD(Token, versionNumberToken) + + // Optional token giving the sub-profile to be used + FIELD(Token, glslProfileToken) +END_SYNTAX_CLASS() + +// A GLSL `#extension` directive +SYNTAX_CLASS(GLSLExtensionDirective, GLSLPreprocessorDirective) + + // Token giving the version number to use + FIELD(Token, extensionNameToken) + + // Optional token giving the sub-profile to be used + FIELD(Token, dispositionToken) +END_SYNTAX_CLASS() + +SYNTAX_CLASS(ParameterGroupReflectionName, Modifier) + FIELD(NameLoc, nameAndLoc) +END_SYNTAX_CLASS() + +// A modifier that indicates a built-in base type (e.g., `float`) +SYNTAX_CLASS(BuiltinTypeModifier, Modifier) + FIELD(BaseType, tag) +END_SYNTAX_CLASS() + +// A modifier that indicates a built-in type that isn't a base type (e.g., `vector`) +// +// TODO(tfoley): This deserves a better name than "magic" +SYNTAX_CLASS(MagicTypeModifier, Modifier) + FIELD(String, name) + FIELD(uint32_t, tag) +END_SYNTAX_CLASS() + +// A modifier applied to declarations of builtin types to indicate how they +// should be lowered to the IR. +// +// TODO: This should really subsume `BuiltinTypeModifier` and +// `MagicTypeModifier` so that we don't have to apply all of them. +SYNTAX_CLASS(IntrinsicTypeModifier, Modifier) + // The IR opcode to use when constructing a type + FIELD(uint32_t, irOp) + + // Additional literal opreands to provide when creating instances. + // (e.g., for a texture type this passes in shape/mutability info) + FIELD(List, irOperands) +END_SYNTAX_CLASS() + +// Modifiers that affect the storage layout for matrices +SIMPLE_SYNTAX_CLASS(MatrixLayoutModifier, Modifier) + +// Modifiers that specify row- and column-major layout, respectively +SIMPLE_SYNTAX_CLASS(RowMajorLayoutModifier, MatrixLayoutModifier) +SIMPLE_SYNTAX_CLASS(ColumnMajorLayoutModifier, MatrixLayoutModifier) + +// The HLSL flavor of those modifiers +SIMPLE_SYNTAX_CLASS(HLSLRowMajorLayoutModifier, RowMajorLayoutModifier) +SIMPLE_SYNTAX_CLASS(HLSLColumnMajorLayoutModifier, ColumnMajorLayoutModifier) + +// The GLSL flavor of those modifiers +// +// Note(tfoley): The GLSL versions of these modifiers are "backwards" +// in the sense that when a GLSL programmer requests row-major layout, +// we actually interpret that as requesting column-major. This makes +// sense because we interpret matrix conventions backwards from how +// GLSL specifies them. +SIMPLE_SYNTAX_CLASS(GLSLRowMajorLayoutModifier, ColumnMajorLayoutModifier) +SIMPLE_SYNTAX_CLASS(GLSLColumnMajorLayoutModifier, RowMajorLayoutModifier) + +// More HLSL Keyword + +ABSTRACT_SYNTAX_CLASS(InterpolationModeModifier, Modifier) +END_SYNTAX_CLASS() + +// HLSL `nointerpolation` modifier +SIMPLE_SYNTAX_CLASS(HLSLNoInterpolationModifier, InterpolationModeModifier) + +// HLSL `noperspective` modifier +SIMPLE_SYNTAX_CLASS(HLSLNoPerspectiveModifier, InterpolationModeModifier) + +// HLSL `linear` modifier +SIMPLE_SYNTAX_CLASS(HLSLLinearModifier, InterpolationModeModifier) + +// HLSL `sample` modifier +SIMPLE_SYNTAX_CLASS(HLSLSampleModifier, InterpolationModeModifier) + +// HLSL `centroid` modifier +SIMPLE_SYNTAX_CLASS(HLSLCentroidModifier, InterpolationModeModifier) + +// HLSL `precise` modifier +SIMPLE_SYNTAX_CLASS(PreciseModifier, Modifier) + +// HLSL `shared` modifier (which is used by the effect system, +// and shouldn't be confused with `groupshared`) +SIMPLE_SYNTAX_CLASS(HLSLEffectSharedModifier, Modifier) + +// HLSL `groupshared` modifier +SIMPLE_SYNTAX_CLASS(HLSLGroupSharedModifier, Modifier) + +// HLSL `static` modifier (probably doesn't need to be +// treated as HLSL-specific) +SIMPLE_SYNTAX_CLASS(HLSLStaticModifier, Modifier) + +// HLSL `uniform` modifier (distinct meaning from GLSL +// use of the keyword) +SIMPLE_SYNTAX_CLASS(HLSLUniformModifier, Modifier) + +// HLSL `volatile` modifier (ignored) +SIMPLE_SYNTAX_CLASS(HLSLVolatileModifier, Modifier) + +SYNTAX_CLASS(AttributeTargetModifier, Modifier) + // A class to which the declared attribute type is applicable + FIELD(SyntaxClass, syntaxClass) +END_SYNTAX_CLASS() + +// Base class for checked and unchecked `[name(arg0, ...)]` style attribute. +SYNTAX_CLASS(AttributeBase, Modifier) + SYNTAX_FIELD(List>, args) +END_SYNTAX_CLASS() + +// A `[name(...)]` attribute that hasn't undergone any semantic analysis. +// After analysis, this will be transformed into a more specific case. +SYNTAX_CLASS(UncheckedAttribute, AttributeBase) + FIELD(RefPtr, scope) +END_SYNTAX_CLASS() + +// A `[name(arg0, ...)]` style attribute that has been validated. +SYNTAX_CLASS(Attribute, AttributeBase) + FIELD(AttributeArgumentValueDict, intArgVals) +END_SYNTAX_CLASS() + +SYNTAX_CLASS(UserDefinedAttribute, Attribute) +END_SYNTAX_CLASS() + +SYNTAX_CLASS(AttributeUsageAttribute, Attribute) + FIELD(SyntaxClass, targetSyntaxClass) +END_SYNTAX_CLASS() + +// An `[unroll]` or `[unroll(count)]` attribute +SYNTAX_CLASS(UnrollAttribute, Attribute) + RAW(IntegerLiteralValue getCount();) +END_SYNTAX_CLASS() + +SIMPLE_SYNTAX_CLASS(LoopAttribute, Attribute) // `[loop]` +SIMPLE_SYNTAX_CLASS(FastOptAttribute, Attribute) // `[fastopt]` +SIMPLE_SYNTAX_CLASS(AllowUAVConditionAttribute, Attribute) // `[allow_uav_condition]` +SIMPLE_SYNTAX_CLASS(BranchAttribute, Attribute) // `[branch]` +SIMPLE_SYNTAX_CLASS(FlattenAttribute, Attribute) // `[flatten]` +SIMPLE_SYNTAX_CLASS(ForceCaseAttribute, Attribute) // `[forcecase]` +SIMPLE_SYNTAX_CLASS(CallAttribute, Attribute) // `[call]` + + +// [[vk_push_constant]] [[push_constant]] +SIMPLE_SYNTAX_CLASS(PushConstantAttribute, Attribute) + +// [[vk_shader_record]] [[shader_record]] +SIMPLE_SYNTAX_CLASS(ShaderRecordAttribute, Attribute) + +// [[vk_binding]] +SYNTAX_CLASS(GLSLBindingAttribute, Attribute) + FIELD(int32_t, binding = 0) + FIELD(int32_t, set = 0) +END_SYNTAX_CLASS() + +// TODO: for attributes that take arguments, the syntax node +// classes should provide accessors for the values of those arguments. + +SIMPLE_SYNTAX_CLASS(MaxTessFactorAttribute, Attribute) +SIMPLE_SYNTAX_CLASS(OutputControlPointsAttribute, Attribute) +SIMPLE_SYNTAX_CLASS(OutputTopologyAttribute, Attribute) +SIMPLE_SYNTAX_CLASS(PartitioningAttribute, Attribute) +SYNTAX_CLASS(PatchConstantFuncAttribute, Attribute) + FIELD(RefPtr, patchConstantFuncDecl) +END_SYNTAX_CLASS() +SIMPLE_SYNTAX_CLASS(DomainAttribute, Attribute) + +SIMPLE_SYNTAX_CLASS(EarlyDepthStencilAttribute, Attribute) // `[earlydepthstencil]` + +// An HLSL `[numthreads(x,y,z)]` attribute +SYNTAX_CLASS(NumThreadsAttribute, Attribute) + // The number of threads to use along each axis + // + // TODO: These should be accessors that use the + // ordinary `args` list, rather than side data. + FIELD(int32_t, x) + FIELD(int32_t, y) + FIELD(int32_t, z) +END_SYNTAX_CLASS() + +SYNTAX_CLASS(MaxVertexCountAttribute, Attribute) + // The number of max vertex count for geometry shader + // + // TODO: This should be an accessor that uses the + // ordinary `args` list, rather than side data. + FIELD(int32_t, value) +END_SYNTAX_CLASS() + +SYNTAX_CLASS(InstanceAttribute, Attribute) + // The number of instances to run for geometry shader + // + // TODO: This should be an accessor that uses the + // ordinary `args` list, rather than side data. + FIELD(int32_t, value) +END_SYNTAX_CLASS() + +// A `[shader("stageName")]` attribute, which marks an entry point +// to be compiled, and specifies the stage for that entry point +SYNTAX_CLASS(EntryPointAttribute, Attribute) + // The resolved stage that the entry point is targetting. + // + // TODO: This should be an accessor that uses the + // ordinary `args` list, rather than side data. + FIELD(Stage, stage); +END_SYNTAX_CLASS() + +// A `[__vulkanRayPayload]` attribute, which is used in the +// standard library implementation to indicate that a variable +// actually represents the input/output interface for a Vulkan +// ray tracing shader to pass per-ray payload information. +SIMPLE_SYNTAX_CLASS(VulkanRayPayloadAttribute, Attribute) + +// A `[__vulkanCallablePayload]` attribute, which is used in the +// standard library implementation to indicate that a variable +// actually represents the input/output interface for a Vulkan +// ray tracing shader to pass payload information to/from a callee. +SIMPLE_SYNTAX_CLASS(VulkanCallablePayloadAttribute, Attribute) + +// A `[__vulkanHitAttributes]` attribute, which is used in the +// standard library implementation to indicate that a variable +// actually represents the output interface for a Vulkan +// intersection shader to pass hit attribute information. +SIMPLE_SYNTAX_CLASS(VulkanHitAttributesAttribute, Attribute) + +// A `[mutating]` attribute, which indicates that a member +// function is allowed to modify things through its `this` +// argument. +// +SIMPLE_SYNTAX_CLASS(MutatingAttribute, Attribute) + +// A `[__readNone]` attribute, which indicates that a function +// computes its results strictly based on argument values, without +// reading or writing through any pointer arguments, or any other +// state that could be observed by a caller. +// +SIMPLE_SYNTAX_CLASS(ReadNoneAttribute, Attribute) + + +// HLSL modifiers for geometry shader input topology +SIMPLE_SYNTAX_CLASS(HLSLGeometryShaderInputPrimitiveTypeModifier, Modifier) +SIMPLE_SYNTAX_CLASS(HLSLPointModifier , HLSLGeometryShaderInputPrimitiveTypeModifier) +SIMPLE_SYNTAX_CLASS(HLSLLineModifier , HLSLGeometryShaderInputPrimitiveTypeModifier) +SIMPLE_SYNTAX_CLASS(HLSLTriangleModifier , HLSLGeometryShaderInputPrimitiveTypeModifier) +SIMPLE_SYNTAX_CLASS(HLSLLineAdjModifier , HLSLGeometryShaderInputPrimitiveTypeModifier) +SIMPLE_SYNTAX_CLASS(HLSLTriangleAdjModifier , HLSLGeometryShaderInputPrimitiveTypeModifier) + +// A modifier to be attached to syntax after we've computed layout +SYNTAX_CLASS(ComputedLayoutModifier, Modifier) + FIELD(RefPtr, layout) +END_SYNTAX_CLASS() + + +SYNTAX_CLASS(TupleVarModifier, Modifier) +// FIELD_INIT(TupleFieldModifier*, tupleField, nullptr) +END_SYNTAX_CLASS() + +// A modifier to indicate that a constructor/initializer can be used +// to perform implicit type conversion, and to specify the cost of +// the conversion, if applied. +SYNTAX_CLASS(ImplicitConversionModifier, Modifier) + // The conversion cost, used to rank conversions + FIELD(ConversionCost, cost) +END_SYNTAX_CLASS() + +SYNTAX_CLASS(FormatAttribute, Attribute) + FIELD(ImageFormat, format) +END_SYNTAX_CLASS() diff --git a/source/slang/slang-name.cpp b/source/slang/slang-name.cpp new file mode 100644 index 000000000..8934b17bd --- /dev/null +++ b/source/slang/slang-name.cpp @@ -0,0 +1,37 @@ +// slang-name.cpp +#include "slang-name.h" + +namespace Slang { + +String getText(Name* name) +{ + if (!name) return String(); + return name->text; +} + +UnownedStringSlice getUnownedStringSliceText(Name* name) +{ + return name ? name->text.getUnownedSlice() : UnownedStringSlice(); +} + +Name* NamePool::getName(String const& text) +{ + RefPtr name; + if (rootPool->names.TryGetValue(text, name)) + return name; + + name = new Name(); + name->text = text; + rootPool->names.Add(text, name); + return name; +} + +Name* NamePool::tryGetName(String const& text) +{ + RefPtr name; + if (rootPool->names.TryGetValue(text, name)) + return name; + return nullptr; +} + +} // namespace Slang diff --git a/source/slang/slang-name.h b/source/slang/slang-name.h new file mode 100644 index 000000000..de04d5fdf --- /dev/null +++ b/source/slang/slang-name.h @@ -0,0 +1,86 @@ +// slang-name.h +#ifndef SLANG_NAME_H_INCLUDED +#define SLANG_NAME_H_INCLUDED + +// This file defines the `Name` type, used to represent +// the name of types, variables, etc. in the AST. + +#include "../core/slang-basic.h" + +namespace Slang { + +// The `Name` type is used to represent the name of a type, variable, etc. +// +// The key benefit of using `Name`s instead of raw strings is that `Name`s +// can be compared for equality just by testing pointer equality. Names +// also don't require any memory management; you can just retain an ordinary +// pointer to one and not deal with reference-counting overhead. +// +// In order to provide these benefits, a `Name` can only be created using +// a `NamePool` that owns the allocations for all the names (so they get +// cleaned up when the pool is deleted), and which is responsible for +// ensuring the uniqueness of name objects. +// +class Name : public RefObject +{ +public: + // The raw text of the name. + // + // Note that at some point in the future we might have other categories + // of name than "simple" names, and so this might change to a structured + // ADT instead of a simple string. + String text; +}; + +// Get the textual string representation of a name +// (e.g., so that it can be printed). +String getText(Name* name); + +/// Get the text as unowned string slice +UnownedStringSlice getUnownedStringSliceText(Name* name); + +// A `RootNamePool` is used to store and look up names. +// If two systems need to work together with names, and be sure that they +// get equivalent names for a string like `"Foo"`, then they need to use +// the same root name pool (directly or indirectly). +// +struct RootNamePool +{ + // The mapping from text strings to the corresponding name. + Dictionary > names; +}; + +// A `NamePool` is effectively a way of storing a subset of the +// names that have been created through a `RootNamePool`. +// +// The intention is that eventually we will add the ability to clean +// up a `NamePool`, and remove the names it created from the corresponding +// `RootNamePool` *if* those names are no longer in use. +// +// The goal of such an approach would be to ensure that the memory +// usage of a `Session` can't bloat over time just because of multiple +// `CompileRequest`s being created, used, and then destroyed (each time +// adding just a few more strings to the name mapping). +// +struct NamePool +{ + // Find or create the `Name` that represents the given `text`. + Name* getName(String const& text); + // Try find the `Name` that represents the given `text`. + // If the name does not exist, return nullptr + Name* tryGetName(String const& text); + // Set the parent name pool to use for lookup + void setRootNamePool(RootNamePool* rootNamePool) + { + this->rootPool = rootNamePool; + } + + // + + // The root name pool to use for storage/lookup + RootNamePool* rootPool = nullptr; +}; + +} // namespace Slang + +#endif diff --git a/source/slang/slang-object-meta-begin.h b/source/slang/slang-object-meta-begin.h new file mode 100644 index 000000000..9c09e845d --- /dev/null +++ b/source/slang/slang-object-meta-begin.h @@ -0,0 +1,43 @@ +// slang-object-meta-begin.h + +#ifndef SYNTAX_CLASS +#error The 'SYNTAX_CLASS' macro should be defined before including 'object-meta-begin.h' +#endif + +#ifndef ABSTRACT_SYNTAX_CLASS +#define ABSTRACT_SYNTAX_CLASS(NAME, BASE) SYNTAX_CLASS(NAME, BASE) +#endif + +#ifndef END_SYNTAX_CLASS +#define END_SYNTAX_CLASS() /* empty */ +#endif + +#ifndef DECL_FIELD +#define DECL_FIELD(TYPE, NAME) SYNTAX_FIELD(TYPE, NAME) +#endif + +#ifndef SYNTAX_FIELD +#define SYNTAX_FIELD(TYPE, NAME) FIELD(TYPE, NAME) +#endif + +#ifndef FIELD_INIT +#define FIELD_INIT(TYPE, NAME, INIT) FIELD(TYPE, NAME) +#endif + +#ifndef FIELD +#define FIELD(...) /* empty */ +#endif + +#ifndef RAW +#define RAW(...) /* empty */ +#endif + +#define SIMPLE_SYNTAX_CLASS(NAME, BASE) SYNTAX_CLASS(NAME, BASE) END_SYNTAX_CLASS() + +// Hack to remove 'warning C4702: unreachable code' on VS2017, blocking compilation +// Note! This is matched in object-meta-end.h +#if _MSC_VER >= 1910 +#pragma warning(push) +#pragma warning(disable: 4702) +#endif + diff --git a/source/slang/slang-object-meta-end.h b/source/slang/slang-object-meta-end.h new file mode 100644 index 000000000..5018b5ede --- /dev/null +++ b/source/slang/slang-object-meta-end.h @@ -0,0 +1,17 @@ +// slang-object-meta-end.h + +#undef SYNTAX_CLASS +#undef ABSTRACT_SYNTAX_CLASS +#undef END_SYNTAX_CLASS +#undef SYNTAX_FIELD +#undef FIELD +#undef FIELD_INIT +#undef DECL_FIELD +#undef RAW +#undef SIMPLE_SYNTAX_CLASS + +// Hack to remove 'warning C4702: unreachable code' on VS2017, blocking compilation +// Note! This is matched in object-meta-begin.h +#if _MSC_VER >= 1910 +#pragma warning(pop) +#endif diff --git a/source/slang/slang-options.cpp b/source/slang/slang-options.cpp new file mode 100644 index 000000000..85a2662bc --- /dev/null +++ b/source/slang/slang-options.cpp @@ -0,0 +1,1356 @@ +// slang-options.cpp + +// Implementation of options parsing for `slangc` command line, +// and also for API interface that takes command-line argument strings. + +#include "../../slang.h" + +#include "slang-compiler.h" +#include "slang-profile.h" + +#include + +namespace Slang { + +SlangResult tryReadCommandLineArgumentRaw(DiagnosticSink* sink, char const* option, char const* const**ioCursor, char const* const*end, char const** argOut) +{ + *argOut = nullptr; + char const* const*& cursor = *ioCursor; + if (cursor == end) + { + sink->diagnose(SourceLoc(), Diagnostics::expectedArgumentForOption, option); + return SLANG_FAIL; + } + else + { + *argOut = *cursor++; + return SLANG_OK; + } +} + +SlangResult tryReadCommandLineArgument(DiagnosticSink* sink, char const* option, char const* const**ioCursor, char const* const*end, String& argOut) +{ + const char* arg; + SLANG_RETURN_ON_FAIL(tryReadCommandLineArgumentRaw(sink, option, ioCursor, end, &arg)); + argOut = arg; + return SLANG_OK; +} + +struct OptionsParser +{ + SlangSession* session = nullptr; + SlangCompileRequest* compileRequest = nullptr; + + Slang::EndToEndCompileRequest* requestImpl = nullptr; + + Slang::RefPtr sharedLibraryLoader; + + // A "translation unit" represents one or more source files + // that are processed as a single entity when it comes to + // semantic checking. + // + // For languages like HLSL, GLSL, and C, a translation unit + // is usually a single source file (which can then go on + // to `#include` other files into the same translation unit). + // + // For Slang, we support having multiple source files in + // a single translation unit, and indeed command-line `slangc` + // will always put all the source files into a single translation + // unit. + // + // We track information on the translation units that we + // create during options parsing, so that we can assocaite + // other entities with these translation units: + // + struct RawTranslationUnit + { + // What language is the translation unit using? + // + // Note: We do not support translation units that mix + // languages. + // + SlangSourceLanguage sourceLanguage; + + // Certain naming conventions imply a stage for + // a file with only a single entry point, and in + // those cases we will try to infer the stage from + // the file when it is possible. + // + Stage impliedStage; + + // We retain the Slang API level translation unit index, + // which we will call an "ID" inside the options parsing code. + // + // This will almost always be the index into the + // `rawTranslationUnits` array below, but could conceivably, + // be mismatched if we were parsing options for a compile + // request that already had some translation unit(s) added + // manually. + // + int translationUnitID; + }; + List rawTranslationUnits; + + // If we already have a translation unit for Slang code, then this will give its index. + // If not, it will be `-1`. + int slangTranslationUnitIndex = -1; + + // The number of input files that have been specified + int inputPathCount = 0; + + int translationUnitCount = 0; + int currentTranslationUnitIndex= -1; + + // An entry point represents a function to be checked and possibly have + // code generated in one of our translation units. An entry point + // needs to have an associated stage, which might come via the + // `-stage` command line option, or a `[shader("...")]` attribute + // in the source code. + // + struct RawEntryPoint + { + String name; + Stage stage = Stage::Unknown; + int translationUnitIndex = -1; + int entryPointID = -1; + + // State for tracking command-line errors + bool conflictingStagesSet = false; + bool redundantStageSet = false; + }; + // + // We collect the entry points in a "raw" array so that we can + // possibly associate them with a stage or translation unit + // after the fact. + // + List rawEntryPoints; + + // In the case where we have only a single entry point, + // the entry point and its options might be specified out + // of order, so we will keep a single `RawEntryPoint` around + // and use it as the target for any state-setting options + // before the first "proper" entry point is specified. + RawEntryPoint defaultEntryPoint; + + SlangCompileFlags flags = 0; + + struct RawOutput + { + String path; + CodeGenTarget impliedFormat = CodeGenTarget::Unknown; + int targetIndex = -1; + int entryPointIndex = -1; + }; + List rawOutputs; + + struct RawTarget + { + CodeGenTarget format = CodeGenTarget::Unknown; + ProfileVersion profileVersion = ProfileVersion::Unknown; + SlangTargetFlags targetFlags = 0; + int targetID = -1; + FloatingPointMode floatingPointMode = FloatingPointMode::Default; + + // State for tracking command-line errors + bool conflictingProfilesSet = false; + bool redundantProfileSet = false; + + }; + List rawTargets; + + RawTarget defaultTarget; + + void addSharedLibraryPath(SharedLibraryType libType, const String& path) + { + if (!sharedLibraryLoader) + { + sharedLibraryLoader = new ConfigurableSharedLibraryLoader; + } + sharedLibraryLoader->addEntry(libType, ConfigurableSharedLibraryLoader::changePath, path); + } + + int addTranslationUnit( + SlangSourceLanguage language, + Stage impliedStage) + { + auto translationUnitIndex = rawTranslationUnits.getCount(); + auto translationUnitID = spAddTranslationUnit(compileRequest, language, nullptr); + + // As a sanity check: the API should be returning the same translation + // unit index as we maintain internally. This invariant would only + // be broken if we decide to support a mix of translation units specified + // via API, and ones specified via command-line arguments. + // + SLANG_RELEASE_ASSERT(Index(translationUnitID) == translationUnitIndex); + + RawTranslationUnit rawTranslationUnit; + rawTranslationUnit.sourceLanguage = language; + rawTranslationUnit.translationUnitID = translationUnitID; + rawTranslationUnit.impliedStage = impliedStage; + + rawTranslationUnits.add(rawTranslationUnit); + + return int(translationUnitIndex); + } + + void addInputSlangPath( + String const& path) + { + // All of the input .slang files will be grouped into a single logical translation unit, + // which we create lazily when the first .slang file is encountered. + if( slangTranslationUnitIndex == -1 ) + { + translationUnitCount++; + slangTranslationUnitIndex = addTranslationUnit(SLANG_SOURCE_LANGUAGE_SLANG, Stage::Unknown); + } + + spAddTranslationUnitSourceFile( + compileRequest, + rawTranslationUnits[slangTranslationUnitIndex].translationUnitID, + path.begin()); + + // Set the translation unit to be used by subsequent entry points + currentTranslationUnitIndex = slangTranslationUnitIndex; + } + + void addInputForeignShaderPath( + String const& path, + SlangSourceLanguage language, + Stage impliedStage) + { + translationUnitCount++; + currentTranslationUnitIndex = addTranslationUnit(language, impliedStage); + + spAddTranslationUnitSourceFile( + compileRequest, + rawTranslationUnits[currentTranslationUnitIndex].translationUnitID, + path.begin()); + } + + static Profile::RawVal findGlslProfileFromPath(const String& path) + { + struct Entry + { + const char* ext; + Profile::RawVal profileId; + }; + + static const Entry entries[] = + { + { ".frag", Profile::GLSL_Fragment }, + { ".geom", Profile::GLSL_Geometry }, + { ".tesc", Profile::GLSL_TessControl }, + { ".tese", Profile::GLSL_TessEval }, + { ".comp", Profile::GLSL_Compute } + }; + + for (int i = 0; i < SLANG_COUNT_OF(entries); ++i) + { + const Entry& entry = entries[i]; + if (path.endsWith(entry.ext)) + { + return entry.profileId; + } + } + return Profile::Unknown; + } + + static SlangSourceLanguage findSourceLanguageFromPath(const String& path, Stage& outImpliedStage) + { + struct Entry + { + const char* ext; + SlangSourceLanguage sourceLanguage; + SlangStage impliedStage; + }; + + static const Entry entries[] = + { + { ".slang", SLANG_SOURCE_LANGUAGE_SLANG, SLANG_STAGE_NONE }, + + { ".hlsl", SLANG_SOURCE_LANGUAGE_HLSL, SLANG_STAGE_NONE }, + { ".fx", SLANG_SOURCE_LANGUAGE_HLSL, SLANG_STAGE_NONE }, + + { ".glsl", SLANG_SOURCE_LANGUAGE_GLSL, SLANG_STAGE_NONE }, + { ".vert", SLANG_SOURCE_LANGUAGE_GLSL, SLANG_STAGE_VERTEX }, + { ".frag", SLANG_SOURCE_LANGUAGE_GLSL, SLANG_STAGE_FRAGMENT }, + { ".geom", SLANG_SOURCE_LANGUAGE_GLSL, SLANG_STAGE_GEOMETRY }, + { ".tesc", SLANG_SOURCE_LANGUAGE_GLSL, SLANG_STAGE_HULL }, + { ".tese", SLANG_SOURCE_LANGUAGE_GLSL, SLANG_STAGE_DOMAIN }, + { ".comp", SLANG_SOURCE_LANGUAGE_GLSL, SLANG_STAGE_COMPUTE }, + }; + + for (int i = 0; i < SLANG_COUNT_OF(entries); ++i) + { + const Entry& entry = entries[i]; + if (path.endsWith(entry.ext)) + { + outImpliedStage = Stage(entry.impliedStage); + return entry.sourceLanguage; + } + } + return SLANG_SOURCE_LANGUAGE_UNKNOWN; + } + + SlangResult addInputPath( + char const* inPath) + { + inputPathCount++; + + // look at the extension on the file name to determine + // how we should handle it. + String path = String(inPath); + + if( path.endsWith(".slang") ) + { + // Plain old slang code + addInputSlangPath(path); + return SLANG_OK; + } + + Stage impliedStage = Stage::Unknown; + SlangSourceLanguage sourceLanguage = findSourceLanguageFromPath(path, impliedStage); + + if (sourceLanguage == SLANG_SOURCE_LANGUAGE_UNKNOWN) + { + requestImpl->getSink()->diagnose(SourceLoc(), Diagnostics::cannotDeduceSourceLanguage, inPath); + return SLANG_FAIL; + } + + addInputForeignShaderPath(path, sourceLanguage, impliedStage); + + return SLANG_OK; + } + + void addOutputPath( + String const& path, + CodeGenTarget impliedFormat) + { + RawOutput rawOutput; + rawOutput.path = path; + rawOutput.impliedFormat = impliedFormat; + rawOutputs.add(rawOutput); + } + + void addOutputPath(char const* inPath) + { + String path = String(inPath); + + if (!inPath) {} +#define CASE(EXT, TARGET) \ + else if(path.endsWith(EXT)) do { addOutputPath(path, CodeGenTarget(SLANG_##TARGET)); } while(0) + + CASE(".hlsl", HLSL); + CASE(".fx", HLSL); + + CASE(".dxbc", DXBC); + CASE(".dxbc.asm", DXBC_ASM); + + CASE(".dxil", DXIL); + CASE(".dxil.asm", DXIL_ASM); + + CASE(".glsl", GLSL); + CASE(".vert", GLSL); + CASE(".frag", GLSL); + CASE(".geom", GLSL); + CASE(".tesc", GLSL); + CASE(".tese", GLSL); + CASE(".comp", GLSL); + + CASE(".spv", SPIRV); + CASE(".spv.asm", SPIRV_ASM); + + CASE(".c", C_SOURCE); + CASE(".cpp", CPP_SOURCE); + +#undef CASE + + else if (path.endsWith(".slang-module")) + { + spSetOutputContainerFormat(compileRequest, SLANG_CONTAINER_FORMAT_SLANG_MODULE); + requestImpl->containerOutputPath = path; + } + else + { + // Allow an unknown-format `-o`, assuming we get a target format + // from another argument. + addOutputPath(path, CodeGenTarget::Unknown); + } + } + + RawEntryPoint* getCurrentEntryPoint() + { + auto rawEntryPointCount = rawEntryPoints.getCount(); + return rawEntryPointCount ? &rawEntryPoints[rawEntryPointCount-1] : &defaultEntryPoint; + } + + void setStage(RawEntryPoint* rawEntryPoint, Stage stage) + { + if(rawEntryPoint->stage != Stage::Unknown) + { + rawEntryPoint->redundantStageSet = true; + if( stage != rawEntryPoint->stage ) + { + rawEntryPoint->conflictingStagesSet = true; + } + } + rawEntryPoint->stage = stage; + } + + RawTarget* getCurrentTarget() + { + auto rawTargetCount = rawTargets.getCount(); + return rawTargetCount ? &rawTargets[rawTargetCount-1] : &defaultTarget; + } + + void setProfileVersion(RawTarget* rawTarget, ProfileVersion profileVersion) + { + if(rawTarget->profileVersion != ProfileVersion::Unknown) + { + rawTarget->redundantProfileSet = true; + + if(profileVersion != rawTarget->profileVersion) + { + rawTarget->conflictingProfilesSet = true; + } + } + rawTarget->profileVersion = profileVersion; + } + + void setFloatingPointMode(RawTarget* rawTarget, FloatingPointMode mode) + { + rawTarget->floatingPointMode = mode; + } + + SlangResult parse( + int argc, + char const* const* argv) + { + // Copy some state out of the current request, in case we've been called + // after some other initialization has been performed. + flags = requestImpl->getFrontEndReq()->compileFlags; + + DiagnosticSink* sink = requestImpl->getSink(); + + SlangMatrixLayoutMode defaultMatrixLayoutMode = SLANG_MATRIX_LAYOUT_MODE_UNKNOWN; + + char const* const* argCursor = &argv[0]; + char const* const* argEnd = &argv[argc]; + while (argCursor != argEnd) + { + char const* arg = *argCursor++; + if (arg[0] == '-') + { + String argStr = String(arg); + + if(argStr == "-no-mangle" ) + { + flags |= SLANG_COMPILE_FLAG_NO_MANGLING; + } + else if (argStr == "-no-codegen") + { + flags |= SLANG_COMPILE_FLAG_NO_CODEGEN; + } + else if(argStr == "-dump-ir" ) + { + requestImpl->getFrontEndReq()->shouldDumpIR = true; + requestImpl->getBackEndReq()->shouldDumpIR = true; + } + else if (argStr == "-serial-ir") + { + requestImpl->getFrontEndReq()->useSerialIRBottleneck = true; + } + else if (argStr == "-verbose-paths") + { + requestImpl->getSink()->flags |= DiagnosticSink::Flag::VerbosePath; + } + else if (argStr == "-verify-debug-serial-ir") + { + requestImpl->getFrontEndReq()->verifyDebugSerialization = true; + } + else if(argStr == "-validate-ir" ) + { + requestImpl->getFrontEndReq()->shouldValidateIR = true; + requestImpl->getBackEndReq()->shouldValidateIR = true; + } + else if(argStr == "-skip-codegen" ) + { + requestImpl->shouldSkipCodegen = true; + } + else if(argStr == "-parameter-blocks-use-register-spaces" ) + { + getCurrentTarget()->targetFlags |= SLANG_TARGET_FLAG_PARAMETER_BLOCKS_USE_REGISTER_SPACES; + } + else if (argStr == "-target") + { + String name; + SLANG_RETURN_ON_FAIL(tryReadCommandLineArgument(sink, arg, &argCursor, argEnd, name)); + + SlangCompileTarget format = SLANG_TARGET_UNKNOWN; + + #define CASE(NAME, TARGET) \ + if(name == NAME) { format = SLANG_##TARGET; } else + + CASE("hlsl", HLSL) + CASE("glsl", GLSL) + CASE("dxbc", DXBC) + CASE("dxbc-assembly", DXBC_ASM) + CASE("dxbc-asm", DXBC_ASM) + CASE("spirv", SPIRV) + CASE("spirv-assembly", SPIRV_ASM) + CASE("spirv-asm", SPIRV_ASM) + CASE("dxil", DXIL) + CASE("dxil-assembly", DXIL_ASM) + CASE("dxil-asm", DXIL_ASM) + CASE("c", C_SOURCE) + CASE("cpp", CPP_SOURCE) + + #undef CASE + /* else */ + { + sink->diagnose(SourceLoc(), Diagnostics::unknownCodeGenerationTarget, name); + return SLANG_FAIL; + } + + RawTarget rawTarget; + rawTarget.format = CodeGenTarget(format); + + rawTargets.add(rawTarget); + } + // A "profile" can specify both a general capability level for + // a target, and also (as a legacy/compatibility feature) a + // specific stage to use for an entry point. + else if (argStr == "-profile") + { + String name; + SLANG_RETURN_ON_FAIL(tryReadCommandLineArgument(sink, arg, &argCursor, argEnd, name)); + + SlangProfileID profileID = spFindProfile(session, name.begin()); + if( profileID == SLANG_PROFILE_UNKNOWN ) + { + sink->diagnose(SourceLoc(), Diagnostics::unknownProfile, name); + return SLANG_FAIL; + } + else + { + auto profile = Profile(profileID); + + setProfileVersion(getCurrentTarget(), profile.GetVersion()); + + // A `-profile` option that also specifies a stage (e.g., `-profile vs_5_0`) + // should be treated like a composite (e.g., `-profile sm_5_0 -stage vertex`) + auto stage = profile.GetStage(); + if(stage != Stage::Unknown) + { + setStage(getCurrentEntryPoint(), stage); + } + } + } + else if (argStr == "-stage") + { + String name; + SLANG_RETURN_ON_FAIL(tryReadCommandLineArgument(sink, arg, &argCursor, argEnd, name)); + + Stage stage = findStageByName(name); + if( stage == Stage::Unknown ) + { + sink->diagnose(SourceLoc(), Diagnostics::unknownStage, name); + return SLANG_FAIL; + } + else + { + setStage(getCurrentEntryPoint(), stage); + } + } + else if (argStr == "-entry") + { + String name; + SLANG_RETURN_ON_FAIL(tryReadCommandLineArgument(sink, arg, &argCursor, argEnd, name)); + + RawEntryPoint rawEntryPoint; + rawEntryPoint.name = name; + rawEntryPoint.translationUnitIndex = currentTranslationUnitIndex; + + rawEntryPoints.add(rawEntryPoint); + } + else if (argStr == "-pass-through") + { + String name; + SLANG_RETURN_ON_FAIL(tryReadCommandLineArgument(sink, arg, &argCursor, argEnd, name)); + + SlangPassThrough passThrough = SLANG_PASS_THROUGH_NONE; + if (name == "fxc") { passThrough = SLANG_PASS_THROUGH_FXC; } + else if (name == "dxc") { passThrough = SLANG_PASS_THROUGH_DXC; } + else if (name == "glslang") { passThrough = SLANG_PASS_THROUGH_GLSLANG; } + else + { + sink->diagnose(SourceLoc(), Diagnostics::unknownPassThroughTarget, name); + return SLANG_FAIL; + } + + spSetPassThrough( + compileRequest, + passThrough); + } + else if (argStr == "-dxc-path") + { + String name; + SLANG_RETURN_ON_FAIL(tryReadCommandLineArgument(sink, arg, &argCursor, argEnd, name)); + addSharedLibraryPath(SharedLibraryType::Dxc, name); + addSharedLibraryPath(SharedLibraryType::Dxil, name); + } + else if (argStr == "-glslang-path") + { + String name; + SLANG_RETURN_ON_FAIL(tryReadCommandLineArgument(sink, arg, &argCursor, argEnd, name)); + addSharedLibraryPath(SharedLibraryType::Glslang, name); + } + else if (argStr == "-fxc-path") + { + String name; + SLANG_RETURN_ON_FAIL(tryReadCommandLineArgument(sink, arg, &argCursor, argEnd, name)); + addSharedLibraryPath(SharedLibraryType::Fxc, name); + } + else if (argStr[1] == 'D') + { + // The value to be defined might be part of the same option, as in: + // -DFOO + // or it might come separately, as in: + // -D FOO + char const* defineStr = arg + 2; + if (defineStr[0] == 0) + { + // Need to read another argument from the command line + SLANG_RETURN_ON_FAIL(tryReadCommandLineArgumentRaw(sink, arg, &argCursor, argEnd, &defineStr)); + } + // The string that sets up the define can have an `=` between + // the name to be defined and its value, so we search for one. + char const* eqPos = nullptr; + for(char const* dd = defineStr; *dd; ++dd) + { + if (*dd == '=') + { + eqPos = dd; + break; + } + } + + // Now set the preprocessor define + // + if (eqPos) + { + // If we found an `=`, we split the string... + + spAddPreprocessorDefine( + compileRequest, + String(defineStr, eqPos).begin(), + String(eqPos+1).begin()); + } + else + { + // If there was no `=`, then just #define it to an empty string + + spAddPreprocessorDefine( + compileRequest, + String(defineStr).begin(), + ""); + } + } + else if (argStr[1] == 'I') + { + // The value to be defined might be part of the same option, as in: + // -IFOO + // or it might come separately, as in: + // -I FOO + // (see handling of `-D` above) + char const* includeDirStr = arg + 2; + if (includeDirStr[0] == 0) + { + // Need to read another argument from the command line + SLANG_RETURN_ON_FAIL(tryReadCommandLineArgumentRaw(sink, arg, &argCursor, argEnd, &includeDirStr)); + } + + spAddSearchPath( + compileRequest, + String(includeDirStr).begin()); + } + // + // A `-o` option is used to specify a desired output file. + else if (argStr == "-o") + { + char const* outputPath = nullptr; + SLANG_RETURN_ON_FAIL(tryReadCommandLineArgumentRaw(sink, arg, &argCursor, argEnd, &outputPath)); + if (!outputPath) continue; + + addOutputPath(outputPath); + } + else if(argStr == "-matrix-layout-row-major") + { + defaultMatrixLayoutMode = kMatrixLayoutMode_RowMajor; + } + else if(argStr == "-matrix-layout-column-major") + { + defaultMatrixLayoutMode = kMatrixLayoutMode_ColumnMajor; + } + else if(argStr == "-line-directive-mode") + { + String name; + SLANG_RETURN_ON_FAIL(tryReadCommandLineArgument(sink, arg, &argCursor, argEnd, name)); + + SlangLineDirectiveMode mode = SLANG_LINE_DIRECTIVE_MODE_DEFAULT; + if(name == "none") + { + mode = SLANG_LINE_DIRECTIVE_MODE_NONE; + } + else + { + sink->diagnose(SourceLoc(), Diagnostics::unknownLineDirectiveMode, name); + return SLANG_FAIL; + } + + spSetLineDirectiveMode(compileRequest, mode); + + } + else if( argStr == "-fp-mode" || argStr == "-floating-point-mode" ) + { + String name; + SLANG_RETURN_ON_FAIL(tryReadCommandLineArgument(sink, arg, &argCursor, argEnd, name)); + + FloatingPointMode mode = FloatingPointMode::Default; + if(name == "fast") + { + mode = FloatingPointMode::Fast; + } + else if(name == "precise") + { + mode = FloatingPointMode::Precise; + } + else + { + sink->diagnose(SourceLoc(), Diagnostics::unknownFloatingPointMode, name); + return SLANG_FAIL; + } + + setFloatingPointMode(getCurrentTarget(), mode); + } + else if( argStr[1] == 'O' ) + { + char const* name = arg + 2; + SlangOptimizationLevel level = SLANG_OPTIMIZATION_LEVEL_DEFAULT; + + bool invalidOptimizationLevel = strlen(name) > 2; + switch( name[0] ) + { + case '0': level = SLANG_OPTIMIZATION_LEVEL_NONE; break; + case '1': level = SLANG_OPTIMIZATION_LEVEL_DEFAULT; break; + case '2': level = SLANG_OPTIMIZATION_LEVEL_HIGH; break; + case '3': level = SLANG_OPTIMIZATION_LEVEL_MAXIMAL; break; + case 0 : level = SLANG_OPTIMIZATION_LEVEL_DEFAULT; break; + default: + invalidOptimizationLevel = true; + break; + } + if( invalidOptimizationLevel ) + { + sink->diagnose(SourceLoc(), Diagnostics::unknownOptimiziationLevel, name); + return SLANG_FAIL; + } + + spSetOptimizationLevel(compileRequest, level); + } + + // Note: unlike with `-O` above, we have to consider that other + // options might have names that start with `-g` and so cannot + // just detect it as a prefix. + else if( argStr == "-g" || argStr == "-g2" ) + { + spSetDebugInfoLevel(compileRequest, SLANG_DEBUG_INFO_LEVEL_STANDARD); + } + else if( argStr == "-g0" ) + { + spSetDebugInfoLevel(compileRequest, SLANG_DEBUG_INFO_LEVEL_NONE); + } + else if( argStr == "-g1" ) + { + spSetDebugInfoLevel(compileRequest, SLANG_DEBUG_INFO_LEVEL_MINIMAL); + } + else if( argStr == "-g3" ) + { + spSetDebugInfoLevel(compileRequest, SLANG_DEBUG_INFO_LEVEL_MAXIMAL); + } + else if( argStr == "-default-image-format-unknown" ) + { + requestImpl->getBackEndReq()->useUnknownImageFormatAsDefault = true; + } + else if (argStr == "--") + { + // The `--` option causes us to stop trying to parse options, + // and treat the rest of the command line as input file names: + while (argCursor != argEnd) + { + SLANG_RETURN_ON_FAIL(addInputPath(*argCursor++)); + } + break; + } + else + { + sink->diagnose(SourceLoc(), Diagnostics::unknownCommandLineOption, argStr); + // TODO: print a usage message + return SLANG_FAIL; + } + } + else + { + SLANG_RETURN_ON_FAIL(addInputPath(arg)); + } + } + + spSetCompileFlags(compileRequest, flags); + + // As a compatability feature, if the user didn't list any explicit entry + // point names, *and* they are compiling a single translation unit, *and* they + // have either specified a stage, or we can assume one from the naming + // of the translation unit, then we assume they wanted to compile a single + // entry point named `main`. + // + if(rawEntryPoints.getCount() == 0 + && rawTranslationUnits.getCount() == 1 + && (defaultEntryPoint.stage != Stage::Unknown + || rawTranslationUnits[0].impliedStage != Stage::Unknown)) + { + RawEntryPoint entry; + entry.name = "main"; + entry.translationUnitIndex = 0; + rawEntryPoints.add(entry); + } + + // If the user (manually or implicitly) specified only a single entry point, + // then we allow the associated stage to be specified either before or after + // the entry point. This means that if there is a stage attached + // to the "default" entry point, we should copy it over to the + // explicit one. + // + if( rawEntryPoints.getCount() == 1 ) + { + if(defaultEntryPoint.stage != Stage::Unknown) + { + setStage(getCurrentEntryPoint(), defaultEntryPoint.stage); + } + + if(defaultEntryPoint.redundantStageSet) + getCurrentEntryPoint()->redundantStageSet = true; + if(defaultEntryPoint.conflictingStagesSet) + getCurrentEntryPoint()->conflictingStagesSet = true; + } + else + { + // If the "default" entry point has had a stage (or + // other state, if we add other per-entry-point state) + // specified, but there is more than one entry point, + // then that state doesn't apply to anything and we + // should issue an error to tell the user something + // funky is going on. + // + if( defaultEntryPoint.stage != Stage::Unknown ) + { + if( rawEntryPoints.getCount() == 0 ) + { + sink->diagnose(SourceLoc(), Diagnostics::stageSpecificationIgnoredBecauseNoEntryPoints); + } + else + { + sink->diagnose(SourceLoc(), Diagnostics::stageSpecificationIgnoredBecauseBeforeAllEntryPoints); + } + } + } + + // Slang requires that every explicit entry point indicate the translation + // unit it comes from. If there is only one translation unit specified, + // then implicitly all entry points come from it. + // + if(translationUnitCount == 1) + { + for( auto& entryPoint : rawEntryPoints ) + { + entryPoint.translationUnitIndex = 0; + } + } + else + { + // Otherwise, we require that all entry points be specified after + // the translation unit to which tye belong. + bool anyEntryPointWithoutTranslationUnit = false; + for( auto& entryPoint : rawEntryPoints ) + { + // Skip entry points that are already associated with a translation unit... + if( entryPoint.translationUnitIndex != -1 ) + continue; + + anyEntryPointWithoutTranslationUnit = true; + } + if( anyEntryPointWithoutTranslationUnit ) + { + sink->diagnose(SourceLoc(), Diagnostics::entryPointsNeedToBeAssociatedWithTranslationUnits); + return SLANG_FAIL; + } + } + + // Now that entry points are associated with translation units, + // we can make one additional pass where if an entry point has + // no specified stage, but the nameing of its translation unit + // implies a stage, we will use that (a manual `-stage` annotation + // will always win out in such a case). + // + for( auto& rawEntryPoint : rawEntryPoints ) + { + // Skip entry points that already have a stage. + if(rawEntryPoint.stage != Stage::Unknown) + continue; + + // Sanity check: don't process entry points with no associated translation unit. + if( rawEntryPoint.translationUnitIndex == -1 ) + continue; + + auto impliedStage = rawTranslationUnits[rawEntryPoint.translationUnitIndex].impliedStage; + if(impliedStage != Stage::Unknown) + rawEntryPoint.stage = impliedStage; + } + + // Note: it is possible that some entry points still won't have associated + // stages at this point, but we don't want to error out here, because + // those entry points might get stages later, as part of semantic checking, + // if the corresponding function has a `[shader("...")]` attribute. + + // Now that we've tried to establish stages for entry points, we can + // issue diagnostics for cases where stages were set redundantly or + // in conflicting ways. + // + for( auto& rawEntryPoint : rawEntryPoints ) + { + if( rawEntryPoint.conflictingStagesSet ) + { + sink->diagnose(SourceLoc(), Diagnostics::conflictingStagesForEntryPoint, rawEntryPoint.name); + } + else if( rawEntryPoint.redundantStageSet ) + { + sink->diagnose(SourceLoc(), Diagnostics::sameStageSpecifiedMoreThanOnce, rawEntryPoint.stage, rawEntryPoint.name); + } + else if( rawEntryPoint.translationUnitIndex != -1 ) + { + // As a quality-of-life feature, if the file name implies a particular + // stage, but the user manually specified something different for + // their entry point, give a warning in case they made a mistake. + + auto& rawTranslationUnit = rawTranslationUnits[rawEntryPoint.translationUnitIndex]; + if( rawTranslationUnit.impliedStage != Stage::Unknown + && rawEntryPoint.stage != Stage::Unknown + && rawTranslationUnit.impliedStage != rawEntryPoint.stage ) + { + sink->diagnose(SourceLoc(), Diagnostics::explicitStageDoesntMatchImpliedStage, rawEntryPoint.name, rawEntryPoint.stage, rawTranslationUnit.impliedStage); + } + } + } + + // If the user is requesting code generation via pass-through, + // then any entry points they specify need to have a stage set, + // because fxc/dxc/glslang don't have a facility for taking + // a named entry point and pulling its stage from an attribute. + // + if( requestImpl->passThrough != PassThroughMode::None ) + { + for( auto& rawEntryPoint : rawEntryPoints ) + { + if( rawEntryPoint.stage == Stage::Unknown ) + { + sink->diagnose(SourceLoc(), Diagnostics::noStageSpecifiedInPassThroughMode, rawEntryPoint.name); + } + } + } + + // We now have inferred enough information to add the + // entry points to our compile request. + // + for( auto& rawEntryPoint : rawEntryPoints ) + { + if(rawEntryPoint.translationUnitIndex < 0) + continue; + + auto translationUnitID = rawTranslationUnits[rawEntryPoint.translationUnitIndex].translationUnitID; + + int entryPointID = spAddEntryPoint( + compileRequest, + translationUnitID, + rawEntryPoint.name.begin(), + SlangStage(rawEntryPoint.stage)); + + rawEntryPoint.entryPointID = entryPointID; + } + + // We are going to build a mapping from target formats to the + // target that handles that format. + Dictionary mapFormatToTargetIndex; + + // If there was no explicit `-target` specified, then we will look + // at the `-o` options to see what we can infer. + // + if(rawTargets.getCount() == 0) + { + for(auto& rawOutput : rawOutputs) + { + // Some outputs don't imply a target format, and we shouldn't use those for inference. + auto impliedFormat = rawOutput.impliedFormat; + if( impliedFormat == CodeGenTarget::Unknown ) + continue; + + int targetIndex = 0; + if( !mapFormatToTargetIndex.TryGetValue(impliedFormat, targetIndex) ) + { + targetIndex = (int) rawTargets.getCount(); + + RawTarget rawTarget; + rawTarget.format = impliedFormat; + rawTargets.add(rawTarget); + + mapFormatToTargetIndex[impliedFormat] = targetIndex; + } + + rawOutput.targetIndex = targetIndex; + } + } + else + { + // If there were explicit targets, then we will use those, but still + // build up our mapping. We should object if the same target format + // is specified more than once (just because of the ambiguities + // it will create). + // + int targetCount = (int) rawTargets.getCount(); + for(int targetIndex = 0; targetIndex < targetCount; ++targetIndex) + { + auto format = rawTargets[targetIndex].format; + + if( mapFormatToTargetIndex.ContainsKey(format) ) + { + sink->diagnose(SourceLoc(), Diagnostics::duplicateTargets, format); + } + else + { + mapFormatToTargetIndex[format] = targetIndex; + } + } + } + + // If we weren't able to infer any targets from output paths (perhaps + // because there were no output paths), but there was a profile specified, + // then we can try to infer a target from the profile. + // + if( rawTargets.getCount() == 0 + && defaultTarget.profileVersion != ProfileVersion::Unknown + && !defaultTarget.conflictingProfilesSet) + { + // Let's see if the chosen profile allows us to infer + // the code gen target format that the user probably meant. + // + CodeGenTarget inferredFormat = CodeGenTarget::Unknown; + auto profileVersion = defaultTarget.profileVersion; + switch( Profile(profileVersion).getFamily() ) + { + default: + break; + + // For GLSL profile versions, we will assume SPIR-V + // is the output format the user intended. + case ProfileFamily::GLSL: + inferredFormat = CodeGenTarget::SPIRV; + break; + + // For DX profile versions, we will assume that the + // user wants DXIL for Shader Model 6.0 and up, + // and DXBC for all earlier versions. + // + // Note: There is overlap where both DXBC and DXIL + // nominally support SM 5.1, but in general we + // expect users to prefer to make a clean break + // at SM 6.0. Anybody who cares about the overlap + // cases should manually specify `-target dxil`. + // + case ProfileFamily::DX: + if( profileVersion >= ProfileVersion::DX_6_0 ) + { + inferredFormat = CodeGenTarget::DXIL; + } + else + { + inferredFormat = CodeGenTarget::DXBytecode; + } + break; + } + + if( inferredFormat != CodeGenTarget::Unknown ) + { + RawTarget rawTarget; + rawTarget.format = inferredFormat; + rawTargets.add(rawTarget); + } + } + + // Similar to the case for entry points, if there is a single target, + // then we allow some of its options to come from the "default" + // target state. + if(rawTargets.getCount() == 1) + { + if(defaultTarget.profileVersion != ProfileVersion::Unknown) + { + setProfileVersion(getCurrentTarget(), defaultTarget.profileVersion); + } + + getCurrentTarget()->targetFlags |= defaultTarget.targetFlags; + + if( defaultTarget.floatingPointMode != FloatingPointMode::Default ) + { + setFloatingPointMode(getCurrentTarget(), defaultTarget.floatingPointMode); + } + } + else + { + // If the "default" target has had a profile (or other state) + // specified, but there is != 1 taget, then that state doesn't + // apply to anythign and we should give the user an error. + // + if( defaultTarget.profileVersion != ProfileVersion::Unknown ) + { + if( rawTargets.getCount() == 0 ) + { + // This should only happen if there were multiple `-profile` options, + // so we didn't try to infer a target, or if the `-profile` option + // somehow didn't imply a target. + // + sink->diagnose(SourceLoc(), Diagnostics::profileSpecificationIgnoredBecauseNoTargets); + } + else + { + sink->diagnose(SourceLoc(), Diagnostics::profileSpecificationIgnoredBecauseBeforeAllTargets); + } + } + + if( defaultTarget.targetFlags ) + { + if( rawTargets.getCount() == 0 ) + { + sink->diagnose(SourceLoc(), Diagnostics::targetFlagsIgnoredBecauseNoTargets); + } + else + { + sink->diagnose(SourceLoc(), Diagnostics::targetFlagsIgnoredBecauseBeforeAllTargets); + } + } + + if( defaultTarget.floatingPointMode != FloatingPointMode::Default ) + { + if( rawTargets.getCount() == 0 ) + { + sink->diagnose(SourceLoc(), Diagnostics::targetFlagsIgnoredBecauseNoTargets); + } + else + { + sink->diagnose(SourceLoc(), Diagnostics::targetFlagsIgnoredBecauseBeforeAllTargets); + } + } + + } + + for(auto& rawTarget : rawTargets) + { + if( rawTarget.conflictingProfilesSet ) + { + sink->diagnose(SourceLoc(), Diagnostics::conflictingProfilesSpecifiedForTarget, rawTarget.format); + } + else if( rawTarget.redundantProfileSet ) + { + sink->diagnose(SourceLoc(), Diagnostics::sameProfileSpecifiedMoreThanOnce, rawTarget.profileVersion, rawTarget.format); + } + } + + // TODO: do we need to require that a target must have a profile specified, + // or will we continue to allow the profile to be inferred from the target? + + // We now have enough information to go ahead and declare the targets + // through the Slang API: + // + for(auto& rawTarget : rawTargets) + { + int targetID = spAddCodeGenTarget(compileRequest, SlangCompileTarget(rawTarget.format)); + rawTarget.targetID = targetID; + + if( rawTarget.profileVersion != ProfileVersion::Unknown ) + { + spSetTargetProfile(compileRequest, targetID, Profile(rawTarget.profileVersion).raw); + } + + if( rawTarget.targetFlags ) + { + spSetTargetFlags(compileRequest, targetID, rawTarget.targetFlags); + } + + if( rawTarget.floatingPointMode != FloatingPointMode::Default ) + { + spSetTargetFloatingPointMode(compileRequest, targetID, SlangFloatingPointMode(rawTarget.floatingPointMode)); + } + } + + if(defaultMatrixLayoutMode != SLANG_MATRIX_LAYOUT_MODE_UNKNOWN) + { + spSetMatrixLayoutMode(compileRequest, defaultMatrixLayoutMode); + } + + // Next we need to sort out the output files specified with `-o`, and + // figure out which entry point and/or target they apply to. + // + // If there is only a single entry point, then that is automatically + // the entry point that should be associated with all outputs. + // + if( rawEntryPoints.getCount() == 1 ) + { + for( auto& rawOutput : rawOutputs ) + { + rawOutput.entryPointIndex = 0; + } + } + // + // Similarly, if there is only one target, then all outputs must + // implicitly appertain to that target. + // + if( rawTargets.getCount() == 1 ) + { + for( auto& rawOutput : rawOutputs ) + { + rawOutput.targetIndex = 0; + } + } + + // Consider the output files specified via `-o` and try to figure + // out how to deal with them. + // + for(auto& rawOutput : rawOutputs) + { + // For now, all output formats need to be tightly bound to + // both a target and an entry point (down the road we will + // need to support output formats that can store multiple + // entry points in one file). + + // If an output doesn't have a target associated with + // it, then search for the target with the matching format. + if( rawOutput.targetIndex == -1 ) + { + auto impliedFormat = rawOutput.impliedFormat; + int targetIndex = -1; + + if(impliedFormat == CodeGenTarget::Unknown) + { + // If we hit this case, then it means that we need to pick the + // target to assocaite with this output based on its implied + // format, but the file path doesn't direclty imply a format + // (it doesn't have a suffix like `.spv` that tells us what to write). + // + sink->diagnose(SourceLoc(), Diagnostics::cannotDeduceOutputFormatFromPath, rawOutput.path); + } + else if( mapFormatToTargetIndex.TryGetValue(rawOutput.impliedFormat, targetIndex) ) + { + rawOutput.targetIndex = targetIndex; + } + else + { + sink->diagnose(SourceLoc(), Diagnostics::cannotMatchOutputFileToTarget, rawOutput.path, rawOutput.impliedFormat); + } + } + + // We won't do any searching to match an output file + // with an entry point, since the case of a single entry + // point was handled above, and the user is expected to + // follow the ordering rules when using multiple entry points. + // + if( rawOutput.entryPointIndex == -1 ) + { + sink->diagnose(SourceLoc(), Diagnostics::cannotMatchOutputFileToEntryPoint, rawOutput.path); + } + } + + // Now that we've diagnosed the output paths, we can add them + // to the compile request at the appropriate locations. + // + // We will consider the output files specified via `-o` and try to figure + // out how to deal with them. + // + for(auto& rawOutput : rawOutputs) + { + if(rawOutput.targetIndex == -1) continue; + if(rawOutput.entryPointIndex == -1) continue; + + auto targetID = rawTargets[rawOutput.targetIndex].targetID; + Int entryPointID = rawEntryPoints[rawOutput.entryPointIndex].entryPointID; + + auto target = requestImpl->getLinkage()->targets[targetID]; + auto entryPointReq = requestImpl->getFrontEndReq()->getEntryPointReqs()[entryPointID]; + + RefPtr targetInfo; + if( !requestImpl->targetInfos.TryGetValue(target, targetInfo) ) + { + targetInfo = new EndToEndCompileRequest::TargetInfo(); + requestImpl->targetInfos[target] = targetInfo; + } + + String outputPath; + if( targetInfo->entryPointOutputPaths.ContainsKey(entryPointID) ) + { + sink->diagnose(SourceLoc(), Diagnostics::duplicateOutputPathsForEntryPointAndTarget, entryPointReq->getName(), target->getTarget()); + } + else + { + targetInfo->entryPointOutputPaths[entryPointID] = rawOutput.path; + } + } + + if (sharedLibraryLoader) + { + spSessionSetSharedLibraryLoader(session, sharedLibraryLoader); + } + + return (sink->GetErrorCount() == 0) ? SLANG_OK : SLANG_FAIL; + } +}; + + +SlangResult parseOptions( + SlangCompileRequest* compileRequestIn, + int argc, + char const* const* argv) +{ + Slang::EndToEndCompileRequest* compileRequest = (Slang::EndToEndCompileRequest*) compileRequestIn; + + OptionsParser parser; + parser.compileRequest = compileRequestIn; + parser.requestImpl = compileRequest; + parser.session = (SlangSession*)compileRequest->getSession(); + + Result res = parser.parse(argc, argv); + + DiagnosticSink* sink = compileRequest->getSink(); + if (sink->GetErrorCount() > 0) + { + // Put the errors in the diagnostic + compileRequest->mDiagnosticOutput = sink->outputBuffer.ProduceString(); + } + + return res; +} + + +} // namespace Slang + +SLANG_API SlangResult spProcessCommandLineArguments( + SlangCompileRequest* request, + char const* const* args, + int argCount) +{ + return Slang::parseOptions(request, argCount, args); +} diff --git a/source/slang/slang-parameter-binding.cpp b/source/slang/slang-parameter-binding.cpp new file mode 100644 index 000000000..5c4fd24b5 --- /dev/null +++ b/source/slang/slang-parameter-binding.cpp @@ -0,0 +1,2583 @@ +// slang-parameter-binding.cpp +#include "slang-parameter-binding.h" + +#include "slang-lookup.h" +#include "slang-compiler.h" +#include "slang-type-layout.h" + +#include "../../slang.h" + +namespace Slang { + +struct ParameterInfo; + +// Information on ranges of registers already claimed/used +struct UsedRange +{ + // What parameter has claimed this range? + VarLayout* parameter; + + // Begin/end of the range (half-open interval) + UInt begin; + UInt end; +}; +bool operator<(UsedRange left, UsedRange right) +{ + if (left.begin != right.begin) + return left.begin < right.begin; + if (left.end != right.end) + return left.end < right.end; + return false; +} + +static bool rangesOverlap(UsedRange const& x, UsedRange const& y) +{ + SLANG_ASSERT(x.begin <= x.end); + SLANG_ASSERT(y.begin <= y.end); + + // If they don't overlap, then one must be earlier than the other, + // and that one must therefore *end* before the other *begins* + + if (x.end <= y.begin) return false; + if (y.end <= x.begin) return false; + + // Otherwise they must overlap + return true; +} + + +struct UsedRanges +{ + // The `ranges` array maintains a sorted list of `UsedRange` + // objects such that the `end` of a range is <= the `begin` + // of any range that comes after it. + // + // The values covered by each `[begin,end)` range are marked + // as used, and anything not in such an interval is implicitly + // free. + // + // TODO: if it ever starts to matter for performance, we + // could encode this information as a tree instead of an array. + // + List ranges; + + // Add a range to the set, either by extending + // existing range(s), or by adding a new one. + // + // If we find that the new range overlaps with + // an existing range for a *different* parameter + // then we return that parameter so that the + // caller can issue an error. + // + VarLayout* Add(UsedRange range) + { + // The invariant on entry to this + // function is that the `ranges` array + // is sorted and no two entries in the + // array intersect. We must preserve + // that property as a postcondition. + // + // The other postcondition is that the + // interval covered by the input `range` + // must be marked as consumed. + + // We will try track any parameter associated + // with an overlapping range that doesn't + // match the parameter on `range`, so that + // the compiler can issue useful diagnostics. + // + VarLayout* newParam = range.parameter; + VarLayout* existingParam = nullptr; + + // A clever algorithm might use a binary + // search to identify the first entry in `ranges` + // that might overlap `range`, but we are going + // to settle for being less clever for now, in + // the hopes that we can at least be correct. + // + // Note: we are going to iterate over `ranges` + // using indices, because we may actually modify + // the array as we go. + // + Int rangeCount = ranges.getCount(); + for(Int rr = 0; rr < rangeCount; ++rr) + { + auto existingRange = ranges[rr]; + + // The invariant on entry to each loop + // iteration will be that `range` does + // *not* intersect any preceding entry + // in the array. + // + // Note that this invariant might be + // true only because we modified + // `range` along the way. + // + // If `range` does not intertsect `existingRange` + // then our invariant will be trivially + // true for the next iteration. + // + if(!rangesOverlap(existingRange, range)) + { + continue; + } + + // We now know that `range` and `existingRange` + // intersect. The first thing to do + // is to check if we have a parameter + // associated with `existingRange`, so + // that we can use it for emitting diagnostics + // about the overlap: + // + if( existingRange.parameter + && existingRange.parameter != newParam) + { + // There was an overlap with a range that + // had a parameter specified, so we will + // use that parameter in any subsequent + // diagnostics. + // + existingParam = existingRange.parameter; + } + + // Before we can move on in our iteration, + // we need to re-establish our invariant by modifying + // `range` so that it doesn't overlap with `existingRange`. + // Of course we also want to end up with a correct + // result for the overall operation, so we can't just + // throw away intervals. + // + // We first note that if `range` starts before `existingRange`, + // then the interval from `range.begin` to `existingRange.begin` + // needs to be accounted for in the final result. Furthermore, + // the interval `[range.begin, existingRange.begin)` could not + // intersect with any range already in the `ranges` array, + // because it comes strictly before `existingRange`, and our + // invariant says there is no intersection with preceding ranges. + // + if(range.begin < existingRange.begin) + { + UsedRange prefix; + prefix.begin = range.begin; + prefix.end = existingRange.begin; + prefix.parameter = range.parameter; + ranges.add(prefix); + } + // + // Now we know that the interval `[range.begin, existingRange.begin)` + // is claimed, if it exists, and clearly the interval + // `[existingRange.begin, existingRange.end)` is already claimed, + // so the only interval left to consider would be + // `[existingRange.end, range.end)`, if it is non-empty. + // That range might intersect with others in the array, so + // we will need to continue iterating to deal with that + // possibility. + // + range.begin = existingRange.end; + + // If the range would be empty, then of course we have nothing + // left to do. + // + if(range.begin >= range.end) + break; + + // Otherwise, have can be sure that `range` now comes + // strictly *after* `existingRange`, and thus our invariant + // is preserved. + } + + // If we manage to exit the loop, then we have resolved + // an intersection with existing entries - possibly by + // adding some new entries. + // + // If the `range` we are left with is still non-empty, + // then we should go ahead and add it. + // + if(range.begin < range.end) + { + ranges.add(range); + } + + // Any ranges that got added along the way might not + // be in the proper sorted order, so we'll need to + // sort the array to restore our global invariant. + // + ranges.sort(); + + // We end by returning an overlapping parameter that + // we found along the way, if any. + // + return existingParam; + } + + VarLayout* Add(VarLayout* param, UInt begin, UInt end) + { + UsedRange range; + range.parameter = param; + range.begin = begin; + range.end = end; + return Add(range); + } + + VarLayout* Add(VarLayout* param, UInt begin, LayoutSize end) + { + UsedRange range; + range.parameter = param; + range.begin = begin; + range.end = end.isFinite() ? end.getFiniteValue() : UInt(-1); + return Add(range); + } + + bool contains(UInt index) + { + for (auto rr : ranges) + { + if (index < rr.begin) + return false; + + if (index >= rr.end) + continue; + + return true; + } + + return false; + } + + + // Try to find space for `count` entries + UInt Allocate(VarLayout* param, UInt count) + { + UInt begin = 0; + + UInt rangeCount = ranges.getCount(); + for (UInt rr = 0; rr < rangeCount; ++rr) + { + // try to fit in before this range... + + UInt end = ranges[rr].begin; + + // If there is enough space... + if (end >= begin + count) + { + // ... then claim it and be done + Add(param, begin, begin + count); + return begin; + } + + // ... otherwise, we need to look at the + // space between this range and the next + begin = ranges[rr].end; + } + + // We've run out of ranges to check, so we + // can safely go after the last one! + Add(param, begin, begin + count); + return begin; + } +}; + +struct ParameterBindingInfo +{ + size_t space = 0; + size_t index = 0; + LayoutSize count; +}; + +struct ParameterBindingAndKindInfo : ParameterBindingInfo +{ + LayoutResourceKind kind = LayoutResourceKind::None; +}; + +enum +{ + kLayoutResourceKindCount = SLANG_PARAMETER_CATEGORY_COUNT, +}; + +struct UsedRangeSet : RefObject +{ + // Information on what ranges of "registers" have already + // been claimed, for each resource type + UsedRanges usedResourceRanges[kLayoutResourceKindCount]; +}; + +// Information on a single parameter +struct ParameterInfo : RefObject +{ + // Layout info for the concrete variables that will make up this parameter + List> varLayouts; + + ParameterBindingInfo bindingInfo[kLayoutResourceKindCount]; + + // The translation unit this parameter is specific to, if any + TranslationUnitRequest* translationUnit = nullptr; + + ParameterInfo() + { + // Make sure we aren't claiming any resources yet + for( int ii = 0; ii < kLayoutResourceKindCount; ++ii ) + { + bindingInfo[ii].count = 0; + } + } +}; + +struct EntryPointParameterBindingContext +{ + // What ranges of resources bindings are already claimed for this translation unit + UsedRangeSet usedRangeSet; +}; + +// State that is shared during parameter binding, +// across all translation units +struct SharedParameterBindingContext +{ + SharedParameterBindingContext( + LayoutRulesFamilyImpl* defaultLayoutRules, + ProgramLayout* programLayout, + TargetRequest* targetReq, + DiagnosticSink* sink) + : defaultLayoutRules(defaultLayoutRules) + , programLayout(programLayout) + , targetRequest(targetReq) + , m_sink(sink) + { + } + + DiagnosticSink* m_sink = nullptr; + + // The program that we are laying out +// Program* program = nullptr; + + // The target request that is triggering layout + // + // TODO: We should eventually strip this down to + // just the subset of fields on the target that + // can influence layout decisions. + TargetRequest* targetRequest = nullptr; + + LayoutRulesFamilyImpl* defaultLayoutRules; + + // All shader parameters we've discovered so far, and started to lay out... + List> parameters; + + // The program layout we are trying to construct + RefPtr programLayout; + + // What ranges of resources bindings are already claimed at the global scope? + // We store one of these for each declared binding space/set. + // + Dictionary> globalSpaceUsedRangeSets; + + // Which register spaces have been claimed so far? + UsedRanges usedSpaces; + + // The space to use for auto-generated bindings. + UInt defaultSpace = 0; + + TargetRequest* getTargetRequest() { return targetRequest; } + DiagnosticSink* getSink() { return m_sink; } +}; + +static DiagnosticSink* getSink(SharedParameterBindingContext* shared) +{ + return shared->getSink(); +} + +// State that might be specific to a single translation unit +// or event to an entry point. +struct ParameterBindingContext +{ + // All the shared state needs to be available + SharedParameterBindingContext* shared; + + // The type layout context to use when computing + // the resource usage of shader parameters. + TypeLayoutContext layoutContext; + + // What stage (if any) are we compiling for? + Stage stage; + + // The entry point that is being processed right now. + EntryPointLayout* entryPointLayout = nullptr; + + TargetRequest* getTargetRequest() { return shared->getTargetRequest(); } + LayoutRulesFamilyImpl* getRulesFamily() { return layoutContext.getRulesFamily(); } +}; + +static DiagnosticSink* getSink(ParameterBindingContext* context) +{ + return getSink(context->shared); +} + + +struct LayoutSemanticInfo +{ + LayoutResourceKind kind; // the register kind + UInt space; + UInt index; + + // TODO: need to deal with component-granularity binding... +}; + +static bool isDigit(char c) +{ + return (c >= '0') && (c <= '9'); +} + +/// Given a string that specifies a name and index (e.g., `COLOR0`), +/// split it into slices for the name part and the index part. +static void splitNameAndIndex( + UnownedStringSlice const& text, + UnownedStringSlice& outName, + UnownedStringSlice& outDigits) +{ + char const* nameBegin = text.begin(); + char const* digitsEnd = text.end(); + + char const* nameEnd = digitsEnd; + while( nameEnd != nameBegin && isDigit(*(nameEnd - 1)) ) + { + nameEnd--; + } + char const* digitsBegin = nameEnd; + + outName = UnownedStringSlice(nameBegin, nameEnd); + outDigits = UnownedStringSlice(digitsBegin, digitsEnd); +} + +LayoutResourceKind findRegisterClassFromName(UnownedStringSlice const& registerClassName) +{ + switch( registerClassName.size() ) + { + case 1: + switch (*registerClassName.begin()) + { + case 'b': return LayoutResourceKind::ConstantBuffer; + case 't': return LayoutResourceKind::ShaderResource; + case 'u': return LayoutResourceKind::UnorderedAccess; + case 's': return LayoutResourceKind::SamplerState; + + default: + break; + } + break; + + case 5: + if( registerClassName == "space" ) + { + return LayoutResourceKind::RegisterSpace; + } + break; + + default: + break; + } + return LayoutResourceKind::None; +} + +LayoutSemanticInfo ExtractLayoutSemanticInfo( + ParameterBindingContext* context, + HLSLLayoutSemantic* semantic) +{ + LayoutSemanticInfo info; + info.space = 0; + info.index = 0; + info.kind = LayoutResourceKind::None; + + UnownedStringSlice registerName = semantic->registerName.Content; + if (registerName.size() == 0) + return info; + + // The register name is expected to be in the form: + // + // identifier-char+ digit+ + // + // where the identifier characters name a "register class" + // and the digits identify a register index within that class. + // + // We are going to split the string the user gave us + // into these constituent parts: + // + UnownedStringSlice registerClassName; + UnownedStringSlice registerIndexDigits; + splitNameAndIndex(registerName, registerClassName, registerIndexDigits); + + LayoutResourceKind kind = findRegisterClassFromName(registerClassName); + if(kind == LayoutResourceKind::None) + { + getSink(context)->diagnose(semantic->registerName, Diagnostics::unknownRegisterClass, registerClassName); + return info; + } + + // For a `register` semantic, the register index is not optional (unlike + // how it works for varying input/output semantics). + if( registerIndexDigits.size() == 0 ) + { + getSink(context)->diagnose(semantic->registerName, Diagnostics::expectedARegisterIndex, registerClassName); + } + + UInt index = 0; + for(auto c : registerIndexDigits) + { + SLANG_ASSERT(isDigit(c)); + index = index * 10 + (c - '0'); + } + + + UInt space = 0; + if( auto registerSemantic = as(semantic) ) + { + auto const& spaceName = registerSemantic->spaceName.Content; + if(spaceName.size() != 0) + { + UnownedStringSlice spaceSpelling; + UnownedStringSlice spaceDigits; + splitNameAndIndex(spaceName, spaceSpelling, spaceDigits); + + if( kind == LayoutResourceKind::RegisterSpace ) + { + getSink(context)->diagnose(registerSemantic->spaceName, Diagnostics::unexpectedSpecifierAfterSpace, spaceName); + } + else if( spaceSpelling != UnownedTerminatedStringSlice("space") ) + { + getSink(context)->diagnose(registerSemantic->spaceName, Diagnostics::expectedSpace, spaceSpelling); + } + else if( spaceDigits.size() == 0 ) + { + getSink(context)->diagnose(registerSemantic->spaceName, Diagnostics::expectedSpaceIndex); + } + else + { + for(auto c : spaceDigits) + { + SLANG_ASSERT(isDigit(c)); + space = space * 10 + (c - '0'); + } + } + } + } + + // TODO: handle component mask part of things... + if( semantic->componentMask.Content.size() != 0 ) + { + getSink(context)->diagnose(semantic->componentMask, Diagnostics::componentMaskNotSupported); + } + + info.kind = kind; + info.index = (int) index; + info.space = space; + return info; +} + + +// + +// Given a GLSL `layout` modifier, we need to be able to check for +// a particular sub-argument and extract its value if present. +template +static bool findLayoutArg( + RefPtr syntax, + UInt* outVal) +{ + for( auto modifier : syntax->GetModifiersOfType() ) + { + if( modifier ) + { + *outVal = (UInt) strtoull(String(modifier->valToken.Content).getBuffer(), nullptr, 10); + return true; + } + } + return false; +} + +template +static bool findLayoutArg( + DeclRef declRef, + UInt* outVal) +{ + return findLayoutArg(declRef.getDecl(), outVal); +} + + /// Determine how to lay out a global variable that might be a shader parameter. + /// + /// Returns `nullptr` if the declaration does not represent a shader parameter. +RefPtr getTypeLayoutForGlobalShaderParameter( + ParameterBindingContext* context, + VarDeclBase* varDecl, + Type* type) +{ + auto layoutContext = context->layoutContext; + auto rules = layoutContext.getRulesFamily(); + + if( varDecl->HasModifier() && as(type) ) + { + return createTypeLayout( + layoutContext.with(rules->getShaderRecordConstantBufferRules()), + type); + } + + + // We want to check for a constant-buffer type with a `push_constant` layout + // qualifier before we move on to anything else. + if (varDecl->HasModifier() && as(type)) + { + return createTypeLayout( + layoutContext.with(rules->getPushConstantBufferRules()), + type); + } + + // HLSL `static` modifier indicates "thread local" + if(varDecl->HasModifier()) + return nullptr; + + // HLSL `groupshared` modifier indicates "thread-group local" + if(varDecl->HasModifier()) + return nullptr; + + // TODO(tfoley): there may be other cases that we need to handle here + + // An "ordinary" global variable is implicitly a uniform + // shader parameter. + return createTypeLayout( + layoutContext.with(rules->getConstantBufferRules()), + type); +} + +// + +struct EntryPointParameterState +{ + String* optSemanticName = nullptr; + int* ioSemanticIndex = nullptr; + EntryPointParameterDirectionMask directionMask; + int semanticSlotCount; + Stage stage = Stage::Unknown; + bool isSampleRate = false; + SourceLoc loc; +}; + + +static RefPtr processEntryPointVaryingParameter( + ParameterBindingContext* context, + RefPtr type, + EntryPointParameterState const& state, + RefPtr varLayout); + +// Collect a single declaration into our set of parameters +static void collectGlobalGenericParameter( + ParameterBindingContext* context, + RefPtr paramDecl) +{ + RefPtr layout = new GenericParamLayout(); + layout->decl = paramDecl; + layout->index = (int)context->shared->programLayout->globalGenericParams.getCount(); + context->shared->programLayout->globalGenericParams.add(layout); + context->shared->programLayout->globalGenericParamsMap[layout->decl->getName()->text] = layout.Ptr(); +} + +// Collect a single declaration into our set of parameters +static void collectGlobalScopeParameter( + ParameterBindingContext* context, + GlobalShaderParamInfo const& shaderParamInfo, + SubstitutionSet globalGenericSubst) +{ + auto varDeclRef = shaderParamInfo.paramDeclRef; + + // We apply any substitutions for global generic parameters here. + auto type = GetType(varDeclRef)->Substitute(globalGenericSubst).as(); + + // We use a single operation to both check whether the + // variable represents a shader parameter, and to compute + // the layout for that parameter's type. + auto typeLayout = getTypeLayoutForGlobalShaderParameter( + context, + varDeclRef.getDecl(), + type); + + // If we did not find appropriate layout rules, then it + // must mean that this global variable is *not* a shader + // parameter. + if(!typeLayout) + return; + + // Now create a variable layout that we can use + RefPtr varLayout = new VarLayout(); + varLayout->typeLayout = typeLayout; + varLayout->varDecl = varDeclRef; + + // The logic in `check.cpp` that created the `GlobalShaderParamInfo` + // will have identified any cases where there might be multiple + // global variables that logically represent the same shader parameter. + // + // We will track the same basic information during layout using + // the `ParameterInfo` type. + // + // TODO: `ParameterInfo` should probably become `LayoutParamInfo`. + // + ParameterInfo* parameterInfo = new ParameterInfo(); + context->shared->parameters.add(parameterInfo); + + // Add the first variable declaration to the list of declarations for the parameter + parameterInfo->varLayouts.add(varLayout); + + // Add any additional variables to the list of declarations + for( auto additionalVarDeclRef : shaderParamInfo.additionalParamDeclRefs ) + { + // TODO: We should either eliminate the design choice where different + // declarations of the "same" shade parameter get merged across + // translation units (it is effectively just a compatiblity feature), + // or we should clean things up earlier in the chain so that we can + // re-use a single `VarLayout` across all of the different declarations. + // + // TODO: It would also make sense in these cases to ensure that + // such global shader parameters get the same mangled name across + // all translation units, so that they can automatically be collapsed + // during linking. + + RefPtr additionalVarLayout = new VarLayout(); + additionalVarLayout->typeLayout = typeLayout; + additionalVarLayout->varDecl = additionalVarDeclRef; + + parameterInfo->varLayouts.add(additionalVarLayout); + } +} + +static RefPtr findUsedRangeSetForSpace( + ParameterBindingContext* context, + UInt space) +{ + RefPtr usedRangeSet; + if (context->shared->globalSpaceUsedRangeSets.TryGetValue(space, usedRangeSet)) + return usedRangeSet; + + usedRangeSet = new UsedRangeSet(); + context->shared->globalSpaceUsedRangeSets.Add(space, usedRangeSet); + return usedRangeSet; +} + +// Record that a particular register space (or set, in the GLSL case) +// has been used in at least one binding, and so it should not +// be used by auto-generated bindings that need to claim entire +// spaces. +static void markSpaceUsed( + ParameterBindingContext* context, + UInt space) +{ + context->shared->usedSpaces.Add(nullptr, space, space+1); +} + +static UInt allocateUnusedSpaces( + ParameterBindingContext* context, + UInt count) +{ + return context->shared->usedSpaces.Allocate(nullptr, count); +} + +static void addExplicitParameterBinding( + ParameterBindingContext* context, + RefPtr parameterInfo, + VarDeclBase* varDecl, + LayoutSemanticInfo const& semanticInfo, + LayoutSize count, + RefPtr usedRangeSet = nullptr) +{ + auto kind = semanticInfo.kind; + + auto& bindingInfo = parameterInfo->bindingInfo[(int)kind]; + if( bindingInfo.count != 0 ) + { + // We already have a binding here, so we want to + // confirm that it matches the new one that is + // incoming... + if( bindingInfo.count != count + || bindingInfo.index != semanticInfo.index + || bindingInfo.space != semanticInfo.space ) + { + getSink(context)->diagnose(varDecl, Diagnostics::conflictingExplicitBindingsForParameter, getReflectionName(varDecl)); + + auto firstVarDecl = parameterInfo->varLayouts[0]->varDecl.getDecl(); + if( firstVarDecl != varDecl ) + { + getSink(context)->diagnose(firstVarDecl, Diagnostics::seeOtherDeclarationOf, getReflectionName(firstVarDecl)); + } + } + + // TODO(tfoley): `register` semantics can technically be + // profile-specific (not sure if anybody uses that)... + } + else + { + bindingInfo.count = count; + bindingInfo.index = semanticInfo.index; + bindingInfo.space = semanticInfo.space; + + if (!usedRangeSet) + { + usedRangeSet = findUsedRangeSetForSpace(context, semanticInfo.space); + + // Record that the particular binding space was + // used by an explicit binding, so that we don't + // claim it for auto-generated bindings that + // need to grab a full space + markSpaceUsed(context, semanticInfo.space); + } + auto overlappedVarLayout = usedRangeSet->usedResourceRanges[(int)semanticInfo.kind].Add( + parameterInfo->varLayouts[0], + semanticInfo.index, + semanticInfo.index + count); + + if (overlappedVarLayout) + { + auto paramA = parameterInfo->varLayouts[0]->varDecl.getDecl(); + auto paramB = overlappedVarLayout->varDecl.getDecl(); + + getSink(context)->diagnose(paramA, Diagnostics::parameterBindingsOverlap, + getReflectionName(paramA), + getReflectionName(paramB)); + + getSink(context)->diagnose(paramB, Diagnostics::seeDeclarationOf, getReflectionName(paramB)); + } + } +} + +static void addExplicitParameterBindings_HLSL( + ParameterBindingContext* context, + RefPtr parameterInfo, + RefPtr varLayout) +{ + // We only want to apply D3D `register` modifiers when compiling for + // D3D targets. + // + // TODO: Nominally, the `register` keyword allows for a shader + // profile to be specified, so that a given binding only + // applies for a specific profile: + // + // https://docs.microsoft.com/en-us/windows/desktop/direct3dhlsl/dx-graphics-hlsl-variable-register + // + // We might want to consider supporting that syntax in the + // long run, in order to handle bindings for multiple targets + // in a more consistent fashion (whereas using `register` for D3D + // and `[[vk::binding(...)]]` for Vulkan creates a lot of + // visual noise). + // + // For now we do the filtering on target in a very direct fashion: + // + if(!isD3DTarget(context->getTargetRequest())) + return; + + auto typeLayout = varLayout->typeLayout; + auto varDecl = varLayout->varDecl; + + // If the declaration has explicit binding modifiers, then + // here is where we want to extract and apply them... + + // Look for HLSL `register` or `packoffset` semantics. + for (auto semantic : varDecl.getDecl()->GetModifiersOfType()) + { + // Need to extract the information encoded in the semantic + LayoutSemanticInfo semanticInfo = ExtractLayoutSemanticInfo(context, semantic); + auto kind = semanticInfo.kind; + if (kind == LayoutResourceKind::None) + continue; + + // TODO: need to special-case when this is a `c` register binding... + + // Find the appropriate resource-binding information + // inside the type, to see if we even use any resources + // of the given kind. + + auto typeRes = typeLayout->FindResourceInfo(kind); + LayoutSize count = 0; + if (typeRes) + { + count = typeRes->count; + } + else + { + // TODO: warning here! + } + + addExplicitParameterBinding(context, parameterInfo, varDecl, semanticInfo, count); + } +} + +static void maybeDiagnoseMissingVulkanLayoutModifier( + ParameterBindingContext* context, + DeclRef const& varDecl) +{ + // If the user didn't specify a `binding` (and optional `set`) for Vulkan, + // but they *did* specify a `register` for D3D, then that is probably an + // oversight on their part. + if( auto registerModifier = varDecl.getDecl()->FindModifier() ) + { + getSink(context)->diagnose(registerModifier, Diagnostics::registerModifierButNoVulkanLayout, varDecl.GetName()); + } +} + +static void addExplicitParameterBindings_GLSL( + ParameterBindingContext* context, + RefPtr parameterInfo, + RefPtr varLayout) +{ + + // We only want to apply GLSL-style layout modifers + // when compiling for a Khronos-related target. + // + // TODO: This should have some finer granularity + // so that we are able to distinguish between + // Vulkan and OpenGL as targets. + // + if(!isKhronosTarget(context->getTargetRequest())) + return; + + auto typeLayout = varLayout->typeLayout; + auto varDecl = varLayout->varDecl; + + // The catch in GLSL is that the expected resource type + // is implied by the parameter declaration itself, and + // the `layout` modifier is only allowed to adjust + // the index/offset/etc. + // + + // We also may need to store explicit binding info in a different place, + // in the case of varying input/output, since we don't want to collect + // things globally; + RefPtr usedRangeSet; + + TypeLayout::ResourceInfo* resInfo = nullptr; + LayoutSemanticInfo semanticInfo; + semanticInfo.index = 0; + semanticInfo.space = 0; + if( (resInfo = typeLayout->FindResourceInfo(LayoutResourceKind::DescriptorTableSlot)) != nullptr ) + { + // Try to find `binding` and `set` + auto attr = varDecl.getDecl()->FindModifier(); + if (!attr) + { + maybeDiagnoseMissingVulkanLayoutModifier(context, varDecl); + return; + } + semanticInfo.index = attr->binding; + semanticInfo.space = attr->set; + } + else if( (resInfo = typeLayout->FindResourceInfo(LayoutResourceKind::RegisterSpace)) != nullptr ) + { + // Try to find `set` + auto attr = varDecl.getDecl()->FindModifier(); + if (!attr) + { + maybeDiagnoseMissingVulkanLayoutModifier(context, varDecl); + return; + } + if( attr->binding != 0) + { + getSink(context)->diagnose(attr, Diagnostics::wholeSpaceParameterRequiresZeroBinding, varDecl.GetName(), attr->binding); + } + semanticInfo.index = attr->set; + semanticInfo.space = 0; + } + else if( (resInfo = typeLayout->FindResourceInfo(LayoutResourceKind::SpecializationConstant)) != nullptr ) + { + // Try to find `constant_id` binding + if(!findLayoutArg(varDecl, &semanticInfo.index)) + return; + } + + // If we didn't find any matches, then bail + if(!resInfo) + return; + + auto kind = resInfo->kind; + auto count = resInfo->count; + semanticInfo.kind = kind; + + addExplicitParameterBinding(context, parameterInfo, varDecl, semanticInfo, count, usedRangeSet); +} + +// Given a single parameter, collect whatever information we have on +// how it has been explicitly bound, which may come from multiple declarations +void generateParameterBindings( + ParameterBindingContext* context, + RefPtr parameterInfo) +{ + // There must be at least one declaration for the parameter. + SLANG_RELEASE_ASSERT(parameterInfo->varLayouts.getCount() != 0); + + // Iterate over all declarations looking for explicit binding information. + for( auto& varLayout : parameterInfo->varLayouts ) + { + // Handle HLSL `register` and `packoffset` modifiers + addExplicitParameterBindings_HLSL(context, parameterInfo, varLayout); + + + // Handle GLSL `layout` modifiers + addExplicitParameterBindings_GLSL(context, parameterInfo, varLayout); + } +} + +// Generate the binding information for a shader parameter. +static void completeBindingsForParameterImpl( + ParameterBindingContext* context, + RefPtr firstVarLayout, + ParameterBindingInfo bindingInfos[kLayoutResourceKindCount], + RefPtr parameterInfo) +{ + // For any resource kind used by the parameter + // we need to update its layout information + // to include a binding for that resource kind. + // + auto firstTypeLayout = firstVarLayout->typeLayout; + + // We need to deal with allocation of full register spaces first, + // since that is the most complicated bit of logic. + // + // We will compute how many full register spaces the parameter + // needs to allocate, across all the kinds of resources it + // consumes, so that we can allocate a contiguous range of + // spaces. + // + UInt spacesToAllocateCount = 0; + for(auto typeRes : firstTypeLayout->resourceInfos) + { + auto kind = typeRes.kind; + + // We want to ignore resource kinds for which the user + // has specified an explicit binding, since those won't + // go into our contiguously allocated range. + // + auto& bindingInfo = bindingInfos[(int)kind]; + if( bindingInfo.count != 0 ) + { + continue; + } + + // Now we inspect the kind of resource to figure out + // its space requirements: + // + switch( kind ) + { + default: + // An unbounded-size array will need its own space. + // + if( typeRes.count.isInfinite() ) + { + spacesToAllocateCount++; + } + break; + + case LayoutResourceKind::RegisterSpace: + // If the parameter consumes any full spaces (e.g., it + // is a `struct` type with one or more unbounded arrays + // for fields), then we will include those spaces in + // our allocaiton. + // + // We assume/require here that we never end up needing + // an unbounded number of spaces. + // TODO: we should enforce that somewhere with an error. + // + spacesToAllocateCount += typeRes.count.getFiniteValue(); + break; + + case LayoutResourceKind::Uniform: + // We want to ignore uniform data for this calculation, + // since any uniform data in top-level shader parameters + // needs to go into a global constant buffer. + // + break; + + case LayoutResourceKind::GenericResource: + // This is more of a marker case, and shouldn't ever + // need a space allocated to it. + break; + } + } + + // If we compute that the parameter needs some number of full + // spaces allocated to it, then we will go ahead and allocate + // contiguous spaces here. + // + UInt firstAllocatedSpace = 0; + if(spacesToAllocateCount) + { + firstAllocatedSpace = allocateUnusedSpaces(context, spacesToAllocateCount); + } + + // We'll then dole the allocated spaces (if any) out to the resource + // categories that need them. + // + UInt currentAllocatedSpace = firstAllocatedSpace; + + for(auto typeRes : firstTypeLayout->resourceInfos) + { + // Did we already apply some explicit binding information + // for this resource kind? + auto kind = typeRes.kind; + auto& bindingInfo = bindingInfos[(int)kind]; + if( bindingInfo.count != 0 ) + { + // If things have already been bound, our work is done. + // + // TODO: it would be good to handle the case where a + // binding specified a space, but not an offset/index + // for some kind of resource. + // + continue; + } + + auto count = typeRes.count; + + // Certain resource kinds require special handling. + // + // Note: This `switch` statement should have a `case` for + // all of the special cases above that affect the computation of + // `spacesToAllocateCount`. + // + switch( kind ) + { + case LayoutResourceKind::RegisterSpace: + { + // The parameter's type needs to consume some number of whole + // register spaces, and we have already allocated a contiguous + // range of spaces above. + // + // As always, we can't handle the case of a parameter that needs + // an infinite number of spaces. + // + SLANG_ASSERT(count.isFinite()); + bindingInfo.count = count; + + // We will use the spaces we've allocated, and bump + // the variable tracking the "current" space by + // the number of spaces consumed. + // + bindingInfo.index = currentAllocatedSpace; + currentAllocatedSpace += count.getFiniteValue(); + + // TODO: what should we store as the "space" for + // an allocation of register spaces? Either zero + // or `space` makes sense, but it isn't clear + // which is a better choice. + bindingInfo.space = 0; + + continue; + } + + case LayoutResourceKind::GenericResource: + { + // `GenericResource` is somewhat confusingly named, + // but simply indicates that the type of this parameter + // in some way depends on a generic parameter that has + // not been bound to a concrete value, so that asking + // specific questions about its resource usage isn't + // really possible. + // + bindingInfo.space = 0; + bindingInfo.count = 1; + bindingInfo.index = 0; + continue; + } + + case LayoutResourceKind::Uniform: + // TODO: we don't currently handle global-scope uniform parameters. + break; + } + + // At this point, we know the parameter consumes some resource + // (e.g., D3D `t` registers or Vulkan `binding`s), and the user + // didn't specify an explicit binding, so we will have to + // assign one for them. + // + // If we are consuming an infinite amount of the given resource + // (e.g., an unbounded array of `Texure2D` requires an infinite + // number of `t` regisers in D3D), then we will go ahead + // and assign a full space: + // + if( count.isInfinite() ) + { + bindingInfo.count = count; + bindingInfo.index = 0; + bindingInfo.space = currentAllocatedSpace; + currentAllocatedSpace++; + } + else + { + // If we have a finite amount of resources, then + // we will go ahead and allocate from the "default" + // space. + + UInt space = context->shared->defaultSpace; + RefPtr usedRangeSet = findUsedRangeSetForSpace(context, space); + + bindingInfo.count = count; + bindingInfo.index = usedRangeSet->usedResourceRanges[(int)kind].Allocate(firstVarLayout, count.getFiniteValue()); + bindingInfo.space = space; + } + } +} + +static void applyBindingInfoToParameter( + RefPtr varLayout, + ParameterBindingInfo bindingInfos[kLayoutResourceKindCount]) +{ + for(auto k = 0; k < kLayoutResourceKindCount; ++k) + { + auto kind = LayoutResourceKind(k); + auto& bindingInfo = bindingInfos[k]; + + // skip resources we aren't consuming + if(bindingInfo.count == 0) + continue; + + // Add a record to the variable layout + auto varRes = varLayout->AddResourceInfo(kind); + varRes->space = (int) bindingInfo.space; + varRes->index = (int) bindingInfo.index; + } +} + +// Generate the binding information for a shader parameter. +static void completeBindingsForParameter( + ParameterBindingContext* context, + RefPtr parameterInfo) +{ + // We will use the first declaration of the parameter as + // a stand-in for all the declarations, so it is important + // that earlier code has validated that the declarations + // "match". + + SLANG_RELEASE_ASSERT(parameterInfo->varLayouts.getCount() != 0); + auto firstVarLayout = parameterInfo->varLayouts.getFirst(); + + completeBindingsForParameterImpl( + context, + firstVarLayout, + parameterInfo->bindingInfo, + parameterInfo); + + // At this point we should have explicit binding locations chosen for + // all the relevant resource kinds, so we can apply these to the + // declarations: + + for(auto& varLayout : parameterInfo->varLayouts) + { + applyBindingInfoToParameter(varLayout, parameterInfo->bindingInfo); + } +} + +static void completeBindingsForParameter( + ParameterBindingContext* context, + RefPtr varLayout) +{ + ParameterBindingInfo bindingInfos[kLayoutResourceKindCount]; + completeBindingsForParameterImpl( + context, + varLayout, + bindingInfos, + nullptr); + applyBindingInfoToParameter(varLayout, bindingInfos); +} + + /// Allocate binding location for any "pending" data in a shader parameter. + /// + /// When a parameter contains interface-type fields (recursively), we might + /// not have included them in the base layout for the parameter, and instead + /// need to allocate space for them after all other shader parameters have + /// been laid out. + /// + /// This function should be called on the `pendingVarLayout` field of an + /// existing `VarLayout` to ensure that its pending data has been properly + /// assigned storage. It handles the case where the `pendingVarLayout` + /// field is null. + /// +static void _allocateBindingsForPendingData( + ParameterBindingContext* context, + RefPtr pendingVarLayout) +{ + if(!pendingVarLayout) return; + + completeBindingsForParameter(context, pendingVarLayout); +} + +struct SimpleSemanticInfo +{ + String name; + int index; +}; + +SimpleSemanticInfo decomposeSimpleSemantic( + HLSLSimpleSemantic* semantic) +{ + auto composedName = semantic->name.Content; + + // look for a trailing sequence of decimal digits + // at the end of the composed name + UInt length = composedName.size(); + UInt indexLoc = length; + while( indexLoc > 0 ) + { + auto c = composedName[indexLoc-1]; + if( c >= '0' && c <= '9' ) + { + indexLoc--; + continue; + } + else + { + break; + } + } + + SimpleSemanticInfo info; + + // + if( indexLoc == length ) + { + // No index suffix + info.name = composedName; + info.index = 0; + } + else + { + // The name is everything before the digits + String stringComposedName(composedName); + + info.name = stringComposedName.subString(0, indexLoc); + info.index = strtol(stringComposedName.begin() + indexLoc, nullptr, 10); + } + return info; +} + +static RefPtr processSimpleEntryPointParameter( + ParameterBindingContext* context, + RefPtr type, + EntryPointParameterState const& inState, + RefPtr varLayout, + int semanticSlotCount = 1) +{ + EntryPointParameterState state = inState; + state.semanticSlotCount = semanticSlotCount; + + auto optSemanticName = state.optSemanticName; + auto semanticIndex = *state.ioSemanticIndex; + + String semanticName = optSemanticName ? *optSemanticName : ""; + String sn = semanticName.toLower(); + + RefPtr typeLayout; + if (sn.startsWith("sv_") + || sn.startsWith("nv_")) + { + // System-value semantic. + + if (state.directionMask & kEntryPointParameterDirection_Output) + { + // Note: I'm just doing something expedient here and detecting `SV_Target` + // outputs and claiming the appropriate register range right away. + // + // TODO: we should really be building up some representation of all of this, + // once we've gone to the trouble of looking it all up... + if( sn == "sv_target" ) + { + // TODO: construct a `ParameterInfo` we can use here so that + // overlapped layout errors get reported nicely. + + auto usedResourceSet = findUsedRangeSetForSpace(context, 0); + usedResourceSet->usedResourceRanges[int(LayoutResourceKind::UnorderedAccess)].Add(nullptr, semanticIndex, semanticIndex + semanticSlotCount); + + + // We also need to track this as an ordinary varying output from the stage, + // since that is how GLSL will want to see it. + // + typeLayout = getSimpleVaryingParameterTypeLayout( + context->layoutContext, + type, + kEntryPointParameterDirection_Output); + } + } + + if (state.directionMask & kEntryPointParameterDirection_Input) + { + if (sn == "sv_sampleindex") + { + state.isSampleRate = true; + } + } + + if( !typeLayout ) + { + // If we didn't compute a special-case layout for the + // system-value parameter (e.g., because it was an + // `SV_Target` output), then create a default layout + // that consumes no input/output varying slots. + // (since system parameters are distinct from + // user-defined parameters for layout purposes) + // + typeLayout = getSimpleVaryingParameterTypeLayout( + context->layoutContext, + type, + 0); + } + + // Remember the system-value semantic so that we can query it later + if (varLayout) + { + varLayout->systemValueSemantic = semanticName; + varLayout->systemValueSemanticIndex = semanticIndex; + } + + // TODO: add some kind of usage information for system input/output + } + else + { + // In this case we have a user-defined semantic, which means + // an ordinary input and/or output varying parameter. + // + typeLayout = getSimpleVaryingParameterTypeLayout( + context->layoutContext, + type, + state.directionMask); + } + + if (state.isSampleRate + && (state.directionMask & kEntryPointParameterDirection_Input) + && (context->stage == Stage::Fragment)) + { + if (auto entryPointLayout = context->entryPointLayout) + { + entryPointLayout->flags |= EntryPointLayout::Flag::usesAnySampleRateInput; + } + } + + *state.ioSemanticIndex += state.semanticSlotCount; + typeLayout->type = type; + + return typeLayout; +} + +static RefPtr processEntryPointVaryingParameterDecl( + ParameterBindingContext* context, + Decl* decl, + RefPtr type, + EntryPointParameterState const& inState, + RefPtr varLayout) +{ + SimpleSemanticInfo semanticInfo; + int semanticIndex = 0; + + EntryPointParameterState state = inState; + + // If there is no explicit semantic already in effect, *and* we find an explicit + // semantic on the associated declaration, then we'll use it. + if( !state.optSemanticName ) + { + if( auto semantic = decl->FindModifier() ) + { + semanticInfo = decomposeSimpleSemantic(semantic); + semanticIndex = semanticInfo.index; + + state.optSemanticName = &semanticInfo.name; + state.ioSemanticIndex = &semanticIndex; + } + } + + if (decl) + { + if (decl->FindModifier()) + { + state.isSampleRate = true; + } + } + + // Default case: either there was an explicit semantic in effect already, + // *or* we couldn't find an explicit semantic to apply on the given + // declaration, so we will just recursive with whatever we have at + // the moment. + return processEntryPointVaryingParameter(context, type, state, varLayout); +} + +static RefPtr processEntryPointVaryingParameter( + ParameterBindingContext* context, + RefPtr type, + EntryPointParameterState const& state, + RefPtr varLayout) +{ + // Make sure to associate a stage with every + // varying parameter (including sub-fields of + // `struct`-type parameters), since downstream + // code generation will need to look at the + // stage (possibly on individual leaf fields) to + // decide when to emit things like the `flat` + // interpolation modifier. + // + if( varLayout ) + { + varLayout->stage = state.stage; + } + + // The default handling of varying parameters should not apply + // to geometry shader output streams; they have their own special rules. + if( auto gsStreamType = as(type) ) + { + // + + auto elementType = gsStreamType->getElementType(); + + int semanticIndex = 0; + + EntryPointParameterState elementState; + elementState.directionMask = kEntryPointParameterDirection_Output; + elementState.ioSemanticIndex = &semanticIndex; + elementState.isSampleRate = false; + elementState.optSemanticName = nullptr; + elementState.semanticSlotCount = 0; + elementState.stage = state.stage; + elementState.loc = state.loc; + + auto elementTypeLayout = processEntryPointVaryingParameter(context, elementType, elementState, nullptr); + + RefPtr typeLayout = new StreamOutputTypeLayout(); + typeLayout->type = type; + typeLayout->rules = elementTypeLayout->rules; + typeLayout->elementTypeLayout = elementTypeLayout; + + for(auto resInfo : elementTypeLayout->resourceInfos) + typeLayout->addResourceUsage(resInfo); + + return typeLayout; + } + + // Raytracing shaders have a slightly different interpretation of their + // "varying" input/output parameters, since they don't have the same + // idea of previous/next stage as the rasterization shader types. + // + if( state.directionMask & kEntryPointParameterDirection_Output ) + { + // Note: we are silently treating `out` parameters as if they + // were `in out` for this test, under the assumption that + // an `out` parameter represents a write-only payload. + + switch(state.stage) + { + default: + // Not a raytracing shader. + break; + + case Stage::Intersection: + case Stage::RayGeneration: + // Don't expect this case to have any `in out` parameters. + getSink(context)->diagnose(state.loc, Diagnostics::dontExpectOutParametersForStage, getStageName(state.stage)); + break; + + case Stage::AnyHit: + case Stage::ClosestHit: + case Stage::Miss: + // `in out` or `out` parameter is payload + return createTypeLayout(context->layoutContext.with( + context->getRulesFamily()->getRayPayloadParameterRules()), + type); + + case Stage::Callable: + // `in out` or `out` parameter is payload + return createTypeLayout(context->layoutContext.with( + context->getRulesFamily()->getCallablePayloadParameterRules()), + type); + + } + } + else + { + switch(state.stage) + { + default: + // Not a raytracing shader. + break; + + case Stage::Intersection: + case Stage::RayGeneration: + case Stage::Miss: + case Stage::Callable: + // Don't expect this case to have any `in` parameters. + // + // TODO: For a miss or callable shader we could interpret + // an `in` parameter as indicating a payload that the + // programmer doesn't intend to write to. + // + getSink(context)->diagnose(state.loc, Diagnostics::dontExpectInParametersForStage, getStageName(state.stage)); + break; + + case Stage::AnyHit: + case Stage::ClosestHit: + // `in` parameter is hit attributes + return createTypeLayout(context->layoutContext.with( + context->getRulesFamily()->getHitAttributesParameterRules()), + type); + } + } + + // If there is an available semantic name and index, + // then we should apply it to this parameter unconditionally + // (that is, not just if it is a leaf parameter). + auto optSemanticName = state.optSemanticName; + if (optSemanticName && varLayout) + { + // Always store semantics in upper-case for + // reflection information, since they are + // supposed to be case-insensitive and + // upper-case is the dominant convention. + String semanticName = *optSemanticName; + String sn = semanticName.toUpper(); + + auto semanticIndex = *state.ioSemanticIndex; + + varLayout->semanticName = sn; + varLayout->semanticIndex = semanticIndex; + varLayout->flags |= VarLayoutFlag::HasSemantic; + } + + // Scalar and vector types are treated as outputs directly + if(auto basicType = as(type)) + { + return processSimpleEntryPointParameter(context, basicType, state, varLayout); + } + else if(auto vectorType = as(type)) + { + return processSimpleEntryPointParameter(context, vectorType, state, varLayout); + } + // A matrix is processed as if it was an array of rows + else if( auto matrixType = as(type) ) + { + auto rowCount = GetIntVal(matrixType->getRowCount()); + return processSimpleEntryPointParameter(context, matrixType, state, varLayout, (int) rowCount); + } + else if( auto arrayType = as(type) ) + { + // Note: Bad Things will happen if we have an array input + // without a semantic already being enforced. + + auto elementCount = (UInt) GetIntVal(arrayType->ArrayLength); + + // We use the first element to derive the layout for the element type + auto elementTypeLayout = processEntryPointVaryingParameter(context, arrayType->baseType, state, varLayout); + + // We still walk over subsequent elements to make sure they consume resources + // as needed + for( UInt ii = 1; ii < elementCount; ++ii ) + { + processEntryPointVaryingParameter(context, arrayType->baseType, state, nullptr); + } + + RefPtr arrayTypeLayout = new ArrayTypeLayout(); + arrayTypeLayout->elementTypeLayout = elementTypeLayout; + arrayTypeLayout->type = arrayType; + + for (auto rr : elementTypeLayout->resourceInfos) + { + arrayTypeLayout->findOrAddResourceInfo(rr.kind)->count = rr.count * elementCount; + } + + return arrayTypeLayout; + } + // Ignore a bunch of types that don't make sense here... + else if (auto textureType = as(type)) { return nullptr; } + else if(auto samplerStateType = as(type)) { return nullptr; } + else if(auto constantBufferType = as(type)) { return nullptr; } + // Catch declaration-reference types late in the sequence, since + // otherwise they will include all of the above cases... + else if( auto declRefType = as(type) ) + { + auto declRef = declRefType->declRef; + + if (auto structDeclRef = declRef.as()) + { + RefPtr structLayout = new StructTypeLayout(); + structLayout->type = type; + + // Need to recursively walk the fields of the structure now... + for( auto field : GetFields(structDeclRef) ) + { + RefPtr fieldVarLayout = new VarLayout(); + fieldVarLayout->varDecl = field; + + auto fieldTypeLayout = processEntryPointVaryingParameterDecl( + context, + field.getDecl(), + GetType(field), + state, + fieldVarLayout); + + if(fieldTypeLayout) + { + fieldVarLayout->typeLayout = fieldTypeLayout; + + for (auto rr : fieldTypeLayout->resourceInfos) + { + SLANG_RELEASE_ASSERT(rr.count != 0); + + auto structRes = structLayout->findOrAddResourceInfo(rr.kind); + fieldVarLayout->findOrAddResourceInfo(rr.kind)->index = structRes->count.getFiniteValue(); + structRes->count += rr.count; + } + } + + structLayout->fields.add(fieldVarLayout); + structLayout->mapVarToLayout.Add(field.getDecl(), fieldVarLayout); + } + + return structLayout; + } + else if (auto globalGenericParam = declRef.as()) + { + auto genParamTypeLayout = new GenericParamTypeLayout(); + // we should have already populated ProgramLayout::genericEntryPointParams list at this point, + // so we can find the index of this generic param decl in the list + genParamTypeLayout->type = type; + genParamTypeLayout->paramIndex = findGenericParam(context->shared->programLayout->globalGenericParams, globalGenericParam.getDecl()); + genParamTypeLayout->findOrAddResourceInfo(LayoutResourceKind::GenericResource)->count += 1; + return genParamTypeLayout; + } + else if (auto associatedTypeParam = declRef.as()) + { + RefPtr assocTypeLayout = new TypeLayout(); + assocTypeLayout->type = type; + return assocTypeLayout; + } + else + { + SLANG_UNEXPECTED("unhandled type kind"); + } + } + // If we ran into an error in checking the user's code, then skip this parameter + else if( auto errorType = as(type) ) + { + return nullptr; + } + + SLANG_UNEXPECTED("unhandled type kind"); + UNREACHABLE_RETURN(nullptr); +} + + /// Compute the type layout for a parameter declared directly on an entry point. +static RefPtr computeEntryPointParameterTypeLayout( + ParameterBindingContext* context, + SubstitutionSet typeSubst, + DeclRef paramDeclRef, + RefPtr paramVarLayout, + EntryPointParameterState& state) +{ + auto paramDeclRefType = GetType(paramDeclRef); + SLANG_ASSERT(paramDeclRefType); + + auto paramType = paramDeclRefType->Substitute(typeSubst).as(); + + if( paramDeclRef.getDecl()->HasModifier() ) + { + // An entry-point parameter that is explicitly marked `uniform` represents + // a uniform shader parameter passed via the implicitly-defined + // constant buffer (e.g., the `$Params` constant buffer seen in fxc/dxc output). + // + return createTypeLayout( + context->layoutContext.with(context->getRulesFamily()->getConstantBufferRules()), + paramType); + } + else + { + // The default case is a varying shader parameter, which could be used for + // input, output, or both. + // + // The varying case needs to not only compute a layout, but also assocaite + // "semantic" strings/indices with the varying parameters by recursively + // walking their structure. + + state.directionMask = 0; + + // If it appears to be an input, process it as such. + if( paramDeclRef.getDecl()->HasModifier() + || paramDeclRef.getDecl()->HasModifier() + || !paramDeclRef.getDecl()->HasModifier() ) + { + state.directionMask |= kEntryPointParameterDirection_Input; + } + + // If it appears to be an output, process it as such. + if(paramDeclRef.getDecl()->HasModifier() + || paramDeclRef.getDecl()->HasModifier()) + { + state.directionMask |= kEntryPointParameterDirection_Output; + } + + return processEntryPointVaryingParameterDecl( + context, + paramDeclRef.getDecl(), + paramType, + state, + paramVarLayout); + } +} + +// There are multiple places where we need to compute the layout +// for a "scope" such as the global scope or an entry point. +// The `ScopeLayoutBuilder` encapsulates the logic around: +// +// * Doing layout for the ordinary/uniform fields, which involves +// using the `struct` layout rules for constant buffers on +// the target. +// +// * Creating a final type/var layout that reflects whether the +// scope needs a constant buffer to be allocated to it. +// +struct ScopeLayoutBuilder +{ + ParameterBindingContext* m_context = nullptr; + LayoutRulesImpl* m_rules = nullptr; + RefPtr m_structLayout; + UniformLayoutInfo m_structLayoutInfo; + + // We need to compute a layout for any "pending" data inside + // of the parameters being added to the scope, to facilitate + // later allocating space for all the pending parameters after + // the primary shader parameters. + // + StructTypeLayoutBuilder m_pendingDataTypeLayoutBuilder; + + void beginLayout( + ParameterBindingContext* context) + { + m_context = context; + m_rules = context->getRulesFamily()->getConstantBufferRules(); + m_structLayout = new StructTypeLayout(); + m_structLayout->rules = m_rules; + + m_structLayoutInfo = m_rules->BeginStructLayout(); + } + + void _addParameter( + RefPtr firstVarLayout, + ParameterInfo* parameterInfo) + { + // Does the parameter have any uniform data? + auto layoutInfo = firstVarLayout->typeLayout->FindResourceInfo(LayoutResourceKind::Uniform); + LayoutSize uniformSize = layoutInfo ? layoutInfo->count : 0; + if( uniformSize != 0 ) + { + // Make sure uniform fields get laid out properly... + + UniformLayoutInfo fieldInfo( + uniformSize, + firstVarLayout->typeLayout->uniformAlignment); + + LayoutSize uniformOffset = m_rules->AddStructField( + &m_structLayoutInfo, + fieldInfo); + + if( parameterInfo ) + { + for( auto& varLayout : parameterInfo->varLayouts ) + { + varLayout->findOrAddResourceInfo(LayoutResourceKind::Uniform)->index = uniformOffset.getFiniteValue(); + } + } + else + { + firstVarLayout->findOrAddResourceInfo(LayoutResourceKind::Uniform)->index = uniformOffset.getFiniteValue(); + } + } + + m_structLayout->fields.add(firstVarLayout); + + if( parameterInfo ) + { + for( auto& varLayout : parameterInfo->varLayouts ) + { + m_structLayout->mapVarToLayout.Add(varLayout->varDecl.getDecl(), varLayout); + } + } + else + { + m_structLayout->mapVarToLayout.Add(firstVarLayout->varDecl.getDecl(), firstVarLayout); + } + + // Any "pending" items on a field type become "pending" items + // on the overall `struct` type layout. + // + // TODO: This logic ends up duplicated between here and the main + // `struct` layout logic in `type-layout.cpp`. If this gets any + // more complicated we should see if there is a way to share it. + // + if( auto fieldPendingDataTypeLayout = firstVarLayout->typeLayout->pendingDataTypeLayout ) + { + m_pendingDataTypeLayoutBuilder.beginLayoutIfNeeded(nullptr, m_rules); + auto fieldPendingDataVarLayout = m_pendingDataTypeLayoutBuilder.addField(firstVarLayout->varDecl, fieldPendingDataTypeLayout); + + m_structLayout->pendingDataTypeLayout = m_pendingDataTypeLayoutBuilder.getTypeLayout(); + + if( parameterInfo ) + { + for( auto& varLayout : parameterInfo->varLayouts ) + { + varLayout->pendingVarLayout = fieldPendingDataVarLayout; + } + } + else + { + firstVarLayout->pendingVarLayout = fieldPendingDataVarLayout; + } + } + } + + void addParameter( + RefPtr varLayout) + { + _addParameter(varLayout, nullptr); + } + + void addParameter( + ParameterInfo* parameterInfo) + { + SLANG_RELEASE_ASSERT(parameterInfo->varLayouts.getCount() != 0); + auto firstVarLayout = parameterInfo->varLayouts.getFirst(); + + _addParameter(firstVarLayout, parameterInfo); + } + + RefPtr endLayout() + { + // Finish computing the layout for the ordindary data (if any). + // + m_rules->EndStructLayout(&m_structLayoutInfo); + m_pendingDataTypeLayoutBuilder.endLayout(); + + // Copy the final layout information computed for ordinary data + // over to the struct type layout for the scope. + // + m_structLayout->addResourceUsage(LayoutResourceKind::Uniform, m_structLayoutInfo.size); + m_structLayout->uniformAlignment = m_structLayout->uniformAlignment; + + RefPtr scopeTypeLayout = m_structLayout; + + // If a constant buffer is needed (because there is a non-zero + // amount of uniform data), then we need to wrap up the layout + // to reflect the constant buffer that will be generated. + // + scopeTypeLayout = createConstantBufferTypeLayoutIfNeeded( + m_context->layoutContext, + scopeTypeLayout); + + // We now have a bunch of layout information, which we should + // record into a suitable object that represents the scope + RefPtr scopeVarLayout = new VarLayout(); + scopeVarLayout->typeLayout = scopeTypeLayout; + + if( auto pendingTypeLayout = scopeTypeLayout->pendingDataTypeLayout ) + { + RefPtr pendingVarLayout = new VarLayout(); + pendingVarLayout->typeLayout = pendingTypeLayout; + scopeVarLayout->pendingVarLayout = pendingVarLayout; + } + + return scopeVarLayout; + } +}; + + /// Helper routine to allocate a constant buffer binding if one is needed. + /// + /// This function primarily exists to encapsulate the logic for allocating + /// the resources required for a constant buffer in the appropriate + /// target-specific fashion. + /// +static ParameterBindingAndKindInfo maybeAllocateConstantBufferBinding( + ParameterBindingContext* context, + bool needConstantBuffer) +{ + if( !needConstantBuffer ) return ParameterBindingAndKindInfo(); + + UInt space = context->shared->defaultSpace; + auto usedRangeSet = findUsedRangeSetForSpace(context, space); + + auto layoutInfo = context->getRulesFamily()->getConstantBufferRules()->GetObjectLayout( + ShaderParameterKind::ConstantBuffer); + + ParameterBindingAndKindInfo info; + info.kind = layoutInfo.kind; + info.count = layoutInfo.size; + info.index = usedRangeSet->usedResourceRanges[(int)layoutInfo.kind].Allocate(nullptr, layoutInfo.size.getFiniteValue()); + info.space = space; + return info; +} + + /// Iterate over the parameters of an entry point to compute its requirements. + /// +static void collectEntryPointParameters( + ParameterBindingContext* context, + EntryPoint* entryPoint, + SubstitutionSet typeSubst) +{ + DeclRef entryPointFuncDeclRef = entryPoint->getFuncDeclRef(); + + // We will take responsibility for creating and filling in + // the `EntryPointLayout` object here. + // + RefPtr entryPointLayout = new EntryPointLayout(); + entryPointLayout->profile = entryPoint->getProfile(); + entryPointLayout->entryPoint = entryPointFuncDeclRef.getDecl(); + + // The entry point layout must be added to the output + // program layout so that it can be accessed by reflection. + // + context->shared->programLayout->entryPoints.add(entryPointLayout); + + // For the duration of our parameter collection work we will + // establish this entry point as the current one in the context. + // + context->entryPointLayout = entryPointLayout; + + // Note: this isn't really the best place for this logic to sit, + // but it is the simplest place where we have a direct correspondence + // between a single `EntryPoint` and its matching `EntryPointLayout`, + // so we'll use it. + // + for( auto taggedUnionType : entryPoint->getTaggedUnionTypes() ) + { + SLANG_ASSERT(taggedUnionType); + auto substType = taggedUnionType->Substitute(typeSubst).as(); + auto typeLayout = createTypeLayout(context->layoutContext, substType); + entryPointLayout->taggedUnionTypeLayouts.add(typeLayout); + } + + // We are going to iterate over the entry-point parameters, + // and while we do so we will go ahead and perform layout/binding + // assignment for two cases: + // + // First, the varying parameters of the entry point will have + // their semantics and locations assigned, so we set up state + // for tracking that layout. + // + int defaultSemanticIndex = 0; + EntryPointParameterState state; + state.ioSemanticIndex = &defaultSemanticIndex; + state.optSemanticName = nullptr; + state.semanticSlotCount = 0; + state.stage = entryPoint->getStage(); + + // Second, we will compute offsets for any "ordinary" data + // in the parameter list (e.g., a `uniform float4x4 mvp` parameter), + // which is what the `ScopeLayoutBuilder` is designed to help with. + // + ScopeLayoutBuilder scopeBuilder; + scopeBuilder.beginLayout(context); + auto paramsStructLayout = scopeBuilder.m_structLayout; + + for( auto& shaderParamInfo : entryPoint->getShaderParams() ) + { + auto paramDeclRef = shaderParamInfo.paramDeclRef; + + // When computing layout for an entry-point parameter, + // we want to make sure that the layout context has access + // to the existential type arguments (if any) that were + // provided for the entry-point existential type parameters (if any). + // + context->layoutContext= context->layoutContext + .withExistentialTypeArgs( + entryPoint->getExistentialTypeArgCount(), + entryPoint->getExistentialTypeArgs()) + .withExistentialTypeSlotsOffsetBy( + shaderParamInfo.firstExistentialTypeSlot); + + // Any error messages we emit during the process should + // refer to the location of this parameter. + // + state.loc = paramDeclRef.getLoc(); + + // We are going to construct the variable layout for this + // parameter *before* computing the type layout, because + // the type layout computation is also determining the effective + // semantic of the parameter, which needs to be stored + // back onto the `VarLayout`. + // + RefPtr paramVarLayout = new VarLayout(); + paramVarLayout->varDecl = paramDeclRef; + paramVarLayout->stage = state.stage; + + auto paramTypeLayout = computeEntryPointParameterTypeLayout( + context, + typeSubst, + paramDeclRef, + paramVarLayout, + state); + paramVarLayout->typeLayout = paramTypeLayout; + + // We expect to always be able to compute a layout for + // entry-point parameters, but to be defensive we will + // skip parameters that couldn't have a layout computed + // when assertions are disabled. + // + SLANG_ASSERT(paramTypeLayout); + if(!paramTypeLayout) + continue; + + // Now that we've computed the layout to use for the parameter, + // we need to add its resource usage to that of the entry + // point as a whole. + // + // Any "ordinary" data (e.g., a `float4x4`) needs to be accounted + // for using the `ScopeLayoutBuilder`, since it will handle + // the details of target-specific `struct` type layout. + // + scopeBuilder.addParameter(paramVarLayout); + + // All of the other resources types will be handled in a + // simpler loop that just increments the relevant counters. + // + for (auto paramTypeResInfo : paramTypeLayout->resourceInfos) + { + // We need to skip ordinary data because it is being + // handled by the `scopeBuilder`. + // + if(paramTypeResInfo.kind == LayoutResourceKind::Uniform) + continue; + + // Whatever resources the parameter uses, we need to + // assign the parameter's location/register/binding offset to + // be the sum of everything added so far. + // + auto entryPointResInfo = paramsStructLayout->findOrAddResourceInfo(paramTypeResInfo.kind); + paramVarLayout->findOrAddResourceInfo(paramTypeResInfo.kind)->index = entryPointResInfo->count.getFiniteValue(); + + // We then need to add the resources consumed by the parameter + // to those consumed by the entry point. + // + entryPointResInfo->count += paramTypeResInfo.count; + } + } + entryPointLayout->parametersLayout = scopeBuilder.endLayout(); + + // For an entry point with a non-`void` return type, we need to process the + // return type as a varying output parameter. + // + // TODO: Ideally we should make the layout process more robust to empty/void + // types and apply this logic unconditionally. + // + auto resultType = GetResultType(entryPointFuncDeclRef)->Substitute(typeSubst).as(); + SLANG_ASSERT(resultType); + + if( !resultType->Equals(resultType->getSession()->getVoidType()) ) + { + state.loc = entryPointFuncDeclRef.getLoc(); + state.directionMask = kEntryPointParameterDirection_Output; + + RefPtr resultLayout = new VarLayout(); + resultLayout->stage = state.stage; + + auto resultTypeLayout = processEntryPointVaryingParameterDecl( + context, + entryPointFuncDeclRef.getDecl(), + resultType->Substitute(typeSubst).as(), + state, + resultLayout); + + if( resultTypeLayout ) + { + resultLayout->typeLayout = resultTypeLayout; + + for (auto rr : resultTypeLayout->resourceInfos) + { + auto entryPointRes = paramsStructLayout->findOrAddResourceInfo(rr.kind); + resultLayout->findOrAddResourceInfo(rr.kind)->index = entryPointRes->count.getFiniteValue(); + entryPointRes->count += rr.count; + } + } + + entryPointLayout->resultLayout = resultLayout; + } +} + +static void collectParameters( + ParameterBindingContext* inContext, + Program* program) +{ + // All of the parameters in translation units directly + // referenced in the compile request are part of one + // logical namespace/"linkage" so that two parameters + // with the same name should represent the same + // parameter, and get the same binding(s) + + ParameterBindingContext contextData = *inContext; + auto context = &contextData; + context->stage = Stage::Unknown; + + auto globalGenericSubst = program->getGlobalGenericSubstitution(); + + // We will start by looking for any global generic type parameters. + + for(RefPtr module : program->getModuleDependencies()) + { + for( auto genParamDecl : module->getModuleDecl()->getMembersOfType() ) + { + collectGlobalGenericParameter(context, genParamDecl); + } + } + + // Once we have enumerated global generic type parameters, we can + // begin enumerating shader parameters, starting at the global scope. + // + // Because we have already enumerated the global generic type parameters, + // we will be able to look up the index of a global generic type parameter + // when we see it referenced in the type of one of the shader parameters. + + for(auto& globalParamInfo : program->getShaderParams() ) + { + // When computing layout for a global shader parameter, + // we want to make sure that the layout context has access + // to the existential type arguments (if any) that were + // provided for the global existential type parameters (if any). + // + context->layoutContext= context->layoutContext + .withExistentialTypeArgs( + program->getExistentialTypeArgCount(), + program->getExistentialTypeArgs()) + .withExistentialTypeSlotsOffsetBy( + globalParamInfo.firstExistentialTypeSlot); + + collectGlobalScopeParameter(context, globalParamInfo, globalGenericSubst); + } + + // Next consider parameters for entry points + for(auto entryPoint : program->getEntryPoints()) + { + context->stage = entryPoint->getStage(); + collectEntryPointParameters(context, entryPoint, globalGenericSubst); + } + context->entryPointLayout = nullptr; +} + + /// Emit a diagnostic about a uniform parameter at global scope. +void diagnoseGlobalUniform( + SharedParameterBindingContext* sharedContext, + VarDeclBase* varDecl) +{ + // It is entirely possible for Slang to support uniform parameters at the global scope, + // by bundling them into an implicit constant buffer, and indeed the layout algorithm + // implemented in this file computes a layout *as if* the Slang compiler does just that. + // + // The missing link is the downstream IR and code generation steps, where we would need + // to collect all of the global-scope uniforms into a common `struct` type and then + // create a new constant buffer parameter over that type. + // + // For now it is easier to simply ban this case, since most shader authors have + // switched to modern HLSL/GLSL style with `cbuffer` or `uniform` block declarations. + // + // TODO: In the long run it may be best to require *all* global-scope shader parameters + // to be marked with a keyword (e.g., `uniform`) so that ordinary global variable syntax can be + // used safely. + // + getSink(sharedContext)->diagnose(varDecl, Diagnostics::globalUniformsNotSupported, varDecl->getName()); +} + +static int _calcTotalNumUsedRegistersForLayoutResourceKind(ParameterBindingContext* bindingContext, LayoutResourceKind kind) +{ + int numUsed = 0; + for (auto& pair : bindingContext->shared->globalSpaceUsedRangeSets) + { + UsedRangeSet* rangeSet = pair.Value; + const auto& usedRanges = rangeSet->usedResourceRanges[kind]; + for (const auto& usedRange : usedRanges.ranges) + { + numUsed += int(usedRange.end - usedRange.begin); + } + } + return numUsed; +} + +RefPtr generateParameterBindings( + TargetProgram* targetProgram, + DiagnosticSink* sink) +{ + auto program = targetProgram->getProgram(); + auto targetReq = targetProgram->getTargetReq(); + + RefPtr programLayout = new ProgramLayout(); + programLayout->targetProgram = targetProgram; + + // Try to find rules based on the selected code-generation target + auto layoutContext = getInitialLayoutContextForTarget(targetReq, programLayout); + + // If there was no target, or there are no rules for the target, + // then bail out here. + if (!layoutContext.rules) + return nullptr; + + // Create a context to hold shared state during the process + // of generating parameter bindings + SharedParameterBindingContext sharedContext( + layoutContext.getRulesFamily(), + programLayout, + targetReq, + sink); + + // Create a sub-context to collect parameters that get + // declared into the global scope + ParameterBindingContext context; + context.shared = &sharedContext; + context.layoutContext = layoutContext; + + // Walk through AST to discover all the parameters + collectParameters(&context, program); + + // Now walk through the parameters to generate initial binding information + for( auto& parameter : sharedContext.parameters ) + { + generateParameterBindings(&context, parameter); + } + + // Determine if there are any global-scope parameters that use `Uniform` + // resources, and thus need to get packaged into a constant buffer. + // + // Note: this doesn't account for GLSL's support for "legacy" uniforms + // at global scope, which don't get assigned a CB. + bool needDefaultConstantBuffer = false; + for( auto& parameterInfo : sharedContext.parameters ) + { + SLANG_RELEASE_ASSERT(parameterInfo->varLayouts.getCount() != 0); + auto firstVarLayout = parameterInfo->varLayouts.getFirst(); + + // Does the field have any uniform data? + if( firstVarLayout->typeLayout->FindResourceInfo(LayoutResourceKind::Uniform) ) + { + needDefaultConstantBuffer = true; + diagnoseGlobalUniform(&sharedContext, firstVarLayout->varDecl); + } + } + + // Next, we want to determine if there are any global-scope parameters + // that don't just allocate a whole register space to themselves; these + // parameters will need to go into a "default" space, which should always + // be the first space we allocate. + // + // As a starting point, we will definitely need a "default" space if + // we are creating a default constant buffer, since it should get + // a binding in that "default" space. + // + bool needDefaultSpace = needDefaultConstantBuffer; + if (!needDefaultSpace) + { + // Next we will look at the global-scope parameters and see if + // any of them requires a `register` or `binding` that will + // thus need to land in a default space. + // + for (auto& parameterInfo : sharedContext.parameters) + { + SLANG_RELEASE_ASSERT(parameterInfo->varLayouts.getCount() != 0); + auto firstVarLayout = parameterInfo->varLayouts.getFirst(); + + // For each parameter, we will look at each resource it consumes. + // + for (auto resInfo : firstVarLayout->typeLayout->resourceInfos) + { + // We don't care about whole register spaces/sets, since + // we don't need to allocate a default space/set for a parameter + // that itself consumes a whole space/set. + // + if( resInfo.kind == LayoutResourceKind::RegisterSpace ) + continue; + + // We also don't want to consider resource kinds for which + // the variable already has an (explicit) binding, since + // the space from the explicit binding will be used, so + // that a default space isn't needed. + // + if( parameterInfo->bindingInfo[resInfo.kind].count != 0 ) + continue; + + // Otherwise, we have a shader parameter that will need + // a default space or set to live in. + // + needDefaultSpace = true; + break; + } + } + } + + // If we need a space for default bindings, then allocate it here. + if (needDefaultSpace) + { + UInt defaultSpace = 0; + + // Check if space #0 has been allocated yet. If not, then we'll + // want to use it. + if (sharedContext.usedSpaces.contains(0)) + { + // Somebody has already put things in space zero. + // + // TODO: There are two cases to handle here: + // + // 1) If there is any free register ranges in space #0, + // then we should keep using it as the default space. + // + // 2) If somebody went and put an HLSL unsized array into space #0, + // *or* if they manually placed something like a paramter block + // there (which should consume whole spaces), then we need to + // allocate an unused space instead. + // + // For now we don't deal with the concept of unsized arrays, or + // manually assigning parameter blocks to spaces, so we punt + // on this and assume case (1). + + defaultSpace = 0; + } + else + { + // Nobody has used space zero yet, so we need + // to make sure to reserve it for defaults. + defaultSpace = allocateUnusedSpaces(&context, 1); + + // The result of this allocation had better be that + // we got space #0, or else something has gone wrong. + SLANG_ASSERT(defaultSpace == 0); + } + + sharedContext.defaultSpace = defaultSpace; + } + + // If there are any global-scope uniforms, then we need to + // allocate a constant-buffer binding for them here. + // + ParameterBindingAndKindInfo globalConstantBufferBinding = maybeAllocateConstantBufferBinding( + &context, + needDefaultConstantBuffer); + + // Now walk through again to actually give everything + // ranges of registers... + for( auto& parameter : sharedContext.parameters ) + { + completeBindingsForParameter(&context, parameter); + } + + // After we have allocated registers/bindings to everything + // in the global scope we will process the parameters + // of each entry point in order. + // + // Note: the effect of the current implementation is to + // allocate non-overlapping registers/bindings between all + // the entry points in the compile request (e.g., if you + // have a vertex and fragment shader being compiled together, + // we will allocate distinct constant buffer registers for + // their uniform parameters). + // + // TODO: We probably need to provide some more nuanced control + // over whether entry points get overlapping or non-overlapping + // bindings. It seems clear that if we were compiling multiple + // compute kernels in one invocation we'd want them to get + // overlapping bindings, because we cannot ever have them bound + // together in a single pipeline state. + // + // Similarly, entry point parameters of DirectX Raytracing (DXR) + // shaders should probably be allowed to overlap by default, + // since those parameters should really go into the "local root signature." + // (Note: there is a bit more subtlety around ray tracing + // shaders that will be assembled into a "hit group") + // + // For now we are just doing the simplest thing, which will be + // appropriate for: + // + // * Compiling a single compute shader in a compile request. + // * Compiling some number of rasterization shader entry points + // in a single request, to be used together. + // * Compiling a single ray-tracing shader in a compile request. + // + for( auto entryPoint : sharedContext.programLayout->entryPoints ) + { + auto entryPointParamsLayout = entryPoint->parametersLayout; + completeBindingsForParameter(&context, entryPointParamsLayout); + } + + // Next we need to create a type layout to reflect the information + // we have collected, and we will use the `ScopeLayoutBuilder` + // to encapsulate the logic that can be shared with the entry-point + // case. + // + ScopeLayoutBuilder globalScopeLayoutBuilder; + globalScopeLayoutBuilder.beginLayout(&context); + for( auto& parameterInfo : sharedContext.parameters ) + { + globalScopeLayoutBuilder.addParameter(parameterInfo); + } + + auto globalScopeVarLayout = globalScopeLayoutBuilder.endLayout(); + if( globalConstantBufferBinding.count != 0 ) + { + auto cbInfo = globalScopeVarLayout->findOrAddResourceInfo(globalConstantBufferBinding.kind); + cbInfo->space = globalConstantBufferBinding.space; + cbInfo->index = globalConstantBufferBinding.index; + } + + // After we have laid out all the ordinary parameters, + // we need to go through the global scope plus each entry point, + // and "flush" out any pending data that was associated with + // those scopes as part of dealing with interface-type parameters. + // + _allocateBindingsForPendingData(&context, globalScopeVarLayout->pendingVarLayout); + for( auto entryPoint : sharedContext.programLayout->entryPoints ) + { + _allocateBindingsForPendingData(&context, entryPoint->parametersLayout->pendingVarLayout); + } + + + // HACK: we want global parameters to not have to deal with offsetting + // by the `VarLayout` stored in `globalScopeVarLayout`, so we will scan + // through and for any global parameter that used "pending" data, we will manually + // offset all of its resource infos to account for where the global pending data + // got placed. + // + // TODO: A more appropriate solution would be to pass the `globalScopeVarLayout` + // down into the pass that puts layout information onto global parameters in + // the IR, and apply the offsetting there. + // + for( auto& parameterInfo : sharedContext.parameters ) + { + for( auto varLayout : parameterInfo->varLayouts ) + { + auto pendingVarLayout = varLayout->pendingVarLayout; + if(!pendingVarLayout) continue; + + for( auto& resInfo : pendingVarLayout->resourceInfos ) + { + if( auto globalResInfo = globalScopeVarLayout->pendingVarLayout->FindResourceInfo(resInfo.kind) ) + { + resInfo.index += globalResInfo->index; + resInfo.space += globalResInfo->space; + } + } + } + } + + programLayout->parametersLayout = globalScopeVarLayout; + + { + const int numShaderRecordRegs = _calcTotalNumUsedRegistersForLayoutResourceKind(&context, LayoutResourceKind::ShaderRecord); + if (numShaderRecordRegs > 1) + { + sink->diagnose(SourceLoc(), Diagnostics::tooManyShaderRecordConstantBuffers, numShaderRecordRegs); + } + } + + return programLayout; +} + +ProgramLayout* TargetProgram::getOrCreateLayout(DiagnosticSink* sink) +{ + if( !m_layout ) + { + m_layout = generateParameterBindings(this, sink); + } + return m_layout; +} + +void generateParameterBindings( + Program* program, + TargetRequest* targetReq, + DiagnosticSink* sink) +{ + program->getTargetProgram(targetReq)->getOrCreateLayout(sink); +} + +} // namespace Slang diff --git a/source/slang/slang-parameter-binding.h b/source/slang/slang-parameter-binding.h new file mode 100644 index 000000000..139064680 --- /dev/null +++ b/source/slang/slang-parameter-binding.h @@ -0,0 +1,34 @@ +#ifndef SLANG_PARAMETER_BINDING_H +#define SLANG_PARAMETER_BINDING_H + +#include "../core/slang-basic.h" +#include "slang-syntax.h" + +#include "../../slang.h" + +namespace Slang { + +class Program; +class TargetRequest; + +// The parameter-binding interface is responsible for assigning +// binding locations/registers to every parameter of a shader +// program. This can include both parameters declared on a +// particular entry point, as well as parameters declared at +// global scope. +// + + +// Generate binding information for the given program, +// represented as a collection of different translation units, +// and attach that information to the syntax nodes +// of the program. + +void generateParameterBindings( + Program* program, + TargetRequest* targetReq, + DiagnosticSink* sink); + +} + +#endif diff --git a/source/slang/slang-parser.cpp b/source/slang/slang-parser.cpp new file mode 100644 index 000000000..bb70347c9 --- /dev/null +++ b/source/slang/slang-parser.cpp @@ -0,0 +1,4725 @@ +#include "slang-parser.h" + +#include + +#include "slang-compiler.h" +#include "slang-lookup.h" +#include "slang-visitor.h" + +namespace Slang +{ + // pre-declare + static Name* getName(Parser* parser, String const& text); + + // Helper class useful to build a list of modifiers. + struct ModifierListBuilder + { + ModifierListBuilder() + { + m_next = &m_result; + } + void add(Modifier* modifier) + { + // Doesn't handle SharedModifiers + SLANG_ASSERT(as(modifier) == nullptr); + + // Splice at end + *m_next = modifier; + m_next = &modifier->next; + } + template + T* find() const + { + Modifier* cur = m_result; + while (cur) + { + T* castCur = as(cur); + if (castCur) + { + return castCur; + } + cur = cur->next; + } + return nullptr; + } + template + bool hasType() const + { + return find() != nullptr; + } + RefPtr getFirst() { return m_result; }; + protected: + + RefPtr m_result; + RefPtr* m_next; + }; + + enum Precedence : int + { + Invalid = -1, + Comma, + Assignment, + TernaryConditional, + LogicalOr, + LogicalAnd, + BitOr, + BitXor, + BitAnd, + EqualityComparison, + RelationalComparison, + BitShift, + Additive, + Multiplicative, + Prefix, + Postfix, + }; + + // TODO: implement two pass parsing for file reference and struct type recognition + + class Parser + { + public: + NamePool* namePool; + SourceLanguage sourceLanguage; + + NamePool* getNamePool() { return namePool; } + SourceLanguage getSourceLanguage() { return sourceLanguage; } + + int anonymousCounter = 0; + + RefPtr outerScope; + RefPtr currentScope; + + TokenReader tokenReader; + DiagnosticSink* sink; + int genericDepth = 0; + + // Have we seen any `import` declarations? If so, we need + // to parse function bodies completely, even if we are in + // "rewrite" mode. + bool haveSeenAnyImportDecls = false; + + // Is the parser in a "recovering" state? + // During recovery we don't emit additional errors, until we find + // a token that we expected, when we exit recovery. + bool isRecovering = false; + + void FillPosition(SyntaxNode * node) + { + node->loc = tokenReader.PeekLoc(); + } + void PushScope(ContainerDecl* containerDecl) + { + RefPtr newScope = new Scope(); + newScope->containerDecl = containerDecl; + newScope->parent = currentScope; + + currentScope = newScope; + } + + void pushScopeAndSetParent(ContainerDecl* containerDecl) + { + containerDecl->ParentDecl = currentScope->containerDecl; + PushScope(containerDecl); + } + + void PopScope() + { + currentScope = currentScope->parent; + } + Parser( + Session* session, + TokenSpan const& _tokens, + DiagnosticSink * sink, + RefPtr const& outerScope) + : tokenReader(_tokens) + , sink(sink) + , outerScope(outerScope) + , m_session(session) + {} + Parser(const Parser & other) = default; + + Session* m_session = nullptr; + Session* getSession() { return m_session; } + + Token ReadToken(); + Token ReadToken(TokenType type); + Token ReadToken(const char * string); + bool LookAheadToken(TokenType type, int offset = 0); + bool LookAheadToken(const char * string, int offset = 0); + void parseSourceFile(ModuleDecl* program); + RefPtr ParseStruct(); + RefPtr ParseClass(); + RefPtr ParseStatement(); + RefPtr parseBlockStatement(); + RefPtr parseVarDeclrStatement(Modifiers modifiers); + RefPtr parseIfStatement(); + RefPtr ParseForStatement(); + RefPtr ParseWhileStatement(); + RefPtr ParseDoWhileStatement(); + RefPtr ParseBreakStatement(); + RefPtr ParseContinueStatement(); + RefPtr ParseReturnStatement(); + RefPtr ParseExpressionStatement(); + RefPtr ParseExpression(Precedence level = Precedence::Comma); + + // Parse an expression that might be used in an initializer or argument context, so we should avoid operator-comma + inline RefPtr ParseInitExpr() { return ParseExpression(Precedence::Assignment); } + inline RefPtr ParseArgExpr() { return ParseExpression(Precedence::Assignment); } + + RefPtr ParseLeafExpression(); + RefPtr ParseParameter(); + RefPtr ParseType(); + TypeExp ParseTypeExp(); + + Parser & operator = (const Parser &) = delete; + }; + + // Forward Declarations + + static void ParseDeclBody( + Parser* parser, + ContainerDecl* containerDecl, + TokenType closingToken); + + static RefPtr parseEnumDecl(Parser* parser); + + // Parse the `{}`-delimeted body of an aggregate type declaration + static void parseAggTypeDeclBody( + Parser* parser, + AggTypeDeclBase* decl); + + static RefPtr ParseOptSemantics( + Parser* parser); + + static void ParseOptSemantics( + Parser* parser, + Decl* decl); + + static RefPtr ParseDecl( + Parser* parser, + ContainerDecl* containerDecl); + + static RefPtr ParseSingleDecl( + Parser* parser, + ContainerDecl* containerDecl); + + // + + static void Unexpected( + Parser* parser) + { + // Don't emit "unexpected token" errors if we are in recovering mode + if (!parser->isRecovering) + { + parser->sink->diagnose(parser->tokenReader.PeekLoc(), Diagnostics::unexpectedToken, + parser->tokenReader.PeekTokenType()); + + // Switch into recovery mode, to suppress additional errors + parser->isRecovering = true; + } + } + + static void Unexpected( + Parser* parser, + char const* expected) + { + // Don't emit "unexpected token" errors if we are in recovering mode + if (!parser->isRecovering) + { + parser->sink->diagnose(parser->tokenReader.PeekLoc(), Diagnostics::unexpectedTokenExpectedTokenName, + parser->tokenReader.PeekTokenType(), + expected); + + // Switch into recovery mode, to suppress additional errors + parser->isRecovering = true; + } + } + + static void Unexpected( + Parser* parser, + TokenType expected) + { + // Don't emit "unexpected token" errors if we are in recovering mode + if (!parser->isRecovering) + { + parser->sink->diagnose(parser->tokenReader.PeekLoc(), Diagnostics::unexpectedTokenExpectedTokenType, + parser->tokenReader.PeekTokenType(), + expected); + + // Switch into recovery mode, to suppress additional errors + parser->isRecovering = true; + } + } + + static TokenType SkipToMatchingToken(TokenReader* reader, TokenType tokenType); + + // Skip a singel balanced token, which is either a single token in + // the common case, or a matched pair of tokens for `()`, `[]`, and `{}` + static TokenType SkipBalancedToken( + TokenReader* reader) + { + TokenType tokenType = reader->AdvanceToken().type; + switch (tokenType) + { + default: + break; + + case TokenType::LParent: tokenType = SkipToMatchingToken(reader, TokenType::RParent); break; + case TokenType::LBrace: tokenType = SkipToMatchingToken(reader, TokenType::RBrace); break; + case TokenType::LBracket: tokenType = SkipToMatchingToken(reader, TokenType::RBracket); break; + } + return tokenType; + } + + // Skip balanced + static TokenType SkipToMatchingToken( + TokenReader* reader, + TokenType tokenType) + { + for (;;) + { + if (reader->IsAtEnd()) return TokenType::EndOfFile; + if (reader->PeekTokenType() == tokenType) + { + reader->AdvanceToken(); + return tokenType; + } + SkipBalancedToken(reader); + } + } + + // Is the given token type one that is used to "close" a + // balanced construct. + static bool IsClosingToken(TokenType tokenType) + { + switch (tokenType) + { + case TokenType::EndOfFile: + case TokenType::RBracket: + case TokenType::RParent: + case TokenType::RBrace: + return true; + + default: + return false; + } + } + + + // Expect an identifier token with the given content, and consume it. + Token Parser::ReadToken(const char* expected) + { + if (tokenReader.PeekTokenType() == TokenType::Identifier + && tokenReader.PeekToken().Content == expected) + { + isRecovering = false; + return tokenReader.AdvanceToken(); + } + + if (!isRecovering) + { + Unexpected(this, expected); + return tokenReader.PeekToken(); + } + else + { + // Try to find a place to recover + for (;;) + { + // The token we expected? + // Then exit recovery mode and pretend like all is well. + if (tokenReader.PeekTokenType() == TokenType::Identifier + && tokenReader.PeekToken().Content == expected) + { + isRecovering = false; + return tokenReader.AdvanceToken(); + } + + + // Don't skip past any "closing" tokens. + if (IsClosingToken(tokenReader.PeekTokenType())) + { + return tokenReader.PeekToken(); + } + + // Skip balanced tokens and try again. + SkipBalancedToken(&tokenReader); + } + } + } + + Token Parser::ReadToken() + { + return tokenReader.AdvanceToken(); + } + + static bool TryRecover( + Parser* parser, + TokenType const* recoverBefore, + int recoverBeforeCount, + TokenType const* recoverAfter, + int recoverAfterCount) + { + if (!parser->isRecovering) + return true; + + // Determine if we are looking for common closing tokens, + // so that we can know whether or we are allowed to skip + // over them. + + bool lookingForEOF = false; + bool lookingForRCurly = false; + bool lookingForRParen = false; + bool lookingForRSquare = false; + + for (int ii = 0; ii < recoverBeforeCount; ++ii) + { + switch (recoverBefore[ii]) + { + default: + break; + + case TokenType::EndOfFile: lookingForEOF = true; break; + case TokenType::RBrace: lookingForRCurly = true; break; + case TokenType::RParent: lookingForRParen = true; break; + case TokenType::RBracket: lookingForRSquare = true; break; + } + } + for (int ii = 0; ii < recoverAfterCount; ++ii) + { + switch (recoverAfter[ii]) + { + default: + break; + + case TokenType::EndOfFile: lookingForEOF = true; break; + case TokenType::RBrace: lookingForRCurly = true; break; + case TokenType::RParent: lookingForRParen = true; break; + case TokenType::RBracket: lookingForRSquare = true; break; + } + } + + TokenReader* tokenReader = &parser->tokenReader; + for (;;) + { + TokenType peek = tokenReader->PeekTokenType(); + + // Is the next token in our recover-before set? + // If so, then we have recovered successfully! + for (int ii = 0; ii < recoverBeforeCount; ++ii) + { + if (peek == recoverBefore[ii]) + { + parser->isRecovering = false; + return true; + } + } + + // If we are looking at a token in our recover-after set, + // then consume it and recover + for (int ii = 0; ii < recoverAfterCount; ++ii) + { + if (peek == recoverAfter[ii]) + { + tokenReader->AdvanceToken(); + parser->isRecovering = false; + return true; + } + } + + // Don't try to skip past end of file + if (peek == TokenType::EndOfFile) + return false; + + switch (peek) + { + // Don't skip past simple "closing" tokens, *unless* + // we are looking for a closing token + case TokenType::RParent: + case TokenType::RBracket: + if (lookingForRParen || lookingForRSquare || lookingForRCurly || lookingForEOF) + { + // We are looking for a closing token, so it is okay to skip these + } + else + return false; + break; + + // Don't skip a `}`, to avoid spurious errors, + // with the exception of when we are looking for EOF + case TokenType::RBrace: + if (lookingForRCurly || lookingForEOF) + { + // We are looking for end-of-file, so it is okay to skip here + } + else + { + return false; + } + } + + // Skip balanced tokens and try again. + TokenType skipped = SkipBalancedToken(tokenReader); + + // If we happened to find a matched pair of tokens, and + // the end of it was a token we were looking for, + // then recover here + for (int ii = 0; ii < recoverAfterCount; ++ii) + { + if (skipped == recoverAfter[ii]) + { + parser->isRecovering = false; + return true; + } + } + } + } + + static bool TryRecoverBefore( + Parser* parser, + TokenType before0) + { + TokenType recoverBefore[] = { before0 }; + return TryRecover(parser, recoverBefore, 1, nullptr, 0); + } + + // Default recovery strategy, to use inside `{}`-delimeted blocks. + static bool TryRecover( + Parser* parser) + { + TokenType recoverBefore[] = { TokenType::RBrace }; + TokenType recoverAfter[] = { TokenType::Semicolon }; + return TryRecover(parser, recoverBefore, 1, recoverAfter, 1); + } + + Token Parser::ReadToken(TokenType expected) + { + if (tokenReader.PeekTokenType() == expected) + { + isRecovering = false; + return tokenReader.AdvanceToken(); + } + + if (!isRecovering) + { + Unexpected(this, expected); + return tokenReader.PeekToken(); + } + else + { + // Try to find a place to recover + if (TryRecoverBefore(this, expected)) + { + isRecovering = false; + return tokenReader.AdvanceToken(); + } + + return tokenReader.PeekToken(); + } + } + + bool Parser::LookAheadToken(const char * string, int offset) + { + TokenReader r = tokenReader; + for (int ii = 0; ii < offset; ++ii) + r.AdvanceToken(); + + return r.PeekTokenType() == TokenType::Identifier + && r.PeekToken().Content == string; +} + + bool Parser::LookAheadToken(TokenType type, int offset) + { + TokenReader r = tokenReader; + for (int ii = 0; ii < offset; ++ii) + r.AdvanceToken(); + + return r.PeekTokenType() == type; + } + + // Consume a token and return true it if matches, otherwise false + bool AdvanceIf(Parser* parser, TokenType tokenType) + { + if (parser->LookAheadToken(tokenType)) + { + parser->ReadToken(); + return true; + } + return false; + } + + // Consume a token and return true it if matches, otherwise false + bool AdvanceIf(Parser* parser, char const* text) + { + if (parser->LookAheadToken(text)) + { + parser->ReadToken(); + return true; + } + return false; + } + + // Consume a token and return true if it matches, otherwise check + // for end-of-file and expect that token (potentially producing + // an error) and return true to maintain forward progress. + // Otherwise return false. + bool AdvanceIfMatch(Parser* parser, TokenType tokenType) + { + // If we've run into a syntax error, but haven't recovered inside + // the block, then try to recover here. + if (parser->isRecovering) + { + TryRecoverBefore(parser, tokenType); + } + if (AdvanceIf(parser, tokenType)) + return true; + if (parser->tokenReader.PeekTokenType() == TokenType::EndOfFile) + { + parser->ReadToken(tokenType); + return true; + } + return false; + } + + RefPtr ParseTypeDef(Parser* parser, void* /*userData*/) + { + RefPtr typeDefDecl = new TypeDefDecl(); + + // TODO(tfoley): parse an actual declarator + auto type = parser->ParseTypeExp(); + + auto nameToken = parser->ReadToken(TokenType::Identifier); + typeDefDecl->loc = nameToken.loc; + + typeDefDecl->nameAndLoc = NameLoc(nameToken); + typeDefDecl->type = type; + + return typeDefDecl; + } + + // Add a modifier to a list of modifiers being built + static void AddModifier(RefPtr** ioModifierLink, RefPtr modifier) + { + RefPtr*& modifierLink = *ioModifierLink; + + // We'd like to add the modifier to the end of the list, + // but we need to be careful, in case there is a "shared" + // section of modifiers for multiple declarations. + // + // TODO: This whole approach is a mess because we are "accidentally quadratic" + // when adding many modifiers. + for(;;) + { + // At end of the chain? Done. + if(!*modifierLink) + break; + + // About to look at shared modifiers? Done. + RefPtr linkMod = *modifierLink; + if(as(linkMod)) + { + break; + } + + // Otherwise: keep traversing the modifier list. + modifierLink = &(*modifierLink)->next; + } + + // Splice the modifier into the linked list + + // We need to deal with the case where the modifier to + // be spliced in might actually be a modifier *list*, + // so that we actually want to splice in at the + // end of the new list... + auto spliceLink = &modifier->next; + while(*spliceLink) + spliceLink = &(*spliceLink)->next; + + // Do the splice. + *spliceLink = *modifierLink; + + *modifierLink = modifier; + modifierLink = &modifier->next; + } + + void addModifier( + RefPtr syntax, + RefPtr modifier) + { + auto modifierLink = &syntax->modifiers.first; + AddModifier(&modifierLink, modifier); + } + + // + // '::'? identifier ('::' identifier)* + static Token parseAttributeName(Parser* parser) + { + const SourceLoc scopedIdSourceLoc = parser->tokenReader.PeekLoc(); + + // Strip initial :: if there is one + const TokenType initialTokenType = parser->tokenReader.PeekTokenType(); + if (initialTokenType == TokenType::Scope) + { + parser->ReadToken(TokenType::Scope); + } + + const Token firstIdentifier = parser->ReadToken(TokenType::Identifier); + if (initialTokenType != TokenType::Scope && parser->tokenReader.PeekTokenType() != TokenType::Scope) + { + return firstIdentifier; + } + + // Build up scoped string + StringBuilder scopedIdentifierBuilder; + if (initialTokenType == TokenType::Scope) + { + scopedIdentifierBuilder.Append('_'); + } + scopedIdentifierBuilder.Append(firstIdentifier.Content); + + while (parser->tokenReader.PeekTokenType() == TokenType::Scope) + { + parser->ReadToken(TokenType::Scope); + scopedIdentifierBuilder.Append('_'); + + const Token nextIdentifier(parser->ReadToken(TokenType::Identifier)); + scopedIdentifierBuilder.Append(nextIdentifier.Content); + } + + // Make a 'token' + SourceManager* sourceManager = parser->sink->sourceManager; + const UnownedStringSlice scopedIdentifier(sourceManager->allocateStringSlice(scopedIdentifierBuilder.getUnownedSlice())); + Token token(TokenType::Identifier, scopedIdentifier, scopedIdSourceLoc); + + // Get the name pool + auto namePool = parser->getNamePool(); + + // Since it's an Identifier have to set the name. + token.ptrValue = namePool->getName(token.Content); + + return token; + } + + // Parse HLSL-style `[name(arg, ...)]` style "attribute" modifiers + static void ParseSquareBracketAttributes(Parser* parser, RefPtr** ioModifierLink) + { + parser->ReadToken(TokenType::LBracket); + + const bool hasDoubleBracket = AdvanceIf(parser, TokenType::LBracket); + + for(;;) + { + // Note: When parsing we just construct an AST node for an + // "unchecked" attribute, and defer all detailed semantic + // checking until later. + // + // An alternative would be to perform lookup of an `AttributeDecl` + // at this point, similar to what we do for `SyntaxDecl`, but it + // seems better to not complicate the parsing process any more. + // + + Token nameToken = parseAttributeName(parser); + + RefPtr modifier = new UncheckedAttribute(); + modifier->name = nameToken.getName(); + modifier->loc = nameToken.getLoc(); + modifier->scope = parser->currentScope; + + if (AdvanceIf(parser, TokenType::LParent)) + { + // HLSL-style `[name(arg0, ...)]` attribute + + while (!AdvanceIfMatch(parser, TokenType::RParent)) + { + auto arg = parser->ParseArgExpr(); + if (arg) + { + modifier->args.add(arg); + } + + if (AdvanceIfMatch(parser, TokenType::RParent)) + break; + + parser->ReadToken(TokenType::Comma); + } + } + AddModifier(ioModifierLink, modifier); + + + if (AdvanceIfMatch(parser, TokenType::RBracket)) + break; + + parser->ReadToken(TokenType::Comma); + } + + if (hasDoubleBracket) + { + // Read the second ] + parser->ReadToken(TokenType::RBracket); + } + } + + static TokenType peekTokenType(Parser* parser) + { + return parser->tokenReader.PeekTokenType(); + } + + static Token advanceToken(Parser* parser) + { + return parser->ReadToken(); + } + + static Token peekToken(Parser* parser) + { + return parser->tokenReader.PeekToken(); + } + + static SyntaxDecl* tryLookUpSyntaxDecl( + Parser* parser, + Name* name) + { + // Let's look up the name and see what we find. + + auto lookupResult = lookUp( + parser->getSession(), + nullptr, // no semantics visitor available yet + name, + parser->currentScope); + + // If we didn't find anything, or the result was overloaded, + // then we aren't going to be able to extract a single decl. + if(!lookupResult.isValid() || lookupResult.isOverloaded()) + return nullptr; + + auto decl = lookupResult.item.declRef.getDecl(); + if( auto syntaxDecl = as(decl) ) + { + return syntaxDecl; + } + else + { + return nullptr; + } + } + + template + bool tryParseUsingSyntaxDecl( + Parser* parser, + SyntaxDecl* syntaxDecl, + RefPtr* outSyntax) + { + if (!syntaxDecl) + return false; + + if (!syntaxDecl->syntaxClass.isSubClassOf()) + return false; + + // Consume the token that specified the keyword + auto keywordToken = advanceToken(parser); + + RefPtr parsedObject = syntaxDecl->parseCallback(parser, syntaxDecl->parseUserData); + if (!parsedObject) + { + return false; + } + + auto syntax = as(parsedObject); + if (syntax) + { + if (!syntax->loc.isValid()) + { + syntax->loc = keywordToken.loc; + } + } + else if (parsedObject) + { + // Something was parsed, but it didn't have the expected type! + SLANG_DIAGNOSE_UNEXPECTED(parser->sink, keywordToken, "parser callback did not return the expected type"); + } + + *outSyntax = syntax; + return true; + } + + template + bool tryParseUsingSyntaxDecl( + Parser* parser, + RefPtr* outSyntax) + { + if (peekTokenType(parser) != TokenType::Identifier) + return false; + + auto nameToken = peekToken(parser); + auto name = nameToken.getName(); + + auto syntaxDecl = tryLookUpSyntaxDecl(parser, name); + + if (!syntaxDecl) + return false; + + return tryParseUsingSyntaxDecl(parser, syntaxDecl, outSyntax); + } + + static Modifiers ParseModifiers(Parser* parser) + { + Modifiers modifiers; + RefPtr* modifierLink = &modifiers.first; + for (;;) + { + SourceLoc loc = parser->tokenReader.PeekLoc(); + + switch (peekTokenType(parser)) + { + default: + // If we don't see a token type that we recognize, then + // assume we are done with the modifier sequence. + return modifiers; + + case TokenType::Identifier: + { + // We see an identifier ahead, and it might be the name + // of a modifier keyword of some kind. + + Token nameToken = peekToken(parser); + + RefPtr parsedModifier; + if (tryParseUsingSyntaxDecl(parser, &parsedModifier)) + { + parsedModifier->name = nameToken.getName(); + if (!parsedModifier->loc.isValid()) + { + parsedModifier->loc = nameToken.loc; + } + + AddModifier(&modifierLink, parsedModifier); + continue; + } + + // If there was no match for a modifier keyword, then we + // must be at the end of the modifier sequence + return modifiers; + } + break; + + // HLSL uses `[attributeName]` style for its modifiers, which closely + // matches the C++ `[[attributeName]]` style. + case TokenType::LBracket: + ParseSquareBracketAttributes(parser, &modifierLink); + break; + } + } + } + + static Name* getName(Parser* parser, String const& text) + { + return parser->getNamePool()->getName(text); + } + + static NameLoc expectIdentifier(Parser* parser) + { + return NameLoc(parser->ReadToken(TokenType::Identifier)); + } + + + static RefPtr parseImportDecl( + Parser* parser, void* /*userData*/) + { + parser->haveSeenAnyImportDecls = true; + + auto decl = new ImportDecl(); + decl->scope = parser->currentScope; + + if (peekTokenType(parser) == TokenType::StringLiteral) + { + auto nameToken = parser->ReadToken(TokenType::StringLiteral); + auto nameString = getStringLiteralTokenValue(nameToken); + auto moduleName = getName(parser, nameString); + + decl->moduleNameAndLoc = NameLoc(moduleName, nameToken.loc); + } + else + { + auto moduleNameAndLoc = expectIdentifier(parser); + + // We allow a dotted format for the name, as sugar + if (peekTokenType(parser) == TokenType::Dot) + { + StringBuilder sb; + sb << getText(moduleNameAndLoc.name); + while (AdvanceIf(parser, TokenType::Dot)) + { + sb << "/"; + sb << parser->ReadToken(TokenType::Identifier).Content; + } + + moduleNameAndLoc.name = getName(parser, sb.ProduceString()); + } + + decl->moduleNameAndLoc = moduleNameAndLoc; + } + + parser->ReadToken(TokenType::Semicolon); + + return decl; + } + + static NameLoc ParseDeclName( + Parser* parser) + { + Token nameToken; + if (AdvanceIf(parser, "operator")) + { + nameToken = parser->ReadToken(); + switch (nameToken.type) + { + case TokenType::OpAdd: case TokenType::OpSub: case TokenType::OpMul: case TokenType::OpDiv: + case TokenType::OpMod: case TokenType::OpNot: case TokenType::OpBitNot: case TokenType::OpLsh: case TokenType::OpRsh: + case TokenType::OpEql: case TokenType::OpNeq: case TokenType::OpGreater: case TokenType::OpLess: case TokenType::OpGeq: + case TokenType::OpLeq: case TokenType::OpAnd: case TokenType::OpOr: case TokenType::OpBitXor: case TokenType::OpBitAnd: + case TokenType::OpBitOr: case TokenType::OpInc: case TokenType::OpDec: + case TokenType::OpAddAssign: + case TokenType::OpSubAssign: + case TokenType::OpMulAssign: + case TokenType::OpDivAssign: + case TokenType::OpModAssign: + case TokenType::OpShlAssign: + case TokenType::OpShrAssign: + case TokenType::OpOrAssign: + case TokenType::OpAndAssign: + case TokenType::OpXorAssign: + + // Note(tfoley): A bit of a hack: + case TokenType::Comma: + case TokenType::OpAssign: + break; + + // Note(tfoley): Even more of a hack! + case TokenType::QuestionMark: + if (AdvanceIf(parser, TokenType::Colon)) + { + // Concat : onto ? + nameToken.Content = UnownedStringSlice::fromLiteral("?:"); + break; + } + ; // fall-thru + default: + parser->sink->diagnose(nameToken.loc, Diagnostics::invalidOperator, nameToken); + break; + } + + return NameLoc( + getName(parser, nameToken.Content), + nameToken.loc); + } + else + { + nameToken = parser->ReadToken(TokenType::Identifier); + return NameLoc(nameToken); + } + } + + // A "declarator" as used in C-style languages + struct Declarator : RefObject + { + // Different cases of declarator appear as "flavors" here + enum class Flavor + { + name, + Pointer, + Array, + }; + Flavor flavor; + }; + + // The most common case of declarator uses a simple name + struct NameDeclarator : Declarator + { + NameLoc nameAndLoc; + }; + + // A declarator that declares a pointer type + struct PointerDeclarator : Declarator + { + // location of the `*` token + SourceLoc starLoc; + + RefPtr inner; + }; + + // A declarator that declares an array type + struct ArrayDeclarator : Declarator + { + RefPtr inner; + + // location of the `[` token + SourceLoc openBracketLoc; + + // The expression that yields the element count, or NULL + RefPtr elementCountExpr; + }; + + // "Unwrapped" information about a declarator + struct DeclaratorInfo + { + RefPtr typeSpec; + NameLoc nameAndLoc; + RefPtr semantics; + RefPtr initializer; + }; + + // Add a member declaration to its container, and ensure that its + // parent link is set up correctly. + static void AddMember(RefPtr container, RefPtr member) + { + if (container) + { + member->ParentDecl = container.Ptr(); + container->Members.add(member); + + container->memberDictionaryIsValid = false; + } + } + + static void AddMember(RefPtr scope, RefPtr member) + { + if (scope) + { + AddMember(scope->containerDecl, member); + } + } + + static RefPtr ParseGenericParamDecl( + Parser* parser, + RefPtr genericDecl) + { + // simple syntax to introduce a value parameter + if (AdvanceIf(parser, "let")) + { + // default case is a type parameter + auto paramDecl = new GenericValueParamDecl(); + paramDecl->nameAndLoc = NameLoc(parser->ReadToken(TokenType::Identifier)); + if (AdvanceIf(parser, TokenType::Colon)) + { + paramDecl->type = parser->ParseTypeExp(); + } + if (AdvanceIf(parser, TokenType::OpAssign)) + { + paramDecl->initExpr = parser->ParseInitExpr(); + } + return paramDecl; + } + else + { + // default case is a type parameter + RefPtr paramDecl = new GenericTypeParamDecl(); + parser->FillPosition(paramDecl); + paramDecl->nameAndLoc = NameLoc(parser->ReadToken(TokenType::Identifier)); + if (AdvanceIf(parser, TokenType::Colon)) + { + // The user is apply a constraint to this type parameter... + + auto paramConstraint = new GenericTypeConstraintDecl(); + parser->FillPosition(paramConstraint); + + auto paramType = DeclRefType::Create( + parser->getSession(), + DeclRef(paramDecl, nullptr)); + + auto paramTypeExpr = new SharedTypeExpr(); + paramTypeExpr->loc = paramDecl->loc; + paramTypeExpr->base.type = paramType; + paramTypeExpr->type = QualType(getTypeType(paramType)); + + paramConstraint->sub = TypeExp(paramTypeExpr); + paramConstraint->sup = parser->ParseTypeExp(); + + AddMember(genericDecl, paramConstraint); + + + } + if (AdvanceIf(parser, TokenType::OpAssign)) + { + paramDecl->initType = parser->ParseTypeExp(); + } + return paramDecl; + } + } + + template + static void ParseGenericDeclImpl( + Parser* parser, GenericDecl* decl, const TFunc & parseInnerFunc) + { + parser->ReadToken(TokenType::OpLess); + parser->genericDepth++; + while (!parser->LookAheadToken(TokenType::OpGreater)) + { + AddMember(decl, ParseGenericParamDecl(parser, decl)); + + if (parser->LookAheadToken(TokenType::OpGreater)) + break; + + parser->ReadToken(TokenType::Comma); + } + parser->genericDepth--; + parser->ReadToken(TokenType::OpGreater); + decl->inner = parseInnerFunc(decl); + decl->inner->ParentDecl = decl; + + // A generic decl hijacks the name of the declaration + // it wraps, so that lookup can find it. + if (decl->inner) + { + decl->nameAndLoc = decl->inner->nameAndLoc; + decl->loc = decl->inner->loc; + } + } + + template + static RefPtr parseOptGenericDecl( + Parser* parser, const ParseFunc& parseInner) + { + // TODO: may want more advanced disambiguation than this... + if (parser->LookAheadToken(TokenType::OpLess)) + { + RefPtr genericDecl = new GenericDecl(); + parser->FillPosition(genericDecl); + parser->PushScope(genericDecl); + ParseGenericDeclImpl(parser, genericDecl, parseInner); + parser->PopScope(); + return genericDecl; + } + else + { + return parseInner(nullptr); + } + } + + static RefPtr ParseGenericDecl(Parser* parser, void*) + { + RefPtr decl = new GenericDecl(); + parser->FillPosition(decl.Ptr()); + parser->PushScope(decl.Ptr()); + ParseGenericDeclImpl(parser, decl.Ptr(), [=](GenericDecl* genDecl) {return ParseSingleDecl(parser, genDecl); }); + parser->PopScope(); + return decl; + } + + static void parseParameterList( + Parser* parser, + RefPtr decl) + { + parser->ReadToken(TokenType::LParent); + + // Allow a declaration to use the keyword `void` for a parameter list, + // since that was required in ancient C, and continues to be supported + // in a bunc hof its derivatives even if it is a Bad Design Choice + // + // TODO: conditionalize this so we don't keep this around for "pure" + // Slang code + if( parser->LookAheadToken("void") && parser->LookAheadToken(TokenType::RParent, 1) ) + { + parser->ReadToken("void"); + parser->ReadToken(TokenType::RParent); + return; + } + + while (!AdvanceIfMatch(parser, TokenType::RParent)) + { + AddMember(decl, parser->ParseParameter()); + if (AdvanceIf(parser, TokenType::RParent)) + break; + parser->ReadToken(TokenType::Comma); + } + } + + // systematically replace all scopes in an expression tree + class ReplaceScopeVisitor : public ExprVisitor + { + public: + RefPtr scope; + void visitDeclRefExpr(DeclRefExpr* expr) + { + expr->scope = scope; + } + void visitGenericAppExpr(GenericAppExpr * expr) + { + expr->FunctionExpr->accept(this, nullptr); + for (auto arg : expr->Arguments) + arg->accept(this, nullptr); + } + void visitIndexExpr(IndexExpr * expr) + { + expr->BaseExpression->accept(this, nullptr); + expr->IndexExpression->accept(this, nullptr); + } + void visitMemberExpr(MemberExpr * expr) + { + expr->BaseExpression->accept(this, nullptr); + expr->scope = scope; + } + void visitStaticMemberExpr(StaticMemberExpr * expr) + { + expr->BaseExpression->accept(this, nullptr); + expr->scope = scope; + } + void visitExpr(Expr* /*expr*/) + {} + }; + + /// Parse an optional body statement for a declaration that can have a body. + static RefPtr parseOptBody(Parser* parser) + { + if (AdvanceIf(parser, TokenType::Semicolon)) + { + // empty body + return nullptr; + } + else + { + return parser->parseBlockStatement(); + } + } + + /// Complete parsing of a function using traditional (C-like) declarator syntax + static RefPtr parseTraditionalFuncDecl( + Parser* parser, + DeclaratorInfo const& declaratorInfo) + { + RefPtr decl = new FuncDecl(); + parser->FillPosition(decl.Ptr()); + decl->loc = declaratorInfo.nameAndLoc.loc; + decl->nameAndLoc = declaratorInfo.nameAndLoc; + + return parseOptGenericDecl(parser, [&](GenericDecl*) + { + // HACK: The return type of the function will already have been + // parsed in a scope that didn't include the function's generic + // parameters. + // + // We will use a visitor here to try and replace the scope associated + // with any name expressiosn in the reuslt type. + // + // TODO: This should be fixed by not associating scopes with + // such expressions at parse time, and instead pushing down scopes + // as part of the state during semantic checking. + // + ReplaceScopeVisitor replaceScopeVisitor; + replaceScopeVisitor.scope = parser->currentScope; + declaratorInfo.typeSpec->accept(&replaceScopeVisitor, nullptr); + + decl->ReturnType = TypeExp(declaratorInfo.typeSpec); + + parser->PushScope(decl); + + parseParameterList(parser, decl); + ParseOptSemantics(parser, decl.Ptr()); + decl->Body = parseOptBody(parser); + + parser->PopScope(); + + return decl; + }); + } + + static RefPtr CreateVarDeclForContext( + ContainerDecl* containerDecl ) + { + if (as(containerDecl)) + { + // Function parameters always use their dedicated syntax class. + // + return new ParamDecl(); + } + else + { + // Globals, locals, and member variables all use the same syntax class. + // + return new VarDecl(); + } + } + + // Add modifiers to the end of the modifier list for a declaration + void AddModifiers(Decl* decl, RefPtr modifiers) + { + if (!modifiers) + return; + + RefPtr* link = &decl->modifiers.first; + while (*link) + { + link = &(*link)->next; + } + *link = modifiers; + } + + static Name* generateName(Parser* parser, String const& base) + { + // TODO: somehow mangle the name to avoid clashes + return getName(parser, "SLANG_" + base); + } + + static Name* generateName(Parser* parser) + { + return generateName(parser, "anonymous_" + String(parser->anonymousCounter++)); + } + + + // Set up a variable declaration based on what we saw in its declarator... + static void CompleteVarDecl( + Parser* parser, + RefPtr decl, + DeclaratorInfo const& declaratorInfo) + { + parser->FillPosition(decl.Ptr()); + + if( !declaratorInfo.nameAndLoc.name ) + { + // HACK(tfoley): we always give a name, even if the declarator didn't include one... :( + decl->nameAndLoc = NameLoc(generateName(parser)); + } + else + { + decl->loc = declaratorInfo.nameAndLoc.loc; + decl->nameAndLoc = declaratorInfo.nameAndLoc; + } + decl->type = TypeExp(declaratorInfo.typeSpec); + + AddModifiers(decl.Ptr(), declaratorInfo.semantics); + + decl->initExpr = declaratorInfo.initializer; + } + + static RefPtr ParseDeclarator(Parser* parser); + + static RefPtr ParseDirectAbstractDeclarator( + Parser* parser) + { + RefPtr declarator; + switch( parser->tokenReader.PeekTokenType() ) + { + case TokenType::Identifier: + { + auto nameDeclarator = new NameDeclarator(); + nameDeclarator->flavor = Declarator::Flavor::name; + nameDeclarator->nameAndLoc = ParseDeclName(parser); + declarator = nameDeclarator; + } + break; + + case TokenType::LParent: + { + // Note(tfoley): This is a point where disambiguation is required. + // We could be looking at an abstract declarator for a function-type + // parameter: + // + // void F( int(int) ); + // + // Or we could be looking at the use of parenthesese in an ordinary + // declarator: + // + // void (*f)(int); + // + // The difference really doesn't matter right now, but we err in + // the direction of assuming the second case. + parser->ReadToken(TokenType::LParent); + declarator = ParseDeclarator(parser); + parser->ReadToken(TokenType::RParent); + } + break; + + default: + // an empty declarator is allowed + return nullptr; + } + + // postifx additions + for( ;;) + { + switch( parser->tokenReader.PeekTokenType() ) + { + case TokenType::LBracket: + { + auto arrayDeclarator = new ArrayDeclarator(); + arrayDeclarator->openBracketLoc = parser->tokenReader.PeekLoc(); + arrayDeclarator->flavor = Declarator::Flavor::Array; + arrayDeclarator->inner = declarator; + + parser->ReadToken(TokenType::LBracket); + if( parser->tokenReader.PeekTokenType() != TokenType::RBracket ) + { + arrayDeclarator->elementCountExpr = parser->ParseExpression(); + } + parser->ReadToken(TokenType::RBracket); + + declarator = arrayDeclarator; + continue; + } + + case TokenType::LParent: + break; + + default: + break; + } + + break; + } + + return declarator; + } + + // Parse a declarator (or at least as much of one as we support) + static RefPtr ParseDeclarator( + Parser* parser) + { + if( parser->tokenReader.PeekTokenType() == TokenType::OpMul ) + { + auto ptrDeclarator = new PointerDeclarator(); + ptrDeclarator->starLoc = parser->tokenReader.PeekLoc(); + ptrDeclarator->flavor = Declarator::Flavor::Pointer; + + parser->ReadToken(TokenType::OpMul); + + // TODO(tfoley): allow qualifiers like `const` here? + + ptrDeclarator->inner = ParseDeclarator(parser); + return ptrDeclarator; + } + else + { + return ParseDirectAbstractDeclarator(parser); + } + } + + // A declarator plus optional semantics and initializer + struct InitDeclarator + { + RefPtr declarator; + RefPtr semantics; + RefPtr initializer; + }; + + // Parse a declarator plus optional semantics + static InitDeclarator ParseSemanticDeclarator( + Parser* parser) + { + InitDeclarator result; + result.declarator = ParseDeclarator(parser); + result.semantics = ParseOptSemantics(parser); + return result; + } + + // Parse a declarator plus optional semantics and initializer + static InitDeclarator ParseInitDeclarator( + Parser* parser) + { + InitDeclarator result = ParseSemanticDeclarator(parser); + if (AdvanceIf(parser, TokenType::OpAssign)) + { + result.initializer = parser->ParseInitExpr(); + } + return result; + } + + static void UnwrapDeclarator( + RefPtr declarator, + DeclaratorInfo* ioInfo) + { + while( declarator ) + { + switch(declarator->flavor) + { + case Declarator::Flavor::name: + { + auto nameDeclarator = (NameDeclarator*) declarator.Ptr(); + ioInfo->nameAndLoc = nameDeclarator->nameAndLoc; + return; + } + break; + + case Declarator::Flavor::Pointer: + { + auto ptrDeclarator = (PointerDeclarator*) declarator.Ptr(); + + // TODO(tfoley): we don't support pointers for now + // ioInfo->typeSpec = new PointerTypeExpr(ioInfo->typeSpec); + + declarator = ptrDeclarator->inner; + } + break; + + case Declarator::Flavor::Array: + { + // TODO(tfoley): we don't support pointers for now + auto arrayDeclarator = (ArrayDeclarator*) declarator.Ptr(); + + auto arrayTypeExpr = new IndexExpr(); + arrayTypeExpr->loc = arrayDeclarator->openBracketLoc; + arrayTypeExpr->BaseExpression = ioInfo->typeSpec; + arrayTypeExpr->IndexExpression = arrayDeclarator->elementCountExpr; + ioInfo->typeSpec = arrayTypeExpr; + + declarator = arrayDeclarator->inner; + } + break; + + default: + SLANG_UNREACHABLE("all cases handled"); + break; + } + } + } + + static void UnwrapDeclarator( + InitDeclarator const& initDeclarator, + DeclaratorInfo* ioInfo) + { + UnwrapDeclarator(initDeclarator.declarator, ioInfo); + ioInfo->semantics = initDeclarator.semantics; + ioInfo->initializer = initDeclarator.initializer; + } + + // Either a single declaration, or a group of them + struct DeclGroupBuilder + { + SourceLoc startPosition; + RefPtr decl; + RefPtr group; + + // Add a new declaration to the potential group + void addDecl( + RefPtr newDecl) + { + SLANG_ASSERT(newDecl); + + if( decl ) + { + group = new DeclGroup(); + group->loc = startPosition; + group->decls.add(decl); + decl = nullptr; + } + + if( group ) + { + group->decls.add(newDecl); + } + else + { + decl = newDecl; + } + } + + RefPtr getResult() + { + if(group) return group; + return decl; + } + }; + + // Pares an argument to an application of a generic + RefPtr ParseGenericArg(Parser* parser) + { + return parser->ParseArgExpr(); + } + + // Create a type expression that will refer to the given declaration + static RefPtr + createDeclRefType(Parser* parser, RefPtr decl) + { + // For now we just construct an expression that + // will look up the given declaration by name. + // + // TODO: do this better, e.g. by filling in the `declRef` field directly + + auto expr = new VarExpr(); + expr->scope = parser->currentScope.Ptr(); + expr->loc = decl->getNameLoc(); + expr->name = decl->getName(); + return expr; + } + + // Representation for a parsed type specifier, which might + // include a declaration (e.g., of a `struct` type) + struct TypeSpec + { + // If the type-spec declared something, then put it here + RefPtr decl; + + // Put the resulting expression (which should evaluate to a type) here + RefPtr expr; + }; + + static RefPtr parseGenericApp( + Parser* parser, + RefPtr base) + { + RefPtr genericApp = new GenericAppExpr(); + + parser->FillPosition(genericApp.Ptr()); // set up scope for lookup + genericApp->FunctionExpr = base; + parser->ReadToken(TokenType::OpLess); + parser->genericDepth++; + // For now assume all generics have at least one argument + genericApp->Arguments.add(ParseGenericArg(parser)); + while (AdvanceIf(parser, TokenType::Comma)) + { + genericApp->Arguments.add(ParseGenericArg(parser)); + } + parser->genericDepth--; + + if (parser->tokenReader.PeekToken().type == TokenType::OpRsh) + { + parser->tokenReader.PeekToken().type = TokenType::OpGreater; + parser->tokenReader.PeekToken().loc.setRaw(parser->tokenReader.PeekToken().loc.getRaw() + 1); + } + else if (parser->LookAheadToken(TokenType::OpGreater)) + parser->ReadToken(TokenType::OpGreater); + else + parser->sink->diagnose(parser->tokenReader.PeekToken(), Diagnostics::tokenTypeExpected, "'>'"); + return genericApp; + } + + static bool isGenericName(Parser* parser, Name* name) + { + auto lookupResult = lookUp( + parser->getSession(), + nullptr, // no semantics visitor available yet + name, + parser->currentScope); + if (!lookupResult.isValid() || lookupResult.isOverloaded()) + return false; + + return lookupResult.item.declRef.is(); + } + + static RefPtr tryParseGenericApp( + Parser* parser, + RefPtr base) + { + Name * baseName = nullptr; + if (auto varExpr = as(base)) + baseName = varExpr->name; + // if base is a known generics, parse as generics + if (baseName && isGenericName(parser, baseName)) + return parseGenericApp(parser, base); + + // otherwise, we speculate as generics, and fallback to comparison when parsing failed + TokenSpan tokenSpan; + tokenSpan.mBegin = parser->tokenReader.mCursor; + tokenSpan.mEnd = parser->tokenReader.mEnd; + DiagnosticSink newSink; + newSink.sourceManager = parser->sink->sourceManager; + Parser newParser(*parser); + newParser.sink = &newSink; + auto speculateParseRs = parseGenericApp(&newParser, base); + if (newSink.errorCount == 0) + { + // disambiguate based on FOLLOW set + switch (peekTokenType(&newParser)) + { + case TokenType::Dot: + case TokenType::LParent: + case TokenType::RParent: + case TokenType::RBracket: + case TokenType::Colon: + case TokenType::Comma: + case TokenType::QuestionMark: + case TokenType::Semicolon: + case TokenType::OpEql: + case TokenType::OpNeq: + { + return parseGenericApp(parser, base); + } + } + } + return base; + } + static RefPtr parseMemberType(Parser * parser, RefPtr base) + { + // When called the :: or . have been consumed, so don't need to consume here. + + RefPtr memberExpr = new MemberExpr(); + + parser->FillPosition(memberExpr.Ptr()); + memberExpr->BaseExpression = base; + memberExpr->name = expectIdentifier(parser).name; + return memberExpr; + } + + // Parse option `[]` braces after a type expression, that indicate an array type + static RefPtr parsePostfixTypeSuffix( + Parser* parser, + RefPtr inTypeExpr) + { + auto typeExpr = inTypeExpr; + while (parser->LookAheadToken(TokenType::LBracket)) + { + RefPtr arrType = new IndexExpr(); + arrType->loc = typeExpr->loc; + arrType->BaseExpression = typeExpr; + parser->ReadToken(TokenType::LBracket); + if (!parser->LookAheadToken(TokenType::RBracket)) + { + arrType->IndexExpression = parser->ParseExpression(); + } + parser->ReadToken(TokenType::RBracket); + typeExpr = arrType; + } + return typeExpr; + } + + static RefPtr parseTaggedUnionType(Parser* parser) + { + RefPtr taggedUnionType = new TaggedUnionTypeExpr(); + + parser->ReadToken(TokenType::LParent); + while(!AdvanceIfMatch(parser, TokenType::RParent)) + { + auto caseType = parser->ParseTypeExp(); + taggedUnionType->caseTypes.add(caseType); + + if(AdvanceIf(parser, TokenType::RParent)) + break; + + parser->ReadToken(TokenType::Comma); + } + + return taggedUnionType; + } + + static TypeSpec parseTypeSpec(Parser* parser) + { + TypeSpec typeSpec; + + // We may see a `struct` (or `enum` or `class`) tag specified here, and need to act accordingly + // + // TODO(tfoley): Handle the case where the user is just using `struct` + // as a way to name an existing struct "tag" (e.g., `struct Foo foo;`) + // + // TODO: We should really make these keywords be registered like any other + // syntax category, rather than be special-cased here. The main issue here + // is that we need to allow them to be used as type specifiers, as in: + // + // struct Foo { int x } foo; + // + // The ideal answer would be to register certain keywords as being able + // to parse a type specifier, and look for those keywords here. + // We should ideally add special case logic that bails out of declarator + // parsing iff we have one of these kinds of type specifiers and the + // closing `}` is at the end of its line, as a bit of a special case + // to allow the common idiom. + // + if( parser->LookAheadToken("struct") ) + { + auto decl = parser->ParseStruct(); + typeSpec.decl = decl; + typeSpec.expr = createDeclRefType(parser, decl); + return typeSpec; + } + else if( parser->LookAheadToken("class") ) + { + auto decl = parser->ParseClass(); + typeSpec.decl = decl; + typeSpec.expr = createDeclRefType(parser, decl); + return typeSpec; + } + else if(parser->LookAheadToken("enum")) + { + auto decl = parseEnumDecl(parser); + typeSpec.decl = decl; + typeSpec.expr = createDeclRefType(parser, decl); + return typeSpec; + } + else if(AdvanceIf(parser, "__TaggedUnion")) + { + typeSpec.expr = parseTaggedUnionType(parser); + return typeSpec; + } + + Token typeName = parser->ReadToken(TokenType::Identifier); + + auto basicType = new VarExpr(); + basicType->scope = parser->currentScope.Ptr(); + basicType->loc = typeName.loc; + basicType->name = typeName.getNameOrNull(); + + RefPtr typeExpr = basicType; + + bool shouldLoop = true; + while (shouldLoop) + { + switch (peekTokenType(parser)) + { + case TokenType::OpLess: + typeExpr = parseGenericApp(parser, typeExpr); + break; + case TokenType::Scope: + parser->ReadToken(TokenType::Scope); + typeExpr = parseMemberType(parser, typeExpr); + break; + case TokenType::Dot: + parser->ReadToken(TokenType::Dot); + typeExpr = parseMemberType(parser, typeExpr); + break; + default: + shouldLoop = false; + } + } + + typeSpec.expr = typeExpr; + return typeSpec; + } + + static RefPtr ParseDeclaratorDecl( + Parser* parser, + ContainerDecl* containerDecl) + { + SourceLoc startPosition = parser->tokenReader.PeekLoc(); + + auto typeSpec = parseTypeSpec(parser); + + // We may need to build up multiple declarations in a group, + // but the common case will be when we have just a single + // declaration + DeclGroupBuilder declGroupBuilder; + declGroupBuilder.startPosition = startPosition; + + // The type specifier may include a declaration. E.g., + // it might declare a `struct` type. + if(typeSpec.decl) + declGroupBuilder.addDecl(typeSpec.decl); + + if( AdvanceIf(parser, TokenType::Semicolon) ) + { + // No actual variable is being declared here, but + // that might not be an error. + + auto result = declGroupBuilder.getResult(); + if( !result ) + { + parser->sink->diagnose(startPosition, Diagnostics::declarationDidntDeclareAnything); + } + return result; + } + + // It is possible that we have a plain `struct`, `enum`, + // or similar declaration that isn't being used to declare + // any variable, and the user didn't put a trailing + // semicolon on it: + // + // struct Batman + // { + // int cape; + // } + // + // We want to allow this syntax (rather than give an + // inscrutable error), but also support the less common + // idiom where that declaration is used as part of + // a variable declaration: + // + // struct Robin + // { + // float tights; + // } boyWonder; + // + // As a bit of a hack (insofar as it means we aren't + // *really* compatible with arbitrary HLSL code), we + // will check if there are any more tokens on the + // same line as the closing `}`, and if not, we + // will treat it like the end of the declaration. + // + // Just as a safety net, only apply this logic for + // a file that is being passed in as "true" Slang code. + // + if(parser->getSourceLanguage() == SourceLanguage::Slang) + { + if(typeSpec.decl) + { + if(peekToken(parser).flags & TokenFlag::AtStartOfLine) + { + // The token after the `}` is at the start of its + // own line, which means it can't be on the same line. + // + // This means the programmer probably wants to + // just treat this as a declaration. + return declGroupBuilder.getResult(); + } + } + } + + + InitDeclarator initDeclarator = ParseInitDeclarator(parser); + + DeclaratorInfo declaratorInfo; + declaratorInfo.typeSpec = typeSpec.expr; + + + // Rather than parse function declarators properly for now, + // we'll just do a quick disambiguation here. This won't + // matter unless we actually decide to support function-type parameters, + // using C syntax. + // + if ((parser->tokenReader.PeekTokenType() == TokenType::LParent || + parser->tokenReader.PeekTokenType() == TokenType::OpLess) + + // Only parse as a function if we didn't already see mutually-exclusive + // constructs when parsing the declarator. + && !initDeclarator.initializer + && !initDeclarator.semantics) + { + // Looks like a function, so parse it like one. + UnwrapDeclarator(initDeclarator, &declaratorInfo); + return parseTraditionalFuncDecl(parser, declaratorInfo); + } + + // Otherwise we are looking at a variable declaration, which could be one in a sequence... + + if( AdvanceIf(parser, TokenType::Semicolon) ) + { + // easy case: we only had a single declaration! + UnwrapDeclarator(initDeclarator, &declaratorInfo); + RefPtr firstDecl = CreateVarDeclForContext(containerDecl); + CompleteVarDecl(parser, firstDecl, declaratorInfo); + + declGroupBuilder.addDecl(firstDecl); + return declGroupBuilder.getResult(); + } + + // Otherwise we have multiple declarations in a sequence, and these + // declarations need to somehow share both the type spec and modifiers. + // + // If there are any errors in the type specifier, we only want to hear + // about it once, so we need to share structure rather than just + // clone syntax. + + auto sharedTypeSpec = new SharedTypeExpr(); + sharedTypeSpec->loc = typeSpec.expr->loc; + sharedTypeSpec->base = TypeExp(typeSpec.expr); + + for(;;) + { + declaratorInfo.typeSpec = sharedTypeSpec; + UnwrapDeclarator(initDeclarator, &declaratorInfo); + + RefPtr varDecl = CreateVarDeclForContext(containerDecl); + CompleteVarDecl(parser, varDecl, declaratorInfo); + + declGroupBuilder.addDecl(varDecl); + + // end of the sequence? + if(AdvanceIf(parser, TokenType::Semicolon)) + return declGroupBuilder.getResult(); + + // ad-hoc recovery, to avoid infinite loops + if( parser->isRecovering ) + { + parser->ReadToken(TokenType::Semicolon); + return declGroupBuilder.getResult(); + } + + // Let's default to assuming that a missing `,` + // indicates the end of a declaration, + // where a `;` would be expected, and not + // a continuation of this declaration, where + // a `,` would be expected (this is tailoring + // the diagnostic message a bit). + // + // TODO: a more advanced heuristic here might + // look at whether the next token is on the + // same line, to predict whether `,` or `;` + // would be more likely... + + if (!AdvanceIf(parser, TokenType::Comma)) + { + parser->ReadToken(TokenType::Semicolon); + return declGroupBuilder.getResult(); + } + + // expect another variable declaration... + initDeclarator = ParseInitDeclarator(parser); + } + } + + /// Parse the "register name" part of a `register` or `packoffset` semantic. + /// + /// The syntax matched is: + /// + /// register-name-and-component-mask ::= register-name component-mask? + /// register-name ::= identifier + /// component-mask ::= '.' identifier + /// + static void parseHLSLRegisterNameAndOptionalComponentMask( + Parser* parser, + HLSLLayoutSemantic* semantic) + { + semantic->registerName = parser->ReadToken(TokenType::Identifier); + if (AdvanceIf(parser, TokenType::Dot)) + { + semantic->componentMask = parser->ReadToken(TokenType::Identifier); + } + } + + /// Parse an HLSL `register` semantic. + /// + /// The syntax matched is: + /// + /// register-semantic ::= 'register' '(' register-name-and-component-mask register-space? ')' + /// register-space ::= ',' identifier + /// + static void parseHLSLRegisterSemantic( + Parser* parser, + HLSLRegisterSemantic* semantic) + { + // Read the `register` keyword + semantic->name = parser->ReadToken(TokenType::Identifier); + + // Expect a parenthized list of additional arguments + parser->ReadToken(TokenType::LParent); + + // First argument is a required register name and optional component mask + parseHLSLRegisterNameAndOptionalComponentMask(parser, semantic); + + // Second argument is an optional register space + if(AdvanceIf(parser, TokenType::Comma)) + { + semantic->spaceName = parser->ReadToken(TokenType::Identifier); + } + + parser->ReadToken(TokenType::RParent); + } + + /// Parse an HLSL `packoffset` semantic. + /// + /// The syntax matched is: + /// + /// packoffset-semantic ::= 'packoffset' '(' register-name-and-component-mask ')' + /// + static void parseHLSLPackOffsetSemantic( + Parser* parser, + HLSLPackOffsetSemantic* semantic) + { + // Read the `packoffset` keyword + semantic->name = parser->ReadToken(TokenType::Identifier); + + // Expect a parenthized list of additional arguments + parser->ReadToken(TokenType::LParent); + + // First and only argument is a required register name and optional component mask + parseHLSLRegisterNameAndOptionalComponentMask(parser, semantic); + + parser->ReadToken(TokenType::RParent); + + parser->sink->diagnose(semantic, Diagnostics::packOffsetNotSupported); + } + + // + // semantic ::= identifier ( '(' args ')' )? + // + static RefPtr ParseSemantic( + Parser* parser) + { + if (parser->LookAheadToken("register")) + { + RefPtr semantic = new HLSLRegisterSemantic(); + parser->FillPosition(semantic); + parseHLSLRegisterSemantic(parser, semantic.Ptr()); + return semantic; + } + else if (parser->LookAheadToken("packoffset")) + { + RefPtr semantic = new HLSLPackOffsetSemantic(); + parser->FillPosition(semantic); + parseHLSLPackOffsetSemantic(parser, semantic.Ptr()); + return semantic; + } + else if (parser->LookAheadToken(TokenType::Identifier)) + { + RefPtr semantic = new HLSLSimpleSemantic(); + parser->FillPosition(semantic); + semantic->name = parser->ReadToken(TokenType::Identifier); + return semantic; + } + else + { + // expect an identifier, just to produce an error message + parser->ReadToken(TokenType::Identifier); + return nullptr; + } + } + + // + // opt-semantics ::= (':' semantic)* + // + static RefPtr ParseOptSemantics( + Parser* parser) + { + if (!AdvanceIf(parser, TokenType::Colon)) + return nullptr; + + RefPtr result; + RefPtr* link = &result; + SLANG_ASSERT(!*link); + + for (;;) + { + RefPtr semantic = ParseSemantic(parser); + if (semantic) + { + *link = semantic; + link = &semantic->next; + } + + // If we see another `:`, then that means there + // is yet another semantic to be processed. + // Otherwise we assume we are at the end of the list. + // + // TODO: This could produce sub-optimal diagnostics + // when the user *meant* to apply multiple semantics + // to a single declaration: + // + // Foo foo : register(t0) register(s0); + // ^ + // missing ':' here | + // + // However, that is an uncommon occurence, and trying + // to continue parsing semantics here even if we didn't + // see a colon forces us to be careful about + // avoiding an infinite loop here. + if (!AdvanceIf(parser, TokenType::Colon)) + { + return result; + } + } + + } + + + static void ParseOptSemantics( + Parser* parser, + Decl* decl) + { + AddModifiers(decl, ParseOptSemantics(parser)); + } + + static RefPtr ParseHLSLBufferDecl( + Parser* parser, + String bufferWrapperTypeName) + { + // An HLSL declaration of a constant buffer like this: + // + // cbuffer Foo : register(b0) { int a; float b; }; + // + // is treated as syntax sugar for a type declaration + // and then a global variable declaration using that type: + // + // struct $anonymous { int a; float b; }; + // ConstantBuffer<$anonymous> Foo; + // + // where `$anonymous` is a fresh name, and the variable + // declaration is made to be "transparent" so that lookup + // will see through it to the members inside. + + auto bufferWrapperTypeNamePos = parser->tokenReader.PeekLoc(); + + // We are going to represent each buffer as a pair of declarations. + // The first is a type declaration that holds all the members, while + // the second is a variable declaration that uses the buffer type. + RefPtr bufferDataTypeDecl = new StructDecl(); + RefPtr bufferVarDecl = new VarDecl(); + + // Both declarations will have a location that points to the name + parser->FillPosition(bufferDataTypeDecl.Ptr()); + parser->FillPosition(bufferVarDecl.Ptr()); + + auto reflectionNameToken = parser->ReadToken(TokenType::Identifier); + + // Attach the reflection name to the block so we can use it + auto reflectionNameModifier = new ParameterGroupReflectionName(); + reflectionNameModifier->nameAndLoc = NameLoc(reflectionNameToken); + addModifier(bufferVarDecl, reflectionNameModifier); + + // Both the buffer variable and its type need to have names generated + bufferVarDecl->nameAndLoc.name = generateName(parser, "parameterGroup_" + String(reflectionNameToken.Content)); + bufferDataTypeDecl->nameAndLoc.name = generateName(parser, "ParameterGroup_" + String(reflectionNameToken.Content)); + + addModifier(bufferDataTypeDecl, new ImplicitParameterGroupElementTypeModifier()); + addModifier(bufferVarDecl, new ImplicitParameterGroupVariableModifier()); + + // TODO(tfoley): We end up constructing unchecked syntax here that + // is expected to type check into the right form, but it might be + // cleaner to have a more explicit desugaring pass where we parse + // these constructs directly into the AST and *then* desugar them. + + // Construct a type expression to reference the buffer data type + auto bufferDataTypeExpr = new VarExpr(); + bufferDataTypeExpr->loc = bufferDataTypeDecl->loc; + bufferDataTypeExpr->name = bufferDataTypeDecl->nameAndLoc.name; + bufferDataTypeExpr->scope = parser->currentScope.Ptr(); + + // Construct a type expression to reference the type constructor + auto bufferWrapperTypeExpr = new VarExpr(); + bufferWrapperTypeExpr->loc = bufferWrapperTypeNamePos; + bufferWrapperTypeExpr->name = getName(parser, bufferWrapperTypeName); + + // Always need to look this up in the outer scope, + // so that it won't collide with, e.g., a local variable called `ConstantBuffer` + bufferWrapperTypeExpr->scope = parser->outerScope; + + // Construct a type expression that represents the type for the variable, + // which is the wrapper type applied to the data type + auto bufferVarTypeExpr = new GenericAppExpr(); + bufferVarTypeExpr->loc = bufferVarDecl->loc; + bufferVarTypeExpr->FunctionExpr = bufferWrapperTypeExpr; + bufferVarTypeExpr->Arguments.add(bufferDataTypeExpr); + + bufferVarDecl->type.exp = bufferVarTypeExpr; + + // Any semantics applied to the buffer declaration are taken as applying + // to the variable instead. + ParseOptSemantics(parser, bufferVarDecl.Ptr()); + + // The declarations in the body belong to the data type. + parseAggTypeDeclBody(parser, bufferDataTypeDecl.Ptr()); + + // All HLSL buffer declarations are "transparent" in that their + // members are implicitly made visible in the parent scope. + // We achieve this by applying the transparent modifier to the variable. + auto transparentModifier = new TransparentModifier(); + transparentModifier->next = bufferVarDecl->modifiers.first; + bufferVarDecl->modifiers.first = transparentModifier; + + // Because we are constructing two declarations, we have a thorny + // issue that were are only supposed to return one. + // For now we handle this by adding the type declaration to + // the current scope manually, and then returning the variable + // declaration. + // + // Note: this means that any modifiers that have already been parsed + // will get attached to the variable declaration, not the type. + // There might be cases where we need to shuffle things around. + + AddMember(parser->currentScope, bufferDataTypeDecl); + + return bufferVarDecl; + } + + static RefPtr parseHLSLCBufferDecl( + Parser* parser, void* /*userData*/) + { + return ParseHLSLBufferDecl(parser, "ConstantBuffer"); + } + + static RefPtr parseHLSLTBufferDecl( + Parser* parser, void* /*userData*/) + { + return ParseHLSLBufferDecl(parser, "TextureBuffer"); + } + + static void parseOptionalInheritanceClause(Parser* parser, AggTypeDeclBase* decl) + { + if (AdvanceIf(parser, TokenType::Colon)) + { + do + { + auto base = parser->ParseTypeExp(); + + auto inheritanceDecl = new InheritanceDecl(); + inheritanceDecl->loc = base.exp->loc; + inheritanceDecl->nameAndLoc.name = getName(parser, "$inheritance"); + inheritanceDecl->base = base; + + AddMember(decl, inheritanceDecl); + + } while (AdvanceIf(parser, TokenType::Comma)); + } + } + + static RefPtr ParseExtensionDecl(Parser* parser, void* /*userData*/) + { + RefPtr decl = new ExtensionDecl(); + parser->FillPosition(decl.Ptr()); + decl->targetType = parser->ParseTypeExp(); + parseOptionalInheritanceClause(parser, decl); + parseAggTypeDeclBody(parser, decl.Ptr()); + + return decl; + } + + + void parseOptionalGenericConstraints(Parser * parser, ContainerDecl* decl) + { + if (AdvanceIf(parser, TokenType::Colon)) + { + do + { + RefPtr paramConstraint = new GenericTypeConstraintDecl(); + parser->FillPosition(paramConstraint); + + // substitution needs to be filled during check + RefPtr paramType = DeclRefType::Create( + parser->getSession(), + DeclRef(decl, nullptr)); + + RefPtr paramTypeExpr = new SharedTypeExpr(); + paramTypeExpr->loc = decl->loc; + paramTypeExpr->base.type = paramType; + paramTypeExpr->type = QualType(getTypeType(paramType)); + + paramConstraint->sub = TypeExp(paramTypeExpr); + paramConstraint->sup = parser->ParseTypeExp(); + + AddMember(decl, paramConstraint); + } while (AdvanceIf(parser, TokenType::Comma)); + } + } + + RefPtr parseAssocType(Parser * parser, void *) + { + RefPtr assocTypeDecl = new AssocTypeDecl(); + + auto nameToken = parser->ReadToken(TokenType::Identifier); + assocTypeDecl->nameAndLoc = NameLoc(nameToken); + assocTypeDecl->loc = nameToken.loc; + parseOptionalGenericConstraints(parser, assocTypeDecl); + parser->ReadToken(TokenType::Semicolon); + return assocTypeDecl; + } + + RefPtr parseGlobalGenericParamDecl(Parser * parser, void *) + { + RefPtr genParamDecl = new GlobalGenericParamDecl(); + auto nameToken = parser->ReadToken(TokenType::Identifier); + genParamDecl->nameAndLoc = NameLoc(nameToken); + genParamDecl->loc = nameToken.loc; + parseOptionalGenericConstraints(parser, genParamDecl); + parser->ReadToken(TokenType::Semicolon); + return genParamDecl; + } + + static RefPtr parseInterfaceDecl(Parser* parser, void* /*userData*/) + { + RefPtr decl = new InterfaceDecl(); + parser->FillPosition(decl.Ptr()); + decl->nameAndLoc = NameLoc(parser->ReadToken(TokenType::Identifier)); + + parseOptionalInheritanceClause(parser, decl.Ptr()); + + parseAggTypeDeclBody(parser, decl.Ptr()); + + return decl; + } + + static RefPtr parseConstructorDecl(Parser* parser, void* /*userData*/) + { + RefPtr decl = new ConstructorDecl(); + parser->FillPosition(decl.Ptr()); + + // TODO: we need to make sure that all initializers have + // the same name, but that this name doesn't conflict + // with any user-defined names. + // Giving them a name (rather than leaving it null) + // ensures that we can use name-based lookup to find + // all of the initializers on a type (and has + // the potential to unify initializer lookup with + // ordinary member lookup). + decl->nameAndLoc.name = getName(parser, "$init"); + + parseParameterList(parser, decl); + + decl->Body = parseOptBody(parser); + + return decl; + } + + static RefPtr parseAccessorDecl(Parser* parser) + { + Modifiers modifiers = ParseModifiers(parser); + + RefPtr decl; + if( AdvanceIf(parser, "get") ) + { + decl = new GetterDecl(); + } + else if( AdvanceIf(parser, "set") ) + { + decl = new SetterDecl(); + } + else if( AdvanceIf(parser, "ref") ) + { + decl = new RefAccessorDecl(); + } + else + { + Unexpected(parser); + return nullptr; + } + + AddModifiers(decl, modifiers.first); + + if( parser->tokenReader.PeekTokenType() == TokenType::LBrace ) + { + decl->Body = parser->parseBlockStatement(); + } + else + { + parser->ReadToken(TokenType::Semicolon); + } + + return decl; + } + + static RefPtr ParseSubscriptDecl(Parser* parser, void* /*userData*/) + { + RefPtr decl = new SubscriptDecl(); + parser->FillPosition(decl.Ptr()); + + // TODO: the use of this name here is a bit magical... + decl->nameAndLoc.name = getName(parser, "operator[]"); + + parseParameterList(parser, decl); + + if( AdvanceIf(parser, TokenType::RightArrow) ) + { + decl->ReturnType = parser->ParseTypeExp(); + } + + if( AdvanceIf(parser, TokenType::LBrace) ) + { + // We want to parse nested "accessor" declarations + while( !AdvanceIfMatch(parser, TokenType::RBrace) ) + { + auto accessor = parseAccessorDecl(parser); + AddMember(decl, accessor); + } + } + else + { + parser->ReadToken(TokenType::Semicolon); + + // empty body should be treated like `{ get; }` + } + + return decl; + } + + static bool expect(Parser* parser, TokenType tokenType) + { + return parser->ReadToken(tokenType).type == tokenType; + } + + static void parseModernVarDeclBaseCommon( + Parser* parser, + RefPtr decl) + { + parser->FillPosition(decl.Ptr()); + decl->nameAndLoc = NameLoc(parser->ReadToken(TokenType::Identifier)); + + if(AdvanceIf(parser, TokenType::Colon)) + { + decl->type = parser->ParseTypeExp(); + } + + if(AdvanceIf(parser, TokenType::OpAssign)) + { + decl->initExpr = parser->ParseInitExpr(); + } + } + + static void parseModernVarDeclCommon( + Parser* parser, + RefPtr decl) + { + parseModernVarDeclBaseCommon(parser, decl); + expect(parser, TokenType::Semicolon); + } + + static RefPtr parseLetDecl( + Parser* parser, void* /*userData*/) + { + RefPtr decl = new LetDecl(); + parseModernVarDeclCommon(parser, decl); + return decl; + } + + static RefPtr parseVarDecl( + Parser* parser, void* /*userData*/) + { + RefPtr decl = new VarDecl(); + parseModernVarDeclCommon(parser, decl); + return decl; + } + + static RefPtr parseModernParamDecl( + Parser* parser) + { + RefPtr decl = new ParamDecl(); + + // TODO: "modern" parameters should not accept keyword-based + // modifiers and should only accept `[attribute]` syntax for + // modifiers to keep the grammar as simple as possible. + // + // Further, they should accept `out` and `in out`/`inout` + // before the type (e.g., `a: inout float4`). + // + decl->modifiers = ParseModifiers(parser); + parseModernVarDeclBaseCommon(parser, decl); + return decl; + } + + static void parseModernParamList( + Parser* parser, + RefPtr decl) + { + parser->ReadToken(TokenType::LParent); + + while (!AdvanceIfMatch(parser, TokenType::RParent)) + { + AddMember(decl, parseModernParamDecl(parser)); + if (AdvanceIf(parser, TokenType::RParent)) + break; + parser->ReadToken(TokenType::Comma); + } + } + + static RefPtr parseFuncDecl( + Parser* parser, void* /*userData*/) + { + RefPtr decl = new FuncDecl(); + + parser->FillPosition(decl.Ptr()); + decl->nameAndLoc = NameLoc(parser->ReadToken(TokenType::Identifier)); + + return parseOptGenericDecl(parser, [&](GenericDecl*) + { + parser->PushScope(decl.Ptr()); + parseModernParamList(parser, decl); + if(AdvanceIf(parser, TokenType::RightArrow)) + { + decl->ReturnType = parser->ParseTypeExp(); + } + decl->Body = parseOptBody(parser); + parser->PopScope(); + return decl; + }); + } + + static RefPtr parseTypeAliasDecl( + Parser* parser, void* /*userData*/) + { + RefPtr decl = new TypeAliasDecl(); + + parser->FillPosition(decl.Ptr()); + decl->nameAndLoc = NameLoc(parser->ReadToken(TokenType::Identifier)); + + return parseOptGenericDecl(parser, [&](GenericDecl*) + { + if( expect(parser, TokenType::OpAssign) ) + { + decl->type = parser->ParseTypeExp(); + } + expect(parser, TokenType::Semicolon); + return decl; + }); + } + + // This is a catch-all syntax-construction callback to handle cases where + // a piece of syntax is fully defined by the keyword to use, along with + // the class of AST node to construct. + static RefPtr parseSimpleSyntax(Parser* /*parser*/, void* userData) + { + SyntaxClassBase syntaxClass((SyntaxClassBase::ClassInfo*) userData); + return (RefObject*) syntaxClass.createInstanceImpl(); + } + + // Parse a declaration of a keyword that can be used to define further syntax. + static RefPtr parseSyntaxDecl(Parser* parser, void* /*userData*/) + { + // Right now the basic form is: + // + // syntax [: ] [= ]; + // + // - `name` gives the name of the keyword to define. + // - `syntaxClass` is the name of an AST node class that we expect + // this syntax to construct when parsed. + // - `existingKeyword` is the name of an existing keyword that + // the new syntax should be an alias for. + + // First we parse the keyword name. + auto nameAndLoc = expectIdentifier(parser); + + // Next we look for a clause that specified the AST node class. + SyntaxClass syntaxClass; + if (AdvanceIf(parser, TokenType::Colon)) + { + // User is specifying the class that should be construted + auto classNameAndLoc = expectIdentifier(parser); + + syntaxClass = parser->getSession()->findSyntaxClass(classNameAndLoc.name); + } + + // If the user specified a syntax class, then we will default + // to the `parseSimpleSyntax` callback that will just construct + // an instance of that type to represent the keyword in the AST. + SyntaxParseCallback parseCallback = &parseSimpleSyntax; + void* parseUserData = (void*) syntaxClass.classInfo; + + // Next we look for an initializer that will make this keyword + // an alias for some existing keyword. + if (AdvanceIf(parser, TokenType::OpAssign)) + { + auto existingKeywordNameAndLoc = expectIdentifier(parser); + + auto existingSyntax = tryLookUpSyntaxDecl(parser, existingKeywordNameAndLoc.name); + if (!existingSyntax) + { + // TODO: diagnose: keyword did not name syntax + } + else + { + // The user is expecting us to parse our new syntax like + // the existing syntax given, so we need to override + // the callback. + parseCallback = existingSyntax->parseCallback; + parseUserData = existingSyntax->parseUserData; + + // If we don't already have a syntax class specified, then + // we will crib the one from the existing syntax, to ensure + // that we are creating a drop-in alias. + if (!syntaxClass.classInfo) + syntaxClass = existingSyntax->syntaxClass; + } + } + + // It is an error if the user didn't give us either an existing keyword + // to use to the define the callback, or a valid AST node class to construct. + // + // TODO: down the line this should be expanded so that the user can reference + // an existing *function* to use to parse the chosen syntax. + if (!syntaxClass.classInfo) + { + // TODO: diagnose: either a type or an existing keyword needs to be specified + } + + expect(parser, TokenType::Semicolon); + + // TODO: skip creating the declaration if anything failed, just to not screw things + // up for downstream code? + + RefPtr syntaxDecl = new SyntaxDecl(); + syntaxDecl->nameAndLoc = nameAndLoc; + syntaxDecl->loc = nameAndLoc.loc; + syntaxDecl->syntaxClass = syntaxClass; + syntaxDecl->parseCallback = parseCallback; + syntaxDecl->parseUserData = parseUserData; + return syntaxDecl; + } + + // A parameter declaration in an attribute declaration. + // + // We are going to use `name: type` syntax just for simplicty, and let the type + // be optional, because we don't actually need it in all cases. + // + static RefPtr parseAttributeParamDecl(Parser* parser) + { + auto nameAndLoc = expectIdentifier(parser); + + RefPtr paramDecl = new ParamDecl(); + paramDecl->nameAndLoc = nameAndLoc; + + if(AdvanceIf(parser, TokenType::Colon)) + { + paramDecl->type = parser->ParseTypeExp(); + } + + if(AdvanceIf(parser, TokenType::OpAssign)) + { + paramDecl->initExpr = parser->ParseInitExpr(); + } + + return paramDecl; + } + + // Parse declaration of a name to be used for resolving `[attribute(...)]` style modifiers. + // + // These are distinct from `syntax` declarations, because their names don't get added + // to the current scope using their default name. + // + // Also, attribute-specific code doesn't get invokved during parsing. We always parse + // using the default attribute-parsing logic and then all specialized behavior takes + // place during semantic checking. + // + static RefPtr parseAttributeSyntaxDecl(Parser* parser, void* /*userData*/) + { + // Right now the basic form is: + // + // attribute_syntax : ; + // + // - `name` gives the name of the attribute to define. + // - `syntaxClass` is the name of an AST node class that we expect + // this attribute to create when checked. + // - `existingKeyword` is the name of an existing keyword that + // the new syntax should be an alias for. + + expect(parser, TokenType::LBracket); + + // First we parse the attribute name. + auto nameAndLoc = expectIdentifier(parser); + + RefPtr attrDecl = new AttributeDecl(); + if(AdvanceIf(parser, TokenType::LParent)) + { + while(!AdvanceIfMatch(parser, TokenType::RParent)) + { + auto param = parseAttributeParamDecl(parser); + + AddMember(attrDecl, param); + + if(AdvanceIfMatch(parser, TokenType::RParent)) + break; + + expect(parser, TokenType::Comma); + } + } + + expect(parser, TokenType::RBracket); + + // TODO: we should allow parameters to be specified here, to cut down + // on the amount of per-attribute-type logic that has to occur later. + + // Next we look for a clause that specified the AST node class. + SyntaxClass syntaxClass; + if (AdvanceIf(parser, TokenType::Colon)) + { + // User is specifying the class that should be construted + auto classNameAndLoc = expectIdentifier(parser); + + syntaxClass = parser->getSession()->findSyntaxClass(classNameAndLoc.name); + } + else + { + // For now we don't support the alternative approach where + // an existing piece of syntax is named to provide the parsing + // support. + + // TODO: diagnose: a syntax class must be specified. + } + + expect(parser, TokenType::Semicolon); + + // TODO: skip creating the declaration if anything failed, just to not screw things + // up for downstream code? + + attrDecl->nameAndLoc = nameAndLoc; + attrDecl->loc = nameAndLoc.loc; + attrDecl->syntaxClass = syntaxClass; + return attrDecl; + } + + // Finish up work on a declaration that was parsed + static void CompleteDecl( + Parser* /*parser*/, + RefPtr decl, + ContainerDecl* containerDecl, + Modifiers modifiers) + { + // Add any modifiers we parsed before the declaration to the list + // of modifiers on the declaration itself. + // + // We need to be careful, because if `decl` is a generic declaration, + // then we really want the modifiers to apply to the inner declaration. + // + RefPtr declToModify = decl; + if(auto genericDecl = as(decl)) + declToModify = genericDecl->inner; + AddModifiers(declToModify.Ptr(), modifiers.first); + + // Make sure the decl is properly nested inside its lexical parent + if (containerDecl) + { + AddMember(containerDecl, decl); + } + } + + static RefPtr ParseDeclWithModifiers( + Parser* parser, + ContainerDecl* containerDecl, + Modifiers modifiers ) + { + RefPtr decl; + + auto loc = parser->tokenReader.PeekLoc(); + + switch (peekTokenType(parser)) + { + case TokenType::Identifier: + { + // A declaration that starts with an identifier might be: + // + // - A keyword-based declaration (e.g., `cbuffer ...`) + // - The beginning of a type in a declarator-based declaration (e.g., `int ...`) + + // First we will check whether we can use the identifier token + // as a declaration keyword and parse a declaration using + // its associated callback: + RefPtr parsedDecl; + if (tryParseUsingSyntaxDecl(parser, &parsedDecl)) + { + decl = parsedDecl; + break; + } + + // Our final fallback case is to assume that the user is + // probably writing a C-style declarator-based declaration. + decl = ParseDeclaratorDecl(parser, containerDecl); + break; + } + break; + + // It is valid in HLSL/GLSL to have an "empty" declaration + // that consists of just a semicolon. In particular, this + // gets used a lot in GLSL to attach custom semantics to + // shader input or output. + // + case TokenType::Semicolon: + { + advanceToken(parser); + + decl = new EmptyDecl(); + decl->loc = loc; + } + break; + + // If nothing else matched, we try to parse an "ordinary" declarator-based declaration + default: + decl = ParseDeclaratorDecl(parser, containerDecl); + break; + } + + if (decl) + { + if( auto dd = as(decl) ) + { + CompleteDecl(parser, dd, containerDecl, modifiers); + } + else if(auto declGroup = as(decl)) + { + // We are going to add the same modifiers to *all* of these declarations, + // so we want to give later passes a way to detect which modifiers + // were shared, vs. which ones are specific to a single declaration. + + auto sharedModifiers = new SharedModifiers(); + sharedModifiers->next = modifiers.first; + modifiers.first = sharedModifiers; + + for( auto subDecl : declGroup->decls ) + { + CompleteDecl(parser, subDecl, containerDecl, modifiers); + } + } + } + return decl; + } + + static RefPtr ParseDecl( + Parser* parser, + ContainerDecl* containerDecl) + { + Modifiers modifiers = ParseModifiers(parser); + return ParseDeclWithModifiers(parser, containerDecl, modifiers); + } + + static RefPtr ParseSingleDecl( + Parser* parser, + ContainerDecl* containerDecl) + { + auto declBase = ParseDecl(parser, containerDecl); + if(!declBase) + return nullptr; + if( auto decl = as(declBase) ) + { + return decl; + } + else if( auto declGroup = as(declBase) ) + { + if( declGroup->decls.getCount() == 1 ) + { + return declGroup->decls[0]; + } + } + + parser->sink->diagnose(declBase->loc, Diagnostics::unimplemented, "didn't expect multiple declarations here"); + return nullptr; + } + + + // Parse a body consisting of declarations + static void ParseDeclBody( + Parser* parser, + ContainerDecl* containerDecl, + TokenType closingToken) + { + while(!AdvanceIfMatch(parser, closingToken)) + { + ParseDecl(parser, containerDecl); + } + } + + // Parse the `{}`-delimeted body of an aggregate type declaration + static void parseAggTypeDeclBody( + Parser* parser, + AggTypeDeclBase* decl) + { + // TODO: the scope used for the body might need to be + // slightly specialized to deal with the complexity + // of how `this` works. + // + // Alternatively, that complexity can be pushed down + // to semantic analysis so that it doesn't clutter + // things here. + parser->PushScope(decl); + + parser->ReadToken(TokenType::LBrace); + ParseDeclBody(parser, decl, TokenType::RBrace); + + parser->PopScope(); + } + + + void Parser::parseSourceFile(ModuleDecl* program) + { + if (outerScope) + { + currentScope = outerScope; + } + + PushScope(program); + program->loc = tokenReader.PeekLoc(); + program->scope = currentScope; + ParseDeclBody(this, program, TokenType::EndOfFile); + PopScope(); + + SLANG_RELEASE_ASSERT(currentScope == outerScope); + currentScope = nullptr; + } + + RefPtr Parser::ParseStruct() + { + RefPtr rs = new StructDecl(); + FillPosition(rs.Ptr()); + ReadToken("struct"); + + // TODO: support `struct` declaration without tag + rs->nameAndLoc = expectIdentifier(this); + + return parseOptGenericDecl(this, [&](GenericDecl*) + { + // We allow for an inheritance clause on a `struct` + // so that it can conform to interfaces. + parseOptionalInheritanceClause(this, rs.Ptr()); + parseAggTypeDeclBody(this, rs.Ptr()); + return rs; + }); + } + + RefPtr Parser::ParseClass() + { + RefPtr rs = new ClassDecl(); + FillPosition(rs.Ptr()); + ReadToken("class"); + rs->nameAndLoc = expectIdentifier(this); + + parseOptionalInheritanceClause(this, rs.Ptr()); + + parseAggTypeDeclBody(this, rs.Ptr()); + return rs; + } + + static RefPtr parseEnumCaseDecl(Parser* parser) + { + RefPtr decl = new EnumCaseDecl(); + decl->nameAndLoc = expectIdentifier(parser); + + if(AdvanceIf(parser, TokenType::OpAssign)) + { + decl->tagExpr = parser->ParseArgExpr(); + } + + return decl; + } + + static RefPtr parseEnumDecl(Parser* parser) + { + RefPtr decl = new EnumDecl(); + parser->FillPosition(decl); + + parser->ReadToken("enum"); + + // HACK: allow the user to write `enum class` in case + // they are trying to share a header between C++ and Slang. + // + // TODO: diagnose this with a warning some day, and move + // toward deprecating it. + // + AdvanceIf(parser, "class"); + + decl->nameAndLoc = expectIdentifier(parser); + + + return parseOptGenericDecl(parser, [&](GenericDecl*) + { + parseOptionalInheritanceClause(parser, decl); + parser->ReadToken(TokenType::LBrace); + + while(!AdvanceIfMatch(parser, TokenType::RBrace)) + { + RefPtr caseDecl = parseEnumCaseDecl(parser); + AddMember(decl, caseDecl); + + if(AdvanceIf(parser, TokenType::RBrace)) + break; + + parser->ReadToken(TokenType::Comma); + } + return decl; + }); + } + + static RefPtr ParseSwitchStmt(Parser* parser) + { + RefPtr stmt = new SwitchStmt(); + parser->FillPosition(stmt.Ptr()); + parser->ReadToken("switch"); + parser->ReadToken(TokenType::LParent); + stmt->condition = parser->ParseExpression(); + parser->ReadToken(TokenType::RParent); + stmt->body = parser->parseBlockStatement(); + return stmt; + } + + static RefPtr ParseCaseStmt(Parser* parser) + { + RefPtr stmt = new CaseStmt(); + parser->FillPosition(stmt.Ptr()); + parser->ReadToken("case"); + stmt->expr = parser->ParseExpression(); + parser->ReadToken(TokenType::Colon); + return stmt; + } + + static RefPtr ParseDefaultStmt(Parser* parser) + { + RefPtr stmt = new DefaultStmt(); + parser->FillPosition(stmt.Ptr()); + parser->ReadToken("default"); + parser->ReadToken(TokenType::Colon); + return stmt; + } + + static bool isTypeName(Parser* parser, Name* name) + { + auto lookupResult = lookUp( + parser->getSession(), + nullptr, // no semantics visitor available yet + name, + parser->currentScope); + if(!lookupResult.isValid() || lookupResult.isOverloaded()) + return false; + + auto decl = lookupResult.item.declRef.getDecl(); + if( auto typeDecl = as(decl) ) + { + return true; + } + else if( auto typeVarDecl = as(decl) ) + { + return true; + } + else + { + return false; + } + } + + static bool peekTypeName(Parser* parser) + { + if(!parser->LookAheadToken(TokenType::Identifier)) + return false; + + auto name = parser->tokenReader.PeekToken().getName(); + return isTypeName(parser, name); + } + + RefPtr parseCompileTimeForStmt( + Parser* parser) + { + RefPtr scopeDecl = new ScopeDecl(); + RefPtr stmt = new CompileTimeForStmt(); + stmt->scopeDecl = scopeDecl; + + + parser->ReadToken("for"); + parser->ReadToken(TokenType::LParent); + + NameLoc varNameAndLoc = expectIdentifier(parser); + RefPtr varDecl = new VarDecl(); + varDecl->nameAndLoc = varNameAndLoc; + varDecl->loc = varNameAndLoc.loc; + + stmt->varDecl = varDecl; + + parser->ReadToken("in"); + parser->ReadToken("Range"); + parser->ReadToken(TokenType::LParent); + + RefPtr rangeBeginExpr; + RefPtr rangeEndExpr = parser->ParseArgExpr(); + if (AdvanceIf(parser, TokenType::Comma)) + { + rangeBeginExpr = rangeEndExpr; + rangeEndExpr = parser->ParseArgExpr(); + } + + stmt->rangeBeginExpr = rangeBeginExpr; + stmt->rangeEndExpr = rangeEndExpr; + + parser->ReadToken(TokenType::RParent); + parser->ReadToken(TokenType::RParent); + + parser->pushScopeAndSetParent(scopeDecl); + AddMember(parser->currentScope, varDecl); + + stmt->body = parser->ParseStatement(); + + parser->PopScope(); + + return stmt; + } + + RefPtr parseCompileTimeStmt( + Parser* parser) + { + parser->ReadToken(TokenType::Dollar); + if (parser->LookAheadToken("for")) + { + return parseCompileTimeForStmt(parser); + } + else + { + Unexpected(parser); + return nullptr; + } + } + + RefPtr Parser::ParseStatement() + { + auto modifiers = ParseModifiers(this); + + RefPtr statement; + if (LookAheadToken(TokenType::LBrace)) + statement = parseBlockStatement(); + else if (peekTypeName(this)) + statement = parseVarDeclrStatement(modifiers); + else if (LookAheadToken("if")) + statement = parseIfStatement(); + else if (LookAheadToken("for")) + statement = ParseForStatement(); + else if (LookAheadToken("while")) + statement = ParseWhileStatement(); + else if (LookAheadToken("do")) + statement = ParseDoWhileStatement(); + else if (LookAheadToken("break")) + statement = ParseBreakStatement(); + else if (LookAheadToken("continue")) + statement = ParseContinueStatement(); + else if (LookAheadToken("return")) + statement = ParseReturnStatement(); + else if (LookAheadToken("discard")) + { + statement = new DiscardStmt(); + FillPosition(statement.Ptr()); + ReadToken("discard"); + ReadToken(TokenType::Semicolon); + } + else if (LookAheadToken("switch")) + statement = ParseSwitchStmt(this); + else if (LookAheadToken("case")) + statement = ParseCaseStmt(this); + else if (LookAheadToken("default")) + statement = ParseDefaultStmt(this); + else if (LookAheadToken(TokenType::Dollar)) + { + statement = parseCompileTimeStmt(this); + } + else if (LookAheadToken(TokenType::Identifier)) + { + // We might be looking at a local declaration, or an + // expression statement, and we need to figure out which. + // + // We'll solve this with backtracking for now. + + TokenReader::ParsingCursor startPos = tokenReader.getCursor(); + + // Try to parse a type (knowing that the type grammar is + // a subset of the expression grammar, and so this should + // always succeed). + RefPtr type = ParseType(); + // We don't actually care about the type, though, so + // don't retain it + type = nullptr; + + // If the next token after we parsed a type looks like + // we are going to declare a variable, then lets guess + // that this is a declaration. + // + // TODO(tfoley): this wouldn't be robust for more + // general kinds of declarators (notably pointer declarators), + // so we'll need to be careful about this. + if (LookAheadToken(TokenType::Identifier)) + { + // Reset the cursor and try to parse a declaration now. + // Note: the declaration will consume any modifiers + // that had been in place on the statement. + tokenReader.setCursor(startPos); + statement = parseVarDeclrStatement(modifiers); + return statement; + } + + // Fallback: reset and parse an expression + tokenReader.setCursor(startPos); + statement = ParseExpressionStatement(); + } + else if (LookAheadToken(TokenType::Semicolon)) + { + statement = new EmptyStmt(); + FillPosition(statement.Ptr()); + ReadToken(TokenType::Semicolon); + } + else + { + // Default case should always fall back to parsing an expression, + // and then let that detect any errors + statement = ParseExpressionStatement(); + } + + if (statement && !as(statement)) + { + // Install any modifiers onto the statement. + // Note: this path is bypassed in the case of a + // declaration statement, so we don't end up + // doubling up the modifiers. + statement->modifiers = modifiers; + } + + return statement; + } + + RefPtr Parser::parseBlockStatement() + { + RefPtr scopeDecl = new ScopeDecl(); + RefPtr blockStatement = new BlockStmt(); + blockStatement->scopeDecl = scopeDecl; + pushScopeAndSetParent(scopeDecl.Ptr()); + ReadToken(TokenType::LBrace); + + RefPtr body; + + if(!tokenReader.IsAtEnd()) + { + FillPosition(blockStatement.Ptr()); + } + while (!AdvanceIfMatch(this, TokenType::RBrace)) + { + auto stmt = ParseStatement(); + if(stmt) + { + if (!body) + { + body = stmt; + } + else if (auto seqStmt = as(body)) + { + seqStmt->stmts.add(stmt); + } + else + { + RefPtr newBody = new SeqStmt(); + newBody->loc = blockStatement->loc; + newBody->stmts.add(body); + newBody->stmts.add(stmt); + + body = newBody; + } + } + TryRecover(this); + } + PopScope(); + + if(!body) + { + body = new EmptyStmt(); + body->loc = blockStatement->loc; + } + + blockStatement->body = body; + return blockStatement; + } + + RefPtr Parser::parseVarDeclrStatement( + Modifiers modifiers) + { + RefPtrvarDeclrStatement = new DeclStmt(); + + FillPosition(varDeclrStatement.Ptr()); + auto decl = ParseDeclWithModifiers(this, currentScope->containerDecl, modifiers); + varDeclrStatement->decl = decl; + return varDeclrStatement; + } + + RefPtr Parser::parseIfStatement() + { + RefPtr ifStatement = new IfStmt(); + FillPosition(ifStatement.Ptr()); + ReadToken("if"); + ReadToken(TokenType::LParent); + ifStatement->Predicate = ParseExpression(); + ReadToken(TokenType::RParent); + ifStatement->PositiveStatement = ParseStatement(); + if (LookAheadToken("else")) + { + ReadToken("else"); + ifStatement->NegativeStatement = ParseStatement(); + } + return ifStatement; + } + + RefPtr Parser::ParseForStatement() + { + RefPtr scopeDecl = new ScopeDecl(); + + // HLSL implements the bad approach to scoping a `for` loop + // variable, and we want to respect that, but *only* when + // parsing HLSL code. + // + + bool brokenScoping = getSourceLanguage() == SourceLanguage::HLSL; + + // We will create a distinct syntax node class for the unscoped + // case, just so that we can correctly handle it in downstream + // logic. + // + RefPtr stmt; + if (brokenScoping) + { + stmt = new UnscopedForStmt(); + } + else + { + stmt = new ForStmt(); + } + + stmt->scopeDecl = scopeDecl; + + if(!brokenScoping) + pushScopeAndSetParent(scopeDecl.Ptr()); + FillPosition(stmt.Ptr()); + ReadToken("for"); + ReadToken(TokenType::LParent); + if (peekTypeName(this)) + { + stmt->InitialStatement = parseVarDeclrStatement(Modifiers()); + } + else + { + if (!LookAheadToken(TokenType::Semicolon)) + { + stmt->InitialStatement = ParseExpressionStatement(); + } + else + { + ReadToken(TokenType::Semicolon); + } + } + if (!LookAheadToken(TokenType::Semicolon)) + stmt->PredicateExpression = ParseExpression(); + ReadToken(TokenType::Semicolon); + if (!LookAheadToken(TokenType::RParent)) + stmt->SideEffectExpression = ParseExpression(); + ReadToken(TokenType::RParent); + stmt->Statement = ParseStatement(); + + if (!brokenScoping) + PopScope(); + + return stmt; + } + + RefPtr Parser::ParseWhileStatement() + { + RefPtr whileStatement = new WhileStmt(); + FillPosition(whileStatement.Ptr()); + ReadToken("while"); + ReadToken(TokenType::LParent); + whileStatement->Predicate = ParseExpression(); + ReadToken(TokenType::RParent); + whileStatement->Statement = ParseStatement(); + return whileStatement; + } + + RefPtr Parser::ParseDoWhileStatement() + { + RefPtr doWhileStatement = new DoWhileStmt(); + FillPosition(doWhileStatement.Ptr()); + ReadToken("do"); + doWhileStatement->Statement = ParseStatement(); + ReadToken("while"); + ReadToken(TokenType::LParent); + doWhileStatement->Predicate = ParseExpression(); + ReadToken(TokenType::RParent); + ReadToken(TokenType::Semicolon); + return doWhileStatement; + } + + RefPtr Parser::ParseBreakStatement() + { + RefPtr breakStatement = new BreakStmt(); + FillPosition(breakStatement.Ptr()); + ReadToken("break"); + ReadToken(TokenType::Semicolon); + return breakStatement; + } + + RefPtr Parser::ParseContinueStatement() + { + RefPtr continueStatement = new ContinueStmt(); + FillPosition(continueStatement.Ptr()); + ReadToken("continue"); + ReadToken(TokenType::Semicolon); + return continueStatement; + } + + RefPtr Parser::ParseReturnStatement() + { + RefPtr returnStatement = new ReturnStmt(); + FillPosition(returnStatement.Ptr()); + ReadToken("return"); + if (!LookAheadToken(TokenType::Semicolon)) + returnStatement->Expression = ParseExpression(); + ReadToken(TokenType::Semicolon); + return returnStatement; + } + + RefPtr Parser::ParseExpressionStatement() + { + RefPtr statement = new ExpressionStmt(); + + FillPosition(statement.Ptr()); + statement->Expression = ParseExpression(); + + ReadToken(TokenType::Semicolon); + return statement; + } + + RefPtr Parser::ParseParameter() + { + RefPtr parameter = new ParamDecl(); + parameter->modifiers = ParseModifiers(this); + + DeclaratorInfo declaratorInfo; + declaratorInfo.typeSpec = ParseType(); + + InitDeclarator initDeclarator = ParseInitDeclarator(this); + UnwrapDeclarator(initDeclarator, &declaratorInfo); + + // Assume it is a variable-like declarator + CompleteVarDecl(this, parameter, declaratorInfo); + return parameter; + } + + RefPtr Parser::ParseType() + { + auto typeSpec = parseTypeSpec(this); + if( typeSpec.decl ) + { + AddMember(currentScope, typeSpec.decl); + } + auto typeExpr = typeSpec.expr; + + typeExpr = parsePostfixTypeSuffix(this, typeExpr); + + return typeExpr; + } + + + + TypeExp Parser::ParseTypeExp() + { + return TypeExp(ParseType()); + } + + enum class Associativity + { + Left, Right + }; + + + + Associativity GetAssociativityFromLevel(Precedence level) + { + if (level == Precedence::Assignment) + return Associativity::Right; + else + return Associativity::Left; + } + + + + + Precedence GetOpLevel(Parser* parser, TokenType type) + { + switch(type) + { + case TokenType::QuestionMark: + return Precedence::TernaryConditional; + case TokenType::Comma: + return Precedence::Comma; + case TokenType::OpAssign: + case TokenType::OpMulAssign: + case TokenType::OpDivAssign: + case TokenType::OpAddAssign: + case TokenType::OpSubAssign: + case TokenType::OpModAssign: + case TokenType::OpShlAssign: + case TokenType::OpShrAssign: + case TokenType::OpOrAssign: + case TokenType::OpAndAssign: + case TokenType::OpXorAssign: + return Precedence::Assignment; + case TokenType::OpOr: + return Precedence::LogicalOr; + case TokenType::OpAnd: + return Precedence::LogicalAnd; + case TokenType::OpBitOr: + return Precedence::BitOr; + case TokenType::OpBitXor: + return Precedence::BitXor; + case TokenType::OpBitAnd: + return Precedence::BitAnd; + case TokenType::OpEql: + case TokenType::OpNeq: + return Precedence::EqualityComparison; + case TokenType::OpGreater: + case TokenType::OpGeq: + // Don't allow these ops inside a generic argument + if (parser->genericDepth > 0) return Precedence::Invalid; + ; // fall-thru + case TokenType::OpLeq: + case TokenType::OpLess: + return Precedence::RelationalComparison; + case TokenType::OpRsh: + // Don't allow this op inside a generic argument + if (parser->genericDepth > 0) return Precedence::Invalid; + ; // fall-thru + case TokenType::OpLsh: + return Precedence::BitShift; + case TokenType::OpAdd: + case TokenType::OpSub: + return Precedence::Additive; + case TokenType::OpMul: + case TokenType::OpDiv: + case TokenType::OpMod: + return Precedence::Multiplicative; + default: + return Precedence::Invalid; + } + } + + static RefPtr parseOperator(Parser* parser) + { + Token opToken; + switch(parser->tokenReader.PeekTokenType()) + { + case TokenType::QuestionMark: + opToken = parser->ReadToken(); + opToken.Content = UnownedStringSlice::fromLiteral("?:"); + break; + + default: + opToken = parser->ReadToken(); + break; + } + + auto opExpr = new VarExpr(); + opExpr->name = getName(parser, opToken.Content); + opExpr->scope = parser->currentScope; + opExpr->loc = opToken.loc; + + return opExpr; + + } + + static RefPtr createInfixExpr( + Parser* /*parser*/, + RefPtr left, + RefPtr op, + RefPtr right) + { + RefPtr expr = new InfixExpr(); + expr->loc = op->loc; + expr->FunctionExpr = op; + expr->Arguments.add(left); + expr->Arguments.add(right); + return expr; + } + + static RefPtr parseInfixExprWithPrecedence( + Parser* parser, + RefPtr inExpr, + Precedence prec) + { + auto expr = inExpr; + for(;;) + { + auto opTokenType = parser->tokenReader.PeekTokenType(); + auto opPrec = GetOpLevel(parser, opTokenType); + if(opPrec < prec) + break; + + auto op = parseOperator(parser); + + // Special case the `?:` operator since it is the + // one non-binary case we need to deal with. + if(opTokenType == TokenType::QuestionMark) + { + RefPtr select = new SelectExpr(); + select->loc = op->loc; + select->FunctionExpr = op; + + select->Arguments.add(expr); + + select->Arguments.add(parser->ParseExpression(opPrec)); + parser->ReadToken(TokenType::Colon); + select->Arguments.add(parser->ParseExpression(opPrec)); + + expr = select; + continue; + } + + auto right = parser->ParseLeafExpression(); + + for(;;) + { + auto nextOpPrec = GetOpLevel(parser, parser->tokenReader.PeekTokenType()); + + if((GetAssociativityFromLevel(nextOpPrec) == Associativity::Right) ? (nextOpPrec < opPrec) : (nextOpPrec <= opPrec)) + break; + + right = parseInfixExprWithPrecedence(parser, right, nextOpPrec); + } + + if (opTokenType == TokenType::OpAssign) + { + RefPtr assignExpr = new AssignExpr(); + assignExpr->loc = op->loc; + assignExpr->left = expr; + assignExpr->right = right; + + expr = assignExpr; + } + else + { + expr = createInfixExpr(parser, expr, op, right); + } + } + return expr; + } + + RefPtr Parser::ParseExpression(Precedence level) + { + auto expr = ParseLeafExpression(); + return parseInfixExprWithPrecedence(this, expr, level); + +#if 0 + + if (level == Precedence::Prefix) + return ParseLeafExpression(); + if (level == Precedence::TernaryConditional) + { + // parse select clause + auto condition = ParseExpression(Precedence(level + 1)); + if (LookAheadToken(TokenType::QuestionMark)) + { + RefPtr select = new SelectExpr(); + FillPosition(select.Ptr()); + + select->Arguments.Add(condition); + + select->FunctionExpr = parseOperator(this); + + select->Arguments.Add(ParseExpression(level)); + ReadToken(TokenType::Colon); + select->Arguments.Add(ParseExpression(level)); + return select; + } + else + return condition; + } + else + { + if (GetAssociativityFromLevel(level) == Associativity::Left) + { + auto left = ParseExpression(Precedence(level + 1)); + while (GetOpLevel(this, tokenReader.PeekTokenType()) == level) + { + RefPtr tmp = new InfixExpr(); + tmp->FunctionExpr = parseOperator(this); + + tmp->Arguments.Add(left); + FillPosition(tmp.Ptr()); + tmp->Arguments.Add(ParseExpression(Precedence(level + 1))); + left = tmp; + } + return left; + } + else + { + auto left = ParseExpression(Precedence(level + 1)); + if (GetOpLevel(this, tokenReader.PeekTokenType()) == level) + { + RefPtr tmp = new InfixExpr(); + tmp->Arguments.Add(left); + FillPosition(tmp.Ptr()); + tmp->FunctionExpr = parseOperator(this); + tmp->Arguments.Add(ParseExpression(level)); + left = tmp; + } + return left; + } + } +#endif + } + + // We *might* be looking at an application of a generic to arguments, + // but we need to disambiguate to make sure. + static RefPtr maybeParseGenericApp( + Parser* parser, + + // TODO: need to support more general expressions here + RefPtr base) + { + if(peekTokenType(parser) != TokenType::OpLess) + return base; + return tryParseGenericApp(parser, base); + } + + static RefPtr parsePrefixExpr(Parser* parser); + + // Parse OOP `this` expression syntax + static RefPtr parseThisExpr(Parser* parser, void* /*userData*/) + { + RefPtr expr = new ThisExpr(); + expr->scope = parser->currentScope; + return expr; + } + + static RefPtr parseBoolLitExpr(Parser* /*parser*/, bool value) + { + RefPtr expr = new BoolLiteralExpr(); + expr->value = value; + return expr; + } + + static RefPtr parseTrueExpr(Parser* parser, void* /*userData*/) + { + return parseBoolLitExpr(parser, true); + } + + static RefPtr parseFalseExpr(Parser* parser, void* /*userData*/) + { + return parseBoolLitExpr(parser, false); + } + + static RefPtr parseAtomicExpr(Parser* parser) + { + switch( peekTokenType(parser) ) + { + default: + // TODO: should this return an error expression instead of NULL? + parser->sink->diagnose(parser->tokenReader.PeekLoc(), Diagnostics::syntaxError); + return nullptr; + + // Either: + // - parenthized expression `(exp)` + // - cast `(type) exp` + // + // Proper disambiguation requires mixing up parsing + // and semantic checking (which we should do eventually) + // but for now we will follow some heuristics. + case TokenType::LParent: + { + Token openParen = parser->ReadToken(TokenType::LParent); + + if (peekTypeName(parser) && parser->LookAheadToken(TokenType::RParent, 1)) + { + RefPtr tcexpr = new ExplicitCastExpr(); + parser->FillPosition(tcexpr.Ptr()); + tcexpr->FunctionExpr = parser->ParseType(); + parser->ReadToken(TokenType::RParent); + + auto arg = parsePrefixExpr(parser); + tcexpr->Arguments.add(arg); + + return tcexpr; + } + else + { + RefPtr base = parser->ParseExpression(); + parser->ReadToken(TokenType::RParent); + + RefPtr parenExpr = new ParenExpr(); + parenExpr->loc = openParen.loc; + parenExpr->base = base; + return parenExpr; + } + } + + // An initializer list `{ expr, ... }` + case TokenType::LBrace: + { + RefPtr initExpr = new InitializerListExpr(); + parser->FillPosition(initExpr.Ptr()); + + // Initializer list + parser->ReadToken(TokenType::LBrace); + + List> exprs; + + for(;;) + { + if(AdvanceIfMatch(parser, TokenType::RBrace)) + break; + + auto expr = parser->ParseArgExpr(); + if( expr ) + { + initExpr->args.add(expr); + } + + if(AdvanceIfMatch(parser, TokenType::RBrace)) + break; + + parser->ReadToken(TokenType::Comma); + } + + return initExpr; + } + + case TokenType::IntegerLiteral: + { + RefPtr constExpr = new IntegerLiteralExpr(); + parser->FillPosition(constExpr.Ptr()); + + auto token = parser->tokenReader.AdvanceToken(); + constExpr->token = token; + + UnownedStringSlice suffix; + IntegerLiteralValue value = getIntegerLiteralValue(token, &suffix); + + // Look at any suffix on the value + char const* suffixCursor = suffix.begin(); + const char*const suffixEnd = suffix.end(); + + RefPtr suffixType = nullptr; + if( suffixCursor < suffixEnd ) + { + int lCount = 0; + int uCount = 0; + int unknownCount = 0; + while(suffixCursor < suffixEnd) + { + switch( *suffixCursor++ ) + { + case 'l': case 'L': + lCount++; + break; + + case 'u': case 'U': + uCount++; + break; + + default: + unknownCount++; + break; + } + } + + if(unknownCount) + { + parser->sink->diagnose(token, Diagnostics::invalidIntegerLiteralSuffix, suffix); + suffixType = parser->getSession()->getErrorType(); + } + // `u` or `ul` suffix -> `uint` + else if(uCount == 1 && (lCount <= 1)) + { + suffixType = parser->getSession()->getUIntType(); + } + // `l` suffix on integer -> `int` (== `long`) + else if(lCount == 1 && !uCount) + { + suffixType = parser->getSession()->getIntType(); + } + // `ull` suffix -> `uint64_t` + else if(uCount == 1 && lCount == 2) + { + suffixType = parser->getSession()->getUInt64Type(); + } + // `ll` suffix -> `int64_t` + else if(uCount == 0 && lCount == 2) + { + suffixType = parser->getSession()->getInt64Type(); + } + // TODO: do we need suffixes for smaller integer types? + else + { + parser->sink->diagnose(token, Diagnostics::invalidIntegerLiteralSuffix, suffix); + suffixType = parser->getSession()->getErrorType(); + } + } + + constExpr->value = value; + constExpr->type = QualType(suffixType); + + return constExpr; + } + + + case TokenType::FloatingPointLiteral: + { + RefPtr constExpr = new FloatingPointLiteralExpr(); + parser->FillPosition(constExpr.Ptr()); + + auto token = parser->tokenReader.AdvanceToken(); + constExpr->token = token; + + UnownedStringSlice suffix; + FloatingPointLiteralValue value = getFloatingPointLiteralValue(token, &suffix); + + // Look at any suffix on the value + char const* suffixCursor = suffix.begin(); + const char*const suffixEnd = suffix.end(); + + RefPtr suffixType = nullptr; + if( suffixCursor < suffixEnd ) + { + int fCount = 0; + int lCount = 0; + int hCount = 0; + int unknownCount = 0; + while(suffixCursor < suffixEnd) + { + switch( *suffixCursor++ ) + { + case 'f': case 'F': + fCount++; + break; + + case 'l': case 'L': + lCount++; + break; + + case 'h': case 'H': + hCount++; + break; + + default: + unknownCount++; + break; + } + } + + if (unknownCount) + { + parser->sink->diagnose(token, Diagnostics::invalidFloatingPointLiteralSuffix, suffix); + suffixType = parser->getSession()->getErrorType(); + } + // `f` suffix -> `float` + if(fCount == 1 && !lCount) + { + suffixType = parser->getSession()->getFloatType(); + } + // `l` or `lf` suffix on floating-point literal -> `double` + else if(lCount == 1 && (fCount <= 1)) + { + suffixType = parser->getSession()->getDoubleType(); + } + // `h` or `hf` suffix on floating-point literal -> `half` + else if(lCount == 1 && (fCount <= 1)) + { + suffixType = parser->getSession()->getHalfType(); + } + // TODO: are there other suffixes we need to handle? + else + { + parser->sink->diagnose(token, Diagnostics::invalidFloatingPointLiteralSuffix, suffix); + suffixType = parser->getSession()->getErrorType(); + } + } + + constExpr->value = value; + constExpr->type = QualType(suffixType); + + return constExpr; + } + + case TokenType::StringLiteral: + { + RefPtr constExpr = new StringLiteralExpr(); + auto token = parser->tokenReader.AdvanceToken(); + constExpr->token = token; + parser->FillPosition(constExpr.Ptr()); + + if (!parser->LookAheadToken(TokenType::StringLiteral)) + { + // Easy/common case: a single string + constExpr->value = getStringLiteralTokenValue(token); + } + else + { + StringBuilder sb; + sb << getStringLiteralTokenValue(token); + while (parser->LookAheadToken(TokenType::StringLiteral)) + { + token = parser->tokenReader.AdvanceToken(); + sb << getStringLiteralTokenValue(token); + } + constExpr->value = sb.ProduceString(); + } + + return constExpr; + } + + case TokenType::Identifier: + { + // We will perform name lookup here so that we can find syntax + // keywords registered for use as expressions. + Token nameToken = peekToken(parser); + + RefPtr parsedExpr; + if (tryParseUsingSyntaxDecl(parser, &parsedExpr)) + { + if (!parsedExpr->loc.isValid()) + { + parsedExpr->loc = nameToken.loc; + } + return parsedExpr; + } + + // Default behavior is just to create a name expression + RefPtr varExpr = new VarExpr(); + varExpr->scope = parser->currentScope.Ptr(); + parser->FillPosition(varExpr.Ptr()); + + auto nameAndLoc = expectIdentifier(parser); + varExpr->name = nameAndLoc.name; + + if(peekTokenType(parser) == TokenType::OpLess) + { + return maybeParseGenericApp(parser, varExpr); + } + + return varExpr; + } + } + } + + static RefPtr parsePostfixExpr(Parser* parser) + { + auto expr = parseAtomicExpr(parser); + for(;;) + { + switch( peekTokenType(parser) ) + { + default: + return expr; + + // Postfix increment/decrement + case TokenType::OpInc: + case TokenType::OpDec: + { + RefPtr postfixExpr = new PostfixExpr(); + parser->FillPosition(postfixExpr.Ptr()); + postfixExpr->FunctionExpr = parseOperator(parser); + postfixExpr->Arguments.add(expr); + + expr = postfixExpr; + } + break; + + // Subscript operation `a[i]` + case TokenType::LBracket: + { + RefPtr indexExpr = new IndexExpr(); + indexExpr->BaseExpression = expr; + parser->FillPosition(indexExpr.Ptr()); + parser->ReadToken(TokenType::LBracket); + // TODO: eventually we may want to support multiple arguments inside the `[]` + if (!parser->LookAheadToken(TokenType::RBracket)) + { + indexExpr->IndexExpression = parser->ParseExpression(); + } + parser->ReadToken(TokenType::RBracket); + + expr = indexExpr; + } + break; + + // Call oepration `f(x)` + case TokenType::LParent: + { + RefPtr invokeExpr = new InvokeExpr(); + invokeExpr->FunctionExpr = expr; + parser->FillPosition(invokeExpr.Ptr()); + parser->ReadToken(TokenType::LParent); + while (!parser->tokenReader.IsAtEnd()) + { + if (!parser->LookAheadToken(TokenType::RParent)) + invokeExpr->Arguments.add(parser->ParseArgExpr()); + else + { + break; + } + if (!parser->LookAheadToken(TokenType::Comma)) + break; + parser->ReadToken(TokenType::Comma); + } + parser->ReadToken(TokenType::RParent); + + expr = invokeExpr; + } + break; + + // Scope access `x::m` + case TokenType::Scope: + { + RefPtr staticMemberExpr = new StaticMemberExpr(); + + // TODO(tfoley): why would a member expression need this? + staticMemberExpr->scope = parser->currentScope.Ptr(); + + parser->FillPosition(staticMemberExpr.Ptr()); + staticMemberExpr->BaseExpression = expr; + parser->ReadToken(TokenType::Scope); + staticMemberExpr->name = expectIdentifier(parser).name; + + if (peekTokenType(parser) == TokenType::OpLess) + expr = maybeParseGenericApp(parser, staticMemberExpr); + else + expr = staticMemberExpr; + + break; + } + // Member access `x.m` + case TokenType::Dot: + { + RefPtr memberExpr = new MemberExpr(); + + // TODO(tfoley): why would a member expression need this? + memberExpr->scope = parser->currentScope.Ptr(); + + parser->FillPosition(memberExpr.Ptr()); + memberExpr->BaseExpression = expr; + parser->ReadToken(TokenType::Dot); + memberExpr->name = expectIdentifier(parser).name; + + if (peekTokenType(parser) == TokenType::OpLess) + expr = maybeParseGenericApp(parser, memberExpr); + else + expr = memberExpr; + } + break; + } + } + } + + static RefPtr parsePrefixExpr(Parser* parser) + { + switch( peekTokenType(parser) ) + { + default: + return parsePostfixExpr(parser); + + case TokenType::OpInc: + case TokenType::OpDec: + case TokenType::OpNot: + case TokenType::OpBitNot: + case TokenType::OpAdd: + case TokenType::OpSub: + { + RefPtr prefixExpr = new PrefixExpr(); + parser->FillPosition(prefixExpr.Ptr()); + prefixExpr->FunctionExpr = parseOperator(parser); + prefixExpr->Arguments.add(parsePrefixExpr(parser)); + return prefixExpr; + } + break; + } + } + + RefPtr Parser::ParseLeafExpression() + { + return parsePrefixExpr(this); + } + + RefPtr parseTypeFromSourceFile( + Session* session, + TokenSpan const& tokens, + DiagnosticSink* sink, + RefPtr const& outerScope, + NamePool* namePool, + SourceLanguage sourceLanguage) + { + Parser parser(session, tokens, sink, outerScope); + parser.currentScope = outerScope; + parser.namePool = namePool; + parser.sourceLanguage = sourceLanguage; + return parser.ParseType(); + } + + // Parse a source file into an existing translation unit + void parseSourceFile( + TranslationUnitRequest* translationUnit, + TokenSpan const& tokens, + DiagnosticSink* sink, + RefPtr const& outerScope) + { + Parser parser(translationUnit->getSession(), tokens, sink, outerScope); + parser.namePool = translationUnit->getNamePool(); + parser.sourceLanguage = translationUnit->sourceLanguage; + + return parser.parseSourceFile(translationUnit->getModuleDecl()); + } + + static void addBuiltinSyntaxImpl( + Session* session, + Scope* scope, + char const* nameText, + SyntaxParseCallback callback, + void* userData, + SyntaxClass syntaxClass) + { + Name* name = session->getNamePool()->getName(nameText); + + RefPtr syntaxDecl = new SyntaxDecl(); + syntaxDecl->nameAndLoc = NameLoc(name); + syntaxDecl->syntaxClass = syntaxClass; + syntaxDecl->parseCallback = callback; + syntaxDecl->parseUserData = userData; + + AddMember(scope, syntaxDecl); + } + + template + static void addBuiltinSyntax( + Session* session, + Scope* scope, + char const* name, + SyntaxParseCallback callback, + void* userData = nullptr) + { + addBuiltinSyntaxImpl(session, scope, name, callback, userData, getClass()); + } + + template + static void addSimpleModifierSyntax( + Session* session, + Scope* scope, + char const* name) + { + auto syntaxClass = getClass(); + addBuiltinSyntaxImpl(session, scope, name, &parseSimpleSyntax, (void*) syntaxClass.classInfo, getClass()); + } + + static RefPtr parseIntrinsicOpModifier(Parser* parser, void* /*userData*/) + { + RefPtr modifier = new IntrinsicOpModifier(); + + // We allow a few difference forms here: + // + // First, we can specify the intrinsic op `enum` value directly: + // + // __intrinsic_op() + // + // Second, we can specify the operation by name: + // + // __intrinsic_op() + // + // Finally, we can leave off the specification, so that the + // op name will be derived from the function name: + // + // __intrinsic_op + // + if (AdvanceIf(parser, TokenType::LParent)) + { + if (AdvanceIf(parser, TokenType::OpSub)) + { + modifier->op = IROp(-StringToInt(parser->ReadToken().Content)); + } + else if (parser->LookAheadToken(TokenType::IntegerLiteral)) + { + modifier->op = IROp(StringToInt(parser->ReadToken().Content)); + } + else + { + modifier->opToken = parser->ReadToken(TokenType::Identifier); + + modifier->op = findIROp(modifier->opToken.Content); + + if (modifier->op == kIROp_Invalid) + { + parser->sink->diagnose(modifier->opToken, Diagnostics::unimplemented, "unknown intrinsic op"); + } + } + + parser->ReadToken(TokenType::RParent); + } + + + return modifier; + } + + static RefPtr parseTargetIntrinsicModifier(Parser* parser, void* /*userData*/) + { + auto modifier = new TargetIntrinsicModifier(); + + if (AdvanceIf(parser, TokenType::LParent)) + { + modifier->targetToken = parser->ReadToken(TokenType::Identifier); + + if( AdvanceIf(parser, TokenType::Comma) ) + { + if( parser->LookAheadToken(TokenType::StringLiteral) ) + { + modifier->definitionToken = parser->ReadToken(); + } + else + { + modifier->definitionToken = parser->ReadToken(TokenType::Identifier); + } + } + + parser->ReadToken(TokenType::RParent); + } + + return modifier; + } + + static RefPtr parseSpecializedForTargetModifier(Parser* parser, void* /*userData*/) + { + auto modifier = new SpecializedForTargetModifier(); + if (AdvanceIf(parser, TokenType::LParent)) + { + modifier->targetToken = parser->ReadToken(TokenType::Identifier); + parser->ReadToken(TokenType::RParent); + } + return modifier; + } + + static RefPtr parseGLSLExtensionModifier(Parser* parser, void* /*userData*/) + { + auto modifier = new RequiredGLSLExtensionModifier(); + + parser->ReadToken(TokenType::LParent); + modifier->extensionNameToken = parser->ReadToken(TokenType::Identifier); + parser->ReadToken(TokenType::RParent); + + return modifier; + } + + static RefPtr parseGLSLVersionModifier(Parser* parser, void* /*userData*/) + { + auto modifier = new RequiredGLSLVersionModifier(); + + parser->ReadToken(TokenType::LParent); + modifier->versionNumberToken = parser->ReadToken(TokenType::IntegerLiteral); + parser->ReadToken(TokenType::RParent); + + return modifier; + } + + static RefPtr parseLayoutModifier(Parser* parser, void* /*userData*/) + { + ModifierListBuilder listBuilder; + + listBuilder.add(new GLSLLayoutModifierGroupBegin()); + + parser->ReadToken(TokenType::LParent); + while (!AdvanceIfMatch(parser, TokenType::RParent)) + { + auto nameAndLoc = expectIdentifier(parser); + const String& nameText = nameAndLoc.name->text; + + if (nameText == "binding" || + nameText == "set") + { + GLSLBindingAttribute* attr = listBuilder.find(); + if (!attr) + { + attr = new GLSLBindingAttribute(); + listBuilder.add(attr); + } + + parser->ReadToken(TokenType::OpAssign); + + // If the token asked for is not returned found will put in recovering state, and return token found + Token valToken = parser->ReadToken(TokenType::IntegerLiteral); + // If wasn't the desired IntegerLiteral return that couldn't parse + if (valToken.type != TokenType::IntegerLiteral) + { + return nullptr; + } + + // Work out the value + auto value = getIntegerLiteralValue(valToken); + + if (nameText == "binding") + { + attr->binding = int32_t(value); + } + else + { + attr->set = int32_t(value); + } + } + else + { + RefPtr modifier; + +#define CASE(key, type) if (nameText == #key) { modifier = new type; } else + CASE(push_constant, PushConstantAttribute) + CASE(shaderRecordNV, ShaderRecordAttribute) + CASE(constant_id, GLSLConstantIDLayoutModifier) + CASE(location, GLSLLocationLayoutModifier) + CASE(local_size_x, GLSLLocalSizeXLayoutModifier) + CASE(local_size_y, GLSLLocalSizeYLayoutModifier) + CASE(local_size_z, GLSLLocalSizeZLayoutModifier) + { + modifier = new GLSLUnparsedLayoutModifier(); + } + SLANG_ASSERT(modifier); +#undef CASE + + modifier->name = nameAndLoc.name; + modifier->loc = nameAndLoc.loc; + + // Special handling for GLSLLayoutModifier + if (auto glslModifier = as(modifier)) + { + if (AdvanceIf(parser, TokenType::OpAssign)) + { + glslModifier->valToken = parser->ReadToken(TokenType::IntegerLiteral); + } + } + + listBuilder.add(modifier); + } + + if (AdvanceIf(parser, TokenType::RParent)) + break; + parser->ReadToken(TokenType::Comma); + } + + listBuilder.add(new GLSLLayoutModifierGroupEnd()); + + return listBuilder.getFirst(); + } + + static RefPtr parseBuiltinTypeModifier(Parser* parser, void* /*userData*/) + { + RefPtr modifier = new BuiltinTypeModifier(); + parser->ReadToken(TokenType::LParent); + modifier->tag = BaseType(StringToInt(parser->ReadToken(TokenType::IntegerLiteral).Content)); + parser->ReadToken(TokenType::RParent); + + return modifier; + } + + static RefPtr parseMagicTypeModifier(Parser* parser, void* /*userData*/) + { + RefPtr modifier = new MagicTypeModifier(); + parser->ReadToken(TokenType::LParent); + modifier->name = parser->ReadToken(TokenType::Identifier).Content; + if (AdvanceIf(parser, TokenType::Comma)) + { + modifier->tag = uint32_t(StringToInt(parser->ReadToken(TokenType::IntegerLiteral).Content)); + } + parser->ReadToken(TokenType::RParent); + + return modifier; + } + + static RefPtr parseIntrinsicTypeModifier(Parser* parser, void* /*userData*/) + { + RefPtr modifier = new IntrinsicTypeModifier(); + parser->ReadToken(TokenType::LParent); + modifier->irOp = uint32_t(StringToInt(parser->ReadToken(TokenType::IntegerLiteral).Content)); + while( AdvanceIf(parser, TokenType::Comma) ) + { + auto operand = uint32_t(StringToInt(parser->ReadToken(TokenType::IntegerLiteral).Content)); + modifier->irOperands.add(operand); + } + parser->ReadToken(TokenType::RParent); + + return modifier; + } + static RefPtr parseImplicitConversionModifier(Parser* parser, void* /*userData*/) + { + RefPtr modifier = new ImplicitConversionModifier(); + + ConversionCost cost = kConversionCost_Default; + if( AdvanceIf(parser, TokenType::LParent) ) + { + cost = ConversionCost(StringToInt(parser->ReadToken(TokenType::IntegerLiteral).Content)); + parser->ReadToken(TokenType::RParent); + } + modifier->cost = cost; + return modifier; + } + + static RefPtr parseAttributeTargetModifier(Parser* parser, void* /*userData*/) + { + expect(parser, TokenType::LParent); + auto syntaxClassNameAndLoc = expectIdentifier(parser); + expect(parser, TokenType::RParent); + + auto syntaxClass = parser->getSession()->findSyntaxClass(syntaxClassNameAndLoc.name); + + RefPtr modifier = new AttributeTargetModifier(); + modifier->syntaxClass = syntaxClass; + + return modifier; + } + + RefPtr populateBaseLanguageModule( + Session* session, + RefPtr scope) + { + RefPtr moduleDecl = new ModuleDecl(); + scope->containerDecl = moduleDecl; + + // Add syntax for declaration keywords + #define DECL(KEYWORD, CALLBACK) \ + addBuiltinSyntax(session, scope, #KEYWORD, &CALLBACK) + DECL(typedef, ParseTypeDef); + DECL(associatedtype, parseAssocType); + DECL(type_param, parseGlobalGenericParamDecl); + DECL(cbuffer, parseHLSLCBufferDecl); + DECL(tbuffer, parseHLSLTBufferDecl); + DECL(__generic, ParseGenericDecl); + DECL(__extension, ParseExtensionDecl); + DECL(extension, ParseExtensionDecl); + DECL(__init, parseConstructorDecl); + DECL(__subscript, ParseSubscriptDecl); + DECL(interface, parseInterfaceDecl); + DECL(syntax, parseSyntaxDecl); + DECL(attribute_syntax,parseAttributeSyntaxDecl); + DECL(__import, parseImportDecl); + DECL(import, parseImportDecl); + DECL(let, parseLetDecl); + DECL(var, parseVarDecl); + DECL(func, parseFuncDecl); + DECL(typealias, parseTypeAliasDecl); + + #undef DECL + + // Add syntax for "simple" modifier keywords. + // These are the ones that just appear as a single + // keyword (no further tokens expected/allowed), + // and which can be represented just by creating + // a new AST node of the corresponding type. + #define MODIFIER(KEYWORD, CLASS) \ + addSimpleModifierSyntax(session, scope, #KEYWORD) + + MODIFIER(in, InModifier); + MODIFIER(input, InputModifier); + MODIFIER(out, OutModifier); + MODIFIER(inout, InOutModifier); + MODIFIER(__ref, RefModifier); + MODIFIER(const, ConstModifier); + MODIFIER(instance, InstanceModifier); + MODIFIER(__builtin, BuiltinModifier); + + MODIFIER(inline, InlineModifier); + MODIFIER(public, PublicModifier); + MODIFIER(require, RequireModifier); + MODIFIER(param, ParamModifier); + MODIFIER(extern, ExternModifier); + + MODIFIER(row_major, HLSLRowMajorLayoutModifier); + MODIFIER(column_major, HLSLColumnMajorLayoutModifier); + + MODIFIER(nointerpolation, HLSLNoInterpolationModifier); + MODIFIER(noperspective, HLSLNoPerspectiveModifier); + MODIFIER(linear, HLSLLinearModifier); + MODIFIER(sample, HLSLSampleModifier); + MODIFIER(centroid, HLSLCentroidModifier); + MODIFIER(precise, PreciseModifier); + MODIFIER(shared, HLSLEffectSharedModifier); + MODIFIER(groupshared, HLSLGroupSharedModifier); + MODIFIER(static, HLSLStaticModifier); + MODIFIER(uniform, HLSLUniformModifier); + MODIFIER(volatile, HLSLVolatileModifier); + + // Modifiers for geometry shader input + MODIFIER(point, HLSLPointModifier); + MODIFIER(line, HLSLLineModifier); + MODIFIER(triangle, HLSLTriangleModifier); + MODIFIER(lineadj, HLSLLineAdjModifier); + MODIFIER(triangleadj, HLSLTriangleAdjModifier); + + // Modifiers for unary operator declarations + MODIFIER(__prefix, PrefixModifier); + MODIFIER(__postfix, PostfixModifier); + + // Modifier to apply to `import` that should be re-exported + MODIFIER(__exported, ExportedModifier); + + #undef MODIFIER + + // Add syntax for more complex modifiers, which allow + // or expect more tokens after the initial keyword. + #define MODIFIER(KEYWORD, CALLBACK) \ + addBuiltinSyntax(session, scope, #KEYWORD, &CALLBACK) + + MODIFIER(layout, parseLayoutModifier); + + MODIFIER(__intrinsic_op, parseIntrinsicOpModifier); + MODIFIER(__target_intrinsic, parseTargetIntrinsicModifier); + MODIFIER(__specialized_for_target, parseSpecializedForTargetModifier); + MODIFIER(__glsl_extension, parseGLSLExtensionModifier); + MODIFIER(__glsl_version, parseGLSLVersionModifier); + + MODIFIER(__builtin_type, parseBuiltinTypeModifier); + MODIFIER(__magic_type, parseMagicTypeModifier); + MODIFIER(__intrinsic_type, parseIntrinsicTypeModifier); + MODIFIER(__implicit_conversion, parseImplicitConversionModifier); + + MODIFIER(__attributeTarget, parseAttributeTargetModifier); + + +#undef MODIFIER + + // Add syntax for expression keywords + #define EXPR(KEYWORD, CALLBACK) \ + addBuiltinSyntax(session, scope, #KEYWORD, &CALLBACK) + + EXPR(this, parseThisExpr); + EXPR(true, parseTrueExpr); + EXPR(false, parseFalseExpr); + + #undef EXPR + + return moduleDecl; + } + +} diff --git a/source/slang/slang-parser.h b/source/slang/slang-parser.h new file mode 100644 index 000000000..98fd9ed65 --- /dev/null +++ b/source/slang/slang-parser.h @@ -0,0 +1,30 @@ +#ifndef SLANG_PARSER_H +#define SLANG_PARSER_H + +#include "slang-lexer.h" +#include "slang-compiler.h" +#include "slang-syntax.h" + +namespace Slang +{ + // Parse a source file into an existing translation unit + void parseSourceFile( + TranslationUnitRequest* translationUnit, + TokenSpan const& tokens, + DiagnosticSink* sink, + RefPtr const& outerScope); + + RefPtr parseTypeFromSourceFile( + Session* session, + TokenSpan const& tokens, + DiagnosticSink* sink, + RefPtr const& outerScope, + NamePool* namePool, + SourceLanguage sourceLanguage); + + RefPtr populateBaseLanguageModule( + Session* session, + RefPtr scope); +} + +#endif diff --git a/source/slang/slang-preprocessor.cpp b/source/slang/slang-preprocessor.cpp new file mode 100644 index 000000000..16a64e571 --- /dev/null +++ b/source/slang/slang-preprocessor.cpp @@ -0,0 +1,2302 @@ +// slang-preprocessor.cpp +#include "slang-preprocessor.h" + +#include "slang-compiler.h" +#include "slang-diagnostics.h" +#include "slang-lexer.h" +// Needed so that we can construct modifier syntax to represent GLSL directives +#include "slang-syntax.h" + +#include + +// This file provides an implementation of a simple C-style preprocessor. +// It does not aim for 100% compatibility with any particular preprocessor +// specification, but the goal is to have it accept the most common +// idioms for using the preprocessor, found in shader code in the wild. + + +namespace Slang { + +// State of a preprocessor conditional, which can change when +// we encounter directives like `#elif` or `#endif` +enum class PreprocessorConditionalState +{ + Before, // We have not yet seen a branch with a `true` condition. + During, // We are inside the branch with a `true` condition. + After, // We have already seen the branch with a `true` condition. +}; + +// Represents a preprocessor conditional that we are currently +// nested inside. +struct PreprocessorConditional +{ + // The next outer conditional in the current file/stream, or NULL. + PreprocessorConditional* parent; + + // The directive token that started the conditional (an `#if` or `#ifdef`) + Token ifToken; + + // The `#else` directive token, if one has been seen (otherwise `TokenType::Unknown`) + Token elseToken; + + // The state of the conditional + PreprocessorConditionalState state; +}; + +struct PreprocessorMacro; + +struct PreprocessorEnvironment +{ + // The "outer" environment, to be used if lookup in this env fails + PreprocessorEnvironment* parent = NULL; + + // Macros defined in this environment + Dictionary macros; + + ~PreprocessorEnvironment(); +}; + +// Input tokens can either come from source text, or from macro expansion. +// In general, input streams can be nested, so we have to keep a conceptual +// stack of input. + +struct PrimaryInputStream; + +// A stream of input tokens to be consumed +struct PreprocessorInputStream +{ + // The primary input stream that is the parent to this one, + // or NULL if this stream is itself a primary stream. + PrimaryInputStream* primaryStream; + + // The next input stream up the stack, if any. + PreprocessorInputStream* parent; + + // Environment to use when looking up macros + PreprocessorEnvironment* environment; + + // Destructor is virtual so that we can clean up + // after concrete subtypes. + virtual ~PreprocessorInputStream() = default; +}; + +// A "primary" input stream represents the top-level context of a file +// being parsed, and tracks things like preprocessor conditional state +struct PrimaryInputStream : PreprocessorInputStream +{ + // The next *primary* input stream up the stack + PrimaryInputStream* parentPrimaryInputStream; + + // The deepest preprocessor conditional active for this stream. + PreprocessorConditional* conditional; + + // The lexer state that will provide input + Lexer lexer; + + // One token of lookahead + Token token; +}; + +// A "secondary" input stream represents code that is being expanded +// into the current scope, but which had already been tokenized before. +// +struct PretokenizedInputStream : PreprocessorInputStream +{ + // Reader for pre-tokenized input + TokenReader tokenReader; +}; + +// A pre-tokenized input stream that will only be used once, and which +// therefore owns the memory for its tokens. +struct SimpleTokenInputStream : PretokenizedInputStream +{ + // A list of raw tokens that will provide input + TokenList lexedTokens; +}; + +struct MacroExpansion : PretokenizedInputStream +{ + // The macro we will expand + PreprocessorMacro* macro; +}; + +struct ObjectLikeMacroExpansion : MacroExpansion +{ +}; + +struct FunctionLikeMacroExpansion : MacroExpansion +{ + // Environment for macro arguments + PreprocessorEnvironment argumentEnvironment; +}; + +// An enumeration for the diferent types of macros +enum class PreprocessorMacroFlavor +{ + ObjectLike, + FunctionArg, + FunctionLike, +}; + +// In the current design (which we may want to re-consider), +// a macro is a specialized flavor of input stream, that +// captures the token list in its expansion, and then +// can be "played back." +struct PreprocessorMacro +{ + // The name under which the macro was `#define`d + NameLoc nameAndLoc; + + // Parameters of the macro, in case of a function-like macro + List params; + + // The tokens that make up the macro body + TokenList tokens; + + // The flavor of macro + PreprocessorMacroFlavor flavor; + + // The environment in which this macro needs to be expanded. + // For ordinary macros this will be the global environment, + // while for function-like macro arguments, it will be + // the environment of the macro invocation. + PreprocessorEnvironment* environment; + + // + Name* getName() + { + return nameAndLoc.name; + } + + SourceLoc getLoc() + { + return nameAndLoc.loc; + } +}; + +// State of the preprocessor +struct Preprocessor +{ + // diagnostics sink to use when writing messages + DiagnosticSink* sink; + + // An external callback interface to use when looking + // for files in a `#include` directive + IncludeHandler* includeHandler; + + // Current input stream (top of the stack of input) + PreprocessorInputStream* inputStream; + + // Currently-defined macros + PreprocessorEnvironment globalEnv; + + // A pre-allocated token that can be returned to + // represent end-of-input situations. + Token endOfFileToken; + + /// The linkage the provides the context for preprocessing + Linkage* linkage = nullptr; + + /// The module, if any, that the preprocessed result will belong to + Module* parentModule = nullptr; + + // The unique identities of any paths that have issued `#pragma once` directives to + // stop them from being included again. + HashSet pragmaOnceUniqueIdentities; + + NamePool* getNamePool() { return linkage->getNamePool(); } + SourceManager* getSourceManager() { return linkage->getSourceManager(); } +}; + +// Convenience routine to access the diagnostic sink +static DiagnosticSink* GetSink(Preprocessor* preprocessor) +{ + return preprocessor->sink; +} + +// +// Forward declarations +// + +static void DestroyConditional(PreprocessorConditional* conditional); +static void DestroyMacro(Preprocessor* preprocessor, PreprocessorMacro* macro); +static bool IsSkipping(Preprocessor* preprocessor); + +// +// Basic Input Handling +// + +// Create a fresh input stream +static void initializeInputStream(Preprocessor* preprocessor, PreprocessorInputStream* inputStream) +{ + inputStream->parent = NULL; + inputStream->environment = &preprocessor->globalEnv; +} + +static void initializePrimaryInputStream(Preprocessor* preprocessor, PrimaryInputStream* inputStream) +{ + initializeInputStream(preprocessor, inputStream); + inputStream->primaryStream = inputStream; + inputStream->conditional = NULL; +} + +// Destroy an input stream +static void destroyInputStream(Preprocessor* /*preprocessor*/, PreprocessorInputStream* inputStream) +{ + delete inputStream; +} + +// Create an input stream to represent a pre-tokenized input file. +// TODO(tfoley): pre-tokenizing files isn't going to work in the long run. +static PreprocessorInputStream* CreateInputStreamForSource( + Preprocessor* preprocessor, + SourceView* sourceView) +{ + MemoryArena* memoryArena = sourceView->getSourceManager()->getMemoryArena(); + + PrimaryInputStream* inputStream = new PrimaryInputStream(); + initializePrimaryInputStream(preprocessor, inputStream); + + // initialize the embedded lexer so that it can generate a token stream + inputStream->lexer.initialize(sourceView, GetSink(preprocessor), preprocessor->getNamePool(), memoryArena); + inputStream->token = inputStream->lexer.lexToken(); + + return inputStream; +} + +static PrimaryInputStream* asPrimaryInputStream(PreprocessorInputStream* inputStream) +{ + auto primaryStream = inputStream->primaryStream; + if(primaryStream == inputStream) + return primaryStream; + return nullptr; +} + + +static void PushInputStream(Preprocessor* preprocessor, PreprocessorInputStream* inputStream) +{ + inputStream->parent = preprocessor->inputStream; + if(!asPrimaryInputStream(inputStream)) + inputStream->primaryStream = preprocessor->inputStream->primaryStream; + preprocessor->inputStream = inputStream; +} + +// Called when we reach the end of an input stream. +// Performs some validation and then destroys the input stream if required. +static void EndInputStream(Preprocessor* preprocessor, PreprocessorInputStream* inputStream) +{ + if(auto primaryStream = asPrimaryInputStream(inputStream)) + { + // If there are any conditionals that weren't completed, then it is an error + if (primaryStream->conditional) + { + PreprocessorConditional* conditional = primaryStream->conditional; + + GetSink(preprocessor)->diagnose(conditional->ifToken.loc, Diagnostics::endOfFileInPreprocessorConditional); + + while (conditional) + { + PreprocessorConditional* parent = conditional->parent; + DestroyConditional(conditional); + conditional = parent; + } + } + } + + destroyInputStream(preprocessor, inputStream); +} + +// Consume one token from an input stream +static Token AdvanceRawToken(PreprocessorInputStream* inputStream, LexerFlags lexerFlags = 0) +{ + if( auto primaryStream = asPrimaryInputStream(inputStream) ) + { + auto result = primaryStream->token; + primaryStream->token = primaryStream->lexer.lexToken(lexerFlags); + return result; + } + else + { + PretokenizedInputStream* pretokenized = (PretokenizedInputStream*) inputStream; + return pretokenized->tokenReader.AdvanceToken(); + } +} + +// Peek one token from an input stream +static Token PeekRawToken(PreprocessorInputStream* inputStream) +{ + if( auto primaryStream = asPrimaryInputStream(inputStream) ) + { + return primaryStream->token; + } + else + { + PretokenizedInputStream* pretokenized = (PretokenizedInputStream*) inputStream; + return pretokenized->tokenReader.PeekToken(); + } +} + +// Peek one token type from an input stream +static TokenType PeekRawTokenType(PreprocessorInputStream* inputStream) +{ + if( auto primaryStream = asPrimaryInputStream(inputStream) ) + { + return primaryStream->token.type; + } + else + { + PretokenizedInputStream* pretokenized = (PretokenizedInputStream*) inputStream; + return pretokenized->tokenReader.PeekTokenType(); + } +} + + +// Read one token in "raw" mode (meaning don't expand macros) +static Token AdvanceRawToken(Preprocessor* preprocessor, LexerFlags lexerFlags = 0) +{ + for(;;) + { + // Look at the input stream on top of the stack + PreprocessorInputStream* inputStream = preprocessor->inputStream; + + // If there isn't one, then there is no more input left to read. + if(!inputStream) + { + return preprocessor->endOfFileToken; + } + + // The top-most input stream may be at its end + if(PeekRawTokenType(inputStream) == TokenType::EndOfFile) + { + // If there is another stream remaining, switch to it + if(inputStream->parent) + { + preprocessor->inputStream = inputStream->parent; + EndInputStream(preprocessor, inputStream); + continue; + } + } + + // Everything worked, so read a token from the top-most stream + return AdvanceRawToken( + inputStream, + lexerFlags | (IsSkipping(preprocessor) ? kLexerFlag_IgnoreInvalid : 0)); + } +} + +// Return the next token in "raw" mode, but don't advance the +// current token state. +static Token PeekRawToken(Preprocessor* preprocessor) +{ + // We need to find the stream that `advanceRawToken` would read from. + PreprocessorInputStream* inputStream = preprocessor->inputStream; + for (;;) + { + if (!inputStream) + { + // No more input streams left to read + return preprocessor->endOfFileToken; + } + + // The top-most input stream may be at its end, so + // look one entry up the stack (don't actually pop + // here, since we are just peeking) + if (PeekRawTokenType(inputStream) == TokenType::EndOfFile) + { + if (inputStream->parent) + { + inputStream = inputStream->parent; + continue; + } + } + + // Everything worked, so the token we just peeked is fine. + return PeekRawToken(inputStream); + } +} + +// Get the location of the current (raw) token +static SourceLoc PeekLoc(Preprocessor* preprocessor) +{ + return PeekRawToken(preprocessor).loc; +} + +// Get the `TokenType` of the current (raw) token +static TokenType PeekRawTokenType(Preprocessor* preprocessor) +{ + return PeekRawToken(preprocessor).type; +} + +// +// Macros +// + +// Create a macro +static PreprocessorMacro* CreateMacro(Preprocessor* preprocessor) +{ + // TODO(tfoley): Allocate these more intelligently. + // For example, consider pooling them on the preprocessor. + + PreprocessorMacro* macro = new PreprocessorMacro(); + macro->flavor = PreprocessorMacroFlavor::ObjectLike; + macro->environment = &preprocessor->globalEnv; + return macro; +} + +// Destroy a macro +static void DestroyMacro(Preprocessor* /*preprocessor*/, PreprocessorMacro* macro) +{ + delete macro; +} + + +// Find the currently-defined macro of the given name, or return NULL +static PreprocessorMacro* LookupMacro(PreprocessorEnvironment* environment, Name* name) +{ + for(PreprocessorEnvironment* e = environment; e; e = e->parent) + { + PreprocessorMacro* macro = NULL; + if (e->macros.TryGetValue(name, macro)) + return macro; + } + + return NULL; +} + +static PreprocessorEnvironment* GetCurrentEnvironment(Preprocessor* preprocessor) +{ + // The environment we will use for looking up a macro is associated + // with the current input stream (because it may include entries + // for macro arguments). + // + // We need to be careful, though, when we are at the end of an + // input stream (e.g., representing one argument), so that we + // don't use its environment. + + PreprocessorInputStream* inputStream = preprocessor->inputStream; + + for(;;) + { + // If there is no input stream that isn't at its end, + // then fall back to the global environment. + if (!inputStream) + return &preprocessor->globalEnv; + + // If the current input stream is at its end, then + // fall back to its parent stream. + if (PeekRawTokenType(inputStream) == TokenType::EndOfFile) + { + inputStream = inputStream->parent; + continue; + } + + // If we've found an active stream that isn't at its end, + // then use that for lookup. + return inputStream->environment; + } +} + +static PreprocessorMacro* LookupMacro(Preprocessor* preprocessor, Name* name) +{ + return LookupMacro(GetCurrentEnvironment(preprocessor), name); +} + +// A macro is "busy" if it is currently being used for expansion. +// A macro cannot be expanded again while busy, to avoid infinite recursion. +static bool IsMacroBusy(PreprocessorMacro* /*macro*/) +{ + // TODO: need to implement this correctly + // + // The challenge here is that we are implementing expansion + // for argumenst to function-like macros in a "lazy" fashion. + // + // The letter of the spec is that we should macro expand + // each argument *before* substitution, and then go and + // macro-expand the substituted body. This means that we + // can invoke a macro as part of an argument to an + // invocation of the same macro: + // + // FOO( 1, FOO(22), 333 ); + // + // In our implementation, the "inner" invocation of `FOO` + // gets expanded at the point where it gets referenced + // in the body of the "outer" invocation of `FOO`. + // Doing things this way leads to greatly simplified + // code for handling expansion. + // + // A proper implementation of `IsMacroBusy` needs to + // take context into account, so that it bans recursive + // use of a macro when it occurs (indirectly) through + // the *body* of the expansion, but not when it occcurs + // only through an *argument*. + return false; +} + +// +// Reading Tokens With Expansion +// + +static void InitializeMacroExpansion( + Preprocessor* preprocessor, + MacroExpansion* expansion, + PreprocessorMacro* macro) +{ + initializeInputStream(preprocessor, expansion); + + expansion->parent = preprocessor->inputStream; + expansion->primaryStream = preprocessor->inputStream->primaryStream; + + expansion->environment = macro->environment; + expansion->macro = macro; + expansion->tokenReader = TokenReader(macro->tokens); +} + +static void PushMacroExpansion( + Preprocessor* preprocessor, + MacroExpansion* expansion) +{ + PushInputStream(preprocessor, expansion); +} + +static void AddEndOfStreamToken( + Preprocessor* preprocessor, + PreprocessorMacro* macro) +{ + Token token = PeekRawToken(preprocessor); + token.type = TokenType::EndOfFile; + macro->tokens.mTokens.add(token); +} + +static SimpleTokenInputStream* createSimpleInputStream( + Preprocessor* preprocessor, + Token const& token) +{ + SimpleTokenInputStream* inputStream = new SimpleTokenInputStream(); + initializeInputStream(preprocessor, inputStream); + + inputStream->lexedTokens.mTokens.add(token); + + Token eofToken; + eofToken.type = TokenType::EndOfFile; + eofToken.loc = token.loc; + eofToken.flags = TokenFlag::AfterWhitespace | TokenFlag::AtStartOfLine; + inputStream->lexedTokens.mTokens.add(eofToken); + + inputStream->tokenReader = TokenReader(inputStream->lexedTokens); + + return inputStream; +} + +// Check whether the current token on the given input stream should be +// treated as a macro invocation, and if so set up state for expanding +// that macro. +static void MaybeBeginMacroExpansion( + Preprocessor* preprocessor ) +{ + // We iterate because the first token in the expansion of one + // macro may be another macro invocation. + for (;;) + { + // Look at the next token ahead of us + Token token = PeekRawToken(preprocessor); + + // Not an identifier? Can't be a macro. + if (token.type != TokenType::Identifier) + return; + + // Look for a macro with the given name. + Name* name = token.getName(); + PreprocessorMacro* macro = LookupMacro(preprocessor, name); + + // Not a macro? Can't be an invocation. + if (!macro) + return; + + // If the macro is busy (already being expanded), + // don't try to trigger recursive expansion + if (IsMacroBusy(macro)) + return; + + // We might already have looked at this token, + // and need to suppress expansion + if (token.flags & TokenFlag::SuppressMacroExpansion) + return; + + // A function-style macro invocation should only match + // if the token *after* the identifier is `(`. This + // requires more lookahead than we usually have/need + if (macro->flavor == PreprocessorMacroFlavor::FunctionLike) + { + // Consume the token that (possibly) triggered macro expansion + AdvanceRawToken(preprocessor); + + // Look at the next token, and see if it is an opening `(` + // that indicates we should actually expand a macro. + if(PeekRawTokenType(preprocessor) != TokenType::LParent) + { + // In this case, we are in a bit of a mess, because we have + // consumed the token that named the macro, but we need to + // make sure that token (and not whatever came after it) + // gets returned to the user. + // + // To work around this we will construct a short-lived input + // stream just to handle that one token, and also set + // a flag on the token to keep us from doing this logic again. + + token.flags |= TokenFlag::SuppressMacroExpansion; + + SimpleTokenInputStream* simpleStream = createSimpleInputStream(preprocessor, token); + PushInputStream(preprocessor, simpleStream); + return; + } + + // Consume the opening `(` + Token leftParen = AdvanceRawToken(preprocessor); + + FunctionLikeMacroExpansion* expansion = new FunctionLikeMacroExpansion(); + InitializeMacroExpansion(preprocessor, expansion, macro); + expansion->argumentEnvironment.parent = &preprocessor->globalEnv; + expansion->environment = &expansion->argumentEnvironment; + + // Try to read any arguments present. + UInt paramCount = macro->params.getCount(); + UInt argIndex = 0; + + switch (PeekRawTokenType(preprocessor)) + { + case TokenType::EndOfFile: + case TokenType::RParent: + // No arguments. + break; + + default: + // At least one argument + while(argIndex < paramCount) + { + // Read an argument + + // Create the argument, represented as a special flavor of macro + PreprocessorMacro* arg = CreateMacro(preprocessor); + arg->flavor = PreprocessorMacroFlavor::FunctionArg; + arg->environment = GetCurrentEnvironment(preprocessor); + + // Associate the new macro with its parameter name + NameLoc paramNameAndLoc = macro->params[argIndex]; + Name* paramName = paramNameAndLoc.name; + arg->nameAndLoc = paramNameAndLoc; + expansion->argumentEnvironment.macros[paramName] = arg; + argIndex++; + + // Read tokens for the argument + + // We track the nesting depth, since we don't break + // arguments on a `,` nested in balanced parentheses + // + int nesting = 0; + for (;;) + { + switch (PeekRawTokenType(preprocessor)) + { + case TokenType::EndOfFile: + // if we reach the end of the file, + // then we have an error, and need to + // bail out + AddEndOfStreamToken(preprocessor, arg); + goto doneWithAllArguments; + + case TokenType::RParent: + // If we see a right paren when we aren't nested + // then we are at the end of an argument + if (nesting == 0) + { + AddEndOfStreamToken(preprocessor, arg); + goto doneWithAllArguments; + } + // Otherwise we decrease our nesting depth, add + // the token, and keep going + nesting--; + break; + + case TokenType::Comma: + // If we see a comma when we aren't nested + // then we are at the end of an argument + if (nesting == 0) + { + AddEndOfStreamToken(preprocessor, arg); + AdvanceRawToken(preprocessor); + goto doneWithArgument; + } + // Otherwise we add it as a normal token + break; + + case TokenType::LParent: + // If we see a left paren then we need to + // increase our tracking of nesting + nesting++; + break; + + default: + break; + } + + // Add the token and continue parsing. + arg->tokens.mTokens.add(AdvanceRawToken(preprocessor)); + } + doneWithArgument: {} + // We've parsed an argument and should move onto + // the next one. + } + break; + } + doneWithAllArguments: + // TODO: handle possible varargs + + // Expect closing right paren + if (PeekRawTokenType(preprocessor) == TokenType::RParent) + { + AdvanceRawToken(preprocessor); + } + else + { + GetSink(preprocessor)->diagnose(PeekLoc(preprocessor), Diagnostics::expectedTokenInMacroArguments, TokenType::RParent, PeekRawTokenType(preprocessor)); + } + + UInt argCount = argIndex; + if (argCount != paramCount) + { + GetSink(preprocessor)->diagnose(PeekLoc(preprocessor), Diagnostics::wrongNumberOfArgumentsToMacro, paramCount, argCount); + } + + // We are ready to expand. + PushMacroExpansion(preprocessor, expansion); + } + else + { + // Consume the token that triggered macro expansion + AdvanceRawToken(preprocessor); + + // Object-like macros are the easy case. + ObjectLikeMacroExpansion* expansion = new ObjectLikeMacroExpansion(); + InitializeMacroExpansion(preprocessor, expansion, macro); + PushMacroExpansion(preprocessor, expansion); + } + } +} + +// Read one token with macro-expansion enabled. +static Token AdvanceToken(Preprocessor* preprocessor) +{ +top: + // Check whether we need to macro expand at the cursor. + MaybeBeginMacroExpansion(preprocessor); + + // Read a raw token (now that expansion has been triggered) + Token token = AdvanceRawToken(preprocessor); + + // Check if we need to perform token pasting + if (PeekRawTokenType(preprocessor) != TokenType::PoundPound) + { + // If we aren't token pasting, then we are done + return token; + } + else + { + // We are pasting tokens, which could get messy + + StringBuilder sb; + sb << token.Content; + + while (PeekRawTokenType(preprocessor) == TokenType::PoundPound) + { + // Consume the `##` + AdvanceRawToken(preprocessor); + + // Possibly macro-expand the next token + MaybeBeginMacroExpansion(preprocessor); + + // Read the next raw token (now that expansion has been triggered) + Token nextToken = AdvanceRawToken(preprocessor); + + sb << nextToken.Content; + } + + // Now re-lex the input + + SourceManager* sourceManager = preprocessor->getSourceManager(); + + // We create a dummy file to represent the token-paste operation + PathInfo pathInfo = PathInfo::makeTokenPaste(); + + SourceFile* sourceFile = sourceManager->createSourceFileWithString(pathInfo, sb.ProduceString()); + SourceView* sourceView = sourceManager->createSourceView(sourceFile, nullptr); + + Lexer lexer; + lexer.initialize(sourceView, GetSink(preprocessor), preprocessor->getNamePool(), sourceManager->getMemoryArena()); + + SimpleTokenInputStream* inputStream = new SimpleTokenInputStream(); + initializeInputStream(preprocessor, inputStream); + + inputStream->lexedTokens = lexer.lexAllTokens(); + inputStream->tokenReader = TokenReader(inputStream->lexedTokens); + + // We expect the reuslt of lexing to be two tokens: one for the actual value, + // and one for the end-of-input marker. + if (inputStream->tokenReader.GetCount() != 2) + { + // We expect a token paste to produce a single token + // TODO(tfoley): emit a diagnostic here + } + + PushInputStream(preprocessor, inputStream); + goto top; + } +} + +// Read one token with macro-expansion enabled. +// +// Note that because triggering macro expansion may +// involve changing the input-stream state, this +// operation *can* have side effects. +static Token PeekToken(Preprocessor* preprocessor) +{ + // Check whether we need to macro expand at the cursor. + MaybeBeginMacroExpansion(preprocessor); + + // Peek a raw token (now that expansion has been triggered) + return PeekRawToken(preprocessor); + + // TODO: need a plan for how to handle token pasting + // here without it being onerous. Would be nice if we + // didn't have to re-do pasting on a "peek"... +} + +// Peek the type of the next token, including macro expansion. +static TokenType PeekTokenType(Preprocessor* preprocessor) +{ + return PeekToken(preprocessor).type; +} + +// +// Preprocessor Directives +// + +// When reading a preprocessor directive, we use a context +// to wrap the direct preprocessor routines defines so far. +// +// One of the most important things the directive context +// does is give us a convenient way to read tokens with +// a guarantee that we won't read past the end of a line. +struct PreprocessorDirectiveContext +{ + // The preprocessor that is parsing the directive. + Preprocessor* preprocessor; + + // The directive token (e.g., the `if` in `#if`). + // Useful for reference in diagnostic messages. + Token directiveToken; + + // Has any kind of parse error been encountered in + // the directive so far? + bool parseError; + + // Have we done the necessary checks at the end + // of the directive already? + bool haveDoneEndOfDirectiveChecks; +}; + +// Get the token for the preprocessor directive being parsed. +inline Token const& GetDirective(PreprocessorDirectiveContext* context) +{ + return context->directiveToken; +} + +// Get the name of the directive being parsed. +inline UnownedStringSlice const& GetDirectiveName(PreprocessorDirectiveContext* context) +{ + return context->directiveToken.Content; +} + +// Get the location of the directive being parsed. +inline SourceLoc const& GetDirectiveLoc(PreprocessorDirectiveContext* context) +{ + return context->directiveToken.loc; +} + +// Wrapper to get the diagnostic sink in the context of a directive. +static inline DiagnosticSink* GetSink(PreprocessorDirectiveContext* context) +{ + return GetSink(context->preprocessor); +} + +// Wrapper to get a "current" location when parsing a directive +static SourceLoc PeekLoc(PreprocessorDirectiveContext* context) +{ + return PeekLoc(context->preprocessor); +} + +// Wrapper to look up a macro in the context of a directive. +static PreprocessorMacro* LookupMacro(PreprocessorDirectiveContext* context, Name* name) +{ + return LookupMacro(context->preprocessor, name); +} + +// Determine if we have read everything on the directive's line. +static bool IsEndOfLine(PreprocessorDirectiveContext* context) +{ + return PeekRawToken(context->preprocessor).type == TokenType::EndOfDirective; +} + +// Peek one raw token in a directive, without going past the end of the line. +static Token PeekRawToken(PreprocessorDirectiveContext* context) +{ + return PeekRawToken(context->preprocessor); +} + +// Read one raw token in a directive, without going past the end of the line. +static Token AdvanceRawToken(PreprocessorDirectiveContext* context, LexerFlags lexerFlags = 0) +{ + if (IsEndOfLine(context)) + return PeekRawToken(context); + return AdvanceRawToken(context->preprocessor, lexerFlags); +} + +// Peek next raw token type, without going past the end of the line. +static TokenType PeekRawTokenType(PreprocessorDirectiveContext* context) +{ + return PeekRawTokenType(context->preprocessor); +} + +// Read one token, with macro-expansion, without going past the end of the line. +static Token AdvanceToken(PreprocessorDirectiveContext* context) +{ + if (IsEndOfLine(context)) + return PeekRawToken(context); + return AdvanceToken(context->preprocessor); +} + +// Peek one token, with macro-expansion, without going past the end of the line. +static Token PeekToken(PreprocessorDirectiveContext* context) +{ + if (IsEndOfLine(context)) + return context->preprocessor->endOfFileToken; + return PeekToken(context->preprocessor); +} + +// Peek next token type, with macro-expansion, without going past the end of the line. +static TokenType PeekTokenType(PreprocessorDirectiveContext* context) +{ + if (IsEndOfLine(context)) + return TokenType::EndOfDirective; + return PeekTokenType(context->preprocessor); +} + +// Skip to the end of the line (useful for recovering from errors in a directive) +static void SkipToEndOfLine(PreprocessorDirectiveContext* context) +{ + while(!IsEndOfLine(context)) + { + AdvanceRawToken(context); + } +} + +static bool ExpectRaw(PreprocessorDirectiveContext* context, TokenType tokenType, DiagnosticInfo const& diagnostic, Token* outToken = NULL) +{ + if (PeekRawTokenType(context) != tokenType) + { + // Only report the first parse error within a directive + if (!context->parseError) + { + GetSink(context)->diagnose(PeekLoc(context), diagnostic, tokenType, GetDirectiveName(context)); + } + context->parseError = true; + return false; + } + Token const& token = AdvanceRawToken(context); + if (outToken) + *outToken = token; + return true; +} + +static bool Expect(PreprocessorDirectiveContext* context, TokenType tokenType, DiagnosticInfo const& diagnostic, Token* outToken = NULL) +{ + if (PeekTokenType(context) != tokenType) + { + // Only report the first parse error within a directive + if (!context->parseError) + { + GetSink(context)->diagnose(PeekLoc(context), diagnostic, tokenType, GetDirectiveName(context)); + context->parseError = true; + } + return false; + } + Token const& token = AdvanceToken(context); + if (outToken) + *outToken = token; + return true; +} + + + +// +// Preprocessor Conditionals +// + +// Determine whether the current preprocessor state means we +// should be skipping tokens. +static bool IsSkipping(Preprocessor* preprocessor) +{ + PreprocessorInputStream* inputStream = preprocessor->inputStream; + if (!inputStream) return false; + + PrimaryInputStream* primaryStream = inputStream->primaryStream; + if(!primaryStream) return false; + + // If we are not inside a preprocessor conditional, then don't skip + PreprocessorConditional* conditional = primaryStream->conditional; + if (!conditional) return false; + + // skip tokens unless the conditional is inside its `true` case + return conditional->state != PreprocessorConditionalState::During; +} + +// Wrapper for use inside directives +static inline bool IsSkipping(PreprocessorDirectiveContext* context) +{ + return IsSkipping(context->preprocessor); +} + +// Create a preprocessor conditional +static PreprocessorConditional* CreateConditional(Preprocessor* /*preprocessor*/) +{ + // TODO(tfoley): allocate these more intelligently (for example, + // pool them on the `Preprocessor`. + return new PreprocessorConditional(); +} + +// Destroy a preprocessor conditional. +static void DestroyConditional(PreprocessorConditional* conditional) +{ + delete conditional; +} + +// Start a preprocessor conditional, with an initial enable/disable state. +static void beginConditional( + PreprocessorDirectiveContext* context, + PreprocessorInputStream* inputStream, + bool enable) +{ + Preprocessor* preprocessor = context->preprocessor; + SLANG_ASSERT(inputStream); + + PreprocessorConditional* conditional = CreateConditional(preprocessor); + + conditional->ifToken = context->directiveToken; + + // Set state of this condition appropriately. + // + // Default to the "haven't yet seen a `true` branch" state. + PreprocessorConditionalState state = PreprocessorConditionalState::Before; + // + // If we are nested inside a `false` branch of another condition, then + // we never want to enable, so we act as if we already *saw* the `true` branch. + // + if (IsSkipping(preprocessor)) state = PreprocessorConditionalState::After; + // + // Similarly, if we ran into any parse errors when dealing with the + // opening directive, then things are probably screwy and we should just + // skip all the branches. + if (IsSkipping(preprocessor)) state = PreprocessorConditionalState::After; + // + // Otherwise, if our condition was true, then set us to be inside the `true` branch + else if (enable) state = PreprocessorConditionalState::During; + + conditional->state = state; + + // Push conditional onto the stack + auto primaryStream = inputStream->primaryStream; + conditional->parent = primaryStream->conditional; + primaryStream->conditional = conditional; +} + +// Start a preprocessor conditional, with an initial enable/disable state. +static void beginConditional( + PreprocessorDirectiveContext* context, + bool enable) +{ + beginConditional(context, context->preprocessor->inputStream, enable); +} + +// +// Preprocessor Conditional Expressions +// + +// Conditional expressions are always of type `int` +typedef int PreprocessorExpressionValue; + +// Forward-declaretion +static PreprocessorExpressionValue ParseAndEvaluateExpression(PreprocessorDirectiveContext* context); + +// Parse a unary (prefix) expression inside of a preprocessor directive. +static PreprocessorExpressionValue ParseAndEvaluateUnaryExpression(PreprocessorDirectiveContext* context) +{ + switch (PeekTokenType(context)) + { + // handle prefix unary ops + case TokenType::OpSub: + AdvanceToken(context); + return -ParseAndEvaluateUnaryExpression(context); + case TokenType::OpNot: + AdvanceToken(context); + return !ParseAndEvaluateUnaryExpression(context); + case TokenType::OpBitNot: + AdvanceToken(context); + return ~ParseAndEvaluateUnaryExpression(context); + + // handle parenthized sub-expression + case TokenType::LParent: + { + Token leftParen = AdvanceToken(context); + PreprocessorExpressionValue value = ParseAndEvaluateExpression(context); + if (!Expect(context, TokenType::RParent, Diagnostics::expectedTokenInPreprocessorExpression)) + { + GetSink(context)->diagnose(leftParen.loc, Diagnostics::seeOpeningToken, leftParen); + } + return value; + } + + case TokenType::IntegerLiteral: + return StringToInt(AdvanceToken(context).Content); + + case TokenType::Identifier: + { + Token token = AdvanceToken(context); + if (token.Content == "defined") + { + // handle `defined(someName)` + + // Possibly parse a `(` + Token leftParen; + if (PeekRawTokenType(context) == TokenType::LParent) + { + leftParen = AdvanceRawToken(context); + } + + // Expect an identifier + Token nameToken; + if (!ExpectRaw(context, TokenType::Identifier, Diagnostics::expectedTokenInDefinedExpression, &nameToken)) + { + return 0; + } + Name* name = nameToken.getName(); + + // If we saw an opening `(`, then expect one to close + if (leftParen.type != TokenType::Unknown) + { + if(!ExpectRaw(context, TokenType::RParent, Diagnostics::expectedTokenInDefinedExpression)) + { + GetSink(context)->diagnose(leftParen.loc, Diagnostics::seeOpeningToken, leftParen); + return 0; + } + } + + return LookupMacro(context, name) != NULL; + } + + // An identifier here means it was not defined as a macro (or + // it is defined, but as a function-like macro. These should + // just evaluate to zero (possibly with a warning) + GetSink(context)->diagnose(token.loc, Diagnostics::undefinedIdentifierInPreprocessorExpression, token.getName()); + return 0; + } + + default: + GetSink(context)->diagnose(PeekLoc(context), Diagnostics::syntaxErrorInPreprocessorExpression); + return 0; + } +} + +// Determine the precedence level of an infix operator +// for use in parsing preprocessor conditionals. +static int GetInfixOpPrecedence(Token const& opToken) +{ + // If token is on another line, it is not part of the + // expression + if (opToken.flags & TokenFlag::AtStartOfLine) + return -1; + + // otherwise we look at the token type to figure + // out what precedence it should be parse with + switch (opToken.type) + { + default: + // tokens that aren't infix operators should + // cause us to stop parsing an expression + return -1; + + case TokenType::OpMul: return 10; + case TokenType::OpDiv: return 10; + case TokenType::OpMod: return 10; + + case TokenType::OpAdd: return 9; + case TokenType::OpSub: return 9; + + case TokenType::OpLsh: return 8; + case TokenType::OpRsh: return 8; + + case TokenType::OpLess: return 7; + case TokenType::OpGreater: return 7; + case TokenType::OpLeq: return 7; + case TokenType::OpGeq: return 7; + + case TokenType::OpEql: return 6; + case TokenType::OpNeq: return 6; + + case TokenType::OpBitAnd: return 5; + case TokenType::OpBitOr: return 4; + case TokenType::OpBitXor: return 3; + case TokenType::OpAnd: return 2; + case TokenType::OpOr: return 1; + } +}; + +// Evaluate one infix operation in a preprocessor +// conditional expression +static PreprocessorExpressionValue EvaluateInfixOp( + PreprocessorDirectiveContext* context, + Token const& opToken, + PreprocessorExpressionValue left, + PreprocessorExpressionValue right) +{ + switch (opToken.type) + { + default: +// SLANG_INTERNAL_ERROR(getSink(preprocessor), opToken); + return 0; + break; + + case TokenType::OpMul: return left * right; + case TokenType::OpDiv: + { + if (right == 0) + { + if (!context->parseError) + { + GetSink(context)->diagnose(opToken.loc, Diagnostics::divideByZeroInPreprocessorExpression); + } + return 0; + } + return left / right; + } + case TokenType::OpMod: + { + if (right == 0) + { + if (!context->parseError) + { + GetSink(context)->diagnose(opToken.loc, Diagnostics::divideByZeroInPreprocessorExpression); + } + return 0; + } + return left % right; + } + case TokenType::OpAdd: return left + right; + case TokenType::OpSub: return left - right; + case TokenType::OpLsh: return left << right; + case TokenType::OpRsh: return left >> right; + case TokenType::OpLess: return left < right ? 1 : 0; + case TokenType::OpGreater: return left > right ? 1 : 0; + case TokenType::OpLeq: return left <= right ? 1 : 0; + case TokenType::OpGeq: return left >= right ? 1 : 0; + case TokenType::OpEql: return left == right ? 1 : 0; + case TokenType::OpNeq: return left != right ? 1 : 0; + case TokenType::OpBitAnd: return left & right; + case TokenType::OpBitOr: return left | right; + case TokenType::OpBitXor: return left ^ right; + case TokenType::OpAnd: return left && right; + case TokenType::OpOr: return left || right; + } +} + +// Parse the rest of an infix preprocessor expression with +// precedence greater than or equal to the given `precedence` argument. +// The value of the left-hand-side expression is provided as +// an argument. +// This is used to form a simple recursive-descent expression parser. +static PreprocessorExpressionValue ParseAndEvaluateInfixExpressionWithPrecedence( + PreprocessorDirectiveContext* context, + PreprocessorExpressionValue left, + int precedence) +{ + for (;;) + { + // Look at the next token, and see if it is an operator of + // high enough precedence to be included in our expression + Token opToken = PeekToken(context); + int opPrecedence = GetInfixOpPrecedence(opToken); + + // If it isn't an operator of high enough precedence, we are done. + if(opPrecedence < precedence) + break; + + // Otherwise we need to consume the operator token. + AdvanceToken(context); + + // Next we parse a right-hand-side expression by starting with + // a unary expression and absorbing and many infix operators + // as possible with strictly higher precedence than the operator + // we found above. + PreprocessorExpressionValue right = ParseAndEvaluateUnaryExpression(context); + for (;;) + { + // Look for an operator token + Token rightOpToken = PeekToken(context); + int rightOpPrecedence = GetInfixOpPrecedence(rightOpToken); + + // If no operator was found, or the operator wasn't high + // enough precedence to fold into the right-hand-side, + // exit this loop. + if (rightOpPrecedence <= opPrecedence) + break; + + // Now invoke the parser recursively, passing in our + // existing right-hand side to form an even larger one. + right = ParseAndEvaluateInfixExpressionWithPrecedence( + context, + right, + rightOpPrecedence); + } + + // Now combine the left- and right-hand sides using + // the operator we found above. + left = EvaluateInfixOp(context, opToken, left, right); + } + return left; +} + +// Parse a complete (infix) preprocessor expression, and return its value +static PreprocessorExpressionValue ParseAndEvaluateExpression(PreprocessorDirectiveContext* context) +{ + // First read in the left-hand side (or the whole expression in the unary case) + PreprocessorExpressionValue value = ParseAndEvaluateUnaryExpression(context); + + // Try to read in trailing infix operators with correct precedence + return ParseAndEvaluateInfixExpressionWithPrecedence(context, value, 0); +} + +// Handle a `#if` directive +static void HandleIfDirective(PreprocessorDirectiveContext* context) +{ + // Record current input stream in case preprocessor expression + // changes the input stream to a macro expansion while we + // are parsing. + auto inputStream = context->preprocessor->inputStream; + + // If we are skipping, we can just consume the expression, and assume true + if (IsSkipping(context->preprocessor)) + { + // Consume everything until the end of the line + SkipToEndOfLine(context); + // Begin a preprocessor block, assume true based on the expression + // (contents will all be ignored because skipping). + beginConditional(context, inputStream, true); + } + else + { + // Parse a preprocessor expression. + PreprocessorExpressionValue value = ParseAndEvaluateExpression(context); + + // Begin a preprocessor block, enabled based on the expression. + beginConditional(context, inputStream, value != 0); + } +} + +// Handle a `#ifdef` directive +static void HandleIfDefDirective(PreprocessorDirectiveContext* context) +{ + // Expect a raw identifier, so we can check if it is defined + Token nameToken; + if(!ExpectRaw(context, TokenType::Identifier, Diagnostics::expectedTokenInPreprocessorDirective, &nameToken)) + return; + Name* name = nameToken.getName(); + + // Check if the name is defined. + beginConditional(context, LookupMacro(context, name) != NULL); +} + +// Handle a `#ifndef` directive +static void HandleIfNDefDirective(PreprocessorDirectiveContext* context) +{ + // Expect a raw identifier, so we can check if it is defined + Token nameToken; + if(!ExpectRaw(context, TokenType::Identifier, Diagnostics::expectedTokenInPreprocessorDirective, &nameToken)) + return; + Name* name = nameToken.getName(); + + // Check if the name is defined. + beginConditional(context, LookupMacro(context, name) == NULL); +} + +// Handle a `#else` directive +static void HandleElseDirective(PreprocessorDirectiveContext* context) +{ + PreprocessorInputStream* inputStream = context->preprocessor->inputStream; + SLANG_ASSERT(inputStream); + + // if we aren't inside a conditional, then error + PreprocessorConditional* conditional = inputStream->primaryStream->conditional; + if (!conditional) + { + GetSink(context)->diagnose(GetDirectiveLoc(context), Diagnostics::directiveWithoutIf, GetDirectiveName(context)); + return; + } + + // if we've already seen a `#else`, then it is an error + if (conditional->elseToken.type != TokenType::Unknown) + { + GetSink(context)->diagnose(GetDirectiveLoc(context), Diagnostics::directiveAfterElse, GetDirectiveName(context)); + GetSink(context)->diagnose(conditional->elseToken.loc, Diagnostics::seeDirective); + return; + } + conditional->elseToken = context->directiveToken; + + switch (conditional->state) + { + case PreprocessorConditionalState::Before: + conditional->state = PreprocessorConditionalState::During; + break; + + case PreprocessorConditionalState::During: + conditional->state = PreprocessorConditionalState::After; + break; + + default: + break; + } +} + +// Handle a `#elif` directive +static void HandleElifDirective(PreprocessorDirectiveContext* context) +{ + // Need to grab current input stream *before* we try to parse + // the conditional expression. + PreprocessorInputStream* inputStream = context->preprocessor->inputStream; + SLANG_ASSERT(inputStream); + + // HACK(tfoley): handle an empty `elif` like an `else` directive + // + // This is the behavior expected by at least one input program. + // We will eventually want to be pedantic about this. + // even if t + if (PeekRawTokenType(context) == TokenType::EndOfDirective) + { + GetSink(context)->diagnose(GetDirectiveLoc(context), Diagnostics::directiveExpectsExpression, GetDirectiveName(context)); + HandleElseDirective(context); + return; + } + + PreprocessorExpressionValue value = ParseAndEvaluateExpression(context); + + // if we aren't inside a conditional, then error + PreprocessorConditional* conditional = inputStream->primaryStream->conditional; + if (!conditional) + { + GetSink(context)->diagnose(GetDirectiveLoc(context), Diagnostics::directiveWithoutIf, GetDirectiveName(context)); + return; + } + + // if we've already seen a `#else`, then it is an error + if (conditional->elseToken.type != TokenType::Unknown) + { + GetSink(context)->diagnose(GetDirectiveLoc(context), Diagnostics::directiveAfterElse, GetDirectiveName(context)); + GetSink(context)->diagnose(conditional->elseToken.loc, Diagnostics::seeDirective); + return; + } + + switch (conditional->state) + { + case PreprocessorConditionalState::Before: + if(value) + conditional->state = PreprocessorConditionalState::During; + break; + + case PreprocessorConditionalState::During: + conditional->state = PreprocessorConditionalState::After; + break; + + default: + break; + } +} + +// Handle a `#endif` directive +static void HandleEndIfDirective(PreprocessorDirectiveContext* context) +{ + PreprocessorInputStream* inputStream = context->preprocessor->inputStream; + SLANG_ASSERT(inputStream); + + // if we aren't inside a conditional, then error + PreprocessorConditional* conditional = inputStream->primaryStream->conditional; + if (!conditional) + { + GetSink(context)->diagnose(GetDirectiveLoc(context), Diagnostics::directiveWithoutIf, GetDirectiveName(context)); + return; + } + + inputStream->primaryStream->conditional = conditional->parent; + DestroyConditional(conditional); +} + +// Helper routine to check that we find the end of a directive where +// we expect it. +// +// Most directives do not need to call this directly, since we have +// a catch-all case in the main `HandleDirective()` function. +// The `#include` case will call it directly to avoid complications +// when it switches the input stream. +static void expectEndOfDirective(PreprocessorDirectiveContext* context) +{ + if(context->haveDoneEndOfDirectiveChecks) + return; + + context->haveDoneEndOfDirectiveChecks = true; + + if (!IsEndOfLine(context)) + { + // If we already saw a previous parse error, then don't + // emit another one for the same directive. + if (!context->parseError) + { + GetSink(context)->diagnose(PeekLoc(context), Diagnostics::unexpectedTokensAfterDirective, GetDirectiveName(context)); + } + SkipToEndOfLine(context); + } + + // Clear out the end-of-directive token + AdvanceRawToken(context->preprocessor); +} + + /// Read a file in the context of handling a preprocessor directive +static SlangResult readFile( + PreprocessorDirectiveContext* context, + String const& path, + ISlangBlob** outBlob) +{ + // The actual file loading will be handled by the file system + // associated with the parent linkage. + // + auto linkage = context->preprocessor->linkage; + auto fileSystemExt = linkage->getFileSystemExt(); + SLANG_RETURN_ON_FAIL(fileSystemExt->loadFile(path.getBuffer(), outBlob)); + + // If we are running the preprocessor as part of compiling a + // specific module, then we must keep track of the file we've + // read as yet another file that the module will depend on. + // + if(auto module = context->preprocessor->parentModule) + { + module->addFilePathDependency(path); + } + + return SLANG_OK; +} + +// Handle a `#include` directive +static void HandleIncludeDirective(PreprocessorDirectiveContext* context) +{ + // Consume the directive, and inform the lexer to process the remainder of the line as a file path. + AdvanceRawToken(context, kLexerFlag_ExpectFileName); + + Token pathToken; + if(!Expect(context, TokenType::StringLiteral, Diagnostics::expectedTokenInPreprocessorDirective, &pathToken)) + return; + + String path = getFileNameTokenValue(pathToken); + + auto directiveLoc = GetDirectiveLoc(context); + + PathInfo includedFromPathInfo = context->preprocessor->getSourceManager()->getPathInfo(directiveLoc, SourceLocType::Actual); + + IncludeHandler* includeHandler = context->preprocessor->includeHandler; + if (!includeHandler) + { + GetSink(context)->diagnose(pathToken.loc, Diagnostics::includeFailed, path); + GetSink(context)->diagnose(pathToken.loc, Diagnostics::noIncludeHandlerSpecified); + return; + } + + /* Find the path relative to the foundPath */ + PathInfo filePathInfo; + if (SLANG_FAILED(includeHandler->findFile(path, includedFromPathInfo.foundPath, filePathInfo))) + { + GetSink(context)->diagnose(pathToken.loc, Diagnostics::includeFailed, path); + return; + } + + // We must have a uniqueIdentity to be compare + if (!filePathInfo.hasUniqueIdentity()) + { + GetSink(context)->diagnose(pathToken.loc, Diagnostics::noUniqueIdentity, path); + return; + } + + // Do all checking related to the end of this directive before we push a new stream, + // just to avoid complications where that check would need to deal with + // a switch of input stream + expectEndOfDirective(context); + + // Check whether we've previously included this file and seen a `#pragma once` directive + if(context->preprocessor->pragmaOnceUniqueIdentities.Contains(filePathInfo.uniqueIdentity)) + { + return; + } + + // Simplify the path + filePathInfo.foundPath = includeHandler->simplifyPath(filePathInfo.foundPath); + + // Push the new file onto our stack of input streams + // TODO(tfoley): check if we have made our include stack too deep + auto sourceManager = context->preprocessor->getSourceManager(); + + // See if this an already loaded source file + SourceFile* sourceFile = sourceManager->findSourceFileRecursively(filePathInfo.uniqueIdentity); + // If not create a new one, and add to the list of known source files + if (!sourceFile) + { + ComPtr foundSourceBlob; + if (SLANG_FAILED(readFile(context, filePathInfo.foundPath, foundSourceBlob.writeRef()))) + { + GetSink(context)->diagnose(pathToken.loc, Diagnostics::includeFailed, path); + return; + } + + + sourceFile = sourceManager->createSourceFileWithBlob(filePathInfo, foundSourceBlob); + sourceManager->addSourceFile(filePathInfo.uniqueIdentity, sourceFile); + } + + // This is a new parse (even if it's a pre-existing source file), so create a new SourceUnit + SourceView* sourceView = sourceManager->createSourceView(sourceFile, &filePathInfo); + + PreprocessorInputStream* inputStream = CreateInputStreamForSource(context->preprocessor, sourceView); + inputStream->parent = context->preprocessor->inputStream; + context->preprocessor->inputStream = inputStream; +} + +// Handle a `#define` directive +static void HandleDefineDirective(PreprocessorDirectiveContext* context) +{ + Token nameToken; + if (!ExpectRaw(context, TokenType::Identifier, Diagnostics::expectedTokenInPreprocessorDirective, &nameToken)) + return; + Name* name = nameToken.getName(); + + PreprocessorMacro* macro = CreateMacro(context->preprocessor); + macro->nameAndLoc = NameLoc(nameToken); + + PreprocessorMacro* oldMacro = LookupMacro(&context->preprocessor->globalEnv, name); + if (oldMacro) + { + GetSink(context)->diagnose(nameToken.loc, Diagnostics::macroRedefinition, name); + GetSink(context)->diagnose(oldMacro->getLoc(), Diagnostics::seePreviousDefinitionOf, name); + + DestroyMacro(context->preprocessor, oldMacro); + } + context->preprocessor->globalEnv.macros[name] = macro; + + // If macro name is immediately followed (with no space) by `(`, + // then we have a function-like macro + if (PeekRawTokenType(context) == TokenType::LParent) + { + if (!(PeekRawToken(context).flags & TokenFlag::AfterWhitespace)) + { + // This is a function-like macro, so we need to remember that + // and start capturing parameters + macro->flavor = PreprocessorMacroFlavor::FunctionLike; + + AdvanceRawToken(context); + + // If there are any parameters, parse them + if (PeekRawTokenType(context) != TokenType::RParent) + { + for (;;) + { + // TODO: handle elipsis (`...`) for varags + + // A macro parameter name should be a raw identifier + Token paramToken; + if (!ExpectRaw(context, TokenType::Identifier, Diagnostics::expectedTokenInMacroParameters, ¶mToken)) + break; + + // TODO(tfoley): some validation on parameter name. + // Certain names (e.g., `defined` and `__VA_ARGS__` + // are not allowed to be used as macros or parameters). + + // Add the parameter to the macro being deifned + macro->params.add(paramToken); + + // If we see `)` then we are done with arguments + if (PeekRawTokenType(context) == TokenType::RParent) + break; + + ExpectRaw(context, TokenType::Comma, Diagnostics::expectedTokenInMacroParameters); + } + } + + ExpectRaw(context, TokenType::RParent, Diagnostics::expectedTokenInMacroParameters); + } + } + + // consume tokens until end-of-line + for(;;) + { + Token token = AdvanceRawToken(context); + if( token.type == TokenType::EndOfDirective ) + { + // Last token on line will be turned into a conceptual end-of-file + // token for the sub-stream that the macro expands into. + token.type = TokenType::EndOfFile; + macro->tokens.mTokens.add(token); + break; + } + + // In the ordinary case, we just add the token to the definition + macro->tokens.mTokens.add(token); + } +} + +// Handle a `#undef` directive +static void HandleUndefDirective(PreprocessorDirectiveContext* context) +{ + Token nameToken; + if (!ExpectRaw(context, TokenType::Identifier, Diagnostics::expectedTokenInPreprocessorDirective, &nameToken)) + return; + Name* name = nameToken.getName(); + + PreprocessorEnvironment* env = &context->preprocessor->globalEnv; + PreprocessorMacro* macro = LookupMacro(env, name); + if (macro != NULL) + { + // name was defined, so remove it + env->macros.Remove(name); + + DestroyMacro(context->preprocessor, macro); + } + else + { + // name wasn't defined + GetSink(context)->diagnose(nameToken.loc, Diagnostics::macroNotDefined, name); + } +} + +// Handle a `#warning` directive +static void HandleWarningDirective(PreprocessorDirectiveContext* context) +{ + // Consume the directive, and inform the lexer to process the remainder of the line as a custom message. + AdvanceRawToken(context, kLexerFlag_ExpectDirectiveMessage); + + // Read the message token. + Token messageToken; + Expect(context, TokenType::DirectiveMessage, Diagnostics::expectedTokenInPreprocessorDirective, &messageToken); + + // Report the custom error. + GetSink(context)->diagnose(GetDirectiveLoc(context), Diagnostics::userDefinedWarning, messageToken.Content); +} + +// Handle a `#error` directive +static void HandleErrorDirective(PreprocessorDirectiveContext* context) +{ + // Consume the directive, and inform the lexer to process the remainder of the line as a custom message. + AdvanceRawToken(context, kLexerFlag_ExpectDirectiveMessage); + + // Read the message token. + Token messageToken; + Expect(context, TokenType::DirectiveMessage, Diagnostics::expectedTokenInPreprocessorDirective, &messageToken); + + // Report the custom error. + GetSink(context)->diagnose(GetDirectiveLoc(context), Diagnostics::userDefinedError, messageToken.Content); +} + +// Handle a `#line` directive +static void HandleLineDirective(PreprocessorDirectiveContext* context) +{ + auto inputStream = context->preprocessor->inputStream; + + int line = 0; + + SourceLoc directiveLoc = GetDirectiveLoc(context); + + // `#line ...` + if (PeekTokenType(context) == TokenType::IntegerLiteral) + { + line = StringToInt(AdvanceToken(context).Content); + } + // `#line` + // `#line default` + else if ( + PeekTokenType(context) == TokenType::EndOfDirective + || (PeekTokenType(context) == TokenType::Identifier + && PeekToken(context).Content == "default")) + { + AdvanceToken(context); + + // Stop overriding source locations. + auto sourceView = inputStream->primaryStream->lexer.sourceView; + sourceView->addDefaultLineDirective(directiveLoc); + return; + } + else + { + GetSink(context)->diagnose(PeekLoc(context), Diagnostics::expected2TokensInPreprocessorDirective, + TokenType::IntegerLiteral, + "default", + GetDirectiveName(context)); + context->parseError = true; + return; + } + + auto sourceManager = context->preprocessor->getSourceManager(); + + String file; + if (PeekTokenType(context) == TokenType::EndOfDirective) + { + file = sourceManager->getPathInfo(directiveLoc).foundPath; + } + else if (PeekTokenType(context) == TokenType::StringLiteral) + { + file = getStringLiteralTokenValue(AdvanceToken(context)); + } + else if (PeekTokenType(context) == TokenType::IntegerLiteral) + { + // Note(tfoley): GLSL allows the "source string" to be indicated by an integer + // TODO(tfoley): Figure out a better way to handle this, if it matters + file = AdvanceToken(context).Content; + } + else + { + Expect(context, TokenType::StringLiteral, Diagnostics::expectedTokenInPreprocessorDirective); + return; + } + + auto sourceView = inputStream->primaryStream->lexer.sourceView; + sourceView->addLineDirective(directiveLoc, file, line); +} + +#define SLANG_PRAGMA_DIRECTIVE_CALLBACK(NAME) \ + void NAME(PreprocessorDirectiveContext* context, Token subDirectiveToken) + +// Callback interface used by `#pragma` directives +typedef SLANG_PRAGMA_DIRECTIVE_CALLBACK((*PragmaDirectiveCallback)); + +SLANG_PRAGMA_DIRECTIVE_CALLBACK(handleUnknownPragmaDirective) +{ + GetSink(context)->diagnose(subDirectiveToken, Diagnostics::unknownPragmaDirectiveIgnored, subDirectiveToken.getName()); + SkipToEndOfLine(context); + return; +} + +SLANG_PRAGMA_DIRECTIVE_CALLBACK(handlePragmaOnceDirective) +{ + // We need to identify the path of the file we are preprocessing, + // so that we can avoid including it again. + // + // We are using the 'uniqueIdentity' as determined by the ISlangFileSystemEx interface to determine file identities. + + auto directiveLoc = GetDirectiveLoc(context); + auto issuedFromPathInfo = context->preprocessor->getSourceManager()->getPathInfo(directiveLoc, SourceLocType::Actual); + + // Must have uniqueIdentity for a #pragma once to work + if (!issuedFromPathInfo.hasUniqueIdentity()) + { + GetSink(context)->diagnose(subDirectiveToken, Diagnostics::pragmaOnceIgnored); + return; + } + + context->preprocessor->pragmaOnceUniqueIdentities.Add(issuedFromPathInfo.uniqueIdentity); +} + +// Information about a specific `#pragma` directive +struct PragmaDirective +{ + // name of the directive + char const* name; + + // Callback to handle the directive + PragmaDirectiveCallback callback; +}; + +// A simple array of all the `#pragma` directives we know how to handle. +static const PragmaDirective kPragmaDirectives[] = +{ + { "once", &handlePragmaOnceDirective }, + + { NULL, NULL }, +}; + +static const PragmaDirective kUnknownPragmaDirective = { + NULL, &handleUnknownPragmaDirective, +}; + +// Look up the `#pragma` directive with the given name. +static PragmaDirective const* findPragmaDirective(String const& name) +{ + char const* nameStr = name.getBuffer(); + for (int ii = 0; kPragmaDirectives[ii].name; ++ii) + { + if (strcmp(kPragmaDirectives[ii].name, nameStr) != 0) + continue; + + return &kPragmaDirectives[ii]; + } + + return &kUnknownPragmaDirective; +} + +// Handle a `#pragma` directive +static void HandlePragmaDirective(PreprocessorDirectiveContext* context) +{ + // Try to read the sub-directive name. + Token subDirectiveToken = PeekRawToken(context); + + // The sub-directive had better be an identifier + if (subDirectiveToken.type != TokenType::Identifier) + { + GetSink(context)->diagnose(GetDirectiveLoc(context), Diagnostics::expectedPragmaDirectiveName); + SkipToEndOfLine(context); + return; + } + AdvanceRawToken(context); + + // Look up the handler for the sub-directive. + PragmaDirective const* subDirective = findPragmaDirective(subDirectiveToken.getName()->text); + + // Apply the sub-directive-specific callback + (subDirective->callback)(context, subDirectiveToken); +} + +// Handle an invalid directive +static void HandleInvalidDirective(PreprocessorDirectiveContext* context) +{ + GetSink(context)->diagnose(GetDirectiveLoc(context), Diagnostics::unknownPreprocessorDirective, GetDirectiveName(context)); + SkipToEndOfLine(context); +} + +// Callback interface used by preprocessor directives +typedef void (*PreprocessorDirectiveCallback)(PreprocessorDirectiveContext* context); + +enum PreprocessorDirectiveFlag : unsigned int +{ + // Should this directive be handled even when skipping disbaled code? + ProcessWhenSkipping = 1 << 0, + + /// Allow the handler for this directive to advance past the + /// directive token itself, so that it can control lexer behavior + /// more closely. + DontConsumeDirectiveAutomatically = 1 << 1, +}; + +// Information about a specific directive +struct PreprocessorDirective +{ + // name of the directive + char const* name; + + // Callback to handle the directive + PreprocessorDirectiveCallback callback; + + unsigned int flags; +}; + +// A simple array of all the directives we know how to handle. +// TODO(tfoley): considering making this into a real hash map, +// and then make it easy-ish for users of the codebase to add +// their own directives as desired. +static const PreprocessorDirective kDirectives[] = +{ + { "if", &HandleIfDirective, ProcessWhenSkipping }, + { "ifdef", &HandleIfDefDirective, ProcessWhenSkipping }, + { "ifndef", &HandleIfNDefDirective, ProcessWhenSkipping }, + { "else", &HandleElseDirective, ProcessWhenSkipping }, + { "elif", &HandleElifDirective, ProcessWhenSkipping }, + { "endif", &HandleEndIfDirective, ProcessWhenSkipping }, + + { "include", &HandleIncludeDirective, DontConsumeDirectiveAutomatically }, + { "define", &HandleDefineDirective, 0 }, + { "undef", &HandleUndefDirective, 0 }, + { "warning", &HandleWarningDirective, DontConsumeDirectiveAutomatically }, + { "error", &HandleErrorDirective, DontConsumeDirectiveAutomatically }, + { "line", &HandleLineDirective, 0 }, + { "pragma", &HandlePragmaDirective, 0 }, + + { nullptr, nullptr, 0 }, +}; + +static const PreprocessorDirective kInvalidDirective = { + nullptr, &HandleInvalidDirective, 0, +}; + +// Look up the directive with the given name. +static PreprocessorDirective const* FindDirective(String const& name) +{ + char const* nameStr = name.getBuffer(); + for (int ii = 0; kDirectives[ii].name; ++ii) + { + if (strcmp(kDirectives[ii].name, nameStr) != 0) + continue; + + return &kDirectives[ii]; + } + + return &kInvalidDirective; +} + +// Process a directive, where the preprocessor has already consumed the +// `#` token that started the directive line. +static void HandleDirective(PreprocessorDirectiveContext* context) +{ + // Try to read the directive name. + context->directiveToken = PeekRawToken(context); + + TokenType directiveTokenType = GetDirective(context).type; + + // An empty directive is allowed, and ignored. + if (directiveTokenType == TokenType::EndOfDirective) + { + return; + } + // Otherwise the directive name had better be an identifier + else if (directiveTokenType != TokenType::Identifier) + { + GetSink(context)->diagnose(GetDirectiveLoc(context), Diagnostics::expectedPreprocessorDirectiveName); + SkipToEndOfLine(context); + return; + } + + // Look up the handler for the directive. + PreprocessorDirective const* directive = FindDirective(GetDirectiveName(context)); + + // If we are skipping disabled code, and the directive is not one + // of the small number that need to run even in that case, skip it. + if (IsSkipping(context) && !(directive->flags & PreprocessorDirectiveFlag::ProcessWhenSkipping)) + { + SkipToEndOfLine(context); + return; + } + + if(!(directive->flags & PreprocessorDirectiveFlag::DontConsumeDirectiveAutomatically)) + { + // Consume the directive name token. + AdvanceRawToken(context); + } + + // Apply the directive-specific callback + (directive->callback)(context); + + // We expect the directive callback to consume the entire line, so if + // it hasn't that is a parse error. + expectEndOfDirective(context); +} + +// Read one token using the full preprocessor, with all its behaviors. +static Token ReadToken(Preprocessor* preprocessor) +{ + for (;;) + { + // Depending on what the lookahead token is, we + // might need to start expanding it. + // + // Note: doing this at the start of this loop + // is important, in case a macro has an empty + // expansion, and we end up looking at a different + // token after applying the expansion. + if(!IsSkipping(preprocessor)) + { + MaybeBeginMacroExpansion(preprocessor); + } + + // Look at the next raw token in the input. + Token const& token = PeekRawToken(preprocessor); + if (token.type == TokenType::EndOfFile) + return token; + + // If we have a directive (`#` at start of line) then handle it + if ((token.type == TokenType::Pound) && (token.flags & TokenFlag::AtStartOfLine)) + { + // Skip the `#` + AdvanceRawToken(preprocessor); + + // Create a context for parsing the directive + PreprocessorDirectiveContext directiveContext; + directiveContext.preprocessor = preprocessor; + directiveContext.parseError = false; + directiveContext.haveDoneEndOfDirectiveChecks = false; + + // Parse and handle the directive + HandleDirective(&directiveContext); + continue; + } + + // otherwise, if we are currently in a skipping mode, then skip tokens + if (IsSkipping(preprocessor)) + { + AdvanceRawToken(preprocessor); + continue; + } + + // otherwise read a token, which may involve macro expansion + return AdvanceToken(preprocessor); + } +} + +// intialize a preprocessor context, using the given sink for errros +static void InitializePreprocessor( + Preprocessor* preprocessor, + DiagnosticSink* sink) +{ + preprocessor->sink = sink; + preprocessor->includeHandler = NULL; + preprocessor->endOfFileToken.type = TokenType::EndOfFile; + preprocessor->endOfFileToken.flags = TokenFlag::AtStartOfLine; +} + +// clean up after an environment +PreprocessorEnvironment::~PreprocessorEnvironment() +{ + for (auto pair : this->macros) + { + DestroyMacro(NULL, pair.Value); + } +} + +// finalize a preprocessor and free any memory still in use +static void FinalizePreprocessor( + Preprocessor* preprocessor) +{ + // Clear out any waiting input streams + PreprocessorInputStream* input = preprocessor->inputStream; + while (input) + { + PreprocessorInputStream* parent = input->parent; + EndInputStream(preprocessor, input); + input = parent; + } + +#if 0 + // clean up any macros that were allocated + for (auto pair : preprocessor->globalEnv.macros) + { + DestroyMacro(preprocessor, pair.Value); + } +#endif +} + +// Add a simple macro definition from a string (e.g., for a +// `-D` option passed on the command line +static void DefineMacro( + Preprocessor* preprocessor, + String const& key, + String const& value) +{ + PathInfo pathInfo = PathInfo::makeCommandLine(); + + PreprocessorMacro* macro = CreateMacro(preprocessor); + + auto sourceManager = preprocessor->getSourceManager(); + + SourceFile* keyFile = sourceManager->createSourceFileWithString(pathInfo, key); + SourceFile* valueFile = sourceManager->createSourceFileWithString(pathInfo, value); + + SourceView* keyView = sourceManager->createSourceView(keyFile, nullptr); + SourceView* valueView = sourceManager->createSourceView(valueFile, nullptr); + + // Use existing `Lexer` to generate a token stream. + Lexer lexer; + lexer.initialize(valueView, GetSink(preprocessor), preprocessor->getNamePool(), sourceManager->getMemoryArena()); + macro->tokens = lexer.lexAllTokens(); + + Name* keyName = preprocessor->getNamePool()->getName(key); + + macro->nameAndLoc.name = keyName; + macro->nameAndLoc.loc = keyView->getRange().begin; + + PreprocessorMacro* oldMacro = NULL; + if (preprocessor->globalEnv.macros.TryGetValue(keyName, oldMacro)) + { + DestroyMacro(preprocessor, oldMacro); + } + + preprocessor->globalEnv.macros[keyName] = macro; +} + +// read the entire input into tokens +static TokenList ReadAllTokens( + Preprocessor* preprocessor) +{ + TokenList tokens; + for (;;) + { + Token token = ReadToken(preprocessor); + + tokens.mTokens.add(token); + + // Note: we include the EOF token in the list, + // since that is expected by the `TokenList` type. + if (token.type == TokenType::EndOfFile) + break; + } + return tokens; +} + +TokenList preprocessSource( + SourceFile* file, + DiagnosticSink* sink, + IncludeHandler* includeHandler, + Dictionary defines, + Linkage* linkage, + Module* parentModule) +{ + Preprocessor preprocessor; + InitializePreprocessor(&preprocessor, sink); + preprocessor.linkage = linkage; + preprocessor.parentModule = parentModule; + + preprocessor.includeHandler = includeHandler; + for (auto p : defines) + { + DefineMacro(&preprocessor, p.Key, p.Value); + } + + SourceManager* sourceManager = linkage->getSourceManager(); + + SourceView* sourceView = sourceManager->createSourceView(file, nullptr); + + // create an initial input stream based on the provided buffer + preprocessor.inputStream = CreateInputStreamForSource(&preprocessor, sourceView); + + TokenList tokens = ReadAllTokens(&preprocessor); + + FinalizePreprocessor(&preprocessor); + + // debugging: build the pre-processed source back together +#if 0 + StringBuilder sb; + for (auto t : tokens) + { + if (t.flags & TokenFlag::AtStartOfLine) + { + sb << "\n"; + } + else if (t.flags & TokenFlag::AfterWhitespace) + { + sb << " "; + } + + sb << t.Content; + } + + String s = sb.ProduceString(); +#endif + + return tokens; +} + +} diff --git a/source/slang/slang-preprocessor.h b/source/slang/slang-preprocessor.h new file mode 100644 index 000000000..191adce88 --- /dev/null +++ b/source/slang/slang-preprocessor.h @@ -0,0 +1,38 @@ +// Preprocessor.h +#ifndef SLANG_PREPROCESSOR_H_INCLUDED +#define SLANG_PREPROCESSOR_H_INCLUDED + +#include "../core/slang-basic.h" +#include "../slang/slang-lexer.h" + +namespace Slang { + +class DiagnosticSink; +class Linkage; +class Module; +class ModuleDecl; + +// Callback interface for the preprocessor to use when looking +// for files in `#include` directives. +struct IncludeHandler +{ + + virtual SlangResult findFile(const String& pathToInclude, + const String& pathIncludedFrom, + PathInfo& pathInfoOut) = 0; + + virtual String simplifyPath(const String& path) = 0; +}; + +// Take a string of source code and preprocess it into a list of tokens. +TokenList preprocessSource( + SourceFile* file, + DiagnosticSink* sink, + IncludeHandler* includeHandler, + Dictionary defines, + Linkage* linkage, + Module* parentModule); + +} // namespace Slang + +#endif diff --git a/source/slang/slang-profile-defs.h b/source/slang/slang-profile-defs.h new file mode 100644 index 000000000..238621084 --- /dev/null +++ b/source/slang/slang-profile-defs.h @@ -0,0 +1,305 @@ +// + +// Define all the various language "profiles" we want to support. + +#ifndef LANGUAGE +#define LANGUAGE(TAG, NAME) /* emptry */ +#endif + +#ifndef LANGUAGE_ALIAS +#define LANGUAGE_ALIAS(TAG, NAME) /* empty */ +#endif + +#ifndef PROFILE_FAMILY +#define PROFILE_FAMILY(TAG) /* empty */ +#endif + +#ifndef PROFILE_VERSION +#define PROFILE_VERSION(TAG, FAMILY) /* empty */ +#endif + +#ifndef PROFILE_STAGE +#define PROFILE_STAGE(TAG, NAME, VAL) /* empty */ +#endif + +#ifndef PROFILE_STAGE_ALIAS +#define PROFILE_STAGE_ALIAS(TAG, NAME, VAL) /* empty */ +#endif + + +#ifndef PROFILE +#define PROFILE(TAG, NAME, STAGE, VERSION) /* empty */ +#endif + +#ifndef PROFILE_ALIAS +#define PROFILE_ALIAS(TAG, DEF, NAME) /* empty */ +#endif + +// Source and destination languages + +LANGUAGE(HLSL, hlsl) +LANGUAGE(DXBytecode, dxbc) +LANGUAGE(DXBytecodeAssembly,dxbc_asm) +LANGUAGE(DXIL, dxil) +LANGUAGE(DXILAssembly, dxil_asm) +LANGUAGE(GLSL, glsl) +LANGUAGE(GLSL_ES, glsl_es) +LANGUAGE(GLSL_VK, glsl_vk) +LANGUAGE(SPIRV, spirv) +LANGUAGE(SPIRV_GL, spirv_gl) + +LANGUAGE_ALIAS(GLSL, glsl_gl) +LANGUAGE_ALIAS(SPIRV, spirv_vk) + + +// Pipeline stages to target +PROFILE_STAGE(Vertex, vertex, SLANG_STAGE_VERTEX) +PROFILE_STAGE(Hull, hull, SLANG_STAGE_HULL) +PROFILE_STAGE(Domain, domain, SLANG_STAGE_DOMAIN) +PROFILE_STAGE(Geometry, geometry, SLANG_STAGE_GEOMETRY) +PROFILE_STAGE(Pixel, pixel, SLANG_STAGE_FRAGMENT) +PROFILE_STAGE(Compute, compute, SLANG_STAGE_COMPUTE) + +PROFILE_STAGE(RayGeneration, raygeneration, SLANG_STAGE_RAY_GENERATION) +PROFILE_STAGE(Intersection, intersection, SLANG_STAGE_INTERSECTION) +PROFILE_STAGE(AnyHit, anyhit, SLANG_STAGE_ANY_HIT) +PROFILE_STAGE(ClosestHit, closesthit, SLANG_STAGE_CLOSEST_HIT) +PROFILE_STAGE(Miss, miss, SLANG_STAGE_MISS) +PROFILE_STAGE(Callable, callable, SLANG_STAGE_CALLABLE) + + +// Note: HLSL and Direct3D convention erroneously uses the term "Pixel Shader" +// for the thing that shades *fragments*. Slang strives to treat the more correct +// term "Fragment Shader" as the primary one, but in order to be compatible with +// existing HLSL conventions, we need to treat `pixel` as the official stage +// name and `fragment` as an alias for it here, because the lower-case stage +// names are used to drive output HLSL generation. +// +PROFILE_STAGE_ALIAS(Fragment, fragment, Pixel) + +// Profile families + +PROFILE_FAMILY(DX) +PROFILE_FAMILY(GLSL) + +// Profile versions + + +PROFILE_VERSION(DX_4_0, DX) +PROFILE_VERSION(DX_4_0_Level_9_0, DX) +PROFILE_VERSION(DX_4_0_Level_9_1, DX) +PROFILE_VERSION(DX_4_0_Level_9_3, DX) +PROFILE_VERSION(DX_4_1, DX) +PROFILE_VERSION(DX_5_0, DX) +PROFILE_VERSION(DX_5_1, DX) +PROFILE_VERSION(DX_6_0, DX) +PROFILE_VERSION(DX_6_1, DX) +PROFILE_VERSION(DX_6_2, DX) +PROFILE_VERSION(DX_6_3, DX) + +PROFILE_VERSION(GLSL_110, GLSL) +PROFILE_VERSION(GLSL_120, GLSL) +PROFILE_VERSION(GLSL_130, GLSL) +PROFILE_VERSION(GLSL_140, GLSL) +PROFILE_VERSION(GLSL_150, GLSL) +PROFILE_VERSION(GLSL_330, GLSL) +PROFILE_VERSION(GLSL_400, GLSL) +PROFILE_VERSION(GLSL_410, GLSL) +PROFILE_VERSION(GLSL_420, GLSL) +PROFILE_VERSION(GLSL_430, GLSL) +PROFILE_VERSION(GLSL_440, GLSL) +PROFILE_VERSION(GLSL_450, GLSL) +PROFILE_VERSION(GLSL_460, GLSL) + + +// Specific profiles + +PROFILE(DX_Compute_4_0, cs_4_0, Compute, DX_4_0) +PROFILE(DX_Compute_4_1, cs_4_1, Compute, DX_4_1) +PROFILE(DX_Compute_5_0, cs_5_0, Compute, DX_5_0) +PROFILE(DX_Compute_5_1, cs_5_1, Compute, DX_5_1) +PROFILE(DX_Compute_6_0, cs_6_0, Compute, DX_6_0) +PROFILE(DX_Compute_6_1, cs_6_1, Compute, DX_6_1) +PROFILE(DX_Compute_6_2, cs_6_2, Compute, DX_6_2) +PROFILE(DX_Compute_6_3, cs_6_3, Compute, DX_6_3) + +PROFILE(DX_Domain_5_0, ds_5_0, Domain, DX_5_0) +PROFILE(DX_Domain_5_1, ds_5_1, Domain, DX_5_1) +PROFILE(DX_Domain_6_0, ds_6_0, Domain, DX_6_0) +PROFILE(DX_Domain_6_1, ds_6_1, Domain, DX_6_1) +PROFILE(DX_Domain_6_2, ds_6_2, Domain, DX_6_2) +PROFILE(DX_Domain_6_3, ds_6_3, Domain, DX_6_3) + +PROFILE(DX_Geometry_4_0, gs_4_0, Geometry, DX_4_0) +PROFILE(DX_Geometry_4_1, gs_4_1, Geometry, DX_4_1) +PROFILE(DX_Geometry_5_0, gs_5_0, Geometry, DX_5_0) +PROFILE(DX_Geometry_5_1, gs_5_1, Geometry, DX_5_1) +PROFILE(DX_Geometry_6_0, gs_6_0, Geometry, DX_6_0) +PROFILE(DX_Geometry_6_1, gs_6_1, Geometry, DX_6_1) +PROFILE(DX_Geometry_6_2, gs_6_2, Geometry, DX_6_2) +PROFILE(DX_Geometry_6_3, gs_6_3, Geometry, DX_6_3) + + +PROFILE(DX_Hull_5_0, hs_5_0, Hull, DX_5_0) +PROFILE(DX_Hull_5_1, hs_5_1, Hull, DX_5_1) +PROFILE(DX_Hull_6_0, hs_6_0, Hull, DX_6_0) +PROFILE(DX_Hull_6_1, hs_6_1, Hull, DX_6_1) +PROFILE(DX_Hull_6_2, hs_6_2, Hull, DX_6_2) +PROFILE(DX_Hull_6_3, hs_6_3, Hull, DX_6_3) + + +PROFILE(DX_Fragment_4_0, ps_4_0, Fragment, DX_4_0) +PROFILE(DX_Fragment_4_0_Level_9_0, ps_4_0_level_9_0, Fragment, DX_4_0_Level_9_0) +PROFILE(DX_Fragment_4_0_Level_9_1, ps_4_0_level_9_1, Fragment, DX_4_0_Level_9_1) +PROFILE(DX_Fragment_4_0_Level_9_3, ps_4_0_level_9_3, Fragment, DX_4_0_Level_9_3) +PROFILE(DX_Fragment_4_1, ps_4_1, Fragment, DX_4_1) +PROFILE(DX_Fragment_5_0, ps_5_0, Fragment, DX_5_0) +PROFILE(DX_Fragment_5_1, ps_5_1, Fragment, DX_5_1) +PROFILE(DX_Fragment_6_0, ps_6_0, Fragment, DX_6_0) +PROFILE(DX_Fragment_6_1, ps_6_1, Fragment, DX_6_1) +PROFILE(DX_Fragment_6_2, ps_6_2, Fragment, DX_6_2) +PROFILE(DX_Fragment_6_3, ps_6_3, Fragment, DX_6_3) + + +PROFILE(DX_Vertex_4_0, vs_4_0, Vertex, DX_4_0) +PROFILE(DX_Vertex_4_0_Level_9_0, vs_4_0_level_9_0, Vertex, DX_4_0_Level_9_0) +PROFILE(DX_Vertex_4_0_Level_9_1, vs_4_0_level_9_1, Vertex, DX_4_0_Level_9_1) +PROFILE(DX_Vertex_4_0_Level_9_3, vs_4_0_level_9_3, Vertex, DX_4_0_Level_9_3) +PROFILE(DX_Vertex_4_1, vs_4_1, Vertex, DX_4_1) +PROFILE(DX_Vertex_5_0, vs_5_0, Vertex, DX_5_0) +PROFILE(DX_Vertex_5_1, vs_5_1, Vertex, DX_5_1) +PROFILE(DX_Vertex_6_0, vs_6_0, Vertex, DX_6_0) +PROFILE(DX_Vertex_6_1, vs_6_1, Vertex, DX_6_1) +PROFILE(DX_Vertex_6_2, vs_6_2, Vertex, DX_6_2) +PROFILE(DX_Vertex_6_3, vs_6_3, Vertex, DX_6_3) + +// TODO: consider making `lib_*_*` alias these... +PROFILE(DX_None_4_0, sm_4_0, Unknown, DX_4_0) +PROFILE(DX_None_4_0_Level_9_0, sm_4_0_level_9_0, Unknown, DX_4_0_Level_9_0) +PROFILE(DX_None_4_0_Level_9_1, sm_4_0_level_9_1, Unknown, DX_4_0_Level_9_1) +PROFILE(DX_None_4_0_Level_9_3, sm_4_0_level_9_3, Unknown, DX_4_0_Level_9_3) +PROFILE(DX_None_4_1, sm_4_1, Unknown, DX_4_1) +PROFILE(DX_None_5_0, sm_5_0, Unknown, DX_5_0) +PROFILE(DX_None_5_1, sm_5_1, Unknown, DX_5_1) +PROFILE(DX_None_6_0, sm_6_0, Unknown, DX_6_0) + +// From Shader Model 6.1 on, the dxc compiler recognizes a `lib` profile +// that can be used to compile multiple entry points. We want that +// `lib` name to be the default for how we render these profiles when +// invoking downstream tools, so we use that instead of the `sm_` +// prefix, and then re-introduce the `sm_` variants as aliases. +// +// TODO: We may eventually want a split between how Slang represents +// profiles and their names to users, vs. how it renders them when +// invoking downstream tools, so that the profile name in any +// error messages can be consistent with our `sm_*` naems above +// +PROFILE(DX_Lib_6_1, lib_6_1, Unknown, DX_6_1) +PROFILE(DX_Lib_6_2, lib_6_2, Unknown, DX_6_2) +PROFILE(DX_Lib_6_3, lib_6_3, Unknown, DX_6_3) + +PROFILE_ALIAS(DX_None_6_1, DX_Lib_6_1, sm_6_1) +PROFILE_ALIAS(DX_None_6_2, DX_Lib_6_2, sm_6_2) +PROFILE_ALIAS(DX_None_6_3, DX_Lib_6_3, sm_6_3) + + +// Define all the GLSL profiles + +PROFILE(GLSL_None_110, glsl_110, Unknown, GLSL_110) +PROFILE(GLSL_None_120, glsl_120, Unknown, GLSL_120) +PROFILE(GLSL_None_130, glsl_130, Unknown, GLSL_130) +PROFILE(GLSL_None_140, glsl_140, Unknown, GLSL_140) +PROFILE(GLSL_None_150, glsl_150, Unknown, GLSL_150) +PROFILE(GLSL_None_330, glsl_330, Unknown, GLSL_330) +PROFILE(GLSL_None_400, glsl_400, Unknown, GLSL_400) +PROFILE(GLSL_None_410, glsl_410, Unknown, GLSL_410) +PROFILE(GLSL_None_420, glsl_420, Unknown, GLSL_420) +PROFILE(GLSL_None_430, glsl_430, Unknown, GLSL_430) +PROFILE(GLSL_None_440, glsl_440, Unknown, GLSL_440) +PROFILE(GLSL_None_450, glsl_450, Unknown, GLSL_450) +PROFILE(GLSL_None_460, glsl_460, Unknown, GLSL_460) + +#define P(UPPER, LOWER, VERSION) \ +PROFILE(GLSL_##UPPER##_##VERSION, glsl_##LOWER##_##VERSION, UPPER, GLSL_##VERSION) + +P(Vertex, vertex, 110) +P(Vertex, vertex, 120) +P(Vertex, vertex, 130) +P(Vertex, vertex, 140) +P(Vertex, vertex, 150) +P(Vertex, vertex, 330) +P(Vertex, vertex, 400) +P(Vertex, vertex, 410) +P(Vertex, vertex, 420) +P(Vertex, vertex, 430) +P(Vertex, vertex, 440) +P(Vertex, vertex, 450) + +P(Fragment, fragment, 110) +P(Fragment, fragment, 120) +P(Fragment, fragment, 130) +P(Fragment, fragment, 140) +P(Fragment, fragment, 150) +P(Fragment, fragment, 330) +P(Fragment, fragment, 400) +P(Fragment, fragment, 410) +P(Fragment, fragment, 420) +P(Fragment, fragment, 430) +P(Fragment, fragment, 440) +P(Fragment, fragment, 450) + +P(Geometry, geometry, 150) +P(Geometry, geometry, 330) +P(Geometry, geometry, 400) +P(Geometry, geometry, 410) +P(Geometry, geometry, 420) +P(Geometry, geometry, 430) +P(Geometry, geometry, 440) +P(Geometry, geometry, 450) + +P(Compute, compute, 430) +P(Compute, compute, 440) +P(Compute, compute, 450) + +#undef P +#define P(UPPER, LOWER, STAGE, VERSION) \ +PROFILE(GLSL_##UPPER##_##VERSION, glsl_##LOWER##_##VERSION, STAGE, GLSL_##VERSION) + +P(TessControl, tess_control, Hull, 400) +P(TessControl, tess_control, Hull, 410) +P(TessControl, tess_control, Hull, 420) +P(TessControl, tess_control, Hull, 430) +P(TessControl, tess_control, Hull, 440) +P(TessControl, tess_control, Hull, 450) + +P(TessEval, tess_eval, Domain, 400) +P(TessEval, tess_eval, Domain, 410) +P(TessEval, tess_eval, Domain, 420) +P(TessEval, tess_eval, Domain, 430) +P(TessEval, tess_eval, Domain, 440) +P(TessEval, tess_eval, Domain, 450) + +#undef P + +// Define a default profile for each GLSL stage that just +// uses the latest language version we know of + +PROFILE_ALIAS(GLSL_Vertex, GLSL_Vertex_450, glsl_vertex) +PROFILE_ALIAS(GLSL_Fragment, GLSL_Fragment_450, glsl_fragment) +PROFILE_ALIAS(GLSL_Geometry, GLSL_Geometry_450, glsl_geometry) +PROFILE_ALIAS(GLSL_TessControl, GLSL_TessControl_450, glsl_tess_control) +PROFILE_ALIAS(GLSL_TessEval, GLSL_TessEval_450, glsl_tess_eval) +PROFILE_ALIAS(GLSL_Compute, GLSL_Compute_450, glsl_compute) + +// TODO: define a profile for each GLSL *version* that we can +// use as a catch-all when the stage can be inferred from +// something else + +#undef LANGUAGE +#undef LANGUAGE_ALIAS +#undef PROFILE_FAMILY +#undef PROFILE_VERSION +#undef PROFILE_STAGE +#undef PROFILE_STAGE_ALIAS +#undef PROFILE +#undef PROFILE_ALIAS diff --git a/source/slang/slang-profile.cpp b/source/slang/slang-profile.cpp new file mode 100644 index 000000000..204b467f9 --- /dev/null +++ b/source/slang/slang-profile.cpp @@ -0,0 +1,34 @@ +// slang-profile.cpp +#include "slang-profile.h" + +namespace Slang { + +ProfileFamily getProfileFamily(ProfileVersion version) +{ + switch( version ) + { + default: return ProfileFamily::Unknown; + +#define PROFILE_VERSION(TAG, FAMILY) case ProfileVersion::TAG: return ProfileFamily::FAMILY; +#include "slang-profile-defs.h" + } +} + +const char* getStageName(Stage stage) +{ + switch(stage) + { +#define PROFILE_STAGE(ID, NAME, ENUM) \ + case Stage::ID: return #NAME; + +#include "slang-profile-defs.h" + + default: + return nullptr; + } + +} + + + +} diff --git a/source/slang/slang-profile.h b/source/slang/slang-profile.h new file mode 100644 index 000000000..de471282d --- /dev/null +++ b/source/slang/slang-profile.h @@ -0,0 +1,106 @@ +#ifndef SLANG_PROFILE_H_INCLUDED +#define SLANG_PROFILE_H_INCLUDED + +#include "../core/slang-basic.h" +#include "../../slang.h" + +namespace Slang +{ + // Flavors of translation unit + enum class SourceLanguage : SlangSourceLanguage + { + Unknown = SLANG_SOURCE_LANGUAGE_UNKNOWN, // should not occur + Slang = SLANG_SOURCE_LANGUAGE_SLANG, + HLSL = SLANG_SOURCE_LANGUAGE_HLSL, + GLSL = SLANG_SOURCE_LANGUAGE_GLSL, + }; + + // TODO(tfoley): This should merge with the above... + enum class Language + { + Unknown, +#define LANGUAGE(TAG, NAME) TAG, +#include "slang-profile-defs.h" + }; + + enum class ProfileFamily + { + Unknown, +#define PROFILE_FAMILY(TAG) TAG, +#include "slang-profile-defs.h" + }; + + enum class ProfileVersion + { + Unknown, +#define PROFILE_VERSION(TAG, FAMILY) TAG, +#include "slang-profile-defs.h" + }; + + enum class Stage : SlangStage + { + Unknown = SLANG_STAGE_NONE, +#define PROFILE_STAGE(TAG, NAME, VAL) TAG = VAL, +#define PROFILE_STAGE_ALIAS(TAG, NAME, VAL) TAG = VAL, +#include "slang-profile-defs.h" + }; + + const char* getStageName(Stage stage); + + ProfileFamily getProfileFamily(ProfileVersion version); + + struct Profile + { + typedef uint32_t RawVal; + enum RawEnum : RawVal + { + Unknown, + +#define PROFILE(TAG, NAME, STAGE, VERSION) TAG = (uint32_t(ProfileVersion::VERSION) << 16) | uint32_t(Stage::STAGE), +#define PROFILE_ALIAS(TAG, DEF, NAME) TAG = DEF, +#include "slang-profile-defs.h" + }; + + Profile() {} + Profile(RawEnum raw) + : raw(raw) + {} + explicit Profile(RawVal raw) + : raw(raw) + {} + explicit Profile(Stage stage) + { + setStage(stage); + } + explicit Profile(ProfileVersion version) + { + setVersion(version); + } + + bool operator==(Profile const& other) const { return raw == other.raw; } + bool operator!=(Profile const& other) const { return raw != other.raw; } + + Stage GetStage() const { return Stage(uint32_t(raw) & 0xFFFF); } + void setStage(Stage stage) + { + raw = (raw & ~0xFFFF) | uint32_t(stage); + } + + ProfileVersion GetVersion() const { return ProfileVersion((uint32_t(raw) >> 16) & 0xFFFF); } + void setVersion(ProfileVersion version) + { + raw = (raw & 0x0000FFFF) | (uint32_t(version) << 16); + } + + ProfileFamily getFamily() const { return getProfileFamily(GetVersion()); } + + static Profile LookUp(char const* name); + char const* getName(); + + RawVal raw = Unknown; + }; + + Stage findStageByName(String const& name); +} + +#endif diff --git a/source/slang/slang-reflection.cpp b/source/slang/slang-reflection.cpp new file mode 100644 index 000000000..c5428cdeb --- /dev/null +++ b/source/slang/slang-reflection.cpp @@ -0,0 +1,1451 @@ +// slang-reflection.cpp +#include "slang-reflection.h" + +#include "slang-compiler.h" +#include "slang-type-layout.h" +#include "slang-syntax.h" +#include + +// Don't signal errors for stuff we don't implement here, +// and instead just try to return things defensively +// +// Slang developers can switch this when debugging. +#define SLANG_REFLECTION_UNEXPECTED() do {} while(0) + +// Implementation to back public-facing reflection API + +using namespace Slang; + + +// Conversion routines to help with strongly-typed reflection API +static inline Session* convert(SlangSession* session) +{ + return (Session*)session; +} + +static inline UserDefinedAttribute* convert(SlangReflectionUserAttribute* attrib) +{ + return (UserDefinedAttribute*)attrib; +} +static inline SlangReflectionUserAttribute* convert(UserDefinedAttribute* attrib) +{ + return (SlangReflectionUserAttribute*)attrib; +} +static inline Type* convert(SlangReflectionType* type) +{ + return (Type*) type; +} + +static inline SlangReflectionType* convert(Type* type) +{ + return (SlangReflectionType*) type; +} + +static inline TypeLayout* convert(SlangReflectionTypeLayout* type) +{ + return (TypeLayout*) type; +} + +static inline SlangReflectionTypeLayout* convert(TypeLayout* type) +{ + return (SlangReflectionTypeLayout*) type; +} + +static inline GenericParamLayout* convert(SlangReflectionTypeParameter * typeParam) +{ + return (GenericParamLayout*)typeParam; +} + +static inline VarDeclBase* convert(SlangReflectionVariable* var) +{ + return (VarDeclBase*) var; +} + +static inline SlangReflectionVariable* convert(VarDeclBase* var) +{ + return (SlangReflectionVariable*) var; +} + +static inline VarLayout* convert(SlangReflectionVariableLayout* var) +{ + return (VarLayout*) var; +} + +static inline SlangReflectionVariableLayout* convert(VarLayout* var) +{ + return (SlangReflectionVariableLayout*) var; +} + +static inline EntryPointLayout* convert(SlangReflectionEntryPoint* entryPoint) +{ + return (EntryPointLayout*) entryPoint; +} + +static inline SlangReflectionEntryPoint* convert(EntryPointLayout* entryPoint) +{ + return (SlangReflectionEntryPoint*) entryPoint; +} + + +static inline ProgramLayout* convert(SlangReflection* program) +{ + return (ProgramLayout*) program; +} + +static inline SlangReflection* convert(ProgramLayout* program) +{ + return (SlangReflection*) program; +} + +// user attaribute + +unsigned int getUserAttributeCount(Decl* decl) +{ + unsigned int count = 0; + for (auto x : decl->GetModifiersOfType()) + { + SLANG_UNUSED(x); + count++; + } + return count; +} + +SlangReflectionUserAttribute* findUserAttributeByName(Session* session, Decl* decl, const char* name) +{ + auto nameObj = session->tryGetNameObj(name); + for (auto x : decl->GetModifiersOfType()) + { + if (x->name == nameObj) + return (SlangReflectionUserAttribute*)(x); + } + return nullptr; +} + +SlangReflectionUserAttribute* getUserAttributeByIndex(Decl* decl, unsigned int index) +{ + unsigned int id = 0; + for (auto x : decl->GetModifiersOfType()) + { + if (id == index) + return convert(x); + id++; + } + return nullptr; +} + +SLANG_API char const* spReflectionUserAttribute_GetName(SlangReflectionUserAttribute* attrib) +{ + auto userAttr = convert(attrib); + if (!userAttr) return nullptr; + return userAttr->getName()->text.getBuffer(); +} +SLANG_API unsigned int spReflectionUserAttribute_GetArgumentCount(SlangReflectionUserAttribute* attrib) +{ + auto userAttr = convert(attrib); + if (!userAttr) return 0; + return (unsigned int)userAttr->args.getCount(); +} +SlangReflectionType* spReflectionUserAttribute_GetArgumentType(SlangReflectionUserAttribute* attrib, unsigned int index) +{ + auto userAttr = convert(attrib); + if (!userAttr) return nullptr; + return convert(userAttr->args[index]->type.type.Ptr()); +} +SLANG_API SlangResult spReflectionUserAttribute_GetArgumentValueInt(SlangReflectionUserAttribute* attrib, unsigned int index, int * rs) +{ + auto userAttr = convert(attrib); + if (!userAttr) return SLANG_ERROR_INVALID_PARAMETER; + if (index >= (unsigned int)userAttr->args.getCount()) return SLANG_ERROR_INVALID_PARAMETER; + RefPtr val; + if (userAttr->intArgVals.TryGetValue(index, val)) + { + *rs = (int)as(val)->value; + return 0; + } + return SLANG_ERROR_INVALID_PARAMETER; +} +SLANG_API SlangResult spReflectionUserAttribute_GetArgumentValueFloat(SlangReflectionUserAttribute* attrib, unsigned int index, float * rs) +{ + auto userAttr = convert(attrib); + if (!userAttr) return SLANG_ERROR_INVALID_PARAMETER; + if (index >= (unsigned int)userAttr->args.getCount()) return SLANG_ERROR_INVALID_PARAMETER; + if (auto cexpr = as(userAttr->args[index])) + { + *rs = (float)cexpr->value; + return 0; + } + return SLANG_ERROR_INVALID_PARAMETER; +} +SLANG_API const char* spReflectionUserAttribute_GetArgumentValueString(SlangReflectionUserAttribute* attrib, unsigned int index, size_t* bufLen) +{ + auto userAttr = convert(attrib); + if (!userAttr) return nullptr; + if (index >= (unsigned int)userAttr->args.getCount()) return nullptr; + if (auto cexpr = as(userAttr->args[index])) + { + if (bufLen) + *bufLen = cexpr->token.Content.size(); + return cexpr->token.Content.begin(); + } + return nullptr; +} + + + +// type Reflection + + +SLANG_API SlangTypeKind spReflectionType_GetKind(SlangReflectionType* inType) +{ + auto type = convert(inType); + if(!type) return SLANG_TYPE_KIND_NONE; + + // TODO(tfoley: Don't emit the same type more than once... + + if (auto basicType = as(type)) + { + return SLANG_TYPE_KIND_SCALAR; + } + else if (auto vectorType = as(type)) + { + return SLANG_TYPE_KIND_VECTOR; + } + else if (auto matrixType = as(type)) + { + return SLANG_TYPE_KIND_MATRIX; + } + else if (auto parameterBlockType = as(type)) + { + return SLANG_TYPE_KIND_PARAMETER_BLOCK; + } + else if (auto constantBufferType = as(type)) + { + return SLANG_TYPE_KIND_CONSTANT_BUFFER; + } + else if( auto streamOutputType = as(type) ) + { + return SLANG_TYPE_KIND_OUTPUT_STREAM; + } + else if (as(type)) + { + return SLANG_TYPE_KIND_TEXTURE_BUFFER; + } + else if (as(type)) + { + return SLANG_TYPE_KIND_SHADER_STORAGE_BUFFER; + } + else if (auto samplerStateType = as(type)) + { + return SLANG_TYPE_KIND_SAMPLER_STATE; + } + else if (auto textureType = as(type)) + { + return SLANG_TYPE_KIND_RESOURCE; + } + // TODO: need a better way to handle this stuff... +#define CASE(TYPE) \ + else if(as(type)) do { \ + return SLANG_TYPE_KIND_RESOURCE; \ + } while(0) + + CASE(HLSLStructuredBufferType); + CASE(HLSLRWStructuredBufferType); + CASE(HLSLRasterizerOrderedStructuredBufferType); + CASE(HLSLAppendStructuredBufferType); + CASE(HLSLConsumeStructuredBufferType); + CASE(HLSLByteAddressBufferType); + CASE(HLSLRWByteAddressBufferType); + CASE(HLSLRasterizerOrderedByteAddressBufferType); + CASE(UntypedBufferResourceType); +#undef CASE + + else if (auto arrayType = as(type)) + { + return SLANG_TYPE_KIND_ARRAY; + } + else if( auto declRefType = as(type) ) + { + const auto& declRef = declRefType->declRef; + if(declRef.is() ) + { + return SLANG_TYPE_KIND_STRUCT; + } + else if (declRef.is()) + { + return SLANG_TYPE_KIND_GENERIC_TYPE_PARAMETER; + } + else if (declRef.is()) + { + return SLANG_TYPE_KIND_INTERFACE; + } + } + else if( auto specializedType = as(type) ) + { + return SLANG_TYPE_KIND_SPECIALIZED; + } + else if (auto errorType = as(type)) + { + // This means we saw a type we didn't understand in the user's code + return SLANG_TYPE_KIND_NONE; + } + + SLANG_REFLECTION_UNEXPECTED(); + return SLANG_TYPE_KIND_NONE; +} + +SLANG_API unsigned int spReflectionType_GetFieldCount(SlangReflectionType* inType) +{ + auto type = convert(inType); + if(!type) return 0; + + // TODO: maybe filter based on kind + + if(auto declRefType = as(type)) + { + auto declRef = declRefType->declRef; + if( auto structDeclRef = declRef.as()) + { + return GetFields(structDeclRef).Count(); + } + } + + return 0; +} + +SLANG_API SlangReflectionVariable* spReflectionType_GetFieldByIndex(SlangReflectionType* inType, unsigned index) +{ + auto type = convert(inType); + if(!type) return nullptr; + + // TODO: maybe filter based on kind + + if(auto declRefType = as(type)) + { + auto declRef = declRefType->declRef; + if( auto structDeclRef = declRef.as()) + { + auto fieldDeclRef = GetFields(structDeclRef).ToArray()[index]; + return (SlangReflectionVariable*) fieldDeclRef.getDecl(); + } + } + + return nullptr; +} + +SLANG_API size_t spReflectionType_GetElementCount(SlangReflectionType* inType) +{ + auto type = convert(inType); + if(!type) return 0; + + if(auto arrayType = as(type)) + { + return arrayType->ArrayLength ? (size_t) GetIntVal(arrayType->ArrayLength) : 0; + } + else if( auto vectorType = as(type)) + { + return (size_t) GetIntVal(vectorType->elementCount); + } + + return 0; +} + +SLANG_API SlangReflectionType* spReflectionType_GetElementType(SlangReflectionType* inType) +{ + auto type = convert(inType); + if(!type) return nullptr; + + if(auto arrayType = as(type)) + { + return (SlangReflectionType*) arrayType->baseType.Ptr(); + } + else if( auto constantBufferType = as(type)) + { + return convert(constantBufferType->elementType.Ptr()); + } + else if( auto vectorType = as(type)) + { + return convert(vectorType->elementType.Ptr()); + } + else if( auto matrixType = as(type)) + { + return convert(matrixType->getElementType()); + } + + return nullptr; +} + +SLANG_API unsigned int spReflectionType_GetRowCount(SlangReflectionType* inType) +{ + auto type = convert(inType); + if(!type) return 0; + + if(auto matrixType = as(type)) + { + return (unsigned int) GetIntVal(matrixType->getRowCount()); + } + else if(auto vectorType = as(type)) + { + return 1; + } + else if( auto basicType = as(type) ) + { + return 1; + } + + return 0; +} + +SLANG_API unsigned int spReflectionType_GetColumnCount(SlangReflectionType* inType) +{ + auto type = convert(inType); + if(!type) return 0; + + if(auto matrixType = as(type)) + { + return (unsigned int) GetIntVal(matrixType->getColumnCount()); + } + else if(auto vectorType = as(type)) + { + return (unsigned int) GetIntVal(vectorType->elementCount); + } + else if( auto basicType = as(type) ) + { + return 1; + } + + return 0; +} + +SLANG_API SlangScalarType spReflectionType_GetScalarType(SlangReflectionType* inType) +{ + auto type = convert(inType); + if(!type) return 0; + + if(auto matrixType = as(type)) + { + type = matrixType->getElementType(); + } + else if(auto vectorType = as(type)) + { + type = vectorType->elementType.Ptr(); + } + + if(auto basicType = as(type)) + { + switch (basicType->baseType) + { +#define CASE(BASE, TAG) \ + case BaseType::BASE: return SLANG_SCALAR_TYPE_##TAG + + CASE(Void, VOID); + CASE(Bool, BOOL); + CASE(Int8, INT8); + CASE(Int16, INT16); + CASE(Int, INT32); + CASE(Int64, INT64); + CASE(UInt8, UINT8); + CASE(UInt16, UINT16); + CASE(UInt, UINT32); + CASE(UInt64, UINT64); + CASE(Half, FLOAT16); + CASE(Float, FLOAT32); + CASE(Double, FLOAT64); + +#undef CASE + + default: + SLANG_REFLECTION_UNEXPECTED(); + return SLANG_SCALAR_TYPE_NONE; + break; + } + } + + return SLANG_SCALAR_TYPE_NONE; +} + +SLANG_API unsigned int spReflectionType_GetUserAttributeCount(SlangReflectionType* inType) +{ + auto type = convert(inType); + if (!type) return 0; + if (auto declRefType = as(type)) + { + return getUserAttributeCount(declRefType->declRef.getDecl()); + } + return 0; +} +SLANG_API SlangReflectionUserAttribute* spReflectionType_GetUserAttribute(SlangReflectionType* inType, unsigned int index) +{ + auto type = convert(inType); + if (!type) return 0; + if (auto declRefType = as(type)) + { + return getUserAttributeByIndex(declRefType->declRef.getDecl(), index); + } + return 0; +} +SLANG_API SlangReflectionUserAttribute* spReflectionType_FindUserAttributeByName(SlangReflectionType* inType, char const* name) +{ + auto type = convert(inType); + if (!type) return 0; + if (auto declRefType = as(type)) + { + return findUserAttributeByName(declRefType->getSession(), declRefType->declRef.getDecl(), name); + } + return 0; +} + +SLANG_API SlangResourceShape spReflectionType_GetResourceShape(SlangReflectionType* inType) +{ + auto type = convert(inType); + if(!type) return 0; + + while(auto arrayType = as(type)) + { + type = arrayType->baseType.Ptr(); + } + + if(auto textureType = as(type)) + { + return textureType->getShape(); + } + + // TODO: need a better way to handle this stuff... +#define CASE(TYPE, SHAPE, ACCESS) \ + else if(as(type)) do { \ + return SHAPE; \ + } while(0) + + CASE(HLSLStructuredBufferType, SLANG_STRUCTURED_BUFFER, SLANG_RESOURCE_ACCESS_READ); + CASE(HLSLRWStructuredBufferType, SLANG_STRUCTURED_BUFFER, SLANG_RESOURCE_ACCESS_READ_WRITE); + CASE(HLSLRasterizerOrderedStructuredBufferType, SLANG_STRUCTURED_BUFFER, SLANG_RESOURCE_ACCESS_RASTER_ORDERED); + CASE(HLSLAppendStructuredBufferType, SLANG_STRUCTURED_BUFFER, SLANG_RESOURCE_ACCESS_APPEND); + CASE(HLSLConsumeStructuredBufferType, SLANG_STRUCTURED_BUFFER, SLANG_RESOURCE_ACCESS_CONSUME); + CASE(HLSLByteAddressBufferType, SLANG_BYTE_ADDRESS_BUFFER, SLANG_RESOURCE_ACCESS_READ); + CASE(HLSLRWByteAddressBufferType, SLANG_BYTE_ADDRESS_BUFFER, SLANG_RESOURCE_ACCESS_READ_WRITE); + CASE(HLSLRasterizerOrderedByteAddressBufferType, SLANG_BYTE_ADDRESS_BUFFER, SLANG_RESOURCE_ACCESS_RASTER_ORDERED); + CASE(RaytracingAccelerationStructureType, SLANG_ACCELERATION_STRUCTURE, SLANG_RESOURCE_ACCESS_READ); + CASE(UntypedBufferResourceType, SLANG_BYTE_ADDRESS_BUFFER, SLANG_RESOURCE_ACCESS_READ); +#undef CASE + + return SLANG_RESOURCE_NONE; +} + +SLANG_API SlangResourceAccess spReflectionType_GetResourceAccess(SlangReflectionType* inType) +{ + auto type = convert(inType); + if(!type) return 0; + + while(auto arrayType = as(type)) + { + type = arrayType->baseType.Ptr(); + } + + if(auto textureType = as(type)) + { + return textureType->getAccess(); + } + + // TODO: need a better way to handle this stuff... +#define CASE(TYPE, SHAPE, ACCESS) \ + else if(as(type)) do { \ + return ACCESS; \ + } while(0) + + CASE(HLSLStructuredBufferType, SLANG_STRUCTURED_BUFFER, SLANG_RESOURCE_ACCESS_READ); + CASE(HLSLRWStructuredBufferType, SLANG_STRUCTURED_BUFFER, SLANG_RESOURCE_ACCESS_READ_WRITE); + CASE(HLSLRasterizerOrderedStructuredBufferType, SLANG_STRUCTURED_BUFFER, SLANG_RESOURCE_ACCESS_RASTER_ORDERED); + CASE(HLSLAppendStructuredBufferType, SLANG_STRUCTURED_BUFFER, SLANG_RESOURCE_ACCESS_APPEND); + CASE(HLSLConsumeStructuredBufferType, SLANG_STRUCTURED_BUFFER, SLANG_RESOURCE_ACCESS_CONSUME); + CASE(HLSLByteAddressBufferType, SLANG_BYTE_ADDRESS_BUFFER, SLANG_RESOURCE_ACCESS_READ); + CASE(HLSLRWByteAddressBufferType, SLANG_BYTE_ADDRESS_BUFFER, SLANG_RESOURCE_ACCESS_READ_WRITE); + CASE(HLSLRasterizerOrderedByteAddressBufferType, SLANG_BYTE_ADDRESS_BUFFER, SLANG_RESOURCE_ACCESS_RASTER_ORDERED); + CASE(UntypedBufferResourceType, SLANG_BYTE_ADDRESS_BUFFER, SLANG_RESOURCE_ACCESS_READ); + + // This isn't entirely accurate, but I can live with it for now + CASE(GLSLShaderStorageBufferType, SLANG_STRUCTURED_BUFFER, SLANG_RESOURCE_ACCESS_READ_WRITE); +#undef CASE + + return SLANG_RESOURCE_ACCESS_NONE; +} + +SLANG_API char const* spReflectionType_GetName(SlangReflectionType* inType) +{ + auto type = convert(inType); + + if( auto declRefType = as(type) ) + { + auto declRef = declRefType->declRef; + + // Don't return a name for auto-generated anonymous types + // that represent `cbuffer` members, etc. + auto decl = declRef.getDecl(); + if(decl->HasModifier()) + return nullptr; + + return getText(declRef.GetName()).begin(); + } + + return nullptr; +} + +SLANG_API SlangReflectionType * spReflection_FindTypeByName(SlangReflection * reflection, char const * name) +{ + auto programLayout = convert(reflection); + auto program = programLayout->getProgram(); + + // TODO: We should extend this API to support getting error messages + // when type lookup fails. + // + Slang::DiagnosticSink sink; + + sink.sourceManager = programLayout->getTargetReq()->getLinkage()->getSourceManager();; + RefPtr result = program->getTypeFromString(name, &sink); + return (SlangReflectionType*)result.Ptr(); +} + +SLANG_API SlangReflectionTypeLayout* spReflection_GetTypeLayout( + SlangReflection* reflection, + SlangReflectionType* inType, + SlangLayoutRules /*rules*/) +{ + auto context = convert(reflection); + auto type = convert(inType); + auto targetReq = context->getTargetReq(); + auto layoutContext = getInitialLayoutContextForTarget(targetReq, context); + RefPtr result; + if (targetReq->getTypeLayouts().TryGetValue(type, result)) + return (SlangReflectionTypeLayout*)result.Ptr(); + result = createTypeLayout(layoutContext, type); + targetReq->getTypeLayouts()[type] = result; + return (SlangReflectionTypeLayout*)result.Ptr(); +} + +SLANG_API SlangReflectionType* spReflectionType_GetResourceResultType(SlangReflectionType* inType) +{ + auto type = convert(inType); + if(!type) return nullptr; + + while(auto arrayType = as(type)) + { + type = arrayType->baseType.Ptr(); + } + + if (auto textureType = as(type)) + { + return convert(textureType->elementType.Ptr()); + } + + // TODO: need a better way to handle this stuff... +#define CASE(TYPE, SHAPE, ACCESS) \ + else if(as(type)) do { \ + return convert(as(type)->elementType.Ptr()); \ + } while(0) + + // TODO: structured buffer needs to expose type layout! + + CASE(HLSLStructuredBufferType, SLANG_STRUCTURED_BUFFER, SLANG_RESOURCE_ACCESS_READ); + CASE(HLSLRWStructuredBufferType, SLANG_STRUCTURED_BUFFER, SLANG_RESOURCE_ACCESS_READ_WRITE); + CASE(HLSLRasterizerOrderedStructuredBufferType, SLANG_STRUCTURED_BUFFER, SLANG_RESOURCE_ACCESS_RASTER_ORDERED); + CASE(HLSLAppendStructuredBufferType, SLANG_STRUCTURED_BUFFER, SLANG_RESOURCE_ACCESS_APPEND); + CASE(HLSLConsumeStructuredBufferType, SLANG_STRUCTURED_BUFFER, SLANG_RESOURCE_ACCESS_CONSUME); +#undef CASE + + return nullptr; +} + +// type Layout Reflection + +SLANG_API SlangReflectionType* spReflectionTypeLayout_GetType(SlangReflectionTypeLayout* inTypeLayout) +{ + auto typeLayout = convert(inTypeLayout); + if(!typeLayout) return nullptr; + + return (SlangReflectionType*) typeLayout->type.Ptr(); +} + +namespace +{ + static size_t getReflectionSize(LayoutSize size) + { + if(size.isFinite()) + return size.getFiniteValue(); + + return SLANG_UNBOUNDED_SIZE; + } +} + +SLANG_API size_t spReflectionTypeLayout_GetSize(SlangReflectionTypeLayout* inTypeLayout, SlangParameterCategory category) +{ + auto typeLayout = convert(inTypeLayout); + if(!typeLayout) return 0; + + auto info = typeLayout->FindResourceInfo(LayoutResourceKind(category)); + if(!info) return 0; + + return getReflectionSize(info->count); +} + +SLANG_API SlangReflectionVariableLayout* spReflectionTypeLayout_GetFieldByIndex(SlangReflectionTypeLayout* inTypeLayout, unsigned index) +{ + auto typeLayout = convert(inTypeLayout); + if(!typeLayout) return nullptr; + + if(auto structTypeLayout = as(typeLayout)) + { + return (SlangReflectionVariableLayout*) structTypeLayout->fields[index].Ptr(); + } + + return nullptr; +} + +SLANG_API size_t spReflectionTypeLayout_GetElementStride(SlangReflectionTypeLayout* inTypeLayout, SlangParameterCategory category) +{ + auto typeLayout = convert(inTypeLayout); + if(!typeLayout) return 0; + + if( auto arrayTypeLayout = as(typeLayout)) + { + switch (category) + { + // We store the stride explicitly for the uniform case + case SLANG_PARAMETER_CATEGORY_UNIFORM: + return arrayTypeLayout->uniformStride; + + // For most other cases (resource registers), the "stride" + // of an array is simply the number of resources (if any) + // consumed by its element type. + default: + { + auto elementTypeLayout = arrayTypeLayout->elementTypeLayout; + auto info = elementTypeLayout->FindResourceInfo(LayoutResourceKind(category)); + if(!info) return 0; + return getReflectionSize(info->count); + } + + // An important special case, though, is Vulkan descriptor-table slots, + // where an entire array will use a single `binding`, so that the + // effective stride is zero: + case SLANG_PARAMETER_CATEGORY_DESCRIPTOR_TABLE_SLOT: + return 0; + } + } + + return 0; +} + +SLANG_API SlangReflectionTypeLayout* spReflectionTypeLayout_GetElementTypeLayout(SlangReflectionTypeLayout* inTypeLayout) +{ + auto typeLayout = convert(inTypeLayout); + if(!typeLayout) return nullptr; + + if( auto arrayTypeLayout = as(typeLayout)) + { + return (SlangReflectionTypeLayout*) arrayTypeLayout->elementTypeLayout.Ptr(); + } + else if( auto constantBufferTypeLayout = as(typeLayout)) + { + return convert(constantBufferTypeLayout->offsetElementTypeLayout.Ptr()); + } + else if( auto structuredBufferTypeLayout = as(typeLayout)) + { + return convert(structuredBufferTypeLayout->elementTypeLayout.Ptr()); + } + else if( auto specializedTypeLayout = as(typeLayout) ) + { + return convert(specializedTypeLayout->baseTypeLayout.Ptr()); + } + + return nullptr; +} + +SLANG_API SlangReflectionVariableLayout* spReflectionTypeLayout_GetElementVarLayout(SlangReflectionTypeLayout* inTypeLayout) +{ + auto typeLayout = convert(inTypeLayout); + if(!typeLayout) return nullptr; + + if( auto constantBufferTypeLayout = as(typeLayout)) + { + return convert(constantBufferTypeLayout->elementVarLayout.Ptr()); + } + + return nullptr; +} + +static SlangParameterCategory getParameterCategory( + LayoutResourceKind kind) +{ + return SlangParameterCategory(kind); +} + +static SlangParameterCategory getParameterCategory( + TypeLayout* typeLayout) +{ + auto resourceInfoCount = typeLayout->resourceInfos.getCount(); + if(resourceInfoCount == 1) + { + return getParameterCategory(typeLayout->resourceInfos[0].kind); + } + else if(resourceInfoCount == 0) + { + // TODO: can this ever happen? + return SLANG_PARAMETER_CATEGORY_NONE; + } + return SLANG_PARAMETER_CATEGORY_MIXED; +} + +static TypeLayout* maybeGetContainerLayout(TypeLayout* typeLayout) +{ + if (auto parameterGroupTypeLayout = as(typeLayout)) + { + auto containerTypeLayout = parameterGroupTypeLayout->containerVarLayout->typeLayout; + if (containerTypeLayout->resourceInfos.getCount() != 0) + { + return containerTypeLayout; + } + } + + return typeLayout; +} + +SLANG_API SlangParameterCategory spReflectionTypeLayout_GetParameterCategory(SlangReflectionTypeLayout* inTypeLayout) +{ + auto typeLayout = convert(inTypeLayout); + if(!typeLayout) return SLANG_PARAMETER_CATEGORY_NONE; + + typeLayout = maybeGetContainerLayout(typeLayout); + + return getParameterCategory(typeLayout); +} + +SLANG_API unsigned spReflectionTypeLayout_GetCategoryCount(SlangReflectionTypeLayout* inTypeLayout) +{ + auto typeLayout = convert(inTypeLayout); + if(!typeLayout) return 0; + + typeLayout = maybeGetContainerLayout(typeLayout); + + return (unsigned) typeLayout->resourceInfos.getCount(); +} + +SLANG_API SlangParameterCategory spReflectionTypeLayout_GetCategoryByIndex(SlangReflectionTypeLayout* inTypeLayout, unsigned index) +{ + auto typeLayout = convert(inTypeLayout); + if(!typeLayout) return SLANG_PARAMETER_CATEGORY_NONE; + + typeLayout = maybeGetContainerLayout(typeLayout); + + return typeLayout->resourceInfos[index].kind; +} + +SLANG_API SlangMatrixLayoutMode spReflectionTypeLayout_GetMatrixLayoutMode(SlangReflectionTypeLayout* inTypeLayout) +{ + auto typeLayout = convert(inTypeLayout); + if(!typeLayout) return SLANG_MATRIX_LAYOUT_MODE_UNKNOWN; + + if( auto matrixLayout = as(typeLayout) ) + { + return matrixLayout->mode; + } + else + { + return SLANG_MATRIX_LAYOUT_MODE_UNKNOWN; + } + +} + +SLANG_API int spReflectionTypeLayout_getGenericParamIndex(SlangReflectionTypeLayout* inTypeLayout) +{ + auto typeLayout = convert(inTypeLayout); + if(!typeLayout) return -1; + + if(auto genericParamTypeLayout = as(typeLayout)) + { + return genericParamTypeLayout->paramIndex; + } + else + { + return -1; + } +} + +SLANG_API SlangReflectionTypeLayout* spReflectionTypeLayout_getPendingDataTypeLayout(SlangReflectionTypeLayout* inTypeLayout) +{ + auto typeLayout = convert(inTypeLayout); + if(!typeLayout) return nullptr; + + auto pendingDataTypeLayout = typeLayout->pendingDataTypeLayout.Ptr(); + return convert(pendingDataTypeLayout); +} + +SLANG_API SlangReflectionVariableLayout* spReflectionVariableLayout_getPendingDataLayout(SlangReflectionVariableLayout* inVarLayout) +{ + auto varLayout = convert(inVarLayout); + if(!varLayout) return nullptr; + + auto pendingDataLayout = varLayout->pendingVarLayout.Ptr(); + return convert(pendingDataLayout); +} + +SLANG_API SlangReflectionVariableLayout* spReflectionTypeLayout_getSpecializedTypePendingDataVarLayout(SlangReflectionTypeLayout* inTypeLayout) +{ + auto typeLayout = convert(inTypeLayout); + if(!typeLayout) return nullptr; + + if( auto specializedTypeLayout = as(typeLayout) ) + { + auto pendingDataVarLayout = specializedTypeLayout->pendingDataVarLayout.Ptr(); + return convert(pendingDataVarLayout); + } + else + { + return nullptr; + } +} + + +// Variable Reflection + +SLANG_API char const* spReflectionVariable_GetName(SlangReflectionVariable* inVar) +{ + auto var = convert(inVar); + if(!var) return nullptr; + + // If the variable is one that has an "external" name that is supposed + // to be exposed for reflection, then report it here + if(auto reflectionNameMod = var->FindModifier()) + return getText(reflectionNameMod->nameAndLoc.name).getBuffer(); + + return getText(var->getName()).getBuffer(); +} + +SLANG_API SlangReflectionType* spReflectionVariable_GetType(SlangReflectionVariable* inVar) +{ + auto var = convert(inVar); + if(!var) return nullptr; + + return convert(var->getType()); +} + +SLANG_API SlangReflectionModifier* spReflectionVariable_FindModifier(SlangReflectionVariable* inVar, SlangModifierID modifierID) +{ + auto var = convert(inVar); + if(!var) return nullptr; + + Modifier* modifier = nullptr; + switch( modifierID ) + { + case SLANG_MODIFIER_SHARED: + modifier = var->FindModifier(); + break; + + default: + return nullptr; + } + + return (SlangReflectionModifier*) modifier; +} + +SLANG_API unsigned int spReflectionVariable_GetUserAttributeCount(SlangReflectionVariable* inVar) +{ + auto varDecl = convert(inVar); + if (!varDecl) return 0; + return getUserAttributeCount(varDecl); +} +SLANG_API SlangReflectionUserAttribute* spReflectionVariable_GetUserAttribute(SlangReflectionVariable* inVar, unsigned int index) +{ + auto varDecl = convert(inVar); + if (!varDecl) return 0; + return getUserAttributeByIndex(varDecl, index); +} +SLANG_API SlangReflectionUserAttribute* spReflectionVariable_FindUserAttributeByName(SlangReflectionVariable* inVar, SlangSession* session, char const* name) +{ + auto varDecl = convert(inVar); + if (!varDecl) return 0; + return findUserAttributeByName(convert(session), varDecl, name); +} + +// Variable Layout Reflection + +SLANG_API SlangReflectionVariable* spReflectionVariableLayout_GetVariable(SlangReflectionVariableLayout* inVarLayout) +{ + auto varLayout = convert(inVarLayout); + if(!varLayout) return nullptr; + + return convert(varLayout->varDecl.getDecl()); +} + +SLANG_API SlangReflectionTypeLayout* spReflectionVariableLayout_GetTypeLayout(SlangReflectionVariableLayout* inVarLayout) +{ + auto varLayout = convert(inVarLayout); + if(!varLayout) return nullptr; + + return convert(varLayout->getTypeLayout()); +} + +namespace Slang +{ + // Attempt "do what I mean" remapping from the parameter category the user asked about, + // over to a parameter category that they might have meant. + static SlangParameterCategory maybeRemapParameterCategory( + TypeLayout* typeLayout, + SlangParameterCategory category) + { + // Do we have an entry for the category they asked about? Then use that. + if (typeLayout->FindResourceInfo(LayoutResourceKind(category))) + return category; + + // Do we have an entry for the `DescriptorTableSlot` category? + if (typeLayout->FindResourceInfo(LayoutResourceKind::DescriptorTableSlot)) + { + // Is the category they were asking about one that makes sense for the type + // of this variable? + Type* type = typeLayout->getType(); + while (auto arrayType = as(type)) + type = arrayType->baseType; + switch (spReflectionType_GetKind(convert(type))) + { + case SLANG_TYPE_KIND_CONSTANT_BUFFER: + if(category == SLANG_PARAMETER_CATEGORY_CONSTANT_BUFFER) + return SLANG_PARAMETER_CATEGORY_DESCRIPTOR_TABLE_SLOT; + break; + + case SLANG_TYPE_KIND_RESOURCE: + if(category == SLANG_PARAMETER_CATEGORY_SHADER_RESOURCE) + return SLANG_PARAMETER_CATEGORY_DESCRIPTOR_TABLE_SLOT; + break; + + case SLANG_TYPE_KIND_SAMPLER_STATE: + if(category == SLANG_PARAMETER_CATEGORY_SAMPLER_STATE) + return SLANG_PARAMETER_CATEGORY_DESCRIPTOR_TABLE_SLOT; + break; + + // TODO: implement more helpers here + + default: + break; + } + } + + return category; + } +} + +SLANG_API size_t spReflectionVariableLayout_GetOffset(SlangReflectionVariableLayout* inVarLayout, SlangParameterCategory category) +{ + auto varLayout = convert(inVarLayout); + if(!varLayout) return 0; + + auto info = varLayout->FindResourceInfo(LayoutResourceKind(category)); + + if (!info) + { + // No match with requested category? Try again with one they might have meant... + category = maybeRemapParameterCategory(varLayout->getTypeLayout(), category); + info = varLayout->FindResourceInfo(LayoutResourceKind(category)); + } + + if(!info) return 0; + + return info->index; +} + +SLANG_API size_t spReflectionVariableLayout_GetSpace(SlangReflectionVariableLayout* inVarLayout, SlangParameterCategory category) +{ + auto varLayout = convert(inVarLayout); + if(!varLayout) return 0; + + + auto info = varLayout->FindResourceInfo(LayoutResourceKind(category)); + if (!info) + { + // No match with requested category? Try again with one they might have meant... + category = maybeRemapParameterCategory(varLayout->getTypeLayout(), category); + info = varLayout->FindResourceInfo(LayoutResourceKind(category)); + } + + UInt space = 0; + + // First, deal with any offset applied to the specific resource kind specified + if (info) + { + space += info->space; + } + + // Next, deal with any dedicated register-space offset applied to, e.g., a parameter block + if (auto spaceInfo = varLayout->FindResourceInfo(LayoutResourceKind::RegisterSpace)) + { + space += spaceInfo->index; + } + + return space; +} + +SLANG_API char const* spReflectionVariableLayout_GetSemanticName(SlangReflectionVariableLayout* inVarLayout) +{ + auto varLayout = convert(inVarLayout); + if(!varLayout) return 0; + + if (!(varLayout->flags & Slang::VarLayoutFlag::HasSemantic)) + return 0; + + return varLayout->semanticName.getBuffer(); +} + +SLANG_API size_t spReflectionVariableLayout_GetSemanticIndex(SlangReflectionVariableLayout* inVarLayout) +{ + auto varLayout = convert(inVarLayout); + if(!varLayout) return 0; + + if (!(varLayout->flags & Slang::VarLayoutFlag::HasSemantic)) + return 0; + + return varLayout->semanticIndex; +} + +SLANG_API SlangStage spReflectionVariableLayout_getStage( + SlangReflectionVariableLayout* inVarLayout) +{ + auto varLayout = convert(inVarLayout); + if(!varLayout) return SLANG_STAGE_NONE; + + // A parameter that is not a varying input or output is + // not considered to belong to a single stage. + // + // TODO: We might need to reconsider this for, e.g., entry + // point parameters, where they might be stage-specific even + // if they are uniform. + if (!varLayout->FindResourceInfo(Slang::LayoutResourceKind::VaryingInput) + && !varLayout->FindResourceInfo(Slang::LayoutResourceKind::VaryingOutput)) + { + return SLANG_STAGE_NONE; + } + + // TODO: We should find the stage for a variable layout by + // walking up the tree of layout information, until we find + // something that has a definitive stage attached to it (e.g., + // either an entry point or a GLSL translation unit). + // + // We don't currently have parent links in the reflection layout + // information, so doing that walk would be tricky right now, so + // it is easier to just bloat the representation and store yet another + // field on every variable layout. + return (SlangStage) varLayout->stage; +} + + +// Shader Parameter Reflection + +SLANG_API unsigned spReflectionParameter_GetBindingIndex(SlangReflectionParameter* inVarLayout) +{ + SlangReflectionVariableLayout* varLayout = (SlangReflectionVariableLayout*)inVarLayout; + return (unsigned) spReflectionVariableLayout_GetOffset( + varLayout, + spReflectionTypeLayout_GetParameterCategory( + spReflectionVariableLayout_GetTypeLayout(varLayout))); +} + +SLANG_API unsigned spReflectionParameter_GetBindingSpace(SlangReflectionParameter* inVarLayout) +{ + SlangReflectionVariableLayout* varLayout = (SlangReflectionVariableLayout*)inVarLayout; + return (unsigned) spReflectionVariableLayout_GetSpace( + varLayout, + spReflectionTypeLayout_GetParameterCategory( + spReflectionVariableLayout_GetTypeLayout(varLayout))); +} + +// Helpers for getting parameter count + +namespace Slang +{ + static unsigned getParameterCount(RefPtr typeLayout) + { + if(auto parameterGroupLayout = as(typeLayout)) + { + typeLayout = parameterGroupLayout->offsetElementTypeLayout; + } + + if(auto structLayout = as(typeLayout)) + { + return (unsigned) structLayout->fields.getCount(); + } + + return 0; + } + + static VarLayout* getParameterByIndex(RefPtr typeLayout, unsigned index) + { + if(auto parameterGroupLayout = as(typeLayout)) + { + typeLayout = parameterGroupLayout->offsetElementTypeLayout; + } + + if(auto structLayout = as(typeLayout)) + { + return structLayout->fields[index]; + } + + return 0; + } +} + +// Entry Point Reflection + +SLANG_API char const* spReflectionEntryPoint_getName( + SlangReflectionEntryPoint* inEntryPoint) +{ + auto entryPointLayout = convert(inEntryPoint); + if(!entryPointLayout) return 0; + + return getText(entryPointLayout->entryPoint->getName()).begin(); +} + +SLANG_API unsigned spReflectionEntryPoint_getParameterCount( + SlangReflectionEntryPoint* inEntryPoint) +{ + auto entryPointLayout = convert(inEntryPoint); + if(!entryPointLayout) return 0; + + return getParameterCount(entryPointLayout->parametersLayout->typeLayout); +} + +SLANG_API SlangReflectionVariableLayout* spReflectionEntryPoint_getParameterByIndex( + SlangReflectionEntryPoint* inEntryPoint, + unsigned index) +{ + auto entryPointLayout = convert(inEntryPoint); + if(!entryPointLayout) return 0; + + return convert(getParameterByIndex(entryPointLayout->parametersLayout->typeLayout, index)); +} + +SLANG_API SlangStage spReflectionEntryPoint_getStage(SlangReflectionEntryPoint* inEntryPoint) +{ + auto entryPointLayout = convert(inEntryPoint); + + if(!entryPointLayout) return SLANG_STAGE_NONE; + + return SlangStage(entryPointLayout->profile.GetStage()); +} + +SLANG_API void spReflectionEntryPoint_getComputeThreadGroupSize( + SlangReflectionEntryPoint* inEntryPoint, + SlangUInt axisCount, + SlangUInt* outSizeAlongAxis) +{ + auto entryPointLayout = convert(inEntryPoint); + + if(!entryPointLayout) return; + if(!axisCount) return; + if(!outSizeAlongAxis) return; + + auto entryPointFunc = entryPointLayout->entryPoint; + if(!entryPointFunc) return; + + SlangUInt sizeAlongAxis[3] = { 1, 1, 1 }; + + // First look for the HLSL case, where we have an attribute attached to the entry point function + auto numThreadsAttribute = entryPointFunc->FindModifier(); + if (numThreadsAttribute) + { + sizeAlongAxis[0] = numThreadsAttribute->x; + sizeAlongAxis[1] = numThreadsAttribute->y; + sizeAlongAxis[2] = numThreadsAttribute->z; + } + else + { + // Fall back to the GLSL case, which requires a search over global-scope declarations + // to look for as with the `local_size_*` qualifier + auto module = as(entryPointFunc->ParentDecl); + if (module) + { + for (auto dd : module->Members) + { + for (auto mod : dd->GetModifiersOfType()) + { + if (auto xMod = as(mod)) + sizeAlongAxis[0] = (SlangUInt) getIntegerLiteralValue(xMod->valToken); + else if (auto yMod = as(mod)) + sizeAlongAxis[1] = (SlangUInt) getIntegerLiteralValue(yMod->valToken); + else if (auto zMod = as(mod)) + sizeAlongAxis[2] = (SlangUInt) getIntegerLiteralValue(zMod->valToken); + } + } + } + } + + // + + if(axisCount > 0) outSizeAlongAxis[0] = sizeAlongAxis[0]; + if(axisCount > 1) outSizeAlongAxis[1] = sizeAlongAxis[1]; + if(axisCount > 2) outSizeAlongAxis[2] = sizeAlongAxis[2]; + for( SlangUInt aa = 3; aa < axisCount; ++aa ) + { + outSizeAlongAxis[aa] = 1; + } +} + +SLANG_API int spReflectionEntryPoint_usesAnySampleRateInput( + SlangReflectionEntryPoint* inEntryPoint) +{ + auto entryPointLayout = convert(inEntryPoint); + if(!entryPointLayout) + return 0; + + if (entryPointLayout->profile.GetStage() != Stage::Fragment) + return 0; + + return (entryPointLayout->flags & EntryPointLayout::Flag::usesAnySampleRateInput) != 0; +} + +// SlangReflectionTypeParameter +SLANG_API char const* spReflectionTypeParameter_GetName(SlangReflectionTypeParameter * inTypeParam) +{ + auto typeParam = convert(inTypeParam); + return typeParam->decl->getName()->text.getBuffer(); +} + +SLANG_API unsigned spReflectionTypeParameter_GetIndex(SlangReflectionTypeParameter * inTypeParam) +{ + auto typeParam = convert(inTypeParam); + return (unsigned)(typeParam->index); +} + +SLANG_API unsigned int spReflectionTypeParameter_GetConstraintCount(SlangReflectionTypeParameter* inTypeParam) +{ + auto typeParam = convert(inTypeParam); + auto constraints = typeParam->decl->getMembersOfType(); + return (unsigned int)constraints.getCount(); +} + +SLANG_API SlangReflectionType* spReflectionTypeParameter_GetConstraintByIndex(SlangReflectionTypeParameter * inTypeParam, unsigned index) +{ + auto typeParam = convert(inTypeParam); + auto constraints = typeParam->decl->getMembersOfType(); + return (SlangReflectionType*)constraints.toArray()[index]->sup.Ptr(); +} + +// Shader Reflection + +SLANG_API unsigned spReflection_GetParameterCount(SlangReflection* inProgram) +{ + auto program = convert(inProgram); + if(!program) return 0; + + auto globalStructLayout = getGlobalStructLayout(program); + if (!globalStructLayout) + return 0; + + return (unsigned) globalStructLayout->fields.getCount(); +} + +SLANG_API SlangReflectionParameter* spReflection_GetParameterByIndex(SlangReflection* inProgram, unsigned index) +{ + auto program = convert(inProgram); + if(!program) return nullptr; + + auto globalStructLayout = getGlobalStructLayout(program); + if (!globalStructLayout) + return 0; + + return convert(globalStructLayout->fields[index].Ptr()); +} + +SLANG_API unsigned int spReflection_GetTypeParameterCount(SlangReflection * reflection) +{ + auto program = convert(reflection); + return (unsigned int)program->globalGenericParams.getCount(); +} + +SLANG_API SlangReflectionTypeParameter* spReflection_GetTypeParameterByIndex(SlangReflection * reflection, unsigned int index) +{ + auto program = convert(reflection); + return (SlangReflectionTypeParameter*)program->globalGenericParams[index].Ptr(); +} + +SLANG_API SlangReflectionTypeParameter * spReflection_FindTypeParameter(SlangReflection * inProgram, char const * name) +{ + auto program = convert(inProgram); + if (!program) return nullptr; + GenericParamLayout * result = nullptr; + program->globalGenericParamsMap.TryGetValue(name, result); + return (SlangReflectionTypeParameter*)result; +} + +SLANG_API SlangUInt spReflection_getEntryPointCount(SlangReflection* inProgram) +{ + auto program = convert(inProgram); + if(!program) return 0; + + return SlangUInt(program->entryPoints.getCount()); +} + +SLANG_API SlangReflectionEntryPoint* spReflection_getEntryPointByIndex(SlangReflection* inProgram, SlangUInt index) +{ + auto program = convert(inProgram); + if(!program) return 0; + + return convert(program->entryPoints[(int) index].Ptr()); +} + +SLANG_API SlangReflectionEntryPoint* spReflection_findEntryPointByName(SlangReflection* inProgram, char const* name) +{ + auto program = convert(inProgram); + if(!program) return 0; + + // TODO: improve on dumb linear search + for(auto ep : program->entryPoints) + { + if(ep->entryPoint->getName()->text == name) + { + return convert(ep); + } + } + + return nullptr; +} + + +SLANG_API SlangUInt spReflection_getGlobalConstantBufferBinding(SlangReflection* inProgram) +{ + auto program = convert(inProgram); + if (!program) return 0; + auto cb = program->parametersLayout->FindResourceInfo(LayoutResourceKind::ConstantBuffer); + if (!cb) return 0; + return cb->index; +} + +SLANG_API size_t spReflection_getGlobalConstantBufferSize(SlangReflection* inProgram) +{ + auto program = convert(inProgram); + if (!program) return 0; + auto structLayout = getGlobalStructLayout(program); + auto uniform = structLayout->FindResourceInfo(LayoutResourceKind::Uniform); + if (!uniform) return 0; + return getReflectionSize(uniform->count); +} + +SLANG_API SlangReflectionType* spReflection_specializeType( + SlangReflection* inProgramLayout, + SlangReflectionType* inType, + SlangInt specializationArgCount, + SlangReflectionType* const* specializationArgs, + ISlangBlob** outDiagnostics) +{ + auto programLayout = convert(inProgramLayout); + if(!programLayout) return nullptr; + + auto unspecializedType = convert(inType); + if(!unspecializedType) return nullptr; + + auto linkage = programLayout->getProgram()->getLinkage(); + + DiagnosticSink sink; + sink.sourceManager = linkage->getSourceManager(); + + auto specializedType = linkage->specializeType(unspecializedType, specializationArgCount, (Type* const*) specializationArgs, &sink); + + sink.getBlobIfNeeded(outDiagnostics); + + return convert(specializedType); +} + diff --git a/source/slang/slang-reflection.h b/source/slang/slang-reflection.h new file mode 100644 index 000000000..ea3021bd6 --- /dev/null +++ b/source/slang/slang-reflection.h @@ -0,0 +1,26 @@ +#ifndef SLANG_REFLECTION_H +#define SLANG_REFLECTION_H + +#include "../core/slang-basic.h" +#include "slang-syntax.h" + +#include "../../slang.h" + +namespace Slang { + +class ProgramLayout; +class TypeLayout; + +// + +SlangTypeKind getReflectionTypeKind(Type* type); + +SlangTypeKind getReflectionParameterCategory(TypeLayout* typeLayout); + +UInt getReflectionFieldCount(Type* type); +UInt getReflectionFieldByIndex(Type* type, UInt index); +UInt getReflectionFieldByIndex(TypeLayout* typeLayout, UInt index); + +} + +#endif // SLANG_REFLECTION_H diff --git a/source/slang/slang-source-loc.cpp b/source/slang/slang-source-loc.cpp new file mode 100644 index 000000000..faa7e77c3 --- /dev/null +++ b/source/slang/slang-source-loc.cpp @@ -0,0 +1,591 @@ +// slang-source-loc.cpp +#include "slang-source-loc.h" + +#include "slang-compiler.h" + +#include "../core/slang-string-util.h" + +namespace Slang { + +/* !!!!!!!!!!!!!!!!!!!!!!!!! SourceView !!!!!!!!!!!!!!!!!!!!!!!!!!!! */ + +const String PathInfo::getMostUniqueIdentity() const +{ + switch (type) + { + case Type::Normal: return uniqueIdentity; + case Type::FoundPath: + case Type::FromString: + { + return foundPath; + } + default: return ""; + } +} + +/* !!!!!!!!!!!!!!!!!!!!!!!!! SourceView !!!!!!!!!!!!!!!!!!!!!!!!!!!! */ + +int SourceView::findEntryIndex(SourceLoc sourceLoc) const +{ + if (!m_range.contains(sourceLoc)) + { + return -1; + } + + const auto rawValue = sourceLoc.getRaw(); + + Index hi = m_entries.getCount(); + // If there are no entries, or it is in front of the first entry, then there is no associated entry + if (hi == 0 || + m_entries[0].m_startLoc.getRaw() > sourceLoc.getRaw()) + { + return -1; + } + + Index lo = 0; + while (lo + 1 < hi) + { + const Index mid = (hi + lo) >> 1; + const Entry& midEntry = m_entries[mid]; + SourceLoc::RawValue midValue = midEntry.m_startLoc.getRaw(); + if (midValue <= rawValue) + { + // The location we seek is at or after this entry + lo = mid; + } + else + { + // The location we seek is before this entry + hi = mid; + } + } + + return int(lo); +} + +void SourceView::addLineDirective(SourceLoc directiveLoc, StringSlicePool::Handle pathHandle, int line) +{ + SLANG_ASSERT(pathHandle != StringSlicePool::Handle(0)); + SLANG_ASSERT(m_range.contains(directiveLoc)); + + // Check that the directiveLoc values are always increasing + SLANG_ASSERT(m_entries.getCount() == 0 || (m_entries.getLast().m_startLoc.getRaw() < directiveLoc.getRaw())); + + // Calculate the offset + const int offset = m_range.getOffset(directiveLoc); + + // Get the line index in the original file + const int lineIndex = m_sourceFile->calcLineIndexFromOffset(offset); + + Entry entry; + entry.m_startLoc = directiveLoc; + entry.m_pathHandle = pathHandle; + + // We also need to make sure that any lookups for line numbers will + // get corrected based on this files location. + // We assume the line number coming from the directive is a line number, NOT an index, so the correction needs + 1 + // There is an additional + 1 because we want the NEXT line - ie the line after the #line directive, to the specified value + // Taking both into account means +2 is correct 'fix' + entry.m_lineAdjust = line - (lineIndex + 2); + + m_entries.add(entry); +} + +void SourceView::addLineDirective(SourceLoc directiveLoc, const String& path, int line) +{ + StringSlicePool::Handle pathHandle = getSourceManager()->getStringSlicePool().add(path.getUnownedSlice()); + return addLineDirective(directiveLoc, pathHandle, line); +} + +void SourceView::addDefaultLineDirective(SourceLoc directiveLoc) +{ + SLANG_ASSERT(m_range.contains(directiveLoc)); + // Check that the directiveLoc values are always increasing + SLANG_ASSERT(m_entries.getCount() == 0 || (m_entries.getLast().m_startLoc.getRaw() < directiveLoc.getRaw())); + + // Well if there are no entries, or the last one puts it in default case, then we don't need to add anything + if (m_entries.getCount() == 0 || (m_entries.getCount() && m_entries.getLast().isDefault())) + { + return; + } + + Entry entry; + entry.m_startLoc = directiveLoc; + entry.m_lineAdjust = 0; // No line adjustment... we are going back to default + entry.m_pathHandle = StringSlicePool::Handle(0); // Mark that there is no path, and that this is a 'default' + + SLANG_ASSERT(entry.isDefault()); + + m_entries.add(entry); +} + +HumaneSourceLoc SourceView::getHumaneLoc(SourceLoc loc, SourceLocType type) +{ + const int offset = m_range.getOffset(loc); + + // We need the line index from the original source file + const int lineIndex = m_sourceFile->calcLineIndexFromOffset(offset); + + // TODO: we should really translate the byte index in the line + // to deal with: + // + // - Non-ASCII characters, while might consume multiple bytes + // + // - Tab characters, which should really adjust how we report + // columns (although how are we supposed to know the setting + // that an IDE expects us to use when reporting locations?) + const int columnIndex = m_sourceFile->calcColumnIndex(lineIndex, offset); + + HumaneSourceLoc humaneLoc; + humaneLoc.column = columnIndex + 1; + humaneLoc.line = lineIndex + 1; + + // Make up a default entry + StringSlicePool::Handle pathHandle = StringSlicePool::Handle(0); + + // Only bother looking up the entry information if we want a 'Normal' lookup + const int entryIndex = (type == SourceLocType::Nominal) ? findEntryIndex(loc) : -1; + if (entryIndex >= 0) + { + const Entry& entry = m_entries[entryIndex]; + // Adjust the line + humaneLoc.line += entry.m_lineAdjust; + // Get the pathHandle.. + pathHandle = entry.m_pathHandle; + } + + humaneLoc.pathInfo = _getPathInfoFromHandle(pathHandle); + return humaneLoc; +} + +PathInfo SourceView::_getPathInfo() const +{ + if (m_viewPath.getLength()) + { + PathInfo pathInfo(m_sourceFile->getPathInfo()); + pathInfo.foundPath = m_viewPath; + return pathInfo; + } + else + { + return m_sourceFile->getPathInfo(); + } +} + +PathInfo SourceView::_getPathInfoFromHandle(StringSlicePool::Handle pathHandle) const +{ + // If there is no override path, then just the source files path + if (pathHandle == StringSlicePool::Handle(0)) + { + return _getPathInfo(); + } + else + { + return PathInfo::makePath(getSourceManager()->getStringSlicePool().getSlice(pathHandle)); + } +} + +PathInfo SourceView::getPathInfo(SourceLoc loc, SourceLocType type) +{ + if (type == SourceLocType::Actual) + { + return _getPathInfo(); + } + + const int entryIndex = findEntryIndex(loc); + return _getPathInfoFromHandle((entryIndex >= 0) ? m_entries[entryIndex].m_pathHandle : StringSlicePool::Handle(0)); +} + +/* !!!!!!!!!!!!!!!!!!!!!!! SourceFile !!!!!!!!!!!!!!!!!!!!!!!!!!!! */ + +void SourceFile::setLineBreakOffsets(const uint32_t* offsets, UInt numOffsets) +{ + m_lineBreakOffsets.clear(); + m_lineBreakOffsets.addRange(offsets, numOffsets); +} + +const List& SourceFile::getLineBreakOffsets() +{ + // We now have a raw input file that we can search for line breaks. + // We obviously don't want to do a linear scan over and over, so we will + // cache an array of line break locations in the file. + if (m_lineBreakOffsets.getCount() == 0) + { + UnownedStringSlice content = getContent(); + + char const* begin = content.begin(); + char const* end = content.end(); + + char const* cursor = begin; + + // Treat the beginning of the file as a line break + m_lineBreakOffsets.add(0); + + while (cursor != end) + { + int c = *cursor++; + switch (c) + { + case '\r': case '\n': + { + // When we see a line-break character we need + // to record the line break, but we also need + // to deal with the annoying issue of encodings, + // where a multi-byte sequence might encode + // the line break. + + // Check to make sure that the EOF hasn't been reached. + if (cursor != end) + { + int d = *cursor; + if ((c ^ d) == ('\r' ^ '\n')) + cursor++; + } + + m_lineBreakOffsets.add(uint32_t(cursor - begin)); + break; + } + default: + break; + } + } + + // Note that we do *not* treat the end of the file as a line + // break, because otherwise we would report errors like + // "end of file inside string literal" with a line number + // that points at a line that doesn't exist. + } + + return m_lineBreakOffsets; +} + +int SourceFile::calcLineIndexFromOffset(int offset) +{ + SLANG_ASSERT(UInt(offset) <= getContentSize()); + + // Make sure we have the line break offsets + const auto& lineBreakOffsets = getLineBreakOffsets(); + + // At this point we can assume the `lineBreakOffsets` array has been filled in. + // We will use a binary search to find the line index that contains our + // chosen offset. + Index lo = 0; + Index hi = lineBreakOffsets.getCount(); + + while (lo + 1 < hi) + { + const Index mid = (hi + lo) >> 1; + const uint32_t midOffset = lineBreakOffsets[mid]; + if (midOffset <= uint32_t(offset)) + { + lo = mid; + } + else + { + hi = mid; + } + } + + return int(lo); +} + +int SourceFile::calcColumnIndex(int lineIndex, int offset) +{ + const auto& lineBreakOffsets = getLineBreakOffsets(); + return offset - lineBreakOffsets[lineIndex]; +} + +/* !!!!!!!!!!!!!!!!!!!!!!!!! SourceFile !!!!!!!!!!!!!!!!!!!!!!!!!!!! */ + +void SourceFile::setContents(ISlangBlob* blob) +{ + const UInt contentSize = blob->getBufferSize(); + + SLANG_ASSERT(contentSize == m_contentSize); + + char const* contentBegin = (char const*)blob->getBufferPointer(); + char const* contentEnd = contentBegin + contentSize; + + m_contentBlob = blob; + m_content = UnownedStringSlice(contentBegin, contentEnd); +} + +void SourceFile::setContents(const String& content) +{ + ComPtr contentBlob = StringUtil::createStringBlob(content); + setContents(contentBlob); +} + +SourceFile::SourceFile(SourceManager* sourceManager, const PathInfo& pathInfo, size_t contentSize) : + m_sourceManager(sourceManager), + m_pathInfo(pathInfo), + m_contentSize(contentSize) +{ +} + +SourceFile::~SourceFile() +{ +} + +String SourceFile::calcVerbosePath() const +{ + ISlangFileSystemExt* fileSystemExt = getSourceManager()->getFileSystemExt(); + + if (fileSystemExt) + { + String canonicalPath; + ComPtr canonicalPathBlob; + if (SLANG_SUCCEEDED(fileSystemExt->getCanonicalPath(m_pathInfo.foundPath.getBuffer(), canonicalPathBlob.writeRef()))) + { + canonicalPath = StringUtil::getString(canonicalPathBlob); + } + if (canonicalPath.getLength() > 0) + { + return canonicalPath; + } + } + + return m_pathInfo.foundPath; +} + +/* !!!!!!!!!!!!!!!!!!!!!!!!! SourceManager !!!!!!!!!!!!!!!!!!!!!!!!!!!! */ + +void SourceManager::initialize( + SourceManager* p, + ISlangFileSystemExt* fileSystemExt) +{ + m_fileSystemExt = fileSystemExt; + + m_parent = p; + + if( p ) + { + // If we have a parent source manager, then we assume that all code at that level + // has already been loaded, and it is safe to start our own source locations + // right after those from the parent. + // + // TODO: more clever allocation in cases where that might not be reasonable + m_startLoc = p->m_nextLoc; + } + else + { + // Location zero is reserved for an invalid location, + // so we need to start reserving locations starting at 1. + m_startLoc = SourceLoc::fromRaw(1); + } + + m_nextLoc = m_startLoc; +} + +SourceManager::~SourceManager() +{ + for (auto item : m_sourceViews) + { + delete item; + } + + for (auto item : m_sourceFiles) + { + delete item; + } +} + +UnownedStringSlice SourceManager::allocateStringSlice(const UnownedStringSlice& slice) +{ + const UInt numChars = slice.size(); + + char* dst = (char*)m_memoryArena.allocate(numChars); + ::memcpy(dst, slice.begin(), numChars); + + return UnownedStringSlice(dst, numChars); +} + +SourceRange SourceManager::allocateSourceRange(UInt size) +{ + // TODO: consider using atomics here + + + SourceLoc beginLoc = m_nextLoc; + SourceLoc endLoc = beginLoc + size; + + // We need to be able to represent the location that is *at* the end of + // the input source, so the next available location for a new file + // must be placed one after the end of this one. + + m_nextLoc = endLoc + 1; + + return SourceRange(beginLoc, endLoc); +} + +SourceFile* SourceManager::createSourceFileWithSize(const PathInfo& pathInfo, size_t contentSize) +{ + SourceFile* sourceFile = new SourceFile(this, pathInfo, contentSize); + m_sourceFiles.add(sourceFile); + return sourceFile; +} + +SourceFile* SourceManager::createSourceFileWithString(const PathInfo& pathInfo, const String& contents) +{ + SourceFile* sourceFile = new SourceFile(this, pathInfo, contents.getLength()); + m_sourceFiles.add(sourceFile); + sourceFile->setContents(contents); + return sourceFile; +} + +SourceFile* SourceManager::createSourceFileWithBlob(const PathInfo& pathInfo, ISlangBlob* blob) +{ + SourceFile* sourceFile = new SourceFile(this, pathInfo, blob->getBufferSize()); + m_sourceFiles.add(sourceFile); + sourceFile->setContents(blob); + return sourceFile; +} + +SourceView* SourceManager::createSourceView(SourceFile* sourceFile, const PathInfo* pathInfo) +{ + SourceRange range = allocateSourceRange(sourceFile->getContentSize()); + + SourceView* sourceView = nullptr; + if (pathInfo && + (pathInfo->foundPath.getLength() && sourceFile->getPathInfo().foundPath != pathInfo->foundPath)) + { + sourceView = new SourceView(sourceFile, range, &pathInfo->foundPath); + } + else + { + sourceView = new SourceView(sourceFile, range, nullptr); + } + + m_sourceViews.add(sourceView); + + return sourceView; +} + +SourceView* SourceManager::findSourceView(SourceLoc loc) const +{ + Index hi = m_sourceViews.getCount(); + // It must be in the range of this manager and have associated views for it to possibly be a hit + if (!getSourceRange().contains(loc) || hi == 0) + { + return nullptr; + } + + // If we don't have very many, we may as well just linearly search + if (hi <= 8) + { + for (int i = 0; i < hi; ++i) + { + SourceView* view = m_sourceViews[i]; + if (view->getRange().contains(loc)) + { + return view; + } + } + return nullptr; + } + + const SourceLoc::RawValue rawLoc = loc.getRaw(); + + // Binary chop to see if we can find the associated SourceUnit + Index lo = 0; + while (lo + 1 < hi) + { + Index mid = (hi + lo) >> 1; + + SourceView* midView = m_sourceViews[mid]; + if (midView->getRange().contains(loc)) + { + return midView; + } + + const SourceLoc::RawValue midValue = midView->getRange().begin.getRaw(); + if (midValue <= rawLoc) + { + // The location we seek is at or after this entry + lo = mid; + } + else + { + // The location we seek is before this entry + hi = mid; + } + } + + // Check if low is actually a hit + SourceView* view = m_sourceViews[lo]; + return (view->getRange().contains(loc)) ? view : nullptr; +} + +SourceView* SourceManager::findSourceViewRecursively(SourceLoc loc) const +{ + // Start with this manager + const SourceManager* manager = this; + do + { + SourceView* sourceView = manager->findSourceView(loc); + // If we found a hit we are done + if (sourceView) + { + return sourceView; + } + // Try the parent + manager = manager->m_parent; + } + while (manager); + // Didn't find it + return nullptr; +} + +SourceFile* SourceManager::findSourceFile(const String& uniqueIdentity) const +{ + SourceFile*const* filePtr = m_sourceFileMap.TryGetValue(uniqueIdentity); + return (filePtr) ? *filePtr : nullptr; +} + +SourceFile* SourceManager::findSourceFileRecursively(const String& uniqueIdentity) const +{ + const SourceManager* manager = this; + do + { + SourceFile* sourceFile = manager->findSourceFile(uniqueIdentity); + if (sourceFile) + { + return sourceFile; + } + manager = manager->m_parent; + } while (manager); + return nullptr; +} + +void SourceManager::addSourceFile(const String& uniqueIdentity, SourceFile* sourceFile) +{ + SLANG_ASSERT(!findSourceFileRecursively(uniqueIdentity)); + m_sourceFileMap.Add(uniqueIdentity, sourceFile); +} + +HumaneSourceLoc SourceManager::getHumaneLoc(SourceLoc loc, SourceLocType type) +{ + SourceView* sourceView = findSourceViewRecursively(loc); + if (sourceView) + { + return sourceView->getHumaneLoc(loc, type); + } + else + { + return HumaneSourceLoc(); + } +} + +PathInfo SourceManager::getPathInfo(SourceLoc loc, SourceLocType type) +{ + SourceView* sourceView = findSourceViewRecursively(loc); + if (sourceView) + { + return sourceView->getPathInfo(loc, type); + } + else + { + return PathInfo::makeUnknown(); + } +} + +} // namespace Slang diff --git a/source/slang/slang-source-loc.h b/source/slang/slang-source-loc.h new file mode 100644 index 000000000..632c05084 --- /dev/null +++ b/source/slang/slang-source-loc.h @@ -0,0 +1,412 @@ +// slang-source-loc.h +#ifndef SLANG_SOURCE_LOC_H_INCLUDED +#define SLANG_SOURCE_LOC_H_INCLUDED + +#include "../core/slang-basic.h" +#include "../core/slang-memory-arena.h" +#include "../core/slang-string-slice-pool.h" + +#include "../../slang-com-ptr.h" +#include "../../slang.h" + +namespace Slang { + +/** Overview: + +There needs to be a mechanism where we can easily and quickly track a specific locations in any source file used during a compilation. +This is important because that original location is meaningful to the user as it relates to their original source. Thus SourceLoc are +used so we can display meaningful and accurate errors/warnings as well as being able to always map generated code locations back to their origins. + +A 'SourceLoc' along with associated structures (SourceView, SourceFile, SourceMangager) this can pinpoint the location down to the byte across the +compilation. This could be achieved by storing for every token and instruction the file, line and column number came from. The SourceLoc is used in +lots of places - every AST node, every Token from the lexer, every IRInst - so we really want to make it small. So for this reason we actually +encode SourceLoc as a single integer and then use the associated structures when needed to determine what the location actually refers to - +the source file, line and column number, or in effect the byte in the original file. + +Unfortunately there is extra complications. When a source is parsed it's interpretation (in terms of how a piece of source maps to an 'original' file etc) +can be overridden - for example by using #line directives. Moreover a single source file can be parsed multiple times. When it's parsed multiple times the +interpretation of the mapping (#line directives for example) can change. This is the purpose of the SourceView - it holds the interpretation of a source file +for a specific Lex/Parse. + +Another complication is that not all 'source' comes from SourceFiles, a macro expansion, may generate new 'source' we need to handle this, but also be able +to have a SourceLoc map to the expansion unambiguously. This is handled by creating a SourceFile and SourceView that holds only the macro generated +specific information. + +SourceFile - Is the immutable text contents of a file (or perhaps some generated source - say from doing a macro substitution) +SourceView - Tracks a single parse of a SourceFile. Each SourceView defines a range of source locations used. If a SourceFile is parsed twice, two +SourceViews are created, with unique SourceRanges. This is so that it is possible to tell which specific parse a SourceLoc is from - and so know the right +interpretation for that lex/parse. +*/ + +struct PathInfo +{ + /// To be more rigorous about where a path comes from, the type identifies what a paths origin is + enum class Type + { + Unknown, ///< The path is not known + Normal, ///< Normal has both path and uniqueIdentity + FoundPath, ///< Just has a found path (uniqueIdentity is unknown, or even 'unknowable') + FromString, ///< Created from a string (so found path might not be defined and should not be taken as to map to a loaded file) + TokenPaste, ///< No paths, just created to do a macro expansion + TypeParse, ///< No path, just created to do a type parse + CommandLine, ///< A macro constructed from the command line + }; + + /// True if has a canonical path + SLANG_FORCE_INLINE bool hasUniqueIdentity() const { return type == Type::Normal && uniqueIdentity.getLength() > 0; } + /// True if has a regular found path + SLANG_FORCE_INLINE bool hasFoundPath() const { return type == Type::Normal || type == Type::FoundPath || (type == Type::FromString && foundPath.getLength() > 0); } + /// True if has a found path that has originated from a file (as opposed to string or some other origin) + SLANG_FORCE_INLINE bool hasFileFoundPath() const { return (type == Type::Normal || type == Type::FoundPath) && foundPath.getLength() > 0; } + + /// Returns the 'most unique' identity for the path. If has a 'uniqueIdentity' returns that, else the foundPath, else "". + const String getMostUniqueIdentity() const; + + // So simplify construction. In normal usage it's safer to use make methods over constructing directly. + static PathInfo makeUnknown() { return PathInfo { Type::Unknown, "unknown", String() }; } + static PathInfo makeTokenPaste() { return PathInfo{ Type::TokenPaste, "token paste", String()}; } + static PathInfo makeNormal(const String& foundPathIn, const String& uniqueIdentity) { SLANG_ASSERT(uniqueIdentity.getLength() > 0 && foundPathIn.getLength() > 0); return PathInfo { Type::Normal, foundPathIn, uniqueIdentity }; } + static PathInfo makePath(const String& pathIn) { SLANG_ASSERT(pathIn.getLength() > 0); return PathInfo { Type::FoundPath, pathIn, String()}; } + static PathInfo makeTypeParse() { return PathInfo { Type::TypeParse, "type string", String() }; } + static PathInfo makeCommandLine() { return PathInfo { Type::CommandLine, "command line", String() }; } + static PathInfo makeFromString(const String& userPath) { return PathInfo{ Type::FromString, userPath, String() }; } + + Type type; ///< The type of path + String foundPath; ///< The path where the file was found (might contain relative elements) + String uniqueIdentity; ///< The unique identity of the file on the path found +}; + +class SourceLoc +{ +public: + typedef uint32_t RawValue; + +private: + RawValue raw; + +public: + SourceLoc() + : raw(0) + {} + + SourceLoc( + SourceLoc const& loc) + : raw(loc.raw) + {} + + RawValue getRaw() const { return raw; } + void setRaw(RawValue value) { raw = value; } + + static SourceLoc fromRaw(RawValue value) + { + SourceLoc result; + result.setRaw(value); + return result; + } + + bool isValid() const + { + return raw != 0; + } +}; + +inline SourceLoc operator+(SourceLoc loc, Int offset) +{ + return SourceLoc::fromRaw(SourceLoc::RawValue(Int(loc.getRaw()) + offset)); +} + +// A range of locations in the input source +struct SourceRange +{ + /// True if the loc is in the range. Range is inclusive on begin to end. + bool contains(SourceLoc loc) const { const auto rawLoc = loc.getRaw(); return rawLoc >= begin.getRaw() && rawLoc <= end.getRaw(); } + /// Get the total size + UInt getSize() const { return UInt(end.getRaw() - begin.getRaw()); } + + /// Get the offset of a loc in this range + int getOffset(SourceLoc loc) const { SLANG_ASSERT(contains(loc)); return int(loc.getRaw() - begin.getRaw()); } + + SourceRange() + {} + + SourceRange(SourceLoc loc) + : begin(loc) + , end(loc) + {} + + SourceRange(SourceLoc begin, SourceLoc end) + : begin(begin) + , end(end) + {} + + SourceLoc begin; + SourceLoc end; +}; + +// Pre-declare +struct SourceManager; + +// A logical or physical storage object for a range of input code +// that has logically contiguous source locations. +class SourceFile +{ +public: + + /// Returns the line break offsets (in bytes from start of content) + /// Note that this is lazily evaluated - the line breaks are only calculated on the first request + const List& getLineBreakOffsets(); + + /// Set the line break offsets + void setLineBreakOffsets(const uint32_t* offsets, UInt numOffsets); + + /// Calculate the line based on the offset + int calcLineIndexFromOffset(int offset); + + /// Calculate the offset for a line + int calcColumnIndex(int line, int offset); + + /// Get the content holding blob + ISlangBlob* getContentBlob() const { return m_contentBlob; } + + /// True if has full set content + bool hasContent() const { return m_contentBlob != nullptr; } + + /// Get the content size + size_t getContentSize() const { return m_contentSize; } + + /// Get the content + const UnownedStringSlice& getContent() const { return m_content; } + + /// Get path info + const PathInfo& getPathInfo() const { return m_pathInfo; } + + /// Set the content as a blob + void setContents(ISlangBlob* blob); + /// Set the content as a string + void setContents(const String& content); + + /// Calculate a display path -> can canonicalize if necessary + String calcVerbosePath() const; + + /// Get the source manager this was created on + SourceManager* getSourceManager() const { return m_sourceManager; } + + /// Ctor + SourceFile(SourceManager* sourceManager, const PathInfo& pathInfo, size_t contentSize); + /// Dtor + ~SourceFile(); + + protected: + + SourceManager* m_sourceManager; ///< The source manager this belongs to + PathInfo m_pathInfo; ///< The path The logical file path to report for locations inside this span. + ComPtr m_contentBlob; ///< A blob that owns the storage for the file contents. If nullptr, there is no contents + UnownedStringSlice m_content; ///< The actual contents of the file. + size_t m_contentSize; ///< The size of the actual contents + + // In order to speed up lookup of line number information, + // we will cache the starting offset of each line break in + // the input file: + List m_lineBreakOffsets; +}; + +enum class SourceLocType +{ + Nominal, ///< The normal interpretation which takes into account #line directives + Actual, ///< Ignores #line directives - and is the location as seen in the actual file +}; + +// A source location in a format a human might like to see +struct HumaneSourceLoc +{ + PathInfo pathInfo; + Int line = 0; + Int column = 0; +}; + + +/* A SourceView maps to a single span of SourceLoc range and is equivalent to a single include or more precisely use of a source file. +It is distinct from a SourceFile - because a SourceFile may be included multiple times, with different interpretations (depending +on #defines for example). +*/ +class SourceView +{ + public: + + // Each entry represents some contiguous span of locations that + // all map to the same logical file. + struct Entry + { + /// True if this resets the line numbering. It is distinct from a m_lineAdjust being 0, because it also means the path returns to the default. + bool isDefault() const { return m_pathHandle == StringSlicePool::Handle(0); } + + SourceLoc m_startLoc; ///< Where does this entry begin? + StringSlicePool::Handle m_pathHandle; ///< What is the presumed path for this entry. If 0 it means there is no path. + int32_t m_lineAdjust; ///< Adjustment to apply to source line numbers when printing presumed locations. Relative to the line number in the underlying file. + }; + + /// Given a sourceLoc finds the entry associated with it. If returns -1 then no entry is + /// associated with this location, and therefore the location should be interpreted as an offset + /// into the underlying sourceFile. + int findEntryIndex(SourceLoc sourceLoc) const; + + /// Add a line directive for this view. The directiveLoc must of course be in this SourceView + /// The path handle, must have been constructed on the SourceManager associated with the view + /// NOTE! Directives are assumed to be added IN ORDER during parsing such that every directiveLoc > previous + void addLineDirective(SourceLoc directiveLoc, StringSlicePool::Handle pathHandle, int line); + void addLineDirective(SourceLoc directiveLoc, const String& path, int line); + + /// Removes any corrections on line numbers and reverts to the source files path + void addDefaultLineDirective(SourceLoc directiveLoc); + + /// Get the range that this view applies to + const SourceRange& getRange() const { return m_range; } + /// Get the entries + const List& getEntries() const { return m_entries; } + /// Set the entries list + void setEntries(const Entry* entries, UInt numEntries) { m_entries.clear(); m_entries.addRange(entries, numEntries); } + + /// Get the source file holds the contents this view + SourceFile* getSourceFile() const { return m_sourceFile; } + /// Get the source manager + SourceManager* getSourceManager() const { return m_sourceFile->getSourceManager(); } + + /// Get the associated 'content' (the source text) + const UnownedStringSlice& getContent() const { return m_sourceFile->getContent(); } + + /// Get the size of the content + size_t getContentSize() const { return m_sourceFile->getContentSize(); } + + /// Get the humane location + /// Type determines if the location wanted is the original, or the 'normal' (which modifys behavior based on #line directives) + HumaneSourceLoc getHumaneLoc(SourceLoc loc, SourceLocType type = SourceLocType::Nominal); + + /// Get the path associated with a location + PathInfo getPathInfo(SourceLoc loc, SourceLocType type = SourceLocType::Nominal); + + /// Ctor + SourceView(SourceFile* sourceFile, SourceRange range, const String* viewPath): + m_range(range), + m_sourceFile(sourceFile) + { + if (viewPath) + { + m_viewPath = *viewPath; + } + } + + protected: + /// Get the pathInfo from a string handle. If it's 0, it will return the _getPathInfo + PathInfo _getPathInfoFromHandle(StringSlicePool::Handle pathHandle) const; + /// Gets the pathInfo for this view. It may be different from the m_sourceFile's if the path has been + /// overridden by m_viewPath + PathInfo _getPathInfo() const; + + String m_viewPath; ///< Path to this view. If empty the path is the path to the SourceView + + SourceRange m_range; ///< The range that this SourceView applies to + SourceFile* m_sourceFile; ///< The source file. Can hold the line breaks + List m_entries; ///< An array entries describing how we should interpret a range, starting from the start location. +}; + +struct SourceManager +{ + // Initialize a source manager, with an optional parent + void initialize(SourceManager* parent, ISlangFileSystemExt* fileSystemExt); + + /// Allocate a range of SourceLoc locations, these can be used to identify a specific location in the source + SourceRange allocateSourceRange(UInt size); + + /// Create a SourceFile defined with the specified path, and content held within a blob + SourceFile* createSourceFileWithSize(const PathInfo& pathInfo, size_t contentSize); + SourceFile* createSourceFileWithString(const PathInfo& pathInfo, const String& contents); + SourceFile* createSourceFileWithBlob(const PathInfo& pathInfo, ISlangBlob* blob); + + /// Get the humane source location + HumaneSourceLoc getHumaneLoc(SourceLoc loc, SourceLocType type = SourceLocType::Nominal); + + /// Get the path associated with a location + PathInfo getPathInfo(SourceLoc loc, SourceLocType type = SourceLocType::Nominal); + + /// Create a new source view from a file + /// @param sourceFile is the source file that contains the source + /// @param pathInfo is path used to read the file from + SourceView* createSourceView(SourceFile* sourceFile, const PathInfo* pathInfo); + + /// Find a view by a source file location. + /// If not found in this manager will look in the parent SourceManager + /// Returns nullptr if not found. + SourceView* findSourceViewRecursively(SourceLoc loc) const; + + /// Find the SourceView associated with this manager for a specified location + /// Returns nullptr if not found. + SourceView* findSourceView(SourceLoc loc) const; + + /// Searches this manager, and then the parent to see if can find a match for path. + /// If not found returns nullptr. + SourceFile* findSourceFileRecursively(const String& uniqueIdentity) const; + /// Find if the source file is defined on this manager. + SourceFile* findSourceFile(const String& uniqueIdentity) const; + + /// Get the file system associated with this source manager + ISlangFileSystemExt* getFileSystemExt() const { return m_fileSystemExt; } + /// Get the file system associated with this source manager + void setFileSystemExt(ISlangFileSystemExt* fileSystemExt) { m_fileSystemExt = fileSystemExt; } + + /// Add a source file, uniqueIdentity must be unique for this manager AND any parents + void addSourceFile(const String& uniqueIdentity, SourceFile* sourceFile); + + /// Get the slice pool + StringSlicePool& getStringSlicePool() { return m_slicePool; } + + /// Get the source range for just this manager + /// Caution - the range will change if allocations are made to this manager. + SourceRange getSourceRange() const { return SourceRange(m_startLoc, m_nextLoc); } + + /// Get the parent manager to this manager. Returns nullptr if there isn't any. + SourceManager* getParent() const { return m_parent; } + + /// A memory arena to hold allocations that are in scope for the same time as SourceManager + MemoryArena* getMemoryArena() { return &m_memoryArena; } + + /// Allocate a string slice + UnownedStringSlice allocateStringSlice(const UnownedStringSlice& slice); + + SourceManager() : + m_memoryArena(2048) + {} + ~SourceManager(); + + protected: + + // The first location available to this source manager + // (may not be the first location of all, because we might + // have a parent source manager) + SourceLoc m_startLoc; + + // The "parent" source manager that owns locations ahead of `startLoc` + SourceManager* m_parent = nullptr; + + // The location to be used by the next source file to be loaded + SourceLoc m_nextLoc; + + // All of the SourceViews constructed on this SourceManager. These are held in increasing order of range, so can find by doing a binary chop. + List m_sourceViews; + // All of the SourceFiles constructed on this SourceManager. This owns the SourceFile. + List m_sourceFiles; + + StringSlicePool m_slicePool; + + // Memory arena that can be used for holding data to held in scope as long as the Source is + // Can be used for storing the decoded contents of Token. Content for example. + MemoryArena m_memoryArena; + + // Maps uniqueIdentities to source files + Dictionary m_sourceFileMap; + + ComPtr m_fileSystemExt; +}; + +} // namespace Slang + +#endif diff --git a/source/slang/slang-source-stream.h b/source/slang/slang-source-stream.h index 8dcd29c8d..e6b7507c0 100644 --- a/source/slang/slang-source-stream.h +++ b/source/slang/slang-source-stream.h @@ -2,9 +2,9 @@ #ifndef SLANG_SOURCE_STREAM_H_INCLUDED #define SLANG_SOURCE_STREAM_H_INCLUDED -#include "../core/basic.h" +#include "../core/slang-basic.h" -#include "compiler.h" +#include "slang-compiler.h" namespace Slang { diff --git a/source/slang/slang-stdlib.cpp b/source/slang/slang-stdlib.cpp index d2d29beae..036a40ac2 100644 --- a/source/slang/slang-stdlib.cpp +++ b/source/slang/slang-stdlib.cpp @@ -1,8 +1,8 @@ // slang-stdlib.cpp -#include "compiler.h" -#include "ir.h" -#include "syntax.h" +#include "slang-compiler.h" +#include "slang-ir.h" +#include "slang-syntax.h" #define STRINGIZE(x) STRINGIZE2(x) #define STRINGIZE2(x) #x diff --git a/source/slang/slang-stmt-defs.h b/source/slang/slang-stmt-defs.h new file mode 100644 index 000000000..bf25f1706 --- /dev/null +++ b/source/slang/slang-stmt-defs.h @@ -0,0 +1,124 @@ +// slang-stmt-defs.h + +// Syntax class definitions for statements. + +ABSTRACT_SYNTAX_CLASS(ScopeStmt, Stmt) + SYNTAX_FIELD(RefPtr, scopeDecl) +END_SYNTAX_CLASS() + +// A sequence of statements, treated as a single statement +SYNTAX_CLASS(SeqStmt, Stmt) + SYNTAX_FIELD(List>, stmts) +END_SYNTAX_CLASS() + +// The simplest kind of scope statement: just a `{...}` block +SYNTAX_CLASS(BlockStmt, ScopeStmt) + SYNTAX_FIELD(RefPtr, body); +END_SYNTAX_CLASS() + +// A statement that we aren't going to parse or check, because +// we want to let a downstream compiler handle any issues +SYNTAX_CLASS(UnparsedStmt, Stmt) + // The tokens that were contained between `{` and `}` + FIELD(List, tokens) +END_SYNTAX_CLASS() + +SIMPLE_SYNTAX_CLASS(EmptyStmt, Stmt) + +SIMPLE_SYNTAX_CLASS(DiscardStmt, Stmt) + +SYNTAX_CLASS(DeclStmt, Stmt) + SYNTAX_FIELD(RefPtr, decl) +END_SYNTAX_CLASS() + +SYNTAX_CLASS(IfStmt, Stmt) + SYNTAX_FIELD(RefPtr, Predicate) + SYNTAX_FIELD(RefPtr, PositiveStatement) + SYNTAX_FIELD(RefPtr, NegativeStatement) +END_SYNTAX_CLASS() + +// A statement that can be escaped with a `break` +ABSTRACT_SYNTAX_CLASS(BreakableStmt, ScopeStmt) +END_SYNTAX_CLASS() + +SYNTAX_CLASS(SwitchStmt, BreakableStmt) + SYNTAX_FIELD(RefPtr, condition) + SYNTAX_FIELD(RefPtr, body) +END_SYNTAX_CLASS() + +// A statement that is expected to appear lexically nested inside +// some other construct, and thus needs to keep track of the +// outer statement that it is associated with... +ABSTRACT_SYNTAX_CLASS(ChildStmt, Stmt) + DECL_FIELD(Stmt*, parentStmt RAW(= nullptr)) +END_SYNTAX_CLASS() + +// a `case` or `default` statement inside a `switch` +// +// Note(tfoley): A correct AST for a C-like language would treat +// these as a labelled statement, and so they would contain a +// sub-statement. I'm leaving that out for now for simplicity. +ABSTRACT_SYNTAX_CLASS(CaseStmtBase, ChildStmt) +END_SYNTAX_CLASS() + +// a `case` statement inside a `switch` +SYNTAX_CLASS(CaseStmt, CaseStmtBase) + SYNTAX_FIELD(RefPtr, expr) +END_SYNTAX_CLASS() + +// a `default` statement inside a `switch` +SIMPLE_SYNTAX_CLASS(DefaultStmt, CaseStmtBase) + +// A statement that represents a loop, and can thus be escaped with a `continue` +ABSTRACT_SYNTAX_CLASS(LoopStmt, BreakableStmt) +END_SYNTAX_CLASS() + +// A `for` statement +SYNTAX_CLASS(ForStmt, LoopStmt) + SYNTAX_FIELD(RefPtr, InitialStatement) + SYNTAX_FIELD(RefPtr, SideEffectExpression) + SYNTAX_FIELD(RefPtr, PredicateExpression) + SYNTAX_FIELD(RefPtr, Statement) +END_SYNTAX_CLASS() + +// A `for` statement in a language that doesn't restrict the scope +// of the loop variable to the body. +SYNTAX_CLASS(UnscopedForStmt, ForStmt); +END_SYNTAX_CLASS() + +SYNTAX_CLASS(WhileStmt, LoopStmt) + SYNTAX_FIELD(RefPtr, Predicate) + SYNTAX_FIELD(RefPtr, Statement) +END_SYNTAX_CLASS() + +SYNTAX_CLASS(DoWhileStmt, LoopStmt) + SYNTAX_FIELD(RefPtr, Statement) + SYNTAX_FIELD(RefPtr, Predicate) +END_SYNTAX_CLASS() + +// A compile-time, range-based `for` loop, which will not appear in the output code +SYNTAX_CLASS(CompileTimeForStmt, ScopeStmt) + SYNTAX_FIELD(RefPtr, varDecl) + SYNTAX_FIELD(RefPtr, rangeBeginExpr) + SYNTAX_FIELD(RefPtr, rangeEndExpr) + SYNTAX_FIELD(RefPtr, body) + SYNTAX_FIELD(RefPtr, rangeBeginVal) + SYNTAX_FIELD(RefPtr, rangeEndVal) +END_SYNTAX_CLASS() + +// The case of child statements that do control flow relative +// to their parent statement. +ABSTRACT_SYNTAX_CLASS(JumpStmt, ChildStmt) +END_SYNTAX_CLASS() + +SIMPLE_SYNTAX_CLASS(BreakStmt, JumpStmt) + +SIMPLE_SYNTAX_CLASS(ContinueStmt, JumpStmt) + +SYNTAX_CLASS(ReturnStmt, Stmt) + SYNTAX_FIELD(RefPtr, Expression) +END_SYNTAX_CLASS() + +SYNTAX_CLASS(ExpressionStmt, Stmt) + SYNTAX_FIELD(RefPtr, Expression) +END_SYNTAX_CLASS() diff --git a/source/slang/slang-syntax-base-defs.h b/source/slang/slang-syntax-base-defs.h new file mode 100644 index 000000000..2f7c8b1fa --- /dev/null +++ b/source/slang/slang-syntax-base-defs.h @@ -0,0 +1,307 @@ +// slang-syntax-base-defs.h + +// This file defines the primary base classes for the hierarchy of +// AST nodes and related objects. For example, this is where the +// basic `Decl`, `Stmt`, `Expr`, `type`, etc. definitions come from. + +ABSTRACT_SYNTAX_CLASS(NodeBase, RefObject) + // A helper to access the corresponding class on a concrete instance + RAW( + virtual SyntaxClass getClass() = 0; + ) +END_SYNTAX_CLASS() + +// Base class for all nodes representing actual syntax +// (thus having a location in the source code) +ABSTRACT_SYNTAX_CLASS(SyntaxNodeBase, NodeBase) + // The primary source location associated with this AST node + FIELD(SourceLoc, loc) +END_SYNTAX_CLASS() + +// Base class for compile-time values (most often a type). +// These are *not* syntax nodes, because they do not have +// a unique location, and any two `Val`s representing +// the same value should be conceptually equal. +ABSTRACT_SYNTAX_CLASS(Val, NodeBase) + RAW(typedef IValVisitor Visitor;) + + RAW(virtual void accept(IValVisitor* visitor, void* extra) = 0;) + + RAW( + // construct a new value by applying a set of parameter + // substitutions to this one + RefPtr Substitute(SubstitutionSet subst); + + // Lower-level interface for substitution. Like the basic + // `Substitute` above, but also takes a by-reference + // integer parameter that should be incremented when + // returning a modified value (this can help the caller + // decide whether they need to do anything). + virtual RefPtr SubstituteImpl(SubstitutionSet subst, int* ioDiff); + + virtual bool EqualsVal(Val* val) = 0; + virtual String ToString() = 0; + virtual int GetHashCode() = 0; + bool operator == (const Val & v) + { + return EqualsVal(const_cast(&v)); + } + ) +END_SYNTAX_CLASS() + +RAW( + class Type; + + template + SLANG_FORCE_INLINE T* as(Type* obj); + template + SLANG_FORCE_INLINE const T* as(const Type* obj); + ) + +// A type, representing a classifier for some term in the AST. +// +// Types can include "sugar" in that they may refer to a +// `typedef` which gives them a good name when printed as +// part of diagnostic messages. +// +// In order to operation on types, though, we often want +// to look past any sugar, and operate on an underlying +// "canonical" type. The representation caches a pointer to +// a canonical type on every type, so we can easily +// operate on the raw representation when needed. +ABSTRACT_SYNTAX_CLASS(Type, Val) + RAW(typedef ITypeVisitor Visitor;) + + RAW(virtual void accept(IValVisitor* visitor, void* extra) override;) + RAW(virtual void accept(ITypeVisitor* visitor, void* extra) = 0;) + +RAW( +public: + Session* getSession() { return this->session; } + void setSession(Session* s) { this->session = s; } + + bool Equals(Type* type); + + Type* GetCanonicalType(); + + virtual RefPtr SubstituteImpl(SubstitutionSet subst, int* ioDiff) override; + + virtual bool EqualsVal(Val* val) override; + + ~Type(); + +protected: + virtual bool EqualsImpl(Type* type) = 0; + + virtual RefPtr CreateCanonicalType() = 0; + Type* canonicalType = nullptr; + + Session* session = nullptr; + ) +END_SYNTAX_CLASS() +RAW( + template + SLANG_FORCE_INLINE T* as(Type* obj) { return obj ? dynamicCast(obj->GetCanonicalType()) : nullptr; } + template + SLANG_FORCE_INLINE const T* as(const Type* obj) { return obj ? dynamicCast(const_cast(obj)->GetCanonicalType()) : nullptr; } +) + +// A substitution represents a binding of certain +// type-level variables to concrete argument values +ABSTRACT_SYNTAX_CLASS(Substitutions, RefObject) + // The next outer that this one refines. + FIELD(RefPtr, outer) + + RAW( + // Apply a set of substitutions to the bindings in this substitution + virtual RefPtr applySubstitutionsShallow(SubstitutionSet substSet, RefPtr substOuter, int* ioDiff) = 0; + + // Check if these are equivalent substitutiosn to another set + virtual bool Equals(Substitutions* subst) = 0; + virtual int GetHashCode() const = 0; + ) +END_SYNTAX_CLASS() + +SYNTAX_CLASS(GenericSubstitution, Substitutions) + // The generic declaration that defines the + // parameters we are binding to arguments + DECL_FIELD(GenericDecl*, genericDecl) + + // The actual values of the arguments + SYNTAX_FIELD(List>, args) + + RAW( + // Apply a set of substitutions to the bindings in this substitution + virtual RefPtr applySubstitutionsShallow(SubstitutionSet substSet, RefPtr substOuter, int* ioDiff) override; + + // Check if these are equivalent substitutiosn to another set + virtual bool Equals(Substitutions* subst) override; + + virtual int GetHashCode() const override + { + int rs = 0; + for (auto && v : args) + { + rs ^= v->GetHashCode(); + rs *= 16777619; + } + return rs; + } + ) +END_SYNTAX_CLASS() + +SYNTAX_CLASS(ThisTypeSubstitution, Substitutions) + // The declaration of the interface that we are specializing + FIELD_INIT(InterfaceDecl*, interfaceDecl, nullptr) + + // A witness that shows that the concrete type used to + // specialize the interface conforms to the interface. + FIELD(RefPtr, witness) + + // The actual type that provides the lookup scope for an associated type + RAW( + // Apply a set of substitutions to the bindings in this substitution + virtual RefPtr applySubstitutionsShallow(SubstitutionSet substSet, RefPtr substOuter, int* ioDiff) override; + + // Check if these are equivalent substitutiosn to another set + virtual bool Equals(Substitutions* subst) override; + + virtual int GetHashCode() const override; + ) +END_SYNTAX_CLASS() + +SYNTAX_CLASS(GlobalGenericParamSubstitution, Substitutions) + // the type_param decl to be substituted + DECL_FIELD(GlobalGenericParamDecl*, paramDecl) + + // the actual type to substitute in + SYNTAX_FIELD(RefPtr, actualType) + + RAW( + struct ConstraintArg + { + RefPtr decl; + RefPtr val; + }; + ) + + // the values that satisfy any constraints on the type parameter + SYNTAX_FIELD(List, constraintArgs) + +RAW( + // Apply a set of substitutions to the bindings in this substitution + virtual RefPtr applySubstitutionsShallow(SubstitutionSet substSet, RefPtr substOuter, int* ioDiff) override; + + // Check if these are equivalent substitutiosn to another set + virtual bool Equals(Substitutions* subst) override; + + virtual int GetHashCode() const override + { + int rs = actualType->GetHashCode(); + for (auto && a : constraintArgs) + { + rs = combineHash(rs, a.val->GetHashCode()); + } + return rs; + } + ) +END_SYNTAX_CLASS() + +ABSTRACT_SYNTAX_CLASS(SyntaxNode, SyntaxNodeBase) +END_SYNTAX_CLASS() + +// +// All modifiers are represented as full-fledged objects in the AST +// (that is, we don't use a bitfield, even for simple/common flags). +// This ensures that we can track source locations for all modifiers. +// +ABSTRACT_SYNTAX_CLASS(Modifier, SyntaxNode) + RAW(typedef IModifierVisitor Visitor;) + + RAW(virtual void accept(IModifierVisitor* visitor, void* extra) = 0;) + + // Next modifier in linked list of modifiers on same piece of syntax + SYNTAX_FIELD(RefPtr, next) + + // The keyword that was used to introduce t that was used to name this modifier. + FIELD(Name*, name) + + RAW( + Name* getName() { return name; } + NameLoc getNameAndLoc() { return NameLoc(name, loc); } + ) +END_SYNTAX_CLASS() + +// A syntax node which can have modifiers applied +ABSTRACT_SYNTAX_CLASS(ModifiableSyntaxNode, SyntaxNode) + + SYNTAX_FIELD(Modifiers, modifiers) + + RAW( + template + FilteredModifierList GetModifiersOfType() { return FilteredModifierList(modifiers.first.Ptr()); } + + // Find the first modifier of a given type, or return `nullptr` if none is found. + template + T* FindModifier() + { + return *GetModifiersOfType().begin(); + } + + template + bool HasModifier() { return FindModifier() != nullptr; } + ) +END_SYNTAX_CLASS() + + +// An intermediate type to represent either a single declaration, or a group of declarations +ABSTRACT_SYNTAX_CLASS(DeclBase, ModifiableSyntaxNode) + RAW(typedef IDeclVisitor Visitor;) + + RAW(virtual void accept(IDeclVisitor* visitor, void* extra) = 0;) + + +END_SYNTAX_CLASS() + +ABSTRACT_SYNTAX_CLASS(Decl, DeclBase) + DECL_FIELD(ContainerDecl*, ParentDecl RAW(=nullptr)) + + FIELD(NameLoc, nameAndLoc) + + RAW( + Name* getName() { return nameAndLoc.name; } + SourceLoc getNameLoc() { return nameAndLoc.loc ; } + NameLoc getNameAndLoc() { return nameAndLoc ; } + ) + + + FIELD_INIT(DeclCheckState, checkState, DeclCheckState::Unchecked) + + // The next declaration defined in the same container with the same name + DECL_FIELD(Decl*, nextInContainerWithSameName RAW(= nullptr)) + + RAW( + bool IsChecked(DeclCheckState state) { return checkState >= state; } + void SetCheckState(DeclCheckState state) + { + SLANG_RELEASE_ASSERT(state >= checkState); + checkState = state; + } + ) +END_SYNTAX_CLASS() + +ABSTRACT_SYNTAX_CLASS(Expr, SyntaxNode) + RAW(typedef IExprVisitor Visitor;) + + FIELD(QualType, type) + + RAW(virtual void accept(IExprVisitor* visitor, void* extra) = 0;) + +END_SYNTAX_CLASS() + +ABSTRACT_SYNTAX_CLASS(Stmt, ModifiableSyntaxNode) + RAW(typedef IStmtVisitor Visitor;) + + RAW(virtual void accept(IStmtVisitor* visitor, void* extra) = 0;) + +END_SYNTAX_CLASS() diff --git a/source/slang/slang-syntax-defs.h b/source/slang/slang-syntax-defs.h new file mode 100644 index 000000000..5a16c3709 --- /dev/null +++ b/source/slang/slang-syntax-defs.h @@ -0,0 +1,10 @@ +// slang-syntax-defs.h + +#include "slang-syntax-base-defs.h" + +#include "slang-expr-defs.h" +#include "slang-decl-defs.h" +#include "slang-modifier-defs.h" +#include "slang-stmt-defs.h" +#include "slang-type-defs.h" +#include "slang-val-defs.h" diff --git a/source/slang/slang-syntax-visitors.h b/source/slang/slang-syntax-visitors.h new file mode 100644 index 000000000..dc230f051 --- /dev/null +++ b/source/slang/slang-syntax-visitors.h @@ -0,0 +1,36 @@ +#ifndef SLANG_SYNTAX_VISITORS_H +#define SLANG_SYNTAX_VISITORS_H + +#include "slang-diagnostics.h" +#include "slang-syntax.h" + +namespace Slang +{ + class DiagnosticSink; + class EntryPoint; + class Linkage; + class Module; + class ShaderCompiler; + class ShaderLinkInfo; + class ShaderSymbol; + + class TranslationUnitRequest; + + void checkTranslationUnit( + TranslationUnitRequest* translationUnit); + + // Look for a module that matches the given name: + // either one we've loaded already, or one we + // can find vai the search paths available to us. + // + // Needed by import declaration checking. + // + // TODO: need a better location to declare this. + RefPtr findOrImportModule( + Linkage* linkage, + Name* name, + SourceLoc const& loc, + DiagnosticSink* sink); +} + +#endif diff --git a/source/slang/slang-syntax.cpp b/source/slang/slang-syntax.cpp new file mode 100644 index 000000000..08d671241 --- /dev/null +++ b/source/slang/slang-syntax.cpp @@ -0,0 +1,2865 @@ +#include "slang-syntax.h" + +#include "slang-compiler.h" +#include "slang-visitor.h" + +#include +#include + +namespace Slang +{ + // BasicExpressionType + + bool BasicExpressionType::EqualsImpl(Type * type) + { + auto basicType = dynamicCast(type); + return basicType && basicType->baseType == this->baseType; + } + + RefPtr BasicExpressionType::CreateCanonicalType() + { + // A basic type is already canonical, in our setup + return this; + } + + // Generate dispatch logic and other definitions for all syntax classes +#define SYNTAX_CLASS(NAME, BASE) /* empty */ +#include "slang-object-meta-begin.h" + +#include "slang-syntax-base-defs.h" +#undef SYNTAX_CLASS +#undef ABSTRACT_SYNTAX_CLASS + +#define ABSTRACT_SYNTAX_CLASS(NAME, BASE) \ + template<> \ + SyntaxClassBase::ClassInfo const SyntaxClassBase::Impl::kClassInfo = { #NAME, &SyntaxClassBase::Impl::kClassInfo, nullptr }; + +#define SYNTAX_CLASS(NAME, BASE) \ + void NAME::accept(NAME::Visitor* visitor, void* extra) \ + { visitor->dispatch_##NAME(this, extra); } \ + template<> \ + void* SyntaxClassBase::Impl::createFunc() { return new NAME(); } \ + SyntaxClass NAME::getClass() { return Slang::getClass(); } \ + template<> \ + SyntaxClassBase::ClassInfo const SyntaxClassBase::Impl::kClassInfo = { #NAME, &SyntaxClassBase::Impl::kClassInfo, &SyntaxClassBase::Impl::createFunc }; + +template<> +SyntaxClassBase::ClassInfo const SyntaxClassBase::Impl::kClassInfo = { "RefObject", nullptr, nullptr }; + +ABSTRACT_SYNTAX_CLASS(NodeBase, RefObject); +ABSTRACT_SYNTAX_CLASS(SyntaxNodeBase, NodeBase); +ABSTRACT_SYNTAX_CLASS(SyntaxNode, SyntaxNodeBase); +ABSTRACT_SYNTAX_CLASS(ModifiableSyntaxNode, SyntaxNode); +ABSTRACT_SYNTAX_CLASS(DeclBase, ModifiableSyntaxNode); +ABSTRACT_SYNTAX_CLASS(Decl, DeclBase); +ABSTRACT_SYNTAX_CLASS(Stmt, ModifiableSyntaxNode); +ABSTRACT_SYNTAX_CLASS(Val, NodeBase); +ABSTRACT_SYNTAX_CLASS(Type, Val); +ABSTRACT_SYNTAX_CLASS(Modifier, SyntaxNodeBase); +ABSTRACT_SYNTAX_CLASS(Expr, SyntaxNode); + +ABSTRACT_SYNTAX_CLASS(Substitutions, SyntaxNode); +ABSTRACT_SYNTAX_CLASS(GenericSubstitution, Substitutions); +ABSTRACT_SYNTAX_CLASS(ThisTypeSubstitution, Substitutions); +ABSTRACT_SYNTAX_CLASS(GlobalGenericParamSubstitution, Substitutions); + +#include "slang-expr-defs.h" +#include "slang-decl-defs.h" +#include "slang-modifier-defs.h" +#include "slang-stmt-defs.h" +#include "slang-type-defs.h" +#include "slang-val-defs.h" +#include "slang-object-meta-end.h" + +bool SyntaxClassBase::isSubClassOfImpl(SyntaxClassBase const& super) const +{ + SyntaxClassBase::ClassInfo const* info = classInfo; + while (info) + { + if (info == super.classInfo) + return true; + + info = info->baseClass; + } + + return false; +} + +void Type::accept(IValVisitor* visitor, void* extra) +{ + accept((ITypeVisitor*)visitor, extra); +} + + // TypeExp + + bool TypeExp::Equals(Type* other) + { + return type->Equals(other); + } + + bool TypeExp::Equals(RefPtr other) + { + return type->Equals(other.Ptr()); + } + + // BasicExpressionType + + BasicExpressionType* BasicExpressionType::GetScalarType() + { + return this; + } + + // + + Type::~Type() + { + // If the canonicalType !=nullptr AND it is not set to this (ie the canonicalType is another object) + // then it needs to be released because it's owned by this object. + if (canonicalType && canonicalType != this) + { + canonicalType->releaseReference(); + } + } + + bool Type::Equals(Type * type) + { + return GetCanonicalType()->EqualsImpl(type->GetCanonicalType()); + } + + bool Type::EqualsVal(Val* val) + { + if (auto type = dynamicCast(val)) + return const_cast(this)->Equals(type); + return false; + } + + RefPtr Type::SubstituteImpl(SubstitutionSet subst, int* ioDiff) + { + int diff = 0; + auto canSubst = GetCanonicalType()->SubstituteImpl(subst, &diff); + + // If nothing changed, then don't drop any sugar that is applied + if (!diff) + return this; + + // If the canonical type changed, then we return a canonical type, + // rather than try to re-construct any amount of sugar + (*ioDiff)++; + return canSubst; + } + + Type* Type::GetCanonicalType() + { + Type* et = const_cast(this); + if (!et->canonicalType) + { + // TODO(tfoley): worry about thread safety here? + auto canType = et->CreateCanonicalType(); + et->canonicalType = canType; + + // TODO(js): That this detachs when canType == this is a little surprising. It would seem + // as if this would create a circular reference on the object, but in practice there are + // no leaks so appears correct. + // That the dtor only releases if != this, also makes it surprising. + canType.detach(); + + SLANG_ASSERT(et->canonicalType); + } + return et->canonicalType; + } + + void Session::initializeTypes() + { + errorType = new ErrorType(); + errorType->setSession(this); + + initializerListType = new InitializerListType(); + initializerListType->setSession(this); + + overloadedType = new OverloadGroupType(); + overloadedType->setSession(this); + } + + Type* Session::getBoolType() + { + return getBuiltinType(BaseType::Bool); + } + + Type* Session::getHalfType() + { + return getBuiltinType(BaseType::Half); + } + + Type* Session::getFloatType() + { + return getBuiltinType(BaseType::Float); + } + + Type* Session::getDoubleType() + { + return getBuiltinType(BaseType::Double); + } + + Type* Session::getIntType() + { + return getBuiltinType(BaseType::Int); + } + + Type* Session::getInt64Type() + { + return getBuiltinType(BaseType::Int64); + } + + Type* Session::getUIntType() + { + return getBuiltinType(BaseType::UInt); + } + + Type* Session::getUInt64Type() + { + return getBuiltinType(BaseType::UInt64); + } + + Type* Session::getVoidType() + { + return getBuiltinType(BaseType::Void); + } + + Type* Session::getBuiltinType(BaseType flavor) + { + return RefPtr(builtinTypes[(int)flavor]); + } + + Type* Session::getInitializerListType() + { + return initializerListType; + } + + Type* Session::getOverloadedType() + { + return overloadedType; + } + + Type* Session::getErrorType() + { + return errorType; + } + + Type* Session::getStringType() + { + if (stringType == nullptr) + { + auto stringTypeDecl = findMagicDecl(this, "StringType"); + stringType = DeclRefType::Create(this, makeDeclRef(stringTypeDecl)); + } + return stringType; + } + + Type* Session::getEnumTypeType() + { + if (enumTypeType == nullptr) + { + auto enumTypeTypeDecl = findMagicDecl(this, "EnumTypeType"); + enumTypeType = DeclRefType::Create(this, makeDeclRef(enumTypeTypeDecl)); + } + return enumTypeType; + } + + RefPtr Session::getPtrType( + RefPtr valueType) + { + return getPtrType(valueType, "PtrType").dynamicCast(); + } + + // Construct the type `Out` + RefPtr Session::getOutType(RefPtr valueType) + { + return getPtrType(valueType, "OutType").dynamicCast(); + } + + RefPtr Session::getInOutType(RefPtr valueType) + { + return getPtrType(valueType, "InOutType").dynamicCast(); + } + + RefPtr Session::getRefType(RefPtr valueType) + { + return getPtrType(valueType, "RefType").dynamicCast(); + } + + RefPtr Session::getPtrType(RefPtr valueType, char const* ptrTypeName) + { + auto genericDecl = findMagicDecl(this, ptrTypeName).dynamicCast(); + return getPtrType(valueType, genericDecl); + } + + RefPtr Session::getPtrType(RefPtr valueType, GenericDecl* genericDecl) + { + auto typeDecl = genericDecl->inner; + + auto substitutions = new GenericSubstitution(); + substitutions->genericDecl = genericDecl; + substitutions->args.add(valueType); + + auto declRef = DeclRef(typeDecl.Ptr(), substitutions); + auto rsType = DeclRefType::Create( + this, + declRef); + return as( rsType); + } + + RefPtr Session::getArrayType( + Type* elementType, + IntVal* elementCount) + { + RefPtr arrayType = new ArrayExpressionType(); + arrayType->setSession(this); + arrayType->baseType = elementType; + arrayType->ArrayLength = elementCount; + return arrayType; + } + + SyntaxClass Session::findSyntaxClass(Name* name) + { + SyntaxClass syntaxClass; + if (mapNameToSyntaxClass.TryGetValue(name, syntaxClass)) + return syntaxClass; + + return SyntaxClass(); + } + + + + bool ArrayExpressionType::EqualsImpl(Type* type) + { + auto arrType = as(type); + if (!arrType) + return false; + return (areValsEqual(ArrayLength, arrType->ArrayLength) && baseType->Equals(arrType->baseType.Ptr())); + } + + RefPtr ArrayExpressionType::SubstituteImpl(SubstitutionSet subst, int* ioDiff) + { + int diff = 0; + auto elementType = baseType->SubstituteImpl(subst, &diff).as(); + auto arrlen = ArrayLength->SubstituteImpl(subst, &diff).as(); + SLANG_ASSERT(arrlen); + if (diff) + { + *ioDiff = 1; + auto rsType = getArrayType( + elementType, + arrlen); + return rsType; + } + return this; + } + + RefPtr ArrayExpressionType::CreateCanonicalType() + { + auto canonicalElementType = baseType->GetCanonicalType(); + auto canonicalArrayType = getArrayType( + canonicalElementType, + ArrayLength); + return canonicalArrayType; + } + int ArrayExpressionType::GetHashCode() + { + if (ArrayLength) + return (baseType->GetHashCode() * 16777619) ^ ArrayLength->GetHashCode(); + else + return baseType->GetHashCode(); + } + Slang::String ArrayExpressionType::ToString() + { + if (ArrayLength) + return baseType->ToString() + "[" + ArrayLength->ToString() + "]"; + else + return baseType->ToString() + "[]"; + } + + // DeclRefType + + String DeclRefType::ToString() + { + return declRef.toString(); + } + + int DeclRefType::GetHashCode() + { + return (declRef.GetHashCode() * 16777619) ^ (int)(typeid(this).hash_code()); + } + + bool DeclRefType::EqualsImpl(Type * type) + { + if (auto declRefType = as(type)) + { + return declRef.Equals(declRefType->declRef); + } + return false; + } + + RefPtr DeclRefType::CreateCanonicalType() + { + // A declaration reference is already canonical + return this; + } + + // + // RequirementWitness + // + + RequirementWitness::RequirementWitness(RefPtr val) + : m_flavor(Flavor::val) + , m_obj(val) + {} + + + RequirementWitness::RequirementWitness(RefPtr witnessTable) + : m_flavor(Flavor::witnessTable) + , m_obj(witnessTable) + {} + + RefPtr RequirementWitness::getWitnessTable() + { + SLANG_ASSERT(getFlavor() == Flavor::witnessTable); + return m_obj.as(); + } + + + RequirementWitness RequirementWitness::specialize(SubstitutionSet const& subst) + { + switch(getFlavor()) + { + default: + SLANG_UNEXPECTED("unknown requirement witness flavor"); + case RequirementWitness::Flavor::none: + return RequirementWitness(); + + case RequirementWitness::Flavor::declRef: + { + int diff = 0; + return RequirementWitness( + getDeclRef().SubstituteImpl(subst, &diff)); + } + + case RequirementWitness::Flavor::val: + { + auto val = getVal(); + SLANG_ASSERT(val); + + return RequirementWitness( + val->Substitute(subst)); + } + } + } + + RequirementWitness tryLookUpRequirementWitness( + SubtypeWitness* subtypeWitness, + Decl* requirementKey) + { + if(auto declaredSubtypeWitness = as(subtypeWitness)) + { + if(auto inheritanceDeclRef = declaredSubtypeWitness->declRef.as()) + { + // A conformance that was declared as part of an inheritance clause + // will have built up a dictionary of the satisfying declarations + // for each of its requirements. + RequirementWitness requirementWitness; + auto witnessTable = inheritanceDeclRef.getDecl()->witnessTable; + if(witnessTable && witnessTable->requirementDictionary.TryGetValue(requirementKey, requirementWitness)) + { + // The `inheritanceDeclRef` has substitutions applied to it that + // *aren't* present in the `requirementWitness`, because it was + // derived by the front-end when looking at the `InheritanceDecl` alone. + // + // We need to apply these substitutions here for the result to make sense. + // + // E.g., if we have a case like: + // + // interface ISidekick { associatedtype Hero; void follow(Hero hero); } + // struct Sidekick : ISidekick { typedef H Hero; void follow(H hero) {} }; + // + // void followHero(S s, S.Hero h) + // { + // s.follow(h); + // } + // + // Batman batman; + // Sidekick robin; + // followHero>(robin, batman); + // + // The second argument to `followHero` is `batman`, which has type `Batman`. + // The parameter declaration lists the type `S.Hero`, which is a reference + // to an associated type. The front end will expand this into something + // like `S.{S:ISidekick}.Hero` - that is, we'll end up with a declaration + // reference to `ISidekick.Hero` with a this-type substitution that references + // the `{S:ISidekick}` declaration as a witness. + // + // The front-end will expand the generic application `followHero>` + // to `followHero, {Sidekick:ISidekick}[H->Batman]>` + // (that is, the hidden second parameter will reference the inheritance + // clause on `Sidekick`, with a substitution to map `H` to `Batman`. + // + // This step should map the `{S:ISidekick}` declaration over to the + // concrete `{Sidekick:ISidekick}[H->Batman]` inheritance declaration. + // At that point `tryLookupRequirementWitness` might be called, because + // we want to look up the witness for the key `ISidekick.Hero` in the + // inheritance decl-ref that is `{Sidekick:ISidekick}[H->Batman]`. + // + // That lookup will yield us a reference to the typedef `Sidekick.Hero`, + // *without* any substitution for `H` (or rather, with a default one that + // maps `H` to `H`. + // + // So, in order to get the *right* end result, we need to apply + // the substitutions from the inheritance decl-ref to the witness. + // + requirementWitness = requirementWitness.specialize(inheritanceDeclRef.substitutions); + + return requirementWitness; + } + } + } + + // TODO: should handle the transitive case here too + + return RequirementWitness(); + } + + RefPtr DeclRefType::SubstituteImpl(SubstitutionSet subst, int* ioDiff) + { + if (!subst) return this; + + // the case we especially care about is when this type references a declaration + // of a generic parameter, since that is what we might be substituting... + if (auto genericTypeParamDecl = as(declRef.getDecl())) + { + // search for a substitution that might apply to us + for(auto s = subst.substitutions; s; s = s->outer) + { + auto genericSubst = s.as(); + if(!genericSubst) + continue; + + // the generic decl associated with the substitution list must be + // the generic decl that declared this parameter + auto genericDecl = genericSubst->genericDecl; + if (genericDecl != genericTypeParamDecl->ParentDecl) + continue; + + int index = 0; + for (auto m : genericDecl->Members) + { + if (m.Ptr() == genericTypeParamDecl) + { + // We've found it, so return the corresponding specialization argument + (*ioDiff)++; + return genericSubst->args[index]; + } + else if (auto typeParam = as(m)) + { + index++; + } + else if (auto valParam = as(m)) + { + index++; + } + else + { + } + } + } + } + else if (auto globalGenParam = as(declRef.getDecl())) + { + // search for a substitution that might apply to us + for(auto s = subst.substitutions; s; s = s->outer) + { + auto genericSubst = as(s); + if(!genericSubst) + continue; + + if (genericSubst->paramDecl == globalGenParam) + { + (*ioDiff)++; + return genericSubst->actualType; + } + } + } + int diff = 0; + DeclRef substDeclRef = declRef.SubstituteImpl(subst, &diff); + + if (!diff) + return this; + + // Make sure to record the difference! + *ioDiff += diff; + + // If this type is a reference to an associated type declaration, + // and the substitutions provide a "this type" substitution for + // the outer interface, then try to replace the type with the + // actual value of the associated type for the given implementation. + // + if(auto substAssocTypeDecl = as(substDeclRef.decl)) + { + for(auto s = substDeclRef.substitutions.substitutions; s; s = s->outer) + { + auto thisSubst = s.as(); + if(!thisSubst) + continue; + + if(auto interfaceDecl = as(substAssocTypeDecl->ParentDecl)) + { + if(thisSubst->interfaceDecl == interfaceDecl) + { + // We need to look up the declaration that satisfies + // the requirement named by the associated type. + Decl* requirementKey = substAssocTypeDecl; + RequirementWitness requirementWitness = tryLookUpRequirementWitness(thisSubst->witness, requirementKey); + switch(requirementWitness.getFlavor()) + { + default: + // No usable value was found, so there is nothing we can do. + break; + + case RequirementWitness::Flavor::val: + { + auto satisfyingVal = requirementWitness.getVal(); + return satisfyingVal; + } + break; + } + } + } + } + } + + // Re-construct the type in case we are using a specialized sub-class + return DeclRefType::Create(getSession(), substDeclRef); + } + + static RefPtr ExtractGenericArgType(RefPtr val) + { + auto type = val.as(); + SLANG_RELEASE_ASSERT(type.Ptr()); + return type; + } + + static RefPtr ExtractGenericArgInteger(RefPtr val) + { + auto intVal = val.as(); + SLANG_RELEASE_ASSERT(intVal.Ptr()); + return intVal; + } + + DeclRef createDefaultSubstitutionsIfNeeded( + Session* session, + DeclRef declRef) + { + // It is possible that `declRef` refers to a generic type, + // but does not specify arguments for its generic parameters. + // (E.g., this happens when referring to a generic type from + // within its own member functions). To handle this case, + // we will construct a default specialization at the use + // site if needed. + // + // This same logic should also apply to declarations nested + // more than one level inside of a generic (e.g., a `typdef` + // inside of a generic `struct`). + // + // Similarly, it needs to work for multiple levels of + // nested generics. + // + + // We are going to build up a list of substitutions that need + // to be applied to the decl-ref to make it specialized. + RefPtr substsToApply; + RefPtr* link = &substsToApply; + + RefPtr dd = declRef.getDecl(); + for(;;) + { + RefPtr childDecl = dd; + RefPtr parentDecl = dd->ParentDecl; + if(!parentDecl) + break; + + dd = parentDecl; + + if(auto genericParentDecl = parentDecl.as()) + { + // Don't specialize any parameters of a generic. + if(childDecl != genericParentDecl->inner) + break; + + // We have a generic ancestor, but do we have an substitutions for it? + RefPtr foundSubst; + for(auto s = declRef.substitutions.substitutions; s; s = s->outer) + { + auto genSubst = s.as(); + if(!genSubst) + continue; + + if(genSubst->genericDecl != genericParentDecl) + continue; + + // Okay, we found a matching substitution, + // so there is nothing to be done. + foundSubst = genSubst; + break; + } + + if(!foundSubst) + { + RefPtr newSubst = createDefaultSubsitutionsForGeneric( + session, + genericParentDecl, + nullptr); + + *link = newSubst; + link = &newSubst->outer; + } + } + } + + if(!substsToApply) + return declRef; + + int diff = 0; + return declRef.SubstituteImpl(substsToApply, &diff); + } + + // TODO: need to figure out how to unify this with the logic + // in the generic case... + RefPtr DeclRefType::Create( + Session* session, + DeclRef declRef) + { + declRef = createDefaultSubstitutionsIfNeeded(session, declRef); + + if (auto builtinMod = declRef.getDecl()->FindModifier()) + { + auto type = new BasicExpressionType(builtinMod->tag); + type->setSession(session); + type->declRef = declRef; + return type; + } + else if (auto magicMod = declRef.getDecl()->FindModifier()) + { + GenericSubstitution* subst = nullptr; + for(auto s = declRef.substitutions.substitutions; s; s = s->outer) + { + if(auto genericSubst = s.as()) + { + subst = genericSubst; + break; + } + } + + if (magicMod->name == "SamplerState") + { + auto type = new SamplerStateType(); + type->setSession(session); + type->declRef = declRef; + type->flavor = SamplerStateFlavor(magicMod->tag); + return type; + } + else if (magicMod->name == "Vector") + { + SLANG_ASSERT(subst && subst->args.getCount() == 2); + auto vecType = new VectorExpressionType(); + vecType->setSession(session); + vecType->declRef = declRef; + vecType->elementType = ExtractGenericArgType(subst->args[0]); + vecType->elementCount = ExtractGenericArgInteger(subst->args[1]); + return vecType; + } + else if (magicMod->name == "Matrix") + { + SLANG_ASSERT(subst && subst->args.getCount() == 3); + auto matType = new MatrixExpressionType(); + matType->setSession(session); + matType->declRef = declRef; + return matType; + } + else if (magicMod->name == "Texture") + { + SLANG_ASSERT(subst && subst->args.getCount() >= 1); + auto textureType = new TextureType( + TextureFlavor(magicMod->tag), + ExtractGenericArgType(subst->args[0])); + textureType->setSession(session); + textureType->declRef = declRef; + return textureType; + } + else if (magicMod->name == "TextureSampler") + { + SLANG_ASSERT(subst && subst->args.getCount() >= 1); + auto textureType = new TextureSamplerType( + TextureFlavor(magicMod->tag), + ExtractGenericArgType(subst->args[0])); + textureType->setSession(session); + textureType->declRef = declRef; + return textureType; + } + else if (magicMod->name == "GLSLImageType") + { + SLANG_ASSERT(subst && subst->args.getCount() >= 1); + auto textureType = new GLSLImageType( + TextureFlavor(magicMod->tag), + ExtractGenericArgType(subst->args[0])); + textureType->setSession(session); + textureType->declRef = declRef; + return textureType; + } + + // TODO: eventually everything should follow this pattern, + // and we can drive the dispatch with a table instead + // of this ridiculously slow `if` cascade. + + #define CASE(n,T) \ + else if(magicMod->name == #n) { \ + auto type = new T(); \ + type->setSession(session); \ + type->declRef = declRef; \ + return type; \ + } + + CASE(HLSLInputPatchType, HLSLInputPatchType) + CASE(HLSLOutputPatchType, HLSLOutputPatchType) + + #undef CASE + + #define CASE(n,T) \ + else if(magicMod->name == #n) { \ + SLANG_ASSERT(subst && subst->args.getCount() == 1); \ + auto type = new T(); \ + type->setSession(session); \ + type->elementType = ExtractGenericArgType(subst->args[0]); \ + type->declRef = declRef; \ + return type; \ + } + + CASE(ConstantBuffer, ConstantBufferType) + CASE(TextureBuffer, TextureBufferType) + CASE(ParameterBlockType, ParameterBlockType) + CASE(GLSLInputParameterGroupType, GLSLInputParameterGroupType) + CASE(GLSLOutputParameterGroupType, GLSLOutputParameterGroupType) + CASE(GLSLShaderStorageBufferType, GLSLShaderStorageBufferType) + + CASE(HLSLStructuredBufferType, HLSLStructuredBufferType) + CASE(HLSLRWStructuredBufferType, HLSLRWStructuredBufferType) + CASE(HLSLRasterizerOrderedStructuredBufferType, HLSLRasterizerOrderedStructuredBufferType) + CASE(HLSLAppendStructuredBufferType, HLSLAppendStructuredBufferType) + CASE(HLSLConsumeStructuredBufferType, HLSLConsumeStructuredBufferType) + + CASE(HLSLPointStreamType, HLSLPointStreamType) + CASE(HLSLLineStreamType, HLSLLineStreamType) + CASE(HLSLTriangleStreamType, HLSLTriangleStreamType) + + #undef CASE + + // "magic" builtin types which have no generic parameters + #define CASE(n,T) \ + else if(magicMod->name == #n) { \ + auto type = new T(); \ + type->setSession(session); \ + type->declRef = declRef; \ + return type; \ + } + + CASE(HLSLByteAddressBufferType, HLSLByteAddressBufferType) + CASE(HLSLRWByteAddressBufferType, HLSLRWByteAddressBufferType) + CASE(HLSLRasterizerOrderedByteAddressBufferType, HLSLRasterizerOrderedByteAddressBufferType) + CASE(UntypedBufferResourceType, UntypedBufferResourceType) + + CASE(GLSLInputAttachmentType, GLSLInputAttachmentType) + + #undef CASE + + else + { + auto classInfo = session->findSyntaxClass( + session->getNamePool()->getName(magicMod->name)); + if (!classInfo.classInfo) + { + SLANG_UNEXPECTED("unhandled type"); + } + + RefPtr type = classInfo.createInstance(); + if (!type) + { + SLANG_UNEXPECTED("constructor failure"); + } + + auto declRefType = dynamicCast(type); + if (!declRefType) + { + SLANG_UNEXPECTED("expected a declaration reference type"); + } + declRefType->session = session; + declRefType->declRef = declRef; + return declRefType; + } + } + else + { + auto type = new DeclRefType(declRef); + type->setSession(session); + return type; + } + } + + // OverloadGroupType + + String OverloadGroupType::ToString() + { + return "overload group"; + } + + bool OverloadGroupType::EqualsImpl(Type * /*type*/) + { + return false; + } + + RefPtr OverloadGroupType::CreateCanonicalType() + { + return this; + } + + int OverloadGroupType::GetHashCode() + { + return (int)(int64_t)(void*)this; + } + + // InitializerListType + + String InitializerListType::ToString() + { + return "initializer list"; + } + + bool InitializerListType::EqualsImpl(Type * /*type*/) + { + return false; + } + + RefPtr InitializerListType::CreateCanonicalType() + { + return this; + } + + int InitializerListType::GetHashCode() + { + return (int)(int64_t)(void*)this; + } + + // ErrorType + + String ErrorType::ToString() + { + return "error"; + } + + bool ErrorType::EqualsImpl(Type* type) + { + if (auto errorType = as(type)) + return true; + return false; + } + + RefPtr ErrorType::CreateCanonicalType() + { + return this; + } + + RefPtr ErrorType::SubstituteImpl(SubstitutionSet /*subst*/, int* /*ioDiff*/) + { + return this; + } + + int ErrorType::GetHashCode() + { + return (int)(int64_t)(void*)this; + } + + + // NamedExpressionType + + String NamedExpressionType::ToString() + { + return getText(declRef.GetName()); + } + + bool NamedExpressionType::EqualsImpl(Type * /*type*/) + { + SLANG_UNEXPECTED("unreachable"); + UNREACHABLE_RETURN(false); + } + + RefPtr NamedExpressionType::CreateCanonicalType() + { + if (!innerType) + innerType = GetType(declRef); + return innerType->GetCanonicalType(); + } + + int NamedExpressionType::GetHashCode() + { + // Type equality is based on comparing canonical types, + // so the hash code for a type needs to come from the + // canonical version of the type. This really means + // that `Type::GetHashCode()` should dispatch out to + // something like `Type::GetHashCodeImpl()` on the + // canonical version of a type, but it is less invasive + // for now (and hopefully equivalent) to just have any + // named types automaticlaly route hash-code requests + // to their canonical type. + return GetCanonicalType()->GetHashCode(); + } + + // FuncType + + String FuncType::ToString() + { + StringBuilder sb; + sb << "("; + UInt paramCount = getParamCount(); + for (UInt pp = 0; pp < paramCount; ++pp) + { + if (pp != 0) sb << ", "; + sb << getParamType(pp)->ToString(); + } + sb << ") -> "; + sb << getResultType()->ToString(); + return sb.ProduceString(); + } + + bool FuncType::EqualsImpl(Type * type) + { + if (auto funcType = as(type)) + { + auto paramCount = getParamCount(); + auto otherParamCount = funcType->getParamCount(); + if (paramCount != otherParamCount) + return false; + + for (UInt pp = 0; pp < paramCount; ++pp) + { + auto paramType = getParamType(pp); + auto otherParamType = funcType->getParamType(pp); + if (!paramType->Equals(otherParamType)) + return false; + } + + if(!resultType->Equals(funcType->resultType)) + return false; + + // TODO: if we ever introduce other kinds + // of qualification on function types, we'd + // want to consider it here. + return true; + } + return false; + } + + RefPtr FuncType::SubstituteImpl(SubstitutionSet subst, int* ioDiff) + { + int diff = 0; + + // result type + RefPtr substResultType = resultType->SubstituteImpl(subst, &diff).as(); + + // parameter types + List> substParamTypes; + for( auto pp : paramTypes ) + { + substParamTypes.add(pp->SubstituteImpl(subst, &diff).as()); + } + + // early exit for no change... + if(!diff) + return this; + + (*ioDiff)++; + RefPtr substType = new FuncType(); + substType->session = session; + substType->resultType = substResultType; + substType->paramTypes = substParamTypes; + return substType; + } + + RefPtr FuncType::CreateCanonicalType() + { + // result type + RefPtr canResultType = resultType->GetCanonicalType(); + + // parameter types + List> canParamTypes; + for( auto pp : paramTypes ) + { + canParamTypes.add(pp->GetCanonicalType()); + } + + RefPtr canType = new FuncType(); + canType->session = session; + canType->resultType = resultType; + canType->paramTypes = canParamTypes; + + return canType; + } + + int FuncType::GetHashCode() + { + int hashCode = getResultType()->GetHashCode(); + UInt paramCount = getParamCount(); + hashCode = combineHash(hashCode, Slang::GetHashCode(paramCount)); + for (UInt pp = 0; pp < paramCount; ++pp) + { + hashCode = combineHash( + hashCode, + getParamType(pp)->GetHashCode()); + } + return hashCode; + } + + // TypeType + + String TypeType::ToString() + { + StringBuilder sb; + sb << "typeof(" << type->ToString() << ")"; + return sb.ProduceString(); + } + + bool TypeType::EqualsImpl(Type * t) + { + if (auto typeType = as(t)) + { + return t->Equals(typeType->type); + } + return false; + } + + RefPtr TypeType::CreateCanonicalType() + { + auto canType = getTypeType(type->GetCanonicalType()); + return canType; + } + + int TypeType::GetHashCode() + { + SLANG_UNEXPECTED("unreachable"); + UNREACHABLE_RETURN(0); + } + + // GenericDeclRefType + + String GenericDeclRefType::ToString() + { + // TODO: what is appropriate here? + return ">"; + } + + bool GenericDeclRefType::EqualsImpl(Type * type) + { + if (auto genericDeclRefType = as(type)) + { + return declRef.Equals(genericDeclRefType->declRef); + } + return false; + } + + int GenericDeclRefType::GetHashCode() + { + return declRef.GetHashCode(); + } + + RefPtr GenericDeclRefType::CreateCanonicalType() + { + return this; + } + + // ArithmeticExpressionType + + // VectorExpressionType + + String VectorExpressionType::ToString() + { + StringBuilder sb; + sb << "vector<" << elementType->ToString() << "," << elementCount->ToString() << ">"; + return sb.ProduceString(); + } + + BasicExpressionType* VectorExpressionType::GetScalarType() + { + return as(elementType); + } + + // + + RefPtr findInnerMostGenericSubstitution(Substitutions* subst) + { + for(RefPtr s = subst; s; s = s->outer) + { + if(auto genericSubst = as(s)) + return genericSubst; + } + return nullptr; + } + + // MatrixExpressionType + + String MatrixExpressionType::ToString() + { + StringBuilder sb; + sb << "matrix<" << getElementType()->ToString() << "," << getRowCount()->ToString() << "," << getColumnCount()->ToString() << ">"; + return sb.ProduceString(); + } + + BasicExpressionType* MatrixExpressionType::GetScalarType() + { + return as(getElementType()); + } + + Type* MatrixExpressionType::getElementType() + { + return as(findInnerMostGenericSubstitution(declRef.substitutions)->args[0]); + } + + IntVal* MatrixExpressionType::getRowCount() + { + return as(findInnerMostGenericSubstitution(declRef.substitutions)->args[1]); + } + + IntVal* MatrixExpressionType::getColumnCount() + { + return as(findInnerMostGenericSubstitution(declRef.substitutions)->args[2]); + } + + RefPtr MatrixExpressionType::getRowType() + { + if( !mRowType ) + { + mRowType = getSession()->getVectorType(getElementType(), getColumnCount()); + } + return mRowType; + } + + RefPtr Session::getVectorType( + RefPtr elementType, + RefPtr elementCount) + { + auto vectorGenericDecl = findMagicDecl( + this, "Vector").as(); + auto vectorTypeDecl = vectorGenericDecl->inner; + + auto substitutions = new GenericSubstitution(); + substitutions->genericDecl = vectorGenericDecl.Ptr(); + substitutions->args.add(elementType); + substitutions->args.add(elementCount); + + auto declRef = DeclRef(vectorTypeDecl.Ptr(), substitutions); + + return DeclRefType::Create( + this, + declRef).as(); + } + + + // PtrTypeBase + + Type* PtrTypeBase::getValueType() + { + return as(findInnerMostGenericSubstitution(declRef.substitutions)->args[0]); + } + + // GenericParamIntVal + + bool GenericParamIntVal::EqualsVal(Val* val) + { + if (auto genericParamVal = as(val)) + { + return declRef.Equals(genericParamVal->declRef); + } + return false; + } + + String GenericParamIntVal::ToString() + { + return getText(declRef.GetName()); + } + + int GenericParamIntVal::GetHashCode() + { + return declRef.GetHashCode() ^ 0xFFFF; + } + + RefPtr GenericParamIntVal::SubstituteImpl(SubstitutionSet subst, int* ioDiff) + { + // search for a substitution that might apply to us + for(auto s = subst.substitutions; s; s = s->outer) + { + auto genSubst = s.as(); + if(!genSubst) + continue; + + // the generic decl associated with the substitution list must be + // the generic decl that declared this parameter + auto genericDecl = genSubst->genericDecl; + if (genericDecl != declRef.getDecl()->ParentDecl) + continue; + + int index = 0; + for (auto m : genericDecl->Members) + { + if (m.Ptr() == declRef.getDecl()) + { + // We've found it, so return the corresponding specialization argument + (*ioDiff)++; + return genSubst->args[index]; + } + else if (auto typeParam = as(m)) + { + index++; + } + else if (auto valParam = as(m)) + { + index++; + } + else + { + } + } + } + + // Nothing found: don't substitute. + return this; + } + + // Substitutions + + RefPtr GenericSubstitution::applySubstitutionsShallow(SubstitutionSet substSet, RefPtr substOuter, int* ioDiff) + { + int diff = 0; + + if(substOuter != outer) diff++; + + List> substArgs; + for (auto a : args) + { + substArgs.add(a->SubstituteImpl(substSet, &diff)); + } + + if (!diff) return this; + + (*ioDiff)++; + auto substSubst = new GenericSubstitution(); + substSubst->genericDecl = genericDecl; + substSubst->args = substArgs; + substSubst->outer = substOuter; + return substSubst; + } + + bool GenericSubstitution::Equals(Substitutions* subst) + { + // both must be NULL, or non-NULL + if (subst == nullptr) + return false; + if (this == subst) + return true; + + auto genericSubst = as(subst); + if (!genericSubst) + return false; + if (genericDecl != genericSubst->genericDecl) + return false; + + Index argCount = args.getCount(); + SLANG_RELEASE_ASSERT(args.getCount() == genericSubst->args.getCount()); + for (Index aa = 0; aa < argCount; ++aa) + { + if (!args[aa]->EqualsVal(genericSubst->args[aa].Ptr())) + return false; + } + + if (!outer) + return !genericSubst->outer; + + if (!outer->Equals(genericSubst->outer.Ptr())) + return false; + + return true; + } + + RefPtr ThisTypeSubstitution::applySubstitutionsShallow(SubstitutionSet substSet, RefPtr substOuter, int* ioDiff) + { + int diff = 0; + + if(substOuter != outer) diff++; + + // NOTE: Must use .as because we must have a smart pointer here to keep in scope. + auto substWitness = witness->SubstituteImpl(substSet, &diff).as(); + + if (!diff) return this; + + (*ioDiff)++; + auto substSubst = new ThisTypeSubstitution(); + substSubst->interfaceDecl = interfaceDecl; + substSubst->witness = substWitness; + substSubst->outer = substOuter; + return substSubst; + } + + bool ThisTypeSubstitution::Equals(Substitutions* subst) + { + if (!subst) + return false; + if (subst == this) + return true; + + if (auto thisTypeSubst = as(subst)) + { + return witness->EqualsVal(thisTypeSubst->witness); + } + return false; + } + + int ThisTypeSubstitution::GetHashCode() const + { + return witness->GetHashCode(); + } + + RefPtr GlobalGenericParamSubstitution::applySubstitutionsShallow(SubstitutionSet substSet, RefPtr substOuter, int* ioDiff) + { + // if we find a GlobalGenericParamSubstitution in subst that references the same type_param decl + // return a copy of that GlobalGenericParamSubstitution + int diff = 0; + + if(substOuter != outer) diff++; + + auto substActualType = actualType->SubstituteImpl(substSet, &diff).as(); + + List substConstraintArgs; + for(auto constraintArg : constraintArgs) + { + ConstraintArg substConstraintArg; + substConstraintArg.decl = constraintArg.decl; + substConstraintArg.val = constraintArg.val->SubstituteImpl(substSet, &diff); + + substConstraintArgs.add(substConstraintArg); + } + + if(!diff) + return this; + + (*ioDiff)++; + + RefPtr substSubst = new GlobalGenericParamSubstitution(); + substSubst->paramDecl = paramDecl; + substSubst->actualType = substActualType; + substSubst->constraintArgs = substConstraintArgs; + substSubst->outer = substOuter; + return substSubst; + } + + bool GlobalGenericParamSubstitution::Equals(Substitutions* subst) + { + if (!subst) + return false; + if (subst == this) + return true; + + if (auto genSubst = as(subst)) + { + if (paramDecl != genSubst->paramDecl) + return false; + if (!actualType->EqualsVal(genSubst->actualType)) + return false; + if (constraintArgs.getCount() != genSubst->constraintArgs.getCount()) + return false; + for (Index i = 0; i < constraintArgs.getCount(); i++) + { + if (!constraintArgs[i].val->EqualsVal(genSubst->constraintArgs[i].val)) + return false; + } + return true; + } + return false; + } + + + // DeclRefBase + + RefPtr DeclRefBase::Substitute(RefPtr type) const + { + // Note that type can be nullptr, and so this function can return nullptr (although only correctly when no substitutions) + + // No substitutions? Easy. + if (!substitutions) + return type; + + SLANG_ASSERT(type); + + // Otherwise we need to recurse on the type structure + // and apply substitutions where it makes sense + return type->Substitute(substitutions).as(); + } + + DeclRefBase DeclRefBase::Substitute(DeclRefBase declRef) const + { + if(!substitutions) + return declRef; + + int diff = 0; + return declRef.SubstituteImpl(substitutions, &diff); + } + + RefPtr DeclRefBase::Substitute(RefPtr expr) const + { + // No substitutions? Easy. + if (!substitutions) + return expr; + + SLANG_UNIMPLEMENTED_X("generic substitution into expressions"); + + UNREACHABLE_RETURN(expr); + } + + void buildMemberDictionary(ContainerDecl* decl); + + InterfaceDecl* findOuterInterfaceDecl(Decl* decl) + { + Decl* dd = decl; + while(dd) + { + if(auto interfaceDecl = as(dd)) + return interfaceDecl; + + dd = dd->ParentDecl; + } + return nullptr; + } + + RefPtr findGlobalGenericSubst( + RefPtr substs, + GlobalGenericParamDecl* paramDecl) + { + for(auto s = substs; s; s = s->outer) + { + auto gSubst = s.as(); + if(!gSubst) + continue; + + if(gSubst->paramDecl != paramDecl) + continue; + + return gSubst; + } + + return nullptr; + } + + RefPtr specializeSubstitutionsShallow( + RefPtr substToSpecialize, + RefPtr substsToApply, + RefPtr restSubst, + int* ioDiff) + { + SLANG_ASSERT(substToSpecialize); + return substToSpecialize->applySubstitutionsShallow(substsToApply, restSubst, ioDiff); + } + + RefPtr specializeGlobalGenericSubstitutions( + Decl* declToSpecialize, + RefPtr substsToSpecialize, + RefPtr substsToApply, + int* ioDiff, + HashSet& ioParametersFound) + { + // Any existing global-generic substitutions will trigger + // a recursive case that skips the rest of the function. + for(auto specSubst = substsToSpecialize; specSubst; specSubst = specSubst->outer) + { + auto specGlobalGenericSubst = specSubst.as(); + if(!specGlobalGenericSubst) + continue; + + ioParametersFound.Add(specGlobalGenericSubst->paramDecl); + + int diff = 0; + auto restSubst = specializeGlobalGenericSubstitutions( + declToSpecialize, + specSubst->outer, + substsToApply, + &diff, + ioParametersFound); + + auto firstSubst = specializeSubstitutionsShallow( + specGlobalGenericSubst, + substsToApply, + restSubst, + &diff); + + *ioDiff += diff; + return firstSubst; + } + + // No more existing substitutions, so we know we can apply + // our global generic substitutions without any special work. + + // We expect global generic substitutions to come at + // the end of the list in all cases, so lets advance + // until we see them. + RefPtr appGlobalGenericSubsts = substsToApply; + while(appGlobalGenericSubsts && !appGlobalGenericSubsts.as()) + appGlobalGenericSubsts = appGlobalGenericSubsts->outer; + + + // If there is nothing to apply, then we are done + if(!appGlobalGenericSubsts) + return nullptr; + + // Otherwise, it seems like something has to change. + (*ioDiff)++; + + // If there were no parameters bound by the existing substitution, + // then we can safely use the global generics from the to-apply set. + if(ioParametersFound.Count() == 0) + return appGlobalGenericSubsts; + + RefPtr resultSubst; + RefPtr* link = &resultSubst; + for(auto appSubst = appGlobalGenericSubsts; appSubst; appSubst = appSubst->outer) + { + auto appGlobalGenericSubst = appSubst.as(); + if(!appSubst) + continue; + + // Don't include substitutions for parameters already handled. + if(ioParametersFound.Contains(appGlobalGenericSubst->paramDecl)) + continue; + + RefPtr newSubst = new GlobalGenericParamSubstitution(); + newSubst->paramDecl = appGlobalGenericSubst->paramDecl; + newSubst->actualType = appGlobalGenericSubst->actualType; + newSubst->constraintArgs = appGlobalGenericSubst->constraintArgs; + + *link = newSubst; + link = &newSubst->outer; + } + + return resultSubst; + } + + RefPtr specializeGlobalGenericSubstitutions( + Decl* declToSpecialize, + RefPtr substsToSpecialize, + RefPtr substsToApply, + int* ioDiff) + { + // Keep track of any parameters already present in the + // existing substitution. + HashSet parametersFound; + return specializeGlobalGenericSubstitutions(declToSpecialize, substsToSpecialize, substsToApply, ioDiff, parametersFound); + } + + + // Construct new substitutions to apply to a declaration, + // based on a provided substitution set to be applied + RefPtr specializeSubstitutions( + Decl* declToSpecialize, + RefPtr substsToSpecialize, + RefPtr substsToApply, + int* ioDiff) + { + // No declaration? Then nothing to specialize. + if(!declToSpecialize) + return nullptr; + + // No (remaining) substitutions to apply? Then we are done. + if(!substsToApply) + return substsToSpecialize; + + // Walk the hierarchy of the declaration to determine what specializations might apply. + // We assume that the `substsToSpecialize` must be aligned with the ancestor + // hierarchy of `declToSpecialize` such that if, e.g., the `declToSpecialize` is + // nested directly in a generic, then `substToSpecialize` will either start with + // the corresponding `GenericSubstitution` or there will be *no* generic substitutions + // corresponding to that decl. + for(Decl* ancestorDecl = declToSpecialize; ancestorDecl; ancestorDecl = ancestorDecl->ParentDecl) + { + if(auto ancestorGenericDecl = as(ancestorDecl)) + { + // The declaration is nested inside a generic. + // Does it already have a specialization for that generic? + if(auto specGenericSubst = as(substsToSpecialize)) + { + if(specGenericSubst->genericDecl == ancestorGenericDecl) + { + // Yes. We have an existing specialization, so we will + // keep one matching it in place. + int diff = 0; + auto restSubst = specializeSubstitutions( + ancestorGenericDecl->ParentDecl, + specGenericSubst->outer, + substsToApply, + &diff); + + auto firstSubst = specializeSubstitutionsShallow( + specGenericSubst, + substsToApply, + restSubst, + &diff); + + *ioDiff += diff; + return firstSubst; + } + } + + // If the declaration is not already specialized + // for the given generic, then see if we are trying + // to *apply* such specializations to it. + // + // TODO: The way we handle things right now with + // "default" specializations, this case shouldn't + // actually come up. + // + for(auto s = substsToApply; s; s = s->outer) + { + auto appGenericSubst = as(s); + if(!appGenericSubst) + continue; + + if(appGenericSubst->genericDecl != ancestorGenericDecl) + continue; + + // The substitutions we are applying are trying + // to specialize this generic, but we don't already + // have a generic substitution in place. + // We will need to create one. + + int diff = 0; + auto restSubst = specializeSubstitutions( + ancestorGenericDecl->ParentDecl, + substsToSpecialize, + substsToApply, + &diff); + + RefPtr firstSubst = new GenericSubstitution(); + firstSubst->genericDecl = ancestorGenericDecl; + firstSubst->args = appGenericSubst->args; + firstSubst->outer = restSubst; + + (*ioDiff)++; + return firstSubst; + } + } + else if(auto ancestorInterfaceDecl = as(ancestorDecl)) + { + // The task is basically the same as for the generic case: + // We want to see if there is any existing substitution that + // applies to this declaration, and use that if possible. + + // The declaration is nested inside a generic. + // Does it already have a specialization for that generic? + if(auto specThisTypeSubst = as(substsToSpecialize)) + { + if(specThisTypeSubst->interfaceDecl == ancestorInterfaceDecl) + { + // Yes. We have an existing specialization, so we will + // keep one matching it in place. + int diff = 0; + auto restSubst = specializeSubstitutions( + ancestorInterfaceDecl->ParentDecl, + specThisTypeSubst->outer, + substsToApply, + &diff); + + auto firstSubst = specializeSubstitutionsShallow( + specThisTypeSubst, + substsToApply, + restSubst, + &diff); + + *ioDiff += diff; + return firstSubst; + } + } + + // Otherwise, check if we are trying to apply + // a this-type substitution to the given interface + // + for(auto s = substsToApply; s; s = s->outer) + { + auto appThisTypeSubst = s.as(); + if(!appThisTypeSubst) + continue; + + if(appThisTypeSubst->interfaceDecl != ancestorInterfaceDecl) + continue; + + int diff = 0; + auto restSubst = specializeSubstitutions( + ancestorInterfaceDecl->ParentDecl, + substsToSpecialize, + substsToApply, + &diff); + + RefPtr firstSubst = new ThisTypeSubstitution(); + firstSubst->interfaceDecl = ancestorInterfaceDecl; + firstSubst->witness = appThisTypeSubst->witness; + firstSubst->outer = restSubst; + + (*ioDiff)++; + return firstSubst; + } + } + } + + // If we reach here then we've walked the full hierarchy up from + // `declToSpecialize` and either didn't run into an generic/interface + // declarations, or we didn't find any attempt to specialize them + // in either substitution. + // + // As an invariant, there should *not* be any generic or this-type + // substitutions in `substToSpecialize`, because otherwise they + // would be specializations that don't actually apply to the given + // declaration. + // + // The remaining substitutions to apply, if any, should thus be + // global-generic substitutions. And similarly, those are the + // only remaining substitutions we really care about in + // `substsToApply`. + // + // Note: this does *not* mean that `substsToApply` doesn't have + // any generic or this-type substitutions; it just means that none + // of them were applicable. + // + return specializeGlobalGenericSubstitutions( + declToSpecialize, + substsToSpecialize, + substsToApply, + ioDiff); + } + + DeclRefBase DeclRefBase::SubstituteImpl(SubstitutionSet substSet, int* ioDiff) + { + // Nothing to do when we have no declaration. + if(!decl) + return *this; + + // Apply the given substitutions to any specializations + // that have already been applied to this declaration. + int diff = 0; + + auto substSubst = specializeSubstitutions( + decl, + substitutions.substitutions, + substSet.substitutions, + &diff); + + if (!diff) + return *this; + + *ioDiff += diff; + + DeclRefBase substDeclRef; + substDeclRef.decl = decl; + substDeclRef.substitutions = substSubst; + + // TODO: The old code here used to try to translate a decl-ref + // to an associated type in a decl-ref for the concrete type + // in a particular implementation. + // + // I have only kept that logic in `DeclRefType::SubstituteImpl`, + // but it may turn out it is needed here too. + + return substDeclRef; + } + + + // Check if this is an equivalent declaration reference to another + bool DeclRefBase::Equals(DeclRefBase const& declRef) const + { + if (decl != declRef.decl) + return false; + if (!substitutions.Equals(declRef.substitutions)) + return false; + + return true; + } + + // Convenience accessors for common properties of declarations + Name* DeclRefBase::GetName() const + { + return decl->nameAndLoc.name; + } + + SourceLoc DeclRefBase::getLoc() const + { + return decl->loc; + } + + DeclRefBase DeclRefBase::GetParent() const + { + // Want access to the free function (the 'as' method by default gets priority) + // Can access as method with this->as because it removes any ambiguity. + using Slang::as; + + auto parentDecl = decl->ParentDecl; + if (!parentDecl) + return DeclRefBase(); + + // Default is to apply the same set of substitutions/specializations + // to the parent declaration as were applied to the child. + RefPtr substToApply = substitutions.substitutions; + + if(auto interfaceDecl = as(decl)) + { + // The declaration being referenced is an `interface` declaration, + // and there might be a this-type substitution in place. + // A reference to the parent of the interface declaration + // should not include that substitution. + if(auto thisTypeSubst = as(substToApply)) + { + if(thisTypeSubst->interfaceDecl == interfaceDecl) + { + // Strip away that specializations that apply to the interface. + substToApply = thisTypeSubst->outer; + } + } + } + + if (auto parentGenericDecl = as(parentDecl)) + { + // The parent of this declaration is a generic, which means + // that the decl-ref to the current declaration might include + // substitutions that specialize the generic parameters. + // A decl-ref to the parent generic should *not* include + // those substitutions. + // + if(auto genericSubst = as(substToApply)) + { + if(genericSubst->genericDecl == parentGenericDecl) + { + // Strip away the specializations that were applied to the parent. + substToApply = genericSubst->outer; + } + } + } + + return DeclRefBase(parentDecl, substToApply); + } + + int DeclRefBase::GetHashCode() const + { + return combineHash(PointerHash<1>::GetHashCode(decl), substitutions.GetHashCode()); + } + + // Val + + RefPtr Val::Substitute(SubstitutionSet subst) + { + if (!subst) return this; + int diff = 0; + return SubstituteImpl(subst, &diff); + } + + RefPtr Val::SubstituteImpl(SubstitutionSet /*subst*/, int* /*ioDiff*/) + { + // Default behavior is to not substitute at all + return this; + } + + // IntVal + + IntegerLiteralValue GetIntVal(RefPtr val) + { + if (auto constantVal = as(val)) + { + return constantVal->value; + } + SLANG_UNEXPECTED("needed a known integer value"); + return 0; + } + + // ConstantIntVal + + bool ConstantIntVal::EqualsVal(Val* val) + { + if (auto intVal = as(val)) + return value == intVal->value; + return false; + } + + String ConstantIntVal::ToString() + { + return String(value); + } + + int ConstantIntVal::GetHashCode() + { + return (int) value; + } + + // + + void registerBuiltinDecl( + Session* session, + RefPtr decl, + RefPtr modifier) + { + auto type = DeclRefType::Create( + session, + DeclRef(decl.Ptr(), nullptr)); + session->builtinTypes[(int)modifier->tag] = type; + } + + void registerMagicDecl( + Session* session, + RefPtr decl, + RefPtr modifier) + { + session->magicDecls[modifier->name] = decl.Ptr(); + } + + RefPtr findMagicDecl( + Session* session, + String const& name) + { + return session->magicDecls[name].GetValue(); + } + + // + + SyntaxNodeBase* createInstanceOfSyntaxClassByName( + String const& name) + { + if(0) {} + #define CASE(NAME) \ + else if(name == #NAME) return new NAME() + + CASE(GLSLBufferModifier); + CASE(GLSLWriteOnlyModifier); + CASE(GLSLReadOnlyModifier); + CASE(GLSLPatchModifier); + CASE(SimpleModifier); + + #undef CASE + else + { + SLANG_UNEXPECTED("unhandled syntax class name"); + UNREACHABLE_RETURN(nullptr); + } + } + + // + + // HLSLPatchType + + Type* HLSLPatchType::getElementType() + { + return as(findInnerMostGenericSubstitution(declRef.substitutions)->args[0]); + } + + IntVal* HLSLPatchType::getElementCount() + { + return as(findInnerMostGenericSubstitution(declRef.substitutions)->args[1]); + } + + // Constructors for types + + RefPtr getArrayType( + Type* elementType, + IntVal* elementCount) + { + auto session = elementType->getSession(); + auto arrayType = new ArrayExpressionType(); + arrayType->setSession(session); + arrayType->baseType = elementType; + arrayType->ArrayLength = elementCount; + return arrayType; + } + + RefPtr getArrayType( + Type* elementType) + { + auto session = elementType->getSession(); + auto arrayType = new ArrayExpressionType(); + arrayType->setSession(session); + arrayType->baseType = elementType; + return arrayType; + } + + RefPtr getNamedType( + Session* session, + DeclRef const& declRef) + { + DeclRef specializedDeclRef = createDefaultSubstitutionsIfNeeded(session, declRef).as(); + + auto namedType = new NamedExpressionType(specializedDeclRef); + namedType->setSession(session); + return namedType; + } + + RefPtr getTypeType( + Type* type) + { + auto session = type->getSession(); + auto typeType = new TypeType(type); + typeType->setSession(session); + return typeType; + } + + RefPtr getFuncType( + Session* session, + DeclRef const& declRef) + { + RefPtr funcType = new FuncType(); + funcType->setSession(session); + + funcType->resultType = GetResultType(declRef); + for (auto paramDeclRef : GetParameters(declRef)) + { + auto paramDecl = paramDeclRef.getDecl(); + auto paramType = GetType(paramDeclRef); + if( paramDecl->FindModifier() ) + { + paramType = session->getRefType(paramType); + } + else if( paramDecl->FindModifier() ) + { + if(paramDecl->FindModifier() || paramDecl->FindModifier()) + { + paramType = session->getInOutType(paramType); + } + else + { + paramType = session->getOutType(paramType); + } + } + funcType->paramTypes.add(paramType); + } + + return funcType; + } + + RefPtr getGenericDeclRefType( + Session* session, + DeclRef const& declRef) + { + auto genericDeclRefType = new GenericDeclRefType(declRef); + genericDeclRefType->setSession(session); + return genericDeclRefType; + } + + RefPtr getSamplerStateType( + Session* session) + { + auto samplerStateType = new SamplerStateType(); + samplerStateType->setSession(session); + return samplerStateType; + } + + // TODO: should really have a `type.cpp` and a `witness.cpp` + + bool TypeEqualityWitness::EqualsVal(Val* val) + { + auto otherWitness = as(val); + if (!otherWitness) + return false; + return sub->Equals(otherWitness->sub); + } + + RefPtr TypeEqualityWitness::SubstituteImpl(SubstitutionSet subst, int * ioDiff) + { + RefPtr rs = new TypeEqualityWitness(); + rs->sub = sub->SubstituteImpl(subst, ioDiff).as(); + rs->sup = sup->SubstituteImpl(subst, ioDiff).as(); + return rs; + } + + String TypeEqualityWitness::ToString() + { + return "TypeEqualityWitness(" + sub->ToString() + ")"; + } + + int TypeEqualityWitness::GetHashCode() + { + return sub->GetHashCode(); + } + + bool DeclaredSubtypeWitness::EqualsVal(Val* val) + { + auto otherWitness = as(val); + if(!otherWitness) + return false; + + return sub->Equals(otherWitness->sub) + && sup->Equals(otherWitness->sup) + && declRef.Equals(otherWitness->declRef); + } + + RefPtr findThisTypeSubstitution( + Substitutions* substs, + InterfaceDecl* interfaceDecl) + { + for(RefPtr s = substs; s; s = s->outer) + { + auto thisTypeSubst = as(s); + if(!thisTypeSubst) + continue; + + if(thisTypeSubst->interfaceDecl != interfaceDecl) + continue; + + return thisTypeSubst; + } + + return nullptr; + } + + RefPtr DeclaredSubtypeWitness::SubstituteImpl(SubstitutionSet subst, int * ioDiff) + { + if (auto genConstraintDeclRef = declRef.as()) + { + auto genConstraintDecl = genConstraintDeclRef.getDecl(); + + // search for a substitution that might apply to us + for(auto s = subst.substitutions; s; s = s->outer) + { + if(auto genericSubst = as(s)) + { + // the generic decl associated with the substitution list must be + // the generic decl that declared this parameter + auto genericDecl = genericSubst->genericDecl; + if (genericDecl != genConstraintDecl->ParentDecl) + continue; + + bool found = false; + Index index = 0; + for (auto m : genericDecl->Members) + { + if (auto constraintParam = as(m)) + { + if (constraintParam == declRef.getDecl()) + { + found = true; + break; + } + index++; + } + } + if (found) + { + (*ioDiff)++; + auto ordinaryParamCount = genericDecl->getMembersOfType().getCount() + + genericDecl->getMembersOfType().getCount(); + SLANG_ASSERT(index + ordinaryParamCount < genericSubst->args.getCount()); + return genericSubst->args[index + ordinaryParamCount]; + } + } + else if(auto globalGenericSubst = s.as()) + { + // check if the substitution is really about this global generic type parameter + if (globalGenericSubst->paramDecl != genConstraintDecl->ParentDecl) + continue; + + for(auto constraintArg : globalGenericSubst->constraintArgs) + { + if(constraintArg.decl.Ptr() != genConstraintDecl) + continue; + + (*ioDiff)++; + return constraintArg.val; + } + } + } + } + + // Perform substitution on the constituent elements. + int diff = 0; + auto substSub = sub->SubstituteImpl(subst, &diff).as(); + auto substSup = sup->SubstituteImpl(subst, &diff).as(); + auto substDeclRef = declRef.SubstituteImpl(subst, &diff); + if (!diff) + return this; + + (*ioDiff)++; + + // If we have a reference to a type constraint for an + // associated type declaration, then we can replace it + // with the concrete conformance witness for a concrete + // type implementing the outer interface. + // + // TODO: It is a bit gross that we use `GenericTypeConstraintDecl` for + // associated types, when they aren't really generic type *parameters*, + // so we'll need to change this location in the code if we ever clean + // up the hierarchy. + // + if (auto substTypeConstraintDecl = as(substDeclRef.decl)) + { + if (auto substAssocTypeDecl = as(substTypeConstraintDecl->ParentDecl)) + { + if (auto interfaceDecl = as(substAssocTypeDecl->ParentDecl)) + { + // At this point we have a constraint decl for an associated type, + // and we nee to see if we are dealing with a concrete substitution + // for the interface around that associated type. + if(auto thisTypeSubst = findThisTypeSubstitution(substDeclRef.substitutions, interfaceDecl)) + { + // We need to look up the declaration that satisfies + // the requirement named by the associated type. + Decl* requirementKey = substTypeConstraintDecl; + RequirementWitness requirementWitness = tryLookUpRequirementWitness(thisTypeSubst->witness, requirementKey); + switch(requirementWitness.getFlavor()) + { + default: + break; + + case RequirementWitness::Flavor::val: + { + auto satisfyingVal = requirementWitness.getVal(); + return satisfyingVal; + } + } + } + } + } + } + + + + + RefPtr rs = new DeclaredSubtypeWitness(); + rs->sub = substSub; + rs->sup = substSup; + rs->declRef = substDeclRef; + return rs; + } + + String DeclaredSubtypeWitness::ToString() + { + StringBuilder sb; + sb << "DeclaredSubtypeWitness("; + sb << this->sub->ToString(); + sb << ", "; + sb << this->sup->ToString(); + sb << ", "; + sb << this->declRef.toString(); + sb << ")"; + return sb.ProduceString(); + } + + int DeclaredSubtypeWitness::GetHashCode() + { + return declRef.GetHashCode(); + } + + // TransitiveSubtypeWitness + + bool TransitiveSubtypeWitness::EqualsVal(Val* val) + { + auto otherWitness = as(val); + if(!otherWitness) + return false; + + return sub->Equals(otherWitness->sub) + && sup->Equals(otherWitness->sup) + && subToMid->EqualsVal(otherWitness->subToMid) + && midToSup.Equals(otherWitness->midToSup); + } + + RefPtr TransitiveSubtypeWitness::SubstituteImpl(SubstitutionSet subst, int * ioDiff) + { + int diff = 0; + + RefPtr substSub = sub->SubstituteImpl(subst, &diff).as(); + RefPtr substSup = sup->SubstituteImpl(subst, &diff).as(); + RefPtr substSubToMid = subToMid->SubstituteImpl(subst, &diff).as(); + DeclRef substMidToSup = midToSup.SubstituteImpl(subst, &diff); + + // If nothing changed, then we can bail out early. + if (!diff) + return this; + + // Something changes, so let the caller know. + (*ioDiff)++; + + // TODO: are there cases where we can simplify? + // + // In principle, if either `subToMid` or `midToSub` turns into + // a reflexive subtype witness, then we could drop that side, + // and just return the other one (this would imply that `sub == mid` + // or `mid == sup` after substitutions). + // + // In the long run, is it also possible that if `sub` gets resolved + // to a concrete type *and* we decide to flatten out the inheritance + // graph into a linearized "class precedence list" stored in any + // aggregate type, then we could potentially just redirect to point + // to the appropriate inheritance decl in the original type. + // + // For now I'm going to ignore those possibilities and hope for the best. + + // In the simple case, we just construct a new transitive subtype + // witness, and we move on with life. + RefPtr result = new TransitiveSubtypeWitness(); + result->sub = substSub; + result->sup = substSup; + result->subToMid = substSubToMid; + result->midToSup = substMidToSup; + return result; + } + + String TransitiveSubtypeWitness::ToString() + { + // Note: we only print the constituent + // witnesses, and rely on them to print + // the starting and ending types. + StringBuilder sb; + sb << "TransitiveSubtypeWitness("; + sb << this->subToMid->ToString(); + sb << ", "; + sb << this->midToSup.toString(); + sb << ")"; + return sb.ProduceString(); + } + + int TransitiveSubtypeWitness::GetHashCode() + { + auto hash = sub->GetHashCode(); + hash = combineHash(hash, sup->GetHashCode()); + hash = combineHash(hash, subToMid->GetHashCode()); + hash = combineHash(hash, midToSup.GetHashCode()); + return hash; + } + + // + + String DeclRefBase::toString() const + { + if (!decl) return ""; + + auto name = decl->getName(); + if (!name) return ""; + + // TODO: need to print out substitutions too! + return name->text; + } + + bool SubstitutionSet::Equals(const SubstitutionSet& substSet) const + { + if (substitutions == substSet.substitutions) + { + return true; + } + if (substitutions == nullptr || substSet.substitutions == nullptr) + { + return false; + } + return substitutions->Equals(substSet.substitutions); + } + + int SubstitutionSet::GetHashCode() const + { + int rs = 0; + if (substitutions) + rs = combineHash(rs, substitutions->GetHashCode()); + return rs; + } + + // ExtractExistentialType + + String ExtractExistentialType::ToString() + { + String result; + result.append(declRef.toString()); + result.append(".This"); + return result; + } + + bool ExtractExistentialType::EqualsImpl(Type* type) + { + if( auto extractExistential = as(type) ) + { + return declRef.Equals(extractExistential->declRef); + } + return false; + } + + int ExtractExistentialType::GetHashCode() + { + return declRef.GetHashCode(); + } + + RefPtr ExtractExistentialType::CreateCanonicalType() + { + return this; + } + + RefPtr ExtractExistentialType::SubstituteImpl(SubstitutionSet subst, int* ioDiff) + { + int diff = 0; + auto substDeclRef = declRef.SubstituteImpl(subst, &diff); + if(!diff) + return this; + + (*ioDiff)++; + + RefPtr substValue = new ExtractExistentialType(); + substValue->declRef = declRef; + return substValue; + } + + // ExtractExistentialSubtypeWitness + + bool ExtractExistentialSubtypeWitness::EqualsVal(Val* val) + { + if( auto extractWitness = as(val) ) + { + return declRef.Equals(extractWitness->declRef); + } + return false; + } + + String ExtractExistentialSubtypeWitness::ToString() + { + String result; + result.append("extractExistentialValue("); + result.append(declRef.toString()); + result.append(")"); + return result; + } + + int ExtractExistentialSubtypeWitness::GetHashCode() + { + return declRef.GetHashCode(); + } + + RefPtr ExtractExistentialSubtypeWitness::SubstituteImpl(SubstitutionSet subst, int* ioDiff) + { + int diff = 0; + + auto substDeclRef = declRef.SubstituteImpl(subst, &diff); + auto substSub = sub->SubstituteImpl(subst, &diff).as(); + auto substSup = sup->SubstituteImpl(subst, &diff).as(); + + if(!diff) + return this; + + (*ioDiff)++; + + RefPtr substValue = new ExtractExistentialSubtypeWitness(); + substValue->declRef = declRef; + substValue->sub = substSub; + substValue->sup = substSup; + return substValue; + } + + // + // TaggedUnionType + // + + String TaggedUnionType::ToString() + { + String result; + result.append("__TaggedUnion("); + bool first = true; + for( auto caseType : caseTypes ) + { + if(!first) result.append(", "); + first = false; + + result.append(caseType->ToString()); + } + result.append(")"); + return result; + } + + bool TaggedUnionType::EqualsImpl(Type* type) + { + auto taggedUnion = as(type); + if(!taggedUnion) + return false; + + auto caseCount = caseTypes.getCount(); + if(caseCount != taggedUnion->caseTypes.getCount()) + return false; + + for( Index ii = 0; ii < caseCount; ++ii ) + { + if(!caseTypes[ii]->Equals(taggedUnion->caseTypes[ii])) + return false; + } + return true; + } + + int TaggedUnionType::GetHashCode() + { + int hashCode = 0; + for( auto caseType : caseTypes ) + { + hashCode = combineHash(hashCode, caseType->GetHashCode()); + } + return hashCode; + } + + RefPtr TaggedUnionType::CreateCanonicalType() + { + RefPtr canType = new TaggedUnionType(); + canType->setSession(getSession()); + + for( auto caseType : caseTypes ) + { + auto canCaseType = caseType->GetCanonicalType(); + canType->caseTypes.add(canCaseType); + } + + return canType; + } + + RefPtr TaggedUnionType::SubstituteImpl(SubstitutionSet subst, int* ioDiff) + { + int diff = 0; + + List> substCaseTypes; + for( auto caseType : caseTypes ) + { + substCaseTypes.add(caseType->SubstituteImpl(subst, &diff).as()); + } + if(!diff) + return this; + + (*ioDiff)++; + + RefPtr substType = new TaggedUnionType(); + substType->setSession(getSession()); + substType->caseTypes.swapWith(substCaseTypes); + return substType; + } + +// +// TaggedUnionSubtypeWitness +// + + +bool TaggedUnionSubtypeWitness::EqualsVal(Val* val) +{ + auto taggedUnionWitness = as(val); + if(!taggedUnionWitness) + return false; + + auto caseCount = caseWitnesses.getCount(); + if(caseCount != taggedUnionWitness->caseWitnesses.getCount()) + return false; + + for(Index ii = 0; ii < caseCount; ++ii) + { + if(!caseWitnesses[ii]->EqualsVal(taggedUnionWitness->caseWitnesses[ii])) + return false; + } + + return true; +} + +String TaggedUnionSubtypeWitness::ToString() +{ + String result; + result.append("TaggedUnionSubtypeWitness("); + bool first = true; + for( auto caseWitness : caseWitnesses ) + { + if(!first) result.append(", "); + first = false; + + result.append(caseWitness->ToString()); + } + return result; +} + +int TaggedUnionSubtypeWitness::GetHashCode() +{ + int hash = 0; + for( auto caseWitness : caseWitnesses ) + { + hash = combineHash(hash, caseWitness->GetHashCode()); + } + return hash; +} + +RefPtr TaggedUnionSubtypeWitness::SubstituteImpl(SubstitutionSet subst, int* ioDiff) +{ + int diff = 0; + + auto substSub = sub->SubstituteImpl(subst, &diff).as(); + auto substSup = sup->SubstituteImpl(subst, &diff).as(); + + List> substCaseWitnesses; + for( auto caseWitness : caseWitnesses ) + { + substCaseWitnesses.add(caseWitness->SubstituteImpl(subst, &diff)); + } + + if(!diff) + return this; + + (*ioDiff)++; + + RefPtr substWitness = new TaggedUnionSubtypeWitness(); + substWitness->sub = substSub; + substWitness->sup = substSup; + substWitness->caseWitnesses.swapWith(substCaseWitnesses); + return substWitness; +} + +Module* getModule(Decl* decl) +{ + for( auto dd = decl; dd; dd = dd->ParentDecl ) + { + if(auto moduleDecl = as(dd)) + return moduleDecl->module; + } + return nullptr; +} + +bool findImageFormatByName(char const* name, ImageFormat* outFormat) +{ + static const struct + { + char const* name; + ImageFormat format; + } kFormats[] = + { +#define FORMAT(NAME) { #NAME, ImageFormat::NAME }, +#include "slang-image-format-defs.h" + }; + + for( auto item : kFormats ) + { + if( strcmp(item.name, name) == 0 ) + { + *outFormat = item.format; + return true; + } + } + + return false; +} + +char const* getGLSLNameForImageFormat(ImageFormat format) +{ + switch( format ) + { + default: return "unhandled"; +#define FORMAT(NAME) case ImageFormat::NAME: return #NAME; +#include "slang-image-format-defs.h" + } +} + +// +// ExistentialSpecializedType +// + +String ExistentialSpecializedType::ToString() +{ + String result; + result.append("__ExistentialSpecializedType("); + result.append(baseType->ToString()); + for( auto arg : slots.args ) + { + result.append(", "); + result.append(arg.type->ToString()); + } + result.append(")"); + return result; +} + +bool ExistentialSpecializedType::EqualsImpl(Type * type) +{ + auto other = as(type); + if(!other) + return false; + + if(!baseType->Equals(other->baseType)) + return false; + + auto argCount = slots.args.getCount(); + if(argCount != other->slots.args.getCount()) + return false; + + for( Index ii = 0; ii < argCount; ++ii ) + { + if(!slots.args[ii].type->Equals(other->slots.args[ii].type)) + return false; + + if(!slots.args[ii].witness->EqualsVal(other->slots.args[ii].witness)) + return false; + } + return true; +} + +int ExistentialSpecializedType::GetHashCode() +{ + Hasher hasher; + hasher.hashObject(baseType); + for(auto arg : slots.args) + { + hasher.hashObject(arg.type); + hasher.hashObject(arg.witness); + } + return hasher.getResult(); +} + +RefPtr ExistentialSpecializedType::CreateCanonicalType() +{ + RefPtr canType = new ExistentialSpecializedType(); + canType->setSession(getSession()); + + canType->baseType = baseType->GetCanonicalType(); + for( auto paramType : slots.paramTypes ) + { + canType->slots.paramTypes.add( paramType->GetCanonicalType() ); + } + for( auto arg : slots.args ) + { + ExistentialTypeSlots::Arg canArg; + canArg.type = arg.type->GetCanonicalType(); + canArg.witness = arg.witness; + canType->slots.args.add(canArg); + } + return canType; +} + +RefPtr ExistentialSpecializedType::SubstituteImpl(SubstitutionSet subst, int* ioDiff) +{ + int diff = 0; + + auto substBaseType = baseType->SubstituteImpl(subst, &diff).as(); + + ExistentialTypeSlots substSlots; + for( auto paramType : slots.paramTypes ) + { + substSlots.paramTypes.add( paramType->SubstituteImpl(subst, &diff).as() ); + } + for( auto arg : slots.args ) + { + ExistentialTypeSlots::Arg substArg; + substArg.type = arg.type->SubstituteImpl(subst, &diff).as(); + substArg.witness = arg.witness->SubstituteImpl(subst, &diff); + substSlots.args.add(substArg); + } + + if(!diff) + return this; + + (*ioDiff)++; + + RefPtr substType = new ExistentialSpecializedType(); + substType->setSession(getSession()); + substType->baseType = substBaseType; + substType->slots = substSlots; + return substType; +} + +} // namespace Slang diff --git a/source/slang/slang-syntax.h b/source/slang/slang-syntax.h new file mode 100644 index 000000000..049220ef9 --- /dev/null +++ b/source/slang/slang-syntax.h @@ -0,0 +1,1419 @@ +#ifndef SLANG_SYNTAX_H +#define SLANG_SYNTAX_H + +#include "../core/slang-basic.h" +#include "slang-ir.h" +#include "slang-lexer.h" +#include "slang-profile.h" +#include "slang-type-system-shared.h" +#include "../../slang.h" + +#include + +namespace Slang +{ + class Module; + class Name; + class Session; + class Substitutions; + class SyntaxVisitor; + class FuncDecl; + class Layout; + + struct IExprVisitor; + struct IDeclVisitor; + struct IModifierVisitor; + struct IStmtVisitor; + struct ITypeVisitor; + struct IValVisitor; + + class Parser; + class SyntaxNode; + + typedef RefPtr (*SyntaxParseCallback)(Parser* parser, void* userData); + + typedef unsigned int ConversionCost; + enum : ConversionCost + { + // No conversion at all + kConversionCost_None = 0, + + // Conversion from a buffer to the type it carries needs to add a minimal + // extra cost, just so we can distinguish an overload on `ConstantBuffer` + // from one on `Foo` + kConversionCost_ImplicitDereference = 10, + + // Conversions based on explicit sub-typing relationships are the cheapest + // + // TODO(tfoley): We will eventually need a discipline for ranking + // when two up-casts are comparable. + kConversionCost_CastToInterface = 50, + + // Conversion that is lossless and keeps the "kind" of the value the same + kConversionCost_RankPromotion = 150, + + // Conversions that are lossless, but change "kind" + kConversionCost_UnsignedToSignedPromotion = 200, + + // Conversion from signed->unsigned integer of same or greater size + kConversionCost_SignedToUnsignedConversion = 300, + + // Cost of converting an integer to a floating-point type + kConversionCost_IntegerToFloatConversion = 400, + + // Default case (usable for user-defined conversions) + kConversionCost_Default = 500, + + // Catch-all for conversions that should be discouraged + // (i.e., that really shouldn't be made implicitly) + // + // TODO: make these conversions not be allowed implicitly in "Slang mode" + kConversionCost_GeneralConversion = 900, + + // This is the cost of an explicit conversion, which should + // not actually be performed. + kConversionCost_Explicit = 90000, + + // Additional conversion cost to add when promoting from a scalar to + // a vector (this will be added to the cost, if any, of converting + // the element type of the vector) + kConversionCost_ScalarToVector = 1, + + // Conversion is impossible + kConversionCost_Impossible = 0xFFFFFFFF, + }; + + enum class ImageFormat + { +#define FORMAT(NAME) NAME, +#include "slang-image-format-defs.h" + }; + + bool findImageFormatByName(char const* name, ImageFormat* outFormat); + char const* getGLSLNameForImageFormat(ImageFormat format); + + // TODO(tfoley): We should ditch this enumeration + // and just use the IR opcodes that represent these + // types directly. The one major complication there + // is that the order of the enum values currently + // matters, since it determines promotion rank. + // We either need to keep that restriction, or + // look up promotion rank by some other means. + // + + class Decl; + class Val; + + // Forward-declare all syntax classes +#define SYNTAX_CLASS(NAME, BASE, ...) class NAME; +#include "slang-object-meta-begin.h" +#include "slang-syntax-defs.h" +#include "slang-object-meta-end.h" + + // Helper type for pairing up a name and the location where it appeared + struct NameLoc + { + Name* name; + SourceLoc loc; + + NameLoc() + : name(nullptr) + {} + + explicit NameLoc(Name* name) + : name(name) + {} + + + NameLoc(Name* name, SourceLoc loc) + : name(name) + , loc(loc) + {} + + NameLoc(Token const& token) + : name(token.getNameOrNull()) + , loc(token.getLoc()) + {} + }; + + // Helper class for iterating over a list of heap-allocated modifiers + struct ModifierList + { + struct Iterator + { + Modifier* current; + + Modifier* operator*() + { + return current; + } + + void operator++(); +#if 0 + { + current = current->next.Ptr(); + } +#endif + + bool operator!=(Iterator other) + { + return current != other.current; + }; + + Iterator() + : current(nullptr) + {} + + Iterator(Modifier* modifier) + : current(modifier) + {} + }; + + ModifierList() + : modifiers(nullptr) + {} + + ModifierList(Modifier* modifiers) + : modifiers(modifiers) + {} + + Iterator begin() { return Iterator(modifiers); } + Iterator end() { return Iterator(nullptr); } + + Modifier* modifiers; + }; + + // Helper class for iterating over heap-allocated modifiers + // of a specific type. + template + struct FilteredModifierList + { + struct Iterator + { + Modifier* current; + + T* operator*() + { + return (T*)current; + } + + void operator++(); + #if 0 + { + current = Adjust(current->next.Ptr()); + } + #endif + + bool operator!=(Iterator other) + { + return current != other.current; + }; + + Iterator() + : current(nullptr) + {} + + Iterator(Modifier* modifier) + : current(modifier) + {} + }; + + FilteredModifierList() + : modifiers(nullptr) + {} + + FilteredModifierList(Modifier* modifiers) + : modifiers(Adjust(modifiers)) + {} + + Iterator begin() { return Iterator(modifiers); } + Iterator end() { return Iterator(nullptr); } + + static Modifier* Adjust(Modifier* modifier); + #if 0 + { + Modifier* m = modifier; + for (;;) + { + if (!m) return m; + if (dynamicCast(m)) return m; + m = m->next.Ptr(); + } + } + #endif + + Modifier* modifiers; + }; + + // A set of modifiers attached to a syntax node + struct Modifiers + { + // The first modifier in the linked list of heap-allocated modifiers + RefPtr first; + + template + FilteredModifierList getModifiersOfType() { return FilteredModifierList(first.Ptr()); } + + // Find the first modifier of a given type, or return `nullptr` if none is found. + template + T* findModifier() + { + return *getModifiersOfType().begin(); + } + + template + bool hasModifier() { return findModifier() != nullptr; } + + FilteredModifierList::Iterator begin() { return FilteredModifierList::Iterator(first.Ptr()); } + FilteredModifierList::Iterator end() { return FilteredModifierList::Iterator(nullptr); } + }; + + class NamedExpressionType; + class GenericDecl; + class ContainerDecl; + + // Try to extract a simple integer value from an `IntVal`. + // This fill assert-fail if the object doesn't represent a literal value. + IntegerLiteralValue GetIntVal(RefPtr val); + + // Represents how much checking has been applied to a declaration. + enum class DeclCheckState : uint8_t + { + // The declaration has been parsed, but not checked + Unchecked, + + // We are in the process of checking the declaration "header" + // (those parts of the declaration needed in order to + // reference it) + CheckingHeader, + + // We are done checking the declaration header. + CheckedHeader, + + // We have checked the declaration fully. + Checked, + }; + + void addModifier( + RefPtr syntax, + RefPtr modifier); + + struct QualType + { + RefPtr type; + bool IsLeftValue; + + QualType() + : IsLeftValue(false) + {} + + QualType(Type* type) + : type(type) + , IsLeftValue(false) + {} + + Type* Ptr() { return type.Ptr(); } + + operator Type*() { return type; } + operator RefPtr() { return type; } + RefPtr operator->() { return type; } + }; + + // A reference to a class of syntax node, that can be + // used to create instances on the fly + struct SyntaxClassBase + { + typedef void* (*CreateFunc)(); + + // Run-time type representation for syntax nodes + struct ClassInfo + { + // Textual class name, for debugging + char const* name; + + // Base class for runtime queries + ClassInfo const* baseClass; + + // Callback to use when creating instances + CreateFunc createFunc; + }; + + SyntaxClassBase() + {} + + SyntaxClassBase(ClassInfo const* classInfoIn) + : classInfo(classInfoIn) + {} + + void* createInstanceImpl() const + { + auto ci = classInfo; + if (!ci) return nullptr; + + auto cf = ci->createFunc; + if (!cf) return nullptr; + + return cf(); + } + + bool isSubClassOfImpl(SyntaxClassBase const& super) const; + + ClassInfo const* classInfo = nullptr; + + template + struct Impl + { + static void* createFunc(); + static const ClassInfo kClassInfo; + }; + }; + + template + struct SyntaxClass : SyntaxClassBase + { + SyntaxClass() + {} + + template + SyntaxClass(SyntaxClass const& other, + typename EnableIf::Value, void>::type* = 0) + : SyntaxClassBase(other.classInfo) + { + } + + T* createInstance() const + { + return (T*)createInstanceImpl(); + } + + SyntaxClass(const ClassInfo* classInfoIn): + SyntaxClassBase(classInfoIn) + {} + + static SyntaxClass getClass() + { + return SyntaxClass(&SyntaxClassBase::Impl::kClassInfo); + } + + template + bool isSubClassOf(SyntaxClass super) + { + return isSubClassOfImpl(super); + } + + template + bool isSubClassOf() + { + return isSubClassOf(SyntaxClass::getClass()); + } + }; + + template + SyntaxClass getClass() + { + return SyntaxClass::getClass(); + } + + struct SubstitutionSet + { + RefPtr substitutions; + operator Substitutions*() const + { + return substitutions; + } + + SubstitutionSet() {} + SubstitutionSet(RefPtr subst) + : substitutions(subst) + { + } + bool Equals(const SubstitutionSet& substSet) const; + int GetHashCode() const; + }; + + template + struct DeclRef; + + // A reference to a declaration, which may include + // substitutions for generic parameters. + struct DeclRefBase + { + typedef Decl DeclType; + + // The underlying declaration + Decl* decl = nullptr; + Decl* getDecl() const { return decl; } + + // Optionally, a chain of substitutions to perform + SubstitutionSet substitutions; + + DeclRefBase() + {} + + DeclRefBase(Decl* decl) + :decl(decl) + {} + + DeclRefBase(Decl* decl, SubstitutionSet subst) + :decl(decl), + substitutions(subst) + {} + + DeclRefBase(Decl* decl, RefPtr subst) + : decl(decl) + , substitutions(subst) + {} + + // Apply substitutions to a type or declaration + RefPtr Substitute(RefPtr type) const; + + DeclRefBase Substitute(DeclRefBase declRef) const; + + // Apply substitutions to an expression + RefPtr Substitute(RefPtr expr) const; + + // Apply substitutions to this declaration reference + DeclRefBase SubstituteImpl(SubstitutionSet subst, int* ioDiff); + + // Returns true if 'as' will return a valid cast + template + bool is() const { return Slang::as(decl) != nullptr; } + + // "dynamic cast" to a more specific declaration reference type + template + DeclRef as() const; + + // Check if this is an equivalent declaration reference to another + bool Equals(DeclRefBase const& declRef) const; + bool operator == (const DeclRefBase& other) const + { + return Equals(other); + } + + // Convenience accessors for common properties of declarations + Name* GetName() const; + SourceLoc getLoc() const; + DeclRefBase GetParent() const; + + int GetHashCode() const; + + // Debugging: + String toString() const; + }; + + template + struct DeclRef : DeclRefBase + { + typedef T DeclType; + + DeclRef() + {} + + DeclRef(T* decl, SubstitutionSet subst) + : DeclRefBase(decl, subst) + {} + + DeclRef(T* decl, RefPtr subst) + : DeclRefBase(decl, SubstitutionSet(subst)) + {} + + template + DeclRef(DeclRef const& other, + typename EnableIf::Value, void>::type* = 0) + : DeclRefBase(other.decl, other.substitutions) + { + } + + T* getDecl() const + { + return (T*)decl; + } + + operator T*() const + { + return getDecl(); + } + + // + static DeclRef unsafeInit(DeclRefBase const& declRef) + { + return DeclRef((T*) declRef.decl, declRef.substitutions); + } + + RefPtr Substitute(RefPtr type) const + { + return DeclRefBase::Substitute(type); + } + RefPtr Substitute(RefPtr expr) const + { + return DeclRefBase::Substitute(expr); + } + + // Apply substitutions to a type or declaration + template + DeclRef Substitute(DeclRef declRef) const + { + return DeclRef::unsafeInit(DeclRefBase::Substitute(declRef)); + } + + // Apply substitutions to this declaration reference + DeclRef SubstituteImpl(SubstitutionSet subst, int* ioDiff) + { + return DeclRef::unsafeInit(DeclRefBase::SubstituteImpl(subst, ioDiff)); + } + + DeclRef GetParent() const + { + return DeclRef::unsafeInit(DeclRefBase::GetParent()); + } + }; + + template + DeclRef DeclRefBase::as() const + { + DeclRef result; + result.decl = Slang::as(decl); + result.substitutions = substitutions; + return result; + } + + template + inline DeclRef makeDeclRef(T* decl) + { + return DeclRef(decl, nullptr); + } + + template + struct FilteredMemberList + { + typedef RefPtr Element; + + FilteredMemberList() + : m_begin(nullptr) + , m_end(nullptr) + {} + + explicit FilteredMemberList( + List const& list) + : m_begin(adjust(list.begin(), list.end())) + , m_end(list.end()) + {} + + struct Iterator + { + Element* m_cursor; + Element* m_end; + + bool operator!=(Iterator const& other) + { + return m_cursor != other.m_cursor; + } + + void operator++() + { + m_cursor = adjust(m_cursor + 1, m_end); + } + + RefPtr& operator*() + { + return *(RefPtr*)m_cursor; + } + }; + + Iterator begin() + { + Iterator iter = { m_begin, m_end }; + return iter; + } + + Iterator end() + { + Iterator iter = { m_end, m_end }; + return iter; + } + + static Element* adjust(Element* cursor, Element* end) + { + while (cursor != end) + { + if (as(*cursor)) + return cursor; + cursor++; + } + return cursor; + } + + // TODO(tfoley): It is ugly to have these. + // We should probably fix the call sites instead. + RefPtr& getFirst() { return *begin(); } + Index getCount() + { + Index count = 0; + for (auto iter : (*this)) + { + (void)iter; + count++; + } + return count; + } + + List> toArray() + { + List> result; + for (auto element : (*this)) + { + result.add(element); + } + return result; + } + + Element* m_begin; + Element* m_end; + }; + + struct TransparentMemberInfo + { + // The declaration of the transparent member + Decl* decl; + }; + + template + struct FilteredMemberRefList + { + List> const& decls; + SubstitutionSet substitutions; + + FilteredMemberRefList( + List> const& decls, + SubstitutionSet substitutions) + : decls(decls) + , substitutions(substitutions) + {} + + int Count() const + { + int count = 0; + for (auto d : *this) + count++; + return count; + } + + List> ToArray() const + { + List> result; + for (auto d : *this) + result.add(d); + return result; + } + + struct Iterator + { + FilteredMemberRefList const* list; + RefPtr* ptr; + RefPtr* end; + + Iterator() : list(nullptr), ptr(nullptr) {} + Iterator( + FilteredMemberRefList const* list, + RefPtr* ptr, + RefPtr* end) + : list(list) + , ptr(ptr) + , end(end) + {} + + bool operator!=(Iterator other) + { + return ptr != other.ptr; + } + + void operator++() + { + ptr = list->Adjust(ptr + 1, end); + } + + DeclRef operator*() + { + return DeclRef((T*) ptr->Ptr(), list->substitutions); + } + }; + + Iterator begin() const { return Iterator(this, Adjust(decls.begin(), decls.end()), decls.end()); } + Iterator end() const { return Iterator(this, decls.end(), decls.end()); } + + RefPtr* Adjust(RefPtr* ptr, RefPtr* end) const + { + for (; ptr != end; ptr++) + { + if (ptr->is()) + { + return ptr; + } + } + return end; + } + }; + + // + // type Expressions + // + + // A "type expression" is a term that we expect to resolve to a type during checking. + // We store both the original syntax and the resolved type here. + struct TypeExp + { + TypeExp() {} + TypeExp(TypeExp const& other) + : exp(other.exp) + , type(other.type) + {} + explicit TypeExp(RefPtr exp) + : exp(exp) + {} + explicit TypeExp(RefPtr type) + : type(type) + {} + TypeExp(RefPtr exp, RefPtr type) + : exp(exp) + , type(type) + {} + + RefPtr exp; + RefPtr type; + + bool Equals(Type* other); +#if 0 + { + return type->Equals(other); + } +#endif + bool Equals(RefPtr other); +#if 0 + { + return type->Equals(other.Ptr()); + } +#endif + Type* Ptr() { return type.Ptr(); } + operator Type*() + { + return type; + } + Type* operator->() { return Ptr(); } + + TypeExp Accept(SyntaxVisitor* visitor); + }; + + + + struct Scope : public RefObject + { + // The parent of this scope (where lookup should go if nothing is found locally) + RefPtr parent; + + // The next sibling of this scope (a peer for lookup) + RefPtr nextSibling; + + // The container to use for lookup + // + // Note(tfoley): This is kept as an unowned pointer + // so that a scope can't keep parts of the AST alive, + // but the opposite it allowed. + ContainerDecl* containerDecl; + }; + + // Masks to be applied when lookup up declarations + enum class LookupMask : uint8_t + { + type = 0x1, + Function = 0x2, + Value = 0x4, + Attribute = 0x8, + + Default = type | Function | Value, + }; + + // Represents one item found during lookup + struct LookupResultItem + { + // Sometimes lookup finds an item, but there were additional + // "hops" taken to reach it. We need to remember these steps + // so that if/when we consturct a full expression we generate + // appropriate AST nodes for all the steps. + // + // We build up a list of these "breadcrumbs" while doing + // lookup, and store them alongside each item found. + // + // As an example, suppose we have an HLSL `cbuffer` declaration: + // + // cbuffer C { float4 f; } + // + // This is syntax sugar for a global-scope variable of + // type `ConstantBuffer` where `T` is a `struct` containing + // all the members: + // + // struct Anon0 { float4 f; }; + // __transparent ConstantBuffer anon1; + // + // The `__transparent` modifier there captures the fact that + // when somebody writes `f` in their code, they expect it to + // "see through" the `cbuffer` declaration (or the global variable, + // in this case) and find the member inside. + // + // But when the user writes `f` we can't just create a simple + // `VarExpr` that refers directly to that field, because that + // doesn't actually reflect the required steps in a way that + // code generation can use. + // + // Instead we need to construct an expression like `(*anon1).f`, + // where there is are two additional steps in the process: + // + // 1. We needed to dereference the pointer-like type `ConstantBuffer` + // to get at a value of type `Anon0` + // 2. We needed to access a sub-field of the aggregate type `Anon0` + // + // We *could* just create these full-formed expressions during + // lookup, but this might mean creating a large number of + // AST nodes in cases where the user calls an overloaded function. + // At the very least we'd rather not heap-allocate in the common + // case where no "extra" steps need to be performed to get to + // the declarations. + // + // This is where "breadcrumbs" come in. A breadcrumb represents + // an extra "step" that must be performed to turn a declaration + // found by lookup into a valid expression to splice into the + // AST. Most of the time lookup result items don't have any + // breadcrumbs, so that no extra heap allocation takes place. + // When an item does have breadcrumbs, and it is chosen as + // the unique result (perhaps by overload resolution), then + // we can walk the list of breadcrumbs to create a full + // expression. + class Breadcrumb : public RefObject + { + public: + enum class Kind : uint8_t + { + // The lookup process looked "through" an in-scope + // declaration to the fields inside of it, so that + // even if lookup started with a simple name `f`, + // it needs to result in a member expression `obj.f`. + Member, + + // The lookup process took a pointer(-like) value, and then + // proceeded to derefence it and look at the thing(s) + // it points to instead, so that the final expression + // needs to have `(*obj)` + Deref, + + // The lookup process saw a value `obj` of type `T` and + // took into account an in-scope constraint that says + // `T` is a subtype of some other type `U`, so that + // lookup was able to find a member through type `U` + // instead. + Constraint, + + // The lookup process considered a member of an + // enclosing type as being in scope, so that any + // reference to that member needs to use a `this` + // expression as appropriate. + This, + }; + + // The kind of lookup step that was performed + Kind kind; + + // For the `Kind::This` case, is the `this` parameter + // mutable or not? + enum class ThisParameterMode : uint8_t + { + Default, + Mutating, + }; + ThisParameterMode thisParameterMode = ThisParameterMode::Default; + + + // As needed, a reference to the declaration that faciliated + // the lookup step. + // + // For a `Member` lookup step, this is the declaration whose + // members were implicitly pulled into scope. + // + // For a `Constraint` lookup step, this is the `ConstraintDecl` + // that serves to witness the subtype relationship. + // + DeclRef declRef; + + // The next implicit step that the lookup process took to + // arrive at a final value. + RefPtr next; + + Breadcrumb( + Kind kind, + DeclRef declRef, + RefPtr next, + ThisParameterMode thisParameterMode = ThisParameterMode::Default) + : kind(kind) + , thisParameterMode(thisParameterMode) + , declRef(declRef) + , next(next) + {} + }; + + // A properly-specialized reference to the declaration that was found. + DeclRef declRef; + + // Any breadcrumbs needed in order to turn that declaration + // reference into a well-formed expression. + // + // This is unused in the simple case where a declaration + // is being referenced directly (rather than through + // transparent members). + RefPtr breadcrumbs; + + LookupResultItem() = default; + explicit LookupResultItem(DeclRef declRef) + : declRef(declRef) + {} + LookupResultItem(DeclRef declRef, RefPtr breadcrumbs) + : declRef(declRef) + , breadcrumbs(breadcrumbs) + {} + }; + + + // Result of looking up a name in some lexical/semantic environment. + // Can be used to enumerate all the declarations matching that name, + // in the case where the result is overloaded. + struct LookupResult + { + // The one item that was found, in the smple case + LookupResultItem item; + + // All of the items that were found, in the complex case. + // Note: if there was no overloading, then this list isn't + // used at all, to avoid allocation. + List items; + + HashSet> lookedupDecls; + + // Was at least one result found? + bool isValid() const { return item.declRef.getDecl() != nullptr; } + + bool isOverloaded() const { return items.getCount() > 1; } + + Name* getName() const + { + return items.getCount() > 1 ? items[0].declRef.GetName() : item.declRef.GetName(); + } + LookupResultItem* begin() + { + if (isValid()) + { + if (isOverloaded()) + return items.begin(); + else + return &item; + } + else + return nullptr; + } + LookupResultItem* end() + { + if (isValid()) + { + if (isOverloaded()) + return items.end(); + else + return &item + 1; + } + else + return nullptr; + } + }; + + struct SemanticsVisitor; + + struct LookupRequest + { + SemanticsVisitor* semantics = nullptr; + RefPtr scope = nullptr; + RefPtr endScope = nullptr; + + LookupMask mask = LookupMask::Default; + }; + + struct WitnessTable; + + // A value that witnesses the satisfaction of an interface + // requirement by a particular declaration or value. + struct RequirementWitness + { + RequirementWitness() + : m_flavor(Flavor::none) + {} + + RequirementWitness(DeclRef declRef) + : m_flavor(Flavor::declRef) + , m_declRef(declRef) + {} + + RequirementWitness(RefPtr val); + + RequirementWitness(RefPtr witnessTable); + + enum class Flavor + { + none, + declRef, + val, + witnessTable, + }; + + Flavor getFlavor() + { + return m_flavor; + } + + DeclRef getDeclRef() + { + SLANG_ASSERT(getFlavor() == Flavor::declRef); + return m_declRef; + } + + RefPtr getVal() + { + SLANG_ASSERT(getFlavor() == Flavor::val); + return m_obj.as(); + } + + RefPtr getWitnessTable(); + + RequirementWitness specialize(SubstitutionSet const& subst); + + Flavor m_flavor; + DeclRef m_declRef; + RefPtr m_obj; + + }; + + typedef Dictionary RequirementDictionary; + + struct WitnessTable : RefObject + { + RequirementDictionary requirementDictionary; + }; + + typedef Dictionary> AttributeArgumentValueDict; + + /// Collects information about existential type parameters and their arguments. + struct ExistentialTypeSlots + { + /// For each type parameter, holds the interface/existential type that constrains it. + List> paramTypes; + + /// An argument for an existential type parameter. + /// + /// Comprises a concrete type and a witness for its conformance to the desired + /// interface/existential type for the corresponding parameter. + /// + struct Arg + { + RefPtr type; + RefPtr witness; + }; + + /// Any arguments provided for the existential type parameters. + /// + /// It is possible for `args` to be empty even if `paramTypes` is non-empty; + /// that situation represents an unspecialized program or entry point. + /// + List args; + }; + + + // Generate class definition for all syntax classes +#define SYNTAX_FIELD(TYPE, NAME) TYPE NAME; +#define FIELD(TYPE, NAME) TYPE NAME; +#define FIELD_INIT(TYPE, NAME, INIT) TYPE NAME = INIT; +#define RAW(...) __VA_ARGS__ +#define END_SYNTAX_CLASS() }; +#define SYNTAX_CLASS(NAME, BASE, ...) class NAME : public BASE {public: +#include "slang-object-meta-begin.h" + +#include "slang-syntax-base-defs.h" +#undef SYNTAX_CLASS + +#undef ABSTRACT_SYNTAX_CLASS +#define ABSTRACT_SYNTAX_CLASS(NAME, BASE, ...) \ + class NAME : public BASE { \ + public: /* ... */ +#define SYNTAX_CLASS(NAME, BASE, ...) \ + class NAME : public BASE { \ + virtual void accept(NAME::Visitor* visitor, void* extra) override; \ + public: virtual SyntaxClass getClass() override; \ + public: /* ... */ +#include "slang-expr-defs.h" +#include "slang-decl-defs.h" +#include "slang-modifier-defs.h" +#include "slang-stmt-defs.h" +#include "slang-type-defs.h" +#include "slang-val-defs.h" + +#include "slang-object-meta-end.h" + + inline RefPtr GetSub(DeclRef const& declRef) + { + return declRef.Substitute(declRef.getDecl()->sub.Ptr()); + } + + inline RefPtr GetSup(DeclRef const& declRef) + { + return declRef.Substitute(declRef.getDecl()->getSup().type); + } + + // Note(tfoley): These logically belong to `Type`, + // but order-of-declaration stuff makes that tricky + // + // TODO(tfoley): These should really belong to the compilation context! + // + void registerBuiltinDecl( + Session* session, + RefPtr decl, + RefPtr modifier); + void registerMagicDecl( + Session* session, + RefPtr decl, + RefPtr modifier); + + // Look up a magic declaration by its name + RefPtr findMagicDecl( + Session* session, + String const& name); + + // Create an instance of a syntax class by name + SyntaxNodeBase* createInstanceOfSyntaxClassByName( + String const& name); + + // `Val` + + inline bool areValsEqual(Val* left, Val* right) + { + if(!left || !right) return left == right; + return left->EqualsVal(right); + } + + // + + inline BaseType GetVectorBaseType(VectorExpressionType* vecType) + { + auto basicExprType = as(vecType->elementType); + return basicExprType->baseType; + } + + inline int GetVectorSize(VectorExpressionType* vecType) + { + auto constantVal = as(vecType->elementCount); + if (constantVal) + return (int) constantVal->value; + // TODO: what to do in this case? + return 0; + } + + // + // Declarations + // + + inline ExtensionDecl* GetCandidateExtensions(DeclRef const& declRef) + { + return declRef.getDecl()->candidateExtensions; + } + + inline FilteredMemberRefList getMembers(DeclRef const& declRef) + { + return FilteredMemberRefList(declRef.getDecl()->Members, declRef.substitutions); + } + + template + inline FilteredMemberRefList getMembersOfType(DeclRef const& declRef) + { + return FilteredMemberRefList(declRef.getDecl()->Members, declRef.substitutions); + } + + template + inline List> getMembersOfTypeWithExt(DeclRef const& declRef) + { + List> rs; + for (auto d : getMembersOfType(declRef)) + rs.add(d); + if (auto aggDeclRef = declRef.as()) + { + for (auto ext = GetCandidateExtensions(aggDeclRef); ext; ext = ext->nextCandidateExtension) + { + auto extMembers = getMembersOfType(DeclRef(ext, declRef.substitutions)); + for (auto mbr : extMembers) + rs.add(mbr); + } + } + return rs; + } + + /// The the user-level name for a variable that might be a shader parameter. + /// + /// In most cases this is just the name of the variable declaration itself, + /// but in the specific case of a `cbuffer`, the name that the user thinks + /// of is really metadata. For example: + /// + /// cbuffer C { int x; } + /// + /// In this example, error messages relating to the constant buffer should + /// really use the name `C`, but that isn't the name of the declaration + /// (it is in practice anonymous, and `C` can be used for a different + /// declaration in the same file). + /// + Name* getReflectionName(VarDeclBase* varDecl); + + inline RefPtr GetType(DeclRef const& declRef) + { + return declRef.Substitute(declRef.getDecl()->type.Ptr()); + } + + inline RefPtr getInitExpr(DeclRef const& declRef) + { + return declRef.Substitute(declRef.getDecl()->initExpr); + } + + inline RefPtr getType(DeclRef const& declRef) + { + return declRef.Substitute(declRef.getDecl()->type.Ptr()); + } + + inline RefPtr getTagExpr(DeclRef const& declRef) + { + return declRef.Substitute(declRef.getDecl()->tagExpr); + } + + inline RefPtr GetTargetType(DeclRef const& declRef) + { + return declRef.Substitute(declRef.getDecl()->targetType.Ptr()); + } + + inline FilteredMemberRefList GetFields(DeclRef const& declRef) + { + return getMembersOfType(declRef); + } + + inline RefPtr getBaseType(DeclRef const& declRef) + { + return declRef.Substitute(declRef.getDecl()->base.type); + } + + inline RefPtr GetType(DeclRef const& declRef) + { + return declRef.Substitute(declRef.getDecl()->type.Ptr()); + } + + inline RefPtr GetResultType(DeclRef const& declRef) + { + return declRef.Substitute(declRef.getDecl()->ReturnType.type.Ptr()); + } + + inline FilteredMemberRefList GetParameters(DeclRef const& declRef) + { + return getMembersOfType(declRef); + } + + inline Decl* GetInner(DeclRef const& declRef) + { + // TODO: Should really return a `DeclRef` for the inner + // declaration, and not just a raw pointer + return declRef.getDecl()->inner.Ptr(); + } + + + // + + RefPtr getArrayType( + Type* elementType, + IntVal* elementCount); + + RefPtr getArrayType( + Type* elementType); + + RefPtr getNamedType( + Session* session, + DeclRef const& declRef); + + RefPtr getTypeType( + Type* type); + + RefPtr getFuncType( + Session* session, + DeclRef const& declRef); + + RefPtr getGenericDeclRefType( + Session* session, + DeclRef const& declRef); + + RefPtr getSamplerStateType( + Session* session); + + + // Definitions that can't come earlier despite + // being in templates, because gcc/clang get angry. + // + template + void FilteredModifierList::Iterator::operator++() + { + current = Adjust(current->next.Ptr()); + } + // + template + Modifier* FilteredModifierList::Adjust(Modifier* modifier) + { + Modifier* m = modifier; + for (;;) + { + if (!m) return m; + if (as(m)) + { + return m; + } + m = m->next.Ptr(); + } + } + + // TODO: where should this live? + SubstitutionSet createDefaultSubstitutions( + Session* session, + Decl* decl, + SubstitutionSet parentSubst); + + SubstitutionSet createDefaultSubstitutions( + Session* session, + Decl* decl); + + DeclRef createDefaultSubstitutionsIfNeeded( + Session* session, + DeclRef declRef); + + RefPtr createDefaultSubsitutionsForGeneric( + Session* session, + GenericDecl* genericDecl, + RefPtr outerSubst); + + RefPtr findInnerMostGenericSubstitution(Substitutions* subst); + + enum class UserDefinedAttributeTargets + { + None = 0, + Struct = 1, + Var = 2, + Function = 4, + All = 7 + }; + + /// Get the module that a declaration is associated with, if any. + Module* getModule(Decl* decl); + +} // namespace Slang + +#endif diff --git a/source/slang/slang-token-defs.h b/source/slang/slang-token-defs.h new file mode 100644 index 000000000..6cece330e --- /dev/null +++ b/source/slang/slang-token-defs.h @@ -0,0 +1,96 @@ +// slang-token-defs.h + +// This file is meant to be included multiple times, to produce different +// pieces of code related to tokens +// +// Each token is declared here with: +// +// TOKEN(id, desc) +// +// where `id` is the identifier that will be used for the token in +// ordinary code, while `desc` is name we should print when +// referring to this token in diagnostic messages. + + +#ifndef TOKEN +#error Need to define TOKEN(ID, DESC) before including "token-defs.h" +#endif + +TOKEN(Unknown, "") +TOKEN(EndOfFile, "end of file") +TOKEN(EndOfDirective, "end of line") +TOKEN(Invalid, "invalid character") +TOKEN(Identifier, "identifier") +TOKEN(IntegerLiteral, "integer literal") +TOKEN(FloatingPointLiteral, "floating-point literal") +TOKEN(StringLiteral, "string literal") +TOKEN(CharLiteral, "character literal") +TOKEN(WhiteSpace, "whitespace") +TOKEN(NewLine, "newline") +TOKEN(LineComment, "line comment") +TOKEN(BlockComment, "block comment") +TOKEN(DirectiveMessage, "user-defined message") + +#define PUNCTUATION(id, text) \ + TOKEN(id, "'" text "'") + +PUNCTUATION(Semicolon, ";") +PUNCTUATION(Comma, ",") +PUNCTUATION(Dot, ".") + +PUNCTUATION(LBrace, "{") +PUNCTUATION(RBrace, "}") +PUNCTUATION(LBracket, "[") +PUNCTUATION(RBracket, "]") +PUNCTUATION(LParent, "(") +PUNCTUATION(RParent, ")") + +PUNCTUATION(OpAssign, "=") +PUNCTUATION(OpAdd, "+") +PUNCTUATION(OpSub, "-") +PUNCTUATION(OpMul, "*") +PUNCTUATION(OpDiv, "/") +PUNCTUATION(OpMod, "%") +PUNCTUATION(OpNot, "!") +PUNCTUATION(OpBitNot, "~") +PUNCTUATION(OpLsh, "<<") +PUNCTUATION(OpRsh, ">>") +PUNCTUATION(OpEql, "==") +PUNCTUATION(OpNeq, "!=") +PUNCTUATION(OpGreater, ">") +PUNCTUATION(OpLess, "<") +PUNCTUATION(OpGeq, ">=") +PUNCTUATION(OpLeq, "<=") +PUNCTUATION(OpAnd, "&&") +PUNCTUATION(OpOr, "||") +PUNCTUATION(OpBitAnd, "&") +PUNCTUATION(OpBitOr, "|") +PUNCTUATION(OpBitXor, "^") +PUNCTUATION(OpInc, "++") +PUNCTUATION(OpDec, "--") + +PUNCTUATION(OpAddAssign, "+=") +PUNCTUATION(OpSubAssign, "-=") +PUNCTUATION(OpMulAssign, "*=") +PUNCTUATION(OpDivAssign, "/=") +PUNCTUATION(OpModAssign, "%=") +PUNCTUATION(OpShlAssign, "<<=") +PUNCTUATION(OpShrAssign, ">>=") +PUNCTUATION(OpAndAssign, "&=") +PUNCTUATION(OpOrAssign, "|=") +PUNCTUATION(OpXorAssign, "^=") + +PUNCTUATION(QuestionMark, "?") +PUNCTUATION(Colon, ":") +PUNCTUATION(RightArrow, "->") +PUNCTUATION(At, "@") +PUNCTUATION(Dollar, "$") +PUNCTUATION(Pound, "#") +PUNCTUATION(PoundPound, "##") + +PUNCTUATION(Scope, "::") + +#undef PUNCTUATION + +// Un-define the `TOKEN` macro so that client doesn't have to +#undef TOKEN diff --git a/source/slang/slang-token.cpp b/source/slang/slang-token.cpp new file mode 100644 index 000000000..a8238eb6c --- /dev/null +++ b/source/slang/slang-token.cpp @@ -0,0 +1,39 @@ +// slang-token.cpp +#include "slang-token.h" + +#include + +namespace Slang { + + +Name* Token::getName() const +{ + return getNameOrNull(); +} + +Name* Token::getNameOrNull() const +{ + switch (type) + { + default: + return nullptr; + + case TokenType::Identifier: + return (Name*) ptrValue; + } +} + +char const* TokenTypeToString(TokenType type) +{ + switch( type ) + { + default: + SLANG_ASSERT(!"unexpected"); + return ""; + +#define TOKEN(NAME, DESC) case TokenType::NAME: return DESC; +#include "slang-token-defs.h" + } +} + +} // namespace Slang diff --git a/source/slang/slang-token.h b/source/slang/slang-token.h new file mode 100644 index 000000000..193b128fa --- /dev/null +++ b/source/slang/slang-token.h @@ -0,0 +1,67 @@ +// slang-token.h +#ifndef SLANG_TOKEN_H_INCLUDED +#define SLANG_TOKEN_H_INCLUDED + +#include "../core/slang-basic.h" + +#include "slang-source-loc.h" + +namespace Slang { + +class Name; + +enum class TokenType +{ +#define TOKEN(NAME, DESC) NAME, +#include "slang-token-defs.h" +}; + +char const* TokenTypeToString(TokenType type); + +enum TokenFlag : unsigned int +{ + AtStartOfLine = 1 << 0, + AfterWhitespace = 1 << 1, + SuppressMacroExpansion = 1 << 2, + ScrubbingNeeded = 1 << 3, +}; +typedef unsigned int TokenFlags; + +class Token +{ +public: + TokenType type = TokenType::Unknown; + TokenFlags flags = 0; + + SourceLoc loc; + void* ptrValue; + + UnownedStringSlice Content; + + Token() = default; + + Token( + TokenType typeIn, + const UnownedStringSlice & contentIn, + SourceLoc locIn, + TokenFlags flagsIn = 0) + : flags(flagsIn) + { + type = typeIn; + Content = contentIn; + loc = locIn; + ptrValue = nullptr; + } + + Name* getName() const; + + Name* getNameOrNull() const; + + SourceLoc getLoc() const { return loc; } +}; + + + +} // namespace Slang + +#endif diff --git a/source/slang/slang-type-defs.h b/source/slang/slang-type-defs.h new file mode 100644 index 000000000..d9907bafe --- /dev/null +++ b/source/slang/slang-type-defs.h @@ -0,0 +1,490 @@ +// slang-type-defs.h + +// Syntax class definitions for types. + +// The type of a reference to an overloaded name +SYNTAX_CLASS(OverloadGroupType, Type) +RAW( +public: + virtual String ToString() override; + +protected: + virtual bool EqualsImpl(Type * type) override; + virtual RefPtr CreateCanonicalType() override; + virtual int GetHashCode() override; +) +END_SYNTAX_CLASS() + +// The type of an initializer-list expression (before it has +// been coerced to some other type) +SYNTAX_CLASS(InitializerListType, Type) +RAW( + virtual String ToString() override; + +protected: + virtual bool EqualsImpl(Type * type) override; + virtual RefPtr CreateCanonicalType() override; + virtual int GetHashCode() override; +) +END_SYNTAX_CLASS() + +// The type of an expression that was erroneous +SYNTAX_CLASS(ErrorType, Type) +RAW( +public: + virtual String ToString() override; + +protected: + virtual bool EqualsImpl(Type * type) override; + virtual RefPtr SubstituteImpl(SubstitutionSet subst, int* ioDiff) override; + virtual RefPtr CreateCanonicalType() override; + virtual int GetHashCode() override; +) +END_SYNTAX_CLASS() + +// A type that takes the form of a reference to some declaration +SYNTAX_CLASS(DeclRefType, Type) + DECL_FIELD(DeclRef, declRef) + +RAW( + virtual String ToString() override; + virtual RefPtr SubstituteImpl(SubstitutionSet subst, int* ioDiff) override; + + static RefPtr Create( + Session* session, + DeclRef declRef); + + DeclRefType() + {} + DeclRefType( + DeclRef declRef) + : declRef(declRef) + {} +protected: + virtual int GetHashCode() override; + virtual bool EqualsImpl(Type * type) override; + virtual RefPtr CreateCanonicalType() override; +) +END_SYNTAX_CLASS() + +// Base class for types that can be used in arithmetic expressions +ABSTRACT_SYNTAX_CLASS(ArithmeticExpressionType, DeclRefType) +RAW( +public: + virtual BasicExpressionType* GetScalarType() = 0; +) +END_SYNTAX_CLASS() + +SYNTAX_CLASS(BasicExpressionType, ArithmeticExpressionType) + + FIELD(BaseType, baseType) + +RAW( + BasicExpressionType() {} + BasicExpressionType( + Slang::BaseType baseType) + : baseType(baseType) + {} +protected: + virtual BasicExpressionType* GetScalarType() override; + virtual bool EqualsImpl(Type * type) override; + virtual RefPtr CreateCanonicalType() override; +) +END_SYNTAX_CLASS() + +// Base type for things that are built in to the compiler, +// and will usually have special behavior or a custom +// mapping to the IR level. +ABSTRACT_SYNTAX_CLASS(BuiltinType, DeclRefType) +END_SYNTAX_CLASS() + +// Resources that contain "elements" that can be fetched +ABSTRACT_SYNTAX_CLASS(ResourceType, BuiltinType) + // The type that results from fetching an element from this resource + SYNTAX_FIELD(RefPtr, elementType) + + // Shape and access level information for this resource type + FIELD(TextureFlavor, flavor) + + RAW( + TextureFlavor::Shape GetBaseShape() + { + return flavor.GetBaseShape(); + } + bool isMultisample() { return flavor.isMultisample(); } + bool isArray() { return flavor.isArray(); } + SlangResourceShape getShape() const { return flavor.getShape(); } + SlangResourceAccess getAccess() { return flavor.getAccess(); } + + ) +END_SYNTAX_CLASS() + +ABSTRACT_SYNTAX_CLASS(TextureTypeBase, ResourceType) +RAW( + TextureTypeBase() + {} + TextureTypeBase( + TextureFlavor flavor, + RefPtr elementType) + { + this->elementType = elementType; + this->flavor = flavor; + } +) +END_SYNTAX_CLASS() + +SYNTAX_CLASS(TextureType, TextureTypeBase) +RAW( + TextureType() + {} + TextureType( + TextureFlavor flavor, + RefPtr elementType) + : TextureTypeBase(flavor, elementType) + {} +) +END_SYNTAX_CLASS() + +// This is a base type for texture/sampler pairs, +// as they exist in, e.g., GLSL +SYNTAX_CLASS(TextureSamplerType, TextureTypeBase) +RAW( + TextureSamplerType() + {} + TextureSamplerType( + TextureFlavor flavor, + RefPtr elementType) + : TextureTypeBase(flavor, elementType) + {} +) +END_SYNTAX_CLASS() + +// This is a base type for `image*` types, as they exist in GLSL +SYNTAX_CLASS(GLSLImageType, TextureTypeBase) +RAW( + GLSLImageType() + {} + GLSLImageType( + TextureFlavor flavor, + RefPtr elementType) + : TextureTypeBase(flavor, elementType) + {} +) +END_SYNTAX_CLASS() + +SYNTAX_CLASS(SamplerStateType, BuiltinType) + // What flavor of sampler state is this + FIELD(SamplerStateFlavor, flavor) +END_SYNTAX_CLASS() + +// Other cases of generic types known to the compiler +SYNTAX_CLASS(BuiltinGenericType, BuiltinType) + SYNTAX_FIELD(RefPtr, elementType) + + RAW(Type* getElementType() { return elementType; }) +END_SYNTAX_CLASS() + +// Types that behave like pointers, in that they can be +// dereferenced (implicitly) to access members defined +// in the element type. +SIMPLE_SYNTAX_CLASS(PointerLikeType, BuiltinGenericType) + +// HLSL buffer-type resources + +SIMPLE_SYNTAX_CLASS(HLSLStructuredBufferTypeBase, BuiltinGenericType) +SIMPLE_SYNTAX_CLASS(HLSLStructuredBufferType, HLSLStructuredBufferTypeBase) +SIMPLE_SYNTAX_CLASS(HLSLRWStructuredBufferType, HLSLStructuredBufferTypeBase) +SIMPLE_SYNTAX_CLASS(HLSLRasterizerOrderedStructuredBufferType, HLSLStructuredBufferTypeBase) + +SIMPLE_SYNTAX_CLASS(UntypedBufferResourceType, BuiltinType) +SIMPLE_SYNTAX_CLASS(HLSLByteAddressBufferType, UntypedBufferResourceType) +SIMPLE_SYNTAX_CLASS(HLSLRWByteAddressBufferType, UntypedBufferResourceType) +SIMPLE_SYNTAX_CLASS(HLSLRasterizerOrderedByteAddressBufferType, UntypedBufferResourceType) +SIMPLE_SYNTAX_CLASS(RaytracingAccelerationStructureType, UntypedBufferResourceType) + +SIMPLE_SYNTAX_CLASS(HLSLAppendStructuredBufferType, HLSLStructuredBufferTypeBase) +SIMPLE_SYNTAX_CLASS(HLSLConsumeStructuredBufferType, HLSLStructuredBufferTypeBase) + +SYNTAX_CLASS(HLSLPatchType, BuiltinType) +RAW( + Type* getElementType(); + IntVal* getElementCount(); +) +END_SYNTAX_CLASS() + +SIMPLE_SYNTAX_CLASS(HLSLInputPatchType, HLSLPatchType) +SIMPLE_SYNTAX_CLASS(HLSLOutputPatchType, HLSLPatchType) + +// HLSL geometry shader output stream types + +SIMPLE_SYNTAX_CLASS(HLSLStreamOutputType, BuiltinGenericType) +SIMPLE_SYNTAX_CLASS(HLSLPointStreamType, HLSLStreamOutputType) +SIMPLE_SYNTAX_CLASS(HLSLLineStreamType, HLSLStreamOutputType) +SIMPLE_SYNTAX_CLASS(HLSLTriangleStreamType, HLSLStreamOutputType) + +// +SIMPLE_SYNTAX_CLASS(GLSLInputAttachmentType, BuiltinType) + +// Base class for types used when desugaring parameter block +// declarations, includeing HLSL `cbuffer` or GLSL `uniform` blocks. +SIMPLE_SYNTAX_CLASS(ParameterGroupType, PointerLikeType) + +SIMPLE_SYNTAX_CLASS(UniformParameterGroupType, ParameterGroupType) +SIMPLE_SYNTAX_CLASS(VaryingParameterGroupType, ParameterGroupType) + +// type for HLSL `cbuffer` declarations, and `ConstantBuffer` +// ALso used for GLSL `uniform` blocks. +SIMPLE_SYNTAX_CLASS(ConstantBufferType, UniformParameterGroupType) + +// type for HLSL `tbuffer` declarations, and `TextureBuffer` +SIMPLE_SYNTAX_CLASS(TextureBufferType, UniformParameterGroupType) + +// type for GLSL `in` and `out` blocks +SIMPLE_SYNTAX_CLASS(GLSLInputParameterGroupType, VaryingParameterGroupType) +SIMPLE_SYNTAX_CLASS(GLSLOutputParameterGroupType, VaryingParameterGroupType) + +// type for GLLSL `buffer` blocks +SIMPLE_SYNTAX_CLASS(GLSLShaderStorageBufferType, UniformParameterGroupType) + +// type for Slang `ParameterBlock` type +SIMPLE_SYNTAX_CLASS(ParameterBlockType, UniformParameterGroupType) + +SYNTAX_CLASS(ArrayExpressionType, Type) + SYNTAX_FIELD(RefPtr, baseType) + SYNTAX_FIELD(RefPtr, ArrayLength) + +RAW( + virtual Slang::String ToString() override; + +protected: + virtual bool EqualsImpl(Type * type) override; + virtual RefPtr CreateCanonicalType() override; + virtual RefPtr SubstituteImpl(SubstitutionSet subst, int* ioDiff) override; + virtual int GetHashCode() override; + ) +END_SYNTAX_CLASS() + +// The "type" of an expression that resolves to a type. +// For example, in the expression `float(2)` the sub-expression, +// `float` would have the type `TypeType(float)`. +SYNTAX_CLASS(TypeType, Type) + // The type that this is the type of... + SYNTAX_FIELD(RefPtr, type) + +RAW( +public: + TypeType() + {} + TypeType(RefPtr type) + : type(type) + {} + + virtual String ToString() override; + +protected: + virtual bool EqualsImpl(Type * type) override; + virtual RefPtr CreateCanonicalType() override; + virtual int GetHashCode() override; +) +END_SYNTAX_CLASS() + +// A vector type, e.g., `vector` +SYNTAX_CLASS(VectorExpressionType, ArithmeticExpressionType) + + // The type of vector elements. + // As an invariant, this should be a basic type or an alias. + SYNTAX_FIELD(RefPtr, elementType) + + // The number of elements + SYNTAX_FIELD(RefPtr, elementCount) + +RAW( + virtual String ToString() override; + +protected: + virtual BasicExpressionType* GetScalarType() override; +) +END_SYNTAX_CLASS() + +// A matrix type, e.g., `matrix` +SYNTAX_CLASS(MatrixExpressionType, ArithmeticExpressionType) +RAW( + + Type* getElementType(); + IntVal* getRowCount(); + IntVal* getColumnCount(); + + RefPtr getRowType(); + + virtual String ToString() override; + +protected: + virtual BasicExpressionType* GetScalarType() override; + +private: + RefPtr mRowType; +) +END_SYNTAX_CLASS() + +// The built-in `String` type +SIMPLE_SYNTAX_CLASS(StringType, BuiltinType) + +// Type built-in `__EnumType` type +SYNTAX_CLASS(EnumTypeType, BuiltinType) + +// TODO: provide accessors for the declaration, the "tag" type, etc. + +END_SYNTAX_CLASS() + +// Base class for types that map down to +// simple pointers as part of code generation. +SYNTAX_CLASS(PtrTypeBase, BuiltinType) +RAW( + // Get the type of the pointed-to value. + Type* getValueType(); +) +END_SYNTAX_CLASS() + +// A true (user-visible) pointer type, e.g., `T*` +SYNTAX_CLASS(PtrType, PtrTypeBase) +END_SYNTAX_CLASS() + +// A type that represents the behind-the-scenes +// logical pointer that is passed for an `out` +// or `in out` parameter +SYNTAX_CLASS(OutTypeBase, PtrTypeBase) +END_SYNTAX_CLASS() + +// The type for an `out` parameter, e.g., `out T` +SYNTAX_CLASS(OutType, OutTypeBase) +END_SYNTAX_CLASS() + +// The type for an `in out` parameter, e.g., `in out T` +SYNTAX_CLASS(InOutType, OutTypeBase) +END_SYNTAX_CLASS() + +// The type for an `ref` parameter, e.g., `ref T` +SYNTAX_CLASS(RefType, PtrTypeBase) +END_SYNTAX_CLASS() + +// A type alias of some kind (e.g., via `typedef`) +SYNTAX_CLASS(NamedExpressionType, Type) +DECL_FIELD(DeclRef, declRef) + +RAW( + RefPtr innerType; + NamedExpressionType() + {} + NamedExpressionType( + DeclRef declRef) + : declRef(declRef) + {} + + + virtual String ToString() override; + +protected: + virtual bool EqualsImpl(Type * type) override; + virtual RefPtr CreateCanonicalType() override; + virtual int GetHashCode() override; +) +END_SYNTAX_CLASS() + +// A function type is defined by its parameter types +// and its result type. +SYNTAX_CLASS(FuncType, Type) + + // TODO: We may want to preserve parameter names + // in the list here, just so that we can print + // out friendly names when printing a function + // type, even if they don't affect the actual + // semantic type underneath. + + FIELD(List>, paramTypes) + FIELD(RefPtr, resultType) +RAW( + FuncType() + {} + + UInt getParamCount() { return paramTypes.getCount(); } + Type* getParamType(UInt index) { return paramTypes[index]; } + Type* getResultType() { return resultType; } + + virtual String ToString() override; +protected: + virtual RefPtr SubstituteImpl(SubstitutionSet subst, int* ioDiff) override; + virtual bool EqualsImpl(Type * type) override; + virtual RefPtr CreateCanonicalType() override; + virtual int GetHashCode() override; +) +END_SYNTAX_CLASS() + +// The "type" of an expression that names a generic declaration. +SYNTAX_CLASS(GenericDeclRefType, Type) + + DECL_FIELD(DeclRef, declRef) + + RAW( + GenericDeclRefType() + {} + GenericDeclRefType( + DeclRef declRef) + : declRef(declRef) + {} + + + DeclRef const& GetDeclRef() const { return declRef; } + + virtual String ToString() override; + +protected: + virtual bool EqualsImpl(Type * type) override; + virtual int GetHashCode() override; + virtual RefPtr CreateCanonicalType() override; +) +END_SYNTAX_CLASS() + +// The concrete type for a value wrapped in an existential, accessible +// when the existential is "opened" in some context. +SYNTAX_CLASS(ExtractExistentialType, Type) +RAW( + DeclRef declRef; + + virtual String ToString() override; + virtual bool EqualsImpl(Type * type) override; + virtual int GetHashCode() override; + virtual RefPtr CreateCanonicalType() override; + virtual RefPtr SubstituteImpl(SubstitutionSet subst, int* ioDiff) override; +) +END_SYNTAX_CLASS() + + /// A tagged union of zero or more other types. +SYNTAX_CLASS(TaggedUnionType, Type) +RAW( + /// The distinct "cases" the tagged union can store. + /// + /// For each type in this array, the array index is the + /// tag value for that case. + /// + List> caseTypes; + + virtual String ToString() override; + virtual bool EqualsImpl(Type * type) override; + virtual int GetHashCode() override; + virtual RefPtr CreateCanonicalType() override; + virtual RefPtr SubstituteImpl(SubstitutionSet subst, int* ioDiff) override; +) +END_SYNTAX_CLASS() + +SYNTAX_CLASS(ExistentialSpecializedType, Type) +RAW( + RefPtr baseType; + ExistentialTypeSlots slots; + + virtual String ToString() override; + virtual bool EqualsImpl(Type * type) override; + virtual int GetHashCode() override; + virtual RefPtr CreateCanonicalType() override; + virtual RefPtr SubstituteImpl(SubstitutionSet subst, int* ioDiff) override; +) +END_SYNTAX_CLASS() diff --git a/source/slang/slang-type-layout.cpp b/source/slang/slang-type-layout.cpp new file mode 100644 index 000000000..30ab53ca6 --- /dev/null +++ b/source/slang/slang-type-layout.cpp @@ -0,0 +1,3209 @@ +// slang-type-layout.cpp +#include "slang-type-layout.h" + +#include "slang-syntax.h" + +#include + +namespace Slang { + +size_t RoundToAlignment(size_t offset, size_t alignment) +{ + size_t remainder = offset % alignment; + if (remainder == 0) + return offset; + else + return offset + (alignment - remainder); +} + +LayoutSize RoundToAlignment(LayoutSize offset, size_t alignment) +{ + // An infinite size is assumed to be maximally aligned. + if(offset.isInfinite()) + return LayoutSize::infinite(); + + return RoundToAlignment(offset.getFiniteValue(), alignment); +} + +static size_t RoundUpToPowerOfTwo( size_t value ) +{ + // TODO(tfoley): I know this isn't a fast approach + size_t result = 1; + while (result < value) + result *= 2; + return result; +} + +// + +struct DefaultLayoutRulesImpl : SimpleLayoutRulesImpl +{ + // Get size and alignment for a single value of base type. + SimpleLayoutInfo GetScalarLayout(BaseType baseType) override + { + switch (baseType) + { + case BaseType::Void: return SimpleLayoutInfo(); + + // Note: By convention, a `bool` in a constant buffer is stored as an `int. + // This default may eventually change, at which point this logic will need + // to be updated. + // + // TODO: We should probably warn in this case, since storing a `bool` in + // a constant buffer seems like a Bad Idea anyway. + // + case BaseType::Bool: return SimpleLayoutInfo( LayoutResourceKind::Uniform, 4, 4 ); + + + case BaseType::Int8: return SimpleLayoutInfo( LayoutResourceKind::Uniform, 1,1); + case BaseType::Int16: return SimpleLayoutInfo( LayoutResourceKind::Uniform, 2,2); + case BaseType::Int: return SimpleLayoutInfo( LayoutResourceKind::Uniform, 4,4); + case BaseType::Int64: return SimpleLayoutInfo( LayoutResourceKind::Uniform, 8,8); + + case BaseType::UInt8: return SimpleLayoutInfo( LayoutResourceKind::Uniform, 1,1); + case BaseType::UInt16: return SimpleLayoutInfo( LayoutResourceKind::Uniform, 2,2); + case BaseType::UInt: return SimpleLayoutInfo( LayoutResourceKind::Uniform, 4,4); + case BaseType::UInt64: return SimpleLayoutInfo( LayoutResourceKind::Uniform, 8,8); + + case BaseType::Half: return SimpleLayoutInfo( LayoutResourceKind::Uniform, 2,2); + case BaseType::Float: return SimpleLayoutInfo( LayoutResourceKind::Uniform, 4,4); + case BaseType::Double: return SimpleLayoutInfo( LayoutResourceKind::Uniform, 8,8); + + default: + SLANG_UNEXPECTED("uhandled scalar type"); + UNREACHABLE_RETURN(SimpleLayoutInfo( LayoutResourceKind::Uniform, 0, 1 )); + } + } + + SimpleArrayLayoutInfo GetArrayLayout( SimpleLayoutInfo elementInfo, LayoutSize elementCount) override + { + SLANG_RELEASE_ASSERT(elementInfo.size.isFinite()); + auto elementSize = elementInfo.size.getFiniteValue(); + auto elementAlignment = elementInfo.alignment; + auto elementStride = RoundToAlignment(elementSize, elementAlignment); + + // An array with no elements will have zero size. + // + LayoutSize arraySize = 0; + // + // Any array with a non-zero number of elements will need + // to have space for N elements of size `elementSize`, with + // the constraints that there must be `elementStride` bytes + // between consecutive elements. + // + if( elementCount > 0 ) + { + // We can think of this as either allocating (N-1) + // chunks of size `elementStride` (for most of the elements) + // and then one final chunk of size `elementSize` for + // the last element, or equivalently as allocating + // N chunks of size `elementStride` and then "giving back" + // the final `elementStride - elementSize` bytes. + // + arraySize = (elementStride * (elementCount-1)) + elementSize; + } + + SimpleArrayLayoutInfo arrayInfo; + arrayInfo.kind = elementInfo.kind; + arrayInfo.size = arraySize; + arrayInfo.alignment = elementAlignment; + arrayInfo.elementStride = elementStride; + return arrayInfo; + } + + SimpleLayoutInfo GetVectorLayout(SimpleLayoutInfo elementInfo, size_t elementCount) override + { + SimpleLayoutInfo vectorInfo; + vectorInfo.kind = elementInfo.kind; + vectorInfo.size = elementInfo.size * elementCount; + vectorInfo.alignment = elementInfo.alignment; + return vectorInfo; + } + + SimpleArrayLayoutInfo GetMatrixLayout(SimpleLayoutInfo elementInfo, size_t rowCount, size_t columnCount) override + { + // The default behavior here is to lay out a matrix + // as an array of row vectors (that is row-major). + // + // In practice, the code that calls `GetMatrixLayout` will + // potentially transpose the row/column counts in order + // to get layouts with a different convention. + // + return GetArrayLayout( + GetVectorLayout(elementInfo, columnCount), + rowCount); + } + + UniformLayoutInfo BeginStructLayout() override + { + UniformLayoutInfo structInfo(0, 1); + return structInfo; + } + + LayoutSize AddStructField(UniformLayoutInfo* ioStructInfo, UniformLayoutInfo fieldInfo) override + { + // Skip zero-size fields + if(fieldInfo.size == 0) + return ioStructInfo->size; + + // A struct type must be at least as aligned as its most-aligned field. + ioStructInfo->alignment = std::max(ioStructInfo->alignment, fieldInfo.alignment); + + // The new field will be added to the end of the struct. + auto fieldBaseOffset = ioStructInfo->size; + + // We need to ensure that the offset for the field will respect its alignment + auto fieldOffset = RoundToAlignment(fieldBaseOffset, fieldInfo.alignment); + + // The size of the struct must be adjusted to cover the bytes consumed + // by this field. + ioStructInfo->size = fieldOffset + fieldInfo.size; + + return fieldOffset; + } + + + void EndStructLayout(UniformLayoutInfo* ioStructInfo) override + { + SLANG_UNUSED(ioStructInfo); + + // Note: A traditional C layout algorithm would adjust the size + // of a struct type so that it is a multiple of the alignment. + // This is a parsimonious design choice because it means that + // `sizeof(T)` can both be used when copying/allocating a single + // value of type `T` or an array of N values, without having to + // consider more details. + // + // Of course the choice also has down-sides in that wrapping things + // into a `struct` can affect layout in ways that waste space. E.g., + // the following two cases don't lay out the same: + // + // struct S0 { double d; float f; float g; }; + // + // struct X { double d; float f; } + // struct S1 { X x; float g; } + // + // Even though `S0::g` and `S1::g` have the same amount of useful + // data in front of them, they will not land at the same offset, + // and the resulting struct sizes will differ (`sizeof(S0)` will be + // 16 while `sizeof(S1)` will be 24). + // + // Slang doesn't get to be opinionated about this stuff because + // there is already precedent in both HLSL and GLSL for types + // that have a size that is not rounded up to their alignment. + // + // Our default layout rules won't implement the C-like policy, + // and instead it will be injected in the concrete implementations + // that require it. + } +}; + + /// Common behavior for GLSL-family layout. +struct GLSLBaseLayoutRulesImpl : DefaultLayoutRulesImpl +{ + typedef DefaultLayoutRulesImpl Super; + + SimpleLayoutInfo GetVectorLayout(SimpleLayoutInfo elementInfo, size_t elementCount) override + { + // The `std140` and `std430` rules require vectors to be aligned to the next power of + // two up from their size (so a `float2` is 8-byte aligned, and a `float3` is + // 16-byte aligned). + // + // Note that in this case we have a type layout where the size is *not* a multiple + // of the alignment, so it should be possible to pack a scalar after a `float3`. + // + SLANG_RELEASE_ASSERT(elementInfo.kind == LayoutResourceKind::Uniform); + SLANG_RELEASE_ASSERT(elementInfo.size.isFinite()); + + auto size = elementInfo.size.getFiniteValue() * elementCount; + SimpleLayoutInfo vectorInfo( + LayoutResourceKind::Uniform, + size, + RoundUpToPowerOfTwo(size)); + return vectorInfo; + } + + SimpleArrayLayoutInfo GetArrayLayout( SimpleLayoutInfo elementInfo, LayoutSize elementCount) override + { + // The size of an array must be rounded up to be a multiple of its alignment. + // + auto info = Super::GetArrayLayout(elementInfo, elementCount); + info.size = RoundToAlignment(info.size, info.alignment); + return info; + } + + void EndStructLayout(UniformLayoutInfo* ioStructInfo) override + { + // The size of a `struct` must be rounded up to be a multiple of its alignment. + // + ioStructInfo->size = RoundToAlignment(ioStructInfo->size, ioStructInfo->alignment); + } +}; + + /// The GLSL `std430` layout rules. +struct Std430LayoutRulesImpl : GLSLBaseLayoutRulesImpl +{ + // These rules don't actually need any differences from our + // base/common GLSL layout rules. +}; + + /// The GLSL `std430` layout rules. +struct Std140LayoutRulesImpl : GLSLBaseLayoutRulesImpl +{ + typedef GLSLBaseLayoutRulesImpl Super; + + SimpleArrayLayoutInfo GetArrayLayout(SimpleLayoutInfo elementInfo, LayoutSize elementCount) override + { + // The `std140` rules require that array elements + // be aligned on 16-byte boundaries. + // + if(elementInfo.kind == LayoutResourceKind::Uniform) + { + if (elementInfo.alignment < 16) + elementInfo.alignment = 16; + } + return Super::GetArrayLayout(elementInfo, elementCount); + } + + UniformLayoutInfo BeginStructLayout() override + { + // The `std140` rules require that a `struct` type + // be at least 16-byte aligned. + // + return UniformLayoutInfo(0, 16); + } +}; + +struct HLSLConstantBufferLayoutRulesImpl : DefaultLayoutRulesImpl +{ + typedef DefaultLayoutRulesImpl Super; + + // Similar to GLSL `std140` rules, an HLSL constant buffer requires that + // `struct` and array types have 16-byte alignement. + // + // Unlike GLSL `std140`, the overall size of an array or `struct` type + // is *not* rounded up to the alignment, so it is possible for later + // fields to sneak into the "tail space" left behind by a preceding + // structure or array. E.g., in this example: + // + // struct S { float3 a[2]; float b; }; + // + // The stride of the array `a` is 16 bytes per element, but the size + // of `a` will only be 28 bytes (not 32), so that `b` can fit into + // the space after the last array element and the overall structure + // will have a size of 32 bytes. + + SimpleArrayLayoutInfo GetArrayLayout(SimpleLayoutInfo elementInfo, LayoutSize elementCount) override + { + if(elementInfo.kind == LayoutResourceKind::Uniform) + { + if (elementInfo.alignment < 16) + elementInfo.alignment = 16; + } + return Super::GetArrayLayout(elementInfo, elementCount); + } + + UniformLayoutInfo BeginStructLayout() override + { + return UniformLayoutInfo(0, 16); + } + + // HLSL layout rules do *not* impose additional alignment + // constraints on vectors (e.g., all of `float`, `float2`, + // `float3`, and `float4` have 4-byte alignment), but instead + // they impose a rule that any `struct` field must not + // "straddle" a 16-byte boundary. + // + // This has the effect of making it *look* like `float4` + // values have 16-byte alignment in practice, but the + // effects on `float2` and `float3` are more nuanched and + // lead to different result than the GLSL rules. + // + LayoutSize AddStructField(UniformLayoutInfo* ioStructInfo, UniformLayoutInfo fieldInfo) override + { + // Skip zero-size fields + if(fieldInfo.size == 0) + return ioStructInfo->size; + + ioStructInfo->alignment = std::max(ioStructInfo->alignment, fieldInfo.alignment); + ioStructInfo->size = RoundToAlignment(ioStructInfo->size, fieldInfo.alignment); + + LayoutSize fieldOffset = ioStructInfo->size; + LayoutSize fieldSize = fieldInfo.size; + + // Would this field cross a 16-byte boundary? + auto registerSize = 16; + auto startRegister = fieldOffset / registerSize; + auto endRegister = (fieldOffset + fieldSize - 1) / registerSize; + if (startRegister != endRegister) + { + ioStructInfo->size = RoundToAlignment(ioStructInfo->size, size_t(registerSize)); + fieldOffset = ioStructInfo->size; + } + + ioStructInfo->size += fieldInfo.size; + return fieldOffset; + } +}; + +struct HLSLStructuredBufferLayoutRulesImpl : DefaultLayoutRulesImpl +{ + // HLSL structured buffers drop the restrictions added for constant buffers, + // but retain the rules around not adjusting the size of an array or + // structure to its alignment. In this way they should match our + // default layout rules. +}; + +struct DefaultVaryingLayoutRulesImpl : DefaultLayoutRulesImpl +{ + LayoutResourceKind kind; + + DefaultVaryingLayoutRulesImpl(LayoutResourceKind kind) + : kind(kind) + {} + + + // hook to allow differentiating for input/output + virtual LayoutResourceKind getKind() + { + return kind; + } + + SimpleLayoutInfo GetScalarLayout(BaseType) override + { + // Assume that all scalars take up one "slot" + return SimpleLayoutInfo( + getKind(), + 1); + } + + SimpleLayoutInfo GetVectorLayout(SimpleLayoutInfo, size_t) override + { + // Vectors take up one slot by default + // + // TODO: some platforms may decide that vectors of `double` need + // special handling + return SimpleLayoutInfo( + getKind(), + 1); + } +}; + +struct GLSLVaryingLayoutRulesImpl : DefaultVaryingLayoutRulesImpl +{ + GLSLVaryingLayoutRulesImpl(LayoutResourceKind kind) + : DefaultVaryingLayoutRulesImpl(kind) + {} +}; + +struct HLSLVaryingLayoutRulesImpl : DefaultVaryingLayoutRulesImpl +{ + HLSLVaryingLayoutRulesImpl(LayoutResourceKind kind) + : DefaultVaryingLayoutRulesImpl(kind) + {} +}; + +// + +struct GLSLSpecializationConstantLayoutRulesImpl : DefaultLayoutRulesImpl +{ + LayoutResourceKind getKind() + { + return LayoutResourceKind::SpecializationConstant; + } + + SimpleLayoutInfo GetScalarLayout(BaseType) override + { + // Assume that all scalars take up one "slot" + return SimpleLayoutInfo( + getKind(), + 1); + } + + SimpleLayoutInfo GetVectorLayout(SimpleLayoutInfo, size_t elementCount) override + { + // GLSL doesn't support vectors of specialization constants, + // but we will assume that, if supported, they would use one slot per element. + return SimpleLayoutInfo( + getKind(), + elementCount); + } +}; + +GLSLSpecializationConstantLayoutRulesImpl kGLSLSpecializationConstantLayoutRulesImpl; + +// + +struct GLSLObjectLayoutRulesImpl : ObjectLayoutRulesImpl +{ + virtual SimpleLayoutInfo GetObjectLayout(ShaderParameterKind) override + { + // In Vulkan GLSL, pretty much every object is just a descriptor-table slot. + // We can refine this method once we support a case where this isn't true. + return SimpleLayoutInfo(LayoutResourceKind::DescriptorTableSlot, 1); + } +}; +GLSLObjectLayoutRulesImpl kGLSLObjectLayoutRulesImpl; + +struct GLSLPushConstantBufferObjectLayoutRulesImpl : GLSLObjectLayoutRulesImpl +{ + virtual SimpleLayoutInfo GetObjectLayout(ShaderParameterKind /*kind*/) override + { + // Special-case the layout for a constant-buffer, because we don't + // want it to allocate a descriptor-table slot + return SimpleLayoutInfo(LayoutResourceKind::PushConstantBuffer, 1); + } +}; +GLSLPushConstantBufferObjectLayoutRulesImpl kGLSLPushConstantBufferObjectLayoutRulesImpl_; + +struct GLSLShaderRecordConstantBufferObjectLayoutRulesImpl : GLSLObjectLayoutRulesImpl +{ + virtual SimpleLayoutInfo GetObjectLayout(ShaderParameterKind /*kind*/) override + { + // Special-case the layout for a constant-buffer, because we don't + // want it to allocate a descriptor-table slot + return SimpleLayoutInfo(LayoutResourceKind::ShaderRecord, 1); + } +}; +GLSLShaderRecordConstantBufferObjectLayoutRulesImpl kGLSLShaderRecordConstantBufferObjectLayoutRulesImpl_; + +struct HLSLObjectLayoutRulesImpl : ObjectLayoutRulesImpl +{ + virtual SimpleLayoutInfo GetObjectLayout(ShaderParameterKind kind) override + { + switch( kind ) + { + case ShaderParameterKind::ConstantBuffer: + return SimpleLayoutInfo(LayoutResourceKind::ConstantBuffer, 1); + + case ShaderParameterKind::TextureUniformBuffer: + case ShaderParameterKind::StructuredBuffer: + case ShaderParameterKind::RawBuffer: + case ShaderParameterKind::Buffer: + case ShaderParameterKind::Texture: + return SimpleLayoutInfo(LayoutResourceKind::ShaderResource, 1); + + case ShaderParameterKind::MutableStructuredBuffer: + case ShaderParameterKind::MutableRawBuffer: + case ShaderParameterKind::MutableBuffer: + case ShaderParameterKind::MutableTexture: + return SimpleLayoutInfo(LayoutResourceKind::UnorderedAccess, 1); + + case ShaderParameterKind::SamplerState: + return SimpleLayoutInfo(LayoutResourceKind::SamplerState, 1); + + case ShaderParameterKind::TextureSampler: + case ShaderParameterKind::MutableTextureSampler: + case ShaderParameterKind::InputRenderTarget: + // TODO: how to handle these? + default: + SLANG_UNEXPECTED("unhandled shader parameter kind"); + UNREACHABLE_RETURN(SimpleLayoutInfo()); + } + } +}; +HLSLObjectLayoutRulesImpl kHLSLObjectLayoutRulesImpl; + +// HACK: Treating ray-tracing input/output as if it was another +// case of varying input/output when it really needs to be +// based on byte storage/layout. +// +struct GLSLRayTracingLayoutRulesImpl : DefaultVaryingLayoutRulesImpl +{ + GLSLRayTracingLayoutRulesImpl(LayoutResourceKind kind) + : DefaultVaryingLayoutRulesImpl(kind) + {} +}; +struct HLSLRayTracingLayoutRulesImpl : DefaultVaryingLayoutRulesImpl +{ + HLSLRayTracingLayoutRulesImpl(LayoutResourceKind kind) + : DefaultVaryingLayoutRulesImpl(kind) + {} +}; + +Std140LayoutRulesImpl kStd140LayoutRulesImpl; +Std430LayoutRulesImpl kStd430LayoutRulesImpl; +HLSLConstantBufferLayoutRulesImpl kHLSLConstantBufferLayoutRulesImpl; +HLSLStructuredBufferLayoutRulesImpl kHLSLStructuredBufferLayoutRulesImpl; + +GLSLVaryingLayoutRulesImpl kGLSLVaryingInputLayoutRulesImpl(LayoutResourceKind::VertexInput); +GLSLVaryingLayoutRulesImpl kGLSLVaryingOutputLayoutRulesImpl(LayoutResourceKind::FragmentOutput); + +GLSLRayTracingLayoutRulesImpl kGLSLRayPayloadParameterLayoutRulesImpl(LayoutResourceKind::RayPayload); +GLSLRayTracingLayoutRulesImpl kGLSLCallablePayloadParameterLayoutRulesImpl(LayoutResourceKind::CallablePayload); +GLSLRayTracingLayoutRulesImpl kGLSLHitAttributesParameterLayoutRulesImpl(LayoutResourceKind::HitAttributes); + +HLSLVaryingLayoutRulesImpl kHLSLVaryingInputLayoutRulesImpl(LayoutResourceKind::VertexInput); +HLSLVaryingLayoutRulesImpl kHLSLVaryingOutputLayoutRulesImpl(LayoutResourceKind::FragmentOutput); + +HLSLRayTracingLayoutRulesImpl kHLSLRayPayloadParameterLayoutRulesImpl(LayoutResourceKind::RayPayload); +HLSLRayTracingLayoutRulesImpl kHLSLCallablePayloadParameterLayoutRulesImpl(LayoutResourceKind::CallablePayload); +HLSLRayTracingLayoutRulesImpl kHLSLHitAttributesParameterLayoutRulesImpl(LayoutResourceKind::HitAttributes); + +// + +struct GLSLLayoutRulesFamilyImpl : LayoutRulesFamilyImpl +{ + virtual LayoutRulesImpl* getConstantBufferRules() override; + virtual LayoutRulesImpl* getPushConstantBufferRules() override; + virtual LayoutRulesImpl* getTextureBufferRules() override; + virtual LayoutRulesImpl* getVaryingInputRules() override; + virtual LayoutRulesImpl* getVaryingOutputRules() override; + virtual LayoutRulesImpl* getSpecializationConstantRules() override; + virtual LayoutRulesImpl* getShaderStorageBufferRules() override; + virtual LayoutRulesImpl* getParameterBlockRules() override; + + LayoutRulesImpl* getRayPayloadParameterRules() override; + LayoutRulesImpl* getCallablePayloadParameterRules() override; + LayoutRulesImpl* getHitAttributesParameterRules() override; + + LayoutRulesImpl* getShaderRecordConstantBufferRules() override; +}; + +struct HLSLLayoutRulesFamilyImpl : LayoutRulesFamilyImpl +{ + virtual LayoutRulesImpl* getConstantBufferRules() override; + virtual LayoutRulesImpl* getPushConstantBufferRules() override; + virtual LayoutRulesImpl* getTextureBufferRules() override; + virtual LayoutRulesImpl* getVaryingInputRules() override; + virtual LayoutRulesImpl* getVaryingOutputRules() override; + virtual LayoutRulesImpl* getSpecializationConstantRules() override; + virtual LayoutRulesImpl* getShaderStorageBufferRules() override; + virtual LayoutRulesImpl* getParameterBlockRules() override; + + LayoutRulesImpl* getRayPayloadParameterRules() override; + LayoutRulesImpl* getCallablePayloadParameterRules() override; + LayoutRulesImpl* getHitAttributesParameterRules() override; + + LayoutRulesImpl* getShaderRecordConstantBufferRules() override; +}; + +GLSLLayoutRulesFamilyImpl kGLSLLayoutRulesFamilyImpl; +HLSLLayoutRulesFamilyImpl kHLSLLayoutRulesFamilyImpl; + + +// GLSL cases + +LayoutRulesImpl kStd140LayoutRulesImpl_ = { + &kGLSLLayoutRulesFamilyImpl, &kStd140LayoutRulesImpl, &kGLSLObjectLayoutRulesImpl, +}; + +LayoutRulesImpl kStd430LayoutRulesImpl_ = { + &kGLSLLayoutRulesFamilyImpl, &kStd430LayoutRulesImpl, &kGLSLObjectLayoutRulesImpl, +}; + +LayoutRulesImpl kGLSLPushConstantLayoutRulesImpl_ = { + &kGLSLLayoutRulesFamilyImpl, &kStd430LayoutRulesImpl, &kGLSLPushConstantBufferObjectLayoutRulesImpl_, +}; + +LayoutRulesImpl kGLSLShaderRecordLayoutRulesImpl_ = { + &kGLSLLayoutRulesFamilyImpl, &kStd430LayoutRulesImpl, &kGLSLShaderRecordConstantBufferObjectLayoutRulesImpl_, +}; + +LayoutRulesImpl kGLSLVaryingInputLayoutRulesImpl_ = { + &kGLSLLayoutRulesFamilyImpl, &kGLSLVaryingInputLayoutRulesImpl, &kGLSLObjectLayoutRulesImpl, +}; + +LayoutRulesImpl kGLSLVaryingOutputLayoutRulesImpl_ = { + &kGLSLLayoutRulesFamilyImpl, &kGLSLVaryingOutputLayoutRulesImpl, &kGLSLObjectLayoutRulesImpl, +}; + +LayoutRulesImpl kGLSLSpecializationConstantLayoutRulesImpl_ = { + &kGLSLLayoutRulesFamilyImpl, &kGLSLSpecializationConstantLayoutRulesImpl, &kGLSLObjectLayoutRulesImpl, +}; + +LayoutRulesImpl kGLSLRayPayloadParameterLayoutRulesImpl_ = { + &kGLSLLayoutRulesFamilyImpl, &kGLSLRayPayloadParameterLayoutRulesImpl, &kGLSLObjectLayoutRulesImpl, +}; + +LayoutRulesImpl kGLSLCallablePayloadParameterLayoutRulesImpl_ = { + &kGLSLLayoutRulesFamilyImpl, &kGLSLCallablePayloadParameterLayoutRulesImpl, &kGLSLObjectLayoutRulesImpl, +}; + +LayoutRulesImpl kGLSLHitAttributesParameterLayoutRulesImpl_ = { + &kGLSLLayoutRulesFamilyImpl, &kGLSLHitAttributesParameterLayoutRulesImpl, &kGLSLObjectLayoutRulesImpl, +}; + +// HLSL cases + +LayoutRulesImpl kHLSLConstantBufferLayoutRulesImpl_ = { + &kHLSLLayoutRulesFamilyImpl, &kHLSLConstantBufferLayoutRulesImpl, &kHLSLObjectLayoutRulesImpl, +}; + +LayoutRulesImpl kHLSLStructuredBufferLayoutRulesImpl_ = { + &kHLSLLayoutRulesFamilyImpl, &kHLSLStructuredBufferLayoutRulesImpl, &kHLSLObjectLayoutRulesImpl, +}; + +LayoutRulesImpl kHLSLVaryingInputLayoutRulesImpl_ = { + &kHLSLLayoutRulesFamilyImpl, &kHLSLVaryingInputLayoutRulesImpl, &kHLSLObjectLayoutRulesImpl, +}; + +LayoutRulesImpl kHLSLVaryingOutputLayoutRulesImpl_ = { + &kHLSLLayoutRulesFamilyImpl, &kHLSLVaryingOutputLayoutRulesImpl, &kHLSLObjectLayoutRulesImpl, +}; + +LayoutRulesImpl kHLSLRayPayloadParameterLayoutRulesImpl_ = { + &kHLSLLayoutRulesFamilyImpl, &kHLSLRayPayloadParameterLayoutRulesImpl, &kHLSLObjectLayoutRulesImpl, +}; + +LayoutRulesImpl kHLSLCallablePayloadParameterLayoutRulesImpl_ = { + &kHLSLLayoutRulesFamilyImpl, &kHLSLCallablePayloadParameterLayoutRulesImpl, &kHLSLObjectLayoutRulesImpl, +}; + +LayoutRulesImpl kHLSLHitAttributesParameterLayoutRulesImpl_ = { + &kHLSLLayoutRulesFamilyImpl, &kHLSLHitAttributesParameterLayoutRulesImpl, &kHLSLObjectLayoutRulesImpl, +}; + +// + +LayoutRulesImpl* GLSLLayoutRulesFamilyImpl::getConstantBufferRules() +{ + return &kStd140LayoutRulesImpl_; +} + +LayoutRulesImpl* GLSLLayoutRulesFamilyImpl::getParameterBlockRules() +{ + // TODO: actually pick something appropriate + return &kStd140LayoutRulesImpl_; +} + +LayoutRulesImpl* GLSLLayoutRulesFamilyImpl::getPushConstantBufferRules() +{ + return &kGLSLPushConstantLayoutRulesImpl_; +} + +LayoutRulesImpl* GLSLLayoutRulesFamilyImpl::getShaderRecordConstantBufferRules() +{ + return &kGLSLShaderRecordLayoutRulesImpl_; +} + +LayoutRulesImpl* GLSLLayoutRulesFamilyImpl::getTextureBufferRules() +{ + return nullptr; +} + +LayoutRulesImpl* GLSLLayoutRulesFamilyImpl::getVaryingInputRules() +{ + return &kGLSLVaryingInputLayoutRulesImpl_; +} + +LayoutRulesImpl* GLSLLayoutRulesFamilyImpl::getVaryingOutputRules() +{ + return &kGLSLVaryingOutputLayoutRulesImpl_; +} + +LayoutRulesImpl* GLSLLayoutRulesFamilyImpl::getSpecializationConstantRules() +{ + return &kGLSLSpecializationConstantLayoutRulesImpl_; +} + +LayoutRulesImpl* GLSLLayoutRulesFamilyImpl::getShaderStorageBufferRules() +{ + return &kStd430LayoutRulesImpl_; +} + +LayoutRulesImpl* GLSLLayoutRulesFamilyImpl::getRayPayloadParameterRules() +{ + return &kGLSLRayPayloadParameterLayoutRulesImpl_; +} + +LayoutRulesImpl* GLSLLayoutRulesFamilyImpl::getCallablePayloadParameterRules() +{ + return &kGLSLCallablePayloadParameterLayoutRulesImpl_; +} + +LayoutRulesImpl* GLSLLayoutRulesFamilyImpl::getHitAttributesParameterRules() +{ + return &kGLSLHitAttributesParameterLayoutRulesImpl_; +} + +// + +LayoutRulesImpl* HLSLLayoutRulesFamilyImpl::getConstantBufferRules() +{ + return &kHLSLConstantBufferLayoutRulesImpl_; +} + +LayoutRulesImpl* HLSLLayoutRulesFamilyImpl::getParameterBlockRules() +{ + // TODO: actually pick something appropriate... + return &kHLSLConstantBufferLayoutRulesImpl_; +} + + +LayoutRulesImpl* HLSLLayoutRulesFamilyImpl::getPushConstantBufferRules() +{ + return &kHLSLConstantBufferLayoutRulesImpl_; +} + +LayoutRulesImpl* HLSLLayoutRulesFamilyImpl::getShaderRecordConstantBufferRules() +{ + return &kHLSLConstantBufferLayoutRulesImpl_; +} + +LayoutRulesImpl* HLSLLayoutRulesFamilyImpl::getTextureBufferRules() +{ + return nullptr; +} + +LayoutRulesImpl* HLSLLayoutRulesFamilyImpl::getVaryingInputRules() +{ + return &kHLSLVaryingInputLayoutRulesImpl_; +} + +LayoutRulesImpl* HLSLLayoutRulesFamilyImpl::getVaryingOutputRules() +{ + return &kHLSLVaryingOutputLayoutRulesImpl_; +} + +LayoutRulesImpl* HLSLLayoutRulesFamilyImpl::getSpecializationConstantRules() +{ + return nullptr; +} + +LayoutRulesImpl* HLSLLayoutRulesFamilyImpl::getShaderStorageBufferRules() +{ + return nullptr; +} + +LayoutRulesImpl* HLSLLayoutRulesFamilyImpl::getRayPayloadParameterRules() +{ + return &kHLSLRayPayloadParameterLayoutRulesImpl_; +} + +LayoutRulesImpl* HLSLLayoutRulesFamilyImpl::getCallablePayloadParameterRules() +{ + return &kHLSLCallablePayloadParameterLayoutRulesImpl_; +} + +LayoutRulesImpl* HLSLLayoutRulesFamilyImpl::getHitAttributesParameterRules() +{ + return &kHLSLHitAttributesParameterLayoutRulesImpl_; +} + + + +// + +LayoutRulesImpl* GetLayoutRulesImpl(LayoutRule rule) +{ + switch (rule) + { + case LayoutRule::Std140: return &kStd140LayoutRulesImpl_; + case LayoutRule::Std430: return &kStd430LayoutRulesImpl_; + case LayoutRule::HLSLConstantBuffer: return &kHLSLConstantBufferLayoutRulesImpl_; + case LayoutRule::HLSLStructuredBuffer: return &kHLSLStructuredBufferLayoutRulesImpl_; + default: + return nullptr; + } +} + +LayoutRulesFamilyImpl* getDefaultLayoutRulesFamilyForTarget(TargetRequest* targetReq) +{ + switch (targetReq->getTarget()) + { + case CodeGenTarget::HLSL: + case CodeGenTarget::DXBytecode: + case CodeGenTarget::DXBytecodeAssembly: + case CodeGenTarget::DXIL: + case CodeGenTarget::DXILAssembly: + return &kHLSLLayoutRulesFamilyImpl; + + case CodeGenTarget::GLSL: + case CodeGenTarget::SPIRV: + case CodeGenTarget::SPIRVAssembly: + return &kGLSLLayoutRulesFamilyImpl; + + + case CodeGenTarget::CPPSource: + case CodeGenTarget::CSource: + { + // We just need to decide here what style of layout is appropriate, in terms of memory + // and binding. That in terms of the actual binding that will be injected into functions + // in the form of a BindContext. For now we'll go with HLSL layout - + // that we may want to rethink that with the use of arrays and binding VK style binding might be + // more appropriate in some ways. + + return &kHLSLLayoutRulesFamilyImpl; + } + + default: + return nullptr; + } +} + +TypeLayoutContext getInitialLayoutContextForTarget(TargetRequest* targetReq, ProgramLayout* programLayout) +{ + LayoutRulesFamilyImpl* rulesFamily = getDefaultLayoutRulesFamilyForTarget(targetReq); + + TypeLayoutContext context; + context.targetReq = targetReq; + context.programLayout = programLayout; + context.rules = nullptr; + context.matrixLayoutMode = targetReq->getDefaultMatrixLayoutMode(); + + if( rulesFamily ) + { + context.rules = rulesFamily->getConstantBufferRules(); + } + + return context; +} + + +static LayoutSize GetElementCount(RefPtr val) +{ + // Lack of a size indicates an unbounded array. + if(!val) + return LayoutSize::infinite(); + + if (auto constantVal = as(val)) + { + return LayoutSize(LayoutSize::RawValue(constantVal->value)); + } + else if( auto varRefVal = as(val) ) + { + // TODO: We want to treat the case where the number of + // elements in an array depends on a generic parameter + // much like the case where the number of elements is + // unbounded, *but* we can't just blindly do that because + // an API might disallow unbounded arrays in various + // cases where a generic bound might work (because + // any concrete specialization will have a finite bound...) + // + return 0; + } + SLANG_UNEXPECTED("unhandled integer literal kind"); + UNREACHABLE_RETURN(LayoutSize(0)); +} + +bool IsResourceKind(LayoutResourceKind kind) +{ + switch (kind) + { + case LayoutResourceKind::None: + case LayoutResourceKind::Uniform: + return false; + + default: + return true; + } + +} + + /// Create a type layout for a type that has simple layout needs. + /// + /// This handles any type that can express its layout in `SimpleLayoutInfo`, + /// and that only needs a `TypeLayout` and not a refined subclass. + /// +static TypeLayoutResult createSimpleTypeLayout( + SimpleLayoutInfo info, + RefPtr type, + LayoutRulesImpl* rules) +{ + RefPtr typeLayout = new TypeLayout(); + + typeLayout->type = type; + typeLayout->rules = rules; + + typeLayout->uniformAlignment = info.alignment; + + typeLayout->addResourceUsage(info.kind, info.size); + + return TypeLayoutResult(typeLayout, info); +} + +static SimpleLayoutInfo getParameterGroupLayoutInfo( + RefPtr type, + LayoutRulesImpl* rules) +{ + if( as(type) ) + { + return rules->GetObjectLayout(ShaderParameterKind::ConstantBuffer); + } + else if( as(type) ) + { + return rules->GetObjectLayout(ShaderParameterKind::TextureUniformBuffer); + } + else if( as(type) ) + { + return rules->GetObjectLayout(ShaderParameterKind::ShaderStorageBuffer); + } + else if (as(type)) + { + // Note: we default to consuming zero register spces here, because + // a parameter block might not contain anything (or all it contains + // is other blocks), and so it won't get a space allocated. + // + // This choice *also* means that in the case where we don't actually + // want to allocate register spaces to blocks at all, we haven't + // committed to that choice here. + // + // TODO: wouldn't it be any different to just allocate this + // as an empty `SimpleLayoutInfo` of any other kind? + return SimpleLayoutInfo(LayoutResourceKind::RegisterSpace, 0); + } + + // TODO: the vertex-input and fragment-output cases should + // only actually apply when we are at the appropriate stage in + // the pipeline... + else if( as(type) ) + { + return SimpleLayoutInfo(LayoutResourceKind::VertexInput, 0); + } + else if( as(type) ) + { + return SimpleLayoutInfo(LayoutResourceKind::FragmentOutput, 0); + } + else + { + SLANG_UNEXPECTED("unhandled parameter block type"); + UNREACHABLE_RETURN(SimpleLayoutInfo()); + } +} + +static bool isOpenGLTarget(TargetRequest*) +{ + // We aren't officially supporting OpenGL right now + return false; +} + +bool isD3DTarget(TargetRequest* targetReq) +{ + switch( targetReq->getTarget() ) + { + case CodeGenTarget::HLSL: + case CodeGenTarget::DXBytecode: + case CodeGenTarget::DXBytecodeAssembly: + case CodeGenTarget::DXIL: + case CodeGenTarget::DXILAssembly: + return true; + + default: + return false; + } +} + +bool isKhronosTarget(TargetRequest* targetReq) +{ + switch( targetReq->getTarget() ) + { + default: + return false; + + case CodeGenTarget::GLSL: + case CodeGenTarget::SPIRV: + case CodeGenTarget::SPIRVAssembly: + return true; + } +} + +static bool isD3D11Target(TargetRequest*) +{ + // We aren't officially supporting D3D11 right now + return false; +} + +static bool isD3D12Target(TargetRequest* targetReq) +{ + // We are currently only officially supporting D3D12 + return isD3DTarget(targetReq); +} + + +static bool isSM5OrEarlier(TargetRequest* targetReq) +{ + if(!isD3DTarget(targetReq)) + return false; + + auto profile = targetReq->getTargetProfile(); + + if(profile.getFamily() == ProfileFamily::DX) + { + if(profile.GetVersion() <= ProfileVersion::DX_5_0) + return true; + } + + return false; +} + +static bool isSM5_1OrLater(TargetRequest* targetReq) +{ + if(!isD3DTarget(targetReq)) + return false; + + auto profile = targetReq->getTargetProfile(); + + if(profile.getFamily() == ProfileFamily::DX) + { + if(profile.GetVersion() >= ProfileVersion::DX_5_1) + return true; + } + + return false; +} + +static bool isVulkanTarget(TargetRequest* targetReq) +{ + // For right now, any Khronos-related target is assumed + // to be a Vulkan target. + return isKhronosTarget(targetReq); +} + +static bool shouldAllocateRegisterSpaceForParameterBlock( + TypeLayoutContext const& context) +{ + auto targetReq = context.targetReq; + + // We *never* want to use register spaces/sets under + // OpenGL, D3D11, or for Shader Model 5.0 or earlier. + if(isOpenGLTarget(targetReq) || isD3D11Target(targetReq) || isSM5OrEarlier(targetReq)) + return false; + + // If we know that we are targetting Vulkan, then + // the only way to effectively use parameter blocks + // is by using descriptor sets. + if(isVulkanTarget(targetReq)) + return true; + + // If none of the above passed, then it seems like we + // are generating code for D3D12, and using SM5.1 or later. + // We will use a register space for parameter blocks *if* + // the target options tell us to: + if( isD3D12Target(targetReq) && isSM5_1OrLater(targetReq) ) + { + return true; + } + + return false; +} + +// Given an existing type layout `oldTypeLayout`, apply offsets +// to any contained fields based on the resource infos in `offsetVarLayout`. +RefPtr applyOffsetToTypeLayout( + RefPtr oldTypeLayout, + RefPtr offsetVarLayout) +{ + // There is no need to apply offsets if the old type and the offset + // don't share any resource infos in common. + bool anyHit = false; + for (auto oldResInfo : oldTypeLayout->resourceInfos) + { + if (auto offsetResInfo = offsetVarLayout->FindResourceInfo(oldResInfo.kind)) + { + anyHit = true; + break; + } + } + + if (!anyHit) + return oldTypeLayout; + + RefPtr newTypeLayout; + if (auto oldStructTypeLayout = oldTypeLayout.as()) + { + RefPtr newStructTypeLayout = new StructTypeLayout(); + newStructTypeLayout->type = oldStructTypeLayout->type; + newStructTypeLayout->uniformAlignment = oldStructTypeLayout->uniformAlignment; + + Dictionary mapOldFieldToNew; + + for (auto oldField : oldStructTypeLayout->fields) + { + RefPtr newField = new VarLayout(); + newField->varDecl = oldField->varDecl; + newField->typeLayout = oldField->typeLayout; + newField->flags = oldField->flags; + newField->semanticIndex = oldField->semanticIndex; + newField->semanticName = oldField->semanticName; + newField->stage = oldField->stage; + newField->systemValueSemantic = oldField->systemValueSemantic; + newField->systemValueSemanticIndex = oldField->systemValueSemanticIndex; + + + for (auto oldResInfo : oldField->resourceInfos) + { + auto newResInfo = newField->findOrAddResourceInfo(oldResInfo.kind); + newResInfo->index = oldResInfo.index; + newResInfo->space = oldResInfo.space; + if (auto offsetResInfo = offsetVarLayout->FindResourceInfo(oldResInfo.kind)) + { + newResInfo->index += offsetResInfo->index; + } + } + + newStructTypeLayout->fields.add(newField); + + mapOldFieldToNew.Add(oldField.Ptr(), newField.Ptr()); + } + + for (auto entry : oldStructTypeLayout->mapVarToLayout) + { + VarLayout* newFieldLayout = nullptr; + if (mapOldFieldToNew.TryGetValue(entry.Value.Ptr(), newFieldLayout)) + { + newStructTypeLayout->mapVarToLayout.Add(entry.Key, newFieldLayout); + } + } + + newTypeLayout = newStructTypeLayout; + } + else + { + // TODO: need to handle other cases here + return oldTypeLayout; + } + + // No matter what replacement we plug in for the element type, we need to copy + // over its resource usage: + for (auto oldResInfo : oldTypeLayout->resourceInfos) + { + auto newResInfo = newTypeLayout->findOrAddResourceInfo(oldResInfo.kind); + newResInfo->count = oldResInfo.count; + } + + return newTypeLayout; +} + +static bool _usesResourceKind(RefPtr typeLayout, LayoutResourceKind kind) +{ + auto resInfo = typeLayout->FindResourceInfo(kind); + return resInfo && resInfo->count != 0; +} + +static bool _usesOrdinaryData(RefPtr typeLayout) +{ + return _usesResourceKind(typeLayout, LayoutResourceKind::Uniform); +} + + /// Add resource usage from `srcTypeLayout` to `dstTypeLayout` unless it would be "masked." + /// + /// This function is appropriate for applying resource usage from an element type + /// to the resource usage of a container like a `ConstantBuffer` or + /// `ParameterBlock`. + /// +static void _addUnmaskedResourceUsage( + TypeLayout* dstTypeLayout, + TypeLayout* srcTypeLayout, + bool haveFullRegisterSpaceOrSet) +{ + for( auto resInfo : srcTypeLayout->resourceInfos ) + { + switch( resInfo.kind ) + { + case LayoutResourceKind::Uniform: + // Ordinary/uniform resource usage will always be masked. + break; + + case LayoutResourceKind::RegisterSpace: + case LayoutResourceKind::ExistentialTypeParam: + // A parameter group will always pay for full registers + // spaces consumed by its element type. + // + // The same is true for existential type parameters, + // since these need to be exposed up through the API. + // + dstTypeLayout->addResourceUsage(resInfo); + break; + + default: + // For all other resource kinds, a parameter group + // will be able to mask them if and only if it + // has a full space/set allocated to it. + // + // Otherwise, the resource usage of the group must + // include the resource usage of the element. + // + if( !haveFullRegisterSpaceOrSet ) + { + dstTypeLayout->addResourceUsage(resInfo); + } + break; + } + } +} + +static RefPtr _createParameterGroupTypeLayout( + TypeLayoutContext const& context, + RefPtr parameterGroupType, + RefPtr rawElementTypeLayout) +{ + // We are being asked to create a layout for a parameter group, + // which is curently either a `ParameterBlock` or a `ConstantBuffer` + // + auto parameterGroupRules = context.rules; + RefPtr typeLayout = new ParameterGroupTypeLayout(); + typeLayout->type = parameterGroupType; + typeLayout->rules = parameterGroupRules; + + // Computing the layout is made tricky by several factors. + // + // A parameter group has to draw a distinction between the element type, + // and the resources it consumes, and the "container," which main + // consume other resources. The type of resource consumed by + // the two can overlap. + // + // Consider: + // + // struct MyMaterial { float2 uvScale; Texture2D albedoMap; } + // ParameterBlock gMaterial; + // + // In this example, `gMaterial` will need both a constant buffer + // binding (to hold the data for `uvScale`) and a texture binding + // (for `albedoMap`). On Vulkan, those two things require the *same* + // `LayoutResourceKind` (representing a GLSL `binding`). We will + // thus track the resource usage of the "container" type and + // element type separately, and then combine these to form + // the overall layout for the parameter group. + + RefPtr containerTypeLayout = new TypeLayout(); + containerTypeLayout->type = parameterGroupType; + containerTypeLayout->rules = parameterGroupRules; + + // Because the container and element types will each be situated + // at some offset relative to the initial register/binding for + // the group as a whole, we allocate a `VarLayout` for both + // the container and the element type, to store that offset + // information (think of `TypeLayout`s as holding size information, + // while `VarLayout`s hold offset information). + + RefPtr containerVarLayout = new VarLayout(); + containerVarLayout->typeLayout = containerTypeLayout; + typeLayout->containerVarLayout = containerVarLayout; + + RefPtr elementVarLayout = new VarLayout(); + elementVarLayout->typeLayout = rawElementTypeLayout; + typeLayout->elementVarLayout = elementVarLayout; + + // It is possible to have a `ConstantBuffer` that doesn't + // actually need a constant buffer register/binding allocated to it, + // because the type `T` doesn't actually contain any ordinary/uniform + // data that needs to go into the constant buffer. For example: + // + // struct MyMaterial { Texture2D t; SamplerState s; }; + // ConstantBuffer gMaterial; + // + // In this example, the `gMaterial` parameter doesn't actually need + // a constant buffer allocated for it. This isn't something that + // comes up often for `ConstantBuffer`, but can happen a lot for + // `ParameterBlock`. + // + // To determine if we actually need a constant-buffer binding, + // we will inspect the element type and see if it contains + // any ordinary/uniform data. + // + bool wantConstantBuffer = _usesOrdinaryData(rawElementTypeLayout); + if( wantConstantBuffer ) + { + // If there is any ordinary data, then we'll need to + // allocate a constant buffer regiser/binding into + // the overall layout, to account for it. + // + auto cbUsage = parameterGroupRules->GetObjectLayout(ShaderParameterKind::ConstantBuffer); + containerTypeLayout->addResourceUsage(cbUsage.kind, cbUsage.size); + } + + // Similarly to how we only need a constant buffer to be allocated + // if the contents of the group actually had ordinary/uniform data, + // we also only want to allocate a `space` or `set` if that is really + // required. + // + // + bool canUseSpaceOrSet = false; + // + // We will only allocate a `space` or `set` if the type is `ParameterBlock` + // and not just `ConstantBuffer`. + // + // Note: `parameterGroupType` is allowed to be null here, if we are allocating + // an anonymous constant buffer for global or entry-point parameters, but that + // is fine because the case will just return null in that case anyway. + // + auto parameterBlockType = as(parameterGroupType); + if( parameterBlockType ) + { + // We also can't allocate a `space` or `set` unless the compilation + // target actually supports them. + // + if( shouldAllocateRegisterSpaceForParameterBlock(context) ) + { + canUseSpaceOrSet = true; + } + } + + // Just knowing that we *can* use a `space` or `set` doesn't tell + // us if we would *like* to. + // + // The basic rule here is that if the element type of the parameter + // block contains anything that isn't itself consuming a full + // register `space` or `set`, then we'll want an umbrella `space`/`set` + // for all such data. + // + bool wantSpaceOrSet = false; + if( canUseSpaceOrSet ) + { + // Note that if we are allocating a constant buffer to hold + // some ordinary/uniform data then we definitely want a space/set, + // but we don't need to special-case that because the loop + // here will also detect the `LayoutResourceKind::Uniform` usage. + + for( auto elementResourceInfo : rawElementTypeLayout->resourceInfos ) + { + if(elementResourceInfo.kind != LayoutResourceKind::RegisterSpace) + { + wantSpaceOrSet = true; + break; + } + } + } + + // If after all that we determine that we want a register space/set, + // then we allocate one as part of the overall resource usage for + // the parameter group type. + // + if( wantSpaceOrSet ) + { + containerTypeLayout->addResourceUsage(LayoutResourceKind::RegisterSpace, 1); + } + + // Now that we've computed basic resource requirements for the container + // part of things (i.e., does it require a constant buffer or not?), + // let's go ahead and assign the container variable a relative offset + // of zero for each of the kinds of resources that it consumes. + // + for( auto typeResInfo : containerTypeLayout->resourceInfos ) + { + containerVarLayout->findOrAddResourceInfo(typeResInfo.kind); + } + + // Because the container's resource allocation is logically coming + // first in the overall group, the element needs to have a layout + // such that it comes *after* the container in the relative order. + // + for( auto elementTypeResInfo : rawElementTypeLayout->resourceInfos ) + { + auto kind = elementTypeResInfo.kind; + auto elementVarResInfo = elementVarLayout->findOrAddResourceInfo(kind); + + // If the container part of things is using the same resource kind + // as the element type, then the element needs to start at an offset + // after the container. + // + if( auto containerTypeResInfo = containerTypeLayout->FindResourceInfo(kind) ) + { + SLANG_RELEASE_ASSERT(containerTypeResInfo->count.isFinite()); + elementVarResInfo->index += containerTypeResInfo->count.getFiniteValue(); + } + } + + // The existing Slang reflection API was created before we really + // understood the wrinkle that the "container" and elements parts + // of a parameter group could collide on some resource kinds, + // so the API doesn't currently expose the nice `VarLayout`s we've + // just computed. + // + // Instead, the API allows the user to query the element type layout + // for the group, and the user just assumes that the offsetting + // is magically applied there. To go back to the earlier example: + // + // struct MyMaterial { Texture2D t; SamplerState s; }; + // ConstantBuffer gMaterial; + // + // A user of the existing reflection API expects to be able to + // query the `binding` of `gMaterial` and get back zero, then + // query the `binding` of the `t` field of the element type + // and get *one*. It is clear that in the abstract, the + // `MyMaterial::t` field should have an offset of zero (as + // the first field in a `struct`), so to meet the user's + // expectations, some cleverness is needed. + // + // We will use a subroutine `applyOffsetToTypeLayout` + // that tries to recursively walk an existing `TypeLayout` + // and apply an offset to its fields. This is currently + // quite ad hoc, but that doesn't matter much as it + // handles `struct` types which are the 99% case for + // parameter blocks. + // + typeLayout->offsetElementTypeLayout = applyOffsetToTypeLayout(rawElementTypeLayout, elementVarLayout); + + // Next, resource usage from the container and element + // types may need to "bleed through" to the overall + // parameter group type. + // + // If the parameter group is a `ConstantBuffer` then + // any ordinary/uniform bytes consumed by `Foo` are masked, + // but any other resources it consumes (e.g. `binding`s) need + // to bleed through and be accounted for in the overall + // layout of the type. + // + // If we have a `ParameterBlock` then any ordinary/uniform + // bytes are masked. Furthermore, *if* a whole `space`/`set` + // was allocated to the block, then any `register`s or + // `binding`s consumed by `Foo` (and by the "container" constant + // buffer if we allocated one) are also masked. Any whole + // spaces/sets consumed by `Foo` need to bleed through. + // + // We can start with the easier case of the container type, + // since it will either be empty or consume a single constant + // buffer. Its resource usage will only bleed through if we + // didn't allocate a full `space` or `set`. + // + _addUnmaskedResourceUsage(typeLayout, containerTypeLayout, wantSpaceOrSet); + + // next we turn to the element type, where the cases are slightly + // more involved (technically we could use this same logic for + // the container, as it is more general, but it was simpler to + // just special-case the container). + // + + _addUnmaskedResourceUsage(typeLayout, rawElementTypeLayout, wantSpaceOrSet); + + // At this point we have handled all the complexities that + // arise for a parameter group that doesn't include interface-type + // fields, or that doesn't include specialization for those fields. + // + // The remaining complexity all arises if we have interface-type + // data in the parameter group, and we are specializing it to + // concrete types, that will have their own layout requirements. + // In those cases there will be "pending data" on the element + // type layout that need to get placed somwhere, but wasn't + // included in the layout computed so far. + // + // All of this is extra work we only have to do if there is + // "pending" data in the element type layout. + // + if( auto pendingElementTypeLayout = rawElementTypeLayout->pendingDataTypeLayout ) + { + auto rules = rawElementTypeLayout->rules; + + // One really annoying complication we need to deal with here + // its that it is possible that the original parameter group + // declaration didn't need a constant buffer or `space`/`set` + // to be allocated, but once we consider the "pending" data + // we need to have a constant buffer and/or space. + // + // We will compute whether the pending data create a demand + // for a constant buffer and/or a space/set, so that we know + // if we are in the tricky case. + // + bool pendingDataWantsConstantBuffer = _usesOrdinaryData(pendingElementTypeLayout); + bool pendingDataWantsSpaceOrSet = false; + if( canUseSpaceOrSet ) + { + for( auto resInfo : pendingElementTypeLayout->resourceInfos ) + { + if( resInfo.kind != LayoutResourceKind::RegisterSpace ) + { + pendingDataWantsSpaceOrSet = true; + break; + } + } + } + + // We will use a few different variables to track resource + // usage for the pending data, with roles similar to the + // umbrella type layout, container layout, and element layout + // that already came up for the main part of the parameter group type. + + + RefPtr pendingContainerTypeLayout = new TypeLayout(); + pendingContainerTypeLayout->type = parameterGroupType; + pendingContainerTypeLayout->rules = parameterGroupRules; + + containerTypeLayout->pendingDataTypeLayout = pendingContainerTypeLayout; + + RefPtr pendingContainerVarLayout = new VarLayout(); + pendingContainerVarLayout->typeLayout = pendingContainerTypeLayout; + + containerVarLayout->pendingVarLayout = pendingContainerVarLayout; + + + RefPtr pendingElementVarLayout = new VarLayout(); + pendingElementVarLayout->typeLayout = pendingElementTypeLayout; + + elementVarLayout->pendingVarLayout = pendingElementVarLayout; + + // If we need a space/set for the pending data, and don't already + // have one, then we will allocate it now, as part of the + // "full" data type. + // + if( pendingDataWantsSpaceOrSet && !wantSpaceOrSet ) + { + pendingContainerTypeLayout->addResourceUsage(LayoutResourceKind::RegisterSpace, 1); + + // From here on, we know we have access to a register space, + // and we can mask any registers/bindings appropriately. + // + wantSpaceOrSet = true; + } + + // If we need a constant buffer for laying out ordinary + // data, and didn't have one allocated before, we will create + // one. + // + if( pendingDataWantsConstantBuffer && !wantConstantBuffer ) + { + auto cbUsage = rules->GetObjectLayout(ShaderParameterKind::ConstantBuffer); + pendingContainerTypeLayout->addResourceUsage(cbUsage.kind, cbUsage.size); + + wantConstantBuffer = true; + } + + for( auto resInfo : pendingContainerTypeLayout->resourceInfos ) + { + pendingContainerVarLayout->findOrAddResourceInfo(resInfo.kind); + } + + // Now that we've added in the resource usage for any CB or set/space + // we needed to allocate just for the pending data, we can safely + // lay out the pending data itself. + // + // The ordinary/uniform part of things wil always be "masked" and + // needs to come after any uniform data from the original element type. + // + // To kick things off we will initialize state for `struct` type layout, + // so that we can lay out the pending data as if it were the second + // field in a structure type, after the original data. + // + UniformLayoutInfo uniformLayout = rules->BeginStructLayout(); + if( auto resInfo = rawElementTypeLayout->FindResourceInfo(LayoutResourceKind::Uniform) ) + { + uniformLayout.alignment = rawElementTypeLayout->uniformAlignment; + uniformLayout.size = resInfo->count; + } + + // Now we can scan through the resources used by the pending data. + // + for( auto resInfo : pendingElementTypeLayout->resourceInfos ) + { + if( resInfo.kind == LayoutResourceKind::Uniform ) + { + // For the ordinary/uniform resource kind, we will add the resource + // usage as a structure field, and then write the resulting offset + // into the variable layout for the pending data. + // + auto offset = rules->AddStructField( + &uniformLayout, + UniformLayoutInfo( + resInfo.count, + pendingElementTypeLayout->uniformAlignment)); + pendingElementVarLayout->findOrAddResourceInfo(resInfo.kind)->index = offset.getFiniteValue(); + } + else + { + // For all other resource kinds, we will set the offset in + // the variable layout based on the total resources of that + // kind seen so far (including the "container" if any), + // and then bump the count for total resource usage. + // + auto elementVarResInfo = pendingElementVarLayout->findOrAddResourceInfo(resInfo.kind); + if( auto containerTypeInfo = pendingContainerTypeLayout->FindResourceInfo(resInfo.kind) ) + { + elementVarResInfo->index = containerTypeInfo->count.getFiniteValue(); + } + } + } + rules->EndStructLayout(&uniformLayout); + + // Okay, now we have a `VarLayout` for the element data, and an overall `TypeLayout` + // for all the data that this parameter group needs allocated for pending + // data. + // + // The next major step is to compute the version of that combined resource usage + // that will "bleed through" and thus needs to be allocated at the next level + // up the hierarchy. + // + RefPtr unmaskedPendingDataTypeLayout = new TypeLayout(); + _addUnmaskedResourceUsage(unmaskedPendingDataTypeLayout, pendingContainerTypeLayout, wantSpaceOrSet); + _addUnmaskedResourceUsage(unmaskedPendingDataTypeLayout, pendingElementTypeLayout, wantSpaceOrSet); + + // TODO: we should probably optimize for the case where there is no unmasked + // usage that needs to be reported out, since it should be a common case. + + // Now we need to update the type layout to what we've done. + // + typeLayout->pendingDataTypeLayout = unmaskedPendingDataTypeLayout; + } + + return typeLayout; +} + + /// Do we need to wrap the given element type in a constant buffer layout? +static bool needsConstantBuffer(RefPtr elementTypeLayout) +{ + // We need a constant buffer if the element type has ordinary/uniform data. + // + if(_usesOrdinaryData(elementTypeLayout)) + return true; + + // We also need a constant buffer if there is any "pending" + // data that need ordinary/uniform data allocated to them. + // + if(auto pendingDataTypeLayout = elementTypeLayout->pendingDataTypeLayout) + { + if(_usesOrdinaryData(pendingDataTypeLayout)) + return true; + } + + return false; +} + +RefPtr createConstantBufferTypeLayoutIfNeeded( + TypeLayoutContext const& context, + RefPtr elementTypeLayout) +{ + // First things first, we need to check whether the element type + // we are trying to lay out even needs a constant buffer allocated + // for it. + // + if(!needsConstantBuffer(elementTypeLayout)) + return elementTypeLayout; + + auto parameterGroupRules = context.getRulesFamily()->getConstantBufferRules(); + + return _createParameterGroupTypeLayout( + context + .with(parameterGroupRules) + .with(context.targetReq->getDefaultMatrixLayoutMode()), + nullptr, + elementTypeLayout); +} + + +static RefPtr _createParameterGroupTypeLayout( + TypeLayoutContext const& context, + RefPtr parameterGroupType, + RefPtr elementType, + LayoutRulesImpl* elementTypeRules) +{ + // We will first compute a layout for the element type of + // the parameter group. + // + auto elementTypeLayout = createTypeLayout( + context.with(elementTypeRules), + elementType); + + // Now we delegate to a routine that does the meat of + // the complicated layout logic. + // + return _createParameterGroupTypeLayout( + context, + parameterGroupType, + elementTypeLayout); +} + +LayoutRulesImpl* getParameterBufferElementTypeLayoutRules( + RefPtr parameterGroupType, + LayoutRulesImpl* rules) +{ + if( as(parameterGroupType) ) + { + return rules->getLayoutRulesFamily()->getConstantBufferRules(); + } + else if( as(parameterGroupType) ) + { + return rules->getLayoutRulesFamily()->getTextureBufferRules(); + } + else if( as(parameterGroupType) ) + { + return rules->getLayoutRulesFamily()->getVaryingInputRules(); + } + else if( as(parameterGroupType) ) + { + return rules->getLayoutRulesFamily()->getVaryingOutputRules(); + } + else if( as(parameterGroupType) ) + { + return rules->getLayoutRulesFamily()->getShaderStorageBufferRules(); + } + else if (as(parameterGroupType)) + { + return rules->getLayoutRulesFamily()->getParameterBlockRules(); + } + else + { + SLANG_UNEXPECTED("uhandled parameter block type"); + return nullptr; + } +} + +RefPtr createParameterGroupTypeLayout( + TypeLayoutContext const& context, + RefPtr parameterGroupType) +{ + auto parameterGroupRules = context.rules; + + // Determine the layout rules to use for the contents of the block + auto elementTypeRules = getParameterBufferElementTypeLayoutRules( + parameterGroupType, + parameterGroupRules); + + auto elementType = parameterGroupType->elementType; + + return _createParameterGroupTypeLayout( + context, + parameterGroupType, + elementType, + elementTypeRules); +} + +// Create a type layout for a structured buffer type. +RefPtr +createStructuredBufferTypeLayout( + TypeLayoutContext const& context, + ShaderParameterKind kind, + RefPtr structuredBufferType, + RefPtr elementTypeLayout) +{ + auto rules = context.rules; + auto info = rules->GetObjectLayout(kind); + + auto typeLayout = new StructuredBufferTypeLayout(); + + typeLayout->type = structuredBufferType; + typeLayout->rules = rules; + + typeLayout->elementTypeLayout = elementTypeLayout; + + typeLayout->uniformAlignment = info.alignment; + SLANG_RELEASE_ASSERT(!typeLayout->FindResourceInfo(LayoutResourceKind::Uniform)); + SLANG_RELEASE_ASSERT(typeLayout->uniformAlignment == 1); + + if( info.size != 0 ) + { + typeLayout->addResourceUsage(info.kind, info.size); + } + + // Note: for now we don't deal with the case of a structured + // buffer that might contain anything other than "uniform" data, + // because there really isn't a way to implement that. + + return typeLayout; +} + +// Create a type layout for a structured buffer type. +RefPtr +createStructuredBufferTypeLayout( + TypeLayoutContext const& context, + ShaderParameterKind kind, + RefPtr structuredBufferType, + RefPtr elementType) +{ + // TODO(tfoley): we should be looking up the appropriate rules + // via the `LayoutRulesFamily` in use here... + auto structuredBufferLayoutRules = GetLayoutRulesImpl( + LayoutRule::HLSLStructuredBuffer); + + // Create and save type layout for the buffer contents. + auto elementTypeLayout = createTypeLayout( + context.with(structuredBufferLayoutRules), + elementType.Ptr()); + + return createStructuredBufferTypeLayout( + context, + kind, + structuredBufferType, + elementTypeLayout); + +} + + /// Create layout information for the given `type`. + /// + /// This internal routine returns both the constructed type + /// layout object and the simple layout info, encapsulated + /// together as a `TypeLayoutResult`. + /// +static TypeLayoutResult _createTypeLayout( + TypeLayoutContext const& context, + Type* type); + + /// Create layout information for the given `type`, obeying any layout modifiers on the given declaration. + /// + /// If `declForModifiers` has any matrix layout modifiers associated with it, then + /// the resulting type layout will respect those modifiers. + /// +static TypeLayoutResult _createTypeLayout( + TypeLayoutContext const& context, + Type* type, + Decl* declForModifiers) +{ + TypeLayoutContext subContext = context; + + if (declForModifiers) + { + // TODO: The approach implemented here has a row/column-major + // layout model recursively affect any sub-fields (so that + // the layout of a nested struct depends on the context where + // it is nested). This is consistent with the GLSL behavior + // for these modifiers, but it is *not* how HLSL is supposed + // to work. + // + // In the trivial case where `row_major` and `column_major` + // are only applied to leaf fields/variables of matrix type + // the difference should be immaterial. + + if (declForModifiers->HasModifier()) + subContext.matrixLayoutMode = kMatrixLayoutMode_RowMajor; + + if (declForModifiers->HasModifier()) + subContext.matrixLayoutMode = kMatrixLayoutMode_ColumnMajor; + + // TODO: really need to look for other modifiers that affect + // layout, such as GLSL `std140`. + } + + return _createTypeLayout(subContext, type); +} + +int findGenericParam(List> & genericParameters, GlobalGenericParamDecl * decl) +{ + return (int)genericParameters.findFirstIndex([=](RefPtr & x) {return x->decl.Ptr() == decl; }); +} + +// When constructing a new var layout from an existing one, +// copy fields to the new var from the old. +void copyVarLayoutFields( + VarLayout* dstVarLayout, + VarLayout* srcVarLayout) +{ + dstVarLayout->varDecl = srcVarLayout->varDecl; + dstVarLayout->typeLayout = srcVarLayout->typeLayout; + dstVarLayout->flags = srcVarLayout->flags; + dstVarLayout->systemValueSemantic = srcVarLayout->systemValueSemantic; + dstVarLayout->systemValueSemanticIndex = srcVarLayout->systemValueSemanticIndex; + dstVarLayout->semanticName = srcVarLayout->semanticName; + dstVarLayout->semanticIndex = srcVarLayout->semanticIndex; + dstVarLayout->stage = srcVarLayout->stage; + dstVarLayout->resourceInfos = srcVarLayout->resourceInfos; +} + +// When constructing a new type layout from an existing one, +// copy fields to the new type from the old. +void copyTypeLayoutFields( + TypeLayout* dstTypeLayout, + TypeLayout* srcTypeLayout) +{ + dstTypeLayout->type = srcTypeLayout->type; + dstTypeLayout->rules = srcTypeLayout->rules; + dstTypeLayout->uniformAlignment = srcTypeLayout->uniformAlignment; + dstTypeLayout->resourceInfos = srcTypeLayout->resourceInfos; +} + +// Does this layout resource kind require adjustment when used in +// an array-of-structs fashion? +bool doesResourceRequireAdjustmentForArrayOfStructs(LayoutResourceKind kind) +{ + switch( kind ) + { + case LayoutResourceKind::ConstantBuffer: + case LayoutResourceKind::ShaderResource: + case LayoutResourceKind::UnorderedAccess: + case LayoutResourceKind::SamplerState: + return true; + + default: + return false; + } +} + +// Given the type layout for an element of an array, apply any adjustments required +// based on the element count of the array. +// +// The particular case where this matters is when we have an array of an aggregate +// type that contains resources, since each resource field might need to be at +// a different offset than we would otherwise expect. +// +// For example, given: +// +// struct Foo { Texture2D a; Texture2D b; } +// +// if we just write: +// +// Foo foo; +// +// it gets split into: +// +// Texture2D foo_a; +// Texture2D foo_b; +// +// we expect `foo_a` to get `register(t0)` and +// `foo_b` to get `register(t1)`. However, if we instead have an array: +// +// Foo foo[10]; +// +// then we expect it to be split into: +// +// Texture2D foo_a[8]; +// Texture2D foo_b[8]; +// +// and then we expect `foo_b` to get `register(t8)`, rather +// than `register(t1)`. +// +static RefPtr maybeAdjustLayoutForArrayElementType( + RefPtr originalTypeLayout, + LayoutSize elementCount, + UInt& ioAdditionalSpacesNeeded) +{ + // We will start by looking for cases that we can reject out + // of hand. + + // If the original element type layout doesn't use any + // resource registers, then we are fine. + bool anyResource = false; + for( auto resInfo : originalTypeLayout->resourceInfos ) + { + if( doesResourceRequireAdjustmentForArrayOfStructs(resInfo.kind) ) + { + anyResource = true; + break; + } + } + if(!anyResource) + return originalTypeLayout; + + // Let's look at the type layout we have, and see if there is anything + // that we need to do with it. + // + if( auto originalArrayTypeLayout = originalTypeLayout.as() ) + { + // The element type is itself an array, so we'll need to adjust + // *its* element type accordingly. + // + // We adjust the already-adjusted element type of the inner + // array type, so that we pick up adjustments already made: + auto originalInnerElementTypeLayout = originalArrayTypeLayout->elementTypeLayout; + auto adjustedInnerElementTypeLayout = maybeAdjustLayoutForArrayElementType( + originalInnerElementTypeLayout, + elementCount, + ioAdditionalSpacesNeeded); + + // If nothing needed to be changed on the inner element type, + // then we are done. + if(adjustedInnerElementTypeLayout == originalInnerElementTypeLayout) + return originalTypeLayout; + + // Otherwise, we need to construct a new array type layout + RefPtr adjustedArrayTypeLayout = new ArrayTypeLayout(); + adjustedArrayTypeLayout->originalElementTypeLayout = originalInnerElementTypeLayout; + adjustedArrayTypeLayout->elementTypeLayout = adjustedInnerElementTypeLayout; + adjustedArrayTypeLayout->uniformStride = originalArrayTypeLayout->uniformStride; + + copyTypeLayoutFields(adjustedArrayTypeLayout, originalArrayTypeLayout); + + return adjustedArrayTypeLayout; + } + else if(auto originalParameterGroupTypeLayout = originalTypeLayout.as() ) + { + auto originalInnerElementTypeLayout = originalParameterGroupTypeLayout->elementVarLayout->typeLayout; + auto adjustedInnerElementTypeLayout = maybeAdjustLayoutForArrayElementType( + originalInnerElementTypeLayout, + elementCount, + ioAdditionalSpacesNeeded); + + // If nothing needed to be changed on the inner element type, + // then we are done. + if(adjustedInnerElementTypeLayout == originalInnerElementTypeLayout) + return originalTypeLayout; + + // TODO: actually adjust the element type, and create all the required bits and + // pieces of layout. + + SLANG_UNIMPLEMENTED_X("array of parameter group"); + UNREACHABLE_RETURN(originalTypeLayout); + } + else if(auto originalStructTypeLayout = originalTypeLayout.as() ) + { + Index fieldCount = originalStructTypeLayout->fields.getCount(); + + // Empty struct? Bail out. + if(fieldCount == 0) + return originalTypeLayout; + + RefPtr adjustedStructTypeLayout = new StructTypeLayout(); + copyTypeLayoutFields(adjustedStructTypeLayout, originalStructTypeLayout); + + // If the array type adjustment forces us to give a whole space to + // one or more fields, then we'll need to carefully compute the space + // index for each field as we go. + // + LayoutSize nextSpaceIndex = 0; + + Dictionary, RefPtr> mapOriginalFieldToAdjusted; + for( auto originalField : originalStructTypeLayout->fields ) + { + auto originalFieldTypeLayout = originalField->typeLayout; + + LayoutSize originalFieldSpaceCount = 0; + if(auto resInfo = originalFieldTypeLayout->FindResourceInfo(LayoutResourceKind::RegisterSpace)) + originalFieldSpaceCount = resInfo->count; + + // Compute the adjusted type for the field + UInt fieldAdditionalSpaces = 0; + auto adjustedFieldTypeLayout = maybeAdjustLayoutForArrayElementType( + originalFieldTypeLayout, + elementCount, + fieldAdditionalSpaces); + + LayoutSize adjustedFieldSpaceCount = originalFieldSpaceCount + fieldAdditionalSpaces; + + LayoutSize spaceOffsetForField = nextSpaceIndex; + nextSpaceIndex += adjustedFieldSpaceCount; + + ioAdditionalSpacesNeeded += fieldAdditionalSpaces; + + // Create an adjusted field variable, that is mostly + // a clone of the original field (just with our + // adjusted type in place). + RefPtr adjustedField = new VarLayout(); + copyVarLayoutFields(adjustedField, originalField); + adjustedField->typeLayout = adjustedFieldTypeLayout; + + // We will now walk through the resource usage for + // the adjusted field, and try to figure out what + // to do with it all. + // + for(auto& resInfo : adjustedField->resourceInfos ) + { + if( doesResourceRequireAdjustmentForArrayOfStructs(resInfo.kind) ) + { + if(elementCount.isFinite()) + { + // If the array size is finite, then the field's index/offset + // is just going to be strided by the array size since we + // are effectively doing AoS to SoA conversion. + // + resInfo.index *= elementCount.getFiniteValue(); + } + else + { + // If we are making an unbounded array, then a `struct` + // field with resource type will turn into its own space, + // and it will start at register zero in that space. + // + resInfo.index = 0; + resInfo.space = spaceOffsetForField.getFiniteValue(); + } + } + } + + adjustedStructTypeLayout->fields.add(adjustedField); + + mapOriginalFieldToAdjusted.Add(originalField, adjustedField); + } + + for( auto p : originalStructTypeLayout->mapVarToLayout ) + { + Decl* key = p.Key; + RefPtr originalVal = p.Value; + RefPtr adjustedVal; + if( mapOriginalFieldToAdjusted.TryGetValue(originalVal, adjustedVal) ) + { + adjustedStructTypeLayout->mapVarToLayout.Add(key, adjustedVal); + } + } + + return adjustedStructTypeLayout; + } + else + { + // In the leaf case, we must have a field that used up some resource + // that requires adjustment. Because there is no sub-structure to work + // with, we can just return the type layout as-is, but we also want + // to make a note that this value should consume an additional register + // space *if* the element count is unbounded. + if( elementCount.isInfinite() ) + { + ioAdditionalSpacesNeeded++; + } + + return originalTypeLayout; + } +} + + /// Convert a `TypeLayout` to a `TypeLayoutResult` + /// + /// A `TypeLayout` holds all the data needed to make a `TypeLayoutResult` in practice, + /// but sometimes it is more convenient to have the data split out. + /// +TypeLayoutResult makeTypeLayoutResult(RefPtr typeLayout) +{ + TypeLayoutResult result; + result.layout = typeLayout; + + // If the type only consumes a single kind of non-uniform resource, + // we can fill in the `info` field directly. + // + if( typeLayout->resourceInfos.getCount() == 1 ) + { + auto resInfo = typeLayout->resourceInfos[0]; + if( resInfo.kind != LayoutResourceKind::Uniform ) + { + result.info.kind = resInfo.kind; + result.info.size = resInfo.count; + return result; + } + } + + // Otherwise, we will fill out the info based on the uniform + // resources consumed, if any. + // + if( auto resInfo = typeLayout->FindResourceInfo(LayoutResourceKind::Uniform) ) + { + result.info.kind = LayoutResourceKind::Uniform; + result.info.alignment = typeLayout->uniformAlignment; + result.info.size = resInfo->count; + } + + // If there was no ordinary/uniform resource usage, then we + // will leave the `info` field in its default state (which + // shows no resources consumed). + // + // The type layout might have more detailed information, but + // at this point it must contain either zero, or more than one + // `ResourceInfo`, so there is nothing unambiguous we can + // store into `info`. + + return result; +} + +// +// StructTypeLayoutBuilder +// + +void StructTypeLayoutBuilder::beginLayout( + Type* type, + LayoutRulesImpl* rules) +{ + m_rules = rules; + + m_typeLayout = new StructTypeLayout(); + m_typeLayout->type = type; + m_typeLayout->rules = m_rules; + + m_info = m_rules->BeginStructLayout(); +} + +void StructTypeLayoutBuilder::beginLayoutIfNeeded( + Type* type, + LayoutRulesImpl* rules) +{ + if( !m_typeLayout ) + { + beginLayout(type, rules); + } +} + +RefPtr StructTypeLayoutBuilder::addField( + DeclRef field, + TypeLayoutResult fieldResult) +{ + SLANG_ASSERT(m_typeLayout); + + RefPtr fieldTypeLayout = fieldResult.layout; + UniformLayoutInfo fieldInfo = fieldResult.info.getUniformLayout(); + + // Note: we don't add any zero-size fields + // when computing structure layout, just + // to avoid having a resource type impact + // the final layout. + // + // This means that the code to generate final + // declarations needs to *also* eliminate zero-size + // fields to be safe... + // + LayoutSize uniformOffset = m_info.size; + if(fieldInfo.size != 0) + { + uniformOffset = m_rules->AddStructField(&m_info, fieldInfo); + } + + + // We need to create variable layouts + // for each field of the structure. + RefPtr fieldLayout = new VarLayout(); + fieldLayout->varDecl = field; + fieldLayout->typeLayout = fieldTypeLayout; + m_typeLayout->fields.add(fieldLayout); + + if( field ) + { + m_typeLayout->mapVarToLayout.Add(field.getDecl(), fieldLayout); + } + + // Set up uniform offset information, if there is any uniform data in the field + if( fieldTypeLayout->FindResourceInfo(LayoutResourceKind::Uniform) ) + { + fieldLayout->AddResourceInfo(LayoutResourceKind::Uniform)->index = uniformOffset.getFiniteValue(); + } + + // Add offset information for any other resource kinds + for( auto fieldTypeResourceInfo : fieldTypeLayout->resourceInfos ) + { + // Uniforms were dealt with above + if(fieldTypeResourceInfo.kind == LayoutResourceKind::Uniform) + continue; + + // We should not have already processed this resource type + SLANG_RELEASE_ASSERT(!fieldLayout->FindResourceInfo(fieldTypeResourceInfo.kind)); + + // The field will need offset information for this kind + auto fieldResourceInfo = fieldLayout->AddResourceInfo(fieldTypeResourceInfo.kind); + + // It is possible for a `struct` field to use an unbounded array + // type, and in the D3D case that would consume an unbounded number + // of registers. What is more, a single `struct` could have multiple + // such fields, or ordinary resource fields after an unbounded field. + // + // We handle this case by allocating a distinct register space for + // any field that consumes an unbounded amount of registers. + // + if( fieldTypeResourceInfo.count.isInfinite() ) + { + // We need to add one register space to own the storage for this field. + // + auto structTypeSpaceResourceInfo = m_typeLayout->findOrAddResourceInfo(LayoutResourceKind::RegisterSpace); + auto spaceOffset = structTypeSpaceResourceInfo->count; + structTypeSpaceResourceInfo->count += 1; + + // The field itself will record itself as having a zero offset into + // the chosen space. + // + fieldResourceInfo->space = spaceOffset.getFiniteValue(); + fieldResourceInfo->index = 0; + } + else + { + // In the case where the field consumes a finite number of slots, we + // can simply set its offset/index to the number of such slots consumed + // so far, and then increment the number of slots consumed by the + // `struct` type itself. + // + auto structTypeResourceInfo = m_typeLayout->findOrAddResourceInfo(fieldTypeResourceInfo.kind); + fieldResourceInfo->index = structTypeResourceInfo->count.getFiniteValue(); + structTypeResourceInfo->count += fieldTypeResourceInfo.count; + } + } + + return fieldLayout; +} + +RefPtr StructTypeLayoutBuilder::addField( + DeclRef field, + RefPtr fieldTypeLayout) +{ + TypeLayoutResult fieldResult = makeTypeLayoutResult(fieldTypeLayout); + return addField(field, fieldResult); +} + +void StructTypeLayoutBuilder::endLayout() +{ + if(!m_typeLayout) return; + + m_rules->EndStructLayout(&m_info); + + m_typeLayout->uniformAlignment = m_info.alignment; + m_typeLayout->addResourceUsage(LayoutResourceKind::Uniform, m_info.size); +} + +RefPtr StructTypeLayoutBuilder::getTypeLayout() +{ + return m_typeLayout; +} + +TypeLayoutResult StructTypeLayoutBuilder::getTypeLayoutResult() +{ + return TypeLayoutResult(m_typeLayout, m_info); +} + +static TypeLayoutResult _createTypeLayout( + TypeLayoutContext const& context, + Type* type) +{ + auto rules = context.rules; + + if (auto parameterGroupType = as(type)) + { + // If the user is just interested in uniform layout info, + // then this is easy: a `ConstantBuffer` is really no + // different from a `Texture2D` in terms of how it + // should be handled as a member of a container. + // + auto info = getParameterGroupLayoutInfo(parameterGroupType, rules); + + // The more interesting case, though, is when the user + // is requesting us to actually create a `TypeLayout`, + // since in that case we need to: + // + // 1. Compute a layout for the data inside the constant + // buffer, including offsets, etc. + // + // 2. Compute information about any object types inside + // the constant buffer, which need to be surfaces out + // to the top level. + // + auto typeLayout = createParameterGroupTypeLayout( + context, + parameterGroupType); + + return TypeLayoutResult(typeLayout, info); + } + else if (auto samplerStateType = as(type)) + { + return createSimpleTypeLayout( + rules->GetObjectLayout(ShaderParameterKind::SamplerState), + type, + rules); + } + else if (auto textureType = as(type)) + { + // TODO: the logic here should really be defined by the rules, + // and not at this top level... + ShaderParameterKind kind; + switch( textureType->getAccess() ) + { + default: + kind = ShaderParameterKind::MutableTexture; + break; + + case SLANG_RESOURCE_ACCESS_READ: + kind = ShaderParameterKind::Texture; + break; + } + + return createSimpleTypeLayout( + rules->GetObjectLayout(kind), + type, + rules); + } + else if (auto imageType = as(type)) + { + // TODO: the logic here should really be defined by the rules, + // and not at this top level... + ShaderParameterKind kind; + switch( imageType->getAccess() ) + { + default: + kind = ShaderParameterKind::MutableImage; + break; + + case SLANG_RESOURCE_ACCESS_READ: + kind = ShaderParameterKind::Image; + break; + } + + return createSimpleTypeLayout( + rules->GetObjectLayout(kind), + type, + rules); + } + else if (auto textureSamplerType = as(type)) + { + // TODO: the logic here should really be defined by the rules, + // and not at this top level... + ShaderParameterKind kind; + switch( textureSamplerType->getAccess() ) + { + default: + kind = ShaderParameterKind::MutableTextureSampler; + break; + + case SLANG_RESOURCE_ACCESS_READ: + kind = ShaderParameterKind::TextureSampler; + break; + } + + return createSimpleTypeLayout( + rules->GetObjectLayout(kind), + type, + rules); + } + + // TODO: need a better way to handle this stuff... +#define CASE(TYPE, KIND) \ + else if(auto type_##TYPE = as(type)) do { \ + auto info = rules->GetObjectLayout(ShaderParameterKind::KIND); \ + auto typeLayout = createStructuredBufferTypeLayout( \ + context, \ + ShaderParameterKind::KIND, \ + type_##TYPE, \ + type_##TYPE->elementType.Ptr()); \ + return TypeLayoutResult(typeLayout, info); \ + } while(0) + + CASE(HLSLStructuredBufferType, StructuredBuffer); + CASE(HLSLRWStructuredBufferType, MutableStructuredBuffer); + CASE(HLSLRasterizerOrderedStructuredBufferType, MutableStructuredBuffer); + CASE(HLSLAppendStructuredBufferType, MutableStructuredBuffer); + CASE(HLSLConsumeStructuredBufferType, MutableStructuredBuffer); + +#undef CASE + + + // TODO: need a better way to handle this stuff... +#define CASE(TYPE, KIND) \ + else if(as(type)) do { \ + return createSimpleTypeLayout( \ + rules->GetObjectLayout(ShaderParameterKind::KIND), \ + type, rules); \ + } while(0) + + CASE(HLSLByteAddressBufferType, RawBuffer); + CASE(HLSLRWByteAddressBufferType, MutableRawBuffer); + CASE(HLSLRasterizerOrderedByteAddressBufferType, MutableRawBuffer); + + CASE(GLSLInputAttachmentType, InputRenderTarget); + + // This case is mostly to allow users to add new resource types... + CASE(UntypedBufferResourceType, RawBuffer); + +#undef CASE + + else if(auto basicType = as(type)) + { + return createSimpleTypeLayout( + rules->GetScalarLayout(basicType->baseType), + type, + rules); + } + else if(auto vecType = as(type)) + { + auto elementType = vecType->elementType; + size_t elementCount = (size_t) GetIntVal(vecType->elementCount); + + auto element = _createTypeLayout( + context, + elementType); + + auto info = rules->GetVectorLayout(element.info, elementCount); + + RefPtr typeLayout = new VectorTypeLayout(); + typeLayout->type = type; + typeLayout->rules = rules; + typeLayout->uniformAlignment = info.alignment; + + typeLayout->elementTypeLayout = element.layout; + typeLayout->uniformStride = element.info.getUniformLayout().size.getFiniteValue(); + + typeLayout->addResourceUsage(info.kind, info.size); + + return TypeLayoutResult(typeLayout, info); + } + else if(auto matType = as(type)) + { + size_t rowCount = (size_t) GetIntVal(matType->getRowCount()); + size_t colCount = (size_t) GetIntVal(matType->getColumnCount()); + + auto elementType = matType->getElementType(); + auto elementResult = _createTypeLayout( + context, + elementType); + auto elementTypeLayout = elementResult.layout; + auto elementInfo = elementResult.info; + + // The `GetMatrixLayout` implementation in the layout rules + // currently defaults to assuming row-major layout, + // so if we want column-major layout we achieve it here by + // transposing the major/minor axes counts. + // + size_t layoutMajorCount = rowCount; + size_t layoutMinorCount = colCount; + if (context.matrixLayoutMode == kMatrixLayoutMode_ColumnMajor) + { + size_t tmp = layoutMajorCount; + layoutMajorCount = layoutMinorCount; + layoutMinorCount = tmp; + } + auto info = rules->GetMatrixLayout( + elementInfo, + layoutMajorCount, + layoutMinorCount); + + auto rowType = matType->getRowType(); + RefPtr rowTypeLayout = new VectorTypeLayout(); + + auto rowInfo = rules->GetVectorLayout( + elementInfo, + colCount); + + size_t majorStride = info.elementStride; + size_t minorStride = elementInfo.getUniformLayout().size.getFiniteValue(); + + size_t rowStride = 0; + size_t colStride = 0; + if(context.matrixLayoutMode == kMatrixLayoutMode_ColumnMajor) + { + colStride = majorStride; + rowStride = minorStride; + } + else + { + rowStride = majorStride; + colStride = minorStride; + } + + rowTypeLayout->type = type; + rowTypeLayout->rules = rules; + rowTypeLayout->uniformAlignment = elementInfo.getUniformLayout().alignment; + + rowTypeLayout->uniformStride = colStride; + rowTypeLayout->elementTypeLayout = elementTypeLayout; + rowTypeLayout->addResourceUsage(rowInfo.kind, rowInfo.size); + + RefPtr typeLayout = new MatrixTypeLayout(); + + typeLayout->type = type; + typeLayout->rules = rules; + typeLayout->uniformAlignment = info.alignment; + + typeLayout->elementTypeLayout = rowTypeLayout; + typeLayout->uniformStride = rowStride; + typeLayout->mode = context.matrixLayoutMode; + + typeLayout->addResourceUsage(info.kind, info.size); + + return TypeLayoutResult(typeLayout, info); + } + else if (auto arrayType = as(type)) + { + auto elementResult = _createTypeLayout( + context, + arrayType->baseType.Ptr()); + auto elementInfo = elementResult.info; + auto elementTypeLayout = elementResult.layout; + + // To a first approximation, an array will usually be laid out + // by taking the element's type layout and laying out `elementCount` + // copies of it. There are of course many details that make + // this simplistic version of things not quite work. + // + // An important complication to deal with is the possibility of + // having "unbounded" arrays, which don't specify a size.' + // The layout rules for these vary heavily by resource kind and API. + // + + auto elementCount = GetElementCount(arrayType->ArrayLength); + + // + // We can compute the uniform storage layout of an array using + // the rules for the target API. + // + // TODO: ensure that this does something reasonable with the unbounded + // case, or else issue an error message that the target doesn't + // support unbounded types. + // + + auto arrayUniformInfo = rules->GetArrayLayout( + elementInfo, + elementCount).getUniformLayout(); + + RefPtr typeLayout = new ArrayTypeLayout(); + + // Some parts of the array type layout object are easy to fill in: + typeLayout->type = type; + typeLayout->rules = rules; + typeLayout->originalElementTypeLayout = elementTypeLayout; + typeLayout->uniformAlignment = arrayUniformInfo.alignment; + typeLayout->uniformStride = arrayUniformInfo.elementStride; + + typeLayout->addResourceUsage(LayoutResourceKind::Uniform, arrayUniformInfo.size); + + // + // The tricky part in constructing an array type layout comes when + // the element type is (or nests) a structure with resource-type + // fields, because in that case we need to perform AoS-to-SoA + // conversion as part of computing the final type layout, and + // we also need to pre-compute an "adjusted" element type + // layout that accounts for the striding that happens with + // resource-type contents. + // + // This complication is only made worse when we have to deal with + // unbounded-size arrays over such element types, since those + // resource-type fields will each end up consuming a full space + // in the resulting layout. + // + // The `maybeAdjustLayoutForArrayElementType` computes an "adjusted" + // type layout for the element type which takes the array stride into + // account. If it returns the same type layout that was passed in, + // then that means no adjustement took place. + // + // The `additionalSpacesNeededForAdjustedElementType` variable counts + // the number of additional register spaces that were consumed, + // in the case of an unbounded array. + // + UInt additionalSpacesNeededForAdjustedElementType = 0; + RefPtr adjustedElementTypeLayout = maybeAdjustLayoutForArrayElementType( + elementTypeLayout, + elementCount, + additionalSpacesNeededForAdjustedElementType); + + typeLayout->elementTypeLayout = adjustedElementTypeLayout; + + // We will now iterate over the resources consumed by the element + // type to compute how they contribute to the resource usage + // of the overall array type. + // + for( auto elementResourceInfo : elementTypeLayout->resourceInfos ) + { + // The uniform case was already handled above + if( elementResourceInfo.kind == LayoutResourceKind::Uniform ) + continue; + + LayoutSize arrayResourceCount = 0; + + // In almost all cases, the resources consumed by an array + // will be its element count times the resources consumed + // by its element type. + // + // The first exception to this is arrays of resources when + // compiling to GLSL for Vulkan, where an entire array + // only consumes a single descriptor-table slot. + // + if (elementResourceInfo.kind == LayoutResourceKind::DescriptorTableSlot) + { + arrayResourceCount = elementResourceInfo.count; + } + // + // The next big exception is when we are forming an unbounded-size + // array and the element type got "adjusted," because that means + // the array type will need to allocate full spaces for any resource-type + // fields in the element type. + // + // Note: we carefully carve things out so that the case of a simple + // array of resources does *not* lead to the element type being adjusted, + // so that this logic doesn't trigger and we instead handle it with + // the default logic below. + // + else if( + elementCount.isInfinite() + && adjustedElementTypeLayout != elementTypeLayout + && doesResourceRequireAdjustmentForArrayOfStructs(elementResourceInfo.kind) ) + { + // We want to ignore resource types consumed by the element type + // that need adjustement if the array size is infinite, since + // we will be allocating whole spaces for that part of the + // element's resource usage. + } + else + { + arrayResourceCount = elementResourceInfo.count * elementCount; + } + + // Now that we've computed how the resource usage of the element type + // should contribute to the resource usage of the array, we can + // add in that resource usage. + // + typeLayout->addResourceUsage( + elementResourceInfo.kind, + arrayResourceCount); + } + + // The loop above to compute the resource usage of the array from its + // element type ignored any resource-type fields in an unbounded-size + // array if they would have been allocated as full register spaces. + // Those same fields were counted in `additionalSpacesNeededForAdjustedElementType`, + // and need to be added into the total resource usage for the array + // if we skipped them as part of the loop (which happens when + // we detect that the element type layout had been "adjusted"). + // + if( adjustedElementTypeLayout != elementTypeLayout ) + { + typeLayout->addResourceUsage(LayoutResourceKind::RegisterSpace, additionalSpacesNeededForAdjustedElementType); + } + + return TypeLayoutResult(typeLayout, arrayUniformInfo); + } + else if (auto declRefType = as(type)) + { + auto declRef = declRefType->declRef; + + if (auto structDeclRef = declRef.as()) + { + StructTypeLayoutBuilder typeLayoutBuilder; + StructTypeLayoutBuilder pendingDataTypeLayoutBuilder; + + typeLayoutBuilder.beginLayout(type, rules); + auto typeLayout = typeLayoutBuilder.getTypeLayout(); + for (auto field : GetFields(structDeclRef)) + { + // Static fields shouldn't take part in layout. + if(field.getDecl()->HasModifier()) + continue; + + // The fields of a `struct` type may include existential (interface) + // types (including as nested sub-fields), and any types present + // in those fields will need to be specialized based on the + // input arguments being passed to `_createTypeLayout`. + // + // We won't know how many type slots each field consumes until + // we process it, but we can figure out the starting index for + // the slots its will consume by looking at the layout we've + // computed so far. + // + Int baseExistentialSlotIndex = 0; + if(auto resInfo = typeLayout->FindResourceInfo(LayoutResourceKind::ExistentialTypeParam)) + baseExistentialSlotIndex = Int(resInfo->count.getFiniteValue()); + // + // When computing the layout for the field, we will give it access + // to all the incoming specialized type slots that haven't already + // been consumed/claimed by preceding fields. + // + auto fieldLayoutContext = context.withExistentialTypeSlotsOffsetBy(baseExistentialSlotIndex); + + TypeLayoutResult fieldResult = _createTypeLayout( + fieldLayoutContext, + GetType(field).Ptr(), + field.getDecl()); + auto fieldTypeLayout = fieldResult.layout; + + auto fieldVarLayout = typeLayoutBuilder.addField(field, fieldResult); + + // If any of the fields of the `struct` type had existential/interface + // type, then we need to compute a second `StructTypeLayout` that + // represents the layout and resource using for the "pending data" + // that this type needs to have stored somewhere, but which can't + // be laid out in the layout of the type itself. + // + if(auto fieldPendingDataTypeLayout = fieldTypeLayout->pendingDataTypeLayout) + { + // We only create this secondary layout on-demand, so that + // we don't end up with a bunch of empty structure type layouts + // created for no reason. + // + pendingDataTypeLayoutBuilder.beginLayoutIfNeeded(type, rules); + auto fieldPendingVarLayout = pendingDataTypeLayoutBuilder.addField(field, fieldPendingDataTypeLayout); + fieldVarLayout->pendingVarLayout = fieldPendingVarLayout; + } + } + + typeLayoutBuilder.endLayout(); + pendingDataTypeLayoutBuilder.endLayout(); + + if( auto pendingDataTypeLayout = pendingDataTypeLayoutBuilder.getTypeLayout() ) + { + typeLayout->pendingDataTypeLayout = pendingDataTypeLayout; + } + + return typeLayoutBuilder.getTypeLayoutResult(); + } + else if (auto globalGenParam = declRef.as()) + { + SimpleLayoutInfo info; + info.alignment = 0; + info.size = 0; + info.kind = LayoutResourceKind::GenericResource; + + auto genParamTypeLayout = new GenericParamTypeLayout(); + // we should have already populated ProgramLayout::genericEntryPointParams list at this point, + // so we can find the index of this generic param decl in the list + genParamTypeLayout->type = type; + genParamTypeLayout->paramIndex = findGenericParam(context.programLayout->globalGenericParams, genParamTypeLayout->getGlobalGenericParamDecl()); + genParamTypeLayout->rules = rules; + genParamTypeLayout->findOrAddResourceInfo(LayoutResourceKind::GenericResource)->count += 1; + + return TypeLayoutResult(genParamTypeLayout, info); + } + else if (auto assocTypeParam = declRef.as()) + { + return createSimpleTypeLayout( + SimpleLayoutInfo(), + type, + rules); + } + else if( auto simpleGenericParam = declRef.as() ) + { + // A bare generic type parameter can come up during layout + // of a generic entry point (or an entry point nested in + // a generic type). For now we will just pretend like + // the fields of generic parameter type take no space, + // since there is no reasonable way to account for them + // in the resulting layout. + // + // TODO: It might be better to completely ignore generic + // entry points during initial layout, but doing so would + // mean that users couldn't get layout information on + // any parameters, even those that don't depend on + // generics. + // + return createSimpleTypeLayout( + SimpleLayoutInfo(), + type, + rules); + } + else if( auto interfaceDeclRef = declRef.as() ) + { + // When laying out a type that includes interface-type fields, + // we cannot know how much space the concrete type that + // gets stored into the field consumes. + // + // If we were doing layout for a typical CPU target, then + // we could just say that each interface-type field consumes + // some fixed number of pointers (e.g., a data pointer plus a witness + // table pointer). + // + // We will borrow the intuition from that and invent a new + // resource kind for "existential slots" which conceptually + // represents the indirections needed to reference the + // data to be referenced by this field. + // + + RefPtr typeLayout = new TypeLayout(); + typeLayout->type = type; + typeLayout->rules = rules; + + typeLayout->addResourceUsage(LayoutResourceKind::ExistentialTypeParam, 1); + typeLayout->addResourceUsage(LayoutResourceKind::ExistentialObjectParam, 1); + + // If there are any concrete types available, the first one will be + // the value that should be plugged into the slot we just introduced. + // + if( context.existentialTypeArgCount ) + { + RefPtr concreteType = context.existentialTypeArgs[0].type; + + RefPtr concreteTypeLayout = createTypeLayout(context, concreteType); + + // Layout for this specialized interface type then results + // in a type layout that tracks both the resource usage of the + // interface type itself (just the type + value slots introduced + // above), plus a "pending data" type that represents the value + // conceptually pointed to by the interface-type field/variable at runtime. + // + typeLayout->pendingDataTypeLayout = concreteTypeLayout; + } + + return TypeLayoutResult(typeLayout, SimpleLayoutInfo()); + } + } + else if (auto errorType = as(type)) + { + // An error type means that we encountered something we don't understand. + // + // We should probably inform the user with an error message here. + + return createSimpleTypeLayout( + SimpleLayoutInfo(), + type, + rules); + } + else if( auto taggedUnionType = as(type) ) + { + // A tagged union type needs to be laid out as the maximum + // size of any constituent type. + // + // In practice, only a tagged union of uniform data will + // work, but for now we will compute the maximum usage + // for each resource kind for generality. + // + // For the uniform data we will start with a size + // of zero and an alignment of one for our base case + // (this is what a tagged union of no cases would consume). + // + UniformLayoutInfo info(0, 1); + + RefPtr taggedUnionLayout = new TaggedUnionTypeLayout(); + taggedUnionLayout->type = type; + taggedUnionLayout->rules = rules; + + // Now we iterate over the case types and see if they + // change our computed maximum size/alignement. + // + for( auto caseType : taggedUnionType->caseTypes ) + { + // Note: A tagged union type is not expected to have any existential/interface type + // slots; the case types that are provided must be fully specialized before the union is + // formed. Thus we don't need to mess around with existential type slots here the + // way we do for the `struct` case. + + auto caseTypeResult = _createTypeLayout(context, caseType); + RefPtr caseTypeLayout = caseTypeResult.layout; + UniformLayoutInfo caseTypeInfo = caseTypeResult.info.getUniformLayout(); + + info.size = maximum(info.size, caseTypeInfo.size); + info.alignment = std::max(info.alignment, caseTypeInfo.alignment); + + // We need to remember the layout of the case type + // on the final `TaggedUnionTypeLayout`. + // + taggedUnionLayout->caseTypeLayouts.add(caseTypeLayout); + + // We also need to consider contributions for other + // resource kinds beyond uniform data. + // + for( auto caseResInfo : caseTypeLayout->resourceInfos ) + { + auto unionResInfo = taggedUnionLayout->findOrAddResourceInfo(caseResInfo.kind); + unionResInfo->count = maximum(unionResInfo->count, caseResInfo.count); + } + } + + // After we've computed the size required to hold all the + // case types, we will allocate space for the tag field. + // + // TODO: This assumes the tag will always be allocated out + // of uniform storage, which means we can't support a tagged + // union as part of a varying input/output signature. That is + // probably a valid limitation, but it should get enforced + // somewhere along the way. + // + { + // The tag is always a `uint` for now. + // + auto tagInfo = context.rules->GetScalarLayout(BaseType::UInt); + info.size = RoundToAlignment(info.size, tagInfo.alignment); + + taggedUnionLayout->tagOffset = info.size; + + info.size += tagInfo.size; + info.alignment = std::max(info.alignment, tagInfo.alignment); + } + + // As a final step, if we are computing a full `TypeLayout` + // we will make sure that its information on uniform layout + // matches what we've computed in the `UniformLayoutInfo` we return. + // + taggedUnionLayout->findOrAddResourceInfo(LayoutResourceKind::Uniform)->count = info.size; + taggedUnionLayout->uniformAlignment = info.alignment; + + return TypeLayoutResult(taggedUnionLayout, info); + } + else if( auto existentialSpecializedType = as(type) ) + { + TypeLayoutContext subContext = context.withExistentialTypeArgs( + existentialSpecializedType->slots.args.getCount(), + existentialSpecializedType->slots.args.getBuffer()); + + auto baseTypeLayoutResult = _createTypeLayout( + subContext, + existentialSpecializedType->baseType); + + UniformLayoutInfo info = rules->BeginStructLayout(); + rules->AddStructField(&info, baseTypeLayoutResult.info.getUniformLayout()); + + RefPtr typeLayout = new ExistentialSpecializedTypeLayout(); + typeLayout->type = type; + typeLayout->rules = rules; + + RefPtr pendingDataVarLayout = new VarLayout(); + if(auto pendingDataTypeLayout = baseTypeLayoutResult.layout->pendingDataTypeLayout) + { + for( auto pendingResInfo : pendingDataTypeLayout->resourceInfos ) + { + auto kind = pendingResInfo.kind; + UInt index = 0; + if( kind == LayoutResourceKind::Uniform ) + { + LayoutSize uniformOffset = rules->AddStructField( + &info, + makeTypeLayoutResult(pendingDataTypeLayout).info.getUniformLayout()); + + index = uniformOffset.getFiniteValue(); + } + else + { + if(auto primaryResInfo = baseTypeLayoutResult.layout->FindResourceInfo(kind)) + index = primaryResInfo->count.getFiniteValue(); + } + pendingDataVarLayout->AddResourceInfo(kind)->index = index; + } + } + + typeLayout->baseTypeLayout = baseTypeLayoutResult.layout; + typeLayout->pendingDataVarLayout = pendingDataVarLayout; + + return makeTypeLayoutResult(typeLayout); + } + + // catch-all case in case nothing matched + SLANG_ASSERT(!"unimplemented case in type layout"); + return createSimpleTypeLayout( + SimpleLayoutInfo(), + type, + rules); +} + +RefPtr getSimpleVaryingParameterTypeLayout( + TypeLayoutContext const& context, + Type* type, + EntryPointParameterDirectionMask directionMask) +{ + auto rules = context.rules; + + // TODO: This logic should ideally share as much + // as possible with the `_createTypeLayout` function, + // to avoid duplication, but we also have to deal + // with the many ways in which varying parameter + // layout differs from non-varying layout. + + // We will compute resource consumption for the type + // as a varying input, output, or both/neither. + // To avoid duplication, we'll build an array that + // includes all the layout rules we need to apply. + // + int varyingRulesCount = 0; + LayoutRulesImpl* varyingRules[2]; + + if( directionMask & kEntryPointParameterDirection_Input ) + { + varyingRules[varyingRulesCount++] = context.getRulesFamily()->getVaryingInputRules(); + } + if( directionMask & kEntryPointParameterDirection_Output ) + { + varyingRules[varyingRulesCount++] = context.getRulesFamily()->getVaryingOutputRules(); + } + + if(auto basicType = as(type)) + { + auto baseType = basicType->baseType; + + RefPtr typeLayout = new TypeLayout(); + typeLayout->type = type; + typeLayout->rules = rules; + + for( int rr = 0; rr < varyingRulesCount; ++rr ) + { + auto info = varyingRules[rr]->GetScalarLayout(baseType); + typeLayout->addResourceUsage(info.kind, info.size); + } + + return typeLayout; + } + else if(auto vecType = as(type)) + { + auto elementType = vecType->elementType; + size_t elementCount = (size_t) GetIntVal(vecType->elementCount); + + BaseType elementBaseType = BaseType::Void; + if( auto elementBasicType = as(elementType) ) + { + elementBaseType = elementBasicType->baseType; + } + + // Note that we do *not* add any resource usage to the type + // layout for the element type, because we currently cannot count + // varying parameter usage at a granularity finer than + // individual "locations." + // + RefPtr elementTypeLayout = new TypeLayout(); + elementTypeLayout->type = elementType; + elementTypeLayout->rules = rules; + + RefPtr typeLayout = new VectorTypeLayout(); + typeLayout->type = vecType; + typeLayout->rules = rules; + typeLayout->elementTypeLayout = elementTypeLayout; + + for( int rr = 0; rr < varyingRulesCount; ++rr ) + { + auto varyingRuleSet = varyingRules[rr]; + auto elementInfo = varyingRuleSet->GetScalarLayout(elementBaseType); + auto info = varyingRuleSet->GetVectorLayout(elementInfo, elementCount); + typeLayout->addResourceUsage(info.kind, info.size); + } + + return typeLayout; + } + else if(auto matType = as(type)) + { + size_t rowCount = (size_t) GetIntVal(matType->getRowCount()); + size_t colCount = (size_t) GetIntVal(matType->getColumnCount()); + auto elementType = matType->getElementType(); + + BaseType elementBaseType = BaseType::Void; + if( auto elementBasicType = as(elementType) ) + { + elementBaseType = elementBasicType->baseType; + } + + // Just as for `_createTypeLayout`, we need to handle row- and + // column-major matrices differently, to ensure we get + // the expected layout. + // + // A varying parameter with row-major layout is effectively + // just an array of row vectors, while a column-major one + // is just an array of column vectors. + // + size_t layoutMajorCount = rowCount; + size_t layoutMinorCount = colCount; + if (context.matrixLayoutMode == kMatrixLayoutMode_ColumnMajor) + { + size_t tmp = layoutMajorCount; + layoutMajorCount = layoutMinorCount; + layoutMinorCount = tmp; + } + + RefPtr elementTypeLayout = new TypeLayout(); + elementTypeLayout->type = elementType; + elementTypeLayout->rules = rules; + + RefPtr rowTypeLayout = new VectorTypeLayout(); + rowTypeLayout->type = matType->getRowType(); + rowTypeLayout->rules = rules; + rowTypeLayout->elementTypeLayout = elementTypeLayout; + + RefPtr typeLayout = new MatrixTypeLayout(); + typeLayout->type = type; + typeLayout->rules = rules; + typeLayout->elementTypeLayout = rowTypeLayout; + typeLayout->mode = context.matrixLayoutMode; + + for( int rr = 0; rr < varyingRulesCount; ++rr ) + { + auto varyingRuleSet = varyingRules[rr]; + auto elementInfo = varyingRuleSet->GetScalarLayout(elementBaseType); + + auto info = varyingRuleSet->GetMatrixLayout(elementInfo, layoutMajorCount, layoutMinorCount); + typeLayout->addResourceUsage(info.kind, info.size); + + if(context.matrixLayoutMode == kMatrixLayoutMode_RowMajor) + { + // For row-major matrices only, we can compute an effective + // resource usage for the row type. + auto rowInfo = varyingRuleSet->GetVectorLayout(elementInfo, colCount); + rowTypeLayout->addResourceUsage(rowInfo.kind, rowInfo.size); + } + } + + return typeLayout; + } + + // catch-all case in case nothing matched + SLANG_ASSERT(!"unimplemented case for varying parameter layout"); + return createSimpleTypeLayout( + SimpleLayoutInfo(), + type, + rules).layout; +} + +RefPtr createTypeLayout( + TypeLayoutContext const& context, + Type* type) +{ + return _createTypeLayout(context, type).layout; +} + +void TypeLayout::addResourceUsageFrom(TypeLayout* otherTypeLayout) +{ + for(auto resInfo : otherTypeLayout->resourceInfos) + addResourceUsage(resInfo); +} + + +RefPtr TypeLayout::unwrapArray() +{ + TypeLayout* typeLayout = this; + + while(auto arrayTypeLayout = as(typeLayout)) + typeLayout = arrayTypeLayout->elementTypeLayout; + + return typeLayout; +} + + +RefPtr GenericParamTypeLayout::getGlobalGenericParamDecl() +{ + auto declRefType = as(type); + SLANG_ASSERT(declRefType); + auto rsDeclRef = declRefType->declRef.as(); + return rsDeclRef.getDecl(); +} + +} // namespace Slang diff --git a/source/slang/slang-type-layout.h b/source/slang/slang-type-layout.h new file mode 100644 index 000000000..97113c77f --- /dev/null +++ b/source/slang/slang-type-layout.h @@ -0,0 +1,1118 @@ +#ifndef SLANG_TYPE_LAYOUT_H +#define SLANG_TYPE_LAYOUT_H + +#include "../core/slang-basic.h" +#include "slang-compiler.h" +#include "slang-profile.h" +#include "slang-syntax.h" + +#include "../../slang.h" + +namespace Slang { + +// Forward declarations + +enum class BaseType; +class Type; + +// + +enum class LayoutRule +{ + Std140, + Std430, + HLSLConstantBuffer, + HLSLStructuredBuffer, +}; + +#if 0 +enum class LayoutRulesFamily +{ + HLSL, + GLSL, +}; +#endif + +// A "size" that can either be a simple finite size or +// the special case of an infinite/unbounded size. +// +struct LayoutSize +{ + typedef size_t RawValue; + + LayoutSize() + : raw(0) + {} + + LayoutSize(RawValue size) + : raw(size) + { + SLANG_ASSERT(size != RawValue(-1)); + } + + static LayoutSize infinite() + { + LayoutSize result; + result.raw = RawValue(-1); + return result; + } + + bool isInfinite() const { return raw == RawValue(-1); } + + bool isFinite() const { return raw != RawValue(-1); } + RawValue getFiniteValue() const { SLANG_ASSERT(isFinite()); return raw; } + + bool operator==(LayoutSize that) const + { + return raw == that.raw; + } + + bool operator!=(LayoutSize that) const + { + return raw != that.raw; + } + + void operator+=(LayoutSize right) + { + if( isInfinite() ) {} + else if( right.isInfinite() ) + { + *this = LayoutSize::infinite(); + } + else + { + *this = LayoutSize(raw + right.raw); + } + } + + void operator*=(LayoutSize right) + { + // Deal with zero first, so that anything (even the "infinite" value) times zero is zero. + if( raw == 0 ) + { + return; + } + + if( right.raw == 0 ) + { + raw = 0; + return; + } + + // Next we deal with infinite cases, so that infinite times anything non-zero is infinite + if( isInfinite() ) + { + return; + } + + if( right.isInfinite() ) + { + *this = LayoutSize::infinite(); + return; + } + + // Finally deal with the case where both sides are finite + *this = LayoutSize(raw * right.raw); + } + + void operator-=(RawValue right) + { + if( isInfinite() ) {} + else + { + *this = LayoutSize(raw - right); + } + } + + void operator/=(RawValue right) + { + if( isInfinite() ) {} + else + { + *this = LayoutSize(raw / right); + } + } + RawValue raw; +}; + +inline LayoutSize operator+(LayoutSize left, LayoutSize right) +{ + LayoutSize result(left); + result += right; + return result; +} + +inline LayoutSize operator*(LayoutSize left, LayoutSize right) +{ + LayoutSize result(left); + result *= right; + return result; +} + +inline LayoutSize operator-(LayoutSize left, LayoutSize::RawValue right) +{ + LayoutSize result(left); + result -= right; + return result; +} + +inline LayoutSize operator/(LayoutSize left, LayoutSize::RawValue right) +{ + LayoutSize result(left); + result /= right; + return result; +} + +inline LayoutSize maximum(LayoutSize left, LayoutSize right) +{ + if(left.isInfinite() || right.isInfinite()) + return LayoutSize::infinite(); + + return LayoutSize(Math::Max( + left.getFiniteValue(), + right.getFiniteValue())); +} + +inline bool operator>(LayoutSize left, LayoutSize::RawValue right) +{ + return left.isInfinite() || (left.getFiniteValue() > right); +} + +inline bool operator<=(LayoutSize left, LayoutSize::RawValue right) +{ + return left.isFinite() && (left.getFiniteValue() <= right); +} + +// Layout appropriate to "just memory" scenarios, +// such as laying out the members of a constant buffer. +struct UniformLayoutInfo +{ + LayoutSize size; + size_t alignment; + + UniformLayoutInfo() + : size(0) + , alignment(1) + {} + + UniformLayoutInfo( + LayoutSize size, + size_t alignment) + : size(size) + , alignment(alignment) + {} +}; + +// Extended information required for an array of uniform data, +// including the "stride" of the array (the space between +// consecutive elements). +struct UniformArrayLayoutInfo : UniformLayoutInfo +{ + size_t elementStride; + + UniformArrayLayoutInfo() + : elementStride(0) + {} + + UniformArrayLayoutInfo( + LayoutSize size, + size_t alignment, + size_t elementStride) + : UniformLayoutInfo(size, alignment) + , elementStride(elementStride) + {} +}; + +typedef slang::ParameterCategory LayoutResourceKind; + +// Layout information for a value that only consumes +// a single resource kind. +struct SimpleLayoutInfo +{ + // What kind of resource should we consume? + LayoutResourceKind kind; + + // How many resources of that kind? + LayoutSize size; + + // only useful in the uniform case + size_t alignment; + + SimpleLayoutInfo() + : kind(LayoutResourceKind::None) + , size(0) + , alignment(1) + {} + + SimpleLayoutInfo( + UniformLayoutInfo uniformInfo) + : kind(LayoutResourceKind::Uniform) + , size(uniformInfo.size) + , alignment(uniformInfo.alignment) + {} + + SimpleLayoutInfo(LayoutResourceKind kind, LayoutSize size, size_t alignment=1) + : kind(kind) + , size(size) + , alignment(alignment) + {} + + // Convert to layout for uniform data + UniformLayoutInfo getUniformLayout() + { + if(kind == LayoutResourceKind::Uniform) + { + return UniformLayoutInfo(size, alignment); + } + else + { + return UniformLayoutInfo(0, 1); + } + } +}; + +// Only useful in the case of a homogeneous array +struct SimpleArrayLayoutInfo : SimpleLayoutInfo +{ + // This field is only useful in the uniform case + size_t elementStride; + + // Convert to layout for uniform data + UniformArrayLayoutInfo getUniformLayout() + { + if(kind == LayoutResourceKind::Uniform) + { + return UniformArrayLayoutInfo(size, alignment, elementStride); + } + else + { + return UniformArrayLayoutInfo(0, 1, 0); + } + } +}; + +struct LayoutRulesImpl; + +// Base class for things that store layout info +class Layout : public RefObject +{ +}; + +// A reified representation of a particular laid-out type +class TypeLayout : public Layout +{ +public: + // The type that was laid out + RefPtr type; + Type* getType() { return type.Ptr(); } + + // The layout rules that were used to produce this type + LayoutRulesImpl* rules; + + struct ResourceInfo + { + // What kind of register was it? + LayoutResourceKind kind = LayoutResourceKind::None; + + // How many registers of the above kind did we use? + LayoutSize count; + }; + + List resourceInfos; + + // For uniform data, alignment matters, but not for + // any other resource category, so we don't waste + // the space storing it in the above array + UInt uniformAlignment = 1; + + + /// The layout for data that is conceptually owned by this type, but which is pending layout. + /// + /// When a type contains interface/existential fields (recursively), the + /// actual data referenced by these fields needs to get allocated somewhere, + /// but it cannot go inline at the point where the interface/existential + /// type appears, or else the layout of a composite object would change + /// when the concrete type(s) we plug in change. + /// + /// We solve this problem by tracking this data that is "pending" layout, + /// and then "flushing" the pending data at appropriate places during + /// the layout process. + /// + RefPtr pendingDataTypeLayout; + + ResourceInfo* FindResourceInfo(LayoutResourceKind kind) + { + for(auto& rr : resourceInfos) + { + if(rr.kind == kind) + return &rr; + } + return nullptr; + } + + ResourceInfo* findOrAddResourceInfo(LayoutResourceKind kind) + { + auto existing = FindResourceInfo(kind); + if(existing) return existing; + + ResourceInfo info; + info.kind = kind; + info.count = 0; + resourceInfos.add(info); + return &resourceInfos.getLast(); + } + + void addResourceUsage(ResourceInfo info) + { + if(info.count == 0) return; + + findOrAddResourceInfo(info.kind)->count += info.count; + } + + void addResourceUsage(LayoutResourceKind kind, LayoutSize count) + { + ResourceInfo info; + info.kind = kind; + info.count = count; + addResourceUsage(info); + } + + void addResourceUsageFrom(TypeLayout* otherTypeLayout); + + /// "Unwrap" any layers of array-ness from this type layout. + /// + /// If this is an `ArrayTypeLayout`, returns the result of unwrapping the element type layout. + /// Otherwise, returns this type layout. + /// + RefPtr unwrapArray(); +}; + +typedef unsigned int VarLayoutFlags; +enum VarLayoutFlag : VarLayoutFlags +{ + HasSemantic = 1 << 1 +}; + +// A reified layout for a particular variable, field, etc. +class VarLayout : public Layout +{ +public: + // The variable we are laying out + DeclRef varDecl; + VarDeclBase* getVariable() { return varDecl.getDecl(); } + + Name* getName() { return getVariable()->getName(); } + + // The result of laying out the variable's type + RefPtr typeLayout; + TypeLayout* getTypeLayout() { return typeLayout.Ptr(); } + + // Additional flags + VarLayoutFlags flags = 0; + + // System-value semantic (and index) if this is a system value + String systemValueSemantic; + int systemValueSemanticIndex; + + // General case semantic name and index + // TODO: this and the system-value field are redundant + // TODO: the `VarLayout` type is getting bloated; we need to not store this + // information unless actually required. + String semanticName; + int semanticIndex; + + // The stage this variable belongs to, in case it is + // stage-specific. + // TODO: This is wasteful to be storing on every single + // variable layout. + Stage stage = Stage::Unknown; + + // The start register(s) for any resources + struct ResourceInfo + { + // What kind of register was it? + LayoutResourceKind kind = LayoutResourceKind::None; + + // What binding space (HLSL) or set (Vulkan) are we placed in? + UInt space; + + // What is our starting register in that space? + // + // (In the case of uniform data, this is a byte offset) + UInt index; + }; + List resourceInfos; + + ResourceInfo* FindResourceInfo(LayoutResourceKind kind) + { + for(auto& rr : resourceInfos) + { + if(rr.kind == kind) + return &rr; + } + return nullptr; + } + + ResourceInfo* AddResourceInfo(LayoutResourceKind kind) + { + ResourceInfo info; + info.kind = kind; + info.space = 0; + info.index = 0; + + resourceInfos.add(info); + return &resourceInfos.getLast(); + } + + ResourceInfo* findOrAddResourceInfo(LayoutResourceKind kind) + { + auto existing = FindResourceInfo(kind); + if(existing) return existing; + + return AddResourceInfo(kind); + } + + RefPtr pendingVarLayout; +}; + +// type layout for a variable that has a constant-buffer type +class ParameterGroupTypeLayout : public TypeLayout +{ +public: + // The layout of the "container" part itself. + // E.g., for a constant buffer, this would reflect + // the resource usage of the container, without + // the element type factored in. All of the offsets + // for this variable should be zero, but it is included + // for completeness. + RefPtr containerVarLayout; + + // A variable layout for the element of the container. + // The offsets of the variable layout will reflect + // the offsets that need to applied to get past the + // container types resource usage, while the actual + // type layout won't have offsets applied (unlike + // `offsetElementTypeLayout` below). + RefPtr elementVarLayout; + + // The layout of the element type, with offsets applied + // so that any fields (if the element type is a `struct`) + // will be offset by the resource usage of the container. + RefPtr offsetElementTypeLayout; + + // If the element type layout had any "pending" data, then + // as much of that data as possible will be flushed to + // fit into the overall layout of the parameter group. + // + // This field stores the offset information for where + // the pending data got stored relative to the start of + // the group. + // +// RefPtr flushedDataVarLayout; +}; + +// type layout for a variable that has a constant-buffer type +class StructuredBufferTypeLayout : public TypeLayout +{ +public: + RefPtr elementTypeLayout; +}; + + /// Type layout for a logical sequence type +class SequenceTypeLayout : public TypeLayout +{ +public: + /// The layout of the element type. + /// + /// This layout may include adjustments to make lookups in elements + /// of the array Just Work, and may not be the same as the layout + /// of the element type when used in a non-array context. + /// + RefPtr elementTypeLayout; + + /// The stride in bytes between elements. + size_t uniformStride = 0; +}; + + /// Type layout for an array type +class ArrayTypeLayout : public SequenceTypeLayout +{ +public: + /// The original layout of the element type. + /// + /// This layout does not include any adjustments that + /// were made to the element type in order to make + /// lookup into array elements Just Work. + /// + RefPtr originalElementTypeLayout; +}; + +// type layout for a variable with stream-output type +class StreamOutputTypeLayout : public TypeLayout +{ +public: + RefPtr elementTypeLayout; +}; + +class VectorTypeLayout : public SequenceTypeLayout +{ +public: +}; + + +class MatrixTypeLayout : public SequenceTypeLayout +{ +public: + /// Is this matrix laid out as row-major or column-major? + /// + /// Note that this does *not* affect the interpretation + /// of the `elementTypeLayout` field, which always represents + /// the logical elements of the matrix type, which are its + /// rows. + /// + MatrixLayoutMode mode; +}; + +// Specific case of type layout for a struct +class StructTypeLayout : public TypeLayout +{ +public: + // An ordered list of layouts for the known fields + List> fields; + + // Map a variable to its layout directly. + // + // Note that in the general case, there may be entries + // in the `fields` array that came from multiple + // translation units, and in cases where there are + // multiple declarations of the same parameter, only + // one will appear in `fields`, while all of + // them will be reflected in `mapVarToLayout`. + // + // TODO: This should map from a declaration to the *index* + // in the array above, rather than to the actual pointer, + // so that we + Dictionary> mapVarToLayout; + + // As an accellerator for type layouts created at the + // IR layer, we include a second map that use IR "key" + // instructions to map to fields. + // + Dictionary> mapKeyToLayout; +}; + +class GenericParamTypeLayout : public TypeLayout +{ +public: + RefPtr getGlobalGenericParamDecl(); + int paramIndex = 0; +}; + + /// Layout information for a tagged union type. +class TaggedUnionTypeLayout : public TypeLayout +{ +public: + /// The layouts of each of the case types. + /// + /// The order of entries in this array matches + /// the order of case types on the original + /// `TaggedUnionType`, and the index of a case + /// type is also the tag value for that case. + /// + List> caseTypeLayouts; + + /// The byte offset for the tag field. + /// + /// The tag field will always be allocated as + /// a `uint`, so we don't store a separate layout + /// for it. + /// + LayoutSize tagOffset; +}; + + /// Layout information for a type with existential (sub-)field types specialized. +class ExistentialSpecializedTypeLayout : public TypeLayout +{ +public: + RefPtr baseTypeLayout; + RefPtr pendingDataVarLayout; +}; + + /// Layout for a scoped entity like a program, module, or entry point +class ScopeLayout : public Layout +{ +public: + // The layout for the parameters of this entity. + // + RefPtr parametersLayout; +}; + +StructTypeLayout* getScopeStructLayout( + ScopeLayout* programLayout); + +// Layout information for a single shader entry point +// within a program +// +// Treated as a subclass of `StructTypeLayout` because +// it needs to include computed layout information +// for the parameters of the entry point. +// +// TODO: where to store layout info for the return +// type of the function? +class EntryPointLayout : public ScopeLayout +{ +public: + // The corresponding function declaration + RefPtr entryPoint; + + // The shader profile that was used to compile the entry point + Profile profile; + + // Layout for any results of the entry point + RefPtr resultLayout; + + enum Flag : unsigned + { + usesAnySampleRateInput = 0x1, + }; + unsigned flags = 0; + + /// Layouts for all tagged union types required by this entry point. + /// + /// These are any tagged union types used by the generic + /// arguments that this entry point is being compiled with. + List> taggedUnionTypeLayouts; +}; + +class GenericParamLayout : public Layout +{ +public: + RefPtr decl; + int index; +}; + +// Layout information for the global scope of a program +class ProgramLayout : public ScopeLayout +{ +public: + /* + // We store a layout for the declarations at the global + // scope. Note that this will *either* be a single + // `StructTypeLayout` with the fields stored directly, + // or it will be a single `ParameterGroupTypeLayout`, + // where the global-scope fields are the members of + // that constant buffer. + // + // The `struct` case will be used if there are no + // "naked" global-scope uniform variables, and the + // constant-buffer case will be used if there are + // (since a constant buffer will have to be allocated + // to store them). + // + RefPtr globalScopeLayout; + */ + + /// The target and program for which layout was computed + TargetProgram* targetProgram; + + TargetProgram* getTargetProgram() { return targetProgram; } + TargetRequest* getTargetReq() { return targetProgram->getTargetReq(); } + Program* getProgram() { return targetProgram->getProgram(); } + + + // We catalog the requested entry points here, + // and any entry-point-specific parameter data + // will (eventually) belong there... + List> entryPoints; + + List> globalGenericParams; + Dictionary globalGenericParamsMap; +}; + +StructTypeLayout* getGlobalStructLayout( + ProgramLayout* programLayout); + +struct LayoutRulesFamilyImpl; + +// A delineation of shader parameter types into fine-grained +// categories that can then be mapped down to actual resources +// by a given set of rules. +// +// TODO(tfoley): `SlangParameterCategory` and `slang::ParameterCategory` +// are badly named, and need to be revised so they can't be confused +// with this concept. +enum class ShaderParameterKind +{ + ConstantBuffer, + TextureUniformBuffer, + ShaderStorageBuffer, + + StructuredBuffer, + MutableStructuredBuffer, + + RawBuffer, + MutableRawBuffer, + + Buffer, + MutableBuffer, + + Texture, + MutableTexture, + + TextureSampler, + MutableTextureSampler, + + InputRenderTarget, + + SamplerState, + + Image, + MutableImage, + + RegisterSpace, +}; + +struct SimpleLayoutRulesImpl +{ + // Get size and alignment for a single value of base type. + virtual SimpleLayoutInfo GetScalarLayout(BaseType baseType) = 0; + + // Get size and alignment for an array of elements + virtual SimpleArrayLayoutInfo GetArrayLayout(SimpleLayoutInfo elementInfo, LayoutSize elementCount) = 0; + + // Get layout for a vector or matrix type + virtual SimpleLayoutInfo GetVectorLayout(SimpleLayoutInfo elementInfo, size_t elementCount) = 0; + virtual SimpleArrayLayoutInfo GetMatrixLayout(SimpleLayoutInfo elementInfo, size_t rowCount, size_t columnCount) = 0; + + // Begin doing layout on a `struct` type + virtual UniformLayoutInfo BeginStructLayout() = 0; + + // Add a field to a `struct` type, and return the offset for the field + virtual LayoutSize AddStructField(UniformLayoutInfo* ioStructInfo, UniformLayoutInfo fieldInfo) = 0; + + // End layout for a struct, and finalize its size/alignment. + virtual void EndStructLayout(UniformLayoutInfo* ioStructInfo) = 0; +}; + +struct ObjectLayoutRulesImpl +{ + // Compute layout info for an object type + virtual SimpleLayoutInfo GetObjectLayout(ShaderParameterKind kind) = 0; +}; + +struct LayoutRulesImpl +{ + LayoutRulesFamilyImpl* family; + SimpleLayoutRulesImpl* simpleRules; + ObjectLayoutRulesImpl* objectRules; + + // Forward `SimpleLayoutRulesImpl` interface + + SimpleLayoutInfo GetScalarLayout(BaseType baseType) + { + return simpleRules->GetScalarLayout(baseType); + } + + SimpleArrayLayoutInfo GetArrayLayout(SimpleLayoutInfo elementInfo, LayoutSize elementCount) + { + return simpleRules->GetArrayLayout(elementInfo, elementCount); + } + + SimpleLayoutInfo GetVectorLayout(SimpleLayoutInfo elementInfo, size_t elementCount) + { + return simpleRules->GetVectorLayout(elementInfo, elementCount); + } + + SimpleArrayLayoutInfo GetMatrixLayout(SimpleLayoutInfo elementInfo, size_t rowCount, size_t columnCount) + { + return simpleRules->GetMatrixLayout(elementInfo, rowCount, columnCount); + } + + UniformLayoutInfo BeginStructLayout() + { + return simpleRules->BeginStructLayout(); + } + + LayoutSize AddStructField(UniformLayoutInfo* ioStructInfo, UniformLayoutInfo fieldInfo) + { + return simpleRules->AddStructField(ioStructInfo, fieldInfo); + } + + void EndStructLayout(UniformLayoutInfo* ioStructInfo) + { + return simpleRules->EndStructLayout(ioStructInfo); + } + + // Forward `ObjectLayoutRulesImpl` interface + + SimpleLayoutInfo GetObjectLayout(ShaderParameterKind kind) + { + return objectRules->GetObjectLayout(kind); + } + + // + + LayoutRulesFamilyImpl* getLayoutRulesFamily() { return family; } +}; + +struct LayoutRulesFamilyImpl +{ + virtual LayoutRulesImpl* getConstantBufferRules() = 0; + virtual LayoutRulesImpl* getPushConstantBufferRules() = 0; + virtual LayoutRulesImpl* getTextureBufferRules() = 0; + virtual LayoutRulesImpl* getVaryingInputRules() = 0; + virtual LayoutRulesImpl* getVaryingOutputRules() = 0; + virtual LayoutRulesImpl* getSpecializationConstantRules()= 0; + virtual LayoutRulesImpl* getShaderStorageBufferRules() = 0; + virtual LayoutRulesImpl* getParameterBlockRules() = 0; + + virtual LayoutRulesImpl* getRayPayloadParameterRules() = 0; + virtual LayoutRulesImpl* getCallablePayloadParameterRules() = 0; + virtual LayoutRulesImpl* getHitAttributesParameterRules()= 0; + + virtual LayoutRulesImpl* getShaderRecordConstantBufferRules() = 0; +}; + +typedef List> GenericParamLayouts; + +struct TypeLayoutContext +{ + // The layout rules to use (e.g., we compute + // layout differently in a `cbuffer` vs. the + // parameter list of a fragment shader). + LayoutRulesImpl* rules; + + // The target request that is triggering layout + TargetRequest* targetReq; + + // A parent program layout that will establish the ordering + // of all global generic type parameters. + // + ProgramLayout* programLayout; + + // Whether to lay out matrices column-major + // or row-major. + MatrixLayoutMode matrixLayoutMode; + + // The concrete types (if any) to plug into the currently in-scope + // existential type slots. + // + Int existentialTypeArgCount = 0; + ExistentialTypeSlots::Arg const* existentialTypeArgs = nullptr; + + LayoutRulesImpl* getRules() { return rules; } + LayoutRulesFamilyImpl* getRulesFamily() const { return rules->getLayoutRulesFamily(); } + + TypeLayoutContext with(LayoutRulesImpl* inRules) const + { + TypeLayoutContext result = *this; + result.rules = inRules; + return result; + } + + TypeLayoutContext with(MatrixLayoutMode inMatrixLayoutMode) const + { + TypeLayoutContext result = *this; + result.matrixLayoutMode = inMatrixLayoutMode; + return result; + } + + TypeLayoutContext withExistentialTypeArgs( + Int argCount, + ExistentialTypeSlots::Arg const* args) const + { + TypeLayoutContext result = *this; + result.existentialTypeArgCount = argCount; + result.existentialTypeArgs = args; + return result; + } + + TypeLayoutContext withExistentialTypeSlotsOffsetBy( + Int offset) const + { + TypeLayoutContext result = *this; + if( existentialTypeArgCount > offset ) + { + result.existentialTypeArgCount = existentialTypeArgCount - offset; + result.existentialTypeArgs = existentialTypeArgs + offset; + } + else + { + result.existentialTypeArgCount = 0; + result.existentialTypeArgs = nullptr; + } + return result; + + } +}; + +// + + /// A custom tuple to capture the outputs of type layout +struct TypeLayoutResult +{ + /// The actual heap-allocated layout object with all the details + RefPtr layout; + + /// A simplified representation of layout information. + /// + /// This information is suitable for the case where a type only + /// consumes a single resource. + /// + SimpleLayoutInfo info; + + /// Default constructor. + TypeLayoutResult() + {} + + /// Construct a result from the given layout object and simple layout info. + TypeLayoutResult(RefPtr inLayout, SimpleLayoutInfo const& inInfo) + : layout(inLayout) + , info(inInfo) + {} +}; + + /// Helper type for building `struct` type layouts +struct StructTypeLayoutBuilder +{ +public: + /// Begin the layout process for `type`, using `rules` + void beginLayout( + Type* type, + LayoutRulesImpl* rules); + + /// Begin the layout process for `type`, using `rules`, if it hasn't already been begun. + /// + /// This functions allows for a `StructTypeLayoutBuilder` to be use lazily, + /// only allocating a type layout object if it is actaully needed. + /// + void beginLayoutIfNeeded( + Type* type, + LayoutRulesImpl* rules); + + /// Add a field to the struct type layout. + /// + /// One of the `beginLayout*()` functions must have been called previously. + /// + RefPtr addField( + DeclRef field, + TypeLayoutResult fieldResult); + + /// Add a field to the struct type layout. + /// + /// One of the `beginLayout*()` functions must have been called previously. + /// + RefPtr addField( + DeclRef field, + RefPtr fieldTypeLayout); + + /// Complete layout. + /// + /// If layout was begun, ensures that the result of `getTypeLayout()` is usable. + /// If layout was never begin, does nothing. + /// + void endLayout(); + + /// Get the type layout. + /// + /// This can be called any time after `beginLayout*()`. + /// In particular, it can be called before `endLayout`. + /// + RefPtr getTypeLayout(); + + /// The the type layout result. + /// + /// This is primarily useful for implementation code in `_createTypeLayout`. + /// + TypeLayoutResult getTypeLayoutResult(); + +private: + /// The layout rules being used, if layout has begun. + LayoutRulesImpl* m_rules = nullptr; + + /// The type layout being computed, if layout has begun. + RefPtr m_typeLayout; + + /// Uniform offset/alignment statte used when computing offset for uniform fields. + UniformLayoutInfo m_info; +}; + +// + +// Get an appropriate set of layout rules (packaged up +// as a `TypeLayoutContext`) to perform type layout +// for the given target. +// +// The provided `programLayout` is used to establish +// the ordering of all global generic type paramters. +// +TypeLayoutContext getInitialLayoutContextForTarget( + TargetRequest* targetReq, + ProgramLayout* programLayout); + + /// Direction(s) of a varying shader parameter +typedef unsigned int EntryPointParameterDirectionMask; +enum +{ + kEntryPointParameterDirection_Input = 0x1, + kEntryPointParameterDirection_Output = 0x2, +}; + + + /// Get layout information for a simple varying parameter type. + /// + /// A simple varying parameter is a scalar, vector, or matrix. + /// +RefPtr getSimpleVaryingParameterTypeLayout( + TypeLayoutContext const& context, + Type* type, + EntryPointParameterDirectionMask directionMask); + +// Create a full type-layout object for a type, +// according to the layout rules in `context`. +RefPtr createTypeLayout( + TypeLayoutContext const& context, + Type* type); + +// + + /// Create a layout for a parameter-group type (a `ConstantBuffer` or `ParameterBlock`). +RefPtr createParameterGroupTypeLayout( + TypeLayoutContext const& context, + RefPtr parameterGroupType); + + /// Create a wrapper constant buffer type layout, if needed. + /// + /// When dealing with entry-point `uniform` and global-scope parameters, + /// we want to create a wrapper constant buffer for all the parameters + /// if and only if there exist some parameters that use "ordinary" data + /// (`LayoutResourceKind::Uniform`). + /// + /// This function determines whether such a wrapper is needed, based + /// on the `elementTypeLayout` given, and either creates and returns + /// the layout for the wrapper, or the unmodified `elementTypeLayout`. + /// +RefPtr createConstantBufferTypeLayoutIfNeeded( + TypeLayoutContext const& context, + RefPtr elementTypeLayout); + +// Create a type layout for a structured buffer type. +RefPtr +createStructuredBufferTypeLayout( + TypeLayoutContext const& context, + ShaderParameterKind kind, + RefPtr structuredBufferType, + RefPtr elementType); + +int findGenericParam(List> & genericParameters, GlobalGenericParamDecl * decl); +// + +// Given an existing type layout `oldTypeLayout`, apply offsets +// to any contained fields based on the resource infos in `offsetVarLayout`. +RefPtr applyOffsetToTypeLayout( + RefPtr oldTypeLayout, + RefPtr offsetVarLayout); + +} + +#endif diff --git a/source/slang/slang-type-system-shared.cpp b/source/slang/slang-type-system-shared.cpp new file mode 100644 index 000000000..7ccde5bcd --- /dev/null +++ b/source/slang/slang-type-system-shared.cpp @@ -0,0 +1,11 @@ +#include "slang-type-system-shared.h" + +namespace Slang +{ + TextureFlavor TextureFlavor::create(SlangResourceShape shape, SlangResourceAccess access) + { + TextureFlavor rs; + rs.flavor = uint16_t(shape | (access << 8)); + return rs; + } +} diff --git a/source/slang/slang-type-system-shared.h b/source/slang/slang-type-system-shared.h new file mode 100644 index 000000000..95840e701 --- /dev/null +++ b/source/slang/slang-type-system-shared.h @@ -0,0 +1,102 @@ +#ifndef SLANG_TYPE_SYSTEM_SHARED_H +#define SLANG_TYPE_SYSTEM_SHARED_H + +#include "../../slang.h" + +namespace Slang +{ +#define FOREACH_BASE_TYPE(X) \ + X(Void) \ + X(Bool) \ + X(Int8) \ + X(Int16) \ + X(Int) \ + X(Int64) \ + X(UInt8) \ + X(UInt16) \ + X(UInt) \ + X(UInt64) \ + X(Half) \ + X(Float) \ + X(Double) \ +/* end */ + + enum class BaseType + { +#define DEFINE_BASE_TYPE(NAME) NAME, +FOREACH_BASE_TYPE(DEFINE_BASE_TYPE) +#undef DEFINE_BASE_TYPE + + CountOf, + }; + + struct TextureFlavor + { + typedef TextureFlavor ThisType; + enum + { + // Mask for the overall "shape" of the texture + BaseShapeMask = SLANG_RESOURCE_BASE_SHAPE_MASK, + + // Flag for whether the shape has "array-ness" + ArrayFlag = SLANG_TEXTURE_ARRAY_FLAG, + + // Whether or not the texture stores multiple samples per pixel + MultisampleFlag = SLANG_TEXTURE_MULTISAMPLE_FLAG, + + // Whether or not this is a shadow texture + // + // TODO(tfoley): is this even meaningful/used? + // ShadowFlag = 0x80, + }; + + enum Shape : uint8_t + { + Shape1D = SLANG_TEXTURE_1D, + Shape2D = SLANG_TEXTURE_2D, + Shape3D = SLANG_TEXTURE_3D, + ShapeCube = SLANG_TEXTURE_CUBE, + ShapeBuffer = SLANG_TEXTURE_BUFFER, + + Shape1DArray = Shape1D | ArrayFlag, + Shape2DArray = Shape2D | ArrayFlag, + // No Shape3DArray + ShapeCubeArray = ShapeCube | ArrayFlag, + }; + + enum + { + // This the total number of expressible flavors, + // which is *not* to say that every expressible + // flavor is actual valid. + Count = 0x10000, + }; + + uint16_t flavor; + + Shape GetBaseShape() const { return Shape(flavor & BaseShapeMask); } + bool isArray() const { return (flavor & ArrayFlag) != 0; } + bool isMultisample() const { return (flavor & MultisampleFlag) != 0; } + // bool isShadow() const { return (flavor & ShadowFlag) != 0; } + + SLANG_FORCE_INLINE bool operator==(const ThisType& rhs) const { return flavor == rhs.flavor; } + SLANG_FORCE_INLINE bool operator!=(const ThisType& rhs) const { return !(*this == rhs); } + + SlangResourceShape getShape() const { return flavor & 0xFF; } + SlangResourceAccess getAccess() const { return (flavor >> 8) & 0xFF; } + + TextureFlavor() = default; + TextureFlavor(uint32_t tag) { flavor = (uint16_t)tag; } + + static TextureFlavor create(SlangResourceShape shape, SlangResourceAccess access); + }; + + enum class SamplerStateFlavor : uint8_t + { + SamplerState, + SamplerComparisonState, + }; + +} + +#endif diff --git a/source/slang/slang-val-defs.h b/source/slang/slang-val-defs.h new file mode 100644 index 000000000..b9d5188ed --- /dev/null +++ b/source/slang/slang-val-defs.h @@ -0,0 +1,155 @@ +// slang-val-defs.h + +// Syntax class definitions for compile-time values. + +// A compile-time integer (may not have a specific concrete value) +ABSTRACT_SYNTAX_CLASS(IntVal, Val) +END_SYNTAX_CLASS() + +// Trivial case of a value that is just a constant integer +SYNTAX_CLASS(ConstantIntVal, IntVal) + FIELD(IntegerLiteralValue, value) + + RAW( + ConstantIntVal() + {} + ConstantIntVal(IntegerLiteralValue value) + : value(value) + {} + + virtual bool EqualsVal(Val* val) override; + virtual String ToString() override; + virtual int GetHashCode() override; + ) +END_SYNTAX_CLASS() + +// The logical "value" of a rererence to a generic value parameter +SYNTAX_CLASS(GenericParamIntVal, IntVal) + DECL_FIELD(DeclRef, declRef) + + RAW( + GenericParamIntVal() + {} + GenericParamIntVal(DeclRef declRef) + : declRef(declRef) + {} + + virtual bool EqualsVal(Val* val) override; + virtual String ToString() override; + virtual int GetHashCode() override; + virtual RefPtr SubstituteImpl(SubstitutionSet subst, int* ioDiff) override; +) +END_SYNTAX_CLASS() + +// A witness to the fact that some proposition is true, encoded +// at the level of the type system. +// +// Given a generic like: +// +// void example(L light) +// where L : ILight +// { ... } +// +// a call to `example()` needs two things for us to be sure +// it is valid: +// +// 1. We need a type `X` to use as the argument for the +// parameter `L`. We might supply this explicitly, or +// via inference. +// +// 2. We need a *proof* that whatever `X` we chose conforms +// to the `ILight` interface. +// +// The easiest way to make such a proof is by construction, +// and a `Witness` represents such a constructive proof. +// Conceptually a proposition like `X : ILight` can be +// seen as a type, and witness prooving that proposition +// is a value of that type. +// +// We construct and store witnesses explicitly during +// semantic checking because they can help us with +// generating downstream code. By following the structure +// of a witness (the structure of a proof) we can, e.g., +// navigate from the knowledge that `X : ILight` to +// the concrete declarations that provide the implementation +// of `ILight` for `X`. +// +ABSTRACT_SYNTAX_CLASS(Witness, Val) +END_SYNTAX_CLASS() + +// A witness that one type is a subtype of another +// (where by "subtype" we include both inheritance +// relationships and type-conforms-to-interface relationships) +// +// TODO: we may need to tease those apart. +ABSTRACT_SYNTAX_CLASS(SubtypeWitness, Witness) + FIELD(RefPtr, sub) + FIELD(RefPtr, sup) +END_SYNTAX_CLASS() + +SYNTAX_CLASS(TypeEqualityWitness, SubtypeWitness) +RAW( + virtual bool EqualsVal(Val* val) override; + virtual String ToString() override; + virtual int GetHashCode() override; + virtual RefPtr SubstituteImpl(SubstitutionSet subst, int * ioDiff) override; +) +END_SYNTAX_CLASS() +// A witness that one type is a subtype of another +// because some in-scope declaration says so +SYNTAX_CLASS(DeclaredSubtypeWitness, SubtypeWitness) + FIELD(DeclRef, declRef); +RAW( + virtual bool EqualsVal(Val* val) override; + virtual String ToString() override; + virtual int GetHashCode() override; + virtual RefPtr SubstituteImpl(SubstitutionSet subst, int * ioDiff) override; +) +END_SYNTAX_CLASS() + +// A witness that `sub : sup` because `sub : mid` and `mid : sup` +SYNTAX_CLASS(TransitiveSubtypeWitness, SubtypeWitness) + // Witness that `sub : mid` + FIELD(RefPtr, subToMid); + + // Witness that `mid : sup` + FIELD(DeclRef, midToSup); +RAW( + virtual bool EqualsVal(Val* val) override; + virtual String ToString() override; + virtual int GetHashCode() override; + virtual RefPtr SubstituteImpl(SubstitutionSet subst, int * ioDiff) override; +) +END_SYNTAX_CLASS() + +// A witness taht `sub : sup` because `sub` was wrapped into +// an existential of type `sup`. +SYNTAX_CLASS(ExtractExistentialSubtypeWitness, SubtypeWitness) +RAW( + // The declaration of the existential value that has been opened + DeclRef declRef; + + virtual bool EqualsVal(Val* val) override; + virtual String ToString() override; + virtual int GetHashCode() override; + virtual RefPtr SubstituteImpl(SubstitutionSet subst, int * ioDiff) override; +) +END_SYNTAX_CLASS() + +// A witness that `sub : sup`, because `sub` is a tagged union +// of the form `A | B | C | ...` and each of `A : sup`, +// `B : sup`, `C : sup`, etc. +// +SYNTAX_CLASS(TaggedUnionSubtypeWitness, SubtypeWitness) +RAW( + // Witnesses that each of the "case" types in the union + // is a subtype of `sup`. + // + List> caseWitnesses; + + virtual bool EqualsVal(Val* val) override; + virtual String ToString() override; + virtual int GetHashCode() override; + virtual RefPtr SubstituteImpl(SubstitutionSet subst, int * ioDiff) override; +) +END_SYNTAX_CLASS() diff --git a/source/slang/slang-visitor.h b/source/slang/slang-visitor.h new file mode 100644 index 000000000..c6d63cd40 --- /dev/null +++ b/source/slang/slang-visitor.h @@ -0,0 +1,535 @@ +// slang-visitor.h +#ifndef SLANG_VISITOR_H_INCLUDED +#define SLANG_VISITOR_H_INCLUDED + +// This file defines the basic "Visitor" pattern for doing dispatch +// over the various categories of syntax node. + +#include "slang-syntax.h" + +namespace Slang { + +// +// type Visitors +// + +struct ITypeVisitor +{ +#define ABSTRACT_SYNTAX_CLASS(NAME,BASE) /* empty */ +#define SYNTAX_CLASS(NAME, BASE) \ + virtual void dispatch_##NAME(NAME* obj, void* extra) = 0; + +#include "slang-object-meta-begin.h" +#include "slang-type-defs.h" +#include "slang-object-meta-end.h" +}; + +template +struct TypeVisitor : Base +{ + Result dispatch(Type* type) + { + Result result; + type->accept(this, &result); + return result; + } + + Result dispatchType(Type* type) + { + Result result; + type->accept(this, &result); + return result; + } + +#define ABSTRACT_SYNTAX_CLASS(NAME,BASE) /* empty */ +#define SYNTAX_CLASS(NAME, BASE) \ + virtual void dispatch_##NAME(NAME* obj, void* extra) override \ + { *(Result*)extra = ((Derived*) this)->visit##NAME(obj); } + +#include "slang-object-meta-begin.h" +#include "slang-type-defs.h" +#include "slang-object-meta-end.h" + +#define ABSTRACT_SYNTAX_CLASS(NAME,BASE) SYNTAX_CLASS(NAME, BASE) +#define SYNTAX_CLASS(NAME, BASE) \ + Result visit##NAME(NAME* obj) \ + { return ((Derived*) this)->visit##BASE(obj); } + +#include "slang-object-meta-begin.h" +#include "slang-type-defs.h" +#include "slang-object-meta-end.h" +}; + +template +struct TypeVisitor : Base +{ + void dispatch(Type* type) + { + type->accept(this, 0); + } + + void dispatchType(Type* type) + { + type->accept(this, 0); + } + +#define ABSTRACT_SYNTAX_CLASS(NAME,BASE) /* empty */ +#define SYNTAX_CLASS(NAME, BASE) \ + virtual void dispatch_##NAME(NAME* obj, void*) override \ + { ((Derived*) this)->visit##NAME(obj); } + +#include "slang-object-meta-begin.h" +#include "slang-type-defs.h" +#include "slang-object-meta-end.h" + +#define ABSTRACT_SYNTAX_CLASS(NAME,BASE) SYNTAX_CLASS(NAME, BASE) +#define SYNTAX_CLASS(NAME, BASE) \ + void visit##NAME(NAME* obj) \ + { ((Derived*) this)->visit##BASE(obj); } + +#include "slang-object-meta-begin.h" +#include "slang-type-defs.h" +#include "slang-object-meta-end.h" +}; + +template +struct TypeVisitorWithArg : Base +{ + void dispatch(Type* type, Arg const& arg) + { + type->accept(this, (void*)&arg); + } + +#define ABSTRACT_SYNTAX_CLASS(NAME,BASE) /* empty */ +#define SYNTAX_CLASS(NAME, BASE) \ + virtual void dispatch_##NAME(NAME* obj, void* arg) override \ + { ((Derived*) this)->visit##NAME(obj, *(Arg*)arg); } + +#include "slang-object-meta-begin.h" +#include "slang-type-defs.h" +#include "slang-object-meta-end.h" + +#define ABSTRACT_SYNTAX_CLASS(NAME,BASE) SYNTAX_CLASS(NAME, BASE) +#define SYNTAX_CLASS(NAME, BASE) \ + void visit##NAME(NAME* obj, Arg const& arg) \ + { ((Derived*) this)->visit##BASE(obj, arg); } + +#include "slang-object-meta-begin.h" +#include "slang-type-defs.h" +#include "slang-object-meta-end.h" +}; + +// +// Expression Visitors +// + +struct IExprVisitor +{ +#define ABSTRACT_SYNTAX_CLASS(NAME,BASE) /* empty */ +#define SYNTAX_CLASS(NAME, BASE) \ + virtual void dispatch_##NAME(NAME* obj, void* extra) = 0; + +#include "slang-object-meta-begin.h" +#include "slang-expr-defs.h" +#include "slang-object-meta-end.h" +}; + +template +struct ExprVisitor : IExprVisitor +{ + Result dispatch(Expr* expr) + { + Result result; + expr->accept(this, &result); + return result; + } + +#define ABSTRACT_SYNTAX_CLASS(NAME,BASE) /* empty */ +#define SYNTAX_CLASS(NAME, BASE) \ + virtual void dispatch_##NAME(NAME* obj, void* extra) override \ + { *(Result*)extra = ((Derived*) this)->visit##NAME(obj); } + +#include "slang-object-meta-begin.h" +#include "slang-expr-defs.h" +#include "slang-object-meta-end.h" + +#define ABSTRACT_SYNTAX_CLASS(NAME,BASE) SYNTAX_CLASS(NAME, BASE) +#define SYNTAX_CLASS(NAME, BASE) \ + Result visit##NAME(NAME* obj) \ + { return ((Derived*) this)->visit##BASE(obj); } + +#include "slang-object-meta-begin.h" +#include "slang-expr-defs.h" +#include "slang-object-meta-end.h" +}; + +template +struct ExprVisitor : IExprVisitor +{ + void dispatch(Expr* expr) + { + expr->accept(this, 0); + } + +#define ABSTRACT_SYNTAX_CLASS(NAME,BASE) /* empty */ +#define SYNTAX_CLASS(NAME, BASE) \ + virtual void dispatch_##NAME(NAME* obj, void*) override \ + { ((Derived*) this)->visit##NAME(obj); } + +#include "slang-object-meta-begin.h" +#include "slang-expr-defs.h" +#include "slang-object-meta-end.h" + +#define ABSTRACT_SYNTAX_CLASS(NAME,BASE) SYNTAX_CLASS(NAME, BASE) +#define SYNTAX_CLASS(NAME, BASE) \ + void visit##NAME(NAME* obj) \ + { ((Derived*) this)->visit##BASE(obj); } + +#include "slang-object-meta-begin.h" +#include "slang-expr-defs.h" +#include "slang-object-meta-end.h" +}; + +template +struct ExprVisitorWithArg : IExprVisitor +{ + void dispatch(Expr* obj, Arg const& arg) + { + obj->accept(this, (void*)&arg); + } + +#define ABSTRACT_SYNTAX_CLASS(NAME,BASE) /* empty */ +#define SYNTAX_CLASS(NAME, BASE) \ + virtual void dispatch_##NAME(NAME* obj, void* arg) override \ + { ((Derived*) this)->visit##NAME(obj, *(Arg*)arg); } + +#include "slang-object-meta-begin.h" +#include "slang-expr-defs.h" +#include "slang-object-meta-end.h" + +#define ABSTRACT_SYNTAX_CLASS(NAME,BASE) SYNTAX_CLASS(NAME, BASE) +#define SYNTAX_CLASS(NAME, BASE) \ + void visit##NAME(NAME* obj, Arg const& arg) \ + { ((Derived*) this)->visit##BASE(obj, arg); } + +#include "slang-object-meta-begin.h" +#include "slang-expr-defs.h" +#include "slang-object-meta-end.h" +}; + +// +// Statement Visitors +// + +struct IStmtVisitor +{ +#define ABSTRACT_SYNTAX_CLASS(NAME,BASE) /* empty */ +#define SYNTAX_CLASS(NAME, BASE) \ + virtual void dispatch_##NAME(NAME* obj, void* extra) = 0; + +#include "slang-object-meta-begin.h" +#include "slang-stmt-defs.h" +#include "slang-object-meta-end.h" +}; + +template +struct StmtVisitor : IStmtVisitor +{ + Result dispatch(Stmt* stmt) + { + Result result; + stmt->accept(this, &result); + return result; + } + +#define ABSTRACT_SYNTAX_CLASS(NAME,BASE) /* empty */ +#define SYNTAX_CLASS(NAME, BASE) \ + virtual void dispatch_##NAME(NAME* obj, void* extra) override \ + { *(Result*)extra = ((Derived*) this)->visit##NAME(obj); } + +#include "slang-object-meta-begin.h" +#include "slang-stmt-defs.h" +#include "slang-object-meta-end.h" + +#define ABSTRACT_SYNTAX_CLASS(NAME,BASE) SYNTAX_CLASS(NAME, BASE) +#define SYNTAX_CLASS(NAME, BASE) \ + Result visit##NAME(NAME* obj) \ + { return ((Derived*) this)->visit##BASE(obj); } + +#include "slang-object-meta-begin.h" +#include "slang-stmt-defs.h" +#include "slang-object-meta-end.h" +}; + +template +struct StmtVisitor : IStmtVisitor +{ + void dispatch(Stmt* stmt) + { + stmt->accept(this, 0); + } + +#define ABSTRACT_SYNTAX_CLASS(NAME,BASE) /* empty */ +#define SYNTAX_CLASS(NAME, BASE) \ + virtual void dispatch_##NAME(NAME* obj, void*) override \ + { ((Derived*) this)->visit##NAME(obj); } + +#include "slang-object-meta-begin.h" +#include "slang-stmt-defs.h" +#include "slang-object-meta-end.h" + +#define ABSTRACT_SYNTAX_CLASS(NAME,BASE) SYNTAX_CLASS(NAME, BASE) +#define SYNTAX_CLASS(NAME, BASE) \ + void visit##NAME(NAME* obj) \ + { ((Derived*) this)->visit##BASE(obj); } + +#include "slang-object-meta-begin.h" +#include "slang-stmt-defs.h" +#include "slang-object-meta-end.h" +}; + +// +// Declaration Visitors +// + +struct IDeclVisitor +{ +#define ABSTRACT_SYNTAX_CLASS(NAME,BASE) /* empty */ +#define SYNTAX_CLASS(NAME, BASE) \ + virtual void dispatch_##NAME(NAME* obj, void* extra) = 0; + +#include "slang-object-meta-begin.h" +#include "slang-decl-defs.h" +#include "slang-object-meta-end.h" +}; + +template +struct DeclVisitor : IDeclVisitor +{ + Result dispatch(DeclBase* decl) + { + Result result; + decl->accept(this, &result); + return result; + } + +#define ABSTRACT_SYNTAX_CLASS(NAME,BASE) /* empty */ +#define SYNTAX_CLASS(NAME, BASE) \ + virtual void dispatch_##NAME(NAME* obj, void* extra) override \ + { *(Result*)extra = ((Derived*) this)->visit##NAME(obj); } + +#include "slang-object-meta-begin.h" +#include "slang-decl-defs.h" +#include "slang-object-meta-end.h" + +#define ABSTRACT_SYNTAX_CLASS(NAME,BASE) SYNTAX_CLASS(NAME, BASE) +#define SYNTAX_CLASS(NAME, BASE) \ + Result visit##NAME(NAME* obj) \ + { return ((Derived*) this)->visit##BASE(obj); } + +#include "slang-object-meta-begin.h" +#include "slang-decl-defs.h" +#include "slang-object-meta-end.h" +}; + +template +struct DeclVisitor : IDeclVisitor +{ + void dispatch(DeclBase* decl) + { + decl->accept(this, 0); + } + +#define ABSTRACT_SYNTAX_CLASS(NAME,BASE) /* empty */ +#define SYNTAX_CLASS(NAME, BASE) \ + virtual void dispatch_##NAME(NAME* obj, void*) override \ + { ((Derived*) this)->visit##NAME(obj); } + +#include "slang-object-meta-begin.h" +#include "slang-decl-defs.h" +#include "slang-object-meta-end.h" + +#define ABSTRACT_SYNTAX_CLASS(NAME,BASE) SYNTAX_CLASS(NAME, BASE) +#define SYNTAX_CLASS(NAME, BASE) \ + void visit##NAME(NAME* obj) \ + { ((Derived*) this)->visit##BASE(obj); } + +#include "slang-object-meta-begin.h" +#include "slang-decl-defs.h" +#include "slang-object-meta-end.h" +}; + +template +struct DeclVisitorWithArg : IDeclVisitor +{ + void dispatch(DeclBase* obj, Arg const& arg) + { + obj->accept(this, (void*)&arg); + } + +#define ABSTRACT_SYNTAX_CLASS(NAME,BASE) /* empty */ +#define SYNTAX_CLASS(NAME, BASE) \ + virtual void dispatch_##NAME(NAME* obj, void* arg) override \ + { ((Derived*) this)->visit##NAME(obj, *(Arg*)arg); } + +#include "slang-object-meta-begin.h" +#include "slang-decl-defs.h" +#include "slang-object-meta-end.h" + +#define ABSTRACT_SYNTAX_CLASS(NAME,BASE) SYNTAX_CLASS(NAME, BASE) +#define SYNTAX_CLASS(NAME, BASE) \ + void visit##NAME(NAME* obj, Arg const& arg) \ + { ((Derived*) this)->visit##BASE(obj, arg); } + +#include "slang-object-meta-begin.h" +#include "slang-decl-defs.h" +#include "slang-object-meta-end.h" +}; + + +// +// Modifier Visitors +// + +struct IModifierVisitor +{ +#define ABSTRACT_SYNTAX_CLASS(NAME,BASE) /* empty */ +#define SYNTAX_CLASS(NAME, BASE) \ + virtual void dispatch_##NAME(NAME* obj, void* extra) = 0; + +#include "slang-object-meta-begin.h" +#include "slang-modifier-defs.h" +#include "slang-object-meta-end.h" +}; + +template +struct ModifierVisitor : IModifierVisitor +{ + Result dispatch(Modifier* modifier) + { + Result result; + modifier->accept(this, &result); + return result; + } + +#define ABSTRACT_SYNTAX_CLASS(NAME,BASE) /* empty */ +#define SYNTAX_CLASS(NAME, BASE) \ + virtual void dispatch_##NAME(NAME* obj, void* extra) override \ + { *(Result*)extra = ((Derived*) this)->visit##NAME(obj); } + +#include "slang-object-meta-begin.h" +#include "slang-modifier-defs.h" +#include "slang-object-meta-end.h" + +#define ABSTRACT_SYNTAX_CLASS(NAME,BASE) SYNTAX_CLASS(NAME, BASE) +#define SYNTAX_CLASS(NAME, BASE) \ + Result visit##NAME(NAME* obj) \ + { return ((Derived*) this)->visit##BASE(obj); } + +#include "slang-object-meta-begin.h" +#include "slang-modifier-defs.h" +#include "slang-object-meta-end.h" +}; + +template +struct ModifierVisitor : IModifierVisitor +{ + void dispatch(Modifier* modifier) + { + modifier->accept(this, 0); + } + +#define ABSTRACT_SYNTAX_CLASS(NAME,BASE) /* empty */ +#define SYNTAX_CLASS(NAME, BASE) \ + virtual void dispatch_##NAME(NAME* obj, void*) override \ + { ((Derived*) this)->visit##NAME(obj); } + +#include "slang-object-meta-begin.h" +#include "slang-modifier-defs.h" +#include "slang-object-meta-end.h" + +#define ABSTRACT_SYNTAX_CLASS(NAME,BASE) SYNTAX_CLASS(NAME, BASE) +#define SYNTAX_CLASS(NAME, BASE) \ + void visit##NAME(NAME* obj) \ + { ((Derived*) this)->visit##BASE(obj); } + +#include "slang-object-meta-begin.h" +#include "slang-modifier-defs.h" +#include "slang-object-meta-end.h" +}; + +// +// Val Visitors +// + +struct IValVisitor : ITypeVisitor +{ +#define ABSTRACT_SYNTAX_CLASS(NAME,BASE) /* empty */ +#define SYNTAX_CLASS(NAME, BASE) \ + virtual void dispatch_##NAME(NAME* obj, void* extra) = 0; + +#include "slang-object-meta-begin.h" +#include "slang-val-defs.h" +#include "slang-object-meta-end.h" +}; + +template +struct ValVisitor : TypeVisitor +{ + Result dispatch(Val* val) + { + Result result; + val->accept(this, &result); + return result; + } + +#define ABSTRACT_SYNTAX_CLASS(NAME,BASE) /* empty */ +#define SYNTAX_CLASS(NAME, BASE) \ + virtual void dispatch_##NAME(NAME* obj, void* extra) override \ + { *(Result*)extra = ((Derived*) this)->visit##NAME(obj); } + +#include "slang-object-meta-begin.h" +#include "slang-val-defs.h" +#include "slang-object-meta-end.h" + +#define ABSTRACT_SYNTAX_CLASS(NAME,BASE) SYNTAX_CLASS(NAME, BASE) +#define SYNTAX_CLASS(NAME, BASE) \ + Result visit##NAME(NAME* obj) \ + { return ((Derived*) this)->visit##BASE(obj); } + +#include "slang-object-meta-begin.h" +#include "slang-val-defs.h" +#include "slang-object-meta-end.h" +}; + +template +struct ValVisitor : TypeVisitor +{ + void dispatch(Val* val) + { + val->accept(this, 0); + } + +#define ABSTRACT_SYNTAX_CLASS(NAME,BASE) /* empty */ +#define SYNTAX_CLASS(NAME, BASE) \ + virtual void dispatch_##NAME(NAME* obj, void*) override \ + { ((Derived*) this)->visit##NAME(obj); } + +#include "slang-object-meta-begin.h" +#include "slang-val-defs.h" +#include "slang-object-meta-end.h" + +#define ABSTRACT_SYNTAX_CLASS(NAME,BASE) SYNTAX_CLASS(NAME, BASE) +#define SYNTAX_CLASS(NAME, BASE) \ + void visit##NAME(NAME* obj) \ + { ((Derived*) this)->visit##BASE(obj); } + +#include "slang-object-meta-begin.h" +#include "slang-val-defs.h" +#include "slang-object-meta-end.h" + +}; + +} + +#endif diff --git a/source/slang/slang.cpp b/source/slang/slang.cpp index c78a27f54..a3875ef62 100644 --- a/source/slang/slang.cpp +++ b/source/slang/slang.cpp @@ -4,20 +4,21 @@ #include "../core/slang-string-util.h" #include "../core/slang-shared-library.h" -#include "parameter-binding.h" -#include "lower-to-ir.h" -#include "../slang/parser.h" -#include "../slang/preprocessor.h" -#include "../slang/reflection.h" -#include "syntax-visitors.h" -#include "../slang/type-layout.h" +#include "slang-parameter-binding.h" +#include "slang-lower-to-ir.h" +#include "slang-parser.h" +#include "slang-preprocessor.h" +#include "slang-reflection.h" +#include "slang-syntax-visitors.h" +#include "slang-type-layout.h" #include "slang-file-system.h" + #include "../core/slang-writer.h" -#include "source-loc.h" +#include "slang-source-loc.h" -#include "ir-serialize.h" +#include "slang-ir-serialize.h" // Used to print exception type names in internal-compiler-error messages #include @@ -47,15 +48,15 @@ Session::Session() #define SYNTAX_CLASS(NAME, BASE) \ mapNameToSyntaxClass.Add(getNamePool()->getName(#NAME), getClass()); -#include "object-meta-begin.h" -#include "syntax-base-defs.h" -#include "expr-defs.h" -#include "decl-defs.h" -#include "modifier-defs.h" -#include "stmt-defs.h" -#include "type-defs.h" -#include "val-defs.h" -#include "object-meta-end.h" +#include "slang-object-meta-begin.h" +#include "slang-syntax-base-defs.h" +#include "slang-expr-defs.h" +#include "slang-decl-defs.h" +#include "slang-modifier-defs.h" +#include "slang-stmt-defs.h" +#include "slang-type-defs.h" +#include "slang-val-defs.h" +#include "slang-object-meta-end.h" // Make sure our source manager is initialized builtinSourceManager.initialize(nullptr, nullptr); diff --git a/source/slang/slang.vcxproj b/source/slang/slang.vcxproj index f85aad7e0..3f10a36e1 100644 --- a/source/slang/slang.vcxproj +++ b/source/slang/slang.vcxproj @@ -171,126 +171,126 @@ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + - - - - - - - - - - - - - + + + + + + + + + + + + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + - - - - - diff --git a/source/slang/slang.vcxproj.filters b/source/slang/slang.vcxproj.filters index f08395da7..7103749c0 100644 --- a/source/slang/slang.vcxproj.filters +++ b/source/slang/slang.vcxproj.filters @@ -12,360 +12,360 @@ Header Files - + Header Files - + Header Files - + Header Files - + Header Files - + Header Files - + Header Files - + Header Files - + Header Files - + Header Files - + Header Files - + Header Files - + Header Files - + Header Files - + Header Files - + Header Files - + Header Files - + Header Files - + Header Files - + Header Files - + Header Files - + Header Files - + Header Files - + Header Files - + Header Files - + Header Files - + Header Files - + Header Files - + Header Files - + Header Files - + Header Files - + Header Files - + Header Files - + Header Files - + Header Files - + Header Files - + Header Files - + Header Files - + Header Files - + Header Files - + Header Files - + Header Files - + Header Files - + Header Files - + Header Files - + Header Files - + Header Files - + Header Files - + Header Files - + Header Files - + Header Files - + Header Files - + Header Files - + Header Files - + Header Files - + Header Files - + Header Files - + Header Files - + Header Files - + Header Files - + Header Files - + Header Files - + Header Files - + Header Files - + Header Files - + Header Files - + Header Files - + Header Files - + Source Files - + Source Files - + Source Files - + Source Files - + Source Files - + Source Files - + Source Files - + Source Files - + Source Files - + Source Files - + Source Files - + Source Files - + Source Files - + Source Files - + Source Files - + Source Files - + Source Files - + Source Files - + Source Files - + Source Files - + Source Files - + Source Files - + Source Files - + Source Files - + Source Files - + Source Files - + Source Files - + Source Files - + Source Files - + Source Files - + Source Files - + Source Files - + Source Files - + Source Files - + Source Files - + Source Files - + Source Files - + Source Files - + Source Files - + Source Files - + Source Files - + Source Files - + Source Files - + Source Files - + Source Files - + Source Files - + Source Files - + Source Files - + Source Files - + Source Files - + Source Files diff --git a/source/slang/source-loc.cpp b/source/slang/source-loc.cpp deleted file mode 100644 index c66fdedb5..000000000 --- a/source/slang/source-loc.cpp +++ /dev/null @@ -1,591 +0,0 @@ -// source-loc.cpp -#include "source-loc.h" - -#include "compiler.h" - -#include "../core/slang-string-util.h" - -namespace Slang { - -/* !!!!!!!!!!!!!!!!!!!!!!!!! SourceView !!!!!!!!!!!!!!!!!!!!!!!!!!!! */ - -const String PathInfo::getMostUniqueIdentity() const -{ - switch (type) - { - case Type::Normal: return uniqueIdentity; - case Type::FoundPath: - case Type::FromString: - { - return foundPath; - } - default: return ""; - } -} - -/* !!!!!!!!!!!!!!!!!!!!!!!!! SourceView !!!!!!!!!!!!!!!!!!!!!!!!!!!! */ - -int SourceView::findEntryIndex(SourceLoc sourceLoc) const -{ - if (!m_range.contains(sourceLoc)) - { - return -1; - } - - const auto rawValue = sourceLoc.getRaw(); - - Index hi = m_entries.getCount(); - // If there are no entries, or it is in front of the first entry, then there is no associated entry - if (hi == 0 || - m_entries[0].m_startLoc.getRaw() > sourceLoc.getRaw()) - { - return -1; - } - - Index lo = 0; - while (lo + 1 < hi) - { - const Index mid = (hi + lo) >> 1; - const Entry& midEntry = m_entries[mid]; - SourceLoc::RawValue midValue = midEntry.m_startLoc.getRaw(); - if (midValue <= rawValue) - { - // The location we seek is at or after this entry - lo = mid; - } - else - { - // The location we seek is before this entry - hi = mid; - } - } - - return int(lo); -} - -void SourceView::addLineDirective(SourceLoc directiveLoc, StringSlicePool::Handle pathHandle, int line) -{ - SLANG_ASSERT(pathHandle != StringSlicePool::Handle(0)); - SLANG_ASSERT(m_range.contains(directiveLoc)); - - // Check that the directiveLoc values are always increasing - SLANG_ASSERT(m_entries.getCount() == 0 || (m_entries.getLast().m_startLoc.getRaw() < directiveLoc.getRaw())); - - // Calculate the offset - const int offset = m_range.getOffset(directiveLoc); - - // Get the line index in the original file - const int lineIndex = m_sourceFile->calcLineIndexFromOffset(offset); - - Entry entry; - entry.m_startLoc = directiveLoc; - entry.m_pathHandle = pathHandle; - - // We also need to make sure that any lookups for line numbers will - // get corrected based on this files location. - // We assume the line number coming from the directive is a line number, NOT an index, so the correction needs + 1 - // There is an additional + 1 because we want the NEXT line - ie the line after the #line directive, to the specified value - // Taking both into account means +2 is correct 'fix' - entry.m_lineAdjust = line - (lineIndex + 2); - - m_entries.add(entry); -} - -void SourceView::addLineDirective(SourceLoc directiveLoc, const String& path, int line) -{ - StringSlicePool::Handle pathHandle = getSourceManager()->getStringSlicePool().add(path.getUnownedSlice()); - return addLineDirective(directiveLoc, pathHandle, line); -} - -void SourceView::addDefaultLineDirective(SourceLoc directiveLoc) -{ - SLANG_ASSERT(m_range.contains(directiveLoc)); - // Check that the directiveLoc values are always increasing - SLANG_ASSERT(m_entries.getCount() == 0 || (m_entries.getLast().m_startLoc.getRaw() < directiveLoc.getRaw())); - - // Well if there are no entries, or the last one puts it in default case, then we don't need to add anything - if (m_entries.getCount() == 0 || (m_entries.getCount() && m_entries.getLast().isDefault())) - { - return; - } - - Entry entry; - entry.m_startLoc = directiveLoc; - entry.m_lineAdjust = 0; // No line adjustment... we are going back to default - entry.m_pathHandle = StringSlicePool::Handle(0); // Mark that there is no path, and that this is a 'default' - - SLANG_ASSERT(entry.isDefault()); - - m_entries.add(entry); -} - -HumaneSourceLoc SourceView::getHumaneLoc(SourceLoc loc, SourceLocType type) -{ - const int offset = m_range.getOffset(loc); - - // We need the line index from the original source file - const int lineIndex = m_sourceFile->calcLineIndexFromOffset(offset); - - // TODO: we should really translate the byte index in the line - // to deal with: - // - // - Non-ASCII characters, while might consume multiple bytes - // - // - Tab characters, which should really adjust how we report - // columns (although how are we supposed to know the setting - // that an IDE expects us to use when reporting locations?) - const int columnIndex = m_sourceFile->calcColumnIndex(lineIndex, offset); - - HumaneSourceLoc humaneLoc; - humaneLoc.column = columnIndex + 1; - humaneLoc.line = lineIndex + 1; - - // Make up a default entry - StringSlicePool::Handle pathHandle = StringSlicePool::Handle(0); - - // Only bother looking up the entry information if we want a 'Normal' lookup - const int entryIndex = (type == SourceLocType::Nominal) ? findEntryIndex(loc) : -1; - if (entryIndex >= 0) - { - const Entry& entry = m_entries[entryIndex]; - // Adjust the line - humaneLoc.line += entry.m_lineAdjust; - // Get the pathHandle.. - pathHandle = entry.m_pathHandle; - } - - humaneLoc.pathInfo = _getPathInfoFromHandle(pathHandle); - return humaneLoc; -} - -PathInfo SourceView::_getPathInfo() const -{ - if (m_viewPath.getLength()) - { - PathInfo pathInfo(m_sourceFile->getPathInfo()); - pathInfo.foundPath = m_viewPath; - return pathInfo; - } - else - { - return m_sourceFile->getPathInfo(); - } -} - -PathInfo SourceView::_getPathInfoFromHandle(StringSlicePool::Handle pathHandle) const -{ - // If there is no override path, then just the source files path - if (pathHandle == StringSlicePool::Handle(0)) - { - return _getPathInfo(); - } - else - { - return PathInfo::makePath(getSourceManager()->getStringSlicePool().getSlice(pathHandle)); - } -} - -PathInfo SourceView::getPathInfo(SourceLoc loc, SourceLocType type) -{ - if (type == SourceLocType::Actual) - { - return _getPathInfo(); - } - - const int entryIndex = findEntryIndex(loc); - return _getPathInfoFromHandle((entryIndex >= 0) ? m_entries[entryIndex].m_pathHandle : StringSlicePool::Handle(0)); -} - -/* !!!!!!!!!!!!!!!!!!!!!!! SourceFile !!!!!!!!!!!!!!!!!!!!!!!!!!!! */ - -void SourceFile::setLineBreakOffsets(const uint32_t* offsets, UInt numOffsets) -{ - m_lineBreakOffsets.clear(); - m_lineBreakOffsets.addRange(offsets, numOffsets); -} - -const List& SourceFile::getLineBreakOffsets() -{ - // We now have a raw input file that we can search for line breaks. - // We obviously don't want to do a linear scan over and over, so we will - // cache an array of line break locations in the file. - if (m_lineBreakOffsets.getCount() == 0) - { - UnownedStringSlice content = getContent(); - - char const* begin = content.begin(); - char const* end = content.end(); - - char const* cursor = begin; - - // Treat the beginning of the file as a line break - m_lineBreakOffsets.add(0); - - while (cursor != end) - { - int c = *cursor++; - switch (c) - { - case '\r': case '\n': - { - // When we see a line-break character we need - // to record the line break, but we also need - // to deal with the annoying issue of encodings, - // where a multi-byte sequence might encode - // the line break. - - // Check to make sure that the EOF hasn't been reached. - if (cursor != end) - { - int d = *cursor; - if ((c ^ d) == ('\r' ^ '\n')) - cursor++; - } - - m_lineBreakOffsets.add(uint32_t(cursor - begin)); - break; - } - default: - break; - } - } - - // Note that we do *not* treat the end of the file as a line - // break, because otherwise we would report errors like - // "end of file inside string literal" with a line number - // that points at a line that doesn't exist. - } - - return m_lineBreakOffsets; -} - -int SourceFile::calcLineIndexFromOffset(int offset) -{ - SLANG_ASSERT(UInt(offset) <= getContentSize()); - - // Make sure we have the line break offsets - const auto& lineBreakOffsets = getLineBreakOffsets(); - - // At this point we can assume the `lineBreakOffsets` array has been filled in. - // We will use a binary search to find the line index that contains our - // chosen offset. - Index lo = 0; - Index hi = lineBreakOffsets.getCount(); - - while (lo + 1 < hi) - { - const Index mid = (hi + lo) >> 1; - const uint32_t midOffset = lineBreakOffsets[mid]; - if (midOffset <= uint32_t(offset)) - { - lo = mid; - } - else - { - hi = mid; - } - } - - return int(lo); -} - -int SourceFile::calcColumnIndex(int lineIndex, int offset) -{ - const auto& lineBreakOffsets = getLineBreakOffsets(); - return offset - lineBreakOffsets[lineIndex]; -} - -/* !!!!!!!!!!!!!!!!!!!!!!!!! SourceFile !!!!!!!!!!!!!!!!!!!!!!!!!!!! */ - -void SourceFile::setContents(ISlangBlob* blob) -{ - const UInt contentSize = blob->getBufferSize(); - - SLANG_ASSERT(contentSize == m_contentSize); - - char const* contentBegin = (char const*)blob->getBufferPointer(); - char const* contentEnd = contentBegin + contentSize; - - m_contentBlob = blob; - m_content = UnownedStringSlice(contentBegin, contentEnd); -} - -void SourceFile::setContents(const String& content) -{ - ComPtr contentBlob = StringUtil::createStringBlob(content); - setContents(contentBlob); -} - -SourceFile::SourceFile(SourceManager* sourceManager, const PathInfo& pathInfo, size_t contentSize) : - m_sourceManager(sourceManager), - m_pathInfo(pathInfo), - m_contentSize(contentSize) -{ -} - -SourceFile::~SourceFile() -{ -} - -String SourceFile::calcVerbosePath() const -{ - ISlangFileSystemExt* fileSystemExt = getSourceManager()->getFileSystemExt(); - - if (fileSystemExt) - { - String canonicalPath; - ComPtr canonicalPathBlob; - if (SLANG_SUCCEEDED(fileSystemExt->getCanonicalPath(m_pathInfo.foundPath.getBuffer(), canonicalPathBlob.writeRef()))) - { - canonicalPath = StringUtil::getString(canonicalPathBlob); - } - if (canonicalPath.getLength() > 0) - { - return canonicalPath; - } - } - - return m_pathInfo.foundPath; -} - -/* !!!!!!!!!!!!!!!!!!!!!!!!! SourceManager !!!!!!!!!!!!!!!!!!!!!!!!!!!! */ - -void SourceManager::initialize( - SourceManager* p, - ISlangFileSystemExt* fileSystemExt) -{ - m_fileSystemExt = fileSystemExt; - - m_parent = p; - - if( p ) - { - // If we have a parent source manager, then we assume that all code at that level - // has already been loaded, and it is safe to start our own source locations - // right after those from the parent. - // - // TODO: more clever allocation in cases where that might not be reasonable - m_startLoc = p->m_nextLoc; - } - else - { - // Location zero is reserved for an invalid location, - // so we need to start reserving locations starting at 1. - m_startLoc = SourceLoc::fromRaw(1); - } - - m_nextLoc = m_startLoc; -} - -SourceManager::~SourceManager() -{ - for (auto item : m_sourceViews) - { - delete item; - } - - for (auto item : m_sourceFiles) - { - delete item; - } -} - -UnownedStringSlice SourceManager::allocateStringSlice(const UnownedStringSlice& slice) -{ - const UInt numChars = slice.size(); - - char* dst = (char*)m_memoryArena.allocate(numChars); - ::memcpy(dst, slice.begin(), numChars); - - return UnownedStringSlice(dst, numChars); -} - -SourceRange SourceManager::allocateSourceRange(UInt size) -{ - // TODO: consider using atomics here - - - SourceLoc beginLoc = m_nextLoc; - SourceLoc endLoc = beginLoc + size; - - // We need to be able to represent the location that is *at* the end of - // the input source, so the next available location for a new file - // must be placed one after the end of this one. - - m_nextLoc = endLoc + 1; - - return SourceRange(beginLoc, endLoc); -} - -SourceFile* SourceManager::createSourceFileWithSize(const PathInfo& pathInfo, size_t contentSize) -{ - SourceFile* sourceFile = new SourceFile(this, pathInfo, contentSize); - m_sourceFiles.add(sourceFile); - return sourceFile; -} - -SourceFile* SourceManager::createSourceFileWithString(const PathInfo& pathInfo, const String& contents) -{ - SourceFile* sourceFile = new SourceFile(this, pathInfo, contents.getLength()); - m_sourceFiles.add(sourceFile); - sourceFile->setContents(contents); - return sourceFile; -} - -SourceFile* SourceManager::createSourceFileWithBlob(const PathInfo& pathInfo, ISlangBlob* blob) -{ - SourceFile* sourceFile = new SourceFile(this, pathInfo, blob->getBufferSize()); - m_sourceFiles.add(sourceFile); - sourceFile->setContents(blob); - return sourceFile; -} - -SourceView* SourceManager::createSourceView(SourceFile* sourceFile, const PathInfo* pathInfo) -{ - SourceRange range = allocateSourceRange(sourceFile->getContentSize()); - - SourceView* sourceView = nullptr; - if (pathInfo && - (pathInfo->foundPath.getLength() && sourceFile->getPathInfo().foundPath != pathInfo->foundPath)) - { - sourceView = new SourceView(sourceFile, range, &pathInfo->foundPath); - } - else - { - sourceView = new SourceView(sourceFile, range, nullptr); - } - - m_sourceViews.add(sourceView); - - return sourceView; -} - -SourceView* SourceManager::findSourceView(SourceLoc loc) const -{ - Index hi = m_sourceViews.getCount(); - // It must be in the range of this manager and have associated views for it to possibly be a hit - if (!getSourceRange().contains(loc) || hi == 0) - { - return nullptr; - } - - // If we don't have very many, we may as well just linearly search - if (hi <= 8) - { - for (int i = 0; i < hi; ++i) - { - SourceView* view = m_sourceViews[i]; - if (view->getRange().contains(loc)) - { - return view; - } - } - return nullptr; - } - - const SourceLoc::RawValue rawLoc = loc.getRaw(); - - // Binary chop to see if we can find the associated SourceUnit - Index lo = 0; - while (lo + 1 < hi) - { - Index mid = (hi + lo) >> 1; - - SourceView* midView = m_sourceViews[mid]; - if (midView->getRange().contains(loc)) - { - return midView; - } - - const SourceLoc::RawValue midValue = midView->getRange().begin.getRaw(); - if (midValue <= rawLoc) - { - // The location we seek is at or after this entry - lo = mid; - } - else - { - // The location we seek is before this entry - hi = mid; - } - } - - // Check if low is actually a hit - SourceView* view = m_sourceViews[lo]; - return (view->getRange().contains(loc)) ? view : nullptr; -} - -SourceView* SourceManager::findSourceViewRecursively(SourceLoc loc) const -{ - // Start with this manager - const SourceManager* manager = this; - do - { - SourceView* sourceView = manager->findSourceView(loc); - // If we found a hit we are done - if (sourceView) - { - return sourceView; - } - // Try the parent - manager = manager->m_parent; - } - while (manager); - // Didn't find it - return nullptr; -} - -SourceFile* SourceManager::findSourceFile(const String& uniqueIdentity) const -{ - SourceFile*const* filePtr = m_sourceFileMap.TryGetValue(uniqueIdentity); - return (filePtr) ? *filePtr : nullptr; -} - -SourceFile* SourceManager::findSourceFileRecursively(const String& uniqueIdentity) const -{ - const SourceManager* manager = this; - do - { - SourceFile* sourceFile = manager->findSourceFile(uniqueIdentity); - if (sourceFile) - { - return sourceFile; - } - manager = manager->m_parent; - } while (manager); - return nullptr; -} - -void SourceManager::addSourceFile(const String& uniqueIdentity, SourceFile* sourceFile) -{ - SLANG_ASSERT(!findSourceFileRecursively(uniqueIdentity)); - m_sourceFileMap.Add(uniqueIdentity, sourceFile); -} - -HumaneSourceLoc SourceManager::getHumaneLoc(SourceLoc loc, SourceLocType type) -{ - SourceView* sourceView = findSourceViewRecursively(loc); - if (sourceView) - { - return sourceView->getHumaneLoc(loc, type); - } - else - { - return HumaneSourceLoc(); - } -} - -PathInfo SourceManager::getPathInfo(SourceLoc loc, SourceLocType type) -{ - SourceView* sourceView = findSourceViewRecursively(loc); - if (sourceView) - { - return sourceView->getPathInfo(loc, type); - } - else - { - return PathInfo::makeUnknown(); - } -} - -} // namespace Slang diff --git a/source/slang/source-loc.h b/source/slang/source-loc.h deleted file mode 100644 index 95db7a50e..000000000 --- a/source/slang/source-loc.h +++ /dev/null @@ -1,412 +0,0 @@ -// source-loc.h -#ifndef SLANG_SOURCE_LOC_H_INCLUDED -#define SLANG_SOURCE_LOC_H_INCLUDED - -#include "../core/basic.h" -#include "../core/slang-memory-arena.h" -#include "../core/slang-string-slice-pool.h" - -#include "../../slang-com-ptr.h" -#include "../../slang.h" - -namespace Slang { - -/** Overview: - -There needs to be a mechanism where we can easily and quickly track a specific locations in any source file used during a compilation. -This is important because that original location is meaningful to the user as it relates to their original source. Thus SourceLoc are -used so we can display meaningful and accurate errors/warnings as well as being able to always map generated code locations back to their origins. - -A 'SourceLoc' along with associated structures (SourceView, SourceFile, SourceMangager) this can pinpoint the location down to the byte across the -compilation. This could be achieved by storing for every token and instruction the file, line and column number came from. The SourceLoc is used in -lots of places - every AST node, every Token from the lexer, every IRInst - so we really want to make it small. So for this reason we actually -encode SourceLoc as a single integer and then use the associated structures when needed to determine what the location actually refers to - -the source file, line and column number, or in effect the byte in the original file. - -Unfortunately there is extra complications. When a source is parsed it's interpretation (in terms of how a piece of source maps to an 'original' file etc) -can be overridden - for example by using #line directives. Moreover a single source file can be parsed multiple times. When it's parsed multiple times the -interpretation of the mapping (#line directives for example) can change. This is the purpose of the SourceView - it holds the interpretation of a source file -for a specific Lex/Parse. - -Another complication is that not all 'source' comes from SourceFiles, a macro expansion, may generate new 'source' we need to handle this, but also be able -to have a SourceLoc map to the expansion unambiguously. This is handled by creating a SourceFile and SourceView that holds only the macro generated -specific information. - -SourceFile - Is the immutable text contents of a file (or perhaps some generated source - say from doing a macro substitution) -SourceView - Tracks a single parse of a SourceFile. Each SourceView defines a range of source locations used. If a SourceFile is parsed twice, two -SourceViews are created, with unique SourceRanges. This is so that it is possible to tell which specific parse a SourceLoc is from - and so know the right -interpretation for that lex/parse. -*/ - -struct PathInfo -{ - /// To be more rigorous about where a path comes from, the type identifies what a paths origin is - enum class Type - { - Unknown, ///< The path is not known - Normal, ///< Normal has both path and uniqueIdentity - FoundPath, ///< Just has a found path (uniqueIdentity is unknown, or even 'unknowable') - FromString, ///< Created from a string (so found path might not be defined and should not be taken as to map to a loaded file) - TokenPaste, ///< No paths, just created to do a macro expansion - TypeParse, ///< No path, just created to do a type parse - CommandLine, ///< A macro constructed from the command line - }; - - /// True if has a canonical path - SLANG_FORCE_INLINE bool hasUniqueIdentity() const { return type == Type::Normal && uniqueIdentity.getLength() > 0; } - /// True if has a regular found path - SLANG_FORCE_INLINE bool hasFoundPath() const { return type == Type::Normal || type == Type::FoundPath || (type == Type::FromString && foundPath.getLength() > 0); } - /// True if has a found path that has originated from a file (as opposed to string or some other origin) - SLANG_FORCE_INLINE bool hasFileFoundPath() const { return (type == Type::Normal || type == Type::FoundPath) && foundPath.getLength() > 0; } - - /// Returns the 'most unique' identity for the path. If has a 'uniqueIdentity' returns that, else the foundPath, else "". - const String getMostUniqueIdentity() const; - - // So simplify construction. In normal usage it's safer to use make methods over constructing directly. - static PathInfo makeUnknown() { return PathInfo { Type::Unknown, "unknown", String() }; } - static PathInfo makeTokenPaste() { return PathInfo{ Type::TokenPaste, "token paste", String()}; } - static PathInfo makeNormal(const String& foundPathIn, const String& uniqueIdentity) { SLANG_ASSERT(uniqueIdentity.getLength() > 0 && foundPathIn.getLength() > 0); return PathInfo { Type::Normal, foundPathIn, uniqueIdentity }; } - static PathInfo makePath(const String& pathIn) { SLANG_ASSERT(pathIn.getLength() > 0); return PathInfo { Type::FoundPath, pathIn, String()}; } - static PathInfo makeTypeParse() { return PathInfo { Type::TypeParse, "type string", String() }; } - static PathInfo makeCommandLine() { return PathInfo { Type::CommandLine, "command line", String() }; } - static PathInfo makeFromString(const String& userPath) { return PathInfo{ Type::FromString, userPath, String() }; } - - Type type; ///< The type of path - String foundPath; ///< The path where the file was found (might contain relative elements) - String uniqueIdentity; ///< The unique identity of the file on the path found -}; - -class SourceLoc -{ -public: - typedef uint32_t RawValue; - -private: - RawValue raw; - -public: - SourceLoc() - : raw(0) - {} - - SourceLoc( - SourceLoc const& loc) - : raw(loc.raw) - {} - - RawValue getRaw() const { return raw; } - void setRaw(RawValue value) { raw = value; } - - static SourceLoc fromRaw(RawValue value) - { - SourceLoc result; - result.setRaw(value); - return result; - } - - bool isValid() const - { - return raw != 0; - } -}; - -inline SourceLoc operator+(SourceLoc loc, Int offset) -{ - return SourceLoc::fromRaw(SourceLoc::RawValue(Int(loc.getRaw()) + offset)); -} - -// A range of locations in the input source -struct SourceRange -{ - /// True if the loc is in the range. Range is inclusive on begin to end. - bool contains(SourceLoc loc) const { const auto rawLoc = loc.getRaw(); return rawLoc >= begin.getRaw() && rawLoc <= end.getRaw(); } - /// Get the total size - UInt getSize() const { return UInt(end.getRaw() - begin.getRaw()); } - - /// Get the offset of a loc in this range - int getOffset(SourceLoc loc) const { SLANG_ASSERT(contains(loc)); return int(loc.getRaw() - begin.getRaw()); } - - SourceRange() - {} - - SourceRange(SourceLoc loc) - : begin(loc) - , end(loc) - {} - - SourceRange(SourceLoc begin, SourceLoc end) - : begin(begin) - , end(end) - {} - - SourceLoc begin; - SourceLoc end; -}; - -// Pre-declare -struct SourceManager; - -// A logical or physical storage object for a range of input code -// that has logically contiguous source locations. -class SourceFile -{ -public: - - /// Returns the line break offsets (in bytes from start of content) - /// Note that this is lazily evaluated - the line breaks are only calculated on the first request - const List& getLineBreakOffsets(); - - /// Set the line break offsets - void setLineBreakOffsets(const uint32_t* offsets, UInt numOffsets); - - /// Calculate the line based on the offset - int calcLineIndexFromOffset(int offset); - - /// Calculate the offset for a line - int calcColumnIndex(int line, int offset); - - /// Get the content holding blob - ISlangBlob* getContentBlob() const { return m_contentBlob; } - - /// True if has full set content - bool hasContent() const { return m_contentBlob != nullptr; } - - /// Get the content size - size_t getContentSize() const { return m_contentSize; } - - /// Get the content - const UnownedStringSlice& getContent() const { return m_content; } - - /// Get path info - const PathInfo& getPathInfo() const { return m_pathInfo; } - - /// Set the content as a blob - void setContents(ISlangBlob* blob); - /// Set the content as a string - void setContents(const String& content); - - /// Calculate a display path -> can canonicalize if necessary - String calcVerbosePath() const; - - /// Get the source manager this was created on - SourceManager* getSourceManager() const { return m_sourceManager; } - - /// Ctor - SourceFile(SourceManager* sourceManager, const PathInfo& pathInfo, size_t contentSize); - /// Dtor - ~SourceFile(); - - protected: - - SourceManager* m_sourceManager; ///< The source manager this belongs to - PathInfo m_pathInfo; ///< The path The logical file path to report for locations inside this span. - ComPtr m_contentBlob; ///< A blob that owns the storage for the file contents. If nullptr, there is no contents - UnownedStringSlice m_content; ///< The actual contents of the file. - size_t m_contentSize; ///< The size of the actual contents - - // In order to speed up lookup of line number information, - // we will cache the starting offset of each line break in - // the input file: - List m_lineBreakOffsets; -}; - -enum class SourceLocType -{ - Nominal, ///< The normal interpretation which takes into account #line directives - Actual, ///< Ignores #line directives - and is the location as seen in the actual file -}; - -// A source location in a format a human might like to see -struct HumaneSourceLoc -{ - PathInfo pathInfo; - Int line = 0; - Int column = 0; -}; - - -/* A SourceView maps to a single span of SourceLoc range and is equivalent to a single include or more precisely use of a source file. -It is distinct from a SourceFile - because a SourceFile may be included multiple times, with different interpretations (depending -on #defines for example). -*/ -class SourceView -{ - public: - - // Each entry represents some contiguous span of locations that - // all map to the same logical file. - struct Entry - { - /// True if this resets the line numbering. It is distinct from a m_lineAdjust being 0, because it also means the path returns to the default. - bool isDefault() const { return m_pathHandle == StringSlicePool::Handle(0); } - - SourceLoc m_startLoc; ///< Where does this entry begin? - StringSlicePool::Handle m_pathHandle; ///< What is the presumed path for this entry. If 0 it means there is no path. - int32_t m_lineAdjust; ///< Adjustment to apply to source line numbers when printing presumed locations. Relative to the line number in the underlying file. - }; - - /// Given a sourceLoc finds the entry associated with it. If returns -1 then no entry is - /// associated with this location, and therefore the location should be interpreted as an offset - /// into the underlying sourceFile. - int findEntryIndex(SourceLoc sourceLoc) const; - - /// Add a line directive for this view. The directiveLoc must of course be in this SourceView - /// The path handle, must have been constructed on the SourceManager associated with the view - /// NOTE! Directives are assumed to be added IN ORDER during parsing such that every directiveLoc > previous - void addLineDirective(SourceLoc directiveLoc, StringSlicePool::Handle pathHandle, int line); - void addLineDirective(SourceLoc directiveLoc, const String& path, int line); - - /// Removes any corrections on line numbers and reverts to the source files path - void addDefaultLineDirective(SourceLoc directiveLoc); - - /// Get the range that this view applies to - const SourceRange& getRange() const { return m_range; } - /// Get the entries - const List& getEntries() const { return m_entries; } - /// Set the entries list - void setEntries(const Entry* entries, UInt numEntries) { m_entries.clear(); m_entries.addRange(entries, numEntries); } - - /// Get the source file holds the contents this view - SourceFile* getSourceFile() const { return m_sourceFile; } - /// Get the source manager - SourceManager* getSourceManager() const { return m_sourceFile->getSourceManager(); } - - /// Get the associated 'content' (the source text) - const UnownedStringSlice& getContent() const { return m_sourceFile->getContent(); } - - /// Get the size of the content - size_t getContentSize() const { return m_sourceFile->getContentSize(); } - - /// Get the humane location - /// Type determines if the location wanted is the original, or the 'normal' (which modifys behavior based on #line directives) - HumaneSourceLoc getHumaneLoc(SourceLoc loc, SourceLocType type = SourceLocType::Nominal); - - /// Get the path associated with a location - PathInfo getPathInfo(SourceLoc loc, SourceLocType type = SourceLocType::Nominal); - - /// Ctor - SourceView(SourceFile* sourceFile, SourceRange range, const String* viewPath): - m_range(range), - m_sourceFile(sourceFile) - { - if (viewPath) - { - m_viewPath = *viewPath; - } - } - - protected: - /// Get the pathInfo from a string handle. If it's 0, it will return the _getPathInfo - PathInfo _getPathInfoFromHandle(StringSlicePool::Handle pathHandle) const; - /// Gets the pathInfo for this view. It may be different from the m_sourceFile's if the path has been - /// overridden by m_viewPath - PathInfo _getPathInfo() const; - - String m_viewPath; ///< Path to this view. If empty the path is the path to the SourceView - - SourceRange m_range; ///< The range that this SourceView applies to - SourceFile* m_sourceFile; ///< The source file. Can hold the line breaks - List m_entries; ///< An array entries describing how we should interpret a range, starting from the start location. -}; - -struct SourceManager -{ - // Initialize a source manager, with an optional parent - void initialize(SourceManager* parent, ISlangFileSystemExt* fileSystemExt); - - /// Allocate a range of SourceLoc locations, these can be used to identify a specific location in the source - SourceRange allocateSourceRange(UInt size); - - /// Create a SourceFile defined with the specified path, and content held within a blob - SourceFile* createSourceFileWithSize(const PathInfo& pathInfo, size_t contentSize); - SourceFile* createSourceFileWithString(const PathInfo& pathInfo, const String& contents); - SourceFile* createSourceFileWithBlob(const PathInfo& pathInfo, ISlangBlob* blob); - - /// Get the humane source location - HumaneSourceLoc getHumaneLoc(SourceLoc loc, SourceLocType type = SourceLocType::Nominal); - - /// Get the path associated with a location - PathInfo getPathInfo(SourceLoc loc, SourceLocType type = SourceLocType::Nominal); - - /// Create a new source view from a file - /// @param sourceFile is the source file that contains the source - /// @param pathInfo is path used to read the file from - SourceView* createSourceView(SourceFile* sourceFile, const PathInfo* pathInfo); - - /// Find a view by a source file location. - /// If not found in this manager will look in the parent SourceManager - /// Returns nullptr if not found. - SourceView* findSourceViewRecursively(SourceLoc loc) const; - - /// Find the SourceView associated with this manager for a specified location - /// Returns nullptr if not found. - SourceView* findSourceView(SourceLoc loc) const; - - /// Searches this manager, and then the parent to see if can find a match for path. - /// If not found returns nullptr. - SourceFile* findSourceFileRecursively(const String& uniqueIdentity) const; - /// Find if the source file is defined on this manager. - SourceFile* findSourceFile(const String& uniqueIdentity) const; - - /// Get the file system associated with this source manager - ISlangFileSystemExt* getFileSystemExt() const { return m_fileSystemExt; } - /// Get the file system associated with this source manager - void setFileSystemExt(ISlangFileSystemExt* fileSystemExt) { m_fileSystemExt = fileSystemExt; } - - /// Add a source file, uniqueIdentity must be unique for this manager AND any parents - void addSourceFile(const String& uniqueIdentity, SourceFile* sourceFile); - - /// Get the slice pool - StringSlicePool& getStringSlicePool() { return m_slicePool; } - - /// Get the source range for just this manager - /// Caution - the range will change if allocations are made to this manager. - SourceRange getSourceRange() const { return SourceRange(m_startLoc, m_nextLoc); } - - /// Get the parent manager to this manager. Returns nullptr if there isn't any. - SourceManager* getParent() const { return m_parent; } - - /// A memory arena to hold allocations that are in scope for the same time as SourceManager - MemoryArena* getMemoryArena() { return &m_memoryArena; } - - /// Allocate a string slice - UnownedStringSlice allocateStringSlice(const UnownedStringSlice& slice); - - SourceManager() : - m_memoryArena(2048) - {} - ~SourceManager(); - - protected: - - // The first location available to this source manager - // (may not be the first location of all, because we might - // have a parent source manager) - SourceLoc m_startLoc; - - // The "parent" source manager that owns locations ahead of `startLoc` - SourceManager* m_parent = nullptr; - - // The location to be used by the next source file to be loaded - SourceLoc m_nextLoc; - - // All of the SourceViews constructed on this SourceManager. These are held in increasing order of range, so can find by doing a binary chop. - List m_sourceViews; - // All of the SourceFiles constructed on this SourceManager. This owns the SourceFile. - List m_sourceFiles; - - StringSlicePool m_slicePool; - - // Memory arena that can be used for holding data to held in scope as long as the Source is - // Can be used for storing the decoded contents of Token. Content for example. - MemoryArena m_memoryArena; - - // Maps uniqueIdentities to source files - Dictionary m_sourceFileMap; - - ComPtr m_fileSystemExt; -}; - -} // namespace Slang - -#endif diff --git a/source/slang/stmt-defs.h b/source/slang/stmt-defs.h deleted file mode 100644 index 01dbcc4ca..000000000 --- a/source/slang/stmt-defs.h +++ /dev/null @@ -1,124 +0,0 @@ -// stmt-defs.h - -// Syntax class definitions for statements. - -ABSTRACT_SYNTAX_CLASS(ScopeStmt, Stmt) - SYNTAX_FIELD(RefPtr, scopeDecl) -END_SYNTAX_CLASS() - -// A sequence of statements, treated as a single statement -SYNTAX_CLASS(SeqStmt, Stmt) - SYNTAX_FIELD(List>, stmts) -END_SYNTAX_CLASS() - -// The simplest kind of scope statement: just a `{...}` block -SYNTAX_CLASS(BlockStmt, ScopeStmt) - SYNTAX_FIELD(RefPtr, body); -END_SYNTAX_CLASS() - -// A statement that we aren't going to parse or check, because -// we want to let a downstream compiler handle any issues -SYNTAX_CLASS(UnparsedStmt, Stmt) - // The tokens that were contained between `{` and `}` - FIELD(List, tokens) -END_SYNTAX_CLASS() - -SIMPLE_SYNTAX_CLASS(EmptyStmt, Stmt) - -SIMPLE_SYNTAX_CLASS(DiscardStmt, Stmt) - -SYNTAX_CLASS(DeclStmt, Stmt) - SYNTAX_FIELD(RefPtr, decl) -END_SYNTAX_CLASS() - -SYNTAX_CLASS(IfStmt, Stmt) - SYNTAX_FIELD(RefPtr, Predicate) - SYNTAX_FIELD(RefPtr, PositiveStatement) - SYNTAX_FIELD(RefPtr, NegativeStatement) -END_SYNTAX_CLASS() - -// A statement that can be escaped with a `break` -ABSTRACT_SYNTAX_CLASS(BreakableStmt, ScopeStmt) -END_SYNTAX_CLASS() - -SYNTAX_CLASS(SwitchStmt, BreakableStmt) - SYNTAX_FIELD(RefPtr, condition) - SYNTAX_FIELD(RefPtr, body) -END_SYNTAX_CLASS() - -// A statement that is expected to appear lexically nested inside -// some other construct, and thus needs to keep track of the -// outer statement that it is associated with... -ABSTRACT_SYNTAX_CLASS(ChildStmt, Stmt) - DECL_FIELD(Stmt*, parentStmt RAW(= nullptr)) -END_SYNTAX_CLASS() - -// a `case` or `default` statement inside a `switch` -// -// Note(tfoley): A correct AST for a C-like language would treat -// these as a labelled statement, and so they would contain a -// sub-statement. I'm leaving that out for now for simplicity. -ABSTRACT_SYNTAX_CLASS(CaseStmtBase, ChildStmt) -END_SYNTAX_CLASS() - -// a `case` statement inside a `switch` -SYNTAX_CLASS(CaseStmt, CaseStmtBase) - SYNTAX_FIELD(RefPtr, expr) -END_SYNTAX_CLASS() - -// a `default` statement inside a `switch` -SIMPLE_SYNTAX_CLASS(DefaultStmt, CaseStmtBase) - -// A statement that represents a loop, and can thus be escaped with a `continue` -ABSTRACT_SYNTAX_CLASS(LoopStmt, BreakableStmt) -END_SYNTAX_CLASS() - -// A `for` statement -SYNTAX_CLASS(ForStmt, LoopStmt) - SYNTAX_FIELD(RefPtr, InitialStatement) - SYNTAX_FIELD(RefPtr, SideEffectExpression) - SYNTAX_FIELD(RefPtr, PredicateExpression) - SYNTAX_FIELD(RefPtr, Statement) -END_SYNTAX_CLASS() - -// A `for` statement in a language that doesn't restrict the scope -// of the loop variable to the body. -SYNTAX_CLASS(UnscopedForStmt, ForStmt); -END_SYNTAX_CLASS() - -SYNTAX_CLASS(WhileStmt, LoopStmt) - SYNTAX_FIELD(RefPtr, Predicate) - SYNTAX_FIELD(RefPtr, Statement) -END_SYNTAX_CLASS() - -SYNTAX_CLASS(DoWhileStmt, LoopStmt) - SYNTAX_FIELD(RefPtr, Statement) - SYNTAX_FIELD(RefPtr, Predicate) -END_SYNTAX_CLASS() - -// A compile-time, range-based `for` loop, which will not appear in the output code -SYNTAX_CLASS(CompileTimeForStmt, ScopeStmt) - SYNTAX_FIELD(RefPtr, varDecl) - SYNTAX_FIELD(RefPtr, rangeBeginExpr) - SYNTAX_FIELD(RefPtr, rangeEndExpr) - SYNTAX_FIELD(RefPtr, body) - SYNTAX_FIELD(RefPtr, rangeBeginVal) - SYNTAX_FIELD(RefPtr, rangeEndVal) -END_SYNTAX_CLASS() - -// The case of child statements that do control flow relative -// to their parent statement. -ABSTRACT_SYNTAX_CLASS(JumpStmt, ChildStmt) -END_SYNTAX_CLASS() - -SIMPLE_SYNTAX_CLASS(BreakStmt, JumpStmt) - -SIMPLE_SYNTAX_CLASS(ContinueStmt, JumpStmt) - -SYNTAX_CLASS(ReturnStmt, Stmt) - SYNTAX_FIELD(RefPtr, Expression) -END_SYNTAX_CLASS() - -SYNTAX_CLASS(ExpressionStmt, Stmt) - SYNTAX_FIELD(RefPtr, Expression) -END_SYNTAX_CLASS() diff --git a/source/slang/syntax-base-defs.h b/source/slang/syntax-base-defs.h deleted file mode 100644 index b0da3f57e..000000000 --- a/source/slang/syntax-base-defs.h +++ /dev/null @@ -1,307 +0,0 @@ -// syntax-base-defs.h - -// This file defines the primary base classes for the hierarchy of -// AST nodes and related objects. For example, this is where the -// basic `Decl`, `Stmt`, `Expr`, `type`, etc. definitions come from. - -ABSTRACT_SYNTAX_CLASS(NodeBase, RefObject) - // A helper to access the corresponding class on a concrete instance - RAW( - virtual SyntaxClass getClass() = 0; - ) -END_SYNTAX_CLASS() - -// Base class for all nodes representing actual syntax -// (thus having a location in the source code) -ABSTRACT_SYNTAX_CLASS(SyntaxNodeBase, NodeBase) - // The primary source location associated with this AST node - FIELD(SourceLoc, loc) -END_SYNTAX_CLASS() - -// Base class for compile-time values (most often a type). -// These are *not* syntax nodes, because they do not have -// a unique location, and any two `Val`s representing -// the same value should be conceptually equal. -ABSTRACT_SYNTAX_CLASS(Val, NodeBase) - RAW(typedef IValVisitor Visitor;) - - RAW(virtual void accept(IValVisitor* visitor, void* extra) = 0;) - - RAW( - // construct a new value by applying a set of parameter - // substitutions to this one - RefPtr Substitute(SubstitutionSet subst); - - // Lower-level interface for substitution. Like the basic - // `Substitute` above, but also takes a by-reference - // integer parameter that should be incremented when - // returning a modified value (this can help the caller - // decide whether they need to do anything). - virtual RefPtr SubstituteImpl(SubstitutionSet subst, int* ioDiff); - - virtual bool EqualsVal(Val* val) = 0; - virtual String ToString() = 0; - virtual int GetHashCode() = 0; - bool operator == (const Val & v) - { - return EqualsVal(const_cast(&v)); - } - ) -END_SYNTAX_CLASS() - -RAW( - class Type; - - template - SLANG_FORCE_INLINE T* as(Type* obj); - template - SLANG_FORCE_INLINE const T* as(const Type* obj); - ) - -// A type, representing a classifier for some term in the AST. -// -// Types can include "sugar" in that they may refer to a -// `typedef` which gives them a good name when printed as -// part of diagnostic messages. -// -// In order to operation on types, though, we often want -// to look past any sugar, and operate on an underlying -// "canonical" type. The representation caches a pointer to -// a canonical type on every type, so we can easily -// operate on the raw representation when needed. -ABSTRACT_SYNTAX_CLASS(Type, Val) - RAW(typedef ITypeVisitor Visitor;) - - RAW(virtual void accept(IValVisitor* visitor, void* extra) override;) - RAW(virtual void accept(ITypeVisitor* visitor, void* extra) = 0;) - -RAW( -public: - Session* getSession() { return this->session; } - void setSession(Session* s) { this->session = s; } - - bool Equals(Type* type); - - Type* GetCanonicalType(); - - virtual RefPtr SubstituteImpl(SubstitutionSet subst, int* ioDiff) override; - - virtual bool EqualsVal(Val* val) override; - - ~Type(); - -protected: - virtual bool EqualsImpl(Type* type) = 0; - - virtual RefPtr CreateCanonicalType() = 0; - Type* canonicalType = nullptr; - - Session* session = nullptr; - ) -END_SYNTAX_CLASS() -RAW( - template - SLANG_FORCE_INLINE T* as(Type* obj) { return obj ? dynamicCast(obj->GetCanonicalType()) : nullptr; } - template - SLANG_FORCE_INLINE const T* as(const Type* obj) { return obj ? dynamicCast(const_cast(obj)->GetCanonicalType()) : nullptr; } -) - -// A substitution represents a binding of certain -// type-level variables to concrete argument values -ABSTRACT_SYNTAX_CLASS(Substitutions, RefObject) - // The next outer that this one refines. - FIELD(RefPtr, outer) - - RAW( - // Apply a set of substitutions to the bindings in this substitution - virtual RefPtr applySubstitutionsShallow(SubstitutionSet substSet, RefPtr substOuter, int* ioDiff) = 0; - - // Check if these are equivalent substitutiosn to another set - virtual bool Equals(Substitutions* subst) = 0; - virtual int GetHashCode() const = 0; - ) -END_SYNTAX_CLASS() - -SYNTAX_CLASS(GenericSubstitution, Substitutions) - // The generic declaration that defines the - // parameters we are binding to arguments - DECL_FIELD(GenericDecl*, genericDecl) - - // The actual values of the arguments - SYNTAX_FIELD(List>, args) - - RAW( - // Apply a set of substitutions to the bindings in this substitution - virtual RefPtr applySubstitutionsShallow(SubstitutionSet substSet, RefPtr substOuter, int* ioDiff) override; - - // Check if these are equivalent substitutiosn to another set - virtual bool Equals(Substitutions* subst) override; - - virtual int GetHashCode() const override - { - int rs = 0; - for (auto && v : args) - { - rs ^= v->GetHashCode(); - rs *= 16777619; - } - return rs; - } - ) -END_SYNTAX_CLASS() - -SYNTAX_CLASS(ThisTypeSubstitution, Substitutions) - // The declaration of the interface that we are specializing - FIELD_INIT(InterfaceDecl*, interfaceDecl, nullptr) - - // A witness that shows that the concrete type used to - // specialize the interface conforms to the interface. - FIELD(RefPtr, witness) - - // The actual type that provides the lookup scope for an associated type - RAW( - // Apply a set of substitutions to the bindings in this substitution - virtual RefPtr applySubstitutionsShallow(SubstitutionSet substSet, RefPtr substOuter, int* ioDiff) override; - - // Check if these are equivalent substitutiosn to another set - virtual bool Equals(Substitutions* subst) override; - - virtual int GetHashCode() const override; - ) -END_SYNTAX_CLASS() - -SYNTAX_CLASS(GlobalGenericParamSubstitution, Substitutions) - // the type_param decl to be substituted - DECL_FIELD(GlobalGenericParamDecl*, paramDecl) - - // the actual type to substitute in - SYNTAX_FIELD(RefPtr, actualType) - - RAW( - struct ConstraintArg - { - RefPtr decl; - RefPtr val; - }; - ) - - // the values that satisfy any constraints on the type parameter - SYNTAX_FIELD(List, constraintArgs) - -RAW( - // Apply a set of substitutions to the bindings in this substitution - virtual RefPtr applySubstitutionsShallow(SubstitutionSet substSet, RefPtr substOuter, int* ioDiff) override; - - // Check if these are equivalent substitutiosn to another set - virtual bool Equals(Substitutions* subst) override; - - virtual int GetHashCode() const override - { - int rs = actualType->GetHashCode(); - for (auto && a : constraintArgs) - { - rs = combineHash(rs, a.val->GetHashCode()); - } - return rs; - } - ) -END_SYNTAX_CLASS() - -ABSTRACT_SYNTAX_CLASS(SyntaxNode, SyntaxNodeBase) -END_SYNTAX_CLASS() - -// -// All modifiers are represented as full-fledged objects in the AST -// (that is, we don't use a bitfield, even for simple/common flags). -// This ensures that we can track source locations for all modifiers. -// -ABSTRACT_SYNTAX_CLASS(Modifier, SyntaxNode) - RAW(typedef IModifierVisitor Visitor;) - - RAW(virtual void accept(IModifierVisitor* visitor, void* extra) = 0;) - - // Next modifier in linked list of modifiers on same piece of syntax - SYNTAX_FIELD(RefPtr, next) - - // The keyword that was used to introduce t that was used to name this modifier. - FIELD(Name*, name) - - RAW( - Name* getName() { return name; } - NameLoc getNameAndLoc() { return NameLoc(name, loc); } - ) -END_SYNTAX_CLASS() - -// A syntax node which can have modifiers applied -ABSTRACT_SYNTAX_CLASS(ModifiableSyntaxNode, SyntaxNode) - - SYNTAX_FIELD(Modifiers, modifiers) - - RAW( - template - FilteredModifierList GetModifiersOfType() { return FilteredModifierList(modifiers.first.Ptr()); } - - // Find the first modifier of a given type, or return `nullptr` if none is found. - template - T* FindModifier() - { - return *GetModifiersOfType().begin(); - } - - template - bool HasModifier() { return FindModifier() != nullptr; } - ) -END_SYNTAX_CLASS() - - -// An intermediate type to represent either a single declaration, or a group of declarations -ABSTRACT_SYNTAX_CLASS(DeclBase, ModifiableSyntaxNode) - RAW(typedef IDeclVisitor Visitor;) - - RAW(virtual void accept(IDeclVisitor* visitor, void* extra) = 0;) - - -END_SYNTAX_CLASS() - -ABSTRACT_SYNTAX_CLASS(Decl, DeclBase) - DECL_FIELD(ContainerDecl*, ParentDecl RAW(=nullptr)) - - FIELD(NameLoc, nameAndLoc) - - RAW( - Name* getName() { return nameAndLoc.name; } - SourceLoc getNameLoc() { return nameAndLoc.loc ; } - NameLoc getNameAndLoc() { return nameAndLoc ; } - ) - - - FIELD_INIT(DeclCheckState, checkState, DeclCheckState::Unchecked) - - // The next declaration defined in the same container with the same name - DECL_FIELD(Decl*, nextInContainerWithSameName RAW(= nullptr)) - - RAW( - bool IsChecked(DeclCheckState state) { return checkState >= state; } - void SetCheckState(DeclCheckState state) - { - SLANG_RELEASE_ASSERT(state >= checkState); - checkState = state; - } - ) -END_SYNTAX_CLASS() - -ABSTRACT_SYNTAX_CLASS(Expr, SyntaxNode) - RAW(typedef IExprVisitor Visitor;) - - FIELD(QualType, type) - - RAW(virtual void accept(IExprVisitor* visitor, void* extra) = 0;) - -END_SYNTAX_CLASS() - -ABSTRACT_SYNTAX_CLASS(Stmt, ModifiableSyntaxNode) - RAW(typedef IStmtVisitor Visitor;) - - RAW(virtual void accept(IStmtVisitor* visitor, void* extra) = 0;) - -END_SYNTAX_CLASS() diff --git a/source/slang/syntax-defs.h b/source/slang/syntax-defs.h deleted file mode 100644 index 4ff4a55a6..000000000 --- a/source/slang/syntax-defs.h +++ /dev/null @@ -1,10 +0,0 @@ -// syntax-defs.h - -#include "syntax-base-defs.h" - -#include "expr-defs.h" -#include "decl-defs.h" -#include "modifier-defs.h" -#include "stmt-defs.h" -#include "type-defs.h" -#include "val-defs.h" diff --git a/source/slang/syntax-visitors.h b/source/slang/syntax-visitors.h deleted file mode 100644 index 3fca323e8..000000000 --- a/source/slang/syntax-visitors.h +++ /dev/null @@ -1,36 +0,0 @@ -#ifndef RASTER_RENDERER_SYNTAX_PRINTER_H -#define RASTER_RENDERER_SYNTAX_PRINTER_H - -#include "diagnostics.h" -#include "syntax.h" - -namespace Slang -{ - class DiagnosticSink; - class EntryPoint; - class Linkage; - class Module; - class ShaderCompiler; - class ShaderLinkInfo; - class ShaderSymbol; - - class TranslationUnitRequest; - - void checkTranslationUnit( - TranslationUnitRequest* translationUnit); - - // Look for a module that matches the given name: - // either one we've loaded already, or one we - // can find vai the search paths available to us. - // - // Needed by import declaration checking. - // - // TODO: need a better location to declare this. - RefPtr findOrImportModule( - Linkage* linkage, - Name* name, - SourceLoc const& loc, - DiagnosticSink* sink); -} - -#endif \ No newline at end of file diff --git a/source/slang/syntax.cpp b/source/slang/syntax.cpp deleted file mode 100644 index 17c85175d..000000000 --- a/source/slang/syntax.cpp +++ /dev/null @@ -1,2865 +0,0 @@ -#include "syntax.h" - -#include "compiler.h" -#include "visitor.h" - -#include -#include - -namespace Slang -{ - // BasicExpressionType - - bool BasicExpressionType::EqualsImpl(Type * type) - { - auto basicType = dynamicCast(type); - return basicType && basicType->baseType == this->baseType; - } - - RefPtr BasicExpressionType::CreateCanonicalType() - { - // A basic type is already canonical, in our setup - return this; - } - - // Generate dispatch logic and other definitions for all syntax classes -#define SYNTAX_CLASS(NAME, BASE) /* empty */ -#include "object-meta-begin.h" - -#include "syntax-base-defs.h" -#undef SYNTAX_CLASS -#undef ABSTRACT_SYNTAX_CLASS - -#define ABSTRACT_SYNTAX_CLASS(NAME, BASE) \ - template<> \ - SyntaxClassBase::ClassInfo const SyntaxClassBase::Impl::kClassInfo = { #NAME, &SyntaxClassBase::Impl::kClassInfo, nullptr }; - -#define SYNTAX_CLASS(NAME, BASE) \ - void NAME::accept(NAME::Visitor* visitor, void* extra) \ - { visitor->dispatch_##NAME(this, extra); } \ - template<> \ - void* SyntaxClassBase::Impl::createFunc() { return new NAME(); } \ - SyntaxClass NAME::getClass() { return Slang::getClass(); } \ - template<> \ - SyntaxClassBase::ClassInfo const SyntaxClassBase::Impl::kClassInfo = { #NAME, &SyntaxClassBase::Impl::kClassInfo, &SyntaxClassBase::Impl::createFunc }; - -template<> -SyntaxClassBase::ClassInfo const SyntaxClassBase::Impl::kClassInfo = { "RefObject", nullptr, nullptr }; - -ABSTRACT_SYNTAX_CLASS(NodeBase, RefObject); -ABSTRACT_SYNTAX_CLASS(SyntaxNodeBase, NodeBase); -ABSTRACT_SYNTAX_CLASS(SyntaxNode, SyntaxNodeBase); -ABSTRACT_SYNTAX_CLASS(ModifiableSyntaxNode, SyntaxNode); -ABSTRACT_SYNTAX_CLASS(DeclBase, ModifiableSyntaxNode); -ABSTRACT_SYNTAX_CLASS(Decl, DeclBase); -ABSTRACT_SYNTAX_CLASS(Stmt, ModifiableSyntaxNode); -ABSTRACT_SYNTAX_CLASS(Val, NodeBase); -ABSTRACT_SYNTAX_CLASS(Type, Val); -ABSTRACT_SYNTAX_CLASS(Modifier, SyntaxNodeBase); -ABSTRACT_SYNTAX_CLASS(Expr, SyntaxNode); - -ABSTRACT_SYNTAX_CLASS(Substitutions, SyntaxNode); -ABSTRACT_SYNTAX_CLASS(GenericSubstitution, Substitutions); -ABSTRACT_SYNTAX_CLASS(ThisTypeSubstitution, Substitutions); -ABSTRACT_SYNTAX_CLASS(GlobalGenericParamSubstitution, Substitutions); - -#include "expr-defs.h" -#include "decl-defs.h" -#include "modifier-defs.h" -#include "stmt-defs.h" -#include "type-defs.h" -#include "val-defs.h" -#include "object-meta-end.h" - -bool SyntaxClassBase::isSubClassOfImpl(SyntaxClassBase const& super) const -{ - SyntaxClassBase::ClassInfo const* info = classInfo; - while (info) - { - if (info == super.classInfo) - return true; - - info = info->baseClass; - } - - return false; -} - -void Type::accept(IValVisitor* visitor, void* extra) -{ - accept((ITypeVisitor*)visitor, extra); -} - - // TypeExp - - bool TypeExp::Equals(Type* other) - { - return type->Equals(other); - } - - bool TypeExp::Equals(RefPtr other) - { - return type->Equals(other.Ptr()); - } - - // BasicExpressionType - - BasicExpressionType* BasicExpressionType::GetScalarType() - { - return this; - } - - // - - Type::~Type() - { - // If the canonicalType !=nullptr AND it is not set to this (ie the canonicalType is another object) - // then it needs to be released because it's owned by this object. - if (canonicalType && canonicalType != this) - { - canonicalType->releaseReference(); - } - } - - bool Type::Equals(Type * type) - { - return GetCanonicalType()->EqualsImpl(type->GetCanonicalType()); - } - - bool Type::EqualsVal(Val* val) - { - if (auto type = dynamicCast(val)) - return const_cast(this)->Equals(type); - return false; - } - - RefPtr Type::SubstituteImpl(SubstitutionSet subst, int* ioDiff) - { - int diff = 0; - auto canSubst = GetCanonicalType()->SubstituteImpl(subst, &diff); - - // If nothing changed, then don't drop any sugar that is applied - if (!diff) - return this; - - // If the canonical type changed, then we return a canonical type, - // rather than try to re-construct any amount of sugar - (*ioDiff)++; - return canSubst; - } - - Type* Type::GetCanonicalType() - { - Type* et = const_cast(this); - if (!et->canonicalType) - { - // TODO(tfoley): worry about thread safety here? - auto canType = et->CreateCanonicalType(); - et->canonicalType = canType; - - // TODO(js): That this detachs when canType == this is a little surprising. It would seem - // as if this would create a circular reference on the object, but in practice there are - // no leaks so appears correct. - // That the dtor only releases if != this, also makes it surprising. - canType.detach(); - - SLANG_ASSERT(et->canonicalType); - } - return et->canonicalType; - } - - void Session::initializeTypes() - { - errorType = new ErrorType(); - errorType->setSession(this); - - initializerListType = new InitializerListType(); - initializerListType->setSession(this); - - overloadedType = new OverloadGroupType(); - overloadedType->setSession(this); - } - - Type* Session::getBoolType() - { - return getBuiltinType(BaseType::Bool); - } - - Type* Session::getHalfType() - { - return getBuiltinType(BaseType::Half); - } - - Type* Session::getFloatType() - { - return getBuiltinType(BaseType::Float); - } - - Type* Session::getDoubleType() - { - return getBuiltinType(BaseType::Double); - } - - Type* Session::getIntType() - { - return getBuiltinType(BaseType::Int); - } - - Type* Session::getInt64Type() - { - return getBuiltinType(BaseType::Int64); - } - - Type* Session::getUIntType() - { - return getBuiltinType(BaseType::UInt); - } - - Type* Session::getUInt64Type() - { - return getBuiltinType(BaseType::UInt64); - } - - Type* Session::getVoidType() - { - return getBuiltinType(BaseType::Void); - } - - Type* Session::getBuiltinType(BaseType flavor) - { - return RefPtr(builtinTypes[(int)flavor]); - } - - Type* Session::getInitializerListType() - { - return initializerListType; - } - - Type* Session::getOverloadedType() - { - return overloadedType; - } - - Type* Session::getErrorType() - { - return errorType; - } - - Type* Session::getStringType() - { - if (stringType == nullptr) - { - auto stringTypeDecl = findMagicDecl(this, "StringType"); - stringType = DeclRefType::Create(this, makeDeclRef(stringTypeDecl)); - } - return stringType; - } - - Type* Session::getEnumTypeType() - { - if (enumTypeType == nullptr) - { - auto enumTypeTypeDecl = findMagicDecl(this, "EnumTypeType"); - enumTypeType = DeclRefType::Create(this, makeDeclRef(enumTypeTypeDecl)); - } - return enumTypeType; - } - - RefPtr Session::getPtrType( - RefPtr valueType) - { - return getPtrType(valueType, "PtrType").dynamicCast(); - } - - // Construct the type `Out` - RefPtr Session::getOutType(RefPtr valueType) - { - return getPtrType(valueType, "OutType").dynamicCast(); - } - - RefPtr Session::getInOutType(RefPtr valueType) - { - return getPtrType(valueType, "InOutType").dynamicCast(); - } - - RefPtr Session::getRefType(RefPtr valueType) - { - return getPtrType(valueType, "RefType").dynamicCast(); - } - - RefPtr Session::getPtrType(RefPtr valueType, char const* ptrTypeName) - { - auto genericDecl = findMagicDecl(this, ptrTypeName).dynamicCast(); - return getPtrType(valueType, genericDecl); - } - - RefPtr Session::getPtrType(RefPtr valueType, GenericDecl* genericDecl) - { - auto typeDecl = genericDecl->inner; - - auto substitutions = new GenericSubstitution(); - substitutions->genericDecl = genericDecl; - substitutions->args.add(valueType); - - auto declRef = DeclRef(typeDecl.Ptr(), substitutions); - auto rsType = DeclRefType::Create( - this, - declRef); - return as( rsType); - } - - RefPtr Session::getArrayType( - Type* elementType, - IntVal* elementCount) - { - RefPtr arrayType = new ArrayExpressionType(); - arrayType->setSession(this); - arrayType->baseType = elementType; - arrayType->ArrayLength = elementCount; - return arrayType; - } - - SyntaxClass Session::findSyntaxClass(Name* name) - { - SyntaxClass syntaxClass; - if (mapNameToSyntaxClass.TryGetValue(name, syntaxClass)) - return syntaxClass; - - return SyntaxClass(); - } - - - - bool ArrayExpressionType::EqualsImpl(Type* type) - { - auto arrType = as(type); - if (!arrType) - return false; - return (areValsEqual(ArrayLength, arrType->ArrayLength) && baseType->Equals(arrType->baseType.Ptr())); - } - - RefPtr ArrayExpressionType::SubstituteImpl(SubstitutionSet subst, int* ioDiff) - { - int diff = 0; - auto elementType = baseType->SubstituteImpl(subst, &diff).as(); - auto arrlen = ArrayLength->SubstituteImpl(subst, &diff).as(); - SLANG_ASSERT(arrlen); - if (diff) - { - *ioDiff = 1; - auto rsType = getArrayType( - elementType, - arrlen); - return rsType; - } - return this; - } - - RefPtr ArrayExpressionType::CreateCanonicalType() - { - auto canonicalElementType = baseType->GetCanonicalType(); - auto canonicalArrayType = getArrayType( - canonicalElementType, - ArrayLength); - return canonicalArrayType; - } - int ArrayExpressionType::GetHashCode() - { - if (ArrayLength) - return (baseType->GetHashCode() * 16777619) ^ ArrayLength->GetHashCode(); - else - return baseType->GetHashCode(); - } - Slang::String ArrayExpressionType::ToString() - { - if (ArrayLength) - return baseType->ToString() + "[" + ArrayLength->ToString() + "]"; - else - return baseType->ToString() + "[]"; - } - - // DeclRefType - - String DeclRefType::ToString() - { - return declRef.toString(); - } - - int DeclRefType::GetHashCode() - { - return (declRef.GetHashCode() * 16777619) ^ (int)(typeid(this).hash_code()); - } - - bool DeclRefType::EqualsImpl(Type * type) - { - if (auto declRefType = as(type)) - { - return declRef.Equals(declRefType->declRef); - } - return false; - } - - RefPtr DeclRefType::CreateCanonicalType() - { - // A declaration reference is already canonical - return this; - } - - // - // RequirementWitness - // - - RequirementWitness::RequirementWitness(RefPtr val) - : m_flavor(Flavor::val) - , m_obj(val) - {} - - - RequirementWitness::RequirementWitness(RefPtr witnessTable) - : m_flavor(Flavor::witnessTable) - , m_obj(witnessTable) - {} - - RefPtr RequirementWitness::getWitnessTable() - { - SLANG_ASSERT(getFlavor() == Flavor::witnessTable); - return m_obj.as(); - } - - - RequirementWitness RequirementWitness::specialize(SubstitutionSet const& subst) - { - switch(getFlavor()) - { - default: - SLANG_UNEXPECTED("unknown requirement witness flavor"); - case RequirementWitness::Flavor::none: - return RequirementWitness(); - - case RequirementWitness::Flavor::declRef: - { - int diff = 0; - return RequirementWitness( - getDeclRef().SubstituteImpl(subst, &diff)); - } - - case RequirementWitness::Flavor::val: - { - auto val = getVal(); - SLANG_ASSERT(val); - - return RequirementWitness( - val->Substitute(subst)); - } - } - } - - RequirementWitness tryLookUpRequirementWitness( - SubtypeWitness* subtypeWitness, - Decl* requirementKey) - { - if(auto declaredSubtypeWitness = as(subtypeWitness)) - { - if(auto inheritanceDeclRef = declaredSubtypeWitness->declRef.as()) - { - // A conformance that was declared as part of an inheritance clause - // will have built up a dictionary of the satisfying declarations - // for each of its requirements. - RequirementWitness requirementWitness; - auto witnessTable = inheritanceDeclRef.getDecl()->witnessTable; - if(witnessTable && witnessTable->requirementDictionary.TryGetValue(requirementKey, requirementWitness)) - { - // The `inheritanceDeclRef` has substitutions applied to it that - // *aren't* present in the `requirementWitness`, because it was - // derived by the front-end when looking at the `InheritanceDecl` alone. - // - // We need to apply these substitutions here for the result to make sense. - // - // E.g., if we have a case like: - // - // interface ISidekick { associatedtype Hero; void follow(Hero hero); } - // struct Sidekick : ISidekick { typedef H Hero; void follow(H hero) {} }; - // - // void followHero(S s, S.Hero h) - // { - // s.follow(h); - // } - // - // Batman batman; - // Sidekick robin; - // followHero>(robin, batman); - // - // The second argument to `followHero` is `batman`, which has type `Batman`. - // The parameter declaration lists the type `S.Hero`, which is a reference - // to an associated type. The front end will expand this into something - // like `S.{S:ISidekick}.Hero` - that is, we'll end up with a declaration - // reference to `ISidekick.Hero` with a this-type substitution that references - // the `{S:ISidekick}` declaration as a witness. - // - // The front-end will expand the generic application `followHero>` - // to `followHero, {Sidekick:ISidekick}[H->Batman]>` - // (that is, the hidden second parameter will reference the inheritance - // clause on `Sidekick`, with a substitution to map `H` to `Batman`. - // - // This step should map the `{S:ISidekick}` declaration over to the - // concrete `{Sidekick:ISidekick}[H->Batman]` inheritance declaration. - // At that point `tryLookupRequirementWitness` might be called, because - // we want to look up the witness for the key `ISidekick.Hero` in the - // inheritance decl-ref that is `{Sidekick:ISidekick}[H->Batman]`. - // - // That lookup will yield us a reference to the typedef `Sidekick.Hero`, - // *without* any substitution for `H` (or rather, with a default one that - // maps `H` to `H`. - // - // So, in order to get the *right* end result, we need to apply - // the substitutions from the inheritance decl-ref to the witness. - // - requirementWitness = requirementWitness.specialize(inheritanceDeclRef.substitutions); - - return requirementWitness; - } - } - } - - // TODO: should handle the transitive case here too - - return RequirementWitness(); - } - - RefPtr DeclRefType::SubstituteImpl(SubstitutionSet subst, int* ioDiff) - { - if (!subst) return this; - - // the case we especially care about is when this type references a declaration - // of a generic parameter, since that is what we might be substituting... - if (auto genericTypeParamDecl = as(declRef.getDecl())) - { - // search for a substitution that might apply to us - for(auto s = subst.substitutions; s; s = s->outer) - { - auto genericSubst = s.as(); - if(!genericSubst) - continue; - - // the generic decl associated with the substitution list must be - // the generic decl that declared this parameter - auto genericDecl = genericSubst->genericDecl; - if (genericDecl != genericTypeParamDecl->ParentDecl) - continue; - - int index = 0; - for (auto m : genericDecl->Members) - { - if (m.Ptr() == genericTypeParamDecl) - { - // We've found it, so return the corresponding specialization argument - (*ioDiff)++; - return genericSubst->args[index]; - } - else if (auto typeParam = as(m)) - { - index++; - } - else if (auto valParam = as(m)) - { - index++; - } - else - { - } - } - } - } - else if (auto globalGenParam = as(declRef.getDecl())) - { - // search for a substitution that might apply to us - for(auto s = subst.substitutions; s; s = s->outer) - { - auto genericSubst = as(s); - if(!genericSubst) - continue; - - if (genericSubst->paramDecl == globalGenParam) - { - (*ioDiff)++; - return genericSubst->actualType; - } - } - } - int diff = 0; - DeclRef substDeclRef = declRef.SubstituteImpl(subst, &diff); - - if (!diff) - return this; - - // Make sure to record the difference! - *ioDiff += diff; - - // If this type is a reference to an associated type declaration, - // and the substitutions provide a "this type" substitution for - // the outer interface, then try to replace the type with the - // actual value of the associated type for the given implementation. - // - if(auto substAssocTypeDecl = as(substDeclRef.decl)) - { - for(auto s = substDeclRef.substitutions.substitutions; s; s = s->outer) - { - auto thisSubst = s.as(); - if(!thisSubst) - continue; - - if(auto interfaceDecl = as(substAssocTypeDecl->ParentDecl)) - { - if(thisSubst->interfaceDecl == interfaceDecl) - { - // We need to look up the declaration that satisfies - // the requirement named by the associated type. - Decl* requirementKey = substAssocTypeDecl; - RequirementWitness requirementWitness = tryLookUpRequirementWitness(thisSubst->witness, requirementKey); - switch(requirementWitness.getFlavor()) - { - default: - // No usable value was found, so there is nothing we can do. - break; - - case RequirementWitness::Flavor::val: - { - auto satisfyingVal = requirementWitness.getVal(); - return satisfyingVal; - } - break; - } - } - } - } - } - - // Re-construct the type in case we are using a specialized sub-class - return DeclRefType::Create(getSession(), substDeclRef); - } - - static RefPtr ExtractGenericArgType(RefPtr val) - { - auto type = val.as(); - SLANG_RELEASE_ASSERT(type.Ptr()); - return type; - } - - static RefPtr ExtractGenericArgInteger(RefPtr val) - { - auto intVal = val.as(); - SLANG_RELEASE_ASSERT(intVal.Ptr()); - return intVal; - } - - DeclRef createDefaultSubstitutionsIfNeeded( - Session* session, - DeclRef declRef) - { - // It is possible that `declRef` refers to a generic type, - // but does not specify arguments for its generic parameters. - // (E.g., this happens when referring to a generic type from - // within its own member functions). To handle this case, - // we will construct a default specialization at the use - // site if needed. - // - // This same logic should also apply to declarations nested - // more than one level inside of a generic (e.g., a `typdef` - // inside of a generic `struct`). - // - // Similarly, it needs to work for multiple levels of - // nested generics. - // - - // We are going to build up a list of substitutions that need - // to be applied to the decl-ref to make it specialized. - RefPtr substsToApply; - RefPtr* link = &substsToApply; - - RefPtr dd = declRef.getDecl(); - for(;;) - { - RefPtr childDecl = dd; - RefPtr parentDecl = dd->ParentDecl; - if(!parentDecl) - break; - - dd = parentDecl; - - if(auto genericParentDecl = parentDecl.as()) - { - // Don't specialize any parameters of a generic. - if(childDecl != genericParentDecl->inner) - break; - - // We have a generic ancestor, but do we have an substitutions for it? - RefPtr foundSubst; - for(auto s = declRef.substitutions.substitutions; s; s = s->outer) - { - auto genSubst = s.as(); - if(!genSubst) - continue; - - if(genSubst->genericDecl != genericParentDecl) - continue; - - // Okay, we found a matching substitution, - // so there is nothing to be done. - foundSubst = genSubst; - break; - } - - if(!foundSubst) - { - RefPtr newSubst = createDefaultSubsitutionsForGeneric( - session, - genericParentDecl, - nullptr); - - *link = newSubst; - link = &newSubst->outer; - } - } - } - - if(!substsToApply) - return declRef; - - int diff = 0; - return declRef.SubstituteImpl(substsToApply, &diff); - } - - // TODO: need to figure out how to unify this with the logic - // in the generic case... - RefPtr DeclRefType::Create( - Session* session, - DeclRef declRef) - { - declRef = createDefaultSubstitutionsIfNeeded(session, declRef); - - if (auto builtinMod = declRef.getDecl()->FindModifier()) - { - auto type = new BasicExpressionType(builtinMod->tag); - type->setSession(session); - type->declRef = declRef; - return type; - } - else if (auto magicMod = declRef.getDecl()->FindModifier()) - { - GenericSubstitution* subst = nullptr; - for(auto s = declRef.substitutions.substitutions; s; s = s->outer) - { - if(auto genericSubst = s.as()) - { - subst = genericSubst; - break; - } - } - - if (magicMod->name == "SamplerState") - { - auto type = new SamplerStateType(); - type->setSession(session); - type->declRef = declRef; - type->flavor = SamplerStateFlavor(magicMod->tag); - return type; - } - else if (magicMod->name == "Vector") - { - SLANG_ASSERT(subst && subst->args.getCount() == 2); - auto vecType = new VectorExpressionType(); - vecType->setSession(session); - vecType->declRef = declRef; - vecType->elementType = ExtractGenericArgType(subst->args[0]); - vecType->elementCount = ExtractGenericArgInteger(subst->args[1]); - return vecType; - } - else if (magicMod->name == "Matrix") - { - SLANG_ASSERT(subst && subst->args.getCount() == 3); - auto matType = new MatrixExpressionType(); - matType->setSession(session); - matType->declRef = declRef; - return matType; - } - else if (magicMod->name == "Texture") - { - SLANG_ASSERT(subst && subst->args.getCount() >= 1); - auto textureType = new TextureType( - TextureFlavor(magicMod->tag), - ExtractGenericArgType(subst->args[0])); - textureType->setSession(session); - textureType->declRef = declRef; - return textureType; - } - else if (magicMod->name == "TextureSampler") - { - SLANG_ASSERT(subst && subst->args.getCount() >= 1); - auto textureType = new TextureSamplerType( - TextureFlavor(magicMod->tag), - ExtractGenericArgType(subst->args[0])); - textureType->setSession(session); - textureType->declRef = declRef; - return textureType; - } - else if (magicMod->name == "GLSLImageType") - { - SLANG_ASSERT(subst && subst->args.getCount() >= 1); - auto textureType = new GLSLImageType( - TextureFlavor(magicMod->tag), - ExtractGenericArgType(subst->args[0])); - textureType->setSession(session); - textureType->declRef = declRef; - return textureType; - } - - // TODO: eventually everything should follow this pattern, - // and we can drive the dispatch with a table instead - // of this ridiculously slow `if` cascade. - - #define CASE(n,T) \ - else if(magicMod->name == #n) { \ - auto type = new T(); \ - type->setSession(session); \ - type->declRef = declRef; \ - return type; \ - } - - CASE(HLSLInputPatchType, HLSLInputPatchType) - CASE(HLSLOutputPatchType, HLSLOutputPatchType) - - #undef CASE - - #define CASE(n,T) \ - else if(magicMod->name == #n) { \ - SLANG_ASSERT(subst && subst->args.getCount() == 1); \ - auto type = new T(); \ - type->setSession(session); \ - type->elementType = ExtractGenericArgType(subst->args[0]); \ - type->declRef = declRef; \ - return type; \ - } - - CASE(ConstantBuffer, ConstantBufferType) - CASE(TextureBuffer, TextureBufferType) - CASE(ParameterBlockType, ParameterBlockType) - CASE(GLSLInputParameterGroupType, GLSLInputParameterGroupType) - CASE(GLSLOutputParameterGroupType, GLSLOutputParameterGroupType) - CASE(GLSLShaderStorageBufferType, GLSLShaderStorageBufferType) - - CASE(HLSLStructuredBufferType, HLSLStructuredBufferType) - CASE(HLSLRWStructuredBufferType, HLSLRWStructuredBufferType) - CASE(HLSLRasterizerOrderedStructuredBufferType, HLSLRasterizerOrderedStructuredBufferType) - CASE(HLSLAppendStructuredBufferType, HLSLAppendStructuredBufferType) - CASE(HLSLConsumeStructuredBufferType, HLSLConsumeStructuredBufferType) - - CASE(HLSLPointStreamType, HLSLPointStreamType) - CASE(HLSLLineStreamType, HLSLLineStreamType) - CASE(HLSLTriangleStreamType, HLSLTriangleStreamType) - - #undef CASE - - // "magic" builtin types which have no generic parameters - #define CASE(n,T) \ - else if(magicMod->name == #n) { \ - auto type = new T(); \ - type->setSession(session); \ - type->declRef = declRef; \ - return type; \ - } - - CASE(HLSLByteAddressBufferType, HLSLByteAddressBufferType) - CASE(HLSLRWByteAddressBufferType, HLSLRWByteAddressBufferType) - CASE(HLSLRasterizerOrderedByteAddressBufferType, HLSLRasterizerOrderedByteAddressBufferType) - CASE(UntypedBufferResourceType, UntypedBufferResourceType) - - CASE(GLSLInputAttachmentType, GLSLInputAttachmentType) - - #undef CASE - - else - { - auto classInfo = session->findSyntaxClass( - session->getNamePool()->getName(magicMod->name)); - if (!classInfo.classInfo) - { - SLANG_UNEXPECTED("unhandled type"); - } - - RefPtr type = classInfo.createInstance(); - if (!type) - { - SLANG_UNEXPECTED("constructor failure"); - } - - auto declRefType = dynamicCast(type); - if (!declRefType) - { - SLANG_UNEXPECTED("expected a declaration reference type"); - } - declRefType->session = session; - declRefType->declRef = declRef; - return declRefType; - } - } - else - { - auto type = new DeclRefType(declRef); - type->setSession(session); - return type; - } - } - - // OverloadGroupType - - String OverloadGroupType::ToString() - { - return "overload group"; - } - - bool OverloadGroupType::EqualsImpl(Type * /*type*/) - { - return false; - } - - RefPtr OverloadGroupType::CreateCanonicalType() - { - return this; - } - - int OverloadGroupType::GetHashCode() - { - return (int)(int64_t)(void*)this; - } - - // InitializerListType - - String InitializerListType::ToString() - { - return "initializer list"; - } - - bool InitializerListType::EqualsImpl(Type * /*type*/) - { - return false; - } - - RefPtr InitializerListType::CreateCanonicalType() - { - return this; - } - - int InitializerListType::GetHashCode() - { - return (int)(int64_t)(void*)this; - } - - // ErrorType - - String ErrorType::ToString() - { - return "error"; - } - - bool ErrorType::EqualsImpl(Type* type) - { - if (auto errorType = as(type)) - return true; - return false; - } - - RefPtr ErrorType::CreateCanonicalType() - { - return this; - } - - RefPtr ErrorType::SubstituteImpl(SubstitutionSet /*subst*/, int* /*ioDiff*/) - { - return this; - } - - int ErrorType::GetHashCode() - { - return (int)(int64_t)(void*)this; - } - - - // NamedExpressionType - - String NamedExpressionType::ToString() - { - return getText(declRef.GetName()); - } - - bool NamedExpressionType::EqualsImpl(Type * /*type*/) - { - SLANG_UNEXPECTED("unreachable"); - UNREACHABLE_RETURN(false); - } - - RefPtr NamedExpressionType::CreateCanonicalType() - { - if (!innerType) - innerType = GetType(declRef); - return innerType->GetCanonicalType(); - } - - int NamedExpressionType::GetHashCode() - { - // Type equality is based on comparing canonical types, - // so the hash code for a type needs to come from the - // canonical version of the type. This really means - // that `Type::GetHashCode()` should dispatch out to - // something like `Type::GetHashCodeImpl()` on the - // canonical version of a type, but it is less invasive - // for now (and hopefully equivalent) to just have any - // named types automaticlaly route hash-code requests - // to their canonical type. - return GetCanonicalType()->GetHashCode(); - } - - // FuncType - - String FuncType::ToString() - { - StringBuilder sb; - sb << "("; - UInt paramCount = getParamCount(); - for (UInt pp = 0; pp < paramCount; ++pp) - { - if (pp != 0) sb << ", "; - sb << getParamType(pp)->ToString(); - } - sb << ") -> "; - sb << getResultType()->ToString(); - return sb.ProduceString(); - } - - bool FuncType::EqualsImpl(Type * type) - { - if (auto funcType = as(type)) - { - auto paramCount = getParamCount(); - auto otherParamCount = funcType->getParamCount(); - if (paramCount != otherParamCount) - return false; - - for (UInt pp = 0; pp < paramCount; ++pp) - { - auto paramType = getParamType(pp); - auto otherParamType = funcType->getParamType(pp); - if (!paramType->Equals(otherParamType)) - return false; - } - - if(!resultType->Equals(funcType->resultType)) - return false; - - // TODO: if we ever introduce other kinds - // of qualification on function types, we'd - // want to consider it here. - return true; - } - return false; - } - - RefPtr FuncType::SubstituteImpl(SubstitutionSet subst, int* ioDiff) - { - int diff = 0; - - // result type - RefPtr substResultType = resultType->SubstituteImpl(subst, &diff).as(); - - // parameter types - List> substParamTypes; - for( auto pp : paramTypes ) - { - substParamTypes.add(pp->SubstituteImpl(subst, &diff).as()); - } - - // early exit for no change... - if(!diff) - return this; - - (*ioDiff)++; - RefPtr substType = new FuncType(); - substType->session = session; - substType->resultType = substResultType; - substType->paramTypes = substParamTypes; - return substType; - } - - RefPtr FuncType::CreateCanonicalType() - { - // result type - RefPtr canResultType = resultType->GetCanonicalType(); - - // parameter types - List> canParamTypes; - for( auto pp : paramTypes ) - { - canParamTypes.add(pp->GetCanonicalType()); - } - - RefPtr canType = new FuncType(); - canType->session = session; - canType->resultType = resultType; - canType->paramTypes = canParamTypes; - - return canType; - } - - int FuncType::GetHashCode() - { - int hashCode = getResultType()->GetHashCode(); - UInt paramCount = getParamCount(); - hashCode = combineHash(hashCode, Slang::GetHashCode(paramCount)); - for (UInt pp = 0; pp < paramCount; ++pp) - { - hashCode = combineHash( - hashCode, - getParamType(pp)->GetHashCode()); - } - return hashCode; - } - - // TypeType - - String TypeType::ToString() - { - StringBuilder sb; - sb << "typeof(" << type->ToString() << ")"; - return sb.ProduceString(); - } - - bool TypeType::EqualsImpl(Type * t) - { - if (auto typeType = as(t)) - { - return t->Equals(typeType->type); - } - return false; - } - - RefPtr TypeType::CreateCanonicalType() - { - auto canType = getTypeType(type->GetCanonicalType()); - return canType; - } - - int TypeType::GetHashCode() - { - SLANG_UNEXPECTED("unreachable"); - UNREACHABLE_RETURN(0); - } - - // GenericDeclRefType - - String GenericDeclRefType::ToString() - { - // TODO: what is appropriate here? - return ">"; - } - - bool GenericDeclRefType::EqualsImpl(Type * type) - { - if (auto genericDeclRefType = as(type)) - { - return declRef.Equals(genericDeclRefType->declRef); - } - return false; - } - - int GenericDeclRefType::GetHashCode() - { - return declRef.GetHashCode(); - } - - RefPtr GenericDeclRefType::CreateCanonicalType() - { - return this; - } - - // ArithmeticExpressionType - - // VectorExpressionType - - String VectorExpressionType::ToString() - { - StringBuilder sb; - sb << "vector<" << elementType->ToString() << "," << elementCount->ToString() << ">"; - return sb.ProduceString(); - } - - BasicExpressionType* VectorExpressionType::GetScalarType() - { - return as(elementType); - } - - // - - RefPtr findInnerMostGenericSubstitution(Substitutions* subst) - { - for(RefPtr s = subst; s; s = s->outer) - { - if(auto genericSubst = as(s)) - return genericSubst; - } - return nullptr; - } - - // MatrixExpressionType - - String MatrixExpressionType::ToString() - { - StringBuilder sb; - sb << "matrix<" << getElementType()->ToString() << "," << getRowCount()->ToString() << "," << getColumnCount()->ToString() << ">"; - return sb.ProduceString(); - } - - BasicExpressionType* MatrixExpressionType::GetScalarType() - { - return as(getElementType()); - } - - Type* MatrixExpressionType::getElementType() - { - return as(findInnerMostGenericSubstitution(declRef.substitutions)->args[0]); - } - - IntVal* MatrixExpressionType::getRowCount() - { - return as(findInnerMostGenericSubstitution(declRef.substitutions)->args[1]); - } - - IntVal* MatrixExpressionType::getColumnCount() - { - return as(findInnerMostGenericSubstitution(declRef.substitutions)->args[2]); - } - - RefPtr MatrixExpressionType::getRowType() - { - if( !mRowType ) - { - mRowType = getSession()->getVectorType(getElementType(), getColumnCount()); - } - return mRowType; - } - - RefPtr Session::getVectorType( - RefPtr elementType, - RefPtr elementCount) - { - auto vectorGenericDecl = findMagicDecl( - this, "Vector").as(); - auto vectorTypeDecl = vectorGenericDecl->inner; - - auto substitutions = new GenericSubstitution(); - substitutions->genericDecl = vectorGenericDecl.Ptr(); - substitutions->args.add(elementType); - substitutions->args.add(elementCount); - - auto declRef = DeclRef(vectorTypeDecl.Ptr(), substitutions); - - return DeclRefType::Create( - this, - declRef).as(); - } - - - // PtrTypeBase - - Type* PtrTypeBase::getValueType() - { - return as(findInnerMostGenericSubstitution(declRef.substitutions)->args[0]); - } - - // GenericParamIntVal - - bool GenericParamIntVal::EqualsVal(Val* val) - { - if (auto genericParamVal = as(val)) - { - return declRef.Equals(genericParamVal->declRef); - } - return false; - } - - String GenericParamIntVal::ToString() - { - return getText(declRef.GetName()); - } - - int GenericParamIntVal::GetHashCode() - { - return declRef.GetHashCode() ^ 0xFFFF; - } - - RefPtr GenericParamIntVal::SubstituteImpl(SubstitutionSet subst, int* ioDiff) - { - // search for a substitution that might apply to us - for(auto s = subst.substitutions; s; s = s->outer) - { - auto genSubst = s.as(); - if(!genSubst) - continue; - - // the generic decl associated with the substitution list must be - // the generic decl that declared this parameter - auto genericDecl = genSubst->genericDecl; - if (genericDecl != declRef.getDecl()->ParentDecl) - continue; - - int index = 0; - for (auto m : genericDecl->Members) - { - if (m.Ptr() == declRef.getDecl()) - { - // We've found it, so return the corresponding specialization argument - (*ioDiff)++; - return genSubst->args[index]; - } - else if (auto typeParam = as(m)) - { - index++; - } - else if (auto valParam = as(m)) - { - index++; - } - else - { - } - } - } - - // Nothing found: don't substitute. - return this; - } - - // Substitutions - - RefPtr GenericSubstitution::applySubstitutionsShallow(SubstitutionSet substSet, RefPtr substOuter, int* ioDiff) - { - int diff = 0; - - if(substOuter != outer) diff++; - - List> substArgs; - for (auto a : args) - { - substArgs.add(a->SubstituteImpl(substSet, &diff)); - } - - if (!diff) return this; - - (*ioDiff)++; - auto substSubst = new GenericSubstitution(); - substSubst->genericDecl = genericDecl; - substSubst->args = substArgs; - substSubst->outer = substOuter; - return substSubst; - } - - bool GenericSubstitution::Equals(Substitutions* subst) - { - // both must be NULL, or non-NULL - if (subst == nullptr) - return false; - if (this == subst) - return true; - - auto genericSubst = as(subst); - if (!genericSubst) - return false; - if (genericDecl != genericSubst->genericDecl) - return false; - - Index argCount = args.getCount(); - SLANG_RELEASE_ASSERT(args.getCount() == genericSubst->args.getCount()); - for (Index aa = 0; aa < argCount; ++aa) - { - if (!args[aa]->EqualsVal(genericSubst->args[aa].Ptr())) - return false; - } - - if (!outer) - return !genericSubst->outer; - - if (!outer->Equals(genericSubst->outer.Ptr())) - return false; - - return true; - } - - RefPtr ThisTypeSubstitution::applySubstitutionsShallow(SubstitutionSet substSet, RefPtr substOuter, int* ioDiff) - { - int diff = 0; - - if(substOuter != outer) diff++; - - // NOTE: Must use .as because we must have a smart pointer here to keep in scope. - auto substWitness = witness->SubstituteImpl(substSet, &diff).as(); - - if (!diff) return this; - - (*ioDiff)++; - auto substSubst = new ThisTypeSubstitution(); - substSubst->interfaceDecl = interfaceDecl; - substSubst->witness = substWitness; - substSubst->outer = substOuter; - return substSubst; - } - - bool ThisTypeSubstitution::Equals(Substitutions* subst) - { - if (!subst) - return false; - if (subst == this) - return true; - - if (auto thisTypeSubst = as(subst)) - { - return witness->EqualsVal(thisTypeSubst->witness); - } - return false; - } - - int ThisTypeSubstitution::GetHashCode() const - { - return witness->GetHashCode(); - } - - RefPtr GlobalGenericParamSubstitution::applySubstitutionsShallow(SubstitutionSet substSet, RefPtr substOuter, int* ioDiff) - { - // if we find a GlobalGenericParamSubstitution in subst that references the same type_param decl - // return a copy of that GlobalGenericParamSubstitution - int diff = 0; - - if(substOuter != outer) diff++; - - auto substActualType = actualType->SubstituteImpl(substSet, &diff).as(); - - List substConstraintArgs; - for(auto constraintArg : constraintArgs) - { - ConstraintArg substConstraintArg; - substConstraintArg.decl = constraintArg.decl; - substConstraintArg.val = constraintArg.val->SubstituteImpl(substSet, &diff); - - substConstraintArgs.add(substConstraintArg); - } - - if(!diff) - return this; - - (*ioDiff)++; - - RefPtr substSubst = new GlobalGenericParamSubstitution(); - substSubst->paramDecl = paramDecl; - substSubst->actualType = substActualType; - substSubst->constraintArgs = substConstraintArgs; - substSubst->outer = substOuter; - return substSubst; - } - - bool GlobalGenericParamSubstitution::Equals(Substitutions* subst) - { - if (!subst) - return false; - if (subst == this) - return true; - - if (auto genSubst = as(subst)) - { - if (paramDecl != genSubst->paramDecl) - return false; - if (!actualType->EqualsVal(genSubst->actualType)) - return false; - if (constraintArgs.getCount() != genSubst->constraintArgs.getCount()) - return false; - for (Index i = 0; i < constraintArgs.getCount(); i++) - { - if (!constraintArgs[i].val->EqualsVal(genSubst->constraintArgs[i].val)) - return false; - } - return true; - } - return false; - } - - - // DeclRefBase - - RefPtr DeclRefBase::Substitute(RefPtr type) const - { - // Note that type can be nullptr, and so this function can return nullptr (although only correctly when no substitutions) - - // No substitutions? Easy. - if (!substitutions) - return type; - - SLANG_ASSERT(type); - - // Otherwise we need to recurse on the type structure - // and apply substitutions where it makes sense - return type->Substitute(substitutions).as(); - } - - DeclRefBase DeclRefBase::Substitute(DeclRefBase declRef) const - { - if(!substitutions) - return declRef; - - int diff = 0; - return declRef.SubstituteImpl(substitutions, &diff); - } - - RefPtr DeclRefBase::Substitute(RefPtr expr) const - { - // No substitutions? Easy. - if (!substitutions) - return expr; - - SLANG_UNIMPLEMENTED_X("generic substitution into expressions"); - - UNREACHABLE_RETURN(expr); - } - - void buildMemberDictionary(ContainerDecl* decl); - - InterfaceDecl* findOuterInterfaceDecl(Decl* decl) - { - Decl* dd = decl; - while(dd) - { - if(auto interfaceDecl = as(dd)) - return interfaceDecl; - - dd = dd->ParentDecl; - } - return nullptr; - } - - RefPtr findGlobalGenericSubst( - RefPtr substs, - GlobalGenericParamDecl* paramDecl) - { - for(auto s = substs; s; s = s->outer) - { - auto gSubst = s.as(); - if(!gSubst) - continue; - - if(gSubst->paramDecl != paramDecl) - continue; - - return gSubst; - } - - return nullptr; - } - - RefPtr specializeSubstitutionsShallow( - RefPtr substToSpecialize, - RefPtr substsToApply, - RefPtr restSubst, - int* ioDiff) - { - SLANG_ASSERT(substToSpecialize); - return substToSpecialize->applySubstitutionsShallow(substsToApply, restSubst, ioDiff); - } - - RefPtr specializeGlobalGenericSubstitutions( - Decl* declToSpecialize, - RefPtr substsToSpecialize, - RefPtr substsToApply, - int* ioDiff, - HashSet& ioParametersFound) - { - // Any existing global-generic substitutions will trigger - // a recursive case that skips the rest of the function. - for(auto specSubst = substsToSpecialize; specSubst; specSubst = specSubst->outer) - { - auto specGlobalGenericSubst = specSubst.as(); - if(!specGlobalGenericSubst) - continue; - - ioParametersFound.Add(specGlobalGenericSubst->paramDecl); - - int diff = 0; - auto restSubst = specializeGlobalGenericSubstitutions( - declToSpecialize, - specSubst->outer, - substsToApply, - &diff, - ioParametersFound); - - auto firstSubst = specializeSubstitutionsShallow( - specGlobalGenericSubst, - substsToApply, - restSubst, - &diff); - - *ioDiff += diff; - return firstSubst; - } - - // No more existing substitutions, so we know we can apply - // our global generic substitutions without any special work. - - // We expect global generic substitutions to come at - // the end of the list in all cases, so lets advance - // until we see them. - RefPtr appGlobalGenericSubsts = substsToApply; - while(appGlobalGenericSubsts && !appGlobalGenericSubsts.as()) - appGlobalGenericSubsts = appGlobalGenericSubsts->outer; - - - // If there is nothing to apply, then we are done - if(!appGlobalGenericSubsts) - return nullptr; - - // Otherwise, it seems like something has to change. - (*ioDiff)++; - - // If there were no parameters bound by the existing substitution, - // then we can safely use the global generics from the to-apply set. - if(ioParametersFound.Count() == 0) - return appGlobalGenericSubsts; - - RefPtr resultSubst; - RefPtr* link = &resultSubst; - for(auto appSubst = appGlobalGenericSubsts; appSubst; appSubst = appSubst->outer) - { - auto appGlobalGenericSubst = appSubst.as(); - if(!appSubst) - continue; - - // Don't include substitutions for parameters already handled. - if(ioParametersFound.Contains(appGlobalGenericSubst->paramDecl)) - continue; - - RefPtr newSubst = new GlobalGenericParamSubstitution(); - newSubst->paramDecl = appGlobalGenericSubst->paramDecl; - newSubst->actualType = appGlobalGenericSubst->actualType; - newSubst->constraintArgs = appGlobalGenericSubst->constraintArgs; - - *link = newSubst; - link = &newSubst->outer; - } - - return resultSubst; - } - - RefPtr specializeGlobalGenericSubstitutions( - Decl* declToSpecialize, - RefPtr substsToSpecialize, - RefPtr substsToApply, - int* ioDiff) - { - // Keep track of any parameters already present in the - // existing substitution. - HashSet parametersFound; - return specializeGlobalGenericSubstitutions(declToSpecialize, substsToSpecialize, substsToApply, ioDiff, parametersFound); - } - - - // Construct new substitutions to apply to a declaration, - // based on a provided substitution set to be applied - RefPtr specializeSubstitutions( - Decl* declToSpecialize, - RefPtr substsToSpecialize, - RefPtr substsToApply, - int* ioDiff) - { - // No declaration? Then nothing to specialize. - if(!declToSpecialize) - return nullptr; - - // No (remaining) substitutions to apply? Then we are done. - if(!substsToApply) - return substsToSpecialize; - - // Walk the hierarchy of the declaration to determine what specializations might apply. - // We assume that the `substsToSpecialize` must be aligned with the ancestor - // hierarchy of `declToSpecialize` such that if, e.g., the `declToSpecialize` is - // nested directly in a generic, then `substToSpecialize` will either start with - // the corresponding `GenericSubstitution` or there will be *no* generic substitutions - // corresponding to that decl. - for(Decl* ancestorDecl = declToSpecialize; ancestorDecl; ancestorDecl = ancestorDecl->ParentDecl) - { - if(auto ancestorGenericDecl = as(ancestorDecl)) - { - // The declaration is nested inside a generic. - // Does it already have a specialization for that generic? - if(auto specGenericSubst = as(substsToSpecialize)) - { - if(specGenericSubst->genericDecl == ancestorGenericDecl) - { - // Yes. We have an existing specialization, so we will - // keep one matching it in place. - int diff = 0; - auto restSubst = specializeSubstitutions( - ancestorGenericDecl->ParentDecl, - specGenericSubst->outer, - substsToApply, - &diff); - - auto firstSubst = specializeSubstitutionsShallow( - specGenericSubst, - substsToApply, - restSubst, - &diff); - - *ioDiff += diff; - return firstSubst; - } - } - - // If the declaration is not already specialized - // for the given generic, then see if we are trying - // to *apply* such specializations to it. - // - // TODO: The way we handle things right now with - // "default" specializations, this case shouldn't - // actually come up. - // - for(auto s = substsToApply; s; s = s->outer) - { - auto appGenericSubst = as(s); - if(!appGenericSubst) - continue; - - if(appGenericSubst->genericDecl != ancestorGenericDecl) - continue; - - // The substitutions we are applying are trying - // to specialize this generic, but we don't already - // have a generic substitution in place. - // We will need to create one. - - int diff = 0; - auto restSubst = specializeSubstitutions( - ancestorGenericDecl->ParentDecl, - substsToSpecialize, - substsToApply, - &diff); - - RefPtr firstSubst = new GenericSubstitution(); - firstSubst->genericDecl = ancestorGenericDecl; - firstSubst->args = appGenericSubst->args; - firstSubst->outer = restSubst; - - (*ioDiff)++; - return firstSubst; - } - } - else if(auto ancestorInterfaceDecl = as(ancestorDecl)) - { - // The task is basically the same as for the generic case: - // We want to see if there is any existing substitution that - // applies to this declaration, and use that if possible. - - // The declaration is nested inside a generic. - // Does it already have a specialization for that generic? - if(auto specThisTypeSubst = as(substsToSpecialize)) - { - if(specThisTypeSubst->interfaceDecl == ancestorInterfaceDecl) - { - // Yes. We have an existing specialization, so we will - // keep one matching it in place. - int diff = 0; - auto restSubst = specializeSubstitutions( - ancestorInterfaceDecl->ParentDecl, - specThisTypeSubst->outer, - substsToApply, - &diff); - - auto firstSubst = specializeSubstitutionsShallow( - specThisTypeSubst, - substsToApply, - restSubst, - &diff); - - *ioDiff += diff; - return firstSubst; - } - } - - // Otherwise, check if we are trying to apply - // a this-type substitution to the given interface - // - for(auto s = substsToApply; s; s = s->outer) - { - auto appThisTypeSubst = s.as(); - if(!appThisTypeSubst) - continue; - - if(appThisTypeSubst->interfaceDecl != ancestorInterfaceDecl) - continue; - - int diff = 0; - auto restSubst = specializeSubstitutions( - ancestorInterfaceDecl->ParentDecl, - substsToSpecialize, - substsToApply, - &diff); - - RefPtr firstSubst = new ThisTypeSubstitution(); - firstSubst->interfaceDecl = ancestorInterfaceDecl; - firstSubst->witness = appThisTypeSubst->witness; - firstSubst->outer = restSubst; - - (*ioDiff)++; - return firstSubst; - } - } - } - - // If we reach here then we've walked the full hierarchy up from - // `declToSpecialize` and either didn't run into an generic/interface - // declarations, or we didn't find any attempt to specialize them - // in either substitution. - // - // As an invariant, there should *not* be any generic or this-type - // substitutions in `substToSpecialize`, because otherwise they - // would be specializations that don't actually apply to the given - // declaration. - // - // The remaining substitutions to apply, if any, should thus be - // global-generic substitutions. And similarly, those are the - // only remaining substitutions we really care about in - // `substsToApply`. - // - // Note: this does *not* mean that `substsToApply` doesn't have - // any generic or this-type substitutions; it just means that none - // of them were applicable. - // - return specializeGlobalGenericSubstitutions( - declToSpecialize, - substsToSpecialize, - substsToApply, - ioDiff); - } - - DeclRefBase DeclRefBase::SubstituteImpl(SubstitutionSet substSet, int* ioDiff) - { - // Nothing to do when we have no declaration. - if(!decl) - return *this; - - // Apply the given substitutions to any specializations - // that have already been applied to this declaration. - int diff = 0; - - auto substSubst = specializeSubstitutions( - decl, - substitutions.substitutions, - substSet.substitutions, - &diff); - - if (!diff) - return *this; - - *ioDiff += diff; - - DeclRefBase substDeclRef; - substDeclRef.decl = decl; - substDeclRef.substitutions = substSubst; - - // TODO: The old code here used to try to translate a decl-ref - // to an associated type in a decl-ref for the concrete type - // in a particular implementation. - // - // I have only kept that logic in `DeclRefType::SubstituteImpl`, - // but it may turn out it is needed here too. - - return substDeclRef; - } - - - // Check if this is an equivalent declaration reference to another - bool DeclRefBase::Equals(DeclRefBase const& declRef) const - { - if (decl != declRef.decl) - return false; - if (!substitutions.Equals(declRef.substitutions)) - return false; - - return true; - } - - // Convenience accessors for common properties of declarations - Name* DeclRefBase::GetName() const - { - return decl->nameAndLoc.name; - } - - SourceLoc DeclRefBase::getLoc() const - { - return decl->loc; - } - - DeclRefBase DeclRefBase::GetParent() const - { - // Want access to the free function (the 'as' method by default gets priority) - // Can access as method with this->as because it removes any ambiguity. - using Slang::as; - - auto parentDecl = decl->ParentDecl; - if (!parentDecl) - return DeclRefBase(); - - // Default is to apply the same set of substitutions/specializations - // to the parent declaration as were applied to the child. - RefPtr substToApply = substitutions.substitutions; - - if(auto interfaceDecl = as(decl)) - { - // The declaration being referenced is an `interface` declaration, - // and there might be a this-type substitution in place. - // A reference to the parent of the interface declaration - // should not include that substitution. - if(auto thisTypeSubst = as(substToApply)) - { - if(thisTypeSubst->interfaceDecl == interfaceDecl) - { - // Strip away that specializations that apply to the interface. - substToApply = thisTypeSubst->outer; - } - } - } - - if (auto parentGenericDecl = as(parentDecl)) - { - // The parent of this declaration is a generic, which means - // that the decl-ref to the current declaration might include - // substitutions that specialize the generic parameters. - // A decl-ref to the parent generic should *not* include - // those substitutions. - // - if(auto genericSubst = as(substToApply)) - { - if(genericSubst->genericDecl == parentGenericDecl) - { - // Strip away the specializations that were applied to the parent. - substToApply = genericSubst->outer; - } - } - } - - return DeclRefBase(parentDecl, substToApply); - } - - int DeclRefBase::GetHashCode() const - { - return combineHash(PointerHash<1>::GetHashCode(decl), substitutions.GetHashCode()); - } - - // Val - - RefPtr Val::Substitute(SubstitutionSet subst) - { - if (!subst) return this; - int diff = 0; - return SubstituteImpl(subst, &diff); - } - - RefPtr Val::SubstituteImpl(SubstitutionSet /*subst*/, int* /*ioDiff*/) - { - // Default behavior is to not substitute at all - return this; - } - - // IntVal - - IntegerLiteralValue GetIntVal(RefPtr val) - { - if (auto constantVal = as(val)) - { - return constantVal->value; - } - SLANG_UNEXPECTED("needed a known integer value"); - return 0; - } - - // ConstantIntVal - - bool ConstantIntVal::EqualsVal(Val* val) - { - if (auto intVal = as(val)) - return value == intVal->value; - return false; - } - - String ConstantIntVal::ToString() - { - return String(value); - } - - int ConstantIntVal::GetHashCode() - { - return (int) value; - } - - // - - void registerBuiltinDecl( - Session* session, - RefPtr decl, - RefPtr modifier) - { - auto type = DeclRefType::Create( - session, - DeclRef(decl.Ptr(), nullptr)); - session->builtinTypes[(int)modifier->tag] = type; - } - - void registerMagicDecl( - Session* session, - RefPtr decl, - RefPtr modifier) - { - session->magicDecls[modifier->name] = decl.Ptr(); - } - - RefPtr findMagicDecl( - Session* session, - String const& name) - { - return session->magicDecls[name].GetValue(); - } - - // - - SyntaxNodeBase* createInstanceOfSyntaxClassByName( - String const& name) - { - if(0) {} - #define CASE(NAME) \ - else if(name == #NAME) return new NAME() - - CASE(GLSLBufferModifier); - CASE(GLSLWriteOnlyModifier); - CASE(GLSLReadOnlyModifier); - CASE(GLSLPatchModifier); - CASE(SimpleModifier); - - #undef CASE - else - { - SLANG_UNEXPECTED("unhandled syntax class name"); - UNREACHABLE_RETURN(nullptr); - } - } - - // - - // HLSLPatchType - - Type* HLSLPatchType::getElementType() - { - return as(findInnerMostGenericSubstitution(declRef.substitutions)->args[0]); - } - - IntVal* HLSLPatchType::getElementCount() - { - return as(findInnerMostGenericSubstitution(declRef.substitutions)->args[1]); - } - - // Constructors for types - - RefPtr getArrayType( - Type* elementType, - IntVal* elementCount) - { - auto session = elementType->getSession(); - auto arrayType = new ArrayExpressionType(); - arrayType->setSession(session); - arrayType->baseType = elementType; - arrayType->ArrayLength = elementCount; - return arrayType; - } - - RefPtr getArrayType( - Type* elementType) - { - auto session = elementType->getSession(); - auto arrayType = new ArrayExpressionType(); - arrayType->setSession(session); - arrayType->baseType = elementType; - return arrayType; - } - - RefPtr getNamedType( - Session* session, - DeclRef const& declRef) - { - DeclRef specializedDeclRef = createDefaultSubstitutionsIfNeeded(session, declRef).as(); - - auto namedType = new NamedExpressionType(specializedDeclRef); - namedType->setSession(session); - return namedType; - } - - RefPtr getTypeType( - Type* type) - { - auto session = type->getSession(); - auto typeType = new TypeType(type); - typeType->setSession(session); - return typeType; - } - - RefPtr getFuncType( - Session* session, - DeclRef const& declRef) - { - RefPtr funcType = new FuncType(); - funcType->setSession(session); - - funcType->resultType = GetResultType(declRef); - for (auto paramDeclRef : GetParameters(declRef)) - { - auto paramDecl = paramDeclRef.getDecl(); - auto paramType = GetType(paramDeclRef); - if( paramDecl->FindModifier() ) - { - paramType = session->getRefType(paramType); - } - else if( paramDecl->FindModifier() ) - { - if(paramDecl->FindModifier() || paramDecl->FindModifier()) - { - paramType = session->getInOutType(paramType); - } - else - { - paramType = session->getOutType(paramType); - } - } - funcType->paramTypes.add(paramType); - } - - return funcType; - } - - RefPtr getGenericDeclRefType( - Session* session, - DeclRef const& declRef) - { - auto genericDeclRefType = new GenericDeclRefType(declRef); - genericDeclRefType->setSession(session); - return genericDeclRefType; - } - - RefPtr getSamplerStateType( - Session* session) - { - auto samplerStateType = new SamplerStateType(); - samplerStateType->setSession(session); - return samplerStateType; - } - - // TODO: should really have a `type.cpp` and a `witness.cpp` - - bool TypeEqualityWitness::EqualsVal(Val* val) - { - auto otherWitness = as(val); - if (!otherWitness) - return false; - return sub->Equals(otherWitness->sub); - } - - RefPtr TypeEqualityWitness::SubstituteImpl(SubstitutionSet subst, int * ioDiff) - { - RefPtr rs = new TypeEqualityWitness(); - rs->sub = sub->SubstituteImpl(subst, ioDiff).as(); - rs->sup = sup->SubstituteImpl(subst, ioDiff).as(); - return rs; - } - - String TypeEqualityWitness::ToString() - { - return "TypeEqualityWitness(" + sub->ToString() + ")"; - } - - int TypeEqualityWitness::GetHashCode() - { - return sub->GetHashCode(); - } - - bool DeclaredSubtypeWitness::EqualsVal(Val* val) - { - auto otherWitness = as(val); - if(!otherWitness) - return false; - - return sub->Equals(otherWitness->sub) - && sup->Equals(otherWitness->sup) - && declRef.Equals(otherWitness->declRef); - } - - RefPtr findThisTypeSubstitution( - Substitutions* substs, - InterfaceDecl* interfaceDecl) - { - for(RefPtr s = substs; s; s = s->outer) - { - auto thisTypeSubst = as(s); - if(!thisTypeSubst) - continue; - - if(thisTypeSubst->interfaceDecl != interfaceDecl) - continue; - - return thisTypeSubst; - } - - return nullptr; - } - - RefPtr DeclaredSubtypeWitness::SubstituteImpl(SubstitutionSet subst, int * ioDiff) - { - if (auto genConstraintDeclRef = declRef.as()) - { - auto genConstraintDecl = genConstraintDeclRef.getDecl(); - - // search for a substitution that might apply to us - for(auto s = subst.substitutions; s; s = s->outer) - { - if(auto genericSubst = as(s)) - { - // the generic decl associated with the substitution list must be - // the generic decl that declared this parameter - auto genericDecl = genericSubst->genericDecl; - if (genericDecl != genConstraintDecl->ParentDecl) - continue; - - bool found = false; - Index index = 0; - for (auto m : genericDecl->Members) - { - if (auto constraintParam = as(m)) - { - if (constraintParam == declRef.getDecl()) - { - found = true; - break; - } - index++; - } - } - if (found) - { - (*ioDiff)++; - auto ordinaryParamCount = genericDecl->getMembersOfType().getCount() + - genericDecl->getMembersOfType().getCount(); - SLANG_ASSERT(index + ordinaryParamCount < genericSubst->args.getCount()); - return genericSubst->args[index + ordinaryParamCount]; - } - } - else if(auto globalGenericSubst = s.as()) - { - // check if the substitution is really about this global generic type parameter - if (globalGenericSubst->paramDecl != genConstraintDecl->ParentDecl) - continue; - - for(auto constraintArg : globalGenericSubst->constraintArgs) - { - if(constraintArg.decl.Ptr() != genConstraintDecl) - continue; - - (*ioDiff)++; - return constraintArg.val; - } - } - } - } - - // Perform substitution on the constituent elements. - int diff = 0; - auto substSub = sub->SubstituteImpl(subst, &diff).as(); - auto substSup = sup->SubstituteImpl(subst, &diff).as(); - auto substDeclRef = declRef.SubstituteImpl(subst, &diff); - if (!diff) - return this; - - (*ioDiff)++; - - // If we have a reference to a type constraint for an - // associated type declaration, then we can replace it - // with the concrete conformance witness for a concrete - // type implementing the outer interface. - // - // TODO: It is a bit gross that we use `GenericTypeConstraintDecl` for - // associated types, when they aren't really generic type *parameters*, - // so we'll need to change this location in the code if we ever clean - // up the hierarchy. - // - if (auto substTypeConstraintDecl = as(substDeclRef.decl)) - { - if (auto substAssocTypeDecl = as(substTypeConstraintDecl->ParentDecl)) - { - if (auto interfaceDecl = as(substAssocTypeDecl->ParentDecl)) - { - // At this point we have a constraint decl for an associated type, - // and we nee to see if we are dealing with a concrete substitution - // for the interface around that associated type. - if(auto thisTypeSubst = findThisTypeSubstitution(substDeclRef.substitutions, interfaceDecl)) - { - // We need to look up the declaration that satisfies - // the requirement named by the associated type. - Decl* requirementKey = substTypeConstraintDecl; - RequirementWitness requirementWitness = tryLookUpRequirementWitness(thisTypeSubst->witness, requirementKey); - switch(requirementWitness.getFlavor()) - { - default: - break; - - case RequirementWitness::Flavor::val: - { - auto satisfyingVal = requirementWitness.getVal(); - return satisfyingVal; - } - } - } - } - } - } - - - - - RefPtr rs = new DeclaredSubtypeWitness(); - rs->sub = substSub; - rs->sup = substSup; - rs->declRef = substDeclRef; - return rs; - } - - String DeclaredSubtypeWitness::ToString() - { - StringBuilder sb; - sb << "DeclaredSubtypeWitness("; - sb << this->sub->ToString(); - sb << ", "; - sb << this->sup->ToString(); - sb << ", "; - sb << this->declRef.toString(); - sb << ")"; - return sb.ProduceString(); - } - - int DeclaredSubtypeWitness::GetHashCode() - { - return declRef.GetHashCode(); - } - - // TransitiveSubtypeWitness - - bool TransitiveSubtypeWitness::EqualsVal(Val* val) - { - auto otherWitness = as(val); - if(!otherWitness) - return false; - - return sub->Equals(otherWitness->sub) - && sup->Equals(otherWitness->sup) - && subToMid->EqualsVal(otherWitness->subToMid) - && midToSup.Equals(otherWitness->midToSup); - } - - RefPtr TransitiveSubtypeWitness::SubstituteImpl(SubstitutionSet subst, int * ioDiff) - { - int diff = 0; - - RefPtr substSub = sub->SubstituteImpl(subst, &diff).as(); - RefPtr substSup = sup->SubstituteImpl(subst, &diff).as(); - RefPtr substSubToMid = subToMid->SubstituteImpl(subst, &diff).as(); - DeclRef substMidToSup = midToSup.SubstituteImpl(subst, &diff); - - // If nothing changed, then we can bail out early. - if (!diff) - return this; - - // Something changes, so let the caller know. - (*ioDiff)++; - - // TODO: are there cases where we can simplify? - // - // In principle, if either `subToMid` or `midToSub` turns into - // a reflexive subtype witness, then we could drop that side, - // and just return the other one (this would imply that `sub == mid` - // or `mid == sup` after substitutions). - // - // In the long run, is it also possible that if `sub` gets resolved - // to a concrete type *and* we decide to flatten out the inheritance - // graph into a linearized "class precedence list" stored in any - // aggregate type, then we could potentially just redirect to point - // to the appropriate inheritance decl in the original type. - // - // For now I'm going to ignore those possibilities and hope for the best. - - // In the simple case, we just construct a new transitive subtype - // witness, and we move on with life. - RefPtr result = new TransitiveSubtypeWitness(); - result->sub = substSub; - result->sup = substSup; - result->subToMid = substSubToMid; - result->midToSup = substMidToSup; - return result; - } - - String TransitiveSubtypeWitness::ToString() - { - // Note: we only print the constituent - // witnesses, and rely on them to print - // the starting and ending types. - StringBuilder sb; - sb << "TransitiveSubtypeWitness("; - sb << this->subToMid->ToString(); - sb << ", "; - sb << this->midToSup.toString(); - sb << ")"; - return sb.ProduceString(); - } - - int TransitiveSubtypeWitness::GetHashCode() - { - auto hash = sub->GetHashCode(); - hash = combineHash(hash, sup->GetHashCode()); - hash = combineHash(hash, subToMid->GetHashCode()); - hash = combineHash(hash, midToSup.GetHashCode()); - return hash; - } - - // - - String DeclRefBase::toString() const - { - if (!decl) return ""; - - auto name = decl->getName(); - if (!name) return ""; - - // TODO: need to print out substitutions too! - return name->text; - } - - bool SubstitutionSet::Equals(const SubstitutionSet& substSet) const - { - if (substitutions == substSet.substitutions) - { - return true; - } - if (substitutions == nullptr || substSet.substitutions == nullptr) - { - return false; - } - return substitutions->Equals(substSet.substitutions); - } - - int SubstitutionSet::GetHashCode() const - { - int rs = 0; - if (substitutions) - rs = combineHash(rs, substitutions->GetHashCode()); - return rs; - } - - // ExtractExistentialType - - String ExtractExistentialType::ToString() - { - String result; - result.append(declRef.toString()); - result.append(".This"); - return result; - } - - bool ExtractExistentialType::EqualsImpl(Type* type) - { - if( auto extractExistential = as(type) ) - { - return declRef.Equals(extractExistential->declRef); - } - return false; - } - - int ExtractExistentialType::GetHashCode() - { - return declRef.GetHashCode(); - } - - RefPtr ExtractExistentialType::CreateCanonicalType() - { - return this; - } - - RefPtr ExtractExistentialType::SubstituteImpl(SubstitutionSet subst, int* ioDiff) - { - int diff = 0; - auto substDeclRef = declRef.SubstituteImpl(subst, &diff); - if(!diff) - return this; - - (*ioDiff)++; - - RefPtr substValue = new ExtractExistentialType(); - substValue->declRef = declRef; - return substValue; - } - - // ExtractExistentialSubtypeWitness - - bool ExtractExistentialSubtypeWitness::EqualsVal(Val* val) - { - if( auto extractWitness = as(val) ) - { - return declRef.Equals(extractWitness->declRef); - } - return false; - } - - String ExtractExistentialSubtypeWitness::ToString() - { - String result; - result.append("extractExistentialValue("); - result.append(declRef.toString()); - result.append(")"); - return result; - } - - int ExtractExistentialSubtypeWitness::GetHashCode() - { - return declRef.GetHashCode(); - } - - RefPtr ExtractExistentialSubtypeWitness::SubstituteImpl(SubstitutionSet subst, int* ioDiff) - { - int diff = 0; - - auto substDeclRef = declRef.SubstituteImpl(subst, &diff); - auto substSub = sub->SubstituteImpl(subst, &diff).as(); - auto substSup = sup->SubstituteImpl(subst, &diff).as(); - - if(!diff) - return this; - - (*ioDiff)++; - - RefPtr substValue = new ExtractExistentialSubtypeWitness(); - substValue->declRef = declRef; - substValue->sub = substSub; - substValue->sup = substSup; - return substValue; - } - - // - // TaggedUnionType - // - - String TaggedUnionType::ToString() - { - String result; - result.append("__TaggedUnion("); - bool first = true; - for( auto caseType : caseTypes ) - { - if(!first) result.append(", "); - first = false; - - result.append(caseType->ToString()); - } - result.append(")"); - return result; - } - - bool TaggedUnionType::EqualsImpl(Type* type) - { - auto taggedUnion = as(type); - if(!taggedUnion) - return false; - - auto caseCount = caseTypes.getCount(); - if(caseCount != taggedUnion->caseTypes.getCount()) - return false; - - for( Index ii = 0; ii < caseCount; ++ii ) - { - if(!caseTypes[ii]->Equals(taggedUnion->caseTypes[ii])) - return false; - } - return true; - } - - int TaggedUnionType::GetHashCode() - { - int hashCode = 0; - for( auto caseType : caseTypes ) - { - hashCode = combineHash(hashCode, caseType->GetHashCode()); - } - return hashCode; - } - - RefPtr TaggedUnionType::CreateCanonicalType() - { - RefPtr canType = new TaggedUnionType(); - canType->setSession(getSession()); - - for( auto caseType : caseTypes ) - { - auto canCaseType = caseType->GetCanonicalType(); - canType->caseTypes.add(canCaseType); - } - - return canType; - } - - RefPtr TaggedUnionType::SubstituteImpl(SubstitutionSet subst, int* ioDiff) - { - int diff = 0; - - List> substCaseTypes; - for( auto caseType : caseTypes ) - { - substCaseTypes.add(caseType->SubstituteImpl(subst, &diff).as()); - } - if(!diff) - return this; - - (*ioDiff)++; - - RefPtr substType = new TaggedUnionType(); - substType->setSession(getSession()); - substType->caseTypes.swapWith(substCaseTypes); - return substType; - } - -// -// TaggedUnionSubtypeWitness -// - - -bool TaggedUnionSubtypeWitness::EqualsVal(Val* val) -{ - auto taggedUnionWitness = as(val); - if(!taggedUnionWitness) - return false; - - auto caseCount = caseWitnesses.getCount(); - if(caseCount != taggedUnionWitness->caseWitnesses.getCount()) - return false; - - for(Index ii = 0; ii < caseCount; ++ii) - { - if(!caseWitnesses[ii]->EqualsVal(taggedUnionWitness->caseWitnesses[ii])) - return false; - } - - return true; -} - -String TaggedUnionSubtypeWitness::ToString() -{ - String result; - result.append("TaggedUnionSubtypeWitness("); - bool first = true; - for( auto caseWitness : caseWitnesses ) - { - if(!first) result.append(", "); - first = false; - - result.append(caseWitness->ToString()); - } - return result; -} - -int TaggedUnionSubtypeWitness::GetHashCode() -{ - int hash = 0; - for( auto caseWitness : caseWitnesses ) - { - hash = combineHash(hash, caseWitness->GetHashCode()); - } - return hash; -} - -RefPtr TaggedUnionSubtypeWitness::SubstituteImpl(SubstitutionSet subst, int* ioDiff) -{ - int diff = 0; - - auto substSub = sub->SubstituteImpl(subst, &diff).as(); - auto substSup = sup->SubstituteImpl(subst, &diff).as(); - - List> substCaseWitnesses; - for( auto caseWitness : caseWitnesses ) - { - substCaseWitnesses.add(caseWitness->SubstituteImpl(subst, &diff)); - } - - if(!diff) - return this; - - (*ioDiff)++; - - RefPtr substWitness = new TaggedUnionSubtypeWitness(); - substWitness->sub = substSub; - substWitness->sup = substSup; - substWitness->caseWitnesses.swapWith(substCaseWitnesses); - return substWitness; -} - -Module* getModule(Decl* decl) -{ - for( auto dd = decl; dd; dd = dd->ParentDecl ) - { - if(auto moduleDecl = as(dd)) - return moduleDecl->module; - } - return nullptr; -} - -bool findImageFormatByName(char const* name, ImageFormat* outFormat) -{ - static const struct - { - char const* name; - ImageFormat format; - } kFormats[] = - { -#define FORMAT(NAME) { #NAME, ImageFormat::NAME }, -#include "image-format-defs.h" - }; - - for( auto item : kFormats ) - { - if( strcmp(item.name, name) == 0 ) - { - *outFormat = item.format; - return true; - } - } - - return false; -} - -char const* getGLSLNameForImageFormat(ImageFormat format) -{ - switch( format ) - { - default: return "unhandled"; -#define FORMAT(NAME) case ImageFormat::NAME: return #NAME; -#include "image-format-defs.h" - } -} - -// -// ExistentialSpecializedType -// - -String ExistentialSpecializedType::ToString() -{ - String result; - result.append("__ExistentialSpecializedType("); - result.append(baseType->ToString()); - for( auto arg : slots.args ) - { - result.append(", "); - result.append(arg.type->ToString()); - } - result.append(")"); - return result; -} - -bool ExistentialSpecializedType::EqualsImpl(Type * type) -{ - auto other = as(type); - if(!other) - return false; - - if(!baseType->Equals(other->baseType)) - return false; - - auto argCount = slots.args.getCount(); - if(argCount != other->slots.args.getCount()) - return false; - - for( Index ii = 0; ii < argCount; ++ii ) - { - if(!slots.args[ii].type->Equals(other->slots.args[ii].type)) - return false; - - if(!slots.args[ii].witness->EqualsVal(other->slots.args[ii].witness)) - return false; - } - return true; -} - -int ExistentialSpecializedType::GetHashCode() -{ - Hasher hasher; - hasher.hashObject(baseType); - for(auto arg : slots.args) - { - hasher.hashObject(arg.type); - hasher.hashObject(arg.witness); - } - return hasher.getResult(); -} - -RefPtr ExistentialSpecializedType::CreateCanonicalType() -{ - RefPtr canType = new ExistentialSpecializedType(); - canType->setSession(getSession()); - - canType->baseType = baseType->GetCanonicalType(); - for( auto paramType : slots.paramTypes ) - { - canType->slots.paramTypes.add( paramType->GetCanonicalType() ); - } - for( auto arg : slots.args ) - { - ExistentialTypeSlots::Arg canArg; - canArg.type = arg.type->GetCanonicalType(); - canArg.witness = arg.witness; - canType->slots.args.add(canArg); - } - return canType; -} - -RefPtr ExistentialSpecializedType::SubstituteImpl(SubstitutionSet subst, int* ioDiff) -{ - int diff = 0; - - auto substBaseType = baseType->SubstituteImpl(subst, &diff).as(); - - ExistentialTypeSlots substSlots; - for( auto paramType : slots.paramTypes ) - { - substSlots.paramTypes.add( paramType->SubstituteImpl(subst, &diff).as() ); - } - for( auto arg : slots.args ) - { - ExistentialTypeSlots::Arg substArg; - substArg.type = arg.type->SubstituteImpl(subst, &diff).as(); - substArg.witness = arg.witness->SubstituteImpl(subst, &diff); - substSlots.args.add(substArg); - } - - if(!diff) - return this; - - (*ioDiff)++; - - RefPtr substType = new ExistentialSpecializedType(); - substType->setSession(getSession()); - substType->baseType = substBaseType; - substType->slots = substSlots; - return substType; -} - -} // namespace Slang diff --git a/source/slang/syntax.h b/source/slang/syntax.h deleted file mode 100644 index aa3944d0a..000000000 --- a/source/slang/syntax.h +++ /dev/null @@ -1,1419 +0,0 @@ -#ifndef SLANG_SYNTAX_H -#define SLANG_SYNTAX_H - -#include "../core/basic.h" -#include "ir.h" -#include "lexer.h" -#include "profile.h" -#include "type-system-shared.h" -#include "../../slang.h" - -#include - -namespace Slang -{ - class Module; - class Name; - class Session; - class Substitutions; - class SyntaxVisitor; - class FuncDecl; - class Layout; - - struct IExprVisitor; - struct IDeclVisitor; - struct IModifierVisitor; - struct IStmtVisitor; - struct ITypeVisitor; - struct IValVisitor; - - class Parser; - class SyntaxNode; - - typedef RefPtr (*SyntaxParseCallback)(Parser* parser, void* userData); - - typedef unsigned int ConversionCost; - enum : ConversionCost - { - // No conversion at all - kConversionCost_None = 0, - - // Conversion from a buffer to the type it carries needs to add a minimal - // extra cost, just so we can distinguish an overload on `ConstantBuffer` - // from one on `Foo` - kConversionCost_ImplicitDereference = 10, - - // Conversions based on explicit sub-typing relationships are the cheapest - // - // TODO(tfoley): We will eventually need a discipline for ranking - // when two up-casts are comparable. - kConversionCost_CastToInterface = 50, - - // Conversion that is lossless and keeps the "kind" of the value the same - kConversionCost_RankPromotion = 150, - - // Conversions that are lossless, but change "kind" - kConversionCost_UnsignedToSignedPromotion = 200, - - // Conversion from signed->unsigned integer of same or greater size - kConversionCost_SignedToUnsignedConversion = 300, - - // Cost of converting an integer to a floating-point type - kConversionCost_IntegerToFloatConversion = 400, - - // Default case (usable for user-defined conversions) - kConversionCost_Default = 500, - - // Catch-all for conversions that should be discouraged - // (i.e., that really shouldn't be made implicitly) - // - // TODO: make these conversions not be allowed implicitly in "Slang mode" - kConversionCost_GeneralConversion = 900, - - // This is the cost of an explicit conversion, which should - // not actually be performed. - kConversionCost_Explicit = 90000, - - // Additional conversion cost to add when promoting from a scalar to - // a vector (this will be added to the cost, if any, of converting - // the element type of the vector) - kConversionCost_ScalarToVector = 1, - - // Conversion is impossible - kConversionCost_Impossible = 0xFFFFFFFF, - }; - - enum class ImageFormat - { -#define FORMAT(NAME) NAME, -#include "image-format-defs.h" - }; - - bool findImageFormatByName(char const* name, ImageFormat* outFormat); - char const* getGLSLNameForImageFormat(ImageFormat format); - - // TODO(tfoley): We should ditch this enumeration - // and just use the IR opcodes that represent these - // types directly. The one major complication there - // is that the order of the enum values currently - // matters, since it determines promotion rank. - // We either need to keep that restriction, or - // look up promotion rank by some other means. - // - - class Decl; - class Val; - - // Forward-declare all syntax classes -#define SYNTAX_CLASS(NAME, BASE, ...) class NAME; -#include "object-meta-begin.h" -#include "syntax-defs.h" -#include "object-meta-end.h" - - // Helper type for pairing up a name and the location where it appeared - struct NameLoc - { - Name* name; - SourceLoc loc; - - NameLoc() - : name(nullptr) - {} - - explicit NameLoc(Name* name) - : name(name) - {} - - - NameLoc(Name* name, SourceLoc loc) - : name(name) - , loc(loc) - {} - - NameLoc(Token const& token) - : name(token.getNameOrNull()) - , loc(token.getLoc()) - {} - }; - - // Helper class for iterating over a list of heap-allocated modifiers - struct ModifierList - { - struct Iterator - { - Modifier* current; - - Modifier* operator*() - { - return current; - } - - void operator++(); -#if 0 - { - current = current->next.Ptr(); - } -#endif - - bool operator!=(Iterator other) - { - return current != other.current; - }; - - Iterator() - : current(nullptr) - {} - - Iterator(Modifier* modifier) - : current(modifier) - {} - }; - - ModifierList() - : modifiers(nullptr) - {} - - ModifierList(Modifier* modifiers) - : modifiers(modifiers) - {} - - Iterator begin() { return Iterator(modifiers); } - Iterator end() { return Iterator(nullptr); } - - Modifier* modifiers; - }; - - // Helper class for iterating over heap-allocated modifiers - // of a specific type. - template - struct FilteredModifierList - { - struct Iterator - { - Modifier* current; - - T* operator*() - { - return (T*)current; - } - - void operator++(); - #if 0 - { - current = Adjust(current->next.Ptr()); - } - #endif - - bool operator!=(Iterator other) - { - return current != other.current; - }; - - Iterator() - : current(nullptr) - {} - - Iterator(Modifier* modifier) - : current(modifier) - {} - }; - - FilteredModifierList() - : modifiers(nullptr) - {} - - FilteredModifierList(Modifier* modifiers) - : modifiers(Adjust(modifiers)) - {} - - Iterator begin() { return Iterator(modifiers); } - Iterator end() { return Iterator(nullptr); } - - static Modifier* Adjust(Modifier* modifier); - #if 0 - { - Modifier* m = modifier; - for (;;) - { - if (!m) return m; - if (dynamicCast(m)) return m; - m = m->next.Ptr(); - } - } - #endif - - Modifier* modifiers; - }; - - // A set of modifiers attached to a syntax node - struct Modifiers - { - // The first modifier in the linked list of heap-allocated modifiers - RefPtr first; - - template - FilteredModifierList getModifiersOfType() { return FilteredModifierList(first.Ptr()); } - - // Find the first modifier of a given type, or return `nullptr` if none is found. - template - T* findModifier() - { - return *getModifiersOfType().begin(); - } - - template - bool hasModifier() { return findModifier() != nullptr; } - - FilteredModifierList::Iterator begin() { return FilteredModifierList::Iterator(first.Ptr()); } - FilteredModifierList::Iterator end() { return FilteredModifierList::Iterator(nullptr); } - }; - - class NamedExpressionType; - class GenericDecl; - class ContainerDecl; - - // Try to extract a simple integer value from an `IntVal`. - // This fill assert-fail if the object doesn't represent a literal value. - IntegerLiteralValue GetIntVal(RefPtr val); - - // Represents how much checking has been applied to a declaration. - enum class DeclCheckState : uint8_t - { - // The declaration has been parsed, but not checked - Unchecked, - - // We are in the process of checking the declaration "header" - // (those parts of the declaration needed in order to - // reference it) - CheckingHeader, - - // We are done checking the declaration header. - CheckedHeader, - - // We have checked the declaration fully. - Checked, - }; - - void addModifier( - RefPtr syntax, - RefPtr modifier); - - struct QualType - { - RefPtr type; - bool IsLeftValue; - - QualType() - : IsLeftValue(false) - {} - - QualType(Type* type) - : type(type) - , IsLeftValue(false) - {} - - Type* Ptr() { return type.Ptr(); } - - operator Type*() { return type; } - operator RefPtr() { return type; } - RefPtr operator->() { return type; } - }; - - // A reference to a class of syntax node, that can be - // used to create instances on the fly - struct SyntaxClassBase - { - typedef void* (*CreateFunc)(); - - // Run-time type representation for syntax nodes - struct ClassInfo - { - // Textual class name, for debugging - char const* name; - - // Base class for runtime queries - ClassInfo const* baseClass; - - // Callback to use when creating instances - CreateFunc createFunc; - }; - - SyntaxClassBase() - {} - - SyntaxClassBase(ClassInfo const* classInfoIn) - : classInfo(classInfoIn) - {} - - void* createInstanceImpl() const - { - auto ci = classInfo; - if (!ci) return nullptr; - - auto cf = ci->createFunc; - if (!cf) return nullptr; - - return cf(); - } - - bool isSubClassOfImpl(SyntaxClassBase const& super) const; - - ClassInfo const* classInfo = nullptr; - - template - struct Impl - { - static void* createFunc(); - static const ClassInfo kClassInfo; - }; - }; - - template - struct SyntaxClass : SyntaxClassBase - { - SyntaxClass() - {} - - template - SyntaxClass(SyntaxClass const& other, - typename EnableIf::Value, void>::type* = 0) - : SyntaxClassBase(other.classInfo) - { - } - - T* createInstance() const - { - return (T*)createInstanceImpl(); - } - - SyntaxClass(const ClassInfo* classInfoIn): - SyntaxClassBase(classInfoIn) - {} - - static SyntaxClass getClass() - { - return SyntaxClass(&SyntaxClassBase::Impl::kClassInfo); - } - - template - bool isSubClassOf(SyntaxClass super) - { - return isSubClassOfImpl(super); - } - - template - bool isSubClassOf() - { - return isSubClassOf(SyntaxClass::getClass()); - } - }; - - template - SyntaxClass getClass() - { - return SyntaxClass::getClass(); - } - - struct SubstitutionSet - { - RefPtr substitutions; - operator Substitutions*() const - { - return substitutions; - } - - SubstitutionSet() {} - SubstitutionSet(RefPtr subst) - : substitutions(subst) - { - } - bool Equals(const SubstitutionSet& substSet) const; - int GetHashCode() const; - }; - - template - struct DeclRef; - - // A reference to a declaration, which may include - // substitutions for generic parameters. - struct DeclRefBase - { - typedef Decl DeclType; - - // The underlying declaration - Decl* decl = nullptr; - Decl* getDecl() const { return decl; } - - // Optionally, a chain of substitutions to perform - SubstitutionSet substitutions; - - DeclRefBase() - {} - - DeclRefBase(Decl* decl) - :decl(decl) - {} - - DeclRefBase(Decl* decl, SubstitutionSet subst) - :decl(decl), - substitutions(subst) - {} - - DeclRefBase(Decl* decl, RefPtr subst) - : decl(decl) - , substitutions(subst) - {} - - // Apply substitutions to a type or declaration - RefPtr Substitute(RefPtr type) const; - - DeclRefBase Substitute(DeclRefBase declRef) const; - - // Apply substitutions to an expression - RefPtr Substitute(RefPtr expr) const; - - // Apply substitutions to this declaration reference - DeclRefBase SubstituteImpl(SubstitutionSet subst, int* ioDiff); - - // Returns true if 'as' will return a valid cast - template - bool is() const { return Slang::as(decl) != nullptr; } - - // "dynamic cast" to a more specific declaration reference type - template - DeclRef as() const; - - // Check if this is an equivalent declaration reference to another - bool Equals(DeclRefBase const& declRef) const; - bool operator == (const DeclRefBase& other) const - { - return Equals(other); - } - - // Convenience accessors for common properties of declarations - Name* GetName() const; - SourceLoc getLoc() const; - DeclRefBase GetParent() const; - - int GetHashCode() const; - - // Debugging: - String toString() const; - }; - - template - struct DeclRef : DeclRefBase - { - typedef T DeclType; - - DeclRef() - {} - - DeclRef(T* decl, SubstitutionSet subst) - : DeclRefBase(decl, subst) - {} - - DeclRef(T* decl, RefPtr subst) - : DeclRefBase(decl, SubstitutionSet(subst)) - {} - - template - DeclRef(DeclRef const& other, - typename EnableIf::Value, void>::type* = 0) - : DeclRefBase(other.decl, other.substitutions) - { - } - - T* getDecl() const - { - return (T*)decl; - } - - operator T*() const - { - return getDecl(); - } - - // - static DeclRef unsafeInit(DeclRefBase const& declRef) - { - return DeclRef((T*) declRef.decl, declRef.substitutions); - } - - RefPtr Substitute(RefPtr type) const - { - return DeclRefBase::Substitute(type); - } - RefPtr Substitute(RefPtr expr) const - { - return DeclRefBase::Substitute(expr); - } - - // Apply substitutions to a type or declaration - template - DeclRef Substitute(DeclRef declRef) const - { - return DeclRef::unsafeInit(DeclRefBase::Substitute(declRef)); - } - - // Apply substitutions to this declaration reference - DeclRef SubstituteImpl(SubstitutionSet subst, int* ioDiff) - { - return DeclRef::unsafeInit(DeclRefBase::SubstituteImpl(subst, ioDiff)); - } - - DeclRef GetParent() const - { - return DeclRef::unsafeInit(DeclRefBase::GetParent()); - } - }; - - template - DeclRef DeclRefBase::as() const - { - DeclRef result; - result.decl = Slang::as(decl); - result.substitutions = substitutions; - return result; - } - - template - inline DeclRef makeDeclRef(T* decl) - { - return DeclRef(decl, nullptr); - } - - template - struct FilteredMemberList - { - typedef RefPtr Element; - - FilteredMemberList() - : m_begin(nullptr) - , m_end(nullptr) - {} - - explicit FilteredMemberList( - List const& list) - : m_begin(adjust(list.begin(), list.end())) - , m_end(list.end()) - {} - - struct Iterator - { - Element* m_cursor; - Element* m_end; - - bool operator!=(Iterator const& other) - { - return m_cursor != other.m_cursor; - } - - void operator++() - { - m_cursor = adjust(m_cursor + 1, m_end); - } - - RefPtr& operator*() - { - return *(RefPtr*)m_cursor; - } - }; - - Iterator begin() - { - Iterator iter = { m_begin, m_end }; - return iter; - } - - Iterator end() - { - Iterator iter = { m_end, m_end }; - return iter; - } - - static Element* adjust(Element* cursor, Element* end) - { - while (cursor != end) - { - if (as(*cursor)) - return cursor; - cursor++; - } - return cursor; - } - - // TODO(tfoley): It is ugly to have these. - // We should probably fix the call sites instead. - RefPtr& getFirst() { return *begin(); } - Index getCount() - { - Index count = 0; - for (auto iter : (*this)) - { - (void)iter; - count++; - } - return count; - } - - List> toArray() - { - List> result; - for (auto element : (*this)) - { - result.add(element); - } - return result; - } - - Element* m_begin; - Element* m_end; - }; - - struct TransparentMemberInfo - { - // The declaration of the transparent member - Decl* decl; - }; - - template - struct FilteredMemberRefList - { - List> const& decls; - SubstitutionSet substitutions; - - FilteredMemberRefList( - List> const& decls, - SubstitutionSet substitutions) - : decls(decls) - , substitutions(substitutions) - {} - - int Count() const - { - int count = 0; - for (auto d : *this) - count++; - return count; - } - - List> ToArray() const - { - List> result; - for (auto d : *this) - result.add(d); - return result; - } - - struct Iterator - { - FilteredMemberRefList const* list; - RefPtr* ptr; - RefPtr* end; - - Iterator() : list(nullptr), ptr(nullptr) {} - Iterator( - FilteredMemberRefList const* list, - RefPtr* ptr, - RefPtr* end) - : list(list) - , ptr(ptr) - , end(end) - {} - - bool operator!=(Iterator other) - { - return ptr != other.ptr; - } - - void operator++() - { - ptr = list->Adjust(ptr + 1, end); - } - - DeclRef operator*() - { - return DeclRef((T*) ptr->Ptr(), list->substitutions); - } - }; - - Iterator begin() const { return Iterator(this, Adjust(decls.begin(), decls.end()), decls.end()); } - Iterator end() const { return Iterator(this, decls.end(), decls.end()); } - - RefPtr* Adjust(RefPtr* ptr, RefPtr* end) const - { - for (; ptr != end; ptr++) - { - if (ptr->is()) - { - return ptr; - } - } - return end; - } - }; - - // - // type Expressions - // - - // A "type expression" is a term that we expect to resolve to a type during checking. - // We store both the original syntax and the resolved type here. - struct TypeExp - { - TypeExp() {} - TypeExp(TypeExp const& other) - : exp(other.exp) - , type(other.type) - {} - explicit TypeExp(RefPtr exp) - : exp(exp) - {} - explicit TypeExp(RefPtr type) - : type(type) - {} - TypeExp(RefPtr exp, RefPtr type) - : exp(exp) - , type(type) - {} - - RefPtr exp; - RefPtr type; - - bool Equals(Type* other); -#if 0 - { - return type->Equals(other); - } -#endif - bool Equals(RefPtr other); -#if 0 - { - return type->Equals(other.Ptr()); - } -#endif - Type* Ptr() { return type.Ptr(); } - operator Type*() - { - return type; - } - Type* operator->() { return Ptr(); } - - TypeExp Accept(SyntaxVisitor* visitor); - }; - - - - struct Scope : public RefObject - { - // The parent of this scope (where lookup should go if nothing is found locally) - RefPtr parent; - - // The next sibling of this scope (a peer for lookup) - RefPtr nextSibling; - - // The container to use for lookup - // - // Note(tfoley): This is kept as an unowned pointer - // so that a scope can't keep parts of the AST alive, - // but the opposite it allowed. - ContainerDecl* containerDecl; - }; - - // Masks to be applied when lookup up declarations - enum class LookupMask : uint8_t - { - type = 0x1, - Function = 0x2, - Value = 0x4, - Attribute = 0x8, - - Default = type | Function | Value, - }; - - // Represents one item found during lookup - struct LookupResultItem - { - // Sometimes lookup finds an item, but there were additional - // "hops" taken to reach it. We need to remember these steps - // so that if/when we consturct a full expression we generate - // appropriate AST nodes for all the steps. - // - // We build up a list of these "breadcrumbs" while doing - // lookup, and store them alongside each item found. - // - // As an example, suppose we have an HLSL `cbuffer` declaration: - // - // cbuffer C { float4 f; } - // - // This is syntax sugar for a global-scope variable of - // type `ConstantBuffer` where `T` is a `struct` containing - // all the members: - // - // struct Anon0 { float4 f; }; - // __transparent ConstantBuffer anon1; - // - // The `__transparent` modifier there captures the fact that - // when somebody writes `f` in their code, they expect it to - // "see through" the `cbuffer` declaration (or the global variable, - // in this case) and find the member inside. - // - // But when the user writes `f` we can't just create a simple - // `VarExpr` that refers directly to that field, because that - // doesn't actually reflect the required steps in a way that - // code generation can use. - // - // Instead we need to construct an expression like `(*anon1).f`, - // where there is are two additional steps in the process: - // - // 1. We needed to dereference the pointer-like type `ConstantBuffer` - // to get at a value of type `Anon0` - // 2. We needed to access a sub-field of the aggregate type `Anon0` - // - // We *could* just create these full-formed expressions during - // lookup, but this might mean creating a large number of - // AST nodes in cases where the user calls an overloaded function. - // At the very least we'd rather not heap-allocate in the common - // case where no "extra" steps need to be performed to get to - // the declarations. - // - // This is where "breadcrumbs" come in. A breadcrumb represents - // an extra "step" that must be performed to turn a declaration - // found by lookup into a valid expression to splice into the - // AST. Most of the time lookup result items don't have any - // breadcrumbs, so that no extra heap allocation takes place. - // When an item does have breadcrumbs, and it is chosen as - // the unique result (perhaps by overload resolution), then - // we can walk the list of breadcrumbs to create a full - // expression. - class Breadcrumb : public RefObject - { - public: - enum class Kind : uint8_t - { - // The lookup process looked "through" an in-scope - // declaration to the fields inside of it, so that - // even if lookup started with a simple name `f`, - // it needs to result in a member expression `obj.f`. - Member, - - // The lookup process took a pointer(-like) value, and then - // proceeded to derefence it and look at the thing(s) - // it points to instead, so that the final expression - // needs to have `(*obj)` - Deref, - - // The lookup process saw a value `obj` of type `T` and - // took into account an in-scope constraint that says - // `T` is a subtype of some other type `U`, so that - // lookup was able to find a member through type `U` - // instead. - Constraint, - - // The lookup process considered a member of an - // enclosing type as being in scope, so that any - // reference to that member needs to use a `this` - // expression as appropriate. - This, - }; - - // The kind of lookup step that was performed - Kind kind; - - // For the `Kind::This` case, is the `this` parameter - // mutable or not? - enum class ThisParameterMode : uint8_t - { - Default, - Mutating, - }; - ThisParameterMode thisParameterMode = ThisParameterMode::Default; - - - // As needed, a reference to the declaration that faciliated - // the lookup step. - // - // For a `Member` lookup step, this is the declaration whose - // members were implicitly pulled into scope. - // - // For a `Constraint` lookup step, this is the `ConstraintDecl` - // that serves to witness the subtype relationship. - // - DeclRef declRef; - - // The next implicit step that the lookup process took to - // arrive at a final value. - RefPtr next; - - Breadcrumb( - Kind kind, - DeclRef declRef, - RefPtr next, - ThisParameterMode thisParameterMode = ThisParameterMode::Default) - : kind(kind) - , thisParameterMode(thisParameterMode) - , declRef(declRef) - , next(next) - {} - }; - - // A properly-specialized reference to the declaration that was found. - DeclRef declRef; - - // Any breadcrumbs needed in order to turn that declaration - // reference into a well-formed expression. - // - // This is unused in the simple case where a declaration - // is being referenced directly (rather than through - // transparent members). - RefPtr breadcrumbs; - - LookupResultItem() = default; - explicit LookupResultItem(DeclRef declRef) - : declRef(declRef) - {} - LookupResultItem(DeclRef declRef, RefPtr breadcrumbs) - : declRef(declRef) - , breadcrumbs(breadcrumbs) - {} - }; - - - // Result of looking up a name in some lexical/semantic environment. - // Can be used to enumerate all the declarations matching that name, - // in the case where the result is overloaded. - struct LookupResult - { - // The one item that was found, in the smple case - LookupResultItem item; - - // All of the items that were found, in the complex case. - // Note: if there was no overloading, then this list isn't - // used at all, to avoid allocation. - List items; - - HashSet> lookedupDecls; - - // Was at least one result found? - bool isValid() const { return item.declRef.getDecl() != nullptr; } - - bool isOverloaded() const { return items.getCount() > 1; } - - Name* getName() const - { - return items.getCount() > 1 ? items[0].declRef.GetName() : item.declRef.GetName(); - } - LookupResultItem* begin() - { - if (isValid()) - { - if (isOverloaded()) - return items.begin(); - else - return &item; - } - else - return nullptr; - } - LookupResultItem* end() - { - if (isValid()) - { - if (isOverloaded()) - return items.end(); - else - return &item + 1; - } - else - return nullptr; - } - }; - - struct SemanticsVisitor; - - struct LookupRequest - { - SemanticsVisitor* semantics = nullptr; - RefPtr scope = nullptr; - RefPtr endScope = nullptr; - - LookupMask mask = LookupMask::Default; - }; - - struct WitnessTable; - - // A value that witnesses the satisfaction of an interface - // requirement by a particular declaration or value. - struct RequirementWitness - { - RequirementWitness() - : m_flavor(Flavor::none) - {} - - RequirementWitness(DeclRef declRef) - : m_flavor(Flavor::declRef) - , m_declRef(declRef) - {} - - RequirementWitness(RefPtr val); - - RequirementWitness(RefPtr witnessTable); - - enum class Flavor - { - none, - declRef, - val, - witnessTable, - }; - - Flavor getFlavor() - { - return m_flavor; - } - - DeclRef getDeclRef() - { - SLANG_ASSERT(getFlavor() == Flavor::declRef); - return m_declRef; - } - - RefPtr getVal() - { - SLANG_ASSERT(getFlavor() == Flavor::val); - return m_obj.as(); - } - - RefPtr getWitnessTable(); - - RequirementWitness specialize(SubstitutionSet const& subst); - - Flavor m_flavor; - DeclRef m_declRef; - RefPtr m_obj; - - }; - - typedef Dictionary RequirementDictionary; - - struct WitnessTable : RefObject - { - RequirementDictionary requirementDictionary; - }; - - typedef Dictionary> AttributeArgumentValueDict; - - /// Collects information about existential type parameters and their arguments. - struct ExistentialTypeSlots - { - /// For each type parameter, holds the interface/existential type that constrains it. - List> paramTypes; - - /// An argument for an existential type parameter. - /// - /// Comprises a concrete type and a witness for its conformance to the desired - /// interface/existential type for the corresponding parameter. - /// - struct Arg - { - RefPtr type; - RefPtr witness; - }; - - /// Any arguments provided for the existential type parameters. - /// - /// It is possible for `args` to be empty even if `paramTypes` is non-empty; - /// that situation represents an unspecialized program or entry point. - /// - List args; - }; - - - // Generate class definition for all syntax classes -#define SYNTAX_FIELD(TYPE, NAME) TYPE NAME; -#define FIELD(TYPE, NAME) TYPE NAME; -#define FIELD_INIT(TYPE, NAME, INIT) TYPE NAME = INIT; -#define RAW(...) __VA_ARGS__ -#define END_SYNTAX_CLASS() }; -#define SYNTAX_CLASS(NAME, BASE, ...) class NAME : public BASE {public: -#include "object-meta-begin.h" - -#include "syntax-base-defs.h" -#undef SYNTAX_CLASS - -#undef ABSTRACT_SYNTAX_CLASS -#define ABSTRACT_SYNTAX_CLASS(NAME, BASE, ...) \ - class NAME : public BASE { \ - public: /* ... */ -#define SYNTAX_CLASS(NAME, BASE, ...) \ - class NAME : public BASE { \ - virtual void accept(NAME::Visitor* visitor, void* extra) override; \ - public: virtual SyntaxClass getClass() override; \ - public: /* ... */ -#include "expr-defs.h" -#include "decl-defs.h" -#include "modifier-defs.h" -#include "stmt-defs.h" -#include "type-defs.h" -#include "val-defs.h" - -#include "object-meta-end.h" - - inline RefPtr GetSub(DeclRef const& declRef) - { - return declRef.Substitute(declRef.getDecl()->sub.Ptr()); - } - - inline RefPtr GetSup(DeclRef const& declRef) - { - return declRef.Substitute(declRef.getDecl()->getSup().type); - } - - // Note(tfoley): These logically belong to `Type`, - // but order-of-declaration stuff makes that tricky - // - // TODO(tfoley): These should really belong to the compilation context! - // - void registerBuiltinDecl( - Session* session, - RefPtr decl, - RefPtr modifier); - void registerMagicDecl( - Session* session, - RefPtr decl, - RefPtr modifier); - - // Look up a magic declaration by its name - RefPtr findMagicDecl( - Session* session, - String const& name); - - // Create an instance of a syntax class by name - SyntaxNodeBase* createInstanceOfSyntaxClassByName( - String const& name); - - // `Val` - - inline bool areValsEqual(Val* left, Val* right) - { - if(!left || !right) return left == right; - return left->EqualsVal(right); - } - - // - - inline BaseType GetVectorBaseType(VectorExpressionType* vecType) - { - auto basicExprType = as(vecType->elementType); - return basicExprType->baseType; - } - - inline int GetVectorSize(VectorExpressionType* vecType) - { - auto constantVal = as(vecType->elementCount); - if (constantVal) - return (int) constantVal->value; - // TODO: what to do in this case? - return 0; - } - - // - // Declarations - // - - inline ExtensionDecl* GetCandidateExtensions(DeclRef const& declRef) - { - return declRef.getDecl()->candidateExtensions; - } - - inline FilteredMemberRefList getMembers(DeclRef const& declRef) - { - return FilteredMemberRefList(declRef.getDecl()->Members, declRef.substitutions); - } - - template - inline FilteredMemberRefList getMembersOfType(DeclRef const& declRef) - { - return FilteredMemberRefList(declRef.getDecl()->Members, declRef.substitutions); - } - - template - inline List> getMembersOfTypeWithExt(DeclRef const& declRef) - { - List> rs; - for (auto d : getMembersOfType(declRef)) - rs.add(d); - if (auto aggDeclRef = declRef.as()) - { - for (auto ext = GetCandidateExtensions(aggDeclRef); ext; ext = ext->nextCandidateExtension) - { - auto extMembers = getMembersOfType(DeclRef(ext, declRef.substitutions)); - for (auto mbr : extMembers) - rs.add(mbr); - } - } - return rs; - } - - /// The the user-level name for a variable that might be a shader parameter. - /// - /// In most cases this is just the name of the variable declaration itself, - /// but in the specific case of a `cbuffer`, the name that the user thinks - /// of is really metadata. For example: - /// - /// cbuffer C { int x; } - /// - /// In this example, error messages relating to the constant buffer should - /// really use the name `C`, but that isn't the name of the declaration - /// (it is in practice anonymous, and `C` can be used for a different - /// declaration in the same file). - /// - Name* getReflectionName(VarDeclBase* varDecl); - - inline RefPtr GetType(DeclRef const& declRef) - { - return declRef.Substitute(declRef.getDecl()->type.Ptr()); - } - - inline RefPtr getInitExpr(DeclRef const& declRef) - { - return declRef.Substitute(declRef.getDecl()->initExpr); - } - - inline RefPtr getType(DeclRef const& declRef) - { - return declRef.Substitute(declRef.getDecl()->type.Ptr()); - } - - inline RefPtr getTagExpr(DeclRef const& declRef) - { - return declRef.Substitute(declRef.getDecl()->tagExpr); - } - - inline RefPtr GetTargetType(DeclRef const& declRef) - { - return declRef.Substitute(declRef.getDecl()->targetType.Ptr()); - } - - inline FilteredMemberRefList GetFields(DeclRef const& declRef) - { - return getMembersOfType(declRef); - } - - inline RefPtr getBaseType(DeclRef const& declRef) - { - return declRef.Substitute(declRef.getDecl()->base.type); - } - - inline RefPtr GetType(DeclRef const& declRef) - { - return declRef.Substitute(declRef.getDecl()->type.Ptr()); - } - - inline RefPtr GetResultType(DeclRef const& declRef) - { - return declRef.Substitute(declRef.getDecl()->ReturnType.type.Ptr()); - } - - inline FilteredMemberRefList GetParameters(DeclRef const& declRef) - { - return getMembersOfType(declRef); - } - - inline Decl* GetInner(DeclRef const& declRef) - { - // TODO: Should really return a `DeclRef` for the inner - // declaration, and not just a raw pointer - return declRef.getDecl()->inner.Ptr(); - } - - - // - - RefPtr getArrayType( - Type* elementType, - IntVal* elementCount); - - RefPtr getArrayType( - Type* elementType); - - RefPtr getNamedType( - Session* session, - DeclRef const& declRef); - - RefPtr getTypeType( - Type* type); - - RefPtr getFuncType( - Session* session, - DeclRef const& declRef); - - RefPtr getGenericDeclRefType( - Session* session, - DeclRef const& declRef); - - RefPtr getSamplerStateType( - Session* session); - - - // Definitions that can't come earlier despite - // being in templates, because gcc/clang get angry. - // - template - void FilteredModifierList::Iterator::operator++() - { - current = Adjust(current->next.Ptr()); - } - // - template - Modifier* FilteredModifierList::Adjust(Modifier* modifier) - { - Modifier* m = modifier; - for (;;) - { - if (!m) return m; - if (as(m)) - { - return m; - } - m = m->next.Ptr(); - } - } - - // TODO: where should this live? - SubstitutionSet createDefaultSubstitutions( - Session* session, - Decl* decl, - SubstitutionSet parentSubst); - - SubstitutionSet createDefaultSubstitutions( - Session* session, - Decl* decl); - - DeclRef createDefaultSubstitutionsIfNeeded( - Session* session, - DeclRef declRef); - - RefPtr createDefaultSubsitutionsForGeneric( - Session* session, - GenericDecl* genericDecl, - RefPtr outerSubst); - - RefPtr findInnerMostGenericSubstitution(Substitutions* subst); - - enum class UserDefinedAttributeTargets - { - None = 0, - Struct = 1, - Var = 2, - Function = 4, - All = 7 - }; - - /// Get the module that a declaration is associated with, if any. - Module* getModule(Decl* decl); - -} // namespace Slang - -#endif diff --git a/source/slang/token-defs.h b/source/slang/token-defs.h deleted file mode 100644 index 873a252b4..000000000 --- a/source/slang/token-defs.h +++ /dev/null @@ -1,96 +0,0 @@ -// token-defs.h - -// This file is meant to be included multiple times, to produce different -// pieces of code related to tokens -// -// Each token is declared here with: -// -// TOKEN(id, desc) -// -// where `id` is the identifier that will be used for the token in -// ordinary code, while `desc` is name we should print when -// referring to this token in diagnostic messages. - - -#ifndef TOKEN -#error Need to define TOKEN(ID, DESC) before including "token-defs.h" -#endif - -TOKEN(Unknown, "") -TOKEN(EndOfFile, "end of file") -TOKEN(EndOfDirective, "end of line") -TOKEN(Invalid, "invalid character") -TOKEN(Identifier, "identifier") -TOKEN(IntegerLiteral, "integer literal") -TOKEN(FloatingPointLiteral, "floating-point literal") -TOKEN(StringLiteral, "string literal") -TOKEN(CharLiteral, "character literal") -TOKEN(WhiteSpace, "whitespace") -TOKEN(NewLine, "newline") -TOKEN(LineComment, "line comment") -TOKEN(BlockComment, "block comment") -TOKEN(DirectiveMessage, "user-defined message") - -#define PUNCTUATION(id, text) \ - TOKEN(id, "'" text "'") - -PUNCTUATION(Semicolon, ";") -PUNCTUATION(Comma, ",") -PUNCTUATION(Dot, ".") - -PUNCTUATION(LBrace, "{") -PUNCTUATION(RBrace, "}") -PUNCTUATION(LBracket, "[") -PUNCTUATION(RBracket, "]") -PUNCTUATION(LParent, "(") -PUNCTUATION(RParent, ")") - -PUNCTUATION(OpAssign, "=") -PUNCTUATION(OpAdd, "+") -PUNCTUATION(OpSub, "-") -PUNCTUATION(OpMul, "*") -PUNCTUATION(OpDiv, "/") -PUNCTUATION(OpMod, "%") -PUNCTUATION(OpNot, "!") -PUNCTUATION(OpBitNot, "~") -PUNCTUATION(OpLsh, "<<") -PUNCTUATION(OpRsh, ">>") -PUNCTUATION(OpEql, "==") -PUNCTUATION(OpNeq, "!=") -PUNCTUATION(OpGreater, ">") -PUNCTUATION(OpLess, "<") -PUNCTUATION(OpGeq, ">=") -PUNCTUATION(OpLeq, "<=") -PUNCTUATION(OpAnd, "&&") -PUNCTUATION(OpOr, "||") -PUNCTUATION(OpBitAnd, "&") -PUNCTUATION(OpBitOr, "|") -PUNCTUATION(OpBitXor, "^") -PUNCTUATION(OpInc, "++") -PUNCTUATION(OpDec, "--") - -PUNCTUATION(OpAddAssign, "+=") -PUNCTUATION(OpSubAssign, "-=") -PUNCTUATION(OpMulAssign, "*=") -PUNCTUATION(OpDivAssign, "/=") -PUNCTUATION(OpModAssign, "%=") -PUNCTUATION(OpShlAssign, "<<=") -PUNCTUATION(OpShrAssign, ">>=") -PUNCTUATION(OpAndAssign, "&=") -PUNCTUATION(OpOrAssign, "|=") -PUNCTUATION(OpXorAssign, "^=") - -PUNCTUATION(QuestionMark, "?") -PUNCTUATION(Colon, ":") -PUNCTUATION(RightArrow, "->") -PUNCTUATION(At, "@") -PUNCTUATION(Dollar, "$") -PUNCTUATION(Pound, "#") -PUNCTUATION(PoundPound, "##") - -PUNCTUATION(Scope, "::") - -#undef PUNCTUATION - -// Un-define the `TOKEN` macro so that client doesn't have to -#undef TOKEN diff --git a/source/slang/token.cpp b/source/slang/token.cpp deleted file mode 100644 index 42f7ab55f..000000000 --- a/source/slang/token.cpp +++ /dev/null @@ -1,39 +0,0 @@ -// token.cpp -#include "token.h" - -#include - -namespace Slang { - - -Name* Token::getName() const -{ - return getNameOrNull(); -} - -Name* Token::getNameOrNull() const -{ - switch (type) - { - default: - return nullptr; - - case TokenType::Identifier: - return (Name*) ptrValue; - } -} - -char const* TokenTypeToString(TokenType type) -{ - switch( type ) - { - default: - SLANG_ASSERT(!"unexpected"); - return ""; - -#define TOKEN(NAME, DESC) case TokenType::NAME: return DESC; -#include "token-defs.h" - } -} - -} // namespace Slang diff --git a/source/slang/token.h b/source/slang/token.h deleted file mode 100644 index d7d45882d..000000000 --- a/source/slang/token.h +++ /dev/null @@ -1,67 +0,0 @@ -// token.h -#ifndef SLANG_TOKEN_H_INCLUDED -#define SLANG_TOKEN_H_INCLUDED - -#include "../core/basic.h" - -#include "source-loc.h" - -namespace Slang { - -class Name; - -enum class TokenType -{ -#define TOKEN(NAME, DESC) NAME, -#include "token-defs.h" -}; - -char const* TokenTypeToString(TokenType type); - -enum TokenFlag : unsigned int -{ - AtStartOfLine = 1 << 0, - AfterWhitespace = 1 << 1, - SuppressMacroExpansion = 1 << 2, - ScrubbingNeeded = 1 << 3, -}; -typedef unsigned int TokenFlags; - -class Token -{ -public: - TokenType type = TokenType::Unknown; - TokenFlags flags = 0; - - SourceLoc loc; - void* ptrValue; - - UnownedStringSlice Content; - - Token() = default; - - Token( - TokenType typeIn, - const UnownedStringSlice & contentIn, - SourceLoc locIn, - TokenFlags flagsIn = 0) - : flags(flagsIn) - { - type = typeIn; - Content = contentIn; - loc = locIn; - ptrValue = nullptr; - } - - Name* getName() const; - - Name* getNameOrNull() const; - - SourceLoc getLoc() const { return loc; } -}; - - - -} // namespace Slang - -#endif diff --git a/source/slang/type-defs.h b/source/slang/type-defs.h deleted file mode 100644 index d0c00c73a..000000000 --- a/source/slang/type-defs.h +++ /dev/null @@ -1,490 +0,0 @@ -// type-defs.h - -// Syntax class definitions for types. - -// The type of a reference to an overloaded name -SYNTAX_CLASS(OverloadGroupType, Type) -RAW( -public: - virtual String ToString() override; - -protected: - virtual bool EqualsImpl(Type * type) override; - virtual RefPtr CreateCanonicalType() override; - virtual int GetHashCode() override; -) -END_SYNTAX_CLASS() - -// The type of an initializer-list expression (before it has -// been coerced to some other type) -SYNTAX_CLASS(InitializerListType, Type) -RAW( - virtual String ToString() override; - -protected: - virtual bool EqualsImpl(Type * type) override; - virtual RefPtr CreateCanonicalType() override; - virtual int GetHashCode() override; -) -END_SYNTAX_CLASS() - -// The type of an expression that was erroneous -SYNTAX_CLASS(ErrorType, Type) -RAW( -public: - virtual String ToString() override; - -protected: - virtual bool EqualsImpl(Type * type) override; - virtual RefPtr SubstituteImpl(SubstitutionSet subst, int* ioDiff) override; - virtual RefPtr CreateCanonicalType() override; - virtual int GetHashCode() override; -) -END_SYNTAX_CLASS() - -// A type that takes the form of a reference to some declaration -SYNTAX_CLASS(DeclRefType, Type) - DECL_FIELD(DeclRef, declRef) - -RAW( - virtual String ToString() override; - virtual RefPtr SubstituteImpl(SubstitutionSet subst, int* ioDiff) override; - - static RefPtr Create( - Session* session, - DeclRef declRef); - - DeclRefType() - {} - DeclRefType( - DeclRef declRef) - : declRef(declRef) - {} -protected: - virtual int GetHashCode() override; - virtual bool EqualsImpl(Type * type) override; - virtual RefPtr CreateCanonicalType() override; -) -END_SYNTAX_CLASS() - -// Base class for types that can be used in arithmetic expressions -ABSTRACT_SYNTAX_CLASS(ArithmeticExpressionType, DeclRefType) -RAW( -public: - virtual BasicExpressionType* GetScalarType() = 0; -) -END_SYNTAX_CLASS() - -SYNTAX_CLASS(BasicExpressionType, ArithmeticExpressionType) - - FIELD(BaseType, baseType) - -RAW( - BasicExpressionType() {} - BasicExpressionType( - Slang::BaseType baseType) - : baseType(baseType) - {} -protected: - virtual BasicExpressionType* GetScalarType() override; - virtual bool EqualsImpl(Type * type) override; - virtual RefPtr CreateCanonicalType() override; -) -END_SYNTAX_CLASS() - -// Base type for things that are built in to the compiler, -// and will usually have special behavior or a custom -// mapping to the IR level. -ABSTRACT_SYNTAX_CLASS(BuiltinType, DeclRefType) -END_SYNTAX_CLASS() - -// Resources that contain "elements" that can be fetched -ABSTRACT_SYNTAX_CLASS(ResourceType, BuiltinType) - // The type that results from fetching an element from this resource - SYNTAX_FIELD(RefPtr, elementType) - - // Shape and access level information for this resource type - FIELD(TextureFlavor, flavor) - - RAW( - TextureFlavor::Shape GetBaseShape() - { - return flavor.GetBaseShape(); - } - bool isMultisample() { return flavor.isMultisample(); } - bool isArray() { return flavor.isArray(); } - SlangResourceShape getShape() const { return flavor.getShape(); } - SlangResourceAccess getAccess() { return flavor.getAccess(); } - - ) -END_SYNTAX_CLASS() - -ABSTRACT_SYNTAX_CLASS(TextureTypeBase, ResourceType) -RAW( - TextureTypeBase() - {} - TextureTypeBase( - TextureFlavor flavor, - RefPtr elementType) - { - this->elementType = elementType; - this->flavor = flavor; - } -) -END_SYNTAX_CLASS() - -SYNTAX_CLASS(TextureType, TextureTypeBase) -RAW( - TextureType() - {} - TextureType( - TextureFlavor flavor, - RefPtr elementType) - : TextureTypeBase(flavor, elementType) - {} -) -END_SYNTAX_CLASS() - -// This is a base type for texture/sampler pairs, -// as they exist in, e.g., GLSL -SYNTAX_CLASS(TextureSamplerType, TextureTypeBase) -RAW( - TextureSamplerType() - {} - TextureSamplerType( - TextureFlavor flavor, - RefPtr elementType) - : TextureTypeBase(flavor, elementType) - {} -) -END_SYNTAX_CLASS() - -// This is a base type for `image*` types, as they exist in GLSL -SYNTAX_CLASS(GLSLImageType, TextureTypeBase) -RAW( - GLSLImageType() - {} - GLSLImageType( - TextureFlavor flavor, - RefPtr elementType) - : TextureTypeBase(flavor, elementType) - {} -) -END_SYNTAX_CLASS() - -SYNTAX_CLASS(SamplerStateType, BuiltinType) - // What flavor of sampler state is this - FIELD(SamplerStateFlavor, flavor) -END_SYNTAX_CLASS() - -// Other cases of generic types known to the compiler -SYNTAX_CLASS(BuiltinGenericType, BuiltinType) - SYNTAX_FIELD(RefPtr, elementType) - - RAW(Type* getElementType() { return elementType; }) -END_SYNTAX_CLASS() - -// Types that behave like pointers, in that they can be -// dereferenced (implicitly) to access members defined -// in the element type. -SIMPLE_SYNTAX_CLASS(PointerLikeType, BuiltinGenericType) - -// HLSL buffer-type resources - -SIMPLE_SYNTAX_CLASS(HLSLStructuredBufferTypeBase, BuiltinGenericType) -SIMPLE_SYNTAX_CLASS(HLSLStructuredBufferType, HLSLStructuredBufferTypeBase) -SIMPLE_SYNTAX_CLASS(HLSLRWStructuredBufferType, HLSLStructuredBufferTypeBase) -SIMPLE_SYNTAX_CLASS(HLSLRasterizerOrderedStructuredBufferType, HLSLStructuredBufferTypeBase) - -SIMPLE_SYNTAX_CLASS(UntypedBufferResourceType, BuiltinType) -SIMPLE_SYNTAX_CLASS(HLSLByteAddressBufferType, UntypedBufferResourceType) -SIMPLE_SYNTAX_CLASS(HLSLRWByteAddressBufferType, UntypedBufferResourceType) -SIMPLE_SYNTAX_CLASS(HLSLRasterizerOrderedByteAddressBufferType, UntypedBufferResourceType) -SIMPLE_SYNTAX_CLASS(RaytracingAccelerationStructureType, UntypedBufferResourceType) - -SIMPLE_SYNTAX_CLASS(HLSLAppendStructuredBufferType, HLSLStructuredBufferTypeBase) -SIMPLE_SYNTAX_CLASS(HLSLConsumeStructuredBufferType, HLSLStructuredBufferTypeBase) - -SYNTAX_CLASS(HLSLPatchType, BuiltinType) -RAW( - Type* getElementType(); - IntVal* getElementCount(); -) -END_SYNTAX_CLASS() - -SIMPLE_SYNTAX_CLASS(HLSLInputPatchType, HLSLPatchType) -SIMPLE_SYNTAX_CLASS(HLSLOutputPatchType, HLSLPatchType) - -// HLSL geometry shader output stream types - -SIMPLE_SYNTAX_CLASS(HLSLStreamOutputType, BuiltinGenericType) -SIMPLE_SYNTAX_CLASS(HLSLPointStreamType, HLSLStreamOutputType) -SIMPLE_SYNTAX_CLASS(HLSLLineStreamType, HLSLStreamOutputType) -SIMPLE_SYNTAX_CLASS(HLSLTriangleStreamType, HLSLStreamOutputType) - -// -SIMPLE_SYNTAX_CLASS(GLSLInputAttachmentType, BuiltinType) - -// Base class for types used when desugaring parameter block -// declarations, includeing HLSL `cbuffer` or GLSL `uniform` blocks. -SIMPLE_SYNTAX_CLASS(ParameterGroupType, PointerLikeType) - -SIMPLE_SYNTAX_CLASS(UniformParameterGroupType, ParameterGroupType) -SIMPLE_SYNTAX_CLASS(VaryingParameterGroupType, ParameterGroupType) - -// type for HLSL `cbuffer` declarations, and `ConstantBuffer` -// ALso used for GLSL `uniform` blocks. -SIMPLE_SYNTAX_CLASS(ConstantBufferType, UniformParameterGroupType) - -// type for HLSL `tbuffer` declarations, and `TextureBuffer` -SIMPLE_SYNTAX_CLASS(TextureBufferType, UniformParameterGroupType) - -// type for GLSL `in` and `out` blocks -SIMPLE_SYNTAX_CLASS(GLSLInputParameterGroupType, VaryingParameterGroupType) -SIMPLE_SYNTAX_CLASS(GLSLOutputParameterGroupType, VaryingParameterGroupType) - -// type for GLLSL `buffer` blocks -SIMPLE_SYNTAX_CLASS(GLSLShaderStorageBufferType, UniformParameterGroupType) - -// type for Slang `ParameterBlock` type -SIMPLE_SYNTAX_CLASS(ParameterBlockType, UniformParameterGroupType) - -SYNTAX_CLASS(ArrayExpressionType, Type) - SYNTAX_FIELD(RefPtr, baseType) - SYNTAX_FIELD(RefPtr, ArrayLength) - -RAW( - virtual Slang::String ToString() override; - -protected: - virtual bool EqualsImpl(Type * type) override; - virtual RefPtr CreateCanonicalType() override; - virtual RefPtr SubstituteImpl(SubstitutionSet subst, int* ioDiff) override; - virtual int GetHashCode() override; - ) -END_SYNTAX_CLASS() - -// The "type" of an expression that resolves to a type. -// For example, in the expression `float(2)` the sub-expression, -// `float` would have the type `TypeType(float)`. -SYNTAX_CLASS(TypeType, Type) - // The type that this is the type of... - SYNTAX_FIELD(RefPtr, type) - -RAW( -public: - TypeType() - {} - TypeType(RefPtr type) - : type(type) - {} - - virtual String ToString() override; - -protected: - virtual bool EqualsImpl(Type * type) override; - virtual RefPtr CreateCanonicalType() override; - virtual int GetHashCode() override; -) -END_SYNTAX_CLASS() - -// A vector type, e.g., `vector` -SYNTAX_CLASS(VectorExpressionType, ArithmeticExpressionType) - - // The type of vector elements. - // As an invariant, this should be a basic type or an alias. - SYNTAX_FIELD(RefPtr, elementType) - - // The number of elements - SYNTAX_FIELD(RefPtr, elementCount) - -RAW( - virtual String ToString() override; - -protected: - virtual BasicExpressionType* GetScalarType() override; -) -END_SYNTAX_CLASS() - -// A matrix type, e.g., `matrix` -SYNTAX_CLASS(MatrixExpressionType, ArithmeticExpressionType) -RAW( - - Type* getElementType(); - IntVal* getRowCount(); - IntVal* getColumnCount(); - - RefPtr getRowType(); - - virtual String ToString() override; - -protected: - virtual BasicExpressionType* GetScalarType() override; - -private: - RefPtr mRowType; -) -END_SYNTAX_CLASS() - -// The built-in `String` type -SIMPLE_SYNTAX_CLASS(StringType, BuiltinType) - -// Type built-in `__EnumType` type -SYNTAX_CLASS(EnumTypeType, BuiltinType) - -// TODO: provide accessors for the declaration, the "tag" type, etc. - -END_SYNTAX_CLASS() - -// Base class for types that map down to -// simple pointers as part of code generation. -SYNTAX_CLASS(PtrTypeBase, BuiltinType) -RAW( - // Get the type of the pointed-to value. - Type* getValueType(); -) -END_SYNTAX_CLASS() - -// A true (user-visible) pointer type, e.g., `T*` -SYNTAX_CLASS(PtrType, PtrTypeBase) -END_SYNTAX_CLASS() - -// A type that represents the behind-the-scenes -// logical pointer that is passed for an `out` -// or `in out` parameter -SYNTAX_CLASS(OutTypeBase, PtrTypeBase) -END_SYNTAX_CLASS() - -// The type for an `out` parameter, e.g., `out T` -SYNTAX_CLASS(OutType, OutTypeBase) -END_SYNTAX_CLASS() - -// The type for an `in out` parameter, e.g., `in out T` -SYNTAX_CLASS(InOutType, OutTypeBase) -END_SYNTAX_CLASS() - -// The type for an `ref` parameter, e.g., `ref T` -SYNTAX_CLASS(RefType, PtrTypeBase) -END_SYNTAX_CLASS() - -// A type alias of some kind (e.g., via `typedef`) -SYNTAX_CLASS(NamedExpressionType, Type) -DECL_FIELD(DeclRef, declRef) - -RAW( - RefPtr innerType; - NamedExpressionType() - {} - NamedExpressionType( - DeclRef declRef) - : declRef(declRef) - {} - - - virtual String ToString() override; - -protected: - virtual bool EqualsImpl(Type * type) override; - virtual RefPtr CreateCanonicalType() override; - virtual int GetHashCode() override; -) -END_SYNTAX_CLASS() - -// A function type is defined by its parameter types -// and its result type. -SYNTAX_CLASS(FuncType, Type) - - // TODO: We may want to preserve parameter names - // in the list here, just so that we can print - // out friendly names when printing a function - // type, even if they don't affect the actual - // semantic type underneath. - - FIELD(List>, paramTypes) - FIELD(RefPtr, resultType) -RAW( - FuncType() - {} - - UInt getParamCount() { return paramTypes.getCount(); } - Type* getParamType(UInt index) { return paramTypes[index]; } - Type* getResultType() { return resultType; } - - virtual String ToString() override; -protected: - virtual RefPtr SubstituteImpl(SubstitutionSet subst, int* ioDiff) override; - virtual bool EqualsImpl(Type * type) override; - virtual RefPtr CreateCanonicalType() override; - virtual int GetHashCode() override; -) -END_SYNTAX_CLASS() - -// The "type" of an expression that names a generic declaration. -SYNTAX_CLASS(GenericDeclRefType, Type) - - DECL_FIELD(DeclRef, declRef) - - RAW( - GenericDeclRefType() - {} - GenericDeclRefType( - DeclRef declRef) - : declRef(declRef) - {} - - - DeclRef const& GetDeclRef() const { return declRef; } - - virtual String ToString() override; - -protected: - virtual bool EqualsImpl(Type * type) override; - virtual int GetHashCode() override; - virtual RefPtr CreateCanonicalType() override; -) -END_SYNTAX_CLASS() - -// The concrete type for a value wrapped in an existential, accessible -// when the existential is "opened" in some context. -SYNTAX_CLASS(ExtractExistentialType, Type) -RAW( - DeclRef declRef; - - virtual String ToString() override; - virtual bool EqualsImpl(Type * type) override; - virtual int GetHashCode() override; - virtual RefPtr CreateCanonicalType() override; - virtual RefPtr SubstituteImpl(SubstitutionSet subst, int* ioDiff) override; -) -END_SYNTAX_CLASS() - - /// A tagged union of zero or more other types. -SYNTAX_CLASS(TaggedUnionType, Type) -RAW( - /// The distinct "cases" the tagged union can store. - /// - /// For each type in this array, the array index is the - /// tag value for that case. - /// - List> caseTypes; - - virtual String ToString() override; - virtual bool EqualsImpl(Type * type) override; - virtual int GetHashCode() override; - virtual RefPtr CreateCanonicalType() override; - virtual RefPtr SubstituteImpl(SubstitutionSet subst, int* ioDiff) override; -) -END_SYNTAX_CLASS() - -SYNTAX_CLASS(ExistentialSpecializedType, Type) -RAW( - RefPtr baseType; - ExistentialTypeSlots slots; - - virtual String ToString() override; - virtual bool EqualsImpl(Type * type) override; - virtual int GetHashCode() override; - virtual RefPtr CreateCanonicalType() override; - virtual RefPtr SubstituteImpl(SubstitutionSet subst, int* ioDiff) override; -) -END_SYNTAX_CLASS() \ No newline at end of file diff --git a/source/slang/type-layout.cpp b/source/slang/type-layout.cpp deleted file mode 100644 index 92f7d6af3..000000000 --- a/source/slang/type-layout.cpp +++ /dev/null @@ -1,3209 +0,0 @@ -// TypeLayout.cpp -#include "type-layout.h" - -#include "syntax.h" - -#include - -namespace Slang { - -size_t RoundToAlignment(size_t offset, size_t alignment) -{ - size_t remainder = offset % alignment; - if (remainder == 0) - return offset; - else - return offset + (alignment - remainder); -} - -LayoutSize RoundToAlignment(LayoutSize offset, size_t alignment) -{ - // An infinite size is assumed to be maximally aligned. - if(offset.isInfinite()) - return LayoutSize::infinite(); - - return RoundToAlignment(offset.getFiniteValue(), alignment); -} - -static size_t RoundUpToPowerOfTwo( size_t value ) -{ - // TODO(tfoley): I know this isn't a fast approach - size_t result = 1; - while (result < value) - result *= 2; - return result; -} - -// - -struct DefaultLayoutRulesImpl : SimpleLayoutRulesImpl -{ - // Get size and alignment for a single value of base type. - SimpleLayoutInfo GetScalarLayout(BaseType baseType) override - { - switch (baseType) - { - case BaseType::Void: return SimpleLayoutInfo(); - - // Note: By convention, a `bool` in a constant buffer is stored as an `int. - // This default may eventually change, at which point this logic will need - // to be updated. - // - // TODO: We should probably warn in this case, since storing a `bool` in - // a constant buffer seems like a Bad Idea anyway. - // - case BaseType::Bool: return SimpleLayoutInfo( LayoutResourceKind::Uniform, 4, 4 ); - - - case BaseType::Int8: return SimpleLayoutInfo( LayoutResourceKind::Uniform, 1,1); - case BaseType::Int16: return SimpleLayoutInfo( LayoutResourceKind::Uniform, 2,2); - case BaseType::Int: return SimpleLayoutInfo( LayoutResourceKind::Uniform, 4,4); - case BaseType::Int64: return SimpleLayoutInfo( LayoutResourceKind::Uniform, 8,8); - - case BaseType::UInt8: return SimpleLayoutInfo( LayoutResourceKind::Uniform, 1,1); - case BaseType::UInt16: return SimpleLayoutInfo( LayoutResourceKind::Uniform, 2,2); - case BaseType::UInt: return SimpleLayoutInfo( LayoutResourceKind::Uniform, 4,4); - case BaseType::UInt64: return SimpleLayoutInfo( LayoutResourceKind::Uniform, 8,8); - - case BaseType::Half: return SimpleLayoutInfo( LayoutResourceKind::Uniform, 2,2); - case BaseType::Float: return SimpleLayoutInfo( LayoutResourceKind::Uniform, 4,4); - case BaseType::Double: return SimpleLayoutInfo( LayoutResourceKind::Uniform, 8,8); - - default: - SLANG_UNEXPECTED("uhandled scalar type"); - UNREACHABLE_RETURN(SimpleLayoutInfo( LayoutResourceKind::Uniform, 0, 1 )); - } - } - - SimpleArrayLayoutInfo GetArrayLayout( SimpleLayoutInfo elementInfo, LayoutSize elementCount) override - { - SLANG_RELEASE_ASSERT(elementInfo.size.isFinite()); - auto elementSize = elementInfo.size.getFiniteValue(); - auto elementAlignment = elementInfo.alignment; - auto elementStride = RoundToAlignment(elementSize, elementAlignment); - - // An array with no elements will have zero size. - // - LayoutSize arraySize = 0; - // - // Any array with a non-zero number of elements will need - // to have space for N elements of size `elementSize`, with - // the constraints that there must be `elementStride` bytes - // between consecutive elements. - // - if( elementCount > 0 ) - { - // We can think of this as either allocating (N-1) - // chunks of size `elementStride` (for most of the elements) - // and then one final chunk of size `elementSize` for - // the last element, or equivalently as allocating - // N chunks of size `elementStride` and then "giving back" - // the final `elementStride - elementSize` bytes. - // - arraySize = (elementStride * (elementCount-1)) + elementSize; - } - - SimpleArrayLayoutInfo arrayInfo; - arrayInfo.kind = elementInfo.kind; - arrayInfo.size = arraySize; - arrayInfo.alignment = elementAlignment; - arrayInfo.elementStride = elementStride; - return arrayInfo; - } - - SimpleLayoutInfo GetVectorLayout(SimpleLayoutInfo elementInfo, size_t elementCount) override - { - SimpleLayoutInfo vectorInfo; - vectorInfo.kind = elementInfo.kind; - vectorInfo.size = elementInfo.size * elementCount; - vectorInfo.alignment = elementInfo.alignment; - return vectorInfo; - } - - SimpleArrayLayoutInfo GetMatrixLayout(SimpleLayoutInfo elementInfo, size_t rowCount, size_t columnCount) override - { - // The default behavior here is to lay out a matrix - // as an array of row vectors (that is row-major). - // - // In practice, the code that calls `GetMatrixLayout` will - // potentially transpose the row/column counts in order - // to get layouts with a different convention. - // - return GetArrayLayout( - GetVectorLayout(elementInfo, columnCount), - rowCount); - } - - UniformLayoutInfo BeginStructLayout() override - { - UniformLayoutInfo structInfo(0, 1); - return structInfo; - } - - LayoutSize AddStructField(UniformLayoutInfo* ioStructInfo, UniformLayoutInfo fieldInfo) override - { - // Skip zero-size fields - if(fieldInfo.size == 0) - return ioStructInfo->size; - - // A struct type must be at least as aligned as its most-aligned field. - ioStructInfo->alignment = std::max(ioStructInfo->alignment, fieldInfo.alignment); - - // The new field will be added to the end of the struct. - auto fieldBaseOffset = ioStructInfo->size; - - // We need to ensure that the offset for the field will respect its alignment - auto fieldOffset = RoundToAlignment(fieldBaseOffset, fieldInfo.alignment); - - // The size of the struct must be adjusted to cover the bytes consumed - // by this field. - ioStructInfo->size = fieldOffset + fieldInfo.size; - - return fieldOffset; - } - - - void EndStructLayout(UniformLayoutInfo* ioStructInfo) override - { - SLANG_UNUSED(ioStructInfo); - - // Note: A traditional C layout algorithm would adjust the size - // of a struct type so that it is a multiple of the alignment. - // This is a parsimonious design choice because it means that - // `sizeof(T)` can both be used when copying/allocating a single - // value of type `T` or an array of N values, without having to - // consider more details. - // - // Of course the choice also has down-sides in that wrapping things - // into a `struct` can affect layout in ways that waste space. E.g., - // the following two cases don't lay out the same: - // - // struct S0 { double d; float f; float g; }; - // - // struct X { double d; float f; } - // struct S1 { X x; float g; } - // - // Even though `S0::g` and `S1::g` have the same amount of useful - // data in front of them, they will not land at the same offset, - // and the resulting struct sizes will differ (`sizeof(S0)` will be - // 16 while `sizeof(S1)` will be 24). - // - // Slang doesn't get to be opinionated about this stuff because - // there is already precedent in both HLSL and GLSL for types - // that have a size that is not rounded up to their alignment. - // - // Our default layout rules won't implement the C-like policy, - // and instead it will be injected in the concrete implementations - // that require it. - } -}; - - /// Common behavior for GLSL-family layout. -struct GLSLBaseLayoutRulesImpl : DefaultLayoutRulesImpl -{ - typedef DefaultLayoutRulesImpl Super; - - SimpleLayoutInfo GetVectorLayout(SimpleLayoutInfo elementInfo, size_t elementCount) override - { - // The `std140` and `std430` rules require vectors to be aligned to the next power of - // two up from their size (so a `float2` is 8-byte aligned, and a `float3` is - // 16-byte aligned). - // - // Note that in this case we have a type layout where the size is *not* a multiple - // of the alignment, so it should be possible to pack a scalar after a `float3`. - // - SLANG_RELEASE_ASSERT(elementInfo.kind == LayoutResourceKind::Uniform); - SLANG_RELEASE_ASSERT(elementInfo.size.isFinite()); - - auto size = elementInfo.size.getFiniteValue() * elementCount; - SimpleLayoutInfo vectorInfo( - LayoutResourceKind::Uniform, - size, - RoundUpToPowerOfTwo(size)); - return vectorInfo; - } - - SimpleArrayLayoutInfo GetArrayLayout( SimpleLayoutInfo elementInfo, LayoutSize elementCount) override - { - // The size of an array must be rounded up to be a multiple of its alignment. - // - auto info = Super::GetArrayLayout(elementInfo, elementCount); - info.size = RoundToAlignment(info.size, info.alignment); - return info; - } - - void EndStructLayout(UniformLayoutInfo* ioStructInfo) override - { - // The size of a `struct` must be rounded up to be a multiple of its alignment. - // - ioStructInfo->size = RoundToAlignment(ioStructInfo->size, ioStructInfo->alignment); - } -}; - - /// The GLSL `std430` layout rules. -struct Std430LayoutRulesImpl : GLSLBaseLayoutRulesImpl -{ - // These rules don't actually need any differences from our - // base/common GLSL layout rules. -}; - - /// The GLSL `std430` layout rules. -struct Std140LayoutRulesImpl : GLSLBaseLayoutRulesImpl -{ - typedef GLSLBaseLayoutRulesImpl Super; - - SimpleArrayLayoutInfo GetArrayLayout(SimpleLayoutInfo elementInfo, LayoutSize elementCount) override - { - // The `std140` rules require that array elements - // be aligned on 16-byte boundaries. - // - if(elementInfo.kind == LayoutResourceKind::Uniform) - { - if (elementInfo.alignment < 16) - elementInfo.alignment = 16; - } - return Super::GetArrayLayout(elementInfo, elementCount); - } - - UniformLayoutInfo BeginStructLayout() override - { - // The `std140` rules require that a `struct` type - // be at least 16-byte aligned. - // - return UniformLayoutInfo(0, 16); - } -}; - -struct HLSLConstantBufferLayoutRulesImpl : DefaultLayoutRulesImpl -{ - typedef DefaultLayoutRulesImpl Super; - - // Similar to GLSL `std140` rules, an HLSL constant buffer requires that - // `struct` and array types have 16-byte alignement. - // - // Unlike GLSL `std140`, the overall size of an array or `struct` type - // is *not* rounded up to the alignment, so it is possible for later - // fields to sneak into the "tail space" left behind by a preceding - // structure or array. E.g., in this example: - // - // struct S { float3 a[2]; float b; }; - // - // The stride of the array `a` is 16 bytes per element, but the size - // of `a` will only be 28 bytes (not 32), so that `b` can fit into - // the space after the last array element and the overall structure - // will have a size of 32 bytes. - - SimpleArrayLayoutInfo GetArrayLayout(SimpleLayoutInfo elementInfo, LayoutSize elementCount) override - { - if(elementInfo.kind == LayoutResourceKind::Uniform) - { - if (elementInfo.alignment < 16) - elementInfo.alignment = 16; - } - return Super::GetArrayLayout(elementInfo, elementCount); - } - - UniformLayoutInfo BeginStructLayout() override - { - return UniformLayoutInfo(0, 16); - } - - // HLSL layout rules do *not* impose additional alignment - // constraints on vectors (e.g., all of `float`, `float2`, - // `float3`, and `float4` have 4-byte alignment), but instead - // they impose a rule that any `struct` field must not - // "straddle" a 16-byte boundary. - // - // This has the effect of making it *look* like `float4` - // values have 16-byte alignment in practice, but the - // effects on `float2` and `float3` are more nuanched and - // lead to different result than the GLSL rules. - // - LayoutSize AddStructField(UniformLayoutInfo* ioStructInfo, UniformLayoutInfo fieldInfo) override - { - // Skip zero-size fields - if(fieldInfo.size == 0) - return ioStructInfo->size; - - ioStructInfo->alignment = std::max(ioStructInfo->alignment, fieldInfo.alignment); - ioStructInfo->size = RoundToAlignment(ioStructInfo->size, fieldInfo.alignment); - - LayoutSize fieldOffset = ioStructInfo->size; - LayoutSize fieldSize = fieldInfo.size; - - // Would this field cross a 16-byte boundary? - auto registerSize = 16; - auto startRegister = fieldOffset / registerSize; - auto endRegister = (fieldOffset + fieldSize - 1) / registerSize; - if (startRegister != endRegister) - { - ioStructInfo->size = RoundToAlignment(ioStructInfo->size, size_t(registerSize)); - fieldOffset = ioStructInfo->size; - } - - ioStructInfo->size += fieldInfo.size; - return fieldOffset; - } -}; - -struct HLSLStructuredBufferLayoutRulesImpl : DefaultLayoutRulesImpl -{ - // HLSL structured buffers drop the restrictions added for constant buffers, - // but retain the rules around not adjusting the size of an array or - // structure to its alignment. In this way they should match our - // default layout rules. -}; - -struct DefaultVaryingLayoutRulesImpl : DefaultLayoutRulesImpl -{ - LayoutResourceKind kind; - - DefaultVaryingLayoutRulesImpl(LayoutResourceKind kind) - : kind(kind) - {} - - - // hook to allow differentiating for input/output - virtual LayoutResourceKind getKind() - { - return kind; - } - - SimpleLayoutInfo GetScalarLayout(BaseType) override - { - // Assume that all scalars take up one "slot" - return SimpleLayoutInfo( - getKind(), - 1); - } - - SimpleLayoutInfo GetVectorLayout(SimpleLayoutInfo, size_t) override - { - // Vectors take up one slot by default - // - // TODO: some platforms may decide that vectors of `double` need - // special handling - return SimpleLayoutInfo( - getKind(), - 1); - } -}; - -struct GLSLVaryingLayoutRulesImpl : DefaultVaryingLayoutRulesImpl -{ - GLSLVaryingLayoutRulesImpl(LayoutResourceKind kind) - : DefaultVaryingLayoutRulesImpl(kind) - {} -}; - -struct HLSLVaryingLayoutRulesImpl : DefaultVaryingLayoutRulesImpl -{ - HLSLVaryingLayoutRulesImpl(LayoutResourceKind kind) - : DefaultVaryingLayoutRulesImpl(kind) - {} -}; - -// - -struct GLSLSpecializationConstantLayoutRulesImpl : DefaultLayoutRulesImpl -{ - LayoutResourceKind getKind() - { - return LayoutResourceKind::SpecializationConstant; - } - - SimpleLayoutInfo GetScalarLayout(BaseType) override - { - // Assume that all scalars take up one "slot" - return SimpleLayoutInfo( - getKind(), - 1); - } - - SimpleLayoutInfo GetVectorLayout(SimpleLayoutInfo, size_t elementCount) override - { - // GLSL doesn't support vectors of specialization constants, - // but we will assume that, if supported, they would use one slot per element. - return SimpleLayoutInfo( - getKind(), - elementCount); - } -}; - -GLSLSpecializationConstantLayoutRulesImpl kGLSLSpecializationConstantLayoutRulesImpl; - -// - -struct GLSLObjectLayoutRulesImpl : ObjectLayoutRulesImpl -{ - virtual SimpleLayoutInfo GetObjectLayout(ShaderParameterKind) override - { - // In Vulkan GLSL, pretty much every object is just a descriptor-table slot. - // We can refine this method once we support a case where this isn't true. - return SimpleLayoutInfo(LayoutResourceKind::DescriptorTableSlot, 1); - } -}; -GLSLObjectLayoutRulesImpl kGLSLObjectLayoutRulesImpl; - -struct GLSLPushConstantBufferObjectLayoutRulesImpl : GLSLObjectLayoutRulesImpl -{ - virtual SimpleLayoutInfo GetObjectLayout(ShaderParameterKind /*kind*/) override - { - // Special-case the layout for a constant-buffer, because we don't - // want it to allocate a descriptor-table slot - return SimpleLayoutInfo(LayoutResourceKind::PushConstantBuffer, 1); - } -}; -GLSLPushConstantBufferObjectLayoutRulesImpl kGLSLPushConstantBufferObjectLayoutRulesImpl_; - -struct GLSLShaderRecordConstantBufferObjectLayoutRulesImpl : GLSLObjectLayoutRulesImpl -{ - virtual SimpleLayoutInfo GetObjectLayout(ShaderParameterKind /*kind*/) override - { - // Special-case the layout for a constant-buffer, because we don't - // want it to allocate a descriptor-table slot - return SimpleLayoutInfo(LayoutResourceKind::ShaderRecord, 1); - } -}; -GLSLShaderRecordConstantBufferObjectLayoutRulesImpl kGLSLShaderRecordConstantBufferObjectLayoutRulesImpl_; - -struct HLSLObjectLayoutRulesImpl : ObjectLayoutRulesImpl -{ - virtual SimpleLayoutInfo GetObjectLayout(ShaderParameterKind kind) override - { - switch( kind ) - { - case ShaderParameterKind::ConstantBuffer: - return SimpleLayoutInfo(LayoutResourceKind::ConstantBuffer, 1); - - case ShaderParameterKind::TextureUniformBuffer: - case ShaderParameterKind::StructuredBuffer: - case ShaderParameterKind::RawBuffer: - case ShaderParameterKind::Buffer: - case ShaderParameterKind::Texture: - return SimpleLayoutInfo(LayoutResourceKind::ShaderResource, 1); - - case ShaderParameterKind::MutableStructuredBuffer: - case ShaderParameterKind::MutableRawBuffer: - case ShaderParameterKind::MutableBuffer: - case ShaderParameterKind::MutableTexture: - return SimpleLayoutInfo(LayoutResourceKind::UnorderedAccess, 1); - - case ShaderParameterKind::SamplerState: - return SimpleLayoutInfo(LayoutResourceKind::SamplerState, 1); - - case ShaderParameterKind::TextureSampler: - case ShaderParameterKind::MutableTextureSampler: - case ShaderParameterKind::InputRenderTarget: - // TODO: how to handle these? - default: - SLANG_UNEXPECTED("unhandled shader parameter kind"); - UNREACHABLE_RETURN(SimpleLayoutInfo()); - } - } -}; -HLSLObjectLayoutRulesImpl kHLSLObjectLayoutRulesImpl; - -// HACK: Treating ray-tracing input/output as if it was another -// case of varying input/output when it really needs to be -// based on byte storage/layout. -// -struct GLSLRayTracingLayoutRulesImpl : DefaultVaryingLayoutRulesImpl -{ - GLSLRayTracingLayoutRulesImpl(LayoutResourceKind kind) - : DefaultVaryingLayoutRulesImpl(kind) - {} -}; -struct HLSLRayTracingLayoutRulesImpl : DefaultVaryingLayoutRulesImpl -{ - HLSLRayTracingLayoutRulesImpl(LayoutResourceKind kind) - : DefaultVaryingLayoutRulesImpl(kind) - {} -}; - -Std140LayoutRulesImpl kStd140LayoutRulesImpl; -Std430LayoutRulesImpl kStd430LayoutRulesImpl; -HLSLConstantBufferLayoutRulesImpl kHLSLConstantBufferLayoutRulesImpl; -HLSLStructuredBufferLayoutRulesImpl kHLSLStructuredBufferLayoutRulesImpl; - -GLSLVaryingLayoutRulesImpl kGLSLVaryingInputLayoutRulesImpl(LayoutResourceKind::VertexInput); -GLSLVaryingLayoutRulesImpl kGLSLVaryingOutputLayoutRulesImpl(LayoutResourceKind::FragmentOutput); - -GLSLRayTracingLayoutRulesImpl kGLSLRayPayloadParameterLayoutRulesImpl(LayoutResourceKind::RayPayload); -GLSLRayTracingLayoutRulesImpl kGLSLCallablePayloadParameterLayoutRulesImpl(LayoutResourceKind::CallablePayload); -GLSLRayTracingLayoutRulesImpl kGLSLHitAttributesParameterLayoutRulesImpl(LayoutResourceKind::HitAttributes); - -HLSLVaryingLayoutRulesImpl kHLSLVaryingInputLayoutRulesImpl(LayoutResourceKind::VertexInput); -HLSLVaryingLayoutRulesImpl kHLSLVaryingOutputLayoutRulesImpl(LayoutResourceKind::FragmentOutput); - -HLSLRayTracingLayoutRulesImpl kHLSLRayPayloadParameterLayoutRulesImpl(LayoutResourceKind::RayPayload); -HLSLRayTracingLayoutRulesImpl kHLSLCallablePayloadParameterLayoutRulesImpl(LayoutResourceKind::CallablePayload); -HLSLRayTracingLayoutRulesImpl kHLSLHitAttributesParameterLayoutRulesImpl(LayoutResourceKind::HitAttributes); - -// - -struct GLSLLayoutRulesFamilyImpl : LayoutRulesFamilyImpl -{ - virtual LayoutRulesImpl* getConstantBufferRules() override; - virtual LayoutRulesImpl* getPushConstantBufferRules() override; - virtual LayoutRulesImpl* getTextureBufferRules() override; - virtual LayoutRulesImpl* getVaryingInputRules() override; - virtual LayoutRulesImpl* getVaryingOutputRules() override; - virtual LayoutRulesImpl* getSpecializationConstantRules() override; - virtual LayoutRulesImpl* getShaderStorageBufferRules() override; - virtual LayoutRulesImpl* getParameterBlockRules() override; - - LayoutRulesImpl* getRayPayloadParameterRules() override; - LayoutRulesImpl* getCallablePayloadParameterRules() override; - LayoutRulesImpl* getHitAttributesParameterRules() override; - - LayoutRulesImpl* getShaderRecordConstantBufferRules() override; -}; - -struct HLSLLayoutRulesFamilyImpl : LayoutRulesFamilyImpl -{ - virtual LayoutRulesImpl* getConstantBufferRules() override; - virtual LayoutRulesImpl* getPushConstantBufferRules() override; - virtual LayoutRulesImpl* getTextureBufferRules() override; - virtual LayoutRulesImpl* getVaryingInputRules() override; - virtual LayoutRulesImpl* getVaryingOutputRules() override; - virtual LayoutRulesImpl* getSpecializationConstantRules() override; - virtual LayoutRulesImpl* getShaderStorageBufferRules() override; - virtual LayoutRulesImpl* getParameterBlockRules() override; - - LayoutRulesImpl* getRayPayloadParameterRules() override; - LayoutRulesImpl* getCallablePayloadParameterRules() override; - LayoutRulesImpl* getHitAttributesParameterRules() override; - - LayoutRulesImpl* getShaderRecordConstantBufferRules() override; -}; - -GLSLLayoutRulesFamilyImpl kGLSLLayoutRulesFamilyImpl; -HLSLLayoutRulesFamilyImpl kHLSLLayoutRulesFamilyImpl; - - -// GLSL cases - -LayoutRulesImpl kStd140LayoutRulesImpl_ = { - &kGLSLLayoutRulesFamilyImpl, &kStd140LayoutRulesImpl, &kGLSLObjectLayoutRulesImpl, -}; - -LayoutRulesImpl kStd430LayoutRulesImpl_ = { - &kGLSLLayoutRulesFamilyImpl, &kStd430LayoutRulesImpl, &kGLSLObjectLayoutRulesImpl, -}; - -LayoutRulesImpl kGLSLPushConstantLayoutRulesImpl_ = { - &kGLSLLayoutRulesFamilyImpl, &kStd430LayoutRulesImpl, &kGLSLPushConstantBufferObjectLayoutRulesImpl_, -}; - -LayoutRulesImpl kGLSLShaderRecordLayoutRulesImpl_ = { - &kGLSLLayoutRulesFamilyImpl, &kStd430LayoutRulesImpl, &kGLSLShaderRecordConstantBufferObjectLayoutRulesImpl_, -}; - -LayoutRulesImpl kGLSLVaryingInputLayoutRulesImpl_ = { - &kGLSLLayoutRulesFamilyImpl, &kGLSLVaryingInputLayoutRulesImpl, &kGLSLObjectLayoutRulesImpl, -}; - -LayoutRulesImpl kGLSLVaryingOutputLayoutRulesImpl_ = { - &kGLSLLayoutRulesFamilyImpl, &kGLSLVaryingOutputLayoutRulesImpl, &kGLSLObjectLayoutRulesImpl, -}; - -LayoutRulesImpl kGLSLSpecializationConstantLayoutRulesImpl_ = { - &kGLSLLayoutRulesFamilyImpl, &kGLSLSpecializationConstantLayoutRulesImpl, &kGLSLObjectLayoutRulesImpl, -}; - -LayoutRulesImpl kGLSLRayPayloadParameterLayoutRulesImpl_ = { - &kGLSLLayoutRulesFamilyImpl, &kGLSLRayPayloadParameterLayoutRulesImpl, &kGLSLObjectLayoutRulesImpl, -}; - -LayoutRulesImpl kGLSLCallablePayloadParameterLayoutRulesImpl_ = { - &kGLSLLayoutRulesFamilyImpl, &kGLSLCallablePayloadParameterLayoutRulesImpl, &kGLSLObjectLayoutRulesImpl, -}; - -LayoutRulesImpl kGLSLHitAttributesParameterLayoutRulesImpl_ = { - &kGLSLLayoutRulesFamilyImpl, &kGLSLHitAttributesParameterLayoutRulesImpl, &kGLSLObjectLayoutRulesImpl, -}; - -// HLSL cases - -LayoutRulesImpl kHLSLConstantBufferLayoutRulesImpl_ = { - &kHLSLLayoutRulesFamilyImpl, &kHLSLConstantBufferLayoutRulesImpl, &kHLSLObjectLayoutRulesImpl, -}; - -LayoutRulesImpl kHLSLStructuredBufferLayoutRulesImpl_ = { - &kHLSLLayoutRulesFamilyImpl, &kHLSLStructuredBufferLayoutRulesImpl, &kHLSLObjectLayoutRulesImpl, -}; - -LayoutRulesImpl kHLSLVaryingInputLayoutRulesImpl_ = { - &kHLSLLayoutRulesFamilyImpl, &kHLSLVaryingInputLayoutRulesImpl, &kHLSLObjectLayoutRulesImpl, -}; - -LayoutRulesImpl kHLSLVaryingOutputLayoutRulesImpl_ = { - &kHLSLLayoutRulesFamilyImpl, &kHLSLVaryingOutputLayoutRulesImpl, &kHLSLObjectLayoutRulesImpl, -}; - -LayoutRulesImpl kHLSLRayPayloadParameterLayoutRulesImpl_ = { - &kHLSLLayoutRulesFamilyImpl, &kHLSLRayPayloadParameterLayoutRulesImpl, &kHLSLObjectLayoutRulesImpl, -}; - -LayoutRulesImpl kHLSLCallablePayloadParameterLayoutRulesImpl_ = { - &kHLSLLayoutRulesFamilyImpl, &kHLSLCallablePayloadParameterLayoutRulesImpl, &kHLSLObjectLayoutRulesImpl, -}; - -LayoutRulesImpl kHLSLHitAttributesParameterLayoutRulesImpl_ = { - &kHLSLLayoutRulesFamilyImpl, &kHLSLHitAttributesParameterLayoutRulesImpl, &kHLSLObjectLayoutRulesImpl, -}; - -// - -LayoutRulesImpl* GLSLLayoutRulesFamilyImpl::getConstantBufferRules() -{ - return &kStd140LayoutRulesImpl_; -} - -LayoutRulesImpl* GLSLLayoutRulesFamilyImpl::getParameterBlockRules() -{ - // TODO: actually pick something appropriate - return &kStd140LayoutRulesImpl_; -} - -LayoutRulesImpl* GLSLLayoutRulesFamilyImpl::getPushConstantBufferRules() -{ - return &kGLSLPushConstantLayoutRulesImpl_; -} - -LayoutRulesImpl* GLSLLayoutRulesFamilyImpl::getShaderRecordConstantBufferRules() -{ - return &kGLSLShaderRecordLayoutRulesImpl_; -} - -LayoutRulesImpl* GLSLLayoutRulesFamilyImpl::getTextureBufferRules() -{ - return nullptr; -} - -LayoutRulesImpl* GLSLLayoutRulesFamilyImpl::getVaryingInputRules() -{ - return &kGLSLVaryingInputLayoutRulesImpl_; -} - -LayoutRulesImpl* GLSLLayoutRulesFamilyImpl::getVaryingOutputRules() -{ - return &kGLSLVaryingOutputLayoutRulesImpl_; -} - -LayoutRulesImpl* GLSLLayoutRulesFamilyImpl::getSpecializationConstantRules() -{ - return &kGLSLSpecializationConstantLayoutRulesImpl_; -} - -LayoutRulesImpl* GLSLLayoutRulesFamilyImpl::getShaderStorageBufferRules() -{ - return &kStd430LayoutRulesImpl_; -} - -LayoutRulesImpl* GLSLLayoutRulesFamilyImpl::getRayPayloadParameterRules() -{ - return &kGLSLRayPayloadParameterLayoutRulesImpl_; -} - -LayoutRulesImpl* GLSLLayoutRulesFamilyImpl::getCallablePayloadParameterRules() -{ - return &kGLSLCallablePayloadParameterLayoutRulesImpl_; -} - -LayoutRulesImpl* GLSLLayoutRulesFamilyImpl::getHitAttributesParameterRules() -{ - return &kGLSLHitAttributesParameterLayoutRulesImpl_; -} - -// - -LayoutRulesImpl* HLSLLayoutRulesFamilyImpl::getConstantBufferRules() -{ - return &kHLSLConstantBufferLayoutRulesImpl_; -} - -LayoutRulesImpl* HLSLLayoutRulesFamilyImpl::getParameterBlockRules() -{ - // TODO: actually pick something appropriate... - return &kHLSLConstantBufferLayoutRulesImpl_; -} - - -LayoutRulesImpl* HLSLLayoutRulesFamilyImpl::getPushConstantBufferRules() -{ - return &kHLSLConstantBufferLayoutRulesImpl_; -} - -LayoutRulesImpl* HLSLLayoutRulesFamilyImpl::getShaderRecordConstantBufferRules() -{ - return &kHLSLConstantBufferLayoutRulesImpl_; -} - -LayoutRulesImpl* HLSLLayoutRulesFamilyImpl::getTextureBufferRules() -{ - return nullptr; -} - -LayoutRulesImpl* HLSLLayoutRulesFamilyImpl::getVaryingInputRules() -{ - return &kHLSLVaryingInputLayoutRulesImpl_; -} - -LayoutRulesImpl* HLSLLayoutRulesFamilyImpl::getVaryingOutputRules() -{ - return &kHLSLVaryingOutputLayoutRulesImpl_; -} - -LayoutRulesImpl* HLSLLayoutRulesFamilyImpl::getSpecializationConstantRules() -{ - return nullptr; -} - -LayoutRulesImpl* HLSLLayoutRulesFamilyImpl::getShaderStorageBufferRules() -{ - return nullptr; -} - -LayoutRulesImpl* HLSLLayoutRulesFamilyImpl::getRayPayloadParameterRules() -{ - return &kHLSLRayPayloadParameterLayoutRulesImpl_; -} - -LayoutRulesImpl* HLSLLayoutRulesFamilyImpl::getCallablePayloadParameterRules() -{ - return &kHLSLCallablePayloadParameterLayoutRulesImpl_; -} - -LayoutRulesImpl* HLSLLayoutRulesFamilyImpl::getHitAttributesParameterRules() -{ - return &kHLSLHitAttributesParameterLayoutRulesImpl_; -} - - - -// - -LayoutRulesImpl* GetLayoutRulesImpl(LayoutRule rule) -{ - switch (rule) - { - case LayoutRule::Std140: return &kStd140LayoutRulesImpl_; - case LayoutRule::Std430: return &kStd430LayoutRulesImpl_; - case LayoutRule::HLSLConstantBuffer: return &kHLSLConstantBufferLayoutRulesImpl_; - case LayoutRule::HLSLStructuredBuffer: return &kHLSLStructuredBufferLayoutRulesImpl_; - default: - return nullptr; - } -} - -LayoutRulesFamilyImpl* getDefaultLayoutRulesFamilyForTarget(TargetRequest* targetReq) -{ - switch (targetReq->getTarget()) - { - case CodeGenTarget::HLSL: - case CodeGenTarget::DXBytecode: - case CodeGenTarget::DXBytecodeAssembly: - case CodeGenTarget::DXIL: - case CodeGenTarget::DXILAssembly: - return &kHLSLLayoutRulesFamilyImpl; - - case CodeGenTarget::GLSL: - case CodeGenTarget::SPIRV: - case CodeGenTarget::SPIRVAssembly: - return &kGLSLLayoutRulesFamilyImpl; - - - case CodeGenTarget::CPPSource: - case CodeGenTarget::CSource: - { - // We just need to decide here what style of layout is appropriate, in terms of memory - // and binding. That in terms of the actual binding that will be injected into functions - // in the form of a BindContext. For now we'll go with HLSL layout - - // that we may want to rethink that with the use of arrays and binding VK style binding might be - // more appropriate in some ways. - - return &kHLSLLayoutRulesFamilyImpl; - } - - default: - return nullptr; - } -} - -TypeLayoutContext getInitialLayoutContextForTarget(TargetRequest* targetReq, ProgramLayout* programLayout) -{ - LayoutRulesFamilyImpl* rulesFamily = getDefaultLayoutRulesFamilyForTarget(targetReq); - - TypeLayoutContext context; - context.targetReq = targetReq; - context.programLayout = programLayout; - context.rules = nullptr; - context.matrixLayoutMode = targetReq->getDefaultMatrixLayoutMode(); - - if( rulesFamily ) - { - context.rules = rulesFamily->getConstantBufferRules(); - } - - return context; -} - - -static LayoutSize GetElementCount(RefPtr val) -{ - // Lack of a size indicates an unbounded array. - if(!val) - return LayoutSize::infinite(); - - if (auto constantVal = as(val)) - { - return LayoutSize(LayoutSize::RawValue(constantVal->value)); - } - else if( auto varRefVal = as(val) ) - { - // TODO: We want to treat the case where the number of - // elements in an array depends on a generic parameter - // much like the case where the number of elements is - // unbounded, *but* we can't just blindly do that because - // an API might disallow unbounded arrays in various - // cases where a generic bound might work (because - // any concrete specialization will have a finite bound...) - // - return 0; - } - SLANG_UNEXPECTED("unhandled integer literal kind"); - UNREACHABLE_RETURN(LayoutSize(0)); -} - -bool IsResourceKind(LayoutResourceKind kind) -{ - switch (kind) - { - case LayoutResourceKind::None: - case LayoutResourceKind::Uniform: - return false; - - default: - return true; - } - -} - - /// Create a type layout for a type that has simple layout needs. - /// - /// This handles any type that can express its layout in `SimpleLayoutInfo`, - /// and that only needs a `TypeLayout` and not a refined subclass. - /// -static TypeLayoutResult createSimpleTypeLayout( - SimpleLayoutInfo info, - RefPtr type, - LayoutRulesImpl* rules) -{ - RefPtr typeLayout = new TypeLayout(); - - typeLayout->type = type; - typeLayout->rules = rules; - - typeLayout->uniformAlignment = info.alignment; - - typeLayout->addResourceUsage(info.kind, info.size); - - return TypeLayoutResult(typeLayout, info); -} - -static SimpleLayoutInfo getParameterGroupLayoutInfo( - RefPtr type, - LayoutRulesImpl* rules) -{ - if( as(type) ) - { - return rules->GetObjectLayout(ShaderParameterKind::ConstantBuffer); - } - else if( as(type) ) - { - return rules->GetObjectLayout(ShaderParameterKind::TextureUniformBuffer); - } - else if( as(type) ) - { - return rules->GetObjectLayout(ShaderParameterKind::ShaderStorageBuffer); - } - else if (as(type)) - { - // Note: we default to consuming zero register spces here, because - // a parameter block might not contain anything (or all it contains - // is other blocks), and so it won't get a space allocated. - // - // This choice *also* means that in the case where we don't actually - // want to allocate register spaces to blocks at all, we haven't - // committed to that choice here. - // - // TODO: wouldn't it be any different to just allocate this - // as an empty `SimpleLayoutInfo` of any other kind? - return SimpleLayoutInfo(LayoutResourceKind::RegisterSpace, 0); - } - - // TODO: the vertex-input and fragment-output cases should - // only actually apply when we are at the appropriate stage in - // the pipeline... - else if( as(type) ) - { - return SimpleLayoutInfo(LayoutResourceKind::VertexInput, 0); - } - else if( as(type) ) - { - return SimpleLayoutInfo(LayoutResourceKind::FragmentOutput, 0); - } - else - { - SLANG_UNEXPECTED("unhandled parameter block type"); - UNREACHABLE_RETURN(SimpleLayoutInfo()); - } -} - -static bool isOpenGLTarget(TargetRequest*) -{ - // We aren't officially supporting OpenGL right now - return false; -} - -bool isD3DTarget(TargetRequest* targetReq) -{ - switch( targetReq->getTarget() ) - { - case CodeGenTarget::HLSL: - case CodeGenTarget::DXBytecode: - case CodeGenTarget::DXBytecodeAssembly: - case CodeGenTarget::DXIL: - case CodeGenTarget::DXILAssembly: - return true; - - default: - return false; - } -} - -bool isKhronosTarget(TargetRequest* targetReq) -{ - switch( targetReq->getTarget() ) - { - default: - return false; - - case CodeGenTarget::GLSL: - case CodeGenTarget::SPIRV: - case CodeGenTarget::SPIRVAssembly: - return true; - } -} - -static bool isD3D11Target(TargetRequest*) -{ - // We aren't officially supporting D3D11 right now - return false; -} - -static bool isD3D12Target(TargetRequest* targetReq) -{ - // We are currently only officially supporting D3D12 - return isD3DTarget(targetReq); -} - - -static bool isSM5OrEarlier(TargetRequest* targetReq) -{ - if(!isD3DTarget(targetReq)) - return false; - - auto profile = targetReq->getTargetProfile(); - - if(profile.getFamily() == ProfileFamily::DX) - { - if(profile.GetVersion() <= ProfileVersion::DX_5_0) - return true; - } - - return false; -} - -static bool isSM5_1OrLater(TargetRequest* targetReq) -{ - if(!isD3DTarget(targetReq)) - return false; - - auto profile = targetReq->getTargetProfile(); - - if(profile.getFamily() == ProfileFamily::DX) - { - if(profile.GetVersion() >= ProfileVersion::DX_5_1) - return true; - } - - return false; -} - -static bool isVulkanTarget(TargetRequest* targetReq) -{ - // For right now, any Khronos-related target is assumed - // to be a Vulkan target. - return isKhronosTarget(targetReq); -} - -static bool shouldAllocateRegisterSpaceForParameterBlock( - TypeLayoutContext const& context) -{ - auto targetReq = context.targetReq; - - // We *never* want to use register spaces/sets under - // OpenGL, D3D11, or for Shader Model 5.0 or earlier. - if(isOpenGLTarget(targetReq) || isD3D11Target(targetReq) || isSM5OrEarlier(targetReq)) - return false; - - // If we know that we are targetting Vulkan, then - // the only way to effectively use parameter blocks - // is by using descriptor sets. - if(isVulkanTarget(targetReq)) - return true; - - // If none of the above passed, then it seems like we - // are generating code for D3D12, and using SM5.1 or later. - // We will use a register space for parameter blocks *if* - // the target options tell us to: - if( isD3D12Target(targetReq) && isSM5_1OrLater(targetReq) ) - { - return true; - } - - return false; -} - -// Given an existing type layout `oldTypeLayout`, apply offsets -// to any contained fields based on the resource infos in `offsetVarLayout`. -RefPtr applyOffsetToTypeLayout( - RefPtr oldTypeLayout, - RefPtr offsetVarLayout) -{ - // There is no need to apply offsets if the old type and the offset - // don't share any resource infos in common. - bool anyHit = false; - for (auto oldResInfo : oldTypeLayout->resourceInfos) - { - if (auto offsetResInfo = offsetVarLayout->FindResourceInfo(oldResInfo.kind)) - { - anyHit = true; - break; - } - } - - if (!anyHit) - return oldTypeLayout; - - RefPtr newTypeLayout; - if (auto oldStructTypeLayout = oldTypeLayout.as()) - { - RefPtr newStructTypeLayout = new StructTypeLayout(); - newStructTypeLayout->type = oldStructTypeLayout->type; - newStructTypeLayout->uniformAlignment = oldStructTypeLayout->uniformAlignment; - - Dictionary mapOldFieldToNew; - - for (auto oldField : oldStructTypeLayout->fields) - { - RefPtr newField = new VarLayout(); - newField->varDecl = oldField->varDecl; - newField->typeLayout = oldField->typeLayout; - newField->flags = oldField->flags; - newField->semanticIndex = oldField->semanticIndex; - newField->semanticName = oldField->semanticName; - newField->stage = oldField->stage; - newField->systemValueSemantic = oldField->systemValueSemantic; - newField->systemValueSemanticIndex = oldField->systemValueSemanticIndex; - - - for (auto oldResInfo : oldField->resourceInfos) - { - auto newResInfo = newField->findOrAddResourceInfo(oldResInfo.kind); - newResInfo->index = oldResInfo.index; - newResInfo->space = oldResInfo.space; - if (auto offsetResInfo = offsetVarLayout->FindResourceInfo(oldResInfo.kind)) - { - newResInfo->index += offsetResInfo->index; - } - } - - newStructTypeLayout->fields.add(newField); - - mapOldFieldToNew.Add(oldField.Ptr(), newField.Ptr()); - } - - for (auto entry : oldStructTypeLayout->mapVarToLayout) - { - VarLayout* newFieldLayout = nullptr; - if (mapOldFieldToNew.TryGetValue(entry.Value.Ptr(), newFieldLayout)) - { - newStructTypeLayout->mapVarToLayout.Add(entry.Key, newFieldLayout); - } - } - - newTypeLayout = newStructTypeLayout; - } - else - { - // TODO: need to handle other cases here - return oldTypeLayout; - } - - // No matter what replacement we plug in for the element type, we need to copy - // over its resource usage: - for (auto oldResInfo : oldTypeLayout->resourceInfos) - { - auto newResInfo = newTypeLayout->findOrAddResourceInfo(oldResInfo.kind); - newResInfo->count = oldResInfo.count; - } - - return newTypeLayout; -} - -static bool _usesResourceKind(RefPtr typeLayout, LayoutResourceKind kind) -{ - auto resInfo = typeLayout->FindResourceInfo(kind); - return resInfo && resInfo->count != 0; -} - -static bool _usesOrdinaryData(RefPtr typeLayout) -{ - return _usesResourceKind(typeLayout, LayoutResourceKind::Uniform); -} - - /// Add resource usage from `srcTypeLayout` to `dstTypeLayout` unless it would be "masked." - /// - /// This function is appropriate for applying resource usage from an element type - /// to the resource usage of a container like a `ConstantBuffer` or - /// `ParameterBlock`. - /// -static void _addUnmaskedResourceUsage( - TypeLayout* dstTypeLayout, - TypeLayout* srcTypeLayout, - bool haveFullRegisterSpaceOrSet) -{ - for( auto resInfo : srcTypeLayout->resourceInfos ) - { - switch( resInfo.kind ) - { - case LayoutResourceKind::Uniform: - // Ordinary/uniform resource usage will always be masked. - break; - - case LayoutResourceKind::RegisterSpace: - case LayoutResourceKind::ExistentialTypeParam: - // A parameter group will always pay for full registers - // spaces consumed by its element type. - // - // The same is true for existential type parameters, - // since these need to be exposed up through the API. - // - dstTypeLayout->addResourceUsage(resInfo); - break; - - default: - // For all other resource kinds, a parameter group - // will be able to mask them if and only if it - // has a full space/set allocated to it. - // - // Otherwise, the resource usage of the group must - // include the resource usage of the element. - // - if( !haveFullRegisterSpaceOrSet ) - { - dstTypeLayout->addResourceUsage(resInfo); - } - break; - } - } -} - -static RefPtr _createParameterGroupTypeLayout( - TypeLayoutContext const& context, - RefPtr parameterGroupType, - RefPtr rawElementTypeLayout) -{ - // We are being asked to create a layout for a parameter group, - // which is curently either a `ParameterBlock` or a `ConstantBuffer` - // - auto parameterGroupRules = context.rules; - RefPtr typeLayout = new ParameterGroupTypeLayout(); - typeLayout->type = parameterGroupType; - typeLayout->rules = parameterGroupRules; - - // Computing the layout is made tricky by several factors. - // - // A parameter group has to draw a distinction between the element type, - // and the resources it consumes, and the "container," which main - // consume other resources. The type of resource consumed by - // the two can overlap. - // - // Consider: - // - // struct MyMaterial { float2 uvScale; Texture2D albedoMap; } - // ParameterBlock gMaterial; - // - // In this example, `gMaterial` will need both a constant buffer - // binding (to hold the data for `uvScale`) and a texture binding - // (for `albedoMap`). On Vulkan, those two things require the *same* - // `LayoutResourceKind` (representing a GLSL `binding`). We will - // thus track the resource usage of the "container" type and - // element type separately, and then combine these to form - // the overall layout for the parameter group. - - RefPtr containerTypeLayout = new TypeLayout(); - containerTypeLayout->type = parameterGroupType; - containerTypeLayout->rules = parameterGroupRules; - - // Because the container and element types will each be situated - // at some offset relative to the initial register/binding for - // the group as a whole, we allocate a `VarLayout` for both - // the container and the element type, to store that offset - // information (think of `TypeLayout`s as holding size information, - // while `VarLayout`s hold offset information). - - RefPtr containerVarLayout = new VarLayout(); - containerVarLayout->typeLayout = containerTypeLayout; - typeLayout->containerVarLayout = containerVarLayout; - - RefPtr elementVarLayout = new VarLayout(); - elementVarLayout->typeLayout = rawElementTypeLayout; - typeLayout->elementVarLayout = elementVarLayout; - - // It is possible to have a `ConstantBuffer` that doesn't - // actually need a constant buffer register/binding allocated to it, - // because the type `T` doesn't actually contain any ordinary/uniform - // data that needs to go into the constant buffer. For example: - // - // struct MyMaterial { Texture2D t; SamplerState s; }; - // ConstantBuffer gMaterial; - // - // In this example, the `gMaterial` parameter doesn't actually need - // a constant buffer allocated for it. This isn't something that - // comes up often for `ConstantBuffer`, but can happen a lot for - // `ParameterBlock`. - // - // To determine if we actually need a constant-buffer binding, - // we will inspect the element type and see if it contains - // any ordinary/uniform data. - // - bool wantConstantBuffer = _usesOrdinaryData(rawElementTypeLayout); - if( wantConstantBuffer ) - { - // If there is any ordinary data, then we'll need to - // allocate a constant buffer regiser/binding into - // the overall layout, to account for it. - // - auto cbUsage = parameterGroupRules->GetObjectLayout(ShaderParameterKind::ConstantBuffer); - containerTypeLayout->addResourceUsage(cbUsage.kind, cbUsage.size); - } - - // Similarly to how we only need a constant buffer to be allocated - // if the contents of the group actually had ordinary/uniform data, - // we also only want to allocate a `space` or `set` if that is really - // required. - // - // - bool canUseSpaceOrSet = false; - // - // We will only allocate a `space` or `set` if the type is `ParameterBlock` - // and not just `ConstantBuffer`. - // - // Note: `parameterGroupType` is allowed to be null here, if we are allocating - // an anonymous constant buffer for global or entry-point parameters, but that - // is fine because the case will just return null in that case anyway. - // - auto parameterBlockType = as(parameterGroupType); - if( parameterBlockType ) - { - // We also can't allocate a `space` or `set` unless the compilation - // target actually supports them. - // - if( shouldAllocateRegisterSpaceForParameterBlock(context) ) - { - canUseSpaceOrSet = true; - } - } - - // Just knowing that we *can* use a `space` or `set` doesn't tell - // us if we would *like* to. - // - // The basic rule here is that if the element type of the parameter - // block contains anything that isn't itself consuming a full - // register `space` or `set`, then we'll want an umbrella `space`/`set` - // for all such data. - // - bool wantSpaceOrSet = false; - if( canUseSpaceOrSet ) - { - // Note that if we are allocating a constant buffer to hold - // some ordinary/uniform data then we definitely want a space/set, - // but we don't need to special-case that because the loop - // here will also detect the `LayoutResourceKind::Uniform` usage. - - for( auto elementResourceInfo : rawElementTypeLayout->resourceInfos ) - { - if(elementResourceInfo.kind != LayoutResourceKind::RegisterSpace) - { - wantSpaceOrSet = true; - break; - } - } - } - - // If after all that we determine that we want a register space/set, - // then we allocate one as part of the overall resource usage for - // the parameter group type. - // - if( wantSpaceOrSet ) - { - containerTypeLayout->addResourceUsage(LayoutResourceKind::RegisterSpace, 1); - } - - // Now that we've computed basic resource requirements for the container - // part of things (i.e., does it require a constant buffer or not?), - // let's go ahead and assign the container variable a relative offset - // of zero for each of the kinds of resources that it consumes. - // - for( auto typeResInfo : containerTypeLayout->resourceInfos ) - { - containerVarLayout->findOrAddResourceInfo(typeResInfo.kind); - } - - // Because the container's resource allocation is logically coming - // first in the overall group, the element needs to have a layout - // such that it comes *after* the container in the relative order. - // - for( auto elementTypeResInfo : rawElementTypeLayout->resourceInfos ) - { - auto kind = elementTypeResInfo.kind; - auto elementVarResInfo = elementVarLayout->findOrAddResourceInfo(kind); - - // If the container part of things is using the same resource kind - // as the element type, then the element needs to start at an offset - // after the container. - // - if( auto containerTypeResInfo = containerTypeLayout->FindResourceInfo(kind) ) - { - SLANG_RELEASE_ASSERT(containerTypeResInfo->count.isFinite()); - elementVarResInfo->index += containerTypeResInfo->count.getFiniteValue(); - } - } - - // The existing Slang reflection API was created before we really - // understood the wrinkle that the "container" and elements parts - // of a parameter group could collide on some resource kinds, - // so the API doesn't currently expose the nice `VarLayout`s we've - // just computed. - // - // Instead, the API allows the user to query the element type layout - // for the group, and the user just assumes that the offsetting - // is magically applied there. To go back to the earlier example: - // - // struct MyMaterial { Texture2D t; SamplerState s; }; - // ConstantBuffer gMaterial; - // - // A user of the existing reflection API expects to be able to - // query the `binding` of `gMaterial` and get back zero, then - // query the `binding` of the `t` field of the element type - // and get *one*. It is clear that in the abstract, the - // `MyMaterial::t` field should have an offset of zero (as - // the first field in a `struct`), so to meet the user's - // expectations, some cleverness is needed. - // - // We will use a subroutine `applyOffsetToTypeLayout` - // that tries to recursively walk an existing `TypeLayout` - // and apply an offset to its fields. This is currently - // quite ad hoc, but that doesn't matter much as it - // handles `struct` types which are the 99% case for - // parameter blocks. - // - typeLayout->offsetElementTypeLayout = applyOffsetToTypeLayout(rawElementTypeLayout, elementVarLayout); - - // Next, resource usage from the container and element - // types may need to "bleed through" to the overall - // parameter group type. - // - // If the parameter group is a `ConstantBuffer` then - // any ordinary/uniform bytes consumed by `Foo` are masked, - // but any other resources it consumes (e.g. `binding`s) need - // to bleed through and be accounted for in the overall - // layout of the type. - // - // If we have a `ParameterBlock` then any ordinary/uniform - // bytes are masked. Furthermore, *if* a whole `space`/`set` - // was allocated to the block, then any `register`s or - // `binding`s consumed by `Foo` (and by the "container" constant - // buffer if we allocated one) are also masked. Any whole - // spaces/sets consumed by `Foo` need to bleed through. - // - // We can start with the easier case of the container type, - // since it will either be empty or consume a single constant - // buffer. Its resource usage will only bleed through if we - // didn't allocate a full `space` or `set`. - // - _addUnmaskedResourceUsage(typeLayout, containerTypeLayout, wantSpaceOrSet); - - // next we turn to the element type, where the cases are slightly - // more involved (technically we could use this same logic for - // the container, as it is more general, but it was simpler to - // just special-case the container). - // - - _addUnmaskedResourceUsage(typeLayout, rawElementTypeLayout, wantSpaceOrSet); - - // At this point we have handled all the complexities that - // arise for a parameter group that doesn't include interface-type - // fields, or that doesn't include specialization for those fields. - // - // The remaining complexity all arises if we have interface-type - // data in the parameter group, and we are specializing it to - // concrete types, that will have their own layout requirements. - // In those cases there will be "pending data" on the element - // type layout that need to get placed somwhere, but wasn't - // included in the layout computed so far. - // - // All of this is extra work we only have to do if there is - // "pending" data in the element type layout. - // - if( auto pendingElementTypeLayout = rawElementTypeLayout->pendingDataTypeLayout ) - { - auto rules = rawElementTypeLayout->rules; - - // One really annoying complication we need to deal with here - // its that it is possible that the original parameter group - // declaration didn't need a constant buffer or `space`/`set` - // to be allocated, but once we consider the "pending" data - // we need to have a constant buffer and/or space. - // - // We will compute whether the pending data create a demand - // for a constant buffer and/or a space/set, so that we know - // if we are in the tricky case. - // - bool pendingDataWantsConstantBuffer = _usesOrdinaryData(pendingElementTypeLayout); - bool pendingDataWantsSpaceOrSet = false; - if( canUseSpaceOrSet ) - { - for( auto resInfo : pendingElementTypeLayout->resourceInfos ) - { - if( resInfo.kind != LayoutResourceKind::RegisterSpace ) - { - pendingDataWantsSpaceOrSet = true; - break; - } - } - } - - // We will use a few different variables to track resource - // usage for the pending data, with roles similar to the - // umbrella type layout, container layout, and element layout - // that already came up for the main part of the parameter group type. - - - RefPtr pendingContainerTypeLayout = new TypeLayout(); - pendingContainerTypeLayout->type = parameterGroupType; - pendingContainerTypeLayout->rules = parameterGroupRules; - - containerTypeLayout->pendingDataTypeLayout = pendingContainerTypeLayout; - - RefPtr pendingContainerVarLayout = new VarLayout(); - pendingContainerVarLayout->typeLayout = pendingContainerTypeLayout; - - containerVarLayout->pendingVarLayout = pendingContainerVarLayout; - - - RefPtr pendingElementVarLayout = new VarLayout(); - pendingElementVarLayout->typeLayout = pendingElementTypeLayout; - - elementVarLayout->pendingVarLayout = pendingElementVarLayout; - - // If we need a space/set for the pending data, and don't already - // have one, then we will allocate it now, as part of the - // "full" data type. - // - if( pendingDataWantsSpaceOrSet && !wantSpaceOrSet ) - { - pendingContainerTypeLayout->addResourceUsage(LayoutResourceKind::RegisterSpace, 1); - - // From here on, we know we have access to a register space, - // and we can mask any registers/bindings appropriately. - // - wantSpaceOrSet = true; - } - - // If we need a constant buffer for laying out ordinary - // data, and didn't have one allocated before, we will create - // one. - // - if( pendingDataWantsConstantBuffer && !wantConstantBuffer ) - { - auto cbUsage = rules->GetObjectLayout(ShaderParameterKind::ConstantBuffer); - pendingContainerTypeLayout->addResourceUsage(cbUsage.kind, cbUsage.size); - - wantConstantBuffer = true; - } - - for( auto resInfo : pendingContainerTypeLayout->resourceInfos ) - { - pendingContainerVarLayout->findOrAddResourceInfo(resInfo.kind); - } - - // Now that we've added in the resource usage for any CB or set/space - // we needed to allocate just for the pending data, we can safely - // lay out the pending data itself. - // - // The ordinary/uniform part of things wil always be "masked" and - // needs to come after any uniform data from the original element type. - // - // To kick things off we will initialize state for `struct` type layout, - // so that we can lay out the pending data as if it were the second - // field in a structure type, after the original data. - // - UniformLayoutInfo uniformLayout = rules->BeginStructLayout(); - if( auto resInfo = rawElementTypeLayout->FindResourceInfo(LayoutResourceKind::Uniform) ) - { - uniformLayout.alignment = rawElementTypeLayout->uniformAlignment; - uniformLayout.size = resInfo->count; - } - - // Now we can scan through the resources used by the pending data. - // - for( auto resInfo : pendingElementTypeLayout->resourceInfos ) - { - if( resInfo.kind == LayoutResourceKind::Uniform ) - { - // For the ordinary/uniform resource kind, we will add the resource - // usage as a structure field, and then write the resulting offset - // into the variable layout for the pending data. - // - auto offset = rules->AddStructField( - &uniformLayout, - UniformLayoutInfo( - resInfo.count, - pendingElementTypeLayout->uniformAlignment)); - pendingElementVarLayout->findOrAddResourceInfo(resInfo.kind)->index = offset.getFiniteValue(); - } - else - { - // For all other resource kinds, we will set the offset in - // the variable layout based on the total resources of that - // kind seen so far (including the "container" if any), - // and then bump the count for total resource usage. - // - auto elementVarResInfo = pendingElementVarLayout->findOrAddResourceInfo(resInfo.kind); - if( auto containerTypeInfo = pendingContainerTypeLayout->FindResourceInfo(resInfo.kind) ) - { - elementVarResInfo->index = containerTypeInfo->count.getFiniteValue(); - } - } - } - rules->EndStructLayout(&uniformLayout); - - // Okay, now we have a `VarLayout` for the element data, and an overall `TypeLayout` - // for all the data that this parameter group needs allocated for pending - // data. - // - // The next major step is to compute the version of that combined resource usage - // that will "bleed through" and thus needs to be allocated at the next level - // up the hierarchy. - // - RefPtr unmaskedPendingDataTypeLayout = new TypeLayout(); - _addUnmaskedResourceUsage(unmaskedPendingDataTypeLayout, pendingContainerTypeLayout, wantSpaceOrSet); - _addUnmaskedResourceUsage(unmaskedPendingDataTypeLayout, pendingElementTypeLayout, wantSpaceOrSet); - - // TODO: we should probably optimize for the case where there is no unmasked - // usage that needs to be reported out, since it should be a common case. - - // Now we need to update the type layout to what we've done. - // - typeLayout->pendingDataTypeLayout = unmaskedPendingDataTypeLayout; - } - - return typeLayout; -} - - /// Do we need to wrap the given element type in a constant buffer layout? -static bool needsConstantBuffer(RefPtr elementTypeLayout) -{ - // We need a constant buffer if the element type has ordinary/uniform data. - // - if(_usesOrdinaryData(elementTypeLayout)) - return true; - - // We also need a constant buffer if there is any "pending" - // data that need ordinary/uniform data allocated to them. - // - if(auto pendingDataTypeLayout = elementTypeLayout->pendingDataTypeLayout) - { - if(_usesOrdinaryData(pendingDataTypeLayout)) - return true; - } - - return false; -} - -RefPtr createConstantBufferTypeLayoutIfNeeded( - TypeLayoutContext const& context, - RefPtr elementTypeLayout) -{ - // First things first, we need to check whether the element type - // we are trying to lay out even needs a constant buffer allocated - // for it. - // - if(!needsConstantBuffer(elementTypeLayout)) - return elementTypeLayout; - - auto parameterGroupRules = context.getRulesFamily()->getConstantBufferRules(); - - return _createParameterGroupTypeLayout( - context - .with(parameterGroupRules) - .with(context.targetReq->getDefaultMatrixLayoutMode()), - nullptr, - elementTypeLayout); -} - - -static RefPtr _createParameterGroupTypeLayout( - TypeLayoutContext const& context, - RefPtr parameterGroupType, - RefPtr elementType, - LayoutRulesImpl* elementTypeRules) -{ - // We will first compute a layout for the element type of - // the parameter group. - // - auto elementTypeLayout = createTypeLayout( - context.with(elementTypeRules), - elementType); - - // Now we delegate to a routine that does the meat of - // the complicated layout logic. - // - return _createParameterGroupTypeLayout( - context, - parameterGroupType, - elementTypeLayout); -} - -LayoutRulesImpl* getParameterBufferElementTypeLayoutRules( - RefPtr parameterGroupType, - LayoutRulesImpl* rules) -{ - if( as(parameterGroupType) ) - { - return rules->getLayoutRulesFamily()->getConstantBufferRules(); - } - else if( as(parameterGroupType) ) - { - return rules->getLayoutRulesFamily()->getTextureBufferRules(); - } - else if( as(parameterGroupType) ) - { - return rules->getLayoutRulesFamily()->getVaryingInputRules(); - } - else if( as(parameterGroupType) ) - { - return rules->getLayoutRulesFamily()->getVaryingOutputRules(); - } - else if( as(parameterGroupType) ) - { - return rules->getLayoutRulesFamily()->getShaderStorageBufferRules(); - } - else if (as(parameterGroupType)) - { - return rules->getLayoutRulesFamily()->getParameterBlockRules(); - } - else - { - SLANG_UNEXPECTED("uhandled parameter block type"); - return nullptr; - } -} - -RefPtr createParameterGroupTypeLayout( - TypeLayoutContext const& context, - RefPtr parameterGroupType) -{ - auto parameterGroupRules = context.rules; - - // Determine the layout rules to use for the contents of the block - auto elementTypeRules = getParameterBufferElementTypeLayoutRules( - parameterGroupType, - parameterGroupRules); - - auto elementType = parameterGroupType->elementType; - - return _createParameterGroupTypeLayout( - context, - parameterGroupType, - elementType, - elementTypeRules); -} - -// Create a type layout for a structured buffer type. -RefPtr -createStructuredBufferTypeLayout( - TypeLayoutContext const& context, - ShaderParameterKind kind, - RefPtr structuredBufferType, - RefPtr elementTypeLayout) -{ - auto rules = context.rules; - auto info = rules->GetObjectLayout(kind); - - auto typeLayout = new StructuredBufferTypeLayout(); - - typeLayout->type = structuredBufferType; - typeLayout->rules = rules; - - typeLayout->elementTypeLayout = elementTypeLayout; - - typeLayout->uniformAlignment = info.alignment; - SLANG_RELEASE_ASSERT(!typeLayout->FindResourceInfo(LayoutResourceKind::Uniform)); - SLANG_RELEASE_ASSERT(typeLayout->uniformAlignment == 1); - - if( info.size != 0 ) - { - typeLayout->addResourceUsage(info.kind, info.size); - } - - // Note: for now we don't deal with the case of a structured - // buffer that might contain anything other than "uniform" data, - // because there really isn't a way to implement that. - - return typeLayout; -} - -// Create a type layout for a structured buffer type. -RefPtr -createStructuredBufferTypeLayout( - TypeLayoutContext const& context, - ShaderParameterKind kind, - RefPtr structuredBufferType, - RefPtr elementType) -{ - // TODO(tfoley): we should be looking up the appropriate rules - // via the `LayoutRulesFamily` in use here... - auto structuredBufferLayoutRules = GetLayoutRulesImpl( - LayoutRule::HLSLStructuredBuffer); - - // Create and save type layout for the buffer contents. - auto elementTypeLayout = createTypeLayout( - context.with(structuredBufferLayoutRules), - elementType.Ptr()); - - return createStructuredBufferTypeLayout( - context, - kind, - structuredBufferType, - elementTypeLayout); - -} - - /// Create layout information for the given `type`. - /// - /// This internal routine returns both the constructed type - /// layout object and the simple layout info, encapsulated - /// together as a `TypeLayoutResult`. - /// -static TypeLayoutResult _createTypeLayout( - TypeLayoutContext const& context, - Type* type); - - /// Create layout information for the given `type`, obeying any layout modifiers on the given declaration. - /// - /// If `declForModifiers` has any matrix layout modifiers associated with it, then - /// the resulting type layout will respect those modifiers. - /// -static TypeLayoutResult _createTypeLayout( - TypeLayoutContext const& context, - Type* type, - Decl* declForModifiers) -{ - TypeLayoutContext subContext = context; - - if (declForModifiers) - { - // TODO: The approach implemented here has a row/column-major - // layout model recursively affect any sub-fields (so that - // the layout of a nested struct depends on the context where - // it is nested). This is consistent with the GLSL behavior - // for these modifiers, but it is *not* how HLSL is supposed - // to work. - // - // In the trivial case where `row_major` and `column_major` - // are only applied to leaf fields/variables of matrix type - // the difference should be immaterial. - - if (declForModifiers->HasModifier()) - subContext.matrixLayoutMode = kMatrixLayoutMode_RowMajor; - - if (declForModifiers->HasModifier()) - subContext.matrixLayoutMode = kMatrixLayoutMode_ColumnMajor; - - // TODO: really need to look for other modifiers that affect - // layout, such as GLSL `std140`. - } - - return _createTypeLayout(subContext, type); -} - -int findGenericParam(List> & genericParameters, GlobalGenericParamDecl * decl) -{ - return (int)genericParameters.findFirstIndex([=](RefPtr & x) {return x->decl.Ptr() == decl; }); -} - -// When constructing a new var layout from an existing one, -// copy fields to the new var from the old. -void copyVarLayoutFields( - VarLayout* dstVarLayout, - VarLayout* srcVarLayout) -{ - dstVarLayout->varDecl = srcVarLayout->varDecl; - dstVarLayout->typeLayout = srcVarLayout->typeLayout; - dstVarLayout->flags = srcVarLayout->flags; - dstVarLayout->systemValueSemantic = srcVarLayout->systemValueSemantic; - dstVarLayout->systemValueSemanticIndex = srcVarLayout->systemValueSemanticIndex; - dstVarLayout->semanticName = srcVarLayout->semanticName; - dstVarLayout->semanticIndex = srcVarLayout->semanticIndex; - dstVarLayout->stage = srcVarLayout->stage; - dstVarLayout->resourceInfos = srcVarLayout->resourceInfos; -} - -// When constructing a new type layout from an existing one, -// copy fields to the new type from the old. -void copyTypeLayoutFields( - TypeLayout* dstTypeLayout, - TypeLayout* srcTypeLayout) -{ - dstTypeLayout->type = srcTypeLayout->type; - dstTypeLayout->rules = srcTypeLayout->rules; - dstTypeLayout->uniformAlignment = srcTypeLayout->uniformAlignment; - dstTypeLayout->resourceInfos = srcTypeLayout->resourceInfos; -} - -// Does this layout resource kind require adjustment when used in -// an array-of-structs fashion? -bool doesResourceRequireAdjustmentForArrayOfStructs(LayoutResourceKind kind) -{ - switch( kind ) - { - case LayoutResourceKind::ConstantBuffer: - case LayoutResourceKind::ShaderResource: - case LayoutResourceKind::UnorderedAccess: - case LayoutResourceKind::SamplerState: - return true; - - default: - return false; - } -} - -// Given the type layout for an element of an array, apply any adjustments required -// based on the element count of the array. -// -// The particular case where this matters is when we have an array of an aggregate -// type that contains resources, since each resource field might need to be at -// a different offset than we would otherwise expect. -// -// For example, given: -// -// struct Foo { Texture2D a; Texture2D b; } -// -// if we just write: -// -// Foo foo; -// -// it gets split into: -// -// Texture2D foo_a; -// Texture2D foo_b; -// -// we expect `foo_a` to get `register(t0)` and -// `foo_b` to get `register(t1)`. However, if we instead have an array: -// -// Foo foo[10]; -// -// then we expect it to be split into: -// -// Texture2D foo_a[8]; -// Texture2D foo_b[8]; -// -// and then we expect `foo_b` to get `register(t8)`, rather -// than `register(t1)`. -// -static RefPtr maybeAdjustLayoutForArrayElementType( - RefPtr originalTypeLayout, - LayoutSize elementCount, - UInt& ioAdditionalSpacesNeeded) -{ - // We will start by looking for cases that we can reject out - // of hand. - - // If the original element type layout doesn't use any - // resource registers, then we are fine. - bool anyResource = false; - for( auto resInfo : originalTypeLayout->resourceInfos ) - { - if( doesResourceRequireAdjustmentForArrayOfStructs(resInfo.kind) ) - { - anyResource = true; - break; - } - } - if(!anyResource) - return originalTypeLayout; - - // Let's look at the type layout we have, and see if there is anything - // that we need to do with it. - // - if( auto originalArrayTypeLayout = originalTypeLayout.as() ) - { - // The element type is itself an array, so we'll need to adjust - // *its* element type accordingly. - // - // We adjust the already-adjusted element type of the inner - // array type, so that we pick up adjustments already made: - auto originalInnerElementTypeLayout = originalArrayTypeLayout->elementTypeLayout; - auto adjustedInnerElementTypeLayout = maybeAdjustLayoutForArrayElementType( - originalInnerElementTypeLayout, - elementCount, - ioAdditionalSpacesNeeded); - - // If nothing needed to be changed on the inner element type, - // then we are done. - if(adjustedInnerElementTypeLayout == originalInnerElementTypeLayout) - return originalTypeLayout; - - // Otherwise, we need to construct a new array type layout - RefPtr adjustedArrayTypeLayout = new ArrayTypeLayout(); - adjustedArrayTypeLayout->originalElementTypeLayout = originalInnerElementTypeLayout; - adjustedArrayTypeLayout->elementTypeLayout = adjustedInnerElementTypeLayout; - adjustedArrayTypeLayout->uniformStride = originalArrayTypeLayout->uniformStride; - - copyTypeLayoutFields(adjustedArrayTypeLayout, originalArrayTypeLayout); - - return adjustedArrayTypeLayout; - } - else if(auto originalParameterGroupTypeLayout = originalTypeLayout.as() ) - { - auto originalInnerElementTypeLayout = originalParameterGroupTypeLayout->elementVarLayout->typeLayout; - auto adjustedInnerElementTypeLayout = maybeAdjustLayoutForArrayElementType( - originalInnerElementTypeLayout, - elementCount, - ioAdditionalSpacesNeeded); - - // If nothing needed to be changed on the inner element type, - // then we are done. - if(adjustedInnerElementTypeLayout == originalInnerElementTypeLayout) - return originalTypeLayout; - - // TODO: actually adjust the element type, and create all the required bits and - // pieces of layout. - - SLANG_UNIMPLEMENTED_X("array of parameter group"); - UNREACHABLE_RETURN(originalTypeLayout); - } - else if(auto originalStructTypeLayout = originalTypeLayout.as() ) - { - Index fieldCount = originalStructTypeLayout->fields.getCount(); - - // Empty struct? Bail out. - if(fieldCount == 0) - return originalTypeLayout; - - RefPtr adjustedStructTypeLayout = new StructTypeLayout(); - copyTypeLayoutFields(adjustedStructTypeLayout, originalStructTypeLayout); - - // If the array type adjustment forces us to give a whole space to - // one or more fields, then we'll need to carefully compute the space - // index for each field as we go. - // - LayoutSize nextSpaceIndex = 0; - - Dictionary, RefPtr> mapOriginalFieldToAdjusted; - for( auto originalField : originalStructTypeLayout->fields ) - { - auto originalFieldTypeLayout = originalField->typeLayout; - - LayoutSize originalFieldSpaceCount = 0; - if(auto resInfo = originalFieldTypeLayout->FindResourceInfo(LayoutResourceKind::RegisterSpace)) - originalFieldSpaceCount = resInfo->count; - - // Compute the adjusted type for the field - UInt fieldAdditionalSpaces = 0; - auto adjustedFieldTypeLayout = maybeAdjustLayoutForArrayElementType( - originalFieldTypeLayout, - elementCount, - fieldAdditionalSpaces); - - LayoutSize adjustedFieldSpaceCount = originalFieldSpaceCount + fieldAdditionalSpaces; - - LayoutSize spaceOffsetForField = nextSpaceIndex; - nextSpaceIndex += adjustedFieldSpaceCount; - - ioAdditionalSpacesNeeded += fieldAdditionalSpaces; - - // Create an adjusted field variable, that is mostly - // a clone of the original field (just with our - // adjusted type in place). - RefPtr adjustedField = new VarLayout(); - copyVarLayoutFields(adjustedField, originalField); - adjustedField->typeLayout = adjustedFieldTypeLayout; - - // We will now walk through the resource usage for - // the adjusted field, and try to figure out what - // to do with it all. - // - for(auto& resInfo : adjustedField->resourceInfos ) - { - if( doesResourceRequireAdjustmentForArrayOfStructs(resInfo.kind) ) - { - if(elementCount.isFinite()) - { - // If the array size is finite, then the field's index/offset - // is just going to be strided by the array size since we - // are effectively doing AoS to SoA conversion. - // - resInfo.index *= elementCount.getFiniteValue(); - } - else - { - // If we are making an unbounded array, then a `struct` - // field with resource type will turn into its own space, - // and it will start at register zero in that space. - // - resInfo.index = 0; - resInfo.space = spaceOffsetForField.getFiniteValue(); - } - } - } - - adjustedStructTypeLayout->fields.add(adjustedField); - - mapOriginalFieldToAdjusted.Add(originalField, adjustedField); - } - - for( auto p : originalStructTypeLayout->mapVarToLayout ) - { - Decl* key = p.Key; - RefPtr originalVal = p.Value; - RefPtr adjustedVal; - if( mapOriginalFieldToAdjusted.TryGetValue(originalVal, adjustedVal) ) - { - adjustedStructTypeLayout->mapVarToLayout.Add(key, adjustedVal); - } - } - - return adjustedStructTypeLayout; - } - else - { - // In the leaf case, we must have a field that used up some resource - // that requires adjustment. Because there is no sub-structure to work - // with, we can just return the type layout as-is, but we also want - // to make a note that this value should consume an additional register - // space *if* the element count is unbounded. - if( elementCount.isInfinite() ) - { - ioAdditionalSpacesNeeded++; - } - - return originalTypeLayout; - } -} - - /// Convert a `TypeLayout` to a `TypeLayoutResult` - /// - /// A `TypeLayout` holds all the data needed to make a `TypeLayoutResult` in practice, - /// but sometimes it is more convenient to have the data split out. - /// -TypeLayoutResult makeTypeLayoutResult(RefPtr typeLayout) -{ - TypeLayoutResult result; - result.layout = typeLayout; - - // If the type only consumes a single kind of non-uniform resource, - // we can fill in the `info` field directly. - // - if( typeLayout->resourceInfos.getCount() == 1 ) - { - auto resInfo = typeLayout->resourceInfos[0]; - if( resInfo.kind != LayoutResourceKind::Uniform ) - { - result.info.kind = resInfo.kind; - result.info.size = resInfo.count; - return result; - } - } - - // Otherwise, we will fill out the info based on the uniform - // resources consumed, if any. - // - if( auto resInfo = typeLayout->FindResourceInfo(LayoutResourceKind::Uniform) ) - { - result.info.kind = LayoutResourceKind::Uniform; - result.info.alignment = typeLayout->uniformAlignment; - result.info.size = resInfo->count; - } - - // If there was no ordinary/uniform resource usage, then we - // will leave the `info` field in its default state (which - // shows no resources consumed). - // - // The type layout might have more detailed information, but - // at this point it must contain either zero, or more than one - // `ResourceInfo`, so there is nothing unambiguous we can - // store into `info`. - - return result; -} - -// -// StructTypeLayoutBuilder -// - -void StructTypeLayoutBuilder::beginLayout( - Type* type, - LayoutRulesImpl* rules) -{ - m_rules = rules; - - m_typeLayout = new StructTypeLayout(); - m_typeLayout->type = type; - m_typeLayout->rules = m_rules; - - m_info = m_rules->BeginStructLayout(); -} - -void StructTypeLayoutBuilder::beginLayoutIfNeeded( - Type* type, - LayoutRulesImpl* rules) -{ - if( !m_typeLayout ) - { - beginLayout(type, rules); - } -} - -RefPtr StructTypeLayoutBuilder::addField( - DeclRef field, - TypeLayoutResult fieldResult) -{ - SLANG_ASSERT(m_typeLayout); - - RefPtr fieldTypeLayout = fieldResult.layout; - UniformLayoutInfo fieldInfo = fieldResult.info.getUniformLayout(); - - // Note: we don't add any zero-size fields - // when computing structure layout, just - // to avoid having a resource type impact - // the final layout. - // - // This means that the code to generate final - // declarations needs to *also* eliminate zero-size - // fields to be safe... - // - LayoutSize uniformOffset = m_info.size; - if(fieldInfo.size != 0) - { - uniformOffset = m_rules->AddStructField(&m_info, fieldInfo); - } - - - // We need to create variable layouts - // for each field of the structure. - RefPtr fieldLayout = new VarLayout(); - fieldLayout->varDecl = field; - fieldLayout->typeLayout = fieldTypeLayout; - m_typeLayout->fields.add(fieldLayout); - - if( field ) - { - m_typeLayout->mapVarToLayout.Add(field.getDecl(), fieldLayout); - } - - // Set up uniform offset information, if there is any uniform data in the field - if( fieldTypeLayout->FindResourceInfo(LayoutResourceKind::Uniform) ) - { - fieldLayout->AddResourceInfo(LayoutResourceKind::Uniform)->index = uniformOffset.getFiniteValue(); - } - - // Add offset information for any other resource kinds - for( auto fieldTypeResourceInfo : fieldTypeLayout->resourceInfos ) - { - // Uniforms were dealt with above - if(fieldTypeResourceInfo.kind == LayoutResourceKind::Uniform) - continue; - - // We should not have already processed this resource type - SLANG_RELEASE_ASSERT(!fieldLayout->FindResourceInfo(fieldTypeResourceInfo.kind)); - - // The field will need offset information for this kind - auto fieldResourceInfo = fieldLayout->AddResourceInfo(fieldTypeResourceInfo.kind); - - // It is possible for a `struct` field to use an unbounded array - // type, and in the D3D case that would consume an unbounded number - // of registers. What is more, a single `struct` could have multiple - // such fields, or ordinary resource fields after an unbounded field. - // - // We handle this case by allocating a distinct register space for - // any field that consumes an unbounded amount of registers. - // - if( fieldTypeResourceInfo.count.isInfinite() ) - { - // We need to add one register space to own the storage for this field. - // - auto structTypeSpaceResourceInfo = m_typeLayout->findOrAddResourceInfo(LayoutResourceKind::RegisterSpace); - auto spaceOffset = structTypeSpaceResourceInfo->count; - structTypeSpaceResourceInfo->count += 1; - - // The field itself will record itself as having a zero offset into - // the chosen space. - // - fieldResourceInfo->space = spaceOffset.getFiniteValue(); - fieldResourceInfo->index = 0; - } - else - { - // In the case where the field consumes a finite number of slots, we - // can simply set its offset/index to the number of such slots consumed - // so far, and then increment the number of slots consumed by the - // `struct` type itself. - // - auto structTypeResourceInfo = m_typeLayout->findOrAddResourceInfo(fieldTypeResourceInfo.kind); - fieldResourceInfo->index = structTypeResourceInfo->count.getFiniteValue(); - structTypeResourceInfo->count += fieldTypeResourceInfo.count; - } - } - - return fieldLayout; -} - -RefPtr StructTypeLayoutBuilder::addField( - DeclRef field, - RefPtr fieldTypeLayout) -{ - TypeLayoutResult fieldResult = makeTypeLayoutResult(fieldTypeLayout); - return addField(field, fieldResult); -} - -void StructTypeLayoutBuilder::endLayout() -{ - if(!m_typeLayout) return; - - m_rules->EndStructLayout(&m_info); - - m_typeLayout->uniformAlignment = m_info.alignment; - m_typeLayout->addResourceUsage(LayoutResourceKind::Uniform, m_info.size); -} - -RefPtr StructTypeLayoutBuilder::getTypeLayout() -{ - return m_typeLayout; -} - -TypeLayoutResult StructTypeLayoutBuilder::getTypeLayoutResult() -{ - return TypeLayoutResult(m_typeLayout, m_info); -} - -static TypeLayoutResult _createTypeLayout( - TypeLayoutContext const& context, - Type* type) -{ - auto rules = context.rules; - - if (auto parameterGroupType = as(type)) - { - // If the user is just interested in uniform layout info, - // then this is easy: a `ConstantBuffer` is really no - // different from a `Texture2D` in terms of how it - // should be handled as a member of a container. - // - auto info = getParameterGroupLayoutInfo(parameterGroupType, rules); - - // The more interesting case, though, is when the user - // is requesting us to actually create a `TypeLayout`, - // since in that case we need to: - // - // 1. Compute a layout for the data inside the constant - // buffer, including offsets, etc. - // - // 2. Compute information about any object types inside - // the constant buffer, which need to be surfaces out - // to the top level. - // - auto typeLayout = createParameterGroupTypeLayout( - context, - parameterGroupType); - - return TypeLayoutResult(typeLayout, info); - } - else if (auto samplerStateType = as(type)) - { - return createSimpleTypeLayout( - rules->GetObjectLayout(ShaderParameterKind::SamplerState), - type, - rules); - } - else if (auto textureType = as(type)) - { - // TODO: the logic here should really be defined by the rules, - // and not at this top level... - ShaderParameterKind kind; - switch( textureType->getAccess() ) - { - default: - kind = ShaderParameterKind::MutableTexture; - break; - - case SLANG_RESOURCE_ACCESS_READ: - kind = ShaderParameterKind::Texture; - break; - } - - return createSimpleTypeLayout( - rules->GetObjectLayout(kind), - type, - rules); - } - else if (auto imageType = as(type)) - { - // TODO: the logic here should really be defined by the rules, - // and not at this top level... - ShaderParameterKind kind; - switch( imageType->getAccess() ) - { - default: - kind = ShaderParameterKind::MutableImage; - break; - - case SLANG_RESOURCE_ACCESS_READ: - kind = ShaderParameterKind::Image; - break; - } - - return createSimpleTypeLayout( - rules->GetObjectLayout(kind), - type, - rules); - } - else if (auto textureSamplerType = as(type)) - { - // TODO: the logic here should really be defined by the rules, - // and not at this top level... - ShaderParameterKind kind; - switch( textureSamplerType->getAccess() ) - { - default: - kind = ShaderParameterKind::MutableTextureSampler; - break; - - case SLANG_RESOURCE_ACCESS_READ: - kind = ShaderParameterKind::TextureSampler; - break; - } - - return createSimpleTypeLayout( - rules->GetObjectLayout(kind), - type, - rules); - } - - // TODO: need a better way to handle this stuff... -#define CASE(TYPE, KIND) \ - else if(auto type_##TYPE = as(type)) do { \ - auto info = rules->GetObjectLayout(ShaderParameterKind::KIND); \ - auto typeLayout = createStructuredBufferTypeLayout( \ - context, \ - ShaderParameterKind::KIND, \ - type_##TYPE, \ - type_##TYPE->elementType.Ptr()); \ - return TypeLayoutResult(typeLayout, info); \ - } while(0) - - CASE(HLSLStructuredBufferType, StructuredBuffer); - CASE(HLSLRWStructuredBufferType, MutableStructuredBuffer); - CASE(HLSLRasterizerOrderedStructuredBufferType, MutableStructuredBuffer); - CASE(HLSLAppendStructuredBufferType, MutableStructuredBuffer); - CASE(HLSLConsumeStructuredBufferType, MutableStructuredBuffer); - -#undef CASE - - - // TODO: need a better way to handle this stuff... -#define CASE(TYPE, KIND) \ - else if(as(type)) do { \ - return createSimpleTypeLayout( \ - rules->GetObjectLayout(ShaderParameterKind::KIND), \ - type, rules); \ - } while(0) - - CASE(HLSLByteAddressBufferType, RawBuffer); - CASE(HLSLRWByteAddressBufferType, MutableRawBuffer); - CASE(HLSLRasterizerOrderedByteAddressBufferType, MutableRawBuffer); - - CASE(GLSLInputAttachmentType, InputRenderTarget); - - // This case is mostly to allow users to add new resource types... - CASE(UntypedBufferResourceType, RawBuffer); - -#undef CASE - - else if(auto basicType = as(type)) - { - return createSimpleTypeLayout( - rules->GetScalarLayout(basicType->baseType), - type, - rules); - } - else if(auto vecType = as(type)) - { - auto elementType = vecType->elementType; - size_t elementCount = (size_t) GetIntVal(vecType->elementCount); - - auto element = _createTypeLayout( - context, - elementType); - - auto info = rules->GetVectorLayout(element.info, elementCount); - - RefPtr typeLayout = new VectorTypeLayout(); - typeLayout->type = type; - typeLayout->rules = rules; - typeLayout->uniformAlignment = info.alignment; - - typeLayout->elementTypeLayout = element.layout; - typeLayout->uniformStride = element.info.getUniformLayout().size.getFiniteValue(); - - typeLayout->addResourceUsage(info.kind, info.size); - - return TypeLayoutResult(typeLayout, info); - } - else if(auto matType = as(type)) - { - size_t rowCount = (size_t) GetIntVal(matType->getRowCount()); - size_t colCount = (size_t) GetIntVal(matType->getColumnCount()); - - auto elementType = matType->getElementType(); - auto elementResult = _createTypeLayout( - context, - elementType); - auto elementTypeLayout = elementResult.layout; - auto elementInfo = elementResult.info; - - // The `GetMatrixLayout` implementation in the layout rules - // currently defaults to assuming row-major layout, - // so if we want column-major layout we achieve it here by - // transposing the major/minor axes counts. - // - size_t layoutMajorCount = rowCount; - size_t layoutMinorCount = colCount; - if (context.matrixLayoutMode == kMatrixLayoutMode_ColumnMajor) - { - size_t tmp = layoutMajorCount; - layoutMajorCount = layoutMinorCount; - layoutMinorCount = tmp; - } - auto info = rules->GetMatrixLayout( - elementInfo, - layoutMajorCount, - layoutMinorCount); - - auto rowType = matType->getRowType(); - RefPtr rowTypeLayout = new VectorTypeLayout(); - - auto rowInfo = rules->GetVectorLayout( - elementInfo, - colCount); - - size_t majorStride = info.elementStride; - size_t minorStride = elementInfo.getUniformLayout().size.getFiniteValue(); - - size_t rowStride = 0; - size_t colStride = 0; - if(context.matrixLayoutMode == kMatrixLayoutMode_ColumnMajor) - { - colStride = majorStride; - rowStride = minorStride; - } - else - { - rowStride = majorStride; - colStride = minorStride; - } - - rowTypeLayout->type = type; - rowTypeLayout->rules = rules; - rowTypeLayout->uniformAlignment = elementInfo.getUniformLayout().alignment; - - rowTypeLayout->uniformStride = colStride; - rowTypeLayout->elementTypeLayout = elementTypeLayout; - rowTypeLayout->addResourceUsage(rowInfo.kind, rowInfo.size); - - RefPtr typeLayout = new MatrixTypeLayout(); - - typeLayout->type = type; - typeLayout->rules = rules; - typeLayout->uniformAlignment = info.alignment; - - typeLayout->elementTypeLayout = rowTypeLayout; - typeLayout->uniformStride = rowStride; - typeLayout->mode = context.matrixLayoutMode; - - typeLayout->addResourceUsage(info.kind, info.size); - - return TypeLayoutResult(typeLayout, info); - } - else if (auto arrayType = as(type)) - { - auto elementResult = _createTypeLayout( - context, - arrayType->baseType.Ptr()); - auto elementInfo = elementResult.info; - auto elementTypeLayout = elementResult.layout; - - // To a first approximation, an array will usually be laid out - // by taking the element's type layout and laying out `elementCount` - // copies of it. There are of course many details that make - // this simplistic version of things not quite work. - // - // An important complication to deal with is the possibility of - // having "unbounded" arrays, which don't specify a size.' - // The layout rules for these vary heavily by resource kind and API. - // - - auto elementCount = GetElementCount(arrayType->ArrayLength); - - // - // We can compute the uniform storage layout of an array using - // the rules for the target API. - // - // TODO: ensure that this does something reasonable with the unbounded - // case, or else issue an error message that the target doesn't - // support unbounded types. - // - - auto arrayUniformInfo = rules->GetArrayLayout( - elementInfo, - elementCount).getUniformLayout(); - - RefPtr typeLayout = new ArrayTypeLayout(); - - // Some parts of the array type layout object are easy to fill in: - typeLayout->type = type; - typeLayout->rules = rules; - typeLayout->originalElementTypeLayout = elementTypeLayout; - typeLayout->uniformAlignment = arrayUniformInfo.alignment; - typeLayout->uniformStride = arrayUniformInfo.elementStride; - - typeLayout->addResourceUsage(LayoutResourceKind::Uniform, arrayUniformInfo.size); - - // - // The tricky part in constructing an array type layout comes when - // the element type is (or nests) a structure with resource-type - // fields, because in that case we need to perform AoS-to-SoA - // conversion as part of computing the final type layout, and - // we also need to pre-compute an "adjusted" element type - // layout that accounts for the striding that happens with - // resource-type contents. - // - // This complication is only made worse when we have to deal with - // unbounded-size arrays over such element types, since those - // resource-type fields will each end up consuming a full space - // in the resulting layout. - // - // The `maybeAdjustLayoutForArrayElementType` computes an "adjusted" - // type layout for the element type which takes the array stride into - // account. If it returns the same type layout that was passed in, - // then that means no adjustement took place. - // - // The `additionalSpacesNeededForAdjustedElementType` variable counts - // the number of additional register spaces that were consumed, - // in the case of an unbounded array. - // - UInt additionalSpacesNeededForAdjustedElementType = 0; - RefPtr adjustedElementTypeLayout = maybeAdjustLayoutForArrayElementType( - elementTypeLayout, - elementCount, - additionalSpacesNeededForAdjustedElementType); - - typeLayout->elementTypeLayout = adjustedElementTypeLayout; - - // We will now iterate over the resources consumed by the element - // type to compute how they contribute to the resource usage - // of the overall array type. - // - for( auto elementResourceInfo : elementTypeLayout->resourceInfos ) - { - // The uniform case was already handled above - if( elementResourceInfo.kind == LayoutResourceKind::Uniform ) - continue; - - LayoutSize arrayResourceCount = 0; - - // In almost all cases, the resources consumed by an array - // will be its element count times the resources consumed - // by its element type. - // - // The first exception to this is arrays of resources when - // compiling to GLSL for Vulkan, where an entire array - // only consumes a single descriptor-table slot. - // - if (elementResourceInfo.kind == LayoutResourceKind::DescriptorTableSlot) - { - arrayResourceCount = elementResourceInfo.count; - } - // - // The next big exception is when we are forming an unbounded-size - // array and the element type got "adjusted," because that means - // the array type will need to allocate full spaces for any resource-type - // fields in the element type. - // - // Note: we carefully carve things out so that the case of a simple - // array of resources does *not* lead to the element type being adjusted, - // so that this logic doesn't trigger and we instead handle it with - // the default logic below. - // - else if( - elementCount.isInfinite() - && adjustedElementTypeLayout != elementTypeLayout - && doesResourceRequireAdjustmentForArrayOfStructs(elementResourceInfo.kind) ) - { - // We want to ignore resource types consumed by the element type - // that need adjustement if the array size is infinite, since - // we will be allocating whole spaces for that part of the - // element's resource usage. - } - else - { - arrayResourceCount = elementResourceInfo.count * elementCount; - } - - // Now that we've computed how the resource usage of the element type - // should contribute to the resource usage of the array, we can - // add in that resource usage. - // - typeLayout->addResourceUsage( - elementResourceInfo.kind, - arrayResourceCount); - } - - // The loop above to compute the resource usage of the array from its - // element type ignored any resource-type fields in an unbounded-size - // array if they would have been allocated as full register spaces. - // Those same fields were counted in `additionalSpacesNeededForAdjustedElementType`, - // and need to be added into the total resource usage for the array - // if we skipped them as part of the loop (which happens when - // we detect that the element type layout had been "adjusted"). - // - if( adjustedElementTypeLayout != elementTypeLayout ) - { - typeLayout->addResourceUsage(LayoutResourceKind::RegisterSpace, additionalSpacesNeededForAdjustedElementType); - } - - return TypeLayoutResult(typeLayout, arrayUniformInfo); - } - else if (auto declRefType = as(type)) - { - auto declRef = declRefType->declRef; - - if (auto structDeclRef = declRef.as()) - { - StructTypeLayoutBuilder typeLayoutBuilder; - StructTypeLayoutBuilder pendingDataTypeLayoutBuilder; - - typeLayoutBuilder.beginLayout(type, rules); - auto typeLayout = typeLayoutBuilder.getTypeLayout(); - for (auto field : GetFields(structDeclRef)) - { - // Static fields shouldn't take part in layout. - if(field.getDecl()->HasModifier()) - continue; - - // The fields of a `struct` type may include existential (interface) - // types (including as nested sub-fields), and any types present - // in those fields will need to be specialized based on the - // input arguments being passed to `_createTypeLayout`. - // - // We won't know how many type slots each field consumes until - // we process it, but we can figure out the starting index for - // the slots its will consume by looking at the layout we've - // computed so far. - // - Int baseExistentialSlotIndex = 0; - if(auto resInfo = typeLayout->FindResourceInfo(LayoutResourceKind::ExistentialTypeParam)) - baseExistentialSlotIndex = Int(resInfo->count.getFiniteValue()); - // - // When computing the layout for the field, we will give it access - // to all the incoming specialized type slots that haven't already - // been consumed/claimed by preceding fields. - // - auto fieldLayoutContext = context.withExistentialTypeSlotsOffsetBy(baseExistentialSlotIndex); - - TypeLayoutResult fieldResult = _createTypeLayout( - fieldLayoutContext, - GetType(field).Ptr(), - field.getDecl()); - auto fieldTypeLayout = fieldResult.layout; - - auto fieldVarLayout = typeLayoutBuilder.addField(field, fieldResult); - - // If any of the fields of the `struct` type had existential/interface - // type, then we need to compute a second `StructTypeLayout` that - // represents the layout and resource using for the "pending data" - // that this type needs to have stored somewhere, but which can't - // be laid out in the layout of the type itself. - // - if(auto fieldPendingDataTypeLayout = fieldTypeLayout->pendingDataTypeLayout) - { - // We only create this secondary layout on-demand, so that - // we don't end up with a bunch of empty structure type layouts - // created for no reason. - // - pendingDataTypeLayoutBuilder.beginLayoutIfNeeded(type, rules); - auto fieldPendingVarLayout = pendingDataTypeLayoutBuilder.addField(field, fieldPendingDataTypeLayout); - fieldVarLayout->pendingVarLayout = fieldPendingVarLayout; - } - } - - typeLayoutBuilder.endLayout(); - pendingDataTypeLayoutBuilder.endLayout(); - - if( auto pendingDataTypeLayout = pendingDataTypeLayoutBuilder.getTypeLayout() ) - { - typeLayout->pendingDataTypeLayout = pendingDataTypeLayout; - } - - return typeLayoutBuilder.getTypeLayoutResult(); - } - else if (auto globalGenParam = declRef.as()) - { - SimpleLayoutInfo info; - info.alignment = 0; - info.size = 0; - info.kind = LayoutResourceKind::GenericResource; - - auto genParamTypeLayout = new GenericParamTypeLayout(); - // we should have already populated ProgramLayout::genericEntryPointParams list at this point, - // so we can find the index of this generic param decl in the list - genParamTypeLayout->type = type; - genParamTypeLayout->paramIndex = findGenericParam(context.programLayout->globalGenericParams, genParamTypeLayout->getGlobalGenericParamDecl()); - genParamTypeLayout->rules = rules; - genParamTypeLayout->findOrAddResourceInfo(LayoutResourceKind::GenericResource)->count += 1; - - return TypeLayoutResult(genParamTypeLayout, info); - } - else if (auto assocTypeParam = declRef.as()) - { - return createSimpleTypeLayout( - SimpleLayoutInfo(), - type, - rules); - } - else if( auto simpleGenericParam = declRef.as() ) - { - // A bare generic type parameter can come up during layout - // of a generic entry point (or an entry point nested in - // a generic type). For now we will just pretend like - // the fields of generic parameter type take no space, - // since there is no reasonable way to account for them - // in the resulting layout. - // - // TODO: It might be better to completely ignore generic - // entry points during initial layout, but doing so would - // mean that users couldn't get layout information on - // any parameters, even those that don't depend on - // generics. - // - return createSimpleTypeLayout( - SimpleLayoutInfo(), - type, - rules); - } - else if( auto interfaceDeclRef = declRef.as() ) - { - // When laying out a type that includes interface-type fields, - // we cannot know how much space the concrete type that - // gets stored into the field consumes. - // - // If we were doing layout for a typical CPU target, then - // we could just say that each interface-type field consumes - // some fixed number of pointers (e.g., a data pointer plus a witness - // table pointer). - // - // We will borrow the intuition from that and invent a new - // resource kind for "existential slots" which conceptually - // represents the indirections needed to reference the - // data to be referenced by this field. - // - - RefPtr typeLayout = new TypeLayout(); - typeLayout->type = type; - typeLayout->rules = rules; - - typeLayout->addResourceUsage(LayoutResourceKind::ExistentialTypeParam, 1); - typeLayout->addResourceUsage(LayoutResourceKind::ExistentialObjectParam, 1); - - // If there are any concrete types available, the first one will be - // the value that should be plugged into the slot we just introduced. - // - if( context.existentialTypeArgCount ) - { - RefPtr concreteType = context.existentialTypeArgs[0].type; - - RefPtr concreteTypeLayout = createTypeLayout(context, concreteType); - - // Layout for this specialized interface type then results - // in a type layout that tracks both the resource usage of the - // interface type itself (just the type + value slots introduced - // above), plus a "pending data" type that represents the value - // conceptually pointed to by the interface-type field/variable at runtime. - // - typeLayout->pendingDataTypeLayout = concreteTypeLayout; - } - - return TypeLayoutResult(typeLayout, SimpleLayoutInfo()); - } - } - else if (auto errorType = as(type)) - { - // An error type means that we encountered something we don't understand. - // - // We should probably inform the user with an error message here. - - return createSimpleTypeLayout( - SimpleLayoutInfo(), - type, - rules); - } - else if( auto taggedUnionType = as(type) ) - { - // A tagged union type needs to be laid out as the maximum - // size of any constituent type. - // - // In practice, only a tagged union of uniform data will - // work, but for now we will compute the maximum usage - // for each resource kind for generality. - // - // For the uniform data we will start with a size - // of zero and an alignment of one for our base case - // (this is what a tagged union of no cases would consume). - // - UniformLayoutInfo info(0, 1); - - RefPtr taggedUnionLayout = new TaggedUnionTypeLayout(); - taggedUnionLayout->type = type; - taggedUnionLayout->rules = rules; - - // Now we iterate over the case types and see if they - // change our computed maximum size/alignement. - // - for( auto caseType : taggedUnionType->caseTypes ) - { - // Note: A tagged union type is not expected to have any existential/interface type - // slots; the case types that are provided must be fully specialized before the union is - // formed. Thus we don't need to mess around with existential type slots here the - // way we do for the `struct` case. - - auto caseTypeResult = _createTypeLayout(context, caseType); - RefPtr caseTypeLayout = caseTypeResult.layout; - UniformLayoutInfo caseTypeInfo = caseTypeResult.info.getUniformLayout(); - - info.size = maximum(info.size, caseTypeInfo.size); - info.alignment = std::max(info.alignment, caseTypeInfo.alignment); - - // We need to remember the layout of the case type - // on the final `TaggedUnionTypeLayout`. - // - taggedUnionLayout->caseTypeLayouts.add(caseTypeLayout); - - // We also need to consider contributions for other - // resource kinds beyond uniform data. - // - for( auto caseResInfo : caseTypeLayout->resourceInfos ) - { - auto unionResInfo = taggedUnionLayout->findOrAddResourceInfo(caseResInfo.kind); - unionResInfo->count = maximum(unionResInfo->count, caseResInfo.count); - } - } - - // After we've computed the size required to hold all the - // case types, we will allocate space for the tag field. - // - // TODO: This assumes the tag will always be allocated out - // of uniform storage, which means we can't support a tagged - // union as part of a varying input/output signature. That is - // probably a valid limitation, but it should get enforced - // somewhere along the way. - // - { - // The tag is always a `uint` for now. - // - auto tagInfo = context.rules->GetScalarLayout(BaseType::UInt); - info.size = RoundToAlignment(info.size, tagInfo.alignment); - - taggedUnionLayout->tagOffset = info.size; - - info.size += tagInfo.size; - info.alignment = std::max(info.alignment, tagInfo.alignment); - } - - // As a final step, if we are computing a full `TypeLayout` - // we will make sure that its information on uniform layout - // matches what we've computed in the `UniformLayoutInfo` we return. - // - taggedUnionLayout->findOrAddResourceInfo(LayoutResourceKind::Uniform)->count = info.size; - taggedUnionLayout->uniformAlignment = info.alignment; - - return TypeLayoutResult(taggedUnionLayout, info); - } - else if( auto existentialSpecializedType = as(type) ) - { - TypeLayoutContext subContext = context.withExistentialTypeArgs( - existentialSpecializedType->slots.args.getCount(), - existentialSpecializedType->slots.args.getBuffer()); - - auto baseTypeLayoutResult = _createTypeLayout( - subContext, - existentialSpecializedType->baseType); - - UniformLayoutInfo info = rules->BeginStructLayout(); - rules->AddStructField(&info, baseTypeLayoutResult.info.getUniformLayout()); - - RefPtr typeLayout = new ExistentialSpecializedTypeLayout(); - typeLayout->type = type; - typeLayout->rules = rules; - - RefPtr pendingDataVarLayout = new VarLayout(); - if(auto pendingDataTypeLayout = baseTypeLayoutResult.layout->pendingDataTypeLayout) - { - for( auto pendingResInfo : pendingDataTypeLayout->resourceInfos ) - { - auto kind = pendingResInfo.kind; - UInt index = 0; - if( kind == LayoutResourceKind::Uniform ) - { - LayoutSize uniformOffset = rules->AddStructField( - &info, - makeTypeLayoutResult(pendingDataTypeLayout).info.getUniformLayout()); - - index = uniformOffset.getFiniteValue(); - } - else - { - if(auto primaryResInfo = baseTypeLayoutResult.layout->FindResourceInfo(kind)) - index = primaryResInfo->count.getFiniteValue(); - } - pendingDataVarLayout->AddResourceInfo(kind)->index = index; - } - } - - typeLayout->baseTypeLayout = baseTypeLayoutResult.layout; - typeLayout->pendingDataVarLayout = pendingDataVarLayout; - - return makeTypeLayoutResult(typeLayout); - } - - // catch-all case in case nothing matched - SLANG_ASSERT(!"unimplemented case in type layout"); - return createSimpleTypeLayout( - SimpleLayoutInfo(), - type, - rules); -} - -RefPtr getSimpleVaryingParameterTypeLayout( - TypeLayoutContext const& context, - Type* type, - EntryPointParameterDirectionMask directionMask) -{ - auto rules = context.rules; - - // TODO: This logic should ideally share as much - // as possible with the `_createTypeLayout` function, - // to avoid duplication, but we also have to deal - // with the many ways in which varying parameter - // layout differs from non-varying layout. - - // We will compute resource consumption for the type - // as a varying input, output, or both/neither. - // To avoid duplication, we'll build an array that - // includes all the layout rules we need to apply. - // - int varyingRulesCount = 0; - LayoutRulesImpl* varyingRules[2]; - - if( directionMask & kEntryPointParameterDirection_Input ) - { - varyingRules[varyingRulesCount++] = context.getRulesFamily()->getVaryingInputRules(); - } - if( directionMask & kEntryPointParameterDirection_Output ) - { - varyingRules[varyingRulesCount++] = context.getRulesFamily()->getVaryingOutputRules(); - } - - if(auto basicType = as(type)) - { - auto baseType = basicType->baseType; - - RefPtr typeLayout = new TypeLayout(); - typeLayout->type = type; - typeLayout->rules = rules; - - for( int rr = 0; rr < varyingRulesCount; ++rr ) - { - auto info = varyingRules[rr]->GetScalarLayout(baseType); - typeLayout->addResourceUsage(info.kind, info.size); - } - - return typeLayout; - } - else if(auto vecType = as(type)) - { - auto elementType = vecType->elementType; - size_t elementCount = (size_t) GetIntVal(vecType->elementCount); - - BaseType elementBaseType = BaseType::Void; - if( auto elementBasicType = as(elementType) ) - { - elementBaseType = elementBasicType->baseType; - } - - // Note that we do *not* add any resource usage to the type - // layout for the element type, because we currently cannot count - // varying parameter usage at a granularity finer than - // individual "locations." - // - RefPtr elementTypeLayout = new TypeLayout(); - elementTypeLayout->type = elementType; - elementTypeLayout->rules = rules; - - RefPtr typeLayout = new VectorTypeLayout(); - typeLayout->type = vecType; - typeLayout->rules = rules; - typeLayout->elementTypeLayout = elementTypeLayout; - - for( int rr = 0; rr < varyingRulesCount; ++rr ) - { - auto varyingRuleSet = varyingRules[rr]; - auto elementInfo = varyingRuleSet->GetScalarLayout(elementBaseType); - auto info = varyingRuleSet->GetVectorLayout(elementInfo, elementCount); - typeLayout->addResourceUsage(info.kind, info.size); - } - - return typeLayout; - } - else if(auto matType = as(type)) - { - size_t rowCount = (size_t) GetIntVal(matType->getRowCount()); - size_t colCount = (size_t) GetIntVal(matType->getColumnCount()); - auto elementType = matType->getElementType(); - - BaseType elementBaseType = BaseType::Void; - if( auto elementBasicType = as(elementType) ) - { - elementBaseType = elementBasicType->baseType; - } - - // Just as for `_createTypeLayout`, we need to handle row- and - // column-major matrices differently, to ensure we get - // the expected layout. - // - // A varying parameter with row-major layout is effectively - // just an array of row vectors, while a column-major one - // is just an array of column vectors. - // - size_t layoutMajorCount = rowCount; - size_t layoutMinorCount = colCount; - if (context.matrixLayoutMode == kMatrixLayoutMode_ColumnMajor) - { - size_t tmp = layoutMajorCount; - layoutMajorCount = layoutMinorCount; - layoutMinorCount = tmp; - } - - RefPtr elementTypeLayout = new TypeLayout(); - elementTypeLayout->type = elementType; - elementTypeLayout->rules = rules; - - RefPtr rowTypeLayout = new VectorTypeLayout(); - rowTypeLayout->type = matType->getRowType(); - rowTypeLayout->rules = rules; - rowTypeLayout->elementTypeLayout = elementTypeLayout; - - RefPtr typeLayout = new MatrixTypeLayout(); - typeLayout->type = type; - typeLayout->rules = rules; - typeLayout->elementTypeLayout = rowTypeLayout; - typeLayout->mode = context.matrixLayoutMode; - - for( int rr = 0; rr < varyingRulesCount; ++rr ) - { - auto varyingRuleSet = varyingRules[rr]; - auto elementInfo = varyingRuleSet->GetScalarLayout(elementBaseType); - - auto info = varyingRuleSet->GetMatrixLayout(elementInfo, layoutMajorCount, layoutMinorCount); - typeLayout->addResourceUsage(info.kind, info.size); - - if(context.matrixLayoutMode == kMatrixLayoutMode_RowMajor) - { - // For row-major matrices only, we can compute an effective - // resource usage for the row type. - auto rowInfo = varyingRuleSet->GetVectorLayout(elementInfo, colCount); - rowTypeLayout->addResourceUsage(rowInfo.kind, rowInfo.size); - } - } - - return typeLayout; - } - - // catch-all case in case nothing matched - SLANG_ASSERT(!"unimplemented case for varying parameter layout"); - return createSimpleTypeLayout( - SimpleLayoutInfo(), - type, - rules).layout; -} - -RefPtr createTypeLayout( - TypeLayoutContext const& context, - Type* type) -{ - return _createTypeLayout(context, type).layout; -} - -void TypeLayout::addResourceUsageFrom(TypeLayout* otherTypeLayout) -{ - for(auto resInfo : otherTypeLayout->resourceInfos) - addResourceUsage(resInfo); -} - - -RefPtr TypeLayout::unwrapArray() -{ - TypeLayout* typeLayout = this; - - while(auto arrayTypeLayout = as(typeLayout)) - typeLayout = arrayTypeLayout->elementTypeLayout; - - return typeLayout; -} - - -RefPtr GenericParamTypeLayout::getGlobalGenericParamDecl() -{ - auto declRefType = as(type); - SLANG_ASSERT(declRefType); - auto rsDeclRef = declRefType->declRef.as(); - return rsDeclRef.getDecl(); -} - -} // namespace Slang diff --git a/source/slang/type-layout.h b/source/slang/type-layout.h deleted file mode 100644 index c58f92cfb..000000000 --- a/source/slang/type-layout.h +++ /dev/null @@ -1,1118 +0,0 @@ -#ifndef SLANG_TYPE_LAYOUT_H -#define SLANG_TYPE_LAYOUT_H - -#include "../core/basic.h" -#include "compiler.h" -#include "profile.h" -#include "syntax.h" - -#include "../../slang.h" - -namespace Slang { - -// Forward declarations - -enum class BaseType; -class Type; - -// - -enum class LayoutRule -{ - Std140, - Std430, - HLSLConstantBuffer, - HLSLStructuredBuffer, -}; - -#if 0 -enum class LayoutRulesFamily -{ - HLSL, - GLSL, -}; -#endif - -// A "size" that can either be a simple finite size or -// the special case of an infinite/unbounded size. -// -struct LayoutSize -{ - typedef size_t RawValue; - - LayoutSize() - : raw(0) - {} - - LayoutSize(RawValue size) - : raw(size) - { - SLANG_ASSERT(size != RawValue(-1)); - } - - static LayoutSize infinite() - { - LayoutSize result; - result.raw = RawValue(-1); - return result; - } - - bool isInfinite() const { return raw == RawValue(-1); } - - bool isFinite() const { return raw != RawValue(-1); } - RawValue getFiniteValue() const { SLANG_ASSERT(isFinite()); return raw; } - - bool operator==(LayoutSize that) const - { - return raw == that.raw; - } - - bool operator!=(LayoutSize that) const - { - return raw != that.raw; - } - - void operator+=(LayoutSize right) - { - if( isInfinite() ) {} - else if( right.isInfinite() ) - { - *this = LayoutSize::infinite(); - } - else - { - *this = LayoutSize(raw + right.raw); - } - } - - void operator*=(LayoutSize right) - { - // Deal with zero first, so that anything (even the "infinite" value) times zero is zero. - if( raw == 0 ) - { - return; - } - - if( right.raw == 0 ) - { - raw = 0; - return; - } - - // Next we deal with infinite cases, so that infinite times anything non-zero is infinite - if( isInfinite() ) - { - return; - } - - if( right.isInfinite() ) - { - *this = LayoutSize::infinite(); - return; - } - - // Finally deal with the case where both sides are finite - *this = LayoutSize(raw * right.raw); - } - - void operator-=(RawValue right) - { - if( isInfinite() ) {} - else - { - *this = LayoutSize(raw - right); - } - } - - void operator/=(RawValue right) - { - if( isInfinite() ) {} - else - { - *this = LayoutSize(raw / right); - } - } - RawValue raw; -}; - -inline LayoutSize operator+(LayoutSize left, LayoutSize right) -{ - LayoutSize result(left); - result += right; - return result; -} - -inline LayoutSize operator*(LayoutSize left, LayoutSize right) -{ - LayoutSize result(left); - result *= right; - return result; -} - -inline LayoutSize operator-(LayoutSize left, LayoutSize::RawValue right) -{ - LayoutSize result(left); - result -= right; - return result; -} - -inline LayoutSize operator/(LayoutSize left, LayoutSize::RawValue right) -{ - LayoutSize result(left); - result /= right; - return result; -} - -inline LayoutSize maximum(LayoutSize left, LayoutSize right) -{ - if(left.isInfinite() || right.isInfinite()) - return LayoutSize::infinite(); - - return LayoutSize(Math::Max( - left.getFiniteValue(), - right.getFiniteValue())); -} - -inline bool operator>(LayoutSize left, LayoutSize::RawValue right) -{ - return left.isInfinite() || (left.getFiniteValue() > right); -} - -inline bool operator<=(LayoutSize left, LayoutSize::RawValue right) -{ - return left.isFinite() && (left.getFiniteValue() <= right); -} - -// Layout appropriate to "just memory" scenarios, -// such as laying out the members of a constant buffer. -struct UniformLayoutInfo -{ - LayoutSize size; - size_t alignment; - - UniformLayoutInfo() - : size(0) - , alignment(1) - {} - - UniformLayoutInfo( - LayoutSize size, - size_t alignment) - : size(size) - , alignment(alignment) - {} -}; - -// Extended information required for an array of uniform data, -// including the "stride" of the array (the space between -// consecutive elements). -struct UniformArrayLayoutInfo : UniformLayoutInfo -{ - size_t elementStride; - - UniformArrayLayoutInfo() - : elementStride(0) - {} - - UniformArrayLayoutInfo( - LayoutSize size, - size_t alignment, - size_t elementStride) - : UniformLayoutInfo(size, alignment) - , elementStride(elementStride) - {} -}; - -typedef slang::ParameterCategory LayoutResourceKind; - -// Layout information for a value that only consumes -// a single resource kind. -struct SimpleLayoutInfo -{ - // What kind of resource should we consume? - LayoutResourceKind kind; - - // How many resources of that kind? - LayoutSize size; - - // only useful in the uniform case - size_t alignment; - - SimpleLayoutInfo() - : kind(LayoutResourceKind::None) - , size(0) - , alignment(1) - {} - - SimpleLayoutInfo( - UniformLayoutInfo uniformInfo) - : kind(LayoutResourceKind::Uniform) - , size(uniformInfo.size) - , alignment(uniformInfo.alignment) - {} - - SimpleLayoutInfo(LayoutResourceKind kind, LayoutSize size, size_t alignment=1) - : kind(kind) - , size(size) - , alignment(alignment) - {} - - // Convert to layout for uniform data - UniformLayoutInfo getUniformLayout() - { - if(kind == LayoutResourceKind::Uniform) - { - return UniformLayoutInfo(size, alignment); - } - else - { - return UniformLayoutInfo(0, 1); - } - } -}; - -// Only useful in the case of a homogeneous array -struct SimpleArrayLayoutInfo : SimpleLayoutInfo -{ - // This field is only useful in the uniform case - size_t elementStride; - - // Convert to layout for uniform data - UniformArrayLayoutInfo getUniformLayout() - { - if(kind == LayoutResourceKind::Uniform) - { - return UniformArrayLayoutInfo(size, alignment, elementStride); - } - else - { - return UniformArrayLayoutInfo(0, 1, 0); - } - } -}; - -struct LayoutRulesImpl; - -// Base class for things that store layout info -class Layout : public RefObject -{ -}; - -// A reified representation of a particular laid-out type -class TypeLayout : public Layout -{ -public: - // The type that was laid out - RefPtr type; - Type* getType() { return type.Ptr(); } - - // The layout rules that were used to produce this type - LayoutRulesImpl* rules; - - struct ResourceInfo - { - // What kind of register was it? - LayoutResourceKind kind = LayoutResourceKind::None; - - // How many registers of the above kind did we use? - LayoutSize count; - }; - - List resourceInfos; - - // For uniform data, alignment matters, but not for - // any other resource category, so we don't waste - // the space storing it in the above array - UInt uniformAlignment = 1; - - - /// The layout for data that is conceptually owned by this type, but which is pending layout. - /// - /// When a type contains interface/existential fields (recursively), the - /// actual data referenced by these fields needs to get allocated somewhere, - /// but it cannot go inline at the point where the interface/existential - /// type appears, or else the layout of a composite object would change - /// when the concrete type(s) we plug in change. - /// - /// We solve this problem by tracking this data that is "pending" layout, - /// and then "flushing" the pending data at appropriate places during - /// the layout process. - /// - RefPtr pendingDataTypeLayout; - - ResourceInfo* FindResourceInfo(LayoutResourceKind kind) - { - for(auto& rr : resourceInfos) - { - if(rr.kind == kind) - return &rr; - } - return nullptr; - } - - ResourceInfo* findOrAddResourceInfo(LayoutResourceKind kind) - { - auto existing = FindResourceInfo(kind); - if(existing) return existing; - - ResourceInfo info; - info.kind = kind; - info.count = 0; - resourceInfos.add(info); - return &resourceInfos.getLast(); - } - - void addResourceUsage(ResourceInfo info) - { - if(info.count == 0) return; - - findOrAddResourceInfo(info.kind)->count += info.count; - } - - void addResourceUsage(LayoutResourceKind kind, LayoutSize count) - { - ResourceInfo info; - info.kind = kind; - info.count = count; - addResourceUsage(info); - } - - void addResourceUsageFrom(TypeLayout* otherTypeLayout); - - /// "Unwrap" any layers of array-ness from this type layout. - /// - /// If this is an `ArrayTypeLayout`, returns the result of unwrapping the element type layout. - /// Otherwise, returns this type layout. - /// - RefPtr unwrapArray(); -}; - -typedef unsigned int VarLayoutFlags; -enum VarLayoutFlag : VarLayoutFlags -{ - HasSemantic = 1 << 1 -}; - -// A reified layout for a particular variable, field, etc. -class VarLayout : public Layout -{ -public: - // The variable we are laying out - DeclRef varDecl; - VarDeclBase* getVariable() { return varDecl.getDecl(); } - - Name* getName() { return getVariable()->getName(); } - - // The result of laying out the variable's type - RefPtr typeLayout; - TypeLayout* getTypeLayout() { return typeLayout.Ptr(); } - - // Additional flags - VarLayoutFlags flags = 0; - - // System-value semantic (and index) if this is a system value - String systemValueSemantic; - int systemValueSemanticIndex; - - // General case semantic name and index - // TODO: this and the system-value field are redundant - // TODO: the `VarLayout` type is getting bloated; we need to not store this - // information unless actually required. - String semanticName; - int semanticIndex; - - // The stage this variable belongs to, in case it is - // stage-specific. - // TODO: This is wasteful to be storing on every single - // variable layout. - Stage stage = Stage::Unknown; - - // The start register(s) for any resources - struct ResourceInfo - { - // What kind of register was it? - LayoutResourceKind kind = LayoutResourceKind::None; - - // What binding space (HLSL) or set (Vulkan) are we placed in? - UInt space; - - // What is our starting register in that space? - // - // (In the case of uniform data, this is a byte offset) - UInt index; - }; - List resourceInfos; - - ResourceInfo* FindResourceInfo(LayoutResourceKind kind) - { - for(auto& rr : resourceInfos) - { - if(rr.kind == kind) - return &rr; - } - return nullptr; - } - - ResourceInfo* AddResourceInfo(LayoutResourceKind kind) - { - ResourceInfo info; - info.kind = kind; - info.space = 0; - info.index = 0; - - resourceInfos.add(info); - return &resourceInfos.getLast(); - } - - ResourceInfo* findOrAddResourceInfo(LayoutResourceKind kind) - { - auto existing = FindResourceInfo(kind); - if(existing) return existing; - - return AddResourceInfo(kind); - } - - RefPtr pendingVarLayout; -}; - -// type layout for a variable that has a constant-buffer type -class ParameterGroupTypeLayout : public TypeLayout -{ -public: - // The layout of the "container" part itself. - // E.g., for a constant buffer, this would reflect - // the resource usage of the container, without - // the element type factored in. All of the offsets - // for this variable should be zero, but it is included - // for completeness. - RefPtr containerVarLayout; - - // A variable layout for the element of the container. - // The offsets of the variable layout will reflect - // the offsets that need to applied to get past the - // container types resource usage, while the actual - // type layout won't have offsets applied (unlike - // `offsetElementTypeLayout` below). - RefPtr elementVarLayout; - - // The layout of the element type, with offsets applied - // so that any fields (if the element type is a `struct`) - // will be offset by the resource usage of the container. - RefPtr offsetElementTypeLayout; - - // If the element type layout had any "pending" data, then - // as much of that data as possible will be flushed to - // fit into the overall layout of the parameter group. - // - // This field stores the offset information for where - // the pending data got stored relative to the start of - // the group. - // -// RefPtr flushedDataVarLayout; -}; - -// type layout for a variable that has a constant-buffer type -class StructuredBufferTypeLayout : public TypeLayout -{ -public: - RefPtr elementTypeLayout; -}; - - /// Type layout for a logical sequence type -class SequenceTypeLayout : public TypeLayout -{ -public: - /// The layout of the element type. - /// - /// This layout may include adjustments to make lookups in elements - /// of the array Just Work, and may not be the same as the layout - /// of the element type when used in a non-array context. - /// - RefPtr elementTypeLayout; - - /// The stride in bytes between elements. - size_t uniformStride = 0; -}; - - /// Type layout for an array type -class ArrayTypeLayout : public SequenceTypeLayout -{ -public: - /// The original layout of the element type. - /// - /// This layout does not include any adjustments that - /// were made to the element type in order to make - /// lookup into array elements Just Work. - /// - RefPtr originalElementTypeLayout; -}; - -// type layout for a variable with stream-output type -class StreamOutputTypeLayout : public TypeLayout -{ -public: - RefPtr elementTypeLayout; -}; - -class VectorTypeLayout : public SequenceTypeLayout -{ -public: -}; - - -class MatrixTypeLayout : public SequenceTypeLayout -{ -public: - /// Is this matrix laid out as row-major or column-major? - /// - /// Note that this does *not* affect the interpretation - /// of the `elementTypeLayout` field, which always represents - /// the logical elements of the matrix type, which are its - /// rows. - /// - MatrixLayoutMode mode; -}; - -// Specific case of type layout for a struct -class StructTypeLayout : public TypeLayout -{ -public: - // An ordered list of layouts for the known fields - List> fields; - - // Map a variable to its layout directly. - // - // Note that in the general case, there may be entries - // in the `fields` array that came from multiple - // translation units, and in cases where there are - // multiple declarations of the same parameter, only - // one will appear in `fields`, while all of - // them will be reflected in `mapVarToLayout`. - // - // TODO: This should map from a declaration to the *index* - // in the array above, rather than to the actual pointer, - // so that we - Dictionary> mapVarToLayout; - - // As an accellerator for type layouts created at the - // IR layer, we include a second map that use IR "key" - // instructions to map to fields. - // - Dictionary> mapKeyToLayout; -}; - -class GenericParamTypeLayout : public TypeLayout -{ -public: - RefPtr getGlobalGenericParamDecl(); - int paramIndex = 0; -}; - - /// Layout information for a tagged union type. -class TaggedUnionTypeLayout : public TypeLayout -{ -public: - /// The layouts of each of the case types. - /// - /// The order of entries in this array matches - /// the order of case types on the original - /// `TaggedUnionType`, and the index of a case - /// type is also the tag value for that case. - /// - List> caseTypeLayouts; - - /// The byte offset for the tag field. - /// - /// The tag field will always be allocated as - /// a `uint`, so we don't store a separate layout - /// for it. - /// - LayoutSize tagOffset; -}; - - /// Layout information for a type with existential (sub-)field types specialized. -class ExistentialSpecializedTypeLayout : public TypeLayout -{ -public: - RefPtr baseTypeLayout; - RefPtr pendingDataVarLayout; -}; - - /// Layout for a scoped entity like a program, module, or entry point -class ScopeLayout : public Layout -{ -public: - // The layout for the parameters of this entity. - // - RefPtr parametersLayout; -}; - -StructTypeLayout* getScopeStructLayout( - ScopeLayout* programLayout); - -// Layout information for a single shader entry point -// within a program -// -// Treated as a subclass of `StructTypeLayout` because -// it needs to include computed layout information -// for the parameters of the entry point. -// -// TODO: where to store layout info for the return -// type of the function? -class EntryPointLayout : public ScopeLayout -{ -public: - // The corresponding function declaration - RefPtr entryPoint; - - // The shader profile that was used to compile the entry point - Profile profile; - - // Layout for any results of the entry point - RefPtr resultLayout; - - enum Flag : unsigned - { - usesAnySampleRateInput = 0x1, - }; - unsigned flags = 0; - - /// Layouts for all tagged union types required by this entry point. - /// - /// These are any tagged union types used by the generic - /// arguments that this entry point is being compiled with. - List> taggedUnionTypeLayouts; -}; - -class GenericParamLayout : public Layout -{ -public: - RefPtr decl; - int index; -}; - -// Layout information for the global scope of a program -class ProgramLayout : public ScopeLayout -{ -public: - /* - // We store a layout for the declarations at the global - // scope. Note that this will *either* be a single - // `StructTypeLayout` with the fields stored directly, - // or it will be a single `ParameterGroupTypeLayout`, - // where the global-scope fields are the members of - // that constant buffer. - // - // The `struct` case will be used if there are no - // "naked" global-scope uniform variables, and the - // constant-buffer case will be used if there are - // (since a constant buffer will have to be allocated - // to store them). - // - RefPtr globalScopeLayout; - */ - - /// The target and program for which layout was computed - TargetProgram* targetProgram; - - TargetProgram* getTargetProgram() { return targetProgram; } - TargetRequest* getTargetReq() { return targetProgram->getTargetReq(); } - Program* getProgram() { return targetProgram->getProgram(); } - - - // We catalog the requested entry points here, - // and any entry-point-specific parameter data - // will (eventually) belong there... - List> entryPoints; - - List> globalGenericParams; - Dictionary globalGenericParamsMap; -}; - -StructTypeLayout* getGlobalStructLayout( - ProgramLayout* programLayout); - -struct LayoutRulesFamilyImpl; - -// A delineation of shader parameter types into fine-grained -// categories that can then be mapped down to actual resources -// by a given set of rules. -// -// TODO(tfoley): `SlangParameterCategory` and `slang::ParameterCategory` -// are badly named, and need to be revised so they can't be confused -// with this concept. -enum class ShaderParameterKind -{ - ConstantBuffer, - TextureUniformBuffer, - ShaderStorageBuffer, - - StructuredBuffer, - MutableStructuredBuffer, - - RawBuffer, - MutableRawBuffer, - - Buffer, - MutableBuffer, - - Texture, - MutableTexture, - - TextureSampler, - MutableTextureSampler, - - InputRenderTarget, - - SamplerState, - - Image, - MutableImage, - - RegisterSpace, -}; - -struct SimpleLayoutRulesImpl -{ - // Get size and alignment for a single value of base type. - virtual SimpleLayoutInfo GetScalarLayout(BaseType baseType) = 0; - - // Get size and alignment for an array of elements - virtual SimpleArrayLayoutInfo GetArrayLayout(SimpleLayoutInfo elementInfo, LayoutSize elementCount) = 0; - - // Get layout for a vector or matrix type - virtual SimpleLayoutInfo GetVectorLayout(SimpleLayoutInfo elementInfo, size_t elementCount) = 0; - virtual SimpleArrayLayoutInfo GetMatrixLayout(SimpleLayoutInfo elementInfo, size_t rowCount, size_t columnCount) = 0; - - // Begin doing layout on a `struct` type - virtual UniformLayoutInfo BeginStructLayout() = 0; - - // Add a field to a `struct` type, and return the offset for the field - virtual LayoutSize AddStructField(UniformLayoutInfo* ioStructInfo, UniformLayoutInfo fieldInfo) = 0; - - // End layout for a struct, and finalize its size/alignment. - virtual void EndStructLayout(UniformLayoutInfo* ioStructInfo) = 0; -}; - -struct ObjectLayoutRulesImpl -{ - // Compute layout info for an object type - virtual SimpleLayoutInfo GetObjectLayout(ShaderParameterKind kind) = 0; -}; - -struct LayoutRulesImpl -{ - LayoutRulesFamilyImpl* family; - SimpleLayoutRulesImpl* simpleRules; - ObjectLayoutRulesImpl* objectRules; - - // Forward `SimpleLayoutRulesImpl` interface - - SimpleLayoutInfo GetScalarLayout(BaseType baseType) - { - return simpleRules->GetScalarLayout(baseType); - } - - SimpleArrayLayoutInfo GetArrayLayout(SimpleLayoutInfo elementInfo, LayoutSize elementCount) - { - return simpleRules->GetArrayLayout(elementInfo, elementCount); - } - - SimpleLayoutInfo GetVectorLayout(SimpleLayoutInfo elementInfo, size_t elementCount) - { - return simpleRules->GetVectorLayout(elementInfo, elementCount); - } - - SimpleArrayLayoutInfo GetMatrixLayout(SimpleLayoutInfo elementInfo, size_t rowCount, size_t columnCount) - { - return simpleRules->GetMatrixLayout(elementInfo, rowCount, columnCount); - } - - UniformLayoutInfo BeginStructLayout() - { - return simpleRules->BeginStructLayout(); - } - - LayoutSize AddStructField(UniformLayoutInfo* ioStructInfo, UniformLayoutInfo fieldInfo) - { - return simpleRules->AddStructField(ioStructInfo, fieldInfo); - } - - void EndStructLayout(UniformLayoutInfo* ioStructInfo) - { - return simpleRules->EndStructLayout(ioStructInfo); - } - - // Forward `ObjectLayoutRulesImpl` interface - - SimpleLayoutInfo GetObjectLayout(ShaderParameterKind kind) - { - return objectRules->GetObjectLayout(kind); - } - - // - - LayoutRulesFamilyImpl* getLayoutRulesFamily() { return family; } -}; - -struct LayoutRulesFamilyImpl -{ - virtual LayoutRulesImpl* getConstantBufferRules() = 0; - virtual LayoutRulesImpl* getPushConstantBufferRules() = 0; - virtual LayoutRulesImpl* getTextureBufferRules() = 0; - virtual LayoutRulesImpl* getVaryingInputRules() = 0; - virtual LayoutRulesImpl* getVaryingOutputRules() = 0; - virtual LayoutRulesImpl* getSpecializationConstantRules()= 0; - virtual LayoutRulesImpl* getShaderStorageBufferRules() = 0; - virtual LayoutRulesImpl* getParameterBlockRules() = 0; - - virtual LayoutRulesImpl* getRayPayloadParameterRules() = 0; - virtual LayoutRulesImpl* getCallablePayloadParameterRules() = 0; - virtual LayoutRulesImpl* getHitAttributesParameterRules()= 0; - - virtual LayoutRulesImpl* getShaderRecordConstantBufferRules() = 0; -}; - -typedef List> GenericParamLayouts; - -struct TypeLayoutContext -{ - // The layout rules to use (e.g., we compute - // layout differently in a `cbuffer` vs. the - // parameter list of a fragment shader). - LayoutRulesImpl* rules; - - // The target request that is triggering layout - TargetRequest* targetReq; - - // A parent program layout that will establish the ordering - // of all global generic type parameters. - // - ProgramLayout* programLayout; - - // Whether to lay out matrices column-major - // or row-major. - MatrixLayoutMode matrixLayoutMode; - - // The concrete types (if any) to plug into the currently in-scope - // existential type slots. - // - Int existentialTypeArgCount = 0; - ExistentialTypeSlots::Arg const* existentialTypeArgs = nullptr; - - LayoutRulesImpl* getRules() { return rules; } - LayoutRulesFamilyImpl* getRulesFamily() const { return rules->getLayoutRulesFamily(); } - - TypeLayoutContext with(LayoutRulesImpl* inRules) const - { - TypeLayoutContext result = *this; - result.rules = inRules; - return result; - } - - TypeLayoutContext with(MatrixLayoutMode inMatrixLayoutMode) const - { - TypeLayoutContext result = *this; - result.matrixLayoutMode = inMatrixLayoutMode; - return result; - } - - TypeLayoutContext withExistentialTypeArgs( - Int argCount, - ExistentialTypeSlots::Arg const* args) const - { - TypeLayoutContext result = *this; - result.existentialTypeArgCount = argCount; - result.existentialTypeArgs = args; - return result; - } - - TypeLayoutContext withExistentialTypeSlotsOffsetBy( - Int offset) const - { - TypeLayoutContext result = *this; - if( existentialTypeArgCount > offset ) - { - result.existentialTypeArgCount = existentialTypeArgCount - offset; - result.existentialTypeArgs = existentialTypeArgs + offset; - } - else - { - result.existentialTypeArgCount = 0; - result.existentialTypeArgs = nullptr; - } - return result; - - } -}; - -// - - /// A custom tuple to capture the outputs of type layout -struct TypeLayoutResult -{ - /// The actual heap-allocated layout object with all the details - RefPtr layout; - - /// A simplified representation of layout information. - /// - /// This information is suitable for the case where a type only - /// consumes a single resource. - /// - SimpleLayoutInfo info; - - /// Default constructor. - TypeLayoutResult() - {} - - /// Construct a result from the given layout object and simple layout info. - TypeLayoutResult(RefPtr inLayout, SimpleLayoutInfo const& inInfo) - : layout(inLayout) - , info(inInfo) - {} -}; - - /// Helper type for building `struct` type layouts -struct StructTypeLayoutBuilder -{ -public: - /// Begin the layout process for `type`, using `rules` - void beginLayout( - Type* type, - LayoutRulesImpl* rules); - - /// Begin the layout process for `type`, using `rules`, if it hasn't already been begun. - /// - /// This functions allows for a `StructTypeLayoutBuilder` to be use lazily, - /// only allocating a type layout object if it is actaully needed. - /// - void beginLayoutIfNeeded( - Type* type, - LayoutRulesImpl* rules); - - /// Add a field to the struct type layout. - /// - /// One of the `beginLayout*()` functions must have been called previously. - /// - RefPtr addField( - DeclRef field, - TypeLayoutResult fieldResult); - - /// Add a field to the struct type layout. - /// - /// One of the `beginLayout*()` functions must have been called previously. - /// - RefPtr addField( - DeclRef field, - RefPtr fieldTypeLayout); - - /// Complete layout. - /// - /// If layout was begun, ensures that the result of `getTypeLayout()` is usable. - /// If layout was never begin, does nothing. - /// - void endLayout(); - - /// Get the type layout. - /// - /// This can be called any time after `beginLayout*()`. - /// In particular, it can be called before `endLayout`. - /// - RefPtr getTypeLayout(); - - /// The the type layout result. - /// - /// This is primarily useful for implementation code in `_createTypeLayout`. - /// - TypeLayoutResult getTypeLayoutResult(); - -private: - /// The layout rules being used, if layout has begun. - LayoutRulesImpl* m_rules = nullptr; - - /// The type layout being computed, if layout has begun. - RefPtr m_typeLayout; - - /// Uniform offset/alignment statte used when computing offset for uniform fields. - UniformLayoutInfo m_info; -}; - -// - -// Get an appropriate set of layout rules (packaged up -// as a `TypeLayoutContext`) to perform type layout -// for the given target. -// -// The provided `programLayout` is used to establish -// the ordering of all global generic type paramters. -// -TypeLayoutContext getInitialLayoutContextForTarget( - TargetRequest* targetReq, - ProgramLayout* programLayout); - - /// Direction(s) of a varying shader parameter -typedef unsigned int EntryPointParameterDirectionMask; -enum -{ - kEntryPointParameterDirection_Input = 0x1, - kEntryPointParameterDirection_Output = 0x2, -}; - - - /// Get layout information for a simple varying parameter type. - /// - /// A simple varying parameter is a scalar, vector, or matrix. - /// -RefPtr getSimpleVaryingParameterTypeLayout( - TypeLayoutContext const& context, - Type* type, - EntryPointParameterDirectionMask directionMask); - -// Create a full type-layout object for a type, -// according to the layout rules in `context`. -RefPtr createTypeLayout( - TypeLayoutContext const& context, - Type* type); - -// - - /// Create a layout for a parameter-group type (a `ConstantBuffer` or `ParameterBlock`). -RefPtr createParameterGroupTypeLayout( - TypeLayoutContext const& context, - RefPtr parameterGroupType); - - /// Create a wrapper constant buffer type layout, if needed. - /// - /// When dealing with entry-point `uniform` and global-scope parameters, - /// we want to create a wrapper constant buffer for all the parameters - /// if and only if there exist some parameters that use "ordinary" data - /// (`LayoutResourceKind::Uniform`). - /// - /// This function determines whether such a wrapper is needed, based - /// on the `elementTypeLayout` given, and either creates and returns - /// the layout for the wrapper, or the unmodified `elementTypeLayout`. - /// -RefPtr createConstantBufferTypeLayoutIfNeeded( - TypeLayoutContext const& context, - RefPtr elementTypeLayout); - -// Create a type layout for a structured buffer type. -RefPtr -createStructuredBufferTypeLayout( - TypeLayoutContext const& context, - ShaderParameterKind kind, - RefPtr structuredBufferType, - RefPtr elementType); - -int findGenericParam(List> & genericParameters, GlobalGenericParamDecl * decl); -// - -// Given an existing type layout `oldTypeLayout`, apply offsets -// to any contained fields based on the resource infos in `offsetVarLayout`. -RefPtr applyOffsetToTypeLayout( - RefPtr oldTypeLayout, - RefPtr offsetVarLayout); - -} - -#endif diff --git a/source/slang/type-system-shared.cpp b/source/slang/type-system-shared.cpp deleted file mode 100644 index 10ebaee24..000000000 --- a/source/slang/type-system-shared.cpp +++ /dev/null @@ -1,11 +0,0 @@ -#include "type-system-shared.h" - -namespace Slang -{ - TextureFlavor TextureFlavor::create(SlangResourceShape shape, SlangResourceAccess access) - { - TextureFlavor rs; - rs.flavor = uint16_t(shape | (access << 8)); - return rs; - } -} diff --git a/source/slang/type-system-shared.h b/source/slang/type-system-shared.h deleted file mode 100644 index 95840e701..000000000 --- a/source/slang/type-system-shared.h +++ /dev/null @@ -1,102 +0,0 @@ -#ifndef SLANG_TYPE_SYSTEM_SHARED_H -#define SLANG_TYPE_SYSTEM_SHARED_H - -#include "../../slang.h" - -namespace Slang -{ -#define FOREACH_BASE_TYPE(X) \ - X(Void) \ - X(Bool) \ - X(Int8) \ - X(Int16) \ - X(Int) \ - X(Int64) \ - X(UInt8) \ - X(UInt16) \ - X(UInt) \ - X(UInt64) \ - X(Half) \ - X(Float) \ - X(Double) \ -/* end */ - - enum class BaseType - { -#define DEFINE_BASE_TYPE(NAME) NAME, -FOREACH_BASE_TYPE(DEFINE_BASE_TYPE) -#undef DEFINE_BASE_TYPE - - CountOf, - }; - - struct TextureFlavor - { - typedef TextureFlavor ThisType; - enum - { - // Mask for the overall "shape" of the texture - BaseShapeMask = SLANG_RESOURCE_BASE_SHAPE_MASK, - - // Flag for whether the shape has "array-ness" - ArrayFlag = SLANG_TEXTURE_ARRAY_FLAG, - - // Whether or not the texture stores multiple samples per pixel - MultisampleFlag = SLANG_TEXTURE_MULTISAMPLE_FLAG, - - // Whether or not this is a shadow texture - // - // TODO(tfoley): is this even meaningful/used? - // ShadowFlag = 0x80, - }; - - enum Shape : uint8_t - { - Shape1D = SLANG_TEXTURE_1D, - Shape2D = SLANG_TEXTURE_2D, - Shape3D = SLANG_TEXTURE_3D, - ShapeCube = SLANG_TEXTURE_CUBE, - ShapeBuffer = SLANG_TEXTURE_BUFFER, - - Shape1DArray = Shape1D | ArrayFlag, - Shape2DArray = Shape2D | ArrayFlag, - // No Shape3DArray - ShapeCubeArray = ShapeCube | ArrayFlag, - }; - - enum - { - // This the total number of expressible flavors, - // which is *not* to say that every expressible - // flavor is actual valid. - Count = 0x10000, - }; - - uint16_t flavor; - - Shape GetBaseShape() const { return Shape(flavor & BaseShapeMask); } - bool isArray() const { return (flavor & ArrayFlag) != 0; } - bool isMultisample() const { return (flavor & MultisampleFlag) != 0; } - // bool isShadow() const { return (flavor & ShadowFlag) != 0; } - - SLANG_FORCE_INLINE bool operator==(const ThisType& rhs) const { return flavor == rhs.flavor; } - SLANG_FORCE_INLINE bool operator!=(const ThisType& rhs) const { return !(*this == rhs); } - - SlangResourceShape getShape() const { return flavor & 0xFF; } - SlangResourceAccess getAccess() const { return (flavor >> 8) & 0xFF; } - - TextureFlavor() = default; - TextureFlavor(uint32_t tag) { flavor = (uint16_t)tag; } - - static TextureFlavor create(SlangResourceShape shape, SlangResourceAccess access); - }; - - enum class SamplerStateFlavor : uint8_t - { - SamplerState, - SamplerComparisonState, - }; - -} - -#endif diff --git a/source/slang/val-defs.h b/source/slang/val-defs.h deleted file mode 100644 index f5b099079..000000000 --- a/source/slang/val-defs.h +++ /dev/null @@ -1,155 +0,0 @@ -// val-defs.h - -// Syntax class definitions for compile-time values. - -// A compile-time integer (may not have a specific concrete value) -ABSTRACT_SYNTAX_CLASS(IntVal, Val) -END_SYNTAX_CLASS() - -// Trivial case of a value that is just a constant integer -SYNTAX_CLASS(ConstantIntVal, IntVal) - FIELD(IntegerLiteralValue, value) - - RAW( - ConstantIntVal() - {} - ConstantIntVal(IntegerLiteralValue value) - : value(value) - {} - - virtual bool EqualsVal(Val* val) override; - virtual String ToString() override; - virtual int GetHashCode() override; - ) -END_SYNTAX_CLASS() - -// The logical "value" of a rererence to a generic value parameter -SYNTAX_CLASS(GenericParamIntVal, IntVal) - DECL_FIELD(DeclRef, declRef) - - RAW( - GenericParamIntVal() - {} - GenericParamIntVal(DeclRef declRef) - : declRef(declRef) - {} - - virtual bool EqualsVal(Val* val) override; - virtual String ToString() override; - virtual int GetHashCode() override; - virtual RefPtr SubstituteImpl(SubstitutionSet subst, int* ioDiff) override; -) -END_SYNTAX_CLASS() - -// A witness to the fact that some proposition is true, encoded -// at the level of the type system. -// -// Given a generic like: -// -// void example(L light) -// where L : ILight -// { ... } -// -// a call to `example()` needs two things for us to be sure -// it is valid: -// -// 1. We need a type `X` to use as the argument for the -// parameter `L`. We might supply this explicitly, or -// via inference. -// -// 2. We need a *proof* that whatever `X` we chose conforms -// to the `ILight` interface. -// -// The easiest way to make such a proof is by construction, -// and a `Witness` represents such a constructive proof. -// Conceptually a proposition like `X : ILight` can be -// seen as a type, and witness prooving that proposition -// is a value of that type. -// -// We construct and store witnesses explicitly during -// semantic checking because they can help us with -// generating downstream code. By following the structure -// of a witness (the structure of a proof) we can, e.g., -// navigate from the knowledge that `X : ILight` to -// the concrete declarations that provide the implementation -// of `ILight` for `X`. -// -ABSTRACT_SYNTAX_CLASS(Witness, Val) -END_SYNTAX_CLASS() - -// A witness that one type is a subtype of another -// (where by "subtype" we include both inheritance -// relationships and type-conforms-to-interface relationships) -// -// TODO: we may need to tease those apart. -ABSTRACT_SYNTAX_CLASS(SubtypeWitness, Witness) - FIELD(RefPtr, sub) - FIELD(RefPtr, sup) -END_SYNTAX_CLASS() - -SYNTAX_CLASS(TypeEqualityWitness, SubtypeWitness) -RAW( - virtual bool EqualsVal(Val* val) override; - virtual String ToString() override; - virtual int GetHashCode() override; - virtual RefPtr SubstituteImpl(SubstitutionSet subst, int * ioDiff) override; -) -END_SYNTAX_CLASS() -// A witness that one type is a subtype of another -// because some in-scope declaration says so -SYNTAX_CLASS(DeclaredSubtypeWitness, SubtypeWitness) - FIELD(DeclRef, declRef); -RAW( - virtual bool EqualsVal(Val* val) override; - virtual String ToString() override; - virtual int GetHashCode() override; - virtual RefPtr SubstituteImpl(SubstitutionSet subst, int * ioDiff) override; -) -END_SYNTAX_CLASS() - -// A witness that `sub : sup` because `sub : mid` and `mid : sup` -SYNTAX_CLASS(TransitiveSubtypeWitness, SubtypeWitness) - // Witness that `sub : mid` - FIELD(RefPtr, subToMid); - - // Witness that `mid : sup` - FIELD(DeclRef, midToSup); -RAW( - virtual bool EqualsVal(Val* val) override; - virtual String ToString() override; - virtual int GetHashCode() override; - virtual RefPtr SubstituteImpl(SubstitutionSet subst, int * ioDiff) override; -) -END_SYNTAX_CLASS() - -// A witness taht `sub : sup` because `sub` was wrapped into -// an existential of type `sup`. -SYNTAX_CLASS(ExtractExistentialSubtypeWitness, SubtypeWitness) -RAW( - // The declaration of the existential value that has been opened - DeclRef declRef; - - virtual bool EqualsVal(Val* val) override; - virtual String ToString() override; - virtual int GetHashCode() override; - virtual RefPtr SubstituteImpl(SubstitutionSet subst, int * ioDiff) override; -) -END_SYNTAX_CLASS() - -// A witness that `sub : sup`, because `sub` is a tagged union -// of the form `A | B | C | ...` and each of `A : sup`, -// `B : sup`, `C : sup`, etc. -// -SYNTAX_CLASS(TaggedUnionSubtypeWitness, SubtypeWitness) -RAW( - // Witnesses that each of the "case" types in the union - // is a subtype of `sup`. - // - List> caseWitnesses; - - virtual bool EqualsVal(Val* val) override; - virtual String ToString() override; - virtual int GetHashCode() override; - virtual RefPtr SubstituteImpl(SubstitutionSet subst, int * ioDiff) override; -) -END_SYNTAX_CLASS() diff --git a/source/slang/visitor.h b/source/slang/visitor.h deleted file mode 100644 index 8a0301782..000000000 --- a/source/slang/visitor.h +++ /dev/null @@ -1,535 +0,0 @@ -// visitor.h -#ifndef SLANG_VISITOR_H_INCLUDED -#define SLANG_VISITOR_H_INCLUDED - -// This file defines the basic "Visitor" pattern for doing dispatch -// over the various categories of syntax node. - -#include "syntax.h" - -namespace Slang { - -// -// type Visitors -// - -struct ITypeVisitor -{ -#define ABSTRACT_SYNTAX_CLASS(NAME,BASE) /* empty */ -#define SYNTAX_CLASS(NAME, BASE) \ - virtual void dispatch_##NAME(NAME* obj, void* extra) = 0; - -#include "object-meta-begin.h" -#include "type-defs.h" -#include "object-meta-end.h" -}; - -template -struct TypeVisitor : Base -{ - Result dispatch(Type* type) - { - Result result; - type->accept(this, &result); - return result; - } - - Result dispatchType(Type* type) - { - Result result; - type->accept(this, &result); - return result; - } - -#define ABSTRACT_SYNTAX_CLASS(NAME,BASE) /* empty */ -#define SYNTAX_CLASS(NAME, BASE) \ - virtual void dispatch_##NAME(NAME* obj, void* extra) override \ - { *(Result*)extra = ((Derived*) this)->visit##NAME(obj); } - -#include "object-meta-begin.h" -#include "type-defs.h" -#include "object-meta-end.h" - -#define ABSTRACT_SYNTAX_CLASS(NAME,BASE) SYNTAX_CLASS(NAME, BASE) -#define SYNTAX_CLASS(NAME, BASE) \ - Result visit##NAME(NAME* obj) \ - { return ((Derived*) this)->visit##BASE(obj); } - -#include "object-meta-begin.h" -#include "type-defs.h" -#include "object-meta-end.h" -}; - -template -struct TypeVisitor : Base -{ - void dispatch(Type* type) - { - type->accept(this, 0); - } - - void dispatchType(Type* type) - { - type->accept(this, 0); - } - -#define ABSTRACT_SYNTAX_CLASS(NAME,BASE) /* empty */ -#define SYNTAX_CLASS(NAME, BASE) \ - virtual void dispatch_##NAME(NAME* obj, void*) override \ - { ((Derived*) this)->visit##NAME(obj); } - -#include "object-meta-begin.h" -#include "type-defs.h" -#include "object-meta-end.h" - -#define ABSTRACT_SYNTAX_CLASS(NAME,BASE) SYNTAX_CLASS(NAME, BASE) -#define SYNTAX_CLASS(NAME, BASE) \ - void visit##NAME(NAME* obj) \ - { ((Derived*) this)->visit##BASE(obj); } - -#include "object-meta-begin.h" -#include "type-defs.h" -#include "object-meta-end.h" -}; - -template -struct TypeVisitorWithArg : Base -{ - void dispatch(Type* type, Arg const& arg) - { - type->accept(this, (void*)&arg); - } - -#define ABSTRACT_SYNTAX_CLASS(NAME,BASE) /* empty */ -#define SYNTAX_CLASS(NAME, BASE) \ - virtual void dispatch_##NAME(NAME* obj, void* arg) override \ - { ((Derived*) this)->visit##NAME(obj, *(Arg*)arg); } - -#include "object-meta-begin.h" -#include "type-defs.h" -#include "object-meta-end.h" - -#define ABSTRACT_SYNTAX_CLASS(NAME,BASE) SYNTAX_CLASS(NAME, BASE) -#define SYNTAX_CLASS(NAME, BASE) \ - void visit##NAME(NAME* obj, Arg const& arg) \ - { ((Derived*) this)->visit##BASE(obj, arg); } - -#include "object-meta-begin.h" -#include "type-defs.h" -#include "object-meta-end.h" -}; - -// -// Expression Visitors -// - -struct IExprVisitor -{ -#define ABSTRACT_SYNTAX_CLASS(NAME,BASE) /* empty */ -#define SYNTAX_CLASS(NAME, BASE) \ - virtual void dispatch_##NAME(NAME* obj, void* extra) = 0; - -#include "object-meta-begin.h" -#include "expr-defs.h" -#include "object-meta-end.h" -}; - -template -struct ExprVisitor : IExprVisitor -{ - Result dispatch(Expr* expr) - { - Result result; - expr->accept(this, &result); - return result; - } - -#define ABSTRACT_SYNTAX_CLASS(NAME,BASE) /* empty */ -#define SYNTAX_CLASS(NAME, BASE) \ - virtual void dispatch_##NAME(NAME* obj, void* extra) override \ - { *(Result*)extra = ((Derived*) this)->visit##NAME(obj); } - -#include "object-meta-begin.h" -#include "expr-defs.h" -#include "object-meta-end.h" - -#define ABSTRACT_SYNTAX_CLASS(NAME,BASE) SYNTAX_CLASS(NAME, BASE) -#define SYNTAX_CLASS(NAME, BASE) \ - Result visit##NAME(NAME* obj) \ - { return ((Derived*) this)->visit##BASE(obj); } - -#include "object-meta-begin.h" -#include "expr-defs.h" -#include "object-meta-end.h" -}; - -template -struct ExprVisitor : IExprVisitor -{ - void dispatch(Expr* expr) - { - expr->accept(this, 0); - } - -#define ABSTRACT_SYNTAX_CLASS(NAME,BASE) /* empty */ -#define SYNTAX_CLASS(NAME, BASE) \ - virtual void dispatch_##NAME(NAME* obj, void*) override \ - { ((Derived*) this)->visit##NAME(obj); } - -#include "object-meta-begin.h" -#include "expr-defs.h" -#include "object-meta-end.h" - -#define ABSTRACT_SYNTAX_CLASS(NAME,BASE) SYNTAX_CLASS(NAME, BASE) -#define SYNTAX_CLASS(NAME, BASE) \ - void visit##NAME(NAME* obj) \ - { ((Derived*) this)->visit##BASE(obj); } - -#include "object-meta-begin.h" -#include "expr-defs.h" -#include "object-meta-end.h" -}; - -template -struct ExprVisitorWithArg : IExprVisitor -{ - void dispatch(Expr* obj, Arg const& arg) - { - obj->accept(this, (void*)&arg); - } - -#define ABSTRACT_SYNTAX_CLASS(NAME,BASE) /* empty */ -#define SYNTAX_CLASS(NAME, BASE) \ - virtual void dispatch_##NAME(NAME* obj, void* arg) override \ - { ((Derived*) this)->visit##NAME(obj, *(Arg*)arg); } - -#include "object-meta-begin.h" -#include "expr-defs.h" -#include "object-meta-end.h" - -#define ABSTRACT_SYNTAX_CLASS(NAME,BASE) SYNTAX_CLASS(NAME, BASE) -#define SYNTAX_CLASS(NAME, BASE) \ - void visit##NAME(NAME* obj, Arg const& arg) \ - { ((Derived*) this)->visit##BASE(obj, arg); } - -#include "object-meta-begin.h" -#include "expr-defs.h" -#include "object-meta-end.h" -}; - -// -// Statement Visitors -// - -struct IStmtVisitor -{ -#define ABSTRACT_SYNTAX_CLASS(NAME,BASE) /* empty */ -#define SYNTAX_CLASS(NAME, BASE) \ - virtual void dispatch_##NAME(NAME* obj, void* extra) = 0; - -#include "object-meta-begin.h" -#include "stmt-defs.h" -#include "object-meta-end.h" -}; - -template -struct StmtVisitor : IStmtVisitor -{ - Result dispatch(Stmt* stmt) - { - Result result; - stmt->accept(this, &result); - return result; - } - -#define ABSTRACT_SYNTAX_CLASS(NAME,BASE) /* empty */ -#define SYNTAX_CLASS(NAME, BASE) \ - virtual void dispatch_##NAME(NAME* obj, void* extra) override \ - { *(Result*)extra = ((Derived*) this)->visit##NAME(obj); } - -#include "object-meta-begin.h" -#include "stmt-defs.h" -#include "object-meta-end.h" - -#define ABSTRACT_SYNTAX_CLASS(NAME,BASE) SYNTAX_CLASS(NAME, BASE) -#define SYNTAX_CLASS(NAME, BASE) \ - Result visit##NAME(NAME* obj) \ - { return ((Derived*) this)->visit##BASE(obj); } - -#include "object-meta-begin.h" -#include "stmt-defs.h" -#include "object-meta-end.h" -}; - -template -struct StmtVisitor : IStmtVisitor -{ - void dispatch(Stmt* stmt) - { - stmt->accept(this, 0); - } - -#define ABSTRACT_SYNTAX_CLASS(NAME,BASE) /* empty */ -#define SYNTAX_CLASS(NAME, BASE) \ - virtual void dispatch_##NAME(NAME* obj, void*) override \ - { ((Derived*) this)->visit##NAME(obj); } - -#include "object-meta-begin.h" -#include "stmt-defs.h" -#include "object-meta-end.h" - -#define ABSTRACT_SYNTAX_CLASS(NAME,BASE) SYNTAX_CLASS(NAME, BASE) -#define SYNTAX_CLASS(NAME, BASE) \ - void visit##NAME(NAME* obj) \ - { ((Derived*) this)->visit##BASE(obj); } - -#include "object-meta-begin.h" -#include "stmt-defs.h" -#include "object-meta-end.h" -}; - -// -// Declaration Visitors -// - -struct IDeclVisitor -{ -#define ABSTRACT_SYNTAX_CLASS(NAME,BASE) /* empty */ -#define SYNTAX_CLASS(NAME, BASE) \ - virtual void dispatch_##NAME(NAME* obj, void* extra) = 0; - -#include "object-meta-begin.h" -#include "decl-defs.h" -#include "object-meta-end.h" -}; - -template -struct DeclVisitor : IDeclVisitor -{ - Result dispatch(DeclBase* decl) - { - Result result; - decl->accept(this, &result); - return result; - } - -#define ABSTRACT_SYNTAX_CLASS(NAME,BASE) /* empty */ -#define SYNTAX_CLASS(NAME, BASE) \ - virtual void dispatch_##NAME(NAME* obj, void* extra) override \ - { *(Result*)extra = ((Derived*) this)->visit##NAME(obj); } - -#include "object-meta-begin.h" -#include "decl-defs.h" -#include "object-meta-end.h" - -#define ABSTRACT_SYNTAX_CLASS(NAME,BASE) SYNTAX_CLASS(NAME, BASE) -#define SYNTAX_CLASS(NAME, BASE) \ - Result visit##NAME(NAME* obj) \ - { return ((Derived*) this)->visit##BASE(obj); } - -#include "object-meta-begin.h" -#include "decl-defs.h" -#include "object-meta-end.h" -}; - -template -struct DeclVisitor : IDeclVisitor -{ - void dispatch(DeclBase* decl) - { - decl->accept(this, 0); - } - -#define ABSTRACT_SYNTAX_CLASS(NAME,BASE) /* empty */ -#define SYNTAX_CLASS(NAME, BASE) \ - virtual void dispatch_##NAME(NAME* obj, void*) override \ - { ((Derived*) this)->visit##NAME(obj); } - -#include "object-meta-begin.h" -#include "decl-defs.h" -#include "object-meta-end.h" - -#define ABSTRACT_SYNTAX_CLASS(NAME,BASE) SYNTAX_CLASS(NAME, BASE) -#define SYNTAX_CLASS(NAME, BASE) \ - void visit##NAME(NAME* obj) \ - { ((Derived*) this)->visit##BASE(obj); } - -#include "object-meta-begin.h" -#include "decl-defs.h" -#include "object-meta-end.h" -}; - -template -struct DeclVisitorWithArg : IDeclVisitor -{ - void dispatch(DeclBase* obj, Arg const& arg) - { - obj->accept(this, (void*)&arg); - } - -#define ABSTRACT_SYNTAX_CLASS(NAME,BASE) /* empty */ -#define SYNTAX_CLASS(NAME, BASE) \ - virtual void dispatch_##NAME(NAME* obj, void* arg) override \ - { ((Derived*) this)->visit##NAME(obj, *(Arg*)arg); } - -#include "object-meta-begin.h" -#include "decl-defs.h" -#include "object-meta-end.h" - -#define ABSTRACT_SYNTAX_CLASS(NAME,BASE) SYNTAX_CLASS(NAME, BASE) -#define SYNTAX_CLASS(NAME, BASE) \ - void visit##NAME(NAME* obj, Arg const& arg) \ - { ((Derived*) this)->visit##BASE(obj, arg); } - -#include "object-meta-begin.h" -#include "decl-defs.h" -#include "object-meta-end.h" -}; - - -// -// Modifier Visitors -// - -struct IModifierVisitor -{ -#define ABSTRACT_SYNTAX_CLASS(NAME,BASE) /* empty */ -#define SYNTAX_CLASS(NAME, BASE) \ - virtual void dispatch_##NAME(NAME* obj, void* extra) = 0; - -#include "object-meta-begin.h" -#include "modifier-defs.h" -#include "object-meta-end.h" -}; - -template -struct ModifierVisitor : IModifierVisitor -{ - Result dispatch(Modifier* modifier) - { - Result result; - modifier->accept(this, &result); - return result; - } - -#define ABSTRACT_SYNTAX_CLASS(NAME,BASE) /* empty */ -#define SYNTAX_CLASS(NAME, BASE) \ - virtual void dispatch_##NAME(NAME* obj, void* extra) override \ - { *(Result*)extra = ((Derived*) this)->visit##NAME(obj); } - -#include "object-meta-begin.h" -#include "modifier-defs.h" -#include "object-meta-end.h" - -#define ABSTRACT_SYNTAX_CLASS(NAME,BASE) SYNTAX_CLASS(NAME, BASE) -#define SYNTAX_CLASS(NAME, BASE) \ - Result visit##NAME(NAME* obj) \ - { return ((Derived*) this)->visit##BASE(obj); } - -#include "object-meta-begin.h" -#include "modifier-defs.h" -#include "object-meta-end.h" -}; - -template -struct ModifierVisitor : IModifierVisitor -{ - void dispatch(Modifier* modifier) - { - modifier->accept(this, 0); - } - -#define ABSTRACT_SYNTAX_CLASS(NAME,BASE) /* empty */ -#define SYNTAX_CLASS(NAME, BASE) \ - virtual void dispatch_##NAME(NAME* obj, void*) override \ - { ((Derived*) this)->visit##NAME(obj); } - -#include "object-meta-begin.h" -#include "modifier-defs.h" -#include "object-meta-end.h" - -#define ABSTRACT_SYNTAX_CLASS(NAME,BASE) SYNTAX_CLASS(NAME, BASE) -#define SYNTAX_CLASS(NAME, BASE) \ - void visit##NAME(NAME* obj) \ - { ((Derived*) this)->visit##BASE(obj); } - -#include "object-meta-begin.h" -#include "modifier-defs.h" -#include "object-meta-end.h" -}; - -// -// Val Visitors -// - -struct IValVisitor : ITypeVisitor -{ -#define ABSTRACT_SYNTAX_CLASS(NAME,BASE) /* empty */ -#define SYNTAX_CLASS(NAME, BASE) \ - virtual void dispatch_##NAME(NAME* obj, void* extra) = 0; - -#include "object-meta-begin.h" -#include "val-defs.h" -#include "object-meta-end.h" -}; - -template -struct ValVisitor : TypeVisitor -{ - Result dispatch(Val* val) - { - Result result; - val->accept(this, &result); - return result; - } - -#define ABSTRACT_SYNTAX_CLASS(NAME,BASE) /* empty */ -#define SYNTAX_CLASS(NAME, BASE) \ - virtual void dispatch_##NAME(NAME* obj, void* extra) override \ - { *(Result*)extra = ((Derived*) this)->visit##NAME(obj); } - -#include "object-meta-begin.h" -#include "val-defs.h" -#include "object-meta-end.h" - -#define ABSTRACT_SYNTAX_CLASS(NAME,BASE) SYNTAX_CLASS(NAME, BASE) -#define SYNTAX_CLASS(NAME, BASE) \ - Result visit##NAME(NAME* obj) \ - { return ((Derived*) this)->visit##BASE(obj); } - -#include "object-meta-begin.h" -#include "val-defs.h" -#include "object-meta-end.h" -}; - -template -struct ValVisitor : TypeVisitor -{ - void dispatch(Val* val) - { - val->accept(this, 0); - } - -#define ABSTRACT_SYNTAX_CLASS(NAME,BASE) /* empty */ -#define SYNTAX_CLASS(NAME, BASE) \ - virtual void dispatch_##NAME(NAME* obj, void*) override \ - { ((Derived*) this)->visit##NAME(obj); } - -#include "object-meta-begin.h" -#include "val-defs.h" -#include "object-meta-end.h" - -#define ABSTRACT_SYNTAX_CLASS(NAME,BASE) SYNTAX_CLASS(NAME, BASE) -#define SYNTAX_CLASS(NAME, BASE) \ - void visit##NAME(NAME* obj) \ - { ((Derived*) this)->visit##BASE(obj); } - -#include "object-meta-begin.h" -#include "val-defs.h" -#include "object-meta-end.h" - -}; - -} - -#endif \ No newline at end of file diff --git a/tools/gfx/circular-resource-heap-d3d12.h b/tools/gfx/circular-resource-heap-d3d12.h index cca981601..bf9f412cf 100644 --- a/tools/gfx/circular-resource-heap-d3d12.h +++ b/tools/gfx/circular-resource-heap-d3d12.h @@ -1,7 +1,7 @@ -#pragma once +#pragma once #include "../../slang-com-ptr.h" -#include "../../source/core/list.h" +#include "../../source/core/slang-list.h" #include "../../source/core/slang-free-list.h" #include "resource-d3d12.h" diff --git a/tools/gfx/d3d-util.h b/tools/gfx/d3d-util.h index 0c05bed46..6bcee054c 100644 --- a/tools/gfx/d3d-util.h +++ b/tools/gfx/d3d-util.h @@ -1,4 +1,4 @@ -// d3d-util.h +// d3d-util.h #pragma once #include @@ -6,7 +6,7 @@ #include "../../slang-com-helper.h" #include "../../slang-com-ptr.h" -#include "../../source/core/list.h" +#include "../../source/core/slang-list.h" #include "flag-combiner.h" diff --git a/tools/gfx/descriptor-heap-d3d12.h b/tools/gfx/descriptor-heap-d3d12.h index 638c1f752..a546395d8 100644 --- a/tools/gfx/descriptor-heap-d3d12.h +++ b/tools/gfx/descriptor-heap-d3d12.h @@ -5,7 +5,7 @@ #include #include "../../slang-com-ptr.h" -#include "../../source/core/list.h" +#include "../../source/core/slang-list.h" namespace gfx { diff --git a/tools/gfx/flag-combiner.h b/tools/gfx/flag-combiner.h index 83962d2dd..db8c6863b 100644 --- a/tools/gfx/flag-combiner.h +++ b/tools/gfx/flag-combiner.h @@ -1,7 +1,7 @@ #ifndef GFX_FLAG_COMBINER_H #define GFX_FLAG_COMBINER_H -#include "../../source/core/list.h" +#include "../../source/core/slang-list.h" namespace gfx { diff --git a/tools/gfx/render-gl.cpp b/tools/gfx/render-gl.cpp index c20eb7e6d..3249b8620 100644 --- a/tools/gfx/render-gl.cpp +++ b/tools/gfx/render-gl.cpp @@ -6,8 +6,8 @@ #include #include -#include "core/basic.h" -#include "core/secure-crt.h" +#include "core/slang-basic.h" +#include "core/slang-secure-crt.h" #include "external/stb/stb_image_write.h" #include "surface.h" diff --git a/tools/gfx/render-vk.cpp b/tools/gfx/render-vk.cpp index c76a8e42d..77b593565 100644 --- a/tools/gfx/render-vk.cpp +++ b/tools/gfx/render-vk.cpp @@ -4,7 +4,7 @@ //WORKING:#include "options.h" #include "render.h" -#include "../../source/core/smart-pointer.h" +#include "../../source/core/slang-smart-pointer.h" #include "vk-api.h" #include "vk-util.h" diff --git a/tools/gfx/render.h b/tools/gfx/render.h index 292b4f8f8..247932bd5 100644 --- a/tools/gfx/render.h +++ b/tools/gfx/render.h @@ -9,9 +9,9 @@ #include "../../slang-com-helper.h" -#include "../../source/core/smart-pointer.h" -#include "../../source/core/list.h" -#include "../../source/core/dictionary.h" +#include "../../source/core/slang-smart-pointer.h" +#include "../../source/core/slang-list.h" +#include "../../source/core/slang-dictionary.h" #include "../../slang.h" diff --git a/tools/gfx/surface.cpp b/tools/gfx/surface.cpp index 63fd7087c..28fe744de 100644 --- a/tools/gfx/surface.cpp +++ b/tools/gfx/surface.cpp @@ -4,7 +4,7 @@ #include #include -#include "../../source/core/list.h" +#include "../../source/core/slang-list.h" namespace gfx { using namespace Slang; diff --git a/tools/gfx/vk-api.cpp b/tools/gfx/vk-api.cpp index 304513d24..50f80aa26 100644 --- a/tools/gfx/vk-api.cpp +++ b/tools/gfx/vk-api.cpp @@ -1,7 +1,7 @@ // vk-api.cpp #include "vk-api.h" -#include "../../source/core/list.h" +#include "../../source/core/slang-list.h" namespace gfx { using namespace Slang; diff --git a/tools/gfx/vk-swap-chain.cpp b/tools/gfx/vk-swap-chain.cpp index bde68c413..5cf2e96ae 100644 --- a/tools/gfx/vk-swap-chain.cpp +++ b/tools/gfx/vk-swap-chain.cpp @@ -3,7 +3,7 @@ #include "vk-util.h" -#include "../../source/core/list.h" +#include "../../source/core/slang-list.h" #include #include diff --git a/tools/gfx/vk-swap-chain.h b/tools/gfx/vk-swap-chain.h index ad2357315..f8ad98a83 100644 --- a/tools/gfx/vk-swap-chain.h +++ b/tools/gfx/vk-swap-chain.h @@ -6,7 +6,7 @@ #include "render.h" -#include "../../source/core/list.h" +#include "../../source/core/slang-list.h" namespace gfx { diff --git a/tools/render-test/options.cpp b/tools/render-test/options.cpp index 17ba864f7..9423b5b6e 100644 --- a/tools/render-test/options.cpp +++ b/tools/render-test/options.cpp @@ -9,7 +9,7 @@ #include "../../source/core/slang-writer.h" #include "../../source/core/slang-render-api-util.h" -#include "../../source/core/list.h" +#include "../../source/core/slang-list.h" #include "../../source/core/slang-string-util.h" namespace renderer_test { diff --git a/tools/render-test/shader-input-layout.cpp b/tools/render-test/shader-input-layout.cpp index 644b0889e..8205c979e 100644 --- a/tools/render-test/shader-input-layout.cpp +++ b/tools/render-test/shader-input-layout.cpp @@ -1,5 +1,5 @@ #include "shader-input-layout.h" -#include "core/token-reader.h" +#include "core/slang-token-reader.h" #include "render.h" diff --git a/tools/render-test/shader-input-layout.h b/tools/render-test/shader-input-layout.h index d9188fadd..d5a1b6fd5 100644 --- a/tools/render-test/shader-input-layout.h +++ b/tools/render-test/shader-input-layout.h @@ -1,7 +1,7 @@ #ifndef SLANG_TEST_SHADER_INPUT_LAYOUT_H #define SLANG_TEST_SHADER_INPUT_LAYOUT_H -#include "core/basic.h" +#include "core/slang-basic.h" #include "render.h" diff --git a/tools/slang-generate/main.cpp b/tools/slang-generate/main.cpp index ed5af370b..ddd087072 100644 --- a/tools/slang-generate/main.cpp +++ b/tools/slang-generate/main.cpp @@ -3,9 +3,9 @@ #include #include #include -#include "../../source/core/secure-crt.h" +#include "../../source/core/slang-secure-crt.h" -#include "../../source/core/list.h" +#include "../../source/core/slang-list.h" #include "../../source/core/slang-string.h" using namespace Slang; diff --git a/tools/slang-test/options.h b/tools/slang-test/options.h index 48c06463c..c78952625 100644 --- a/tools/slang-test/options.h +++ b/tools/slang-test/options.h @@ -3,11 +3,11 @@ #ifndef OPTIONS_H_INCLUDED #define OPTIONS_H_INCLUDED -#include "../../source/core/dictionary.h" +#include "../../source/core/slang-dictionary.h" #include "test-reporter.h" #include "../../source/core/slang-render-api-util.h" -#include "../../source/core/smart-pointer.h" +#include "../../source/core/slang-smart-pointer.h" // A category that a test can be tagged with struct TestCategory: public Slang::RefObject diff --git a/tools/slang-test/slang-test-main.cpp b/tools/slang-test/slang-test-main.cpp index b7de3082f..f030dc8aa 100644 --- a/tools/slang-test/slang-test-main.cpp +++ b/tools/slang-test/slang-test-main.cpp @@ -1,7 +1,7 @@ // slang-test-main.cpp #include "../../source/core/slang-io.h" -#include "../../source/core/token-reader.h" +#include "../../source/core/slang-token-reader.h" #include "../../source/core/slang-std-writers.h" #include "../../slang-com-helper.h" diff --git a/tools/slang-test/slangc-tool.cpp b/tools/slang-test/slangc-tool.cpp index 40fdf05dd..2a30b10e9 100644 --- a/tools/slang-test/slangc-tool.cpp +++ b/tools/slang-test/slangc-tool.cpp @@ -1,7 +1,7 @@ // test-context.cpp #include "slangc-tool.h" -#include "../../source/core/exception.h" +#include "../../source/core/slang-exception.h" using namespace Slang; diff --git a/tools/slang-test/test-context.h b/tools/slang-test/test-context.h index 95b46fe6e..afc5bb427 100644 --- a/tools/slang-test/test-context.h +++ b/tools/slang-test/test-context.h @@ -4,9 +4,9 @@ #define TEST_CONTEXT_H_INCLUDED #include "../../source/core/slang-string-util.h" -#include "../../source/core/platform.h" +#include "../../source/core/slang-platform.h" #include "../../source/core/slang-std-writers.h" -#include "../../source/core/dictionary.h" +#include "../../source/core/slang-dictionary.h" #include "../../source/core/slang-test-tool-util.h" #include "../../source/core/slang-render-api-util.h" diff --git a/tools/slang-test/test-reporter.h b/tools/slang-test/test-reporter.h index 38e0d2cc3..95f9950e8 100644 --- a/tools/slang-test/test-reporter.h +++ b/tools/slang-test/test-reporter.h @@ -4,9 +4,9 @@ #define TEST_REPORTER_H_INCLUDED #include "../../source/core/slang-string-util.h" -#include "../../source/core/platform.h" +#include "../../source/core/slang-platform.h" #include "../../source/core/slang-std-writers.h" -#include "../../source/core/dictionary.h" +#include "../../source/core/slang-dictionary.h" #define SLANG_CHECK(x) TestReporter::get()->addResultWithLocation((x), #x, __FILE__, __LINE__); diff --git a/tools/slang-test/unit-test-byte-encode.cpp b/tools/slang-test/unit-test-byte-encode.cpp index 83b20d4e0..8ffac3ee3 100644 --- a/tools/slang-test/unit-test-byte-encode.cpp +++ b/tools/slang-test/unit-test-byte-encode.cpp @@ -8,7 +8,7 @@ #include "test-context.h" #include "../../source/core/slang-random-generator.h" -#include "../../source/core/list.h" +#include "../../source/core/slang-list.h" using namespace Slang; diff --git a/tools/slang-test/unit-test-free-list.cpp b/tools/slang-test/unit-test-free-list.cpp index fd7b4844f..28973d9e5 100644 --- a/tools/slang-test/unit-test-free-list.cpp +++ b/tools/slang-test/unit-test-free-list.cpp @@ -8,7 +8,7 @@ #include "test-context.h" #include "../../source/core/slang-random-generator.h" -#include "../../source/core/list.h" +#include "../../source/core/slang-list.h" using namespace Slang; diff --git a/tools/slang-test/unit-test-memory-arena.cpp b/tools/slang-test/unit-test-memory-arena.cpp index 5aa0262e5..2aa898c9d 100644 --- a/tools/slang-test/unit-test-memory-arena.cpp +++ b/tools/slang-test/unit-test-memory-arena.cpp @@ -8,7 +8,7 @@ #include "test-context.h" #include "../../source/core/slang-random-generator.h" -#include "../../source/core/list.h" +#include "../../source/core/slang-list.h" using namespace Slang; -- cgit v1.2.3