#!/usr/bin/python3

'''Driver package query/installation tool for Ubuntu'''

# (C) 2012 Canonical Ltd.
# Author: Martin Pitt <martin.pitt@ubuntu.com>
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2 of the License, or
# (at your option) any later version.

import argparse
import subprocess
import fnmatch
import sys
import os
import logging
import apt

import UbuntuDrivers.detect

sys_path = os.environ.get('UBUNTU_DRIVERS_SYS_DIR')

# Make sure that the PATH environment variable is set
# See LP: #1854472
if not os.environ.get('PATH'):
    os.environ['PATH'] = '/sbin:/usr/sbin:/bin:/usr/bin'

logger = logging.getLogger()
logger.setLevel(logging.WARNING)

def create_parser():
    '''Create parser for command line arguments.'''
    deprecated = ('autoinstall',)

    # build help for commands
    commands = []
    command_help = 'Available commands:'
    for c in globals():
        if c.startswith('command_'):
            name = c.split('_', 1)[1].replace('_','-')
            commands.append(name)
            if not fnmatch.filter(deprecated, name):
                command_help += '\n   %s: %s' % (name, globals()[c].__doc__)

    parser = argparse.ArgumentParser(description='List/install driver packages for Ubuntu.',
            epilog=command_help, formatter_class=argparse.RawDescriptionHelpFormatter)
    parser.add_argument('command', metavar='<command>', nargs='?', choices=commands,
            help='See the "Available commands" section below')

    parser.add_argument('--package-list', metavar='PATH', 
            help='Create file with list of installed packages (in install mode)')

    if sys.version_info.major == 3 and sys.version_info.minor >= 6:
        parser.add_argument("--gpgpu", const='default', nargs='?', metavar='driver[:version][,driver[:version]]', help="Install GPGPU drivers")
    else:
        # Python 3.5.2 in Xenial will crash when matching square brackets
        parser.add_argument("--gpgpu", const='default', nargs='?', metavar='driver{:version}{,driver{:version}}', help="Install GPGPU drivers")

    parser.add_argument('--free-only', action="store_true", default=False,
                        help='Only consider free packages')
    parser.add_argument('--no-oem', action="store_false", default=True,
                        dest='install_oem_meta',
                        help='Do not include OEM enablement packages (these enable an external archive)')
    return parser

def command_list(args):
    '''Show all driver packages which apply to the current system.'''

    packages = UbuntuDrivers.detect.system_driver_packages(
        sys_path=sys_path, freeonly=args.free_only, include_oem=args.install_oem_meta)

    cache = apt.Cache()
    for package in packages:
        try:
            linux_modules = UbuntuDrivers.detect.get_linux_modules_metapackage(cache, package)
            if (not linux_modules and package.find('dkms') != -1):
                linux_modules = package

            if linux_modules:
                print('%s, (kernel modules provided by %s)' % (package, linux_modules))
            else:
                print(package)
        except KeyError:
            print(package)

    return 0

def command_list_oem(args):
    '''Show all OEM enablement packages which apply to this system'''

    if not args.install_oem_meta:
        return 0

    packages = UbuntuDrivers.detect.system_device_specific_metapackages(
        sys_path=sys_path, include_oem=args.install_oem_meta)

    if packages:
        print('\n'.join(packages))

        if args.package_list:
            with open(args.package_list, 'a') as f:
                f.write('\n'.join(packages))
                f.write('\n')

    return 0


def list_gpgpu(args):
    '''Show all GPGPU driver packages which apply to the current system.'''
    found = False
    cache = apt.Cache()
    packages = UbuntuDrivers.detect.system_gpgpu_driver_packages(cache, sys_path)
    for package in packages:
        candidate = packages[package]['metapackage']
        if candidate:
            print('%s, (kernel modules provided by %s)' % (candidate, UbuntuDrivers.detect.get_linux_modules_metapackage(cache, candidate)))

    return 0

def command_devices(args):
    '''Show all devices which need drivers, and which packages apply to them.'''

    drivers = UbuntuDrivers.detect.system_device_drivers(
        sys_path=sys_path, freeonly=args.free_only)
    for device, info in drivers.items():
        print('== %s ==' % device)
        for k, v in info.items():
            if k == 'drivers':
                continue
            print('%-9s: %s' % (k, v))

        for pkg, pkginfo in info['drivers'].items():
            info_str = ''
            if pkginfo['from_distro']:
                info_str += ' distro'
            else:
                info_str += ' third-party'
            if pkginfo['free']:
                info_str += ' free'
            else:
                info_str += ' non-free'
            if pkginfo.get('builtin'):
                info_str += ' builtin'
            if pkginfo.get('recommended'):
                info_str += ' recommended'
            print('%-9s: %s -%s' % ('driver', pkg, info_str))
        print('')

def command_install(args):
    '''Install drivers that are appropriate for your hardware.'''

    cache = apt.Cache()

    packages = UbuntuDrivers.detect.system_driver_packages(
        cache, sys_path, freeonly=args.free_only,
        include_oem=args.install_oem_meta)
    packages = UbuntuDrivers.detect.auto_install_filter(packages)
    if not packages:
        print('No drivers found for installation.')
        return

    # ignore packages which are already installed
    to_install = []
    for p in packages:
        if not cache[p].installed:
            to_install.append(p)
            # Add the matching linux modules package when available
            try:
                modules_package = UbuntuDrivers.detect.get_linux_modules_metapackage(cache, p)
                if modules_package and not cache[modules_package].installed:
                    to_install.append(modules_package)
            except KeyError:
                pass

    if not packages:
        print('All the available drivers are already installed.')
        return

    ret = subprocess.call(['apt-get', 'install', '-o',
        'DPkg::options::=--force-confnew', '-y'] + to_install)

    oem_meta_to_install = fnmatch.filter(to_install, 'oem-*-meta')

    # create package list
    if ret == 0 and args.package_list:
        with open(args.package_list, 'a') as f:
            f.write('\n'.join(to_install))
            f.write('\n')
    elif ret != 0:
        return ret

    for package_to_install in oem_meta_to_install:
        sources_list_path = os.path.join(os.path.sep,
                                         'etc',
                                         'apt',
                                         'sources.list.d',
                                         F'{package_to_install}.list')

        update_ret = subprocess.call(['apt',
                                      '-o', F'Dir::Etc::SourceList={sources_list_path}',
                                      '-o', 'Dir::Etc::SourceParts=/dev/null',
                                      '--no-list-cleanup',
                                      'update'])

        if update_ret != 0:
            return update_ret

    # All updates completed successfully, now let's upgrade the packages
    if oem_meta_to_install:
        ret = subprocess.call(['apt',
                               'install',
                               '-o', 'DPkg::Options::=--force-confnew',
                               '-y'] + oem_meta_to_install)

    return ret


def command_autoinstall(args):
    '''Install drivers that are appropriate for automatic installation. [DEPRECATED]'''
    return command_install(args)

def install_gpgpu(args):
    '''Install GPGPU drivers that are appropriate for your hardware.'''
    candidate = ''
    # No args, just --gpgpu
    if args.gpgpu == 'default':
        not_found_exit_status = 0
    else:
        # Just one driver
        # e.g. --gpgpu 390
        #      --gpgpu nvidia:390
        #
        # Or Multiple drivers
        # e.g. --gpgpu nvidia:390,amdgpu
        not_found_exit_status = 1

    cache = apt.Cache()

    packages = UbuntuDrivers.detect.system_gpgpu_driver_packages(cache, sys_path)
    packages = UbuntuDrivers.detect.gpgpu_install_filter(packages, args.gpgpu)
    if not packages:
        print('No drivers found for installation.')
        return not_found_exit_status

    # ignore packages which are already installed
    to_install = []
    for p in packages:
        candidate = packages[p].get('metapackage')
        print(candidate)
        if candidate and not cache[candidate].installed:
            to_install.append(candidate)

    if candidate:
        # Add the matching linux modules package
        modules_package = UbuntuDrivers.detect.get_linux_modules_metapackage(cache, candidate)
        print(modules_package)
        if modules_package and not cache[modules_package].installed:
            to_install.append(modules_package)

    if not to_install:
        print('All the available drivers are already installed.')
        return 0

    ret = subprocess.call(['apt-get', 'install', '-o',
        'DPkg::options::=--force-confnew', '-y'] + to_install)

    # create package list
    if ret == 0 and args.package_list:
        with open(args.package_list, 'a') as f:
            f.write('\n'.join(to_install))
            f.write('\n')

    return ret

def command_debug(args):
    '''Print all available information and debug data about drivers.'''

    logger = logging.basicConfig(level=logging.DEBUG, stream=sys.stdout)

    print('=== log messages from detection ===')
    aliases = UbuntuDrivers.detect.system_modaliases()
    cache = apt.Cache()
    packages = UbuntuDrivers.detect.system_driver_packages(
        cache, sys_path, freeonly=args.free_only, include_oem=args.install_oem_meta)
    auto_packages = UbuntuDrivers.detect.auto_install_filter(packages)

    print('=== modaliases in the system ===')
    for alias in aliases:
        print(alias)

    print('=== matching driver packages ===')
    for package, info in packages.items():
        p = cache[package]
        try:
            inst = p.installed.version
        except AttributeError:
            inst = '<none>'
        try:
            cand = p.candidate.version
        except AttributeError:
            cand = '<none>'
        if package in auto_packages:
            auto = ' (auto-install)'
        else:
            auto = ''

        info_str = ''
        if info['from_distro']:
            info_str += '  [distro]'
        else:
            info_str += '  [third party]'
        if info['free']:
            info_str += '  free'
        else:
            info_str += '  non-free'
        if 'modalias' in info:
            info_str += '  modalias: ' + info['modalias']
        if 'syspath' in info:
            info_str += '  path: ' + info['syspath']
        if 'vendor' in info:
            info_str += '  vendor: ' + info['vendor']
        if 'model' in info:
            info_str += '  model: ' + info['model']

        print('%s: installed: %s   available: %s%s%s ' % (package, inst, cand, auto,  info_str))

#
# main
#
parser = create_parser()
args = parser.parse_args()

if args.command == 'install' and args.gpgpu != None:
    sys.exit(install_gpgpu(args))
elif args.command == 'list' and args.gpgpu != None:
    sys.exit(list_gpgpu(args))
else:
    if not args.command:
        parser.print_help()
        sys.exit(1)
    # run the function corresponding to the specified command
    command = globals()['command_' + args.command.replace('-', '_')]
    sys.exit(command(args))
