Commit 0f891992 authored by Ondřej Kuzník's avatar Ondřej Kuzník
Browse files

Handle errors during startup

parent e89f7766
......@@ -16,7 +16,7 @@ setup(
url='https://git.openldap.org/openldap/syncmonitor',
license=LICENSE,
packages=find_packages(exclude=('tests', 'docs')),
python_requires=">= 3.7", # Modern asyncio
python_requires=">= 3.7", # Modern asyncio
entry_points={
'console_scripts': ['syncmonitor = syncmonitor.ui:main'],
},
......@@ -24,5 +24,6 @@ setup(
'ldap0 >= 1.1.0',
'urwid',
'pyyaml',
'tenacity',
],
)
......@@ -20,19 +20,23 @@ Connection set up
"""
import logging
logger = logging.getLogger(__name__)
from .ldap_wrapper import AsyncClient
import ldap0 as ldap
from ldap0.ldapurl import LDAPUrl
logger = logging.getLogger(__name__)
def connect_and_setup(uri, config):
async def connect_and_setup(uri, config):
config = config or {}
uri = LDAPUrl(uri)
conn = AsyncClient(uri)
try:
conn = AsyncClient(uri)
except ldap.LDAPError:
logger.exception("During connect_and_setup for %s", uri)
raise
conn.protocol_version = ldap.VERSION3
reset_tls_ctx = False
......@@ -48,11 +52,11 @@ def connect_and_setup(uri, config):
conn.set_option(ldap.OPT_X_TLS_NEWCTX, 0)
if uri.urlscheme == 'ldap':
starttls = config.get('starttls', 'no')
if starttls:
starttls = config.get('starttls', 'yes')
if not starttls or starttls.lower() != 'no':
try:
conn.start_tls_s()
except ldap.SERVER_DOWN:
except ldap.PROTOCOL_ERROR:
if starttls == 'hard':
logger.exception("Cannot set up a TLS layer")
raise
......@@ -66,10 +70,13 @@ def connect_and_setup(uri, config):
if auth.get('authz_id'):
raise NotImplementedError
logger.debug("binding as %s on %s", auth['binddn'], uri)
conn.bind_s(auth['binddn'], auth['password'])
await conn.bind(auth['binddn'], auth['password'])
elif mechanism in ('EXTERNAL', 'GSSAPI'):
conn.sasl_non_interactive_bind_s(mechanism, auth.get('authz_id'))
else:
raise NotImplementedError
else:
# make sure we do something to confirm the connection is live
await conn.bind()
return conn
# -*- coding: utf-8 -*-
# This work is part of OpenLDAP Software <http://www.openldap.org/>.
#
# Copyright 2020 The OpenLDAP Foundation.
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted only as authorized by the OpenLDAP
# Public License.
#
# A copy of this license is available in the file LICENSE in the
# top-level directory of the distribution or, alternatively, at
# <http://www.OpenLDAP.org/license.html>.
#
# ACKNOWLEDGEMENTS:
# This work was initially developed by Ondřej Kuzník
# for inclusion in OpenLDAP Software.
"""
asyncio-based component to rate-limit a function
"""
import asyncio
from types import MethodType
def limited(t):
def wrapper(f):
return RateLimiter(t, f)
return wrapper
class RateLimiter:
class _Wrapper:
def __init__(self, time, wrapped):
self.time = time
self.wrapped = wrapped
self.timer = None
self.last_called = None
self.args = None
def __call__(self, *args, **kwargs):
if self.timer:
self.args = args, kwargs
return
loop = asyncio.get_running_loop()
now = loop.time()
if self.last_called:
next_call = self.last_called + self.time
else:
next_call = now
if next_call > now:
self.timer = loop.call_at(next_call, self.delayed)
self.args = args, kwargs
return
self.last_called = now
return self.wrapped(*args, **kwargs)
def delayed(self):
self.timer = None
args, kwargs = self.args
self.args = None
return self.wrapped(*args, **kwargs)
def __init__(self, time, f):
self.time = time
self.wrapped = f
self.name = None
def __set_name__(self, owner, name):
self.name = name
def __get__(self, instance, owner=None):
instance.__dict__[self.name] = RateLimiter._Wrapper(self.time, MethodType(self.wrapped, instance))
return instance.__dict__[self.name]
......@@ -22,6 +22,7 @@ OpenLDAP multiprovider replication monitor
import asyncio
import enum
import logging
from tenacity import retry, stop_after_delay, wait_exponential
import ldap0
import ldap0.controls
......@@ -102,18 +103,19 @@ class Provider:
self.catching_up_to = None
self.catching_up_set = None
self.set_up()
def set_up(self):
@retry(wait=wait_exponential(multiplier=0.1, max=5),
stop=stop_after_delay(30))
async def set_up(self):
self.state = ReplicaState.CONNECTING
self.state_changed(self.state)
self.client = connect_and_setup(self.uri, self.config)
self.up_to_date.clear()
self.client = await connect_and_setup(self.uri, self.config)
control = ldap0.controls.syncrepl.SyncRequestControl(
cookie=self.cookie, mode=self.mode)
self.search = self.client.search(self.base, scope=self.scope, filterstr="(|)",
req_ctrls=[control])
self.search = self.client.search(self.base, scope=self.scope,
filterstr="(|)", req_ctrls=[control])
self.observer = SyncreplObserver(self.search, self.cookie)
self.observer.cookie_updated.connect(self._update_cookie)
......@@ -123,6 +125,7 @@ class Provider:
self.state = ReplicaState.STOPPED
self.observer.finished.disconnect(self.search_finished)
self.observer.cookie_updated.disconnect(self._update_cookie)
self.client.unbind()
def _update_cookie(self, new_cookie):
logger.debug("Update from ourselves, cookie %s", new_cookie)
......@@ -172,9 +175,12 @@ class Provider:
self.catching_up_set = None
elif not self.behind and cookie not in self.cookie:
if not self.catching_up:
self.state = ReplicaState.CATCHING_UP
self.catching_up_from = self.cookie.copy()
self.catching_up_to = cookie.copy()
self.catching_up_set = self.catching_up_to - self.catching_up_from
self.catching_up_set = self.catching_up_to - \
self.catching_up_from
assert self.catching_up_set
self.catching_up.rearm()
self.behind.rearm()
......@@ -193,12 +199,13 @@ class Provider:
ldap0.TIMELIMIT_EXCEEDED,
ldap0.SIZELIMIT_EXCEEDED,
0x1000):
logger.error("Search to %r finished with %r, waiting before reconnection", self.uri, e)
await asyncio.sleep(1)
logger.warning("Search to %r finished with %r, "
"waiting before reconnection", self.uri, e)
await asyncio.sleep(0.1)
if result == 0x1000:
self.cookie = SyncreplCookie()
self.cookie_updated(None)
self.set_up()
await self.set_up()
class SyncreplEnvironment:
......@@ -222,10 +229,32 @@ class SyncreplEnvironment:
self.providers[uri] = provider
async def set_up(self):
await asyncio.gather(*(provider.set_up()
for _, provider in self.providers.items()))
async def add_server(self, uri, config=None):
if uri in self.providers:
raise ValueError(f"Provider at {uri!r} already set up")
provider_config = None
if config:
provider_config = config.get(uri, config)
provider = Provider(uri, provider_config, self.base, self.scope,
self.cookie, mode='refreshAndPersist')
provider.cookie_updated.connect(self.update_cookie)
self.cookie_updated.connect(provider.environment_update)
self.providers[uri] = provider
await provider.set_up()
def stop(self):
for uri, provider in self.providers.items():
provider.cookie_updated.disconnect(self.update_cookie)
self.cookie_updated.disconnect(provider.environment_update)
provider.stop()
def update_cookie(self, new_cookie):
if new_cookie:
......
......@@ -20,8 +20,6 @@ python-ldap0 asyncio wrapper
"""
import asyncio
import collections
import collections.abc
import logging
import ldap0 as ldap
......@@ -56,8 +54,7 @@ class LDAPRequest(asyncio.Future):
self._queue.put_nowait(message)
else:
self.full_result = message
# unfortunately, we have to assume LDAP_SUCCESS
self.set_result(ldap0.SUCCESS)
self.set_result(self.full_result)
def cancel(self, abandon=False):
if not abandon:
......@@ -95,14 +92,19 @@ class AsyncClient(ldap.ldapobject.LDAPObject):
self._have_reader = False
while self._in_progress:
self._in_progress.popitem()[1].cancel()
try:
self._in_progress.popitem()[1].cancel()
except ldap.SERVER_DOWN:
pass
def _read(self):
#print(self, "ready to read")
while True:
try:
message = self.result(msgid=ldap.RES_ANY,
all_results=ldap.MSG_ONE, timeout=0, add_intermediates=1)
except ldap.SERVER_DOWN as e:
all_results=ldap.MSG_ONE,
timeout=0, add_intermediates=1)
except ldap.SERVER_DOWN:
self._shutdown()
break
except ldap.LDAPError as e:
......@@ -131,17 +133,20 @@ class AsyncClient(ldap.ldapobject.LDAPObject):
if request:
request.cancel(abandon=True)
return super().abandon(msgid, *args, **kwargs)
try:
return super().abandon(msgid, *args, **kwargs)
except ldap.NO_SUCH_OPERATION:
return None
def unbind(self, *args, **kwargs):
"Send unbind and cancel all tasks"
self._shutdown()
return super().unbind_ext(*args, **kwargs)
return super().unbind(*args, **kwargs)
def __send_request(self, name, *args, **kwargs):
"Send a request and return the awaitable+iterable object"
method_name = name #+ '_ext'
method_name = name # + '_ext'
method = getattr(super(), method_name)
msgid = method(*args, **kwargs)
......@@ -161,7 +166,7 @@ class AsyncClient(ldap.ldapobject.LDAPObject):
def bind(self, *args, **kwargs):
"Bind operation"
return self.__send_request('bind', *args, **kwargs)
return self.__send_request('simple_bind', *args, **kwargs)
def delete(self, *args, **kwargs):
"Delete operation"
......
......@@ -40,6 +40,7 @@ logger = logging.getLogger(__name__)
class SyncreplState(Enum):
UNKNOWN = "unknown"
NOT_STARTED = "waiting for first response"
PRESENT = "present phase"
DELETE = "delete phase"
PERSIST = "persist stage"
......@@ -55,7 +56,7 @@ class SyncreplObserver:
def __init__(self, request, cookie=None):
self.request = request
self.cookie = SyncreplCookie(cookie)
self.state: SyncreplState = SyncreplState.PRESENT
self.state: SyncreplState = SyncreplState.NOT_STARTED
self.task = asyncio.create_task(self._listen())
......@@ -65,6 +66,8 @@ class SyncreplObserver:
self.message_received(message)
entry = message.rdata[0]
if isinstance(entry, (SearchResultEntry, SearchReference)):
if self.state == SyncreplState.NOT_STARTED:
self.state = SyncreplState.PRESENT
for control in entry.ctrls:
if isinstance(control, SyncStateControl):
cookie = getattr(control, 'cookie', None)
......@@ -82,7 +85,7 @@ class SyncreplObserver:
self.state = SyncreplState.PERSIST
else:
self.state = SyncreplState.UNKNOWN
if entry.cookie:
if entry.refreshDone or entry.cookie:
self._update_cookie(message, entry.cookie)
elif isinstance(entry, SyncInfoMessage):
self._update_cookie(message, entry.cookie)
......@@ -111,7 +114,7 @@ class SyncreplObserver:
self.finished()
def _update_cookie(self, message, cookie):
logger.debug("New cookie: %r", cookie.decode())
logger.debug("New cookie: %r", cookie and cookie.decode())
logger.debug("Transforming from state: %s", self.cookie)
self.cookie.update(cookie)
logger.debug("New state: %s", self.cookie)
......
......@@ -28,6 +28,7 @@ import sys
import urwid
from .environment import SyncreplEnvironment
from .debounce import limited
logger = logging.getLogger(__name__)
......@@ -60,8 +61,10 @@ class ProviderEntry(urwid.LineBox):
self.provider.state_changed.connect(self.state_changed)
self.provider.sid_discovered.connect(self.new_sid)
@limited(0.1)
def cookie_updated(self, cookie):
self.body.contents[1:] = [(urwid.Text(csn), self.body.options()) for csn in cookie]
self.body.contents[1:] = [(urwid.Text(csn), self.body.options())
for csn in cookie]
def state_changed(self, new_state):
self.state.set_text(new_state.value)
......@@ -124,6 +127,8 @@ async def run(args=None):
evl = urwid.AsyncioEventLoop(loop=loop)
app = App(options)
await app.environment.set_up()
global urwid_loop
urwid_loop = urwid.MainLoop(app, event_loop=evl,
unhandled_input=app.unhandled_input)
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment