#! /usr/bin/env python
import sys
import os

def extract_literals(body):
	atoms=[]
	body=body.rstrip(".\n")+","
#	print body
	lindex=0
	lastindex=0
	pos=body.find(",",lindex)
	while pos!=-1:
#		print "sustring:"+body[0:pos]
		if (body[0:pos].count("(")==body[0:pos].count(")")) & (body[0:pos].count("{")==body[0:pos].count("}")):
			lit = body[lastindex:pos].strip()
			atoms.append(lit)
			lastindex=pos+1
		lindex=pos+1
		pos=body.find(",",lindex)
	return atoms

def extract_atoms(body):
	atoms=[]
	body=body.rstrip(".\n")+","
#	print body
	lindex=0
	lastindex=0
	pos=body.find(",",lindex)
	while pos!=-1:
#		print "subtring:"+body[0:pos]
		if body[0:pos].count("(")==body[0:pos].count(")"):
			lit = body[lastindex:pos]
		#	print lit
			if lit.find("{")!=-1:
				lit=lit.split("{")[1]
			if lit.find("}")!=-1:
				lit=lit.split("}")[0]
			if lit.find(":")!=-1:
				lits=lit.split(":")
				for l in lits:
					if (l.find("not ")==0):
						l=l[4:]
					l=l.lstrip("-")
					atoms.append(l)
			if lit.find("|")!=-1:
				lits=lit.split("|")
				for l in lits:
					if (l.find("not ")==0):
						l=l[4:]
					l=l.lstrip("-")
					atoms.append(l)
			else:
				if (lit.find("not ")==0):
					lit = lit[4:]
				lit=lit.lstrip("-")
				atoms.append(lit)
			lastindex=pos+1
		lindex=pos+1
		pos=body.find(",",lindex)
	return atoms

def extract_constants(literal):
#	print "extract_constants",literal
	arg_list=[]
	if len(literal)==0:
		return []
	if len(literal)==1:
		lit=literal[0]
		if lit.find("==")!=-1:
			return lit.split("==")
		if lit.find("!=")!=-1:
			return lit.split("!=")
		if lit.find("(")==-1: # no constants
			return literal
		else:
			if lit.find("{")!=-1: #cardinality constraints
				args=lit[lit.find("{")+1:lit.rfind("}")]
				return extract_constants(extract_atoms(args))
			elif lit.find("(")!=-1:
				args=lit[lit.find("(")+1:lit.rfind(")")]
				atoms=extract_atoms(args)
	#			print atoms
				for atom in atoms:
					if atom.find("(")!=-1:
						print "%error: find constants "+atom+", leading to infinite domain! Only non-nested constants will be extracted."
				return extract_constants(atoms)
	else:
		return extract_constants([literal[0]])+extract_constants(literal[1:])
		

def get_pred(head):
	if head.find("(")==-1:
		return []
	pred=head[0:head.find("(")]
	args=head[head.find("(")+1:head.rfind(")")]
	if args.find(";;")!=-1:
		args=(args.split(";;"))[0]
	elif args.find(";")!=-1:
		args=(args.split(";"))[0]

	args=args+","
	arity=0
	lindex=0
	pos=args.find(",",lindex)
	while pos!=-1:
#		print "substring:"+args[0:pos]
		if args[0:pos].count("(")==args[0:pos].count(")"):
			arity = arity+1
		lindex=pos+1
		pos=args.find(",",lindex)
#	print arity
	pred=pred.strip()
	if (pred.find("==")==-1) & (pred.find("!=")==-1) & (pred!="true") & (pred!="false"):
		return [pred,arity]
	else:
		return []


def extract_objects(cons_list):
	p=[]
	q=[]
	obj_list=[]
	for cons in cons_list:
		p=p+cons
	for cons in p:
		if (cons.find(";;")!=-1):
			q=q+cons.split(";;")
		elif (cons.find(";")!=-1):
			q=q+cons.split(";")
		else:
			 q=q+[cons]

	for obj in q:
		obj=obj.strip()
		if ((obj=="true") | (obj=="false")):
			continue
		elif len(obj)==0:
			continue
		elif (obj[0].islower() | obj.isdigit()) & (obj_list.count(obj)==0):
			obj_list=obj_list+[obj]
	return obj_list
				
def print_objects(obj_list):
	for obj in obj_list:
		print "u("+obj+")."
	print "#hide u/1."


def print_domain_variables(m):
	for i in range(1,m+1):
		print "#domain u(V_EQ_X"+str(i)+")."
		print "#domain u(V_EQ_Y"+str(i)+")."
#		print "#domain u(V_EQ_Z"+str(i)+")."

def remove_dup(pred_list):
	if (len(pred_list)==0):
		return pred_list
	p=[]
	for pred in pred_list:
		if (len(pred)==0):
			 continue
		if (p.count(pred)==0) & (len(pred[0])>0):
			p=p+[pred]
	return p


def print_eq_rules(pred_list):
	for pred in pred_list:
		name=pred[0]
#		if name=="_db":
#			continue
		arity=pred[1]
		print name+"(",
		for i in range(1,arity):
			print "V_EQ_Y"+str(i)+",",
		print "V_EQ_Y"+str(arity)+"):-"+name+"(",
		for i in range(1,arity):
			print "V_EQ_X"+str(i)+",",
		print "V_EQ_X"+str(arity)+"),",
		for i in range(1,arity):
			print "eq(V_EQ_X"+str(i)+",V_EQ_Y"+str(i)+"),",
		print "eq(V_EQ_X"+str(arity)+",V_EQ_Y"+str(arity)+")."
	print "eq(V_EQ_X1,V_EQ_X1)."
	print "eq(V_EQ_X1,V_EQ_Y1):- eq(V_EQ_Y1,V_EQ_X1)."
	print "eq(V_EQ_X1,V_EQ_Z1):- eq(V_EQ_X1,V_EQ_Y1), eq(V_EQ_Y1, V_EQ_Z1)."
	print "{eq(V_EQ_X1,V_EQ_Y1)}."

def extract_db_objects(cons_list,null_list):
	db_objects=[]
	for item in cons_list:
		if null_list.count(item)==0:
			db_objects=db_objects+[item]
	return db_objects

def print_db_objects(db_objects):
	for item in db_objects:
		print "_db("+item+")."
	print "#domain _db(V_DB_X1)."
	print "#domain _db(V_DB_X2)."
	print ":- eq(V_DB_X1,V_DB_X2),V_DB_X1!=V_DB_X2."
	print "#hide _db/1."	

#main function begins here

pred_list=[]
cons_list=[]
null_list=[]
db_objects=[]
obj_list=[]

#print sys.argv
flag=0

if len(sys.argv)<2: # no input is given
#	print "No input file."
#	print "Command line format: nonH.py inputfile [-una|-no-una] [list of object constants]."
#	sys.exit()
	f=sys.stdin

else:
	if (os.path.isfile(sys.argv[1])):
		inputfile=sys.argv[1]
		f=open(inputfile,'r')
		if len(sys.argv)==3:
			if (sys.argv[2]!="-una") & (sys.argv[2]!="-no-una"):
				print "Option "+sys.argv[2]+" is not valid."	
				print "Command line format: nonH.py inputfile [-una|-no-una] [list of object constants]."
				sys.exit()
			if (sys.argv[2]=="-una"):
				print "%Unique names assumed for all object constants. Nothing to be done"
				for line in f:
					print line
				sys.exit()
	else:
		f=sys.stdin
		if len(sys.argv)==2:
			if (sys.argv[1]!="-una") & (sys.argv[1]!="-no-una"):
				print "Option "+sys.argv[1]+" is not valid."	
				print "Command line format: nonH.py inputfile [-una|-no-una] [list of object constants]."
				sys.exit()
			if (sys.argv[1]=="-una"):
				print "%Unique names assumed for all object constants. Nothing to be done"
				for line in f:
					print line
				sys.exit()
if f==sys.stdin:
	tmpfile=open('_tmp','w')
	for line in f:
		tmpfile.write(line)
	tmpfile.close()
	f=open('_tmp','r')
#extracting predicates, arities, and constants
for line in f:
	line=line.strip()
	if len(line)==0:
		continue
	if ((line[0]=="#") | (line[0]=="%")):
		continue
	if (line=="true.\n") | (line==":-false.\n"):
		continue
	head=line.split(":-")[0]
	if len(head)>0:
#extract predicate, arties and constants in head
		head=head.lstrip("-") #get rid of classical negation
		head_atoms=extract_atoms(head)
		for atom in head_atoms:
			pred_arity=get_pred(atom)
			if len(pred_arity)!=0:
				pred_list.append(pred_arity)
				cons_list.append(extract_constants([atom]))

	if line.find(":-")==-1: #body is empty
		continue		
	
	body=line.split(":-")[1]
	body_atoms=extract_atoms(body)
	for atom in body_atoms:
			pred_arity=get_pred(atom)
			if len(pred_arity)!=0:
				pred_list.append(pred_arity)
				cons_list.append(extract_constants([atom]))

#			print pred_list
#			print cons_list

obj_list=[]
obj_list=extract_objects(cons_list)
if (f==sys.stdin):
	if (len(sys.argv)==2):
		db_objects=[]
	elif(len(sys.argv)==3):
		if (sys.argv[2]=="-no-una"):
			print "No constant list specified for which UNA doesn't hold!"
			sys.exit()
#			db_objects=[]		
	elif len(sys.argv)>3:
		if (sys.argv[2]=="-una"):
			db_objects = sys.argv[3:]
		elif (sys.argv[2]=="-no-una"):
			null_list=sys.argv[3:]			
			db_objects=extract_db_objects(obj_list,null_list)
else:	
	if (len(sys.argv)==2):
		db_objects=[]
	elif(len(sys.argv)==3):
		if (sys.argv[2]=="-no-una"):
			print "No constant list specified for which UNA doesn't hold!"
			sys.exit()
#			db_objects=[]		
	elif len(sys.argv)>3:
		if (sys.argv[2]=="-una"):
			db_objects = sys.argv[3:]
		elif (sys.argv[2]=="-no-una"):
			null_list=sys.argv[3:]			
			db_objects=extract_db_objects(obj_list,null_list)

print_objects(obj_list)
print_db_objects(db_objects)			

m=-1
for t in pred_list:
	if len(t)!=0:
		if m<t[1]:
			m=t[1]
#print m
print_domain_variables(m)
pred_list=remove_dup(pred_list)
#print pred_list
print_eq_rules(pred_list)

f.seek(0)

#substitution
for line in f:
	if (line=="\n"):
		print
		continue
	if (line[0]=='#') | (line[0]=="%"):
		print line.rstrip("\n")
		continue
	ar=line.split(":-")
	if len(ar[0])!=0:
		head=ar[0];
		head_literals=extract_literals(head)
		for literal in head_literals:
			if literal.find("==")!=-1:
				i=head_literals.index(literal)
				head_literals.remove(literal)
				part=literal.split("==")
				l="eq("+part[0].strip()+","+part[1].strip()+")"
				head_literal.insert(i,l)
	else:
		head_literals=[]

#	print "head_lits", head_literals
	
	if line.find(":-")!=-1:
		body=ar[1];
		body_literals=extract_literals(ar[1])
		for literal in body_literals:
			if literal.find("==")!=-1:
				i=body_literals.index(literal)
				body_literals.remove(literal)
				part=literal.split("==")
				l="eq("+part[0].strip()+","+part[1].strip()+")"
				body_literals.insert(i,l)
	else:
		body_literals=[]

#	print "h:",head_literals
#	print "b:",body_literals

	if (len(body_literals)!=0) | (len(head_literals)!=0):
		if len(body_literals)==0:	
			print "|".join(head_literals)+"."
		else:
			print "|".join(head_literals)+":-"+",".join(body_literals)+"."

f.close()

