From 84605a13104c388fa28c9ab09787d953251602eb Mon Sep 17 00:00:00 2001
From: Philip Rebohle <philip.rebohle@tu-dortmund.de>
Date: Fri, 8 Dec 2017 00:44:58 +0100
Subject: [PATCH] [dxvk] Refactored input layout state

---
 src/d3d11/d3d11_context.cpp       |  7 ++--
 src/d3d11/d3d11_device.cpp        | 18 ++++-----
 src/d3d11/d3d11_input_layout.cpp  | 23 ++++++++++-
 src/d3d11/d3d11_input_layout.h    | 14 ++++---
 src/dxgi/dxgi_presenter.cpp       |  3 +-
 src/dxvk/dxvk_constant_state.cpp  | 26 -------------
 src/dxvk/dxvk_constant_state.h    | 65 ++++++++++++++++---------------
 src/dxvk/dxvk_context.cpp         | 39 ++++++++++++-------
 src/dxvk/dxvk_context.h           | 13 +++++--
 src/dxvk/dxvk_context_state.h     |  1 +
 tests/dxvk/test_dxvk_triangle.cpp |  3 +-
 11 files changed, 114 insertions(+), 98 deletions(-)

diff --git a/src/d3d11/d3d11_context.cpp b/src/d3d11/d3d11_context.cpp
index 5be7d014b..17ad6bfaf 100644
--- a/src/d3d11/d3d11_context.cpp
+++ b/src/d3d11/d3d11_context.cpp
@@ -463,9 +463,10 @@ namespace dxvk {
     if (m_state.ia.inputLayout != inputLayout) {
       m_state.ia.inputLayout = inputLayout;
       
-      m_context->setInputLayout(inputLayout != nullptr
-        ? inputLayout->GetDXVKInputLayout()
-        : nullptr);
+      if (inputLayout != nullptr)
+        inputLayout->BindToContext(m_context);
+      else
+        m_context->setInputLayout(0, nullptr, 0, nullptr);
     }
   }
   
diff --git a/src/d3d11/d3d11_device.cpp b/src/d3d11/d3d11_device.cpp
index d42e96577..bab8af5e9 100644
--- a/src/d3d11/d3d11_device.cpp
+++ b/src/d3d11/d3d11_device.cpp
@@ -342,8 +342,8 @@ namespace dxvk {
       
       Rc<DxbcIsgn> inputSignature = dxbcModule.isgn();
       
-      std::vector<VkVertexInputAttributeDescription> attributes;
-      std::vector<VkVertexInputBindingDescription>   bindings;
+      std::vector<DxvkVertexAttribute> attributes;
+      std::vector<DxvkVertexBinding>   bindings;
       
       for (uint32_t i = 0; i < NumElements; i++) {
         const DxbcSgnEntry* entry = inputSignature->find(
@@ -359,7 +359,7 @@ namespace dxvk {
         }
         
         // Create vertex input attribute description
-        VkVertexInputAttributeDescription attrib;
+        DxvkVertexAttribute attrib;
         attrib.location = entry->registerId;
         attrib.binding  = pInputElementDescs[i].InputSlot;
         attrib.format   = m_dxgiAdapter->LookupFormat(
@@ -376,9 +376,8 @@ namespace dxvk {
         // Create vertex input binding description. The
         // stride is dynamic state in D3D11 and will be
         // set by D3D11DeviceContext::IASetVertexBuffers.
-        VkVertexInputBindingDescription binding;
+        DxvkVertexBinding binding;
         binding.binding   = pInputElementDescs[i].InputSlot;
-        binding.stride    = 0;
         binding.inputRate = VK_VERTEX_INPUT_RATE_VERTEX;
         
         if (pInputElementDescs[i].InputSlotClass == D3D11_INPUT_PER_INSTANCE_DATA) {
@@ -417,11 +416,10 @@ namespace dxvk {
       if (ppInputLayout != nullptr) {
         *ppInputLayout = ref(
           new D3D11InputLayout(this,
-            new DxvkInputLayout(
-              attributes.size(),
-              attributes.data(),
-              bindings.size(),
-              bindings.data())));
+            attributes.size(),
+            attributes.data(),
+            bindings.size(),
+            bindings.data()));
       }
       
       return S_OK;
diff --git a/src/d3d11/d3d11_input_layout.cpp b/src/d3d11/d3d11_input_layout.cpp
index eb6623c4e..9652a04ed 100644
--- a/src/d3d11/d3d11_input_layout.cpp
+++ b/src/d3d11/d3d11_input_layout.cpp
@@ -5,9 +5,19 @@ namespace dxvk {
   
   D3D11InputLayout::D3D11InputLayout(
           D3D11Device*                pDevice,
-    const Rc<DxvkInputLayout>&        inputLayout)
-  : m_device(pDevice), m_inputLayout(inputLayout) {
+          uint32_t                    numAttributes,
+    const DxvkVertexAttribute*        pAttributes,
+          uint32_t                    numBindings,
+    const DxvkVertexBinding*          pBindings)
+  : m_device(pDevice) {
+    m_attributes.resize(numAttributes);
+    m_bindings.resize(numBindings);
     
+    for (uint32_t i = 0; i < numAttributes; i++)
+      m_attributes.at(i) = pAttributes[i];
+    
+    for (uint32_t i = 0; i < numBindings; i++)
+      m_bindings.at(i) = pBindings[i];
   }
   
   
@@ -30,4 +40,13 @@ namespace dxvk {
     *ppDevice = ref(m_device);
   }
   
+  
+  void D3D11InputLayout::BindToContext(const Rc<DxvkContext>& ctx) {
+    ctx->setInputLayout(
+      m_attributes.size(),
+      m_attributes.data(),
+      m_bindings.size(),
+      m_bindings.data());
+  }
+  
 }
diff --git a/src/d3d11/d3d11_input_layout.h b/src/d3d11/d3d11_input_layout.h
index 3748b593d..2b3ee5236 100644
--- a/src/d3d11/d3d11_input_layout.h
+++ b/src/d3d11/d3d11_input_layout.h
@@ -12,7 +12,10 @@ namespace dxvk {
     
     D3D11InputLayout(
             D3D11Device*                pDevice,
-      const Rc<DxvkInputLayout>&        inputLayout);
+            uint32_t                    numAttributes,
+      const DxvkVertexAttribute*        pAttributes,
+            uint32_t                    numBindings,
+      const DxvkVertexBinding*          pBindings);
     
     ~D3D11InputLayout();
     
@@ -23,14 +26,15 @@ namespace dxvk {
     void GetDevice(
             ID3D11Device **ppDevice) final;
     
-    Rc<DxvkInputLayout> GetDXVKInputLayout() const {
-      return m_inputLayout;
-    }
+    void BindToContext(
+      const Rc<DxvkContext>& ctx);
     
   private:
     
     D3D11Device* const  m_device;
-    Rc<DxvkInputLayout> m_inputLayout;
+    
+    std::vector<DxvkVertexAttribute> m_attributes;
+    std::vector<DxvkVertexBinding>   m_bindings;
     
   };
   
diff --git a/src/dxgi/dxgi_presenter.cpp b/src/dxgi/dxgi_presenter.cpp
index 1d5d65da4..c2c026319 100644
--- a/src/dxgi/dxgi_presenter.cpp
+++ b/src/dxgi/dxgi_presenter.cpp
@@ -61,8 +61,7 @@ namespace dxvk {
     m_context->setInputAssemblyState(iaState);
     
     m_context->setInputLayout(
-      new DxvkInputLayout(
-        0, nullptr, 0, nullptr));
+      0, nullptr, 0, nullptr);
     
     DxvkRasterizerState rsState;
     rsState.enableDepthClamp   = VK_FALSE;
diff --git a/src/dxvk/dxvk_constant_state.cpp b/src/dxvk/dxvk_constant_state.cpp
index 6294152cc..88ce00177 100644
--- a/src/dxvk/dxvk_constant_state.cpp
+++ b/src/dxvk/dxvk_constant_state.cpp
@@ -41,30 +41,4 @@ namespace dxvk {
       m_info.blendConstants[i]    = 0.0f;
   }
   
-  
-  DxvkInputLayout::DxvkInputLayout(
-          uint32_t                           attributeCount,
-    const VkVertexInputAttributeDescription* attributeInfo,
-          uint32_t                           bindingCount,
-    const VkVertexInputBindingDescription*   bindingInfo) {
-    // Copy vertex attribute info to a persistent array
-    m_attributes.resize(attributeCount);
-    for (uint32_t i = 0; i < attributeCount; i++)
-      m_attributes.at(i) = attributeInfo[i];
-    
-    // Copy vertex binding info to a persistent array
-    m_bindings.resize(bindingCount);
-    for (uint32_t i = 0; i < bindingCount; i++)
-      m_bindings.at(i) = bindingInfo[i];
-    
-    // Create info structure referencing those arrays
-    m_info.sType                            = VK_STRUCTURE_TYPE_PIPELINE_VERTEX_INPUT_STATE_CREATE_INFO;
-    m_info.pNext                            = nullptr;
-    m_info.flags                            = 0;
-    m_info.vertexBindingDescriptionCount    = m_bindings.size();
-    m_info.pVertexBindingDescriptions       = m_bindings.data();
-    m_info.vertexAttributeDescriptionCount  = m_attributes.size();
-    m_info.pVertexAttributeDescriptions     = m_attributes.data();
-  }
-  
 }
\ No newline at end of file
diff --git a/src/dxvk/dxvk_constant_state.h b/src/dxvk/dxvk_constant_state.h
index f12d45eb4..c0b98c4c9 100644
--- a/src/dxvk/dxvk_constant_state.h
+++ b/src/dxvk/dxvk_constant_state.h
@@ -103,47 +103,48 @@ namespace dxvk {
   };
   
   
+  /**
+   * \brief Vertex attribute description
+   * 
+   * Stores information about a
+   * single vertex attribute.
+   */
+  struct DxvkVertexAttribute {
+    uint32_t location;
+    uint32_t binding;
+    VkFormat format;
+    uint32_t offset;
+  };
+  
+  
+  /**
+   * \brief Vertex binding description
+   * 
+   * Stores information about a
+   * single vertex binding slot.
+   */
+  struct DxvkVertexBinding {
+    uint32_t          binding;
+    VkVertexInputRate inputRate;
+  };
+  
+  
   /**
    * \brief Input layout
    * 
-   * Stores the attributes and vertex buffer binding
-   * descriptions that the vertex shader will take
-   * its input values from.
+   * Stores the description of all active
+   * vertex attributes and vertex bindings.
    */
-  class DxvkInputLayout : public RcObject {
-    
-  public:
-    
-    DxvkInputLayout(
-            uint32_t                           attributeCount,
-      const VkVertexInputAttributeDescription* attributeInfo,
-            uint32_t                           bindingCount,
-      const VkVertexInputBindingDescription*   bindingInfo);
-    
-    uint32_t vertexAttributeCount() const {
-      return m_attributes.size();
-    }
-    
-    uint32_t vertexBindingCount() const {
-      return m_bindings.size();
-    }
-    
-    const VkPipelineVertexInputStateCreateInfo& info() const {
-      return m_info;
-    }
-    
-  private:
-    
-    std::vector<VkVertexInputAttributeDescription> m_attributes;
-    std::vector<VkVertexInputBindingDescription>   m_bindings;
-    
-    VkPipelineVertexInputStateCreateInfo m_info;
+  struct DxvkInputLayout {
+    uint32_t numAttributes;
+    uint32_t numBindings;
     
+    std::array<DxvkVertexAttribute, DxvkLimits::MaxNumVertexAttributes> attributes;
+    std::array<DxvkVertexBinding,   DxvkLimits::MaxNumVertexBindings>   bindings;
   };
   
   
   struct DxvkConstantStateObjects {
-    Rc<DxvkInputLayout>         inputLayout;
     Rc<DxvkBlendState>          blendState;
   };
   
diff --git a/src/dxvk/dxvk_context.cpp b/src/dxvk/dxvk_context.cpp
index 372d5b31b..0504eceda 100644
--- a/src/dxvk/dxvk_context.cpp
+++ b/src/dxvk/dxvk_context.cpp
@@ -389,11 +389,20 @@ namespace dxvk {
   
   
   void DxvkContext::setInputLayout(
-    const Rc<DxvkInputLayout>& state) {
-    if (m_state.co.inputLayout != state) {
-      m_state.co.inputLayout = state;
-      m_flags.set(DxvkContextFlag::GpDirtyPipelineState);
-    }
+          uint32_t             attributeCount,
+    const DxvkVertexAttribute* attributes,
+          uint32_t             bindingCount,
+    const DxvkVertexBinding*   bindings) {
+    m_flags.set(DxvkContextFlag::GpDirtyPipelineState);
+    
+    m_state.il.numAttributes = attributeCount;
+    m_state.il.numBindings   = bindingCount;
+    
+    for (uint32_t i = 0; i < attributeCount; i++)
+      m_state.il.attributes.at(i) = attributes[i];
+    
+    for (uint32_t i = 0; i < bindingCount; i++)
+      m_state.il.bindings.at(i) = bindings[i];
   }
   
   
@@ -501,16 +510,20 @@ namespace dxvk {
       gpState.iaPrimitiveTopology      = m_state.ia.primitiveTopology;
       gpState.iaPrimitiveRestart       = m_state.ia.primitiveRestart;
       
-      const auto& il = m_state.co.inputLayout->info();
-      gpState.ilAttributeCount         = il.vertexAttributeDescriptionCount;
-      gpState.ilBindingCount           = il.vertexBindingDescriptionCount;
+      gpState.ilAttributeCount         = m_state.il.numAttributes;
+      gpState.ilBindingCount           = m_state.il.numBindings;
       
-      for (uint32_t i = 0; i < gpState.ilAttributeCount; i++)
-        gpState.ilAttributes[i] = il.pVertexAttributeDescriptions[i];
+      for (uint32_t i = 0; i < m_state.il.numAttributes; i++) {
+        gpState.ilAttributes[i].location = m_state.il.attributes[i].location;
+        gpState.ilAttributes[i].binding  = m_state.il.attributes[i].binding;
+        gpState.ilAttributes[i].format   = m_state.il.attributes[i].format;
+        gpState.ilAttributes[i].offset   = m_state.il.attributes[i].offset;
+      }
       
-      for (uint32_t i = 0; i < gpState.ilBindingCount; i++) {
-        gpState.ilBindings[i] = il.pVertexBindingDescriptions[i];
-        gpState.ilBindings[i].stride = m_state.vi.vertexStrides.at(i);
+      for (uint32_t i = 0; i < m_state.il.numBindings; i++) {
+        gpState.ilBindings[i].binding    = m_state.il.bindings[i].binding;
+        gpState.ilBindings[i].inputRate  = m_state.il.bindings[i].inputRate;
+        gpState.ilBindings[i].stride     = m_state.vi.vertexStrides.at(i);
       }
       
       gpState.rsEnableDepthClamp       = m_state.rs.enableDepthClamp;
diff --git a/src/dxvk/dxvk_context.h b/src/dxvk/dxvk_context.h
index 0d04befad..f87fe6510 100644
--- a/src/dxvk/dxvk_context.h
+++ b/src/dxvk/dxvk_context.h
@@ -270,11 +270,18 @@ namespace dxvk {
       const DxvkInputAssemblyState& state);
     
     /**
-     * \brief Sets input layout state
-     * \param [in] state New state object
+     * \brief Sets input layout
+     * 
+     * \param [in] attributeCount Number of vertex attributes
+     * \param [in] attributes The vertex attributes
+     * \param [in] bindingCount Number of buffer bindings
+     * \param [in] bindings Vertex buffer bindigs
      */
     void setInputLayout(
-      const Rc<DxvkInputLayout>& state);
+            uint32_t             attributeCount,
+      const DxvkVertexAttribute* attributes,
+            uint32_t             bindingCount,
+      const DxvkVertexBinding*   bindings);
     
     /**
      * \brief Sets rasterizer state
diff --git a/src/dxvk/dxvk_context_state.h b/src/dxvk/dxvk_context_state.h
index 17a93e9cc..8fa12c386 100644
--- a/src/dxvk/dxvk_context_state.h
+++ b/src/dxvk/dxvk_context_state.h
@@ -90,6 +90,7 @@ namespace dxvk {
    */
   struct DxvkContextState {
     DxvkInputAssemblyState    ia;
+    DxvkInputLayout           il;
     DxvkRasterizerState       rs;
     DxvkMultisampleState      ms;
     DxvkDepthStencilState     ds;
diff --git a/tests/dxvk/test_dxvk_triangle.cpp b/tests/dxvk/test_dxvk_triangle.cpp
index 977f43156..2ffe42dc7 100644
--- a/tests/dxvk/test_dxvk_triangle.cpp
+++ b/tests/dxvk/test_dxvk_triangle.cpp
@@ -87,8 +87,7 @@ public:
     iaState.primitiveRestart  = VK_FALSE;
     m_dxvkContext->setInputAssemblyState(iaState);
     
-    m_dxvkContext->setInputLayout(
-      new DxvkInputLayout(0, nullptr, 0, nullptr));
+    m_dxvkContext->setInputLayout(0, nullptr, 0, nullptr);
     
     DxvkRasterizerState rsState;
     rsState.enableDepthClamp   = VK_FALSE;