Created
June 26, 2023 19:18
-
-
Save cpcloud/84a4e3eb4df56812db1fc488ff120cb8 to your computer and use it in GitHub Desktop.
Ibis, DuckDB and PyTorch
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import ibis.expr.datatypes as dt | |
import torch | |
import torch.nn as nn | |
import tqdm | |
import pyarrow as pa | |
class LinearRegression(nn.Module): | |
def __init__(self, input_dim, output_dim): | |
super().__init__() | |
self.linear = nn.Linear(input_dim, output_dim) | |
def forward(self, distances): | |
return self.linear(distances) | |
class PredictCabFare: | |
def __init__(self, data, learning_rate: float = 0.01, epochs: int = 100) -> None: | |
# Define the input and output dimensions | |
input_dim = 1 | |
output_dim = 1 | |
# Create a linear regression model instance | |
self.data = data | |
self.model = LinearRegression(input_dim, output_dim) | |
self.learning_rate = learning_rate | |
self.epochs = epochs | |
def train(self): | |
distances = self.data["trip_distance"].reshape(-1, 1) | |
fares = self.data["fare_amount"].reshape(-1, 1) | |
# Define the loss function | |
criterion = nn.MSELoss() | |
# Define the optimizer | |
optimizer = torch.optim.SGD(self.model.parameters(), lr=self.learning_rate) | |
# Train the model | |
for _ in tqdm.trange(self.epochs): # noqa: F402 | |
# Forward pass | |
y_pred = self.model(distances) | |
# Compute loss | |
loss = criterion(y_pred, fares) | |
# Backward pass and optimize | |
optimizer.zero_grad() | |
loss.backward() | |
optimizer.step() | |
def predict(self, input): | |
with torch.no_grad(): | |
return self.model(input) | |
def __call__(self, input: pa.ChunkedArray): | |
# Convert the input to numpy so it can be fed to the model | |
# | |
# .copy() to avoid the warning about undefined behavior from torch | |
input = torch.from_numpy(input.to_numpy().copy())[:, None] | |
predicted = self.predict(input).ravel() | |
return pa.array(predicted.numpy()) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"id": "60be381b-30ae-41ab-81f0-751815e60559", | |
"metadata": {}, | |
"source": [ | |
"# Using Ibis with Torch" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "73193571-f288-479e-85ba-dd729babe415", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import ibis\n", | |
"import ibis.expr.datatypes as dt\n", | |
"\n", | |
"from plotnine import aes, ggtitle, ggplot, geom_point, xlab, ylab\n", | |
"\n", | |
"from ibis.expr.operations import udf\n", | |
"from ibis import _, selectors as s\n", | |
"\n", | |
"ibis.options.interactive = True" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "49bed195-1426-409a-98cb-2c972408e06b", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def clean_input(path):\n", | |
" return (\n", | |
" # load parquet\n", | |
" ibis.read_parquet(path)\n", | |
" # compute fare_amount_zscore and trip_distance_zscore\n", | |
" .mutate(s.across([\"fare_amount\", \"trip_distance\"], dict(zscore=(_ - _.mean()) / _.std())))\n", | |
" # filter out negative trip distance and bizarre transactions\n", | |
" .filter([_.trip_distance > 0.0, _.fare_amount >= 0.0])\n", | |
" # keep values that within 2 standard deviations\n", | |
" .filter(s.if_all(s.endswith(\"_zscore\"), _.abs() <= 2))\n", | |
" # drop columns that aren't necessary for further analysis\n", | |
" .drop(s.endswith(\"_zscore\"))\n", | |
" # select the columns we care about\n", | |
" .select(s.across([\"fare_amount\", \"trip_distance\"], _.cast(\"float32\")))\n", | |
" )\n", | |
"\n", | |
"training_data = clean_input(\"/home/cloud/data/trip-data/yellow_tripdata_2016-01.parquet\")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "f9bcfae7-8de7-40d2-9f52-021ab6373d79", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"training_data" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "1c1c1e4d-abc6-4e93-b3ba-ed5cee46107f", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"from model import PredictCabFare\n", | |
"\n", | |
"# def __init__(data: dict[str, tensor]): ...\n", | |
"# def train(self): ... # black box\n", | |
"# def __call__(self, input: pyarrow.ChunkedArray): ...\n", | |
"\n", | |
"torch_training_data = training_data.to_torch() # dict[str, Tensor]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "faede14c-0678-4485-9844-92a1c4b8b1e3", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"torch_training_data" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "11c1795b-d075-4bff-83ae-fa2d5047779c", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"model = PredictCabFare(torch_training_data)\n", | |
"model.train()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "a2282161-bd97-4572-9bff-75b23a74d7a0", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"@udf.scalar.pyarrow\n", | |
"def predict_fare(distance: dt.float64) -> dt.float32:\n", | |
" return model(distance)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "0c0d528c-fe55-463b-a0ae-b96b4ba7bfae", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"## Visualize the comparison between the predicted cab fares and the actual cab fares\n", | |
"prediction = clean_input(\n", | |
" \"/home/cloud/data/trip-data/yellow_tripdata_2016-02.parquet\"\n", | |
").limit(10_000).mutate(\n", | |
" predicted_fare=lambda t: predict_fare(t.trip_distance.cast(\"float32\")),\n", | |
")\n", | |
"prediction" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "9f8d93a4-c13b-43c3-a081-0b4c8a1ed028", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"pivoted_prediction = prediction.pivot_longer(~s.c(\"trip_distance\"), values_to=\"fare\", names_to=\"metric\")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "eba15ec7-7a17-4b97-a2c3-22353eb20441", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"pivoted_prediction" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "f0737502-be3b-4aac-b831-58e58f17a679", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"(\n", | |
" ggplot(pivoted_prediction, aes(x=\"trip_distance\", y=\"fare\", color=\"metric\"))\n", | |
" + geom_point()\n", | |
" + xlab(\"Trip Distance\")\n", | |
" + ylab(\"Fare\")\n", | |
" + ggtitle(\"Predicted Fare vs Actual Fare by Trip Distance\")\n", | |
")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "9bce2011-61ff-45d1-b7b4-279c6ccc84ed", | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3 (ipykernel)", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.10.12" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 5 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment