How to Implement Grouped Multi-headed Attention in C#.Net
by: Gyula Rábai
Grouped multi-headed attention is often described abstractly, hidden behind tensor libraries and opaque framework calls, but at its core, it is a carefully structured composition of very concrete operations. In this page, we walk through a full, ground-up implementation of grouped multi-headed self-attention in C#, showing how attention can be decomposed into reusable components: attention heads, attention groups, and a final aggregation layer. Rather than relying on high-rank tensor tricks, the design emphasizes explicit vector and matrix operations, clear memory ownership, and strict parameter validation. The result is a transparent, extensible architecture that exposes how queries, keys, values, RoPE, and scaled dot-product attention actually fit together, and how grouping heads can be implemented cleanly without sacrificing performance or conceptual clarity.
Attention
The attention component consists of multiple sub-components. Below is the CompParams for the whole attention component.
partial class OzAIMultiHeadAttn
{
public OzAICompIOMem_Unary Mem;
public CompIParams IParams;
public class CompIParams : OzAICompIParams
{
public OzAIAttnGroup.CompIParams[] GroupParams;
public OzAIMatrix OutputWeights;
public override bool IsPossible(out string error)
{
List<object> objs = [ExecManager, GroupParams, OutputWeights];
List<string> names = ["ExecManager", "GroupParams", "Outputs"];
if (!CheckIfNull(objs, names, out error))
return false;
for (int i = 0; i < GroupParams.Length; i++)
{
var group = GroupParams[i];
if (group == null)
{
error = $"No group parameters provided for group number {i}.";
return false;
}
if (!group.IsPossible(out error))
{
error = $"No group parameters number {i} not possible: " + error;
return false;
}
}
return true;
}
}
public CompHParams HParams;
public class CompHParams : OzAICompHParams
{
public uint GroupCount;
public OzAIAttnGroup.CompHParams GroupParams;
public override bool SetDefaults(OzAIProcMode mode, out string error)
{
return GroupParams.SetDefaults(mode, out error);
}
public override bool IsPossible(out string error)
{
if (GroupCount == 0)
{
error = "Group count needs to be non zero";
return false;
}
if (!CheckIfNull(GroupParams, "GroupParams", out error)) return false;
if (!GroupParams.IsPossible(out error)) return false;
return true;
}
}
public override bool IsPossible(out string error)
{
if (Params.Mem is not OzAICompIOMem_Unary)
{
error = "IO memory required to be in a format for a unary op OzAICompIOMem_Unary.";
return false;
}
if (Params.IParams is not CompIParams)
{
error = "Instance parameters required to be in a format for a OzAIMultiHeadAttn.CompIParams.";
return false;
}
if (Params.HParams is not CompHParams)
{
error = "Hyperparameters required to be in a format for a OzAIMultiHeadAttn.CompHParams.";
return false;
}
Mem = Params.Mem as OzAICompIOMem_Unary;
IParams = Params.IParams as CompIParams;
HParams = Params.HParams as CompHParams;
if (IParams.GroupParams.Length != HParams.GroupCount)
{
error = "Number of group parameters provided does not match the group count hparam.";
return false;
}
error = null;
return true;
}
}
To summarize, Attention is actually a unary operation from an outside perspective on the tokens performed. The IParams are the IParams for the groups (discussed later) and the Output transformation matrix. The HParams are the HParams of the group and how many groups there are. The rest of the class is down below:
///<summary>
/// Applies self attention to a some number of vectors a lot of times in different heads.
///</summary>
public partial class OzAIMultiHeadAttn : OzAIArchComp
{
List<OzAIAttnGroup> Groups;
public override string Name => "OzAIMultiHeadAttn";
protected override bool InitInner(out string error)
{
Groups = new List<OzAIAttnGroup>();
for (int i = 0; i < HParams.GroupCount; i++)
{
if (!createGroup(i, out error))
return false;
}
error = null;
return true;
}
bool createGroup(int i, out string error)
{
var group = new OzAIAttnGroup();
var groupMem = createGroupMem(i);
var groupParams = new OzAICompParams()
{
Mem = groupMem,
IParams = IParams.GroupParams[i],
HParams = HParams.GroupParams,
};
if (!group.Init(groupParams, out error))
return false;
Groups.Add(group);
return true;
}
OzAIAttnGroup.CompMem createGroupMem(int i)
{
var groupOuts = new OzAIMemNode[HParams.GroupParams.HeadCount];
for (long j = 0; j < groupOuts.LongLength; j++)
{
groupOuts[j] = new OzAIMemNode();
}
return new OzAIAttnGroup.CompMem()
{
Inputs = Mem.Inputs,
Outputs = groupOuts
};
}
public override bool Forward(out string error)
{
var exec = IParams.ExecManager;
if (!exec.GetProcMode(out var mode, out error))
return false;
for (int i = 0; i < Groups.Count; i++)
{
var group = Groups[i];
for (long j = 0; j < group.Mem.Outputs.LongLength; j++)
{
group.Mem.Outputs[j].Clear();
if (!group.Mem.Outputs[j].AddVecs(mode, HParams.GroupParams.HeadParams.ValLen, Mem.Inputs.Count, out error))
return false;
}
if (!group.Forward(out error))
return false;
}
if (!mergeGroupOuts(mode, out var res, out error))
return false;
var vecs = res.ToArray();
var outs = Mem.Outputs.GetArray();
if (!exec.MatMul(vecs, IParams.OutputWeights, outs, out error))
return false;
error = null;
return true;
}
bool mergeGroupOuts(OzAIProcMode mode, out List<OzAIVector> res, out string error)
{
res = new List<OzAIVector>();
var valLen = HParams.GroupParams.HeadParams.ValLen;
var vecLen = valLen * HParams.GroupParams.HeadCount * HParams.GroupCount;
var count = Mem.Inputs.GetList().Count;
for (int i = 0; i < count; i++)
{
if (!CombineVec(i, (int)valLen, vecLen, mode, res, out error))
return false;
}
error = null;
return true;
}
bool CombineVec(int i, int valLen, ulong vecLen, OzAIProcMode mode, List<OzAIVector> res, out string error)
{
if (!OzAIVector.Create(mode, out var resVec, out error))
return false;
if (!resVec.GetBytesPerBlock(out var bpb, out error))
return false;
if (!resVec.GetNumsPerBlock(out var npb, out error))
return false;
var size = (vecLen / npb) * bpb;
var vals = new byte[size];
var valSize = (valLen / (int)npb) * (int)bpb;
var offset = 0;
for (int j = 0; j < Groups.Count; j++)
{
for (int k = 0; k < Groups[j].Mem.Outputs.Length; k++)
{
if (!Groups[j].Mem.Outputs[k].GetList()[i].ToBytes(out var headFloats, out error))
return false;
Buffer.BlockCopy(headFloats, 0, vals, offset, valSize);
offset += valSize;
}
}
if (!resVec.Init(vals, 0, size, out error))
return false;
res.Add(resVec);
return true;
}
}
Breaking it down, this class contains a list of Attention Groups and initializes them with their respective parameters. The forward is simply calling forward for all of the groups and then concatenating the result. It also means passing result to the execution manager with the output weights matrix for a final transformation.
Attention Group
Let us take a look at the CompParams:
partial class OzAIAttnGroup
{
public CompMem Mem;
public class CompMem : OzAICompIOMem
{
public OzAIMemNode Inputs;
public OzAIMemNode[] Outputs;
public override bool IsPossible(out string error)
{
List<object> objs = [Inputs, Outputs];
List<string> names = ["Inputs", "Outputs"];
return CheckIfNull(objs, names, out error);
}
}
public CompIParams IParams;
public class CompIParams : OzAICompIParams
{
public OzAIMatrix ValueWeights;
public OzAIMatrix KeyWeights;
public OzAIAttnHead.CompIParams[] HeadParams;
public override bool IsPossible(out string error)
{
List<object> objs = [ExecManager, ValueWeights, KeyWeights, HeadParams];
List<string> names = ["ExecManager","ValueWeights", "KeyWeights", "AttnParams"];
if (!CheckIfNull(objs, names, out error))
return false;
for (int i = 0; i < HeadParams.Length; i++)
{
var head = HeadParams[i];
if (head == null)
{
error = $"No attention parameters provided for head number {i}."; ;
return false;
}
if (!HeadParams[i].IsPossible(out error))
{
error = $"No attention parameters for head number {i} not possible: " + error; ;
return false;
}
}
error = null;
return true;
}
}
public CompHParams HParams;
public class CompHParams : OzAICompHParams
{
public uint HeadCount = 4;
public OzAIAttnHead.CompHParams HeadParams;
public override bool IsPossible(out string error)
{
if (!CheckIfNull(HeadParams, "AttnParams", out error))
return false;
return HeadParams.IsPossible(out error);
}
}
public override bool IsPossible(out string error)
{
if (Params.Mem is not CompMem)
{
error = "IO memory required to be in a format for a OzAIAttnGroup.CompMem.";
return false;
}
if (Params.IParams is not CompIParams)
{
error = "Instance parameters required to be in a format for a OzAIAttnGroup.CompIParams.";
return false;
}
if (Params.HParams is not CompHParams)
{
error = "Hyperparameters required to be in a format for a OzAIAttnGroup.CompHParams.";
return false;
}
Mem = Params.Mem as CompMem;
IParams = Params.IParams as CompIParams;
HParams = Params.HParams as CompHParams;
if (Mem.Outputs.Length != HParams.HeadCount)
{
error = "Number of outputs provided does not match the number of heads";
return false;
}
if (IParams.HeadParams.Length != HParams.HeadCount)
{
error = "Number of heads provided does not match the head cout hparam.";
return false;
}
error = null;
return true;
}
}
It takes the one same mem node as input as does the attention component. For outputs, it outputs all the vectors of the respective heads.
One must recall that where n is the number of tokens, d is the size of an embedding vector, k is the size of a key vector, v is the size of a value vector, and g is the number of groups, we perform the following operation (in einsum) on the value martrices (g, v, d), key matrices (g, k, d), and input matrix(n, d)
einsum(“nd, gvd -> ngv”, inputs, valueMatrices, values)
einsum(“nd, gkd -> ngk”, inputs, keyMatrices, keys)
One should note that only one contraction was made along the d dimension. This is a common theme for all components of the architecture, since most AI engineers have yet to understand that the real power in einsum lies in its generality to perform many outer products and contractions in one go. This negates any benefit gained from most tensor operations.
Also note that most tensor operations happen on rank 3 or 4 tensors in PyTorch implementations, perhaps due to the inherent limitation on the number of dimensions our 3D brains can visualise. This means most operations can easily be considered as lists of matrices instead of an einsum operation. This is why I made the decision to break up all operations into its constituent vector/matrix operations, and from here one we will make no more reference to the underlying tensor operations.
The IParams for this operation are the value matrix, the key matrix and the IParams for the heads in this group. The Hparams is the head count per group and indirectly the RoPE HParams. The actual forward function is below:
/// <summary>
/// This applies scaled dot product attention pooling.
/// Inputs: Originals, keys, values
/// </summary>
public partial class OzAIAttnGroup : OzAIArchComp
{
public OzAIMemNode Values;
public OzAIMemNode Keys;
public OzAIRoPE_Original RoPE;
public List<OzAIAttnHead> Heads;
public override string Name => "OzAIAttnHeadGroup";
protected override bool InitInner(out string error)
{
Values = new OzAIMemNode();
Keys = new OzAIMemNode();
if (!createRoPE(out error)) return false;
Heads = new List<OzAIAttnHead>((int)HParams.HeadCount);
for (int i = 0; i < HParams.HeadCount; i++)
{
if (!createHead(i, out error)) return false;
}
error = null;
return true;
}
bool createRoPE(out string error)
{
RoPE = new OzAIRoPE_Original();
var ropeMem = new OzAICompIOMem_Unary()
{
Inputs = Keys,
Outputs = Keys
};
var ropeIParams = new OzAICompIParams_ExecOnly()
{
ExecManager = IParams.ExecManager
};
var ropeParams = new OzAICompParams()
{
Mem = ropeMem,
IParams = ropeIParams,
HParams = HParams.HeadParams.RoPEParams
};
if (!RoPE.Init(ropeParams, out error))
return false;
return true;
}
bool createHead(int i, out string error)
{
var head = new OzAIAttnHead();
var headMem = new OzAIAttnHead.CompMem()
{
Inputs = Mem.Inputs,
Keys = Keys,
Values = Values,
Outputs = Mem.Outputs[i],
};
var headParams = new OzAICompParams()
{
Mem = headMem,
IParams = IParams.HeadParams[i],
HParams = HParams.HeadParams
};
if (!head.Init(headParams, out error))
return false;
Heads.Add(head);
return true;
}
public override bool Forward(out string error)
{
if (!initMem(out error)) return false;
if (!getKeys(out error)) return false;
if (!getVals(out error)) return false;
for (int i = 0; i < Heads.Count; i++)
{
var head = Heads[i];
if (!head.Forward(out error))
return false;
}
error = null;
return true;
}
bool initMem(out string error)
{
var exec = IParams.ExecManager;
if (!exec.GetProcMode(out var mode, out error))
return false;
Values.Clear();
if (!Values.AddVecs(mode, HParams.HeadParams.ValLen, Mem.Inputs.Count, out error))
return false;
Keys.Clear();
if (!Keys.AddVecs(mode, HParams.HeadParams.KeyLen, Mem.Inputs.Count, out error))
return false;
return true;
}
bool getVals(out string error)
{
var exec = IParams.ExecManager;
var inps = Mem.Inputs.GetArray();
var vals = Values.GetArray();
if (!exec.MatMul(inps, IParams.ValueWeights, vals, out error))
return false;
return true;
}
bool getKeys(out string error)
{
var exec = IParams.ExecManager;
var inps = Mem.Inputs.GetArray();
var keys = Keys.GetArray();
if (!exec.MatMul(inps, IParams.KeyWeights, keys, out error))
return false;
if (!RoPE.Forward(out error))
return false;
return true;
}
}
As initialization we create the respective heads and rope components. For the actual forward function, We tell the exec manager to calculate the values and keys. Then we call forward on the necessary rope and head components.
Attention Head
This component actually performs the scaled dot product attention. Here are its CompParams:
partial class OzAIAttnHead
{
public CompMem Mem;
public class CompMem : OzAICompIOMem
{
public OzAIMemNode Inputs;
public OzAIMemNode Values;
public OzAIMemNode Keys;
public OzAIMemNode Outputs;
public override bool IsPossible(out string error)
{
List<object> objs = [Inputs, Values, Keys, Outputs];
List<string> names = ["Inputs", "Values", "Keys", "Outputs"];
return CheckIfNull(objs, names, out error);
}
}
public CompIParams IParams;
public class CompIParams : OzAICompIParams
{
public OzAIMatrix QueryMat;
public override bool IsPossible(out string error)
{
List<object> objs = [QueryMat, ExecManager];
List<string> names = ["QueryMat", "ExecManager"];
return CheckIfNull(objs, names, out error);
}
}
public CompHParams HParams;
public class CompHParams : OzAICompHParams
{
public ulong ValLen = 64;
public ulong KeyLen = 64;
public OzAIScalar Scale;
public OzAIRoPE.CompHParams RoPEParams;
public override bool SetDefaults(OzAIProcMode mode, out string error)
{
if (!mode.GetCPUSettings(out var cpu, out error))
return false;
if (!OzAIScalar.CreateHalf((Half)1 / Half.Sqrt((Half)KeyLen), mode, out Scale, out error))
return false;
return true;
}
public override bool IsPossible(out string error)
{
if (!CheckIfNull([RoPEParams, Scale], ["RoPEParams", "Scale"], out error))
return false;
return RoPEParams.IsPossible(out error);
}
}
public override bool IsPossible(out string error)
{
if (Params.Mem is not CompMem)
{
error = "IO memory required to be in a format for a unary op OzAIAttnHead.CompMem.";
return false;
}
if (Params.IParams is not CompIParams)
{
error = "Instance parameters required to be in a format for a OzAIAttnHead.CompIParams.";
return false;
}
if (Params.HParams is not CompHParams)
{
error = "Hyperparameters required to be in a format for a OzAIAttnHead.CompHParams.";
return false;
}
Mem = Params.Mem as CompMem;
IParams = Params.IParams as CompIParams;
HParams = Params.HParams as CompHParams;
error = null;
return true;
}
}
For IO, it takes an inputs and outputs memory node, along with the Keys and Values. Its IParams consist of the query matrix, and the HParams are the scalar for the scaled dot product attention, the RoPE HParams, the lengths of the value vectors and the lengths of the key (and thus the query) vectors. Here is the rest of the class:
/// <summary>
/// This applies scaled dot product attention pooling.
/// Inputs: Originals, keys, values
/// </summary>
public partial class OzAIAttnHead : OzAIArchComp
{
public override string Name => "OzAIAttnHead";
public OzAIMemNode Queries;
public OzAIMemNode MyValues;
OzAIRoPE_Original RoPE;
protected override bool InitInner(out string error)
{
MyValues = new OzAIMemNode();
Queries = new OzAIMemNode();
RoPE = new OzAIRoPE_Original();
var ropeMem = new OzAICompIOMem_Unary()
{
Inputs = Queries,
Outputs = Queries
};
var ropeIParams = new OzAICompIParams_ExecOnly()
{
ExecManager = IParams.ExecManager,
};
var ropeParams = new OzAICompParams()
{
Mem = ropeMem,
IParams = ropeIParams,
HParams = HParams.RoPEParams
};
if (!RoPE.Init(ropeParams, out error))
return false;
error = null;
return true;
}
public override bool Forward(out string error)
{
var exec = IParams.ExecManager;
if (!exec.GetProcMode(out var mode, out error))
return false;
if (!initQKV(mode, out error))
return false;
// Make a destination vector for scores
if (!OzAIVector.Create(mode, out var scoresVec, out error))
return false;
Mem.Outputs.Clear();
var count = (ulong)Queries.GetArray().LongLength;
for (ulong i = 0; i < count; i++)
{
if (!getQKV(mode, i, out var query, out var keys, out var vals, out error))
return false;
if (!processQuery(i, query, keys, vals, scoresVec, out error))
return false;
}
error = null;
return true;
}
OzAIMatrixRange _keyMatRange;
bool initQKV(OzAIProcMode mode, out string error)
{
if (!getQueries(mode, out error))
return false;
if (!getKeyMatRange(mode, out _keyMatRange, out error))
return false;
error = null;
return true;
}
bool getQKV(OzAIProcMode mode, ulong i, out OzAIVector query, out OzAIMatrixRange keys, out OzAIVector[] vals, out string error)
{
query = null;
keys = null;
vals = null;
query = Queries.GetArray()[i];
_keyMatRange.Counts = new Tuple<ulong, ulong>(_keyMatRange.Counts.Item1, i + 1);
keys = _keyMatRange;
if (!MyValues.Clone(Mem.Values, out error))
return false;
vals = MyValues.GetArray();
error = null;
return true;
}
bool getQueries(OzAIProcMode mode, out string error)
{
Queries.Clear();
if (!Queries.AddVecs(mode, HParams.KeyLen, Mem.Inputs.Count, out error))
return false;
var inps = Mem.Inputs.GetArray();
var outs = Queries.GetArray();
var exec = IParams.ExecManager;
if (!exec.MatMul(inps, IParams.QueryMat, outs, out error))
return false;
if (!RoPE.Forward(out error))
return false;
error = null;
return true;
}
bool getKeyMatRange(OzAIProcMode mode, out OzAIMatrixRange res, out string error)
{
res = null;
var keys = Mem.Keys.GetArray();
if (!OzAIMatrix.Create(mode, out var keyMat, out error))
return false;
if (!keyMat.Init(keys, out error))
return false;
if (!OzAIMatrixRange.ToFull(keyMat, out res, out error))
return false;
error = null;
return true;
}
bool processQuery(ulong i, OzAIVector query, OzAIMatrixRange keyMatRange, OzAIVector[] values, OzAIVector scores, out string error)
{
var exec = IParams.ExecManager;
ulong keyCount = i + 1;
if (!scores.Init(keyCount, out error))
return false;
if (!OzAIVectorRange.ToFull(scores, out var scoresRange, out error))
return false;
if (!OzAIVectorRange.ToFull(query, out var queryRange, out error))
return false;
OzAIVectorRange[] src = [queryRange];
OzAIVectorRange[] dst = [scoresRange];
if (!exec.MatMul(src, keyMatRange, dst, out error))
return false;
if (!exec.Scale(dst, HParams.Scale, dst, out error))
return false;
if (!scores.ToDType(OzAINumType.Float32, out var fScores, out error))
return false;
if (!OzAIVectorRange.ToFull(fScores, out var fScoresRange, out error))
return false;
if (!exec.SoftMax([fScores], [fScoresRange], out error))
return false;
if (!fScores.ToDType(scores.GetNumType(), out scores, out error))
return false;
scoresRange.Vector = scores;
var res = values[0];
var scalar = scoresRange.GetNth(0);
if (!exec.Scale([res], scalar, [res], out error))
return false;
for (ulong j = 1; j < keyCount; j++)
{
var val = values[j];
scalar = scoresRange.GetNth(j);
if (!exec.Scale([val], scalar, [val], out error))
return false;
if (!exec.Add([val], [res], [res], out error))
return false;
}
Mem.Outputs.Add(res);
error = null;
return true;
}
}
This obtains the queries and applies RoPE to them. Then it performs attention, which I will not go into detail over. Suffice it to say, the softmax operation also requires conversion to float32, because it uses the exponentiation function, which readily overflows on float16 operations.
Summary
In this implementation, grouped multi-headed attention is not treated as a single monolithic tensor operation, but as a hierarchy of well-defined components, each with a clear responsibility.
At the top level, OzAIMultiHeadAttn acts as a unary transformation over a sequence of token embeddings. Its role is orchestration: it instantiates multiple attention groups, dispatches the forward pass to each group, concatenates their outputs, and applies a final linear projection through the output weight matrix. From the outside, this component behaves exactly like a standard self-attention layer, while internally exposing how grouping and note aggregation are handled.
Each OzAIAttnGroup represents a group of attention heads that share the same key and value projections. The group is responsible for computing keys and values once per group, applying RoPE to the keys, and feeding those shared representations into each head. This mirrors the mathematical idea of grouped attention, where heads are partitioned to reduce parameter count and improve cache locality while still allowing multiple attention subspaces.
Within a group, OzAIAttnHead performs the actual scaled dot-product attention. A head computes its own queries, applies RoPE to them, calculates attention scores against the group’s keys, applies scaling and softmax, and finally produces a weighted sum of the values. This component is where causal masking, numerical stability (via float32 softmax), and per-token accumulation are explicitly handled.
Supporting these core components are the memory abstractions (OzAIMemNode, vector and matrix ranges) and the execution manager, which together decouple algorithmic intent from execution details such as data layout, precision, and device backend. This separation makes the attention logic explicit while still allowing optimized implementations underneath.
Taken together, this design shows that modern attention mechanisms do not require opaque tensor gymnastics to be expressive or efficient. By breaking attention into groups, heads, and well-scoped matrix and vector operations, the implementation makes the data flow and mathematics of grouped multi-headed attention transparent—while remaining faithful to the behavior expected from contemporary transformer architectures.