This guide covers how State is handled in Gradio. Learn the difference between Global and Session states, and how to use both.
Your function may use data that persists beyond a single function call. If the data is something accessible to all function calls and all users, you can create a variable outside the function call and access it inside the function. For example, you may load a large model outside the function and use it inside the function so that every function call does not need to reload the model.
import gradio as gr
scores = []
def track_score(score):
scores.append(score)
top_scores = sorted(scores, reverse=True)[:3]
return top_scores
demo = gr.Interface(
track_score,
gr.Number(label="Score"),
gr.JSON(label="Top Scores")
)
demo.launch()
In the code above, the scores
array is shared between all users. If multiple users are accessing this demo, their scores will all be added to the same list, and the returned top 3 scores will be collected from this shared reference.
Another type of data persistence Gradio supports is session state, where data persists across multiple submits within a page session. However, data is not shared between different users of your model. To store data in a session state, you need to do three things:
'state'
input and 'state'
output components when creating your Interface
A chatbot is an example where you would need session state - you want access to a users previous submissions, but you cannot store chat history in a global variable, because then chat history would get jumbled between different users.
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-medium")
model = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-medium")
def user(message, history):
return "", history + [[message, None]]
def bot(history):
user_message = history[-1][0]
new_user_input_ids = tokenizer.encode(
user_message + tokenizer.eos_token, return_tensors="pt"
)
# append the new user input tokens to the chat history
bot_input_ids = torch.cat([torch.LongTensor([]), new_user_input_ids], dim=-1)
# generate a response
response = model.generate(
bot_input_ids, max_length=1000, pad_token_id=tokenizer.eos_token_id
).tolist()
# convert the tokens to text, and then split the responses into lines
response = tokenizer.decode(response[0]).split("<|endoftext|>")
response = [
(response[i], response[i + 1]) for i in range(0, len(response) - 1, 2)
] # convert to tuples of list
history[-1] = response[0]
return history
with gr.Blocks() as demo:
chatbot = gr.Chatbot()
msg = gr.Textbox()
clear = gr.Button("Clear")
msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
bot, chatbot, chatbot
)
clear.click(lambda: None, None, chatbot, queue=False)
demo.launch()
Notice how the state persists across submits within each page, but if you load this demo in another tab (or refresh the page), the demos will not share chat history.
The default value of state
is None. If you pass a default value to the state parameter of the function, it is used as the default value of the state instead. The Interface
class only supports a single input and outputs state variable, though it can be a list with multiple elements. For more complex use cases, you can use Blocks, which supports multiple State
variables.