import sys
import string
from main import *
from cmdint import *

class FindK:
	
	def __init__(self):
		self.cmdint=CmdInt()
		self.cmdint._user_loadmodel('massspring.sim')

		self.cmdint._user_setsimparam('step_size', 0.01)
		self.cmdint._user_setsimparam('comm_int', 0.06)
		self.cmdint._user_setsimparam('end_t', 4.0)

		self.cmdint._user_setconstval('RestLength', 0.2)
		self.cmdint._user_setconstval('mass', 0.23)
		self.cmdint._user_setconstval('Const_2', 0.1)

		self.kstepsize=0.01
		self.totaldata=[]
		self.squareerror=[]	

		self.measureddata=[]
		f=open('MassSpring.measured.x.txt', 'r')
		line=f.readline()
		line=f.readline()
		line=f.readline()
		#print line
		point=string.split(line)
		#print point[0]
		while line != "":
            		if line != "\n":
                		point = line.split('\t', 2)
                
			#print point[1]
                	self.measureddata.append(float(point[1]))

            		line = f.readline()
        	f.close()	

	def run(self):
		k=1

		index=self.cmdint.solver.blocknameindexdictionary['Add_1']
		indextime=self.cmdint.solver.blocknameindexdictionary['time']
		while(k<=10):
			self.cmdint._user_setconstval('k',k)
			self.cmdint._user_go()

			onepassdata=[]
			for eachstepdata in self.cmdint.solver.totaloutputdata:
				#print eachstepdata[index]
				onepassdata.append(eachstepdata[index])
			
			self.totaldata.append(onepassdata)
			k=k+self.kstepsize

		for each in self.totaldata:
			self.squareerror.append(self.calsqrerr(each))


		f=open('ktest', 'w')
		f.write('k value\tsum of square error\n')
		opteach=self.squareerror[0]
		m=0
		i=1
		j=len(self.squareerror)
		#print opteach
		while(i<j):
			if opteach > self.squareerror[i]:
				opteach=self.squareerror[i]
				m=i
			f.write(str(1+i*self.kstepsize)+'\t'+str(self.squareerror[i])+'\n')
			#print self.squareerror[i]
			i=i+1

		f.close()
		print 'the optimal k is ', 1+m*self.kstepsize
		print 'the optimal error is', self.squareerror[m]

	def calsqrerr(self, onepassdata):
		sumerror=0

		k=len(onepassdata)
		i=0
		#print len(onepassdata), len(self.measureddata)
		while(i<k):
			#print self.measureddata[i]
			error=onepassdata[i]-self.measureddata[i]
			sumerror=sumerror+error*error
			i=i+1
			
		return sumerror

		
if __name__ == '__main__':
	findoptk=FindK()
	findoptk.run()
	
	
