Implementing an efficient AI inference engine in C#
- Stage 2 -
This page provides information on the current development of my AI inference engine for Large Language Models. The first implementation is documented on the following webpage:
How to build an AI Inference Engine for Large Language Models in C#Key Changes
I achieved the first working version of the inference engine around a year ago, but I was nowhere near content with the state of my inference engine. The next main problem I faced was that my engine was very slow. The key problem was that I had only completed the first stage of my goals, and now it was time to delived on the rest of the optimizations, which would make this project truly impressive. Progress was slower, since I had to start studying more for my A-Levels and prepare for various university entrance exams.
There were two key changes I had to make:
1) I had to change my memory model from C# primitives which were not designed for the kind of large batches that I would have to use them in to native memory mangement across multiple devices and threads.
2) I had to change my execution model to accomodate for more devices and more paralleism.
How I plan to achieve these two changes is detailed in the next section.
Architectural Changes
Memory Model
For the memory model, I plan to stick with my vector/matrix decomposition approach, however, I plan to make the data a vector stores be able to move from device to device. I also plan to abandon storing data as an array of data types. Much rather, the data will be a pointer to a series of bytes and the operation itself will give what datatype we will treat the data as. This will remove the burden of having to reimplement the vector class for each different datatype.
The way I will implement this is by creating an abstraction between the data and the vector itself, which I call a data unit. A data unit will be a memory block of less than or equal to 2MB (for caching convenience), and a vector can be one or more of these chained together. The data unit itself is abstracted away from where it is stored through composition. Each data unit will contain one (if the data is variable) or more (if the data is constant) data storage classes, which represent the storage of a data unit on one device. The data storages will be managed by a class representing their respective storage devices.
This model also means that larger matrices and vectors will be tiled into smaller segments (this is good for performance, because managing smaller memory blocks vs. monolithic arrays is easier). This is trivial to implement for vectors. However, for matrices it will be implemented not by tiling a larger array, which stores the data in row-major order, but rather, the matrix will be tiled into square matrices (imagine a grid being applied to the matrix, which cuts out the tiles), which will fit into a single data unit in row major order. This will make it possible to convineintly work with vectors and matrices spread across different devices.
Execution Model
My current approach of architecture components being executed sequentially misses out a large amount of parallelism. Moreover, the threads having to await eachother means that a lot of possible processing time is wasted. This means I will convert my execution model to be based on a compute graph instead (like I did in my neural network simulation project).
The way this will work is that the execution manager will simply accept a root node for the compute graph, and each inference it will traverse the graph from there where each OzAIOperation constitutes a vertex in that graph. All OzAIOperations whose dependencies are all met will be stored in one linked list the access to which will be atomic (no-concurrent operations shall be made). This list will first be traversed, then locked if the traversing thread found a task it can execute. It will then reserve all the relevant data units and release the lock. Finally, it will move all the relevant data units to the correct device and perform the operation. This operation will yield more operations it was connected to in the compute graph, which will be appended to the end of the list.
This way the maximum number of threads would be working at all times. An OzAIOperation takes a range over a single data unit for each operand. This will naturally yield sub-vector, or to use the tensor equivalent sub-row, level paralellism. Furthermore, to make this convenient I will repalce the architecture component model with a model where these primitive AI operations are orgnaized together by OzAITasks. This way, level by level, I can rebuild my code to have a maximally parallelizable architecture to yield ultimate performance even on low-end hardware.
Technical Implementation Preview
The New OzAIOperation
I decided to make the ’primitve’ vector operations list even more fundamental, removing operations like Softmax, RMS Norm from the list of supported operations. To promote sub-vector level parallelism, these will be implemented by combining other operation that are parallelizable, because they do not contain any summation reductions. Moreover, I will incorporate the datatypes use as part of the operations. Finally, I have introduced a wild-card operation called Exec, which executes a function (calls execute on an OzAIExecutable) enabling any operation possible in the previous approach as well as making changes to the graph dynamically. The new list of possible operations is as follows:
public enum OzAIOperationType
{
Exec,
AdditionF,
AdditionH,
DivF,
DivH,
HadF,
HadH,
MatMulF,
MatMulH,
MaxF,
MaxH,
RoPEF,
ScaleF,
ScaleH,
SumF,
SumH,
SquaredSumF,
SquaredSumH,
Swish1F,
ExpAndSumF
}
The actual operation class itself has changed to accomodate for the parallelized compute graph approach. Also note, I have made a greater effort to document my code. To do this I had AI write comment and summaries. I tried to make AI write other code however it is faster to just write the code myself, since the AI has seen no examples of the innovative solution I intend to create. The code for OzAIOperation is below:
public abstract class OzAIOperation : OzAICheckable
{
public abstract OzAIOperationType Type { get; }
private int _taken; // 0 = false, 1 = true
public bool Taken => Volatile.Read(ref _taken) == 1;
// Thread‑safe attempt to take the operation. Returns true if this call set Taken to true.
public bool Take()
{
// Atomically set _taken to 1 and return true if it was 0 before.
return Interlocked.Exchange(ref _taken, 1) == 0;
}
private uint _doneCount;
public uint DoneCount => _doneCount;
public uint DepCount {get; private set;}
/// <summary>
/// Marks this operation as done, processes all forward‑linked operations and returns
/// a list of operations whose dependencies are now satisfied.
/// The method is thread‑safe and aims for efficiency by using the lock‑free
/// forward‑link queue and a simple lock to guarantee that only one thread
/// performs the completion logic for a given operation.
/// </summary>
/// <returns>
/// A <see cref="LinkedList{OzAIOperation}"/> containing the next operations
/// that have become ready (i.e.,<c>checkOrIncDoneCount()</c> returned true).
/// </returns>
public LinkedList<OzAIOperation> Done()
{
var readyOps = new LinkedList<OzAIOperation>();
// Enumerate the forward‑link queue (snapshot semantics are fine here).
foreach (var nextOp in _nextQueue)
{
// If the next operation's dependencies are satisfied after increment,
// add it to the list of ready operations.
if (nextOp.checkOrIncDoneCount())
{
readyOps.AddLast(nextOp);
}
}
// Reset this operation's state for potential reuse.
Reset();
return readyOps;
}
// Returns true if all dependencies are satisfied; otherwise increments DoneCount.
private bool checkOrIncDoneCount()
{
while (true)
{
uint current = _doneCount;
if (current == DepCount)
return true;
uint newVal = current + 1;
uint original = Interlocked.CompareExchange(ref _doneCount, newVal, current);
if (original == current)
return false;
// another thread modified _doneCount, retry
}
}
private readonly ConcurrentQueue<OzAIOperation> _nextQueue = new();
/// <summary>
/// Adds the specified operation to the forward‑link queue in a thread‑safe manner
/// and increments the dependency count of the added operation.
/// </summary>
/// <param name="op">The operation to link as a successor.</param>
public void ForwardLink(OzAIOperation op)
{
_nextQueue.Enqueue(op);
op.DepCount++;
}
/// <summary>
/// Resets the operation state in a lock‑free, thread‑safe manner.
/// Sets <see cref="DoneCount"/> to 0 and marks the operation as not taken.
/// The <see cref="Taken"/> flag is cleared last to ensure any thread that
/// observes the reset will see <c>Taken == false</c>.
/// </summary>
public void Reset()
{
// Reset the done count (non‑atomic write is acceptable here)
_doneCount = 0;
// Reset the taken flag atomically; this write happens last.
Interlocked.Exchange(ref _taken, 0);
}
protected bool CheckRanges(List<OzAIDataRange> ranges, List<string> names, out string error)
{
if (ranges == null)
{
error = $"Could not check if {Type} is possible, becuase no ranges provided to check.";
return false;
}
if (names == null)
{
error = $"Could not check if {Type} is possible, becuase no names provided for the ranges.";
return false;
}
if (ranges.Count != names.Count)
{
error = $"Could not check if {Type} is possible, becuase not all ranges were given a name.";
return false;
}
for (int i = 0; i < ranges.Count; i++)
{
var range = ranges[i].MemRange;
var name = names[i];
if (!CheckAddr(range.Addr, name, out error))
return false;
if (!CheckSize(range.Size, name, out error))
return false;
}
error = null;
return true;
}
public bool CheckAddr(nint addr, string name, out string error)
{
if (addr < 0)
{
error = $"{Type} is not possible, because the address provided for ${name} is negative.";
return false;
}
if (addr == 0)
{
error = $"{Type} is not possible, because the address provided for ${name} is 0.";
return false;
}
if (addr == nint.MaxValue)
{
error = $"{Type} is not possible, because the address provided for ${name} is the max possible value for an nint.";
return false;
}
error = null;
return true;
}
public bool CheckSize(nuint size, string name, out string error)
{
if (size == 0)
{
error = $"{Type} is not possible, because the size provided for ${name} is 0.";
return false;
}
if (size == nuint.MaxValue)
{
error = $"{Type} is not possible, because the size provided for ${name} is the max possible value for an nuint.";
return false;
}
error = null;
return true;
}
public abstract bool IsPossible(out string error);
}
The Take method will be used to claim a task atomically, only then will the stroing linked list be locked to reserve the corresponding data units. The Done operation will increment the number of dependencies completed for all forward connections of this task and will return all the new tasks that have become available. Finally, ForwardLink makes the current task a dependency of the next. Here is an implementation of this abstract class as an example:
public abstract class OzAIBinaryOp : OzAIOperation
{
public OzAIDataRange Source1, Source2, Destination;
public override bool IsPossible(out string error)
{
if (!CheckRanges([Source1, Source2, Destination], ["Source 1", "Source 2", "Destination"], out error))
return false;
error = null;
return true;
}
}
DataRange is a structure (range over a data unit ). I also learnt new syntax, so most of my variables are stack like so in the new classes. It already evident that this is a better approach just but the small amount of overhead it carries into the inheritors of OzAIOperation. This BinaryOp class is then inherited by others making the implementation for all binary operations being as simple as follows:
public class OzAIAdditionF : OzAIBinaryOp
{
public override OzAIOperationType Type => OzAIOperationType.AdditionF;
}
A similar model is followed for all other implementations. These will then be ForwardLinked together in OzAITask.
New Executor
This means my executor will be implemented with pointers in mind. This will also make iteration a lot more efficient because I can just index into pointers or increment them instead of having to recalculate indecies all the time. In particular this approach shows its benefits in the matrixmultiplcation where row-boundaries can be stepped across by incrementing the pointer automnatically instead of having to use the formula, x + y*width:
public class OzAICPUSharpExec : OzAICPUExecutor
{
public override unsafe void AddF(OzAIMemRange src1, OzAIMemRange src2, OzAIMemRange dst)
{
var pSrc1 = (float*)src1.Addr;
var pSrc2 = (float*)src2.Addr;
var pDst = (float*)dst.Addr;
nuint count = dst.Size / 4;
for (nuint i = 0; i < count; i++)
{
pDst[i] = pSrc1[i] + pSrc2[i];
}
}
public override unsafe void AddH(OzAIMemRange src1, OzAIMemRange src2, OzAIMemRange dst)
{
var pSrc1 = (Half*)src1.Addr;
var pSrc2 = (Half*)src2.Addr;
var pDst = (Half*)dst.Addr;
nuint count = dst.Size / 2;
for (nuint i = 0; i < count; i++)
{
pDst[i] = pSrc1[i] + pSrc2[i];
}
}
…
public override unsafe void MatMulF(OzAIMemRange mat, OzAIMemRange src, OzAIMemRange dst)
{
var pMat = (float*)mat.Addr;
var pSrc = (float*)src.Addr;
var pDst = (float*)dst.Addr;
nuint height = dst.Size / 4;
nuint width = src.Size / 4;
for (nuint i = 0; i < height; i++)
{
for (nuint j = 0; j < width; j++)
{
pDst[i] += *pMat++ * pSrc[j];
}
}
}
…
public override void Swish1F(OzAIMemRange src, OzAIMemRange dst)
{
throw new NotImplementedException();
}
}
Note that not all operations have been implemented yet, however one class is capable of containing all the operations required (perhaps split into partial classes if needs be.)
New Tensor Classes
Finally, new tensor classes are gonig to be used based on the memory model discussed earlier. Here is the more simplistic OzAIVector:
public class OzAIVector : OzAIData
{
public OzAIVector(nuint count, OzAINumType type)
{
_count = count;
_type = type;
_units = new();
var size = Size;
OzAIDataUnit next;
while (size > OzAIMemManager.TileSize)
{
size -= OzAIMemManager.TileSize;
next = new OzAIDataUnit();
_units.Add(next);
}
_remainder = size;
if (size == 0) return;
next = new OzAIDataUnit();
_units.Add(next);
}
List<OzAIDataUnit> _units;
nuint _count;
nuint _remainder;
public override nuint Count => _count;
OzAINumType _type;
public override OzAINumType Type => _type;
public nuint Size => Type.GetSize(Count);
public override void Allocate()
{
var end = _units.Count - 1;
var unit = _units[end];
if (_remainder == 0)
OzAIMemManager.AllocUnit(unit, OzAIMemManager.TileSize);
else
OzAIMemManager.AllocUnit(unit, _remainder);
for (int i = 0; i < end; i--)
{
OzAIMemManager.AllocUnit(unit);
}
}
public override void Free()
{
OzAIMemManager.Free(_units);
}
}
As you can see, memory management will be improved further with separate Free and allocate functions. These will be handled in exec operations. . The MemManager class is discussed in the next section. The tiling into data units for the matix class however is more complex, based on the previously discussed approach as below:
public abstract class OzAIMatrix : OzAIData
{
public OzAIMatrix(nuint width, nuint height, OzAINumType type)
{
Width = width;
Height = height;
_type = type;
_units = new();
pack();
}
void pack()
{
if (packOne()) return;
var sideLen = Type.GetOptimalSquare(OzAIMemManager.TileSize);
_tileWidth = Math.Min(Width, sideLen);
_tileHeight = Math.Min(Height, sideLen);
_tileRowSize = DivAndRoundUp(Width, _tileWidth, out _rowRemainder);
_tileColSize = DivAndRoundUp(Height, _tileHeight, out _colRemainder);
var tileCount = _tileRowSize * _tileColSize;
for (nuint i = 0; i < tileCount; i++)
{
var next = new OzAIDataUnit();
_units.Add(next);
}
}
bool packOne()
{
if (Size > OzAIMemManager.TileSize) return false;
_colRemainder = Height;
_rowRemainder = Width;
var next = new OzAIDataUnit();
_units.Add(next);
return true;
}
nuint DivAndRoundUp(nuint a, nuint b, out nuint remainder)
{
var res = a / b;
if (res * b < a)
{
remainder = a - res * b;
return 1 + res;
}
remainder = 0;
return res;
}
List<OzAIDataUnit> _units;
public readonly nuint Width;
nuint _rowRemainder;
nuint _tileRowSize;
nuint _tileWidth;
public readonly nuint Height;
nuint _colRemainder;
nuint _tileColSize;
nuint _tileHeight;
public override nuint Count => Width * Height;
OzAINumType _type;
public override OzAINumType Type => _type;
public nuint Size => Type.GetSize(Count);
public override void Allocate()
{
OzAIDataUnit unit;
if (_units.Count == 1)
{
unit = _units[0];
OzAIMemManager.AllocUnit(unit, Size);
return;
}
unit = _units[^1];
var lastHeight = _colRemainder == 0 ? _tileHeight : _colRemainder;
var lastWidth = _rowRemainder == 0 ? _tileWidth : _rowRemainder;
var lastSize = lastHeight * Type.GetSize(lastWidth);
OzAIMemManager.AllocUnit(unit, lastSize);
nuint xEnd = _tileRowSize - 1;
nuint yEnd = _tileColSize - 1;
var coreSize = _tileHeight * Type.GetSize(_tileWidth);
for (nuint x = 0; x < xEnd; x++)
{
for (nuint y = 0; y < yEnd; y++)
{
unit = _units[(int)(x + y * _tileRowSize)];
OzAIMemManager.AllocUnit(unit, coreSize);
}
}
var lastRowSize = lastHeight * Type.GetSize(_tileWidth);
for (nuint x = 0; x < xEnd; x++)
{
unit = _units[(int)(x + yEnd * _tileRowSize)];
OzAIMemManager.AllocUnit(unit, lastRowSize);
}
var lastColSize = _tileHeight * Type.GetSize(lastWidth);
for (nuint y = 0; y < yEnd; y++)
{
unit = _units[(int)(xEnd + y * _tileRowSize)];
OzAIMemManager.AllocUnit(unit, lastColSize);
}
}
public override void Free()
{
OzAIMemManager.Free(_units);
}
}
Both override the OzAIData class, which is going to form the backbone of data management instead of the former memory node class.
New Memory Management
The final change that I can present in this form is the new memory managment classes namely memory manger (helps manage data units):
public partial class OzAIMemManager
{
/// <summary>
/// 64 for 512-bit AVX-512 alignment.
/// </summary>
public static nuint AlignmentBytes = 64;
public const nuint TileSize = 1024*1024*2;
public static OzAIRAM RAM = new OzAIRAM();
static List<OzAIDataUnit> _units;
public static void AllocUnit(OzAIDataUnit unit, nuint size = TileSize)
{
unit.StorageDevice = RAM;
unit.StorageID = RAM.Allocate(size);
_units.Add(unit);
}
public static void FreeUnit(OzAIDataUnit unit)
{
unit.StorageDevice.Free(unit.StorageID);
unit.StorageDevice = null;
unit.StorageID = -1;
}
public static OzAIDataUnit AllocOne(nuint size = TileSize)
{
var res = new OzAIDataUnit();
AllocUnit(res, size);
return res;
}
public static List<OzAIDataUnit> Alloc(nuint size)
{
List<OzAIDataUnit> res = new List<OzAIDataUnit>();
OzAIDataUnit unit;
while (size > 0)
{
if (size > TileSize)
{
unit = AllocOne();
res.Add(unit);
size -= TileSize;
continue;
}
unit = AllocOne(size);
res.Add(unit);
break;
}
return res;
}
public static void Free(List<OzAIDataUnit> units)
{
foreach (var unit in units)
{
FreeUnit(unit);
}
}
}
Bigger management classes will also be replacing most of OzAIProcMode, where each class actively makes changes to the program state and adapts (plans for adaptive memory management) instead of statically storing information. Finally, it is required to see the DataUnit class and the OzAIRAM class to make sense of my new data storage system
Here is the data unit class:
public class OzAIDataUnit
{
public int StorageID = -1;
public OzAIStorageDevice StorageDevice = null;
public bool IsAlloced
{
get
{
return StorageID != -1;
}
}
public OzAIMemRange GetMemRange()
{
var ds = StorageDevice.GetStorage(StorageID);
var res = ds.GetMemRange();
return res;
}
}
And here is the class managing RAM allocation with corresponding DataStorage class:
public class OzAIRAM : OzAIStorageDevice
{
public override OzAIRAMDStorage InnerAllocate(nuint size)
{
var res = new OzAIRAMDStorage(size);
res.Allocate();
return res;
}
public override void InnerFree(OzAIDataStorage ds)
{
ds.Free();
}
}
Here is the abstract class for data storage which uses the new nint for storing memory addresses:
public abstract class OzAIDataStorage
{
public nint Addr { get; protected set; }
public nuint Size { get; protected set; }
public void Allocate()
{
if (Addr != 0) return;
Addr = InnerAllocate();
}
protected abstract nint InnerAllocate();
public void Free()
{
if (Addr == 0) return;
InnerFree();
Addr = 0;
}
public OzAIMemRange GetMemRange()
{
return new OzAIMemRange { Addr = this.Addr, Size = this.Size };
}
protected abstract void InnerFree();
}
Finally, here is the ram implementation:
public class OzAIRAMDStorage : OzAIDataStorage
{
public OzAIRAMDStorage(nuint size)
{
Size = size;
}
protected override nint InnerAllocate()
{
unsafe
{
return (nint)NativeMemory.AlignedAlloc(Size, OzAIMemManager.AlignmentBytes);
}
}
protected override void InnerFree()
{
unsafe
{
NativeMemory.AlignedFree((void*)Addr);
}
}
}