You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hi,
Would it be possible to pass-through the GPU id when running RosettaFold2 so that in setups where there is more than one GPU, it can run in parallel? Thx.
Example where it is currently hard-coded:
network/predict.py
class Predictor():
def __init__(self, model_weights, device="cuda:0"):
# define model name
self.model_weights = model_weights
self.device = device
self.active_fn = nn.Softmax(dim=1)
# define model & load model
self.model = RoseTTAFoldModule(
**MODEL_PARAM
).to(self.device)
could_load = self.load_model(self.model_weights)
if not could_load:
print ("ERROR: failed to load model")
sys.exit()
The text was updated successfully, but these errors were encountered:
Hi,
Would it be possible to pass-through the GPU id when running RosettaFold2 so that in setups where there is more than one GPU, it can run in parallel? Thx.
Example where it is currently hard-coded:
network/predict.py
The text was updated successfully, but these errors were encountered: