﻿# coding=utf-8

#Python modules
try:
	import visual as v
except:
	print 'Error importing Visual Python. Please ensure it is installed.'
	exit()

import numpy
import difn
import math

#initialise the scene
def initialise():
	return v.display(title='MµCalc', width=800, height=800,x=500,y=500, center=(0,0,0), background=(0.1,0.1,0.1),exit=False)#,stereo='redcyan')

def hide(a):
	if a.__class__.__name__ == 'dict': #if it's a Python dictionary, do this function on all sub-elements thereof
		for key in iter(a):
			hide(a[key])
	elif  a.__class__.__name__ == 'list':
		for i in range(len(a)):
			hide(a[i])
	else: #if it's not a list or a dictionary, it must be a VPython object so hide it
		a.visible = False
	return a

def atom_colours(name):
	if name == 'H': return (255.0/255,255.0/255,255.0/255),0.5
	if name == 'He': return (217.0/255,255.0/255,255.0/255),1
	if name == 'Li': return (204.0/255,128.0/255,255.0/255),1
	if name == 'Be': return (194.0/255,255.0/255,0.0/255),1
	if name == 'B': return (255.0/255,181.0/255,181.0/255),1
	if name == 'C': return (144.0/255,144.0/255,144.0/255),1
	if name == 'N': return (48.0/255,80.0/255,248.0/255),1
	if name == 'O': return (255.0/255,13.0/255,13.0/255),0.5
	if name == 'F': return (144.0/255,224.0/255,80.0/255),1
	if name == 'Ne': return (179.0/255,227.0/255,245.0/255),1
	if name == 'Na': return (171.0/255,92.0/255,242.0/255),1
	if name == 'Mg': return (138.0/255,255.0/255,0.0/255),1
	if name == 'Al': return (191.0/255,166.0/255,166.0/255),1
	if name == 'Si': return (240.0/255,200.0/255,160.0/255),1
	if name == 'P': return (255.0/255,128.0/255,0.0/255),1
	if name == 'S': return (255.0/255,255.0/255,48.0/255),1
	if name == 'Cl': return (31.0/255,240.0/255,31.0/255),1
	if name == 'Ar': return (128.0/255,209.0/255,227.0/255),1
	if name == 'K': return (143.0/255,64.0/255,212.0/255),1
	if name == 'Ca': return (61.0/255,255.0/255,0.0/255),1
	if name == 'Sc': return (230.0/255,230.0/255,230.0/255),1
	if name == 'Ti': return (191.0/255,194.0/255,199.0/255),1
	if name == 'V': return (166.0/255,166.0/255,171.0/255),1
	if name == 'Cr': return (138.0/255,153.0/255,199.0/255),1
	if name == 'Mn': return (156.0/255,122.0/255,199.0/255),1
	if name == 'Fe': return (224.0/255,102.0/255,51.0/255),1
	if name == 'Co': return (240.0/255,144.0/255,160.0/255),1
	if name == 'Ni': return (80.0/255,208.0/255,80.0/255),1
	if name == 'Cu': return (200.0/255,128.0/255,51.0/255),1
	if name == 'Zn': return (125.0/255,128.0/255,176.0/255),1
	if name == 'Ga': return (194.0/255,143.0/255,143.0/255),1
	if name == 'Ge': return (102.0/255,143.0/255,143.0/255),1
	if name == 'As': return (189.0/255,128.0/255,227.0/255),1
	if name == 'Se': return (255.0/255,161.0/255,0.0/255),1
	if name == 'Br': return (166.0/255,41.0/255,41.0/255),1
	if name == 'Kr': return (92.0/255,184.0/255,209.0/255),1
	if name == 'Rb': return (112.0/255,46.0/255,176.0/255),1
	if name == 'Sr': return (0.0/255,255.0/255,0.0/255),1
	if name == 'Y': return (148.0/255,255.0/255,255.0/255),1
	if name == 'Zr': return (148.0/255,224.0/255,224.0/255),1
	if name == 'Nb': return (115.0/255,194.0/255,201.0/255),1
	if name == 'Mo': return (84.0/255,181.0/255,181.0/255),1
	if name == 'Tc': return (59.0/255,158.0/255,158.0/255),1
	if name == 'Ru': return (36.0/255,143.0/255,143.0/255),1
	if name == 'Rh': return (10.0/255,125.0/255,140.0/255),1
	if name == 'Pd': return (0.0/255,105.0/255,133.0/255),1
	if name == 'Ag': return (192.0/255,192.0/255,192.0/255),1
	if name == 'Cd': return (255.0/255,217.0/255,143.0/255),1
	if name == 'In': return (166.0/255,117.0/255,115.0/255),1
	if name == 'Sn': return (102.0/255,128.0/255,128.0/255),1
	if name == 'Sb': return (158.0/255,99.0/255,181.0/255),1
	if name == 'Te': return (212.0/255,122.0/255,0.0/255),1
	if name == 'I': return (148.0/255,0.0/255,148.0/255),1
	if name == 'Xe': return (66.0/255,158.0/255,176.0/255),1
	if name == 'Cs': return (87.0/255,23.0/255,143.0/255),1
	if name == 'Ba': return (0.0/255,201.0/255,0.0/255),1
	if name == 'La': return (112.0/255,212.0/255,255.0/255),1
	if name == 'Ce': return (255.0/255,255.0/255,199.0/255),1
	if name == 'Pr': return (217.0/255,255.0/255,199.0/255),1
	if name == 'Nd': return (199.0/255,255.0/255,199.0/255),1
	if name == 'Pm': return (163.0/255,255.0/255,199.0/255),1
	if name == 'Sm': return (143.0/255,255.0/255,199.0/255),1
	if name == 'Eu': return (97.0/255,255.0/255,199.0/255),1
	if name == 'Gd': return (69.0/255,255.0/255,199.0/255),1
	if name == 'Tb': return (48.0/255,255.0/255,199.0/255),1
	if name == 'Dy': return (31.0/255,255.0/255,199.0/255),1
	if name == 'Ho': return (0.0/255,255.0/255,156.0/255),1
	if name == 'Er': return (0.0/255,230.0/255,117.0/255),1
	if name == 'Tm': return (0.0/255,212.0/255,82.0/255),1
	if name == 'Yb': return (0.0/255,191.0/255,56.0/255),1
	if name == 'Lu': return (0.0/255,171.0/255,36.0/255),1
	if name == 'Hf': return (77.0/255,194.0/255,255.0/255),1
	if name == 'Ta': return (77.0/255,166.0/255,255.0/255),1
	if name == 'W': return (33.0/255,148.0/255,214.0/255),1
	if name == 'Re': return (38.0/255,125.0/255,171.0/255),1
	if name == 'Os': return (38.0/255,102.0/255,150.0/255),1
	if name == 'Ir': return (23.0/255,84.0/255,135.0/255),1
	if name == 'Pt': return (208.0/255,208.0/255,224.0/255),1
	if name == 'Au': return (255.0/255,209.0/255,35.0/255),1
	if name == 'Hg': return (184.0/255,184.0/255,208.0/255),1
	if name == 'Tl': return (166.0/255,84.0/255,77.0/255),1
	if name == 'Pb': return (87.0/255,89.0/255,97.0/255),1
	if name == 'Bi': return (158.0/255,79.0/255,181.0/255),1
	if name == 'Po': return (171.0/255,92.0/255,0.0/255),1
	if name == 'At': return (117.0/255,79.0/255,69.0/255),1
	if name == 'Rn': return (66.0/255,130.0/255,150.0/255),1
	if name == 'Fr': return (66.0/255,0.0/255,102.0/255),1
	if name == 'Ra': return (0.0/255,125.0/255,0.0/255),1
	if name == 'Ac': return (112.0/255,171.0/255,250.0/255),1
	if name == 'Th': return (0.0/255,186.0/255,255.0/255),1
	if name == 'Pa': return (0.0/255,161.0/255,255.0/255),1
	if name == 'U': return (0.0/255,143.0/255,255.0/255),1
	if name == 'Np': return (0.0/255,128.0/255,255.0/255),1
	if name == 'Pu': return (0.0/255,107.0/255,255.0/255),1
	if name == 'Am': return (84.0/255,92.0/255,242.0/255),1
	if name == 'Cm': return (120.0/255,92.0/255,227.0/255),1
	if name == 'Bk': return (138.0/255,79.0/255,227.0/255),1
	if name == 'Cf': return (161.0/255,54.0/255,212.0/255),1
	if name == 'Es': return (179.0/255,31.0/255,212.0/255),1
	if name == 'Fm': return (179.0/255,31.0/255,186.0/255),1
	if name == 'Md': return (179.0/255,13.0/255,166.0/255),1
	if name == 'No': return (189.0/255,13.0/255,135.0/255),1
	if name == 'Lr': return (199.0/255,0.0/255,102.0/255),1
	if name == 'Rf': return (204.0/255,0.0/255,89.0/255),1
	if name == 'Db': return (209.0/255,0.0/255,79.0/255),1
	if name == 'Sg': return (217.0/255,0.0/255,69.0/255),1
	if name == 'Bh': return (224.0/255,0.0/255,56.0/255),1
	if name == 'Hs': return (230.0/255,0.0/255,46.0/255),1
	if name == 'Mt': return (235.0/255,0.0/255,38.0/255),1

# turns a number from 0 to 1 into a rainbow colour bar colour, starting at blue and ending at red
# If x > 1 or x < 0, returns a cautionary black or white.
def col_rainbow(x):
	if(0 <= x <= 1):
		a = math.floor(x/0.25)%4
		s = (x - a*0.25)*4.0
		if(s==4.0): a,s=3,1.0 #this stops the error if x = 1.0...is there a better way?
		# blue - cyan
		if(a==0):
			return numpy.array([0,s,1])
		# cyan - green
		elif(a==1):
			return numpy.array([0,1,1-s])
		# green - yellow
		elif(a==2):
			return numpy.array([s,1,0])
		# yellow - red
		elif(a==3):
			return numpy.array([1,1-s,0])
	#if x < 0 was passed, return black
	elif(x < 0):
		return  numpy.array([0,0,0])
	#if x > 1 was passed, return white
	elif(x > 1):
		return  numpy.array([1,1,1])
	#if x is not numerical, return magenta
	elif(x > 1):
		return  numpy.array([1,0,1])

def col_rainbow_complex(r, theta,degrees=False):
	if(0 <= r <= 1):
		if(not(degrees)):
			#turn the angles into degrees
			deg = (theta * 180 / math.pi) % 360 #mod 360 to catch rounding errors
		a = math.floor(deg/60.0)%6
		s = (deg - a*60.0)/60.0
		if(s==6.0): a,s=5,1.0 #this stops the error if deg = 360.0...is there a better way?
		s = r*s #scale s with r
		# red - yellow
		if(a==0):
			return numpy.array([r,s,0])
		# yellow - green
		elif(a==1):
			return numpy.array([r-s,r,0])
		# green - cyan
		elif(a==2):
			return numpy.array([0,r,s])
		# cyan - blue
		elif(a==3):
			return numpy.array([0,r-s,r])
		# blue - magenta
		elif(a==4):
			return numpy.array([s,0,r])
		# magenta - red
		elif(a==5):
			return numpy.array([r,0,r-s])
	#if x < 0 was passed, return grey
	elif(r < 0):
		return  numpy.array([0.5,0.5,0.5])
	#if x > 1 was passed, return white
	elif(r > 1):
		return  numpy.array([1,1,1])
	#if x is not numerical, return light magenta
	else:
		return  numpy.array([1,0.9,1])
		
def col_rainbow_theta(theta,degrees=False):
	if(not(degrees)):
		#turn the angles into degrees
		deg = (theta * 180 / math.pi) % 360 #mod 360 to catch rounding errors/angles outside 0-360
	a = math.floor(deg/60.0)%6
	s = (deg - a*60.0)/60.0
	if(s==6.0): a,s=5,1.0 #this stops the error if deg = 360.0...is there a better way?
	# red - yellow
	if(a==0):
		return numpy.array([1,s,0])
	# yellow - green
	elif(a==1):
		return numpy.array([1-s,1,0])
	# green - cyan
	elif(a==2):
		return numpy.array([0,1,s])
	# cyan - blue
	elif(a==3):
		return numpy.array([0,1-s,1])
	# blue - magenta
	elif(a==4):
		return numpy.array([s,0,1])
	# magenta - red
	elif(a==5):
		return numpy.array([1,0,1-s])

def draw_crystal(r, attr, types):
	# draw atoms
	for i in range(len(r)):
		xyz,s = numpy.array(r[i]),numpy.array(attr[i])
		#choose colour depending on spin direction (make the col vector the unit vector)
		if (s[0]==0 and s[1]==0 and s[2]==0):
			col = numpy.array((0,0,0))
		else:
			col = s/numpy.sqrt(numpy.dot(s,s))
			#and if any are less than zero, add the complementary
			if col[0] < 0:
				col[1]-= col[0]
				col[2] -=col[0]
				col[0] = 0
			if col[1] < 0:
				col[0]-= col[1]
				col[2] -=col[1]
				col[1] = 0
			if col[2] < 0:
				col[0]-= col[2]
				col[1] -=col[2]
				col[2] = 0
		spingro = 0.2 #because mu_B is 10^-24, so we need to make it about ~10^-10 to display
		print xyz,s
		pointer = v.arrow(pos=xyz-s*spingro/2, axis=s*spingro, color=col)
		#draw spheres on the atom sites
		colour,size = atom_colours(types[i])
		pointer = v.sphere(pos=xyz, color=colour, radius=0.1*size)

	#draw a dot at the origin
	#pointer = v.sphere(pos=(0,0,0), color=(1,1,1), radius=0.35e-10)

	#draw a dot at the muon position
	#pointer = v.sphere(pos=a_cart[0]*mu_frac[0]+a_cart[1]*mu_frac[1]+a_cart[2]*mu_frac[2], color=(1,0.8,0), radius=0.15e-10)

def unitcell_init(a,scale):
	radius = 0.01*scale
	#initialise variable to return
	components = []
	#draw the cylinders (one for each vector at the origin, two of each other vector from the tip of each vector, one of each vector from the sum of each pair of vectors)
	for i in range(3):
		components.append(v.cylinder(pos=(0,0,0), axis=(a[i][0],a[i][1],a[i][2]), radius=radius))
		components.append(v.cylinder(pos=(a[(i+1)%3][0],a[(i+1)%3][1],a[(i+1)%3][2]), axis=(a[i][0],a[i][1],a[i][2]), radius=radius))
		components.append(v.cylinder(pos=(a[(i+2)%3][0],a[(i+2)%3][1],a[(i+2)%3][2]), axis=(a[i][0],a[i][1],a[i][2]), radius=radius))
		components.append(v.cylinder(pos=(a[(i+1)%3][0]+a[(i+2)%3][0],a[(i+1)%3][1]+a[(i+2)%3][1],a[(i+1)%3][2]+a[(i+2)%3][2]), axis=(a[i][0],a[i][1],a[i][2]), radius=radius))
	return components

def draw_atoms(r,names,scale):
	atoms = []
	for i in range(len(r)):
		#colour from black to white at fmax
		colour,size = atom_colours(names[i])
		size = size*scale*0.06
		atoms.append(v.sphere(pos=r[i], color=colour, radius=size))
	return atoms

#~ def draw_unit_cell_atoms(r,names,scale,offset=[0,0,0],fenceposts=True):
	#~ #if fenceposts is true, then get adjacent atoms
	#~ if fenceposts:
		#~ atoms_r,atoms_names = difn.unit_cell_shared_atoms(r,names)
	#~ #otherwise, just use atoms already generated
	#~ else:
		#~ atoms_r = []
		#~ atoms_names = []
		#~ for i in range(len(r)):
			#~ atoms_r.append(r[i])
			#~ atoms_names.append(names[i])
	#~ return draw_atoms(atoms_r,atoms_names,scale)

def scalar_field(r,phi,phimin=0,phimax=0,colourtype='rainbow',scale=1):
	field = []
	#work out limits of phi if not provided
	if(phimin==phimax==0):
		phimin = numpy.min(numpy.abs(phi))
		phimax = numpy.max(numpy.abs(phi))
		#if they're the same because all passed field values are identical
		if(phimin==phimax):
			phimin = 0
	
	for i in range(len(r)):
		#colour from black to white at fmax
		val = (numpy.abs(phi[i]-phimin))/(phimax-phimin)
		if(colourtype=='rainbow'):
			colour = col_rainbow(val)
			opacity = 1.0
		elif(colourtype=='bw'):
			colour = (val,val,val)
			opacity = 1.0
		elif(colourtype=='rainbow_complex'):
			colour = col_rainbow_complex(val,numpy.angle(phi[i]))
			opacity = 1.0
		elif(colourtype=='rainbow_complex_transparency'):
			colour = col_rainbow_theta(numpy.angle(phi[i]))
			opacity = val*0.95 + 0.05 #make it such that the minimum opacity is not zero
		field.append(v.sphere(pos=r[i], color=colour, radius=0.1*scale, opacity=opacity))
	return field

def vector_field(r,vec,vmin,vmax,colourtype,lengthtype,scale):
	field = []
	#work out limits of phi if not provided
	if(vmin==vmax==0):
		vmin,vmax = difn.vector_min_max(vec)
		#if they're the same because all passed field values are identical
		if(vmin==vmax):
			vmin = 0
	v_unit = difn.unit_vectors(vec)
	
	for i in range(len(r)):
		modv = numpy.sqrt(numpy.dot(vec[i],vec[i]))
		val = (modv-vmin)/(vmax-vmin)
		print vec[i],val
		if colourtype == 'fadetoblack':
			colour = (val,val,val)
		elif colourtype == 'rainbow':
			colour = col_rainbow(val)
			opacity = 1.0
		if lengthtype.__class__.__name__ == 'float' or lengthtype.__class__.__name__ == 'int':
			length = numpy.float(lengthtype)
		elif lengthtype == 'proportional':
			length = val
		else:
			length = 1
		if length != 0:
			scalefactor = 0.3
			print length*scale*scalefactor
			field.append(v.arrow(pos=r[i]-0.5*length*scale*scalefactor*v_unit[i], axis=length*scale*scalefactor*v_unit[i], color=colour)) #length needs to be determined automatically
	return field
	
def points(r,scale):
	field = []
	for i in range(len(r)):
		size = scale*0.07
		field.append(v.box(pos=r[i], length=size, height=size, width=size, color=(0,1,1)))
	return field

def freq_limits(r,f,fsmall,fbig,colour):
	#colour from black to white at fmax
	if f > fsmall and f < fbig:
		pointer = v.box(pos=r, color=colour, length=0.1e-10,height=0.1e-10,width=0.1e-10)

def click_catcher(scene):
	mm = scene.mouse.getevent()
	return mm
