diff options
Diffstat (limited to 'lib/prserv')
-rw-r--r-- | lib/prserv/__init__.py | 12 | ||||
-rw-r--r-- | lib/prserv/client.py | 71 | ||||
-rw-r--r-- | lib/prserv/db.py | 211 | ||||
-rw-r--r-- | lib/prserv/serv.py | 616 |
4 files changed, 480 insertions, 430 deletions
diff --git a/lib/prserv/__init__.py b/lib/prserv/__init__.py index c3cb73ad9..0e0aa34d0 100644 --- a/lib/prserv/__init__.py +++ b/lib/prserv/__init__.py @@ -1,13 +1,19 @@ +# +# Copyright BitBake Contributors +# +# SPDX-License-Identifier: GPL-2.0-only +# + __version__ = "1.0.0" import os, time -import sys,logging +import sys, logging def init_logger(logfile, loglevel): numeric_level = getattr(logging, loglevel.upper(), None) if not isinstance(numeric_level, int): - raise ValueError('Invalid log level: %s' % loglevel) - FORMAT = '%(asctime)-15s %(message)s' + raise ValueError("Invalid log level: %s" % loglevel) + FORMAT = "%(asctime)-15s %(message)s" logging.basicConfig(level=numeric_level, filename=logfile, format=FORMAT) class NotFoundError(Exception): diff --git a/lib/prserv/client.py b/lib/prserv/client.py new file mode 100644 index 000000000..8471ee304 --- /dev/null +++ b/lib/prserv/client.py @@ -0,0 +1,71 @@ +# +# Copyright BitBake Contributors +# +# SPDX-License-Identifier: GPL-2.0-only +# + +import logging +import bb.asyncrpc + +logger = logging.getLogger("BitBake.PRserv") + +class PRAsyncClient(bb.asyncrpc.AsyncClient): + def __init__(self): + super().__init__("PRSERVICE", "1.0", logger) + + async def getPR(self, version, pkgarch, checksum): + response = await self.invoke( + {"get-pr": {"version": version, "pkgarch": pkgarch, "checksum": checksum}} + ) + if response: + return response["value"] + + async def test_pr(self, version, pkgarch, checksum): + response = await self.invoke( + {"test-pr": {"version": version, "pkgarch": pkgarch, "checksum": checksum}} + ) + if response: + return response["value"] + + async def test_package(self, version, pkgarch): + response = await self.invoke( + {"test-package": {"version": version, "pkgarch": pkgarch}} + ) + if response: + return response["value"] + + async def max_package_pr(self, version, pkgarch): + response = await self.invoke( + {"max-package-pr": {"version": version, "pkgarch": pkgarch}} + ) + if response: + return response["value"] + + async def importone(self, version, pkgarch, checksum, value): + response = await self.invoke( + {"import-one": {"version": version, "pkgarch": pkgarch, "checksum": checksum, "value": value}} + ) + if response: + return response["value"] + + async def export(self, version, pkgarch, checksum, colinfo): + response = await self.invoke( + {"export": {"version": version, "pkgarch": pkgarch, "checksum": checksum, "colinfo": colinfo}} + ) + if response: + return (response["metainfo"], response["datainfo"]) + + async def is_readonly(self): + response = await self.invoke( + {"is-readonly": {}} + ) + if response: + return response["readonly"] + +class PRClient(bb.asyncrpc.Client): + def __init__(self): + super().__init__() + self._add_methods("getPR", "test_pr", "test_package", "importone", "export", "is_readonly") + + def _get_async_client(self): + return PRAsyncClient() diff --git a/lib/prserv/db.py b/lib/prserv/db.py index 495d09f39..eb4150819 100644 --- a/lib/prserv/db.py +++ b/lib/prserv/db.py @@ -1,3 +1,9 @@ +# +# Copyright BitBake Contributors +# +# SPDX-License-Identifier: GPL-2.0-only +# + import logging import os.path import errno @@ -26,21 +32,29 @@ if sqlversion[0] < 3 or (sqlversion[0] == 3 and sqlversion[1] < 3): # class PRTable(object): - def __init__(self, conn, table, nohist): + def __init__(self, conn, table, nohist, read_only): self.conn = conn self.nohist = nohist + self.read_only = read_only self.dirty = False if nohist: - self.table = "%s_nohist" % table + self.table = "%s_nohist" % table else: - self.table = "%s_hist" % table + self.table = "%s_hist" % table - self._execute("CREATE TABLE IF NOT EXISTS %s \ - (version TEXT NOT NULL, \ - pkgarch TEXT NOT NULL, \ - checksum TEXT NOT NULL, \ - value INTEGER, \ - PRIMARY KEY (version, pkgarch, checksum));" % self.table) + if self.read_only: + table_exists = self._execute( + "SELECT count(*) FROM sqlite_master \ + WHERE type='table' AND name='%s'" % (self.table)) + if not table_exists: + raise prserv.NotFoundError + else: + self._execute("CREATE TABLE IF NOT EXISTS %s \ + (version TEXT NOT NULL, \ + pkgarch TEXT NOT NULL, \ + checksum TEXT NOT NULL, \ + value INTEGER, \ + PRIMARY KEY (version, pkgarch, checksum));" % self.table) def _execute(self, *query): """Execute a query, waiting to acquire a lock if necessary""" @@ -50,31 +64,87 @@ class PRTable(object): try: return self.conn.execute(*query) except sqlite3.OperationalError as exc: - if 'is locked' in str(exc) and end > time.time(): + if "is locked" in str(exc) and end > time.time(): continue raise exc def sync(self): - self.conn.commit() - self._execute("BEGIN EXCLUSIVE TRANSACTION") + if not self.read_only: + self.conn.commit() + self._execute("BEGIN EXCLUSIVE TRANSACTION") def sync_if_dirty(self): if self.dirty: self.sync() self.dirty = False - def _getValueHist(self, version, pkgarch, checksum): + def test_package(self, version, pkgarch): + """Returns whether the specified package version is found in the database for the specified architecture""" + + # Just returns the value if found or None otherwise + data=self._execute("SELECT value FROM %s WHERE version=? AND pkgarch=?;" % self.table, + (version, pkgarch)) + row=data.fetchone() + if row is not None: + return True + else: + return False + + def test_value(self, version, pkgarch, value): + """Returns whether the specified value is found in the database for the specified package and architecture""" + + # Just returns the value if found or None otherwise + data=self._execute("SELECT value FROM %s WHERE version=? AND pkgarch=? and value=?;" % self.table, + (version, pkgarch, value)) + row=data.fetchone() + if row is not None: + return True + else: + return False + + def find_value(self, version, pkgarch, checksum): + """Returns the value for the specified checksum if found or None otherwise.""" + data=self._execute("SELECT value FROM %s WHERE version=? AND pkgarch=? AND checksum=?;" % self.table, (version, pkgarch, checksum)) row=data.fetchone() - if row != None: + if row is not None: + return row[0] + else: + return None + + def find_max_value(self, version, pkgarch): + """Returns the greatest value for (version, pkgarch), or None if not found. Doesn't create a new value""" + + data = self._execute("SELECT max(value) FROM %s where version=? AND pkgarch=?;" % (self.table), + (version, pkgarch)) + row = data.fetchone() + if row is not None: + return row[0] + else: + return None + + def _get_value_hist(self, version, pkgarch, checksum): + data=self._execute("SELECT value FROM %s WHERE version=? AND pkgarch=? AND checksum=?;" % self.table, + (version, pkgarch, checksum)) + row=data.fetchone() + if row is not None: return row[0] else: #no value found, try to insert + if self.read_only: + data = self._execute("SELECT ifnull(max(value)+1, 0) FROM %s where version=? AND pkgarch=?;" % (self.table), + (version, pkgarch)) + row = data.fetchone() + if row is not None: + return row[0] + else: + return 0 + try: - self._execute("INSERT INTO %s VALUES (?, ?, ?, (select ifnull(max(value)+1,0) from %s where version=? AND pkgarch=?));" - % (self.table,self.table), - (version,pkgarch, checksum,version, pkgarch)) + self._execute("INSERT INTO %s VALUES (?, ?, ?, (select ifnull(max(value)+1, 0) from %s where version=? AND pkgarch=?));" + % (self.table, self.table), + (version, pkgarch, checksum, version, pkgarch)) except sqlite3.IntegrityError as exc: logger.error(str(exc)) @@ -83,25 +153,30 @@ class PRTable(object): data=self._execute("SELECT value FROM %s WHERE version=? AND pkgarch=? AND checksum=?;" % self.table, (version, pkgarch, checksum)) row=data.fetchone() - if row != None: + if row is not None: return row[0] else: raise prserv.NotFoundError - def _getValueNohist(self, version, pkgarch, checksum): + def _get_value_no_hist(self, version, pkgarch, checksum): data=self._execute("SELECT value FROM %s \ WHERE version=? AND pkgarch=? AND checksum=? AND \ - value >= (select max(value) from %s where version=? AND pkgarch=?);" + value >= (select max(value) from %s where version=? AND pkgarch=?);" % (self.table, self.table), (version, pkgarch, checksum, version, pkgarch)) row=data.fetchone() - if row != None: + if row is not None: return row[0] else: #no value found, try to insert + if self.read_only: + data = self._execute("SELECT ifnull(max(value)+1, 0) FROM %s where version=? AND pkgarch=?;" % (self.table), + (version, pkgarch)) + return data.fetchone()[0] + try: - self._execute("INSERT OR REPLACE INTO %s VALUES (?, ?, ?, (select ifnull(max(value)+1,0) from %s where version=? AND pkgarch=?));" - % (self.table,self.table), + self._execute("INSERT OR REPLACE INTO %s VALUES (?, ?, ?, (select ifnull(max(value)+1, 0) from %s where version=? AND pkgarch=?));" + % (self.table, self.table), (version, pkgarch, checksum, version, pkgarch)) except sqlite3.IntegrityError as exc: logger.error(str(exc)) @@ -112,23 +187,26 @@ class PRTable(object): data=self._execute("SELECT value FROM %s WHERE version=? AND pkgarch=? AND checksum=?;" % self.table, (version, pkgarch, checksum)) row=data.fetchone() - if row != None: + if row is not None: return row[0] else: raise prserv.NotFoundError - def getValue(self, version, pkgarch, checksum): + def get_value(self, version, pkgarch, checksum): if self.nohist: - return self._getValueNohist(version, pkgarch, checksum) + return self._get_value_no_hist(version, pkgarch, checksum) else: - return self._getValueHist(version, pkgarch, checksum) + return self._get_value_hist(version, pkgarch, checksum) + + def _import_hist(self, version, pkgarch, checksum, value): + if self.read_only: + return None - def _importHist(self, version, pkgarch, checksum, value): - val = None + val = None data = self._execute("SELECT value FROM %s WHERE version=? AND pkgarch=? AND checksum=?;" % self.table, (version, pkgarch, checksum)) row = data.fetchone() - if row != None: + if row is not None: val=row[0] else: #no value found, try to insert @@ -143,63 +221,66 @@ class PRTable(object): data = self._execute("SELECT value FROM %s WHERE version=? AND pkgarch=? AND checksum=?;" % self.table, (version, pkgarch, checksum)) row = data.fetchone() - if row != None: + if row is not None: val = row[0] return val - def _importNohist(self, version, pkgarch, checksum, value): + def _import_no_hist(self, version, pkgarch, checksum, value): + if self.read_only: + return None + try: #try to insert self._execute("INSERT INTO %s VALUES (?, ?, ?, ?);" % (self.table), - (version, pkgarch, checksum,value)) + (version, pkgarch, checksum, value)) except sqlite3.IntegrityError as exc: #already have the record, try to update try: - self._execute("UPDATE %s SET value=? WHERE version=? AND pkgarch=? AND checksum=? AND value<?" + self._execute("UPDATE %s SET value=? WHERE version=? AND pkgarch=? AND checksum=? AND value<?" % (self.table), - (value,version,pkgarch,checksum,value)) + (value, version, pkgarch, checksum, value)) except sqlite3.IntegrityError as exc: logger.error(str(exc)) self.dirty = True data = self._execute("SELECT value FROM %s WHERE version=? AND pkgarch=? AND checksum=? AND value>=?;" % self.table, - (version,pkgarch,checksum,value)) + (version, pkgarch, checksum, value)) row=data.fetchone() - if row != None: + if row is not None: return row[0] else: return None def importone(self, version, pkgarch, checksum, value): if self.nohist: - return self._importNohist(version, pkgarch, checksum, value) + return self._import_no_hist(version, pkgarch, checksum, value) else: - return self._importHist(version, pkgarch, checksum, value) + return self._import_hist(version, pkgarch, checksum, value) def export(self, version, pkgarch, checksum, colinfo): metainfo = {} - #column info + #column info if colinfo: - metainfo['tbl_name'] = self.table - metainfo['core_ver'] = prserv.__version__ - metainfo['col_info'] = [] + metainfo["tbl_name"] = self.table + metainfo["core_ver"] = prserv.__version__ + metainfo["col_info"] = [] data = self._execute("PRAGMA table_info(%s);" % self.table) for row in data: col = {} - col['name'] = row['name'] - col['type'] = row['type'] - col['notnull'] = row['notnull'] - col['dflt_value'] = row['dflt_value'] - col['pk'] = row['pk'] - metainfo['col_info'].append(col) + col["name"] = row["name"] + col["type"] = row["type"] + col["notnull"] = row["notnull"] + col["dflt_value"] = row["dflt_value"] + col["pk"] = row["pk"] + metainfo["col_info"].append(col) #data info datainfo = [] if self.nohist: sqlstmt = "SELECT T1.version, T1.pkgarch, T1.checksum, T1.value FROM %s as T1, \ - (SELECT version,pkgarch,max(value) as maxvalue FROM %s GROUP BY version,pkgarch) as T2 \ + (SELECT version, pkgarch, max(value) as maxvalue FROM %s GROUP BY version, pkgarch) as T2 \ WHERE T1.version=T2.version AND T1.pkgarch=T2.pkgarch AND T1.value=T2.maxvalue " % (self.table, self.table) else: sqlstmt = "SELECT * FROM %s as T1 WHERE 1=1 " % self.table @@ -222,12 +303,12 @@ class PRTable(object): else: data = self._execute(sqlstmt) for row in data: - if row['version']: + if row["version"]: col = {} - col['version'] = row['version'] - col['pkgarch'] = row['pkgarch'] - col['checksum'] = row['checksum'] - col['value'] = row['value'] + col["version"] = row["version"] + col["pkgarch"] = row["pkgarch"] + col["checksum"] = row["checksum"] + col["value"] = row["value"] datainfo.append(col) return (metainfo, datainfo) @@ -236,41 +317,45 @@ class PRTable(object): for line in self.conn.iterdump(): writeCount = writeCount + len(line) + 1 fd.write(line) - fd.write('\n') + fd.write("\n") return writeCount class PRData(object): """Object representing the PR database""" - def __init__(self, filename, nohist=True): + def __init__(self, filename, nohist=True, read_only=False): self.filename=os.path.abspath(filename) self.nohist=nohist + self.read_only = read_only #build directory hierarchy try: os.makedirs(os.path.dirname(self.filename)) except OSError as e: if e.errno != errno.EEXIST: raise e - self.connection=sqlite3.connect(self.filename, isolation_level="EXCLUSIVE", check_same_thread = False) + uri = "file:%s%s" % (self.filename, "?mode=ro" if self.read_only else "") + logger.debug("Opening PRServ database '%s'" % (uri)) + self.connection=sqlite3.connect(uri, uri=True, isolation_level="EXCLUSIVE", check_same_thread = False) self.connection.row_factory=sqlite3.Row - self.connection.execute("pragma synchronous = off;") - self.connection.execute("PRAGMA journal_mode = WAL;") + if not self.read_only: + self.connection.execute("pragma synchronous = off;") + self.connection.execute("PRAGMA journal_mode = MEMORY;") self._tables={} def disconnect(self): self.connection.close() - def __getitem__(self,tblname): + def __getitem__(self, tblname): if not isinstance(tblname, str): raise TypeError("tblname argument must be a string, not '%s'" % type(tblname)) if tblname in self._tables: return self._tables[tblname] else: - tableobj = self._tables[tblname] = PRTable(self.connection, tblname, self.nohist) + tableobj = self._tables[tblname] = PRTable(self.connection, tblname, self.nohist, self.read_only) return tableobj def __delitem__(self, tblname): if tblname in self._tables: del self._tables[tblname] logger.info("drop table %s" % (tblname)) - self.connection.execute("DROP TABLE IF EXISTS %s;" % tblname) + self.connection.execute("DROP TABLE IF EXISTS %s;" % tblname) diff --git a/lib/prserv/serv.py b/lib/prserv/serv.py index 6a99728c4..dc4be5b62 100644 --- a/lib/prserv/serv.py +++ b/lib/prserv/serv.py @@ -1,356 +1,244 @@ +# +# Copyright BitBake Contributors +# +# SPDX-License-Identifier: GPL-2.0-only +# + import os,sys,logging import signal, time -from xmlrpc.server import SimpleXMLRPCServer, SimpleXMLRPCRequestHandler -import threading -import queue import socket import io import sqlite3 -import bb.server.xmlrpcclient import prserv import prserv.db import errno -import select +import bb.asyncrpc logger = logging.getLogger("BitBake.PRserv") -if sys.hexversion < 0x020600F0: - print("Sorry, python 2.6 or later is required.") - sys.exit(1) +PIDPREFIX = "/tmp/PRServer_%s_%s.pid" +singleton = None -class Handler(SimpleXMLRPCRequestHandler): - def _dispatch(self,method,params): +class PRServerClient(bb.asyncrpc.AsyncServerConnection): + def __init__(self, socket, server): + super().__init__(socket, "PRSERVICE", server.logger) + self.server = server + + self.handlers.update({ + "get-pr": self.handle_get_pr, + "test-pr": self.handle_test_pr, + "test-package": self.handle_test_package, + "max-package-pr": self.handle_max_package_pr, + "import-one": self.handle_import_one, + "export": self.handle_export, + "is-readonly": self.handle_is_readonly, + }) + + def validate_proto_version(self): + return (self.proto_version == (1, 0)) + + async def dispatch_message(self, msg): try: - value=self.server.funcs[method](*params) + return await super().dispatch_message(msg) except: - import traceback - traceback.print_exc() + self.server.table.sync() raise - return value + else: + self.server.table.sync_if_dirty() -PIDPREFIX = "/tmp/PRServer_%s_%s.pid" -singleton = None + async def handle_test_pr(self, request): + '''Finds the PR value corresponding to the request. If not found, returns None and doesn't insert a new value''' + version = request["version"] + pkgarch = request["pkgarch"] + checksum = request["checksum"] + value = self.server.table.find_value(version, pkgarch, checksum) + return {"value": value} -class PRServer(SimpleXMLRPCServer): - def __init__(self, dbfile, logfile, interface, daemon=True): - ''' constructor ''' - try: - SimpleXMLRPCServer.__init__(self, interface, - logRequests=False, allow_none=True) - except socket.error: - ip=socket.gethostbyname(interface[0]) - port=interface[1] - msg="PR Server unable to bind to %s:%s\n" % (ip, port) - sys.stderr.write(msg) - raise PRServiceConfigError + async def handle_test_package(self, request): + '''Tells whether there are entries for (version, pkgarch) in the db. Returns True or False''' + version = request["version"] + pkgarch = request["pkgarch"] - self.dbfile=dbfile - self.daemon=daemon - self.logfile=logfile - self.working_thread=None - self.host, self.port = self.socket.getsockname() - self.pidfile=PIDPREFIX % (self.host, self.port) - - self.register_function(self.getPR, "getPR") - self.register_function(self.quit, "quit") - self.register_function(self.ping, "ping") - self.register_function(self.export, "export") - self.register_function(self.dump_db, "dump_db") - self.register_function(self.importone, "importone") - self.register_introspection_functions() - - self.quitpipein, self.quitpipeout = os.pipe() - - self.requestqueue = queue.Queue() - self.handlerthread = threading.Thread(target = self.process_request_thread) - self.handlerthread.daemon = False - - def process_request_thread(self): - """Same as in BaseServer but as a thread. - - In addition, exception handling is done here. - - """ - iter_count = 1 - # 60 iterations between syncs or sync if dirty every ~30 seconds - iterations_between_sync = 60 - - bb.utils.set_process_name("PRServ Handler") - - while not self.quitflag: - try: - (request, client_address) = self.requestqueue.get(True, 30) - except queue.Empty: - self.table.sync_if_dirty() - continue - if request is None: - continue - try: - self.finish_request(request, client_address) - self.shutdown_request(request) - iter_count = (iter_count + 1) % iterations_between_sync - if iter_count == 0: - self.table.sync_if_dirty() - except: - self.handle_error(request, client_address) - self.shutdown_request(request) - self.table.sync() - self.table.sync_if_dirty() - - def sigint_handler(self, signum, stack): - if self.table: - self.table.sync() + value = self.server.table.test_package(version, pkgarch) + return {"value": value} - def sigterm_handler(self, signum, stack): - if self.table: - self.table.sync() - self.quit() - self.requestqueue.put((None, None)) + async def handle_max_package_pr(self, request): + '''Finds the greatest PR value for (version, pkgarch) in the db. Returns None if no entry was found''' + version = request["version"] + pkgarch = request["pkgarch"] - def process_request(self, request, client_address): - self.requestqueue.put((request, client_address)) + value = self.server.table.find_max_value(version, pkgarch) + return {"value": value} - def export(self, version=None, pkgarch=None, checksum=None, colinfo=True): - try: - return self.table.export(version, pkgarch, checksum, colinfo) - except sqlite3.Error as exc: - logger.error(str(exc)) - return None - - def dump_db(self): - """ - Returns a script (string) that reconstructs the state of the - entire database at the time this function is called. The script - language is defined by the backing database engine, which is a - function of server configuration. - Returns None if the database engine does not support dumping to - script or if some other error is encountered in processing. - """ - buff = io.StringIO() + async def handle_get_pr(self, request): + version = request["version"] + pkgarch = request["pkgarch"] + checksum = request["checksum"] + + response = None try: - self.table.sync() - self.table.dump_db(buff) - return buff.getvalue() - except Exception as exc: - logger.error(str(exc)) - return None - finally: - buff.close() + value = self.server.table.get_value(version, pkgarch, checksum) + response = {"value": value} + except prserv.NotFoundError: + self.logger.error("failure storing value in database for (%s, %s)",version, checksum) + + return response + + async def handle_import_one(self, request): + response = None + if not self.server.read_only: + version = request["version"] + pkgarch = request["pkgarch"] + checksum = request["checksum"] + value = request["value"] - def importone(self, version, pkgarch, checksum, value): - return self.table.importone(version, pkgarch, checksum, value) + value = self.server.table.importone(version, pkgarch, checksum, value) + if value is not None: + response = {"value": value} - def ping(self): - return not self.quitflag + return response - def getinfo(self): - return (self.host, self.port) + async def handle_export(self, request): + version = request["version"] + pkgarch = request["pkgarch"] + checksum = request["checksum"] + colinfo = request["colinfo"] - def getPR(self, version, pkgarch, checksum): try: - return self.table.getValue(version, pkgarch, checksum) - except prserv.NotFoundError: - logger.error("can not find value for (%s, %s)",version, checksum) - return None + (metainfo, datainfo) = self.server.table.export(version, pkgarch, checksum, colinfo) except sqlite3.Error as exc: - logger.error(str(exc)) - return None - - def quit(self): - self.quitflag=True - os.write(self.quitpipeout, b"q") - os.close(self.quitpipeout) - return - - def work_forever(self,): - self.quitflag = False - # This timeout applies to the poll in TCPServer, we need the select - # below to wake on our quit pipe closing. We only ever call into handle_request - # if there is data there. - self.timeout = 0.01 - - bb.utils.set_process_name("PRServ") - - # DB connection must be created after all forks - self.db = prserv.db.PRData(self.dbfile) - self.table = self.db["PRMAIN"] + self.logger.error(str(exc)) + metainfo = datainfo = None - logger.info("Started PRServer with DBfile: %s, IP: %s, PORT: %s, PID: %s" % - (self.dbfile, self.host, self.port, str(os.getpid()))) - - self.handlerthread.start() - while not self.quitflag: - ready = select.select([self.fileno(), self.quitpipein], [], [], 30) - if self.quitflag: - break - if self.fileno() in ready[0]: - self.handle_request() - self.handlerthread.join() - self.db.disconnect() - logger.info("PRServer: stopping...") - self.server_close() - os.close(self.quitpipein) - return + return {"metainfo": metainfo, "datainfo": datainfo} - def start(self): - if self.daemon: - pid = self.daemonize() - else: - pid = self.fork() - self.pid = pid + async def handle_is_readonly(self, request): + return {"readonly": self.server.read_only} - # Ensure both the parent sees this and the child from the work_forever log entry above - logger.info("Started PRServer with DBfile: %s, IP: %s, PORT: %s, PID: %s" % - (self.dbfile, self.host, self.port, str(pid))) +class PRServer(bb.asyncrpc.AsyncServer): + def __init__(self, dbfile, read_only=False): + super().__init__(logger) + self.dbfile = dbfile + self.table = None + self.read_only = read_only - def delpid(self): - os.remove(self.pidfile) + def accept_client(self, socket): + return PRServerClient(socket, self) - def daemonize(self): - """ - See Advanced Programming in the UNIX, Sec 13.3 - """ - try: - pid = os.fork() - if pid > 0: - os.waitpid(pid, 0) - #parent return instead of exit to give control - return pid - except OSError as e: - raise Exception("%s [%d]" % (e.strerror, e.errno)) - - os.setsid() - """ - fork again to make sure the daemon is not session leader, - which prevents it from acquiring controlling terminal - """ - try: - pid = os.fork() - if pid > 0: #parent - os._exit(0) - except OSError as e: - raise Exception("%s [%d]" % (e.strerror, e.errno)) + def start(self): + tasks = super().start() + self.db = prserv.db.PRData(self.dbfile, read_only=self.read_only) + self.table = self.db["PRMAIN"] - self.cleanup_handles() - os._exit(0) + self.logger.info("Started PRServer with DBfile: %s, Address: %s, PID: %s" % + (self.dbfile, self.address, str(os.getpid()))) - def fork(self): - try: - pid = os.fork() - if pid > 0: - return pid - except OSError as e: - raise Exception("%s [%d]" % (e.strerror, e.errno)) - - bb.utils.signal_on_parent_exit("SIGTERM") - self.cleanup_handles() - os._exit(0) - - def cleanup_handles(self): - signal.signal(signal.SIGINT, self.sigint_handler) - signal.signal(signal.SIGTERM, self.sigterm_handler) - os.chdir("/") - - sys.stdout.flush() - sys.stderr.flush() - - # We could be called from a python thread with io.StringIO as - # stdout/stderr or it could be 'real' unix fd forking where we need - # to physically close the fds to prevent the program launching us from - # potentially hanging on a pipe. Handle both cases. - si = open('/dev/null', 'r') - try: - os.dup2(si.fileno(),sys.stdin.fileno()) - except (AttributeError, io.UnsupportedOperation): - sys.stdin = si - so = open(self.logfile, 'a+') - try: - os.dup2(so.fileno(),sys.stdout.fileno()) - except (AttributeError, io.UnsupportedOperation): - sys.stdout = so - try: - os.dup2(so.fileno(),sys.stderr.fileno()) - except (AttributeError, io.UnsupportedOperation): - sys.stderr = so - - # Clear out all log handlers prior to the fork() to avoid calling - # event handlers not part of the PRserver - for logger_iter in logging.Logger.manager.loggerDict.keys(): - logging.getLogger(logger_iter).handlers = [] - - # Ensure logging makes it to the logfile - streamhandler = logging.StreamHandler() - streamhandler.setLevel(logging.DEBUG) - formatter = bb.msg.BBLogFormatter("%(levelname)s: %(message)s") - streamhandler.setFormatter(formatter) - logger.addHandler(streamhandler) - - # write pidfile - pid = str(os.getpid()) - pf = open(self.pidfile, 'w') - pf.write("%s\n" % pid) - pf.close() + return tasks + + async def stop(self): + self.table.sync_if_dirty() + self.db.disconnect() + await super().stop() - self.work_forever() - self.delpid() + def signal_handler(self): + super().signal_handler() + if self.table: + self.table.sync() class PRServSingleton(object): - def __init__(self, dbfile, logfile, interface): + def __init__(self, dbfile, logfile, host, port): self.dbfile = dbfile self.logfile = logfile - self.interface = interface - self.host = None - self.port = None - - def start(self): - self.prserv = PRServer(self.dbfile, self.logfile, self.interface, daemon=False) - self.prserv.start() - self.host, self.port = self.prserv.getinfo() - - def getinfo(self): - return (self.host, self.port) - -class PRServerConnection(object): - def __init__(self, host, port): - if is_local_special(host, port): - host, port = singleton.getinfo() self.host = host self.port = port - self.connection, self.transport = bb.server.xmlrpcclient._create_server(self.host, self.port) - def terminate(self): - try: - logger.info("Terminating PRServer...") - self.connection.quit() - except Exception as exc: - sys.stderr.write("%s\n" % str(exc)) + def start(self): + self.prserv = PRServer(self.dbfile) + self.prserv.start_tcp_server(socket.gethostbyname(self.host), self.port) + self.process = self.prserv.serve_as_process(log_level=logging.WARNING) - def getPR(self, version, pkgarch, checksum): - return self.connection.getPR(version, pkgarch, checksum) + if not self.prserv.address: + raise PRServiceConfigError + if not self.port: + self.port = int(self.prserv.address.rsplit(":", 1)[1]) - def ping(self): - return self.connection.ping() +def run_as_daemon(func, pidfile, logfile): + """ + See Advanced Programming in the UNIX, Sec 13.3 + """ + try: + pid = os.fork() + if pid > 0: + os.waitpid(pid, 0) + #parent return instead of exit to give control + return pid + except OSError as e: + raise Exception("%s [%d]" % (e.strerror, e.errno)) - def export(self,version=None, pkgarch=None, checksum=None, colinfo=True): - return self.connection.export(version, pkgarch, checksum, colinfo) + os.setsid() + """ + fork again to make sure the daemon is not session leader, + which prevents it from acquiring controlling terminal + """ + try: + pid = os.fork() + if pid > 0: #parent + os._exit(0) + except OSError as e: + raise Exception("%s [%d]" % (e.strerror, e.errno)) - def dump_db(self): - return self.connection.dump_db() + os.chdir("/") - def importone(self, version, pkgarch, checksum, value): - return self.connection.importone(version, pkgarch, checksum, value) + sys.stdout.flush() + sys.stderr.flush() - def getinfo(self): - return self.host, self.port + # We could be called from a python thread with io.StringIO as + # stdout/stderr or it could be 'real' unix fd forking where we need + # to physically close the fds to prevent the program launching us from + # potentially hanging on a pipe. Handle both cases. + si = open("/dev/null", "r") + try: + os.dup2(si.fileno(), sys.stdin.fileno()) + except (AttributeError, io.UnsupportedOperation): + sys.stdin = si + so = open(logfile, "a+") + try: + os.dup2(so.fileno(), sys.stdout.fileno()) + except (AttributeError, io.UnsupportedOperation): + sys.stdout = so + try: + os.dup2(so.fileno(), sys.stderr.fileno()) + except (AttributeError, io.UnsupportedOperation): + sys.stderr = so + + # Clear out all log handlers prior to the fork() to avoid calling + # event handlers not part of the PRserver + for logger_iter in logging.Logger.manager.loggerDict.keys(): + logging.getLogger(logger_iter).handlers = [] + + # Ensure logging makes it to the logfile + streamhandler = logging.StreamHandler() + streamhandler.setLevel(logging.DEBUG) + formatter = bb.msg.BBLogFormatter("%(levelname)s: %(message)s") + streamhandler.setFormatter(formatter) + logger.addHandler(streamhandler) + + # write pidfile + pid = str(os.getpid()) + with open(pidfile, "w") as pf: + pf.write("%s\n" % pid) + + func() + os.remove(pidfile) + os._exit(0) -def start_daemon(dbfile, host, port, logfile): +def start_daemon(dbfile, host, port, logfile, read_only=False): ip = socket.gethostbyname(host) pidfile = PIDPREFIX % (ip, port) try: - pf = open(pidfile,'r') - pid = int(pf.readline().strip()) - pf.close() + with open(pidfile) as pf: + pid = int(pf.readline().strip()) except IOError: pid = None @@ -359,15 +247,13 @@ def start_daemon(dbfile, host, port, logfile): % pidfile) return 1 - server = PRServer(os.path.abspath(dbfile), os.path.abspath(logfile), (ip,port)) - server.start() + dbfile = os.path.abspath(dbfile) + def daemon_main(): + server = PRServer(dbfile, read_only=read_only) + server.start_tcp_server(ip, port) + server.serve_forever() - # Sometimes, the port (i.e. localhost:0) indicated by the user does not match with - # the one the server actually is listening, so at least warn the user about it - _,rport = server.getinfo() - if port != rport: - sys.stdout.write("Server is listening at port %s instead of %s\n" - % (rport,port)) + run_as_daemon(daemon_main, pidfile, os.path.abspath(logfile)) return 0 def stop_daemon(host, port): @@ -375,9 +261,8 @@ def stop_daemon(host, port): ip = socket.gethostbyname(host) pidfile = PIDPREFIX % (ip, port) try: - pf = open(pidfile,'r') - pid = int(pf.readline().strip()) - pf.close() + with open(pidfile) as pf: + pid = int(pf.readline().strip()) except IOError: pid = None @@ -386,37 +271,28 @@ def stop_daemon(host, port): # so at least advise the user which ports the corresponding server is listening ports = [] portstr = "" - for pf in glob.glob(PIDPREFIX % (ip,'*')): + for pf in glob.glob(PIDPREFIX % (ip, "*")): bn = os.path.basename(pf) root, _ = os.path.splitext(bn) - ports.append(root.split('_')[-1]) + ports.append(root.split("_")[-1]) if len(ports): - portstr = "Wrong port? Other ports listening at %s: %s" % (host, ' '.join(ports)) + portstr = "Wrong port? Other ports listening at %s: %s" % (host, " ".join(ports)) sys.stderr.write("pidfile %s does not exist. Daemon not running? %s\n" - % (pidfile,portstr)) + % (pidfile, portstr)) return 1 try: - PRServerConnection(ip, port).terminate() - except: - logger.critical("Stop PRService %s:%d failed" % (host,port)) - - try: - if pid: - wait_timeout = 0 - print("Waiting for pr-server to exit.") - while is_running(pid) and wait_timeout < 50: - time.sleep(0.1) - wait_timeout += 1 + if is_running(pid): + print("Sending SIGTERM to pr-server.") + os.kill(pid, signal.SIGTERM) + time.sleep(0.1) - if is_running(pid): - print("Sending SIGTERM to pr-server.") - os.kill(pid,signal.SIGTERM) - time.sleep(0.1) - - if os.path.exists(pidfile): - os.remove(pidfile) + try: + os.remove(pidfile) + except FileNotFoundError: + # The PID file might have been removed by the exiting process + pass except OSError as e: err = str(e) @@ -434,7 +310,7 @@ def is_running(pid): return True def is_local_special(host, port): - if host.strip().upper() == 'localhost'.upper() and (not port): + if (host == "localhost" or host == "127.0.0.1") and not port: return True else: return False @@ -445,60 +321,72 @@ class PRServiceConfigError(Exception): def auto_start(d): global singleton - # Shutdown any existing PR Server - auto_shutdown() - - host_params = list(filter(None, (d.getVar('PRSERV_HOST') or '').split(':'))) + host_params = list(filter(None, (d.getVar("PRSERV_HOST") or "").split(":"))) if not host_params: + # Shutdown any existing PR Server + auto_shutdown() return None if len(host_params) != 2: - logger.critical('\n'.join(['PRSERV_HOST: incorrect format', + # Shutdown any existing PR Server + auto_shutdown() + logger.critical("\n".join(["PRSERV_HOST: incorrect format", 'Usage: PRSERV_HOST = "<hostname>:<port>"'])) raise PRServiceConfigError - if is_local_special(host_params[0], int(host_params[1])) and not singleton: + host = host_params[0].strip().lower() + port = int(host_params[1]) + if is_local_special(host, port): import bb.utils cachedir = (d.getVar("PERSISTENT_DIR") or d.getVar("CACHE")) if not cachedir: logger.critical("Please set the 'PERSISTENT_DIR' or 'CACHE' variable") raise PRServiceConfigError - bb.utils.mkdirhier(cachedir) dbfile = os.path.join(cachedir, "prserv.sqlite3") logfile = os.path.join(cachedir, "prserv.log") - singleton = PRServSingleton(os.path.abspath(dbfile), os.path.abspath(logfile), ("localhost",0)) - singleton.start() + if singleton: + if singleton.dbfile != dbfile: + # Shutdown any existing PR Server as doesn't match config + auto_shutdown() + if not singleton: + bb.utils.mkdirhier(cachedir) + singleton = PRServSingleton(os.path.abspath(dbfile), os.path.abspath(logfile), host, port) + singleton.start() if singleton: - host, port = singleton.getinfo() - else: - host = host_params[0] - port = int(host_params[1]) + host = singleton.host + port = singleton.port try: - connection = PRServerConnection(host,port) - connection.ping() - realhost, realport = connection.getinfo() - return str(realhost) + ":" + str(realport) - + ping(host, port) + return str(host) + ":" + str(port) + except Exception: logger.critical("PRservice %s:%d not available" % (host, port)) raise PRServiceConfigError def auto_shutdown(): global singleton - if singleton: - host, port = singleton.getinfo() - try: - PRServerConnection(host, port).terminate() - except: - logger.critical("Stop PRService %s:%d failed" % (host,port)) - - try: - os.waitpid(singleton.prserv.pid, 0) - except ChildProcessError: - pass + if singleton and singleton.process: + singleton.process.terminate() + singleton.process.join() singleton = None def ping(host, port): - conn=PRServerConnection(host, port) - return conn.ping() + from . import client + + with client.PRClient() as conn: + conn.connect_tcp(host, port) + return conn.ping() + +def connect(host, port): + from . import client + + global singleton + + if host.strip().lower() == "localhost" and not port: + host = "localhost" + port = singleton.port + + conn = client.PRClient() + conn.connect_tcp(host, port) + return conn |