diff --git a/Ryujinx.Graphics.Shader/CodeGen/Spirv/CodeGenContext.cs b/Ryujinx.Graphics.Shader/CodeGen/Spirv/CodeGenContext.cs index f451602db..b2b856e06 100644 --- a/Ryujinx.Graphics.Shader/CodeGen/Spirv/CodeGenContext.cs +++ b/Ryujinx.Graphics.Shader/CodeGen/Spirv/CodeGenContext.cs @@ -13,6 +13,8 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv partial class CodeGenContext : Module { + private readonly StructuredProgramInfo _info; + public ShaderConfig Config { get; } public int InputVertices { get; } @@ -65,8 +67,13 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv public SpirvDelegates Delegates { get; } - public CodeGenContext(ShaderConfig config, GeneratorPool instPool, GeneratorPool integerPool) : base(0x00010300, instPool, integerPool) + public CodeGenContext( + StructuredProgramInfo info, + ShaderConfig config, + GeneratorPool instPool, + GeneratorPool integerPool) : base(0x00010300, instPool, integerPool) { + _info = info; Config = config; if (config.Stage == ShaderStage.Geometry) @@ -217,17 +224,28 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv elemType = attrInfo.Type & AggregateType.ElementTypeMask; - var storageClass = isOutAttr ? StorageClass.Output : StorageClass.Input; + int attrOffset = attrInfo.BaseValue; + AggregateType type = attrInfo.Type; - var elemIndex = Constant(TypeU32(), attrInfo.GetInnermostIndex()); + bool isUserAttr = attr >= AttributeConsts.UserAttributeBase && attr < AttributeConsts.UserAttributeEnd; + if (isUserAttr && Config.TransformFeedbackEnabled && + ((isOutAttr && Config.Stage != ShaderStage.Fragment) || + (!isOutAttr && Config.Stage != ShaderStage.Vertex))) + { + attrOffset = attr; + type = elemType; + } - var ioVariable = isOutAttr ? Outputs[attrInfo.BaseValue] : Inputs[attrInfo.BaseValue]; + var ioVariable = isOutAttr ? Outputs[attrOffset] : Inputs[attrOffset]; - if ((attrInfo.Type & (AggregateType.Array | AggregateType.Vector)) == 0) + if ((type & (AggregateType.Array | AggregateType.Vector)) == 0) { return ioVariable; } + var storageClass = isOutAttr ? StorageClass.Output : StorageClass.Input; + var elemIndex = Constant(TypeU32(), attrInfo.GetInnermostIndex()); + if (Config.Stage == ShaderStage.Geometry && !isOutAttr && (!attrInfo.IsBuiltin || AttributeInfo.IsArrayBuiltIn(attr))) { return AccessChain(TypePointer(storageClass, GetType(elemType)), ioVariable, index, elemIndex); @@ -300,6 +318,18 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv return _functions[funcIndex]; } + public TransformFeedbackOutput GetTransformFeedbackOutput(int location, int component) + { + int index = (AttributeConsts.UserAttributeBase / 4) + location * 4 + component; + return _info.TransformFeedbackOutputs[index]; + } + + public TransformFeedbackOutput GetTransformFeedbackOutput(int location) + { + int index = location / 4; + return _info.TransformFeedbackOutputs[index]; + } + public Instruction GetType(AggregateType type, int length = 1) { if (type.HasFlag(AggregateType.Array)) diff --git a/Ryujinx.Graphics.Shader/CodeGen/Spirv/Declarations.cs b/Ryujinx.Graphics.Shader/CodeGen/Spirv/Declarations.cs index 106c1d352..23f64ea5d 100644 --- a/Ryujinx.Graphics.Shader/CodeGen/Spirv/Declarations.cs +++ b/Ryujinx.Graphics.Shader/CodeGen/Spirv/Declarations.cs @@ -4,6 +4,7 @@ using Ryujinx.Graphics.Shader.Translation; using Spv.Generator; using System; using System.Collections.Generic; +using System.Diagnostics; using System.Linq; using static Spv.Specification; @@ -335,14 +336,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv iq = context.Config.ImapTypes[(attr - AttributeConsts.UserAttributeBase) / 16].GetFirstUsedType(); } - if (context.Config.TransformFeedbackEnabled) - { - throw new NotImplementedException(); - } - else - { - DeclareInputOrOutput(context, attr, false, iq); - } + DeclareInputOrOutput(context, attr, false, iq); } } @@ -371,14 +365,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv private static void DeclareOutputAttribute(CodeGenContext context, int attr) { - if (context.Config.TransformFeedbackEnabled) - { - throw new NotImplementedException(); - } - else - { - DeclareInputOrOutput(context, attr, true); - } + DeclareInputOrOutput(context, attr, true); } public static void DeclareInvocationId(CodeGenContext context) @@ -388,6 +375,15 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv private static void DeclareInputOrOutput(CodeGenContext context, int attr, bool isOutAttr, PixelImap iq = PixelImap.Unused) { + bool isUserAttr = attr >= AttributeConsts.UserAttributeBase && attr < AttributeConsts.UserAttributeEnd; + if (isUserAttr && context.Config.TransformFeedbackEnabled && + ((isOutAttr && context.Config.Stage != ShaderStage.Fragment) || + (!isOutAttr && context.Config.Stage != ShaderStage.Vertex))) + { + DeclareInputOrOutput(context, attr, (attr >> 2) & 3, isOutAttr, iq); + return; + } + var dict = isOutAttr ? context.Outputs : context.Inputs; var attrInfo = AttributeInfo.From(context.Config, attr); @@ -410,8 +406,19 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv if (attrInfo.IsBuiltin) { context.Decorate(spvVar, Decoration.BuiltIn, (LiteralInteger)GetBuiltIn(context, attrInfo.BaseValue)); + + if (context.Config.TransformFeedbackEnabled && isOutAttr) + { + var tfOutput = context.GetTransformFeedbackOutput(attrInfo.BaseValue); + if (tfOutput.Valid) + { + context.Decorate(spvVar, Decoration.XfbBuffer, (LiteralInteger)tfOutput.Buffer); + context.Decorate(spvVar, Decoration.XfbStride, (LiteralInteger)tfOutput.Stride); + context.Decorate(spvVar, Decoration.Offset, (LiteralInteger)tfOutput.Offset); + } + } } - else if (attr >= AttributeConsts.UserAttributeBase && attr < AttributeConsts.UserAttributeEnd) + else if (isUserAttr) { int location = (attr - AttributeConsts.UserAttributeBase) / 16; context.Decorate(spvVar, Decoration.Location, (LiteralInteger)location); @@ -439,6 +446,60 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv dict.Add(attrInfo.BaseValue, spvVar); } + private static void DeclareInputOrOutput(CodeGenContext context, int attr, int component, bool isOutAttr, PixelImap iq = PixelImap.Unused) + { + var dict = isOutAttr ? context.Outputs : context.Inputs; + var attrInfo = AttributeInfo.From(context.Config, attr); + + if (dict.ContainsKey(attr)) + { + return; + } + + var storageClass = isOutAttr ? StorageClass.Output : StorageClass.Input; + var attrType = context.GetType(attrInfo.Type & AggregateType.ElementTypeMask); + + if (context.Config.Stage == ShaderStage.Geometry && !isOutAttr && (!attrInfo.IsBuiltin || AttributeInfo.IsArrayBuiltIn(attr))) + { + attrType = context.TypeArray(attrType, context.Constant(context.TypeU32(), (LiteralInteger)context.InputVertices)); + } + + var spvType = context.TypePointer(storageClass, attrType); + var spvVar = context.Variable(spvType, storageClass); + + Debug.Assert(attr >= AttributeConsts.UserAttributeBase && attr < AttributeConsts.UserAttributeEnd); + int location = (attr - AttributeConsts.UserAttributeBase) / 16; + + context.Decorate(spvVar, Decoration.Location, (LiteralInteger)location); + context.Decorate(spvVar, Decoration.Component, (LiteralInteger)component); + + if (isOutAttr) + { + var tfOutput = context.GetTransformFeedbackOutput(location, component); + if (tfOutput.Valid) + { + context.Decorate(spvVar, Decoration.XfbBuffer, (LiteralInteger)tfOutput.Buffer); + context.Decorate(spvVar, Decoration.XfbStride, (LiteralInteger)tfOutput.Stride); + context.Decorate(spvVar, Decoration.Offset, (LiteralInteger)tfOutput.Offset); + } + } + else + { + switch (iq) + { + case PixelImap.Constant: + context.Decorate(spvVar, Decoration.Flat); + break; + case PixelImap.ScreenLinear: + context.Decorate(spvVar, Decoration.NoPerspective); + break; + } + } + + context.AddGlobalVariable(spvVar); + dict.Add(attr, spvVar); + } + private static BuiltIn GetBuiltIn(CodeGenContext context, int attr) { return attr switch diff --git a/Ryujinx.Graphics.Shader/CodeGen/Spirv/SpirvGenerator.cs b/Ryujinx.Graphics.Shader/CodeGen/Spirv/SpirvGenerator.cs index 4e5dddce2..20b8fa0cb 100644 --- a/Ryujinx.Graphics.Shader/CodeGen/Spirv/SpirvGenerator.cs +++ b/Ryujinx.Graphics.Shader/CodeGen/Spirv/SpirvGenerator.cs @@ -47,7 +47,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv integerPool = IntegerPool.Allocate(); } - CodeGenContext context = new CodeGenContext(config, instPool, integerPool); + CodeGenContext context = new CodeGenContext(info, config, instPool, integerPool); context.AddCapability(Capability.GroupNonUniformBallot); context.AddCapability(Capability.ImageBuffer); @@ -56,6 +56,11 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv context.AddCapability(Capability.SubgroupBallotKHR); context.AddCapability(Capability.SubgroupVoteKHR); + if (config.TransformFeedbackEnabled && config.Stage != ShaderStage.Fragment) + { + context.AddCapability(Capability.TransformFeedback); + } + if (config.Stage == ShaderStage.Geometry) { context.AddCapability(Capability.Geometry); @@ -193,6 +198,11 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv localSizeY, localSizeZ); } + + if (context.Config.TransformFeedbackEnabled && context.Config.Stage != ShaderStage.Fragment) + { + context.AddExecutionMode(spvFunc, ExecutionMode.Xfb); + } } }