Skip to content

Instantly share code, notes, and snippets.

@cpcloud
Created June 26, 2023 19:18
Show Gist options
  • Save cpcloud/84a4e3eb4df56812db1fc488ff120cb8 to your computer and use it in GitHub Desktop.
Save cpcloud/84a4e3eb4df56812db1fc488ff120cb8 to your computer and use it in GitHub Desktop.
Ibis, DuckDB and PyTorch
import ibis.expr.datatypes as dt
import torch
import torch.nn as nn
import tqdm
import pyarrow as pa
class LinearRegression(nn.Module):
def __init__(self, input_dim, output_dim):
super().__init__()
self.linear = nn.Linear(input_dim, output_dim)
def forward(self, distances):
return self.linear(distances)
class PredictCabFare:
def __init__(self, data, learning_rate: float = 0.01, epochs: int = 100) -> None:
# Define the input and output dimensions
input_dim = 1
output_dim = 1
# Create a linear regression model instance
self.data = data
self.model = LinearRegression(input_dim, output_dim)
self.learning_rate = learning_rate
self.epochs = epochs
def train(self):
distances = self.data["trip_distance"].reshape(-1, 1)
fares = self.data["fare_amount"].reshape(-1, 1)
# Define the loss function
criterion = nn.MSELoss()
# Define the optimizer
optimizer = torch.optim.SGD(self.model.parameters(), lr=self.learning_rate)
# Train the model
for _ in tqdm.trange(self.epochs): # noqa: F402
# Forward pass
y_pred = self.model(distances)
# Compute loss
loss = criterion(y_pred, fares)
# Backward pass and optimize
optimizer.zero_grad()
loss.backward()
optimizer.step()
def predict(self, input):
with torch.no_grad():
return self.model(input)
def __call__(self, input: pa.ChunkedArray):
# Convert the input to numpy so it can be fed to the model
#
# .copy() to avoid the warning about undefined behavior from torch
input = torch.from_numpy(input.to_numpy().copy())[:, None]
predicted = self.predict(input).ravel()
return pa.array(predicted.numpy())
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"id": "60be381b-30ae-41ab-81f0-751815e60559",
"metadata": {},
"source": [
"# Using Ibis with Torch"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "73193571-f288-479e-85ba-dd729babe415",
"metadata": {},
"outputs": [],
"source": [
"import ibis\n",
"import ibis.expr.datatypes as dt\n",
"\n",
"from plotnine import aes, ggtitle, ggplot, geom_point, xlab, ylab\n",
"\n",
"from ibis.expr.operations import udf\n",
"from ibis import _, selectors as s\n",
"\n",
"ibis.options.interactive = True"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "49bed195-1426-409a-98cb-2c972408e06b",
"metadata": {},
"outputs": [],
"source": [
"def clean_input(path):\n",
" return (\n",
" # load parquet\n",
" ibis.read_parquet(path)\n",
" # compute fare_amount_zscore and trip_distance_zscore\n",
" .mutate(s.across([\"fare_amount\", \"trip_distance\"], dict(zscore=(_ - _.mean()) / _.std())))\n",
" # filter out negative trip distance and bizarre transactions\n",
" .filter([_.trip_distance > 0.0, _.fare_amount >= 0.0])\n",
" # keep values that within 2 standard deviations\n",
" .filter(s.if_all(s.endswith(\"_zscore\"), _.abs() <= 2))\n",
" # drop columns that aren't necessary for further analysis\n",
" .drop(s.endswith(\"_zscore\"))\n",
" # select the columns we care about\n",
" .select(s.across([\"fare_amount\", \"trip_distance\"], _.cast(\"float32\")))\n",
" )\n",
"\n",
"training_data = clean_input(\"/home/cloud/data/trip-data/yellow_tripdata_2016-01.parquet\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f9bcfae7-8de7-40d2-9f52-021ab6373d79",
"metadata": {},
"outputs": [],
"source": [
"training_data"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1c1c1e4d-abc6-4e93-b3ba-ed5cee46107f",
"metadata": {},
"outputs": [],
"source": [
"from model import PredictCabFare\n",
"\n",
"# def __init__(data: dict[str, tensor]): ...\n",
"# def train(self): ... # black box\n",
"# def __call__(self, input: pyarrow.ChunkedArray): ...\n",
"\n",
"torch_training_data = training_data.to_torch() # dict[str, Tensor]"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "faede14c-0678-4485-9844-92a1c4b8b1e3",
"metadata": {},
"outputs": [],
"source": [
"torch_training_data"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "11c1795b-d075-4bff-83ae-fa2d5047779c",
"metadata": {},
"outputs": [],
"source": [
"model = PredictCabFare(torch_training_data)\n",
"model.train()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a2282161-bd97-4572-9bff-75b23a74d7a0",
"metadata": {},
"outputs": [],
"source": [
"@udf.scalar.pyarrow\n",
"def predict_fare(distance: dt.float64) -> dt.float32:\n",
" return model(distance)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0c0d528c-fe55-463b-a0ae-b96b4ba7bfae",
"metadata": {},
"outputs": [],
"source": [
"## Visualize the comparison between the predicted cab fares and the actual cab fares\n",
"prediction = clean_input(\n",
" \"/home/cloud/data/trip-data/yellow_tripdata_2016-02.parquet\"\n",
").limit(10_000).mutate(\n",
" predicted_fare=lambda t: predict_fare(t.trip_distance.cast(\"float32\")),\n",
")\n",
"prediction"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9f8d93a4-c13b-43c3-a081-0b4c8a1ed028",
"metadata": {},
"outputs": [],
"source": [
"pivoted_prediction = prediction.pivot_longer(~s.c(\"trip_distance\"), values_to=\"fare\", names_to=\"metric\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "eba15ec7-7a17-4b97-a2c3-22353eb20441",
"metadata": {},
"outputs": [],
"source": [
"pivoted_prediction"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f0737502-be3b-4aac-b831-58e58f17a679",
"metadata": {},
"outputs": [],
"source": [
"(\n",
" ggplot(pivoted_prediction, aes(x=\"trip_distance\", y=\"fare\", color=\"metric\"))\n",
" + geom_point()\n",
" + xlab(\"Trip Distance\")\n",
" + ylab(\"Fare\")\n",
" + ggtitle(\"Predicted Fare vs Actual Fare by Trip Distance\")\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9bce2011-61ff-45d1-b7b4-279c6ccc84ed",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment