-
Notifications
You must be signed in to change notification settings - Fork 29
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #28 from kundajelab/hotfix
Bugfix for tensorflow backend
- Loading branch information
Showing
7 changed files
with
1,295 additions
and
931 deletions.
There are no files selected for viewing
1,282 changes: 690 additions & 592 deletions
1,282
examples/simulated_tf_binding/TF MoDISco TAL GATA.ipynb
Large diffs are not rendered by default.
Oops, something went wrong.
708 changes: 373 additions & 335 deletions
708
examples/simulated_tf_binding/With Hit Scoring TF MoDISco TAL GATA.ipynb
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,80 @@ | ||
from __future__ import division, absolute_import, print_function | ||
from .tensorflow_backend import * | ||
import os | ||
import json | ||
import sys | ||
|
||
#This code is based on the keras backend code | ||
|
||
# Set TF-MoDISco base dir path given TFMODISCO_HOME env variable, if applicable. | ||
# Otherwise either ~/.tfmodisco or /tmp. | ||
if 'TFMODISCO_HOME' in os.environ: | ||
_tfmodisco_dir = os.environ.get('TFMODISCO_HOME') | ||
else: | ||
_tfmodisco_base_dir = os.path.expanduser('~') | ||
if not os.access(_tfmodisco_base_dir, os.W_OK): | ||
_tfmodisco_base_dir = '/tmp' | ||
_tfmodisco_dir = os.path.join(_tfmodisco_base_dir, '.tfmodisco') | ||
|
||
# Default backend: TensorFlow. | ||
_BACKEND = 'tensorflow' | ||
|
||
# Attempt to read tfmodisco config file. | ||
_config_path = os.path.expanduser(os.path.join(_tfmodisco_dir, 'tfmodisco.json')) | ||
if os.path.exists(_config_path): | ||
try: | ||
with open(_config_path) as f: | ||
_config = json.load(f) | ||
except ValueError: | ||
_config = {} | ||
_backend = _config.get('backend', _BACKEND) | ||
_BACKEND = _backend | ||
|
||
# Save config file, if possible. | ||
if not os.path.exists(_tfmodisco_dir): | ||
try: | ||
os.makedirs(_tfmodisco_dir) | ||
except OSError: | ||
# Except permission denied and potential race conditions | ||
# in multi-threaded environments. | ||
pass | ||
|
||
if not os.path.exists(_config_path): | ||
_config = { | ||
'backend': _BACKEND, | ||
} | ||
try: | ||
with open(_config_path, 'w') as f: | ||
f.write(json.dumps(_config, indent=4)) | ||
except IOError: | ||
# Except permission denied. | ||
pass | ||
|
||
# Set backend based on TFMODISCO_BACKEND flag, if applicable. | ||
if 'TFMODISCO_BACKEND' in os.environ: | ||
_backend = os.environ['TFMODISCO_BACKEND'] | ||
if _backend: | ||
_BACKEND = _backend | ||
|
||
# Import backend functions. | ||
if _BACKEND == 'theano': | ||
sys.stderr.write('TF-MoDISco is using the Theano backend.\n') | ||
from .theano_backend import * | ||
elif _BACKEND == 'tensorflow': | ||
sys.stderr.write('TF-MoDISco is using the TensorFlow backend.\n') | ||
from .tensorflow_backend import * | ||
else: | ||
raise ValueError('Unable to import backend : ' + str(_BACKEND)) | ||
|
||
|
||
def backend(): | ||
"""Publicly accessible method | ||
for determining the current backend. | ||
# Returns | ||
String, the name of the backend tfmodisco is currently using. | ||
# Example | ||
```python | ||
>>> tfmodisco.backend.backend() | ||
'tensorflow' | ||
``` | ||
""" | ||
return _BACKEND |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,150 @@ | ||
from __future__ import division, print_function | ||
import theano | ||
from theano import tensor as T | ||
from theano.tensor import signal | ||
from theano.tensor.signal import pool | ||
import numpy as np | ||
import sys | ||
|
||
|
||
def run_function_in_batches(func, | ||
input_data_list, | ||
learning_phase=None, | ||
batch_size=10, | ||
progress_update=1000, | ||
multimodal_output=False): | ||
#func has a return value such that the first index is the | ||
#batch. This function will run func in batches on the inputData | ||
#and will extend the result into one big list. | ||
#if multimodal_output=True, func has a return value such that first | ||
#index is the mode and second index is the batch | ||
assert isinstance(input_data_list, list), "input_data_list must be a list" | ||
#input_datas is an array of the different input_data modes. | ||
to_return = []; | ||
i = 0; | ||
while i < len(input_data_list[0]): | ||
if (progress_update is not None): | ||
if (i%progress_update == 0): | ||
print("Done",i) | ||
sys.stdout.flush() | ||
func_output = func(*([x[i:i+batch_size] for x in input_data_list] | ||
+([] if learning_phase is | ||
None else [learning_phase]) | ||
)) | ||
if (multimodal_output): | ||
assert isinstance(func_output, list),\ | ||
"multimodal_output=True yet function return value is not a list" | ||
if (len(to_return)==0): | ||
to_return = [[] for x in func_output] | ||
for to_extend, batch_results in zip(to_return, func_output): | ||
to_extend.extend(batch_results) | ||
else: | ||
to_return.extend(func_output) | ||
i += batch_size; | ||
return to_return | ||
|
||
|
||
def get_gapped_kmer_embedding_func(filters, biases, require_onehot_match): | ||
|
||
#filters should be: out_channels, rows, ACGT | ||
filters = filters.astype("float32") | ||
biases = biases.astype("float32") | ||
if (require_onehot_match): | ||
onehot_var = theano.tensor.TensorType(dtype=theano.config.floatX, | ||
broadcastable=[False]*3)("onehot") | ||
toembed_var = theano.tensor.TensorType(dtype=theano.config.floatX, | ||
broadcastable=[False]*3)("toembed") | ||
theano_filters = theano.tensor.as_tensor_variable( | ||
x=filters, name="filters") | ||
theano_biases = theano.tensor.as_tensor_variable(x=biases, name="biases") | ||
if (require_onehot_match): | ||
onehot_out = 1.0*((theano.tensor.nnet.conv2d( | ||
input=onehot_var[:,None,:,:], | ||
filters=theano_filters[:,None,::-1,::-1], | ||
border_mode='valid')[:,:,:,0] + biases[None,:,None]) | ||
> 0.0) | ||
embedding_out = theano.tensor.sum((theano.tensor.nnet.conv2d( | ||
input=toembed_var[:,None,:,:], | ||
filters=theano_filters[:,None,::-1,::-1], | ||
border_mode='valid')[:,:,:,0])* | ||
(onehot_out if require_onehot_match else 1.0), axis=2) | ||
if (require_onehot_match): | ||
func = theano.function([onehot_var, toembed_var], embedding_out, | ||
allow_input_downcast=True) | ||
def batchwise_func(onehot, to_embed, batch_size, progress_update): | ||
return np.array(run_function_in_batches( | ||
func=func, | ||
input_data_list=[onehot, to_embed], | ||
batch_size=batch_size, | ||
progress_update=progress_update)) | ||
else: | ||
func = theano.function([toembed_var], embedding_out, | ||
allow_input_downcast=True) | ||
def batchwise_func(to_embed, batch_size, progress_update): | ||
return np.array(run_function_in_batches( | ||
func=func, | ||
input_data_list=[to_embed], | ||
batch_size=batch_size, | ||
progress_update=progress_update)) | ||
return batchwise_func | ||
|
||
|
||
def max_cross_corrs(filters, things_to_scan, min_overlap, | ||
batch_size=50, | ||
func_params_size=1000000, | ||
progress_update=1000): | ||
""" | ||
func_params_size: when compiling functions | ||
""" | ||
#reverse the patterns as the func is a conv not a cross corr | ||
assert len(filters.shape)==3,"Did you pass in filters of unequal len?" | ||
assert filters.shape[-1]==things_to_scan.shape[-1] | ||
filters = filters.astype("float32")[:,::-1,::-1] | ||
to_return = np.zeros((filters.shape[0], len(things_to_scan))) | ||
#compile the number of filters that result in a function with | ||
#params equal to func_params_size | ||
params_per_filter = np.prod(filters[0].shape) | ||
filter_batch_size = int(func_params_size/params_per_filter) | ||
filter_length = filters.shape[1] | ||
filter_idx = 0 | ||
while filter_idx < filters.shape[0]: | ||
if (progress_update is not None): | ||
print("On filters",filter_idx,"to", | ||
min((filter_idx+filter_batch_size),len(filters))) | ||
sys.stdout.flush() | ||
|
||
filter_batch = filters[filter_idx: | ||
min((filter_idx+filter_batch_size),len(filters))] | ||
|
||
padding_amount = int((filter_length)*(1-min_overlap)) | ||
padded_input = [np.pad(array=x, | ||
pad_width=((padding_amount, padding_amount), | ||
(0,0)), | ||
mode="constant") for x in things_to_scan] | ||
|
||
input_var = theano.tensor.TensorType(dtype=theano.config.floatX, | ||
broadcastable=[False]*3)("input") | ||
theano_filters = theano.tensor.as_tensor_variable( | ||
x=filter_batch, name="filters") | ||
conv_out = theano.tensor.nnet.conv2d( | ||
input=input_var[:,None,:,:], | ||
filters=theano_filters[:,None,::-1,::-1], | ||
border_mode='valid')[:,:,:,0] | ||
|
||
max_out = T.max(conv_out, axis=-1) | ||
|
||
max_cross_corr_func = theano.function([input_var], max_out, | ||
allow_input_downcast=True) | ||
|
||
max_cross_corrs = np.array(run_function_in_batches( | ||
func=max_cross_corr_func, | ||
input_data_list=[padded_input], | ||
batch_size=batch_size, | ||
progress_update=progress_update)) | ||
assert len(max_cross_corrs.shape)==2, max_cross_corrs.shape | ||
to_return[filter_idx: | ||
min((filter_idx+filter_batch_size),len(filters)),:] =\ | ||
np.transpose(max_cross_corrs) | ||
filter_idx += filter_batch_size | ||
|
||
return to_return |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters