Data Science, Machine Learning, and AI

Did you ever want to make your machine learning model available to other people, but didn’t know how? Or maybe you just heard about the term API, and want to know what’s behind it? Then this post is for you!

Here at STATWORX, we use and write APIs daily. For this article, I wrote down how you can build your own API for a machine learning model that you create and the meaning of some of the most important concepts like REST. After reading this short article, you will know how to make requests to your API within a Python program. So have fun reading and learning!

What is an API?

API is short for Application Programming Interface. It allows users to interact with the underlying functionality of some written code by accessing the interface. There is a multitude of APIs, and chances are good that you already heard about the type of API, we are going to talk about in this blog post: The web API.

This specific type of API allows users to interact with functionality over the internet. In this example, we are building an API that will provide predictions through our trained machine learning model. In a real-world setting, this kind of API could be embedded in some type of application, where a user enters new data and receives a prediction in return. APIs are very flexible and easy to maintain, making them a handy tool in the daily work of a Data Scientist or Data Engineer.

An example of a publicly available machine learning API is Time Door. It provides Time Series tools that you can integrate into your applications. APIs can also be used to make data available, not only machine learning models.

API Illustration

And what is REST?

Representational State Transfer (or REST) is an approach that entails a specific style of communication through web services. When using some of the REST best practices to implement an API, we call that API a “REST API”. There are other approaches to web communication, too (such as the Simple Object Access Protocol: SOAP), but REST generally runs on less bandwidth, making it preferable to serve your machine learning models.

In a REST API, the four most important types of requests are:

  • GET
  • PUT
  • POST

For our little machine learning application, we will mostly focus on the POST method, since it is very versatile, and lots of clients can’t send GET methods.

It’s important to mention that APIs are stateless. This means that they don’t save the inputs you give during an API call, so they don’t preserve the state. That’s significant because it allows multiple users and applications to use the API at the same time, without one user request interfering with another.

The Model

For this How-To-article, I decided to serve a machine learning model trained on the famous iris dataset. If you don’t know the dataset, you can check it out here. When making predictions, we will have four input parameters: sepal length, sepal width, petal length, and finally, petal width. Those will help to decide which type of iris flower the input is.

For this example I used the scikit-learn implementation of a simple KNN (K-nearest neighbor) algorithm to predict the type of iris:

from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import accuracy_score
from sklearn.externals import joblib
import numpy as np

def train(X,y):

    # train test split
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3)

    knn = KNeighborsClassifier(n_neighbors=1)

    # fit the model, y_train)
    preds = knn.predict(X_test)
    acc = accuracy_score(y_test, preds)
    print(f'Successfully trained model with an accuracy of {acc:.2f}')

    return knn

if __name__ == '__main__':

    iris_data = datasets.load_iris()
    X = iris_data['data']
    y = iris_data['target']

    labels = {0 : 'iris-setosa',
              1 : 'iris-versicolor',
              2 : 'iris-virginica'}

    # rename integer labels to actual flower names
    y = np.vectorize(labels.__getitem__)(y)

    mdl = train(X,y)

    # serialize model
    joblib.dump(mdl, 'iris.mdl')

As you can see, I trained the model with 70% of the data and then validated with 30% out of sample test data. After the model training has taken place, I serialize the model with the joblib library. Joblib is basically an alternative to pickle, which preserves the persistence of scikit estimators, which include a large number of numpy arrays (such as the KNN model, which contains all the training data). After the file is saved as a joblib file (the file ending thereby is not important by the way, so don’t be confused that some people call it .model or .joblib), it can be loaded again later in our application.

The API with Python and Flask

To build an API from our trained model, we will be using the popular web development package Flask and Flask-RESTful. Further, we import joblib to load our model and numpy to handle the input and output data.

In a new script, namely, we can now set up an instance of a Flask app and an API and load the trained model (this requires saving the model in the same directory as the script):

from flask import Flask
from flask_restful import Api, Resource, reqparse
from sklearn.externals import joblib
import numpy as np

APP = Flask(__name__)
API = Api(APP)

IRIS_MODEL = joblib.load('iris.mdl')

The second step now is to create a class, which is responsible for our prediction. This class will be a child class of the Flask-RESTful class Resource. This lets our class inherit the respective class methods and allows Flask to do the work behind your API without needing to implement everything.

In this class, we can also define the methods (REST requests) that we talked about before. So now we implement a Predict class with a .post() method we talked about earlier.

The post method allows the user to send a body along with the default API parameters. Usually, we want the body to be in JSON format. Since this body is not delivered directly in the URL, but as a text, we have to parse this text and fetch the arguments. The flask _restful package offers the RequestParser class for that. We simply add all the arguments we expect to find in the JSON input with the .add_argument() method and parse them into a dictionary. We then convert it into an array and return the prediction of our model as JSON.

class Predict(Resource):

    def post():
        parser = reqparse.RequestParser()

        args = parser.parse_args()  # creates dict

        X_new = np.fromiter(args.values(), dtype=float)  # convert input to array

        out = {'Prediction': IRIS_MODEL.predict([X_new])[0]}

        return out, 200

You might be wondering what the 200 is that we are returning at the end: For APIs, some HTTP status codes are displayed when sending requests. You all might be familiar with the famous 404 - page not found code. 200 just means that the request has been received successfully. You basically let the user know that everything went according to plan.

In the end, you just have to add the Predict class as a resource to the API, and write the main function:

API.add_resource(Predict, '/predict')

if __name__ == '__main__':, port='1080')

The '/predict' you see in the .add_resource() call, is the so-called API endpoint. Through this endpoint, users of your API will be able to access and send (in this case) POST requests. If you don’t define a port, port 5000 will be the default.

You can see the whole code for the app again here:

from flask import Flask
from flask_restful import Api, Resource, reqparse
from sklearn.externals import joblib
import numpy as np

APP = Flask(__name__)
API = Api(APP)

IRIS_MODEL = joblib.load('iris.mdl')

class Predict(Resource):

    def post():
        parser = reqparse.RequestParser()

        args = parser.parse_args()  # creates dict

        X_new = np.fromiter(args.values(), dtype=float)  # convert input to array

        out = {'Prediction': IRIS_MODEL.predict([X_new])[0]}

        return out, 200

API.add_resource(Predict, '/predict')

if __name__ == '__main__':, port='1080')

Run the API

Now it’s time to run and test our API!

To run the app, simply open a terminal in the same directory as your script and run this command.

python run

You should now get a notification, that the API runs on your localhost in the port you defined. There are several ways of accessing the API once it is deployed. For debugging and testing purposes, I usually use tools like Postman. We can also access the API from within a Python application, just like another user might want to do to use your model in their code.

We use the requests module, by first defining the URL to access and the body to send along with our HTTP request:

import requests

url = ''  # localhost and the defined port + endpoint
body = {
    "petal_length": 2,
    "sepal_length": 2,
    "petal_width": 0.5,
    "sepal_width": 3
response =, data=body)

The output should look something like this:

Out[1]: {'Prediction': 'iris-versicolor'}

That’s how easy it is to include an API call in your Python code! Please note that this API is just running on your localhost. You would have to deploy the API to a live server (e.g., on AWS) for others to access it.


In this blog article, you got a brief overview of how to build a REST API to serve your machine learning model with a web interface. Further, you now understand how to integrate simple API requests into your Python code. For the next step, maybe try securing your APIs? If you are interested in learning how to build an API with R, you should check out this post. I hope that this gave you a solid introduction to the concept and that you will be building your own APIs immediately. Happy coding!


Jannik Klauke Jannik Klauke


When working on data science projects in R, exporting internal R objects as files on your hard drive is often necessary to facilitate collaboration. Here at STATWORX, we regularly export R objects (such as outputs of a machine learning model) as .RDS files and put them on our internal file server. Our co-workers can then pick them up for further usage down the line of the data science workflow (such as visualizing them in a dashboard together with inputs from other colleagues).

Over the last couple of months, I came to work a lot with RDS files and noticed a crucial shortcoming: The base R saveRDS function does not allow for any kind of archiving of existing same-named files on your hard drive. In this blog post, I will explain why this might be very useful by introducing the basics of serialization first and then showcasing my proposed solution: A wrapper function around the existing base R serialization framework.

Be wary of silent file replacements!

In base R, you can easily export any object from the environment to an RDS file with:

saveRDS(object = my_object, file = "path/to/dir/my_object.RDS")

However, including such a line somewhere in your script can carry unintended consequences: When calling saveRDS multiple times with identical file names, R silently overwrites existing, identically named .RDS files in the specified directory. If the object you are exporting is not what you expect it to be — for example due to some bug in newly edited code — your working copy of the RDS file is simply overwritten in-place. Needless to say, this can prove undesirable.

If you are familiar with this pitfall, you probably used to forestall such potentially troublesome side effects by commenting out the respective lines, then carefully checking each time whether the R object looked fine, then executing the line manually. But even when there is nothing wrong with the R object you seek to export, it can make sense to retain an archived copy of previous RDS files: Think of a dataset you run through a data prep script, and then you get an update of the raw data, or you decide to change something in the data prep (like removing a variable). You may wish to archive an existing copy in such cases, especially with complex data prep pipelines with long execution time.

Don’t get tangled up in manual renaming

You could manually move or rename the existing file each time you plan to create a new one, but that’s tedious, error-prone, and does not allow for unattended execution and scalability. For this reason, I set out to write a carefully designed wrapper function around the existing saveRDS call, which is pretty straightforward: As a first step, it checks if the file you attempt to save already exists in the specified location. If it does, the existing file is renamed/archived (with customizable options), and the “updated” file will be saved under the originally specified name.

This approach has the crucial advantage that the existing code that depends on the file name remaining identical (such as readRDS calls in other scripts) will continue to work with the latest version without any needs for adjustment! No more saving your objects as “models_2020-07-12.RDS”, then combing through the other scripts to replace the file name, only to repeat this process the next day. At the same time, an archived copy of the — otherwise overwritten — file will be kept.

What are RDS files anyways?

Before I walk you through my proposed solution, let’s first examine the basics of serialization, the underlying process behind high-level functions like saveRDS.

Simply speaking, serialization is the “process of converting an object into a stream of bytes so that it can be transferred over a network or stored in a persistent storage.” Stack Overflow: What is serialization?

There is also a low-level R interface, serialize, which you can use to explore (un-)serialization first-hand: Simply fire up R and run something like serialize(object = c(1, 2, 3), connection = NULL). This call serializes the specified vector and prints the output right to the console. The result is an odd-looking raw vector, with each byte separately represented as a pair of hex digits. Now let’s see what happens if we revert this process:

s <- serialize(object = c(1, 2, 3), connection = NULL)
# >  [1] 58 0a 00 00 00 03 00 03 06 00 00 03 05 00 00 00 00 05 55 54 46 2d 38 00 00 00 0e 00
# > [29] 00 00 03 3f f0 00 00 00 00 00 00 40 00 00 00 00 00 00 00 40 08 00 00 00 00 00 00

# > 1 2 3

The length of this raw vector increases rapidly with the complexity of the stored information: For instance, serializing the famous, although not too large, iris dataset results in a raw vector consisting of 5959 pairs of hex digits!

Besides the already mentioned saveRDS function, there is also the more generic save function. The former saves a single R object to a file. It allows us to restore the object from that file (with the counterpart readRDS), possibly under a different variable name: That is, you can assign the contents of a call to readRDS to another variable. By contrast, save allows for saving multiple R objects, but when reading back in (with load), they are simply restored in the environment under the object names they were saved with. (That’s also what happens automatically when you answer “Yes” to the notorious question of whether to “save the workspace image to ~/.RData” when quitting RStudio.)

Creating the archives

Obviously, it’s great to have the possibility to save internal R objects to a file and then be able to re-import them in a clean session or on a different machine. This is especially true for the results of long and computationally heavy operations such as fitting machine learning models. But as we learned earlier, one wrong keystroke can potentially erase that one precious 3-hour-fit fine-tuned XGBoost model you ran and carefully saved to an RDS file yesterday.

Digging into the wrapper

So, how did I go about fixing this? Let’s take a look at the code. First, I define the arguments and their defaults: The object and file arguments are taken directly from the wrapped function, the remaining arguments allow the user to customize the archiving process: Append the archive file name with either the date the original file was archived or last modified, add an additional timestamp (not just the calendar date), or save the file to a dedicated archive directory. For more details, please check the documentation here. I also include the ellipsis ... for additional arguments to be passed down to saveRDS. Additionally, I do some basic input handling (not included here).

