"""
CC-BY 4.0, Mickaël Péchaud, 2021

Implementation of priority queue for arbitrary type elements, using a standard heap structure, as well as a dictionary mecanism to allow direct access to an element in the heap to decrease its priority. Preserve insertion order when priorities tie. It can thus be used to implement algorithms such that Dijkstra or A*.

Assuming O(1) time complexity for list access, append, pop and dict access, addition, removal, this class provides the following complexities (n being the number of elements currently in the queue) : 

* q.empty() (empty test in O(1))
* q.add(elt, p) (add element elt with priority p, in O(log(n)))
* q.popmin() (retrieve and remove elt with small priority, in O(log(n)))
* q.decreasepriority(elt, p) (decreases the priority of elt to p, in O(log(n)))

If the elements to be added are mutable, you should provide the constructor with a one-to-one code fonction from the type of the elements to a non-mutable one.

------
Basic example
------

>>> def code(l):
        return tuple(l)

>>> p = PriorityQueue(code)

>>> p.add([1, 2, 3], 5)
>>> p.add([0, 3, 2], 4)
>>> p.add([2, 4, 0], 6)
>>> p.add([3, 5, 1], 8)

>>> p.decreasepriority([3, 5, 1], 5)

>>> while not p.empty():
       print (p.popmin())
[0, 3, 2]
[1, 2, 3]
[3, 5, 1]
[2, 4, 0]
"""

import itertools


class PriorityQueue:
    '''Priority queue for arbitrary type elements. If type is mutable, you should provide a one-to-one code fonction from this type to a non-mutable one'''
    
    class PriorityQueueElement:
        '''Wraper for priority queue elements'''
        counter = itertools.count()
        def __init__(self, data, priority):
            self.data = data
            self.queueindex = -1 # queueindex allows direct access to element position in queue, thus updating an element priority in the queue
            self.priority = priority
            self.num = next(PriorityQueue.PriorityQueueElement.counter)
        def __lt__(self, other):
            return self.priority < other.priority or (self.priority == other.priority and self.num < other.num)

    
    def __init__(self, code = lambda x : x):
        self.elements_tab = {} #dict to retrieve PriorityQueueElement corresponding to a given element
        self.heap = []
        self.code = code
        
    def add(self, elt, p):
        '''Add an element to the queue'''
        pqe = PriorityQueue.PriorityQueueElement(elt, p)
        pqe.queueindex = len(self.heap)
        self.heap.append(pqe)
        self.elements_tab[self.code(elt)] = pqe
        self.__traversedown(len(self.heap)-1)
        
    def popmin(self):
        '''Pop the element of smallest value'''
        m = self.heap[0].data
        self.__switch(0, len(self.heap)-1)
        self.heap.pop()
        self.__traverseup(0)
        del self.elements_tab[self.code(m)]
        return m
    
    def decreasepriority(self, elt, p):
        '''Decrease the priority of elt to p'''
        e = self.elements_tab[self.code(elt)]
        if e.priority < p:
            raise ValueError('Cannot increase priority value')
        e.priority = p
        self.__traversedown(e.queueindex)
    
    def empty(self):
        '''Test if the heap is empty'''
        return self.heap == []
        
    def __switch(self, i, j):
        '''Switch elements of indexes i and j'''
        self.heap[i], self.heap[j] = self.heap[j], self.heap[i]
        self.heap[i].queueindex = i
        self.heap[j].queueindex = j

    def __traversedown(self, i):
        '''Put ith element in place (towards the root)'''
        while i > 0 and self.heap[i] < self.heap[(i-1)//2]:
            self.__switch(i, (i-1)//2)
            i = (i-1)//2
    
    def __traverseup(self, i):
        '''Put ith element in place (towards the leaves)'''
        if 2*i + 2 < len(self.heap):
            if self.heap[2*i + 1] < self.heap[i] and self.heap[2*i + 1] < self.heap[2*i + 2]:
                self.__switch(i, 2*i + 1)
                self.__traverseup(2*i + 1)
            elif self.heap[2*i + 2] < self.heap[i] and self.heap[2*i + 2] < self.heap[2*i + 1]:
                self.__switch(i, 2*i + 2)
                self.__traverseup(2*i + 2)
        elif 2*i + 1 < len(self.heap):
            if self.heap[2*i + 1] < self.heap[i]:
                self.__switch(i, 2*i + 1)
                
    def _check_coherency(self):
        '''Checks if the heap conditions stand. For debugging purposes only'''
        for i in range(len(self.heap)):
            if 2*i + 1 < len(self.heap):
                assert self.heap[i] < self.heap[2*i + 1]
            if 2*i + 2 < len(self.heap):
                assert self.heap[i] < self.heap[2*i + 2]


if __name__ == "__main__":
    pq = PriorityQueue()
    for i in range(0,20,3):
        pq.add('a'+str(i), i)
        
    for i in range(1,20,3):
        pq.add('a'+str(i), i)
    
    for i in range(0,20,3):
        pq.add('b'+str(i), i + 5)
    
    for i in range(1,20,3):
        pq.add('b'+str(i), i + 5)
        
    for i in range(1,20,3):
        pq.decreasepriority('b'+str(i), i)

    for i in range(2,20,3):
        pq.add('a'+str(i), i)
    
    for i in range(0,20,3):
        pq.decreasepriority('b'+str(i), i)
        
    print(list(map(lambda x : pq.elements_tab[x.data].queueindex, pq.heap)))
    
    pq._check_coherency()
    
    print('--------------')    
    
    while not pq.empty():
        print(pq.heap[0].num, pq.popmin())
        pq._check_coherency()
