#!/usr/bin/python3
#
#       prime-select
#
#       Copyright 2013 Canonical Ltd.
#       Author: Alberto Milone <alberto.milone@canonical.com>
#
#       Script to switch between NVIDIA and Intel graphics driver libraries.
#
#       Usage:
#           prime-select   nvidia|intel|on-demand|query
#           nvidia:    switches to NVIDIA's version of libGL.so
#           on-demand: load NVIDIA driver, and on-demend for others
#           intel: switches to the open-source version of libGL.so
#           query: checks which version is currently active and writes
#                  "nvidia", "intel", "on-demand" or "unknown" to the
#                  standard output
#
#       Permission is hereby granted, free of charge, to any person
#       obtaining a copy of this software and associated documentation
#       files (the "Software"), to deal in the Software without
#       restriction, including without limitation the rights to use,
#       copy, modify, merge, publish, distribute, sublicense, and/or sell
#       copies of the Software, and to permit persons to whom the
#       Software is furnished to do so, subject to the following
#       conditions:
#
#       The above copyright notice and this permission notice shall be
#       included in all copies or substantial portions of the Software.
#
#       THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
#       EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES
#       OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
#       NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT
#       HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,
#       WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
#       FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
#       OTHER DEALINGS IN THE SOFTWARE.

import glob
import os
import sys
import re
import subprocess
import shutil

from copy import deepcopy
from subprocess import Popen, PIPE, CalledProcessError


class Switcher(object):

    def __init__(self):
        self._power_profile_path = '/etc/prime-discrete'
        self._grub_path = '/etc/default/grub'
        self._grub_cmdline_start = 'GRUB_CMDLINE_LINUX_DEFAULT='
        self._old_blacklist_file = '/etc/modprobe.d/blacklist-nvidia.conf'
        self._blacklist_file = '/lib/modprobe.d/blacklist-nvidia.conf'
        self._nvidia_kms_file = '/lib/modprobe.d/nvidia-kms.conf'
        self._gdm_conf_file = '/etc/gdm3/custom.conf'
        self._udev_rule_file = '/lib/udev/rules.d/50-pm-nvidia.rules'
        self._old_udev_rule_file = '/lib/udev/rules.d/80-pm-nvidia.rules'

    def _get_profile(self):

        try:
            settings = open(self._power_profile_path, 'r')
        except:
            return 'unknown'

        config = settings.read().strip()
        if config == 'on':
            return 'nvidia'
        elif config == "on-demand":
            return 'on-demand'
        else:
            return 'intel'

    def print_profile(self):
        profile = self._get_profile()
        if profile == 'unknown':
            return False

        print('%s' % profile)
        return True

    def _write_profile(self, profile):
        if profile == 'intel':
            nvidia_power = 'off'
        elif profile == "on-demand":
            nvidia_power = "on-demand"
        elif profile == 'nvidia':
            nvidia_power = 'on'
        else:
            return False

        # Write the settings to the file
        settings = open(self._power_profile_path, 'w')
        settings.write('%s\n' % nvidia_power)
        settings.close()

    def _has_intel_gpu(self):
        status = False;

        path = '/var/lib/ubuntu-drivers-common/last_gfx_boot'
        if os.path.isfile(path):
            with open(path, 'r') as f:
                t = f.read()
                if t.find('8086') != -1:
                    status = True
                f.close()

        return status

    def enable_profile(self, profile):
        current_profile = self._get_profile()

        if profile == current_profile:
            # No need to do anything if we're already using the desired
            # profile
            sys.stdout.write('Info: the %s profile is already set\n' % (profile))
            return True

        sys.stdout.write('Info: selecting the %s profile\n' % (profile))

        self._backup_grub_config()

        if profile == 'nvidia':
            # Always allow enabling nvidia
            # (No need to check if nvidia is available)
            self._enable_nvidia()
        elif profile == "on-demand":
            self._disable_nvidia(keep_nvidia_modules=True)
        else:
            # Make sure that the installed packages support PRIME
            #if not self._supports_prime():
            #    sys.stderr.write('Error: the installed packages do not support PRIME\n')
            #    return False
            self._disable_nvidia()

        # Write the settings to the config file
        self._write_profile(profile)

        return True

    def _create_pm_udev_rule(self, keep_nvidia=True):
        udev_rule_stub = '''# Remove NVIDIA USB xHCI Host Controller devices, if present
ACTION=="add", SUBSYSTEM=="pci", ATTR{vendor}=="0x10de", ATTR{class}=="0x0c0330", ATTR{remove}="1"

# Remove NVIDIA USB Type-C UCSI devices, if present
ACTION=="add", SUBSYSTEM=="pci", ATTR{vendor}=="0x10de", ATTR{class}=="0x0c8000", ATTR{remove}="1"

# Remove NVIDIA Audio devices, if present
ACTION=="add", SUBSYSTEM=="pci", ATTR{vendor}=="0x10de", ATTR{class}=="0x040300", ATTR{remove}="1"
%s
# Enable runtime PM for NVIDIA VGA/3D controller devices on driver bind
ACTION=="bind", SUBSYSTEM=="pci", ATTR{vendor}=="0x10de", ATTR{class}=="0x030000", TEST=="power/control", ATTR{power/control}="auto"
ACTION=="bind", SUBSYSTEM=="pci", ATTR{vendor}=="0x10de", ATTR{class}=="0x030200", TEST=="power/control", ATTR{power/control}="auto"

# Disable runtime PM for NVIDIA VGA/3D controller devices on driver unbind
ACTION=="unbind", SUBSYSTEM=="pci", ATTR{vendor}=="0x10de", ATTR{class}=="0x030000", TEST=="power/control", ATTR{power/control}="on"
ACTION=="unbind", SUBSYSTEM=="pci", ATTR{vendor}=="0x10de", ATTR{class}=="0x030200", TEST=="power/control", ATTR{power/control}="on"'''

        disable_nvidia_stub = '''
# Remove NVIDIA VGA/3D controller
ACTION=="add", SUBSYSTEM=="pci", ATTR{vendor}=="0x10de", ATTR{class}=="0x030000", ATTR{remove}="1"
ACTION=="add", SUBSYSTEM=="pci", ATTR{vendor}=="0x10de", ATTR{class}=="0x038000", ATTR{remove}="1"
'''
        if keep_nvidia or not self._has_intel_gpu():
            complete_stub = ''
        else:
            complete_stub = disable_nvidia_stub
        udev_rule = udev_rule_stub % (complete_stub)

        try:
            os.unlink(self._old_udev_rule_file)
        except:
            pass

        rule_fd = open(self._udev_rule_file, 'w')
        rule_fd.write('%s\n' % (udev_rule))
        rule_fd.close()

    def _disable_nvidia(self, keep_nvidia_modules=False):
        try:
            os.unlink(self._old_blacklist_file)
        except:
            pass

        if keep_nvidia_modules:
            try:
                os.unlink(self._blacklist_file)
            except:
                pass
        else:
            self._blacklist_nvidia()

        try:
            os.unlink(self._old_udev_rule_file)
        except:
            pass

        self._create_pm_udev_rule(keep_nvidia_modules)

    def _enable_nvidia(self):
        try:
            os.unlink(self._old_blacklist_file)
        except:
            pass

        try:
            os.unlink(self._blacklist_file)
        except:
            pass

        try:
            os.unlink(self._udev_rule_file)
        except:
            pass

        # Create configuration file so that users can enable
        # KMS easily.
        # modeset is off by default.
        if not os.path.isfile(self._nvidia_kms_file):
            self._enable_kms()

    def _blacklist_nvidia(self):
        blacklist_text = '''# Do not modify
# This file was generated by nvidia-prime
blacklist nvidia
blacklist nvidia-drm
blacklist nvidia-modeset
alias nvidia off
alias nvidia-drm off
alias nvidia-modeset off'''
        blacklist_fd = open(self._blacklist_file, 'w')
        blacklist_fd.write(blacklist_text)
        blacklist_fd.close()

    def _enable_kms(self):
        # This is actually disabled now, but it can be enabled
        # by users with a simple change.
        kms_text = '''# This file was generated by nvidia-prime
# Set value to 1 to enable modesetting
options nvidia-drm modeset=0'''
        kms_fd = open(self._nvidia_kms_file, 'w')
        kms_fd.write(kms_text)
        kms_fd.close()

    def _add_boot_params(self, pattern, path, params):
        it = 0
        arg_found = False

        with open(path, 'r+') as f:
            t = f.read()
            f.seek(0)
            for line in t.split('\n'):
                if line.startswith(pattern):
                    boot_args = line.replace(pattern, '').replace('"', '')
                    boot_args_list = boot_args.split(' ')
                    final_boot_args = deepcopy(boot_args_list)

                    for key, value in params.items():
                        target_param = '%s=%s' % (key, value)
                        for i, arg in enumerate(boot_args_list):
                            if key in arg:
                                arg_found = True
                                final_boot_args[i] = '%s' % (target_param)
                        if not arg_found:
                            final_boot_args.append(target_param)
                        else:
                            arg_found = False
                    new_line = '%s"%s"' % (pattern, ' '.join(final_boot_args))
                    f.write('%s%s' % ((it > 0 and '\n' or ''), new_line))
                else:
                    f.write('%s%s' % ((it > 0 and '\n' or ''), line))
                it +=1
            f.truncate()

    def _remove_boot_params(self, pattern, path, params):
        it = 0
        arg_found = False

        with open(path, 'r+') as f:
            t = f.read()
            f.seek(0)
            for line in t.split('\n'):
                if line.startswith(pattern):
                    boot_args = line.replace(pattern, '').replace('"', '')
                    boot_args_list = boot_args.split(' ')
                    final_boot_args = deepcopy(boot_args_list)

                    for key in params:
                        for i, arg in enumerate(boot_args_list):
                            if key in arg:
                                final_boot_args[i] = ''
                    final_boot_args = list(filter(bool, final_boot_args))
                    new_line = '%s"%s"' % (pattern, ' '.join(final_boot_args))
                    f.write('%s%s' % ((it > 0 and '\n' or ''), new_line))
                else:
                    f.write('%s%s' % ((it > 0 and '\n' or ''), line))
                it +=1
            f.truncate()


    def _find_connected_connectors(self, card):
        connectors = glob.glob('/sys/class/drm/%s-*' % (card))
        connected_connectors = []
        for connector in connectors:
            path = '%s/status' % connector
            with open(path, 'r') as f:
                t = f.read()
                if t.strip() == 'connected':
                    connected_connectors.append(connector)
                f.close()
        return connected_connectors

    def _get_boot_params_from_phantom_vga_connectors(self):
        params = []
        connectors = self._find_connected_connectors('card1')
        for connector in connectors:
            if 'vga' in connector.lower():
                conn = connector.replace('/sys/class/drm/card1-', '').replace('-', '')
                param = 'video=%s:d' % conn
                params.append(param)
        return params


    def _update_grub(self):
        subprocess.call(['update-grub'])

    def _backup_grub_config(self):
        destination = '%s.prime-backup' % self._grub_path
        if not os.path.isfile(destination):
            shutil.copyfile(self._grub_path, destination)


def check_root():
    if not os.geteuid() == 0:
        sys.stderr.write("This operation requires root privileges\n")
        exit(1)

def handle_query_error():
    sys.stderr.write("Error: no profile can be found\n")
    exit(1)

def usage():
    sys.stderr.write("Usage: %s nvidia|intel|on-demand|query\n" % (sys.argv[0]))

if __name__ == '__main__':
    try:
        arg = sys.argv[1]
    except IndexError:
        arg = None

    if len(sys.argv[1:]) != 1:
        usage()
        exit(1)

    switcher = Switcher()

    if arg in ['intel', 'nvidia', 'on-demand']:
        check_root()
        switcher.enable_profile(arg)
    elif arg == 'query':
        if not switcher.print_profile():
            handle_query_error()
    else:
        usage()
        sys.exit(1)

    exit(0)
