Created
June 26, 2023 19:18
-
-
Save cpcloud/84a4e3eb4df56812db1fc488ff120cb8 to your computer and use it in GitHub Desktop.
Ibis, DuckDB and PyTorch
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import ibis.expr.datatypes as dt | |
import torch | |
import torch.nn as nn | |
import tqdm | |
import pyarrow as pa | |
class LinearRegression(nn.Module): | |
def __init__(self, input_dim, output_dim): | |
super().__init__() | |
self.linear = nn.Linear(input_dim, output_dim) | |
def forward(self, distances): | |
return self.linear(distances) | |
class PredictCabFare: | |
def __init__(self, data, learning_rate: float = 0.01, epochs: int = 100) -> None: | |
# Define the input and output dimensions | |
input_dim = 1 | |
output_dim = 1 | |
# Create a linear regression model instance | |
self.data = data | |
self.model = LinearRegression(input_dim, output_dim) | |
self.learning_rate = learning_rate | |
self.epochs = epochs | |
def train(self): | |
distances = self.data["trip_distance"].reshape(-1, 1) | |
fares = self.data["fare_amount"].reshape(-1, 1) | |
# Define the loss function | |
criterion = nn.MSELoss() | |
# Define the optimizer | |
optimizer = torch.optim.SGD(self.model.parameters(), lr=self.learning_rate) | |
# Train the model | |
for _ in tqdm.trange(self.epochs): # noqa: F402 | |
# Forward pass | |
y_pred = self.model(distances) | |
# Compute loss | |
loss = criterion(y_pred, fares) | |
# Backward pass and optimize | |
optimizer.zero_grad() | |
loss.backward() | |
optimizer.step() | |
def predict(self, input): | |
with torch.no_grad(): | |
return self.model(input) | |
def __call__(self, input: pa.ChunkedArray): | |
# Convert the input to numpy so it can be fed to the model | |
# | |
# .copy() to avoid the warning about undefined behavior from torch | |
input = torch.from_numpy(input.to_numpy().copy())[:, None] | |
predicted = self.predict(input).ravel() | |
return pa.array(predicted.numpy()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment