Understanding Retention with Gradio | by Mariya Mansurova | Oct, 2023

I remember a moment when I built my first web application. It was around eight years ago, and I was a rather junior analyst and was convinced that BI tools could solve all the problems.

The engineering team built a prototype of a new SDK and wanted to learn whether it gathers data better. They were testing it on a set of devices, looking at the data and comparing it to the old version. However, the set of devices was constantly changing, so keeping it up-to-date in BI tools would require quite a lot of work. So, I decided to build a web application.

I found a set of articles (ten or eleven if I remember correctly), read them all and tried to use this knowledge for my task. It took me around a week to finish the first prototype. I had to write both the back-end and front-end sides, so now I could consider myself at least a junior full-stack developer. For the back-end, I used Flask (I was lucky not to bump into Django, or I would have spent the whole month), and for front-end — Bootstrap and Leaflet.

Overall, it was a challenging task that required much effort to upskill in engineering. I believe it’s always worth having a deeper understanding of the other spheres next to your primary domain of expertise.

However, I’m delighted that nowadays, there are many tools that allow analysts and data scientists to build prototypes in less than an hour. In many cases, such prototypes can bring your analytics to the next level. Here are some examples:

  • Revenue and audience forecast depending on the input parameters (like marketing budget or markets where we will launch a new feature),
  • Tools that will speed up your team’s work or reduce ad-hoc workload, like an A/B testing calculator or automatic root cause analysis,
  • MVP solutions, for example, if you want to use LLMs to automate some internal processes, it’s worth testing a prototype before spending time on a production version. I shared such an ML prototype in one of my previous articles, “Build your first Deep Learning app within an hour”.

In this article, I would like to tell you about one of such frameworks that can help you quickly and almost effortlessly create nice-looking web applications without bothering with JavaScript and CSS. We will learn the basics of Gradio, develop a couple of web applications, and publish them to HuggingFace Spaces so anyone can access them.

Gradio is not the only framework of that kind. There are a few other open-source Python alternatives:

  • Streamlit is another popular and powerful library for building data apps with little code. It is also supported by HuggingFace Spaces so that you can host such apps.
  • Dash could be convenient if you are already used to Plotly, and it provides more capabilities for customization.
  • However, if you want to build something custom and complex, your last resort would be Flask or even Django.

You can find more details regarding the main features of the different frameworks in this article.

Gradio is an open-source Python library that is used to build interactive applications.

The main advantages of Gradio are:

  • you can build applications using only Python, which also means that you can use all Python libraries in your app,
  • you can run it in Jupyter Notebook or as a separate webpage,
  • you can host Gradio apps permanently on HuggingFace spaces.

There’s no silver bullet, so Gradio has its limitations:

  • It’s explicitly designed for ML applications. So, if you’re using it for other use cases, you might have to change defaults (for example, switching off flagging with allow_flagging= "never").
  • Customization is limited, especially if we are talking about design.
  • I would bear in mind that Gradio is a framework primarily for quick prototyping. It mostly works well, but from time to time, I face some strange behaviour. For example, table editing in Safari works counterintuitively, or sometimes you need to restart Jupyter Notebook to make the interface load.

To start using Gradio, we need to install the Python package.

pip install gradio

Following the old programmers’ tradition, let’s start with “Hello, World!”.

We can use gr.Interface class to define the interface (documentation). It’s one of the core Gradio classes that helps you to create a web application based on any Python function.

We need to specify the following parameters:

  • inputs: input components of the interface (in our case, just a text field),
  • outputs: output components of the interface (in our case, also just a text field),
  • fn: core functionality (a function that gets inputs and returns outputs, in our case, gets name from the input and returns “Hello, <name>!”),
  • title & description: a bit of markdown to make our app more .
import gradio as gr

demo = gr.Interface(
inputs=[gr.Textbox(label="Name", lines=1)],
outputs=[gr.Textbox(label="Result", lines=1)],
fn=lambda x: 'Hello, %s!' % x,
title="Hello, World!",
description="Your first app using Gradio",
allow_flagging='never')

demo.launch()

You can run this code in your Jupyter Notebook and see the results. It’s pretty handy for debugging. Later, we will discuss how to make your web application available to others.

Image by author

That’s it: just a few lines of code, and your first Gradio app is running. Also, I must note that it looks pretty nice, and we didn’t have to use any front-end magic for it.

Gradio launches a lot of processes in the background when you’re working from Jupyter Notebook, so it’s worth from time to time close connections using gr.close_all().

We looked at the most basic example and saw the building blocks of Gradio: inputs, outputs and functions. Now, we are ready to move on to real-life analytical tasks.

As the first example, we will look at the of retention on the users’ growth for the product.

Retention as the basis for growth

Two parameters define the growth of the product:

  • acquisition (number of new users each period),
  • retention (ability to retain customers in the product).

Let’s model how the user base will grow depending on the retention curve.

We can describe any retention curve using the following function with a set of parameters (a, b, c and d):

Let’s talk about the most common case of retention: cohort is defined by the first action in the product, and all actions are counted into the retention. In that case, retention for periods = 0 must equal 1 (because the cohort entry and retention events are the same). So, we can define one of the parameters automatically:

The main factor for growth is long-term retention. It defines whether customers stick to the product for a long time and your product grows sustainably or customers churn in a month, and you need to acquire more and more new users for growth. In our formula, a parameter is in charge of long-term retention.

We can use this formula to define the retention curve. So we have everything we need to move on to the development.

Visualising retention graph

Let’s start simple and make an application that will take the retention curve parameters and show the relation as a graph.

Similarly to our “Hello, World” example, we need to use gr.Interface class and pass inputs, outputs and fn to map them.

  • We now need more input parameters. So, inputs will be a list of controls. We will use gr.Slider and gr.Dropdown controls.
    For gr.Slider, we need to pass min, max, default values and a label that we will use in the function.
    For gr.Dropdown, we need to define a list of possible values, default value, and a label.
  • We will still have only one output — a plot so that outputs will be gr.Plot without any parameters.
  • Function fn will map inputs to outputs, so it will get input arguments and return plotly.Figure object that will be visualised.
import plotly.express as px

# functions to calculate retention

def get_retention(a, b, c, d, periods):
return a + 1./(b + c * (periods ** d))

def get_retention_same_event(a, c, d, periods):
b = 1./(1 - a)
return get_retention(a, b, c, d, periods)

# define function - return plot depending on input parameters

def get_retention_plot(a, c, d, num_periods):
df = pd.DataFrame({'x': range(num_periods + 1)})
df['retention'] = df.x.map(lambda x: get_retention_same_event(a, c, d, x))

return px.line(df, x = 'x', y = 'retention',
color_discrete_sequence = px.colors.qualitative.Prism,
title = 'Retention curve', labels = {'x': 'period'})

# define inputs
inputs = [
gr.Slider(0, 1, 0.03, label="a"),
gr.Slider(0, 5, 0.55, label="c"),
gr.Slider(0, 5, 1.5, label="d"),
gr.Dropdown([10, 30, 60, 90], value = 30, label="Number of Periods"),
gr.Dropdown([10, 100, 1000, 10000], value = 10000, label="Number of new users each period")
]

# define outputs
outputs = gr.Plot()

# define interface
demo = gr.Interface(
fn=get_retention_plot,
inputs=inputs,
outputs=outputs,
cache_examples=True,
allow_flagging = 'never' # hiding default flag functionality in the interface
)

# launch
demo.launch(debug = True)

Let’s try to run this app. It’s working — we can see a graph that changes if we submit new parameters.

Adding more graphs

Our goal was to look at the impact of retention on growth, so we need to add graphs showing not only retention but also audience over time. Let’s change our interface.

For simplicity, we will consider that in each period, the same number of new users start using our product (cohort_size parameter).

We need to make just a couple of changes to our implementation:

  • Change get_retention_plot function so that it gets one more parameter for cohort size, calculates the number of users over time and returns three Figures.
  • Parameter outputs is now equal to the list of three gr.Plot() objects.

def get_retention_plot(a, c, d, num_periods, cohort_size):
ret_df = pd.DataFrame({'x': range(num_periods + 1)})
ret_df['retention'] = ret_df.x.map(lambda x: get_retention_same_event(a, c, d, x))

ret_fig = px.line(ret_df.iloc[1:], x = 'x', y = 'retention',
color_discrete_sequence = px.colors.qualitative.Prism,
title = 'Retention curve')

# simulation

tmp_data = []

for cohort in range(num_periods + 1):
for cohort_period in range(num_periods + 1):
period = cohort_period + cohort
if period > num_periods:
continue
retention = get_retention_same_event(a, c, d, cohort_period)
tmp_data.append(
{
'cohort': 'cohort %s' % str(cohort).rjust(3, '0'),
'cohort_period': cohort_period,
'period': period,
'retention': retention,
'users': int(round(retention * cohort_size))
}
)
users_df = pd.DataFrame(tmp_data)

users_fig = px.area(users_df.groupby('period').users.sum(),
color_discrete_sequence = px.colors.qualitative.Prism,
title = 'Active users')

cohorts_fig = px.area(users_df.pivot_table(index = 'period', columns = 'cohort', values = 'users',
aggfunc = 'sum'),
color_discrete_sequence = px.colors.qualitative.Prism,
title = 'Active users by cohorts')

return ret_fig, users_fig, cohorts_fig

inputs = [
gr.Slider(0, 1, 0.03, label="a"),
gr.Slider(0, 5, 0.55, label="c"),
gr.Slider(0, 5, 1.5, label="d"),
gr.Dropdown([10, 30, 60, 90], value = 30, label="Number of Periods"),
gr.Dropdown([10, 100, 1000, 10000], value = 10000, label="Number of new users each period")
]

outputs = [gr.Plot(), gr.Plot(), gr.Plot()]

demo = gr.Interface(
fn=get_retention_plot,
inputs=inputs,
outputs=outputs,
allow_flagging = 'never',
cache_examples=True,
)

demo.launch(debug = True)

Fantastic, now we can see the complete picture and analyse the relationships. However, there’s room for improvement — we can add formatting to make our app more convenient for users.

Image by author

Adding a bit of style

We can tweak our interface a bit to make it more user-friendly and straightforward.

For that, we will be using gr.Blocks() as a . This functionality allows you to create more custom web applications and define layouts and data flows (events that trigger functions and consequent execution).

Blocks will open new opportunities for us:

  • With gr.Blocks() we can use gr.Row() and gr.Column() to organize a layout.
  • gr.Markdown allows you to add markdown elements, for example, title or even LaTeX with (by default, you need to put them inside $).
  • gr.Accordion can help you hide some parameters you don’t want to show the user by default.
  • Also, this approach allows you to define more complex of . For example, update plots not only on the submit button but on the change of any input parameter. We will use this functionality in the following example.

When working with Blocks, we need to define each input and output as variables, for example, a = gr.Slider(0, 1, 0.03, label=”a”).

Also, there are no default controls, so we have to define buttons ourselves — btn_caption = gr.Button(“Submit”).

The action on button click also must be specified, setting the already familiar parameters — inputs, outputs and fn.

btn_caption.click(fn=get_retention_plot, 
inputs=[a, c, d, num_periods, cohort_size],
outputs=[plot1, plot2, plot3])

Here is the full version of code.

with gr.Blocks() as demo:
gr.Markdown("# Understanding Growth 🚀")
with gr.Row():
with gr.Column():
gr.Markdown("## Retention curve parameters 📈")
gr.Markdown(r"$\textbf{retention}(\textsf{x}) = \textsf{a} + \frac{\textsf{1}}{\textsf{b} + \textsf{c} * \textsf{x}^{\textsf{d}}}\ where\ \textsf{b} = \frac{\textsf{1}}{\textsf{1}-\textsf{a}}$")
with gr.Row():
a = gr.Slider(0, 1, 0.03, label="a")
c = gr.Slider(0, 5, 0.55, label="c")
d = gr.Slider(0, 5, 1.5, label="d")
with gr.Accordion("More options", open=False):
with gr.Row():
num_periods = gr.Dropdown([10, 30, 60, 90], value = 30, label="Number of Periods")
cohort_size = gr.Dropdown([10, 100, 1000, 10000], value = 10000, label="Number of new users each period")
btn_caption = gr.Button("Submit")
with gr.Column():
plot1 = gr.Plot()
with gr.Row():
plot2 = gr.Plot()
plot3 = gr.Plot()

btn_caption.click(fn=get_retention_plot,
inputs=[a, c, d, num_periods, cohort_size],
outputs=[plot1, plot2, plot3])

demo.launch()

Hosting your application

Also, we can use HuggingFace Spaces to host our web applications and share them easily with others.

To start using Spaces, you need to have an account. Follow this link if you haven’t registered yet. It won’t take more than a couple of minutes.

The next step is to create a new Space. You can find instructions with more details in the documentation.

