"""
This is example python code for constructing complex
variables argument plots for polynomials.
"""
from pylab import *

## boilerplate start #######################

def imagplot(f,center, radius, n=20):
    center = center + 0j
    xmin = center.real - radius
    xmax = center.real + radius
    ymin = center.imag - radius
    ymax = center.imag + radius
    dx = (xmax-xmin)/n
    dy = (ymax-ymin)/n
    x = arange(xmin, xmax, dx)
    y = arange(ymin, ymax, dy)
    X, Y = meshgrid(x, y)
    Z = f(X+1j*Y)
    cs = pcolor(X, Y, Z, cmap=get_cmap('hsv'))
    cbar = colorbar(cs)
    return cs

def stexpr2tex(y):
    y = y.replace('**','^').replace('*','')
    return f'${y}$'

def argplotpoly(apoly,center=0,radius=6,n=40):
    # avoid name collision with poly!!!
    if isinstance(apoly,str):
        p = eval('lambda z :'+apoly)
    else:
        assert callable(apoly)
        p = apoly
    f = lambda z : angle(p(z))
    ax = imagplot(f, center, radius, n)
    if isinstance(poly,str):
        title(stexpr2tex(polystring))
    return ax

## boilerplate end #######################

def genrandpoly(n=5):
    s = randn(n) + 1j*randn(n)
    def p(z):
        w = 1
        for k in range(n):
            w = w*(1-z/s[k])
        return w
    return p

def main():
    f = genrandpoly()
    # or f = 'z**2-1'

    argplotpoly(f,n=199)
    show()

main()