|
#!/usr/bin/env -S uv run --script |
|
# /// script |
|
# requires-python = ">=3.11" |
|
# dependencies = [ |
|
# "click", |
|
# "pyyaml", |
|
# "requests", |
|
# ] |
|
# /// |
|
|
|
from __future__ import annotations |
|
|
|
import fnmatch |
|
import json |
|
import os |
|
import secrets |
|
import socket |
|
import time |
|
import webbrowser |
|
from datetime import datetime, timedelta, timezone |
|
from http.server import BaseHTTPRequestHandler, HTTPServer |
|
from pathlib import Path |
|
from threading import Event, Thread |
|
from urllib.parse import parse_qs, urlencode, urlparse |
|
|
|
import click |
|
import requests |
|
import yaml |
|
|
|
# --------------------------------------------------------------------------- |
|
# Paths |
|
# --------------------------------------------------------------------------- |
|
|
|
CONFIG_DIR = Path("~/.config/ha-cli").expanduser() |
|
CONFIG_FILE = CONFIG_DIR / "config.yaml" |
|
TOKENS_FILE = CONFIG_DIR / "tokens.json" |
|
|
|
# --------------------------------------------------------------------------- |
|
# Config |
|
# --------------------------------------------------------------------------- |
|
|
|
|
|
class Config: |
|
def __init__(self, config_dir: Path = CONFIG_DIR): |
|
self._config_file = config_dir / "config.yaml" |
|
|
|
def load(self) -> dict: |
|
if self._config_file.exists(): |
|
with self._config_file.open() as f: |
|
return yaml.safe_load(f) or {} |
|
return {} |
|
|
|
def save_url(self, url: str) -> None: |
|
self._config_file.parent.mkdir(parents=True, exist_ok=True) |
|
data = self.load() |
|
data["url"] = url |
|
with self._config_file.open("w") as f: |
|
yaml.dump(data, f) |
|
|
|
def get_url(self, override: str | None = None) -> str: |
|
url = self.load().get("url") |
|
url = os.environ.get("HA_URL", url) |
|
url = override or url |
|
if not url: |
|
raise click.UsageError( |
|
"Home Assistant URL is not configured. " |
|
"Set it via --url, the HA_URL environment variable, " |
|
"or run: ha_cli.py auth --url <url>" |
|
) |
|
return url |
|
|
|
|
|
# --------------------------------------------------------------------------- |
|
# TokenStore |
|
# --------------------------------------------------------------------------- |
|
|
|
|
|
class TokenStore: |
|
def __init__(self, config_dir: Path = CONFIG_DIR): |
|
self._tokens_file = config_dir / "tokens.json" |
|
|
|
def load(self) -> dict | None: |
|
if not self._tokens_file.exists(): |
|
return None |
|
with self._tokens_file.open() as f: |
|
return json.load(f) |
|
|
|
def save(self, tokens: dict) -> None: |
|
self._tokens_file.parent.mkdir(parents=True, exist_ok=True) |
|
with self._tokens_file.open("w") as f: |
|
json.dump(tokens, f) |
|
|
|
def is_expired(self, tokens: dict) -> bool: |
|
return time.time() >= tokens["expires_at"] |
|
|
|
|
|
# --------------------------------------------------------------------------- |
|
# AuthManager |
|
# --------------------------------------------------------------------------- |
|
|
|
|
|
class AuthManager: |
|
def __init__(self, token_store: TokenStore): |
|
self._store = token_store |
|
|
|
@staticmethod |
|
def _find_free_port() -> int: |
|
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: |
|
s.bind(("localhost", 0)) |
|
return s.getsockname()[1] |
|
|
|
def _wait_for_callback(self, port: int) -> tuple[str, str]: |
|
result: dict = {} |
|
ready = Event() |
|
|
|
class _Handler(BaseHTTPRequestHandler): |
|
def do_GET(self) -> None: |
|
parsed = urlparse(self.path) |
|
params = parse_qs(parsed.query) |
|
result["code"] = params.get("code", [""])[0] |
|
result["state"] = params.get("state", [""])[0] |
|
self.send_response(200) |
|
self.end_headers() |
|
self.wfile.write(b"Login successful. You may close this window.") |
|
ready.set() |
|
|
|
def log_message(self, *args: object) -> None: |
|
pass |
|
|
|
server = HTTPServer(("localhost", port), _Handler) |
|
server.timeout = 1.0 |
|
|
|
def _serve() -> None: |
|
while not ready.is_set(): |
|
server.handle_request() |
|
server.server_close() |
|
|
|
t = Thread(target=_serve, daemon=True) |
|
t.start() |
|
t.join(timeout=120) |
|
if not ready.is_set(): |
|
raise click.ClickException("Timed out waiting for OAuth callback.") |
|
return result["code"], result["state"] |
|
|
|
def refresh_tokens(self, ha_url: str, tokens: dict) -> dict: |
|
resp = requests.post( |
|
f"{ha_url.rstrip('/')}/auth/token", |
|
data={ |
|
"grant_type": "refresh_token", |
|
"refresh_token": tokens["refresh_token"], |
|
"client_id": tokens["client_id"], |
|
}, |
|
) |
|
resp.raise_for_status() |
|
data = resp.json() |
|
return { |
|
**tokens, |
|
"access_token": data["access_token"], |
|
"refresh_token": data.get("refresh_token", tokens["refresh_token"]), |
|
"expires_at": time.time() + data["expires_in"], |
|
} |
|
|
|
def run_oauth_flow(self, ha_url: str) -> dict: |
|
ha_url = ha_url.rstrip("/") |
|
port = self._find_free_port() |
|
redirect_uri = f"http://localhost:{port}/" |
|
state = secrets.token_urlsafe(16) |
|
|
|
auth_params = urlencode({ |
|
"response_type": "code", |
|
"client_id": redirect_uri, |
|
"redirect_uri": redirect_uri, |
|
"state": state, |
|
}) |
|
auth_url = f"{ha_url}/auth/authorize?{auth_params}" |
|
click.echo(f"Opening browser for Home Assistant login...\n{auth_url}") |
|
webbrowser.open(auth_url) |
|
|
|
code, returned_state = self._wait_for_callback(port) |
|
if returned_state != state: |
|
raise click.ClickException("OAuth state mismatch — possible CSRF. Aborting.") |
|
|
|
resp = requests.post( |
|
f"{ha_url}/auth/token", |
|
data={ |
|
"grant_type": "authorization_code", |
|
"code": code, |
|
"client_id": redirect_uri, |
|
"redirect_uri": redirect_uri, |
|
}, |
|
) |
|
resp.raise_for_status() |
|
data = resp.json() |
|
return { |
|
"access_token": data["access_token"], |
|
"refresh_token": data["refresh_token"], |
|
"expires_at": time.time() + data["expires_in"], |
|
"client_id": redirect_uri, |
|
} |
|
|
|
def get_valid_tokens(self, ha_url: str) -> str: |
|
tokens = self._store.load() |
|
if tokens is None: |
|
click.echo( |
|
"No tokens found. Run: ha_cli.py auth --url <url>", |
|
err=True, |
|
) |
|
raise SystemExit(1) |
|
if self._store.is_expired(tokens): |
|
tokens = self.refresh_tokens(ha_url, tokens) |
|
self._store.save(tokens) |
|
return tokens["access_token"] |
|
|
|
|
|
# --------------------------------------------------------------------------- |
|
# HAClient |
|
# --------------------------------------------------------------------------- |
|
|
|
|
|
class HAClient: |
|
def __init__(self, ha_url: str, access_token: str): |
|
self._url = ha_url.rstrip("/") |
|
self._session = requests.Session() |
|
self._session.headers["Authorization"] = f"Bearer {access_token}" |
|
|
|
def get_states(self) -> list[dict]: |
|
resp = self._session.get(f"{self._url}/api/states") |
|
resp.raise_for_status() |
|
return resp.json() |
|
|
|
def get_state(self, entity_id: str) -> dict: |
|
resp = self._session.get(f"{self._url}/api/states/{entity_id}") |
|
resp.raise_for_status() |
|
return resp.json() |
|
|
|
def get_automation_config(self, automation_id: str) -> dict | None: |
|
resp = self._session.get(f"{self._url}/api/config/automation/config/{automation_id}") |
|
try: |
|
resp.raise_for_status() |
|
except requests.HTTPError as exc: |
|
if exc.response.status_code == 404: |
|
return None |
|
raise |
|
return resp.json() |
|
|
|
def create_automation(self, config: dict) -> dict: |
|
resp = self._session.post( |
|
f"{self._url}/api/config/automation/config", |
|
json=config, |
|
) |
|
resp.raise_for_status() |
|
return resp.json() |
|
|
|
def update_automation_config(self, automation_id: str, config: dict) -> dict: |
|
resp = self._session.post( |
|
f"{self._url}/api/config/automation/config/{automation_id}", |
|
json=config, |
|
) |
|
resp.raise_for_status() |
|
return resp.json() |
|
|
|
def get_history( |
|
self, |
|
entity_ids: list[str], |
|
start: datetime, |
|
end: datetime, |
|
minimal: bool = False, |
|
) -> list[list[dict]]: |
|
params: dict[str, str] = { |
|
"end_time": end.isoformat(), |
|
"filter_entity_id": ",".join(entity_ids), |
|
"significant_changes_only": "false", |
|
} |
|
if minimal: |
|
params["minimal_response"] = "true" |
|
resp = self._session.get( |
|
f"{self._url}/api/history/period/{start.isoformat()}", |
|
params=params, |
|
) |
|
resp.raise_for_status() |
|
return resp.json() |
|
|
|
def get_logbook( |
|
self, |
|
entity_ids: list[str], |
|
start: datetime, |
|
end: datetime, |
|
) -> list[dict]: |
|
params: dict[str, str] = {"end_time": end.isoformat()} |
|
if entity_ids: |
|
params["entity_id"] = ",".join(entity_ids) |
|
resp = self._session.get( |
|
f"{self._url}/api/logbook/{start.isoformat()}", |
|
params=params, |
|
) |
|
resp.raise_for_status() |
|
return resp.json() |
|
|
|
def get_services(self) -> list[dict]: |
|
resp = self._session.get(f"{self._url}/api/services") |
|
resp.raise_for_status() |
|
return resp.json() |
|
|
|
def render_template(self, template: str) -> str: |
|
resp = self._session.post( |
|
f"{self._url}/api/template", |
|
json={"template": template}, |
|
) |
|
resp.raise_for_status() |
|
return resp.text |
|
|
|
|
|
# --------------------------------------------------------------------------- |
|
# Helpers |
|
# --------------------------------------------------------------------------- |
|
|
|
|
|
_RELATIVE_UNITS = {"h": "hours", "d": "days", "w": "weeks"} |
|
|
|
|
|
def parse_datetime(s: str) -> datetime: |
|
s = s.strip() |
|
if s == "now": |
|
return datetime.now(timezone.utc) |
|
if len(s) >= 2 and s[-1] in _RELATIVE_UNITS and s[:-1].isdigit(): |
|
amount = int(s[:-1]) |
|
return datetime.now(timezone.utc) - timedelta(**{_RELATIVE_UNITS[s[-1]]: amount}) |
|
try: |
|
dt = datetime.fromisoformat(s) |
|
if dt.tzinfo is None: |
|
dt = dt.replace(tzinfo=timezone.utc) |
|
return dt |
|
except ValueError: |
|
raise click.BadParameter(f"Cannot parse datetime: {s!r}") |
|
|
|
|
|
def match_entities( |
|
states: list[dict], |
|
domains: tuple[str, ...], |
|
pattern: str | None, |
|
) -> list[str]: |
|
results = [] |
|
for state in states: |
|
eid = state["entity_id"] |
|
if domains and eid.split(".")[0] not in domains: |
|
continue |
|
if pattern and not fnmatch.fnmatch(eid, pattern): |
|
continue |
|
results.append(eid) |
|
return results |
|
|
|
|
|
# --------------------------------------------------------------------------- |
|
# CLI |
|
# --------------------------------------------------------------------------- |
|
|
|
|
|
@click.group() |
|
@click.option("--url", default=None, help="Home Assistant base URL") |
|
@click.pass_context |
|
def cli(ctx: click.Context, url: str | None) -> None: |
|
ctx.ensure_object(dict) |
|
ctx.obj["url"] = url |
|
|
|
|
|
@cli.command() |
|
@click.option("--url", default=None, help="Home Assistant base URL") |
|
@click.pass_context |
|
def auth(ctx: click.Context, url: str | None) -> None: |
|
"""Authenticate with Home Assistant via OAuth2.""" |
|
config = Config() |
|
ha_url = config.get_url(override=url or ctx.obj.get("url")) |
|
store = TokenStore() |
|
mgr = AuthManager(store) |
|
tokens = mgr.run_oauth_flow(ha_url) |
|
store.save(tokens) |
|
config.save_url(ha_url) |
|
click.echo("Authentication successful. Tokens saved.") |
|
|
|
|
|
@cli.command() |
|
@click.argument("entities", nargs=-1) |
|
@click.option("-d", "--domain", "domains", multiple=True, help="HA domain(s) to query") |
|
@click.option("-f", "--filter", "pattern", default=None, help="Glob filter on entity IDs") |
|
@click.option("-s", "--start", "start_str", required=True, help="Start datetime (ISO or 2h/1d/7d)") |
|
@click.option("-e", "--end", "end_str", default="now", help="End datetime (ISO or 'now')") |
|
@click.option("-o", "--output", "output_path", default=None, help="Output file path") |
|
@click.option("--minimal", is_flag=True, default=False, help="Omit state attributes") |
|
@click.option("--url", default=None, help="Home Assistant base URL") |
|
@click.pass_context |
|
def query( |
|
ctx: click.Context, |
|
entities: tuple[str, ...], |
|
domains: tuple[str, ...], |
|
pattern: str | None, |
|
start_str: str, |
|
end_str: str, |
|
output_path: str | None, |
|
minimal: bool, |
|
url: str | None, |
|
) -> None: |
|
"""Query HA state history and write to JSON.""" |
|
if not entities and not domains: |
|
raise click.UsageError("Specify at least one entity ID or use --domain.") |
|
config = Config() |
|
ha_url = config.get_url(override=url or ctx.obj.get("url")) |
|
store = TokenStore() |
|
mgr = AuthManager(store) |
|
access_token = mgr.get_valid_tokens(ha_url) |
|
client = HAClient(ha_url, access_token) |
|
|
|
start = parse_datetime(start_str) |
|
end = parse_datetime(end_str) |
|
|
|
if domains: |
|
states = client.get_states() |
|
entity_ids = match_entities(states, domains, pattern) |
|
else: |
|
entity_ids = list(entities) |
|
if pattern: |
|
entity_ids = [e for e in entity_ids if fnmatch.fnmatch(e, pattern)] |
|
|
|
if not entity_ids: |
|
raise click.UsageError("No entities matched the given filters.") |
|
|
|
history = client.get_history(entity_ids, start, end, minimal) |
|
|
|
if output_path is None: |
|
ts = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%SZ") |
|
output_path = f"ha_query_{ts}.json" |
|
|
|
output = { |
|
"query": { |
|
"start": start.isoformat(), |
|
"end": end.isoformat(), |
|
"entities_requested": entity_ids, |
|
"ha_url": ha_url, |
|
}, |
|
"history": history, |
|
} |
|
with open(output_path, "w") as f: |
|
json.dump(output, f, indent=2) |
|
click.echo(f"Written to {output_path}") |
|
|
|
|
|
@cli.command("entities") |
|
@click.option("-d", "--domain", "domains", multiple=True, help="Filter by domain") |
|
@click.option("-f", "--filter", "pattern", default=None, help="Glob filter on entity IDs") |
|
@click.option("--url", default=None, help="Home Assistant base URL") |
|
@click.pass_context |
|
def list_entities( |
|
ctx: click.Context, |
|
domains: tuple[str, ...], |
|
pattern: str | None, |
|
url: str | None, |
|
) -> None: |
|
"""List entities matching domain/pattern filters.""" |
|
config = Config() |
|
ha_url = config.get_url(override=url or ctx.obj.get("url")) |
|
store = TokenStore() |
|
mgr = AuthManager(store) |
|
access_token = mgr.get_valid_tokens(ha_url) |
|
client = HAClient(ha_url, access_token) |
|
states = client.get_states() |
|
entity_ids = match_entities(states, domains, pattern) |
|
for eid in entity_ids: |
|
state_val = next(s["state"] for s in states if s["entity_id"] == eid) |
|
click.echo(f"{eid}: {state_val}") |
|
|
|
|
|
@cli.command("automation") |
|
@click.argument("entity_id") |
|
@click.option("-s", "--start", "start_str", default=None, help="Include trigger history from this time (ISO or 2h/1d/7d)") |
|
@click.option("-o", "--output", "output_path", default=None, help="Output file path") |
|
@click.option("--url", default=None, help="Home Assistant base URL") |
|
@click.pass_context |
|
def show_automation( |
|
ctx: click.Context, |
|
entity_id: str, |
|
start_str: str | None, |
|
output_path: str | None, |
|
url: str | None, |
|
) -> None: |
|
"""Fetch automation config and optionally its trigger history.""" |
|
config = Config() |
|
ha_url = config.get_url(override=url or ctx.obj.get("url")) |
|
store = TokenStore() |
|
mgr = AuthManager(store) |
|
access_token = mgr.get_valid_tokens(ha_url) |
|
client = HAClient(ha_url, access_token) |
|
|
|
state = client.get_state(entity_id) |
|
automation_id = state.get("attributes", {}).get("id") |
|
auto_config = client.get_automation_config(automation_id) if automation_id else None |
|
|
|
output: dict = { |
|
"entity_id": entity_id, |
|
"state": state, |
|
"config": auto_config, |
|
} |
|
|
|
if start_str is not None: |
|
start = parse_datetime(start_str) |
|
end = datetime.now(timezone.utc) |
|
output["history"] = client.get_history([entity_id], start, end, minimal=False) |
|
|
|
if output_path is None: |
|
ts = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%SZ") |
|
output_path = f"ha_automation_{ts}.json" |
|
|
|
with open(output_path, "w") as f: |
|
json.dump(output, f, indent=2) |
|
click.echo(f"Written to {output_path}") |
|
|
|
|
|
@cli.command("logbook") |
|
@click.argument("entities", nargs=-1) |
|
@click.option("-d", "--domain", "domains", multiple=True, help="HA domain(s) to query") |
|
@click.option("-f", "--filter", "pattern", default=None, help="Glob filter on entity IDs") |
|
@click.option("-s", "--start", "start_str", required=True, help="Start datetime (ISO or 2h/1d/7d)") |
|
@click.option("-e", "--end", "end_str", default="now", help="End datetime (ISO or 'now')") |
|
@click.option("-o", "--output", "output_path", default=None, help="Output file path") |
|
@click.option("--url", default=None, help="Home Assistant base URL") |
|
@click.pass_context |
|
def logbook_cmd( |
|
ctx: click.Context, |
|
entities: tuple[str, ...], |
|
domains: tuple[str, ...], |
|
pattern: str | None, |
|
start_str: str, |
|
end_str: str, |
|
output_path: str | None, |
|
url: str | None, |
|
) -> None: |
|
"""Fetch logbook entries and write to JSON.""" |
|
config = Config() |
|
ha_url = config.get_url(override=url or ctx.obj.get("url")) |
|
store = TokenStore() |
|
mgr = AuthManager(store) |
|
access_token = mgr.get_valid_tokens(ha_url) |
|
client = HAClient(ha_url, access_token) |
|
|
|
start = parse_datetime(start_str) |
|
end = parse_datetime(end_str) |
|
|
|
if domains: |
|
states = client.get_states() |
|
entity_ids = match_entities(states, domains, pattern) |
|
else: |
|
entity_ids = list(entities) |
|
if pattern: |
|
entity_ids = [e for e in entity_ids if fnmatch.fnmatch(e, pattern)] |
|
|
|
entries = client.get_logbook(entity_ids, start, end) |
|
|
|
if output_path is None: |
|
ts = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%SZ") |
|
output_path = f"ha_logbook_{ts}.json" |
|
|
|
output = { |
|
"query": { |
|
"start": start.isoformat(), |
|
"end": end.isoformat(), |
|
"entities_requested": entity_ids, |
|
"ha_url": ha_url, |
|
}, |
|
"entries": entries, |
|
} |
|
with open(output_path, "w") as f: |
|
json.dump(output, f, indent=2) |
|
click.echo(f"Written to {output_path}") |
|
|
|
|
|
@cli.command("services") |
|
@click.option("-d", "--domain", "domain", default=None, help="Filter by domain") |
|
@click.option("-o", "--output", "output_path", default=None, help="Output file path") |
|
@click.option("--url", default=None, help="Home Assistant base URL") |
|
@click.pass_context |
|
def services_cmd( |
|
ctx: click.Context, |
|
domain: str | None, |
|
output_path: str | None, |
|
url: str | None, |
|
) -> None: |
|
"""List available HA services, optionally filtered by domain.""" |
|
config = Config() |
|
ha_url = config.get_url(override=url or ctx.obj.get("url")) |
|
store = TokenStore() |
|
mgr = AuthManager(store) |
|
access_token = mgr.get_valid_tokens(ha_url) |
|
client = HAClient(ha_url, access_token) |
|
|
|
all_services = client.get_services() |
|
if domain: |
|
all_services = [s for s in all_services if s.get("domain") == domain] |
|
|
|
if output_path is None: |
|
click.echo(json.dumps(all_services, indent=2)) |
|
else: |
|
with open(output_path, "w") as f: |
|
json.dump(all_services, f, indent=2) |
|
click.echo(f"Written to {output_path}") |
|
|
|
|
|
@cli.command("template") |
|
@click.argument("template_str") |
|
@click.option("--url", default=None, help="Home Assistant base URL") |
|
@click.pass_context |
|
def template_cmd( |
|
ctx: click.Context, |
|
template_str: str, |
|
url: str | None, |
|
) -> None: |
|
"""Render a Jinja2 template against live HA state.""" |
|
config = Config() |
|
ha_url = config.get_url(override=url or ctx.obj.get("url")) |
|
store = TokenStore() |
|
mgr = AuthManager(store) |
|
access_token = mgr.get_valid_tokens(ha_url) |
|
client = HAClient(ha_url, access_token) |
|
result = client.render_template(template_str) |
|
click.echo(result) |
|
|
|
|
|
if __name__ == "__main__": # pragma: no cover |
|
cli() |