Skip to content

Commit

Permalink
Merge branch 'main' into sdxl_v1
Browse files Browse the repository at this point in the history
  • Loading branch information
monorimet authored Nov 13, 2023
2 parents d5823a6 + 91df5f0 commit 428bec0
Show file tree
Hide file tree
Showing 16 changed files with 258 additions and 76 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -199,3 +199,6 @@ apps/stable_diffusion/web/EBWebView/

# Llama2 tokenizer configs
llama2_tokenizer_configs/

# Webview2 runtime artefacts
EBWebView/
4 changes: 2 additions & 2 deletions apps/language_models/langchain/langchain_requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,8 @@ tiktoken==0.4.0
openai==0.27.8

# optional for chat with PDF
langchain==0.0.202
pypdf==3.12.2
langchain==0.0.329
pypdf==3.17.0
# avoid textract, requires old six
#textract==1.6.5

Expand Down
2 changes: 2 additions & 0 deletions apps/language_models/scripts/vicuna.py
Original file line number Diff line number Diff line change
Expand Up @@ -1309,6 +1309,7 @@ def __init__(

def get_model_path(self, suffix="mlir"):
safe_device = self.device.split("-")[0]
safe_device = safe_device.split("://")[0]
if suffix in ["mlirbc", "mlir"]:
return Path(f"{self.model_name}_{self.precision}.{suffix}")

Expand Down Expand Up @@ -1973,6 +1974,7 @@ def create_prompt(model_name, history):
max_num_tokens=max_tokens,
min_num_tokens=min_tokens,
device=args.device,
vulkan_target_triple=vulkan_target_triple,
precision=args.precision,
vicuna_mlir_path=vic_mlir_path,
vicuna_vmfb_path=vic_vmfb_path,
Expand Down
9 changes: 8 additions & 1 deletion apps/stable_diffusion/src/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,7 +477,14 @@ def get_devices_by_name(driver_name):
f"{device_name} => {driver_name.replace('local', 'cpu')}"
)
else:
device_list.append(f"{device_name} => {driver_name}://{i}")
# for drivers with single devices
# let the default device be selected without any indexing
if len(device_list_dict) == 1:
device_list.append(f"{device_name} => {driver_name}")
else:
device_list.append(
f"{device_name} => {driver_name}://{i}"
)
return device_list

set_iree_runtime_flags()
Expand Down
69 changes: 29 additions & 40 deletions apps/stable_diffusion/web/index.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from multiprocessing import Process, freeze_support
from multiprocessing import freeze_support
import os
import sys
import logging
import apps.stable_diffusion.web.utils.app as app

if sys.platform == "darwin":
# import before IREE to avoid torch-MLIR library issues
Expand All @@ -21,26 +22,6 @@
clear_all()


def launch_app(address):
from tkinter import Tk
import webview

window = Tk()

# get screen width and height of display and make it more reasonably
# sized as we aren't making it full-screen or maximized
width = int(window.winfo_screenwidth() * 0.81)
height = int(window.winfo_screenheight() * 0.91)
webview.create_window(
"SHARK AI Studio",
url=address,
width=width,
height=height,
text_select=True,
)
webview.start(private_mode=False, storage_path=os.getcwd())


if __name__ == "__main__":
if args.debug:
logging.basicConfig(level=logging.DEBUG)
Expand All @@ -59,27 +40,27 @@ def launch_app(address):
# init global sd pipeline and config
global_obj._init()

app = FastAPI()
app.mount("/sdapi/", sdapi)
api = FastAPI()
api.mount("/sdapi/", sdapi)

# chat APIs needed for compatibility with multiple extensions using OpenAI API
app.add_api_route(
api.add_api_route(
"/v1/chat/completions", llm_chat_api, methods=["post"]
)
app.add_api_route("/v1/completions", llm_chat_api, methods=["post"])
app.add_api_route("/chat/completions", llm_chat_api, methods=["post"])
app.add_api_route("/completions", llm_chat_api, methods=["post"])
app.add_api_route(
api.add_api_route("/v1/completions", llm_chat_api, methods=["post"])
api.add_api_route("/chat/completions", llm_chat_api, methods=["post"])
api.add_api_route("/completions", llm_chat_api, methods=["post"])
api.add_api_route(
"/v1/engines/codegen/completions", llm_chat_api, methods=["post"]
)
app.include_router(APIRouter())
api.include_router(APIRouter())

# deal with CORS requests if CORS accept origins are set
if args.api_accept_origin:
print(
f"API Configured for CORS. Accepting origins: { args.api_accept_origin }"
)
app.add_middleware(
api.add_middleware(
CORSMiddleware,
allow_origins=args.api_accept_origin,
allow_methods=["GET", "POST"],
Expand All @@ -88,7 +69,7 @@ def launch_app(address):
else:
print("API not configured for CORS")

uvicorn.run(app, host="0.0.0.0", port=args.server_port)
uvicorn.run(api, host="0.0.0.0", port=args.server_port)
sys.exit(0)

# Setup to use shark_tmp for gradio's temporary image files and clear any
Expand All @@ -102,7 +83,10 @@ def launch_app(address):
import gradio as gr

# Create custom models folders if they don't exist
from apps.stable_diffusion.web.ui.utils import create_custom_models_folders
from apps.stable_diffusion.web.ui.utils import (
create_custom_models_folders,
nodicon_loc,
)

create_custom_models_folders()

Expand Down Expand Up @@ -222,7 +206,7 @@ def register_outputgallery_button(button, selectedid, inputs, outputs):
)

with gr.Blocks(
css=dark_theme, analytics_enabled=False, title="Stable Diffusion"
css=dark_theme, analytics_enabled=False, title="SHARK AI Studio"
) as sd_web:
with gr.Tabs() as tabs:
# NOTE: If adding, removing, or re-ordering tabs, make sure that they
Expand Down Expand Up @@ -278,6 +262,15 @@ def register_outputgallery_button(button, selectedid, inputs, outputs):
with gr.TabItem(label="Text-to-Image-SDXL (Experimental)", id=13):
txt2img_sdxl_web.render()

actual_port = app.usable_port()
if actual_port != args.server_port:
sd_web.load(
fn=lambda: gr.Info(
f"Port {args.server_port} is in use by another application. "
f"Shark is running on port {actual_port} instead."
)
)

# send to buttons
register_button_click(
txt2img_sendto_img2img,
Expand Down Expand Up @@ -438,14 +431,10 @@ def register_outputgallery_button(button, selectedid, inputs, outputs):
)

sd_web.queue()
if args.ui == "app":
t = Process(
target=launch_app, args=[f"http://localhost:{args.server_port}"]
)
t.start()
sd_web.launch(
share=args.share,
inbrowser=args.ui == "web",
inbrowser=not app.launch(actual_port),
server_name="0.0.0.0",
server_port=args.server_port,
server_port=actual_port,
favicon_path=nodicon_loc,
)
Binary file added apps/stable_diffusion/web/ui/logos/nod-icon.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
5 changes: 5 additions & 0 deletions apps/stable_diffusion/web/ui/stablelm_ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@ def chat(

device_id = None
model_name, model_path = list(map(str.strip, model.split("=>")))
device = device if "=>" not in device else device.split("=>")[1].strip()
if "cuda" in device:
device = "cuda"
elif "sync" in device:
Expand All @@ -164,6 +165,8 @@ def chat(
device = "vulkan"
elif "rocm" in device:
device = "rocm"
elif "metal" in device:
device = "metal"
else:
print("unrecognized device")

Expand Down Expand Up @@ -331,6 +334,8 @@ def llm_chat_api(InputData: dict):
elif "vulkan" in device:
device_id = int(device.split("://")[1])
device = "vulkan"
elif "metal" in device:
device = "metal"
else:
print("unrecognized device")

Expand Down
1 change: 1 addition & 0 deletions apps/stable_diffusion/web/ui/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,4 +170,5 @@ def cancel_sd():


nodlogo_loc = resource_path("logos/nod-logo.png")
nodicon_loc = resource_path("logos/nod-icon.png")
available_devices = get_available_devices()
105 changes: 105 additions & 0 deletions apps/stable_diffusion/web/utils/app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
import os
import sys
import webview
import webview.util
import socket

from contextlib import closing
from multiprocessing import Process

from apps.stable_diffusion.src import args


def webview2_installed():
if sys.platform != "win32":
return False

# On windows we want to ensure we have MS webview2 available so we don't fall back
# to MSHTML (aka ye olde Internet Explorer) which is deprecated by pywebview, and
# apparently causes SHARK not to load in properly.

# Checking these registry entries is how Microsoft says to detect a webview2 installation:
# https://learn.microsoft.com/en-us/microsoft-edge/webview2/concepts/distribution
import winreg

path = r"SOFTWARE\WOW6432Node\Microsoft\EdgeUpdate\Clients\{F3017226-FE2A-4295-8BDF-00C3A9A7E4C5}"

# only way can find if a registry entry even exists is to try and open it
try:
# check for an all user install
with winreg.OpenKey(
winreg.HKEY_LOCAL_MACHINE,
path,
0,
winreg.KEY_QUERY_VALUE | winreg.KEY_WOW64_64KEY,
) as registry_key:
value, type = winreg.QueryValueEx(registry_key, "pv")

# if it didn't exist, we want to continue on...
except WindowsError:
try:
# ...to check for a current user install
with winreg.OpenKey(
winreg.HKEY_CURRENT_USER,
path,
0,
winreg.KEY_QUERY_VALUE | winreg.KEY_WOW64_64KEY,
) as registry_key:
value, type = winreg.QueryValueEx(registry_key, "pv")
except WindowsError:
value = None
finally:
return (value is not None) and value != "" and value != "0.0.0.0"


def window(address):
from tkinter import Tk

window = Tk()

# get screen width and height of display and make it more reasonably
# sized as we aren't making it full-screen or maximized
width = int(window.winfo_screenwidth() * 0.81)
height = int(window.winfo_screenheight() * 0.91)
webview.create_window(
"SHARK AI Studio",
url=address,
width=width,
height=height,
text_select=True,
)
webview.start(private_mode=False, storage_path=os.getcwd())


def usable_port():
# Make sure we can actually use the port given in args.server_port. If
# not ask the OS for a port and return that as our port to use.

port = args.server_port

with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as sock:
try:
sock.bind(("0.0.0.0", port))
except OSError:
with closing(
socket.socket(socket.AF_INET, socket.SOCK_STREAM)
) as sock:
sock.bind(("0.0.0.0", 0))
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
return sock.getsockname()[1]

return port


def launch(port):
# setup to launch as an app if app mode has been requested and we're able
# to do it, answering whether we succeeded.
if args.ui == "app" and (sys.platform != "win32" or webview2_installed()):
try:
t = Process(target=window, args=[f"http://localhost:{port}"])
t.start()
return True
except webview.util.WebViewException:
return False
else:
return False
4 changes: 2 additions & 2 deletions docs/shark_sd_koboldcpp.md
Original file line number Diff line number Diff line change
Expand Up @@ -83,11 +83,11 @@ SHARK should start in server mode, and you should see something like this:

* There is one final piece of image generation configuration within Koboldcpp you might want to do. This is also in the generate images section of advanced settings. Here there is, not very obviously, a 'style' button:

![Selecting the 'styles' button](https://user-images.githubusercontent.com/121311569/280556172-4aab9794-7a77-46d7-bdda-43df570ad19a.png)
![Selecting the 'styles' button](https://user-images.githubusercontent.com/121311569/280556694-55cd1c55-a059-4b54-9293-63d66a32368e.png)

This will bring up a dialog box where you can enter a short text that will sent as a prefix to the Prompt sent to SHARK:

![Entering extra image styles](https://github.com/one-lithe-rune/SHARK/assets/121311569/4aab9794-7a77-46d7-bdda-43df570ad19a)
![Entering extra image styles](https://user-images.githubusercontent.com/121311569/280556172-4aab9794-7a77-46d7-bdda-43df570ad19a.png)


## Connecting to SHARK on a different address or port
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ joblib # for langchain
timm # for MiniGPT4
langchain
einops # for zoedepth
pydantic==2.4.1 # pin until pyinstaller-hooks-contrib works with beta versions

# Keep PyInstaller at the end. Sometimes Windows Defender flags it but most folks can continue even if it errors
pefile
Expand Down
14 changes: 10 additions & 4 deletions shark/iree_utils/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,12 @@
import subprocess


def run_cmd(cmd, debug=False):
def run_cmd(cmd, debug=False, raise_err=False):
"""
Inputs: cli command string.
Inputs:
cmd : cli command string.
debug : if True, prints debug info
raise_err : if True, raise exception to caller
"""
if debug:
print("IREE run command: \n\n")
Expand All @@ -39,8 +42,11 @@ def run_cmd(cmd, debug=False):
stderr = result.stderr.decode()
return stdout, stderr
except subprocess.CalledProcessError as e:
print(e.output)
sys.exit(f"Exiting program due to error running {cmd}")
if raise_err:
raise Exception from e
else:
print(e.output)
sys.exit(f"Exiting program due to error running {cmd}")


def iree_device_map(device):
Expand Down
Loading

0 comments on commit 428bec0

Please sign in to comment.