Created
April 24, 2025 12:04
-
-
Save ehzawad/ccb9ec669a2cb7d9f12c3978b88d4882 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
import asyncio | |
import aiohttp | |
import json | |
import time | |
import argparse | |
import os | |
import sys | |
import logging | |
from typing import Dict, Any, Optional, List, Text | |
# Set up logging | |
logging.basicConfig( | |
level=logging.INFO, | |
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' | |
) | |
logger = logging.getLogger(__name__) | |
class RasaCLI: | |
def __init__( | |
self, | |
rasa_port: Optional[int] = None, | |
action_port: Optional[int] = None | |
): | |
"""Initialize Rasa client for interactive CLI.""" | |
# Use default ports if not specified | |
self.rasa_port = rasa_port or 5005 | |
self.action_port = action_port or 5054 | |
# Configure base URLs with appropriate ports | |
self.rasa_base_url = f"http://localhost:{self.rasa_port}" | |
self.action_base_url = f"http://localhost:{self.action_port}" | |
self.via_number = "8809611888444" # Fixed via number | |
self.sender_id = None | |
# Test numbers for masking | |
self.test_numbers = [ | |
'09696387582', '09638372914', '01924560627', '01518472623', | |
'01580582654', '01833626976', '01571321136', '01764655648', | |
'09638317055', '09638080760', '09696173224', '09611888444', | |
'01911310316', '19723182900', '01558666739', '01714007806', | |
'01714020387' | |
] | |
self.session = None | |
async def connect(self): | |
"""Create a hardened aiohttp session that resists broken pipes.""" | |
if self.session is not None: | |
await self.close() | |
# Create a more resilient session with longer timeouts and connection hardening | |
timeout = aiohttp.ClientTimeout(sock_connect=10, sock_read=120, total=120) | |
connector = aiohttp.TCPConnector( | |
limit=20, # Sane concurrency | |
force_close=True, # Never reuse a keep-alive connection | |
enable_cleanup_closed=True, # Reap closed transports quickly | |
) | |
self.session = aiohttp.ClientSession( | |
timeout=timeout, | |
connector=connector, | |
) | |
async def close(self): | |
"""Close the aiohttp session.""" | |
if self.session: | |
await self.session.close() | |
self.session = None | |
async def _request_with_retry(self, method: str, url: str, **kwargs): | |
"""Make a request with one retry on broken pipe error.""" | |
if self.session is None: | |
await self.connect() | |
for attempt in (1, 2): # At most one retry | |
try: | |
async with self.session.request(method, url, **kwargs) as resp: | |
if resp.status >= 400: | |
logger.warning(f"HTTP {resp.status} from {method} {url}") | |
return await resp.json() | |
except aiohttp.ClientOSError as e: | |
if e.errno == 32 and attempt == 1: # Broken pipe, first attempt | |
logger.warning(f"Broken pipe → refreshing session & retrying {method} {url}...") | |
await self.connect() # Reconnect and retry | |
continue | |
raise # Re-raise on second failure or other error types | |
async def send_message(self, message_text: str) -> Dict[Any, Any]: | |
"""Send message to Rasa server and get response.""" | |
try: | |
# Make sure we have a session | |
if self.session is None: | |
await self.connect() | |
# First get NLU parse result | |
parse_result = await self._request_with_retry( | |
"POST", | |
f"{self.rasa_base_url}/model/parse", | |
json={"text": message_text} | |
) | |
intent_info = parse_result.get('intent', {}) | |
intent_name = intent_info.get('name', 'unknown') | |
intent_confidence = intent_info.get('confidence', 0.0) | |
entities = parse_result.get('entities', []) | |
# Send message to webhook | |
bot_response = await self._request_with_retry( | |
"POST", | |
f"{self.rasa_base_url}/webhooks/rest/webhook", | |
json={"sender": self.sender_id, "message": message_text} | |
) | |
# Get tracker state | |
tracker_data = await self._request_with_retry( | |
"GET", | |
f"{self.rasa_base_url}/conversations/{self.sender_id}/tracker" | |
) | |
# Get next action prediction | |
prediction = await self._request_with_retry( | |
"POST", | |
f"{self.rasa_base_url}/conversations/{self.sender_id}/predict" | |
) | |
next_action = prediction.get("scores", [{}])[0].get("action", "None") | |
confidence = prediction.get("scores", [{}])[0].get("score", 0.0) | |
return { | |
"response": bot_response, | |
"intent": {"name": intent_name, "confidence": intent_confidence}, | |
"entities": entities, | |
"tracker_data": tracker_data, | |
"next_action": { | |
"name": next_action, | |
"confidence": confidence | |
} | |
} | |
except aiohttp.ClientError as e: | |
logger.error(f"Error communicating with Rasa server: {str(e)}") | |
return { | |
"response": [], | |
"intent": {"name": "unknown", "confidence": 0.0}, | |
"entities": [], | |
"tracker_data": {}, | |
"next_action": {"name": "None", "confidence": 0.0} | |
} | |
except Exception as e: | |
logger.error(f"Unexpected error: {str(e)}") | |
return { | |
"response": [], | |
"intent": {"name": "unknown", "confidence": 0.0}, | |
"entities": [], | |
"tracker_data": {}, | |
"next_action": {"name": "None", "confidence": 0.0} | |
} | |
def get_bot_response_text(self, response: Dict[Any, Any]) -> str: | |
"""Extract text from bot response.""" | |
if not response.get("response"): | |
return "No response from bot" | |
response_texts = [] | |
for message in response["response"]: | |
if "text" in message: | |
response_texts.append(message["text"].replace('[service_response]', '').strip()) | |
return response_texts[0] if response_texts else "No text response from bot" | |
def print_tracker_state(self, response_data: Dict[str, Any]): | |
"""Print detailed tracker state information.""" | |
print("\n" + "="*40) | |
print("TRACKER STATE".center(40)) | |
tracker_data = response_data.get("tracker_data", {}) | |
# Active form | |
active_form = tracker_data.get("active_loop", {}).get("name") | |
print(f"Active Form: {active_form or 'None'}") | |
# Intent | |
intent_name = response_data["intent"]["name"] | |
intent_confidence = response_data["intent"]["confidence"] | |
print(f"Intent Name: {intent_name}") | |
print(f"Intent Confidence: {intent_confidence:.4f}") | |
# Slots | |
print("\nCurrent Slots:") | |
slots = tracker_data.get("slots", {}) | |
filled_slots = False | |
for slot, value in slots.items(): | |
if value: # Only print non-empty slots | |
print(f" - {slot}: {repr(value)}") | |
filled_slots = True | |
if not filled_slots: | |
print(" (No filled slots)") | |
# Actions | |
latest_action = tracker_data.get("latest_action_name", "None") | |
next_action = response_data.get("next_action", {}) | |
print(f"\nLatest Action: {latest_action}") | |
print(f"Next Predicted Action: {next_action.get('name')} (confidence: {next_action.get('confidence', 0.0):.4f})") | |
# Recent Events - Improved presentation | |
print("\nRecent Events (Chronological):") | |
# Get a sufficient number of events to ensure we capture a complete exchange | |
events = tracker_data.get("events", [])[-10:] # Get last 10 events | |
# Group events for better readability | |
current_user_msg = None | |
current_action = None | |
bot_responses = [] | |
for event in events: | |
event_type = event.get("event") | |
if event_type == "user": | |
# Print any pending actions and bot responses before showing the new user message | |
if current_action: | |
print(f" Action Executed: {current_action}") | |
current_action = None | |
if bot_responses: | |
for resp in bot_responses: | |
print(f" Bot: {repr(resp)}") | |
bot_responses = [] | |
# Now print the user message | |
intent = event.get("parse_data", {}).get("intent", {}).get("name", "None") | |
current_user_msg = event.get('text') | |
print(f" User: {current_user_msg} → Intent: {intent}") | |
elif event_type == "bot": | |
bot_text = event.get('text') | |
if bot_text: # Only add non-empty messages | |
bot_responses.append(bot_text) | |
elif event_type == "action": | |
# If we get a new action, print any previous action first | |
if current_action and current_action != event.get('name'): | |
print(f" Action Executed: {current_action}") | |
current_action = event.get('name') | |
# If the action is action_listen and we have bot responses, print them | |
if current_action == "action_listen" and bot_responses: | |
for resp in bot_responses: | |
print(f" Bot: {repr(resp)}") | |
bot_responses = [] | |
elif event_type == "slot": | |
print(f" Slot Set: {event.get('name')} = {repr(event.get('value'))}") | |
elif event_type == "active_loop": | |
status = "Started" if event.get("name") else "Stopped" | |
print(f" Active Loop: {event.get('name', 'None')} ({status})") | |
# Print any remaining actions or bot responses | |
if current_action: | |
print(f" Action Executed: {current_action}") | |
for resp in bot_responses: | |
print(f" Bot: {repr(resp)}") | |
print("="*40 + "\n") | |
async def check_server_health(self): | |
"""Check if Rasa servers are running.""" | |
# Create a separate session just for health checks | |
timeout = aiohttp.ClientTimeout(total=10) | |
connector = aiohttp.TCPConnector(force_close=True) | |
async with aiohttp.ClientSession(timeout=timeout, connector=connector) as session: | |
try: | |
# Check Rasa server | |
try: | |
async with session.get(f"{self.rasa_base_url}/") as response: | |
rasa_status = response.status == 200 | |
if not rasa_status: | |
logger.error(f"Rasa server health check failed: {response.status}") | |
except Exception as e: | |
logger.error(f"Rasa server check failed: {e}") | |
return False | |
# Check Action server | |
try: | |
async with session.get(f"{self.action_base_url}/health") as response: | |
action_status = response.status == 200 | |
if not action_status: | |
logger.error(f"Action server health check failed: {response.status}") | |
except Exception as e: | |
logger.error(f"Action server check failed: {e}") | |
return False | |
return rasa_status and action_status | |
except Exception as e: | |
logger.error(f"Error checking server health: {e}") | |
return False | |
async def initialize_session(self): | |
"""Initialize a new session with user input.""" | |
print("\n" + "="*50) | |
print("Welcome to Rasa Interactive CLI!".center(50)) | |
print("="*50) | |
# Get session tuple from user | |
print("\nPlease enter your session information in this format:") | |
print("(session_id, phone_number, alternate_number)") | |
session_input = input("\n> ") | |
try: | |
# Parse the input format - handle spaces and parentheses carefully | |
session_input = session_input.strip() | |
# Remove parentheses if present | |
if session_input.startswith("(") and session_input.endswith(")"): | |
session_input = session_input[1:-1] | |
# Split by comma and strip whitespace | |
parts = [part.strip() for part in session_input.split(",")] | |
if len(parts) >= 3: | |
call_id = parts[0].strip() | |
phone_number = parts[1].strip() | |
alt_number = parts[2].strip() | |
# Handle number masking | |
masked_number = phone_number | |
if phone_number[-11:] in self.test_numbers: | |
masked_number = '01568725958' | |
# Use timestamp from NOW to generate the sender_id | |
current_time = int(time.time() * 1000) | |
self.sender_id = f"{current_time}_{self.via_number}_{masked_number}" | |
logger.info(f"Session initialized with sender_id: {self.sender_id}") | |
print("\nSession initialized successfully.") | |
print("Type your messages below. Type 'exit' to quit.") | |
print("User: ", end="", flush=True) | |
return True | |
else: | |
logger.error("Failed to parse session information - incorrect format") | |
print("\nError: Could not parse session information.") | |
print("Please use format: (session_id, phone_number, alternate_number)") | |
return False | |
except Exception as e: | |
logger.error(f"Error initializing session: {e}") | |
print(f"\nError initializing session: {e}") | |
return False | |
async def start_interactive_session(self): | |
"""Start an interactive session with the bot.""" | |
try: | |
# Check if servers are running | |
is_healthy = await self.check_server_health() | |
if not is_healthy: | |
print(f"Error: Cannot connect to Rasa (port {self.rasa_port}) or Action (port {self.action_port}) server") | |
print("Please ensure both servers are running") | |
return | |
# Initialize session with user information | |
session_initialized = await self.initialize_session() | |
if not session_initialized: | |
print("Failed to initialize session. Exiting.") | |
return | |
# Initialize the connection once at the start | |
await self.connect() | |
# Main conversation loop | |
while True: | |
try: | |
# Get user input (we already printed the prompt) | |
user_input = input() | |
# Check for exit command | |
if user_input.lower() in ["exit", "quit", "bye"]: | |
print("Exiting Rasa Interactive CLI. Goodbye!") | |
break | |
# Send message and get response with full tracker info | |
response_data = await self.send_message(user_input) | |
# Print bot response | |
bot_text = self.get_bot_response_text(response_data) | |
print(f"Bot: {bot_text}") | |
# Print detailed tracker state | |
self.print_tracker_state(response_data) | |
# Print next user prompt | |
print("User: ", end="", flush=True) | |
except KeyboardInterrupt: | |
print("\nExiting due to keyboard interrupt.") | |
break | |
except aiohttp.ClientOSError as e: | |
if e.errno == 32: # Broken pipe | |
logger.warning(f"Connection reset (broken pipe). Reconnecting...") | |
await self.connect() # Force reconnection | |
print("User: ", end="", flush=True) | |
else: | |
logger.error(f"Connection error: {e}") | |
print(f"Connection error: {e}") | |
print("User: ", end="", flush=True) | |
except Exception as e: | |
logger.error(f"Error in message loop: {e}") | |
print(f"Error processing message: {e}") | |
print("User: ", end="", flush=True) | |
except Exception as e: | |
logger.error(f"Critical error in interactive session: {e}") | |
print(f"Critical error: {e}") | |
finally: | |
# Ensure we close the session | |
await self.close() | |
async def main(): | |
parser = argparse.ArgumentParser(description="Interactive CLI for Rasa") | |
parser.add_argument('--rasa-port', type=int, default=5005, help='Rasa server port (default: 5005)') | |
parser.add_argument('--action-port', type=int, default=5054, help='Action server port (default: 5054)') | |
parser.add_argument('--debug', action='store_true', help='Enable debug output') | |
args = parser.parse_args() | |
# Set debug level if requested | |
if args.debug: | |
logging.getLogger().setLevel(logging.DEBUG) | |
# Create and start the CLI | |
cli = RasaCLI(rasa_port=args.rasa_port, action_port=args.action_port) | |
try: | |
await cli.start_interactive_session() | |
except Exception as e: | |
logger.error(f"Unhandled exception: {e}") | |
sys.exit(1) | |
finally: | |
await cli.close() | |
if __name__ == "__main__": | |
try: | |
asyncio.run(main()) | |
except KeyboardInterrupt: | |
print("\nProgram terminated by user") | |
except Exception as e: | |
print(f"Fatal error: {str(e)}") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment