Last active
August 18, 2024 11:05
-
-
Save CVxTz/8eace07d9bd2c5123a89bf790b5cc39e to your computer and use it in GitHub Desktop.
Json Mode using Gemini Flash
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from dotenv import load_dotenv | |
import os | |
load_dotenv() | |
import google.generativeai as genai | |
from data_model import ( | |
TouristLocation, | |
ClimateType, | |
ActivityType, | |
Attraction, | |
LocationType, | |
) | |
from google.generativeai.types.content_types import ContentDict | |
from pydantic import BaseModel, Field | |
from typing import List, Optional, Tuple | |
from enum import Enum | |
class ClimateType(str, Enum): | |
"""Type of climate at the location""" | |
tropical = "tropical" | |
desert = "desert" | |
temperate = "temperate" | |
continental = "continental" | |
polar = "polar" | |
unknown = "unknown" | |
@classmethod | |
def _missing_(cls, value): | |
return cls.unknown | |
class ActivityType(str, Enum): | |
"""Type of activities and attractions available at the location""" | |
# Social and Nightlife | |
partying = "partying" | |
clubbing = "clubbing" | |
bars_and_pubs = "bars and pubs" | |
live_music = "live music" | |
comedy_clubs = "comedy clubs" | |
# Nature and Outdoors | |
nature = "nature" | |
hiking = "hiking" | |
camping = "camping" | |
birdwatching = "birdwatching" | |
stargazing = "stargazing" | |
outdoor_activities = "outdoor activities" | |
scenic_drives = "scenic drives" | |
wildlife_viewing = "wildlife viewing" | |
# Water and Aquatic | |
aquatic_activities = "aquatic activities" | |
snorkeling = "snorkeling" | |
scuba_diving = "scuba diving" | |
surfing = "surfing" | |
kayaking = "kayaking" | |
paddleboarding = "paddleboarding" | |
boat_tours = "boat tours" | |
fishing = "fishing" | |
# Family and Children | |
children_activities = "children activities" | |
amusement_parks = "amusement parks" | |
water_parks = "water parks" | |
zoos_and_aquariums = "zoos and aquariums" | |
museums_for_kids = "museums for kids" | |
playgrounds = "playgrounds" | |
# Sports and Fitness | |
sports = "sports" | |
running = "running" | |
cycling = "cycling" | |
swimming = "swimming" | |
golfing = "golfing" | |
skiing = "skiing" | |
snowboarding = "snowboarding" | |
team_sports = "team sports" | |
# Wellness and Relaxation | |
wellness = "wellness" | |
spas = "spas" | |
yoga = "yoga" | |
meditation = "meditation" | |
saunas_and_hot_tubs = "saunas and hot tubs" | |
massage = "massage" | |
# Culture and History | |
cultural = "cultural" | |
historical = "historical" | |
museums = "museums" | |
art_galleries = "art galleries" | |
landmarks = "landmarks" | |
historical_sites = "historical sites" | |
cultural_events = "cultural events" | |
# Food and Drink | |
food_and_wine = "food and wine" | |
wine_tasting = "wine tasting" | |
brewery_tours = "brewery tours" | |
coffee_shops = "coffee shops" | |
restaurants = "restaurants" | |
cooking_classes = "cooking classes" | |
food_tours = "food tours" | |
# Shopping and Retail | |
shopping = "shopping" | |
malls = "malls" | |
markets = "markets" | |
boutiques = "boutiques" | |
souvenir_shops = "souvenir shops" | |
antique_shops = "antique shops" | |
# Entertainment and Performance | |
entertainment = "entertainment" | |
movies = "movies" | |
theater = "theater" | |
music_venues = "music venues" | |
comedy_shows = "comedy shows" | |
magic_shows = "magic shows" | |
# Education and Learning | |
educational = "educational" | |
workshops = "workshops" | |
conferences = "conferences" | |
seminars = "seminars" | |
classes = "classes" | |
lectures = "lectures" | |
# Spirituality and Faith | |
spiritual = "spiritual" | |
churches = "churches" | |
temples = "temples" | |
mosques = "mosques" | |
synagogues = "synagogues" | |
meditation_retreats = "meditation retreats" | |
# Games and Recreation | |
games = "games" | |
escape_rooms = "escape rooms" | |
game_centers = "game centers" | |
bowling = "bowling" | |
laser_tag = "laser tag" | |
mini_golf = "mini golf" | |
# Adventure and Thrills | |
adventure = "adventure" | |
skydiving = "skydiving" | |
bungee_jumping = "bungee jumping" | |
rock_climbing = "rock climbing" | |
zip_lining = "zip lining" | |
white_water_rafting = "white water rafting" | |
unknown = "unknown" | |
@classmethod | |
def _missing_(cls, value): | |
return cls.unknown | |
class LocationType(str, Enum): | |
"""Type of location (city, country, establishment, etc.)""" | |
city = "city" | |
country = "country" | |
establishment = "establishment" | |
landmark = "landmark" | |
national_park = "national park" | |
island = "island" | |
region = "region" | |
continent = "continent" | |
restaurant = "restaurant" | |
stadium = "stadium" | |
museum = "museum" | |
theater = "theater" | |
hotel = "hotel" | |
airport = "airport" | |
train_station = "train station" | |
bus_station = "bus station" | |
port = "port" | |
shopping_mall = "shopping mall" | |
market = "market" | |
unknown = "unknown" | |
@classmethod | |
def _missing_(cls, value): | |
return cls.unknown | |
class Attraction(BaseModel): | |
"""Model for an attraction""" | |
name: str = Field(..., description="Name of the attraction") | |
description: str = Field(..., description="Description of the attraction") | |
class TouristLocation(BaseModel): | |
"""Model for a tourist location""" | |
name: str = Field(..., description="Name of the location") | |
location_long_lat: Optional[Tuple[float, float]] = Field( | |
None, | |
description="Geographic coordinates (longitude, latitude) of the location, null if unknown", | |
) | |
climate_type: ClimateType = Field( | |
..., description="Type of climate at the location" | |
) | |
high_season_months: List[int] = Field( | |
[], description="List of months (1-12) when the location is most visited" | |
) | |
activity_types: List[ActivityType] = Field( | |
..., | |
description="List of activity types and attractions available at the location", | |
) | |
attraction_list: List[Attraction] = Field( | |
[], description="List of specific attractions at the location" | |
) | |
tags: List[str] = Field( | |
..., | |
description="List of tags describing the location (e.g. accessible, sustainable, sunny, cheap, pricey)", | |
min_length=1, | |
) | |
description: str = Field(..., description="Text description of the location") | |
most_notably_known_for: str = Field( | |
..., description="What the location is most notably known for" | |
) | |
location_type: LocationType = Field( | |
..., description="Type of location (city, country, establishment, etc.)" | |
) | |
parents: List[str] = Field( | |
[], description="List of parent locations (country, continent, city, etc.)" | |
) | |
class Config: | |
use_enum_values = True | |
class TouristLocations(BaseModel): | |
"""Model for a tourist locations""" | |
tourist_location: List[TouristLocation] = Field( | |
..., description="List of Locations" | |
) | |
def replace_value_in_dict(item, original_schema): | |
# Source: https://github.com/pydantic/pydantic/issues/889 | |
if isinstance(item, list): | |
return [replace_value_in_dict(i, original_schema) for i in item] | |
elif isinstance(item, dict): | |
if list(item.keys()) == ["$ref"]: | |
definitions = item["$ref"][2:].split("/") | |
res = original_schema.copy() | |
for definition in definitions: | |
res = res[definition] | |
return res | |
else: | |
return { | |
key: replace_value_in_dict(i, original_schema) | |
for key, i in item.items() | |
} | |
else: | |
return item | |
def delete_keys_recursive(d, key_to_delete): | |
if isinstance(d, dict): | |
# Delete the key if it exists | |
if key_to_delete in d: | |
del d[key_to_delete] | |
# Recursively process all items in the dictionary | |
for k, v in d.items(): | |
delete_keys_recursive(v, key_to_delete) | |
elif isinstance(d, list): | |
# Recursively process all items in the list | |
for item in d: | |
delete_keys_recursive(item, key_to_delete) | |
location = TouristLocation( | |
name="Hawaii", | |
location_long_lat=(-155.0, 19.9), | |
climate_type=ClimateType.tropical, | |
high_season_months=[6, 7, 8, 12], | |
activity_types=[ | |
ActivityType.surfing, | |
ActivityType.snorkeling, | |
ActivityType.hiking, | |
], | |
attraction_list=[ | |
Attraction( | |
name="Waikiki Beach", | |
description="One of the most iconic and popular beaches in the world", | |
), | |
Attraction( | |
name="Haleakala National Park", | |
description="A dormant volcano with stunning sunrises and sunsets", | |
), | |
], | |
tags=["beach", "tropical", "paradise", "expensive"], | |
description="A tropical paradise with stunning beaches and lush scenery", | |
most_notably_known_for="Its beautiful beaches and active volcanoes", | |
location_type=LocationType.island, | |
parents=["USA", "North America"], | |
) | |
print(TouristLocation.model_json_schema()) | |
if __name__ == "__main__": | |
schema = TouristLocation.model_json_schema() | |
schema = replace_value_in_dict(schema.copy(), schema.copy()) | |
del schema["$defs"] | |
delete_keys_recursive(schema, key_to_delete="title") | |
delete_keys_recursive(schema, key_to_delete="location_long_lat") | |
delete_keys_recursive(schema, key_to_delete="default") | |
delete_keys_recursive(schema, key_to_delete="default") | |
delete_keys_recursive(schema, key_to_delete="minItems") | |
print(schema) | |
messages = [ | |
ContentDict( | |
role="user", | |
parts=[ | |
"You are a helpful assistant that outputs in JSON." | |
f"Follow this schema {TouristLocation.model_json_schema()}" | |
], | |
), | |
ContentDict(role="user", parts=["Generate information about Hawaii, US."]), | |
ContentDict(role="model", parts=[f"{location.model_dump_json()}"]), | |
ContentDict(role="user", parts=["Generate information about Casablanca"]), | |
] | |
genai.configure(api_key=os.environ["GOOGLE_API_KEY"]) | |
# Using `response_mime_type` with `response_schema` requires a Gemini 1.5 Pro model | |
model = genai.GenerativeModel( | |
"gemini-1.5-flash", | |
# Set the `response_mime_type` to output JSON | |
# Pass the schema object to the `response_schema` field | |
generation_config={ | |
"response_mime_type": "application/json", | |
"response_schema": schema, | |
}, | |
) | |
response = model.generate_content(messages) | |
print(response.text) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment