Skip to content

Instantly share code, notes, and snippets.

@18182324
Created September 29, 2025 12:19
Show Gist options
  • Save 18182324/613f194e7b846731c3ada511d3bf82e6 to your computer and use it in GitHub Desktop.
Save 18182324/613f194e7b846731c3ada511d3bf82e6 to your computer and use it in GitHub Desktop.
Art of Pairs Trading
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from statsmodels.tsa.stattools import coint
import statsmodels.api as sm
np.random.seed(107)
# -----------------------------
# 1) Simulate a cointegrated pair
# -----------------------------
n = 500 # observations
drift = 0.1
sigma = 1.0
X_ret = np.random.normal(drift/n, sigma/np.sqrt(n), n)
X = 50 + np.cumsum(X_ret) # random walk base
X = pd.Series(X, name="X")
# Create Y as Y = alpha + beta*X + stationary noise
beta_true = 1.05
alpha_true = 5.0
eps = np.zeros(n)
rho = 0.8 # AR(1) noise for stationary spread
noise = np.random.normal(0, 0.5, n)
for t in range(1, n):
eps[t] = rho * eps[t-1] + noise[t]
Y = alpha_true + beta_true * X.values + eps
Y = pd.Series(Y, name="Y")
df = pd.concat([X, Y], axis=1)
# -----------------------------
# 2) Cointegration test
# -----------------------------
score, pvalue, crit = coint(df["X"], df["Y"])
print(f"Engle-Granger coint p-value: {pvalue:.4f}")
# Hedge ratio via OLS (Y on X)
X_ = sm.add_constant(df["X"])
beta_hat = sm.OLS(df["Y"], X_).fit().params["X"]
alpha_hat = sm.OLS(df["Y"], X_).fit().params["const"]
# -----------------------------
# 3) Spread and z-score
# -----------------------------
spread = df["Y"] - (alpha_hat + beta_hat * df["X"])
lookback = 60
z = (spread - spread.rolling(lookback).mean()) / spread.rolling(lookback).std()
z = z.dropna()
# -----------------------------
# 4) Trading rules (long-only or classic long-short)
# Here we do classic mean reversion:
# If z > +z_entry: short Y, long X
# If z < -z_entry: long Y, short X
# Exit when |z| < z_exit
# -----------------------------
z_entry = 2.0
z_exit = 0.5
pos_Y = pd.Series(0, index=z.index, dtype=float)
pos_X = pd.Series(0, index=z.index, dtype=float)
in_trade = False
side = 0 # +1 means long Y short X, -1 means short Y long X
for t in range(len(z)):
zi = z.iloc[t]
idx = z.index[t]
if not in_trade:
if zi > z_entry:
# Short Y, long X
side = -1
in_trade = True
elif zi < -z_entry:
# Long Y, short X
side = +1
in_trade = True
else:
if abs(zi) < z_exit:
# flatten
side = 0
in_trade = False
pos_Y.iloc[t] = side
pos_X.iloc[t] = -side * beta_hat # dollar neutral approximation
# Align positions with prices for PnL
px = df.loc[z.index]
ret_Y = px["Y"].pct_change().fillna(0.0)
ret_X = px["X"].pct_change().fillna(0.0)
# -----------------------------
# 5) Backtest with costs
# -----------------------------
# Transaction cost per turnover (both legs). Use bps on notional.
cost_bps = 1.0 # 1 bp per leg is common for liquid names in sims
turnover = (pos_Y.diff().abs() + pos_X.diff().abs()).fillna(0.0)
gross_ret = pos_Y.shift(1) * ret_Y + pos_X.shift(1) * ret_X
costs = turnover * (cost_bps / 10000.0)
net_ret = (gross_ret - costs).fillna(0.0)
equity = (1 + net_ret).cumprod()
cum_ret = equity.iloc[-1] - 1
ann_factor = 252
ann_ret = (equity.iloc[-1])**(ann_factor/len(equity)) - 1
ann_vol = net_ret.std() * np.sqrt(ann_factor)
sharpe = 0 if ann_vol == 0 else ann_ret / ann_vol
max_dd = (equity / equity.cummax() - 1).min()
print(f"Cum Return: {cum_ret:.2%}")
print(f"Ann Return: {ann_ret:.2%}")
print(f"Ann Vol: {ann_vol:.2%}")
print(f"Sharpe: {sharpe:.2f}")
print(f"Max Drawdown: {max_dd:.2%}")
# -----------------------------
# 6) Plots
# -----------------------------
plt.figure(figsize=(12,5))
df.plot(ax=plt.gca(), title="Simulated Cointegrated Pair")
plt.tight_layout()
plt.show()
plt.figure(figsize=(12,4))
spread.plot(title="Spread")
plt.tight_layout()
plt.show()
plt.figure(figsize=(12,4))
z.plot(title="Z-score of Spread (60-day)")
plt.axhline(z_entry, linestyle="--")
plt.axhline(-z_entry, linestyle="--")
plt.axhline(0, linestyle=":")
plt.tight_layout()
plt.show()
plt.figure(figsize=(12,4))
equity.plot(title="Equity Curve (Net of Costs)")
plt.tight_layout()
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment