#!/usr/bin/python3
# -*- coding: utf-8 -*-
#
# Copyright (C) 2011-2014 Julien Muchembled <jm@jmuchemb.eu>
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License version 3 as
# published by the Free Software Foundation.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.

import errno, logging, os, random, shutil, stat
import threading, tempfile, types, unittest
from collections import defaultdict, deque
fssync = types.ModuleType('fssync')
exec(compile(open('fssync').read(), os.path.realpath('fssync'), 'exec'),
     fssync.__dict__)

logger = fssync.logger
def setLevel(level=logging.DEBUG):
  for x in logging.lastResort, logger:
    x.setLevel(level)
#setLevel()  # uncomment to enable all logs

# XXX: http://braawi.org/genbackupdata/ (packaged by Debian) may help.

class Stat(fssync.Stat):

  _fake_dev = set()
  _fake_ino = {}
  _last_ino = 0

  def __init__(self, path):
    super(Stat, self).__init__(path)
    if self.key in self._fake_dev:
      self.dev += 1
    try:
      self.ino = self._fake_ino[self.ino]
    except KeyError:
      pass

  @classmethod
  def fake_dev(cls, path):
    cls._fake_dev.add(cls.__bases__[0](path).key)

  @classmethod
  def reset_dev(cls, path=None):
    if path:
      cls._fake_dev.remove(cls.__bases__[0](path).key)
    else:
      cls._fake_dev.clear()

  @classmethod
  def set_ino(cls, path, ino=None):
    if ino is None:
      # negative number to not conflict with real inodes
      cls._last_ino = ino = cls._last_ino - 1
    cls._fake_ino[os.lstat(path).st_ino] = ino
    return ino

  @classmethod
  def reset_ino(cls):
    cls._fake_ino.clear()
    cls._last_ino = 0

  @classmethod
  def _xstat(cls, s):
    if (s.st_dev, s.st_ino) in cls._fake_dev:
      return type(s)(s[:2] + (-1,) + s[3:])
    try:
      return type(s)(s[:1] + (cls._fake_ino[s.st_ino],) + s[2:])
    except KeyError:
      return s

  @classmethod
  def _lstat(cls, name, **kw):
    return cls._xstat(os.lstat(name, **kw))

fssync.Stat = Stat


class DummyRpcClient(object):

  def __init__(self, remote):
    self.remote = remote
    self._rpc = deque()
    self.called = defaultdict(int)

  def wait(self):
    name, args, kw = self._rpc.popleft()
    self.called[name] += 1
    if isinstance(args[0], bytes):
      args = (os.path.join(self.remote.root, args[0]),) + args[1:]
    if len(args) > 1 and (args[1] is None or
      isinstance(args[1], (int, bytes))):
      logger.debug('%s(%r, %r)', name, args[0], args[1])
    else:
      logger.debug('%s(%r)', name, args[0])
    return getattr(self.remote, name)(*args, **kw)

  def __getattr__(self, name):
    append = self._rpc.append
    f = lambda *args, **kw: append((name, args, kw))
    setattr(self, name, f)
    return f

  def check(self, path):
    o = os.pipe()
    i = os.pipe()
    self.stdin = open(i[0], "rb")
    self.stdout = open(o[1], "wb", 0)
    def t():
      with open(o[0], "rb") as stdin, open(i[1], "wb", 0) as stdout:
        r = self.remote._check(stdin, stdout, path)
        fssync.write_rpc(stdout, fssync.dumps(r))
    threading.Thread(target=t, daemon=True).start()
    self.wait = types.MethodType(fssync.RpcClient.wait, self)

  def send(self, value):
    fssync.RpcClient.send(self, value)
    if value is None:
      self.stdout.close()
      def wait():
        try:
          return fssync.RpcClient.wait(self)
        finally:
          self.stdin.close()
          del self.stdin, self.stdout, self.wait
      self.wait = wait


class Local(fssync.Local):

  force = False
  prealloc = True

  def __init__(self):
    super(Local, self).__init__(fssync.encode(tempfile.mkdtemp()), ':memory:',
                                DummyRpcClient(Remote()))

  def __del__(self):
    shutil.rmtree(self.root)
    super(Local, self).__del__()
    Stat.reset_ino()

  def is_masked(self, path):
    return super(Local, self).is_masked(path, Stat._lstat)

  @property
  def remote(self):
    return self.rpc.remote


class Remote(fssync.Remote):

  def __init__(self):
    super(Remote, self).__init__(fssync.encode(tempfile.mkdtemp()))

  def __del__(self):
    shutil.rmtree(self.root)


def gen_data(size, __blob=bytes(int(random.gauss(0, .8)) % 256
                                for x in range(100000))):
  i = random.randrange(len(__blob) - size)
  return __blob[i:i+size]


class Test(unittest.TestCase):

  os_fstat = staticmethod(os.fstat)
  os_listdir = staticmethod(os.listdir)

  def setUp(self):
    self.fssync = Local()
    os.chdir(self.fssync.root)
    os.fstat = lambda fd: Stat._xstat(self.os_fstat(fd))
    os.listdir = lambda p: sorted(self.os_listdir(p))

  def tearDown(self):
    os.fstat = self.os_fstat
    os.listdir = self.os_listdir
    del self.fssync

  def mkreg(self, path, size=0, sparse_map=0, mode=None):
    d = os.path.dirname(path)
    if d and not os.path.exists(d):
      os.makedirs(d)
    with open(path, 'wb') as f:
      if size:
        f.write(gen_data(size) if isinstance(size, int) else size)
    if mode is not None:
      os.chmod(path, mode)

  def assertSynced(self):
    isdir = os.path.isdir
    join = os.path.join
    dst_root = self.fssync.remote.root
    src = {}
    dst = {}
    for inodes, root in (src, b'.'), (dst, self.fssync.remote.root):
      n = len(root) + 1
      for r, dirs, files in os.walk(root):
        for names in dirs, files:
          for name in names:
            p = join(r, name)
            s = Stat(p)
            fmt = stat.S_IFMT(s.mode)
            if fmt == stat.S_IFLNK:
              data = os.readlink(p)
            elif fmt == stat.S_IFREG:
              with open(p, 'rb') as f:
                data = f.read()
            else:
              data = None
            try:
              inodes[s.key] += p[n:],
            except KeyError:
              inodes[s.key] = s.value, data, p[n:]
    kw = {'key': lambda x: x[2:]}
    self.assertListEqual(sorted(src.values(), **kw), sorted(dst.values(), **kw))

  def assertNotSynced(self):
    self.assertRaises(self.failureException, self.assertSynced)

  @property
  def called(self):
    return self.fssync.rpc.called

  def assertCalled(self, **kw):
    self.assertDictEqual(self.called, kw)

  def sync(self, clean=True, synced=True):
    self.fssync.sync(b'')
    if clean:
      self.fssync.clean(b'')
    del self.fssync.masked[:]
    (self.assertSynced if synced else self.assertNotSynced)()

  def test1(self):
    import xml
    shutil.copytree(fssync.encode(os.path.dirname(xml.__file__)),
                    b'a', symlinks=True)
    self.assertNotSynced()
    self.sync()
    self.assertEqual(sorted(self.called), ['check_data', 'sync_data',
                                           'sync_meta', 'truncate'])

    self.called.clear()
    self.sync()
    self.assertCalled()

    os.rename('a', 'b')
    self.assertNotSynced()
    self.sync()
    self.assertCalled(rename=1, sync_meta=1)

    self.called.clear()
    os.rename('b', 'c')
    Stat.set_ino('c')
    self.sync()
    self.assertEqual(sorted(self.called), ['link', 'removemany',
                                           'rename', 'sync_meta'])

  def test2(self):
    self.mkreg('a')
    os.mkdir('b')
    os.link('a', 'b/c')
    self.sync()
    self.assertCalled(link=1, sync_meta=3)

    self.called.clear()
    Stat.fake_dev('b')
    self.sync()

    os.rename('a', 'c')
    self.assertRaises(SystemExit, self.sync)
    self.assertCalled()

    Stat.reset_dev('b')
    self.sync()
    self.assertCalled(link=1, removemany=1, sync_meta=1)

    self.called.clear()
    os.link('c', 'd')
    self.fssync.remote.removemany([b'b/c'])
    self.sync(synced=False)
    self.sync()
    self.assertCalled(link=3, sync_meta=2)

    self.called.clear()
    os.rename('b', 'a')
    shutil.rmtree(self.fssync.remote.root + b'/b')
    self.sync()
    self.assertCalled(link=1, rename=1, sync_meta=2)

    i = Stat.set_ino('c')
    self.sync()

    self.called.clear()
    for x in 'a/c', 'c', 'd':
      os.remove(x)
    os.symlink('.', 'c')
    os.link('c', 'd')
    Stat.set_ino('c', i)
    self.sync()
    self.assertCalled(link=1, remove=1, removemany=1, symlink=1, sync_meta=3)

    self.called.clear()
    os.remove('c')
    os.remove('d')
    os.symlink('..', 'c')
    os.link('c', 'd')
    Stat.set_ino('c', i)
    os.symlink('.', 'a/a')
    i = Stat.set_ino('a/a')
    self.sync()
    self.assertCalled(link=1, readlink=1, remove=2, symlink=2, sync_meta=4)

    self.called.clear()
    os.rename('c', 'a/c')
    for x in '.', 'a', 'a/a', 'a/c':
      os.utime(x, (1, 1), follow_symlinks=False)
    self.sync()
    self.assertCalled(link=1, readlink=2, removemany=1, sync_meta=4)

    self.called.clear()
    self.mkreg('a/b', b'foo')
    for x in 'a', 'a/b':
      os.utime(x, (1, 1))
    self.sync()
    self.assertCalled(check_data=1, sync_data=1, sync_meta=2)

    # Check --force for regular files ...
    self.called.clear()
    with open('a/b', 'wb') as f:
      f.write(b'bar')
    os.utime('a/b', (1, 1))
    self.sync(synced=False)
    self.assertCalled()

    self.fssync.force = True
    self.sync()
    del self.fssync.force
    self.assertCalled(check_data=1, readlink=3, sync_data=1, sync_meta=1)

    # ... and symlinks
    os.remove('a/a')
    os.symlink('b', 'a/a')
    os.utime('a/a', (1, 1), follow_symlinks=False)
    Stat.set_ino('a/a', i)
    self.sync(synced=False)
    self.called.clear()
    self.fssync.force = True
    self.sync()
    del self.fssync.force
    self.assertCalled(check_data=1, readlink=3,
                      remove=1, symlink=1, sync_meta=3)

    self.called.clear()
    x = self.fssync.remote.root + b'/a/'
    os.rename(x + b'c', x + b'd')
    self.sync(synced=False)
    self.sync(synced=False) # retrying obviously does not help
    self.assertListEqual([b'a', b'a/c'], list(self.fssync.check(b'')))
    self.assertListEqual([], list(self.fssync.check(b'')))
    self.sync()
    self.assertCalled(link=1, removemany=1, sync_meta=1)

    os.mkdir(x + b'd', 0)
    self.assertListEqual([b'a'], list(self.fssync.check(b'')))
    self.sync()

    self.mkreg('a/b')
    self.sync()
    with open('a/b', 'wb') as f:
        f.write(b'foo')
    try:
        self.fssync.remote.sync_meta = None
        self.assertRaises(TypeError, self.sync)
        self.fssync.remote._close()
    finally:
        del self.fssync.remote.sync_meta
    with open('a/b', 'wb') as f:
        f.write(b'bar')
    os.utime('a/b', (1, 1))
    self.called.clear()
    self.sync()
    self.assertCalled(check_data=1, sync_data=1, sync_meta=1)


if __name__ == '__main__':
  unittest.main()
