Created
March 11, 2018 22:59
-
-
Save zzeleznick/8e757ac8a211744856296b50a8313d7e to your computer and use it in GitHub Desktop.
Unix Socket Server
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import click | |
# Add this to avoid the annoying warning: http://click.pocoo.org/5/python3/ | |
click.disable_unicode_literals_warning = True | |
import sys | |
from select import select | |
# Internal modules | |
from sockets import SocketServer, ThreadedSocketServer, SocketClient | |
def main(): | |
@click.group(help="Test sockets") | |
@click.pass_context | |
def cli(ctx): | |
pass | |
@cli.command(help="Launch Server") | |
def server(): | |
sock_server = SocketServer() | |
while True: | |
try: | |
sock_server.read() | |
except Exception as e: | |
print e | |
break | |
del sock_server | |
@cli.command(help="Launch Threaded Server") | |
def threaded_server(): | |
sock_server = ThreadedSocketServer() | |
sock_server.listen_for_connections() | |
@cli.command(help="Launch Client") | |
@click.argument('cmd_args', nargs=-1) | |
def client(cmd_args): | |
if cmd_args: | |
click.echo("args: {}".format(cmd_args)) | |
sock_client = SocketClient() | |
sock_client.launch_reader_thread() | |
if cmd_args: | |
msg = " ".join(cmd_args) | |
print "SENDING:", msg | |
sock_client.write(msg) | |
print "Shutting down." | |
return | |
timeout = 5 | |
click.echo("> ", nl=False) | |
while True: | |
try: | |
# Enable exit triggered by server shutdown at max delay 5s | |
rlist, _, _ = select([sys.stdin], [], [], timeout) | |
if rlist: | |
s = sys.stdin.readline().strip() | |
if not s: | |
click.echo("> ", nl=False) | |
continue | |
print "SENDING:", s | |
sock_client.write(s) | |
if "DONE" == s: | |
break | |
click.echo("> ", nl=False) | |
except KeyboardInterrupt: | |
print "Shutting down." | |
sock_client.close() | |
break | |
else: | |
print "Couldn't Connect!" | |
print "Done" | |
cli(obj={}, standalone_mode=False) | |
try: | |
main() | |
except click.exceptions.Abort: | |
print 'Aborted' |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
click==6.7 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from abc import ABCMeta, abstractmethod | |
import os | |
import signal | |
import socket | |
import sys | |
import threading | |
import time | |
from itertools import count | |
from collections import OrderedDict | |
from thread import interrupt_main | |
from Queue import Queue | |
from os.path import abspath, dirname, join | |
class Socket(object): | |
__metaclass__ = ABCMeta | |
DEFAULT_NAME = join(abspath(dirname(__file__)), 'test.socket') | |
def __init__(self, name=""): | |
self.socket_file = name or self.DEFAULT_NAME | |
signal.signal(signal.SIGINT, self.sig_handler) | |
@abstractmethod | |
def read(self): | |
pass | |
@abstractmethod | |
def close(self): | |
pass | |
@abstractmethod | |
def write(self): | |
pass | |
def tearDown(self): | |
self.close() | |
def _remove_socket_file(self): | |
if os.path.exists(self.socket_file): | |
os.remove(self.socket_file) | |
def sig_handler(self, sig, frame): | |
print >>sys.stderr, "** Signal Received **" | |
self.tearDown() | |
raise KeyboardInterrupt() | |
def __del__(self): | |
print >>sys.stderr, "=== Delete Called ===" | |
self.tearDown() | |
print >>sys.stderr, "=== Delete Finished ===" | |
class SocketServer(Socket): | |
def __init__(self, name=""): | |
super(SocketServer, self).__init__(name) | |
self._remove_socket_file() | |
print >>sys.stderr, "=== Opening socket ===" | |
self.server = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) | |
self.server.bind(self.socket_file) | |
self.connection = None | |
self._listening = False | |
def listen(self): | |
print >>sys.stderr, "=== Listening ===" | |
# Listen for incoming connections | |
self.server.listen(1) | |
print >>sys.stderr, "=== Waiting for a Connection ===" | |
self.connection, _ = self.server.accept() | |
print >>sys.stderr, '*** Accepted connection ***' | |
self._listening = True | |
def read(self): | |
if not self._listening: | |
self.listen() | |
print >>sys.stderr, "=== Waiting for input === " | |
msg = self.connection.recv(1024) | |
print >>sys.stderr, 'Received', msg | |
if not msg: | |
print >>sys.stderr, 'Empty message' | |
return msg | |
def close(self): | |
print >>sys.stderr, "=== Cleaning up the connection ===" | |
if self.connection: | |
self.connection.close() | |
self.server.close() | |
def write(self, res): | |
try: | |
if not self._listening: | |
self.listen() | |
self.connection.send("{}".format(res)) | |
except Exception as e: | |
print >>sys.stderr, e | |
class ThreadedSocketServer(Socket): | |
def __init__(self, name=""): | |
super(ThreadedSocketServer, self).__init__(name) | |
self._remove_socket_file() | |
print >>sys.stderr, "=== Opening socket ===" | |
self.server = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) | |
self.server.bind(self.socket_file) | |
self.counter = count(1) | |
self.connections = OrderedDict() | |
self.threads = OrderedDict() | |
self._listening = False | |
self.queue = Queue() | |
def read(self, connection, conn_idx=None): | |
"""A threaded read""" | |
current_thread = threading.current_thread() | |
retries = 2 | |
ctr = count(1) | |
while self._listening: | |
print >>sys.stderr, "=== {} Waiting for input === ".format(current_thread) | |
msg = connection.recv(1024) | |
print >>sys.stderr, '** {} Received {} **'.format(current_thread, msg) | |
if not msg: | |
print >>sys.stderr, '** {} Empty message **'.format(current_thread) | |
if retries == 0: # Client did not send connection close | |
print >>sys.stderr, '** Closing {} due to empty messages **'.format(current_thread) | |
break | |
retries -= 1 | |
else: | |
retries = 2 | |
if 'echo' in msg: | |
print >>sys.stderr, '** {} Writing back message {} **'.format(current_thread, msg) | |
self.write(connection, msg) | |
elif 'task' in msg: | |
task = " ".join(msg.split('task')).strip() | |
task_id = "{}-{}".format(conn_idx, next(ctr)) | |
resp = "Added task ({}): '{}' - Qsize: {}".format(task_id, task, self.queue.qsize()) | |
print >>sys.stderr, '** {} {} **'.format(current_thread, resp) | |
self.queue.put((task_id, task)) | |
self.write(connection, resp) | |
print "** {} exits **".format(current_thread) | |
if conn_idx in self.connections: | |
print "** Removing connection {} **".format(conn_idx) | |
self.connections[conn_idx].close() | |
del self.connections[conn_idx] | |
if conn_idx in self.threads: | |
# Cannot join current thread (https://github.com/python/cpython/blob/2.7/Lib/threading.py#L931) | |
# Just remove reference to it | |
print "** Releasing thread {} **".format(conn_idx) | |
del self.threads[conn_idx] | |
def process_queue(self): | |
print >>sys.stderr, "=== Starting process_queue ===" | |
while self._listening: | |
task_id, task = self.queue.get() | |
print >>sys.stderr, '** Executing task ({}): "{}" **'.format(task_id, task) | |
time.sleep(3) # add fake delay | |
self.queue.task_done() | |
conn_idx = int(task_id.split("-")[0]) | |
resp = "Completed task ({}): '{}'".format(task_id, task) | |
if conn_idx in self.connections: | |
connection = self.connections[conn_idx] | |
self.write(connection, resp) | |
else: | |
# Possibly client disconnected while we were working | |
print >>sys.stderr, "Could not send client {} done status".format(conn_idx) | |
print >>sys.stderr, "=== End of process_queue ===" | |
def do_work(self): | |
thread = threading.Thread(target=self.process_queue) | |
thread.daemon = True | |
thread.start() | |
def listen_for_connections(self): | |
print >>sys.stderr, "=== Listening ===" | |
self._listening = True | |
self.do_work() | |
while self._listening: | |
try: | |
self.server.listen(1) | |
print >>sys.stderr, "=== Waiting for a Connection ===" | |
connection, _ = self.server.accept() | |
conn_idx = next(self.counter) | |
print >>sys.stderr, '*** Accepted connection {} ***'.format(conn_idx) | |
self.connections[conn_idx] = connection | |
thread = threading.Thread(target=self.read, args=(connection, conn_idx,)) | |
self.threads[conn_idx] = thread | |
thread.daemon = True | |
thread.start() | |
except Exception as e: | |
print >>sys.stderr, e | |
break | |
self._listening = False | |
print >>sys.stderr, "=== End of listen_for_connections ===" | |
def close(self): | |
print >>sys.stderr, "=== Cleaning up the connection ===" | |
print self.connections | |
for (idx, connection) in self.connections.iteritems(): | |
print >>sys.stderr, "** Closing connecton {} **".format(idx) | |
connection.close() | |
print self.threads | |
for (idx, thread) in self.threads.iteritems(): | |
thread.join(1) | |
self._remove_socket_file() | |
print >>sys.stderr, "=== Finished close ===" | |
def tearDown(self): | |
self.close() | |
self._remove_socket_file() | |
def write(self, connection, res): | |
try: | |
connection.send("{}".format(res)) | |
except Exception as e: | |
print >>sys.stderr, e | |
class SocketClient(Socket): | |
def __init__(self, name=""): | |
super(SocketClient, self).__init__(name) | |
self._listening = False | |
self.reader_thread = None | |
self.client = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) | |
self.client.connect(self.socket_file) | |
self._listening = True | |
def read_loop(self): | |
retries = 2 | |
while self._listening: | |
msg = self.read() | |
if not msg: | |
print >>sys.stderr, 'Empty message' | |
if retries == 0: # Something went wrong | |
print >>sys.stderr, 'Closing reader thread' | |
break | |
retries -= 1 | |
else: | |
retries = 2 | |
print >>sys.stderr, 'read_loop completed -- server exited?' | |
interrupt_main() | |
def read(self): | |
print >>sys.stderr, 'Waiting for input' | |
msg = self.client.recv(1024) | |
print >>sys.stderr, '** Received', msg, '**' | |
return msg | |
def launch_reader_thread(self): | |
print >>sys.stderr, 'Launching reader_thread' | |
if not self.client: | |
raise Exception("Not connected") | |
self.reader_thread = threading.Thread(target=self.read_loop) | |
self.reader_thread.daemon = True | |
self.reader_thread.start() | |
def write(self, res): | |
try: | |
self.client.send("{}".format(res)) | |
except Exception as e: | |
print >>sys.stderr, e | |
raise | |
def close(self): | |
self._listening = False | |
if self.reader_thread: | |
self.reader_thread.join(1) | |
if self.client: | |
self.client.close() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment