-
Notifications
You must be signed in to change notification settings - Fork 0
/
spline1.jl
275 lines (213 loc) · 6.81 KB
/
spline1.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
# This file contains code, adapted from Numerical Recipes, for computing cubic spline
# interpolation
include("io1.jl")
# The spline type: n is the number of grid points, x is the set of grid points,
# y is the set of values of the function on the grid points, and ydp is the set of
# of second derivatives on the grid points, constructed as part of the cubic spline.
struct spline
n::Int64
x::Array{Float64}
y::Array{Float64}
ydp::Array{Float64}
end
# Create an equally-spaced grid of points between xlow and xhigh.
function creategrid(xlow,xhigh,npts)
xinc = (xhigh - xlow)/(npts-1)
return collect(xlow:xinc:xhigh)
end
# Function for writing a set of arrays according to a specified format to output io.
function writearrays(io,format,arraylist...;writeindex=true)
nvars = length(arraylist)
if writeindex
if !isa(format[1],Int64)
wait("First element of format is not an integer in writearrays.")
else
format2 = (format[1],)
end
for i in 2:length(format)
if isa(format[i],Tuple)
if (length(format[i]) != 2)
wait("Invalid format in writearrays.")
else
for j in 1:format[i][1]
format2 = (format2...,format[i][2])
end
end
else
format2 = (format2...,format[i])
end
end
if (length(format2) != nvars+1)
wait("Format length not correct in writearrays.")
end
for i in 1:length(arraylist[1])
writeio(io,format2[1],i,cr=false)
for j = 1:nvars
writeio(io,format2[j+1],arraylist[j][i],cr=(j==nvars))
end
end
else
format2 = ()
for i in 1:length(format)
if isa(format[i],Tuple)
if (length(format[i]) != 2)
wait("Invalid format in writearrays.")
else
for j in 1:format[i][1]
format2 = (format2...,format[i][2])
end
end
else
format2 = (format2...,format[i])
end
end
if (length(format2) != nvars)
wait("Format length not correct in writearrays.")
end
for i in 1:length(arraylist[1])
for j = 1:nvars
writeio(io,format2[j],arraylist[j][i],cr=(j==nvars))
end
end
end
end
# Function for solving a set of tridiagonal linear equations, with coefficents a, b, c, and r
# as outlined in lecture notes from March 17.
function tridag(a,b,c,r)
toler = 1.0e-12
n = length(a)
u = zeros(n)
gam = zeros(n)
bet = b[1]
u[1] = r[1]/bet
for j in 2:n
gam[j] = c[j-1]/bet
bet = b[j] - a[j]*gam[j]
if (abs(bet) <= toler)
wait("Failure in subroutine tridag.")
end
u[j] = (r[j]-a[j]*u[j-1])/bet
end
for j in n-1:-1:1
u[j] = u[j] - gam[j+1]*u[j+1]
end
return u
end
# Function to create a cubic spline. fpts is the set of grid points and flevel is the
# set of values of the function on the grid points. makespline returns an object of type
# spline.
function makespline(fpts,flevel)
zero = 0.0
one = 1.0
npts = length(fpts)
a = zeros(npts)
b = zeros(npts)
c = zeros(npts)
r = zeros(npts)
an = -one
c1 = -one
a[1] = zero
a[npts] = an
b[1] = one
b[npts] = one
c[1] = c1
c[npts] = zero
r[1] = zero
r[npts] = zero
for i in 2:npts-1
a[i] = (fpts[i]-fpts[i-1])/6
b[i] = (fpts[i+1]-fpts[i-1])/3
c[i] = (fpts[i+1]-fpts[i])/6
r[i] = (flevel[i+1]-flevel[i])/(fpts[i+1]-fpts[i]) -
(flevel[i]-flevel[i-1])/(fpts[i]-fpts[i-1])
end
vdp = tridag(a,b,c,r)
return spline(npts,fpts,flevel,vdp)
end
# Calculate interpolated values at x using the cubic spline embodied in yspline (of type spline).
# calcy, calcyp, and calcydp are optional arguments; if calcy = true then the level (y) at x is
# is computed; if calcyp = true then the first derivative (yp) at x is computed; if calcydp = true
# then second derivative (ydp) at x is computed. interp returns the type of real numbers
# (y,yp,ydp).
function interp(x,yspline;calcy=true,calcyp=false,calcydp=false)
one = 1.0
npts = yspline.n
klo = 1
khi = npts
while ((khi-klo) > 1)
k = Int(trunc((khi+klo)/2))
if (yspline.x[k] > x)
khi = k
else
klo = k
end
end
h = yspline.x[khi] - yspline.x[klo]
a = (yspline.x[khi] - x)/h
b = (x - yspline.x[klo])/h
asq = a*a
bsq = b*b
if calcy
y = a*yspline.y[klo] + b*yspline.y[khi] +
((asq*a-a)*yspline.ydp[klo]
+(bsq*b-b)*yspline.ydp[khi])*(h*h)/6
else
y = 0.0
end
if calcyp
yp = (yspline.y[khi]-yspline.y[klo])/h -
(3*asq-one)/6*h*yspline.ydp[klo] +
(3*bsq-one)/6*h*yspline.ydp[khi]
else
yp = 0.0
end
if calcydp
ydp = a*yspline.ydp[klo] + b*yspline.ydp[khi]
else
ydp = 0.0
end
interp = (y,yp,ydp)
end
# The rest of the code illustrates how to compute a cubic spline and how to calculate
# interpolated values using it. The example function f(x) is log(x). fp(x) is the first
# derivative of f at x and fdp(x) is the second derivative of f at x. When
# including spline1.jl in another program, you should delete all the code below.
function f(x)
return log(x)
end
function fp(x)
return 1.0/x
end
function fdp(x)
return -1.0/(x*x)
end
nx = 10
a = 1.0
b = 3.0
# Create the grid.
x = creategrid(a,b,nx)
# Calculate the function values on the grid.
y = f.(x)
# Compute the cubic spline using x and y.
yspline = makespline(x,y)
writearrays(stdout,(5,(3,15.8)),yspline.x,yspline.y,yspline.ydp)
wait()
# Loop to calculate interpolated levels and first and second derivatives and compare them to
# the correct values.
doloop = true
while doloop
writeio(stdout,("Enter x: ",),cr=false)
xval = readio(stdin,1)
if (xval < 0)
break
end
(yval,yvalp,yvaldp) = interp(xval,yspline,calcy=true,calcyp=true,calcydp=true)
writeio(stdout,((7,15.8),),xval,yval,f(xval),yvalp,fp(xval),yvaldp,fdp(xval))
end
# Set up a fine grid and calculate interpolated values and actual (correct) values on this grid;
# diff is the difference between the interpolated and correct values.
nxfine = 200
xfine = creategrid(a,b,nxfine)
ysplinefine = broadcast(x->interp(x,yspline)[1],xfine)
yfine = f.(xfine)
diff = ysplinefine - yfine