#!/usr/bin/env python
# -*- coding: utf-8 -*-

"""
    Twitr aims to be a quirk test of denormalized data set and how to do
    that with SQLAlchemy. The goal is to ease read operations.

    Users' information are stored in the main db (called "base")

    Relationships and messages are stored into different db depending of
    the user itself. Each user has all the information he needs at a single
    place. It means a lot of copies but join are done are write time only.

    And some of the write operations can be done in a background process,
    see lwqueue.

    In fact SQLAlchemy isn't that handy because relations are broken
    by design and the lovely ORM is just here to avoid you writing SQL
    but you still need to have a deep overview of what's behind.
    
    init:
        CREATE DATABASE twitr_base CHARACTER SET utf8;
        CREATE DATABASE twitr_a CHARACTER SET utf8;
        CREATE DATABASE twitr_b CHARACTER SET utf8;
        CREATE USER 'twitr'@'localhost' IDENTIFIED BY 'twitr';
        SET PASSWORD FOR 'twitr'@'localhost' = PASSWORD('twitr');
        GRANT ALL PRIVILEGES ON twitr_base.* TO 'twitr'@'localhost';
        GRANT ALL PRIVILEGES ON twitr_a.* TO 'twitr'@'localhost';
        GRANT ALL PRIVILEGES ON twitr_b.* TO 'twitr'@'localhost';
    
    clean up:
        DROP TABLE twitr_base.users;
        DROP TABLE twitr_a.relations;
        DROP TABLE twitr_a.twits;
        DROP TABLE twitr_b.relations;
        DROP TABLE twitr_b.twits;

    test:
        python twitr.py
"""

__author__ = "Yoan Blanc <yoan @ dosimple.ch>"

import sys
import random

import sqlalchemy as sa
from sqlalchemy import orm
from sqlalchemy import schema

from datetime import datetime

# Config
CONN = r"mysql://twitr:twitr@localhost:3306/%(dbname)s"
BASE = "base"
DB = "a", "b"

# Model
class User(object):
    def __init__(self, name, database):
        self.name = name
        self.database = database

class Relation(object):
    def __init__(self, subscriber, to):
        self.from_id = subscriber.uid
        self.to_id = to.uid

# fail!
class RelationA(Relation): pass
class RelationB(Relation): pass

class Twit(object):
    def __init__(self, user, message, to=None):
        if to is None:
            self.user_id = user.uid
        else:
            self.user_id = to.uid

        self.from_id = user.uid
        self.from_name = user.name
        self.message = message
        self.created_at = datetime.now()

# epic fail!
class TwitA(Twit): pass
class TwitB(Twit): pass

# DB
class Connection(object):
    """
    Handling a connection to a databse
    """
    def __init__(self, uri, schema):
        engine = sa.create_engine(uri)
        Session = orm.sessionmaker(bind=engine, autoflush=False, transactional=True)
        self.session = orm.scoped_session(Session)
        schema.metadata.create_all(bind=engine, checkfirst=True)

    def __del__(self):
        self.session.close()

class Schema(object):
    """Abstract schema"""
    def __init__(self):
        self.metadata = sa.MetaData()
        self.tables = {}

class UserSchema(Schema):
    """
    User informations schema
    """
    def __init__(self):
        super(UserSchema, self).__init__()

        users = sa.Table("users",
            self.metadata,
            sa.Column("uid", sa.Integer, primary_key=True),
            sa.Column("name", sa.types.Unicode(25), index=True, unique=True, nullable=False),
            sa.Column("database", sa.types.String(1), nullable=False))
        
        orm.mapper(User, users)
        self.User = User

        self.tables = {
            "users": users 
        }

class DataSchema(Schema):
    """
    Schema for the data: relations and twits
    """
    def __init__(self, name="a"):
        super(DataSchema, self).__init__()
        self.metadata = schema.ThreadLocalMetaData()
        
        relations = sa.Table("relations",
            self.metadata,
            sa.Column("rid", sa.Integer, primary_key=True),
            sa.Column("from_id", sa.Integer, nullable=False),
            sa.Column("to_id", sa.Integer, nullable=False))

        twits = sa.Table("twits",
            self.metadata,
            sa.Column("tid", sa.Integer, primary_key=True),
            sa.Column("user_id", sa.Integer, nullable=False),
            sa.Column("from_id", sa.Integer, nullable=False),
            sa.Column("from_name", sa.Unicode(25)),
            sa.Column("message", sa.Unicode(255), nullable=False),
            sa.Column("created_at", sa.DateTime(), nullable=False))
        
        # fail!
        if name == "a":
            self.Relation = RelationA
            self.Twit = TwitA
        else:
            self.Relation = RelationB
            self.Twit = TwitB

        orm.mapper(self.Relation, relations)
        orm.mapper(self.Twit, twits, order_by=[twits.c.tid.desc()])

        self.tables = {
            "relations": relations,
            "twits": twits
        }

# Controller (kind of)
class Twitr(object):
    prefix = "twitr"

    def __init__(self, base, databases):
        self.user = UserSchema()
        self.connections = {
            "base": Connection(CONN % {"dbname": "%s_%s" % (self.prefix, base)}, self.user),
        }
        
        self.data = {}
        for db in databases:
            self.data[db] = DataSchema(db)
            self.connections[db] =  Connection(CONN % {"dbname": "%s_%s" % (self.prefix, db)},
                self.data[db])

    def create_user(self, name):
        session = self.connections["base"].session
        user = self.user.User(name, random.choice(DB))
        
        session.save(user)
        session.commit()
        
        return user

    def subscribe(self, subscriber, to, db=None):
        if db is None:
            db = subscriber.database
        
        session = self.connections[db].session

        relation = self.data[db].Relation(subscriber, to)
        
        session.save(relation)
        session.commit()
        
        # TODO: do this in a background process
        if db != to.database:
            self.subscribe(subscriber, to, to.database)
            
        
    def post(self, user, message):
        db = user.database
        session = self.connections[db].session

        twit = self.data[db].Twit(user, message)

        session.save(twit)
        session.commit()

        # TODO: do this in a background process
        q = self.connections["base"].session.query(User)

        query = session.query(self.data[db].Relation)
        relations = query.filter_by(to_id=user.uid).all()
        for relation in relations:
            subscriber = q.get(relation.from_id)
            db = subscriber.database
            session = self.connections[db].session

            twit = self.data[db].Twit(user, message, subscriber)
            
            session.save(twit)
            session.commit()

    def list(self, user):
        db = user.database
        session = self.connections[db].session

        q = session.query(self.data[db].Twit)
        return q.filter_by(user_id=user.uid).all()

def main(argv):
    twitr = Twitr(BASE, DB)
    
    yoan = twitr.create_user(u"yoan")
    batiste = twitr.create_user(u"batiste")
    jon = twitr.create_user(u"jonathan")
    bot = twitr.create_user(u"bot")

    twitr.subscribe(yoan, batiste)
    twitr.subscribe(yoan, jon)
    twitr.subscribe(batiste, jon)
    twitr.subscribe(jon, yoan)
    twitr.subscribe(bot, yoan)
    twitr.subscribe(bot, batiste)
    twitr.subscribe(bot, jon)

    twitr.post(yoan, u"Hello World!")
    twitr.post(yoan, u"This is a test.")
    twitr.post(batiste, u"I'm on holidays, stop bothering me!")
    twitr.post(jon, u"Yawn!!1!")
    twitr.post(yoan, u"~ the end ~")

    print " Twitr "
    print "======="
    print

    for user in (yoan, batiste, jon, bot):
        print "", user.name
        print "-"*(len(user.name)+2)
        for twit in twitr.list(user):
            print u'%10s: "%s"' % (twit.from_name, twit.message)
        print
    print

if __name__ == "__main__":
    main(sys.argv)

