Skip to content

Commit

Permalink
Merge pull request #78 from rmodrak/devel
Browse files Browse the repository at this point in the history
adjustments for latest version of SLURM
  • Loading branch information
rmodrak authored Sep 25, 2017
2 parents 61ea626 + f82fde8 commit 91c2b2b
Show file tree
Hide file tree
Showing 8 changed files with 119 additions and 69 deletions.
5 changes: 1 addition & 4 deletions scripts/visualize/specfem2d/quickplot
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,7 @@ if __name__ == '__main__':
""" USAGE: quickplot DIR [ARGS]
Plots all SPECFEM2D model or kernel files in given directory DIR.
Any additional arguments ARGS are simply passed to PLOTGLL and not used
directly by QUICKPLOT
Any additional ARGS are simply passed to PLOTGLL.
"""

exe = 'python -W ignore '+which('plotgll')
Expand All @@ -95,7 +93,6 @@ if __name__ == '__main__':
"usage: quickplot DIR [ARGS]")

dir = parse_args()

if not exists(dir):
raise IOError(
"directory does not exit: %s" % dir)
Expand Down
24 changes: 15 additions & 9 deletions seisflows/postprocess/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,10 @@


class base(object):
""" Gradient postprocessing class
""" Postprocessing base class
Postprocesing refers to image processing and regularization operations on
models or gradients
"""

def check(self):
Expand Down Expand Up @@ -95,19 +98,22 @@ def process_kernels(self, path='', parameters=[]):
parameters = solver.parameters

if PAR.SMOOTH > 0:
suffix = '_nosmooth'

solver.combine(
input_path=path,
output_path=path+'/'+'sum'+suffix,
parameters=parameters)
solver.combine(
input_path=path,
output_path=path+'/'+'sum_nosmooth',
parameters=parameters)

if PAR.SMOOTH > 0.:
solver.smooth(
input_path=path+'/'+'sum'+suffix,
input_path=path+'/'+'sum_nosmooth',
output_path=path+'/'+'sum',
parameters=parameters,
span=PAR.SMOOTH)
else:
solver.combine(
input_path=path,
output_path=path+'/'+'sum',
parameters=parameters)



def save(self, g, path='', parameters=[], backup=None):
Expand Down
17 changes: 17 additions & 0 deletions seisflows/postprocess/default.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@

import sys

from seisflows.config import custom_import

PAR = sys.modules['seisflows_parameters']
PATH = sys.modules['seisflows_paths']


class default(custom_import('postprocess', 'base')):
""" Default postprocesing option
Provides default image processing and regularization functions for models
or gradients
"""
# currently identical to base class
pass
3 changes: 3 additions & 0 deletions seisflows/preprocess/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@

class base(object):
""" Data preprocessing class
Provides data processing functions for seismic traces, with options for
data misfit, filtering, normalization and muting
"""

def check(self):
Expand Down
18 changes: 18 additions & 0 deletions seisflows/preprocess/default.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@

import sys

from seisflows.config import custom_import

PAR = sys.modules['seisflows_parameters']
PATH = sys.modules['seisflows_paths']


class default(custom_import('preprocess', 'base')):
""" Default preprocesing class
Provides data processing functions for seismic traces, with options for
data misfit, filtering, normalization and muting
"""
# currently identical to base class
pass

15 changes: 8 additions & 7 deletions seisflows/preprocess/double_difference.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@


class double_difference(custom_import('preprocess', 'base')):
""" Data preprocessing class
""" Double-difference data processing class
Adds double-difference data misfit functions to base class
"""

def check(self):
Expand Down Expand Up @@ -48,22 +50,21 @@ def write_residuals(self, path, syn, dat):
nr, _ = self.get_network_size(syn)
rx, ry, rz = self.get_receiver_coords(syn)

# calculate distances between stations
dist = np.zeros((nr,nr))
count = np.zeros(nr)
delta_syn = np.zeros((nr,nr))
delta_obs = np.zeros((nr,nr))

# calculate distances between stations
for i in range(nr):
for j in range(i):
dist[i,j] = self.distance(rx[i], rx[j], ry[i], ry[j])

# calculate traveltime differences between stations
delta_syn = np.zeros((nr,nr))
delta_obs = np.zeros((nr,nr))

# calculate traveltime lags between stations pairs
for i in range(nr):
for j in range(i):
if dist[i,j] > PAR.DISTMAX:
continue

delta_syn[i,j] = self.misfit(syn[i].data, syn[j].data, nt, dt)
delta_obs[i,j] = self.misfit(dat[i].data, dat[j].data, nt, dt)
delta_syn[j,i] = -delta_syn[i,j]
Expand Down
105 changes: 56 additions & 49 deletions seisflows/system/slurm_lg.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,14 +128,22 @@ def submit(self, workflow):

def run(self, classname, method, hosts='all', **kwargs):
""" Executes the following task:
classname.method(*args, **kwargs)
classname.method(**kwargs)
"""
self.checkpoint()

self.save_kwargs(classname, method, kwargs)
jobs = self.submit_job_array(classname, method, hosts)

# submit job array
stdout = check_output(
self.job_array_cmd(classname, method, hosts),
shell=True)

# keep track of job ids
jobs = self.job_id_list(stdout, hosts)

# check job array completion status
while True:
# wait a few seconds before checking again
# wait a few seconds between queries
time.sleep(5)

isdone, jobs = self.job_array_status(classname, method, jobs)
Expand All @@ -153,67 +161,55 @@ def taskid(self):
""" Provides a unique identifier for each running task
"""
try:
return int(os.getenv('SEISFLOWS_TASKID'))
except:
return int(os.getenv('SLURM_ARRAY_TASK_ID'))
except:
raise Exception("TASK_ID environment variable not defined.")


### job array methods

def submit_job_array(self, classname, method, hosts='all'):
""" Submits job array and returns associated job ids
"""
# submit job array
cmd = self.job_array_cmd(classname, method, hosts)
stdout = check_output(cmd, shell=True)

# construct job id list
id = stdout.split()[-1].strip()
if hosts=='all':
tasks = range(PAR.NTASK)
jobs = [id+'_'+str(task) for task in tasks]
else:
jobs = [id+'_0']
return jobs


def job_array_cmd(self, classname, method, hosts):
return ('sbatch '
+ '%s ' % PAR.SLURMARGS
+ '--job-name=%s ' % PAR.TITLE
+ '--nodes=%d ' % math.ceil(PAR.NPROC/float(PAR.NODESIZE))
+ '--ntasks-per-node=%d ' % PAR.NODESIZE
+ '--ntasks=%d ' % PAR.NPROC
+ '--time=%d ' % PAR.TASKTIME
+ self.job_array_args(hosts)
+ findpath('seisflows.system') +'/'+ 'wrappers/run '
+ PATH.OUTPUT + ' '
+ classname + ' '
+ method + ' '
+ PAR.ENVIRONS)


def job_array_args(self, hosts):
if hosts == 'all':
args = ('--array=%d-%d ' % (0,PAR.NTASK-1%PAR.NTASKMAX)
+'--output %s ' % (PATH.WORKDIR+'/'+'output.slurm/'+'%A_%a'))
return ('sbatch %s ' % PAR.SLURMARGS
+ '--job-name=%s ' % PAR.TITLE
+ '--nodes=%d ' % math.ceil(PAR.NPROC/float(PAR.NODESIZE))
+ '--ntasks-per-node=%d ' % PAR.NODESIZE
+ '--ntasks=%d ' % PAR.NPROC
+ '--time=%d ' % PAR.TASKTIME
+ '--array=%d-%d ' % (0,PAR.NTASK-1%PAR.NTASKMAX)
+ '--output %s ' % (PATH.WORKDIR+'/'+'output.slurm/'+'%A_%a')
+ '%s ' % (findpath('seisflows.system') +'/'+ 'wrappers/run')
+ '%s ' % PATH.OUTPUT
+ '%s ' % classname
+ '%s ' % method
+ '%s ' % PAR.ENVIRONS)

elif hosts == 'head':
args = ('--array=%d-%d ' % (0,0)
+'--output=%s ' % (PATH.WORKDIR+'/'+'output.slurm/'+'%j'))
return ('sbatch %s ' % PAR.SLURMARGS
+ '--job-name=%s ' % PAR.TITLE
+ '--nodes=%d ' % math.ceil(PAR.NPROC/float(PAR.NODESIZE))
+ '--ntasks-per-node=%d ' % PAR.NODESIZE
+ '--ntasks=%d ' % PAR.NPROC
+ '--time=%d ' % PAR.TASKTIME
+ '--array=%d-%d ' % (0,0)
+ '--output %s ' % (PATH.WORKDIR+'/'+'output.slurm/'+'%A_%a')
+ '%s ' % (findpath('seisflows.system') +'/'+ 'wrappers/run')
+ '%s ' % PATH.OUTPUT
+ '%s ' % classname
+ '%s ' % method
+ '%s ' % PAR.ENVIRONS
+ '%s ' % 'SEISFLOWS_TASKID=0')

else:
raise KeyError('Bad keyword argument: system.run: hosts')

return args


def job_array_status(self, classname, method, jobs):
""" Determines completion status of one or more jobs
""" Determines completion status of job array
"""
states = []
for job in jobs:
state = self._query(job)
state = self.job_status(job)
if state in ['TIMEOUT']:
print msg.TimoutError % (classname, method, job, PAR.TASKTIME)
sys.exit(-1)
Expand All @@ -230,8 +226,18 @@ def job_array_status(self, classname, method, jobs):
return isdone, jobs


def _query(self, job):
""" Queries job state from SLURM database
def job_id_list(self, stdout, hosts):
""" Parses job id list from sbatch standard output
"""
job_id = stdout.split()[-1].strip()
if hosts == 'all':
return [job_id+'_'+str(ii) for ii in range(PAR.NTASK)]
else:
return [job_id+'_0']


def job_status(self, job):
""" Queries completion status of a single job
"""
stdout = check_output(
'sacct -n -o jobid,state -j '+ job.split('_')[0],
Expand All @@ -251,3 +257,4 @@ def save_kwargs(self, classname, method, kwargs):
unix.mkdir(kwargspath)
saveobj(kwargsfile, kwargs)


1 change: 1 addition & 0 deletions tests/setup/test_import
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ if not matplotlib:

if not obspy:
excluded += ['seisflows.preprocess.base']
excluded += ['seisflows.preprocess.default']
excluded += ['seisflows.preprocess.double_difference']


Expand Down

0 comments on commit 91c2b2b

Please sign in to comment.