Skip to content

Instantly share code, notes, and snippets.

@aziflaj
Created November 12, 2024 07:46
Show Gist options
  • Save aziflaj/bc1b8e7aedd148f5fbb967cb7cdf8a7a to your computer and use it in GitHub Desktop.
Save aziflaj/bc1b8e7aedd148f5fbb967cb7cdf8a7a to your computer and use it in GitHub Desktop.
package order
import (
"database/sql"
"time"
)
type Order struct {
ID int64
UserID int64
Amount float64
Status string
CreatedAt time.Time
}
type EmailService interface {
SendOrderConfirmation(userEmail string, orderID int64, amount float64) error
}
type OrderService struct {
db *sql.DB
email EmailService
}
func NewOrderService(db *sql.DB, email EmailService) *OrderService {
return &OrderService{
db: db,
email: email,
}
}
func (s *OrderService) CreateOrder(userID int64, amount float64) (*Order, error) {
tx, err := s.db.Begin()
if err != nil {
return nil, err
}
defer tx.Rollback()
// Insert order
query := `INSERT INTO orders (user_id, amount, status, created_at)
VALUES (?, ?, ?, ?) RETURNING id`
var order Order
err = tx.QueryRow(query, userID, amount, "pending", time.Now()).Scan(&order.ID)
if err != nil {
return nil, err
}
// Get user email
var email string
err = tx.QueryRow("SELECT email FROM users WHERE id = ?", userID).Scan(&email)
if err != nil {
return nil, err
}
// Send confirmation email
err = s.email.SendOrderConfirmation(email, order.ID, amount)
if err != nil {
return nil, err
}
if err = tx.Commit(); err != nil {
return nil, err
}
return &order, nil
}
package order
import (
"database/sql"
"testing"
"time"
_ "github.com/go-sql-driver/mysql"
)
// TestEmailService implements EmailService for testing
type TestEmailService struct {
sentEmails []struct {
email string
orderID int64
amount float64
}
}
func (s *TestEmailService) SendOrderConfirmation(email string, orderID int64, amount float64) error {
s.sentEmails = append(s.sentEmails, struct {
email string
orderID int64
amount float64
}{email, orderID, amount})
return nil
}
func setupTestDB(t *testing.T) *sql.DB {
// Use a real test database
db, err := sql.Open("mysql", "test_user:test_pass@/test_db")
if err != nil {
t.Fatalf("failed to connect to test database: %v", err)
}
// Clean up tables
_, err = db.Exec(`TRUNCATE TABLE orders`)
if err != nil {
t.Fatalf("failed to truncate orders table: %v", err)
}
_, err = db.Exec(`TRUNCATE TABLE users`)
if err != nil {
t.Fatalf("failed to truncate users table: %v", err)
}
return db
}
func TestOrderService_Integration(t *testing.T) {
// Setup
db := setupTestDB(t)
defer db.Close()
emailService := &TestEmailService{}
orderService := NewOrderService(db, emailService)
// Insert test user
userID := int64(1)
userEmail := "[email protected]"
_, err := db.Exec(
"INSERT INTO users (id, email) VALUES (?, ?)",
userID, userEmail,
)
if err != nil {
t.Fatalf("failed to insert test user: %v", err)
}
t.Run("successful order creation", func(t *testing.T) {
amount := 99.99
order, err := orderService.CreateOrder(userID, amount)
if err != nil {
t.Fatalf("failed to create order: %v", err)
}
// Verify order in database
var dbOrder Order
err = db.QueryRow(
`SELECT id, user_id, amount, status, created_at
FROM orders WHERE id = ?`,
order.ID,
).Scan(
&dbOrder.ID,
&dbOrder.UserID,
&dbOrder.Amount,
&dbOrder.Status,
&dbOrder.CreatedAt,
)
if err != nil {
t.Fatalf("failed to fetch order from db: %v", err)
}
// Verify order details
if dbOrder.UserID != userID {
t.Errorf("got user_id %d, want %d", dbOrder.UserID, userID)
}
if dbOrder.Amount != amount {
t.Errorf("got amount %f, want %f", dbOrder.Amount, amount)
}
if dbOrder.Status != "pending" {
t.Errorf("got status %s, want pending", dbOrder.Status)
}
if time.Since(dbOrder.CreatedAt) > time.Minute {
t.Error("created_at timestamp is too old")
}
// Verify email was sent
if len(emailService.sentEmails) != 1 {
t.Fatalf("got %d emails sent, want 1", len(emailService.sentEmails))
}
sentEmail := emailService.sentEmails[0]
if sentEmail.email != userEmail {
t.Errorf("email sent to %s, want %s", sentEmail.email, userEmail)
}
if sentEmail.orderID != order.ID {
t.Errorf("email for order %d, want %d", sentEmail.orderID, order.ID)
}
if sentEmail.amount != amount {
t.Errorf("email amount %f, want %f", sentEmail.amount, amount)
}
})
t.Run("rollback on email failure", func(t *testing.T) {
// Setup email service that fails
failingEmailService := &TestEmailService{}
failingOrderService := NewOrderService(db, failingEmailService)
// Count initial orders
var initialCount int
err := db.QueryRow("SELECT COUNT(*) FROM orders").Scan(&initialCount)
if err != nil {
t.Fatalf("failed to count orders: %v", err)
}
// Attempt to create order
_, err = failingOrderService.CreateOrder(userID, 59.99)
if err == nil {
t.Fatal("expected error, got nil")
}
// Verify no order was created (transaction rolled back)
var finalCount int
err = db.QueryRow("SELECT COUNT(*) FROM orders").Scan(&finalCount)
if err != nil {
t.Fatalf("failed to count orders: %v", err)
}
if finalCount != initialCount {
t.Errorf("order count changed from %d to %d, expected no change",
initialCount, finalCount)
}
})
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment