diff --git a/Ryujinx.Graphics.Shader/CodeGen/Spirv/Instructions.cs b/Ryujinx.Graphics.Shader/CodeGen/Spirv/Instructions.cs index 514722355..7265fac10 100644 --- a/Ryujinx.Graphics.Shader/CodeGen/Spirv/Instructions.cs +++ b/Ryujinx.Graphics.Shader/CodeGen/Spirv/Instructions.cs @@ -125,7 +125,11 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv Add(Instruction.StoreAttribute, GenerateStoreAttribute); Add(Instruction.StoreLocal, GenerateStoreLocal); Add(Instruction.StoreShared, GenerateStoreShared); + Add(Instruction.StoreShared16, GenerateStoreShared16); + Add(Instruction.StoreShared8, GenerateStoreShared8); Add(Instruction.StoreStorage, GenerateStoreStorage); + Add(Instruction.StoreStorage16, GenerateStoreStorage16); + Add(Instruction.StoreStorage8, GenerateStoreStorage8); Add(Instruction.Subtract, GenerateSubtract); Add(Instruction.SwizzleAdd, GenerateSwizzleAdd); Add(Instruction.TextureSample, GenerateTextureSample); @@ -1322,6 +1326,20 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv return OperationResult.Invalid; } + private static OperationResult GenerateStoreShared16(CodeGenContext context, AstOperation operation) + { + GenerateStoreSharedSmallInt(context, operation, 16); + + return OperationResult.Invalid; + } + + private static OperationResult GenerateStoreShared8(CodeGenContext context, AstOperation operation) + { + GenerateStoreSharedSmallInt(context, operation, 8); + + return OperationResult.Invalid; + } + private static OperationResult GenerateStoreStorage(CodeGenContext context, AstOperation operation) { var elemPointer = GetStorageElemPointer(context, operation); @@ -1330,6 +1348,20 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv return OperationResult.Invalid; } + private static OperationResult GenerateStoreStorage16(CodeGenContext context, AstOperation operation) + { + GenerateStoreStorageSmallInt(context, operation, 16); + + return OperationResult.Invalid; + } + + private static OperationResult GenerateStoreStorage8(CodeGenContext context, AstOperation operation) + { + GenerateStoreStorageSmallInt(context, operation, 8); + + return OperationResult.Invalid; + } + private static OperationResult GenerateSubtract(CodeGenContext context, AstOperation operation) { return GenerateBinary(context, operation, context.Delegates.FSub, context.Delegates.ISub); @@ -1862,6 +1894,69 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv return new OperationResult(AggregateType.U32, context.AtomicCompareExchange(context.TypeU32(), elemPointer, one, zero, zero, value1, value0)); } + private static void GenerateStoreSharedSmallInt(CodeGenContext context, AstOperation operation, int bitSize) + { + var offset = context.Get(AggregateType.U32, operation.GetSource(0)); + var value = context.Get(AggregateType.U32, operation.GetSource(1)); + + var wordOffset = context.ShiftRightLogical(context.TypeU32(), offset, context.Constant(context.TypeU32(), 2)); + var bitOffset = context.BitwiseAnd(context.TypeU32(), offset, context.Constant(context.TypeU32(), 3)); + bitOffset = context.ShiftLeftLogical(context.TypeU32(), bitOffset, context.Constant(context.TypeU32(), 3)); + + var memory = context.SharedMemory; + + var elemPointer = context.AccessChain(context.TypePointer(StorageClass.Workgroup, context.TypeU32()), memory, wordOffset); + + GenerateStoreSmallInt(context, elemPointer, bitOffset, value, bitSize); + } + + private static void GenerateStoreStorageSmallInt(CodeGenContext context, AstOperation operation, int bitSize) + { + var i0 = context.Get(AggregateType.S32, operation.GetSource(0)); + var offset = context.Get(AggregateType.U32, operation.GetSource(1)); + var value = context.Get(AggregateType.U32, operation.GetSource(2)); + + var wordOffset = context.ShiftRightLogical(context.TypeU32(), offset, context.Constant(context.TypeU32(), 2)); + var bitOffset = context.BitwiseAnd(context.TypeU32(), offset, context.Constant(context.TypeU32(), 3)); + bitOffset = context.ShiftLeftLogical(context.TypeU32(), bitOffset, context.Constant(context.TypeU32(), 3)); + + var sbVariable = context.StorageBuffersArray; + + var i1 = context.Constant(context.TypeS32(), 0); + + var elemPointer = context.AccessChain(context.TypePointer(StorageClass.Uniform, context.TypeU32()), sbVariable, i0, i1, wordOffset); + + GenerateStoreSmallInt(context, elemPointer, bitOffset, value, bitSize); + } + + private static void GenerateStoreSmallInt( + CodeGenContext context, + SpvInstruction elemPointer, + SpvInstruction bitOffset, + SpvInstruction value, + int bitSize) + { + var loopStart = context.Label(); + var loopEnd = context.Label(); + + context.Branch(loopStart); + context.AddLabel(loopStart); + + var oldValue = context.Load(context.TypeU32(), elemPointer); + var newValue = context.BitFieldInsert(context.TypeU32(), oldValue, value, bitOffset, context.Constant(context.TypeU32(), bitSize)); + + var one = context.Constant(context.TypeU32(), 1); + var zero = context.Constant(context.TypeU32(), 0); + + var result = context.AtomicCompareExchange(context.TypeU32(), elemPointer, one, zero, zero, newValue, oldValue); + var failed = context.INotEqual(context.TypeBool(), result, oldValue); + + context.LoopMerge(loopEnd, loopStart, LoopControlMask.MaskNone); + context.BranchConditional(failed, loopStart, loopEnd); + + context.AddLabel(loopEnd); + } + private static SpvInstruction GetStorageElemPointer(CodeGenContext context, AstOperation operation) { var sbVariable = context.StorageBuffersArray;