#!/usr/bin/env python3
import ctypes
import ctypes.util
import errno
import sys
import os
import fcntl
import termios
import time

# <sched.h> constants for unshare
CLONE_NEWNS = 0x00020000
CLONE_NEWPID = 0x20000000

# <sys/mount.h> - constants for mount
MS_REC = 16384
MS_PRIVATE = 1 << 18
MS_SLAVE = 1 << 19

# Load libc bindings
_libc = ctypes.CDLL(ctypes.util.find_library("c"), use_errno=True)

try:
    _unshare = _libc.unshare
except AttributeError:
    raise OSError(errno.EINVAL, "unshare is not supported on this platform")
else:
    _unshare.argtypes = [ctypes.c_int]
    _unshare.restype = ctypes.c_int

try:
    _setns = _libc.setns
except AttributeError:
    raise OSError(errno.EINVAL, "setns is not supported on this platform")
else:
    _setns.argtypes = [ctypes.c_int, ctypes.c_int]
    _setns.restype = ctypes.c_int

try:
    _mount = _libc.mount
except AttributeError:
    raise OSError(errno.EINVAL, "mount is not supported on this platform")
else:
    _mount.argtypes = [
        ctypes.c_char_p,
        ctypes.c_char_p,
        ctypes.c_char_p,
        ctypes.c_ulong,
        ctypes.c_void_p
    ]
    _mount.restype = ctypes.c_int


def _mount_new_proc():
    """
    Mount new /proc filesystem.
    """
    if _mount(None, b"/", None, MS_SLAVE | MS_REC, None):
        _errno = ctypes.get_errno()
        raise OSError(_errno, errno.errorcode[_errno])
    if _mount(b'proc', b'/proc', b'proc', 0, None):
        _errno = ctypes.get_errno()
        raise OSError(_errno, errno.errorcode[_errno])


def _wait_for_process_status(criu_pid):
    """
    Wait for CRIU to exit and report the status back.
    """
    while True:
        try:
            (pid, status) = os.wait()
            if pid == criu_pid:
                # The following code block is based on
                # os.waitstatus_to_exitcode() introduced in Python 3.9
                # and we implement this for comparability with older
                # versions of Python.
                if os.WIFSIGNALED(status):
                    return os.WTERMSIG(status)
                elif os.WIFEXITED(status):
                    return os.WEXITSTATUS(status)
                elif os.WIFSTOPPED(status):
                    return os.WSTOPSIG(status)
                else:
                    raise Exception("CRIU was terminated by an "
                                    "unidentified reason")
        except OSError:
            return -251


def run_criu(args):
    """
    Spawn CRIU binary
    """
    print(sys.argv)

    if "--criu-binary" in args:
        try:
            opt_index = args.index("--criu-binary")
            path = args[opt_index + 1]
            del args[opt_index:opt_index + 2]
            args.insert(0, "criu")
            os.execv(path, args)
            raise OSError(errno.ENOENT, "No such command")
        except (ValueError, IndexError, FileNotFoundError):
            raise OSError(errno.ENOENT, "--criu-binary missing argument")
    else:
        args.insert(0, "criu")
        os.execvp("criu", args)
        raise OSError(errno.ENOENT, "No such command")


# pidns_holder creates a process that is reparented to the init.
#
# The init process can exit if it doesn't have any child processes and its
# pidns is destroyed in this case. CRIU dump is running in the target pid
# namespace and it kills dumped processes at the end. We need to create a
# holder process to be sure that the pid namespace will not be destroy before
# criu exits.
def pidns_holder():
    r, w = os.pipe()
    pid = os.fork()
    if pid == 0:
        pid = os.fork()
        if pid == 0:
            os.close(w)
            # The write end is owned by the parent process and it is closed by
            # kernel when the parent process exits.
            os.read(r, 1)
        sys.exit(0)
    os.waitpid(pid, 0)


