microsoft-promptist
Version: 6
Promptist: reinforcement learning for automatic prompt optimization
News
- [Demo Release] Dec, 2022: Demo at HuggingFace Space
- [Model Release] Dec, 2022: link
- [Paper Release] Dec, 2022: Optimizing Prompts for Text-to-Image Generation
- Language models serve as a prompt interface that optimizes user input into model-preferred prompts.
- Learn a language model for automatic prompt optimization via reinforcement learning.
Load Pretrained Model for Stable Diffusion v1.4
You can try the online demo at https://huggingface.co/spaces/microsoft/Promptist .[Note] the online demo at HuggingFace Space is using CPU, so slow generation speed would be expected. Please load the model locally with GPUs for faster generation.
import gradio as grad
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
def load_prompter():
prompter_model = AutoModelForCausalLM.from_pretrained("microsoft/Promptist")
tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left"
return prompter_model, tokenizer
prompter_model, prompter_tokenizer = load_prompter()
def generate(plain_text):
input_ids = prompter_tokenizer(plain_text.strip()+" Rephrase:", return_tensors="pt").input_ids
eos_id = prompter_tokenizer.eos_token_id
outputs = prompter_model.generate(input_ids, do_sample=False, max_new_tokens=75, num_beams=8, num_return_sequences=8, eos_token_id=eos_id, pad_token_id=eos_id, length_penalty=-1.0)
output_texts = prompter_tokenizer.batch_decode(outputs, skip_special_tokens=True)
res = output_texts[0].replace(plain_text+" Rephrase:", "").strip()
return res
txt = grad.Textbox(lines=1, label="Initial Text", placeholder="Input Prompt")
out = grad.Textbox(lines=1, label="Optimized Prompt")
examples = ["A rabbit is wearing a space suit", "Several railroad tracks with one train passing by", "The roof is wet from the rain", "Cats dancing in a space club"]
grad.Interface(fn=generate,
inputs=txt,
outputs=out,
title="Promptist Demo",
description="Promptist is a prompt interface for Stable Diffusion v1-4 (https://huggingface.co/CompVis/stable-diffusion-v1-4) that optimizes user input into model-preferred prompts.",
examples=examples,
allow_flagging='never',
cache_examples=False,
theme="default").launch(enable_queue=True, debug=True)
microsoft/Promptist powered by Text Generation Inference (TGI)
Send Request
You can use cURL or any REST Client to send a request to the AzureML endpoint with your AzureML token.curl <AZUREML_ENDPOINT_URL> \
-X POST \
-d '{"inputs":"What is Deep Learning?"}' \
-H "Authorization: Bearer <AZUREML_TOKEN>" \
-H "Content-Type: application/json"
Supported Parameters
- inputs (string): Input prompt.
- parameters (object):
- best_of (integer): Generate best_of sequences and return the one if the highest token logprobs.
- decoder_input_details (boolean): Whether to return decoder input token logprobs and ids.
- details (boolean): Whether to return generation details.
- do_sample (boolean): Activate logits sampling.
- frequency_penalty (float): The parameter for frequency penalty. 1.0 means no penalty Penalize new tokens based on their existing frequency in the text so far, decreasing the model’s likelihood to repeat the same line verbatim.
- grammar (object): One of the following
- #1 (object):
- type (enum): Possible values: json.
- value (string): A string that represents a JSON Schema. JSON Schema is a declarative language that allows to annotate JSON documents with types and descriptions.
- #2 (object):
- type (enum): Possible values: regex.
- value (string): The regular expression.
- #3 (object):
- type (enum): Possible values: json_schema.
- value (object):
- name (string): Optional name identifier for the schema
- schema (object): The actual JSON schema definition
- #1 (object):
- max_new_tokens (integer): Maximum number of tokens to generate.
- repetition_penalty (float): The parameter for repetition penalty. 1.0 means no penalty. See this paper for more details.
- return_full_text (boolean): Whether to prepend the prompt to the generated text
- seed (integer): Random sampling seed.
- stop (string[]): Stop generating tokens if a member of stop is generated.
- temperature (float): The value used to module the logits distribution.
- top_k (integer): The number of highest probability vocabulary tokens to keep for top-k-filtering.
- top_n_tokens (integer): The number of highest probability vocabulary tokens to keep for top-n-filtering.
- top_p (float): Top-p value for nucleus sampling.
- truncate (integer): Truncate inputs tokens to the given size.
- typical_p (float): Typical Decoding mass See Typical Decoding for Natural Language Generation for more information.
- watermark (boolean): Watermarking with A Watermark for Large Language Models.
- stream (boolean): Whether to stream the output tokens or not. Defaults to false.
Example payload
{
"inputs": "What is Deep Learning?",
"parameters": {
"do_sample": true,
"top_p": 0.95,
"temperature": 0.2,
"top_k": 50,
"max_new_tokens": 256,
"repetition_penalty": 1.03,
"stop": ["\nUser:", "<|endoftext|>", "</s>"]
}
}
OpenAI Chat Completion API compatibility
Additionally, Text Generation Inference (TGI) offers an OpenAI Chat Completion API compatible layer under the endpoint/v1/chat/completions,check the full specification in the OpenAI Chat Completion Create Documentation .
Model Specifications
LicenseUnknown
Last UpdatedJuly 2025
ProviderHuggingFace