Skip to content
Snippets Groups Projects
common.py 3.14 KiB
import json
import logging

import jsonschema
import redis
import redis.sentinel

logger = logging.getLogger(__name__)

DB_LATCH_SCHEMA = {
    "$schema": "http://json-schema.org/draft-07/schema#",
    "type": "object",
    "properties": {
        "current": {"type": "integer"},
        "next": {"type": "integer"},
        "this": {"type": "integer"}
    },
    "required": ["current", "next", "this"],
    "additionalProperties": False
}


def get_latch(r):
    latch = r.get('db:latch')
    if latch is None:
        logger.error('no latch key found in db')
        return None
    try:
        latch = json.loads(latch.decode('utf-8'))
        jsonschema.validate(latch, DB_LATCH_SCHEMA)
    except (jsonschema.ValidationError, json.JSONDecodeError):
        logging.exception('error validating latch value')
        return None

    return latch


def set_latch(config, new_current, new_next):

    for db in config['redis-databases']:
        latch = {
            'current': new_current,
            'next': new_next,
            'this': db
        }

        r = _get_redis(config, dbid=db)
        r.set('db:latch', json.dumps(latch))


def latch_db(config):
    db_ids = config['redis-databases']
    db_ids = sorted(set(db_ids))

    r = get_next_redis(config)
    latch = get_latch(r)
    if not latch:
        latch = {
            'current': db_ids[0],
            'next': db_ids[0]
        }

    next_idx = db_ids.index(latch['next'])
    next_idx = (next_idx + 1) % len(db_ids)

    set_latch(config, new_current=latch['next'], new_next=db_ids[next_idx])


def _get_redis(config, dbid=None):

    if dbid is None:
        logger.debug('no db specified, using minimum as first guess')
        dbid = min(config['redis-databases'])

    if dbid not in config['redis-databases']:
        logger.error('tried to connect to unknown db id: {}'.format(dbid))
        dbid = min(config['redis-databases'])

    kwargs = {
        'db': dbid,
        'socket_timeout': 0.1
    }

    if 'sentinel' in config:
        sentinel = redis.sentinel.Sentinel([(
            config['sentinel']['hostname'],
            config['sentinel']['port'])],
            **kwargs)
        return sentinel.master_for(
            config['sentinel']['name'],
            socket_timeout=0.1)
    else:
        return redis.StrictRedis(
            host=config['redis']['hostname'],
            port=config['redis']['port'],
            **kwargs)


def get_current_redis(config):
    r = _get_redis(config)
    latch = get_latch(r)
    if not latch:
        logger.warning("can't determine current db")
        return r
    if latch['this'] == latch['current']:
        return r
    else:
        return _get_redis(config, latch['current'])


def get_next_redis(config):
    r = _get_redis(config)
    latch = get_latch(r)
    if latch and latch['this'] == latch['next']:
        return r

    if latch and latch['next'] in config['redis-databases']:
        next_id = latch['next']
    else:
        logger.warning("next db not configured, deriving default value")
        db_ids = sorted(set(config['redis-databases']))
        next_id = db_ids[0] if len(db_ids) == 1 else db_ids[1]

    return _get_redis(config, next_id)