#!/usr/bin/python3
# -*- encoding: utf-8 -*-
#
# The Qubes OS Project, http://www.qubes-os.org
#
# Copyright (C) 2015 Marek Marczykowski-Górecki
#                              <marmarek@invisiblethingslab.com>
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
# as published by the Free Software Foundation; either version 2
# of the License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.
#
#

# Tool for sending windows icon over qrexec to dom0/GUI domain. Icons will be
# tinted by dom0 and then attached to appropriate windows.

import sys
import struct
import asyncio
import logging
import time

import xcffib
from xcffib import xproto

ICON_MAX_SIZE = 1024
# Chronological order of preferred icon sizes based on certified HW and defaults
ICON_PREFERRED_SIZES = (
    (128, 128),
    (96, 96),
    (64, 64),
    (48, 48),
    (256, 256),
    (512, 512),
    (32, 32),
)

IconPixmapHint = 0b1 << 2
IconMaskHint = 0b1 << 5


log = logging.getLogger('icon-sender')


class NoIconError(KeyError):
    pass


class IconRetriever(object):
    def __init__(self):
        self.conn = xcffib.connect()
        self.setup = self.conn.get_setup()
        self.root = self.setup.roots[0].root

        # just created windows for which icon wasn't sent yet - should
        # be send on MapNotifyEvent
        self.window_queue = set()

        self.atom_net_wm_icon = self.conn.core.InternAtom(
            False, len("_NET_WM_ICON"), "_NET_WM_ICON").reply().atom

    def disconnect(self):
        log.info('disconnecting from X')
        self.conn.disconnect()

    def watch_window(self, w):
        self.conn.core.ChangeWindowAttributesChecked(
            w, xproto.CW.EventMask, [xproto.EventMask.PropertyChange])

    def get_icons(self, w):
        # check for initial icon now:
        prop_cookie = self.conn.core.GetProperty(
            False,  # delete
            w,  # window
            self.atom_net_wm_icon,
            xproto.Atom.CARDINAL,
            0,  # long_offset
            512 * 1024  # long_length
        )
        try:
            icon = prop_cookie.reply()
        except xproto.BadWindow:
            # Window disappeared in the meantime
            raise NoIconError()

        if icon.format == 0:
            # Legacy case
            # Ancient X11 applications (Xterm, Xcalc, Xlogo, Xeyes, Xclock, ...)
            try:
                prop_cookie = self.conn.core.GetProperty(
                    False,  # delete
                    w,  # window
                    xproto.Atom.WM_HINTS,
                    xproto.GetPropertyType.Any,
                    0,  # long_offset
                    48  # long_length -> XWMHints struct length on x86_64
                )
                reply = prop_cookie.reply()
                if not reply.value_len:
                    raise NoIconError()
                atoms = reply.value.to_atoms()
                flags = atoms[0]
                if not flags & IconPixmapHint:
                    # A menu, pop-up or similar icon-less window
                    raise NoIconError()
                icon_pixmap_geometry = self.conn.core.GetGeometry(
                    atoms[3]
                ).reply()
                if (
                        # Only 1bit/pixel & 24bit/pixel icons are tested
                        not icon_pixmap_geometry.depth in [1, 24]
                        or icon_pixmap_geometry.width > ICON_MAX_SIZE
                        or icon_pixmap_geometry.height > ICON_MAX_SIZE
                ):
                    raise NoIconError()
                image = self.conn.core.GetImage(
                    xproto.ImageFormat.ZPixmap,
                    atoms[3],
                    0,
                    0,
                    icon_pixmap_geometry.width,
                    icon_pixmap_geometry.height,
                    0xFFFFFFFF
                ).reply()
                icon_pixmap_data = image.data.raw
                if not flags & IconMaskHint:
                    # We have an icon without transparency (mask)
                    icon_mask_geometry = None
                    icon_mask_data = None
                else:
                    icon_mask_geometry = self.conn.core.GetGeometry(
                        atoms[7]
                    ).reply()
                    image = self.conn.core.GetImage(
                        xproto.ImageFormat.ZPixmap,
                        atoms[7],
                        0,
                        0,
                        icon_mask_geometry.width,
                        icon_mask_geometry.height,
                        0xFFFFFFFF
                    ).reply()
                    icon_mask_data = image.data.raw
            except (
                    xproto.BadWindow,
                    xproto.WindowError,
                    xproto.AccessError,
                    xproto.DrawableError
                ):
                raise NoIconError

            # Finally we have the required data to construct icon
            icons = {}
            if icon_pixmap_geometry.depth == 1:
                # 1 bit per pixel icons (i.e. xlogo, xeyes, xcalc, ...)
                icon_data = []
                # There might be trailing bytes at the end of each row since
                # each row should be multiples of 32 bits
                row_width = int(
                    len(icon_pixmap_data) / icon_pixmap_geometry.height
                )
                for y in range(0, icon_pixmap_geometry.height):
                    offset = y * row_width
                    for x in range(0, icon_pixmap_geometry.width):
                        byte_offset = int(x / 8) + offset
                        byte = int(icon_pixmap_data[byte_offset])
                        byte = byte >> (x % 8)
                        bit = byte & 0x1
                        if bit:
                            # Can not decide the light/dark theme from vmside :/
                            icon_data.append(0xff7f7f7f)
                        else:
                            icon_data.append(0x0)
            elif icon_pixmap_geometry.depth == 24:
                # 24 bit per pixel icons (i.e. Xterm)
                # Technically this could handle modern programs as well
                # However, _NET_WM_ICON is faster
                icon_data = struct.unpack(
                        "%dI" % (len(icon_pixmap_data) / 4),
                        icon_pixmap_data
                )
                icon_data = [d | 0xff000000 for d in icon_data]
            else:
                # Could not find 8 bit icons of that era to work with
                raise NoIconError()
            if icon_mask_data and icon_mask_geometry.depth == 1:
                # Even Xterm uses 1 bit/pixel mask. I do not know why
                row_width = int(len(icon_mask_data) / icon_mask_geometry.height)
                for y in range(0, icon_mask_geometry.height):
                    offset = y * row_width
                    for x in range(0, icon_mask_geometry.width):
                        byte_offset = int(x/8) + offset
                        byte = int(icon_mask_data[byte_offset])
                        byte = byte >> (x % 8)
                        bit = byte & 0x1
                        pixel = x + y * icon_mask_geometry.height
                        if bit:
                            icon_data[pixel] = icon_data[pixel] & 0xffffffff
                        else:
                            icon_data[pixel] = icon_data[pixel] & 0x00ffffff
            elif icon_mask_data and icon_mask_geometry.depth == 8:
                # Technically this is not tested (No X prog uses 8bit/pix mask)
                # At least not on Qubes OS 4.3 & default Xfwm4
                row_width = int(len(icon_mask_data) / icon_mask_geometry.height)
                for y in range(0, icon_mask_geometry.height):
                    offset = y * row_width
                    for x in range(0, icon_mask_geometry.width):
                        byte_offset = x + offset
                        byte = int(icon_mask_data[byte_offset])
                        pixmask = (byte << 24) | 0x00ffffff
                        icon_data[pixel] = icon_data[pixel] & pixmask
            size = (icon_pixmap_geometry.width, icon_pixmap_geometry.height)
            icons[size] = icon_data
            return icons

        # We have sane _NET_WM_ICON Atom for modern programs
        # convert it later to a proper int array
        icon_data = icon.value.buf()
        if icon.bytes_after:
            prop_cookie = self.conn.core.GetProperty(
                False,  # delete
                w,  # window
                self.atom_net_wm_icon,
                xproto.Atom.CARDINAL,
                icon.value_len,  # long_offset
                icon.bytes_after  # long_length
            )
            icon_cont = prop_cookie.reply()
            icon_data += icon_cont.value.buf()

        # join each 4 bytes into a single int
        icon_data = struct.unpack("%dI" % (len(icon_data) / 4), icon_data)
        icons = {}
        index = 0
        # split the array into icons
        while index < len(icon_data):
            size = (icon_data[index], icon_data[index + 1])
            if size[0] <= ICON_MAX_SIZE and size[1] <= ICON_MAX_SIZE:
                icons[size] = icon_data[index + 2:
                                        index + 2 + (size[0] * size[1])]
            index += 2 + (size[0] * size[1])
        if len(icons.keys()) == 0:
            # no icon with acceptable size
            raise NoIconError()
        return icons

    def describe_icon(self, w):
        try:
            icons = self.get_icons(w)
            # Override the largest icon with our preferred sizes (if possible)
            for size in ICON_PREFERRED_SIZES:
                if size in icons.keys():
                    chosen_size = size
                    break
            else:
                chosen_size = sorted(icons.keys())[-1]

            data = b''
            data += "{}\n".format(w).encode('ascii')
            data += "{} {}\n".format(
                chosen_size[0], chosen_size[1]).encode('ascii')
            data += b''.join(
                [struct.pack('>I', ((b << 8) & 0xffffff00) | (b >> 24)) for b in
                 icons[chosen_size]])
            return data
        except NoIconError:
            return None

    def initial_sync(self):
        self.conn.core.ChangeWindowAttributesChecked(
            self.root, xproto.CW.EventMask,
            [xproto.EventMask.SubstructureNotify])
        self.conn.flush()

        cookie = self.conn.core.QueryTree(self.root)
        root_tree = cookie.reply()
        for w in root_tree.children:
            self.watch_window(w)
            yield self.describe_icon(w)

    async def watch_and_send_icons(self):
        '''
        Yield data for all icons we receive.
        This is an asynchronous generator, so that we can handle reconnections
        during waiting for X events.
        '''
        for icon in self.initial_sync():
            yield icon

        # Emulate select()
        event = asyncio.Event()
        asyncio.get_event_loop().add_reader(
            self.conn.get_file_descriptor(), event.set)

        while True:
            await event.wait()
            event.clear()
            for icon in self.handle_pending_events():
                yield icon

    def handle_pending_events(self):
        for ev in iter(self.conn.poll_for_event, None):
            if isinstance(ev, xproto.CreateNotifyEvent):
                self.window_queue.add(ev.window)
                self.watch_window(ev.window)
            elif isinstance(ev, xproto.MapNotifyEvent):
                if ev.window in self.window_queue:
                    yield self.describe_icon(ev.window)
                    self.window_queue.remove(ev.window)
            elif isinstance(ev, xproto.PropertyNotifyEvent):
                if ev.atom == self.atom_net_wm_icon:
                    yield self.describe_icon(ev.window)