def wrap_restore():
    restore_args = sys.argv[1:]
    if '--restore-sibling' in restore_args:
        raise OSError(errno.EINVAL, "--restore-sibling is not supported")

    # Unshare pid namespace
    if _unshare(CLONE_NEWPID) != 0:
        _errno = ctypes.get_errno()
        raise OSError(_errno, errno.errorcode[_errno])

    restore_detached = False
    if '-d' in restore_args:
        restore_detached = True
        restore_args.remove('-d')
    if '--restore-detached' in restore_args:
        restore_detached = True
        restore_args.remove('--restore-detached')

    restore_pidfile = None
    if '--pidfile' in restore_args:
        try:
            opt_index = restore_args.index('--pidfile')
            restore_pidfile = restore_args[opt_index + 1]
            del restore_args[opt_index:opt_index + 2]
        except (ValueError, IndexError, FileNotFoundError):
            raise OSError(errno.ENOENT, "--pidfile missing argument")

        if not restore_pidfile.startswith('/'):
            for base_dir_opt in ['--work-dir', '-W', '--images-dir', '-D']:
                if base_dir_opt in restore_args:
                    try:
                        opt_index = restore_args.index(base_dir_opt)
                        restore_pidfile = os.path.join(restore_args[opt_index + 1], restore_pidfile)
                        break
                    except (ValueError, IndexError, FileNotFoundError):
                        raise OSError(errno.ENOENT, base_dir_opt + " missing argument")

    criu_pid = os.fork()
    if criu_pid == 0:
        # Unshare mount namespace
        if _unshare(CLONE_NEWNS) != 0:
            _errno = ctypes.get_errno()
            raise OSError(_errno, errno.errorcode[_errno])

        os.setsid()
        # Set stdin tty to be a controlling tty of our new session, this is
        # required by --shell-job option, as for it CRIU would try to set a
        # process group of restored root task to be a foreground group on the
        # terminal.
        if '--shell-job' in restore_args or '-j' in restore_args:
            if os.isatty(sys.stdin.fileno()):
                fcntl.ioctl(sys.stdin.fileno(), termios.TIOCSCTTY, 1)
            else:
                raise OSError(errno.EINVAL, 'The stdin is not a tty for a --shell-job')

        _mount_new_proc()
        run_criu(restore_args)

    if restore_pidfile:
        restored_pid = None
        retry = 5

        while not restored_pid and retry:
            with open('/proc/%d/task/%d/children' % (criu_pid, criu_pid)) as f:
                line = f.readline().strip()
                if len(line):
                    restored_pid = line
                    break
            retry -= 1
            time.sleep(1)

        if restored_pid:
            with open(restore_pidfile, 'w+') as f:
                f.write(restored_pid)
        else:
            print("Warn: Search of restored pid for --pidfile option timeouted")

    if restore_detached:
        return 0

    return _wait_for_process_status(criu_pid)


def get_varg(args):
    for i in range(1, len(sys.argv)):
        if sys.argv[i] not in args:
            continue

        if i + 1 >= len(sys.argv):
            break

        return (sys.argv[i + 1], i + 1)

    return (None, None)


def _set_namespace(fd):
    """Join namespace referred to by fd"""
    if _setns(fd, 0) != 0:
        _errno = ctypes.get_errno()
        raise OSError(_errno, errno.errorcode[_errno])


def is_my_namespace(fd, ns):
    """Returns True if fd refers to current namespace"""
    return os.stat('/proc/self/ns/%s' % ns).st_ino == os.fstat(fd).st_ino


def set_pidns(tpid, pid_idx):
    """
    Join pid namespace. Note, that the given pid should
    be changed in -t option, as task lives in different
    pid namespace.
    """
    ns_fd = os.open('/proc/%s/ns/pid' % tpid, os.O_RDONLY)
    if not is_my_namespace(ns_fd, "pid"):
        for line in open('/proc/%s/status' % tpid):
            if not line.startswith('NSpid:'):
                continue
            ls = line.split()
            if ls[1] != tpid:
                os.close(ns_fd)
                raise OSError(errno.ESRCH, 'No such pid')

            print('Replace pid {} with {}'.format(tpid, ls[2]))
            sys.argv[pid_idx] = ls[2]
            break
        else:
            os.close(ns_fd)
            raise OSError(errno.ENOENT, 'Cannot find NSpid field in proc')
        _set_namespace(ns_fd)
    os.close(ns_fd)


def set_mntns(tpid):
    """
    Join mount namespace. Trick here too -- check / and .
    will be the same in target mntns.
    """
    ns_fd = os.open('/proc/%s/ns/mnt' % tpid, os.O_RDONLY)
    if not is_my_namespace(ns_fd, "mnt"):
        root_st = os.stat('/')
        cwd_st = os.stat('.')
        cwd_path = os.path.realpath('.')

        _set_namespace(ns_fd)

        os.chdir(cwd_path)
        root_nst = os.stat('/')
        cwd_nst = os.stat('.')

        def steq(st, nst):
            return (st.st_dev, st.st_ino) == (nst.st_dev, nst.st_ino)

        if not steq(root_st, root_nst):
            os.close(ns_fd)
            raise OSError(errno.EXDEV, 'Target ns / is not as current')
        if not steq(cwd_st, cwd_nst):
            os.close(ns_fd)
            raise OSError(errno.EXDEV, 'Target ns . is not as current')
    os.close(ns_fd)


def wrap_dump():
    (pid, pid_idx) = get_varg(('-t', '--tree'))
    if pid is None:
        raise OSError(errno.EINVAL, 'No --tree option given')

    set_pidns(pid, pid_idx)
    set_mntns(pid)

    pidns_holder()

    criu_pid = os.fork()
    if criu_pid == 0:
        run_criu(sys.argv[1:])
    return _wait_for_process_status(criu_pid)


def show_usage():
    print("""
Usage:
  {0} dump|pre-dump -t PID [<options>]
  {0} restore [<options>]
\nCommands:
  dump           checkpoint a process/tree identified by pid
  pre-dump       pre-dump task(s) minimizing their frozen time
  restore        restore a process/tree
  check          checks whether the kernel support is up-to-date
""".format(sys.argv[0]))


if __name__ == "__main__":
    if len(sys.argv) == 1:
        show_usage()
        exit(1)

    action = sys.argv[1]
    if action == 'restore':
        res = wrap_restore()
    elif action in ['dump', 'pre-dump']:
        res = wrap_dump()
    elif action == 'check':
        run_criu(sys.argv[1:])
    else:
        print('Unsupported action {} for criu-ns'.format(action))
        res = -1

    sys.exit(res)
