156 lines
4.3 KiB
Python
156 lines
4.3 KiB
Python
# Copyright (C) 2018-2019 Garmin Ltd.
|
|
#
|
|
# SPDX-License-Identifier: GPL-2.0-only
|
|
#
|
|
|
|
import asyncio
|
|
from contextlib import closing
|
|
import re
|
|
import sqlite3
|
|
import itertools
|
|
import json
|
|
|
|
UNIX_PREFIX = "unix://"
|
|
|
|
ADDR_TYPE_UNIX = 0
|
|
ADDR_TYPE_TCP = 1
|
|
|
|
# The Python async server defaults to a 64K receive buffer, so we hardcode our
|
|
# maximum chunk size. It would be better if the client and server reported to
|
|
# each other what the maximum chunk sizes were, but that will slow down the
|
|
# connection setup with a round trip delay so I'd rather not do that unless it
|
|
# is necessary
|
|
DEFAULT_MAX_CHUNK = 32 * 1024
|
|
|
|
UNIHASH_TABLE_DEFINITION = (
|
|
("method", "TEXT NOT NULL", "UNIQUE"),
|
|
("taskhash", "TEXT NOT NULL", "UNIQUE"),
|
|
("unihash", "TEXT NOT NULL", ""),
|
|
)
|
|
|
|
UNIHASH_TABLE_COLUMNS = tuple(name for name, _, _ in UNIHASH_TABLE_DEFINITION)
|
|
|
|
OUTHASH_TABLE_DEFINITION = (
|
|
("method", "TEXT NOT NULL", "UNIQUE"),
|
|
("taskhash", "TEXT NOT NULL", "UNIQUE"),
|
|
("outhash", "TEXT NOT NULL", "UNIQUE"),
|
|
("created", "DATETIME", ""),
|
|
|
|
# Optional fields
|
|
("owner", "TEXT", ""),
|
|
("PN", "TEXT", ""),
|
|
("PV", "TEXT", ""),
|
|
("PR", "TEXT", ""),
|
|
("task", "TEXT", ""),
|
|
("outhash_siginfo", "TEXT", ""),
|
|
)
|
|
|
|
OUTHASH_TABLE_COLUMNS = tuple(name for name, _, _ in OUTHASH_TABLE_DEFINITION)
|
|
|
|
def _make_table(cursor, name, definition):
|
|
cursor.execute('''
|
|
CREATE TABLE IF NOT EXISTS {name} (
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
{fields}
|
|
UNIQUE({unique})
|
|
)
|
|
'''.format(
|
|
name=name,
|
|
fields=" ".join("%s %s," % (name, typ) for name, typ, _ in definition),
|
|
unique=", ".join(name for name, _, flags in definition if "UNIQUE" in flags)
|
|
))
|
|
|
|
|
|
def setup_database(database, sync=True):
|
|
db = sqlite3.connect(database)
|
|
db.row_factory = sqlite3.Row
|
|
|
|
with closing(db.cursor()) as cursor:
|
|
_make_table(cursor, "unihashes_v2", UNIHASH_TABLE_DEFINITION)
|
|
_make_table(cursor, "outhashes_v2", OUTHASH_TABLE_DEFINITION)
|
|
|
|
cursor.execute('PRAGMA journal_mode = WAL')
|
|
cursor.execute('PRAGMA synchronous = %s' % ('NORMAL' if sync else 'OFF'))
|
|
|
|
# Drop old indexes
|
|
cursor.execute('DROP INDEX IF EXISTS taskhash_lookup')
|
|
cursor.execute('DROP INDEX IF EXISTS outhash_lookup')
|
|
cursor.execute('DROP INDEX IF EXISTS taskhash_lookup_v2')
|
|
cursor.execute('DROP INDEX IF EXISTS outhash_lookup_v2')
|
|
|
|
# TODO: Upgrade from tasks_v2?
|
|
cursor.execute('DROP TABLE IF EXISTS tasks_v2')
|
|
|
|
# Create new indexes
|
|
cursor.execute('CREATE INDEX IF NOT EXISTS taskhash_lookup_v3 ON unihashes_v2 (method, taskhash)')
|
|
cursor.execute('CREATE INDEX IF NOT EXISTS outhash_lookup_v3 ON outhashes_v2 (method, outhash)')
|
|
|
|
return db
|
|
|
|
|
|
def parse_address(addr):
|
|
if addr.startswith(UNIX_PREFIX):
|
|
return (ADDR_TYPE_UNIX, (addr[len(UNIX_PREFIX):],))
|
|
else:
|
|
m = re.match(r'\[(?P<host>[^\]]*)\]:(?P<port>\d+)$', addr)
|
|
if m is not None:
|
|
host = m.group('host')
|
|
port = m.group('port')
|
|
else:
|
|
host, port = addr.split(':')
|
|
|
|
return (ADDR_TYPE_TCP, (host, int(port)))
|
|
|
|
|
|
def chunkify(msg, max_chunk):
|
|
if len(msg) < max_chunk - 1:
|
|
yield ''.join((msg, "\n"))
|
|
else:
|
|
yield ''.join((json.dumps({
|
|
'chunk-stream': None
|
|
}), "\n"))
|
|
|
|
args = [iter(msg)] * (max_chunk - 1)
|
|
for m in map(''.join, itertools.zip_longest(*args, fillvalue='')):
|
|
yield ''.join(itertools.chain(m, "\n"))
|
|
yield "\n"
|
|
|
|
|
|
def create_server(addr, dbname, *, sync=True, upstream=None, read_only=False):
|
|
from . import server
|
|
db = setup_database(dbname, sync=sync)
|
|
s = server.Server(db, upstream=upstream, read_only=read_only)
|
|
|
|
(typ, a) = parse_address(addr)
|
|
if typ == ADDR_TYPE_UNIX:
|
|
s.start_unix_server(*a)
|
|
else:
|
|
s.start_tcp_server(*a)
|
|
|
|
return s
|
|
|
|
|
|
def create_client(addr):
|
|
from . import client
|
|
c = client.Client()
|
|
|
|
(typ, a) = parse_address(addr)
|
|
if typ == ADDR_TYPE_UNIX:
|
|
c.connect_unix(*a)
|
|
else:
|
|
c.connect_tcp(*a)
|
|
|
|
return c
|
|
|
|
async def create_async_client(addr):
|
|
from . import client
|
|
c = client.AsyncClient()
|
|
|
|
(typ, a) = parse_address(addr)
|
|
if typ == ADDR_TYPE_UNIX:
|
|
await c.connect_unix(*a)
|
|
else:
|
|
await c.connect_tcp(*a)
|
|
|
|
return c
|