class IconSender:
    async def run(self):
        fail_threshold_seconds = 5
        restart_wait_seconds = 5
        restart_tries = 5

        try_num = 0
        while True:
            t = await self.run_client()
            if t < fail_threshold_seconds:
                try_num += 1
                if try_num == restart_tries:
                    log.error('giving up after %d tries', try_num)
                    break
                log.error('process exited too soon, waiting %d seconds and retrying',
                          restart_wait_seconds)
                await asyncio.sleep(restart_wait_seconds)
            else:
                log.info('process exited, trying to reconnect')
                try_num = 0

    async def run_client(self):
        cmd = ['qrexec-client-vm', 'dom0', 'qubes.WindowIconUpdater']
        log.info('running: %s', cmd)

        start_time = time.time()
        proc = await asyncio.create_subprocess_exec(
            *cmd,
            stdin=asyncio.subprocess.PIPE
        )
        send_icons_task = asyncio.create_task(self.send_icons(proc.stdin))
        try:
            await proc.wait()
            if proc.returncode:
                log.error('process failed with status %d', proc.returncode)
            return time.time() - start_time
        finally:
            send_icons_task.cancel()

    async def send_icons(self, writer):
        retriever = IconRetriever()
        try:
            async for data in retriever.watch_and_send_icons():
                if data:
                    writer.write(data)
                await writer.drain()
        except IOError:
            log.exception()
        except asyncio.CancelledError:
            pass
        finally:
            retriever.disconnect()
            writer.close()


def main():
    logging.basicConfig(
        stream=sys.stderr, level=logging.INFO,
        format='%(asctime)s %(name)s: %(message)s')

    sender = IconSender()
    asyncio.run(sender.run())


if __name__ == '__main__':
    main()