save_rds_archive <- function(object,
                             file = "",
                             archive = TRUE,
                             last_modified = FALSE,
                             with_time = FALSE,
                             archive_dir_path = NULL,
                             ...) {

The main body of the function is basically a series of if/else statements. I first check if the archive argument (which controls whether the file should be archived in the first place) is set to TRUE, and then if the file we are trying to save already exists (note that “file” here actually refers to the whole file path). If it does, I call the internal helper function create_archived_file, which eliminates redundancy and allows for concise code.

if (archive) {

    # check if file exists
    if (file.exists(file)) {

      archived_file <- create_archived_file(file = file,
                                            last_modified = last_modified,
                                            with_time = with_time)

Composing the new file name

In this function, I create the new name for the file which is to be archived, depending on user input: If last_modified is set, then the mtime of the file is accessed. Otherwise, the current system date/time (= the date of archiving) is taken instead. Then the spaces and special characters are replaced with underscores, and, depending on the value of the with_time argument, the actual time information (not just the calendar date) is kept or not.

To make it easier to identify directly from the file name what exactly (date of archiving vs. date of modification) the indicated date/time refers to, I also add appropriate information to the file name. Then I save the file extension for easier replacement (note that “.RDS”, “.Rds”, and “.rds” are all valid file extensions for RDS files). Lastly, I replace the current file extension with a concatenated string containing the type info, the new date/time suffix, and the original file extension. Note here that I add a “$” sign to the regex which is to be matched by gsub to only match the end of the string: If I did not do that and the file name would be something like “my_RDS.RDS”, then both matches would be replaced.

# create_archived_file.R

create_archived_file <- function(file, last_modified, with_time) {

  # create main suffix depending on type
  suffix_main <- ifelse(last_modified,

  if (with_time) {

    # create clean date-time suffix
    suffix <- gsub(pattern = " ", replacement = "_", x = suffix_main)
    suffix <- gsub(pattern = ":", replacement = "-", x = suffix)

    # add "at" between date and time
    suffix <- paste0(substr(suffix, 1, 10), "_at_", substr(suffix, 12, 19))

  } else {

    # create date suffix
    suffix <- substr(suffix_main, 1, 10)


  # create info to paste depending on type
  type_info <- ifelse(last_modified,

  # get file extension (could be any of "RDS", "Rds", "rds", etc.)
  ext <- paste0(".", tools::file_ext(file))

  # replace extension with suffix
  archived_file <- gsub(pattern = paste0(ext, "$"),
                        replacement = paste0(type_info,
                        x = file)



Archiving the archives?

By way of example, with last_modified = FALSE and with_time = TRUE, this function would turn the character file name “models.RDS” into “models_ARCHIVED_on_2020-07-12_at_11-31-43.RDS”. However, this is just a character vector for now — the file itself is not renamed yet. For this, we need to call the base R file.rename function, which provides a direct interface to your machine’s file system. I first check, however, whether a file with the same name as the newly created archived file string already exists: This could well be the case if one appends only the date (with_time = FALSE) and calls this function several times per day (or potentially on the same file if last_modified = TRUE).

Somehow, we are back to the old problem in this case. However, I decided that it was not a good idea to archive files that are themselves archived versions of another file since this would lead to too much confusion (and potentially too much disk space being occupied). Therefore, only the most recent archived version will be kept. (Note that if you still want to keep multiple archived versions of a single file, you can set with_time = TRUE. This will append a timestamp to the archived file name up to the second, virtually eliminating the possibility of duplicated file names.) A warning is issued, and then the already existing archived file will be overwritten with the current archived version.

The last puzzle piece: Renaming the original file

To do this, I call the file.rename function, renaming the “file” originally passed by the user call to the string returned by the helper function. The file.rename function always returns a boolean indicating if the operation succeeded, which I save to a variable temp to inspect later. Under some circumstances, the renaming process may fail, for instance due to missing permissions or OS-specific restrictions. We did set up a CI pipeline with GitHub Actions and continuously test our code on Windows, Linux, and MacOS machines with different versions of R. So far, we didn’t run into any problems. Still, it’s better to provide in-built checks.

It’s an error! Or is it?

The problem here is that, when renaming the file on disk failed, file.rename raises merely a warning, not an error. Since any causes of these warnings most likely originate from the local file system, there is no sense in continuing the function if the renaming failed. That’s why I wrapped it into a tryCatch call that captures the warning message and passes it to the stop call, which then terminates the function with the appropriate message.

Just to be on the safe side, I check the value of the temp variable, which should be TRUE if the renaming succeeded, and also check if the archived version of the file (that is, the result of our renaming operation) exists. If both of these conditions hold, I simply call saveRDS with the original specifications (now that our existing copy has been renamed, nothing will be overwritten if we save the new file with the original name), passing along further arguments with ....

        if (file.exists(archived_file)) {
          warning("Archived copy already exists - will overwrite!")

        # rename existing file with the new name
        # save return value of the file.rename function
        # (returns TRUE if successful) and wrap in tryCatch
        temp <- tryCatch({file.rename(from = file,
                                      to = archived_file)
        warning = function(e) {


      # check return value and if archived file exists
      if (temp & file.exists(archived_file)) {
        # then save new file under specified name
        saveRDS(object = object, file = file, ...)


These code snippets represent the cornerstones of my function. I also skipped some portions of the source code for reasons of brevity, chiefly the creation of the “archive directory” (if one is specified) and the process of copying the archived file into it. Please refer to our GitHub for the complete source code of the main and the helper function.

Finally, to illustrate, let’s see what this looks like in action:

x <- 5
y <- 10
z <- 20

## save to RDS
saveRDS(x, "temp.RDS")
saveRDS(y, "temp.RDS")

## "temp.RDS" is silently overwritten with y
## previous version is lost
#> [1] 10

save_rds_archive(z, "temp.RDS")
## current version is updated
#> [1] 20

## previous version is archived
#> [1] 10

Great, how can I get this?

The function save_rds_archive is now included in the newly refactored helfRlein package (now available in version 1.0.0!) which you can install directly from GitHub:

# install.packages("devtools")

Feel free to check out additional documentation and the source code there. If you have any inputs or feedback on how the function could be improved, please do not hesitate to contact me or raise an issue on our GitHub.


That’s it! No more manually renaming your precious RDS files — with this function in place, you can automate this tedious task and easily keep a comprehensive archive of previous versions. You will be able to take another look at that one model you ran last week (and then discarded again) in the blink of an eye. I hope you enjoyed reading my post — maybe the function will come in handy for you someday!

Lukas Feick Lukas Feick

At STATWORX, coding is our bread and butter. Because our projects involve many different people in several organizations across multiple generations of programmers, writing clean code is essential. The main requirements for well-structured and readable code are comments and sections. In RStudio, these sections are defined by comments that end with at least four dashes ---- (you can also use trailing equal signs ==== or hashes ####). In my opinion, the code is even more clear if the dashes cover the whole range of 80 characters (why you should not exceed the 80 characters limit). That’s how my code usually looks like:

# loading packages -------------------------------------------------------------

# load data --------------------------------------------------------------------
my_iris <- as_tibble(iris)

# prepare data -----------------------------------------------------------------
my_iris_preped <- my_iris %>% 
  filter(Species == "virginica") %>% 
  mutate_if(is.numeric, list(squared = sqrt))

# ...

Clean, huh? Well, yes, but neither of the three options available to achieve this are as neat as I want it to be:

  • Press - for some time.
  • Copy a certain amount of dashes and insert them sequentially. Both options often result in too many dashes, so I have to remove the redundant ones.
  • Use the shortcut to insert a new section (CMD/STRG + SHIFT + R). However, you cannot neatly include it after you wrote your comments.

Wouldn’t it be nice to have a keyboard shortcut that included the right amount of dashes up from the cursor position? “Easy as can be,” I thought before trying to define a custom shortcut in RStudio.

Unfortunately, it turned out not to be that easy. There is a manual from RStudio that actually covers how you can create your shortcut, but it requires you to put it in a package first. Since I have not been an expert in R package development myself, I decided to go the full distance in this blog post. By following it step by step, you should be able to define your shortcuts within a few minutes.

Note: This article is not about creating a CRAN-worthy package, but covers what is necessary to define your own shortcuts. If you have already created packages before, you can skip the parts about package development and jump directly to what is new to you.

Setting up an R package

First of all, open RStudio and create an R package directory. For this, please do the following steps:

  1. Go to “New Project…”
  2. “New Directory”
  3. “R Package”
  4. Select an awesome package name of your choice. In this example, I named my package shoRtcut
  5. In “Create project as subdirectory of:” select a directory of your choice. A new folder with your package name will be created in this directory.

Tada, everything necessary for a powerful R Package has been set up. RStudio also automatically provides a dummy function hello(). Since we do not like to have this function in our own package, move to the “R” folder in your project and delete the hello.R file. Do the same in the “man” folder and delete

Creating an Addin Function

Now we can start and define our function. For this, we need the wonderful packages usethis and devtools. These provide all the functionality we need for the next steps.

Defining the Addin Function

Via the use_r() function, we define a new R script file with the given name. That should correspond to the name of the function we are about to create. In my case, I call it set_new_chapter.

# use this function to automatically create a new r script for your function

You are directly forwarded to the created file. Now the tricky part begins, defining a function that does what you want. When defining shortcuts that interact with an R script in RStudio, you will soon discover the package rstudioapi. With its functions, you can grab all information from RStudio and make it available within R. Let me guide you through it step by step.

  1. As per usual, I set up a regular R function and define its name as set_new_chapter. Next, I define up until which limit I want to include the dashes. You will note that I rather set nchars to 81 than 80. This is because the number corresponds to the cursor position after including the dashes. You will notice that when you write text, the cursor automatically jumps to the position right after the newly typed character. After you have written your 80th character, the cursor will be at position 81.
  2. Now we have to find out where the cursor is currently located. This information can be unearthed by the getActiveDocumentContext() function. The returned object returns quite a bit of information, but we are only interested in the cursor position regarding the column. Why the column? You can think of the script like a matrix. Hitting return brings you to a new row, typing a character into a new column. Having a font with equal space characters, which is the default setting in RStudio makes this concept easy to see.
  3. By sneaking into the nested list, we find the information we are looking for and store it in context_col. Now we check whether the cursor is already at “column” 81. If not, there is space in which we insert the dashes. For this final step, we can use another function: insertText.
  4. As its name implies, it inserts text in an R script or console. You can either specify a specific position in the document or, by leaving it empty, insert text at the current cursor position, which is exactly what I want right now. As the final step, I need to find out the number of dashes that should be inserted. That’s the difference between the current cursor location and its target position. For example, if the cursor blinks at column 51, meaning I already have typed 50 characters, I want to insert 30 dashes.
  5. To document the function, I use the “Code” > “Insert Roxygen Skeleton” feature and fill it out appropriately.

This is what my final function looks like.

#' Insert dashes from courser position to up to 80 characters
#' @return dashes inside RStudio
set_new_chapter <- function(){
  # set limit to which position dashes should be included
  nchars <- 81

  # grab current document information
  context <- rstudioapi::getActiveDocumentContext()
  # extract horizontal courser position in document
  context_col <- contextselection[[1]]range$end["column"]

  # if a line has less than 81 characters, insert hyphens at the current line
  # up to 80 characters
  if (nchars > context_col) {
    rstudioapi::insertText(strrep("-", nchars - context_col))

Defining the Function AS and Addin

Now we must somehow tell RStudio that this particular function should be used as an addin rather than a regular function. For this, go to “File” > “New File” > “Text File” and include the following text:

Name: Insert Dashes (---)
Description: Inserts `---` at the cursor position up to 80 characters.
Binding: set_new_chapter
Interactive: false
  • Name is a short description of what the addin does. This will be displayed when you want to set the shortcut later.
  • Description is a longer description of its functionality.
  • Binding sets the name of the function that should be called by the shortcut.
  • Interactive defines whether this addin is interactive (e.g., runs a Shiny application) or not.

You now must save this file as “addins.dcf” in your project with the following path: “inst” > “rstudio”. The result should look like this:

Finalize the Package

To wrap everything up and make the shortcut available to you and your colleagues, we only have to call a few more functions. Not all these steps are necessary, yet it is good practice to create a proper package.

# OPTIONAL: define the license of your package
usethis::use_mit_license(name = "Matthias Nistler")

# define dependencies you use in your package

# OPTIONAL: include your function description to the manual

# check for errors

# update/create your package

> ✓  checking for file ‘/Users/matthiasnistler/Projekte/2020/blog_shoRtcut/DESCRIPTION’ ...
> ─  preparing ‘shoRtcut’:
> ✓  checking DESCRIPTION meta-information ...
> ─  checking for LF line-endings in source and make files and shell scripts
> ─  checking for empty or unneeded directories
> ─  building ‘shoRtcut_0.0.0.9000.tar.gz’
> [1] "/Users/matthiasnistler/Projekte/2020/shoRtcut_0.0.0.9000.tar.gz"

There you go! You just created an awesome package and distributed it to your friends and colleagues.

Make the shortcut available

For the last step, you have to install your package and set a keyboard combination for your shortcut. For this, use the following specification of install.packages:

    # same path as above
  # indicate it is a local file
  repos = NULL)

# check if everything works

Now go to “Tools” > “Modify Keyboard Shortcuts…” and search for “dashes”. Here you can define the keyboard combination by clicking inside the empty “Shortcut” field and pressing the desired key combination on your keyboard. Click “Apply”, and that’s it!

In case you are just here to use my shortcut, you can install it via remotes::install_github("mnist91/shoRtcut").


You made it! Now you can use your own RStudio shortcut. Exciting, isn’t it?

But that’s not all there is – next week, I will give you an introduction to the wonderful world of R package naming. So stay tuned and happy coding!

Matthias Nistler Matthias Nistler

Data operations is an increasingly important part of data science because it enables companies to feed large business data back into production effectively. We at STATWORX, therefore, operationalize our models and algorithms by translating them into Application Programming Interfaces (APIs). Representational State Transfer (REST) APIs are well suited to be implemented as part of a modern micro-services infrastructure. They are flexible, easy to deploy, scale, and maintain, and they are further accessible by multiple clients and client types at the same time. Their primary purpose is to simplify programming by abstracting the underlying implementation and only exposing objects and actions that are needed for any further development and interaction.

An additional advantage of APIs is that they allow for an easy combination of code written in different programming languages or by different development teams. This is because APIs are naturally separated from each other, and communication with and between APIs is handled by IP or URL (http), typically using JSON or XML format. Imagine, e.g., an infrastructure, where an API that’s written in Python and one that’s written in R communicate with each other and serve an application written in JavaScript.

In this blog post, I will show you how to translate a simple R script, which transforms tables from wide to long format, into a REST API with the R package Plumber and how to run it locally or with Docker. I have created this example API for our trainee program, and it serves our new data scientists and engineers as a starting point to familiarize themselves with the subject.

Translate the R Script

Transforming an R script into a REST API is quite easy. All you need, in addition to R and RStudio, is the package Plumber and optionally Docker. REST APIs can be interacted with by sending a REST Request, and the probably most commonly used ones are GET, PUT, POST, and DELETE. Here is the code of the example API, that transforms tables from wide to long or from long to wide format:

## transform wide to long and long to wide format
#' @post /widelong
#' @get /widelong
function(req) {
  # library

  # post body
  body <- jsonlite::fromJSON(req$postBody)

  .data <- body$.data
  .trans <- body$.trans
  .key <- body$.key
  .value <- body$.value
  .select <- body$.select

  # wide or long transformation
  if(.trans == 'l' || .trans == 'long') {
    .data %<>% gather(key = !!.key, value = !!.value, !!.select)
  } else if(.trans == 'w' || .trans == 'wide') {
    .data %<>% spread(key = !!.key, value = !!.value)
  } else {
    print('Please specify the transformation')

As you can see, it is a standard R function, that is extended by the special plumber comments @post and @get, which enable the API to respond to those types of requests. It is necessary to add the path, /widelong, to any incoming request. That is done because it is possible to stack several API functions, which respond to different paths. We could, e.g., add another function with the path /naremove to our API, which removes NAs from tables.

The R function itself has one function argument req, which is used to receive a (POST) Request Body. In general, there are two different possibilities to send additional arguments and objects to a REST API, the header and the body. I decided to use a body only and no header at all, which makes the API cleaner, safer and allows us to send larger objects. A header could, e.g., be used to set some optional function arguments, but should be used sparsely otherwise.

Using a body with the API is also the reason to allow for GET and POST Requests (@post, @get) at the same time. While some clients prefer to send a body with a GET Request, when they do not permanently post something to the server etc., many other clients do not have the option to send a body with a GET Request at all. In this case, it is mandatory to add a POST Request. Typical clients are Applications, Integrated Development Environments (IDEs), and other APIs. By accepting both request types, our API, therefore, gains greater response flexibility.

For the request-response format of the API, I have decided to stick with the JavaScript Object Notation (JSON), which is probably the most common format. It would be possible to use Extensible Markup Language (XML) with R Plumber instead as well. The decision for one or the other will most likely depend on which additional R packages you want to use or on which format the API’s clients are predominantly using. The R packages that are used to handle REST Requests in my example API are jsonlite and httr. The three Tidyverse packages are used to do the table transformation to wide or long.


The finished REST API can be run locally with R or RStudio as follows:


widelong_api <- plumber::plumb("./path/to/directory/widelongwide.R")
widelong_api$run(host = '', port = 8000)

Upon starting the API, the Plumber package provides us with an IP address, and a port and a client, e.g., another R instance, can now begin to send REST Requests. It also opens a browser tool called Swagger, which can be useful to check if your API is working as intended. Once the development of an API is finished, I would suggest to build a docker image and run it in a container. That makes the API highly portable and independent of its host system. Since we want to use most APIs in production and deploy them to, e.g., a company server or the cloud, this is especially important. Here is the Dockerfile to build the docker image of the example API:

FROM trestletech/plumber

# Install dependencies
RUN apt-get update --allow-releaseinfo-change && apt-get install -y 

# Install R packages
RUN R -e "install.packages(c('tidyr', 'dplyr', 'magrittr', 'httr', 'jsonlite'), 
repos = '')"

# Add API
COPY ./path/to/directory/widelongwide.R /widelongwide.R

# Make port available

# Entrypoint
ENTRYPOINT ["R", "-e", 
"widelong <- plumber::plumb('widelongwide.R'); 
widelong$run(host = '', port= 8000)"]

CMD ["/widelongwide.R"]

Send a REST Request

The wide-long example API can generally respond to any client sending a POST or GET Request with a Body in JSON format, that contains a table in csv format and all needed information on how to transform it. Here is an example for a web application, which I have written for our trainee program to supplement the wide-long API:

The application is written in R Shiny, which is a great R package to transform your static plots and outputs into an interactive dashboard. If you are interested in how to create dashboards in R, check out other posts on our STATWORX Blog.

Last but not least here is an example on how to send a REST Request from R or RStudio:

options(stringsAsFactors = FALSE)

# url for local testing
url <- ""

# url for docker container
url <- ""

# read example stock data
.data <- read.csv('./path/to/data/stocks.csv')

# create example body
body <- list(
  .data = .data,
  .trans = "w",
  .key = "stock",
  .value = "price",
  .select = c("X","Y","Z")

# set API path
path <- 'widelong'

# send POST Request to API
raw.result <- POST(url = url, path = path, body = body, encode = 'json')

# check status code

# retrieve transformed example stock data
.t_data <- fromJSON(rawToChar(raw.result$content))

As you can see, it is quite easy to make REST Requests in R. If you need some test data, you could use the stocks data example from the Tidyverse.


In this blog post, I showed you how to translate a simple R script, which transforms tables from wide to long format, into a REST API with the R package Plumber and how to run it locally or with Docker. I hope you enjoyed the read and learned something about operationalizing R scripts into REST APIs with the R package Plumber and how to run them locally and with Docker. You are of welcome to copy and use any code from this blog post to start and create your REST APIs with R.

Until then, stay tuned and visit our STATWORX Blog again soon.

We’re hiring!

Data Engineering is your jam and you’re looking for a job? We’re currently looking for Junior Consultants and Consultants in Data Engineering. Check the requirements and benefits of working with us on our career site. We’re looking forward to your application!

Stephan Emmer Stephan Emmer

Nearly one year ago, I analyzed how we use emojis in our Slack messages. Since then, STATWORX grew, and we are a lot more people now! So, I just wanted to check if something changed.

Last time, I did not show our custom emojis, since they are, of course, not available in the fonts I used. This time, I will incorporate them with geom_image(). It is part of the ggimage package from Guangchuang Yu, which you can find here on his Github. With geom_image() you can include images like .png files to your ggplot.

What changed since last year?

Let’s first have a look at the amount of emojis we are using. In the plot below, you can see that since my last analysis in October 2018 (red line) the amount of emojis is rising. Not as much as I thought it would, but compared to the previous period, we now have more days with a usage of over 100 emojis per day!

Like last time, our top emoji is ????, followed by ???? and ????. But sneaking in at number ten is one of our custom emojis: party_hat_parrot!


How to include custom images?

In my previous blogpost, I hid all our custom emojis behind❓since they were not part of the font. It did not occur to me to use their images, even though the package is from the same creator! So, to make up for my ignorance, I grabbed the top 30 custom emojis and downloaded their images from our Slack servers, saved them as .png and made sure they are all roughly the same size.

To use geom_image() I just added the path of the images to my data (the are just an abbreviation for the complete path).

1:          alnatura    25       63 .../custom/alnatura.png
2:              blog    19       20 .../custom/blog.png
3:           dataiku    15       22 .../custom/dataiku.png
4: dealwithit_parrot     3      100 .../custom/dealwithit_parrot.png
5:      deananddavid    31       18 .../custom/deananddavid.png

This would have been enough to just add the images now, but since I wanted the NAME attribute as a label, I included geom_text_repel from the ggrepel library. This makes handling of non-overlapping labels much simpler!

ggplot(custom_dt, aes( x = REACTION, y = COUNT, label = NAME)) +
  geom_image(aes(image = IMAGE), size = 0.04) +
  geom_text_repel(point.padding = 0.9, segment.alpha = 0) +
  xlab("as reaction") +
  ylab("within message") +

Usually, if a label is “too far” away from the marker, geom_text_repel includes a line to indicate where the labels belong. Since these lines would overlap the images, I used segment.alpha = 0 to make them invisible. With point.padding = 0.9 I gave the labels a bit more space, so it looks nicer. Depending on the size of the plot, this needs to be adjusted. In the plot, one can see our usage of emojis within a message (y-axis) and as a reaction (x-axis).

To combine the emoji font and custom emojis, I used the following data and code — really… why did I not do this last time? ???? Since the UNICODE is NA when I want to use the IMAGE, there is no “double plotting”.

 1:                    :+1:     1090     0 1090     1 U0001f44d
 2:                   :joy:      609   152  761     2 U0001f602
 3:                 :smile:       91   496  587     3 U0001f604
 4:                    :-1:      434     9  443     4 U0001f44e
 5:                  :tada:      346    38  384     5 U0001f389
 6:                  :fire:      274    17  291     6 U0001f525
 7: :slightly_smiling_face:        1   250  251     7 U0001f642
 8:                  :wink:       27   191  218     8 U0001f609
 9:                  :clap:      201    13  214     9 U0001f44f
10:      :party_hat_parrot:      192     9  201    10       <NA>  .../custom/party_hat_parrot.png
ggplot(plotdata2, aes(x = PLACE, y = SUM, label = UNICODE)) +
  geom_bar(stat = "identity", fill = "steelblue") +
  geom_text(family="EmojiOne") +
  xlab("Most popular emojis") +
  ylab("Number of usage") +
  scale_fill_brewer(palette = "Paired") +
  geom_image(aes(image = IMAGE), size = 0.04) +
ps = grid.export(paste0(main_path, "plots/top-10-used-emojis.svg"), addClass=T)

The meaning behind emojis

Now we know what our top emojis are. But what is the rest of the world doing? Thanks to Emojimore for providing me with this overview! On their site, you can find meanings for a lot more emojis.

Behind each of our custom emojis is a story as well. For example, all the food emojis are helping us every day to decide where to eat and provide information on what everyone is planning for lunch! And if you do not agree with the decision, just react with sadphan to let the others know about your feelings. If you want to know the whole stories behind all custom emojis or even help create new ones, then maybe you should join our team — check out our available job offers here!


Jakob Gepp Jakob Gepp

The tidyverse is making the life of a data scientist a lot easier. That’s why we at STATWORX love to execute our analytics and data science with the tidyverse. Its user-centered approach has many advantages. Instead of the base R version df_test[df_test$x > 10], we can write df_test %>% filter(x>10)), which is a lot more readable – especially if our data pipeline gets more complex and nested. Also, as you might have noticed, we can use the column names directly instead of referencing the Data Frame before. Because of those advantages, we want to use dplyr-verbs for writing our function. Imagine we want to write our own summary-function my_summary(), which takes a grouping variable and calculates some descriptive statistics. Let’s see what happens when we wrap a dplyr-pipeline into a function:

my_summary <- function(df, grouping_var){
 df %>%
  group_by(grouping_var) %>% 
   avg = mean(air_time),
   sum = sum(air_time),
   min = min(air_time),
   max = max(air_time),
   obs = n()
my_summary(airline_df, origin)
Error in grouped_df_impl(data, unname(vars), drop) : 
 Column `grouping_var` is unknown 

Our new function uses group_by(), which is searching for a grouping variable grouping_var, and not for origin, as we intended. So, what happened here? group_by() is searching within its scope for the variable grouping_var, which it does not find. group_by() is quoting its arguments, grouping_var in our example. That’s why dplyr can implement custom ways of handling its operation. Throughout the tidyverse, tidy evaluation is used. Therefore we can use column names, as it is a variable. However, our data frame has no column grouping_var.

Non-Standard Evaluation

Talking about whether an argument is quoted or evaluated is a more precise way of stating whether or not a function uses non-standard evaluation (NSE). – Hadley Wickham

The quoting used by group_by() means, that it uses non-standard evaluation, like most verbs you can find in dplyr. Nonetheless, non-standard evaluation is not only found and used within dplyr and the tidyverse.


Because dplyr quotes its arguments, we have to do two things to use it in our function:

  • First, we have to quote our argument
  • Second, we have to tell dplyr, that we already have quoted the argument, which we do with unquoting

We will see this quote-and-unquote pattern consequently through functions which are using tidy evaluation.

my_summary <- function(df, grouping_var){
  df %>%
    group_by(!!grouping_var) %>% 
      avg = mean(air_time),
      sum = sum(air_time),
      min = min(air_time),
      max = max(air_time),
      obs = n()
my_summary(airline_df, quo(origin))

Therefore, as input in our function, we quote the origin-variable, which means that R doesn’t search for the symbol origin in the global environment, but holds evaluation. The quotation takes place with the quo() function. In order to tell group_by(), that the variable was already quoted we need to use the !!-Operator; pronounced Bang-Bang (if you wondered about the title).

If we are not using !!, group_by() at first searches for the variable within its scope, which are the columns of the given data frame. As mentioned before, throughout the tidyverse, tidy evaluation is used with its eval_tidy()-function. Whereby, it also introduces the concept of data mask, which makes data a first class object in R.

Data Mask


Generally speaking, the data mask approach is much more convenient. However, on the programming site, we have to pay attention to some things, like the quote-and-unquote pattern from before.

As a next step, we want the quotation to take place inside of the function, so the user of the function does not have to do it. Sadly, using quo() inside the function does not work.

my_summary <- function(df, grouping_var){
  df %>%
    group_by(!!grouping_var) %>% 
      avg = mean(air_time),
      sum = sum(air_time),
      min = min(air_time),
      max = max(air_time),
      obs = n()
my_summary(airline_df, origin)
Error in quos(...) : object 'origin' not found 

We are getting an error message because quo() is taking it too literal and is quoting grouping_var directly instead of substituting it with origin as desired. That’s why we use the function enquo() for enriched quotation, which creates a quosure. A quosure is an object which contains an expression and an environment. Quosures redefine the internal promise object into something that can be used for programming. Thus, the following code is working, and we see the quote-and-unquote pattern again.

my_summary <- function(df, grouping_var){
  grouping_var <- enquo(grouping_var)
  df %>%
    group_by(!!grouping_var) %>% 
      avg = mean(air_time),
      sum = sum(air_time),
      min = min(air_time),
      max = max(air_time),
      obs = n()
my_summary(airline_df, origin)
# A tibble: 2 x 6
  origin   avg    sum   min   max   obs
  <fct>  <dbl>  <int> <int> <int> <int>
1 JFK     166. 587966    19   415  3539
2 LAX     132. 850259     1   381  6461

All R code is a tree

To better understand what’s happening, it is useful to know that every R code can be represented by an Abstract Syntax Tree (AST) because the structure of the code is strictly hierarchical. The leaves of an AST are either symbols or constants. The more complex a function call gets, the deeper an AST is getting with more and more levels. Symbols are drawn in dark-blue and have rounded corners, whereby constants have green borders and square corners. The strings are surrounded by quotes so that they won’t be confused with symbols. The branches are function calls and are depicted as orange rectangles.

a(b(4, "s"), c(3, 4, d()))

To understand how an expression is represented as an AST, it helps to write it in its prefix form.

y <- x * 10
`<-`(y, `*`(x, 10))

There is also the R package called lobstr, which contains the function ast() to create an AST from R Code.

The code from the first example lobstr::ast(a(b(4, "s"), c(3, 4, d()))) results in this:


It looks as expected and just like our hand-drawn AST. The concept of ASTs helps us to understand what is happening in our function. So, if we have the following simple function, !!` introduces a placeholder (promise) for x.

x <- expr(-1)
f(!!x, y)

Due to R’s lazy evaluation, the function f() is not evaluated immediately, but once we called it. At the moment of the function call, the placeholder x is replaced by an additional AST, which can get arbitrary complex. Furthermore, it keeps the order of the operators correct, which is not the case when we use parse() and paste() with strings. So the resulting AST of our code snippet is the following:


Furthermore, !! also works with symbols, functions, and constants.

Perfecting our function

Now, we want to add an argument for the variable we are summarizing to refine our function. At the moment we have air_time hardcoded into it. Thus, we want to replace it with a general summary_var as an argument in our function. Additionally, we want the column names of the final output data frame to be adjusted dynamically, depending on the input variable. For adding summary_var, we follow the quote and unquote pattern from above. However, for the column-naming, we need two additional functions.

Firstly, quo_name(), which converts a quoted symbol into a string. Therefore, we can use normal string operations on it and, e.g. use the base paste command for manipulating it. However, we also need to unquote it, which would be on the Left-Hand-Side, where R is not allowing any computations. Thus, we need the second function, the vestigial operator := instead of the normal =.

my_summary <- function(df, grouping_var, summary_var){
  grouping_var <- enquo(grouping_var)
  summary_var <- enquo(summary_var)
  summary_nm <- quo_name(summary_var)
  summary_nm_avg <- paste0("avg_",summary_nm)
  summary_nm_sum <- paste0("sum_",summary_nm)
  summary_nm_obs <- paste0("obs_",summary_nm)

  df %>%
    group_by(!!grouping_var) %>% 
      !!summary_nm_avg := mean(!!summary_var),
      !!summary_nm_sum := sum(!!summary_var),
      !!summary_nm_obs := n()
my_summary(airline_df, origin, air_time)
# A tibble: 2 x 4
  origin avg_air_time sum_air_time obs_air_time
  <fct>         <dbl>        <int>        <int>
1 JFK            166.       587966         3539
2 LAX            132.       850259         6461

Tidy Dots

In the next step, we want to add the possibility to summarize an arbitrary number of variables. Therefore, we need to use tidy dots (or dot-dot-dot) . E.g. if we call the documentation for select(), we get

Usage select(.data, ...) Arguments ... One or more unquoted expressions separated by commas.

In select() we can use any number of variables we want to select. We will use tidy dots ... in our function. However, there are some things we have to account for.

Within the function, ... is treated as a list. So we cannot use !! or enquo(), because these commands are made for single variables. However, there are counterparts for the case of .... In order to quote several arguments at once, we can use enquos(). enquos() gives back a list of quoted arguments. In order to unquote several arguments we need to use !!!, which is also called the big bang-Operator. !!! replaces arguments one-to-many, which is called unquote-splicing and respects hierarchical orders.


With using purrr, we can neatly handle the computation with our list entries provided by ... (for more information ask your Purrr-Macist). So, putting everything together, we finally arrive at our final function.

my_summary <- function(df, grouping_var, ...) {
  grouping_var <- enquo(grouping_var)

  smry_vars <- enquos(..., .named = TRUE)

  smry_avg <- purrr::map(smry_vars, function(var) {
    expr(mean(!!var, na.rm = TRUE))
  names(smry_avg) <- paste0("avg_", names(smry_avg))

  smry_sum <- purrr::map(smry_vars, function(var) {
    expr(sum(!!var, na.rm = TRUE))
  names(smry_sum) <- paste0("sum_", names(smry_sum))

  df %>%
    group_by(!!grouping_var) %>%
    summarise(!!!smry_avg, !!!smry_sum, obs = n())

my_summary(airline_df, origin, dep_delay, arr_delay)
# A tibble: 2 x 6
  origin avg_dep_delay avg_arr_delay sum_dep_delay sum_arr_delay   obs
  <fct>          <dbl>         <dbl>         <int>         <int> <int>
1 JFK            12.9          11.8          45792         41625  3539
2 LAX             8.64          5.13         55816         33117  6461

And the tidy evaluation goes on and on

As mentioned in the beginning, tidy evaluation is not only used within dplyr but within most of the packages in the tidyverse. Thus, to know how tidy evaluation works is also helpful if one wants to use ggplot in order to create a function for a styled version of a grouped scatter plot. In this example, the function takes the data, the values for the x and y-axes as well as the grouping variable as inputs:

scatter_by <- function(.data, x, y, z=NULL) {
  x <- enquo(x)
  y <- enquo(y)
  z <- enquo(z)

  ggplot(.data) + 
    geom_point(aes(!!x, !!y, color = !!z)) +
scatter_by(airline_df, distance, air_time, origin) 

Another example would be to use R Shiny Inputs in a sparklyr-Pipeline. input$ cannot be used directly within sparklyr, because it would try to resolve the input list object on the spark side.



# Define server logic required to filter numbers
shinyServer(function(input, output) {
    tbl_1 <- tibble(a = 1:5, b = 6:10)
    sc <- spark_connect(master = "local")

    tbl_1_sp <-
            dest = sc,
            df = tbl_1,
            name = "tbl_1_sp",
            overwrite = TRUE

    observeEvent(input$select_a, {

        number_b <- tbl_1_sp %>%
            filter(a == !!input$select_a) %>%
            collect() %>%

        output$text_b <- renderText({
            paste0("Selected number : ", number_b)



# Define UI for application t
    # Application title
    titlePanel("Select Number Example"),

    # Sidebar with a slider input for number
            "Number for 1:",
            min = 1,
            max = 5,
            value = 1

    # Show a text as output


There are many use cases for tidy evaluation, especially for advanced programmers. With the tidyverse getting bigger by the day, knowing tidy evaluation gets more and more useful. For getting more information about the metaprogramming in R and other advanced topics, I can recommend the book Advanced R by Hadley Wickham.

Markus Berroth Markus Berroth

Reinforcement learning is currently one of the hottest topics in machine learning. For a recent conference we attended (the awesome Data Festival in Munich), we’ve developed a reinforcement learning model that learns to play Super Mario Bros on NES so that visitors, that come to our booth, can compete against the agent in terms of level completion time.

The promotion was a great success and people enjoyed the “human vs. machine” competition. There was only one contestant who was able to beat the AI by taking a secret shortcut, that the AI wasn’t aware of. Also, developing the model in Python was a lot of fun. So, I decided to write a blog post about it that covers some of the fundamental concepts of reinforcement learning as well as the actual implementation of our Super Mario agent in TensorFlow (beware, I’ve used TensorFlow 1.13.1, TensorFlow 2.0 was not released at the time of writing this article).

Recap: reinforcement learning

Most machine learning models have an explicit connection between inputs and outputs that does not change during training time. Therefore, it can be difficult to model or predict systems, where the inputs or targets themselves depend on previous predictions. However, often,the world around the model updates itself with every prediction made. What sounds quite abstract is actually a very common situation in the real world: autonomous driving, machine control, process automation etc. – in many situations, decisions that are made by models have an impact on their surroundings and consequently on the next actions to be taken. Classical supervised learning approaches can only be used to a limited extend in such kinds of situations. To solve the latter, machine learning models are needed that are able to cope with time-dependent variation of inputs and outputs that are interdependent. This is where reinforcement learning comes into play.

In reinforcement learning, the model (called agent) interacts with its environment by choosing from a set of possible actions (action space) in each state of the environment that cause either positive or negative rewards from the environment. Think of rewards as an abstract concept of signalizing that the action taken was good or bad. Thereby, the reward issued by the environment can be immediate or delayed into the future. By learning from the combination of environment states, actions and corresponsing rewards (so called transitions), the agent tries to reach an optimal set of decision rules (the policy) that maximize the total reward gathered by the agent in each state.

Q-learning and Deep Q-learning

In reinforcement learning we often use a learning concept called Q-learning. Q-learning is based on so called Q-values, that help the agent determining the optimal action, given the current state of the environment. Q-values are “discounted” future rewards, that our agent collects during training by taking actions and moving through the different states of the environment. Q-values themselves are tried to be approximated during training, either by simple exploration of the environment or by using a function approximator, such as a deep neural network (as in our case here). Mostly, we select in each state the action that has the highest Q-value, i.e. the highest discounuted future reward, givent the current state of the environment.

When using a neural network as a Q-function approximator we learn by computing the difference between the predicted Q-values and the “true” Q-values, i.e. the representation of the optimal decision in the current state. Based on the computed loss, we update the network’s parameters using gradient descent, just like in any other neural network model. By doing this often, our network converges to a state, where it can approximate the Q-values of the next state, given the current state of the environment. If the approximation is good enough, we simple select the action that has the highest Q-value. By doing so, the agent is able to decide in each situation, which action generates the best outcome in terms of reward collection.

In most deep reinforcement learning models there are actually two deep neural networks involved: the online- and the target-network. This is done because during training, the loss function of a single neural network is computed against steadily changing targets (Q-values), that are based on the networks weights themselves. This adds increased difficulty to the optimization problem or might result in no convergence at all. The target network is basically a copy of the online network with frozen weights that are not directly trained. Instead the target network’s weights are synchronized with the online network after a certain amount of training steps. Enforcing “stable outputs” of the target network that do not change after each training step makes sure that the computed target Q-values that are needed for computing the loss do not change steadily which supports convergence of the optimization problem.

Deep Double Q-learning

Another possible issue with Q-learning is, that due to the selection of the maximum Q-value for determining the best action, the model sometimes produces extraordinary high Q-values during training. Basically, this is not always a problem but might turn into one, if there is a strong concentration at certain actions that in return lead to the negletion of less favorable but “worth-to-try” actions. If the latter are neglected all the time, the model might run into a locally optimal solution or even worse selects the same actions all the time. One way to deal with this problem is to introduce an updated version of Q-learning called double Q-learning.

In double Q-learning the actions in each state are not simply chosen by selecting the action with maximum Q-value of the target network. Instead, the selection process is split into three distinct steps: (1) first, the target network computes the target Q-values of the state after taking the action. Then, (2) the online network computes the Q-values of the state after taking the action and selects the best action by finding the maximum Q-value. Finally, (3) the target Q-Values are calculated using the target Q-values of the target network, but at the selected action indices of the online network. This assures, that there cannot occur an overestimation of Q-values because the Q-values are not updated based on themselves.

Gym environments

In order to build a reinforcement learning aplication, we need two things: (1) an environment that the agent can interact with and learn from (2) the agent, that observes the state(s) of the environment and chooses appropriate actions using Q-values, that (ideally) result in high rewards for the agent. An environment is typically provided as a so called gym, a class that contains the neecessary code to emulate the states and rewards of the environment as a function of the agent’s actions as well further information, e.g. about the possible action space. Here is an example of a simple environment class in Python:

class Environment:
    """ A simple environment skeleton """
    def __init__(self):
          # Initializes the environment

    def step(self, action):
          # Changes the environment based on agents action
        return next_state, reward, done, info

    def reset(self):
        # Resets the environment to its initial state

    def render(self):
          # Show the state of the environment on screen

The environment has three major class functions: (1) step() executes the environment code as function of the action selected by the agent and returns the next state of the environment, the reward with respect to action, a done flag indicating if the environment has reached its terminal state as well as a dictionary of additional information about the environment and its state, (2) reset() resets the environment in it’s original state and (3) render() print the current state on the screen (for example showing the current frame of the Super Mario Bros game).

For Python, a go-to place for finding gyms is OpenAI. It contains lots of diffenrent games and problems well suited for solving using reinforcement learning. Furthermore, there is an Open AI project called Gym Retro that contains hundrets of Sega and SNES games, ready to be tackled by reinforcement learning algorithms.


The agent comsumes the current state of the environment and selects an appropriate action based on the selection policy. The policy maps the state of the environment to the action to be taken by the agent. Finding the right policy is a key question in reinforcement learning and often involves the usage of deep neural networks. The following agent simply observes the state of the environment and returns action = 1 if state is larger than 0 and action = 0 otherwise.

class Agent:
    """ A simple agent """
    def __init__(self):

    def action(self, state):
        if state > 0:
            return 1
            return 0

This is of course a very simplistic policy. In practical reinforcement learning applications the state of the environment can be very complex and high-dimensional. One example are video games. The state of the environment is determined by the pixels on screen and the previous actions of the player. Our agent needs to find a policy that maps the screen pixels into actions that generate rewards from the environment.

Environment wrappers

Gym environments contain most of the functionalities needed to use them in a reinforcement learning scenario. However, there are certain features that do not come prebuilt into the gym, such as image downscaling, frame skipping and stacking, reward clipping and so on. Luckily, there exist so called gym wrappers that provide such kinds of utility functions. An example that can be used for many video games such as Atari or NES can be found here. For video game gyms it is very common to use wrapper functions in order to achieve a good performance of the agent. The example below shows a simple reward clipping wrapper.

import gym

class ClipRewardEnv(gym.RewardWrapper):
        """ Example wrapper for reward clipping """
    def __init__(self, env):
        gym.RewardWrapper.__init__(self, env)

    def reward(self, reward):
        # Clip reward to {1, 0, -1} by its sign
        return np.sign(reward)

From the example shown above you can see, that it is possible to change the default behavior of the environment by “overwriting” its core functions. Here, rewards of the environment are clipped to [-1, 0, 1] using np.sign() based on the sign of the reward.

The Super Mario Bros NES environment

For our Super Mario Bros reinforcement learning experiment, I’ve used gym-super-mario-bros. The API ist straightforward and very similar to the Open AI gym API. The following code shows a random agent playing Super Mario. This causes Mario to wiggle around on the screen and – of course – does not lead to a susscessful completion of the game.

from nes_py.wrappers import BinarySpaceToDiscreteSpaceEnv
import gym_super_mario_bros
from gym_super_mario_bros.actions import SIMPLE_MOVEMENT

# Make gym environment
env = gym_super_mario_bros.make('SuperMarioBros-v0')
env = BinarySpaceToDiscreteSpaceEnv(env, SIMPLE_MOVEMENT)

# Play random
done = True
for step in range(5000):
    if done:
        state = env.reset()
    state, reward, done, info = env.step(env.action_space.sample())

# Close device

The agent interacts with the environment by choosing random actions from the action space of the environment. The action space of a video game is actually quite large since you can press multiple buttons at the same time. Here, the action space is reduced to SIMPLE_MOVEMENT, which covers basic game actions such as run in all directions, jump, duck and so on. BinarySpaceToDiscreteSpaceEnv transforms the binary action space (dummy indicator variables for all buttons and directions) into a single integer. So for example the integer action 12 corresponds to pressing right and A (running).

Using a deep learning model as an agent

When playing Super Mario Bros on NES, humans see the game screen – more precisely – they see consecutive frames of pixels, displayed at a high speed on the screen. Our human brains are capable of transforming the raw sensorial input from our eyes into electrical signals that are processed by our brain that trigger corresponding actions (pressing buttons on the controller) that (hopefully) lead Mario to the finishing line.

When training the agent, the gym renders each game frame as a matrix of pixels, according to the respective action taken by the agent. Basically, those pixels can be used as an input to any machine learning model. However, in reinforcement learning we often use convolutional neural networks (CNNs) that excel at image recognition problems compared to other ML models. I won’t go into technical detail about CNNs here, there’s a plethora of great intro articles to CNNs like this one.

Instead of using only the current game screen as an input to the model, it is common to use multiple stacked frames as an input to the CNN. By doing so, the model can process changes and “movements” on the screen between consecutive frames, which would not be possible when using only a single game frame. Here, the input tensor of our model is of size [84, 84, 4]. This corresponds to a stack of 4 grayscale frames, each frame of size 84×84 pixels. This corresponds to the default tensor size for 2-dimensional convolution.

The architecture of the deep learning model consists of three convolutional layers, followed by a flatten and one fully connected layer with 512 neurons as well as an output layer, consisting of actions = 6 nerons, which corresponds to the action space of the game (in this case RIGHT_ONLY, i.e. actions to move Mario to the right – enlarging the action space usually causes an increase in problem complexity and training time).

If you take a closer look at the TensorBoard image below, you’ll notice that the model actually consists of not only one but two identical convolutional branches. One is the online network branch, the other one is the target network branch. The online network is acutally trained using gradient descent. The target network is not directly trained but periodically synchronized every copy = 10000 steps by copying the weights from the online branch to the target branch of the network. The target network branch is excluded from gradient descent training by using the tf.stop_gradient() function around the output layer of the branch. This causes a stop in the flow of gradients at the output layer so that they cannot propagate along the branch and so the weights are not updated.

The agent learns by (1) taking random samples of historical transitions, (2) computing the “true” Q-values based on the states of the environment after action, next_state, using the target network branch and the double Q-learning rule, (3) discounting the target Q-values using gamma = 0.9 and (4) run a batch gradient descent step based on the network’s internal Q-prediction and the true Q-values, supplied by target_q. In order to speed up the training process, the agent is not trained after each action but every train_each = 3 frames which corresponds to a training every 4 frames. In addition, not every frame is stored in the replay buffer but each 4th frame. This is called frame skipping. More specifically, a max pooling operation is performed that aggregates the information between the last 4 consecutive frames. This is motivated by the fact that consecutive frames contain nearly the same information which does not add new information to the learning problem and might introduce strongly autocorrelated datapoints.

Speaking of correlated data: our network is trained using adaptive moment estimation (ADAM) and gradient descent at a learning_rate = 0.00025, which requires i.i.d. datapoints in order to work well. This means, that we cannot simply use all new transition tuples subsequently for training since they are highly correlated. To solve this issue we use a concept called experience replay buffer. Hereby, we store every transition of our game in a ring buffer object (in Python the deque() function) which is then randomly sampled from, when we acquire our training data of batch_size = 32. By using a random sampling strategy and a large enough replay buffer, we can assume that the resulting datapoints are (hopefully) not correlated. The following codebox shows the DQNAgent class.

import time
import random
import numpy as np
from collections import deque
import tensorflow as tf
from matplotlib import pyplot as plt

class DQNAgent:
    """ DQN agent """
    def __init__(self, states, actions, max_memory, double_q):
        self.states = states
        self.actions = actions
        self.session = tf.Session()
        self.saver = tf.train.Saver(max_to_keep=10)
        self.saver = tf.train.Saver()
        self.memory = deque(maxlen=max_memory)
        self.eps = 1
        self.eps_decay = 0.99999975
        self.eps_min = 0.1
        self.gamma = 0.90
        self.batch_size = 32
        self.burnin = 100000
        self.copy = 10000
        self.step = 0
        self.learn_each = 3
        self.learn_step = 0
        self.save_each = 500000
        self.double_q = double_q

    def build_model(self):
        """ Model builder function """
        self.input = tf.placeholder(dtype=tf.float32, shape=(None, ) + self.states, name='input')
        self.q_true = tf.placeholder(dtype=tf.float32, shape=[None], name='labels')
        self.a_true = tf.placeholder(dtype=tf.int32, shape=[None], name='actions')
        self.reward = tf.placeholder(dtype=tf.float32, shape=[], name='reward')
        self.input_float = tf.to_float(self.input) / 255.
        # Online network
        with tf.variable_scope('online'):
            self.conv_1 = tf.layers.conv2d(inputs=self.input_float, filters=32, kernel_size=8, strides=4, activation=tf.nn.relu)
            self.conv_2 = tf.layers.conv2d(inputs=self.conv_1, filters=64, kernel_size=4, strides=2, activation=tf.nn.relu)
            self.conv_3 = tf.layers.conv2d(inputs=self.conv_2, filters=64, kernel_size=3, strides=1, activation=tf.nn.relu)
            self.flatten = tf.layers.flatten(inputs=self.conv_3)
            self.dense = tf.layers.dense(inputs=self.flatten, units=512, activation=tf.nn.relu)
            self.output = tf.layers.dense(inputs=self.dense, units=self.actions, name='output')
        # Target network
        with tf.variable_scope('target'):
            self.conv_1_target = tf.layers.conv2d(inputs=self.input_float, filters=32, kernel_size=8, strides=4, activation=tf.nn.relu)
            self.conv_2_target = tf.layers.conv2d(inputs=self.conv_1_target, filters=64, kernel_size=4, strides=2, activation=tf.nn.relu)
            self.conv_3_target = tf.layers.conv2d(inputs=self.conv_2_target, filters=64, kernel_size=3, strides=1, activation=tf.nn.relu)
            self.flatten_target = tf.layers.flatten(inputs=self.conv_3_target)
            self.dense_target = tf.layers.dense(inputs=self.flatten_target, units=512, activation=tf.nn.relu)
            self.output_target = tf.stop_gradient(tf.layers.dense(inputs=self.dense_target, units=self.actions, name='output_target'))
        # Optimizer
        self.action = tf.argmax(input=self.output, axis=1)
        self.q_pred = tf.gather_nd(params=self.output, indices=tf.stack([tf.range(tf.shape(self.a_true)[0]), self.a_true], axis=1))
        self.loss = tf.losses.huber_loss(labels=self.q_true, predictions=self.q_pred)
        self.train = tf.train.AdamOptimizer(learning_rate=0.00025).minimize(self.loss)
        # Summaries
        self.summaries = tf.summary.merge([
            tf.summary.scalar('reward', self.reward),
            tf.summary.scalar('loss', self.loss),
            tf.summary.scalar('max_q', tf.reduce_max(self.output))
        self.writer = tf.summary.FileWriter(logdir='./logs', graph=self.session.graph)

    def copy_model(self):
        """ Copy weights to target network """[tf.assign(new, old) for (new, old) in zip(tf.trainable_variables('target'), tf.trainable_variables('online'))])

    def save_model(self):
        """ Saves current model to disk """, save_path='./models/model', global_step=self.step)

    def add(self, experience):
        """ Add observation to experience """

    def predict(self, model, state):
        """ Prediction """
        if model == 'online':
            return, feed_dict={self.input: np.array(state)})
        if model == 'target':
            return, feed_dict={self.input: np.array(state)})

    def run(self, state):
        """ Perform action """
        if np.random.rand() < self.eps:
            # Random action
            action = np.random.randint(low=0, high=self.actions)
            # Policy action
            q = self.predict('online', np.expand_dims(state, 0))
            action = np.argmax(q)
        # Decrease eps
        self.eps *= self.eps_decay
        self.eps = max(self.eps_min, self.eps)
        # Increment step
        self.step += 1
        return action

    def learn(self):
        """ Gradient descent """
        # Sync target network
        if self.step % self.copy == 0:
        # Checkpoint model
        if self.step % self.save_each == 0:
        # Break if burn-in
        if self.step < self.burnin:
        # Break if no training
        if self.learn_step < self.learn_each:
            self.learn_step += 1
        # Sample batch
        batch = random.sample(self.memory, self.batch_size)
        state, next_state, action, reward, done = map(np.array, zip(*batch))
        # Get next q values from target network
        next_q = self.predict('target', next_state)
        # Calculate discounted future reward
        if self.double_q:
            q = self.predict('online', next_state)
            a = np.argmax(q, axis=1)
            target_q = reward + (1. - done) * self.gamma * next_q[np.arange(0, self.batch_size), a]
            target_q = reward + (1. - done) * self.gamma * np.amax(next_q, axis=1)
        # Update model
        summary, _ =[self.summaries, self.train],
                                      feed_dict={self.input: state,
                                                 self.q_true: np.array(target_q),
                                                 self.a_true: np.array(action),
                                                 self.reward: np.mean(reward)})
        # Reset learn step
        self.learn_step = 0
        # Write
        self.writer.add_summary(summary, self.step)

Training the agent to play

First, we need to instantiate the environment. Here, we use the first level of Super Mario Bros, SuperMarioBros-1-1-v0 as well as a discrete event space with RIGHT_ONLY action space. Additionally, we use a wrapper that applies frame resizing, stacking and max pooling, reward clipping as well as lazy frame loading to the environment.

When the training starts, the agent begins to explore the environment by taking random actions. This is done in order to build up initial experience that serves as a starting point for the actual learning process. After burin = 100000 game frames, the agent slowly starts to replace random actions by actions determined by the CNN policy. This is called an epsilon-greedy policy. Epsilon-greeedy means, that the agent takes a random action with probability epsilon or a policy-based action with probability (1-epsilon). Here, epsilon diminisches linearly during training by a factor of eps_decay = 0.99999975 until it reaches eps = 0.1 where it remains constant for the rest of the training process. It is important to not completely eliminate random actions from the training process in order to avoid getting stuck on locally optimal solutions.

For each action taken, the environment returns a four objects: (1) the next game state, (2) the reward for taking the action, (3) a flag if the episode is done and (4) an info dictionary containing additional information from the environment. After taking the action, a tuple of the returned objects is added to the replay buffer and the agent performs a learning step. After learning, the current state is updated with the next_state and the loop increments. The while loop breaks, if the done flag is True. This corresponds to either the death of Mario or to a successful completion of the level. Here, the agent is trained in 10000 episodes.

import time
import numpy as np
from nes_py.wrappers import BinarySpaceToDiscreteSpaceEnv
import gym_super_mario_bros
from gym_super_mario_bros.actions import RIGHT_ONLY
from agent import DQNAgent
from wrappers import wrapper

# Build env (first level, right only)
env = gym_super_mario_bros.make('SuperMarioBros-1-1-v0')
env = BinarySpaceToDiscreteSpaceEnv(env, RIGHT_ONLY)
env = wrapper(env)

# Parameters
states = (84, 84, 4)
actions = env.action_space.n

# Agent
agent = DQNAgent(states=states, actions=actions, max_memory=100000, double_q=True)

# Episodes
episodes = 10000
rewards = []

# Timing
start = time.time()
step = 0

# Main loop
for e in range(episodes):

    # Reset env
    state = env.reset()

    # Reward
    total_reward = 0
    iter = 0

    # Play
    while True:

        # Show env (diabled)
        # env.render()

        # Run agent
        action =

        # Perform action
        next_state, reward, done, info = env.step(action=action)

        # Remember transition
        agent.add(experience=(state, next_state, action, reward, done))

        # Update agent

        # Total reward
        total_reward += reward

        # Update state
        state = next_state

        # Increment
        iter += 1

        # If done break loop
        if done or info['flag_get']:

    # Rewards
    rewards.append(total_reward / iter)

    # Print
    if e % 100 == 0:
        print('Episode {e} - +'
              'Frame {f} - +'
              'Frames/sec {fs} - +'
              'Epsilon {eps} - +'
              'Mean Reward {r}'.format(e=e,
                                       fs=np.round((agent.step - step) / (time.time() - start)),
                                       eps=np.round(agent.eps, 4),
        start = time.time()
        step = agent.step

# Save rewards'rewards.npy', rewards)

After each game episode, the averagy reward in this episode is appended to the rewards list. Furthermore, different stats such as frames per second and the current epsilon are printed after every 100 episodes.


During training, the program checkpoints the current network at save_each = 500000 frames and keeps the 10 latest models on disk. I’ve downloaded several model versions during training to my local machine and produced the following video.

It is so awesome to see the learning progress of the agent! The training process took approximately 20 hours on a GPU accelerated VM on Google Cloud.

Summary and outlook

Reinforcement learning is an exciting field in machine learning that offers a wide range of possible applications in science and business likewise. However, the training of reinforcement learning agents is still quite cumbersome and often requires tedious tuning of hyperparameters and network architecture in order to work well. There have been recent advances, such as RAINBOW (a combination of multiple RL learning strategies) that aim at a more robust framework for training reinforcement learning agents but the field is still an area of active research. Besides Q-learning, there are many other interesting training concepts in reinforcement learning that have been developed. If you want to try different RL agents and training approaches, I suggest you check out Stable Baselines, a great way to easily use state-of-the-art RL agents and training concepts.

If you are a deep learning beginner and want to learn more, you should check our brandnew STATWORX Deep Learning Bootcamp, a 5-day in-person introduction into the field that covers everything you need to know in order to develop your first deep learning models: neural net theory, backpropagation and gradient descent, programming models in Python, TensorFlow and Keras, CNNs and other image recognition models, recurrent networks and LSTMs for time series data and NLP as well as advanced topics such as deep reinforcement learning and GANs.

If you have any comments or questions on my post, feel free to contact me!  Also, feel free to use my code (link to GitHub repo) or share this post with your peers on social platforms of your choice.

If you’re interested in more content like this, join our mailing list, constantly bringing you fresh data science, machine learning and AI reads and treats from me and my team at STATWORX right into your inbox!

Lastly, follow me on LinkedIn or my company STATWORX on Twitter, if you’re interested in more!

Sebastian Heinz Sebastian Heinz

It’s Valentine’s day, making this the most romantic time of the year. But actually, already 2018 was a year full of love here at STATWORX: many of my STATWORX colleagues got engaged. And so we began to wonder – some fearful, some hopeful – who will be next? Therefore, today we’re going to tackle this question in the only true way: with data science!

Gathering the Data

To get my data, I surveyed my colleagues. I asked my (to be) married colleagues to answer my questions based on the very day they got engaged. My single colleagues answered my questions with respect to their current situation. I asked them about some factors that I’ve always suspected to influence someone’s likeliness to get married. For example, I’m sure that in comparison to Python users, R users are much more romantic. The indiscreet questions I badgered my coworkers with were:

  • Are you married or engaged?
  • How long have you been in your relationship?
  • Is your employment permanent?
  • How long have you been working at STATWORX?
  • What’s your age?
  • Are you living together with your partner?
  • Are you co-owning a pet with your partner?
  • What’s your preferred programming language: R, Python or none of both.

I’m going to treat the relationship status as dichotomous variable: Married or engaged vs. single or “only” dating. To maintain some of the privacy of my colleagues I gave them all some randomly (!!) chosen pet names. (Side note: There really is a subreddit for everything.)

Descriptive Exploration

Since the first step in generating data-driven answer should always be a descriptive exploration of the data at hand, I made some plots.

First, I took a look at the absolute frequencies of preferred programming languages in the groups of singles vs. married or engaged STATWORX employees. I fear, the romantic nature of R users is not the explanation we’re looking for:

# reformatting the target variable
df1 <- df %>%
  dplyr::mutate(engaged = ifelse(engaged == "yes", 
                                 "Engaged or Married", 
                                 "Single")) %>%
  dplyr::group_by(`primary programming language`, engaged) %>%
  dplyr::summarise(freq = n(),
                   image = "~/Desktop/heart_red.png") 

# since in geom_image size cannot be mapped to variable
# multiple layers of data subsets  
ggplot() +
  geom_image(data = filter(df1, freq == 1), 
             aes(y = `primary programming language`,
                 x = engaged, 
                 image = image), 
             size = 0.1) + 
  geom_image(data = filter(df1, freq > 1 & freq <= 5), 
             aes(y = `primary programming language`, 
                 x = engaged, 
                 image = image),
             size = 0.2) +
  geom_image(data = filter(df1, freq >= 13), 
             aes(y= `primary programming language`, 
                 x = engaged, 
                 image = image),
             size = 0.3) +
  geom_text(data = df1, 
            aes(y =`primary programming language`, 
                x = engaged, 
                label = freq), 
            color = "white", size = 4) +
  ylab("Preferred programming language") +
  xlab("n Absolute frequencies") +
programming languages frequencies

I also explored the association of relationship status and the more conventional factors of age and relationship duration. And indeed, those of my colleagues who are in their late twenties or older and have been partnered for a while now, are mostly married or engaged.

# plotting age and relationship duration vs. relationship status

ggplot() +
# doing this only to get a legend:
  geom_point(data = df,
             aes(x = age, y = `relationship duration`,
                 color = engaged), shape = 15) + 
  geom_image(data = filter(df, engaged == "yes"), 
             aes(x = age, y = `relationship duration`,
                 image = "~/Desktop/heart_red.png")) +
  geom_image(data = filter(df, engaged == "no"), 
             aes(x = age, y = `relationship duration`,
                 image = "~/Desktop/heart_black.png")) +
  ylab("Relationship duration n") +
  xlab("n Age") +
  scale_color_manual(name = "Married or engaged?",
                     values = c("#000000", "#D00B0B")) +
  scale_x_continuous(breaks = pretty_breaks()) +
  theme_minimal() +
  theme(legend.position = "bottom")
age relationship duration

Statistical Models

I’ll employ some statistical models, but the database is rather small. Therefore, our model options are somewhat limited (and of course only suitable for entertainment). But it’s still possible to fit a decision tree, which might help to pinpoint due to which circumstance some of us are still waiting for that special someone to put a ring on (it).

# recoding target to get more understandable labels
df <- df %>%
  dplyr::mutate(engaged = ifelse(engaged == "yes", 
                                 "(to be) married", 

# growing a decision three with a ridiculous low minsplit
fit <- rpart(engaged ~ `relationship duration` + age + 
             `shared pet` + `permanent employment` +
             cohabitating + `years at STATWORX`,
             control = rpart.control(minsplit = 2), # overfitting ftw
             method = "class", data = df)

# plotting the three
rpart.plot(fit, type = 3, extra = 2, 
           box.palette = c("#D00B0B", "#fae6e6"))

relationship decision tree

Our decision tree implies, that the unintentionally unmarried of us maybe should consider moving in with their partner since cohabitating seems to be the most important factor.

But that still doesn’t exactly answer the question, who of us will be next. To predict our chances to get engaged, I estimated a logistic regression.

We see that cohabiting, one’s age and the time we’ve been working at STATWORX are accompanied by a higher probability to (soon to) be married. However, simply having been together for a long time or owning a pet together with our partner, does not help. (Although, I assume that this rather unintuitive interrelation is caused by a certain outlier in the data – “Honey”, I’m looking at you!)

Finally, I got the logistic regression’s predicted probabilities for all of us to be married or engaged. As you can see down below, the single days of “Teddy Bear”, “Honey”, “Sweet Pea” and “Babe” seem to be numbered.

# reformatting the target variable
df <- df %>%
  dplyr::mutate(engaged = ifelse(engaged == "(to be) married", 1, 0))

# in-sample fitting: estimating the model 
log_reg <- glm(engaged ~ `relationship duration` + age +
               `shared pet` + `permanent employment` + 
               cohabitating + `years at STATWORX`,
              family = binomial, data = df)

df$probability <- predict(log_reg, df, type = "response")

# plotting the predicted probabilities
ggplot() +
  # again, doing this only to get a legend:
  geom_point(data = df,
             aes(x = probability, y = nickname,
                 color = as.factor(engaged)), shape = 15) + 
  geom_image(data = filter(df, engaged == 1), 
             aes(x = probability, y = nickname,
                 image = "~/Desktop/heart_red.png")) +
  geom_image(data = filter(df, engaged == 0), 
             aes(x = probability, y = nickname,
                 image = "~/Desktop/heart_black.png")) +
  ylab(" ") +
  xlab("Predicted Probability") +
  scale_color_manual(name = "Married or engaged?",
                     values = c("#000000", "#D00B0B"),
                     labels = c("no", "yes")) +
  scale_x_continuous(breaks = pretty_breaks()) +
  theme_minimal() +
  theme(legend.position = "bottom")
predicted probabilities for marriage

I hope this was as insightful for you as it was for me. And to all of us, whose hopes have been shattered by cold, hard facts, let’s remember: there are tons of discounted chocolates waiting for us on February 15th.

Lea Waniek Lea Waniek


RStudio is a powerful IDE that helped me so many times with conveniently debugging large and complex R programs as well as shiny apps, but there is one thing that bugs me about it: there is no easy option or interface in RStudio that lets you customize your theme, just as you can do in more developed text editors such as Atom or Sublime. This is sad, especially if you like working with RStudio but are not satisfied with its appearance. That being said, I did some research on how to customize its themes by hand with a little hack! Here you go!

Customizing your theme

Well, before we get started, I’ve got some sad news for you: first, I use a Mac, so all the instructions pertain to that platform. However, I think things should work similarly when having Windows on your computer. Second, you have to sacrifice one of RStudio’s built-in editor themes, so choose wisely. In the end, what we will do is overwrite the settings of one theme with your own preferences. For the sake of demonstration, I will do that with the dawn editor theme in RStudio version 1.1.419. To understand what will be going on, be aware that the RStudio IDE in fact works like a browser and the styles and themes you have at your hand are essentially css files lying somewhere in your file system. We will eventually access those files and change them.

First step: Change a theme until you like it

Let’s go now and look at the dawn theme.dawn theme

Now, if you want to experiment with changing it, recall that RStudio actually is a browser, so right-clicking somewhere in the text editor and selecting “Inspect Element” should open the HTML the editor is based on.

href to css
Scroll down until you find the <link> tag referencing a css file. There is a path in the href argument. You should remember the filename of the corresponding css file, because we will change that file on our file system later. Simply click on the path to the css file to view its content.

old theme overview

Perfect! Now, you can mess around with the css selectors’ attributes and view the result on the theme in real-time (unfortunately, you cannot rearrange the selectors)! As an example, I will change the selector .ace_comment which defines the physical appearance of comments in the code. Let’s say I don’t like it being italic and I want to change the color to be a bit more … noticeable. That’s why I decide to make the font bold and change the color to red. Furthermore, I add another attribute font-size to make the text larger.

This is what it looked like before …

comment old

… and this is what it looks like now ….

comment new

Second step: Overwrite the theme on your file system

So far, we have temporarily changed the theme. If you reopen RStudio, everything is gone. This can be handy if you just want to play around. However, we want to make things permanent, which is our next step.

As I have already mentioned, an editor theme in RStudio essentially is a css file lying somewhere on your computer, so all R themes can be accessed through your file system. Simply navigate on the program folder with your finder, right-click on the RStudio logo and click on “show package content” (or something similar to that; sorry, my system language is German ;)).

RStudio show package content

You should now find yourself in the directory Contents. Next, navigate to the css files as shown below

path to css

If you change the file corresponding to the dawn theme (97D3C…9C49C5.cache.css), you will have permanently changed that theme.


Customizing RStudio themes requires some little tricks, but it is definitely doable. Please keep in mind that any changes you make when inspecting the editor theme will only be temporary. To make changes permanent, take a look at what theme you want to overwrite, search its corresponding css file name, and enter your changes there.

If you like the Atom theme “One Dark” and you would like to have it as an RStudio theme, you can find it on my GitHub account. Simply rename it to the css file you want to replace and put it in RStudio’s theme folder. As a little teaser: this is what it looks like:

atom theme rstudio

[author class=”mtl” title”Über den Autor”] Tobias Krabel Tobias Krabel

One of the functions I use the most is strsplit. It is quite useful if you want to separate a string by a specific character. Even if you have some complex rules for the split, most of the time you can solve this with a regular expression. However, recently I came across a problem I could not get my head around. I wanted to split the string but also keep the delimiter.

Basic Regular Expressions

Let’s start at the beginning. If you do not know what regular expressions are, I will give you a short introduction. With regular expressions, you can describe patterns in a string and then use them in functions like grep, gsub or strsplit.

As the R (3.4.1) help file for regex states:

A regular expression is a pattern that describes a set of strings. Two types of regular expressions are used in R, extended regular expressions (the default) and Perl-like regular expressions used by perl = TRUE. There is a also fixed = TRUE which can be considered to use a literal regular expression.

If you are looking for a specific pattern in a string – let’s say "3D" – you can just use those characters:

x <- c("3D", "4D", "3a")
grep("3D", x)
[1] 1

If you instead want all numbers followed by an upper case letter you should use regular expressions:

x <- c("3D", "4D", "3a")
grep("[0-9][A-Z]", x)
[1] 1 2

Since regular expressions can get quite complicated really fast, I will stop here and refer you to a cheat sheet for more info. In the cheat sheet, you can also find the part that gave me the trouble: lookarounds


Back to my problem. I had a string like c("3D/MON&SUN") and wanted to separate it by / and &.

x <- c("3D/MON&SUN")
strsplit(x, "[/&]", perl = TRUE)
[1] "3D"  "MON" "SUN"

Since I still needed the delimiter as it contained useful information, I used the lookaround regular expressions. First up is the lookbehind which works just fine:

strsplit(x, "(?<=[/&])", perl = TRUE)
[1] "3D/"  "MON&" "SUN"

However, when i used the lookahead, it did not work as I expected

strsplit(x, "(?=[/&])", perl = TRUE)
[1] "3D"  "/"   "MON" "&"   "SUN"

In my search for a solution, I finally found this post on Stackoverflow, which explained the strange behaviour of strsplit. Well, after reading the post and the help file – it is not strange anymore. It is just what the algorithm said it would do – the very same way it is stated in help file of strsplit:

repeat {
    if the string is empty
    if there is a match
        add the string to the left of the match to the output.
        remove the match and all to the left of it.
        add the string to the output.

Since the lookarounds have zero length, they mess up the removing part within the algorithm. Luckily, the post also gave a solution that contains some regular expression magic:

strsplit(x = x, "(?<=.)(?=[/&])",perl = TRUE)
[1] "3D"   "/MON" "&SUN"

So my problem is solved, but I would have to remember this regular expression … uurrghhh!

A New Function: strsplit 2.0

If I have the chance to write a function which eases my work – I will do it! So I wrote my own strsplit with a new argument type = c("remove", "before", "after"). Basically, I just used the regular expression mentioned above and put it into an if-condition.
To sum it all up: Regular expressions are a powerful tool and you should try to learn and understand how they work!

strsplit <- function(x,
                     type = "remove",
                     perl = FALSE,
                     ...) {
  if (type == "remove") {
    # use base::strsplit
    out <- base::strsplit(x = x, split = split, perl = perl, ...)
  } else if (type == "before") {
    # split before the delimiter and keep it
    out <- base::strsplit(x = x,
                          split = paste0("(?<=.)(?=", split, ")"),
                          perl = TRUE,
  } else if (type == "after") {
    # split after the delimiter and keep it
    out <- base::strsplit(x = x,
                          split = paste0("(?<=", split, ")"),
                          perl = TRUE,
  } else {
    # wrong type input
    stop("type must be remove, after or before!")

Jakob Gepp Jakob Gepp

One of the functions I use the most is strsplit. It is quite useful if you want to separate a string by a specific character. Even if you have some complex rules for the split, most of the time you can solve this with a regular expression. However, recently I came across a problem I could not get my head around. I wanted to split the string but also keep the delimiter.

Basic Regular Expressions

Let’s start at the beginning. If you do not know what regular expressions are, I will give you a short introduction. With regular expressions, you can describe patterns in a string and then use them in functions like grep, gsub or strsplit.

As the R (3.4.1) help file for regex states:

A regular expression is a pattern that describes a set of strings. Two types of regular expressions are used in R, extended regular expressions (the default) and Perl-like regular expressions used by perl = TRUE. There is a also fixed = TRUE which can be considered to use a literal regular expression.

If you are looking for a specific pattern in a string – let’s say "3D" – you can just use those characters:

x <- c("3D", "4D", "3a")
grep("3D", x)
[1] 1

If you instead want all numbers followed by an upper case letter you should use regular expressions:

x <- c("3D", "4D", "3a")
grep("[0-9][A-Z]", x)
[1] 1 2

Since regular expressions can get quite complicated really fast, I will stop here and refer you to a cheat sheet for more info. In the cheat sheet, you can also find the part that gave me the trouble: lookarounds


Back to my problem. I had a string like c("3D/MON&SUN") and wanted to separate it by / and &.

x <- c("3D/MON&SUN")
strsplit(x, "[/&]", perl = TRUE)
[1] "3D"  "MON" "SUN"

Since I still needed the delimiter as it contained useful information, I used the lookaround regular expressions. First up is the lookbehind which works just fine:

strsplit(x, "(?<=[/&])", perl = TRUE)
[1] "3D/"  "MON&" "SUN"

However, when i used the lookahead, it did not work as I expected

strsplit(x, "(?=[/&])", perl = TRUE)
[1] "3D"  "/"   "MON" "&"   "SUN"

In my search for a solution, I finally found this post on Stackoverflow, which explained the strange behaviour of strsplit. Well, after reading the post and the help file – it is not strange anymore. It is just what the algorithm said it would do – the very same way it is stated in help file of strsplit:

repeat {
    if the string is empty
    if there is a match
        add the string to the left of the match to the output.
        remove the match and all to the left of it.
        add the string to the output.

Since the lookarounds have zero length, they mess up the removing part within the algorithm. Luckily, the post also gave a solution that contains some regular expression magic:

strsplit(x = x, "(?<=.)(?=[/&])",perl = TRUE)
[1] "3D"   "/MON" "&SUN"

So my problem is solved, but I would have to remember this regular expression … uurrghhh!

A New Function: strsplit 2.0

If I have the chance to write a function which eases my work – I will do it! So I wrote my own strsplit with a new argument type = c("remove", "before", "after"). Basically, I just used the regular expression mentioned above and put it into an if-condition.
To sum it all up: Regular expressions are a powerful tool and you should try to learn and understand how they work!

strsplit <- function(x,
                     type = "remove",
                     perl = FALSE,
                     ...) {
  if (type == "remove") {
    # use base::strsplit
    out <- base::strsplit(x = x, split = split, perl = perl, ...)
  } else if (type == "before") {
    # split before the delimiter and keep it
    out <- base::strsplit(x = x,
                          split = paste0("(?<=.)(?=", split, ")"),
                          perl = TRUE,
  } else if (type == "after") {
    # split after the delimiter and keep it
    out <- base::strsplit(x = x,
                          split = paste0("(?<=", split, ")"),
                          perl = TRUE,
  } else {
    # wrong type input
    stop("type must be remove, after or before!")

Jakob Gepp Jakob Gepp