MLX Tutorial: Apple's Machine Learning Framework for Apple Silicon
MLX is an open-source machine learning framework from Apple, designed specifically for Apple Silicon chips (M1, M2, M3, M4). The framework offers a familiar API for NumPy and PyTorch users, with performance optimized for Apple hardware. MLX has become the go-to choice for developers who want to run ML models locally on Mac devices without requiring NVIDIA GPUs.
In this tutorial, we will cover installation, basic usage, LLM fine-tuning, model inference, and best practices to maximize MLX performance on Apple Silicon devices.
Why MLX?
Before diving in, let's understand why MLX is worth learning:
Installation
Prerequisites
- macOS 13.5 or later
- Apple Silicon (M1/M2/M3/M4)
- Python 3.9 or later
Installing MLX Core
pip install mlx
Installing MLX-LM (for Language Models)
pip install mlx-lm
Installing MLX-VLM (for Vision-Language Models)
pip install mlx-vlm
Installing from Source (Optional)
git clone https://github.com/ml-explore/mlx.git
cd mlx
pip install -e .
Verify Installation
import mlx.core as mx
import mlx.nn as nn
print(f"MLX version: {mx.version}")
print(f"Default device: {mx.defaultdevice()}")
Simple test
a = mx.array([1, 2, 3, 4, 5])
print(f"Array: {a}")
print(f"Sum: {mx.sum(a)}")
Expected output:
MLX version: 0.x.x
Default device: Device(gpu, 0)
Array: array([1, 2, 3, 4, 5], dtype=int32)
Sum: array(15, dtype=int32)
Basic Usage
Array Operations
MLX arrays are very similar to NumPy but optimized for Apple Silicon:
import mlx.core as mx
Creating arrays
a = mx.array([1.0, 2.0, 3.0, 4.0])
b = mx.ones((3, 4))
c = mx.zeros((2, 3))
d = mx.random.normal((5, 5))
print(f"Shape of b: {b.shape}")
print(f"Dtype of a: {a.dtype}")
Mathematical operations
x = mx.array([[1, 2], [3, 4]], dtype=mx.float32)
y = mx.array([[5, 6], [7, 8]], dtype=mx.float32)
Element-wise operations
print(f"Addition: {x + y}")
print(f"Multiplication: {x y}")
Matrix multiplication
print(f"MatMul: {x @ y}")
Reduction operations
print(f"Sum: {mx.sum(x)}")
print(f"Mean: {mx.mean(x)}")
print(f"Max: {mx.max(x)}")
Lazy Evaluation
One of MLX's unique features is lazy evaluation. Computations are not executed until their results are actually needed:
import mlx.core as mx
a = mx.ones((1000, 1000))
b = mx.ones((1000, 1000))
These operations are not yet executed
c = a + b
d = c 2
Evaluation happens when we need the result
mx.eval(d)
print(d)
Or when converting to Python/NumPy
result = d.tolist()
Device Control
import mlx.core as mx
Check default device
print(f"Default device: {mx.defaultdevice()}")
Run on CPU
mx.setdefaultdevice(mx.cpu)
a = mx.ones((100, 100))
print(f"Device: {mx.defaultdevice()}")
Switch back to GPU
mx.setdefaultdevice(mx.gpu)
b = mx.ones((100, 100))
print(f"Device: {mx.defaultdevice()}")
Building Neural Networks with MLX
Simple Model
MLX provides the mlx.nn module, similar to PyTorch: