基于 huggingface diffuser 库本地部署 Stable diffusion
前言
Stable Diffusion 是用 LAION-5B 的子集(图像大小为512*512)训练的扩散模型。此模型冻结 CLIP 的 ViT-L/14 文本编码器建模 prompt text。模型包含 860M UNet 和123M 文本编码器,可运行在具有至少10GB VRAM 的 GPU 上。
HF主页:https://huggingface.co/CompVis/stable-diffusion
Colab:https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/stable_diffusion.ipynb
diffusers官方文档:https://huggingface.co/docs/diffusers
接下来实战一下本地部署。
本地部署
1. 安装环境
conda create -n diffenv python=3.8conda activate diffenvpip install diffusers==0.4.0pip install transformers scipy ftfy# pip install "ipywidgets>=7,<8" 这个是colab用于交互输入的控件
如果后面执行代码时报错 RuntimeError: CUDA error: no kernel image is available for execution on the device
,说明cuda版本和pytorch版本问题,根据机器的 cuda 版本重新装一下:
pip install torch==1.12.0+cu113 torchvision==0.13.0+cu113 torchaudio==0.12.0 -f https://download.pytorch.org/whl/torch_stable.html
获取模型:
首先得同意模型的使用协议。
如果用官方 colab,需要输入 huggingface 的 access token 来联网校验你是否同意了协议。如果不想输入的话,就执行以下命令先把模型权重等文件下载到本地:
git lfs installgit clone https://huggingface.co/CompVis/stable-diffusion-v1-4
这样加载模型时直接 DiffusionPipeline.from_pretrained("./MODEL_PATH/stable-diffusion-v1-4")
,就不用加 use_auth_token=AUTH_TOKEN
参数了。
2. 加载模型
如果要确保高精度(占显存也高),删除 revision="fp16"
和 torch_dtype=torch.float16
。
import torch, osfrom diffusers import StableDiffusionPipelineos.environ["CUDA_VISIBLE_DEVICES"] = "2"pipe = StableDiffusionPipeline.from_pretrained( "CompVis/stable-diffusion-v1-4", # 本地地址也行 revision="fp16", # 如果不想用半精度,删掉这行和下面一行 torch_dtype=torch.float16) pipe = pipe.to("cuda")
3. 生成图像
3.1 直接生成
默认长宽都是512像素,可以指定 pipe(height=512, width=768)
来控制尺寸。需要注意的是:
- 高度和宽度都是8的倍数
- 低于 512 可能会导致图像质量降低
- 长宽都超过 512 会出现重复图像区域(丢失全局一致性)
- 非正方形图像最佳方法:一个维度为 512px,另一维度大于512px。
prompt = "a photograph of an astronaut swimming in the river"image = pipe(prompt).images[0] # PIL格式 (https://pillow.readthedocs.io/en/stable/)image.save(f"astronaut_rides_horse.png")image
输出结果如下,还有点内味儿。
3.2 非随机生成
刚才 3.1 部分生成的每次都不一样,若需非随机生成,则指定随机种子,pipe()
中传入 generator
参数指定 generator。
import torchgenerator = torch.Generator("cuda").manual_seed(1024)image = pipe(prompt, generator=generator).images[0]image
3.3 推理步数控制图像质量
使用 num_inference_steps
参数更改推理 steps。通常步数越多,结果越好,推理越慢。Stable Diffusion 比较强,只需相对较少的步骤效果就不错,因此建议使用默认值50。如图把 num_inference_steps
设成 100,随机种子保持不变,貌似效果差距并不大。
import torchgenerator = torch.Generator("cuda").manual_seed(1024)image = pipe(prompt, num_inference_steps=100, generator=generator).images[0]image
3.4 生成多张图片
写个做图片拼接的函数:
from PIL import Imagedef image_grid(imgs, rows, cols): assert len(imgs) == rows*cols w, h = imgs[0].size grid = Image.new('RGB', size=(cols*w, rows*h)) grid_w, grid_h = grid.size for i, img in enumerate(imgs): grid.paste(img, box=(i%cols*w, i//cols*h)) return grid
一次性生成 3 幅图,此时 prompt 为 list 而不是 str。
num_images = 3prompt = ["a traditional Chinese painting of a squirrel eating a banana"] * num_imagesimages = pipe(prompt).imagesgrid = image_grid(images, rows=1, cols=3)grid
附录
依赖环境如下:
Package Version----------------------------- ----------------------------absl-py 1.2.0aeppl 0.0.33aesara 2.7.9aiohttp 3.8.3aiosignal 1.2.0alabaster 0.7.12albumentations 1.2.1altair 4.2.0appdirs 1.4.4arviz 0.12.1astor 0.8.1astropy 4.3.1astunparse 1.6.3async-timeout 4.0.2asynctest 0.13.0atari-py 0.2.9atomicwrites 1.4.1attrs 22.1.0audioread 3.0.0autograd 1.5Babel 2.10.3backcall 0.2.0beautifulsoup4 4.6.3bleach 5.0.1blis 0.7.8bokeh 2.3.3branca 0.5.0bs4 0.0.1CacheControl 0.12.11cached-property 1.5.2cachetools 4.2.4catalogue 2.0.8certifi 2022.9.24cffi 1.15.1cftime 1.6.2chardet 3.0.4charset-normalizer 2.1.1click 7.1.2clikit 0.6.2cloudpickle 1.5.0cmake 3.22.6cmdstanpy 1.0.7colorcet 3.0.1colorlover 0.3.0community 1.0.0b1confection 0.0.2cons 0.4.5contextlib2 0.5.5convertdate 2.4.0crashtest 0.3.1crcmod 1.7cufflinks 0.17.3cupy-cuda11x 11.0.0cvxopt 1.3.0cvxpy 1.2.1cycler 0.11.0cymem 2.0.6Cython 0.29.32daft 0.0.4dask 2022.2.0datascience 0.17.5debugpy 1.0.0decorator 4.4.2defusedxml 0.7.1descartes 1.1.0diffusers 0.4.0dill 0.3.5.1distributed 2022.2.0dlib 19.24.0dm-tree 0.1.7docutils 0.17.1dopamine-rl 1.0.5earthengine-api 0.1.326easydict 1.10ecos 2.0.10editdistance 0.5.3en-core-web-sm 3.4.0entrypoints 0.4ephem 4.1.3et-xmlfile 1.1.0etils 0.8.0etuples 0.3.8fa2 0.3.5fastai 2.7.9fastcore 1.5.27fastdownload 0.0.7fastdtw 0.3.4fastjsonschema 2.16.2fastprogress 1.0.3fastrlock 0.8feather-format 0.4.1filelock 3.8.0firebase-admin 4.4.0fix-yahoo-finance 0.0.22Flask 1.1.4flatbuffers 22.9.24folium 0.12.1.post1frozenlist 1.3.1fsspec 2022.8.2ftfy 6.1.1future 0.16.0gast 0.5.3GDAL 2.2.2gdown 4.4.0gensim 3.6.0geographiclib 1.52geopy 1.17.0gin-config 0.5.0glob2 0.7google 2.0.3google-api-core 1.31.6google-api-python-client 1.12.11google-auth 1.35.0google-auth-httplib2 0.0.4google-auth-oauthlib 0.4.6google-cloud-bigquery 1.21.0google-cloud-bigquery-storage 1.1.2google-cloud-core 1.0.3google-cloud-datastore 1.8.0google-cloud-firestore 1.7.0google-cloud-language 1.2.0google-cloud-storage 1.18.1google-cloud-translate 1.5.0google-colab 1.0.0google-pasta 0.2.0google-resumable-media 0.4.1googleapis-common-protos 1.56.4googledrivedownloader 0.4graphviz 0.10.1greenlet 1.1.3grpcio 1.49.1gspread 3.4.2gspread-dataframe 3.0.8gym 0.25.2gym-notices 0.0.8h5py 3.1.0HeapDict 1.0.1hijri-converter 2.2.4holidays 0.16holoviews 1.14.9html5lib 1.0.1httpimport 0.5.18httplib2 0.17.4httplib2shim 0.0.3httpstan 4.6.1huggingface-hub 0.10.0humanize 0.5.1hyperopt 0.1.2idna 2.10imageio 2.9.0imagesize 1.4.1imbalanced-learn 0.8.1imblearn 0.0imgaug 0.4.0importlib-metadata 5.0.0importlib-resources 5.9.0imutils 0.5.4inflect 2.1.0intel-openmp 2022.2.0intervaltree 2.1.0ipykernel 5.3.4ipython 7.9.0ipython-genutils 0.2.0ipython-sql 0.3.9ipywidgets 7.7.1itsdangerous 1.1.0jax 0.3.21jaxlib 0.3.20+cuda11.cudnn805jedi 0.18.1jieba 0.42.1Jinja2 2.11.3joblib 1.2.0jpeg4py 0.1.4jsonschema 4.3.3jupyter-client 6.1.12jupyter-console 6.1.0jupyter-core 4.11.1jupyterlab-widgets 3.0.3kaggle 1.5.12kapre 0.3.7keras 2.8.0Keras-Preprocessing 1.1.2keras-vis 0.4.1kiwisolver 1.4.4korean-lunar-calendar 0.3.1langcodes 3.3.0libclang 14.0.6librosa 0.8.1lightgbm 2.2.3llvmlite 0.39.1lmdb 0.99locket 1.0.0logical-unification 0.4.5LunarCalendar 0.0.9lxml 4.9.1Markdown 3.4.1MarkupSafe 2.0.1marshmallow 3.18.0matplotlib 3.2.2matplotlib-venn 0.11.7miniKanren 1.0.3missingno 0.5.1mistune 0.8.4mizani 0.7.3mkl 2019.0mlxtend 0.14.0more-itertools 8.14.0moviepy 0.2.3.5mpmath 1.2.1msgpack 1.0.4multidict 6.0.2multipledispatch 0.6.0multitasking 0.0.11murmurhash 1.0.8music21 5.5.0natsort 5.5.0nbconvert 5.6.1nbformat 5.6.1netCDF4 1.6.1networkx 2.6.3nibabel 3.0.2nltk 3.7notebook 5.3.1numba 0.56.2numexpr 2.8.3numpy 1.21.6oauth2client 4.1.3oauthlib 3.2.1okgrade 0.4.3opencv-contrib-python 4.6.0.66opencv-python 4.6.0.66opencv-python-headless 4.6.0.66openpyxl 3.0.10opt-einsum 3.3.0osqp 0.6.2.post0packaging 21.3palettable 3.3.0pandas 1.3.5pandas-datareader 0.9.0pandas-gbq 0.13.3pandas-profiling 1.4.1pandocfilters 1.5.0panel 0.12.1param 1.12.2parso 0.8.3partd 1.3.0pastel 0.2.1pathlib 1.0.1pathy 0.6.2patsy 0.5.2pep517 0.13.0pexpect 4.8.0pickleshare 0.7.5Pillow 7.1.2pip 21.1.3pip-tools 6.2.0plotly 5.5.0plotnine 0.8.0pluggy 0.7.1pooch 1.6.0portpicker 1.3.9prefetch-generator 1.0.1preshed 3.0.7prettytable 3.4.1progressbar2 3.38.0promise 2.3prompt-toolkit 2.0.10prophet 1.1.1protobuf 3.17.3psutil 5.4.8psycopg2 2.9.3ptyprocess 0.7.0py1.11.0pyarrow 6.0.1pyasn1 0.4.8pyasn1-modules 0.2.8pycocotools 2.0.5pycparser 2.21pyct 0.4.8pydantic 1.9.2pydata-google-auth 1.4.0pydot 1.3.0pydot-ng 2.0.0pydotplus 2.0.2PyDrive 1.3.1pyemd 0.5.1pyerfa 2.0.0.1Pygments 2.6.1pygobject 3.26.1pylev 1.4.0pymc 4.1.4PyMeeus 0.5.11pymongo 4.2.0pymystem3 0.2.0PyOpenGL 3.1.6pyparsing 3.0.9pyrsistent 0.18.1pysimdjson 3.2.0pysndfile 1.3.8PySocks 1.7.1pystan 3.3.0pytest 3.6.4python-apt 0.0.0python-chess 0.23.11python-dateutil 2.8.2python-louvain 0.16python-slugify 6.1.2python-utils 3.3.3pytz 2022.4pyviz-comms 2.2.1PyWavelets 1.3.0PyYAML 6.0pyzmq 23.2.1qdldl 0.1.5.post2qudida 0.0.4regex 2022.6.2requests 2.23.0requests-oauthlib 1.3.1resampy 0.4.2rpy2 3.4.5rsa 4.9scikit-image 0.18.3scikit-learn 1.0.2scipy 1.7.3screen-resolution-extra 0.0.0scs 3.2.0seaborn 0.11.2Send2Trash 1.8.0setuptools 57.4.0setuptools-git 1.2Shapely 1.8.4six 1.15.0sklearn-pandas 1.8.0smart-open 5.2.1snowballstemmer 2.2.0sortedcontainers 2.4.0soundfile 0.11.0spacy 3.4.1spacy-legacy 3.0.10spacy-loggers 1.0.3Sphinx 1.8.6sphinxcontrib-serializinghtml 1.1.5sphinxcontrib-websupport 1.2.4SQLAlchemy 1.4.41sqlparse 0.4.3srsly 2.4.4statsmodels 0.12.2sympy 1.7.1tables 3.7.0tabulate 0.8.10tblib 1.7.0tenacity 8.1.0tensorboard 2.8.0tensorboard-data-server 0.6.1tensorboard-plugin-wit 1.8.1tensorflow 2.8.2+zzzcolab20220929150707tensorflow-datasets 4.6.0tensorflow-estimator 2.8.0tensorflow-gcs-config 2.8.0tensorflow-hub 0.12.0tensorflow-io-gcs-filesystem 0.27.0tensorflow-metadata 1.10.0tensorflow-probability 0.16.0termcolor 2.0.1terminado 0.13.3testpath 0.6.0text-unidecode 1.3textblob 0.15.3thinc 8.1.2threadpoolctl 3.1.0tifffile 2021.11.2tokenizers 0.12.1toml 0.10.2tomli 2.0.1toolz 0.12.0torch 1.12.1+cu113torchaudio 0.12.1+cu113torchsummary 1.5.1torchtext 0.13.1torchvision 0.13.1+cu113tornado 5.1.1tqdm 4.64.1traitlets 5.1.1transformers 4.22.2tweepy 3.10.0typeguard 2.7.1typer 0.4.2typing-extensions 4.1.1tzlocal 1.5.1ujson 5.5.0uritemplate 3.0.1urllib3 1.24.3vega-datasets 0.9.0wasabi 0.10.1wcwidth 0.2.5webargs 8.2.0webencodings 0.5.1Werkzeug 1.0.1wheel 0.37.1widgetsnbextension 3.6.1wordcloud 1.8.2.2wrapt 1.14.1xarray 0.20.2xarray-einstats 0.2.2xgboost 0.90xkit 0.0.0xlrd 1.1.0xlwt 1.3.0yarl 1.8.1yellowbrick 1.5zict 2.2.0zipp 3.8.1
来源地址:https://blog.csdn.net/muyao987/article/details/127230089
免责声明:
① 本站未注明“稿件来源”的信息均来自网络整理。其文字、图片和音视频稿件的所属权归原作者所有。本站收集整理出于非商业性的教育和科研之目的,并不意味着本站赞同其观点或证实其内容的真实性。仅作为临时的测试数据,供内部测试之用。本站并未授权任何人以任何方式主动获取本站任何信息。
② 本站未注明“稿件来源”的临时测试数据将在测试完成后最终做删除处理。有问题或投稿请发送至: 邮箱/279061341@qq.com QQ/279061341