Image by author

For new Space, you must fill in the following parameters: name, license and Gradio as your SDK.

Image by author

Then, you need to commit your code to the Git repository from Hugging Spaces. First of all, we need to clone the repository.

-- cloning repo
git clone https://huggingface.co/spaces/<your_login>/<your_app_name>
cd <your_app_name>

Recently, HuggingFace has changed the Git authentication process, so we need to create a token first and then set it for the Git repo.

git remote set-url origin https://<your_login>:<token>@huggingface.co/spaces/<your_login>/<your_app_name>
git pull origin

Now, it’s time to commit related to our application. We need to have at least the following files:

  • app.py with the that launches the Gradio app
  • requirements.txt with the list of Python packages you need for your application. In our case, only pandas and plotly.

Then, basic steps with git: add, commit and push to HuggingFaces.

git add app.py
git add requirements.txt
git commit -m 'First version of retention simulator app'
git push

It took a couple of minutes to build the app, and it’s done. Now our web application is up and running on HuggingFaces Spaces. You can try it here.

Image by author

It looks much nicer than our initial version since the layout doesn’t require scrolling, and users don’t have to guess what parameters a, c and d mean.

We’ve learned how to generate graphs based on a bunch of parameters in a web application. But in real life, we usually have to input quite a lot of data, so let’s find out how to use data from .csv files in apps.

As an example, we will look at actual retention data for a few first periods and try to predict retention for the following periods. It’s quite a common task since we usually don’t want to wait three months to compare third-month retention for the new cohort. We will upload factual data as a .csv file.

Let’s not waste our time and jump to the implementation.

Getting data from files

Here is the code to generate the whole interface and business logic. It might look a bit complex. Don’t worry. We will discuss the core points later.

# parses file or string and returns dataframe
def parse_file(input_text_or_file, num_periods):
if isinstance(input_text_or_file, str):
df = pd.read_csv(StringIO(input_text_or_file), sep = '\t')
else:
df = pd.read_csv(input_text_or_file.name, sep = '\t')
return df

# takes dataframe and returns plot
def show_graph_for_df(df, num_periods):
df['period'] = df.period.map(int)
df['retention_fact'] = df.retention_fact.map(float)
result = scipy.optimize.minimize(lambda x: get_mse_for_retention(x, df), [random.random(), random.random(), random.random()])
a, c, d = result.x

pred_df = pd.DataFrame({'period': range(num_periods + 1)})
pred_df['retention_pred'] = pred_df.period.map(lambda x: get_retention_same_event(a, c, d, x))
pred_df = pred_df.merge(df, how = 'left')

fig = go.Figure()
fig.add_trace(go.Scatter(x=pred_df.period, y=pred_df.retention_fact, name='fact',
line=dict(color=plotly.colors.qualitative.Prism[0], width=3)))

fig.add_trace(go.Scatter(x=pred_df.period, y=pred_df.retention_pred, name='prediction',
line=dict(color=plotly.colors.qualitative.Prism[0], width=3, dash='dot')))

fig.update_layout(title=' retention model (a = %.2f, c = %.2f, d = %.2f)' % (a, c, d),
yaxis_title='retention',
xaxis_title='period')
return fig

# takes file and return plot
def show_graph_for_file(temp_file, num_periods):
df = parse_file(temp_file, num_periods)
return show_graph_for_df(df, num_periods)

# hard-coded example of data
default_csv = 'period\tretention_fact\n0\t1\n1\t0.55\n2\t0.4\n3\t0.35\n4\t0.3\n'

# interface
with gr.Blocks() as demo:
gr.Markdown('# Predicting retention curve 📊')
periods = gr.Dropdown([10, 30, 90, 180], label="Number of Periods", value = 30)
gr.Markdown('Upload .csv file with data, use default data as an example or put in numbers manually in the Uploaded data section.')
gr.Markdown('''__File format:__ 2 columns (`period` and `retention_fact`)''')

with gr.Row():
upload_button = gr.UploadButton(label="Upload file", file_types = ['.csv'], live=True, file_count = "single")
default_button = gr.Button('Show example')

with gr.Row():
with gr.Accordion("Uploaded data", open=False):
gr.Markdown('You can change values in the table')
table = gr.Dataframe(type="pandas", col_count=2, interactive = True, headers = ['period', 'retention_fact'])

with gr.Row():
image = gr.Plot()

# business logic of triggers and events
upload_button.upload(fn=show_graph_for_file, inputs=[upload_button, periods], outputs=image, api_name="upload_graph")
upload_button.upload(fn=parse_file, inputs=[upload_button, periods], outputs=table, api_name="upload_csv")
default_button.click(fn=lambda x: show_graph_for_file(default_csv, x), inputs=[periods], outputs=image, api_name="upload_example_graph")
default_button.click(fn=lambda x: parse_file(default_csv, x), inputs=[periods], outputs=table, api_name="upload_example_csv")
table.change(fn=show_graph_for_df, inputs=[table, periods], outputs=image, api_name="upload_table_graph")
periods.change(fn=show_graph_for_df, inputs=[table, periods], outputs=image, api_name="upload_periods_graph")

demo.launch(debug=True)

Let’s look at it closer. We have the following elements in the interface:

  • periods — input parameter,
  • upload_button — input parameter that allows you to load data from a .csv file,
  • default_button — allows you to update table and graph with pre-defined values as an example,
  • table shows the data frame from uploaded data (either from .csv file or example); also, you could change the numbers in the table in place, and the graph will be updated — so it’s an input parameter as well,
  • image — output parameter, that shows a plot.
Image by author

Function parse_file gets either file from upload_button or string from the default example and returns a pandas data frame we could use further. So, using data from files is pretty straightforward.

The crucial business logic is defined in the code snippet below. It defines actions for all interface elements:

  • for uploading .csv file — the table and the plot are updated,
  • for click on the button “Show example” — the table and the plot are updated,
  • for changing data in the table — only the plot is updated,
  • for changing the number of periods — only the plot is updated.
upload_button.upload(fn=show_graph_for_file, inputs=[upload_button, periods], outputs=image, api_name="upload_graph")
upload_button.upload(fn=parse_file, inputs=[upload_button, periods], outputs=table, api_name="upload_csv")

default_button.click(fn=lambda x: show_graph_for_file(default_csv, x), inputs=[periods], outputs=image, api_name="upload_example_graph")
default_button.click(fn=lambda x: parse_file(default_csv, x), inputs=[periods], outputs=table, api_name="upload_example_csv")

table.change(fn=show_graph_for_df, inputs=[table, periods], outputs=image, api_name="upload_table_graph")
periods.change(fn=show_graph_for_df, inputs=[table, periods], outputs=image, api_name="upload_periods_graph")

Defining the best-fit function

The essential part of our solution is finding the best-fit function for our factual data. Let’s look at how to do it.

  • First, we define the function get_mse_for_retention that returns the error for the set of parameters (a, c and d). It also takes the data frame as an input.
  • We use standard Mean Squared Error (MSE) as the error we will minimize.
  • Then, we will use scipy.optimize.minimize function for optimization. We need to pass just two parameters: the function to optimize (we passed the lambda function with a hard-coded data frame since we are optimizing only params) and the initial values for parameters (just a list of random values).
  • After optimisation, we could access optimal params using result.x .
def get_mse_for_retention(params, df):
tmp_df = df.copy()
tmp_df['retention_pred'] = tmp_df.index.map(
lambda x: get_retention_same_event(params[0], params[1], params[2], x)
)

tmp_df['se'] = (tmp_df.retention_fact - tmp_df.retention_pred)
tmp_df['se'] = tmp_df['se']**2

return tmp_df.se.mean() ** 0.5

result = scipy.optimize.minimize(lambda x: get_mse_for_retention(x, df), [random.random(), random.random(), random.random()])
a, c, d = result.x
print(a, c, d)

That’s it, now we know the theoretical retention curve for our factual data and can use it in our app for prediction.

Last step

I followed the same instructions and posted this app to HuggingFace Spaces as well. So you could try to play with it here.

You can find the whole code for both apps in GitHub.

In this article, we’ve gone through the basics of the Gradio library and learned how to build pleasant web applications with only Python.

We’ve learned a couple of approaches:

  • High-level gr.Interface class that allows you to get a working prototype quickly,
  • More customizable way of using gr.Blocks when you can specify the exact layout you need and define complex relations between inputs and outputs.

Thank you a lot for reading this article. I hope it was insightful to you. If you have any follow-up questions or comments, please leave them in the comments section.

This article is inspired by “Building Applications with Gradio” course.

Source link