Skip to content

Instantly share code, notes, and snippets.

@CVxTz
Last active August 18, 2024 11:05
Show Gist options
  • Save CVxTz/8eace07d9bd2c5123a89bf790b5cc39e to your computer and use it in GitHub Desktop.
Save CVxTz/8eace07d9bd2c5123a89bf790b5cc39e to your computer and use it in GitHub Desktop.
Json Mode using Gemini Flash
from dotenv import load_dotenv
import os
load_dotenv()
import google.generativeai as genai
from data_model import (
TouristLocation,
ClimateType,
ActivityType,
Attraction,
LocationType,
)
from google.generativeai.types.content_types import ContentDict
from pydantic import BaseModel, Field
from typing import List, Optional, Tuple
from enum import Enum
class ClimateType(str, Enum):
"""Type of climate at the location"""
tropical = "tropical"
desert = "desert"
temperate = "temperate"
continental = "continental"
polar = "polar"
unknown = "unknown"
@classmethod
def _missing_(cls, value):
return cls.unknown
class ActivityType(str, Enum):
"""Type of activities and attractions available at the location"""
# Social and Nightlife
partying = "partying"
clubbing = "clubbing"
bars_and_pubs = "bars and pubs"
live_music = "live music"
comedy_clubs = "comedy clubs"
# Nature and Outdoors
nature = "nature"
hiking = "hiking"
camping = "camping"
birdwatching = "birdwatching"
stargazing = "stargazing"
outdoor_activities = "outdoor activities"
scenic_drives = "scenic drives"
wildlife_viewing = "wildlife viewing"
# Water and Aquatic
aquatic_activities = "aquatic activities"
snorkeling = "snorkeling"
scuba_diving = "scuba diving"
surfing = "surfing"
kayaking = "kayaking"
paddleboarding = "paddleboarding"
boat_tours = "boat tours"
fishing = "fishing"
# Family and Children
children_activities = "children activities"
amusement_parks = "amusement parks"
water_parks = "water parks"
zoos_and_aquariums = "zoos and aquariums"
museums_for_kids = "museums for kids"
playgrounds = "playgrounds"
# Sports and Fitness
sports = "sports"
running = "running"
cycling = "cycling"
swimming = "swimming"
golfing = "golfing"
skiing = "skiing"
snowboarding = "snowboarding"
team_sports = "team sports"
# Wellness and Relaxation
wellness = "wellness"
spas = "spas"
yoga = "yoga"
meditation = "meditation"
saunas_and_hot_tubs = "saunas and hot tubs"
massage = "massage"
# Culture and History
cultural = "cultural"
historical = "historical"
museums = "museums"
art_galleries = "art galleries"
landmarks = "landmarks"
historical_sites = "historical sites"
cultural_events = "cultural events"
# Food and Drink
food_and_wine = "food and wine"
wine_tasting = "wine tasting"
brewery_tours = "brewery tours"
coffee_shops = "coffee shops"
restaurants = "restaurants"
cooking_classes = "cooking classes"
food_tours = "food tours"
# Shopping and Retail
shopping = "shopping"
malls = "malls"
markets = "markets"
boutiques = "boutiques"
souvenir_shops = "souvenir shops"
antique_shops = "antique shops"
# Entertainment and Performance
entertainment = "entertainment"
movies = "movies"
theater = "theater"
music_venues = "music venues"
comedy_shows = "comedy shows"
magic_shows = "magic shows"
# Education and Learning
educational = "educational"
workshops = "workshops"
conferences = "conferences"
seminars = "seminars"
classes = "classes"
lectures = "lectures"
# Spirituality and Faith
spiritual = "spiritual"
churches = "churches"
temples = "temples"
mosques = "mosques"
synagogues = "synagogues"
meditation_retreats = "meditation retreats"
# Games and Recreation
games = "games"
escape_rooms = "escape rooms"
game_centers = "game centers"
bowling = "bowling"
laser_tag = "laser tag"
mini_golf = "mini golf"
# Adventure and Thrills
adventure = "adventure"
skydiving = "skydiving"
bungee_jumping = "bungee jumping"
rock_climbing = "rock climbing"
zip_lining = "zip lining"
white_water_rafting = "white water rafting"
unknown = "unknown"
@classmethod
def _missing_(cls, value):
return cls.unknown
class LocationType(str, Enum):
"""Type of location (city, country, establishment, etc.)"""
city = "city"
country = "country"
establishment = "establishment"
landmark = "landmark"
national_park = "national park"
island = "island"
region = "region"
continent = "continent"
restaurant = "restaurant"
stadium = "stadium"
museum = "museum"
theater = "theater"
hotel = "hotel"
airport = "airport"
train_station = "train station"
bus_station = "bus station"
port = "port"
shopping_mall = "shopping mall"
market = "market"
unknown = "unknown"
@classmethod
def _missing_(cls, value):
return cls.unknown
class Attraction(BaseModel):
"""Model for an attraction"""
name: str = Field(..., description="Name of the attraction")
description: str = Field(..., description="Description of the attraction")
class TouristLocation(BaseModel):
"""Model for a tourist location"""
name: str = Field(..., description="Name of the location")
location_long_lat: Optional[Tuple[float, float]] = Field(
None,
description="Geographic coordinates (longitude, latitude) of the location, null if unknown",
)
climate_type: ClimateType = Field(
..., description="Type of climate at the location"
)
high_season_months: List[int] = Field(
[], description="List of months (1-12) when the location is most visited"
)
activity_types: List[ActivityType] = Field(
...,
description="List of activity types and attractions available at the location",
)
attraction_list: List[Attraction] = Field(
[], description="List of specific attractions at the location"
)
tags: List[str] = Field(
...,
description="List of tags describing the location (e.g. accessible, sustainable, sunny, cheap, pricey)",
min_length=1,
)
description: str = Field(..., description="Text description of the location")
most_notably_known_for: str = Field(
..., description="What the location is most notably known for"
)
location_type: LocationType = Field(
..., description="Type of location (city, country, establishment, etc.)"
)
parents: List[str] = Field(
[], description="List of parent locations (country, continent, city, etc.)"
)
class Config:
use_enum_values = True
class TouristLocations(BaseModel):
"""Model for a tourist locations"""
tourist_location: List[TouristLocation] = Field(
..., description="List of Locations"
)
def replace_value_in_dict(item, original_schema):
# Source: https://github.com/pydantic/pydantic/issues/889
if isinstance(item, list):
return [replace_value_in_dict(i, original_schema) for i in item]
elif isinstance(item, dict):
if list(item.keys()) == ["$ref"]:
definitions = item["$ref"][2:].split("/")
res = original_schema.copy()
for definition in definitions:
res = res[definition]
return res
else:
return {
key: replace_value_in_dict(i, original_schema)
for key, i in item.items()
}
else:
return item
def delete_keys_recursive(d, key_to_delete):
if isinstance(d, dict):
# Delete the key if it exists
if key_to_delete in d:
del d[key_to_delete]
# Recursively process all items in the dictionary
for k, v in d.items():
delete_keys_recursive(v, key_to_delete)
elif isinstance(d, list):
# Recursively process all items in the list
for item in d:
delete_keys_recursive(item, key_to_delete)
location = TouristLocation(
name="Hawaii",
location_long_lat=(-155.0, 19.9),
climate_type=ClimateType.tropical,
high_season_months=[6, 7, 8, 12],
activity_types=[
ActivityType.surfing,
ActivityType.snorkeling,
ActivityType.hiking,
],
attraction_list=[
Attraction(
name="Waikiki Beach",
description="One of the most iconic and popular beaches in the world",
),
Attraction(
name="Haleakala National Park",
description="A dormant volcano with stunning sunrises and sunsets",
),
],
tags=["beach", "tropical", "paradise", "expensive"],
description="A tropical paradise with stunning beaches and lush scenery",
most_notably_known_for="Its beautiful beaches and active volcanoes",
location_type=LocationType.island,
parents=["USA", "North America"],
)
print(TouristLocation.model_json_schema())
if __name__ == "__main__":
schema = TouristLocation.model_json_schema()
schema = replace_value_in_dict(schema.copy(), schema.copy())
del schema["$defs"]
delete_keys_recursive(schema, key_to_delete="title")
delete_keys_recursive(schema, key_to_delete="location_long_lat")
delete_keys_recursive(schema, key_to_delete="default")
delete_keys_recursive(schema, key_to_delete="default")
delete_keys_recursive(schema, key_to_delete="minItems")
print(schema)
messages = [
ContentDict(
role="user",
parts=[
"You are a helpful assistant that outputs in JSON."
f"Follow this schema {TouristLocation.model_json_schema()}"
],
),
ContentDict(role="user", parts=["Generate information about Hawaii, US."]),
ContentDict(role="model", parts=[f"{location.model_dump_json()}"]),
ContentDict(role="user", parts=["Generate information about Casablanca"]),
]
genai.configure(api_key=os.environ["GOOGLE_API_KEY"])
# Using `response_mime_type` with `response_schema` requires a Gemini 1.5 Pro model
model = genai.GenerativeModel(
"gemini-1.5-flash",
# Set the `response_mime_type` to output JSON
# Pass the schema object to the `response_schema` field
generation_config={
"response_mime_type": "application/json",
"response_schema": schema,
},
)
response = model.generate_content(messages)
print(response.text)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment