import pattern, synt, copy, search
from getpath import getpath
from unify import schemsubst

propflags = {2:3}

nreductionlist = []

def cmop(form):
	global propflags
	if type(form) is not list: return ''
	if len(form) < 3: return ''
	if type(form[2]) is not str: return ''
#	prop_flag = synt.mathdb[synt.MD_NRYPROP].get(form[2]) 
	prop_flag = propflags.get(form[2]) 
	if not prop_flag : return ''
	if prop_flag & 24 == 24:
		return form[2]
	return ''

def newvarlist(list):
	global synt 
# This is only called by notariancondense when some variable
# occurs in both indicial and accepted positions
	r = []	
	for x in list:
		synt.newvarnum = synt.newvarnum + 1
		r.append('v_{' + ("%d" % synt.newvarnum) + '}')
	return r


def deep(exp,raw = False):
	assert type(exp) is str or exp[0][0] > 0
	cmopp = cmop(exp)
	if type(exp) is str: 
		return exp
	elif exp[1] == '(':
		return deep(exp[2],raw)
	elif exp[0][0] in [14,15]:
		return exp[1]
	elif cmopp:
		retlist = [exp[0]]
		for r in exp[1:]:
			d = deep(r,raw)
			if cmop(d) == cmopp:
				retlist.extend(d[1:])
			else:
				retlist.append(d)
		return retlist
#	elif exp[0][0] == 45 and exp[0][1] == 6:  precedence level is exp[0][1]
	elif exp[0][0] == 45 and exp[0][1] in synt.mathdb[synt.MD_CTHNDL]:
		handler = eval('synt.'+ synt.mathdb[synt.MD_CTHNDL][exp[0][1]])
		recurselist = [exp[0]]
		for r in exp[1:]:
			recurselist.append(deep(r,raw))
#		return verbexpand(recurselist)
		return handler(recurselist) 
	elif len(exp[0]) > 1 and\
        not raw and\
        exp[0][0] in [40,41] and\
        exp[0][1] in [3,4,5,6,7]:
 		return deep(notariancondense(exp),raw = True)
	else:
		retlist = [exp[0]]
		for x in exp[1:]:
			retlist.append(deep(x,raw))
		return retlist

def notariancondense(pexp):
#	defs = synt.mathdb[synt.MD_DEFS][pexp[1]]
#	print(defs)
	styp = synt.symtype(pexp[1])
	ntyp = pexp[0][1]
	indvs = synt.indvlist(pexp)
	accvs = synt.accvlist(pexp)
	for x in indvs:
		if x in accvs:
			nlist = newvarlist(indvs)
			newform = synt.indvsubst(nlist, indvs, pexp)
			indvs = nlist
			break
	else:
		newform = copy.deepcopy(pexp)
	newscope = [[48,[]], indvs[0]]
	for x in indvs[1:]:
		newscope.append(',')
		newscope.append(x)
		
	for x in newform[1:]:
		if type(x) is str: 
			pass
		elif x[0][0] == 48:
			scopecond = scopecondition(x)
	if ntyp == 4:
		if scopecond:
			scopecond = makeand(scopecond, deep(newform[-1] ))
		else:
			scopecond = newform[-1]
	elif ntyp in [5,7]:
		if scopecond:
			scopecond = makeand(scopecond, deep(newform[-2])) 
		else:
			scopecond = newform[-2]
	if scopecond == []:
		if styp == 8 or ntyp in [6,7]:
			scopecond = '\\true'

	if ntyp in [3,5]:
		indform = newform[-1]
	elif ntyp == 4 :
		if len(indvs) == 1:
			indform = indvs[0]
		else: #We may never use this:
			indform = [[45,6]] + newscope[1:]
	else:
	 	indform = newform[2] 

	if ntyp in [6,7]and styp == 9:
		if ntyp == 7:
			newform[2] = deep(newform[2])
			newform[4] = newscope 
			newform[6] = deep(scopecond)
		elif ntyp == 6:
			newform[0][1] = 7
			newform[2] = deep(newform[2])
			newform[4] = newscope 
			newform[5:5] = [';',deep(scopecond)]
		newform3 = newform
	elif styp == 8 :
		newform[2:] = []
		newform[0][1] = 5
		newform.append(newscope)
		newform.append(';')
		newform.append(deep(scopecond))
		newform.append(deep(indform))
		newform3 = newform
	else:
		newform2= copy.deepcopy(newform)
		newform2[2:] = []
		newform2[0][1] = 5
		newform2.append(newscope)
		scopecond2 = copy.deepcopy(scopecond)
		if scopecond2:
			newform2.append(';')
			newform2.append(deep(scopecond2))
		else :
			newform2[0][1] = 3
		indform2 = deep(indform,raw=False)
		newform2.append(deep(indform2))
		newform3 = reduce(newform2,nreductionlist)
#		
	return newform3
	

def close_off(substitution):
	kyv = set([v for v in substitution if synt.symtype(v) in [10,11]]) 
	kys = set([v for v in substitution if synt.symtype(v) in [12,13]]) 
	inlist = []
	outlist= []
	done = set()
	while (kys | kyv) - done:
#		
		for s in kys - done:
			if not (kys | kyv)  & set(synt.nblist(substitution[s][1])):
				done.add(s)
		for v in kyv - done:
			if not (kys | kyv) & set(synt.nblist(substitution[v])):
				done.add(v)
		inlist = []
		outlist= [] 
		for v in kyv & done:
			inlist.append(substitution[v])
			outlist.append(v)
		for s in kys - done:
			for vs in kys & done:
				substitution[s] = [substitution[s][0],
								 schemsubst(substitution[vs],substitution[s][1])]
			substitution[s] =  [substitution[s][0],
								 synt.subst(inlist, outlist, substitution[s][1])]
		for v in kyv - done:
			for vs in kys & done:
				substitution[v] = schemsubst(substitution[vs],substitution[v]) 
			substitution[v] = synt.subst(inlist,outlist,substitution[v])
	return substitution

def reduce(form, reductionslist):
	junk=form
	if type(form) is str:
		if symtype(form) in [10, 14]:
			red_var = 'xvar'
		else:
			red_var = 'pvar'
	else:
		assert type(form) is list
		if form[0][0] in [40,42,44]:
			red_var = 'xvar'
		else:
			red_var = 'pvar'
	takenvars = synt.nblist(form) + synt.varlist(form)
	localvars = synt.nblist(form) + synt.varlist(form)
	match = search.search_rules(red_var, [form], ['$'], takenvars, localvars, 
						 rules=reductionslist)
#	print("match ==================",match)
#	print("form ===================",form)
	while match:
#		print("subst in =========+========",match.subst)
		subst = close_off(match.subst)
#		print(match.inferencerule.body)
#		print(match.inferencerule.filename)
#		input("subst out ====" + str(subst))
		form = match.subst[red_var] 
		match = search.search_rules(red_var, [form], ['$'], takenvars, localvars, rules=reductionslist)
#		print("form =========++=======",form)
#		print("match =========++=======",match)
#	if junk != form:
#			  print("in form", junk)
#			  input("out form " + str(form))
	return form
		

def makeand(exp1, exp2):
#	if '\\And' in mathdb[synt.MD_CAOPS]:
#		andback = [[45,5,-1]]
#	else:
#		andback = [[45,5]]	
	andback = [[45,synt.mathdb[synt.MD_PRECED]['\\And']]]
	if '\\And' in synt.mathdb[synt.MD_CTDEFS]:
		andback = synt.mathdb[synt.MD_CTDEFS]['\\And'][0][:1]
	else:
		raise SystemExit("Use of notarians requires \\And")
	if cmop(exp1) == '\\And':
		andback.extend(exp1[1:])
	else:
		andback.append(exp1)
	andback.append('\\And')
	if cmop(exp2) == '\\And':
		andback.extend(exp2[1:])
	else:
		andback.append(exp2)
	return andback

def scopecondition(scopenode):
	state = 0
	tree = []
	synt.addtoken(tree,'(')
	verbfound = 0
	for xi in scopenode[1:]:
		if type(xi) is list: 
			if state == 1:
				state = 2
			synt.addnode(tree,xi)
		elif synt.symtype(xi) == 14:
			if state == 1:
				state = 2
			synt.addtoken(tree,xi)
		elif synt.symtype(xi) == 10:
			if state == 0: 
				pass	
			elif state == 1:
				state = 2
			synt.addtoken(tree,xi)
		elif xi ==  ',':
			if state == 2:
				state = 0
				synt.addtoken(tree,'\\And')
			else:
				synt.addtoken(tree,xi)
		elif synt.symtype(xi) == 3:
			if state == 0:
				state = 1
				verbfound = 1
			synt.addtoken(tree,xi)
		else:
			raise RuntimeError("Scope error")
	if verbfound:
		synt.addtoken(tree,')')
		return tree[0][2]
	else:
		return []

############################################################
#
#  Get transitive properties from file and store in db
#
############################################################

def readprops(propfilename, entry, db):
	if len(db) < synt.MD_LEN:
			print("dfs file obsolete")
			raise SystemExit
	f = open(propfilename,"r")
	line_list = f.readlines()
	f.close()
	if len(entry) == 1:
		first_line = 1
		last_line = len(line_list)
	else:
		first_line = entry[1]
		last_line = entry[2] 
	for i, r in enumerate(line_list):
		if i < first_line - 1 or i >= last_line:
			continue
		infrulem = pattern.inference_rule.match(r)
		if not infrulem:
			continue
		infrule = Rule(r)
		if not infrule.signature == ['$', '-', '$', '\\C', '$']:
			continue
		xy_form = infrule.body[0]
		yz_form = infrule.body[2]
		xz_form = infrule.body[4]
		if not(xy_form[0][0] == yz_form[0][0] == xz_form[0][0] in [44,45]):
			continue
		if not(len(xy_form) == len(yz_form) == len(xz_form) == 4):
			continue
		if not all([type(x) is str for x in xy_form[1:]]):
			continue
		if not all([type(x) is str for x in yz_form[1:]]):
			continue
		if not all([type(x) is str for x in xz_form[1:]]):
			continue
		if xy_form[1] == xz_form[3] or xy_form[1] != xz_form[1] or yz_form[3] != xz_form[3]:
			continue
		db[synt.MD_TRMUL][(xy_form[2], yz_form[2])] = xz_form[2]
		if db[synt.MD_PRECED][xy_form[2]] not in db[synt.MD_TRPRS]:
			db[synt.MD_TRPRS].append(db[synt.MD_PRECED][xy_form[2]])
		if db[synt.MD_PRECED][yz_form[2]] not in db[synt.MD_TRPRS]:
			db[synt.MD_TRPRS].append(db[synt.MD_PRECED][yz_form[2]])
			

class Rule:

	def __init__(self,textline,raw = False):
		shallow = ruleparse(textline)
		parsed_rule = shallow[0]
		#signature = shallow[1]
		body = []   #formerly rule[0]
		allvars = []   #formerly rule[1]
		signature = []  #formerly rule[2]

		i = 0
		for sig_element in shallow[1]:	
			signature.append(sig_element)
			if sig_element == '$':
				d = deep(parsed_rule[i],raw)
				i = i + 1
				if type(d) is list and d[0][0] in [10,11]:
					body.append(d[1])
				else:
					body.append(d)
				for x in synt.varlist(d):
					if x not in allvars:
						allvars.append(x)
			else:
				body.append(sig_element)

		self.body = body
		self.allvars = allvars
		self.signature = signature
		self.filename = ''
		self.linenum = 0
		self.tracefun = ''

	def premise_list(self): 
		return self.body[:self.body.index('\\C')]

def ruleparse(textline):
		rule = []
		rulesignature = []
		t = textline
		while t:
			t = t.lstrip()
			if not t: break
			if t[0] == '$':
				TeXmatch = pattern.TeXmath.match(t)
				if not TeXmatch:
					raise ValueError(" Unmatched Tex dollar sign")
					print("Error: Unmatched Tex dollar sign")
					return 0
				rvar = synt.stringparse(TeXmatch.group(1))
				#rvar = stringparse(TeXmatch.group(1))
				if rvar == 0:
					raise ValueError(" Bad rule: " + TeXmatch.group(1) + " in " + textline)
					print("Error: Bad rule", TeXmatch.group(1), "in", textline)
					return 0
				rule.append(rvar)
				rulesignature.append('$')
				t = TeXmatch.group(2)
				continue
			punctsmatch = pattern.puncts.match(t)
			puncts = punctsmatch.group(1)
			findsinglematch = pattern.findsingle.match(puncts)
			if findsinglematch:
				if findsinglematch.start(1) == 0:
					rulesignature.append(puncts[0])
					t = t[1:]
				elif findsinglematch.group(1) in synt.reference_punctuator_list: 
					rulesignature.append(findsinglematch.group(1))
					t = t[findsinglematch.start(2):]
				else:
					raise ValueError(findsinglematch.group(1) + " invalid punctuator")
			else:
				u = punctsmatch.group(2)
				if puncts in synt.reference_punctuator_list:
					rulesignature.append(puncts)
					t = u 
				else:
					raise ValueError(puncts + " invalid punctuator")
		assert len(rule) == rulesignature.count('$')
		return (rule, rulesignature)


rulelist = []

def getrules(filelist, line_num = None, raw=False):

	rulelist = []

	for rulefilename in filelist:
		linenum = 0
		rulepathname = getpath(rulefilename)
		if not rulepathname:
			print("Rules file:",rulefilename,"not found.")
			raise SystemExit
		f = open(rulepathname,"r")
		rule_file_lines = f.readlines()
		f.close()
		for r in rule_file_lines:
			linenum = linenum + 1
			if r.find('$') == -1: continue
			if not pattern.inference_rule.match(r): 
					continue 
			try:
				parsed_rule = Rule(r, raw)
			except Exception as e:
				print(e)
				print("Error in ", rulefilename, ", line: ", linenum)
				raise SystemExit
	
			parsed_rule.filename = rulefilename
			parsed_rule.linenum = linenum
			if line_num and linenum != line_num:
				pass
			else:
				rulelist.append(parsed_rule)
	
	return rulelist

