Skip to content

Commit

Permalink
Merge branch 'main' into fix-vlad-update
Browse files Browse the repository at this point in the history
  • Loading branch information
ionite34 authored Jul 5, 2023
2 parents eb9d4bb + 3dedd86 commit c4898ef
Show file tree
Hide file tree
Showing 4 changed files with 104 additions and 3 deletions.
13 changes: 12 additions & 1 deletion StabilityMatrix/Helper/HardwareHelper.cs
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,15 @@ public static IEnumerable<GpuInfo> IterGpuInfo()
/// </summary>
public static bool HasNvidiaGpu()
{
return IterGpuInfo().Any(gpu => gpu.Name?.ToLowerInvariant().Contains("nvidia") ?? false);
return IterGpuInfo().Any(gpu => gpu.IsNvidia);
}

/// <summary>
/// Return true if the system has at least one AMD GPU.
/// </summary>
public static bool HasAmdGpu()
{
return IterGpuInfo().Any(gpu => gpu.IsAmd);
}
}

Expand Down Expand Up @@ -112,4 +120,7 @@ public record GpuInfo
< 8 * Size.GiB => Level.Medium,
_ => Level.High
};

public bool IsNvidia => Name?.ToLowerInvariant().Contains("nvidia") ?? false;
public bool IsAmd => Name?.ToLowerInvariant().Contains("amd") ?? false;
}
2 changes: 2 additions & 0 deletions StabilityMatrix/Models/Packages/A3WebUI.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
using System.IO;
using System.Linq;
using System.Text.RegularExpressions;
Expand Down Expand Up @@ -47,6 +48,7 @@ public A3WebUI(IGithubApiCache githubApi, ISettingsManager settingsManager, IDow
[SharedFolderType.ControlNet] = "models/ControlNet"
};

[SuppressMessage("ReSharper", "ArrangeObjectCreationWhenTypeNotEvident")]
public override List<LaunchOptionDefinition> LaunchOptions => new()
{
new()
Expand Down
85 changes: 83 additions & 2 deletions StabilityMatrix/Models/Packages/VladAutomatic.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
using System.IO;
using System.Linq;
using System.Text.RegularExpressions;
Expand All @@ -9,6 +10,7 @@
using StabilityMatrix.Helper;
using StabilityMatrix.Helper.Cache;
using StabilityMatrix.Models.Progress;
using StabilityMatrix.Python;
using StabilityMatrix.Services;

namespace StabilityMatrix.Models.Packages;
Expand Down Expand Up @@ -51,8 +53,23 @@ public VladAutomatic(IGithubApiCache githubApi, ISettingsManager settingsManager
[SharedFolderType.LyCORIS] = "models/LyCORIS",
};

[SuppressMessage("ReSharper", "ArrangeObjectCreationWhenTypeNotEvident")]
public override List<LaunchOptionDefinition> LaunchOptions => new()
{
new()
{
Name = "Host",
Type = LaunchOptionType.String,
DefaultValue = "localhost",
Options = new() {"--server-name"}
},
new()
{
Name = "Port",
Type = LaunchOptionType.String,
DefaultValue = "7860",
Options = new() {"--port"}
},
new()
{
Name = "VRAM",
Expand All @@ -65,10 +82,42 @@ public VladAutomatic(IGithubApiCache githubApi, ISettingsManager settingsManager
Options = new() { "--lowvram", "--medvram" }
},
new()
{
Name = "Force use of Intel OneAPI XPU backend",
Options = new() { "--use-ipex" }
},
new()
{
Name = "Use DirectML if no compatible GPU is detected",
InitialValue = !HardwareHelper.HasNvidiaGpu() && HardwareHelper.HasAmdGpu(),
Options = new() { "--use-directml" }
},
new()
{
Name = "Force use of Nvidia CUDA backend",
Options = new() { "--use-cuda" }
},
new()
{
Name = "Force use of AMD ROCm backend",
Options = new() { "--use-rocm" }
},
new()
{
Name = "CUDA Device ID",
Type = LaunchOptionType.String,
Options = new() { "--device-id" }
},
new()
{
Name = "API",
Options = new() { "--api" }
},
new()
{
Name = "Debug Logging",
Options = new() { "--debug" }
},
LaunchOptionDefinition.Extras
};

Expand All @@ -88,8 +137,40 @@ public override async Task<IEnumerable<PackageVersion>> GetAllVersions(bool isRe

public override async Task InstallPackage(IProgress<ProgressReport>? progress = null)
{
await PrerequisiteHelper.SetupPythonDependencies(InstallLocation, "requirements.txt", progress,
OnConsoleOutput);
progress?.Report(new ProgressReport(-1, isIndeterminate: true));
// Setup venv
var venvRunner = new PyVenvRunner(Path.Combine(InstallLocation, "venv"));
if (!venvRunner.Exists())
{
await venvRunner.Setup();
}

// Install torch / xformers based on gpu info
var gpus = HardwareHelper.IterGpuInfo().ToList();
if (gpus.Any(g => g.IsNvidia))
{
Logger.Info("Starting torch install (CUDA)...");
await venvRunner.PipInstall(PyVenvRunner.TorchPipInstallArgsCuda,
InstallLocation, OnConsoleOutput);
Logger.Info("Installing xformers...");
await venvRunner.PipInstall("xformers", InstallLocation, OnConsoleOutput);
}
else if (gpus.Any(g => g.IsAmd))
{
Logger.Info("Starting torch install (DirectML)...");
await venvRunner.PipInstall(PyVenvRunner.TorchPipInstallArgsDirectML);
}
else
{
Logger.Info("Starting torch install (CPU)...");
await venvRunner.PipInstall(PyVenvRunner.TorchPipInstallArgsCpu);
}

// Install requirements file
Logger.Info("Installing requirements.txt");
await venvRunner.PipInstall($"-r requirements.txt", InstallLocation, OnConsoleOutput);

progress?.Report(new ProgressReport(1, isIndeterminate: false));
}

public override async Task<string?> DownloadPackage(string version, bool isCommitHash, IProgress<ProgressReport>? progress = null)
Expand Down
7 changes: 7 additions & 0 deletions StabilityMatrix/Python/PyVenvRunner.cs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,13 @@ public class PyVenvRunner : IDisposable
{
private static readonly Logger Logger = LogManager.GetCurrentClassLogger();

public const string TorchPipInstallArgsCuda =
"torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu118";
public const string TorchPipInstallArgsCpu =
"torch torchvision torchaudio";
public const string TorchPipInstallArgsDirectML =
"torch-directml";

/// <summary>
/// The process running the python executable.
/// </summary>
Expand Down

0 comments on commit c4898ef

Please sign in to comment.