289 lines
9.3 KiB
Python
289 lines
9.3 KiB
Python
|
#
|
||
|
# Copyright BitBake Contributors
|
||
|
#
|
||
|
# SPDX-License-Identifier: GPL-2.0-only
|
||
|
#
|
||
|
|
||
|
import abc
|
||
|
import asyncio
|
||
|
import json
|
||
|
import os
|
||
|
import signal
|
||
|
import socket
|
||
|
import sys
|
||
|
import multiprocessing
|
||
|
from . import chunkify, DEFAULT_MAX_CHUNK
|
||
|
|
||
|
|
||
|
class ClientError(Exception):
|
||
|
pass
|
||
|
|
||
|
|
||
|
class ServerError(Exception):
|
||
|
pass
|
||
|
|
||
|
|
||
|
class AsyncServerConnection(object):
|
||
|
def __init__(self, reader, writer, proto_name, logger):
|
||
|
self.reader = reader
|
||
|
self.writer = writer
|
||
|
self.proto_name = proto_name
|
||
|
self.max_chunk = DEFAULT_MAX_CHUNK
|
||
|
self.handlers = {
|
||
|
'chunk-stream': self.handle_chunk,
|
||
|
'ping': self.handle_ping,
|
||
|
}
|
||
|
self.logger = logger
|
||
|
|
||
|
async def process_requests(self):
|
||
|
try:
|
||
|
self.addr = self.writer.get_extra_info('peername')
|
||
|
self.logger.debug('Client %r connected' % (self.addr,))
|
||
|
|
||
|
# Read protocol and version
|
||
|
client_protocol = await self.reader.readline()
|
||
|
if client_protocol is None:
|
||
|
return
|
||
|
|
||
|
(client_proto_name, client_proto_version) = client_protocol.decode('utf-8').rstrip().split()
|
||
|
if client_proto_name != self.proto_name:
|
||
|
self.logger.debug('Rejecting invalid protocol %s' % (self.proto_name))
|
||
|
return
|
||
|
|
||
|
self.proto_version = tuple(int(v) for v in client_proto_version.split('.'))
|
||
|
if not self.validate_proto_version():
|
||
|
self.logger.debug('Rejecting invalid protocol version %s' % (client_proto_version))
|
||
|
return
|
||
|
|
||
|
# Read headers. Currently, no headers are implemented, so look for
|
||
|
# an empty line to signal the end of the headers
|
||
|
while True:
|
||
|
line = await self.reader.readline()
|
||
|
if line is None:
|
||
|
return
|
||
|
|
||
|
line = line.decode('utf-8').rstrip()
|
||
|
if not line:
|
||
|
break
|
||
|
|
||
|
# Handle messages
|
||
|
while True:
|
||
|
d = await self.read_message()
|
||
|
if d is None:
|
||
|
break
|
||
|
await self.dispatch_message(d)
|
||
|
await self.writer.drain()
|
||
|
except ClientError as e:
|
||
|
self.logger.error(str(e))
|
||
|
finally:
|
||
|
self.writer.close()
|
||
|
|
||
|
async def dispatch_message(self, msg):
|
||
|
for k in self.handlers.keys():
|
||
|
if k in msg:
|
||
|
self.logger.debug('Handling %s' % k)
|
||
|
await self.handlers[k](msg[k])
|
||
|
return
|
||
|
|
||
|
raise ClientError("Unrecognized command %r" % msg)
|
||
|
|
||
|
def write_message(self, msg):
|
||
|
for c in chunkify(json.dumps(msg), self.max_chunk):
|
||
|
self.writer.write(c.encode('utf-8'))
|
||
|
|
||
|
async def read_message(self):
|
||
|
l = await self.reader.readline()
|
||
|
if not l:
|
||
|
return None
|
||
|
|
||
|
try:
|
||
|
message = l.decode('utf-8')
|
||
|
|
||
|
if not message.endswith('\n'):
|
||
|
return None
|
||
|
|
||
|
return json.loads(message)
|
||
|
except (json.JSONDecodeError, UnicodeDecodeError) as e:
|
||
|
self.logger.error('Bad message from client: %r' % message)
|
||
|
raise e
|
||
|
|
||
|
async def handle_chunk(self, request):
|
||
|
lines = []
|
||
|
try:
|
||
|
while True:
|
||
|
l = await self.reader.readline()
|
||
|
l = l.rstrip(b"\n").decode("utf-8")
|
||
|
if not l:
|
||
|
break
|
||
|
lines.append(l)
|
||
|
|
||
|
msg = json.loads(''.join(lines))
|
||
|
except (json.JSONDecodeError, UnicodeDecodeError) as e:
|
||
|
self.logger.error('Bad message from client: %r' % lines)
|
||
|
raise e
|
||
|
|
||
|
if 'chunk-stream' in msg:
|
||
|
raise ClientError("Nested chunks are not allowed")
|
||
|
|
||
|
await self.dispatch_message(msg)
|
||
|
|
||
|
async def handle_ping(self, request):
|
||
|
response = {'alive': True}
|
||
|
self.write_message(response)
|
||
|
|
||
|
|
||
|
class AsyncServer(object):
|
||
|
def __init__(self, logger):
|
||
|
self._cleanup_socket = None
|
||
|
self.logger = logger
|
||
|
self.start = None
|
||
|
self.address = None
|
||
|
self.loop = None
|
||
|
|
||
|
def start_tcp_server(self, host, port):
|
||
|
def start_tcp():
|
||
|
self.server = self.loop.run_until_complete(
|
||
|
asyncio.start_server(self.handle_client, host, port)
|
||
|
)
|
||
|
|
||
|
for s in self.server.sockets:
|
||
|
self.logger.debug('Listening on %r' % (s.getsockname(),))
|
||
|
# Newer python does this automatically. Do it manually here for
|
||
|
# maximum compatibility
|
||
|
s.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1)
|
||
|
s.setsockopt(socket.SOL_TCP, socket.TCP_QUICKACK, 1)
|
||
|
|
||
|
name = self.server.sockets[0].getsockname()
|
||
|
if self.server.sockets[0].family == socket.AF_INET6:
|
||
|
self.address = "[%s]:%d" % (name[0], name[1])
|
||
|
else:
|
||
|
self.address = "%s:%d" % (name[0], name[1])
|
||
|
|
||
|
self.start = start_tcp
|
||
|
|
||
|
def start_unix_server(self, path):
|
||
|
def cleanup():
|
||
|
os.unlink(path)
|
||
|
|
||
|
def start_unix():
|
||
|
cwd = os.getcwd()
|
||
|
try:
|
||
|
# Work around path length limits in AF_UNIX
|
||
|
os.chdir(os.path.dirname(path))
|
||
|
self.server = self.loop.run_until_complete(
|
||
|
asyncio.start_unix_server(self.handle_client, os.path.basename(path))
|
||
|
)
|
||
|
finally:
|
||
|
os.chdir(cwd)
|
||
|
|
||
|
self.logger.debug('Listening on %r' % path)
|
||
|
|
||
|
self._cleanup_socket = cleanup
|
||
|
self.address = "unix://%s" % os.path.abspath(path)
|
||
|
|
||
|
self.start = start_unix
|
||
|
|
||
|
@abc.abstractmethod
|
||
|
def accept_client(self, reader, writer):
|
||
|
pass
|
||
|
|
||
|
async def handle_client(self, reader, writer):
|
||
|
# writer.transport.set_write_buffer_limits(0)
|
||
|
try:
|
||
|
client = self.accept_client(reader, writer)
|
||
|
await client.process_requests()
|
||
|
except Exception as e:
|
||
|
import traceback
|
||
|
self.logger.error('Error from client: %s' % str(e), exc_info=True)
|
||
|
traceback.print_exc()
|
||
|
writer.close()
|
||
|
self.logger.debug('Client disconnected')
|
||
|
|
||
|
def run_loop_forever(self):
|
||
|
try:
|
||
|
self.loop.run_forever()
|
||
|
except KeyboardInterrupt:
|
||
|
pass
|
||
|
|
||
|
def signal_handler(self):
|
||
|
self.logger.debug("Got exit signal")
|
||
|
self.loop.stop()
|
||
|
|
||
|
def _serve_forever(self):
|
||
|
try:
|
||
|
self.loop.add_signal_handler(signal.SIGTERM, self.signal_handler)
|
||
|
signal.pthread_sigmask(signal.SIG_UNBLOCK, [signal.SIGTERM])
|
||
|
|
||
|
self.run_loop_forever()
|
||
|
self.server.close()
|
||
|
|
||
|
self.loop.run_until_complete(self.server.wait_closed())
|
||
|
self.logger.debug('Server shutting down')
|
||
|
finally:
|
||
|
if self._cleanup_socket is not None:
|
||
|
self._cleanup_socket()
|
||
|
|
||
|
def serve_forever(self):
|
||
|
"""
|
||
|
Serve requests in the current process
|
||
|
"""
|
||
|
# Create loop and override any loop that may have existed in
|
||
|
# a parent process. It is possible that the usecases of
|
||
|
# serve_forever might be constrained enough to allow using
|
||
|
# get_event_loop here, but better safe than sorry for now.
|
||
|
self.loop = asyncio.new_event_loop()
|
||
|
asyncio.set_event_loop(self.loop)
|
||
|
self.start()
|
||
|
self._serve_forever()
|
||
|
|
||
|
def serve_as_process(self, *, prefunc=None, args=()):
|
||
|
"""
|
||
|
Serve requests in a child process
|
||
|
"""
|
||
|
def run(queue):
|
||
|
# Create loop and override any loop that may have existed
|
||
|
# in a parent process. Without doing this and instead
|
||
|
# using get_event_loop, at the very minimum the hashserv
|
||
|
# unit tests will hang when running the second test.
|
||
|
# This happens since get_event_loop in the spawned server
|
||
|
# process for the second testcase ends up with the loop
|
||
|
# from the hashserv client created in the unit test process
|
||
|
# when running the first testcase. The problem is somewhat
|
||
|
# more general, though, as any potential use of asyncio in
|
||
|
# Cooker could create a loop that needs to replaced in this
|
||
|
# new process.
|
||
|
self.loop = asyncio.new_event_loop()
|
||
|
asyncio.set_event_loop(self.loop)
|
||
|
try:
|
||
|
self.start()
|
||
|
finally:
|
||
|
queue.put(self.address)
|
||
|
queue.close()
|
||
|
|
||
|
if prefunc is not None:
|
||
|
prefunc(self, *args)
|
||
|
|
||
|
self._serve_forever()
|
||
|
|
||
|
if sys.version_info >= (3, 6):
|
||
|
self.loop.run_until_complete(self.loop.shutdown_asyncgens())
|
||
|
self.loop.close()
|
||
|
|
||
|
queue = multiprocessing.Queue()
|
||
|
|
||
|
# Temporarily block SIGTERM. The server process will inherit this
|
||
|
# block which will ensure it doesn't receive the SIGTERM until the
|
||
|
# handler is ready for it
|
||
|
mask = signal.pthread_sigmask(signal.SIG_BLOCK, [signal.SIGTERM])
|
||
|
try:
|
||
|
self.process = multiprocessing.Process(target=run, args=(queue,))
|
||
|
self.process.start()
|
||
|
|
||
|
self.address = queue.get()
|
||
|
queue.close()
|
||
|
queue.join_thread()
|
||
|
|
||
|
return self.process
|
||
|
finally:
|
||
|
signal.pthread_sigmask(signal.SIG_SETMASK, mask)
|